Mainflux.mainflux/ws/api/endpoint_test.go

189 lines
4.8 KiB
Go

// Copyright (c) Mainflux
// SPDX-License-Identifier: Apache-2.0
package api_test
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/gorilla/websocket"
httpmock "github.com/mainflux/mainflux/http/mocks"
mflog "github.com/mainflux/mainflux/logger"
"github.com/mainflux/mainflux/things/policies"
"github.com/mainflux/mainflux/ws"
"github.com/mainflux/mainflux/ws/api"
"github.com/mainflux/mainflux/ws/mocks"
"github.com/stretchr/testify/assert"
)
const (
chanID = "30315311-56ba-484d-b500-c1e08305511f"
id = "1"
thingKey = "c02ff576-ccd5-40f6-ba5f-c85377aad529"
protocol = "ws"
instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
)
var msg = []byte(`[{"n":"current","t":-1,"v":1.6}]`)
func newService(cc policies.AuthServiceClient) (ws.Service, mocks.MockPubSub) {
pubsub := mocks.NewPubSub()
return ws.New(cc, pubsub), pubsub
}
func newHTTPServer(svc ws.Service) *httptest.Server {
logger := mflog.NewMock()
mux := api.MakeHandler(svc, logger, instanceID)
return httptest.NewServer(mux)
}
func makeURL(tsURL, chanID, subtopic, thingKey string, header bool) (string, error) {
u, _ := url.Parse(tsURL)
u.Scheme = protocol
if chanID == "0" || chanID == "" {
if header {
return fmt.Sprintf("%s/channels/%s/messages", u, chanID), fmt.Errorf("invalid channel id")
}
return fmt.Sprintf("%s/channels/%s/messages?authorization=%s", u, chanID, thingKey), fmt.Errorf("invalid channel id")
}
subtopicPart := ""
if subtopic != "" {
subtopicPart = fmt.Sprintf("/%s", subtopic)
}
if header {
return fmt.Sprintf("%s/channels/%s/messages%s", u, chanID, subtopicPart), nil
}
return fmt.Sprintf("%s/channels/%s/messages%s?authorization=%s", u, chanID, subtopicPart, thingKey), nil
}
func handshake(tsURL, chanID, subtopic, thingKey string, addHeader bool) (*websocket.Conn, *http.Response, error) {
header := http.Header{}
if addHeader {
header.Add("Authorization", thingKey)
}
url, _ := makeURL(tsURL, chanID, subtopic, thingKey, addHeader)
conn, res, errRet := websocket.DefaultDialer.Dial(url, header)
return conn, res, errRet
}
func TestHandshake(t *testing.T) {
thingsClient := httpmock.NewThingsClient(map[string]string{thingKey: chanID})
svc, _ := newService(thingsClient)
ts := newHTTPServer(svc)
defer ts.Close()
cases := []struct {
desc string
chanID string
subtopic string
header bool
thingKey string
status int
err error
msg []byte
}{
{
desc: "connect and send message",
chanID: id,
subtopic: "",
header: true,
thingKey: thingKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message with thingKey as query parameter",
chanID: id,
subtopic: "",
header: false,
thingKey: thingKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message that cannot be published",
chanID: id,
subtopic: "",
header: true,
thingKey: thingKey,
status: http.StatusSwitchingProtocols,
msg: []byte{},
},
{
desc: "connect and send message to subtopic",
chanID: id,
subtopic: "subtopic",
header: true,
thingKey: thingKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message to nested subtopic",
chanID: id,
subtopic: "subtopic/nested",
header: true,
thingKey: thingKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message to all subtopics",
chanID: id,
subtopic: ">",
header: true,
thingKey: thingKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect to empty channel",
chanID: "",
subtopic: "",
header: true,
thingKey: thingKey,
status: http.StatusBadRequest,
msg: []byte{},
},
{
desc: "connect with empty thingKey",
chanID: id,
subtopic: "",
header: true,
thingKey: "",
status: http.StatusForbidden,
msg: []byte{},
},
{
desc: "connect and send message to subtopic with invalid name",
chanID: id,
subtopic: "sub/a*b/topic",
header: true,
thingKey: thingKey,
status: http.StatusBadRequest,
msg: msg,
},
}
for _, tc := range cases {
conn, res, err := handshake(ts.URL, tc.chanID, tc.subtopic, tc.thingKey, tc.header)
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code '%d' got '%d'\n", tc.desc, tc.status, res.StatusCode))
if tc.status == http.StatusSwitchingProtocols {
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err))
err = conn.WriteMessage(websocket.TextMessage, tc.msg)
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err))
}
}
}