From e95d1070532099421013058da4ba2cee0976aea0 Mon Sep 17 00:00:00 2001 From: devnev Date: Fri, 17 Feb 2017 11:38:00 +0000 Subject: [PATCH] Fix data race in error handling. (#459) The response message must copied regardless of whether there was an error or not, otherwise two concurrent queries may modify the response as they write it out. --- client.go | 6 ++--- client_test.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 8b09e189..d54d6422 100644 --- a/client.go +++ b/client.go @@ -121,12 +121,12 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { return c.exchange(m, a) }) + if r != nil && shared { + r = r.Copy() + } if err != nil { return r, rtt, err } - if shared { - return r.Copy(), rtt, nil - } return r, rtt, nil } diff --git a/client_test.go b/client_test.go index 559afc5c..dee585f3 100644 --- a/client_test.go +++ b/client_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "strconv" + "sync" "testing" "time" ) @@ -450,3 +451,61 @@ func TestTimeout(t *testing.T) { t.Errorf("exchange took longer (%v) than specified Timeout (%v)", length, timeout) } } + +// Check that responses from deduplicated requests aren't shared between callers +func TestConcurrentExchanges(t *testing.T) { + cases := make([]*Msg, 2) + cases[0] = new(Msg) + cases[1] = new(Msg) + cases[1].Truncated = true + for _, m := range cases { + block := make(chan struct{}) + waiting := make(chan struct{}) + + handler := func(w ResponseWriter, req *Msg) { + r := m.Copy() + r.SetReply(req) + + waiting <- struct{}{} + <-block + w.WriteMsg(r) + } + + HandleFunc("miek.nl.", handler) + defer HandleRemove("miek.nl.") + + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("unable to run test server: %s", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSRV) + c := &Client{ + SingleInflight: true, + } + r := make([]*Msg, 2) + + var wg sync.WaitGroup + wg.Add(len(r)) + for i := 0; i < len(r); i++ { + go func(i int) { + r[i], _, _ = c.Exchange(m.Copy(), addrstr) + wg.Done() + }(i) + } + select { + case <-waiting: + case <-time.After(time.Second): + t.FailNow() + } + close(block) + wg.Wait() + + if r[0] == r[1] { + t.Log("Got same response object, expected non-shared responses") + t.Fail() + } + } +}