diff --git a/listen_go111.go b/listen_go111.go new file mode 100644 index 00000000..bd024c89 --- /dev/null +++ b/listen_go111.go @@ -0,0 +1,43 @@ +// +build go1.11,!windows + +package dns + +import ( + "context" + "net" + "syscall" + + "golang.org/x/sys/unix" +) + +const supportsReusePort = true + +func reuseportControl(network, address string, c syscall.RawConn) error { + var opErr error + err := c.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + }) + if err != nil { + return err + } + + return opErr +} + +func listenTCP(network, addr string, reuseport bool) (net.Listener, error) { + var lc net.ListenConfig + if reuseport { + lc.Control = reuseportControl + } + + return lc.Listen(context.Background(), network, addr) +} + +func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) { + var lc net.ListenConfig + if reuseport { + lc.Control = reuseportControl + } + + return lc.ListenPacket(context.Background(), network, addr) +} diff --git a/listen_go_not111.go b/listen_go_not111.go new file mode 100644 index 00000000..f1fc652c --- /dev/null +++ b/listen_go_not111.go @@ -0,0 +1,23 @@ +// +build !go1.11 windows + +package dns + +import "net" + +const supportsReusePort = false + +func listenTCP(network, addr string, reuseport bool) (net.Listener, error) { + if reuseport { + // TODO(tmthrgd): return an error? + } + + return net.Listen(network, addr) +} + +func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) { + if reuseport { + // TODO(tmthrgd): return an error? + } + + return net.ListenPacket(network, addr) +} diff --git a/server.go b/server.go index c0f18037..d3f526fc 100644 --- a/server.go +++ b/server.go @@ -6,8 +6,10 @@ import ( "bytes" "crypto/tls" "encoding/binary" + "errors" "io" "net" + "strings" "sync" "sync/atomic" "time" @@ -312,6 +314,9 @@ type Server struct { DecorateWriter DecorateWriter // Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1). MaxTCPQueries int + // Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address. + // It is only supported on go1.11+ and when using ListenAndServe. + ReusePort bool // UDP packet or TCP connection queue queue chan *response @@ -418,11 +423,7 @@ func (srv *Server) ListenAndServe() error { switch srv.Net { case "tcp", "tcp4", "tcp6": - a, err := net.ResolveTCPAddr(srv.Net, addr) - if err != nil { - return err - } - l, err := net.ListenTCP(srv.Net, a) + l, err := listenTCP(srv.Net, addr, srv.ReusePort) if err != nil { return err } @@ -431,37 +432,32 @@ func (srv *Server) ListenAndServe() error { unlock() return srv.serveTCP(l) case "tcp-tls", "tcp4-tls", "tcp6-tls": - network := "tcp" - if srv.Net == "tcp4-tls" { - network = "tcp4" - } else if srv.Net == "tcp6-tls" { - network = "tcp6" + if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) { + return errors.New("dns: neither Certificates nor GetCertificate set in Config") } - - l, err := tls.Listen(network, addr, srv.TLSConfig) + network := strings.TrimSuffix(srv.Net, "-tls") + l, err := listenTCP(network, addr, srv.ReusePort) if err != nil { return err } + l = tls.NewListener(l, srv.TLSConfig) srv.Listener = l srv.started = true unlock() return srv.serveTCP(l) case "udp", "udp4", "udp6": - a, err := net.ResolveUDPAddr(srv.Net, addr) + l, err := listenUDP(srv.Net, addr, srv.ReusePort) if err != nil { return err } - l, err := net.ListenUDP(srv.Net, a) - if err != nil { - return err - } - if e := setUDPSocketOptions(l); e != nil { + u := l.(*net.UDPConn) + if e := setUDPSocketOptions(u); e != nil { return e } srv.PacketConn = l srv.started = true unlock() - return srv.serveUDP(l) + return srv.serveUDP(u) } return &Error{err: "bad network"} } diff --git a/server_test.go b/server_test.go index e8fc05d2..e6f073d7 100644 --- a/server_test.go +++ b/server_test.go @@ -702,6 +702,52 @@ func TestServerStartStopRace(t *testing.T) { wg.Wait() } +func TestServerReuseport(t *testing.T) { + if !supportsReusePort { + t.Skip("reuseport is not supported") + } + + startServer := func(addr string) (*Server, chan error) { + wait := make(chan struct{}) + srv := &Server{ + Net: "udp", + Addr: addr, + NotifyStartedFunc: func() { close(wait) }, + ReusePort: true, + } + + fin := make(chan error, 1) + go func() { + fin <- srv.ListenAndServe() + }() + + select { + case <-wait: + case err := <-fin: + t.Fatalf("failed to start server: %v", err) + } + + return srv, fin + } + + srv1, fin1 := startServer(":0") // :0 is resolved to a random free port by the kernel + srv2, fin2 := startServer(srv1.PacketConn.LocalAddr().String()) + + if err := srv1.Shutdown(); err != nil { + t.Fatalf("failed to shutdown first server: %v", err) + } + if err := srv2.Shutdown(); err != nil { + t.Fatalf("failed to shutdown second server: %v", err) + } + + if err := <-fin1; err != nil { + t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err) + } + if err := <-fin2; err != nil { + t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err) + } +} + type ExampleFrameLengthWriter struct { Writer }