Add MsgAcceptFunc in server
Generalize the srv.Unsafe and make it pluggeable. Also add a default accept function that allows to discard malformed DNS messages very early on. Before we allocate and parse anything furher. Also re-use the client's message when sending a reply. Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
parent
6bf402f3c4
commit
2c18e7259a
|
@ -0,0 +1,46 @@
|
|||
package dns
|
||||
|
||||
// MsgAcceptFunc is used early in the server code to accept or reject a message with RcodeFormatError.
|
||||
// There are to booleans to be returned, once signaling the rejection and another to signal if
|
||||
// a reply is to be send back (you want to prevent DNS ping-pong and not reply to a response for instance).
|
||||
type MsgAcceptFunc func(dh Header) (accept bool, respond bool)
|
||||
|
||||
// DefaultMsgAcceptFunc checks the request and will reject if:
|
||||
//
|
||||
// * isn't a request (don't respond in that case).
|
||||
// * opcode isn't OpcodeQuery or OpcodeNotify
|
||||
// * Zero bit isn't zero
|
||||
// * has more than 1 question in the question section
|
||||
// * has more than 0 RRs in the Answer section
|
||||
// * has more than 0 RRs in the Authority section
|
||||
// * has more than 2 RRs in the Additional section
|
||||
var DefaultMsgAcceptFunc = defaultMsgAcceptFunc
|
||||
|
||||
var defaultMsgAcceptFunc = func(dh Header) (bool, bool) {
|
||||
if isResponse := dh.Bits&_QR != 0; isResponse {
|
||||
return false, false
|
||||
}
|
||||
|
||||
// Don't allow dynamic updates, because then the sections can contain a whole bunch of RRs.
|
||||
opcode := int(dh.Bits>>11) & 0xF
|
||||
if opcode != OpcodeQuery && opcode != OpcodeNotify {
|
||||
return false, true
|
||||
}
|
||||
|
||||
if isZero := dh.Bits&_Z != 0; isZero {
|
||||
return false, true
|
||||
}
|
||||
if dh.Qdcount != 1 {
|
||||
return false, true
|
||||
}
|
||||
if dh.Ancount != 0 {
|
||||
return false, true
|
||||
}
|
||||
if dh.Nscount != 0 {
|
||||
return false, true
|
||||
}
|
||||
if dh.Arcount > 2 {
|
||||
return false, true
|
||||
}
|
||||
return true, true
|
||||
}
|
52
msg.go
52
msg.go
|
@ -778,30 +778,8 @@ func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression map[string]
|
|||
return msg[:off], nil
|
||||
}
|
||||
|
||||
// Unpack unpacks a binary message to a Msg structure.
|
||||
func (dns *Msg) Unpack(msg []byte) (err error) {
|
||||
// We use a similar function in tsig.go's stripTsig.
|
||||
|
||||
var (
|
||||
dh Header
|
||||
off int
|
||||
)
|
||||
if dh, off, err = unpackMsgHdr(msg, off); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dns.Id = dh.Id
|
||||
dns.Response = dh.Bits&_QR != 0
|
||||
dns.Opcode = int(dh.Bits>>11) & 0xF
|
||||
dns.Authoritative = dh.Bits&_AA != 0
|
||||
dns.Truncated = dh.Bits&_TC != 0
|
||||
dns.RecursionDesired = dh.Bits&_RD != 0
|
||||
dns.RecursionAvailable = dh.Bits&_RA != 0
|
||||
dns.Zero = dh.Bits&_Z != 0
|
||||
dns.AuthenticatedData = dh.Bits&_AD != 0
|
||||
dns.CheckingDisabled = dh.Bits&_CD != 0
|
||||
dns.Rcode = int(dh.Bits & 0xF)
|
||||
|
||||
func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) {
|
||||
dns.setHdr(dh)
|
||||
// If we are at the end of the message we should return *just* the
|
||||
// header. This can still be useful to the caller. 9.9.9.9 sends these
|
||||
// when responding with REFUSED for instance.
|
||||
|
@ -854,6 +832,17 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
|
|||
// println("dns: extra bytes in dns packet", off, "<", len(msg))
|
||||
}
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
// Unpack unpacks a binary message to a Msg structure.
|
||||
func (dns *Msg) Unpack(msg []byte) (err error) {
|
||||
dh, off, err := unpackMsgHdr(msg, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return dns.unpack(dh, msg, off)
|
||||
}
|
||||
|
||||
// Convert a complete message to a string with dig-like output.
|
||||
|
@ -1196,3 +1185,18 @@ func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
|
|||
dh.Arcount, off, err = unpackUint16(msg, off)
|
||||
return dh, off, err
|
||||
}
|
||||
|
||||
// setHdr set the header in the dns using the binary data in dh.
|
||||
func (dns *Msg) setHdr(dh Header) {
|
||||
dns.Id = dh.Id
|
||||
dns.Response = dh.Bits&_QR != 0
|
||||
dns.Opcode = int(dh.Bits>>11) & 0xF
|
||||
dns.Authoritative = dh.Bits&_AA != 0
|
||||
dns.Truncated = dh.Bits&_TC != 0
|
||||
dns.RecursionDesired = dh.Bits&_RD != 0
|
||||
dns.RecursionAvailable = dh.Bits&_RA != 0
|
||||
dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite.
|
||||
dns.AuthenticatedData = dh.Bits&_AD != 0
|
||||
dns.CheckingDisabled = dh.Bits&_CD != 0
|
||||
dns.Rcode = int(dh.Bits & 0xF)
|
||||
}
|
||||
|
|
42
server.go
42
server.go
|
@ -203,9 +203,6 @@ type Server struct {
|
|||
IdleTimeout func() time.Duration
|
||||
// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
|
||||
TsigSecret map[string]string
|
||||
// Unsafe instructs the server to disregard any sanity checks and directly hand the message to
|
||||
// the handler. It will specifically not check if the query has the QR bit not set.
|
||||
Unsafe bool
|
||||
// If NotifyStartedFunc is set it is called once the server has started listening.
|
||||
NotifyStartedFunc func()
|
||||
// DecorateReader is optional, allows customization of the process that reads raw DNS messages.
|
||||
|
@ -217,6 +214,9 @@ type Server struct {
|
|||
// Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
|
||||
// It is only supported on go1.11+ and when using ListenAndServe.
|
||||
ReusePort bool
|
||||
// AcceptMsgFunc will check the incoming message and will reject it early in the process. This function must be
|
||||
// defined. By default DefaultMsgAcceptFunc will be used.
|
||||
MsgAcceptFunc MsgAcceptFunc
|
||||
|
||||
// UDP packet or TCP connection queue
|
||||
queue chan *response
|
||||
|
@ -300,6 +300,9 @@ func (srv *Server) init() {
|
|||
if srv.UDPSize == 0 {
|
||||
srv.UDPSize = MinMsgSize
|
||||
}
|
||||
if srv.MsgAcceptFunc == nil {
|
||||
srv.MsgAcceptFunc = defaultMsgAcceptFunc
|
||||
}
|
||||
|
||||
srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
|
||||
}
|
||||
|
@ -630,14 +633,33 @@ func (srv *Server) disposeBuffer(w *response) {
|
|||
}
|
||||
|
||||
func (srv *Server) serveDNS(w *response) {
|
||||
req := new(Msg)
|
||||
err := req.Unpack(w.msg)
|
||||
if err != nil { // Send a FormatError back
|
||||
x := new(Msg)
|
||||
x.SetRcodeFormatError(req)
|
||||
w.WriteMsg(x)
|
||||
dh, off, err := unpackMsgHdr(w.msg, 0)
|
||||
if err != nil {
|
||||
// Let client hang, they are sending crap; any reply can be used to amplify.
|
||||
return
|
||||
}
|
||||
if err != nil || !srv.Unsafe && req.Response {
|
||||
|
||||
req := new(Msg)
|
||||
req.setHdr(dh)
|
||||
|
||||
if accept, respond := srv.MsgAcceptFunc(dh); !accept {
|
||||
if !respond {
|
||||
return
|
||||
}
|
||||
req.SetRcodeFormatError(req)
|
||||
// Are we allowed to delete any OPT records here?
|
||||
req.Ns, req.Answer, req.Extra = nil, nil, nil
|
||||
|
||||
w.WriteMsg(req)
|
||||
srv.disposeBuffer(w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := req.unpack(dh, w.msg, off); err != nil {
|
||||
req.SetRcodeFormatError(req)
|
||||
req.Ns, req.Answer, req.Extra = nil, nil, nil
|
||||
|
||||
w.WriteMsg(req)
|
||||
srv.disposeBuffer(w)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -102,6 +102,7 @@ func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, er
|
|||
}
|
||||
|
||||
server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
|
||||
server.init()
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
|
@ -568,6 +569,7 @@ func TestServingResponse(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("unable to run test server: %v", err)
|
||||
}
|
||||
defer s.Shutdown()
|
||||
|
||||
c := new(Client)
|
||||
m := new(Msg)
|
||||
|
@ -582,20 +584,6 @@ func TestServingResponse(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Fatal("exchanged response message")
|
||||
}
|
||||
|
||||
s.Shutdown()
|
||||
s, addrstr, _, err = RunLocalUDPServerWithFinChan(":0",
|
||||
func(srv *Server) { srv.Unsafe = true })
|
||||
if err != nil {
|
||||
t.Fatalf("unable to run test server: %v", err)
|
||||
}
|
||||
defer s.Shutdown()
|
||||
|
||||
m.Response = true
|
||||
_, _, err = c.Exchange(m, addrstr)
|
||||
if err != nil {
|
||||
t.Fatal("could exchanged response message in Unsafe mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdownTCP(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue