From dddcd696baaf92de07a6e946357268635fd1af34 Mon Sep 17 00:00:00 2001 From: Michael Haro Date: Tue, 5 May 2015 22:56:42 -0700 Subject: [PATCH] Check that the query ID matches the answer ID. Reduce some code duplication by making Exchange() use Client.Exchange(). When performing an Exchange if the query ID does not match the answer ID return an error. Also add a test for this condition. --- client.go | 30 +++++++++--------------------- client_test.go | 23 +++++++++++++++++++++++ server_test.go | 10 ++++++++++ 3 files changed, 42 insertions(+), 21 deletions(-) diff --git a/client.go b/client.go index cdab4432..49a24a96 100644 --- a/client.go +++ b/client.go @@ -46,27 +46,9 @@ type Client struct { // co.Close() // func Exchange(m *Msg, a string) (r *Msg, err error) { - var co *Conn - co, err = DialTimeout("udp", a, dnsTimeout) - if err != nil { - return nil, err - } - - defer co.Close() - co.SetReadDeadline(time.Now().Add(dnsTimeout)) - co.SetWriteDeadline(time.Now().Add(dnsTimeout)) - - opt := m.IsEdns0() - // If EDNS0 is used use that for size. - if opt != nil && opt.UDPSize() >= MinMsgSize { - co.UDPSize = opt.UDPSize() - } - - if err = co.WriteMsg(m); err != nil { - return nil, err - } - r, err = co.ReadMsg() - return r, err + c := Client{} + r, _, err = c.Exchange(m, a) + return } // ExchangeConn performs a synchronous query. It sends the message m via the connection @@ -86,6 +68,9 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { return nil, err } r, err = co.ReadMsg() + if err == nil && r.Id != m.Id { + err = ErrId + } return r, err } @@ -161,6 +146,9 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro return nil, 0, err } r, err = co.ReadMsg() + if err == nil && r.Id != m.Id { + err = ErrId + } return r, co.rtt, err } diff --git a/client_test.go b/client_test.go index d0294f97..8a70c7ea 100644 --- a/client_test.go +++ b/client_test.go @@ -37,6 +37,29 @@ func TestClientSync(t *testing.T) { } } +func TestClientSyncBadId(t *testing.T) { + HandleFunc("miek.nl.", HelloServerBadId) + defer HandleRemove("miek.nl.") + + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + c := new(Client) + if _, _, err := c.Exchange(m, addrstr); err != ErrId { + t.Errorf("did not find a bad Id") + } + // And now with plain Exchange(). + if _, err := Exchange(m, addrstr); err != ErrId { + t.Errorf("did not find a bad Id") + } +} + func TestClientEDNS0(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") diff --git a/server_test.go b/server_test.go index c2422b1f..dff0fb52 100644 --- a/server_test.go +++ b/server_test.go @@ -17,6 +17,16 @@ func HelloServer(w ResponseWriter, req *Msg) { w.WriteMsg(m) } +func HelloServerBadId(w ResponseWriter, req *Msg) { + m := new(Msg) + m.SetReply(req) + m.Id += 1 + + m.Extra = make([]RR, 1) + m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} + w.WriteMsg(m) +} + func AnotherHelloServer(w ResponseWriter, req *Msg) { m := new(Msg) m.SetReply(req)