remove the need for NewClient()

This commit is contained in:
Miek Gieben 2012-05-26 10:24:47 +02:00
parent eb7b7c9745
commit 458a74b8ce
2 changed files with 49 additions and 34 deletions

View File

@ -3,8 +3,6 @@
* outgoing [AI]xfr * outgoing [AI]xfr
* zonereader that extracts glue (or at least signals it) and other usefull stuff? * zonereader that extracts glue (or at least signals it) and other usefull stuff?
* a complete dnssec resolver * a complete dnssec resolver
* the outgoing channel for resolver isn't usefull - remove it.
## Nice to have ## Nice to have

View File

@ -169,14 +169,15 @@ func (mux *QueryMux) QueryDNS(w RequestWriter, r *Msg) {
h.QueryDNS(w, r) h.QueryDNS(w, r)
} }
// A nil Client is usable.
type Client struct { type Client struct {
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "", is UDP)
Attempts int // number of attempts Attempts int // number of attempts, if not set defaults to 1
Retry bool // retry with TCP Retry bool // retry with TCP
Request chan *Request // read DNS request from this channel Request chan *Request // read DNS request from this channel
Reply chan *Exchange // write replies to this channel Reply chan *Exchange // write replies to this channel
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns) ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defauls to 2 * 1e9
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns) WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defauls to 2 * 1e9
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret> TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
Hijacked net.Conn // if set the calling code takes care of the connection Hijacked net.Conn // if set the calling code takes care of the connection
// LocalAddr string // Local address to use // LocalAddr string // Local address to use
@ -245,8 +246,16 @@ func ListenAndQueryRequest(request chan *Request, handler QueryHandler) {
// reply channel of the client. // reply channel of the client.
func (w *reply) Write(m *Msg) error { func (w *reply) Write(m *Msg) error {
if w.conn == nil { if w.conn == nil {
if w.Client().Reply == nil {
QueryReply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt}
return nil
}
w.Client().Reply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt} w.Client().Reply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt}
} else { } else {
if w.Client().Reply == nil {
QueryReply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()}
return nil
}
w.Client().Reply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()} w.Client().Reply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()}
} }
return nil return nil
@ -330,7 +339,7 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, addr
switch c.Net { switch c.Net {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
in = make([]byte, MaxMsgSize) in = make([]byte, MaxMsgSize)
case "udp", "udp4", "udp6": case "", "udp", "udp4", "udp6":
size := UDPMsgSize size := UDPMsgSize
for _, r := range m.Extra { for _, r := range m.Extra {
if r.Header().Rrtype == TypeOPT { if r.Header().Rrtype == TypeOPT {
@ -365,7 +374,7 @@ func (w *reply) Receive() (*Msg, error) {
switch w.Client().Net { switch w.Client().Net {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
p = make([]byte, MaxMsgSize) p = make([]byte, MaxMsgSize)
case "udp", "udp4", "udp6": case "", "udp", "udp4", "udp6":
p = make([]byte, DefaultMsgSize) p = make([]byte, DefaultMsgSize)
} }
n, err := w.readClient(p) n, err := w.readClient(p)
@ -393,15 +402,17 @@ func (w *reply) readClient(p []byte) (n int, err error) {
if w.conn == nil { if w.conn == nil {
return 0, ErrConnEmpty return 0, ErrConnEmpty
} }
if len(p) < 1 {
return 0, io.ErrShortBuffer
}
attempts := w.Client().Attempts
if attempts == 0 {
attempts = 1
}
switch w.Client().Net { switch w.Client().Net {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
if len(p) < 1 { setTimeouts(w)
return 0, io.ErrShortBuffer for a := 0; a < attempts; a++ {
}
for a := 0; a < w.Client().Attempts; a++ {
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout))
n, err = w.conn.(*net.TCPConn).Read(p[0:2]) n, err = w.conn.(*net.TCPConn).Read(p[0:2])
if err != nil || n != 2 { if err != nil || n != 2 {
if e, ok := err.(net.Error); ok && e.Timeout() { if e, ok := err.(net.Error); ok && e.Timeout() {
@ -437,11 +448,9 @@ func (w *reply) readClient(p []byte) (n int, err error) {
} }
n = i n = i
} }
case "udp", "udp4", "udp6": case "", "udp", "udp4", "udp6":
for a := 0; a < w.Client().Attempts; a++ { for a := 0; a < attempts; a++ {
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout)) setTimeouts(w)
w.conn.SetWriteDeadline(time.Now().Add(w.Client().ReadTimeout))
n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p) n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p)
if err != nil { if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() { if e, ok := err.(net.Error); ok && e.Timeout() {
@ -485,11 +494,9 @@ func (w *reply) Send(m *Msg) (err error) {
} }
func (w *reply) writeClient(p []byte) (n int, err error) { func (w *reply) writeClient(p []byte) (n int, err error) {
if w.Client().Attempts == 0 { attempts := w.Client().Attempts
panic("c.Attempts 0") if attempts == 0 {
} attempts = 1
if w.Client().Net == "" {
panic("c.Net empty")
} }
if w.Client().Hijacked == nil { if w.Client().Hijacked == nil {
if err = w.Dial(); err != nil { if err = w.Dial(); err != nil {
@ -501,10 +508,8 @@ func (w *reply) writeClient(p []byte) (n int, err error) {
if len(p) < 2 { if len(p) < 2 {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
for a := 0; a < w.Client().Attempts; a++ { for a := 0; a < attempts; a++ {
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout)) setTimeouts(w)
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
a, b := packUint16(uint16(len(p))) a, b := packUint16(uint16(len(p)))
n, err = w.conn.Write([]byte{a, b}) n, err = w.conn.Write([]byte{a, b})
if err != nil { if err != nil {
@ -537,11 +542,9 @@ func (w *reply) writeClient(p []byte) (n int, err error) {
} }
n = i n = i
} }
case "udp", "udp4", "udp6": case "", "udp", "udp4", "udp6":
for a := 0; a < w.Client().Attempts; a++ { for a := 0; a < attempts; a++ {
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout)) setTimeouts(w)
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
n, err = w.conn.(*net.UDPConn).Write(p) n, err = w.conn.(*net.UDPConn).Write(p)
if err != nil { if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() { if e, ok := err.(net.Error); ok && e.Timeout() {
@ -554,6 +557,20 @@ func (w *reply) writeClient(p []byte) (n int, err error) {
return return
} }
func setTimeouts(w *reply) {
if w.Client().ReadTimeout == 0 {
w.conn.SetReadDeadline(time.Now().Add(2 * 1e9))
} else {
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
}
if w.Client().WriteTimeout == 0 {
w.conn.SetWriteDeadline(time.Now().Add(2 * 1e9))
} else {
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout))
}
}
// Close implents the RequestWriter.Close method // Close implents the RequestWriter.Close method
func (w *reply) Close() (err error) { return w.conn.Close() } func (w *reply) Close() (err error) { return w.conn.Close() }