From 4bde528be5c29a9f8bfd5e941d76d66dbf2b21db Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Sat, 28 Sep 2013 21:58:08 +0100 Subject: [PATCH] Export dns.Conn and make it more like net.Conn Export lowlevel function and types so that they may be used. They higher level Exchange function is still there. ExchangeConn is gone, because it is not needed. --- client.go | 213 ++++++++++++++++++++++-------------------------------- ex/q/q.go | 2 + xfr.go | 2 + 3 files changed, 91 insertions(+), 126 deletions(-) diff --git a/client.go b/client.go index ea437afb..10bbdce2 100644 --- a/client.go +++ b/client.go @@ -12,17 +12,9 @@ import ( "time" ) -// Order of events: -// *client -> *reply -> Exchange() -> dial()/send()->write()/receive()->read() - - -// A Conn represents a connection (which may be short lived) to a DNS -// server. +// A Conn represents a connection (which may be short lived) to a DNS server. type Conn struct { net.Conn - client *Client - addr string - req *Msg tsigRequestMAC string tsigTimersOnly bool tsigStatus error @@ -30,8 +22,7 @@ type Conn struct { t time.Time } -// A Client defines parameter for a DNS client. A nil -// Client is usable for sending queries. +// A Client defines parameters for a DNS client. A nil Client is usable for sending queries. type Client struct { Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defaults to 2 * 1e9 @@ -41,16 +32,31 @@ type Client struct { group singleflight } -func Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err) { - +// Exchange performs an synchronous UDP query. It sends the message m to the address +// contained in a and waits for an reply. +func Exchange(m *Msg, a string) (r *Msg, err error) { + co := new(Conn) + co.Conn, err = net.DialTimeout("udp", a, 5*1e9) + if err != nil { + return nil, err + } + defer co.Close() + if err = co.WriteMsg(m); err != nil { + return nil, err + } + r, err = co.ReadMsg() + return r, err } -func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.Duration, err error) { +// Exchange performs an synchronous query. It sends the message m to the address +// contained in a and waits for an reply. Basic use pattern with a *dns.Client: +// +// c := new(dns.Client) +// in, rtt, err := c.Exchange(message, "127.0.0.1:53") +// +func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { if !c.SingleInflight { - if s == nil { - return c.exchange(m, a) - } - return c.exchangeConn(m, s) + return c.exchange(m, a) } // This adds a bunch of garbage, TODO(miek). t := "nop" @@ -62,86 +68,47 @@ func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.D cl = cl1 } r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { - if s == nil { - return c.exchange(m, a) - } - return c.exchangeConn(m, s) + return c.exchange(m, a) }) if err != nil { return r, rtt, err } if shared { r1 := r.copy() - r1.Id = r.Id // Copy Id! +// not needed r1.Id = r.Id // Copy Id! r = r1 } return r, rtt, nil } -// Exchange performs an synchronous query. It sends the message m to the address -// contained in a and waits for an reply. Basic use pattern with a *dns.Client: -// -// c := new(dns.Client) -// in, rtt, err := c.Exchange(message, "127.0.0.1:53") -// -func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - return c.exchangeMerge(m, a, nil) -} - func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - w := &reply{client: c, addr: a} - if err = w.dial(); err != nil { - return nil, 0, err - } - defer w.conn.Close() - if err = w.send(m); err != nil { - return nil, 0, err - } - r, err = w.receive() - return r, w.rtt, err -} - -// ExchangeConn performs an synchronous query. It sends the message m trough the -// connection s and waits for a reply. -func (c *Client) ExchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) { - return c.exchangeMerge(m, "", s) -} - -func (c *Client) exchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) { - w := &reply{client: c, conn: s} - if err = w.send(m); err != nil { - return nil, 0, err - } - r, err = w.receive() - return r, w.rtt, err -} - -// dial connects to the address addr for the network set in c.Net -func (w *reply) dial() (err error) { - var conn net.Conn - if w.client.Net == "" { - conn, err = net.DialTimeout("udp", w.addr, 5*1e9) + co := new(Conn) + if c.Net == "" { + co.Conn, err = net.DialTimeout("udp", a, 5*1e9) } else { - conn, err = net.DialTimeout(w.client.Net, w.addr, 5*1e9) + co.Conn, err = net.DialTimeout(c.Net, a, 5*1e9) } if err != nil { - return err + return nil, 0, err } - w.conn = conn - return + defer co.Close() + if err = co.WriteMsg(m); err != nil { + return nil, 0, err + } + r, err = co.ReadMsg() + return r, co.rtt, err } -func (w *reply) receive() (*Msg, error) { +func (co *Conn) ReadMsg() (*Msg, error) { var p []byte m := new(Msg) - switch w.client.Net { - case "tcp", "tcp4", "tcp6": + if _, ok := co.Conn.(*net.TCPConn); ok { p = make([]byte, MaxMsgSize) - case "", "udp", "udp4", "udp6": + } else { // OPT! TODO(mg) p = make([]byte, DefaultMsgSize) } - n, err := w.read(p) + n, err := co.Read(p) if err != nil && n == 0 { return nil, err } @@ -149,30 +116,28 @@ func (w *reply) receive() (*Msg, error) { if err := m.Unpack(p); err != nil { return nil, err } - w.rtt = time.Since(w.t) - if t := m.IsTsig(); t != nil { - secret := t.Hdr.Name - if _, ok := w.client.TsigSecret[secret]; !ok { - w.tsigStatus = ErrSecret - return m, ErrSecret - } - // Need to work on the original message p, as that was used to calculate the tsig. - w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly) - } - return m, w.tsigStatus + co.rtt = time.Since(co.t) +// if t := m.IsTsig(); t != nil { +// secret := t.Hdr.Name +// if _, ok := w.client.TsigSecret[secret]; !ok { +// w.tsigStatus = ErrSecret +// return m, ErrSecret +// } +// // Need to work on the original message p, as that was used to calculate the tsig. +// w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly) +// } + return m, nil } -func (w *reply) read(p []byte) (n int, err error) { - if w.conn == nil { +func (co *Conn) Read(p []byte) (n int, err error) { + if co.Conn == nil { return 0, ErrConnEmpty } if len(p) < 2 { return 0, io.ErrShortBuffer } - switch w.client.Net { - case "tcp", "tcp4", "tcp6": - setTimeouts(w) - n, err = w.conn.(*net.TCPConn).Read(p[0:2]) + if t, ok := co.Conn.(*net.TCPConn); ok { + n, err = t.Read(p[0:2]) if err != nil || n != 2 { return n, err } @@ -183,25 +148,25 @@ func (w *reply) read(p []byte) (n int, err error) { if int(l) > len(p) { return int(l), io.ErrShortBuffer } - n, err = w.conn.(*net.TCPConn).Read(p[:l]) + n, err = t.Read(p[:l]) if err != nil { return n, err } i := n for i < int(l) { - j, err := w.conn.(*net.TCPConn).Read(p[i:int(l)]) + j, err := t.Read(p[i:int(l)]) if err != nil { return i, err } i += j } n = i - case "", "udp", "udp4", "udp6": - setTimeouts(w) - n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p) - if err != nil { - return n, err - } + return n, err + } + // assume udp connection + n, _, err = co.Conn.(*net.UDPConn).ReadFromUDP(p) + if err != nil { + return n, err } return n, err } @@ -209,62 +174,57 @@ func (w *reply) read(p []byte) (n int, err error) { // send sends a dns msg to the address specified in w. // If the message m contains a TSIG record the transaction // signature is calculated. -func (w *reply) send(m *Msg) (err error) { +func (co *Conn) WriteMsg(m *Msg) (err error) { var out []byte - if t := m.IsTsig(); t != nil { - mac := "" - name := t.Hdr.Name - if _, ok := w.client.TsigSecret[name]; !ok { - return ErrSecret - } - out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly) - w.tsigRequestMAC = mac - } else { - out, err = m.Pack() - } +// if t := m.IsTsig(); t != nil { +// mac := "" +// name := t.Hdr.Name +// if _, ok := w.client.TsigSecret[name]; !ok { +// return ErrSecret +// } +// out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly) +// w.tsigRequestMAC = mac +// } else { + out, err = m.Pack() +// } if err != nil { return err } - w.t = time.Now() - if _, err = w.write(out); err != nil { + co.t = time.Now() + if _, err = co.Write(out); err != nil { return err } return nil } -func (w *reply) write(p []byte) (n int, err error) { - switch w.client.Net { - case "tcp", "tcp4", "tcp6": +func (co *Conn) Write(p []byte) (n int, err error) { + if t, ok := co.Conn.(*net.TCPConn); ok { if len(p) < 2 { return 0, io.ErrShortBuffer } - setTimeouts(w) l := make([]byte, 2) l[0], l[1] = packUint16(uint16(len(p))) p = append(l, p...) - n, err := w.conn.Write(p) + n, err := t.Write(p) if err != nil { return n, err } i := n if i < len(p) { - j, err := w.conn.Write(p[i:len(p)]) + j, err := t.Write(p[i:len(p)]) if err != nil { return i, err } i += j } n = i - case "", "udp", "udp4", "udp6": - setTimeouts(w) - n, err = w.conn.(*net.UDPConn).Write(p) - if err != nil { - return n, err - } + return n, err } - return + n, err = co.Conn.(*net.UDPConn).Write(p) + return n, err } +/* func setTimeouts(w *reply) { if w.client.ReadTimeout == 0 { w.conn.SetReadDeadline(time.Now().Add(2 * 1e9)) @@ -278,3 +238,4 @@ func setTimeouts(w *reply) { w.conn.SetWriteDeadline(time.Now().Add(w.client.WriteTimeout)) } } +*/ diff --git a/ex/q/q.go b/ex/q/q.go index 04f6b86e..22031298 100644 --- a/ex/q/q.go +++ b/ex/q/q.go @@ -396,6 +396,7 @@ func shortRR(r dns.RR) dns.RR { } func doXfr(c *dns.Client, m *dns.Msg, nameserver string) { + /* if t, e := c.TransferIn(m, nameserver); e == nil { for r := range t { if r.Error == nil { @@ -412,4 +413,5 @@ func doXfr(c *dns.Client, m *dns.Msg, nameserver string) { } else { fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", e.Error()) } + */ } diff --git a/xfr.go b/xfr.go index 77ea7e0f..a2101ee4 100644 --- a/xfr.go +++ b/xfr.go @@ -10,6 +10,7 @@ type Envelope struct { Error error // If something went wrong, this contains the error. } +/* // TransferIn performs a [AI]XFR request (depends on the message's Qtype). It returns // a channel of *Envelope on which the replies from the server are sent. At the end of // the transfer the channel is closed. @@ -201,3 +202,4 @@ func xfrOut(w ResponseWriter, req *Msg, c chan *Envelope, e *error) { rep.Answer = nil } } +*/