Only treat a *net.UnixConn of unixgram as a packet conn (#1322)
* Refactor net.PacketConn checks into helper function * Only treat a *net.UnixConn of unixgram as a packet conn * Handle wrapped net.Conn types in isPacketConn * Use Error instead of Fatal where appropriate in TestIsPacketConn
This commit is contained in:
parent
af5144a5ca
commit
0544c8bb11
20
client.go
20
client.go
|
@ -18,6 +18,18 @@ const (
|
|||
tcpIdleTimeout time.Duration = 8 * time.Second
|
||||
)
|
||||
|
||||
func isPacketConn(c net.Conn) bool {
|
||||
if _, ok := c.(net.PacketConn); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if ua, ok := c.LocalAddr().(*net.UnixAddr); ok {
|
||||
return ua.Net == "unixgram"
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// A Conn represents a connection to a DNS server.
|
||||
type Conn struct {
|
||||
net.Conn // a net.Conn holding the connection
|
||||
|
@ -221,7 +233,7 @@ func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg,
|
|||
return nil, 0, err
|
||||
}
|
||||
|
||||
if _, ok := co.Conn.(net.PacketConn); ok {
|
||||
if isPacketConn(co.Conn) {
|
||||
for {
|
||||
r, err = co.ReadMsg()
|
||||
// Ignore replies with mismatched IDs because they might be
|
||||
|
@ -282,7 +294,7 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
|
|||
err error
|
||||
)
|
||||
|
||||
if _, ok := co.Conn.(net.PacketConn); ok {
|
||||
if isPacketConn(co.Conn) {
|
||||
if co.UDPSize > MinMsgSize {
|
||||
p = make([]byte, co.UDPSize)
|
||||
} else {
|
||||
|
@ -322,7 +334,7 @@ func (co *Conn) Read(p []byte) (n int, err error) {
|
|||
return 0, ErrConnEmpty
|
||||
}
|
||||
|
||||
if _, ok := co.Conn.(net.PacketConn); ok {
|
||||
if isPacketConn(co.Conn) {
|
||||
// UDP connection
|
||||
return co.Conn.Read(p)
|
||||
}
|
||||
|
@ -371,7 +383,7 @@ func (co *Conn) Write(p []byte) (int, error) {
|
|||
return 0, &Error{err: "message too large"}
|
||||
}
|
||||
|
||||
if _, ok := co.Conn.(net.PacketConn); ok {
|
||||
if isPacketConn(co.Conn) {
|
||||
return co.Conn.Write(p)
|
||||
}
|
||||
|
||||
|
|
|
@ -6,12 +6,87 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsPacketConn(t *testing.T) {
|
||||
// UDP
|
||||
s, addrstr, _, err := RunLocalUDPServer(":0")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to run test server: %v", err)
|
||||
}
|
||||
defer s.Shutdown()
|
||||
c, err := net.Dial("udp", addrstr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
if !isPacketConn(c) {
|
||||
t.Error("UDP connection should be a packet conn")
|
||||
}
|
||||
if !isPacketConn(struct{ *net.UDPConn }{c.(*net.UDPConn)}) {
|
||||
t.Error("UDP connection (wrapped type) should be a packet conn")
|
||||
}
|
||||
|
||||
// TCP
|
||||
s, addrstr, _, err = RunLocalTCPServer(":0")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to run test server: %v", err)
|
||||
}
|
||||
defer s.Shutdown()
|
||||
c, err = net.Dial("tcp", addrstr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
if isPacketConn(c) {
|
||||
t.Error("TCP connection should not be a packet conn")
|
||||
}
|
||||
if isPacketConn(struct{ *net.TCPConn }{c.(*net.TCPConn)}) {
|
||||
t.Error("TCP connection (wrapped type) should not be a packet conn")
|
||||
}
|
||||
|
||||
// Unix datagram
|
||||
s, addrstr, _, err = RunLocalUnixGramServer(filepath.Join(t.TempDir(), "unixgram.sock"))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to run test server: %v", err)
|
||||
}
|
||||
defer s.Shutdown()
|
||||
c, err = net.Dial("unixgram", addrstr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
if !isPacketConn(c) {
|
||||
t.Error("Unix datagram connection should be a packet conn")
|
||||
}
|
||||
if !isPacketConn(struct{ *net.UnixConn }{c.(*net.UnixConn)}) {
|
||||
t.Error("Unix datagram connection (wrapped type) should be a packet conn")
|
||||
}
|
||||
|
||||
// Unix stream
|
||||
s, addrstr, _, err = RunLocalUnixServer(filepath.Join(t.TempDir(), "unixstream.sock"))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to run test server: %v", err)
|
||||
}
|
||||
defer s.Shutdown()
|
||||
c, err = net.Dial("unix", addrstr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
if isPacketConn(c) {
|
||||
t.Error("Unix stream connection should not be a packet conn")
|
||||
}
|
||||
if isPacketConn(struct{ *net.UnixConn }{c.(*net.UnixConn)}) {
|
||||
t.Error("Unix stream connection (wrapped type) should not be a packet conn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialUDP(t *testing.T) {
|
||||
HandleFunc("miek.nl.", HelloServer)
|
||||
defer HandleRemove("miek.nl.")
|
||||
|
|
|
@ -67,12 +67,14 @@ func AnotherHelloServer(w ResponseWriter, req *Msg) {
|
|||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||
pc, err := net.ListenPacket("udp", laddr)
|
||||
if err != nil {
|
||||
return nil, "", nil, err
|
||||
func RunLocalServer(pc net.PacketConn, l net.Listener, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||
server := &Server{
|
||||
PacketConn: pc,
|
||||
Listener: l,
|
||||
|
||||
ReadTimeout: time.Hour,
|
||||
WriteTimeout: time.Hour,
|
||||
}
|
||||
server := &Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
|
@ -82,6 +84,18 @@ func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, ch
|
|||
opt(server)
|
||||
}
|
||||
|
||||
var (
|
||||
addr string
|
||||
closer io.Closer
|
||||
)
|
||||
if l != nil {
|
||||
addr = l.Addr().String()
|
||||
closer = l
|
||||
} else {
|
||||
addr = pc.LocalAddr().String()
|
||||
closer = pc
|
||||
}
|
||||
|
||||
// fin must be buffered so the goroutine below won't block
|
||||
// forever if fin is never read from. This always happens
|
||||
// if the channel is discarded and can happen in TestShutdownUDP.
|
||||
|
@ -89,11 +103,20 @@ func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, ch
|
|||
|
||||
go func() {
|
||||
fin <- server.ActivateAndServe()
|
||||
pc.Close()
|
||||
closer.Close()
|
||||
}()
|
||||
|
||||
waitLock.Lock()
|
||||
return server, pc.LocalAddr().String(), fin, nil
|
||||
return server, addr, fin, nil
|
||||
}
|
||||
|
||||
func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||
pc, err := net.ListenPacket("udp", laddr)
|
||||
if err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
|
||||
return RunLocalServer(pc, nil, opts...)
|
||||
}
|
||||
|
||||
func RunLocalPacketConnServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||
|
@ -109,26 +132,7 @@ func RunLocalTCPServer(laddr string, opts ...func(*Server)) (*Server, string, ch
|
|||
return nil, "", nil, err
|
||||
}
|
||||
|
||||
server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
server.NotifyStartedFunc = waitLock.Unlock
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(server)
|
||||
}
|
||||
|
||||
// See the comment in RunLocalUDPServer as to why fin must be buffered.
|
||||
fin := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
fin <- server.ActivateAndServe()
|
||||
l.Close()
|
||||
}()
|
||||
|
||||
waitLock.Lock()
|
||||
return server, l.Addr().String(), fin, nil
|
||||
return RunLocalServer(nil, l, opts...)
|
||||
}
|
||||
|
||||
func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan error, error) {
|
||||
|
@ -137,6 +141,24 @@ func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan
|
|||
})
|
||||
}
|
||||
|
||||
func RunLocalUnixServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||
l, err := net.Listen("unix", laddr)
|
||||
if err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
|
||||
return RunLocalServer(nil, l, opts...)
|
||||
}
|
||||
|
||||
func RunLocalUnixGramServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
|
||||
pc, err := net.ListenPacket("unixgram", laddr)
|
||||
if err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
|
||||
return RunLocalServer(pc, nil, opts...)
|
||||
}
|
||||
|
||||
func TestServing(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
|
|
Loading…
Reference in New Issue