From 64746df23bbe418a542e5940c2cdf11291134819 Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Wed, 16 May 2018 17:24:01 +0930 Subject: [PATCH] WIP: DNS-over-HTTPS support for Client.Exchange API (#671) * Add DNS-over-HTTPS support to (*Client).Exchange * Ignore net/http goroutine leak from DoH * Use existing Dialer and TLSConfig fields on Client for DOH * Make DOH http.Client fully configurable * Pipe context into exchangeDOH --- client.go | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++ client_test.go | 22 ++++++++++++ leak_test.go | 1 + 3 files changed, 116 insertions(+) diff --git a/client.go b/client.go index 282565af..6aa4235d 100644 --- a/client.go +++ b/client.go @@ -7,8 +7,12 @@ import ( "context" "crypto/tls" "encoding/binary" + "fmt" "io" + "io/ioutil" "net" + "net/http" + "net/url" "strings" "time" ) @@ -16,6 +20,8 @@ import ( const dnsTimeout time.Duration = 2 * time.Second const tcpIdleTimeout time.Duration = 8 * time.Second +const dohMimeType = "application/dns-udpwireformat" + // A Conn represents a connection to a DNS server. type Conn struct { net.Conn // a net.Conn holding the connection @@ -37,6 +43,7 @@ type Client struct { DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero + HTTPClient *http.Client // The http.Client to use for DNS-over-HTTPS TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass group singleflight @@ -134,6 +141,11 @@ func (c *Client) Dial(address string) (conn *Conn, err error) { // attribute appropriately func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) { if !c.SingleInflight { + if c.Net == "https" { + // TODO(tmthrgd): pipe timeouts into exchangeDOH + return c.exchangeDOH(context.TODO(), m, address) + } + return c.exchange(m, address) } @@ -146,6 +158,11 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er cl = cl1 } r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { + if c.Net == "https" { + // TODO(tmthrgd): pipe timeouts into exchangeDOH + return c.exchangeDOH(context.TODO(), m, address) + } + return c.exchange(m, address) }) if r != nil && shared { @@ -191,6 +208,77 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro return r, rtt, err } +func (c *Client) exchangeDOH(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) { + p, err := m.Pack() + if err != nil { + return nil, 0, err + } + + // TODO(tmthrgd): Allow the path to be customised? + u := &url.URL{ + Scheme: "https", + Host: a, + Path: "/.well-known/dns-query", + } + if u.Port() == "443" { + u.Host = u.Hostname() + } + + req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(p)) + if err != nil { + return nil, 0, err + } + + req.Header.Set("Content-Type", dohMimeType) + req.Header.Set("Accept", dohMimeType) + + t := time.Now() + + hc := http.DefaultClient + if c.HTTPClient != nil { + hc = c.HTTPClient + } + + if ctx != context.Background() && ctx != context.TODO() { + req = req.WithContext(ctx) + } + + resp, err := hc.Do(req) + if err != nil { + return nil, 0, err + } + defer closeHTTPBody(resp.Body) + + if resp.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("dns: server returned HTTP %d error: %q", resp.StatusCode, resp.Status) + } + + if ct := resp.Header.Get("Content-Type"); ct != dohMimeType { + return nil, 0, fmt.Errorf("dns: unexpected Content-Type %q; expected %q", ct, dohMimeType) + } + + p, err = ioutil.ReadAll(resp.Body) + if err != nil { + return nil, 0, err + } + + rtt = time.Since(t) + + r = new(Msg) + if err := r.Unpack(p); err != nil { + return r, 0, err + } + + // TODO: TSIG? Is it even supported over DoH? + + return r, rtt, nil +} + +func closeHTTPBody(r io.ReadCloser) error { + io.Copy(ioutil.Discard, io.LimitReader(r, 8<<20)) + return r.Close() +} + // ReadMsg reads a message from the connection co. // If the received message contains a TSIG record the transaction signature // is verified. This method always tries to return the message, however if an @@ -490,6 +578,10 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout // context, if present. If there is both a context deadline and a configured // timeout on the client, the earliest of the two takes effect. func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) { + if !c.SingleInflight && c.Net == "https" { + return c.exchangeDOH(ctx, m, a) + } + var timeout time.Duration if deadline, ok := ctx.Deadline(); !ok { timeout = 0 @@ -498,6 +590,7 @@ func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, } // not passing the context to the underlying calls, as the API does not support // context. For timeouts you should set up Client.Dialer and call Client.Exchange. + // TODO(tmthrgd): this is a race condition c.Dialer = &net.Dialer{Timeout: timeout} return c.Exchange(m, a) } diff --git a/client_test.go b/client_test.go index cc419831..99fb08ae 100644 --- a/client_test.go +++ b/client_test.go @@ -588,3 +588,25 @@ func TestConcurrentExchanges(t *testing.T) { } } } + +func TestDoHExchange(t *testing.T) { + const addrstr = "dns.cloudflare.com:443" + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + cl := &Client{Net: "https"} + + r, _, err := cl.Exchange(m, addrstr) + if err != nil { + t.Fatalf("failed to exchange: %v", err) + } + + if r == nil || r.Rcode != RcodeSuccess { + t.Errorf("failed to get an valid answer\n%v", r) + } + + t.Log(r) + + // TODO: proper tests for this +} diff --git a/leak_test.go b/leak_test.go index af37011d..ff83ac74 100644 --- a/leak_test.go +++ b/leak_test.go @@ -29,6 +29,7 @@ func interestingGoroutines() (gs []string) { strings.Contains(stack, "closeWriteAndWait") || strings.Contains(stack, "testing.Main(") || strings.Contains(stack, "testing.(*T).Run(") || + strings.Contains(stack, "created by net/http.(*http2Transport).newClientConn") || // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) strings.Contains(stack, "runtime.goexit") || strings.Contains(stack, "created by runtime.gc") ||