Correctly implement multiple queries over 1 tcp conn.
Completely transparant give users another query to handle.
This commit is contained in:
parent
ed0b128bd2
commit
5eca59c9e7
142
server.go
142
server.go
|
@ -195,8 +195,8 @@ type Server struct {
|
|||
UDPSize int // default buffer size to use to read incoming UDP messages
|
||||
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections
|
||||
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections
|
||||
IdleTimeout func() time.Duration // TCP idle timeout, see RFC 5966, if nil, defaults to 1 time.Minute
|
||||
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
|
||||
IdleTimeout func() time.Duration // TCP idle timeout, see RFC 5966, if nil, default to 1 time.Minute
|
||||
}
|
||||
|
||||
// ListenAndServe starts a nameserver on the configured address in *Server.
|
||||
|
@ -238,43 +238,20 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
|
|||
if handler == nil {
|
||||
handler = DefaultServeMux
|
||||
}
|
||||
forever:
|
||||
rtimeout := dnsTimeout
|
||||
if srv.ReadTimeout != 0 {
|
||||
rtimeout = srv.ReadTimeout
|
||||
}
|
||||
for {
|
||||
rw, e := l.AcceptTCP()
|
||||
if e != nil {
|
||||
// don't bail out, but wait for a new request
|
||||
continue
|
||||
}
|
||||
if srv.ReadTimeout != 0 {
|
||||
rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
|
||||
}
|
||||
if srv.WriteTimeout != 0 {
|
||||
rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout))
|
||||
}
|
||||
l := make([]byte, 2)
|
||||
n, err := rw.Read(l)
|
||||
if err != nil || n != 2 {
|
||||
m, e := srv.readTCP(rw, rtimeout)
|
||||
if e != nil {
|
||||
continue
|
||||
}
|
||||
length, _ := unpackUint16(l, 0)
|
||||
if length == 0 {
|
||||
continue
|
||||
}
|
||||
m := make([]byte, int(length))
|
||||
n, err = rw.Read(m[:int(length)])
|
||||
if err != nil || n == 0 {
|
||||
continue
|
||||
}
|
||||
i := n
|
||||
for i < int(length) {
|
||||
j, err := rw.Read(m[i:int(length)])
|
||||
if err != nil {
|
||||
continue forever
|
||||
}
|
||||
i += j
|
||||
}
|
||||
n = i
|
||||
go serve(rw.RemoteAddr(), handler, m, nil, rw, srv.TsigSecret)
|
||||
go srv.serve(rw.RemoteAddr(), handler, m, nil, rw)
|
||||
}
|
||||
panic("dns: not reached")
|
||||
}
|
||||
|
@ -290,62 +267,113 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
|||
if srv.UDPSize == 0 {
|
||||
srv.UDPSize = MinMsgSize
|
||||
}
|
||||
rtimeout := dnsTimeout
|
||||
if srv.ReadTimeout != 0 {
|
||||
rtimeout = srv.ReadTimeout
|
||||
}
|
||||
for {
|
||||
if srv.ReadTimeout != 0 {
|
||||
l.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
|
||||
}
|
||||
if srv.WriteTimeout != 0 {
|
||||
l.SetWriteDeadline(time.Now().Add(srv.WriteTimeout))
|
||||
}
|
||||
m := make([]byte, srv.UDPSize)
|
||||
n, a, e := l.ReadFromUDP(m)
|
||||
if e != nil || n == 0 {
|
||||
// don't bail out, but wait for a new request
|
||||
m, a, e := srv.readUDP(l, rtimeout)
|
||||
if e != nil {
|
||||
continue
|
||||
}
|
||||
m = m[:n]
|
||||
go serve(a, handler, m, l, nil, srv.TsigSecret)
|
||||
go srv.serve(a, handler, m, l, nil)
|
||||
}
|
||||
panic("dns: not reached")
|
||||
}
|
||||
|
||||
// Serve a new connection.
|
||||
func serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, t *net.TCPConn, tsigSecret map[string]string) {
|
||||
// Request has been read in serveUDP or serveTCP
|
||||
w := &response{tsigSecret: tsigSecret, udp: u, tcp: t, remoteAddr: a}
|
||||
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, t *net.TCPConn) {
|
||||
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a}
|
||||
Redo:
|
||||
req := new(Msg)
|
||||
if req.Unpack(m) != nil {
|
||||
// Send a format error back
|
||||
x := new(Msg)
|
||||
x.SetRcodeFormatError(req)
|
||||
w.WriteMsg(x)
|
||||
w.Close()
|
||||
return
|
||||
goto Exit
|
||||
}
|
||||
defer func() {
|
||||
if w.hijacked {
|
||||
// client takes care of the connection, i.e. calls Close()
|
||||
return
|
||||
}
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
w.tsigStatus = nil
|
||||
if w.tsigSecret != nil {
|
||||
if t := req.IsTsig(); t != nil {
|
||||
secret := t.Hdr.Name
|
||||
if _, ok := tsigSecret[secret]; !ok {
|
||||
if _, ok := w.tsigSecret[secret]; !ok {
|
||||
w.tsigStatus = ErrKeyAlg
|
||||
}
|
||||
w.tsigStatus = TsigVerify(m, tsigSecret[secret], "", false)
|
||||
w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
|
||||
w.tsigTimersOnly = false
|
||||
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
|
||||
}
|
||||
}
|
||||
h.ServeDNS(w, req) // this does the writing back to the client
|
||||
|
||||
Exit:
|
||||
if w.hijacked {
|
||||
return // client takes care of the connection, i.e. calls Close()
|
||||
}
|
||||
if u != nil { // UDP, "close" and return
|
||||
w.Close()
|
||||
return
|
||||
}
|
||||
idleTimeout := tcpIdleTimeout
|
||||
if srv.IdleTimeout != nil {
|
||||
idleTimeout = srv.IdleTimeout()
|
||||
}
|
||||
m, e := srv.readTCP(w.tcp, idleTimeout)
|
||||
if e == nil {
|
||||
goto Redo
|
||||
}
|
||||
w.Close()
|
||||
return
|
||||
}
|
||||
|
||||
func (srv *Server) readTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) {
|
||||
conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
l := make([]byte, 2)
|
||||
n, err := conn.Read(l)
|
||||
if err != nil || n != 2 {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, ErrConn
|
||||
}
|
||||
length, _ := unpackUint16(l, 0)
|
||||
if length == 0 {
|
||||
return nil, ErrConn
|
||||
}
|
||||
m := make([]byte, int(length))
|
||||
n, err = conn.Read(m[:int(length)])
|
||||
if err != nil || n == 0 {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, ErrConn
|
||||
}
|
||||
i := n
|
||||
for i < int(length) {
|
||||
j, err := conn.Read(m[i:int(length)])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i += j
|
||||
}
|
||||
n = i
|
||||
m = m[:n]
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, net.Addr, error) {
|
||||
conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
m := make([]byte, srv.UDPSize)
|
||||
n, a, e := conn.ReadFromUDP(m)
|
||||
if e != nil || n == 0 {
|
||||
return nil, nil, ErrConn
|
||||
}
|
||||
m = m[:n]
|
||||
return m, a, nil
|
||||
}
|
||||
|
||||
// WriteMsg implements the ResponseWriter.WriteMsg method.
|
||||
func (w *response) WriteMsg(m *Msg) (err error) {
|
||||
var data []byte
|
||||
|
|
Loading…
Reference in New Issue