mirror of https://github.com/Dreamacro/clash.git
Feature: support domain in fallback filter (#964)
This commit is contained in:
parent
e09931dcf7
commit
a6444bb449
|
@ -69,6 +69,7 @@ type DNS struct {
|
|||
type FallbackFilter struct {
|
||||
GeoIP bool `yaml:"geoip"`
|
||||
IPCIDR []*net.IPNet `yaml:"ipcidr"`
|
||||
Domain []string `yaml:"domain"`
|
||||
}
|
||||
|
||||
// Experimental config
|
||||
|
@ -103,6 +104,7 @@ type RawDNS struct {
|
|||
type RawFallbackFilter struct {
|
||||
GeoIP bool `yaml:"geoip"`
|
||||
IPCIDR []string `yaml:"ipcidr"`
|
||||
Domain []string `yaml:"domain"`
|
||||
}
|
||||
|
||||
type RawConfig struct {
|
||||
|
@ -561,6 +563,7 @@ func parseDNS(cfg RawDNS, hosts *trie.DomainTrie) (*DNS, error) {
|
|||
if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil {
|
||||
dnsCfg.FallbackFilter.IPCIDR = fallbackip
|
||||
}
|
||||
dnsCfg.FallbackFilter.Domain = cfg.FallbackFilter.Domain
|
||||
|
||||
if cfg.UseHosts {
|
||||
dnsCfg.Hosts = hosts
|
||||
|
|
|
@ -4,9 +4,10 @@ import (
|
|||
"net"
|
||||
|
||||
"github.com/Dreamacro/clash/component/mmdb"
|
||||
"github.com/Dreamacro/clash/component/trie"
|
||||
)
|
||||
|
||||
type fallbackFilter interface {
|
||||
type fallbackIPFilter interface {
|
||||
Match(net.IP) bool
|
||||
}
|
||||
|
||||
|
@ -24,3 +25,22 @@ type ipnetFilter struct {
|
|||
func (inf *ipnetFilter) Match(ip net.IP) bool {
|
||||
return inf.ipnet.Contains(ip)
|
||||
}
|
||||
|
||||
type fallbackDomainFilter interface {
|
||||
Match(domain string) bool
|
||||
}
|
||||
type domainFilter struct {
|
||||
tree *trie.DomainTrie
|
||||
}
|
||||
|
||||
func NewDomainFilter(domains []string) *domainFilter {
|
||||
df := domainFilter{tree: trie.New()}
|
||||
for _, domain := range domains {
|
||||
df.tree.Insert(domain, "")
|
||||
}
|
||||
return &df
|
||||
}
|
||||
|
||||
func (df *domainFilter) Match(domain string) bool {
|
||||
return df.tree.Search(domain) != nil
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Dreamacro/clash/common/cache"
|
||||
|
@ -34,13 +35,14 @@ type result struct {
|
|||
}
|
||||
|
||||
type Resolver struct {
|
||||
ipv6 bool
|
||||
hosts *trie.DomainTrie
|
||||
main []dnsClient
|
||||
fallback []dnsClient
|
||||
fallbackFilters []fallbackFilter
|
||||
group singleflight.Group
|
||||
lruCache *cache.LruCache
|
||||
ipv6 bool
|
||||
hosts *trie.DomainTrie
|
||||
main []dnsClient
|
||||
fallback []dnsClient
|
||||
fallbackDomainFilters []fallbackDomainFilter
|
||||
fallbackIPFilters []fallbackIPFilter
|
||||
group singleflight.Group
|
||||
lruCache *cache.LruCache
|
||||
}
|
||||
|
||||
// ResolveIP request with TypeA and TypeAAAA, priority return TypeA
|
||||
|
@ -78,8 +80,8 @@ func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) {
|
|||
return r.resolveIP(host, D.TypeAAAA)
|
||||
}
|
||||
|
||||
func (r *Resolver) shouldFallback(ip net.IP) bool {
|
||||
for _, filter := range r.fallbackFilters {
|
||||
func (r *Resolver) shouldIPFallback(ip net.IP) bool {
|
||||
for _, filter := range r.fallbackIPFilters {
|
||||
if filter.Match(ip) {
|
||||
return true
|
||||
}
|
||||
|
@ -126,7 +128,7 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) {
|
|||
|
||||
isIPReq := isIPRequest(q)
|
||||
if isIPReq {
|
||||
return r.fallbackExchange(m)
|
||||
return r.ipExchange(m)
|
||||
}
|
||||
|
||||
return r.batchExchange(r.main, m)
|
||||
|
@ -170,19 +172,49 @@ func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err
|
|||
return
|
||||
}
|
||||
|
||||
func (r *Resolver) fallbackExchange(m *D.Msg) (msg *D.Msg, err error) {
|
||||
func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
|
||||
if r.fallback == nil || len(r.fallbackDomainFilters) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
domain := r.msgToDomain(m)
|
||||
|
||||
if domain == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, df := range r.fallbackDomainFilters {
|
||||
if df.Match(domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) {
|
||||
|
||||
onlyFallback := r.shouldOnlyQueryFallback(m)
|
||||
|
||||
if onlyFallback {
|
||||
res := <-r.asyncExchange(r.fallback, m)
|
||||
return res.Msg, res.Error
|
||||
}
|
||||
|
||||
msgCh := r.asyncExchange(r.main, m)
|
||||
if r.fallback == nil {
|
||||
|
||||
if r.fallback == nil { // directly return if no fallback servers are available
|
||||
res := <-msgCh
|
||||
msg, err = res.Msg, res.Error
|
||||
return
|
||||
}
|
||||
|
||||
fallbackMsg := r.asyncExchange(r.fallback, m)
|
||||
res := <-msgCh
|
||||
if res.Error == nil {
|
||||
if ips := r.msgToIP(res.Msg); len(ips) != 0 {
|
||||
if !r.shouldFallback(ips[0]) {
|
||||
msg = res.Msg
|
||||
if !r.shouldIPFallback(ips[0]) {
|
||||
msg = res.Msg // no need to wait for fallback result
|
||||
err = res.Error
|
||||
return msg, err
|
||||
}
|
||||
|
@ -240,6 +272,14 @@ func (r *Resolver) msgToIP(msg *D.Msg) []net.IP {
|
|||
return ips
|
||||
}
|
||||
|
||||
func (r *Resolver) msgToDomain(msg *D.Msg) string {
|
||||
if len(msg.Question) > 0 {
|
||||
return strings.TrimRight(msg.Question[0].Name, ".")
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *Resolver) asyncExchange(client []dnsClient, msg *D.Msg) <-chan *result {
|
||||
ch := make(chan *result, 1)
|
||||
go func() {
|
||||
|
@ -257,6 +297,7 @@ type NameServer struct {
|
|||
type FallbackFilter struct {
|
||||
GeoIP bool
|
||||
IPCIDR []*net.IPNet
|
||||
Domain []string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
|
@ -286,14 +327,19 @@ func NewResolver(config Config) *Resolver {
|
|||
r.fallback = transform(config.Fallback, defaultResolver)
|
||||
}
|
||||
|
||||
fallbackFilters := []fallbackFilter{}
|
||||
fallbackIPFilters := []fallbackIPFilter{}
|
||||
if config.FallbackFilter.GeoIP {
|
||||
fallbackFilters = append(fallbackFilters, &geoipFilter{})
|
||||
fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{})
|
||||
}
|
||||
for _, ipnet := range config.FallbackFilter.IPCIDR {
|
||||
fallbackFilters = append(fallbackFilters, &ipnetFilter{ipnet: ipnet})
|
||||
fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet})
|
||||
}
|
||||
r.fallbackIPFilters = fallbackIPFilters
|
||||
|
||||
if len(config.FallbackFilter.Domain) != 0 {
|
||||
fallbackDomainFilters := []fallbackDomainFilter{NewDomainFilter(config.FallbackFilter.Domain)}
|
||||
r.fallbackDomainFilters = fallbackDomainFilters
|
||||
}
|
||||
r.fallbackFilters = fallbackFilters
|
||||
|
||||
return r
|
||||
}
|
||||
|
|
|
@ -118,6 +118,7 @@ func updateDNS(c *config.DNS) {
|
|||
FallbackFilter: dns.FallbackFilter{
|
||||
GeoIP: c.FallbackFilter.GeoIP,
|
||||
IPCIDR: c.FallbackFilter.IPCIDR,
|
||||
Domain: c.FallbackFilter.Domain,
|
||||
},
|
||||
Default: c.DefaultNameserver,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue