Look at the domain to see if it is edns

throw this information back upwards so the
the Edns bool can be set.
This commit is contained in:
Miek Gieben 2010-12-23 09:51:43 +01:00
parent 1a50861b43
commit 25f2e3d7e8
1 changed files with 39 additions and 30 deletions

69
msg.go
View File

@ -118,13 +118,13 @@ func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) {
// which is where the next record will start.
// In theory, the pointers are only allowed to jump backward.
// We let them jump anywhere and stop jumping after a while.
func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) {
func unpackDomainName(msg []byte, off int) (s string, off1 int, ok, edns bool) {
s = ""
ptr := 0 // number of pointers followed
Loop:
for {
if off >= len(msg) {
return "", len(msg), false
return "", len(msg), false, false
}
c := int(msg[off])
off++
@ -136,7 +136,7 @@ Loop:
}
// literal string
if off+c > len(msg) {
return "", len(msg), false
return "", len(msg), false, false
}
s += string(msg[off:off+c]) + "."
off += c
@ -148,6 +148,7 @@ Loop:
// but the parsing here (is for now) relatively simple
// The name must be the root label aka 00
// TODO check! MG
edns = true
s = ""
off++
break Loop
@ -158,7 +159,7 @@ Loop:
// also, don't follow too many pointers --
// maybe there's a loop.
if off >= len(msg) {
return "", len(msg), false
return "", len(msg), false, false
}
c1 := msg[off]
off++
@ -166,18 +167,18 @@ Loop:
off1 = off
}
if ptr++; ptr > 10 {
return "", len(msg), false
return "", len(msg), false, false
}
off = (c^0xC0)<<8 | int(c1)
default:
// 0x80 and 0x40 are reserved
return "", len(msg), false
return "", len(msg), false, false
}
}
if ptr == 0 {
off1 = off
}
return s, off1, true
return s, off1, true, edns
}
// TODO(rsc): Move into generic library?
@ -189,12 +190,15 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o
switch fv := val.Field(i).(type) {
default:
BadType:
fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v\n", f.Type)
return len(msg), false
case *reflect.BoolValue:
// Used internally for Edns, not present in the DNS
continue;
case *reflect.SliceValue:
switch f.Tag {
default:
fmt.Fprintf(os.Stderr, "net: dns: unknown IP tag %v", f.Tag)
fmt.Fprintf(os.Stderr, "net: dns: unknown IP tag %v\n", f.Tag)
return len(msg), false
case "OPT": // edns
for j := 0; j < val.Field(i).(*reflect.SliceValue).Len(); j++ {
@ -313,29 +317,32 @@ func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
// Unpack a reflect.StructValue from msg.
// Same restrictions as packStructValue.
func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) {
func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok, edns bool) {
for i := 0; i < val.NumField(); i++ {
f := val.Type().(*reflect.StructType).Field(i)
switch fv := val.Field(i).(type) {
default:
BadType:
fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type)
return len(msg), false
return len(msg), false, false
case *reflect.BoolValue:
// Used internally for Edns, not present in the DNS
continue;
case *reflect.SliceValue:
switch f.Tag {
default:
fmt.Fprintf(os.Stderr, "net: dns: unknown IP tag %v", f.Tag)
return len(msg), false
return len(msg), false, false
case "A":
if off+net.IPv4len > len(msg) {
return len(msg), false
return len(msg), false, false
}
b := net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])
fv.Set(reflect.NewValue(b).(*reflect.SliceValue))
off += net.IPv4len
case "AAAA":
if off+net.IPv6len > len(msg) {
return len(msg), false
return len(msg), false, false
}
p := make(net.IP, net.IPv6len)
copy(p, msg[off:off+net.IPv6len])
@ -346,28 +353,28 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
// do it here
}
case *reflect.StructValue:
off, ok = unpackStructValue(fv, msg, off)
off, ok, edns = unpackStructValue(fv, msg, off)
case *reflect.UintValue:
switch fv.Type().Kind() {
default:
goto BadType
case reflect.Uint8:
if off+1 > len(msg) {
return len(msg), false
return len(msg), false, false
}
i := uint8(msg[off])
fv.Set(uint64(i))
off++
case reflect.Uint16:
if off+2 > len(msg) {
return len(msg), false
return len(msg), false, false
}
i := uint16(msg[off])<<8 | uint16(msg[off+1])
fv.Set(uint64(i))
off += 2
case reflect.Uint32:
if off+4 > len(msg) {
return len(msg), false
return len(msg), false, false
}
i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
fv.Set(uint64(i))
@ -378,7 +385,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
switch f.Tag {
default:
fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag)
return len(msg), false
return len(msg), false, false
case "hex":
// Rest of the RR is hex encoded
rdlength := int(val.FieldByName("Hdr").(*reflect.StructValue).FieldByName("Rdlength").(*reflect.UintValue).Get())
@ -414,13 +421,13 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
s = string(b64)
off += rdlength - consumed
case "domain-name":
s, off, ok = unpackDomainName(msg, off)
s, off, ok, edns = unpackDomainName(msg, off)
if !ok {
return len(msg), false
return len(msg), false, false
}
case "":
if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
return len(msg), false
return len(msg), false, false
}
n := int(msg[off])
off++
@ -434,12 +441,12 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
fv.Set(s)
}
}
return off, true
return off, true, edns
}
func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
off, ok = unpackStructValue(structValue(any), msg, off)
return off, ok
func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok, edns bool) {
off, ok, edns = unpackStructValue(structValue(any), msg, off)
return off, ok, edns
}
// THIS can GO TODO
@ -493,10 +500,12 @@ func unpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) {
// unpack just the header, to find the rr type and length
// check if we have an edns packet, and set h.Edns to true
var h RR_Header
var edns bool
off0 := off
if off, ok = unpackStruct(&h, msg, off); !ok {
if off, ok, edns = unpackStruct(&h, msg, off); !ok {
return nil, len(msg), false
}
h.Edns = edns // set Edns if found
end := off + int(h.Rdlength)
// make an rr of that type and re-unpack.
@ -507,7 +516,7 @@ func unpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) {
}
rr = mk()
off, ok = unpackStruct(rr, msg, off0)
off, ok, _ = unpackStruct(rr, msg, off0) // don't care about edns?
if off != end {
// added MG
// println("Hier gaat het dan fout, echt waar en was if off0", off0)
@ -658,7 +667,7 @@ func (dns *Msg) Unpack(msg []byte) bool {
var dh Header
off := 0
var ok bool
if off, ok = unpackStruct(&dh, msg, off); !ok {
if off, ok, _ = unpackStruct(&dh, msg, off); !ok {
return false
}
dns.Id = dh.Id
@ -677,7 +686,7 @@ func (dns *Msg) Unpack(msg []byte) bool {
dns.Extra = make([]RR, dh.Arcount)
for i := 0; i < len(dns.Question); i++ {
off, ok = unpackStruct(&dns.Question[i], msg, off)
off, ok, _ = unpackStruct(&dns.Question[i], msg, off)
}
for i := 0; i < len(dns.Answer); i++ {
dns.Answer[i], off, ok = unpackRR(msg, off)