Handle all net.Conn connections correctly (#957)

* Change switch to if condition

* Update switch to if in read function
This commit is contained in:
Pepijnvi 2019-05-22 15:38:57 +02:00 committed by Miek Gieben
parent a2c73fb86d
commit fbd426fefa
1 changed files with 33 additions and 35 deletions

View File

@ -215,8 +215,15 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
n int
err error
)
switch co.Conn.(type) {
case *net.TCPConn, *tls.Conn:
if _, ok := co.Conn.(net.PacketConn); ok {
if co.UDPSize > MinMsgSize {
p = make([]byte, co.UDPSize)
} else {
p = make([]byte, MinMsgSize)
}
n, err = co.Read(p)
} else {
var length uint16
if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
return nil, err
@ -224,13 +231,6 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
p = make([]byte, length)
n, err = io.ReadFull(co.Conn, p)
default:
if co.UDPSize > MinMsgSize {
p = make([]byte, co.UDPSize)
} else {
p = make([]byte, MinMsgSize)
}
n, err = co.Read(p)
}
if err != nil {
@ -256,21 +256,20 @@ func (co *Conn) Read(p []byte) (n int, err error) {
return 0, ErrConnEmpty
}
switch co.Conn.(type) {
case *net.TCPConn, *tls.Conn:
var length uint16
if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
return 0, err
}
if int(length) > len(p) {
return 0, io.ErrShortBuffer
}
return io.ReadFull(co.Conn, p[:length])
if _, ok := co.Conn.(net.PacketConn); ok {
// UDP connection
return co.Conn.Read(p)
}
// UDP connection
return co.Conn.Read(p)
var length uint16
if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
return 0, err
}
if int(length) > len(p) {
return 0, io.ErrShortBuffer
}
return io.ReadFull(co.Conn, p[:length])
}
// WriteMsg sends a message through the connection co.
@ -297,21 +296,20 @@ func (co *Conn) WriteMsg(m *Msg) (err error) {
}
// Write implements the net.Conn Write method.
func (co *Conn) Write(p []byte) (n int, err error) {
switch co.Conn.(type) {
case *net.TCPConn, *tls.Conn:
if len(p) > MaxMsgSize {
return 0, &Error{err: "message too large"}
}
l := make([]byte, 2)
binary.BigEndian.PutUint16(l, uint16(len(p)))
n, err := (&net.Buffers{l, p}).WriteTo(co.Conn)
return int(n), err
func (co *Conn) Write(p []byte) (int, error) {
if len(p) > MaxMsgSize {
return 0, &Error{err: "message too large"}
}
return co.Conn.Write(p)
if _, ok := co.Conn.(net.PacketConn); ok {
return co.Conn.Write(p)
}
l := make([]byte, 2)
binary.BigEndian.PutUint16(l, uint16(len(p)))
n, err := (&net.Buffers{l, p}).WriteTo(co.Conn)
return int(n), err
}
// Return the appropriate timeout for a specific request