103 lines
2.3 KiB
Go
103 lines
2.3 KiB
Go
// Copyright (c) Mainflux
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package ws_test
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/mainflux/mainflux/ws"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
const expectedCount = uint64(1)
|
|
|
|
var (
|
|
msgChan = make(chan []byte)
|
|
c *ws.Client
|
|
count uint64
|
|
|
|
upgrader = websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
)
|
|
|
|
func handler(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
for {
|
|
_, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
break
|
|
}
|
|
atomic.AddUint64(&count, 1)
|
|
msgChan <- message
|
|
}
|
|
}
|
|
|
|
func TestHandle(t *testing.T) {
|
|
s := httptest.NewServer(http.HandlerFunc(handler))
|
|
defer s.Close()
|
|
|
|
// Convert http://127.0.0.1 to ws://127.0.0.1
|
|
u := strings.Replace(s.URL, "http", "ws", 1)
|
|
|
|
// Connect to the server
|
|
wsConn, _, err := websocket.DefaultDialer.Dial(u, nil)
|
|
if err != nil {
|
|
t.Fatalf("%v", err)
|
|
}
|
|
defer wsConn.Close()
|
|
|
|
c = ws.NewClient(wsConn)
|
|
|
|
cases := []struct {
|
|
desc string
|
|
publisher string
|
|
expectedPayload []byte
|
|
expectMsg bool
|
|
}{
|
|
{
|
|
desc: "handling with different id from ws.Client",
|
|
publisher: msg.Publisher,
|
|
expectedPayload: msg.Payload,
|
|
expectMsg: true,
|
|
},
|
|
{
|
|
desc: "handling with same id as ws.Client (empty by default) drops message",
|
|
publisher: "",
|
|
expectedPayload: []byte{},
|
|
expectMsg: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
msg.Publisher = tc.publisher
|
|
err = c.Handle(&msg)
|
|
assert.Nil(t, err, fmt.Sprintf("expected nil error from handle, got: %s", err))
|
|
receivedMsg := []byte{}
|
|
switch tc.expectMsg {
|
|
case true:
|
|
rec := <-msgChan // Wait for the message to be received.
|
|
receivedMsg = rec
|
|
case false:
|
|
time.Sleep(100 * time.Millisecond) // Give time to server to process c.Handle call.
|
|
}
|
|
assert.Equal(t, tc.expectedPayload, receivedMsg, fmt.Sprintf("%s: expected %+v, got %+v", tc.desc, &msg, receivedMsg))
|
|
}
|
|
c := atomic.LoadUint64(&count)
|
|
assert.Equal(t, expectedCount, c, fmt.Sprintf("expected message count %d, got %d", expectedCount, c))
|
|
}
|