Export dns.Conn and make it more like net.Conn

Export lowlevel function and types so that they may be used.
They higher level Exchange function is still there. ExchangeConn
is gone, because it is not needed.
This commit is contained in:
Miek Gieben 2013-09-28 21:58:08 +01:00
parent 22977491c3
commit 4bde528be5
3 changed files with 91 additions and 126 deletions

213
client.go
View File

@ -12,17 +12,9 @@ import (
"time"
)
// Order of events:
// *client -> *reply -> Exchange() -> dial()/send()->write()/receive()->read()
// A Conn represents a connection (which may be short lived) to a DNS
// server.
// A Conn represents a connection (which may be short lived) to a DNS server.
type Conn struct {
net.Conn
client *Client
addr string
req *Msg
tsigRequestMAC string
tsigTimersOnly bool
tsigStatus error
@ -30,8 +22,7 @@ type Conn struct {
t time.Time
}
// A Client defines parameter for a DNS client. A nil
// Client is usable for sending queries.
// A Client defines parameters for a DNS client. A nil Client is usable for sending queries.
type Client struct {
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defaults to 2 * 1e9
@ -41,16 +32,31 @@ type Client struct {
group singleflight
}
func Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err) {
// Exchange performs an synchronous UDP query. It sends the message m to the address
// contained in a and waits for an reply.
func Exchange(m *Msg, a string) (r *Msg, err error) {
co := new(Conn)
co.Conn, err = net.DialTimeout("udp", a, 5*1e9)
if err != nil {
return nil, err
}
defer co.Close()
if err = co.WriteMsg(m); err != nil {
return nil, err
}
r, err = co.ReadMsg()
return r, err
}
func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.Duration, err error) {
// Exchange performs an synchronous query. It sends the message m to the address
// contained in a and waits for an reply. Basic use pattern with a *dns.Client:
//
// c := new(dns.Client)
// in, rtt, err := c.Exchange(message, "127.0.0.1:53")
//
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
if !c.SingleInflight {
if s == nil {
return c.exchange(m, a)
}
return c.exchangeConn(m, s)
return c.exchange(m, a)
}
// This adds a bunch of garbage, TODO(miek).
t := "nop"
@ -62,86 +68,47 @@ func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.D
cl = cl1
}
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
if s == nil {
return c.exchange(m, a)
}
return c.exchangeConn(m, s)
return c.exchange(m, a)
})
if err != nil {
return r, rtt, err
}
if shared {
r1 := r.copy()
r1.Id = r.Id // Copy Id!
// not needed r1.Id = r.Id // Copy Id!
r = r1
}
return r, rtt, nil
}
// Exchange performs an synchronous query. It sends the message m to the address
// contained in a and waits for an reply. Basic use pattern with a *dns.Client:
//
// c := new(dns.Client)
// in, rtt, err := c.Exchange(message, "127.0.0.1:53")
//
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
return c.exchangeMerge(m, a, nil)
}
func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
w := &reply{client: c, addr: a}
if err = w.dial(); err != nil {
return nil, 0, err
}
defer w.conn.Close()
if err = w.send(m); err != nil {
return nil, 0, err
}
r, err = w.receive()
return r, w.rtt, err
}
// ExchangeConn performs an synchronous query. It sends the message m trough the
// connection s and waits for a reply.
func (c *Client) ExchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) {
return c.exchangeMerge(m, "", s)
}
func (c *Client) exchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) {
w := &reply{client: c, conn: s}
if err = w.send(m); err != nil {
return nil, 0, err
}
r, err = w.receive()
return r, w.rtt, err
}
// dial connects to the address addr for the network set in c.Net
func (w *reply) dial() (err error) {
var conn net.Conn
if w.client.Net == "" {
conn, err = net.DialTimeout("udp", w.addr, 5*1e9)
co := new(Conn)
if c.Net == "" {
co.Conn, err = net.DialTimeout("udp", a, 5*1e9)
} else {
conn, err = net.DialTimeout(w.client.Net, w.addr, 5*1e9)
co.Conn, err = net.DialTimeout(c.Net, a, 5*1e9)
}
if err != nil {
return err
return nil, 0, err
}
w.conn = conn
return
defer co.Close()
if err = co.WriteMsg(m); err != nil {
return nil, 0, err
}
r, err = co.ReadMsg()
return r, co.rtt, err
}
func (w *reply) receive() (*Msg, error) {
func (co *Conn) ReadMsg() (*Msg, error) {
var p []byte
m := new(Msg)
switch w.client.Net {
case "tcp", "tcp4", "tcp6":
if _, ok := co.Conn.(*net.TCPConn); ok {
p = make([]byte, MaxMsgSize)
case "", "udp", "udp4", "udp6":
} else {
// OPT! TODO(mg)
p = make([]byte, DefaultMsgSize)
}
n, err := w.read(p)
n, err := co.Read(p)
if err != nil && n == 0 {
return nil, err
}
@ -149,30 +116,28 @@ func (w *reply) receive() (*Msg, error) {
if err := m.Unpack(p); err != nil {
return nil, err
}
w.rtt = time.Since(w.t)
if t := m.IsTsig(); t != nil {
secret := t.Hdr.Name
if _, ok := w.client.TsigSecret[secret]; !ok {
w.tsigStatus = ErrSecret
return m, ErrSecret
}
// Need to work on the original message p, as that was used to calculate the tsig.
w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly)
}
return m, w.tsigStatus
co.rtt = time.Since(co.t)
// if t := m.IsTsig(); t != nil {
// secret := t.Hdr.Name
// if _, ok := w.client.TsigSecret[secret]; !ok {
// w.tsigStatus = ErrSecret
// return m, ErrSecret
// }
// // Need to work on the original message p, as that was used to calculate the tsig.
// w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly)
// }
return m, nil
}
func (w *reply) read(p []byte) (n int, err error) {
if w.conn == nil {
func (co *Conn) Read(p []byte) (n int, err error) {
if co.Conn == nil {
return 0, ErrConnEmpty
}
if len(p) < 2 {
return 0, io.ErrShortBuffer
}
switch w.client.Net {
case "tcp", "tcp4", "tcp6":
setTimeouts(w)
n, err = w.conn.(*net.TCPConn).Read(p[0:2])
if t, ok := co.Conn.(*net.TCPConn); ok {
n, err = t.Read(p[0:2])
if err != nil || n != 2 {
return n, err
}
@ -183,25 +148,25 @@ func (w *reply) read(p []byte) (n int, err error) {
if int(l) > len(p) {
return int(l), io.ErrShortBuffer
}
n, err = w.conn.(*net.TCPConn).Read(p[:l])
n, err = t.Read(p[:l])
if err != nil {
return n, err
}
i := n
for i < int(l) {
j, err := w.conn.(*net.TCPConn).Read(p[i:int(l)])
j, err := t.Read(p[i:int(l)])
if err != nil {
return i, err
}
i += j
}
n = i
case "", "udp", "udp4", "udp6":
setTimeouts(w)
n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p)
if err != nil {
return n, err
}
return n, err
}
// assume udp connection
n, _, err = co.Conn.(*net.UDPConn).ReadFromUDP(p)
if err != nil {
return n, err
}
return n, err
}
@ -209,62 +174,57 @@ func (w *reply) read(p []byte) (n int, err error) {
// send sends a dns msg to the address specified in w.
// If the message m contains a TSIG record the transaction
// signature is calculated.
func (w *reply) send(m *Msg) (err error) {
func (co *Conn) WriteMsg(m *Msg) (err error) {
var out []byte
if t := m.IsTsig(); t != nil {
mac := ""
name := t.Hdr.Name
if _, ok := w.client.TsigSecret[name]; !ok {
return ErrSecret
}
out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly)
w.tsigRequestMAC = mac
} else {
out, err = m.Pack()
}
// if t := m.IsTsig(); t != nil {
// mac := ""
// name := t.Hdr.Name
// if _, ok := w.client.TsigSecret[name]; !ok {
// return ErrSecret
// }
// out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly)
// w.tsigRequestMAC = mac
// } else {
out, err = m.Pack()
// }
if err != nil {
return err
}
w.t = time.Now()
if _, err = w.write(out); err != nil {
co.t = time.Now()
if _, err = co.Write(out); err != nil {
return err
}
return nil
}
func (w *reply) write(p []byte) (n int, err error) {
switch w.client.Net {
case "tcp", "tcp4", "tcp6":
func (co *Conn) Write(p []byte) (n int, err error) {
if t, ok := co.Conn.(*net.TCPConn); ok {
if len(p) < 2 {
return 0, io.ErrShortBuffer
}
setTimeouts(w)
l := make([]byte, 2)
l[0], l[1] = packUint16(uint16(len(p)))
p = append(l, p...)
n, err := w.conn.Write(p)
n, err := t.Write(p)
if err != nil {
return n, err
}
i := n
if i < len(p) {
j, err := w.conn.Write(p[i:len(p)])
j, err := t.Write(p[i:len(p)])
if err != nil {
return i, err
}
i += j
}
n = i
case "", "udp", "udp4", "udp6":
setTimeouts(w)
n, err = w.conn.(*net.UDPConn).Write(p)
if err != nil {
return n, err
}
return n, err
}
return
n, err = co.Conn.(*net.UDPConn).Write(p)
return n, err
}
/*
func setTimeouts(w *reply) {
if w.client.ReadTimeout == 0 {
w.conn.SetReadDeadline(time.Now().Add(2 * 1e9))
@ -278,3 +238,4 @@ func setTimeouts(w *reply) {
w.conn.SetWriteDeadline(time.Now().Add(w.client.WriteTimeout))
}
}
*/

View File

@ -396,6 +396,7 @@ func shortRR(r dns.RR) dns.RR {
}
func doXfr(c *dns.Client, m *dns.Msg, nameserver string) {
/*
if t, e := c.TransferIn(m, nameserver); e == nil {
for r := range t {
if r.Error == nil {
@ -412,4 +413,5 @@ func doXfr(c *dns.Client, m *dns.Msg, nameserver string) {
} else {
fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", e.Error())
}
*/
}

2
xfr.go
View File

@ -10,6 +10,7 @@ type Envelope struct {
Error error // If something went wrong, this contains the error.
}
/*
// TransferIn performs a [AI]XFR request (depends on the message's Qtype). It returns
// a channel of *Envelope on which the replies from the server are sent. At the end of
// the transfer the channel is closed.
@ -201,3 +202,4 @@ func xfrOut(w ResponseWriter, req *Msg, c chan *Envelope, e *error) {
rep.Answer = nil
}
}
*/