dns/client.go

355 lines
8.3 KiB
Go
Raw Normal View History

2011-04-13 05:44:56 +10:00
package dns
// A concurrent client implementation.
import (
2011-04-16 07:55:27 +10:00
"io"
"net"
2012-01-20 22:13:47 +11:00
"time"
2011-04-13 05:44:56 +10:00
)
// hijacked connections...?
2011-04-13 06:39:38 +10:00
type reply struct {
2011-04-19 06:08:12 +10:00
client *Client
addr string
req *Msg
conn net.Conn
2011-04-23 00:37:26 +10:00
tsigRequestMAC string
2011-04-19 06:08:12 +10:00
tsigTimersOnly bool
tsigStatus error
2012-05-05 07:18:29 +10:00
rtt time.Duration
t time.Time
}
2012-05-26 18:24:47 +10:00
// A nil Client is usable.
2011-04-13 05:44:56 +10:00
type Client struct {
2012-05-26 18:24:47 +10:00
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "", is UDP)
Attempts int // number of attempts, if not set defaults to 1
2011-04-19 06:08:12 +10:00
Retry bool // retry with TCP
2012-05-26 18:24:47 +10:00
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defauls to 2 * 1e9
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defauls to 2 * 1e9
2011-04-19 06:08:12 +10:00
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
}
func (w *reply) RemoteAddr() net.Addr {
if w.conn == nil {
return nil
} else {
return w.conn.RemoteAddr()
}
return nil
}
// Do performs an asynchronous query. The msg *Msg is the question to ask, the
2012-08-08 17:38:23 +10:00
// string addr is the address of the nameserver, the parameter data is used
// in the callback function. The call backback function is called with the
2012-08-08 17:38:23 +10:00
// original query, the answer returned from the nameserver an optional error and
// data.
func (c *Client) Do(msg *Msg, addr string, data interface{}, callback func(*Msg, *Msg, error, interface{})) {
go func() {
r, err := c.Exchange(msg, addr)
callback(msg, r, err, data)
}()
}
2012-08-08 17:38:23 +10:00
// DoRtt is equivalent to Do, except that is calls ExchangeRtt.
2012-08-08 17:45:31 +10:00
func (c *Client) DoRtt(msg *Msg, addr string, data interface{}, callback func(*Msg, *Msg, time.Duration, error, interface{})) {
2012-08-08 17:38:23 +10:00
go func() {
r, rtt, err := c.ExchangeRtt(msg, addr)
callback(msg, r, rtt, err, data)
}()
}
2012-05-06 00:09:57 +10:00
// exchangeBuffer performs a synchronous query. It sends the buffer m to the
2011-12-17 05:35:37 +11:00
// address contained in a.
2012-05-06 00:09:57 +10:00
func (c *Client) exchangeBuffer(inbuf []byte, a string, outbuf []byte) (n int, w *reply, err error) {
w = new(reply)
2011-04-18 17:28:56 +10:00
w.client = c
w.addr = a
2012-08-07 04:34:09 +10:00
if err = w.dial(); err != nil {
return 0, w, err
}
2012-08-23 18:33:33 +10:00
defer w.conn.Close()
2012-05-05 07:18:29 +10:00
w.t = time.Now()
2011-08-08 21:10:35 +10:00
if n, err = w.writeClient(inbuf); err != nil {
2012-05-06 00:09:57 +10:00
return 0, w, err
2011-04-18 17:28:56 +10:00
}
2011-08-08 21:10:35 +10:00
if n, err = w.readClient(outbuf); err != nil {
2012-05-06 00:09:57 +10:00
return n, w, err
2011-04-18 17:28:56 +10:00
}
2012-05-05 07:18:29 +10:00
w.rtt = time.Since(w.t)
2012-05-06 00:09:57 +10:00
return n, w, nil
2011-08-04 19:27:56 +10:00
}
// Exchange performs an synchronous query. It sends the message m to the address
2012-05-06 04:47:23 +10:00
// contained in a and waits for an reply. Basic use pattern with a *Client:
//
2012-05-26 18:28:32 +10:00
// c := new(dns.Client)
2012-05-07 23:50:13 +10:00
// in, err := c.Exchange(message, "127.0.0.1:53")
2012-05-07 23:53:35 +10:00
//
2012-05-22 16:47:47 +10:00
// See Client.ExchangeRtt(...) to get the round trip time.
2012-05-07 23:50:13 +10:00
func (c *Client) Exchange(m *Msg, a string) (r *Msg, err error) {
r, _, err = c.ExchangeRtt(m, a)
2012-05-07 23:50:13 +10:00
return
}
2012-05-22 16:47:47 +10:00
// ExchangeRtt performs an synchronous query. It sends the message m to the address
2012-05-07 23:50:13 +10:00
// contained in a and waits for an reply. Basic use pattern with a *Client:
//
2012-05-26 18:28:32 +10:00
// c := new(dns.Client)
2012-08-07 20:33:31 +10:00
// in, rtt, err := c.ExchangeRtt(message, "127.0.0.1:53")
2012-05-06 04:47:23 +10:00
//
func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
2011-08-08 21:10:35 +10:00
var n int
2012-05-06 00:09:57 +10:00
var w *reply
2011-08-04 19:27:56 +10:00
out, ok := m.Pack()
if !ok {
return nil, 0, ErrPack
2011-08-04 19:27:56 +10:00
}
var in []byte
switch c.Net {
2012-05-23 04:15:30 +10:00
case "tcp", "tcp4", "tcp6":
in = make([]byte, MaxMsgSize)
2012-05-26 18:24:47 +10:00
case "", "udp", "udp4", "udp6":
2012-08-17 16:31:38 +10:00
size := udpMsgSize
for _, r := range m.Extra {
if r.Header().Rrtype == TypeOPT {
size = int(r.(*RR_OPT).UDPSize())
}
}
2012-01-29 10:20:56 +11:00
in = make([]byte, size)
}
2012-05-06 00:09:57 +10:00
if n, w, err = c.exchangeBuffer(out, a, in); err != nil {
return nil, 0, err
2011-08-08 21:10:35 +10:00
}
r = new(Msg)
2012-06-01 18:05:27 +10:00
r.Size = n
2011-08-08 21:10:35 +10:00
if ok := r.Unpack(in[:n]); !ok {
return nil, w.rtt, ErrUnpack
}
return r, w.rtt, nil
}
2012-08-07 04:34:09 +10:00
// dial connects to the address addr for the network set in c.Net
func (w *reply) dial() (err error) {
2012-05-26 20:02:37 +10:00
var conn net.Conn
2012-08-17 16:45:26 +10:00
if w.client.Net == "" {
2012-05-26 20:02:37 +10:00
conn, err = net.Dial("udp", w.addr)
} else {
2012-08-17 16:45:26 +10:00
conn, err = net.Dial(w.client.Net, w.addr)
2012-05-26 20:02:37 +10:00
}
2011-08-08 21:10:35 +10:00
if err != nil {
2012-05-26 20:02:37 +10:00
return
2011-08-08 21:10:35 +10:00
}
w.conn = conn
return nil
}
2012-08-07 04:34:09 +10:00
func (w *reply) receive() (*Msg, error) {
var p []byte
m := new(Msg)
2012-08-17 16:45:26 +10:00
switch w.client.Net {
2011-07-06 04:55:05 +10:00
case "tcp", "tcp4", "tcp6":
2011-04-19 02:29:46 +10:00
p = make([]byte, MaxMsgSize)
2012-05-26 18:24:47 +10:00
case "", "udp", "udp4", "udp6":
p = make([]byte, DefaultMsgSize)
2011-04-19 06:08:12 +10:00
}
n, err := w.readClient(p)
2012-05-05 07:18:29 +10:00
if err != nil || n == 0 {
2011-04-19 06:08:12 +10:00
return nil, err
}
p = p[:n]
if ok := m.Unpack(p); !ok {
return nil, ErrUnpack
}
2012-05-05 07:18:29 +10:00
w.rtt = time.Since(w.t)
2012-06-01 18:05:27 +10:00
m.Size = n
2011-04-23 00:37:26 +10:00
if m.IsTsig() {
secret := m.Extra[len(m.Extra)-1].(*RR_TSIG).Hdr.Name
2012-08-17 16:45:26 +10:00
if _, ok := w.client.TsigSecret[secret]; !ok {
w.tsigStatus = ErrSecret
return m, nil
2011-04-23 00:37:26 +10:00
}
2012-02-26 07:42:08 +11:00
// Need to work on the original message p, as that was used to calculate the tsig.
2012-08-17 16:45:26 +10:00
w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly)
2011-04-23 00:37:26 +10:00
}
return m, nil
}
2011-04-16 07:55:27 +10:00
2011-11-03 09:06:54 +11:00
func (w *reply) readClient(p []byte) (n int, err error) {
if w.conn == nil {
2011-11-03 09:06:54 +11:00
return 0, ErrConnEmpty
}
2012-05-26 18:24:47 +10:00
if len(p) < 1 {
return 0, io.ErrShortBuffer
}
2012-08-17 16:45:26 +10:00
attempts := w.client.Attempts
2012-05-26 18:24:47 +10:00
if attempts == 0 {
attempts = 1
}
2012-08-17 16:45:26 +10:00
switch w.client.Net {
2011-07-06 04:55:05 +10:00
case "tcp", "tcp4", "tcp6":
2012-05-26 18:24:47 +10:00
setTimeouts(w)
for a := 0; a < attempts; a++ {
2012-01-24 06:35:14 +11:00
n, err = w.conn.(*net.TCPConn).Read(p[0:2])
if err != nil || n != 2 {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return n, err
}
l, _ := unpackUint16(p[0:2], 0)
if l == 0 {
return 0, ErrShortRead
}
if int(l) > len(p) {
return int(l), io.ErrShortBuffer
}
n, err = w.conn.(*net.TCPConn).Read(p[:l])
2011-04-19 06:08:12 +10:00
if err != nil {
2012-01-24 06:35:14 +11:00
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return n, err
2011-04-19 06:08:12 +10:00
}
2012-01-24 06:35:14 +11:00
i := n
for i < int(l) {
j, err := w.conn.(*net.TCPConn).Read(p[i:int(l)])
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
// We are half way in our read...
continue
}
return i, err
}
i += j
}
n = i
2011-04-19 06:08:12 +10:00
}
2012-05-26 18:24:47 +10:00
case "", "udp", "udp4", "udp6":
for a := 0; a < attempts; a++ {
setTimeouts(w)
2012-01-24 06:35:14 +11:00
n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p)
2012-08-06 02:36:36 +10:00
if err == nil {
return n, err
}
2012-01-24 06:35:14 +11:00
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return n, err
}
2011-04-17 18:54:34 +10:00
}
}
return
2011-04-16 07:55:27 +10:00
}
2012-08-07 04:34:09 +10:00
// send sends a dns msg to the address specified in w.
2011-04-19 06:08:12 +10:00
// If the message m contains a TSIG record the transaction
// signature is calculated.
2012-08-07 04:34:09 +10:00
func (w *reply) send(m *Msg) (err error) {
2012-03-03 07:19:37 +11:00
var out []byte
2011-04-19 06:08:12 +10:00
if m.IsTsig() {
2012-03-03 07:19:37 +11:00
mac := ""
name := m.Extra[len(m.Extra)-1].(*RR_TSIG).Hdr.Name
2012-08-17 16:45:26 +10:00
if _, ok := w.client.TsigSecret[name]; !ok {
return ErrSecret
2011-04-23 00:37:26 +10:00
}
2012-08-17 16:45:26 +10:00
out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly)
if err != nil {
return err
}
w.tsigRequestMAC = mac
2012-03-03 07:19:37 +11:00
} else {
ok := false
out, ok = m.Pack()
if !ok {
return ErrPack
2011-11-03 09:06:54 +11:00
}
2011-04-16 07:55:27 +10:00
}
2012-05-05 07:18:29 +10:00
w.t = time.Now()
2012-03-03 07:19:37 +11:00
if _, err = w.writeClient(out); err != nil {
return err
}
2011-04-16 07:55:27 +10:00
return nil
}
2011-11-03 09:06:54 +11:00
func (w *reply) writeClient(p []byte) (n int, err error) {
2012-08-17 16:45:26 +10:00
attempts := w.client.Attempts
2012-05-26 18:24:47 +10:00
if attempts == 0 {
attempts = 1
}
2012-08-07 04:34:09 +10:00
if err = w.dial(); err != nil {
return 0, err
2011-04-16 07:55:27 +10:00
}
2012-08-17 16:45:26 +10:00
switch w.client.Net {
2011-07-06 04:55:05 +10:00
case "tcp", "tcp4", "tcp6":
2011-04-16 07:55:27 +10:00
if len(p) < 2 {
return 0, io.ErrShortBuffer
}
2012-05-26 18:24:47 +10:00
for a := 0; a < attempts; a++ {
setTimeouts(w)
2011-08-09 21:15:25 +10:00
a, b := packUint16(uint16(len(p)))
n, err = w.conn.Write([]byte{a, b})
2011-04-16 07:55:27 +10:00
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return n, err
}
if n != 2 {
return n, io.ErrShortWrite
}
2011-08-08 21:10:35 +10:00
n, err = w.conn.Write(p)
2011-04-16 07:55:27 +10:00
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return n, err
}
i := n
if i < len(p) {
2011-08-08 21:10:35 +10:00
j, err := w.conn.Write(p[i:len(p)])
2011-04-16 07:55:27 +10:00
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
// We are half way in our write...
continue
}
return i, err
}
i += j
}
n = i
}
2012-05-26 18:24:47 +10:00
case "", "udp", "udp4", "udp6":
for a := 0; a < attempts; a++ {
setTimeouts(w)
n, err = w.conn.(*net.UDPConn).Write(p)
2012-08-06 02:13:23 +10:00
if err == nil {
return
}
2011-04-16 07:55:27 +10:00
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
2012-01-24 06:35:14 +11:00
return n, err
2011-04-16 07:55:27 +10:00
}
}
}
2012-01-24 06:35:14 +11:00
return
2011-04-16 07:55:27 +10:00
}
2012-05-05 07:18:29 +10:00
2012-05-26 18:24:47 +10:00
func setTimeouts(w *reply) {
2012-08-17 16:45:26 +10:00
if w.client.ReadTimeout == 0 {
2012-05-26 18:24:47 +10:00
w.conn.SetReadDeadline(time.Now().Add(2 * 1e9))
} else {
2012-08-17 16:45:26 +10:00
w.conn.SetReadDeadline(time.Now().Add(w.client.ReadTimeout))
2012-05-26 18:24:47 +10:00
}
2012-08-17 16:45:26 +10:00
if w.client.WriteTimeout == 0 {
2012-05-26 18:24:47 +10:00
w.conn.SetWriteDeadline(time.Now().Add(2 * 1e9))
} else {
2012-08-17 16:45:26 +10:00
w.conn.SetWriteDeadline(time.Now().Add(w.client.WriteTimeout))
2012-05-26 18:24:47 +10:00
}
}