diff --git a/xfr.go b/xfr.go index c8e78cf7..155572ae 100644 --- a/xfr.go +++ b/xfr.go @@ -9,7 +9,7 @@ import ( "time" ) -// Envelope is used when doing [IA]XFR with a remote server. +// Envelope is used when doing a transfer with a remote server. type Envelope struct { RR []RR // The set of RRs in the answer section of the AXFR reply message. Error error // If something went wrong, this contains the error. @@ -17,11 +17,10 @@ type Envelope struct { type Transfer struct { Conn - DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9 - ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9 - WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9 - TsigSecret map[string]string // secret(s) for Tsig map[], zonename must be fully qualified - tsigTimersOnly bool + DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9 + ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9 + WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9 + timersOnly bool } // In performs a [AI]XFR request (depends on the message's Qtype). It returns @@ -49,118 +48,28 @@ func (t *Transfer) In(q *Msg, a string, env chan *Envelope) (err error) { if err != nil { return err } - // re-read 'n stuff must be pushed down - timeout = dnsTimeout - if t.ReadTimeout != 0 { - timeout = t.ReadTimeout + if q.Question[0].Qtype == TypeAXFR { + go t.InAxfr(q.Id, env) + return nil } - co.SetReadDeadline(time.Now().Add(dnsTimeout)) - timeout = dnsTimeout - if t.WriteTimeout != 0 { - timeout = t.WriteTimeout + if q.Question[0].Qtype == TypeIXFR { + go t.InAxfr(q.Id, env) + return nil } - co.SetWriteDeadline(time.Now().Add(dnsTimeout)) - defer co.Close() - return nil + return nil // TODO(miek): some error } -// Out performs an outgoing [AI]XFR depending on the request message. The -// caller is responsible for sending the correct sequence of RR sets through -// the channel c. For reasons of symmetry Envelope is re-used. -// Errors are signaled via the error pointer, when an error occurs the function -// sets the error and returns (it does not close the channel). -// TSIG and enveloping is handled by TransferOut. -// -// Basic use pattern for sending an AXFR: -// -// // m contains the AXFR request -// t := new(dns.Transfer) -// env := make(chan *dns.Envelope) -// err := t.Out(m, c, e) -// for rrset := range rrsets { // rrsets is a []RR -// c <- &{Envelope{RR: rrset} -// if e != nil { -// close(c) -// break -// } -// } -// // w.Close() // Don't! Let the client close the connection -func (t *Transfer) Out(q *Msg, a string) (chan *Envelope, error) { - return nil, nil -} - -// ReadMsg reads a message from the transfer connection t. -func (t *Transfer) ReadMsg() (*Msg, error) { - m := new(Msg) - p := make([]byte, MaxMsgSize) - n, err := t.Conn.Read(p) - if err != nil && n == 0 { - return nil, err - } - p = p[:n] - if err := m.Unpack(p); err != nil { - return nil, err - } - if ts := m.IsTsig(); t != nil { - if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { - return m, ErrSecret - } - // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.requestMAC, false) - } - return m, err -} - -// WriteMsg write a message throught the transfer connection t. -func (t *Transfer) WriteMsg(m *Msg) (err error) { - var out []byte - if ts := m.IsTsig(); t != nil { - mac := "" - if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { - return ErrSecret - } - out, mac, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.requestMAC, false) - // Set for the next read, allthough only used in zone transfers - t.requestMAC = mac - } else { - out, err = m.Pack() - } - if err != nil { - return err - } - if _, err = t.Conn.Write(out); err != nil { - return err - } - return nil -} - -/* -func (c *Client) TransferIn(q *Msg, a string) (chan *Envelope, error) { - e := make(chan *Envelope) - switch q.Question[0].Qtype { - case TypeAXFR: - go w.axfrIn(q, e) - return e, nil - case TypeIXFR: - go w.ixfrIn(q, e) - return e, nil - default: - return nil, nil - } - panic("dns: not reached") -} - -func (w *reply) axfrIn(q *Msg, c chan *Envelope) { +func (t *Transfer) InAxfr(id uint16, c chan *Envelope) { first := true - defer w.conn.Close() + defer t.Close() defer close(c) for { - in, err := w.receive() + in, err := t.ReadMsg() if err != nil { c <- &Envelope{nil, err} return } - if in.Id != q.Id { + if id != q.Id { c <- &Envelope{in.Answer, ErrId} return } @@ -190,6 +99,71 @@ func (w *reply) axfrIn(q *Msg, c chan *Envelope) { panic("dns: not reached") } + // re-read 'n stuff must be pushed down + timeout = dnsTimeout + if t.ReadTimeout != 0 { + timeout = t.ReadTimeout + } + co.SetReadDeadline(time.Now().Add(dnsTimeout)) + timeout = dnsTimeout + if t.WriteTimeout != 0 { + timeout = t.WriteTimeout + } + co.SetWriteDeadline(time.Now().Add(dnsTimeout)) + defer co.Close() + return nil +} + +func (t *Transfer) Out(w ResponseWriter, q *Msg, a string) (chan *Envelope, error) { + ch := make(chan *Envelope) + + return ch, nil +} + +// ReadMsg reads a message from the transfer connection t. +func (t *Transfer) ReadMsg() (*Msg, error) { + m := new(Msg) + p := make([]byte, MaxMsgSize) + n, err := t.Read(p) + if err != nil && n == 0 { + return nil, err + } + p = p[:n] + if err := m.Unpack(p); err != nil { + return nil, err + } + if ts := m.IsTsig(); t != nil { + if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { + return m, ErrSecret + } + // Need to work on the original message p, as that was used to calculate the tsig. + err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.requestMAC, t.timersOnly) + } + return m, err +} + +// WriteMsg write a message throught the transfer connection t. +func (t *Transfer) WriteMsg(m *Msg) (err error) { + var out []byte + if ts := m.IsTsig(); t != nil { + if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { + return ErrSecret + } + out, t.requestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.requestMAC, t.timersOnly) + } else { + out, err = m.Pack() + } + if err != nil { + return err + } + if _, err = t.Write(out); err != nil { + return err + } + return nil +} + +/* + func (w *reply) ixfrIn(q *Msg, c chan *Envelope) { var serial uint32 // The first serial seen is the current server serial first := true @@ -238,20 +212,24 @@ func (w *reply) ixfrIn(q *Msg, c chan *Envelope) { panic("dns: not reached") } -// Check if he SOA record exists in the Answer section of -// the packet. If first is true the first RR must be a SOA -// if false, the last one should be a SOA. -func checkSOA(in *Msg, first bool) bool { +/* + +func checkFirstSOA(in *Msg) bool { if len(in.Answer) > 0 { - if first { - return in.Answer[0].Header().Rrtype == TypeSOA - } else { - return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA - } + return in.Answer[0].Header().Rrtype == TypeSOA } return false } +func checkLastSOA(in *Msg) bool { + if len(in.Answer) > 0 { + return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA + } + return false +} + + +/* func TransferOut(w ResponseWriter, q *Msg, c chan *Envelope, e *error) error { switch q.Question[0].Qtype { case TypeAXFR, TypeIXFR: