diff --git a/server.go b/server.go index 2901f872..df0c19a3 100644 --- a/server.go +++ b/server.go @@ -735,20 +735,23 @@ func (srv *Server) serve(w *response) { } } -func (srv *Server) serveDNS(w *response) { - req := new(Msg) - err := req.Unpack(w.msg) +func (srv *Server) disposeBuffer(w *response) { if w.udp != nil && cap(w.msg) == srv.UDPSize { srv.udpPool.Put(w.msg[:srv.UDPSize]) } w.msg = nil +} + +func (srv *Server) serveDNS(w *response) { + req := new(Msg) + err := req.Unpack(w.msg) if err != nil { // Send a FormatError back x := new(Msg) x.SetRcodeFormatError(req) w.WriteMsg(x) - return } - if !srv.Unsafe && req.Response { + if err != nil || !srv.Unsafe && req.Response { + srv.disposeBuffer(w) return } @@ -765,6 +768,8 @@ func (srv *Server) serveDNS(w *response) { } } + srv.disposeBuffer(w) + handler := srv.Handler if handler == nil { handler = DefaultServeMux diff --git a/server_test.go b/server_test.go index d60f0cd0..b846f8de 100644 --- a/server_test.go +++ b/server_test.go @@ -59,7 +59,7 @@ func RunLocalUDPServer(laddr string) (*Server, string, error) { return server, l, err } -func RunLocalUDPServerWithFinChan(laddr string) (*Server, string, chan error, error) { +func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { pc, err := net.ListenPacket("udp", laddr) if err != nil { return nil, "", nil, err @@ -75,6 +75,10 @@ func RunLocalUDPServerWithFinChan(laddr string) (*Server, string, chan error, er // in RunLocalUDPServer and can happen in TestShutdownUDP. fin := make(chan error, 1) + for _, opt := range opts { + opt(server) + } + go func() { fin <- server.ActivateAndServe() pc.Close() @@ -999,6 +1003,56 @@ func TestServerReuseport(t *testing.T) { } } +func TestServerRoundtripTsig(t *testing.T) { + secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="} + + s, addrstr, _, err := RunLocalUDPServerWithFinChan(":0", func(srv *Server) { + srv.TsigSecret = secret + }) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + HandleFunc("example.com.", func(w ResponseWriter, r *Msg) { + m := new(Msg) + m.SetReply(r) + if r.IsTsig() != nil { + status := w.TsigStatus() + if status == nil { + // *Msg r has an TSIG record and it was validated + m.SetTsig("test.", HmacMD5, 300, time.Now().Unix()) + } else { + // *Msg r has an TSIG records and it was not valided + t.Errorf("invalid TSIG: %v", status) + } + } else { + t.Error("missing TSIG") + } + w.WriteMsg(m) + }) + + c := new(Client) + m := new(Msg) + m.Opcode = OpcodeUpdate + m.SetQuestion("example.com.", TypeSOA) + m.Ns = []RR{&CNAME{ + Hdr: RR_Header{ + Name: "foo.example.com.", + Rrtype: TypeCNAME, + Class: ClassINET, + Ttl: 300, + }, + Target: "bar.example.com.", + }} + c.TsigSecret = secret + m.SetTsig("test.", HmacMD5, 300, time.Now().Unix()) + _, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Fatal("failed to exchange", err) + } +} + type ExampleFrameLengthWriter struct { Writer }