dns/client.go

413 lines
9.7 KiB
Go
Raw Normal View History

2011-04-13 05:44:56 +10:00
package dns
// A concurrent client implementation.
// Client sends query to a channel which
// will then handle the query. Returned replys
// are return on another channel. Ready for handling --- same
// setup for server - a HANDLER function that gets run
// when the query returns.
import (
2011-04-13 06:21:09 +10:00
"os"
2011-04-16 07:55:27 +10:00
"io"
"net"
2011-04-13 05:44:56 +10:00
)
type QueryHandler interface {
2011-04-13 06:21:09 +10:00
QueryDNS(w RequestWriter, q *Msg)
2011-04-13 05:44:56 +10:00
}
2011-08-01 21:15:15 +10:00
// The RequestWriter interface is used by a DNS query handler to
// construct a DNS request.
2011-04-13 05:44:56 +10:00
type RequestWriter interface {
WriteMessages([]*Msg)
2011-04-16 07:55:27 +10:00
Write(*Msg)
2011-04-18 17:28:56 +10:00
Send(*Msg) os.Error
Receive() (*Msg, os.Error)
2011-04-13 06:21:09 +10:00
}
// hijacked connections...?
2011-04-13 06:39:38 +10:00
type reply struct {
2011-04-19 06:08:12 +10:00
client *Client
addr string
req *Msg
conn net.Conn
2011-04-23 00:37:26 +10:00
tsigRequestMAC string
2011-04-19 06:08:12 +10:00
tsigTimersOnly bool
}
2011-08-01 21:15:15 +10:00
// A Request is a incoming message from a Client
type Request struct {
Request *Msg
Addr string
Client *Client
2011-04-13 06:21:09 +10:00
}
2011-04-13 05:44:56 +10:00
// QueryMux is an DNS request multiplexer. It matches the
// zone name of each incoming request against a list of
// registered patterns add calls the handler for the pattern
// that most closely matches the zone name.
type QueryMux struct {
2011-04-13 06:21:09 +10:00
m map[string]QueryHandler
2011-04-13 05:44:56 +10:00
}
// NewQueryMux allocates and returns a new QueryMux.
func NewQueryMux() *QueryMux { return &QueryMux{make(map[string]QueryHandler)} }
// DefaultQueryMux is the default QueryMux used by Query.
var DefaultQueryMux = NewQueryMux()
func newQueryChanSlice() chan []*Msg { return make(chan []*Msg) }
func newQueryChan() chan *Request { return make(chan *Request) }
2011-04-14 04:41:16 +10:00
2011-07-06 05:08:22 +10:00
// Default channels to use for the resolver
var DefaultReplyChan = newQueryChanSlice()
2011-04-14 04:41:16 +10:00
var DefaultQueryChan = newQueryChan()
2011-04-13 05:44:56 +10:00
// The HandlerQueryFunc type is an adapter to allow the use of
// ordinary functions as DNS query handlers. If f is a function
// with the appropriate signature, HandlerQueryFunc(f) is a
// QueryHandler object that calls f.
2011-04-13 05:44:56 +10:00
type HandlerQueryFunc func(RequestWriter, *Msg)
// QueryDNS calls f(w, reg)
func (f HandlerQueryFunc) QueryDNS(w RequestWriter, r *Msg) {
go f(w, r)
}
func HandleQueryFunc(pattern string, handler func(RequestWriter, *Msg)) {
DefaultQueryMux.HandleQueryFunc(pattern, handler)
2011-04-13 05:44:56 +10:00
}
// reusing zoneMatch from server.go
func (mux *QueryMux) match(zone string) QueryHandler {
2011-04-13 06:21:09 +10:00
var h QueryHandler
var n = 0
for k, v := range mux.m {
if !zoneMatch(k, zone) {
continue
}
if h == nil || len(k) > n {
n = len(k)
h = v
}
}
return h
2011-04-13 05:44:56 +10:00
}
func (mux *QueryMux) Handle(pattern string, handler QueryHandler) {
2011-04-13 06:21:09 +10:00
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
2011-07-05 06:27:23 +10:00
mux.m[pattern] = handler
2011-04-13 05:44:56 +10:00
}
func (mux *QueryMux) HandleQueryFunc(pattern string, handler func(RequestWriter, *Msg)) {
2011-04-13 06:21:09 +10:00
mux.Handle(pattern, HandlerQueryFunc(handler))
2011-04-13 05:44:56 +10:00
}
func (mux *QueryMux) QueryDNS(w RequestWriter, r *Msg) {
h := mux.match(r.Question[0].Name)
2011-04-13 06:21:09 +10:00
if h == nil {
2011-07-24 07:43:43 +10:00
panic("dns: no handler found for " + r.Question[0].Name)
2011-04-13 06:21:09 +10:00
}
h.QueryDNS(w, r)
2011-04-13 05:44:56 +10:00
}
type Client struct {
2011-04-19 06:08:12 +10:00
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one
Attempts int // number of attempts
Retry bool // retry with TCP
ChannelQuery chan *Request // read DNS request from this channel
ChannelReply chan []*Msg // read DNS request from this channel
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
2011-07-24 07:43:43 +10:00
// LocalAddr string // Local address to use
}
2011-07-31 22:33:13 +10:00
// NewClient creates a new client, with Net set to "udp" and Attempts to 1.
func NewClient() *Client {
c := new(Client)
c.Net = "udp"
c.Attempts = 1
c.ChannelReply = DefaultReplyChan
return c
}
type Query struct {
ChannelQuery chan *Request // read DNS request from this channel
Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil
}
func (q *Query) Query() os.Error {
handler := q.Handler
2011-04-13 06:21:09 +10:00
if handler == nil {
handler = DefaultQueryMux
}
2011-07-24 07:43:43 +10:00
//forever:
for {
select {
case in := <-q.ChannelQuery:
w := new(reply)
w.req = in.Request
w.addr = in.Addr
w.client = in.Client
handler.QueryDNS(w, in.Request)
}
}
2011-04-13 06:21:09 +10:00
return nil
2011-04-13 05:44:56 +10:00
}
func (q *Query) ListenAndQuery() os.Error {
if q.ChannelQuery == nil {
q.ChannelQuery = DefaultQueryChan
}
return q.Query()
}
2011-07-05 06:27:23 +10:00
// Start listener for firing off the queries. If
// c is nil DefaultQueryChan is used. If handler is nil
// DefaultQueryMux is used.
func ListenAndQuery(c chan *Request, handler QueryHandler) {
q := &Query{ChannelQuery: c, Handler: handler}
go q.ListenAndQuery()
}
func (w *reply) Write(m *Msg) {
w.Client().ChannelReply <- []*Msg{w.req, m}
}
2011-07-05 05:38:50 +10:00
// Do performs an asynchronous query. The result is returned on the
2011-07-31 22:33:13 +10:00
// channel set in the c. If no channel is set DefaultQueryChan is used.
func (c *Client) Do(m *Msg, a string) {
if c.ChannelQuery == nil {
DefaultQueryChan <- &Request{Client: c, Addr: a, Request: m}
} else {
c.ChannelQuery <- &Request{Client: c, Addr: a, Request: m}
2011-04-16 07:55:27 +10:00
}
}
2011-08-04 19:27:56 +10:00
// ExchangeBuf performs a synchronous query. It sends the buffer m to the
// address (net.Addr?) contained in a
func (c *Client) ExchangeBuffer(inbuf []byte, a string, outbuf []byte) bool {
2011-04-18 17:28:56 +10:00
w := new(reply)
w.client = c
w.addr = a
_, err := w.writeClient(inbuf)
2011-08-04 21:59:15 +10:00
defer w.closeClient() // XXX here?? what about TCP which should remain open
2011-04-18 17:28:56 +10:00
if err != nil {
2011-08-04 19:27:56 +10:00
println(err.String())
return false
2011-04-18 17:28:56 +10:00
}
2011-08-04 19:27:56 +10:00
// udp / tcp TODO
n, err := w.readClient(outbuf)
2011-04-18 17:28:56 +10:00
if err != nil {
return false
2011-04-18 17:28:56 +10:00
}
outbuf = outbuf[:n]
return true
2011-08-04 19:27:56 +10:00
}
// Exchange performs an synchronous query. It sends the message m to the address
// contained in a and waits for an reply.
func (c *Client) Exchange(m *Msg, a string) *Msg {
out, ok := m.Pack()
if !ok {
panic("failed to pack message")
}
in := make([]byte, DefaultMsgSize)
if ok := c.ExchangeBuffer(out, a, in); !ok {
2011-08-04 19:27:56 +10:00
return nil
}
2011-07-24 07:43:43 +10:00
r := new(Msg)
if ok := r.Unpack(in); !ok {
2011-04-18 17:28:56 +10:00
return nil
}
2011-08-04 19:27:56 +10:00
return r
}
func (w *reply) WriteMessages(m []*Msg) {
m1 := append([]*Msg{w.req}, m...)
w.Client().ChannelReply <- m1
2011-04-13 05:44:56 +10:00
}
2011-04-16 07:55:27 +10:00
func (w *reply) Client() *Client {
return w.client
2011-04-17 18:54:34 +10:00
}
func (w *reply) Request() *Msg {
return w.req
}
2011-04-18 17:28:56 +10:00
func (w *reply) Receive() (*Msg, os.Error) {
var p []byte
m := new(Msg)
switch w.Client().Net {
2011-07-06 04:55:05 +10:00
case "tcp", "tcp4", "tcp6":
2011-04-19 02:29:46 +10:00
p = make([]byte, MaxMsgSize)
2011-07-06 04:55:05 +10:00
case "udp", "udp4", "udp6":
p = make([]byte, DefaultMsgSize)
2011-04-19 06:08:12 +10:00
}
n, err := w.readClient(p)
if err != nil {
return nil, err
}
p = p[:n]
if ok := m.Unpack(p); !ok {
return nil, ErrUnpack
}
2011-04-23 00:37:26 +10:00
// Tsig
if m.IsTsig() {
secret := m.Extra[len(m.Extra)-1].(*RR_TSIG).Hdr.Name
_, ok := w.Client().TsigSecret[secret]
if !ok {
return m, ErrNoSig
}
ok, err := TsigVerify(p, w.Client().TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly)
if !ok {
return m, err
}
}
return m, nil
}
2011-04-16 07:55:27 +10:00
func (w *reply) readClient(p []byte) (n int, err os.Error) {
if w.conn == nil {
panic("no connection")
}
switch w.Client().Net {
2011-07-06 04:55:05 +10:00
case "tcp", "tcp4", "tcp6":
2011-04-19 02:27:59 +10:00
if len(p) < 1 {
2011-04-19 06:08:12 +10:00
return 0, io.ErrShortBuffer
}
n, err = w.conn.(*net.TCPConn).Read(p[0:2])
if err != nil || n != 2 {
return n, err
}
l, _ := unpackUint16(p[0:2], 0)
if l == 0 {
return 0, ErrShortRead
}
if int(l) > len(p) {
return int(l), io.ErrShortBuffer
}
n, err = w.conn.(*net.TCPConn).Read(p[:l])
if err != nil {
return n, err
}
i := n
for i < int(l) {
j, err := w.conn.(*net.TCPConn).Read(p[i:int(l)])
if err != nil {
return i, err
}
i += j
}
n = i
2011-07-06 04:55:05 +10:00
case "udp", "udp4", "udp6":
n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p)
2011-04-17 18:54:34 +10:00
if err != nil {
return n, err
}
}
return
2011-04-16 07:55:27 +10:00
}
2011-04-19 06:08:12 +10:00
// Send a msg to the address specified in w.
// If the message m contains a TSIG record the transaction
// signature is calculated.
2011-04-18 17:28:56 +10:00
func (w *reply) Send(m *Msg) os.Error {
2011-04-19 06:08:12 +10:00
if m.IsTsig() {
2011-04-23 00:37:26 +10:00
secret := m.Extra[len(m.Extra)-1].(*RR_TSIG).Hdr.Name
_, ok := w.Client().TsigSecret[secret]
if !ok {
return ErrNoSig
}
m, _ = TsigGenerate(m, w.Client().TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly)
w.tsigRequestMAC = m.Extra[len(m.Extra)-1].(*RR_TSIG).MAC // Safe the requestMAC
2011-04-19 06:08:12 +10:00
}
2011-04-16 07:55:27 +10:00
out, ok := m.Pack()
if !ok {
return ErrPack
}
_, err := w.writeClient(out)
2011-04-16 07:55:27 +10:00
if err != nil {
return err
}
return nil
}
func (w *reply) writeClient(p []byte) (n int, err os.Error) {
2011-08-04 21:49:40 +10:00
if w.Client().Attempts == 0 {
panic("c.Attempts 0")
}
2011-08-04 21:49:40 +10:00
if w.Client().Net == "" {
panic("c.Net empty")
}
2011-08-04 21:49:40 +10:00
conn, err := net.Dial(w.Client().Net, w.addr)
2011-04-16 07:55:27 +10:00
if err != nil {
return 0, err
}
w.conn = conn
2011-08-04 21:49:40 +10:00
switch w.Client().Net {
2011-07-06 04:55:05 +10:00
case "tcp", "tcp4", "tcp6":
2011-04-16 07:55:27 +10:00
if len(p) < 2 {
return 0, io.ErrShortBuffer
}
2011-08-04 21:49:40 +10:00
for a := 0; a < w.Client().Attempts; a++ {
2011-04-16 07:55:27 +10:00
l := make([]byte, 2)
l[0], l[1] = packUint16(uint16(len(p)))
n, err = conn.Write(l)
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return n, err
}
if n != 2 {
return n, io.ErrShortWrite
}
n, err = conn.Write(p)
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return n, err
}
i := n
if i < len(p) {
j, err := conn.Write(p[i:len(p)])
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
// We are half way in our write...
continue
}
return i, err
}
i += j
}
n = i
}
2011-07-06 04:55:05 +10:00
case "udp", "udp4", "udp6":
2011-08-04 21:49:40 +10:00
for a := 0; a < w.Client().Attempts; a++ {
2011-04-16 07:55:27 +10:00
n, err = conn.(*net.UDPConn).WriteTo(p, conn.RemoteAddr())
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
continue
}
return 0, err
}
}
}
return 0, nil
}
2011-08-04 19:27:56 +10:00
// UDP/TCP stuff
func (w *reply) closeClient() (err os.Error) {
return w.conn.Close()
}