diff --git a/server.go b/server.go index b231c343..6bd202bd 100644 --- a/server.go +++ b/server.go @@ -753,24 +753,33 @@ func (w *response) Write(m []byte) (int, error) { n, err := io.Copy(w.tcp, bytes.NewReader(m)) return int(n), err + default: + panic("dns: Write called after Close") } - panic("not reached") } // LocalAddr implements the ResponseWriter.LocalAddr method. func (w *response) LocalAddr() net.Addr { - if w.tcp != nil { + switch { + case w.udp != nil: + return w.udp.LocalAddr() + case w.tcp != nil: return w.tcp.LocalAddr() + default: + panic("dns: LocalAddr called after Close") } - return w.udp.LocalAddr() } // RemoteAddr implements the ResponseWriter.RemoteAddr method. func (w *response) RemoteAddr() net.Addr { - if w.tcp != nil { + switch { + case w.udpSession != nil: + return w.udpSession.RemoteAddr() + case w.tcp != nil: return w.tcp.RemoteAddr() + default: + panic("dns: RemoteAddr called after Close") } - return w.udpSession.RemoteAddr() } // TsigStatus implements the ResponseWriter.TsigStatus method. diff --git a/server_test.go b/server_test.go index a1fd2798..dfb81141 100644 --- a/server_test.go +++ b/server_test.go @@ -990,6 +990,35 @@ func TestServerRoundtripTsig(t *testing.T) { } } +func TestResponseAfterClose(t *testing.T) { + testPanic := func(name string, fn func()) { + defer func() { + expect := fmt.Sprintf("dns: %s called after Close", name) + if err := recover(); err == nil { + t.Errorf("expected panic from %s after Close", name) + } else if err != expect { + t.Errorf("expected explicit panic from %s after Close, expected %q, got %q", name, expect, err) + } + }() + fn() + } + + rw := &response{ + tcp: nil, // Close sets tcp to nil + udp: nil, + udpSession: nil, + } + testPanic("Write", func() { + rw.Write(make([]byte, 2)) + }) + testPanic("LocalAddr", func() { + rw.LocalAddr() + }) + testPanic("RemoteAddr", func() { + rw.RemoteAddr() + }) +} + type ExampleFrameLengthWriter struct { Writer }