diff --git a/internal/repository/lock_file.go b/internal/repository/lock_file.go index 45685a165..9e467f3db 100644 --- a/internal/repository/lock_file.go +++ b/internal/repository/lock_file.go @@ -304,25 +304,32 @@ func delayedCancelContext(parentCtx context.Context, delay time.Duration) (conte // timestamp. Afterwards the old lock is removed. func (l *lockHandle) refresh(ctx context.Context) error { debug.Log("refreshing lock %v", l.lockID) - l.mu.Lock() - l.Time = time.Now() - l.mu.Unlock() - id, err := l.createLock(ctx) + id, err := l.createReplacementLock(ctx) if err != nil { return err } + ctx, cancel := delayedCancelContext(ctx, unlockCancelDelay) + defer cancel() + return l.adoptReplacementLock(ctx, id) +} + +func (l *lockHandle) createReplacementLock(ctx context.Context) (restic.ID, error) { + l.mu.Lock() + l.Time = time.Now() + l.mu.Unlock() + return l.createLock(ctx) +} + +func (l *lockHandle) adoptReplacementLock(ctx context.Context, id restic.ID) error { l.mu.Lock() defer l.mu.Unlock() debug.Log("new lock ID %v", id) - oldLockID := l.lockID + oldID := *l.lockID l.lockID = &id - ctx, cancel := delayedCancelContext(ctx, unlockCancelDelay) - defer cancel() - - return l.repo.RemoveUnpacked(ctx, restic.LockFile, *oldLockID) + return l.repo.RemoveUnpacked(ctx, restic.LockFile, oldID) } // refreshStaleLock is an extended variant of refresh that can also refresh stale lock files. @@ -338,10 +345,7 @@ func (l *lockHandle) refreshStaleLock(ctx context.Context) error { return errRemovedLock } - l.mu.Lock() - l.Time = time.Now() - l.mu.Unlock() - id, err := l.createLock(ctx) + id, err := l.createReplacementLock(ctx) if err != nil { return err } @@ -365,14 +369,7 @@ func (l *lockHandle) refreshStaleLock(ctx context.Context) error { return errRemovedLock } - l.mu.Lock() - defer l.mu.Unlock() - - debug.Log("new lock ID %v", id) - oldLockID := l.lockID - l.lockID = &id - - return l.repo.RemoveUnpacked(ctx, restic.LockFile, *oldLockID) + return l.adoptReplacementLock(ctx, id) } func (l *lockHandle) checkExistence(ctx context.Context) (bool, error) {