kill a packStruct

This commit is contained in:
Miek Gieben 2011-08-08 14:09:05 +02:00
parent 472935c3d2
commit 982f277a44
3 changed files with 106 additions and 82 deletions

View File

@ -19,6 +19,7 @@ GOFILES=\
nsec3.go \ nsec3.go \
qnamestring.go\ qnamestring.go\
server.go \ server.go \
rawmsg.go \
tsig.go\ tsig.go\
types.go\ types.go\
xfr.go\ xfr.go\

162
msg.go
View File

@ -47,7 +47,7 @@ var (
ErrXfrSoa os.Error = &Error{Error: "no SOA seen"} ErrXfrSoa os.Error = &Error{Error: "no SOA seen"}
ErrHandle os.Error = &Error{Error: "handle is nil"} ErrHandle os.Error = &Error{Error: "handle is nil"}
ErrChan os.Error = &Error{Error: "channel is nil"} ErrChan os.Error = &Error{Error: "channel is nil"}
ErrName os.Error = &Error{Error: "type not found for name"} ErrName os.Error = &Error{Error: "type not found for name"}
) )
// A manually-unpacked version of (id, bits). // A manually-unpacked version of (id, bits).
@ -360,47 +360,47 @@ func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool)
} }
case reflect.Struct: case reflect.Struct:
off, ok = packStructValue(fv, msg, off) off, ok = packStructValue(fv, msg, off)
case reflect.Uint8: case reflect.Uint8:
if off+1 > len(msg) { if off+1 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint8") //fmt.Fprintf(os.Stderr, "dns: overflow packing uint8")
return len(msg), false return len(msg), false
} }
msg[off] = byte(fv.Uint()) msg[off] = byte(fv.Uint())
off++ off++
case reflect.Uint16: case reflect.Uint16:
if off+2 > len(msg) { if off+2 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint16") //fmt.Fprintf(os.Stderr, "dns: overflow packing uint16")
return len(msg), false return len(msg), false
} }
i := fv.Uint() i := fv.Uint()
msg[off] = byte(i >> 8) msg[off] = byte(i >> 8)
msg[off+1] = byte(i) msg[off+1] = byte(i)
off += 2 off += 2
case reflect.Uint32: case reflect.Uint32:
if off+4 > len(msg) { if off+4 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint32") //fmt.Fprintf(os.Stderr, "dns: overflow packing uint32")
return len(msg), false return len(msg), false
} }
i := fv.Uint() i := fv.Uint()
msg[off] = byte(i >> 24) msg[off] = byte(i >> 24)
msg[off+1] = byte(i >> 16) msg[off+1] = byte(i >> 16)
msg[off+2] = byte(i >> 8) msg[off+2] = byte(i >> 8)
msg[off+3] = byte(i) msg[off+3] = byte(i)
off += 4 off += 4
case reflect.Uint64: case reflect.Uint64:
// Only used in TSIG, where it stops at 48 bits, so we discard the upper 16 // Only used in TSIG, where it stops at 48 bits, so we discard the upper 16
if off+6 > len(msg) { if off+6 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow packing uint64") //fmt.Fprintf(os.Stderr, "dns: overflow packing uint64")
return len(msg), false return len(msg), false
} }
i := fv.Uint() i := fv.Uint()
msg[off] = byte(i >> 40) msg[off] = byte(i >> 40)
msg[off+1] = byte(i >> 32) msg[off+1] = byte(i >> 32)
msg[off+2] = byte(i >> 24) msg[off+2] = byte(i >> 24)
msg[off+3] = byte(i >> 16) msg[off+3] = byte(i >> 16)
msg[off+4] = byte(i >> 8) msg[off+4] = byte(i >> 8)
msg[off+5] = byte(i) msg[off+5] = byte(i)
off += 6 off += 6
case reflect.String: case reflect.String:
// There are multiple string encodings. // There are multiple string encodings.
// The tag distinguishes ordinary strings from domain names. // The tag distinguishes ordinary strings from domain names.
@ -591,41 +591,41 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
} }
case reflect.Struct: case reflect.Struct:
off, ok = unpackStructValue(fv, msg, off) off, ok = unpackStructValue(fv, msg, off)
case reflect.Uint8: case reflect.Uint8:
if off+1 > len(msg) { if off+1 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint8") //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint8")
return len(msg), false return len(msg), false
} }
i := uint8(msg[off]) i := uint8(msg[off])
fv.SetUint(uint64(i)) fv.SetUint(uint64(i))
off++ off++
case reflect.Uint16: case reflect.Uint16:
var i uint16 var i uint16
if off+2 > len(msg) { if off+2 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint16") //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint16")
return len(msg), false return len(msg), false
} }
i, off = unpackUint16(msg, off) i, off = unpackUint16(msg, off)
fv.SetUint(uint64(i)) fv.SetUint(uint64(i))
case reflect.Uint32: case reflect.Uint32:
if off+4 > len(msg) { if off+4 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint32") //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint32")
return len(msg), false return len(msg), false
} }
i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
fv.SetUint(uint64(i)) fv.SetUint(uint64(i))
off += 4 off += 4
case reflect.Uint64: case reflect.Uint64:
// This is *only* used in TSIG where the last 48 bits are occupied // This is *only* used in TSIG where the last 48 bits are occupied
// So for now, assume a uint48 (6 bytes) // So for now, assume a uint48 (6 bytes)
if off+6 > len(msg) { if off+6 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint64") //fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint64")
return len(msg), false return len(msg), false
} }
i := uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 | i := uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
uint64(msg[off+4])<<8 | uint64(msg[off+5]) uint64(msg[off+4])<<8 | uint64(msg[off+5])
fv.SetUint(uint64(i)) fv.SetUint(uint64(i))
off += 6 off += 6
case reflect.String: case reflect.String:
var s string var s string
switch f.Tag { switch f.Tag {
@ -829,11 +829,9 @@ func packRR(rr RR, msg []byte, off int) (off2 int, ok bool) {
if !ok { if !ok {
return len(msg), false return len(msg), false
} }
if !RawSetRdlength(msg, uint16(off2-off1)) {
// TODO make this quicker return len(msg), false
// pack a third time; redo header with correct data length }
rr.Header().Rdlength = uint16(off2 - off1)
packStruct(rr.Header(), msg, off)
return off2, true return off2, true
} }

25
rawmsg.go Normal file
View File

@ -0,0 +1,25 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dns
/* Function defined in this subpackage work on []byte and but still
* provide some higher level functions
*/
// SetRdlength sets the length of the length of the rdata
// directly at the correct position in the buffer buf.
// If buf does not look like a DNS message false is returned,
// otherwise true.
func RawSetRdlength(buf []byte, i uint16) bool {
var off int
var ok bool
if _, off, ok = unpackDomainName(buf, 0); !ok {
return false
}
// off + type(2) + class(2) + ttl(4) -> rdlength
buf[off+2+2+4], buf[off+2+2+4+1] = packUint16(i)
return true
}