diff --git a/TODO.markdown b/TODO.markdown index 7eaf036a..e961b8f3 100644 --- a/TODO.markdown +++ b/TODO.markdown @@ -3,8 +3,6 @@ * outgoing [AI]xfr * zonereader that extracts glue (or at least signals it) and other usefull stuff? * a complete dnssec resolver -* the outgoing channel for resolver isn't usefull - remove it. - ## Nice to have diff --git a/client.go b/client.go index 9c8172be..d8bb7ce5 100644 --- a/client.go +++ b/client.go @@ -169,14 +169,15 @@ func (mux *QueryMux) QueryDNS(w RequestWriter, r *Msg) { h.QueryDNS(w, r) } +// A nil Client is usable. type Client struct { - Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one - Attempts int // number of attempts + Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "", is UDP) + Attempts int // number of attempts, if not set defaults to 1 Retry bool // retry with TCP Request chan *Request // read DNS request from this channel Reply chan *Exchange // write replies to this channel - ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns) - WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns) + ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defauls to 2 * 1e9 + WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defauls to 2 * 1e9 TsigSecret map[string]string // secret(s) for Tsig map[] Hijacked net.Conn // if set the calling code takes care of the connection // LocalAddr string // Local address to use @@ -245,8 +246,16 @@ func ListenAndQueryRequest(request chan *Request, handler QueryHandler) { // reply channel of the client. func (w *reply) Write(m *Msg) error { if w.conn == nil { + if w.Client().Reply == nil { + QueryReply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt} + return nil + } w.Client().Reply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt} } else { + if w.Client().Reply == nil { + QueryReply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()} + return nil + } w.Client().Reply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()} } return nil @@ -330,7 +339,7 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, addr switch c.Net { case "tcp", "tcp4", "tcp6": in = make([]byte, MaxMsgSize) - case "udp", "udp4", "udp6": + case "", "udp", "udp4", "udp6": size := UDPMsgSize for _, r := range m.Extra { if r.Header().Rrtype == TypeOPT { @@ -365,7 +374,7 @@ func (w *reply) Receive() (*Msg, error) { switch w.Client().Net { case "tcp", "tcp4", "tcp6": p = make([]byte, MaxMsgSize) - case "udp", "udp4", "udp6": + case "", "udp", "udp4", "udp6": p = make([]byte, DefaultMsgSize) } n, err := w.readClient(p) @@ -393,15 +402,17 @@ func (w *reply) readClient(p []byte) (n int, err error) { if w.conn == nil { return 0, ErrConnEmpty } + if len(p) < 1 { + return 0, io.ErrShortBuffer + } + attempts := w.Client().Attempts + if attempts == 0 { + attempts = 1 + } switch w.Client().Net { case "tcp", "tcp4", "tcp6": - if len(p) < 1 { - return 0, io.ErrShortBuffer - } - for a := 0; a < w.Client().Attempts; a++ { - w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout)) - w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout)) - + setTimeouts(w) + for a := 0; a < attempts; a++ { n, err = w.conn.(*net.TCPConn).Read(p[0:2]) if err != nil || n != 2 { if e, ok := err.(net.Error); ok && e.Timeout() { @@ -437,11 +448,9 @@ func (w *reply) readClient(p []byte) (n int, err error) { } n = i } - case "udp", "udp4", "udp6": - for a := 0; a < w.Client().Attempts; a++ { - w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout)) - w.conn.SetWriteDeadline(time.Now().Add(w.Client().ReadTimeout)) - + case "", "udp", "udp4", "udp6": + for a := 0; a < attempts; a++ { + setTimeouts(w) n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p) if err != nil { if e, ok := err.(net.Error); ok && e.Timeout() { @@ -485,11 +494,9 @@ func (w *reply) Send(m *Msg) (err error) { } func (w *reply) writeClient(p []byte) (n int, err error) { - if w.Client().Attempts == 0 { - panic("c.Attempts 0") - } - if w.Client().Net == "" { - panic("c.Net empty") + attempts := w.Client().Attempts + if attempts == 0 { + attempts = 1 } if w.Client().Hijacked == nil { if err = w.Dial(); err != nil { @@ -501,10 +508,8 @@ func (w *reply) writeClient(p []byte) (n int, err error) { if len(p) < 2 { return 0, io.ErrShortBuffer } - for a := 0; a < w.Client().Attempts; a++ { - w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout)) - w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout)) - + for a := 0; a < attempts; a++ { + setTimeouts(w) a, b := packUint16(uint16(len(p))) n, err = w.conn.Write([]byte{a, b}) if err != nil { @@ -537,11 +542,9 @@ func (w *reply) writeClient(p []byte) (n int, err error) { } n = i } - case "udp", "udp4", "udp6": - for a := 0; a < w.Client().Attempts; a++ { - w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout)) - w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout)) - + case "", "udp", "udp4", "udp6": + for a := 0; a < attempts; a++ { + setTimeouts(w) n, err = w.conn.(*net.UDPConn).Write(p) if err != nil { if e, ok := err.(net.Error); ok && e.Timeout() { @@ -554,6 +557,20 @@ func (w *reply) writeClient(p []byte) (n int, err error) { return } +func setTimeouts(w *reply) { + if w.Client().ReadTimeout == 0 { + w.conn.SetReadDeadline(time.Now().Add(2 * 1e9)) + } else { + w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout)) + } + + if w.Client().WriteTimeout == 0 { + w.conn.SetWriteDeadline(time.Now().Add(2 * 1e9)) + } else { + w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout)) + } +} + // Close implents the RequestWriter.Close method func (w *reply) Close() (err error) { return w.conn.Close() }