Use workers instead spawning goroutines for each incoming DNS request (#664)

* Use workers instead spawning goroutines for each incoming DNS request

* Replace count (int) with inUse (bool)
This commit is contained in:
Uladzimir Trehubenka 2018-05-09 18:44:32 +03:00 committed by Miek Gieben
parent 9c76f9827e
commit 98a1ef4565
1 changed files with 90 additions and 38 deletions

128
server.go
View File

@ -9,12 +9,19 @@ import (
"io" "io"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
// Maximum number of TCP queries before we close the socket. // Maximum number of TCP queries before we close the socket.
const maxTCPQueries = 128 const maxTCPQueries = 128
// Interval for stop worker if no load
const idleWorkerTimeout = 10 * time.Second
// Maximum number of workers
const maxWorkersCount = 10000
// Handler is implemented by any value that implements ServeDNS. // Handler is implemented by any value that implements ServeDNS.
type Handler interface { type Handler interface {
ServeDNS(w ResponseWriter, r *Msg) ServeDNS(w ResponseWriter, r *Msg)
@ -43,6 +50,7 @@ type ResponseWriter interface {
} }
type response struct { type response struct {
msg []byte
hijacked bool // connection has been hijacked by handler hijacked bool // connection has been hijacked by handler
tsigStatus error tsigStatus error
tsigTimersOnly bool tsigTimersOnly bool
@ -296,11 +304,60 @@ type Server struct {
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages. // DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
DecorateWriter DecorateWriter DecorateWriter DecorateWriter
// UDP packet or TCP connection queue
queue chan *response
// Workers count
workersCount int32
// Shutdown handling // Shutdown handling
lock sync.RWMutex lock sync.RWMutex
started bool started bool
} }
func (srv *Server) worker(w *response) {
srv.serve(w)
for {
count := atomic.LoadInt32(&srv.workersCount)
if count > maxWorkersCount {
return
}
if atomic.CompareAndSwapInt32(&srv.workersCount, count, count+1) {
break
}
}
defer atomic.AddInt32(&srv.workersCount, -1)
inUse := false
timeout := time.NewTimer(idleWorkerTimeout)
defer timeout.Stop()
LOOP:
for {
select {
case w, ok := <-srv.queue:
if !ok {
break LOOP
}
inUse = true
srv.serve(w)
case <-timeout.C:
if !inUse {
break LOOP
}
inUse = false
timeout.Reset(idleWorkerTimeout)
}
}
}
func (srv *Server) spawnWorker(w *response) {
select {
case srv.queue <- w:
default:
go srv.worker(w)
}
}
// ListenAndServe starts a nameserver on the configured address in *Server. // ListenAndServe starts a nameserver on the configured address in *Server.
func (srv *Server) ListenAndServe() error { func (srv *Server) ListenAndServe() error {
srv.lock.Lock() srv.lock.Lock()
@ -308,6 +365,7 @@ func (srv *Server) ListenAndServe() error {
if srv.started { if srv.started {
return &Error{err: "server already started"} return &Error{err: "server already started"}
} }
addr := srv.Addr addr := srv.Addr
if addr == "" { if addr == "" {
addr = ":domain" addr = ":domain"
@ -315,6 +373,8 @@ func (srv *Server) ListenAndServe() error {
if srv.UDPSize == 0 { if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize srv.UDPSize = MinMsgSize
} }
srv.queue = make(chan *response)
defer close(srv.queue)
switch srv.Net { switch srv.Net {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
a, err := net.ResolveTCPAddr(srv.Net, addr) a, err := net.ResolveTCPAddr(srv.Net, addr)
@ -379,8 +439,11 @@ func (srv *Server) ActivateAndServe() error {
if srv.started { if srv.started {
return &Error{err: "server already started"} return &Error{err: "server already started"}
} }
pConn := srv.PacketConn pConn := srv.PacketConn
l := srv.Listener l := srv.Listener
srv.queue = make(chan *response)
defer close(srv.queue)
if pConn != nil { if pConn != nil {
if srv.UDPSize == 0 { if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize srv.UDPSize = MinMsgSize
@ -438,7 +501,6 @@ func (srv *Server) getReadTimeout() time.Duration {
} }
// serveTCP starts a TCP listener for the server. // serveTCP starts a TCP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveTCP(l net.Listener) error { func (srv *Server) serveTCP(l net.Listener) error {
defer l.Close() defer l.Close()
@ -446,11 +508,6 @@ func (srv *Server) serveTCP(l net.Listener) error {
srv.NotifyStartedFunc() srv.NotifyStartedFunc()
} }
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
// deadline is not used here
for { for {
rw, err := l.Accept() rw, err := l.Accept()
srv.lock.RLock() srv.lock.RLock()
@ -465,12 +522,11 @@ func (srv *Server) serveTCP(l net.Listener) error {
} }
return err return err
} }
go srv.serveTCPConn(handler, rw) srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw})
} }
} }
// serveUDP starts a UDP listener for the server. // serveUDP starts a UDP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveUDP(l *net.UDPConn) error { func (srv *Server) serveUDP(l *net.UDPConn) error {
defer l.Close() defer l.Close()
@ -483,10 +539,6 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
reader = srv.DecorateReader(reader) reader = srv.DecorateReader(reader)
} }
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
rtimeout := srv.getReadTimeout() rtimeout := srv.getReadTimeout()
// deadline is not used here // deadline is not used here
for { for {
@ -506,24 +558,28 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
if len(m) < headerSize { if len(m) < headerSize {
continue continue
} }
go srv.serveUDPPacket(handler, m, l, s) srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s})
} }
} }
// Serve a new TCP connection. func (srv *Server) serve(w *response) {
func (srv *Server) serveTCPConn(h Handler, t net.Conn) {
reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
w := &response{tsigSecret: srv.TsigSecret, tcp: t}
if srv.DecorateWriter != nil { if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w) w.writer = srv.DecorateWriter(w)
} else { } else {
w.writer = w w.writer = w
} }
if w.udp != nil {
// serve UDP
srv.serveDNS(w)
return
}
reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
defer func() { defer func() {
if !w.hijacked { if !w.hijacked {
w.Close() w.Close()
@ -539,12 +595,13 @@ func (srv *Server) serveTCPConn(h Handler, t net.Conn) {
// TODO(miek): make maxTCPQueries configurable? // TODO(miek): make maxTCPQueries configurable?
for q := 0; q < maxTCPQueries; q++ { for q := 0; q < maxTCPQueries; q++ {
m, err := reader.ReadTCP(t, timeout) var err error
w.msg, err = reader.ReadTCP(w.tcp, timeout)
if err != nil { if err != nil {
// TODO(tmthrgd): handle error // TODO(tmthrgd): handle error
break break
} }
srv.serveDNS(h, m, w) srv.serveDNS(w)
if w.tcp == nil { if w.tcp == nil {
break // Close() was called break // Close() was called
} }
@ -557,20 +614,9 @@ func (srv *Server) serveTCPConn(h Handler, t net.Conn) {
} }
} }
// Serve a new UDP request. func (srv *Server) serveDNS(w *response) {
func (srv *Server) serveUDPPacket(h Handler, m []byte, u *net.UDPConn, s *SessionUDP) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
w.writer = w
}
srv.serveDNS(h, m, w)
}
func (srv *Server) serveDNS(h Handler, m []byte, w *response) {
req := new(Msg) req := new(Msg)
err := req.Unpack(m) err := req.Unpack(w.msg)
if err != nil { // Send a FormatError back if err != nil { // Send a FormatError back
x := new(Msg) x := new(Msg)
x.SetRcodeFormatError(req) x.SetRcodeFormatError(req)
@ -585,7 +631,7 @@ func (srv *Server) serveDNS(h Handler, m []byte, w *response) {
if w.tsigSecret != nil { if w.tsigSecret != nil {
if t := req.IsTsig(); t != nil { if t := req.IsTsig(); t != nil {
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
w.tsigStatus = TsigVerify(m, secret, "", false) w.tsigStatus = TsigVerify(w.msg, secret, "", false)
} else { } else {
w.tsigStatus = ErrSecret w.tsigStatus = ErrSecret
} }
@ -593,7 +639,13 @@ func (srv *Server) serveDNS(h Handler, m []byte, w *response) {
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
} }
} }
h.ServeDNS(w, req) // Writes back to the client
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
handler.ServeDNS(w, req) // Writes back to the client
} }
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {