diff --git a/lib/kv/bolt.go b/lib/kv/bolt.go index 44a59da7e..4b0efc7dc 100644 --- a/lib/kv/bolt.go +++ b/lib/kv/bolt.go @@ -43,8 +43,9 @@ type DB struct { } var ( - dbMap = map[string]*DB{} - dbMut = sync.Mutex{} + dbMap = map[string]*DB{} + dbMut sync.Mutex + atExit bool ) // Supported returns true on supported OSes @@ -66,7 +67,9 @@ func makeName(facility string, f fs.Fs) string { // Start a new key-value database func Start(ctx context.Context, facility string, f fs.Fs) (*DB, error) { - if db := Get(facility, f); db != nil { + dbMut.Lock() + defer dbMut.Unlock() + if db := lockedGet(facility, f); db != nil { return db, nil } @@ -101,44 +104,29 @@ func Start(ctx context.Context, facility string, f fs.Fs) (*DB, error) { return nil, errors.Wrapf(err, "cannot open db: %s", db.path) } - // Initialization above was performed without locks.. - dbMut.Lock() - defer dbMut.Unlock() - if dbOther := dbMap[name]; dbOther != nil { - // Races between concurrent Start's are rare but possible, the 1st one wins. - _ = db.close() - return dbOther, nil - } - go db.loop() // Start queue handling + dbMap[name] = db + go db.loop() return db, nil } // Get returns database record for given filesystem and facility func Get(facility string, f fs.Fs) *DB { - name := makeName(facility, f) dbMut.Lock() + defer dbMut.Unlock() + return lockedGet(facility, f) +} + +func lockedGet(facility string, f fs.Fs) *DB { + name := makeName(facility, f) db := dbMap[name] if db != nil { db.mu.Lock() db.refs++ db.mu.Unlock() } - dbMut.Unlock() return db } -// free database record -func (db *DB) free() { - dbMut.Lock() - db.mu.Lock() - db.refs-- - if db.refs <= 0 { - delete(dbMap, db.name) - } - db.mu.Unlock() - dbMut.Unlock() -} - // Path returns database path func (db *DB) Path() string { return db.path } @@ -201,18 +189,28 @@ func (db *DB) close() (err error) { // loop over database operations sequentially func (db *DB) loop() { ctx := context.Background() - for db.queue != nil { + var req *request + quit := false + for !quit { select { - case req := <-db.queue: - req.handle(ctx, db) - _ = db.idleTimer.Reset(db.idleTime) + case req = <-db.queue: + if quit = req.handle(ctx, db); !quit { + req.wg.Done() + _ = db.idleTimer.Reset(db.idleTime) + } case <-db.idleTimer.C: _ = db.close() case <-db.lockTimer.C: _ = db.close() } } - db.free() + db.queue = nil + if !atExit { + dbMut.Lock() + delete(dbMap, db.name) + dbMut.Unlock() + } + req.wg.Done() } // Do a key-value operation and return error when done @@ -239,8 +237,10 @@ type request struct { } // handle a key-value request with given DB -func (r *request) handle(ctx context.Context, db *DB) { +// returns true as a signal to quit the loop +func (r *request) handle(ctx context.Context, db *DB) bool { db.mu.Lock() + defer db.mu.Unlock() if op, stop := r.op.(*opStop); stop { r.err = db.close() if op.remove { @@ -248,12 +248,11 @@ func (r *request) handle(ctx context.Context, db *DB) { r.err = err } } - db.queue = nil - } else { - r.err = db.execute(ctx, r.op, r.wr) + db.refs-- + return db.refs <= 0 } - db.mu.Unlock() - r.wg.Done() + r.err = db.execute(ctx, r.op, r.wr) + return false } // execute a key-value DB operation @@ -302,11 +301,15 @@ func (*opStop) Do(context.Context, Bucket) error { return nil } -// Exit stops all databases +// Exit immediately stops all databases func Exit() { dbMut.Lock() + atExit = true for _, s := range dbMap { + s.refs = 0 _ = s.Stop(false) } + dbMap = map[string]*DB{} + atExit = false dbMut.Unlock() } diff --git a/lib/kv/internal_test.go b/lib/kv/internal_test.go new file mode 100644 index 000000000..ae40c6ed4 --- /dev/null +++ b/lib/kv/internal_test.go @@ -0,0 +1,71 @@ +//go:build !plan9 && !js +// +build !plan9,!js + +package kv + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestKvConcurrency(t *testing.T) { + require.Equal(t, 0, len(dbMap), "no databases can be started initially") + + const threadNum = 5 + const facility = "test" + var wg sync.WaitGroup + ctx := context.Background() + results := make([]*DB, threadNum) + wg.Add(threadNum) + for i := 0; i < threadNum; i++ { + go func(i int) { + db, err := Start(ctx, "test", nil) + require.NoError(t, err) + require.NotNil(t, db) + results[i] = db + wg.Done() + }(i) + } + wg.Wait() + + // must have a single multi-referenced db + db := results[0] + assert.Equal(t, 1, len(dbMap)) + assert.Equal(t, threadNum, db.refs) + for i := 0; i < threadNum; i++ { + assert.Equal(t, db, results[i]) + } + + for i := 0; i < threadNum; i++ { + assert.Equal(t, 1, len(dbMap)) + err := db.Stop(false) + assert.NoError(t, err, "unexpected error %v at retry %d", err, i) + } + + assert.Equal(t, 0, len(dbMap), "must be closed in the end") + err := db.Stop(false) + assert.ErrorIs(t, err, ErrInactive, "missing expected stop indication") +} + +func TestKvExit(t *testing.T) { + require.Equal(t, 0, len(dbMap), "no databases can be started initially") + const dbNum = 5 + const openNum = 2 + ctx := context.Background() + for i := 0; i < dbNum; i++ { + facility := fmt.Sprintf("test-%d", i) + for j := 0; j <= i; j++ { + db, err := Start(ctx, facility, nil) + require.NoError(t, err) + require.NotNil(t, db) + } + } + assert.Equal(t, dbNum, len(dbMap)) + Exit() + assert.Equal(t, 0, len(dbMap)) +}