diff --git a/server.go b/server.go index 7d3be2ac..5c365bfa 100644 --- a/server.go +++ b/server.go @@ -11,6 +11,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" ) @@ -302,34 +303,49 @@ func (srv *Server) ActivateAndServe() error { // Shutdown gracefully shuts down a server. After a call to Shutdown, ListenAndServe and // ActivateAndServe will return. All in progress queries are completed before the server -// is taken down. If the Shutdown was not succesful an error is returned. +// is taken down. If the Shutdown was not succesful an error is taking longer than reading +// timeout. func (srv *Server) Shutdown() error { - var net, addr string - + net := srv.Net switch { case srv.Listener != nil: - a := srv.Listener.Addr() - net, addr = a.Network(), a.String() + net = srv.Listener.Addr().Network() case srv.PacketConn != nil: - a := srv.PacketConn.LocalAddr() - net, addr = a.Network(), a.String() - default: - net, addr = srv.Net, srv.Addr + net = srv.PacketConn.LocalAddr().Network() } + fin := make(chan bool) switch net { case "tcp", "tcp4", "tcp6": - go func() { srv.stopTCP <- true }() + go func() { + srv.stopTCP <- true + srv.wgTCP.Wait() + fin <- true + }() + case "udp", "udp4", "udp6": - go func() { srv.stopUDP <- true }() + go func() { + srv.stopUDP <- true + srv.wgUDP.Wait() + fin <- true + }() } - // Send packet to server socket in order to force readUDP or readTCP to finish waiting for data. - // TODO(asergeyev): Alternative concurrent watchdog is possible to create in "serve*" in future - c := &Client{Net: net} - go c.Exchange(new(Msg), addr) + select { + case <-time.After(srv.getReadTimeout()): + return &Error{err: "shutdown is pending"} + case <-fin: + return nil + } +} - return nil +// getReadTimeout is a helper func to use system timeout if server did not intend to change it. +func (srv *Server) getReadTimeout() time.Duration { + rtimeout := dnsTimeout + if srv.ReadTimeout != 0 { + rtimeout = srv.ReadTimeout + } + return rtimeout } // serveTCP starts a TCP listener for the server. @@ -344,7 +360,14 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { if srv.ReadTimeout != 0 { rtimeout = srv.ReadTimeout } - for { + // deadline is not used here + done := int32(0) + go func() { + <-srv.stopTCP // there is no way out of serving but to receive stop + l.SetDeadline(time.Now()) + atomic.StoreInt32(&done, 1) + }() + for done == 0 { rw, e := l.AcceptTCP() if e != nil { continue @@ -355,21 +378,15 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { } srv.wgTCP.Add(1) go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) - select { - case <-srv.stopTCP: - // Asked to shutdown - srv.wgTCP.Wait() - return nil - default: - } } - panic("dns: not reached") + return nil } // serveUDP starts a UDP listener for the server. // Each request is handled in a seperate goroutine. func (srv *Server) serveUDP(l *net.UDPConn) error { defer l.Close() + handler := srv.Handler if handler == nil { handler = DefaultServeMux @@ -379,22 +396,22 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { rtimeout = srv.ReadTimeout } // deadline is not used here - for { + done := int32(0) + go func() { + <-srv.stopUDP // there is no way out of serving but to receive stop + l.SetDeadline(time.Now()) + atomic.StoreInt32(&done, 1) + }() + for done == 0 { m, s, e := srv.readUDP(l, rtimeout) if e != nil { continue } srv.wgUDP.Add(1) go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) - select { - case <-srv.stopUDP: - // Asked to shutdown - srv.wgUDP.Wait() - return nil - default: - } } - panic("dns: not reached") + srv.wgUDP.Wait() + return nil } // Serve a new connection. diff --git a/server_test.go b/server_test.go index cce21289..3264a358 100644 --- a/server_test.go +++ b/server_test.go @@ -313,7 +313,7 @@ func TestShutdownTCP(t *testing.T) { } err = s.Shutdown() if err != nil { - t.Error("Could not shutdown test TCP server, %s", err) + t.Errorf("Could not shutdown test TCP server, %s", err) } } @@ -324,6 +324,6 @@ func TestShutdownUDP(t *testing.T) { } err = s.Shutdown() if err != nil { - t.Error("Could not shutdown test UDP server, %s", err) + t.Errorf("Could not shutdown test UDP server, %s", err) } }