Implement Shutdown() call

And fix some tests to call Fatal().
This commit is contained in:
Miek Gieben 2014-08-18 22:06:29 +01:00
parent 6cf24a5637
commit f3a6c86462
2 changed files with 46 additions and 7 deletions

View File

@ -36,7 +36,7 @@ type ResponseWriter interface {
// TsigTimersOnly sets the tsig timers only boolean.
TsigTimersOnly(bool)
// Hijack lets the caller take over the connection.
// After a call to Hijack(), the DNS package will not do anything with the connection
// After a call to Hijack(), the DNS package will not do anything with the connection.
Hijack()
}
@ -217,14 +217,22 @@ type Server struct {
// Default buffer size to use to read incoming UDP messages. If not set
// it defaults to MinMsgSize (512 B).
UDPSize int
// The net.Conn.SetReadTimeout value for new connections.
// The net.Conn.SetReadTimeout value for new connections, defaults to 2 seconds.
ReadTimeout time.Duration
// The net.Conn.SetWriteTimeout value for new connections.
// The net.Conn.SetWriteTimeout value for new connections, defaults to 2 seconds.
WriteTimeout time.Duration
// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
IdleTimeout func() time.Duration
// Listener deadline timeout, defaults to 2 seconds.
Deadline time.Duration
// Secret(s) for Tsig map[<zonename>]<base64 secret>.
TsigSecret map[string]string
// For graceful shutdown.
stopUDP chan bool
stopTCP chan bool
wgUDP sync.WaitGroup
wgTCP sync.WaitGroup
}
// ListenAndServe starts a nameserver on the configured address in *Server.
@ -238,6 +246,7 @@ func (srv *Server) ListenAndServe() error {
}
switch srv.Net {
case "tcp", "tcp4", "tcp6":
srv.stopTCP = make(chan bool)
a, e := net.ResolveTCPAddr(srv.Net, addr)
if e != nil {
return e
@ -248,6 +257,7 @@ func (srv *Server) ListenAndServe() error {
}
return srv.serveTCP(l)
case "udp", "udp4", "udp6":
srv.stopUDP = make(chan bool)
a, e := net.ResolveUDPAddr(srv.Net, addr)
if e != nil {
return e
@ -268,7 +278,11 @@ func (srv *Server) ListenAndServe() error {
// ActivateAndServe starts a nameserver with the PacketConn or Listener
// configured in *Server. Its main use is to start a server from systemd.
func (srv *Server) ActivateAndServe() error {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
if srv.PacketConn != nil {
srv.stopUDP = make(chan bool)
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
@ -280,6 +294,7 @@ func (srv *Server) ActivateAndServe() error {
}
}
if srv.Listener != nil {
srv.stopTCP = make(chan bool)
if t, ok := srv.Listener.(*net.TCPListener); ok {
return srv.serveTCP(t)
}
@ -290,7 +305,8 @@ func (srv *Server) ActivateAndServe() error {
// Shutdown shuts down a server. When Shutdown returns all currently in progress
// queries have been answered and all started goroutines have been stopped.
func (srv *Server) Shutdown() {
srv.stopTCP <- true
srv.stopUDP <- true
}
// serveTCP starts a TCP listener for the server.
@ -307,6 +323,13 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
}
for {
rw, e := l.AcceptTCP()
select {
case <-srv.stopTCP:
// Asked to shutdown
srv.wgTCP.Wait()
return nil
default:
}
if e != nil {
continue
}
@ -314,6 +337,7 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
if e != nil {
continue
}
srv.wgTCP.Add(1)
go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
}
panic("dns: not reached")
@ -333,10 +357,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
}
for {
m, s, e := srv.readUDP(l, rtimeout)
select {
case <-srv.stopUDP:
// Asked to shutdown
srv.wgUDP.Wait()
return nil
default:
}
if e != nil {
// TODO(miek): logging?
continue
}
srv.wgUDP.Add(1)
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
}
panic("dns: not reached")
@ -346,6 +377,14 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *sessionUDP, t *net.TCPConn) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
q := 0
defer func() {
if u != nil {
srv.wgUDP.Done()
}
if t != nil {
srv.wgTCP.Done()
}
}()
Redo:
req := new(Msg)
err := req.Unpack(m)

View File

@ -48,7 +48,7 @@ func TestServing(t *testing.T) {
r, _, err := c.Exchange(m, "127.0.0.1:8053")
if err != nil {
t.Log("Failed to exchange miek.nl", err)
t.Fail()
t.Fatal()
}
txt := r.Extra[0].(*TXT).Txt[0]
if txt != "Hello world" {
@ -60,7 +60,7 @@ func TestServing(t *testing.T) {
r, _, err = c.Exchange(m, "127.0.0.1:8053")
if err != nil {
t.Log("Failed to exchange example.com", err)
t.Fail()
t.Fatal()
}
txt = r.Extra[0].(*TXT).Txt[0]
if txt != "Hello example" {