diff --git a/client.go b/client.go index edd9368b..e7ff786a 100644 --- a/client.go +++ b/client.go @@ -185,9 +185,20 @@ func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err erro } co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout()))) - r, err = co.ReadMsg() - if err == nil && r.Id != m.Id { - err = ErrId + if _, ok := co.Conn.(net.PacketConn); ok { + for { + r, err = co.ReadMsg() + // Ignore replies with mismatched IDs because they might be + // responses to earlier queries that timed out. + if err != nil || r.Id == m.Id { + break + } + } + } else { + r, err = co.ReadMsg() + if err == nil && r.Id != m.Id { + err = ErrId + } } rtt = time.Since(t) return r, rtt, err diff --git a/client_test.go b/client_test.go index b75d946c..11c8469d 100644 --- a/client_test.go +++ b/client_test.go @@ -3,6 +3,7 @@ package dns import ( "context" "crypto/tls" + "errors" "fmt" "net" "strconv" @@ -162,6 +163,12 @@ func TestClientTLSSyncV4(t *testing.T) { } } +func isNetworkTimeout(err error) bool { + // TODO: when Go 1.14 support is dropped, do this: https://golang.org/doc/go1.15#net + var netError net.Error + return errors.As(err, &netError) && netError.Timeout() +} + func TestClientSyncBadID(t *testing.T) { HandleFunc("miek.nl.", HelloServerBadID) defer HandleRemove("miek.nl.") @@ -175,12 +182,66 @@ func TestClientSyncBadID(t *testing.T) { m := new(Msg) m.SetQuestion("miek.nl.", TypeSOA) - c := new(Client) - if _, _, err := c.Exchange(m, addrstr); err != ErrId { - t.Errorf("did not find a bad Id") + c := &Client{ + Timeout: 50 * time.Millisecond, + } + if _, _, err := c.Exchange(m, addrstr); err == nil || !isNetworkTimeout(err) { + t.Errorf("query did not time out") } // And now with plain Exchange(). - if _, err := Exchange(m, addrstr); err != ErrId { + if _, err = Exchange(m, addrstr); err == nil || !isNetworkTimeout(err) { + t.Errorf("query did not time out") + } +} + +func TestClientSyncBadThenGoodID(t *testing.T) { + HandleFunc("miek.nl.", HelloServerBadThenGoodID) + defer HandleRemove("miek.nl.") + + s, addrstr, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + c := new(Client) + r, _, err := c.Exchange(m, addrstr) + if err != nil { + t.Errorf("failed to exchange: %v", err) + } + if r.Id != m.Id { + t.Errorf("failed to get response with expected Id") + } + // And now with plain Exchange(). + r, err = Exchange(m, addrstr) + if err != nil { + t.Errorf("failed to exchange: %v", err) + } + if r.Id != m.Id { + t.Errorf("failed to get response with expected Id") + } +} + +func TestClientSyncTCPBadID(t *testing.T) { + HandleFunc("miek.nl.", HelloServerBadID) + defer HandleRemove("miek.nl.") + + s, addrstr, err := RunLocalTCPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + c := &Client{ + Net: "tcp", + } + if _, _, err := c.Exchange(m, addrstr); err != ErrId { t.Errorf("did not find a bad Id") } } diff --git a/server_test.go b/server_test.go index 864b52f5..c76c7a77 100644 --- a/server_test.go +++ b/server_test.go @@ -35,6 +35,19 @@ func HelloServerBadID(w ResponseWriter, req *Msg) { w.WriteMsg(m) } +func HelloServerBadThenGoodID(w ResponseWriter, req *Msg) { + m := new(Msg) + m.SetReply(req) + m.Id++ + + m.Extra = make([]RR, 1) + m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} + w.WriteMsg(m) + + m.Id-- + w.WriteMsg(m) +} + func HelloServerEchoAddrPort(w ResponseWriter, req *Msg) { m := new(Msg) m.SetReply(req)