diff --git a/server.go b/server.go index 575c657d..c0f18037 100644 --- a/server.go +++ b/server.go @@ -320,6 +320,9 @@ type Server struct { // Shutdown handling lock sync.RWMutex started bool + + // A pool for UDP message buffers. + udpPool sync.Pool } func (srv *Server) isStarted() bool { @@ -374,6 +377,22 @@ func (srv *Server) spawnWorker(w *response) { } } +func makeUDPBuffer(size int) func() interface{} { + return func() interface{} { + return make([]byte, size) + } +} + +func (srv *Server) init() { + srv.queue = make(chan *response) + + if srv.UDPSize == 0 { + srv.UDPSize = MinMsgSize + } + + srv.udpPool.New = makeUDPBuffer(srv.UDPSize) +} + func unlockOnce(l sync.Locker) func() { var once sync.Once return func() { once.Do(l.Unlock) } @@ -393,11 +412,10 @@ func (srv *Server) ListenAndServe() error { if addr == "" { addr = ":domain" } - if srv.UDPSize == 0 { - srv.UDPSize = MinMsgSize - } - srv.queue = make(chan *response) + + srv.init() defer close(srv.queue) + switch srv.Net { case "tcp", "tcp4", "tcp6": a, err := net.ResolveTCPAddr(srv.Net, addr) @@ -459,14 +477,12 @@ func (srv *Server) ActivateAndServe() error { return &Error{err: "server already started"} } + srv.init() + defer close(srv.queue) + pConn := srv.PacketConn l := srv.Listener - srv.queue = make(chan *response) - defer close(srv.queue) if pConn != nil { - if srv.UDPSize == 0 { - srv.UDPSize = MinMsgSize - } // Check PacketConn interface's type is valid and value // is not nil if t, ok := pConn.(*net.UDPConn); ok && t != nil { @@ -565,6 +581,9 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { return err } if len(m) < headerSize { + if cap(m) == srv.UDPSize { + srv.udpPool.Put(m[:srv.UDPSize]) + } continue } srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s}) @@ -630,6 +649,10 @@ func (srv *Server) serve(w *response) { func (srv *Server) serveDNS(w *response) { req := new(Msg) err := req.Unpack(w.msg) + if w.udp != nil && cap(w.msg) == srv.UDPSize { + srv.udpPool.Put(w.msg[:srv.UDPSize]) + } + w.msg = nil if err != nil { // Send a FormatError back x := new(Msg) x.SetRcodeFormatError(req) @@ -698,9 +721,10 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { conn.SetReadDeadline(time.Now().Add(timeout)) - m := make([]byte, srv.UDPSize) + m := srv.udpPool.Get().([]byte) n, s, err := ReadFromSessionUDP(conn, m) if err != nil { + srv.udpPool.Put(m) return nil, nil, err } m = m[:n]