remove the need for NewClient()

This commit is contained in:
Miek Gieben 2012-05-26 10:24:47 +02:00
parent eb7b7c9745
commit 458a74b8ce
2 changed files with 49 additions and 34 deletions

View File

@ -3,8 +3,6 @@
* outgoing [AI]xfr
* zonereader that extracts glue (or at least signals it) and other usefull stuff?
* a complete dnssec resolver
* the outgoing channel for resolver isn't usefull - remove it.
## Nice to have

View File

@ -169,14 +169,15 @@ func (mux *QueryMux) QueryDNS(w RequestWriter, r *Msg) {
h.QueryDNS(w, r)
}
// A nil Client is usable.
type Client struct {
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one
Attempts int // number of attempts
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "", is UDP)
Attempts int // number of attempts, if not set defaults to 1
Retry bool // retry with TCP
Request chan *Request // read DNS request from this channel
Reply chan *Exchange // write replies to this channel
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns)
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns)
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defauls to 2 * 1e9
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defauls to 2 * 1e9
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
Hijacked net.Conn // if set the calling code takes care of the connection
// LocalAddr string // Local address to use
@ -245,8 +246,16 @@ func ListenAndQueryRequest(request chan *Request, handler QueryHandler) {
// reply channel of the client.
func (w *reply) Write(m *Msg) error {
if w.conn == nil {
if w.Client().Reply == nil {
QueryReply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt}
return nil
}
w.Client().Reply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt}
} else {
if w.Client().Reply == nil {
QueryReply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()}
return nil
}
w.Client().Reply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()}
}
return nil
@ -330,7 +339,7 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, addr
switch c.Net {
case "tcp", "tcp4", "tcp6":
in = make([]byte, MaxMsgSize)
case "udp", "udp4", "udp6":
case "", "udp", "udp4", "udp6":
size := UDPMsgSize
for _, r := range m.Extra {
if r.Header().Rrtype == TypeOPT {
@ -365,7 +374,7 @@ func (w *reply) Receive() (*Msg, error) {
switch w.Client().Net {
case "tcp", "tcp4", "tcp6":
p = make([]byte, MaxMsgSize)
case "udp", "udp4", "udp6":
case "", "udp", "udp4", "udp6":
p = make([]byte, DefaultMsgSize)
}
n, err := w.readClient(p)
@ -393,15 +402,17 @@ func (w *reply) readClient(p []byte) (n int, err error) {
if w.conn == nil {
return 0, ErrConnEmpty
}
if len(p) < 1 {
return 0, io.ErrShortBuffer
}
attempts := w.Client().Attempts
if attempts == 0 {
attempts = 1
}
switch w.Client().Net {
case "tcp", "tcp4", "tcp6":
if len(p) < 1 {
return 0, io.ErrShortBuffer
}
for a := 0; a < w.Client().Attempts; a++ {
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout))
setTimeouts(w)
for a := 0; a < attempts; a++ {
n, err = w.conn.(*net.TCPConn).Read(p[0:2])
if err != nil || n != 2 {
if e, ok := err.(net.Error); ok && e.Timeout() {
@ -437,11 +448,9 @@ func (w *reply) readClient(p []byte) (n int, err error) {
}
n = i
}
case "udp", "udp4", "udp6":
for a := 0; a < w.Client().Attempts; a++ {
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
w.conn.SetWriteDeadline(time.Now().Add(w.Client().ReadTimeout))
case "", "udp", "udp4", "udp6":
for a := 0; a < attempts; a++ {
setTimeouts(w)
n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p)
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
@ -485,11 +494,9 @@ func (w *reply) Send(m *Msg) (err error) {
}
func (w *reply) writeClient(p []byte) (n int, err error) {
if w.Client().Attempts == 0 {
panic("c.Attempts 0")
}
if w.Client().Net == "" {
panic("c.Net empty")
attempts := w.Client().Attempts
if attempts == 0 {
attempts = 1
}
if w.Client().Hijacked == nil {
if err = w.Dial(); err != nil {
@ -501,10 +508,8 @@ func (w *reply) writeClient(p []byte) (n int, err error) {
if len(p) < 2 {
return 0, io.ErrShortBuffer
}
for a := 0; a < w.Client().Attempts; a++ {
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout))
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
for a := 0; a < attempts; a++ {
setTimeouts(w)
a, b := packUint16(uint16(len(p)))
n, err = w.conn.Write([]byte{a, b})
if err != nil {
@ -537,11 +542,9 @@ func (w *reply) writeClient(p []byte) (n int, err error) {
}
n = i
}
case "udp", "udp4", "udp6":
for a := 0; a < w.Client().Attempts; a++ {
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout))
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
case "", "udp", "udp4", "udp6":
for a := 0; a < attempts; a++ {
setTimeouts(w)
n, err = w.conn.(*net.UDPConn).Write(p)
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
@ -554,6 +557,20 @@ func (w *reply) writeClient(p []byte) (n int, err error) {
return
}
func setTimeouts(w *reply) {
if w.Client().ReadTimeout == 0 {
w.conn.SetReadDeadline(time.Now().Add(2 * 1e9))
} else {
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
}
if w.Client().WriteTimeout == 0 {
w.conn.SetWriteDeadline(time.Now().Add(2 * 1e9))
} else {
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout))
}
}
// Close implents the RequestWriter.Close method
func (w *reply) Close() (err error) { return w.conn.Close() }