Merge pull request #224 from asergeyev/master
Added function for lazy message reading per #222
This commit is contained in:
commit
6a8b26eb31
123
client.go
123
client.go
|
@ -189,26 +189,15 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
|
||||||
// If the received message contains a TSIG record the transaction
|
// If the received message contains a TSIG record the transaction
|
||||||
// signature is verified.
|
// signature is verified.
|
||||||
func (co *Conn) ReadMsg() (*Msg, error) {
|
func (co *Conn) ReadMsg() (*Msg, error) {
|
||||||
var p []byte
|
p, err := co.ReadMsgBytes(nil)
|
||||||
m := new(Msg)
|
if err != nil {
|
||||||
if _, ok := co.Conn.(*net.TCPConn); ok {
|
|
||||||
p = make([]byte, MaxMsgSize)
|
|
||||||
} else {
|
|
||||||
if co.UDPSize > MinMsgSize {
|
|
||||||
p = make([]byte, co.UDPSize)
|
|
||||||
} else {
|
|
||||||
p = make([]byte, MinMsgSize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
n, err := co.Read(p)
|
|
||||||
if err != nil && n == 0 {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
p = p[:n]
|
|
||||||
|
m := new(Msg)
|
||||||
if err := m.Unpack(p); err != nil {
|
if err := m.Unpack(p); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
co.rtt = time.Since(co.t)
|
|
||||||
if t := m.IsTsig(); t != nil {
|
if t := m.IsTsig(); t != nil {
|
||||||
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
|
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
|
||||||
return m, ErrSecret
|
return m, ErrSecret
|
||||||
|
@ -219,6 +208,81 @@ func (co *Conn) ReadMsg() (*Msg, error) {
|
||||||
return m, err
|
return m, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadMsgBytes reads a DNS message, parses and populates hdr (when hdr is not nil).
|
||||||
|
// Returns message as a byte slice to pasrse with Msg.Unpack later on.
|
||||||
|
// Note that error handling on the message body is not possible as only the header is parsed.
|
||||||
|
func (co *Conn) ReadMsgBytes(hdr *Header) ([]byte, error) {
|
||||||
|
var (
|
||||||
|
p []byte
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if t, ok := co.Conn.(*net.TCPConn); ok {
|
||||||
|
// First two bytes specify the length of the entire message.
|
||||||
|
l, err := tcpMsgLen(t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p = make([]byte, l)
|
||||||
|
n, err = tcpRead(t, p)
|
||||||
|
} else {
|
||||||
|
if co.UDPSize > MinMsgSize {
|
||||||
|
p = make([]byte, co.UDPSize)
|
||||||
|
} else {
|
||||||
|
p = make([]byte, MinMsgSize)
|
||||||
|
}
|
||||||
|
n, err = co.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if n < _DNSHeaderSize {
|
||||||
|
return nil, ErrShortRead
|
||||||
|
}
|
||||||
|
|
||||||
|
p = p[:n]
|
||||||
|
if hdr != nil {
|
||||||
|
if _, err = UnpackStruct(hdr, p, 0); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpMsgLen is a helper func to read first two bytes of stream as uint16 packet length.
|
||||||
|
func tcpMsgLen(t *net.TCPConn) (int, error) {
|
||||||
|
p := []byte{0, 0}
|
||||||
|
n, err := t.Read(p)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n != 2 {
|
||||||
|
return 0, ErrShortRead
|
||||||
|
}
|
||||||
|
l, _ := unpackUint16(p, 0)
|
||||||
|
if l == 0 {
|
||||||
|
return 0, ErrShortRead
|
||||||
|
}
|
||||||
|
return int(l), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpRead calls TCPConn.Read enough times to fill allocated buffer.
|
||||||
|
func tcpRead(t *net.TCPConn, p []byte) (int, error) {
|
||||||
|
n, err := t.Read(p)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
for n < len(p) {
|
||||||
|
j, err := t.Read(p[n:])
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
n += j
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
// Read implements the net.Conn read method.
|
// Read implements the net.Conn read method.
|
||||||
func (co *Conn) Read(p []byte) (n int, err error) {
|
func (co *Conn) Read(p []byte) (n int, err error) {
|
||||||
if co.Conn == nil {
|
if co.Conn == nil {
|
||||||
|
@ -228,37 +292,22 @@ func (co *Conn) Read(p []byte) (n int, err error) {
|
||||||
return 0, io.ErrShortBuffer
|
return 0, io.ErrShortBuffer
|
||||||
}
|
}
|
||||||
if t, ok := co.Conn.(*net.TCPConn); ok {
|
if t, ok := co.Conn.(*net.TCPConn); ok {
|
||||||
n, err = t.Read(p[0:2])
|
l, err := tcpMsgLen(t)
|
||||||
if err != nil || n != 2 {
|
if err != nil {
|
||||||
return n, err
|
return 0, err
|
||||||
}
|
}
|
||||||
l, _ := unpackUint16(p[0:2], 0)
|
if l > len(p) {
|
||||||
if l == 0 {
|
|
||||||
return 0, ErrShortRead
|
|
||||||
}
|
|
||||||
if int(l) > len(p) {
|
|
||||||
return int(l), io.ErrShortBuffer
|
return int(l), io.ErrShortBuffer
|
||||||
}
|
}
|
||||||
n, err = t.Read(p[:l])
|
return tcpRead(t, p[:l])
|
||||||
if err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
i := n
|
|
||||||
for i < int(l) {
|
|
||||||
j, err := t.Read(p[i:int(l)])
|
|
||||||
if err != nil {
|
|
||||||
return i, err
|
|
||||||
}
|
|
||||||
i += j
|
|
||||||
}
|
|
||||||
n = i
|
|
||||||
return n, err
|
|
||||||
}
|
}
|
||||||
// UDP connection
|
// UDP connection
|
||||||
n, err = co.Conn.Read(p)
|
n, err = co.Conn.Read(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
co.rtt = time.Since(co.t)
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ func TestClientSync(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to exchange: %v", err)
|
t.Errorf("failed to exchange: %v", err)
|
||||||
}
|
}
|
||||||
if r != nil && r.Rcode != RcodeSuccess {
|
if r == nil || r.Rcode != RcodeSuccess {
|
||||||
t.Errorf("failed to get an valid answer\n%v", r)
|
t.Errorf("failed to get an valid answer\n%v", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -235,3 +235,52 @@ func ExampleUpdateLeaseTSIG(t *testing.T) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientConn(t *testing.T) {
|
||||||
|
HandleFunc("miek.nl.", HelloServer)
|
||||||
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
|
// This uses TCP just to make it slightly different than TestClientSync
|
||||||
|
s, addrstr, err := RunLocalTCPServer("127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to run test server: %v", err)
|
||||||
|
}
|
||||||
|
defer s.Shutdown()
|
||||||
|
|
||||||
|
m := new(Msg)
|
||||||
|
m.SetQuestion("miek.nl.", TypeSOA)
|
||||||
|
|
||||||
|
cn, err := Dial("tcp", addrstr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to dial %s: %v", addrstr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cn.WriteMsg(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to exchange: %v", err)
|
||||||
|
}
|
||||||
|
r, err := cn.ReadMsg()
|
||||||
|
if r == nil || r.Rcode != RcodeSuccess {
|
||||||
|
t.Errorf("failed to get an valid answer\n%v", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cn.WriteMsg(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to exchange: %v", err)
|
||||||
|
}
|
||||||
|
h := new(Header)
|
||||||
|
buf, err := cn.ReadMsgBytes(h)
|
||||||
|
if buf == nil {
|
||||||
|
t.Errorf("failed to get an valid answer\n%v", r)
|
||||||
|
}
|
||||||
|
if int(h.Bits&0xF) != RcodeSuccess {
|
||||||
|
t.Errorf("failed to get an valid answer in ReadMsgBytes\n%v", r)
|
||||||
|
}
|
||||||
|
if h.Ancount != 0 || h.Qdcount != 1 || h.Nscount != 0 || h.Arcount != 1 {
|
||||||
|
t.Errorf("expected to have question and additional in response; got something else: %+v", h)
|
||||||
|
}
|
||||||
|
if err = r.Unpack(buf); err != nil {
|
||||||
|
t.Errorf("unable to unpack message fully: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
8
types.go
8
types.go
|
@ -158,6 +158,9 @@ type Header struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// Header Size
|
||||||
|
_DNSHeaderSize = 12
|
||||||
|
|
||||||
// Header.Bits
|
// Header.Bits
|
||||||
_QR = 1 << 15 // query/response (response=1)
|
_QR = 1 << 15 // query/response (response=1)
|
||||||
_AA = 1 << 10 // authoritative
|
_AA = 1 << 10 // authoritative
|
||||||
|
@ -1568,8 +1571,9 @@ type CAA struct {
|
||||||
func (rr *CAA) Header() *RR_Header { return &rr.Hdr }
|
func (rr *CAA) Header() *RR_Header { return &rr.Hdr }
|
||||||
func (rr *CAA) copy() RR { return &CAA{*rr.Hdr.copyHeader(), rr.Flag, rr.Tag, rr.Value} }
|
func (rr *CAA) copy() RR { return &CAA{*rr.Hdr.copyHeader(), rr.Flag, rr.Tag, rr.Value} }
|
||||||
func (rr *CAA) len() int { return rr.Hdr.len() + 1 + len(rr.Tag) + len(rr.Value)/2 }
|
func (rr *CAA) len() int { return rr.Hdr.len() + 1 + len(rr.Tag) + len(rr.Value)/2 }
|
||||||
func (rr *CAA) String() string { return rr.Hdr.String() + strconv.Itoa(int(rr.Flag)) + " " + rr.Tag + " " + sprintCAAValue(rr.Value) }
|
func (rr *CAA) String() string {
|
||||||
|
return rr.Hdr.String() + strconv.Itoa(int(rr.Flag)) + " " + rr.Tag + " " + sprintCAAValue(rr.Value)
|
||||||
|
}
|
||||||
|
|
||||||
type UID struct {
|
type UID struct {
|
||||||
Hdr RR_Header
|
Hdr RR_Header
|
||||||
|
|
Loading…
Reference in New Issue