2011-04-13 05:44:56 +10:00
|
|
|
package dns
|
|
|
|
|
|
|
|
// A concurrent client implementation.
|
|
|
|
// Client sends query to a channel which
|
|
|
|
// will then handle the query. Returned replys
|
|
|
|
// are return on another channel. Ready for handling --- same
|
|
|
|
// setup for server - a HANDLER function that gets run
|
|
|
|
// when the query returns.
|
|
|
|
|
|
|
|
import (
|
2011-04-13 06:21:09 +10:00
|
|
|
"os"
|
2011-04-16 07:55:27 +10:00
|
|
|
"io"
|
|
|
|
"net"
|
2011-04-13 05:44:56 +10:00
|
|
|
)
|
|
|
|
|
|
|
|
type QueryHandler interface {
|
2011-04-13 06:21:09 +10:00
|
|
|
QueryDNS(w RequestWriter, q *Msg)
|
2011-04-13 05:44:56 +10:00
|
|
|
}
|
|
|
|
|
|
|
|
// A RequestWriter interface is used by an DNS query handler to
|
|
|
|
// construct an DNS request.
|
|
|
|
type RequestWriter interface {
|
2011-04-16 05:42:27 +10:00
|
|
|
WriteMessages([]*Msg)
|
2011-04-16 07:55:27 +10:00
|
|
|
Write(*Msg)
|
2011-04-18 17:28:56 +10:00
|
|
|
Send(*Msg) os.Error
|
|
|
|
Receive() (*Msg, os.Error)
|
2011-04-13 06:21:09 +10:00
|
|
|
}
|
|
|
|
|
2011-04-15 06:11:41 +10:00
|
|
|
// hijacked connections...?
|
2011-04-13 06:39:38 +10:00
|
|
|
type reply struct {
|
2011-04-18 05:56:40 +10:00
|
|
|
client *Client
|
|
|
|
addr string
|
2011-04-15 06:11:41 +10:00
|
|
|
req *Msg
|
2011-04-18 05:56:40 +10:00
|
|
|
conn net.Conn
|
|
|
|
}
|
|
|
|
|
|
|
|
type Request struct {
|
|
|
|
Request *Msg
|
|
|
|
Addr string
|
|
|
|
Client *Client
|
2011-04-13 06:21:09 +10:00
|
|
|
}
|
2011-04-13 05:44:56 +10:00
|
|
|
|
|
|
|
// QueryMux is an DNS request multiplexer. It matches the
|
|
|
|
// zone name of each incoming request against a list of
|
|
|
|
// registered patterns add calls the handler for the pattern
|
|
|
|
// that most closely matches the zone name.
|
|
|
|
type QueryMux struct {
|
2011-04-13 06:21:09 +10:00
|
|
|
m map[string]QueryHandler
|
2011-04-13 05:44:56 +10:00
|
|
|
}
|
|
|
|
|
|
|
|
// NewQueryMux allocates and returns a new QueryMux.
|
|
|
|
func NewQueryMux() *QueryMux { return &QueryMux{make(map[string]QueryHandler)} }
|
|
|
|
|
|
|
|
// DefaultQueryMux is the default QueryMux used by Query.
|
|
|
|
var DefaultQueryMux = NewQueryMux()
|
|
|
|
|
2011-04-16 05:42:27 +10:00
|
|
|
func newQueryChanSlice() chan []*Msg { return make(chan []*Msg) }
|
2011-04-18 05:56:40 +10:00
|
|
|
func newQueryChan() chan *Request { return make(chan *Request) }
|
2011-04-14 04:41:16 +10:00
|
|
|
|
|
|
|
// Default channel to use for the resolver
|
2011-04-16 05:42:27 +10:00
|
|
|
var DefaultReplyChan = newQueryChanSlice()
|
2011-04-14 04:41:16 +10:00
|
|
|
var DefaultQueryChan = newQueryChan()
|
|
|
|
|
2011-04-13 05:44:56 +10:00
|
|
|
// The HandlerQueryFunc type is an adapter to allow the use of
|
|
|
|
// ordinary functions as DNS query handlers. If f is a function
|
|
|
|
// with the appropriate signature, HandlerQueryFunc(f) is a
|
2011-04-15 06:11:41 +10:00
|
|
|
// QueryHandler object that calls f.
|
2011-04-13 05:44:56 +10:00
|
|
|
type HandlerQueryFunc func(RequestWriter, *Msg)
|
|
|
|
|
|
|
|
// QueryDNS calls f(w, reg)
|
|
|
|
func (f HandlerQueryFunc) QueryDNS(w RequestWriter, r *Msg) {
|
2011-04-15 06:11:41 +10:00
|
|
|
go f(w, r)
|
|
|
|
}
|
|
|
|
|
|
|
|
func HandleQueryFunc(pattern string, handler func(RequestWriter, *Msg)) {
|
|
|
|
DefaultQueryMux.HandleQueryFunc(pattern, handler)
|
2011-04-13 05:44:56 +10:00
|
|
|
}
|
|
|
|
|
|
|
|
// reusing zoneMatch from server.go
|
|
|
|
func (mux *QueryMux) match(zone string) QueryHandler {
|
2011-04-13 06:21:09 +10:00
|
|
|
var h QueryHandler
|
|
|
|
var n = 0
|
|
|
|
for k, v := range mux.m {
|
|
|
|
if !zoneMatch(k, zone) {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if h == nil || len(k) > n {
|
|
|
|
n = len(k)
|
|
|
|
h = v
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return h
|
2011-04-13 05:44:56 +10:00
|
|
|
}
|
|
|
|
|
|
|
|
func (mux *QueryMux) Handle(pattern string, handler QueryHandler) {
|
2011-04-13 06:21:09 +10:00
|
|
|
if pattern == "" {
|
|
|
|
panic("dns: invalid pattern " + pattern)
|
|
|
|
}
|
|
|
|
if pattern[len(pattern)-1] != '.' { // no ending .
|
|
|
|
mux.m[pattern+"."] = handler
|
|
|
|
} else {
|
|
|
|
mux.m[pattern] = handler
|
|
|
|
}
|
2011-04-13 05:44:56 +10:00
|
|
|
}
|
|
|
|
|
|
|
|
func (mux *QueryMux) HandleQueryFunc(pattern string, handler func(RequestWriter, *Msg)) {
|
2011-04-13 06:21:09 +10:00
|
|
|
mux.Handle(pattern, HandlerQueryFunc(handler))
|
2011-04-13 05:44:56 +10:00
|
|
|
}
|
|
|
|
|
2011-04-18 05:56:40 +10:00
|
|
|
func (mux *QueryMux) QueryDNS(w RequestWriter, r *Msg) {
|
|
|
|
h := mux.match(r.Question[0].Name)
|
2011-04-13 06:21:09 +10:00
|
|
|
if h == nil {
|
|
|
|
// h = RefusedHandler()
|
|
|
|
// something else
|
|
|
|
}
|
2011-04-18 05:56:40 +10:00
|
|
|
h.QueryDNS(w, r)
|
2011-04-13 05:44:56 +10:00
|
|
|
}
|
|
|
|
|
2011-04-18 17:28:56 +10:00
|
|
|
// TODO add: LocalAddr
|
2011-04-13 05:44:56 +10:00
|
|
|
type Client struct {
|
2011-04-18 05:56:40 +10:00
|
|
|
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one
|
|
|
|
Addr string // address to call
|
|
|
|
Attempts int // number of attempts
|
|
|
|
Retry bool // retry with TCP
|
|
|
|
ChannelQuery chan *Request // read DNS request from this channel
|
|
|
|
ChannelReply chan []*Msg // read DNS request from this channel
|
|
|
|
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
|
|
|
|
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewClient() *Client {
|
|
|
|
c := new(Client)
|
|
|
|
c.Net = "udp"
|
|
|
|
c.Attempts = 1
|
|
|
|
c.ChannelReply = DefaultReplyChan
|
|
|
|
return c
|
|
|
|
}
|
|
|
|
|
|
|
|
type Query struct {
|
|
|
|
ChannelQuery chan *Request // read DNS request from this channel
|
|
|
|
Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (q *Query) Query() os.Error {
|
|
|
|
handler := q.Handler
|
2011-04-13 06:21:09 +10:00
|
|
|
if handler == nil {
|
|
|
|
handler = DefaultQueryMux
|
|
|
|
}
|
2011-04-15 06:11:41 +10:00
|
|
|
forever:
|
|
|
|
for {
|
|
|
|
select {
|
2011-04-18 05:56:40 +10:00
|
|
|
case in := <-q.ChannelQuery:
|
2011-04-15 06:11:41 +10:00
|
|
|
w := new(reply)
|
2011-04-18 05:56:40 +10:00
|
|
|
w.req = in.Request
|
|
|
|
w.addr = in.Addr
|
|
|
|
w.client = in.Client
|
|
|
|
handler.QueryDNS(w, in.Request)
|
2011-04-15 06:11:41 +10:00
|
|
|
}
|
|
|
|
}
|
2011-04-13 06:21:09 +10:00
|
|
|
return nil
|
2011-04-13 05:44:56 +10:00
|
|
|
}
|
|
|
|
|
2011-04-18 05:56:40 +10:00
|
|
|
func (q *Query) ListenAndQuery() os.Error {
|
|
|
|
if q.ChannelQuery == nil {
|
|
|
|
q.ChannelQuery = DefaultQueryChan
|
2011-04-15 06:11:41 +10:00
|
|
|
}
|
2011-04-18 05:56:40 +10:00
|
|
|
return q.Query()
|
|
|
|
}
|
|
|
|
|
|
|
|
func ListenAndQuery(c chan *Request, handler QueryHandler) {
|
|
|
|
q := &Query{ChannelQuery: c, Handler: handler}
|
|
|
|
go q.ListenAndQuery()
|
2011-04-15 06:11:41 +10:00
|
|
|
}
|
|
|
|
|
2011-04-18 05:56:40 +10:00
|
|
|
func (w *reply) Write(m *Msg) {
|
|
|
|
w.Client().ChannelReply <- []*Msg{w.req, m}
|
|
|
|
}
|
|
|
|
|
2011-04-18 17:28:56 +10:00
|
|
|
// Do performs an asynchronize query. The result is returned on the
|
|
|
|
// channel set in the c.
|
2011-04-18 05:56:40 +10:00
|
|
|
func (c *Client) Do(m *Msg, a string) {
|
2011-04-15 06:11:41 +10:00
|
|
|
if c.ChannelQuery == nil {
|
2011-04-18 05:56:40 +10:00
|
|
|
DefaultQueryChan <- &Request{Client: c, Addr: a, Request: m}
|
|
|
|
} else {
|
|
|
|
c.ChannelQuery <- &Request{Client: c, Addr: a, Request: m}
|
2011-04-16 07:55:27 +10:00
|
|
|
}
|
2011-04-15 06:11:41 +10:00
|
|
|
}
|
|
|
|
|
2011-04-18 17:28:56 +10:00
|
|
|
// A sync query
|
2011-04-18 05:56:40 +10:00
|
|
|
func (c *Client) Exchange(m *Msg, a string) *Msg {
|
2011-04-18 17:28:56 +10:00
|
|
|
w := new(reply)
|
|
|
|
w.client = c
|
|
|
|
w.addr = a
|
2011-04-18 05:56:40 +10:00
|
|
|
out, ok := m.Pack()
|
|
|
|
if !ok {
|
2011-04-18 17:28:56 +10:00
|
|
|
//
|
|
|
|
}
|
|
|
|
_, err := w.writeClient(out)
|
|
|
|
if err != nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// udp / tcp
|
|
|
|
p := make([]byte, DefaultMsgSize)
|
|
|
|
n, err := w.readClient(p)
|
|
|
|
if err != nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
p = p[:n]
|
|
|
|
if ok := m.Unpack(p); !ok {
|
|
|
|
return nil
|
2011-04-18 05:56:40 +10:00
|
|
|
}
|
|
|
|
return m
|
2011-04-16 05:42:27 +10:00
|
|
|
}
|
|
|
|
|
|
|
|
func (w *reply) WriteMessages(m []*Msg) {
|
2011-04-18 05:56:40 +10:00
|
|
|
m1 := append([]*Msg{w.req}, m...)
|
|
|
|
w.Client().ChannelReply <- m1
|
2011-04-13 05:44:56 +10:00
|
|
|
}
|
2011-04-16 07:55:27 +10:00
|
|
|
|
2011-04-18 05:56:40 +10:00
|
|
|
func (w *reply) Client() *Client {
|
|
|
|
return w.client
|
2011-04-17 18:54:34 +10:00
|
|
|
}
|
|
|
|
|
2011-04-18 05:56:40 +10:00
|
|
|
func (w *reply) Request() *Msg {
|
|
|
|
return w.req
|
|
|
|
}
|
|
|
|
|
2011-04-18 17:28:56 +10:00
|
|
|
func (w *reply) Receive() (*Msg, os.Error) {
|
2011-04-18 05:56:40 +10:00
|
|
|
var p []byte
|
|
|
|
m := new(Msg)
|
|
|
|
switch w.Client().Net {
|
2011-04-17 18:54:34 +10:00
|
|
|
case "tcp":
|
2011-04-19 02:29:46 +10:00
|
|
|
p = make([]byte, MaxMsgSize)
|
2011-04-18 05:56:40 +10:00
|
|
|
case "udp":
|
|
|
|
p = make([]byte, DefaultMsgSize)
|
2011-04-19 02:29:46 +10:00
|
|
|
}
|
|
|
|
n, err := w.readClient(p)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
p = p[:n]
|
|
|
|
if ok := m.Unpack(p); !ok {
|
|
|
|
return nil, ErrUnpack
|
|
|
|
}
|
2011-04-18 05:56:40 +10:00
|
|
|
return m, nil
|
|
|
|
}
|
2011-04-16 07:55:27 +10:00
|
|
|
|
2011-04-18 05:56:40 +10:00
|
|
|
func (w *reply) readClient(p []byte) (n int, err os.Error) {
|
|
|
|
if w.conn == nil {
|
|
|
|
panic("no connection")
|
|
|
|
}
|
|
|
|
switch w.Client().Net {
|
|
|
|
case "tcp":
|
2011-04-19 02:27:59 +10:00
|
|
|
if len(p) < 1 {
|
|
|
|
return 0, io.ErrShortBuffer
|
|
|
|
}
|
|
|
|
n, err = w.conn.(*net.TCPConn).Read(p[0:2])
|
|
|
|
if err != nil || n != 2 {
|
|
|
|
return n, err
|
|
|
|
}
|
|
|
|
l, _ := unpackUint16(p[0:2], 0)
|
|
|
|
if l == 0 {
|
|
|
|
return 0, ErrShortRead
|
|
|
|
}
|
|
|
|
if int(l) > len(p) {
|
|
|
|
return int(l), io.ErrShortBuffer
|
|
|
|
}
|
|
|
|
n, err = w.conn.(*net.TCPConn).Read(p[:l])
|
|
|
|
if err != nil {
|
|
|
|
return n, err
|
|
|
|
}
|
|
|
|
i := n
|
|
|
|
for i < int(l) {
|
|
|
|
j, err := w.conn.(*net.TCPConn).Read(p[i:int(l)])
|
|
|
|
if err != nil {
|
|
|
|
return i, err
|
|
|
|
}
|
|
|
|
i += j
|
|
|
|
}
|
|
|
|
n = i
|
2011-04-17 18:54:34 +10:00
|
|
|
case "udp":
|
2011-04-18 05:56:40 +10:00
|
|
|
n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p)
|
2011-04-17 18:54:34 +10:00
|
|
|
if err != nil {
|
|
|
|
return n, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return
|
2011-04-16 07:55:27 +10:00
|
|
|
}
|
|
|
|
|
2011-04-18 17:28:56 +10:00
|
|
|
func (w *reply) Send(m *Msg) os.Error {
|
2011-04-16 07:55:27 +10:00
|
|
|
out, ok := m.Pack()
|
|
|
|
if !ok {
|
|
|
|
return ErrPack
|
|
|
|
}
|
2011-04-18 05:56:40 +10:00
|
|
|
_, err := w.writeClient(out)
|
2011-04-16 07:55:27 +10:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2011-04-18 05:56:40 +10:00
|
|
|
func (w *reply) writeClient(p []byte) (n int, err os.Error) {
|
|
|
|
c := w.Client()
|
|
|
|
if c.Attempts == 0 {
|
|
|
|
panic("c.Attempts 0")
|
|
|
|
}
|
|
|
|
if c.Net == "" {
|
|
|
|
panic("c.Net empty")
|
|
|
|
}
|
|
|
|
|
|
|
|
conn, err := net.Dial(c.Net, "", w.addr)
|
2011-04-16 07:55:27 +10:00
|
|
|
if err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
2011-04-18 05:56:40 +10:00
|
|
|
w.conn = conn
|
2011-04-16 07:55:27 +10:00
|
|
|
switch c.Net {
|
|
|
|
case "tcp":
|
|
|
|
if len(p) < 2 {
|
|
|
|
return 0, io.ErrShortBuffer
|
|
|
|
}
|
|
|
|
for a := 0; a < c.Attempts; a++ {
|
|
|
|
l := make([]byte, 2)
|
|
|
|
l[0], l[1] = packUint16(uint16(len(p)))
|
|
|
|
n, err = conn.Write(l)
|
|
|
|
if err != nil {
|
|
|
|
if e, ok := err.(net.Error); ok && e.Timeout() {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
return n, err
|
|
|
|
}
|
|
|
|
if n != 2 {
|
|
|
|
return n, io.ErrShortWrite
|
|
|
|
}
|
|
|
|
n, err = conn.Write(p)
|
|
|
|
if err != nil {
|
|
|
|
if e, ok := err.(net.Error); ok && e.Timeout() {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
return n, err
|
|
|
|
}
|
|
|
|
i := n
|
|
|
|
if i < len(p) {
|
|
|
|
j, err := conn.Write(p[i:len(p)])
|
|
|
|
if err != nil {
|
|
|
|
if e, ok := err.(net.Error); ok && e.Timeout() {
|
|
|
|
// We are half way in our write...
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
return i, err
|
|
|
|
}
|
|
|
|
i += j
|
|
|
|
}
|
|
|
|
n = i
|
|
|
|
}
|
|
|
|
case "udp":
|
|
|
|
for a := 0; a < c.Attempts; a++ {
|
|
|
|
n, err = conn.(*net.UDPConn).WriteTo(p, conn.RemoteAddr())
|
|
|
|
if err != nil {
|
|
|
|
if e, ok := err.(net.Error); ok && e.Timeout() {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
return 0, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return 0, nil
|
|
|
|
}
|