diff --git a/idn/punycode.go b/idn/punycode.go index a8d16ed2..d3433fbd 100644 --- a/idn/punycode.go +++ b/idn/punycode.go @@ -5,6 +5,7 @@ import ( "bytes" "strings" "unicode" + "unicode/utf8" "github.com/miekg/dns" ) @@ -27,9 +28,15 @@ const ( ) // ToPunycode converts unicode domain names to DNS-appropriate punycode names. -// This function would return an empty string result for domain names with +// This function will return an empty string result for domain names with // invalid unicode strings. This function expects domain names in lowercase. func ToPunycode(s string) string { + // Early check to see if encoding is needed. + // This will prevent making heap allocations when not needed. + if !needToPunycode(s) { + return s + } + tokens := dns.SplitDomainName(s) switch { case s == "": @@ -51,7 +58,14 @@ func ToPunycode(s string) string { } // FromPunycode returns unicode domain name from provided punycode string. +// This function expects punycode strings in lowercase. func FromPunycode(s string) string { + // Early check to see if decoding is needed. + // This will prevent making heap allocations when not needed. + if !needFromPunycode(s) { + return s + } + tokens := dns.SplitDomainName(s) switch { case s == "": @@ -143,6 +157,51 @@ func tfunc(k, bias rune) rune { return k - bias } +// needToPunycode returns true for strings that require punycode encoding +// (contain unicode characters). +func needToPunycode(s string) bool { + // This function is very similar to bytes.Runes. We don't use bytes.Runes + // because it makes a heap allocation that's not needed here. + for i := 0; len(s) > 0; i++ { + r, l := utf8.DecodeRuneInString(s) + if r > 0x7f { + return true + } + s = s[l:] + } + return false +} + +// needFromPunycode returns true for strings that require punycode decoding. +func needFromPunycode(s string) bool { + if s == "." { + return false + } + + off := 0 + end := false + pl := len(_PREFIX) + sl := len(s) + + // If s starts with _PREFIX. + if sl > pl && s[off:off+pl] == _PREFIX { + return true + } + + for { + // Find the part after the next ".". + off, end = dns.NextLabel(s, off) + if end { + return false + } + // If this parts starts with _PREFIX. + if sl-off > pl && s[off:off+pl] == _PREFIX { + return true + } + } + panic("dns: not reached") +} + // encode transforms Unicode input bytes (that represent DNS label) into // punycode bytestream. This function would return nil if there's an invalid // character in the label. diff --git a/idn/punycode_test.go b/idn/punycode_test.go index 0311ecdb..9c9a15f0 100644 --- a/idn/punycode_test.go +++ b/idn/punycode_test.go @@ -8,9 +8,9 @@ import ( var testcases = [][2]string{ {"", ""}, {"a", "a"}, - {"A-B", "a-b"}, - {"A-B-C", "a-b-c"}, - {"AbC", "abc"}, + {"a-b", "a-b"}, + {"a-b-c", "a-b-c"}, + {"abc", "abc"}, {"я", "xn--41a"}, {"zя", "xn--z-0ub"}, {"яZ", "xn--z-zub"},