Use workers instead spawning goroutines for each incoming DNS request (#664)

* Use workers instead spawning goroutines for each incoming DNS request

* Replace count (int) with inUse (bool)
This commit is contained in:
Uladzimir Trehubenka 2018-05-09 18:44:32 +03:00 committed by Miek Gieben
parent 9c76f9827e
commit 98a1ef4565
1 changed files with 90 additions and 38 deletions

128
server.go
View File

@ -9,12 +9,19 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"time"
)
// Maximum number of TCP queries before we close the socket.
const maxTCPQueries = 128
// Interval for stop worker if no load
const idleWorkerTimeout = 10 * time.Second
// Maximum number of workers
const maxWorkersCount = 10000
// Handler is implemented by any value that implements ServeDNS.
type Handler interface {
ServeDNS(w ResponseWriter, r *Msg)
@ -43,6 +50,7 @@ type ResponseWriter interface {
}
type response struct {
msg []byte
hijacked bool // connection has been hijacked by handler
tsigStatus error
tsigTimersOnly bool
@ -296,11 +304,60 @@ type Server struct {
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
DecorateWriter DecorateWriter
// UDP packet or TCP connection queue
queue chan *response
// Workers count
workersCount int32
// Shutdown handling
lock sync.RWMutex
started bool
}
func (srv *Server) worker(w *response) {
srv.serve(w)
for {
count := atomic.LoadInt32(&srv.workersCount)
if count > maxWorkersCount {
return
}
if atomic.CompareAndSwapInt32(&srv.workersCount, count, count+1) {
break
}
}
defer atomic.AddInt32(&srv.workersCount, -1)
inUse := false
timeout := time.NewTimer(idleWorkerTimeout)
defer timeout.Stop()
LOOP:
for {
select {
case w, ok := <-srv.queue:
if !ok {
break LOOP
}
inUse = true
srv.serve(w)
case <-timeout.C:
if !inUse {
break LOOP
}
inUse = false
timeout.Reset(idleWorkerTimeout)
}
}
}
func (srv *Server) spawnWorker(w *response) {
select {
case srv.queue <- w:
default:
go srv.worker(w)
}
}
// ListenAndServe starts a nameserver on the configured address in *Server.
func (srv *Server) ListenAndServe() error {
srv.lock.Lock()
@ -308,6 +365,7 @@ func (srv *Server) ListenAndServe() error {
if srv.started {
return &Error{err: "server already started"}
}
addr := srv.Addr
if addr == "" {
addr = ":domain"
@ -315,6 +373,8 @@ func (srv *Server) ListenAndServe() error {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
srv.queue = make(chan *response)
defer close(srv.queue)
switch srv.Net {
case "tcp", "tcp4", "tcp6":
a, err := net.ResolveTCPAddr(srv.Net, addr)
@ -379,8 +439,11 @@ func (srv *Server) ActivateAndServe() error {
if srv.started {
return &Error{err: "server already started"}
}
pConn := srv.PacketConn
l := srv.Listener
srv.queue = make(chan *response)
defer close(srv.queue)
if pConn != nil {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
@ -438,7 +501,6 @@ func (srv *Server) getReadTimeout() time.Duration {
}
// serveTCP starts a TCP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveTCP(l net.Listener) error {
defer l.Close()
@ -446,11 +508,6 @@ func (srv *Server) serveTCP(l net.Listener) error {
srv.NotifyStartedFunc()
}
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
// deadline is not used here
for {
rw, err := l.Accept()
srv.lock.RLock()
@ -465,12 +522,11 @@ func (srv *Server) serveTCP(l net.Listener) error {
}
return err
}
go srv.serveTCPConn(handler, rw)
srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw})
}
}
// serveUDP starts a UDP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveUDP(l *net.UDPConn) error {
defer l.Close()
@ -483,10 +539,6 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
reader = srv.DecorateReader(reader)
}
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
@ -506,24 +558,28 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
if len(m) < headerSize {
continue
}
go srv.serveUDPPacket(handler, m, l, s)
srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s})
}
}
// Serve a new TCP connection.
func (srv *Server) serveTCPConn(h Handler, t net.Conn) {
reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
w := &response{tsigSecret: srv.TsigSecret, tcp: t}
func (srv *Server) serve(w *response) {
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
w.writer = w
}
if w.udp != nil {
// serve UDP
srv.serveDNS(w)
return
}
reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
defer func() {
if !w.hijacked {
w.Close()
@ -539,12 +595,13 @@ func (srv *Server) serveTCPConn(h Handler, t net.Conn) {
// TODO(miek): make maxTCPQueries configurable?
for q := 0; q < maxTCPQueries; q++ {
m, err := reader.ReadTCP(t, timeout)
var err error
w.msg, err = reader.ReadTCP(w.tcp, timeout)
if err != nil {
// TODO(tmthrgd): handle error
break
}
srv.serveDNS(h, m, w)
srv.serveDNS(w)
if w.tcp == nil {
break // Close() was called
}
@ -557,20 +614,9 @@ func (srv *Server) serveTCPConn(h Handler, t net.Conn) {
}
}
// Serve a new UDP request.
func (srv *Server) serveUDPPacket(h Handler, m []byte, u *net.UDPConn, s *SessionUDP) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
w.writer = w
}
srv.serveDNS(h, m, w)
}
func (srv *Server) serveDNS(h Handler, m []byte, w *response) {
func (srv *Server) serveDNS(w *response) {
req := new(Msg)
err := req.Unpack(m)
err := req.Unpack(w.msg)
if err != nil { // Send a FormatError back
x := new(Msg)
x.SetRcodeFormatError(req)
@ -585,7 +631,7 @@ func (srv *Server) serveDNS(h Handler, m []byte, w *response) {
if w.tsigSecret != nil {
if t := req.IsTsig(); t != nil {
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
w.tsigStatus = TsigVerify(m, secret, "", false)
w.tsigStatus = TsigVerify(w.msg, secret, "", false)
} else {
w.tsigStatus = ErrSecret
}
@ -593,7 +639,13 @@ func (srv *Server) serveDNS(h Handler, m []byte, w *response) {
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
}
}
h.ServeDNS(w, req) // Writes back to the client
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
handler.ServeDNS(w, req) // Writes back to the client
}
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {