Implement Shutdown() call
And fix some tests to call Fatal().
This commit is contained in:
parent
6cf24a5637
commit
f3a6c86462
49
server.go
49
server.go
|
@ -36,7 +36,7 @@ type ResponseWriter interface {
|
||||||
// TsigTimersOnly sets the tsig timers only boolean.
|
// TsigTimersOnly sets the tsig timers only boolean.
|
||||||
TsigTimersOnly(bool)
|
TsigTimersOnly(bool)
|
||||||
// Hijack lets the caller take over the connection.
|
// Hijack lets the caller take over the connection.
|
||||||
// After a call to Hijack(), the DNS package will not do anything with the connection
|
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
||||||
Hijack()
|
Hijack()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,14 +217,22 @@ type Server struct {
|
||||||
// Default buffer size to use to read incoming UDP messages. If not set
|
// Default buffer size to use to read incoming UDP messages. If not set
|
||||||
// it defaults to MinMsgSize (512 B).
|
// it defaults to MinMsgSize (512 B).
|
||||||
UDPSize int
|
UDPSize int
|
||||||
// The net.Conn.SetReadTimeout value for new connections.
|
// The net.Conn.SetReadTimeout value for new connections, defaults to 2 seconds.
|
||||||
ReadTimeout time.Duration
|
ReadTimeout time.Duration
|
||||||
// The net.Conn.SetWriteTimeout value for new connections.
|
// The net.Conn.SetWriteTimeout value for new connections, defaults to 2 seconds.
|
||||||
WriteTimeout time.Duration
|
WriteTimeout time.Duration
|
||||||
// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
|
// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
|
||||||
IdleTimeout func() time.Duration
|
IdleTimeout func() time.Duration
|
||||||
|
// Listener deadline timeout, defaults to 2 seconds.
|
||||||
|
Deadline time.Duration
|
||||||
// Secret(s) for Tsig map[<zonename>]<base64 secret>.
|
// Secret(s) for Tsig map[<zonename>]<base64 secret>.
|
||||||
TsigSecret map[string]string
|
TsigSecret map[string]string
|
||||||
|
|
||||||
|
// For graceful shutdown.
|
||||||
|
stopUDP chan bool
|
||||||
|
stopTCP chan bool
|
||||||
|
wgUDP sync.WaitGroup
|
||||||
|
wgTCP sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenAndServe starts a nameserver on the configured address in *Server.
|
// ListenAndServe starts a nameserver on the configured address in *Server.
|
||||||
|
@ -238,6 +246,7 @@ func (srv *Server) ListenAndServe() error {
|
||||||
}
|
}
|
||||||
switch srv.Net {
|
switch srv.Net {
|
||||||
case "tcp", "tcp4", "tcp6":
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
srv.stopTCP = make(chan bool)
|
||||||
a, e := net.ResolveTCPAddr(srv.Net, addr)
|
a, e := net.ResolveTCPAddr(srv.Net, addr)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
return e
|
return e
|
||||||
|
@ -248,6 +257,7 @@ func (srv *Server) ListenAndServe() error {
|
||||||
}
|
}
|
||||||
return srv.serveTCP(l)
|
return srv.serveTCP(l)
|
||||||
case "udp", "udp4", "udp6":
|
case "udp", "udp4", "udp6":
|
||||||
|
srv.stopUDP = make(chan bool)
|
||||||
a, e := net.ResolveUDPAddr(srv.Net, addr)
|
a, e := net.ResolveUDPAddr(srv.Net, addr)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
return e
|
return e
|
||||||
|
@ -268,7 +278,11 @@ func (srv *Server) ListenAndServe() error {
|
||||||
// ActivateAndServe starts a nameserver with the PacketConn or Listener
|
// ActivateAndServe starts a nameserver with the PacketConn or Listener
|
||||||
// configured in *Server. Its main use is to start a server from systemd.
|
// configured in *Server. Its main use is to start a server from systemd.
|
||||||
func (srv *Server) ActivateAndServe() error {
|
func (srv *Server) ActivateAndServe() error {
|
||||||
|
if srv.UDPSize == 0 {
|
||||||
|
srv.UDPSize = MinMsgSize
|
||||||
|
}
|
||||||
if srv.PacketConn != nil {
|
if srv.PacketConn != nil {
|
||||||
|
srv.stopUDP = make(chan bool)
|
||||||
if srv.UDPSize == 0 {
|
if srv.UDPSize == 0 {
|
||||||
srv.UDPSize = MinMsgSize
|
srv.UDPSize = MinMsgSize
|
||||||
}
|
}
|
||||||
|
@ -280,6 +294,7 @@ func (srv *Server) ActivateAndServe() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if srv.Listener != nil {
|
if srv.Listener != nil {
|
||||||
|
srv.stopTCP = make(chan bool)
|
||||||
if t, ok := srv.Listener.(*net.TCPListener); ok {
|
if t, ok := srv.Listener.(*net.TCPListener); ok {
|
||||||
return srv.serveTCP(t)
|
return srv.serveTCP(t)
|
||||||
}
|
}
|
||||||
|
@ -290,7 +305,8 @@ func (srv *Server) ActivateAndServe() error {
|
||||||
// Shutdown shuts down a server. When Shutdown returns all currently in progress
|
// Shutdown shuts down a server. When Shutdown returns all currently in progress
|
||||||
// queries have been answered and all started goroutines have been stopped.
|
// queries have been answered and all started goroutines have been stopped.
|
||||||
func (srv *Server) Shutdown() {
|
func (srv *Server) Shutdown() {
|
||||||
|
srv.stopTCP <- true
|
||||||
|
srv.stopUDP <- true
|
||||||
}
|
}
|
||||||
|
|
||||||
// serveTCP starts a TCP listener for the server.
|
// serveTCP starts a TCP listener for the server.
|
||||||
|
@ -307,6 +323,13 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
rw, e := l.AcceptTCP()
|
rw, e := l.AcceptTCP()
|
||||||
|
select {
|
||||||
|
case <-srv.stopTCP:
|
||||||
|
// Asked to shutdown
|
||||||
|
srv.wgTCP.Wait()
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
if e != nil {
|
if e != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -314,6 +337,7 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
|
||||||
if e != nil {
|
if e != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
srv.wgTCP.Add(1)
|
||||||
go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
|
go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
|
||||||
}
|
}
|
||||||
panic("dns: not reached")
|
panic("dns: not reached")
|
||||||
|
@ -333,10 +357,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
m, s, e := srv.readUDP(l, rtimeout)
|
m, s, e := srv.readUDP(l, rtimeout)
|
||||||
|
select {
|
||||||
|
case <-srv.stopUDP:
|
||||||
|
// Asked to shutdown
|
||||||
|
srv.wgUDP.Wait()
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
if e != nil {
|
if e != nil {
|
||||||
// TODO(miek): logging?
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
srv.wgUDP.Add(1)
|
||||||
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
|
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
|
||||||
}
|
}
|
||||||
panic("dns: not reached")
|
panic("dns: not reached")
|
||||||
|
@ -346,6 +377,14 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
||||||
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *sessionUDP, t *net.TCPConn) {
|
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *sessionUDP, t *net.TCPConn) {
|
||||||
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
|
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
|
||||||
q := 0
|
q := 0
|
||||||
|
defer func() {
|
||||||
|
if u != nil {
|
||||||
|
srv.wgUDP.Done()
|
||||||
|
}
|
||||||
|
if t != nil {
|
||||||
|
srv.wgTCP.Done()
|
||||||
|
}
|
||||||
|
}()
|
||||||
Redo:
|
Redo:
|
||||||
req := new(Msg)
|
req := new(Msg)
|
||||||
err := req.Unpack(m)
|
err := req.Unpack(m)
|
||||||
|
|
|
@ -48,7 +48,7 @@ func TestServing(t *testing.T) {
|
||||||
r, _, err := c.Exchange(m, "127.0.0.1:8053")
|
r, _, err := c.Exchange(m, "127.0.0.1:8053")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("Failed to exchange miek.nl", err)
|
t.Log("Failed to exchange miek.nl", err)
|
||||||
t.Fail()
|
t.Fatal()
|
||||||
}
|
}
|
||||||
txt := r.Extra[0].(*TXT).Txt[0]
|
txt := r.Extra[0].(*TXT).Txt[0]
|
||||||
if txt != "Hello world" {
|
if txt != "Hello world" {
|
||||||
|
@ -60,7 +60,7 @@ func TestServing(t *testing.T) {
|
||||||
r, _, err = c.Exchange(m, "127.0.0.1:8053")
|
r, _, err = c.Exchange(m, "127.0.0.1:8053")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log("Failed to exchange example.com", err)
|
t.Log("Failed to exchange example.com", err)
|
||||||
t.Fail()
|
t.Fatal()
|
||||||
}
|
}
|
||||||
txt = r.Extra[0].(*TXT).Txt[0]
|
txt = r.Extra[0].(*TXT).Txt[0]
|
||||||
if txt != "Hello example" {
|
if txt != "Hello example" {
|
||||||
|
|
Loading…
Reference in New Issue