Use workers instead spawning goroutines for each incoming DNS request (#664)
* Use workers instead spawning goroutines for each incoming DNS request * Replace count (int) with inUse (bool)
This commit is contained in:
parent
9c76f9827e
commit
98a1ef4565
128
server.go
128
server.go
|
@ -9,12 +9,19 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Maximum number of TCP queries before we close the socket.
|
// Maximum number of TCP queries before we close the socket.
|
||||||
const maxTCPQueries = 128
|
const maxTCPQueries = 128
|
||||||
|
|
||||||
|
// Interval for stop worker if no load
|
||||||
|
const idleWorkerTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// Maximum number of workers
|
||||||
|
const maxWorkersCount = 10000
|
||||||
|
|
||||||
// Handler is implemented by any value that implements ServeDNS.
|
// Handler is implemented by any value that implements ServeDNS.
|
||||||
type Handler interface {
|
type Handler interface {
|
||||||
ServeDNS(w ResponseWriter, r *Msg)
|
ServeDNS(w ResponseWriter, r *Msg)
|
||||||
|
@ -43,6 +50,7 @@ type ResponseWriter interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type response struct {
|
type response struct {
|
||||||
|
msg []byte
|
||||||
hijacked bool // connection has been hijacked by handler
|
hijacked bool // connection has been hijacked by handler
|
||||||
tsigStatus error
|
tsigStatus error
|
||||||
tsigTimersOnly bool
|
tsigTimersOnly bool
|
||||||
|
@ -296,11 +304,60 @@ type Server struct {
|
||||||
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
|
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
|
||||||
DecorateWriter DecorateWriter
|
DecorateWriter DecorateWriter
|
||||||
|
|
||||||
|
// UDP packet or TCP connection queue
|
||||||
|
queue chan *response
|
||||||
|
// Workers count
|
||||||
|
workersCount int32
|
||||||
// Shutdown handling
|
// Shutdown handling
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
started bool
|
started bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (srv *Server) worker(w *response) {
|
||||||
|
srv.serve(w)
|
||||||
|
|
||||||
|
for {
|
||||||
|
count := atomic.LoadInt32(&srv.workersCount)
|
||||||
|
if count > maxWorkersCount {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if atomic.CompareAndSwapInt32(&srv.workersCount, count, count+1) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
defer atomic.AddInt32(&srv.workersCount, -1)
|
||||||
|
|
||||||
|
inUse := false
|
||||||
|
timeout := time.NewTimer(idleWorkerTimeout)
|
||||||
|
defer timeout.Stop()
|
||||||
|
LOOP:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case w, ok := <-srv.queue:
|
||||||
|
if !ok {
|
||||||
|
break LOOP
|
||||||
|
}
|
||||||
|
inUse = true
|
||||||
|
srv.serve(w)
|
||||||
|
case <-timeout.C:
|
||||||
|
if !inUse {
|
||||||
|
break LOOP
|
||||||
|
}
|
||||||
|
inUse = false
|
||||||
|
timeout.Reset(idleWorkerTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) spawnWorker(w *response) {
|
||||||
|
select {
|
||||||
|
case srv.queue <- w:
|
||||||
|
default:
|
||||||
|
go srv.worker(w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ListenAndServe starts a nameserver on the configured address in *Server.
|
// ListenAndServe starts a nameserver on the configured address in *Server.
|
||||||
func (srv *Server) ListenAndServe() error {
|
func (srv *Server) ListenAndServe() error {
|
||||||
srv.lock.Lock()
|
srv.lock.Lock()
|
||||||
|
@ -308,6 +365,7 @@ func (srv *Server) ListenAndServe() error {
|
||||||
if srv.started {
|
if srv.started {
|
||||||
return &Error{err: "server already started"}
|
return &Error{err: "server already started"}
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := srv.Addr
|
addr := srv.Addr
|
||||||
if addr == "" {
|
if addr == "" {
|
||||||
addr = ":domain"
|
addr = ":domain"
|
||||||
|
@ -315,6 +373,8 @@ func (srv *Server) ListenAndServe() error {
|
||||||
if srv.UDPSize == 0 {
|
if srv.UDPSize == 0 {
|
||||||
srv.UDPSize = MinMsgSize
|
srv.UDPSize = MinMsgSize
|
||||||
}
|
}
|
||||||
|
srv.queue = make(chan *response)
|
||||||
|
defer close(srv.queue)
|
||||||
switch srv.Net {
|
switch srv.Net {
|
||||||
case "tcp", "tcp4", "tcp6":
|
case "tcp", "tcp4", "tcp6":
|
||||||
a, err := net.ResolveTCPAddr(srv.Net, addr)
|
a, err := net.ResolveTCPAddr(srv.Net, addr)
|
||||||
|
@ -379,8 +439,11 @@ func (srv *Server) ActivateAndServe() error {
|
||||||
if srv.started {
|
if srv.started {
|
||||||
return &Error{err: "server already started"}
|
return &Error{err: "server already started"}
|
||||||
}
|
}
|
||||||
|
|
||||||
pConn := srv.PacketConn
|
pConn := srv.PacketConn
|
||||||
l := srv.Listener
|
l := srv.Listener
|
||||||
|
srv.queue = make(chan *response)
|
||||||
|
defer close(srv.queue)
|
||||||
if pConn != nil {
|
if pConn != nil {
|
||||||
if srv.UDPSize == 0 {
|
if srv.UDPSize == 0 {
|
||||||
srv.UDPSize = MinMsgSize
|
srv.UDPSize = MinMsgSize
|
||||||
|
@ -438,7 +501,6 @@ func (srv *Server) getReadTimeout() time.Duration {
|
||||||
}
|
}
|
||||||
|
|
||||||
// serveTCP starts a TCP listener for the server.
|
// serveTCP starts a TCP listener for the server.
|
||||||
// Each request is handled in a separate goroutine.
|
|
||||||
func (srv *Server) serveTCP(l net.Listener) error {
|
func (srv *Server) serveTCP(l net.Listener) error {
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
|
||||||
|
@ -446,11 +508,6 @@ func (srv *Server) serveTCP(l net.Listener) error {
|
||||||
srv.NotifyStartedFunc()
|
srv.NotifyStartedFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := srv.Handler
|
|
||||||
if handler == nil {
|
|
||||||
handler = DefaultServeMux
|
|
||||||
}
|
|
||||||
// deadline is not used here
|
|
||||||
for {
|
for {
|
||||||
rw, err := l.Accept()
|
rw, err := l.Accept()
|
||||||
srv.lock.RLock()
|
srv.lock.RLock()
|
||||||
|
@ -465,12 +522,11 @@ func (srv *Server) serveTCP(l net.Listener) error {
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
go srv.serveTCPConn(handler, rw)
|
srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// serveUDP starts a UDP listener for the server.
|
// serveUDP starts a UDP listener for the server.
|
||||||
// Each request is handled in a separate goroutine.
|
|
||||||
func (srv *Server) serveUDP(l *net.UDPConn) error {
|
func (srv *Server) serveUDP(l *net.UDPConn) error {
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
|
||||||
|
@ -483,10 +539,6 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
||||||
reader = srv.DecorateReader(reader)
|
reader = srv.DecorateReader(reader)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := srv.Handler
|
|
||||||
if handler == nil {
|
|
||||||
handler = DefaultServeMux
|
|
||||||
}
|
|
||||||
rtimeout := srv.getReadTimeout()
|
rtimeout := srv.getReadTimeout()
|
||||||
// deadline is not used here
|
// deadline is not used here
|
||||||
for {
|
for {
|
||||||
|
@ -506,24 +558,28 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
||||||
if len(m) < headerSize {
|
if len(m) < headerSize {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go srv.serveUDPPacket(handler, m, l, s)
|
srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve a new TCP connection.
|
func (srv *Server) serve(w *response) {
|
||||||
func (srv *Server) serveTCPConn(h Handler, t net.Conn) {
|
|
||||||
reader := Reader(&defaultReader{srv})
|
|
||||||
if srv.DecorateReader != nil {
|
|
||||||
reader = srv.DecorateReader(reader)
|
|
||||||
}
|
|
||||||
|
|
||||||
w := &response{tsigSecret: srv.TsigSecret, tcp: t}
|
|
||||||
if srv.DecorateWriter != nil {
|
if srv.DecorateWriter != nil {
|
||||||
w.writer = srv.DecorateWriter(w)
|
w.writer = srv.DecorateWriter(w)
|
||||||
} else {
|
} else {
|
||||||
w.writer = w
|
w.writer = w
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if w.udp != nil {
|
||||||
|
// serve UDP
|
||||||
|
srv.serveDNS(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := Reader(&defaultReader{srv})
|
||||||
|
if srv.DecorateReader != nil {
|
||||||
|
reader = srv.DecorateReader(reader)
|
||||||
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if !w.hijacked {
|
if !w.hijacked {
|
||||||
w.Close()
|
w.Close()
|
||||||
|
@ -539,12 +595,13 @@ func (srv *Server) serveTCPConn(h Handler, t net.Conn) {
|
||||||
|
|
||||||
// TODO(miek): make maxTCPQueries configurable?
|
// TODO(miek): make maxTCPQueries configurable?
|
||||||
for q := 0; q < maxTCPQueries; q++ {
|
for q := 0; q < maxTCPQueries; q++ {
|
||||||
m, err := reader.ReadTCP(t, timeout)
|
var err error
|
||||||
|
w.msg, err = reader.ReadTCP(w.tcp, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO(tmthrgd): handle error
|
// TODO(tmthrgd): handle error
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
srv.serveDNS(h, m, w)
|
srv.serveDNS(w)
|
||||||
if w.tcp == nil {
|
if w.tcp == nil {
|
||||||
break // Close() was called
|
break // Close() was called
|
||||||
}
|
}
|
||||||
|
@ -557,20 +614,9 @@ func (srv *Server) serveTCPConn(h Handler, t net.Conn) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve a new UDP request.
|
func (srv *Server) serveDNS(w *response) {
|
||||||
func (srv *Server) serveUDPPacket(h Handler, m []byte, u *net.UDPConn, s *SessionUDP) {
|
|
||||||
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
|
|
||||||
if srv.DecorateWriter != nil {
|
|
||||||
w.writer = srv.DecorateWriter(w)
|
|
||||||
} else {
|
|
||||||
w.writer = w
|
|
||||||
}
|
|
||||||
srv.serveDNS(h, m, w)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) serveDNS(h Handler, m []byte, w *response) {
|
|
||||||
req := new(Msg)
|
req := new(Msg)
|
||||||
err := req.Unpack(m)
|
err := req.Unpack(w.msg)
|
||||||
if err != nil { // Send a FormatError back
|
if err != nil { // Send a FormatError back
|
||||||
x := new(Msg)
|
x := new(Msg)
|
||||||
x.SetRcodeFormatError(req)
|
x.SetRcodeFormatError(req)
|
||||||
|
@ -585,7 +631,7 @@ func (srv *Server) serveDNS(h Handler, m []byte, w *response) {
|
||||||
if w.tsigSecret != nil {
|
if w.tsigSecret != nil {
|
||||||
if t := req.IsTsig(); t != nil {
|
if t := req.IsTsig(); t != nil {
|
||||||
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
|
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
|
||||||
w.tsigStatus = TsigVerify(m, secret, "", false)
|
w.tsigStatus = TsigVerify(w.msg, secret, "", false)
|
||||||
} else {
|
} else {
|
||||||
w.tsigStatus = ErrSecret
|
w.tsigStatus = ErrSecret
|
||||||
}
|
}
|
||||||
|
@ -593,7 +639,13 @@ func (srv *Server) serveDNS(h Handler, m []byte, w *response) {
|
||||||
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
|
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
h.ServeDNS(w, req) // Writes back to the client
|
|
||||||
|
handler := srv.Handler
|
||||||
|
if handler == nil {
|
||||||
|
handler = DefaultServeMux
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.ServeDNS(w, req) // Writes back to the client
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
|
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
|
||||||
|
|
Loading…
Reference in New Issue