Implement Shutdown() call

And fix some tests to call Fatal().
This commit is contained in:
Miek Gieben 2014-08-18 22:06:29 +01:00
parent 6cf24a5637
commit f3a6c86462
2 changed files with 46 additions and 7 deletions

View File

@ -36,7 +36,7 @@ type ResponseWriter interface {
// TsigTimersOnly sets the tsig timers only boolean. // TsigTimersOnly sets the tsig timers only boolean.
TsigTimersOnly(bool) TsigTimersOnly(bool)
// Hijack lets the caller take over the connection. // Hijack lets the caller take over the connection.
// After a call to Hijack(), the DNS package will not do anything with the connection // After a call to Hijack(), the DNS package will not do anything with the connection.
Hijack() Hijack()
} }
@ -217,14 +217,22 @@ type Server struct {
// Default buffer size to use to read incoming UDP messages. If not set // Default buffer size to use to read incoming UDP messages. If not set
// it defaults to MinMsgSize (512 B). // it defaults to MinMsgSize (512 B).
UDPSize int UDPSize int
// The net.Conn.SetReadTimeout value for new connections. // The net.Conn.SetReadTimeout value for new connections, defaults to 2 seconds.
ReadTimeout time.Duration ReadTimeout time.Duration
// The net.Conn.SetWriteTimeout value for new connections. // The net.Conn.SetWriteTimeout value for new connections, defaults to 2 seconds.
WriteTimeout time.Duration WriteTimeout time.Duration
// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966). // TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
IdleTimeout func() time.Duration IdleTimeout func() time.Duration
// Listener deadline timeout, defaults to 2 seconds.
Deadline time.Duration
// Secret(s) for Tsig map[<zonename>]<base64 secret>. // Secret(s) for Tsig map[<zonename>]<base64 secret>.
TsigSecret map[string]string TsigSecret map[string]string
// For graceful shutdown.
stopUDP chan bool
stopTCP chan bool
wgUDP sync.WaitGroup
wgTCP sync.WaitGroup
} }
// ListenAndServe starts a nameserver on the configured address in *Server. // ListenAndServe starts a nameserver on the configured address in *Server.
@ -238,6 +246,7 @@ func (srv *Server) ListenAndServe() error {
} }
switch srv.Net { switch srv.Net {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
srv.stopTCP = make(chan bool)
a, e := net.ResolveTCPAddr(srv.Net, addr) a, e := net.ResolveTCPAddr(srv.Net, addr)
if e != nil { if e != nil {
return e return e
@ -248,6 +257,7 @@ func (srv *Server) ListenAndServe() error {
} }
return srv.serveTCP(l) return srv.serveTCP(l)
case "udp", "udp4", "udp6": case "udp", "udp4", "udp6":
srv.stopUDP = make(chan bool)
a, e := net.ResolveUDPAddr(srv.Net, addr) a, e := net.ResolveUDPAddr(srv.Net, addr)
if e != nil { if e != nil {
return e return e
@ -268,7 +278,11 @@ func (srv *Server) ListenAndServe() error {
// ActivateAndServe starts a nameserver with the PacketConn or Listener // ActivateAndServe starts a nameserver with the PacketConn or Listener
// configured in *Server. Its main use is to start a server from systemd. // configured in *Server. Its main use is to start a server from systemd.
func (srv *Server) ActivateAndServe() error { func (srv *Server) ActivateAndServe() error {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
if srv.PacketConn != nil { if srv.PacketConn != nil {
srv.stopUDP = make(chan bool)
if srv.UDPSize == 0 { if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize srv.UDPSize = MinMsgSize
} }
@ -280,6 +294,7 @@ func (srv *Server) ActivateAndServe() error {
} }
} }
if srv.Listener != nil { if srv.Listener != nil {
srv.stopTCP = make(chan bool)
if t, ok := srv.Listener.(*net.TCPListener); ok { if t, ok := srv.Listener.(*net.TCPListener); ok {
return srv.serveTCP(t) return srv.serveTCP(t)
} }
@ -290,7 +305,8 @@ func (srv *Server) ActivateAndServe() error {
// Shutdown shuts down a server. When Shutdown returns all currently in progress // Shutdown shuts down a server. When Shutdown returns all currently in progress
// queries have been answered and all started goroutines have been stopped. // queries have been answered and all started goroutines have been stopped.
func (srv *Server) Shutdown() { func (srv *Server) Shutdown() {
srv.stopTCP <- true
srv.stopUDP <- true
} }
// serveTCP starts a TCP listener for the server. // serveTCP starts a TCP listener for the server.
@ -307,6 +323,13 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
} }
for { for {
rw, e := l.AcceptTCP() rw, e := l.AcceptTCP()
select {
case <-srv.stopTCP:
// Asked to shutdown
srv.wgTCP.Wait()
return nil
default:
}
if e != nil { if e != nil {
continue continue
} }
@ -314,6 +337,7 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
if e != nil { if e != nil {
continue continue
} }
srv.wgTCP.Add(1)
go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
} }
panic("dns: not reached") panic("dns: not reached")
@ -333,10 +357,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
} }
for { for {
m, s, e := srv.readUDP(l, rtimeout) m, s, e := srv.readUDP(l, rtimeout)
select {
case <-srv.stopUDP:
// Asked to shutdown
srv.wgUDP.Wait()
return nil
default:
}
if e != nil { if e != nil {
// TODO(miek): logging?
continue continue
} }
srv.wgUDP.Add(1)
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
} }
panic("dns: not reached") panic("dns: not reached")
@ -346,6 +377,14 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *sessionUDP, t *net.TCPConn) { func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *sessionUDP, t *net.TCPConn) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
q := 0 q := 0
defer func() {
if u != nil {
srv.wgUDP.Done()
}
if t != nil {
srv.wgTCP.Done()
}
}()
Redo: Redo:
req := new(Msg) req := new(Msg)
err := req.Unpack(m) err := req.Unpack(m)

View File

@ -48,7 +48,7 @@ func TestServing(t *testing.T) {
r, _, err := c.Exchange(m, "127.0.0.1:8053") r, _, err := c.Exchange(m, "127.0.0.1:8053")
if err != nil { if err != nil {
t.Log("Failed to exchange miek.nl", err) t.Log("Failed to exchange miek.nl", err)
t.Fail() t.Fatal()
} }
txt := r.Extra[0].(*TXT).Txt[0] txt := r.Extra[0].(*TXT).Txt[0]
if txt != "Hello world" { if txt != "Hello world" {
@ -60,7 +60,7 @@ func TestServing(t *testing.T) {
r, _, err = c.Exchange(m, "127.0.0.1:8053") r, _, err = c.Exchange(m, "127.0.0.1:8053")
if err != nil { if err != nil {
t.Log("Failed to exchange example.com", err) t.Log("Failed to exchange example.com", err)
t.Fail() t.Fatal()
} }
txt = r.Extra[0].(*TXT).Txt[0] txt = r.Extra[0].(*TXT).Txt[0]
if txt != "Hello example" { if txt != "Hello example" {