From 1fc9fa1db0edfea37366f1a6fbbda29dd62011a5 Mon Sep 17 00:00:00 2001 From: yaakov kuperman Date: Mon, 4 May 2020 04:22:21 -0400 Subject: [PATCH] Adds function ExchangeWithConn (#1110) * Implements ExchangeWithConn, a function that allows callers to pass in a connection instead of having the library create a new one for them. Exchange now wraps around this, implementing the existing behavior by creating a new connection and passing it to ExchangeWithConn. c.exchange has been updated to support this behavior as well. * adding tab * formatting problem * adds test case for ExchangeWithConn --- client.go | 37 ++++++++++++++++++++++++++----------- client_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/client.go b/client.go index db2761d4..bb8667fd 100644 --- a/client.go +++ b/client.go @@ -124,15 +124,38 @@ func (c *Client) Dial(address string) (conn *Conn, err error) { // of 512 bytes // To specify a local address or a timeout, the caller has to set the `Client.Dialer` // attribute appropriately + func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) { + co, err := c.Dial(address) + + if err != nil { + return nil, 0, err + } + defer co.Close() + return c.ExchangeWithConn(m, co) +} + +// ExchangeWithConn has the same behavior as Exchange, just with a predetermined connection +// that will be used instead of creating a new one. +// Usage pattern with a *dns.Client: +// c := new(dns.Client) +// // connection management logic goes here +// +// conn := c.Dial(address) +// in, rtt, err := c.ExchangeWithConn(message, conn) +// +// This allows users of the library to implement their own connection management, +// as opposed to Exchange, which will always use new connections and incur the added overhead +// that entails when using "tcp" and especially "tcp-tls" clients. +func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) { if !c.SingleInflight { - return c.exchange(m, address) + return c.exchange(m, conn) } q := m.Question[0] key := fmt.Sprintf("%s:%d:%d", q.Name, q.Qtype, q.Qclass) r, rtt, err, shared := c.group.Do(key, func() (*Msg, time.Duration, error) { - return c.exchange(m, address) + return c.exchange(m, conn) }) if r != nil && shared { r = r.Copy() @@ -141,15 +164,7 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er return r, rtt, err } -func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - var co *Conn - - co, err = c.Dial(a) - - if err != nil { - return nil, 0, err - } - defer co.Close() +func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) { opt := m.IsEdns0() // If EDNS0 is used use that for size. diff --git a/client_test.go b/client_test.go index e7bb8ba3..b75d946c 100644 --- a/client_test.go +++ b/client_test.go @@ -565,3 +565,34 @@ func TestConcurrentExchanges(t *testing.T) { } } } + +func TestExchangeWithConn(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + defer HandleRemove("miek.nl.") + + s, addrstr, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + c := new(Client) + conn, err := c.Dial(addrstr) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + + r, _, err := c.ExchangeWithConn(m, conn) + if err != nil { + t.Fatalf("failed to exchange: %v", err) + } + if r == nil { + t.Fatal("response is nil") + } + if r.Rcode != RcodeSuccess { + t.Errorf("failed to get an valid answer\n%v", r) + } +}