Avoid setting the Rdlength field when packing records (#859)

* Avoid setting the Rdlength field when packing

The offset of the end of the header is now returned from the RR.pack
method, with the RDLENGTH record field being written in packRR.

To maintain compatability with callers of PackRR who might be relying
on this old behaviour, PackRR will now set rr.Header().Rdlength for
external callers. Care must be taken by callers to ensure this won't
cause a data-race.

* Prevent panic if TestClientLocalAddress fails

This came up during testing of the previous change.

* Change style of overflow check in packRR
This commit is contained in:
Tom Thorogood 2018-12-02 18:53:35 +10:30 committed by Miek Gieben
parent 470f08e191
commit ff7d445081
8 changed files with 527 additions and 709 deletions

View File

@ -93,7 +93,7 @@ func TestClientLocalAddress(t *testing.T) {
t.Errorf("failed to get an valid answer\n%v", r)
}
if len(r.Extra) != 1 {
t.Errorf("failed to get additional answers\n%v", r)
t.Fatalf("failed to get additional answers\n%v", r)
}
txt := r.Extra[0].(*TXT)
if txt == nil {

13
dns.go
View File

@ -42,7 +42,7 @@ type RR interface {
len(off int, compression map[string]struct{}) int
// pack packs an RR into wire format.
pack([]byte, int, compressionMap, bool) (int, error)
pack(msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error)
}
// RR_Header is the header all DNS resource records share.
@ -84,19 +84,20 @@ func (h *RR_Header) len(off int, compression map[string]struct{}) int {
// ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597.
func (rr *RFC3597) ToRFC3597(r RR) error {
buf := make([]byte, Len(r)*2)
off, err := PackRR(r, buf, 0, nil, false)
headerEnd, off, err := packRR(r, buf, 0, compressionMap{}, false)
if err != nil {
return err
}
buf = buf[:off]
if int(r.Header().Rdlength) > off {
return ErrBuf
}
rfc3597, _, err := unpackRFC3597(*r.Header(), buf, off-int(r.Header().Rdlength))
hdr := *r.Header()
hdr.Rdlength = uint16(off - headerEnd)
rfc3597, _, err := unpackRFC3597(hdr, buf, headerEnd)
if err != nil {
return err
}
*rr = *rfc3597.(*RFC3597)
return nil
}

34
msg.go
View File

@ -619,23 +619,33 @@ func intToBytes(i *big.Int, length int) []byte {
// PackRR packs a resource record rr into msg[off:].
// See PackDomainName for documentation about the compression.
func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
return packRR(rr, msg, off, compressionMap{ext: compression}, compress)
headerEnd, off1, err := packRR(rr, msg, off, compressionMap{ext: compression}, compress)
if err == nil {
// packRR no longer sets the Rdlength field on the rr, but
// callers might be expecting it so we set it here.
rr.Header().Rdlength = uint16(off1 - headerEnd)
}
return off1, err
}
func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) {
if rr == nil {
return len(msg), &Error{err: "nil rr"}
return len(msg), len(msg), &Error{err: "nil rr"}
}
off1, err = rr.pack(msg, off, compression, compress)
headerEnd, off1, err = rr.pack(msg, off, compression, compress)
if err != nil {
return len(msg), err
return headerEnd, len(msg), err
}
// TODO(miek): Not sure if this is needed? If removed we can remove rawmsg.go as well.
if rawSetRdlength(msg, off, off1) {
return off1, nil
rdlength := off1 - headerEnd
if int(uint16(rdlength)) != rdlength { // overflow
return headerEnd, len(msg), ErrRdata
}
return off, ErrRdata
// The RDLENGTH field is the last field in the header and we set it here.
binary.BigEndian.PutUint16(msg[headerEnd-2:], uint16(rdlength))
return headerEnd, off1, nil
}
// UnpackRR unpacks msg[off:] into an RR.
@ -821,19 +831,19 @@ func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compression
}
}
for _, r := range dns.Answer {
off, err = packRR(r, msg, off, compression, compress)
_, off, err = packRR(r, msg, off, compression, compress)
if err != nil {
return nil, err
}
}
for _, r := range dns.Ns {
off, err = packRR(r, msg, off, compression, compress)
_, off, err = packRR(r, msg, off, compression, compress)
if err != nil {
return nil, err
}
}
for _, r := range dns.Extra {
off, err = packRR(r, msg, off, compression, compress)
_, off, err = packRR(r, msg, off, compression, compress)
if err != nil {
return nil, err
}

View File

@ -80,18 +80,17 @@ func main() {
o := scope.Lookup(name)
st, _ := getTypeStruct(o.Type(), scope)
fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {\n", name)
fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress)
fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (int, int, error) {\n", name)
fmt.Fprint(b, `headerEnd, off, err := rr.Hdr.pack(msg, off, compression, compress)
if err != nil {
return off, err
return headerEnd, off, err
}
headerEnd := off
`)
for i := 1; i < st.NumFields(); i++ {
o := func(s string) {
fmt.Fprintf(b, s, st.Field(i).Name())
fmt.Fprint(b, `if err != nil {
return off, err
return headerEnd, off, err
}
`)
}
@ -145,7 +144,7 @@ return off, err
if rr.%s != "-" {
off, err = packStringHex(rr.%s, msg, off)
if err != nil {
return off, err
return headerEnd, off, err
}
}
`, field, field)
@ -176,9 +175,7 @@ if rr.%s != "-" {
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
}
// We have packed everything, only now we know the rdlength of this RR
fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)")
fmt.Fprintln(b, "return off, nil }\n")
fmt.Fprintln(b, "return headerEnd, off, nil }\n")
}
fmt.Fprint(b, "// unpack*() functions\n\n")

View File

@ -101,32 +101,32 @@ func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte,
// pack packs an RR header, returning the offset to the end of the header.
// See PackDomainName for documentation about the compression.
func (hdr RR_Header) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
func (hdr RR_Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, int, error) {
if off == len(msg) {
return off, nil
return off, off, nil
}
off, _, err = packDomainName(hdr.Name, msg, off, compression, compress)
off, _, err := packDomainName(hdr.Name, msg, off, compression, compress)
if err != nil {
return len(msg), err
return off, len(msg), err
}
off, err = packUint16(hdr.Rrtype, msg, off)
if err != nil {
return len(msg), err
return off, len(msg), err
}
off, err = packUint16(hdr.Class, msg, off)
if err != nil {
return len(msg), err
return off, len(msg), err
}
off, err = packUint32(hdr.Ttl, msg, off)
if err != nil {
return len(msg), err
return off, len(msg), err
}
off, err = packUint16(hdr.Rdlength, msg, off)
off, err = packUint16(0, msg, off) // The RDLENGTH field will be set later in packRR.
if err != nil {
return len(msg), err
return off, len(msg), err
}
return off, nil
return off, off, nil
}
// helper helper functions.

View File

@ -69,19 +69,18 @@ func (r *PrivateRR) copy() RR {
}
return rr
}
func (r *PrivateRR) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
off, err := r.Hdr.pack(msg, off, compression, compress)
func (r *PrivateRR) pack(msg []byte, off int, compression compressionMap, compress bool) (int, int, error) {
headerEnd, off, err := r.Hdr.pack(msg, off, compression, compress)
if err != nil {
return off, err
return off, off, err
}
headerEnd := off
n, err := r.Data.Pack(msg[off:])
if err != nil {
return len(msg), err
return headerEnd, len(msg), err
}
off += n
r.Header().Rdlength = uint16(off - headerEnd)
return off, nil
return headerEnd, off, nil
}
// PrivateHandle registers a private resource record type. It requires

View File

@ -1,49 +0,0 @@
package dns
import "encoding/binary"
// rawSetRdlength sets the rdlength in the header of
// the RR. The offset 'off' must be positioned at the
// start of the header of the RR, 'end' must be the
// end of the RR.
func rawSetRdlength(msg []byte, off, end int) bool {
l := len(msg)
Loop:
for {
if off+1 > l {
return false
}
c := int(msg[off])
off++
switch c & 0xC0 {
case 0x00:
if c == 0x00 {
// End of the domainname
break Loop
}
if off+c > l {
return false
}
off += c
case 0xC0:
// pointer, next byte included, ends domainname
off++
break Loop
}
}
// The domainname has been seen, we at the start of the fixed part in the header.
// Type is 2 bytes, class is 2 bytes, ttl 4 and then 2 bytes for the length.
off += 2 + 2 + 4
if off+2 > l {
return false
}
//off+1 is the end of the header, 'end' is the end of the rr
//so 'end' - 'off+2' is the length of the rdata
rdatalen := end - (off + 2)
if rdatalen > 0xFFFF {
return false
}
binary.BigEndian.PutUint16(msg[off:], uint16(rdatalen))
return true
}

1090
zmsg.go

File diff suppressed because it is too large Load Diff