From 7793f5554518d785d1b69b3818ece4d001972907 Mon Sep 17 00:00:00 2001 From: fatedier Date: Fri, 10 Aug 2018 11:43:08 +0800 Subject: [PATCH] websocket: update muxer for websocket --- models/config/client_common.go | 7 +- server/service.go | 53 ++++---------- utils/net/conn.go | 108 +++++++++++++++++---------- utils/net/websocket.go | 130 ++++++++++++++------------------- 4 files changed, 143 insertions(+), 155 deletions(-) diff --git a/models/config/client_common.go b/models/config/client_common.go index c1d61cb4..5dc49aa0 100644 --- a/models/config/client_common.go +++ b/models/config/client_common.go @@ -186,9 +186,10 @@ func UnmarshalClientConfFromIni(defaultCfg *ClientCommonConf, content string) (c } if tmpStr, ok = conf.Get("common", "protocol"); ok { - // Now it only support tcp and kcp. - if tmpStr != "kcp" && tmpStr != "websocket" { - tmpStr = "tcp" + // Now it only support tcp and kcp and websocket. + if tmpStr != "tcp" && tmpStr != "kcp" && tmpStr != "websocket" { + err = fmt.Errorf("Parse conf error: invalid protocol") + return } cfg.Protocol = tmpStr } diff --git a/server/service.go b/server/service.go index dcb7a2ba..024b6835 100644 --- a/server/service.go +++ b/server/service.go @@ -15,11 +15,11 @@ package server import ( + "bytes" "fmt" "io/ioutil" "net" "net/http" - "strings" "time" "github.com/fatedier/frp/assets" @@ -139,6 +139,13 @@ func NewService() (svr *Service, err error) { log.Info("frps kcp listen on udp %s:%d", cfg.BindAddr, cfg.KcpBindPort) } + // Listen for accepting connections from client using websocket protocol. + websocketPrefix := []byte("GET /%23frp") + websocketLn := svr.muxer.Listen(0, uint32(len(websocketPrefix)), func(data []byte) bool { + return bytes.Equal(data, websocketPrefix) + }) + svr.websocketListener = frpNet.NewWebsocketListener(websocketLn) + // Create http vhost muxer. if cfg.VhostHttpPort > 0 { rp := vhost.NewHttpReverseProxy() @@ -150,7 +157,9 @@ func NewService() (svr *Service, err error) { Handler: rp, } var l net.Listener - if !httpMuxOn { + if httpMuxOn { + l = svr.muxer.ListenHttp(1) + } else { l, err = net.Listen("tcp", address) if err != nil { err = fmt.Errorf("Create vhost http listener error, %v", err) @@ -165,7 +174,7 @@ func NewService() (svr *Service, err error) { if cfg.VhostHttpsPort > 0 { var l net.Listener if httpsMuxOn { - l = svr.muxer.ListenHttps(0) + l = svr.muxer.ListenHttps(1) } else { l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort)) if err != nil { @@ -205,37 +214,6 @@ func NewService() (svr *Service, err error) { log.Info("Dashboard listen on %s:%d", cfg.DashboardAddr, cfg.DashboardPort) } - if !httpMuxOn { - svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0), nil) - return - } - - // server := &http.Server{} - if httpMuxOn { - rp := svr.httpReverseProxy - svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0), - func(w http.ResponseWriter, req *http.Request) bool { - domain := getHostFromAddr(req.Host) - location := req.URL.Path - headers := rp.GetHeaders(domain, location) - if headers == nil { - return true - } - rp.ServeHTTP(w, req) - return false - }) - } - - return -} - -func getHostFromAddr(addr string) (host string) { - strs := strings.Split(addr, ":") - if len(strs) > 1 { - host = strs[0] - } else { - host = addr - } return } @@ -246,9 +224,9 @@ func (svr *Service) Run() { if g.GlbServerCfg.KcpBindPort > 0 { go svr.HandleListener(svr.kcpListener) } - if svr.websocketListener != nil { - go svr.HandleListener(svr.websocketListener) - } + + go svr.HandleListener(svr.websocketListener) + svr.HandleListener(svr.listener) } @@ -260,6 +238,7 @@ func (svr *Service) HandleListener(l frpNet.Listener) { log.Warn("Listener for incoming connections from client closed") return } + // Start a new goroutine for dealing connections. go func(frpConn frpNet.Conn) { dealFn := func(conn frpNet.Conn) { diff --git a/utils/net/conn.go b/utils/net/conn.go index 825a9896..6dab2bdb 100644 --- a/utils/net/conn.go +++ b/utils/net/conn.go @@ -96,6 +96,75 @@ func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } +type CloseNotifyConn struct { + net.Conn + log.Logger + + // 1 means closed + closeFlag int32 + + closeFn func() +} + +// closeFn will be only called once +func WrapCloseNotifyConn(c net.Conn, closeFn func()) Conn { + return &CloseNotifyConn{ + Conn: c, + Logger: log.NewPrefixLogger(""), + closeFn: closeFn, + } +} + +func (cc *CloseNotifyConn) Close() (err error) { + pflag := atomic.SwapInt32(&cc.closeFlag, 1) + if pflag == 0 { + err = cc.Close() + if cc.closeFn != nil { + cc.closeFn() + } + } + return +} + +type StatsConn struct { + Conn + + closed int64 // 1 means closed + totalRead int64 + totalWrite int64 + statsFunc func(totalRead, totalWrite int64) +} + +func WrapStatsConn(conn Conn, statsFunc func(total, totalWrite int64)) *StatsConn { + return &StatsConn{ + Conn: conn, + statsFunc: statsFunc, + } +} + +func (statsConn *StatsConn) Read(p []byte) (n int, err error) { + n, err = statsConn.Conn.Read(p) + statsConn.totalRead += int64(n) + return +} + +func (statsConn *StatsConn) Write(p []byte) (n int, err error) { + n, err = statsConn.Conn.Write(p) + statsConn.totalWrite += int64(n) + return +} + +func (statsConn *StatsConn) Close() (err error) { + old := atomic.SwapInt64(&statsConn.closed, 1) + if old != 1 { + err = statsConn.Conn.Close() + if statsConn.statsFunc != nil { + statsConn.statsFunc(statsConn.totalRead, statsConn.totalWrite) + } + } + return +} + func ConnectServer(protocol string, addr string) (c Conn, err error) { switch protocol { case "tcp": @@ -138,42 +207,3 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn return nil, fmt.Errorf("unsupport protocol: %s", protocol) } } - -type StatsConn struct { - Conn - - closed int64 // 1 means closed - totalRead int64 - totalWrite int64 - statsFunc func(totalRead, totalWrite int64) -} - -func WrapStatsConn(conn Conn, statsFunc func(total, totalWrite int64)) *StatsConn { - return &StatsConn{ - Conn: conn, - statsFunc: statsFunc, - } -} - -func (statsConn *StatsConn) Read(p []byte) (n int, err error) { - n, err = statsConn.Conn.Read(p) - statsConn.totalRead += int64(n) - return -} - -func (statsConn *StatsConn) Write(p []byte) (n int, err error) { - n, err = statsConn.Conn.Write(p) - statsConn.totalWrite += int64(n) - return -} - -func (statsConn *StatsConn) Close() (err error) { - old := atomic.SwapInt64(&statsConn.closed, 1) - if old != 1 { - err = statsConn.Conn.Close() - if statsConn.statsFunc != nil { - statsConn.statsFunc(statsConn.totalRead, statsConn.totalWrite) - } - } - return -} diff --git a/utils/net/websocket.go b/utils/net/websocket.go index 04111129..a3bf0f0a 100644 --- a/utils/net/websocket.go +++ b/utils/net/websocket.go @@ -1,127 +1,105 @@ package net import ( + "errors" "fmt" "net" "net/http" "net/url" - "sync/atomic" "time" "github.com/fatedier/frp/utils/log" + "golang.org/x/net/websocket" ) +var ( + ErrWebsocketListenerClosed = errors.New("websocket listener closed") +) + +const ( + FrpWebsocketPath = "/#frp" +) + type WebsocketListener struct { + net.Addr + ln net.Listener + accept chan Conn log.Logger + server *http.Server httpMutex *http.ServeMux - connChan chan *WebsocketConn - closeFlag bool } -func NewWebsocketListener(ln net.Listener, - filter func(w http.ResponseWriter, r *http.Request) bool) (l *WebsocketListener, err error) { - l = &WebsocketListener{ - httpMutex: http.NewServeMux(), - connChan: make(chan *WebsocketConn), - Logger: log.NewPrefixLogger(""), +// ln: tcp listener for websocket connections +func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) { + wl = &WebsocketListener{ + Addr: ln.Addr(), + accept: make(chan Conn), + Logger: log.NewPrefixLogger(""), } - l.httpMutex.Handle("/", websocket.Handler(func(c *websocket.Conn) { - conn := NewWebScoketConn(c) - l.connChan <- conn - conn.waitClose() + + muxer := http.NewServeMux() + muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) { + notifyCh := make(chan struct{}) + conn := WrapCloseNotifyConn(c, func() { + close(notifyCh) + }) + wl.accept <- conn + <-notifyCh })) - l.server = &http.Server{ - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if filter != nil && !filter(w, r) { - return - } - l.httpMutex.ServeHTTP(w, r) - }), + + wl.server = &http.Server{ + Addr: ln.Addr().String(), + Handler: muxer, } - ch := make(chan struct{}) - go func() { - close(ch) - err = l.server.Serve(ln) - }() - <-ch - <-time.After(time.Millisecond) + + go wl.server.Serve(ln) return } -func ListenWebsocket(bindAddr string, bindPort int) (l *WebsocketListener, err error) { - ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) +func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) { + tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) if err != nil { - return + return nil, err } - l, err = NewWebsocketListener(ln, nil) - return + l := NewWebsocketListener(tcpLn) + return l, nil } func (p *WebsocketListener) Accept() (Conn, error) { - c := <-p.connChan + c, ok := <-p.accept + if !ok { + return nil, ErrWebsocketListenerClosed + } return c, nil } func (p *WebsocketListener) Close() error { - if !p.closeFlag { - p.closeFlag = true - p.server.Close() - } - return nil + return p.server.Close() } -type WebsocketConn struct { - net.Conn - log.Logger - closed int32 - wait chan struct{} -} - -func NewWebScoketConn(conn net.Conn) (c *WebsocketConn) { - c = &WebsocketConn{ - Conn: conn, - Logger: log.NewPrefixLogger(""), - wait: make(chan struct{}), - } - return -} - -func (p *WebsocketConn) Close() error { - if atomic.SwapInt32(&p.closed, 1) == 1 { - return nil - } - close(p.wait) - return p.Conn.Close() -} - -func (p *WebsocketConn) waitClose() { - <-p.wait -} - -// ConnectWebsocketServer : -// addr: ws://domain:port -func ConnectWebsocketServer(addr string) (c Conn, err error) { - addr = "ws://" + addr +// addr: domain:port +func ConnectWebsocketServer(addr string) (Conn, error) { + addr = "ws://" + addr + FrpWebsocketPath uri, err := url.Parse(addr) if err != nil { - return + return nil, err } origin := "http://" + uri.Host cfg, err := websocket.NewConfig(addr, origin) if err != nil { - return + return nil, err } cfg.Dialer = &net.Dialer{ - Timeout: time.Second * 10, + Timeout: 10 * time.Second, } conn, err := websocket.DialConfig(cfg) if err != nil { - return + return nil, err } - c = NewWebScoketConn(conn) - return + c := WrapConn(conn) + return c, nil }