Reduce UDP server memory usage (#735)

* Clear the response.msg field after unpacking

The allocated buffer cannot be freed by the garbage collector while the
response is alive, by clearing msg here, the GC can collect the buffer
sooner.

* Use a sync.Pool for UDP message buffers

* Return UDP message buffer to pool in all paths

* Move udpPool.New closure out of (*Server).init

The closure used to capture the *Server which would cause a reference
loop and prevent it from ever being released by the garbage collector.

This also gives the closure a more obvious name in memory profiles:
  github.com/miekg/dns.makeUDPBuffer.func1 rather than
  github.com/miekg/dns.(*Server).init.func1.
This commit is contained in:
Tom Thorogood 2018-09-09 01:45:17 +09:30 committed by Miek Gieben
parent 3ce7efeace
commit bf6da3a5bd
1 changed files with 34 additions and 10 deletions

View File

@ -320,6 +320,9 @@ type Server struct {
// Shutdown handling // Shutdown handling
lock sync.RWMutex lock sync.RWMutex
started bool started bool
// A pool for UDP message buffers.
udpPool sync.Pool
} }
func (srv *Server) isStarted() bool { func (srv *Server) isStarted() bool {
@ -374,6 +377,22 @@ func (srv *Server) spawnWorker(w *response) {
} }
} }
func makeUDPBuffer(size int) func() interface{} {
return func() interface{} {
return make([]byte, size)
}
}
func (srv *Server) init() {
srv.queue = make(chan *response)
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
}
func unlockOnce(l sync.Locker) func() { func unlockOnce(l sync.Locker) func() {
var once sync.Once var once sync.Once
return func() { once.Do(l.Unlock) } return func() { once.Do(l.Unlock) }
@ -393,11 +412,10 @@ func (srv *Server) ListenAndServe() error {
if addr == "" { if addr == "" {
addr = ":domain" addr = ":domain"
} }
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize srv.init()
}
srv.queue = make(chan *response)
defer close(srv.queue) defer close(srv.queue)
switch srv.Net { switch srv.Net {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
a, err := net.ResolveTCPAddr(srv.Net, addr) a, err := net.ResolveTCPAddr(srv.Net, addr)
@ -459,14 +477,12 @@ func (srv *Server) ActivateAndServe() error {
return &Error{err: "server already started"} return &Error{err: "server already started"}
} }
srv.init()
defer close(srv.queue)
pConn := srv.PacketConn pConn := srv.PacketConn
l := srv.Listener l := srv.Listener
srv.queue = make(chan *response)
defer close(srv.queue)
if pConn != nil { if pConn != nil {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
// Check PacketConn interface's type is valid and value // Check PacketConn interface's type is valid and value
// is not nil // is not nil
if t, ok := pConn.(*net.UDPConn); ok && t != nil { if t, ok := pConn.(*net.UDPConn); ok && t != nil {
@ -565,6 +581,9 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
return err return err
} }
if len(m) < headerSize { if len(m) < headerSize {
if cap(m) == srv.UDPSize {
srv.udpPool.Put(m[:srv.UDPSize])
}
continue continue
} }
srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s}) srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s})
@ -630,6 +649,10 @@ func (srv *Server) serve(w *response) {
func (srv *Server) serveDNS(w *response) { func (srv *Server) serveDNS(w *response) {
req := new(Msg) req := new(Msg)
err := req.Unpack(w.msg) err := req.Unpack(w.msg)
if w.udp != nil && cap(w.msg) == srv.UDPSize {
srv.udpPool.Put(w.msg[:srv.UDPSize])
}
w.msg = nil
if err != nil { // Send a FormatError back if err != nil { // Send a FormatError back
x := new(Msg) x := new(Msg)
x.SetRcodeFormatError(req) x.SetRcodeFormatError(req)
@ -698,9 +721,10 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
conn.SetReadDeadline(time.Now().Add(timeout)) conn.SetReadDeadline(time.Now().Add(timeout))
m := make([]byte, srv.UDPSize) m := srv.udpPool.Get().([]byte)
n, s, err := ReadFromSessionUDP(conn, m) n, s, err := ReadFromSessionUDP(conn, m)
if err != nil { if err != nil {
srv.udpPool.Put(m)
return nil, nil, err return nil, nil, err
} }
m = m[:n] m = m[:n]