From 0e1c4e69ddf2b81648425770bb63f94696765738 Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Sun, 25 Oct 2020 02:23:01 +1030 Subject: [PATCH] Support generic net.PacketConn's for the Server (#1174) * Support generic net.PacketConn's for the Server This commit adds support for listening on generic net.PacketConn's for UDP DNS requests, previously *net.UDPConn was the only supported type. In the event of a future v2 of this module, this should be streamlined. * Eliminate wrapper functions around RunLocalXServerWithFinChan * Eliminate RunLocalTCPServerWithTsig function * Replace RunLocalTLSServer with a wrapper around RunLocalTCPServer This reduces code duplication. * Add net.PacketConn server tests This provides coverage over nearly all of the newly added code (with the unfortunate exception of (*response).RemoteAddr). * Fix broken client_test.go tests a433fbede4dc was merged into master between this PR being opened and being merged. This broke the CI tests in rather strange ways as the code was being merged into master in a way that wasn't at all clear. This commit fixes the two broken lines. --- acceptfunc_test.go | 2 +- client_test.go | 24 ++--- server.go | 104 +++++++++++++++++----- server_test.go | 217 +++++++++++++++++++++++++-------------------- xfr_test.go | 48 ++-------- 5 files changed, 222 insertions(+), 173 deletions(-) diff --git a/acceptfunc_test.go b/acceptfunc_test.go index 2d784a61..d40d4e4c 100644 --- a/acceptfunc_test.go +++ b/acceptfunc_test.go @@ -6,7 +6,7 @@ import ( func TestAcceptNotify(t *testing.T) { HandleFunc("example.org.", handleNotify) - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } diff --git a/client_test.go b/client_test.go index 11c8469d..3ac8354c 100644 --- a/client_test.go +++ b/client_test.go @@ -16,7 +16,7 @@ func TestDialUDP(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -39,7 +39,7 @@ func TestClientSync(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -73,7 +73,7 @@ func TestClientLocalAddress(t *testing.T) { HandleFunc("miek.nl.", HelloServerEchoAddrPort) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -117,7 +117,7 @@ func TestClientTLSSyncV4(t *testing.T) { Certificates: []tls.Certificate{cert}, } - s, addrstr, err := RunLocalTLSServer(":0", &config) + s, addrstr, _, err := RunLocalTLSServer(":0", &config) if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -173,7 +173,7 @@ func TestClientSyncBadID(t *testing.T) { HandleFunc("miek.nl.", HelloServerBadID) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -198,7 +198,7 @@ func TestClientSyncBadThenGoodID(t *testing.T) { HandleFunc("miek.nl.", HelloServerBadThenGoodID) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -229,7 +229,7 @@ func TestClientSyncTCPBadID(t *testing.T) { HandleFunc("miek.nl.", HelloServerBadID) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalTCPServer(":0") + s, addrstr, _, err := RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -250,7 +250,7 @@ func TestClientEDNS0(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -297,7 +297,7 @@ func TestClientEDNS0Local(t *testing.T) { HandleFunc("miek.nl.", handler) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %s", err) } @@ -347,7 +347,7 @@ func TestClientConn(t *testing.T) { defer HandleRemove("miek.nl.") // This uses TCP just to make it slightly different than TestClientSync - s, addrstr, err := RunLocalTCPServer(":0") + s, addrstr, _, err := RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -594,7 +594,7 @@ func TestConcurrentExchanges(t *testing.T) { HandleFunc("miek.nl.", handler) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %s", err) } @@ -631,7 +631,7 @@ func TestExchangeWithConn(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } diff --git a/server.go b/server.go index 77b43dea..d7f23485 100644 --- a/server.go +++ b/server.go @@ -72,9 +72,10 @@ type response struct { tsigStatus error tsigRequestMAC string tsigSecret map[string]string // the tsig secrets - udp *net.UDPConn // i/o connection if UDP was used + udp net.PacketConn // i/o connection if UDP was used tcp net.Conn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right + pcSession net.Addr // address to use when writing to a generic net.PacketConn writer Writer // writer to output the raw DNS bits } @@ -147,12 +148,24 @@ type Reader interface { ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) } -// defaultReader is an adapter for the Server struct that implements the Reader interface -// using the readTCP and readUDP func of the embedded Server. +// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns. +type PacketConnReader interface { + Reader + + // ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may + // alter connection properties, for example the read-deadline. + ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) +} + +// defaultReader is an adapter for the Server struct that implements the Reader and +// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs +// of the embedded Server. type defaultReader struct { *Server } +var _ PacketConnReader = defaultReader{} + func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { return dr.readTCP(conn, timeout) } @@ -161,8 +174,14 @@ func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byt return dr.readUDP(conn, timeout) } +func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) { + return dr.readPacketConn(conn, timeout) +} + // DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader. // Implementations should never return a nil Reader. +// Readers should also implement the optional ReaderPacketConn interface. +// ReaderPacketConn is required to use a generic net.PacketConn. type DecorateReader func(Reader) Reader // DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer. @@ -325,24 +344,22 @@ func (srv *Server) ActivateAndServe() error { srv.init() - pConn := srv.PacketConn - l := srv.Listener - if pConn != nil { + if srv.PacketConn != nil { // Check PacketConn interface's type is valid and value // is not nil - if t, ok := pConn.(*net.UDPConn); ok && t != nil { + if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil { if e := setUDPSocketOptions(t); e != nil { return e } - srv.started = true - unlock() - return srv.serveUDP(t) } - } - if l != nil { srv.started = true unlock() - return srv.serveTCP(l) + return srv.serveUDP(srv.PacketConn) + } + if srv.Listener != nil { + srv.started = true + unlock() + return srv.serveTCP(srv.Listener) } return &Error{err: "bad listeners"} } @@ -446,18 +463,24 @@ func (srv *Server) serveTCP(l net.Listener) error { } // serveUDP starts a UDP listener for the server. -func (srv *Server) serveUDP(l *net.UDPConn) error { +func (srv *Server) serveUDP(l net.PacketConn) error { defer l.Close() - if srv.NotifyStartedFunc != nil { - srv.NotifyStartedFunc() - } - reader := Reader(defaultReader{srv}) if srv.DecorateReader != nil { reader = srv.DecorateReader(reader) } + lUDP, isUDP := l.(*net.UDPConn) + readerPC, canPacketConn := reader.(PacketConnReader) + if !isUDP && !canPacketConn { + return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"} + } + + if srv.NotifyStartedFunc != nil { + srv.NotifyStartedFunc() + } + var wg sync.WaitGroup defer func() { wg.Wait() @@ -467,7 +490,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { rtimeout := srv.getReadTimeout() // deadline is not used here for srv.isStarted() { - m, s, err := reader.ReadUDP(l, rtimeout) + var ( + m []byte + sPC net.Addr + sUDP *SessionUDP + err error + ) + if isUDP { + m, sUDP, err = reader.ReadUDP(lUDP, rtimeout) + } else { + m, sPC, err = readerPC.ReadPacketConn(l, rtimeout) + } if err != nil { if !srv.isStarted() { return nil @@ -484,7 +517,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { continue } wg.Add(1) - go srv.serveUDPPacket(&wg, m, l, s) + go srv.serveUDPPacket(&wg, m, l, sUDP, sPC) } return nil @@ -546,8 +579,8 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) { } // Serve a new UDP request. -func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s} +func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) { + w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession} if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { @@ -659,6 +692,24 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S return m, s, nil } +func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) { + srv.lock.RLock() + if srv.started { + // See the comment in readTCP above. + conn.SetReadDeadline(time.Now().Add(timeout)) + } + srv.lock.RUnlock() + + m := srv.udpPool.Get().([]byte) + n, addr, err := conn.ReadFrom(m) + if err != nil { + srv.udpPool.Put(m) + return nil, nil, err + } + m = m[:n] + return m, addr, nil +} + // WriteMsg implements the ResponseWriter.WriteMsg method. func (w *response) WriteMsg(m *Msg) (err error) { if w.closed { @@ -692,7 +743,10 @@ func (w *response) Write(m []byte) (int, error) { switch { case w.udp != nil: - return WriteToSessionUDP(w.udp, m, w.udpSession) + if u, ok := w.udp.(*net.UDPConn); ok { + return WriteToSessionUDP(u, m, w.udpSession) + } + return w.udp.WriteTo(m, w.pcSession) case w.tcp != nil: if len(m) > MaxMsgSize { return 0, &Error{err: "message too large"} @@ -725,10 +779,12 @@ func (w *response) RemoteAddr() net.Addr { switch { case w.udpSession != nil: return w.udpSession.RemoteAddr() + case w.pcSession != nil: + return w.pcSession case w.tcp != nil: return w.tcp.RemoteAddr() default: - panic("dns: internal error: udpSession and tcp both nil") + panic("dns: internal error: udpSession, pcSession and tcp are all nil") } } diff --git a/server_test.go b/server_test.go index 5dea723f..6f4bb5a6 100644 --- a/server_test.go +++ b/server_test.go @@ -67,13 +67,7 @@ func AnotherHelloServer(w ResponseWriter, req *Msg) { w.WriteMsg(m) } -func RunLocalUDPServer(laddr string) (*Server, string, error) { - server, l, _, err := RunLocalUDPServerWithFinChan(laddr) - - return server, l, err -} - -func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { +func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { pc, err := net.ListenPacket("udp", laddr) if err != nil { return nil, "", nil, err @@ -84,15 +78,15 @@ func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server, waitLock.Lock() server.NotifyStartedFunc = waitLock.Unlock - // fin must be buffered so the goroutine below won't block - // forever if fin is never read from. This always happens - // in RunLocalUDPServer and can happen in TestShutdownUDP. - fin := make(chan error, 1) - for _, opt := range opts { opt(server) } + // fin must be buffered so the goroutine below won't block + // forever if fin is never read from. This always happens + // if the channel is discarded and can happen in TestShutdownUDP. + fin := make(chan error, 1) + go func() { fin <- server.ActivateAndServe() pc.Close() @@ -102,13 +96,14 @@ func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server, return server, pc.LocalAddr().String(), fin, nil } -func RunLocalTCPServer(laddr string) (*Server, string, error) { - server, l, _, err := RunLocalTCPServerWithFinChan(laddr) - - return server, l, err +func RunLocalPacketConnServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { + return RunLocalUDPServer(laddr, append(opts, func(srv *Server) { + // Make srv.PacketConn opaque to trigger the generic code paths. + srv.PacketConn = struct{ net.PacketConn }{srv.PacketConn} + })...) } -func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, error) { +func RunLocalTCPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { l, err := net.Listen("tcp", laddr) if err != nil { return nil, "", nil, err @@ -120,8 +115,11 @@ func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, er waitLock.Lock() server.NotifyStartedFunc = waitLock.Unlock - // See the comment in RunLocalUDPServerWithFinChan as to - // why fin must be buffered. + for _, opt := range opts { + opt(server) + } + + // See the comment in RunLocalUDPServer as to why fin must be buffered. fin := make(chan error, 1) go func() { @@ -133,70 +131,69 @@ func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, er return server, l.Addr().String(), fin, nil } -func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, error) { - l, err := tls.Listen("tcp", laddr, config) - if err != nil { - return nil, "", err - } - - server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour} - - waitLock := sync.Mutex{} - waitLock.Lock() - server.NotifyStartedFunc = waitLock.Unlock - - go func() { - server.ActivateAndServe() - l.Close() - }() - - waitLock.Lock() - return server, l.Addr().String(), nil +func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan error, error) { + return RunLocalTCPServer(laddr, func(srv *Server) { + srv.Listener = tls.NewListener(srv.Listener, config) + }) } func TestServing(t *testing.T) { - HandleFunc("miek.nl.", HelloServer) - HandleFunc("example.com.", AnotherHelloServer) - defer HandleRemove("miek.nl.") - defer HandleRemove("example.com.") + for _, tc := range []struct { + name string + network string + runServer func(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) + }{ + {"udp", "udp", RunLocalUDPServer}, + {"tcp", "tcp", RunLocalTCPServer}, + {"PacketConn", "udp", RunLocalPacketConnServer}, + } { + t.Run(tc.name, func(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + HandleFunc("example.com.", AnotherHelloServer) + defer HandleRemove("miek.nl.") + defer HandleRemove("example.com.") - s, addrstr, err := RunLocalUDPServer(":0") - if err != nil { - t.Fatalf("unable to run test server: %v", err) - } - defer s.Shutdown() + s, addrstr, _, err := tc.runServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() - c := new(Client) - m := new(Msg) - m.SetQuestion("miek.nl.", TypeTXT) - r, _, err := c.Exchange(m, addrstr) - if err != nil || len(r.Extra) == 0 { - t.Fatal("failed to exchange miek.nl", err) - } - txt := r.Extra[0].(*TXT).Txt[0] - if txt != "Hello world" { - t.Error("unexpected result for miek.nl", txt, "!= Hello world") - } + c := &Client{ + Net: tc.network, + } + m := new(Msg) + m.SetQuestion("miek.nl.", TypeTXT) + r, _, err := c.Exchange(m, addrstr) + if err != nil || len(r.Extra) == 0 { + t.Fatal("failed to exchange miek.nl", err) + } + txt := r.Extra[0].(*TXT).Txt[0] + if txt != "Hello world" { + t.Error("unexpected result for miek.nl", txt, "!= Hello world") + } - m.SetQuestion("example.com.", TypeTXT) - r, _, err = c.Exchange(m, addrstr) - if err != nil { - t.Fatal("failed to exchange example.com", err) - } - txt = r.Extra[0].(*TXT).Txt[0] - if txt != "Hello example" { - t.Error("unexpected result for example.com", txt, "!= Hello example") - } + m.SetQuestion("example.com.", TypeTXT) + r, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Fatal("failed to exchange example.com", err) + } + txt = r.Extra[0].(*TXT).Txt[0] + if txt != "Hello example" { + t.Error("unexpected result for example.com", txt, "!= Hello example") + } - // Test Mixes cased as noticed by Ask. - m.SetQuestion("eXaMplE.cOm.", TypeTXT) - r, _, err = c.Exchange(m, addrstr) - if err != nil { - t.Error("failed to exchange eXaMplE.cOm", err) - } - txt = r.Extra[0].(*TXT).Txt[0] - if txt != "Hello example" { - t.Error("unexpected result for example.com", txt, "!= Hello example") + // Test Mixes cased as noticed by Ask. + m.SetQuestion("eXaMplE.cOm.", TypeTXT) + r, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Error("failed to exchange eXaMplE.cOm", err) + } + txt = r.Extra[0].(*TXT).Txt[0] + if txt != "Hello example" { + t.Error("unexpected result for example.com", txt, "!= Hello example") + } + }) } } @@ -204,7 +201,7 @@ func TestServing(t *testing.T) { func TestServeIgnoresZFlag(t *testing.T) { HandleFunc("example.com.", AnotherHelloServer) - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -233,7 +230,7 @@ func TestServeNotImplemented(t *testing.T) { HandleFunc("example.com.", AnotherHelloServer) opcode := 15 - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -272,7 +269,7 @@ func TestServingTLS(t *testing.T) { Certificates: []tls.Certificate{cert}, } - s, addrstr, err := RunLocalTLSServer(":0", &config) + s, addrstr, _, err := RunLocalTLSServer(":0", &config) if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -358,7 +355,7 @@ func TestServingTLSConnectionState(t *testing.T) { Certificates: []tls.Certificate{cert}, } - s, addrstr, err := RunLocalTLSServer(":0", &config) + s, addrstr, _, err := RunLocalTLSServer(":0", &config) if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -381,7 +378,7 @@ func TestServingTLSConnectionState(t *testing.T) { // UDP DNS Server HandleFunc(".", tlsHandlerTLS(false)) defer HandleRemove(".") - s, addrstr, err = RunLocalUDPServer(":0") + s, addrstr, _, err = RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -395,7 +392,7 @@ func TestServingTLSConnectionState(t *testing.T) { } // TCP DNS Server - s, addrstr, err = RunLocalTCPServer(":0") + s, addrstr, _, err = RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -479,7 +476,7 @@ func BenchmarkServe(b *testing.B) { defer HandleRemove("miek.nl.") a := runtime.GOMAXPROCS(4) - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { b.Fatalf("unable to run test server: %v", err) } @@ -504,7 +501,7 @@ func BenchmarkServe6(b *testing.B) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") a := runtime.GOMAXPROCS(4) - s, addrstr, err := RunLocalUDPServer("[::1]:0") + s, addrstr, _, err := RunLocalUDPServer("[::1]:0") if err != nil { if strings.Contains(err.Error(), "bind: cannot assign requested address") { b.Skip("missing IPv6 support") @@ -541,7 +538,7 @@ func BenchmarkServeCompress(b *testing.B) { HandleFunc("miek.nl.", HelloServerCompress) defer HandleRemove("miek.nl.") a := runtime.GOMAXPROCS(4) - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { b.Fatalf("unable to run test server: %v", err) } @@ -594,7 +591,7 @@ func TestServingLargeResponses(t *testing.T) { HandleFunc("example.", HelloServerLargeResponse) defer HandleRemove("example.") - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -634,7 +631,7 @@ func TestServingResponse(t *testing.T) { t.Skip("skipping test in short mode.") } HandleFunc("miek.nl.", HelloServer) - s, addrstr, err := RunLocalUDPServer(":0") + s, addrstr, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -657,7 +654,7 @@ func TestServingResponse(t *testing.T) { } func TestShutdownTCP(t *testing.T) { - s, _, fin, err := RunLocalTCPServerWithFinChan(":0") + s, _, fin, err := RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -788,7 +785,7 @@ func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr stri } func TestInProgressQueriesAtShutdownTCP(t *testing.T) { - s, addr, _, err := RunLocalTCPServerWithFinChan(":0") + s, addr, _, err := RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -807,7 +804,7 @@ func TestShutdownTLS(t *testing.T) { Certificates: []tls.Certificate{cert}, } - s, _, err := RunLocalTLSServer(":0", &config) + s, _, _, err := RunLocalTLSServer(":0", &config) if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -827,7 +824,7 @@ func TestInProgressQueriesAtShutdownTLS(t *testing.T) { Certificates: []tls.Certificate{cert}, } - s, addr, err := RunLocalTLSServer(":0", &config) + s, addr, _, err := RunLocalTLSServer(":0", &config) if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -842,7 +839,6 @@ func TestInProgressQueriesAtShutdownTLS(t *testing.T) { } func TestHandlerCloseTCP(t *testing.T) { - ln, err := net.Listen("tcp", ":0") if err != nil { panic(err) @@ -887,7 +883,26 @@ func TestHandlerCloseTCP(t *testing.T) { } func TestShutdownUDP(t *testing.T) { - s, _, fin, err := RunLocalUDPServerWithFinChan(":0") + s, _, fin, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + err = s.Shutdown() + if err != nil { + t.Errorf("could not shutdown test UDP server, %v", err) + } + select { + case err := <-fin: + if err != nil { + t.Errorf("error returned from ActivateAndServe, %v", err) + } + case <-time.After(2 * time.Second): + t.Error("could not shutdown test UDP server. Gave up waiting") + } +} + +func TestShutdownPacketConn(t *testing.T) { + s, _, fin, err := RunLocalPacketConnServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -906,7 +921,17 @@ func TestShutdownUDP(t *testing.T) { } func TestInProgressQueriesAtShutdownUDP(t *testing.T) { - s, addr, _, err := RunLocalUDPServerWithFinChan(":0") + s, addr, _, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + + c := &Client{Net: "udp"} + checkInProgressQueriesAtShutdownServer(t, s, addr, c) +} + +func TestInProgressQueriesAtShutdownPacketConn(t *testing.T) { + s, addr, _, err := RunLocalPacketConnServer(":0") if err != nil { t.Fatalf("unable to run test server: %v", err) } @@ -919,7 +944,7 @@ func TestServerStartStopRace(t *testing.T) { var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) - s, _, _, err := RunLocalUDPServerWithFinChan(":0") + s, _, _, err := RunLocalUDPServer(":0") if err != nil { t.Fatalf("could not start server: %s", err) } @@ -982,7 +1007,7 @@ func TestServerReuseport(t *testing.T) { func TestServerRoundtripTsig(t *testing.T) { secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="} - s, addrstr, _, err := RunLocalUDPServerWithFinChan(":0", func(srv *Server) { + s, addrstr, _, err := RunLocalUDPServer(":0", func(srv *Server) { srv.TsigSecret = secret srv.MsgAcceptFunc = func(dh Header) MsgAcceptAction { // defaultMsgAcceptFunc does reject UPDATE queries diff --git a/xfr_test.go b/xfr_test.go index 6510e41c..7cb9f1d3 100644 --- a/xfr_test.go +++ b/xfr_test.go @@ -1,11 +1,6 @@ package dns -import ( - "net" - "sync" - "testing" - "time" -) +import "testing" var ( tsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="} @@ -52,7 +47,7 @@ func TestInvalidXfr(t *testing.T) { HandleFunc("miek.nl.", InvalidXfrServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalTCPServer(":0") + s, addrstr, _, err := RunLocalTCPServer(":0") if err != nil { t.Fatalf("unable to run test server: %s", err) } @@ -78,7 +73,9 @@ func TestSingleEnvelopeXfr(t *testing.T) { HandleFunc("miek.nl.", SingleEnvelopeXfrServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalTCPServerWithTsig(":0", tsigSecret) + s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) { + srv.TsigSecret = tsigSecret + }) if err != nil { t.Fatalf("unable to run test server: %s", err) } @@ -91,7 +88,9 @@ func TestMultiEnvelopeXfr(t *testing.T) { HandleFunc("miek.nl.", MultipleEnvelopeXfrServer) defer HandleRemove("miek.nl.") - s, addrstr, err := RunLocalTCPServerWithTsig(":0", tsigSecret) + s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) { + srv.TsigSecret = tsigSecret + }) if err != nil { t.Fatalf("unable to run test server: %s", err) } @@ -100,37 +99,6 @@ func TestMultiEnvelopeXfr(t *testing.T) { axfrTestingSuite(t, addrstr) } -func RunLocalTCPServerWithTsig(laddr string, tsig map[string]string) (*Server, string, error) { - server, l, _, err := RunLocalTCPServerWithFinChanWithTsig(laddr, tsig) - - return server, l, err -} - -func RunLocalTCPServerWithFinChanWithTsig(laddr string, tsig map[string]string) (*Server, string, chan error, error) { - l, err := net.Listen("tcp", laddr) - if err != nil { - return nil, "", nil, err - } - - server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour, TsigSecret: tsig} - - waitLock := sync.Mutex{} - waitLock.Lock() - server.NotifyStartedFunc = waitLock.Unlock - - // See the comment in RunLocalUDPServerWithFinChan as to - // why fin must be buffered. - fin := make(chan error, 1) - - go func() { - fin <- server.ActivateAndServe() - l.Close() - }() - - waitLock.Lock() - return server, l.Addr().String(), fin, nil -} - func axfrTestingSuite(t *testing.T, addrstr string) { tr := new(Transfer) m := new(Msg)