Correctly implement multiple queries over 1 tcp conn.

Completely transparant give users another query to handle.
This commit is contained in:
Miek Gieben 2013-10-18 23:06:28 +01:00
parent ed0b128bd2
commit 5eca59c9e7
1 changed files with 85 additions and 57 deletions

142
server.go
View File

@ -195,8 +195,8 @@ type Server struct {
UDPSize int // default buffer size to use to read incoming UDP messages UDPSize int // default buffer size to use to read incoming UDP messages
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections
IdleTimeout func() time.Duration // TCP idle timeout, see RFC 5966, if nil, defaults to 1 time.Minute
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret> TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
IdleTimeout func() time.Duration // TCP idle timeout, see RFC 5966, if nil, default to 1 time.Minute
} }
// ListenAndServe starts a nameserver on the configured address in *Server. // ListenAndServe starts a nameserver on the configured address in *Server.
@ -238,43 +238,20 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
if handler == nil { if handler == nil {
handler = DefaultServeMux handler = DefaultServeMux
} }
forever: rtimeout := dnsTimeout
if srv.ReadTimeout != 0 {
rtimeout = srv.ReadTimeout
}
for { for {
rw, e := l.AcceptTCP() rw, e := l.AcceptTCP()
if e != nil { if e != nil {
// don't bail out, but wait for a new request
continue continue
} }
if srv.ReadTimeout != 0 { m, e := srv.readTCP(rw, rtimeout)
rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) if e != nil {
}
if srv.WriteTimeout != 0 {
rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout))
}
l := make([]byte, 2)
n, err := rw.Read(l)
if err != nil || n != 2 {
continue continue
} }
length, _ := unpackUint16(l, 0) go srv.serve(rw.RemoteAddr(), handler, m, nil, rw)
if length == 0 {
continue
}
m := make([]byte, int(length))
n, err = rw.Read(m[:int(length)])
if err != nil || n == 0 {
continue
}
i := n
for i < int(length) {
j, err := rw.Read(m[i:int(length)])
if err != nil {
continue forever
}
i += j
}
n = i
go serve(rw.RemoteAddr(), handler, m, nil, rw, srv.TsigSecret)
} }
panic("dns: not reached") panic("dns: not reached")
} }
@ -290,62 +267,113 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
if srv.UDPSize == 0 { if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize srv.UDPSize = MinMsgSize
} }
rtimeout := dnsTimeout
if srv.ReadTimeout != 0 {
rtimeout = srv.ReadTimeout
}
for { for {
if srv.ReadTimeout != 0 { m, a, e := srv.readUDP(l, rtimeout)
l.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) if e != nil {
}
if srv.WriteTimeout != 0 {
l.SetWriteDeadline(time.Now().Add(srv.WriteTimeout))
}
m := make([]byte, srv.UDPSize)
n, a, e := l.ReadFromUDP(m)
if e != nil || n == 0 {
// don't bail out, but wait for a new request
continue continue
} }
m = m[:n] go srv.serve(a, handler, m, l, nil)
go serve(a, handler, m, l, nil, srv.TsigSecret)
} }
panic("dns: not reached") panic("dns: not reached")
} }
// Serve a new connection. // Serve a new connection.
func serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, t *net.TCPConn, tsigSecret map[string]string) { func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, t *net.TCPConn) {
// Request has been read in serveUDP or serveTCP w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a}
w := &response{tsigSecret: tsigSecret, udp: u, tcp: t, remoteAddr: a} Redo:
req := new(Msg) req := new(Msg)
if req.Unpack(m) != nil { if req.Unpack(m) != nil {
// Send a format error back // Send a format error back
x := new(Msg) x := new(Msg)
x.SetRcodeFormatError(req) x.SetRcodeFormatError(req)
w.WriteMsg(x) w.WriteMsg(x)
w.Close() goto Exit
return
} }
defer func() {
if w.hijacked {
// client takes care of the connection, i.e. calls Close()
return
}
w.Close()
}()
w.tsigStatus = nil w.tsigStatus = nil
if w.tsigSecret != nil { if w.tsigSecret != nil {
if t := req.IsTsig(); t != nil { if t := req.IsTsig(); t != nil {
secret := t.Hdr.Name secret := t.Hdr.Name
if _, ok := tsigSecret[secret]; !ok { if _, ok := w.tsigSecret[secret]; !ok {
w.tsigStatus = ErrKeyAlg w.tsigStatus = ErrKeyAlg
} }
w.tsigStatus = TsigVerify(m, tsigSecret[secret], "", false) w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
w.tsigTimersOnly = false w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
} }
} }
h.ServeDNS(w, req) // this does the writing back to the client h.ServeDNS(w, req) // this does the writing back to the client
Exit:
if w.hijacked {
return // client takes care of the connection, i.e. calls Close()
}
if u != nil { // UDP, "close" and return
w.Close()
return
}
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
}
m, e := srv.readTCP(w.tcp, idleTimeout)
if e == nil {
goto Redo
}
w.Close()
return return
} }
func (srv *Server) readTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
l := make([]byte, 2)
n, err := conn.Read(l)
if err != nil || n != 2 {
if err != nil {
return nil, err
}
return nil, ErrConn
}
length, _ := unpackUint16(l, 0)
if length == 0 {
return nil, ErrConn
}
m := make([]byte, int(length))
n, err = conn.Read(m[:int(length)])
if err != nil || n == 0 {
if err != nil {
return nil, err
}
return nil, ErrConn
}
i := n
for i < int(length) {
j, err := conn.Read(m[i:int(length)])
if err != nil {
return nil, err
}
i += j
}
n = i
m = m[:n]
return m, nil
}
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, net.Addr, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
m := make([]byte, srv.UDPSize)
n, a, e := conn.ReadFromUDP(m)
if e != nil || n == 0 {
return nil, nil, ErrConn
}
m = m[:n]
return m, a, nil
}
// WriteMsg implements the ResponseWriter.WriteMsg method. // WriteMsg implements the ResponseWriter.WriteMsg method.
func (w *response) WriteMsg(m *Msg) (err error) { func (w *response) WriteMsg(m *Msg) (err error) {
var data []byte var data []byte