From b0dc93d2760ef438612a252a9e448d054d28b625 Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Thu, 13 Sep 2018 23:06:28 +0930 Subject: [PATCH] Make Shutdown wait for connections to terminate gracefully (#717) * Make Shutdown wait for connections to terminate gracefully * Add graceful shutdown test files from #713 * Tidy up graceful shutdown tests * Call t.Error directly in checkInProgressQueriesAtShutdownServer * Remove timeout arguments from RunLocal*ServerWithFinChan * Merge defers together in (*Server).serve This removes the defer from the UDP path, in favour of directly calling (*sync.WaitGroup).Done after (*Serve).serveDNS has returned. * Replace checkInProgressQueriesAtShutdownServer implementation This performs dialing, writing and reading as three seperate steps. * Add sleep after writing shutdown test messages * Avoid race condition when setting server timeouts Server timeouts cannot be set after the server has started without triggering the race detector. The timeout's are not strictly needed, so remove them. * Use a sync.Cond for testShutdownNotify Using a chan erroneously triggered the race detector, using a sync.Cond avoids that problem. * Remove TestShutdownUDPWithContext This doesn't really add anything. * Move shutdown and conn into (*Server).init * Only log ResponseWriter.WriteMsg error once * Test that ShutdownContext waits for the reply * Remove stray newline from diff * Rename err to ctxErr in ShutdownContext * Reword testShutdownNotify comment --- server.go | 139 ++++++++++++++++++++++++++++++++++-------- server_test.go | 159 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+), 26 deletions(-) diff --git a/server.go b/server.go index d3f526fc..4fbf7db6 100644 --- a/server.go +++ b/server.go @@ -4,6 +4,7 @@ package dns import ( "bytes" + "context" "crypto/tls" "encoding/binary" "errors" @@ -31,6 +32,10 @@ const maxIdleWorkersCount = 10000 // The maximum length of time a worker may idle for before being destroyed. const idleWorkerTimeout = 10 * time.Second +// aLongTimeAgo is a non-zero time, far in the past, used for +// immediate cancelation of network operations. +var aLongTimeAgo = time.Unix(1, 0) + // Handler is implemented by any value that implements ServeDNS. type Handler interface { ServeDNS(w ResponseWriter, r *Msg) @@ -69,6 +74,7 @@ type response struct { tcp net.Conn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right writer Writer // writer to output the raw DNS bits + wg *sync.WaitGroup // for gracefull shutdown } // ServeMux is an DNS request multiplexer. It matches the @@ -322,9 +328,12 @@ type Server struct { queue chan *response // Workers count workersCount int32 + // Shutdown handling - lock sync.RWMutex - started bool + lock sync.RWMutex + started bool + shutdown chan struct{} + conns map[net.Conn]struct{} // A pool for UDP message buffers. udpPool sync.Pool @@ -391,6 +400,9 @@ func makeUDPBuffer(size int) func() interface{} { func (srv *Server) init() { srv.queue = make(chan *response) + srv.shutdown = make(chan struct{}) + srv.conns = make(map[net.Conn]struct{}) + if srv.UDPSize == 0 { srv.UDPSize = MinMsgSize } @@ -501,23 +513,58 @@ func (srv *Server) ActivateAndServe() error { // Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and // ActivateAndServe will return. func (srv *Server) Shutdown() error { + return srv.ShutdownContext(context.Background()) +} + +// ShutdownContext shuts down a server. After a call to ShutdownContext, +// ListenAndServe and ActivateAndServe will return. +// +// A context.Context may be passed to limit how long to wait for connections +// to terminate. +func (srv *Server) ShutdownContext(ctx context.Context) error { srv.lock.Lock() - if !srv.started { - srv.lock.Unlock() - return &Error{err: "server not started"} - } + started := srv.started srv.started = false srv.lock.Unlock() + if !started { + return &Error{err: "server not started"} + } + + if srv.PacketConn != nil { + srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads + } + + if srv.Listener != nil { + srv.Listener.Close() + } + + srv.lock.Lock() + for rw := range srv.conns { + rw.SetReadDeadline(aLongTimeAgo) // Unblock reads + } + srv.lock.Unlock() + + if testShutdownNotify != nil { + testShutdownNotify.Broadcast() + } + + var ctxErr error + select { + case <-srv.shutdown: + case <-ctx.Done(): + ctxErr = ctx.Err() + } + if srv.PacketConn != nil { srv.PacketConn.Close() } - if srv.Listener != nil { - srv.Listener.Close() - } - return nil + + return ctxErr } +var testShutdownNotify *sync.Cond + // getReadTimeout is a helper func to use system timeout if server did not intend to change it. func (srv *Server) getReadTimeout() time.Duration { rtimeout := dnsTimeout @@ -535,19 +582,36 @@ func (srv *Server) serveTCP(l net.Listener) error { srv.NotifyStartedFunc() } - for { + var wg sync.WaitGroup + defer func() { + wg.Wait() + close(srv.shutdown) + }() + + for srv.isStarted() { rw, err := l.Accept() - if !srv.isStarted() { - return nil - } if err != nil { + if !srv.isStarted() { + return nil + } if neterr, ok := err.(net.Error); ok && neterr.Temporary() { continue } return err } - srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw}) + srv.lock.Lock() + // Track the connection to allow unblocking reads on shutdown. + srv.conns[rw] = struct{}{} + srv.lock.Unlock() + wg.Add(1) + srv.spawnWorker(&response{ + tsigSecret: srv.TsigSecret, + tcp: rw, + wg: &wg, + }) } + + return nil } // serveUDP starts a UDP listener for the server. @@ -563,14 +627,20 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { reader = srv.DecorateReader(reader) } + var wg sync.WaitGroup + defer func() { + wg.Wait() + close(srv.shutdown) + }() + rtimeout := srv.getReadTimeout() // deadline is not used here - for { + for srv.isStarted() { m, s, err := reader.ReadUDP(l, rtimeout) - if !srv.isStarted() { - return nil - } if err != nil { + if !srv.isStarted() { + return nil + } if netErr, ok := err.(net.Error); ok && netErr.Temporary() { continue } @@ -582,8 +652,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { } continue } - srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s}) + wg.Add(1) + srv.spawnWorker(&response{ + msg: m, + tsigSecret: srv.TsigSecret, + udp: l, + udpSession: s, + wg: &wg, + }) } + + return nil } func (srv *Server) serve(w *response) { @@ -596,20 +675,28 @@ func (srv *Server) serve(w *response) { if w.udp != nil { // serve UDP srv.serveDNS(w) - return - } - reader := Reader(&defaultReader{srv}) - if srv.DecorateReader != nil { - reader = srv.DecorateReader(reader) + w.wg.Done() + return } defer func() { if !w.hijacked { w.Close() } + + srv.lock.Lock() + delete(srv.conns, w.tcp) + srv.lock.Unlock() + + w.wg.Done() }() + reader := Reader(&defaultReader{srv}) + if srv.DecorateReader != nil { + reader = srv.DecorateReader(reader) + } + idleTimeout := tcpIdleTimeout if srv.IdleTimeout != nil { idleTimeout = srv.IdleTimeout() @@ -622,7 +709,7 @@ func (srv *Server) serve(w *response) { limit = maxTCPQueries } - for q := 0; q < limit || limit == -1; q++ { + for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ { var err error w.msg, err = reader.ReadTCP(w.tcp, timeout) if err != nil { diff --git a/server_test.go b/server_test.go index e6f073d7..c53edd03 100644 --- a/server_test.go +++ b/server_test.go @@ -1,6 +1,7 @@ package dns import ( + "context" "crypto/tls" "fmt" "io" @@ -10,6 +11,8 @@ import ( "sync" "testing" "time" + + "golang.org/x/sync/errgroup" ) func HelloServer(w ResponseWriter, req *Msg) { @@ -588,6 +591,128 @@ func TestShutdownTCP(t *testing.T) { } } +func init() { + testShutdownNotify = &sync.Cond{ + L: new(sync.Mutex), + } +} + +func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr string, client *Client) { + const requests = 100 + + var wg sync.WaitGroup + wg.Add(requests) + + var errOnce sync.Once + + HandleFunc("example.com.", func(w ResponseWriter, req *Msg) { + defer wg.Done() + + // Wait until ShutdownContext is called before replying. + testShutdownNotify.L.Lock() + testShutdownNotify.Wait() + testShutdownNotify.L.Unlock() + + m := new(Msg) + m.SetReply(req) + m.Extra = make([]RR, 1) + m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} + + if err := w.WriteMsg(m); err != nil { + errOnce.Do(func() { + t.Errorf("ResponseWriter.WriteMsg error: %s", err) + }) + } + }) + defer HandleRemove("example.com.") + + client.Timeout = 10 * time.Second + + conns := make([]*Conn, requests) + eg := new(errgroup.Group) + + for i := range conns { + conn := &conns[i] + eg.Go(func() error { + var err error + *conn, err = client.Dial(addr) + return err + }) + } + + if eg.Wait() != nil { + t.Fatalf("client.Dial error: %v", eg.Wait()) + } + + m := new(Msg) + m.SetQuestion("example.com.", TypeTXT) + eg = new(errgroup.Group) + + for _, conn := range conns { + conn := conn + eg.Go(func() error { + conn.SetWriteDeadline(time.Now().Add(client.Timeout)) + + return conn.WriteMsg(m) + }) + } + + if eg.Wait() != nil { + t.Fatalf("conn.WriteMsg error: %v", eg.Wait()) + } + + // This sleep is needed to allow time for the requests to + // pass from the client through the kernel and back into + // the server. Without it, some requests may still be in + // the kernel's buffer when ShutdownContext is called. + time.Sleep(100 * time.Millisecond) + + eg = new(errgroup.Group) + + for _, conn := range conns { + conn := conn + eg.Go(func() error { + conn.SetReadDeadline(time.Now().Add(client.Timeout)) + + _, err := conn.ReadMsg() + return err + }) + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + ctx, cancel := context.WithTimeout(context.Background(), client.Timeout) + defer cancel() + + if err := srv.ShutdownContext(ctx); err != nil { + t.Errorf("could not shutdown test server, %v", err) + } + + select { + case <-done: + default: + t.Error("ShutdownContext returned before replies") + } + + if eg.Wait() != nil { + t.Fatalf("conn.ReadMsg error: %v", eg.Wait()) + } +} + +func TestInProgressQueriesAtShutdownTCP(t *testing.T) { + s, addr, _, err := RunLocalTCPServerWithFinChan(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + + c := &Client{Net: "tcp"} + checkInProgressQueriesAtShutdownServer(t, s, addr, c) +} + func TestShutdownTLS(t *testing.T) { cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) if err != nil { @@ -608,6 +733,30 @@ func TestShutdownTLS(t *testing.T) { } } +func TestInProgressQueriesAtShutdownTLS(t *testing.T) { + cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) + if err != nil { + t.Fatalf("unable to build certificate: %v", err) + } + + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + s, addr, err := RunLocalTLSServer(":0", &config) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + + c := &Client{ + Net: "tcp-tls", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + checkInProgressQueriesAtShutdownServer(t, s, addr, c) +} + type trigger struct { done bool sync.RWMutex @@ -684,6 +833,16 @@ func TestShutdownUDP(t *testing.T) { } } +func TestInProgressQueriesAtShutdownUDP(t *testing.T) { + s, addr, _, err := RunLocalUDPServerWithFinChan(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + + c := &Client{Net: "udp"} + checkInProgressQueriesAtShutdownServer(t, s, addr, c) +} + func TestServerStartStopRace(t *testing.T) { var wg sync.WaitGroup for i := 0; i < 10; i++ {