mirror of https://github.com/Dreamacro/clash.git
Optimization: refactor picker
This commit is contained in:
parent
0eff8516c0
commit
7c6c147a18
|
@ -1,6 +1,7 @@
|
|||
package adapters
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
|
@ -99,7 +100,7 @@ func (p *Proxy) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
// URLTest get the delay for the specified URL
|
||||
func (p *Proxy) URLTest(url string) (t uint16, err error) {
|
||||
func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
|
||||
defer func() {
|
||||
p.alive = err == nil
|
||||
record := C.DelayHistory{Time: time.Now()}
|
||||
|
@ -123,6 +124,13 @@ func (p *Proxy) URLTest(url string) (t uint16, err error) {
|
|||
return
|
||||
}
|
||||
defer instance.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
transport := &http.Transport{
|
||||
Dial: func(string, string) (net.Conn, error) {
|
||||
return instance, nil
|
||||
|
@ -133,8 +141,9 @@ func (p *Proxy) URLTest(url string) (t uint16, err error) {
|
|||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
client := http.Client{Transport: transport}
|
||||
resp, err := client.Get(url)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package adapters
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
|
@ -90,7 +91,7 @@ func (f *Fallback) validTest() {
|
|||
|
||||
for _, p := range f.proxies {
|
||||
go func(p C.Proxy) {
|
||||
p.URLTest(f.rawURL)
|
||||
p.URLTest(context.Background(), f.rawURL)
|
||||
wg.Done()
|
||||
}(p)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package adapters
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
|
@ -95,7 +96,7 @@ func (lb *LoadBalance) validTest() {
|
|||
|
||||
for _, p := range lb.proxies {
|
||||
go func(p C.Proxy) {
|
||||
p.URLTest(lb.rawURL)
|
||||
p.URLTest(context.Background(), lb.rawURL)
|
||||
wg.Done()
|
||||
}(p)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
|
@ -103,35 +102,22 @@ func (u *URLTest) speedTest() {
|
|||
}
|
||||
defer atomic.StoreInt32(&u.once, 0)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(u.proxies))
|
||||
c := make(chan interface{})
|
||||
fast := picker.SelectFast(context.Background(), c)
|
||||
timer := time.NewTimer(u.interval)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), u.interval)
|
||||
defer cancel()
|
||||
picker, ctx := picker.WithContext(ctx)
|
||||
for _, p := range u.proxies {
|
||||
go func(p C.Proxy) {
|
||||
_, err := p.URLTest(u.rawURL)
|
||||
if err == nil {
|
||||
c <- p
|
||||
picker.Go(func() (interface{}, error) {
|
||||
_, err := p.URLTest(ctx, u.rawURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wg.Done()
|
||||
}(p)
|
||||
return p, nil
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(c)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
// Wait for fast to return or close.
|
||||
<-fast
|
||||
case p, open := <-fast:
|
||||
if open {
|
||||
u.fast = p.(C.Proxy)
|
||||
}
|
||||
fast := picker.Wait()
|
||||
if fast != nil {
|
||||
u.fast = fast.(C.Proxy)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,22 +1,53 @@
|
|||
package picker
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Picker provides synchronization, and Context cancelation
|
||||
// for groups of goroutines working on subtasks of a common task.
|
||||
// Inspired by errGroup
|
||||
type Picker struct {
|
||||
cancel func()
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
||||
once sync.Once
|
||||
result interface{}
|
||||
}
|
||||
|
||||
// WithContext returns a new Picker and an associated Context derived from ctx.
|
||||
func WithContext(ctx context.Context) (*Picker, context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Picker{cancel: cancel}, ctx
|
||||
}
|
||||
|
||||
// Wait blocks until all function calls from the Go method have returned,
|
||||
// then returns the first nil error result (if any) from them.
|
||||
func (p *Picker) Wait() interface{} {
|
||||
p.wg.Wait()
|
||||
if p.cancel != nil {
|
||||
p.cancel()
|
||||
}
|
||||
return p.result
|
||||
}
|
||||
|
||||
// Go calls the given function in a new goroutine.
|
||||
// The first call to return a nil error cancels the group; its result will be returned by Wait.
|
||||
func (p *Picker) Go(f func() (interface{}, error)) {
|
||||
p.wg.Add(1)
|
||||
|
||||
func SelectFast(ctx context.Context, in <-chan interface{}) <-chan interface{} {
|
||||
out := make(chan interface{})
|
||||
go func() {
|
||||
select {
|
||||
case p, open := <-in:
|
||||
if open {
|
||||
out <- p
|
||||
}
|
||||
case <-ctx.Done():
|
||||
}
|
||||
defer p.wg.Done()
|
||||
|
||||
close(out)
|
||||
for range in {
|
||||
if ret, err := f(); err == nil {
|
||||
p.once.Do(func() {
|
||||
p.result = ret
|
||||
if p.cancel != nil {
|
||||
p.cancel()
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
return out
|
||||
}
|
||||
|
|
|
@ -6,39 +6,37 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func sleepAndSend(delay int, in chan<- interface{}, input interface{}) {
|
||||
time.Sleep(time.Millisecond * time.Duration(delay))
|
||||
in <- input
|
||||
}
|
||||
|
||||
func sleepAndClose(delay int, in chan interface{}) {
|
||||
time.Sleep(time.Millisecond * time.Duration(delay))
|
||||
close(in)
|
||||
func sleepAndSend(ctx context.Context, delay int, input interface{}) func() (interface{}, error) {
|
||||
return func() (interface{}, error) {
|
||||
timer := time.NewTimer(time.Millisecond * time.Duration(delay))
|
||||
select {
|
||||
case <-timer.C:
|
||||
return input, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPicker_Basic(t *testing.T) {
|
||||
in := make(chan interface{})
|
||||
fast := SelectFast(context.Background(), in)
|
||||
go sleepAndSend(20, in, 1)
|
||||
go sleepAndSend(30, in, 2)
|
||||
go sleepAndClose(40, in)
|
||||
picker, ctx := WithContext(context.Background())
|
||||
picker.Go(sleepAndSend(ctx, 30, 2))
|
||||
picker.Go(sleepAndSend(ctx, 20, 1))
|
||||
|
||||
number, exist := <-fast
|
||||
if !exist || number != 1 {
|
||||
t.Error("should recv 1", exist, number)
|
||||
number := picker.Wait()
|
||||
if number != nil && number.(int) != 1 {
|
||||
t.Error("should recv 1", number)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPicker_Timeout(t *testing.T) {
|
||||
in := make(chan interface{})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*5)
|
||||
defer cancel()
|
||||
fast := SelectFast(ctx, in)
|
||||
go sleepAndSend(20, in, 1)
|
||||
go sleepAndClose(30, in)
|
||||
picker, ctx := WithContext(ctx)
|
||||
picker.Go(sleepAndSend(ctx, 20, 1))
|
||||
|
||||
_, exist := <-fast
|
||||
if exist {
|
||||
t.Error("should recv false")
|
||||
number := picker.Wait()
|
||||
if number != nil {
|
||||
t.Error("should recv nil")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package constant
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
@ -44,7 +45,7 @@ type Proxy interface {
|
|||
Alive() bool
|
||||
DelayHistory() []DelayHistory
|
||||
LastDelay() uint16
|
||||
URLTest(url string) (uint16, error)
|
||||
URLTest(ctx context.Context, url string) (uint16, error)
|
||||
}
|
||||
|
||||
// AdapterType is enum of adapter type
|
||||
|
|
|
@ -163,32 +163,22 @@ func (r *Resolver) IsFakeIP() bool {
|
|||
}
|
||||
|
||||
func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) {
|
||||
in := make(chan interface{})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
fast := picker.SelectFast(ctx, in)
|
||||
fast, ctx := picker.WithContext(ctx)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(clients))
|
||||
for _, r := range clients {
|
||||
go func(r resolver) {
|
||||
defer wg.Done()
|
||||
fast.Go(func() (interface{}, error) {
|
||||
msg, err := r.ExchangeContext(ctx, m)
|
||||
if err != nil || msg.Rcode != D.RcodeSuccess {
|
||||
return
|
||||
return nil, errors.New("resolve error")
|
||||
}
|
||||
in <- msg
|
||||
}(r)
|
||||
return msg, nil
|
||||
})
|
||||
}
|
||||
|
||||
// release in channel
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(in)
|
||||
}()
|
||||
|
||||
elm, exist := <-fast
|
||||
if !exist {
|
||||
elm := fast.Wait()
|
||||
if elm == nil {
|
||||
return nil, errors.New("All DNS requests failed")
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
A "github.com/Dreamacro/clash/adapters/outbound"
|
||||
"github.com/Dreamacro/clash/common/picker"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
T "github.com/Dreamacro/clash/tunnel"
|
||||
|
||||
|
@ -110,27 +111,28 @@ func getProxyDelay(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
proxy := r.Context().Value(CtxKeyProxy).(C.Proxy)
|
||||
|
||||
sigCh := make(chan uint16)
|
||||
go func() {
|
||||
t, err := proxy.URLTest(url)
|
||||
if err != nil {
|
||||
sigCh <- 0
|
||||
}
|
||||
sigCh <- t
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(timeout))
|
||||
defer cancel()
|
||||
picker, ctx := picker.WithContext(ctx)
|
||||
picker.Go(func() (interface{}, error) {
|
||||
return proxy.URLTest(ctx, url)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-time.After(time.Millisecond * time.Duration(timeout)):
|
||||
elm := picker.Wait()
|
||||
if elm == nil {
|
||||
render.Status(r, http.StatusRequestTimeout)
|
||||
render.JSON(w, r, ErrRequestTimeout)
|
||||
case t := <-sigCh:
|
||||
if t == 0 {
|
||||
render.Status(r, http.StatusServiceUnavailable)
|
||||
render.JSON(w, r, newError("An error occurred in the delay test"))
|
||||
} else {
|
||||
render.JSON(w, r, render.M{
|
||||
"delay": t,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
delay := elm.(uint16)
|
||||
if delay == 0 {
|
||||
render.Status(r, http.StatusServiceUnavailable)
|
||||
render.JSON(w, r, newError("An error occurred in the delay test"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, r, render.M{
|
||||
"delay": delay,
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue