calculate tsig in the normal query path too

This commit is contained in:
Miek Gieben 2012-10-16 08:50:53 +02:00
parent 39b9f93167
commit 3089111fa4
2 changed files with 8 additions and 18 deletions

View File

@ -101,6 +101,7 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, err error) {
func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
var ( var (
n int n int
mac string
out []byte out []byte
) )
w := new(reply) w := new(reply)
@ -108,7 +109,7 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err e
if _, ok := w.client.TsigSecret[t.Hdr.Name]; !ok { if _, ok := w.client.TsigSecret[t.Hdr.Name]; !ok {
return nil, 0, ErrSecret return nil, 0, ErrSecret
} }
out, _, err = TsigGenerate(m, c.TsigSecret[t.Hdr.Name], "", false) out, mac, err = TsigGenerate(m, c.TsigSecret[t.Hdr.Name], "", false)
} else { } else {
out, err = m.Pack() out, err = m.Pack()
} }
@ -138,16 +139,13 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err e
} }
if t := r.IsTsig(); t != nil { if t := r.IsTsig(); t != nil {
secret := t.Hdr.Name secret := t.Hdr.Name
if _, ok := client.TsigSecret[secret]; !ok { if _, ok := c.TsigSecret[secret]; !ok {
w.tsigStatus = ErrSecret return r, w.rtt, ErrSecret
return m, nil
} }
// Need to work on the original message p, as that was used to calculate the tsig. // Need to work on the original message p, as that was used to calculate the tsig.
w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly) err = TsigVerify(in, c.TsigSecret[secret], mac, false)
} }
return r, w.rtt, err
return r, w.rtt, nil
} }
// dial connects to the address addr for the network set in c.Net // dial connects to the address addr for the network set in c.Net

12
zone.go
View File

@ -318,16 +318,8 @@ func (z *Zone) Remove(r RR) error {
func (z *Zone) RemoveName(s string) error { func (z *Zone) RemoveName(s string) error {
key := toRadixName(s) key := toRadixName(s)
z.Lock() z.Lock()
zd, exact := z.Radix.Find(key) defer z.Unlock()
if !exact { z.Radix.Remove(key)
defer z.Unlock()
return nil
}
z.Unlock()
zd.Value.(*ZoneData).mutex.Lock()
defer zd.Value.(*ZoneData).mutex.Unlock()
zd.Value = nil // remove the lot
if len(s) > 1 && s[0] == '*' && s[1] == '.' { if len(s) > 1 && s[0] == '*' && s[1] == '.' {
z.Wildcard-- z.Wildcard--
if z.Wildcard < 0 { if z.Wildcard < 0 {