dns/client.go

233 lines
5.6 KiB
Go
Raw Normal View History

2013-05-13 00:09:52 +10:00
// Copyright 2011 Miek Gieben. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
2011-04-13 05:44:56 +10:00
package dns
2013-01-29 06:41:17 +11:00
// A client implementation.
2011-04-13 05:44:56 +10:00
import (
2011-04-16 07:55:27 +10:00
"io"
"net"
2012-01-20 22:13:47 +11:00
"time"
2011-04-13 05:44:56 +10:00
)
2012-10-16 22:18:59 +11:00
// Order of events:
2012-11-18 23:28:16 +11:00
// *client -> *reply -> Exchange() -> dial()/send()->write()/receive()->read()
2013-01-29 06:41:17 +11:00
2012-11-18 23:12:11 +11:00
// Do I want make this an interface thingy?
2011-04-13 06:39:38 +10:00
type reply struct {
2011-04-19 06:08:12 +10:00
client *Client
addr string
req *Msg
conn net.Conn
2011-04-23 00:37:26 +10:00
tsigRequestMAC string
2011-04-19 06:08:12 +10:00
tsigTimersOnly bool
tsigStatus error
2012-05-05 07:18:29 +10:00
rtt time.Duration
t time.Time
}
2012-09-02 01:06:24 +10:00
// A Client defines parameter for a DNS client. A nil
// Client is usable for sending queries.
2011-04-13 05:44:56 +10:00
type Client struct {
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
2012-11-18 23:03:11 +11:00
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defaults to 2 * 1e9
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defaults to 2 * 1e9
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
Inflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
}
2011-08-04 19:27:56 +10:00
// Exchange performs an synchronous query. It sends the message m to the address
2012-12-02 20:14:53 +11:00
// contained in a and waits for an reply. Basic use pattern with a *dns.Client:
2012-05-06 04:47:23 +10:00
//
2012-05-26 18:28:32 +10:00
// c := new(dns.Client)
2012-11-18 22:29:40 +11:00
// in, rtt, err := c.Exchange(message, "127.0.0.1:53")
2013-01-29 06:40:41 +11:00
//
2012-11-18 22:29:40 +11:00
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
2013-01-29 06:30:13 +11:00
w := &reply{client: c, addr: a}
if err = w.dial(); err != nil {
return nil, 0, err
2011-08-04 19:27:56 +10:00
}
defer w.conn.Close()
if err = w.send(m); err != nil {
return nil, 0, err
2011-08-08 21:10:35 +10:00
}
r, err = w.receive()
return r, w.rtt, err
}
2013-01-29 06:30:13 +11:00
// ExchangeConn performs an synchronous query. It sends the message m trough the
2013-01-29 06:32:36 +11:00
// connection s and waits for a reply.
func (c *Client) ExchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) {
2013-01-29 07:49:23 +11:00
w := &reply{client: c, conn: s}
if err = w.send(m); err != nil {
2013-01-29 06:30:13 +11:00
return nil, 0, err
}
r, err = w.receive()
return r, w.rtt, err
}
2012-08-07 04:34:09 +10:00
// dial connects to the address addr for the network set in c.Net
func (w *reply) dial() (err error) {
2012-05-26 20:02:37 +10:00
var conn net.Conn
if w.client.Net == "" {
conn, err = net.DialTimeout("udp", w.addr, 5*1e9)
} else {
conn, err = net.DialTimeout(w.client.Net, w.addr, 5*1e9)
2012-05-26 20:02:37 +10:00
}
if err != nil {
return err
2011-08-08 21:10:35 +10:00
}
w.conn = conn
2012-10-17 18:05:26 +11:00
return
2011-08-08 21:10:35 +10:00
}
2012-08-07 04:34:09 +10:00
func (w *reply) receive() (*Msg, error) {
var p []byte
m := new(Msg)
2012-08-17 16:45:26 +10:00
switch w.client.Net {
2011-07-06 04:55:05 +10:00
case "tcp", "tcp4", "tcp6":
2011-04-19 02:29:46 +10:00
p = make([]byte, MaxMsgSize)
2012-05-26 18:24:47 +10:00
case "", "udp", "udp4", "udp6":
// OPT! TODO(mg)
p = make([]byte, DefaultMsgSize)
2011-04-19 06:08:12 +10:00
}
n, err := w.read(p)
if err != nil && n == 0 {
2011-04-19 06:08:12 +10:00
return nil, err
}
p = p[:n]
if err := m.Unpack(p); err != nil {
return nil, err
2011-04-19 06:08:12 +10:00
}
2012-05-05 07:18:29 +10:00
w.rtt = time.Since(w.t)
if t := m.IsTsig(); t != nil {
secret := t.Hdr.Name
2012-08-17 16:45:26 +10:00
if _, ok := w.client.TsigSecret[secret]; !ok {
w.tsigStatus = ErrSecret
2012-10-16 18:42:38 +11:00
return m, ErrSecret
2011-04-23 00:37:26 +10:00
}
2012-02-26 07:42:08 +11:00
// Need to work on the original message p, as that was used to calculate the tsig.
2012-08-17 16:45:26 +10:00
w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly)
2011-04-23 00:37:26 +10:00
}
2012-10-16 18:42:38 +11:00
return m, w.tsigStatus
}
2011-04-16 07:55:27 +10:00
func (w *reply) read(p []byte) (n int, err error) {
if w.conn == nil {
2011-11-03 09:06:54 +11:00
return 0, ErrConnEmpty
}
if len(p) < 2 {
2012-05-26 18:24:47 +10:00
return 0, io.ErrShortBuffer
}
2012-08-17 16:45:26 +10:00
switch w.client.Net {
2011-07-06 04:55:05 +10:00
case "tcp", "tcp4", "tcp6":
setTimeouts(w)
n, err = w.conn.(*net.TCPConn).Read(p[0:2])
if err != nil || n != 2 {
return n, err
}
l, _ := unpackUint16(p[0:2], 0)
if l == 0 {
return 0, ErrShortRead
}
if int(l) > len(p) {
return int(l), io.ErrShortBuffer
}
n, err = w.conn.(*net.TCPConn).Read(p[:l])
if err != nil {
return n, err
}
i := n
for i < int(l) {
j, err := w.conn.(*net.TCPConn).Read(p[i:int(l)])
2011-04-19 06:08:12 +10:00
if err != nil {
return i, err
}
i += j
2011-04-19 06:08:12 +10:00
}
n = i
2012-05-26 18:24:47 +10:00
case "", "udp", "udp4", "udp6":
setTimeouts(w)
n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p)
if err != nil {
return n, err
2011-04-17 18:54:34 +10:00
}
}
return n, err
2011-04-16 07:55:27 +10:00
}
2012-08-07 04:34:09 +10:00
// send sends a dns msg to the address specified in w.
2011-04-19 06:08:12 +10:00
// If the message m contains a TSIG record the transaction
// signature is calculated.
2012-08-07 04:34:09 +10:00
func (w *reply) send(m *Msg) (err error) {
2012-03-03 07:19:37 +11:00
var out []byte
if t := m.IsTsig(); t != nil {
2012-03-03 07:19:37 +11:00
mac := ""
name := t.Hdr.Name
2012-08-17 16:45:26 +10:00
if _, ok := w.client.TsigSecret[name]; !ok {
return ErrSecret
2011-04-23 00:37:26 +10:00
}
2012-08-17 16:45:26 +10:00
out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly)
w.tsigRequestMAC = mac
2012-03-03 07:19:37 +11:00
} else {
out, err = m.Pack()
}
if err != nil {
return err
2011-04-16 07:55:27 +10:00
}
2012-05-05 07:18:29 +10:00
w.t = time.Now()
if _, err = w.write(out); err != nil {
2012-03-03 07:19:37 +11:00
return err
}
2011-04-16 07:55:27 +10:00
return nil
}
func (w *reply) write(p []byte) (n int, err error) {
2012-08-17 16:45:26 +10:00
switch w.client.Net {
2011-07-06 04:55:05 +10:00
case "tcp", "tcp4", "tcp6":
2011-04-16 07:55:27 +10:00
if len(p) < 2 {
return 0, io.ErrShortBuffer
}
setTimeouts(w)
l := make([]byte, 2)
l[0], l[1] = packUint16(uint16(len(p)))
p = append(l, p...)
n, err := w.conn.Write(p)
if err != nil {
return n, err
}
i := n
if i < len(p) {
j, err := w.conn.Write(p[i:len(p)])
2011-04-16 07:55:27 +10:00
if err != nil {
return i, err
}
i += j
2011-04-16 07:55:27 +10:00
}
n = i
2012-05-26 18:24:47 +10:00
case "", "udp", "udp4", "udp6":
setTimeouts(w)
n, err = w.conn.(*net.UDPConn).Write(p)
if err != nil {
return n, err
2011-04-16 07:55:27 +10:00
}
}
2012-01-24 06:35:14 +11:00
return
2011-04-16 07:55:27 +10:00
}
2012-05-05 07:18:29 +10:00
2012-05-26 18:24:47 +10:00
func setTimeouts(w *reply) {
2012-08-17 16:45:26 +10:00
if w.client.ReadTimeout == 0 {
2012-05-26 18:24:47 +10:00
w.conn.SetReadDeadline(time.Now().Add(2 * 1e9))
} else {
2012-08-17 16:45:26 +10:00
w.conn.SetReadDeadline(time.Now().Add(w.client.ReadTimeout))
2012-05-26 18:24:47 +10:00
}
2012-08-17 16:45:26 +10:00
if w.client.WriteTimeout == 0 {
2012-05-26 18:24:47 +10:00
w.conn.SetWriteDeadline(time.Now().Add(2 * 1e9))
} else {
2012-08-17 16:45:26 +10:00
w.conn.SetWriteDeadline(time.Now().Add(w.client.WriteTimeout))
2012-05-26 18:24:47 +10:00
}
}