diff --git a/client.go b/client.go index 8d0ef7b8..3bc58f28 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ package dns import ( "bytes" + "crypto/tls" "io" "net" "time" @@ -26,6 +27,7 @@ type Conn struct { type Client struct { Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) UDPSize uint16 // minimum receive buffer for UDP messages + TLS bool // enables TLS connection (port 853) DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds @@ -152,11 +154,18 @@ func (c *Client) writeTimeout() time.Duration { func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { var co *Conn - if c.Net == "" { - co, err = DialTimeout("udp", a, c.dialTimeout()) - } else { - co, err = DialTimeout(c.Net, a, c.dialTimeout()) + + network := "udp" + if c.Net != "" { + network = c.Net } + + if c.TLS { + co, err = DialTimeoutWithTLS(network, a, c.dialTimeout()) + } else { + co, err = DialTimeout(network, a, c.dialTimeout()) + } + if err != nil { return nil, 0, err } @@ -383,3 +392,26 @@ func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, er } return conn, nil } + +// DialWithTLS connects to the address on the named network with TLS. +func DialWithTLS(network, address string) (conn *Conn, err error) { + conn = new(Conn) + conn.Conn, err = tls.Dial(network, address, nil) + if err != nil { + return nil, err + } + return conn, nil +} + +// DialTimeoutWithTLS acts like DialWithTLS but takes a timeout. +func DialTimeoutWithTLS(network, address string, timeout time.Duration) (conn *Conn, err error) { + var dialer net.Dialer + dialer.Timeout = timeout + + conn = new(Conn) + conn.Conn, err = tls.DialWithDialer(&dialer, network, address, nil) + if err != nil { + return nil, err + } + return conn, nil +}