Simplify TCP reading (#935)

* Simplify Server.readTCP

This slightly alters the error behaviour, but it should not be
observable outside of a decorated reader. I don't believe the old
behaviour was either obvious, documented or correct.

* Simplify TCP reading in client Conn

This alters the error behaviour in possibly observable ways, though
this is quite subtle and may not actually be readily observable.

Conn.ReadMsgHeader should behave the same way and still returns
ErrShortRead for length being too short.

Conn.Read will no longer return ErrShortRead if the length == 0,
otherwise it should be largely similar.

* Remove redundant error check in Conn.ReadMsgHeader
This commit is contained in:
Tom Thorogood 2019-03-11 21:29:25 +10:30 committed by Miek Gieben
parent 337216f9a7
commit 834f456fff
2 changed files with 24 additions and 93 deletions

View File

@ -219,18 +219,15 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
n int
err error
)
switch t := co.Conn.(type) {
switch co.Conn.(type) {
case *net.TCPConn, *tls.Conn:
r := t.(io.Reader)
// First two bytes specify the length of the entire message.
l, err := tcpMsgLen(r)
if err != nil {
var length uint16
if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
return nil, err
}
p = make([]byte, l)
n, err = tcpRead(r, p)
p = make([]byte, length)
n, err = io.ReadFull(co.Conn, p)
default:
if co.UDPSize > MinMsgSize {
p = make([]byte, co.UDPSize)
@ -257,72 +254,26 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
return p, err
}
// tcpMsgLen is a helper func to read first two bytes of stream as uint16 packet length.
func tcpMsgLen(t io.Reader) (int, error) {
p := []byte{0, 0}
n, err := t.Read(p)
if err != nil {
return 0, err
}
// As seen with my local router/switch, returns 1 byte on the above read,
// resulting a a ShortRead. Just write it out (instead of loop) and read the
// other byte.
if n == 1 {
n1, err := t.Read(p[1:])
if err != nil {
return 0, err
}
n += n1
}
if n != 2 {
return 0, ErrShortRead
}
l := binary.BigEndian.Uint16(p)
if l == 0 {
return 0, ErrShortRead
}
return int(l), nil
}
// tcpRead calls TCPConn.Read enough times to fill allocated buffer.
func tcpRead(t io.Reader, p []byte) (int, error) {
n, err := t.Read(p)
if err != nil {
return n, err
}
for n < len(p) {
j, err := t.Read(p[n:])
if err != nil {
return n, err
}
n += j
}
return n, err
}
// Read implements the net.Conn read method.
func (co *Conn) Read(p []byte) (n int, err error) {
if co.Conn == nil {
return 0, ErrConnEmpty
}
if len(p) < 2 {
return 0, io.ErrShortBuffer
}
switch t := co.Conn.(type) {
case *net.TCPConn, *tls.Conn:
r := t.(io.Reader)
l, err := tcpMsgLen(r)
if err != nil {
switch co.Conn.(type) {
case *net.TCPConn, *tls.Conn:
var length uint16
if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
return 0, err
}
if l > len(p) {
return l, io.ErrShortBuffer
if int(length) > len(p) {
return 0, io.ErrShortBuffer
}
return tcpRead(r, p[:l])
n, err := io.ReadFull(co.Conn, p[:length])
return int(n), err
}
// UDP connection
return co.Conn.Read(p)
}

View File

@ -614,36 +614,16 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
}
srv.lock.RUnlock()
l := make([]byte, 2)
n, err := conn.Read(l)
if err != nil || n != 2 {
if err != nil {
return nil, err
}
return nil, ErrShortRead
var length uint16
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
return nil, err
}
length := binary.BigEndian.Uint16(l)
if length == 0 {
return nil, ErrShortRead
m := make([]byte, length)
if _, err := io.ReadFull(conn, m); err != nil {
return nil, err
}
m := make([]byte, int(length))
n, err = conn.Read(m[:int(length)])
if err != nil || n == 0 {
if err != nil {
return nil, err
}
return nil, ErrShortRead
}
i := n
for i < int(length) {
j, err := conn.Read(m[i:int(length)])
if err != nil {
return nil, err
}
i += j
}
n = i
m = m[:n]
return m, nil
}