diff --git a/common/picker/picker.go b/common/picker/picker.go index 49a58f0..0c846cb 100644 --- a/common/picker/picker.go +++ b/common/picker/picker.go @@ -17,15 +17,12 @@ type Picker struct { once sync.Once result interface{} - - firstDone chan struct{} } func newPicker(ctx context.Context, cancel func()) *Picker { return &Picker{ - ctx: ctx, - cancel: cancel, - firstDone: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, } } @@ -42,12 +39,6 @@ func WithTimeout(ctx context.Context, timeout time.Duration) (*Picker, context.C return newPicker(ctx, cancel), ctx } -// WithoutAutoCancel returns a new Picker and an associated Context derived from ctx, -// but it wouldn't cancel context when the first element return. -func WithoutAutoCancel(ctx context.Context) *Picker { - return newPicker(ctx, nil) -} - // Wait blocks until all function calls from the Go method have returned, // then returns the first nil error result (if any) from them. func (p *Picker) Wait() interface{} { @@ -58,17 +49,6 @@ func (p *Picker) Wait() interface{} { return p.result } -// WaitWithoutCancel blocks until the first result return, if timeout will return nil. -// The return of this function will not wait for the cancel of context. -func (p *Picker) WaitWithoutCancel() interface{} { - select { - case <-p.firstDone: - return p.result - case <-p.ctx.Done(): - return p.result - } -} - // Go calls the given function in a new goroutine. // The first call to return a nil error cancels the group; its result will be returned by Wait. func (p *Picker) Go(f func() (interface{}, error)) { @@ -80,7 +60,6 @@ func (p *Picker) Go(f func() (interface{}, error)) { if ret, err := f(); err == nil { p.once.Do(func() { p.result = ret - p.firstDone <- struct{}{} if p.cancel != nil { p.cancel() } diff --git a/common/picker/picker_test.go b/common/picker/picker_test.go index 9e16500..8f0ba95 100644 --- a/common/picker/picker_test.go +++ b/common/picker/picker_test.go @@ -37,30 +37,3 @@ func TestPicker_Timeout(t *testing.T) { number := picker.Wait() assert.Nil(t, number) } - -func TestPicker_WaitWithoutAutoCancel(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*60) - defer cancel() - picker := WithoutAutoCancel(ctx) - - trigger := false - picker.Go(sleepAndSend(ctx, 10, 1)) - picker.Go(func() (interface{}, error) { - timer := time.NewTimer(time.Millisecond * time.Duration(30)) - select { - case <-timer.C: - trigger = true - return 2, nil - case <-ctx.Done(): - return nil, ctx.Err() - } - }) - elm := picker.WaitWithoutCancel() - - assert.NotNil(t, elm) - assert.Equal(t, elm.(int), 1) - - elm = picker.Wait() - assert.True(t, trigger) - assert.Equal(t, elm.(int), 1) -} diff --git a/dns/client.go b/dns/client.go index 91ba7ec..a3d5666 100644 --- a/dns/client.go +++ b/dns/client.go @@ -20,7 +20,22 @@ func (c *client) Exchange(m *D.Msg) (msg *D.Msg, err error) { func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { c.Client.Dialer = dialer.Dialer() - // Please note that miekg/dns ExchangeContext doesn't respond to context cancel. - msg, _, err = c.Client.ExchangeContext(ctx, m, c.Address) - return + // miekg/dns ExchangeContext doesn't respond to context cancel. + // this is a workaround + type result struct { + msg *D.Msg + err error + } + ch := make(chan result, 1) + go func() { + msg, _, err := c.Client.ExchangeContext(ctx, m, c.Address) + ch <- result{msg, err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case ret := <-ch: + return ret.msg, ret.err + } } diff --git a/dns/middleware.go b/dns/middleware.go index 5aa691a..90d5a7f 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -39,7 +39,8 @@ func withFakeIP(fakePool *fakeip.Pool) middleware { msg.Answer = []D.RR{rr} setMsgTTL(msg, 1) - msg.SetReply(r) + msg.SetRcode(r, msg.Rcode) + msg.Authoritative = true w.WriteMsg(msg) return } @@ -55,7 +56,8 @@ func withResolver(resolver *Resolver) handler { D.HandleFailed(w, r) return } - msg.SetReply(r) + msg.SetRcode(r, msg.Rcode) + msg.Authoritative = true w.WriteMsg(msg) return } diff --git a/dns/resolver.go b/dns/resolver.go index 7e76e81..58c2a3e 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "errors" + "math/rand" "net" "strings" "time" @@ -178,15 +179,11 @@ func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err for _, client := range clients { r := client fast.Go(func() (interface{}, error) { - msg, err := r.ExchangeContext(ctx, m) - if err != nil || msg.Rcode != D.RcodeSuccess { - return nil, errors.New("resolve error") - } - return msg, nil + return r.ExchangeContext(ctx, m) }) } - elm := fast.WaitWithoutCancel() + elm := fast.Wait() if elm == nil { return nil, errors.New("All DNS requests failed") } @@ -239,11 +236,12 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) } ips := r.msgToIP(msg) - if len(ips) == 0 { + ipLength := len(ips) + if ipLength == 0 { return nil, errIPNotFound } - ip = ips[0] + ip = ips[rand.Intn(ipLength)] return }