diff --git a/acceptfunc_test.go b/acceptfunc_test.go index 2d784a61..d40d4e4c 100644 --- a/acceptfunc_test.go +++ b/acceptfunc_test.go @@ -6,7 +6,7 @@ import ( func TestAcceptNotify(t *testing.T) { HandleFunc("example.org.", handleNotify) - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } diff --git a/client_test.go b/client_test.go index 11c8469d..3ac8354c 100644 --- a/client_test.go +++ b/client_test.go @@ -16,7 +16,7 @@ func TestDialUDP(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -39,7 +39,7 @@ func TestClientSync(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -73,7 +73,7 @@ func TestClientLocalAddress(t *testing.T) { HandleFunc("miek.nl.", HelloServerEchoAddrPort) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -117,7 +117,7 @@ func TestClientTLSSyncV4(t *testing.T) { Certificates: []tls.Certificate{cert}, } - s, addrstr, err := RunLocalTLSServer(":0", &config) + s, addrstr, _, err := RunLocalTLSServer(":0", &config) if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -173,7 +173,7 @@ func TestClientSyncBadID(t *testing.T) { HandleFunc("miek.nl.", HelloServerBadID) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -198,7 +198,7 @@ func TestClientSyncBadThenGoodID(t *testing.T) { HandleFunc("miek.nl.", HelloServerBadThenGoodID) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -229,7 +229,7 @@ func TestClientSyncTCPBadID(t *testing.T) { HandleFunc("miek.nl.", HelloServerBadID) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalTCPServer(":0") + s, addrstr, _, err := RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -250,7 +250,7 @@ func TestClientEDNS0(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -297,7 +297,7 @@ func TestClientEDNS0Local(t *testing.T) { HandleFunc("miek.nl.", handler) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %s", err) } @@ -347,7 +347,7 @@ func TestClientConn(t *testing.T) { defer HandleRemove("miek.nl.") // This uses TCP just to make it slightly different than TestClientSync - s, addrstr, err := RunLocalTCPServer(":0") + s, addrstr, _, err := RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -594,7 +594,7 @@ func TestConcurrentExchanges(t *testing.T) { HandleFunc("miek.nl.", handler) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %s", err) } @@ -631,7 +631,7 @@ func TestExchangeWithConn(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } diff --git a/server.go b/server.go index 77b43dea..d7f23485 100644 --- a/server.go +++ b/server.go @@ -72,9 +72,10 @@ type response struct { tsigStatus error tsigRequestMAC string tsigSecret map[string]string // the tsig secrets - udp *net.UDPConn // i/o connection if UDP was used + udp net.PacketConn // i/o connection if UDP was used tcp net.Conn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right + pcSession net.Addr // address to use when writing to a generic net.PacketConn writer Writer // writer to output the raw DNS bits } @@ -147,12 +148,24 @@ type Reader interface { ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) } -// defaultReader is an adapter for the Server struct that implements the Reader interface -// using the readTCP and readUDP func of the embedded Server. +// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns. +type PacketConnReader interface { + Reader + + // ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may + // alter connection properties, for example the read-deadline. + ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) +} + +// defaultReader is an adapter for the Server struct that implements the Reader and +// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs +// of the embedded Server. type defaultReader struct { *Server } +var _ PacketConnReader = defaultReader{} + func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { return dr.readTCP(conn, timeout) } @@ -161,8 +174,14 @@ func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byt return dr.readUDP(conn, timeout) } +func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) { + return dr.readPacketConn(conn, timeout) +} + // DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader. // Implementations should never return a nil Reader. +// Readers should also implement the optional ReaderPacketConn interface. +// ReaderPacketConn is required to use a generic net.PacketConn. type DecorateReader func(Reader) Reader // DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer. @@ -325,24 +344,22 @@ func (srv *Server) ActivateAndServe() error { srv.init() - pConn := srv.PacketConn - l := srv.Listener - if pConn != nil { + if srv.PacketConn != nil { // Check PacketConn interface's type is valid and value // is not nil - if t, ok := pConn.(*net.UDPConn); ok && t != nil { + if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil { if e := setUDPSocketOptions(t); e != nil { return e } - srv.started = true - unlock() - return srv.serveUDP(t) } - } - if l != nil { srv.started = true unlock() - return srv.serveTCP(l) + return srv.serveUDP(srv.PacketConn) + } + if srv.Listener != nil { + srv.started = true + unlock() + return srv.serveTCP(srv.Listener) } return &Error{err: "bad listeners"} } @@ -446,18 +463,24 @@ func (srv *Server) serveTCP(l net.Listener) error { } // serveUDP starts a UDP listener for the server. -func (srv *Server) serveUDP(l *net.UDPConn) error { +func (srv *Server) serveUDP(l net.PacketConn) error { defer l.Close() - if srv.NotifyStartedFunc != nil { - srv.NotifyStartedFunc() - } - reader := Reader(defaultReader{srv}) if srv.DecorateReader != nil { reader = srv.DecorateReader(reader) } + lUDP, isUDP := l.(*net.UDPConn) + readerPC, canPacketConn := reader.(PacketConnReader) + if !isUDP && !canPacketConn { + return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"} + } + + if srv.NotifyStartedFunc != nil { + srv.NotifyStartedFunc() + } + var wg sync.WaitGroup defer func() { wg.Wait() @@ -467,7 +490,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { rtimeout := srv.getReadTimeout() // deadline is not used here for srv.isStarted() { - m, s, err := reader.ReadUDP(l, rtimeout) + var ( + m []byte + sPC net.Addr + sUDP *SessionUDP + err error + ) + if isUDP { + m, sUDP, err = reader.ReadUDP(lUDP, rtimeout) + } else { + m, sPC, err = readerPC.ReadPacketConn(l, rtimeout) + } if err != nil { if !srv.isStarted() { return nil @@ -484,7 +517,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { continue } wg.Add(1) - go srv.serveUDPPacket(&wg, m, l, s) + go srv.serveUDPPacket(&wg, m, l, sUDP, sPC) } return nil @@ -546,8 +579,8 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) { } // Serve a new UDP request. -func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s} +func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) { + w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession} if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { @@ -659,6 +692,24 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S return m, s, nil } +func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) { + srv.lock.RLock() + if srv.started { + // See the comment in readTCP above. + conn.SetReadDeadline(time.Now().Add(timeout)) + } + srv.lock.RUnlock() + + m := srv.udpPool.Get().([]byte) + n, addr, err := conn.ReadFrom(m) + if err != nil { + srv.udpPool.Put(m) + return nil, nil, err + } + m = m[:n] + return m, addr, nil +} + // WriteMsg implements the ResponseWriter.WriteMsg method. func (w *response) WriteMsg(m *Msg) (err error) { if w.closed { @@ -692,7 +743,10 @@ func (w *response) Write(m []byte) (int, error) { switch { case w.udp != nil: - return WriteToSessionUDP(w.udp, m, w.udpSession) + if u, ok := w.udp.(*net.UDPConn); ok { + return WriteToSessionUDP(u, m, w.udpSession) + } + return w.udp.WriteTo(m, w.pcSession) case w.tcp != nil: if len(m) > MaxMsgSize { return 0, &Error{err: "message too large"} @@ -725,10 +779,12 @@ func (w *response) RemoteAddr() net.Addr { switch { case w.udpSession != nil: return w.udpSession.RemoteAddr() + case w.pcSession != nil: + return w.pcSession case w.tcp != nil: return w.tcp.RemoteAddr() default: - panic("dns: internal error: udpSession and tcp both nil") + panic("dns: internal error: udpSession, pcSession and tcp are all nil") } } diff --git a/server_test.go b/server_test.go index 5dea723f..6f4bb5a6 100644 --- a/server_test.go +++ b/server_test.go @@ -67,13 +67,7 @@ func AnotherHelloServer(w ResponseWriter, req *Msg) { w.WriteMsg(m) } -func RunLocalUDPServer(laddr string) (*Server, string, error) { - server, l, _, err := RunLocalUDPServerWithFinChan(laddr) - - return server, l, err -} - -func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { +func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { pc, err := net.ListenPacket("udp", laddr) if err != nil { return nil, "", nil, err @@ -84,15 +78,15 @@ func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server, waitLock.Lock() server.NotifyStartedFunc = waitLock.Unlock - // fin must be buffered so the goroutine below won't block - // forever if fin is never read from. This always happens - // in RunLocalUDPServer and can happen in TestShutdownUDP. - fin := make(chan error, 1) - for _, opt := range opts { opt(server) } + // fin must be buffered so the goroutine below won't block + // forever if fin is never read from. This always happens + // if the channel is discarded and can happen in TestShutdownUDP. + fin := make(chan error, 1) + go func() { fin <- server.ActivateAndServe() pc.Close() @@ -102,13 +96,14 @@ func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server, return server, pc.LocalAddr().String(), fin, nil } -func RunLocalTCPServer(laddr string) (*Server, string, error) { - server, l, _, err := RunLocalTCPServerWithFinChan(laddr) - - return server, l, err +func RunLocalPacketConnServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { + return RunLocalUDPServer(laddr, append(opts, func(srv *Server) { + // Make srv.PacketConn opaque to trigger the generic code paths. + srv.PacketConn = struct{ net.PacketConn }{srv.PacketConn} + })...) } -func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, error) { +func RunLocalTCPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { l, err := net.Listen("tcp", laddr) if err != nil { return nil, "", nil, err @@ -120,8 +115,11 @@ func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, er waitLock.Lock() server.NotifyStartedFunc = waitLock.Unlock - // See the comment in RunLocalUDPServerWithFinChan as to - // why fin must be buffered. + for _, opt := range opts { + opt(server) + } + + // See the comment in RunLocalUDPServer as to why fin must be buffered. fin := make(chan error, 1) go func() { @@ -133,70 +131,69 @@ func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, er return server, l.Addr().String(), fin, nil } -func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, error) { - l, err := tls.Listen("tcp", laddr, config) - if err != nil { - return nil, "", err - } - - server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour} - - waitLock := sync.Mutex{} - waitLock.Lock() - server.NotifyStartedFunc = waitLock.Unlock - - go func() { - server.ActivateAndServe() - l.Close() - }() - - waitLock.Lock() - return server, l.Addr().String(), nil +func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan error, error) { + return RunLocalTCPServer(laddr, func(srv *Server) { + srv.Listener = tls.NewListener(srv.Listener, config) + }) } func TestServing(t *testing.T) { - HandleFunc("miek.nl.", HelloServer) - HandleFunc("example.com.", AnotherHelloServer) - defer HandleRemove("miek.nl.") - defer HandleRemove("example.com.") + for _, tc := range []struct { + name string + network string + runServer func(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) + }{ + {"udp", "udp", RunLocalUDPServer}, + {"tcp", "tcp", RunLocalTCPServer}, + {"PacketConn", "udp", RunLocalPacketConnServer}, + } { + t.Run(tc.name, func(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + HandleFunc("example.com.", AnotherHelloServer) + defer HandleRemove("miek.nl.") + defer HandleRemove("example.com.") - s, addrstr, err := RunLocalUDPServer(":0") - if err != nil { - t.Fatalf("unable to run test server: %v", err) - } - defer s.Shutdown() + s, addrstr, _, err := tc.runServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() - c := new(Client) - m := new(Msg) - m.SetQuestion("miek.nl.", TypeTXT) - r, _, err := c.Exchange(m, addrstr) - if err != nil || len(r.Extra) == 0 { - t.Fatal("failed to exchange miek.nl", err) - } - txt := r.Extra[0].(*TXT).Txt[0] - if txt != "Hello world" { - t.Error("unexpected result for miek.nl", txt, "!= Hello world") - } + c := &Client{ + Net: tc.network, + } + m := new(Msg) + m.SetQuestion("miek.nl.", TypeTXT) + r, _, err := c.Exchange(m, addrstr) + if err != nil || len(r.Extra) == 0 { + t.Fatal("failed to exchange miek.nl", err) + } + txt := r.Extra[0].(*TXT).Txt[0] + if txt != "Hello world" { + t.Error("unexpected result for miek.nl", txt, "!= Hello world") + } - m.SetQuestion("example.com.", TypeTXT) - r, _, err = c.Exchange(m, addrstr) - if err != nil { - t.Fatal("failed to exchange example.com", err) - } - txt = r.Extra[0].(*TXT).Txt[0] - if txt != "Hello example" { - t.Error("unexpected result for example.com", txt, "!= Hello example") - } + m.SetQuestion("example.com.", TypeTXT) + r, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Fatal("failed to exchange example.com", err) + } + txt = r.Extra[0].(*TXT).Txt[0] + if txt != "Hello example" { + t.Error("unexpected result for example.com", txt, "!= Hello example") + } - // Test Mixes cased as noticed by Ask. - m.SetQuestion("eXaMplE.cOm.", TypeTXT) - r, _, err = c.Exchange(m, addrstr) - if err != nil { - t.Error("failed to exchange eXaMplE.cOm", err) - } - txt = r.Extra[0].(*TXT).Txt[0] - if txt != "Hello example" { - t.Error("unexpected result for example.com", txt, "!= Hello example") + // Test Mixes cased as noticed by Ask. + m.SetQuestion("eXaMplE.cOm.", TypeTXT) + r, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Error("failed to exchange eXaMplE.cOm", err) + } + txt = r.Extra[0].(*TXT).Txt[0] + if txt != "Hello example" { + t.Error("unexpected result for example.com", txt, "!= Hello example") + } + }) } } @@ -204,7 +201,7 @@ func TestServing(t *testing.T) { func TestServeIgnoresZFlag(t *testing.T) { HandleFunc("example.com.", AnotherHelloServer) - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -233,7 +230,7 @@ func TestServeNotImplemented(t *testing.T) { HandleFunc("example.com.", AnotherHelloServer) opcode := 15 - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -272,7 +269,7 @@ func TestServingTLS(t *testing.T) { Certificates: []tls.Certificate{cert}, } - s, addrstr, err := RunLocalTLSServer(":0", &config) + s, addrstr, _, err := RunLocalTLSServer(":0", &config) if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -358,7 +355,7 @@ func TestServingTLSConnectionState(t *testing.T) { Certificates: []tls.Certificate{cert}, } - s, addrstr, err := RunLocalTLSServer(":0", &config) + s, addrstr, _, err := RunLocalTLSServer(":0", &config) if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -381,7 +378,7 @@ func TestServingTLSConnectionState(t *testing.T) { // UDP DNS Server HandleFunc(".", tlsHandlerTLS(false)) defer HandleRemove(".") - s, addrstr, err = RunLocalUDPServer(":0") + s, addrstr, _, err = RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -395,7 +392,7 @@ func TestServingTLSConnectionState(t *testing.T) { } // TCP DNS Server - s, addrstr, err = RunLocalTCPServer(":0") + s, addrstr, _, err = RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -479,7 +476,7 @@ func BenchmarkServe(b *testing.B) { defer HandleRemove("miek.nl.") a := runtime.GOMAXPROCS(4) - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { b.Fatalf("unable to run test server: %v", err) } @@ -504,7 +501,7 @@ func BenchmarkServe6(b *testing.B) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") a := runtime.GOMAXPROCS(4) - s, addrstr, err := RunLocalUDPServer("[::1]:0") + s, addrstr, _, err := RunLocalUDPServer("[::1]:0") if err != nil { if strings.Contains(err.Error(), "bind: cannot assign requested address") { b.Skip("missing IPv6 support") @@ -541,7 +538,7 @@ func BenchmarkServeCompress(b *testing.B) { HandleFunc("miek.nl.", HelloServerCompress) defer HandleRemove("miek.nl.") a := runtime.GOMAXPROCS(4) - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { b.Fatalf("unable to run test server: %v", err) } @@ -594,7 +591,7 @@ func TestServingLargeResponses(t *testing.T) { HandleFunc("example.", HelloServerLargeResponse) defer HandleRemove("example.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -634,7 +631,7 @@ func TestServingResponse(t *testing.T) { t.Skip("skipping test in short mode.") } HandleFunc("miek.nl.", HelloServer) - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -657,7 +654,7 @@ func TestServingResponse(t *testing.T) { } func TestShutdownTCP(t *testing.T) { - s, _, fin, err := RunLocalTCPServerWithFinChan(":0") + s, _, fin, err := RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -788,7 +785,7 @@ func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr stri } func TestInProgressQueriesAtShutdownTCP(t *testing.T) { - s, addr, _, err := RunLocalTCPServerWithFinChan(":0") + s, addr, _, err := RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -807,7 +804,7 @@ func TestShutdownTLS(t *testing.T) { Certificates: []tls.Certificate{cert}, } - s, _, err := RunLocalTLSServer(":0", &config) + s, _, _, err := RunLocalTLSServer(":0", &config) if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -827,7 +824,7 @@ func TestInProgressQueriesAtShutdownTLS(t *testing.T) { Certificates: []tls.Certificate{cert}, } - s, addr, err := RunLocalTLSServer(":0", &config) + s, addr, _, err := RunLocalTLSServer(":0", &config) if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -842,7 +839,6 @@ func TestInProgressQueriesAtShutdownTLS(t *testing.T) { } func TestHandlerCloseTCP(t *testing.T) { - ln, err := net.Listen("tcp", ":0") if err != nil { panic(err) @@ -887,7 +883,26 @@ func TestHandlerCloseTCP(t *testing.T) { } func TestShutdownUDP(t *testing.T) { - s, _, fin, err := RunLocalUDPServerWithFinChan(":0") + s, _, fin, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + err = s.Shutdown() + if err != nil { + t.Errorf("could not shutdown test UDP server, %v", err) + } + select { + case err := <-fin: + if err != nil { + t.Errorf("error returned from ActivateAndServe, %v", err) + } + case <-time.After(2 * time.Second): + t.Error("could not shutdown test UDP server. Gave up waiting") + } +} + +func TestShutdownPacketConn(t *testing.T) { + s, _, fin, err := RunLocalPacketConnServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -906,7 +921,17 @@ func TestShutdownUDP(t *testing.T) { } func TestInProgressQueriesAtShutdownUDP(t *testing.T) { - s, addr, _, err := RunLocalUDPServerWithFinChan(":0") + s, addr, _, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + + c := &Client{Net: "udp"} + checkInProgressQueriesAtShutdownServer(t, s, addr, c) +} + +func TestInProgressQueriesAtShutdownPacketConn(t *testing.T) { + s, addr, _, err := RunLocalPacketConnServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -919,7 +944,7 @@ func TestServerStartStopRace(t *testing.T) { var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) - s, _, _, err := RunLocalUDPServerWithFinChan(":0") + s, _, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("could not start server: %s", err) } @@ -982,7 +1007,7 @@ func TestServerReuseport(t *testing.T) { func TestServerRoundtripTsig(t *testing.T) { secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="} - s, addrstr, _, err := RunLocalUDPServerWithFinChan(":0", func(srv *Server) { + s, addrstr, _, err := RunLocalUDPServer(":0", func(srv *Server) { srv.TsigSecret = secret srv.MsgAcceptFunc = func(dh Header) MsgAcceptAction { // defaultMsgAcceptFunc does reject UPDATE queries diff --git a/xfr_test.go b/xfr_test.go index 6510e41c..7cb9f1d3 100644 --- a/xfr_test.go +++ b/xfr_test.go @@ -1,11 +1,6 @@ package dns -import ( - "net" - "sync" - "testing" - "time" -) +import "testing" var ( tsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="} @@ -52,7 +47,7 @@ func TestInvalidXfr(t *testing.T) { HandleFunc("miek.nl.", InvalidXfrServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalTCPServer(":0") + s, addrstr, _, err := RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %s", err) } @@ -78,7 +73,9 @@ func TestSingleEnvelopeXfr(t *testing.T) { HandleFunc("miek.nl.", SingleEnvelopeXfrServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalTCPServerWithTsig(":0", tsigSecret) + s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) { + srv.TsigSecret = tsigSecret + }) if err != nil { t.Fatalf("unable to run test server: %s", err) } @@ -91,7 +88,9 @@ func TestMultiEnvelopeXfr(t *testing.T) { HandleFunc("miek.nl.", MultipleEnvelopeXfrServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalTCPServerWithTsig(":0", tsigSecret) + s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) { + srv.TsigSecret = tsigSecret + }) if err != nil { t.Fatalf("unable to run test server: %s", err) } @@ -100,37 +99,6 @@ func TestMultiEnvelopeXfr(t *testing.T) { axfrTestingSuite(t, addrstr) } -func RunLocalTCPServerWithTsig(laddr string, tsig map[string]string) (*Server, string, error) { - server, l, _, err := RunLocalTCPServerWithFinChanWithTsig(laddr, tsig) - - return server, l, err -} - -func RunLocalTCPServerWithFinChanWithTsig(laddr string, tsig map[string]string) (*Server, string, chan error, error) { - l, err := net.Listen("tcp", laddr) - if err != nil { - return nil, "", nil, err - } - - server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour, TsigSecret: tsig} - - waitLock := sync.Mutex{} - waitLock.Lock() - server.NotifyStartedFunc = waitLock.Unlock - - // See the comment in RunLocalUDPServerWithFinChan as to - // why fin must be buffered. - fin := make(chan error, 1) - - go func() { - fin <- server.ActivateAndServe() - l.Close() - }() - - waitLock.Lock() - return server, l.Addr().String(), fin, nil -} - func axfrTestingSuite(t *testing.T, addrstr string) { tr := new(Transfer) m := new(Msg)