diff --git a/client.go b/client.go index e66216cf..2096928e 100644 --- a/client.go +++ b/client.go @@ -12,39 +12,77 @@ import ( "time" ) -// Order of events: -// *client -> *reply -> Exchange() -> dial()/send()->write()/receive()->read() +const dnsTimeout time.Duration = 2 * 1e9 -// Do I want make this an interface thingy? -type reply struct { - client *Client - addr string - req *Msg - conn net.Conn - tsigRequestMAC string - tsigTimersOnly bool - tsigStatus error +// A Conn represents a connection to a DNS server. +type Conn struct { + net.Conn // a net.Conn holding the connection + UDPSize uint16 // Minimum receive buffer for UDP messages + TsigSecret map[string]string // Secret(s) for Tsig map[], zonename must be fully qualified rtt time.Duration t time.Time + tsigRequestMAC string } -// A Client defines parameter for a DNS client. A nil -// Client is usable for sending queries. +// A Client defines parameters for a DNS client. A nil Client is usable for sending queries. type Client struct { Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one (default is "" for UDP) - ReadTimeout time.Duration // the net.Conn.SetReadTimeout value for new connections (ns), defaults to 2 * 1e9 - WriteTimeout time.Duration // the net.Conn.SetWriteTimeout value for new connections (ns), defaults to 2 * 1e9 + DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9 + ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9 + WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9 TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be fully qualified SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass group singleflight } -func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.Duration, err error) { +// Exchange performs a synchronous UDP query. It sends the message m to the address +// contained in a and waits for an reply. +func Exchange(m *Msg, a string) (r *Msg, err error) { + co := new(Conn) + co.Conn, err = net.DialTimeout("udp", a, dnsTimeout) + if err != nil { + return nil, err + } + + defer co.Close() + co.SetReadDeadline(time.Now().Add(dnsTimeout)) + co.SetWriteDeadline(time.Now().Add(dnsTimeout)) + if err = co.WriteMsg(m); err != nil { + return nil, err + } + r, err = co.ReadMsg() + return r, err +} + +// ExchangeConn performs a synchronous query. It sends the message m via the connection +// c and waits for a reply. The connection c is not closed by ExchangeConn. +// This function is going away, but can easily be mimicked: +// +// co := new(dns.Conn) +// co.Conn = c // c is your net.Conn +// co.WriteMsg(m) +// in, _ := co.ReadMsg() +// +func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { + println("dns: this function is deprecated") + co := new(Conn) + co.Conn = c + if err = co.WriteMsg(m); err != nil { + return nil, err + } + r, err = co.ReadMsg() + return r, err +} + +// Exchange performs an synchronous query. It sends the message m to the address +// contained in a and waits for an reply. Basic use pattern with a *dns.Client: +// +// c := new(dns.Client) +// in, rtt, err := c.Exchange(message, "127.0.0.1:53") +// +func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { if !c.SingleInflight { - if s == nil { - return c.exchange(m, a) - } - return c.exchangeConn(m, s) + return c.exchange(m, a) } // This adds a bunch of garbage, TODO(miek). t := "nop" @@ -56,86 +94,71 @@ func (c *Client) exchangeMerge(m *Msg, a string, s net.Conn) (r *Msg, rtt time.D cl = cl1 } r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { - if s == nil { - return c.exchange(m, a) - } - return c.exchangeConn(m, s) + return c.exchange(m, a) }) if err != nil { return r, rtt, err } if shared { r1 := r.copy() - r1.Id = r.Id // Copy Id! r = r1 } return r, rtt, nil } -// Exchange performs an synchronous query. It sends the message m to the address -// contained in a and waits for an reply. Basic use pattern with a *dns.Client: -// -// c := new(dns.Client) -// in, rtt, err := c.Exchange(message, "127.0.0.1:53") -// -func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - return c.exchangeMerge(m, a, nil) -} - func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { - w := &reply{client: c, addr: a} - if err = w.dial(); err != nil { - return nil, 0, err + co := new(Conn) + timeout := dnsTimeout + if c.DialTimeout != 0 { + timeout = c.DialTimeout } - defer w.conn.Close() - if err = w.send(m); err != nil { - return nil, 0, err - } - r, err = w.receive() - return r, w.rtt, err -} - -// ExchangeConn performs an synchronous query. It sends the message m trough the -// connection s and waits for a reply. -func (c *Client) ExchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) { - return c.exchangeMerge(m, "", s) -} - -func (c *Client) exchangeConn(m *Msg, s net.Conn) (r *Msg, rtt time.Duration, err error) { - w := &reply{client: c, conn: s} - if err = w.send(m); err != nil { - return nil, 0, err - } - r, err = w.receive() - return r, w.rtt, err -} - -// dial connects to the address addr for the network set in c.Net -func (w *reply) dial() (err error) { - var conn net.Conn - if w.client.Net == "" { - conn, err = net.DialTimeout("udp", w.addr, 5*1e9) + if c.Net == "" { + co.Conn, err = net.DialTimeout("udp", a, timeout) } else { - conn, err = net.DialTimeout(w.client.Net, w.addr, 5*1e9) + co.Conn, err = net.DialTimeout(c.Net, a, timeout) } if err != nil { - return err + return nil, 0, err } - w.conn = conn - return + timeout = dnsTimeout + if c.ReadTimeout != 0 { + timeout = c.ReadTimeout + } + co.SetReadDeadline(time.Now().Add(timeout)) + timeout = dnsTimeout + if c.WriteTimeout != 0 { + timeout = c.WriteTimeout + } + co.SetWriteDeadline(time.Now().Add(timeout)) + defer co.Close() + opt := m.IsEdns0() + if opt != nil && opt.UDPSize() >= MinMsgSize { + co.UDPSize = opt.UDPSize() + } + co.TsigSecret = c.TsigSecret + if err = co.WriteMsg(m); err != nil { + return nil, 0, err + } + r, err = co.ReadMsg() + return r, co.rtt, err } -func (w *reply) receive() (*Msg, error) { +// ReadMsg reads a message from the connection co. +// If the received message contains a TSIG record the transaction +// signature is verified. +func (co *Conn) ReadMsg() (*Msg, error) { var p []byte m := new(Msg) - switch w.client.Net { - case "tcp", "tcp4", "tcp6": + if _, ok := co.Conn.(*net.TCPConn); ok { p = make([]byte, MaxMsgSize) - case "", "udp", "udp4", "udp6": - // OPT! TODO(mg) - p = make([]byte, DefaultMsgSize) + } else { + if co.UDPSize >= 512 { + p = make([]byte, co.UDPSize) + } else { + p = make([]byte, MinMsgSize) + } } - n, err := w.read(p) + n, err := co.Read(p) if err != nil && n == 0 { return nil, err } @@ -143,21 +166,20 @@ func (w *reply) receive() (*Msg, error) { if err := m.Unpack(p); err != nil { return nil, err } - w.rtt = time.Since(w.t) + co.rtt = time.Since(co.t) if t := m.IsTsig(); t != nil { - secret := t.Hdr.Name - if _, ok := w.client.TsigSecret[secret]; !ok { - w.tsigStatus = ErrSecret + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { return m, ErrSecret } // Need to work on the original message p, as that was used to calculate the tsig. - w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly) + err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } - return m, w.tsigStatus + return m, err } -func (w *reply) read(p []byte) (n int, err error) { - if w.conn == nil { +// Read implements the net.Conn read method. +func (co *Conn) Read(p []byte) (n int, err error) { + if co.Conn == nil { return 0, ErrConnEmpty } if len(p) < 2 { @@ -200,75 +222,83 @@ func (w *reply) read(p []byte) (n int, err error) { return n, err } -// send sends a dns msg to the address specified in w. +// WriteMsg sends a message throught the connection co. // If the message m contains a TSIG record the transaction // signature is calculated. -func (w *reply) send(m *Msg) (err error) { +func (co *Conn) WriteMsg(m *Msg) (err error) { var out []byte if t := m.IsTsig(); t != nil { mac := "" - name := t.Hdr.Name - if _, ok := w.client.TsigSecret[name]; !ok { + if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { return ErrSecret } - out, mac, err = TsigGenerate(m, w.client.TsigSecret[name], w.tsigRequestMAC, w.tsigTimersOnly) - w.tsigRequestMAC = mac + out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) + // Set for the next read, allthough only used in zone transfers + co.tsigRequestMAC = mac } else { out, err = m.Pack() } if err != nil { return err } - w.t = time.Now() - if _, err = w.write(out); err != nil { + co.t = time.Now() + if _, err = co.Write(out); err != nil { return err } return nil } -func (w *reply) write(p []byte) (n int, err error) { - switch w.client.Net { - case "tcp", "tcp4", "tcp6": +// Write implements the net.Conn Write method. +func (co *Conn) Write(p []byte) (n int, err error) { + if t, ok := co.Conn.(*net.TCPConn); ok { if len(p) < 2 { return 0, io.ErrShortBuffer } - setTimeouts(w) l := make([]byte, 2) l[0], l[1] = packUint16(uint16(len(p))) p = append(l, p...) - n, err := w.conn.Write(p) + n, err := t.Write(p) if err != nil { return n, err } i := n if i < len(p) { - j, err := w.conn.Write(p[i:len(p)]) + j, err := t.Write(p[i:len(p)]) if err != nil { return i, err } i += j } n = i +<<<<<<< HEAD case "", "udp", "udp4", "udp6": setTimeouts(w) n, err = w.conn.Write(p) if err != nil { return n, err } +======= + return n, err +>>>>>>> net } - return + n, err = co.Conn.(*net.UDPConn).Write(p) + return n, err } -func setTimeouts(w *reply) { - if w.client.ReadTimeout == 0 { - w.conn.SetReadDeadline(time.Now().Add(2 * 1e9)) - } else { - w.conn.SetReadDeadline(time.Now().Add(w.client.ReadTimeout)) - } +// Close implements the net.Conn Close method. +func (co *Conn) Close() error { return co.Conn.Close() } - if w.client.WriteTimeout == 0 { - w.conn.SetWriteDeadline(time.Now().Add(2 * 1e9)) - } else { - w.conn.SetWriteDeadline(time.Now().Add(w.client.WriteTimeout)) - } -} +// LocalAddr implements the net.Conn LocalAddr method. +func (co *Conn) LocalAddr() net.Addr { return co.Conn.LocalAddr() } + +// RemoteAddr implements the net.Conn RemoteAddr method. +func (co *Conn) RemoteAddr() net.Addr { return co.Conn.RemoteAddr() } + +// SetDeadline implements the net.Conn SetDeadline method. +func (co *Conn) SetDeadline(t time.Time) error { return co.Conn.SetDeadline(t) } + +// SetReadDeadline implements the net.Conn SetReadDeadline method. +func (co *Conn) SetReadDeadline(t time.Time) error { return co.Conn.SetReadDeadline(t) } + +// SetWriteDeadline implements the net.Conn SetWriteDeadline method. +func (co *Conn) SetWriteDeadline(t time.Time) error { return co.Conn.SetWriteDeadline(t) } diff --git a/client_test.go b/client_test.go index 64f84b08..a84fcd06 100644 --- a/client_test.go +++ b/client_test.go @@ -82,14 +82,13 @@ Loop: /* func TestClientTsigAXFR(t *testing.T) { m := new(Msg) - m.SetAxfr("miek.nl.") + m.SetAxfr("example.nl.") m.SetTsig("axfr.", HmacMD5, 300, time.Now().Unix()) - c := new(Client) - c.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="} - c.Net = "tcp" + tr := new(Transfer) + tr.TsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="} - if a, err := c.TransferIn(m, "37.251.95.53:53"); err != nil { + if a, err := tr.In(m, "176.58.119.54:53"); err != nil { t.Log("Failed to setup axfr: " + err.Error()) t.Fatal() } else { @@ -106,14 +105,12 @@ func TestClientTsigAXFR(t *testing.T) { } } -func TestClientAXFRMultipleMessages(t *testing.T) { +func TestClientAXFRMultipleEnvelopes(t *testing.T) { m := new(Msg) - m.SetAxfr("dnsex.nl.") + m.SetAxfr("nlnetlabs.nl.") - c := new(Client) - c.Net = "tcp" - - if a, err := c.TransferIn(m, "37.251.95.53:53"); err != nil { + tr := new(Transfer) + if a, err := tr.In(m, "213.154.224.1:53"); err != nil { t.Log("Failed to setup axfr" + err.Error()) t.Fail() return @@ -130,7 +127,7 @@ func TestClientAXFRMultipleMessages(t *testing.T) { */ // not really a test, but shows how to use update leases -func TestUpdateLeaseTSIG(t *testing.T) { +func ExampleUpdateLeaseTSIG(t *testing.T) { m := new(Msg) m.SetUpdate("t.local.ip6.io.") rr, _ := NewRR("t.local.ip6.io. 30 A 127.0.0.1") @@ -151,16 +148,9 @@ func TestUpdateLeaseTSIG(t *testing.T) { m.SetTsig("polvi.", HmacMD5, 300, time.Now().Unix()) c.TsigSecret = map[string]string{"polvi.": "pRZgBrBvI4NAHZYhxmhs/Q=="} - w := new(reply) - w.client = c - w.addr = "127.0.0.1:53" - w.req = m - - if err := w.dial(); err != nil { + _, _, err := c.Exchange(m, "127.0.0.1:53") + if err != nil { + t.Log(err.Error()) t.Fail() } - if err := w.send(m); err != nil { - t.Fail() - } - } diff --git a/dns.go b/dns.go index cc442f17..8c746264 100644 --- a/dns.go +++ b/dns.go @@ -36,7 +36,7 @@ // In the DNS messages are exchanged, these messages contain resource // records (sets). Use pattern for creating a message: // -// m := dns.new(Msg) +// m := new(dns.Msg) // m.SetQuestion("miek.nl.", dns.TypeMX) // // Or when not certain if the domain name is fully qualified: @@ -66,6 +66,11 @@ // // c.SingleInflight = true // +// If these "advanced" features are not needed, a simple UDP query can be send, +// with: +// +// in, err := dns.Exchange(m1, "127.0.0.1:53") +// // A dns message consists out of four sections. // The question section: in.Question, the answer section: in.Answer, // the authority section: in.Ns and the additional section: in.Extra. @@ -86,7 +91,7 @@ import ( const ( year68 = 1 << 31 // For RFC1982 (Serial Arithmetic) calculations in 32 bits. DefaultMsgSize = 4096 // Standard default for larger than 512 packets. - udpMsgSize = 512 // Default buffer size for servers receiving UDP packets. + MinMsgSize = 512 // Minimal size of a DNS packet. MaxMsgSize = 65536 // Largest possible DNS packet. defaultTtl = 3600 // Default TTL. ) diff --git a/dns_test.go b/dns_test.go index 0e508f93..aa208a99 100644 --- a/dns_test.go +++ b/dns_test.go @@ -9,9 +9,6 @@ import ( "testing" ) -// Query with way to long name -//./q mx bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.bla.miek.nl.miek.nl.miek123.nl. - func TestPackUnpack(t *testing.T) { out := new(Msg) out.Answer = make([]RR, 1) @@ -245,7 +242,7 @@ func TestNoRdataPack(t *testing.T) { data := make([]byte, 1024) for typ, fn := range rr_mk { if typ == TypeCAA { - continue // known broken, will fix. TODO(miek) + continue // TODO(miek): known ommision } r := fn() *r.Header() = RR_Header{Name: "miek.nl.", Rrtype: typ, Class: ClassINET, Ttl: 3600} diff --git a/edns.go b/edns.go index 7e716a6b..8852ae28 100644 --- a/edns.go +++ b/edns.go @@ -443,7 +443,6 @@ func (e *EDNS0_DAU) String() string { } } return s - } type EDNS0_DHU struct { diff --git a/ex/q/q.go b/ex/q/q.go index 3fbd2b28..48c183ac 100644 --- a/ex/q/q.go +++ b/ex/q/q.go @@ -141,6 +141,14 @@ Flags: nameserver = dns.Fqdn(nameserver) + ":" + strconv.Itoa(*port) } c := new(dns.Client) + t := new(dns.Transfer) + c.Net = "udp" + if *four { + c.Net = "udp4" + } + if *six { + c.Net = "udp6" + } if *tcp { c.Net = "tcp" if *four { @@ -149,14 +157,6 @@ Flags: if *six { c.Net = "tcp6" } - } else { - c.Net = "udp" - if *four { - c.Net = "udp4" - } - if *six { - c.Net = "udp6" - } } m := new(dns.Msg) @@ -204,14 +204,15 @@ Flags: m.Extra = append(m.Extra, o) } +query: for _, v := range qname { m.Question[0] = dns.Question{dns.Fqdn(v), qtype, qclass} m.Id = dns.Id() - // Add tsig if *tsig != "" { if algo, name, secret, ok := tsigKeyParse(*tsig); ok { m.SetTsig(name, algo, 300, time.Now().Unix()) c.TsigSecret = map[string]string{name: secret} + t.TsigSecret = map[string]string{name: secret} } else { fmt.Fprintf(os.Stderr, "TSIG key data error\n") return @@ -221,13 +222,26 @@ Flags: fmt.Printf("%s", m.String()) fmt.Printf("\n;; size: %d bytes\n\n", m.Len()) } - if qtype == dns.TypeAXFR { - c.Net = "tcp" - doXfr(c, m, nameserver) - continue - } - if qtype == dns.TypeIXFR { - doXfr(c, m, nameserver) + if qtype == dns.TypeAXFR || qtype == dns.TypeIXFR { + env, err := t.In(m, nameserver) + if err != nil { + fmt.Printf(";; %s\n", err.Error()) + continue + } + envelope := 0 + record := 0 + for e := range env { + if e.Error != nil { + fmt.Printf(";; %s\n", e.Error.Error()) + continue query + } + for _, r := range e.RR { + fmt.Printf("%s\n", r) + } + record+=len(e.RR) + envelope++ + } + fmt.Printf("\n;; xfr size: %d records (envelopes %d)\n", record, envelope) continue } r, rtt, e := c.Exchange(m, nameserver) @@ -280,15 +294,15 @@ func tsigKeyParse(s string) (algo, name, secret string, ok bool) { s1 := strings.SplitN(s, ":", 3) switch len(s1) { case 2: - return "hmac-md5.sig-alg.reg.int.", s1[0], s1[1], true + return "hmac-md5.sig-alg.reg.int.", dns.Fqdn(s1[0]), s1[1], true case 3: switch s1[0] { case "hmac-md5": - return "hmac-md5.sig-alg.reg.int.", s1[1], s1[2], true + return "hmac-md5.sig-alg.reg.int.", dns.Fqdn(s1[1]), s1[2], true case "hmac-sha1": - return "hmac-sha1.", s1[1], s1[2], true + return "hmac-sha1.", dns.Fqdn(s1[1]), s1[2], true case "hmac-sha256": - return "hmac-sha256.", s1[1], s1[2], true + return "hmac-sha256.", dns.Fqdn(s1[1]), s1[2], true } } return @@ -402,20 +416,22 @@ func shortRR(r dns.RR) dns.RR { } func doXfr(c *dns.Client, m *dns.Msg, nameserver string) { - if t, e := c.TransferIn(m, nameserver); e == nil { - for r := range t { - if r.Error == nil { - for _, rr := range r.RR { - if *short { - rr = shortRR(rr) + /* + if t, e := c.TransferIn(m, nameserver); e == nil { + for r := range t { + if r.Error == nil { + for _, rr := range r.RR { + if *short { + rr = shortRR(rr) + } + fmt.Printf("%v\n", rr) } - fmt.Printf("%v\n", rr) + } else { + fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", r.Error.Error()) } - } else { - fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", r.Error.Error()) } + } else { + fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", e.Error()) } - } else { - fmt.Fprintf(os.Stderr, "Failure to read XFR: %s\n", e.Error()) - } + */ } diff --git a/ex/reflect/reflect.go b/ex/reflect/reflect.go index bc33226b..b41f43f7 100644 --- a/ex/reflect/reflect.go +++ b/ex/reflect/reflect.go @@ -103,24 +103,6 @@ func handleReflect(w dns.ResponseWriter, r *dns.Msg) { t.Txt = []string{str} switch r.Question[0].Qtype { - case dns.TypeAXFR: - c := make(chan *dns.Envelope) - var e *error - if err := dns.TransferOut(w, r, c, e); err != nil { - close(c) - return - } - soa, _ := dns.NewRR(`whoami.miek.nl. IN SOA elektron.atoom.net. miekg.atoom.net. ( - 2009032802 - 21600 - 7200 - 604800 - 3600)`) - c <- &dns.Envelope{RR: []dns.RR{soa, t, rr, soa}} - close(c) - w.Hijack() - // w.Close() // Client closes - return case dns.TypeTXT: m.Answer = append(m.Answer, t) m.Extra = append(m.Extra, rr) @@ -129,6 +111,21 @@ func handleReflect(w dns.ResponseWriter, r *dns.Msg) { case dns.TypeAAAA, dns.TypeA: m.Answer = append(m.Answer, rr) m.Extra = append(m.Extra, t) + + case dns.TypeAXFR, dns.TypeIXFR: + c := make(chan *dns.Envelope) + tr := new(dns.Transfer) + defer close(c) + err := tr.Out(w, r, c) + if err != nil { + return + } + soa, _ := dns.NewRR(`whoami.miek.nl. 0 IN SOA linode.atoom.net. miek.miek.nl. 2009032802 21600 7200 604800 3600`) + c <- &dns.Envelope{RR: []dns.RR{soa, t, rr, soa}} + w.Hijack() + // w.Close() // Client closes connection + return + } if r.IsTsig() != nil { @@ -161,7 +158,7 @@ func serve(net, name, secret string) { } func main() { - runtime.GOMAXPROCS(runtime.NumCPU()*4) + runtime.GOMAXPROCS(runtime.NumCPU() * 4) cpuprofile := flag.String("cpuprofile", "", "write cpu profile to file") printf = flag.Bool("print", false, "print replies") compress = flag.Bool("compress", false, "compress replies") diff --git a/labels.go b/labels.go index 7742dbd9..e7707b2e 100644 --- a/labels.go +++ b/labels.go @@ -160,17 +160,3 @@ func PrevLabel(s string, n int) (i int, start bool) { } return lab[len(lab)-n], false } - -func LenLabels(s string) int { - println("LenLabels is to be removed in future versions, for the better named CountLabel") - return CountLabel(s) -} -func SplitLabels(s string) []string { - println("SplitLabels is to be removed in future versions, for the better named SplitDomainName") - return SplitDomainName(s) -} - -func CompareLabels(s1, s2 string) (n int) { - println("CompareLabels is to be removed in future versions, for better named CompareDomainName") - return CompareDomainName(s1, s2) -} diff --git a/server.go b/server.go index bc363800..7aa39168 100644 --- a/server.go +++ b/server.go @@ -357,7 +357,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { handler = DefaultServeMux } if srv.UDPSize == 0 { - srv.UDPSize = udpMsgSize + srv.UDPSize = MinMsgSize } for { if srv.ReadTimeout != 0 { diff --git a/xfr.go b/xfr.go index bd440476..9fb11447 100644 --- a/xfr.go +++ b/xfr.go @@ -6,84 +6,91 @@ package dns -// New Transfer +import ( + "net" + "time" +) -// Envelope is used when doing [IA]XFR with a remote server. +// Envelope is used when doing a transfer with a remote server. type Envelope struct { RR []RR // The set of RRs in the answer section of the AXFR reply message. Error error // If something went wrong, this contains the error. } -// TransferIn performs a [AI]XFR request (depends on the message's Qtype). It returns -// a channel of *Envelope on which the replies from the server are sent. At the end of -// the transfer the channel is closed. -// The messages are TSIG checked if -// needed, no other post-processing is performed. The caller must dissect the returned -// messages. -// -// Basic use pattern for receiving an AXFR: -// -// // m contains the AXFR request -// t, e := c.TransferIn(m, "127.0.0.1:53") -// for r := range t { -// // ... deal with r.RR or r.Error -// } -func (c *Client) TransferIn(q *Msg, a string) (chan *Envelope, error) { - w := new(reply) - w.client = c - w.addr = a - w.req = q - if err := w.dial(); err != nil { - return nil, err - } - if err := w.send(q); err != nil { - return nil, err - } - e := make(chan *Envelope) - switch q.Question[0].Qtype { - case TypeAXFR: - go w.axfrIn(q, e) - return e, nil - case TypeIXFR: - go w.ixfrIn(q, e) - return e, nil - default: - return nil, nil - } - panic("dns: not reached") +// A Transfer defines parameters that are used during a zone transfer. +type Transfer struct { + *Conn + DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9 + ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9 + WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9 + TsigSecret map[string]string // Secret(s) for Tsig map[], zonename must be fully qualified + tsigTimersOnly bool } -func (w *reply) axfrIn(q *Msg, c chan *Envelope) { +// In performs an incoming transfer with the server in a. +func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) { + t.Conn = new(Conn) + timeout := dnsTimeout + if t.DialTimeout != 0 { + timeout = t.DialTimeout + } + t.Conn.Conn, err = net.DialTimeout("tcp", a, timeout) + if err != nil { + return nil, err + } + if err := t.WriteMsg(q); err != nil { + return nil, err + } + env = make(chan *Envelope) + go func() { + if q.Question[0].Qtype == TypeAXFR { + go t.inAxfr(q.Id, env) + return + } + if q.Question[0].Qtype == TypeIXFR { + go t.inIxfr(q.Id, env) + return + } + }() + return env, nil +} + +func (t *Transfer) inAxfr(id uint16, c chan *Envelope) { first := true - defer w.conn.Close() + defer t.Close() defer close(c) + timeout := dnsTimeout + if t.ReadTimeout != 0 { + timeout = t.ReadTimeout + } for { - in, err := w.receive() + t.Conn.SetReadDeadline(time.Now().Add(timeout)) + in, err := t.ReadMsg() if err != nil { c <- &Envelope{nil, err} return } - if in.Id != q.Id { + if id != in.Id { c <- &Envelope{in.Answer, ErrId} return } if first { - if !checkXfrSOA(in, true) { + if !isSOAFirst(in) { c <- &Envelope{in.Answer, ErrSoa} return } first = !first // only one answer that is SOA, receive more if len(in.Answer) == 1 { - w.tsigTimersOnly = true + t.tsigTimersOnly = true c <- &Envelope{in.Answer, nil} continue } } if !first { - w.tsigTimersOnly = true // Subsequent envelopes use this. - if checkXfrSOA(in, false) { + t.tsigTimersOnly = true // Subsequent envelopes use this. + if isSOALast(in) { c <- &Envelope{in.Answer, nil} return } @@ -93,30 +100,35 @@ func (w *reply) axfrIn(q *Msg, c chan *Envelope) { panic("dns: not reached") } -func (w *reply) ixfrIn(q *Msg, c chan *Envelope) { - var serial uint32 // The first serial seen is the current server serial +func (t *Transfer) inIxfr(id uint16, c chan *Envelope) { + serial := uint32(0) // The first serial seen is the current server serial first := true - defer w.conn.Close() + defer t.Close() defer close(c) + timeout := dnsTimeout + if t.ReadTimeout != 0 { + timeout = t.ReadTimeout + } for { - in, err := w.receive() + t.SetReadDeadline(time.Now().Add(timeout)) + in, err := t.ReadMsg() if err != nil { c <- &Envelope{in.Answer, err} return } - if q.Id != in.Id { + if id != in.Id { c <- &Envelope{in.Answer, ErrId} return } if first { // A single SOA RR signals "no changes" - if len(in.Answer) == 1 && checkXfrSOA(in, true) { + if len(in.Answer) == 1 && isSOAFirst(in) { c <- &Envelope{in.Answer, nil} return } // Check if the returned answer is ok - if !checkXfrSOA(in, true) { + if !isSOAFirst(in) { c <- &Envelope{in.Answer, ErrSoa} return } @@ -127,7 +139,7 @@ func (w *reply) ixfrIn(q *Msg, c chan *Envelope) { // Now we need to check each message for SOA records, to see what we need to do if !first { - w.tsigTimersOnly = true + t.tsigTimersOnly = true // If the last record in the IXFR contains the servers' SOA, we should quit if v, ok := in.Answer[len(in.Answer)-1].(*SOA); ok { if v.Serial == serial { @@ -138,70 +150,95 @@ func (w *reply) ixfrIn(q *Msg, c chan *Envelope) { c <- &Envelope{in.Answer, nil} } } - panic("dns: not reached") } -// Check if he SOA record exists in the Answer section of -// the packet. If first is true the first RR must be a SOA -// if false, the last one should be a SOA. -func checkXfrSOA(in *Msg, first bool) bool { - if len(in.Answer) > 0 { - if first { - return in.Answer[0].Header().Rrtype == TypeSOA - } else { - return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA + + +// Out performs an outgoing transfer with the client connecting in w. +// Basic use pattern: +// +// ch := make(chan *dns.Envelope) +// tr := new(dns.Transfer) +// tr.Out(w, r, ch) +// c <- &dns.Envelope{RR: []dns.RR{soa, rr1, rr2, rr3, soa}} +// close(ch) +// w.Hijack() +// // w.Close() // Client closes connection +// +// The server is responsible for sending the correct sequence of RRs through the +// channel ch. +func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error { + r := new(Msg) + // Compress? + r.SetReply(q) + r.Authoritative = true + + go func() { + for x := range ch { + // assume it fits TODO(miek): fix + r.Answer = append(r.Answer, x.RR...) + if err := w.WriteMsg(r); err != nil { + return + } } + w.TsigTimersOnly(true) + r.Answer = nil + }() + return nil +} + +// ReadMsg reads a message from the transfer connection t. +func (t *Transfer) ReadMsg() (*Msg, error) { + m := new(Msg) + p := make([]byte, MaxMsgSize) + n, err := t.Read(p) + if err != nil && n == 0 { + return nil, err + } + p = p[:n] + if err := m.Unpack(p); err != nil { + return nil, err + } + if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil { + if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { + return m, ErrSecret + } + // Need to work on the original message p, as that was used to calculate the tsig. + err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) + } + return m, err +} + +// WriteMsg write a message throught the transfer connection t. +func (t *Transfer) WriteMsg(m *Msg) (err error) { + var out []byte + if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil { + if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { + return ErrSecret + } + out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) + } else { + out, err = m.Pack() + } + if err != nil { + return err + } + if _, err = t.Write(out); err != nil { + return err + } + return nil +} + +func isSOAFirst(in *Msg) bool { + if len(in.Answer) > 0 { + return in.Answer[0].Header().Rrtype == TypeSOA } return false } -// TransferOut performs an outgoing [AI]XFR depending on the request message. The -// caller is responsible for sending the correct sequence of RR sets through -// the channel c. For reasons of symmetry Envelope is re-used. -// Errors are signaled via the error pointer, when an error occurs the function -// sets the error and returns (it does not close the channel). -// TSIG and enveloping is handled by TransferOut. -// -// Basic use pattern for sending an AXFR: -// -// // q contains the AXFR request -// c := make(chan *Envelope) -// var e *error -// err := TransferOut(w, q, c, e) -// w.Hijack() // hijack the connection so that the package doesn't close it -// for _, rrset := range rrsets { // rrsets is a []RR -// c <- &{Envelope{RR: rrset} -// if e != nil { -// close(c) -// break -// } -// } -// // w.Close() // Don't! Let the client close the connection -func TransferOut(w ResponseWriter, q *Msg, c chan *Envelope, e *error) error { - switch q.Question[0].Qtype { - case TypeAXFR, TypeIXFR: - go xfrOut(w, q, c, e) - return nil - default: - return nil - } - panic("dns: not reached") -} - -// TODO(mg): count the RRs and the resulting size. -func xfrOut(w ResponseWriter, req *Msg, c chan *Envelope, e *error) { - rep := new(Msg) - rep.SetReply(req) - rep.Authoritative = true - - for x := range c { - // assume it fits - rep.Answer = append(rep.Answer, x.RR...) - if err := w.WriteMsg(rep); e != nil { - *e = err - return - } - w.TsigTimersOnly(true) - rep.Answer = nil +func isSOALast(in *Msg) bool { + if len(in.Answer) > 0 { + return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA } + return false }