Use a new type to send back request,answer

a type Exchange has been added, which makes communicating with
the resolver more strait forward.

This will also be used in the *xfr functions.
This commit is contained in:
Miek Gieben 2011-09-10 14:26:08 +02:00
parent b7ca96e7d4
commit 19bf874769
2 changed files with 58 additions and 53 deletions

View File

@ -21,7 +21,6 @@ type QueryHandler interface {
// The RequestWriter interface is used by a DNS query handler to
// construct a DNS request.
type RequestWriter interface {
WriteMessages([]*Msg)
Write(*Msg)
Send(*Msg) os.Error
Receive() (*Msg, os.Error)
@ -60,12 +59,16 @@ func NewQueryMux() *QueryMux { return &QueryMux{make(map[string]QueryHandler)} }
// DefaultQueryMux is the default QueryMux used by Query.
var DefaultQueryMux = NewQueryMux()
func newQueryChanSlice() chan []*Msg { return make(chan []*Msg) }
func newQueryChan() chan *Request { return make(chan *Request) }
func newQueryChanSlice() chan *Exchange { return make(chan *Exchange) }
func newQueryChan() chan *Request { return make(chan *Request) }
// Default channels to use for the resolver
var DefaultReplyChan = newQueryChanSlice()
var DefaultQueryChan = newQueryChan()
var (
DefaultReplyChan = newQueryChanSlice() // DefaultReplyChan is the channel on which the replies are
// coming back. Is it a channel of *Exchange, so that the original
// question is included with the answer.
DefaultQueryChan = newQueryChan() // DefaultQueryChan is the channel were you can send the questions to.
)
// The HandlerQueryFunc type is an adapter to allow the use of
// ordinary functions as DNS query handlers. If f is a function
@ -122,11 +125,11 @@ type Client struct {
Attempts int // number of attempts
Retry bool // retry with TCP
ChannelQuery chan *Request // read DNS request from this channel
ChannelReply chan []*Msg // read DNS request from this channel
ChannelReply chan *Exchange // write the reply (together with the DNS request) to this channel
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
Hijacked net.Conn // if set the calling code takes care of the connection
Hijacked net.Conn // if set the calling code takes care of the connection
// LocalAddr string // Local address to use
}
@ -136,8 +139,8 @@ func NewClient() *Client {
c.Net = "udp"
c.Attempts = 1
c.ChannelReply = DefaultReplyChan
c.ReadTimeout = 5000
c.WriteTimeout = 5000
c.ReadTimeout = 5000
c.WriteTimeout = 5000
return c
}
@ -172,8 +175,7 @@ func (q *Query) ListenAndQuery() os.Error {
return q.Query()
}
// Start listener for firing off the queries. If
// ListenAndQuery starts the listener for firing off the queries. If
// c is nil DefaultQueryChan is used. If handler is nil
// DefaultQueryMux is used.
func ListenAndQuery(c chan *Request, handler QueryHandler) {
@ -181,28 +183,31 @@ func ListenAndQuery(c chan *Request, handler QueryHandler) {
go q.ListenAndQuery()
}
// Write returns the original question and the answer on the reply channel of the
// client.
func (w *reply) Write(m *Msg) {
w.Client().ChannelReply <- []*Msg{w.req, m}
w.Client().ChannelReply <- &Exchange{Request: w.req, Reply: m}
}
// Dial dials a remote server and set... TODO
func (c *Client) Dial(addr string) os.Error {
conn, err := net.Dial(c.Net, addr)
if err != nil {
return err
}
c.Hijacked = conn
return nil
conn, err := net.Dial(c.Net, addr)
if err != nil {
return err
}
c.Hijacked = conn
return nil
}
func (c *Client) Close() os.Error {
if c.Hijacked == nil {
return nil // TODO
}
return c.Hijacked.Close()
if c.Hijacked == nil {
return nil // TODO
}
return c.Hijacked.Close()
}
// Do performs an asynchronous query. The result is returned on the
// channel set in the c. If no channel is set DefaultQueryChan is used.
// channel set in the Client c. If no channel is set DefaultQueryChan is used.
func (c *Client) Do(m *Msg, a string) {
if c.ChannelQuery == nil {
DefaultQueryChan <- &Request{Client: c, Addr: a, Request: m}
@ -211,21 +216,21 @@ func (c *Client) Do(m *Msg, a string) {
}
}
// ExchangeBuf performs a synchronous query. It sends the buffer m to the
// ExchangeBuffer performs a synchronous query. It sends the buffer m to the
// address (net.Addr?) contained in a
func (c *Client) ExchangeBuffer(inbuf []byte, a string, outbuf []byte) (n int, err os.Error) {
w := new(reply)
w.client = c
w.addr = a
if c.Hijacked == nil {
if err = w.Dial(); err != nil {
return 0, err
}
if c.Hijacked == nil {
if err = w.Dial(); err != nil {
return 0, err
}
defer w.Close()
}
if c.Hijacked != nil {
w.conn = c.Hijacked
}
if c.Hijacked != nil {
w.conn = c.Hijacked
}
if n, err = w.writeClient(inbuf); err != nil {
return 0, err
}
@ -243,13 +248,13 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, err os.Error) {
if !ok {
panic("failed to pack message")
}
var in []byte
switch c.Net {
case "tcp":
in = make([]byte, MaxMsgSize)
case "udp":
in = make([]byte, DefaultMsgSize)
}
var in []byte
switch c.Net {
case "tcp":
in = make([]byte, MaxMsgSize)
case "udp":
in = make([]byte, DefaultMsgSize)
}
if n, err = c.ExchangeBuffer(out, a, in); err != nil {
return nil, err
}
@ -275,12 +280,6 @@ func (w *reply) Close() (err os.Error) {
return w.conn.Close()
}
func (w *reply) WriteMessages(m []*Msg) {
m1 := append([]*Msg{w.req}, m...)
w.Client().ChannelReply <- m1
}
func (w *reply) Client() *Client {
return w.client
}

24
dns.go
View File

@ -144,6 +144,12 @@ func (s RRset) Ok() bool {
return true
}
// Exchange is used in communicating with the resolver.
type Exchange struct {
Request *Msg // The question sent.
Reply *Msg // The answer to the question that was sent.
}
// DNS resource records.
// There are many types of messages,
// but they all share the same header.
@ -215,13 +221,13 @@ func zoneMatch(pattern, zone string) (ok bool) {
}
// DnameLength returns the length of a packed dname.
func DomainNameLength(s string) int { // TODO better name
// Add trailing dot to canonicalize name.
if n := len(s); n == 0 || s[n-1] != '.' {
return n+1
} else {
return n+1
}
panic("not reached")
return 0
func DomainNameLength(s string) int { // TODO better name
// Add trailing dot to canonicalize name.
if n := len(s); n == 0 || s[n-1] != '.' {
return n + 1
} else {
return n + 1
}
panic("not reached")
return 0
}