diff --git a/client.go b/client.go index aa2c49d3..000dc013 100644 --- a/client.go +++ b/client.go @@ -340,11 +340,10 @@ func (co *Conn) Write(p []byte) (int, error) { return co.Conn.Write(p) } - l := make([]byte, 2) - binary.BigEndian.PutUint16(l, uint16(len(p))) - - n, err := (&net.Buffers{l, p}).WriteTo(co.Conn) - return int(n), err + msg := make([]byte, 2+len(p)) + binary.BigEndian.PutUint16(msg, uint16(len(p))) + copy(msg[2:], p) + return co.Conn.Write(msg) } // Return the appropriate timeout for a specific request diff --git a/client_test.go b/client_test.go index 3ac8354c..4bff6eb4 100644 --- a/client_test.go +++ b/client_test.go @@ -396,6 +396,24 @@ func TestClientConn(t *testing.T) { } } +func TestClientConnWriteSinglePacket(t *testing.T) { + c := &countingConn{} + conn := Conn{ + Conn: c, + } + m := new(Msg) + m.SetQuestion("miek.nl.", TypeTXT) + err := conn.WriteMsg(m) + + if err != nil { + t.Fatalf("failed to write: %v", err) + } + + if c.writes != 1 { + t.Fatalf("incorrect number of Write calls") + } +} + func TestTruncatedMsg(t *testing.T) { m := new(Msg) m.SetQuestion("miek.nl.", TypeSRV) diff --git a/server.go b/server.go index 30dfd41d..eec02ef9 100644 --- a/server.go +++ b/server.go @@ -752,11 +752,10 @@ func (w *response) Write(m []byte) (int, error) { return 0, &Error{err: "message too large"} } - l := make([]byte, 2) - binary.BigEndian.PutUint16(l, uint16(len(m))) - - n, err := (&net.Buffers{l, m}).WriteTo(w.tcp) - return int(n), err + msg := make([]byte, 2+len(m)) + binary.BigEndian.PutUint16(msg, uint16(len(m))) + copy(msg[2:], m) + return w.tcp.Write(msg) default: panic("dns: internal error: udp and tcp both nil") } diff --git a/server_test.go b/server_test.go index 6f4bb5a6..e3bbb5a2 100644 --- a/server_test.go +++ b/server_test.go @@ -1100,6 +1100,37 @@ func TestResponseDoubleClose(t *testing.T) { } } +type countingConn struct { + net.Conn + writes int +} + +func (c *countingConn) Write(p []byte) (int, error) { + c.writes++ + return len(p), nil +} + +func TestResponseWriteSinglePacket(t *testing.T) { + c := &countingConn{} + rw := &response{ + tcp: c, + } + rw.writer = rw + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeTXT) + m.Response = true + err := rw.WriteMsg(m) + + if err != nil { + t.Fatalf("failed to write: %v", err) + } + + if c.writes != 1 { + t.Fatalf("incorrect number of Write calls") + } +} + type ExampleFrameLengthWriter struct { Writer }