Added function for lazy message reading per #222
This commit is contained in:
parent
ad7777796e
commit
2f3bcbd506
42
client.go
42
client.go
|
@ -201,14 +201,16 @@ func (co *Conn) ReadMsg() (*Msg, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
n, err := co.Read(p)
|
n, err := co.Read(p)
|
||||||
if err != nil && n == 0 {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
} else if n < 12 {
|
||||||
|
return nil, ErrShortRead
|
||||||
}
|
}
|
||||||
|
|
||||||
p = p[:n]
|
p = p[:n]
|
||||||
if err := m.Unpack(p); err != nil {
|
if err := m.Unpack(p); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
co.rtt = time.Since(co.t)
|
|
||||||
if t := m.IsTsig(); t != nil {
|
if t := m.IsTsig(); t != nil {
|
||||||
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
|
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
|
||||||
return m, ErrSecret
|
return m, ErrSecret
|
||||||
|
@ -219,6 +221,40 @@ func (co *Conn) ReadMsg() (*Msg, error) {
|
||||||
return m, err
|
return m, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadMsgBytes reads a message bytes from the connection. Parses and fills
|
||||||
|
// dns message wire header (passing nil would skip header parsing) and
|
||||||
|
// returns message bytes to process them later.
|
||||||
|
//
|
||||||
|
// Note that this function would not be able to report TSIG error or
|
||||||
|
// check it got actual DNS payload.
|
||||||
|
func (co *Conn) ReadMsgBytes(hdr *Header) ([]byte, error) {
|
||||||
|
var p []byte
|
||||||
|
if _, ok := co.Conn.(*net.TCPConn); ok {
|
||||||
|
p = make([]byte, MaxMsgSize)
|
||||||
|
} else {
|
||||||
|
if co.UDPSize > MinMsgSize {
|
||||||
|
p = make([]byte, co.UDPSize)
|
||||||
|
} else {
|
||||||
|
p = make([]byte, MinMsgSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := co.Read(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if n < 12 {
|
||||||
|
return nil, ErrShortRead
|
||||||
|
}
|
||||||
|
|
||||||
|
p = p[:n]
|
||||||
|
if hdr != nil {
|
||||||
|
if _, err = UnpackStruct(hdr, p, 0); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
// Read implements the net.Conn read method.
|
// Read implements the net.Conn read method.
|
||||||
func (co *Conn) Read(p []byte) (n int, err error) {
|
func (co *Conn) Read(p []byte) (n int, err error) {
|
||||||
if co.Conn == nil {
|
if co.Conn == nil {
|
||||||
|
@ -259,6 +295,8 @@ func (co *Conn) Read(p []byte) (n int, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
co.rtt = time.Since(co.t)
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ func TestClientSync(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to exchange: %v", err)
|
t.Errorf("failed to exchange: %v", err)
|
||||||
}
|
}
|
||||||
if r != nil && r.Rcode != RcodeSuccess {
|
if r == nil || r.Rcode != RcodeSuccess {
|
||||||
t.Errorf("failed to get an valid answer\n%v", r)
|
t.Errorf("failed to get an valid answer\n%v", r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -235,3 +235,52 @@ func ExampleUpdateLeaseTSIG(t *testing.T) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientConn(t *testing.T) {
|
||||||
|
HandleFunc("miek.nl.", HelloServer)
|
||||||
|
defer HandleRemove("miek.nl.")
|
||||||
|
|
||||||
|
// This uses TCP just to make it slightly different than TestClientSync
|
||||||
|
s, addrstr, err := RunLocalTCPServer("127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to run test server: %v", err)
|
||||||
|
}
|
||||||
|
defer s.Shutdown()
|
||||||
|
|
||||||
|
m := new(Msg)
|
||||||
|
m.SetQuestion("miek.nl.", TypeSOA)
|
||||||
|
|
||||||
|
cn, err := Dial("tcp", addrstr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to dial %s: %v", addrstr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cn.WriteMsg(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to exchange: %v", err)
|
||||||
|
}
|
||||||
|
r, err := cn.ReadMsg()
|
||||||
|
if r == nil || r.Rcode != RcodeSuccess {
|
||||||
|
t.Errorf("failed to get an valid answer\n%v", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cn.WriteMsg(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to exchange: %v", err)
|
||||||
|
}
|
||||||
|
h := new(Header)
|
||||||
|
buf, err := cn.ReadMsgBytes(h)
|
||||||
|
if buf == nil {
|
||||||
|
t.Errorf("failed to get an valid answer\n%v", r)
|
||||||
|
}
|
||||||
|
if int(h.Bits&0xF) != RcodeSuccess {
|
||||||
|
t.Errorf("failed to get an valid answer in ReadMsgBytes\n%v", r)
|
||||||
|
}
|
||||||
|
if h.Ancount != 0 || h.Qdcount != 1 || h.Nscount != 0 || h.Arcount != 1 {
|
||||||
|
t.Errorf("expected to have question and additional in response; got something else: %+v", h)
|
||||||
|
}
|
||||||
|
if err = r.Unpack(buf); err != nil {
|
||||||
|
t.Errorf("unable to unpack message fully: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue