diff --git a/Makefile b/Makefile index ca9c3f25..9a1786ff 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,7 @@ GOFILES=\ nsec3.go \ qnamestring.go\ server.go \ + rawmsg.go \ tsig.go\ types.go\ xfr.go\ diff --git a/msg.go b/msg.go index d5b288cf..a7f6cf52 100644 --- a/msg.go +++ b/msg.go @@ -47,7 +47,7 @@ var ( ErrXfrSoa os.Error = &Error{Error: "no SOA seen"} ErrHandle os.Error = &Error{Error: "handle is nil"} ErrChan os.Error = &Error{Error: "channel is nil"} - ErrName os.Error = &Error{Error: "type not found for name"} + ErrName os.Error = &Error{Error: "type not found for name"} ) // A manually-unpacked version of (id, bits). @@ -360,47 +360,47 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) } case reflect.Struct: off, ok = packStructValue(fv, msg, off) - case reflect.Uint8: - if off+1 > len(msg) { - //fmt.Fprintf(os.Stderr, "dns: overflow packing uint8") - return len(msg), false - } - msg[off] = byte(fv.Uint()) - off++ - case reflect.Uint16: - if off+2 > len(msg) { - //fmt.Fprintf(os.Stderr, "dns: overflow packing uint16") - return len(msg), false - } - i := fv.Uint() - msg[off] = byte(i >> 8) - msg[off+1] = byte(i) - off += 2 - case reflect.Uint32: - if off+4 > len(msg) { - //fmt.Fprintf(os.Stderr, "dns: overflow packing uint32") - return len(msg), false - } - i := fv.Uint() - msg[off] = byte(i >> 24) - msg[off+1] = byte(i >> 16) - msg[off+2] = byte(i >> 8) - msg[off+3] = byte(i) - off += 4 - case reflect.Uint64: - // Only used in TSIG, where it stops at 48 bits, so we discard the upper 16 - if off+6 > len(msg) { - //fmt.Fprintf(os.Stderr, "dns: overflow packing uint64") - return len(msg), false - } - i := fv.Uint() - msg[off] = byte(i >> 40) - msg[off+1] = byte(i >> 32) - msg[off+2] = byte(i >> 24) - msg[off+3] = byte(i >> 16) - msg[off+4] = byte(i >> 8) - msg[off+5] = byte(i) - off += 6 + case reflect.Uint8: + if off+1 > len(msg) { + //fmt.Fprintf(os.Stderr, "dns: overflow packing uint8") + return len(msg), false + } + msg[off] = byte(fv.Uint()) + off++ + case reflect.Uint16: + if off+2 > len(msg) { + //fmt.Fprintf(os.Stderr, "dns: overflow packing uint16") + return len(msg), false + } + i := fv.Uint() + msg[off] = byte(i >> 8) + msg[off+1] = byte(i) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + //fmt.Fprintf(os.Stderr, "dns: overflow packing uint32") + return len(msg), false + } + i := fv.Uint() + msg[off] = byte(i >> 24) + msg[off+1] = byte(i >> 16) + msg[off+2] = byte(i >> 8) + msg[off+3] = byte(i) + off += 4 + case reflect.Uint64: + // Only used in TSIG, where it stops at 48 bits, so we discard the upper 16 + if off+6 > len(msg) { + //fmt.Fprintf(os.Stderr, "dns: overflow packing uint64") + return len(msg), false + } + i := fv.Uint() + msg[off] = byte(i >> 40) + msg[off+1] = byte(i >> 32) + msg[off+2] = byte(i >> 24) + msg[off+3] = byte(i >> 16) + msg[off+4] = byte(i >> 8) + msg[off+5] = byte(i) + off += 6 case reflect.String: // There are multiple string encodings. // The tag distinguishes ordinary strings from domain names. @@ -591,41 +591,41 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo } case reflect.Struct: off, ok = unpackStructValue(fv, msg, off) - case reflect.Uint8: - if off+1 > len(msg) { - //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint8") - return len(msg), false - } - i := uint8(msg[off]) - fv.SetUint(uint64(i)) - off++ - case reflect.Uint16: - var i uint16 - if off+2 > len(msg) { - //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint16") - return len(msg), false - } - i, off = unpackUint16(msg, off) - fv.SetUint(uint64(i)) - case reflect.Uint32: - if off+4 > len(msg) { - //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint32") - return len(msg), false - } - i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) - fv.SetUint(uint64(i)) - off += 4 - case reflect.Uint64: - // This is *only* used in TSIG where the last 48 bits are occupied - // So for now, assume a uint48 (6 bytes) - if off+6 > len(msg) { - //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint64") - return len(msg), false - } - i := uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 | - uint64(msg[off+4])<<8 | uint64(msg[off+5]) - fv.SetUint(uint64(i)) - off += 6 + case reflect.Uint8: + if off+1 > len(msg) { + //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint8") + return len(msg), false + } + i := uint8(msg[off]) + fv.SetUint(uint64(i)) + off++ + case reflect.Uint16: + var i uint16 + if off+2 > len(msg) { + //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint16") + return len(msg), false + } + i, off = unpackUint16(msg, off) + fv.SetUint(uint64(i)) + case reflect.Uint32: + if off+4 > len(msg) { + //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint32") + return len(msg), false + } + i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) + fv.SetUint(uint64(i)) + off += 4 + case reflect.Uint64: + // This is *only* used in TSIG where the last 48 bits are occupied + // So for now, assume a uint48 (6 bytes) + if off+6 > len(msg) { + //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint64") + return len(msg), false + } + i := uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 | + uint64(msg[off+4])<<8 | uint64(msg[off+5]) + fv.SetUint(uint64(i)) + off += 6 case reflect.String: var s string switch f.Tag { @@ -829,11 +829,9 @@ func packRR(rr RR, msg []byte, off int) (off2 int, ok bool) { if !ok { return len(msg), false } - - // TODO make this quicker - // pack a third time; redo header with correct data length - rr.Header().Rdlength = uint16(off2 - off1) - packStruct(rr.Header(), msg, off) + if !RawSetRdlength(msg, uint16(off2-off1)) { + return len(msg), false + } return off2, true } diff --git a/rawmsg.go b/rawmsg.go new file mode 100644 index 00000000..f1b53440 --- /dev/null +++ b/rawmsg.go @@ -0,0 +1,25 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +/* Function defined in this subpackage work on []byte and but still + * provide some higher level functions + */ + + +// SetRdlength sets the length of the length of the rdata +// directly at the correct position in the buffer buf. +// If buf does not look like a DNS message false is returned, +// otherwise true. +func RawSetRdlength(buf []byte, i uint16) bool { + var off int + var ok bool + if _, off, ok = unpackDomainName(buf, 0); !ok { + return false + } + // off + type(2) + class(2) + ttl(4) -> rdlength + buf[off+2+2+4], buf[off+2+2+4+1] = packUint16(i) + return true +}