Remove reflection (#376)

Everything is generated. Remove all uses of packStruct/unpackStruct and
make the library reflectionless.
This commit is contained in:
Miek Gieben 2016-06-12 21:06:46 +01:00 committed by GitHub
parent dbffa4b057
commit b51e305bc6
8 changed files with 191 additions and 649 deletions

View File

@ -284,9 +284,11 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
p = p[:n]
if hdr != nil {
if _, err = UnpackStruct(hdr, p, 0); err != nil {
dh, _, err := unpackMsgHdr(p, 0)
if err != nil {
return nil, err
}
*hdr = dh
}
return p, err
}

13
dns.go
View File

@ -84,19 +84,22 @@ func (h *RR_Header) len() int {
return l
}
// ToRFC3597 converts a known RR to the unknown RR representation
// from RFC 3597.
// ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597.
func (rr *RFC3597) ToRFC3597(r RR) error {
buf := make([]byte, r.len()*2)
off, err := PackStruct(r, buf, 0)
off, err := PackRR(r, buf, 0, nil, false)
if err != nil {
return err
}
buf = buf[:off]
rawSetRdlength(buf, 0, off)
_, err = UnpackStruct(rr, buf, 0)
if int(r.Header().Rdlength) > off {
return ErrBuf
}
rfc3597, _, err := unpackRFC3597(*r.Header(), buf, off-int(r.Header().Rdlength))
if err != nil {
return err
}
*rr = *rfc3597.(*RFC3597)
return nil
}

View File

@ -315,7 +315,13 @@ func TestToRFC3597(t *testing.T) {
x := new(RFC3597)
x.ToRFC3597(a)
if x.String() != `miek.nl. 3600 CLASS1 TYPE1 \# 4 0a000101` {
t.Error("string mismatch")
t.Errorf("string mismatch, got: %s", x)
}
b, _ := NewRR("miek.nl. IN MX 10 mx.miek.nl.")
x.ToRFC3597(b)
if x.String() != `miek.nl. 3600 CLASS1 TYPE15 \# 14 000a026d78046d69656b026e6c00` {
t.Errorf("string mismatch, got: %s", x)
}
}

View File

@ -104,9 +104,7 @@ const (
ZONE = 1 << 8
)
// The RRSIG needs to be converted to wireformat with some of
// the rdata (the signature) missing. Use this struct to ease
// the conversion (and re-use the pack/unpack functions).
// The RRSIG needs to be converted to wireformat with some of the rdata (the signature) missing.
type rrsigWireFmt struct {
TypeCovered uint16
Algorithm uint8
@ -155,7 +153,7 @@ func (k *DNSKEY) KeyTag() uint16 {
keywire.Algorithm = k.Algorithm
keywire.PublicKey = k.PublicKey
wire := make([]byte, DefaultMsgSize)
n, err := PackStruct(keywire, wire, 0)
n, err := packKeyWire(keywire, wire)
if err != nil {
return 0
}
@ -193,7 +191,7 @@ func (k *DNSKEY) ToDS(h uint8) *DS {
keywire.Algorithm = k.Algorithm
keywire.PublicKey = k.PublicKey
wire := make([]byte, DefaultMsgSize)
n, err := PackStruct(keywire, wire, 0)
n, err := packKeyWire(keywire, wire)
if err != nil {
return nil
}
@ -290,7 +288,7 @@ func (rr *RRSIG) Sign(k crypto.Signer, rrset []RR) error {
// Create the desired binary blob
signdata := make([]byte, DefaultMsgSize)
n, err := PackStruct(sigwire, signdata, 0)
n, err := packSigWire(sigwire, signdata)
if err != nil {
return err
}
@ -408,7 +406,7 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
sigwire.SignerName = strings.ToLower(rr.SignerName)
// Create the desired binary blob
signeddata := make([]byte, DefaultMsgSize)
n, err := PackStruct(sigwire, signeddata, 0)
n, err := packSigWire(sigwire, signeddata)
if err != nil {
return err
}
@ -663,3 +661,61 @@ func rawSignatureData(rrset []RR, s *RRSIG) (buf []byte, err error) {
}
return buf, nil
}
func packSigWire(sw *rrsigWireFmt, msg []byte) (int, error) {
// copied from zmsg.go RRSIG packing
off, err := packUint16(sw.TypeCovered, msg, 0)
if err != nil {
return off, err
}
off, err = packUint8(sw.Algorithm, msg, off)
if err != nil {
return off, err
}
off, err = packUint8(sw.Labels, msg, off)
if err != nil {
return off, err
}
off, err = packUint32(sw.OrigTtl, msg, off)
if err != nil {
return off, err
}
off, err = packUint32(sw.Expiration, msg, off)
if err != nil {
return off, err
}
off, err = packUint32(sw.Inception, msg, off)
if err != nil {
return off, err
}
off, err = packUint16(sw.KeyTag, msg, off)
if err != nil {
return off, err
}
off, err = PackDomainName(sw.SignerName, msg, off, nil, false)
if err != nil {
return off, err
}
return off, nil
}
func packKeyWire(dw *dnskeyWireFmt, msg []byte) (int, error) {
// copied from zmsg.go DNSKEY packing
off, err := packUint16(dw.Flags, msg, 0)
if err != nil {
return off, err
}
off, err = packUint8(dw.Protocol, msg, off)
if err != nil {
return off, err
}
off, err = packUint8(dw.Algorithm, msg, off)
if err != nil {
return off, err
}
off, err = packStringBase64(dw.PublicKey, msg, off)
if err != nil {
return off, err
}
return off, nil
}

594
msg.go
View File

@ -13,11 +13,8 @@ package dns
import (
crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"math/big"
"math/rand"
"net"
"reflect"
"strconv"
)
@ -549,591 +546,6 @@ func unpackTxtString(msg []byte, offset int) (string, int, error) {
return string(s), offset, nil
}
// Pack a reflect.StructValue into msg. Struct members can only be uint8, uint16, uint32, string,
// slices and other (often anonymous) structs.
func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
var txtTmp []byte
lenmsg := len(msg)
numfield := val.NumField()
for i := 0; i < numfield; i++ {
typefield := val.Type().Field(i)
if typefield.Tag == `dns:"-"` {
continue
}
switch fv := val.Field(i); fv.Kind() {
default:
return lenmsg, &Error{err: "bad kind packing"}
case reflect.Interface:
// PrivateRR is the only RR implementation that has interface field.
// therefore it's expected that this interface would be PrivateRdata
switch data := fv.Interface().(type) {
case PrivateRdata:
n, err := data.Pack(msg[off:])
if err != nil {
return lenmsg, err
}
off += n
default:
return lenmsg, &Error{err: "bad kind interface packing"}
}
case reflect.Slice:
switch typefield.Tag {
default:
return lenmsg, &Error{"bad tag packing slice: " + typefield.Tag.Get("dns")}
case `dns:"domain-name"`:
for j := 0; j < val.Field(i).Len(); j++ {
element := val.Field(i).Index(j).String()
off, err = PackDomainName(element, msg, off, compression, false && compress)
if err != nil {
return lenmsg, err
}
}
case `dns:"a"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 1
if val.Field(2).Uint() != 1 {
continue
}
}
// It must be a slice of 4, even if it is 16, we encode
// only the first 4
if off+net.IPv4len > lenmsg {
return lenmsg, &Error{err: "overflow packing a"}
}
switch fv.Len() {
case net.IPv6len:
msg[off] = byte(fv.Index(12).Uint())
msg[off+1] = byte(fv.Index(13).Uint())
msg[off+2] = byte(fv.Index(14).Uint())
msg[off+3] = byte(fv.Index(15).Uint())
off += net.IPv4len
case net.IPv4len:
msg[off] = byte(fv.Index(0).Uint())
msg[off+1] = byte(fv.Index(1).Uint())
msg[off+2] = byte(fv.Index(2).Uint())
msg[off+3] = byte(fv.Index(3).Uint())
off += net.IPv4len
case 0:
// Allowed, for dynamic updates
default:
return lenmsg, &Error{err: "overflow packing a"}
}
case `dns:"aaaa"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 2
if val.Field(2).Uint() != 2 {
continue
}
}
if fv.Len() == 0 {
break
}
if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg {
return lenmsg, &Error{err: "overflow packing aaaa"}
}
for j := 0; j < net.IPv6len; j++ {
msg[off] = byte(fv.Index(j).Uint())
off++
}
case `dns:"nsec"`: // NSEC/NSEC3
// This is the uint16 type bitmap
if val.Field(i).Len() == 0 {
// Do absolutely nothing
break
}
var lastwindow, lastlength uint16
for j := 0; j < val.Field(i).Len(); j++ {
t := uint16(fv.Index(j).Uint())
window := t / 256
length := (t-window*256)/8 + 1
if window > lastwindow && lastlength != 0 {
// New window, jump to the new offset
off += int(lastlength) + 2
lastlength = 0
}
if window < lastwindow || length < lastlength {
return len(msg), &Error{err: "nsec bits out of order"}
}
if off+2+int(length) > len(msg) {
return len(msg), &Error{err: "overflow packing nsec"}
}
// Setting the window #
msg[off] = byte(window)
// Setting the octets length
msg[off+1] = byte(length)
// Setting the bit value for the type in the right octet
msg[off+1+int(length)] |= byte(1 << (7 - (t % 8)))
lastwindow, lastlength = window, length
}
off += int(lastlength) + 2
}
case reflect.Struct:
off, err = packStructValue(fv, msg, off, compression, compress)
if err != nil {
return lenmsg, err
}
case reflect.Uint8:
if off+1 > lenmsg {
return lenmsg, &Error{err: "overflow packing uint8"}
}
msg[off] = byte(fv.Uint())
off++
case reflect.Uint16:
if off+2 > lenmsg {
return lenmsg, &Error{err: "overflow packing uint16"}
}
i := fv.Uint()
msg[off] = byte(i >> 8)
msg[off+1] = byte(i)
off += 2
case reflect.Uint32:
if off+4 > lenmsg {
return lenmsg, &Error{err: "overflow packing uint32"}
}
i := fv.Uint()
msg[off] = byte(i >> 24)
msg[off+1] = byte(i >> 16)
msg[off+2] = byte(i >> 8)
msg[off+3] = byte(i)
off += 4
case reflect.Uint64:
switch typefield.Tag {
default:
if off+8 > lenmsg {
return lenmsg, &Error{err: "overflow packing uint64"}
}
i := fv.Uint()
msg[off] = byte(i >> 56)
msg[off+1] = byte(i >> 48)
msg[off+2] = byte(i >> 40)
msg[off+3] = byte(i >> 32)
msg[off+4] = byte(i >> 24)
msg[off+5] = byte(i >> 16)
msg[off+6] = byte(i >> 8)
msg[off+7] = byte(i)
off += 8
case `dns:"uint48"`:
// Used in TSIG, where it stops at 48 bits, so we discard the upper 16
if off+6 > lenmsg {
return lenmsg, &Error{err: "overflow packing uint64 as uint48"}
}
i := fv.Uint()
msg[off] = byte(i >> 40)
msg[off+1] = byte(i >> 32)
msg[off+2] = byte(i >> 24)
msg[off+3] = byte(i >> 16)
msg[off+4] = byte(i >> 8)
msg[off+5] = byte(i)
off += 6
}
case reflect.String:
// There are multiple string encodings.
// The tag distinguishes ordinary strings from domain names.
s := fv.String()
switch typefield.Tag {
default:
return lenmsg, &Error{"bad tag packing string: " + typefield.Tag.Get("dns")}
case `dns:"base64"`:
b64, err := fromBase64([]byte(s))
if err != nil {
return lenmsg, err
}
if off+len(b64) > lenmsg {
return lenmsg, &Error{err: "overflow packing base64"}
}
copy(msg[off:off+len(b64)], b64)
off += len(b64)
case `dns:"domain-name"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, 1 and 2 or used for addresses
x := val.Field(2).Uint()
if x == 1 || x == 2 {
continue
}
}
if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil {
return lenmsg, err
}
case `dns:"cdomain-name"`:
if off, err = PackDomainName(s, msg, off, compression, true && compress); err != nil {
return lenmsg, err
}
case `dns:"size-base32"`:
// This is purely for NSEC3 atm, the previous byte must
// holds the length of the encoded string. As NSEC3
// is only defined to SHA1, the hashlength is 20 (160 bits)
msg[off-1] = 20
fallthrough
case `dns:"base32"`:
b32, err := fromBase32([]byte(s))
if err != nil {
return lenmsg, err
}
if off+len(b32) > lenmsg {
return lenmsg, &Error{err: "overflow packing base32"}
}
copy(msg[off:off+len(b32)], b32)
off += len(b32)
case `dns:"size-hex"`:
fallthrough
case `dns:"hex"`:
// There is no length encoded here
h, err := hex.DecodeString(s)
if err != nil {
return lenmsg, err
}
if off+hex.DecodedLen(len(s)) > lenmsg {
return lenmsg, &Error{err: "overflow packing hex"}
}
copy(msg[off:off+hex.DecodedLen(len(s))], h)
off += hex.DecodedLen(len(s))
case `dns:"size"`:
// TODO(miek): WTF? size?
// the size is already encoded in the RR, we can safely use the
// length of string. String is RAW (not encoded in hex, nor base64)
copy(msg[off:off+len(s)], s)
off += len(s)
case `dns:"octet"`:
bytesTmp := make([]byte, 256)
off, err = packOctetString(fv.String(), msg, off, bytesTmp)
if err != nil {
return lenmsg, err
}
case `dns:"txt"`:
fallthrough
case "":
if txtTmp == nil {
txtTmp = make([]byte, 256*4+1)
}
off, err = packTxtString(fv.String(), msg, off, txtTmp)
if err != nil {
return lenmsg, err
}
}
}
}
return off, nil
}
func structValue(any interface{}) reflect.Value {
return reflect.ValueOf(any).Elem()
}
// PackStruct packs any structure to wire format.
func PackStruct(any interface{}, msg []byte, off int) (off1 int, err error) {
off, err = packStructValue(structValue(any), msg, off, nil, false)
return off, err
}
func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
off, err = packStructValue(structValue(any), msg, off, compression, compress)
return off, err
}
// Unpack a reflect.StructValue from msg.
// Same restrictions as packStructValue.
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) {
lenmsg := len(msg)
for i := 0; i < val.NumField(); i++ {
if off > lenmsg {
return lenmsg, &Error{"bad offset unpacking"}
}
switch fv := val.Field(i); fv.Kind() {
default:
return lenmsg, &Error{err: "bad kind unpacking"}
case reflect.Interface:
// PrivateRR is the only RR implementation that has interface field.
// therefore it's expected that this interface would be PrivateRdata
switch data := fv.Interface().(type) {
case PrivateRdata:
n, err := data.Unpack(msg[off:])
if err != nil {
return lenmsg, err
}
off += n
default:
return lenmsg, &Error{err: "bad kind interface unpacking"}
}
case reflect.Slice:
switch val.Type().Field(i).Tag {
default:
return lenmsg, &Error{"bad tag unpacking slice: " + val.Type().Field(i).Tag.Get("dns")}
case `dns:"domain-name"`:
// HIP record slice of name (or none)
var servers []string
var s string
for off < lenmsg {
s, off, err = UnpackDomainName(msg, off)
if err != nil {
return lenmsg, err
}
servers = append(servers, s)
}
fv.Set(reflect.ValueOf(servers))
case `dns:"a"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 1
if val.Field(2).Uint() != 1 {
continue
}
}
if off == lenmsg {
break // dyn. update
}
if off+net.IPv4len > lenmsg {
return lenmsg, &Error{err: "overflow unpacking a"}
}
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
off += net.IPv4len
case `dns:"aaaa"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, must be 2
if val.Field(2).Uint() != 2 {
continue
}
}
if off == lenmsg {
break
}
if off+net.IPv6len > lenmsg {
return lenmsg, &Error{err: "overflow unpacking aaaa"}
}
fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4],
msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10],
msg[off+11], msg[off+12], msg[off+13], msg[off+14], msg[off+15]}))
off += net.IPv6len
case `dns:"nsec"`: // NSEC/NSEC3
if off == len(msg) {
break
}
// Rest of the record is the type bitmap
var nsec []uint16
length := 0
window := 0
lastwindow := -1
for off < len(msg) {
if off+2 > len(msg) {
return len(msg), &Error{err: "overflow unpacking nsecx"}
}
window = int(msg[off])
length = int(msg[off+1])
off += 2
if window <= lastwindow {
// RFC 4034: Blocks are present in the NSEC RR RDATA in
// increasing numerical order.
return len(msg), &Error{err: "out of order NSEC block"}
}
if length == 0 {
// RFC 4034: Blocks with no types present MUST NOT be included.
return len(msg), &Error{err: "empty NSEC block"}
}
if length > 32 {
return len(msg), &Error{err: "NSEC block too long"}
}
if off+length > len(msg) {
return len(msg), &Error{err: "overflowing NSEC block"}
}
// Walk the bytes in the window and extract the type bits
for j := 0; j < length; j++ {
b := msg[off+j]
// Check the bits one by one, and set the type
if b&0x80 == 0x80 {
nsec = append(nsec, uint16(window*256+j*8+0))
}
if b&0x40 == 0x40 {
nsec = append(nsec, uint16(window*256+j*8+1))
}
if b&0x20 == 0x20 {
nsec = append(nsec, uint16(window*256+j*8+2))
}
if b&0x10 == 0x10 {
nsec = append(nsec, uint16(window*256+j*8+3))
}
if b&0x8 == 0x8 {
nsec = append(nsec, uint16(window*256+j*8+4))
}
if b&0x4 == 0x4 {
nsec = append(nsec, uint16(window*256+j*8+5))
}
if b&0x2 == 0x2 {
nsec = append(nsec, uint16(window*256+j*8+6))
}
if b&0x1 == 0x1 {
nsec = append(nsec, uint16(window*256+j*8+7))
}
}
off += length
lastwindow = window
}
fv.Set(reflect.ValueOf(nsec))
}
case reflect.Struct:
off, err = unpackStructValue(fv, msg, off)
if err != nil {
return lenmsg, err
}
if val.Type().Field(i).Name == "Hdr" {
lenrd := off + int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
if lenrd > lenmsg {
return lenmsg, &Error{err: "overflowing header size"}
}
msg = msg[:lenrd]
lenmsg = len(msg)
}
case reflect.Uint8:
if off == lenmsg {
break
}
if off+1 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking uint8"}
}
fv.SetUint(uint64(uint8(msg[off])))
off++
case reflect.Uint16:
if off == lenmsg {
break
}
var i uint16
if off+2 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking uint16"}
}
i = binary.BigEndian.Uint16(msg[off:])
off += 2
fv.SetUint(uint64(i))
case reflect.Uint32:
if off == lenmsg {
break
}
if off+4 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking uint32"}
}
fv.SetUint(uint64(binary.BigEndian.Uint32(msg[off:])))
off += 4
case reflect.Uint64:
if off == lenmsg {
break
}
switch val.Type().Field(i).Tag {
default:
if off+8 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking uint64"}
}
fv.SetUint(binary.BigEndian.Uint64(msg[off:]))
off += 8
case `dns:"uint48"`:
// Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes)
if off+6 > lenmsg {
return lenmsg, &Error{err: "overflow unpacking uint64 as uint48"}
}
fv.SetUint(uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
uint64(msg[off+4])<<8 | uint64(msg[off+5])))
off += 6
}
case reflect.String:
var s string
if off == lenmsg {
break
}
switch val.Type().Field(i).Tag {
default:
return lenmsg, &Error{"bad tag unpacking string: " + val.Type().Field(i).Tag.Get("dns")}
case `dns:"octet"`:
s = string(msg[off:])
off = lenmsg
case `dns:"hex"`:
hexend := lenmsg
if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
hexend = off + int(val.FieldByName("HitLength").Uint())
}
if hexend > lenmsg {
return lenmsg, &Error{err: "overflow unpacking HIP hex"}
}
s = hex.EncodeToString(msg[off:hexend])
off = hexend
case `dns:"base64"`:
// Rest of the RR is base64 encoded value
b64end := lenmsg
if val.FieldByName("Hdr").FieldByName("Rrtype").Uint() == uint64(TypeHIP) {
b64end = off + int(val.FieldByName("PublicKeyLength").Uint())
}
if b64end > lenmsg {
return lenmsg, &Error{err: "overflow unpacking HIP base64"}
}
s = toBase64(msg[off:b64end])
off = b64end
case `dns:"cdomain-name"`:
fallthrough
case `dns:"domain-name"`:
if val.Type().String() == "dns.IPSECKEY" {
// Field(2) is GatewayType, 1 and 2 or used for addresses
x := val.Field(2).Uint()
if x == 1 || x == 2 {
continue
}
}
if off == lenmsg && int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) == 0 {
// zero rdata is ok for dyn updates, but only if rdlength is 0
break
}
s, off, err = UnpackDomainName(msg, off)
if err != nil {
return lenmsg, err
}
case `dns:"size-base32"`:
var size int
switch val.Type().Name() {
case "NSEC3":
switch val.Type().Field(i).Name {
case "NextDomain":
name := val.FieldByName("HashLength")
size = int(name.Uint())
}
}
if off+size > lenmsg {
return lenmsg, &Error{err: "overflow unpacking base32"}
}
s = toBase32(msg[off : off+size])
off += size
case `dns:"size-hex"`:
// a "size" string, but it must be encoded in hex in the string
var size int
switch val.Type().Name() {
case "NSEC3":
switch val.Type().Field(i).Name {
case "Salt":
name := val.FieldByName("SaltLength")
size = int(name.Uint())
case "NextDomain":
name := val.FieldByName("HashLength")
size = int(name.Uint())
}
case "TSIG":
switch val.Type().Field(i).Name {
case "MAC":
name := val.FieldByName("MACSize")
size = int(name.Uint())
case "OtherData":
name := val.FieldByName("OtherLen")
size = int(name.Uint())
}
}
if off+size > lenmsg {
return lenmsg, &Error{err: "overflow unpacking hex"}
}
s = hex.EncodeToString(msg[off : off+size])
off += size
case `dns:"txt"`:
fallthrough
case "":
s, off, err = unpackTxtString(msg, off)
}
fv.SetString(s)
}
}
return off, nil
}
// Helpers for dealing with escaped bytes
func isDigit(b byte) bool { return b >= '0' && b <= '9' }
@ -1141,12 +553,6 @@ func dddToByte(s []byte) byte {
return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
}
// UnpackStruct unpacks a binary message from offset off to the interface
// value given.
func UnpackStruct(any interface{}, msg []byte, off int) (int, error) {
return unpackStructValue(structValue(any), msg, off)
}
// Helper function for packing and unpacking
func intToBytes(i *big.Int, length int) []byte {
buf := i.Bytes()

View File

@ -61,7 +61,7 @@ func main() {
if st, _ := getTypeStruct(o.Type(), scope); st == nil {
continue
}
if name == "PrivateRR" || name == "WKS" {
if name == "PrivateRR" {
continue
}

View File

@ -17,7 +17,7 @@ func HashName(label string, ha uint8, iter uint16, salt string) string {
saltwire := new(saltWireFmt)
saltwire.Salt = salt
wire := make([]byte, DefaultMsgSize)
n, err := PackStruct(saltwire, wire, 0)
n, err := packSaltWire(saltwire, wire)
if err != nil {
return ""
}
@ -110,3 +110,11 @@ func (rr *NSEC3) Match(name string) bool {
}
return false
}
func packSaltWire(sw *saltWireFmt, msg []byte) (int, error) {
off, err := packStringHex(sw.Salt, msg, 0)
if err != nil {
return off, err
}
return off, nil
}

139
tsig.go
View File

@ -72,8 +72,7 @@ type tsigWireFmt struct {
OtherData string `dns:"size-hex:OtherLen"`
}
// If we have the MAC use this type to convert it to wiredata.
// Section 3.4.3. Request MAC
// If we have the MAC use this type to convert it to wiredata. Section 3.4.3. Request MAC
type macWireFmt struct {
MACSize uint16
MAC string `dns:"size-hex:MACSize"`
@ -213,7 +212,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
m.MACSize = uint16(len(requestMAC) / 2)
m.MAC = requestMAC
buf = make([]byte, len(requestMAC)) // long enough
n, _ := PackStruct(m, buf, 0)
n, _ := packMacWire(m, buf)
buf = buf[:n]
}
@ -222,7 +221,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
tsig := new(timerWireFmt)
tsig.TimeSigned = rr.TimeSigned
tsig.Fudge = rr.Fudge
n, _ := PackStruct(tsig, tsigvar, 0)
n, _ := packTimerWire(tsig, tsigvar)
tsigvar = tsigvar[:n]
} else {
tsig := new(tsigWireFmt)
@ -235,7 +234,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
tsig.Error = rr.Error
tsig.OtherLen = rr.OtherLen
tsig.OtherData = rr.OtherData
n, _ := PackStruct(tsig, tsigvar, 0)
n, _ := packTsigWire(tsig, tsigvar)
tsigvar = tsigvar[:n]
}
@ -250,57 +249,51 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b
// Strip the TSIG from the raw message.
func stripTsig(msg []byte) ([]byte, *TSIG, error) {
// Copied from msg.go's Unpack()
// Header.
var dh Header
var err error
dns := new(Msg)
rr := new(TSIG)
off := 0
tsigoff := 0
if off, err = UnpackStruct(&dh, msg, off); err != nil {
// Copied from msg.go's Unpack() Header, but modified.
var (
dh Header
err error
)
off, tsigoff := 0, 0
if dh, off, err = unpackMsgHdr(msg, off); err != nil {
return nil, nil, err
}
if dh.Arcount == 0 {
return nil, nil, ErrNoSig
}
// Rcode, see msg.go Unpack()
if int(dh.Bits&0xF) == RcodeNotAuth {
return nil, nil, ErrAuth
}
// Arrays.
dns.Question = make([]Question, dh.Qdcount)
dns.Answer = make([]RR, dh.Ancount)
dns.Ns = make([]RR, dh.Nscount)
dns.Extra = make([]RR, dh.Arcount)
for i := 0; i < int(dh.Qdcount); i++ {
_, off, err = unpackQuestion(msg, off)
if err != nil {
return nil, nil, err
}
}
for i := 0; i < len(dns.Question); i++ {
off, err = UnpackStruct(&dns.Question[i], msg, off)
if err != nil {
return nil, nil, err
}
_, off, err = unpackRRslice(int(dh.Ancount), msg, off)
if err != nil {
return nil, nil, err
}
for i := 0; i < len(dns.Answer); i++ {
dns.Answer[i], off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
_, off, err = unpackRRslice(int(dh.Nscount), msg, off)
if err != nil {
return nil, nil, err
}
for i := 0; i < len(dns.Ns); i++ {
dns.Ns[i], off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
}
for i := 0; i < len(dns.Extra); i++ {
rr := new(TSIG)
var extra RR
for i := 0; i < int(dh.Arcount); i++ {
tsigoff = off
dns.Extra[i], off, err = UnpackRR(msg, off)
extra, off, err = UnpackRR(msg, off)
if err != nil {
return nil, nil, err
}
if dns.Extra[i].Header().Rrtype == TypeTSIG {
rr = dns.Extra[i].(*TSIG)
if extra.Header().Rrtype == TypeTSIG {
rr = extra.(*TSIG)
// Adjust Arcount.
arcount := binary.BigEndian.Uint16(msg[10:])
binary.BigEndian.PutUint16(msg[10:], arcount-1)
@ -319,3 +312,71 @@ func tsigTimeToString(t uint64) string {
ti := time.Unix(int64(t), 0).UTC()
return ti.Format("20060102150405")
}
func packTsigWire(tw *tsigWireFmt, msg []byte) (int, error) {
// copied from zmsg.go TSIG packing
// RR_Header
off, err := PackDomainName(tw.Name, msg, 0, nil, false)
if err != nil {
return off, err
}
off, err = packUint16(tw.Class, msg, off)
if err != nil {
return off, err
}
off, err = packUint32(tw.Ttl, msg, off)
if err != nil {
return off, err
}
off, err = PackDomainName(tw.Algorithm, msg, off, nil, false)
if err != nil {
return off, err
}
off, err = packUint48(tw.TimeSigned, msg, off)
if err != nil {
return off, err
}
off, err = packUint16(tw.Fudge, msg, off)
if err != nil {
return off, err
}
off, err = packUint16(tw.Error, msg, off)
if err != nil {
return off, err
}
off, err = packUint16(tw.OtherLen, msg, off)
if err != nil {
return off, err
}
off, err = packStringHex(tw.OtherData, msg, off)
if err != nil {
return off, err
}
return off, nil
}
func packMacWire(mw *macWireFmt, msg []byte) (int, error) {
off, err := packUint16(mw.MACSize, msg, 0)
if err != nil {
return off, err
}
off, err = packStringHex(mw.MAC, msg, off)
if err != nil {
return off, err
}
return off, nil
}
func packTimerWire(tw *timerWireFmt, msg []byte) (int, error) {
off, err := packUint48(tw.TimeSigned, msg, 0)
if err != nil {
return off, err
}
off, err = packUint16(tw.Fudge, msg, off)
if err != nil {
return off, err
}
return off, nil
}