From 0f4236cb396d0e63c00c39335459577afca6b52f Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sun, 14 Jun 2026 10:50:38 +0200 Subject: [PATCH] repository: return unlock func from LockRepo Drop the Unlocker interface and return the unlock callback directly from LockRepo, simplifying callers that only need to defer unlock(). --- cmd/restic/lock.go | 6 +---- internal/repository/lock.go | 12 +++------- internal/repository/lock_test.go | 38 ++++++++++++++++---------------- 3 files changed, 23 insertions(+), 33 deletions(-) diff --git a/cmd/restic/lock.go b/cmd/restic/lock.go index 8264c6b9e..eb2fc47ed 100644 --- a/cmd/restic/lock.go +++ b/cmd/restic/lock.go @@ -16,9 +16,7 @@ func internalOpenWithLocked(ctx context.Context, gopts global.Options, dryRun bo unlock := func() {} if !dryRun { - var lock repository.Unlocker - - lock, ctx, err = repository.LockRepo(ctx, repo, exclusive, gopts.RetryLock, func(msg string) { + unlock, ctx, err = repository.LockRepo(ctx, repo, exclusive, gopts.RetryLock, func(msg string) { if !gopts.JSON { printer.P("%s", msg) } @@ -26,8 +24,6 @@ func internalOpenWithLocked(ctx context.Context, gopts global.Options, dryRun bo if err != nil { return nil, nil, nil, err } - - unlock = lock.Unlock } else { repo.SetDryRun() } diff --git a/internal/repository/lock.go b/internal/repository/lock.go index 6f43b5e6a..9a3e8aba5 100644 --- a/internal/repository/lock.go +++ b/internal/repository/lock.go @@ -13,10 +13,6 @@ import ( "github.com/restic/restic/internal/restic" ) -type Unlocker interface { - Unlock() -} - type unlocker struct { lock *lockHandle cancel context.CancelFunc @@ -28,8 +24,6 @@ func (l *unlocker) Unlock() { l.refreshWG.Wait() } -var _ Unlocker = &unlocker{} - type locker struct { retrySleepStart time.Duration retrySleepMax time.Duration @@ -50,11 +44,11 @@ var lockerInst = &locker{ // LockRepo acquires a repository lock. The returned context is cancelled when // Unlock is called; cancelling the original context stops lock refresh. -func LockRepo(ctx context.Context, repo *Repository, exclusive bool, retryLock time.Duration, printRetry func(msg string), logger func(format string, args ...interface{})) (Unlocker, context.Context, error) { +func LockRepo(ctx context.Context, repo *Repository, exclusive bool, retryLock time.Duration, printRetry func(msg string), logger func(format string, args ...interface{})) (func(), context.Context, error) { return lockerInst.Lock(ctx, repo, exclusive, retryLock, printRetry, logger) } -func (l *locker) Lock(ctx context.Context, r *Repository, exclusive bool, retryLock time.Duration, printRetry func(msg string), logger func(format string, args ...interface{})) (*unlocker, context.Context, error) { +func (l *locker) Lock(ctx context.Context, r *Repository, exclusive bool, retryLock time.Duration, printRetry func(msg string), logger func(format string, args ...interface{})) (func(), context.Context, error) { var lock *lockHandle var err error @@ -113,7 +107,7 @@ retryLoop: go l.refreshLocks(ctx, repo.be, unlocker, refreshChan, forceRefreshChan, logger) go l.monitorLockRefresh(ctx, unlocker, refreshChan, forceRefreshChan, logger) - return unlocker, ctx, nil + return unlocker.Unlock, ctx, nil } func minDuration(a, b time.Duration) time.Duration { diff --git a/internal/repository/lock_test.go b/internal/repository/lock_test.go index 6f29b7a18..847be76e5 100644 --- a/internal/repository/lock_test.go +++ b/internal/repository/lock_test.go @@ -34,19 +34,19 @@ func openLockTestRepo(t *testing.T, wrapper backendWrapper) (*Repository, backen return TestOpenBackend(t, be), be } -func checkedLockRepo(ctx context.Context, t *testing.T, repo *Repository, lockerInst *locker, retryLock time.Duration) (Unlocker, context.Context) { - lock, wrappedCtx, err := lockerInst.Lock(ctx, repo, false, retryLock, func(msg string) {}, func(format string, args ...interface{}) {}) +func checkedLockRepo(ctx context.Context, t *testing.T, repo *Repository, lockerInst *locker, retryLock time.Duration) (func(), context.Context) { + unlock, wrappedCtx, err := lockerInst.Lock(ctx, repo, false, retryLock, func(msg string) {}, func(format string, args ...interface{}) {}) rtest.OK(t, err) rtest.OK(t, wrappedCtx.Err()) - return lock, wrappedCtx + return unlock, wrappedCtx } func TestLock(t *testing.T) { t.Parallel() repo, _ := openLockTestRepo(t, nil) - lock, wrappedCtx := checkedLockRepo(context.Background(), t, repo, lockerInst, 0) - lock.Unlock() + unlock, wrappedCtx := checkedLockRepo(context.Background(), t, repo, lockerInst, 0) + unlock() if wrappedCtx.Err() == nil { t.Fatal("unlock did not cancel context") } @@ -65,7 +65,7 @@ func TestLockCancel(t *testing.T) { } // Unlock should not crash - lock.Unlock() + lock() } func TestLockConflict(t *testing.T) { @@ -73,9 +73,9 @@ func TestLockConflict(t *testing.T) { repo, be := openLockTestRepo(t, nil) repo2 := TestOpenBackend(t, be) - lock, _, err := LockRepo(context.Background(), repo, true, 0, func(msg string) {}, func(format string, args ...interface{}) {}) + unlock, _, err := LockRepo(context.Background(), repo, true, 0, func(msg string) {}, func(format string, args ...interface{}) {}) rtest.OK(t, err) - defer lock.Unlock() + defer unlock() _, _, err = LockRepo(context.Background(), repo2, false, 0, func(msg string) {}, func(format string, args ...interface{}) {}) if err == nil { t.Fatal("second lock should have failed") @@ -109,7 +109,7 @@ func TestLockFailedRefresh(t *testing.T) { refreshInterval: 20 * time.Millisecond, refreshabilityTimeout: 100 * time.Millisecond, } - lock, wrappedCtx := checkedLockRepo(context.Background(), t, repo, li, 0) + unlock, wrappedCtx := checkedLockRepo(context.Background(), t, repo, li, 0) select { case <-wrappedCtx.Done(): @@ -118,7 +118,7 @@ func TestLockFailedRefresh(t *testing.T) { t.Fatal("failed lock refresh did not cause context cancellation") } // Unlock should not crash - lock.Unlock() + unlock() } type loggingBackend struct { @@ -150,7 +150,7 @@ func TestLockSuccessfulRefresh(t *testing.T) { refreshInterval: 60 * time.Millisecond, refreshabilityTimeout: 500 * time.Millisecond, } - lock, wrappedCtx := checkedLockRepo(context.Background(), t, repo, li, 0) + unlock, wrappedCtx := checkedLockRepo(context.Background(), t, repo, li, 0) select { case <-wrappedCtx.Done(): @@ -167,7 +167,7 @@ func TestLockSuccessfulRefresh(t *testing.T) { // expected lock refresh to work } // Unlock should not crash - lock.Unlock() + unlock() } type slowBackend struct { @@ -201,7 +201,7 @@ func TestLockSuccessfulStaleRefresh(t *testing.T) { refreshabilityTimeout: 50 * time.Millisecond, } - lock, wrappedCtx := checkedLockRepo(context.Background(), t, repo, li, 0) + unlock, wrappedCtx := checkedLockRepo(context.Background(), t, repo, li, 0) // delay lock refreshing long enough that the lock would expire sb.m.Lock() sb.sleep = li.refreshabilityTimeout + li.refreshInterval @@ -230,7 +230,7 @@ func TestLockSuccessfulStaleRefresh(t *testing.T) { } // Unlock should not crash - lock.Unlock() + unlock() } func TestLockWaitTimeout(t *testing.T) { @@ -239,7 +239,7 @@ func TestLockWaitTimeout(t *testing.T) { elock, _, err := LockRepo(context.TODO(), repo, true, 0, func(msg string) {}, func(format string, args ...interface{}) {}) rtest.OK(t, err) - defer elock.Unlock() + defer elock() retryLock := 200 * time.Millisecond @@ -261,7 +261,7 @@ func TestLockWaitCancel(t *testing.T) { elock, _, err := LockRepo(context.TODO(), repo, true, 0, func(msg string) {}, func(format string, args ...interface{}) {}) rtest.OK(t, err) - defer elock.Unlock() + defer elock() retryLock := 200 * time.Millisecond cancelAfter := 40 * time.Millisecond @@ -292,12 +292,12 @@ func TestLockWaitSuccess(t *testing.T) { unlockAfter := 40 * time.Millisecond time.AfterFunc(unlockAfter, func() { - elock.Unlock() + elock() }) - lock, _, err := LockRepo(context.TODO(), repo, false, retryLock, func(msg string) {}, func(format string, args ...interface{}) {}) + unlock, _, err := LockRepo(context.TODO(), repo, false, retryLock, func(msg string) {}, func(format string, args ...interface{}) {}) rtest.OK(t, err) - lock.Unlock() + unlock() } func createFakeLock(repo *Repository, t time.Time, pid int) (restic.ID, error) {