mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-03-20 07:55:57 +00:00
* Refactor: migrate exporter/importer from tqdm to PaperlessCommand.track() Replace direct tqdm usage in document_exporter and document_importer with the PaperlessCommand base class and its track() method, which is backed by Rich and handles --no-progress-bar automatically. Also removes the unused ProgressBarMixin from mixins.py. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Refactor: add explicit supports_progress_bar and supports_multiprocessing to all PaperlessCommand subclasses Each management command now explicitly declares both class attributes rather than relying on defaults, making intent unambiguous at a glance. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
551 lines
18 KiB
Python
551 lines
18 KiB
Python
"""
|
|
Base command class for Paperless-ngx management commands.
|
|
|
|
Provides automatic progress bar and multiprocessing support with minimal boilerplate.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
from collections.abc import Callable
|
|
from collections.abc import Iterable
|
|
from collections.abc import Sized
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from concurrent.futures import as_completed
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING
|
|
from typing import Any
|
|
from typing import ClassVar
|
|
from typing import Generic
|
|
from typing import TypeVar
|
|
|
|
from django import db
|
|
from django.core.management import CommandError
|
|
from django.db.models import QuerySet
|
|
from django_rich.management import RichCommand
|
|
from rich import box
|
|
from rich.console import Console
|
|
from rich.console import Group
|
|
from rich.console import RenderableType
|
|
from rich.live import Live
|
|
from rich.progress import BarColumn
|
|
from rich.progress import MofNCompleteColumn
|
|
from rich.progress import Progress
|
|
from rich.progress import SpinnerColumn
|
|
from rich.progress import TextColumn
|
|
from rich.progress import TimeElapsedColumn
|
|
from rich.progress import TimeRemainingColumn
|
|
from rich.table import Table
|
|
from rich.text import Text
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
from collections.abc import Sequence
|
|
|
|
from django.core.management import CommandParser
|
|
|
|
T = TypeVar("T")
|
|
R = TypeVar("R")
|
|
|
|
|
|
@dataclass(slots=True, frozen=True)
|
|
class _BufferedRecord:
|
|
level: int
|
|
name: str
|
|
message: str
|
|
|
|
|
|
class BufferingLogHandler(logging.Handler):
|
|
"""Captures log records during a command run for deferred rendering.
|
|
|
|
Attach to a logger before a long operation and call ``render()``
|
|
afterwards to emit the buffered records via Rich, optionally filtered
|
|
by minimum level.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._records: list[_BufferedRecord] = []
|
|
|
|
def emit(self, record: logging.LogRecord) -> None:
|
|
self._records.append(
|
|
_BufferedRecord(
|
|
level=record.levelno,
|
|
name=record.name,
|
|
message=self.format(record),
|
|
),
|
|
)
|
|
|
|
def render(
|
|
self,
|
|
console: Console,
|
|
*,
|
|
min_level: int = logging.DEBUG,
|
|
title: str = "Log Output",
|
|
) -> None:
|
|
records = [r for r in self._records if r.level >= min_level]
|
|
if not records:
|
|
return
|
|
|
|
table = Table(
|
|
title=title,
|
|
show_header=True,
|
|
header_style="bold",
|
|
show_lines=False,
|
|
box=box.SIMPLE,
|
|
)
|
|
table.add_column("Level", style="bold", width=8)
|
|
table.add_column("Logger", style="dim")
|
|
table.add_column("Message", no_wrap=False)
|
|
|
|
_level_styles: dict[int, str] = {
|
|
logging.DEBUG: "dim",
|
|
logging.INFO: "cyan",
|
|
logging.WARNING: "yellow",
|
|
logging.ERROR: "red",
|
|
logging.CRITICAL: "bold red",
|
|
}
|
|
|
|
for record in records:
|
|
style = _level_styles.get(record.level, "")
|
|
table.add_row(
|
|
Text(logging.getLevelName(record.level), style=style),
|
|
record.name,
|
|
record.message,
|
|
)
|
|
|
|
console.print(table)
|
|
|
|
def clear(self) -> None:
|
|
self._records.clear()
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ProcessResult(Generic[T, R]):
|
|
"""
|
|
Result of processing a single item in parallel.
|
|
|
|
Attributes:
|
|
item: The input item that was processed.
|
|
result: The return value from the processing function, or None if an error occurred.
|
|
error: The exception if processing failed, or None on success.
|
|
"""
|
|
|
|
item: T
|
|
result: R | None
|
|
error: BaseException | None
|
|
|
|
@property
|
|
def success(self) -> bool:
|
|
"""Return True if the item was processed successfully."""
|
|
return self.error is None
|
|
|
|
|
|
class PaperlessCommand(RichCommand):
|
|
"""
|
|
Base command class with automatic progress bar and multiprocessing support.
|
|
|
|
Features are opt-in via class attributes:
|
|
supports_progress_bar: Adds --no-progress-bar argument (default: True)
|
|
supports_multiprocessing: Adds --processes argument (default: False)
|
|
|
|
Example usage:
|
|
|
|
class Command(PaperlessCommand):
|
|
help = "Process all documents"
|
|
|
|
def handle(self, *args, **options):
|
|
documents = Document.objects.all()
|
|
for doc in self.track(documents, description="Processing..."):
|
|
process_document(doc)
|
|
|
|
class Command(PaperlessCommand):
|
|
help = "Regenerate thumbnails"
|
|
supports_multiprocessing = True
|
|
|
|
def handle(self, *args, **options):
|
|
ids = list(Document.objects.values_list("id", flat=True))
|
|
for result in self.process_parallel(process_doc, ids):
|
|
if result.error:
|
|
self.console.print(f"[red]Failed: {result.error}[/red]")
|
|
|
|
class Command(PaperlessCommand):
|
|
help = "Import documents with live stats"
|
|
|
|
def handle(self, *args, **options):
|
|
stats = ImportStats()
|
|
|
|
def render_stats() -> Table:
|
|
... # build Rich Table from stats
|
|
|
|
for item in self.track_with_stats(
|
|
items,
|
|
description="Importing...",
|
|
stats_renderer=render_stats,
|
|
):
|
|
result = import_item(item)
|
|
stats.imported += 1
|
|
"""
|
|
|
|
supports_progress_bar: ClassVar[bool] = True
|
|
supports_multiprocessing: ClassVar[bool] = False
|
|
|
|
# Instance attributes set by execute() before handle() runs
|
|
no_progress_bar: bool
|
|
process_count: int
|
|
|
|
def add_arguments(self, parser: CommandParser) -> None:
|
|
"""Add arguments based on supported features."""
|
|
super().add_arguments(parser)
|
|
|
|
if self.supports_progress_bar:
|
|
parser.add_argument(
|
|
"--no-progress-bar",
|
|
default=False,
|
|
action="store_true",
|
|
help="Disable the progress bar",
|
|
)
|
|
|
|
if self.supports_multiprocessing:
|
|
default_processes = max(1, (os.cpu_count() or 1) // 4)
|
|
parser.add_argument(
|
|
"--processes",
|
|
default=default_processes,
|
|
type=int,
|
|
help=f"Number of processes to use (default: {default_processes})",
|
|
)
|
|
|
|
def execute(self, *args: Any, **options: Any) -> str | None:
|
|
"""
|
|
Set up instance state before handle() is called.
|
|
|
|
This is called by Django's command infrastructure after argument parsing
|
|
but before handle(). We use it to set instance attributes from options.
|
|
"""
|
|
if self.supports_progress_bar:
|
|
self.no_progress_bar = options.get("no_progress_bar", False)
|
|
else:
|
|
self.no_progress_bar = True
|
|
|
|
if self.supports_multiprocessing:
|
|
self.process_count = options.get("processes", 1)
|
|
if self.process_count < 1:
|
|
raise CommandError("--processes must be at least 1")
|
|
else:
|
|
self.process_count = 1
|
|
|
|
return super().execute(*args, **options)
|
|
|
|
@contextmanager
|
|
def buffered_logging(
|
|
self,
|
|
*logger_names: str,
|
|
level: int = logging.DEBUG,
|
|
) -> Generator[BufferingLogHandler, None, None]:
|
|
"""Context manager that captures log output from named loggers.
|
|
|
|
Installs a ``BufferingLogHandler`` on each named logger for the
|
|
duration of the block, suppressing propagation to avoid interleaving
|
|
with the Rich live display. The handler is removed on exit regardless
|
|
of whether an exception occurred.
|
|
|
|
Usage::
|
|
|
|
with self.buffered_logging("paperless", "documents") as log_buf:
|
|
# ... run progress loop ...
|
|
if options["verbose"]:
|
|
log_buf.render(self.console)
|
|
"""
|
|
handler = BufferingLogHandler()
|
|
handler.setFormatter(logging.Formatter("%(message)s"))
|
|
|
|
loggers: list[logging.Logger] = []
|
|
original_propagate: dict[str, bool] = {}
|
|
|
|
for name in logger_names:
|
|
log = logging.getLogger(name)
|
|
log.addHandler(handler)
|
|
original_propagate[name] = log.propagate
|
|
log.propagate = False
|
|
loggers.append(log)
|
|
|
|
try:
|
|
yield handler
|
|
finally:
|
|
for log in loggers:
|
|
log.removeHandler(handler)
|
|
log.propagate = original_propagate[log.name]
|
|
|
|
@staticmethod
|
|
def _progress_columns() -> tuple[Any, ...]:
|
|
"""
|
|
Return the standard set of progress bar columns.
|
|
|
|
Extracted so both _create_progress (standalone) and track_with_stats
|
|
(inside Live) use identical column configuration without duplication.
|
|
"""
|
|
return (
|
|
SpinnerColumn(),
|
|
TextColumn("[progress.description]{task.description}"),
|
|
BarColumn(),
|
|
MofNCompleteColumn(),
|
|
TimeElapsedColumn(),
|
|
TimeRemainingColumn(),
|
|
)
|
|
|
|
def _create_progress(self, description: str) -> Progress:
|
|
"""
|
|
Create a standalone Progress instance with its own stderr Console.
|
|
|
|
Use this for track(). For track_with_stats(), Progress is created
|
|
directly inside a Live context instead.
|
|
|
|
Progress output is directed to stderr to match the convention that
|
|
progress bars are transient UI feedback, not command output. This
|
|
mirrors the convention that progress bars are transient UI feedback and prevents progress bar rendering
|
|
from interfering with stdout-based assertions in tests or piped
|
|
command output.
|
|
|
|
Args:
|
|
description: Text to display alongside the progress bar.
|
|
|
|
Returns:
|
|
A Progress instance configured with appropriate columns.
|
|
"""
|
|
return Progress(
|
|
*self._progress_columns(),
|
|
console=Console(stderr=True),
|
|
transient=False,
|
|
)
|
|
|
|
def _get_iterable_length(self, iterable: Iterable[object]) -> int | None:
|
|
"""
|
|
Attempt to determine the length of an iterable without consuming it.
|
|
|
|
Tries .count() first (for Django querysets - executes SELECT COUNT(*)),
|
|
then falls back to len() for sequences.
|
|
|
|
Args:
|
|
iterable: The iterable to measure.
|
|
|
|
Returns:
|
|
The length if determinable, None otherwise.
|
|
"""
|
|
if isinstance(iterable, QuerySet):
|
|
return iterable.count()
|
|
|
|
if isinstance(iterable, Sized):
|
|
return len(iterable)
|
|
|
|
return None
|
|
|
|
def track(
|
|
self,
|
|
iterable: Iterable[T],
|
|
*,
|
|
description: str = "Processing...",
|
|
total: int | None = None,
|
|
) -> Generator[T, None, None]:
|
|
"""
|
|
Iterate over items with an optional progress bar.
|
|
|
|
Respects --no-progress-bar flag. When disabled, simply yields items
|
|
without any progress display.
|
|
|
|
Args:
|
|
iterable: The items to iterate over.
|
|
description: Text to display alongside the progress bar.
|
|
total: Total number of items. If None, attempts to determine
|
|
automatically via .count() (for querysets) or len().
|
|
|
|
Yields:
|
|
Items from the iterable.
|
|
|
|
Example:
|
|
for doc in self.track(documents, description="Renaming..."):
|
|
process(doc)
|
|
"""
|
|
if self.no_progress_bar:
|
|
yield from iterable
|
|
return
|
|
|
|
if total is None:
|
|
total = self._get_iterable_length(iterable)
|
|
|
|
with self._create_progress(description) as progress:
|
|
task_id = progress.add_task(description, total=total)
|
|
for item in iterable:
|
|
yield item
|
|
progress.advance(task_id)
|
|
|
|
def track_with_stats(
|
|
self,
|
|
iterable: Iterable[T],
|
|
*,
|
|
description: str = "Processing...",
|
|
stats_renderer: Callable[[], RenderableType],
|
|
total: int | None = None,
|
|
) -> Generator[T, None, None]:
|
|
"""
|
|
Iterate over items with a progress bar and a live-updating stats display.
|
|
|
|
The progress bar and stats renderable are combined in a single Live
|
|
context, so the stats panel re-renders in place below the progress bar
|
|
after each item is processed.
|
|
|
|
Respects --no-progress-bar flag. When disabled, yields items without
|
|
any display (stats are still updated by the caller's loop body, so
|
|
they will be accurate for any post-loop summary the caller prints).
|
|
|
|
Args:
|
|
iterable: The items to iterate over.
|
|
description: Text to display alongside the progress bar.
|
|
stats_renderer: Zero-argument callable that returns a Rich
|
|
renderable. Called after each item to refresh the display.
|
|
The caller typically closes over a mutable dataclass and
|
|
rebuilds a Table from it on each call.
|
|
total: Total number of items. If None, attempts to determine
|
|
automatically via .count() (for querysets) or len().
|
|
|
|
Yields:
|
|
Items from the iterable.
|
|
|
|
Example:
|
|
@dataclass
|
|
class Stats:
|
|
processed: int = 0
|
|
failed: int = 0
|
|
|
|
stats = Stats()
|
|
|
|
def render_stats() -> Table:
|
|
table = Table(box=None)
|
|
table.add_column("Processed")
|
|
table.add_column("Failed")
|
|
table.add_row(str(stats.processed), str(stats.failed))
|
|
return table
|
|
|
|
for item in self.track_with_stats(
|
|
items,
|
|
description="Importing...",
|
|
stats_renderer=render_stats,
|
|
):
|
|
try:
|
|
import_item(item)
|
|
stats.processed += 1
|
|
except Exception:
|
|
stats.failed += 1
|
|
"""
|
|
if self.no_progress_bar:
|
|
yield from iterable
|
|
return
|
|
|
|
if total is None:
|
|
total = self._get_iterable_length(iterable)
|
|
|
|
stderr_console = Console(stderr=True)
|
|
|
|
# Progress is created without its own console so Live controls rendering.
|
|
progress = Progress(*self._progress_columns())
|
|
task_id = progress.add_task(description, total=total)
|
|
|
|
with Live(
|
|
Group(progress, stats_renderer()),
|
|
console=stderr_console,
|
|
refresh_per_second=4,
|
|
) as live:
|
|
for item in iterable:
|
|
yield item
|
|
progress.advance(task_id)
|
|
live.update(Group(progress, stats_renderer()))
|
|
|
|
def process_parallel(
|
|
self,
|
|
fn: Callable[[T], R],
|
|
items: Sequence[T],
|
|
*,
|
|
description: str = "Processing...",
|
|
) -> Generator[ProcessResult[T, R], None, None]:
|
|
"""
|
|
Process items in parallel with progress tracking.
|
|
|
|
When --processes=1, runs sequentially in the main process without
|
|
spawning subprocesses. This is critical for testing, as multiprocessing
|
|
breaks fixtures, mocks, and database transactions.
|
|
|
|
When --processes > 1, uses ProcessPoolExecutor and automatically closes
|
|
database connections before spawning workers (required for PostgreSQL).
|
|
|
|
Args:
|
|
fn: Function to apply to each item. Must be picklable for parallel
|
|
execution (i.e., defined at module level, not a lambda or closure).
|
|
items: Sequence of items to process.
|
|
description: Text to display alongside the progress bar.
|
|
|
|
Yields:
|
|
ProcessResult for each item, containing the item, result, and any error.
|
|
|
|
Example:
|
|
def regenerate_thumbnail(doc_id: int) -> Path:
|
|
...
|
|
|
|
for result in self.process_parallel(regenerate_thumbnail, doc_ids):
|
|
if result.error:
|
|
self.console.print(f"[red]Failed {result.item}[/red]")
|
|
"""
|
|
total = len(items)
|
|
|
|
if self.process_count == 1:
|
|
# Sequential execution in main process - critical for testing, so we don't fork in fork, etc
|
|
yield from self._process_sequential(fn, items, description, total)
|
|
else:
|
|
# Parallel execution with ProcessPoolExecutor
|
|
yield from self._process_parallel(fn, items, description, total)
|
|
|
|
def _process_sequential(
|
|
self,
|
|
fn: Callable[[T], R],
|
|
items: Sequence[T],
|
|
description: str,
|
|
total: int,
|
|
) -> Generator[ProcessResult[T, R], None, None]:
|
|
"""Process items sequentially in the main process."""
|
|
for item in self.track(items, description=description, total=total):
|
|
try:
|
|
result = fn(item)
|
|
yield ProcessResult(item=item, result=result, error=None)
|
|
except Exception as e:
|
|
yield ProcessResult(item=item, result=None, error=e)
|
|
|
|
def _process_parallel(
|
|
self,
|
|
fn: Callable[[T], R],
|
|
items: Sequence[T],
|
|
description: str,
|
|
total: int,
|
|
) -> Generator[ProcessResult[T, R], None, None]:
|
|
"""Process items in parallel using ProcessPoolExecutor."""
|
|
|
|
# Close database connections before forking - required for PostgreSQL
|
|
db.connections.close_all()
|
|
|
|
with self._create_progress(description) as progress:
|
|
task_id = progress.add_task(description, total=total)
|
|
|
|
with ProcessPoolExecutor(max_workers=self.process_count) as executor:
|
|
# Submit all tasks and map futures back to items
|
|
future_to_item = {executor.submit(fn, item): item for item in items}
|
|
|
|
# Yield results as they complete
|
|
for future in as_completed(future_to_item):
|
|
item = future_to_item[future]
|
|
try:
|
|
result = future.result()
|
|
yield ProcessResult(item=item, result=result, error=None)
|
|
except Exception as e:
|
|
yield ProcessResult(item=item, result=None, error=e)
|
|
finally:
|
|
progress.advance(task_id)
|