diff --git a/zone.go b/zone.go index adc63757..08ecc4a7 100644 --- a/zone.go +++ b/zone.go @@ -5,13 +5,15 @@ package dns import ( "github.com/miekg/radix" "strings" + "sync" ) -// Zone represents a DNS zone. Currently there is no locking implemented. +// Zone represents a DNS zone. type Zone struct { Origin string // Origin of the zone Wildcard int // Whenever we see a wildcard name, this is incremented *radix.Radix // Zone data + mutex *sync.RWMutex } // ZoneData holds all the RRs having their owner name equal to Name. @@ -20,6 +22,7 @@ type ZoneData struct { RR map[uint16][]RR // Map of the RR type to the RR Signatures map[uint16][]*RR_RRSIG // DNSSEC signatures for the RRs, stored under type covered NonAuth bool // Always false, except for NSsets that differ from z.Origin + mutex *sync.RWMutex } // toRadixName reverses a domain name so that when we store it in the radix tree @@ -45,6 +48,7 @@ func NewZone(origin string) *Zone { return nil } z := new(Zone) + z.mutex = new(sync.RWMutex) z.Origin = Fqdn(origin) z.Radix = radix.New() return z @@ -58,8 +62,10 @@ func (z *Zone) Insert(r RR) error { } key := toRadixName(r.Header().Name) + z.mutex.Lock() zd := z.Radix.Find(key) if zd == nil { + defer z.mutex.Unlock() // Check if its a wildcard name if len(r.Header().Name) > 1 && r.Header().Name[0] == '*' && r.Header().Name[1] == '.' { z.Wildcard++ @@ -68,6 +74,7 @@ func (z *Zone) Insert(r RR) error { zd.Name = r.Header().Name zd.RR = make(map[uint16][]RR) zd.Signatures = make(map[uint16][]*RR_RRSIG) + zd.mutex = new(sync.RWMutex) switch t := r.Header().Rrtype; t { case TypeRRSIG: sigtype := r.(*RR_RRSIG).TypeCovered @@ -84,6 +91,9 @@ func (z *Zone) Insert(r RR) error { z.Radix.Insert(key, zd) return nil } + z.mutex.Unlock() + zd.Value.(*ZoneData).mutex.Lock() + defer zd.Value.(*ZoneData).mutex.Unlock() // Name already there switch t := r.Header().Rrtype; t { case TypeRRSIG: @@ -104,10 +114,15 @@ func (z *Zone) Insert(r RR) error { // this is a no-op. func (z *Zone) Remove(r RR) error { key := toRadixName(r.Header().Name) + z.mutex.Lock() zd := z.Radix.Find(key) if zd == nil { + defer z.mutex.Unlock() return nil } + z.mutex.Unlock() + zd.Value.(*ZoneData).mutex.Lock() + defer zd.Value.(*ZoneData).mutex.Unlock() remove := false switch t := r.Header().Rrtype; t { case TypeRRSIG: