From 25f2e3d7e8c3899079fbf8e4da7cb02fb66f8d09 Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Thu, 23 Dec 2010 09:51:43 +0100 Subject: [PATCH] Look at the domain to see if it is edns throw this information back upwards so the the Edns bool can be set. --- msg.go | 69 +++++++++++++++++++++++++++++++++------------------------- 1 file changed, 39 insertions(+), 30 deletions(-) diff --git a/msg.go b/msg.go index 0955e74e..81123a13 100644 --- a/msg.go +++ b/msg.go @@ -118,13 +118,13 @@ func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) { // which is where the next record will start. // In theory, the pointers are only allowed to jump backward. // We let them jump anywhere and stop jumping after a while. -func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) { +func unpackDomainName(msg []byte, off int) (s string, off1 int, ok, edns bool) { s = "" ptr := 0 // number of pointers followed Loop: for { if off >= len(msg) { - return "", len(msg), false + return "", len(msg), false, false } c := int(msg[off]) off++ @@ -136,7 +136,7 @@ Loop: } // literal string if off+c > len(msg) { - return "", len(msg), false + return "", len(msg), false, false } s += string(msg[off:off+c]) + "." off += c @@ -148,6 +148,7 @@ Loop: // but the parsing here (is for now) relatively simple // The name must be the root label aka 00 // TODO check! MG + edns = true s = "" off++ break Loop @@ -158,7 +159,7 @@ Loop: // also, don't follow too many pointers -- // maybe there's a loop. if off >= len(msg) { - return "", len(msg), false + return "", len(msg), false, false } c1 := msg[off] off++ @@ -166,18 +167,18 @@ Loop: off1 = off } if ptr++; ptr > 10 { - return "", len(msg), false + return "", len(msg), false, false } off = (c^0xC0)<<8 | int(c1) default: // 0x80 and 0x40 are reserved - return "", len(msg), false + return "", len(msg), false, false } } if ptr == 0 { off1 = off } - return s, off1, true + return s, off1, true, edns } // TODO(rsc): Move into generic library? @@ -189,12 +190,15 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o switch fv := val.Field(i).(type) { default: BadType: - fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v\n", f.Type) return len(msg), false + case *reflect.BoolValue: + // Used internally for Edns, not present in the DNS + continue; case *reflect.SliceValue: switch f.Tag { default: - fmt.Fprintf(os.Stderr, "net: dns: unknown IP tag %v", f.Tag) + fmt.Fprintf(os.Stderr, "net: dns: unknown IP tag %v\n", f.Tag) return len(msg), false case "OPT": // edns for j := 0; j < val.Field(i).(*reflect.SliceValue).Len(); j++ { @@ -313,29 +317,32 @@ func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { // Unpack a reflect.StructValue from msg. // Same restrictions as packStructValue. -func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) { +func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok, edns bool) { for i := 0; i < val.NumField(); i++ { f := val.Type().(*reflect.StructType).Field(i) switch fv := val.Field(i).(type) { default: BadType: fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) - return len(msg), false + return len(msg), false, false + case *reflect.BoolValue: + // Used internally for Edns, not present in the DNS + continue; case *reflect.SliceValue: switch f.Tag { default: fmt.Fprintf(os.Stderr, "net: dns: unknown IP tag %v", f.Tag) - return len(msg), false + return len(msg), false, false case "A": if off+net.IPv4len > len(msg) { - return len(msg), false + return len(msg), false, false } b := net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]) fv.Set(reflect.NewValue(b).(*reflect.SliceValue)) off += net.IPv4len case "AAAA": if off+net.IPv6len > len(msg) { - return len(msg), false + return len(msg), false, false } p := make(net.IP, net.IPv6len) copy(p, msg[off:off+net.IPv6len]) @@ -346,28 +353,28 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, // do it here } case *reflect.StructValue: - off, ok = unpackStructValue(fv, msg, off) + off, ok, edns = unpackStructValue(fv, msg, off) case *reflect.UintValue: switch fv.Type().Kind() { default: goto BadType case reflect.Uint8: if off+1 > len(msg) { - return len(msg), false + return len(msg), false, false } i := uint8(msg[off]) fv.Set(uint64(i)) off++ case reflect.Uint16: if off+2 > len(msg) { - return len(msg), false + return len(msg), false, false } i := uint16(msg[off])<<8 | uint16(msg[off+1]) fv.Set(uint64(i)) off += 2 case reflect.Uint32: if off+4 > len(msg) { - return len(msg), false + return len(msg), false, false } i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) fv.Set(uint64(i)) @@ -378,7 +385,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, switch f.Tag { default: fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) - return len(msg), false + return len(msg), false, false case "hex": // Rest of the RR is hex encoded rdlength := int(val.FieldByName("Hdr").(*reflect.StructValue).FieldByName("Rdlength").(*reflect.UintValue).Get()) @@ -414,13 +421,13 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, s = string(b64) off += rdlength - consumed case "domain-name": - s, off, ok = unpackDomainName(msg, off) + s, off, ok, edns = unpackDomainName(msg, off) if !ok { - return len(msg), false + return len(msg), false, false } case "": if off >= len(msg) || off+1+int(msg[off]) > len(msg) { - return len(msg), false + return len(msg), false, false } n := int(msg[off]) off++ @@ -434,12 +441,12 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, fv.Set(s) } } - return off, true + return off, true, edns } -func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { - off, ok = unpackStructValue(structValue(any), msg, off) - return off, ok +func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok, edns bool) { + off, ok, edns = unpackStructValue(structValue(any), msg, off) + return off, ok, edns } // THIS can GO TODO @@ -493,10 +500,12 @@ func unpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) { // unpack just the header, to find the rr type and length // check if we have an edns packet, and set h.Edns to true var h RR_Header + var edns bool off0 := off - if off, ok = unpackStruct(&h, msg, off); !ok { + if off, ok, edns = unpackStruct(&h, msg, off); !ok { return nil, len(msg), false } + h.Edns = edns // set Edns if found end := off + int(h.Rdlength) // make an rr of that type and re-unpack. @@ -507,7 +516,7 @@ func unpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) { } rr = mk() - off, ok = unpackStruct(rr, msg, off0) + off, ok, _ = unpackStruct(rr, msg, off0) // don't care about edns? if off != end { // added MG // println("Hier gaat het dan fout, echt waar en was if off0", off0) @@ -658,7 +667,7 @@ func (dns *Msg) Unpack(msg []byte) bool { var dh Header off := 0 var ok bool - if off, ok = unpackStruct(&dh, msg, off); !ok { + if off, ok, _ = unpackStruct(&dh, msg, off); !ok { return false } dns.Id = dh.Id @@ -677,7 +686,7 @@ func (dns *Msg) Unpack(msg []byte) bool { dns.Extra = make([]RR, dh.Arcount) for i := 0; i < len(dns.Question); i++ { - off, ok = unpackStruct(&dns.Question[i], msg, off) + off, ok, _ = unpackStruct(&dns.Question[i], msg, off) } for i := 0; i < len(dns.Answer); i++ { dns.Answer[i], off, ok = unpackRR(msg, off)