diff --git a/parse_test.go b/parse_test.go index 63e0fb53..6c2563e0 100644 --- a/parse_test.go +++ b/parse_test.go @@ -1654,6 +1654,7 @@ func TestParseSVCB(t *testing.T) { `example.com. SVCB 16 foo.example.org. alpn=f\\\092oo\092,bar,h2`: `example.com. 3600 IN SVCB 16 foo.example.org. alpn="f\\\092oo\\\044bar,h2"`, // From draft-ietf-add-ddr-06 `_dns.example.net. SVCB 1 example.net. alpn=h2 dohpath=/dns-query{?dns}`: `_dns.example.net. 3600 IN SVCB 1 example.net. alpn="h2" dohpath="/dns-query{?dns}"`, + `_dns.example.net. SVCB 1 example.net. alpn=h2 dohpath=/dns\045query{\?dns}`: `_dns.example.net. 3600 IN SVCB 1 example.net. alpn="h2" dohpath="/dns-query{?dns}"`, } for s, o := range svcbs { rr, err := NewRR(s) diff --git a/svcb.go b/svcb.go index b9c6a952..ea58710d 100644 --- a/svcb.go +++ b/svcb.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "net" "sort" "strconv" @@ -789,7 +790,7 @@ type SVCBDoHPath struct { } func (*SVCBDoHPath) Key() SVCBKey { return SVCB_DOHPATH } -func (s *SVCBDoHPath) String() string { return s.Template } +func (s *SVCBDoHPath) String() string { return svcbParamToStr([]byte(s.Template)) } func (s *SVCBDoHPath) len() int { return len(s.Template) } func (s *SVCBDoHPath) pack() ([]byte, error) { return []byte(s.Template), nil } @@ -799,7 +800,11 @@ func (s *SVCBDoHPath) unpack(b []byte) error { } func (s *SVCBDoHPath) parse(b string) error { - s.Template = b + template, err := svcbParseParam(b) + if err != nil { + return fmt.Errorf("dns: svcbdohpath: %w", err) + } + s.Template = string(template) return nil } @@ -825,6 +830,7 @@ type SVCBLocal struct { } func (s *SVCBLocal) Key() SVCBKey { return s.KeyCode } +func (s *SVCBLocal) String() string { return svcbParamToStr(s.Data) } func (s *SVCBLocal) pack() ([]byte, error) { return append([]byte(nil), s.Data...), nil } func (s *SVCBLocal) len() int { return len(s.Data) } @@ -833,50 +839,10 @@ func (s *SVCBLocal) unpack(b []byte) error { return nil } -func (s *SVCBLocal) String() string { - var str strings.Builder - str.Grow(4 * len(s.Data)) - for _, e := range s.Data { - if ' ' <= e && e <= '~' { - switch e { - case '"', ';', ' ', '\\': - str.WriteByte('\\') - str.WriteByte(e) - default: - str.WriteByte(e) - } - } else { - str.WriteString(escapeByte(e)) - } - } - return str.String() -} - func (s *SVCBLocal) parse(b string) error { - data := make([]byte, 0, len(b)) - for i := 0; i < len(b); { - if b[i] != '\\' { - data = append(data, b[i]) - i++ - continue - } - if i+1 == len(b) { - return errors.New("dns: svcblocal: svcb private/experimental key escape unterminated") - } - if isDigit(b[i+1]) { - if i+3 < len(b) && isDigit(b[i+2]) && isDigit(b[i+3]) { - a, err := strconv.ParseUint(b[i+1:i+4], 10, 8) - if err == nil { - i += 4 - data = append(data, byte(a)) - continue - } - } - return errors.New("dns: svcblocal: svcb private/experimental key bad escaped octet") - } else { - data = append(data, b[i+1]) - i += 2 - } + data, err := svcbParseParam(b) + if err != nil { + return fmt.Errorf("dns: svcblocal: svcb private/experimental key %w", err) } s.Data = data return nil @@ -917,3 +883,53 @@ func areSVCBPairArraysEqual(a []SVCBKeyValue, b []SVCBKeyValue) bool { } return true } + +// svcbParamStr converts the value of an SVCB parameter into a DNS presentation-format string. +func svcbParamToStr(s []byte) string { + var str strings.Builder + str.Grow(4 * len(s)) + for _, e := range s { + if ' ' <= e && e <= '~' { + switch e { + case '"', ';', ' ', '\\': + str.WriteByte('\\') + str.WriteByte(e) + default: + str.WriteByte(e) + } + } else { + str.WriteString(escapeByte(e)) + } + } + return str.String() +} + +// svcbParseParam parses a DNS presentation-format string into an SVCB parameter value. +func svcbParseParam(b string) ([]byte, error) { + data := make([]byte, 0, len(b)) + for i := 0; i < len(b); { + if b[i] != '\\' { + data = append(data, b[i]) + i++ + continue + } + if i+1 == len(b) { + return nil, errors.New("escape unterminated") + } + if isDigit(b[i+1]) { + if i+3 < len(b) && isDigit(b[i+2]) && isDigit(b[i+3]) { + a, err := strconv.ParseUint(b[i+1:i+4], 10, 8) + if err == nil { + i += 4 + data = append(data, byte(a)) + continue + } + } + return nil, errors.New("bad escaped octet") + } else { + data = append(data, b[i+1]) + i += 2 + } + } + return data, nil +}