From 834f456fff2f4edb6cced331f50e4cbfe7e1056b Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Mon, 11 Mar 2019 21:29:25 +1030 Subject: [PATCH] Simplify TCP reading (#935) * Simplify Server.readTCP This slightly alters the error behaviour, but it should not be observable outside of a decorated reader. I don't believe the old behaviour was either obvious, documented or correct. * Simplify TCP reading in client Conn This alters the error behaviour in possibly observable ways, though this is quite subtle and may not actually be readily observable. Conn.ReadMsgHeader should behave the same way and still returns ErrShortRead for length being too short. Conn.Read will no longer return ErrShortRead if the length == 0, otherwise it should be largely similar. * Remove redundant error check in Conn.ReadMsgHeader --- client.go | 81 +++++++++++-------------------------------------------- server.go | 36 ++++++------------------- 2 files changed, 24 insertions(+), 93 deletions(-) diff --git a/client.go b/client.go index 86894ec2..9b225d5e 100644 --- a/client.go +++ b/client.go @@ -219,18 +219,15 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { n int err error ) - - switch t := co.Conn.(type) { + switch co.Conn.(type) { case *net.TCPConn, *tls.Conn: - r := t.(io.Reader) - - // First two bytes specify the length of the entire message. - l, err := tcpMsgLen(r) - if err != nil { + var length uint16 + if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { return nil, err } - p = make([]byte, l) - n, err = tcpRead(r, p) + + p = make([]byte, length) + n, err = io.ReadFull(co.Conn, p) default: if co.UDPSize > MinMsgSize { p = make([]byte, co.UDPSize) @@ -257,72 +254,26 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { return p, err } -// tcpMsgLen is a helper func to read first two bytes of stream as uint16 packet length. -func tcpMsgLen(t io.Reader) (int, error) { - p := []byte{0, 0} - n, err := t.Read(p) - if err != nil { - return 0, err - } - - // As seen with my local router/switch, returns 1 byte on the above read, - // resulting a a ShortRead. Just write it out (instead of loop) and read the - // other byte. - if n == 1 { - n1, err := t.Read(p[1:]) - if err != nil { - return 0, err - } - n += n1 - } - - if n != 2 { - return 0, ErrShortRead - } - l := binary.BigEndian.Uint16(p) - if l == 0 { - return 0, ErrShortRead - } - return int(l), nil -} - -// tcpRead calls TCPConn.Read enough times to fill allocated buffer. -func tcpRead(t io.Reader, p []byte) (int, error) { - n, err := t.Read(p) - if err != nil { - return n, err - } - for n < len(p) { - j, err := t.Read(p[n:]) - if err != nil { - return n, err - } - n += j - } - return n, err -} - // Read implements the net.Conn read method. func (co *Conn) Read(p []byte) (n int, err error) { if co.Conn == nil { return 0, ErrConnEmpty } - if len(p) < 2 { - return 0, io.ErrShortBuffer - } - switch t := co.Conn.(type) { - case *net.TCPConn, *tls.Conn: - r := t.(io.Reader) - l, err := tcpMsgLen(r) - if err != nil { + switch co.Conn.(type) { + case *net.TCPConn, *tls.Conn: + var length uint16 + if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { return 0, err } - if l > len(p) { - return l, io.ErrShortBuffer + if int(length) > len(p) { + return 0, io.ErrShortBuffer } - return tcpRead(r, p[:l]) + + n, err := io.ReadFull(co.Conn, p[:length]) + return int(n), err } + // UDP connection return co.Conn.Read(p) } diff --git a/server.go b/server.go index 4bd4674d..b09d3717 100644 --- a/server.go +++ b/server.go @@ -614,36 +614,16 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) } srv.lock.RUnlock() - l := make([]byte, 2) - n, err := conn.Read(l) - if err != nil || n != 2 { - if err != nil { - return nil, err - } - return nil, ErrShortRead + var length uint16 + if err := binary.Read(conn, binary.BigEndian, &length); err != nil { + return nil, err } - length := binary.BigEndian.Uint16(l) - if length == 0 { - return nil, ErrShortRead + + m := make([]byte, length) + if _, err := io.ReadFull(conn, m); err != nil { + return nil, err } - m := make([]byte, int(length)) - n, err = conn.Read(m[:int(length)]) - if err != nil || n == 0 { - if err != nil { - return nil, err - } - return nil, ErrShortRead - } - i := n - for i < int(length) { - j, err := conn.Read(m[i:int(length)]) - if err != nil { - return nil, err - } - i += j - } - n = i - m = m[:n] + return m, nil }