Mainflux.mainflux/consumers/writers/postgres/consumer.go

214 lines
5.1 KiB
Go

// Copyright (c) Mainflux
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
"context"
"encoding/json"
"fmt"
"github.com/gofrs/uuid"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jmoiron/sqlx" // required for DB access
"github.com/mainflux/mainflux/consumers"
"github.com/mainflux/mainflux/pkg/errors"
mfjson "github.com/mainflux/mainflux/pkg/transformers/json"
"github.com/mainflux/mainflux/pkg/transformers/senml"
)
var (
errInvalidMessage = errors.New("invalid message representation")
errSaveMessage = errors.New("failed to save message to postgres database")
errTransRollback = errors.New("failed to rollback transaction")
errNoTable = errors.New("relation does not exist")
)
var _ consumers.BlockingConsumer = (*postgresRepo)(nil)
type postgresRepo struct {
db *sqlx.DB
}
// New returns new PostgreSQL writer.
func New(db *sqlx.DB) consumers.BlockingConsumer {
return &postgresRepo{db: db}
}
func (pr postgresRepo) ConsumeBlocking(ctx context.Context, message interface{}) (err error) {
switch m := message.(type) {
case mfjson.Messages:
return pr.saveJSON(ctx, m)
default:
return pr.saveSenml(ctx, m)
}
}
func (pr postgresRepo) saveSenml(ctx context.Context, messages interface{}) (err error) {
msgs, ok := messages.([]senml.Message)
if !ok {
return errSaveMessage
}
q := `INSERT INTO messages (id, channel, subtopic, publisher, protocol,
name, unit, value, string_value, bool_value, data_value, sum,
time, update_time)
VALUES (:id, :channel, :subtopic, :publisher, :protocol, :name, :unit,
:value, :string_value, :bool_value, :data_value, :sum,
:time, :update_time);`
tx, err := pr.db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(errSaveMessage, err)
}
defer func() {
if err != nil {
if txErr := tx.Rollback(); txErr != nil {
err = errors.Wrap(err, errors.Wrap(errTransRollback, txErr))
}
return
}
if err = tx.Commit(); err != nil {
err = errors.Wrap(errSaveMessage, err)
}
}()
for _, msg := range msgs {
id, err := uuid.NewV4()
if err != nil {
return err
}
m := senmlMessage{Message: msg, ID: id.String()}
if _, err := tx.NamedExec(q, m); err != nil {
pgErr, ok := err.(*pgconn.PgError)
if ok {
if pgErr.Code == pgerrcode.InvalidTextRepresentation {
return errors.Wrap(errSaveMessage, errInvalidMessage)
}
}
return errors.Wrap(errSaveMessage, err)
}
}
return err
}
func (pr postgresRepo) saveJSON(ctx context.Context, msgs mfjson.Messages) error {
if err := pr.insertJSON(ctx, msgs); err != nil {
if err == errNoTable {
if err := pr.createTable(msgs.Format); err != nil {
return err
}
return pr.insertJSON(ctx, msgs)
}
return err
}
return nil
}
func (pr postgresRepo) insertJSON(ctx context.Context, msgs mfjson.Messages) error {
tx, err := pr.db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(errSaveMessage, err)
}
defer func() {
if err != nil {
if txErr := tx.Rollback(); txErr != nil {
err = errors.Wrap(err, errors.Wrap(errTransRollback, txErr))
}
return
}
if err = tx.Commit(); err != nil {
err = errors.Wrap(errSaveMessage, err)
}
}()
q := `INSERT INTO %s (id, channel, created, subtopic, publisher, protocol, payload)
VALUES (:id, :channel, :created, :subtopic, :publisher, :protocol, :payload);`
q = fmt.Sprintf(q, msgs.Format)
for _, m := range msgs.Data {
var dbmsg jsonMessage
dbmsg, err = toJSONMessage(m)
if err != nil {
return errors.Wrap(errSaveMessage, err)
}
if _, err = tx.NamedExec(q, dbmsg); err != nil {
pgErr, ok := err.(*pgconn.PgError)
if ok {
switch pgErr.Code {
case pgerrcode.InvalidTextRepresentation:
return errors.Wrap(errSaveMessage, errInvalidMessage)
case pgerrcode.UndefinedTable:
return errNoTable
}
}
return err
}
}
return nil
}
func (pr postgresRepo) createTable(name string) error {
q := `CREATE TABLE IF NOT EXISTS %s (
id UUID,
created BIGINT,
channel VARCHAR(254),
subtopic VARCHAR(254),
publisher VARCHAR(254),
protocol TEXT,
payload JSONB,
PRIMARY KEY (id)
)`
q = fmt.Sprintf(q, name)
_, err := pr.db.Exec(q)
return err
}
type senmlMessage struct {
senml.Message
ID string `db:"id"`
}
type jsonMessage struct {
ID string `db:"id"`
Channel string `db:"channel"`
Created int64 `db:"created"`
Subtopic string `db:"subtopic"`
Publisher string `db:"publisher"`
Protocol string `db:"protocol"`
Payload []byte `db:"payload"`
}
func toJSONMessage(msg mfjson.Message) (jsonMessage, error) {
id, err := uuid.NewV4()
if err != nil {
return jsonMessage{}, err
}
data := []byte("{}")
if msg.Payload != nil {
b, err := json.Marshal(msg.Payload)
if err != nil {
return jsonMessage{}, errors.Wrap(errSaveMessage, err)
}
data = b
}
m := jsonMessage{
ID: id.String(),
Channel: msg.Channel,
Created: msg.Created,
Subtopic: msg.Subtopic,
Publisher: msg.Publisher,
Protocol: msg.Protocol,
Payload: data,
}
return m, nil
}