Add Client.Timeout to allow limiting total exchange duration (#345)

This commit is contained in:
Will Bond 2016-04-19 06:29:51 -04:00 committed by Miek Gieben
parent a5cc44dc6b
commit c9d1302d54
2 changed files with 69 additions and 5 deletions

View File

@ -28,9 +28,10 @@ type Client struct {
Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
UDPSize uint16 // minimum receive buffer for UDP messages
TLSConfig *tls.Config // TLS connection configuration
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
Timeout time.Duration // a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout and WriteTimeout when non-zero
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds - overridden by Timeout when that value is non-zero
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
group singleflight
@ -129,6 +130,9 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
}
func (c *Client) dialTimeout() time.Duration {
if c.Timeout != 0 {
return c.Timeout
}
if c.DialTimeout != 0 {
return c.DialTimeout
}
@ -170,6 +174,11 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
}
}
var deadline time.Time
if c.Timeout != 0 {
deadline = time.Now().Add(c.Timeout)
}
if tls {
co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout())
} else {
@ -192,12 +201,12 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
}
co.TsigSecret = c.TsigSecret
co.SetWriteDeadline(time.Now().Add(c.writeTimeout()))
co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout()))
if err = co.WriteMsg(m); err != nil {
return nil, 0, err
}
co.SetReadDeadline(time.Now().Add(c.readTimeout()))
co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout()))
r, err = co.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
@ -434,3 +443,10 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
}
return conn, nil
}
func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time {
if deadline.IsZero() {
return time.Now().Add(timeout)
}
return deadline
}

View File

@ -419,3 +419,51 @@ func TestTruncatedMsg(t *testing.T) {
t.Fail()
}
}
func TestTimeout(t *testing.T) {
// Set up a dummy UDP server that won't respond
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("unable to resolve local udp address: %v", err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer conn.Close()
addrstr := conn.LocalAddr().String()
// Message to send
m := new(Msg)
m.SetQuestion("miek.nl.", TypeTXT)
// Use a channel + timeout to ensure we don't get stuck if the
// Client Timeout is not working properly
done := make(chan struct{})
timeout := time.Millisecond
allowable := timeout + (10 * time.Millisecond)
abortAfter := timeout + (100 * time.Millisecond)
start := time.Now()
go func() {
c := &Client{Timeout: timeout}
_, _, err := c.Exchange(m, addrstr)
if err == nil {
t.Error("no timeout using Client")
}
done <- struct{}{}
}()
select {
case <-done:
case <-time.After(abortAfter):
}
length := time.Since(start)
if length > allowable {
t.Errorf("exchange took longer (%v) than specified Timeout (%v)", length, timeout)
}
}