Fix HIP record unpacking

* limit decoding of Hit to HitLength
* limit decoding of PublicKey to PublicKeyLength
* limit decoding of RendezvousServers to rdata's length
This commit is contained in:
Andrew Tunnell-Jones 2014-02-22 05:28:48 +00:00
parent 979b2ea731
commit c500de0e7a
2 changed files with 69 additions and 35 deletions

66
msg.go
View File

@ -748,7 +748,7 @@ func packStructCompress(any interface{}, msg []byte, off int, compression map[st
// Unpack a reflect.StructValue from msg. // Unpack a reflect.StructValue from msg.
// Same restrictions as packStructValue. // Same restrictions as packStructValue.
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) { func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) {
var rdstart int var rdend int
lenmsg := len(msg) lenmsg := len(msg)
for i := 0; i < val.NumField(); i++ { for i := 0; i < val.NumField(); i++ {
switch fv := val.Field(i); fv.Kind() { switch fv := val.Field(i); fv.Kind() {
@ -762,7 +762,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// HIP record slice of name (or none) // HIP record slice of name (or none)
servers := make([]string, 0) servers := make([]string, 0)
var s string var s string
for off < lenmsg { for off < rdend {
s, off, err = UnpackDomainName(msg, off) s, off, err = UnpackDomainName(msg, off)
if err != nil { if err != nil {
return lenmsg, err return lenmsg, err
@ -772,9 +772,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
fv.Set(reflect.ValueOf(servers)) fv.Set(reflect.ValueOf(servers))
case `dns:"txt"`: case `dns:"txt"`:
txt := make([]string, 0) txt := make([]string, 0)
rdlength := off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
Txts: Txts:
if off == lenmsg || rdlength == off { // dyn. updates, no rdata is OK if off == lenmsg || rdend == off { // dyn. updates, no rdata is OK
break break
} }
l := int(msg[off]) l := int(msg[off])
@ -783,15 +782,13 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
txt = append(txt, string(msg[off+1:off+l+1])) txt = append(txt, string(msg[off+1:off+l+1]))
off += l + 1 off += l + 1
if off < rdlength { if off < rdend {
// More // More
goto Txts goto Txts
} }
fv.Set(reflect.ValueOf(txt)) fv.Set(reflect.ValueOf(txt))
case `dns:"opt"`: // edns0 case `dns:"opt"`: // edns0
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) if off == rdend {
endrr := off + rdlength
if rdlength == 0 {
// This is an EDNS0 (OPT Record) with no rdata // This is an EDNS0 (OPT Record) with no rdata
// We can safely return here. // We can safely return here.
break break
@ -804,7 +801,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
code, off = unpackUint16(msg, off) code, off = unpackUint16(msg, off)
optlen, off1 := unpackUint16(msg, off) optlen, off1 := unpackUint16(msg, off)
if off1+int(optlen) > off+rdlength { if off1+int(optlen) > rdend {
return lenmsg, &Error{err: "overflow unpacking opt"} return lenmsg, &Error{err: "overflow unpacking opt"}
} }
switch code { switch code {
@ -864,7 +861,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
// do nothing? // do nothing?
off = off1 + int(optlen) off = off1 + int(optlen)
} }
if off < endrr { if off < rdend {
goto Option goto Option
} }
fv.Set(reflect.ValueOf(edns)) fv.Set(reflect.ValueOf(edns))
@ -872,16 +869,16 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
if off == lenmsg { if off == lenmsg {
break // dyn. update break // dyn. update
} }
if off+net.IPv4len > lenmsg { if off+net.IPv4len > rdend {
return lenmsg, &Error{err: "overflow unpacking a"} return lenmsg, &Error{err: "overflow unpacking a"}
} }
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]))) fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
off += net.IPv4len off += net.IPv4len
case `dns:"aaaa"`: case `dns:"aaaa"`:
if off == lenmsg { if off == rdend {
break break
} }
if off+net.IPv6len > lenmsg { if off+net.IPv6len > rdend {
return lenmsg, &Error{err: "overflow unpacking aaaa"} return lenmsg, &Error{err: "overflow unpacking aaaa"}
} }
fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4], fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4],
@ -890,11 +887,9 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
off += net.IPv6len off += net.IPv6len
case `dns:"wks"`: case `dns:"wks"`:
// Rest of the record is the bitmap // Rest of the record is the bitmap
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength
serv := make([]uint16, 0) serv := make([]uint16, 0)
j := 0 j := 0
for off < endrr { for off < rdend {
b := msg[off] b := msg[off]
// Check the bits one by one, and set the type // Check the bits one by one, and set the type
if b&0x80 == 0x80 { if b&0x80 == 0x80 {
@ -926,19 +921,17 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
} }
fv.Set(reflect.ValueOf(serv)) fv.Set(reflect.ValueOf(serv))
case `dns:"nsec"`: // NSEC/NSEC3 case `dns:"nsec"`: // NSEC/NSEC3
if off == lenmsg { if off == rdend {
break break
} }
// Rest of the record is the type bitmap // Rest of the record is the type bitmap
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) if off+2 > rdend {
endrr := rdstart + rdlength
if off+2 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking nsecx"} return lenmsg, &Error{err: "overflow unpacking nsecx"}
} }
nsec := make([]uint16, 0) nsec := make([]uint16, 0)
length := 0 length := 0
window := 0 window := 0
for off+2 < endrr { for off+2 < rdend {
window = int(msg[off]) window = int(msg[off])
length = int(msg[off+1]) length = int(msg[off+1])
//println("off, windows, length, end", off, window, length, endrr) //println("off, windows, length, end", off, window, length, endrr)
@ -992,7 +985,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
return lenmsg, err return lenmsg, err
} }
if val.Type().Field(i).Name == "Hdr" { if val.Type().Field(i).Name == "Hdr" {
rdstart = off rdend = off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
} }
case reflect.Uint8: case reflect.Uint8:
if off == lenmsg { if off == lenmsg {
@ -1050,22 +1043,26 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")} return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")}
case `dns:"hex"`: case `dns:"hex"`:
// Rest of the RR is hex encoded, network order an issue here? // Rest of the RR is hex encoded, network order an issue here?
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) hexend := rdend
endrr := rdstart + rdlength if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
if endrr > lenmsg { hexend = off + int(val.FieldByName("HitLength").Uint())
}
if hexend > rdend {
return lenmsg, &Error{err: "overflow unpacking hex"} return lenmsg, &Error{err: "overflow unpacking hex"}
} }
s = hex.EncodeToString(msg[off:endrr]) s = hex.EncodeToString(msg[off:hexend])
off = endrr off = hexend
case `dns:"base64"`: case `dns:"base64"`:
// Rest of the RR is base64 encoded value // Rest of the RR is base64 encoded value
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) b64end := rdend
endrr := rdstart + rdlength if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
if endrr > lenmsg { b64end = off + int(val.FieldByName("PublicKeyLength").Uint())
}
if b64end > rdend {
return lenmsg, &Error{err: "overflow unpacking base64"} return lenmsg, &Error{err: "overflow unpacking base64"}
} }
s = unpackBase64(msg[off:endrr]) s = unpackBase64(msg[off:b64end])
off = endrr off = b64end
case `dns:"cdomain-name"`: case `dns:"cdomain-name"`:
fallthrough fallthrough
case `dns:"domain-name"`: case `dns:"domain-name"`:
@ -1121,9 +1118,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
s = hex.EncodeToString(msg[off : off+size]) s = hex.EncodeToString(msg[off : off+size])
off += size off += size
case `dns:"txt"`: case `dns:"txt"`:
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
Txt: Txt:
if off >= lenmsg || off+1+int(msg[off]) > lenmsg { if off >= lenmsg || off+1+int(msg[off]) > rdend {
return lenmsg, &Error{err: "overflow unpacking txt"} return lenmsg, &Error{err: "overflow unpacking txt"}
} }
n := int(msg[off]) n := int(msg[off])
@ -1132,7 +1128,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
s += string(msg[off+i]) s += string(msg[off+i])
} }
off += n off += n
if off < rdlength { if off < rdend {
// More to // More to
goto Txt goto Txt
} }

