diff --git a/server.go b/server.go index d3f526fc..4fbf7db6 100644 --- a/server.go +++ b/server.go @@ -4,6 +4,7 @@ package dns import ( "bytes" + "context" "crypto/tls" "encoding/binary" "errors" @@ -31,6 +32,10 @@ const maxIdleWorkersCount = 10000 // The maximum length of time a worker may idle for before being destroyed. const idleWorkerTimeout = 10 * time.Second +// aLongTimeAgo is a non-zero time, far in the past, used for +// immediate cancelation of network operations. +var aLongTimeAgo = time.Unix(1, 0) + // Handler is implemented by any value that implements ServeDNS. type Handler interface { ServeDNS(w ResponseWriter, r *Msg) @@ -69,6 +74,7 @@ type response struct { tcp net.Conn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right writer Writer // writer to output the raw DNS bits + wg *sync.WaitGroup // for gracefull shutdown } // ServeMux is an DNS request multiplexer. It matches the @@ -322,9 +328,12 @@ type Server struct { queue chan *response // Workers count workersCount int32 + // Shutdown handling - lock sync.RWMutex - started bool + lock sync.RWMutex + started bool + shutdown chan struct{} + conns map[net.Conn]struct{} // A pool for UDP message buffers. udpPool sync.Pool @@ -391,6 +400,9 @@ func makeUDPBuffer(size int) func() interface{} { func (srv *Server) init() { srv.queue = make(chan *response) + srv.shutdown = make(chan struct{}) + srv.conns = make(map[net.Conn]struct{}) + if srv.UDPSize == 0 { srv.UDPSize = MinMsgSize } @@ -501,23 +513,58 @@ func (srv *Server) ActivateAndServe() error { // Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and // ActivateAndServe will return. func (srv *Server) Shutdown() error { + return srv.ShutdownContext(context.Background()) +} + +// ShutdownContext shuts down a server. After a call to ShutdownContext, +// ListenAndServe and ActivateAndServe will return. +// +// A context.Context may be passed to limit how long to wait for connections +// to terminate. +func (srv *Server) ShutdownContext(ctx context.Context) error { srv.lock.Lock() - if !srv.started { - srv.lock.Unlock() - return &Error{err: "server not started"} - } + started := srv.started srv.started = false srv.lock.Unlock() + if !started { + return &Error{err: "server not started"} + } + + if srv.PacketConn != nil { + srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads + } + + if srv.Listener != nil { + srv.Listener.Close() + } + + srv.lock.Lock() + for rw := range srv.conns { + rw.SetReadDeadline(aLongTimeAgo) // Unblock reads + } + srv.lock.Unlock() + + if testShutdownNotify != nil { + testShutdownNotify.Broadcast() + } + + var ctxErr error + select { + case <-srv.shutdown: + case <-ctx.Done(): + ctxErr = ctx.Err() + } + if srv.PacketConn != nil { srv.PacketConn.Close() } - if srv.Listener != nil { - srv.Listener.Close() - } - return nil + + return ctxErr } +var testShutdownNotify *sync.Cond + // getReadTimeout is a helper func to use system timeout if server did not intend to change it. func (srv *Server) getReadTimeout() time.Duration { rtimeout := dnsTimeout @@ -535,19 +582,36 @@ func (srv *Server) serveTCP(l net.Listener) error { srv.NotifyStartedFunc() } - for { + var wg sync.WaitGroup + defer func() { + wg.Wait() + close(srv.shutdown) + }() + + for srv.isStarted() { rw, err := l.Accept() - if !srv.isStarted() { - return nil - } if err != nil { + if !srv.isStarted() { + return nil + } if neterr, ok := err.(net.Error); ok && neterr.Temporary() { continue } return err } - srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw}) + srv.lock.Lock() + // Track the connection to allow unblocking reads on shutdown. + srv.conns[rw] = struct{}{} + srv.lock.Unlock() + wg.Add(1) + srv.spawnWorker(&response{ + tsigSecret: srv.TsigSecret, + tcp: rw, + wg: &wg, + }) } + + return nil } // serveUDP starts a UDP listener for the server. @@ -563,14 +627,20 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { reader = srv.DecorateReader(reader) } + var wg sync.WaitGroup + defer func() { + wg.Wait() + close(srv.shutdown) + }() + rtimeout := srv.getReadTimeout() // deadline is not used here - for { + for srv.isStarted() { m, s, err := reader.ReadUDP(l, rtimeout) - if !srv.isStarted() { - return nil - } if err != nil { + if !srv.isStarted() { + return nil + } if netErr, ok := err.(net.Error); ok && netErr.Temporary() { continue } @@ -582,8 +652,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { } continue } - srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s}) + wg.Add(1) + srv.spawnWorker(&response{ + msg: m, + tsigSecret: srv.TsigSecret, + udp: l, + udpSession: s, + wg: &wg, + }) } + + return nil } func (srv *Server) serve(w *response) { @@ -596,20 +675,28 @@ func (srv *Server) serve(w *response) { if w.udp != nil { // serve UDP srv.serveDNS(w) - return - } - reader := Reader(&defaultReader{srv}) - if srv.DecorateReader != nil { - reader = srv.DecorateReader(reader) + w.wg.Done() + return } defer func() { if !w.hijacked { w.Close() } + + srv.lock.Lock() + delete(srv.conns, w.tcp) + srv.lock.Unlock() + + w.wg.Done() }() + reader := Reader(&defaultReader{srv}) + if srv.DecorateReader != nil { + reader = srv.DecorateReader(reader) + } + idleTimeout := tcpIdleTimeout if srv.IdleTimeout != nil { idleTimeout = srv.IdleTimeout() @@ -622,7 +709,7 @@ func (srv *Server) serve(w *response) { limit = maxTCPQueries } - for q := 0; q < limit || limit == -1; q++ { + for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ { var err error w.msg, err = reader.ReadTCP(w.tcp, timeout) if err != nil { diff --git a/server_test.go b/server_test.go index e6f073d7..c53edd03 100644 --- a/server_test.go +++ b/server_test.go @@ -1,6 +1,7 @@ package dns import ( + "context" "crypto/tls" "fmt" "io" @@ -10,6 +11,8 @@ import ( "sync" "testing" "time" + + "golang.org/x/sync/errgroup" ) func HelloServer(w ResponseWriter, req *Msg) { @@ -588,6 +591,128 @@ func TestShutdownTCP(t *testing.T) { } } +func init() { + testShutdownNotify = &sync.Cond{ + L: new(sync.Mutex), + } +} + +func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr string, client *Client) { + const requests = 100 + + var wg sync.WaitGroup + wg.Add(requests) + + var errOnce sync.Once + + HandleFunc("example.com.", func(w ResponseWriter, req *Msg) { + defer wg.Done() + + // Wait until ShutdownContext is called before replying. + testShutdownNotify.L.Lock() + testShutdownNotify.Wait() + testShutdownNotify.L.Unlock() + + m := new(Msg) + m.SetReply(req) + m.Extra = make([]RR, 1) + m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} + + if err := w.WriteMsg(m); err != nil { + errOnce.Do(func() { + t.Errorf("ResponseWriter.WriteMsg error: %s", err) + }) + } + }) + defer HandleRemove("example.com.") + + client.Timeout = 10 * time.Second + + conns := make([]*Conn, requests) + eg := new(errgroup.Group) + + for i := range conns { + conn := &conns[i] + eg.Go(func() error { + var err error + *conn, err = client.Dial(addr) + return err + }) + } + + if eg.Wait() != nil { + t.Fatalf("client.Dial error: %v", eg.Wait()) + } + + m := new(Msg) + m.SetQuestion("example.com.", TypeTXT) + eg = new(errgroup.Group) + + for _, conn := range conns { + conn := conn + eg.Go(func() error { + conn.SetWriteDeadline(time.Now().Add(client.Timeout)) + + return conn.WriteMsg(m) + }) + } + + if eg.Wait() != nil { + t.Fatalf("conn.WriteMsg error: %v", eg.Wait()) + } + + // This sleep is needed to allow time for the requests to + // pass from the client through the kernel and back into + // the server. Without it, some requests may still be in + // the kernel's buffer when ShutdownContext is called. + time.Sleep(100 * time.Millisecond) + + eg = new(errgroup.Group) + + for _, conn := range conns { + conn := conn + eg.Go(func() error { + conn.SetReadDeadline(time.Now().Add(client.Timeout)) + + _, err := conn.ReadMsg() + return err + }) + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + ctx, cancel := context.WithTimeout(context.Background(), client.Timeout) + defer cancel() + + if err := srv.ShutdownContext(ctx); err != nil { + t.Errorf("could not shutdown test server, %v", err) + } + + select { + case <-done: + default: + t.Error("ShutdownContext returned before replies") + } + + if eg.Wait() != nil { + t.Fatalf("conn.ReadMsg error: %v", eg.Wait()) + } +} + +func TestInProgressQueriesAtShutdownTCP(t *testing.T) { + s, addr, _, err := RunLocalTCPServerWithFinChan(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + + c := &Client{Net: "tcp"} + checkInProgressQueriesAtShutdownServer(t, s, addr, c) +} + func TestShutdownTLS(t *testing.T) { cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) if err != nil { @@ -608,6 +733,30 @@ func TestShutdownTLS(t *testing.T) { } } +func TestInProgressQueriesAtShutdownTLS(t *testing.T) { + cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) + if err != nil { + t.Fatalf("unable to build certificate: %v", err) + } + + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + s, addr, err := RunLocalTLSServer(":0", &config) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + + c := &Client{ + Net: "tcp-tls", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + checkInProgressQueriesAtShutdownServer(t, s, addr, c) +} + type trigger struct { done bool sync.RWMutex @@ -684,6 +833,16 @@ func TestShutdownUDP(t *testing.T) { } } +func TestInProgressQueriesAtShutdownUDP(t *testing.T) { + s, addr, _, err := RunLocalUDPServerWithFinChan(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + + c := &Client{Net: "udp"} + checkInProgressQueriesAtShutdownServer(t, s, addr, c) +} + func TestServerStartStopRace(t *testing.T) { var wg sync.WaitGroup for i := 0; i < 10; i++ {