diff --git a/adapters/outbound/base.go b/adapters/outbound/base.go index 5887712..7c22549 100644 --- a/adapters/outbound/base.go +++ b/adapters/outbound/base.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "encoding/json" "errors" "net" @@ -99,7 +100,7 @@ func (p *Proxy) MarshalJSON() ([]byte, error) { } // URLTest get the delay for the specified URL -func (p *Proxy) URLTest(url string) (t uint16, err error) { +func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) { defer func() { p.alive = err == nil record := C.DelayHistory{Time: time.Now()} @@ -123,6 +124,13 @@ func (p *Proxy) URLTest(url string) (t uint16, err error) { return } defer instance.Close() + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return + } + req = req.WithContext(ctx) + transport := &http.Transport{ Dial: func(string, string) (net.Conn, error) { return instance, nil @@ -133,8 +141,9 @@ func (p *Proxy) URLTest(url string) (t uint16, err error) { TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } + client := http.Client{Transport: transport} - resp, err := client.Get(url) + resp, err := client.Do(req) if err != nil { return } diff --git a/adapters/outbound/fallback.go b/adapters/outbound/fallback.go index 913383a..67a6ab0 100644 --- a/adapters/outbound/fallback.go +++ b/adapters/outbound/fallback.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "encoding/json" "errors" "net" @@ -90,7 +91,7 @@ func (f *Fallback) validTest() { for _, p := range f.proxies { go func(p C.Proxy) { - p.URLTest(f.rawURL) + p.URLTest(context.Background(), f.rawURL) wg.Done() }(p) } diff --git a/adapters/outbound/loadbalance.go b/adapters/outbound/loadbalance.go index e870271..9418863 100644 --- a/adapters/outbound/loadbalance.go +++ b/adapters/outbound/loadbalance.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "encoding/json" "errors" "net" @@ -95,7 +96,7 @@ func (lb *LoadBalance) validTest() { for _, p := range lb.proxies { go func(p C.Proxy) { - p.URLTest(lb.rawURL) + p.URLTest(context.Background(), lb.rawURL) wg.Done() }(p) } diff --git a/adapters/outbound/urltest.go b/adapters/outbound/urltest.go index 9a219cc..cfe6b5b 100644 --- a/adapters/outbound/urltest.go +++ b/adapters/outbound/urltest.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "net" - "sync" "sync/atomic" "time" @@ -103,35 +102,22 @@ func (u *URLTest) speedTest() { } defer atomic.StoreInt32(&u.once, 0) - wg := sync.WaitGroup{} - wg.Add(len(u.proxies)) - c := make(chan interface{}) - fast := picker.SelectFast(context.Background(), c) - timer := time.NewTimer(u.interval) - + ctx, cancel := context.WithTimeout(context.Background(), u.interval) + defer cancel() + picker, ctx := picker.WithContext(ctx) for _, p := range u.proxies { - go func(p C.Proxy) { - _, err := p.URLTest(u.rawURL) - if err == nil { - c <- p + picker.Go(func() (interface{}, error) { + _, err := p.URLTest(ctx, u.rawURL) + if err != nil { + return nil, err } - wg.Done() - }(p) + return p, nil + }) } - go func() { - wg.Wait() - close(c) - }() - - select { - case <-timer.C: - // Wait for fast to return or close. - <-fast - case p, open := <-fast: - if open { - u.fast = p.(C.Proxy) - } + fast := picker.Wait() + if fast != nil { + u.fast = fast.(C.Proxy) } } diff --git a/common/picker/picker.go b/common/picker/picker.go index 07e2076..f420679 100644 --- a/common/picker/picker.go +++ b/common/picker/picker.go @@ -1,22 +1,53 @@ package picker -import "context" +import ( + "context" + "sync" +) + +// Picker provides synchronization, and Context cancelation +// for groups of goroutines working on subtasks of a common task. +// Inspired by errGroup +type Picker struct { + cancel func() + + wg sync.WaitGroup + + once sync.Once + result interface{} +} + +// WithContext returns a new Picker and an associated Context derived from ctx. +func WithContext(ctx context.Context) (*Picker, context.Context) { + ctx, cancel := context.WithCancel(ctx) + return &Picker{cancel: cancel}, ctx +} + +// Wait blocks until all function calls from the Go method have returned, +// then returns the first nil error result (if any) from them. +func (p *Picker) Wait() interface{} { + p.wg.Wait() + if p.cancel != nil { + p.cancel() + } + return p.result +} + +// Go calls the given function in a new goroutine. +// The first call to return a nil error cancels the group; its result will be returned by Wait. +func (p *Picker) Go(f func() (interface{}, error)) { + p.wg.Add(1) -func SelectFast(ctx context.Context, in <-chan interface{}) <-chan interface{} { - out := make(chan interface{}) go func() { - select { - case p, open := <-in: - if open { - out <- p - } - case <-ctx.Done(): - } + defer p.wg.Done() - close(out) - for range in { + if ret, err := f(); err == nil { + p.once.Do(func() { + p.result = ret + if p.cancel != nil { + p.cancel() + } + }) } }() - - return out } diff --git a/common/picker/picker_test.go b/common/picker/picker_test.go index f33627f..7b225d3 100644 --- a/common/picker/picker_test.go +++ b/common/picker/picker_test.go @@ -6,39 +6,37 @@ import ( "time" ) -func sleepAndSend(delay int, in chan<- interface{}, input interface{}) { - time.Sleep(time.Millisecond * time.Duration(delay)) - in <- input -} - -func sleepAndClose(delay int, in chan interface{}) { - time.Sleep(time.Millisecond * time.Duration(delay)) - close(in) +func sleepAndSend(ctx context.Context, delay int, input interface{}) func() (interface{}, error) { + return func() (interface{}, error) { + timer := time.NewTimer(time.Millisecond * time.Duration(delay)) + select { + case <-timer.C: + return input, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } } func TestPicker_Basic(t *testing.T) { - in := make(chan interface{}) - fast := SelectFast(context.Background(), in) - go sleepAndSend(20, in, 1) - go sleepAndSend(30, in, 2) - go sleepAndClose(40, in) + picker, ctx := WithContext(context.Background()) + picker.Go(sleepAndSend(ctx, 30, 2)) + picker.Go(sleepAndSend(ctx, 20, 1)) - number, exist := <-fast - if !exist || number != 1 { - t.Error("should recv 1", exist, number) + number := picker.Wait() + if number != nil && number.(int) != 1 { + t.Error("should recv 1", number) } } func TestPicker_Timeout(t *testing.T) { - in := make(chan interface{}) ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*5) defer cancel() - fast := SelectFast(ctx, in) - go sleepAndSend(20, in, 1) - go sleepAndClose(30, in) + picker, ctx := WithContext(ctx) + picker.Go(sleepAndSend(ctx, 20, 1)) - _, exist := <-fast - if exist { - t.Error("should recv false") + number := picker.Wait() + if number != nil { + t.Error("should recv nil") } } diff --git a/constant/adapters.go b/constant/adapters.go index e05cd58..11844bc 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -1,6 +1,7 @@ package constant import ( + "context" "net" "time" ) @@ -44,7 +45,7 @@ type Proxy interface { Alive() bool DelayHistory() []DelayHistory LastDelay() uint16 - URLTest(url string) (uint16, error) + URLTest(ctx context.Context, url string) (uint16, error) } // AdapterType is enum of adapter type diff --git a/dns/resolver.go b/dns/resolver.go index 75c3796..044976f 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -163,32 +163,22 @@ func (r *Resolver) IsFakeIP() bool { } func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) { - in := make(chan interface{}) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - fast := picker.SelectFast(ctx, in) + fast, ctx := picker.WithContext(ctx) - wg := sync.WaitGroup{} - wg.Add(len(clients)) for _, r := range clients { - go func(r resolver) { - defer wg.Done() + fast.Go(func() (interface{}, error) { msg, err := r.ExchangeContext(ctx, m) if err != nil || msg.Rcode != D.RcodeSuccess { - return + return nil, errors.New("resolve error") } - in <- msg - }(r) + return msg, nil + }) } - // release in channel - go func() { - wg.Wait() - close(in) - }() - - elm, exist := <-fast - if !exist { + elm := fast.Wait() + if elm == nil { return nil, errors.New("All DNS requests failed") } diff --git a/hub/route/proxies.go b/hub/route/proxies.go index 7122191..e4c77dc 100644 --- a/hub/route/proxies.go +++ b/hub/route/proxies.go @@ -9,6 +9,7 @@ import ( "time" A "github.com/Dreamacro/clash/adapters/outbound" + "github.com/Dreamacro/clash/common/picker" C "github.com/Dreamacro/clash/constant" T "github.com/Dreamacro/clash/tunnel" @@ -110,27 +111,28 @@ func getProxyDelay(w http.ResponseWriter, r *http.Request) { proxy := r.Context().Value(CtxKeyProxy).(C.Proxy) - sigCh := make(chan uint16) - go func() { - t, err := proxy.URLTest(url) - if err != nil { - sigCh <- 0 - } - sigCh <- t - }() + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(timeout)) + defer cancel() + picker, ctx := picker.WithContext(ctx) + picker.Go(func() (interface{}, error) { + return proxy.URLTest(ctx, url) + }) - select { - case <-time.After(time.Millisecond * time.Duration(timeout)): + elm := picker.Wait() + if elm == nil { render.Status(r, http.StatusRequestTimeout) render.JSON(w, r, ErrRequestTimeout) - case t := <-sigCh: - if t == 0 { - render.Status(r, http.StatusServiceUnavailable) - render.JSON(w, r, newError("An error occurred in the delay test")) - } else { - render.JSON(w, r, render.M{ - "delay": t, - }) - } + return } + + delay := elm.(uint16) + if delay == 0 { + render.Status(r, http.StatusServiceUnavailable) + render.JSON(w, r, newError("An error occurred in the delay test")) + return + } + + render.JSON(w, r, render.M{ + "delay": delay, + }) }