Added packLen

packLen() returns the length of an uncompressed packet buffer, this
is used when packing a packet. This is needed for compression. When
compression is used, we first create the full packet and *then*
compress it. If we use Len() which accounts for compression, we can
get buffer overruns, when packing the (then still uncompressed) packet.
This commit is contained in:
Miek Gieben 2013-06-26 22:18:09 +01:00
parent d83e816f30
commit 1ad76fe65b
2 changed files with 51 additions and 50 deletions

View File

@ -159,49 +159,34 @@ func TestCompressLength(t *testing.T) {
// Does the predicted length match final packed length
func TestMsgLenTest(t *testing.T) {
var (
// util function to build messages
makeMsg = func(question string, ans, ns, e []RR) *Msg {
var msg Msg
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return &msg
}
// util function to build messages
makeMsg := func(question string, ans, ns, e []RR) *Msg {
msg := new(Msg)
msg.SetQuestion(Fqdn(question), TypeANY)
msg.Answer = append(msg.Answer, ans...)
msg.Ns = append(msg.Ns, ns...)
msg.Extra = append(msg.Extra, e...)
msg.Compress = true
return msg
}
name = "12345678901234567890123456789012345.12345678.123."
rrA, _ = NewRR(name + " 3600 IN A 192.0.2.1")
rrMx, _ = NewRR(name + " 3600 IN MX 10 " + name)
rrTxt, _ = NewRR(name + ` 3600 IN TXT "I am a TXT"`)
tests = []*Msg{
makeMsg(name, nil, nil, nil),
makeMsg(name, []RR{rrA}, nil, nil),
makeMsg(name, []RR{rrMx}, nil, nil),
makeMsg(name, []RR{rrTxt}, nil, nil),
makeMsg(name, []RR{rrA, rrA}, nil, nil),
makeMsg(name, []RR{rrMx, rrMx}, nil, nil),
makeMsg(name, []RR{rrTxt, rrTxt}, nil, nil),
makeMsg(name, []RR{rrA}, []RR{rrA}, nil),
makeMsg(name, []RR{rrMx}, []RR{rrMx}, nil),
makeMsg(name, []RR{rrTxt}, []RR{rrTxt}, nil),
makeMsg(name, []RR{rrA, rrMx, rrTxt}, []RR{rrA, rrMx, rrTxt}, nil),
makeMsg(name, []RR{rrA, rrMx, rrTxt}, []RR{rrA, rrMx, rrTxt}, []RR{rrA, rrMx, rrTxt})}
)
name1 := "12345678901234567890123456789012345.12345678.123."
rrA, _ := NewRR(name1 + " 3600 IN A 192.0.2.1")
rrMx, _ := NewRR(name1 + " 3600 IN MX 10 " + name1)
tests := []*Msg{
makeMsg(name1, []RR{rrA}, nil, nil),
makeMsg(name1, []RR{rrMx, rrMx}, nil, nil)}
for _, msg := range tests {
var (
predicted = msg.Len()
buf, err = msg.Pack()
actual = len(buf)
)
predicted := msg.Len()
buf, err := msg.Pack()
if err != nil {
t.Error(err)
t.Fail()
}
if predicted != actual {
t.Errorf("Predicted length is wrong: predicted %d, actual %d\n%s", predicted, actual, msg)
if predicted != len(buf) {
t.Errorf("Predicted length is wrong: predicted %s (len=%d) %d, actual %d\n",
msg.Question[0].Name, len(msg.Answer), predicted, len(buf))
t.Fail()
}
}

44
msg.go
View File

@ -283,7 +283,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
// the offset of the current name, because that's
// where we need to insert the pointer later
// If compress is true, we're allowed to compress this dname
// If compress is true, we're allowed to compress this dname
if pointer == -1 && compress {
pointer = p // Where to point to
nameoffset = offset // Where to point from
@ -298,7 +298,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
if len(bs) == 1 && bs[0] == '.' {
return off, nil
}
// If we did compression and we find something at the pointer here
// If we did compression and we find something add the pointer here
if pointer != -1 {
// We have two bytes (14 bits) to put the pointer in
msg[nameoffset], msg[nameoffset+1] = packUint16(uint16(pointer ^ 0xC000))
@ -1255,9 +1255,7 @@ func (dns *Msg) Pack() (msg []byte, err error) {
dh.Nscount = uint16(len(ns))
dh.Arcount = uint16(len(extra))
// TODO(mg): still a little too much, but better than 64K...
msg = make([]byte, dns.Len()+10)
msg = make([]byte, dns.packLen()+10) // TODO(miekg): +10 should go sometimses
// Pack it in: header and then the pieces.
off := 0
off, err = packStructCompress(&dh, msg, off, compression, dns.Compress)
@ -1393,10 +1391,28 @@ func (dns *Msg) String() string {
return s
}
// Len return the message length when in (un)compressed wire format.
// packLen returns the message length when in UNcompressed wire format.
func (dns *Msg) packLen() int {
// Message header is always 12 bytes
l := 12
for i := 0; i < len(dns.Question); i++ {
l += dns.Question[i].len()
}
for i := 0; i < len(dns.Answer); i++ {
l += dns.Answer[i].len()
}
for i := 0; i < len(dns.Ns); i++ {
l += dns.Ns[i].len()
}
for i := 0; i < len(dns.Extra); i++ {
l += dns.Extra[i].len()
}
return l
}
// Len returns the message length when in (un)compressed wire format.
// If dns.Compress is true compression it is taken into account, currently
// this only counts owner name compression. There is no check for
// nil valued sections (allocated, but contain no RRs).
// this only counts owner name compression.
func (dns *Msg) Len() int {
// Message header is always 12 bytes
l := 12
@ -1413,8 +1429,8 @@ func (dns *Msg) Len() int {
}
for i := 0; i < len(dns.Answer); i++ {
if dns.Compress {
if v, ok := compression[dns.Answer[i].Header().Name]; ok {
l += dns.Answer[i].len() - v
if _, ok := compression[dns.Answer[i].Header().Name]; ok {
l += dns.Answer[i].len() - len(dns.Answer[i].Header().Name) + 2
continue
}
compressionHelper(compression, dns.Answer[i].Header().Name)
@ -1423,8 +1439,8 @@ func (dns *Msg) Len() int {
}
for i := 0; i < len(dns.Ns); i++ {
if dns.Compress {
if v, ok := compression[dns.Ns[i].Header().Name]; ok {
l += dns.Ns[i].len() - v
if _, ok := compression[dns.Ns[i].Header().Name]; ok {
l += dns.Ns[i].len() - len(dns.Ns[i].Header().Name) + 2
continue
}
compressionHelper(compression, dns.Ns[i].Header().Name)
@ -1433,8 +1449,8 @@ func (dns *Msg) Len() int {
}
for i := 0; i < len(dns.Extra); i++ {
if dns.Compress {
if v, ok := compression[dns.Extra[i].Header().Name]; ok {
l += dns.Extra[i].len() - v
if _, ok := compression[dns.Extra[i].Header().Name]; ok {
l += dns.Extra[i].len() - len(dns.Extra[i].Header().Name) + 2
continue
}
compressionHelper(compression, dns.Extra[i].Header().Name)