From f3a6c86462e1bf5b0b8be03e0c7a95f10f827023 Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Mon, 18 Aug 2014 22:06:29 +0100 Subject: [PATCH] Implement Shutdown() call And fix some tests to call Fatal(). --- server.go | 49 ++++++++++++++++++++++++++++++++++++++++++++----- server_test.go | 4 ++-- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/server.go b/server.go index 276e37f0..3c4c97d3 100644 --- a/server.go +++ b/server.go @@ -36,7 +36,7 @@ type ResponseWriter interface { // TsigTimersOnly sets the tsig timers only boolean. TsigTimersOnly(bool) // Hijack lets the caller take over the connection. - // After a call to Hijack(), the DNS package will not do anything with the connection + // After a call to Hijack(), the DNS package will not do anything with the connection. Hijack() } @@ -217,14 +217,22 @@ type Server struct { // Default buffer size to use to read incoming UDP messages. If not set // it defaults to MinMsgSize (512 B). UDPSize int - // The net.Conn.SetReadTimeout value for new connections. + // The net.Conn.SetReadTimeout value for new connections, defaults to 2 seconds. ReadTimeout time.Duration - // The net.Conn.SetWriteTimeout value for new connections. + // The net.Conn.SetWriteTimeout value for new connections, defaults to 2 seconds. WriteTimeout time.Duration // TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966). IdleTimeout func() time.Duration + // Listener deadline timeout, defaults to 2 seconds. + Deadline time.Duration // Secret(s) for Tsig map[]. TsigSecret map[string]string + + // For graceful shutdown. + stopUDP chan bool + stopTCP chan bool + wgUDP sync.WaitGroup + wgTCP sync.WaitGroup } // ListenAndServe starts a nameserver on the configured address in *Server. @@ -238,6 +246,7 @@ func (srv *Server) ListenAndServe() error { } switch srv.Net { case "tcp", "tcp4", "tcp6": + srv.stopTCP = make(chan bool) a, e := net.ResolveTCPAddr(srv.Net, addr) if e != nil { return e @@ -248,6 +257,7 @@ func (srv *Server) ListenAndServe() error { } return srv.serveTCP(l) case "udp", "udp4", "udp6": + srv.stopUDP = make(chan bool) a, e := net.ResolveUDPAddr(srv.Net, addr) if e != nil { return e @@ -268,7 +278,11 @@ func (srv *Server) ListenAndServe() error { // ActivateAndServe starts a nameserver with the PacketConn or Listener // configured in *Server. Its main use is to start a server from systemd. func (srv *Server) ActivateAndServe() error { + if srv.UDPSize == 0 { + srv.UDPSize = MinMsgSize + } if srv.PacketConn != nil { + srv.stopUDP = make(chan bool) if srv.UDPSize == 0 { srv.UDPSize = MinMsgSize } @@ -280,6 +294,7 @@ func (srv *Server) ActivateAndServe() error { } } if srv.Listener != nil { + srv.stopTCP = make(chan bool) if t, ok := srv.Listener.(*net.TCPListener); ok { return srv.serveTCP(t) } @@ -290,7 +305,8 @@ func (srv *Server) ActivateAndServe() error { // Shutdown shuts down a server. When Shutdown returns all currently in progress // queries have been answered and all started goroutines have been stopped. func (srv *Server) Shutdown() { - + srv.stopTCP <- true + srv.stopUDP <- true } // serveTCP starts a TCP listener for the server. @@ -307,6 +323,13 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { } for { rw, e := l.AcceptTCP() + select { + case <-srv.stopTCP: + // Asked to shutdown + srv.wgTCP.Wait() + return nil + default: + } if e != nil { continue } @@ -314,6 +337,7 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { if e != nil { continue } + srv.wgTCP.Add(1) go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) } panic("dns: not reached") @@ -333,10 +357,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { } for { m, s, e := srv.readUDP(l, rtimeout) + select { + case <-srv.stopUDP: + // Asked to shutdown + srv.wgUDP.Wait() + return nil + default: + } if e != nil { - // TODO(miek): logging? continue } + srv.wgUDP.Add(1) go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) } panic("dns: not reached") @@ -346,6 +377,14 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *sessionUDP, t *net.TCPConn) { w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} q := 0 + defer func() { + if u != nil { + srv.wgUDP.Done() + } + if t != nil { + srv.wgTCP.Done() + } + }() Redo: req := new(Msg) err := req.Unpack(m) diff --git a/server_test.go b/server_test.go index 763e25b7..a2285ef7 100644 --- a/server_test.go +++ b/server_test.go @@ -48,7 +48,7 @@ func TestServing(t *testing.T) { r, _, err := c.Exchange(m, "127.0.0.1:8053") if err != nil { t.Log("Failed to exchange miek.nl", err) - t.Fail() + t.Fatal() } txt := r.Extra[0].(*TXT).Txt[0] if txt != "Hello world" { @@ -60,7 +60,7 @@ func TestServing(t *testing.T) { r, _, err = c.Exchange(m, "127.0.0.1:8053") if err != nil { t.Log("Failed to exchange example.com", err) - t.Fail() + t.Fatal() } txt = r.Extra[0].(*TXT).Txt[0] if txt != "Hello example" {