Export dns.Conn and make it more like net.Conn

Export lowlevel function and types so that they may be used.
They higher level Exchange function is still there. ExchangeConn
is gone, because it is not needed.
This commit is contained in:
Miek Gieben 2013-09-28 21:58:08 +01:00
parent 22977491c3
commit 4bde528be5
3 changed files with 91 additions and 126 deletions

213
client.go
View File

@ -12,17 +12,9 @@ import (
"time" "time"
) )
// Order of events: // A Conn represents a connection (which may be short lived) to a DNS server.
// *client -> *reply -> Exchange() -> dial()/send()->write()/receive()->read()
// A Conn represents a connection (which may be short lived) to a DNS
// server.
type Conn struct { type Conn struct {
net.Conn net.Conn
client *Client
addr string
req *Msg
tsigRequestMAC string tsigRequestMAC string
tsigTimersOnly bool tsigTimersOnly bool
tsigStatus error tsigStatus error
@ -30,8 +22,7 @@ type Conn struct {
t time.Time t time.Time
} }
// A Client defines parameter for a DNS client. A nil // A Client defines parameters for a DNS client. A nil Client is usable for sending queries.
// Client is usable for sending queries.
type Client struct { type Client struct {
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defaults to 2 * 1e9 ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defaults to 2 * 1e9
@ -41,16 +32,31 @@ type Client struct {
group singleflight group singleflight
} }
func Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err) { // Exchange performs an synchronous UDP query. It sends the message m to the address
// contained in a and waits for an reply.
func Exchange(m *Msg, a string) (r *Msg, err error) {
co := new(Conn)
co.Conn, err = net.DialTimeout("udp", a, 5*1e9)
if err != nil {
return nil, err
}
defer co.Close()
if err = co.WriteMsg(m); err != nil {
return nil, err
}
r, err = co.ReadMsg()
return r, err
} }
func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.Duration, err error) { // Exchange performs an synchronous query. It sends the message m to the address
// contained in a and waits for an reply. Basic use pattern with a *dns.Client:
//
// c := new(dns.Client)
// in, rtt, err := c.Exchange(message, "127.0.0.1:53")
//
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
if !c.SingleInflight { if !c.SingleInflight {
if s == nil { return c.exchange(m, a)
return c.exchange(m, a)
}
return c.exchangeConn(m, s)
} }
// This adds a bunch of garbage, TODO(miek). // This adds a bunch of garbage, TODO(miek).
t := "nop" t := "nop"
@ -62,86 +68,47 @@ func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.D
cl = cl1 cl = cl1
} }
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
if s == nil { return c.exchange(m, a)
return c.exchange(m, a)
}
return c.exchangeConn(m, s)
}) })
if err != nil { if err != nil {
return r, rtt, err return r, rtt, err
} }
if shared { if shared {
r1 := r.copy() r1 := r.copy()
r1.Id = r.Id // Copy Id! // not needed r1.Id = r.Id // Copy Id!
r = r1 r = r1
} }
return r, rtt, nil return r, rtt, nil
} }
// Exchange performs an synchronous query. It sends the message m to the address
// contained in a and waits for an reply. Basic use pattern with a *dns.Client:
//
// c := new(dns.Client)
// in, rtt, err := c.Exchange(message, "127.0.0.1:53")
//
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
return c.exchangeMerge(m, a, nil)
}
func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
w := &reply{client: c, addr: a} co := new(Conn)
if err = w.dial(); err != nil { if c.Net == "" {
return nil, 0, err co.Conn, err = net.DialTimeout("udp", a, 5*1e9)
}
defer w.conn.Close()
if err = w.send(m); err != nil {
return nil, 0, err
}
r, err = w.receive()
return r, w.rtt, err
}
// ExchangeConn performs an synchronous query. It sends the message m trough the
// connection s and waits for a reply.
func (c *Client) ExchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) {
return c.exchangeMerge(m, "", s)
}
func (c *Client) exchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) {
w := &reply{client: c, conn: s}
if err = w.send(m); err != nil {
return nil, 0, err
}
r, err = w.receive()
return r, w.rtt, err
}
// dial connects to the address addr for the network set in c.Net
func (w *reply) dial() (err error) {
var conn net.Conn
if w.client.Net == "" {
conn, err = net.DialTimeout("udp", w.addr, 5*1e9)
} else { } else {
conn, err = net.DialTimeout(w.client.Net, w.addr, 5*1e9) co.Conn, err = net.DialTimeout(c.Net, a, 5*1e9)
} }
if err != nil { if err != nil {
return err return nil, 0, err
} }
w.conn = conn defer co.Close()
return if err = co.WriteMsg(m); err != nil {
return nil, 0, err
}
r, err = co.ReadMsg()
return r, co.rtt, err
} }
func (w *reply) receive() (*Msg, error) { func (co *Conn) ReadMsg() (*Msg, error) {
var p []byte var p []byte
m := new(Msg) m := new(Msg)
switch w.client.Net { if _, ok := co.Conn.(*net.TCPConn); ok {
case "tcp", "tcp4", "tcp6":
p = make([]byte, MaxMsgSize) p = make([]byte, MaxMsgSize)
case "", "udp", "udp4", "udp6": } else {
// OPT! TODO(mg) // OPT! TODO(mg)
p = make([]byte, DefaultMsgSize) p = make([]byte, DefaultMsgSize)
} }
n, err := w.read(p) n, err := co.Read(p)
if err != nil && n == 0 { if err != nil && n == 0 {
return nil, err return nil, err
} }
@ -149,30 +116,28 @@ func (w *reply) receive() (*Msg, error) {
if err := m.Unpack(p); err != nil { if err := m.Unpack(p); err != nil {
return nil, err return nil, err
} }
w.rtt = time.Since(w.t) co.rtt = time.Since(co.t)
if t := m.IsTsig(); t != nil { // if t := m.IsTsig(); t != nil {
secret := t.Hdr.Name // secret := t.Hdr.Name
if _, ok := w.client.TsigSecret[secret]; !ok { // if _, ok := w.client.TsigSecret[secret]; !ok {
w.tsigStatus = ErrSecret // w.tsigStatus = ErrSecret
return m, ErrSecret // return m, ErrSecret
} // }
// Need to work on the original message p, as that was used to calculate the tsig. // // Need to work on the original message p, as that was used to calculate the tsig.
w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly) // w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly)
} // }
return m, w.tsigStatus return m, nil
} }
func (w *reply) read(p []byte) (n int, err error) { func (co *Conn) Read(p []byte) (n int, err error) {
if w.conn == nil { if co.Conn == nil {
return 0, ErrConnEmpty return 0, ErrConnEmpty
} }
if len(p) < 2 { if len(p) < 2 {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
switch w.client.Net { if t, ok := co.Conn.(*net.TCPConn); ok {
case "tcp", "tcp4", "tcp6": n, err = t.Read(p[0:2])
setTimeouts(w)
n, err = w.conn.(*net.TCPConn).Read(p[0:2])
if err != nil || n != 2 { if err != nil || n != 2 {
return n, err return n, err
} }
@ -183,25 +148,25 @@ func (w *reply) read(p []byte) (n int, err error) {
if int(l) > len(p) { if int(l) > len(p) {
return int(l), io.ErrShortBuffer return int(l), io.ErrShortBuffer
} }
n, err = w.conn.(*net.TCPConn).Read(p[:l]) n, err = t.Read(p[:l])
if err != nil { if err != nil {
return n, err return n, err
} }
i := n i := n
for i < int(l) { for i < int(l) {
j, err := w.conn.(*net.TCPConn).Read(p[i:int(l)]) j, err := t.Read(p[i:int(l)])
if err != nil { if err != nil {
return i, err return i, err
} }
i += j i += j
} }
n = i n = i
case "", "udp", "udp4", "udp6": return n, err
setTimeouts(w) }
n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p) // assume udp connection
if err != nil { n, _, err = co.Conn.(*net.UDPConn).ReadFromUDP(p)
return n, err if err != nil {
} return n, err
} }
return n, err return n, err
} }
@ -209,62 +174,57 @@ func (w *reply) read(p []byte) (n int, err error) {
// send sends a dns msg to the address specified in w. // send sends a dns msg to the address specified in w.
// If the message m contains a TSIG record the transaction // If the message m contains a TSIG record the transaction
// signature is calculated. // signature is calculated.
func (w *reply) send(m *Msg) (err error) { func (co *Conn) WriteMsg(m *Msg) (err error) {
var out []byte var out []byte
if t := m.IsTsig(); t != nil { // if t := m.IsTsig(); t != nil {
mac := "" // mac := ""
name := t.Hdr.Name // name := t.Hdr.Name
if _, ok := w.client.TsigSecret[name]; !ok { // if _, ok := w.client.TsigSecret[name]; !ok {
return ErrSecret // return ErrSecret
} // }
out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly) // out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly)
w.tsigRequestMAC = mac // w.tsigRequestMAC = mac
} else { // } else {
out, err = m.Pack() out, err = m.Pack()
} // }
if err != nil { if err != nil {
return err return err
} }
w.t = time.Now() co.t = time.Now()
if _, err = w.write(out); err != nil { if _, err = co.Write(out); err != nil {
return err return err
} }
return nil return nil
} }
func (w *reply) write(p []byte) (n int, err error) { func (co *Conn) Write(p []byte) (n int, err error) {
switch w.client.Net { if t, ok := co.Conn.(*net.TCPConn); ok {
case "tcp", "tcp4", "tcp6":
if len(p) < 2 { if len(p) < 2 {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
setTimeouts(w)
l := make([]byte, 2) l := make([]byte, 2)
l[0], l[1] = packUint16(uint16(len(p))) l[0], l[1] = packUint16(uint16(len(p)))
p = append(l, p...) p = append(l, p...)
n, err := w.conn.Write(p) n, err := t.Write(p)
if err != nil { if err != nil {
return n, err return n, err
} }
i := n i := n
if i < len(p) { if i < len(p) {
j, err := w.conn.Write(p[i:len(p)]) j, err := t.Write(p[i:len(p)])
if err != nil { if err != nil {
return i, err return i, err
} }
i += j i += j
} }
n = i n = i
case "", "udp", "udp4", "udp6": return n, err
setTimeouts(w)
n, err = w.conn.(*net.UDPConn).Write(p)
if err != nil {
return n, err
}
} }
return n, err = co.Conn.(*net.UDPConn).Write(p)
return n, err
} }
/*
func setTimeouts(w *reply) { func setTimeouts(w *reply) {
if w.client.ReadTimeout == 0 { if w.client.ReadTimeout == 0 {
w.conn.SetReadDeadline(time.Now().Add(2 * 1e9)) w.conn.SetReadDeadline(time.Now().Add(2 * 1e9))
@ -278,3 +238,4 @@ func setTimeouts(w *reply) {
w.conn.SetWriteDeadline(time.Now().Add(w.client.WriteTimeout)) w.conn.SetWriteDeadline(time.Now().Add(w.client.WriteTimeout))
} }
} }
*/

View File

@ -396,6 +396,7 @@ func shortRR(r dns.RR) dns.RR {
} }
func doXfr(c *dns.Client, m *dns.Msg, nameserver string) { func doXfr(c *dns.Client, m *dns.Msg, nameserver string) {
/*
if t, e := c.TransferIn(m, nameserver); e == nil { if t, e := c.TransferIn(m, nameserver); e == nil {
for r := range t { for r := range t {
if r.Error == nil { if r.Error == nil {
@ -412,4 +413,5 @@ func doXfr(c *dns.Client, m *dns.Msg, nameserver string) {
} else { } else {
fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", e.Error()) fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", e.Error())
} }
*/
} }

2
xfr.go
View File

@ -10,6 +10,7 @@ type Envelope struct {
Error error // If something went wrong, this contains the error. Error error // If something went wrong, this contains the error.
} }
/*
// TransferIn performs a [AI]XFR request (depends on the message's Qtype). It returns // TransferIn performs a [AI]XFR request (depends on the message's Qtype). It returns
// a channel of *Envelope on which the replies from the server are sent. At the end of // a channel of *Envelope on which the replies from the server are sent. At the end of
// the transfer the channel is closed. // the transfer the channel is closed.
@ -201,3 +202,4 @@ func xfrOut(w ResponseWriter, req *Msg, c chan *Envelope, e *error) {
rep.Answer = nil rep.Answer = nil
} }
} }
*/