diff --git a/labels.go b/labels.go index e32d2a1d..aa9dd41e 100644 --- a/labels.go +++ b/labels.go @@ -13,27 +13,29 @@ func SplitDomainName(s string) (labels []string) { if len(s) == 0 { return nil } - fqdnEnd := 0 // offset of the final '.' or the length of the name - idx := Split(s) - begin := 0 + if s == "." { + return nil + } + // offset of the final '.' or the length of the name + var fqdnEnd int if IsFqdn(s) { fqdnEnd = len(s) - 1 } else { fqdnEnd = len(s) } - - switch len(idx) { - case 0: - return nil - case 1: - // no-op - default: - for _, end := range idx[1:] { - labels = append(labels, s[begin:end-1]) - begin = end + var ( + begin int + off int + end bool + ) + for { + off, end = NextLabel(s, off) + if end { + break } + labels = append(labels, s[begin:off-1]) + begin = off } - return append(labels, s[begin:fqdnEnd]) } @@ -52,52 +54,50 @@ func CompareDomainName(s1, s2 string) (n int) { return 0 } - l1 := Split(s1) - l2 := Split(s2) - - j1 := len(l1) - 1 // end - i1 := len(l1) - 2 // start - j2 := len(l2) - 1 - i2 := len(l2) - 2 - // the second check can be done here: last/only label - // before we fall through into the for-loop below - if equal(s1[l1[j1]:], s2[l2[j2]:]) { - n++ - } else { - return + j1 := len(s1) + if s1[j1-1] == '.' { + j1-- } + j2 := len(s2) + if s2[j2-1] == '.' { + j2-- + } + var i1, i2 int for { - if i1 < 0 || i2 < 0 { - break - } - if equal(s1[l1[i1]:l1[j1]], s2[l2[i2]:l2[j2]]) { + i1 = prevLabel(s1, j1-1) + i2 = prevLabel(s2, j2-1) + if equal(s1[i1:j1], s2[i2:j2]) { n++ } else { break } - j1-- - i1-- - j2-- - i2-- + if i1 == 0 || i2 == 0 { + break + } + j1 = i1 - 2 + j2 = i2 - 2 } return } // CountLabel counts the the number of labels in the string s. // s must be a syntactically valid domain name. -func CountLabel(s string) (labels int) { +func CountLabel(s string) int { if s == "." { - return + return 0 } - off := 0 - end := false - for { - off, end = NextLabel(s, off) - labels++ - if end { - return + labels := 1 + for i := 0; i < len(s)-1; i++ { + c := s[i] + if c == '\\' { + i++ + continue + } + if c == '.' { + labels++ } } + return labels } // Split splits a name s into its label indexes. @@ -126,40 +126,70 @@ func Split(s string) []int { // The bool end is true when the end of the string has been reached. // Also see PrevLabel. func NextLabel(s string, offset int) (i int, end bool) { - quote := false for i = offset; i < len(s)-1; i++ { - switch s[i] { - case '\\': - quote = !quote - default: - quote = false - case '.': - if quote { - quote = !quote - continue - } + c := s[i] + if c == '\\' { + i++ + continue + } + if c == '.' { return i + 1, false } } return i + 1, true } +func prevLabel(s string, offset int) int { + for i := offset; i >= 0; i-- { + if s[i] == '.' { + if i == 0 || s[i-1] != '\\' { + return i + 1 // the '.' is not escaped + } + // We are at '\.' and need to check if the '\' itself is escaped. + // We do this by walking backwards from '\.' and counting the + // number of '\' we encounter. If the number of '\' is even + // (though here it's actually odd since we start at '\.') the '\' + // is escaped. + j := i - 2 + for ; j >= 0 && s[j] == '\\'; j-- { + } + // An odd number here indicates that the '\' preceding the '.' + // is escaped. + if (i-j)&1 == 1 { + return i + 1 + } + i = j + 1 + } + } + return 0 +} + // PrevLabel returns the index of the label when starting from the right and // jumping n labels to the left. // The bool start is true when the start of the string has been overshot. // Also see NextLabel. func PrevLabel(s string, n int) (i int, start bool) { + if s == "." { + return 0, true + } if n == 0 { return len(s), false } - lab := Split(s) - if lab == nil { + i = len(s) - 1 + if s[i] == '.' { + i-- + } + for ; n > 0; n-- { + i = prevLabel(s, i) + if i == 0 { + break + } + i -= 2 + } + if n > 0 { return 0, true } - if n > len(lab) { - return 0, true - } - return lab[len(lab)-n], false + return i + 2, false } // equal compares a and b while ignoring case. It returns true when equal otherwise false. @@ -170,18 +200,19 @@ func equal(a, b string) bool { if la != lb { return false } - - for i := la - 1; i >= 0; i-- { - ai := a[i] - bi := b[i] - if ai >= 'A' && ai <= 'Z' { - ai |= 'a' - 'A' - } - if bi >= 'A' && bi <= 'Z' { - bi |= 'a' - 'A' - } - if ai != bi { - return false + if a != b { + // case-insensitive comparison + for i := la - 1; i >= 0; i-- { + ai := a[i] + bi := b[i] + if ai != bi { + if bi < ai { + bi, ai = ai, bi + } + if !('A' <= ai && ai <= 'Z' && bi == ai+'a'-'A') { + return false + } + } } } return true diff --git a/labels_test.go b/labels_test.go index 3f666df6..bd339d91 100644 --- a/labels_test.go +++ b/labels_test.go @@ -1,55 +1,50 @@ package dns -import "testing" +import ( + "strings" + "testing" +) func TestCompareDomainName(t *testing.T) { - s1 := "www.miek.nl." - s2 := "miek.nl." - s3 := "www.bla.nl." - s4 := "nl.www.bla." - s5 := "nl." - s6 := "miek.nl." - - if CompareDomainName(s1, s2) != 2 { - t.Errorf("%s with %s should be %d", s1, s2, 2) + tests := []struct { + s1, s2 string + expected int + }{ + {"www.miek.nl.", "miek.nl.", 2}, + {"miek.nl.", "www.bla.nl.", 1}, + {"www.bla.nl.", "nl.www.bla.", 0}, + {"www.miek.nl.", "nl.", 1}, + {"www.miek.nl.", "miek.nl.", 2}, + {"www.miek.nl.", ".", 0}, + {".", ".", 0}, + {"test.com.", "TEST.COM.", 2}, + {"a.b.c.d.e.f.", "a.b.c.d.e.", 0}, + {"a.b.c.d.e.", "a.b.c.d.e.", 5}, } - if CompareDomainName(s1, s3) != 1 { - t.Errorf("%s with %s should be %d", s1, s3, 1) - } - if CompareDomainName(s3, s4) != 0 { - t.Errorf("%s with %s should be %d", s3, s4, 0) - } - // Non qualified tests - if CompareDomainName(s1, s5) != 1 { - t.Errorf("%s with %s should be %d", s1, s5, 1) - } - if CompareDomainName(s1, s6) != 2 { - t.Errorf("%s with %s should be %d", s1, s5, 2) - } - - if CompareDomainName(s1, ".") != 0 { - t.Errorf("%s with %s should be %d", s1, s5, 0) - } - if CompareDomainName(".", ".") != 0 { - t.Errorf("%s with %s should be %d", ".", ".", 0) - } - if CompareDomainName("test.com.", "TEST.COM.") != 2 { - t.Errorf("test.com. and TEST.COM. should be an exact match") + for _, x := range tests { + if i := CompareDomainName(x.s1, x.s2); i != x.expected { + t.Errorf("%s with %s should be %d got: %d", x.s1, x.s2, x.expected, i) + } } } func TestSplit(t *testing.T) { splitter := map[string]int{ - "www.miek.nl.": 3, - "www.miek.nl": 3, - "www..miek.nl": 4, - `www\.miek.nl.`: 2, - `www\\.miek.nl.`: 3, - ".": 0, - "nl.": 1, - "nl": 1, - "com.": 1, - ".com.": 2, + "www.miek.nl.": 3, + "www.miek.nl": 3, + "www..miek.nl": 4, + `www\.miek.nl.`: 2, + `www\\.miek.nl.`: 3, + `\\.miek.nl.`: 3, + `\\\.miek.nl.`: 2, + `\\\\.miek.nl.`: 3, + `www.miek\\\\.nl.`: 3, + `www.miek\\\.nl.`: 2, + ".": 0, + "nl.": 1, + "nl": 1, + "com.": 1, + ".com.": 2, } for s, i := range splitter { if x := len(Split(s)); x != i { @@ -98,7 +93,17 @@ func TestPrevLabel(t *testing.T) { {"www.miek.nl.", 3}: 0, {"www.miek.nl", 3}: 0, + + {"a.b.c.", 1}: 4, + {"a.b.c", 1}: 4, } + + // make sure we are safe when the label begins with a possibly escaped '.' + for i := 1; i < 8; i++ { + s := strings.Repeat(`\`, i) + "." + prever[prev{s, 0}] = i + 1 + } + for s, i := range prever { x, ok := PrevLabel(s.string, s.int) if i != x { @@ -210,6 +215,27 @@ func TestIsFqdnEscaped(t *testing.T) { } } +func TestEqual(t *testing.T) { + type testcase struct { + a, b string + match bool + } + tests := []testcase{ + {"a", "a", true}, + {"a", "A", true}, + {"A", "a", true}, + {"A", "b", false}, + {"www.example.com.", "www.exAmpLe.com.", true}, + {"www.example.com.", "www.exAmpLe.org.", false}, + } + for _, x := range tests { + eq := equal(x.a, x.b) + if eq != x.match { + t.Errorf("%+v: want: %t got: %t", x, x.match, eq) + } + } +} + func BenchmarkSplitLabels(b *testing.B) { for i := 0; i < b.N; i++ { Split("www.example.com.")