diff --git a/client.go b/client.go index f907698b..6bae3a1c 100644 --- a/client.go +++ b/client.go @@ -82,6 +82,12 @@ func (c *Client) writeTimeout() time.Duration { // Dial connects to the address on the named network. func (c *Client) Dial(address string) (conn *Conn, err error) { + return c.DialContext(context.Background(), address) +} + +// DialContext connects to the address on the named network, with a context.Context. +// For TLS over TCP (DoT) the context isn't used yet. This will be enabled when Go 1.18 is released. +func (c *Client) DialContext(ctx context.Context, address string) (conn *Conn, err error) { // create a new dialer with the appropriate timeout var d net.Dialer if c.Dialer == nil { @@ -101,9 +107,17 @@ func (c *Client) Dial(address string) (conn *Conn, err error) { if useTLS { network = strings.TrimSuffix(network, "-tls") + // TODO(miekg): Enable after Go 1.18 is released, to be able to support two prev. releases. + /* + tlsDialer := tls.Dialer{ + NetDialer: &d, + Config: c.TLSConfig, + } + conn.Conn, err = tlsDialer.DialContext(ctx, network, address) + */ conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig) } else { - conn.Conn, err = d.Dial(network, address) + conn.Conn, err = d.DialContext(ctx, network, address) } if err != nil { return nil, err @@ -139,24 +153,34 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er // ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection // that will be used instead of creating a new one. // Usage pattern with a *dns.Client: +// // c := new(dns.Client) // // connection management logic goes here // // conn := c.Dial(address) // in, rtt, err := c.ExchangeWithConn(message, conn) // -// This allows users of the library to implement their own connection management, -// as opposed to Exchange, which will always use new connections and incur the added overhead -// that entails when using "tcp" and especially "tcp-tls" clients. +// This allows users of the library to implement their own connection management, +// as opposed to Exchange, which will always use new connections and incur the added overhead +// that entails when using "tcp" and especially "tcp-tls" clients. +// +// When the singleflight is set for this client the context is _not_ forwarded to the (shared) exchange, to +// prevent one cancelation from canceling all outstanding requests. func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) { + return c.exchangeWithConnContext(context.Background(), m, conn) +} + +func (c *Client) exchangeWithConnContext(ctx context.Context, m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) { if !c.SingleInflight { - return c.exchange(m, conn) + return c.exchangeContext(ctx, m, conn) } q := m.Question[0] key := fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass) r, rtt, err, shared := c.group.Do(key, func() (*Msg, time.Duration, error) { - return c.exchange(m, conn) + // When we're doing singleflight we don't want one context cancelation, cancel _all_ outstanding queries. + // Hence we ignore the context and use Background(). + return c.exchangeContext(context.Background(), m, conn) }) if r != nil && shared { r = r.Copy() @@ -165,8 +189,7 @@ func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration return r, rtt, err } -func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) { - +func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) { opt := m.IsEdns0() // If EDNS0 is used use that for size. if opt != nil && opt.UDPSize() >= MinMsgSize { @@ -177,15 +200,27 @@ func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err erro co.UDPSize = c.UDPSize } - co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider - t := time.Now() // write with the appropriate write timeout - co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout()))) + t := time.Now() + writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout())) + readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout())) + if deadline, ok := ctx.Deadline(); ok { + if deadline.Before(writeDeadline) { + writeDeadline = deadline + } + if deadline.Before(readDeadline) { + readDeadline = deadline + } + } + co.SetWriteDeadline(writeDeadline) + co.SetReadDeadline(readDeadline) + + co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider + if err = co.WriteMsg(m); err != nil { return nil, 0, err } - co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout()))) if _, ok := co.Conn.(net.PacketConn); ok { for { r, err = co.ReadMsg() @@ -435,15 +470,11 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout // context, if present. If there is both a context deadline and a configured // timeout on the client, the earliest of the two takes effect. func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - var timeout time.Duration - if deadline, ok := ctx.Deadline(); !ok { - timeout = 0 - } else { - timeout = time.Until(deadline) + conn, err := c.DialContext(ctx, a) + if err != nil { + return nil, 0, err } - // not passing the context to the underlying calls, as the API does not support - // context. For timeouts you should set up Client.Dialer and call Client.Exchange. - // TODO(tmthrgd,miekg): this is a race condition. - c.Dialer = &net.Dialer{Timeout: timeout} - return c.Exchange(m, a) + defer conn.Close() + + return c.exchangeWithConnContext(ctx, m, conn) }