diff --git a/server.go b/server.go index 4b4ec33c..06984e7c 100644 --- a/server.go +++ b/server.go @@ -82,6 +82,7 @@ type ConnectionStater interface { type response struct { msg []byte + closed bool // connection has been closed hijacked bool // connection has been hijacked by handler tsigTimersOnly bool tsigStatus error @@ -728,6 +729,10 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S // WriteMsg implements the ResponseWriter.WriteMsg method. func (w *response) WriteMsg(m *Msg) (err error) { + if w.closed { + return &Error{err: "WriteMsg called after Close"} + } + var data []byte if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) if t := m.IsTsig(); t != nil { @@ -749,6 +754,10 @@ func (w *response) WriteMsg(m *Msg) (err error) { // Write implements the ResponseWriter.Write method. func (w *response) Write(m []byte) (int, error) { + if w.closed { + return 0, &Error{err: "Write called after Close"} + } + switch { case w.udp != nil: n, err := WriteToSessionUDP(w.udp, m, w.udpSession) @@ -768,7 +777,7 @@ func (w *response) Write(m []byte) (int, error) { n, err := io.Copy(w.tcp, bytes.NewReader(m)) return int(n), err default: - panic("dns: Write called after Close") + panic("dns: internal error: udp and tcp both nil") } } @@ -780,7 +789,7 @@ func (w *response) LocalAddr() net.Addr { case w.tcp != nil: return w.tcp.LocalAddr() default: - panic("dns: LocalAddr called after Close") + panic("dns: internal error: udp and tcp both nil") } } @@ -792,7 +801,7 @@ func (w *response) RemoteAddr() net.Addr { case w.tcp != nil: return w.tcp.RemoteAddr() default: - panic("dns: RemoteAddr called after Close") + panic("dns: internal error: udpSession and tcp both nil") } } @@ -807,13 +816,20 @@ func (w *response) Hijack() { w.hijacked = true } // Close implements the ResponseWriter.Close method func (w *response) Close() error { - // Can't close the udp conn, as that is actually the listener. - if w.tcp != nil { - e := w.tcp.Close() - w.tcp = nil - return e + if w.closed { + return &Error{err: "connection already closed"} + } + w.closed = true + + switch { + case w.udp != nil: + // Can't close the udp conn, as that is actually the listener. + return nil + case w.tcp != nil: + return w.tcp.Close() + default: + panic("dns: internal error: udp and tcp both nil") } - return nil } // ConnectionState() implements the ConnectionStater.ConnectionState() interface. diff --git a/server_test.go b/server_test.go index 9966e54d..5c6b8006 100644 --- a/server_test.go +++ b/server_test.go @@ -719,7 +719,13 @@ func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr stri } if eg.Wait() != nil { - t.Fatalf("conn.ReadMsg error: %v", eg.Wait()) + t.Errorf("conn.ReadMsg error: %v", eg.Wait()) + } + + srv.lock.RLock() + defer srv.lock.RUnlock() + if len(srv.conns) != 0 { + t.Errorf("TCP connection tracking map not empty after ShutdownContext; map still contains %d connections", len(srv.conns)) } } @@ -966,32 +972,34 @@ func TestServerRoundtripTsig(t *testing.T) { } func TestResponseAfterClose(t *testing.T) { - testPanic := func(name string, fn func()) { - defer func() { - expect := fmt.Sprintf("dns: %s called after Close", name) - if err := recover(); err == nil { - t.Errorf("expected panic from %s after Close", name) - } else if err != expect { - t.Errorf("expected explicit panic from %s after Close, expected %q, got %q", name, expect, err) - } - }() - fn() + testError := func(name string, err error) { + t.Helper() + + expect := fmt.Sprintf("dns: %s called after Close", name) + if err == nil { + t.Errorf("expected error from %s after Close", name) + } else if err.Error() != expect { + t.Errorf("expected explicit error from %s after Close, expected %q, got %q", name, expect, err) + } } rw := &response{ - tcp: nil, // Close sets tcp to nil - udp: nil, - udpSession: nil, + closed: true, + } + + _, err := rw.Write(make([]byte, 2)) + testError("Write", err) + + testError("WriteMsg", rw.WriteMsg(new(Msg))) +} + +func TestResponseDoubleClose(t *testing.T) { + rw := &response{ + closed: true, + } + if err, expect := rw.Close(), "dns: connection already closed"; err == nil || err.Error() != expect { + t.Errorf("Close did not return expected: error %q, got: %v", expect, err) } - testPanic("Write", func() { - rw.Write(make([]byte, 2)) - }) - testPanic("LocalAddr", func() { - rw.LocalAddr() - }) - testPanic("RemoteAddr", func() { - rw.RemoteAddr() - }) } type ExampleFrameLengthWriter struct {