From 59aea23afe5598de8cfac8c7421b13fc7d3f45ce Mon Sep 17 00:00:00 2001 From: Matt Dainty Date: Thu, 7 Jan 2021 14:28:20 +0000 Subject: [PATCH] Add GSS-TSIG support (#1201) Automatically submitted. --- client.go | 26 +++++++---- doc.go | 24 ++++++++++ tsig.go | 118 +++++++++++++++++++++++++++-------------------- tsig_test.go | 128 ++++++++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 230 insertions(+), 66 deletions(-) diff --git a/client.go b/client.go index f580205e..aa2c49d3 100644 --- a/client.go +++ b/client.go @@ -23,6 +23,7 @@ type Conn struct { net.Conn // a net.Conn holding the connection UDPSize uint16 // minimum receive buffer for UDP messages TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) + TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations. tsigRequestMAC string } @@ -40,6 +41,7 @@ type Client struct { ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) + TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations. SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass group singleflight } @@ -175,7 +177,7 @@ func (c *Client) exchange(m *Msg, co *Conn) (r *Msg, rtt time.Duration, err erro co.UDPSize = c.UDPSize } - co.TsigSecret = c.TsigSecret + co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider t := time.Now() // write with the appropriate write timeout co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout()))) @@ -222,11 +224,15 @@ func (co *Conn) ReadMsg() (*Msg, error) { return m, err } if t := m.IsTsig(); t != nil { - if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { - return m, ErrSecret + if co.TsigProvider != nil { + err = tsigVerifyProvider(p, co.TsigProvider, co.tsigRequestMAC, false) + } else { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return m, ErrSecret + } + // Need to work on the original message p, as that was used to calculate the tsig. + err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } - // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } return m, err } @@ -304,10 +310,14 @@ func (co *Conn) WriteMsg(m *Msg) (err error) { var out []byte if t := m.IsTsig(); t != nil { mac := "" - if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { - return ErrSecret + if co.TsigProvider != nil { + out, mac, err = tsigGenerateProvider(m, co.TsigProvider, co.tsigRequestMAC, false) + } else { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { + return ErrSecret + } + out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } - out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) // Set for the next read, although only used in zone transfers co.tsigRequestMAC = mac } else { diff --git a/doc.go b/doc.go index 6861de77..f7629ec3 100644 --- a/doc.go +++ b/doc.go @@ -194,6 +194,30 @@ request an AXFR for miek.nl. with TSIG key named "axfr." and secret You can now read the records from the transfer as they come in. Each envelope is checked with TSIG. If something is not correct an error is returned. +A custom TSIG implementation can be used. This requires additional code to +perform any session establishment and signature generation/verification. The +client must be configured with an implementation of the TsigProvider interface: + + type Provider struct{} + + func (*Provider) Generate(msg []byte, tsig *dns.TSIG) ([]byte, error) { + // Use tsig.Hdr.Name and tsig.Algorithm in your code to + // generate the MAC using msg as the payload. + } + + func (*Provider) Verify(msg []byte, tsig *dns.TSIG) error { + // Use tsig.Hdr.Name and tsig.Algorithm in your code to verify + // that msg matches the value in tsig.MAC. + } + + c := new(dns.Client) + c.TsigProvider = new(Provider) + m := new(dns.Msg) + m.SetQuestion("miek.nl.", dns.TypeMX) + m.SetTsig(keyname, dns.HmacSHA1, 300, time.Now().Unix()) + ... + // TSIG RR is calculated by calling your Generate method + Basic use pattern validating and replying to a message that has TSIG set. server := &dns.Server{Addr: ":53", Net: "udp"} diff --git a/tsig.go b/tsig.go index 59904dd6..5b52e1c4 100644 --- a/tsig.go +++ b/tsig.go @@ -24,6 +24,56 @@ const ( HmacMD5 = "hmac-md5.sig-alg.reg.int." // Deprecated: HmacMD5 is no longer supported. ) +// TsigProvider provides the API to plug-in a custom TSIG implementation. +type TsigProvider interface { + // Generate is passed the DNS message to be signed and the partial TSIG RR. It returns the signature and nil, otherwise an error. + Generate(msg []byte, t *TSIG) ([]byte, error) + // Verify is passed the DNS message to be verified and the TSIG RR. If the signature is valid it will return nil, otherwise an error. + Verify(msg []byte, t *TSIG) error +} + +type tsigHMACProvider string + +func (key tsigHMACProvider) Generate(msg []byte, t *TSIG) ([]byte, error) { + // If we barf here, the caller is to blame + rawsecret, err := fromBase64([]byte(key)) + if err != nil { + return nil, err + } + var h hash.Hash + switch CanonicalName(t.Algorithm) { + case HmacSHA1: + h = hmac.New(sha1.New, rawsecret) + case HmacSHA224: + h = hmac.New(sha256.New224, rawsecret) + case HmacSHA256: + h = hmac.New(sha256.New, rawsecret) + case HmacSHA384: + h = hmac.New(sha512.New384, rawsecret) + case HmacSHA512: + h = hmac.New(sha512.New, rawsecret) + default: + return nil, ErrKeyAlg + } + h.Write(msg) + return h.Sum(nil), nil +} + +func (key tsigHMACProvider) Verify(msg []byte, t *TSIG) error { + b, err := key.Generate(msg, t) + if err != nil { + return err + } + mac, err := hex.DecodeString(t.MAC) + if err != nil { + return err + } + if !hmac.Equal(b, mac) { + return ErrSig + } + return nil +} + // TSIG is the RR the holds the transaction signature of a message. // See RFC 2845 and RFC 4635. type TSIG struct { @@ -98,14 +148,13 @@ type timerWireFmt struct { // timersOnly is false. // If something goes wrong an error is returned, otherwise it is nil. func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) { + return tsigGenerateProvider(m, tsigHMACProvider(secret), requestMAC, timersOnly) +} + +func tsigGenerateProvider(m *Msg, provider TsigProvider, requestMAC string, timersOnly bool) ([]byte, string, error) { if m.IsTsig() == nil { panic("dns: TSIG not last RR in additional") } - // If we barf here, the caller is to blame - rawsecret, err := fromBase64([]byte(secret)) - if err != nil { - return nil, "", err - } rr := m.Extra[len(m.Extra)-1].(*TSIG) m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg @@ -119,25 +168,13 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s } t := new(TSIG) - var h hash.Hash - switch CanonicalName(rr.Algorithm) { - case HmacSHA1: - h = hmac.New(sha1.New, rawsecret) - case HmacSHA224: - h = hmac.New(sha256.New224, rawsecret) - case HmacSHA256: - h = hmac.New(sha256.New, rawsecret) - case HmacSHA384: - h = hmac.New(sha512.New384, rawsecret) - case HmacSHA512: - h = hmac.New(sha512.New, rawsecret) - default: - return nil, "", ErrKeyAlg - } - h.Write(buf) // Copy all TSIG fields except MAC and its size, which are filled using the computed digest. *t = *rr - t.MAC = hex.EncodeToString(h.Sum(nil)) + mac, err := provider.Generate(buf, rr) + if err != nil { + return nil, "", err + } + t.MAC = hex.EncodeToString(mac) t.MACSize = uint16(len(t.MAC) / 2) // Size is half! tbuf := make([]byte, Len(t)) @@ -156,49 +193,28 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s // If the signature does not validate err contains the // error, otherwise it is nil. func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error { - return tsigVerify(msg, secret, requestMAC, timersOnly, uint64(time.Now().Unix())) + return tsigVerify(msg, tsigHMACProvider(secret), requestMAC, timersOnly, uint64(time.Now().Unix())) +} + +func tsigVerifyProvider(msg []byte, provider TsigProvider, requestMAC string, timersOnly bool) error { + return tsigVerify(msg, provider, requestMAC, timersOnly, uint64(time.Now().Unix())) } // actual implementation of TsigVerify, taking the current time ('now') as a parameter for the convenience of tests. -func tsigVerify(msg []byte, secret, requestMAC string, timersOnly bool, now uint64) error { - rawsecret, err := fromBase64([]byte(secret)) - if err != nil { - return err - } +func tsigVerify(msg []byte, provider TsigProvider, requestMAC string, timersOnly bool, now uint64) error { // Strip the TSIG from the incoming msg stripped, tsig, err := stripTsig(msg) if err != nil { return err } - msgMAC, err := hex.DecodeString(tsig.MAC) - if err != nil { - return err - } - buf, err := tsigBuffer(stripped, tsig, requestMAC, timersOnly) if err != nil { return err } - var h hash.Hash - switch CanonicalName(tsig.Algorithm) { - case HmacSHA1: - h = hmac.New(sha1.New, rawsecret) - case HmacSHA224: - h = hmac.New(sha256.New224, rawsecret) - case HmacSHA256: - h = hmac.New(sha256.New, rawsecret) - case HmacSHA384: - h = hmac.New(sha512.New384, rawsecret) - case HmacSHA512: - h = hmac.New(sha512.New, rawsecret) - default: - return ErrKeyAlg - } - h.Write(buf) - if !hmac.Equal(h.Sum(nil), msgMAC) { - return ErrSig + if err := provider.Verify(buf, tsig); err != nil { + return err } // Fudge factor works both ways. A message can arrive before it was signed because diff --git a/tsig_test.go b/tsig_test.go index 81bb1796..e3f06566 100644 --- a/tsig_test.go +++ b/tsig_test.go @@ -3,6 +3,7 @@ package dns import ( "encoding/binary" "encoding/hex" + "errors" "fmt" "strings" "testing" @@ -79,23 +80,23 @@ func TestTsigErrors(t *testing.T) { } // the signature is valid but 'time signed' is too far from the "current time". - if err := tsigVerify(buildMsgData(timeSigned), testSecret, "", false, timeSigned+301); err != ErrTime { + if err := tsigVerify(buildMsgData(timeSigned), tsigHMACProvider(testSecret), "", false, timeSigned+301); err != ErrTime { t.Fatalf("expected an error '%v' but got '%v'", ErrTime, err) } - if err := tsigVerify(buildMsgData(timeSigned), testSecret, "", false, timeSigned-301); err != ErrTime { + if err := tsigVerify(buildMsgData(timeSigned), tsigHMACProvider(testSecret), "", false, timeSigned-301); err != ErrTime { t.Fatalf("expected an error '%v' but got '%v'", ErrTime, err) } // the signature is invalid and 'time signed' is too far. // the signature should be checked first, so we should see ErrSig. - if err := tsigVerify(buildMsgData(timeSigned+301), testSecret, "", false, timeSigned); err != ErrSig { + if err := tsigVerify(buildMsgData(timeSigned+301), tsigHMACProvider(testSecret), "", false, timeSigned); err != ErrSig { t.Fatalf("expected an error '%v' but got '%v'", ErrSig, err) } // tweak the algorithm name in the wire data, resulting in the "unknown algorithm" error. msgData := buildMsgData(timeSigned) copy(msgData[67:], "bogus") - if err := tsigVerify(msgData, testSecret, "", false, timeSigned); err != ErrKeyAlg { + if err := tsigVerify(msgData, tsigHMACProvider(testSecret), "", false, timeSigned); err != ErrKeyAlg { t.Fatalf("expected an error '%v' but got '%v'", ErrKeyAlg, err) } @@ -104,7 +105,7 @@ func TestTsigErrors(t *testing.T) { if err != nil { t.Fatal(err) } - if err := tsigVerify(msgData, testSecret, "", false, timeSigned); err != ErrNoSig { + if err := tsigVerify(msgData, tsigHMACProvider(testSecret), "", false, timeSigned); err != ErrNoSig { t.Fatalf("expected an error '%v' but got '%v'", ErrNoSig, err) } @@ -120,7 +121,7 @@ func TestTsigErrors(t *testing.T) { if msgData, err = msg.Pack(); err != nil { t.Fatal(err) } - err = tsigVerify(msgData, testSecret, "", false, timeSigned) + err = tsigVerify(msgData, tsigHMACProvider(testSecret), "", false, timeSigned) if err == nil || !strings.Contains(err.Error(), "overflow") { t.Errorf("expected error to contain %q, but got %v", "overflow", err) } @@ -231,9 +232,122 @@ func TestTSIGHMAC224And384(t *testing.T) { if mac != tc.expectedMAC { t.Fatalf("MAC doesn't match: expected '%s' but got '%s'", tc.expectedMAC, mac) } - if err = tsigVerify(msgData, tc.secret, "", false, timeSigned); err != nil { + if err = tsigVerify(msgData, tsigHMACProvider(tc.secret), "", false, timeSigned); err != nil { t.Error(err) } }) } } + +const testGoodKeyName = "goodkey." + +var ( + testErrBadKey = errors.New("this is an intentional error") + testGoodMAC = []byte{0, 1, 2, 3} +) + +// testProvider always generates the same MAC and only accepts the one signature +type testProvider struct { + GenerateAllKeys bool +} + +func (provider *testProvider) Generate(_ []byte, t *TSIG) ([]byte, error) { + if t.Hdr.Name == testGoodKeyName || provider.GenerateAllKeys { + return testGoodMAC, nil + } + return nil, testErrBadKey +} + +func (*testProvider) Verify(_ []byte, t *TSIG) error { + if t.Hdr.Name == testGoodKeyName { + return nil + } + return testErrBadKey +} + +func TestTsigGenerateProvider(t *testing.T) { + tables := []struct { + keyname string + mac []byte + err error + }{ + { + testGoodKeyName, + testGoodMAC, + nil, + }, + { + "badkey.", + nil, + testErrBadKey, + }, + } + + for _, table := range tables { + t.Run(table.keyname, func(t *testing.T) { + tsig := TSIG{ + Hdr: RR_Header{Name: table.keyname, Rrtype: TypeTSIG, Class: ClassANY, Ttl: 0}, + Algorithm: HmacSHA1, + TimeSigned: timeSigned, + Fudge: 300, + OrigId: 42, + } + req := &Msg{ + MsgHdr: MsgHdr{Opcode: OpcodeUpdate}, + Question: []Question{Question{Name: "example.com.", Qtype: TypeSOA, Qclass: ClassINET}}, + Extra: []RR{&tsig}, + } + + _, mac, err := tsigGenerateProvider(req, new(testProvider), "", false) + if err != table.err { + t.Fatalf("error doesn't match: expected '%s' but got '%s'", table.err, err) + } + expectedMAC := hex.EncodeToString(table.mac) + if mac != expectedMAC { + t.Fatalf("MAC doesn't match: expected '%s' but got '%s'", table.mac, expectedMAC) + } + }) + } +} + +func TestTsigVerifyProvider(t *testing.T) { + tables := []struct { + keyname string + err error + }{ + { + testGoodKeyName, + nil, + }, + { + "badkey.", + testErrBadKey, + }, + } + + for _, table := range tables { + t.Run(table.keyname, func(t *testing.T) { + tsig := TSIG{ + Hdr: RR_Header{Name: table.keyname, Rrtype: TypeTSIG, Class: ClassANY, Ttl: 0}, + Algorithm: HmacSHA1, + TimeSigned: timeSigned, + Fudge: 300, + OrigId: 42, + } + req := &Msg{ + MsgHdr: MsgHdr{Opcode: OpcodeUpdate}, + Question: []Question{Question{Name: "example.com.", Qtype: TypeSOA, Qclass: ClassINET}}, + Extra: []RR{&tsig}, + } + + provider := &testProvider{true} + msgData, _, err := tsigGenerateProvider(req, provider, "", false) + if err != nil { + t.Error(err) + } + if err = tsigVerify(msgData, provider, "", false, timeSigned); err != table.err { + t.Fatalf("error doesn't match: expected '%s' but got '%s'", table.err, err) + } + }) + } +}