From c9d1302d540edfb97d9ecbfe90b4fb515088630b Mon Sep 17 00:00:00 2001 From: Will Bond Date: Tue, 19 Apr 2016 06:29:51 -0400 Subject: [PATCH] Add Client.Timeout to allow limiting total exchange duration (#345) --- client.go | 26 +++++++++++++++++++++----- client_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 7a05eb99..c1a4a430 100644 --- a/client.go +++ b/client.go @@ -28,9 +28,10 @@ type Client struct { Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) UDPSize uint16 // minimum receive buffer for UDP messages TLSConfig *tls.Config // TLS connection configuration - DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds - ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds + Timeout time.Duration // a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout and WriteTimeout when non-zero + DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds - overridden by Timeout when that value is non-zero + ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero + WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be fully qualified SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass group singleflight @@ -129,6 +130,9 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro } func (c *Client) dialTimeout() time.Duration { + if c.Timeout != 0 { + return c.Timeout + } if c.DialTimeout != 0 { return c.DialTimeout } @@ -170,6 +174,11 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro } } + var deadline time.Time + if c.Timeout != 0 { + deadline = time.Now().Add(c.Timeout) + } + if tls { co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout()) } else { @@ -192,12 +201,12 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro } co.TsigSecret = c.TsigSecret - co.SetWriteDeadline(time.Now().Add(c.writeTimeout())) + co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout())) if err = co.WriteMsg(m); err != nil { return nil, 0, err } - co.SetReadDeadline(time.Now().Add(c.readTimeout())) + co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout())) r, err = co.ReadMsg() if err == nil && r.Id != m.Id { err = ErrId @@ -434,3 +443,10 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout } return conn, nil } + +func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time { + if deadline.IsZero() { + return time.Now().Add(timeout) + } + return deadline +} diff --git a/client_test.go b/client_test.go index 09ec7037..f3d229c4 100644 --- a/client_test.go +++ b/client_test.go @@ -419,3 +419,51 @@ func TestTruncatedMsg(t *testing.T) { t.Fail() } } + +func TestTimeout(t *testing.T) { + // Set up a dummy UDP server that won't respond + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("unable to resolve local udp address: %v", err) + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer conn.Close() + addrstr := conn.LocalAddr().String() + + // Message to send + m := new(Msg) + m.SetQuestion("miek.nl.", TypeTXT) + + // Use a channel + timeout to ensure we don't get stuck if the + // Client Timeout is not working properly + done := make(chan struct{}) + + timeout := time.Millisecond + allowable := timeout + (10 * time.Millisecond) + abortAfter := timeout + (100 * time.Millisecond) + + start := time.Now() + + go func() { + c := &Client{Timeout: timeout} + _, _, err := c.Exchange(m, addrstr) + if err == nil { + t.Error("no timeout using Client") + } + done <- struct{}{} + }() + + select { + case <-done: + case <-time.After(abortAfter): + } + + length := time.Since(start) + + if length > allowable { + t.Errorf("exchange took longer (%v) than specified Timeout (%v)", length, timeout) + } +}