View File

@ -391,6 +391,44 @@ b1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D
// www.example.com. 3600 IN HIP 2 200100107B1A74DF365639CC39F1D578 AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQb1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D rvs.example.com. // www.example.com. 3600 IN HIP 2 200100107B1A74DF365639CC39F1D578 AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQb1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D rvs.example.com.
} }
func TestHIP(t *testing.T) {
h := `www.example.com. IN HIP ( 2 200100107B1A74DF365639CC39F1D578
AwEAAbdxyhNuSutc5EMzxTs9LBPCIkOFH8cIvM4p
9+LrV4e19WzK00+CI6zBCQTdtWsuxKbWIy87UOoJTwkUs7lBu+Upr1gsNrut79ryra+bSRGQ
b1slImA8YVJyuIDsj7kwzG7jnERNqnWxZ48AWkskmdHaVDP4BcelrTI3rMXdXF5D
rvs1.example.com.
rvs2.example.com. )`
rr, err := NewRR(h)
if err != nil {
t.Fatalf("Failed to parse RR: %s", err)
}
t.Logf("RR: %s", rr)
msg := new(Msg)
msg.Answer = []RR{rr, rr}
bytes, err := msg.Pack()
if err != nil {
t.Fatalf("Failed to pack msg: %s", err)
}
if err := msg.Unpack(bytes); err != nil {
t.Fatalf("Failed to unpack msg: %s", err)
}
if len(msg.Answer) != 2 {
t.Fatalf("2 answers expected: %V", msg)
}
for i, rr := range msg.Answer {
rr := rr.(*HIP)
t.Logf("RR: %s", rr)
if l := len(rr.RendezvousServers); l != 2 {
t.Fatalf("2 servers expected, only %d in record %d:\n%V", l, i, msg)
}
for j, s := range []string{"rvs1.example.com.", "rvs2.example.com."} {
if rr.RendezvousServers[j] != s {
t.Fatalf("Expected server %d of record %d to be %s:\n%V", j, i, s, msg)
}
}
}
}
func ExampleSOA() { func ExampleSOA() {
s := "example.com. 1000 SOA master.example.com. admin.example.com. 1 4294967294 4294967293 4294967295 100" s := "example.com. 1000 SOA master.example.com. admin.example.com. 1 4294967294 4294967293 4294967295 100"
if soa, err := NewRR(s); err == nil { if soa, err := NewRR(s); err == nil {