Add MsgAcceptFunc in server

Generalize the srv.Unsafe and make it pluggeable. Also add a default
accept function that allows to discard malformed DNS messages very early
on. Before we allocate and parse anything furher.

Also re-use the client's message when sending a reply.

Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
Miek Gieben 2018-11-27 10:43:01 +00:00
parent 6bf402f3c4
commit 2c18e7259a
4 changed files with 108 additions and 48 deletions

46
acceptfunc.go Normal file
View File

@ -0,0 +1,46 @@
package dns
// MsgAcceptFunc is used early in the server code to accept or reject a message with RcodeFormatError.
// There are to booleans to be returned, once signaling the rejection and another to signal if
// a reply is to be send back (you want to prevent DNS ping-pong and not reply to a response for instance).
type MsgAcceptFunc func(dh Header) (accept bool, respond bool)
// DefaultMsgAcceptFunc checks the request and will reject if:
//
// * isn't a request (don't respond in that case).
// * opcode isn't OpcodeQuery or OpcodeNotify
// * Zero bit isn't zero
// * has more than 1 question in the question section
// * has more than 0 RRs in the Answer section
// * has more than 0 RRs in the Authority section
// * has more than 2 RRs in the Additional section
var DefaultMsgAcceptFunc = defaultMsgAcceptFunc
var defaultMsgAcceptFunc = func(dh Header) (bool, bool) {
if isResponse := dh.Bits&_QR != 0; isResponse {
return false, false
}
// Don't allow dynamic updates, because then the sections can contain a whole bunch of RRs.
opcode := int(dh.Bits>>11) & 0xF
if opcode != OpcodeQuery && opcode != OpcodeNotify {
return false, true
}
if isZero := dh.Bits&_Z != 0; isZero {
return false, true
}
if dh.Qdcount != 1 {
return false, true
}
if dh.Ancount != 0 {
return false, true
}
if dh.Nscount != 0 {
return false, true
}
if dh.Arcount > 2 {
return false, true
}
return true, true
}

52
msg.go
View File

@ -778,30 +778,8 @@ func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression map[string]
return msg[:off], nil
}
// Unpack unpacks a binary message to a Msg structure.
func (dns *Msg) Unpack(msg []byte) (err error) {
// We use a similar function in tsig.go's stripTsig.
var (
dh Header
off int
)
if dh, off, err = unpackMsgHdr(msg, off); err != nil {
return err
}
dns.Id = dh.Id
dns.Response = dh.Bits&_QR != 0
dns.Opcode = int(dh.Bits>>11) & 0xF
dns.Authoritative = dh.Bits&_AA != 0
dns.Truncated = dh.Bits&_TC != 0
dns.RecursionDesired = dh.Bits&_RD != 0
dns.RecursionAvailable = dh.Bits&_RA != 0
dns.Zero = dh.Bits&_Z != 0
dns.AuthenticatedData = dh.Bits&_AD != 0
dns.CheckingDisabled = dh.Bits&_CD != 0
dns.Rcode = int(dh.Bits & 0xF)
func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) {
dns.setHdr(dh)
// If we are at the end of the message we should return *just* the
// header. This can still be useful to the caller. 9.9.9.9 sends these
// when responding with REFUSED for instance.
@ -854,6 +832,17 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
// println("dns: extra bytes in dns packet", off, "<", len(msg))
}
return err
}
// Unpack unpacks a binary message to a Msg structure.
func (dns *Msg) Unpack(msg []byte) (err error) {
dh, off, err := unpackMsgHdr(msg, 0)
if err != nil {
return err
}
return dns.unpack(dh, msg, off)
}
// Convert a complete message to a string with dig-like output.
@ -1196,3 +1185,18 @@ func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
dh.Arcount, off, err = unpackUint16(msg, off)
return dh, off, err
}
// setHdr set the header in the dns using the binary data in dh.
func (dns *Msg) setHdr(dh Header) {
dns.Id = dh.Id
dns.Response = dh.Bits&_QR != 0
dns.Opcode = int(dh.Bits>>11) & 0xF
dns.Authoritative = dh.Bits&_AA != 0
dns.Truncated = dh.Bits&_TC != 0
dns.RecursionDesired = dh.Bits&_RD != 0
dns.RecursionAvailable = dh.Bits&_RA != 0
dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite.
dns.AuthenticatedData = dh.Bits&_AD != 0
dns.CheckingDisabled = dh.Bits&_CD != 0
dns.Rcode = int(dh.Bits & 0xF)
}

View File

@ -203,9 +203,6 @@ type Server struct {
IdleTimeout func() time.Duration
// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
TsigSecret map[string]string
// Unsafe instructs the server to disregard any sanity checks and directly hand the message to
// the handler. It will specifically not check if the query has the QR bit not set.
Unsafe bool
// If NotifyStartedFunc is set it is called once the server has started listening.
NotifyStartedFunc func()
// DecorateReader is optional, allows customization of the process that reads raw DNS messages.
@ -217,6 +214,9 @@ type Server struct {
// Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
// It is only supported on go1.11+ and when using ListenAndServe.
ReusePort bool
// AcceptMsgFunc will check the incoming message and will reject it early in the process. This function must be
// defined. By default DefaultMsgAcceptFunc will be used.
MsgAcceptFunc MsgAcceptFunc
// UDP packet or TCP connection queue
queue chan *response
@ -300,6 +300,9 @@ func (srv *Server) init() {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
if srv.MsgAcceptFunc == nil {
srv.MsgAcceptFunc = defaultMsgAcceptFunc
}
srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
}
@ -630,14 +633,33 @@ func (srv *Server) disposeBuffer(w *response) {
}
func (srv *Server) serveDNS(w *response) {
req := new(Msg)
err := req.Unpack(w.msg)
if err != nil { // Send a FormatError back
x := new(Msg)
x.SetRcodeFormatError(req)
w.WriteMsg(x)
dh, off, err := unpackMsgHdr(w.msg, 0)
if err != nil {
// Let client hang, they are sending crap; any reply can be used to amplify.
return
}
if err != nil || !srv.Unsafe && req.Response {
req := new(Msg)
req.setHdr(dh)
if accept, respond := srv.MsgAcceptFunc(dh); !accept {
if !respond {
return
}
req.SetRcodeFormatError(req)
// Are we allowed to delete any OPT records here?
req.Ns, req.Answer, req.Extra = nil, nil, nil
w.WriteMsg(req)
srv.disposeBuffer(w)
return
}
if err := req.unpack(dh, w.msg, off); err != nil {
req.SetRcodeFormatError(req)
req.Ns, req.Answer, req.Extra = nil, nil, nil
w.WriteMsg(req)
srv.disposeBuffer(w)
return
}

View File

@ -102,6 +102,7 @@ func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, er
}
server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
server.init()
waitLock := sync.Mutex{}
waitLock.Lock()
@ -568,6 +569,7 @@ func TestServingResponse(t *testing.T) {
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()
c := new(Client)
m := new(Msg)
@ -582,20 +584,6 @@ func TestServingResponse(t *testing.T) {
if err == nil {
t.Fatal("exchanged response message")
}
s.Shutdown()
s, addrstr, _, err = RunLocalUDPServerWithFinChan(":0",
func(srv *Server) { srv.Unsafe = true })
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()
m.Response = true
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Fatal("could exchanged response message in Unsafe mode")
}
}
func TestShutdownTCP(t *testing.T) {