From e46719b2fef404d2e531c0dd9055b1c95ff01e2e Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Fri, 30 Jun 2017 04:44:44 -0700 Subject: [PATCH] Add ExchangeContext methods. (#497) These obey the timeouts provided in a Context. --- client.go | 78 +++++++++++++++++++++++++++++++++++++++++++++----- client_test.go | 25 ++++++++++++---- 2 files changed, 91 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 301dab9c..1c14a19d 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ package dns import ( "bytes" + "context" "crypto/tls" "encoding/binary" "io" @@ -70,6 +71,43 @@ func Exchange(m *Msg, a string) (r *Msg, err error) { return r, err } +// ExchangeContext performs a synchronous UDP query, like Exchange. It +// additionally obeys deadlines from the passed Context. +func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) { + // Combine context deadline with built-in timeout. Context chooses whichever + // is sooner. + timeoutCtx, cancel := context.WithTimeout(ctx, dnsTimeout) + defer cancel() + deadline, _ := timeoutCtx.Deadline() + + co := new(Conn) + dialer := net.Dialer{} + co.Conn, err = dialer.DialContext(timeoutCtx, "udp", a) + if err != nil { + return nil, err + } + + defer co.Conn.Close() + + opt := m.IsEdns0() + // If EDNS0 is used use that for size. + if opt != nil && opt.UDPSize() >= MinMsgSize { + co.UDPSize = opt.UDPSize() + } + + co.SetWriteDeadline(deadline) + if err = co.WriteMsg(m); err != nil { + return nil, err + } + + co.SetReadDeadline(deadline) + r, err = co.ReadMsg() + if err == nil && r.Id != m.Id { + err = ErrId + } + return r, err +} + // ExchangeConn performs a synchronous query. It sends the message m via the connection // c and waits for a reply. The connection c is not closed by ExchangeConn. // This function is going away, but can easily be mimicked: @@ -106,8 +144,18 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { // buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit // of 512 bytes. func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { + return c.ExchangeContext(context.Background(), m, a) +} + +// ExchangeContext acts like Exchange, but honors the deadline on the provided +// context, if present. If there is both a context deadline and a configured +// timeout on the client, the earliest of the two takes effect. +func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) ( + r *Msg, + rtt time.Duration, + err error) { if !c.SingleInflight { - return c.exchange(m, a) + return c.exchange(ctx, m, a) } // This adds a bunch of garbage, TODO(miek). t := "nop" @@ -119,7 +167,7 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro cl = cl1 } r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { - return c.exchange(m, a) + return c.exchange(ctx, m, a) }) if r != nil && shared { r = r.Copy() @@ -154,7 +202,7 @@ func (c *Client) writeTimeout() time.Duration { return dnsTimeout } -func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { +func (c *Client) exchange(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) { var co *Conn network := "udp" tls := false @@ -180,10 +228,13 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro deadline = time.Now().Add(c.Timeout) } + dialDeadline := deadlineOrTimeoutOrCtx(ctx, deadline, c.dialTimeout()) + dialTimeout := dialDeadline.Sub(time.Now()) + if tls { - co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout()) + co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, dialTimeout) } else { - co, err = DialTimeout(network, a, c.dialTimeout()) + co, err = DialTimeout(network, a, dialTimeout) } if err != nil { @@ -202,12 +253,12 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro } co.TsigSecret = c.TsigSecret - co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout())) + co.SetWriteDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.writeTimeout())) if err = co.WriteMsg(m); err != nil { return nil, 0, err } - co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout())) + co.SetReadDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.readTimeout())) r, err = co.ReadMsg() if err == nil && r.Id != m.Id { err = ErrId @@ -459,9 +510,22 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout return conn, nil } +// deadlineOrTimeout chooses between the provided deadline and timeout +// by always preferring the deadline so long as it's non-zero (regardless +// of which is bigger), and returns the equivalent deadline value. func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time { if deadline.IsZero() { return time.Now().Add(timeout) } return deadline } + +// deadlineOrTimeoutOrCtx returns the earliest of: a context deadline, or the +// output of deadlineOrtimeout. +func deadlineOrTimeoutOrCtx(ctx context.Context, deadline time.Time, timeout time.Duration) time.Time { + result := deadlineOrTimeout(deadline, timeout) + if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(result) { + result = ctxDeadline + } + return result +} diff --git a/client_test.go b/client_test.go index dee585f3..d29e4e3b 100644 --- a/client_test.go +++ b/client_test.go @@ -1,6 +1,7 @@ package dns import ( + "context" "crypto/tls" "fmt" "net" @@ -423,7 +424,7 @@ func TestTimeout(t *testing.T) { // Use a channel + timeout to ensure we don't get stuck if the // Client Timeout is not working properly - done := make(chan struct{}) + done := make(chan struct{}, 2) timeout := time.Millisecond allowable := timeout + (10 * time.Millisecond) @@ -435,14 +436,28 @@ func TestTimeout(t *testing.T) { c := &Client{Timeout: timeout} _, _, err := c.Exchange(m, addrstr) if err == nil { - t.Error("no timeout using Client") + t.Error("no timeout using Client.Exchange") } done <- struct{}{} }() - select { - case <-done: - case <-time.After(abortAfter): + go func() { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + c := &Client{} + _, _, err := c.ExchangeContext(ctx, m, addrstr) + if err == nil { + t.Error("no timeout using Client.ExchangeContext") + } + done <- struct{}{} + }() + + // Wait for both the Exchange and ExchangeContext tests to be done. + for i := 0; i < 2; i++ { + select { + case <-done: + case <-time.After(abortAfter): + } } length := time.Since(start)