diff --git a/client.go b/client.go index 28834dc6..8d0ef7b8 100644 --- a/client.go +++ b/client.go @@ -197,6 +197,12 @@ func (co *Conn) ReadMsg() (*Msg, error) { m := new(Msg) if err := m.Unpack(p); err != nil { + // If ErrTruncated was returned, we still want to allow the user to use + // the message, but naively they can just check err if they don't want + // to use a truncated message + if err == ErrTruncated { + return m, err + } return nil, err } if t := m.IsTsig(); t != nil { diff --git a/client_test.go b/client_test.go index eb980576..0f7a33bc 100644 --- a/client_test.go +++ b/client_test.go @@ -1,6 +1,8 @@ package dns import ( + "fmt" + "net" "strconv" "testing" "time" @@ -236,3 +238,146 @@ func TestClientConn(t *testing.T) { t.Errorf("unable to unpack message fully: %v", err) } } + +func TestTruncatedMsg(t *testing.T) { + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSRV) + cnt := 10 + for i := 0; i < cnt; i++ { + r := &SRV{ + Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeSRV, Class: ClassINET, Ttl: 0}, + Port: uint16(i + 8000), + Target: "target.miek.nl.", + } + m.Answer = append(m.Answer, r) + + re := &A{ + Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeA, Class: ClassINET, Ttl: 0}, + A: net.ParseIP(fmt.Sprintf("127.0.0.%d", i)).To4(), + } + m.Extra = append(m.Extra, re) + } + buf, err := m.Pack() + if err != nil { + t.Errorf("failed to pack: %v", err) + } + + r := new(Msg) + if err = r.Unpack(buf); err != nil { + t.Errorf("unable to unpack message: %v", err) + } + if len(r.Answer) != cnt { + t.Logf("answer count after regular unpack doesn't match: %d", len(r.Answer)) + t.Fail() + } + if len(r.Extra) != cnt { + t.Logf("extra count after regular unpack doesn't match: %d", len(r.Extra)) + t.Fail() + } + + m.Truncated = true + buf, err = m.Pack() + if err != nil { + t.Errorf("failed to pack truncated: %v", err) + } + + r = new(Msg) + if err = r.Unpack(buf); err != nil && err != ErrTruncated { + t.Errorf("unable to unpack truncated message: %v", err) + } + if !r.Truncated { + t.Log("truncated message wasn't unpacked as truncated") + t.Fail() + } + if len(r.Answer) != cnt { + t.Logf("answer count after truncated unpack doesn't match: %d", len(r.Answer)) + t.Fail() + } + if len(r.Extra) != cnt { + t.Logf("extra count after truncated unpack doesn't match: %d", len(r.Extra)) + t.Fail() + } + + // Now we want to remove almost all of the extra records + // We're going to loop over the extra to get the count of the size of all + // of them + off := 0 + buf1 := make([]byte, m.Len()) + for i := 0; i < len(m.Extra); i++ { + off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress) + if err != nil { + t.Errorf("failed to pack extra: %v", err) + } + } + + // Remove all of the extra bytes but 10 bytes from the end of buf + off -= 10 + buf1 = buf[:len(buf)-off] + + r = new(Msg) + if err = r.Unpack(buf1); err != nil && err != ErrTruncated { + t.Errorf("unable to unpack cutoff message: %v", err) + } + if !r.Truncated { + t.Log("truncated cutoff message wasn't unpacked as truncated") + t.Fail() + } + if len(r.Answer) != cnt { + t.Logf("answer count after cutoff unpack doesn't match: %d", len(r.Answer)) + t.Fail() + } + if len(r.Extra) != 0 { + t.Logf("extra count after cutoff unpack is not zero: %d", len(r.Extra)) + t.Fail() + } + + // Now we want to remove almost all of the answer records too + buf1 = make([]byte, m.Len()) + as := 0 + for i := 0; i < len(m.Extra); i++ { + off1 := off + off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress) + as = off - off1 + if err != nil { + t.Errorf("failed to pack extra: %v", err) + } + } + + // Keep exactly one answer left + // This should still cause Answer to be nil + off -= as + buf1 = buf[:len(buf)-off] + + r = new(Msg) + if err = r.Unpack(buf1); err != nil && err != ErrTruncated { + t.Errorf("unable to unpack cutoff message: %v", err) + } + if !r.Truncated { + t.Log("truncated cutoff message wasn't unpacked as truncated") + t.Fail() + } + if len(r.Answer) != 0 { + t.Logf("answer count after second cutoff unpack is not zero: %d", len(r.Answer)) + t.Fail() + } + + // Now leave only 1 byte of the question + // Since the header is always 12 bytes, we just need to keep 13 + buf1 = buf[:13] + + r = new(Msg) + err = r.Unpack(buf1) + if err == nil || err == ErrTruncated { + t.Logf("error should not be ErrTruncated from question cutoff unpack: %v", err) + t.Fail() + } + + // Finally, if we only have the header, we should still return an error + buf1 = buf[:12] + + r = new(Msg) + if err = r.Unpack(buf1); err == nil || err != ErrTruncated { + t.Logf("error not ErrTruncated from header-only unpack: %v", err) + t.Fail() + } +} diff --git a/msg.go b/msg.go index 6f0d69c6..73ffb616 100644 --- a/msg.go +++ b/msg.go @@ -54,6 +54,9 @@ var ( ErrSoa error = &Error{err: "no SOA"} // ErrTime indicates a timing error in TSIG authentication. ErrTime error = &Error{err: "bad time"} + // ErrTruncated indicates that we failed to unpack a truncated message. + // We unpacked as much as we had so Msg can still be used, if desired. + ErrTruncated error = &Error{err: "failed to unpack truncated message"} ) // Id, by default, returns a 16 bits random number to be used as a @@ -1238,8 +1241,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er continue } } - if off == lenmsg { - // zero rdata foo, OK for dyn. updates + if off == lenmsg && int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) == 0 { + // zero rdata is ok for dyn updates, but only if rdlength is 0 break } s, off, err = UnpackDomainName(msg, off) @@ -1396,6 +1399,32 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) { return rr, off, err } +// unpackRRslice unpacks msg[off:] into an []RR. +// If we cannot unpack the whole array, then it will return nil +func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) { + var r RR + // Optimistically make dst be the length that was sent + dst := make([]RR, 0, l) + for i := 0; i < l; i++ { + off1 := off + r, off, err = UnpackRR(msg, off) + if err != nil { + off = len(msg) + break + } + // If offset does not increase anymore, l is a lie + if off1 == off { + l = i + break + } + dst = append(dst, r) + } + if err != nil && off == len(msg) { + dst = nil + } + return dst, off, err +} + // Reverse a map func reverseInt8(m map[uint8]string) map[string]uint8 { n := make(map[string]uint8) @@ -1594,84 +1623,48 @@ func (dns *Msg) Unpack(msg []byte) (err error) { dns.CheckingDisabled = (dh.Bits & _CD) != 0 dns.Rcode = int(dh.Bits & 0xF) - // Don't pre-alloc these arrays, the incoming lengths are from the network. - dns.Question = make([]Question, 0, 1) - dns.Answer = make([]RR, 0, 10) - dns.Ns = make([]RR, 0, 10) - dns.Extra = make([]RR, 0, 10) + // Optimistically use the count given to us in the header + dns.Question = make([]Question, 0, int(dh.Qdcount)) var q Question for i := 0; i < int(dh.Qdcount); i++ { off1 := off off, err = UnpackStruct(&q, msg, off) if err != nil { + // Even if Truncated is set, we only will set ErrTruncated if we + // actually got the questions return err } if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie! dh.Qdcount = uint16(i) break } - dns.Question = append(dns.Question, q) - - } - // If we see a TC bit being set we return here, without - // an error, because technically it isn't an error. So return - // without parsing the potentially corrupt packet and hitting an error. - // TODO(miek): this isn't the best strategy! - // Better stragey would be: set boolean indicating truncated message, go forth and parse - // until we hit an error, return the message without the latest parsed rr if this boolean - // is true. - if dns.Truncated { - dns.Answer = nil - dns.Ns = nil - dns.Extra = nil - return nil } - var r RR - for i := 0; i < int(dh.Ancount); i++ { - off1 := off - r, off, err = UnpackRR(msg, off) - if err != nil { - return err - } - if off1 == off { // Offset does not increase anymore, dh.Ancount is a lie! - dh.Ancount = uint16(i) - break - } - dns.Answer = append(dns.Answer, r) + dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off) + // The header counts might have been wrong so we need to update it + dh.Ancount = uint16(len(dns.Answer)) + if err == nil { + dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off) } - for i := 0; i < int(dh.Nscount); i++ { - off1 := off - r, off, err = UnpackRR(msg, off) - if err != nil { - return err - } - if off1 == off { // Offset does not increase anymore, dh.Nscount is a lie! - dh.Nscount = uint16(i) - break - } - dns.Ns = append(dns.Ns, r) - } - for i := 0; i < int(dh.Arcount); i++ { - off1 := off - r, off, err = UnpackRR(msg, off) - if err != nil { - return err - } - if off1 == off { // Offset does not increase anymore, dh.Arcount is a lie! - dh.Arcount = uint16(i) - break - } - dns.Extra = append(dns.Extra, r) + // The header counts might have been wrong so we need to update it + dh.Nscount = uint16(len(dns.Ns)) + if err == nil { + dns.Extra, off, err = unpackRRslice(int(dh.Arcount), msg, off) } + // The header counts might have been wrong so we need to update it + dh.Arcount = uint16(len(dns.Extra)) if off != len(msg) { // TODO(miek) make this an error? // use PackOpt to let people tell how detailed the error reporting should be? // println("dns: extra bytes in dns packet", off, "<", len(msg)) + } else if dns.Truncated { + // Whether we ran into a an error or not, we want to return that it + // was truncated + err = ErrTruncated } - return nil + return err } // Convert a complete message to a string with dig-like output.