Use proper error in packing and unpacking

All the relevant functions now return an error instead of
a simple boolean. This greatly approves the feedback to coders.

Spotted some fishy error handling along the way and fix that too.
This commit is contained in:
Miek Gieben 2012-10-09 21:17:54 +02:00
parent 099c19d5b2
commit 570bf8dc69
6 changed files with 211 additions and 252 deletions

View File

@ -101,9 +101,9 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, err error) {
func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
var n int var n int
var w *reply var w *reply
out, ok := m.Pack() out, err := m.Pack()
if !ok { if err != nil {
return nil, 0, ErrPack return nil, 0, err
} }
var in []byte var in []byte
switch c.Net { switch c.Net {
@ -123,8 +123,8 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err e
} }
r = new(Msg) r = new(Msg)
r.Size = n r.Size = n
if ok := r.Unpack(in[:n]); !ok { if err := r.Unpack(in[:n]); err != nil {
return nil, w.rtt, ErrUnpack return nil, w.rtt, err
} }
return r, w.rtt, nil return r, w.rtt, nil
} }
@ -158,8 +158,8 @@ func (w *reply) receive() (*Msg, error) {
return nil, err return nil, err
} }
p = p[:n] p = p[:n]
if ok := m.Unpack(p); !ok { if err := m.Unpack(p); err != nil {
return nil, ErrUnpack return nil, err
} }
w.rtt = time.Since(w.t) w.rtt = time.Since(w.t)
m.Size = n m.Size = n
@ -260,10 +260,9 @@ func (w *reply) send(m *Msg) (err error) {
} }
w.tsigRequestMAC = mac w.tsigRequestMAC = mac
} else { } else {
ok := false out, err = m.Pack()
out, ok = m.Pack() if err != nil {
if !ok { return err
return ErrPack
} }
} }
w.t = time.Now() w.t = time.Now()

View File

@ -119,8 +119,8 @@ func (k *RR_DNSKEY) KeyTag() uint16 {
keywire.Algorithm = k.Algorithm keywire.Algorithm = k.Algorithm
keywire.PublicKey = k.PublicKey keywire.PublicKey = k.PublicKey
wire := make([]byte, DefaultMsgSize) wire := make([]byte, DefaultMsgSize)
n, ok := PackStruct(keywire, wire, 0) n, err := PackStruct(keywire, wire, 0)
if !ok { if err != nil {
return 0 return 0
} }
wire = wire[:n] wire = wire[:n]
@ -157,15 +157,15 @@ func (k *RR_DNSKEY) ToDS(h int) *RR_DS {
keywire.Algorithm = k.Algorithm keywire.Algorithm = k.Algorithm
keywire.PublicKey = k.PublicKey keywire.PublicKey = k.PublicKey
wire := make([]byte, DefaultMsgSize) wire := make([]byte, DefaultMsgSize)
n, ok := PackStruct(keywire, wire, 0) n, err := PackStruct(keywire, wire, 0)
if !ok { if err != nil {
return nil return nil
} }
wire = wire[:n] wire = wire[:n]
owner := make([]byte, 255) owner := make([]byte, 255)
off, ok1 := PackDomainName(k.Hdr.Name, owner, 0, nil, false) off, err1 := PackDomainName(k.Hdr.Name, owner, 0, nil, false)
if !ok1 { if err1 != nil {
return nil return nil
} }
owner = owner[:off] owner = owner[:off]
@ -237,9 +237,9 @@ func (rr *RR_RRSIG) Sign(k PrivateKey, rrset []RR) error {
// Create the desired binary blob // Create the desired binary blob
signdata := make([]byte, DefaultMsgSize) signdata := make([]byte, DefaultMsgSize)
n, ok := PackStruct(sigwire, signdata, 0) n, err := PackStruct(sigwire, signdata, 0)
if !ok { if err != nil {
return ErrPack return err
} }
signdata = signdata[:n] signdata = signdata[:n]
wire := rawSignatureData(rrset, rr) wire := rawSignatureData(rrset, rr)
@ -349,9 +349,9 @@ func (rr *RR_RRSIG) Verify(k *RR_DNSKEY, rrset []RR) error {
sigwire.SignerName = strings.ToLower(rr.SignerName) sigwire.SignerName = strings.ToLower(rr.SignerName)
// Create the desired binary blob // Create the desired binary blob
signeddata := make([]byte, DefaultMsgSize) signeddata := make([]byte, DefaultMsgSize)
n, ok := PackStruct(sigwire, signeddata, 0) n, err := PackStruct(sigwire, signeddata, 0)
if !ok { if err != nil {
return ErrPack return err
} }
signeddata = signeddata[:n] signeddata = signeddata[:n]
wire := rawSignatureData(rrset, rr) wire := rawSignatureData(rrset, rr)
@ -684,8 +684,8 @@ func rawSignatureData(rrset []RR, s *RR_RRSIG) (buf []byte) {
} }
// 6.2. Canonical RR Form. (5) - origTTL // 6.2. Canonical RR Form. (5) - origTTL
wire := make([]byte, r.Len()*2) wire := make([]byte, r.Len()*2)
off, ok1 := PackRR(r1, wire, 0, nil, false) off, err1 := PackRR(r1, wire, 0, nil, false)
if !ok1 { if err1 != nil {
return nil return nil
} }
wire = wire[:off] wire = wire[:off]

348
msg.go
View File

@ -28,8 +28,10 @@ const maxCompressionOffset = 2 << 13 // We have 14 bits for the compression poin
var ( var (
ErrUnpack error = &Error{Err: "unpacking failed"} ErrUnpack error = &Error{Err: "unpacking failed"}
ErrPack error = &Error{Err: "packing failed"} ErrPack error = &Error{Err: "packing failed"}
ErrFqdn error = &Error{Err: "domain must be fully qualified"}
ErrId error = &Error{Err: "id mismatch"} ErrId error = &Error{Err: "id mismatch"}
ErrBuf error = &Error{Err: "buffer size too large"} ErrRdata error = &Error{Err: "bad rdata"}
ErrBuf error = &Error{Err: "buffer size too small"}
ErrShortRead error = &Error{Err: "short read"} ErrShortRead error = &Error{Err: "short read"}
ErrConn error = &Error{Err: "conn holds both UDP and TCP connection"} ErrConn error = &Error{Err: "conn holds both UDP and TCP connection"}
ErrConnEmpty error = &Error{Err: "conn has no connection"} ErrConnEmpty error = &Error{Err: "conn has no connection"}
@ -46,10 +48,7 @@ var (
ErrSigGen error = &Error{Err: "bad signature generation"} ErrSigGen error = &Error{Err: "bad signature generation"}
ErrAuth error = &Error{Err: "bad authentication"} ErrAuth error = &Error{Err: "bad authentication"}
ErrSoa error = &Error{Err: "no SOA"} ErrSoa error = &Error{Err: "no SOA"}
ErrHandle error = &Error{Err: "handle is nil"} ErrRRset error = &Error{Err: "bad rrset"}
ErrChan error = &Error{Err: "channel is nil"}
ErrName error = &Error{Err: "type not found for name"}
ErrRRset error = &Error{Err: "invalid rrset"}
ErrDenialNsec3 error = &Error{Err: "no NSEC3 records"} ErrDenialNsec3 error = &Error{Err: "no NSEC3 records"}
ErrDenialCe error = &Error{Err: "no matching closest encloser found"} ErrDenialCe error = &Error{Err: "no matching closest encloser found"}
ErrDenialNc error = &Error{Err: "no covering NSEC3 found for next closer"} ErrDenialNc error = &Error{Err: "no covering NSEC3 found for next closer"}
@ -200,13 +199,12 @@ var Rcode_str = map[int]string{
// If compression is wanted compress must be true and the compression // If compression is wanted compress must be true and the compression
// map needs to hold a mapping between domain names and offsets // map needs to hold a mapping between domain names and offsets
// pointing into msg[]. // pointing into msg[].
func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) { func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
// Add trailing dot to canonicalize name.
lenmsg := len(msg) lenmsg := len(msg)
ls := len(s) ls := len(s)
// If not fully qualified, error out
if ls == 0 || s[ls-1] != '.' { if ls == 0 || s[ls-1] != '.' {
//println("dns: name not fully qualified") return lenmsg, ErrFqdn
return lenmsg, false
} }
// Each dot ends a segment of the name. // Each dot ends a segment of the name.
@ -234,30 +232,22 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
if bs[i] == '.' { if bs[i] == '.' {
if i-begin >= 1<<6 { // top two bits of length must be clear if i-begin >= 1<<6 { // top two bits of length must be clear
return lenmsg, false return lenmsg, ErrRdata
} }
// off can already (we're in a loop) be bigger than len(msg) // off can already (we're in a loop) be bigger than len(msg)
// this happens when a name isn't fully qualified // this happens when a name isn't fully qualified
if off+1 > len(msg) { if off+1 > len(msg) {
return lenmsg, false return lenmsg, ErrBuf
} }
msg[off] = byte(i - begin) msg[off] = byte(i - begin)
offset := off offset := off
off++ off++
// TODO(mg): because of the new check above, this can go. But
// just leave it as is for the moment.
// if off > lenmsg {
// return lenmsg, false
// }
for j := begin; j < i; j++ { for j := begin; j < i; j++ {
if off+1 > len(msg) { if off+1 > len(msg) {
return lenmsg, false return lenmsg, ErrBuf
} }
msg[off] = bs[j] msg[off] = bs[j]
off++ off++
// if off > lenmsg {
// return lenmsg, false
// }
} }
// Dont try to compress '.' // Dont try to compress '.'
if compression != nil && string(bs[begin:]) != ".'" { if compression != nil && string(bs[begin:]) != ".'" {
@ -285,7 +275,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
} }
// Root label is special // Root label is special
if string(bs) == "." { if string(bs) == "." {
return off, true return off, nil
} }
// If we did compression and we find something at the pointer here // If we did compression and we find something at the pointer here
if pointer != -1 { if pointer != -1 {
@ -297,7 +287,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
msg[off] = 0 msg[off] = 0
End: End:
off++ off++
return off, true return off, nil
} }
// Unpack a domain name. // Unpack a domain name.
@ -315,14 +305,14 @@ End:
// We let them jump anywhere and stop jumping after a while. // We let them jump anywhere and stop jumping after a while.
// UnpackDomainName unpacks a domain name into a string. // UnpackDomainName unpacks a domain name into a string.
func UnpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) { func UnpackDomainName(msg []byte, off int) (s string, off1 int, err error) {
s = "" s = ""
lenmsg := len(msg) lenmsg := len(msg)
ptr := 0 // number of pointers followed ptr := 0 // number of pointers followed
Loop: Loop:
for { for {
if off >= lenmsg { if off >= lenmsg {
return "", lenmsg, false return "", lenmsg, ErrBuf
} }
c := int(msg[off]) c := int(msg[off])
off++ off++
@ -331,13 +321,13 @@ Loop:
if c == 0x00 { if c == 0x00 {
// end of name // end of name
if s == "" { if s == "" {
return ".", off, true return ".", off, nil
} }
break Loop break Loop
} }
// literal string // literal string
if off+c > lenmsg { if off+c > lenmsg {
return "", lenmsg, false return "", lenmsg, ErrBuf
} }
for j := off; j < off+c; j++ { for j := off; j < off+c; j++ {
if msg[j] == '.' { if msg[j] == '.' {
@ -356,7 +346,7 @@ Loop:
// also, don't follow too many pointers -- // also, don't follow too many pointers --
// maybe there's a loop. // maybe there's a loop.
if off >= lenmsg { if off >= lenmsg {
return "", lenmsg, false return "", lenmsg, ErrBuf
} }
c1 := msg[off] c1 := msg[off]
off++ off++
@ -364,41 +354,38 @@ Loop:
off1 = off off1 = off
} }
if ptr++; ptr > 10 { if ptr++; ptr > 10 {
return "", lenmsg, false return "", lenmsg, &Error{Err: "too many compression pointers"}
} }
off = (c^0xC0)<<8 | int(c1) off = (c^0xC0)<<8 | int(c1)
default: default:
// 0x80 and 0x40 are reserved // 0x80 and 0x40 are reserved
return "", lenmsg, false return "", lenmsg, ErrRdata
} }
} }
if ptr == 0 { if ptr == 0 {
off1 = off off1 = off
} }
return s, off1, true return s, off1, nil
} }
// Pack a reflect.StructValue into msg. Struct members can only be uint8, uint16, uint32, string, // Pack a reflect.StructValue into msg. Struct members can only be uint8, uint16, uint32, string,
// slices and other (often anonymous) structs. // slices and other (often anonymous) structs.
func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) { func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
for i := 0; i < val.NumField(); i++ {
// f := val.Type().Field(i)
lenmsg := len(msg) lenmsg := len(msg)
for i := 0; i < val.NumField(); i++ {
switch fv := val.Field(i); fv.Kind() { switch fv := val.Field(i); fv.Kind() {
default: default:
return lenmsg, false return lenmsg, &Error{Err: "bad kind packing"}
case reflect.Slice: case reflect.Slice:
switch val.Type().Field(i).Tag.Get("dns") { switch val.Type().Field(i).Tag.Get("dns") {
default: default:
// println("dns: unknown tag packing slice", val.Type().Field(i).Tag.Get("dns"), '"', val.Type().Field(i).Tag, '"') return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag packing slice"}
return lenmsg, false
case "domain-name": case "domain-name":
for j := 0; j < val.Field(i).Len(); j++ { for j := 0; j < val.Field(i).Len(); j++ {
element := val.Field(i).Index(j).String() element := val.Field(i).Index(j).String()
off, ok = PackDomainName(element, msg, off, compression, false && compress) off, err = PackDomainName(element, msg, off, compression, false && compress)
if !ok { if err != nil {
// println("dns: overflow packing domain-name", off) return lenmsg, err
return lenmsg, false
} }
} }
case "txt": case "txt":
@ -406,8 +393,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
element := val.Field(i).Index(j).String() element := val.Field(i).Index(j).String()
// Counted string: 1 byte length. // Counted string: 1 byte length.
if len(element) > 255 || off+1+len(element) > lenmsg { if len(element) > 255 || off+1+len(element) > lenmsg {
// println("dns: overflow packing TXT string") return lenmsg, &Error{Err: "overflow packing txt"}
return lenmsg, false
} }
msg[off] = byte(len(element)) msg[off] = byte(len(element))
off++ off++
@ -421,8 +407,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
element := val.Field(i).Index(j).Interface() element := val.Field(i).Index(j).Interface()
b, e := element.(EDNS0).pack() b, e := element.(EDNS0).pack()
if e != nil { if e != nil {
// println("dns: failure packing OPT") return lenmsg, &Error{Err: "overflow packing opt"}
return lenmsg, false
} }
// Option code // Option code
msg[off], msg[off+1] = packUint16(element.(EDNS0).Option()) msg[off], msg[off+1] = packUint16(element.(EDNS0).Option())
@ -436,22 +421,17 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "a": case "a":
// It must be a slice of 4, even if it is 16, we encode // It must be a slice of 4, even if it is 16, we encode
// only the first 4 // only the first 4
if off+net.IPv4len > lenmsg {
return lenmsg, &Error{Err: "overflow packing a"}
}
switch fv.Len() { switch fv.Len() {
case net.IPv6len: case net.IPv6len:
if off+net.IPv4len > lenmsg {
// println("dns: overflow packing A", off, lenmsg)
return lenmsg, false
}
msg[off] = byte(fv.Index(12).Uint()) msg[off] = byte(fv.Index(12).Uint())
msg[off+1] = byte(fv.Index(13).Uint()) msg[off+1] = byte(fv.Index(13).Uint())
msg[off+2] = byte(fv.Index(14).Uint()) msg[off+2] = byte(fv.Index(14).Uint())
msg[off+3] = byte(fv.Index(15).Uint()) msg[off+3] = byte(fv.Index(15).Uint())
off += net.IPv4len off += net.IPv4len
case net.IPv4len: case net.IPv4len:
if off+net.IPv4len > lenmsg {
// println("dns: overflow packing A", off, lenmsg)
return lenmsg, false
}
msg[off] = byte(fv.Index(0).Uint()) msg[off] = byte(fv.Index(0).Uint())
msg[off+1] = byte(fv.Index(1).Uint()) msg[off+1] = byte(fv.Index(1).Uint())
msg[off+2] = byte(fv.Index(2).Uint()) msg[off+2] = byte(fv.Index(2).Uint())
@ -460,13 +440,11 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case 0: case 0:
// Allowed, for dynamic updates // Allowed, for dynamic updates
default: default:
// println("dns: overflow packing A") return lenmsg, &Error{Err: "overflow packing a"}
return lenmsg, false
} }
case "aaaa": case "aaaa":
if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg { if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg {
// println("dns: overflow packing AAAA") return lenmsg, &Error{Err: "overflow packing aaaa"}
return lenmsg, false
} }
for j := 0; j < net.IPv6len; j++ { for j := 0; j < net.IPv6len; j++ {
msg[off] = byte(fv.Index(j).Uint()) msg[off] = byte(fv.Index(j).Uint())
@ -481,8 +459,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
serv := uint16((fv.Index(j).Uint())) serv := uint16((fv.Index(j).Uint()))
bitmapbyte = uint16(serv / 8) bitmapbyte = uint16(serv / 8)
if int(bitmapbyte) > lenmsg { if int(bitmapbyte) > lenmsg {
// println("dns: overflow packing WKS") return lenmsg, &Error{Err: "overflow packing wks"}
return lenmsg, false
} }
bit := uint16(serv) - bitmapbyte*8 bit := uint16(serv) - bitmapbyte*8
msg[bitmapbyte] = byte(1 << (7 - bit)) msg[bitmapbyte] = byte(1 << (7 - bit))
@ -498,8 +475,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
lastwindow := uint16(0) lastwindow := uint16(0)
length := uint16(0) length := uint16(0)
if off+2 > lenmsg { if off+2 > lenmsg {
// println("dns: overflow packing NSECx bitmap") return lenmsg, &Error{Err: "overflow packing nsecx"}
return lenmsg, false
} }
for j := 0; j < val.Field(i).Len(); j++ { for j := 0; j < val.Field(i).Len(); j++ {
t := uint16((fv.Index(j).Uint())) t := uint16((fv.Index(j).Uint()))
@ -508,15 +484,13 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
// New window, jump to the new offset // New window, jump to the new offset
off += int(length) + 3 off += int(length) + 3
if off > lenmsg { if off > lenmsg {
// println("dns: overflow packing NSECx bitmap") return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
return lenmsg, false
} }
} }
length = (t - window*256) / 8 length = (t - window*256) / 8
bit := t - (window * 256) - (length * 8) bit := t - (window * 256) - (length * 8)
if off+2+int(length) > lenmsg { if off+2+int(length) > lenmsg {
// println("dns: overflow packing NSECx bitmap") return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
return lenmsg, false
} }
// Setting the window # // Setting the window #
@ -530,23 +504,20 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off += 2 + int(length) off += 2 + int(length)
off++ off++
if off > lenmsg { if off > lenmsg {
// println("dns: overflow packing NSECx bitmap") return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
return lenmsg, false
} }
} }
case reflect.Struct: case reflect.Struct:
off, ok = packStructValue(fv, msg, off, compression, compress) off, err = packStructValue(fv, msg, off, compression, compress)
case reflect.Uint8: case reflect.Uint8:
if off+1 > lenmsg { if off+1 > lenmsg {
// println("dns: overflow packing uint8") return lenmsg, &Error{Err: "overflow packing uint8"}
return lenmsg, false
} }
msg[off] = byte(fv.Uint()) msg[off] = byte(fv.Uint())
off++ off++
case reflect.Uint16: case reflect.Uint16:
if off+2 > lenmsg { if off+2 > lenmsg {
// println("dns: overflow packing uint16") return lenmsg, &Error{Err: "overflow packing uint16"}
return lenmsg, false
} }
i := fv.Uint() i := fv.Uint()
msg[off] = byte(i >> 8) msg[off] = byte(i >> 8)
@ -554,8 +525,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off += 2 off += 2
case reflect.Uint32: case reflect.Uint32:
if off+4 > lenmsg { if off+4 > lenmsg {
// println("dns: overflow packing uint32") return lenmsg, &Error{Err: "overflow packing uint32"}
return lenmsg, false
} }
i := fv.Uint() i := fv.Uint()
msg[off] = byte(i >> 24) msg[off] = byte(i >> 24)
@ -566,8 +536,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case reflect.Uint64: case reflect.Uint64:
// Only used in TSIG, where it stops at 48 bits, so we discard the upper 16 // Only used in TSIG, where it stops at 48 bits, so we discard the upper 16
if off+6 > lenmsg { if off+6 > lenmsg {
// println("dns: overflow packing uint64") return lenmsg, &Error{Err: "overflow packing uint64 as uint48"}
return lenmsg, false
} }
i := fv.Uint() i := fv.Uint()
msg[off] = byte(i >> 40) msg[off] = byte(i >> 40)
@ -583,24 +552,21 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
s := fv.String() s := fv.String()
switch val.Type().Field(i).Tag.Get("dns") { switch val.Type().Field(i).Tag.Get("dns") {
default: default:
return lenmsg, false return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag packing string"}
case "base64": case "base64":
b64, err := packBase64([]byte(s)) b64, err := packBase64([]byte(s))
if err != nil { if err != nil {
// println("dns: overflow packing base64") return lenmsg, &Error{Err: "overflow packing base64"}
return lenmsg, false
} }
copy(msg[off:off+len(b64)], b64) copy(msg[off:off+len(b64)], b64)
off += len(b64) off += len(b64)
case "domain-name": case "domain-name":
if off, ok = PackDomainName(s, msg, off, compression, false && compress); !ok { if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil {
// println("dns: overflow packing domain-name", off) return lenmsg, err
return lenmsg, false
} }
case "cdomain-name": case "cdomain-name":
if off, ok = PackDomainName(s, msg, off, compression, true && compress); !ok { if off, err = PackDomainName(s, msg, off, compression, true && compress); err != nil {
// println("dns: overflow packing domain-name", off) return lenmsg, err
return lenmsg, false
} }
case "size-base32": case "size-base32":
// This is purely for NSEC3 atm, the previous byte must // This is purely for NSEC3 atm, the previous byte must
@ -611,8 +577,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "base32": case "base32":
b32, err := packBase32([]byte(s)) b32, err := packBase32([]byte(s))
if err != nil { if err != nil {
// println("dns: overflow packing base32") return lenmsg, &Error{Err: "overflow packing base32"}
return lenmsg, false
} }
copy(msg[off:off+len(b32)], b32) copy(msg[off:off+len(b32)], b32)
off += len(b32) off += len(b32)
@ -622,12 +587,10 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
// There is no length encoded here // There is no length encoded here
h, e := hex.DecodeString(s) h, e := hex.DecodeString(s)
if e != nil { if e != nil {
// println("dns: overflow packing (size-)hex string") return lenmsg, &Error{Err: "overflow packing hex"}
return lenmsg, false
} }
if off+hex.DecodedLen(len(s)) > lenmsg { if off+hex.DecodedLen(len(s)) > lenmsg {
// Overflow return lenmsg, &Error{Err: "overflow packing hex"}
return lenmsg, false
} }
copy(msg[off:off+hex.DecodedLen(len(s))], h) copy(msg[off:off+hex.DecodedLen(len(s))], h)
off += hex.DecodedLen(len(s)) off += hex.DecodedLen(len(s))
@ -641,8 +604,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "": case "":
// Counted string: 1 byte length. // Counted string: 1 byte length.
if len(s) > 255 || off+1+len(s) > lenmsg { if len(s) > 255 || off+1+len(s) > lenmsg {
// println("dns: overflow packing string") return lenmsg, &Error{Err: "overflow packing string"}
return lenmsg, false
} }
msg[off] = byte(len(s)) msg[off] = byte(len(s))
off++ off++
@ -653,48 +615,44 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
} }
} }
} }
return off, true return off, nil
} }
func structValue(any interface{}) reflect.Value { func structValue(any interface{}) reflect.Value {
return reflect.ValueOf(any).Elem() return reflect.ValueOf(any).Elem()
} }
func PackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { func PackStruct(any interface{}, msg []byte, off int) (off1 int, err error) {
off, ok = packStructValue(structValue(any), msg, off, nil, false) off, err = packStructValue(structValue(any), msg, off, nil, false)
return off, ok return off, err
} }
func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) { func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
off, ok = packStructValue(structValue(any), msg, off, compression, compress) off, err = packStructValue(structValue(any), msg, off, compression, compress)
return off, ok return off, err
} }
// Unpack a reflect.StructValue from msg. // Unpack a reflect.StructValue from msg.
// Same restrictions as packStructValue. // Same restrictions as packStructValue.
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) { func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) {
var rdstart int var rdstart int
for i := 0; i < val.NumField(); i++ {
// f := val.Type().Field(i)
lenmsg := len(msg) lenmsg := len(msg)
for i := 0; i < val.NumField(); i++ {
switch fv := val.Field(i); fv.Kind() { switch fv := val.Field(i); fv.Kind() {
default: default:
// println("dns: unknown case unpacking struct") return lenmsg, &Error{Err: "bad kind unpacking"}
return lenmsg, false
case reflect.Slice: case reflect.Slice:
switch val.Type().Field(i).Tag.Get("dns") { switch val.Type().Field(i).Tag.Get("dns") {
default: default:
// println("dns: unknown tag unpacking slice", val.Type().Field(i).Tag) return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag unpacking slice"}
return lenmsg, false
case "domain-name": case "domain-name":
// HIP record slice of name (or none) // HIP record slice of name (or none)
servers := make([]string, 0) servers := make([]string, 0)
var s string var s string
for off < lenmsg { for off < lenmsg {
s, off, ok = UnpackDomainName(msg, off) s, off, err = UnpackDomainName(msg, off)
if !ok { if err != nil {
// println("dns: failure unpacking domain-name") return lenmsg, err
return lenmsg, false
} }
servers = append(servers, s) servers = append(servers, s)
} }
@ -705,8 +663,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
Txts: Txts:
l := int(msg[off]) l := int(msg[off])
if off+l+1 > lenmsg { if off+l+1 > lenmsg {
// println("dns: failure unpacking txt strings") return lenmsg, &Error{Err: "overflow unpacking txt"}
return lenmsg, false
} }
txt = append(txt, string(msg[off+1:off+l+1])) txt = append(txt, string(msg[off+1:off+l+1]))
off += l + 1 off += l + 1
@ -724,14 +681,14 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
break break
} }
edns := make([]EDNS0, 0) edns := make([]EDNS0, 0)
// Goto to this place, when there is a goto
code := uint16(0) code := uint16(0)
if off+2 > lenmsg {
code, off = unpackUint16(msg, off) // Overflow? TODO return lenmsg, &Error{Err: "overflow unpacking opt"}
}
code, off = unpackUint16(msg, off)
optlen, off1 := unpackUint16(msg, off) optlen, off1 := unpackUint16(msg, off)
if off1+int(optlen) > off+rdlength { if off1+int(optlen) > off+rdlength {
// println("dns: overflow unpacking OPT") return lenmsg, &Error{Err: "overflow unpacking opt"}
return lenmsg, false
} }
switch code { switch code {
case EDNS0NSID: case EDNS0NSID:
@ -746,18 +703,16 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
off = off1 + int(optlen) off = off1 + int(optlen)
} }
fv.Set(reflect.ValueOf(edns)) fv.Set(reflect.ValueOf(edns))
// goto ?? // multiple EDNS codes?
case "a": case "a":
if off+net.IPv4len > len(msg) { if off+net.IPv4len > len(msg) {
// println("dns: overflow unpacking A") return lenmsg, &Error{Err: "overflow unpacking a"}
return lenmsg, false
} }
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]))) fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
off += net.IPv4len off += net.IPv4len
case "aaaa": case "aaaa":
if off+net.IPv6len > lenmsg { if off+net.IPv6len > lenmsg {
// println("dns: overflow unpacking AAAA") return lenmsg, &Error{Err: "overflow unpacking aaaa"}
return lenmsg, false
} }
fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4], fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4],
msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10], msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10],
@ -806,8 +761,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
endrr := rdstart + rdlength endrr := rdstart + rdlength
if off+2 > lenmsg { if off+2 > lenmsg {
// println("dns: overflow unpacking NSEC") return lenmsg, &Error{Err: "overflow unpacking nsecx"}
return lenmsg, false
} }
nsec := make([]uint16, 0) nsec := make([]uint16, 0)
length := 0 length := 0
@ -820,15 +774,13 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
// A length window of zero is strange. If there // A length window of zero is strange. If there
// the window should not have been specified. Bail out // the window should not have been specified. Bail out
// println("dns: length == 0 when unpacking NSEC") // println("dns: length == 0 when unpacking NSEC")
return lenmsg, false return lenmsg, ErrRdata
} }
if length > 32 { if length > 32 {
// println("dns: length > 32 when unpacking NSEC") return lenmsg, ErrRdata
return lenmsg, false
} }
// Walk the bytes in the window - and check the bit // Walk the bytes in the window - and check the bit settings...
// setting..
off += 2 off += 2
for j := 0; j < length; j++ { for j := 0; j < length; j++ {
b := msg[off+j] b := msg[off+j]
@ -863,29 +815,26 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
fv.Set(reflect.ValueOf(nsec)) fv.Set(reflect.ValueOf(nsec))
} }
case reflect.Struct: case reflect.Struct:
off, ok = unpackStructValue(fv, msg, off) off, err = unpackStructValue(fv, msg, off)
if val.Type().Field(i).Name == "Hdr" { if val.Type().Field(i).Name == "Hdr" {
rdstart = off rdstart = off
} }
case reflect.Uint8: case reflect.Uint8:
if off+1 > lenmsg { if off+1 > lenmsg {
// println("dns: overflow unpacking uint8") return lenmsg, &Error{Err: "overflow unpacking uint8"}
return lenmsg, false
} }
fv.SetUint(uint64(uint8(msg[off]))) fv.SetUint(uint64(uint8(msg[off])))
off++ off++
case reflect.Uint16: case reflect.Uint16:
var i uint16 var i uint16
if off+2 > lenmsg { if off+2 > lenmsg {
// println("dns: overflow unpacking uint16") return lenmsg, &Error{Err: "overflow unpacking uint16"}
return lenmsg, false
} }
i, off = unpackUint16(msg, off) i, off = unpackUint16(msg, off)
fv.SetUint(uint64(i)) fv.SetUint(uint64(i))
case reflect.Uint32: case reflect.Uint32:
if off+4 > lenmsg { if off+4 > lenmsg {
// println("dns: overflow unpacking uint32") return lenmsg, &Error{Err: "overflow unpacking uint32"}
return lenmsg, false
} }
fv.SetUint(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]))) fv.SetUint(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])))
off += 4 off += 4
@ -893,8 +842,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
// This is *only* used in TSIG where the last 48 bits are occupied // This is *only* used in TSIG where the last 48 bits are occupied
// So for now, assume a uint48 (6 bytes) // So for now, assume a uint48 (6 bytes)
if off+6 > lenmsg { if off+6 > lenmsg {
// println("dns: overflow unpacking uint64") return lenmsg, &Error{Err: "overflow unpacking uint64 as uint48"}
return lenmsg, false
} }
fv.SetUint(uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 | fv.SetUint(uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
uint64(msg[off+4])<<8 | uint64(msg[off+5]))) uint64(msg[off+4])<<8 | uint64(msg[off+5])))
@ -903,15 +851,13 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
var s string var s string
switch val.Type().Field(i).Tag.Get("dns") { switch val.Type().Field(i).Tag.Get("dns") {
default: default:
// println("dns: unknown tag unpacking string") return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag unpacking string"}
return lenmsg, false
case "hex": case "hex":
// Rest of the RR is hex encoded, network order an issue here? // Rest of the RR is hex encoded, network order an issue here?
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength endrr := rdstart + rdlength
if endrr > lenmsg { if endrr > lenmsg {
// println("dns: overflow when unpacking hex string") return lenmsg, &Error{Err: "overflow unpacking hex"}
return lenmsg, false
} }
s = hex.EncodeToString(msg[off:endrr]) s = hex.EncodeToString(msg[off:endrr])
off = endrr off = endrr
@ -920,18 +866,16 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength endrr := rdstart + rdlength
if endrr > lenmsg { if endrr > lenmsg {
// println("dns: failure unpacking base64") return lenmsg, &Error{Err: "overflow unpacking base64"}
return lenmsg, false
} }
s = unpackBase64(msg[off:endrr]) s = unpackBase64(msg[off:endrr])
off = endrr off = endrr
case "cdomain-name": case "cdomain-name":
fallthrough fallthrough
case "domain-name": case "domain-name":
s, off, ok = UnpackDomainName(msg, off) s, off, err = UnpackDomainName(msg, off)
if !ok { if err != nil {
// println("dns: failure unpacking domain-name") return lenmsg, err
return lenmsg, false
} }
case "size-base32": case "size-base32":
var size int var size int
@ -944,8 +888,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
} }
} }
if off+size > lenmsg { if off+size > lenmsg {
// println("dns: failure unpacking size-base32 string") return lenmsg, &Error{Err: "overflow unpacking base32"}
return lenmsg, false
} }
s = unpackBase32(msg[off : off+size]) s = unpackBase32(msg[off : off+size])
off += size off += size
@ -973,8 +916,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
} }
} }
if off+size > lenmsg { if off+size > lenmsg {
// println("dns: failure unpacking size-hex string") return lenmsg, &Error{Err: "overflow unpacking hex"}
return lenmsg, false
} }
s = hex.EncodeToString(msg[off : off+size]) s = hex.EncodeToString(msg[off : off+size])
off += size off += size
@ -983,8 +925,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
Txt: Txt:
if off >= lenmsg || off+1+int(msg[off]) > lenmsg { if off >= lenmsg || off+1+int(msg[off]) > lenmsg {
// println("dns: failure unpacking txt string") return lenmsg, &Error{Err: "overflow unpacking txt"}
return lenmsg, false
} }
n := int(msg[off]) n := int(msg[off])
off++ off++
@ -998,8 +939,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
} }
case "": case "":
if off >= lenmsg || off+1+int(msg[off]) > lenmsg { if off >= lenmsg || off+1+int(msg[off]) > lenmsg {
// println("dns: failure unpacking string") return lenmsg, &Error{Err: "overflow unpacking string"}
return lenmsg, false
} }
n := int(msg[off]) n := int(msg[off])
off++ off++
@ -1011,7 +951,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
fv.SetString(s) fv.SetString(s)
} }
} }
return off, true return off, nil
} }
// Helper function for unpacking // Helper function for unpacking
@ -1021,9 +961,9 @@ func unpackUint16(msg []byte, off int) (v uint16, off1 int) {
return return
} }
func UnpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { func UnpackStruct(any interface{}, msg []byte, off int) (off1 int, err error) {
off, ok = unpackStructValue(structValue(any), msg, off) off, err = unpackStructValue(structValue(any), msg, off)
return off, ok return off, err
} }
func unpackBase32(b []byte) string { func unpackBase32(b []byte) string {
@ -1067,28 +1007,26 @@ func packBase32(s []byte) ([]byte, error) {
} }
// Resource record packer. // Resource record packer.
func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) { func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
if rr == nil { if rr == nil {
return len(msg), false return len(msg), &Error{Err: "nil rr"}
} }
off1, ok = packStructCompress(rr, msg, off, compression, compress) off1, err = packStructCompress(rr, msg, off, compression, compress)
if !ok { if err != nil {
return len(msg), false return len(msg), err
} }
if !rawSetRdlength(msg, off, off1) { rawSetRdlength(msg, off, off1)
return len(msg), false return off1, nil
}
return off1, true
} }
// Resource record unpacker. // Resource record unpacker.
func UnpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) { func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
// unpack just the header, to find the rr type and length // unpack just the header, to find the rr type and length
var h RR_Header var h RR_Header
off0 := off off0 := off
if off, ok = UnpackStruct(&h, msg, off); !ok { if off, err = UnpackStruct(&h, msg, off); err != nil {
return nil, len(msg), false return nil, len(msg), err
} }
end := off + int(h.Rdlength) end := off + int(h.Rdlength)
// make an rr of that type and re-unpack. // make an rr of that type and re-unpack.
@ -1098,11 +1036,11 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) {
} else { } else {
rr = mk() rr = mk()
} }
off, ok = UnpackStruct(rr, msg, off0) off, err = UnpackStruct(rr, msg, off0)
if off != end { if off != end {
return &h, end, true return &h, end, nil
} }
return rr, off, ok return rr, off, err
} }
// Reverse a map // Reverse a map
@ -1176,9 +1114,9 @@ func (h *MsgHdr) String() string {
// Pack packs a Msg: it is converted to to wire format. // Pack packs a Msg: it is converted to to wire format.
// If the dns.Compress is true the message will be in compressed wire format. // If the dns.Compress is true the message will be in compressed wire format.
func (dns *Msg) Pack() (msg []byte, ok bool) { func (dns *Msg) Pack() (msg []byte, err error) {
if dns == nil { if dns == nil {
return nil, false return nil, &Error{Err: "nil message"}
} }
var dh Header var dh Header
compression := make(map[string]int) // Compression pointer mappings compression := make(map[string]int) // Compression pointer mappings
@ -1227,34 +1165,41 @@ func (dns *Msg) Pack() (msg []byte, ok bool) {
// Pack it in: header and then the pieces. // Pack it in: header and then the pieces.
off := 0 off := 0
off, ok = packStructCompress(&dh, msg, off, compression, dns.Compress) off, err = packStructCompress(&dh, msg, off, compression, dns.Compress)
for i := 0; i < len(question); i++ { for i := 0; i < len(question); i++ {
off, ok = packStructCompress(&question[i], msg, off, compression, dns.Compress) off, err = packStructCompress(&question[i], msg, off, compression, dns.Compress)
if err != nil {
return nil, err
}
} }
for i := 0; i < len(answer); i++ { for i := 0; i < len(answer); i++ {
off, ok = PackRR(answer[i], msg, off, compression, dns.Compress) off, err = PackRR(answer[i], msg, off, compression, dns.Compress)
if err != nil {
return nil, err
}
} }
for i := 0; i < len(ns); i++ { for i := 0; i < len(ns); i++ {
off, ok = PackRR(ns[i], msg, off, compression, dns.Compress) off, err = PackRR(ns[i], msg, off, compression, dns.Compress)
if err != nil {
return nil, err
}
} }
for i := 0; i < len(extra); i++ { for i := 0; i < len(extra); i++ {
off, ok = PackRR(extra[i], msg, off, compression, dns.Compress) off, err = PackRR(extra[i], msg, off, compression, dns.Compress)
if err != nil {
return nil, err
} }
if !ok {
return nil, false
} }
//println("allocated", dns.Len()+1, "used", off) return msg[:off], nil
return msg[:off], true
} }
// Unpack unpacks a binary message to a Msg structure. // Unpack unpacks a binary message to a Msg structure.
func (dns *Msg) Unpack(msg []byte) bool { func (dns *Msg) Unpack(msg []byte) (err error) {
// Header. // Header.
var dh Header var dh Header
off := 0 off := 0
var ok bool if off, err = UnpackStruct(&dh, msg, off); err != nil {
if off, ok = UnpackStruct(&dh, msg, off); !ok { return err
return false
} }
dns.Id = dh.Id dns.Id = dh.Id
dns.Response = (dh.Bits & _QR) != 0 dns.Response = (dh.Bits & _QR) != 0
@ -1275,25 +1220,34 @@ func (dns *Msg) Unpack(msg []byte) bool {
dns.Extra = make([]RR, dh.Arcount) dns.Extra = make([]RR, dh.Arcount)
for i := 0; i < len(dns.Question); i++ { for i := 0; i < len(dns.Question); i++ {
off, ok = UnpackStruct(&dns.Question[i], msg, off) off, err = UnpackStruct(&dns.Question[i], msg, off)
if err != nil {
return err
}
} }
for i := 0; i < len(dns.Answer); i++ { for i := 0; i < len(dns.Answer); i++ {
dns.Answer[i], off, ok = UnpackRR(msg, off) dns.Answer[i], off, err = UnpackRR(msg, off)
if err != nil {
return err
}
} }
for i := 0; i < len(dns.Ns); i++ { for i := 0; i < len(dns.Ns); i++ {
dns.Ns[i], off, ok = UnpackRR(msg, off) dns.Ns[i], off, err = UnpackRR(msg, off)
if err != nil {
return err
}
} }
for i := 0; i < len(dns.Extra); i++ { for i := 0; i < len(dns.Extra); i++ {
dns.Extra[i], off, ok = UnpackRR(msg, off) dns.Extra[i], off, err = UnpackRR(msg, off)
if err != nil {
return err
} }
if !ok {
return false
} }
if off != len(msg) { if off != len(msg) {
// TODO(mg) remove eventually // TODO(mg) remove eventually
// println("extra bytes in dns packet", off, "<", len(msg)) // println("extra bytes in dns packet", off, "<", len(msg))
} }
return true return nil
} }
// Convert a complete message to a string with dig-like output. // Convert a complete message to a string with dig-like output.

View File

@ -38,14 +38,14 @@ func HashName(label string, ha uint8, iter uint16, salt string) string {
saltwire := new(saltWireFmt) saltwire := new(saltWireFmt)
saltwire.Salt = salt saltwire.Salt = salt
wire := make([]byte, DefaultMsgSize) wire := make([]byte, DefaultMsgSize)
n, ok := PackStruct(saltwire, wire, 0) n, err := PackStruct(saltwire, wire, 0)
if !ok { if err != nil {
return "" return ""
} }
wire = wire[:n] wire = wire[:n]
name := make([]byte, 255) name := make([]byte, 255)
off, ok1 := PackDomainName(strings.ToLower(label), name, 0, nil, false) off, err := PackDomainName(strings.ToLower(label), name, 0, nil, false)
if !ok1 { if err != nil {
return "" return ""
} }
name = name[:off] name = name[:off]

View File

@ -396,7 +396,7 @@ func (c *conn) serve() {
w := new(response) w := new(response)
w.conn = c w.conn = c
req := new(Msg) req := new(Msg)
if !req.Unpack(c.request) { if req.Unpack(c.request) != nil {
// Send a format error back // Send a format error back
x := new(Msg) x := new(Msg)
x.SetRcodeFormatError(req) x.SetRcodeFormatError(req)
@ -436,10 +436,7 @@ func (c *conn) serve() {
// Write implements the ResponseWriter.Write method. // Write implements the ResponseWriter.Write method.
func (w *response) Write(m *Msg) (err error) { func (w *response) Write(m *Msg) (err error) {
var ( var data []byte
data []byte
ok bool
)
if m == nil { if m == nil {
return &Error{Err: "nil message"} return &Error{Err: "nil message"}
} }
@ -449,9 +446,9 @@ func (w *response) Write(m *Msg) (err error) {
return err return err
} }
} else { } else {
data, ok = m.Pack() data, err = m.Pack()
if !ok { if err != nil {
return ErrPack return err
} }
} }
return w.WriteBuf(data) return w.WriteBuf(data)
@ -470,7 +467,7 @@ func (w *response) WriteBuf(m []byte) (err error) {
} }
case w.conn._TCP != nil: case w.conn._TCP != nil:
if len(m) > MaxMsgSize { if len(m) > MaxMsgSize {
return ErrBuf return &Error{Err: "message too large"}
} }
l := make([]byte, 2) l := make([]byte, 2)
l[0], l[1] = packUint16(uint16(len(m))) l[0], l[1] = packUint16(uint16(len(m)))

39
tsig.go
View File

@ -165,9 +165,9 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
rr := m.Extra[len(m.Extra)-1].(*RR_TSIG) rr := m.Extra[len(m.Extra)-1].(*RR_TSIG)
m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg
mbuf, ok := m.Pack() mbuf, err := m.Pack()
if !ok { if err != nil {
return nil, "", ErrPack return nil, "", err
} }
buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly) buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly)
@ -194,10 +194,10 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
t.OrigId = m.Id t.OrigId = m.Id
tbuf := make([]byte, t.Len()) tbuf := make([]byte, t.Len())
if off, ok := PackRR(t, tbuf, 0, nil, false); ok { if off, err := PackRR(t, tbuf, 0, nil, false); err != nil {
tbuf = tbuf[:off] // reset to actual size used tbuf = tbuf[:off] // reset to actual size used
} else { } else {
return nil, "", ErrPack return nil, "", err
} }
mbuf = append(mbuf, tbuf...) mbuf = append(mbuf, tbuf...)
rawSetExtraLen(mbuf, uint16(len(m.Extra)+1)) rawSetExtraLen(mbuf, uint16(len(m.Extra)+1))
@ -298,13 +298,13 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
// Copied from msg.go's Unpack() // Copied from msg.go's Unpack()
// Header. // Header.
var dh Header var dh Header
var err error
dns := new(Msg) dns := new(Msg)
rr := new(RR_TSIG) rr := new(RR_TSIG)
off := 0 off := 0
tsigoff := 0 tsigoff := 0
var ok bool if off, err = UnpackStruct(&dh, msg, off); err !=nil {
if off, ok = UnpackStruct(&dh, msg, off); !ok { return nil, nil, err
return nil, nil, ErrUnpack
} }
if dh.Arcount == 0 { if dh.Arcount == 0 {
return nil, nil, ErrNoSig return nil, nil, ErrNoSig
@ -321,17 +321,29 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
dns.Extra = make([]RR, dh.Arcount) dns.Extra = make([]RR, dh.Arcount)
for i := 0; i < len(dns.Question); i++ { for i := 0; i < len(dns.Question); i++ {
off, ok = UnpackStruct(&dns.Question[i], msg, off) off, err = UnpackStruct(&dns.Question[i], msg, off)
if err != nil {
return nil, nil, err
}
} }
for i := 0; i < len(dns.Answer); i++ { for i := 0; i < len(dns.Answer); i++ {
dns.Answer[i], off, ok = UnpackRR(msg, off) dns.Answer[i], off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
} }
for i := 0; i < len(dns.Ns); i++ { for i := 0; i < len(dns.Ns); i++ {
dns.Ns[i], off, ok = UnpackRR(msg, off) dns.Ns[i], off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
} }
for i := 0; i < len(dns.Extra); i++ { for i := 0; i < len(dns.Extra); i++ {
tsigoff = off tsigoff = off
dns.Extra[i], off, ok = UnpackRR(msg, off) dns.Extra[i], off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
if dns.Extra[i].Header().Rrtype == TypeTSIG { if dns.Extra[i].Header().Rrtype == TypeTSIG {
rr = dns.Extra[i].(*RR_TSIG) rr = dns.Extra[i].(*RR_TSIG)
// Adjust Arcount. // Adjust Arcount.
@ -340,9 +352,6 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
break break
} }
} }
if !ok {
return nil, nil, ErrUnpack
}
if rr == nil { if rr == nil {
return nil, nil, ErrNoSig return nil, nil, ErrNoSig
} }