Hold srv.lock while calling SetReadDeadline (#780)

* Hold srv.lock while calling SetReadDeadline

* Only hold the read lock in readTCP and readUDP
This commit is contained in:
Tom Thorogood 2018-10-10 04:16:15 +10:30 committed by Miek Gieben
parent e6cede5dc8
commit a3f088363b
1 changed files with 15 additions and 12 deletions

View File

@ -416,14 +416,13 @@ func (srv *Server) Shutdown() error {
// to terminate. // to terminate.
func (srv *Server) ShutdownContext(ctx context.Context) error { func (srv *Server) ShutdownContext(ctx context.Context) error {
srv.lock.Lock() srv.lock.Lock()
started := srv.started if !srv.started {
srv.started = false srv.lock.Unlock()
srv.lock.Unlock()
if !started {
return &Error{err: "server not started"} return &Error{err: "server not started"}
} }
srv.started = false
if srv.PacketConn != nil { if srv.PacketConn != nil {
srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads
} }
@ -432,10 +431,10 @@ func (srv *Server) ShutdownContext(ctx context.Context) error {
srv.Listener.Close() srv.Listener.Close()
} }
srv.lock.Lock()
for rw := range srv.conns { for rw := range srv.conns {
rw.SetReadDeadline(aLongTimeAgo) // Unblock reads rw.SetReadDeadline(aLongTimeAgo) // Unblock reads
} }
srv.lock.Unlock() srv.lock.Unlock()
if testShutdownNotify != nil { if testShutdownNotify != nil {
@ -666,13 +665,15 @@ func (srv *Server) serveDNS(w *response) {
} }
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
if srv.isStarted() { // If we race with ShutdownContext, the read deadline may
// If we race with ShutdownContext, the read deadline may // have been set in the distant past to unblock the read
// have been set in the distant past to unblock the read // below. We must not override it, otherwise we may block
// below. We must not override it, otherwise we may block // ShutdownContext.
// ShutdownContext. srv.lock.RLock()
if srv.started {
conn.SetReadDeadline(time.Now().Add(timeout)) conn.SetReadDeadline(time.Now().Add(timeout))
} }
srv.lock.RUnlock()
l := make([]byte, 2) l := make([]byte, 2)
n, err := conn.Read(l) n, err := conn.Read(l)
@ -708,10 +709,12 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
} }
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
if srv.isStarted() { srv.lock.RLock()
if srv.started {
// See the comment in readTCP above. // See the comment in readTCP above.
conn.SetReadDeadline(time.Now().Add(timeout)) conn.SetReadDeadline(time.Now().Add(timeout))
} }
srv.lock.RUnlock()
m := srv.udpPool.Get().([]byte) m := srv.udpPool.Get().([]byte)
n, s, err := ReadFromSessionUDP(conn, m) n, s, err := ReadFromSessionUDP(conn, m)