diff --git a/msg_helpers.go b/msg_helpers.go index 7a4184dc..ecd9280f 100644 --- a/msg_helpers.go +++ b/msg_helpers.go @@ -25,12 +25,13 @@ func unpackDataA(msg []byte, off int) (net.IP, int, error) { } func packDataA(a net.IP, msg []byte, off int) (int, error) { - // It must be a slice of 4, even if it is 16, we encode only the first 4 - if off+net.IPv4len > len(msg) { - return len(msg), &Error{err: "overflow packing a"} - } switch len(a) { case net.IPv4len, net.IPv6len: + // It must be a slice of 4, even if it is 16, we encode only the first 4 + if off+net.IPv4len > len(msg) { + return len(msg), &Error{err: "overflow packing a"} + } + copy(msg[off:], a.To4()) off += net.IPv4len case 0: @@ -51,12 +52,12 @@ func unpackDataAAAA(msg []byte, off int) (net.IP, int, error) { } func packDataAAAA(aaaa net.IP, msg []byte, off int) (int, error) { - if off+net.IPv6len > len(msg) { - return len(msg), &Error{err: "overflow packing aaaa"} - } - switch len(aaaa) { case net.IPv6len: + if off+net.IPv6len > len(msg) { + return len(msg), &Error{err: "overflow packing aaaa"} + } + copy(msg[off:], aaaa) off += net.IPv6len case 0: diff --git a/msg_test.go b/msg_test.go index 616e44db..55cc87a1 100644 --- a/msg_test.go +++ b/msg_test.go @@ -306,3 +306,20 @@ func TestPackUnpackManyCompressionPointers(t *testing.T) { } } } + +func TestLenDynamicA(t *testing.T) { + for _, rr := range []RR{ + testRR("example.org. A"), + testRR("example.org. AAAA"), + testRR("example.org. L32"), + } { + msg := make([]byte, Len(rr)) + off, err := PackRR(rr, msg, 0, nil, false) + if err != nil { + t.Fatalf("PackRR failed for %T: %v", rr, err) + } + if off != len(msg) { + t.Errorf("Len(rr) wrong for %T: Len(rr) = %d, PackRR(rr) = %d", rr, len(msg), off) + } + } +} diff --git a/types_generate.go b/types_generate.go index aa05a085..cbb4a00c 100644 --- a/types_generate.go +++ b/types_generate.go @@ -196,9 +196,9 @@ func main() { case st.Tag(i) == `dns:"any"`: o("l += len(rr.%s)\n") case st.Tag(i) == `dns:"a"`: - o("l += net.IPv4len // %s\n") + o("if len(rr.%s) != 0 { l += net.IPv4len }\n") case st.Tag(i) == `dns:"aaaa"`: - o("l += net.IPv6len // %s\n") + o("if len(rr.%s) != 0 { l += net.IPv6len }\n") case st.Tag(i) == `dns:"txt"`: o("for _, t := range rr.%s { l += len(t) + 1 }\n") case st.Tag(i) == `dns:"uint48"`: diff --git a/ztypes.go b/ztypes.go index 19a542d3..495a83e3 100644 --- a/ztypes.go +++ b/ztypes.go @@ -240,12 +240,16 @@ func (rr *X25) Header() *RR_Header { return &rr.Hdr } // len() functions func (rr *A) len(off int, compression map[string]struct{}) int { l := rr.Hdr.len(off, compression) - l += net.IPv4len // A + if len(rr.A) != 0 { + l += net.IPv4len + } return l } func (rr *AAAA) len(off int, compression map[string]struct{}) int { l := rr.Hdr.len(off, compression) - l += net.IPv6len // AAAA + if len(rr.AAAA) != 0 { + l += net.IPv6len + } return l } func (rr *AFSDB) len(off int, compression map[string]struct{}) int { @@ -364,8 +368,10 @@ func (rr *KX) len(off int, compression map[string]struct{}) int { } func (rr *L32) len(off int, compression map[string]struct{}) int { l := rr.Hdr.len(off, compression) - l += 2 // Preference - l += net.IPv4len // Locator32 + l += 2 // Preference + if len(rr.Locator32) != 0 { + l += net.IPv4len + } return l } func (rr *L64) len(off int, compression map[string]struct{}) int {