From a58e9c7a9e3fecc21e6b51b7b08f557c069a42b4 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Wed, 7 Oct 2015 00:10:38 +0100 Subject: [PATCH] Refactor server shutdown to call Close() on conn and sync on srv.started Remove the necessity for the hackish (and unreliable) fake packet. Fix a couple races and unclutter the start/stop internal state. --- server.go | 121 ++++++++++++++++++++---------------------------------- 1 file changed, 44 insertions(+), 77 deletions(-) diff --git a/server.go b/server.go index 48584f5b..223489a8 100644 --- a/server.go +++ b/server.go @@ -271,27 +271,21 @@ type Server struct { // DecorateWriter is optional, allows customization of the process that writes raw DNS messages. DecorateWriter DecorateWriter - // For graceful shutdown. - stopUDP chan bool - stopTCP chan bool - wgUDP sync.WaitGroup - wgTCP sync.WaitGroup + // Graceful shutdown handling - // make start/shutdown not racy - lock sync.Mutex + inFlight sync.WaitGroup + + lock sync.RWMutex started bool } // ListenAndServe starts a nameserver on the configured address in *Server. func (srv *Server) ListenAndServe() error { srv.lock.Lock() - // We can't use defer() becasue serveTCP/serveUDP don't return. + defer srv.lock.Unlock() if srv.started { - srv.lock.Unlock() return &Error{err: "server already started"} } - srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool) - srv.started = true addr := srv.Addr if addr == "" { addr = ":domain" @@ -303,43 +297,37 @@ func (srv *Server) ListenAndServe() error { case "tcp", "tcp4", "tcp6": a, e := net.ResolveTCPAddr(srv.Net, addr) if e != nil { - srv.lock.Unlock() - srv.started = false return e } l, e := net.ListenTCP(srv.Net, a) if e != nil { - srv.lock.Unlock() - srv.started = false return e } srv.Listener = l + srv.started = true srv.lock.Unlock() - return srv.serveTCP(l) + e = srv.serveTCP(l) + srv.lock.Lock() // to satisfy the defer at the top + return e case "udp", "udp4", "udp6": a, e := net.ResolveUDPAddr(srv.Net, addr) if e != nil { - srv.lock.Unlock() - srv.started = false return e } l, e := net.ListenUDP(srv.Net, a) if e != nil { - srv.lock.Unlock() - srv.started = false return e } if e := setUDPSocketOptions(l); e != nil { - srv.lock.Unlock() - srv.started = false return e } srv.PacketConn = l + srv.started = true srv.lock.Unlock() - return srv.serveUDP(l) + e = srv.serveUDP(l) + srv.lock.Lock() // to satisfy the defer at the top + return e } - srv.lock.Unlock() - srv.started = false return &Error{err: "bad network"} } @@ -347,12 +335,10 @@ func (srv *Server) ListenAndServe() error { // configured in *Server. Its main use is to start a server from systemd. func (srv *Server) ActivateAndServe() error { srv.lock.Lock() + defer srv.lock.Unlock() if srv.started { - srv.lock.Unlock() return &Error{err: "server already started"} } - srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool) - srv.started = true pConn := srv.PacketConn l := srv.Listener if pConn != nil { @@ -361,22 +347,24 @@ func (srv *Server) ActivateAndServe() error { } if t, ok := pConn.(*net.UDPConn); ok { if e := setUDPSocketOptions(t); e != nil { - srv.lock.Unlock() - srv.started = false return e } + srv.started = true srv.lock.Unlock() - return srv.serveUDP(t) + e := srv.serveUDP(t) + srv.lock.Lock() // to satisfy the defer at the top + return e } } if l != nil { if t, ok := l.(*net.TCPListener); ok { + srv.started = true srv.lock.Unlock() - return srv.serveTCP(t) + e := srv.serveTCP(t) + srv.lock.Lock() // to satisfy the defer at the top + return e } } - srv.lock.Unlock() - srv.started = false return &Error{err: "bad listeners"} } @@ -391,36 +379,20 @@ func (srv *Server) Shutdown() error { return &Error{err: "server not started"} } srv.started = false - net, addr := srv.Net, srv.Addr - switch { - case srv.Listener != nil: - a := srv.Listener.Addr() - net, addr = a.Network(), a.String() - case srv.PacketConn != nil: - a := srv.PacketConn.LocalAddr() - net, addr = a.Network(), a.String() - } srv.lock.Unlock() - fin := make(chan bool) - switch net { - case "tcp", "tcp4", "tcp6": - go func() { - srv.stopTCP <- true - srv.wgTCP.Wait() - fin <- true - }() - - case "udp", "udp4", "udp6": - go func() { - srv.stopUDP <- true - srv.wgUDP.Wait() - fin <- true - }() + if srv.PacketConn != nil { + srv.PacketConn.Close() + } + if srv.Listener != nil { + srv.Listener.Close() } - c := &Client{Net: net} - go c.Exchange(new(Msg), addr) // extra query to help ReadXXX loop to pass + fin := make(chan bool) + go func() { + srv.inFlight.Wait() + fin <- true + }() select { case <-time.After(srv.getReadTimeout()): @@ -465,15 +437,16 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { continue } m, e := reader.ReadTCP(rw, rtimeout) - select { - case <-srv.stopTCP: + srv.lock.RLock() + if !srv.started { + srv.lock.RUnlock() return nil - default: } + srv.lock.RUnlock() if e != nil { continue } - srv.wgTCP.Add(1) + srv.inFlight.Add(1) go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) } } @@ -500,21 +473,24 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { // deadline is not used here for { m, s, e := reader.ReadUDP(l, rtimeout) - select { - case <-srv.stopUDP: + srv.lock.RLock() + if !srv.started { + srv.lock.RUnlock() return nil - default: } + srv.lock.RUnlock() if e != nil { continue } - srv.wgUDP.Add(1) + srv.inFlight.Add(1) go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) } } // Serve a new connection. func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t *net.TCPConn) { + defer srv.inFlight.Done() + w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) @@ -524,15 +500,6 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses q := 0 // counter for the amount of TCP queries we get - defer func() { - if u != nil { - srv.wgUDP.Done() - } - if t != nil { - srv.wgTCP.Done() - } - }() - reader := Reader(&defaultReader{srv}) if srv.DecorateReader != nil { reader = srv.DecorateReader(reader)