From ff611cdc4b48526df452d46249bc3ab278b6de63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Oliveirinha?= Date: Wed, 8 Jun 2022 13:03:24 +0100 Subject: [PATCH] Add back support for *net.UnixCon with seqpacket type (#1378) This was broken by PR: https://github.com/miekg/dns/pull/1322 --- client.go | 2 +- client_test.go | 21 +++++++++++++++++++++ server_test.go | 15 +++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 31bf5759..fde5b5e3 100644 --- a/client.go +++ b/client.go @@ -24,7 +24,7 @@ func isPacketConn(c net.Conn) bool { } if ua, ok := c.LocalAddr().(*net.UnixAddr); ok { - return ua.Net == "unixgram" + return ua.Net == "unixgram" || ua.Net == "unixpacket" } return true diff --git a/client_test.go b/client_test.go index 83d5af70..ff0c2187 100644 --- a/client_test.go +++ b/client_test.go @@ -68,6 +68,27 @@ func TestIsPacketConn(t *testing.T) { t.Error("Unix datagram connection (wrapped type) should be a packet conn") } + // Unix Seqpacket + shutChan, addrstr, err := RunLocalUnixSeqPacketServer(filepath.Join(t.TempDir(), "unixpacket.sock")) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + + defer func() { + shutChan <- &struct{}{} + }() + c, err = net.Dial("unixpacket", addrstr) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer c.Close() + if !isPacketConn(c) { + t.Error("Unix datagram connection should be a packet conn") + } + if !isPacketConn(struct{ *net.UnixConn }{c.(*net.UnixConn)}) { + t.Error("Unix datagram connection (wrapped type) should be a packet conn") + } + // Unix stream s, addrstr, _, err = RunLocalUnixServer(filepath.Join(t.TempDir(), "unixstream.sock")) if err != nil { diff --git a/server_test.go b/server_test.go index 85da176d..aaaca704 100644 --- a/server_test.go +++ b/server_test.go @@ -159,6 +159,21 @@ func RunLocalUnixGramServer(laddr string, opts ...func(*Server)) (*Server, strin return RunLocalServer(pc, nil, opts...) } +func RunLocalUnixSeqPacketServer(laddr string) (chan interface{}, string, error) { + pc, err := net.Listen("unixpacket", laddr) + if err != nil { + return nil, "", err + } + + shutdownChan := make(chan interface{}) + go func() { + pc.Accept() + <-shutdownChan + }() + + return shutdownChan, pc.Addr().String(), nil +} + func TestServing(t *testing.T) { for _, tc := range []struct { name string