mirror of https://github.com/Dreamacro/clash.git
Fix: limit concurrency number of provider health check
This commit is contained in:
parent
53e17a916b
commit
8d37220566
|
@ -136,6 +136,8 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
|
|||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
defer client.CloseIdleConnections()
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
@ -2,9 +2,9 @@ package provider
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Dreamacro/clash/common/batch"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
@ -60,19 +60,16 @@ func (hc *HealthCheck) touch() {
|
|||
|
||||
func (hc *HealthCheck) check() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultURLTestTimeout)
|
||||
wg := &sync.WaitGroup{}
|
||||
defer cancel()
|
||||
|
||||
b, ctx := batch.WithContext(ctx, batch.WithConcurrencyNum(10))
|
||||
for _, proxy := range hc.proxies {
|
||||
wg.Add(1)
|
||||
|
||||
go func(p C.Proxy) {
|
||||
p.URLTest(ctx, hc.url)
|
||||
wg.Done()
|
||||
}(proxy)
|
||||
p := proxy
|
||||
b.Go(p.Name(), func() (interface{}, error) {
|
||||
return p.URLTest(ctx, hc.url)
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
cancel()
|
||||
b.Wait()
|
||||
}
|
||||
|
||||
func (hc *HealthCheck) close() {
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
package batch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Option = func(b *Batch)
|
||||
|
||||
type Result struct {
|
||||
Value interface{}
|
||||
Err error
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Key string
|
||||
Err error
|
||||
}
|
||||
|
||||
func WithConcurrencyNum(n int) Option {
|
||||
return func(b *Batch) {
|
||||
q := make(chan struct{}, n)
|
||||
for i := 0; i < n; i++ {
|
||||
q <- struct{}{}
|
||||
}
|
||||
b.queue = q
|
||||
}
|
||||
}
|
||||
|
||||
// Batch similar to errgroup, but can control the maximum number of concurrent
|
||||
type Batch struct {
|
||||
result map[string]Result
|
||||
queue chan struct{}
|
||||
wg sync.WaitGroup
|
||||
mux sync.Mutex
|
||||
err *Error
|
||||
once sync.Once
|
||||
cancel func()
|
||||
}
|
||||
|
||||
func (b *Batch) Go(key string, fn func() (interface{}, error)) {
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
if b.queue != nil {
|
||||
<-b.queue
|
||||
defer func() {
|
||||
b.queue <- struct{}{}
|
||||
}()
|
||||
}
|
||||
|
||||
value, err := fn()
|
||||
if err != nil {
|
||||
b.once.Do(func() {
|
||||
b.err = &Error{key, err}
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
ret := Result{value, err}
|
||||
b.mux.Lock()
|
||||
defer b.mux.Unlock()
|
||||
b.result[key] = ret
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *Batch) Wait() *Error {
|
||||
b.wg.Wait()
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
}
|
||||
return b.err
|
||||
}
|
||||
|
||||
func (b *Batch) WaitAndGetResult() (map[string]Result, *Error) {
|
||||
err := b.Wait()
|
||||
return b.Result(), err
|
||||
}
|
||||
|
||||
func (b *Batch) Result() map[string]Result {
|
||||
b.mux.Lock()
|
||||
defer b.mux.Unlock()
|
||||
copy := map[string]Result{}
|
||||
for k, v := range b.result {
|
||||
copy[k] = v
|
||||
}
|
||||
return copy
|
||||
}
|
||||
|
||||
func New(opts ...Option) *Batch {
|
||||
b := &Batch{
|
||||
result: map[string]Result{},
|
||||
}
|
||||
|
||||
for _, o := range opts {
|
||||
o(b)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func WithContext(ctx context.Context, opts ...Option) (*Batch, context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
b := New(opts...)
|
||||
b.cancel = cancel
|
||||
|
||||
return b, ctx
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package batch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBatch(t *testing.T) {
|
||||
b := New()
|
||||
|
||||
now := time.Now()
|
||||
b.Go("foo", func() (interface{}, error) {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
return "foo", nil
|
||||
})
|
||||
b.Go("bar", func() (interface{}, error) {
|
||||
time.Sleep(time.Millisecond * 150)
|
||||
return "bar", nil
|
||||
})
|
||||
result, err := b.WaitAndGetResult()
|
||||
|
||||
assert.Nil(t, err)
|
||||
|
||||
duration := time.Since(now)
|
||||
assert.Less(t, duration, time.Millisecond*200)
|
||||
assert.Equal(t, 2, len(result))
|
||||
|
||||
for k, v := range result {
|
||||
assert.NoError(t, v.Err)
|
||||
assert.Equal(t, k, v.Value.(string))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchWithConcurrencyNum(t *testing.T) {
|
||||
b := New(
|
||||
WithConcurrencyNum(3),
|
||||
)
|
||||
|
||||
now := time.Now()
|
||||
for i := 0; i < 7; i++ {
|
||||
idx := i
|
||||
b.Go(strconv.Itoa(idx), func() (interface{}, error) {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
return strconv.Itoa(idx), nil
|
||||
})
|
||||
}
|
||||
result, _ := b.WaitAndGetResult()
|
||||
duration := time.Since(now)
|
||||
assert.Greater(t, duration, time.Millisecond*260)
|
||||
assert.Equal(t, 7, len(result))
|
||||
|
||||
for k, v := range result {
|
||||
assert.NoError(t, v.Err)
|
||||
assert.Equal(t, k, v.Value.(string))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchContext(t *testing.T) {
|
||||
b, ctx := WithContext(context.Background())
|
||||
|
||||
b.Go("error", func() (interface{}, error) {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
return nil, errors.New("test error")
|
||||
})
|
||||
|
||||
b.Go("ctx", func() (interface{}, error) {
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
})
|
||||
|
||||
result, err := b.WaitAndGetResult()
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "error", err.Key)
|
||||
|
||||
assert.Equal(t, ctx.Err(), result["ctx"].Err)
|
||||
}
|
Loading…
Reference in New Issue