diff --git a/client.go b/client.go index ea437afb..10bbdce2 100644 --- a/client.go +++ b/client.go @@ -12,17 +12,9 @@ import ( "time" ) -// Order of events: -// *client -> *reply -> Exchange() -> dial()/send()->write()/receive()->read() - - -// A Conn represents a connection (which may be short lived) to a DNS -// server. +// A Conn represents a connection (which may be short lived) to a DNS server. type Conn struct { net.Conn - client *Client - addr string - req *Msg tsigRequestMAC string tsigTimersOnly bool tsigStatus error @@ -30,8 +22,7 @@ type Conn struct { t time.Time } -// A Client defines parameter for a DNS client. A nil -// Client is usable for sending queries. +// A Client defines parameters for a DNS client. A nil Client is usable for sending queries. type Client struct { Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defaults to 2 * 1e9 @@ -41,16 +32,31 @@ type Client struct { group singleflight } -func Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err) { - +// Exchange performs an synchronous UDP query. It sends the message m to the address +// contained in a and waits for an reply. +func Exchange(m *Msg, a string) (r *Msg, err error) { + co := new(Conn) + co.Conn, err = net.DialTimeout("udp", a, 5*1e9) + if err != nil { + return nil, err + } + defer co.Close() + if err = co.WriteMsg(m); err != nil { + return nil, err + } + r, err = co.ReadMsg() + return r, err } -func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.Duration, err error) { +// Exchange performs an synchronous query. It sends the message m to the address +// contained in a and waits for an reply. Basic use pattern with a *dns.Client: +// +// c := new(dns.Client) +// in, rtt, err := c.Exchange(message, "127.0.0.1:53") +// +func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { if !c.SingleInflight { - if s == nil { - return c.exchange(m, a) - } - return c.exchangeConn(m, s) + return c.exchange(m, a) } // This adds a bunch of garbage, TODO(miek). t := "nop" @@ -62,86 +68,47 @@ func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.D cl = cl1 } r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { - if s == nil { - return c.exchange(m, a) - } - return c.exchangeConn(m, s) + return c.exchange(m, a) }) if err != nil { return r, rtt, err } if shared { r1 := r.copy() - r1.Id = r.Id // Copy Id! +// not needed r1.Id = r.Id // Copy Id! r = r1 } return r, rtt, nil } -// Exchange performs an synchronous query. It sends the message m to the address -// contained in a and waits for an reply. Basic use pattern with a *dns.Client: -// -// c := new(dns.Client) -// in, rtt, err := c.Exchange(message, "127.0.0.1:53") -// -func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - return c.exchangeMerge(m, a, nil) -} - func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - w := &reply{client: c, addr: a} - if err = w.dial(); err != nil { - return nil, 0, err - } - defer w.conn.Close() - if err = w.send(m); err != nil { - return nil, 0, err - } - r, err = w.receive() - return r, w.rtt, err -} - -// ExchangeConn performs an synchronous query. It sends the message m trough the -// connection s and waits for a reply. -func (c *Client) ExchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) { - return c.exchangeMerge(m, "", s) -} - -func (c *Client) exchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) { - w := &reply{client: c, conn: s} - if err = w.send(m); err != nil { - return nil, 0, err - } - r, err = w.receive() - return r, w.rtt, err -} - -// dial connects to the address addr for the network set in c.Net -func (w *reply) dial() (err error) { - var conn net.Conn - if w.client.Net == "" { - conn, err = net.DialTimeout("udp", w.addr, 5*1e9) + co := new(Conn) + if c.Net == "" { + co.Conn, err = net.DialTimeout("udp", a, 5*1e9) } else { - conn, err = net.DialTimeout(w.client.Net, w.addr, 5*1e9) + co.Conn, err = net.DialTimeout(c.Net, a, 5*1e9) } if err != nil { - return err + return nil, 0, err } - w.conn = conn - return + defer co.Close() + if err = co.WriteMsg(m); err != nil { + return nil, 0, err + } + r, err = co.ReadMsg() + return r, co.rtt, err } -func (w *reply) receive() (*Msg, error) { +func (co *Conn) ReadMsg() (*Msg, error) { var p []byte m := new(Msg) - switch w.client.Net { - case "tcp", "tcp4", "tcp6": + if _, ok := co.Conn.(*net.TCPConn); ok { p = make([]byte, MaxMsgSize) - case "", "udp", "udp4", "udp6": + } else { // OPT! TODO(mg) p = make([]byte, DefaultMsgSize) } - n, err := w.read(p) + n, err := co.Read(p) if err != nil && n == 0 { return nil, err } @@ -149,30 +116,28 @@ func (w *reply) receive() (*Msg, error) { if err := m.Unpack(p); err != nil { return nil, err } - w.rtt = time.Since(w.t) - if t := m.IsTsig(); t != nil { - secret := t.Hdr.Name - if _, ok := w.client.TsigSecret[secret]; !ok { - w.tsigStatus = ErrSecret - return m, ErrSecret - } - // Need to work on the original message p, as that was used to calculate the tsig. - w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly) - } - return m, w.tsigStatus + co.rtt = time.Since(co.t) +// if t := m.IsTsig(); t != nil { +// secret := t.Hdr.Name +// if _, ok := w.client.TsigSecret[secret]; !ok { +// w.tsigStatus = ErrSecret +// return m, ErrSecret +// } +// // Need to work on the original message p, as that was used to calculate the tsig. +// w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly) +// } + return m, nil } -func (w *reply) read(p []byte) (n int, err error) { - if w.conn == nil { +func (co *Conn) Read(p []byte) (n int, err error) { + if co.Conn == nil { return 0, ErrConnEmpty } if len(p) < 2 { return 0, io.ErrShortBuffer } - switch w.client.Net { - case "tcp", "tcp4", "tcp6": - setTimeouts(w) - n, err = w.conn.(*net.TCPConn).Read(p[0:2]) + if t, ok := co.Conn.(*net.TCPConn); ok { + n, err = t.Read(p[0:2]) if err != nil || n != 2 { return n, err } @@ -183,25 +148,25 @@ func (w *reply) read(p []byte) (n int, err error) { if int(l) > len(p) { return int(l), io.ErrShortBuffer } - n, err = w.conn.(*net.TCPConn).Read(p[:l]) + n, err = t.Read(p[:l]) if err != nil { return n, err } i := n for i < int(l) { - j, err := w.conn.(*net.TCPConn).Read(p[i:int(l)]) + j, err := t.Read(p[i:int(l)]) if err != nil { return i, err } i += j } n = i - case "", "udp", "udp4", "udp6": - setTimeouts(w) - n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p) - if err != nil { - return n, err - } + return n, err + } + // assume udp connection + n, _, err = co.Conn.(*net.UDPConn).ReadFromUDP(p) + if err != nil { + return n, err } return n, err } @@ -209,62 +174,57 @@ func (w *reply) read(p []byte) (n int, err error) { // send sends a dns msg to the address specified in w. // If the message m contains a TSIG record the transaction // signature is calculated. -func (w *reply) send(m *Msg) (err error) { +func (co *Conn) WriteMsg(m *Msg) (err error) { var out []byte - if t := m.IsTsig(); t != nil { - mac := "" - name := t.Hdr.Name - if _, ok := w.client.TsigSecret[name]; !ok { - return ErrSecret - } - out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly) - w.tsigRequestMAC = mac - } else { - out, err = m.Pack() - } +// if t := m.IsTsig(); t != nil { +// mac := "" +// name := t.Hdr.Name +// if _, ok := w.client.TsigSecret[name]; !ok { +// return ErrSecret +// } +// out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly) +// w.tsigRequestMAC = mac +// } else { + out, err = m.Pack() +// } if err != nil { return err } - w.t = time.Now() - if _, err = w.write(out); err != nil { + co.t = time.Now() + if _, err = co.Write(out); err != nil { return err } return nil } -func (w *reply) write(p []byte) (n int, err error) { - switch w.client.Net { - case "tcp", "tcp4", "tcp6": +func (co *Conn) Write(p []byte) (n int, err error) { + if t, ok := co.Conn.(*net.TCPConn); ok { if len(p) < 2 { return 0, io.ErrShortBuffer } - setTimeouts(w) l := make([]byte, 2) l[0], l[1] = packUint16(uint16(len(p))) p = append(l, p...) - n, err := w.conn.Write(p) + n, err := t.Write(p) if err != nil { return n, err } i := n if i < len(p) { - j, err := w.conn.Write(p[i:len(p)]) + j, err := t.Write(p[i:len(p)]) if err != nil { return i, err } i += j } n = i - case "", "udp", "udp4", "udp6": - setTimeouts(w) - n, err = w.conn.(*net.UDPConn).Write(p) - if err != nil { - return n, err - } + return n, err } - return + n, err = co.Conn.(*net.UDPConn).Write(p) + return n, err } +/* func setTimeouts(w *reply) { if w.client.ReadTimeout == 0 { w.conn.SetReadDeadline(time.Now().Add(2 * 1e9)) @@ -278,3 +238,4 @@ func setTimeouts(w *reply) { w.conn.SetWriteDeadline(time.Now().Add(w.client.WriteTimeout)) } } +*/ diff --git a/ex/q/q.go b/ex/q/q.go index 04f6b86e..22031298 100644 --- a/ex/q/q.go +++ b/ex/q/q.go @@ -396,6 +396,7 @@ func shortRR(r dns.RR) dns.RR { } func doXfr(c *dns.Client, m *dns.Msg, nameserver string) { + /* if t, e := c.TransferIn(m, nameserver); e == nil { for r := range t { if r.Error == nil { @@ -412,4 +413,5 @@ func doXfr(c *dns.Client, m *dns.Msg, nameserver string) { } else { fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", e.Error()) } + */ } diff --git a/xfr.go b/xfr.go index 77ea7e0f..a2101ee4 100644 --- a/xfr.go +++ b/xfr.go @@ -10,6 +10,7 @@ type Envelope struct { Error error // If something went wrong, this contains the error. } +/* // TransferIn performs a [AI]XFR request (depends on the message's Qtype). It returns // a channel of *Envelope on which the replies from the server are sent. At the end of // the transfer the channel is closed. @@ -201,3 +202,4 @@ func xfrOut(w ResponseWriter, req *Msg, c chan *Envelope, e *error) { rep.Answer = nil } } +*/