Fix merge conflict from net branch

This commit is contained in:
Miek Gieben 2013-10-12 17:59:46 +01:00
commit 68083bc956
10 changed files with 377 additions and 320 deletions

254
client.go
View File

@ -12,39 +12,77 @@ import (
"time"
)
// Order of events:
// *client -> *reply -> Exchange() -> dial()/send()->write()/receive()->read()
const dnsTimeout time.Duration = 2 * 1e9
// Do I want make this an interface thingy?
type reply struct {
client *Client
addr string
req *Msg
conn net.Conn
tsigRequestMAC string
tsigTimersOnly bool
tsigStatus error
// A Conn represents a connection to a DNS server.
type Conn struct {
net.Conn // a net.Conn holding the connection
UDPSize uint16 // Minimum receive buffer for UDP messages
TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
rtt time.Duration
t time.Time
tsigRequestMAC string
}
// A Client defines parameter for a DNS client. A nil
// Client is usable for sending queries.
// A Client defines parameters for a DNS client. A nil Client is usable for sending queries.
type Client struct {
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defaults to 2 * 1e9
WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defaults to 2 * 1e9
DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
group singleflight
}
func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.Duration, err error) {
// Exchange performs a synchronous UDP query. It sends the message m to the address
// contained in a and waits for an reply.
func Exchange(m *Msg, a string) (r *Msg, err error) {
co := new(Conn)
co.Conn, err = net.DialTimeout("udp", a, dnsTimeout)
if err != nil {
return nil, err
}
defer co.Close()
co.SetReadDeadline(time.Now().Add(dnsTimeout))
co.SetWriteDeadline(time.Now().Add(dnsTimeout))
if err = co.WriteMsg(m); err != nil {
return nil, err
}
r, err = co.ReadMsg()
return r, err
}
// ExchangeConn performs a synchronous query. It sends the message m via the connection
// c and waits for a reply. The connection c is not closed by ExchangeConn.
// This function is going away, but can easily be mimicked:
//
// co := new(dns.Conn)
// co.Conn = c // c is your net.Conn
// co.WriteMsg(m)
// in, _ := co.ReadMsg()
//
func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
println("dns: this function is deprecated")
co := new(Conn)
co.Conn = c
if err = co.WriteMsg(m); err != nil {
return nil, err
}
r, err = co.ReadMsg()
return r, err
}
// Exchange performs an synchronous query. It sends the message m to the address
// contained in a and waits for an reply. Basic use pattern with a *dns.Client:
//
// c := new(dns.Client)
// in, rtt, err := c.Exchange(message, "127.0.0.1:53")
//
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
if !c.SingleInflight {
if s == nil {
return c.exchange(m, a)
}
return c.exchangeConn(m, s)
return c.exchange(m, a)
}
// This adds a bunch of garbage, TODO(miek).
t := "nop"
@ -56,86 +94,71 @@ func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.D
cl = cl1
}
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
if s == nil {
return c.exchange(m, a)
}
return c.exchangeConn(m, s)
return c.exchange(m, a)
})
if err != nil {
return r, rtt, err
}
if shared {
r1 := r.copy()
r1.Id = r.Id // Copy Id!
r = r1
}
return r, rtt, nil
}
// Exchange performs an synchronous query. It sends the message m to the address
// contained in a and waits for an reply. Basic use pattern with a *dns.Client:
//
// c := new(dns.Client)
// in, rtt, err := c.Exchange(message, "127.0.0.1:53")
//
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
return c.exchangeMerge(m, a, nil)
}
func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
w := &reply{client: c, addr: a}
if err = w.dial(); err != nil {
return nil, 0, err
co := new(Conn)
timeout := dnsTimeout
if c.DialTimeout != 0 {
timeout = c.DialTimeout
}
defer w.conn.Close()
if err = w.send(m); err != nil {
return nil, 0, err
}
r, err = w.receive()
return r, w.rtt, err
}
// ExchangeConn performs an synchronous query. It sends the message m trough the
// connection s and waits for a reply.
func (c *Client) ExchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) {
return c.exchangeMerge(m, "", s)
}
func (c *Client) exchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) {
w := &reply{client: c, conn: s}
if err = w.send(m); err != nil {
return nil, 0, err
}
r, err = w.receive()
return r, w.rtt, err
}
// dial connects to the address addr for the network set in c.Net
func (w *reply) dial() (err error) {
var conn net.Conn
if w.client.Net == "" {
conn, err = net.DialTimeout("udp", w.addr, 5*1e9)
if c.Net == "" {
co.Conn, err = net.DialTimeout("udp", a, timeout)
} else {
conn, err = net.DialTimeout(w.client.Net, w.addr, 5*1e9)
co.Conn, err = net.DialTimeout(c.Net, a, timeout)
}
if err != nil {
return err
return nil, 0, err
}
w.conn = conn
return
timeout = dnsTimeout
if c.ReadTimeout != 0 {
timeout = c.ReadTimeout
}
co.SetReadDeadline(time.Now().Add(timeout))
timeout = dnsTimeout
if c.WriteTimeout != 0 {
timeout = c.WriteTimeout
}
co.SetWriteDeadline(time.Now().Add(timeout))
defer co.Close()
opt := m.IsEdns0()
if opt != nil && opt.UDPSize() >= MinMsgSize {
co.UDPSize = opt.UDPSize()
}
co.TsigSecret = c.TsigSecret
if err = co.WriteMsg(m); err != nil {
return nil, 0, err
}
r, err = co.ReadMsg()
return r, co.rtt, err
}
func (w *reply) receive() (*Msg, error) {
// ReadMsg reads a message from the connection co.
// If the received message contains a TSIG record the transaction
// signature is verified.
func (co *Conn) ReadMsg() (*Msg, error) {
var p []byte
m := new(Msg)
switch w.client.Net {
case "tcp", "tcp4", "tcp6":
if _, ok := co.Conn.(*net.TCPConn); ok {
p = make([]byte, MaxMsgSize)
case "", "udp", "udp4", "udp6":
// OPT! TODO(mg)
p = make([]byte, DefaultMsgSize)
} else {
if co.UDPSize >= 512 {
p = make([]byte, co.UDPSize)
} else {
p = make([]byte, MinMsgSize)
}
}
n, err := w.read(p)
n, err := co.Read(p)
if err != nil && n == 0 {
return nil, err
}
@ -143,21 +166,20 @@ func (w *reply) receive() (*Msg, error) {
if err := m.Unpack(p); err != nil {
return nil, err
}
w.rtt = time.Since(w.t)
co.rtt = time.Since(co.t)
if t := m.IsTsig(); t != nil {
secret := t.Hdr.Name
if _, ok := w.client.TsigSecret[secret]; !ok {
w.tsigStatus = ErrSecret
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
return m, ErrSecret
}
// Need to work on the original message p, as that was used to calculate the tsig.
w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly)
err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
}
return m, w.tsigStatus
return m, err
}
func (w *reply) read(p []byte) (n int, err error) {
if w.conn == nil {
// Read implements the net.Conn read method.
func (co *Conn) Read(p []byte) (n int, err error) {
if co.Conn == nil {
return 0, ErrConnEmpty
}
if len(p) < 2 {
@ -200,75 +222,83 @@ func (w *reply) read(p []byte) (n int, err error) {
return n, err
}
// send sends a dns msg to the address specified in w.
// WriteMsg sends a message throught the connection co.
// If the message m contains a TSIG record the transaction
// signature is calculated.
func (w *reply) send(m *Msg) (err error) {
func (co *Conn) WriteMsg(m *Msg) (err error) {
var out []byte
if t := m.IsTsig(); t != nil {
mac := ""
name := t.Hdr.Name
if _, ok := w.client.TsigSecret[name]; !ok {
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
return ErrSecret
}
out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly)
w.tsigRequestMAC = mac
out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
// Set for the next read, allthough only used in zone transfers
co.tsigRequestMAC = mac
} else {
out, err = m.Pack()
}
if err != nil {
return err
}
w.t = time.Now()
if _, err = w.write(out); err != nil {
co.t = time.Now()
if _, err = co.Write(out); err != nil {
return err
}
return nil
}
func (w *reply) write(p []byte) (n int, err error) {
switch w.client.Net {
case "tcp", "tcp4", "tcp6":
// Write implements the net.Conn Write method.
func (co *Conn) Write(p []byte) (n int, err error) {
if t, ok := co.Conn.(*net.TCPConn); ok {
if len(p) < 2 {
return 0, io.ErrShortBuffer
}
setTimeouts(w)
l := make([]byte, 2)
l[0], l[1] = packUint16(uint16(len(p)))
p = append(l, p...)
n, err := w.conn.Write(p)
n, err := t.Write(p)
if err != nil {
return n, err
}
i := n
if i < len(p) {
j, err := w.conn.Write(p[i:len(p)])
j, err := t.Write(p[i:len(p)])
if err != nil {
return i, err
}
i += j
}
n = i
<<<<<<< HEAD
case "", "udp", "udp4", "udp6":
setTimeouts(w)
n, err = w.conn.Write(p)
if err != nil {
return n, err
}
=======
return n, err
>>>>>>> net
}
return
n, err = co.Conn.(*net.UDPConn).Write(p)
return n, err
}
func setTimeouts(w *reply) {
if w.client.ReadTimeout == 0 {
w.conn.SetReadDeadline(time.Now().Add(2 * 1e9))
} else {
w.conn.SetReadDeadline(time.Now().Add(w.client.ReadTimeout))
}
// Close implements the net.Conn Close method.
func (co *Conn) Close() error { return co.Conn.Close() }
if w.client.WriteTimeout == 0 {
w.conn.SetWriteDeadline(time.Now().Add(2 * 1e9))
} else {
w.conn.SetWriteDeadline(time.Now().Add(w.client.WriteTimeout))
}
}
// LocalAddr implements the net.Conn LocalAddr method.
func (co *Conn) LocalAddr() net.Addr { return co.Conn.LocalAddr() }
// RemoteAddr implements the net.Conn RemoteAddr method.
func (co *Conn) RemoteAddr() net.Addr { return co.Conn.RemoteAddr() }
// SetDeadline implements the net.Conn SetDeadline method.
func (co *Conn) SetDeadline(t time.Time) error { return co.Conn.SetDeadline(t) }
// SetReadDeadline implements the net.Conn SetReadDeadline method.
func (co *Conn) SetReadDeadline(t time.Time) error { return co.Conn.SetReadDeadline(t) }
// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
func (co *Conn) SetWriteDeadline(t time.Time) error { return co.Conn.SetWriteDeadline(t) }

View File

@ -82,14 +82,13 @@ Loop:
/*
func TestClientTsigAXFR(t *testing.T) {
m := new(Msg)
m.SetAxfr("miek.nl.")
m.SetAxfr("example.nl.")
m.SetTsig("axfr.", HmacMD5, 300, time.Now().Unix())
c := new(Client)
c.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
c.Net = "tcp"
tr := new(Transfer)
tr.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
if a, err := c.TransferIn(m, "37.251.95.53:53"); err != nil {
if a, err := tr.In(m, "176.58.119.54:53"); err != nil {
t.Log("Failed to setup axfr: " + err.Error())
t.Fatal()
} else {
@ -106,14 +105,12 @@ func TestClientTsigAXFR(t *testing.T) {
}
}
func TestClientAXFRMultipleMessages(t *testing.T) {
func TestClientAXFRMultipleEnvelopes(t *testing.T) {
m := new(Msg)
m.SetAxfr("dnsex.nl.")
m.SetAxfr("nlnetlabs.nl.")
c := new(Client)
c.Net = "tcp"
if a, err := c.TransferIn(m, "37.251.95.53:53"); err != nil {
tr := new(Transfer)
if a, err := tr.In(m, "213.154.224.1:53"); err != nil {
t.Log("Failed to setup axfr" + err.Error())
t.Fail()
return
@ -130,7 +127,7 @@ func TestClientAXFRMultipleMessages(t *testing.T) {
*/
// not really a test, but shows how to use update leases
func TestUpdateLeaseTSIG(t *testing.T) {
func ExampleUpdateLeaseTSIG(t *testing.T) {
m := new(Msg)
m.SetUpdate("t.local.ip6.io.")
rr, _ := NewRR("t.local.ip6.io. 30 A 127.0.0.1")
@ -151,16 +148,9 @@ func TestUpdateLeaseTSIG(t *testing.T) {
m.SetTsig("polvi.", HmacMD5, 300, time.Now().Unix())
c.TsigSecret = map[string]string{"polvi.": "pRZgBrBvI4NAHZYhxmhs/Q=="}
w := new(reply)
w.client = c
w.addr = "127.0.0.1:53"
w.req = m
if err := w.dial(); err != nil {
_, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Log(err.Error())
t.Fail()
}
if err := w.send(m); err != nil {
t.Fail()
}
}

9
dns.go
View File

@ -36,7 +36,7 @@
// In the DNS messages are exchanged, these messages contain resource
// records (sets). Use pattern for creating a message:
//
// m := dns.new(Msg)
// m := new(dns.Msg)
// m.SetQuestion("miek.nl.", dns.TypeMX)
//
// Or when not certain if the domain name is fully qualified:
@ -66,6 +66,11 @@
//
// c.SingleInflight = true
//
// If these "advanced" features are not needed, a simple UDP query can be send,
// with:
//
// in, err := dns.Exchange(m1, "127.0.0.1:53")
//
// A dns message consists out of four sections.
// The question section: in.Question, the answer section: in.Answer,
// the authority section: in.Ns and the additional section: in.Extra.
@ -86,7 +91,7 @@ import (
const (
year68 = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits.
DefaultMsgSize = 4096 // Standard default for larger than 512 packets.
udpMsgSize = 512 // Default buffer size for servers receiving UDP packets.
MinMsgSize = 512 // Minimal size of a DNS packet.
MaxMsgSize = 65536 // Largest possible DNS packet.
defaultTtl = 3600 // Default TTL.
)

View File

@ -9,9 +9,6 @@ import (
"testing"
)
// Query with way to long name
//./q mx bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.miek.nl.miek.nl.miek123.nl.
func TestPackUnpack(t *testing.T) {
out := new(Msg)
out.Answer = make([]RR, 1)
@ -245,7 +242,7 @@ func TestNoRdataPack(t *testing.T) {
data := make([]byte, 1024)
for typ, fn := range rr_mk {
if typ == TypeCAA {
continue // known broken, will fix. TODO(miek)
continue // TODO(miek): known ommision
}
r := fn()
*r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600}

View File

@ -443,7 +443,6 @@ func (e *EDNS0_DAU) String() string {
}
}
return s
}
type EDNS0_DHU struct {

View File

@ -141,6 +141,14 @@ Flags:
nameserver = dns.Fqdn(nameserver) + ":" + strconv.Itoa(*port)
}
c := new(dns.Client)
t := new(dns.Transfer)
c.Net = "udp"
if *four {
c.Net = "udp4"
}
if *six {
c.Net = "udp6"
}
if *tcp {
c.Net = "tcp"
if *four {
@ -149,14 +157,6 @@ Flags:
if *six {
c.Net = "tcp6"
}
} else {
c.Net = "udp"
if *four {
c.Net = "udp4"
}
if *six {
c.Net = "udp6"
}
}
m := new(dns.Msg)
@ -204,14 +204,15 @@ Flags:
m.Extra = append(m.Extra, o)
}
query:
for _, v := range qname {
m.Question[0] = dns.Question{dns.Fqdn(v), qtype, qclass}
m.Id = dns.Id()
// Add tsig
if *tsig != "" {
if algo, name, secret, ok := tsigKeyParse(*tsig); ok {
m.SetTsig(name, algo, 300, time.Now().Unix())
c.TsigSecret = map[string]string{name: secret}
t.TsigSecret = map[string]string{name: secret}
} else {
fmt.Fprintf(os.Stderr, "TSIG key data error\n")
return
@ -221,13 +222,26 @@ Flags:
fmt.Printf("%s", m.String())
fmt.Printf("\n;; size: %d bytes\n\n", m.Len())
}
if qtype == dns.TypeAXFR {
c.Net = "tcp"
doXfr(c, m, nameserver)
continue
}
if qtype == dns.TypeIXFR {
doXfr(c, m, nameserver)
if qtype == dns.TypeAXFR || qtype == dns.TypeIXFR {
env, err := t.In(m, nameserver)
if err != nil {
fmt.Printf(";; %s\n", err.Error())
continue
}
envelope := 0
record := 0
for e := range env {
if e.Error != nil {
fmt.Printf(";; %s\n", e.Error.Error())
continue query
}
for _, r := range e.RR {
fmt.Printf("%s\n", r)
}
record+=len(e.RR)
envelope++
}
fmt.Printf("\n;; xfr size: %d records (envelopes %d)\n", record, envelope)
continue
}
r, rtt, e := c.Exchange(m, nameserver)
@ -280,15 +294,15 @@ func tsigKeyParse(s string) (algo, name, secret string, ok bool) {
s1 := strings.SplitN(s, ":", 3)
switch len(s1) {
case 2:
return "hmac-md5.sig-alg.reg.int.", s1[0], s1[1], true
return "hmac-md5.sig-alg.reg.int.", dns.Fqdn(s1[0]), s1[1], true
case 3:
switch s1[0] {
case "hmac-md5":
return "hmac-md5.sig-alg.reg.int.", s1[1], s1[2], true
return "hmac-md5.sig-alg.reg.int.", dns.Fqdn(s1[1]), s1[2], true
case "hmac-sha1":
return "hmac-sha1.", s1[1], s1[2], true
return "hmac-sha1.", dns.Fqdn(s1[1]), s1[2], true
case "hmac-sha256":
return "hmac-sha256.", s1[1], s1[2], true
return "hmac-sha256.", dns.Fqdn(s1[1]), s1[2], true
}
}
return
@ -402,20 +416,22 @@ func shortRR(r dns.RR) dns.RR {
}
func doXfr(c *dns.Client, m *dns.Msg, nameserver string) {
if t, e := c.TransferIn(m, nameserver); e == nil {
for r := range t {
if r.Error == nil {
for _, rr := range r.RR {
if *short {
rr = shortRR(rr)
/*
if t, e := c.TransferIn(m, nameserver); e == nil {
for r := range t {
if r.Error == nil {
for _, rr := range r.RR {
if *short {
rr = shortRR(rr)
}
fmt.Printf("%v\n", rr)
}
fmt.Printf("%v\n", rr)
} else {
fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", r.Error.Error())
}
} else {
fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", r.Error.Error())
}
} else {
fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", e.Error())
}
} else {
fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", e.Error())
}
*/
}

View File

@ -103,24 +103,6 @@ func handleReflect(w dns.ResponseWriter, r *dns.Msg) {
t.Txt = []string{str}
switch r.Question[0].Qtype {
case dns.TypeAXFR:
c := make(chan *dns.Envelope)
var e *error
if err := dns.TransferOut(w, r, c, e); err != nil {
close(c)
return
}
soa, _ := dns.NewRR(`whoami.miek.nl. IN SOA elektron.atoom.net. miekg.atoom.net. (
2009032802
21600
7200
604800
3600)`)
c <- &dns.Envelope{RR: []dns.RR{soa, t, rr, soa}}
close(c)
w.Hijack()
// w.Close() // Client closes
return
case dns.TypeTXT:
m.Answer = append(m.Answer, t)
m.Extra = append(m.Extra, rr)
@ -129,6 +111,21 @@ func handleReflect(w dns.ResponseWriter, r *dns.Msg) {
case dns.TypeAAAA, dns.TypeA:
m.Answer = append(m.Answer, rr)
m.Extra = append(m.Extra, t)
case dns.TypeAXFR, dns.TypeIXFR:
c := make(chan *dns.Envelope)
tr := new(dns.Transfer)
defer close(c)
err := tr.Out(w, r, c)
if err != nil {
return
}
soa, _ := dns.NewRR(`whoami.miek.nl. 0 IN SOA linode.atoom.net. miek.miek.nl. 2009032802 21600 7200 604800 3600`)
c <- &dns.Envelope{RR: []dns.RR{soa, t, rr, soa}}
w.Hijack()
// w.Close() // Client closes connection
return
}
if r.IsTsig() != nil {
@ -161,7 +158,7 @@ func serve(net, name, secret string) {
}
func main() {
runtime.GOMAXPROCS(runtime.NumCPU()*4)
runtime.GOMAXPROCS(runtime.NumCPU() * 4)
cpuprofile := flag.String("cpuprofile", "", "write cpu profile to file")
printf = flag.Bool("print", false, "print replies")
compress = flag.Bool("compress", false, "compress replies")

View File

@ -160,17 +160,3 @@ func PrevLabel(s string, n int) (i int, start bool) {
}
return lab[len(lab)-n], false
}
func LenLabels(s string) int {
println("LenLabels is to be removed in future versions, for the better named CountLabel")
return CountLabel(s)
}
func SplitLabels(s string) []string {
println("SplitLabels is to be removed in future versions, for the better named SplitDomainName")
return SplitDomainName(s)
}
func CompareLabels(s1, s2 string) (n int) {
println("CompareLabels is to be removed in future versions, for better named CompareDomainName")
return CompareDomainName(s1, s2)
}

View File

@ -357,7 +357,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
handler = DefaultServeMux
}
if srv.UDPSize == 0 {
srv.UDPSize = udpMsgSize
srv.UDPSize = MinMsgSize
}
for {
if srv.ReadTimeout != 0 {

263
xfr.go
View File

@ -6,84 +6,91 @@
package dns
// New Transfer
import (
"net"
"time"
)
// Envelope is used when doing [IA]XFR with a remote server.
// Envelope is used when doing a transfer with a remote server.
type Envelope struct {
RR []RR // The set of RRs in the answer section of the AXFR reply message.
Error error // If something went wrong, this contains the error.
}
// TransferIn performs a [AI]XFR request (depends on the message's Qtype). It returns
// a channel of *Envelope on which the replies from the server are sent. At the end of
// the transfer the channel is closed.
// The messages are TSIG checked if
// needed, no other post-processing is performed. The caller must dissect the returned
// messages.
//
// Basic use pattern for receiving an AXFR:
//
// // m contains the AXFR request
// t, e := c.TransferIn(m, "127.0.0.1:53")
// for r := range t {
// // ... deal with r.RR or r.Error
// }
func (c *Client) TransferIn(q *Msg, a string) (chan *Envelope, error) {
w := new(reply)
w.client = c
w.addr = a
w.req = q
if err := w.dial(); err != nil {
return nil, err
}
if err := w.send(q); err != nil {
return nil, err
}
e := make(chan *Envelope)
switch q.Question[0].Qtype {
case TypeAXFR:
go w.axfrIn(q, e)
return e, nil
case TypeIXFR:
go w.ixfrIn(q, e)
return e, nil
default:
return nil, nil
}
panic("dns: not reached")
// A Transfer defines parameters that are used during a zone transfer.
type Transfer struct {
*Conn
DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9
TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
tsigTimersOnly bool
}
func (w *reply) axfrIn(q *Msg, c chan *Envelope) {
// In performs an incoming transfer with the server in a.
func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) {
t.Conn = new(Conn)
timeout := dnsTimeout
if t.DialTimeout != 0 {
timeout = t.DialTimeout
}
t.Conn.Conn, err = net.DialTimeout("tcp", a, timeout)
if err != nil {
return nil, err
}
if err := t.WriteMsg(q); err != nil {
return nil, err
}
env = make(chan *Envelope)
go func() {
if q.Question[0].Qtype == TypeAXFR {
go t.inAxfr(q.Id, env)
return
}
if q.Question[0].Qtype == TypeIXFR {
go t.inIxfr(q.Id, env)
return
}
}()
return env, nil
}
func (t *Transfer) inAxfr(id uint16, c chan *Envelope) {
first := true
defer w.conn.Close()
defer t.Close()
defer close(c)
timeout := dnsTimeout
if t.ReadTimeout != 0 {
timeout = t.ReadTimeout
}
for {
in, err := w.receive()
t.Conn.SetReadDeadline(time.Now().Add(timeout))
in, err := t.ReadMsg()
if err != nil {
c <- &Envelope{nil, err}
return
}
if in.Id != q.Id {
if id != in.Id {
c <- &Envelope{in.Answer, ErrId}
return
}
if first {
if !checkXfrSOA(in, true) {
if !isSOAFirst(in) {
c <- &Envelope{in.Answer, ErrSoa}
return
}
first = !first
// only one answer that is SOA, receive more
if len(in.Answer) == 1 {
w.tsigTimersOnly = true
t.tsigTimersOnly = true
c <- &Envelope{in.Answer, nil}
continue
}
}
if !first {
w.tsigTimersOnly = true // Subsequent envelopes use this.
if checkXfrSOA(in, false) {
t.tsigTimersOnly = true // Subsequent envelopes use this.
if isSOALast(in) {
c <- &Envelope{in.Answer, nil}
return
}
@ -93,30 +100,35 @@ func (w *reply) axfrIn(q *Msg, c chan *Envelope) {
panic("dns: not reached")
}
func (w *reply) ixfrIn(q *Msg, c chan *Envelope) {
var serial uint32 // The first serial seen is the current server serial
func (t *Transfer) inIxfr(id uint16, c chan *Envelope) {
serial := uint32(0) // The first serial seen is the current server serial
first := true
defer w.conn.Close()
defer t.Close()
defer close(c)
timeout := dnsTimeout
if t.ReadTimeout != 0 {
timeout = t.ReadTimeout
}
for {
in, err := w.receive()
t.SetReadDeadline(time.Now().Add(timeout))
in, err := t.ReadMsg()
if err != nil {
c <- &Envelope{in.Answer, err}
return
}
if q.Id != in.Id {
if id != in.Id {
c <- &Envelope{in.Answer, ErrId}
return
}
if first {
// A single SOA RR signals "no changes"
if len(in.Answer) == 1 && checkXfrSOA(in, true) {
if len(in.Answer) == 1 && isSOAFirst(in) {
c <- &Envelope{in.Answer, nil}
return
}
// Check if the returned answer is ok
if !checkXfrSOA(in, true) {
if !isSOAFirst(in) {
c <- &Envelope{in.Answer, ErrSoa}
return
}
@ -127,7 +139,7 @@ func (w *reply) ixfrIn(q *Msg, c chan *Envelope) {
// Now we need to check each message for SOA records, to see what we need to do
if !first {
w.tsigTimersOnly = true
t.tsigTimersOnly = true
// If the last record in the IXFR contains the servers' SOA, we should quit
if v, ok := in.Answer[len(in.Answer)-1].(*SOA); ok {
if v.Serial == serial {
@ -138,70 +150,95 @@ func (w *reply) ixfrIn(q *Msg, c chan *Envelope) {
c <- &Envelope{in.Answer, nil}
}
}
panic("dns: not reached")
}
// Check if he SOA record exists in the Answer section of
// the packet. If first is true the first RR must be a SOA
// if false, the last one should be a SOA.
func checkXfrSOA(in *Msg, first bool) bool {
if len(in.Answer) > 0 {
if first {
return in.Answer[0].Header().Rrtype == TypeSOA
} else {
return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA
// Out performs an outgoing transfer with the client connecting in w.
// Basic use pattern:
//
// ch := make(chan *dns.Envelope)
// tr := new(dns.Transfer)
// tr.Out(w, r, ch)
// c <- &dns.Envelope{RR: []dns.RR{soa, rr1, rr2, rr3, soa}}
// close(ch)
// w.Hijack()
// // w.Close() // Client closes connection
//
// The server is responsible for sending the correct sequence of RRs through the
// channel ch.
func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error {
r := new(Msg)
// Compress?
r.SetReply(q)
r.Authoritative = true
go func() {
for x := range ch {
// assume it fits TODO(miek): fix
r.Answer = append(r.Answer, x.RR...)
if err := w.WriteMsg(r); err != nil {
return
}
}
w.TsigTimersOnly(true)
r.Answer = nil
}()
return nil
}
// ReadMsg reads a message from the transfer connection t.
func (t *Transfer) ReadMsg() (*Msg, error) {
m := new(Msg)
p := make([]byte, MaxMsgSize)
n, err := t.Read(p)
if err != nil && n == 0 {
return nil, err
}
p = p[:n]
if err := m.Unpack(p); err != nil {
return nil, err
}
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
return m, ErrSecret
}
// Need to work on the original message p, as that was used to calculate the tsig.
err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
}
return m, err
}
// WriteMsg write a message throught the transfer connection t.
func (t *Transfer) WriteMsg(m *Msg) (err error) {
var out []byte
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
return ErrSecret
}
out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
} else {
out, err = m.Pack()
}
if err != nil {
return err
}
if _, err = t.Write(out); err != nil {
return err
}
return nil
}
func isSOAFirst(in *Msg) bool {
if len(in.Answer) > 0 {
return in.Answer[0].Header().Rrtype == TypeSOA
}
return false
}
// TransferOut performs an outgoing [AI]XFR depending on the request message. The
// caller is responsible for sending the correct sequence of RR sets through
// the channel c. For reasons of symmetry Envelope is re-used.
// Errors are signaled via the error pointer, when an error occurs the function
// sets the error and returns (it does not close the channel).
// TSIG and enveloping is handled by TransferOut.
//
// Basic use pattern for sending an AXFR:
//
// // q contains the AXFR request
// c := make(chan *Envelope)
// var e *error
// err := TransferOut(w, q, c, e)
// w.Hijack() // hijack the connection so that the package doesn't close it
// for _, rrset := range rrsets { // rrsets is a []RR
// c <- &{Envelope{RR: rrset}
// if e != nil {
// close(c)
// break
// }
// }
// // w.Close() // Don't! Let the client close the connection
func TransferOut(w ResponseWriter, q *Msg, c chan *Envelope, e *error) error {
switch q.Question[0].Qtype {
case TypeAXFR, TypeIXFR:
go xfrOut(w, q, c, e)
return nil
default:
return nil
}
panic("dns: not reached")
}
// TODO(mg): count the RRs and the resulting size.
func xfrOut(w ResponseWriter, req *Msg, c chan *Envelope, e *error) {
rep := new(Msg)
rep.SetReply(req)
rep.Authoritative = true
for x := range c {
// assume it fits
rep.Answer = append(rep.Answer, x.RR...)
if err := w.WriteMsg(rep); e != nil {
*e = err
return
}
w.TsigTimersOnly(true)
rep.Answer = nil
func isSOALast(in *Msg) bool {
if len(in.Answer) > 0 {
return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA
}
return false
}