diff --git a/msg.go b/msg.go index 19d9147e..5bbc504b 100644 --- a/msg.go +++ b/msg.go @@ -686,6 +686,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str off++ } case `dns:"wks"`: + // TODO(miek): this is wrong should be lenrd if off == lenmsg { break // dyn. updates } @@ -893,7 +894,7 @@ func packStructCompress(any interface{}, msg []byte, off int, compression map[st // Unpack a reflect.StructValue from msg. // Same restrictions as packStructValue. func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) { - var rdend int + var lenrd int lenmsg := len(msg) for i := 0; i < val.NumField(); i++ { if off > lenmsg { @@ -907,7 +908,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er // therefore it's expected that this interface would be PrivateRdata switch data := fv.Interface().(type) { case PrivateRdata: - n, err := data.Unpack(msg[off:rdend]) + n, err := data.Unpack(msg[off:lenrd]) if err != nil { return lenmsg, err } @@ -923,7 +924,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er // HIP record slice of name (or none) servers := make([]string, 0) var s string - for off < rdend { + for off < lenrd { s, off, err = UnpackDomainName(msg, off) if err != nil { return lenmsg, err @@ -932,17 +933,17 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er } fv.Set(reflect.ValueOf(servers)) case `dns:"txt"`: - if off == lenmsg || rdend == off { + if off == lenmsg || lenrd == off { break } var txt []string - txt, off, err = unpackTxt(msg, off, rdend) + txt, off, err = unpackTxt(msg, off, lenrd) if err != nil { return lenmsg, err } fv.Set(reflect.ValueOf(txt)) case `dns:"opt"`: // edns0 - if off == rdend { + if off == lenrd { // This is an EDNS0 (OPT Record) with no rdata // We can safely return here. break @@ -955,7 +956,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er } code, off = unpackUint16(msg, off) optlen, off1 := unpackUint16(msg, off) - if off1+int(optlen) > rdend { + if off1+int(optlen) > lenrd { return lenmsg, &Error{err: "overflow unpacking opt"} } switch code { @@ -1043,7 +1044,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er // Rest of the record is the bitmap serv := make([]uint16, 0) j := 0 - for off < rdend { + for off < lenrd { if off+1 > lenmsg { return lenmsg, &Error{err: "overflow unpacking wks"} } @@ -1078,17 +1079,17 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er } fv.Set(reflect.ValueOf(serv)) case `dns:"nsec"`: // NSEC/NSEC3 - if off == rdend { + if off == lenrd { break } // Rest of the record is the type bitmap - if off+2 > rdend || off+2 > lenmsg { + if off+2 > lenrd || off+2 > lenmsg { return lenmsg, &Error{err: "overflow unpacking nsecx"} } nsec := make([]uint16, 0) length := 0 window := 0 - for off+2 < rdend { + for off+2 < lenrd { window = int(msg[off]) length = int(msg[off+1]) //println("off, windows, length, end", off, window, length, endrr) @@ -1145,7 +1146,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er return lenmsg, err } if val.Type().Field(i).Name == "Hdr" { - rdend = off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) + lenrd = off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) } case reflect.Uint8: if off == lenmsg { @@ -1202,22 +1203,22 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er default: return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")} case `dns:"hex"`: - hexend := rdend + hexend := lenrd if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { hexend = off + int(val.FieldByName("HitLength").Uint()) } - if hexend > rdend || hexend > lenmsg { + if hexend > lenrd || hexend > lenmsg { return lenmsg, &Error{err: "overflow unpacking hex"} } s = hex.EncodeToString(msg[off:hexend]) off = hexend case `dns:"base64"`: // Rest of the RR is base64 encoded value - b64end := rdend + b64end := lenrd if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { b64end = off + int(val.FieldByName("PublicKeyLength").Uint()) } - if b64end > rdend || b64end > lenmsg { + if b64end > lenrd || b64end > lenmsg { return lenmsg, &Error{err: "overflow unpacking base64"} } s = unpackBase64(msg[off:b64end]) diff --git a/update_test.go b/update_test.go index 8206d2d1..00e269b3 100644 --- a/update_test.go +++ b/update_test.go @@ -30,7 +30,7 @@ func TestDynamicUpdateUnpack(t *testing.T) { msg := new(Msg) err := msg.Unpack(buf) if err != nil { - t.Log("failed to unpack: " + err.Error() + "\n" + msg.String())) + t.Log("failed to unpack: " + err.Error() + "\n" + msg.String()) t.Fail() } }