Add ExchangeContext methods. (#497)
These obey the timeouts provided in a Context.
This commit is contained in:
parent
e78414ef75
commit
e46719b2fe
78
client.go
78
client.go
|
@ -4,6 +4,7 @@ package dns
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
|
@ -70,6 +71,43 @@ func Exchange(m *Msg, a string) (r *Msg, err error) {
|
|||
return r, err
|
||||
}
|
||||
|
||||
// ExchangeContext performs a synchronous UDP query, like Exchange. It
|
||||
// additionally obeys deadlines from the passed Context.
|
||||
func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
|
||||
// Combine context deadline with built-in timeout. Context chooses whichever
|
||||
// is sooner.
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||
defer cancel()
|
||||
deadline, _ := timeoutCtx.Deadline()
|
||||
|
||||
co := new(Conn)
|
||||
dialer := net.Dialer{}
|
||||
co.Conn, err = dialer.DialContext(timeoutCtx, "udp", a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer co.Conn.Close()
|
||||
|
||||
opt := m.IsEdns0()
|
||||
// If EDNS0 is used use that for size.
|
||||
if opt != nil && opt.UDPSize() >= MinMsgSize {
|
||||
co.UDPSize = opt.UDPSize()
|
||||
}
|
||||
|
||||
co.SetWriteDeadline(deadline)
|
||||
if err = co.WriteMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
co.SetReadDeadline(deadline)
|
||||
r, err = co.ReadMsg()
|
||||
if err == nil && r.Id != m.Id {
|
||||
err = ErrId
|
||||
}
|
||||
return r, err
|
||||
}
|
||||
|
||||
// ExchangeConn performs a synchronous query. It sends the message m via the connection
|
||||
// c and waits for a reply. The connection c is not closed by ExchangeConn.
|
||||
// This function is going away, but can easily be mimicked:
|
||||
|
@ -106,8 +144,18 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
|
|||
// buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit
|
||||
// of 512 bytes.
|
||||
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
||||
return c.ExchangeContext(context.Background(), m, a)
|
||||
}
|
||||
|
||||
// ExchangeContext acts like Exchange, but honors the deadline on the provided
|
||||
// context, if present. If there is both a context deadline and a configured
|
||||
// timeout on the client, the earliest of the two takes effect.
|
||||
func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (
|
||||
r *Msg,
|
||||
rtt time.Duration,
|
||||
err error) {
|
||||
if !c.SingleInflight {
|
||||
return c.exchange(m, a)
|
||||
return c.exchange(ctx, m, a)
|
||||
}
|
||||
// This adds a bunch of garbage, TODO(miek).
|
||||
t := "nop"
|
||||
|
@ -119,7 +167,7 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
|
|||
cl = cl1
|
||||
}
|
||||
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
|
||||
return c.exchange(m, a)
|
||||
return c.exchange(ctx, m, a)
|
||||
})
|
||||
if r != nil && shared {
|
||||
r = r.Copy()
|
||||
|
@ -154,7 +202,7 @@ func (c *Client) writeTimeout() time.Duration {
|
|||
return dnsTimeout
|
||||
}
|
||||
|
||||
func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
||||
func (c *Client) exchange(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
||||
var co *Conn
|
||||
network := "udp"
|
||||
tls := false
|
||||
|
@ -180,10 +228,13 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
|
|||
deadline = time.Now().Add(c.Timeout)
|
||||
}
|
||||
|
||||
dialDeadline := deadlineOrTimeoutOrCtx(ctx, deadline, c.dialTimeout())
|
||||
dialTimeout := dialDeadline.Sub(time.Now())
|
||||
|
||||
if tls {
|
||||
co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout())
|
||||
co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, dialTimeout)
|
||||
} else {
|
||||
co, err = DialTimeout(network, a, c.dialTimeout())
|
||||
co, err = DialTimeout(network, a, dialTimeout)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -202,12 +253,12 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
|
|||
}
|
||||
|
||||
co.TsigSecret = c.TsigSecret
|
||||
co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout()))
|
||||
co.SetWriteDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.writeTimeout()))
|
||||
if err = co.WriteMsg(m); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout()))
|
||||
co.SetReadDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.readTimeout()))
|
||||
r, err = co.ReadMsg()
|
||||
if err == nil && r.Id != m.Id {
|
||||
err = ErrId
|
||||
|
@ -459,9 +510,22 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
|
|||
return conn, nil
|
||||
}
|
||||
|
||||
// deadlineOrTimeout chooses between the provided deadline and timeout
|
||||
// by always preferring the deadline so long as it's non-zero (regardless
|
||||
// of which is bigger), and returns the equivalent deadline value.
|
||||
func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time {
|
||||
if deadline.IsZero() {
|
||||
return time.Now().Add(timeout)
|
||||
}
|
||||
return deadline
|
||||
}
|
||||
|
||||
// deadlineOrTimeoutOrCtx returns the earliest of: a context deadline, or the
|
||||
// output of deadlineOrtimeout.
|
||||
func deadlineOrTimeoutOrCtx(ctx context.Context, deadline time.Time, timeout time.Duration) time.Time {
|
||||
result := deadlineOrTimeout(deadline, timeout)
|
||||
if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(result) {
|
||||
result = ctxDeadline
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -423,7 +424,7 @@ func TestTimeout(t *testing.T) {
|
|||
|
||||
// Use a channel + timeout to ensure we don't get stuck if the
|
||||
// Client Timeout is not working properly
|
||||
done := make(chan struct{})
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
timeout := time.Millisecond
|
||||
allowable := timeout + (10 * time.Millisecond)
|
||||
|
@ -435,14 +436,28 @@ func TestTimeout(t *testing.T) {
|
|||
c := &Client{Timeout: timeout}
|
||||
_, _, err := c.Exchange(m, addrstr)
|
||||
if err == nil {
|
||||
t.Error("no timeout using Client")
|
||||
t.Error("no timeout using Client.Exchange")
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(abortAfter):
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
c := &Client{}
|
||||
_, _, err := c.ExchangeContext(ctx, m, addrstr)
|
||||
if err == nil {
|
||||
t.Error("no timeout using Client.ExchangeContext")
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
// Wait for both the Exchange and ExchangeContext tests to be done.
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(abortAfter):
|
||||
}
|
||||
}
|
||||
|
||||
length := time.Since(start)
|
||||
|
|
Loading…
Reference in New Issue