diff --git a/README.md b/README.md index 30152cff..8fada255 100644 --- a/README.md +++ b/README.md @@ -150,11 +150,3 @@ Example programs can be found in the `github.com/miekg/exdns` repository. * `NSD` * `Net::DNS` * `GRONG` - -## TODO - -* privatekey.Precompute() when signing? -* Last remaining RRs: APL, ATMA, A6, NSAP and NXT. -* Missing in parsing: ISDN, UNSPEC, NSAP and ATMA. -* NSEC(3) cover/match/closest enclose. -* Replies with TC bit are not parsed to the end. diff --git a/client.go b/client.go index c1a4a430..7868cac2 100644 --- a/client.go +++ b/client.go @@ -300,7 +300,7 @@ func tcpMsgLen(t io.Reader) (int, error) { if n != 2 { return 0, ErrShortRead } - l, _ := unpackUint16(p, 0) + l, _ := unpackUint16Msg(p, 0) if l == 0 { return 0, ErrShortRead } @@ -392,7 +392,7 @@ func (co *Conn) Write(p []byte) (n int, err error) { return 0, &Error{err: "message too large"} } l := make([]byte, 2, lp+2) - l[0], l[1] = packUint16(uint16(lp)) + l[0], l[1] = packUint16Msg(uint16(lp)) p = append(l, p...) n, err := io.Copy(w, bytes.NewReader(p)) return int(n), err diff --git a/dns_bench_test.go b/dns_bench_test.go new file mode 100644 index 00000000..de930087 --- /dev/null +++ b/dns_bench_test.go @@ -0,0 +1,205 @@ +package dns + +import ( + "net" + "testing" +) + +func BenchmarkMsgLength(b *testing.B) { + b.StopTimer() + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + b.StartTimer() + for i := 0; i < b.N; i++ { + msg.Len() + } +} + +func BenchmarkMsgLengthPack(b *testing.B) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = msg.Pack() + } +} + +func BenchmarkPackDomainName(b *testing.B) { + name1 := "12345678901234567890123456789012345.12345678.123." + buf := make([]byte, len(name1)+1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = PackDomainName(name1, buf, 0, nil, false) + } +} + +func BenchmarkUnpackDomainName(b *testing.B) { + name1 := "12345678901234567890123456789012345.12345678.123." + buf := make([]byte, len(name1)+1) + _, _ = PackDomainName(name1, buf, 0, nil, false) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = UnpackDomainName(buf, 0) + } +} + +func BenchmarkUnpackDomainNameUnprintable(b *testing.B) { + name1 := "\x02\x02\x02\x025\x02\x02\x02\x02.12345678.123." + buf := make([]byte, len(name1)+1) + _, _ = PackDomainName(name1, buf, 0, nil, false) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = UnpackDomainName(buf, 0) + } +} + +func BenchmarkCopy(b *testing.B) { + b.ReportAllocs() + m := new(Msg) + m.SetQuestion("miek.nl.", TypeA) + rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1") + m.Answer = []RR{rr} + rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1") + m.Ns = []RR{rr} + rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.1") + m.Extra = []RR{rr} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Copy() + } +} + +func BenchmarkPackA(b *testing.B) { + a := &A{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, A: net.IPv4(127, 0, 0, 1)} + + buf := make([]byte, a.len()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = PackRR(a, buf, 0, nil, false) + } +} + +func BenchmarkUnpackA(b *testing.B) { + a := &A{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, A: net.IPv4(127, 0, 0, 1)} + + buf := make([]byte, a.len()) + PackRR(a, buf, 0, nil, false) + a = nil + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = UnpackRR(buf, 0) + } +} + +func BenchmarkPackMX(b *testing.B) { + m := &MX{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, Mx: "mx.miek.nl."} + + buf := make([]byte, m.len()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = PackRR(m, buf, 0, nil, false) + } +} + +func BenchmarkUnpackMX(b *testing.B) { + m := &MX{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, Mx: "mx.miek.nl."} + + buf := make([]byte, m.len()) + PackRR(m, buf, 0, nil, false) + m = nil + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = UnpackRR(buf, 0) + } +} + +func BenchmarkPackAAAAA(b *testing.B) { + aaaa, _ := NewRR(". IN A ::1") + + buf := make([]byte, aaaa.len()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = PackRR(aaaa, buf, 0, nil, false) + } +} + +func BenchmarkUnpackAAAA(b *testing.B) { + aaaa, _ := NewRR(". IN A ::1") + + buf := make([]byte, aaaa.len()) + PackRR(aaaa, buf, 0, nil, false) + aaaa = nil + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = UnpackRR(buf, 0) + } +} + +func BenchmarkPackMsg(b *testing.B) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + buf := make([]byte, 512) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = msg.PackBuffer(buf) + } +} + +func BenchmarkUnpackMsg(b *testing.B) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + name1 := "12345678901234567890123456789012345.12345678.123." + rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) + msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) + msgBuf, _ := msg.Pack() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = msg.Unpack(msgBuf) + } +} diff --git a/dns_test.go b/dns_test.go index 20b33f0c..a709cce0 100644 --- a/dns_test.go +++ b/dns_test.go @@ -274,9 +274,9 @@ func TestMsgLength2(t *testing.T) { for i, hexData := range testMessages { // we won't fail the decoding of the hex input, _ := hex.DecodeString(hexData) + m := new(Msg) m.Unpack(input) - //println(m.String()) m.Compress = true lenComp := m.Len() b, _ := m.Pack() @@ -310,114 +310,6 @@ func TestMsgLengthCompressionMalformed(t *testing.T) { m.Len() // Should not crash. } -func BenchmarkMsgLength(b *testing.B) { - b.StopTimer() - makeMsg := func(question string, ans, ns, e []RR) *Msg { - msg := new(Msg) - msg.SetQuestion(Fqdn(question), TypeANY) - msg.Answer = append(msg.Answer, ans...) - msg.Ns = append(msg.Ns, ns...) - msg.Extra = append(msg.Extra, e...) - msg.Compress = true - return msg - } - name1 := "12345678901234567890123456789012345.12345678.123." - rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) - msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) - b.StartTimer() - for i := 0; i < b.N; i++ { - msg.Len() - } -} - -func BenchmarkMsgLengthPack(b *testing.B) { - makeMsg := func(question string, ans, ns, e []RR) *Msg { - msg := new(Msg) - msg.SetQuestion(Fqdn(question), TypeANY) - msg.Answer = append(msg.Answer, ans...) - msg.Ns = append(msg.Ns, ns...) - msg.Extra = append(msg.Extra, e...) - msg.Compress = true - return msg - } - name1 := "12345678901234567890123456789012345.12345678.123." - rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) - msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = msg.Pack() - } -} - -func BenchmarkMsgPackBuffer(b *testing.B) { - makeMsg := func(question string, ans, ns, e []RR) *Msg { - msg := new(Msg) - msg.SetQuestion(Fqdn(question), TypeANY) - msg.Answer = append(msg.Answer, ans...) - msg.Ns = append(msg.Ns, ns...) - msg.Extra = append(msg.Extra, e...) - msg.Compress = true - return msg - } - name1 := "12345678901234567890123456789012345.12345678.123." - rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) - msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) - buf := make([]byte, 512) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = msg.PackBuffer(buf) - } -} - -func BenchmarkMsgUnpack(b *testing.B) { - makeMsg := func(question string, ans, ns, e []RR) *Msg { - msg := new(Msg) - msg.SetQuestion(Fqdn(question), TypeANY) - msg.Answer = append(msg.Answer, ans...) - msg.Ns = append(msg.Ns, ns...) - msg.Extra = append(msg.Extra, e...) - msg.Compress = true - return msg - } - name1 := "12345678901234567890123456789012345.12345678.123." - rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1) - msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil) - msgBuf, _ := msg.Pack() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = msg.Unpack(msgBuf) - } -} - -func BenchmarkPackDomainName(b *testing.B) { - name1 := "12345678901234567890123456789012345.12345678.123." - buf := make([]byte, len(name1)+1) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = PackDomainName(name1, buf, 0, nil, false) - } -} - -func BenchmarkUnpackDomainName(b *testing.B) { - name1 := "12345678901234567890123456789012345.12345678.123." - buf := make([]byte, len(name1)+1) - _, _ = PackDomainName(name1, buf, 0, nil, false) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _, _ = UnpackDomainName(buf, 0) - } -} - -func BenchmarkUnpackDomainNameUnprintable(b *testing.B) { - name1 := "\x02\x02\x02\x025\x02\x02\x02\x02.12345678.123." - buf := make([]byte, len(name1)+1) - _, _ = PackDomainName(name1, buf, 0, nil, false) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _, _ = UnpackDomainName(buf, 0) - } -} - func TestToRFC3597(t *testing.T) { a, _ := NewRR("miek.nl. IN A 10.0.1.1") x := new(RFC3597) @@ -431,7 +323,7 @@ func TestNoRdataPack(t *testing.T) { data := make([]byte, 1024) for typ, fn := range TypeToRR { r := fn() - *r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600} + *r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 16} _, err := PackRR(r, data, 0, nil, false) if err != nil { t.Errorf("failed to pack RR with zero rdata: %s: %v", TypeToString[typ], err) @@ -439,7 +331,6 @@ func TestNoRdataPack(t *testing.T) { } } -// TODO(miek): fix dns buffer too small errors this throws func TestNoRdataUnpack(t *testing.T) { data := make([]byte, 1024) for typ, fn := range TypeToRR { @@ -449,7 +340,7 @@ func TestNoRdataUnpack(t *testing.T) { continue } r := fn() - *r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600} + *r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 16} off, err := PackRR(r, data, 0, nil, false) if err != nil { // Should always works, TestNoDataPack should have caught this @@ -513,23 +404,6 @@ func TestMsgCopy(t *testing.T) { } } -func BenchmarkCopy(b *testing.B) { - b.ReportAllocs() - m := new(Msg) - m.SetQuestion("miek.nl.", TypeA) - rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1") - m.Answer = []RR{rr} - rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1") - m.Ns = []RR{rr} - rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.1") - m.Extra = []RR{rr} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - m.Copy() - } -} - func TestPackIPSECKEY(t *testing.T) { tests := []string{ "38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.38 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )", diff --git a/dnssec.go b/dnssec.go index 84cb2142..94010497 100644 --- a/dnssec.go +++ b/dnssec.go @@ -144,7 +144,7 @@ func (k *DNSKEY) KeyTag() uint16 { // at the base64 values. But I'm lazy. modulus, _ := fromBase64([]byte(k.PublicKey)) if len(modulus) > 1 { - x, _ := unpackUint16(modulus, len(modulus)-2) + x, _ := unpackUint16Msg(modulus, len(modulus)-2) keytag = int(x) } default: diff --git a/edns.go b/edns.go index c1aebc11..5634cd70 100644 --- a/edns.go +++ b/edns.go @@ -213,7 +213,7 @@ func (e *EDNS0_SUBNET) Option() uint16 { func (e *EDNS0_SUBNET) pack() ([]byte, error) { b := make([]byte, 4) - b[0], b[1] = packUint16(e.Family) + b[0], b[1] = packUint16Msg(e.Family) b[2] = e.SourceNetmask b[3] = e.SourceScope switch e.Family { @@ -247,7 +247,7 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error { if len(b) < 4 { return ErrBuf } - e.Family, _ = unpackUint16(b, 0) + e.Family, _ = unpackUint16Msg(b, 0) e.SourceNetmask = b[2] e.SourceScope = b[3] switch e.Family { @@ -369,9 +369,9 @@ func (e *EDNS0_LLQ) Option() uint16 { return EDNS0LLQ } func (e *EDNS0_LLQ) pack() ([]byte, error) { b := make([]byte, 18) - b[0], b[1] = packUint16(e.Version) - b[2], b[3] = packUint16(e.Opcode) - b[4], b[5] = packUint16(e.Error) + b[0], b[1] = packUint16Msg(e.Version) + b[2], b[3] = packUint16Msg(e.Opcode) + b[4], b[5] = packUint16Msg(e.Error) b[6] = byte(e.Id >> 56) b[7] = byte(e.Id >> 48) b[8] = byte(e.Id >> 40) @@ -391,9 +391,9 @@ func (e *EDNS0_LLQ) unpack(b []byte) error { if len(b) < 18 { return ErrBuf } - e.Version, _ = unpackUint16(b, 0) - e.Opcode, _ = unpackUint16(b, 2) - e.Error, _ = unpackUint16(b, 4) + e.Version, _ = unpackUint16Msg(b, 0) + e.Opcode, _ = unpackUint16Msg(b, 2) + e.Error, _ = unpackUint16Msg(b, 4) e.Id = uint64(b[6])<<56 | uint64(b[6+1])<<48 | uint64(b[6+2])<<40 | uint64(b[6+3])<<32 | uint64(b[6+4])<<24 | uint64(b[6+5])<<16 | uint64(b[6+6])<<8 | uint64(b[6+7]) e.LeaseLife = uint32(b[14])<<24 | uint32(b[14+1])<<16 | uint32(b[14+2])<<8 | uint32(b[14+3]) diff --git a/zgenerate.go b/generate.go similarity index 100% rename from zgenerate.go rename to generate.go diff --git a/labels.go b/labels.go index cb549fc6..fca5c7dd 100644 --- a/labels.go +++ b/labels.go @@ -107,7 +107,7 @@ func CountLabel(s string) (labels int) { // Split splits a name s into its label indexes. // www.miek.nl. returns []int{0, 4, 9}, www.miek.nl also returns []int{0, 4, 9}. -// The root name (.) returns nil. Also see SplitDomainName. +// The root name (.) returns nil. Also see SplitDomainName. // s must be a syntactically valid domain name. func Split(s string) []int { if s == "." { diff --git a/labels_test.go b/labels_test.go index 50f31163..536757d5 100644 --- a/labels_test.go +++ b/labels_test.go @@ -184,12 +184,14 @@ func BenchmarkLenLabels(b *testing.B) { } func BenchmarkCompareLabels(b *testing.B) { + b.ReportAllocs() for i := 0; i < b.N; i++ { CompareDomainName("www.example.com", "aa.example.com") } } func BenchmarkIsSubDomain(b *testing.B) { + b.ReportAllocs() for i := 0; i < b.N; i++ { IsSubDomain("www.example.com", "aa.example.com") IsSubDomain("example.com", "aa.example.com") diff --git a/msg.go b/msg.go index 41b6e0da..60463b6d 100644 --- a/msg.go +++ b/msg.go @@ -8,9 +8,9 @@ package dns +//go:generate go run msg_generate.go + import ( - "encoding/base32" - "encoding/base64" "encoding/hex" "math/big" "math/rand" @@ -92,18 +92,6 @@ type Msg struct { Extra []RR // Holds the RR(s) of the additional section. } -// StringToType is the reverse of TypeToString, needed for string parsing. -var StringToType = reverseInt16(TypeToString) - -// StringToClass is the reverse of ClassToString, needed for string parsing. -var StringToClass = reverseInt16(ClassToString) - -// Map of opcodes strings. -var StringToOpcode = reverseInt(OpcodeToString) - -// Map of rcodes strings. -var StringToRcode = reverseInt(RcodeToString) - // ClassToString is a maps Classes to strings for each CLASS wire type. var ClassToString = map[uint16]string{ ClassINET: "IN", @@ -291,11 +279,11 @@ func packDomainName(s string, msg []byte, off int, compression map[string]int, c if pointer != -1 { // We have two bytes (14 bits) to put the pointer in // if msg == nil, we will never do compression - msg[nameoffset], msg[nameoffset+1] = packUint16(uint16(pointer ^ 0xC000)) + msg[nameoffset], msg[nameoffset+1] = packUint16Msg(uint16(pointer ^ 0xC000)) off = nameoffset + 1 goto End } - if msg != nil { + if msg != nil && off < len(msg) { msg[off] = 0 } End: @@ -423,7 +411,7 @@ func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) { func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) { lenByteOffset := offset - if offset >= len(msg) { + if offset >= len(msg) || len(s) > len(tmp) { return offset, ErrBuf } offset++ @@ -465,7 +453,7 @@ func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) { } func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) { - if offset >= len(msg) { + if offset >= len(msg) || len(s) > len(tmp) { return offset, ErrBuf } bs := tmp[:len(s)] @@ -600,9 +588,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str return lenmsg, &Error{err: "overflow packing opt"} } // Option code - msg[off], msg[off+1] = packUint16(element.(EDNS0).Option()) + msg[off], msg[off+1] = packUint16Msg(element.(EDNS0).Option()) // Length - msg[off+2], msg[off+3] = packUint16(uint16(len(b))) + msg[off+2], msg[off+3] = packUint16Msg(uint16(len(b))) off += 4 if off+len(b) > lenmsg { copy(msg[off:], b) @@ -783,6 +771,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str if e != nil { return lenmsg, e } + if off+len(b64) > lenmsg { + return lenmsg, &Error{err: "overflow packing base64"} + } copy(msg[off:off+len(b64)], b64) off += len(b64) case `dns:"domain-name"`: @@ -811,6 +802,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str if e != nil { return lenmsg, e } + if off+len(b32) > lenmsg { + return lenmsg, &Error{err: "overflow packing base32"} + } copy(msg[off:off+len(b32)], b32) off += len(b32) case `dns:"size-hex"`: @@ -827,6 +821,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str copy(msg[off:off+hex.DecodedLen(len(s))], h) off += hex.DecodedLen(len(s)) case `dns:"size"`: + // TODO(miek): WTF? size? // the size is already encoded in the RR, we can safely use the // length of string. String is RAW (not encoded in hex, nor base64) copy(msg[off:off+len(s)], s) @@ -930,8 +925,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er if off+4 > lenmsg { return lenmsg, &Error{err: "overflow unpacking opt"} } - code, off = unpackUint16(msg, off) - optlen, off1 := unpackUint16(msg, off) + code, off = unpackUint16Msg(msg, off) + optlen, off1 := unpackUint16Msg(msg, off) if off1+int(optlen) > lenmsg { return lenmsg, &Error{err: "overflow unpacking opt"} } @@ -1174,7 +1169,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er if off+2 > lenmsg { return lenmsg, &Error{err: "overflow unpacking uint16"} } - i, off = unpackUint16(msg, off) + i, off = unpackUint16Msg(msg, off) fv.SetUint(uint64(i)) case reflect.Uint32: if off == lenmsg { @@ -1334,38 +1329,6 @@ func intToBytes(i *big.Int, length int) []byte { return buf } -func unpackUint16(msg []byte, off int) (uint16, int) { - return uint16(msg[off])<<8 | uint16(msg[off+1]), off + 2 -} - -func packUint16(i uint16) (byte, byte) { - return byte(i >> 8), byte(i) -} - -func toBase32(b []byte) string { - return base32.HexEncoding.EncodeToString(b) -} - -func fromBase32(s []byte) (buf []byte, err error) { - buflen := base32.HexEncoding.DecodedLen(len(s)) - buf = make([]byte, buflen) - n, err := base32.HexEncoding.Decode(buf, s) - buf = buf[:n] - return -} - -func toBase64(b []byte) string { - return base64.StdEncoding.EncodeToString(b) -} - -func fromBase64(s []byte) (buf []byte, err error) { - buflen := base64.StdEncoding.DecodedLen(len(s)) - buf = make([]byte, buflen) - n, err := base64.StdEncoding.Decode(buf, s) - buf = buf[:n] - return -} - // PackRR packs a resource record rr into msg[off:]. // See PackDomainName for documentation about the compression. func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { @@ -1373,7 +1336,62 @@ func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress boo return len(msg), &Error{err: "nil rr"} } - off1, err = packStructCompress(rr, msg, off, compression, compress) + _, ok := typeToUnpack[rr.Header().Rrtype] + switch ok { + case true: + // Shortcut reflection, `pack' needs to be added to the RR interface so we can just do this: + // off1, err = t.pack(msg, off, compression, compress) + // TODO(miek): revert the logic and make a blacklist for types that still use reflection. Kill + // typeToUnpack and just generate all the pack and unpack functions even though we don't use + // them for all types (yet). + switch t := rr.(type) { + case *RR_Header: + // we can be called with an empty RR, consisting only out of the header, see update_test.go's + // TestDynamicUpdateZeroRdataUnpack for an example. This is OK as RR_Header also implements the RR interface. + off1, err = t.pack(msg, off, compression, compress) + case *ANY: + // Also "weird" setup, see (again) update_test.go's TestRemoveRRset, where the Rrtype is 1 but the type is *ANY. + off1, err = t.pack(msg, off, compression, compress) + case *A: + off1, err = t.pack(msg, off, compression, compress) + case *AAAA: + off1, err = t.pack(msg, off, compression, compress) + case *CNAME: + off1, err = t.pack(msg, off, compression, compress) + case *DNAME: + off1, err = t.pack(msg, off, compression, compress) + case *HINFO: + off1, err = t.pack(msg, off, compression, compress) + case *L32: + off1, err = t.pack(msg, off, compression, compress) + case *LOC: + off1, err = t.pack(msg, off, compression, compress) + case *MB: + off1, err = t.pack(msg, off, compression, compress) + case *MD: + off1, err = t.pack(msg, off, compression, compress) + case *MF: + off1, err = t.pack(msg, off, compression, compress) + case *MG: + off1, err = t.pack(msg, off, compression, compress) + case *MX: + off1, err = t.pack(msg, off, compression, compress) + case *NID: + off1, err = t.pack(msg, off, compression, compress) + case *NS: + off1, err = t.pack(msg, off, compression, compress) + case *PTR: + off1, err = t.pack(msg, off, compression, compress) + case *RP: + off1, err = t.pack(msg, off, compression, compress) + case *SRV: + off1, err = t.pack(msg, off, compression, compress) + case *DNSKEY: + off1, err = t.pack(msg, off, compression, compress) + } + default: + off1, err = packStructCompress(rr, msg, off, compression, compress) + } if err != nil { return len(msg), err } @@ -1385,21 +1403,27 @@ func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress boo // UnpackRR unpacks msg[off:] into an RR. func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) { - // unpack just the header, to find the rr type and length - var h RR_Header off0 := off - if off, err = UnpackStruct(&h, msg, off); err != nil { + h, off, msg, err := unpackHeader(msg, off) + if err != nil { return nil, len(msg), err } end := off + int(h.Rdlength) - // make an rr of that type and re-unpack. - mk, known := TypeToRR[h.Rrtype] - if !known { - rr = new(RFC3597) - } else { - rr = mk() + + fn, ok := typeToUnpack[h.Rrtype] + switch ok { + case true: + // Shortcut reflection. + rr, off, err = fn(h, msg, off) + default: + mk, known := TypeToRR[h.Rrtype] + if !known { + rr = new(RFC3597) + } else { + rr = mk() + } + off, err = UnpackStruct(rr, msg, off0) } - off, err = UnpackStruct(rr, msg, off0) if off != end { return &h, end, &Error{err: "bad rdlength"} } @@ -1432,31 +1456,6 @@ func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) return dst, off, err } -// Reverse a map -func reverseInt8(m map[uint8]string) map[string]uint8 { - n := make(map[string]uint8) - for u, s := range m { - n[s] = u - } - return n -} - -func reverseInt16(m map[uint16]string) map[string]uint16 { - n := make(map[string]uint16) - for u, s := range m { - n[s] = u - } - return n -} - -func reverseInt(m map[int]string) map[string]int { - n := make(map[string]int) - for u, s := range m { - n[s] = u - } - return n -} - // Convert a MsgHdr to a string, with dig-like headers: // //;; opcode: QUERY, status: NOERROR, id: 48404 @@ -1510,8 +1509,11 @@ func (dns *Msg) Pack() (msg []byte, err error) { // PackBuffer packs a Msg, using the given buffer buf. If buf is too small // a new buffer is allocated. func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) { - var dh Header - var compression map[string]int + var ( + dh Header + compression map[string]int + ) + if dns.Compress { compression = make(map[string]int) // Compression pointer mappings } @@ -1579,12 +1581,12 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) { // Pack it in: header and then the pieces. off := 0 - off, err = packStructCompress(&dh, msg, off, compression, dns.Compress) + off, err = dh.pack(msg, off, compression, dns.Compress) if err != nil { return nil, err } for i := 0; i < len(question); i++ { - off, err = packStructCompress(&question[i], msg, off, compression, dns.Compress) + off, err = question[i].pack(msg, off, compression, dns.Compress) if err != nil { return nil, err } @@ -1612,12 +1614,17 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) { // Unpack unpacks a binary message to a Msg structure. func (dns *Msg) Unpack(msg []byte) (err error) { - // Header. - var dh Header - off := 0 - if off, err = UnpackStruct(&dh, msg, off); err != nil { + var ( + dh Header + off int + ) + if dh, off, err = unpackMsgHdr(msg, off); err != nil { return err } + if off == len(msg) { + return ErrTruncated + } + dns.Id = dh.Id dns.Response = (dh.Bits & _QR) != 0 dns.Opcode = int(dh.Bits>>11) & 0xF @@ -1633,10 +1640,10 @@ func (dns *Msg) Unpack(msg []byte) (err error) { // Optimistically use the count given to us in the header dns.Question = make([]Question, 0, int(dh.Qdcount)) - var q Question for i := 0; i < int(dh.Qdcount); i++ { off1 := off - off, err = UnpackStruct(&q, msg, off) + var q Question + q, off, err = unpackQuestion(msg, off) if err != nil { // Even if Truncated is set, we only will set ErrTruncated if we // actually got the questions @@ -1662,6 +1669,7 @@ func (dns *Msg) Unpack(msg []byte) (err error) { } // The header counts might have been wrong so we need to update it dh.Arcount = uint16(len(dns.Extra)) + if off != len(msg) { // TODO(miek) make this an error? // use PackOpt to let people tell how detailed the error reporting should be? @@ -1735,6 +1743,9 @@ func (dns *Msg) Len() int { } } for i := 0; i < len(dns.Answer); i++ { + if dns.Answer[i] == nil { + continue + } l += dns.Answer[i].len() if dns.Compress { k, ok := compressionLenSearch(compression, dns.Answer[i].Header().Name) @@ -1750,6 +1761,9 @@ func (dns *Msg) Len() int { } } for i := 0; i < len(dns.Ns); i++ { + if dns.Ns[i] == nil { + continue + } l += dns.Ns[i].len() if dns.Compress { k, ok := compressionLenSearch(compression, dns.Ns[i].Header().Name) @@ -1765,6 +1779,9 @@ func (dns *Msg) Len() int { } } for i := 0; i < len(dns.Extra); i++ { + if dns.Extra[i] == nil { + continue + } l += dns.Extra[i].len() if dns.Compress { k, ok := compressionLenSearch(compression, dns.Extra[i].Header().Name) @@ -1955,3 +1972,122 @@ func (dns *Msg) CopyTo(r1 *Msg) *Msg { return r1 } + +func (q *Question) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := PackDomainName(q.Name, msg, off, compression, compress) + if err != nil { + return off, err + } + off, err = packUint16(q.Qtype, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(q.Qclass, msg, off) + if err != nil { + return off, err + } + return off, nil +} + +func unpackQuestion(msg []byte, off int) (Question, int, error) { + var ( + q Question + err error + ) + q.Name, off, err = UnpackDomainName(msg, off) + if err != nil { + return q, off, err + } + if off == len(msg) { + return q, off, nil + } + q.Qtype, off, err = unpackUint16(msg, off) + if err != nil { + return q, off, err + } + if off == len(msg) { + return q, off, nil + } + q.Qclass, off, err = unpackUint16(msg, off) + if off == len(msg) { + return q, off, nil + } + return q, off, err +} + +func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := packUint16(dh.Id, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(dh.Bits, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(dh.Qdcount, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(dh.Ancount, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(dh.Nscount, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(dh.Arcount, msg, off) + return off, err +} + +func unpackMsgHdr(msg []byte, off int) (Header, int, error) { + var ( + dh Header + err error + ) + dh.Id, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Bits, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Qdcount, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Ancount, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Nscount, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Arcount, off, err = unpackUint16(msg, off) + return dh, off, err +} + +// Which types have type specific unpack functions. +var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){ + TypeAAAA: unpackAAAA, + TypeA: unpackA, + TypeCNAME: unpackCNAME, + TypeDNAME: unpackDNAME, + TypeL32: unpackL32, + TypeLOC: unpackLOC, + TypeMB: unpackMB, + TypeMD: unpackMD, + TypeMF: unpackMF, + TypeMG: unpackMG, + TypeMR: unpackMR, + TypeMX: unpackMX, + TypeNID: unpackNID, + TypeNS: unpackNS, + TypePTR: unpackPTR, + TypeRP: unpackRP, + TypeSRV: unpackSRV, + TypeHINFO: unpackHINFO, + TypeDNSKEY: unpackDNSKEY, +} diff --git a/msg_generate.go b/msg_generate.go new file mode 100644 index 00000000..2f7bc99f --- /dev/null +++ b/msg_generate.go @@ -0,0 +1,307 @@ +//+build ignore + +// msg_generate.go is meant to run with go generate. It will use +// go/{importer,types} to track down all the RR struct types. Then for each type +// it will generate pack/unpack methods based on the struct tags. The generated source is +// written to zmsg.go, and is meant to be checked into git. +package main + +import ( + "bytes" + "fmt" + "go/format" + "go/importer" + "go/types" + "log" + "os" +) + +// All RR pack and unpack functions should be generated, currently RR that present some +// problems +// * NSEC/NSEC3 - type bitmap +// * TXT/SPF - string slice +// * URI - weird octet thing there +// * NSEC3/TSIG - size hex +// * OPT RR - EDNS0 parsing - needs to some looking at +// * HIP - uses "hex", but is actually size-hex - might drop size-hex? +// * Z +// * WKS - uint16 slice +// * NINFO +// * PrivateRR + +// What types are we generating, should be kept in sync with typeToUnpack in msg.go +var generate = map[string]bool{ + "AAAA": true, + "ANY": true, + "A": true, + "CNAME": true, + "DNAME": true, + "DNSKEY": true, + "HINFO": true, + "L32": true, + "LOC": true, + "MB": true, + "MD": true, + "MF": true, + "MG": true, + "MR": true, + "MX": true, + "NID": true, + "NS": true, + "PTR": true, + "RP": true, + "SRV": true, +} + +func shouldGenerate(name string) bool { + _, ok := generate[name] + return ok +} + +// For later: IPSECKEY is weird. + +var packageHdr = ` +// *** DO NOT MODIFY *** +// AUTOGENERATED BY go generate from msg_generate.go + +package dns + +` + +// getTypeStruct will take a type and the package scope, and return the +// (innermost) struct if the type is considered a RR type (currently defined as +// those structs beginning with a RR_Header, could be redefined as implementing +// the RR interface). The bool return value indicates if embedded structs were +// resolved. +func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { + st, ok := t.Underlying().(*types.Struct) + if !ok { + return nil, false + } + if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { + return st, false + } + if st.Field(0).Anonymous() { + st, _ := getTypeStruct(st.Field(0).Type(), scope) + return st, true + } + return nil, false +} + +func main() { + // Import and type-check the package + pkg, err := importer.Default().Import("github.com/miekg/dns") + fatalIfErr(err) + scope := pkg.Scope() + + // Collect actual types (*X) + var namedTypes []string + for _, name := range scope.Names() { + o := scope.Lookup(name) + if o == nil || !o.Exported() { + continue + } + if st, _ := getTypeStruct(o.Type(), scope); st == nil { + continue + } + if name == "PrivateRR" { + continue + } + + // Check if corresponding TypeX exists + if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { + log.Fatalf("Constant Type%s does not exist.", o.Name()) + } + + namedTypes = append(namedTypes, o.Name()) + } + + b := &bytes.Buffer{} + b.WriteString(packageHdr) + + fmt.Fprint(b, "// pack*() functions\n\n") + for _, name := range namedTypes { + o := scope.Lookup(name) + st, isEmbedded := getTypeStruct(o.Type(), scope) + if isEmbedded || !shouldGenerate(name) { + continue + } + + fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name) + fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress) +if err != nil { + return off, err +} +headerEnd := off +`) + for i := 1; i < st.NumFields(); i++ { + o := func(s string) { + fmt.Fprintf(b, s, st.Field(i).Name()) + fmt.Fprint(b, `if err != nil { +return off, err +} +`) + } + + //if _, ok := st.Field(i).Type().(*types.Slice); ok { + //switch st.Tag(i) { + //case `dns:"-"`: + //// ignored + //case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`: + //o("for _, x := range rr.%s { l += len(x) + 1 }\n") + //default: + //log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) + //} + //continue + //} + + switch st.Tag(i) { + case `dns:"-"`: + // ignored + case `dns:"cdomain-name"`: + fallthrough + case `dns:"domain-name"`: + o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n") + case `dns:"a"`: + o("off, err = packDataA(rr.%s, msg, off)\n") + case `dns:"aaaa"`: + o("off, err = packDataAAAA(rr.%s, msg, off)\n") + case `dns:"uint48"`: + o("off, err = packUint48(rr.%s, msg, off)\n") + case `dns:"txt"`: + o("off, err = packString(rr.%s, msg, off)\n") + case `dns:"base32"`: + o("off, err = packStringBase32(rr.%s, msg, off)\n") + case `dns:"base64"`: + o("off, err = packStringBase64(rr.%s, msg, off)\n") + case "": + switch st.Field(i).Type().(*types.Basic).Kind() { + case types.Uint8: + o("off, err = packUint8(rr.%s, msg, off)\n") + case types.Uint16: + o("off, err = packUint16(rr.%s, msg, off)\n") + case types.Uint32: + o("off, err = packUint32(rr.%s, msg, off)\n") + case types.Uint64: + o("off, err = packUint64(rr.%s, msg, off)\n") + case types.String: + o("off, err = packString(rr.%s, msg, off)\n") + default: + log.Fatalln(name, st.Field(i).Name()) + } + //default: + //log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) + } + } + // We have packed everything, only now we know the rdlength of this RR + fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off- headerEnd)") + fmt.Fprintln(b, "return off, nil }\n") + } + + fmt.Fprint(b, "// unpack*() functions\n\n") + for _, name := range namedTypes { + o := scope.Lookup(name) + st, isEmbedded := getTypeStruct(o.Type(), scope) + if isEmbedded || !shouldGenerate(name) { + continue + } + + fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name) + fmt.Fprint(b, `if noRdata(h) { +return nil, off, nil + } +var err error +rdStart := off +_ = rdStart + +`) + fmt.Fprintf(b, "rr := new(%s)\n", name) + fmt.Fprintln(b, "rr.Hdr = h\n") + for i := 1; i < st.NumFields(); i++ { + o := func(s string) { + fmt.Fprintf(b, s, st.Field(i).Name()) + fmt.Fprint(b, `if err != nil { +return rr, off, err +} +`) + } + + //if _, ok := st.Field(i).Type().(*types.Slice); ok { + //switch st.Tag(i) { + //case `dns:"-"`: + //// ignored + //case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`: + //o("for _, x := range rr.%s { l += len(x) + 1 }\n") + //default: + //log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) + //} + //continue + //} + + switch st.Tag(i) { + case `dns:"-"`: + // ignored + case `dns:"cdomain-name"`: + fallthrough + case `dns:"domain-name"`: + o("rr.%s, off, err = UnpackDomainName(msg, off)\n") + case `dns:"a"`: + o("rr.%s, off, err = unpackDataA(msg, off)\n") + case `dns:"aaaa"`: + o("rr.%s, off, err = unpackDataAAAA(msg, off)\n") + case `dns:"uint48"`: + o("rr.%s, off, err = unpackUint48(msg, off)\n") + case `dns:"txt"`: + o("rr.%s, off, err = unpackString(msg, off)\n") + case `dns:"base32"`: + o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") + case `dns:"base64"`: + o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") + case "": + switch st.Field(i).Type().(*types.Basic).Kind() { + case types.Uint8: + o("rr.%s, off, err = unpackUint8(msg, off)\n") + case types.Uint16: + o("rr.%s, off, err = unpackUint16(msg, off)\n") + case types.Uint32: + o("rr.%s, off, err = unpackUint32(msg, off)\n") + case types.Uint64: + o("rr.%s, off, err = unpackUint64(msg, off)\n") + case types.String: + o("rr.%s, off, err = unpackString(msg, off)\n") + default: + log.Fatalln(name, st.Field(i).Name()) + } + //default: + //log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) + } + // If we've hit len(msg) we return without error. + if i < st.NumFields()-1 { + fmt.Fprintf(b, `if off == len(msg) { +return rr, off, nil + } +`) + } + } + fmt.Fprintf(b, "return rr, off, err }\n\n") + } + + // gofmt + res, err := format.Source(b.Bytes()) + if err != nil { + b.WriteTo(os.Stderr) + log.Fatal(err) + } + + // write result + f, err := os.Create("zmsg.go") + fatalIfErr(err) + defer f.Close() + f.Write(res) +} + +func fatalIfErr(err error) { + if err != nil { + log.Fatal(err) + } +} diff --git a/msg_helpers.go b/msg_helpers.go new file mode 100644 index 00000000..06eea492 --- /dev/null +++ b/msg_helpers.go @@ -0,0 +1,399 @@ +package dns + +import ( + "encoding/base32" + "encoding/base64" + "encoding/hex" + "net" + "strconv" +) + +// helper functions called from the generated zmsg.go + +// These function are named after the tag to help pack/unpack, if there is no tag it is the name +// of the type they pack/unpack (string, int, etc). We prefix all with unpackData or packData, so packDataA or +// packDataDomainName. + +func unpackDataA(msg []byte, off int) (net.IP, int, error) { + if off+net.IPv4len > len(msg) { + return nil, len(msg), &Error{err: "overflow unpacking a"} + } + a := net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]) + off += net.IPv4len + return a, off, nil +} + +func packDataA(a net.IP, msg []byte, off int) (int, error) { + // It must be a slice of 4, even if it is 16, we encode only the first 4 + if off+net.IPv4len > len(msg) { + return len(msg), &Error{err: "overflow packing a"} + } + switch len(a) { + case net.IPv6len: + msg[off] = a[12] + msg[off+1] = a[13] + msg[off+2] = a[14] + msg[off+3] = a[15] + off += net.IPv4len + case net.IPv4len: + msg[off] = a[0] + msg[off+1] = a[1] + msg[off+2] = a[2] + msg[off+3] = a[3] + off += net.IPv4len + case 0: + // Allowed, for dynamic updates. + default: + return len(msg), &Error{err: "overflow packing a"} + } + return off, nil +} + +func unpackDataAAAA(msg []byte, off int) (net.IP, int, error) { + if off+net.IPv6len > len(msg) { + return nil, len(msg), &Error{err: "overflow unpacking aaaa"} + } + aaaa := net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4], + msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10], + msg[off+11], msg[off+12], msg[off+13], msg[off+14], msg[off+15]} + off += net.IPv6len + return aaaa, off, nil +} + +func packDataAAAA(aaaa net.IP, msg []byte, off int) (int, error) { + if off+net.IPv6len > len(msg) { + return len(msg), &Error{err: "overflow packing aaaa"} + } + + switch len(aaaa) { + case net.IPv6len: + msg[off] = aaaa[0] + msg[off+1] = aaaa[1] + msg[off+2] = aaaa[2] + msg[off+3] = aaaa[3] + msg[off+4] = aaaa[4] + msg[off+5] = aaaa[5] + msg[off+6] = aaaa[6] + msg[off+7] = aaaa[7] + msg[off+8] = aaaa[8] + msg[off+9] = aaaa[9] + msg[off+10] = aaaa[10] + msg[off+11] = aaaa[11] + msg[off+12] = aaaa[12] + msg[off+13] = aaaa[13] + msg[off+14] = aaaa[14] + msg[off+15] = aaaa[15] + off += net.IPv6len + case 0: + // Allowed, dynamic updates. + default: + return len(msg), &Error{err: "overflow packing aaaa"} + } + return off, nil +} + +// unpackHeader unpacks an RR header, returning the offset to the end of the header and a +// re-sliced msg according to the expected length of the RR. +func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte, err error) { + hdr := RR_Header{} + if off == len(msg) { + return hdr, off, msg, nil + } + + hdr.Name, off, err = UnpackDomainName(msg, off) + if err != nil { + return hdr, len(msg), msg, err + } + hdr.Rrtype, off, err = unpackUint16(msg, off) + if err != nil { + return hdr, len(msg), msg, err + } + hdr.Class, off, err = unpackUint16(msg, off) + if err != nil { + return hdr, len(msg), msg, err + } + hdr.Ttl, off, err = unpackUint32(msg, off) + if err != nil { + return hdr, len(msg), msg, err + } + hdr.Rdlength, off, err = unpackUint16(msg, off) + if err != nil { + return hdr, len(msg), msg, err + } + msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength) + return hdr, off, msg, nil +} + +// pack packs an RR header, returning the offset to the end of the header. +// See PackDomainName for documentation about the compression. +func (hdr RR_Header) pack(msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { + if off == len(msg) { + return off, nil + } + + off, err = PackDomainName(hdr.Name, msg, off, compression, compress) + if err != nil { + return len(msg), err + } + off, err = packUint16(hdr.Rrtype, msg, off) + if err != nil { + return len(msg), err + } + off, err = packUint16(hdr.Class, msg, off) + if err != nil { + return len(msg), err + } + off, err = packUint32(hdr.Ttl, msg, off) + if err != nil { + return len(msg), err + } + off, err = packUint16(hdr.Rdlength, msg, off) + if err != nil { + return len(msg), err + } + return off, nil +} + +// helper helper functions. + +// truncateMsgFromRdLength truncates msg to match the expected length of the RR. +// Returns an error if msg is smaller than the expected size. +func truncateMsgFromRdlength(msg []byte, off int, rdlength uint16) (truncmsg []byte, err error) { + lenrd := off + int(rdlength) + if lenrd > len(msg) { + return msg, &Error{err: "overflowing header size"} + } + return msg[:lenrd], nil +} + +func fromBase32(s []byte) (buf []byte, err error) { + buflen := base32.HexEncoding.DecodedLen(len(s)) + buf = make([]byte, buflen) + n, err := base32.HexEncoding.Decode(buf, s) + buf = buf[:n] + return +} + +func toBase32(b []byte) string { return base32.HexEncoding.EncodeToString(b) } + +func fromBase64(s []byte) (buf []byte, err error) { + buflen := base64.StdEncoding.DecodedLen(len(s)) + buf = make([]byte, buflen) + n, err := base64.StdEncoding.Decode(buf, s) + buf = buf[:n] + return +} + +func toBase64(b []byte) string { return base64.StdEncoding.EncodeToString(b) } + +func unpackUint16Msg(msg []byte, off int) (uint16, int) { + return uint16(msg[off])<<8 | uint16(msg[off+1]), off + 2 +} + +func packUint16Msg(i uint16) (byte, byte) { return byte(i >> 8), byte(i) } + +func unpackUint32Msg(msg []byte, off int) (uint32, int) { + return uint32(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]))), off + 4 +} + +func packUint32Msg(i uint32) (byte, byte, byte, byte) { + return byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i) +} + +// dynamicUpdate returns true if the Rdlength is zero. +func noRdata(h RR_Header) bool { return h.Rdlength == 0 } + +func unpackUint8(msg []byte, off int) (i uint8, off1 int, err error) { + if off+1 > len(msg) { + return 0, len(msg), &Error{err: "overflow unpacking uint8"} + } + return uint8(msg[off]), off + 1, nil +} + +func packUint8(i uint8, msg []byte, off int) (off1 int, err error) { + if off+1 > len(msg) { + return len(msg), &Error{err: "overflow packing uint8"} + } + msg[off] = byte(i) + return off + 1, nil +} + +func unpackUint16(msg []byte, off int) (i uint16, off1 int, err error) { + if off+2 > len(msg) { + return 0, len(msg), &Error{err: "overflow unpacking uint16"} + } + i, off = unpackUint16Msg(msg, off) + return i, off, nil +} + +func packUint16(i uint16, msg []byte, off int) (off1 int, err error) { + if off+2 > len(msg) { + return len(msg), &Error{err: "overflow packing uint16"} + } + msg[off], msg[off+1] = packUint16Msg(i) + return off + 2, nil +} + +func unpackUint32(msg []byte, off int) (i uint32, off1 int, err error) { + if off+4 > len(msg) { + return 0, len(msg), &Error{err: "overflow unpacking uint32"} + } + i, off = unpackUint32Msg(msg, off) + return i, off, nil +} + +func packUint32(i uint32, msg []byte, off int) (off1 int, err error) { + if off+4 > len(msg) { + return len(msg), &Error{err: "overflow packing uint32"} + } + msg[off], msg[off+1], msg[off+2], msg[off+3] = packUint32Msg(i) + return off + 4, nil +} + +func unpackUint48(msg []byte, off int) (i uint64, off1 int, err error) { + if off+6 > len(msg) { + return 0, len(msg), &Error{err: "overflow unpacking uint64 as uint48"} + } + // Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes) + i = (uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 | + uint64(msg[off+4])<<8 | uint64(msg[off+5]))) + off += 6 + return i, off, nil +} + +func packUint48(i uint64, msg []byte, off int) (off1 int, err error) { + if off+6 > len(msg) { + return len(msg), &Error{err: "overflow packing uint64 as uint48"} + } + msg[off] = byte(i >> 40) + msg[off+1] = byte(i >> 32) + msg[off+2] = byte(i >> 24) + msg[off+3] = byte(i >> 16) + msg[off+4] = byte(i >> 8) + msg[off+5] = byte(i) + off += 6 + return off, nil +} + +func unpackUint64(msg []byte, off int) (i uint64, off1 int, err error) { + if off+8 > len(msg) { + return 0, len(msg), &Error{err: "overflow unpacking uint64"} + } + i = (uint64(uint64(msg[off])<<56 | uint64(msg[off+1])<<48 | uint64(msg[off+2])<<40 | + uint64(msg[off+3])<<32 | uint64(msg[off+4])<<24 | uint64(msg[off+5])<<16 | uint64(msg[off+6])<<8 | uint64(msg[off+7]))) + off += 8 + return i, off, nil +} + +func packUint64(i uint64, msg []byte, off int) (off1 int, err error) { + if off+8 > len(msg) { + return len(msg), &Error{err: "overflow packing uint64"} + } + msg[off] = byte(i >> 56) + msg[off+1] = byte(i >> 48) + msg[off+2] = byte(i >> 40) + msg[off+3] = byte(i >> 32) + msg[off+4] = byte(i >> 24) + msg[off+5] = byte(i >> 16) + msg[off+6] = byte(i >> 8) + msg[off+7] = byte(i) + off += 8 + return off, nil +} + +func unpackString(msg []byte, off int) (string, int, error) { + if off+1 > len(msg) { + return "", off, &Error{err: "overflow unpacking txt"} + } + l := int(msg[off]) + if off+l+1 > len(msg) { + return "", off, &Error{err: "overflow unpacking txt"} + } + s := make([]byte, 0, l) + for _, b := range msg[off+1 : off+1+l] { + switch b { + case '"', '\\': + s = append(s, '\\', b) + case '\t': + s = append(s, '\t') + case '\r': + s = append(s, '\r') + case '\n': + s = append(s, '\n') + default: + if b < 32 || b > 127 { // unprintable + var buf [3]byte + bufs := strconv.AppendInt(buf[:0], int64(b), 10) + s = append(s, '\\') + for i := 0; i < 3-len(bufs); i++ { + s = append(s, '0') + } + for _, r := range bufs { + s = append(s, r) + } + } else { + s = append(s, b) + } + } + } + off += 1 + l + return string(s), off, nil +} + +func packString(s string, msg []byte, off int) (int, error) { + txtTmp := make([]byte, 256*4+1) + off, err := packTxtString(s, msg, off, txtTmp) + if err != nil { + return len(msg), err + } + return off, nil +} + +func unpackStringBase64(msg []byte, off, end int) (string, int, error) { + // Rest of the RR is base64 encoded value, so we don't need an explicit length + // to be set. Thus far all RR's that have base64 encoded fields have those as their + // last one. What we do need is the end of the RR! + if end > len(msg) { + return "", len(msg), &Error{err: "overflow unpacking base64"} + } + s := toBase64(msg[off:end]) + return s, end, nil +} + +func packStringBase64(s string, msg []byte, off int) (int, error) { + b64, e := fromBase64([]byte(s)) + if e != nil { + return len(msg), e + } + if off+len(b64) > len(msg) { + return len(msg), &Error{err: "overflow packing base64"} + } + copy(msg[off:off+len(b64)], b64) + off += len(b64) + return off, nil +} + +func unpackStringHex(msg []byte, off, end int) (string, int, error) { + // Rest of the RR is hex encoded value, so we don't need an explicit length + // to be set. NSEC and TSIG have hex fields with a length field. + // What we do need is the end of the RR! + if end > len(msg) { + return "", len(msg), &Error{err: "overflow unpacking hex"} + } + + s := hex.EncodeToString(msg[off:end]) + return s, end, nil +} + +func packStringHex(s string, msg []byte, off int) (int, error) { + h, e := hex.DecodeString(s) + if e != nil { + return len(msg), e + } + if off+(len(h)) > len(msg) { + return len(msg), &Error{err: "overflow packing hex"} + } + copy(msg[off:off+len(h)], h) + off += len(h) + return off, nil +} diff --git a/rawmsg.go b/rawmsg.go index b4a706b9..e4e5374d 100644 --- a/rawmsg.go +++ b/rawmsg.go @@ -8,7 +8,7 @@ func rawSetId(msg []byte, i uint16) bool { if len(msg) < 2 { return false } - msg[0], msg[1] = packUint16(i) + msg[0], msg[1] = packUint16Msg(i) return true } @@ -17,7 +17,7 @@ func rawSetQuestionLen(msg []byte, i uint16) bool { if len(msg) < 6 { return false } - msg[4], msg[5] = packUint16(i) + msg[4], msg[5] = packUint16Msg(i) return true } @@ -26,7 +26,7 @@ func rawSetAnswerLen(msg []byte, i uint16) bool { if len(msg) < 8 { return false } - msg[6], msg[7] = packUint16(i) + msg[6], msg[7] = packUint16Msg(i) return true } @@ -35,7 +35,7 @@ func rawSetNsLen(msg []byte, i uint16) bool { if len(msg) < 10 { return false } - msg[8], msg[9] = packUint16(i) + msg[8], msg[9] = packUint16Msg(i) return true } @@ -44,7 +44,7 @@ func rawSetExtraLen(msg []byte, i uint16) bool { if len(msg) < 12 { return false } - msg[10], msg[11] = packUint16(i) + msg[10], msg[11] = packUint16Msg(i) return true } @@ -90,6 +90,6 @@ Loop: if rdatalen > 0xFFFF { return false } - msg[off], msg[off+1] = packUint16(uint16(rdatalen)) + msg[off], msg[off+1] = packUint16Msg(uint16(rdatalen)) return true } diff --git a/reverse.go b/reverse.go new file mode 100644 index 00000000..099dac94 --- /dev/null +++ b/reverse.go @@ -0,0 +1,38 @@ +package dns + +// StringToType is the reverse of TypeToString, needed for string parsing. +var StringToType = reverseInt16(TypeToString) + +// StringToClass is the reverse of ClassToString, needed for string parsing. +var StringToClass = reverseInt16(ClassToString) + +// Map of opcodes strings. +var StringToOpcode = reverseInt(OpcodeToString) + +// Map of rcodes strings. +var StringToRcode = reverseInt(RcodeToString) + +// Reverse a map +func reverseInt8(m map[uint8]string) map[string]uint8 { + n := make(map[string]uint8, len(m)) + for u, s := range m { + n[s] = u + } + return n +} + +func reverseInt16(m map[uint16]string) map[string]uint16 { + n := make(map[string]uint16, len(m)) + for u, s := range m { + n[s] = u + } + return n +} + +func reverseInt(m map[int]string) map[string]int { + n := make(map[string]int, len(m)) + for u, s := range m { + n[s] = u + } + return n +} diff --git a/zscan.go b/scan.go similarity index 100% rename from zscan.go rename to scan.go diff --git a/zscan_rr.go b/scan_rr.go similarity index 100% rename from zscan_rr.go rename to scan_rr.go diff --git a/server.go b/server.go index edc5c625..158cd3b7 100644 --- a/server.go +++ b/server.go @@ -615,7 +615,7 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) } return nil, ErrShortRead } - length, _ := unpackUint16(l, 0) + length, _ := unpackUint16Msg(l, 0) if length == 0 { return nil, ErrShortRead } @@ -690,7 +690,7 @@ func (w *response) Write(m []byte) (int, error) { return 0, &Error{err: "message too large"} } l := make([]byte, 2, 2+lm) - l[0], l[1] = packUint16(uint16(lm)) + l[0], l[1] = packUint16Msg(uint16(lm)) m = append(l, m...) n, err := io.Copy(w.tcp, bytes.NewReader(m)) diff --git a/sig0.go b/sig0.go index 0fccddbc..ea3e5b67 100644 --- a/sig0.go +++ b/sig0.go @@ -67,13 +67,13 @@ func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) { } // Adjust sig data length rdoff := len(mbuf) + 1 + 2 + 2 + 4 - rdlen, _ := unpackUint16(buf, rdoff) + rdlen, _ := unpackUint16Msg(buf, rdoff) rdlen += uint16(len(sig)) - buf[rdoff], buf[rdoff+1] = packUint16(rdlen) + buf[rdoff], buf[rdoff+1] = packUint16Msg(rdlen) // Adjust additional count - adc, _ := unpackUint16(buf, 10) + adc, _ := unpackUint16Msg(buf, 10) adc++ - buf[10], buf[11] = packUint16(adc) + buf[10], buf[11] = packUint16Msg(adc) return buf, nil } @@ -103,10 +103,10 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { hasher := hash.New() buflen := len(buf) - qdc, _ := unpackUint16(buf, 4) - anc, _ := unpackUint16(buf, 6) - auc, _ := unpackUint16(buf, 8) - adc, offset := unpackUint16(buf, 10) + qdc, _ := unpackUint16Msg(buf, 4) + anc, _ := unpackUint16Msg(buf, 6) + auc, _ := unpackUint16Msg(buf, 8) + adc, offset := unpackUint16Msg(buf, 10) var err error for i := uint16(0); i < qdc && offset < buflen; i++ { _, offset, err = UnpackDomainName(buf, offset) @@ -127,7 +127,7 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { continue } var rdlen uint16 - rdlen, offset = unpackUint16(buf, offset) + rdlen, offset = unpackUint16Msg(buf, offset) offset += int(rdlen) } if offset >= buflen { diff --git a/tsig.go b/tsig.go index c3374e19..7a089ba2 100644 --- a/tsig.go +++ b/tsig.go @@ -301,8 +301,8 @@ func stripTsig(msg []byte) ([]byte, *TSIG, error) { if dns.Extra[i].Header().Rrtype == TypeTSIG { rr = dns.Extra[i].(*TSIG) // Adjust Arcount. - arcount, _ := unpackUint16(msg, 10) - msg[10], msg[11] = packUint16(arcount - 1) + arcount, _ := unpackUint16Msg(msg, 10) + msg[10], msg[11] = packUint16Msg(arcount - 1) break } } diff --git a/types_generate.go b/types_generate.go index 63bfda0e..4df136ef 100644 --- a/types_generate.go +++ b/types_generate.go @@ -29,7 +29,7 @@ var skipLen = map[string]struct{}{ var packageHdr = ` // *** DO NOT MODIFY *** -// AUTOGENERATED BY go generate +// AUTOGENERATED BY go generate from type_generate.go package dns diff --git a/update_test.go b/update_test.go index 175c73b1..56602dfe 100644 --- a/update_test.go +++ b/update_test.go @@ -77,7 +77,7 @@ func TestRemoveRRset(t *testing.T) { if !bytes.Equal(actual, expect) { tmp := new(Msg) if err := tmp.Unpack(actual); err != nil { - t.Fatalf("error unpacking actual msg: %v", err) + t.Fatalf("error unpacking actual msg: %v\nexpected: %v\ngot: %v\n", err, expect, actual) } t.Errorf("expected msg:\n%s", expectstr) t.Errorf("actual msg:\n%v", tmp) diff --git a/zmsg.go b/zmsg.go new file mode 100644 index 00000000..7855d406 --- /dev/null +++ b/zmsg.go @@ -0,0 +1,827 @@ +// *** DO NOT MODIFY *** +// AUTOGENERATED BY go generate from msg_generate.go + +package dns + +// pack*() functions + +func (rr *A) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packDataA(rr.A, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *AAAA) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packDataAAAA(rr.AAAA, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *ANY) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *CNAME) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Target, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *DNAME) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Target, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *DNSKEY) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint16(rr.Flags, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.Protocol, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.Algorithm, msg, off) + if err != nil { + return off, err + } + off, err = packStringBase64(rr.PublicKey, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *HINFO) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packString(rr.Cpu, msg, off) + if err != nil { + return off, err + } + off, err = packString(rr.Os, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *L32) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint16(rr.Preference, msg, off) + if err != nil { + return off, err + } + off, err = packDataA(rr.Locator32, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *LOC) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint8(rr.Version, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.Size, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.HorizPre, msg, off) + if err != nil { + return off, err + } + off, err = packUint8(rr.VertPre, msg, off) + if err != nil { + return off, err + } + off, err = packUint32(rr.Latitude, msg, off) + if err != nil { + return off, err + } + off, err = packUint32(rr.Longitude, msg, off) + if err != nil { + return off, err + } + off, err = packUint32(rr.Altitude, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MB) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Mb, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MD) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Md, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MF) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Mf, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MG) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Mg, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MR) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Mr, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *MX) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint16(rr.Preference, msg, off) + if err != nil { + return off, err + } + off, err = PackDomainName(rr.Mx, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *NID) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint16(rr.Preference, msg, off) + if err != nil { + return off, err + } + off, err = packUint64(rr.NodeID, msg, off) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *NS) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Ns, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *PTR) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Ptr, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *RP) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = PackDomainName(rr.Mbox, msg, off, compression, compress) + if err != nil { + return off, err + } + off, err = PackDomainName(rr.Txt, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +func (rr *SRV) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { + off, err := rr.Hdr.pack(msg, off, compression, compress) + if err != nil { + return off, err + } + headerEnd := off + off, err = packUint16(rr.Priority, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(rr.Weight, msg, off) + if err != nil { + return off, err + } + off, err = packUint16(rr.Port, msg, off) + if err != nil { + return off, err + } + off, err = PackDomainName(rr.Target, msg, off, compression, compress) + if err != nil { + return off, err + } + rr.Header().Rdlength = uint16(off - headerEnd) + return off, nil +} + +// unpack*() functions + +func unpackA(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(A) + rr.Hdr = h + + rr.A, off, err = unpackDataA(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackAAAA(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(AAAA) + rr.Hdr = h + + rr.AAAA, off, err = unpackDataAAAA(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackANY(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(ANY) + rr.Hdr = h + + return rr, off, err +} + +func unpackCNAME(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(CNAME) + rr.Hdr = h + + rr.Target, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackDNAME(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(DNAME) + rr.Hdr = h + + rr.Target, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackDNSKEY(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(DNSKEY) + rr.Hdr = h + + rr.Flags, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Protocol, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Algorithm, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.PublicKey, off, err = unpackStringBase64(msg, off, rdStart+int(rr.Hdr.Rdlength)) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackHINFO(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(HINFO) + rr.Hdr = h + + rr.Cpu, off, err = unpackString(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Os, off, err = unpackString(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackL32(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(L32) + rr.Hdr = h + + rr.Preference, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Locator32, off, err = unpackDataA(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackLOC(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(LOC) + rr.Hdr = h + + rr.Version, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Size, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.HorizPre, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.VertPre, off, err = unpackUint8(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Latitude, off, err = unpackUint32(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Longitude, off, err = unpackUint32(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Altitude, off, err = unpackUint32(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMB(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MB) + rr.Hdr = h + + rr.Mb, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMD(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MD) + rr.Hdr = h + + rr.Md, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMF(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MF) + rr.Hdr = h + + rr.Mf, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMG(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MG) + rr.Hdr = h + + rr.Mg, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMR(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MR) + rr.Hdr = h + + rr.Mr, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackMX(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(MX) + rr.Hdr = h + + rr.Preference, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Mx, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackNID(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(NID) + rr.Hdr = h + + rr.Preference, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.NodeID, off, err = unpackUint64(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackNS(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(NS) + rr.Hdr = h + + rr.Ns, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackPTR(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(PTR) + rr.Hdr = h + + rr.Ptr, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackRP(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(RP) + rr.Hdr = h + + rr.Mbox, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Txt, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} + +func unpackSRV(h RR_Header, msg []byte, off int) (RR, int, error) { + if noRdata(h) { + return nil, off, nil + } + var err error + rdStart := off + _ = rdStart + + rr := new(SRV) + rr.Hdr = h + + rr.Priority, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Weight, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Port, off, err = unpackUint16(msg, off) + if err != nil { + return rr, off, err + } + if off == len(msg) { + return rr, off, nil + } + rr.Target, off, err = UnpackDomainName(msg, off) + if err != nil { + return rr, off, err + } + return rr, off, err +} diff --git a/ztypes.go b/ztypes.go index 3d0f9aef..858272bc 100644 --- a/ztypes.go +++ b/ztypes.go @@ -1,5 +1,5 @@ // *** DO NOT MODIFY *** -// AUTOGENERATED BY go generate +// AUTOGENERATED BY go generate from type_generate.go package dns