Use net.Buffers for writing TCP message (#934)

This commit is contained in:
Tom Thorogood 2019-03-11 00:16:14 +10:30 committed by Miek Gieben
parent 1a5555c783
commit 337216f9a7
2 changed files with 13 additions and 23 deletions

View File

@ -3,7 +3,6 @@ package dns
// A client implementation. // A client implementation.
import ( import (
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/binary" "encoding/binary"
@ -353,23 +352,19 @@ func (co *Conn) WriteMsg(m *Msg) (err error) {
// Write implements the net.Conn Write method. // Write implements the net.Conn Write method.
func (co *Conn) Write(p []byte) (n int, err error) { func (co *Conn) Write(p []byte) (n int, err error) {
switch t := co.Conn.(type) { switch co.Conn.(type) {
case *net.TCPConn, *tls.Conn: case *net.TCPConn, *tls.Conn:
w := t.(io.Writer) if len(p) > MaxMsgSize {
lp := len(p)
if lp < 2 {
return 0, io.ErrShortBuffer
}
if lp > MaxMsgSize {
return 0, &Error{err: "message too large"} return 0, &Error{err: "message too large"}
} }
l := make([]byte, 2, lp+2)
binary.BigEndian.PutUint16(l, uint16(lp)) l := make([]byte, 2)
p = append(l, p...) binary.BigEndian.PutUint16(l, uint16(len(p)))
n, err := io.Copy(w, bytes.NewReader(p))
n, err := (&net.Buffers{l, p}).WriteTo(co.Conn)
return int(n), err return int(n), err
} }
return co.Conn.Write(p) return co.Conn.Write(p)
} }

View File

@ -3,7 +3,6 @@
package dns package dns
import ( import (
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/binary" "encoding/binary"
@ -701,18 +700,14 @@ func (w *response) Write(m []byte) (int, error) {
case w.udp != nil: case w.udp != nil:
return WriteToSessionUDP(w.udp, m, w.udpSession) return WriteToSessionUDP(w.udp, m, w.udpSession)
case w.tcp != nil: case w.tcp != nil:
lm := len(m) if len(m) > MaxMsgSize {
if lm < 2 {
return 0, io.ErrShortBuffer
}
if lm > MaxMsgSize {
return 0, &Error{err: "message too large"} return 0, &Error{err: "message too large"}
} }
l := make([]byte, 2, 2+lm)
binary.BigEndian.PutUint16(l, uint16(lm))
m = append(l, m...)
n, err := io.Copy(w.tcp, bytes.NewReader(m)) l := make([]byte, 2)
binary.BigEndian.PutUint16(l, uint16(len(m)))
n, err := (&net.Buffers{l, m}).WriteTo(w.tcp)
return int(n), err return int(n), err
default: default:
panic("dns: internal error: udp and tcp both nil") panic("dns: internal error: udp and tcp both nil")