diff --git a/idn/punycode.go b/idn/punycode.go index b5c29a07..d9c9ff27 100644 --- a/idn/punycode.go +++ b/idn/punycode.go @@ -31,31 +31,34 @@ const ( // unicode strings. func ToPunycode(s string) string { tokens := dns.SplitDomainName(s) - if s[len(s)-1] == '.' { - if tokens == nil { - tokens = []string{"", ""} - } else { - tokens = append(tokens, "") - } + switch { + case s == "": + return "" + case tokens == nil: // s == . + return "." + case s[len(s)-1] == '.': + tokens = append(tokens, "") } + for i := range tokens { - tokens[i] = string(encodeBytes([]byte(tokens[i]))) + tokens[i] = string(encode([]byte(tokens[i]))) } return strings.Join(tokens, ".") } -// FromPunycode returns uncode domain name from provided punycode string. +// FromPunycode returns unicode domain name from provided punycode string. func FromPunycode(s string) string { tokens := dns.SplitDomainName(s) - if s[len(s)-1] == '.' { - if tokens == nil { - tokens = []string{"", ""} - } else { - tokens = append(tokens, "") - } + switch { + case s == "": + return "" + case tokens == nil: // s == . + return "." + case s[len(s)-1] == '.': + tokens = append(tokens, "") } for i := range tokens { - tokens[i] = string(decodeBytes([]byte(tokens[i]))) + tokens[i] = string(decode([]byte(tokens[i]))) } return strings.Join(tokens, ".") } @@ -117,7 +120,7 @@ func next(b []rune, boundary rune) rune { } // preprune converts unicode rune to lower case. At this time it's not -// supporting all things described RFC3454. +// supporting all things described in RFCs func preprune(r rune) rune { if unicode.IsUpper(r) { r = unicode.ToLower(r) @@ -136,8 +139,8 @@ func tfunc(k, bias rune) rune { return k - bias } -// encodeBytes transforms Unicode input bytes (that represent DNS label) into punycode bytestream -func encodeBytes(input []byte) []byte { +// encode transforms Unicode input bytes (that represent DNS label) into punycode bytestream +func encode(input []byte) []byte { n, bias := _N, _BIAS b := bytes.Runes(input) @@ -204,8 +207,8 @@ func encodeBytes(input []byte) []byte { return out.Bytes() } -// decodeBytes transforms punycode input bytes (that represent DNS label) into Unicode bytestream -func decodeBytes(b []byte) []byte { +// decode transforms punycode input bytes (that represent DNS label) into Unicode bytestream +func decode(b []byte) []byte { src := b // b would move and we need to keep it n, bias := _N, _BIAS diff --git a/idn/punycode_test.go b/idn/punycode_test.go index bd764a9d..3202450a 100644 --- a/idn/punycode_test.go +++ b/idn/punycode_test.go @@ -28,11 +28,11 @@ var testcases = [][2]string{ func TestEncodeDecodePunycode(t *testing.T) { for _, tst := range testcases { - enc := encodeBytes([]byte(tst[0])) + enc := encode([]byte(tst[0])) if string(enc) != tst[1] { t.Errorf("%s encodeded as %s but should be %s", tst[0], enc, tst[1]) } - dec := decodeBytes([]byte(tst[1])) + dec := decode([]byte(tst[1])) if string(dec) != strings.ToLower(tst[0]) { t.Errorf("%s decoded as %s but should be %s", tst[1], dec, strings.ToLower(tst[0])) } @@ -66,6 +66,15 @@ func TestEncodeDecodeFinalPeriod(t *testing.T) { if decoded != strings.ToLower(tst[0]+".") { t.Errorf("invalid result from string conversion to punycode when period added, %#v and should be %#v", decoded, tst[0]+".") } + full = ToPunycode(tst[0]) + if full != tst[1] { + t.Errorf("invalid result from string conversion to punycode when no period added at the end, %#v and should be %#v", full, tst[1]+".") + } + // assert punycode.com. == unicode.com. + decoded = FromPunycode(tst[1]) + if decoded != strings.ToLower(tst[0]) { + t.Errorf("invalid result from string conversion to punycode when no period added, %#v and should be %#v", decoded, tst[0]+".") + } } }