From 98a1ef4565e7ee768fdb395432704321f0a4ba21 Mon Sep 17 00:00:00 2001 From: Uladzimir Trehubenka Date: Wed, 9 May 2018 18:44:32 +0300 Subject: [PATCH] Use workers instead spawning goroutines for each incoming DNS request (#664) * Use workers instead spawning goroutines for each incoming DNS request * Replace count (int) with inUse (bool) --- server.go | 128 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 90 insertions(+), 38 deletions(-) diff --git a/server.go b/server.go index 20ef6b10..54bce126 100644 --- a/server.go +++ b/server.go @@ -9,12 +9,19 @@ import ( "io" "net" "sync" + "sync/atomic" "time" ) // Maximum number of TCP queries before we close the socket. const maxTCPQueries = 128 +// Interval for stop worker if no load +const idleWorkerTimeout = 10 * time.Second + +// Maximum number of workers +const maxWorkersCount = 10000 + // Handler is implemented by any value that implements ServeDNS. type Handler interface { ServeDNS(w ResponseWriter, r *Msg) @@ -43,6 +50,7 @@ type ResponseWriter interface { } type response struct { + msg []byte hijacked bool // connection has been hijacked by handler tsigStatus error tsigTimersOnly bool @@ -296,11 +304,60 @@ type Server struct { // DecorateWriter is optional, allows customization of the process that writes raw DNS messages. DecorateWriter DecorateWriter + // UDP packet or TCP connection queue + queue chan *response + // Workers count + workersCount int32 // Shutdown handling lock sync.RWMutex started bool } +func (srv *Server) worker(w *response) { + srv.serve(w) + + for { + count := atomic.LoadInt32(&srv.workersCount) + if count > maxWorkersCount { + return + } + if atomic.CompareAndSwapInt32(&srv.workersCount, count, count+1) { + break + } + } + + defer atomic.AddInt32(&srv.workersCount, -1) + + inUse := false + timeout := time.NewTimer(idleWorkerTimeout) + defer timeout.Stop() +LOOP: + for { + select { + case w, ok := <-srv.queue: + if !ok { + break LOOP + } + inUse = true + srv.serve(w) + case <-timeout.C: + if !inUse { + break LOOP + } + inUse = false + timeout.Reset(idleWorkerTimeout) + } + } +} + +func (srv *Server) spawnWorker(w *response) { + select { + case srv.queue <- w: + default: + go srv.worker(w) + } +} + // ListenAndServe starts a nameserver on the configured address in *Server. func (srv *Server) ListenAndServe() error { srv.lock.Lock() @@ -308,6 +365,7 @@ func (srv *Server) ListenAndServe() error { if srv.started { return &Error{err: "server already started"} } + addr := srv.Addr if addr == "" { addr = ":domain" @@ -315,6 +373,8 @@ func (srv *Server) ListenAndServe() error { if srv.UDPSize == 0 { srv.UDPSize = MinMsgSize } + srv.queue = make(chan *response) + defer close(srv.queue) switch srv.Net { case "tcp", "tcp4", "tcp6": a, err := net.ResolveTCPAddr(srv.Net, addr) @@ -379,8 +439,11 @@ func (srv *Server) ActivateAndServe() error { if srv.started { return &Error{err: "server already started"} } + pConn := srv.PacketConn l := srv.Listener + srv.queue = make(chan *response) + defer close(srv.queue) if pConn != nil { if srv.UDPSize == 0 { srv.UDPSize = MinMsgSize @@ -438,7 +501,6 @@ func (srv *Server) getReadTimeout() time.Duration { } // serveTCP starts a TCP listener for the server. -// Each request is handled in a separate goroutine. func (srv *Server) serveTCP(l net.Listener) error { defer l.Close() @@ -446,11 +508,6 @@ func (srv *Server) serveTCP(l net.Listener) error { srv.NotifyStartedFunc() } - handler := srv.Handler - if handler == nil { - handler = DefaultServeMux - } - // deadline is not used here for { rw, err := l.Accept() srv.lock.RLock() @@ -465,12 +522,11 @@ func (srv *Server) serveTCP(l net.Listener) error { } return err } - go srv.serveTCPConn(handler, rw) + srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw}) } } // serveUDP starts a UDP listener for the server. -// Each request is handled in a separate goroutine. func (srv *Server) serveUDP(l *net.UDPConn) error { defer l.Close() @@ -483,10 +539,6 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { reader = srv.DecorateReader(reader) } - handler := srv.Handler - if handler == nil { - handler = DefaultServeMux - } rtimeout := srv.getReadTimeout() // deadline is not used here for { @@ -506,24 +558,28 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { if len(m) < headerSize { continue } - go srv.serveUDPPacket(handler, m, l, s) + srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s}) } } -// Serve a new TCP connection. -func (srv *Server) serveTCPConn(h Handler, t net.Conn) { - reader := Reader(&defaultReader{srv}) - if srv.DecorateReader != nil { - reader = srv.DecorateReader(reader) - } - - w := &response{tsigSecret: srv.TsigSecret, tcp: t} +func (srv *Server) serve(w *response) { if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { w.writer = w } + if w.udp != nil { + // serve UDP + srv.serveDNS(w) + return + } + + reader := Reader(&defaultReader{srv}) + if srv.DecorateReader != nil { + reader = srv.DecorateReader(reader) + } + defer func() { if !w.hijacked { w.Close() @@ -539,12 +595,13 @@ func (srv *Server) serveTCPConn(h Handler, t net.Conn) { // TODO(miek): make maxTCPQueries configurable? for q := 0; q < maxTCPQueries; q++ { - m, err := reader.ReadTCP(t, timeout) + var err error + w.msg, err = reader.ReadTCP(w.tcp, timeout) if err != nil { // TODO(tmthrgd): handle error break } - srv.serveDNS(h, m, w) + srv.serveDNS(w) if w.tcp == nil { break // Close() was called } @@ -557,20 +614,9 @@ func (srv *Server) serveTCPConn(h Handler, t net.Conn) { } } -// Serve a new UDP request. -func (srv *Server) serveUDPPacket(h Handler, m []byte, u *net.UDPConn, s *SessionUDP) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s} - if srv.DecorateWriter != nil { - w.writer = srv.DecorateWriter(w) - } else { - w.writer = w - } - srv.serveDNS(h, m, w) -} - -func (srv *Server) serveDNS(h Handler, m []byte, w *response) { +func (srv *Server) serveDNS(w *response) { req := new(Msg) - err := req.Unpack(m) + err := req.Unpack(w.msg) if err != nil { // Send a FormatError back x := new(Msg) x.SetRcodeFormatError(req) @@ -585,7 +631,7 @@ func (srv *Server) serveDNS(h Handler, m []byte, w *response) { if w.tsigSecret != nil { if t := req.IsTsig(); t != nil { if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { - w.tsigStatus = TsigVerify(m, secret, "", false) + w.tsigStatus = TsigVerify(w.msg, secret, "", false) } else { w.tsigStatus = ErrSecret } @@ -593,7 +639,13 @@ func (srv *Server) serveDNS(h Handler, m []byte, w *response) { w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC } } - h.ServeDNS(w, req) // Writes back to the client + + handler := srv.Handler + if handler == nil { + handler = DefaultServeMux + } + + handler.ServeDNS(w, req) // Writes back to the client } func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {