From 2f3bcbd50606772286055e61c8bce40a405d4e98 Mon Sep 17 00:00:00 2001 From: Alex Sergeyev Date: Wed, 24 Jun 2015 15:09:46 -0400 Subject: [PATCH] Added function for lazy message reading per #222 --- client.go | 42 +++++++++++++++++++++++++++++++++++++++-- client_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 53ae6eab..a8a6b0ec 100644 --- a/client.go +++ b/client.go @@ -201,14 +201,16 @@ func (co *Conn) ReadMsg() (*Msg, error) { } } n, err := co.Read(p) - if err != nil && n == 0 { + if err != nil { return nil, err + } else if n < 12 { + return nil, ErrShortRead } + p = p[:n] if err := m.Unpack(p); err != nil { return nil, err } - co.rtt = time.Since(co.t) if t := m.IsTsig(); t != nil { if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { return m, ErrSecret @@ -219,6 +221,40 @@ func (co *Conn) ReadMsg() (*Msg, error) { return m, err } +// ReadMsgBytes reads a message bytes from the connection. Parses and fills +// dns message wire header (passing nil would skip header parsing) and +// returns message bytes to process them later. +// +// Note that this function would not be able to report TSIG error or +// check it got actual DNS payload. +func (co *Conn) ReadMsgBytes(hdr *Header) ([]byte, error) { + var p []byte + if _, ok := co.Conn.(*net.TCPConn); ok { + p = make([]byte, MaxMsgSize) + } else { + if co.UDPSize > MinMsgSize { + p = make([]byte, co.UDPSize) + } else { + p = make([]byte, MinMsgSize) + } + } + + n, err := co.Read(p) + if err != nil { + return nil, err + } else if n < 12 { + return nil, ErrShortRead + } + + p = p[:n] + if hdr != nil { + if _, err = UnpackStruct(hdr, p, 0); err != nil { + return nil, err + } + } + return p, err +} + // Read implements the net.Conn read method. func (co *Conn) Read(p []byte) (n int, err error) { if co.Conn == nil { @@ -259,6 +295,8 @@ func (co *Conn) Read(p []byte) (n int, err error) { if err != nil { return n, err } + + co.rtt = time.Since(co.t) return n, err } diff --git a/client_test.go b/client_test.go index 8a70c7ea..15fe7701 100644 --- a/client_test.go +++ b/client_test.go @@ -32,7 +32,7 @@ func TestClientSync(t *testing.T) { if err != nil { t.Errorf("failed to exchange: %v", err) } - if r != nil && r.Rcode != RcodeSuccess { + if r == nil || r.Rcode != RcodeSuccess { t.Errorf("failed to get an valid answer\n%v", r) } } @@ -235,3 +235,52 @@ func ExampleUpdateLeaseTSIG(t *testing.T) { t.Error(err) } } + +func TestClientConn(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + defer HandleRemove("miek.nl.") + + // This uses TCP just to make it slightly different than TestClientSync + s, addrstr, err := RunLocalTCPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %v", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeSOA) + + cn, err := Dial("tcp", addrstr) + if err != nil { + t.Errorf("failed to dial %s: %v", addrstr, err) + } + + err = cn.WriteMsg(m) + if err != nil { + t.Errorf("failed to exchange: %v", err) + } + r, err := cn.ReadMsg() + if r == nil || r.Rcode != RcodeSuccess { + t.Errorf("failed to get an valid answer\n%v", r) + } + + err = cn.WriteMsg(m) + if err != nil { + t.Errorf("failed to exchange: %v", err) + } + h := new(Header) + buf, err := cn.ReadMsgBytes(h) + if buf == nil { + t.Errorf("failed to get an valid answer\n%v", r) + } + if int(h.Bits&0xF) != RcodeSuccess { + t.Errorf("failed to get an valid answer in ReadMsgBytes\n%v", r) + } + if h.Ancount != 0 || h.Qdcount != 1 || h.Nscount != 0 || h.Arcount != 1 { + t.Errorf("expected to have question and additional in response; got something else: %+v", h) + } + if err = r.Unpack(buf); err != nil { + t.Errorf("unable to unpack message fully: %v", err) + } + +}