// Copyright (c) Mainflux // SPDX-License-Identifier: Apache-2.0 package api import ( "context" "fmt" "net/http" "net/url" "regexp" "strings" "time" "github.com/go-zoo/bone" "github.com/gorilla/websocket" "github.com/mainflux/mainflux/pkg/errors" "github.com/mainflux/mainflux/pkg/messaging" "github.com/mainflux/mainflux/ws" ) var channelPartRegExp = regexp.MustCompile(`^/channels/([\w\-]+)/messages(/[^?]*)?(\?.*)?$`) func handshake(svc ws.Service) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() req, err := decodeRequest(r) if err != nil { encodeError(w, err) return } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { logger.Warn(fmt.Sprintf("Failed to upgrade connection to websocket: %s", err.Error())) return } req.conn = conn client := ws.NewClient(conn) if err := svc.Subscribe(ctx, req.thingKey, req.chanID, req.subtopic, client); err != nil { req.conn.Close() return } logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.chanID)) msgs := make(chan []byte) // Listen for messages received from the chan messages, and publish them to broker go process(ctx, svc, req, msgs) go listen(conn, msgs) } } func decodeRequest(r *http.Request) (connReq, error) { authKey := r.Header.Get("Authorization") if authKey == "" { authKeys := bone.GetQuery(r, "authorization") if len(authKeys) == 0 { logger.Debug("Missing authorization key.") return connReq{}, errUnauthorizedAccess } authKey = authKeys[0] } chanID := bone.GetValue(r, "chanID") req := connReq{ thingKey: authKey, chanID: chanID, } channelParts := channelPartRegExp.FindStringSubmatch(r.RequestURI) if len(channelParts) < 2 { logger.Warn("Empty channel id or malformed url") return connReq{}, errors.ErrMalformedEntity } subtopic, err := parseSubTopic(channelParts[2]) if err != nil { return connReq{}, err } req.subtopic = subtopic return req, nil } func parseSubTopic(subtopic string) (string, error) { if subtopic == "" { return subtopic, nil } subtopic, err := url.QueryUnescape(subtopic) if err != nil { return "", errMalformedSubtopic } subtopic = strings.ReplaceAll(subtopic, "/", ".") elems := strings.Split(subtopic, ".") filteredElems := []string{} for _, elem := range elems { if elem == "" { continue } if len(elem) > 1 && (strings.Contains(elem, "*") || strings.Contains(elem, ">")) { return "", errMalformedSubtopic } filteredElems = append(filteredElems, elem) } subtopic = strings.Join(filteredElems, ".") return subtopic, nil } func listen(conn *websocket.Conn, msgs chan<- []byte) { for { // Listen for message from the client, and push them to the msgs channel _, payload, err := conn.ReadMessage() if websocket.IsUnexpectedCloseError(err) { logger.Debug(fmt.Sprintf("Closing WS connection: %s", err.Error())) close(msgs) return } if err != nil { logger.Warn(fmt.Sprintf("Failed to read message: %s", err.Error())) close(msgs) return } msgs <- payload } } func process(ctx context.Context, svc ws.Service, req connReq, msgs <-chan []byte) { for msg := range msgs { m := messaging.Message{ Channel: req.chanID, Subtopic: req.subtopic, Protocol: "websocket", Payload: msg, Created: time.Now().UnixNano(), } _ = svc.Publish(ctx, req.thingKey, &m) } if err := svc.Unsubscribe(ctx, req.thingKey, req.chanID, req.subtopic); err != nil { req.conn.Close() } } func encodeError(w http.ResponseWriter, err error) { var statusCode int switch err { case ws.ErrEmptyID, ws.ErrEmptyTopic: statusCode = http.StatusBadRequest case errUnauthorizedAccess: statusCode = http.StatusForbidden case errMalformedSubtopic, errors.ErrMalformedEntity: statusCode = http.StatusBadRequest default: statusCode = http.StatusNotFound } logger.Warn(fmt.Sprintf("Failed to authorize: %s", err.Error())) w.WriteHeader(statusCode) }