diff --git a/pkg/sdk/go/users_test.go b/pkg/sdk/go/users_test.go index 2cd757dd..8751b063 100644 --- a/pkg/sdk/go/users_test.go +++ b/pkg/sdk/go/users_test.go @@ -497,10 +497,6 @@ func TestListMembers(t *testing.T) { } for _, tc := range cases { - fmt.Println() - fmt.Println(tc.desc) - fmt.Println() - repoCall := pRepo.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) repoCall1 := cRepo.On("Members", mock.Anything, tc.groupID, mock.Anything).Return(mfclients.MembersPage{Members: convertClients(tc.response)}, tc.err) membersPage, err := mfsdk.Members(tc.groupID, tc.page, tc.token) diff --git a/things/policies/mocks/channels.go b/things/policies/mocks/channels.go index 013d728c..009f6537 100644 --- a/things/policies/mocks/channels.go +++ b/things/policies/mocks/channels.go @@ -5,7 +5,6 @@ package mocks import ( "context" - "fmt" "strings" "sync" @@ -13,6 +12,8 @@ import ( "github.com/mainflux/mainflux/things/policies" ) +const separator = ":" + type cacheMock struct { mu sync.Mutex policies map[string]string @@ -25,34 +26,75 @@ func NewCache() policies.Cache { } } -func (ccm *cacheMock) Put(_ context.Context, policy policies.Policy) error { +func (ccm *cacheMock) Put(_ context.Context, policy policies.CachedPolicy) error { ccm.mu.Lock() defer ccm.mu.Unlock() - ccm.policies[fmt.Sprintf("%s:%s", policy.Subject, policy.Object)] = strings.Join(policy.Actions, ":") + key, value := kv(policy) + ccm.policies[key] = value + return nil } -func (ccm *cacheMock) Get(_ context.Context, policy policies.Policy) (policies.Policy, error) { +func (ccm *cacheMock) Get(_ context.Context, policy policies.CachedPolicy) (policies.CachedPolicy, error) { ccm.mu.Lock() defer ccm.mu.Unlock() - actions := ccm.policies[fmt.Sprintf("%s:%s", policy.Subject, policy.Object)] - if actions != "" { - return policies.Policy{ - Subject: policy.Subject, - Object: policy.Object, - Actions: strings.Split(actions, ":"), - }, nil + key, _ := kv(policy) + + val := ccm.policies[key] + if val == "" { + return policies.CachedPolicy{}, errors.ErrNotFound } - return policies.Policy{}, errors.ErrNotFound + thingID := extractThingID(val) + if thingID == "" { + return policies.CachedPolicy{}, errors.ErrNotFound + } + + policy.Actions = separateActions(val) + policy.ThingID = thingID + + return policy, nil } -func (ccm *cacheMock) Remove(_ context.Context, policy policies.Policy) error { +func (ccm *cacheMock) Remove(_ context.Context, policy policies.CachedPolicy) error { ccm.mu.Lock() defer ccm.mu.Unlock() - delete(ccm.policies, fmt.Sprintf("%s:%s", policy.Subject, policy.Object)) + key, _ := kv(policy) + + delete(ccm.policies, key) + return nil } + +// kv is used to create a key-value pair for caching. +func kv(p policies.CachedPolicy) (string, string) { + key := p.ThingKey + separator + p.ChannelID + val := strings.Join(p.Actions, separator) + + if p.ThingID != "" { + val += separator + p.ThingID + } + + return key, val +} + +// separateActions is used to separate the actions from the cache values. +func separateActions(actions string) []string { + return strings.Split(actions, separator) +} + +// extractThingID is used to extract the thingID from the cache values. +func extractThingID(actions string) string { + var lastIdx = strings.LastIndex(actions, separator) + + thingID := actions[lastIdx+1:] + // check if the thingID is a valid UUID + if len(thingID) != 36 { + return "" + } + + return thingID +} diff --git a/things/policies/policies.go b/things/policies/policies.go index 3b551b3d..b6d305d8 100644 --- a/things/policies/policies.go +++ b/things/policies/policies.go @@ -95,16 +95,23 @@ type Service interface { ListPolicies(ctx context.Context, token string, p Page) (PolicyPage, error) } +type CachedPolicy struct { + ThingID string + ThingKey string + ChannelID string + Actions []string +} + // Cache contains channel-thing connection caching interface. type Cache interface { // Put adds policy to cahce. - Put(ctx context.Context, policy Policy) error + Put(ctx context.Context, policy CachedPolicy) error // Get retrieves policy from cache. - Get(ctx context.Context, policy Policy) (Policy, error) + Get(ctx context.Context, policy CachedPolicy) (CachedPolicy, error) // Remove deletes a policy from cache. - Remove(ctx context.Context, policy Policy) error + Remove(ctx context.Context, policy CachedPolicy) error } // validate returns an error if policy representation is invalid. diff --git a/things/policies/redis/policies.go b/things/policies/redis/policies.go index 53836b5e..a3bb862d 100644 --- a/things/policies/redis/policies.go +++ b/things/policies/redis/policies.go @@ -5,7 +5,6 @@ package redis import ( "context" - "fmt" "strings" "time" @@ -25,47 +24,85 @@ type pcache struct { // NewCache returns redis policy cache implementation. func NewCache(client *redis.Client, duration time.Duration) policies.Cache { - return pcache{ + return &pcache{ client: client, keyDuration: duration, } } -func (pc pcache) Put(ctx context.Context, policy policies.Policy) error { - k, v := kv(policy) - if err := pc.client.Set(ctx, k, v, pc.keyDuration).Err(); err != nil { +func (pc *pcache) Put(ctx context.Context, policy policies.CachedPolicy) error { + key, value := kv(policy) + + if err := pc.client.Set(ctx, key, value, pc.keyDuration).Err(); err != nil { return errors.Wrap(errors.ErrCreateEntity, err) } + return nil } -func (pc pcache) Get(ctx context.Context, policy policies.Policy) (policies.Policy, error) { - k, _ := kv(policy) - res := pc.client.Get(ctx, k) +func (pc *pcache) Get(ctx context.Context, policy policies.CachedPolicy) (policies.CachedPolicy, error) { + key, _ := kv(policy) + res := pc.client.Get(ctx, key) // Nil response indicates non-existent key in Redis client. if res == nil || res.Err() == redis.Nil { - return policies.Policy{}, errors.ErrNotFound + return policies.CachedPolicy{}, errors.ErrNotFound } + if err := res.Err(); err != nil { - return policies.Policy{}, err + return policies.CachedPolicy{}, err } - actions, err := res.Result() + + val, err := res.Result() if err != nil { - return policies.Policy{}, err + return policies.CachedPolicy{}, err } - policy.Actions = strings.Split(actions, separator) + + thingID := extractThingID(val) + if thingID == "" { + return policies.CachedPolicy{}, errors.ErrNotFound + } + + policy.ThingID = thingID + policy.Actions = separateActions(val) + return policy, nil } -func (pc pcache) Remove(ctx context.Context, policy policies.Policy) error { - obj, _ := kv(policy) - if err := pc.client.Del(ctx, obj).Err(); err != nil { +func (pc *pcache) Remove(ctx context.Context, policy policies.CachedPolicy) error { + key, _ := kv(policy) + if err := pc.client.Del(ctx, key).Err(); err != nil { return errors.Wrap(errors.ErrRemoveEntity, err) } + return nil } -// Generates key-value pair for Redis client. -func kv(p policies.Policy) (string, string) { - return fmt.Sprintf("%s%s%s", p.Subject, separator, p.Object), strings.Join(p.Actions, separator) +// kv is used to create a key-value pair for caching. +func kv(p policies.CachedPolicy) (string, string) { + key := p.ThingKey + separator + p.ChannelID + val := strings.Join(p.Actions, separator) + + if p.ThingID != "" { + val += separator + p.ThingID + } + + return key, val +} + +// separateActions is used to separate the actions from the cache values. +func separateActions(actions string) []string { + return strings.Split(actions, separator) +} + +// extractThingID is used to extract the thingID from the cache values. +func extractThingID(actions string) string { + var lastIdx = strings.LastIndex(actions, separator) + + thingID := actions[lastIdx+1:] + // check if the thingID is a valid UUID + if len(thingID) != 36 { + return "" + } + + return thingID } diff --git a/things/policies/service.go b/things/policies/service.go index d08c2540..cb0e01f2 100644 --- a/things/policies/service.go +++ b/things/policies/service.go @@ -52,24 +52,32 @@ func NewService(auth upolicies.AuthServiceClient, p Repository, ccache Cache, id func (svc service) Authorize(ctx context.Context, ar AccessRequest) (Policy, error) { // Fetch from cache first. - policy := Policy{ - Subject: ar.Subject, - Object: ar.Object, + cpolicy := CachedPolicy{ + ThingKey: ar.Subject, + ChannelID: ar.Object, } - policy, err := svc.policyCache.Get(ctx, policy) + + cpolicy, err := svc.policyCache.Get(ctx, cpolicy) if err == nil { - for _, action := range policy.Actions { + for _, action := range cpolicy.Actions { if action == ar.Action { + var policy = Policy{ + Subject: cpolicy.ThingID, + } + return policy, nil } } + return Policy{}, errors.ErrAuthorization } + if !errors.Contains(err, errors.ErrNotFound) { return Policy{}, err } - // Fetch from repo as a fallback if not found in cache. + // Fetch from database as a fallback if not found in cache. + var policy Policy switch ar.Entity { case GroupEntityType: policy, err = svc.policies.EvaluateGroupAccess(ctx, ar) @@ -84,14 +92,18 @@ func (svc service) Authorize(ctx context.Context, ar AccessRequest) (Policy, err } case ThingEntityType: - policy, err := svc.policies.EvaluateMessagingAccess(ctx, ar) + policy, err = svc.policies.EvaluateMessagingAccess(ctx, ar) if err != nil { return Policy{}, err } - // Replace Subject since AccessRequest Subject is Thing Key, - // and Policy subject is Thing ID. - policy.Subject = ar.Subject - if err := svc.policyCache.Put(ctx, policy); err != nil { + + cpolicy = CachedPolicy{ + ThingID: policy.Subject, + ThingKey: ar.Subject, + ChannelID: ar.Object, + Actions: policy.Actions, + } + if err := svc.policyCache.Put(ctx, cpolicy); err != nil { return policy, err } @@ -126,7 +138,11 @@ func (svc service) AddPolicy(ctx context.Context, token string, external bool, p p.UpdatedAt = time.Now() p.UpdatedBy = userID - if err := svc.policyCache.Remove(ctx, p); err != nil { + var cpolicy = CachedPolicy{ + ThingKey: p.Subject, + ChannelID: p.Object, + } + if err := svc.policyCache.Remove(ctx, cpolicy); err != nil { return Policy{}, err } @@ -193,7 +209,11 @@ func (svc service) UpdatePolicy(ctx context.Context, token string, p Policy) (Po p.UpdatedAt = time.Now() p.UpdatedBy = userID - if err := svc.policyCache.Remove(ctx, p); err != nil { + var cpolicy = CachedPolicy{ + ThingKey: p.Subject, + ChannelID: p.Object, + } + if err := svc.policyCache.Remove(ctx, cpolicy); err != nil { return Policy{}, err } @@ -228,9 +248,14 @@ func (svc service) DeletePolicy(ctx context.Context, token string, p Policy) err return err } - if err := svc.policyCache.Remove(ctx, p); err != nil { + var cpolicy = CachedPolicy{ + ThingKey: p.Subject, + ChannelID: p.Object, + } + if err := svc.policyCache.Remove(ctx, cpolicy); err != nil { return err } + return svc.policies.Delete(ctx, p) }