WIP: DNS-over-HTTPS support for Client.Exchange API (#671)
* Add DNS-over-HTTPS support to (*Client).Exchange * Ignore net/http goroutine leak from DoH * Use existing Dialer and TLSConfig fields on Client for DOH * Make DOH http.Client fully configurable * Pipe context into exchangeDOH
This commit is contained in:
parent
3745b9737d
commit
64746df23b
93
client.go
93
client.go
|
@ -7,8 +7,12 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -16,6 +20,8 @@ import (
|
||||||
const dnsTimeout time.Duration = 2 * time.Second
|
const dnsTimeout time.Duration = 2 * time.Second
|
||||||
const tcpIdleTimeout time.Duration = 8 * time.Second
|
const tcpIdleTimeout time.Duration = 8 * time.Second
|
||||||
|
|
||||||
|
const dohMimeType = "application/dns-udpwireformat"
|
||||||
|
|
||||||
// A Conn represents a connection to a DNS server.
|
// A Conn represents a connection to a DNS server.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
net.Conn // a net.Conn holding the connection
|
net.Conn // a net.Conn holding the connection
|
||||||
|
@ -37,6 +43,7 @@ type Client struct {
|
||||||
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero
|
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero
|
||||||
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
|
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
|
||||||
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
|
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
|
||||||
|
HTTPClient *http.Client // The http.Client to use for DNS-over-HTTPS
|
||||||
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
|
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
|
||||||
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
|
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
|
||||||
group singleflight
|
group singleflight
|
||||||
|
@ -134,6 +141,11 @@ func (c *Client) Dial(address string) (conn *Conn, err error) {
|
||||||
// attribute appropriately
|
// attribute appropriately
|
||||||
func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) {
|
func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) {
|
||||||
if !c.SingleInflight {
|
if !c.SingleInflight {
|
||||||
|
if c.Net == "https" {
|
||||||
|
// TODO(tmthrgd): pipe timeouts into exchangeDOH
|
||||||
|
return c.exchangeDOH(context.TODO(), m, address)
|
||||||
|
}
|
||||||
|
|
||||||
return c.exchange(m, address)
|
return c.exchange(m, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,6 +158,11 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er
|
||||||
cl = cl1
|
cl = cl1
|
||||||
}
|
}
|
||||||
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
|
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
|
||||||
|
if c.Net == "https" {
|
||||||
|
// TODO(tmthrgd): pipe timeouts into exchangeDOH
|
||||||
|
return c.exchangeDOH(context.TODO(), m, address)
|
||||||
|
}
|
||||||
|
|
||||||
return c.exchange(m, address)
|
return c.exchange(m, address)
|
||||||
})
|
})
|
||||||
if r != nil && shared {
|
if r != nil && shared {
|
||||||
|
@ -191,6 +208,77 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
|
||||||
return r, rtt, err
|
return r, rtt, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) exchangeDOH(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
||||||
|
p, err := m.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(tmthrgd): Allow the path to be customised?
|
||||||
|
u := &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: a,
|
||||||
|
Path: "/.well-known/dns-query",
|
||||||
|
}
|
||||||
|
if u.Port() == "443" {
|
||||||
|
u.Host = u.Hostname()
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(p))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", dohMimeType)
|
||||||
|
req.Header.Set("Accept", dohMimeType)
|
||||||
|
|
||||||
|
t := time.Now()
|
||||||
|
|
||||||
|
hc := http.DefaultClient
|
||||||
|
if c.HTTPClient != nil {
|
||||||
|
hc = c.HTTPClient
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx != context.Background() && ctx != context.TODO() {
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := hc.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer closeHTTPBody(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, 0, fmt.Errorf("dns: server returned HTTP %d error: %q", resp.StatusCode, resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ct := resp.Header.Get("Content-Type"); ct != dohMimeType {
|
||||||
|
return nil, 0, fmt.Errorf("dns: unexpected Content-Type %q; expected %q", ct, dohMimeType)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err = ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rtt = time.Since(t)
|
||||||
|
|
||||||
|
r = new(Msg)
|
||||||
|
if err := r.Unpack(p); err != nil {
|
||||||
|
return r, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: TSIG? Is it even supported over DoH?
|
||||||
|
|
||||||
|
return r, rtt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeHTTPBody(r io.ReadCloser) error {
|
||||||
|
io.Copy(ioutil.Discard, io.LimitReader(r, 8<<20))
|
||||||
|
return r.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// ReadMsg reads a message from the connection co.
|
// ReadMsg reads a message from the connection co.
|
||||||
// If the received message contains a TSIG record the transaction signature
|
// If the received message contains a TSIG record the transaction signature
|
||||||
// is verified. This method always tries to return the message, however if an
|
// is verified. This method always tries to return the message, however if an
|
||||||
|
@ -490,6 +578,10 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
|
||||||
// context, if present. If there is both a context deadline and a configured
|
// context, if present. If there is both a context deadline and a configured
|
||||||
// timeout on the client, the earliest of the two takes effect.
|
// timeout on the client, the earliest of the two takes effect.
|
||||||
func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
||||||
|
if !c.SingleInflight && c.Net == "https" {
|
||||||
|
return c.exchangeDOH(ctx, m, a)
|
||||||
|
}
|
||||||
|
|
||||||
var timeout time.Duration
|
var timeout time.Duration
|
||||||
if deadline, ok := ctx.Deadline(); !ok {
|
if deadline, ok := ctx.Deadline(); !ok {
|
||||||
timeout = 0
|
timeout = 0
|
||||||
|
@ -498,6 +590,7 @@ func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg,
|
||||||
}
|
}
|
||||||
// not passing the context to the underlying calls, as the API does not support
|
// not passing the context to the underlying calls, as the API does not support
|
||||||
// context. For timeouts you should set up Client.Dialer and call Client.Exchange.
|
// context. For timeouts you should set up Client.Dialer and call Client.Exchange.
|
||||||
|
// TODO(tmthrgd): this is a race condition
|
||||||
c.Dialer = &net.Dialer{Timeout: timeout}
|
c.Dialer = &net.Dialer{Timeout: timeout}
|
||||||
return c.Exchange(m, a)
|
return c.Exchange(m, a)
|
||||||
}
|
}
|
||||||
|
|
|
@ -588,3 +588,25 @@ func TestConcurrentExchanges(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDoHExchange(t *testing.T) {
|
||||||
|
const addrstr = "dns.cloudflare.com:443"
|
||||||
|
|
||||||
|
m := new(Msg)
|
||||||
|
m.SetQuestion("miek.nl.", TypeSOA)
|
||||||
|
|
||||||
|
cl := &Client{Net: "https"}
|
||||||
|
|
||||||
|
r, _, err := cl.Exchange(m, addrstr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to exchange: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r == nil || r.Rcode != RcodeSuccess {
|
||||||
|
t.Errorf("failed to get an valid answer\n%v", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log(r)
|
||||||
|
|
||||||
|
// TODO: proper tests for this
|
||||||
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ func interestingGoroutines() (gs []string) {
|
||||||
strings.Contains(stack, "closeWriteAndWait") ||
|
strings.Contains(stack, "closeWriteAndWait") ||
|
||||||
strings.Contains(stack, "testing.Main(") ||
|
strings.Contains(stack, "testing.Main(") ||
|
||||||
strings.Contains(stack, "testing.(*T).Run(") ||
|
strings.Contains(stack, "testing.(*T).Run(") ||
|
||||||
|
strings.Contains(stack, "created by net/http.(*http2Transport).newClientConn") ||
|
||||||
// These only show up with GOTRACEBACK=2; Issue 5005 (comment 28)
|
// These only show up with GOTRACEBACK=2; Issue 5005 (comment 28)
|
||||||
strings.Contains(stack, "runtime.goexit") ||
|
strings.Contains(stack, "runtime.goexit") ||
|
||||||
strings.Contains(stack, "created by runtime.gc") ||
|
strings.Contains(stack, "created by runtime.gc") ||
|
||||||
|
|
Loading…
Reference in New Issue