Add ExchangeContext methods. (#497)
These obey the timeouts provided in a Context.
This commit is contained in:
parent
e78414ef75
commit
e46719b2fe
78
client.go
78
client.go
|
@ -4,6 +4,7 @@ package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
|
@ -70,6 +71,43 @@ func Exchange(m *Msg, a string) (r *Msg, err error) {
|
||||||
return r, err
|
return r, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExchangeContext performs a synchronous UDP query, like Exchange. It
|
||||||
|
// additionally obeys deadlines from the passed Context.
|
||||||
|
func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
|
||||||
|
// Combine context deadline with built-in timeout. Context chooses whichever
|
||||||
|
// is sooner.
|
||||||
|
timeoutCtx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||||
|
defer cancel()
|
||||||
|
deadline, _ := timeoutCtx.Deadline()
|
||||||
|
|
||||||
|
co := new(Conn)
|
||||||
|
dialer := net.Dialer{}
|
||||||
|
co.Conn, err = dialer.DialContext(timeoutCtx, "udp", a)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer co.Conn.Close()
|
||||||
|
|
||||||
|
opt := m.IsEdns0()
|
||||||
|
// If EDNS0 is used use that for size.
|
||||||
|
if opt != nil && opt.UDPSize() >= MinMsgSize {
|
||||||
|
co.UDPSize = opt.UDPSize()
|
||||||
|
}
|
||||||
|
|
||||||
|
co.SetWriteDeadline(deadline)
|
||||||
|
if err = co.WriteMsg(m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
co.SetReadDeadline(deadline)
|
||||||
|
r, err = co.ReadMsg()
|
||||||
|
if err == nil && r.Id != m.Id {
|
||||||
|
err = ErrId
|
||||||
|
}
|
||||||
|
return r, err
|
||||||
|
}
|
||||||
|
|
||||||
// ExchangeConn performs a synchronous query. It sends the message m via the connection
|
// ExchangeConn performs a synchronous query. It sends the message m via the connection
|
||||||
// c and waits for a reply. The connection c is not closed by ExchangeConn.
|
// c and waits for a reply. The connection c is not closed by ExchangeConn.
|
||||||
// This function is going away, but can easily be mimicked:
|
// This function is going away, but can easily be mimicked:
|
||||||
|
@ -106,8 +144,18 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
|
||||||
// buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit
|
// buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit
|
||||||
// of 512 bytes.
|
// of 512 bytes.
|
||||||
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
||||||
|
return c.ExchangeContext(context.Background(), m, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeContext acts like Exchange, but honors the deadline on the provided
|
||||||
|
// context, if present. If there is both a context deadline and a configured
|
||||||
|
// timeout on the client, the earliest of the two takes effect.
|
||||||
|
func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (
|
||||||
|
r *Msg,
|
||||||
|
rtt time.Duration,
|
||||||
|
err error) {
|
||||||
if !c.SingleInflight {
|
if !c.SingleInflight {
|
||||||
return c.exchange(m, a)
|
return c.exchange(ctx, m, a)
|
||||||
}
|
}
|
||||||
// This adds a bunch of garbage, TODO(miek).
|
// This adds a bunch of garbage, TODO(miek).
|
||||||
t := "nop"
|
t := "nop"
|
||||||
|
@ -119,7 +167,7 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
|
||||||
cl = cl1
|
cl = cl1
|
||||||
}
|
}
|
||||||
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
|
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
|
||||||
return c.exchange(m, a)
|
return c.exchange(ctx, m, a)
|
||||||
})
|
})
|
||||||
if r != nil && shared {
|
if r != nil && shared {
|
||||||
r = r.Copy()
|
r = r.Copy()
|
||||||
|
@ -154,7 +202,7 @@ func (c *Client) writeTimeout() time.Duration {
|
||||||
return dnsTimeout
|
return dnsTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
func (c *Client) exchange(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
||||||
var co *Conn
|
var co *Conn
|
||||||
network := "udp"
|
network := "udp"
|
||||||
tls := false
|
tls := false
|
||||||
|
@ -180,10 +228,13 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
|
||||||
deadline = time.Now().Add(c.Timeout)
|
deadline = time.Now().Add(c.Timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dialDeadline := deadlineOrTimeoutOrCtx(ctx, deadline, c.dialTimeout())
|
||||||
|
dialTimeout := dialDeadline.Sub(time.Now())
|
||||||
|
|
||||||
if tls {
|
if tls {
|
||||||
co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout())
|
co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, dialTimeout)
|
||||||
} else {
|
} else {
|
||||||
co, err = DialTimeout(network, a, c.dialTimeout())
|
co, err = DialTimeout(network, a, dialTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -202,12 +253,12 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
|
||||||
}
|
}
|
||||||
|
|
||||||
co.TsigSecret = c.TsigSecret
|
co.TsigSecret = c.TsigSecret
|
||||||
co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout()))
|
co.SetWriteDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.writeTimeout()))
|
||||||
if err = co.WriteMsg(m); err != nil {
|
if err = co.WriteMsg(m); err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout()))
|
co.SetReadDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.readTimeout()))
|
||||||
r, err = co.ReadMsg()
|
r, err = co.ReadMsg()
|
||||||
if err == nil && r.Id != m.Id {
|
if err == nil && r.Id != m.Id {
|
||||||
err = ErrId
|
err = ErrId
|
||||||
|
@ -459,9 +510,22 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deadlineOrTimeout chooses between the provided deadline and timeout
|
||||||
|
// by always preferring the deadline so long as it's non-zero (regardless
|
||||||
|
// of which is bigger), and returns the equivalent deadline value.
|
||||||
func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time {
|
func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time {
|
||||||
if deadline.IsZero() {
|
if deadline.IsZero() {
|
||||||
return time.Now().Add(timeout)
|
return time.Now().Add(timeout)
|
||||||
}
|
}
|
||||||
return deadline
|
return deadline
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deadlineOrTimeoutOrCtx returns the earliest of: a context deadline, or the
|
||||||
|
// output of deadlineOrtimeout.
|
||||||
|
func deadlineOrTimeoutOrCtx(ctx context.Context, deadline time.Time, timeout time.Duration) time.Time {
|
||||||
|
result := deadlineOrTimeout(deadline, timeout)
|
||||||
|
if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(result) {
|
||||||
|
result = ctxDeadline
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
@ -423,7 +424,7 @@ func TestTimeout(t *testing.T) {
|
||||||
|
|
||||||
// Use a channel + timeout to ensure we don't get stuck if the
|
// Use a channel + timeout to ensure we don't get stuck if the
|
||||||
// Client Timeout is not working properly
|
// Client Timeout is not working properly
|
||||||
done := make(chan struct{})
|
done := make(chan struct{}, 2)
|
||||||
|
|
||||||
timeout := time.Millisecond
|
timeout := time.Millisecond
|
||||||
allowable := timeout + (10 * time.Millisecond)
|
allowable := timeout + (10 * time.Millisecond)
|
||||||
|
@ -435,14 +436,28 @@ func TestTimeout(t *testing.T) {
|
||||||
c := &Client{Timeout: timeout}
|
c := &Client{Timeout: timeout}
|
||||||
_, _, err := c.Exchange(m, addrstr)
|
_, _, err := c.Exchange(m, addrstr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("no timeout using Client")
|
t.Error("no timeout using Client.Exchange")
|
||||||
}
|
}
|
||||||
done <- struct{}{}
|
done <- struct{}{}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
go func() {
|
||||||
case <-done:
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
case <-time.After(abortAfter):
|
defer cancel()
|
||||||
|
c := &Client{}
|
||||||
|
_, _, err := c.ExchangeContext(ctx, m, addrstr)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("no timeout using Client.ExchangeContext")
|
||||||
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for both the Exchange and ExchangeContext tests to be done.
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(abortAfter):
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
length := time.Since(start)
|
length := time.Since(start)
|
||||||
|
|
Loading…
Reference in New Issue