diff --git a/client.go b/client.go index 86894ec2..9b225d5e 100644 --- a/client.go +++ b/client.go @@ -219,18 +219,15 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { n int err error ) - - switch t := co.Conn.(type) { + switch co.Conn.(type) { case *net.TCPConn, *tls.Conn: - r := t.(io.Reader) - - // First two bytes specify the length of the entire message. - l, err := tcpMsgLen(r) - if err != nil { + var length uint16 + if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { return nil, err } - p = make([]byte, l) - n, err = tcpRead(r, p) + + p = make([]byte, length) + n, err = io.ReadFull(co.Conn, p) default: if co.UDPSize > MinMsgSize { p = make([]byte, co.UDPSize) @@ -257,72 +254,26 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) { return p, err } -// tcpMsgLen is a helper func to read first two bytes of stream as uint16 packet length. -func tcpMsgLen(t io.Reader) (int, error) { - p := []byte{0, 0} - n, err := t.Read(p) - if err != nil { - return 0, err - } - - // As seen with my local router/switch, returns 1 byte on the above read, - // resulting a a ShortRead. Just write it out (instead of loop) and read the - // other byte. - if n == 1 { - n1, err := t.Read(p[1:]) - if err != nil { - return 0, err - } - n += n1 - } - - if n != 2 { - return 0, ErrShortRead - } - l := binary.BigEndian.Uint16(p) - if l == 0 { - return 0, ErrShortRead - } - return int(l), nil -} - -// tcpRead calls TCPConn.Read enough times to fill allocated buffer. -func tcpRead(t io.Reader, p []byte) (int, error) { - n, err := t.Read(p) - if err != nil { - return n, err - } - for n < len(p) { - j, err := t.Read(p[n:]) - if err != nil { - return n, err - } - n += j - } - return n, err -} - // Read implements the net.Conn read method. func (co *Conn) Read(p []byte) (n int, err error) { if co.Conn == nil { return 0, ErrConnEmpty } - if len(p) < 2 { - return 0, io.ErrShortBuffer - } - switch t := co.Conn.(type) { - case *net.TCPConn, *tls.Conn: - r := t.(io.Reader) - l, err := tcpMsgLen(r) - if err != nil { + switch co.Conn.(type) { + case *net.TCPConn, *tls.Conn: + var length uint16 + if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil { return 0, err } - if l > len(p) { - return l, io.ErrShortBuffer + if int(length) > len(p) { + return 0, io.ErrShortBuffer } - return tcpRead(r, p[:l]) + + n, err := io.ReadFull(co.Conn, p[:length]) + return int(n), err } + // UDP connection return co.Conn.Read(p) } diff --git a/server.go b/server.go index 4bd4674d..b09d3717 100644 --- a/server.go +++ b/server.go @@ -614,36 +614,16 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) } srv.lock.RUnlock() - l := make([]byte, 2) - n, err := conn.Read(l) - if err != nil || n != 2 { - if err != nil { - return nil, err - } - return nil, ErrShortRead + var length uint16 + if err := binary.Read(conn, binary.BigEndian, &length); err != nil { + return nil, err } - length := binary.BigEndian.Uint16(l) - if length == 0 { - return nil, ErrShortRead + + m := make([]byte, length) + if _, err := io.ReadFull(conn, m); err != nil { + return nil, err } - m := make([]byte, int(length)) - n, err = conn.Read(m[:int(length)]) - if err != nil || n == 0 { - if err != nil { - return nil, err - } - return nil, ErrShortRead - } - i := n - for i < int(length) { - j, err := conn.Read(m[i:int(length)]) - if err != nil { - return nil, err - } - i += j - } - n = i - m = m[:n] + return m, nil }