Mainflux.mainflux/ws/api/endpoints.go

172 lines
3.9 KiB
Go

// Copyright (c) Mainflux
// SPDX-License-Identifier: Apache-2.0
package api
import (
"context"
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/go-zoo/bone"
"github.com/gorilla/websocket"
"github.com/mainflux/mainflux/pkg/errors"
"github.com/mainflux/mainflux/pkg/messaging"
"github.com/mainflux/mainflux/ws"
)
var channelPartRegExp = regexp.MustCompile(`^/channels/([\w\-]+)/messages(/[^?]*)?(\?.*)?$`)
func handshake(svc ws.Service) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
req, err := decodeRequest(r)
if err != nil {
encodeError(w, err)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Warn(fmt.Sprintf("Failed to upgrade connection to websocket: %s", err.Error()))
return
}
req.conn = conn
client := ws.NewClient(conn)
if err := svc.Subscribe(ctx, req.thingKey, req.chanID, req.subtopic, client); err != nil {
req.conn.Close()
return
}
logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.chanID))
msgs := make(chan []byte)
// Listen for messages received from the chan messages, and publish them to broker
go process(ctx, svc, req, msgs)
go listen(conn, msgs)
}
}
func decodeRequest(r *http.Request) (connReq, error) {
authKey := r.Header.Get("Authorization")
if authKey == "" {
authKeys := bone.GetQuery(r, "authorization")
if len(authKeys) == 0 {
logger.Debug("Missing authorization key.")
return connReq{}, errUnauthorizedAccess
}
authKey = authKeys[0]
}
chanID := bone.GetValue(r, "chanID")
req := connReq{
thingKey: authKey,
chanID: chanID,
}
channelParts := channelPartRegExp.FindStringSubmatch(r.RequestURI)
if len(channelParts) < 2 {
logger.Warn("Empty channel id or malformed url")
return connReq{}, errors.ErrMalformedEntity
}
subtopic, err := parseSubTopic(channelParts[2])
if err != nil {
return connReq{}, err
}
req.subtopic = subtopic
return req, nil
}
func parseSubTopic(subtopic string) (string, error) {
if subtopic == "" {
return subtopic, nil
}
subtopic, err := url.QueryUnescape(subtopic)
if err != nil {
return "", errMalformedSubtopic
}
subtopic = strings.ReplaceAll(subtopic, "/", ".")
elems := strings.Split(subtopic, ".")
filteredElems := []string{}
for _, elem := range elems {
if elem == "" {
continue
}
if len(elem) > 1 && (strings.Contains(elem, "*") || strings.Contains(elem, ">")) {
return "", errMalformedSubtopic
}
filteredElems = append(filteredElems, elem)
}
subtopic = strings.Join(filteredElems, ".")
return subtopic, nil
}
func listen(conn *websocket.Conn, msgs chan<- []byte) {
for {
// Listen for message from the client, and push them to the msgs channel
_, payload, err := conn.ReadMessage()
if websocket.IsUnexpectedCloseError(err) {
logger.Debug(fmt.Sprintf("Closing WS connection: %s", err.Error()))
close(msgs)
return
}
if err != nil {
logger.Warn(fmt.Sprintf("Failed to read message: %s", err.Error()))
close(msgs)
return
}
msgs <- payload
}
}
func process(ctx context.Context, svc ws.Service, req connReq, msgs <-chan []byte) {
for msg := range msgs {
m := messaging.Message{
Channel: req.chanID,
Subtopic: req.subtopic,
Protocol: "websocket",
Payload: msg,
Created: time.Now().UnixNano(),
}
_ = svc.Publish(ctx, req.thingKey, &m)
}
if err := svc.Unsubscribe(ctx, req.thingKey, req.chanID, req.subtopic); err != nil {
req.conn.Close()
}
}
func encodeError(w http.ResponseWriter, err error) {
var statusCode int
switch err {
case ws.ErrEmptyID, ws.ErrEmptyTopic:
statusCode = http.StatusBadRequest
case errUnauthorizedAccess:
statusCode = http.StatusForbidden
case errMalformedSubtopic, errors.ErrMalformedEntity:
statusCode = http.StatusBadRequest
default:
statusCode = http.StatusNotFound
}
logger.Warn(fmt.Sprintf("Failed to authorize: %s", err.Error()))
w.WriteHeader(statusCode)
}