Correctly implement multiple queries over 1 tcp conn.

Completely transparant give users another query to handle.
This commit is contained in:
Miek Gieben 2013-10-18 23:06:28 +01:00
parent ed0b128bd2
commit 5eca59c9e7
1 changed files with 85 additions and 57 deletions

142
server.go
View File

@ -195,8 +195,8 @@ type Server struct {
UDPSize int // default buffer size to use to read incoming UDP messages
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections
IdleTimeout func() time.Duration // TCP idle timeout, see RFC 5966, if nil, defaults to 1 time.Minute
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
IdleTimeout func() time.Duration // TCP idle timeout, see RFC 5966, if nil, default to 1 time.Minute
}
// ListenAndServe starts a nameserver on the configured address in *Server.
@ -238,43 +238,20 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
if handler == nil {
handler = DefaultServeMux
}
forever:
rtimeout := dnsTimeout
if srv.ReadTimeout != 0 {
rtimeout = srv.ReadTimeout
}
for {
rw, e := l.AcceptTCP()
if e != nil {
// don't bail out, but wait for a new request
continue
}
if srv.ReadTimeout != 0 {
rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
}
if srv.WriteTimeout != 0 {
rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout))
}
l := make([]byte, 2)
n, err := rw.Read(l)
if err != nil || n != 2 {
m, e := srv.readTCP(rw, rtimeout)
if e != nil {
continue
}
length, _ := unpackUint16(l, 0)
if length == 0 {
continue
}
m := make([]byte, int(length))
n, err = rw.Read(m[:int(length)])
if err != nil || n == 0 {
continue
}
i := n
for i < int(length) {
j, err := rw.Read(m[i:int(length)])
if err != nil {
continue forever
}
i += j
}
n = i
go serve(rw.RemoteAddr(), handler, m, nil, rw, srv.TsigSecret)
go srv.serve(rw.RemoteAddr(), handler, m, nil, rw)
}
panic("dns: not reached")
}
@ -290,62 +267,113 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
rtimeout := dnsTimeout
if srv.ReadTimeout != 0 {
rtimeout = srv.ReadTimeout
}
for {
if srv.ReadTimeout != 0 {
l.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
}
if srv.WriteTimeout != 0 {
l.SetWriteDeadline(time.Now().Add(srv.WriteTimeout))
}
m := make([]byte, srv.UDPSize)
n, a, e := l.ReadFromUDP(m)
if e != nil || n == 0 {
// don't bail out, but wait for a new request
m, a, e := srv.readUDP(l, rtimeout)
if e != nil {
continue
}
m = m[:n]
go serve(a, handler, m, l, nil, srv.TsigSecret)
go srv.serve(a, handler, m, l, nil)
}
panic("dns: not reached")
}
// Serve a new connection.
func serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, t *net.TCPConn, tsigSecret map[string]string) {
// Request has been read in serveUDP or serveTCP
w := &response{tsigSecret: tsigSecret, udp: u, tcp: t, remoteAddr: a}
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, t *net.TCPConn) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a}
Redo:
req := new(Msg)
if req.Unpack(m) != nil {
// Send a format error back
x := new(Msg)
x.SetRcodeFormatError(req)
w.WriteMsg(x)
w.Close()
return
goto Exit
}
defer func() {
if w.hijacked {
// client takes care of the connection, i.e. calls Close()
return
}
w.Close()
}()
w.tsigStatus = nil
if w.tsigSecret != nil {
if t := req.IsTsig(); t != nil {
secret := t.Hdr.Name
if _, ok := tsigSecret[secret]; !ok {
if _, ok := w.tsigSecret[secret]; !ok {
w.tsigStatus = ErrKeyAlg
}
w.tsigStatus = TsigVerify(m, tsigSecret[secret], "", false)
w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
}
}
h.ServeDNS(w, req) // this does the writing back to the client
Exit:
if w.hijacked {
return // client takes care of the connection, i.e. calls Close()
}
if u != nil { // UDP, "close" and return
w.Close()
return
}
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
}
m, e := srv.readTCP(w.tcp, idleTimeout)
if e == nil {
goto Redo
}
w.Close()
return
}
func (srv *Server) readTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
l := make([]byte, 2)
n, err := conn.Read(l)
if err != nil || n != 2 {
if err != nil {
return nil, err
}
return nil, ErrConn
}
length, _ := unpackUint16(l, 0)
if length == 0 {
return nil, ErrConn
}
m := make([]byte, int(length))
n, err = conn.Read(m[:int(length)])
if err != nil || n == 0 {
if err != nil {
return nil, err
}
return nil, ErrConn
}
i := n
for i < int(length) {
j, err := conn.Read(m[i:int(length)])
if err != nil {
return nil, err
}
i += j
}
n = i
m = m[:n]
return m, nil
}
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, net.Addr, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
m := make([]byte, srv.UDPSize)
n, a, e := conn.ReadFromUDP(m)
if e != nil || n == 0 {
return nil, nil, ErrConn
}
m = m[:n]
return m, a, nil
}
// WriteMsg implements the ResponseWriter.WriteMsg method.
func (w *response) WriteMsg(m *Msg) (err error) {
var data []byte