Instead of removing all RRs on Truncated, attempt to unpack
This commit is contained in:
parent
6a1556664f
commit
2d2c2ebcfc
|
@ -197,6 +197,12 @@ func (co *Conn) ReadMsg() (*Msg, error) {
|
|||
|
||||
m := new(Msg)
|
||||
if err := m.Unpack(p); err != nil {
|
||||
// If ErrTruncated was returned, we still want to allow the user to use
|
||||
// the message, but naively they can just check err if they don't want
|
||||
// to use a truncated message
|
||||
if err == ErrTruncated {
|
||||
return m, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if t := m.IsTsig(); t != nil {
|
||||
|
|
145
client_test.go
145
client_test.go
|
@ -1,6 +1,8 @@
|
|||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -236,3 +238,146 @@ func TestClientConn(t *testing.T) {
|
|||
t.Errorf("unable to unpack message fully: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncatedMsg(t *testing.T) {
|
||||
m := new(Msg)
|
||||
m.SetQuestion("miek.nl.", TypeSRV)
|
||||
cnt := 10
|
||||
for i := 0; i < cnt; i++ {
|
||||
r := &SRV{
|
||||
Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeSRV, Class: ClassINET, Ttl: 0},
|
||||
Port: uint16(i + 8000),
|
||||
Target: "target.miek.nl.",
|
||||
}
|
||||
m.Answer = append(m.Answer, r)
|
||||
|
||||
re := &A{
|
||||
Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeA, Class: ClassINET, Ttl: 0},
|
||||
A: net.ParseIP(fmt.Sprintf("127.0.0.%d", i)).To4(),
|
||||
}
|
||||
m.Extra = append(m.Extra, re)
|
||||
}
|
||||
buf, err := m.Pack()
|
||||
if err != nil {
|
||||
t.Errorf("failed to pack: %v", err)
|
||||
}
|
||||
|
||||
r := new(Msg)
|
||||
if err = r.Unpack(buf); err != nil {
|
||||
t.Errorf("unable to unpack message: %v", err)
|
||||
}
|
||||
if len(r.Answer) != cnt {
|
||||
t.Logf("answer count after regular unpack doesn't match: %d", len(r.Answer))
|
||||
t.Fail()
|
||||
}
|
||||
if len(r.Extra) != cnt {
|
||||
t.Logf("extra count after regular unpack doesn't match: %d", len(r.Extra))
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
m.Truncated = true
|
||||
buf, err = m.Pack()
|
||||
if err != nil {
|
||||
t.Errorf("failed to pack truncated: %v", err)
|
||||
}
|
||||
|
||||
r = new(Msg)
|
||||
if err = r.Unpack(buf); err != nil && err != ErrTruncated {
|
||||
t.Errorf("unable to unpack truncated message: %v", err)
|
||||
}
|
||||
if !r.Truncated {
|
||||
t.Log("truncated message wasn't unpacked as truncated")
|
||||
t.Fail()
|
||||
}
|
||||
if len(r.Answer) != cnt {
|
||||
t.Logf("answer count after truncated unpack doesn't match: %d", len(r.Answer))
|
||||
t.Fail()
|
||||
}
|
||||
if len(r.Extra) != cnt {
|
||||
t.Logf("extra count after truncated unpack doesn't match: %d", len(r.Extra))
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Now we want to remove almost all of the extra records
|
||||
// We're going to loop over the extra to get the count of the size of all
|
||||
// of them
|
||||
off := 0
|
||||
buf1 := make([]byte, m.Len())
|
||||
for i := 0; i < len(m.Extra); i++ {
|
||||
off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress)
|
||||
if err != nil {
|
||||
t.Errorf("failed to pack extra: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove all of the extra bytes but 10 bytes from the end of buf
|
||||
off -= 10
|
||||
buf1 = buf[:len(buf)-off]
|
||||
|
||||
r = new(Msg)
|
||||
if err = r.Unpack(buf1); err != nil && err != ErrTruncated {
|
||||
t.Errorf("unable to unpack cutoff message: %v", err)
|
||||
}
|
||||
if !r.Truncated {
|
||||
t.Log("truncated cutoff message wasn't unpacked as truncated")
|
||||
t.Fail()
|
||||
}
|
||||
if len(r.Answer) != cnt {
|
||||
t.Logf("answer count after cutoff unpack doesn't match: %d", len(r.Answer))
|
||||
t.Fail()
|
||||
}
|
||||
if len(r.Extra) != 0 {
|
||||
t.Logf("extra count after cutoff unpack is not zero: %d", len(r.Extra))
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Now we want to remove almost all of the answer records too
|
||||
buf1 = make([]byte, m.Len())
|
||||
as := 0
|
||||
for i := 0; i < len(m.Extra); i++ {
|
||||
off1 := off
|
||||
off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress)
|
||||
as = off - off1
|
||||
if err != nil {
|
||||
t.Errorf("failed to pack extra: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep exactly one answer left
|
||||
// This should still cause Answer to be nil
|
||||
off -= as
|
||||
buf1 = buf[:len(buf)-off]
|
||||
|
||||
r = new(Msg)
|
||||
if err = r.Unpack(buf1); err != nil && err != ErrTruncated {
|
||||
t.Errorf("unable to unpack cutoff message: %v", err)
|
||||
}
|
||||
if !r.Truncated {
|
||||
t.Log("truncated cutoff message wasn't unpacked as truncated")
|
||||
t.Fail()
|
||||
}
|
||||
if len(r.Answer) != 0 {
|
||||
t.Logf("answer count after second cutoff unpack is not zero: %d", len(r.Answer))
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Now leave only 1 byte of the question
|
||||
// Since the header is always 12 bytes, we just need to keep 13
|
||||
buf1 = buf[:13]
|
||||
|
||||
r = new(Msg)
|
||||
err = r.Unpack(buf1)
|
||||
if err == nil || err == ErrTruncated {
|
||||
t.Logf("error should not be ErrTruncated from question cutoff unpack: %v", err)
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
// Finally, if we only have the header, we should still return an error
|
||||
buf1 = buf[:12]
|
||||
|
||||
r = new(Msg)
|
||||
if err = r.Unpack(buf1); err == nil || err != ErrTruncated {
|
||||
t.Logf("error not ErrTruncated from header-only unpack: %v", err)
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
|
109
msg.go
109
msg.go
|
@ -54,6 +54,9 @@ var (
|
|||
ErrSoa error = &Error{err: "no SOA"}
|
||||
// ErrTime indicates a timing error in TSIG authentication.
|
||||
ErrTime error = &Error{err: "bad time"}
|
||||
// ErrTruncated indicates that we failed to unpack a truncated message.
|
||||
// We unpacked as much as we had so Msg can still be used, if desired.
|
||||
ErrTruncated error = &Error{err: "failed to unpack truncated message"}
|
||||
)
|
||||
|
||||
// Id, by default, returns a 16 bits random number to be used as a
|
||||
|
@ -1238,8 +1241,8 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er
|
|||
continue
|
||||
}
|
||||
}
|
||||
if off == lenmsg {
|
||||
// zero rdata foo, OK for dyn. updates
|
||||
if off == lenmsg && int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint()) == 0 {
|
||||
// zero rdata is ok for dyn updates, but only if rdlength is 0
|
||||
break
|
||||
}
|
||||
s, off, err = UnpackDomainName(msg, off)
|
||||
|
@ -1396,6 +1399,32 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
|
|||
return rr, off, err
|
||||
}
|
||||
|
||||
// unpackRRslice unpacks msg[off:] into an []RR.
|
||||
// If we cannot unpack the whole array, then it will return nil
|
||||
func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) {
|
||||
var r RR
|
||||
// Optimistically make dst be the length that was sent
|
||||
dst := make([]RR, 0, l)
|
||||
for i := 0; i < l; i++ {
|
||||
off1 := off
|
||||
r, off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
off = len(msg)
|
||||
break
|
||||
}
|
||||
// If offset does not increase anymore, l is a lie
|
||||
if off1 == off {
|
||||
l = i
|
||||
break
|
||||
}
|
||||
dst = append(dst, r)
|
||||
}
|
||||
if err != nil && off == len(msg) {
|
||||
dst = nil
|
||||
}
|
||||
return dst, off, err
|
||||
}
|
||||
|
||||
// Reverse a map
|
||||
func reverseInt8(m map[uint8]string) map[string]uint8 {
|
||||
n := make(map[string]uint8)
|
||||
|
@ -1594,84 +1623,48 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
|
|||
dns.CheckingDisabled = (dh.Bits & _CD) != 0
|
||||
dns.Rcode = int(dh.Bits & 0xF)
|
||||
|
||||
// Don't pre-alloc these arrays, the incoming lengths are from the network.
|
||||
dns.Question = make([]Question, 0, 1)
|
||||
dns.Answer = make([]RR, 0, 10)
|
||||
dns.Ns = make([]RR, 0, 10)
|
||||
dns.Extra = make([]RR, 0, 10)
|
||||
// Optimistically use the count given to us in the header
|
||||
dns.Question = make([]Question, 0, int(dh.Qdcount))
|
||||
|
||||
var q Question
|
||||
for i := 0; i < int(dh.Qdcount); i++ {
|
||||
off1 := off
|
||||
off, err = UnpackStruct(&q, msg, off)
|
||||
if err != nil {
|
||||
// Even if Truncated is set, we only will set ErrTruncated if we
|
||||
// actually got the questions
|
||||
return err
|
||||
}
|
||||
if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
|
||||
dh.Qdcount = uint16(i)
|
||||
break
|
||||
}
|
||||
|
||||
dns.Question = append(dns.Question, q)
|
||||
|
||||
}
|
||||
// If we see a TC bit being set we return here, without
|
||||
// an error, because technically it isn't an error. So return
|
||||
// without parsing the potentially corrupt packet and hitting an error.
|
||||
// TODO(miek): this isn't the best strategy!
|
||||
// Better stragey would be: set boolean indicating truncated message, go forth and parse
|
||||
// until we hit an error, return the message without the latest parsed rr if this boolean
|
||||
// is true.
|
||||
if dns.Truncated {
|
||||
dns.Answer = nil
|
||||
dns.Ns = nil
|
||||
dns.Extra = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var r RR
|
||||
for i := 0; i < int(dh.Ancount); i++ {
|
||||
off1 := off
|
||||
r, off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return err
|
||||
dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off)
|
||||
// The header counts might have been wrong so we need to update it
|
||||
dh.Ancount = uint16(len(dns.Answer))
|
||||
if err == nil {
|
||||
dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off)
|
||||
}
|
||||
if off1 == off { // Offset does not increase anymore, dh.Ancount is a lie!
|
||||
dh.Ancount = uint16(i)
|
||||
break
|
||||
}
|
||||
dns.Answer = append(dns.Answer, r)
|
||||
}
|
||||
for i := 0; i < int(dh.Nscount); i++ {
|
||||
off1 := off
|
||||
r, off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if off1 == off { // Offset does not increase anymore, dh.Nscount is a lie!
|
||||
dh.Nscount = uint16(i)
|
||||
break
|
||||
}
|
||||
dns.Ns = append(dns.Ns, r)
|
||||
}
|
||||
for i := 0; i < int(dh.Arcount); i++ {
|
||||
off1 := off
|
||||
r, off, err = UnpackRR(msg, off)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if off1 == off { // Offset does not increase anymore, dh.Arcount is a lie!
|
||||
dh.Arcount = uint16(i)
|
||||
break
|
||||
}
|
||||
dns.Extra = append(dns.Extra, r)
|
||||
// The header counts might have been wrong so we need to update it
|
||||
dh.Nscount = uint16(len(dns.Ns))
|
||||
if err == nil {
|
||||
dns.Extra, off, err = unpackRRslice(int(dh.Arcount), msg, off)
|
||||
}
|
||||
// The header counts might have been wrong so we need to update it
|
||||
dh.Arcount = uint16(len(dns.Extra))
|
||||
if off != len(msg) {
|
||||
// TODO(miek) make this an error?
|
||||
// use PackOpt to let people tell how detailed the error reporting should be?
|
||||
// println("dns: extra bytes in dns packet", off, "<", len(msg))
|
||||
} else if dns.Truncated {
|
||||
// Whether we ran into a an error or not, we want to return that it
|
||||
// was truncated
|
||||
err = ErrTruncated
|
||||
}
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert a complete message to a string with dig-like output.
|
||||
|
|
Loading…
Reference in New Issue