diff --git a/client.go b/client.go index 14a68453..5882b100 100644 --- a/client.go +++ b/client.go @@ -70,12 +70,12 @@ func newQueryChan() chan *Request { return make(chan *Request) } // Default channels to use for the resolver var ( - // DefaultReplyChan is the channel on which the replies are + // Incoming is the channel on which the replies are // coming back. Is it a channel of *Exchange, so that the original // question is included with the answer. - Incoming = newQueryChanSlice() - // defaultQueryChan is the channel were you can send the questions to. - defaultQueryChan = newQueryChan() + IncomingQuery = newQueryChanSlice() + // RequestQuery is the channel were you can send the questions to. + RequestQuery = newQueryChan() ) // The HandlerQueryFunc type is an adapter to allow the use of @@ -171,7 +171,8 @@ type Client struct { Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one Attempts int // number of attempts Retry bool // retry with TCP - QueryChan chan *Request // read DNS request from this channel + Request chan *Request // read DNS request from this channel + Incoming chan *Exchange // write replies to this channel ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns) WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns) TsigSecret map[string]string // secret(s) for Tsig map[] @@ -186,14 +187,23 @@ func NewClient() *Client { c := new(Client) c.Net = "udp" c.Attempts = 1 - c.QueryChan = defaultQueryChan + c.Request = RequestQuery + c.Incoming = IncomingQuery c.ReadTimeout = 2 * 1e9 c.WriteTimeout = 2 * 1e9 return c } +// NewClientRequest allows the setting of a request channel which +// will be used instead of the default. +func NewClientRequest(c chan *Request) *Client { + c1 := NewClient() + c1.Request = c + return c1 +} + type Query struct { - QueryChan chan *Request // read DNS request from this channel + Request chan *Request // read DNS request from this channel Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil } @@ -204,7 +214,7 @@ func (q *Query) Query() error { } for { select { - case in := <-q.QueryChan: + case in := <-q.Request: w := new(reply) w.req = in.Request w.addr = in.Addr @@ -216,27 +226,36 @@ func (q *Query) Query() error { } func (q *Query) ListenAndQuery() error { - if q.QueryChan == nil { - q.QueryChan = defaultQueryChan + if q.Request == nil { + q.Request = RequestQuery } return q.Query() } -// ListenAndQuery starts the listener for firing off the queries. If -// c is nil defaultQueryChan is used. If handler is nil -// DefaultQueryMux is used. -func ListenAndQuery(request chan *Request, handler QueryHandler) { - q := &Query{QueryChan: request, Handler: handler} +// ListenAndQuery starts the listener for firing off the queries. +// If handler is nil DefaultQueryMux is used. +func ListenAndQuery(handler QueryHandler) { + q := &Query{Request: nil, Handler: handler} go q.ListenAndQuery() } +// ListenAndQueryRequest starts the listener for firing off the queries. If +// c is nil defaultQueryChan is used. If handler is nil +// DefaultQueryMux is used. +func ListenAndQueryRequest(request chan *Request, handler QueryHandler) { + q := &Query{Request: request, Handler: handler} + go q.ListenAndQuery() +} + + + // Write returns the original question and the answer on the // reply channel of the client. func (w *reply) Write(m *Msg) error { if w.conn == nil { - Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt} + w.Client().Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt} } else { - Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()} + w.Client().Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()} } return nil } @@ -259,7 +278,7 @@ func (w *reply) RemoteAddr() net.Addr { // // r is of type Exchange. func (c *Client) Do(m *Msg, a string) { - c.QueryChan <- &Request{Client: c, Addr: a, Request: m} + c.Request <- &Request{Client: c, Addr: a, Request: m} } // exchangeBuffer performs a synchronous query. It sends the buffer m to the diff --git a/client_test.go b/client_test.go index 900ff696..d848100a 100644 --- a/client_test.go +++ b/client_test.go @@ -27,7 +27,7 @@ func helloMiek(w RequestWriter, r *Msg) { func TestClientASync(t *testing.T) { HandleQueryFunc("miek.nl.", helloMiek) // All queries for miek.nl will be handled by HelloMiek - ListenAndQuery(nil, nil) // Detect if this isn't running + ListenAndQuery(nil) // Detect if this isn't running m := new(Msg) m.SetQuestion("miek.nl.", TypeSOA) @@ -38,7 +38,7 @@ func TestClientASync(t *testing.T) { forever: for { select { - case n := <-c.ReplyChan: + case n := <-c.Incoming: if n.Reply != nil && n.Reply.Rcode != RcodeSuccess { t.Log("Failed to get an valid answer") t.Fail() diff --git a/ex/chaos/chaos.go b/ex/chaos/chaos.go index 280a2a6e..389a0b38 100644 --- a/ex/chaos/chaos.go +++ b/ex/chaos/chaos.go @@ -52,7 +52,7 @@ func qhandler(w dns.RequestWriter, m *dns.Msg) { func addresses(conf *dns.ClientConfig, c *dns.Client, name string) []string { dns.HandleQueryFunc(os.Args[1], qhandler) - dns.ListenAndQuery(nil, nil) + dns.ListenAndQuery(nil) m4 := new(dns.Msg) m4.SetQuestion(dns.Fqdn(os.Args[1]), dns.TypeA) @@ -66,7 +66,7 @@ func addresses(conf *dns.ClientConfig, c *dns.Client, name string) []string { forever: for { select { - case r := <-dns.Incoming: + case r := <-c.Incoming: if r.Reply != nil && r.Reply.Rcode == dns.RcodeSuccess { for _, aa := range r.Reply.Answer { switch aa.(type) { diff --git a/ex/q/q.go b/ex/q/q.go index 384ee558..f1e7404b 100644 --- a/ex/q/q.go +++ b/ex/q/q.go @@ -136,7 +136,7 @@ Flags: // We use the async query handling, just to show how it is to be used. dns.HandleQuery(".", q) - dns.ListenAndQuery(nil, nil) + dns.ListenAndQuery(nil) c := dns.NewClient() if *tcp { c.Net = "tcp" @@ -199,7 +199,7 @@ Flags: forever: for { select { - case r := <-dns.Incoming: + case r := <-c.Incoming: if r.Reply != nil { if r.Reply.Rcode == dns.RcodeSuccess { if r.Request.Id != r.Reply.Id {