diff --git a/msg.go b/msg.go index 3b29cf8c..886e3440 100644 --- a/msg.go +++ b/msg.go @@ -254,7 +254,7 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o case *reflect.SliceValue: switch f.Tag { default: - fmt.Fprintf(os.Stderr, "dns: unknown IP tag %v\n", f.Tag) + fmt.Fprintf(os.Stderr, "dns: unknown packing slice tag %v\n", f.Tag) return len(msg), false case "OPT": // edns for j := 0; j < val.Field(i).(*reflect.SliceValue).Len(); j++ { @@ -263,6 +263,7 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o // for each code we should do something else h, e := hex.DecodeString(string(element.(*reflect.StructValue).Field(1).(*reflect.StringValue).Get())) if e != nil { + fmt.Fprintf(os.Stderr, "dns: failure packing OTP") return len(msg), false } data := string(h) @@ -319,12 +320,14 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o goto BadType case reflect.Uint8: if off+1 > len(msg) { + fmt.Fprintf(os.Stderr, "dns: overflow packing uint8") return len(msg), false } msg[off] = byte(i) off++ case reflect.Uint16: if off+2 > len(msg) { + fmt.Fprintf(os.Stderr, "dns: overflow packing uint16") return len(msg), false } msg[off] = byte(i >> 8) @@ -332,6 +335,7 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o off += 2 case reflect.Uint32: if off+4 > len(msg) { + fmt.Fprintf(os.Stderr, "dns: overflow packing uint32") return len(msg), false } msg[off] = byte(i >> 24) @@ -342,6 +346,7 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o case reflect.Uint64: // Only used in TSIG, where it stops as 48 bits, discard the upper 16 if off+6 > len(msg) { + fmt.Fprintf(os.Stderr, "dns: overflow packing uint64") return len(msg), false } msg[off] = byte(i >> 40) @@ -358,29 +363,33 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o s := fv.Get() switch f.Tag { default: + fmt.Fprintf(os.Stderr, "dns: unknown packing string tag %v", f.Tag) return len(msg), false case "base64": // TODO(mg) use the Len as return from the conversion (not used right now) b64len := base64.StdEncoding.DecodedLen(len(s)) _, err := base64.StdEncoding.Decode(msg[off:off+b64len], []byte(s)) if err != nil { + fmt.Fprintf(os.Stderr, "dns: overflow packing base64") return len(msg), false } off += b64len case "domain-name": off, ok = packDomainName(s, msg, off) if !ok { + fmt.Fprintf(os.Stderr, "dns: overflow packing domain-name") return len(msg), false } case "hex": // There is no length encoded here, for DS at least h, e := hex.DecodeString(s) if e != nil { + fmt.Fprintf(os.Stderr, "dns: overflow packing domain-name") return len(msg), false } copy(msg[off:off+hex.DecodedLen(len(s))], h) off += hex.DecodedLen(len(s)) - case "fixed-sized": + case "fixed-size": // the size is already encoded in the RR, we can safely use the // length of string. String is RAW (not encoded in hex, nor base64) copy(msg[off:off+len(s)], s) @@ -388,6 +397,7 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o case "": // Counted string: 1 byte length. if len(s) > 255 || off+1+len(s) > len(msg) { + fmt.Fprintf(os.Stderr, "dns: overflow packing string") return len(msg), false } msg[off] = byte(len(s)) @@ -424,7 +434,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, case *reflect.SliceValue: switch f.Tag { default: - fmt.Fprintf(os.Stderr, "dns: unknown IP tag %v", f.Tag) + fmt.Fprintf(os.Stderr, "dns: unknown unpacking slice tag %v", f.Tag) return len(msg), false case "A": if off+net.IPv4len > len(msg) { @@ -561,7 +571,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, var s string switch f.Tag { default: - fmt.Fprintf(os.Stderr, "dns: unknown string tag %v", f.Tag) + fmt.Fprintf(os.Stderr, "dns: unknown unpacking string tag %v", f.Tag) return len(msg), false case "hex": // Rest of the RR is hex encoded, network order an issue here? diff --git a/tsig.go b/tsig.go index 88e0872e..ebffbec5 100644 --- a/tsig.go +++ b/tsig.go @@ -4,6 +4,7 @@ package dns import ( "io" + "fmt" "encoding/base64" "strconv" "strings" @@ -36,7 +37,6 @@ func (rr *RR_TSIG) Header() *RR_Header { func (rr *RR_TSIG) String() string { // It has no official presentation format - println("mac len: ", rr.MACSize) return rr.Hdr.String() + " " + rr.Algorithm + " " + tsigTimeToDate(rr.TimeSigned) + @@ -76,6 +76,7 @@ func (rr *RR_TSIG) Generate(msg *Msg, secret string) bool { return false } rawsecret = rawsecret[:n] + buf, ok := tsigToBuf(rr, msg) if !ok { return false @@ -103,19 +104,38 @@ func (rr *RR_TSIG) Verify(msg *Msg, secret string) bool { if err != nil { return false } - // kill the last rr - copy msg TODO(mg) rawsecret = rawsecret[:n] - buf, ok := tsigToBuf(rr, msg) + + msg2 := msg // TODO deep copy TODO(mg) + if len(msg2.Extra) < 1 { + // nothing in additional + return false + } + tsigrr := msg2.Extra[len(msg2.Extra)-1] + if tsigrr.Header().Rrtype != TypeTSIG { + // not a tsig RR + return false + } + msg2.MsgHdr.Id = rr.OrigId + msg2.Extra = msg2.Extra[:len(msg2.Extra)-1] + // TODO(mg) + fmt.Printf("%v\n", msg2) + // msg2 + buf1, _ := msg2.Pack() + + buf, ok := tsigToBuf(rr, msg2) if !ok { return false } + hmac1 := hmac.NewMD5([]byte(rawsecret)) + io.WriteString(hmac1, string(buf1)) + fmt.Printf("%X\n", hmac1.Sum()) hmac := hmac.NewMD5([]byte(rawsecret)) io.WriteString(hmac, string(buf)) - rr.MAC = string(hmac.Sum()) - rr.MACSize = uint16(len(rr.MAC)) - rr.OrigId = msg.MsgHdr.Id - return true + fmt.Printf("%X\n", hmac.Sum()) + + return false } func tsigToBuf(rr *RR_TSIG, msg *Msg) ([]byte, bool) { @@ -136,7 +156,6 @@ func tsigToBuf(rr *RR_TSIG, msg *Msg) ([]byte, bool) { return nil, false } buf = buf[:n] - msgbuf, ok := msg.Pack() if !ok { return nil, false