From 5eca59c9e7a6c8c6088e96fba8d0809aacd9e7d8 Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Fri, 18 Oct 2013 23:06:28 +0100 Subject: [PATCH] Correctly implement multiple queries over 1 tcp conn. Completely transparant give users another query to handle. --- server.go | 142 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 85 insertions(+), 57 deletions(-) diff --git a/server.go b/server.go index 0de38237..208df6d4 100644 --- a/server.go +++ b/server.go @@ -195,8 +195,8 @@ type Server struct { UDPSize int // default buffer size to use to read incoming UDP messages ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections + IdleTimeout func() time.Duration // TCP idle timeout, see RFC 5966, if nil, defaults to 1 time.Minute TsigSecret map[string]string // secret(s) for Tsig map[] - IdleTimeout func() time.Duration // TCP idle timeout, see RFC 5966, if nil, default to 1 time.Minute } // ListenAndServe starts a nameserver on the configured address in *Server. @@ -238,43 +238,20 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { if handler == nil { handler = DefaultServeMux } -forever: + rtimeout := dnsTimeout + if srv.ReadTimeout != 0 { + rtimeout = srv.ReadTimeout + } for { rw, e := l.AcceptTCP() if e != nil { - // don't bail out, but wait for a new request continue } - if srv.ReadTimeout != 0 { - rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) - } - if srv.WriteTimeout != 0 { - rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout)) - } - l := make([]byte, 2) - n, err := rw.Read(l) - if err != nil || n != 2 { + m, e := srv.readTCP(rw, rtimeout) + if e != nil { continue } - length, _ := unpackUint16(l, 0) - if length == 0 { - continue - } - m := make([]byte, int(length)) - n, err = rw.Read(m[:int(length)]) - if err != nil || n == 0 { - continue - } - i := n - for i < int(length) { - j, err := rw.Read(m[i:int(length)]) - if err != nil { - continue forever - } - i += j - } - n = i - go serve(rw.RemoteAddr(), handler, m, nil, rw, srv.TsigSecret) + go srv.serve(rw.RemoteAddr(), handler, m, nil, rw) } panic("dns: not reached") } @@ -290,62 +267,113 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { if srv.UDPSize == 0 { srv.UDPSize = MinMsgSize } + rtimeout := dnsTimeout + if srv.ReadTimeout != 0 { + rtimeout = srv.ReadTimeout + } for { - if srv.ReadTimeout != 0 { - l.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) - } - if srv.WriteTimeout != 0 { - l.SetWriteDeadline(time.Now().Add(srv.WriteTimeout)) - } - m := make([]byte, srv.UDPSize) - n, a, e := l.ReadFromUDP(m) - if e != nil || n == 0 { - // don't bail out, but wait for a new request + m, a, e := srv.readUDP(l, rtimeout) + if e != nil { continue } - m = m[:n] - go serve(a, handler, m, l, nil, srv.TsigSecret) + go srv.serve(a, handler, m, l, nil) } panic("dns: not reached") } // Serve a new connection. -func serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, t *net.TCPConn, tsigSecret map[string]string) { - // Request has been read in serveUDP or serveTCP - w := &response{tsigSecret: tsigSecret, udp: u, tcp: t, remoteAddr: a} +func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, t *net.TCPConn) { + w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a} +Redo: req := new(Msg) if req.Unpack(m) != nil { // Send a format error back x := new(Msg) x.SetRcodeFormatError(req) w.WriteMsg(x) - w.Close() - return + goto Exit } - defer func() { - if w.hijacked { - // client takes care of the connection, i.e. calls Close() - return - } - w.Close() - }() w.tsigStatus = nil if w.tsigSecret != nil { if t := req.IsTsig(); t != nil { secret := t.Hdr.Name - if _, ok := tsigSecret[secret]; !ok { + if _, ok := w.tsigSecret[secret]; !ok { w.tsigStatus = ErrKeyAlg } - w.tsigStatus = TsigVerify(m, tsigSecret[secret], "", false) + w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false) w.tsigTimersOnly = false w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC } } h.ServeDNS(w, req) // this does the writing back to the client + +Exit: + if w.hijacked { + return // client takes care of the connection, i.e. calls Close() + } + if u != nil { // UDP, "close" and return + w.Close() + return + } + idleTimeout := tcpIdleTimeout + if srv.IdleTimeout != nil { + idleTimeout = srv.IdleTimeout() + } + m, e := srv.readTCP(w.tcp, idleTimeout) + if e == nil { + goto Redo + } + w.Close() return } +func (srv *Server) readTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) { + conn.SetReadDeadline(time.Now().Add(timeout)) + l := make([]byte, 2) + n, err := conn.Read(l) + if err != nil || n != 2 { + if err != nil { + return nil, err + } + return nil, ErrConn + } + length, _ := unpackUint16(l, 0) + if length == 0 { + return nil, ErrConn + } + m := make([]byte, int(length)) + n, err = conn.Read(m[:int(length)]) + if err != nil || n == 0 { + if err != nil { + return nil, err + } + return nil, ErrConn + } + i := n + for i < int(length) { + j, err := conn.Read(m[i:int(length)]) + if err != nil { + return nil, err + } + i += j + } + n = i + m = m[:n] + return m, nil +} + +func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, net.Addr, error) { + conn.SetReadDeadline(time.Now().Add(timeout)) + m := make([]byte, srv.UDPSize) + n, a, e := conn.ReadFromUDP(m) + if e != nil || n == 0 { + return nil, nil, ErrConn + } + m = m[:n] + return m, a, nil +} + // WriteMsg implements the ResponseWriter.WriteMsg method. func (w *response) WriteMsg(m *Msg) (err error) { var data []byte