Check that the query ID matches the answer ID.

Reduce some code duplication by making Exchange() use Client.Exchange().

When performing an Exchange if the query ID does not match the answer ID
return an error.  Also add a test for this condition.
This commit is contained in:
Michael Haro 2015-05-05 22:56:42 -07:00
parent 015384b10e
commit dddcd696ba
3 changed files with 42 additions and 21 deletions

View File

@ -46,27 +46,9 @@ type Client struct {
// co.Close() // co.Close()
// //
func Exchange(m *Msg, a string) (r *Msg, err error) { func Exchange(m *Msg, a string) (r *Msg, err error) {
var co *Conn c := Client{}
co, err = DialTimeout("udp", a, dnsTimeout) r, _, err = c.Exchange(m, a)
if err != nil { return
return nil, err
}
defer co.Close()
co.SetReadDeadline(time.Now().Add(dnsTimeout))
co.SetWriteDeadline(time.Now().Add(dnsTimeout))
opt := m.IsEdns0()
// If EDNS0 is used use that for size.
if opt != nil && opt.UDPSize() >= MinMsgSize {
co.UDPSize = opt.UDPSize()
}
if err = co.WriteMsg(m); err != nil {
return nil, err
}
r, err = co.ReadMsg()
return r, err
} }
// ExchangeConn performs a synchronous query. It sends the message m via the connection // ExchangeConn performs a synchronous query. It sends the message m via the connection
@ -86,6 +68,9 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
return nil, err return nil, err
} }
r, err = co.ReadMsg() r, err = co.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
}
return r, err return r, err
} }
@ -161,6 +146,9 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
return nil, 0, err return nil, 0, err
} }
r, err = co.ReadMsg() r, err = co.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
}
return r, co.rtt, err return r, co.rtt, err
} }

View File

@ -37,6 +37,29 @@ func TestClientSync(t *testing.T) {
} }
} }
func TestClientSyncBadId(t *testing.T) {
HandleFunc("miek.nl.", HelloServerBadId)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %v", err)
}
defer s.Shutdown()
m := new(Msg)
m.SetQuestion("miek.nl.", TypeSOA)
c := new(Client)
if _, _, err := c.Exchange(m, addrstr); err != ErrId {
t.Errorf("did not find a bad Id")
}
// And now with plain Exchange().
if _, err := Exchange(m, addrstr); err != ErrId {
t.Errorf("did not find a bad Id")
}
}
func TestClientEDNS0(t *testing.T) { func TestClientEDNS0(t *testing.T) {
HandleFunc("miek.nl.", HelloServer) HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.") defer HandleRemove("miek.nl.")

View File

@ -17,6 +17,16 @@ func HelloServer(w ResponseWriter, req *Msg) {
w.WriteMsg(m) w.WriteMsg(m)
} }
func HelloServerBadId(w ResponseWriter, req *Msg) {
m := new(Msg)
m.SetReply(req)
m.Id += 1
m.Extra = make([]RR, 1)
m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
w.WriteMsg(m)
}
func AnotherHelloServer(w ResponseWriter, req *Msg) { func AnotherHelloServer(w ResponseWriter, req *Msg) {
m := new(Msg) m := new(Msg)
m.SetReply(req) m.SetReply(req)