Review the goroutines in Loop

This commit is contained in:
Quentin Perez 2015-10-23 15:16:42 +02:00
parent 49e191195e
commit c318d94118
1 changed files with 97 additions and 47 deletions

View File

@ -1,10 +1,10 @@
package gottyclient package gottyclient
import ( import (
"bufio"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
@ -15,10 +15,10 @@ import (
"sync" "sync"
"syscall" "syscall"
"time" "time"
"unicode/utf8"
"unsafe" "unsafe"
"github.com/moul/gotty-client/vendor/github.com/Sirupsen/logrus" "github.com/moul/gotty-client/vendor/github.com/Sirupsen/logrus"
"github.com/moul/gotty-client/vendor/github.com/creack/goselect"
"github.com/moul/gotty-client/vendor/github.com/gorilla/websocket" "github.com/moul/gotty-client/vendor/github.com/gorilla/websocket"
"github.com/moul/gotty-client/vendor/golang.org/x/crypto/ssh/terminal" "github.com/moul/gotty-client/vendor/golang.org/x/crypto/ssh/terminal"
) )
@ -200,11 +200,19 @@ func (c *Client) Loop() error {
} }
} }
var wg sync.WaitGroup
quit := make(chan struct{})
done := make(chan bool) done := make(chan bool)
go c.readLoop(done)
go c.writeLoop(done) wg.Add(1)
go c.termsizeLoop(done) go c.termsizeLoop(quit, &wg)
wg.Add(1)
go c.readLoop(done, quit, &wg)
wg.Add(1)
go c.writeLoop(done, quit, &wg)
<-done <-done
close(quit)
wg.Wait()
return nil return nil
} }
@ -216,9 +224,11 @@ type winsize struct {
y uint16 y uint16
} }
func (c *Client) termsizeLoop(done chan bool) { func (c *Client) termsizeLoop(quit chan struct{}, wg *sync.WaitGroup) {
defer wg.Done()
ch := make(chan os.Signal, 1) ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGWINCH) signal.Notify(ch, syscall.SIGWINCH)
defer signal.Reset(syscall.SIGWINCH)
ws := winsize{} ws := winsize{}
for { for {
@ -235,66 +245,106 @@ func (c *Client) termsizeLoop(done chan bool) {
if err != nil { if err != nil {
logrus.Warnf("ws.WriteMessage failed: %v", err) logrus.Warnf("ws.WriteMessage failed: %v", err)
} }
select {
<-ch case <-quit:
return
case <-ch:
}
} }
} }
func (c *Client) writeLoop(done chan bool) { type exposeFd interface {
Fd() uintptr
}
func (c *Client) writeLoop(done chan bool, quit chan struct{}, wg *sync.WaitGroup) {
defer wg.Done()
buff := make([]byte, 128)
oldState, err := terminal.MakeRaw(0) oldState, err := terminal.MakeRaw(0)
if err == nil { if err == nil {
defer terminal.Restore(0, oldState) defer terminal.Restore(0, oldState)
} }
reader := bufio.NewReader(os.Stdin) rdfs := &goselect.FDSet{}
reader := io.Reader(os.Stdin)
for { for {
x, size, err := reader.ReadRune() rdfs.Zero()
if size <= 0 || err != nil { rdfs.Set(reader.(exposeFd).Fd())
done <- true err := goselect.Select(1, rdfs, nil, nil, 50*time.Millisecond)
return
}
p := make([]byte, size)
utf8.EncodeRune(p, x)
err = c.write(append([]byte("0"), p...))
if err != nil { if err != nil {
done <- true done <- true
return return
} }
if rdfs.IsSet(reader.(exposeFd).Fd()) {
size, err := reader.Read(buff)
if size <= 0 || err != nil {
done <- true
return
}
data := buff[:size]
err = c.write(append([]byte("0"), data...))
if err != nil {
done <- true
return
}
}
select {
case <-quit:
return
default:
break
}
} }
} }
func (c *Client) readLoop(done chan bool) { func (c *Client) readLoop(done chan bool, quit chan struct{}, wg *sync.WaitGroup) {
defer wg.Done()
type MessageNonBlocking struct {
Data []byte
Err error
}
msgChan := make(chan MessageNonBlocking)
for { for {
_, data, err := c.Conn.ReadMessage() go func() {
if err != nil { _, data, err := c.Conn.ReadMessage()
done <- true msgChan <- MessageNonBlocking{Data: data, Err: err}
logrus.Warnf("c.Conn.ReadMessage: %v", err) }()
select {
case <-quit:
return return
} case msg := <-msgChan:
if len(data) == 0 { if msg.Err != nil {
done <- true done <- true
logrus.Warnf("An error has occured") logrus.Warnf("c.Conn.ReadMessage: %v", msg.Err)
return return
} }
switch data[0] { if len(msg.Data) == 0 {
case '0': // data done <- true
buf, err := base64.StdEncoding.DecodeString(string(data[1:])) logrus.Warnf("An error has occured")
if err != nil { return
logrus.Warnf("Invalid base64 content: %q", data[1:]) }
switch msg.Data[0] {
case '0': // data
buf, err := base64.StdEncoding.DecodeString(string(msg.Data[1:]))
if err != nil {
logrus.Warnf("Invalid base64 content: %q", msg.Data[1:])
}
fmt.Print(string(buf))
case '1': // pong
case '2': // new title
newTitle := string(msg.Data[1:])
fmt.Printf("\033]0;%s\007", newTitle)
case '3': // json prefs
logrus.Debugf("Unhandled protocol message: json pref: %s", string(msg.Data[1:]))
case '4': // autoreconnect
logrus.Debugf("Unhandled protocol message: autoreconnect: %s", string(msg.Data))
default:
logrus.Warnf("Unhandled protocol message: %s", string(msg.Data))
} }
fmt.Print(string(buf))
case '1': // pong
case '2': // new title
newTitle := string(data[1:])
fmt.Printf("\033]0;%s\007", newTitle)
case '3': // json prefs
logrus.Debugf("Unhandled protocol message: json pref: %s", string(data[1:]))
case '4': // autoreconnect
logrus.Debugf("Unhandled protocol message: autoreconnect: %s", string(data))
default:
logrus.Warnf("Unhandled protocol message: %s", string(data))
} }
} }
} }