From a3f088363bfa218c9b198c2ac5e8efb6289edf3e Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Wed, 10 Oct 2018 04:16:15 +1030 Subject: [PATCH] Hold srv.lock while calling SetReadDeadline (#780) * Hold srv.lock while calling SetReadDeadline * Only hold the read lock in readTCP and readUDP --- server.go | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/server.go b/server.go index 4b19b437..4b4ec33c 100644 --- a/server.go +++ b/server.go @@ -416,14 +416,13 @@ func (srv *Server) Shutdown() error { // to terminate. func (srv *Server) ShutdownContext(ctx context.Context) error { srv.lock.Lock() - started := srv.started - srv.started = false - srv.lock.Unlock() - - if !started { + if !srv.started { + srv.lock.Unlock() return &Error{err: "server not started"} } + srv.started = false + if srv.PacketConn != nil { srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads } @@ -432,10 +431,10 @@ func (srv *Server) ShutdownContext(ctx context.Context) error { srv.Listener.Close() } - srv.lock.Lock() for rw := range srv.conns { rw.SetReadDeadline(aLongTimeAgo) // Unblock reads } + srv.lock.Unlock() if testShutdownNotify != nil { @@ -666,13 +665,15 @@ func (srv *Server) serveDNS(w *response) { } func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { - if srv.isStarted() { - // If we race with ShutdownContext, the read deadline may - // have been set in the distant past to unblock the read - // below. We must not override it, otherwise we may block - // ShutdownContext. + // If we race with ShutdownContext, the read deadline may + // have been set in the distant past to unblock the read + // below. We must not override it, otherwise we may block + // ShutdownContext. + srv.lock.RLock() + if srv.started { conn.SetReadDeadline(time.Now().Add(timeout)) } + srv.lock.RUnlock() l := make([]byte, 2) n, err := conn.Read(l) @@ -708,10 +709,12 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) } func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { - if srv.isStarted() { + srv.lock.RLock() + if srv.started { // See the comment in readTCP above. conn.SetReadDeadline(time.Now().Add(timeout)) } + srv.lock.RUnlock() m := srv.udpPool.Get().([]byte) n, s, err := ReadFromSessionUDP(conn, m)