Use proper error in packing and unpacking

All the relevant functions now return an error instead of
a simple boolean. This greatly approves the feedback to coders.

Spotted some fishy error handling along the way and fix that too.
This commit is contained in:
Miek Gieben 2012-10-09 21:17:54 +02:00
parent 099c19d5b2
commit 570bf8dc69
6 changed files with 211 additions and 252 deletions

View File

@ -101,9 +101,9 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, err error) {
func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
var n int
var w *reply
out, ok := m.Pack()
if !ok {
return nil, 0, ErrPack
out, err := m.Pack()
if err != nil {
return nil, 0, err
}
var in []byte
switch c.Net {
@ -123,8 +123,8 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err e
}
r = new(Msg)
r.Size = n
if ok := r.Unpack(in[:n]); !ok {
return nil, w.rtt, ErrUnpack
if err := r.Unpack(in[:n]); err != nil {
return nil, w.rtt, err
}
return r, w.rtt, nil
}
@ -158,8 +158,8 @@ func (w *reply) receive() (*Msg, error) {
return nil, err
}
p = p[:n]
if ok := m.Unpack(p); !ok {
return nil, ErrUnpack
if err := m.Unpack(p); err != nil {
return nil, err
}
w.rtt = time.Since(w.t)
m.Size = n
@ -260,10 +260,9 @@ func (w *reply) send(m *Msg) (err error) {
}
w.tsigRequestMAC = mac
} else {
ok := false
out, ok = m.Pack()
if !ok {
return ErrPack
out, err = m.Pack()
if err != nil {
return err
}
}
w.t = time.Now()

View File

@ -119,8 +119,8 @@ func (k *RR_DNSKEY) KeyTag() uint16 {
keywire.Algorithm = k.Algorithm
keywire.PublicKey = k.PublicKey
wire := make([]byte, DefaultMsgSize)
n, ok := PackStruct(keywire, wire, 0)
if !ok {
n, err := PackStruct(keywire, wire, 0)
if err != nil {
return 0
}
wire = wire[:n]
@ -157,15 +157,15 @@ func (k *RR_DNSKEY) ToDS(h int) *RR_DS {
keywire.Algorithm = k.Algorithm
keywire.PublicKey = k.PublicKey
wire := make([]byte, DefaultMsgSize)
n, ok := PackStruct(keywire, wire, 0)
if !ok {
n, err := PackStruct(keywire, wire, 0)
if err != nil {
return nil
}
wire = wire[:n]
owner := make([]byte, 255)
off, ok1 := PackDomainName(k.Hdr.Name, owner, 0, nil, false)
if !ok1 {
off, err1 := PackDomainName(k.Hdr.Name, owner, 0, nil, false)
if err1 != nil {
return nil
}
owner = owner[:off]
@ -237,9 +237,9 @@ func (rr *RR_RRSIG) Sign(k PrivateKey, rrset []RR) error {
// Create the desired binary blob
signdata := make([]byte, DefaultMsgSize)
n, ok := PackStruct(sigwire, signdata, 0)
if !ok {
return ErrPack
n, err := PackStruct(sigwire, signdata, 0)
if err != nil {
return err
}
signdata = signdata[:n]
wire := rawSignatureData(rrset, rr)
@ -349,9 +349,9 @@ func (rr *RR_RRSIG) Verify(k *RR_DNSKEY, rrset []RR) error {
sigwire.SignerName = strings.ToLower(rr.SignerName)
// Create the desired binary blob
signeddata := make([]byte, DefaultMsgSize)
n, ok := PackStruct(sigwire, signeddata, 0)
if !ok {
return ErrPack
n, err := PackStruct(sigwire, signeddata, 0)
if err != nil {
return err
}
signeddata = signeddata[:n]
wire := rawSignatureData(rrset, rr)
@ -684,8 +684,8 @@ func rawSignatureData(rrset []RR, s *RR_RRSIG) (buf []byte) {
}
// 6.2. Canonical RR Form. (5) - origTTL
wire := make([]byte, r.Len()*2)
off, ok1 := PackRR(r1, wire, 0, nil, false)
if !ok1 {
off, err1 := PackRR(r1, wire, 0, nil, false)
if err1 != nil {
return nil
}
wire = wire[:off]

352
msg.go
View File

@ -28,8 +28,10 @@ const maxCompressionOffset = 2 << 13 // We have 14 bits for the compression poin
var (
ErrUnpack error = &Error{Err: "unpacking failed"}
ErrPack error = &Error{Err: "packing failed"}
ErrFqdn error = &Error{Err: "domain must be fully qualified"}
ErrId error = &Error{Err: "id mismatch"}
ErrBuf error = &Error{Err: "buffer size too large"}
ErrRdata error = &Error{Err: "bad rdata"}
ErrBuf error = &Error{Err: "buffer size too small"}
ErrShortRead error = &Error{Err: "short read"}
ErrConn error = &Error{Err: "conn holds both UDP and TCP connection"}
ErrConnEmpty error = &Error{Err: "conn has no connection"}
@ -46,10 +48,7 @@ var (
ErrSigGen error = &Error{Err: "bad signature generation"}
ErrAuth error = &Error{Err: "bad authentication"}
ErrSoa error = &Error{Err: "no SOA"}
ErrHandle error = &Error{Err: "handle is nil"}
ErrChan error = &Error{Err: "channel is nil"}
ErrName error = &Error{Err: "type not found for name"}
ErrRRset error = &Error{Err: "invalid rrset"}
ErrRRset error = &Error{Err: "bad rrset"}
ErrDenialNsec3 error = &Error{Err: "no NSEC3 records"}
ErrDenialCe error = &Error{Err: "no matching closest encloser found"}
ErrDenialNc error = &Error{Err: "no covering NSEC3 found for next closer"}
@ -200,13 +199,12 @@ var Rcode_str = map[int]string{
// If compression is wanted compress must be true and the compression
// map needs to hold a mapping between domain names and offsets
// pointing into msg[].
func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
// Add trailing dot to canonicalize name.
func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
lenmsg := len(msg)
ls := len(s)
// If not fully qualified, error out
if ls == 0 || s[ls-1] != '.' {
//println("dns: name not fully qualified")
return lenmsg, false
return lenmsg, ErrFqdn
}
// Each dot ends a segment of the name.
@ -234,30 +232,22 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
if bs[i] == '.' {
if i-begin >= 1<<6 { // top two bits of length must be clear
return lenmsg, false
return lenmsg, ErrRdata
}
// off can already (we're in a loop) be bigger than len(msg)
// this happens when a name isn't fully qualified
if off+1 > len(msg) {
return lenmsg, false
return lenmsg, ErrBuf
}
msg[off] = byte(i - begin)
offset := off
off++
// TODO(mg): because of the new check above, this can go. But
// just leave it as is for the moment.
// if off > lenmsg {
// return lenmsg, false
// }
for j := begin; j < i; j++ {
if off+1 > len(msg) {
return lenmsg, false
return lenmsg, ErrBuf
}
msg[off] = bs[j]
off++
// if off > lenmsg {
// return lenmsg, false
// }
}
// Dont try to compress '.'
if compression != nil && string(bs[begin:]) != ".'" {
@ -285,7 +275,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
}
// Root label is special
if string(bs) == "." {
return off, true
return off, nil
}
// If we did compression and we find something at the pointer here
if pointer != -1 {
@ -297,7 +287,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
msg[off] = 0
End:
off++
return off, true
return off, nil
}
// Unpack a domain name.
@ -315,14 +305,14 @@ End:
// We let them jump anywhere and stop jumping after a while.
// UnpackDomainName unpacks a domain name into a string.
func UnpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) {
func UnpackDomainName(msg []byte, off int) (s string, off1 int, err error) {
s = ""
lenmsg := len(msg)
ptr := 0 // number of pointers followed
Loop:
for {
if off >= lenmsg {
return "", lenmsg, false
return "", lenmsg, ErrBuf
}
c := int(msg[off])
off++
@ -331,13 +321,13 @@ Loop:
if c == 0x00 {
// end of name
if s == "" {
return ".", off, true
return ".", off, nil
}
break Loop
}
// literal string
if off+c > lenmsg {
return "", lenmsg, false
return "", lenmsg, ErrBuf
}
for j := off; j < off+c; j++ {
if msg[j] == '.' {
@ -356,7 +346,7 @@ Loop:
// also, don't follow too many pointers --
// maybe there's a loop.
if off >= lenmsg {
return "", lenmsg, false
return "", lenmsg, ErrBuf
}
c1 := msg[off]
off++
@ -364,41 +354,38 @@ Loop:
off1 = off
}
if ptr++; ptr > 10 {
return "", lenmsg, false
return "", lenmsg, &Error{Err: "too many compression pointers"}
}
off = (c^0xC0)<<8 | int(c1)
default:
// 0x80 and 0x40 are reserved
return "", lenmsg, false
return "", lenmsg, ErrRdata
}
}
if ptr == 0 {
off1 = off
}
return s, off1, true
return s, off1, nil
}
// Pack a reflect.StructValue into msg. Struct members can only be uint8, uint16, uint32, string,
// slices and other (often anonymous) structs.
func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
lenmsg := len(msg)
for i := 0; i < val.NumField(); i++ {
// f := val.Type().Field(i)
lenmsg := len(msg)
switch fv := val.Field(i); fv.Kind() {
default:
return lenmsg, false
return lenmsg, &Error{Err: "bad kind packing"}
case reflect.Slice:
switch val.Type().Field(i).Tag.Get("dns") {
default:
// println("dns: unknown tag packing slice", val.Type().Field(i).Tag.Get("dns"), '"', val.Type().Field(i).Tag, '"')
return lenmsg, false
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag packing slice"}
case "domain-name":
for j := 0; j < val.Field(i).Len(); j++ {
element := val.Field(i).Index(j).String()
off, ok = PackDomainName(element, msg, off, compression, false && compress)
if !ok {
// println("dns: overflow packing domain-name", off)
return lenmsg, false
off, err = PackDomainName(element, msg, off, compression, false && compress)
if err != nil {
return lenmsg, err
}
}
case "txt":
@ -406,8 +393,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
element := val.Field(i).Index(j).String()
// Counted string: 1 byte length.
if len(element) > 255 || off+1+len(element) > lenmsg {
// println("dns: overflow packing TXT string")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing txt"}
}
msg[off] = byte(len(element))
off++
@ -421,8 +407,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
element := val.Field(i).Index(j).Interface()
b, e := element.(EDNS0).pack()
if e != nil {
// println("dns: failure packing OPT")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing opt"}
}
// Option code
msg[off], msg[off+1] = packUint16(element.(EDNS0).Option())
@ -436,22 +421,17 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "a":
// It must be a slice of 4, even if it is 16, we encode
// only the first 4
if off+net.IPv4len > lenmsg {
return lenmsg, &Error{Err: "overflow packing a"}
}
switch fv.Len() {
case net.IPv6len:
if off+net.IPv4len > lenmsg {
// println("dns: overflow packing A", off, lenmsg)
return lenmsg, false
}
msg[off] = byte(fv.Index(12).Uint())
msg[off+1] = byte(fv.Index(13).Uint())
msg[off+2] = byte(fv.Index(14).Uint())
msg[off+3] = byte(fv.Index(15).Uint())
off += net.IPv4len
case net.IPv4len:
if off+net.IPv4len > lenmsg {
// println("dns: overflow packing A", off, lenmsg)
return lenmsg, false
}
msg[off] = byte(fv.Index(0).Uint())
msg[off+1] = byte(fv.Index(1).Uint())
msg[off+2] = byte(fv.Index(2).Uint())
@ -460,13 +440,11 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case 0:
// Allowed, for dynamic updates
default:
// println("dns: overflow packing A")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing a"}
}
case "aaaa":
if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg {
// println("dns: overflow packing AAAA")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing aaaa"}
}
for j := 0; j < net.IPv6len; j++ {
msg[off] = byte(fv.Index(j).Uint())
@ -481,8 +459,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
serv := uint16((fv.Index(j).Uint()))
bitmapbyte = uint16(serv / 8)
if int(bitmapbyte) > lenmsg {
// println("dns: overflow packing WKS")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing wks"}
}
bit := uint16(serv) - bitmapbyte*8
msg[bitmapbyte] = byte(1 << (7 - bit))
@ -498,8 +475,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
lastwindow := uint16(0)
length := uint16(0)
if off+2 > lenmsg {
// println("dns: overflow packing NSECx bitmap")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing nsecx"}
}
for j := 0; j < val.Field(i).Len(); j++ {
t := uint16((fv.Index(j).Uint()))
@ -508,15 +484,13 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
// New window, jump to the new offset
off += int(length) + 3
if off > lenmsg {
// println("dns: overflow packing NSECx bitmap")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
}
}
length = (t - window*256) / 8
bit := t - (window * 256) - (length * 8)
if off+2+int(length) > lenmsg {
// println("dns: overflow packing NSECx bitmap")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
}
// Setting the window #
@ -530,23 +504,20 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off += 2 + int(length)
off++
if off > lenmsg {
// println("dns: overflow packing NSECx bitmap")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
}
}
case reflect.Struct:
off, ok = packStructValue(fv, msg, off, compression, compress)
off, err = packStructValue(fv, msg, off, compression, compress)
case reflect.Uint8:
if off+1 > lenmsg {
// println("dns: overflow packing uint8")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing uint8"}
}
msg[off] = byte(fv.Uint())
off++
case reflect.Uint16:
if off+2 > lenmsg {
// println("dns: overflow packing uint16")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing uint16"}
}
i := fv.Uint()
msg[off] = byte(i >> 8)
@ -554,8 +525,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off += 2
case reflect.Uint32:
if off+4 > lenmsg {
// println("dns: overflow packing uint32")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing uint32"}
}
i := fv.Uint()
msg[off] = byte(i >> 24)
@ -566,8 +536,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case reflect.Uint64:
// Only used in TSIG, where it stops at 48 bits, so we discard the upper 16
if off+6 > lenmsg {
// println("dns: overflow packing uint64")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing uint64 as uint48"}
}
i := fv.Uint()
msg[off] = byte(i >> 40)
@ -583,24 +552,21 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
s := fv.String()
switch val.Type().Field(i).Tag.Get("dns") {
default:
return lenmsg, false
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag packing string"}
case "base64":
b64, err := packBase64([]byte(s))
if err != nil {
// println("dns: overflow packing base64")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing base64"}
}
copy(msg[off:off+len(b64)], b64)
off += len(b64)
case "domain-name":
if off, ok = PackDomainName(s, msg, off, compression, false && compress); !ok {
// println("dns: overflow packing domain-name", off)
return lenmsg, false
if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil {
return lenmsg, err
}
case "cdomain-name":
if off, ok = PackDomainName(s, msg, off, compression, true && compress); !ok {
// println("dns: overflow packing domain-name", off)
return lenmsg, false
if off, err = PackDomainName(s, msg, off, compression, true && compress); err != nil {
return lenmsg, err
}
case "size-base32":
// This is purely for NSEC3 atm, the previous byte must
@ -611,8 +577,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "base32":
b32, err := packBase32([]byte(s))
if err != nil {
// println("dns: overflow packing base32")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing base32"}
}
copy(msg[off:off+len(b32)], b32)
off += len(b32)
@ -622,12 +587,10 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
// There is no length encoded here
h, e := hex.DecodeString(s)
if e != nil {
// println("dns: overflow packing (size-)hex string")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing hex"}
}
if off+hex.DecodedLen(len(s)) > lenmsg {
// Overflow
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing hex"}
}
copy(msg[off:off+hex.DecodedLen(len(s))], h)
off += hex.DecodedLen(len(s))
@ -641,8 +604,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "":
// Counted string: 1 byte length.
if len(s) > 255 || off+1+len(s) > lenmsg {
// println("dns: overflow packing string")
return lenmsg, false
return lenmsg, &Error{Err: "overflow packing string"}
}
msg[off] = byte(len(s))
off++
@ -653,48 +615,44 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
}
}
}
return off, true
return off, nil
}
func structValue(any interface{}) reflect.Value {
return reflect.ValueOf(any).Elem()
}
func PackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
off, ok = packStructValue(structValue(any), msg, off, nil, false)
return off, ok
func PackStruct(any interface{}, msg []byte, off int) (off1 int, err error) {
off, err = packStructValue(structValue(any), msg, off, nil, false)
return off, err
}
func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
off, ok = packStructValue(structValue(any), msg, off, compression, compress)
return off, ok
func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
off, err = packStructValue(structValue(any), msg, off, compression, compress)
return off, err
}
// Unpack a reflect.StructValue from msg.
// Same restrictions as packStructValue.
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) {
var rdstart int
lenmsg := len(msg)
for i := 0; i < val.NumField(); i++ {
// f := val.Type().Field(i)
lenmsg := len(msg)
switch fv := val.Field(i); fv.Kind() {
default:
// println("dns: unknown case unpacking struct")
return lenmsg, false
return lenmsg, &Error{Err: "bad kind unpacking"}
case reflect.Slice:
switch val.Type().Field(i).Tag.Get("dns") {
default:
// println("dns: unknown tag unpacking slice", val.Type().Field(i).Tag)
return lenmsg, false
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag unpacking slice"}
case "domain-name":
// HIP record slice of name (or none)
servers := make([]string, 0)
var s string
for off < lenmsg {
s, off, ok = UnpackDomainName(msg, off)
if !ok {
// println("dns: failure unpacking domain-name")
return lenmsg, false
s, off, err = UnpackDomainName(msg, off)
if err != nil {
return lenmsg, err
}
servers = append(servers, s)
}
@ -705,8 +663,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
Txts:
l := int(msg[off])
if off+l+1 > lenmsg {
// println("dns: failure unpacking txt strings")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking txt"}
}
txt = append(txt, string(msg[off+1:off+l+1]))
off += l + 1
@ -724,14 +681,14 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
break
}
edns := make([]EDNS0, 0)
// Goto to this place, when there is a goto
code := uint16(0)
code, off = unpackUint16(msg, off) // Overflow? TODO
if off+2 > lenmsg {
return lenmsg, &Error{Err: "overflow unpacking opt"}
}
code, off = unpackUint16(msg, off)
optlen, off1 := unpackUint16(msg, off)
if off1+int(optlen) > off+rdlength {
// println("dns: overflow unpacking OPT")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking opt"}
}
switch code {
case EDNS0NSID:
@ -746,18 +703,16 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
off = off1 + int(optlen)
}
fv.Set(reflect.ValueOf(edns))
// goto ??
// multiple EDNS codes?
case "a":
if off+net.IPv4len > len(msg) {
// println("dns: overflow unpacking A")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking a"}
}
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
off += net.IPv4len
case "aaaa":
if off+net.IPv6len > lenmsg {
// println("dns: overflow unpacking AAAA")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking aaaa"}
}
fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4],
msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10],
@ -806,8 +761,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
endrr := rdstart + rdlength
if off+2 > lenmsg {
// println("dns: overflow unpacking NSEC")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking nsecx"}
}
nsec := make([]uint16, 0)
length := 0
@ -820,15 +774,13 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
// A length window of zero is strange. If there
// the window should not have been specified. Bail out
// println("dns: length == 0 when unpacking NSEC")
return lenmsg, false
return lenmsg, ErrRdata
}
if length > 32 {
// println("dns: length > 32 when unpacking NSEC")
return lenmsg, false
return lenmsg, ErrRdata
}
// Walk the bytes in the window - and check the bit
// setting..
// Walk the bytes in the window - and check the bit settings...
off += 2
for j := 0; j < length; j++ {
b := msg[off+j]
@ -863,29 +815,26 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
fv.Set(reflect.ValueOf(nsec))
}
case reflect.Struct:
off, ok = unpackStructValue(fv, msg, off)
off, err = unpackStructValue(fv, msg, off)
if val.Type().Field(i).Name == "Hdr" {
rdstart = off
}
case reflect.Uint8:
if off+1 > lenmsg {
// println("dns: overflow unpacking uint8")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking uint8"}
}
fv.SetUint(uint64(uint8(msg[off])))
off++
case reflect.Uint16:
var i uint16
if off+2 > lenmsg {
// println("dns: overflow unpacking uint16")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking uint16"}
}
i, off = unpackUint16(msg, off)
fv.SetUint(uint64(i))
case reflect.Uint32:
if off+4 > lenmsg {
// println("dns: overflow unpacking uint32")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking uint32"}
}
fv.SetUint(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])))
off += 4
@ -893,8 +842,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
// This is *only* used in TSIG where the last 48 bits are occupied
// So for now, assume a uint48 (6 bytes)
if off+6 > lenmsg {
// println("dns: overflow unpacking uint64")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking uint64 as uint48"}
}
fv.SetUint(uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
uint64(msg[off+4])<<8 | uint64(msg[off+5])))
@ -903,15 +851,13 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
var s string
switch val.Type().Field(i).Tag.Get("dns") {
default:
// println("dns: unknown tag unpacking string")
return lenmsg, false
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag unpacking string"}
case "hex":
// Rest of the RR is hex encoded, network order an issue here?
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength
if endrr > lenmsg {
// println("dns: overflow when unpacking hex string")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking hex"}
}
s = hex.EncodeToString(msg[off:endrr])
off = endrr
@ -920,18 +866,16 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength
if endrr > lenmsg {
// println("dns: failure unpacking base64")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking base64"}
}
s = unpackBase64(msg[off:endrr])
off = endrr
case "cdomain-name":
fallthrough
case "domain-name":
s, off, ok = UnpackDomainName(msg, off)
if !ok {
// println("dns: failure unpacking domain-name")
return lenmsg, false
s, off, err = UnpackDomainName(msg, off)
if err != nil {
return lenmsg, err
}
case "size-base32":
var size int
@ -944,8 +888,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
}
}
if off+size > lenmsg {
// println("dns: failure unpacking size-base32 string")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking base32"}
}
s = unpackBase32(msg[off : off+size])
off += size
@ -973,8 +916,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
}
}
if off+size > lenmsg {
// println("dns: failure unpacking size-hex string")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking hex"}
}
s = hex.EncodeToString(msg[off : off+size])
off += size
@ -983,8 +925,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
Txt:
if off >= lenmsg || off+1+int(msg[off]) > lenmsg {
// println("dns: failure unpacking txt string")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking txt"}
}
n := int(msg[off])
off++
@ -998,8 +939,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
}
case "":
if off >= lenmsg || off+1+int(msg[off]) > lenmsg {
// println("dns: failure unpacking string")
return lenmsg, false
return lenmsg, &Error{Err: "overflow unpacking string"}
}
n := int(msg[off])
off++
@ -1011,7 +951,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
fv.SetString(s)
}
}
return off, true
return off, nil
}
// Helper function for unpacking
@ -1021,9 +961,9 @@ func unpackUint16(msg []byte, off int) (v uint16, off1 int) {
return
}
func UnpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
off, ok = unpackStructValue(structValue(any), msg, off)
return off, ok
func UnpackStruct(any interface{}, msg []byte, off int) (off1 int, err error) {
off, err = unpackStructValue(structValue(any), msg, off)
return off, err
}
func unpackBase32(b []byte) string {
@ -1067,28 +1007,26 @@ func packBase32(s []byte) ([]byte, error) {
}
// Resource record packer.
func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
if rr == nil {
return len(msg), false
return len(msg), &Error{Err: "nil rr"}
}
off1, ok = packStructCompress(rr, msg, off, compression, compress)
if !ok {
return len(msg), false
off1, err = packStructCompress(rr, msg, off, compression, compress)
if err != nil {
return len(msg), err
}
if !rawSetRdlength(msg, off, off1) {
return len(msg), false
}
return off1, true
rawSetRdlength(msg, off, off1)
return off1, nil
}
// Resource record unpacker.
func UnpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) {
func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
// unpack just the header, to find the rr type and length
var h RR_Header
off0 := off
if off, ok = UnpackStruct(&h, msg, off); !ok {
return nil, len(msg), false
if off, err = UnpackStruct(&h, msg, off); err != nil {
return nil, len(msg), err
}
end := off + int(h.Rdlength)
// make an rr of that type and re-unpack.
@ -1098,11 +1036,11 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) {
} else {
rr = mk()
}
off, ok = UnpackStruct(rr, msg, off0)
off, err = UnpackStruct(rr, msg, off0)
if off != end {
return &h, end, true
return &h, end, nil
}
return rr, off, ok
return rr, off, err
}
// Reverse a map
@ -1176,9 +1114,9 @@ func (h *MsgHdr) String() string {
// Pack packs a Msg: it is converted to to wire format.
// If the dns.Compress is true the message will be in compressed wire format.
func (dns *Msg) Pack() (msg []byte, ok bool) {
func (dns *Msg) Pack() (msg []byte, err error) {
if dns == nil {
return nil, false
return nil, &Error{Err: "nil message"}
}
var dh Header
compression := make(map[string]int) // Compression pointer mappings
@ -1227,34 +1165,41 @@ func (dns *Msg) Pack() (msg []byte, ok bool) {
// Pack it in: header and then the pieces.
off := 0
off, ok = packStructCompress(&dh, msg, off, compression, dns.Compress)
off, err = packStructCompress(&dh, msg, off, compression, dns.Compress)
for i := 0; i < len(question); i++ {
off, ok = packStructCompress(&question[i], msg, off, compression, dns.Compress)
off, err = packStructCompress(&question[i], msg, off, compression, dns.Compress)
if err != nil {
return nil, err
}
}
for i := 0; i < len(answer); i++ {
off, ok = PackRR(answer[i], msg, off, compression, dns.Compress)
off, err = PackRR(answer[i], msg, off, compression, dns.Compress)
if err != nil {
return nil, err
}
}
for i := 0; i < len(ns); i++ {
off, ok = PackRR(ns[i], msg, off, compression, dns.Compress)
off, err = PackRR(ns[i], msg, off, compression, dns.Compress)
if err != nil {
return nil, err
}
}
for i := 0; i < len(extra); i++ {
off, ok = PackRR(extra[i], msg, off, compression, dns.Compress)
off, err = PackRR(extra[i], msg, off, compression, dns.Compress)
if err != nil {
return nil, err
}
}
if !ok {
return nil, false
}
//println("allocated", dns.Len()+1, "used", off)
return msg[:off], true
return msg[:off], nil
}
// Unpack unpacks a binary message to a Msg structure.
func (dns *Msg) Unpack(msg []byte) bool {
func (dns *Msg) Unpack(msg []byte) (err error) {
// Header.
var dh Header
off := 0
var ok bool
if off, ok = UnpackStruct(&dh, msg, off); !ok {
return false
if off, err = UnpackStruct(&dh, msg, off); err != nil {
return err
}
dns.Id = dh.Id
dns.Response = (dh.Bits & _QR) != 0
@ -1275,25 +1220,34 @@ func (dns *Msg) Unpack(msg []byte) bool {
dns.Extra = make([]RR, dh.Arcount)
for i := 0; i < len(dns.Question); i++ {
off, ok = UnpackStruct(&dns.Question[i], msg, off)
off, err = UnpackStruct(&dns.Question[i], msg, off)
if err != nil {
return err
}
}
for i := 0; i < len(dns.Answer); i++ {
dns.Answer[i], off, ok = UnpackRR(msg, off)
dns.Answer[i], off, err = UnpackRR(msg, off)
if err != nil {
return err
}
}
for i := 0; i < len(dns.Ns); i++ {
dns.Ns[i], off, ok = UnpackRR(msg, off)
dns.Ns[i], off, err = UnpackRR(msg, off)
if err != nil {
return err
}
}
for i := 0; i < len(dns.Extra); i++ {
dns.Extra[i], off, ok = UnpackRR(msg, off)
}
if !ok {
return false
dns.Extra[i], off, err = UnpackRR(msg, off)
if err != nil {
return err
}
}
if off != len(msg) {
// TODO(mg) remove eventually
// println("extra bytes in dns packet", off, "<", len(msg))
}
return true
return nil
}
// Convert a complete message to a string with dig-like output.

View File

@ -38,14 +38,14 @@ func HashName(label string, ha uint8, iter uint16, salt string) string {
saltwire := new(saltWireFmt)
saltwire.Salt = salt
wire := make([]byte, DefaultMsgSize)
n, ok := PackStruct(saltwire, wire, 0)
if !ok {
n, err := PackStruct(saltwire, wire, 0)
if err != nil {
return ""
}
wire = wire[:n]
name := make([]byte, 255)
off, ok1 := PackDomainName(strings.ToLower(label), name, 0, nil, false)
if !ok1 {
off, err := PackDomainName(strings.ToLower(label), name, 0, nil, false)
if err != nil {
return ""
}
name = name[:off]

View File

@ -396,7 +396,7 @@ func (c *conn) serve() {
w := new(response)
w.conn = c
req := new(Msg)
if !req.Unpack(c.request) {
if req.Unpack(c.request) != nil {
// Send a format error back
x := new(Msg)
x.SetRcodeFormatError(req)
@ -436,10 +436,7 @@ func (c *conn) serve() {
// Write implements the ResponseWriter.Write method.
func (w *response) Write(m *Msg) (err error) {
var (
data []byte
ok bool
)
var data []byte
if m == nil {
return &Error{Err: "nil message"}
}
@ -449,9 +446,9 @@ func (w *response) Write(m *Msg) (err error) {
return err
}
} else {
data, ok = m.Pack()
if !ok {
return ErrPack
data, err = m.Pack()
if err != nil {
return err
}
}
return w.WriteBuf(data)
@ -470,7 +467,7 @@ func (w *response) WriteBuf(m []byte) (err error) {
}
case w.conn._TCP != nil:
if len(m) > MaxMsgSize {
return ErrBuf
return &Error{Err: "message too large"}
}
l := make([]byte, 2)
l[0], l[1] = packUint16(uint16(len(m)))

39
tsig.go
View File

@ -165,9 +165,9 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
rr := m.Extra[len(m.Extra)-1].(*RR_TSIG)
m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg
mbuf, ok := m.Pack()
if !ok {
return nil, "", ErrPack
mbuf, err := m.Pack()
if err != nil {
return nil, "", err
}
buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly)
@ -194,10 +194,10 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
t.OrigId = m.Id
tbuf := make([]byte, t.Len())
if off, ok := PackRR(t, tbuf, 0, nil, false); ok {
if off, err := PackRR(t, tbuf, 0, nil, false); err != nil {
tbuf = tbuf[:off] // reset to actual size used
} else {
return nil, "", ErrPack
return nil, "", err
}
mbuf = append(mbuf, tbuf...)
rawSetExtraLen(mbuf, uint16(len(m.Extra)+1))
@ -298,13 +298,13 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
// Copied from msg.go's Unpack()
// Header.
var dh Header
var err error
dns := new(Msg)
rr := new(RR_TSIG)
off := 0
tsigoff := 0
var ok bool
if off, ok = UnpackStruct(&dh, msg, off); !ok {
return nil, nil, ErrUnpack
if off, err = UnpackStruct(&dh, msg, off); err !=nil {
return nil, nil, err
}
if dh.Arcount == 0 {
return nil, nil, ErrNoSig
@ -321,17 +321,29 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
dns.Extra = make([]RR, dh.Arcount)
for i := 0; i < len(dns.Question); i++ {
off, ok = UnpackStruct(&dns.Question[i], msg, off)
off, err = UnpackStruct(&dns.Question[i], msg, off)
if err != nil {
return nil, nil, err
}
}
for i := 0; i < len(dns.Answer); i++ {
dns.Answer[i], off, ok = UnpackRR(msg, off)
dns.Answer[i], off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
}
for i := 0; i < len(dns.Ns); i++ {
dns.Ns[i], off, ok = UnpackRR(msg, off)
dns.Ns[i], off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
}
for i := 0; i < len(dns.Extra); i++ {
tsigoff = off
dns.Extra[i], off, ok = UnpackRR(msg, off)
dns.Extra[i], off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
if dns.Extra[i].Header().Rrtype == TypeTSIG {
rr = dns.Extra[i].(*RR_TSIG)
// Adjust Arcount.
@ -340,9 +352,6 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
break
}
}
if !ok {
return nil, nil, ErrUnpack
}
if rr == nil {
return nil, nil, ErrNoSig
}