Only treat a *net.UnixConn of unixgram as a packet conn (#1322)
* Refactor net.PacketConn checks into helper function * Only treat a *net.UnixConn of unixgram as a packet conn * Handle wrapped net.Conn types in isPacketConn * Use Error instead of Fatal where appropriate in TestIsPacketConn
This commit is contained in:
parent
af5144a5ca
commit
0544c8bb11
20
client.go
20
client.go
|
@ -18,6 +18,18 @@ const (
|
||||||
tcpIdleTimeout time.Duration = 8 * time.Second
|
tcpIdleTimeout time.Duration = 8 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func isPacketConn(c net.Conn) bool {
|
||||||
|
if _, ok := c.(net.PacketConn); !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if ua, ok := c.LocalAddr().(*net.UnixAddr); ok {
|
||||||
|
return ua.Net == "unixgram"
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// A Conn represents a connection to a DNS server.
|
// A Conn represents a connection to a DNS server.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
net.Conn // a net.Conn holding the connection
|
net.Conn // a net.Conn holding the connection
|
||||||
|
@ -221,7 +233,7 @@ func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg,
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := co.Conn.(net.PacketConn); ok {
|
if isPacketConn(co.Conn) {
|
||||||
for {
|
for {
|
||||||
r, err = co.ReadMsg()
|
r, err = co.ReadMsg()
|
||||||
// Ignore replies with mismatched IDs because they might be
|
// Ignore replies with mismatched IDs because they might be
|
||||||
|
@ -282,7 +294,7 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
if _, ok := co.Conn.(net.PacketConn); ok {
|
if isPacketConn(co.Conn) {
|
||||||
if co.UDPSize > MinMsgSize {
|
if co.UDPSize > MinMsgSize {
|
||||||
p = make([]byte, co.UDPSize)
|
p = make([]byte, co.UDPSize)
|
||||||
} else {
|
} else {
|
||||||
|
@ -322,7 +334,7 @@ func (co *Conn) Read(p []byte) (n int, err error) {
|
||||||
return 0, ErrConnEmpty
|
return 0, ErrConnEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := co.Conn.(net.PacketConn); ok {
|
if isPacketConn(co.Conn) {
|
||||||
// UDP connection
|
// UDP connection
|
||||||
return co.Conn.Read(p)
|
return co.Conn.Read(p)
|
||||||
}
|
}
|
||||||
|
@ -371,7 +383,7 @@ func (co *Conn) Write(p []byte) (int, error) {
|
||||||
return 0, &Error{err: "message too large"}
|
return 0, &Error{err: "message too large"}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := co.Conn.(net.PacketConn); ok {
|
if isPacketConn(co.Conn) {
|
||||||
return co.Conn.Write(p)
|
return co.Conn.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,87 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestIsPacketConn(t *testing.T) {
|
||||||
|
// UDP
|
||||||
|
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
|
}
|
||||||
|
defer s.Shutdown()
|
||||||
|
c, err := net.Dial("udp", addrstr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to dial: %v", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
if !isPacketConn(c) {
|
||||||
|
t.Error("UDP connection should be a packet conn")
|
||||||
|
}
|
||||||
|
if !isPacketConn(struct{ *net.UDPConn }{c.(*net.UDPConn)}) {
|
||||||
|
t.Error("UDP connection (wrapped type) should be a packet conn")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCP
|
||||||
|
s, addrstr, _, err = RunLocalTCPServer(":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
|
}
|
||||||
|
defer s.Shutdown()
|
||||||
|
c, err = net.Dial("tcp", addrstr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to dial: %v", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
if isPacketConn(c) {
|
||||||
|
t.Error("TCP connection should not be a packet conn")
|
||||||
|
}
|
||||||
|
if isPacketConn(struct{ *net.TCPConn }{c.(*net.TCPConn)}) {
|
||||||
|
t.Error("TCP connection (wrapped type) should not be a packet conn")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unix datagram
|
||||||
|
s, addrstr, _, err = RunLocalUnixGramServer(filepath.Join(t.TempDir(), "unixgram.sock"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
|
}
|
||||||
|
defer s.Shutdown()
|
||||||
|
c, err = net.Dial("unixgram", addrstr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to dial: %v", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
if !isPacketConn(c) {
|
||||||
|
t.Error("Unix datagram connection should be a packet conn")
|
||||||
|
}
|
||||||
|
if !isPacketConn(struct{ *net.UnixConn }{c.(*net.UnixConn)}) {
|
||||||
|
t.Error("Unix datagram connection (wrapped type) should be a packet conn")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unix stream
|
||||||
|
s, addrstr, _, err = RunLocalUnixServer(filepath.Join(t.TempDir(), "unixstream.sock"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to run test server: %v", err)
|
||||||
|
}
|
||||||
|
defer s.Shutdown()
|
||||||
|
c, err = net.Dial("unix", addrstr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to dial: %v", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
if isPacketConn(c) {
|
||||||
|
t.Error("Unix stream connection should not be a packet conn")
|
||||||
|
}
|
||||||
|
if isPacketConn(struct{ *net.UnixConn }{c.(*net.UnixConn)}) {
|
||||||
|
t.Error("Unix stream connection (wrapped type) should not be a packet conn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDialUDP(t *testing.T) {
|
func TestDialUDP(t *testing.T) {
|
||||||
HandleFunc("miek.nl.", HelloServer)
|
HandleFunc("miek.nl.", HelloServer)
|
||||||
defer HandleRemove("miek.nl.")
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
|
@ -67,12 +67,14 @@ func AnotherHelloServer(w ResponseWriter, req *Msg) {
|
||||||
w.WriteMsg(m)
|
w.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
func RunLocalServer(pc net.PacketConn, l net.Listener, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||||
pc, err := net.ListenPacket("udp", laddr)
|
server := &Server{
|
||||||
if err != nil {
|
PacketConn: pc,
|
||||||
return nil, "", nil, err
|
Listener: l,
|
||||||
|
|
||||||
|
ReadTimeout: time.Hour,
|
||||||
|
WriteTimeout: time.Hour,
|
||||||
}
|
}
|
||||||
server := &Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
|
|
||||||
|
|
||||||
waitLock := sync.Mutex{}
|
waitLock := sync.Mutex{}
|
||||||
waitLock.Lock()
|
waitLock.Lock()
|
||||||
|
@ -82,6 +84,18 @@ func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, ch
|
||||||
opt(server)
|
opt(server)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
addr string
|
||||||
|
closer io.Closer
|
||||||
|
)
|
||||||
|
if l != nil {
|
||||||
|
addr = l.Addr().String()
|
||||||
|
closer = l
|
||||||
|
} else {
|
||||||
|
addr = pc.LocalAddr().String()
|
||||||
|
closer = pc
|
||||||
|
}
|
||||||
|
|
||||||
// fin must be buffered so the goroutine below won't block
|
// fin must be buffered so the goroutine below won't block
|
||||||
// forever if fin is never read from. This always happens
|
// forever if fin is never read from. This always happens
|
||||||
// if the channel is discarded and can happen in TestShutdownUDP.
|
// if the channel is discarded and can happen in TestShutdownUDP.
|
||||||
|
@ -89,11 +103,20 @@ func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, ch
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
fin <- server.ActivateAndServe()
|
fin <- server.ActivateAndServe()
|
||||||
pc.Close()
|
closer.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
waitLock.Lock()
|
waitLock.Lock()
|
||||||
return server, pc.LocalAddr().String(), fin, nil
|
return server, addr, fin, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||||
|
pc, err := net.ListenPacket("udp", laddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return RunLocalServer(pc, nil, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunLocalPacketConnServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
func RunLocalPacketConnServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||||
|
@ -109,26 +132,7 @@ func RunLocalTCPServer(laddr string, opts ...func(*Server)) (*Server, string, ch
|
||||||
return nil, "", nil, err
|
return nil, "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
|
return RunLocalServer(nil, l, opts...)
|
||||||
|
|
||||||
waitLock := sync.Mutex{}
|
|
||||||
waitLock.Lock()
|
|
||||||
server.NotifyStartedFunc = waitLock.Unlock
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(server)
|
|
||||||
}
|
|
||||||
|
|
||||||
// See the comment in RunLocalUDPServer as to why fin must be buffered.
|
|
||||||
fin := make(chan error, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
fin <- server.ActivateAndServe()
|
|
||||||
l.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
waitLock.Lock()
|
|
||||||
return server, l.Addr().String(), fin, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan error, error) {
|
func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan error, error) {
|
||||||
|
@ -137,6 +141,24 @@ func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RunLocalUnixServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||||
|
l, err := net.Listen("unix", laddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return RunLocalServer(nil, l, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunLocalUnixGramServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||||
|
pc, err := net.ListenPacket("unixgram", laddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return RunLocalServer(pc, nil, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
func TestServing(t *testing.T) {
|
func TestServing(t *testing.T) {
|
||||||
for _, tc := range []struct {
|
for _, tc := range []struct {
|
||||||
name string
|
name string
|
||||||
|
|
Loading…
Reference in New Issue