Hold srv.lock while calling SetReadDeadline (#780)

* Hold srv.lock while calling SetReadDeadline

* Only hold the read lock in readTCP and readUDP
This commit is contained in:
Tom Thorogood 2018-10-10 04:16:15 +10:30 committed by Miek Gieben
parent e6cede5dc8
commit a3f088363b
1 changed files with 15 additions and 12 deletions

View File

@ -416,14 +416,13 @@ func (srv *Server) Shutdown() error {
// to terminate.
func (srv *Server) ShutdownContext(ctx context.Context) error {
srv.lock.Lock()
started := srv.started
srv.started = false
srv.lock.Unlock()
if !started {
if !srv.started {
srv.lock.Unlock()
return &Error{err: "server not started"}
}
srv.started = false
if srv.PacketConn != nil {
srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads
}
@ -432,10 +431,10 @@ func (srv *Server) ShutdownContext(ctx context.Context) error {
srv.Listener.Close()
}
srv.lock.Lock()
for rw := range srv.conns {
rw.SetReadDeadline(aLongTimeAgo) // Unblock reads
}
srv.lock.Unlock()
if testShutdownNotify != nil {
@ -666,13 +665,15 @@ func (srv *Server) serveDNS(w *response) {
}
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
if srv.isStarted() {
// If we race with ShutdownContext, the read deadline may
// have been set in the distant past to unblock the read
// below. We must not override it, otherwise we may block
// ShutdownContext.
// If we race with ShutdownContext, the read deadline may
// have been set in the distant past to unblock the read
// below. We must not override it, otherwise we may block
// ShutdownContext.
srv.lock.RLock()
if srv.started {
conn.SetReadDeadline(time.Now().Add(timeout))
}
srv.lock.RUnlock()
l := make([]byte, 2)
n, err := conn.Read(l)
@ -708,10 +709,12 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
}
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
if srv.isStarted() {
srv.lock.RLock()
if srv.started {
// See the comment in readTCP above.
conn.SetReadDeadline(time.Now().Add(timeout))
}
srv.lock.RUnlock()
m := srv.udpPool.Get().([]byte)
n, s, err := ReadFromSessionUDP(conn, m)