172 lines
3.9 KiB
Go
172 lines
3.9 KiB
Go
// Copyright (c) Mainflux
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package api
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-zoo/bone"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/mainflux/mainflux/pkg/errors"
|
|
"github.com/mainflux/mainflux/pkg/messaging"
|
|
"github.com/mainflux/mainflux/ws"
|
|
)
|
|
|
|
var channelPartRegExp = regexp.MustCompile(`^/channels/([\w\-]+)/messages(/[^?]*)?(\?.*)?$`)
|
|
|
|
func handshake(svc ws.Service) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
req, err := decodeRequest(r)
|
|
if err != nil {
|
|
encodeError(w, err)
|
|
return
|
|
}
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
logger.Warn(fmt.Sprintf("Failed to upgrade connection to websocket: %s", err.Error()))
|
|
return
|
|
}
|
|
req.conn = conn
|
|
client := ws.NewClient(conn)
|
|
|
|
if err := svc.Subscribe(ctx, req.thingKey, req.chanID, req.subtopic, client); err != nil {
|
|
req.conn.Close()
|
|
return
|
|
}
|
|
|
|
logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.chanID))
|
|
msgs := make(chan []byte)
|
|
|
|
// Listen for messages received from the chan messages, and publish them to broker
|
|
go process(ctx, svc, req, msgs)
|
|
go listen(conn, msgs)
|
|
}
|
|
}
|
|
|
|
func decodeRequest(r *http.Request) (connReq, error) {
|
|
authKey := r.Header.Get("Authorization")
|
|
if authKey == "" {
|
|
authKeys := bone.GetQuery(r, "authorization")
|
|
if len(authKeys) == 0 {
|
|
logger.Debug("Missing authorization key.")
|
|
return connReq{}, errUnauthorizedAccess
|
|
}
|
|
authKey = authKeys[0]
|
|
}
|
|
|
|
chanID := bone.GetValue(r, "chanID")
|
|
|
|
req := connReq{
|
|
thingKey: authKey,
|
|
chanID: chanID,
|
|
}
|
|
|
|
channelParts := channelPartRegExp.FindStringSubmatch(r.RequestURI)
|
|
if len(channelParts) < 2 {
|
|
logger.Warn("Empty channel id or malformed url")
|
|
return connReq{}, errors.ErrMalformedEntity
|
|
}
|
|
|
|
subtopic, err := parseSubTopic(channelParts[2])
|
|
if err != nil {
|
|
return connReq{}, err
|
|
}
|
|
|
|
req.subtopic = subtopic
|
|
|
|
return req, nil
|
|
}
|
|
|
|
func parseSubTopic(subtopic string) (string, error) {
|
|
if subtopic == "" {
|
|
return subtopic, nil
|
|
}
|
|
|
|
subtopic, err := url.QueryUnescape(subtopic)
|
|
if err != nil {
|
|
return "", errMalformedSubtopic
|
|
}
|
|
|
|
subtopic = strings.ReplaceAll(subtopic, "/", ".")
|
|
|
|
elems := strings.Split(subtopic, ".")
|
|
filteredElems := []string{}
|
|
for _, elem := range elems {
|
|
if elem == "" {
|
|
continue
|
|
}
|
|
|
|
if len(elem) > 1 && (strings.Contains(elem, "*") || strings.Contains(elem, ">")) {
|
|
return "", errMalformedSubtopic
|
|
}
|
|
|
|
filteredElems = append(filteredElems, elem)
|
|
}
|
|
|
|
subtopic = strings.Join(filteredElems, ".")
|
|
|
|
return subtopic, nil
|
|
}
|
|
|
|
func listen(conn *websocket.Conn, msgs chan<- []byte) {
|
|
for {
|
|
// Listen for message from the client, and push them to the msgs channel
|
|
_, payload, err := conn.ReadMessage()
|
|
|
|
if websocket.IsUnexpectedCloseError(err) {
|
|
logger.Debug(fmt.Sprintf("Closing WS connection: %s", err.Error()))
|
|
close(msgs)
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
logger.Warn(fmt.Sprintf("Failed to read message: %s", err.Error()))
|
|
close(msgs)
|
|
return
|
|
}
|
|
|
|
msgs <- payload
|
|
}
|
|
}
|
|
|
|
func process(ctx context.Context, svc ws.Service, req connReq, msgs <-chan []byte) {
|
|
for msg := range msgs {
|
|
m := messaging.Message{
|
|
Channel: req.chanID,
|
|
Subtopic: req.subtopic,
|
|
Protocol: "websocket",
|
|
Payload: msg,
|
|
Created: time.Now().UnixNano(),
|
|
}
|
|
_ = svc.Publish(ctx, req.thingKey, &m)
|
|
}
|
|
if err := svc.Unsubscribe(ctx, req.thingKey, req.chanID, req.subtopic); err != nil {
|
|
req.conn.Close()
|
|
}
|
|
}
|
|
|
|
func encodeError(w http.ResponseWriter, err error) {
|
|
var statusCode int
|
|
|
|
switch err {
|
|
case ws.ErrEmptyID, ws.ErrEmptyTopic:
|
|
statusCode = http.StatusBadRequest
|
|
case errUnauthorizedAccess:
|
|
statusCode = http.StatusForbidden
|
|
case errMalformedSubtopic, errors.ErrMalformedEntity:
|
|
statusCode = http.StatusBadRequest
|
|
default:
|
|
statusCode = http.StatusNotFound
|
|
}
|
|
logger.Warn(fmt.Sprintf("Failed to authorize: %s", err.Error()))
|
|
w.WriteHeader(statusCode)
|
|
}
|