remove the need for NewClient()
This commit is contained in:
parent
eb7b7c9745
commit
458a74b8ce
|
@ -3,8 +3,6 @@
|
|||
* outgoing [AI]xfr
|
||||
* zonereader that extracts glue (or at least signals it) and other usefull stuff?
|
||||
* a complete dnssec resolver
|
||||
* the outgoing channel for resolver isn't usefull - remove it.
|
||||
|
||||
|
||||
## Nice to have
|
||||
|
||||
|
|
81
client.go
81
client.go
|
@ -169,14 +169,15 @@ func (mux *QueryMux) QueryDNS(w RequestWriter, r *Msg) {
|
|||
h.QueryDNS(w, r)
|
||||
}
|
||||
|
||||
// A nil Client is usable.
|
||||
type Client struct {
|
||||
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one
|
||||
Attempts int // number of attempts
|
||||
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "", is UDP)
|
||||
Attempts int // number of attempts, if not set defaults to 1
|
||||
Retry bool // retry with TCP
|
||||
Request chan *Request // read DNS request from this channel
|
||||
Reply chan *Exchange // write replies to this channel
|
||||
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns)
|
||||
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns)
|
||||
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defauls to 2 * 1e9
|
||||
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defauls to 2 * 1e9
|
||||
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
|
||||
Hijacked net.Conn // if set the calling code takes care of the connection
|
||||
// LocalAddr string // Local address to use
|
||||
|
@ -245,8 +246,16 @@ func ListenAndQueryRequest(request chan *Request, handler QueryHandler) {
|
|||
// reply channel of the client.
|
||||
func (w *reply) Write(m *Msg) error {
|
||||
if w.conn == nil {
|
||||
if w.Client().Reply == nil {
|
||||
QueryReply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt}
|
||||
return nil
|
||||
}
|
||||
w.Client().Reply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt}
|
||||
} else {
|
||||
if w.Client().Reply == nil {
|
||||
QueryReply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()}
|
||||
return nil
|
||||
}
|
||||
w.Client().Reply <- &Exchange{Request: w.req, Reply: m, Rtt: w.rtt, RemoteAddr: w.conn.RemoteAddr()}
|
||||
}
|
||||
return nil
|
||||
|
@ -330,7 +339,7 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, addr
|
|||
switch c.Net {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
in = make([]byte, MaxMsgSize)
|
||||
case "udp", "udp4", "udp6":
|
||||
case "", "udp", "udp4", "udp6":
|
||||
size := UDPMsgSize
|
||||
for _, r := range m.Extra {
|
||||
if r.Header().Rrtype == TypeOPT {
|
||||
|
@ -365,7 +374,7 @@ func (w *reply) Receive() (*Msg, error) {
|
|||
switch w.Client().Net {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
p = make([]byte, MaxMsgSize)
|
||||
case "udp", "udp4", "udp6":
|
||||
case "", "udp", "udp4", "udp6":
|
||||
p = make([]byte, DefaultMsgSize)
|
||||
}
|
||||
n, err := w.readClient(p)
|
||||
|
@ -393,15 +402,17 @@ func (w *reply) readClient(p []byte) (n int, err error) {
|
|||
if w.conn == nil {
|
||||
return 0, ErrConnEmpty
|
||||
}
|
||||
if len(p) < 1 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
attempts := w.Client().Attempts
|
||||
if attempts == 0 {
|
||||
attempts = 1
|
||||
}
|
||||
switch w.Client().Net {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
if len(p) < 1 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
for a := 0; a < w.Client().Attempts; a++ {
|
||||
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
|
||||
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout))
|
||||
|
||||
setTimeouts(w)
|
||||
for a := 0; a < attempts; a++ {
|
||||
n, err = w.conn.(*net.TCPConn).Read(p[0:2])
|
||||
if err != nil || n != 2 {
|
||||
if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
|
@ -437,11 +448,9 @@ func (w *reply) readClient(p []byte) (n int, err error) {
|
|||
}
|
||||
n = i
|
||||
}
|
||||
case "udp", "udp4", "udp6":
|
||||
for a := 0; a < w.Client().Attempts; a++ {
|
||||
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
|
||||
w.conn.SetWriteDeadline(time.Now().Add(w.Client().ReadTimeout))
|
||||
|
||||
case "", "udp", "udp4", "udp6":
|
||||
for a := 0; a < attempts; a++ {
|
||||
setTimeouts(w)
|
||||
n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p)
|
||||
if err != nil {
|
||||
if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
|
@ -485,11 +494,9 @@ func (w *reply) Send(m *Msg) (err error) {
|
|||
}
|
||||
|
||||
func (w *reply) writeClient(p []byte) (n int, err error) {
|
||||
if w.Client().Attempts == 0 {
|
||||
panic("c.Attempts 0")
|
||||
}
|
||||
if w.Client().Net == "" {
|
||||
panic("c.Net empty")
|
||||
attempts := w.Client().Attempts
|
||||
if attempts == 0 {
|
||||
attempts = 1
|
||||
}
|
||||
if w.Client().Hijacked == nil {
|
||||
if err = w.Dial(); err != nil {
|
||||
|
@ -501,10 +508,8 @@ func (w *reply) writeClient(p []byte) (n int, err error) {
|
|||
if len(p) < 2 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
for a := 0; a < w.Client().Attempts; a++ {
|
||||
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout))
|
||||
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
|
||||
|
||||
for a := 0; a < attempts; a++ {
|
||||
setTimeouts(w)
|
||||
a, b := packUint16(uint16(len(p)))
|
||||
n, err = w.conn.Write([]byte{a, b})
|
||||
if err != nil {
|
||||
|
@ -537,11 +542,9 @@ func (w *reply) writeClient(p []byte) (n int, err error) {
|
|||
}
|
||||
n = i
|
||||
}
|
||||
case "udp", "udp4", "udp6":
|
||||
for a := 0; a < w.Client().Attempts; a++ {
|
||||
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout))
|
||||
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
|
||||
|
||||
case "", "udp", "udp4", "udp6":
|
||||
for a := 0; a < attempts; a++ {
|
||||
setTimeouts(w)
|
||||
n, err = w.conn.(*net.UDPConn).Write(p)
|
||||
if err != nil {
|
||||
if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
|
@ -554,6 +557,20 @@ func (w *reply) writeClient(p []byte) (n int, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func setTimeouts(w *reply) {
|
||||
if w.Client().ReadTimeout == 0 {
|
||||
w.conn.SetReadDeadline(time.Now().Add(2 * 1e9))
|
||||
} else {
|
||||
w.conn.SetReadDeadline(time.Now().Add(w.Client().ReadTimeout))
|
||||
}
|
||||
|
||||
if w.Client().WriteTimeout == 0 {
|
||||
w.conn.SetWriteDeadline(time.Now().Add(2 * 1e9))
|
||||
} else {
|
||||
w.conn.SetWriteDeadline(time.Now().Add(w.Client().WriteTimeout))
|
||||
}
|
||||
}
|
||||
|
||||
// Close implents the RequestWriter.Close method
|
||||
func (w *reply) Close() (err error) { return w.conn.Close() }
|
||||
|
||||
|
|
Loading…
Reference in New Issue