Mainflux.mainflux/twins/mongodb/states.go

157 lines
3.5 KiB
Go

// Copyright (c) Mainflux
// SPDX-License-Identifier: Apache-2.0
package mongodb
import (
"context"
"github.com/mainflux/mainflux/twins"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
const (
statesCollection string = "states"
twinid string = "twinid"
)
type stateRepository struct {
db *mongo.Database
}
var _ twins.StateRepository = (*stateRepository)(nil)
// NewStateRepository instantiates a MongoDB implementation of state
// repository.
func NewStateRepository(db *mongo.Database) twins.StateRepository {
return &stateRepository{
db: db,
}
}
// SaveState persists the state.
func (sr *stateRepository) Save(ctx context.Context, st twins.State) error {
coll := sr.db.Collection(statesCollection)
if _, err := coll.InsertOne(ctx, st); err != nil {
return err
}
return nil
}
// Update persists the state.
func (sr *stateRepository) Update(ctx context.Context, st twins.State) error {
coll := sr.db.Collection(statesCollection)
filter := bson.M{"id": st.ID, twinid: st.TwinID}
update := bson.M{"$set": st}
if _, err := coll.UpdateOne(ctx, filter, update); err != nil {
return err
}
return nil
}
// CountStates returns the number of states related to twin.
func (sr *stateRepository) Count(ctx context.Context, tw twins.Twin) (int64, error) {
coll := sr.db.Collection(statesCollection)
filter := bson.M{twinid: tw.ID}
total, err := coll.CountDocuments(ctx, filter)
if err != nil {
return 0, err
}
return total, nil
}
// RetrieveAll retrieves the subset of states related to twin specified by id.
func (sr *stateRepository) RetrieveAll(ctx context.Context, offset uint64, limit uint64, twinID string) (twins.StatesPage, error) {
coll := sr.db.Collection(statesCollection)
findOptions := options.Find()
findOptions.SetSkip(int64(offset))
findOptions.SetLimit(int64(limit))
filter := bson.M{twinid: twinID}
cur, err := coll.Find(ctx, filter, findOptions)
if err != nil {
return twins.StatesPage{}, err
}
results, err := decodeStates(ctx, cur)
if err != nil {
return twins.StatesPage{}, err
}
total, err := coll.CountDocuments(ctx, filter)
if err != nil {
return twins.StatesPage{}, err
}
return twins.StatesPage{
States: results,
PageMetadata: twins.PageMetadata{
Total: uint64(total),
Offset: offset,
Limit: limit,
},
}, nil
}
// RetrieveLast returns the last state related to twin spec by id.
func (sr *stateRepository) RetrieveLast(ctx context.Context, twinID string) (twins.State, error) {
coll := sr.db.Collection(statesCollection)
filter := bson.M{twinid: twinID}
total, err := coll.CountDocuments(ctx, filter)
if err != nil {
return twins.State{}, err
}
findOptions := options.Find()
var skip int64
if total > 0 {
skip = total - 1
}
findOptions.SetSkip(skip)
findOptions.SetLimit(1)
cur, err := coll.Find(ctx, filter, findOptions)
if err != nil {
return twins.State{}, err
}
results, err := decodeStates(ctx, cur)
if err != nil {
return twins.State{}, err
}
if len(results) < 1 {
return twins.State{}, nil
}
return results[0], nil
}
func decodeStates(ctx context.Context, cur *mongo.Cursor) ([]twins.State, error) {
defer cur.Close(ctx)
var results []twins.State
for cur.Next(ctx) {
var elem twins.State
if err := cur.Decode(&elem); err != nil {
return []twins.State{}, nil
}
results = append(results, elem)
}
if err := cur.Err(); err != nil {
return []twins.State{}, nil
}
return results, nil
}