Check that the query ID matches the answer ID.

Reduce some code duplication by making Exchange() use Client.Exchange().

When performing an Exchange if the query ID does not match the answer ID
return an error.  Also add a test for this condition.
This commit is contained in:
Michael Haro 2015-05-05 22:56:42 -07:00
parent 015384b10e
commit dddcd696ba
3 changed files with 42 additions and 21 deletions

View File

@ -46,27 +46,9 @@ type Client struct {
// co.Close()
//
func Exchange(m *Msg, a string) (r *Msg, err error) {
var co *Conn
co, err = DialTimeout("udp", a, dnsTimeout)
if err != nil {
return nil, err
}
defer co.Close()
co.SetReadDeadline(time.Now().Add(dnsTimeout))
co.SetWriteDeadline(time.Now().Add(dnsTimeout))
opt := m.IsEdns0()
// If EDNS0 is used use that for size.
if opt != nil && opt.UDPSize() >= MinMsgSize {
co.UDPSize = opt.UDPSize()
}
if err = co.WriteMsg(m); err != nil {
return nil, err
}
r, err = co.ReadMsg()
return r, err
c := Client{}
r, _, err = c.Exchange(m, a)
return
}
// ExchangeConn performs a synchronous query. It sends the message m via the connection
@ -86,6 +68,9 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
return nil, err
}
r, err = co.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
}
return r, err
}
@ -161,6 +146,9 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
return nil, 0, err
}
r, err = co.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
}
return r, co.rtt, err
}

View File

@ -37,6 +37,29 @@ func TestClientSync(t *testing.T) {
}
}
func TestClientSyncBadId(t *testing.T) {
HandleFunc("miek.nl.", HelloServerBadId)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %v", err)
}
defer s.Shutdown()
m := new(Msg)
m.SetQuestion("miek.nl.", TypeSOA)
c := new(Client)
if _, _, err := c.Exchange(m, addrstr); err != ErrId {
t.Errorf("did not find a bad Id")
}
// And now with plain Exchange().
if _, err := Exchange(m, addrstr); err != ErrId {
t.Errorf("did not find a bad Id")
}
}
func TestClientEDNS0(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")

View File

@ -17,6 +17,16 @@ func HelloServer(w ResponseWriter, req *Msg) {
w.WriteMsg(m)
}
func HelloServerBadId(w ResponseWriter, req *Msg) {
m := new(Msg)
m.SetReply(req)
m.Id += 1
m.Extra = make([]RR, 1)
m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
w.WriteMsg(m)
}
func AnotherHelloServer(w ResponseWriter, req *Msg) {
m := new(Msg)
m.SetReply(req)