diff --git a/labels.go b/labels.go index aa9dd41e..e32d2a1d 100644 --- a/labels.go +++ b/labels.go @@ -13,29 +13,27 @@ func SplitDomainName(s string) (labels []string) { if len(s) == 0 { return nil } - if s == "." { - return nil - } - // offset of the final '.' or the length of the name - var fqdnEnd int + fqdnEnd := 0 // offset of the final '.' or the length of the name + idx := Split(s) + begin := 0 if IsFqdn(s) { fqdnEnd = len(s) - 1 } else { fqdnEnd = len(s) } - var ( - begin int - off int - end bool - ) - for { - off, end = NextLabel(s, off) - if end { - break + + switch len(idx) { + case 0: + return nil + case 1: + // no-op + default: + for _, end := range idx[1:] { + labels = append(labels, s[begin:end-1]) + begin = end } - labels = append(labels, s[begin:off-1]) - begin = off } + return append(labels, s[begin:fqdnEnd]) } @@ -54,50 +52,52 @@ func CompareDomainName(s1, s2 string) (n int) { return 0 } - j1 := len(s1) - if s1[j1-1] == '.' { - j1-- + l1 := Split(s1) + l2 := Split(s2) + + j1 := len(l1) - 1 // end + i1 := len(l1) - 2 // start + j2 := len(l2) - 1 + i2 := len(l2) - 2 + // the second check can be done here: last/only label + // before we fall through into the for-loop below + if equal(s1[l1[j1]:], s2[l2[j2]:]) { + n++ + } else { + return } - j2 := len(s2) - if s2[j2-1] == '.' { - j2-- - } - var i1, i2 int for { - i1 = prevLabel(s1, j1-1) - i2 = prevLabel(s2, j2-1) - if equal(s1[i1:j1], s2[i2:j2]) { + if i1 < 0 || i2 < 0 { + break + } + if equal(s1[l1[i1]:l1[j1]], s2[l2[i2]:l2[j2]]) { n++ } else { break } - if i1 == 0 || i2 == 0 { - break - } - j1 = i1 - 2 - j2 = i2 - 2 + j1-- + i1-- + j2-- + i2-- } return } // CountLabel counts the the number of labels in the string s. // s must be a syntactically valid domain name. -func CountLabel(s string) int { +func CountLabel(s string) (labels int) { if s == "." { - return 0 + return } - labels := 1 - for i := 0; i < len(s)-1; i++ { - c := s[i] - if c == '\\' { - i++ - continue - } - if c == '.' { - labels++ + off := 0 + end := false + for { + off, end = NextLabel(s, off) + labels++ + if end { + return } } - return labels } // Split splits a name s into its label indexes. @@ -126,70 +126,40 @@ func Split(s string) []int { // The bool end is true when the end of the string has been reached. // Also see PrevLabel. func NextLabel(s string, offset int) (i int, end bool) { + quote := false for i = offset; i < len(s)-1; i++ { - c := s[i] - if c == '\\' { - i++ - continue - } - if c == '.' { + switch s[i] { + case '\\': + quote = !quote + default: + quote = false + case '.': + if quote { + quote = !quote + continue + } return i + 1, false } } return i + 1, true } -func prevLabel(s string, offset int) int { - for i := offset; i >= 0; i-- { - if s[i] == '.' { - if i == 0 || s[i-1] != '\\' { - return i + 1 // the '.' is not escaped - } - // We are at '\.' and need to check if the '\' itself is escaped. - // We do this by walking backwards from '\.' and counting the - // number of '\' we encounter. If the number of '\' is even - // (though here it's actually odd since we start at '\.') the '\' - // is escaped. - j := i - 2 - for ; j >= 0 && s[j] == '\\'; j-- { - } - // An odd number here indicates that the '\' preceding the '.' - // is escaped. - if (i-j)&1 == 1 { - return i + 1 - } - i = j + 1 - } - } - return 0 -} - // PrevLabel returns the index of the label when starting from the right and // jumping n labels to the left. // The bool start is true when the start of the string has been overshot. // Also see NextLabel. func PrevLabel(s string, n int) (i int, start bool) { - if s == "." { - return 0, true - } if n == 0 { return len(s), false } - i = len(s) - 1 - if s[i] == '.' { - i-- - } - for ; n > 0; n-- { - i = prevLabel(s, i) - if i == 0 { - break - } - i -= 2 - } - if n > 0 { + lab := Split(s) + if lab == nil { return 0, true } - return i + 2, false + if n > len(lab) { + return 0, true + } + return lab[len(lab)-n], false } // equal compares a and b while ignoring case. It returns true when equal otherwise false. @@ -200,19 +170,18 @@ func equal(a, b string) bool { if la != lb { return false } - if a != b { - // case-insensitive comparison - for i := la - 1; i >= 0; i-- { - ai := a[i] - bi := b[i] - if ai != bi { - if bi < ai { - bi, ai = ai, bi - } - if !('A' <= ai && ai <= 'Z' && bi == ai+'a'-'A') { - return false - } - } + + for i := la - 1; i >= 0; i-- { + ai := a[i] + bi := b[i] + if ai >= 'A' && ai <= 'Z' { + ai |= 'a' - 'A' + } + if bi >= 'A' && bi <= 'Z' { + bi |= 'a' - 'A' + } + if ai != bi { + return false } } return true diff --git a/labels_test.go b/labels_test.go index bd339d91..3f666df6 100644 --- a/labels_test.go +++ b/labels_test.go @@ -1,50 +1,55 @@ package dns -import ( - "strings" - "testing" -) +import "testing" func TestCompareDomainName(t *testing.T) { - tests := []struct { - s1, s2 string - expected int - }{ - {"www.miek.nl.", "miek.nl.", 2}, - {"miek.nl.", "www.bla.nl.", 1}, - {"www.bla.nl.", "nl.www.bla.", 0}, - {"www.miek.nl.", "nl.", 1}, - {"www.miek.nl.", "miek.nl.", 2}, - {"www.miek.nl.", ".", 0}, - {".", ".", 0}, - {"test.com.", "TEST.COM.", 2}, - {"a.b.c.d.e.f.", "a.b.c.d.e.", 0}, - {"a.b.c.d.e.", "a.b.c.d.e.", 5}, + s1 := "www.miek.nl." + s2 := "miek.nl." + s3 := "www.bla.nl." + s4 := "nl.www.bla." + s5 := "nl." + s6 := "miek.nl." + + if CompareDomainName(s1, s2) != 2 { + t.Errorf("%s with %s should be %d", s1, s2, 2) } - for _, x := range tests { - if i := CompareDomainName(x.s1, x.s2); i != x.expected { - t.Errorf("%s with %s should be %d got: %d", x.s1, x.s2, x.expected, i) - } + if CompareDomainName(s1, s3) != 1 { + t.Errorf("%s with %s should be %d", s1, s3, 1) + } + if CompareDomainName(s3, s4) != 0 { + t.Errorf("%s with %s should be %d", s3, s4, 0) + } + // Non qualified tests + if CompareDomainName(s1, s5) != 1 { + t.Errorf("%s with %s should be %d", s1, s5, 1) + } + if CompareDomainName(s1, s6) != 2 { + t.Errorf("%s with %s should be %d", s1, s5, 2) + } + + if CompareDomainName(s1, ".") != 0 { + t.Errorf("%s with %s should be %d", s1, s5, 0) + } + if CompareDomainName(".", ".") != 0 { + t.Errorf("%s with %s should be %d", ".", ".", 0) + } + if CompareDomainName("test.com.", "TEST.COM.") != 2 { + t.Errorf("test.com. and TEST.COM. should be an exact match") } } func TestSplit(t *testing.T) { splitter := map[string]int{ - "www.miek.nl.": 3, - "www.miek.nl": 3, - "www..miek.nl": 4, - `www\.miek.nl.`: 2, - `www\\.miek.nl.`: 3, - `\\.miek.nl.`: 3, - `\\\.miek.nl.`: 2, - `\\\\.miek.nl.`: 3, - `www.miek\\\\.nl.`: 3, - `www.miek\\\.nl.`: 2, - ".": 0, - "nl.": 1, - "nl": 1, - "com.": 1, - ".com.": 2, + "www.miek.nl.": 3, + "www.miek.nl": 3, + "www..miek.nl": 4, + `www\.miek.nl.`: 2, + `www\\.miek.nl.`: 3, + ".": 0, + "nl.": 1, + "nl": 1, + "com.": 1, + ".com.": 2, } for s, i := range splitter { if x := len(Split(s)); x != i { @@ -93,17 +98,7 @@ func TestPrevLabel(t *testing.T) { {"www.miek.nl.", 3}: 0, {"www.miek.nl", 3}: 0, - - {"a.b.c.", 1}: 4, - {"a.b.c", 1}: 4, } - - // make sure we are safe when the label begins with a possibly escaped '.' - for i := 1; i < 8; i++ { - s := strings.Repeat(`\`, i) + "." - prever[prev{s, 0}] = i + 1 - } - for s, i := range prever { x, ok := PrevLabel(s.string, s.int) if i != x { @@ -215,27 +210,6 @@ func TestIsFqdnEscaped(t *testing.T) { } } -func TestEqual(t *testing.T) { - type testcase struct { - a, b string - match bool - } - tests := []testcase{ - {"a", "a", true}, - {"a", "A", true}, - {"A", "a", true}, - {"A", "b", false}, - {"www.example.com.", "www.exAmpLe.com.", true}, - {"www.example.com.", "www.exAmpLe.org.", false}, - } - for _, x := range tests { - eq := equal(x.a, x.b) - if eq != x.match { - t.Errorf("%+v: want: %t got: %t", x, x.match, eq) - } - } -} - func BenchmarkSplitLabels(b *testing.B) { for i := 0; i < b.N; i++ { Split("www.example.com.")