diff --git a/edns.go b/edns.go index 0f6c4995..32047151 100644 --- a/edns.go +++ b/edns.go @@ -273,22 +273,16 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error { if e.SourceNetmask > net.IPv4len*8 || e.SourceScope > net.IPv4len*8 { return errors.New("dns: bad netmask") } - addr := make([]byte, net.IPv4len) - for i := 0; i < net.IPv4len && 4+i < len(b); i++ { - addr[i] = b[4+i] - } - e.Address = net.IPv4(addr[0], addr[1], addr[2], addr[3]) + addr := make(net.IP, net.IPv4len) + copy(addr, b[4:]) + e.Address = addr.To16() case 2: if e.SourceNetmask > net.IPv6len*8 || e.SourceScope > net.IPv6len*8 { return errors.New("dns: bad netmask") } - addr := make([]byte, net.IPv6len) - for i := 0; i < net.IPv6len && 4+i < len(b); i++ { - addr[i] = b[4+i] - } - e.Address = net.IP{addr[0], addr[1], addr[2], addr[3], addr[4], - addr[5], addr[6], addr[7], addr[8], addr[9], addr[10], - addr[11], addr[12], addr[13], addr[14], addr[15]} + addr := make(net.IP, net.IPv6len) + copy(addr, b[4:]) + e.Address = addr default: return errors.New("dns: bad address family") } diff --git a/edns_test.go b/edns_test.go index 6d27a4aa..69682a54 100644 --- a/edns_test.go +++ b/edns_test.go @@ -1,6 +1,9 @@ package dns -import "testing" +import ( + "net" + "testing" +) func TestOPTTtl(t *testing.T) { e := &OPT{} @@ -63,8 +66,8 @@ func TestOPTTtl(t *testing.T) { e.SetExtendedRcode(42) // ExtendedRcode has the last 4 bits set to 0. - if e.ExtendedRcode() != 42 & 0xFFFFFFF0 { - t.Errorf("set 42, expected %d, got %d", 42 & 0xFFFFFFF0, e.ExtendedRcode()) + if e.ExtendedRcode() != 42&0xFFFFFFF0 { + t.Errorf("set 42, expected %d, got %d", 42&0xFFFFFFF0, e.ExtendedRcode()) } // This will reset the 8 upper bits of the extended rcode @@ -73,3 +76,36 @@ func TestOPTTtl(t *testing.T) { t.Errorf("Setting a non-extended rcode is expected to set extended rcode to 0, got: %d", e.ExtendedRcode()) } } + +func TestEDNS0_SUBNETUnpack(t *testing.T) { + for _, ip := range []net.IP{ + net.IPv4(0xde, 0xad, 0xbe, 0xef), + net.ParseIP("192.0.2.1"), + net.ParseIP("2001:db8::68"), + } { + var s1 EDNS0_SUBNET + s1.Address = ip + + if ip.To4() == nil { + s1.Family = 2 + s1.SourceNetmask = net.IPv6len * 8 + } else { + s1.Family = 1 + s1.SourceNetmask = net.IPv4len * 8 + } + + b, err := s1.pack() + if err != nil { + t.Fatalf("failed to pack: %v", err) + } + + var s2 EDNS0_SUBNET + if err := s2.unpack(b); err != nil { + t.Fatalf("failed to unpack: %v", err) + } + + if !ip.Equal(s2.Address) { + t.Errorf("address different after unpacking; expected %s, got %s", ip, s2.Address) + } + } +}