Remove reflection (#376)
Everything is generated. Remove all uses of packStruct/unpackStruct and make the library reflectionless.
This commit is contained in:
parent
dbffa4b057
commit
b51e305bc6
@ -284,9 +284,11 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
|
||||
|
||||
p = p[:n]
|
||||
if hdr != nil {
|
||||
if _, err = UnpackStruct(hdr, p, 0); err != nil {
|
||||
dh, _, err := unpackMsgHdr(p, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
*hdr = dh
|
||||
}
|
||||
return p, err
|
||||
}
|
||||
|
13
dns.go
13
dns.go
@ -84,19 +84,22 @@ func (h *RR_Header) len() int {
|
||||
return l
|
||||
}
|
||||
|
||||
// ToRFC3597 converts a known RR to the unknown RR representation
|
||||
// from RFC 3597.
|
||||
// ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597.
|
||||
func (rr *RFC3597) ToRFC3597(r RR) error {
|
||||
buf := make([]byte, r.len()*2)
|
||||
off, err := PackStruct(r, buf, 0)
|
||||
off, err := PackRR(r, buf, 0, nil, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf = buf[:off]
|
||||
rawSetRdlength(buf, 0, off)
|
||||
_, err = UnpackStruct(rr, buf, 0)
|
||||
if int(r.Header().Rdlength) > off {
|
||||
return ErrBuf
|
||||
}
|
||||
|
||||
rfc3597, _, err := unpackRFC3597(*r.Header(), buf, off-int(r.Header().Rdlength))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*rr = *rfc3597.(*RFC3597)
|
||||
return nil
|
||||
}
|
||||
|
@ -315,7 +315,13 @@ func TestToRFC3597(t *testing.T) {
|
||||
x := new(RFC3597)
|
||||
x.ToRFC3597(a)
|
||||
if x.String() != `miek.nl. 3600 CLASS1 TYPE1 \# 4 0a000101` {
|
||||
t.Error("string mismatch")
|
||||
t.Errorf("string mismatch, got: %s", x)
|
||||
}
|
||||
|
||||
b, _ := NewRR("miek.nl. IN MX 10 mx.miek.nl.")
|
||||
x.ToRFC3597(b)
|
||||
if x.String() != `miek.nl. 3600 CLASS1 TYPE15 \# 14 000a026d78046d69656b026e6c00` {
|
||||
t.Errorf("string mismatch, got: %s", x)
|
||||
}
|
||||
}
|
||||
|
||||
|
70
dnssec.go
70
dnssec.go
@ -104,9 +104,7 @@ const (
|
||||
ZONE = 1 << 8
|
||||
)
|
||||
|
||||
// The RRSIG needs to be converted to wireformat with some of
|
||||
// the rdata (the signature) missing. Use this struct to ease
|
||||
// the conversion (and re-use the pack/unpack functions).
|
||||
// The RRSIG needs to be converted to wireformat with some of the rdata (the signature) missing.
|
||||
type rrsigWireFmt struct {
|
||||
TypeCovered uint16
|
||||
Algorithm uint8
|
||||
@ -155,7 +153,7 @@ func (k *DNSKEY) KeyTag() uint16 {
|
||||
keywire.Algorithm = k.Algorithm
|
||||
keywire.PublicKey = k.PublicKey
|
||||
wire := make([]byte, DefaultMsgSize)
|
||||
n, err := PackStruct(keywire, wire, 0)
|
||||
n, err := packKeyWire(keywire, wire)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
@ -193,7 +191,7 @@ func (k *DNSKEY) ToDS(h uint8) *DS {
|
||||
keywire.Algorithm = k.Algorithm
|
||||
keywire.PublicKey = k.PublicKey
|
||||
wire := make([]byte, DefaultMsgSize)
|
||||
n, err := PackStruct(keywire, wire, 0)
|
||||
n, err := packKeyWire(keywire, wire)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@ -290,7 +288,7 @@ func (rr *RRSIG) Sign(k crypto.Signer, rrset []RR) error {
|
||||
|
||||
// Create the desired binary blob
|
||||
signdata := make([]byte, DefaultMsgSize)
|
||||
n, err := PackStruct(sigwire, signdata, 0)
|
||||
n, err := packSigWire(sigwire, signdata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -408,7 +406,7 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
|
||||
sigwire.SignerName = strings.ToLower(rr.SignerName)
|
||||
// Create the desired binary blob
|
||||
signeddata := make([]byte, DefaultMsgSize)
|
||||
n, err := PackStruct(sigwire, signeddata, 0)
|
||||
n, err := packSigWire(sigwire, signeddata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -663,3 +661,61 @@ func rawSignatureData(rrset []RR, s *RRSIG) (buf []byte, err error) {
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func packSigWire(sw *rrsigWireFmt, msg []byte) (int, error) {
|
||||
// copied from zmsg.go RRSIG packing
|
||||
off, err := packUint16(sw.TypeCovered, msg, 0)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint8(sw.Algorithm, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint8(sw.Labels, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint32(sw.OrigTtl, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint32(sw.Expiration, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint32(sw.Inception, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(sw.KeyTag, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = PackDomainName(sw.SignerName, msg, off, nil, false)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
return off, nil
|
||||
}
|
||||
|
||||
func packKeyWire(dw *dnskeyWireFmt, msg []byte) (int, error) {
|
||||
// copied from zmsg.go DNSKEY packing
|
||||
off, err := packUint16(dw.Flags, msg, 0)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint8(dw.Protocol, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint8(dw.Algorithm, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packStringBase64(dw.PublicKey, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
return off, nil
|
||||
}
|
||||
|
594
msg.go
594
msg.go
@ -13,11 +13,8 @@ package dns
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
@ -549,591 +546,6 @@ func unpackTxtString(msg []byte, offset int) (string, int, error) {
|
||||
return string(s), offset, nil
|
||||
}
|
||||
|
||||
// Pack a reflect.StructValue into msg. Struct members can only be uint8, uint16, uint32, string,
|
||||
// slices and other (often anonymous) structs.
|
||||
func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
|
||||
var txtTmp []byte
|
||||
lenmsg := len(msg)
|
||||
numfield := val.NumField()
|
||||
for i := 0; i < numfield; i++ {
|
||||
typefield := val.Type().Field(i)
|
||||
if typefield.Tag == `dns:"-"` {
|
||||
continue
|
||||
}
|
||||
switch fv := val.Field(i); fv.Kind() {
|
||||
default:
|
||||
return lenmsg, &Error{err: "bad kind packing"}
|
||||
case reflect.Interface:
|
||||
// PrivateRR is the only RR implementation that has interface field.
|
||||
// therefore it's expected that this interface would be PrivateRdata
|
||||
switch data := fv.Interface().(type) {
|
||||
case PrivateRdata:
|
||||
n, err := data.Pack(msg[off:])
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
off += n
|
||||
default:
|
||||
return lenmsg, &Error{err: "bad kind interface packing"}
|
||||
}
|
||||
case reflect.Slice:
|
||||
switch typefield.Tag {
|
||||
default:
|
||||
return lenmsg, &Error{"bad tag packing slice: " + typefield.Tag.Get("dns")}
|
||||
case `dns:"domain-name"`:
|
||||
for j := 0; j < val.Field(i).Len(); j++ {
|
||||
element := val.Field(i).Index(j).String()
|
||||
off, err = PackDomainName(element, msg, off, compression, false && compress)
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
}
|
||||
case `dns:"a"`:
|
||||
if val.Type().String() == "dns.IPSECKEY" {
|
||||
// Field(2) is GatewayType, must be 1
|
||||
if val.Field(2).Uint() != 1 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// It must be a slice of 4, even if it is 16, we encode
|
||||
// only the first 4
|
||||
if off+net.IPv4len > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing a"}
|
||||
}
|
||||
switch fv.Len() {
|
||||
case net.IPv6len:
|
||||
msg[off] = byte(fv.Index(12).Uint())
|
||||
msg[off+1] = byte(fv.Index(13).Uint())
|
||||
msg[off+2] = byte(fv.Index(14).Uint())
|
||||
msg[off+3] = byte(fv.Index(15).Uint())
|
||||
off += net.IPv4len
|
||||
case net.IPv4len:
|
||||
msg[off] = byte(fv.Index(0).Uint())
|
||||
msg[off+1] = byte(fv.Index(1).Uint())
|
||||
msg[off+2] = byte(fv.Index(2).Uint())
|
||||
msg[off+3] = byte(fv.Index(3).Uint())
|
||||
off += net.IPv4len
|
||||
case 0:
|
||||
// Allowed, for dynamic updates
|
||||
default:
|
||||
return lenmsg, &Error{err: "overflow packing a"}
|
||||
}
|
||||
case `dns:"aaaa"`:
|
||||
if val.Type().String() == "dns.IPSECKEY" {
|
||||
// Field(2) is GatewayType, must be 2
|
||||
if val.Field(2).Uint() != 2 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if fv.Len() == 0 {
|
||||
break
|
||||
}
|
||||
if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing aaaa"}
|
||||
}
|
||||
for j := 0; j < net.IPv6len; j++ {
|
||||
msg[off] = byte(fv.Index(j).Uint())
|
||||
off++
|
||||
}
|
||||
case `dns:"nsec"`: // NSEC/NSEC3
|
||||
// This is the uint16 type bitmap
|
||||
if val.Field(i).Len() == 0 {
|
||||
// Do absolutely nothing
|
||||
break
|
||||
}
|
||||
var lastwindow, lastlength uint16
|
||||
for j := 0; j < val.Field(i).Len(); j++ {
|
||||
t := uint16(fv.Index(j).Uint())
|
||||
window := t / 256
|
||||
length := (t-window*256)/8 + 1
|
||||
if window > lastwindow && lastlength != 0 {
|
||||
// New window, jump to the new offset
|
||||
off += int(lastlength) + 2
|
||||
lastlength = 0
|
||||
}
|
||||
if window < lastwindow || length < lastlength {
|
||||
return len(msg), &Error{err: "nsec bits out of order"}
|
||||
}
|
||||
if off+2+int(length) > len(msg) {
|
||||
return len(msg), &Error{err: "overflow packing nsec"}
|
||||
}
|
||||
// Setting the window #
|
||||
msg[off] = byte(window)
|
||||
// Setting the octets length
|
||||
msg[off+1] = byte(length)
|
||||
// Setting the bit value for the type in the right octet
|
||||
msg[off+1+int(length)] |= byte(1 << (7 - (t % 8)))
|
||||
lastwindow, lastlength = window, length
|
||||
}
|
||||
off += int(lastlength) + 2
|
||||
}
|
||||
case reflect.Struct:
|
||||
off, err = packStructValue(fv, msg, off, compression, compress)
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
case reflect.Uint8:
|
||||
if off+1 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing uint8"}
|
||||
}
|
||||
msg[off] = byte(fv.Uint())
|
||||
off++
|
||||
case reflect.Uint16:
|
||||
if off+2 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing uint16"}
|
||||
}
|
||||
i := fv.Uint()
|
||||
msg[off] = byte(i >> 8)
|
||||
msg[off+1] = byte(i)
|
||||
off += 2
|
||||
case reflect.Uint32:
|
||||
if off+4 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing uint32"}
|
||||
}
|
||||
i := fv.Uint()
|
||||
msg[off] = byte(i >> 24)
|
||||
msg[off+1] = byte(i >> 16)
|
||||
msg[off+2] = byte(i >> 8)
|
||||
msg[off+3] = byte(i)
|
||||
off += 4
|
||||
case reflect.Uint64:
|
||||
switch typefield.Tag {
|
||||
default:
|
||||
if off+8 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing uint64"}
|
||||
}
|
||||
i := fv.Uint()
|
||||
msg[off] = byte(i >> 56)
|
||||
msg[off+1] = byte(i >> 48)
|
||||
msg[off+2] = byte(i >> 40)
|
||||
msg[off+3] = byte(i >> 32)
|
||||
msg[off+4] = byte(i >> 24)
|
||||
msg[off+5] = byte(i >> 16)
|
||||
msg[off+6] = byte(i >> 8)
|
||||
msg[off+7] = byte(i)
|
||||
off += 8
|
||||
case `dns:"uint48"`:
|
||||
// Used in TSIG, where it stops at 48 bits, so we discard the upper 16
|
||||
if off+6 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing uint64 as uint48"}
|
||||
}
|
||||
i := fv.Uint()
|
||||
msg[off] = byte(i >> 40)
|
||||
msg[off+1] = byte(i >> 32)
|
||||
msg[off+2] = byte(i >> 24)
|
||||
msg[off+3] = byte(i >> 16)
|
||||
msg[off+4] = byte(i >> 8)
|
||||
msg[off+5] = byte(i)
|
||||
off += 6
|
||||
}
|
||||
case reflect.String:
|
||||
// There are multiple string encodings.
|
||||
// The tag distinguishes ordinary strings from domain names.
|
||||
s := fv.String()
|
||||
switch typefield.Tag {
|
||||
default:
|
||||
return lenmsg, &Error{"bad tag packing string: " + typefield.Tag.Get("dns")}
|
||||
case `dns:"base64"`:
|
||||
b64, err := fromBase64([]byte(s))
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
if off+len(b64) > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing base64"}
|
||||
}
|
||||
copy(msg[off:off+len(b64)], b64)
|
||||
off += len(b64)
|
||||
case `dns:"domain-name"`:
|
||||
if val.Type().String() == "dns.IPSECKEY" {
|
||||
// Field(2) is GatewayType, 1 and 2 or used for addresses
|
||||
x := val.Field(2).Uint()
|
||||
if x == 1 || x == 2 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
case `dns:"cdomain-name"`:
|
||||
if off, err = PackDomainName(s, msg, off, compression, true && compress); err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
case `dns:"size-base32"`:
|
||||
// This is purely for NSEC3 atm, the previous byte must
|
||||
// holds the length of the encoded string. As NSEC3
|
||||
// is only defined to SHA1, the hashlength is 20 (160 bits)
|
||||
msg[off-1] = 20
|
||||
fallthrough
|
||||
case `dns:"base32"`:
|
||||
b32, err := fromBase32([]byte(s))
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
if off+len(b32) > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing base32"}
|
||||
}
|
||||
copy(msg[off:off+len(b32)], b32)
|
||||
off += len(b32)
|
||||
case `dns:"size-hex"`:
|
||||
fallthrough
|
||||
case `dns:"hex"`:
|
||||
// There is no length encoded here
|
||||
h, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
if off+hex.DecodedLen(len(s)) > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow packing hex"}
|
||||
}
|
||||
copy(msg[off:off+hex.DecodedLen(len(s))], h)
|
||||
off += hex.DecodedLen(len(s))
|
||||
case `dns:"size"`:
|
||||
// TODO(miek): WTF? size?
|
||||
// the size is already encoded in the RR, we can safely use the
|
||||
// length of string. String is RAW (not encoded in hex, nor base64)
|
||||
copy(msg[off:off+len(s)], s)
|
||||
off += len(s)
|
||||
case `dns:"octet"`:
|
||||
bytesTmp := make([]byte, 256)
|
||||
off, err = packOctetString(fv.String(), msg, off, bytesTmp)
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
case `dns:"txt"`:
|
||||
fallthrough
|
||||
case "":
|
||||
if txtTmp == nil {
|
||||
txtTmp = make([]byte, 256*4+1)
|
||||
}
|
||||
off, err = packTxtString(fv.String(), msg, off, txtTmp)
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return off, nil
|
||||
}
|
||||
|
||||
func structValue(any interface{}) reflect.Value {
|
||||
return reflect.ValueOf(any).Elem()
|
||||
}
|
||||
|
||||
// PackStruct packs any structure to wire format.
|
||||
func PackStruct(any interface{}, msg []byte, off int) (off1 int, err error) {
|
||||
off, err = packStructValue(structValue(any), msg, off, nil, false)
|
||||
return off, err
|
||||
}
|
||||
|
||||
func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
|
||||
off, err = packStructValue(structValue(any), msg, off, compression, compress)
|
||||
return off, err
|
||||
}
|
||||
|
||||
// Unpack a reflect.StructValue from msg.
|
||||
// Same restrictions as packStructValue.
|
||||
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) {
|
||||
lenmsg := len(msg)
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
if off > lenmsg {
|
||||
return lenmsg, &Error{"bad offset unpacking"}
|
||||
}
|
||||
switch fv := val.Field(i); fv.Kind() {
|
||||
default:
|
||||
return lenmsg, &Error{err: "bad kind unpacking"}
|
||||
case reflect.Interface:
|
||||
// PrivateRR is the only RR implementation that has interface field.
|
||||
// therefore it's expected that this interface would be PrivateRdata
|
||||
switch data := fv.Interface().(type) {
|
||||
case PrivateRdata:
|
||||
n, err := data.Unpack(msg[off:])
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
off += n
|
||||
default:
|
||||
return lenmsg, &Error{err: "bad kind interface unpacking"}
|
||||
}
|
||||
case reflect.Slice:
|
||||
switch val.Type().Field(i).Tag {
|
||||
default:
|
||||
return lenmsg, &Error{"bad tag unpacking slice: " + val.Type().Field(i).Tag.Get("dns")}
|
||||
case `dns:"domain-name"`:
|
||||
// HIP record slice of name (or none)
|
||||
var servers []string
|
||||
var s string
|
||||
for off < lenmsg {
|
||||
s, off, err = UnpackDomainName(msg, off)
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
servers = append(servers, s)
|
||||
}
|
||||
fv.Set(reflect.ValueOf(servers))
|
||||
case `dns:"a"`:
|
||||
if val.Type().String() == "dns.IPSECKEY" {
|
||||
// Field(2) is GatewayType, must be 1
|
||||
if val.Field(2).Uint() != 1 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if off == lenmsg {
|
||||
break // dyn. update
|
||||
}
|
||||
if off+net.IPv4len > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking a"}
|
||||
}
|
||||
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
|
||||
off += net.IPv4len
|
||||
case `dns:"aaaa"`:
|
||||
if val.Type().String() == "dns.IPSECKEY" {
|
||||
// Field(2) is GatewayType, must be 2
|
||||
if val.Field(2).Uint() != 2 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if off == lenmsg {
|
||||
break
|
||||
}
|
||||
if off+net.IPv6len > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking aaaa"}
|
||||
}
|
||||
fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4],
|
||||
msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10],
|
||||
msg[off+11], msg[off+12], msg[off+13], msg[off+14], msg[off+15]}))
|
||||
off += net.IPv6len
|
||||
case `dns:"nsec"`: // NSEC/NSEC3
|
||||
if off == len(msg) {
|
||||
break
|
||||
}
|
||||
// Rest of the record is the type bitmap
|
||||
var nsec []uint16
|
||||
length := 0
|
||||
window := 0
|
||||
lastwindow := -1
|
||||
for off < len(msg) {
|
||||
if off+2 > len(msg) {
|
||||
return len(msg), &Error{err: "overflow unpacking nsecx"}
|
||||
}
|
||||
window = int(msg[off])
|
||||
length = int(msg[off+1])
|
||||
off += 2
|
||||
if window <= lastwindow {
|
||||
// RFC 4034: Blocks are present in the NSEC RR RDATA in
|
||||
// increasing numerical order.
|
||||
return len(msg), &Error{err: "out of order NSEC block"}
|
||||
}
|
||||
if length == 0 {
|
||||
// RFC 4034: Blocks with no types present MUST NOT be included.
|
||||
return len(msg), &Error{err: "empty NSEC block"}
|
||||
}
|
||||
if length > 32 {
|
||||
return len(msg), &Error{err: "NSEC block too long"}
|
||||
}
|
||||
if off+length > len(msg) {
|
||||
return len(msg), &Error{err: "overflowing NSEC block"}
|
||||
}
|
||||
|
||||
// Walk the bytes in the window and extract the type bits
|
||||
for j := 0; j < length; j++ {
|
||||
b := msg[off+j]
|
||||
// Check the bits one by one, and set the type
|
||||
if b&0x80 == 0x80 {
|
||||
nsec = append(nsec, uint16(window*256+j*8+0))
|
||||
}
|
||||
if b&0x40 == 0x40 {
|
||||
nsec = append(nsec, uint16(window*256+j*8+1))
|
||||
}
|
||||
if b&0x20 == 0x20 {
|
||||
nsec = append(nsec, uint16(window*256+j*8+2))
|
||||
}
|
||||
if b&0x10 == 0x10 {
|
||||
nsec = append(nsec, uint16(window*256+j*8+3))
|
||||
}
|
||||
if b&0x8 == 0x8 {
|
||||
nsec = append(nsec, uint16(window*256+j*8+4))
|
||||
}
|
||||
if b&0x4 == 0x4 {
|
||||
nsec = append(nsec, uint16(window*256+j*8+5))
|
||||
}
|
||||
if b&0x2 == 0x2 {
|
||||
nsec = append(nsec, uint16(window*256+j*8+6))
|
||||
}
|
||||
if b&0x1 == 0x1 {
|
||||
nsec = append(nsec, uint16(window*256+j*8+7))
|
||||
}
|
||||
}
|
||||
off += length
|
||||
lastwindow = window
|
||||
}
|
||||
fv.Set(reflect.ValueOf(nsec))
|
||||
}
|
||||
case reflect.Struct:
|
||||
off, err = unpackStructValue(fv, msg, off)
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
if val.Type().Field(i).Name == "Hdr" {
|
||||
lenrd := off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
|
||||
if lenrd > lenmsg {
|
||||
return lenmsg, &Error{err: "overflowing header size"}
|
||||
}
|
||||
msg = msg[:lenrd]
|
||||
lenmsg = len(msg)
|
||||
}
|
||||
case reflect.Uint8:
|
||||
if off == lenmsg {
|
||||
break
|
||||
}
|
||||
if off+1 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking uint8"}
|
||||
}
|
||||
fv.SetUint(uint64(uint8(msg[off])))
|
||||
off++
|
||||
case reflect.Uint16:
|
||||
if off == lenmsg {
|
||||
break
|
||||
}
|
||||
var i uint16
|
||||
if off+2 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking uint16"}
|
||||
}
|
||||
i = binary.BigEndian.Uint16(msg[off:])
|
||||
off += 2
|
||||
fv.SetUint(uint64(i))
|
||||
case reflect.Uint32:
|
||||
if off == lenmsg {
|
||||
break
|
||||
}
|
||||
if off+4 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking uint32"}
|
||||
}
|
||||
fv.SetUint(uint64(binary.BigEndian.Uint32(msg[off:])))
|
||||
off += 4
|
||||
case reflect.Uint64:
|
||||
if off == lenmsg {
|
||||
break
|
||||
}
|
||||
switch val.Type().Field(i).Tag {
|
||||
default:
|
||||
if off+8 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking uint64"}
|
||||
}
|
||||
fv.SetUint(binary.BigEndian.Uint64(msg[off:]))
|
||||
off += 8
|
||||
case `dns:"uint48"`:
|
||||
// Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes)
|
||||
if off+6 > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking uint64 as uint48"}
|
||||
}
|
||||
fv.SetUint(uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
|
||||
uint64(msg[off+4])<<8 | uint64(msg[off+5])))
|
||||
off += 6
|
||||
}
|
||||
case reflect.String:
|
||||
var s string
|
||||
if off == lenmsg {
|
||||
break
|
||||
}
|
||||
switch val.Type().Field(i).Tag {
|
||||
default:
|
||||
return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")}
|
||||
case `dns:"octet"`:
|
||||
s = string(msg[off:])
|
||||
off = lenmsg
|
||||
case `dns:"hex"`:
|
||||
hexend := lenmsg
|
||||
if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
|
||||
hexend = off + int(val.FieldByName("HitLength").Uint())
|
||||
}
|
||||
if hexend > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking HIP hex"}
|
||||
}
|
||||
s = hex.EncodeToString(msg[off:hexend])
|
||||
off = hexend
|
||||
case `dns:"base64"`:
|
||||
// Rest of the RR is base64 encoded value
|
||||
b64end := lenmsg
|
||||
if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
|
||||
b64end = off + int(val.FieldByName("PublicKeyLength").Uint())
|
||||
}
|
||||
if b64end > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking HIP base64"}
|
||||
}
|
||||
s = toBase64(msg[off:b64end])
|
||||
off = b64end
|
||||
case `dns:"cdomain-name"`:
|
||||
fallthrough
|
||||
case `dns:"domain-name"`:
|
||||
if val.Type().String() == "dns.IPSECKEY" {
|
||||
// Field(2) is GatewayType, 1 and 2 or used for addresses
|
||||
x := val.Field(2).Uint()
|
||||
if x == 1 || x == 2 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if off == lenmsg && int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) == 0 {
|
||||
// zero rdata is ok for dyn updates, but only if rdlength is 0
|
||||
break
|
||||
}
|
||||
s, off, err = UnpackDomainName(msg, off)
|
||||
if err != nil {
|
||||
return lenmsg, err
|
||||
}
|
||||
case `dns:"size-base32"`:
|
||||
var size int
|
||||
switch val.Type().Name() {
|
||||
case "NSEC3":
|
||||
switch val.Type().Field(i).Name {
|
||||
case "NextDomain":
|
||||
name := val.FieldByName("HashLength")
|
||||
size = int(name.Uint())
|
||||
}
|
||||
}
|
||||
if off+size > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking base32"}
|
||||
}
|
||||
s = toBase32(msg[off : off+size])
|
||||
off += size
|
||||
case `dns:"size-hex"`:
|
||||
// a "size" string, but it must be encoded in hex in the string
|
||||
var size int
|
||||
switch val.Type().Name() {
|
||||
case "NSEC3":
|
||||
switch val.Type().Field(i).Name {
|
||||
case "Salt":
|
||||
name := val.FieldByName("SaltLength")
|
||||
size = int(name.Uint())
|
||||
case "NextDomain":
|
||||
name := val.FieldByName("HashLength")
|
||||
size = int(name.Uint())
|
||||
}
|
||||
case "TSIG":
|
||||
switch val.Type().Field(i).Name {
|
||||
case "MAC":
|
||||
name := val.FieldByName("MACSize")
|
||||
size = int(name.Uint())
|
||||
case "OtherData":
|
||||
name := val.FieldByName("OtherLen")
|
||||
size = int(name.Uint())
|
||||
}
|
||||
}
|
||||
if off+size > lenmsg {
|
||||
return lenmsg, &Error{err: "overflow unpacking hex"}
|
||||
}
|
||||
s = hex.EncodeToString(msg[off : off+size])
|
||||
off += size
|
||||
case `dns:"txt"`:
|
||||
fallthrough
|
||||
case "":
|
||||
s, off, err = unpackTxtString(msg, off)
|
||||
}
|
||||
fv.SetString(s)
|
||||
}
|
||||
}
|
||||
return off, nil
|
||||
}
|
||||
|
||||
// Helpers for dealing with escaped bytes
|
||||
func isDigit(b byte) bool { return b >= '0' && b <= '9' }
|
||||
|
||||
@ -1141,12 +553,6 @@ func dddToByte(s []byte) byte {
|
||||
return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
|
||||
}
|
||||
|
||||
// UnpackStruct unpacks a binary message from offset off to the interface
|
||||
// value given.
|
||||
func UnpackStruct(any interface{}, msg []byte, off int) (int, error) {
|
||||
return unpackStructValue(structValue(any), msg, off)
|
||||
}
|
||||
|
||||
// Helper function for packing and unpacking
|
||||
func intToBytes(i *big.Int, length int) []byte {
|
||||
buf := i.Bytes()
|
||||
|
@ -61,7 +61,7 @@ func main() {
|
||||
if st, _ := getTypeStruct(o.Type(), scope); st == nil {
|
||||
continue
|
||||
}
|
||||
if name == "PrivateRR" || name == "WKS" {
|
||||
if name == "PrivateRR" {
|
||||
continue
|
||||
}
|
||||
|
||||
|
10
nsecx.go
10
nsecx.go
@ -17,7 +17,7 @@ func HashName(label string, ha uint8, iter uint16, salt string) string {
|
||||
saltwire := new(saltWireFmt)
|
||||
saltwire.Salt = salt
|
||||
wire := make([]byte, DefaultMsgSize)
|
||||
n, err := PackStruct(saltwire, wire, 0)
|
||||
n, err := packSaltWire(saltwire, wire)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
@ -110,3 +110,11 @@ func (rr *NSEC3) Match(name string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func packSaltWire(sw *saltWireFmt, msg []byte) (int, error) {
|
||||
off, err := packStringHex(sw.Salt, msg, 0)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
return off, nil
|
||||
}
|
||||
|
139
tsig.go
139
tsig.go
@ -72,8 +72,7 @@ type tsigWireFmt struct {
|
||||
OtherData string `dns:"size-hex:OtherLen"`
|
||||
}
|
||||
|
||||
// If we have the MAC use this type to convert it to wiredata.
|
||||
// Section 3.4.3. Request MAC
|
||||
// If we have the MAC use this type to convert it to wiredata. Section 3.4.3. Request MAC
|
||||
type macWireFmt struct {
|
||||
MACSize uint16
|
||||
MAC string `dns:"size-hex:MACSize"`
|
||||
@ -213,7 +212,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
|
||||
m.MACSize = uint16(len(requestMAC) / 2)
|
||||
m.MAC = requestMAC
|
||||
buf = make([]byte, len(requestMAC)) // long enough
|
||||
n, _ := PackStruct(m, buf, 0)
|
||||
n, _ := packMacWire(m, buf)
|
||||
buf = buf[:n]
|
||||
}
|
||||
|
||||
@ -222,7 +221,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
|
||||
tsig := new(timerWireFmt)
|
||||
tsig.TimeSigned = rr.TimeSigned
|
||||
tsig.Fudge = rr.Fudge
|
||||
n, _ := PackStruct(tsig, tsigvar, 0)
|
||||
n, _ := packTimerWire(tsig, tsigvar)
|
||||
tsigvar = tsigvar[:n]
|
||||
} else {
|
||||
tsig := new(tsigWireFmt)
|
||||
@ -235,7 +234,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
|
||||
tsig.Error = rr.Error
|
||||
tsig.OtherLen = rr.OtherLen
|
||||
tsig.OtherData = rr.OtherData
|
||||
n, _ := PackStruct(tsig, tsigvar, 0)
|
||||
n, _ := packTsigWire(tsig, tsigvar)
|
||||
tsigvar = tsigvar[:n]
|
||||
}
|
||||
|
||||
@ -250,57 +249,51 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
|
||||
|
||||
// Strip the TSIG from the raw message.
|
||||
func stripTsig(msg []byte) ([]byte, *TSIG, error) {
|
||||
// Copied from msg.go's Unpack()
|
||||
// Header.
|
||||
var dh Header
|
||||
var err error
|
||||
dns := new(Msg)
|
||||
rr := new(TSIG)
|
||||
off := 0
|
||||
tsigoff := 0
|
||||
if off, err = UnpackStruct(&dh, msg, off); err != nil {
|
||||
// Copied from msg.go's Unpack() Header, but modified.
|
||||
var (
|
||||
dh Header
|
||||
err error
|
||||
)
|
||||
off, tsigoff := 0, 0
|
||||
|
||||
if dh, off, err = unpackMsgHdr(msg, off); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if dh.Arcount == 0 {
|
||||
return nil, nil, ErrNoSig
|
||||
}
|
||||
|
||||
// Rcode, see msg.go Unpack()
|
||||
if int(dh.Bits&0xF) == RcodeNotAuth {
|
||||
return nil, nil, ErrAuth
|
||||
}
|
||||
|
||||
// Arrays.
|
||||
dns.Question = make([]Question, dh.Qdcount)
|
||||
dns.Answer = make([]RR, dh.Ancount)
|
||||
dns.Ns = make([]RR, dh.Nscount)
|
||||
dns.Extra = make([]RR, dh.Arcount)
|
||||
for i := 0; i < int(dh.Qdcount); i++ {
|
||||
_, off, err = unpackQuestion(msg, off)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(dns.Question); i++ {
|
||||
off, err = UnpackStruct(&dns.Question[i], msg, off)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
_, off, err = unpackRRslice(int(dh.Ancount), msg, off)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for i := 0; i < len(dns.Answer); i++ {
|
||||
dns.Answer[i], off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
_, off, err = unpackRRslice(int(dh.Nscount), msg, off)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for i := 0; i < len(dns.Ns); i++ {
|
||||
dns.Ns[i], off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(dns.Extra); i++ {
|
||||
|
||||
rr := new(TSIG)
|
||||
var extra RR
|
||||
for i := 0; i < int(dh.Arcount); i++ {
|
||||
tsigoff = off
|
||||
dns.Extra[i], off, err = UnpackRR(msg, off)
|
||||
extra, off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if dns.Extra[i].Header().Rrtype == TypeTSIG {
|
||||
rr = dns.Extra[i].(*TSIG)
|
||||
if extra.Header().Rrtype == TypeTSIG {
|
||||
rr = extra.(*TSIG)
|
||||
// Adjust Arcount.
|
||||
arcount := binary.BigEndian.Uint16(msg[10:])
|
||||
binary.BigEndian.PutUint16(msg[10:], arcount-1)
|
||||
@ -319,3 +312,71 @@ func tsigTimeToString(t uint64) string {
|
||||
ti := time.Unix(int64(t), 0).UTC()
|
||||
return ti.Format("20060102150405")
|
||||
}
|
||||
|
||||
func packTsigWire(tw *tsigWireFmt, msg []byte) (int, error) {
|
||||
// copied from zmsg.go TSIG packing
|
||||
// RR_Header
|
||||
off, err := PackDomainName(tw.Name, msg, 0, nil, false)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(tw.Class, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint32(tw.Ttl, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
|
||||
off, err = PackDomainName(tw.Algorithm, msg, off, nil, false)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint48(tw.TimeSigned, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(tw.Fudge, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
|
||||
off, err = packUint16(tw.Error, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(tw.OtherLen, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packStringHex(tw.OtherData, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
return off, nil
|
||||
}
|
||||
|
||||
func packMacWire(mw *macWireFmt, msg []byte) (int, error) {
|
||||
off, err := packUint16(mw.MACSize, msg, 0)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packStringHex(mw.MAC, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
return off, nil
|
||||
}
|
||||
|
||||
func packTimerWire(tw *timerWireFmt, msg []byte) (int, error) {
|
||||
off, err := packUint48(tw.TimeSigned, msg, 0)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
off, err = packUint16(tw.Fudge, msg, off)
|
||||
if err != nil {
|
||||
return off, err
|
||||
}
|
||||
return off, nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user