From 7fdfb0141b47f4a23be03aa1a1cb3b80c7c3971b Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Sun, 1 Apr 2018 12:27:36 +0100 Subject: [PATCH] Revert "Cleanup serve function (#653)" This reverts commit d174bbf0a57b4ab555db36b0e55f692d5e8dfca8. --- server.go | 133 ++++++++++++++++++++++++++---------------------------- 1 file changed, 64 insertions(+), 69 deletions(-) diff --git a/server.go b/server.go index 20ef6b10..685753f4 100644 --- a/server.go +++ b/server.go @@ -51,6 +51,7 @@ type response struct { udp *net.UDPConn // i/o connection if UDP was used tcp net.Conn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right + remoteAddr net.Addr // address of the client writer Writer // writer to output the raw DNS bits } @@ -446,10 +447,16 @@ func (srv *Server) serveTCP(l net.Listener) error { srv.NotifyStartedFunc() } + reader := Reader(&defaultReader{srv}) + if srv.DecorateReader != nil { + reader = srv.DecorateReader(reader) + } + handler := srv.Handler if handler == nil { handler = DefaultServeMux } + rtimeout := srv.getReadTimeout() // deadline is not used here for { rw, err := l.Accept() @@ -465,7 +472,14 @@ func (srv *Server) serveTCP(l net.Listener) error { } return err } - go srv.serveTCPConn(handler, rw) + go func() { + m, err := reader.ReadTCP(rw, rtimeout) + if err != nil { + rw.Close() + return + } + srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) + }() } } @@ -506,94 +520,80 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { if len(m) < headerSize { continue } - go srv.serveUDPPacket(handler, m, l, s) + go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) } } -// Serve a new TCP connection. -func (srv *Server) serveTCPConn(h Handler, t net.Conn) { +// Serve a new connection. +func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) { + w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} + if srv.DecorateWriter != nil { + w.writer = srv.DecorateWriter(w) + } else { + w.writer = w + } + + q := 0 // counter for the amount of TCP queries we get + reader := Reader(&defaultReader{srv}) if srv.DecorateReader != nil { reader = srv.DecorateReader(reader) } - - w := &response{tsigSecret: srv.TsigSecret, tcp: t} - if srv.DecorateWriter != nil { - w.writer = srv.DecorateWriter(w) - } else { - w.writer = w - } - - defer func() { - if !w.hijacked { - w.Close() - } - }() - - idleTimeout := tcpIdleTimeout - if srv.IdleTimeout != nil { - idleTimeout = srv.IdleTimeout() - } - - timeout := srv.getReadTimeout() - - // TODO(miek): make maxTCPQueries configurable? - for q := 0; q < maxTCPQueries; q++ { - m, err := reader.ReadTCP(t, timeout) - if err != nil { - // TODO(tmthrgd): handle error - break - } - srv.serveDNS(h, m, w) - if w.tcp == nil { - break // Close() was called - } - if w.hijacked { - break // client will call Close() themselves - } - // The first read uses the read timeout, the rest use the - // idle timeout. - timeout = idleTimeout - } -} - -// Serve a new UDP request. -func (srv *Server) serveUDPPacket(h Handler, m []byte, u *net.UDPConn, s *SessionUDP) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s} - if srv.DecorateWriter != nil { - w.writer = srv.DecorateWriter(w) - } else { - w.writer = w - } - srv.serveDNS(h, m, w) -} - -func (srv *Server) serveDNS(h Handler, m []byte, w *response) { +Redo: req := new(Msg) err := req.Unpack(m) if err != nil { // Send a FormatError back x := new(Msg) x.SetRcodeFormatError(req) w.WriteMsg(x) - return + goto Exit } if !srv.Unsafe && req.Response { - return + goto Exit } w.tsigStatus = nil if w.tsigSecret != nil { if t := req.IsTsig(); t != nil { - if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { - w.tsigStatus = TsigVerify(m, secret, "", false) - } else { - w.tsigStatus = ErrSecret + secret := t.Hdr.Name + if _, ok := w.tsigSecret[secret]; !ok { + w.tsigStatus = ErrKeyAlg } + w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false) w.tsigTimersOnly = false w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC } } h.ServeDNS(w, req) // Writes back to the client + +Exit: + if w.tcp == nil { + return + } + // TODO(miek): make this number configurable? + if q > maxTCPQueries { // close socket after this many queries + w.Close() + return + } + + if w.hijacked { + return // client calls Close() + } + if u != nil { // UDP, "close" and return + w.Close() + return + } + idleTimeout := tcpIdleTimeout + if srv.IdleTimeout != nil { + idleTimeout = srv.IdleTimeout() + } + m, err = reader.ReadTCP(w.tcp, idleTimeout) + if err == nil { + q++ + goto Redo + } + w.Close() + return } func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { @@ -696,12 +696,7 @@ func (w *response) LocalAddr() net.Addr { } // RemoteAddr implements the ResponseWriter.RemoteAddr method. -func (w *response) RemoteAddr() net.Addr { - if w.tcp != nil { - return w.tcp.RemoteAddr() - } - return w.udpSession.RemoteAddr() -} +func (w *response) RemoteAddr() net.Addr { return w.remoteAddr } // TsigStatus implements the ResponseWriter.TsigStatus method. func (w *response) TsigStatus() error { return w.tsigStatus }