diff --git a/writers/postgres/messages.go b/writers/postgres/messages.go index 89a5afe4..b903da4c 100644 --- a/writers/postgres/messages.go +++ b/writers/postgres/messages.go @@ -20,7 +20,8 @@ var ( // ErrInvalidMessage indicates that service received message that // doesn't fit required format. ErrInvalidMessage = errors.New("invalid message representation") - errSaveMessage = errors.New("faled to save message to postgress database") + errSaveMessage = errors.New("failed to save message to postgres database") + errTransRollback = errors.New("failed to rollback transaction") ) var _ writers.MessageRepository = (*postgresRepo)(nil) @@ -34,7 +35,7 @@ func New(db *sqlx.DB) writers.MessageRepository { return &postgresRepo{db: db} } -func (pr postgresRepo) Save(messages ...senml.Message) error { +func (pr postgresRepo) Save(messages ...senml.Message) (err error) { q := `INSERT INTO messages (id, channel, subtopic, publisher, protocol, name, unit, value, string_value, bool_value, data_value, sum, time, update_time) @@ -46,6 +47,19 @@ func (pr postgresRepo) Save(messages ...senml.Message) error { if err != nil { return errors.Wrap(errSaveMessage, err) } + defer func() { + if err != nil { + if txErr := tx.Rollback(); txErr != nil { + err = errors.Wrap(err, errors.Wrap(errTransRollback, txErr)) + } + return + } + + if err = tx.Commit(); err != nil { + err = errors.Wrap(errSaveMessage, err) + } + return + }() for _, msg := range messages { dbth, err := toDBMessage(msg) @@ -65,10 +79,7 @@ func (pr postgresRepo) Save(messages ...senml.Message) error { return errors.Wrap(errSaveMessage, err) } } - if err := tx.Commit(); err != nil { - return errors.Wrap(errSaveMessage, err) - } - return nil + return err } type dbMessage struct {