diff --git a/dns_bench_test.go b/dns_bench_test.go index a34fe994..0f5a1628 100644 --- a/dns_bench_test.go +++ b/dns_bench_test.go @@ -233,7 +233,7 @@ func BenchmarkUnpackMX(b *testing.B) { } func BenchmarkPackAAAAA(b *testing.B) { - aaaa := testRR(". IN A ::1") + aaaa := testRR(". IN AAAA ::1") buf := make([]byte, Len(aaaa)) b.ReportAllocs() @@ -244,7 +244,7 @@ func BenchmarkPackAAAAA(b *testing.B) { } func BenchmarkUnpackAAAA(b *testing.B) { - aaaa := testRR(". IN A ::1") + aaaa := testRR(". IN AAAA ::1") buf := make([]byte, Len(aaaa)) PackRR(aaaa, buf, 0, nil, false) diff --git a/scan_rr.go b/scan_rr.go index f48ff789..6096f9b0 100644 --- a/scan_rr.go +++ b/scan_rr.go @@ -133,7 +133,12 @@ func (rr *A) parse(c *zlexer, o, f string) *ParseError { } rr.A = net.ParseIP(l.token) - if rr.A == nil || l.err { + // IPv4 addresses cannot include ":". + // We do this rather than use net.IP's To4() because + // To4() treats IPv4-mapped IPv6 addresses as being + // IPv4. + isIPv4 := !strings.Contains(l.token, ":") + if rr.A == nil || !isIPv4 || l.err { return &ParseError{f, "bad A A", l} } return slurpRemainder(c, f) @@ -146,7 +151,10 @@ func (rr *AAAA) parse(c *zlexer, o, f string) *ParseError { } rr.AAAA = net.ParseIP(l.token) - if rr.AAAA == nil || l.err { + // IPv6 addresses must include ":", and IPv4 + // addresses cannot include ":". + isIPv6 := strings.Contains(l.token, ":") + if rr.AAAA == nil || !isIPv6 || l.err { return &ParseError{f, "bad AAAA AAAA", l} } return slurpRemainder(c, f) diff --git a/scan_test.go b/scan_test.go index a939d211..006a869b 100644 --- a/scan_test.go +++ b/scan_test.go @@ -121,6 +121,51 @@ func TestZoneParserIncludeDisallowed(t *testing.T) { } } +func TestZoneParserAddressAAAA(t *testing.T) { + tests := []struct { + record string + want *AAAA + }{ + { + record: "1.example.org. 600 IN AAAA ::1", + want: &AAAA{Hdr: RR_Header{Name: "1.example.org."}, AAAA: net.IPv6loopback}, + }, + { + record: "2.example.org. 600 IN AAAA ::FFFF:127.0.0.1", + want: &AAAA{Hdr: RR_Header{Name: "2.example.org."}, AAAA: net.ParseIP("::FFFF:127.0.0.1")}, + }, + } + + for _, tc := range tests { + got, err := NewRR(tc.record) + if err != nil { + t.Fatalf("expected no error, but got %s", err) + } + aaaa, ok := got.(*AAAA) + if !ok { + t.Fatalf("expected *AAAA RR, but got %T", aaaa) + } + if g, w := aaaa.AAAA, tc.want.AAAA; !g.Equal(w) { + t.Fatalf("expected AAAA with IP %v, but got %v", g, w) + } + } +} + +func TestZoneParserAddressBad(t *testing.T) { + records := []string{ + "1.bad.example.org. 600 IN A ::1", + "2.bad.example.org. 600 IN A ::FFFF:127.0.0.1", + "3.bad.example.org. 600 IN AAAA 127.0.0.1", + } + + for _, record := range records { + const expect = "bad A" + if got, err := NewRR(record); err == nil || !strings.Contains(err.Error(), expect) { + t.Errorf("NewRR(%v) = %v, want err to contain %q", record, got, expect) + } + } +} + func TestParseTA(t *testing.T) { rr, err := NewRR(` Ta 0 0 0`) if err != nil {