From 50b3d497f6ff89ac05aa7d648ced801c1d6b8a8e Mon Sep 17 00:00:00 2001 From: Dreamacro <8615343+Dreamacro@users.noreply.github.com> Date: Thu, 22 Oct 2020 00:11:49 +0800 Subject: [PATCH] Feature: use native syscall to bind interface on Linux and macOS --- component/dialer/bind.go | 104 ++++++++++++++++++++++++++ component/dialer/bind_darwin.go | 46 ++++++++++++ component/dialer/bind_linux.go | 26 +++++++ component/dialer/bind_others.go | 13 ++++ component/dialer/dialer.go | 26 ++----- component/dialer/hook.go | 125 +++----------------------------- 6 files changed, 205 insertions(+), 135 deletions(-) create mode 100644 component/dialer/bind.go create mode 100644 component/dialer/bind_darwin.go create mode 100644 component/dialer/bind_linux.go create mode 100644 component/dialer/bind_others.go diff --git a/component/dialer/bind.go b/component/dialer/bind.go new file mode 100644 index 0000000..cb24a8b --- /dev/null +++ b/component/dialer/bind.go @@ -0,0 +1,104 @@ +package dialer + +import ( + "errors" + "net" +) + +var ( + errPlatformNotSupport = errors.New("unsupport platform") +) + +func lookupTCPAddr(ip net.IP, addrs []net.Addr) (*net.TCPAddr, error) { + ipv4 := ip.To4() != nil + + for _, elm := range addrs { + addr, ok := elm.(*net.IPNet) + if !ok { + continue + } + + addrV4 := addr.IP.To4() != nil + + if addrV4 && ipv4 { + return &net.TCPAddr{IP: addr.IP, Port: 0}, nil + } else if !addrV4 && !ipv4 { + return &net.TCPAddr{IP: addr.IP, Port: 0}, nil + } + } + + return nil, ErrAddrNotFound +} + +func lookupUDPAddr(ip net.IP, addrs []net.Addr) (*net.UDPAddr, error) { + ipv4 := ip.To4() != nil + + for _, elm := range addrs { + addr, ok := elm.(*net.IPNet) + if !ok { + continue + } + + addrV4 := addr.IP.To4() != nil + + if addrV4 && ipv4 { + return &net.UDPAddr{IP: addr.IP, Port: 0}, nil + } else if !addrV4 && !ipv4 { + return &net.UDPAddr{IP: addr.IP, Port: 0}, nil + } + } + + return nil, ErrAddrNotFound +} + +func fallbackBindToDialer(dialer *net.Dialer, network string, ip net.IP, name string) error { + iface, err := net.InterfaceByName(name) + if err != nil { + return err + } + + addrs, err := iface.Addrs() + if err != nil { + return err + } + + switch network { + case "tcp", "tcp4", "tcp6": + if addr, err := lookupTCPAddr(ip, addrs); err == nil { + dialer.LocalAddr = addr + } else { + return err + } + case "udp", "udp4", "udp6": + if addr, err := lookupUDPAddr(ip, addrs); err == nil { + dialer.LocalAddr = addr + } else { + return err + } + } + + return nil +} + +func fallbackBindToListenConfig(name string) (string, error) { + iface, err := net.InterfaceByName(name) + if err != nil { + return "", err + } + + addrs, err := iface.Addrs() + if err != nil { + return "", err + } + + for _, elm := range addrs { + addr, ok := elm.(*net.IPNet) + if !ok || addr.IP.To4() == nil { + continue + } + + return net.JoinHostPort(addr.IP.String(), "0"), nil + } + + return "", ErrAddrNotFound +} diff --git a/component/dialer/bind_darwin.go b/component/dialer/bind_darwin.go new file mode 100644 index 0000000..d46c673 --- /dev/null +++ b/component/dialer/bind_darwin.go @@ -0,0 +1,46 @@ +package dialer + +import ( + "net" + "syscall" +) + +func bindIfaceToDialer(dialer *net.Dialer, ifaceName string) error { + iface, err := net.InterfaceByName(ifaceName) + if err != nil { + return err + } + + dialer.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + switch network { + case "tcp4", "udp4": + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, iface.Index) + case "tcp6", "udp6": + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, iface.Index) + } + }) + } + + return nil +} + +func bindIfaceToListenConfig(lc *net.ListenConfig, ifaceName string) error { + iface, err := net.InterfaceByName(ifaceName) + if err != nil { + return err + } + + lc.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + switch network { + case "tcp4", "udp4": + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, iface.Index) + case "tcp6", "udp6": + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, iface.Index) + } + }) + } + + return nil +} diff --git a/component/dialer/bind_linux.go b/component/dialer/bind_linux.go new file mode 100644 index 0000000..8afa3d3 --- /dev/null +++ b/component/dialer/bind_linux.go @@ -0,0 +1,26 @@ +package dialer + +import ( + "net" + "syscall" +) + +func bindIfaceToDialer(dialer *net.Dialer, ifaceName string) error { + dialer.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + syscall.BindToDevice(int(fd), ifaceName) + }) + } + + return nil +} + +func bindIfaceToListenConfig(lc *net.ListenConfig, ifaceName string) error { + lc.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + syscall.BindToDevice(int(fd), ifaceName) + }) + } + + return nil +} diff --git a/component/dialer/bind_others.go b/component/dialer/bind_others.go new file mode 100644 index 0000000..87bb47d --- /dev/null +++ b/component/dialer/bind_others.go @@ -0,0 +1,13 @@ +// +build !linux,!darwin + +package dialer + +import "net" + +func bindIfaceToDialer(dialer *net.Dialer, ifaceName string) error { + return errNotSupport +} + +func bindIfaceToListenConfig(lc *net.ListenConfig, ifaceName string) error { + return errNotSupport +} diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index b0be7a6..be26681 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -19,17 +19,6 @@ func Dialer() (*net.Dialer, error) { return dialer, nil } -func ListenConfig() (*net.ListenConfig, error) { - cfg := &net.ListenConfig{} - if ListenConfigHook != nil { - if err := ListenConfigHook(cfg); err != nil { - return nil, err - } - } - - return cfg, nil -} - func Dial(network, address string) (net.Conn, error) { return DialContext(context.Background(), network, address) } @@ -73,19 +62,16 @@ func DialContext(ctx context.Context, network, address string) (net.Conn, error) } func ListenPacket(network, address string) (net.PacketConn, error) { - lc, err := ListenConfig() - if err != nil { - return nil, err - } - - if ListenPacketHook != nil && address == "" { - ip, err := ListenPacketHook() + cfg := &net.ListenConfig{} + if ListenPacketHook != nil { + var err error + address, err = ListenPacketHook(cfg, address) if err != nil { return nil, err } - address = net.JoinHostPort(ip.String(), "0") } - return lc.ListenPacket(context.Background(), network, address) + + return cfg.ListenPacket(context.Background(), network, address) } func dualStackDialContext(ctx context.Context, network, address string) (net.Conn, error) { diff --git a/component/dialer/hook.go b/component/dialer/hook.go index d4c955a..356e4b2 100644 --- a/component/dialer/hook.go +++ b/component/dialer/hook.go @@ -3,20 +3,15 @@ package dialer import ( "errors" "net" - "time" - - "github.com/Dreamacro/clash/common/singledo" ) type DialerHookFunc = func(dialer *net.Dialer) error type DialHookFunc = func(dialer *net.Dialer, network string, ip net.IP) error -type ListenConfigHookFunc = func(*net.ListenConfig) error -type ListenPacketHookFunc = func() (net.IP, error) +type ListenPacketHookFunc = func(lc *net.ListenConfig, address string) (string, error) var ( DialerHook DialerHookFunc DialHook DialHookFunc - ListenConfigHook ListenConfigHookFunc ListenPacketHook ListenPacketHookFunc ) @@ -25,124 +20,24 @@ var ( ErrNetworkNotSupport = errors.New("network not support") ) -func lookupTCPAddr(ip net.IP, addrs []net.Addr) (*net.TCPAddr, error) { - ipv4 := ip.To4() != nil - - for _, elm := range addrs { - addr, ok := elm.(*net.IPNet) - if !ok { - continue - } - - addrV4 := addr.IP.To4() != nil - - if addrV4 && ipv4 { - return &net.TCPAddr{IP: addr.IP, Port: 0}, nil - } else if !addrV4 && !ipv4 { - return &net.TCPAddr{IP: addr.IP, Port: 0}, nil - } - } - - return nil, ErrAddrNotFound -} - -func lookupUDPAddr(ip net.IP, addrs []net.Addr) (*net.UDPAddr, error) { - ipv4 := ip.To4() != nil - - for _, elm := range addrs { - addr, ok := elm.(*net.IPNet) - if !ok { - continue - } - - addrV4 := addr.IP.To4() != nil - - if addrV4 && ipv4 { - return &net.UDPAddr{IP: addr.IP, Port: 0}, nil - } else if !addrV4 && !ipv4 { - return &net.UDPAddr{IP: addr.IP, Port: 0}, nil - } - } - - return nil, ErrAddrNotFound -} - func ListenPacketWithInterface(name string) ListenPacketHookFunc { - single := singledo.NewSingle(5 * time.Second) - - return func() (net.IP, error) { - elm, err, _ := single.Do(func() (interface{}, error) { - iface, err := net.InterfaceByName(name) - if err != nil { - return nil, err - } - - addrs, err := iface.Addrs() - if err != nil { - return nil, err - } - - return addrs, nil - }) - - if err != nil { - return nil, err + return func(lc *net.ListenConfig, address string) (string, error) { + err := bindIfaceToListenConfig(lc, name) + if err == errPlatformNotSupport { + address, err = fallbackBindToListenConfig(name) } - addrs := elm.([]net.Addr) - - for _, elm := range addrs { - addr, ok := elm.(*net.IPNet) - if !ok || addr.IP.To4() == nil { - continue - } - - return addr.IP, nil - } - - return nil, ErrAddrNotFound + return address, err } } func DialerWithInterface(name string) DialHookFunc { - single := singledo.NewSingle(5 * time.Second) - return func(dialer *net.Dialer, network string, ip net.IP) error { - elm, err, _ := single.Do(func() (interface{}, error) { - iface, err := net.InterfaceByName(name) - if err != nil { - return nil, err - } - - addrs, err := iface.Addrs() - if err != nil { - return nil, err - } - - return addrs, nil - }) - - if err != nil { - return err + err := bindIfaceToDialer(dialer, name) + if err == errPlatformNotSupport { + err = fallbackBindToDialer(dialer, network, ip, name) } - addrs := elm.([]net.Addr) - - switch network { - case "tcp", "tcp4", "tcp6": - if addr, err := lookupTCPAddr(ip, addrs); err == nil { - dialer.LocalAddr = addr - } else { - return err - } - case "udp", "udp4", "udp6": - if addr, err := lookupUDPAddr(ip, addrs); err == nil { - dialer.LocalAddr = addr - } else { - return err - } - } - - return nil + return err } }