Add Client.Timeout to allow limiting total exchange duration (#345)
This commit is contained in:
parent
a5cc44dc6b
commit
c9d1302d54
26
client.go
26
client.go
|
@ -28,9 +28,10 @@ type Client struct {
|
||||||
Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
|
Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
|
||||||
UDPSize uint16 // minimum receive buffer for UDP messages
|
UDPSize uint16 // minimum receive buffer for UDP messages
|
||||||
TLSConfig *tls.Config // TLS connection configuration
|
TLSConfig *tls.Config // TLS connection configuration
|
||||||
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
|
Timeout time.Duration // a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout and WriteTimeout when non-zero
|
||||||
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
|
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds - overridden by Timeout when that value is non-zero
|
||||||
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
|
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
|
||||||
|
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
|
||||||
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
|
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
|
||||||
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
|
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
|
||||||
group singleflight
|
group singleflight
|
||||||
|
@ -129,6 +130,9 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) dialTimeout() time.Duration {
|
func (c *Client) dialTimeout() time.Duration {
|
||||||
|
if c.Timeout != 0 {
|
||||||
|
return c.Timeout
|
||||||
|
}
|
||||||
if c.DialTimeout != 0 {
|
if c.DialTimeout != 0 {
|
||||||
return c.DialTimeout
|
return c.DialTimeout
|
||||||
}
|
}
|
||||||
|
@ -170,6 +174,11 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var deadline time.Time
|
||||||
|
if c.Timeout != 0 {
|
||||||
|
deadline = time.Now().Add(c.Timeout)
|
||||||
|
}
|
||||||
|
|
||||||
if tls {
|
if tls {
|
||||||
co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout())
|
co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout())
|
||||||
} else {
|
} else {
|
||||||
|
@ -192,12 +201,12 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
|
||||||
}
|
}
|
||||||
|
|
||||||
co.TsigSecret = c.TsigSecret
|
co.TsigSecret = c.TsigSecret
|
||||||
co.SetWriteDeadline(time.Now().Add(c.writeTimeout()))
|
co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout()))
|
||||||
if err = co.WriteMsg(m); err != nil {
|
if err = co.WriteMsg(m); err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
co.SetReadDeadline(time.Now().Add(c.readTimeout()))
|
co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout()))
|
||||||
r, err = co.ReadMsg()
|
r, err = co.ReadMsg()
|
||||||
if err == nil && r.Id != m.Id {
|
if err == nil && r.Id != m.Id {
|
||||||
err = ErrId
|
err = ErrId
|
||||||
|
@ -434,3 +443,10 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
|
||||||
}
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time {
|
||||||
|
if deadline.IsZero() {
|
||||||
|
return time.Now().Add(timeout)
|
||||||
|
}
|
||||||
|
return deadline
|
||||||
|
}
|
||||||
|
|
|
@ -419,3 +419,51 @@ func TestTruncatedMsg(t *testing.T) {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTimeout(t *testing.T) {
|
||||||
|
// Set up a dummy UDP server that won't respond
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to resolve local udp address: %v", err)
|
||||||
|
}
|
||||||
|
conn, err := net.ListenUDP("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
addrstr := conn.LocalAddr().String()
|
||||||
|
|
||||||
|
// Message to send
|
||||||
|
m := new(Msg)
|
||||||
|
m.SetQuestion("miek.nl.", TypeTXT)
|
||||||
|
|
||||||
|
// Use a channel + timeout to ensure we don't get stuck if the
|
||||||
|
// Client Timeout is not working properly
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
timeout := time.Millisecond
|
||||||
|
allowable := timeout + (10 * time.Millisecond)
|
||||||
|
abortAfter := timeout + (100 * time.Millisecond)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
c := &Client{Timeout: timeout}
|
||||||
|
_, _, err := c.Exchange(m, addrstr)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("no timeout using Client")
|
||||||
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(abortAfter):
|
||||||
|
}
|
||||||
|
|
||||||
|
length := time.Since(start)
|
||||||
|
|
||||||
|
if length > allowable {
|
||||||
|
t.Errorf("exchange took longer (%v) than specified Timeout (%v)", length, timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue