diff --git a/client.go b/client.go index c8d720dc..d6739257 100644 --- a/client.go +++ b/client.go @@ -12,9 +12,13 @@ import ( "time" ) +const dnsTimeout time.Duration = 2 * 1e9 + // A Conn represents a connection (which may be short lived) to a DNS server. type Conn struct { net.Conn + UDPSize uint16 // Minimum reveive buffer for UDP messages, if > 512 EDNS0, 0 MinMsgSize. + TsigSecret map[string]string // Secret(s) for Tsig map[], zonename must be fully qualified rtt time.Duration t time.Time requestMAC string @@ -23,8 +27,9 @@ type Conn struct { // A Client defines parameters for a DNS client. A nil Client is usable for sending queries. type Client struct { Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) - ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defaults to 2 * 1e9 - WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defaults to 2 * 1e9 + DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9 + ReadTimeout time.Duration // net.Conn.SetReadTimeout value for new connections (ns), defaults to 2 * 1e9 + WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for new connections (ns), defaults to 2 * 1e9 TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be fully qualified SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass group singleflight @@ -34,15 +39,15 @@ type Client struct { // contained in a and waits for an reply. func Exchange(m *Msg, a string) (r *Msg, err error) { co := new(Conn) - co.Conn, err = net.DialTimeout("udp", a, 5*1e9) + co.Conn, err = net.DialTimeout("udp", a, dnsTimeout) if err != nil { return nil, err } defer co.Close() - if err = co.WriteMsg(m, nil); err != nil { + if err = co.WriteMsg(m); err != nil { return nil, err } - r, err = co.ReadMsg(nil) + r, err = co.ReadMsg() return r, err } @@ -80,31 +85,43 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { co := new(Conn) + timeout := dnsTimeout + if c.DialTimeout != 0 { + timeout = c.DialTimeout + } if c.Net == "" { - co.Conn, err = net.DialTimeout("udp", a, 5*1e9) + co.Conn, err = net.DialTimeout("udp", a, timeout) } else { - co.Conn, err = net.DialTimeout(c.Net, a, 5*1e9) + co.Conn, err = net.DialTimeout(c.Net, a, timeout) } if err != nil { return nil, 0, err } defer co.Close() - if err = co.WriteMsg(m, c.TsigSecret); err != nil { + opt := m.IsEdns0() + if opt != nil && opt.UDPSize() >= MinMsgSize { + co.UDPSize = opt.UDPSize() + } + co.TsigSecret = c.TsigSecret + if err = co.WriteMsg(m); err != nil { return nil, 0, err } - r, err = co.ReadMsg(c.TsigSecret) + r, err = co.ReadMsg() return r, co.rtt, err } -// Add bufsize -func (co *Conn) ReadMsg(tsigSecret map[string]string) (*Msg, error) { +// ReadMsg reads a message from the connection co. +func (co *Conn) ReadMsg() (*Msg, error) { var p []byte m := new(Msg) if _, ok := co.Conn.(*net.TCPConn); ok { p = make([]byte, MaxMsgSize) } else { - // OPT! TODO(miek): needs function change - p = make([]byte, DefaultMsgSize) + if co.UDPSize >= 512 { + p = make([]byte, co.UDPSize) + } else { + p = make([]byte, MinMsgSize) + } } n, err := co.Read(p) if err != nil && n == 0 { @@ -116,15 +133,16 @@ func (co *Conn) ReadMsg(tsigSecret map[string]string) (*Msg, error) { } co.rtt = time.Since(co.t) if t := m.IsTsig(); t != nil { - if _, ok := tsigSecret[t.Hdr.Name]; !ok { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { return m, ErrSecret } // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, tsigSecret[t.Hdr.Name], co.requestMAC, false) + err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.requestMAC, false) } return m, err } +// Read implements the net.Conn read method. func (co *Conn) Read(p []byte) (n int, err error) { if co.Conn == nil { return 0, ErrConnEmpty @@ -167,18 +185,18 @@ func (co *Conn) Read(p []byte) (n int, err error) { return n, err } -// send sends a dns msg to the address specified in w. +// WriteMsg send a dns message throught the connection co. // If the message m contains a TSIG record the transaction // signature is calculated. -func (co *Conn) WriteMsg(m *Msg, tsigSecret map[string]string) (err error) { +func (co *Conn) WriteMsg(m *Msg) (err error) { var out []byte if t := m.IsTsig(); t != nil { mac := "" - if _, ok := tsigSecret[t.Hdr.Name]; !ok { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { return ErrSecret } - out, mac, err = TsigGenerate(m, tsigSecret[t.Hdr.Name], co.requestMAC, false) - // Set for the next read + out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.requestMAC, false) + // Set for the next read, allthough only used in zone transfers co.requestMAC = mac } else { out, err = m.Pack() @@ -193,6 +211,7 @@ func (co *Conn) WriteMsg(m *Msg, tsigSecret map[string]string) (err error) { return nil } +// Write implements the net.Conn Write method. func (co *Conn) Write(p []byte) (n int, err error) { if t, ok := co.Conn.(*net.TCPConn); ok { if len(p) < 2 { @@ -220,18 +239,20 @@ func (co *Conn) Write(p []byte) (n int, err error) { return n, err } -/* -func setTimeouts(w *reply) { - if w.client.ReadTimeout == 0 { - w.conn.SetReadDeadline(time.Now().Add(2 * 1e9)) - } else { - w.conn.SetReadDeadline(time.Now().Add(w.client.ReadTimeout)) - } +// Close implements the net.Conn Close method. +func (co *Conn) Close() error { return co.Conn.Close() } - if w.client.WriteTimeout == 0 { - w.conn.SetWriteDeadline(time.Now().Add(2 * 1e9)) - } else { - w.conn.SetWriteDeadline(time.Now().Add(w.client.WriteTimeout)) - } -} -*/ +// LocalAddr implements the net.Conn LocalAddr method. +func (co *Conn) LocalAddr() net.Addr { return co.Conn.LocalAddr() } + +// RemoteAddr implements the net.Conn RemoteAddr method. +func (co *Conn) RemoteAddr() net.Addr { return co.Conn.RemoteAddr() } + +// SetDeadline implements the net.Conn SetDeadline method. +func (co *Conn) SetDeadline(t time.Time) error { return co.Conn.SetDeadline(t) } + +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (co *Conn) SetReadDeadline(t time.Time) error { return co.Conn.SetReadDeadline(t) } + +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (co *Conn) SetWriteDeadline(t time.Time) error { return co.Conn.SetWriteDeadline(t) } diff --git a/dns.go b/dns.go index 4e606d79..13776d92 100644 --- a/dns.go +++ b/dns.go @@ -91,7 +91,7 @@ import ( const ( year68 = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits. DefaultMsgSize = 4096 // Standard default for larger than 512 packets. - udpMsgSize = 512 // Default buffer size for servers receiving UDP packets. + MinMsgSize = 512 // Minimal size of a DNS packet. MaxMsgSize = 65536 // Largest possible DNS packet. defaultTtl = 3600 // Default TTL. ) diff --git a/server.go b/server.go index a6eeff43..33a5fe90 100644 --- a/server.go +++ b/server.go @@ -355,7 +355,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { handler = DefaultServeMux } if srv.UDPSize == 0 { - srv.UDPSize = udpMsgSize + srv.UDPSize = MinMsgSize } for { if srv.ReadTimeout != 0 {