pool the transient maps used for Msg Pack, Truncate and Len (#1006)

This improves runtime by 20-40% and more importantly significantly
reduces memory usage and allocations.
This commit is contained in:
Charlie Vieth 2019-09-23 02:16:26 -04:00 committed by Miek Gieben
parent b733ad8671
commit 9578caeab0
2 changed files with 32 additions and 5 deletions

30
msg.go
View File

@ -754,13 +754,24 @@ func (dns *Msg) Pack() (msg []byte, err error) {
return dns.PackBuffer(nil) return dns.PackBuffer(nil)
} }
var compressionPackPool = sync.Pool{
New: func() interface{} {
return make(map[string]uint16)
},
}
// PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated. // PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) { func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
// If this message can't be compressed, avoid filling the // If this message can't be compressed, avoid filling the
// compression map and creating garbage. // compression map and creating garbage.
if dns.Compress && dns.isCompressible() { if dns.Compress && dns.isCompressible() {
compression := make(map[string]uint16) // Compression pointer mappings. compression := compressionPackPool.Get().(map[string]uint16)
return dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true) msg, err := dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true)
for k := range compression {
delete(compression, k)
}
compressionPackPool.Put(compression)
return msg, err
} }
return dns.packBufferWithCompressionMap(buf, compressionMap{}, false) return dns.packBufferWithCompressionMap(buf, compressionMap{}, false)
@ -972,6 +983,12 @@ func (dns *Msg) isCompressible() bool {
len(dns.Ns) > 0 || len(dns.Extra) > 0 len(dns.Ns) > 0 || len(dns.Extra) > 0
} }
var compressionPool = sync.Pool{
New: func() interface{} {
return make(map[string]struct{})
},
}
// Len returns the message length when in (un)compressed wire format. // Len returns the message length when in (un)compressed wire format.
// If dns.Compress is true compression it is taken into account. Len() // If dns.Compress is true compression it is taken into account. Len()
// is provided to be a faster way to get the size of the resulting packet, // is provided to be a faster way to get the size of the resulting packet,
@ -980,8 +997,13 @@ func (dns *Msg) Len() int {
// If this message can't be compressed, avoid filling the // If this message can't be compressed, avoid filling the
// compression map and creating garbage. // compression map and creating garbage.
if dns.Compress && dns.isCompressible() { if dns.Compress && dns.isCompressible() {
compression := make(map[string]struct{}) compression := compressionPool.Get().(map[string]struct{})
return msgLenWithCompressionMap(dns, compression) n := msgLenWithCompressionMap(dns, compression)
for k := range compression {
delete(compression, k)
}
compressionPool.Put(compression)
return n
} }
return msgLenWithCompressionMap(dns, nil) return msgLenWithCompressionMap(dns, nil)

View File

@ -54,7 +54,7 @@ func (dns *Msg) Truncate(size int) {
size -= Len(edns0) size -= Len(edns0)
} }
compression := make(map[string]struct{}) compression := compressionPool.Get().(map[string]struct{})
l = headerSize l = headerSize
for _, r := range dns.Question { for _, r := range dns.Question {
@ -88,6 +88,11 @@ func (dns *Msg) Truncate(size int) {
// Add the OPT record back onto the additional section. // Add the OPT record back onto the additional section.
dns.Extra = append(dns.Extra, edns0) dns.Extra = append(dns.Extra, edns0)
} }
for k := range compression {
delete(compression, k)
}
compressionPool.Put(compression)
} }
func truncateLoop(rrs []RR, size, l int, compression map[string]struct{}) (int, int) { func truncateLoop(rrs []RR, size, l int, compression map[string]struct{}) (int, int) {