From 475ab8086778bafb9d0833257c4970eedc8f5e04 Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Sat, 14 May 2016 17:56:20 +0100 Subject: [PATCH] Remove (most) reflection Remove the use of reflection when packing and unpacking, instead generate all the pack and unpack functions using msg_generate. This will generate zmsg.go which in turn calls the helper functions from msg_helper.go. This increases the speed by about ~30% while cutting back on memory usage. Not all RRs are using it, but that will be rectified in upcoming PR. Most of the speed increase is in the header/question section parsing. These functions *are* not generated, but straight forward enough. The implementation can be found in msg.go. The new code has been fuzzed by go-fuzz, which turned up some issues. All files that started with 'z', and not autogenerated were renamed, i.e. zscan.go is now scan.go. Reflection is still used, in subsequent PRs it will be removed entirely. --- README.md | 8 - client.go | 4 +- dns_bench_test.go | 205 +++++++++ dns_test.go | 132 +----- dnssec.go | 2 +- edns.go | 16 +- zgenerate.go => generate.go | 0 labels.go | 2 +- labels_test.go | 2 + msg.go | 338 ++++++++++----- msg_generate.go | 307 +++++++++++++ msg_helpers.go | 399 +++++++++++++++++ rawmsg.go | 12 +- reverse.go | 38 ++ zscan.go => scan.go | 0 zscan_rr.go => scan_rr.go | 0 server.go | 4 +- sig0.go | 18 +- tsig.go | 4 +- types_generate.go | 2 +- update_test.go | 2 +- zmsg.go | 827 ++++++++++++++++++++++++++++++++++++ ztypes.go | 2 +- 23 files changed, 2052 insertions(+), 272 deletions(-) create mode 100644 dns_bench_test.go rename zgenerate.go => generate.go (100%) create mode 100644 msg_generate.go create mode 100644 msg_helpers.go create mode 100644 reverse.go rename zscan.go => scan.go (100%) rename zscan_rr.go => scan_rr.go (100%) create mode 100644 zmsg.go diff --git a/README.md b/README.md index 30152cff..8fada255 100644 --- a/README.md +++ b/README.md @@ -150,11 +150,3 @@ Example programs can be found in the `github.com/miekg/exdns` repository. * `NSD` * `Net::DNS` * `GRONG` - -## TODO - -* privatekey.Precompute() when signing? -* Last remaining RRs: APL, ATMA, A6, NSAP and NXT. -* Missing in parsing: ISDN, UNSPEC, NSAP and ATMA. -* NSEC(3) cover/match/closest enclose. -* Replies with TC bit are not parsed to the end. diff --git a/client.go b/client.go index c1a4a430..7868cac2 100644 --- a/client.go +++ b/client.go @@ -300,7 +300,7 @@ func tcpMsgLen(t io.Reader) (int, error) { if n != 2 { return 0, ErrShortRead } - l, _ := unpackUint16(p, 0) + l, _ := unpackUint16Msg(p, 0) if l == 0 { return 0, ErrShortRead } @@ -392,7 +392,7 @@ func (co *Conn) Write(p []byte) (n int, err error) { return 0, &Error{err: "message too large"} } l := make([]byte, 2, lp+2) - l[0], l[1] = packUint16(uint16(lp)) + l[0], l[1] = packUint16Msg(uint16(lp)) p = append(l, p...) n, err := io.Copy(w, bytes.NewReader(p)) return int(n), err diff --git a/dns_bench_test.go b/dns_bench_test.go new file mode 100644 index 00000000..de930087 --- /dev/null +++ b/dns_bench_test.go @@ -0,0 +1,205 @@ +package dns + +import ( + "net" + "testing" +) + +func BenchmarkMsgLength(b *testing.B) { + b.StopTimer() + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + b.StartTimer() + for i := 0; i < b.N; i++ { + msg.Len() + } +} + +func BenchmarkMsgLengthPack(b *testing.B) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = msg.Pack() + } +} + +func BenchmarkPackDomainName(b *testing.B) { + name1 := "12345678901234567890123456789012345.12345678.123." + buf := make([]byte, len(name1)+1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = PackDomainName(name1, buf, 0, nil, false) + } +} + +func BenchmarkUnpackDomainName(b *testing.B) { + name1 := "12345678901234567890123456789012345.12345678.123." + buf := make([]byte, len(name1)+1) + _, _ = PackDomainName(name1, buf, 0, nil, false) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = UnpackDomainName(buf, 0) + } +} + +func BenchmarkUnpackDomainNameUnprintable(b *testing.B) { + name1 := "\x02\x02\x02\x025\x02\x02\x02\x02.12345678.123." + buf := make([]byte, len(name1)+1) + _, _ = PackDomainName(name1, buf, 0, nil, false) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = UnpackDomainName(buf, 0) + } +} + +func BenchmarkCopy(b *testing.B) { + b.ReportAllocs() + m := new(Msg) + m.SetQuestion("miek.nl.", TypeA) + rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1") + m.Answer = []RR{rr} + rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1") + m.Ns = []RR{rr} + rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.1") + m.Extra = []RR{rr} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Copy() + } +} + +func BenchmarkPackA(b *testing.B) { + a := &A{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, A: net.IPv4(127, 0, 0, 1)} + + buf := make([]byte, a.len()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = PackRR(a, buf, 0, nil, false) + } +} + +func BenchmarkUnpackA(b *testing.B) { + a := &A{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, A: net.IPv4(127, 0, 0, 1)} + + buf := make([]byte, a.len()) + PackRR(a, buf, 0, nil, false) + a = nil + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = UnpackRR(buf, 0) + } +} + +func BenchmarkPackMX(b *testing.B) { + m := &MX{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, Mx: "mx.miek.nl."} + + buf := make([]byte, m.len()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = PackRR(m, buf, 0, nil, false) + } +} + +func BenchmarkUnpackMX(b *testing.B) { + m := &MX{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, Mx: "mx.miek.nl."} + + buf := make([]byte, m.len()) + PackRR(m, buf, 0, nil, false) + m = nil + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = UnpackRR(buf, 0) + } +} + +func BenchmarkPackAAAAA(b *testing.B) { + aaaa, _ := NewRR(". IN A ::1") + + buf := make([]byte, aaaa.len()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = PackRR(aaaa, buf, 0, nil, false) + } +} + +func BenchmarkUnpackAAAA(b *testing.B) { + aaaa, _ := NewRR(". IN A ::1") + + buf := make([]byte, aaaa.len()) + PackRR(aaaa, buf, 0, nil, false) + aaaa = nil + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = UnpackRR(buf, 0) + } +} + +func BenchmarkPackMsg(b *testing.B) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + buf := make([]byte, 512) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = msg.PackBuffer(buf) + } +} + +func BenchmarkUnpackMsg(b *testing.B) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + msgBuf, _ := msg.Pack() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = msg.Unpack(msgBuf) + } +} diff --git a/dns_test.go b/dns_test.go index 20b33f0c..a709cce0 100644 --- a/dns_test.go +++ b/dns_test.go @@ -274,9 +274,9 @@ func TestMsgLength2(t *testing.T) { for i, hexData := range testMessages { // we won't fail the decoding of the hex input, _ := hex.DecodeString(hexData) + m := new(Msg) m.Unpack(input) - //println(m.String()) m.Compress = true lenComp := m.Len() b, _ := m.Pack() @@ -310,114 +310,6 @@ func TestMsgLengthCompressionMalformed(t *testing.T) { m.Len() // Should not crash. } -func BenchmarkMsgLength(b *testing.B) { - b.StopTimer() - makeMsg := func(question string, ans, ns, e []RR) *Msg { - msg := new(Msg) - msg.SetQuestion(Fqdn(question), TypeANY) - msg.Answer = append(msg.Answer, ans...) - msg.Ns = append(msg.Ns, ns...) - msg.Extra = append(msg.Extra, e...) - msg.Compress = true - return msg - } - name1 := "12345678901234567890123456789012345.12345678.123." - rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) - msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) - b.StartTimer() - for i := 0; i < b.N; i++ { - msg.Len() - } -} - -func BenchmarkMsgLengthPack(b *testing.B) { - makeMsg := func(question string, ans, ns, e []RR) *Msg { - msg := new(Msg) - msg.SetQuestion(Fqdn(question), TypeANY) - msg.Answer = append(msg.Answer, ans...) - msg.Ns = append(msg.Ns, ns...) - msg.Extra = append(msg.Extra, e...) - msg.Compress = true - return msg - } - name1 := "12345678901234567890123456789012345.12345678.123." - rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) - msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = msg.Pack() - } -} - -func BenchmarkMsgPackBuffer(b *testing.B) { - makeMsg := func(question string, ans, ns, e []RR) *Msg { - msg := new(Msg) - msg.SetQuestion(Fqdn(question), TypeANY) - msg.Answer = append(msg.Answer, ans...) - msg.Ns = append(msg.Ns, ns...) - msg.Extra = append(msg.Extra, e...) - msg.Compress = true - return msg - } - name1 := "12345678901234567890123456789012345.12345678.123." - rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) - msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) - buf := make([]byte, 512) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = msg.PackBuffer(buf) - } -} - -func BenchmarkMsgUnpack(b *testing.B) { - makeMsg := func(question string, ans, ns, e []RR) *Msg { - msg := new(Msg) - msg.SetQuestion(Fqdn(question), TypeANY) - msg.Answer = append(msg.Answer, ans...) - msg.Ns = append(msg.Ns, ns...) - msg.Extra = append(msg.Extra, e...) - msg.Compress = true - return msg - } - name1 := "12345678901234567890123456789012345.12345678.123." - rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) - msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) - msgBuf, _ := msg.Pack() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = msg.Unpack(msgBuf) - } -} - -func BenchmarkPackDomainName(b *testing.B) { - name1 := "12345678901234567890123456789012345.12345678.123." - buf := make([]byte, len(name1)+1) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = PackDomainName(name1, buf, 0, nil, false) - } -} - -func BenchmarkUnpackDomainName(b *testing.B) { - name1 := "12345678901234567890123456789012345.12345678.123." - buf := make([]byte, len(name1)+1) - _, _ = PackDomainName(name1, buf, 0, nil, false) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _, _ = UnpackDomainName(buf, 0) - } -} - -func BenchmarkUnpackDomainNameUnprintable(b *testing.B) { - name1 := "\x02\x02\x02\x025\x02\x02\x02\x02.12345678.123." - buf := make([]byte, len(name1)+1) - _, _ = PackDomainName(name1, buf, 0, nil, false) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _, _ = UnpackDomainName(buf, 0) - } -} - func TestToRFC3597(t *testing.T) { a, _ := NewRR("miek.nl. IN A 10.0.1.1") x := new(RFC3597) @@ -431,7 +323,7 @@ func TestNoRdataPack(t *testing.T) { data := make([]byte, 1024) for typ, fn := range TypeToRR { r := fn() - *r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600} + *r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 16} _, err := PackRR(r, data, 0, nil, false) if err != nil { t.Errorf("failed to pack RR with zero rdata: %s: %v", TypeToString[typ], err) @@ -439,7 +331,6 @@ func TestNoRdataPack(t *testing.T) { } } -// TODO(miek): fix dns buffer too small errors this throws func TestNoRdataUnpack(t *testing.T) { data := make([]byte, 1024) for typ, fn := range TypeToRR { @@ -449,7 +340,7 @@ func TestNoRdataUnpack(t *testing.T) { continue } r := fn() - *r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600} + *r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 16} off, err := PackRR(r, data, 0, nil, false) if err != nil { // Should always works, TestNoDataPack should have caught this @@ -513,23 +404,6 @@ func TestMsgCopy(t *testing.T) { } } -func BenchmarkCopy(b *testing.B) { - b.ReportAllocs() - m := new(Msg) - m.SetQuestion("miek.nl.", TypeA) - rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1") - m.Answer = []RR{rr} - rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1") - m.Ns = []RR{rr} - rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.1") - m.Extra = []RR{rr} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - m.Copy() - } -} - func TestPackIPSECKEY(t *testing.T) { tests := []string{ "38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.38 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )", diff --git a/dnssec.go b/dnssec.go index 84cb2142..94010497 100644 --- a/dnssec.go +++ b/dnssec.go @@ -144,7 +144,7 @@ func (k *DNSKEY) KeyTag() uint16 { // at the base64 values. But I'm lazy. modulus, _ := fromBase64([]byte(k.PublicKey)) if len(modulus) > 1 { - x, _ := unpackUint16(modulus, len(modulus)-2) + x, _ := unpackUint16Msg(modulus, len(modulus)-2) keytag = int(x) } default: diff --git a/edns.go b/edns.go index c1aebc11..5634cd70 100644 --- a/edns.go +++ b/edns.go @@ -213,7 +213,7 @@ func (e *EDNS0_SUBNET) Option() uint16 { func (e *EDNS0_SUBNET) pack() ([]byte, error) { b := make([]byte, 4) - b[0], b[1] = packUint16(e.Family) + b[0], b[1] = packUint16Msg(e.Family) b[2] = e.SourceNetmask b[3] = e.SourceScope switch e.Family { @@ -247,7 +247,7 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error { if len(b) < 4 { return ErrBuf } - e.Family, _ = unpackUint16(b, 0) + e.Family, _ = unpackUint16Msg(b, 0) e.SourceNetmask = b[2] e.SourceScope = b[3] switch e.Family { @@ -369,9 +369,9 @@ func (e *EDNS0_LLQ) Option() uint16 { return EDNS0LLQ } func (e *EDNS0_LLQ) pack() ([]byte, error) { b := make([]byte, 18) - b[0], b[1] = packUint16(e.Version) - b[2], b[3] = packUint16(e.Opcode) - b[4], b[5] = packUint16(e.Error) + b[0], b[1] = packUint16Msg(e.Version) + b[2], b[3] = packUint16Msg(e.Opcode) + b[4], b[5] = packUint16Msg(e.Error) b[6] = byte(e.Id >> 56) b[7] = byte(e.Id >> 48) b[8] = byte(e.Id >> 40) @@ -391,9 +391,9 @@ func (e *EDNS0_LLQ) unpack(b []byte) error { if len(b) < 18 { return ErrBuf } - e.Version, _ = unpackUint16(b, 0) - e.Opcode, _ = unpackUint16(b, 2) - e.Error, _ = unpackUint16(b, 4) + e.Version, _ = unpackUint16Msg(b, 0) + e.Opcode, _ = unpackUint16Msg(b, 2) + e.Error, _ = unpackUint16Msg(b, 4) e.Id = uint64(b[6])<<56 | uint64(b[6+1])<<48 | uint64(b[6+2])<<40 | uint64(b[6+3])<<32 | uint64(b[6+4])<<24 | uint64(b[6+5])<<16 | uint64(b[6+6])<<8 | uint64(b[6+7]) e.LeaseLife = uint32(b[14])<<24 | uint32(b[14+1])<<16 | uint32(b[14+2])<<8 | uint32(b[14+3]) diff --git a/zgenerate.go b/generate.go similarity index 100% rename from zgenerate.go rename to generate.go diff --git a/labels.go b/labels.go index cb549fc6..fca5c7dd 100644 --- a/labels.go +++ b/labels.go @@ -107,7 +107,7 @@ func CountLabel(s string) (labels int) { // Split splits a name s into its label indexes. // www.miek.nl. returns []int{0, 4, 9}, www.miek.nl also returns []int{0, 4, 9}. -// The root name (.) returns nil. Also see SplitDomainName. +// The root name (.) returns nil. Also see SplitDomainName. // s must be a syntactically valid domain name. func Split(s string) []int { if s == "." { diff --git a/labels_test.go b/labels_test.go index 50f31163..536757d5 100644 --- a/labels_test.go +++ b/labels_test.go @@ -184,12 +184,14 @@ func BenchmarkLenLabels(b *testing.B) { } func BenchmarkCompareLabels(b *testing.B) { + b.ReportAllocs() for i := 0; i < b.N; i++ { CompareDomainName("www.example.com", "aa.example.com") } } func BenchmarkIsSubDomain(b *testing.B) { + b.ReportAllocs() for i := 0; i < b.N; i++ { IsSubDomain("www.example.com", "aa.example.com") IsSubDomain("example.com", "aa.example.com") diff --git a/msg.go b/msg.go index 41b6e0da..60463b6d 100644 --- a/msg.go +++ b/msg.go @@ -8,9 +8,9 @@ package dns +//go:generate go run msg_generate.go + import ( - "encoding/base32" - "encoding/base64" "encoding/hex" "math/big" "math/rand" @@ -92,18 +92,6 @@ type Msg struct { Extra []RR // Holds the RR(s) of the additional section. } -// StringToType is the reverse of TypeToString, needed for string parsing. -var StringToType = reverseInt16(TypeToString) - -// StringToClass is the reverse of ClassToString, needed for string parsing. -var StringToClass = reverseInt16(ClassToString) - -// Map of opcodes strings. -var StringToOpcode = reverseInt(OpcodeToString) - -// Map of rcodes strings. -var StringToRcode = reverseInt(RcodeToString) - // ClassToString is a maps Classes to strings for each CLASS wire type. var ClassToString = map[uint16]string{ ClassINET: "IN", @@ -291,11 +279,11 @@ func packDomainName(s string, msg []byte, off int, compression map[string]int, c if pointer != -1 { // We have two bytes (14 bits) to put the pointer in // if msg == nil, we will never do compression - msg[nameoffset], msg[nameoffset+1] = packUint16(uint16(pointer ^ 0xC000)) + msg[nameoffset], msg[nameoffset+1] = packUint16Msg(uint16(pointer ^ 0xC000)) off = nameoffset + 1 goto End } - if msg != nil { + if msg != nil && off < len(msg) { msg[off] = 0 } End: @@ -423,7 +411,7 @@ func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) { func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) { lenByteOffset := offset - if offset >= len(msg) { + if offset >= len(msg) || len(s) > len(tmp) { return offset, ErrBuf } offset++ @@ -465,7 +453,7 @@ func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) { } func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) { - if offset >= len(msg) { + if offset >= len(msg) || len(s) > len(tmp) { return offset, ErrBuf } bs := tmp[:len(s)] @@ -600,9 +588,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str return lenmsg, &Error{err: "overflow packing opt"} } // Option code - msg[off], msg[off+1] = packUint16(element.(EDNS0).Option()) + msg[off], msg[off+1] = packUint16Msg(element.(EDNS0).Option()) // Length - msg[off+2], msg[off+3] = packUint16(uint16(len(b))) + msg[off+2], msg[off+3] = packUint16Msg(uint16(len(b))) off += 4 if off+len(b) > lenmsg { copy(msg[off:], b) @@ -783,6 +771,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str if e != nil { return lenmsg, e } + if off+len(b64) > lenmsg { + return lenmsg, &Error{err: "overflow packing base64"} + } copy(msg[off:off+len(b64)], b64) off += len(b64) case `dns:"domain-name"`: @@ -811,6 +802,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str if e != nil { return lenmsg, e } + if off+len(b32) > lenmsg { + return lenmsg, &Error{err: "overflow packing base32"} + } copy(msg[off:off+len(b32)], b32) off += len(b32) case `dns:"size-hex"`: @@ -827,6 +821,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str copy(msg[off:off+hex.DecodedLen(len(s))], h) off += hex.DecodedLen(len(s)) case `dns:"size"`: + // TODO(miek): WTF? size? // the size is already encoded in the RR, we can safely use the // length of string. String is RAW (not encoded in hex, nor base64) copy(msg[off:off+len(s)], s) @@ -930,8 +925,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er if off+4 > lenmsg { return lenmsg, &Error{err: "overflow unpacking opt"} } - code, off = unpackUint16(msg, off) - optlen, off1 := unpackUint16(msg, off) + code, off = unpackUint16Msg(msg, off) + optlen, off1 := unpackUint16Msg(msg, off) if off1+int(optlen) > lenmsg { return lenmsg, &Error{err: "overflow unpacking opt"} } @@ -1174,7 +1169,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er if off+2 > lenmsg { return lenmsg, &Error{err: "overflow unpacking uint16"} } - i, off = unpackUint16(msg, off) + i, off = unpackUint16Msg(msg, off) fv.SetUint(uint64(i)) case reflect.Uint32: if off == lenmsg { @@ -1334,38 +1329,6 @@ func intToBytes(i *big.Int, length int) []byte { return buf } -func unpackUint16(msg []byte, off int) (uint16, int) { - return uint16(msg[off])<<8 | uint16(msg[off+1]), off + 2 -} - -func packUint16(i uint16) (byte, byte) { - return byte(i >> 8), byte(i) -} - -func toBase32(b []byte) string { - return base32.HexEncoding.EncodeToString(b) -} - -func fromBase32(s []byte) (buf []byte, err error) { - buflen := base32.HexEncoding.DecodedLen(len(s)) - buf = make([]byte, buflen) - n, err := base32.HexEncoding.Decode(buf, s) - buf = buf[:n] - return -} - -func toBase64(b []byte) string { - return base64.StdEncoding.EncodeToString(b) -} - -func fromBase64(s []byte) (buf []byte, err error) { - buflen := base64.StdEncoding.DecodedLen(len(s)) - buf = make([]byte, buflen) - n, err := base64.StdEncoding.Decode(buf, s) - buf = buf[:n] - return -} - // PackRR packs a resource record rr into msg[off:]. // See PackDomainName for documentation about the compression. func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { @@ -1373,7 +1336,62 @@ func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress boo return len(msg), &Error{err: "nil rr"} } - off1, err = packStructCompress(rr, msg, off, compression, compress) + _, ok := typeToUnpack[rr.Header().Rrtype] + switch ok { + case true: + // Shortcut reflection, `pack' needs to be added to the RR interface so we can just do this: + // off1, err = t.pack(msg, off, compression, compress) + // TODO(miek): revert the logic and make a blacklist for types that still use reflection. Kill + // typeToUnpack and just generate all the pack and unpack functions even though we don't use + // them for all types (yet). + switch t := rr.(type) { + case *RR_Header: + // we can be called with an empty RR, consisting only out of the header, see update_test.go's + // TestDynamicUpdateZeroRdataUnpack for an example. This is OK as RR_Header also implements the RR interface. + off1, err = t.pack(msg, off, compression, compress) + case *ANY: + // Also "weird" setup, see (again) update_test.go's TestRemoveRRset, where the Rrtype is 1 but the type is *ANY. + off1, err = t.pack(msg, off, compression, compress) + case *A: + off1, err = t.pack(msg, off, compression, compress) + case *AAAA: + off1, err = t.pack(msg, off, compression, compress) + case *CNAME: + off1, err = t.pack(msg, off, compression, compress) + case *DNAME: + off1, err = t.pack(msg, off, compression, compress) + case *HINFO: + off1, err = t.pack(msg, off, compression, compress) + case *L32: + off1, err = t.pack(msg, off, compression, compress) + case *LOC: + off1, err = t.pack(msg, off, compression, compress) + case *MB: + off1, err = t.pack(msg, off, compression, compress) + case *MD: + off1, err = t.pack(msg, off, compression, compress) + case *MF: + off1, err = t.pack(msg, off, compression, compress) + case *MG: + off1, err = t.pack(msg, off, compression, compress) + case *MX: + off1, err = t.pack(msg, off, compression, compress) + case *NID: + off1, err = t.pack(msg, off, compression, compress) + case *NS: + off1, err = t.pack(msg, off, compression, compress) + case *PTR: + off1, err = t.pack(msg, off, compression, compress) + case *RP: + off1, err = t.pack(msg, off, compression, compress) + case *SRV: + off1, err = t.pack(msg, off, compression, compress) + case *DNSKEY: + off1, err = t.pack(msg, off, compression, compress) + } + default: + off1, err = packStructCompress(rr, msg, off, compression, compress) + } if err != nil { return len(msg), err } @@ -1385,21 +1403,27 @@ func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress boo // UnpackRR unpacks msg[off:] into an RR. func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) { - // unpack just the header, to find the rr type and length - var h RR_Header off0 := off - if off, err = UnpackStruct(&h, msg, off); err != nil { + h, off, msg, err := unpackHeader(msg, off) + if err != nil { return nil, len(msg), err } end := off + int(h.Rdlength) - // make an rr of that type and re-unpack. - mk, known := TypeToRR[h.Rrtype] - if !known { - rr = new(RFC3597) - } else { - rr = mk() + + fn, ok := typeToUnpack[h.Rrtype] + switch ok { + case true: + // Shortcut reflection. + rr, off, err = fn(h, msg, off) + default: + mk, known := TypeToRR[h.Rrtype] + if !known { + rr = new(RFC3597) + } else { + rr = mk() + } + off, err = UnpackStruct(rr, msg, off0) } - off, err = UnpackStruct(rr, msg, off0) if off != end { return &h, end, &Error{err: "bad rdlength"} } @@ -1432,31 +1456,6 @@ func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) return dst, off, err } -// Reverse a map -func reverseInt8(m map[uint8]string) map[string]uint8 { - n := make(map[string]uint8) - for u, s := range m { - n[s] = u - } - return n -} - -func reverseInt16(m map[uint16]string) map[string]uint16 { - n := make(map[string]uint16) - for u, s := range m { - n[s] = u - } - return n -} - -func reverseInt(m map[int]string) map[string]int { - n := make(map[string]int) - for u, s := range m { - n[s] = u - } - return n -} - // Convert a MsgHdr to a string, with dig-like headers: // //;; opcode: QUERY, status: NOERROR, id: 48404 @@ -1510,8 +1509,11 @@ func (dns *Msg) Pack() (msg []byte, err error) { // PackBuffer packs a Msg, using the given buffer buf. If buf is too small // a new buffer is allocated. func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) { - var dh Header - var compression map[string]int + var ( + dh Header + compression map[string]int + ) + if dns.Compress { compression = make(map[string]int) // Compression pointer mappings } @@ -1579,12 +1581,12 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) { // Pack it in: header and then the pieces. off := 0 - off, err = packStructCompress(&dh, msg, off, compression, dns.Compress) + off, err = dh.pack(msg, off, compression, dns.Compress) if err != nil { return nil, err } for i := 0; i < len(question); i++ { - off, err = packStructCompress(&question[i], msg, off, compression, dns.Compress) + off, err = question[i].pack(msg, off, compression, dns.Compress) if err != nil { return nil, err } @@ -1612,12 +1614,17 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) { // Unpack unpacks a binary message to a Msg structure. func (dns *Msg) Unpack(msg []byte) (err error) { - // Header. - var dh Header - off := 0 - if off, err = UnpackStruct(&dh, msg, off); err != nil { + var ( + dh Header + off int + ) + if dh, off, err = unpackMsgHdr(msg, off); err != nil { return err } + if off == len(msg) { + return ErrTruncated + } + dns.Id = dh.Id dns.Response = (dh.Bits & _QR) != 0 dns.Opcode = int(dh.Bits>>11) & 0xF @@ -1633,10 +1640,10 @@ func (dns *Msg) Unpack(msg []byte) (err error) { // Optimistically use the count given to us in the header dns.Question = make([]Question, 0, int(dh.Qdcount)) - var q Question for i := 0; i < int(dh.Qdcount); i++ { off1 := off - off, err = UnpackStruct(&q, msg, off) + var q Question + q, off, err = unpackQuestion(msg, off) if err != nil { // Even if Truncated is set, we only will set ErrTruncated if we // actually got the questions @@ -1662,6 +1669,7 @@ func (dns *Msg) Unpack(msg []byte) (err error) { } // The header counts might have been wrong so we need to update it dh.Arcount = uint16(len(dns.Extra)) + if off != len(msg) { // TODO(miek) make this an error? // use PackOpt to let people tell how detailed the error reporting should be? @@ -1735,6 +1743,9 @@ func (dns *Msg) Len() int { } } for i := 0; i < len(dns.Answer); i++ { + if dns.Answer[i] == nil { + continue + } l += dns.Answer[i].len() if dns.Compress { k, ok := compressionLenSearch(compression, dns.Answer[i].Header().Name) @@ -1750,6 +1761,9 @@ func (dns *Msg) Len() int { } } for i := 0; i < len(dns.Ns); i++ { + if dns.Ns[i] == nil { + continue + } l += dns.Ns[i].len() if dns.Compress { k, ok := compressionLenSearch(compression, dns.Ns[i].Header().Name) @@ -1765,6 +1779,9 @@ func (dns *Msg) Len() int { } } for i := 0; i < len(dns.Extra); i++ { + if dns.Extra[i] == nil { + continue + } l += dns.Extra[i].len() if dns.Compress { k, ok := compressionLenSearch(compression, dns.Extra[i].Header().Name) @@ -1955,3 +1972,122 @@ func (dns *Msg) CopyTo(r1 *Msg) *Msg { return r1 } + +func (q *Question) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := PackDomainName(q.Name, msg, off, compression, compress) + if err != nil { + return off, err + } + off, err = packUint16(q.Qtype, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(q.Qclass, msg, off) + if err != nil { + return off, err + } + return off, nil +} + +func unpackQuestion(msg []byte, off int) (Question, int, error) { + var ( + q Question + err error + ) + q.Name, off, err = UnpackDomainName(msg, off) + if err != nil { + return q, off, err + } + if off == len(msg) { + return q, off, nil + } + q.Qtype, off, err = unpackUint16(msg, off) + if err != nil { + return q, off, err + } + if off == len(msg) { + return q, off, nil + } + q.Qclass, off, err = unpackUint16(msg, off) + if off == len(msg) { + return q, off, nil + } + return q, off, err +} + +func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := packUint16(dh.Id, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(dh.Bits, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(dh.Qdcount, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(dh.Ancount, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(dh.Nscount, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(dh.Arcount, msg, off) + return off, err +} + +func unpackMsgHdr(msg []byte, off int) (Header, int, error) { + var ( + dh Header + err error + ) + dh.Id, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Bits, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Qdcount, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Ancount, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Nscount, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Arcount, off, err = unpackUint16(msg, off) + return dh, off, err +} + +// Which types have type specific unpack functions. +var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){ + TypeAAAA: unpackAAAA, + TypeA: unpackA, + TypeCNAME: unpackCNAME, + TypeDNAME: unpackDNAME, + TypeL32: unpackL32, + TypeLOC: unpackLOC, + TypeMB: unpackMB, + TypeMD: unpackMD, + TypeMF: unpackMF, + TypeMG: unpackMG, + TypeMR: unpackMR, + TypeMX: unpackMX, + TypeNID: unpackNID, + TypeNS: unpackNS, + TypePTR: unpackPTR, + TypeRP: unpackRP, + TypeSRV: unpackSRV, + TypeHINFO: unpackHINFO, + TypeDNSKEY: unpackDNSKEY, +} diff --git a/msg_generate.go b/msg_generate.go new file mode 100644 index 00000000..2f7bc99f --- /dev/null +++ b/msg_generate.go @@ -0,0 +1,307 @@ +//+build ignore + +// msg_generate.go is meant to run with go generate. It will use +// go/{importer,types} to track down all the RR struct types. Then for each type +// it will generate pack/unpack methods based on the struct tags. The generated source is +// written to zmsg.go, and is meant to be checked into git. +package main + +import ( + "bytes" + "fmt" + "go/format" + "go/importer" + "go/types" + "log" + "os" +) + +// All RR pack and unpack functions should be generated, currently RR that present some +// problems +// * NSEC/NSEC3 - type bitmap +// * TXT/SPF - string slice +// * URI - weird octet thing there +// * NSEC3/TSIG - size hex +// * OPT RR - EDNS0 parsing - needs to some looking at +// * HIP - uses "hex", but is actually size-hex - might drop size-hex? +// * Z +// * WKS - uint16 slice +// * NINFO +// * PrivateRR + +// What types are we generating, should be kept in sync with typeToUnpack in msg.go +var generate = map[string]bool{ + "AAAA": true, + "ANY": true, + "A": true, + "CNAME": true, + "DNAME": true, + "DNSKEY": true, + "HINFO": true, + "L32": true, + "LOC": true, + "MB": true, + "MD": true, + "MF": true, + "MG": true, + "MR": true, + "MX": true, + "NID": true, + "NS": true, + "PTR": true, + "RP": true, + "SRV": true, +} + +func shouldGenerate(name string) bool { + _, ok := generate[name] + return ok +} + +// For later: IPSECKEY is weird. + +var packageHdr = ` +// *** DO NOT MODIFY *** +// AUTOGENERATED BY go generate from msg_generate.go + +package dns + +` + +// getTypeStruct will take a type and the package scope, and return the +// (innermost) struct if the type is considered a RR type (currently defined as +// those structs beginning with a RR_Header, could be redefined as implementing +// the RR interface). The bool return value indicates if embedded structs were +// resolved. +func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { + st, ok := t.Underlying().(*types.Struct) + if !ok { + return nil, false + } + if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { + return st, false + } + if st.Field(0).Anonymous() { + st, _ := getTypeStruct(st.Field(0).Type(), scope) + return st, true + } + return nil, false +} + +func main() { + // Import and type-check the package + pkg, err := importer.Default().Import("github.com/miekg/dns") + fatalIfErr(err) + scope := pkg.Scope() + + // Collect actual types (*X) + var namedTypes []string + for _, name := range scope.Names() { + o := scope.Lookup(name) + if o == nil || !o.Exported() { + continue + } + if st, _ := getTypeStruct(o.Type(), scope); st == nil { + continue + } + if name == "PrivateRR" { + continue + } + + // Check if corresponding TypeX exists + if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { + log.Fatalf("Constant Type%s does not exist.", o.Name()) + } + + namedTypes = append(namedTypes, o.Name()) + } + + b := &bytes.Buffer{} + b.WriteString(packageHdr) + + fmt.Fprint(b, "// pack*() functions\n\n") + for _, name := range namedTypes { + o := scope.Lookup(name) + st, isEmbedded := getTypeStruct(o.Type(), scope) + if isEmbedded || !shouldGenerate(name) { + continue + } + + fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name) + fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress) +if err != nil { + return off, err +} +headerEnd := off +`) + for i := 1; i < st.NumFields(); i++ { + o := func(s string) { + fmt.Fprintf(b, s, st.Field(i).Name()) + fmt.Fprint(b, `if err != nil { +return off, err +} +`) + } + + //if _, ok := st.Field(i).Type().(*types.Slice); ok { + //switch st.Tag(i) { + //case `dns:"-"`: + //// ignored + //case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`: + //o("for _, x := range rr.%s { l += len(x) + 1 }\n") + //default: + //log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) + //} + //continue + //} + + switch st.Tag(i) { + case `dns:"-"`: + // ignored + case `dns:"cdomain-name"`: + fallthrough + case `dns:"domain-name"`: + o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n") + case `dns:"a"`: + o("off, err = packDataA(rr.%s, msg, off)\n") + case `dns:"aaaa"`: + o("off, err = packDataAAAA(rr.%s, msg, off)\n") + case `dns:"uint48"`: + o("off, err = packUint48(rr.%s, msg, off)\n") + case `dns:"txt"`: + o("off, err = packString(rr.%s, msg, off)\n") + case `dns:"base32"`: + o("off, err = packStringBase32(rr.%s, msg, off)\n") + case `dns:"base64"`: + o("off, err = packStringBase64(rr.%s, msg, off)\n") + case "": + switch st.Field(i).Type().(*types.Basic).Kind() { + case types.Uint8: + o("off, err = packUint8(rr.%s, msg, off)\n") + case types.Uint16: + o("off, err = packUint16(rr.%s, msg, off)\n") + case types.Uint32: + o("off, err = packUint32(rr.%s, msg, off)\n") + case types.Uint64: + o("off, err = packUint64(rr.%s, msg, off)\n") + case types.String: + o("off, err = packString(rr.%s, msg, off)\n") + default: + log.Fatalln(name, st.Field(i).Name()) + } + //default: + //log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) + } + } + // We have packed everything, only now we know the rdlength of this RR + fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off- headerEnd)") + fmt.Fprintln(b, "return off, nil }\n") + } + + fmt.Fprint(b, "// unpack*() functions\n\n") + for _, name := range namedTypes { + o := scope.Lookup(name) + st, isEmbedded := getTypeStruct(o.Type(), scope) + if isEmbedded || !shouldGenerate(name) { + continue + } + + fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name) + fmt.Fprint(b, `if noRdata(h) { +return nil, off, nil + } +var err error +rdStart := off +_ = rdStart + +`) + fmt.Fprintf(b, "rr := new(%s)\n", name) + fmt.Fprintln(b, "rr.Hdr = h\n") + for i := 1; i < st.NumFields(); i++ { + o := func(s string) { + fmt.Fprintf(b, s, st.Field(i).Name()) + fmt.Fprint(b, `if err != nil { +return rr, off, err +} +`) + } + + //if _, ok := st.Field(i).Type().(*types.Slice); ok { + //switch st.Tag(i) { + //case `dns:"-"`: + //// ignored + //case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`: + //o("for _, x := range rr.%s { l += len(x) + 1 }\n") + //default: + //log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) + //} + //continue + //} + + switch st.Tag(i) { + case `dns:"-"`: + // ignored + case `dns:"cdomain-name"`: + fallthrough + case `dns:"domain-name"`: + o("rr.%s, off, err = UnpackDomainName(msg, off)\n") + case `dns:"a"`: + o("rr.%s, off, err = unpackDataA(msg, off)\n") + case `dns:"aaaa"`: + o("rr.%s, off, err = unpackDataAAAA(msg, off)\n") + case `dns:"uint48"`: + o("rr.%s, off, err = unpackUint48(msg, off)\n") + case `dns:"txt"`: + o("rr.%s, off, err = unpackString(msg, off)\n") + case `dns:"base32"`: + o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") + case `dns:"base64"`: + o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") + case "": + switch st.Field(i).Type().(*types.Basic).Kind() { + case types.Uint8: + o("rr.%s, off, err = unpackUint8(msg, off)\n") + case types.Uint16: + o("rr.%s, off, err = unpackUint16(msg, off)\n") + case types.Uint32: + o("rr.%s, off, err = unpackUint32(msg, off)\n") + case types.Uint64: + o("rr.%s, off, err = unpackUint64(msg, off)\n") + case types.String: + o("rr.%s, off, err = unpackString(msg, off)\n") + default: + log.Fatalln(name, st.Field(i).Name()) + } + //default: + //log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) + } + // If we've hit len(msg) we return without error. + if i < st.NumFields()-1 { + fmt.Fprintf(b, `if off == len(msg) { +return rr, off, nil + } +`) + } + } + fmt.Fprintf(b, "return rr, off, err }\n\n") + } + + // gofmt + res, err := format.Source(b.Bytes()) + if err != nil { + b.WriteTo(os.Stderr) + log.Fatal(err) + } + + // write result + f, err := os.Create("zmsg.go") + fatalIfErr(err) + defer f.Close() + f.Write(res) +} + +func fatalIfErr(err error) { + if err != nil { + log.Fatal(err) + } +} diff --git a/msg_helpers.go b/msg_helpers.go new file mode 100644 index 00000000..06eea492 --- /dev/null +++ b/msg_helpers.go @@ -0,0 +1,399 @@ +package dns + +import ( + "encoding/base32" + "encoding/base64" + "encoding/hex" + "net" + "strconv" +) + +// helper functions called from the generated zmsg.go + +// These function are named after the tag to help pack/unpack, if there is no tag it is the name +// of the type they pack/unpack (string, int, etc). We prefix all with unpackData or packData, so packDataA or +// packDataDomainName. + +func unpackDataA(msg []byte, off int) (net.IP, int, error) { + if off+net.IPv4len > len(msg) { + return nil, len(msg), &Error{err: "overflow unpacking a"} + } + a := net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]) + off += net.IPv4len + return a, off, nil +} + +func packDataA(a net.IP, msg []byte, off int) (int, error) { + // It must be a slice of 4, even if it is 16, we encode only the first 4 + if off+net.IPv4len > len(msg) { + return len(msg), &Error{err: "overflow packing a"} + } + switch len(a) { + case net.IPv6len: + msg[off] = a[12] + msg[off+1] = a[13] + msg[off+2] = a[14] + msg[off+3] = a[15] + off += net.IPv4len + case net.IPv4len: + msg[off] = a[0] + msg[off+1] = a[1] + msg[off+2] = a[2] + msg[off+3] = a[3] + off += net.IPv4len + case 0: + // Allowed, for dynamic updates. + default: + return len(msg), &Error{err: "overflow packing a"} + } + return off, nil +} + +func unpackDataAAAA(msg []byte, off int) (net.IP, int, error) { + if off+net.IPv6len > len(msg) { + return nil, len(msg), &Error{err: "overflow unpacking aaaa"} + } + aaaa := net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4], + msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10], + msg[off+11], msg[off+12], msg[off+13], msg[off+14], msg[off+15]} + off += net.IPv6len + return aaaa, off, nil +} + +func packDataAAAA(aaaa net.IP, msg []byte, off int) (int, error) { + if off+net.IPv6len > len(msg) { + return len(msg), &Error{err: "overflow packing aaaa"} + } + + switch len(aaaa) { + case net.IPv6len: + msg[off] = aaaa[0] + msg[off+1] = aaaa[1] + msg[off+2] = aaaa[2] + msg[off+3] = aaaa[3] + msg[off+4] = aaaa[4] + msg[off+5] = aaaa[5] + msg[off+6] = aaaa[6] + msg[off+7] = aaaa[7] + msg[off+8] = aaaa[8] + msg[off+9] = aaaa[9] + msg[off+10] = aaaa[10] + msg[off+11] = aaaa[11] + msg[off+12] = aaaa[12] + msg[off+13] = aaaa[13] + msg[off+14] = aaaa[14] + msg[off+15] = aaaa[15] + off += net.IPv6len + case 0: + // Allowed, dynamic updates. + default: + return len(msg), &Error{err: "overflow packing aaaa"} + } + return off, nil +} + +// unpackHeader unpacks an RR header, returning the offset to the end of the header and a +// re-sliced msg according to the expected length of the RR. +func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte, err error) { + hdr := RR_Header{} + if off == len(msg) { + return hdr, off, msg, nil + } + + hdr.Name, off, err = UnpackDomainName(msg, off) + if err != nil { + return hdr, len(msg), msg, err + } + hdr.Rrtype, off, err = unpackUint16(msg, off) + if err != nil { + return hdr, len(msg), msg, err + } + hdr.Class, off, err = unpackUint16(msg, off) + if err != nil { + return hdr, len(msg), msg, err + } + hdr.Ttl, off, err = unpackUint32(msg, off) + if err != nil { + return hdr, len(msg), msg, err + } + hdr.Rdlength, off, err = unpackUint16(msg, off) + if err != nil { + return hdr, len(msg), msg, err + } + msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength) + return hdr, off, msg, nil +} + +// pack packs an RR header, returning the offset to the end of the header. +// See PackDomainName for documentation about the compression. +func (hdr RR_Header) pack(msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { + if off == len(msg) { + return off, nil + } + + off, err = PackDomainName(hdr.Name, msg, off, compression, compress) + if err != nil { + return len(msg), err + } + off, err = packUint16(hdr.Rrtype, msg, off) + if err != nil { + return len(msg), err + } + off, err = packUint16(hdr.Class, msg, off) + if err != nil { + return len(msg), err + } + off, err = packUint32(hdr.Ttl, msg, off) + if err != nil { + return len(msg), err + } + off, err = packUint16(hdr.Rdlength, msg, off) + if err != nil { + return len(msg), err + } + return off, nil +} + +// helper helper functions. + +// truncateMsgFromRdLength truncates msg to match the expected length of the RR. +// Returns an error if msg is smaller than the expected size. +func truncateMsgFromRdlength(msg []byte, off int, rdlength uint16) (truncmsg []byte, err error) { + lenrd := off + int(rdlength) + if lenrd > len(msg) { + return msg, &Error{err: "overflowing header size"} + } + return msg[:lenrd], nil +} + +func fromBase32(s []byte) (buf []byte, err error) { + buflen := base32.HexEncoding.DecodedLen(len(s)) + buf = make([]byte, buflen) + n, err := base32.HexEncoding.Decode(buf, s) + buf = buf[:n] + return +} + +func toBase32(b []byte) string { return base32.HexEncoding.EncodeToString(b) } + +func fromBase64(s []byte) (buf []byte, err error) { + buflen := base64.StdEncoding.DecodedLen(len(s)) + buf = make([]byte, buflen) + n, err := base64.StdEncoding.Decode(buf, s) + buf = buf[:n] + return +} + +func toBase64(b []byte) string { return base64.StdEncoding.EncodeToString(b) } + +func unpackUint16Msg(msg []byte, off int) (uint16, int) { + return uint16(msg[off])<<8 | uint16(msg[off+1]), off + 2 +} + +func packUint16Msg(i uint16) (byte, byte) { return byte(i >> 8), byte(i) } + +func unpackUint32Msg(msg []byte, off int) (uint32, int) { + return uint32(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]))), off + 4 +} + +func packUint32Msg(i uint32) (byte, byte, byte, byte) { + return byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i) +} + +// dynamicUpdate returns true if the Rdlength is zero. +func noRdata(h RR_Header) bool { return h.Rdlength == 0 } + +func unpackUint8(msg []byte, off int) (i uint8, off1 int, err error) { + if off+1 > len(msg) { + return 0, len(msg), &Error{err: "overflow unpacking uint8"} + } + return uint8(msg[off]), off + 1, nil +} + +func packUint8(i uint8, msg []byte, off int) (off1 int, err error) { + if off+1 > len(msg) { + return len(msg), &Error{err: "overflow packing uint8"} + } + msg[off] = byte(i) + return off + 1, nil +} + +func unpackUint16(msg []byte, off int) (i uint16, off1 int, err error) { + if off+2 > len(msg) { + return 0, len(msg), &Error{err: "overflow unpacking uint16"} + } + i, off = unpackUint16Msg(msg, off) + return i, off, nil +} + +func packUint16(i uint16, msg []byte, off int) (off1 int, err error) { + if off+2 > len(msg) { + return len(msg), &Error{err: "overflow packing uint16"} + } + msg[off], msg[off+1] = packUint16Msg(i) + return off + 2, nil +} + +func unpackUint32(msg []byte, off int) (i uint32, off1 int, err error) { + if off+4 > len(msg) { + return 0, len(msg), &Error{err: "overflow unpacking uint32"} + } + i, off = unpackUint32Msg(msg, off) + return i, off, nil +} + +func packUint32(i uint32, msg []byte, off int) (off1 int, err error) { + if off+4 > len(msg) { + return len(msg), &Error{err: "overflow packing uint32"} + } + msg[off], msg[off+1], msg[off+2], msg[off+3] = packUint32Msg(i) + return off + 4, nil +} + +func unpackUint48(msg []byte, off int) (i uint64, off1 int, err error) { + if off+6 > len(msg) { + return 0, len(msg), &Error{err: "overflow unpacking uint64 as uint48"} + } + // Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes) + i = (uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 | + uint64(msg[off+4])<<8 | uint64(msg[off+5]))) + off += 6 + return i, off, nil +} + +func packUint48(i uint64, msg []byte, off int) (off1 int, err error) { + if off+6 > len(msg) { + return len(msg), &Error{err: "overflow packing uint64 as uint48"} + } + msg[off] = byte(i >> 40) + msg[off+1] = byte(i >> 32) + msg[off+2] = byte(i >> 24) + msg[off+3] = byte(i >> 16) + msg[off+4] = byte(i >> 8) + msg[off+5] = byte(i) + off += 6 + return off, nil +} + +func unpackUint64(msg []byte, off int) (i uint64, off1 int, err error) { + if off+8 > len(msg) { + return 0, len(msg), &Error{err: "overflow unpacking uint64"} + } + i = (uint64(uint64(msg[off])<<56 | uint64(msg[off+1])<<48 | uint64(msg[off+2])<<40 | + uint64(msg[off+3])<<32 | uint64(msg[off+4])<<24 | uint64(msg[off+5])<<16 | uint64(msg[off+6])<<8 | uint64(msg[off+7]))) + off += 8 + return i, off, nil +} + +func packUint64(i uint64, msg []byte, off int) (off1 int, err error) { + if off+8 > len(msg) { + return len(msg), &Error{err: "overflow packing uint64"} + } + msg[off] = byte(i >> 56) + msg[off+1] = byte(i >> 48) + msg[off+2] = byte(i >> 40) + msg[off+3] = byte(i >> 32) + msg[off+4] = byte(i >> 24) + msg[off+5] = byte(i >> 16) + msg[off+6] = byte(i >> 8) + msg[off+7] = byte(i) + off += 8 + return off, nil +} + +func unpackString(msg []byte, off int) (string, int, error) { + if off+1 > len(msg) { + return "", off, &Error{err: "overflow unpacking txt"} + } + l := int(msg[off]) + if off+l+1 > len(msg) { + return "", off, &Error{err: "overflow unpacking txt"} + } + s := make([]byte, 0, l) + for _, b := range msg[off+1 : off+1+l] { + switch b { + case '"', '\\': + s = append(s, '\\', b) + case '\t': + s = append(s, '\t') + case '\r': + s = append(s, '\r') + case '\n': + s = append(s, '\n') + default: + if b < 32 || b > 127 { // unprintable + var buf [3]byte + bufs := strconv.AppendInt(buf[:0], int64(b), 10) + s = append(s, '\\') + for i := 0; i < 3-len(bufs); i++ { + s = append(s, '0') + } + for _, r := range bufs { + s = append(s, r) + } + } else { + s = append(s, b) + } + } + } + off += 1 + l + return string(s), off, nil +} + +func packString(s string, msg []byte, off int) (int, error) { + txtTmp := make([]byte, 256*4+1) + off, err := packTxtString(s, msg, off, txtTmp) + if err != nil { + return len(msg), err + } + return off, nil +} + +func unpackStringBase64(msg []byte, off, end int) (string, int, error) { + // Rest of the RR is base64 encoded value, so we don't need an explicit length + // to be set. Thus far all RR's that have base64 encoded fields have those as their + // last one. What we do need is the end of the RR! + if end > len(msg) { + return "", len(msg), &Error{err: "overflow unpacking base64"} + } + s := toBase64(msg[off:end]) + return s, end, nil +} + +func packStringBase64(s string, msg []byte, off int) (int, error) { + b64, e := fromBase64([]byte(s)) + if e != nil { + return len(msg), e + } + if off+len(b64) > len(msg) { + return len(msg), &Error{err: "overflow packing base64"} + } + copy(msg[off:off+len(b64)], b64) + off += len(b64) + return off, nil +} + +func unpackStringHex(msg []byte, off, end int) (string, int, error) { + // Rest of the RR is hex encoded value, so we don't need an explicit length + // to be set. NSEC and TSIG have hex fields with a length field. + // What we do need is the end of the RR! + if end > len(msg) { + return "", len(msg), &Error{err: "overflow unpacking hex"} + } + + s := hex.EncodeToString(msg[off:end]) + return s, end, nil +} + +func packStringHex(s string, msg []byte, off int) (int, error) { + h, e := hex.DecodeString(s) + if e != nil { + return len(msg), e + } + if off+(len(h)) > len(msg) { + return len(msg), &Error{err: "overflow packing hex"} + } + copy(msg[off:off+len(h)], h) + off += len(h) + return off, nil +} diff --git a/rawmsg.go b/rawmsg.go index b4a706b9..e4e5374d 100644 --- a/rawmsg.go +++ b/rawmsg.go @@ -8,7 +8,7 @@ func rawSetId(msg []byte, i uint16) bool { if len(msg) < 2 { return false } - msg[0], msg[1] = packUint16(i) + msg[0], msg[1] = packUint16Msg(i) return true } @@ -17,7 +17,7 @@ func rawSetQuestionLen(msg []byte, i uint16) bool { if len(msg) < 6 { return false } - msg[4], msg[5] = packUint16(i) + msg[4], msg[5] = packUint16Msg(i) return true } @@ -26,7 +26,7 @@ func rawSetAnswerLen(msg []byte, i uint16) bool { if len(msg) < 8 { return false } - msg[6], msg[7] = packUint16(i) + msg[6], msg[7] = packUint16Msg(i) return true } @@ -35,7 +35,7 @@ func rawSetNsLen(msg []byte, i uint16) bool { if len(msg) < 10 { return false } - msg[8], msg[9] = packUint16(i) + msg[8], msg[9] = packUint16Msg(i) return true } @@ -44,7 +44,7 @@ func rawSetExtraLen(msg []byte, i uint16) bool { if len(msg) < 12 { return false } - msg[10], msg[11] = packUint16(i) + msg[10], msg[11] = packUint16Msg(i) return true } @@ -90,6 +90,6 @@ Loop: if rdatalen > 0xFFFF { return false } - msg[off], msg[off+1] = packUint16(uint16(rdatalen)) + msg[off], msg[off+1] = packUint16Msg(uint16(rdatalen)) return true } diff --git a/reverse.go b/reverse.go new file mode 100644 index 00000000..099dac94 --- /dev/null +++ b/reverse.go @@ -0,0 +1,38 @@ +package dns + +// StringToType is the reverse of TypeToString, needed for string parsing. +var StringToType = reverseInt16(TypeToString) + +// StringToClass is the reverse of ClassToString, needed for string parsing. +var StringToClass = reverseInt16(ClassToString) + +// Map of opcodes strings. +var StringToOpcode = reverseInt(OpcodeToString) + +// Map of rcodes strings. +var StringToRcode = reverseInt(RcodeToString) + +// Reverse a map +func reverseInt8(m map[uint8]string) map[string]uint8 { + n := make(map[string]uint8, len(m)) + for u, s := range m { + n[s] = u + } + return n +} + +func reverseInt16(m map[uint16]string) map[string]uint16 { + n := make(map[string]uint16, len(m)) + for u, s := range m { + n[s] = u + } + return n +} + +func reverseInt(m map[int]string) map[string]int { + n := make(map[string]int, len(m)) + for u, s := range m { + n[s] = u + } + return n +} diff --git a/zscan.go b/scan.go similarity index 100% rename from zscan.go rename to scan.go diff --git a/zscan_rr.go b/scan_rr.go similarity index 100% rename from zscan_rr.go rename to scan_rr.go diff --git a/server.go b/server.go index edc5c625..158cd3b7 100644 --- a/server.go +++ b/server.go @@ -615,7 +615,7 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) } return nil, ErrShortRead } - length, _ := unpackUint16(l, 0) + length, _ := unpackUint16Msg(l, 0) if length == 0 { return nil, ErrShortRead } @@ -690,7 +690,7 @@ func (w *response) Write(m []byte) (int, error) { return 0, &Error{err: "message too large"} } l := make([]byte, 2, 2+lm) - l[0], l[1] = packUint16(uint16(lm)) + l[0], l[1] = packUint16Msg(uint16(lm)) m = append(l, m...) n, err := io.Copy(w.tcp, bytes.NewReader(m)) diff --git a/sig0.go b/sig0.go index 0fccddbc..ea3e5b67 100644 --- a/sig0.go +++ b/sig0.go @@ -67,13 +67,13 @@ func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) { } // Adjust sig data length rdoff := len(mbuf) + 1 + 2 + 2 + 4 - rdlen, _ := unpackUint16(buf, rdoff) + rdlen, _ := unpackUint16Msg(buf, rdoff) rdlen += uint16(len(sig)) - buf[rdoff], buf[rdoff+1] = packUint16(rdlen) + buf[rdoff], buf[rdoff+1] = packUint16Msg(rdlen) // Adjust additional count - adc, _ := unpackUint16(buf, 10) + adc, _ := unpackUint16Msg(buf, 10) adc++ - buf[10], buf[11] = packUint16(adc) + buf[10], buf[11] = packUint16Msg(adc) return buf, nil } @@ -103,10 +103,10 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { hasher := hash.New() buflen := len(buf) - qdc, _ := unpackUint16(buf, 4) - anc, _ := unpackUint16(buf, 6) - auc, _ := unpackUint16(buf, 8) - adc, offset := unpackUint16(buf, 10) + qdc, _ := unpackUint16Msg(buf, 4) + anc, _ := unpackUint16Msg(buf, 6) + auc, _ := unpackUint16Msg(buf, 8) + adc, offset := unpackUint16Msg(buf, 10) var err error for i := uint16(0); i < qdc && offset < buflen; i++ { _, offset, err = UnpackDomainName(buf, offset) @@ -127,7 +127,7 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { continue } var rdlen uint16 - rdlen, offset = unpackUint16(buf, offset) + rdlen, offset = unpackUint16Msg(buf, offset) offset += int(rdlen) } if offset >= buflen { diff --git a/tsig.go b/tsig.go index c3374e19..7a089ba2 100644 --- a/tsig.go +++ b/tsig.go @@ -301,8 +301,8 @@ func stripTsig(msg []byte) ([]byte, *TSIG, error) { if dns.Extra[i].Header().Rrtype == TypeTSIG { rr = dns.Extra[i].(*TSIG) // Adjust Arcount. - arcount, _ := unpackUint16(msg, 10) - msg[10], msg[11] = packUint16(arcount - 1) + arcount, _ := unpackUint16Msg(msg, 10) + msg[10], msg[11] = packUint16Msg(arcount - 1) break } } diff --git a/types_generate.go b/types_generate.go index 63bfda0e..4df136ef 100644 --- a/types_generate.go +++ b/types_generate.go @@ -29,7 +29,7 @@ var skipLen = map[string]struct{}{ var packageHdr = ` // *** DO NOT MODIFY *** -// AUTOGENERATED BY go generate +// AUTOGENERATED BY go generate from type_generate.go package dns diff --git a/update_test.go b/update_test.go index 175c73b1..56602dfe 100644 --- a/update_test.go +++ b/update_test.go @@ -77,7 +77,7 @@ func TestRemoveRRset(t *testing.T) { if !bytes.Equal(actual, expect) { tmp := new(Msg) if err := tmp.Unpack(actual); err != nil { - t.Fatalf("error unpacking actual msg: %v", err) + t.Fatalf("error unpacking actual msg: %v\nexpected: %v\ngot: %v\n", err, expect, actual) } t.Errorf("expected msg:\n%s", expectstr) t.Errorf("actual msg:\n%v", tmp) diff --git a/zmsg.go b/zmsg.go new file mode 100644 index 00000000..7855d406 --- /dev/null +++ b/zmsg.go @@ -0,0 +1,827 @@ +// *** DO NOT MODIFY *** +// AUTOGENERATED BY go generate from msg_generate.go + +package dns + +// pack*() functions + +func (rr *A) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packDataA(rr.A, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *AAAA) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packDataAAAA(rr.AAAA, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *ANY) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *CNAME) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Target, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *DNAME) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Target, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *DNSKEY) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint16(rr.Flags, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.Protocol, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.Algorithm, msg, off) + if err != nil { + return off, err + } + off, err = packStringBase64(rr.PublicKey, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *HINFO) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packString(rr.Cpu, msg, off) + if err != nil { + return off, err + } + off, err = packString(rr.Os, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *L32) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint16(rr.Preference, msg, off) + if err != nil { + return off, err + } + off, err = packDataA(rr.Locator32, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *LOC) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint8(rr.Version, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.Size, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.HorizPre, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.VertPre, msg, off) + if err != nil { + return off, err + } + off, err = packUint32(rr.Latitude, msg, off) + if err != nil { + return off, err + } + off, err = packUint32(rr.Longitude, msg, off) + if err != nil { + return off, err + } + off, err = packUint32(rr.Altitude, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MB) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Mb, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MD) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Md, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MF) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Mf, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MG) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Mg, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MR) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Mr, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MX) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint16(rr.Preference, msg, off) + if err != nil { + return off, err + } + off, err = PackDomainName(rr.Mx, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *NID) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint16(rr.Preference, msg, off) + if err != nil { + return off, err + } + off, err = packUint64(rr.NodeID, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *NS) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Ns, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *PTR) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Ptr, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *RP) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Mbox, msg, off, compression, compress) + if err != nil { + return off, err + } + off, err = PackDomainName(rr.Txt, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *SRV) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint16(rr.Priority, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(rr.Weight, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(rr.Port, msg, off) + if err != nil { + return off, err + } + off, err = PackDomainName(rr.Target, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +// unpack*() functions + +func unpackA(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(A) + rr.Hdr = h + + rr.A, off, err = unpackDataA(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackAAAA(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(AAAA) + rr.Hdr = h + + rr.AAAA, off, err = unpackDataAAAA(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackANY(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(ANY) + rr.Hdr = h + + return rr, off, err +} + +func unpackCNAME(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(CNAME) + rr.Hdr = h + + rr.Target, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackDNAME(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(DNAME) + rr.Hdr = h + + rr.Target, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackDNSKEY(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(DNSKEY) + rr.Hdr = h + + rr.Flags, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Protocol, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Algorithm, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.PublicKey, off, err = unpackStringBase64(msg, off, rdStart+int(rr.Hdr.Rdlength)) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackHINFO(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(HINFO) + rr.Hdr = h + + rr.Cpu, off, err = unpackString(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Os, off, err = unpackString(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackL32(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(L32) + rr.Hdr = h + + rr.Preference, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Locator32, off, err = unpackDataA(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackLOC(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(LOC) + rr.Hdr = h + + rr.Version, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Size, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.HorizPre, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.VertPre, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Latitude, off, err = unpackUint32(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Longitude, off, err = unpackUint32(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Altitude, off, err = unpackUint32(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMB(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MB) + rr.Hdr = h + + rr.Mb, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMD(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MD) + rr.Hdr = h + + rr.Md, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMF(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MF) + rr.Hdr = h + + rr.Mf, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMG(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MG) + rr.Hdr = h + + rr.Mg, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMR(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MR) + rr.Hdr = h + + rr.Mr, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMX(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MX) + rr.Hdr = h + + rr.Preference, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Mx, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackNID(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(NID) + rr.Hdr = h + + rr.Preference, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.NodeID, off, err = unpackUint64(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackNS(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(NS) + rr.Hdr = h + + rr.Ns, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackPTR(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(PTR) + rr.Hdr = h + + rr.Ptr, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackRP(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(RP) + rr.Hdr = h + + rr.Mbox, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Txt, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackSRV(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(SRV) + rr.Hdr = h + + rr.Priority, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Weight, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Port, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Target, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} diff --git a/ztypes.go b/ztypes.go index 3d0f9aef..858272bc 100644 --- a/ztypes.go +++ b/ztypes.go @@ -1,5 +1,5 @@ // *** DO NOT MODIFY *** -// AUTOGENERATED BY go generate +// AUTOGENERATED BY go generate from type_generate.go package dns