Make the default async queries more simpler to use

If you don't want to setup your own channels things should
now be simpler. Still the power for eleborate setups is there.
This commit is contained in:
Miek Gieben 2012-05-21 20:58:41 +02:00
parent 7b2745c51d
commit 06d424549e
4 changed files with 43 additions and 24 deletions

View File

@ -70,12 +70,12 @@ func newQueryChan() chan *Request { return make(chan *Request) }
// Default channels to use for the resolver
var (
// DefaultReplyChan is the channel on which the replies are
// Incoming is the channel on which the replies are
// coming back. Is it a channel of *Exchange, so that the original
// question is included with the answer.
Incoming = newQueryChanSlice()
// defaultQueryChan is the channel were you can send the questions to.
defaultQueryChan = newQueryChan()
IncomingQuery = newQueryChanSlice()
// RequestQuery is the channel were you can send the questions to.
RequestQuery = newQueryChan()
)
// The HandlerQueryFunc type is an adapter to allow the use of
@ -171,7 +171,8 @@ type Client struct {
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one
Attempts int // number of attempts
Retry bool // retry with TCP
QueryChan chan *Request // read DNS request from this channel
Request chan *Request // read DNS request from this channel
Incoming chan *Exchange // write replies to this channel
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns)
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns)
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
@ -186,14 +187,23 @@ func NewClient() *Client {
c := new(Client)
c.Net = "udp"
c.Attempts = 1
c.QueryChan = defaultQueryChan
c.Request = RequestQuery
c.Incoming = IncomingQuery
c.ReadTimeout = 2 * 1e9
c.WriteTimeout = 2 * 1e9
return c
}
// NewClientRequest allows the setting of a request channel which
// will be used instead of the default.
func NewClientRequest(c chan *Request) *Client {
c1 := NewClient()
c1.Request = c
return c1
}
type Query struct {
QueryChan chan *Request // read DNS request from this channel
Request chan *Request // read DNS request from this channel
Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil
}
@ -204,7 +214,7 @@ func (q *Query) Query() error {
}
for {
select {
case in := <-q.QueryChan:
case in := <-q.Request:
w := new(reply)
w.req = in.Request
w.addr = in.Addr
@ -216,27 +226,36 @@ func (q *Query) Query() error {
}
func (q *Query) ListenAndQuery() error {
if q.QueryChan == nil {
q.QueryChan = defaultQueryChan
if q.Request == nil {
q.Request = RequestQuery
}
return q.Query()
}
// ListenAndQuery starts the listener for firing off the queries. If
// c is nil defaultQueryChan is used. If handler is nil
// DefaultQueryMux is used.
func ListenAndQuery(request chan *Request, handler QueryHandler) {
q := &Query{QueryChan: request, Handler: handler}
// ListenAndQuery starts the listener for firing off the queries.
// If handler is nil DefaultQueryMux is used.
func ListenAndQuery(handler QueryHandler) {
q := &Query{Request: nil, Handler: handler}
go q.ListenAndQuery()
}
// ListenAndQueryRequest starts the listener for firing off the queries. If
// c is nil defaultQueryChan is used. If handler is nil
// DefaultQueryMux is used.
func ListenAndQueryRequest(request chan *Request, handler QueryHandler) {
q := &Query{Request: request, Handler: handler}
go q.ListenAndQuery()
}
// Write returns the original question and the answer on the
// reply channel of the client.
func (w *reply) Write(m *Msg) error {
if w.conn == nil {
Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt}
w.Client().Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt}
} else {
Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()}
w.Client().Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()}
}
return nil
}
@ -259,7 +278,7 @@ func (w *reply) RemoteAddr() net.Addr {
//
// r is of type Exchange.
func (c *Client) Do(m *Msg, a string) {
c.QueryChan <- &Request{Client: c, Addr: a, Request: m}
c.Request <- &Request{Client: c, Addr: a, Request: m}
}
// exchangeBuffer performs a synchronous query. It sends the buffer m to the

View File

@ -27,7 +27,7 @@ func helloMiek(w RequestWriter, r *Msg) {
func TestClientASync(t *testing.T) {
HandleQueryFunc("miek.nl.", helloMiek) // All queries for miek.nl will be handled by HelloMiek
ListenAndQuery(nil, nil) // Detect if this isn't running
ListenAndQuery(nil) // Detect if this isn't running
m := new(Msg)
m.SetQuestion("miek.nl.", TypeSOA)
@ -38,7 +38,7 @@ func TestClientASync(t *testing.T) {
forever:
for {
select {
case n := <-c.ReplyChan:
case n := <-c.Incoming:
if n.Reply != nil && n.Reply.Rcode != RcodeSuccess {
t.Log("Failed to get an valid answer")
t.Fail()

View File

@ -52,7 +52,7 @@ func qhandler(w dns.RequestWriter, m *dns.Msg) {
func addresses(conf *dns.ClientConfig, c *dns.Client, name string) []string {
dns.HandleQueryFunc(os.Args[1], qhandler)
dns.ListenAndQuery(nil, nil)
dns.ListenAndQuery(nil)
m4 := new(dns.Msg)
m4.SetQuestion(dns.Fqdn(os.Args[1]), dns.TypeA)
@ -66,7 +66,7 @@ func addresses(conf *dns.ClientConfig, c *dns.Client, name string) []string {
forever:
for {
select {
case r := <-dns.Incoming:
case r := <-c.Incoming:
if r.Reply != nil && r.Reply.Rcode == dns.RcodeSuccess {
for _, aa := range r.Reply.Answer {
switch aa.(type) {

View File

@ -136,7 +136,7 @@ Flags:
// We use the async query handling, just to show how it is to be used.
dns.HandleQuery(".", q)
dns.ListenAndQuery(nil, nil)
dns.ListenAndQuery(nil)
c := dns.NewClient()
if *tcp {
c.Net = "tcp"
@ -199,7 +199,7 @@ Flags:
forever:
for {
select {
case r := <-dns.Incoming:
case r := <-c.Incoming:
if r.Reply != nil {
if r.Reply.Rcode == dns.RcodeSuccess {
if r.Request.Id != r.Reply.Id {