Refactor server shutdown to call Close() on conn and sync on srv.started
Remove the necessity for the hackish (and unreliable) fake packet. Fix a couple races and unclutter the start/stop internal state.
This commit is contained in:
parent
7463496b65
commit
a58e9c7a9e
121
server.go
121
server.go
|
@ -271,27 +271,21 @@ type Server struct {
|
|||
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
|
||||
DecorateWriter DecorateWriter
|
||||
|
||||
// For graceful shutdown.
|
||||
stopUDP chan bool
|
||||
stopTCP chan bool
|
||||
wgUDP sync.WaitGroup
|
||||
wgTCP sync.WaitGroup
|
||||
// Graceful shutdown handling
|
||||
|
||||
// make start/shutdown not racy
|
||||
lock sync.Mutex
|
||||
inFlight sync.WaitGroup
|
||||
|
||||
lock sync.RWMutex
|
||||
started bool
|
||||
}
|
||||
|
||||
// ListenAndServe starts a nameserver on the configured address in *Server.
|
||||
func (srv *Server) ListenAndServe() error {
|
||||
srv.lock.Lock()
|
||||
// We can't use defer() becasue serveTCP/serveUDP don't return.
|
||||
defer srv.lock.Unlock()
|
||||
if srv.started {
|
||||
srv.lock.Unlock()
|
||||
return &Error{err: "server already started"}
|
||||
}
|
||||
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
|
||||
srv.started = true
|
||||
addr := srv.Addr
|
||||
if addr == "" {
|
||||
addr = ":domain"
|
||||
|
@ -303,43 +297,37 @@ func (srv *Server) ListenAndServe() error {
|
|||
case "tcp", "tcp4", "tcp6":
|
||||
a, e := net.ResolveTCPAddr(srv.Net, addr)
|
||||
if e != nil {
|
||||
srv.lock.Unlock()
|
||||
srv.started = false
|
||||
return e
|
||||
}
|
||||
l, e := net.ListenTCP(srv.Net, a)
|
||||
if e != nil {
|
||||
srv.lock.Unlock()
|
||||
srv.started = false
|
||||
return e
|
||||
}
|
||||
srv.Listener = l
|
||||
srv.started = true
|
||||
srv.lock.Unlock()
|
||||
return srv.serveTCP(l)
|
||||
e = srv.serveTCP(l)
|
||||
srv.lock.Lock() // to satisfy the defer at the top
|
||||
return e
|
||||
case "udp", "udp4", "udp6":
|
||||
a, e := net.ResolveUDPAddr(srv.Net, addr)
|
||||
if e != nil {
|
||||
srv.lock.Unlock()
|
||||
srv.started = false
|
||||
return e
|
||||
}
|
||||
l, e := net.ListenUDP(srv.Net, a)
|
||||
if e != nil {
|
||||
srv.lock.Unlock()
|
||||
srv.started = false
|
||||
return e
|
||||
}
|
||||
if e := setUDPSocketOptions(l); e != nil {
|
||||
srv.lock.Unlock()
|
||||
srv.started = false
|
||||
return e
|
||||
}
|
||||
srv.PacketConn = l
|
||||
srv.started = true
|
||||
srv.lock.Unlock()
|
||||
return srv.serveUDP(l)
|
||||
e = srv.serveUDP(l)
|
||||
srv.lock.Lock() // to satisfy the defer at the top
|
||||
return e
|
||||
}
|
||||
srv.lock.Unlock()
|
||||
srv.started = false
|
||||
return &Error{err: "bad network"}
|
||||
}
|
||||
|
||||
|
@ -347,12 +335,10 @@ func (srv *Server) ListenAndServe() error {
|
|||
// configured in *Server. Its main use is to start a server from systemd.
|
||||
func (srv *Server) ActivateAndServe() error {
|
||||
srv.lock.Lock()
|
||||
defer srv.lock.Unlock()
|
||||
if srv.started {
|
||||
srv.lock.Unlock()
|
||||
return &Error{err: "server already started"}
|
||||
}
|
||||
srv.stopUDP, srv.stopTCP = make(chan bool), make(chan bool)
|
||||
srv.started = true
|
||||
pConn := srv.PacketConn
|
||||
l := srv.Listener
|
||||
if pConn != nil {
|
||||
|
@ -361,22 +347,24 @@ func (srv *Server) ActivateAndServe() error {
|
|||
}
|
||||
if t, ok := pConn.(*net.UDPConn); ok {
|
||||
if e := setUDPSocketOptions(t); e != nil {
|
||||
srv.lock.Unlock()
|
||||
srv.started = false
|
||||
return e
|
||||
}
|
||||
srv.started = true
|
||||
srv.lock.Unlock()
|
||||
return srv.serveUDP(t)
|
||||
e := srv.serveUDP(t)
|
||||
srv.lock.Lock() // to satisfy the defer at the top
|
||||
return e
|
||||
}
|
||||
}
|
||||
if l != nil {
|
||||
if t, ok := l.(*net.TCPListener); ok {
|
||||
srv.started = true
|
||||
srv.lock.Unlock()
|
||||
return srv.serveTCP(t)
|
||||
e := srv.serveTCP(t)
|
||||
srv.lock.Lock() // to satisfy the defer at the top
|
||||
return e
|
||||
}
|
||||
}
|
||||
srv.lock.Unlock()
|
||||
srv.started = false
|
||||
return &Error{err: "bad listeners"}
|
||||
}
|
||||
|
||||
|
@ -391,36 +379,20 @@ func (srv *Server) Shutdown() error {
|
|||
return &Error{err: "server not started"}
|
||||
}
|
||||
srv.started = false
|
||||
net, addr := srv.Net, srv.Addr
|
||||
switch {
|
||||
case srv.Listener != nil:
|
||||
a := srv.Listener.Addr()
|
||||
net, addr = a.Network(), a.String()
|
||||
case srv.PacketConn != nil:
|
||||
a := srv.PacketConn.LocalAddr()
|
||||
net, addr = a.Network(), a.String()
|
||||
}
|
||||
srv.lock.Unlock()
|
||||
|
||||
fin := make(chan bool)
|
||||
switch net {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
go func() {
|
||||
srv.stopTCP <- true
|
||||
srv.wgTCP.Wait()
|
||||
fin <- true
|
||||
}()
|
||||
|
||||
case "udp", "udp4", "udp6":
|
||||
go func() {
|
||||
srv.stopUDP <- true
|
||||
srv.wgUDP.Wait()
|
||||
fin <- true
|
||||
}()
|
||||
if srv.PacketConn != nil {
|
||||
srv.PacketConn.Close()
|
||||
}
|
||||
if srv.Listener != nil {
|
||||
srv.Listener.Close()
|
||||
}
|
||||
|
||||
c := &Client{Net: net}
|
||||
go c.Exchange(new(Msg), addr) // extra query to help ReadXXX loop to pass
|
||||
fin := make(chan bool)
|
||||
go func() {
|
||||
srv.inFlight.Wait()
|
||||
fin <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(srv.getReadTimeout()):
|
||||
|
@ -465,15 +437,16 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
|
|||
continue
|
||||
}
|
||||
m, e := reader.ReadTCP(rw, rtimeout)
|
||||
select {
|
||||
case <-srv.stopTCP:
|
||||
srv.lock.RLock()
|
||||
if !srv.started {
|
||||
srv.lock.RUnlock()
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
srv.lock.RUnlock()
|
||||
if e != nil {
|
||||
continue
|
||||
}
|
||||
srv.wgTCP.Add(1)
|
||||
srv.inFlight.Add(1)
|
||||
go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
|
||||
}
|
||||
}
|
||||
|
@ -500,21 +473,24 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
|||
// deadline is not used here
|
||||
for {
|
||||
m, s, e := reader.ReadUDP(l, rtimeout)
|
||||
select {
|
||||
case <-srv.stopUDP:
|
||||
srv.lock.RLock()
|
||||
if !srv.started {
|
||||
srv.lock.RUnlock()
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
srv.lock.RUnlock()
|
||||
if e != nil {
|
||||
continue
|
||||
}
|
||||
srv.wgUDP.Add(1)
|
||||
srv.inFlight.Add(1)
|
||||
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Serve a new connection.
|
||||
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t *net.TCPConn) {
|
||||
defer srv.inFlight.Done()
|
||||
|
||||
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
|
||||
if srv.DecorateWriter != nil {
|
||||
w.writer = srv.DecorateWriter(w)
|
||||
|
@ -524,15 +500,6 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses
|
|||
|
||||
q := 0 // counter for the amount of TCP queries we get
|
||||
|
||||
defer func() {
|
||||
if u != nil {
|
||||
srv.wgUDP.Done()
|
||||
}
|
||||
if t != nil {
|
||||
srv.wgTCP.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
reader := Reader(&defaultReader{srv})
|
||||
if srv.DecorateReader != nil {
|
||||
reader = srv.DecorateReader(reader)
|
||||
|
|
Loading…
Reference in New Issue