From 808f9308ef0e3219360d94ac7e73447efdc03f7d Mon Sep 17 00:00:00 2001 From: "tal@whatexit.org" Date: Tue, 12 Jan 2016 14:08:31 -0500 Subject: [PATCH] Added new functions: TrimDomainName()/AddOrigin() * TrimDomainName() Trims the domain off a FQDN. * AddOrigin() Adds a domain to a shortname or leaves a FQDN alone. --- dnsutil/util.go | 79 ++++++++++++++++++++++++++ dnsutil/util_test.go | 130 +++++++++++++++++++++++++++++++++++++++++++ labels.go | 2 + labels_test.go | 5 +- 4 files changed, 213 insertions(+), 3 deletions(-) create mode 100644 dnsutil/util.go create mode 100644 dnsutil/util_test.go diff --git a/dnsutil/util.go b/dnsutil/util.go new file mode 100644 index 00000000..9ed03f29 --- /dev/null +++ b/dnsutil/util.go @@ -0,0 +1,79 @@ +// Package dnsutil contains higher-level methods useful with the dns +// package. While package dns implements the DNS protocols itself, +// these functions are related but not directly required for protocol +// processing. They are often useful in preparing input/output of the +// functions in package dns. +package dnsutil + +import ( + "strings" + + "github.com/miekg/dns" +) + +// AddDomain adds origin to s if s is not already a FQDN. +// Note that the result may not be a FQDN. If origin does not end +// with a ".", the result won't either. +// This implements the zonefile convention (specified in RFC 1035, +// Section "5.1. Format") that "@" represents the +// apex (bare) domain. i.e. AddOrigin("@", "foo.com.") returns "foo.com.". +func AddOrigin(s, origin string) string { + // ("foo.", "origin.") -> "foo." (already a FQDN) + // ("foo", "origin.") -> "foo.origin." + // ("foo"), "origin" -> "foo.origin" + // ("@", "origin.") -> "origin." (@ represents the apex (bare) domain) + // ("", "origin.") -> "origin." (not obvious) + // ("foo", "") -> "foo" (not obvious) + + if dns.IsFqdn(s) { + return s // s is already a FQDN, no need to mess with it. + } + if len(origin) == 0 { + return s // Nothing to append. + } + if s == "@" || len(s) == 0 { + return origin // Expand apex. + } + + if origin == "." { + return s + origin // AddOrigin(s, ".") is an expensive way to add a ".". + } + + return s + "." + origin // The simple case. +} + +// TrimDomainName trims origin from s if s is a subdomain. +// This function will never return "", but returns "@" instead (@ represents the apex (bare) domain). +func TrimDomainName(s, origin string) string { + // An apex (bare) domain is always returned as "@". + // If the return value ends in a ".", the domain was not the suffix. + // origin can end in "." or not. Either way the results should be the same. + + if len(s) == 0 { + return "@" // Return the apex (@) rather than "". + } + // Someone is using TrimDomainName(s, ".") to remove a dot if it exists. + if origin == "." { + return strings.TrimSuffix(s, origin) + } + + // Dude, you aren't even if the right subdomain! + if !dns.IsSubDomain(origin, s) { + return s + } + + slabels := dns.Split(s) + olabels := dns.Split(origin) + m := dns.CompareDomainName(s, origin) + if len(olabels) == m { + if len(olabels) == len(slabels) { + return "@" // origin == s + } + if (s[0] == '.') && (len(slabels) == (len(olabels) + 1)) { + return "@" // TrimDomainName(".foo.", "foo.") + } + } + + // Return the first (len-m) labels: + return s[:slabels[len(slabels)-m]-1] +} diff --git a/dnsutil/util_test.go b/dnsutil/util_test.go new file mode 100644 index 00000000..0f1ecec8 --- /dev/null +++ b/dnsutil/util_test.go @@ -0,0 +1,130 @@ +package dnsutil + +import "testing" + +func TestAddOrigin(t *testing.T) { + var tests = []struct{ e1, e2, expected string }{ + {"@", "example.com", "example.com"}, + {"foo", "example.com", "foo.example.com"}, + {"foo.", "example.com", "foo."}, + {"@", "example.com.", "example.com."}, + {"foo", "example.com.", "foo.example.com."}, + {"foo.", "example.com.", "foo."}, + // Oddball tests: + // In general origin should not be "" or "." but at least + // these tests verify we don't crash and will keep results + // from changing unexpectedly. + {"*.", "", "*."}, + {"@", "", "@"}, + {"foobar", "", "foobar"}, + {"foobar.", "", "foobar."}, + {"*.", ".", "*."}, + {"@", ".", "."}, + {"foobar", ".", "foobar."}, + {"foobar.", ".", "foobar."}, + } + for _, test := range tests { + actual := AddOrigin(test.e1, test.e2) + if test.expected != actual { + t.Errorf("AddOrigin(%#v, %#v) expected %#v, go %#v\n", test.e1, test.e2, test.expected, actual) + } + } +} + +func TestTrimDomainName(t *testing.T) { + + // Basic tests. + // Try trimming "example.com" and "example.com." from typical use cases. + var tests_examplecom = []struct{ experiment, expected string }{ + {"foo.example.com", "foo"}, + {"foo.example.com.", "foo"}, + {".foo.example.com", ".foo"}, + {".foo.example.com.", ".foo"}, + {"*.example.com", "*"}, + {"example.com", "@"}, + {"example.com.", "@"}, + {"com.", "com."}, + {"foo.", "foo."}, + {"serverfault.com.", "serverfault.com."}, + {"serverfault.com", "serverfault.com"}, + {".foo.ronco.com", ".foo.ronco.com"}, + {".foo.ronco.com.", ".foo.ronco.com."}, + } + for _, dom := range []string{"example.com", "example.com."} { + for i, test := range tests_examplecom { + actual := TrimDomainName(test.experiment, dom) + if test.expected != actual { + t.Errorf("%d TrimDomainName(%#v, %#v): expected (%v) got (%v)\n", i, test.experiment, dom, test.expected, actual) + } + } + } + + // Paranoid tests. + // These test shouldn't be needed but I was weary of off-by-one errors. + // In theory, these can't happen because there are no single-letter TLDs, + // but it is good to exercize the code this way. + var tests = []struct{ experiment, expected string }{ + {"", "@"}, + {".", "."}, + {"a.b.c.d.e.f.", "a.b.c.d.e"}, + {"b.c.d.e.f.", "b.c.d.e"}, + {"c.d.e.f.", "c.d.e"}, + {"d.e.f.", "d.e"}, + {"e.f.", "e"}, + {"f.", "@"}, + {".a.b.c.d.e.f.", ".a.b.c.d.e"}, + {".b.c.d.e.f.", ".b.c.d.e"}, + {".c.d.e.f.", ".c.d.e"}, + {".d.e.f.", ".d.e"}, + {".e.f.", ".e"}, + {".f.", "@"}, + {"a.b.c.d.e.f", "a.b.c.d.e"}, + {"a.b.c.d.e.", "a.b.c.d.e."}, + {"a.b.c.d.e", "a.b.c.d.e"}, + {"a.b.c.d.", "a.b.c.d."}, + {"a.b.c.d", "a.b.c.d"}, + {"a.b.c.", "a.b.c."}, + {"a.b.c", "a.b.c"}, + {"a.b.", "a.b."}, + {"a.b", "a.b"}, + {"a.", "a."}, + {"a", "a"}, + {".a.b.c.d.e.f", ".a.b.c.d.e"}, + {".a.b.c.d.e.", ".a.b.c.d.e."}, + {".a.b.c.d.e", ".a.b.c.d.e"}, + {".a.b.c.d.", ".a.b.c.d."}, + {".a.b.c.d", ".a.b.c.d"}, + {".a.b.c.", ".a.b.c."}, + {".a.b.c", ".a.b.c"}, + {".a.b.", ".a.b."}, + {".a.b", ".a.b"}, + {".a.", ".a."}, + {".a", ".a"}, + } + for _, dom := range []string{"f", "f."} { + for i, test := range tests { + actual := TrimDomainName(test.experiment, dom) + if test.expected != actual { + t.Errorf("%d TrimDomainName(%#v, %#v): expected (%v) got (%v)\n", i, test.experiment, dom, test.expected, actual) + } + } + } + + // Test cases for bugs found in the wild. + // These test cases provide both origin, s, and the expected result. + // If you find a bug in the while, this is probably the easiest place + // to add it as a test case. + var tests_wild = []struct{ e1, e2, expected string }{ + {"mathoverflow.net.", ".", "mathoverflow.net"}, + {"mathoverflow.net", ".", "mathoverflow.net"}, + {"", ".", "@"}, + {"@", ".", "@"}, + } + for i, test := range tests_wild { + actual := TrimDomainName(test.e1, test.e2) + if test.expected != actual { + t.Errorf("%d TrimDomainName(%#v, %#v): expected (%v) got (%v)\n", i, test.e1, test.e2, test.expected, actual) + } + } + +} diff --git a/labels.go b/labels.go index 3944dd06..b1ee85ac 100644 --- a/labels.go +++ b/labels.go @@ -4,6 +4,7 @@ package dns // SplitDomainName splits a name string into it's labels. // www.miek.nl. returns []string{"www", "miek", "nl"} +// .www.miek.nl. returns []string{"", "www", "miek", "nl"}, // The root label (.) returns nil. Note that using // strings.Split(s) will work in most cases, but does not handle // escaped dots (\.) for instance. @@ -102,6 +103,7 @@ func CountLabel(s string) (labels int) { // Split splits a name s into its label indexes. // www.miek.nl. returns []int{0, 4, 9}, www.miek.nl also returns []int{0, 4, 9}. +// .www.miek.nl. returns []int{0, 1, 5, 10} // The root name (.) returns nil. Also see SplitDomainName. func Split(s string) []int { if s == "." { diff --git a/labels_test.go b/labels_test.go index a0616509..50f31163 100644 --- a/labels_test.go +++ b/labels_test.go @@ -1,8 +1,6 @@ package dns -import ( - "testing" -) +import "testing" func TestCompareDomainName(t *testing.T) { s1 := "www.miek.nl." @@ -132,6 +130,7 @@ func TestSplitDomainName(t *testing.T) { "www..miek.nl": {"www", "", "miek", "nl"}, `www\.miek.nl`: {`www\.miek`, "nl"}, `www\\.miek.nl`: {`www\\`, "miek", "nl"}, + ".www.miek.nl.": {"", "www", "miek", "nl"}, } domainLoop: for domain, splits := range labels {