calculate tsig in the normal query path too
This commit is contained in:
parent
39b9f93167
commit
3089111fa4
14
client.go
14
client.go
|
@ -101,6 +101,7 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, err error) {
|
||||||
func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
|
||||||
var (
|
var (
|
||||||
n int
|
n int
|
||||||
|
mac string
|
||||||
out []byte
|
out []byte
|
||||||
)
|
)
|
||||||
w := new(reply)
|
w := new(reply)
|
||||||
|
@ -108,7 +109,7 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err e
|
||||||
if _, ok := w.client.TsigSecret[t.Hdr.Name]; !ok {
|
if _, ok := w.client.TsigSecret[t.Hdr.Name]; !ok {
|
||||||
return nil, 0, ErrSecret
|
return nil, 0, ErrSecret
|
||||||
}
|
}
|
||||||
out, _, err = TsigGenerate(m, c.TsigSecret[t.Hdr.Name], "", false)
|
out, mac, err = TsigGenerate(m, c.TsigSecret[t.Hdr.Name], "", false)
|
||||||
} else {
|
} else {
|
||||||
out, err = m.Pack()
|
out, err = m.Pack()
|
||||||
}
|
}
|
||||||
|
@ -138,16 +139,13 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err e
|
||||||
}
|
}
|
||||||
if t := r.IsTsig(); t != nil {
|
if t := r.IsTsig(); t != nil {
|
||||||
secret := t.Hdr.Name
|
secret := t.Hdr.Name
|
||||||
if _, ok := client.TsigSecret[secret]; !ok {
|
if _, ok := c.TsigSecret[secret]; !ok {
|
||||||
w.tsigStatus = ErrSecret
|
return r, w.rtt, ErrSecret
|
||||||
return m, nil
|
|
||||||
}
|
}
|
||||||
// Need to work on the original message p, as that was used to calculate the tsig.
|
// Need to work on the original message p, as that was used to calculate the tsig.
|
||||||
w.tsigStatus = TsigVerify(p, w.client.TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly)
|
err = TsigVerify(in, c.TsigSecret[secret], mac, false)
|
||||||
}
|
}
|
||||||
|
return r, w.rtt, err
|
||||||
|
|
||||||
return r, w.rtt, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// dial connects to the address addr for the network set in c.Net
|
// dial connects to the address addr for the network set in c.Net
|
||||||
|
|
12
zone.go
12
zone.go
|
@ -318,16 +318,8 @@ func (z *Zone) Remove(r RR) error {
|
||||||
func (z *Zone) RemoveName(s string) error {
|
func (z *Zone) RemoveName(s string) error {
|
||||||
key := toRadixName(s)
|
key := toRadixName(s)
|
||||||
z.Lock()
|
z.Lock()
|
||||||
zd, exact := z.Radix.Find(key)
|
defer z.Unlock()
|
||||||
if !exact {
|
z.Radix.Remove(key)
|
||||||
defer z.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
z.Unlock()
|
|
||||||
zd.Value.(*ZoneData).mutex.Lock()
|
|
||||||
defer zd.Value.(*ZoneData).mutex.Unlock()
|
|
||||||
zd.Value = nil // remove the lot
|
|
||||||
|
|
||||||
if len(s) > 1 && s[0] == '*' && s[1] == '.' {
|
if len(s) > 1 && s[0] == '*' && s[1] == '.' {
|
||||||
z.Wildcard--
|
z.Wildcard--
|
||||||
if z.Wildcard < 0 {
|
if z.Wildcard < 0 {
|
||||||
|
|
Loading…
Reference in New Issue