diff --git a/client.go b/client.go index d36d3322..f085e4e0 100644 --- a/client.go +++ b/client.go @@ -16,12 +16,12 @@ const dnsTimeout time.Duration = 2 * 1e9 // A Conn represents a connection (which may be short lived) to a DNS server. type Conn struct { - net.Conn // a net.Conn holding the connection - UDPSize uint16 // Minimum receive buffer for UDP messages - TsigSecret map[string]string // Secret(s) for Tsig map[], zonename must be fully qualified - rtt time.Duration - t time.Time - requestMAC string + net.Conn // a net.Conn holding the connection + UDPSize uint16 // Minimum receive buffer for UDP messages + TsigSecret map[string]string // Secret(s) for Tsig map[], zonename must be fully qualified + rtt time.Duration + t time.Time + tsigRequestMAC string } // A Client defines parameters for a DNS client. A nil Client is usable for sending queries. @@ -150,7 +150,7 @@ func (co *Conn) ReadMsg() (*Msg, error) { return m, ErrSecret } // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.requestMAC, false) + err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) } return m, err } @@ -208,9 +208,9 @@ func (co *Conn) WriteMsg(m *Msg) (err error) { if _, ok := co.TsigSecret[t.Hdr.Name]; !ok { return ErrSecret } - out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.requestMAC, false) + out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false) // Set for the next read, allthough only used in zone transfers - co.requestMAC = mac + co.tsigRequestMAC = mac } else { out, err = m.Pack() } diff --git a/xfr.go b/xfr.go index 155572ae..50ef9652 100644 --- a/xfr.go +++ b/xfr.go @@ -20,7 +20,7 @@ type Transfer struct { DialTimeout time.Duration // net.DialTimeout (ns), defaults to 2 * 1e9 ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections (ns), defaults to 2 * 1e9 WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections (ns), defaults to 2 * 1e9 - timersOnly bool + tsigTimersOnly bool } // In performs a [AI]XFR request (depends on the message's Qtype). It returns @@ -69,27 +69,27 @@ func (t *Transfer) InAxfr(id uint16, c chan *Envelope) { c <- &Envelope{nil, err} return } - if id != q.Id { + if id != in.Id { c <- &Envelope{in.Answer, ErrId} return } if first { - if !checkSOA(in, true) { + if !isSOAFirst(in) { c <- &Envelope{in.Answer, ErrSoa} return } first = !first // only one answer that is SOA, receive more if len(in.Answer) == 1 { - w.tsigTimersOnly = true + t.tsigTimersOnly = true c <- &Envelope{in.Answer, nil} continue } } if !first { - w.tsigTimersOnly = true // Subsequent envelopes use this. - if checkSOA(in, false) { + t.tsigTimersOnly = true // Subsequent envelopes use this. + if isSOALast(in) { c <- &Envelope{in.Answer, nil} return } @@ -99,6 +99,7 @@ func (t *Transfer) InAxfr(id uint16, c chan *Envelope) { panic("dns: not reached") } +/* // re-read 'n stuff must be pushed down timeout = dnsTimeout if t.ReadTimeout != 0 { @@ -112,11 +113,24 @@ func (t *Transfer) InAxfr(id uint16, c chan *Envelope) { co.SetWriteDeadline(time.Now().Add(dnsTimeout)) defer co.Close() return nil -} +*/ func (t *Transfer) Out(w ResponseWriter, q *Msg, a string) (chan *Envelope, error) { ch := make(chan *Envelope) - + r := new(Msg) + r.SetReply(q) + r.Authoritative = true + go func() { + for x := range ch { + // assume it fits TODO(miek): fix + r.Answer = append(r.Answer, x.RR...) + if err := w.WriteMsg(r); err != nil { + return + } + } +// w.TsigTimersOnly(true) +// rep.Answer = nil + }() return ch, nil } @@ -137,7 +151,7 @@ func (t *Transfer) ReadMsg() (*Msg, error) { return m, ErrSecret } // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.requestMAC, t.timersOnly) + err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) } return m, err } @@ -149,7 +163,7 @@ func (t *Transfer) WriteMsg(m *Msg) (err error) { if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { return ErrSecret } - out, t.requestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.requestMAC, t.timersOnly) + out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) } else { out, err = m.Pack() } @@ -211,51 +225,18 @@ func (w *reply) ixfrIn(q *Msg, c chan *Envelope) { } panic("dns: not reached") } +*/ -/* - -func checkFirstSOA(in *Msg) bool { +func isSOAFirst(in *Msg) bool { if len(in.Answer) > 0 { return in.Answer[0].Header().Rrtype == TypeSOA } return false } -func checkLastSOA(in *Msg) bool { +func isSOALast(in *Msg) bool { if len(in.Answer) > 0 { return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA } return false } - - -/* -func TransferOut(w ResponseWriter, q *Msg, c chan *Envelope, e *error) error { - switch q.Question[0].Qtype { - case TypeAXFR, TypeIXFR: - go xfrOut(w, q, c, e) - return nil - default: - return nil - } - panic("dns: not reached") -} - -// TODO(mg): count the RRs and the resulting size. -func xfrOut(w ResponseWriter, req *Msg, c chan *Envelope, e *error) { - rep := new(Msg) - rep.SetReply(req) - rep.Authoritative = true - - for x := range c { - // assume it fits - rep.Answer = append(rep.Answer, x.RR...) - if err := w.WriteMsg(rep); e != nil { - *e = err - return - } - w.TsigTimersOnly(true) - rep.Answer = nil - } -} -*/