replace FindFilteredSnapshots() with (*data.SnapshotFilter).FindAll() (#21912)

Co-authored-by: Michael Eischer <michael.eischer@fau.de>
This commit is contained in:
Winfried Plappert
2026-06-28 14:31:06 +01:00
committed by GitHub
parent 75de8b54e6
commit 62991338ea
9 changed files with 103 additions and 77 deletions
+32 -13
View File
@@ -72,14 +72,22 @@ func (opts *CopyOptions) AddFlags(f *pflag.FlagSet) {
initMultiSnapshotFilter(f, &opts.SnapshotFilter, true)
}
var errSentinelEndIteration = errors.New("end iteration")
// collectAllSnapshots: select all snapshot trees to be copied
func collectAllSnapshots(ctx context.Context, opts CopyOptions,
srcSnapshotLister restic.Lister, srcRepo restic.Repository,
dstSnapshotByOriginal map[restic.ID][]*data.Snapshot, args []string, printer restic.Printer,
) iter.Seq[*data.Snapshot] {
return func(yield func(*data.Snapshot) bool) {
for sn := range FindFilteredSnapshots(ctx, srcSnapshotLister, srcRepo, &opts.SnapshotFilter, args, printer) {
) iter.Seq2[*data.Snapshot, error] {
return func(yield func(*data.Snapshot, error) bool) {
err := opts.SnapshotFilter.FindAll(ctx, srcSnapshotLister, srcRepo, args, func(_ string, sn *data.Snapshot, err error) error {
// check whether the destination has a snapshot with the same persistent ID which has similar snapshot fields
if err != nil {
if !yield(nil, err) {
return errSentinelEndIteration
}
return nil
}
srcOriginal := *sn.ID()
if sn.Original != nil {
srcOriginal = *sn.Original
@@ -95,12 +103,16 @@ func collectAllSnapshots(ctx context.Context, opts CopyOptions,
}
}
if isCopy {
continue
return nil
}
}
if !yield(sn) {
return
if !yield(sn, nil) {
return errSentinelEndIteration
}
return nil
})
if err != nil && !errors.Is(err, errSentinelEndIteration) {
yield(nil, err)
}
}
}
@@ -148,15 +160,19 @@ func runCopy(ctx context.Context, opts CopyOptions, gopts global.Options, args [
}
dstSnapshotByOriginal := make(map[restic.ID][]*data.Snapshot)
for sn := range FindFilteredSnapshots(ctx, dstSnapshotLister, dstRepo, &opts.SnapshotFilter, nil, printer) {
err = opts.SnapshotFilter.FindAll(ctx, dstSnapshotLister, dstRepo, nil, func(_ string, sn *data.Snapshot, err error) error {
if err != nil {
return err
}
if sn.Original != nil && !sn.Original.IsNull() {
dstSnapshotByOriginal[*sn.Original] = append(dstSnapshotByOriginal[*sn.Original], sn)
}
// also consider identical snapshot copies
dstSnapshotByOriginal[*sn.ID()] = append(dstSnapshotByOriginal[*sn.ID()], sn)
}
if ctx.Err() != nil {
return ctx.Err()
return nil
})
if err != nil {
return err
}
selectedSnapshots := collectAllSnapshots(ctx, opts, srcSnapshotLister, srcRepo, dstSnapshotByOriginal, args, printer)
@@ -190,7 +206,7 @@ func similarSnapshots(sna *data.Snapshot, snb *data.Snapshot) bool {
// copyTreeBatched copies multiple snapshots in one go. Snapshots are written after
// data equivalent to at least 10 packfiles was written.
func copyTreeBatched(ctx context.Context, srcRepo *repository.Repository, dstRepo restic.Repository,
selectedSnapshots iter.Seq[*data.Snapshot], printer restic.Printer) error {
selectedSnapshots iter.Seq2[*data.Snapshot, error], printer restic.Printer) error {
// remember already processed trees across all snapshots
visitedTrees := srcRepo.NewAssociatedBlobSet()
@@ -199,7 +215,7 @@ func copyTreeBatched(ctx context.Context, srcRepo *repository.Repository, dstRep
minDuration := 1 * time.Minute
// use pull-based iterator to allow iteration in multiple steps
next, stop := iter.Pull(selectedSnapshots)
next, stop := iter.Pull2(selectedSnapshots)
defer stop()
for {
@@ -210,7 +226,10 @@ func copyTreeBatched(ctx context.Context, srcRepo *repository.Repository, dstRep
// call WithBlobUploader() once and then loop over all selectedSnapshots
err := dstRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error {
for batchSize < targetSize || time.Since(startTime) < minDuration {
sn, ok := next()
sn, err, ok := next()
if err != nil {
return err
}
if !ok {
break
}
+8 -4
View File
@@ -675,11 +675,15 @@ func runFind(ctx context.Context, opts FindOptions, gopts global.Options, args [
}
var filteredSnapshots []*data.Snapshot
for sn := range FindFilteredSnapshots(ctx, snapshotLister, repo, &opts.SnapshotFilter, opts.Snapshots, printer) {
err = opts.SnapshotFilter.FindAll(ctx, snapshotLister, repo, opts.Snapshots, func(_ string, sn *data.Snapshot, err error) error {
if err != nil {
return err
}
filteredSnapshots = append(filteredSnapshots, sn)
}
if ctx.Err() != nil {
return ctx.Err()
return nil
})
if err != nil {
return err
}
sort.Slice(filteredSnapshots, func(i, j int) bool {
+8 -4
View File
@@ -200,11 +200,15 @@ func runForget(ctx context.Context, opts ForgetOptions, pruneOptions PruneOption
var snapshots data.Snapshots
removeSnIDs := restic.NewIDSet()
for sn := range FindFilteredSnapshots(ctx, repo, repo, &opts.SnapshotFilter, args, printer) {
err = opts.SnapshotFilter.FindAll(ctx, repo, repo, args, func(_ string, sn *data.Snapshot, err error) error {
if err != nil {
return err
}
snapshots = append(snapshots, sn)
}
if ctx.Err() != nil {
return ctx.Err()
return nil
})
if err != nil {
return err
}
var jsonGroups []*ForgetGroup
+8 -4
View File
@@ -327,7 +327,10 @@ func runRewrite(ctx context.Context, opts RewriteOptions, gopts global.Options,
}
changedCount := 0
for sn := range FindFilteredSnapshots(ctx, snapshotLister, repo, &opts.SnapshotFilter, args, printer) {
err = opts.SnapshotFilter.FindAll(ctx, snapshotLister, repo, args, func(_ string, sn *data.Snapshot, err error) error {
if err != nil {
return err
}
printer.P("\n%v", sn)
changed, err := rewriteSnapshot(ctx, repo, sn, opts, printer)
if err != nil {
@@ -336,9 +339,10 @@ func runRewrite(ctx context.Context, opts RewriteOptions, gopts global.Options,
if changed {
changedCount++
}
}
if ctx.Err() != nil {
return ctx.Err()
return nil
})
if err != nil {
return err
}
printer.P("")
+9 -4
View File
@@ -90,12 +90,17 @@ func runSnapshots(ctx context.Context, opts SnapshotOptions, gopts global.Option
defer unlock()
var snapshots data.Snapshots
for sn := range FindFilteredSnapshots(ctx, repo, repo, &opts.SnapshotFilter, args, printer) {
err = opts.SnapshotFilter.FindAll(ctx, repo, repo, args, func(_ string, sn *data.Snapshot, err error) error {
if err != nil {
return err
}
snapshots = append(snapshots, sn)
return nil
})
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
snapshotGroups, grouped, err := data.GroupSnapshots(snapshots, opts.GroupBy)
if err != nil {
return err
+10 -3
View File
@@ -31,13 +31,13 @@ func newStatsCommand(globalOptions *global.Options) *cobra.Command {
Short: "Scan the repository and show basic statistics",
Long: `
The "stats" command walks one or multiple snapshots in a repository
and accumulates statistics about the data stored therein. It reports
and accumulates statistics about the data stored therein. It reports
on the number of unique files and their sizes, according to one of
the counting modes as given by the --mode flag.
It operates on all snapshots matching the selection criteria or all
snapshots if nothing is specified. The special snapshot ID "latest"
is also supported. Some modes make more sense over
is also supported. Some modes make more sense over
just a single snapshot, while others are useful across all snapshots,
depending on what you are trying to calculate.
@@ -134,8 +134,15 @@ func runStats(ctx context.Context, opts StatsOptions, gopts global.Options, args
}
var snapshots data.Snapshots
for sn := range FindFilteredSnapshots(ctx, snapshotLister, repo, &opts.SnapshotFilter, args, printer) {
err = opts.SnapshotFilter.FindAll(ctx, snapshotLister, repo, args, func(_ string, sn *data.Snapshot, err error) error {
if err != nil {
return err
}
snapshots = append(snapshots, sn)
return nil
})
if err != nil {
return err
}
statsProgress := statsui.NewProgress(term, gopts.Quiet, gopts.JSON, uint64(len(snapshots)))
+9 -6
View File
@@ -159,19 +159,22 @@ func runTag(ctx context.Context, opts TagOptions, gopts global.Options, term ui.
}
}
for sn := range FindFilteredSnapshots(ctx, repo, repo, &opts.SnapshotFilter, args, printer) {
err = opts.SnapshotFilter.FindAll(ctx, repo, repo, args, func(_ string, sn *data.Snapshot, err error) error {
if err != nil {
return err
}
changed, err := changeTags(ctx, repo, sn, opts.SetTags.Flatten(), opts.AddTags.Flatten(), opts.RemoveTags.Flatten(), printFunc)
if err != nil {
printer.E("unable to modify the tags for snapshot ID %q, ignoring: %v", sn.ID(), err)
continue
return nil
}
if changed {
summary.ChangedSnapshots++
}
}
if ctx.Err() != nil {
return ctx.Err()
return nil
})
if err != nil {
return err
}
printSummary(summary)
+1 -33
View File
@@ -1,16 +1,14 @@
package main
import (
"context"
"os"
"github.com/restic/restic/internal/data"
"github.com/restic/restic/internal/restic"
"github.com/spf13/pflag"
)
// initMultiSnapshotFilter is used for commands that work on multiple snapshots
// MUST be combined with FindFilteredSnapshots
// MUST be combined with (*data,SnapshotFilter).FindAll
// MUST be followed by finalizeSnapshotFilter after flag parsing
func initMultiSnapshotFilter(flags *pflag.FlagSet, filt *data.SnapshotFilter, addHostShorthand bool) {
hostShorthand := "H"
@@ -46,33 +44,3 @@ func finalizeSnapshotFilter(filt *data.SnapshotFilter) {
filt.Hosts = nil
}
}
// FindFilteredSnapshots yields Snapshots, either given explicitly by `snapshotIDs` or filtered from the list of all snapshots.
func FindFilteredSnapshots(ctx context.Context, be restic.Lister, loader restic.LoaderUnpacked, f *data.SnapshotFilter, snapshotIDs []string, printer restic.Printer) <-chan *data.Snapshot {
out := make(chan *data.Snapshot)
go func() {
defer close(out)
be, err := restic.MemorizeList(ctx, be, restic.SnapshotFile)
if err != nil {
printer.E("could not load snapshots: %v", err)
return
}
err = f.FindAll(ctx, be, loader, snapshotIDs, func(id string, sn *data.Snapshot, err error) error {
if err != nil {
printer.E("Ignoring %q: %v", id, err)
} else {
select {
case <-ctx.Done():
return ctx.Err()
case out <- sn:
}
}
return nil
})
if err != nil {
printer.E("could not load snapshots: %v", err)
}
}()
return out
}
+18 -6
View File
@@ -174,13 +174,25 @@ func TestFindListOnce(t *testing.T) {
snapshotIDs = restic.NewIDSet()
// specify the two oldest snapshots explicitly and use "latest" to reference the newest one
for sn := range FindFilteredSnapshots(ctx, repo, repo, &data.SnapshotFilter{}, []string{
secondSnapshot[0].String(),
secondSnapshot[1].String()[:8],
"latest",
}, printer) {
snapshotIDs.Insert(*sn.ID())
err = (&data.SnapshotFilter{}).FindAll(ctx, repo, repo,
[]string{
secondSnapshot[0].String(),
secondSnapshot[1].String()[:8],
"latest",
},
func(id string, sn *data.Snapshot, err error) error {
if err != nil {
return err
}
snapshotIDs.Insert(*sn.ID())
return nil
})
if err != nil {
return err
}
return nil
}))