// Copyright (c) Mainflux // SPDX-License-Identifier: Apache-2.0 package api_test import ( "fmt" "net/http" "net/http/httptest" "net/url" "testing" "github.com/gorilla/websocket" httpmock "github.com/mainflux/mainflux/http/mocks" mflog "github.com/mainflux/mainflux/logger" "github.com/mainflux/mainflux/things/policies" "github.com/mainflux/mainflux/ws" "github.com/mainflux/mainflux/ws/api" "github.com/mainflux/mainflux/ws/mocks" "github.com/stretchr/testify/assert" ) const ( chanID = "30315311-56ba-484d-b500-c1e08305511f" id = "1" thingKey = "c02ff576-ccd5-40f6-ba5f-c85377aad529" protocol = "ws" instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" ) var msg = []byte(`[{"n":"current","t":-1,"v":1.6}]`) func newService(cc policies.AuthServiceClient) (ws.Service, mocks.MockPubSub) { pubsub := mocks.NewPubSub() return ws.New(cc, pubsub), pubsub } func newHTTPServer(svc ws.Service) *httptest.Server { logger := mflog.NewMock() mux := api.MakeHandler(svc, logger, instanceID) return httptest.NewServer(mux) } func makeURL(tsURL, chanID, subtopic, thingKey string, header bool) (string, error) { u, _ := url.Parse(tsURL) u.Scheme = protocol if chanID == "0" || chanID == "" { if header { return fmt.Sprintf("%s/channels/%s/messages", u, chanID), fmt.Errorf("invalid channel id") } return fmt.Sprintf("%s/channels/%s/messages?authorization=%s", u, chanID, thingKey), fmt.Errorf("invalid channel id") } subtopicPart := "" if subtopic != "" { subtopicPart = fmt.Sprintf("/%s", subtopic) } if header { return fmt.Sprintf("%s/channels/%s/messages%s", u, chanID, subtopicPart), nil } return fmt.Sprintf("%s/channels/%s/messages%s?authorization=%s", u, chanID, subtopicPart, thingKey), nil } func handshake(tsURL, chanID, subtopic, thingKey string, addHeader bool) (*websocket.Conn, *http.Response, error) { header := http.Header{} if addHeader { header.Add("Authorization", thingKey) } url, _ := makeURL(tsURL, chanID, subtopic, thingKey, addHeader) conn, res, errRet := websocket.DefaultDialer.Dial(url, header) return conn, res, errRet } func TestHandshake(t *testing.T) { thingsClient := httpmock.NewThingsClient(map[string]string{thingKey: chanID}) svc, _ := newService(thingsClient) ts := newHTTPServer(svc) defer ts.Close() cases := []struct { desc string chanID string subtopic string header bool thingKey string status int err error msg []byte }{ { desc: "connect and send message", chanID: id, subtopic: "", header: true, thingKey: thingKey, status: http.StatusSwitchingProtocols, msg: msg, }, { desc: "connect and send message with thingKey as query parameter", chanID: id, subtopic: "", header: false, thingKey: thingKey, status: http.StatusSwitchingProtocols, msg: msg, }, { desc: "connect and send message that cannot be published", chanID: id, subtopic: "", header: true, thingKey: thingKey, status: http.StatusSwitchingProtocols, msg: []byte{}, }, { desc: "connect and send message to subtopic", chanID: id, subtopic: "subtopic", header: true, thingKey: thingKey, status: http.StatusSwitchingProtocols, msg: msg, }, { desc: "connect and send message to nested subtopic", chanID: id, subtopic: "subtopic/nested", header: true, thingKey: thingKey, status: http.StatusSwitchingProtocols, msg: msg, }, { desc: "connect and send message to all subtopics", chanID: id, subtopic: ">", header: true, thingKey: thingKey, status: http.StatusSwitchingProtocols, msg: msg, }, { desc: "connect to empty channel", chanID: "", subtopic: "", header: true, thingKey: thingKey, status: http.StatusBadRequest, msg: []byte{}, }, { desc: "connect with empty thingKey", chanID: id, subtopic: "", header: true, thingKey: "", status: http.StatusForbidden, msg: []byte{}, }, { desc: "connect and send message to subtopic with invalid name", chanID: id, subtopic: "sub/a*b/topic", header: true, thingKey: thingKey, status: http.StatusBadRequest, msg: msg, }, } for _, tc := range cases { conn, res, err := handshake(ts.URL, tc.chanID, tc.subtopic, tc.thingKey, tc.header) assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code '%d' got '%d'\n", tc.desc, tc.status, res.StatusCode)) if tc.status == http.StatusSwitchingProtocols { assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err)) err = conn.WriteMessage(websocket.TextMessage, tc.msg) assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err)) } } }