Fix race condition in ExchangeContext. (#1281)
Automatically submitted.
This commit is contained in:
parent
af0c865ab3
commit
ab67aa6423
69
client.go
69
client.go
|
@ -82,6 +82,12 @@ func (c *Client) writeTimeout() time.Duration {
|
||||||
|
|
||||||
// Dial connects to the address on the named network.
|
// Dial connects to the address on the named network.
|
||||||
func (c *Client) Dial(address string) (conn *Conn, err error) {
|
func (c *Client) Dial(address string) (conn *Conn, err error) {
|
||||||
|
return c.DialContext(context.Background(), address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext connects to the address on the named network, with a context.Context.
|
||||||
|
// For TLS over TCP (DoT) the context isn't used yet. This will be enabled when Go 1.18 is released.
|
||||||
|
func (c *Client) DialContext(ctx context.Context, address string) (conn *Conn, err error) {
|
||||||
// create a new dialer with the appropriate timeout
|
// create a new dialer with the appropriate timeout
|
||||||
var d net.Dialer
|
var d net.Dialer
|
||||||
if c.Dialer == nil {
|
if c.Dialer == nil {
|
||||||
|
@ -101,9 +107,17 @@ func (c *Client) Dial(address string) (conn *Conn, err error) {
|
||||||
if useTLS {
|
if useTLS {
|
||||||
network = strings.TrimSuffix(network, "-tls")
|
network = strings.TrimSuffix(network, "-tls")
|
||||||
|
|
||||||
|
// TODO(miekg): Enable after Go 1.18 is released, to be able to support two prev. releases.
|
||||||
|
/*
|
||||||
|
tlsDialer := tls.Dialer{
|
||||||
|
NetDialer: &d,
|
||||||
|
Config: c.TLSConfig,
|
||||||
|
}
|
||||||
|
conn.Conn, err = tlsDialer.DialContext(ctx, network, address)
|
||||||
|
*/
|
||||||
conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig)
|
conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig)
|
||||||
} else {
|
} else {
|
||||||
conn.Conn, err = d.Dial(network, address)
|
conn.Conn, err = d.DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -139,6 +153,7 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er
|
||||||
// ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection
|
// ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection
|
||||||
// that will be used instead of creating a new one.
|
// that will be used instead of creating a new one.
|
||||||
// Usage pattern with a *dns.Client:
|
// Usage pattern with a *dns.Client:
|
||||||
|
//
|
||||||
// c := new(dns.Client)
|
// c := new(dns.Client)
|
||||||
// // connection management logic goes here
|
// // connection management logic goes here
|
||||||
//
|
//
|
||||||
|
@ -148,15 +163,24 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er
|
||||||
// This allows users of the library to implement their own connection management,
|
// This allows users of the library to implement their own connection management,
|
||||||
// as opposed to Exchange, which will always use new connections and incur the added overhead
|
// as opposed to Exchange, which will always use new connections and incur the added overhead
|
||||||
// that entails when using "tcp" and especially "tcp-tls" clients.
|
// that entails when using "tcp" and especially "tcp-tls" clients.
|
||||||
|
//
|
||||||
|
// When the singleflight is set for this client the context is _not_ forwarded to the (shared) exchange, to
|
||||||
|
// prevent one cancelation from canceling all outstanding requests.
|
||||||
func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
|
func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
|
||||||
|
return c.exchangeWithConnContext(context.Background(), m, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) exchangeWithConnContext(ctx context.Context, m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
|
||||||
if !c.SingleInflight {
|
if !c.SingleInflight {
|
||||||
return c.exchange(m, conn)
|
return c.exchangeContext(ctx, m, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
q := m.Question[0]
|
q := m.Question[0]
|
||||||
key := fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass)
|
key := fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass)
|
||||||
r, rtt, err, shared := c.group.Do(key, func() (*Msg, time.Duration, error) {
|
r, rtt, err, shared := c.group.Do(key, func() (*Msg, time.Duration, error) {
|
||||||
return c.exchange(m, conn)
|
// When we're doing singleflight we don't want one context cancelation, cancel _all_ outstanding queries.
|
||||||
|
// Hence we ignore the context and use Background().
|
||||||
|
return c.exchangeContext(context.Background(), m, conn)
|
||||||
})
|
})
|
||||||
if r != nil && shared {
|
if r != nil && shared {
|
||||||
r = r.Copy()
|
r = r.Copy()
|
||||||
|
@ -165,8 +189,7 @@ func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration
|
||||||
return r, rtt, err
|
return r, rtt, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) {
|
func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) {
|
||||||
|
|
||||||
opt := m.IsEdns0()
|
opt := m.IsEdns0()
|
||||||
// If EDNS0 is used use that for size.
|
// If EDNS0 is used use that for size.
|
||||||
if opt != nil && opt.UDPSize() >= MinMsgSize {
|
if opt != nil && opt.UDPSize() >= MinMsgSize {
|
||||||
|
@ -177,15 +200,27 @@ func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err erro
|
||||||
co.UDPSize = c.UDPSize
|
co.UDPSize = c.UDPSize
|
||||||
}
|
}
|
||||||
|
|
||||||
co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider
|
|
||||||
t := time.Now()
|
|
||||||
// write with the appropriate write timeout
|
// write with the appropriate write timeout
|
||||||
co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout())))
|
t := time.Now()
|
||||||
|
writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout()))
|
||||||
|
readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout()))
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
if deadline.Before(writeDeadline) {
|
||||||
|
writeDeadline = deadline
|
||||||
|
}
|
||||||
|
if deadline.Before(readDeadline) {
|
||||||
|
readDeadline = deadline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
co.SetWriteDeadline(writeDeadline)
|
||||||
|
co.SetReadDeadline(readDeadline)
|
||||||
|
|
||||||
|
co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider
|
||||||
|
|
||||||
if err = co.WriteMsg(m); err != nil {
|
if err = co.WriteMsg(m); err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout())))
|
|
||||||
if _, ok := co.Conn.(net.PacketConn); ok {
|
if _, ok := co.Conn.(net.PacketConn); ok {
|
||||||
for {
|
for {
|
||||||
r, err = co.ReadMsg()
|
r, err = co.ReadMsg()
|
||||||
|
@ -435,15 +470,11 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
|
||||||
// context, if present. If there is both a context deadline and a configured
|
// context, if present. If there is both a context deadline and a configured
|
||||||
// timeout on the client, the earliest of the two takes effect.
|
// timeout on the client, the earliest of the two takes effect.
|
||||||
func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
||||||
var timeout time.Duration
|
conn, err := c.DialContext(ctx, a)
|
||||||
if deadline, ok := ctx.Deadline(); !ok {
|
if err != nil {
|
||||||
timeout = 0
|
return nil, 0, err
|
||||||
} else {
|
|
||||||
timeout = time.Until(deadline)
|
|
||||||
}
|
}
|
||||||
// not passing the context to the underlying calls, as the API does not support
|
defer conn.Close()
|
||||||
// context. For timeouts you should set up Client.Dialer and call Client.Exchange.
|
|
||||||
// TODO(tmthrgd,miekg): this is a race condition.
|
return c.exchangeWithConnContext(ctx, m, conn)
|
||||||
c.Dialer = &net.Dialer{Timeout: timeout}
|
|
||||||
return c.Exchange(m, a)
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue