diff --git a/src/documents/management/commands/base.py b/src/documents/management/commands/base.py index 39510e419..535fb36ff 100644 --- a/src/documents/management/commands/base.py +++ b/src/documents/management/commands/base.py @@ -7,6 +7,7 @@ Provides automatic progress bar and multiprocessing support with minimal boilerp from __future__ import annotations import os +from collections.abc import Callable from collections.abc import Iterable from collections.abc import Sized from concurrent.futures import ProcessPoolExecutor @@ -23,6 +24,9 @@ from django.core.management import CommandError from django.db.models import QuerySet from django_rich.management import RichCommand from rich.console import Console +from rich.console import Group +from rich.console import RenderableType +from rich.live import Live from rich.progress import BarColumn from rich.progress import MofNCompleteColumn from rich.progress import Progress @@ -32,9 +36,7 @@ from rich.progress import TimeElapsedColumn from rich.progress import TimeRemainingColumn if TYPE_CHECKING: - from collections.abc import Callable from collections.abc import Generator - from collections.abc import Iterable from collections.abc import Sequence from django.core.management import CommandParser @@ -91,6 +93,23 @@ class PaperlessCommand(RichCommand): for result in self.process_parallel(process_doc, ids): if result.error: self.console.print(f"[red]Failed: {result.error}[/red]") + + class Command(PaperlessCommand): + help = "Import documents with live stats" + + def handle(self, *args, **options): + stats = ImportStats() + + def render_stats() -> Table: + ... # build Rich Table from stats + + for item in self.track_with_stats( + items, + description="Importing...", + stats_renderer=render_stats, + ): + result = import_item(item) + stats.imported += 1 """ supports_progress_bar: ClassVar[bool] = True @@ -128,13 +147,11 @@ class PaperlessCommand(RichCommand): This is called by Django's command infrastructure after argument parsing but before handle(). We use it to set instance attributes from options. """ - # Set progress bar state if self.supports_progress_bar: self.no_progress_bar = options.get("no_progress_bar", False) else: self.no_progress_bar = True - # Set multiprocessing state if self.supports_multiprocessing: self.process_count = options.get("processes", 1) if self.process_count < 1: @@ -144,9 +161,29 @@ class PaperlessCommand(RichCommand): return super().execute(*args, **options) + @staticmethod + def _progress_columns() -> tuple[Any, ...]: + """ + Return the standard set of progress bar columns. + + Extracted so both _create_progress (standalone) and track_with_stats + (inside Live) use identical column configuration without duplication. + """ + return ( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + ) + def _create_progress(self, description: str) -> Progress: """ - Create a configured Progress instance. + Create a standalone Progress instance with its own stderr Console. + + Use this for track(). For track_with_stats(), Progress is created + directly inside a Live context instead. Progress output is directed to stderr to match the convention that progress bars are transient UI feedback, not command output. This @@ -161,12 +198,7 @@ class PaperlessCommand(RichCommand): A Progress instance configured with appropriate columns. """ return Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - MofNCompleteColumn(), - TimeElapsedColumn(), - TimeRemainingColumn(), + *self._progress_columns(), console=Console(stderr=True), transient=False, ) @@ -222,7 +254,6 @@ class PaperlessCommand(RichCommand): yield from iterable return - # Attempt to determine total if not provided if total is None: total = self._get_iterable_length(iterable) @@ -232,6 +263,87 @@ class PaperlessCommand(RichCommand): yield item progress.advance(task_id) + def track_with_stats( + self, + iterable: Iterable[T], + *, + description: str = "Processing...", + stats_renderer: Callable[[], RenderableType], + total: int | None = None, + ) -> Generator[T, None, None]: + """ + Iterate over items with a progress bar and a live-updating stats display. + + The progress bar and stats renderable are combined in a single Live + context, so the stats panel re-renders in place below the progress bar + after each item is processed. + + Respects --no-progress-bar flag. When disabled, yields items without + any display (stats are still updated by the caller's loop body, so + they will be accurate for any post-loop summary the caller prints). + + Args: + iterable: The items to iterate over. + description: Text to display alongside the progress bar. + stats_renderer: Zero-argument callable that returns a Rich + renderable. Called after each item to refresh the display. + The caller typically closes over a mutable dataclass and + rebuilds a Table from it on each call. + total: Total number of items. If None, attempts to determine + automatically via .count() (for querysets) or len(). + + Yields: + Items from the iterable. + + Example: + @dataclass + class Stats: + processed: int = 0 + failed: int = 0 + + stats = Stats() + + def render_stats() -> Table: + table = Table(box=None) + table.add_column("Processed") + table.add_column("Failed") + table.add_row(str(stats.processed), str(stats.failed)) + return table + + for item in self.track_with_stats( + items, + description="Importing...", + stats_renderer=render_stats, + ): + try: + import_item(item) + stats.processed += 1 + except Exception: + stats.failed += 1 + """ + if self.no_progress_bar: + yield from iterable + return + + if total is None: + total = self._get_iterable_length(iterable) + + stderr_console = Console(stderr=True) + + # Progress is created without its own console so Live controls rendering. + progress = Progress(*self._progress_columns()) + task_id = progress.add_task(description, total=total) + + with Live( + Group(progress, stats_renderer()), + console=stderr_console, + refresh_per_second=4, + ) as live: + for item in iterable: + yield item + progress.advance(task_id) + live.update(Group(progress, stats_renderer())) + def process_parallel( self, fn: Callable[[T], R], @@ -269,10 +381,8 @@ class PaperlessCommand(RichCommand): total = len(items) if self.process_count == 1: - # Sequential execution in main process - critical for testing yield from self._process_sequential(fn, items, description, total) else: - # Parallel execution with ProcessPoolExecutor yield from self._process_parallel(fn, items, description, total) def _process_sequential( @@ -298,17 +408,14 @@ class PaperlessCommand(RichCommand): total: int, ) -> Generator[ProcessResult[T, R], None, None]: """Process items in parallel using ProcessPoolExecutor.""" - # Close database connections before forking - required for PostgreSQL db.connections.close_all() with self._create_progress(description) as progress: task_id = progress.add_task(description, total=total) with ProcessPoolExecutor(max_workers=self.process_count) as executor: - # Submit all tasks and map futures back to items future_to_item = {executor.submit(fn, item): item for item in items} - # Yield results as they complete for future in as_completed(future_to_item): item = future_to_item[future] try: diff --git a/src/documents/management/commands/document_retagger.py b/src/documents/management/commands/document_retagger.py index 32f895d4e..d14f6b8c6 100644 --- a/src/documents/management/commands/document_retagger.py +++ b/src/documents/management/commands/document_retagger.py @@ -1,4 +1,13 @@ +from __future__ import annotations + import logging +from dataclasses import dataclass +from dataclasses import field +from typing import TYPE_CHECKING + +from rich.console import RenderableType +from rich.table import Table +from rich.text import Text from documents.classifier import load_classifier from documents.management.commands.base import PaperlessCommand @@ -8,9 +17,160 @@ from documents.signals.handlers import set_document_type from documents.signals.handlers import set_storage_path from documents.signals.handlers import set_tags +if TYPE_CHECKING: + from documents.models import Correspondent + from documents.models import DocumentType + from documents.models import StoragePath + from documents.models import Tag + logger = logging.getLogger("paperless.management.retagger") +@dataclass(slots=True) +class RetaggerStats: + """Cumulative counters updated as the retagger processes documents. + + Mutable by design -- fields are incremented in the processing loop. + slots=True reduces per-instance memory overhead and speeds attribute access. + """ + + correspondents: int = 0 + document_types: int = 0 + tags_added: int = 0 + tags_removed: int = 0 + storage_paths: int = 0 + documents_processed: int = 0 + + +@dataclass(slots=True) +class DocumentSuggestion: + """Buffered classifier suggestions for a single document (suggest mode only). + + Mutable by design -- fields are assigned incrementally as each setter runs. + """ + + document: Document + correspondent: Correspondent | None = None + document_type: DocumentType | None = None + tags_to_add: frozenset[Tag] = field(default_factory=frozenset) + tags_to_remove: frozenset[Tag] = field(default_factory=frozenset) + storage_path: StoragePath | None = None + + @property + def has_suggestions(self) -> bool: + return bool( + self.correspondent is not None + or self.document_type is not None + or self.tags_to_add + or self.tags_to_remove + or self.storage_path is not None, + ) + + +def _build_stats_table(stats: RetaggerStats, *, suggest: bool) -> Table: + """ + Build the live-updating stats table shown below the progress bar. + + In suggest mode the labels read "would set / would add" to make clear + that nothing has been written to the database. + """ + table = Table(box=None, padding=(0, 2), show_header=True, header_style="bold") + + table.add_column("Documents") + table.add_column("Correspondents") + table.add_column("Doc Types") + table.add_column("Tags (+)") + table.add_column("Tags (-)") + table.add_column("Storage Paths") + + verb = "would set" if suggest else "set" + + table.add_row( + str(stats.documents_processed), + f"{stats.correspondents} {verb}", + f"{stats.document_types} {verb}", + f"+{stats.tags_added}", + f"-{stats.tags_removed}", + f"{stats.storage_paths} {verb}", + ) + + return table + + +def _build_suggestion_table( + suggestions: list[DocumentSuggestion], + base_url: str | None, +) -> Table: + """ + Build the final suggestion table printed after the progress bar completes. + + Only documents with at least one suggestion are included. + """ + table = Table( + title="Suggested Changes", + show_header=True, + header_style="bold cyan", + show_lines=True, + ) + + table.add_column("Document", style="bold", no_wrap=False, min_width=20) + table.add_column("Correspondent") + table.add_column("Doc Type") + table.add_column("Tags") + table.add_column("Storage Path") + + for suggestion in suggestions: + if not suggestion.has_suggestions: + continue + + doc = suggestion.document + + if base_url: + doc_cell = Text() + doc_cell.append(str(doc)) + doc_cell.append(f"\n{base_url}/documents/{doc.pk}", style="dim") + else: + doc_cell = Text(f"{doc} [{doc.pk}]") + + tag_parts: list[str] = [] + for tag in sorted(suggestion.tags_to_add, key=lambda t: t.name): + tag_parts.append(f"[green]+{tag.name}[/green]") + for tag in sorted(suggestion.tags_to_remove, key=lambda t: t.name): + tag_parts.append(f"[red]-{tag.name}[/red]") + tag_cell = Text.from_markup(", ".join(tag_parts)) if tag_parts else Text("-") + + table.add_row( + doc_cell, + str(suggestion.correspondent) if suggestion.correspondent else "-", + str(suggestion.document_type) if suggestion.document_type else "-", + tag_cell, + str(suggestion.storage_path) if suggestion.storage_path else "-", + ) + + return table + + +def _build_summary_table(stats: RetaggerStats) -> Table: + """Build the final applied-changes summary table.""" + table = Table( + title="Retagger Summary", + show_header=True, + header_style="bold cyan", + ) + + table.add_column("Metric", style="bold") + table.add_column("Count", justify="right") + + table.add_row("Documents processed", str(stats.documents_processed)) + table.add_row("Correspondents set", str(stats.correspondents)) + table.add_row("Document types set", str(stats.document_types)) + table.add_row("Tags added", str(stats.tags_added)) + table.add_row("Tags removed", str(stats.tags_removed)) + table.add_row("Storage paths set", str(stats.storage_paths)) + + return table + + class Command(PaperlessCommand): help = ( "Using the current classification model, assigns correspondents, tags " @@ -19,7 +179,7 @@ class Command(PaperlessCommand): "modified) after their initial import." ) - def add_arguments(self, parser): + def add_arguments(self, parser) -> None: super().add_arguments(parser) parser.add_argument("-c", "--correspondent", default=False, action="store_true") parser.add_argument("-T", "--tags", default=False, action="store_true") @@ -31,9 +191,9 @@ class Command(PaperlessCommand): default=False, action="store_true", help=( - "By default this command won't try to assign a correspondent " - "if more than one matches the document. Use this flag if " - "you'd rather it just pick the first one it finds." + "By default this command will not try to assign a correspondent " + "if more than one matches the document. Use this flag to pick " + "the first match instead." ), ) parser.add_argument( @@ -42,91 +202,133 @@ class Command(PaperlessCommand): default=False, action="store_true", help=( - "If set, the document retagger will overwrite any previously " - "set correspondent, document and remove correspondents, types " - "and tags that do not match anymore due to changed rules." + "Overwrite any previously set correspondent, document type, and " + "remove tags that no longer match due to changed rules." ), ) parser.add_argument( "--suggest", default=False, action="store_true", - help="Return the suggestion, don't change anything.", + help="Show what would be changed without applying anything.", ) parser.add_argument( "--base-url", - help="The base URL to use to build the link to the documents.", + help="Base URL used to build document links in suggest output.", ) parser.add_argument( "--id-range", - help="A range of document ids on which the retagging should be applied.", + help="Restrict retagging to documents within this ID range (inclusive).", nargs=2, type=int, ) - def handle(self, *args, **options): + def handle(self, *args, **options) -> None: + suggest: bool = options["suggest"] + overwrite: bool = options["overwrite"] + use_first: bool = options["use_first"] + base_url: str | None = options["base_url"] + + do_correspondent: bool = options["correspondent"] + do_document_type: bool = options["document_type"] + do_tags: bool = options["tags"] + do_storage_path: bool = options["storage_path"] + + if not any([do_correspondent, do_document_type, do_tags, do_storage_path]): + self.console.print( + "[yellow]No classifier targets specified. " + "Use -c, -T, -t, or -s to select what to retag.[/yellow]", + ) + return + if options["inbox_only"]: queryset = Document.objects.filter(tags__is_inbox_tag=True) else: queryset = Document.objects.all() if options["id_range"]: - queryset = queryset.filter( - id__range=(options["id_range"][0], options["id_range"][1]), - ) + lo, hi = options["id_range"] + queryset = queryset.filter(id__range=(lo, hi)) documents = queryset.distinct() - classifier = load_classifier() - for document in self.track(documents, description="Retagging..."): - if options["correspondent"]: - set_correspondent( - sender=None, - document=document, - classifier=classifier, - replace=options["overwrite"], - use_first=options["use_first"], - suggest=options["suggest"], - base_url=options["base_url"], - stdout=self.stdout, - style_func=self.style, - ) + stats = RetaggerStats() + suggestions: list[DocumentSuggestion] = [] - if options["document_type"]: - set_document_type( - sender=None, - document=document, - classifier=classifier, - replace=options["overwrite"], - use_first=options["use_first"], - suggest=options["suggest"], - base_url=options["base_url"], - stdout=self.stdout, - style_func=self.style, - ) + def render_stats() -> RenderableType: + return _build_stats_table(stats, suggest=suggest) - if options["tags"]: - set_tags( - sender=None, - document=document, - classifier=classifier, - replace=options["overwrite"], - suggest=options["suggest"], - base_url=options["base_url"], - stdout=self.stdout, - style_func=self.style, - ) + for document in self.track_with_stats( + documents, + description="Retagging...", + stats_renderer=render_stats, + ): + suggestion = DocumentSuggestion(document=document) - if options["storage_path"]: - set_storage_path( - sender=None, - document=document, + if do_correspondent: + correspondent = set_correspondent( + None, + document, classifier=classifier, - replace=options["overwrite"], - use_first=options["use_first"], - suggest=options["suggest"], - base_url=options["base_url"], - stdout=self.stdout, - style_func=self.style, + replace=overwrite, + use_first=use_first, + dry_run=suggest, ) + if correspondent is not None: + stats.correspondents += 1 + suggestion.correspondent = correspondent + + if do_document_type: + document_type = set_document_type( + None, + document, + classifier=classifier, + replace=overwrite, + use_first=use_first, + dry_run=suggest, + ) + if document_type is not None: + stats.document_types += 1 + suggestion.document_type = document_type + + if do_tags: + tags_to_add, tags_to_remove = set_tags( + None, + document, + classifier=classifier, + replace=overwrite, + dry_run=suggest, + ) + stats.tags_added += len(tags_to_add) + stats.tags_removed += len(tags_to_remove) + suggestion.tags_to_add = frozenset(tags_to_add) + suggestion.tags_to_remove = frozenset(tags_to_remove) + + if do_storage_path: + storage_path = set_storage_path( + None, + document, + classifier=classifier, + replace=overwrite, + use_first=use_first, + dry_run=suggest, + ) + if storage_path is not None: + stats.storage_paths += 1 + suggestion.storage_path = storage_path + + stats.documents_processed += 1 + + if suggest: + suggestions.append(suggestion) + + # Post-loop output + if suggest: + visible = [s for s in suggestions if s.has_suggestions] + if visible: + self.console.print(_build_suggestion_table(visible, base_url)) + else: + self.console.print("[green]No changes suggested.[/green]") + else: + self.console.print(_build_summary_table(stats)) diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 5fb7ef8c5..a563255f0 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -4,6 +4,7 @@ import logging import shutil from pathlib import Path from typing import TYPE_CHECKING +from typing import Any from celery import shared_task from celery import states @@ -32,12 +33,14 @@ from documents.file_handling import create_source_path_directory from documents.file_handling import delete_empty_directories from documents.file_handling import generate_filename from documents.file_handling import generate_unique_filename +from documents.models import Correspondent from documents.models import CustomField from documents.models import CustomFieldInstance from documents.models import Document -from documents.models import MatchingModel +from documents.models import DocumentType from documents.models import PaperlessTask from documents.models import SavedView +from documents.models import StoragePath from documents.models import Tag from documents.models import UiSettings from documents.models import Workflow @@ -81,47 +84,41 @@ def add_inbox_tags(sender, document: Document, logging_group=None, **kwargs) -> document.add_nested_tags(inbox_tags) -def _suggestion_printer( - stdout, - style_func, - suggestion_type: str, - document: Document, - selected: MatchingModel, - base_url: str | None = None, -) -> None: - """ - Smaller helper to reduce duplication when just outputting suggestions to the console - """ - doc_str = str(document) - if base_url is not None: - stdout.write(style_func.SUCCESS(doc_str)) - stdout.write(style_func.SUCCESS(f"{base_url}/documents/{document.pk}")) - else: - stdout.write(style_func.SUCCESS(f"{doc_str} [{document.pk}]")) - stdout.write(f"Suggest {suggestion_type}: {selected}") - - def set_correspondent( - sender, + sender: object, document: Document, *, - logging_group=None, + logging_group: object = None, classifier: DocumentClassifier | None = None, - replace=False, - use_first=True, - suggest=False, - base_url=None, - stdout=None, - style_func=None, - **kwargs, -) -> None: + replace: bool = False, + use_first: bool = True, + dry_run: bool = False, + **kwargs: Any, +) -> Correspondent | None: + """ + Assign a correspondent to a document based on classifier results. + + Args: + document: The document to classify. + logging_group: Optional logging group for structured log output. + classifier: The trained classifier. If None, only rule-based matching runs. + replace: If True, overwrite an existing correspondent assignment. + use_first: If True, pick the first match when multiple correspondents + match. If False, skip assignment when multiple match. + dry_run: If True, compute and return the selection without saving. + **kwargs: Absorbed for Django signal compatibility (e.g. sender, signal). + + Returns: + The correspondent that was (or would be) assigned, or None if no match + was found or assignment was skipped. + """ if document.correspondent and not replace: - return + return None potential_correspondents = matching.match_correspondents(document, classifier) - potential_count = len(potential_correspondents) selected = potential_correspondents[0] if potential_correspondents else None + if potential_count > 1: if use_first: logger.debug( @@ -135,49 +132,53 @@ def set_correspondent( f"not assigning any correspondent", extra={"group": logging_group}, ) - return + return None - if selected or replace: - if suggest: - _suggestion_printer( - stdout, - style_func, - "correspondent", - document, - selected, - base_url, - ) - else: - logger.info( - f"Assigning correspondent {selected} to {document}", - extra={"group": logging_group}, - ) + if (selected or replace) and not dry_run: + logger.info( + f"Assigning correspondent {selected} to {document}", + extra={"group": logging_group}, + ) + document.correspondent = selected + document.save(update_fields=("correspondent",)) - document.correspondent = selected - document.save(update_fields=("correspondent",)) + return selected def set_document_type( - sender, + sender: object, document: Document, *, - logging_group=None, + logging_group: object = None, classifier: DocumentClassifier | None = None, - replace=False, - use_first=True, - suggest=False, - base_url=None, - stdout=None, - style_func=None, - **kwargs, -) -> None: + replace: bool = False, + use_first: bool = True, + dry_run: bool = False, + **kwargs: Any, +) -> DocumentType | None: + """ + Assign a document type to a document based on classifier results. + + Args: + document: The document to classify. + logging_group: Optional logging group for structured log output. + classifier: The trained classifier. If None, only rule-based matching runs. + replace: If True, overwrite an existing document type assignment. + use_first: If True, pick the first match when multiple types match. + If False, skip assignment when multiple match. + dry_run: If True, compute and return the selection without saving. + **kwargs: Absorbed for Django signal compatibility (e.g. sender, signal). + + Returns: + The document type that was (or would be) assigned, or None if no match + was found or assignment was skipped. + """ if document.document_type and not replace: - return + return None - potential_document_type = matching.match_document_types(document, classifier) - - potential_count = len(potential_document_type) - selected = potential_document_type[0] if potential_document_type else None + potential_document_types = matching.match_document_types(document, classifier) + potential_count = len(potential_document_types) + selected = potential_document_types[0] if potential_document_types else None if potential_count > 1: if use_first: @@ -192,42 +193,64 @@ def set_document_type( f"not assigning any document type", extra={"group": logging_group}, ) - return + return None - if selected or replace: - if suggest: - _suggestion_printer( - stdout, - style_func, - "document type", - document, - selected, - base_url, - ) - else: - logger.info( - f"Assigning document type {selected} to {document}", - extra={"group": logging_group}, - ) + if (selected or replace) and not dry_run: + logger.info( + f"Assigning document type {selected} to {document}", + extra={"group": logging_group}, + ) + document.document_type = selected + document.save(update_fields=("document_type",)) - document.document_type = selected - document.save(update_fields=("document_type",)) + return selected def set_tags( - sender, + sender: object, document: Document, *, - logging_group=None, + logging_group: object = None, classifier: DocumentClassifier | None = None, - replace=False, - suggest=False, - base_url=None, - stdout=None, - style_func=None, - **kwargs, -) -> None: + replace: bool = False, + dry_run: bool = False, + **kwargs: Any, +) -> tuple[set[Tag], set[Tag]]: + """ + Assign tags to a document based on classifier results. + + When replace=True, existing auto-matched and rule-matched tags are removed + before applying the new set (inbox tags and manually-added tags are preserved). + + Args: + document: The document to classify. + logging_group: Optional logging group for structured log output. + classifier: The trained classifier. If None, only rule-based matching runs. + replace: If True, remove existing classifier-managed tags before applying + new ones. Inbox tags and manually-added tags are always preserved. + dry_run: If True, compute what would change without saving anything. + **kwargs: Absorbed for Django signal compatibility (e.g. sender, signal). + + Returns: + A two-tuple of (tags_added, tags_removed). In non-replace mode, + tags_removed is always an empty set. In dry_run mode, neither set + is applied to the database. + """ + # Compute which tags would be removed under replace mode. + # The filter mirrors the .delete() call below: keep inbox tags and + # manually-added tags (match="" and not auto-matched). if replace: + tags_to_remove: set[Tag] = set( + document.tags.exclude( + is_inbox_tag=True, + ).exclude( + Q(match="") & ~Q(matching_algorithm=Tag.MATCH_AUTO), + ), + ) + else: + tags_to_remove = set() + + if replace and not dry_run: Document.tags.through.objects.filter(document=document).exclude( Q(tag__is_inbox_tag=True), ).exclude( @@ -235,65 +258,53 @@ def set_tags( ).delete() current_tags = set(document.tags.all()) - matched_tags = matching.match_tags(document, classifier) + tags_to_add = set(matched_tags) - current_tags - relevant_tags = set(matched_tags) - current_tags - - if suggest: - extra_tags = current_tags - set(matched_tags) - extra_tags = [ - t for t in extra_tags if t.matching_algorithm == MatchingModel.MATCH_AUTO - ] - if not relevant_tags and not extra_tags: - return - doc_str = style_func.SUCCESS(str(document)) - if base_url: - stdout.write(doc_str) - stdout.write(f"{base_url}/documents/{document.pk}") - else: - stdout.write(doc_str + style_func.SUCCESS(f" [{document.pk}]")) - if relevant_tags: - stdout.write("Suggest tags: " + ", ".join([t.name for t in relevant_tags])) - if extra_tags: - stdout.write("Extra tags: " + ", ".join([t.name for t in extra_tags])) - else: - if not relevant_tags: - return - - message = 'Tagging "{}" with "{}"' + if tags_to_add and not dry_run: logger.info( - message.format(document, ", ".join([t.name for t in relevant_tags])), + f'Tagging "{document}" with "{", ".join(t.name for t in tags_to_add)}"', extra={"group": logging_group}, ) + document.add_nested_tags(tags_to_add) - document.add_nested_tags(relevant_tags) + return tags_to_add, tags_to_remove def set_storage_path( - sender, + sender: object, document: Document, *, - logging_group=None, + logging_group: object = None, classifier: DocumentClassifier | None = None, - replace=False, - use_first=True, - suggest=False, - base_url=None, - stdout=None, - style_func=None, - **kwargs, -) -> None: + replace: bool = False, + use_first: bool = True, + dry_run: bool = False, + **kwargs: Any, +) -> StoragePath | None: + """ + Assign a storage path to a document based on classifier results. + + Args: + document: The document to classify. + logging_group: Optional logging group for structured log output. + classifier: The trained classifier. If None, only rule-based matching runs. + replace: If True, overwrite an existing storage path assignment. + use_first: If True, pick the first match when multiple paths match. + If False, skip assignment when multiple match. + dry_run: If True, compute and return the selection without saving. + **kwargs: Absorbed for Django signal compatibility (e.g. sender, signal). + + Returns: + The storage path that was (or would be) assigned, or None if no match + was found or assignment was skipped. + """ if document.storage_path and not replace: - return + return None - potential_storage_path = matching.match_storage_paths( - document, - classifier, - ) - - potential_count = len(potential_storage_path) - selected = potential_storage_path[0] if potential_storage_path else None + potential_storage_paths = matching.match_storage_paths(document, classifier) + potential_count = len(potential_storage_paths) + selected = potential_storage_paths[0] if potential_storage_paths else None if potential_count > 1: if use_first: @@ -308,26 +319,17 @@ def set_storage_path( f"not assigning any storage directory", extra={"group": logging_group}, ) - return + return None - if selected or replace: - if suggest: - _suggestion_printer( - stdout, - style_func, - "storage directory", - document, - selected, - base_url, - ) - else: - logger.info( - f"Assigning storage path {selected} to {document}", - extra={"group": logging_group}, - ) + if (selected or replace) and not dry_run: + logger.info( + f"Assigning storage path {selected} to {document}", + extra={"group": logging_group}, + ) + document.storage_path = selected + document.save(update_fields=("storage_path",)) - document.storage_path = selected - document.save(update_fields=("storage_path",)) + return selected # see empty_trash in documents/tasks.py for signal handling diff --git a/src/documents/tests/factories.py b/src/documents/tests/factories.py index de41bbd02..982634db5 100644 --- a/src/documents/tests/factories.py +++ b/src/documents/tests/factories.py @@ -1,17 +1,65 @@ -from factory import Faker +""" +Factory-boy factories for documents app models. +""" + +from __future__ import annotations + +import factory from factory.django import DjangoModelFactory from documents.models import Correspondent from documents.models import Document +from documents.models import DocumentType +from documents.models import MatchingModel +from documents.models import StoragePath +from documents.models import Tag class CorrespondentFactory(DjangoModelFactory): class Meta: model = Correspondent - name = Faker("name") + name = factory.Faker("company") + match = "" + matching_algorithm = MatchingModel.MATCH_NONE + + +class DocumentTypeFactory(DjangoModelFactory): + class Meta: + model = DocumentType + + name = factory.Faker("bs") + match = "" + matching_algorithm = MatchingModel.MATCH_NONE + + +class TagFactory(DjangoModelFactory): + class Meta: + model = Tag + + name = factory.Faker("word") + match = "" + matching_algorithm = MatchingModel.MATCH_NONE + is_inbox_tag = False + + +class StoragePathFactory(DjangoModelFactory): + class Meta: + model = StoragePath + + name = factory.Faker("file_path", depth=2, extension="") + path = factory.LazyAttribute(lambda o: f"{o.name}/{{title}}") + match = "" + matching_algorithm = MatchingModel.MATCH_NONE class DocumentFactory(DjangoModelFactory): class Meta: model = Document + + title = factory.Faker("sentence", nb_words=4) + checksum = factory.Faker("md5") + content = factory.Faker("paragraph") + correspondent = None + document_type = None + storage_path = None diff --git a/src/documents/tests/test_management_retagger.py b/src/documents/tests/test_management_retagger.py index 29b322c28..2c5f0bd72 100644 --- a/src/documents/tests/test_management_retagger.py +++ b/src/documents/tests/test_management_retagger.py @@ -1,298 +1,358 @@ +""" +Tests for the document_retagger management command. +""" + +from __future__ import annotations + import pytest from django.core.management import call_command from django.core.management.base import CommandError -from django.test import TestCase from documents.models import Correspondent from documents.models import Document from documents.models import DocumentType +from documents.models import MatchingModel from documents.models import StoragePath from documents.models import Tag +from documents.tests.factories import CorrespondentFactory +from documents.tests.factories import DocumentFactory +from documents.tests.factories import DocumentTypeFactory +from documents.tests.factories import StoragePathFactory +from documents.tests.factories import TagFactory from documents.tests.utils import DirectoriesMixin +# --------------------------------------------------------------------------- +# Module-level type aliases +# --------------------------------------------------------------------------- + +StoragePathTuple = tuple[StoragePath, StoragePath, StoragePath] +TagTuple = tuple[Tag, Tag, Tag, Tag, Tag] +CorrespondentTuple = tuple[Correspondent, Correspondent] +DocumentTypeTuple = tuple[DocumentType, DocumentType] +DocumentTuple = tuple[Document, Document, Document, Document] + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def storage_paths(db) -> StoragePathTuple: + """Three storage paths with varying match rules.""" + sp1 = StoragePathFactory( + path="{created_data}/{title}", + match="auto document", + matching_algorithm=MatchingModel.MATCH_LITERAL, + ) + sp2 = StoragePathFactory( + path="{title}", + match="^first|^unrelated", + matching_algorithm=MatchingModel.MATCH_REGEX, + ) + sp3 = StoragePathFactory( + path="{title}", + match="^blah", + matching_algorithm=MatchingModel.MATCH_REGEX, + ) + return sp1, sp2, sp3 + + +@pytest.fixture() +def tags(db) -> TagTuple: + """Tags covering the common matching scenarios.""" + tag_first = TagFactory(match="first", matching_algorithm=Tag.MATCH_ANY) + tag_second = TagFactory(match="second", matching_algorithm=Tag.MATCH_ANY) + tag_inbox = TagFactory(is_inbox_tag=True) + tag_no_match = TagFactory() + tag_auto = TagFactory(matching_algorithm=Tag.MATCH_AUTO) + return tag_first, tag_second, tag_inbox, tag_no_match, tag_auto + + +@pytest.fixture() +def correspondents(db) -> CorrespondentTuple: + """Two correspondents matching 'first' and 'second' content.""" + c_first = CorrespondentFactory( + match="first", + matching_algorithm=MatchingModel.MATCH_ANY, + ) + c_second = CorrespondentFactory( + match="second", + matching_algorithm=MatchingModel.MATCH_ANY, + ) + return c_first, c_second + + +@pytest.fixture() +def document_types(db) -> DocumentTypeTuple: + """Two document types matching 'first' and 'second' content.""" + dt_first = DocumentTypeFactory( + match="first", + matching_algorithm=MatchingModel.MATCH_ANY, + ) + dt_second = DocumentTypeFactory( + match="second", + matching_algorithm=MatchingModel.MATCH_ANY, + ) + return dt_first, dt_second + + +@pytest.fixture() +def documents(storage_paths: StoragePathTuple, tags: TagTuple) -> DocumentTuple: + """Four documents with varied content used across most retagger tests.""" + sp1, sp2, sp3 = storage_paths + tag_first, tag_second, tag_inbox, tag_no_match, tag_auto = tags + + d1 = DocumentFactory(checksum="A", title="A", content="first document") + d2 = DocumentFactory(checksum="B", title="B", content="second document") + d3 = DocumentFactory( + checksum="C", + title="C", + content="unrelated document", + storage_path=sp3, + ) + d4 = DocumentFactory(checksum="D", title="D", content="auto document") + + d3.tags.add(tag_inbox, tag_no_match) + d4.tags.add(tag_auto) + + return d1, d2, d3, d4 + + +def _get_docs() -> DocumentTuple: + return ( + Document.objects.get(title="A"), + Document.objects.get(title="B"), + Document.objects.get(title="C"), + Document.objects.get(title="D"), + ) + + +# --------------------------------------------------------------------------- +# Tag assignment +# --------------------------------------------------------------------------- + @pytest.mark.management -class TestRetagger(DirectoriesMixin, TestCase): - def make_models(self) -> None: - self.sp1 = StoragePath.objects.create( - name="dummy a", - path="{created_data}/{title}", - match="auto document", - matching_algorithm=StoragePath.MATCH_LITERAL, - ) - self.sp2 = StoragePath.objects.create( - name="dummy b", - path="{title}", - match="^first|^unrelated", - matching_algorithm=StoragePath.MATCH_REGEX, - ) - - self.sp3 = StoragePath.objects.create( - name="dummy c", - path="{title}", - match="^blah", - matching_algorithm=StoragePath.MATCH_REGEX, - ) - - self.d1 = Document.objects.create( - checksum="A", - title="A", - content="first document", - ) - self.d2 = Document.objects.create( - checksum="B", - title="B", - content="second document", - ) - self.d3 = Document.objects.create( - checksum="C", - title="C", - content="unrelated document", - storage_path=self.sp3, - ) - self.d4 = Document.objects.create( - checksum="D", - title="D", - content="auto document", - ) - - self.tag_first = Tag.objects.create( - name="tag1", - match="first", - matching_algorithm=Tag.MATCH_ANY, - ) - self.tag_second = Tag.objects.create( - name="tag2", - match="second", - matching_algorithm=Tag.MATCH_ANY, - ) - self.tag_inbox = Tag.objects.create(name="test", is_inbox_tag=True) - self.tag_no_match = Tag.objects.create(name="test2") - self.tag_auto = Tag.objects.create( - name="tagauto", - matching_algorithm=Tag.MATCH_AUTO, - ) - - self.d3.tags.add(self.tag_inbox) - self.d3.tags.add(self.tag_no_match) - self.d4.tags.add(self.tag_auto) - - self.correspondent_first = Correspondent.objects.create( - name="c1", - match="first", - matching_algorithm=Correspondent.MATCH_ANY, - ) - self.correspondent_second = Correspondent.objects.create( - name="c2", - match="second", - matching_algorithm=Correspondent.MATCH_ANY, - ) - - self.doctype_first = DocumentType.objects.create( - name="dt1", - match="first", - matching_algorithm=DocumentType.MATCH_ANY, - ) - self.doctype_second = DocumentType.objects.create( - name="dt2", - match="second", - matching_algorithm=DocumentType.MATCH_ANY, - ) - - def get_updated_docs(self): - return ( - Document.objects.get(title="A"), - Document.objects.get(title="B"), - Document.objects.get(title="C"), - Document.objects.get(title="D"), - ) - - def setUp(self) -> None: - super().setUp() - self.make_models() - - def test_add_tags(self) -> None: +@pytest.mark.django_db +class TestRetaggerTags(DirectoriesMixin): + @pytest.mark.usefixtures("documents") + def test_add_tags(self, tags: TagTuple) -> None: + tag_first, tag_second, *_ = tags call_command("document_retagger", "--tags") - d_first, d_second, d_unrelated, d_auto = self.get_updated_docs() + d_first, d_second, d_unrelated, d_auto = _get_docs() - self.assertEqual(d_first.tags.count(), 1) - self.assertEqual(d_second.tags.count(), 1) - self.assertEqual(d_unrelated.tags.count(), 2) - self.assertEqual(d_auto.tags.count(), 1) + assert d_first.tags.count() == 1 + assert d_second.tags.count() == 1 + assert d_unrelated.tags.count() == 2 + assert d_auto.tags.count() == 1 + assert d_first.tags.first() == tag_first + assert d_second.tags.first() == tag_second - self.assertEqual(d_first.tags.first(), self.tag_first) - self.assertEqual(d_second.tags.first(), self.tag_second) - - def test_add_type(self) -> None: - call_command("document_retagger", "--document_type") - d_first, d_second, _, _ = self.get_updated_docs() - - self.assertEqual(d_first.document_type, self.doctype_first) - self.assertEqual(d_second.document_type, self.doctype_second) - - def test_add_correspondent(self) -> None: - call_command("document_retagger", "--correspondent") - d_first, d_second, _, _ = self.get_updated_docs() - - self.assertEqual(d_first.correspondent, self.correspondent_first) - self.assertEqual(d_second.correspondent, self.correspondent_second) - - def test_overwrite_preserve_inbox(self) -> None: - self.d1.tags.add(self.tag_second) + def test_overwrite_removes_stale_tags_and_preserves_inbox( + self, + documents: DocumentTuple, + tags: TagTuple, + ) -> None: + d1, *_ = documents + tag_first, tag_second, tag_inbox, tag_no_match, tag_auto = tags + d1.tags.add(tag_second) call_command("document_retagger", "--tags", "--overwrite") - d_first, d_second, d_unrelated, d_auto = self.get_updated_docs() + d_first, d_second, d_unrelated, d_auto = _get_docs() - self.assertIsNotNone(Tag.objects.get(id=self.tag_second.id)) + assert Tag.objects.filter(id=tag_second.id).exists() + assert list(d_first.tags.values_list("id", flat=True)) == [tag_first.id] + assert list(d_second.tags.values_list("id", flat=True)) == [tag_second.id] + assert set(d_unrelated.tags.values_list("id", flat=True)) == { + tag_inbox.id, + tag_no_match.id, + } + assert d_auto.tags.count() == 0 - self.assertCountEqual( - [tag.id for tag in d_first.tags.all()], - [self.tag_first.id], - ) - self.assertCountEqual( - [tag.id for tag in d_second.tags.all()], - [self.tag_second.id], - ) - self.assertCountEqual( - [tag.id for tag in d_unrelated.tags.all()], - [self.tag_inbox.id, self.tag_no_match.id], - ) - self.assertEqual(d_auto.tags.count(), 0) + @pytest.mark.usefixtures("documents") + @pytest.mark.parametrize( + "extra_args", + [ + pytest.param([], id="no_base_url"), + pytest.param(["--base-url=http://localhost"], id="with_base_url"), + ], + ) + def test_suggest_does_not_apply_tags(self, extra_args: list[str]) -> None: + call_command("document_retagger", "--tags", "--suggest", *extra_args) + d_first, d_second, _, d_auto = _get_docs() - def test_add_tags_suggest(self) -> None: - call_command("document_retagger", "--tags", "--suggest") - d_first, d_second, _, d_auto = self.get_updated_docs() + assert d_first.tags.count() == 0 + assert d_second.tags.count() == 0 + assert d_auto.tags.count() == 1 - self.assertEqual(d_first.tags.count(), 0) - self.assertEqual(d_second.tags.count(), 0) - self.assertEqual(d_auto.tags.count(), 1) - def test_add_type_suggest(self) -> None: - call_command("document_retagger", "--document_type", "--suggest") - d_first, d_second, _, _ = self.get_updated_docs() +# --------------------------------------------------------------------------- +# Document type assignment +# --------------------------------------------------------------------------- - self.assertIsNone(d_first.document_type) - self.assertIsNone(d_second.document_type) - def test_add_correspondent_suggest(self) -> None: - call_command("document_retagger", "--correspondent", "--suggest") - d_first, d_second, _, _ = self.get_updated_docs() +@pytest.mark.management +@pytest.mark.django_db +class TestRetaggerDocumentType(DirectoriesMixin): + @pytest.mark.usefixtures("documents") + def test_add_type(self, document_types: DocumentTypeTuple) -> None: + dt_first, dt_second = document_types + call_command("document_retagger", "--document_type") + d_first, d_second, _, _ = _get_docs() - self.assertIsNone(d_first.correspondent) - self.assertIsNone(d_second.correspondent) + assert d_first.document_type == dt_first + assert d_second.document_type == dt_second - def test_add_tags_suggest_url(self) -> None: - call_command( - "document_retagger", - "--tags", - "--suggest", - "--base-url=http://localhost", - ) - d_first, d_second, _, d_auto = self.get_updated_docs() + @pytest.mark.usefixtures("documents", "document_types") + @pytest.mark.parametrize( + "extra_args", + [ + pytest.param([], id="no_base_url"), + pytest.param(["--base-url=http://localhost"], id="with_base_url"), + ], + ) + def test_suggest_does_not_apply_document_type(self, extra_args: list[str]) -> None: + call_command("document_retagger", "--document_type", "--suggest", *extra_args) + d_first, d_second, _, _ = _get_docs() - self.assertEqual(d_first.tags.count(), 0) - self.assertEqual(d_second.tags.count(), 0) - self.assertEqual(d_auto.tags.count(), 1) + assert d_first.document_type is None + assert d_second.document_type is None - def test_add_type_suggest_url(self) -> None: - call_command( - "document_retagger", - "--document_type", - "--suggest", - "--base-url=http://localhost", - ) - d_first, d_second, _, _ = self.get_updated_docs() - self.assertIsNone(d_first.document_type) - self.assertIsNone(d_second.document_type) +# --------------------------------------------------------------------------- +# Correspondent assignment +# --------------------------------------------------------------------------- - def test_add_correspondent_suggest_url(self) -> None: - call_command( - "document_retagger", - "--correspondent", - "--suggest", - "--base-url=http://localhost", - ) - d_first, d_second, _, _ = self.get_updated_docs() - self.assertIsNone(d_first.correspondent) - self.assertIsNone(d_second.correspondent) +@pytest.mark.management +@pytest.mark.django_db +class TestRetaggerCorrespondent(DirectoriesMixin): + @pytest.mark.usefixtures("documents") + def test_add_correspondent(self, correspondents: CorrespondentTuple) -> None: + c_first, c_second = correspondents + call_command("document_retagger", "--correspondent") + d_first, d_second, _, _ = _get_docs() - def test_add_storage_path(self) -> None: + assert d_first.correspondent == c_first + assert d_second.correspondent == c_second + + @pytest.mark.usefixtures("documents", "correspondents") + @pytest.mark.parametrize( + "extra_args", + [ + pytest.param([], id="no_base_url"), + pytest.param(["--base-url=http://localhost"], id="with_base_url"), + ], + ) + def test_suggest_does_not_apply_correspondent(self, extra_args: list[str]) -> None: + call_command("document_retagger", "--correspondent", "--suggest", *extra_args) + d_first, d_second, _, _ = _get_docs() + + assert d_first.correspondent is None + assert d_second.correspondent is None + + +# --------------------------------------------------------------------------- +# Storage path assignment +# --------------------------------------------------------------------------- + + +@pytest.mark.management +@pytest.mark.django_db +class TestRetaggerStoragePath(DirectoriesMixin): + @pytest.mark.usefixtures("documents") + def test_add_storage_path(self, storage_paths: StoragePathTuple) -> None: """ - GIVEN: - - 2 storage paths with documents which match them - - 1 document which matches but has a storage path - WHEN: - - document retagger is called - THEN: - - Matching document's storage paths updated - - Non-matching documents have no storage path - - Existing storage patch left unchanged + GIVEN documents matching various storage path rules + WHEN document_retagger --storage_path is called + THEN matching documents get the correct path; existing path is unchanged """ - call_command( - "document_retagger", - "--storage_path", - ) - d_first, d_second, d_unrelated, d_auto = self.get_updated_docs() + sp1, sp2, sp3 = storage_paths + call_command("document_retagger", "--storage_path") + d_first, d_second, d_unrelated, d_auto = _get_docs() - self.assertEqual(d_first.storage_path, self.sp2) - self.assertEqual(d_auto.storage_path, self.sp1) - self.assertIsNone(d_second.storage_path) - self.assertEqual(d_unrelated.storage_path, self.sp3) + assert d_first.storage_path == sp2 + assert d_auto.storage_path == sp1 + assert d_second.storage_path is None + assert d_unrelated.storage_path == sp3 - def test_overwrite_storage_path(self) -> None: + @pytest.mark.usefixtures("documents") + def test_overwrite_storage_path(self, storage_paths: StoragePathTuple) -> None: """ - GIVEN: - - 2 storage paths with documents which match them - - 1 document which matches but has a storage path - WHEN: - - document retagger is called with overwrite - THEN: - - Matching document's storage paths updated - - Non-matching documents have no storage path - - Existing storage patch overwritten + GIVEN a document with an existing storage path that matches a different rule + WHEN document_retagger --storage_path --overwrite is called + THEN the existing path is replaced by the newly matched path """ + sp1, sp2, _ = storage_paths call_command("document_retagger", "--storage_path", "--overwrite") - d_first, d_second, d_unrelated, d_auto = self.get_updated_docs() + d_first, d_second, d_unrelated, d_auto = _get_docs() - self.assertEqual(d_first.storage_path, self.sp2) - self.assertEqual(d_auto.storage_path, self.sp1) - self.assertIsNone(d_second.storage_path) - self.assertEqual(d_unrelated.storage_path, self.sp2) + assert d_first.storage_path == sp2 + assert d_auto.storage_path == sp1 + assert d_second.storage_path is None + assert d_unrelated.storage_path == sp2 - def test_id_range_parameter(self) -> None: - commandOutput = "" - Document.objects.create( - checksum="E", - title="E", - content="NOT the first document", - ) - call_command("document_retagger", "--tags", "--id-range", "1", "2") - # The retagger shouldn`t apply the 'first' tag to our new document - self.assertEqual(Document.objects.filter(tags__id=self.tag_first.id).count(), 1) - try: - commandOutput = call_command("document_retagger", "--tags", "--id-range") - except CommandError: - # Just ignore the error - None - self.assertIn(commandOutput, "Error: argument --id-range: expected 2 arguments") +# --------------------------------------------------------------------------- +# ID range filtering +# --------------------------------------------------------------------------- - try: - commandOutput = call_command( - "document_retagger", - "--tags", - "--id-range", - "a", - "b", - ) - except CommandError: - # Just ignore the error - None - self.assertIn(commandOutput, "error: argument --id-range: invalid int value:") - call_command("document_retagger", "--tags", "--id-range", "1", "9999") - # Now we should have 2 documents - self.assertEqual(Document.objects.filter(tags__id=self.tag_first.id).count(), 2) +@pytest.mark.management +@pytest.mark.django_db +class TestRetaggerIdRange(DirectoriesMixin): + @pytest.mark.usefixtures("documents") + @pytest.mark.parametrize( + ("id_range_args", "expected_count"), + [ + pytest.param(["1", "2"], 1, id="narrow_range_limits_scope"), + pytest.param(["1", "9999"], 2, id="wide_range_tags_all_matches"), + ], + ) + def test_id_range_limits_scope( + self, + tags: TagTuple, + id_range_args: list[str], + expected_count: int, + ) -> None: + DocumentFactory(content="NOT the first document") + call_command("document_retagger", "--tags", "--id-range", *id_range_args) + tag_first, *_ = tags + assert Document.objects.filter(tags__id=tag_first.id).count() == expected_count + + @pytest.mark.usefixtures("documents") + @pytest.mark.parametrize( + "args", + [ + pytest.param(["--tags", "--id-range"], id="missing_both_values"), + pytest.param(["--tags", "--id-range", "a", "b"], id="non_integer_values"), + ], + ) + def test_id_range_invalid_arguments_raise(self, args: list[str]) -> None: + with pytest.raises((CommandError, SystemExit)): + call_command("document_retagger", *args) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.management +@pytest.mark.django_db +class TestRetaggerEdgeCases(DirectoriesMixin): + @pytest.mark.usefixtures("documents") + def test_no_targets_exits_cleanly(self) -> None: + """Calling the retagger with no classifier targets should not raise.""" + call_command("document_retagger") + + @pytest.mark.usefixtures("documents") + def test_inbox_only_skips_non_inbox_documents(self) -> None: + """--inbox-only must restrict processing to documents with an inbox tag.""" + call_command("document_retagger", "--tags", "--inbox-only") + d_first, _, d_unrelated, _ = _get_docs() + + assert d_first.tags.count() == 0 + assert d_unrelated.tags.count() == 2