Simplify the Write() for TCP based connections

Simplify the code path by using io.Copy to handle partial writes. Allocate `l` large enough to avoid a re-allocation. Potential short write fix.
This commit is contained in:
Armon Dadgar 2014-01-03 15:19:35 -08:00
parent 451c12da09
commit 0cf549278c
1 changed files with 7 additions and 16 deletions

View File

@ -9,6 +9,8 @@
package dns package dns
import ( import (
"bytes"
"io"
"net" "net"
"sync" "sync"
"time" "time"
@ -412,26 +414,15 @@ func (w *response) Write(m []byte) (int, error) {
return n, err return n, err
case w.tcp != nil: case w.tcp != nil:
lm := len(m) lm := len(m)
if len(m) > MaxMsgSize { if lm > MaxMsgSize {
return 0, &Error{err: "message too large"} return 0, &Error{err: "message too large"}
} }
l := make([]byte, 2) l := make([]byte, 2, 2+lm)
l[0], l[1] = packUint16(uint16(lm)) l[0], l[1] = packUint16(uint16(lm))
m = append(l, m...) m = append(l, m...)
n, err := w.tcp.Write(m)
if err != nil { n, err := io.Copy(w.tcp, bytes.NewReader(m))
return n, err return int(n), err
}
i := n
if i < lm {
j, err := w.tcp.Write(m[i:lm])
if err != nil {
return i, err
}
i += j
}
n = i
return i, nil
} }
panic("not reached") panic("not reached")
} }