Make the default async queries more simpler to use
If you don't want to setup your own channels things should now be simpler. Still the power for eleborate setups is there.
This commit is contained in:
parent
7b2745c51d
commit
06d424549e
55
client.go
55
client.go
|
@ -70,12 +70,12 @@ func newQueryChan() chan *Request { return make(chan *Request) }
|
|||
|
||||
// Default channels to use for the resolver
|
||||
var (
|
||||
// DefaultReplyChan is the channel on which the replies are
|
||||
// Incoming is the channel on which the replies are
|
||||
// coming back. Is it a channel of *Exchange, so that the original
|
||||
// question is included with the answer.
|
||||
Incoming = newQueryChanSlice()
|
||||
// defaultQueryChan is the channel were you can send the questions to.
|
||||
defaultQueryChan = newQueryChan()
|
||||
IncomingQuery = newQueryChanSlice()
|
||||
// RequestQuery is the channel were you can send the questions to.
|
||||
RequestQuery = newQueryChan()
|
||||
)
|
||||
|
||||
// The HandlerQueryFunc type is an adapter to allow the use of
|
||||
|
@ -171,7 +171,8 @@ type Client struct {
|
|||
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one
|
||||
Attempts int // number of attempts
|
||||
Retry bool // retry with TCP
|
||||
QueryChan chan *Request // read DNS request from this channel
|
||||
Request chan *Request // read DNS request from this channel
|
||||
Incoming chan *Exchange // write replies to this channel
|
||||
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns)
|
||||
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns)
|
||||
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
|
||||
|
@ -186,14 +187,23 @@ func NewClient() *Client {
|
|||
c := new(Client)
|
||||
c.Net = "udp"
|
||||
c.Attempts = 1
|
||||
c.QueryChan = defaultQueryChan
|
||||
c.Request = RequestQuery
|
||||
c.Incoming = IncomingQuery
|
||||
c.ReadTimeout = 2 * 1e9
|
||||
c.WriteTimeout = 2 * 1e9
|
||||
return c
|
||||
}
|
||||
|
||||
// NewClientRequest allows the setting of a request channel which
|
||||
// will be used instead of the default.
|
||||
func NewClientRequest(c chan *Request) *Client {
|
||||
c1 := NewClient()
|
||||
c1.Request = c
|
||||
return c1
|
||||
}
|
||||
|
||||
type Query struct {
|
||||
QueryChan chan *Request // read DNS request from this channel
|
||||
Request chan *Request // read DNS request from this channel
|
||||
Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil
|
||||
}
|
||||
|
||||
|
@ -204,7 +214,7 @@ func (q *Query) Query() error {
|
|||
}
|
||||
for {
|
||||
select {
|
||||
case in := <-q.QueryChan:
|
||||
case in := <-q.Request:
|
||||
w := new(reply)
|
||||
w.req = in.Request
|
||||
w.addr = in.Addr
|
||||
|
@ -216,27 +226,36 @@ func (q *Query) Query() error {
|
|||
}
|
||||
|
||||
func (q *Query) ListenAndQuery() error {
|
||||
if q.QueryChan == nil {
|
||||
q.QueryChan = defaultQueryChan
|
||||
if q.Request == nil {
|
||||
q.Request = RequestQuery
|
||||
}
|
||||
return q.Query()
|
||||
}
|
||||
|
||||
// ListenAndQuery starts the listener for firing off the queries. If
|
||||
// c is nil defaultQueryChan is used. If handler is nil
|
||||
// DefaultQueryMux is used.
|
||||
func ListenAndQuery(request chan *Request, handler QueryHandler) {
|
||||
q := &Query{QueryChan: request, Handler: handler}
|
||||
// ListenAndQuery starts the listener for firing off the queries.
|
||||
// If handler is nil DefaultQueryMux is used.
|
||||
func ListenAndQuery(handler QueryHandler) {
|
||||
q := &Query{Request: nil, Handler: handler}
|
||||
go q.ListenAndQuery()
|
||||
}
|
||||
|
||||
// ListenAndQueryRequest starts the listener for firing off the queries. If
|
||||
// c is nil defaultQueryChan is used. If handler is nil
|
||||
// DefaultQueryMux is used.
|
||||
func ListenAndQueryRequest(request chan *Request, handler QueryHandler) {
|
||||
q := &Query{Request: request, Handler: handler}
|
||||
go q.ListenAndQuery()
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Write returns the original question and the answer on the
|
||||
// reply channel of the client.
|
||||
func (w *reply) Write(m *Msg) error {
|
||||
if w.conn == nil {
|
||||
Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt}
|
||||
w.Client().Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt}
|
||||
} else {
|
||||
Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()}
|
||||
w.Client().Incoming <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -259,7 +278,7 @@ func (w *reply) RemoteAddr() net.Addr {
|
|||
//
|
||||
// r is of type Exchange.
|
||||
func (c *Client) Do(m *Msg, a string) {
|
||||
c.QueryChan <- &Request{Client: c, Addr: a, Request: m}
|
||||
c.Request <- &Request{Client: c, Addr: a, Request: m}
|
||||
}
|
||||
|
||||
// exchangeBuffer performs a synchronous query. It sends the buffer m to the
|
||||
|
|
|
@ -27,7 +27,7 @@ func helloMiek(w RequestWriter, r *Msg) {
|
|||
|
||||
func TestClientASync(t *testing.T) {
|
||||
HandleQueryFunc("miek.nl.", helloMiek) // All queries for miek.nl will be handled by HelloMiek
|
||||
ListenAndQuery(nil, nil) // Detect if this isn't running
|
||||
ListenAndQuery(nil) // Detect if this isn't running
|
||||
|
||||
m := new(Msg)
|
||||
m.SetQuestion("miek.nl.", TypeSOA)
|
||||
|
@ -38,7 +38,7 @@ func TestClientASync(t *testing.T) {
|
|||
forever:
|
||||
for {
|
||||
select {
|
||||
case n := <-c.ReplyChan:
|
||||
case n := <-c.Incoming:
|
||||
if n.Reply != nil && n.Reply.Rcode != RcodeSuccess {
|
||||
t.Log("Failed to get an valid answer")
|
||||
t.Fail()
|
||||
|
|
|
@ -52,7 +52,7 @@ func qhandler(w dns.RequestWriter, m *dns.Msg) {
|
|||
|
||||
func addresses(conf *dns.ClientConfig, c *dns.Client, name string) []string {
|
||||
dns.HandleQueryFunc(os.Args[1], qhandler)
|
||||
dns.ListenAndQuery(nil, nil)
|
||||
dns.ListenAndQuery(nil)
|
||||
|
||||
m4 := new(dns.Msg)
|
||||
m4.SetQuestion(dns.Fqdn(os.Args[1]), dns.TypeA)
|
||||
|
@ -66,7 +66,7 @@ func addresses(conf *dns.ClientConfig, c *dns.Client, name string) []string {
|
|||
forever:
|
||||
for {
|
||||
select {
|
||||
case r := <-dns.Incoming:
|
||||
case r := <-c.Incoming:
|
||||
if r.Reply != nil && r.Reply.Rcode == dns.RcodeSuccess {
|
||||
for _, aa := range r.Reply.Answer {
|
||||
switch aa.(type) {
|
||||
|
|
|
@ -136,7 +136,7 @@ Flags:
|
|||
|
||||
// We use the async query handling, just to show how it is to be used.
|
||||
dns.HandleQuery(".", q)
|
||||
dns.ListenAndQuery(nil, nil)
|
||||
dns.ListenAndQuery(nil)
|
||||
c := dns.NewClient()
|
||||
if *tcp {
|
||||
c.Net = "tcp"
|
||||
|
@ -199,7 +199,7 @@ Flags:
|
|||
forever:
|
||||
for {
|
||||
select {
|
||||
case r := <-dns.Incoming:
|
||||
case r := <-c.Incoming:
|
||||
if r.Reply != nil {
|
||||
if r.Reply.Rcode == dns.RcodeSuccess {
|
||||
if r.Request.Id != r.Reply.Id {
|
||||
|
|
Loading…
Reference in New Issue