Merge branch 'master' of github.com:miekg/dns
This commit is contained in:
commit
0460860e89
22
defaults.go
22
defaults.go
|
@ -146,10 +146,9 @@ func (dns *Msg) IsTsig() *TSIG {
|
|||
// record in the additional section will do. It returns the OPT record
|
||||
// found or nil.
|
||||
func (dns *Msg) IsEdns0() *OPT {
|
||||
// EDNS0 is at the end of the additional section, start there.
|
||||
// We might want to change this to *only* look at the last two
|
||||
// records. So we see TSIG and/or OPT - this a slightly bigger
|
||||
// change though.
|
||||
// RFC 6891, Section 6.1.1 allows the OPT record to appear
|
||||
// anywhere in the additional record section, but it's usually at
|
||||
// the end so start there.
|
||||
for i := len(dns.Extra) - 1; i >= 0; i-- {
|
||||
if dns.Extra[i].Header().Rrtype == TypeOPT {
|
||||
return dns.Extra[i].(*OPT)
|
||||
|
@ -158,6 +157,21 @@ func (dns *Msg) IsEdns0() *OPT {
|
|||
return nil
|
||||
}
|
||||
|
||||
// popEdns0 is like IsEdns0, but it removes the record from the message.
|
||||
func (dns *Msg) popEdns0() *OPT {
|
||||
// RFC 6891, Section 6.1.1 allows the OPT record to appear
|
||||
// anywhere in the additional record section, but it's usually at
|
||||
// the end so start there.
|
||||
for i := len(dns.Extra) - 1; i >= 0; i-- {
|
||||
if dns.Extra[i].Header().Rrtype == TypeOPT {
|
||||
opt := dns.Extra[i].(*OPT)
|
||||
dns.Extra = append(dns.Extra[:i], dns.Extra[i+1:]...)
|
||||
return opt
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsDomainName checks if s is a valid domain name, it returns the number of
|
||||
// labels and true, when a domain name is valid. Note that non fully qualified
|
||||
// domain name is considered valid, in this case the last label is counted in
|
||||
|
|
|
@ -25,12 +25,13 @@ func unpackDataA(msg []byte, off int) (net.IP, int, error) {
|
|||
}
|
||||
|
||||
func packDataA(a net.IP, msg []byte, off int) (int, error) {
|
||||
// It must be a slice of 4, even if it is 16, we encode only the first 4
|
||||
if off+net.IPv4len > len(msg) {
|
||||
return len(msg), &Error{err: "overflow packing a"}
|
||||
}
|
||||
switch len(a) {
|
||||
case net.IPv4len, net.IPv6len:
|
||||
// It must be a slice of 4, even if it is 16, we encode only the first 4
|
||||
if off+net.IPv4len > len(msg) {
|
||||
return len(msg), &Error{err: "overflow packing a"}
|
||||
}
|
||||
|
||||
copy(msg[off:], a.To4())
|
||||
off += net.IPv4len
|
||||
case 0:
|
||||
|
@ -51,12 +52,12 @@ func unpackDataAAAA(msg []byte, off int) (net.IP, int, error) {
|
|||
}
|
||||
|
||||
func packDataAAAA(aaaa net.IP, msg []byte, off int) (int, error) {
|
||||
if off+net.IPv6len > len(msg) {
|
||||
return len(msg), &Error{err: "overflow packing aaaa"}
|
||||
}
|
||||
|
||||
switch len(aaaa) {
|
||||
case net.IPv6len:
|
||||
if off+net.IPv6len > len(msg) {
|
||||
return len(msg), &Error{err: "overflow packing aaaa"}
|
||||
}
|
||||
|
||||
copy(msg[off:], aaaa)
|
||||
off += net.IPv6len
|
||||
case 0:
|
||||
|
|
17
msg_test.go
17
msg_test.go
|
@ -306,3 +306,20 @@ func TestPackUnpackManyCompressionPointers(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLenDynamicA(t *testing.T) {
|
||||
for _, rr := range []RR{
|
||||
testRR("example.org. A"),
|
||||
testRR("example.org. AAAA"),
|
||||
testRR("example.org. L32"),
|
||||
} {
|
||||
msg := make([]byte, Len(rr))
|
||||
off, err := PackRR(rr, msg, 0, nil, false)
|
||||
if err != nil {
|
||||
t.Fatalf("PackRR failed for %T: %v", rr, err)
|
||||
}
|
||||
if off != len(msg) {
|
||||
t.Errorf("Len(rr) wrong for %T: Len(rr) = %d, PackRR(rr) = %d", rr, len(msg), off)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
package dns
|
||||
|
||||
// Truncate ensures the reply message will fit into the requested buffer
|
||||
// size by removing records that exceed the requested size.
|
||||
//
|
||||
// It will first check if the reply fits without compression and then with
|
||||
// compression. If it won't fit with compression, Scrub then walks the
|
||||
// record adding as many records as possible without exceeding the
|
||||
// requested buffer size.
|
||||
//
|
||||
// The TC bit will be set if any answer records were excluded from the
|
||||
// message. This indicates to that the client should retry over TCP.
|
||||
//
|
||||
// The appropriate buffer size can be retrieved from the requests OPT
|
||||
// record, if present, and is transport specific otherwise. dns.MinMsgSize
|
||||
// should be used for UDP requests without an OPT record, and
|
||||
// dns.MaxMsgSize for TCP requests without an OPT record.
|
||||
func (dns *Msg) Truncate(size int) {
|
||||
if dns.IsTsig() != nil {
|
||||
// To simplify this implementation, we don't perform
|
||||
// truncation on responses with a TSIG record.
|
||||
return
|
||||
}
|
||||
|
||||
// RFC 6891 mandates that the payload size in an OPT record
|
||||
// less than 512 bytes must be treated as equal to 512 bytes.
|
||||
//
|
||||
// For ease of use, we impose that restriction here.
|
||||
if size < 512 {
|
||||
size = 512
|
||||
}
|
||||
|
||||
l := msgLenWithCompressionMap(dns, nil) // uncompressed length
|
||||
if l <= size {
|
||||
// Don't waste effort compressing this message.
|
||||
dns.Compress = false
|
||||
return
|
||||
}
|
||||
|
||||
dns.Compress = true
|
||||
|
||||
edns0 := dns.popEdns0()
|
||||
if edns0 != nil {
|
||||
// Account for the OPT record that gets added at the end,
|
||||
// by subtracting that length from our budget.
|
||||
//
|
||||
// The EDNS(0) OPT record must have the root domain and
|
||||
// it's length is thus unaffected by compression.
|
||||
size -= Len(edns0)
|
||||
}
|
||||
|
||||
compression := make(map[string]struct{})
|
||||
|
||||
l = headerSize
|
||||
for _, r := range dns.Question {
|
||||
l += r.len(l, compression)
|
||||
}
|
||||
|
||||
var numAnswer int
|
||||
if l < size {
|
||||
l, numAnswer = truncateLoop(dns.Answer, size, l, compression)
|
||||
}
|
||||
|
||||
var numNS int
|
||||
if l < size {
|
||||
l, numNS = truncateLoop(dns.Ns, size, l, compression)
|
||||
}
|
||||
|
||||
var numExtra int
|
||||
if l < size {
|
||||
l, numExtra = truncateLoop(dns.Extra, size, l, compression)
|
||||
}
|
||||
|
||||
// According to RFC 2181, the TC bit should only be set if not all
|
||||
// of the answer RRs can be included in the response.
|
||||
dns.Truncated = len(dns.Answer) > numAnswer
|
||||
|
||||
dns.Answer = dns.Answer[:numAnswer]
|
||||
dns.Ns = dns.Ns[:numNS]
|
||||
dns.Extra = dns.Extra[:numExtra]
|
||||
|
||||
if edns0 != nil {
|
||||
// Add the OPT record back onto the additional section.
|
||||
dns.Extra = append(dns.Extra, edns0)
|
||||
}
|
||||
}
|
||||
|
||||
func truncateLoop(rrs []RR, size, l int, compression map[string]struct{}) (int, int) {
|
||||
for i, r := range rrs {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
l += r.len(l, compression)
|
||||
if l > size {
|
||||
// Return size, rather than l prior to this record,
|
||||
// to prevent any further records being added.
|
||||
return size, i
|
||||
}
|
||||
if l == size {
|
||||
return l, i + 1
|
||||
}
|
||||
}
|
||||
|
||||
return l, len(rrs)
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestTruncateAnswer(t *testing.T) {
|
||||
m := new(Msg)
|
||||
m.SetQuestion("large.example.com.", TypeSRV)
|
||||
|
||||
reply := new(Msg)
|
||||
reply.SetReply(m)
|
||||
for i := 1; i < 200; i++ {
|
||||
reply.Answer = append(reply.Answer, testRR(
|
||||
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
|
||||
}
|
||||
|
||||
reply.Truncate(MinMsgSize)
|
||||
if want, got := MinMsgSize, reply.Len(); want < got {
|
||||
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
|
||||
}
|
||||
if !reply.Truncated {
|
||||
t.Errorf("truncated bit should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestTruncateExtra(t *testing.T) {
|
||||
m := new(Msg)
|
||||
m.SetQuestion("large.example.com.", TypeSRV)
|
||||
|
||||
reply := new(Msg)
|
||||
reply.SetReply(m)
|
||||
for i := 1; i < 200; i++ {
|
||||
reply.Extra = append(reply.Extra, testRR(
|
||||
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
|
||||
}
|
||||
|
||||
reply.Truncate(MinMsgSize)
|
||||
if want, got := MinMsgSize, reply.Len(); want < got {
|
||||
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
|
||||
}
|
||||
if reply.Truncated {
|
||||
t.Errorf("truncated bit should not be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestTruncateExtraEdns0(t *testing.T) {
|
||||
const size = 4096
|
||||
|
||||
m := new(Msg)
|
||||
m.SetQuestion("large.example.com.", TypeSRV)
|
||||
m.SetEdns0(size, true)
|
||||
|
||||
reply := new(Msg)
|
||||
reply.SetReply(m)
|
||||
for i := 1; i < 200; i++ {
|
||||
reply.Extra = append(reply.Extra, testRR(
|
||||
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
|
||||
}
|
||||
reply.SetEdns0(size, true)
|
||||
|
||||
reply.Truncate(size)
|
||||
if want, got := size, reply.Len(); want < got {
|
||||
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
|
||||
}
|
||||
if reply.Truncated {
|
||||
t.Errorf("truncated bit should not be set")
|
||||
}
|
||||
opt := reply.Extra[len(reply.Extra)-1]
|
||||
if opt.Header().Rrtype != TypeOPT {
|
||||
t.Errorf("expected last RR to be OPT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestTruncateExtraRegression(t *testing.T) {
|
||||
const size = 2048
|
||||
|
||||
m := new(Msg)
|
||||
m.SetQuestion("large.example.com.", TypeSRV)
|
||||
m.SetEdns0(size, true)
|
||||
|
||||
reply := new(Msg)
|
||||
reply.SetReply(m)
|
||||
for i := 1; i < 33; i++ {
|
||||
reply.Answer = append(reply.Answer, testRR(
|
||||
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
|
||||
}
|
||||
for i := 1; i < 33; i++ {
|
||||
reply.Extra = append(reply.Extra, testRR(
|
||||
fmt.Sprintf("10-0-0-%d.default.pod.k8s.example.com. 10 IN A 10.0.0.%d", i, i)))
|
||||
}
|
||||
reply.SetEdns0(size, true)
|
||||
|
||||
reply.Truncate(size)
|
||||
if want, got := size, reply.Len(); want < got {
|
||||
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
|
||||
}
|
||||
if reply.Truncated {
|
||||
t.Errorf("truncated bit should not be set")
|
||||
}
|
||||
opt := reply.Extra[len(reply.Extra)-1]
|
||||
if opt.Header().Rrtype != TypeOPT {
|
||||
t.Errorf("expected last RR to be OPT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncation(t *testing.T) {
|
||||
reply := new(Msg)
|
||||
|
||||
for i := 0; i < 61; i++ {
|
||||
reply.Answer = append(reply.Answer, testRR(fmt.Sprintf("http.service.tcp.srv.k8s.example.org. 5 IN SRV 0 0 80 10-144-230-%d.default.pod.k8s.example.org.", i)))
|
||||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
reply.Extra = append(reply.Extra, testRR(fmt.Sprintf("ip-10-10-52-5%d.subdomain.example.org. 5 IN A 10.10.52.5%d", i, i)))
|
||||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
reply.Ns = append(reply.Ns, testRR(fmt.Sprintf("srv.subdomain.example.org. 5 IN NS ip-10-10-33-6%d.subdomain.example.org.", i)))
|
||||
}
|
||||
|
||||
for bufsize := 1024; bufsize <= 4096; bufsize += 12 {
|
||||
m := new(Msg)
|
||||
m.SetQuestion("http.service.tcp.srv.k8s.example.org.", TypeSRV)
|
||||
m.SetEdns0(uint16(bufsize), true)
|
||||
|
||||
copy := reply.Copy()
|
||||
copy.SetReply(m)
|
||||
|
||||
copy.Truncate(bufsize)
|
||||
if want, got := bufsize, copy.Len(); want < got {
|
||||
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestTruncateAnswerExact(t *testing.T) {
|
||||
const size = 867 // Bit fiddly, but this hits the rl == size break clause in Truncate, 52 RRs should remain.
|
||||
|
||||
m := new(Msg)
|
||||
m.SetQuestion("large.example.com.", TypeSRV)
|
||||
m.SetEdns0(size, false)
|
||||
|
||||
reply := new(Msg)
|
||||
reply.SetReply(m)
|
||||
for i := 1; i < 200; i++ {
|
||||
reply.Answer = append(reply.Answer, testRR(fmt.Sprintf("large.example.com. 10 IN A 127.0.0.%d", i)))
|
||||
}
|
||||
|
||||
reply.Truncate(size)
|
||||
if want, got := size, reply.Len(); want < got {
|
||||
t.Errorf("message length should be bellow %d bytes, got %d bytes", want, got)
|
||||
}
|
||||
if expected := 52; len(reply.Answer) != expected {
|
||||
t.Errorf("wrong number of answers; expected %d, got %d", expected, len(reply.Answer))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMsgTruncate(b *testing.B) {
|
||||
const size = 2048
|
||||
|
||||
m := new(Msg)
|
||||
m.SetQuestion("example.com.", TypeA)
|
||||
m.SetEdns0(size, true)
|
||||
|
||||
reply := new(Msg)
|
||||
reply.SetReply(m)
|
||||
for i := 1; i < 33; i++ {
|
||||
reply.Answer = append(reply.Answer, testRR(
|
||||
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
|
||||
}
|
||||
for i := 1; i < 33; i++ {
|
||||
reply.Extra = append(reply.Extra, testRR(
|
||||
fmt.Sprintf("10-0-0-%d.default.pod.k8s.example.com. 10 IN A 10.0.0.%d", i, i)))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
copy := reply.Copy()
|
||||
b.StartTimer()
|
||||
|
||||
copy.Truncate(size)
|
||||
}
|
||||
}
|
|
@ -196,9 +196,9 @@ func main() {
|
|||
case st.Tag(i) == `dns:"any"`:
|
||||
o("l += len(rr.%s)\n")
|
||||
case st.Tag(i) == `dns:"a"`:
|
||||
o("l += net.IPv4len // %s\n")
|
||||
o("if len(rr.%s) != 0 { l += net.IPv4len }\n")
|
||||
case st.Tag(i) == `dns:"aaaa"`:
|
||||
o("l += net.IPv6len // %s\n")
|
||||
o("if len(rr.%s) != 0 { l += net.IPv6len }\n")
|
||||
case st.Tag(i) == `dns:"txt"`:
|
||||
o("for _, t := range rr.%s { l += len(t) + 1 }\n")
|
||||
case st.Tag(i) == `dns:"uint48"`:
|
||||
|
|
14
ztypes.go
14
ztypes.go
|
@ -240,12 +240,16 @@ func (rr *X25) Header() *RR_Header { return &rr.Hdr }
|
|||
// len() functions
|
||||
func (rr *A) len(off int, compression map[string]struct{}) int {
|
||||
l := rr.Hdr.len(off, compression)
|
||||
l += net.IPv4len // A
|
||||
if len(rr.A) != 0 {
|
||||
l += net.IPv4len
|
||||
}
|
||||
return l
|
||||
}
|
||||
func (rr *AAAA) len(off int, compression map[string]struct{}) int {
|
||||
l := rr.Hdr.len(off, compression)
|
||||
l += net.IPv6len // AAAA
|
||||
if len(rr.AAAA) != 0 {
|
||||
l += net.IPv6len
|
||||
}
|
||||
return l
|
||||
}
|
||||
func (rr *AFSDB) len(off int, compression map[string]struct{}) int {
|
||||
|
@ -364,8 +368,10 @@ func (rr *KX) len(off int, compression map[string]struct{}) int {
|
|||
}
|
||||
func (rr *L32) len(off int, compression map[string]struct{}) int {
|
||||
l := rr.Hdr.len(off, compression)
|
||||
l += 2 // Preference
|
||||
l += net.IPv4len // Locator32
|
||||
l += 2 // Preference
|
||||
if len(rr.Locator32) != 0 {
|
||||
l += net.IPv4len
|
||||
}
|
||||
return l
|
||||
}
|
||||
func (rr *L64) len(off int, compression map[string]struct{}) int {
|
||||
|
|
Loading…
Reference in New Issue