Cleanup serve function (#653)

* Split central ServeDNS code out of (*Server).serve

* Add UDP and TCP specific (*Server).serve wrappers

* Move UDP serve functionality into serveUDPPacket

* Merge serve into serveTCPConn

* Cleanup serveTCPConn replacing goto with for

* defer Close in serveTCPConn

* Remove remoteAddr field from response struct

* Fix broken tsigSecret check in serveDNS

* Reorder serveDNS arguments

This makes it consistent with the ordering of arguments to
serveUDPPacket and serveTCPConn.
This commit is contained in:
Tom Thorogood 2018-03-31 00:20:27 +10:30 committed by Miek Gieben
parent 22cb769f47
commit d174bbf0a5
1 changed files with 65 additions and 60 deletions

125
server.go
View File

@ -51,7 +51,6 @@ type response struct {
udp *net.UDPConn // i/o connection if UDP was used udp *net.UDPConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right udpSession *SessionUDP // oob data to get egress interface right
remoteAddr net.Addr // address of the client
writer Writer // writer to output the raw DNS bits writer Writer // writer to output the raw DNS bits
} }
@ -447,16 +446,10 @@ func (srv *Server) serveTCP(l net.Listener) error {
srv.NotifyStartedFunc() srv.NotifyStartedFunc()
} }
reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
handler := srv.Handler handler := srv.Handler
if handler == nil { if handler == nil {
handler = DefaultServeMux handler = DefaultServeMux
} }
rtimeout := srv.getReadTimeout()
// deadline is not used here // deadline is not used here
for { for {
rw, err := l.Accept() rw, err := l.Accept()
@ -472,14 +465,7 @@ func (srv *Server) serveTCP(l net.Listener) error {
} }
return err return err
} }
go func() { go srv.serveTCPConn(handler, rw)
m, err := reader.ReadTCP(rw, rtimeout)
if err != nil {
rw.Close()
return
}
srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
}()
} }
} }
@ -520,80 +506,94 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
if len(m) < headerSize { if len(m) < headerSize {
continue continue
} }
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) go srv.serveUDPPacket(handler, m, l, s)
} }
} }
// Serve a new connection. // Serve a new TCP connection.
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) { func (srv *Server) serveTCPConn(h Handler, t net.Conn) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
w := &response{tsigSecret: srv.TsigSecret, tcp: t}
if srv.DecorateWriter != nil { if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w) w.writer = srv.DecorateWriter(w)
} else { } else {
w.writer = w w.writer = w
} }
q := 0 // counter for the amount of TCP queries we get defer func() {
if !w.hijacked {
w.Close()
}
}()
reader := Reader(&defaultReader{srv}) idleTimeout := tcpIdleTimeout
if srv.DecorateReader != nil { if srv.IdleTimeout != nil {
reader = srv.DecorateReader(reader) idleTimeout = srv.IdleTimeout()
} }
Redo:
timeout := srv.getReadTimeout()
// TODO(miek): make maxTCPQueries configurable?
for q := 0; q < maxTCPQueries; q++ {
m, err := reader.ReadTCP(t, timeout)
if err != nil {
// TODO(tmthrgd): handle error
break
}
srv.serveDNS(h, m, w)
if w.tcp == nil {
break // Close() was called
}
if w.hijacked {
break // client will call Close() themselves
}
// The first read uses the read timeout, the rest use the
// idle timeout.
timeout = idleTimeout
}
}
// Serve a new UDP request.
func (srv *Server) serveUDPPacket(h Handler, m []byte, u *net.UDPConn, s *SessionUDP) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
w.writer = w
}
srv.serveDNS(h, m, w)
}
func (srv *Server) serveDNS(h Handler, m []byte, w *response) {
req := new(Msg) req := new(Msg)
err := req.Unpack(m) err := req.Unpack(m)
if err != nil { // Send a FormatError back if err != nil { // Send a FormatError back
x := new(Msg) x := new(Msg)
x.SetRcodeFormatError(req) x.SetRcodeFormatError(req)
w.WriteMsg(x) w.WriteMsg(x)
goto Exit return
} }
if !srv.Unsafe && req.Response { if !srv.Unsafe && req.Response {
goto Exit return
} }
w.tsigStatus = nil w.tsigStatus = nil
if w.tsigSecret != nil { if w.tsigSecret != nil {
if t := req.IsTsig(); t != nil { if t := req.IsTsig(); t != nil {
secret := t.Hdr.Name if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
if _, ok := w.tsigSecret[secret]; !ok { w.tsigStatus = TsigVerify(m, secret, "", false)
w.tsigStatus = ErrKeyAlg } else {
w.tsigStatus = ErrSecret
} }
w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
w.tsigTimersOnly = false w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
} }
} }
h.ServeDNS(w, req) // Writes back to the client h.ServeDNS(w, req) // Writes back to the client
Exit:
if w.tcp == nil {
return
}
// TODO(miek): make this number configurable?
if q > maxTCPQueries { // close socket after this many queries
w.Close()
return
}
if w.hijacked {
return // client calls Close()
}
if u != nil { // UDP, "close" and return
w.Close()
return
}
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
}
m, err = reader.ReadTCP(w.tcp, idleTimeout)
if err == nil {
q++
goto Redo
}
w.Close()
return
} }
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
@ -696,7 +696,12 @@ func (w *response) LocalAddr() net.Addr {
} }
// RemoteAddr implements the ResponseWriter.RemoteAddr method. // RemoteAddr implements the ResponseWriter.RemoteAddr method.
func (w *response) RemoteAddr() net.Addr { return w.remoteAddr } func (w *response) RemoteAddr() net.Addr {
if w.tcp != nil {
return w.tcp.RemoteAddr()
}
return w.udpSession.RemoteAddr()
}
// TsigStatus implements the ResponseWriter.TsigStatus method. // TsigStatus implements the ResponseWriter.TsigStatus method.
func (w *response) TsigStatus() error { return w.tsigStatus } func (w *response) TsigStatus() error { return w.tsigStatus }