Add a message truncation implementation (#854)

* Add a message truncation implementation

* Remove OPT if-statement at end of Scrub

* Impose RFC 6891 payload size limit in Scrub

* Remove *Msg receiver from truncateLoop

* Remove OPT record creation from Scrub

* Test that TestRequestScrubAnswerExact has correct record count

* Rename (*Msg).Scrub to Truncate

This better reflects it's purpose.

* Remove comment reference to scrubbing in Truncate

* Properly calculate the length of OPT record in Truncate

* Correct comment in IsEdns0 in regards to RFC 6891

* Handle the OPT record being anywhere in Truncate

* Slight cleanup of Msg.Truncate
This commit is contained in:
Tom Thorogood 2019-03-24 19:50:11 +10:30 committed by Miek Gieben
parent d8ff986484
commit d051b464e9
3 changed files with 311 additions and 4 deletions

View File

@ -146,10 +146,9 @@ func (dns *Msg) IsTsig() *TSIG {
// record in the additional section will do. It returns the OPT record
// found or nil.
func (dns *Msg) IsEdns0() *OPT {
// EDNS0 is at the end of the additional section, start there.
// We might want to change this to *only* look at the last two
// records. So we see TSIG and/or OPT - this a slightly bigger
// change though.
// RFC 6891, Section 6.1.1 allows the OPT record to appear
// anywhere in the additional record section, but it's usually at
// the end so start there.
for i := len(dns.Extra) - 1; i >= 0; i-- {
if dns.Extra[i].Header().Rrtype == TypeOPT {
return dns.Extra[i].(*OPT)
@ -158,6 +157,21 @@ func (dns *Msg) IsEdns0() *OPT {
return nil
}
// popEdns0 is like IsEdns0, but it removes the record from the message.
func (dns *Msg) popEdns0() *OPT {
// RFC 6891, Section 6.1.1 allows the OPT record to appear
// anywhere in the additional record section, but it's usually at
// the end so start there.
for i := len(dns.Extra) - 1; i >= 0; i-- {
if dns.Extra[i].Header().Rrtype == TypeOPT {
opt := dns.Extra[i].(*OPT)
dns.Extra = append(dns.Extra[:i], dns.Extra[i+1:]...)
return opt
}
}
return nil
}
// IsDomainName checks if s is a valid domain name, it returns the number of
// labels and true, when a domain name is valid. Note that non fully qualified
// domain name is considered valid, in this case the last label is counted in

106
msg_truncate.go Normal file
View File

@ -0,0 +1,106 @@
package dns
// Truncate ensures the reply message will fit into the requested buffer
// size by removing records that exceed the requested size.
//
// It will first check if the reply fits without compression and then with
// compression. If it won't fit with compression, Scrub then walks the
// record adding as many records as possible without exceeding the
// requested buffer size.
//
// The TC bit will be set if any answer records were excluded from the
// message. This indicates to that the client should retry over TCP.
//
// The appropriate buffer size can be retrieved from the requests OPT
// record, if present, and is transport specific otherwise. dns.MinMsgSize
// should be used for UDP requests without an OPT record, and
// dns.MaxMsgSize for TCP requests without an OPT record.
func (dns *Msg) Truncate(size int) {
if dns.IsTsig() != nil {
// To simplify this implementation, we don't perform
// truncation on responses with a TSIG record.
return
}
// RFC 6891 mandates that the payload size in an OPT record
// less than 512 bytes must be treated as equal to 512 bytes.
//
// For ease of use, we impose that restriction here.
if size < 512 {
size = 512
}
l := msgLenWithCompressionMap(dns, nil) // uncompressed length
if l <= size {
// Don't waste effort compressing this message.
dns.Compress = false
return
}
dns.Compress = true
edns0 := dns.popEdns0()
if edns0 != nil {
// Account for the OPT record that gets added at the end,
// by subtracting that length from our budget.
//
// The EDNS(0) OPT record must have the root domain and
// it's length is thus unaffected by compression.
size -= Len(edns0)
}
compression := make(map[string]struct{})
l = headerSize
for _, r := range dns.Question {
l += r.len(l, compression)
}
var numAnswer int
if l < size {
l, numAnswer = truncateLoop(dns.Answer, size, l, compression)
}
var numNS int
if l < size {
l, numNS = truncateLoop(dns.Ns, size, l, compression)
}
var numExtra int
if l < size {
l, numExtra = truncateLoop(dns.Extra, size, l, compression)
}
// According to RFC 2181, the TC bit should only be set if not all
// of the answer RRs can be included in the response.
dns.Truncated = len(dns.Answer) > numAnswer
dns.Answer = dns.Answer[:numAnswer]
dns.Ns = dns.Ns[:numNS]
dns.Extra = dns.Extra[:numExtra]
if edns0 != nil {
// Add the OPT record back onto the additional section.
dns.Extra = append(dns.Extra, edns0)
}
}
func truncateLoop(rrs []RR, size, l int, compression map[string]struct{}) (int, int) {
for i, r := range rrs {
if r == nil {
continue
}
l += r.len(l, compression)
if l > size {
// Return size, rather than l prior to this record,
// to prevent any further records being added.
return size, i
}
if l == size {
return l, i + 1
}
}
return l, len(rrs)
}

187
msg_truncate_test.go Normal file
View File

@ -0,0 +1,187 @@
package dns
import (
"fmt"
"testing"
)
func TestRequestTruncateAnswer(t *testing.T) {
m := new(Msg)
m.SetQuestion("large.example.com.", TypeSRV)
reply := new(Msg)
reply.SetReply(m)
for i := 1; i < 200; i++ {
reply.Answer = append(reply.Answer, testRR(
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
}
reply.Truncate(MinMsgSize)
if want, got := MinMsgSize, reply.Len(); want < got {
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
}
if !reply.Truncated {
t.Errorf("truncated bit should be set")
}
}
func TestRequestTruncateExtra(t *testing.T) {
m := new(Msg)
m.SetQuestion("large.example.com.", TypeSRV)
reply := new(Msg)
reply.SetReply(m)
for i := 1; i < 200; i++ {
reply.Extra = append(reply.Extra, testRR(
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
}
reply.Truncate(MinMsgSize)
if want, got := MinMsgSize, reply.Len(); want < got {
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
}
if reply.Truncated {
t.Errorf("truncated bit should not be set")
}
}
func TestRequestTruncateExtraEdns0(t *testing.T) {
const size = 4096
m := new(Msg)
m.SetQuestion("large.example.com.", TypeSRV)
m.SetEdns0(size, true)
reply := new(Msg)
reply.SetReply(m)
for i := 1; i < 200; i++ {
reply.Extra = append(reply.Extra, testRR(
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
}
reply.SetEdns0(size, true)
reply.Truncate(size)
if want, got := size, reply.Len(); want < got {
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
}
if reply.Truncated {
t.Errorf("truncated bit should not be set")
}
opt := reply.Extra[len(reply.Extra)-1]
if opt.Header().Rrtype != TypeOPT {
t.Errorf("expected last RR to be OPT")
}
}
func TestRequestTruncateExtraRegression(t *testing.T) {
const size = 2048
m := new(Msg)
m.SetQuestion("large.example.com.", TypeSRV)
m.SetEdns0(size, true)
reply := new(Msg)
reply.SetReply(m)
for i := 1; i < 33; i++ {
reply.Answer = append(reply.Answer, testRR(
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
}
for i := 1; i < 33; i++ {
reply.Extra = append(reply.Extra, testRR(
fmt.Sprintf("10-0-0-%d.default.pod.k8s.example.com. 10 IN A 10.0.0.%d", i, i)))
}
reply.SetEdns0(size, true)
reply.Truncate(size)
if want, got := size, reply.Len(); want < got {
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
}
if reply.Truncated {
t.Errorf("truncated bit should not be set")
}
opt := reply.Extra[len(reply.Extra)-1]
if opt.Header().Rrtype != TypeOPT {
t.Errorf("expected last RR to be OPT")
}
}
func TestTruncation(t *testing.T) {
reply := new(Msg)
for i := 0; i < 61; i++ {
reply.Answer = append(reply.Answer, testRR(fmt.Sprintf("http.service.tcp.srv.k8s.example.org. 5 IN SRV 0 0 80 10-144-230-%d.default.pod.k8s.example.org.", i)))
}
for i := 0; i < 5; i++ {
reply.Extra = append(reply.Extra, testRR(fmt.Sprintf("ip-10-10-52-5%d.subdomain.example.org. 5 IN A 10.10.52.5%d", i, i)))
}
for i := 0; i < 5; i++ {
reply.Ns = append(reply.Ns, testRR(fmt.Sprintf("srv.subdomain.example.org. 5 IN NS ip-10-10-33-6%d.subdomain.example.org.", i)))
}
for bufsize := 1024; bufsize <= 4096; bufsize += 12 {
m := new(Msg)
m.SetQuestion("http.service.tcp.srv.k8s.example.org.", TypeSRV)
m.SetEdns0(uint16(bufsize), true)
copy := reply.Copy()
copy.SetReply(m)
copy.Truncate(bufsize)
if want, got := bufsize, copy.Len(); want < got {
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
}
}
}
func TestRequestTruncateAnswerExact(t *testing.T) {
const size = 867 // Bit fiddly, but this hits the rl == size break clause in Truncate, 52 RRs should remain.
m := new(Msg)
m.SetQuestion("large.example.com.", TypeSRV)
m.SetEdns0(size, false)
reply := new(Msg)
reply.SetReply(m)
for i := 1; i < 200; i++ {
reply.Answer = append(reply.Answer, testRR(fmt.Sprintf("large.example.com. 10 IN A 127.0.0.%d", i)))
}
reply.Truncate(size)
if want, got := size, reply.Len(); want < got {
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
}
if expected := 52; len(reply.Answer) != expected {
t.Errorf("wrong number of answers; expected %d, got %d", expected, len(reply.Answer))
}
}
func BenchmarkMsgTruncate(b *testing.B) {
const size = 2048
m := new(Msg)
m.SetQuestion("example.com.", TypeA)
m.SetEdns0(size, true)
reply := new(Msg)
reply.SetReply(m)
for i := 1; i < 33; i++ {
reply.Answer = append(reply.Answer, testRR(
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
}
for i := 1; i < 33; i++ {
reply.Extra = append(reply.Extra, testRR(
fmt.Sprintf("10-0-0-%d.default.pod.k8s.example.com. 10 IN A 10.0.0.%d", i, i)))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
copy := reply.Copy()
b.StartTimer()
copy.Truncate(size)
}
}