diff --git a/client.go b/client.go index 6bae3a1c..99bd1896 100644 --- a/client.go +++ b/client.go @@ -18,6 +18,18 @@ const ( tcpIdleTimeout time.Duration = 8 * time.Second ) +func isPacketConn(c net.Conn) bool { + if _, ok := c.(net.PacketConn); !ok { + return false + } + + if ua, ok := c.LocalAddr().(*net.UnixAddr); ok { + return ua.Net == "unixgram" + } + + return true +} + // A Conn represents a connection to a DNS server. type Conn struct { net.Conn // a net.Conn holding the connection @@ -221,7 +233,7 @@ func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, return nil, 0, err } - if _, ok := co.Conn.(net.PacketConn); ok { + if isPacketConn(co.Conn) { for { r, err = co.ReadMsg() // Ignore replies with mismatched IDs because they might be @@ -282,7 +294,7 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { err error ) - if _, ok := co.Conn.(net.PacketConn); ok { + if isPacketConn(co.Conn) { if co.UDPSize > MinMsgSize { p = make([]byte, co.UDPSize) } else { @@ -322,7 +334,7 @@ func (co *Conn) Read(p []byte) (n int, err error) { return 0, ErrConnEmpty } - if _, ok := co.Conn.(net.PacketConn); ok { + if isPacketConn(co.Conn) { // UDP connection return co.Conn.Read(p) } @@ -371,7 +383,7 @@ func (co *Conn) Write(p []byte) (int, error) { return 0, &Error{err: "message too large"} } - if _, ok := co.Conn.(net.PacketConn); ok { + if isPacketConn(co.Conn) { return co.Conn.Write(p) } diff --git a/client_test.go b/client_test.go index 4bff6eb4..71c61770 100644 --- a/client_test.go +++ b/client_test.go @@ -6,12 +6,87 @@ import ( "errors" "fmt" "net" + "path/filepath" "strconv" "strings" "testing" "time" ) +func TestIsPacketConn(t *testing.T) { + // UDP + s, addrstr, _, err := RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + c, err := net.Dial("udp", addrstr) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer c.Close() + if !isPacketConn(c) { + t.Error("UDP connection should be a packet conn") + } + if !isPacketConn(struct{ *net.UDPConn }{c.(*net.UDPConn)}) { + t.Error("UDP connection (wrapped type) should be a packet conn") + } + + // TCP + s, addrstr, _, err = RunLocalTCPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + c, err = net.Dial("tcp", addrstr) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer c.Close() + if isPacketConn(c) { + t.Error("TCP connection should not be a packet conn") + } + if isPacketConn(struct{ *net.TCPConn }{c.(*net.TCPConn)}) { + t.Error("TCP connection (wrapped type) should not be a packet conn") + } + + // Unix datagram + s, addrstr, _, err = RunLocalUnixGramServer(filepath.Join(t.TempDir(), "unixgram.sock")) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + c, err = net.Dial("unixgram", addrstr) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer c.Close() + if !isPacketConn(c) { + t.Error("Unix datagram connection should be a packet conn") + } + if !isPacketConn(struct{ *net.UnixConn }{c.(*net.UnixConn)}) { + t.Error("Unix datagram connection (wrapped type) should be a packet conn") + } + + // Unix stream + s, addrstr, _, err = RunLocalUnixServer(filepath.Join(t.TempDir(), "unixstream.sock")) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + c, err = net.Dial("unix", addrstr) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer c.Close() + if isPacketConn(c) { + t.Error("Unix stream connection should not be a packet conn") + } + if isPacketConn(struct{ *net.UnixConn }{c.(*net.UnixConn)}) { + t.Error("Unix stream connection (wrapped type) should not be a packet conn") + } +} + func TestDialUDP(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") diff --git a/server_test.go b/server_test.go index a5207aa5..85da176d 100644 --- a/server_test.go +++ b/server_test.go @@ -67,12 +67,14 @@ func AnotherHelloServer(w ResponseWriter, req *Msg) { w.WriteMsg(m) } -func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { - pc, err := net.ListenPacket("udp", laddr) - if err != nil { - return nil, "", nil, err +func RunLocalServer(pc net.PacketConn, l net.Listener, opts ...func(*Server)) (*Server, string, chan error, error) { + server := &Server{ + PacketConn: pc, + Listener: l, + + ReadTimeout: time.Hour, + WriteTimeout: time.Hour, } - server := &Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour} waitLock := sync.Mutex{} waitLock.Lock() @@ -82,6 +84,18 @@ func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, ch opt(server) } + var ( + addr string + closer io.Closer + ) + if l != nil { + addr = l.Addr().String() + closer = l + } else { + addr = pc.LocalAddr().String() + closer = pc + } + // fin must be buffered so the goroutine below won't block // forever if fin is never read from. This always happens // if the channel is discarded and can happen in TestShutdownUDP. @@ -89,11 +103,20 @@ func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, ch go func() { fin <- server.ActivateAndServe() - pc.Close() + closer.Close() }() waitLock.Lock() - return server, pc.LocalAddr().String(), fin, nil + return server, addr, fin, nil +} + +func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { + pc, err := net.ListenPacket("udp", laddr) + if err != nil { + return nil, "", nil, err + } + + return RunLocalServer(pc, nil, opts...) } func RunLocalPacketConnServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { @@ -109,26 +132,7 @@ func RunLocalTCPServer(laddr string, opts ...func(*Server)) (*Server, string, ch return nil, "", nil, err } - server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour} - - waitLock := sync.Mutex{} - waitLock.Lock() - server.NotifyStartedFunc = waitLock.Unlock - - for _, opt := range opts { - opt(server) - } - - // See the comment in RunLocalUDPServer as to why fin must be buffered. - fin := make(chan error, 1) - - go func() { - fin <- server.ActivateAndServe() - l.Close() - }() - - waitLock.Lock() - return server, l.Addr().String(), fin, nil + return RunLocalServer(nil, l, opts...) } func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan error, error) { @@ -137,6 +141,24 @@ func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan }) } +func RunLocalUnixServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { + l, err := net.Listen("unix", laddr) + if err != nil { + return nil, "", nil, err + } + + return RunLocalServer(nil, l, opts...) +} + +func RunLocalUnixGramServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { + pc, err := net.ListenPacket("unixgram", laddr) + if err != nil { + return nil, "", nil, err + } + + return RunLocalServer(pc, nil, opts...) +} + func TestServing(t *testing.T) { for _, tc := range []struct { name string