From 2c18e7259a35458cf282adbfa12b04de0d00c899 Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Tue, 27 Nov 2018 10:43:01 +0000 Subject: [PATCH] Add MsgAcceptFunc in server Generalize the srv.Unsafe and make it pluggeable. Also add a default accept function that allows to discard malformed DNS messages very early on. Before we allocate and parse anything furher. Also re-use the client's message when sending a reply. Signed-off-by: Miek Gieben --- acceptfunc.go | 46 ++++++++++++++++++++++++++++++++++++++++++++ msg.go | 52 +++++++++++++++++++++++++++----------------------- server.go | 42 ++++++++++++++++++++++++++++++---------- server_test.go | 16 ++-------------- 4 files changed, 108 insertions(+), 48 deletions(-) create mode 100644 acceptfunc.go diff --git a/acceptfunc.go b/acceptfunc.go new file mode 100644 index 00000000..5f680e2b --- /dev/null +++ b/acceptfunc.go @@ -0,0 +1,46 @@ +package dns + +// MsgAcceptFunc is used early in the server code to accept or reject a message with RcodeFormatError. +// There are to booleans to be returned, once signaling the rejection and another to signal if +// a reply is to be send back (you want to prevent DNS ping-pong and not reply to a response for instance). +type MsgAcceptFunc func(dh Header) (accept bool, respond bool) + +// DefaultMsgAcceptFunc checks the request and will reject if: +// +// * isn't a request (don't respond in that case). +// * opcode isn't OpcodeQuery or OpcodeNotify +// * Zero bit isn't zero +// * has more than 1 question in the question section +// * has more than 0 RRs in the Answer section +// * has more than 0 RRs in the Authority section +// * has more than 2 RRs in the Additional section +var DefaultMsgAcceptFunc = defaultMsgAcceptFunc + +var defaultMsgAcceptFunc = func(dh Header) (bool, bool) { + if isResponse := dh.Bits&_QR != 0; isResponse { + return false, false + } + + // Don't allow dynamic updates, because then the sections can contain a whole bunch of RRs. + opcode := int(dh.Bits>>11) & 0xF + if opcode != OpcodeQuery && opcode != OpcodeNotify { + return false, true + } + + if isZero := dh.Bits&_Z != 0; isZero { + return false, true + } + if dh.Qdcount != 1 { + return false, true + } + if dh.Ancount != 0 { + return false, true + } + if dh.Nscount != 0 { + return false, true + } + if dh.Arcount > 2 { + return false, true + } + return true, true +} diff --git a/msg.go b/msg.go index 2b4c9901..ab6ebca2 100644 --- a/msg.go +++ b/msg.go @@ -778,30 +778,8 @@ func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression map[string] return msg[:off], nil } -// Unpack unpacks a binary message to a Msg structure. -func (dns *Msg) Unpack(msg []byte) (err error) { - // We use a similar function in tsig.go's stripTsig. - - var ( - dh Header - off int - ) - if dh, off, err = unpackMsgHdr(msg, off); err != nil { - return err - } - - dns.Id = dh.Id - dns.Response = dh.Bits&_QR != 0 - dns.Opcode = int(dh.Bits>>11) & 0xF - dns.Authoritative = dh.Bits&_AA != 0 - dns.Truncated = dh.Bits&_TC != 0 - dns.RecursionDesired = dh.Bits&_RD != 0 - dns.RecursionAvailable = dh.Bits&_RA != 0 - dns.Zero = dh.Bits&_Z != 0 - dns.AuthenticatedData = dh.Bits&_AD != 0 - dns.CheckingDisabled = dh.Bits&_CD != 0 - dns.Rcode = int(dh.Bits & 0xF) - +func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) { + dns.setHdr(dh) // If we are at the end of the message we should return *just* the // header. This can still be useful to the caller. 9.9.9.9 sends these // when responding with REFUSED for instance. @@ -854,6 +832,17 @@ func (dns *Msg) Unpack(msg []byte) (err error) { // println("dns: extra bytes in dns packet", off, "<", len(msg)) } return err + +} + +// Unpack unpacks a binary message to a Msg structure. +func (dns *Msg) Unpack(msg []byte) (err error) { + dh, off, err := unpackMsgHdr(msg, 0) + if err != nil { + return err + } + + return dns.unpack(dh, msg, off) } // Convert a complete message to a string with dig-like output. @@ -1196,3 +1185,18 @@ func unpackMsgHdr(msg []byte, off int) (Header, int, error) { dh.Arcount, off, err = unpackUint16(msg, off) return dh, off, err } + +// setHdr set the header in the dns using the binary data in dh. +func (dns *Msg) setHdr(dh Header) { + dns.Id = dh.Id + dns.Response = dh.Bits&_QR != 0 + dns.Opcode = int(dh.Bits>>11) & 0xF + dns.Authoritative = dh.Bits&_AA != 0 + dns.Truncated = dh.Bits&_TC != 0 + dns.RecursionDesired = dh.Bits&_RD != 0 + dns.RecursionAvailable = dh.Bits&_RA != 0 + dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite. + dns.AuthenticatedData = dh.Bits&_AD != 0 + dns.CheckingDisabled = dh.Bits&_CD != 0 + dns.Rcode = int(dh.Bits & 0xF) +} diff --git a/server.go b/server.go index 06984e7c..0cf1bd92 100644 --- a/server.go +++ b/server.go @@ -203,9 +203,6 @@ type Server struct { IdleTimeout func() time.Duration // Secret(s) for Tsig map[]. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2). TsigSecret map[string]string - // Unsafe instructs the server to disregard any sanity checks and directly hand the message to - // the handler. It will specifically not check if the query has the QR bit not set. - Unsafe bool // If NotifyStartedFunc is set it is called once the server has started listening. NotifyStartedFunc func() // DecorateReader is optional, allows customization of the process that reads raw DNS messages. @@ -217,6 +214,9 @@ type Server struct { // Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address. // It is only supported on go1.11+ and when using ListenAndServe. ReusePort bool + // AcceptMsgFunc will check the incoming message and will reject it early in the process. This function must be + // defined. By default DefaultMsgAcceptFunc will be used. + MsgAcceptFunc MsgAcceptFunc // UDP packet or TCP connection queue queue chan *response @@ -300,6 +300,9 @@ func (srv *Server) init() { if srv.UDPSize == 0 { srv.UDPSize = MinMsgSize } + if srv.MsgAcceptFunc == nil { + srv.MsgAcceptFunc = defaultMsgAcceptFunc + } srv.udpPool.New = makeUDPBuffer(srv.UDPSize) } @@ -630,14 +633,33 @@ func (srv *Server) disposeBuffer(w *response) { } func (srv *Server) serveDNS(w *response) { - req := new(Msg) - err := req.Unpack(w.msg) - if err != nil { // Send a FormatError back - x := new(Msg) - x.SetRcodeFormatError(req) - w.WriteMsg(x) + dh, off, err := unpackMsgHdr(w.msg, 0) + if err != nil { + // Let client hang, they are sending crap; any reply can be used to amplify. + return } - if err != nil || !srv.Unsafe && req.Response { + + req := new(Msg) + req.setHdr(dh) + + if accept, respond := srv.MsgAcceptFunc(dh); !accept { + if !respond { + return + } + req.SetRcodeFormatError(req) + // Are we allowed to delete any OPT records here? + req.Ns, req.Answer, req.Extra = nil, nil, nil + + w.WriteMsg(req) + srv.disposeBuffer(w) + return + } + + if err := req.unpack(dh, w.msg, off); err != nil { + req.SetRcodeFormatError(req) + req.Ns, req.Answer, req.Extra = nil, nil, nil + + w.WriteMsg(req) srv.disposeBuffer(w) return } diff --git a/server_test.go b/server_test.go index 5c6b8006..58b4cde9 100644 --- a/server_test.go +++ b/server_test.go @@ -102,6 +102,7 @@ func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, er } server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour} + server.init() waitLock := sync.Mutex{} waitLock.Lock() @@ -568,6 +569,7 @@ func TestServingResponse(t *testing.T) { if err != nil { t.Fatalf("unable to run test server: %v", err) } + defer s.Shutdown() c := new(Client) m := new(Msg) @@ -582,20 +584,6 @@ func TestServingResponse(t *testing.T) { if err == nil { t.Fatal("exchanged response message") } - - s.Shutdown() - s, addrstr, _, err = RunLocalUDPServerWithFinChan(":0", - func(srv *Server) { srv.Unsafe = true }) - if err != nil { - t.Fatalf("unable to run test server: %v", err) - } - defer s.Shutdown() - - m.Response = true - _, _, err = c.Exchange(m, addrstr) - if err != nil { - t.Fatal("could exchanged response message in Unsafe mode") - } } func TestShutdownTCP(t *testing.T) {