106 lines
2.2 KiB
Go
106 lines
2.2 KiB
Go
|
package websocket
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"errors"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
ErrInvalidMethod = errors.New("Only GET Supported")
|
||
|
ErrInvalidVersion = errors.New("Sec-Websocket-Version: 13")
|
||
|
ErrInvalidUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"")
|
||
|
ErrInvalidConnection = errors.New("\"Connection\" must be \"Upgrade\"")
|
||
|
ErrMissingKey = errors.New("Missing Key")
|
||
|
ErrHijacker = errors.New("Not implement http.Hijacker")
|
||
|
ErrNoEmptyConn = errors.New("Conn ReadBuf must be empty")
|
||
|
)
|
||
|
|
||
|
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
|
||
|
if r.Method != "GET" {
|
||
|
return nil, ErrInvalidMethod
|
||
|
}
|
||
|
|
||
|
if r.Header.Get("Sec-Websocket-Version") != "13" {
|
||
|
return nil, ErrInvalidVersion
|
||
|
}
|
||
|
|
||
|
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
|
||
|
return nil, ErrInvalidUpgrade
|
||
|
}
|
||
|
|
||
|
if strings.ToLower(r.Header.Get("Connection")) != "upgrade" {
|
||
|
return nil, ErrInvalidConnection
|
||
|
}
|
||
|
|
||
|
var acceptKey string
|
||
|
|
||
|
if key := r.Header.Get("Sec-Websocket-key"); len(key) == 0 {
|
||
|
return nil, ErrMissingKey
|
||
|
} else {
|
||
|
acceptKey = calcAcceptKey(key)
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
netConn net.Conn
|
||
|
br *bufio.Reader
|
||
|
err error
|
||
|
)
|
||
|
|
||
|
h, ok := w.(http.Hijacker)
|
||
|
if !ok {
|
||
|
return nil, ErrHijacker
|
||
|
}
|
||
|
|
||
|
var rw *bufio.ReadWriter
|
||
|
netConn, rw, err = h.Hijack()
|
||
|
br = rw.Reader
|
||
|
|
||
|
if br.Buffered() > 0 {
|
||
|
netConn.Close()
|
||
|
return nil, ErrNoEmptyConn
|
||
|
}
|
||
|
|
||
|
c := NewConn(netConn, true)
|
||
|
|
||
|
buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ")
|
||
|
|
||
|
buf.WriteString(acceptKey)
|
||
|
buf.WriteString("\r\n")
|
||
|
|
||
|
subProtol := selectSubProtocol(r)
|
||
|
if len(subProtol) > 0 {
|
||
|
buf.WriteString("Sec-Websocket-Protocol: ")
|
||
|
buf.WriteString(subProtol)
|
||
|
buf.WriteString("\r\n")
|
||
|
}
|
||
|
|
||
|
for k, vs := range responseHeader {
|
||
|
for _, v := range vs {
|
||
|
buf.WriteString(k)
|
||
|
buf.WriteString(": ")
|
||
|
buf.WriteString(v)
|
||
|
buf.WriteString("\r\n")
|
||
|
}
|
||
|
}
|
||
|
buf.WriteString("\r\n")
|
||
|
|
||
|
if _, err = netConn.Write(buf.Bytes()); err != nil {
|
||
|
netConn.Close()
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return c, nil
|
||
|
}
|
||
|
|
||
|
func selectSubProtocol(r *http.Request) string {
|
||
|
h := r.Header.Get("Sec-Websocket-Protocol")
|
||
|
if len(h) == 0 {
|
||
|
return ""
|
||
|
}
|
||
|
return strings.Split(h, ",")[0]
|
||
|
}
|