diff --git a/compress_generate.go b/compress_generate.go index 9a136c41..4bcefcb8 100644 --- a/compress_generate.go +++ b/compress_generate.go @@ -101,7 +101,7 @@ Names: // compressionLenHelperType - all types that have domain-name/cdomain-name can be used for compressing names - fmt.Fprint(b, "func compressionLenHelperType(c map[string]int, r RR, initLen int) int {\n") + fmt.Fprint(b, "func compressionLenHelperType(c map[string]struct{}, r RR, initLen int) int {\n") fmt.Fprint(b, "currentLen := initLen\n") fmt.Fprint(b, "switch x := r.(type) {\n") for _, name := range domainTypes { @@ -145,7 +145,7 @@ Names: // compressionLenSearchType - search cdomain-tags types for compressible names. - fmt.Fprint(b, "func compressionLenSearchType(c map[string]int, r RR) (int, bool, int) {\n") + fmt.Fprint(b, "func compressionLenSearchType(c map[string]struct{}, r RR) (int, bool, int) {\n") fmt.Fprint(b, "switch x := r.(type) {\n") for _, name := range cdomainTypes { o := scope.Lookup(name) diff --git a/dns_bench_test.go b/dns_bench_test.go index 3fdfc03f..8653f8ff 100644 --- a/dns_bench_test.go +++ b/dns_bench_test.go @@ -1,6 +1,7 @@ package dns import ( + "fmt" "net" "testing" ) @@ -63,6 +64,31 @@ func BenchmarkMsgLengthPack(b *testing.B) { } } +func BenchmarkMsgLengthMassive(b *testing.B) { + makeMsg := func(question string, ans, ns, e []RR) *Msg { + msg := new(Msg) + msg.SetQuestion(Fqdn(question), TypeANY) + msg.Answer = append(msg.Answer, ans...) + msg.Ns = append(msg.Ns, ns...) + msg.Extra = append(msg.Extra, e...) + msg.Compress = true + return msg + } + const name1 = "12345678901234567890123456789012345.12345678.123." + rrMx := testRR(name1 + " 3600 IN MX 10 " + name1) + answer := []RR{rrMx, rrMx} + for i := 0; i < 128; i++ { + rrA := testRR(fmt.Sprintf("example%03d.something%03delse.org. 2311 IN A 127.0.0.1", i/32, i%32)) + answer = append(answer, rrA) + } + answer = append(answer, rrMx, rrMx) + msg := makeMsg(name1, answer, nil, nil) + b.ResetTimer() + for i := 0; i < b.N; i++ { + msg.Len() + } +} + func BenchmarkMsgLengthOnlyQuestion(b *testing.B) { msg := new(Msg) msg.SetQuestion(Fqdn("12345678901234567890123456789012345.12345678.123."), TypeANY) diff --git a/length_test.go b/length_test.go index cb16741d..17bca197 100644 --- a/length_test.go +++ b/length_test.go @@ -4,7 +4,6 @@ import ( "encoding/hex" "fmt" "net" - "reflect" "strings" "testing" ) @@ -83,85 +82,85 @@ func TestMsgLength(t *testing.T) { } func TestCompressionLenHelper(t *testing.T) { - c := make(map[string]int) + c := make(map[string]struct{}) compressionLenHelper(c, "example.com", 12) - if c["example.com"] != 12 { - t.Errorf("bad %d", c["example.com"]) + if _, ok := c["example.com"]; !ok { + t.Errorf("bad example.com") } - if c["com"] != 20 { - t.Errorf("bad %d", c["com"]) + if _, ok := c["com"]; !ok { + t.Errorf("bad com") } // Test boundaries - c = make(map[string]int) + c = make(map[string]struct{}) // foo label starts at 16379 // com label starts at 16384 compressionLenHelper(c, "foo.com", 16379) - if c["foo.com"] != 16379 { - t.Errorf("bad %d", c["foo.com"]) + if _, ok := c["foo.com"]; !ok { + t.Errorf("bad foo.com") } // com label is accessible - if c["com"] != 16383 { - t.Errorf("bad %d", c["com"]) + if _, ok := c["com"]; !ok { + t.Errorf("bad com") } - c = make(map[string]int) + c = make(map[string]struct{}) // foo label starts at 16379 // com label starts at 16385 => outside range compressionLenHelper(c, "foo.com", 16380) - if c["foo.com"] != 16380 { - t.Errorf("bad %d", c["foo.com"]) + if _, ok := c["foo.com"]; !ok { + t.Errorf("bad foo.com") } // com label is NOT accessible - if c["com"] != 0 { - t.Errorf("bad %d", c["com"]) + if _, ok := c["com"]; ok { + t.Errorf("bad com") } - c = make(map[string]int) + c = make(map[string]struct{}) compressionLenHelper(c, "example.com", 16375) - if c["example.com"] != 16375 { - t.Errorf("bad %d", c["example.com"]) + if _, ok := c["example.com"]; !ok { + t.Errorf("bad example.com") } // com starts AFTER 16384 - if c["com"] != 16383 { - t.Errorf("bad %d", c["com"]) + if _, ok := c["com"]; !ok { + t.Errorf("bad com") } - c = make(map[string]int) + c = make(map[string]struct{}) compressionLenHelper(c, "example.com", 16376) - if c["example.com"] != 16376 { - t.Errorf("bad %d", c["example.com"]) + if _, ok := c["example.com"]; !ok { + t.Errorf("bad example.com") } // com starts AFTER 16384 - if c["com"] != 0 { - t.Errorf("bad %d", c["com"]) + if _, ok := c["com"]; ok { + t.Errorf("bad com") } } func TestCompressionLenSearch(t *testing.T) { - c := make(map[string]int) + c := make(map[string]struct{}) compressed, ok, fullSize := compressionLenSearch(c, "a.b.org.") if compressed != 0 || ok || fullSize != 14 { t.Errorf("Failed: compressed:=%d, ok:=%v, fullSize:=%d", compressed, ok, fullSize) } - c["org."] = 3 + c["org."] = struct{}{} compressed, ok, fullSize = compressionLenSearch(c, "a.b.org.") if compressed != 4 || !ok || fullSize != 8 { t.Errorf("Failed: compressed:=%d, ok:=%v, fullSize:=%d", compressed, ok, fullSize) } - c["b.org."] = 5 + c["b.org."] = struct{}{} compressed, ok, fullSize = compressionLenSearch(c, "a.b.org.") if compressed != 6 || !ok || fullSize != 4 { t.Errorf("Failed: compressed:=%d, ok:=%v, fullSize:=%d", compressed, ok, fullSize) } // Not found long compression - c["x.b.org."] = 5 + c["x.b.org."] = struct{}{} compressed, ok, fullSize = compressionLenSearch(c, "a.b.org.") if compressed != 6 || !ok || fullSize != 4 { t.Errorf("Failed: compressed:=%d, ok:=%v, fullSize:=%d", compressed, ok, fullSize) } // Found long compression - c["a.b.org."] = 5 + c["a.b.org."] = struct{}{} compressed, ok, fullSize = compressionLenSearch(c, "a.b.org.") if compressed != 8 || !ok || fullSize != 0 { t.Errorf("Failed: compressed:=%d, ok:=%v, fullSize:=%d", compressed, ok, fullSize) @@ -262,6 +261,20 @@ func TestMsgCompressLengthLargeRecords(t *testing.T) { } } +func compressionMapsEqual(a map[string]struct{}, b map[string]int) bool { + if len(a) != len(b) { + return false + } + + for k := range a { + if _, ok := b[k]; !ok { + return false + } + } + + return true +} + func TestCompareCompressionMapsForANY(t *testing.T) { msg := new(Msg) msg.Compress = true @@ -278,7 +291,7 @@ func TestCompareCompressionMapsForANY(t *testing.T) { for labelSize := 0; labelSize < 63; labelSize++ { msg.SetQuestion(fmt.Sprintf("a%s.service.acme.", strings.Repeat("x", labelSize)), TypeANY) - compressionFake := make(map[string]int) + compressionFake := make(map[string]struct{}) lenFake := compressedLenWithCompressionMap(msg, compressionFake) compressionReal := make(map[string]int) @@ -289,7 +302,7 @@ func TestCompareCompressionMapsForANY(t *testing.T) { if lenFake != len(buf) { t.Fatalf("padding= %d ; Predicted len := %d != real:= %d", labelSize, lenFake, len(buf)) } - if !reflect.DeepEqual(compressionFake, compressionReal) { + if !compressionMapsEqual(compressionFake, compressionReal) { t.Fatalf("padding= %d ; Fake Compression Map != Real Compression Map\n*** Real:= %v\n\n***Fake:= %v", labelSize, compressionReal, compressionFake) } } @@ -311,7 +324,7 @@ func TestCompareCompressionMapsForSRV(t *testing.T) { for labelSize := 0; labelSize < 63; labelSize++ { msg.SetQuestion(fmt.Sprintf("a%s.service.acme.", strings.Repeat("x", labelSize)), TypeAAAA) - compressionFake := make(map[string]int) + compressionFake := make(map[string]struct{}) lenFake := compressedLenWithCompressionMap(msg, compressionFake) compressionReal := make(map[string]int) @@ -322,7 +335,7 @@ func TestCompareCompressionMapsForSRV(t *testing.T) { if lenFake != len(buf) { t.Fatalf("padding= %d ; Predicted len := %d != real:= %d", labelSize, lenFake, len(buf)) } - if !reflect.DeepEqual(compressionFake, compressionReal) { + if !compressionMapsEqual(compressionFake, compressionReal) { t.Fatalf("padding= %d ; Fake Compression Map != Real Compression Map\n*** Real:= %v\n\n***Fake:= %v", labelSize, compressionReal, compressionFake) } } diff --git a/msg.go b/msg.go index 4208eb45..90f02973 100644 --- a/msg.go +++ b/msg.go @@ -907,7 +907,7 @@ func (dns *Msg) isCompressible() bool { len(dns.Ns) > 0 || len(dns.Extra) > 0 } -func compressedLenWithCompressionMap(dns *Msg, compression map[string]int) int { +func compressedLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int { l := 12 // Message header is always 12 bytes for _, r := range dns.Question { compressionLenHelper(compression, r.Name, l) @@ -927,7 +927,7 @@ func compressedLen(dns *Msg, compress bool) int { // If this message can't be compressed, avoid filling the // compression map and creating garbage. if compress && dns.isCompressible() { - compression := map[string]int{} + compression := make(map[string]struct{}) return compressedLenWithCompressionMap(dns, compression) } @@ -954,7 +954,7 @@ func compressedLen(dns *Msg, compress bool) int { return l } -func compressionLenSlice(lenp int, c map[string]int, rs []RR) int { +func compressionLenSlice(lenp int, c map[string]struct{}, rs []RR) int { initLen := lenp for _, r := range rs { if r == nil { @@ -986,7 +986,7 @@ func compressionLenSlice(lenp int, c map[string]int, rs []RR) int { } // Put the parts of the name in the compression map, return the size in bytes added in payload -func compressionLenHelper(c map[string]int, s string, currentLen int) int { +func compressionLenHelper(c map[string]struct{}, s string, currentLen int) int { if currentLen > maxCompressionOffset { // We won't be able to add any label that could be re-used later anyway return 0 @@ -1005,7 +1005,7 @@ func compressionLenHelper(c map[string]int, s string, currentLen int) int { if _, ok := c[pref]; !ok { // If first byte label is within the first 14bits, it might be re-used later if currentLen < maxCompressionOffset { - c[pref] = currentLen + c[pref] = struct{}{} } } else { added := currentLen - initLen @@ -1023,7 +1023,7 @@ func compressionLenHelper(c map[string]int, s string, currentLen int) int { // keep on searching so we get the longest match. // Will return the size of compression found, whether a match has been // found and the size of record if added in payload -func compressionLenSearch(c map[string]int, s string) (int, bool, int) { +func compressionLenSearch(c map[string]struct{}, s string) (int, bool, int) { off := 0 end := false if s == "" { // don't bork on bogus data diff --git a/zcompress.go b/zcompress.go index 6391a350..fefdd2a0 100644 --- a/zcompress.go +++ b/zcompress.go @@ -2,7 +2,7 @@ package dns -func compressionLenHelperType(c map[string]int, r RR, initLen int) int { +func compressionLenHelperType(c map[string]struct{}, r RR, initLen int) int { currentLen := initLen switch x := r.(type) { case *AFSDB: @@ -107,7 +107,7 @@ func compressionLenHelperType(c map[string]int, r RR, initLen int) int { return currentLen - initLen } -func compressionLenSearchType(c map[string]int, r RR) (int, bool, int) { +func compressionLenSearchType(c map[string]struct{}, r RR) (int, bool, int) { switch x := r.(type) { case *CNAME: k1, ok1, sz1 := compressionLenSearch(c, x.Target)