Fix TCP connection tracking memory leak (#808)

* Add test that srv.conns is empty in checkInProgressQueriesAtShutdownServer

* Track ResponseWriter Close without nil-ing tcp

* Remove LocalAddr and RemoteAddr panic after Close

This is no longer needed as the tcp field is no longer set to nil in
Close.

* Add more explicit WriteMsg panic after Close

Previously this would panic with `dns: Write called after Close` which
is obviously less clear.

* Panic if Hijack is called after Close

Previously this worked, but later calls to Write would panic. This is
more explicit.

* Return an error if Close called multiple times

Neither io.Closer, nor ResponseWriter, provide any guarantees about the
behaviour of multiple calls to Close. This was made explicit in
https://golang.org/cl/8575043 and in practice implementations differ
wildly.

This matches ShutdownContext which returns an error if called multiple
times.

* Check map len under lock in checkInProgressQueriesAtShutdownServer

* Correct error message in checkInProgressQueriesAtShutdownServer

* Remove panic-after-Close from Hijack

* Return errors, not panic, on Write after Close
This commit is contained in:
Tom Thorogood 2018-11-03 20:14:07 +10:30 committed by Miek Gieben
parent 6ae357d393
commit ec3443f85d
2 changed files with 56 additions and 32 deletions

View File

@ -82,6 +82,7 @@ type ConnectionStater interface {
type response struct { type response struct {
msg []byte msg []byte
closed bool // connection has been closed
hijacked bool // connection has been hijacked by handler hijacked bool // connection has been hijacked by handler
tsigTimersOnly bool tsigTimersOnly bool
tsigStatus error tsigStatus error
@ -728,6 +729,10 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S
// WriteMsg implements the ResponseWriter.WriteMsg method. // WriteMsg implements the ResponseWriter.WriteMsg method.
func (w *response) WriteMsg(m *Msg) (err error) { func (w *response) WriteMsg(m *Msg) (err error) {
if w.closed {
return &Error{err: "WriteMsg called after Close"}
}
var data []byte var data []byte
if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
if t := m.IsTsig(); t != nil { if t := m.IsTsig(); t != nil {
@ -749,6 +754,10 @@ func (w *response) WriteMsg(m *Msg) (err error) {
// Write implements the ResponseWriter.Write method. // Write implements the ResponseWriter.Write method.
func (w *response) Write(m []byte) (int, error) { func (w *response) Write(m []byte) (int, error) {
if w.closed {
return 0, &Error{err: "Write called after Close"}
}
switch { switch {
case w.udp != nil: case w.udp != nil:
n, err := WriteToSessionUDP(w.udp, m, w.udpSession) n, err := WriteToSessionUDP(w.udp, m, w.udpSession)
@ -768,7 +777,7 @@ func (w *response) Write(m []byte) (int, error) {
n, err := io.Copy(w.tcp, bytes.NewReader(m)) n, err := io.Copy(w.tcp, bytes.NewReader(m))
return int(n), err return int(n), err
default: default:
panic("dns: Write called after Close") panic("dns: internal error: udp and tcp both nil")
} }
} }
@ -780,7 +789,7 @@ func (w *response) LocalAddr() net.Addr {
case w.tcp != nil: case w.tcp != nil:
return w.tcp.LocalAddr() return w.tcp.LocalAddr()
default: default:
panic("dns: LocalAddr called after Close") panic("dns: internal error: udp and tcp both nil")
} }
} }
@ -792,7 +801,7 @@ func (w *response) RemoteAddr() net.Addr {
case w.tcp != nil: case w.tcp != nil:
return w.tcp.RemoteAddr() return w.tcp.RemoteAddr()
default: default:
panic("dns: RemoteAddr called after Close") panic("dns: internal error: udpSession and tcp both nil")
} }
} }
@ -807,13 +816,20 @@ func (w *response) Hijack() { w.hijacked = true }
// Close implements the ResponseWriter.Close method // Close implements the ResponseWriter.Close method
func (w *response) Close() error { func (w *response) Close() error {
// Can't close the udp conn, as that is actually the listener. if w.closed {
if w.tcp != nil { return &Error{err: "connection already closed"}
e := w.tcp.Close() }
w.tcp = nil w.closed = true
return e
switch {
case w.udp != nil:
// Can't close the udp conn, as that is actually the listener.
return nil
case w.tcp != nil:
return w.tcp.Close()
default:
panic("dns: internal error: udp and tcp both nil")
} }
return nil
} }
// ConnectionState() implements the ConnectionStater.ConnectionState() interface. // ConnectionState() implements the ConnectionStater.ConnectionState() interface.

View File

@ -719,7 +719,13 @@ func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr stri
} }
if eg.Wait() != nil { if eg.Wait() != nil {
t.Fatalf("conn.ReadMsg error: %v", eg.Wait()) t.Errorf("conn.ReadMsg error: %v", eg.Wait())
}
srv.lock.RLock()
defer srv.lock.RUnlock()
if len(srv.conns) != 0 {
t.Errorf("TCP connection tracking map not empty after ShutdownContext; map still contains %d connections", len(srv.conns))
} }
} }
@ -966,32 +972,34 @@ func TestServerRoundtripTsig(t *testing.T) {
} }
func TestResponseAfterClose(t *testing.T) { func TestResponseAfterClose(t *testing.T) {
testPanic := func(name string, fn func()) { testError := func(name string, err error) {
defer func() { t.Helper()
expect := fmt.Sprintf("dns: %s called after Close", name)
if err := recover(); err == nil { expect := fmt.Sprintf("dns: %s called after Close", name)
t.Errorf("expected panic from %s after Close", name) if err == nil {
} else if err != expect { t.Errorf("expected error from %s after Close", name)
t.Errorf("expected explicit panic from %s after Close, expected %q, got %q", name, expect, err) } else if err.Error() != expect {
} t.Errorf("expected explicit error from %s after Close, expected %q, got %q", name, expect, err)
}() }
fn()
} }
rw := &response{ rw := &response{
tcp: nil, // Close sets tcp to nil closed: true,
udp: nil, }
udpSession: nil,
_, err := rw.Write(make([]byte, 2))
testError("Write", err)
testError("WriteMsg", rw.WriteMsg(new(Msg)))
}
func TestResponseDoubleClose(t *testing.T) {
rw := &response{
closed: true,
}
if err, expect := rw.Close(), "dns: connection already closed"; err == nil || err.Error() != expect {
t.Errorf("Close did not return expected: error %q, got: %v", expect, err)
} }
testPanic("Write", func() {
rw.Write(make([]byte, 2))
})
testPanic("LocalAddr", func() {
rw.LocalAddr()
})
testPanic("RemoteAddr", func() {
rw.RemoteAddr()
})
} }
type ExampleFrameLengthWriter struct { type ExampleFrameLengthWriter struct {