mirror of https://github.com/fatedier/frp.git
51 lines
1.1 KiB
Go
51 lines
1.1 KiB
Go
package net
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/url"
|
|
|
|
libnet "github.com/fatedier/golib/net"
|
|
"golang.org/x/net/websocket"
|
|
)
|
|
|
|
func DialHookCustomTLSHeadByte(enableTLS bool, disableCustomTLSHeadByte bool) libnet.AfterHookFunc {
|
|
return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
|
|
if enableTLS && !disableCustomTLSHeadByte {
|
|
_, err := c.Write([]byte{byte(FRPTLSHeadByte)})
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
return ctx, c, nil
|
|
}
|
|
}
|
|
|
|
func DialHookWebsocket(protocol string, host string) libnet.AfterHookFunc {
|
|
return func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
|
|
if protocol != "wss" {
|
|
protocol = "ws"
|
|
}
|
|
if host == "" {
|
|
host = addr
|
|
}
|
|
addr = protocol + "://" + host + FrpWebsocketPath
|
|
uri, err := url.Parse(addr)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
origin := "http://" + uri.Host
|
|
cfg, err := websocket.NewConfig(addr, origin)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
conn, err := websocket.NewClient(cfg, c)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return ctx, conn, nil
|
|
}
|
|
}
|