// Copyright (c) Mainflux // SPDX-License-Identifier: Apache-2.0 package bench import ( "crypto/rsa" "crypto/tls" "crypto/x509" "errors" "fmt" "log" "net" "strings" "sync" "time" mqtt "github.com/eclipse/paho.mqtt.golang" ) // Set default ping timeout to large value, so that ping // won't fail in the case of broker pingresp delay. const pingTimeout = 10000 // Client - represents mqtt client. type Client struct { ID string BrokerURL string BrokerUser string BrokerPass string MsgTopic string MsgSize int MsgCount int MsgQoS byte Quiet bool timeout int mqttClient *mqtt.Client MTLS bool SkipTLSVer bool Retain bool CA []byte ClientCert tls.Certificate ClientKey *rsa.PrivateKey SendMsg handler } type message struct { ID string `json:"id"` Topic string `json:"topic"` QoS byte `json:"qos"` Payload []byte `json:"payload"` Sent time.Time `json:"sent"` Delivered time.Time `json:"delivered"` Error bool `json:"error"` } type handler func(*message) ([]byte, error) func (c *Client) publish(r chan *runResults, errChan chan<- error) { res := &runResults{} times := make([]*float64, c.MsgCount) start := time.Now() if c.connect() != nil { flushMessages := make([]message, c.MsgCount) for i, m := range flushMessages { m.Error = true times[i] = calcMsgRes(&m, res) } r <- calcRes(res, start, arr(times)) } if !c.Quiet { log.Printf("Client %v is connected to the broker %v\n", c.ID, c.BrokerURL) } wg := sync.WaitGroup{} mu := sync.Mutex{} // Use a single message. m := message{ Topic: c.MsgTopic, QoS: c.MsgQoS, ID: c.ID, Sent: time.Now(), } payload, err := c.SendMsg(&m) if err != nil { errChan <- fmt.Errorf("failed to marshal payload - %s", err.Error()) } for i := 0; i < c.MsgCount; i++ { wg.Add(1) go func(mut *sync.Mutex, wg *sync.WaitGroup, i int, m message) { defer wg.Done() m.Sent = time.Now() token := (*c.mqttClient).Publish(m.Topic, m.QoS, c.Retain, payload) if !token.WaitTimeout(time.Second*time.Duration(c.timeout)) || token.Error() != nil || !(*c.mqttClient).IsConnectionOpen() { m.Error = true mu.Lock() times[i] = calcMsgRes(&m, res) mu.Unlock() return } m.Delivered = time.Now() m.Error = false mu.Lock() times[i] = calcMsgRes(&m, res) mu.Unlock() if !c.Quiet && i > 0 && i%100 == 0 { log.Printf("Client %v published %v messages and keeps publishing...\n", c.ID, i) } }(&mu, &wg, i, m) } wg.Wait() r <- calcRes(res, start, arr(times)) } func (c *Client) connect() error { opts := mqtt.NewClientOptions(). AddBroker(c.BrokerURL). SetClientID(c.ID). SetCleanSession(false). SetAutoReconnect(false). SetOnConnectHandler(c.connected). SetConnectionLostHandler(c.connLost). SetPingTimeout(time.Second * pingTimeout). SetAutoReconnect(true). SetCleanSession(false) if c.BrokerUser != "" && c.BrokerPass != "" { opts.SetUsername(c.BrokerUser) opts.SetPassword(c.BrokerPass) } if c.MTLS { cfg := &tls.Config{ InsecureSkipVerify: c.SkipTLSVer, } if c.CA != nil { cfg.RootCAs = x509.NewCertPool() cfg.RootCAs.AppendCertsFromPEM(c.CA) } if c.ClientCert.Certificate != nil { cfg.Certificates = []tls.Certificate{c.ClientCert} } opts.SetTLSConfig(cfg) opts.SetProtocolVersion(4) } client := mqtt.NewClient(opts) token := client.Connect() token.Wait() c.mqttClient = &client if token.Error() != nil { log.Printf("Client %v had error connecting to the broker: %s\n", c.ID, token.Error().Error()) return token.Error() } return nil } func checkConnection(broker string, timeoutSecs int) error { s := strings.Split(broker, ":") if len(s) != 3 { return errors.New("wrong host address format") } network := s[0] host := strings.Trim(s[1], "/") port := s[2] log.Println("Testing connection...") conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%s", host, port), time.Duration(timeoutSecs)*time.Second) conClose := func() { if conn != nil { log.Println("Closing testing connection...") conn.Close() } } defer conClose() if err, ok := err.(*net.OpError); ok && err.Timeout() { return fmt.Errorf("timeout error: %s", err.Error()) } if err != nil { return fmt.Errorf("error: %s", err.Error()) } log.Printf("Connection to %s://%s:%s looks OK\n", network, host, port) return nil } func arr(a []*float64) []float64 { ret := []float64{} for _, v := range a { if v != nil { ret = append(ret, *v) } } if len(ret) == 0 { ret = append(ret, 0) } return ret } func (c *Client) connected(client mqtt.Client) { if !c.Quiet { log.Printf("Client %v is connected to the broker %v\n", c.ID, c.BrokerURL) } } func (c *Client) connLost(client mqtt.Client, reason error) { log.Printf("Client %v had lost connection to the broker: %s\n", c.ID, reason.Error()) }