diff --git a/msg.go b/msg.go index 29306f61..03a79a82 100644 --- a/msg.go +++ b/msg.go @@ -748,7 +748,7 @@ func packStructCompress(any interface{}, msg []byte, off int, compression map[st // Unpack a reflect.StructValue from msg. // Same restrictions as packStructValue. func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) { - var rdstart int + var rdend int lenmsg := len(msg) for i := 0; i < val.NumField(); i++ { switch fv := val.Field(i); fv.Kind() { @@ -762,7 +762,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er // HIP record slice of name (or none) servers := make([]string, 0) var s string - for off < lenmsg { + for off < rdend { s, off, err = UnpackDomainName(msg, off) if err != nil { return lenmsg, err @@ -772,9 +772,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er fv.Set(reflect.ValueOf(servers)) case `dns:"txt"`: txt := make([]string, 0) - rdlength := off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) Txts: - if off == lenmsg || rdlength == off { // dyn. updates, no rdata is OK + if off == lenmsg || rdend == off { // dyn. updates, no rdata is OK break } l := int(msg[off]) @@ -783,15 +782,13 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er } txt = append(txt, string(msg[off+1:off+l+1])) off += l + 1 - if off < rdlength { + if off < rdend { // More goto Txts } fv.Set(reflect.ValueOf(txt)) case `dns:"opt"`: // edns0 - rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) - endrr := off + rdlength - if rdlength == 0 { + if off == rdend { // This is an EDNS0 (OPT Record) with no rdata // We can safely return here. break @@ -804,7 +801,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er } code, off = unpackUint16(msg, off) optlen, off1 := unpackUint16(msg, off) - if off1+int(optlen) > off+rdlength { + if off1+int(optlen) > rdend { return lenmsg, &Error{err: "overflow unpacking opt"} } switch code { @@ -864,7 +861,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er // do nothing? off = off1 + int(optlen) } - if off < endrr { + if off < rdend { goto Option } fv.Set(reflect.ValueOf(edns)) @@ -872,16 +869,16 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er if off == lenmsg { break // dyn. update } - if off+net.IPv4len > lenmsg { + if off+net.IPv4len > rdend { return lenmsg, &Error{err: "overflow unpacking a"} } fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]))) off += net.IPv4len case `dns:"aaaa"`: - if off == lenmsg { + if off == rdend { break } - if off+net.IPv6len > lenmsg { + if off+net.IPv6len > rdend { return lenmsg, &Error{err: "overflow unpacking aaaa"} } fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4], @@ -890,11 +887,9 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er off += net.IPv6len case `dns:"wks"`: // Rest of the record is the bitmap - rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) - endrr := rdstart + rdlength serv := make([]uint16, 0) j := 0 - for off < endrr { + for off < rdend { b := msg[off] // Check the bits one by one, and set the type if b&0x80 == 0x80 { @@ -926,19 +921,17 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er } fv.Set(reflect.ValueOf(serv)) case `dns:"nsec"`: // NSEC/NSEC3 - if off == lenmsg { + if off == rdend { break } // Rest of the record is the type bitmap - rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) - endrr := rdstart + rdlength - if off+2 > lenmsg { + if off+2 > rdend { return lenmsg, &Error{err: "overflow unpacking nsecx"} } nsec := make([]uint16, 0) length := 0 window := 0 - for off+2 < endrr { + for off+2 < rdend { window = int(msg[off]) length = int(msg[off+1]) //println("off, windows, length, end", off, window, length, endrr) @@ -992,7 +985,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er return lenmsg, err } if val.Type().Field(i).Name == "Hdr" { - rdstart = off + rdend = off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) } case reflect.Uint8: if off == lenmsg { @@ -1050,22 +1043,26 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")} case `dns:"hex"`: // Rest of the RR is hex encoded, network order an issue here? - rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) - endrr := rdstart + rdlength - if endrr > lenmsg { + hexend := rdend + if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { + hexend = off + int(val.FieldByName("HitLength").Uint()) + } + if hexend > rdend { return lenmsg, &Error{err: "overflow unpacking hex"} } - s = hex.EncodeToString(msg[off:endrr]) - off = endrr + s = hex.EncodeToString(msg[off:hexend]) + off = hexend case `dns:"base64"`: // Rest of the RR is base64 encoded value - rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) - endrr := rdstart + rdlength - if endrr > lenmsg { + b64end := rdend + if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) { + b64end = off + int(val.FieldByName("PublicKeyLength").Uint()) + } + if b64end > rdend { return lenmsg, &Error{err: "overflow unpacking base64"} } - s = unpackBase64(msg[off:endrr]) - off = endrr + s = unpackBase64(msg[off:b64end]) + off = b64end case `dns:"cdomain-name"`: fallthrough case `dns:"domain-name"`: @@ -1121,9 +1118,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er s = hex.EncodeToString(msg[off : off+size]) off += size case `dns:"txt"`: - rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) Txt: - if off >= lenmsg || off+1+int(msg[off]) > lenmsg { + if off >= lenmsg || off+1+int(msg[off]) > rdend { return lenmsg, &Error{err: "overflow unpacking txt"} } n := int(msg[off]) @@ -1132,7 +1128,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er s += string(msg[off+i]) } off += n - if off < rdlength { + if off < rdend { // More to goto Txt } diff --git a/parse_test.go b/parse_test.go index 84c4bfe8..f157653a 100644 --- a/parse_test.go +++ b/parse_test.go @@ -391,6 +391,44 @@ b1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D // www.example.com. 3600 IN HIP 2 200100107B1A74DF365639CC39F1D578 AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQb1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D rvs.example.com. } +func TestHIP(t *testing.T) { + h := `www.example.com. IN HIP ( 2 200100107B1A74DF365639CC39F1D578 + AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p +9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQ +b1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D + rvs1.example.com. + rvs2.example.com. )` + rr, err := NewRR(h) + if err != nil { + t.Fatalf("Failed to parse RR: %s", err) + } + t.Logf("RR: %s", rr) + msg := new(Msg) + msg.Answer = []RR{rr, rr} + bytes, err := msg.Pack() + if err != nil { + t.Fatalf("Failed to pack msg: %s", err) + } + if err := msg.Unpack(bytes); err != nil { + t.Fatalf("Failed to unpack msg: %s", err) + } + if len(msg.Answer) != 2 { + t.Fatalf("2 answers expected: %V", msg) + } + for i, rr := range msg.Answer { + rr := rr.(*HIP) + t.Logf("RR: %s", rr) + if l := len(rr.RendezvousServers); l != 2 { + t.Fatalf("2 servers expected, only %d in record %d:\n%V", l, i, msg) + } + for j, s := range []string{"rvs1.example.com.", "rvs2.example.com."} { + if rr.RendezvousServers[j] != s { + t.Fatalf("Expected server %d of record %d to be %s:\n%V", j, i, s, msg) + } + } + } +} + func ExampleSOA() { s := "example.com. 1000 SOA master.example.com. admin.example.com. 1 4294967294 4294967293 4294967295 100" if soa, err := NewRR(s); err == nil {