diff --git a/server_test.go b/server_test.go index b3e0c72b..e19671bd 100644 --- a/server_test.go +++ b/server_test.go @@ -2,6 +2,7 @@ package dns import ( "fmt" + "io" "net" "runtime" "sync" @@ -389,6 +390,46 @@ func TestShutdownTCP(t *testing.T) { } } +func TestHandlerCloseTCP(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + panic(err) + } + addr := ln.Addr().String() + + server := &Server{Addr: addr, Net: "tcp", Listener: ln} + + hname := "testhandlerclosetcp." + h_triggered := false + HandleFunc(hname, func(w ResponseWriter, r *Msg) { + h_triggered = true + w.Close() + }) + defer HandleRemove(hname) + + go func() { + defer server.Shutdown() + c := &Client{Net: "tcp"} + m := new(Msg).SetQuestion(hname, 1) + tries := 0 + exchange: + _, _, err := c.Exchange(m, addr) + if err != nil && err != io.EOF { + t.Logf("Exchange failed: %s\n", err) + if tries == 3 { + return + } + time.Sleep(time.Second / 10) + tries += 1 + goto exchange + } + }() + server.ActivateAndServe() + if !h_triggered { + t.Fatalf("Handler never called") + } +} + func TestShutdownUDP(t *testing.T) { s, _, err := RunLocalUDPServer("127.0.0.1:0") if err != nil {