Simplify TCP reading (#935)

* Simplify Server.readTCP

This slightly alters the error behaviour, but it should not be
observable outside of a decorated reader. I don't believe the old
behaviour was either obvious, documented or correct.

* Simplify TCP reading in client Conn

This alters the error behaviour in possibly observable ways, though
this is quite subtle and may not actually be readily observable.

Conn.ReadMsgHeader should behave the same way and still returns
ErrShortRead for length being too short.

Conn.Read will no longer return ErrShortRead if the length == 0,
otherwise it should be largely similar.

* Remove redundant error check in Conn.ReadMsgHeader
This commit is contained in:
Tom Thorogood 2019-03-11 21:29:25 +10:30 committed by Miek Gieben
parent 337216f9a7
commit 834f456fff
2 changed files with 24 additions and 93 deletions

View File

@ -219,18 +219,15 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
n int n int
err error err error
) )
switch co.Conn.(type) {
switch t := co.Conn.(type) {
case *net.TCPConn, *tls.Conn: case *net.TCPConn, *tls.Conn:
r := t.(io.Reader) var length uint16
if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
// First two bytes specify the length of the entire message.
l, err := tcpMsgLen(r)
if err != nil {
return nil, err return nil, err
} }
p = make([]byte, l)
n, err = tcpRead(r, p) p = make([]byte, length)
n, err = io.ReadFull(co.Conn, p)
default: default:
if co.UDPSize > MinMsgSize { if co.UDPSize > MinMsgSize {
p = make([]byte, co.UDPSize) p = make([]byte, co.UDPSize)
@ -257,72 +254,26 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
return p, err return p, err
} }
// tcpMsgLen is a helper func to read first two bytes of stream as uint16 packet length.
func tcpMsgLen(t io.Reader) (int, error) {
p := []byte{0, 0}
n, err := t.Read(p)
if err != nil {
return 0, err
}
// As seen with my local router/switch, returns 1 byte on the above read,
// resulting a a ShortRead. Just write it out (instead of loop) and read the
// other byte.
if n == 1 {
n1, err := t.Read(p[1:])
if err != nil {
return 0, err
}
n += n1
}
if n != 2 {
return 0, ErrShortRead
}
l := binary.BigEndian.Uint16(p)
if l == 0 {
return 0, ErrShortRead
}
return int(l), nil
}
// tcpRead calls TCPConn.Read enough times to fill allocated buffer.
func tcpRead(t io.Reader, p []byte) (int, error) {
n, err := t.Read(p)
if err != nil {
return n, err
}
for n < len(p) {
j, err := t.Read(p[n:])
if err != nil {
return n, err
}
n += j
}
return n, err
}
// Read implements the net.Conn read method. // Read implements the net.Conn read method.
func (co *Conn) Read(p []byte) (n int, err error) { func (co *Conn) Read(p []byte) (n int, err error) {
if co.Conn == nil { if co.Conn == nil {
return 0, ErrConnEmpty return 0, ErrConnEmpty
} }
if len(p) < 2 {
return 0, io.ErrShortBuffer
}
switch t := co.Conn.(type) {
case *net.TCPConn, *tls.Conn:
r := t.(io.Reader)
l, err := tcpMsgLen(r) switch co.Conn.(type) {
if err != nil { case *net.TCPConn, *tls.Conn:
var length uint16
if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
return 0, err return 0, err
} }
if l > len(p) { if int(length) > len(p) {
return l, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
return tcpRead(r, p[:l])
n, err := io.ReadFull(co.Conn, p[:length])
return int(n), err
} }
// UDP connection // UDP connection
return co.Conn.Read(p) return co.Conn.Read(p)
} }

View File

@ -614,36 +614,16 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
} }
srv.lock.RUnlock() srv.lock.RUnlock()
l := make([]byte, 2) var length uint16
n, err := conn.Read(l) if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
if err != nil || n != 2 { return nil, err
if err != nil {
return nil, err
}
return nil, ErrShortRead
} }
length := binary.BigEndian.Uint16(l)
if length == 0 { m := make([]byte, length)
return nil, ErrShortRead if _, err := io.ReadFull(conn, m); err != nil {
return nil, err
} }
m := make([]byte, int(length))
n, err = conn.Read(m[:int(length)])
if err != nil || n == 0 {
if err != nil {
return nil, err
}
return nil, ErrShortRead
}
i := n
for i < int(length) {
j, err := conn.Read(m[i:int(length)])
if err != nil {
return nil, err
}
i += j
}
n = i
m = m[:n]
return m, nil return m, nil
} }