diff --git a/server.go b/server.go index 685753f4..20ef6b10 100644 --- a/server.go +++ b/server.go @@ -51,7 +51,6 @@ type response struct { udp *net.UDPConn // i/o connection if UDP was used tcp net.Conn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right - remoteAddr net.Addr // address of the client writer Writer // writer to output the raw DNS bits } @@ -447,16 +446,10 @@ func (srv *Server) serveTCP(l net.Listener) error { srv.NotifyStartedFunc() } - reader := Reader(&defaultReader{srv}) - if srv.DecorateReader != nil { - reader = srv.DecorateReader(reader) - } - handler := srv.Handler if handler == nil { handler = DefaultServeMux } - rtimeout := srv.getReadTimeout() // deadline is not used here for { rw, err := l.Accept() @@ -472,14 +465,7 @@ func (srv *Server) serveTCP(l net.Listener) error { } return err } - go func() { - m, err := reader.ReadTCP(rw, rtimeout) - if err != nil { - rw.Close() - return - } - srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) - }() + go srv.serveTCPConn(handler, rw) } } @@ -520,80 +506,94 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { if len(m) < headerSize { continue } - go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) + go srv.serveUDPPacket(handler, m, l, s) } } -// Serve a new connection. -func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} +// Serve a new TCP connection. +func (srv *Server) serveTCPConn(h Handler, t net.Conn) { + reader := Reader(&defaultReader{srv}) + if srv.DecorateReader != nil { + reader = srv.DecorateReader(reader) + } + + w := &response{tsigSecret: srv.TsigSecret, tcp: t} if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { w.writer = w } - q := 0 // counter for the amount of TCP queries we get + defer func() { + if !w.hijacked { + w.Close() + } + }() - reader := Reader(&defaultReader{srv}) - if srv.DecorateReader != nil { - reader = srv.DecorateReader(reader) + idleTimeout := tcpIdleTimeout + if srv.IdleTimeout != nil { + idleTimeout = srv.IdleTimeout() } -Redo: + + timeout := srv.getReadTimeout() + + // TODO(miek): make maxTCPQueries configurable? + for q := 0; q < maxTCPQueries; q++ { + m, err := reader.ReadTCP(t, timeout) + if err != nil { + // TODO(tmthrgd): handle error + break + } + srv.serveDNS(h, m, w) + if w.tcp == nil { + break // Close() was called + } + if w.hijacked { + break // client will call Close() themselves + } + // The first read uses the read timeout, the rest use the + // idle timeout. + timeout = idleTimeout + } +} + +// Serve a new UDP request. +func (srv *Server) serveUDPPacket(h Handler, m []byte, u *net.UDPConn, s *SessionUDP) { + w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s} + if srv.DecorateWriter != nil { + w.writer = srv.DecorateWriter(w) + } else { + w.writer = w + } + srv.serveDNS(h, m, w) +} + +func (srv *Server) serveDNS(h Handler, m []byte, w *response) { req := new(Msg) err := req.Unpack(m) if err != nil { // Send a FormatError back x := new(Msg) x.SetRcodeFormatError(req) w.WriteMsg(x) - goto Exit + return } if !srv.Unsafe && req.Response { - goto Exit + return } w.tsigStatus = nil if w.tsigSecret != nil { if t := req.IsTsig(); t != nil { - secret := t.Hdr.Name - if _, ok := w.tsigSecret[secret]; !ok { - w.tsigStatus = ErrKeyAlg + if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { + w.tsigStatus = TsigVerify(m, secret, "", false) + } else { + w.tsigStatus = ErrSecret } - w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false) w.tsigTimersOnly = false w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC } } h.ServeDNS(w, req) // Writes back to the client - -Exit: - if w.tcp == nil { - return - } - // TODO(miek): make this number configurable? - if q > maxTCPQueries { // close socket after this many queries - w.Close() - return - } - - if w.hijacked { - return // client calls Close() - } - if u != nil { // UDP, "close" and return - w.Close() - return - } - idleTimeout := tcpIdleTimeout - if srv.IdleTimeout != nil { - idleTimeout = srv.IdleTimeout() - } - m, err = reader.ReadTCP(w.tcp, idleTimeout) - if err == nil { - q++ - goto Redo - } - w.Close() - return } func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { @@ -696,7 +696,12 @@ func (w *response) LocalAddr() net.Addr { } // RemoteAddr implements the ResponseWriter.RemoteAddr method. -func (w *response) RemoteAddr() net.Addr { return w.remoteAddr } +func (w *response) RemoteAddr() net.Addr { + if w.tcp != nil { + return w.tcp.RemoteAddr() + } + return w.udpSession.RemoteAddr() +} // TsigStatus implements the ResponseWriter.TsigStatus method. func (w *response) TsigStatus() error { return w.tsigStatus }