Split Server.serve into separate TCP and UDP methods (#933)

* Split Server.serve into separate TCP and UDP methods

* Merge reject cases in Server.serveDNS

* Inline Server.disposeBuffer method
This commit is contained in:
Tom Thorogood 2019-03-10 23:22:08 +10:30 committed by Miek Gieben
parent 53b8a87e14
commit 1a5555c783
1 changed files with 35 additions and 33 deletions

View File

@ -67,7 +67,6 @@ type ConnectionStater interface {
} }
type response struct { type response struct {
msg []byte
closed bool // connection has been closed closed bool // connection has been closed
hijacked bool // connection has been hijacked by handler hijacked bool // connection has been hijacked by handler
tsigTimersOnly bool tsigTimersOnly bool
@ -433,7 +432,7 @@ func (srv *Server) serveTCP(l net.Listener) error {
srv.conns[rw] = struct{}{} srv.conns[rw] = struct{}{}
srv.lock.Unlock() srv.lock.Unlock()
wg.Add(1) wg.Add(1)
go srv.serve(&response{tsigSecret: srv.TsigSecret, tcp: rw}, &wg) go srv.serveTCPConn(&wg, rw)
} }
return nil return nil
@ -478,26 +477,21 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
continue continue
} }
wg.Add(1) wg.Add(1)
go srv.serve(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s}, &wg) go srv.serveUDPPacket(&wg, m, l, s)
} }
return nil return nil
} }
func (srv *Server) serve(w *response, wg *sync.WaitGroup) { // Serve a new TCP connection.
func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
w := &response{tsigSecret: srv.TsigSecret, tcp: rw}
if srv.DecorateWriter != nil { if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w) w.writer = srv.DecorateWriter(w)
} else { } else {
w.writer = w w.writer = w
} }
if w.udp != nil {
// serve UDP
srv.serveDNS(w)
wg.Done()
return
}
reader := Reader(defaultReader{srv}) reader := Reader(defaultReader{srv})
if srv.DecorateReader != nil { if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader) reader = srv.DecorateReader(reader)
@ -516,13 +510,12 @@ func (srv *Server) serve(w *response, wg *sync.WaitGroup) {
} }
for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ { for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ {
var err error m, err := reader.ReadTCP(w.tcp, timeout)
w.msg, err = reader.ReadTCP(w.tcp, timeout)
if err != nil { if err != nil {
// TODO(tmthrgd): handle error // TODO(tmthrgd): handle error
break break
} }
srv.serveDNS(w) srv.serveDNS(m, w)
if w.closed { if w.closed {
break // Close() was called break // Close() was called
} }
@ -545,15 +538,21 @@ func (srv *Server) serve(w *response, wg *sync.WaitGroup) {
wg.Done() wg.Done()
} }
func (srv *Server) disposeBuffer(w *response) { // Serve a new UDP request.
if w.udp != nil && cap(w.msg) == srv.UDPSize { func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) {
srv.udpPool.Put(w.msg[:srv.UDPSize]) w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
w.writer = w
} }
w.msg = nil
srv.serveDNS(m, w)
wg.Done()
} }
func (srv *Server) serveDNS(w *response) { func (srv *Server) serveDNS(m []byte, w *response) {
dh, off, err := unpackMsgHdr(w.msg, 0) dh, off, err := unpackMsgHdr(m, 0)
if err != nil { if err != nil {
// Let client hang, they are sending crap; any reply can be used to amplify. // Let client hang, they are sending crap; any reply can be used to amplify.
return return
@ -564,24 +563,24 @@ func (srv *Server) serveDNS(w *response) {
switch srv.MsgAcceptFunc(dh) { switch srv.MsgAcceptFunc(dh) {
case MsgAccept: case MsgAccept:
case MsgIgnore: if req.unpack(dh, m, off) == nil {
return break
}
fallthrough
case MsgReject: case MsgReject:
req.SetRcodeFormatError(req) req.SetRcodeFormatError(req)
// Are we allowed to delete any OPT records here? // Are we allowed to delete any OPT records here?
req.Ns, req.Answer, req.Extra = nil, nil, nil req.Ns, req.Answer, req.Extra = nil, nil, nil
w.WriteMsg(req) w.WriteMsg(req)
srv.disposeBuffer(w)
if w.udp != nil && cap(m) == srv.UDPSize {
srv.udpPool.Put(m[:srv.UDPSize])
}
return return
} case MsgIgnore:
if err := req.unpack(dh, w.msg, off); err != nil {
req.SetRcodeFormatError(req)
req.Ns, req.Answer, req.Extra = nil, nil, nil
w.WriteMsg(req)
srv.disposeBuffer(w)
return return
} }
@ -589,7 +588,7 @@ func (srv *Server) serveDNS(w *response) {
if w.tsigSecret != nil { if w.tsigSecret != nil {
if t := req.IsTsig(); t != nil { if t := req.IsTsig(); t != nil {
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
w.tsigStatus = TsigVerify(w.msg, secret, "", false) w.tsigStatus = TsigVerify(m, secret, "", false)
} else { } else {
w.tsigStatus = ErrSecret w.tsigStatus = ErrSecret
} }
@ -598,7 +597,10 @@ func (srv *Server) serveDNS(w *response) {
} }
} }
srv.disposeBuffer(w) if w.udp != nil && cap(m) == srv.UDPSize {
srv.udpPool.Put(m[:srv.UDPSize])
}
srv.Handler.ServeDNS(w, req) // Writes back to the client srv.Handler.ServeDNS(w, req) // Writes back to the client
} }