Support generic net.PacketConn's for the Server (#1174)
* Support generic net.PacketConn's for the Server
This commit adds support for listening on generic net.PacketConn's for
UDP DNS requests, previously *net.UDPConn was the only supported type.
In the event of a future v2 of this module, this should be streamlined.
* Eliminate wrapper functions around RunLocalXServerWithFinChan
* Eliminate RunLocalTCPServerWithTsig function
* Replace RunLocalTLSServer with a wrapper around RunLocalTCPServer
This reduces code duplication.
* Add net.PacketConn server tests
This provides coverage over nearly all of the newly added code (with
the unfortunate exception of (*response).RemoteAddr).
* Fix broken client_test.go tests
a433fbede4
was merged into master between this PR being opened and
being merged. This broke the CI tests in rather strange ways as the
code was being merged into master in a way that wasn't at all clear.
This commit fixes the two broken lines.
This commit is contained in:
parent
a3ad44419a
commit
0e1c4e69dd
|
@ -6,7 +6,7 @@ import (
|
||||||
|
|
||||||
func TestAcceptNotify(t *testing.T) {
|
func TestAcceptNotify(t *testing.T) {
|
||||||
HandleFunc("example.org.", handleNotify)
|
HandleFunc("example.org.", handleNotify)
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ func TestDialUDP(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", HelloServer)
|
HandleFunc("miek.nl.", HelloServer)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ func TestClientSync(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", HelloServer)
|
HandleFunc("miek.nl.", HelloServer)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -73,7 +73,7 @@ func TestClientLocalAddress(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", HelloServerEchoAddrPort)
|
HandleFunc("miek.nl.", HelloServerEchoAddrPort)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -117,7 +117,7 @@ func TestClientTLSSyncV4(t *testing.T) {
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
}
|
}
|
||||||
|
|
||||||
s, addrstr, err := RunLocalTLSServer(":0", &config)
|
s, addrstr, _, err := RunLocalTLSServer(":0", &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -173,7 +173,7 @@ func TestClientSyncBadID(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", HelloServerBadID)
|
HandleFunc("miek.nl.", HelloServerBadID)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ func TestClientSyncBadThenGoodID(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", HelloServerBadThenGoodID)
|
HandleFunc("miek.nl.", HelloServerBadThenGoodID)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -229,7 +229,7 @@ func TestClientSyncTCPBadID(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", HelloServerBadID)
|
HandleFunc("miek.nl.", HelloServerBadID)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalTCPServer(":0")
|
s, addrstr, _, err := RunLocalTCPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -250,7 +250,7 @@ func TestClientEDNS0(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", HelloServer)
|
HandleFunc("miek.nl.", HelloServer)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -297,7 +297,7 @@ func TestClientEDNS0Local(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", handler)
|
HandleFunc("miek.nl.", handler)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %s", err)
|
t.Fatalf("unable to run test server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -347,7 +347,7 @@ func TestClientConn(t *testing.T) {
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
// This uses TCP just to make it slightly different than TestClientSync
|
// This uses TCP just to make it slightly different than TestClientSync
|
||||||
s, addrstr, err := RunLocalTCPServer(":0")
|
s, addrstr, _, err := RunLocalTCPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -594,7 +594,7 @@ func TestConcurrentExchanges(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", handler)
|
HandleFunc("miek.nl.", handler)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %s", err)
|
t.Fatalf("unable to run test server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -631,7 +631,7 @@ func TestExchangeWithConn(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", HelloServer)
|
HandleFunc("miek.nl.", HelloServer)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
104
server.go
104
server.go
|
@ -72,9 +72,10 @@ type response struct {
|
||||||
tsigStatus error
|
tsigStatus error
|
||||||
tsigRequestMAC string
|
tsigRequestMAC string
|
||||||
tsigSecret map[string]string // the tsig secrets
|
tsigSecret map[string]string // the tsig secrets
|
||||||
udp *net.UDPConn // i/o connection if UDP was used
|
udp net.PacketConn // i/o connection if UDP was used
|
||||||
tcp net.Conn // i/o connection if TCP was used
|
tcp net.Conn // i/o connection if TCP was used
|
||||||
udpSession *SessionUDP // oob data to get egress interface right
|
udpSession *SessionUDP // oob data to get egress interface right
|
||||||
|
pcSession net.Addr // address to use when writing to a generic net.PacketConn
|
||||||
writer Writer // writer to output the raw DNS bits
|
writer Writer // writer to output the raw DNS bits
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,12 +148,24 @@ type Reader interface {
|
||||||
ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
|
ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultReader is an adapter for the Server struct that implements the Reader interface
|
// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns.
|
||||||
// using the readTCP and readUDP func of the embedded Server.
|
type PacketConnReader interface {
|
||||||
|
Reader
|
||||||
|
|
||||||
|
// ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may
|
||||||
|
// alter connection properties, for example the read-deadline.
|
||||||
|
ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultReader is an adapter for the Server struct that implements the Reader and
|
||||||
|
// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs
|
||||||
|
// of the embedded Server.
|
||||||
type defaultReader struct {
|
type defaultReader struct {
|
||||||
*Server
|
*Server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ PacketConnReader = defaultReader{}
|
||||||
|
|
||||||
func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
|
func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
|
||||||
return dr.readTCP(conn, timeout)
|
return dr.readTCP(conn, timeout)
|
||||||
}
|
}
|
||||||
|
@ -161,8 +174,14 @@ func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byt
|
||||||
return dr.readUDP(conn, timeout)
|
return dr.readUDP(conn, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
|
||||||
|
return dr.readPacketConn(conn, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
|
// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
|
||||||
// Implementations should never return a nil Reader.
|
// Implementations should never return a nil Reader.
|
||||||
|
// Readers should also implement the optional ReaderPacketConn interface.
|
||||||
|
// ReaderPacketConn is required to use a generic net.PacketConn.
|
||||||
type DecorateReader func(Reader) Reader
|
type DecorateReader func(Reader) Reader
|
||||||
|
|
||||||
// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
|
// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
|
||||||
|
@ -325,24 +344,22 @@ func (srv *Server) ActivateAndServe() error {
|
||||||
|
|
||||||
srv.init()
|
srv.init()
|
||||||
|
|
||||||
pConn := srv.PacketConn
|
if srv.PacketConn != nil {
|
||||||
l := srv.Listener
|
|
||||||
if pConn != nil {
|
|
||||||
// Check PacketConn interface's type is valid and value
|
// Check PacketConn interface's type is valid and value
|
||||||
// is not nil
|
// is not nil
|
||||||
if t, ok := pConn.(*net.UDPConn); ok && t != nil {
|
if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil {
|
||||||
if e := setUDPSocketOptions(t); e != nil {
|
if e := setUDPSocketOptions(t); e != nil {
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
srv.started = true
|
|
||||||
unlock()
|
|
||||||
return srv.serveUDP(t)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if l != nil {
|
|
||||||
srv.started = true
|
srv.started = true
|
||||||
unlock()
|
unlock()
|
||||||
return srv.serveTCP(l)
|
return srv.serveUDP(srv.PacketConn)
|
||||||
|
}
|
||||||
|
if srv.Listener != nil {
|
||||||
|
srv.started = true
|
||||||
|
unlock()
|
||||||
|
return srv.serveTCP(srv.Listener)
|
||||||
}
|
}
|
||||||
return &Error{err: "bad listeners"}
|
return &Error{err: "bad listeners"}
|
||||||
}
|
}
|
||||||
|
@ -446,18 +463,24 @@ func (srv *Server) serveTCP(l net.Listener) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// serveUDP starts a UDP listener for the server.
|
// serveUDP starts a UDP listener for the server.
|
||||||
func (srv *Server) serveUDP(l *net.UDPConn) error {
|
func (srv *Server) serveUDP(l net.PacketConn) error {
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
|
||||||
if srv.NotifyStartedFunc != nil {
|
|
||||||
srv.NotifyStartedFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
reader := Reader(defaultReader{srv})
|
reader := Reader(defaultReader{srv})
|
||||||
if srv.DecorateReader != nil {
|
if srv.DecorateReader != nil {
|
||||||
reader = srv.DecorateReader(reader)
|
reader = srv.DecorateReader(reader)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lUDP, isUDP := l.(*net.UDPConn)
|
||||||
|
readerPC, canPacketConn := reader.(PacketConnReader)
|
||||||
|
if !isUDP && !canPacketConn {
|
||||||
|
return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"}
|
||||||
|
}
|
||||||
|
|
||||||
|
if srv.NotifyStartedFunc != nil {
|
||||||
|
srv.NotifyStartedFunc()
|
||||||
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
defer func() {
|
defer func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
@ -467,7 +490,17 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
||||||
rtimeout := srv.getReadTimeout()
|
rtimeout := srv.getReadTimeout()
|
||||||
// deadline is not used here
|
// deadline is not used here
|
||||||
for srv.isStarted() {
|
for srv.isStarted() {
|
||||||
m, s, err := reader.ReadUDP(l, rtimeout)
|
var (
|
||||||
|
m []byte
|
||||||
|
sPC net.Addr
|
||||||
|
sUDP *SessionUDP
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if isUDP {
|
||||||
|
m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
|
||||||
|
} else {
|
||||||
|
m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !srv.isStarted() {
|
if !srv.isStarted() {
|
||||||
return nil
|
return nil
|
||||||
|
@ -484,7 +517,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go srv.serveUDPPacket(&wg, m, l, s)
|
go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -546,8 +579,8 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve a new UDP request.
|
// Serve a new UDP request.
|
||||||
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) {
|
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
|
||||||
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
|
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession}
|
||||||
if srv.DecorateWriter != nil {
|
if srv.DecorateWriter != nil {
|
||||||
w.writer = srv.DecorateWriter(w)
|
w.writer = srv.DecorateWriter(w)
|
||||||
} else {
|
} else {
|
||||||
|
@ -659,6 +692,24 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S
|
||||||
return m, s, nil
|
return m, s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
|
||||||
|
srv.lock.RLock()
|
||||||
|
if srv.started {
|
||||||
|
// See the comment in readTCP above.
|
||||||
|
conn.SetReadDeadline(time.Now().Add(timeout))
|
||||||
|
}
|
||||||
|
srv.lock.RUnlock()
|
||||||
|
|
||||||
|
m := srv.udpPool.Get().([]byte)
|
||||||
|
n, addr, err := conn.ReadFrom(m)
|
||||||
|
if err != nil {
|
||||||
|
srv.udpPool.Put(m)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
m = m[:n]
|
||||||
|
return m, addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
// WriteMsg implements the ResponseWriter.WriteMsg method.
|
// WriteMsg implements the ResponseWriter.WriteMsg method.
|
||||||
func (w *response) WriteMsg(m *Msg) (err error) {
|
func (w *response) WriteMsg(m *Msg) (err error) {
|
||||||
if w.closed {
|
if w.closed {
|
||||||
|
@ -692,7 +743,10 @@ func (w *response) Write(m []byte) (int, error) {
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case w.udp != nil:
|
case w.udp != nil:
|
||||||
return WriteToSessionUDP(w.udp, m, w.udpSession)
|
if u, ok := w.udp.(*net.UDPConn); ok {
|
||||||
|
return WriteToSessionUDP(u, m, w.udpSession)
|
||||||
|
}
|
||||||
|
return w.udp.WriteTo(m, w.pcSession)
|
||||||
case w.tcp != nil:
|
case w.tcp != nil:
|
||||||
if len(m) > MaxMsgSize {
|
if len(m) > MaxMsgSize {
|
||||||
return 0, &Error{err: "message too large"}
|
return 0, &Error{err: "message too large"}
|
||||||
|
@ -725,10 +779,12 @@ func (w *response) RemoteAddr() net.Addr {
|
||||||
switch {
|
switch {
|
||||||
case w.udpSession != nil:
|
case w.udpSession != nil:
|
||||||
return w.udpSession.RemoteAddr()
|
return w.udpSession.RemoteAddr()
|
||||||
|
case w.pcSession != nil:
|
||||||
|
return w.pcSession
|
||||||
case w.tcp != nil:
|
case w.tcp != nil:
|
||||||
return w.tcp.RemoteAddr()
|
return w.tcp.RemoteAddr()
|
||||||
default:
|
default:
|
||||||
panic("dns: internal error: udpSession and tcp both nil")
|
panic("dns: internal error: udpSession, pcSession and tcp are all nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
217
server_test.go
217
server_test.go
|
@ -67,13 +67,7 @@ func AnotherHelloServer(w ResponseWriter, req *Msg) {
|
||||||
w.WriteMsg(m)
|
w.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunLocalUDPServer(laddr string) (*Server, string, error) {
|
func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||||
server, l, _, err := RunLocalUDPServerWithFinChan(laddr)
|
|
||||||
|
|
||||||
return server, l, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
|
||||||
pc, err := net.ListenPacket("udp", laddr)
|
pc, err := net.ListenPacket("udp", laddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", nil, err
|
return nil, "", nil, err
|
||||||
|
@ -84,15 +78,15 @@ func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server,
|
||||||
waitLock.Lock()
|
waitLock.Lock()
|
||||||
server.NotifyStartedFunc = waitLock.Unlock
|
server.NotifyStartedFunc = waitLock.Unlock
|
||||||
|
|
||||||
// fin must be buffered so the goroutine below won't block
|
|
||||||
// forever if fin is never read from. This always happens
|
|
||||||
// in RunLocalUDPServer and can happen in TestShutdownUDP.
|
|
||||||
fin := make(chan error, 1)
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(server)
|
opt(server)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fin must be buffered so the goroutine below won't block
|
||||||
|
// forever if fin is never read from. This always happens
|
||||||
|
// if the channel is discarded and can happen in TestShutdownUDP.
|
||||||
|
fin := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
fin <- server.ActivateAndServe()
|
fin <- server.ActivateAndServe()
|
||||||
pc.Close()
|
pc.Close()
|
||||||
|
@ -102,13 +96,14 @@ func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server,
|
||||||
return server, pc.LocalAddr().String(), fin, nil
|
return server, pc.LocalAddr().String(), fin, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunLocalTCPServer(laddr string) (*Server, string, error) {
|
func RunLocalPacketConnServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||||
server, l, _, err := RunLocalTCPServerWithFinChan(laddr)
|
return RunLocalUDPServer(laddr, append(opts, func(srv *Server) {
|
||||||
|
// Make srv.PacketConn opaque to trigger the generic code paths.
|
||||||
return server, l, err
|
srv.PacketConn = struct{ net.PacketConn }{srv.PacketConn}
|
||||||
|
})...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, error) {
|
func RunLocalTCPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||||
l, err := net.Listen("tcp", laddr)
|
l, err := net.Listen("tcp", laddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", nil, err
|
return nil, "", nil, err
|
||||||
|
@ -120,8 +115,11 @@ func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, er
|
||||||
waitLock.Lock()
|
waitLock.Lock()
|
||||||
server.NotifyStartedFunc = waitLock.Unlock
|
server.NotifyStartedFunc = waitLock.Unlock
|
||||||
|
|
||||||
// See the comment in RunLocalUDPServerWithFinChan as to
|
for _, opt := range opts {
|
||||||
// why fin must be buffered.
|
opt(server)
|
||||||
|
}
|
||||||
|
|
||||||
|
// See the comment in RunLocalUDPServer as to why fin must be buffered.
|
||||||
fin := make(chan error, 1)
|
fin := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -133,70 +131,69 @@ func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, er
|
||||||
return server, l.Addr().String(), fin, nil
|
return server, l.Addr().String(), fin, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, error) {
|
func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan error, error) {
|
||||||
l, err := tls.Listen("tcp", laddr, config)
|
return RunLocalTCPServer(laddr, func(srv *Server) {
|
||||||
if err != nil {
|
srv.Listener = tls.NewListener(srv.Listener, config)
|
||||||
return nil, "", err
|
})
|
||||||
}
|
|
||||||
|
|
||||||
server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
|
|
||||||
|
|
||||||
waitLock := sync.Mutex{}
|
|
||||||
waitLock.Lock()
|
|
||||||
server.NotifyStartedFunc = waitLock.Unlock
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
server.ActivateAndServe()
|
|
||||||
l.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
waitLock.Lock()
|
|
||||||
return server, l.Addr().String(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServing(t *testing.T) {
|
func TestServing(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", HelloServer)
|
for _, tc := range []struct {
|
||||||
HandleFunc("example.com.", AnotherHelloServer)
|
name string
|
||||||
defer HandleRemove("miek.nl.")
|
network string
|
||||||
defer HandleRemove("example.com.")
|
runServer func(laddr string, opts ...func(*Server)) (*Server, string, chan error, error)
|
||||||
|
}{
|
||||||
|
{"udp", "udp", RunLocalUDPServer},
|
||||||
|
{"tcp", "tcp", RunLocalTCPServer},
|
||||||
|
{"PacketConn", "udp", RunLocalPacketConnServer},
|
||||||
|
} {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
HandleFunc("miek.nl.", HelloServer)
|
||||||
|
HandleFunc("example.com.", AnotherHelloServer)
|
||||||
|
defer HandleRemove("miek.nl.")
|
||||||
|
defer HandleRemove("example.com.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := tc.runServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
defer s.Shutdown()
|
defer s.Shutdown()
|
||||||
|
|
||||||
c := new(Client)
|
c := &Client{
|
||||||
m := new(Msg)
|
Net: tc.network,
|
||||||
m.SetQuestion("miek.nl.", TypeTXT)
|
}
|
||||||
r, _, err := c.Exchange(m, addrstr)
|
m := new(Msg)
|
||||||
if err != nil || len(r.Extra) == 0 {
|
m.SetQuestion("miek.nl.", TypeTXT)
|
||||||
t.Fatal("failed to exchange miek.nl", err)
|
r, _, err := c.Exchange(m, addrstr)
|
||||||
}
|
if err != nil || len(r.Extra) == 0 {
|
||||||
txt := r.Extra[0].(*TXT).Txt[0]
|
t.Fatal("failed to exchange miek.nl", err)
|
||||||
if txt != "Hello world" {
|
}
|
||||||
t.Error("unexpected result for miek.nl", txt, "!= Hello world")
|
txt := r.Extra[0].(*TXT).Txt[0]
|
||||||
}
|
if txt != "Hello world" {
|
||||||
|
t.Error("unexpected result for miek.nl", txt, "!= Hello world")
|
||||||
|
}
|
||||||
|
|
||||||
m.SetQuestion("example.com.", TypeTXT)
|
m.SetQuestion("example.com.", TypeTXT)
|
||||||
r, _, err = c.Exchange(m, addrstr)
|
r, _, err = c.Exchange(m, addrstr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to exchange example.com", err)
|
t.Fatal("failed to exchange example.com", err)
|
||||||
}
|
}
|
||||||
txt = r.Extra[0].(*TXT).Txt[0]
|
txt = r.Extra[0].(*TXT).Txt[0]
|
||||||
if txt != "Hello example" {
|
if txt != "Hello example" {
|
||||||
t.Error("unexpected result for example.com", txt, "!= Hello example")
|
t.Error("unexpected result for example.com", txt, "!= Hello example")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test Mixes cased as noticed by Ask.
|
// Test Mixes cased as noticed by Ask.
|
||||||
m.SetQuestion("eXaMplE.cOm.", TypeTXT)
|
m.SetQuestion("eXaMplE.cOm.", TypeTXT)
|
||||||
r, _, err = c.Exchange(m, addrstr)
|
r, _, err = c.Exchange(m, addrstr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to exchange eXaMplE.cOm", err)
|
t.Error("failed to exchange eXaMplE.cOm", err)
|
||||||
}
|
}
|
||||||
txt = r.Extra[0].(*TXT).Txt[0]
|
txt = r.Extra[0].(*TXT).Txt[0]
|
||||||
if txt != "Hello example" {
|
if txt != "Hello example" {
|
||||||
t.Error("unexpected result for example.com", txt, "!= Hello example")
|
t.Error("unexpected result for example.com", txt, "!= Hello example")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -204,7 +201,7 @@ func TestServing(t *testing.T) {
|
||||||
func TestServeIgnoresZFlag(t *testing.T) {
|
func TestServeIgnoresZFlag(t *testing.T) {
|
||||||
HandleFunc("example.com.", AnotherHelloServer)
|
HandleFunc("example.com.", AnotherHelloServer)
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -233,7 +230,7 @@ func TestServeNotImplemented(t *testing.T) {
|
||||||
HandleFunc("example.com.", AnotherHelloServer)
|
HandleFunc("example.com.", AnotherHelloServer)
|
||||||
opcode := 15
|
opcode := 15
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -272,7 +269,7 @@ func TestServingTLS(t *testing.T) {
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
}
|
}
|
||||||
|
|
||||||
s, addrstr, err := RunLocalTLSServer(":0", &config)
|
s, addrstr, _, err := RunLocalTLSServer(":0", &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -358,7 +355,7 @@ func TestServingTLSConnectionState(t *testing.T) {
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
}
|
}
|
||||||
|
|
||||||
s, addrstr, err := RunLocalTLSServer(":0", &config)
|
s, addrstr, _, err := RunLocalTLSServer(":0", &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -381,7 +378,7 @@ func TestServingTLSConnectionState(t *testing.T) {
|
||||||
// UDP DNS Server
|
// UDP DNS Server
|
||||||
HandleFunc(".", tlsHandlerTLS(false))
|
HandleFunc(".", tlsHandlerTLS(false))
|
||||||
defer HandleRemove(".")
|
defer HandleRemove(".")
|
||||||
s, addrstr, err = RunLocalUDPServer(":0")
|
s, addrstr, _, err = RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -395,7 +392,7 @@ func TestServingTLSConnectionState(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TCP DNS Server
|
// TCP DNS Server
|
||||||
s, addrstr, err = RunLocalTCPServer(":0")
|
s, addrstr, _, err = RunLocalTCPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -479,7 +476,7 @@ func BenchmarkServe(b *testing.B) {
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
a := runtime.GOMAXPROCS(4)
|
a := runtime.GOMAXPROCS(4)
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatalf("unable to run test server: %v", err)
|
b.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -504,7 +501,7 @@ func BenchmarkServe6(b *testing.B) {
|
||||||
HandleFunc("miek.nl.", HelloServer)
|
HandleFunc("miek.nl.", HelloServer)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
a := runtime.GOMAXPROCS(4)
|
a := runtime.GOMAXPROCS(4)
|
||||||
s, addrstr, err := RunLocalUDPServer("[::1]:0")
|
s, addrstr, _, err := RunLocalUDPServer("[::1]:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "bind: cannot assign requested address") {
|
if strings.Contains(err.Error(), "bind: cannot assign requested address") {
|
||||||
b.Skip("missing IPv6 support")
|
b.Skip("missing IPv6 support")
|
||||||
|
@ -541,7 +538,7 @@ func BenchmarkServeCompress(b *testing.B) {
|
||||||
HandleFunc("miek.nl.", HelloServerCompress)
|
HandleFunc("miek.nl.", HelloServerCompress)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
a := runtime.GOMAXPROCS(4)
|
a := runtime.GOMAXPROCS(4)
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatalf("unable to run test server: %v", err)
|
b.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -594,7 +591,7 @@ func TestServingLargeResponses(t *testing.T) {
|
||||||
HandleFunc("example.", HelloServerLargeResponse)
|
HandleFunc("example.", HelloServerLargeResponse)
|
||||||
defer HandleRemove("example.")
|
defer HandleRemove("example.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -634,7 +631,7 @@ func TestServingResponse(t *testing.T) {
|
||||||
t.Skip("skipping test in short mode.")
|
t.Skip("skipping test in short mode.")
|
||||||
}
|
}
|
||||||
HandleFunc("miek.nl.", HelloServer)
|
HandleFunc("miek.nl.", HelloServer)
|
||||||
s, addrstr, err := RunLocalUDPServer(":0")
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -657,7 +654,7 @@ func TestServingResponse(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShutdownTCP(t *testing.T) {
|
func TestShutdownTCP(t *testing.T) {
|
||||||
s, _, fin, err := RunLocalTCPServerWithFinChan(":0")
|
s, _, fin, err := RunLocalTCPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -788,7 +785,7 @@ func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr stri
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInProgressQueriesAtShutdownTCP(t *testing.T) {
|
func TestInProgressQueriesAtShutdownTCP(t *testing.T) {
|
||||||
s, addr, _, err := RunLocalTCPServerWithFinChan(":0")
|
s, addr, _, err := RunLocalTCPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -807,7 +804,7 @@ func TestShutdownTLS(t *testing.T) {
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
}
|
}
|
||||||
|
|
||||||
s, _, err := RunLocalTLSServer(":0", &config)
|
s, _, _, err := RunLocalTLSServer(":0", &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -827,7 +824,7 @@ func TestInProgressQueriesAtShutdownTLS(t *testing.T) {
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
}
|
}
|
||||||
|
|
||||||
s, addr, err := RunLocalTLSServer(":0", &config)
|
s, addr, _, err := RunLocalTLSServer(":0", &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -842,7 +839,6 @@ func TestInProgressQueriesAtShutdownTLS(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlerCloseTCP(t *testing.T) {
|
func TestHandlerCloseTCP(t *testing.T) {
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", ":0")
|
ln, err := net.Listen("tcp", ":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -887,7 +883,26 @@ func TestHandlerCloseTCP(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShutdownUDP(t *testing.T) {
|
func TestShutdownUDP(t *testing.T) {
|
||||||
s, _, fin, err := RunLocalUDPServerWithFinChan(":0")
|
s, _, fin, err := RunLocalUDPServer(":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
|
}
|
||||||
|
err = s.Shutdown()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("could not shutdown test UDP server, %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case err := <-fin:
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error returned from ActivateAndServe, %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("could not shutdown test UDP server. Gave up waiting")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShutdownPacketConn(t *testing.T) {
|
||||||
|
s, _, fin, err := RunLocalPacketConnServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -906,7 +921,17 @@ func TestShutdownUDP(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInProgressQueriesAtShutdownUDP(t *testing.T) {
|
func TestInProgressQueriesAtShutdownUDP(t *testing.T) {
|
||||||
s, addr, _, err := RunLocalUDPServerWithFinChan(":0")
|
s, addr, _, err := RunLocalUDPServer(":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &Client{Net: "udp"}
|
||||||
|
checkInProgressQueriesAtShutdownServer(t, s, addr, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInProgressQueriesAtShutdownPacketConn(t *testing.T) {
|
||||||
|
s, addr, _, err := RunLocalPacketConnServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %v", err)
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -919,7 +944,7 @@ func TestServerStartStopRace(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
s, _, _, err := RunLocalUDPServerWithFinChan(":0")
|
s, _, _, err := RunLocalUDPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("could not start server: %s", err)
|
t.Fatalf("could not start server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -982,7 +1007,7 @@ func TestServerReuseport(t *testing.T) {
|
||||||
func TestServerRoundtripTsig(t *testing.T) {
|
func TestServerRoundtripTsig(t *testing.T) {
|
||||||
secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="}
|
secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="}
|
||||||
|
|
||||||
s, addrstr, _, err := RunLocalUDPServerWithFinChan(":0", func(srv *Server) {
|
s, addrstr, _, err := RunLocalUDPServer(":0", func(srv *Server) {
|
||||||
srv.TsigSecret = secret
|
srv.TsigSecret = secret
|
||||||
srv.MsgAcceptFunc = func(dh Header) MsgAcceptAction {
|
srv.MsgAcceptFunc = func(dh Header) MsgAcceptAction {
|
||||||
// defaultMsgAcceptFunc does reject UPDATE queries
|
// defaultMsgAcceptFunc does reject UPDATE queries
|
||||||
|
|
48
xfr_test.go
48
xfr_test.go
|
@ -1,11 +1,6 @@
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import "testing"
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
tsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
|
tsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
|
||||||
|
@ -52,7 +47,7 @@ func TestInvalidXfr(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", InvalidXfrServer)
|
HandleFunc("miek.nl.", InvalidXfrServer)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalTCPServer(":0")
|
s, addrstr, _, err := RunLocalTCPServer(":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %s", err)
|
t.Fatalf("unable to run test server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -78,7 +73,9 @@ func TestSingleEnvelopeXfr(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", SingleEnvelopeXfrServer)
|
HandleFunc("miek.nl.", SingleEnvelopeXfrServer)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalTCPServerWithTsig(":0", tsigSecret)
|
s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) {
|
||||||
|
srv.TsigSecret = tsigSecret
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %s", err)
|
t.Fatalf("unable to run test server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -91,7 +88,9 @@ func TestMultiEnvelopeXfr(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", MultipleEnvelopeXfrServer)
|
HandleFunc("miek.nl.", MultipleEnvelopeXfrServer)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
s, addrstr, err := RunLocalTCPServerWithTsig(":0", tsigSecret)
|
s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) {
|
||||||
|
srv.TsigSecret = tsigSecret
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to run test server: %s", err)
|
t.Fatalf("unable to run test server: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -100,37 +99,6 @@ func TestMultiEnvelopeXfr(t *testing.T) {
|
||||||
axfrTestingSuite(t, addrstr)
|
axfrTestingSuite(t, addrstr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunLocalTCPServerWithTsig(laddr string, tsig map[string]string) (*Server, string, error) {
|
|
||||||
server, l, _, err := RunLocalTCPServerWithFinChanWithTsig(laddr, tsig)
|
|
||||||
|
|
||||||
return server, l, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func RunLocalTCPServerWithFinChanWithTsig(laddr string, tsig map[string]string) (*Server, string, chan error, error) {
|
|
||||||
l, err := net.Listen("tcp", laddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour, TsigSecret: tsig}
|
|
||||||
|
|
||||||
waitLock := sync.Mutex{}
|
|
||||||
waitLock.Lock()
|
|
||||||
server.NotifyStartedFunc = waitLock.Unlock
|
|
||||||
|
|
||||||
// See the comment in RunLocalUDPServerWithFinChan as to
|
|
||||||
// why fin must be buffered.
|
|
||||||
fin := make(chan error, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
fin <- server.ActivateAndServe()
|
|
||||||
l.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
waitLock.Lock()
|
|
||||||
return server, l.Addr().String(), fin, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func axfrTestingSuite(t *testing.T, addrstr string) {
|
func axfrTestingSuite(t *testing.T, addrstr string) {
|
||||||
tr := new(Transfer)
|
tr := new(Transfer)
|
||||||
m := new(Msg)
|
m := new(Msg)
|
||||||
|
|
Loading…
Reference in New Issue