Change low level read/write for TLS connection

As tlc.Conn is just a TCP connection after the handshake, we will modify the
TCP functions to work with an io.Reader/io.Writer parameter instead of a
net.TCPConn so we can reuse them.

See #297
This commit is contained in:
Rafael Dantas Justo 2016-01-07 13:27:07 -02:00
parent 124839738d
commit 020f925824
1 changed files with 21 additions and 11 deletions

View File

@ -253,15 +253,19 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
err error
)
if t, ok := co.Conn.(*net.TCPConn); ok {
switch t := co.Conn.(type) {
case *net.TCPConn, *tls.Conn:
r := t.(io.Reader)
// First two bytes specify the length of the entire message.
l, err := tcpMsgLen(t)
l, err := tcpMsgLen(r)
if err != nil {
return nil, err
}
p = make([]byte, l)
n, err = tcpRead(t, p)
} else {
n, err = tcpRead(r, p)
default:
if co.UDPSize > MinMsgSize {
p = make([]byte, co.UDPSize)
} else {
@ -286,7 +290,7 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
}
// tcpMsgLen is a helper func to read first two bytes of stream as uint16 packet length.
func tcpMsgLen(t *net.TCPConn) (int, error) {
func tcpMsgLen(t io.Reader) (int, error) {
p := []byte{0, 0}
n, err := t.Read(p)
if err != nil {
@ -303,7 +307,7 @@ func tcpMsgLen(t *net.TCPConn) (int, error) {
}
// tcpRead calls TCPConn.Read enough times to fill allocated buffer.
func tcpRead(t *net.TCPConn, p []byte) (int, error) {
func tcpRead(t io.Reader, p []byte) (int, error) {
n, err := t.Read(p)
if err != nil {
return n, err
@ -326,15 +330,18 @@ func (co *Conn) Read(p []byte) (n int, err error) {
if len(p) < 2 {
return 0, io.ErrShortBuffer
}
if t, ok := co.Conn.(*net.TCPConn); ok {
l, err := tcpMsgLen(t)
switch t := co.Conn.(type) {
case *net.TCPConn, *tls.Conn:
r := t.(io.Reader)
l, err := tcpMsgLen(r)
if err != nil {
return 0, err
}
if l > len(p) {
return int(l), io.ErrShortBuffer
}
return tcpRead(t, p[:l])
return tcpRead(r, p[:l])
}
// UDP connection
n, err = co.Conn.Read(p)
@ -374,7 +381,10 @@ func (co *Conn) WriteMsg(m *Msg) (err error) {
// Write implements the net.Conn Write method.
func (co *Conn) Write(p []byte) (n int, err error) {
if t, ok := co.Conn.(*net.TCPConn); ok {
switch t := co.Conn.(type) {
case *net.TCPConn, *tls.Conn:
w := t.(io.Writer)
lp := len(p)
if lp < 2 {
return 0, io.ErrShortBuffer
@ -385,7 +395,7 @@ func (co *Conn) Write(p []byte) (n int, err error) {
l := make([]byte, 2, lp+2)
l[0], l[1] = packUint16(uint16(lp))
p = append(l, p...)
n, err := io.Copy(t, bytes.NewReader(p))
n, err := io.Copy(w, bytes.NewReader(p))
return int(n), err
}
n, err = co.Conn.(*net.UDPConn).Write(p)