diff --git a/client.go b/client.go index 7868cac2..fc4a613d 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ package dns import ( "bytes" "crypto/tls" + "encoding/binary" "io" "net" "time" @@ -300,7 +301,7 @@ func tcpMsgLen(t io.Reader) (int, error) { if n != 2 { return 0, ErrShortRead } - l, _ := unpackUint16Msg(p, 0) + l := binary.BigEndian.Uint16(p) if l == 0 { return 0, ErrShortRead } @@ -392,7 +393,7 @@ func (co *Conn) Write(p []byte) (n int, err error) { return 0, &Error{err: "message too large"} } l := make([]byte, 2, lp+2) - l[0], l[1] = packUint16Msg(uint16(lp)) + binary.BigEndian.PutUint16(l, uint16(lp)) p = append(l, p...) n, err := io.Copy(w, bytes.NewReader(p)) return int(n), err diff --git a/dnssec.go b/dnssec.go index 94010497..09c16f09 100644 --- a/dnssec.go +++ b/dnssec.go @@ -13,6 +13,7 @@ import ( _ "crypto/sha256" _ "crypto/sha512" "encoding/asn1" + "encoding/binary" "encoding/hex" "math/big" "sort" @@ -144,7 +145,7 @@ func (k *DNSKEY) KeyTag() uint16 { // at the base64 values. But I'm lazy. modulus, _ := fromBase64([]byte(k.PublicKey)) if len(modulus) > 1 { - x, _ := unpackUint16Msg(modulus, len(modulus)-2) + x := binary.BigEndian.Uint16(modulus[len(modulus)-2:]) keytag = int(x) } default: diff --git a/edns.go b/edns.go index 5634cd70..9d94614f 100644 --- a/edns.go +++ b/edns.go @@ -1,6 +1,7 @@ package dns import ( + "encoding/binary" "encoding/hex" "errors" "net" @@ -213,7 +214,7 @@ func (e *EDNS0_SUBNET) Option() uint16 { func (e *EDNS0_SUBNET) pack() ([]byte, error) { b := make([]byte, 4) - b[0], b[1] = packUint16Msg(e.Family) + binary.BigEndian.PutUint16(b[0:], e.Family) b[2] = e.SourceNetmask b[3] = e.SourceScope switch e.Family { @@ -247,7 +248,7 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error { if len(b) < 4 { return ErrBuf } - e.Family, _ = unpackUint16Msg(b, 0) + e.Family = binary.BigEndian.Uint16(b) e.SourceNetmask = b[2] e.SourceScope = b[3] switch e.Family { @@ -339,10 +340,7 @@ func (e *EDNS0_UL) String() string { return strconv.FormatUint(uint64(e.Lease), // Copied: http://golang.org/src/pkg/net/dnsmsg.go func (e *EDNS0_UL) pack() ([]byte, error) { b := make([]byte, 4) - b[0] = byte(e.Lease >> 24) - b[1] = byte(e.Lease >> 16) - b[2] = byte(e.Lease >> 8) - b[3] = byte(e.Lease) + binary.BigEndian.PutUint32(b, e.Lease) return b, nil } @@ -350,7 +348,7 @@ func (e *EDNS0_UL) unpack(b []byte) error { if len(b) < 4 { return ErrBuf } - e.Lease = uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) + e.Lease = binary.BigEndian.Uint32(b) return nil } @@ -369,21 +367,11 @@ func (e *EDNS0_LLQ) Option() uint16 { return EDNS0LLQ } func (e *EDNS0_LLQ) pack() ([]byte, error) { b := make([]byte, 18) - b[0], b[1] = packUint16Msg(e.Version) - b[2], b[3] = packUint16Msg(e.Opcode) - b[4], b[5] = packUint16Msg(e.Error) - b[6] = byte(e.Id >> 56) - b[7] = byte(e.Id >> 48) - b[8] = byte(e.Id >> 40) - b[9] = byte(e.Id >> 32) - b[10] = byte(e.Id >> 24) - b[11] = byte(e.Id >> 16) - b[12] = byte(e.Id >> 8) - b[13] = byte(e.Id) - b[14] = byte(e.LeaseLife >> 24) - b[15] = byte(e.LeaseLife >> 16) - b[16] = byte(e.LeaseLife >> 8) - b[17] = byte(e.LeaseLife) + binary.BigEndian.PutUint16(b[0:], e.Version) + binary.BigEndian.PutUint16(b[2:], e.Opcode) + binary.BigEndian.PutUint16(b[4:], e.Error) + binary.BigEndian.PutUint64(b[6:], e.Id) + binary.BigEndian.PutUint32(b[14:], e.LeaseLife) return b, nil } @@ -391,12 +379,11 @@ func (e *EDNS0_LLQ) unpack(b []byte) error { if len(b) < 18 { return ErrBuf } - e.Version, _ = unpackUint16Msg(b, 0) - e.Opcode, _ = unpackUint16Msg(b, 2) - e.Error, _ = unpackUint16Msg(b, 4) - e.Id = uint64(b[6])<<56 | uint64(b[6+1])<<48 | uint64(b[6+2])<<40 | - uint64(b[6+3])<<32 | uint64(b[6+4])<<24 | uint64(b[6+5])<<16 | uint64(b[6+6])<<8 | uint64(b[6+7]) - e.LeaseLife = uint32(b[14])<<24 | uint32(b[14+1])<<16 | uint32(b[14+2])<<8 | uint32(b[14+3]) + e.Version = binary.BigEndian.Uint16(b[0:]) + e.Opcode = binary.BigEndian.Uint16(b[2:]) + e.Error = binary.BigEndian.Uint16(b[4:]) + e.Id = binary.BigEndian.Uint64(b[6:]) + e.LeaseLife = binary.BigEndian.Uint32(b[14:]) return nil } @@ -492,7 +479,7 @@ func (e *EDNS0_EXPIRE) unpack(b []byte) error { if len(b) < 4 { return ErrBuf } - e.Expire = uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) + e.Expire = binary.BigEndian.Uint32(b) return nil } diff --git a/msg.go b/msg.go index 79b40f29..3051d52b 100644 --- a/msg.go +++ b/msg.go @@ -294,7 +294,7 @@ func packDomainName(s string, msg []byte, off int, compression map[string]int, c if pointer != -1 { // We have two bytes (14 bits) to put the pointer in // if msg == nil, we will never do compression - msg[nameoffset], msg[nameoffset+1] = packUint16Msg(uint16(pointer ^ 0xC000)) + binary.BigEndian.PutUint16(msg[nameoffset:], uint16(pointer^0xC000)) off = nameoffset + 1 goto End } @@ -603,9 +603,9 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str return lenmsg, &Error{err: "overflow packing opt"} } // Option code - msg[off], msg[off+1] = packUint16Msg(element.(EDNS0).Option()) + binary.BigEndian.PutUint16(msg[off:], element.(EDNS0).Option()) // Length - msg[off+2], msg[off+3] = packUint16Msg(uint16(len(b))) + binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b))) off += 4 if off+len(b) > lenmsg { copy(msg[off:], b) @@ -920,8 +920,10 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er if off+4 > lenmsg { return lenmsg, &Error{err: "overflow unpacking opt"} } - code, off = unpackUint16Msg(msg, off) - optlen, off1 := unpackUint16Msg(msg, off) + code = binary.BigEndian.Uint16(msg[off:]) + off += 2 + optlen := binary.BigEndian.Uint16(msg[off:]) + off1 := off + 2 if off1+int(optlen) > lenmsg { return lenmsg, &Error{err: "overflow unpacking opt"} } @@ -1126,7 +1128,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er if off+2 > lenmsg { return lenmsg, &Error{err: "overflow unpacking uint16"} } - i, off = unpackUint16Msg(msg, off) + i = binary.BigEndian.Uint16(msg[off:]) + off += 2 fv.SetUint(uint64(i)) case reflect.Uint32: if off == lenmsg { @@ -1135,7 +1138,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er if off+4 > lenmsg { return lenmsg, &Error{err: "overflow unpacking uint32"} } - fv.SetUint(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]))) + fv.SetUint(uint64(binary.BigEndian.Uint32(msg[off:]))) off += 4 case reflect.Uint64: if off == lenmsg { @@ -1146,8 +1149,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er if off+8 > lenmsg { return lenmsg, &Error{err: "overflow unpacking uint64"} } - fv.SetUint(uint64(uint64(msg[off])<<56 | uint64(msg[off+1])<<48 | uint64(msg[off+2])<<40 | - uint64(msg[off+3])<<32 | uint64(msg[off+4])<<24 | uint64(msg[off+5])<<16 | uint64(msg[off+6])<<8 | uint64(msg[off+7]))) + fv.SetUint(binary.BigEndian.Uint64(msg[off:])) off += 8 case `dns:"uint48"`: // Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes) diff --git a/msg_helpers.go b/msg_helpers.go index a6d10672..34e5ea1f 100644 --- a/msg_helpers.go +++ b/msg_helpers.go @@ -3,6 +3,7 @@ package dns import ( "encoding/base32" "encoding/base64" + "encoding/binary" "encoding/hex" "net" "strconv" @@ -186,20 +187,6 @@ func fromBase64(s []byte) (buf []byte, err error) { func toBase64(b []byte) string { return base64.StdEncoding.EncodeToString(b) } -func unpackUint16Msg(msg []byte, off int) (uint16, int) { - return uint16(msg[off])<<8 | uint16(msg[off+1]), off + 2 -} - -func packUint16Msg(i uint16) (byte, byte) { return byte(i >> 8), byte(i) } - -func unpackUint32Msg(msg []byte, off int) (uint32, int) { - return uint32(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]))), off + 4 -} - -func packUint32Msg(i uint32) (byte, byte, byte, byte) { - return byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i) -} - // dynamicUpdate returns true if the Rdlength is zero. func noRdata(h RR_Header) bool { return h.Rdlength == 0 } @@ -222,15 +209,14 @@ func unpackUint16(msg []byte, off int) (i uint16, off1 int, err error) { if off+2 > len(msg) { return 0, len(msg), &Error{err: "overflow unpacking uint16"} } - i, off = unpackUint16Msg(msg, off) - return i, off, nil + return binary.BigEndian.Uint16(msg[off:]), off + 2, nil } func packUint16(i uint16, msg []byte, off int) (off1 int, err error) { if off+2 > len(msg) { return len(msg), &Error{err: "overflow packing uint16"} } - msg[off], msg[off+1] = packUint16Msg(i) + binary.BigEndian.PutUint16(msg[off:], i) return off + 2, nil } @@ -238,15 +224,14 @@ func unpackUint32(msg []byte, off int) (i uint32, off1 int, err error) { if off+4 > len(msg) { return 0, len(msg), &Error{err: "overflow unpacking uint32"} } - i, off = unpackUint32Msg(msg, off) - return i, off, nil + return binary.BigEndian.Uint32(msg[off:]), off + 4, nil } func packUint32(i uint32, msg []byte, off int) (off1 int, err error) { if off+4 > len(msg) { return len(msg), &Error{err: "overflow packing uint32"} } - msg[off], msg[off+1], msg[off+2], msg[off+3] = packUint32Msg(i) + binary.BigEndian.PutUint32(msg[off:], i) return off + 4, nil } @@ -279,24 +264,14 @@ func unpackUint64(msg []byte, off int) (i uint64, off1 int, err error) { if off+8 > len(msg) { return 0, len(msg), &Error{err: "overflow unpacking uint64"} } - i = (uint64(uint64(msg[off])<<56 | uint64(msg[off+1])<<48 | uint64(msg[off+2])<<40 | - uint64(msg[off+3])<<32 | uint64(msg[off+4])<<24 | uint64(msg[off+5])<<16 | uint64(msg[off+6])<<8 | uint64(msg[off+7]))) - off += 8 - return i, off, nil + return binary.BigEndian.Uint64(msg[off:]), off + 8, nil } func packUint64(i uint64, msg []byte, off int) (off1 int, err error) { if off+8 > len(msg) { return len(msg), &Error{err: "overflow packing uint64"} } - msg[off] = byte(i >> 56) - msg[off+1] = byte(i >> 48) - msg[off+2] = byte(i >> 40) - msg[off+3] = byte(i >> 32) - msg[off+4] = byte(i >> 24) - msg[off+5] = byte(i >> 16) - msg[off+6] = byte(i >> 8) - msg[off+7] = byte(i) + binary.BigEndian.PutUint64(msg[off:], i) off += 8 return off, nil } diff --git a/rawmsg.go b/rawmsg.go index e4e5374d..1ffed35b 100644 --- a/rawmsg.go +++ b/rawmsg.go @@ -1,5 +1,7 @@ package dns +import "encoding/binary" + // These raw* functions do not use reflection, they directly set the values // in the buffer. There are faster than their reflection counterparts. @@ -8,7 +10,7 @@ func rawSetId(msg []byte, i uint16) bool { if len(msg) < 2 { return false } - msg[0], msg[1] = packUint16Msg(i) + binary.BigEndian.PutUint16(msg, i) return true } @@ -17,7 +19,7 @@ func rawSetQuestionLen(msg []byte, i uint16) bool { if len(msg) < 6 { return false } - msg[4], msg[5] = packUint16Msg(i) + binary.BigEndian.PutUint16(msg[4:], i) return true } @@ -26,7 +28,7 @@ func rawSetAnswerLen(msg []byte, i uint16) bool { if len(msg) < 8 { return false } - msg[6], msg[7] = packUint16Msg(i) + binary.BigEndian.PutUint16(msg[6:], i) return true } @@ -35,7 +37,7 @@ func rawSetNsLen(msg []byte, i uint16) bool { if len(msg) < 10 { return false } - msg[8], msg[9] = packUint16Msg(i) + binary.BigEndian.PutUint16(msg[8:], i) return true } @@ -44,7 +46,7 @@ func rawSetExtraLen(msg []byte, i uint16) bool { if len(msg) < 12 { return false } - msg[10], msg[11] = packUint16Msg(i) + binary.BigEndian.PutUint16(msg[10:], i) return true } @@ -90,6 +92,6 @@ Loop: if rdatalen > 0xFFFF { return false } - msg[off], msg[off+1] = packUint16Msg(uint16(rdatalen)) + binary.BigEndian.PutUint16(msg[off:], uint16(rdatalen)) return true } diff --git a/server.go b/server.go index 158cd3b7..b61dd674 100644 --- a/server.go +++ b/server.go @@ -5,6 +5,7 @@ package dns import ( "bytes" "crypto/tls" + "encoding/binary" "io" "net" "sync" @@ -615,7 +616,7 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) } return nil, ErrShortRead } - length, _ := unpackUint16Msg(l, 0) + length := binary.BigEndian.Uint16(l) if length == 0 { return nil, ErrShortRead } @@ -690,7 +691,7 @@ func (w *response) Write(m []byte) (int, error) { return 0, &Error{err: "message too large"} } l := make([]byte, 2, 2+lm) - l[0], l[1] = packUint16Msg(uint16(lm)) + binary.BigEndian.PutUint16(l, uint16(lm)) m = append(l, m...) n, err := io.Copy(w.tcp, bytes.NewReader(m)) diff --git a/sig0.go b/sig0.go index ea3e5b67..2dce06af 100644 --- a/sig0.go +++ b/sig0.go @@ -5,6 +5,7 @@ import ( "crypto/dsa" "crypto/ecdsa" "crypto/rsa" + "encoding/binary" "math/big" "strings" "time" @@ -67,13 +68,13 @@ func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) { } // Adjust sig data length rdoff := len(mbuf) + 1 + 2 + 2 + 4 - rdlen, _ := unpackUint16Msg(buf, rdoff) + rdlen := binary.BigEndian.Uint16(buf[rdoff:]) rdlen += uint16(len(sig)) - buf[rdoff], buf[rdoff+1] = packUint16Msg(rdlen) + binary.BigEndian.PutUint16(buf[rdoff:], rdlen) // Adjust additional count - adc, _ := unpackUint16Msg(buf, 10) + adc := binary.BigEndian.Uint16(buf[10:]) adc++ - buf[10], buf[11] = packUint16Msg(adc) + binary.BigEndian.PutUint16(buf[10:], adc) return buf, nil } @@ -103,10 +104,11 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { hasher := hash.New() buflen := len(buf) - qdc, _ := unpackUint16Msg(buf, 4) - anc, _ := unpackUint16Msg(buf, 6) - auc, _ := unpackUint16Msg(buf, 8) - adc, offset := unpackUint16Msg(buf, 10) + qdc := binary.BigEndian.Uint16(buf[4:]) + anc := binary.BigEndian.Uint16(buf[6:]) + auc := binary.BigEndian.Uint16(buf[8:]) + adc := binary.BigEndian.Uint16(buf[10:]) + offset := 12 var err error for i := uint16(0); i < qdc && offset < buflen; i++ { _, offset, err = UnpackDomainName(buf, offset) @@ -127,7 +129,8 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { continue } var rdlen uint16 - rdlen, offset = unpackUint16Msg(buf, offset) + rdlen = binary.BigEndian.Uint16(buf[offset:]) + offset += 2 offset += int(rdlen) } if offset >= buflen { @@ -149,9 +152,9 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { if offset+4+4 >= buflen { return &Error{err: "overflow unpacking signed message"} } - expire := uint32(buf[offset])<<24 | uint32(buf[offset+1])<<16 | uint32(buf[offset+2])<<8 | uint32(buf[offset+3]) + expire := binary.BigEndian.Uint32(buf[offset:]) offset += 4 - incept := uint32(buf[offset])<<24 | uint32(buf[offset+1])<<16 | uint32(buf[offset+2])<<8 | uint32(buf[offset+3]) + incept := binary.BigEndian.Uint32(buf[offset:]) offset += 4 now := uint32(time.Now().Unix()) if now < incept || now > expire { diff --git a/tsig.go b/tsig.go index 7a089ba2..b07aab4c 100644 --- a/tsig.go +++ b/tsig.go @@ -6,6 +6,7 @@ import ( "crypto/sha1" "crypto/sha256" "crypto/sha512" + "encoding/binary" "encoding/hex" "hash" "io" @@ -301,8 +302,8 @@ func stripTsig(msg []byte) ([]byte, *TSIG, error) { if dns.Extra[i].Header().Rrtype == TypeTSIG { rr = dns.Extra[i].(*TSIG) // Adjust Arcount. - arcount, _ := unpackUint16Msg(msg, 10) - msg[10], msg[11] = packUint16Msg(arcount - 1) + arcount := binary.BigEndian.Uint16(msg[10:]) + binary.BigEndian.PutUint16(msg[10:], arcount-1) break } }