From 33e64002b6f9bc4d1c694ce6677d6185ba37544f Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Sat, 5 Feb 2022 10:53:49 +1030 Subject: [PATCH] Support TsigProvider for Server and Transfer (#1331) Automatically submitted. --- client.go | 32 ++++++++++++------------------ server.go | 42 ++++++++++++++++++++++++---------------- tsig.go | 18 +++++++++++++++++ xfr.go | 27 +++++++++++++++----------- xfr_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 126 insertions(+), 49 deletions(-) diff --git a/client.go b/client.go index 99bd1896..31bf5759 100644 --- a/client.go +++ b/client.go @@ -39,6 +39,14 @@ type Conn struct { tsigRequestMAC string } +func (co *Conn) tsigProvider() TsigProvider { + if co.TsigProvider != nil { + return co.TsigProvider + } + // tsigSecretProvider will return ErrSecret if co.TsigSecret is nil. + return tsigSecretProvider(co.TsigSecret) +} + // A Client defines parameters for a DNS client. type Client struct { Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) @@ -271,15 +279,8 @@ func (co *Conn) ReadMsg() (*Msg, error) { return m, err } if t := m.IsTsig(); t != nil { - if co.TsigProvider != nil { - err = tsigVerifyProvider(p, co.TsigProvider, co.tsigRequestMAC, false) - } else { - if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { - return m, ErrSecret - } - // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) - } + // Need to work on the original message p, as that was used to calculate the tsig. + err = tsigVerifyProvider(p, co.tsigProvider(), co.tsigRequestMAC, false) } return m, err } @@ -356,17 +357,8 @@ func (co *Conn) Read(p []byte) (n int, err error) { func (co *Conn) WriteMsg(m *Msg) (err error) { var out []byte if t := m.IsTsig(); t != nil { - mac := "" - if co.TsigProvider != nil { - out, mac, err = tsigGenerateProvider(m, co.TsigProvider, co.tsigRequestMAC, false) - } else { - if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { - return ErrSecret - } - out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) - } - // Set for the next read, although only used in zone transfers - co.tsigRequestMAC = mac + // Set tsigRequestMAC for the next read, although only used in zone transfers. + out, co.tsigRequestMAC, err = tsigGenerateProvider(m, co.tsigProvider(), co.tsigRequestMAC, false) } else { out, err = m.Pack() } diff --git a/server.go b/server.go index b2a63bda..b962e6f3 100644 --- a/server.go +++ b/server.go @@ -71,12 +71,12 @@ type response struct { tsigTimersOnly bool tsigStatus error tsigRequestMAC string - tsigSecret map[string]string // the tsig secrets - udp net.PacketConn // i/o connection if UDP was used - tcp net.Conn // i/o connection if TCP was used - udpSession *SessionUDP // oob data to get egress interface right - pcSession net.Addr // address to use when writing to a generic net.PacketConn - writer Writer // writer to output the raw DNS bits + tsigProvider TsigProvider + udp net.PacketConn // i/o connection if UDP was used + tcp net.Conn // i/o connection if TCP was used + udpSession *SessionUDP // oob data to get egress interface right + pcSession net.Addr // address to use when writing to a generic net.PacketConn + writer Writer // writer to output the raw DNS bits } // handleRefused returns a HandlerFunc that returns REFUSED for every request it gets. @@ -211,6 +211,8 @@ type Server struct { WriteTimeout time.Duration // TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966). IdleTimeout func() time.Duration + // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations. + TsigProvider TsigProvider // Secret(s) for Tsig map[]. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2). TsigSecret map[string]string // If NotifyStartedFunc is set it is called once the server has started listening. @@ -238,6 +240,16 @@ type Server struct { udpPool sync.Pool } +func (srv *Server) tsigProvider() TsigProvider { + if srv.TsigProvider != nil { + return srv.TsigProvider + } + if srv.TsigSecret != nil { + return tsigSecretProvider(srv.TsigSecret) + } + return nil +} + func (srv *Server) isStarted() bool { srv.lock.RLock() started := srv.started @@ -526,7 +538,7 @@ func (srv *Server) serveUDP(l net.PacketConn) error { // Serve a new TCP connection. func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) { - w := &response{tsigSecret: srv.TsigSecret, tcp: rw} + w := &response{tsigProvider: srv.tsigProvider(), tcp: rw} if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { @@ -581,7 +593,7 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) { // Serve a new UDP request. func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession} + w := &response{tsigProvider: srv.tsigProvider(), udp: u, udpSession: udpSession, pcSession: pcSession} if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { @@ -632,15 +644,11 @@ func (srv *Server) serveDNS(m []byte, w *response) { } w.tsigStatus = nil - if w.tsigSecret != nil { + if w.tsigProvider != nil { if t := req.IsTsig(); t != nil { - if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { - w.tsigStatus = TsigVerify(m, secret, "", false) - } else { - w.tsigStatus = ErrSecret - } + w.tsigStatus = tsigVerifyProvider(m, w.tsigProvider, "", false) w.tsigTimersOnly = false - w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC + w.tsigRequestMAC = t.MAC } } @@ -718,9 +726,9 @@ func (w *response) WriteMsg(m *Msg) (err error) { } var data []byte - if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) + if w.tsigProvider != nil { // if no provider, dont check for the tsig (which is a longer check) if t := m.IsTsig(); t != nil { - data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) + data, w.tsigRequestMAC, err = tsigGenerateProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly) if err != nil { return err } diff --git a/tsig.go b/tsig.go index 55ca7521..8b37cc84 100644 --- a/tsig.go +++ b/tsig.go @@ -74,6 +74,24 @@ func (key tsigHMACProvider) Verify(msg []byte, t *TSIG) error { return nil } +type tsigSecretProvider map[string]string + +func (ts tsigSecretProvider) Generate(msg []byte, t *TSIG) ([]byte, error) { + key, ok := ts[t.Hdr.Name] + if !ok { + return nil, ErrSecret + } + return tsigHMACProvider(key).Generate(msg, t) +} + +func (ts tsigSecretProvider) Verify(msg []byte, t *TSIG) error { + key, ok := ts[t.Hdr.Name] + if !ok { + return ErrSecret + } + return tsigHMACProvider(key).Verify(msg, t) +} + // TSIG is the RR the holds the transaction signature of a message. // See RFC 2845 and RFC 4635. type TSIG struct { diff --git a/xfr.go b/xfr.go index 43970e64..f0dcf61d 100644 --- a/xfr.go +++ b/xfr.go @@ -17,11 +17,22 @@ type Transfer struct { DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds + TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations. TsigSecret map[string]string // Secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) tsigTimersOnly bool } -// Think we need to away to stop the transfer +func (t *Transfer) tsigProvider() TsigProvider { + if t.TsigProvider != nil { + return t.TsigProvider + } + if t.TsigSecret != nil { + return tsigSecretProvider(t.TsigSecret) + } + return nil +} + +// TODO: Think we need to away to stop the transfer // In performs an incoming transfer with the server in a. // If you would like to set the source IP, or some other attribute @@ -224,12 +235,9 @@ func (t *Transfer) ReadMsg() (*Msg, error) { if err := m.Unpack(p); err != nil { return nil, err } - if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil { - if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { - return m, ErrSecret - } + if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil { // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) + err = tsigVerifyProvider(p, tp, t.tsigRequestMAC, t.tsigTimersOnly) t.tsigRequestMAC = ts.MAC } return m, err @@ -238,11 +246,8 @@ func (t *Transfer) ReadMsg() (*Msg, error) { // WriteMsg writes a message through the transfer connection t. func (t *Transfer) WriteMsg(m *Msg) (err error) { var out []byte - if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil { - if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { - return ErrSecret - } - out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) + if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil { + out, t.tsigRequestMAC, err = tsigGenerateProvider(m, tp, t.tsigRequestMAC, t.tsigTimersOnly) } else { out, err = m.Pack() } diff --git a/xfr_test.go b/xfr_test.go index 7cb9f1d3..f6c5e98c 100644 --- a/xfr_test.go +++ b/xfr_test.go @@ -1,6 +1,9 @@ package dns -import "testing" +import ( + "testing" + "time" +) var ( tsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="} @@ -127,3 +130,54 @@ func axfrTestingSuite(t *testing.T, addrstr string) { } } } + +func axfrTestingSuiteWithCustomTsig(t *testing.T, addrstr string, provider TsigProvider) { + tr := new(Transfer) + m := new(Msg) + var err error + tr.Conn, err = Dial("tcp", addrstr) + if err != nil { + t.Fatal("failed to dial", err) + } + tr.TsigProvider = provider + m.SetAxfr("miek.nl.") + m.SetTsig("axfr.", HmacSHA256, 300, time.Now().Unix()) + + c, err := tr.In(m, addrstr) + if err != nil { + t.Fatal("failed to zone transfer in", err) + } + + var records []RR + for msg := range c { + if msg.Error != nil { + t.Fatal(msg.Error) + } + records = append(records, msg.RR...) + } + + if len(records) != len(xfrTestData) { + t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData) + } + + for i, rr := range records { + if !IsDuplicate(rr, xfrTestData[i]) { + t.Errorf("bad axfr: expected %v, got %v", records, xfrTestData) + } + } +} + +func TestCustomTsigProvider(t *testing.T) { + HandleFunc("miek.nl.", SingleEnvelopeXfrServer) + defer HandleRemove("miek.nl.") + + s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) { + srv.TsigProvider = tsigSecretProvider(tsigSecret) + }) + if err != nil { + t.Fatalf("unable to run test server: %s", err) + } + defer s.Shutdown() + + axfrTestingSuiteWithCustomTsig(t, addrstr, tsigSecretProvider(tsigSecret)) +}