Remove (most) reflection

Remove the use of reflection when packing and unpacking, instead
generate all the pack and unpack functions using msg_generate.
This will generate zmsg.go which in turn calls the helper functions from
msg_helper.go.

This increases the speed by about ~30% while cutting back on memory
usage. Not all RRs are using it, but that will be rectified in upcoming
PR.

Most of the speed increase is in the header/question section parsing.
These functions *are* not generated, but straight forward enough. The
implementation can be found in msg.go.

The new code has been fuzzed by go-fuzz, which turned up some issues.

All files that started with 'z', and not autogenerated were renamed,
i.e. zscan.go is now scan.go.

Reflection is still used, in subsequent PRs it will be removed entirely.
This commit is contained in:
Miek Gieben 2016-05-14 17:56:20 +01:00
parent b5deb9f6c0
commit 475ab80867
23 changed files with 2052 additions and 272 deletions

View File

@ -150,11 +150,3 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
* `NSD`
* `Net::DNS`
* `GRONG`
## TODO
* privatekey.Precompute() when signing?
* Last remaining RRs: APL, ATMA, A6, NSAP and NXT.
* Missing in parsing: ISDN, UNSPEC, NSAP and ATMA.
* NSEC(3) cover/match/closest enclose.
* Replies with TC bit are not parsed to the end.

View File

@ -300,7 +300,7 @@ func tcpMsgLen(t io.Reader) (int, error) {
if n != 2 {
return 0, ErrShortRead
}
l, _ := unpackUint16(p, 0)
l, _ := unpackUint16Msg(p, 0)
if l == 0 {
return 0, ErrShortRead
}
@ -392,7 +392,7 @@ func (co *Conn) Write(p []byte) (n int, err error) {
return 0, &Error{err: "message too large"}
}
l := make([]byte, 2, lp+2)
l[0], l[1] = packUint16(uint16(lp))
l[0], l[1] = packUint16Msg(uint16(lp))
p = append(l, p...)
n, err := io.Copy(w, bytes.NewReader(p))
return int(n), err

205
dns_bench_test.go Normal file
View File

@ -0,0 +1,205 @@
package dns
import (
"net"
"testing"
)
func BenchmarkMsgLength(b *testing.B) {
b.StopTimer()
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
b.StartTimer()
for i := 0; i < b.N; i++ {
msg.Len()
}
}
func BenchmarkMsgLengthPack(b *testing.B) {
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = msg.Pack()
}
}
func BenchmarkPackDomainName(b *testing.B) {
name1 := "12345678901234567890123456789012345.12345678.123."
buf := make([]byte, len(name1)+1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = PackDomainName(name1, buf, 0, nil, false)
}
}
func BenchmarkUnpackDomainName(b *testing.B) {
name1 := "12345678901234567890123456789012345.12345678.123."
buf := make([]byte, len(name1)+1)
_, _ = PackDomainName(name1, buf, 0, nil, false)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = UnpackDomainName(buf, 0)
}
}
func BenchmarkUnpackDomainNameUnprintable(b *testing.B) {
name1 := "\x02\x02\x02\x025\x02\x02\x02\x02.12345678.123."
buf := make([]byte, len(name1)+1)
_, _ = PackDomainName(name1, buf, 0, nil, false)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = UnpackDomainName(buf, 0)
}
}
func BenchmarkCopy(b *testing.B) {
b.ReportAllocs()
m := new(Msg)
m.SetQuestion("miek.nl.", TypeA)
rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1")
m.Answer = []RR{rr}
rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1")
m.Ns = []RR{rr}
rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.1")
m.Extra = []RR{rr}
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.Copy()
}
}
func BenchmarkPackA(b *testing.B) {
a := &A{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, A: net.IPv4(127, 0, 0, 1)}
buf := make([]byte, a.len())
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = PackRR(a, buf, 0, nil, false)
}
}
func BenchmarkUnpackA(b *testing.B) {
a := &A{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, A: net.IPv4(127, 0, 0, 1)}
buf := make([]byte, a.len())
PackRR(a, buf, 0, nil, false)
a = nil
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = UnpackRR(buf, 0)
}
}
func BenchmarkPackMX(b *testing.B) {
m := &MX{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, Mx: "mx.miek.nl."}
buf := make([]byte, m.len())
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = PackRR(m, buf, 0, nil, false)
}
}
func BenchmarkUnpackMX(b *testing.B) {
m := &MX{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, Mx: "mx.miek.nl."}
buf := make([]byte, m.len())
PackRR(m, buf, 0, nil, false)
m = nil
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = UnpackRR(buf, 0)
}
}
func BenchmarkPackAAAAA(b *testing.B) {
aaaa, _ := NewRR(". IN A ::1")
buf := make([]byte, aaaa.len())
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = PackRR(aaaa, buf, 0, nil, false)
}
}
func BenchmarkUnpackAAAA(b *testing.B) {
aaaa, _ := NewRR(". IN A ::1")
buf := make([]byte, aaaa.len())
PackRR(aaaa, buf, 0, nil, false)
aaaa = nil
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = UnpackRR(buf, 0)
}
}
func BenchmarkPackMsg(b *testing.B) {
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
buf := make([]byte, 512)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = msg.PackBuffer(buf)
}
}
func BenchmarkUnpackMsg(b *testing.B) {
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
msgBuf, _ := msg.Pack()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = msg.Unpack(msgBuf)
}
}

View File

@ -274,9 +274,9 @@ func TestMsgLength2(t *testing.T) {
for i, hexData := range testMessages {
// we won't fail the decoding of the hex
input, _ := hex.DecodeString(hexData)
m := new(Msg)
m.Unpack(input)
//println(m.String())
m.Compress = true
lenComp := m.Len()
b, _ := m.Pack()
@ -310,114 +310,6 @@ func TestMsgLengthCompressionMalformed(t *testing.T) {
m.Len() // Should not crash.
}
func BenchmarkMsgLength(b *testing.B) {
b.StopTimer()
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
b.StartTimer()
for i := 0; i < b.N; i++ {
msg.Len()
}
}
func BenchmarkMsgLengthPack(b *testing.B) {
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = msg.Pack()
}
}
func BenchmarkMsgPackBuffer(b *testing.B) {
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
buf := make([]byte, 512)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = msg.PackBuffer(buf)
}
}
func BenchmarkMsgUnpack(b *testing.B) {
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name1 := "12345678901234567890123456789012345.12345678.123."
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
msgBuf, _ := msg.Pack()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = msg.Unpack(msgBuf)
}
}
func BenchmarkPackDomainName(b *testing.B) {
name1 := "12345678901234567890123456789012345.12345678.123."
buf := make([]byte, len(name1)+1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = PackDomainName(name1, buf, 0, nil, false)
}
}
func BenchmarkUnpackDomainName(b *testing.B) {
name1 := "12345678901234567890123456789012345.12345678.123."
buf := make([]byte, len(name1)+1)
_, _ = PackDomainName(name1, buf, 0, nil, false)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = UnpackDomainName(buf, 0)
}
}
func BenchmarkUnpackDomainNameUnprintable(b *testing.B) {
name1 := "\x02\x02\x02\x025\x02\x02\x02\x02.12345678.123."
buf := make([]byte, len(name1)+1)
_, _ = PackDomainName(name1, buf, 0, nil, false)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = UnpackDomainName(buf, 0)
}
}
func TestToRFC3597(t *testing.T) {
a, _ := NewRR("miek.nl. IN A 10.0.1.1")
x := new(RFC3597)
@ -431,7 +323,7 @@ func TestNoRdataPack(t *testing.T) {
data := make([]byte, 1024)
for typ, fn := range TypeToRR {
r := fn()
*r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600}
*r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 16}
_, err := PackRR(r, data, 0, nil, false)
if err != nil {
t.Errorf("failed to pack RR with zero rdata: %s: %v", TypeToString[typ], err)
@ -439,7 +331,6 @@ func TestNoRdataPack(t *testing.T) {
}
}
// TODO(miek): fix dns buffer too small errors this throws
func TestNoRdataUnpack(t *testing.T) {
data := make([]byte, 1024)
for typ, fn := range TypeToRR {
@ -449,7 +340,7 @@ func TestNoRdataUnpack(t *testing.T) {
continue
}
r := fn()
*r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600}
*r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 16}
off, err := PackRR(r, data, 0, nil, false)
if err != nil {
// Should always works, TestNoDataPack should have caught this
@ -513,23 +404,6 @@ func TestMsgCopy(t *testing.T) {
}
}
func BenchmarkCopy(b *testing.B) {
b.ReportAllocs()
m := new(Msg)
m.SetQuestion("miek.nl.", TypeA)
rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1")
m.Answer = []RR{rr}
rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1")
m.Ns = []RR{rr}
rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.1")
m.Extra = []RR{rr}
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.Copy()
}
}
func TestPackIPSECKEY(t *testing.T) {
tests := []string{
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.38 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",

View File

@ -144,7 +144,7 @@ func (k *DNSKEY) KeyTag() uint16 {
// at the base64 values. But I'm lazy.
modulus, _ := fromBase64([]byte(k.PublicKey))
if len(modulus) > 1 {
x, _ := unpackUint16(modulus, len(modulus)-2)
x, _ := unpackUint16Msg(modulus, len(modulus)-2)
keytag = int(x)
}
default:

16
edns.go
View File

@ -213,7 +213,7 @@ func (e *EDNS0_SUBNET) Option() uint16 {
func (e *EDNS0_SUBNET) pack() ([]byte, error) {
b := make([]byte, 4)
b[0], b[1] = packUint16(e.Family)
b[0], b[1] = packUint16Msg(e.Family)
b[2] = e.SourceNetmask
b[3] = e.SourceScope
switch e.Family {
@ -247,7 +247,7 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error {
if len(b) < 4 {
return ErrBuf
}
e.Family, _ = unpackUint16(b, 0)
e.Family, _ = unpackUint16Msg(b, 0)
e.SourceNetmask = b[2]
e.SourceScope = b[3]
switch e.Family {
@ -369,9 +369,9 @@ func (e *EDNS0_LLQ) Option() uint16 { return EDNS0LLQ }
func (e *EDNS0_LLQ) pack() ([]byte, error) {
b := make([]byte, 18)
b[0], b[1] = packUint16(e.Version)
b[2], b[3] = packUint16(e.Opcode)
b[4], b[5] = packUint16(e.Error)
b[0], b[1] = packUint16Msg(e.Version)
b[2], b[3] = packUint16Msg(e.Opcode)
b[4], b[5] = packUint16Msg(e.Error)
b[6] = byte(e.Id >> 56)
b[7] = byte(e.Id >> 48)
b[8] = byte(e.Id >> 40)
@ -391,9 +391,9 @@ func (e *EDNS0_LLQ) unpack(b []byte) error {
if len(b) < 18 {
return ErrBuf
}
e.Version, _ = unpackUint16(b, 0)
e.Opcode, _ = unpackUint16(b, 2)
e.Error, _ = unpackUint16(b, 4)
e.Version, _ = unpackUint16Msg(b, 0)
e.Opcode, _ = unpackUint16Msg(b, 2)
e.Error, _ = unpackUint16Msg(b, 4)
e.Id = uint64(b[6])<<56 | uint64(b[6+1])<<48 | uint64(b[6+2])<<40 |
uint64(b[6+3])<<32 | uint64(b[6+4])<<24 | uint64(b[6+5])<<16 | uint64(b[6+6])<<8 | uint64(b[6+7])
e.LeaseLife = uint32(b[14])<<24 | uint32(b[14+1])<<16 | uint32(b[14+2])<<8 | uint32(b[14+3])

View File

@ -107,7 +107,7 @@ func CountLabel(s string) (labels int) {
// Split splits a name s into its label indexes.
// www.miek.nl. returns []int{0, 4, 9}, www.miek.nl also returns []int{0, 4, 9}.
// The root name (.) returns nil. Also see SplitDomainName.
// The root name (.) returns nil. Also see SplitDomainName.
// s must be a syntactically valid domain name.
func Split(s string) []int {
if s == "." {

View File

@ -184,12 +184,14 @@ func BenchmarkLenLabels(b *testing.B) {
}
func BenchmarkCompareLabels(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
CompareDomainName("www.example.com", "aa.example.com")
}
}
func BenchmarkIsSubDomain(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
IsSubDomain("www.example.com", "aa.example.com")
IsSubDomain("example.com", "aa.example.com")

338
msg.go
View File

@ -8,9 +8,9 @@
package dns
//go:generate go run msg_generate.go
import (
"encoding/base32"
"encoding/base64"
"encoding/hex"
"math/big"
"math/rand"
@ -92,18 +92,6 @@ type Msg struct {
Extra []RR // Holds the RR(s) of the additional section.
}
// StringToType is the reverse of TypeToString, needed for string parsing.
var StringToType = reverseInt16(TypeToString)
// StringToClass is the reverse of ClassToString, needed for string parsing.
var StringToClass = reverseInt16(ClassToString)
// Map of opcodes strings.
var StringToOpcode = reverseInt(OpcodeToString)
// Map of rcodes strings.
var StringToRcode = reverseInt(RcodeToString)
// ClassToString is a maps Classes to strings for each CLASS wire type.
var ClassToString = map[uint16]string{
ClassINET: "IN",
@ -291,11 +279,11 @@ func packDomainName(s string, msg []byte, off int, compression map[string]int, c
if pointer != -1 {
// We have two bytes (14 bits) to put the pointer in
// if msg == nil, we will never do compression
msg[nameoffset], msg[nameoffset+1] = packUint16(uint16(pointer ^ 0xC000))
msg[nameoffset], msg[nameoffset+1] = packUint16Msg(uint16(pointer ^ 0xC000))
off = nameoffset + 1
goto End
}
if msg != nil {
if msg != nil && off < len(msg) {
msg[off] = 0
}
End:
@ -423,7 +411,7 @@ func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) {
func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) {
lenByteOffset := offset
if offset >= len(msg) {
if offset >= len(msg) || len(s) > len(tmp) {
return offset, ErrBuf
}
offset++
@ -465,7 +453,7 @@ func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) {
}
func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) {
if offset >= len(msg) {
if offset >= len(msg) || len(s) > len(tmp) {
return offset, ErrBuf
}
bs := tmp[:len(s)]
@ -600,9 +588,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
return lenmsg, &Error{err: "overflow packing opt"}
}
// Option code
msg[off], msg[off+1] = packUint16(element.(EDNS0).Option())
msg[off], msg[off+1] = packUint16Msg(element.(EDNS0).Option())
// Length
msg[off+2], msg[off+3] = packUint16(uint16(len(b)))
msg[off+2], msg[off+3] = packUint16Msg(uint16(len(b)))
off += 4
if off+len(b) > lenmsg {
copy(msg[off:], b)
@ -783,6 +771,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
if e != nil {
return lenmsg, e
}
if off+len(b64) > lenmsg {
return lenmsg, &Error{err: "overflow packing base64"}
}
copy(msg[off:off+len(b64)], b64)
off += len(b64)
case `dns:"domain-name"`:
@ -811,6 +802,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
if e != nil {
return lenmsg, e
}
if off+len(b32) > lenmsg {
return lenmsg, &Error{err: "overflow packing base32"}
}
copy(msg[off:off+len(b32)], b32)
off += len(b32)
case `dns:"size-hex"`:
@ -827,6 +821,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
copy(msg[off:off+hex.DecodedLen(len(s))], h)
off += hex.DecodedLen(len(s))
case `dns:"size"`:
// TODO(miek): WTF? size?
// the size is already encoded in the RR, we can safely use the
// length of string. String is RAW (not encoded in hex, nor base64)
copy(msg[off:off+len(s)], s)
@ -930,8 +925,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
if off+4 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking opt"}
}
code, off = unpackUint16(msg, off)
optlen, off1 := unpackUint16(msg, off)
code, off = unpackUint16Msg(msg, off)
optlen, off1 := unpackUint16Msg(msg, off)
if off1+int(optlen) > lenmsg {
return lenmsg, &Error{err: "overflow unpacking opt"}
}
@ -1174,7 +1169,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
if off+2 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking uint16"}
}
i, off = unpackUint16(msg, off)
i, off = unpackUint16Msg(msg, off)
fv.SetUint(uint64(i))
case reflect.Uint32:
if off == lenmsg {
@ -1334,38 +1329,6 @@ func intToBytes(i *big.Int, length int) []byte {
return buf
}
func unpackUint16(msg []byte, off int) (uint16, int) {
return uint16(msg[off])<<8 | uint16(msg[off+1]), off + 2
}
func packUint16(i uint16) (byte, byte) {
return byte(i >> 8), byte(i)
}
func toBase32(b []byte) string {
return base32.HexEncoding.EncodeToString(b)
}
func fromBase32(s []byte) (buf []byte, err error) {
buflen := base32.HexEncoding.DecodedLen(len(s))
buf = make([]byte, buflen)
n, err := base32.HexEncoding.Decode(buf, s)
buf = buf[:n]
return
}
func toBase64(b []byte) string {
return base64.StdEncoding.EncodeToString(b)
}
func fromBase64(s []byte) (buf []byte, err error) {
buflen := base64.StdEncoding.DecodedLen(len(s))
buf = make([]byte, buflen)
n, err := base64.StdEncoding.Decode(buf, s)
buf = buf[:n]
return
}
// PackRR packs a resource record rr into msg[off:].
// See PackDomainName for documentation about the compression.
func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
@ -1373,7 +1336,62 @@ func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress boo
return len(msg), &Error{err: "nil rr"}
}
off1, err = packStructCompress(rr, msg, off, compression, compress)
_, ok := typeToUnpack[rr.Header().Rrtype]
switch ok {
case true:
// Shortcut reflection, `pack' needs to be added to the RR interface so we can just do this:
// off1, err = t.pack(msg, off, compression, compress)
// TODO(miek): revert the logic and make a blacklist for types that still use reflection. Kill
// typeToUnpack and just generate all the pack and unpack functions even though we don't use
// them for all types (yet).
switch t := rr.(type) {
case *RR_Header:
// we can be called with an empty RR, consisting only out of the header, see update_test.go's
// TestDynamicUpdateZeroRdataUnpack for an example. This is OK as RR_Header also implements the RR interface.
off1, err = t.pack(msg, off, compression, compress)
case *ANY:
// Also "weird" setup, see (again) update_test.go's TestRemoveRRset, where the Rrtype is 1 but the type is *ANY.
off1, err = t.pack(msg, off, compression, compress)
case *A:
off1, err = t.pack(msg, off, compression, compress)
case *AAAA:
off1, err = t.pack(msg, off, compression, compress)
case *CNAME:
off1, err = t.pack(msg, off, compression, compress)
case *DNAME:
off1, err = t.pack(msg, off, compression, compress)
case *HINFO:
off1, err = t.pack(msg, off, compression, compress)
case *L32:
off1, err = t.pack(msg, off, compression, compress)
case *LOC:
off1, err = t.pack(msg, off, compression, compress)
case *MB:
off1, err = t.pack(msg, off, compression, compress)
case *MD:
off1, err = t.pack(msg, off, compression, compress)
case *MF:
off1, err = t.pack(msg, off, compression, compress)
case *MG:
off1, err = t.pack(msg, off, compression, compress)
case *MX:
off1, err = t.pack(msg, off, compression, compress)
case *NID:
off1, err = t.pack(msg, off, compression, compress)
case *NS:
off1, err = t.pack(msg, off, compression, compress)
case *PTR:
off1, err = t.pack(msg, off, compression, compress)
case *RP:
off1, err = t.pack(msg, off, compression, compress)
case *SRV:
off1, err = t.pack(msg, off, compression, compress)
case *DNSKEY:
off1, err = t.pack(msg, off, compression, compress)
}
default:
off1, err = packStructCompress(rr, msg, off, compression, compress)
}
if err != nil {
return len(msg), err
}
@ -1385,21 +1403,27 @@ func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress boo
// UnpackRR unpacks msg[off:] into an RR.
func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
// unpack just the header, to find the rr type and length
var h RR_Header
off0 := off
if off, err = UnpackStruct(&h, msg, off); err != nil {
h, off, msg, err := unpackHeader(msg, off)
if err != nil {
return nil, len(msg), err
}
end := off + int(h.Rdlength)
// make an rr of that type and re-unpack.
mk, known := TypeToRR[h.Rrtype]
if !known {
rr = new(RFC3597)
} else {
rr = mk()
fn, ok := typeToUnpack[h.Rrtype]
switch ok {
case true:
// Shortcut reflection.
rr, off, err = fn(h, msg, off)
default:
mk, known := TypeToRR[h.Rrtype]
if !known {
rr = new(RFC3597)
} else {
rr = mk()
}
off, err = UnpackStruct(rr, msg, off0)
}
off, err = UnpackStruct(rr, msg, off0)
if off != end {
return &h, end, &Error{err: "bad rdlength"}
}
@ -1432,31 +1456,6 @@ func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error)
return dst, off, err
}
// Reverse a map
func reverseInt8(m map[uint8]string) map[string]uint8 {
n := make(map[string]uint8)
for u, s := range m {
n[s] = u
}
return n
}
func reverseInt16(m map[uint16]string) map[string]uint16 {
n := make(map[string]uint16)
for u, s := range m {
n[s] = u
}
return n
}
func reverseInt(m map[int]string) map[string]int {
n := make(map[string]int)
for u, s := range m {
n[s] = u
}
return n
}
// Convert a MsgHdr to a string, with dig-like headers:
//
//;; opcode: QUERY, status: NOERROR, id: 48404
@ -1510,8 +1509,11 @@ func (dns *Msg) Pack() (msg []byte, err error) {
// PackBuffer packs a Msg, using the given buffer buf. If buf is too small
// a new buffer is allocated.
func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
var dh Header
var compression map[string]int
var (
dh Header
compression map[string]int
)
if dns.Compress {
compression = make(map[string]int) // Compression pointer mappings
}
@ -1579,12 +1581,12 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
// Pack it in: header and then the pieces.
off := 0
off, err = packStructCompress(&dh, msg, off, compression, dns.Compress)
off, err = dh.pack(msg, off, compression, dns.Compress)
if err != nil {
return nil, err
}
for i := 0; i < len(question); i++ {
off, err = packStructCompress(&question[i], msg, off, compression, dns.Compress)
off, err = question[i].pack(msg, off, compression, dns.Compress)
if err != nil {
return nil, err
}
@ -1612,12 +1614,17 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
// Unpack unpacks a binary message to a Msg structure.
func (dns *Msg) Unpack(msg []byte) (err error) {
// Header.
var dh Header
off := 0
if off, err = UnpackStruct(&dh, msg, off); err != nil {
var (
dh Header
off int
)
if dh, off, err = unpackMsgHdr(msg, off); err != nil {
return err
}
if off == len(msg) {
return ErrTruncated
}
dns.Id = dh.Id
dns.Response = (dh.Bits & _QR) != 0
dns.Opcode = int(dh.Bits>>11) & 0xF
@ -1633,10 +1640,10 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
// Optimistically use the count given to us in the header
dns.Question = make([]Question, 0, int(dh.Qdcount))
var q Question
for i := 0; i < int(dh.Qdcount); i++ {
off1 := off
off, err = UnpackStruct(&q, msg, off)
var q Question
q, off, err = unpackQuestion(msg, off)
if err != nil {
// Even if Truncated is set, we only will set ErrTruncated if we
// actually got the questions
@ -1662,6 +1669,7 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
}
// The header counts might have been wrong so we need to update it
dh.Arcount = uint16(len(dns.Extra))
if off != len(msg) {
// TODO(miek) make this an error?
// use PackOpt to let people tell how detailed the error reporting should be?
@ -1735,6 +1743,9 @@ func (dns *Msg) Len() int {
}
}
for i := 0; i < len(dns.Answer); i++ {
if dns.Answer[i] == nil {
continue
}
l += dns.Answer[i].len()
if dns.Compress {
k, ok := compressionLenSearch(compression, dns.Answer[i].Header().Name)
@ -1750,6 +1761,9 @@ func (dns *Msg) Len() int {
}
}
for i := 0; i < len(dns.Ns); i++ {
if dns.Ns[i] == nil {
continue
}
l += dns.Ns[i].len()
if dns.Compress {
k, ok := compressionLenSearch(compression, dns.Ns[i].Header().Name)
@ -1765,6 +1779,9 @@ func (dns *Msg) Len() int {
}
}
for i := 0; i < len(dns.Extra); i++ {
if dns.Extra[i] == nil {
continue
}
l += dns.Extra[i].len()
if dns.Compress {
k, ok := compressionLenSearch(compression, dns.Extra[i].Header().Name)
@ -1955,3 +1972,122 @@ func (dns *Msg) CopyTo(r1 *Msg) *Msg {
return r1
}
func (q *Question) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
off, err := PackDomainName(q.Name, msg, off, compression, compress)
if err != nil {
return off, err
}
off, err = packUint16(q.Qtype, msg, off)
if err != nil {
return off, err
}
off, err = packUint16(q.Qclass, msg, off)
if err != nil {
return off, err
}
return off, nil
}
func unpackQuestion(msg []byte, off int) (Question, int, error) {
var (
q Question
err error
)
q.Name, off, err = UnpackDomainName(msg, off)
if err != nil {
return q, off, err
}
if off == len(msg) {
return q, off, nil
}
q.Qtype, off, err = unpackUint16(msg, off)
if err != nil {
return q, off, err
}
if off == len(msg) {
return q, off, nil
}
q.Qclass, off, err = unpackUint16(msg, off)
if off == len(msg) {
return q, off, nil
}
return q, off, err
}
func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
off, err := packUint16(dh.Id, msg, off)
if err != nil {
return off, err
}
off, err = packUint16(dh.Bits, msg, off)
if err != nil {
return off, err
}
off, err = packUint16(dh.Qdcount, msg, off)
if err != nil {
return off, err
}
off, err = packUint16(dh.Ancount, msg, off)
if err != nil {
return off, err
}
off, err = packUint16(dh.Nscount, msg, off)
if err != nil {
return off, err
}
off, err = packUint16(dh.Arcount, msg, off)
return off, err
}
func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
var (
dh Header
err error
)
dh.Id, off, err = unpackUint16(msg, off)
if err != nil {
return dh, off, err
}
dh.Bits, off, err = unpackUint16(msg, off)
if err != nil {
return dh, off, err
}
dh.Qdcount, off, err = unpackUint16(msg, off)
if err != nil {
return dh, off, err
}
dh.Ancount, off, err = unpackUint16(msg, off)
if err != nil {
return dh, off, err
}
dh.Nscount, off, err = unpackUint16(msg, off)
if err != nil {
return dh, off, err
}
dh.Arcount, off, err = unpackUint16(msg, off)
return dh, off, err
}
// Which types have type specific unpack functions.
var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){
TypeAAAA: unpackAAAA,
TypeA: unpackA,
TypeCNAME: unpackCNAME,
TypeDNAME: unpackDNAME,
TypeL32: unpackL32,
TypeLOC: unpackLOC,
TypeMB: unpackMB,
TypeMD: unpackMD,
TypeMF: unpackMF,
TypeMG: unpackMG,
TypeMR: unpackMR,
TypeMX: unpackMX,
TypeNID: unpackNID,
TypeNS: unpackNS,
TypePTR: unpackPTR,
TypeRP: unpackRP,
TypeSRV: unpackSRV,
TypeHINFO: unpackHINFO,
TypeDNSKEY: unpackDNSKEY,
}

307
msg_generate.go Normal file
View File

@ -0,0 +1,307 @@
//+build ignore
// msg_generate.go is meant to run with go generate. It will use
// go/{importer,types} to track down all the RR struct types. Then for each type
// it will generate pack/unpack methods based on the struct tags. The generated source is
// written to zmsg.go, and is meant to be checked into git.
package main
import (
"bytes"
"fmt"
"go/format"
"go/importer"
"go/types"
"log"
"os"
)
// All RR pack and unpack functions should be generated, currently RR that present some
// problems
// * NSEC/NSEC3 - type bitmap
// * TXT/SPF - string slice
// * URI - weird octet thing there
// * NSEC3/TSIG - size hex
// * OPT RR - EDNS0 parsing - needs to some looking at
// * HIP - uses "hex", but is actually size-hex - might drop size-hex?
// * Z
// * WKS - uint16 slice
// * NINFO
// * PrivateRR
// What types are we generating, should be kept in sync with typeToUnpack in msg.go
var generate = map[string]bool{
"AAAA": true,
"ANY": true,
"A": true,
"CNAME": true,
"DNAME": true,
"DNSKEY": true,
"HINFO": true,
"L32": true,
"LOC": true,
"MB": true,
"MD": true,
"MF": true,
"MG": true,
"MR": true,
"MX": true,
"NID": true,
"NS": true,
"PTR": true,
"RP": true,
"SRV": true,
}
func shouldGenerate(name string) bool {
_, ok := generate[name]
return ok
}
// For later: IPSECKEY is weird.
var packageHdr = `
// *** DO NOT MODIFY ***
// AUTOGENERATED BY go generate from msg_generate.go
package dns
`
// getTypeStruct will take a type and the package scope, and return the
// (innermost) struct if the type is considered a RR type (currently defined as
// those structs beginning with a RR_Header, could be redefined as implementing
// the RR interface). The bool return value indicates if embedded structs were
// resolved.
func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
st, ok := t.Underlying().(*types.Struct)
if !ok {
return nil, false
}
if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
return st, false
}
if st.Field(0).Anonymous() {
st, _ := getTypeStruct(st.Field(0).Type(), scope)
return st, true
}
return nil, false
}
func main() {
// Import and type-check the package
pkg, err := importer.Default().Import("github.com/miekg/dns")
fatalIfErr(err)
scope := pkg.Scope()
// Collect actual types (*X)
var namedTypes []string
for _, name := range scope.Names() {
o := scope.Lookup(name)
if o == nil || !o.Exported() {
continue
}
if st, _ := getTypeStruct(o.Type(), scope); st == nil {
continue
}
if name == "PrivateRR" {
continue
}
// Check if corresponding TypeX exists
if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
log.Fatalf("Constant Type%s does not exist.", o.Name())
}
namedTypes = append(namedTypes, o.Name())
}
b := &bytes.Buffer{}
b.WriteString(packageHdr)
fmt.Fprint(b, "// pack*() functions\n\n")
for _, name := range namedTypes {
o := scope.Lookup(name)
st, isEmbedded := getTypeStruct(o.Type(), scope)
if isEmbedded || !shouldGenerate(name) {
continue
}
fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name)
fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress)
if err != nil {
return off, err
}
headerEnd := off
`)
for i := 1; i < st.NumFields(); i++ {
o := func(s string) {
fmt.Fprintf(b, s, st.Field(i).Name())
fmt.Fprint(b, `if err != nil {
return off, err
}
`)
}
//if _, ok := st.Field(i).Type().(*types.Slice); ok {
//switch st.Tag(i) {
//case `dns:"-"`:
//// ignored
//case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`:
//o("for _, x := range rr.%s { l += len(x) + 1 }\n")
//default:
//log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
//}
//continue
//}
switch st.Tag(i) {
case `dns:"-"`:
// ignored
case `dns:"cdomain-name"`:
fallthrough
case `dns:"domain-name"`:
o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n")
case `dns:"a"`:
o("off, err = packDataA(rr.%s, msg, off)\n")
case `dns:"aaaa"`:
o("off, err = packDataAAAA(rr.%s, msg, off)\n")
case `dns:"uint48"`:
o("off, err = packUint48(rr.%s, msg, off)\n")
case `dns:"txt"`:
o("off, err = packString(rr.%s, msg, off)\n")
case `dns:"base32"`:
o("off, err = packStringBase32(rr.%s, msg, off)\n")
case `dns:"base64"`:
o("off, err = packStringBase64(rr.%s, msg, off)\n")
case "":
switch st.Field(i).Type().(*types.Basic).Kind() {
case types.Uint8:
o("off, err = packUint8(rr.%s, msg, off)\n")
case types.Uint16:
o("off, err = packUint16(rr.%s, msg, off)\n")
case types.Uint32:
o("off, err = packUint32(rr.%s, msg, off)\n")
case types.Uint64:
o("off, err = packUint64(rr.%s, msg, off)\n")
case types.String:
o("off, err = packString(rr.%s, msg, off)\n")
default:
log.Fatalln(name, st.Field(i).Name())
}
//default:
//log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
}
// We have packed everything, only now we know the rdlength of this RR
fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off- headerEnd)")
fmt.Fprintln(b, "return off, nil }\n")
}
fmt.Fprint(b, "// unpack*() functions\n\n")
for _, name := range namedTypes {
o := scope.Lookup(name)
st, isEmbedded := getTypeStruct(o.Type(), scope)
if isEmbedded || !shouldGenerate(name) {
continue
}
fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name)
fmt.Fprint(b, `if noRdata(h) {
return nil, off, nil
}
var err error
rdStart := off
_ = rdStart
`)
fmt.Fprintf(b, "rr := new(%s)\n", name)
fmt.Fprintln(b, "rr.Hdr = h\n")
for i := 1; i < st.NumFields(); i++ {
o := func(s string) {
fmt.Fprintf(b, s, st.Field(i).Name())
fmt.Fprint(b, `if err != nil {
return rr, off, err
}
`)
}
//if _, ok := st.Field(i).Type().(*types.Slice); ok {
//switch st.Tag(i) {
//case `dns:"-"`:
//// ignored
//case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`:
//o("for _, x := range rr.%s { l += len(x) + 1 }\n")
//default:
//log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
//}
//continue
//}
switch st.Tag(i) {
case `dns:"-"`:
// ignored
case `dns:"cdomain-name"`:
fallthrough
case `dns:"domain-name"`:
o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
case `dns:"a"`:
o("rr.%s, off, err = unpackDataA(msg, off)\n")
case `dns:"aaaa"`:
o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
case `dns:"uint48"`:
o("rr.%s, off, err = unpackUint48(msg, off)\n")
case `dns:"txt"`:
o("rr.%s, off, err = unpackString(msg, off)\n")
case `dns:"base32"`:
o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"base64"`:
o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case "":
switch st.Field(i).Type().(*types.Basic).Kind() {
case types.Uint8:
o("rr.%s, off, err = unpackUint8(msg, off)\n")
case types.Uint16:
o("rr.%s, off, err = unpackUint16(msg, off)\n")
case types.Uint32:
o("rr.%s, off, err = unpackUint32(msg, off)\n")
case types.Uint64:
o("rr.%s, off, err = unpackUint64(msg, off)\n")
case types.String:
o("rr.%s, off, err = unpackString(msg, off)\n")
default:
log.Fatalln(name, st.Field(i).Name())
}
//default:
//log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
// If we've hit len(msg) we return without error.
if i < st.NumFields()-1 {
fmt.Fprintf(b, `if off == len(msg) {
return rr, off, nil
}
`)
}
}
fmt.Fprintf(b, "return rr, off, err }\n\n")
}
// gofmt
res, err := format.Source(b.Bytes())
if err != nil {
b.WriteTo(os.Stderr)
log.Fatal(err)
}
// write result
f, err := os.Create("zmsg.go")
fatalIfErr(err)
defer f.Close()
f.Write(res)