diff --git a/compress_generate.go b/compress_generate.go new file mode 100644 index 00000000..1a301e9f --- /dev/null +++ b/compress_generate.go @@ -0,0 +1,184 @@ +//+build ignore + +// compression_generate.go is meant to run with go generate. It will use +// go/{importer,types} to track down all the RR struct types. Then for each type +// it will look to see if there are (compressible) names, if so it will add that +// type to compressionLenHelperType and comressionLenSearchType which "fake" the +// compression so that Len() is fast. +package main + +import ( + "bytes" + "fmt" + "go/format" + "go/importer" + "go/types" + "log" + "os" +) + +var packageHdr = ` +// *** DO NOT MODIFY *** +// AUTOGENERATED BY go generate from compress_generate.go + +package dns + +` + +// getTypeStruct will take a type and the package scope, and return the +// (innermost) struct if the type is considered a RR type (currently defined as +// those structs beginning with a RR_Header, could be redefined as implementing +// the RR interface). The bool return value indicates if embedded structs were +// resolved. +func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { + st, ok := t.Underlying().(*types.Struct) + if !ok { + return nil, false + } + if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { + return st, false + } + if st.Field(0).Anonymous() { + st, _ := getTypeStruct(st.Field(0).Type(), scope) + return st, true + } + return nil, false +} + +func main() { + // Import and type-check the package + pkg, err := importer.Default().Import("github.com/miekg/dns") + fatalIfErr(err) + scope := pkg.Scope() + + domainTypes := map[string]bool{} // Types that have a domain name in them (either comressible or not). + cdomainTypes := map[string]bool{} // Types that have a compressible domain name in them (subset of domainType) + for _, name := range scope.Names() { + o := scope.Lookup(name) + if o == nil || !o.Exported() { + continue + } + st, _ := getTypeStruct(o.Type(), scope) + if st == nil { + continue + } + if name == "PrivateRR" { + continue + } + + if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { + log.Fatalf("Constant Type%s does not exist.", o.Name()) + } + + for i := 1; i < st.NumFields(); i++ { + if _, ok := st.Field(i).Type().(*types.Slice); ok { + if st.Tag(i) == `dns:"domain-name"` { + domainTypes[o.Name()] = true + } + if st.Tag(i) == `dns:"cdomain-name"` { + cdomainTypes[o.Name()] = true + domainTypes[o.Name()] = true + } + continue + } + + switch { + case st.Tag(i) == `dns:"domain-name"`: + domainTypes[o.Name()] = true + case st.Tag(i) == `dns:"cdomain-name"`: + cdomainTypes[o.Name()] = true + domainTypes[o.Name()] = true + } + } + } + + b := &bytes.Buffer{} + b.WriteString(packageHdr) + + // compressionLenHelperType - all types that have domain-name/cdomain-name can be used for compressing names + + fmt.Fprint(b, "func compressionLenHelperType(c map[string]int, r RR) {\n") + fmt.Fprint(b, "switch x := r.(type) {\n") + for name, _ := range domainTypes { + o := scope.Lookup(name) + st, _ := getTypeStruct(o.Type(), scope) + + fmt.Fprintf(b, "case *%s:\n", name) + for i := 1; i < st.NumFields(); i++ { + out := func(s string) { fmt.Fprintf(b, "compressionLenHelper(c, x.%s)\n", st.Field(i).Name()) } + + if _, ok := st.Field(i).Type().(*types.Slice); ok { + switch st.Tag(i) { + case `dns:"domain-name"`: + fallthrough + case `dns:"cdomain-name"`: + // For HIP we need to slice over the elements in this slice. + fmt.Fprintf(b, `for i := range x.%s { + compressionLenHelper(c, x.%s[i]) + } +`, st.Field(i).Name(), st.Field(i).Name()) + } + continue + } + + switch { + case st.Tag(i) == `dns:"cdomain-name"`: + fallthrough + case st.Tag(i) == `dns:"domain-name"`: + out(st.Field(i).Name()) + } + } + } + fmt.Fprintln(b, "}\n}\n\n") + + // compressionLenSearchType - search cdomain-tags types for compressible names. + + fmt.Fprint(b, "func compressionLenSearchType(c map[string]int, r RR) (int, bool) {\n") + fmt.Fprint(b, "switch x := r.(type) {\n") + for name, _ := range cdomainTypes { + o := scope.Lookup(name) + st, _ := getTypeStruct(o.Type(), scope) + + fmt.Fprintf(b, "case *%s:\n", name) + j := 1 + for i := 1; i < st.NumFields(); i++ { + out := func(s string, j int) { + fmt.Fprintf(b, "k%d, ok%d := compressionLenSearch(c, x.%s)\n", j, j, st.Field(i).Name()) + } + + // There are no slice types with names that can be compressed. + + switch { + case st.Tag(i) == `dns:"cdomain-name"`: + out(st.Field(i).Name(), j) + j++ + } + } + k := "k1" + ok := "ok1" + for i := 2; i < j; i++ { + k += fmt.Sprintf(" + k%d", i) + ok += fmt.Sprintf(" && ok%d", i) + } + fmt.Fprintf(b, "return %s, %s\n", k, ok) + } + fmt.Fprintln(b, "}\nreturn 0, false\n}\n\n") + + // gofmt + res, err := format.Source(b.Bytes()) + if err != nil { + b.WriteTo(os.Stderr) + log.Fatal(err) + } + + f, err := os.Create("zcompress.go") + fatalIfErr(err) + defer f.Close() + f.Write(res) +} + +func fatalIfErr(err error) { + if err != nil { + log.Fatal(err) + } +} diff --git a/dns_test.go b/dns_test.go index ad68533f..dbfe2532 100644 --- a/dns_test.go +++ b/dns_test.go @@ -310,6 +310,23 @@ func TestMsgLengthCompressionMalformed(t *testing.T) { m.Len() // Should not crash. } +func TestMsgCompressLength2(t *testing.T) { + msg := new(Msg) + msg.Compress = true + msg.SetQuestion(Fqdn("bliep."), TypeANY) + msg.Answer = append(msg.Answer, &SRV{Hdr: RR_Header{Name: "blaat.", Rrtype: 0x21, Class: 0x1, Ttl: 0x3c}, Port: 0x4c57, Target: "foo.bar."}) + msg.Extra = append(msg.Extra, &A{Hdr: RR_Header{Name: "foo.bar.", Rrtype: 0x1, Class: 0x1, Ttl: 0x3c}, A: net.IP{0xac, 0x11, 0x0, 0x3}}) + predicted := msg.Len() + buf, err := msg.Pack() + if err != nil { + t.Error(err) + } + if predicted != len(buf) { + t.Errorf("predicted compressed length is wrong: predicted %s (len=%d) %d, actual %d", + msg.Question[0].Name, len(msg.Answer), predicted, len(buf)) + } +} + func TestToRFC3597(t *testing.T) { a, _ := NewRR("miek.nl. IN A 10.0.1.1") x := new(RFC3597) diff --git a/msg.go b/msg.go index c74e724f..55ba05fc 100644 --- a/msg.go +++ b/msg.go @@ -9,6 +9,7 @@ package dns //go:generate go run msg_generate.go +//go:generate go run compress_generate.go import ( crand "crypto/rand" @@ -262,7 +263,9 @@ func packDomainName(s string, msg []byte, off int, compression map[string]int, c bsFresh = true } // Don't try to compress '.' - if compress && roBs[begin:] != "." { + // We should only compress when compress it true, but we should also still pick + // up names that can be used for *future* compression(s). + if compression != nil && roBs[begin:] != "." { if p, ok := compression[roBs[begin:]]; !ok { // Only offsets smaller than this can be used. if offset < maxCompressionOffset { @@ -892,10 +895,8 @@ func (dns *Msg) String() string { func (dns *Msg) Len() int { // We always return one more than needed. l := 12 // Message header is always 12 bytes - var compression map[string]int - if dns.Compress { - compression = make(map[string]int) - } + compression := map[string]int{} + for i := 0; i < len(dns.Question); i++ { l += dns.Question[i].len() if dns.Compress { @@ -991,97 +992,6 @@ func compressionLenSearch(c map[string]int, s string) (int, bool) { return 0, false } -// TODO(miek): should add all types, because the all can be *used* for compression. Autogenerate from msg_generate and put in zmsg.go -func compressionLenHelperType(c map[string]int, r RR) { - switch x := r.(type) { - case *NS: - compressionLenHelper(c, x.Ns) - case *MX: - compressionLenHelper(c, x.Mx) - case *CNAME: - compressionLenHelper(c, x.Target) - case *PTR: - compressionLenHelper(c, x.Ptr) - case *SOA: - compressionLenHelper(c, x.Ns) - compressionLenHelper(c, x.Mbox) - case *MB: - compressionLenHelper(c, x.Mb) - case *MG: - compressionLenHelper(c, x.Mg) - case *MR: - compressionLenHelper(c, x.Mr) - case *MF: - compressionLenHelper(c, x.Mf) - case *MD: - compressionLenHelper(c, x.Md) - case *RT: - compressionLenHelper(c, x.Host) - case *RP: - compressionLenHelper(c, x.Mbox) - compressionLenHelper(c, x.Txt) - case *MINFO: - compressionLenHelper(c, x.Rmail) - compressionLenHelper(c, x.Email) - case *AFSDB: - compressionLenHelper(c, x.Hostname) - case *SRV: - compressionLenHelper(c, x.Target) - case *NAPTR: - compressionLenHelper(c, x.Replacement) - case *RRSIG: - compressionLenHelper(c, x.SignerName) - case *NSEC: - compressionLenHelper(c, x.NextDomain) - // HIP? - } -} - -// Only search on compressing these types. -func compressionLenSearchType(c map[string]int, r RR) (int, bool) { - switch x := r.(type) { - case *NS: - return compressionLenSearch(c, x.Ns) - case *MX: - return compressionLenSearch(c, x.Mx) - case *CNAME: - return compressionLenSearch(c, x.Target) - case *DNAME: - return compressionLenSearch(c, x.Target) - case *PTR: - return compressionLenSearch(c, x.Ptr) - case *SOA: - k, ok := compressionLenSearch(c, x.Ns) - k1, ok1 := compressionLenSearch(c, x.Mbox) - if !ok && !ok1 { - return 0, false - } - return k + k1, true - case *MB: - return compressionLenSearch(c, x.Mb) - case *MG: - return compressionLenSearch(c, x.Mg) - case *MR: - return compressionLenSearch(c, x.Mr) - case *MF: - return compressionLenSearch(c, x.Mf) - case *MD: - return compressionLenSearch(c, x.Md) - case *RT: - return compressionLenSearch(c, x.Host) - case *MINFO: - k, ok := compressionLenSearch(c, x.Rmail) - k1, ok1 := compressionLenSearch(c, x.Email) - if !ok && !ok1 { - return 0, false - } - return k + k1, true - case *AFSDB: - return compressionLenSearch(c, x.Hostname) - } - return 0, false -} - // Copy returns a new RR which is a deep-copy of r. func Copy(r RR) RR { r1 := r.copy(); return r1 } diff --git a/zcompress.go b/zcompress.go new file mode 100644 index 00000000..1871d4ae --- /dev/null +++ b/zcompress.go @@ -0,0 +1,119 @@ +// *** DO NOT MODIFY *** +// AUTOGENERATED BY go generate from compress_generate.go + +package dns + +func compressionLenHelperType(c map[string]int, r RR) { + switch x := r.(type) { + case *AFSDB: + compressionLenHelper(c, x.Hostname) + case *RRSIG: + compressionLenHelper(c, x.SignerName) + case *MD: + compressionLenHelper(c, x.Md) + case *MX: + compressionLenHelper(c, x.Mx) + case *TALINK: + compressionLenHelper(c, x.PreviousName) + compressionLenHelper(c, x.NextName) + case *MF: + compressionLenHelper(c, x.Mf) + case *RP: + compressionLenHelper(c, x.Mbox) + compressionLenHelper(c, x.Txt) + case *RT: + compressionLenHelper(c, x.Host) + case *SOA: + compressionLenHelper(c, x.Ns) + compressionLenHelper(c, x.Mbox) + case *NSAPPTR: + compressionLenHelper(c, x.Ptr) + case *CNAME: + compressionLenHelper(c, x.Target) + case *LP: + compressionLenHelper(c, x.Fqdn) + case *NSEC: + compressionLenHelper(c, x.NextDomain) + case *PTR: + compressionLenHelper(c, x.Ptr) + case *TSIG: + compressionLenHelper(c, x.Algorithm) + case *MG: + compressionLenHelper(c, x.Mg) + case *PX: + compressionLenHelper(c, x.Map822) + compressionLenHelper(c, x.Mapx400) + case *SRV: + compressionLenHelper(c, x.Target) + case *TKEY: + compressionLenHelper(c, x.Algorithm) + case *MINFO: + compressionLenHelper(c, x.Rmail) + compressionLenHelper(c, x.Email) + case *NAPTR: + compressionLenHelper(c, x.Replacement) + case *SIG: + compressionLenHelper(c, x.SignerName) + case *DNAME: + compressionLenHelper(c, x.Target) + case *HIP: + for i := range x.RendezvousServers { + compressionLenHelper(c, x.RendezvousServers[i]) + } + case *KX: + compressionLenHelper(c, x.Exchanger) + case *MB: + compressionLenHelper(c, x.Mb) + case *MR: + compressionLenHelper(c, x.Mr) + case *NS: + compressionLenHelper(c, x.Ns) + } +} + +func compressionLenSearchType(c map[string]int, r RR) (int, bool) { + switch x := r.(type) { + case *PTR: + k1, ok1 := compressionLenSearch(c, x.Ptr) + return k1, ok1 + case *RT: + k1, ok1 := compressionLenSearch(c, x.Host) + return k1, ok1 + case *SOA: + k1, ok1 := compressionLenSearch(c, x.Ns) + k2, ok2 := compressionLenSearch(c, x.Mbox) + return k1 + k2, ok1 && ok2 + case *MB: + k1, ok1 := compressionLenSearch(c, x.Mb) + return k1, ok1 + case *MG: + k1, ok1 := compressionLenSearch(c, x.Mg) + return k1, ok1 + case *MR: + k1, ok1 := compressionLenSearch(c, x.Mr) + return k1, ok1 + case *NS: + k1, ok1 := compressionLenSearch(c, x.Ns) + return k1, ok1 + case *MINFO: + k1, ok1 := compressionLenSearch(c, x.Rmail) + k2, ok2 := compressionLenSearch(c, x.Email) + return k1 + k2, ok1 && ok2 + case *MX: + k1, ok1 := compressionLenSearch(c, x.Mx) + return k1, ok1 + case *AFSDB: + k1, ok1 := compressionLenSearch(c, x.Hostname) + return k1, ok1 + case *CNAME: + k1, ok1 := compressionLenSearch(c, x.Target) + return k1, ok1 + case *MD: + k1, ok1 := compressionLenSearch(c, x.Md) + return k1, ok1 + case *MF: + k1, ok1 := compressionLenSearch(c, x.Mf) + return k1, ok1 + } + return 0, false +}