From ec3443f85d88a7e5f4664093efcee4051354c53b Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Sat, 3 Nov 2018 20:14:07 +1030 Subject: [PATCH] Fix TCP connection tracking memory leak (#808) * Add test that srv.conns is empty in checkInProgressQueriesAtShutdownServer * Track ResponseWriter Close without nil-ing tcp * Remove LocalAddr and RemoteAddr panic after Close This is no longer needed as the tcp field is no longer set to nil in Close. * Add more explicit WriteMsg panic after Close Previously this would panic with `dns: Write called after Close` which is obviously less clear. * Panic if Hijack is called after Close Previously this worked, but later calls to Write would panic. This is more explicit. * Return an error if Close called multiple times Neither io.Closer, nor ResponseWriter, provide any guarantees about the behaviour of multiple calls to Close. This was made explicit in https://golang.org/cl/8575043 and in practice implementations differ wildly. This matches ShutdownContext which returns an error if called multiple times. * Check map len under lock in checkInProgressQueriesAtShutdownServer * Correct error message in checkInProgressQueriesAtShutdownServer * Remove panic-after-Close from Hijack * Return errors, not panic, on Write after Close --- server.go | 34 ++++++++++++++++++++++--------- server_test.go | 54 +++++++++++++++++++++++++++++--------------------- 2 files changed, 56 insertions(+), 32 deletions(-) diff --git a/server.go b/server.go index 4b4ec33c..06984e7c 100644 --- a/server.go +++ b/server.go @@ -82,6 +82,7 @@ type ConnectionStater interface { type response struct { msg []byte + closed bool // connection has been closed hijacked bool // connection has been hijacked by handler tsigTimersOnly bool tsigStatus error @@ -728,6 +729,10 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S // WriteMsg implements the ResponseWriter.WriteMsg method. func (w *response) WriteMsg(m *Msg) (err error) { + if w.closed { + return &Error{err: "WriteMsg called after Close"} + } + var data []byte if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) if t := m.IsTsig(); t != nil { @@ -749,6 +754,10 @@ func (w *response) WriteMsg(m *Msg) (err error) { // Write implements the ResponseWriter.Write method. func (w *response) Write(m []byte) (int, error) { + if w.closed { + return 0, &Error{err: "Write called after Close"} + } + switch { case w.udp != nil: n, err := WriteToSessionUDP(w.udp, m, w.udpSession) @@ -768,7 +777,7 @@ func (w *response) Write(m []byte) (int, error) { n, err := io.Copy(w.tcp, bytes.NewReader(m)) return int(n), err default: - panic("dns: Write called after Close") + panic("dns: internal error: udp and tcp both nil") } } @@ -780,7 +789,7 @@ func (w *response) LocalAddr() net.Addr { case w.tcp != nil: return w.tcp.LocalAddr() default: - panic("dns: LocalAddr called after Close") + panic("dns: internal error: udp and tcp both nil") } } @@ -792,7 +801,7 @@ func (w *response) RemoteAddr() net.Addr { case w.tcp != nil: return w.tcp.RemoteAddr() default: - panic("dns: RemoteAddr called after Close") + panic("dns: internal error: udpSession and tcp both nil") } } @@ -807,13 +816,20 @@ func (w *response) Hijack() { w.hijacked = true } // Close implements the ResponseWriter.Close method func (w *response) Close() error { - // Can't close the udp conn, as that is actually the listener. - if w.tcp != nil { - e := w.tcp.Close() - w.tcp = nil - return e + if w.closed { + return &Error{err: "connection already closed"} + } + w.closed = true + + switch { + case w.udp != nil: + // Can't close the udp conn, as that is actually the listener. + return nil + case w.tcp != nil: + return w.tcp.Close() + default: + panic("dns: internal error: udp and tcp both nil") } - return nil } // ConnectionState() implements the ConnectionStater.ConnectionState() interface. diff --git a/server_test.go b/server_test.go index 9966e54d..5c6b8006 100644 --- a/server_test.go +++ b/server_test.go @@ -719,7 +719,13 @@ func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr stri } if eg.Wait() != nil { - t.Fatalf("conn.ReadMsg error: %v", eg.Wait()) + t.Errorf("conn.ReadMsg error: %v", eg.Wait()) + } + + srv.lock.RLock() + defer srv.lock.RUnlock() + if len(srv.conns) != 0 { + t.Errorf("TCP connection tracking map not empty after ShutdownContext; map still contains %d connections", len(srv.conns)) } } @@ -966,32 +972,34 @@ func TestServerRoundtripTsig(t *testing.T) { } func TestResponseAfterClose(t *testing.T) { - testPanic := func(name string, fn func()) { - defer func() { - expect := fmt.Sprintf("dns: %s called after Close", name) - if err := recover(); err == nil { - t.Errorf("expected panic from %s after Close", name) - } else if err != expect { - t.Errorf("expected explicit panic from %s after Close, expected %q, got %q", name, expect, err) - } - }() - fn() + testError := func(name string, err error) { + t.Helper() + + expect := fmt.Sprintf("dns: %s called after Close", name) + if err == nil { + t.Errorf("expected error from %s after Close", name) + } else if err.Error() != expect { + t.Errorf("expected explicit error from %s after Close, expected %q, got %q", name, expect, err) + } } rw := &response{ - tcp: nil, // Close sets tcp to nil - udp: nil, - udpSession: nil, + closed: true, + } + + _, err := rw.Write(make([]byte, 2)) + testError("Write", err) + + testError("WriteMsg", rw.WriteMsg(new(Msg))) +} + +func TestResponseDoubleClose(t *testing.T) { + rw := &response{ + closed: true, + } + if err, expect := rw.Close(), "dns: connection already closed"; err == nil || err.Error() != expect { + t.Errorf("Close did not return expected: error %q, got: %v", expect, err) } - testPanic("Write", func() { - rw.Write(make([]byte, 2)) - }) - testPanic("LocalAddr", func() { - rw.LocalAddr() - }) - testPanic("RemoteAddr", func() { - rw.RemoteAddr() - }) } type ExampleFrameLengthWriter struct {