diff --git a/server_test.go b/server_test.go index a12dc293..9966e54d 100644 --- a/server_test.go +++ b/server_test.go @@ -9,6 +9,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "testing" "time" @@ -625,17 +626,15 @@ func init() { func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr string, client *Client) { const requests = 100 - var wg sync.WaitGroup - wg.Add(requests) - var errOnce sync.Once // t.Fail will panic if it's called after the test function has // finished. Burning the sync.Once with a defer will prevent the // handler from calling t.Errorf after we've returned. defer errOnce.Do(func() {}) + toHandle := int32(requests) HandleFunc("example.com.", func(w ResponseWriter, req *Msg) { - defer wg.Done() + defer atomic.AddInt32(&toHandle, -1) // Wait until ShutdownContext is called before replying. testShutdownNotify.L.Lock() @@ -708,23 +707,15 @@ func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr stri }) } - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - ctx, cancel := context.WithTimeout(context.Background(), client.Timeout) defer cancel() if err := srv.ShutdownContext(ctx); err != nil { - t.Errorf("could not shutdown test server, %v", err) + t.Errorf("could not shutdown test server: %v", err) } - select { - case <-done: - default: - t.Error("ShutdownContext returned before replies") + if left := atomic.LoadInt32(&toHandle); left != 0 { + t.Errorf("ShutdownContext returned before %d replies", left) } if eg.Wait() != nil {