Mainflux.mainflux/things/postgres/channels.go

492 lines
12 KiB
Go

// Copyright (c) Mainflux
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"strings"
"github.com/gofrs/uuid"
"github.com/lib/pq"
"github.com/mainflux/mainflux/pkg/errors"
"github.com/mainflux/mainflux/things"
)
var (
// ErrSaveChannel indicates error while saving to database
ErrSaveChannel = errors.New("save channel to db error")
// ErrUpdateChannel indicates error while updating channel in database
ErrUpdateChannel = errors.New("update channel to db error")
// ErrDeleteChannel indicates error while deleting channel in database
ErrDeleteChannel = errors.New("delete channel from db error")
// ErrSelectChannel indicates error while reading channel from database
ErrSelectChannel = errors.New("select channel from db error")
// ErrDeleteConnection indicates error while deleting connection in database
ErrDeleteConnection = errors.New("unmarshal json error")
// ErrHasThing indicates error while checking connection in database
ErrHasThing = errors.New("check thing-channel connection in database error")
//ErrScan indicates error in database scanner
ErrScan = errors.New("database scanner error")
//ErrValue indicates error in database valuer
ErrValue = errors.New("database valuer error")
)
var _ things.ChannelRepository = (*channelRepository)(nil)
type channelRepository struct {
db Database
}
type dbConnection struct {
Channel string `db:"channel"`
Thing string `db:"thing"`
Owner string `db:"owner"`
}
// NewChannelRepository instantiates a PostgreSQL implementation of channel
// repository.
func NewChannelRepository(db Database) things.ChannelRepository {
return &channelRepository{
db: db,
}
}
func (cr channelRepository) Save(ctx context.Context, channels ...things.Channel) ([]things.Channel, error) {
tx, err := cr.db.BeginTxx(ctx, nil)
if err != nil {
return nil, errors.Wrap(ErrSaveChannel, err)
}
q := `INSERT INTO channels (id, owner, name, metadata)
VALUES (:id, :owner, :name, :metadata);`
for _, channel := range channels {
dbch := toDBChannel(channel)
_, err = tx.NamedExecContext(ctx, q, dbch)
if err != nil {
tx.Rollback()
pqErr, ok := err.(*pq.Error)
if ok {
switch pqErr.Code.Name() {
case errInvalid, errTruncation:
return []things.Channel{}, things.ErrMalformedEntity
case errDuplicate:
return []things.Channel{}, things.ErrConflict
}
}
return []things.Channel{}, errors.Wrap(ErrSaveChannel, err)
}
}
if err = tx.Commit(); err != nil {
return []things.Channel{}, errors.Wrap(ErrSaveChannel, err)
}
return channels, nil
}
func (cr channelRepository) Update(ctx context.Context, channel things.Channel) error {
q := `UPDATE channels SET name = :name, metadata = :metadata WHERE owner = :owner AND id = :id;`
dbch := toDBChannel(channel)
res, err := cr.db.NamedExecContext(ctx, q, dbch)
if err != nil {
pqErr, ok := err.(*pq.Error)
if ok {
switch pqErr.Code.Name() {
case errInvalid, errTruncation:
return things.ErrMalformedEntity
}
}
return errors.Wrap(ErrUpdateChannel, err)
}
cnt, err := res.RowsAffected()
if err != nil {
return errors.Wrap(ErrUpdateChannel, err)
}
if cnt == 0 {
return things.ErrNotFound
}
return nil
}
func (cr channelRepository) RetrieveByID(ctx context.Context, owner, id string) (things.Channel, error) {
q := `SELECT name, metadata FROM channels WHERE id = $1 AND owner = $2;`
dbch := dbChannel{
ID: id,
Owner: owner,
}
if err := cr.db.QueryRowxContext(ctx, q, id, owner).StructScan(&dbch); err != nil {
empty := things.Channel{}
pqErr, ok := err.(*pq.Error)
if err == sql.ErrNoRows || ok && errInvalid == pqErr.Code.Name() {
return empty, things.ErrNotFound
}
return empty, errors.Wrap(ErrSelectChannel, err)
}
return toChannel(dbch), nil
}
func (cr channelRepository) RetrieveAll(ctx context.Context, owner string, offset, limit uint64, name string, metadata things.Metadata) (things.ChannelsPage, error) {
nq, name := getNameQuery(name)
m, mq, err := getMetadataQuery(metadata)
if err != nil {
return things.ChannelsPage{}, errors.Wrap(ErrSelectChannel, err)
}
q := fmt.Sprintf(`SELECT id, name, metadata FROM channels
WHERE owner = :owner %s%s ORDER BY id LIMIT :limit OFFSET :offset;`, mq, nq)
params := map[string]interface{}{
"owner": owner,
"limit": limit,
"offset": offset,
"name": name,
"metadata": m,
}
rows, err := cr.db.NamedQueryContext(ctx, q, params)
if err != nil {
return things.ChannelsPage{}, errors.Wrap(ErrSelectChannel, err)
}
defer rows.Close()
items := []things.Channel{}
for rows.Next() {
dbch := dbChannel{Owner: owner}
if err := rows.StructScan(&dbch); err != nil {
return things.ChannelsPage{}, errors.Wrap(ErrSelectChannel, err)
}
ch := toChannel(dbch)
items = append(items, ch)
}
cq := fmt.Sprintf(`SELECT COUNT(*) FROM channels WHERE owner = :owner %s%s;`, nq, mq)
total, err := total(ctx, cr.db, cq, params)
if err != nil {
return things.ChannelsPage{}, errors.Wrap(ErrSelectChannel, err)
}
page := things.ChannelsPage{
Channels: items,
PageMetadata: things.PageMetadata{
Total: total,
Offset: offset,
Limit: limit,
},
}
return page, nil
}
func (cr channelRepository) RetrieveByThing(ctx context.Context, owner, thing string, offset, limit uint64, connected bool) (things.ChannelsPage, error) {
// Verify if UUID format is valid to avoid internal Postgres error
if _, err := uuid.FromString(thing); err != nil {
return things.ChannelsPage{}, things.ErrNotFound
}
var q, qc string
switch connected {
case true:
q = `SELECT id, name, metadata FROM channels ch
INNER JOIN connections conn
ON ch.id = conn.channel_id
WHERE ch.owner = :owner AND conn.thing_id = :thing
ORDER BY ch.id
LIMIT :limit
OFFSET :offset;`
qc = `SELECT COUNT(*)
FROM channels ch
INNER JOIN connections conn
ON ch.id = conn.channel_id
WHERE ch.owner = $1 AND conn.thing_id = $2`
default:
q = `SELECT id, name, metadata
FROM channels ch
WHERE ch.owner = :owner AND ch.id NOT IN
(SELECT id FROM channels ch
INNER JOIN connections conn
ON ch.id = conn.channel_id
WHERE ch.owner = :owner AND conn.thing_id = :thing)
ORDER BY ch.id
LIMIT :limit
OFFSET :offset;`
qc = `SELECT COUNT(*)
FROM channels ch
WHERE ch.owner = $1 AND ch.id NOT IN
(SELECT id FROM channels ch
INNER JOIN connections conn
ON ch.id = conn.channel_id
WHERE ch.owner = $1 AND conn.thing_id = $2);`
}
params := map[string]interface{}{
"owner": owner,
"thing": thing,
"limit": limit,
"offset": offset,
}
rows, err := cr.db.NamedQueryContext(ctx, q, params)
if err != nil {
return things.ChannelsPage{}, errors.Wrap(ErrSelectChannel, err)
}
defer rows.Close()
items := []things.Channel{}
for rows.Next() {
dbch := dbChannel{Owner: owner}
if err := rows.StructScan(&dbch); err != nil {
return things.ChannelsPage{}, errors.Wrap(ErrSelectChannel, err)
}
ch := toChannel(dbch)
items = append(items, ch)
}
var total uint64
if err := cr.db.GetContext(ctx, &total, qc, owner, thing); err != nil {
return things.ChannelsPage{}, errors.Wrap(ErrSelectChannel, err)
}
return things.ChannelsPage{
Channels: items,
PageMetadata: things.PageMetadata{
Total: total,
Offset: offset,
Limit: limit,
},
}, nil
}
func (cr channelRepository) Remove(ctx context.Context, owner, id string) error {
dbch := dbChannel{
ID: id,
Owner: owner,
}
q := `DELETE FROM channels WHERE id = :id AND owner = :owner`
cr.db.NamedExecContext(ctx, q, dbch)
return nil
}
func (cr channelRepository) Connect(ctx context.Context, owner string, chIDs, thIDs []string) error {
tx, err := cr.db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(ErrDeleteChannel, err)
}
q := `INSERT INTO connections (channel_id, channel_owner, thing_id, thing_owner)
VALUES (:channel, :owner, :thing, :owner);`
for _, chID := range chIDs {
for _, thID := range thIDs {
dbco := dbConnection{
Channel: chID,
Thing: thID,
Owner: owner,
}
_, err := tx.NamedExecContext(ctx, q, dbco)
if err != nil {
tx.Rollback()
pqErr, ok := err.(*pq.Error)
if ok {
switch pqErr.Code.Name() {
case errFK:
return things.ErrNotFound
case errDuplicate:
return things.ErrConflict
}
}
return errors.Wrap(ErrDeleteChannel, err)
}
}
}
if err = tx.Commit(); err != nil {
return errors.Wrap(ErrDeleteChannel, err)
}
return nil
}
func (cr channelRepository) Disconnect(ctx context.Context, owner, chanID, thingID string) error {
q := `DELETE FROM connections
WHERE channel_id = :channel AND channel_owner = :owner
AND thing_id = :thing AND thing_owner = :owner`
conn := dbConnection{
Channel: chanID,
Thing: thingID,
Owner: owner,
}
res, err := cr.db.NamedExecContext(ctx, q, conn)
if err != nil {
return errors.Wrap(ErrDeleteConnection, err)
}
cnt, err := res.RowsAffected()
if err != nil {
return errors.Wrap(ErrDeleteConnection, err)
}
if cnt == 0 {
return things.ErrNotFound
}
return nil
}
func (cr channelRepository) HasThing(ctx context.Context, chanID, key string) (string, error) {
var thingID string
q := `SELECT id FROM things WHERE key = $1`
if err := cr.db.QueryRowxContext(ctx, q, key).Scan(&thingID); err != nil {
return "", errors.Wrap(ErrHasThing, err)
}
if err := cr.hasThing(ctx, chanID, thingID); err != nil {
return "", errors.Wrap(ErrHasThing, err)
}
return thingID, nil
}
func (cr channelRepository) HasThingByID(ctx context.Context, chanID, thingID string) error {
return cr.hasThing(ctx, chanID, thingID)
}
func (cr channelRepository) hasThing(ctx context.Context, chanID, thingID string) error {
q := `SELECT EXISTS (SELECT 1 FROM connections WHERE channel_id = $1 AND thing_id = $2);`
exists := false
if err := cr.db.QueryRowxContext(ctx, q, chanID, thingID).Scan(&exists); err != nil {
return errors.Wrap(ErrHasThing, err)
}
if !exists {
return things.ErrUnauthorizedAccess
}
return nil
}
// dbMetadata type for handling metadata properly in database/sql.
type dbMetadata map[string]interface{}
// Scan implements the database/sql scanner interface.
func (m *dbMetadata) Scan(value interface{}) error {
if value == nil {
m = nil
return nil
}
b, ok := value.([]byte)
if !ok {
m = &dbMetadata{}
return things.ErrScanMetadata
}
if err := json.Unmarshal(b, m); err != nil {
return err
}
return nil
}
// Value implements database/sql valuer interface.
func (m dbMetadata) Value() (driver.Value, error) {
if len(m) == 0 {
return nil, nil
}
b, err := json.Marshal(m)
if err != nil {
return nil, err
}
return b, err
}
type dbChannel struct {
ID string `db:"id"`
Owner string `db:"owner"`
Name string `db:"name"`
Metadata dbMetadata `db:"metadata"`
}
func toDBChannel(ch things.Channel) dbChannel {
return dbChannel{
ID: ch.ID,
Owner: ch.Owner,
Name: ch.Name,
Metadata: ch.Metadata,
}
}
func toChannel(ch dbChannel) things.Channel {
return things.Channel{
ID: ch.ID,
Owner: ch.Owner,
Name: ch.Name,
Metadata: ch.Metadata,
}
}
func getNameQuery(name string) (string, string) {
name = strings.ToLower(name)
nq := ""
if name != "" {
name = fmt.Sprintf(`%%%s%%`, name)
nq = ` AND LOWER(name) LIKE :name`
}
return nq, name
}
func getMetadataQuery(m things.Metadata) ([]byte, string, error) {
mq := ""
mb := []byte("{}")
if len(m) > 0 {
mq = ` AND metadata @> :metadata`
b, err := json.Marshal(m)
if err != nil {
return nil, "", err
}
mb = b
}
return mb, mq, nil
}
func total(ctx context.Context, db Database, query string, params map[string]interface{}) (uint64, error) {
rows, err := db.NamedQueryContext(ctx, query, params)
if err != nil {
return 0, err
}
total := uint64(0)
if rows.Next() {
if err := rows.Scan(&total); err != nil {
return 0, err
}
}
return total, nil
}