clash/dns/client.go

269 lines
5.0 KiB
Go

package dns
import (
"context"
"crypto/tls"
"errors"
"net"
"strings"
"sync"
"time"
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/common/picker"
C "github.com/Dreamacro/clash/constant"
D "github.com/miekg/dns"
geoip2 "github.com/oschwald/geoip2-golang"
)
var (
globalSessionCache = tls.NewLRUClientSessionCache(64)
mmdb *geoip2.Reader
once sync.Once
resolver *Resolver
)
type Resolver struct {
ipv6 bool
mapping bool
fallback []*nameserver
main []*nameserver
cache *cache.Cache
}
type result struct {
Msg *D.Msg
Error error
}
func isIPRequest(q D.Question) bool {
if q.Qclass == D.ClassINET && (q.Qtype == D.TypeA || q.Qtype == D.TypeAAAA) {
return true
}
return false
}
func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) {
if len(m.Question) == 0 {
return nil, errors.New("should have one question at least")
}
q := m.Question[0]
cache, expireTime := r.cache.GetWithExpire(q.String())
if cache != nil {
msg = cache.(*D.Msg).Copy()
setMsgTTL(msg, uint32(expireTime.Sub(time.Now()).Seconds()))
return
}
defer func() {
if msg == nil {
return
}
putMsgToCache(r.cache, q.String(), msg)
if r.mapping {
ips := r.msgToIP(msg)
for _, ip := range ips {
putMsgToCache(r.cache, ip.String(), msg)
}
}
}()
isIPReq := isIPRequest(q)
if isIPReq {
msg, err = r.resolveIP(m)
return
}
msg, err = r.exchange(r.main, m)
return
}
func (r *Resolver) exchange(servers []*nameserver, m *D.Msg) (msg *D.Msg, err error) {
in := make(chan interface{})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
fast := picker.SelectFast(ctx, in)
wg := sync.WaitGroup{}
wg.Add(len(servers))
for _, server := range servers {
go func(s *nameserver) {
defer wg.Done()
msg, _, err := s.Client.Exchange(m, s.Address)
if err != nil || msg.Rcode != D.RcodeSuccess {
return
}
in <- &result{Msg: msg, Error: err}
}(server)
}
// release in channel
go func() {
wg.Wait()
close(in)
}()
elm, exist := <-fast
if !exist {
return nil, errors.New("All DNS requests failed")
}
resp := elm.(*result)
msg, err = resp.Msg, resp.Error
return
}
func (r *Resolver) resolveIP(m *D.Msg) (msg *D.Msg, err error) {
msgCh := r.resolve(r.main, m)
if r.fallback == nil {
res := <-msgCh
msg, err = res.Msg, res.Error
return
}
fallbackMsg := r.resolve(r.fallback, m)
res := <-msgCh
if res.Error == nil {
if mmdb == nil {
return nil, errors.New("GeoIP can't use")
}
if ips := r.msgToIP(res.Msg); len(ips) != 0 {
if record, _ := mmdb.Country(ips[0]); record.Country.IsoCode == "CN" || record.Country.IsoCode == "" {
// release channel
go func() { <-fallbackMsg }()
msg = res.Msg
return msg, err
}
}
}
res = <-fallbackMsg
msg, err = res.Msg, res.Error
return
}
func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) {
ip = net.ParseIP(host)
if ip != nil {
return ip, nil
}
query := &D.Msg{}
dnsType := D.TypeA
if r.ipv6 {
dnsType = D.TypeAAAA
}
query.SetQuestion(D.Fqdn(host), dnsType)
msg, err := r.Exchange(query)
if err != nil {
return nil, err
}
ips := r.msgToIP(msg)
if len(ips) == 0 {
return nil, errors.New("can't found ip")
}
ip = ips[0]
return
}
func (r *Resolver) msgToIP(msg *D.Msg) []net.IP {
ips := []net.IP{}
for _, answer := range msg.Answer {
switch ans := answer.(type) {
case *D.AAAA:
ips = append(ips, ans.AAAA)
case *D.A:
ips = append(ips, ans.A)
}
}
for _, extra := range msg.Extra {
switch record := extra.(type) {
case *D.AAAA:
ips = append(ips, record.AAAA)
case *D.A:
ips = append(ips, record.A)
}
}
return ips
}
func (r *Resolver) IPToHost(ip net.IP) (string, bool) {
cache := r.cache.Get(ip.String())
if cache == nil {
return "", false
}
fqdn := cache.(*D.Msg).Question[0].Name
return strings.TrimRight(fqdn, "."), true
}
func (r *Resolver) resolve(client []*nameserver, msg *D.Msg) <-chan *result {
ch := make(chan *result)
go func() {
res, err := r.exchange(client, msg)
ch <- &result{Msg: res, Error: err}
}()
return ch
}
func (r *Resolver) IsMapping() bool {
return r.mapping
}
type NameServer struct {
Net string
Addr string
}
type nameserver struct {
Client *D.Client
Address string
}
type Config struct {
Main, Fallback []NameServer
IPv6 bool
EnhancedMode EnhancedMode
}
func transform(servers []NameServer) []*nameserver {
var ret []*nameserver
for _, s := range servers {
ret = append(ret, &nameserver{
Client: &D.Client{
Net: s.Net,
TLSConfig: &tls.Config{
ClientSessionCache: globalSessionCache,
},
},
Address: s.Addr,
})
}
return ret
}
func New(config Config) *Resolver {
once.Do(func() {
mmdb, _ = geoip2.Open(C.Path.MMDB())
})
r := &Resolver{
main: transform(config.Main),
ipv6: config.IPv6,
cache: cache.New(time.Second * 60),
mapping: config.EnhancedMode == MAPPING,
}
if config.Fallback != nil {
r.fallback = transform(config.Fallback)
}
return r
}