diff --git a/README.md b/README.md index bb0e07bf..c6c64023 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![build][ci-badge]][ci-url] [![go report card][grc-badge]][grc-url] +[![coverage][cov-badge]][cov-url] [![license][license]](LICENSE) [![chat][gitter-badge]][gitter] @@ -61,5 +62,7 @@ Thank you for your interest in Mainflux and wish to contribute! [gitter-badge]: https://badges.gitter.im/Join%20Chat.svg [grc-badge]: https://goreportcard.com/badge/github.com/mainflux/mainflux [grc-url]: https://goreportcard.com/report/github.com/mainflux/mainflux +[cov-badge]: https://codecov.io/gh/mainflux/mainflux/branch/master/graph/badge.svg +[cov-url]: https://codecov.io/gh/mainflux/mainflux [license]: https://img.shields.io/badge/license-Apache%20v2.0-blue.svg [twitter]: https://twitter.com/mainflux diff --git a/manager/api/requests.go b/manager/api/requests.go index 37cd7a2e..bd937217 100644 --- a/manager/api/requests.go +++ b/manager/api/requests.go @@ -137,7 +137,7 @@ func (req connectionReq) validate() error { return manager.ErrUnauthorizedAccess } - if !govalidator.IsUUID(req.chanId) && !govalidator.IsUUID(req.clientId) { + if !govalidator.IsUUID(req.chanId) || !govalidator.IsUUID(req.clientId) { return manager.ErrNotFound } diff --git a/manager/api/transport.go b/manager/api/transport.go index c9f6b831..6de67cac 100644 --- a/manager/api/transport.go +++ b/manager/api/transport.go @@ -3,6 +3,8 @@ package api import ( "context" "encoding/json" + "errors" + "io" "net/http" kithttp "github.com/go-kit/kit/transport/http" @@ -12,6 +14,8 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) +var errUnsupportedContentType = errors.New("unsupported content type") + // MakeHandler returns a HTTP handler for API endpoints. func MakeHandler(svc manager.Service) http.Handler { opts := []kithttp.ServerOption{ @@ -147,6 +151,10 @@ func decodeIdentity(_ context.Context, r *http.Request) (interface{}, error) { } func decodeCredentials(_ context.Context, r *http.Request) (interface{}, error) { + if r.Header.Get("Content-Type") != contentType { + return nil, errUnsupportedContentType + } + var user manager.User if err := json.NewDecoder(r.Body).Decode(&user); err != nil { return nil, err @@ -156,6 +164,10 @@ func decodeCredentials(_ context.Context, r *http.Request) (interface{}, error) } func decodeClientCreation(_ context.Context, r *http.Request) (interface{}, error) { + if r.Header.Get("Content-Type") != contentType { + return nil, errUnsupportedContentType + } + var client manager.Client if err := json.NewDecoder(r.Body).Decode(&client); err != nil { return nil, err @@ -170,6 +182,10 @@ func decodeClientCreation(_ context.Context, r *http.Request) (interface{}, erro } func decodeClientUpdate(_ context.Context, r *http.Request) (interface{}, error) { + if r.Header.Get("Content-Type") != contentType { + return nil, errUnsupportedContentType + } + var client manager.Client if err := json.NewDecoder(r.Body).Decode(&client); err != nil { return nil, err @@ -185,6 +201,10 @@ func decodeClientUpdate(_ context.Context, r *http.Request) (interface{}, error) } func decodeChannelCreation(_ context.Context, r *http.Request) (interface{}, error) { + if r.Header.Get("Content-Type") != contentType { + return nil, errUnsupportedContentType + } + var channel manager.Channel if err := json.NewDecoder(r.Body).Decode(&channel); err != nil { return nil, err @@ -199,6 +219,10 @@ func decodeChannelCreation(_ context.Context, r *http.Request) (interface{}, err } func decodeChannelUpdate(_ context.Context, r *http.Request) (interface{}, error) { + if r.Header.Get("Content-Type") != contentType { + return nil, errUnsupportedContentType + } + var channel manager.Channel if err := json.NewDecoder(r.Body).Decode(&channel); err != nil { return nil, err @@ -272,6 +296,12 @@ func encodeError(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(http.StatusNotFound) case manager.ErrConflict: w.WriteHeader(http.StatusConflict) + case errUnsupportedContentType: + w.WriteHeader(http.StatusUnsupportedMediaType) + case io.ErrUnexpectedEOF: + w.WriteHeader(http.StatusBadRequest) + case io.EOF: + w.WriteHeader(http.StatusBadRequest) default: switch err.(type) { case *json.SyntaxError: diff --git a/manager/api/transport_test.go b/manager/api/transport_test.go new file mode 100644 index 00000000..da271fe0 --- /dev/null +++ b/manager/api/transport_test.go @@ -0,0 +1,737 @@ +package api_test + +import ( + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/mainflux/mainflux/manager" + "github.com/mainflux/mainflux/manager/api" + "github.com/mainflux/mainflux/manager/mocks" + "github.com/stretchr/testify/assert" +) + +const ( + contentType = "application/json; charset=utf-8" + invalidEmail = "userexample.com" + wrongID = "123e4567-e89b-12d3-a456-000000000042" +) + +var ( + user = manager.User{"user@example.com", "password"} + client = manager.Client{Type: "app", Name: "test_app", Payload: "test_payload"} + channel = manager.Channel{Name: "test"} +) + +type testRequest struct { + client *http.Client + method string + url string + contentType string + token string + body io.Reader +} + +func (tr testRequest) make() (*http.Response, error) { + req, err := http.NewRequest(tr.method, tr.url, tr.body) + if err != nil { + return nil, err + } + if tr.token != "" { + req.Header.Set("Authorization", tr.token) + } + if tr.contentType != "" { + req.Header.Set("Content-Type", tr.contentType) + } + return tr.client.Do(req) +} + +func newService() manager.Service { + users := mocks.NewUserRepository() + clients := mocks.NewClientRepository() + channels := mocks.NewChannelRepository(clients) + hasher := mocks.NewHasher() + idp := mocks.NewIdentityProvider() + + return manager.New(users, clients, channels, hasher, idp) +} + +func newServer(svc manager.Service) *httptest.Server { + mux := api.MakeHandler(svc) + return httptest.NewServer(mux) +} + +func toJSON(data interface{}) string { + jsonData, _ := json.Marshal(data) + return string(jsonData) +} + +func TestRegister(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + data := toJSON(user) + invalidData := toJSON(manager.User{Email: invalidEmail, Password: "password"}) + + cases := []struct { + desc string + req string + contentType string + status int + }{ + {"register new user", data, contentType, http.StatusCreated}, + {"register existing user", data, contentType, http.StatusConflict}, + {"register user with invalid email address", invalidData, contentType, http.StatusBadRequest}, + {"register user with invalid request format", "{", contentType, http.StatusBadRequest}, + {"register user with empty JSON request", "{}", contentType, http.StatusBadRequest}, + {"register user with empty request", "", contentType, http.StatusBadRequest}, + {"register user with missing content type", data, "", http.StatusUnsupportedMediaType}, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodPost, + url: fmt.Sprintf("%s/users", ts.URL), + contentType: tc.contentType, + body: strings.NewReader(tc.req), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + } +} + +func TestLogin(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + tokenData := toJSON(map[string]string{"token": user.Email}) + data := toJSON(user) + invalidEmailData := toJSON(manager.User{Email: invalidEmail, Password: "password"}) + invalidData := toJSON(manager.User{"user@example.com", "invalid_password"}) + nonexistentData := toJSON(manager.User{"non-existentuser@example.com", "pass"}) + svc.Register(user) + + cases := []struct { + desc string + req string + contentType string + status int + res string + }{ + {"login with valid credentials", data, contentType, http.StatusCreated, tokenData}, + {"login with invalid credentials", invalidData, contentType, http.StatusForbidden, ""}, + {"login with invalid email address", invalidEmailData, contentType, http.StatusBadRequest, ""}, + {"login non-existent user", nonexistentData, contentType, http.StatusForbidden, ""}, + {"login with invalid request format", "{", contentType, http.StatusBadRequest, ""}, + {"login with empty JSON request", "{}", contentType, http.StatusBadRequest, ""}, + {"login with empty request", "", contentType, http.StatusBadRequest, ""}, + {"login with missing content type", data, "", http.StatusUnsupportedMediaType, ""}, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodPost, + url: fmt.Sprintf("%s/tokens", ts.URL), + contentType: tc.contentType, + body: strings.NewReader(tc.req), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + body, err := ioutil.ReadAll(res.Body) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + token := strings.Trim(string(body), "\n") + + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.Equal(t, tc.res, token, fmt.Sprintf("%s: expected body %s got %s", tc.desc, tc.res, token)) + } +} + +func TestAddClient(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + cli := ts.Client() + + data := toJSON(client) + invalidData := toJSON(manager.Client{ + Type: "foo", + Name: "invalid_client", + Payload: "some_payload", + }) + svc.Register(user) + + cases := []struct { + desc string + req string + contentType string + auth string + status int + }{ + {"add valid client", data, contentType, user.Email, http.StatusCreated}, + {"add client with invalid data", invalidData, contentType, user.Email, http.StatusBadRequest}, + {"add client with invalid auth token", data, contentType, "invalid_token", http.StatusForbidden}, + {"add client with invalid request format", "}", contentType, user.Email, http.StatusBadRequest}, + {"add client with empty JSON request", "{}", contentType, user.Email, http.StatusBadRequest}, + {"add client with empty request", "", contentType, user.Email, http.StatusBadRequest}, + {"add client with missing content type", data, "", user.Email, http.StatusUnsupportedMediaType}, + } + + for _, tc := range cases { + req := testRequest{ + client: cli, + method: http.MethodPost, + url: fmt.Sprintf("%s/clients", ts.URL), + contentType: tc.contentType, + token: tc.auth, + body: strings.NewReader(tc.req), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + } +} + +func TestUpdateClient(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + cli := ts.Client() + + data := toJSON(client) + invalidData := toJSON(manager.Client{ + Type: "foo", + Name: client.Name, + Payload: client.Payload, + }) + svc.Register(user) + id, _ := svc.AddClient(user.Email, client) + + cases := []struct { + desc string + req string + id string + contentType string + auth string + status int + }{ + {"update existing client", data, id, contentType, user.Email, http.StatusOK}, + {"update non-existent client", data, wrongID, contentType, user.Email, http.StatusNotFound}, + {"update client with invalid id", data, "1", contentType, user.Email, http.StatusNotFound}, + {"update client with invalid data", invalidData, id, contentType, user.Email, http.StatusBadRequest}, + {"update client with invalid user token", data, id, contentType, invalidEmail, http.StatusForbidden}, + {"update client with invalid data format", "{", id, contentType, user.Email, http.StatusBadRequest}, + {"update client with empty JSON request", "{}", id, contentType, user.Email, http.StatusBadRequest}, + {"update client with empty request", "", id, contentType, user.Email, http.StatusBadRequest}, + {"update client with missing content type", data, id, "", user.Email, http.StatusUnsupportedMediaType}, + } + + for _, tc := range cases { + req := testRequest{ + client: cli, + method: http.MethodPut, + url: fmt.Sprintf("%s/clients/%s", ts.URL, tc.id), + contentType: tc.contentType, + token: tc.auth, + body: strings.NewReader(tc.req), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + } +} + +func TestViewClient(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + cli := ts.Client() + + svc.Register(user) + id, _ := svc.AddClient(user.Email, client) + + client.ID = id + client.Key = id + data := toJSON(client) + + cases := []struct { + desc string + id string + auth string + status int + res string + }{ + {"view existing client", id, user.Email, http.StatusOK, data}, + {"view non-existent client", wrongID, user.Email, http.StatusNotFound, ""}, + {"view client by passing invalid id", "1", user.Email, http.StatusNotFound, ""}, + {"view client by passing invalid token", id, invalidEmail, http.StatusForbidden, ""}, + } + + for _, tc := range cases { + req := testRequest{ + client: cli, + method: http.MethodGet, + url: fmt.Sprintf("%s/clients/%s", ts.URL, tc.id), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + body, err := ioutil.ReadAll(res.Body) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + data := strings.Trim(string(body), "\n") + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.Equal(t, tc.res, data, fmt.Sprintf("%s: expected body %s got %s", tc.desc, tc.res, data)) + } +} + +func TestListClients(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + cli := ts.Client() + + svc.Register(user) + noClientsUser := manager.User{Email: "no_clients_user@example.com", Password: user.Password} + svc.Register(noClientsUser) + clients := []manager.Client{} + for i := 0; i < 10; i++ { + id, _ := svc.AddClient(user.Email, client) + client.ID = id + client.Key = id + clients = append(clients, client) + } + + cases := []struct { + desc string + auth string + status int + res []manager.Client + }{ + {"fetch list of clients", user.Email, http.StatusOK, clients}, + {"fetch empty list of clients", noClientsUser.Email, http.StatusOK, []manager.Client{}}, + {"fetch list of clients with invalid token", invalidEmail, http.StatusForbidden, nil}, + } + + for _, tc := range cases { + req := testRequest{ + client: cli, + method: http.MethodGet, + url: fmt.Sprintf("%s/clients", ts.URL), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var data map[string][]manager.Client + json.NewDecoder(res.Body).Decode(&data) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.ElementsMatch(t, tc.res, data["clients"], fmt.Sprintf("%s: expected body %s got %s", tc.desc, tc.res, data["clients"])) + } +} + +func TestRemoveClient(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + cli := ts.Client() + + svc.Register(user) + id, _ := svc.AddClient(user.Email, client) + + cases := []struct { + desc string + id string + auth string + status int + }{ + {"delete existing client", id, user.Email, http.StatusNoContent}, + {"delete non-existent client", wrongID, user.Email, http.StatusNoContent}, + {"delete client with invalid id", "1", user.Email, http.StatusNoContent}, + {"delete client with invalid token", id, invalidEmail, http.StatusForbidden}, + } + + for _, tc := range cases { + req := testRequest{ + client: cli, + method: http.MethodDelete, + url: fmt.Sprintf("%s/clients/%s", ts.URL, tc.id), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + } +} + +func TestCreateChannel(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + data := toJSON(channel) + svc.Register(user) + + cases := []struct { + desc string + req string + contentType string + auth string + status int + }{ + {"create new channel", data, contentType, user.Email, http.StatusCreated}, + {"create new channel with invalid token", data, contentType, invalidEmail, http.StatusForbidden}, + {"create new channel with invalid data format", "{", contentType, user.Email, http.StatusBadRequest}, + {"create new channel with empty JSON request", "{}", contentType, user.Email, http.StatusCreated}, + {"create new channel with empty request", "", contentType, user.Email, http.StatusBadRequest}, + {"create new channel with missing content type", data, "", user.Email, http.StatusUnsupportedMediaType}, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodPost, + url: fmt.Sprintf("%s/channels", ts.URL), + contentType: tc.contentType, + token: tc.auth, + body: strings.NewReader(tc.req), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + } +} + +func TestUpdateChannel(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + updateData := toJSON(map[string]string{ + "name": "updated_channel", + }) + svc.Register(user) + id, _ := svc.CreateChannel(user.Email, channel) + + cases := []struct { + desc string + req string + id string + contentType string + auth string + status int + }{ + {"update existing channel", updateData, id, contentType, user.Email, http.StatusOK}, + {"update non-existing channel", updateData, wrongID, contentType, user.Email, http.StatusNotFound}, + {"update channel with invalid token", updateData, id, contentType, invalidEmail, http.StatusForbidden}, + {"update channel with invalid id", updateData, "1", contentType, user.Email, http.StatusNotFound}, + {"update channel with invalid data format", "}", id, contentType, user.Email, http.StatusBadRequest}, + {"update channel with empty JSON object", "{}", id, contentType, user.Email, http.StatusOK}, + {"update channel with empty request", "", id, contentType, user.Email, http.StatusBadRequest}, + {"update channel with missing content type", updateData, id, "", user.Email, http.StatusUnsupportedMediaType}, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodPut, + url: fmt.Sprintf("%s/channels/%s", ts.URL, tc.id), + contentType: tc.contentType, + token: tc.auth, + body: strings.NewReader(tc.req), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + } +} + +func TestViewChannel(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + svc.Register(user) + id, _ := svc.CreateChannel(user.Email, channel) + channel.ID = id + data := toJSON(channel) + + cases := []struct { + desc string + id string + auth string + status int + res string + }{ + {"view existing channel", id, user.Email, http.StatusOK, data}, + {"view non-existent channel", wrongID, user.Email, http.StatusNotFound, ""}, + {"view channel with invalid id", "1", user.Email, http.StatusNotFound, ""}, + {"view channel with invalid token", id, invalidEmail, http.StatusForbidden, ""}, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodGet, + url: fmt.Sprintf("%s/channels/%s", ts.URL, tc.id), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + data, err := ioutil.ReadAll(res.Body) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + body := strings.Trim(string(data), "\n") + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.Equal(t, tc.res, body, fmt.Sprintf("%s: expected body %s got %s", tc.desc, tc.res, body)) + } +} + +func TestListChannels(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + svc.Register(user) + channels := []manager.Channel{} + for i := 0; i < 10; i++ { + id, _ := svc.CreateChannel(user.Email, channel) + channel.ID = id + channels = append(channels, channel) + } + + cases := []struct { + desc string + auth string + status int + res []manager.Channel + }{ + {"get a list of channels", user.Email, http.StatusOK, channels}, + {"get a list of channels with invalid token", invalidEmail, http.StatusForbidden, nil}, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodGet, + url: fmt.Sprintf("%s/channels", ts.URL), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var body map[string][]manager.Channel + json.NewDecoder(res.Body).Decode(&body) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.ElementsMatch(t, tc.res, body["channels"], fmt.Sprintf("%s: expected body %s got %s", tc.desc, tc.res, body["channels"])) + } +} + +func TestRemoveChannel(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + svc.Register(user) + id, _ := svc.CreateChannel(user.Email, channel) + channel.ID = id + + cases := []struct { + desc string + id string + auth string + status int + }{ + {"remove existing channel", channel.ID, user.Email, http.StatusNoContent}, + {"remove non-existent channel", channel.ID, user.Email, http.StatusNoContent}, + {"remove channel with invalid id", wrongID, user.Email, http.StatusNoContent}, + {"remove channel with invalid token", channel.ID, invalidEmail, http.StatusForbidden}, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodDelete, + url: fmt.Sprintf("%s/channels/%s", ts.URL, tc.id), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + } +} + +func TestConnect(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + cli := ts.Client() + + svc.Register(user) + clientID, _ := svc.AddClient(user.Email, client) + chanID, _ := svc.CreateChannel(user.Email, channel) + + otherUser := manager.User{Email: "other_user@example.com", Password: "password"} + svc.Register(otherUser) + otherClientID, _ := svc.AddClient(otherUser.Email, client) + otherChanID, _ := svc.CreateChannel(otherUser.Email, channel) + + cases := []struct { + desc string + chanID string + clientID string + auth string + status int + }{ + {"connect existing client to existing channel", chanID, clientID, user.Email, http.StatusOK}, + {"connect existing client to non-existent channel", wrongID, clientID, user.Email, http.StatusNotFound}, + {"connect client with invalid id to channel", chanID, "1", user.Email, http.StatusNotFound}, + {"connect client to channel with invalid id", "1", clientID, user.Email, http.StatusNotFound}, + {"connect existing client to existing channel with invalid token", chanID, clientID, invalidEmail, http.StatusForbidden}, + {"connect client from owner to channel of other user", otherChanID, clientID, user.Email, http.StatusNotFound}, + {"connect client from other user to owner's channel", chanID, otherClientID, user.Email, http.StatusNotFound}, + } + + for _, tc := range cases { + req := testRequest{ + client: cli, + method: http.MethodPut, + url: fmt.Sprintf("%s/channels/%s/clients/%s", ts.URL, tc.chanID, tc.clientID), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + } +} + +func TestDisconnnect(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + cli := ts.Client() + + svc.Register(user) + clientID, _ := svc.AddClient(user.Email, client) + chanID, _ := svc.CreateChannel(user.Email, channel) + svc.Connect(user.Email, chanID, clientID) + otherUser := manager.User{Email: "other_user@example.com", Password: "password"} + svc.Register(otherUser) + otherClientID, _ := svc.AddClient(otherUser.Email, client) + otherChanID, _ := svc.CreateChannel(otherUser.Email, channel) + svc.Connect(otherUser.Email, otherChanID, otherClientID) + + cases := []struct { + desc string + chanID string + clientID string + auth string + status int + }{ + {"disconnect connected client from channel", chanID, clientID, user.Email, http.StatusNoContent}, + {"disconnect non-connected client from channel", chanID, clientID, user.Email, http.StatusNotFound}, + {"disconnect non-existent client from channel", chanID, "1", user.Email, http.StatusNotFound}, + {"disconnect client from non-existent channel", "1", clientID, user.Email, http.StatusNotFound}, + {"disconnect client from channel with invalid token", chanID, clientID, invalidEmail, http.StatusForbidden}, + {"disconnect owner's client from someone elses channel", otherChanID, clientID, user.Email, http.StatusNotFound}, + {"disconnect other's client from owner's channel", chanID, otherClientID, user.Email, http.StatusNotFound}, + } + + for _, tc := range cases { + req := testRequest{ + client: cli, + method: http.MethodDelete, + url: fmt.Sprintf("%s/channels/%s/clients/%s", ts.URL, tc.chanID, tc.clientID), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + } +} + +func TestIdentity(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + cli := ts.Client() + + svc.Register(user) + clientID, _ := svc.AddClient(user.Email, client) + + cases := []struct { + desc string + key string + status int + clientID string + }{ + {"get client id using existing client key", clientID, http.StatusOK, clientID}, + {"get client id using non-existent client key", "", http.StatusForbidden, ""}, + } + + for _, tc := range cases { + req := testRequest{ + client: cli, + method: http.MethodGet, + url: fmt.Sprintf("%s/access-grant", ts.URL), + token: tc.key, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + clientID := res.Header.Get("X-client-id") + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.Equal(t, tc.clientID, clientID, fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.clientID, clientID)) + } +} + +func TestCanAccess(t *testing.T) { + svc := newService() + ts := newServer(svc) + defer ts.Close() + cli := ts.Client() + + svc.Register(user) + clientID, _ := svc.AddClient(user.Email, client) + notConnectedClientID, _ := svc.AddClient(user.Email, client) + chanID, _ := svc.CreateChannel(user.Email, channel) + svc.Connect(user.Email, chanID, clientID) + + cases := []struct { + desc string + chanID string + clientKey string + status int + clientID string + }{ + {"check access to existing channel given connected client", chanID, clientID, http.StatusOK, clientID}, + {"check access to existing channel given not connected client", chanID, notConnectedClientID, http.StatusForbidden, ""}, + {"check access to existing channel given non-existent client", chanID, "invalid_token", http.StatusForbidden, ""}, + {"check access to non-existent channel given existing client", "invalid_token", clientID, http.StatusForbidden, ""}, + } + + for _, tc := range cases { + req := testRequest{ + client: cli, + method: http.MethodGet, + url: fmt.Sprintf("%s/channels/%s/access-grant", ts.URL, tc.chanID), + token: tc.clientKey, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + clientID := res.Header.Get("X-client-id") + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.Equal(t, tc.clientID, clientID, fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.clientID, clientID)) + } +} diff --git a/manager/manager_test.go b/manager/manager_test.go index b014f86f..269a9789 100644 --- a/manager/manager_test.go +++ b/manager/manager_test.go @@ -13,14 +13,14 @@ const wrong string = "wrong-value" var ( user manager.User = manager.User{"user@example.com", "password"} - client manager.Client = manager.Client{ID: "1", Type: "app", Name: "test", Key: "1"} - channel manager.Channel = manager.Channel{ID: "1", Name: "test", Clients: []manager.Client{client}} + client manager.Client = manager.Client{Type: "app", Name: "test"} + channel manager.Channel = manager.Channel{Name: "test", Clients: []manager.Client{}} ) func newService() manager.Service { users := mocks.NewUserRepository() clients := mocks.NewClientRepository() - channels := mocks.NewChannelRepository() + channels := mocks.NewChannelRepository(clients) hasher := mocks.NewHasher() idp := mocks.NewIdentityProvider() @@ -89,7 +89,8 @@ func TestUpdateClient(t *testing.T) { svc := newService() svc.Register(user) key, _ := svc.Login(user) - svc.AddClient(key, client) + clientId, _ := svc.AddClient(key, client) + client.ID = clientId cases := map[string]struct { client manager.Client @@ -111,7 +112,8 @@ func TestViewClient(t *testing.T) { svc := newService() svc.Register(user) key, _ := svc.Login(user) - svc.AddClient(key, client) + clientId, _ := svc.AddClient(key, client) + client.ID = clientId cases := map[string]struct { id string @@ -152,7 +154,8 @@ func TestRemoveClient(t *testing.T) { svc := newService() svc.Register(user) key, _ := svc.Login(user) - svc.AddClient(key, client) + clientId, _ := svc.AddClient(key, client) + client.ID = clientId cases := map[string]struct { id string @@ -195,7 +198,8 @@ func TestUpdateChannel(t *testing.T) { svc := newService() svc.Register(user) key, _ := svc.Login(user) - svc.CreateChannel(key, channel) + chanId, _ := svc.CreateChannel(key, channel) + channel.ID = chanId cases := map[string]struct { channel manager.Channel @@ -217,7 +221,8 @@ func TestViewChannel(t *testing.T) { svc := newService() svc.Register(user) key, _ := svc.Login(user) - svc.CreateChannel(key, channel) + chanId, _ := svc.CreateChannel(key, channel) + channel.ID = chanId cases := map[string]struct { id string @@ -258,7 +263,8 @@ func TestRemoveChannel(t *testing.T) { svc := newService() svc.Register(user) key, _ := svc.Login(user) - svc.CreateChannel(key, channel) + chanId, _ := svc.CreateChannel(key, channel) + channel.ID = chanId cases := map[string]struct { id string @@ -283,7 +289,9 @@ func TestConnect(t *testing.T) { key, _ := svc.Login(user) clientId, _ := svc.AddClient(key, client) + client.ID = clientId chanId, _ := svc.CreateChannel(key, channel) + channel.ID = chanId cases := map[string]struct { key string @@ -291,9 +299,9 @@ func TestConnect(t *testing.T) { clientId string err error }{ - "connect client": {key, chanId, clientId, nil}, - "connect client with wrong credentials": {wrong, chanId, clientId, manager.ErrUnauthorizedAccess}, - "connect client to non-existing channel": {key, wrong, clientId, manager.ErrNotFound}, + "connect client": {key, channel.ID, client.ID, nil}, + "connect client with wrong credentials": {wrong, channel.ID, client.ID, manager.ErrUnauthorizedAccess}, + "connect client to non-existing channel": {key, wrong, client.ID, manager.ErrNotFound}, } for desc, tc := range cases { @@ -308,7 +316,9 @@ func TestDisconnect(t *testing.T) { key, _ := svc.Login(user) clientId, _ := svc.AddClient(key, client) + client.ID = clientId chanId, _ := svc.CreateChannel(key, channel) + channel.ID = chanId svc.Connect(key, chanId, clientId) @@ -319,11 +329,11 @@ func TestDisconnect(t *testing.T) { clientId string err error }{ - {"disconnect connected client", key, chanId, clientId, nil}, - {"disconnect disconnected client", key, chanId, clientId, manager.ErrNotFound}, - {"disconnect client with wrong credentials", wrong, chanId, clientId, manager.ErrUnauthorizedAccess}, - {"disconnect client from non-existing channel", key, wrong, clientId, manager.ErrNotFound}, - {"disconnect non-existing client", key, chanId, wrong, manager.ErrNotFound}, + {"disconnect connected client", key, channel.ID, client.ID, nil}, + {"disconnect disconnected client", key, channel.ID, client.ID, manager.ErrNotFound}, + {"disconnect client with wrong credentials", wrong, channel.ID, client.ID, manager.ErrUnauthorizedAccess}, + {"disconnect client from non-existing channel", key, wrong, client.ID, manager.ErrNotFound}, + {"disconnect non-existing client", key, channel.ID, wrong, manager.ErrNotFound}, } for _, tc := range cases { @@ -357,8 +367,13 @@ func TestCanAccess(t *testing.T) { svc.Register(user) key, _ := svc.Login(user) - svc.AddClient(key, client) - svc.CreateChannel(key, channel) + clientId, _ := svc.AddClient(key, client) + client.ID = clientId + client.Key = clientId + + channel.Clients = []manager.Client{client} + chanId, _ := svc.CreateChannel(key, channel) + channel.ID = chanId cases := map[string]struct { key string diff --git a/manager/mocks/channels.go b/manager/mocks/channels.go index 09c878dd..467d71c3 100644 --- a/manager/mocks/channels.go +++ b/manager/mocks/channels.go @@ -2,7 +2,6 @@ package mocks import ( "fmt" - "strconv" "strings" "sync" @@ -15,12 +14,14 @@ type channelRepositoryMock struct { mu sync.Mutex counter int channels map[string]manager.Channel + clients manager.ClientRepository } // NewChannelRepository creates in-memory channel repository. -func NewChannelRepository() manager.ChannelRepository { +func NewChannelRepository(clients manager.ClientRepository) manager.ChannelRepository { return &channelRepositoryMock{ channels: make(map[string]manager.Channel), + clients: clients, } } @@ -29,7 +30,7 @@ func (crm *channelRepositoryMock) Save(channel manager.Channel) (string, error) defer crm.mu.Unlock() crm.counter += 1 - channel.ID = strconv.Itoa(crm.counter) + channel.ID = fmt.Sprintf("123e4567-e89b-12d3-a456-%012d", crm.counter) crm.channels[key(channel.Owner, channel.ID)] = channel @@ -85,10 +86,11 @@ func (crm *channelRepositoryMock) Connect(owner, chanId, clientId string) error return err } - // Since the current implementation has no way to retrieve a real client - // instance, the implementation will assume client always exist and create - // a dummy one, containing only the provided ID. - channel.Clients = append(channel.Clients, manager.Client{ID: clientId}) + client, err := crm.clients.One(owner, clientId) + if err != nil { + return err + } + channel.Clients = append(channel.Clients, client) return crm.Update(channel) } diff --git a/manager/mocks/clients.go b/manager/mocks/clients.go index 357b8188..0c5e0702 100644 --- a/manager/mocks/clients.go +++ b/manager/mocks/clients.go @@ -2,7 +2,6 @@ package mocks import ( "fmt" - "strconv" "strings" "sync" @@ -29,7 +28,7 @@ func (crm *clientRepositoryMock) Id() string { defer crm.mu.Unlock() crm.counter += 1 - return strconv.Itoa(crm.counter) + return fmt.Sprintf("123e4567-e89b-12d3-a456-%012d", crm.counter) } func (crm *clientRepositoryMock) Save(client manager.Client) error { diff --git a/manager/swagger.yaml b/manager/swagger.yaml index f06c122f..02e3ed4c 100644 --- a/manager/swagger.yaml +++ b/manager/swagger.yaml @@ -30,6 +30,8 @@ paths: description: Failed due to malformed JSON. 409: description: Failed due to using an existing email address. + 415: + description: Missing or invalid content type. 500: $ref: "#/responses/ServiceError" /tokens: @@ -53,7 +55,12 @@ paths: $ref: "#/definitions/Token" 400: description: | - Failed due to malformed JSON or using an invalid credentials. + Failed due to malformed JSON. + 403: + description: | + Failed due to using invalid credentials. + 415: + description: Missing or invalid content type. 500: $ref: "#/responses/ServiceError" /clients: @@ -83,6 +90,8 @@ paths: description: Failed due to malformed JSON. 403: description: Missing or invalid access token provided. + 415: + description: Missing or invalid content type. 500: $ref: "#/responses/ServiceError" get: @@ -160,6 +169,8 @@ paths: description: Missing or invalid access token provided. 404: description: Client does not exist. + 415: + description: Missing or invalid content type. 500: $ref: "#/responses/ServiceError" delete: @@ -206,6 +217,8 @@ paths: description: Failed due to malformed JSON. 403: description: Missing or invalid access token provided. + 415: + description: Missing or invalid content type. 500: $ref: "#/responses/ServiceError" get: @@ -283,6 +296,8 @@ paths: description: Missing or invalid access token provided. 404: description: Channel does not exist. + 415: + description: Missing or invalid content type. 500: $ref: "#/responses/ServiceError" delete: