Remove (most) reflection
Remove the use of reflection when packing and unpacking, instead generate all the pack and unpack functions using msg_generate. This will generate zmsg.go which in turn calls the helper functions from msg_helper.go. This increases the speed by about ~30% while cutting back on memory usage. Not all RRs are using it, but that will be rectified in upcoming PR. Most of the speed increase is in the header/question section parsing. These functions *are* not generated, but straight forward enough. The implementation can be found in msg.go. The new code has been fuzzed by go-fuzz, which turned up some issues. All files that started with 'z', and not autogenerated were renamed, i.e. zscan.go is now scan.go. Reflection is still used, in subsequent PRs it will be removed entirely.
This commit is contained in:
parent
b5deb9f6c0
commit
475ab80867
@ -150,11 +150,3 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
|
||||
* `NSD`
|
||||
* `Net::DNS`
|
||||
* `GRONG`
|
||||
|
||||
## TODO
|
||||
|
||||
* privatekey.Precompute() when signing?
|
||||
* Last remaining RRs: APL, ATMA, A6, NSAP and NXT.
|
||||
* Missing in parsing: ISDN, UNSPEC, NSAP and ATMA.
|
||||
* NSEC(3) cover/match/closest enclose.
|
||||
* Replies with TC bit are not parsed to the end.
|
||||
|
@ -300,7 +300,7 @@ func tcpMsgLen(t io.Reader) (int, error) {
|
||||
if n != 2 {
|
||||
return 0, ErrShortRead
|
||||
}
|
||||
l, _ := unpackUint16(p, 0)
|
||||
l, _ := unpackUint16Msg(p, 0)
|
||||
if l == 0 {
|
||||
return 0, ErrShortRead
|
||||
}
|
||||
@ -392,7 +392,7 @@ func (co *Conn) Write(p []byte) (n int, err error) {
|
||||
return 0, &Error{err: "message too large"}
|
||||
}
|
||||
l := make([]byte, 2, lp+2)
|
||||
l[0], l[1] = packUint16(uint16(lp))
|
||||
l[0], l[1] = packUint16Msg(uint16(lp))
|
||||
p = append(l, p...)
|
||||
n, err := io.Copy(w, bytes.NewReader(p))
|
||||
return int(n), err
|
||||
|
205
dns_bench_test.go
Normal file
205
dns_bench_test.go
Normal file
@ -0,0 +1,205 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkMsgLength(b *testing.B) {
|
||||
b.StopTimer()
|
||||
makeMsg := func(question string, ans, ns, e []RR) *Msg {
|
||||
msg := new(Msg)
|
||||
msg.SetQuestion(Fqdn(question), TypeANY)
|
||||
msg.Answer = append(msg.Answer, ans...)
|
||||
msg.Ns = append(msg.Ns, ns...)
|
||||
msg.Extra = append(msg.Extra, e...)
|
||||
msg.Compress = true
|
||||
return msg
|
||||
}
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
|
||||
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
msg.Len()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMsgLengthPack(b *testing.B) {
|
||||
makeMsg := func(question string, ans, ns, e []RR) *Msg {
|
||||
msg := new(Msg)
|
||||
msg.SetQuestion(Fqdn(question), TypeANY)
|
||||
msg.Answer = append(msg.Answer, ans...)
|
||||
msg.Ns = append(msg.Ns, ns...)
|
||||
msg.Extra = append(msg.Extra, e...)
|
||||
msg.Compress = true
|
||||
return msg
|
||||
}
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
|
||||
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = msg.Pack()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPackDomainName(b *testing.B) {
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
buf := make([]byte, len(name1)+1)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = PackDomainName(name1, buf, 0, nil, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnpackDomainName(b *testing.B) {
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
buf := make([]byte, len(name1)+1)
|
||||
_, _ = PackDomainName(name1, buf, 0, nil, false)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = UnpackDomainName(buf, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnpackDomainNameUnprintable(b *testing.B) {
|
||||
name1 := "\x02\x02\x02\x025\x02\x02\x02\x02.12345678.123."
|
||||
buf := make([]byte, len(name1)+1)
|
||||
_, _ = PackDomainName(name1, buf, 0, nil, false)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = UnpackDomainName(buf, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCopy(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
m := new(Msg)
|
||||
m.SetQuestion("miek.nl.", TypeA)
|
||||
rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1")
|
||||
m.Answer = []RR{rr}
|
||||
rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1")
|
||||
m.Ns = []RR{rr}
|
||||
rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.1")
|
||||
m.Extra = []RR{rr}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.Copy()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPackA(b *testing.B) {
|
||||
a := &A{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, A: net.IPv4(127, 0, 0, 1)}
|
||||
|
||||
buf := make([]byte, a.len())
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = PackRR(a, buf, 0, nil, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnpackA(b *testing.B) {
|
||||
a := &A{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, A: net.IPv4(127, 0, 0, 1)}
|
||||
|
||||
buf := make([]byte, a.len())
|
||||
PackRR(a, buf, 0, nil, false)
|
||||
a = nil
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = UnpackRR(buf, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPackMX(b *testing.B) {
|
||||
m := &MX{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, Mx: "mx.miek.nl."}
|
||||
|
||||
buf := make([]byte, m.len())
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = PackRR(m, buf, 0, nil, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnpackMX(b *testing.B) {
|
||||
m := &MX{Hdr: RR_Header{Name: ".", Rrtype: TypeA, Class: ClassANY}, Mx: "mx.miek.nl."}
|
||||
|
||||
buf := make([]byte, m.len())
|
||||
PackRR(m, buf, 0, nil, false)
|
||||
m = nil
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = UnpackRR(buf, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPackAAAAA(b *testing.B) {
|
||||
aaaa, _ := NewRR(". IN A ::1")
|
||||
|
||||
buf := make([]byte, aaaa.len())
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = PackRR(aaaa, buf, 0, nil, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnpackAAAA(b *testing.B) {
|
||||
aaaa, _ := NewRR(". IN A ::1")
|
||||
|
||||
buf := make([]byte, aaaa.len())
|
||||
PackRR(aaaa, buf, 0, nil, false)
|
||||
aaaa = nil
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = UnpackRR(buf, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPackMsg(b *testing.B) {
|
||||
makeMsg := func(question string, ans, ns, e []RR) *Msg {
|
||||
msg := new(Msg)
|
||||
msg.SetQuestion(Fqdn(question), TypeANY)
|
||||
msg.Answer = append(msg.Answer, ans...)
|
||||
msg.Ns = append(msg.Ns, ns...)
|
||||
msg.Extra = append(msg.Extra, e...)
|
||||
msg.Compress = true
|
||||
return msg
|
||||
}
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
|
||||
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
|
||||
buf := make([]byte, 512)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = msg.PackBuffer(buf)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnpackMsg(b *testing.B) {
|
||||
makeMsg := func(question string, ans, ns, e []RR) *Msg {
|
||||
msg := new(Msg)
|
||||
msg.SetQuestion(Fqdn(question), TypeANY)
|
||||
msg.Answer = append(msg.Answer, ans...)
|
||||
msg.Ns = append(msg.Ns, ns...)
|
||||
msg.Extra = append(msg.Extra, e...)
|
||||
msg.Compress = true
|
||||
return msg
|
||||
}
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
|
||||
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
|
||||
msgBuf, _ := msg.Pack()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = msg.Unpack(msgBuf)
|
||||
}
|
||||
}
|
132
dns_test.go
132
dns_test.go
@ -274,9 +274,9 @@ func TestMsgLength2(t *testing.T) {
|
||||
for i, hexData := range testMessages {
|
||||
// we won't fail the decoding of the hex
|
||||
input, _ := hex.DecodeString(hexData)
|
||||
|
||||
m := new(Msg)
|
||||
m.Unpack(input)
|
||||
//println(m.String())
|
||||
m.Compress = true
|
||||
lenComp := m.Len()
|
||||
b, _ := m.Pack()
|
||||
@ -310,114 +310,6 @@ func TestMsgLengthCompressionMalformed(t *testing.T) {
|
||||
m.Len() // Should not crash.
|
||||
}
|
||||
|
||||
func BenchmarkMsgLength(b *testing.B) {
|
||||
b.StopTimer()
|
||||
makeMsg := func(question string, ans, ns, e []RR) *Msg {
|
||||
msg := new(Msg)
|
||||
msg.SetQuestion(Fqdn(question), TypeANY)
|
||||
msg.Answer = append(msg.Answer, ans...)
|
||||
msg.Ns = append(msg.Ns, ns...)
|
||||
msg.Extra = append(msg.Extra, e...)
|
||||
msg.Compress = true
|
||||
return msg
|
||||
}
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
|
||||
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
msg.Len()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMsgLengthPack(b *testing.B) {
|
||||
makeMsg := func(question string, ans, ns, e []RR) *Msg {
|
||||
msg := new(Msg)
|
||||
msg.SetQuestion(Fqdn(question), TypeANY)
|
||||
msg.Answer = append(msg.Answer, ans...)
|
||||
msg.Ns = append(msg.Ns, ns...)
|
||||
msg.Extra = append(msg.Extra, e...)
|
||||
msg.Compress = true
|
||||
return msg
|
||||
}
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
|
||||
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = msg.Pack()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMsgPackBuffer(b *testing.B) {
|
||||
makeMsg := func(question string, ans, ns, e []RR) *Msg {
|
||||
msg := new(Msg)
|
||||
msg.SetQuestion(Fqdn(question), TypeANY)
|
||||
msg.Answer = append(msg.Answer, ans...)
|
||||
msg.Ns = append(msg.Ns, ns...)
|
||||
msg.Extra = append(msg.Extra, e...)
|
||||
msg.Compress = true
|
||||
return msg
|
||||
}
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
|
||||
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
|
||||
buf := make([]byte, 512)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = msg.PackBuffer(buf)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMsgUnpack(b *testing.B) {
|
||||
makeMsg := func(question string, ans, ns, e []RR) *Msg {
|
||||
msg := new(Msg)
|
||||
msg.SetQuestion(Fqdn(question), TypeANY)
|
||||
msg.Answer = append(msg.Answer, ans...)
|
||||
msg.Ns = append(msg.Ns, ns...)
|
||||
msg.Extra = append(msg.Extra, e...)
|
||||
msg.Compress = true
|
||||
return msg
|
||||
}
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
|
||||
msg := makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)
|
||||
msgBuf, _ := msg.Pack()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = msg.Unpack(msgBuf)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPackDomainName(b *testing.B) {
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
buf := make([]byte, len(name1)+1)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = PackDomainName(name1, buf, 0, nil, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnpackDomainName(b *testing.B) {
|
||||
name1 := "12345678901234567890123456789012345.12345678.123."
|
||||
buf := make([]byte, len(name1)+1)
|
||||
_, _ = PackDomainName(name1, buf, 0, nil, false)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = UnpackDomainName(buf, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnpackDomainNameUnprintable(b *testing.B) {
|
||||
name1 := "\x02\x02\x02\x025\x02\x02\x02\x02.12345678.123."
|
||||
buf := make([]byte, len(name1)+1)
|
||||
_, _ = PackDomainName(name1, buf, 0, nil, false)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = UnpackDomainName(buf, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToRFC3597(t *testing.T) {
|
||||
a, _ := NewRR("miek.nl. IN A 10.0.1.1")
|
||||
x := new(RFC3597)
|
||||
@ -431,7 +323,7 @@ func TestNoRdataPack(t *testing.T) {
|
||||
data := make([]byte, 1024)
|
||||
for typ, fn := range TypeToRR {
|
||||
r := fn()
|
||||
*r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600}
|
||||
*r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 16}
|
||||
_, err := PackRR(r, data, 0, nil, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to pack RR with zero rdata: %s: %v", TypeToString[typ], err)
|
||||
@ -439,7 +331,6 @@ func TestNoRdataPack(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(miek): fix dns buffer too small errors this throws
|
||||
func TestNoRdataUnpack(t *testing.T) {
|
||||
data := make([]byte, 1024)
|
||||
for typ, fn := range TypeToRR {
|
||||
@ -449,7 +340,7 @@ func TestNoRdataUnpack(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
r := fn()
|
||||
*r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600}
|
||||
*r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 16}
|
||||
off, err := PackRR(r, data, 0, nil, false)
|
||||
if err != nil {
|
||||
// Should always works, TestNoDataPack should have caught this
|
||||
@ -513,23 +404,6 @@ func TestMsgCopy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCopy(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
m := new(Msg)
|
||||
m.SetQuestion("miek.nl.", TypeA)
|
||||
rr, _ := NewRR("miek.nl. 2311 IN A 127.0.0.1")
|
||||
m.Answer = []RR{rr}
|
||||
rr, _ = NewRR("miek.nl. 2311 IN NS 127.0.0.1")
|
||||
m.Ns = []RR{rr}
|
||||
rr, _ = NewRR("miek.nl. 2311 IN A 127.0.0.1")
|
||||
m.Extra = []RR{rr}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.Copy()
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackIPSECKEY(t *testing.T) {
|
||||
tests := []string{
|
||||
"38.2.0.192.in-addr.arpa. 7200 IN IPSECKEY ( 10 1 2 192.0.2.38 AQNRU3mG7TVTO2BkR47usntb102uFJtugbo6BSGvgqt4AQ== )",
|
||||
|
@ -144,7 +144,7 @@ func (k *DNSKEY) KeyTag() uint16 {
|
||||
// at the base64 values. But I'm lazy.
|
||||
modulus, _ := fromBase64([]byte(k.PublicKey))
|
||||
if len(modulus) > 1 {
|
||||
x, _ := unpackUint16(modulus, len(modulus)-2)
|
||||
x, _ := unpackUint16Msg(modulus, len(modulus)-2)
|
||||
keytag = int(x)
|
||||
}
|
||||
default:
|
||||
|
16
edns.go
16
edns.go
@ -213,7 +213,7 @@ func (e *EDNS0_SUBNET) Option() uint16 {
|
||||
|
||||
func (e *EDNS0_SUBNET) pack() ([]byte, error) {
|
||||
b := make([]byte, 4)
|
||||
b[0], b[1] = packUint16(e.Family)
|
||||
b[0], b[1] = packUint16Msg(e.Family)
|
||||
b[2] = e.SourceNetmask
|
||||
b[3] = e.SourceScope
|
||||
switch e.Family {
|
||||
@ -247,7 +247,7 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error {
|
||||
if len(b) < 4 {
|
||||
return ErrBuf
|
||||
}
|
||||
e.Family, _ = unpackUint16(b, 0)
|
||||
e.Family, _ = unpackUint16Msg(b, 0)
|
||||
e.SourceNetmask = b[2]
|
||||
e.SourceScope = b[3]
|
||||
switch e.Family {
|
||||
@ -369,9 +369,9 @@ func (e *EDNS0_LLQ) Option() uint16 { return EDNS0LLQ }
|
||||
|
||||
func (e *EDNS0_LLQ) pack() ([]byte, error) {
|
||||
b := make([]byte, 18)
|
||||
b[0], b[1] = packUint16(e.Version)
|
||||
b[2], b[3] = packUint16(e.Opcode)
|
||||
b[4], b[5] = packUint16(e.Error)
|
||||
b[0], b[1] = packUint16Msg(e.Version)
|
||||
b[2], b[3] = packUint16Msg(e.Opcode)
|
||||
b[4], b[5] = packUint16Msg(e.Error)
|
||||
b[6] = byte(e.Id >> 56)
|
||||
b[7] = byte(e.Id >> 48)
|
||||
b[8] = byte(e.Id >> 40)
|
||||
@ -391,9 +391,9 @@ func (e *EDNS0_LLQ) unpack(b []byte) error {
|
||||
if len(b) < 18 {
|
||||
return ErrBuf
|
||||
}
|
||||
e.Version, _ = unpackUint16(b, 0)
|
||||
e.Opcode, _ = unpackUint16(b, 2)
|
||||
e.Error, _ = unpackUint16(b, 4)
|
||||
e.Version, _ = unpackUint16Msg(b, 0)
|
||||
e.Opcode, _ = unpackUint16Msg(b, 2)
|
||||
e.Error, _ = unpackUint16Msg(b, 4)
|
||||
e.Id = uint64(b[6])<<56 | uint64(b[6+1])<<48 | uint64(b[6+2])<<40 |
|
||||
uint64(b[6+3])<<32 | uint64(b[6+4])<<24 | uint64(b[6+5])<<16 | uint64(b[6+6])<<8 | uint64(b[6+7])
|
||||
e.LeaseLife = uint32(b[14])<<24 | uint32(b[14+1])<<16 | uint32(b[14+2])<<8 | uint32(b[14+3])
|
||||
|
@ -107,7 +107,7 @@ func CountLabel(s string) (labels int) {
|
||||
|
||||
// Split splits a name s into its label indexes.
|
||||
// www.miek.nl. returns []int{0, 4, 9}, www.miek.nl also returns []int{0, 4, 9}.
|
||||
// The root name (.) returns nil. Also see SplitDomainName.
|
||||
// The root name (.) returns nil. Also see SplitDomainName.
|
||||
// s must be a syntactically valid domain name.
|
||||
func Split(s string) []int {
|
||||
if s == "." {
|
||||
|
@ -184,12 +184,14 @@ func BenchmarkLenLabels(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkCompareLabels(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
CompareDomainName("www.example.com", "aa.example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsSubDomain(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
IsSubDomain("www.example.com", "aa.example.com")
|
||||
IsSubDomain("example.com", "aa.example.com")
|
||||
|
338
msg.go
338
msg.go
@ -8,9 +8,9 @@
|
||||
|
||||
package dns
|
||||
|
||||
//go:generate go run msg_generate.go
|
||||
|
||||
import (
|
||||
"encoding/base32"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
@ -92,18 +92,6 @@ type Msg struct {
|
||||
Extra []RR // Holds the RR(s) of the additional section.
|
||||
}
|
||||
|
||||
// StringToType is the reverse of TypeToString, needed for string parsing.
|
||||
var StringToType = reverseInt16(TypeToString)
|
||||
|
||||
// StringToClass is the reverse of ClassToString, needed for string parsing.
|
||||
var StringToClass = reverseInt16(ClassToString)
|
||||
|
||||
// Map of opcodes strings.
|
||||
var StringToOpcode = reverseInt(OpcodeToString)
|
||||
|
||||
// Map of rcodes strings.
|
||||
var StringToRcode = reverseInt(RcodeToString)
|
||||
|
||||
// ClassToString is a maps Classes to strings for each CLASS wire type.
|
||||
var ClassToString = map[uint16]string{
|
||||
ClassINET: "IN",
|
||||
@ -291,11 +279,11 @@ func packDomainName(s string, msg []byte, off int, compression map[string]int, c
|
||||
if pointer != -1 {
|
||||
// We have two bytes (14 bits) to put the pointer in
|
||||
// if msg == nil, we will never do compression
|
||||
msg[nameoffset], msg[nameoffset+1] = packUint16(uint16(pointer ^ 0xC000))
|
||||
msg[nameoffset], msg[nameoffset+1] = packUint16Msg(uint16(pointer ^ 0xC000))
|
||||
off = nameoffset + 1
|
||||
goto End
|
||||
}
|
||||
if msg != nil {
|
||||
if msg != nil && off < len(msg) {
|
||||
msg[off] = 0
|
||||
}
|
||||
End:
|
||||
@ -423,7 +411,7 @@ func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) {
|
||||
|
||||
func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) {
|
||||
lenByteOffset := offset
|
||||
if offset >= len(msg) {
|
||||
if offset >= len(msg) || len(s) > len(tmp) {
|
||||
return offset, ErrBuf
|
||||
}
|
||||
offset++
|
||||
@ -465,7 +453,7 @@ func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) {
|
||||
}
|
||||
|
||||
func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) {
|
||||
if offset >= len(msg) {
|
||||
if offset >= len(msg) || len(s) > len(tmp) {
|
||||
return offset, ErrBuf
|
||||
}
|
||||
bs := tmp[:len(s)]
|
||||
@ -600,9 +588,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
||||
return lenmsg, &Error{err: "overflow packing opt"}
|
||||
}
|
||||
// Option code
|
||||
msg[off], msg[off+1] = packUint16(element.(EDNS0).Option())
|
||||
msg[off], msg[off+1] = packUint16Msg(element.(EDNS0).Option())
|
||||
// Length
|
||||
msg[off+2], msg[off+3] = packUint16(uint16(len(b)))
|
||||
msg[off+2], msg[off+3] = packUint16Msg(uint16(len(b)))
|
||||
off += 4
|
||||
if off+len(b) > lenmsg {
|
||||
copy(msg[off:], b)
|
||||
@ -783,6 +771,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
||||
if e != nil {
|
||||
return lenmsg, e
|
||||
}
|
||||
if off+len(b64) > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing base64"}
|
||||
}
|
||||
copy(msg[off:off+len(b64)], b64)
|
||||
off += len(b64)
|
||||
case `dns:"domain-name"`:
|
||||
@ -811,6 +802,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
||||
if e != nil {
|
||||
return lenmsg, e
|
||||
}
|
||||
if off+len(b32) > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing base32"}
|
||||
}
|
||||
copy(msg[off:off+len(b32)], b32)
|
||||
off += len(b32)
|
||||
case `dns:"size-hex"`:
|
||||
@ -827,6 +821,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
||||
copy(msg[off:off+hex.DecodedLen(len(s))], h)
|
||||
off += hex.DecodedLen(len(s))
|
||||
case `dns:"size"`:
|
||||
// TODO(miek): WTF? size?
|
||||
// the size is already encoded in the RR, we can safely use the
|
||||
// length of string. String is RAW (not encoded in hex, nor base64)
|
||||
copy(msg[off:off+len(s)], s)
|
||||
@ -930,8 +925,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
|
||||
if off+4 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking opt"}
|
||||
}
|
||||
code, off = unpackUint16(msg, off)
|
||||
optlen, off1 := unpackUint16(msg, off)
|
||||
code, off = unpackUint16Msg(msg, off)
|
||||
optlen, off1 := unpackUint16Msg(msg, off)
|
||||
if off1+int(optlen) > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking opt"}
|
||||
}
|
||||
@ -1174,7 +1169,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
|
||||
if off+2 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking uint16"}
|
||||
}
|
||||
i, off = unpackUint16(msg, off)
|
||||
i, off = unpackUint16Msg(msg, off)
|
||||
fv.SetUint(uint64(i))
|
||||
case reflect.Uint32:
|
||||
if off == lenmsg {
|
||||
@ -1334,38 +1329,6 @@ func intToBytes(i *big.Int, length int) []byte {
|
||||
return buf
|
||||
}
|
||||
|
||||
func unpackUint16(msg []byte, off int) (uint16, int) {
|
||||
return uint16(msg[off])<<8 | uint16(msg[off+1]), off + 2
|
||||
}
|
||||
|
||||
func packUint16(i uint16) (byte, byte) {
|
||||
return byte(i >> 8), byte(i)
|
||||
}
|
||||
|
||||
func toBase32(b []byte) string {
|
||||
return base32.HexEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func fromBase32(s []byte) (buf []byte, err error) {
|
||||
buflen := base32.HexEncoding.DecodedLen(len(s))
|
||||
buf = make([]byte, buflen)
|
||||
n, err := base32.HexEncoding.Decode(buf, s)
|
||||
buf = buf[:n]
|
||||
return
|
||||
}
|
||||
|
||||
func toBase64(b []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func fromBase64(s []byte) (buf []byte, err error) {
|
||||
buflen := base64.StdEncoding.DecodedLen(len(s))
|
||||
buf = make([]byte, buflen)
|
||||
n, err := base64.StdEncoding.Decode(buf, s)
|
||||
buf = buf[:n]
|
||||
return
|
||||
}
|
||||
|
||||
// PackRR packs a resource record rr into msg[off:].
|
||||
// See PackDomainName for documentation about the compression.
|
||||
func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
|
||||
@ -1373,7 +1336,62 @@ func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress boo
|
||||
return len(msg), &Error{err: "nil rr"}
|
||||
}
|
||||
|
||||
off1, err = packStructCompress(rr, msg, off, compression, compress)
|
||||
_, ok := typeToUnpack[rr.Header().Rrtype]
|
||||
switch ok {
|
||||
case true:
|
||||
// Shortcut reflection, `pack' needs to be added to the RR interface so we can just do this:
|
||||
// off1, err = t.pack(msg, off, compression, compress)
|
||||
// TODO(miek): revert the logic and make a blacklist for types that still use reflection. Kill
|
||||
// typeToUnpack and just generate all the pack and unpack functions even though we don't use
|
||||
// them for all types (yet).
|
||||
switch t := rr.(type) {
|
||||
case *RR_Header:
|
||||
// we can be called with an empty RR, consisting only out of the header, see update_test.go's
|
||||
// TestDynamicUpdateZeroRdataUnpack for an example. This is OK as RR_Header also implements the RR interface.
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *ANY:
|
||||
// Also "weird" setup, see (again) update_test.go's TestRemoveRRset, where the Rrtype is 1 but the type is *ANY.
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *A:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *AAAA:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *CNAME:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *DNAME:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *HINFO:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *L32:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *LOC:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *MB:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *MD:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *MF:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *MG:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *MX:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *NID:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *NS:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *PTR:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *RP:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *SRV:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
case *DNSKEY:
|
||||
off1, err = t.pack(msg, off, compression, compress)
|
||||
}
|
||||
default:
|
||||
off1, err = packStructCompress(rr, msg, off, compression, compress)
|
||||
}
|
||||
if err != nil {
|
||||
return len(msg), err
|
||||
}
|
||||
@ -1385,21 +1403,27 @@ func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress boo
|
||||
|
||||
// UnpackRR unpacks msg[off:] into an RR.
|
||||
func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
|
||||
// unpack just the header, to find the rr type and length
|
||||
var h RR_Header
|
||||
off0 := off
|
||||
if off, err = UnpackStruct(&h, msg, off); err != nil {
|
||||
h, off, msg, err := unpackHeader(msg, off)
|
||||
if err != nil {
|
||||
return nil, len(msg), err
|
||||
}
|
||||
end := off + int(h.Rdlength)
|
||||
// make an rr of that type and re-unpack.
|
||||
mk, known := TypeToRR[h.Rrtype]
|
||||
if !known {
|
||||
rr = new(RFC3597)
|
||||
} else {
|
||||
rr = mk()
|
||||
|
||||
fn, ok := typeToUnpack[h.Rrtype]
|
||||
switch ok {
|
||||
case true:
|
||||
// Shortcut reflection.
|
||||
rr, off, err = fn(h, msg, off)
|
||||
default:
|
||||
mk, known := TypeToRR[h.Rrtype]
|
||||
if !known {
|
||||
rr = new(RFC3597)
|
||||
} else {
|
||||
rr = mk()
|
||||
}
|
||||
off, err = UnpackStruct(rr, msg, off0)
|
||||
}
|
||||
off, err = UnpackStruct(rr, msg, off0)
|
||||
if off != end {
|
||||
return &h, end, &Error{err: "bad rdlength"}
|
||||
}
|
||||
@ -1432,31 +1456,6 @@ func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error)
|
||||
return dst, off, err
|
||||
}
|
||||
|
||||
// Reverse a map
|
||||
func reverseInt8(m map[uint8]string) map[string]uint8 {
|
||||
n := make(map[string]uint8)
|
||||
for u, s := range m {
|
||||
n[s] = u
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func reverseInt16(m map[uint16]string) map[string]uint16 {
|
||||
n := make(map[string]uint16)
|
||||
for u, s := range m {
|
||||
n[s] = u
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func reverseInt(m map[int]string) map[string]int {
|
||||
n := make(map[string]int)
|
||||
for u, s := range m {
|
||||
n[s] = u
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// Convert a MsgHdr to a string, with dig-like headers:
|
||||
//
|
||||
//;; opcode: QUERY, status: NOERROR, id: 48404
|
||||
@ -1510,8 +1509,11 @@ func (dns *Msg) Pack() (msg []byte, err error) {
|
||||
// PackBuffer packs a Msg, using the given buffer buf. If buf is too small
|
||||
// a new buffer is allocated.
|
||||
func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
|
||||
var dh Header
|
||||
var compression map[string]int
|
||||
var (
|
||||
dh Header
|
||||
compression map[string]int
|
||||
)
|
||||
|
||||
if dns.Compress {
|
||||
compression = make(map[string]int) // Compression pointer mappings
|
||||
}
|
||||
@ -1579,12 +1581,12 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
|
||||
|
||||
// Pack it in: header and then the pieces.
|
||||
off := 0
|
||||
off, err = packStructCompress(&dh, msg, off, compression, dns.Compress)
|
||||
off, err = dh.pack(msg, off, compression, dns.Compress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := 0; i < len(question); i++ {
|
||||
off, err = packStructCompress(&question[i], msg, off, compression, dns.Compress)
|
||||
off, err = question[i].pack(msg, off, compression, dns.Compress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1612,12 +1614,17 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
|
||||
|
||||
// Unpack unpacks a binary message to a Msg structure.
|
||||
func (dns *Msg) Unpack(msg []byte) (err error) {
|
||||
// Header.
|
||||
var dh Header
|
||||
off := 0
|
||||
if off, err = UnpackStruct(&dh, msg, off); err != nil {
|
||||
var (
|
||||
dh Header
|
||||
off int
|
||||
)
|
||||
if dh, off, err = unpackMsgHdr(msg, off); err != nil {
|
||||
return err
|
||||
}
|
||||
if off == len(msg) {
|
||||
return ErrTruncated
|
||||
}
|
||||
|
||||
dns.Id = dh.Id
|
||||
dns.Response = (dh.Bits & _QR) != 0
|
||||
dns.Opcode = int(dh.Bits>>11) & 0xF
|
||||
@ -1633,10 +1640,10 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
|
||||
// Optimistically use the count given to us in the header
|
||||
dns.Question = make([]Question, 0, int(dh.Qdcount))
|
||||
|
||||
var q Question
|
||||
for i := 0; i < int(dh.Qdcount); i++ {
|
||||
off1 := off
|
||||
off, err = UnpackStruct(&q, msg, off)
|
||||
var q Question
|
||||
q, off, err = unpackQuestion(msg, off)
|
||||
if err != nil {
|
||||
// Even if Truncated is set, we only will set ErrTruncated if we
|
||||
// actually got the questions
|
||||
@ -1662,6 +1669,7 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
|
||||
}
|
||||
// The header counts might have been wrong so we need to update it
|
||||
dh.Arcount = uint16(len(dns.Extra))
|
||||
|
||||
if off != len(msg) {
|
||||
// TODO(miek) make this an error?
|
||||
// use PackOpt to let people tell how detailed the error reporting should be?
|
||||
@ -1735,6 +1743,9 @@ func (dns *Msg) Len() int {
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(dns.Answer); i++ {
|
||||
if dns.Answer[i] == nil {
|
||||
continue
|
||||
}
|
||||
l += dns.Answer[i].len()
|
||||
if dns.Compress {
|
||||
k, ok := compressionLenSearch(compression, dns.Answer[i].Header().Name)
|
||||
@ -1750,6 +1761,9 @@ func (dns *Msg) Len() int {
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(dns.Ns); i++ {
|
||||
if dns.Ns[i] == nil {
|
||||
continue
|
||||
}
|
||||
l += dns.Ns[i].len()
|
||||
if dns.Compress {
|
||||
k, ok := compressionLenSearch(compression, dns.Ns[i].Header().Name)
|
||||
@ -1765,6 +1779,9 @@ func (dns *Msg) Len() int {
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(dns.Extra); i++ {
|
||||
if dns.Extra[i] == nil {
|
||||
continue
|
||||
}
|
||||
l += dns.Extra[i].len()
|
||||
if dns.Compress {
|
||||
k, ok := compressionLenSearch(compression, dns.Extra[i].Header().Name)
|
||||
@ -1955,3 +1972,122 @@ func (dns *Msg) CopyTo(r1 *Msg) *Msg {
|
||||
|
||||
return r1
|
||||
}
|
||||
|
||||
func (q *Question) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
|
||||
off, err := PackDomainName(q.Name, msg, off, compression, compress)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(q.Qtype, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(q.Qclass, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
return off, nil
|
||||
}
|
||||
|
||||
func unpackQuestion(msg []byte, off int) (Question, int, error) {
|
||||
var (
|
||||
q Question
|
||||
err error
|
||||
)
|
||||
q.Name, off, err = UnpackDomainName(msg, off)
|
||||
if err != nil {
|
||||
return q, off, err
|
||||
}
|
||||
if off == len(msg) {
|
||||
return q, off, nil
|
||||
}
|
||||
q.Qtype, off, err = unpackUint16(msg, off)
|
||||
if err != nil {
|
||||
return q, off, err
|
||||
}
|
||||
if off == len(msg) {
|
||||
return q, off, nil
|
||||
}
|
||||
q.Qclass, off, err = unpackUint16(msg, off)
|
||||
if off == len(msg) {
|
||||
return q, off, nil
|
||||
}
|
||||
return q, off, err
|
||||
}
|
||||
|
||||
func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
|
||||
off, err := packUint16(dh.Id, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(dh.Bits, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(dh.Qdcount, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(dh.Ancount, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(dh.Nscount, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(dh.Arcount, msg, off)
|
||||
return off, err
|
||||
}
|
||||
|
||||
func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
|
||||
var (
|
||||
dh Header
|
||||
err error
|
||||
)
|
||||
dh.Id, off, err = unpackUint16(msg, off)
|
||||
if err != nil {
|
||||
return dh, off, err
|
||||
}
|
||||
dh.Bits, off, err = unpackUint16(msg, off)
|
||||
if err != nil {
|
||||
return dh, off, err
|
||||
}
|
||||
dh.Qdcount, off, err = unpackUint16(msg, off)
|
||||
if err != nil {
|
||||
return dh, off, err
|
||||
}
|
||||
dh.Ancount, off, err = unpackUint16(msg, off)
|
||||
if err != nil {
|
||||
return dh, off, err
|
||||
}
|
||||
dh.Nscount, off, err = unpackUint16(msg, off)
|
||||
if err != nil {
|
||||
return dh, off, err
|
||||
}
|
||||
dh.Arcount, off, err = unpackUint16(msg, off)
|
||||
return dh, off, err
|
||||
}
|
||||
|
||||
// Which types have type specific unpack functions.
|
||||
var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){
|
||||
TypeAAAA: unpackAAAA,
|
||||
TypeA: unpackA,
|
||||
TypeCNAME: unpackCNAME,
|
||||
TypeDNAME: unpackDNAME,
|
||||
TypeL32: unpackL32,
|
||||
TypeLOC: unpackLOC,
|
||||
TypeMB: unpackMB,
|
||||
TypeMD: unpackMD,
|
||||
TypeMF: unpackMF,
|
||||
TypeMG: unpackMG,
|
||||
TypeMR: unpackMR,
|
||||
TypeMX: unpackMX,
|
||||
TypeNID: unpackNID,
|
||||
TypeNS: unpackNS,
|
||||
TypePTR: unpackPTR,
|
||||
TypeRP: unpackRP,
|
||||
TypeSRV: unpackSRV,
|
||||
TypeHINFO: unpackHINFO,
|
||||
TypeDNSKEY: unpackDNSKEY,
|
||||
}
|
||||
|
307
msg_generate.go
Normal file
307
msg_generate.go
Normal file
@ -0,0 +1,307 @@
|
||||
//+build ignore
|
||||
|
||||
// msg_generate.go is meant to run with go generate. It will use
|
||||
// go/{importer,types} to track down all the RR struct types. Then for each type
|
||||
// it will generate pack/unpack methods based on the struct tags. The generated source is
|
||||
// written to zmsg.go, and is meant to be checked into git.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"go/importer"
|
||||
"go/types"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// All RR pack and unpack functions should be generated, currently RR that present some
|
||||
// problems
|
||||
// * NSEC/NSEC3 - type bitmap
|
||||
// * TXT/SPF - string slice
|
||||
// * URI - weird octet thing there
|
||||
// * NSEC3/TSIG - size hex
|
||||
// * OPT RR - EDNS0 parsing - needs to some looking at
|
||||
// * HIP - uses "hex", but is actually size-hex - might drop size-hex?
|
||||
// * Z
|
||||
// * WKS - uint16 slice
|
||||
// * NINFO
|
||||
// * PrivateRR
|
||||
|
||||
// What types are we generating, should be kept in sync with typeToUnpack in msg.go
|
||||
var generate = map[string]bool{
|
||||
"AAAA": true,
|
||||
"ANY": true,
|
||||
"A": true,
|
||||
"CNAME": true,
|
||||
"DNAME": true,
|
||||
"DNSKEY": true,
|
||||
"HINFO": true,
|
||||
"L32": true,
|
||||
"LOC": true,
|
||||
"MB": true,
|
||||
"MD": true,
|
||||
"MF": true,
|
||||
"MG": true,
|
||||
"MR": true,
|
||||
"MX": true,
|
||||
"NID": true,
|
||||
"NS": true,
|
||||
"PTR": true,
|
||||
"RP": true,
|
||||
"SRV": true,
|
||||
}
|
||||
|
||||
func shouldGenerate(name string) bool {
|
||||
_, ok := generate[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// For later: IPSECKEY is weird.
|
||||
|
||||
var packageHdr = `
|
||||
// *** DO NOT MODIFY ***
|
||||
// AUTOGENERATED BY go generate from msg_generate.go
|
||||
|
||||
package dns
|
||||
|
||||
`
|
||||
|
||||
// getTypeStruct will take a type and the package scope, and return the
|
||||
// (innermost) struct if the type is considered a RR type (currently defined as
|
||||
// those structs beginning with a RR_Header, could be redefined as implementing
|
||||
// the RR interface). The bool return value indicates if embedded structs were
|
||||
// resolved.
|
||||
func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
|
||||
st, ok := t.Underlying().(*types.Struct)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
|
||||
return st, false
|
||||
}
|
||||
if st.Field(0).Anonymous() {
|
||||
st, _ := getTypeStruct(st.Field(0).Type(), scope)
|
||||
return st, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Import and type-check the package
|
||||
pkg, err := importer.Default().Import("github.com/miekg/dns")
|
||||
fatalIfErr(err)
|
||||
scope := pkg.Scope()
|
||||
|
||||
// Collect actual types (*X)
|
||||
var namedTypes []string
|
||||
for _, name := range scope.Names() {
|
||||
o := scope.Lookup(name)
|
||||
if o == nil || !o.Exported() {
|
||||
continue
|
||||
}
|
||||
if st, _ := getTypeStruct(o.Type(), scope); st == nil {
|
||||
continue
|
||||
}
|
||||
if name == "PrivateRR" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if corresponding TypeX exists
|
||||
if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
|
||||
log.Fatalf("Constant Type%s does not exist.", o.Name())
|
||||
}
|
||||
|
||||
namedTypes = append(namedTypes, o.Name())
|
||||
}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
b.WriteString(packageHdr)
|
||||
|
||||
fmt.Fprint(b, "// pack*() functions\n\n")
|
||||
for _, name := range namedTypes {
|
||||
o := scope.Lookup(name)
|
||||
st, isEmbedded := getTypeStruct(o.Type(), scope)
|
||||
if isEmbedded || !shouldGenerate(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name)
|
||||
fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
headerEnd := off
|
||||
`)
|
||||
for i := 1; i < st.NumFields(); i++ {
|
||||
o := func(s string) {
|
||||
fmt.Fprintf(b, s, st.Field(i).Name())
|
||||
fmt.Fprint(b, `if err != nil {
|
||||
return off, err
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
//if _, ok := st.Field(i).Type().(*types.Slice); ok {
|
||||
//switch st.Tag(i) {
|
||||
//case `dns:"-"`:
|
||||
//// ignored
|
||||
//case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`:
|
||||
//o("for _, x := range rr.%s { l += len(x) + 1 }\n")
|
||||
//default:
|
||||
//log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
|
||||
//}
|
||||
//continue
|
||||
//}
|
||||
|
||||
switch st.Tag(i) {
|
||||
case `dns:"-"`:
|
||||
// ignored
|
||||
case `dns:"cdomain-name"`:
|
||||
fallthrough
|
||||
case `dns:"domain-name"`:
|
||||
o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n")
|
||||
case `dns:"a"`:
|
||||
o("off, err = packDataA(rr.%s, msg, off)\n")
|
||||
case `dns:"aaaa"`:
|
||||
o("off, err = packDataAAAA(rr.%s, msg, off)\n")
|
||||
case `dns:"uint48"`:
|
||||
o("off, err = packUint48(rr.%s, msg, off)\n")
|
||||
case `dns:"txt"`:
|
||||
o("off, err = packString(rr.%s, msg, off)\n")
|
||||
case `dns:"base32"`:
|
||||
o("off, err = packStringBase32(rr.%s, msg, off)\n")
|
||||
case `dns:"base64"`:
|
||||
o("off, err = packStringBase64(rr.%s, msg, off)\n")
|
||||
case "":
|
||||
switch st.Field(i).Type().(*types.Basic).Kind() {
|
||||
case types.Uint8:
|
||||
o("off, err = packUint8(rr.%s, msg, off)\n")
|
||||
case types.Uint16:
|
||||
o("off, err = packUint16(rr.%s, msg, off)\n")
|
||||
case types.Uint32:
|
||||
o("off, err = packUint32(rr.%s, msg, off)\n")
|
||||
case types.Uint64:
|
||||
o("off, err = packUint64(rr.%s, msg, off)\n")
|
||||
case types.String:
|
||||
o("off, err = packString(rr.%s, msg, off)\n")
|
||||
default:
|
||||
log.Fatalln(name, st.Field(i).Name())
|
||||
}
|
||||
//default:
|
||||
//log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
|
||||
}
|
||||
}
|
||||
// We have packed everything, only now we know the rdlength of this RR
|
||||
fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off- headerEnd)")
|
||||
fmt.Fprintln(b, "return off, nil }\n")
|
||||
}
|
||||
|
||||
fmt.Fprint(b, "// unpack*() functions\n\n")
|
||||
for _, name := range namedTypes {
|
||||
o := scope.Lookup(name)
|
||||
st, isEmbedded := getTypeStruct(o.Type(), scope)
|
||||
if isEmbedded || !shouldGenerate(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name)
|
||||
fmt.Fprint(b, `if noRdata(h) {
|
||||
return nil, off, nil
|
||||
}
|
||||
var err error
|
||||
rdStart := off
|
||||
_ = rdStart
|
||||
|
||||
`)
|
||||
fmt.Fprintf(b, "rr := new(%s)\n", name)
|
||||
fmt.Fprintln(b, "rr.Hdr = h\n")
|
||||
for i := 1; i < st.NumFields(); i++ {
|
||||
o := func(s string) {
|
||||
fmt.Fprintf(b, s, st.Field(i).Name())
|
||||
fmt.Fprint(b, `if err != nil {
|
||||
return rr, off, err
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
//if _, ok := st.Field(i).Type().(*types.Slice); ok {
|
||||
//switch st.Tag(i) {
|
||||
//case `dns:"-"`:
|
||||
//// ignored
|
||||
//case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`:
|
||||
//o("for _, x := range rr.%s { l += len(x) + 1 }\n")
|
||||
//default:
|
||||
//log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
|
||||
//}
|
||||
//continue
|
||||
//}
|
||||
|
||||
switch st.Tag(i) {
|
||||
case `dns:"-"`:
|
||||
// ignored
|
||||
case `dns:"cdomain-name"`:
|
||||
fallthrough
|
||||
case `dns:"domain-name"`:
|
||||
o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
|
||||
case `dns:"a"`:
|
||||
o("rr.%s, off, err = unpackDataA(msg, off)\n")
|
||||
case `dns:"aaaa"`:
|
||||
o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
|
||||
case `dns:"uint48"`:
|
||||
o("rr.%s, off, err = unpackUint48(msg, off)\n")
|
||||
case `dns:"txt"`:
|
||||
o("rr.%s, off, err = unpackString(msg, off)\n")
|
||||
case `dns:"base32"`:
|
||||
o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
|
||||
case `dns:"base64"`:
|
||||
o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
|
||||
case "":
|
||||
switch st.Field(i).Type().(*types.Basic).Kind() {
|
||||
case types.Uint8:
|
||||
o("rr.%s, off, err = unpackUint8(msg, off)\n")
|
||||
case types.Uint16:
|
||||
o("rr.%s, off, err = unpackUint16(msg, off)\n")
|
||||
case types.Uint32:
|
||||
o("rr.%s, off, err = unpackUint32(msg, off)\n")
|
||||
case types.Uint64:
|
||||
o("rr.%s, off, err = unpackUint64(msg, off)\n")
|
||||
case types.String:
|
||||
o("rr.%s, off, err = unpackString(msg, off)\n")
|
||||
default:
|
||||
log.Fatalln(name, st.Field(i).Name())
|
||||
}
|
||||
//default:
|
||||
//log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
|
||||
}
|
||||
// If we've hit len(msg) we return without error.
|
||||
if i < st.NumFields()-1 {
|
||||
fmt.Fprintf(b, `if off == len(msg) {
|
||||
return rr, off, nil
|
||||
}
|
||||
`)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(b, "return rr, off, err }\n\n")
|
||||
}
|
||||
|
||||
// gofmt
|
||||
res, err := format.Source(b.Bytes())
|
||||
if err != nil {
|
||||
b.WriteTo(os.Stderr)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// write result
|
||||
f, err := os.Create("zmsg.go")
|
||||
fatalIfErr(err)
|
||||
defer f.Close()
|
||||
f.Write(res)
|
||||