kill a packStruct

This commit is contained in:
Miek Gieben 2011-08-08 14:09:05 +02:00
parent 472935c3d2
commit 982f277a44
3 changed files with 106 additions and 82 deletions

View File

@ -19,6 +19,7 @@ GOFILES=\
nsec3.go \
qnamestring.go\
server.go \
rawmsg.go \
tsig.go\
types.go\
xfr.go\

162
msg.go
View File

@ -47,7 +47,7 @@ var (
ErrXfrSoa os.Error = &Error{Error: "no SOA seen"}
ErrHandle os.Error = &Error{Error: "handle is nil"}
ErrChan os.Error = &Error{Error: "channel is nil"}
ErrName os.Error = &Error{Error: "type not found for name"}
ErrName os.Error = &Error{Error: "type not found for name"}
)
// A manually-unpacked version of (id, bits).
@ -360,47 +360,47 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool)
}
case reflect.Struct:
off, ok = packStructValue(fv, msg, off)
case reflect.Uint8:
if off+1 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint8")
return len(msg), false
}
msg[off] = byte(fv.Uint())
off++
case reflect.Uint16:
if off+2 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint16")
return len(msg), false
}
i := fv.Uint()
msg[off] = byte(i >> 8)
msg[off+1] = byte(i)
off += 2
case reflect.Uint32:
if off+4 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint32")
return len(msg), false
}
i := fv.Uint()
msg[off] = byte(i >> 24)
msg[off+1] = byte(i >> 16)
msg[off+2] = byte(i >> 8)
msg[off+3] = byte(i)
off += 4
case reflect.Uint64:
// Only used in TSIG, where it stops at 48 bits, so we discard the upper 16
if off+6 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint64")
return len(msg), false
}
i := fv.Uint()
msg[off] = byte(i >> 40)
msg[off+1] = byte(i >> 32)
msg[off+2] = byte(i >> 24)
msg[off+3] = byte(i >> 16)
msg[off+4] = byte(i >> 8)
msg[off+5] = byte(i)
off += 6
case reflect.Uint8:
if off+1 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint8")
return len(msg), false
}
msg[off] = byte(fv.Uint())
off++
case reflect.Uint16:
if off+2 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint16")
return len(msg), false
}
i := fv.Uint()
msg[off] = byte(i >> 8)
msg[off+1] = byte(i)
off += 2
case reflect.Uint32:
if off+4 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint32")
return len(msg), false
}
i := fv.Uint()
msg[off] = byte(i >> 24)
msg[off+1] = byte(i >> 16)
msg[off+2] = byte(i >> 8)
msg[off+3] = byte(i)
off += 4
case reflect.Uint64:
// Only used in TSIG, where it stops at 48 bits, so we discard the upper 16
if off+6 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint64")
return len(msg), false
}
i := fv.Uint()
msg[off] = byte(i >> 40)
msg[off+1] = byte(i >> 32)
msg[off+2] = byte(i >> 24)
msg[off+3] = byte(i >> 16)
msg[off+4] = byte(i >> 8)
msg[off+5] = byte(i)
off += 6
case reflect.String:
// There are multiple string encodings.
// The tag distinguishes ordinary strings from domain names.
@ -591,41 +591,41 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
}
case reflect.Struct:
off, ok = unpackStructValue(fv, msg, off)
case reflect.Uint8:
if off+1 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint8")
return len(msg), false
}
i := uint8(msg[off])
fv.SetUint(uint64(i))
off++
case reflect.Uint16:
var i uint16
if off+2 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint16")
return len(msg), false
}
i, off = unpackUint16(msg, off)
fv.SetUint(uint64(i))
case reflect.Uint32:
if off+4 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint32")
return len(msg), false
}
i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
fv.SetUint(uint64(i))
off += 4
case reflect.Uint64:
// This is *only* used in TSIG where the last 48 bits are occupied
// So for now, assume a uint48 (6 bytes)
if off+6 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint64")
return len(msg), false
}
i := uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
uint64(msg[off+4])<<8 | uint64(msg[off+5])
fv.SetUint(uint64(i))
off += 6
case reflect.Uint8:
if off+1 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint8")
return len(msg), false
}
i := uint8(msg[off])
fv.SetUint(uint64(i))
off++
case reflect.Uint16:
var i uint16
if off+2 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint16")
return len(msg), false
}
i, off = unpackUint16(msg, off)
fv.SetUint(uint64(i))
case reflect.Uint32:
if off+4 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint32")
return len(msg), false
}
i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
fv.SetUint(uint64(i))
off += 4
case reflect.Uint64:
// This is *only* used in TSIG where the last 48 bits are occupied
// So for now, assume a uint48 (6 bytes)
if off+6 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint64")
return len(msg), false
}
i := uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
uint64(msg[off+4])<<8 | uint64(msg[off+5])
fv.SetUint(uint64(i))
off += 6
case reflect.String:
var s string
switch f.Tag {
@ -829,11 +829,9 @@ func packRR(rr RR, msg []byte, off int) (off2 int, ok bool) {
if !ok {
return len(msg), false
}
// TODO make this quicker
// pack a third time; redo header with correct data length
rr.Header().Rdlength = uint16(off2 - off1)
packStruct(rr.Header(), msg, off)
if !RawSetRdlength(msg, uint16(off2-off1)) {
return len(msg), false
}
return off2, true
}

25
rawmsg.go Normal file
View File

@ -0,0 +1,25 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dns
/* Function defined in this subpackage work on []byte and but still
* provide some higher level functions
*/
// SetRdlength sets the length of the length of the rdata
// directly at the correct position in the buffer buf.
// If buf does not look like a DNS message false is returned,
// otherwise true.
func RawSetRdlength(buf []byte, i uint16) bool {
var off int
var ok bool
if _, off, ok = unpackDomainName(buf, 0); !ok {
return false
}
// off + type(2) + class(2) + ttl(4) -> rdlength
buf[off+2+2+4], buf[off+2+2+4+1] = packUint16(i)
return true
}