Not allocating 64K buffers for reading
This commit is contained in:
parent
53dfadf090
commit
b625f190ce
78
client.go
78
client.go
|
@ -215,20 +215,29 @@ func (co *Conn) ReadMsg() (*Msg, error) {
|
|||
// Note that this function would not be able to report TSIG error or
|
||||
// check it got actual DNS payload.
|
||||
func (co *Conn) ReadMsgBytes(hdr *Header) ([]byte, error) {
|
||||
var p []byte
|
||||
var (
|
||||
p []byte
|
||||
n int
|
||||
err error
|
||||
)
|
||||
|
||||
if _, ok := co.Conn.(*net.TCPConn); ok {
|
||||
// we got two byte
|
||||
p = make([]byte, MaxMsgSize)
|
||||
if t, ok := co.Conn.(*net.TCPConn); ok {
|
||||
// we got two byte header to know how much to receive...
|
||||
l, err := tcpMsgLen(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p = make([]byte, l)
|
||||
n, err = tcpRead(t, p)
|
||||
} else {
|
||||
if co.UDPSize > MinMsgSize {
|
||||
p = make([]byte, co.UDPSize)
|
||||
} else {
|
||||
p = make([]byte, MinMsgSize)
|
||||
}
|
||||
n, err = co.Read(p)
|
||||
}
|
||||
|
||||
n, err := co.Read(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if n < _HBytes {
|
||||
|
@ -244,6 +253,38 @@ func (co *Conn) ReadMsgBytes(hdr *Header) ([]byte, error) {
|
|||
return p, err
|
||||
}
|
||||
|
||||
// tcpMsgLen - helper func to read first two bytes of stream as uint16 packet length
|
||||
func tcpMsgLen(t *net.TCPConn) (int, error) {
|
||||
p := [2]byte{0, 0}
|
||||
n, err := t.Read(p[:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if n != 2 {
|
||||
return 0, ErrShortRead
|
||||
}
|
||||
l, _ := unpackUint16(p[:], 0)
|
||||
if l == 0 {
|
||||
return 0, ErrShortRead
|
||||
}
|
||||
return int(l), nil
|
||||
}
|
||||
|
||||
// tcpRead - calls TCPConn.Read enough times to fill allocated buffer
|
||||
func tcpRead(t *net.TCPConn, p []byte) (int, error) {
|
||||
n, err := t.Read(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
for n < len(p) {
|
||||
j, err := t.Read(p[n:])
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
n += j
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Read implements the net.Conn read method.
|
||||
func (co *Conn) Read(p []byte) (n int, err error) {
|
||||
if co.Conn == nil {
|
||||
|
@ -253,31 +294,14 @@ func (co *Conn) Read(p []byte) (n int, err error) {
|
|||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
if t, ok := co.Conn.(*net.TCPConn); ok {
|
||||
n, err = t.Read(p[0:2])
|
||||
if err != nil || n != 2 {
|
||||
return n, err
|
||||
l, err := tcpMsgLen(t)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
l, _ := unpackUint16(p[0:2], 0)
|
||||
if l == 0 {
|
||||
return 0, ErrShortRead
|
||||
}
|
||||
if int(l) > len(p) {
|
||||
if l > len(p) {
|
||||
return int(l), io.ErrShortBuffer
|
||||
}
|
||||
n, err = t.Read(p[:l])
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
i := n
|
||||
for i < int(l) {
|
||||
j, err := t.Read(p[i:int(l)])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
i += j
|
||||
}
|
||||
n = i
|
||||
return n, err
|
||||
return tcpRead(t, p[:l])
|
||||
}
|
||||
// UDP connection
|
||||
n, err = co.Conn.Read(p)
|
||||
|
|
Loading…
Reference in New Issue