Added function for lazy message reading per #222

This commit is contained in:
Alex Sergeyev 2015-06-24 15:09:46 -04:00
parent ad7777796e
commit 2f3bcbd506
2 changed files with 90 additions and 3 deletions

View File

@ -201,14 +201,16 @@ func (co *Conn) ReadMsg() (*Msg, error) {
}
}
n, err := co.Read(p)
if err != nil && n == 0 {
if err != nil {
return nil, err
} else if n < 12 {
return nil, ErrShortRead
}
p = p[:n]
if err := m.Unpack(p); err != nil {
return nil, err
}
co.rtt = time.Since(co.t)
if t := m.IsTsig(); t != nil {
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
return m, ErrSecret
@ -219,6 +221,40 @@ func (co *Conn) ReadMsg() (*Msg, error) {
return m, err
}
// ReadMsgBytes reads a message bytes from the connection. Parses and fills
// dns message wire header (passing nil would skip header parsing) and
// returns message bytes to process them later.
//
// Note that this function would not be able to report TSIG error or
// check it got actual DNS payload.
func (co *Conn) ReadMsgBytes(hdr *Header) ([]byte, error) {
var p []byte
if _, ok := co.Conn.(*net.TCPConn); ok {
p = make([]byte, MaxMsgSize)
} else {
if co.UDPSize > MinMsgSize {
p = make([]byte, co.UDPSize)
} else {
p = make([]byte, MinMsgSize)
}
}
n, err := co.Read(p)
if err != nil {
return nil, err
} else if n < 12 {
return nil, ErrShortRead
}
p = p[:n]
if hdr != nil {
if _, err = UnpackStruct(hdr, p, 0); err != nil {
return nil, err
}
}
return p, err
}
// Read implements the net.Conn read method.
func (co *Conn) Read(p []byte) (n int, err error) {
if co.Conn == nil {
@ -259,6 +295,8 @@ func (co *Conn) Read(p []byte) (n int, err error) {
if err != nil {
return n, err
}
co.rtt = time.Since(co.t)
return n, err
}

View File

@ -32,7 +32,7 @@ func TestClientSync(t *testing.T) {
if err != nil {
t.Errorf("failed to exchange: %v", err)
}
if r != nil && r.Rcode != RcodeSuccess {
if r == nil || r.Rcode != RcodeSuccess {
t.Errorf("failed to get an valid answer\n%v", r)
}
}
@ -235,3 +235,52 @@ func ExampleUpdateLeaseTSIG(t *testing.T) {
t.Error(err)
}
}
func TestClientConn(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
// This uses TCP just to make it slightly different than TestClientSync
s, addrstr, err := RunLocalTCPServer("127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to run test server: %v", err)
}
defer s.Shutdown()
m := new(Msg)
m.SetQuestion("miek.nl.", TypeSOA)
cn, err := Dial("tcp", addrstr)
if err != nil {
t.Errorf("failed to dial %s: %v", addrstr, err)
}
err = cn.WriteMsg(m)
if err != nil {
t.Errorf("failed to exchange: %v", err)
}
r, err := cn.ReadMsg()
if r == nil || r.Rcode != RcodeSuccess {
t.Errorf("failed to get an valid answer\n%v", r)
}
err = cn.WriteMsg(m)
if err != nil {
t.Errorf("failed to exchange: %v", err)
}
h := new(Header)
buf, err := cn.ReadMsgBytes(h)
if buf == nil {
t.Errorf("failed to get an valid answer\n%v", r)
}
if int(h.Bits&0xF) != RcodeSuccess {
t.Errorf("failed to get an valid answer in ReadMsgBytes\n%v", r)
}
if h.Ancount != 0 || h.Qdcount != 1 || h.Nscount != 0 || h.Arcount != 1 {
t.Errorf("expected to have question and additional in response; got something else: %+v", h)
}
if err = r.Unpack(buf); err != nil {
t.Errorf("unable to unpack message fully: %v", err)
}
}