diff --git a/client.go b/client.go index d982de7f..1a1370c0 100644 --- a/client.go +++ b/client.go @@ -164,8 +164,8 @@ func (w *reply) receive() (*Msg, error) { } w.rtt = time.Since(w.t) m.Size = n - if m.IsTsig() { - secret := m.Extra[len(m.Extra)-1].(*RR_TSIG).Hdr.Name + if t := m.IsTsig(); t != nil { + secret := t.Hdr.Name if _, ok := w.client.TsigSecret[secret]; !ok { w.tsigStatus = ErrSecret return m, nil @@ -249,9 +249,9 @@ func (w *reply) readClient(p []byte) (n int, err error) { // signature is calculated. func (w *reply) send(m *Msg) (err error) { var out []byte - if m.IsTsig() { + if t := m.IsTsig(); t != nil { mac := "" - name := m.Extra[len(m.Extra)-1].(*RR_TSIG).Hdr.Name + name := t.Hdr.Name if _, ok := w.client.TsigSecret[name]; !ok { return ErrSecret } diff --git a/defaults.go b/defaults.go index b5f6b61d..91a1ea80 100644 --- a/defaults.go +++ b/defaults.go @@ -124,23 +124,26 @@ func (dns *Msg) SetEdns0(udpsize uint16, do bool) *Msg { } // IsTsig checks if the message has a TSIG record as the last record -// in the additional section. -func (dns *Msg) IsTsig() (ok bool) { +// in the additional section. It returns the TSIG record found or nil. +func (dns *Msg) IsTsig() *RR_TSIG { if len(dns.Extra) > 0 { - return dns.Extra[len(dns.Extra)-1].Header().Rrtype == TypeTSIG + if dns.Extra[len(dns.Extra)-1].Header().Rrtype == TypeTSIG { + return dns.Extra[len(dns.Extra)-1].(*RR_TSIG) + } } - return + return nil } // IsEdns0 checks if the message has a EDNS0 (OPT) record, any EDNS0 -// record in the additional section will do. -func (dns *Msg) IsEdns0() (ok bool) { +// record in the additional section will do. It returns the OPT record +// found or nil. +func (dns *Msg) IsEdns0() *RR_OPT { for _, r := range dns.Extra { if r.Header().Rrtype == TypeOPT { - return true + return r.(*RR_OPT) } } - return + return nil } // IsDomainName checks if s is a valid domainname, it returns diff --git a/msg.go b/msg.go index c8ced0da..e01eecd2 100644 --- a/msg.go +++ b/msg.go @@ -79,9 +79,8 @@ type MsgHdr struct { // The layout of a DNS message. type Msg struct { MsgHdr - Compress bool // If true, the message will be compressed when converted to wire format. - Size int // Number of octects in the message received from the wire. - // Remote addr? back + Compress bool // If true, the message will be compressed when converted to wire format. + Size int // Number of octects in the message received from the wire. Question []Question // Holds the RR(s) of the question section. Answer []RR // Holds the RR(s) of the answer section. Ns []RR // Holds the RR(s) of the authority section. diff --git a/parse_test.go b/parse_test.go index c66c8eed..5bb186ff 100644 --- a/parse_test.go +++ b/parse_test.go @@ -380,14 +380,17 @@ func TestParseFailure(t *testing.T) { // A bit useless, how to use b.N? func BenchmarkZoneParsing(b *testing.B) { + b.StopTimer() f, err := os.Open("t/miek.nl.signed_test") if err != nil { return } defer f.Close() - to := ParseZone(f, "", "t/miek.nl.signed_test") - for x := range to { - x = x + b.StartTimer() + for i := 0; i < b.N; i++ { + to := ParseZone(f, "", "t/miek.nl.signed_test") + for _ = range to { + } } } diff --git a/server.go b/server.go index a8bbd1c0..7dc02cab 100644 --- a/server.go +++ b/server.go @@ -330,8 +330,8 @@ func (c *conn) serve() { } w.tsigStatus = nil - if req.IsTsig() { - secret := req.Extra[len(req.Extra)-1].(*RR_TSIG).Hdr.Name + if t := req.IsTsig(); t != nil { + secret := t.Hdr.Name if _, ok := w.conn.tsigSecret[secret]; !ok { w.tsigStatus = ErrKeyAlg } @@ -360,8 +360,8 @@ func (w *response) Write(m *Msg) (err error) { if m == nil { return &Error{Err: "nil message"} } - if m.IsTsig() { - data, w.tsigRequestMAC, err = TsigGenerate(m, w.conn.tsigSecret[m.Extra[len(m.Extra)-1].(*RR_TSIG).Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) + if t := m.IsTsig(); t != nil { + data, w.tsigRequestMAC, err = TsigGenerate(m, w.conn.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) if err != nil { return err } diff --git a/tsig.go b/tsig.go index 6ce2a1b6..43ed9fb5 100644 --- a/tsig.go +++ b/tsig.go @@ -154,7 +154,7 @@ type timerWireFmt struct { // timersOnly is false. // If something goes wrong an error is returned, otherwise it is nil. func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) { - if !m.IsTsig() { + if m.IsTsig() == nil { panic("TSIG not last RR in additional") } // If we barf here, the caller is to blame diff --git a/zone.go b/zone.go index b78e9b57..077c77a8 100644 --- a/zone.go +++ b/zone.go @@ -5,25 +5,27 @@ package dns import ( "github.com/miekg/radix" "strings" + "sync" ) -// Zone represents a DNS zone. Currently there is no locking implemented. +// Zone represents a DNS zone. The structure is safe for concurrent access. type Zone struct { Origin string // Origin of the zone Wildcard int // Whenever we see a wildcard name, this is incremented *radix.Radix // Zone data + mutex *sync.RWMutex } -// ZoneData holds all the RRs having their ownername equal to Name. +// ZoneData holds all the RRs having their owner name equal to Name. type ZoneData struct { Name string // Domain name for this node RR map[uint16][]RR // Map of the RR type to the RR Signatures map[uint16][]*RR_RRSIG // DNSSEC signatures for the RRs, stored under type covered - // Always false, except for NSsets that differ from z.Origin - NonAuth bool + NonAuth bool // Always false, except for NSsets that differ from z.Origin + mutex *sync.RWMutex } -// toRadixName reverses a domainname so that when we store it in the radix tree +// toRadixName reverses a domain name so that when we store it in the radix tree // we preserve the nsec ordering of the zone (this idea was stolen from NSD). // each label is also lowercased. func toRadixName(d string) string { @@ -46,21 +48,24 @@ func NewZone(origin string) *Zone { return nil } z := new(Zone) + z.mutex = new(sync.RWMutex) z.Origin = Fqdn(origin) z.Radix = radix.New() return z } -// Insert inserts an RR into the zone. Duplicate data overwrites the old data without -// warning. +// Insert inserts an RR into the zone. There is no check for duplicate data, although +// Remove will remove all duplicates. func (z *Zone) Insert(r RR) error { if !IsSubDomain(z.Origin, r.Header().Name) { return &Error{Err: "out of zone data", Name: r.Header().Name} } key := toRadixName(r.Header().Name) + z.mutex.Lock() zd := z.Radix.Find(key) if zd == nil { + defer z.mutex.Unlock() // Check if its a wildcard name if len(r.Header().Name) > 1 && r.Header().Name[0] == '*' && r.Header().Name[1] == '.' { z.Wildcard++ @@ -69,6 +74,7 @@ func (z *Zone) Insert(r RR) error { zd.Name = r.Header().Name zd.RR = make(map[uint16][]RR) zd.Signatures = make(map[uint16][]*RR_RRSIG) + zd.mutex = new(sync.RWMutex) switch t := r.Header().Rrtype; t { case TypeRRSIG: sigtype := r.(*RR_RRSIG).TypeCovered @@ -85,6 +91,9 @@ func (z *Zone) Insert(r RR) error { z.Radix.Insert(key, zd) return nil } + z.mutex.Unlock() + zd.Value.(*ZoneData).mutex.Lock() + defer zd.Value.(*ZoneData).mutex.Unlock() // Name already there switch t := r.Header().Rrtype; t { case TypeRRSIG: @@ -102,9 +111,43 @@ func (z *Zone) Insert(r RR) error { } // Remove removes the RR r from the zone. If there RR can not be found, -// this is a no-op. TODO(mg): not implemented. +// this is a no-op. func (z *Zone) Remove(r RR) error { - // Wildcards + key := toRadixName(r.Header().Name) + z.mutex.Lock() + zd := z.Radix.Find(key) + if zd == nil { + defer z.mutex.Unlock() + return nil + } + z.mutex.Unlock() + zd.Value.(*ZoneData).mutex.Lock() + defer zd.Value.(*ZoneData).mutex.Unlock() + remove := false + switch t := r.Header().Rrtype; t { + case TypeRRSIG: + sigtype := r.(*RR_RRSIG).TypeCovered + for i, zr := range zd.Value.(*ZoneData).RR[sigtype] { + if r == zr { + zd.Value.(*ZoneData).RR[sigtype] = append(zd.Value.(*ZoneData).RR[sigtype][:i], zd.Value.(*ZoneData).RR[sigtype][i+1:]...) + remove = true + } + } + default: + for i, zr := range zd.Value.(*ZoneData).RR[t] { + if r == zr { + zd.Value.(*ZoneData).RR[t] = append(zd.Value.(*ZoneData).RR[t][:i], zd.Value.(*ZoneData).RR[t][i+1:]...) + remove = true + } + } + } + if remove && len(r.Header().Name) > 1 && r.Header().Name[0] == '*' && r.Header().Name[1] == '.' { + z.Wildcard-- + if z.Wildcard < 0 { + z.Wildcard = 0 + } + } + // TODO(mg): what to do if the whole structure is empty? Set it to nil? return nil } diff --git a/zone_test.go b/zone_test.go new file mode 100644 index 00000000..341c0f0a --- /dev/null +++ b/zone_test.go @@ -0,0 +1,10 @@ +package dns + +import ( + "testing" +) + +func TestInsert(t *testing.T) { +} +func TestRemove(t *testing.T) { +}