Use proper error in packing and unpacking
All the relevant functions now return an error instead of a simple boolean. This greatly approves the feedback to coders. Spotted some fishy error handling along the way and fix that too.
This commit is contained in:
parent
099c19d5b2
commit
570bf8dc69
21
client.go
21
client.go
|
@ -101,9 +101,9 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, err error) {
|
|||
func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
||||
var n int
|
||||
var w *reply
|
||||
out, ok := m.Pack()
|
||||
if !ok {
|
||||
return nil, 0, ErrPack
|
||||
out, err := m.Pack()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
var in []byte
|
||||
switch c.Net {
|
||||
|
@ -123,8 +123,8 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err e
|
|||
}
|
||||
r = new(Msg)
|
||||
r.Size = n
|
||||
if ok := r.Unpack(in[:n]); !ok {
|
||||
return nil, w.rtt, ErrUnpack
|
||||
if err := r.Unpack(in[:n]); err != nil {
|
||||
return nil, w.rtt, err
|
||||
}
|
||||
return r, w.rtt, nil
|
||||
}
|
||||
|
@ -158,8 +158,8 @@ func (w *reply) receive() (*Msg, error) {
|
|||
return nil, err
|
||||
}
|
||||
p = p[:n]
|
||||
if ok := m.Unpack(p); !ok {
|
||||
return nil, ErrUnpack
|
||||
if err := m.Unpack(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w.rtt = time.Since(w.t)
|
||||
m.Size = n
|
||||
|
@ -260,10 +260,9 @@ func (w *reply) send(m *Msg) (err error) {
|
|||
}
|
||||
w.tsigRequestMAC = mac
|
||||
} else {
|
||||
ok := false
|
||||
out, ok = m.Pack()
|
||||
if !ok {
|
||||
return ErrPack
|
||||
out, err = m.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.t = time.Now()
|
||||
|
|
28
dnssec.go
28
dnssec.go
|
@ -119,8 +119,8 @@ func (k *RR_DNSKEY) KeyTag() uint16 {
|
|||
keywire.Algorithm = k.Algorithm
|
||||
keywire.PublicKey = k.PublicKey
|
||||
wire := make([]byte, DefaultMsgSize)
|
||||
n, ok := PackStruct(keywire, wire, 0)
|
||||
if !ok {
|
||||
n, err := PackStruct(keywire, wire, 0)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
wire = wire[:n]
|
||||
|
@ -157,15 +157,15 @@ func (k *RR_DNSKEY) ToDS(h int) *RR_DS {
|
|||
keywire.Algorithm = k.Algorithm
|
||||
keywire.PublicKey = k.PublicKey
|
||||
wire := make([]byte, DefaultMsgSize)
|
||||
n, ok := PackStruct(keywire, wire, 0)
|
||||
if !ok {
|
||||
n, err := PackStruct(keywire, wire, 0)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
wire = wire[:n]
|
||||
|
||||
owner := make([]byte, 255)
|
||||
off, ok1 := PackDomainName(k.Hdr.Name, owner, 0, nil, false)
|
||||
if !ok1 {
|
||||
off, err1 := PackDomainName(k.Hdr.Name, owner, 0, nil, false)
|
||||
if err1 != nil {
|
||||
return nil
|
||||
}
|
||||
owner = owner[:off]
|
||||
|
@ -237,9 +237,9 @@ func (rr *RR_RRSIG) Sign(k PrivateKey, rrset []RR) error {
|
|||
|
||||
// Create the desired binary blob
|
||||
signdata := make([]byte, DefaultMsgSize)
|
||||
n, ok := PackStruct(sigwire, signdata, 0)
|
||||
if !ok {
|
||||
return ErrPack
|
||||
n, err := PackStruct(sigwire, signdata, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
signdata = signdata[:n]
|
||||
wire := rawSignatureData(rrset, rr)
|
||||
|
@ -349,9 +349,9 @@ func (rr *RR_RRSIG) Verify(k *RR_DNSKEY, rrset []RR) error {
|
|||
sigwire.SignerName = strings.ToLower(rr.SignerName)
|
||||
// Create the desired binary blob
|
||||
signeddata := make([]byte, DefaultMsgSize)
|
||||
n, ok := PackStruct(sigwire, signeddata, 0)
|
||||
if !ok {
|
||||
return ErrPack
|
||||
n, err := PackStruct(sigwire, signeddata, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
signeddata = signeddata[:n]
|
||||
wire := rawSignatureData(rrset, rr)
|
||||
|
@ -684,8 +684,8 @@ func rawSignatureData(rrset []RR, s *RR_RRSIG) (buf []byte) {
|
|||
}
|
||||
// 6.2. Canonical RR Form. (5) - origTTL
|
||||
wire := make([]byte, r.Len()*2)
|
||||
off, ok1 := PackRR(r1, wire, 0, nil, false)
|
||||
if !ok1 {
|
||||
off, err1 := PackRR(r1, wire, 0, nil, false)
|
||||
if err1 != nil {
|
||||
return nil
|
||||
}
|
||||
wire = wire[:off]
|
||||
|
|
352
msg.go
352
msg.go
|
@ -28,8 +28,10 @@ const maxCompressionOffset = 2 << 13 // We have 14 bits for the compression poin
|
|||
var (
|
||||
ErrUnpack error = &Error{Err: "unpacking failed"}
|
||||
ErrPack error = &Error{Err: "packing failed"}
|
||||
ErrFqdn error = &Error{Err: "domain must be fully qualified"}
|
||||
ErrId error = &Error{Err: "id mismatch"}
|
||||
ErrBuf error = &Error{Err: "buffer size too large"}
|
||||
ErrRdata error = &Error{Err: "bad rdata"}
|
||||
ErrBuf error = &Error{Err: "buffer size too small"}
|
||||
ErrShortRead error = &Error{Err: "short read"}
|
||||
ErrConn error = &Error{Err: "conn holds both UDP and TCP connection"}
|
||||
ErrConnEmpty error = &Error{Err: "conn has no connection"}
|
||||
|
@ -46,10 +48,7 @@ var (
|
|||
ErrSigGen error = &Error{Err: "bad signature generation"}
|
||||
ErrAuth error = &Error{Err: "bad authentication"}
|
||||
ErrSoa error = &Error{Err: "no SOA"}
|
||||
ErrHandle error = &Error{Err: "handle is nil"}
|
||||
ErrChan error = &Error{Err: "channel is nil"}
|
||||
ErrName error = &Error{Err: "type not found for name"}
|
||||
ErrRRset error = &Error{Err: "invalid rrset"}
|
||||
ErrRRset error = &Error{Err: "bad rrset"}
|
||||
ErrDenialNsec3 error = &Error{Err: "no NSEC3 records"}
|
||||
ErrDenialCe error = &Error{Err: "no matching closest encloser found"}
|
||||
ErrDenialNc error = &Error{Err: "no covering NSEC3 found for next closer"}
|
||||
|
@ -200,13 +199,12 @@ var Rcode_str = map[int]string{
|
|||
// If compression is wanted compress must be true and the compression
|
||||
// map needs to hold a mapping between domain names and offsets
|
||||
// pointing into msg[].
|
||||
func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
|
||||
// Add trailing dot to canonicalize name.
|
||||
func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
|
||||
lenmsg := len(msg)
|
||||
ls := len(s)
|
||||
// If not fully qualified, error out
|
||||
if ls == 0 || s[ls-1] != '.' {
|
||||
//println("dns: name not fully qualified")
|
||||
return lenmsg, false
|
||||
return lenmsg, ErrFqdn
|
||||
}
|
||||
|
||||
// Each dot ends a segment of the name.
|
||||
|
@ -234,30 +232,22 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
|
|||
|
||||
if bs[i] == '.' {
|
||||
if i-begin >= 1<<6 { // top two bits of length must be clear
|
||||
return lenmsg, false
|
||||
return lenmsg, ErrRdata
|
||||
}
|
||||
// off can already (we're in a loop) be bigger than len(msg)
|
||||
// this happens when a name isn't fully qualified
|
||||
if off+1 > len(msg) {
|
||||
return lenmsg, false
|
||||
return lenmsg, ErrBuf
|
||||
}
|
||||
msg[off] = byte(i - begin)
|
||||
offset := off
|
||||
off++
|
||||
// TODO(mg): because of the new check above, this can go. But
|
||||
// just leave it as is for the moment.
|
||||
// if off > lenmsg {
|
||||
// return lenmsg, false
|
||||
// }
|
||||
for j := begin; j < i; j++ {
|
||||
if off+1 > len(msg) {
|
||||
return lenmsg, false
|
||||
return lenmsg, ErrBuf
|
||||
}
|
||||
msg[off] = bs[j]
|
||||
off++
|
||||
// if off > lenmsg {
|
||||
// return lenmsg, false
|
||||
// }
|
||||
}
|
||||
// Dont try to compress '.'
|
||||
if compression != nil && string(bs[begin:]) != ".'" {
|
||||
|
@ -285,7 +275,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
|
|||
}
|
||||
// Root label is special
|
||||
if string(bs) == "." {
|
||||
return off, true
|
||||
return off, nil
|
||||
}
|
||||
// If we did compression and we find something at the pointer here
|
||||
if pointer != -1 {
|
||||
|
@ -297,7 +287,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
|
|||
msg[off] = 0
|
||||
End:
|
||||
off++
|
||||
return off, true
|
||||
return off, nil
|
||||
}
|
||||
|
||||
// Unpack a domain name.
|
||||
|
@ -315,14 +305,14 @@ End:
|
|||
// We let them jump anywhere and stop jumping after a while.
|
||||
|
||||
// UnpackDomainName unpacks a domain name into a string.
|
||||
func UnpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) {
|
||||
func UnpackDomainName(msg []byte, off int) (s string, off1 int, err error) {
|
||||
s = ""
|
||||
lenmsg := len(msg)
|
||||
ptr := 0 // number of pointers followed
|
||||
Loop:
|
||||
for {
|
||||
if off >= lenmsg {
|
||||
return "", lenmsg, false
|
||||
return "", lenmsg, ErrBuf
|
||||
}
|
||||
c := int(msg[off])
|
||||
off++
|
||||
|
@ -331,13 +321,13 @@ Loop:
|
|||
if c == 0x00 {
|
||||
// end of name
|
||||
if s == "" {
|
||||
return ".", off, true
|
||||
return ".", off, nil
|
||||
}
|
||||
break Loop
|
||||
}
|
||||
// literal string
|
||||
if off+c > lenmsg {
|
||||
return "", lenmsg, false
|
||||
return "", lenmsg, ErrBuf
|
||||
}
|
||||
for j := off; j < off+c; j++ {
|
||||
if msg[j] == '.' {
|
||||
|
@ -356,7 +346,7 @@ Loop:
|
|||
// also, don't follow too many pointers --
|
||||
// maybe there's a loop.
|
||||
if off >= lenmsg {
|
||||
return "", lenmsg, false
|
||||
return "", lenmsg, ErrBuf
|
||||
}
|
||||
c1 := msg[off]
|
||||
off++
|
||||
|
@ -364,41 +354,38 @@ Loop:
|
|||
off1 = off
|
||||
}
|
||||
if ptr++; ptr > 10 {
|
||||
return "", lenmsg, false
|
||||
return "", lenmsg, &Error{Err: "too many compression pointers"}
|
||||
}
|
||||
off = (c^0xC0)<<8 | int(c1)
|
||||
default:
|
||||
// 0x80 and 0x40 are reserved
|
||||
return "", lenmsg, false
|
||||
return "", lenmsg, ErrRdata
|
||||
}
|
||||
}
|
||||
if ptr == 0 {
|
||||
off1 = off
|
||||
}
|
||||
return s, off1, true
|
||||
return s, off1, nil
|
||||
}
|
||||
|
||||
// Pack a reflect.StructValue into msg. Struct members can only be uint8, uint16, uint32, string,
|
||||
// slices and other (often anonymous) structs.
|
||||
func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
|
||||
func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
|
||||
lenmsg := len(msg)
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
// f := val.Type().Field(i)
|
||||
lenmsg := len(msg)
|
||||
switch fv := val.Field(i); fv.Kind() {
|
||||
default:
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "bad kind packing"}
|
||||
case reflect.Slice:
|
||||
switch val.Type().Field(i).Tag.Get("dns") {
|
||||
default:
|
||||
// println("dns: unknown tag packing slice", val.Type().Field(i).Tag.Get("dns"), '"', val.Type().Field(i).Tag, '"')
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag packing slice"}
|
||||
case "domain-name":
|
||||
for j := 0; j < val.Field(i).Len(); j++ {
|
||||
element := val.Field(i).Index(j).String()
|
||||
off, ok = PackDomainName(element, msg, off, compression, false && compress)
|
||||
if !ok {
|
||||
// println("dns: overflow packing domain-name", off)
|
||||
return lenmsg, false
|
||||
off, err = PackDomainName(element, msg, off, compression, false && compress)
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
}
|
||||
case "txt":
|
||||
|
@ -406,8 +393,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
element := val.Field(i).Index(j).String()
|
||||
// Counted string: 1 byte length.
|
||||
if len(element) > 255 || off+1+len(element) > lenmsg {
|
||||
// println("dns: overflow packing TXT string")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing txt"}
|
||||
}
|
||||
msg[off] = byte(len(element))
|
||||
off++
|
||||
|
@ -421,8 +407,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
element := val.Field(i).Index(j).Interface()
|
||||
b, e := element.(EDNS0).pack()
|
||||
if e != nil {
|
||||
// println("dns: failure packing OPT")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing opt"}
|
||||
}
|
||||
// Option code
|
||||
msg[off], msg[off+1] = packUint16(element.(EDNS0).Option())
|
||||
|
@ -436,22 +421,17 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
case "a":
|
||||
// It must be a slice of 4, even if it is 16, we encode
|
||||
// only the first 4
|
||||
if off+net.IPv4len > lenmsg {
|
||||
return lenmsg, &Error{Err: "overflow packing a"}
|
||||
}
|
||||
switch fv.Len() {
|
||||
case net.IPv6len:
|
||||
if off+net.IPv4len > lenmsg {
|
||||
// println("dns: overflow packing A", off, lenmsg)
|
||||
return lenmsg, false
|
||||
}
|
||||
msg[off] = byte(fv.Index(12).Uint())
|
||||
msg[off+1] = byte(fv.Index(13).Uint())
|
||||
msg[off+2] = byte(fv.Index(14).Uint())
|
||||
msg[off+3] = byte(fv.Index(15).Uint())
|
||||
off += net.IPv4len
|
||||
case net.IPv4len:
|
||||
if off+net.IPv4len > lenmsg {
|
||||
// println("dns: overflow packing A", off, lenmsg)
|
||||
return lenmsg, false
|
||||
}
|
||||
msg[off] = byte(fv.Index(0).Uint())
|
||||
msg[off+1] = byte(fv.Index(1).Uint())
|
||||
msg[off+2] = byte(fv.Index(2).Uint())
|
||||
|
@ -460,13 +440,11 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
case 0:
|
||||
// Allowed, for dynamic updates
|
||||
default:
|
||||
// println("dns: overflow packing A")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing a"}
|
||||
}
|
||||
case "aaaa":
|
||||
if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg {
|
||||
// println("dns: overflow packing AAAA")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing aaaa"}
|
||||
}
|
||||
for j := 0; j < net.IPv6len; j++ {
|
||||
msg[off] = byte(fv.Index(j).Uint())
|
||||
|
@ -481,8 +459,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
serv := uint16((fv.Index(j).Uint()))
|
||||
bitmapbyte = uint16(serv / 8)
|
||||
if int(bitmapbyte) > lenmsg {
|
||||
// println("dns: overflow packing WKS")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing wks"}
|
||||
}
|
||||
bit := uint16(serv) - bitmapbyte*8
|
||||
msg[bitmapbyte] = byte(1 << (7 - bit))
|
||||
|
@ -498,8 +475,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
lastwindow := uint16(0)
|
||||
length := uint16(0)
|
||||
if off+2 > lenmsg {
|
||||
// println("dns: overflow packing NSECx bitmap")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing nsecx"}
|
||||
}
|
||||
for j := 0; j < val.Field(i).Len(); j++ {
|
||||
t := uint16((fv.Index(j).Uint()))
|
||||
|
@ -508,15 +484,13 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
// New window, jump to the new offset
|
||||
off += int(length) + 3
|
||||
if off > lenmsg {
|
||||
// println("dns: overflow packing NSECx bitmap")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
|
||||
}
|
||||
}
|
||||
length = (t - window*256) / 8
|
||||
bit := t - (window * 256) - (length * 8)
|
||||
if off+2+int(length) > lenmsg {
|
||||
// println("dns: overflow packing NSECx bitmap")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
|
||||
}
|
||||
|
||||
// Setting the window #
|
||||
|
@ -530,23 +504,20 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
off += 2 + int(length)
|
||||
off++
|
||||
if off > lenmsg {
|
||||
// println("dns: overflow packing NSECx bitmap")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
off, ok = packStructValue(fv, msg, off, compression, compress)
|
||||
off, err = packStructValue(fv, msg, off, compression, compress)
|
||||
case reflect.Uint8:
|
||||
if off+1 > lenmsg {
|
||||
// println("dns: overflow packing uint8")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing uint8"}
|
||||
}
|
||||
msg[off] = byte(fv.Uint())
|
||||
off++
|
||||
case reflect.Uint16:
|
||||
if off+2 > lenmsg {
|
||||
// println("dns: overflow packing uint16")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing uint16"}
|
||||
}
|
||||
i := fv.Uint()
|
||||
msg[off] = byte(i >> 8)
|
||||
|
@ -554,8 +525,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
off += 2
|
||||
case reflect.Uint32:
|
||||
if off+4 > lenmsg {
|
||||
// println("dns: overflow packing uint32")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing uint32"}
|
||||
}
|
||||
i := fv.Uint()
|
||||
msg[off] = byte(i >> 24)
|
||||
|
@ -566,8 +536,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
case reflect.Uint64:
|
||||
// Only used in TSIG, where it stops at 48 bits, so we discard the upper 16
|
||||
if off+6 > lenmsg {
|
||||
// println("dns: overflow packing uint64")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing uint64 as uint48"}
|
||||
}
|
||||
i := fv.Uint()
|
||||
msg[off] = byte(i >> 40)
|
||||
|
@ -583,24 +552,21 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
s := fv.String()
|
||||
switch val.Type().Field(i).Tag.Get("dns") {
|
||||
default:
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag packing string"}
|
||||
case "base64":
|
||||
b64, err := packBase64([]byte(s))
|
||||
if err != nil {
|
||||
// println("dns: overflow packing base64")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing base64"}
|
||||
}
|
||||
copy(msg[off:off+len(b64)], b64)
|
||||
off += len(b64)
|
||||
case "domain-name":
|
||||
if off, ok = PackDomainName(s, msg, off, compression, false && compress); !ok {
|
||||
// println("dns: overflow packing domain-name", off)
|
||||
return lenmsg, false
|
||||
if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
case "cdomain-name":
|
||||
if off, ok = PackDomainName(s, msg, off, compression, true && compress); !ok {
|
||||
// println("dns: overflow packing domain-name", off)
|
||||
return lenmsg, false
|
||||
if off, err = PackDomainName(s, msg, off, compression, true && compress); err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
case "size-base32":
|
||||
// This is purely for NSEC3 atm, the previous byte must
|
||||
|
@ -611,8 +577,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
case "base32":
|
||||
b32, err := packBase32([]byte(s))
|
||||
if err != nil {
|
||||
// println("dns: overflow packing base32")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing base32"}
|
||||
}
|
||||
copy(msg[off:off+len(b32)], b32)
|
||||
off += len(b32)
|
||||
|
@ -622,12 +587,10 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
// There is no length encoded here
|
||||
h, e := hex.DecodeString(s)
|
||||
if e != nil {
|
||||
// println("dns: overflow packing (size-)hex string")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing hex"}
|
||||
}
|
||||
if off+hex.DecodedLen(len(s)) > lenmsg {
|
||||
// Overflow
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing hex"}
|
||||
}
|
||||
copy(msg[off:off+hex.DecodedLen(len(s))], h)
|
||||
off += hex.DecodedLen(len(s))
|
||||
|
@ -641,8 +604,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
case "":
|
||||
// Counted string: 1 byte length.
|
||||
if len(s) > 255 || off+1+len(s) > lenmsg {
|
||||
// println("dns: overflow packing string")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow packing string"}
|
||||
}
|
||||
msg[off] = byte(len(s))
|
||||
off++
|
||||
|
@ -653,48 +615,44 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
|
|||
}
|
||||
}
|
||||
}
|
||||
return off, true
|
||||
return off, nil
|
||||
}
|
||||
|
||||
func structValue(any interface{}) reflect.Value {
|
||||
return reflect.ValueOf(any).Elem()
|
||||
}
|
||||
|
||||
func PackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
|
||||
off, ok = packStructValue(structValue(any), msg, off, nil, false)
|
||||
return off, ok
|
||||
func PackStruct(any interface{}, msg []byte, off int) (off1 int, err error) {
|
||||
off, err = packStructValue(structValue(any), msg, off, nil, false)
|
||||
return off, err
|
||||
}
|
||||
|
||||
func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
|
||||
off, ok = packStructValue(structValue(any), msg, off, compression, compress)
|
||||
return off, ok
|
||||
func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
|
||||
off, err = packStructValue(structValue(any), msg, off, compression, compress)
|
||||
return off, err
|
||||
}
|
||||
|
||||
// Unpack a reflect.StructValue from msg.
|
||||
// Same restrictions as packStructValue.
|
||||
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
|
||||
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) {
|
||||
var rdstart int
|
||||
lenmsg := len(msg)
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
// f := val.Type().Field(i)
|
||||
lenmsg := len(msg)
|
||||
switch fv := val.Field(i); fv.Kind() {
|
||||
default:
|
||||
// println("dns: unknown case unpacking struct")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "bad kind unpacking"}
|
||||
case reflect.Slice:
|
||||
switch val.Type().Field(i).Tag.Get("dns") {
|
||||
default:
|
||||
// println("dns: unknown tag unpacking slice", val.Type().Field(i).Tag)
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag unpacking slice"}
|
||||
case "domain-name":
|
||||
// HIP record slice of name (or none)
|
||||
servers := make([]string, 0)
|
||||
var s string
|
||||
for off < lenmsg {
|
||||
s, off, ok = UnpackDomainName(msg, off)
|
||||
if !ok {
|
||||
// println("dns: failure unpacking domain-name")
|
||||
return lenmsg, false
|
||||
s, off, err = UnpackDomainName(msg, off)
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
servers = append(servers, s)
|
||||
}
|
||||
|
@ -705,8 +663,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
Txts:
|
||||
l := int(msg[off])
|
||||
if off+l+1 > lenmsg {
|
||||
// println("dns: failure unpacking txt strings")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking txt"}
|
||||
}
|
||||
txt = append(txt, string(msg[off+1:off+l+1]))
|
||||
off += l + 1
|
||||
|
@ -724,14 +681,14 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
break
|
||||
}
|
||||
edns := make([]EDNS0, 0)
|
||||
// Goto to this place, when there is a goto
|
||||
code := uint16(0)
|
||||
|
||||
code, off = unpackUint16(msg, off) // Overflow? TODO
|
||||
if off+2 > lenmsg {
|
||||
return lenmsg, &Error{Err: "overflow unpacking opt"}
|
||||
}
|
||||
code, off = unpackUint16(msg, off)
|
||||
optlen, off1 := unpackUint16(msg, off)
|
||||
if off1+int(optlen) > off+rdlength {
|
||||
// println("dns: overflow unpacking OPT")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking opt"}
|
||||
}
|
||||
switch code {
|
||||
case EDNS0NSID:
|
||||
|
@ -746,18 +703,16 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
off = off1 + int(optlen)
|
||||
}
|
||||
fv.Set(reflect.ValueOf(edns))
|
||||
// goto ??
|
||||
// multiple EDNS codes?
|
||||
case "a":
|
||||
if off+net.IPv4len > len(msg) {
|
||||
// println("dns: overflow unpacking A")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking a"}
|
||||
}
|
||||
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
|
||||
off += net.IPv4len
|
||||
case "aaaa":
|
||||
if off+net.IPv6len > lenmsg {
|
||||
// println("dns: overflow unpacking AAAA")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking aaaa"}
|
||||
}
|
||||
fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4],
|
||||
msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10],
|
||||
|
@ -806,8 +761,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
endrr := rdstart + rdlength
|
||||
|
||||
if off+2 > lenmsg {
|
||||
// println("dns: overflow unpacking NSEC")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking nsecx"}
|
||||
}
|
||||
nsec := make([]uint16, 0)
|
||||
length := 0
|
||||
|
@ -820,15 +774,13 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
// A length window of zero is strange. If there
|
||||
// the window should not have been specified. Bail out
|
||||
// println("dns: length == 0 when unpacking NSEC")
|
||||
return lenmsg, false
|
||||
return lenmsg, ErrRdata
|
||||
}
|
||||
if length > 32 {
|
||||
// println("dns: length > 32 when unpacking NSEC")
|
||||
return lenmsg, false
|
||||
return lenmsg, ErrRdata
|
||||
}
|
||||
|
||||
// Walk the bytes in the window - and check the bit
|
||||
// setting..
|
||||
// Walk the bytes in the window - and check the bit settings...
|
||||
off += 2
|
||||
for j := 0; j < length; j++ {
|
||||
b := msg[off+j]
|
||||
|
@ -863,29 +815,26 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
fv.Set(reflect.ValueOf(nsec))
|
||||
}
|
||||
case reflect.Struct:
|
||||
off, ok = unpackStructValue(fv, msg, off)
|
||||
off, err = unpackStructValue(fv, msg, off)
|
||||
if val.Type().Field(i).Name == "Hdr" {
|
||||
rdstart = off
|
||||
}
|
||||
case reflect.Uint8:
|
||||
if off+1 > lenmsg {
|
||||
// println("dns: overflow unpacking uint8")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking uint8"}
|
||||
}
|
||||
fv.SetUint(uint64(uint8(msg[off])))
|
||||
off++
|
||||
case reflect.Uint16:
|
||||
var i uint16
|
||||
if off+2 > lenmsg {
|
||||
// println("dns: overflow unpacking uint16")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking uint16"}
|
||||
}
|
||||
i, off = unpackUint16(msg, off)
|
||||
fv.SetUint(uint64(i))
|
||||
case reflect.Uint32:
|
||||
if off+4 > lenmsg {
|
||||
// println("dns: overflow unpacking uint32")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking uint32"}
|
||||
}
|
||||
fv.SetUint(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])))
|
||||
off += 4
|
||||
|
@ -893,8 +842,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
// This is *only* used in TSIG where the last 48 bits are occupied
|
||||
// So for now, assume a uint48 (6 bytes)
|
||||
if off+6 > lenmsg {
|
||||
// println("dns: overflow unpacking uint64")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking uint64 as uint48"}
|
||||
}
|
||||
fv.SetUint(uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
|
||||
uint64(msg[off+4])<<8 | uint64(msg[off+5])))
|
||||
|
@ -903,15 +851,13 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
var s string
|
||||
switch val.Type().Field(i).Tag.Get("dns") {
|
||||
default:
|
||||
// println("dns: unknown tag unpacking string")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag unpacking string"}
|
||||
case "hex":
|
||||
// Rest of the RR is hex encoded, network order an issue here?
|
||||
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
|
||||
endrr := rdstart + rdlength
|
||||
if endrr > lenmsg {
|
||||
// println("dns: overflow when unpacking hex string")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking hex"}
|
||||
}
|
||||
s = hex.EncodeToString(msg[off:endrr])
|
||||
off = endrr
|
||||
|
@ -920,18 +866,16 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
|
||||
endrr := rdstart + rdlength
|
||||
if endrr > lenmsg {
|
||||
// println("dns: failure unpacking base64")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking base64"}
|
||||
}
|
||||
s = unpackBase64(msg[off:endrr])
|
||||
off = endrr
|
||||
case "cdomain-name":
|
||||
fallthrough
|
||||
case "domain-name":
|
||||
s, off, ok = UnpackDomainName(msg, off)
|
||||
if !ok {
|
||||
// println("dns: failure unpacking domain-name")
|
||||
return lenmsg, false
|
||||
s, off, err = UnpackDomainName(msg, off)
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
case "size-base32":
|
||||
var size int
|
||||
|
@ -944,8 +888,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
}
|
||||
}
|
||||
if off+size > lenmsg {
|
||||
// println("dns: failure unpacking size-base32 string")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking base32"}
|
||||
}
|
||||
s = unpackBase32(msg[off : off+size])
|
||||
off += size
|
||||
|
@ -973,8 +916,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
}
|
||||
}
|
||||
if off+size > lenmsg {
|
||||
// println("dns: failure unpacking size-hex string")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking hex"}
|
||||
}
|
||||
s = hex.EncodeToString(msg[off : off+size])
|
||||
off += size
|
||||
|
@ -983,8 +925,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
|
||||
Txt:
|
||||
if off >= lenmsg || off+1+int(msg[off]) > lenmsg {
|
||||
// println("dns: failure unpacking txt string")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking txt"}
|
||||
}
|
||||
n := int(msg[off])
|
||||
off++
|
||||
|
@ -998,8 +939,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
}
|
||||
case "":
|
||||
if off >= lenmsg || off+1+int(msg[off]) > lenmsg {
|
||||
// println("dns: failure unpacking string")
|
||||
return lenmsg, false
|
||||
return lenmsg, &Error{Err: "overflow unpacking string"}
|
||||
}
|
||||
n := int(msg[off])
|
||||
off++
|
||||
|
@ -1011,7 +951,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
|
|||
fv.SetString(s)
|
||||
}
|
||||
}
|
||||
return off, true
|
||||
return off, nil
|
||||
}
|
||||
|
||||
// Helper function for unpacking
|
||||
|
@ -1021,9 +961,9 @@ func unpackUint16(msg []byte, off int) (v uint16, off1 int) {
|
|||
return
|
||||
}
|
||||
|
||||
func UnpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
|
||||
off, ok = unpackStructValue(structValue(any), msg, off)
|
||||
return off, ok
|
||||
func UnpackStruct(any interface{}, msg []byte, off int) (off1 int, err error) {
|
||||
off, err = unpackStructValue(structValue(any), msg, off)
|
||||
return off, err
|
||||
}
|
||||
|
||||
func unpackBase32(b []byte) string {
|
||||
|
@ -1067,28 +1007,26 @@ func packBase32(s []byte) ([]byte, error) {
|
|||
}
|
||||
|
||||
// Resource record packer.
|
||||
func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
|
||||
func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
|
||||
if rr == nil {
|
||||
return len(msg), false
|
||||
return len(msg), &Error{Err: "nil rr"}
|
||||
}
|
||||
|
||||
off1, ok = packStructCompress(rr, msg, off, compression, compress)
|
||||
if !ok {
|
||||
return len(msg), false
|
||||
off1, err = packStructCompress(rr, msg, off, compression, compress)
|
||||
if err != nil {
|
||||
return len(msg), err
|
||||
}
|
||||
if !rawSetRdlength(msg, off, off1) {
|
||||
return len(msg), false
|
||||
}
|
||||
return off1, true
|
||||
rawSetRdlength(msg, off, off1)
|
||||
return off1, nil
|
||||
}
|
||||
|
||||
// Resource record unpacker.
|
||||
func UnpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) {
|
||||
func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
|
||||
// unpack just the header, to find the rr type and length
|
||||
var h RR_Header
|
||||
off0 := off
|
||||
if off, ok = UnpackStruct(&h, msg, off); !ok {
|
||||
return nil, len(msg), false
|
||||
if off, err = UnpackStruct(&h, msg, off); err != nil {
|
||||
return nil, len(msg), err
|
||||
}
|
||||
end := off + int(h.Rdlength)
|
||||
// make an rr of that type and re-unpack.
|
||||
|
@ -1098,11 +1036,11 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) {
|
|||
} else {
|
||||
rr = mk()
|
||||
}
|
||||
off, ok = UnpackStruct(rr, msg, off0)
|
||||
off, err = UnpackStruct(rr, msg, off0)
|
||||
if off != end {
|
||||
return &h, end, true
|
||||
return &h, end, nil
|
||||
}
|
||||
return rr, off, ok
|
||||
return rr, off, err
|
||||
}
|
||||
|
||||
// Reverse a map
|
||||
|
@ -1176,9 +1114,9 @@ func (h *MsgHdr) String() string {
|
|||
|
||||
// Pack packs a Msg: it is converted to to wire format.
|
||||
// If the dns.Compress is true the message will be in compressed wire format.
|
||||
func (dns *Msg) Pack() (msg []byte, ok bool) {
|
||||
func (dns *Msg) Pack() (msg []byte, err error) {
|
||||
if dns == nil {
|
||||
return nil, false
|
||||
return nil, &Error{Err: "nil message"}
|
||||
}
|
||||
var dh Header
|
||||
compression := make(map[string]int) // Compression pointer mappings
|
||||
|
@ -1227,34 +1165,41 @@ func (dns *Msg) Pack() (msg []byte, ok bool) {
|
|||
|
||||
// Pack it in: header and then the pieces.
|
||||
off := 0
|
||||
off, ok = packStructCompress(&dh, msg, off, compression, dns.Compress)
|
||||
off, err = packStructCompress(&dh, msg, off, compression, dns.Compress)
|
||||
for i := 0; i < len(question); i++ {
|
||||
off, ok = packStructCompress(&question[i], msg, off, compression, dns.Compress)
|
||||
off, err = packStructCompress(&question[i], msg, off, compression, dns.Compress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(answer); i++ {
|
||||
off, ok = PackRR(answer[i], msg, off, compression, dns.Compress)
|
||||
off, err = PackRR(answer[i], msg, off, compression, dns.Compress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(ns); i++ {
|
||||
off, ok = PackRR(ns[i], msg, off, compression, dns.Compress)
|
||||
off, err = PackRR(ns[i], msg, off, compression, dns.Compress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(extra); i++ {
|
||||
off, ok = PackRR(extra[i], msg, off, compression, dns.Compress)
|
||||
off, err = PackRR(extra[i], msg, off, compression, dns.Compress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
//println("allocated", dns.Len()+1, "used", off)
|
||||
return msg[:off], true
|
||||
return msg[:off], nil
|
||||
}
|
||||
|
||||
// Unpack unpacks a binary message to a Msg structure.
|
||||
func (dns *Msg) Unpack(msg []byte) bool {
|
||||
func (dns *Msg) Unpack(msg []byte) (err error) {
|
||||
// Header.
|
||||
var dh Header
|
||||
off := 0
|
||||
var ok bool
|
||||
if off, ok = UnpackStruct(&dh, msg, off); !ok {
|
||||
return false
|
||||
if off, err = UnpackStruct(&dh, msg, off); err != nil {
|
||||
return err
|
||||
}
|
||||
dns.Id = dh.Id
|
||||
dns.Response = (dh.Bits & _QR) != 0
|
||||
|
@ -1275,25 +1220,34 @@ func (dns *Msg) Unpack(msg []byte) bool {
|
|||
dns.Extra = make([]RR, dh.Arcount)
|
||||
|
||||
for i := 0; i < len(dns.Question); i++ {
|
||||
off, ok = UnpackStruct(&dns.Question[i], msg, off)
|
||||
off, err = UnpackStruct(&dns.Question[i], msg, off)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(dns.Answer); i++ {
|
||||
dns.Answer[i], off, ok = UnpackRR(msg, off)
|
||||
dns.Answer[i], off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(dns.Ns); i++ {
|
||||
dns.Ns[i], off, ok = UnpackRR(msg, off)
|
||||
dns.Ns[i], off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(dns.Extra); i++ {
|
||||
dns.Extra[i], off, ok = UnpackRR(msg, off)
|
||||
}
|
||||
if !ok {
|
||||
return false
|
||||
dns.Extra[i], off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if off != len(msg) {
|
||||
// TODO(mg) remove eventually
|
||||
// println("extra bytes in dns packet", off, "<", len(msg))
|
||||
}
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert a complete message to a string with dig-like output.
|
||||
|
|
8
nsecx.go
8
nsecx.go
|
@ -38,14 +38,14 @@ func HashName(label string, ha uint8, iter uint16, salt string) string {
|
|||
saltwire := new(saltWireFmt)
|
||||
saltwire.Salt = salt
|
||||
wire := make([]byte, DefaultMsgSize)
|
||||
n, ok := PackStruct(saltwire, wire, 0)
|
||||
if !ok {
|
||||
n, err := PackStruct(saltwire, wire, 0)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
wire = wire[:n]
|
||||
name := make([]byte, 255)
|
||||
off, ok1 := PackDomainName(strings.ToLower(label), name, 0, nil, false)
|
||||
if !ok1 {
|
||||
off, err := PackDomainName(strings.ToLower(label), name, 0, nil, false)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
name = name[:off]
|
||||
|
|
15
server.go
15
server.go
|
@ -396,7 +396,7 @@ func (c *conn) serve() {
|
|||
w := new(response)
|
||||
w.conn = c
|
||||
req := new(Msg)
|
||||
if !req.Unpack(c.request) {
|
||||
if req.Unpack(c.request) != nil {
|
||||
// Send a format error back
|
||||
x := new(Msg)
|
||||
x.SetRcodeFormatError(req)
|
||||
|
@ -436,10 +436,7 @@ func (c *conn) serve() {
|
|||
|
||||
// Write implements the ResponseWriter.Write method.
|
||||
func (w *response) Write(m *Msg) (err error) {
|
||||
var (
|
||||
data []byte
|
||||
ok bool
|
||||
)
|
||||
var data []byte
|
||||
if m == nil {
|
||||
return &Error{Err: "nil message"}
|
||||
}
|
||||
|
@ -449,9 +446,9 @@ func (w *response) Write(m *Msg) (err error) {
|
|||
return err
|
||||
}
|
||||
} else {
|
||||
data, ok = m.Pack()
|
||||
if !ok {
|
||||
return ErrPack
|
||||
data, err = m.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return w.WriteBuf(data)
|
||||
|
@ -470,7 +467,7 @@ func (w *response) WriteBuf(m []byte) (err error) {
|
|||
}
|
||||
case w.conn._TCP != nil:
|
||||
if len(m) > MaxMsgSize {
|
||||
return ErrBuf
|
||||
return &Error{Err: "message too large"}
|
||||
}
|
||||
l := make([]byte, 2)
|
||||
l[0], l[1] = packUint16(uint16(len(m)))
|
||||
|
|
39
tsig.go
39
tsig.go
|
@ -165,9 +165,9 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
|
|||
|
||||
rr := m.Extra[len(m.Extra)-1].(*RR_TSIG)
|
||||
m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg
|
||||
mbuf, ok := m.Pack()
|
||||
if !ok {
|
||||
return nil, "", ErrPack
|
||||
mbuf, err := m.Pack()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly)
|
||||
|
||||
|
@ -194,10 +194,10 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
|
|||
t.OrigId = m.Id
|
||||
|
||||
tbuf := make([]byte, t.Len())
|
||||
if off, ok := PackRR(t, tbuf, 0, nil, false); ok {
|
||||
if off, err := PackRR(t, tbuf, 0, nil, false); err != nil {
|
||||
tbuf = tbuf[:off] // reset to actual size used
|
||||
} else {
|
||||
return nil, "", ErrPack
|
||||
return nil, "", err
|
||||
}
|
||||
mbuf = append(mbuf, tbuf...)
|
||||
rawSetExtraLen(mbuf, uint16(len(m.Extra)+1))
|
||||
|
@ -298,13 +298,13 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
|
|||
// Copied from msg.go's Unpack()
|
||||
// Header.
|
||||
var dh Header
|
||||
var err error
|
||||
dns := new(Msg)
|
||||
rr := new(RR_TSIG)
|
||||
off := 0
|
||||
tsigoff := 0
|
||||
var ok bool
|
||||
if off, ok = UnpackStruct(&dh, msg, off); !ok {
|
||||
return nil, nil, ErrUnpack
|
||||
if off, err = UnpackStruct(&dh, msg, off); err !=nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if dh.Arcount == 0 {
|
||||
return nil, nil, ErrNoSig
|
||||
|
@ -321,17 +321,29 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
|
|||
dns.Extra = make([]RR, dh.Arcount)
|
||||
|
||||
for i := 0; i < len(dns.Question); i++ {
|
||||
off, ok = UnpackStruct(&dns.Question[i], msg, off)
|
||||
off, err = UnpackStruct(&dns.Question[i], msg, off)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(dns.Answer); i++ {
|
||||
dns.Answer[i], off, ok = UnpackRR(msg, off)
|
||||
dns.Answer[i], off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(dns.Ns); i++ {
|
||||
dns.Ns[i], off, ok = UnpackRR(msg, off)
|
||||
dns.Ns[i], off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(dns.Extra); i++ {
|
||||
tsigoff = off
|
||||
dns.Extra[i], off, ok = UnpackRR(msg, off)
|
||||
dns.Extra[i], off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if dns.Extra[i].Header().Rrtype == TypeTSIG {
|
||||
rr = dns.Extra[i].(*RR_TSIG)
|
||||
// Adjust Arcount.
|
||||
|
@ -340,9 +352,6 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
|
|||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return nil, nil, ErrUnpack
|
||||
}
|
||||
if rr == nil {
|
||||
return nil, nil, ErrNoSig
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue