Support generic net.PacketConn's for the Server (#1174)

* Support generic net.PacketConn's for the Server

This commit adds support for listening on generic net.PacketConn's for
UDP DNS requests, previously *net.UDPConn was the only supported type.

In the event of a future v2 of this module, this should be streamlined.

* Eliminate wrapper functions around RunLocalXServerWithFinChan

* Eliminate RunLocalTCPServerWithTsig function

* Replace RunLocalTLSServer with a wrapper around RunLocalTCPServer

This reduces code duplication.

* Add net.PacketConn server tests

This provides coverage over nearly all of the newly added code (with
the unfortunate exception of (*response).RemoteAddr).

* Fix broken client_test.go tests

a433fbede4 was merged into master between this PR being opened and
being merged. This broke the CI tests in rather strange ways as the
code was being merged into master in a way that wasn't at all clear.
This commit fixes the two broken lines.
This commit is contained in:
Tom Thorogood 2020-10-25 02:23:01 +10:30 committed by GitHub
parent a3ad44419a
commit 0e1c4e69dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 222 additions and 173 deletions

View File

@ -6,7 +6,7 @@ import (
func TestAcceptNotify(t *testing.T) {
HandleFunc("example.org.", handleNotify)
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}

View File

@ -16,7 +16,7 @@ func TestDialUDP(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -39,7 +39,7 @@ func TestClientSync(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -73,7 +73,7 @@ func TestClientLocalAddress(t *testing.T) {
HandleFunc("miek.nl.", HelloServerEchoAddrPort)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -117,7 +117,7 @@ func TestClientTLSSyncV4(t *testing.T) {
Certificates: []tls.Certificate{cert},
}
s, addrstr, err := RunLocalTLSServer(":0", &config)
s, addrstr, _, err := RunLocalTLSServer(":0", &config)
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -173,7 +173,7 @@ func TestClientSyncBadID(t *testing.T) {
HandleFunc("miek.nl.", HelloServerBadID)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -198,7 +198,7 @@ func TestClientSyncBadThenGoodID(t *testing.T) {
HandleFunc("miek.nl.", HelloServerBadThenGoodID)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -229,7 +229,7 @@ func TestClientSyncTCPBadID(t *testing.T) {
HandleFunc("miek.nl.", HelloServerBadID)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalTCPServer(":0")
s, addrstr, _, err := RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -250,7 +250,7 @@ func TestClientEDNS0(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -297,7 +297,7 @@ func TestClientEDNS0Local(t *testing.T) {
HandleFunc("miek.nl.", handler)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %s", err)
}
@ -347,7 +347,7 @@ func TestClientConn(t *testing.T) {
defer HandleRemove("miek.nl.")
// This uses TCP just to make it slightly different than TestClientSync
s, addrstr, err := RunLocalTCPServer(":0")
s, addrstr, _, err := RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -594,7 +594,7 @@ func TestConcurrentExchanges(t *testing.T) {
HandleFunc("miek.nl.", handler)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %s", err)
}
@ -631,7 +631,7 @@ func TestExchangeWithConn(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}

104
server.go
View File

@ -72,9 +72,10 @@ type response struct {
tsigStatus error
tsigRequestMAC string
tsigSecret map[string]string // the tsig secrets
udp *net.UDPConn // i/o connection if UDP was used
udp net.PacketConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
pcSession net.Addr // address to use when writing to a generic net.PacketConn
writer Writer // writer to output the raw DNS bits
}
@ -147,12 +148,24 @@ type Reader interface {
ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
}
// defaultReader is an adapter for the Server struct that implements the Reader interface
// using the readTCP and readUDP func of the embedded Server.
// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns.
type PacketConnReader interface {
Reader
// ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may
// alter connection properties, for example the read-deadline.
ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
}
// defaultReader is an adapter for the Server struct that implements the Reader and
// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs
// of the embedded Server.
type defaultReader struct {
*Server
}
var _ PacketConnReader = defaultReader{}
func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
return dr.readTCP(conn, timeout)
}
@ -161,8 +174,14 @@ func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byt
return dr.readUDP(conn, timeout)
}
func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
return dr.readPacketConn(conn, timeout)
}
// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
// Implementations should never return a nil Reader.
// Readers should also implement the optional ReaderPacketConn interface.
// ReaderPacketConn is required to use a generic net.PacketConn.
type DecorateReader func(Reader) Reader
// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
@ -325,24 +344,22 @@ func (srv *Server) ActivateAndServe() error {
srv.init()
pConn := srv.PacketConn
l := srv.Listener
if pConn != nil {
if srv.PacketConn != nil {
// Check PacketConn interface's type is valid and value
// is not nil
if t, ok := pConn.(*net.UDPConn); ok && t != nil {
if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil {
if e := setUDPSocketOptions(t); e != nil {
return e
}
srv.started = true
unlock()
return srv.serveUDP(t)
}
}
if l != nil {
srv.started = true
unlock()
return srv.serveTCP(l)
return srv.serveUDP(srv.PacketConn)
}
if srv.Listener != nil {
srv.started = true
unlock()
return srv.serveTCP(srv.Listener)
}
return &Error{err: "bad listeners"}
}
@ -446,18 +463,24 @@ func (srv *Server) serveTCP(l net.Listener) error {
}
// serveUDP starts a UDP listener for the server.
func (srv *Server) serveUDP(l *net.UDPConn) error {
func (srv *Server) serveUDP(l net.PacketConn) error {
defer l.Close()
if srv.NotifyStartedFunc != nil {
srv.NotifyStartedFunc()
}
reader := Reader(defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
lUDP, isUDP := l.(*net.UDPConn)
readerPC, canPacketConn := reader.(PacketConnReader)
if !isUDP && !canPacketConn {
return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"}
}
if srv.NotifyStartedFunc != nil {
srv.NotifyStartedFunc()
}
var wg sync.WaitGroup
defer func() {
wg.Wait()
@ -467,7 +490,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
rtimeout := srv.getReadTimeout()
// deadline is not used here
for srv.isStarted() {
m, s, err := reader.ReadUDP(l, rtimeout)
var (
m []byte
sPC net.Addr
sUDP *SessionUDP
err error
)
if isUDP {
m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
} else {
m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
}
if err != nil {
if !srv.isStarted() {
return nil
@ -484,7 +517,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
continue
}
wg.Add(1)
go srv.serveUDPPacket(&wg, m, l, s)
go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
}
return nil
@ -546,8 +579,8 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
}
// Serve a new UDP request.
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
@ -659,6 +692,24 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S
return m, s, nil
}
func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
srv.lock.RLock()
if srv.started {
// See the comment in readTCP above.
conn.SetReadDeadline(time.Now().Add(timeout))
}
srv.lock.RUnlock()
m := srv.udpPool.Get().([]byte)
n, addr, err := conn.ReadFrom(m)
if err != nil {
srv.udpPool.Put(m)
return nil, nil, err
}
m = m[:n]
return m, addr, nil
}
// WriteMsg implements the ResponseWriter.WriteMsg method.
func (w *response) WriteMsg(m *Msg) (err error) {
if w.closed {
@ -692,7 +743,10 @@ func (w *response) Write(m []byte) (int, error) {
switch {
case w.udp != nil:
return WriteToSessionUDP(w.udp, m, w.udpSession)
if u, ok := w.udp.(*net.UDPConn); ok {
return WriteToSessionUDP(u, m, w.udpSession)
}
return w.udp.WriteTo(m, w.pcSession)
case w.tcp != nil:
if len(m) > MaxMsgSize {
return 0, &Error{err: "message too large"}
@ -725,10 +779,12 @@ func (w *response) RemoteAddr() net.Addr {
switch {
case w.udpSession != nil:
return w.udpSession.RemoteAddr()
case w.pcSession != nil:
return w.pcSession
case w.tcp != nil:
return w.tcp.RemoteAddr()
default:
panic("dns: internal error: udpSession and tcp both nil")
panic("dns: internal error: udpSession, pcSession and tcp are all nil")
}
}

View File

@ -67,13 +67,7 @@ func AnotherHelloServer(w ResponseWriter, req *Msg) {
w.WriteMsg(m)
}
func RunLocalUDPServer(laddr string) (*Server, string, error) {
server, l, _, err := RunLocalUDPServerWithFinChan(laddr)
return server, l, err
}
func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
pc, err := net.ListenPacket("udp", laddr)
if err != nil {
return nil, "", nil, err
@ -84,15 +78,15 @@ func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server,
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
// fin must be buffered so the goroutine below won't block
// forever if fin is never read from. This always happens
// in RunLocalUDPServer and can happen in TestShutdownUDP.
fin := make(chan error, 1)
for _, opt := range opts {
opt(server)
}
// fin must be buffered so the goroutine below won't block
// forever if fin is never read from. This always happens
// if the channel is discarded and can happen in TestShutdownUDP.
fin := make(chan error, 1)
go func() {
fin <- server.ActivateAndServe()
pc.Close()
@ -102,13 +96,14 @@ func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server,
return server, pc.LocalAddr().String(), fin, nil
}
func RunLocalTCPServer(laddr string) (*Server, string, error) {
server, l, _, err := RunLocalTCPServerWithFinChan(laddr)
return server, l, err
func RunLocalPacketConnServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
return RunLocalUDPServer(laddr, append(opts, func(srv *Server) {
// Make srv.PacketConn opaque to trigger the generic code paths.
srv.PacketConn = struct{ net.PacketConn }{srv.PacketConn}
})...)
}
func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, error) {
func RunLocalTCPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
l, err := net.Listen("tcp", laddr)
if err != nil {
return nil, "", nil, err
@ -120,8 +115,11 @@ func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, er
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
// See the comment in RunLocalUDPServerWithFinChan as to
// why fin must be buffered.
for _, opt := range opts {
opt(server)
}
// See the comment in RunLocalUDPServer as to why fin must be buffered.
fin := make(chan error, 1)
go func() {
@ -133,70 +131,69 @@ func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, er
return server, l.Addr().String(), fin, nil
}
func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, error) {
l, err := tls.Listen("tcp", laddr, config)
if err != nil {
return nil, "", err
}
server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
go func() {
server.ActivateAndServe()
l.Close()
}()
waitLock.Lock()
return server, l.Addr().String(), nil
func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan error, error) {
return RunLocalTCPServer(laddr, func(srv *Server) {
srv.Listener = tls.NewListener(srv.Listener, config)
})
}
func TestServing(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
HandleFunc("example.com.", AnotherHelloServer)
defer HandleRemove("miek.nl.")
defer HandleRemove("example.com.")
for _, tc := range []struct {
name string
network string
runServer func(laddr string, opts ...func(*Server)) (*Server, string, chan error, error)
}{
{"udp", "udp", RunLocalUDPServer},
{"tcp", "tcp", RunLocalTCPServer},
{"PacketConn", "udp", RunLocalPacketConnServer},
} {
t.Run(tc.name, func(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
HandleFunc("example.com.", AnotherHelloServer)
defer HandleRemove("miek.nl.")
defer HandleRemove("example.com.")
s, addrstr, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()
s, addrstr, _, err := tc.runServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()
c := new(Client)
m := new(Msg)
m.SetQuestion("miek.nl.", TypeTXT)
r, _, err := c.Exchange(m, addrstr)
if err != nil || len(r.Extra) == 0 {
t.Fatal("failed to exchange miek.nl", err)
}
txt := r.Extra[0].(*TXT).Txt[0]
if txt != "Hello world" {
t.Error("unexpected result for miek.nl", txt, "!= Hello world")
}
c := &Client{
Net: tc.network,
}
m := new(Msg)
m.SetQuestion("miek.nl.", TypeTXT)
r, _, err := c.Exchange(m, addrstr)
if err != nil || len(r.Extra) == 0 {
t.Fatal("failed to exchange miek.nl", err)
}
txt := r.Extra[0].(*TXT).Txt[0]
if txt != "Hello world" {
t.Error("unexpected result for miek.nl", txt, "!= Hello world")
}
m.SetQuestion("example.com.", TypeTXT)
r, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Fatal("failed to exchange example.com", err)
}
txt = r.Extra[0].(*TXT).Txt[0]
if txt != "Hello example" {
t.Error("unexpected result for example.com", txt, "!= Hello example")
}
m.SetQuestion("example.com.", TypeTXT)
r, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Fatal("failed to exchange example.com", err)
}
txt = r.Extra[0].(*TXT).Txt[0]
if txt != "Hello example" {
t.Error("unexpected result for example.com", txt, "!= Hello example")
}
// Test Mixes cased as noticed by Ask.
m.SetQuestion("eXaMplE.cOm.", TypeTXT)
r, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Error("failed to exchange eXaMplE.cOm", err)
}
txt = r.Extra[0].(*TXT).Txt[0]
if txt != "Hello example" {
t.Error("unexpected result for example.com", txt, "!= Hello example")
// Test Mixes cased as noticed by Ask.
m.SetQuestion("eXaMplE.cOm.", TypeTXT)
r, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Error("failed to exchange eXaMplE.cOm", err)
}
txt = r.Extra[0].(*TXT).Txt[0]
if txt != "Hello example" {
t.Error("unexpected result for example.com", txt, "!= Hello example")
}
})
}
}
@ -204,7 +201,7 @@ func TestServing(t *testing.T) {
func TestServeIgnoresZFlag(t *testing.T) {
HandleFunc("example.com.", AnotherHelloServer)
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -233,7 +230,7 @@ func TestServeNotImplemented(t *testing.T) {
HandleFunc("example.com.", AnotherHelloServer)
opcode := 15
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -272,7 +269,7 @@ func TestServingTLS(t *testing.T) {
Certificates: []tls.Certificate{cert},
}
s, addrstr, err := RunLocalTLSServer(":0", &config)
s, addrstr, _, err := RunLocalTLSServer(":0", &config)
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -358,7 +355,7 @@ func TestServingTLSConnectionState(t *testing.T) {
Certificates: []tls.Certificate{cert},
}
s, addrstr, err := RunLocalTLSServer(":0", &config)
s, addrstr, _, err := RunLocalTLSServer(":0", &config)
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -381,7 +378,7 @@ func TestServingTLSConnectionState(t *testing.T) {
// UDP DNS Server
HandleFunc(".", tlsHandlerTLS(false))
defer HandleRemove(".")
s, addrstr, err = RunLocalUDPServer(":0")
s, addrstr, _, err = RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -395,7 +392,7 @@ func TestServingTLSConnectionState(t *testing.T) {
}
// TCP DNS Server
s, addrstr, err = RunLocalTCPServer(":0")
s, addrstr, _, err = RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -479,7 +476,7 @@ func BenchmarkServe(b *testing.B) {
defer HandleRemove("miek.nl.")
a := runtime.GOMAXPROCS(4)
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
b.Fatalf("unable to run test server: %v", err)
}
@ -504,7 +501,7 @@ func BenchmarkServe6(b *testing.B) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
a := runtime.GOMAXPROCS(4)
s, addrstr, err := RunLocalUDPServer("[::1]:0")
s, addrstr, _, err := RunLocalUDPServer("[::1]:0")
if err != nil {
if strings.Contains(err.Error(), "bind: cannot assign requested address") {
b.Skip("missing IPv6 support")
@ -541,7 +538,7 @@ func BenchmarkServeCompress(b *testing.B) {
HandleFunc("miek.nl.", HelloServerCompress)
defer HandleRemove("miek.nl.")
a := runtime.GOMAXPROCS(4)
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
b.Fatalf("unable to run test server: %v", err)
}
@ -594,7 +591,7 @@ func TestServingLargeResponses(t *testing.T) {
HandleFunc("example.", HelloServerLargeResponse)
defer HandleRemove("example.")
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -634,7 +631,7 @@ func TestServingResponse(t *testing.T) {
t.Skip("skipping test in short mode.")
}
HandleFunc("miek.nl.", HelloServer)
s, addrstr, err := RunLocalUDPServer(":0")
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -657,7 +654,7 @@ func TestServingResponse(t *testing.T) {
}
func TestShutdownTCP(t *testing.T) {
s, _, fin, err := RunLocalTCPServerWithFinChan(":0")
s, _, fin, err := RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -788,7 +785,7 @@ func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr stri
}
func TestInProgressQueriesAtShutdownTCP(t *testing.T) {
s, addr, _, err := RunLocalTCPServerWithFinChan(":0")
s, addr, _, err := RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -807,7 +804,7 @@ func TestShutdownTLS(t *testing.T) {
Certificates: []tls.Certificate{cert},
}
s, _, err := RunLocalTLSServer(":0", &config)
s, _, _, err := RunLocalTLSServer(":0", &config)
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -827,7 +824,7 @@ func TestInProgressQueriesAtShutdownTLS(t *testing.T) {
Certificates: []tls.Certificate{cert},
}
s, addr, err := RunLocalTLSServer(":0", &config)
s, addr, _, err := RunLocalTLSServer(":0", &config)
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -842,7 +839,6 @@ func TestInProgressQueriesAtShutdownTLS(t *testing.T) {
}
func TestHandlerCloseTCP(t *testing.T) {
ln, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
@ -887,7 +883,26 @@ func TestHandlerCloseTCP(t *testing.T) {
}
func TestShutdownUDP(t *testing.T) {
s, _, fin, err := RunLocalUDPServerWithFinChan(":0")
s, _, fin, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
err = s.Shutdown()
if err != nil {
t.Errorf("could not shutdown test UDP server, %v", err)
}
select {
case err := <-fin:
if err != nil {
t.Errorf("error returned from ActivateAndServe, %v", err)
}
case <-time.After(2 * time.Second):
t.Error("could not shutdown test UDP server. Gave up waiting")
}
}
func TestShutdownPacketConn(t *testing.T) {
s, _, fin, err := RunLocalPacketConnServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -906,7 +921,17 @@ func TestShutdownUDP(t *testing.T) {
}
func TestInProgressQueriesAtShutdownUDP(t *testing.T) {
s, addr, _, err := RunLocalUDPServerWithFinChan(":0")
s, addr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
c := &Client{Net: "udp"}
checkInProgressQueriesAtShutdownServer(t, s, addr, c)
}
func TestInProgressQueriesAtShutdownPacketConn(t *testing.T) {
s, addr, _, err := RunLocalPacketConnServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
@ -919,7 +944,7 @@ func TestServerStartStopRace(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
s, _, _, err := RunLocalUDPServerWithFinChan(":0")
s, _, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("could not start server: %s", err)
}
@ -982,7 +1007,7 @@ func TestServerReuseport(t *testing.T) {
func TestServerRoundtripTsig(t *testing.T) {
secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="}
s, addrstr, _, err := RunLocalUDPServerWithFinChan(":0", func(srv *Server) {
s, addrstr, _, err := RunLocalUDPServer(":0", func(srv *Server) {
srv.TsigSecret = secret
srv.MsgAcceptFunc = func(dh Header) MsgAcceptAction {
// defaultMsgAcceptFunc does reject UPDATE queries

View File

@ -1,11 +1,6 @@
package dns
import (
"net"
"sync"
"testing"
"time"
)
import "testing"
var (
tsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
@ -52,7 +47,7 @@ func TestInvalidXfr(t *testing.T) {
HandleFunc("miek.nl.", InvalidXfrServer)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalTCPServer(":0")
s, addrstr, _, err := RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %s", err)
}
@ -78,7 +73,9 @@ func TestSingleEnvelopeXfr(t *testing.T) {
HandleFunc("miek.nl.", SingleEnvelopeXfrServer)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalTCPServerWithTsig(":0", tsigSecret)
s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) {
srv.TsigSecret = tsigSecret
})
if err != nil {
t.Fatalf("unable to run test server: %s", err)
}
@ -91,7 +88,9 @@ func TestMultiEnvelopeXfr(t *testing.T) {
HandleFunc("miek.nl.", MultipleEnvelopeXfrServer)
defer HandleRemove("miek.nl.")
s, addrstr, err := RunLocalTCPServerWithTsig(":0", tsigSecret)
s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) {
srv.TsigSecret = tsigSecret
})
if err != nil {
t.Fatalf("unable to run test server: %s", err)
}
@ -100,37 +99,6 @@ func TestMultiEnvelopeXfr(t *testing.T) {
axfrTestingSuite(t, addrstr)
}
func RunLocalTCPServerWithTsig(laddr string, tsig map[string]string) (*Server, string, error) {
server, l, _, err := RunLocalTCPServerWithFinChanWithTsig(laddr, tsig)
return server, l, err
}
func RunLocalTCPServerWithFinChanWithTsig(laddr string, tsig map[string]string) (*Server, string, chan error, error) {
l, err := net.Listen("tcp", laddr)
if err != nil {
return nil, "", nil, err
}
server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour, TsigSecret: tsig}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
// See the comment in RunLocalUDPServerWithFinChan as to
// why fin must be buffered.
fin := make(chan error, 1)
go func() {
fin <- server.ActivateAndServe()
l.Close()
}()
waitLock.Lock()
return server, l.Addr().String(), fin, nil
}
func axfrTestingSuite(t *testing.T, addrstr string) {
tr := new(Transfer)
m := new(Msg)