Handles the assignment case/logging. Might need some testing?

This commit is contained in:
Trenton H
2026-03-02 16:08:15 -08:00
parent c3dd7615e0
commit 01144b4b1a
2 changed files with 187 additions and 57 deletions
+123
View File
@@ -6,12 +6,14 @@ Provides automatic progress bar and multiprocessing support with minimal boilerp
from __future__ import annotations
import logging
import os
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Sized
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import as_completed
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Any
@@ -23,6 +25,7 @@ from django import db
from django.core.management import CommandError
from django.db.models import QuerySet
from django_rich.management import RichCommand
from rich import box
from rich.console import Console
from rich.console import Group
from rich.console import RenderableType
@@ -34,6 +37,8 @@ from rich.progress import SpinnerColumn
from rich.progress import TextColumn
from rich.progress import TimeElapsedColumn
from rich.progress import TimeRemainingColumn
from rich.table import Table
from rich.text import Text
if TYPE_CHECKING:
from collections.abc import Generator
@@ -45,6 +50,78 @@ T = TypeVar("T")
R = TypeVar("R")
@dataclass(slots=True, frozen=True)
class _BufferedRecord:
level: int
name: str
message: str
class BufferingLogHandler(logging.Handler):
"""Captures log records during a command run for deferred rendering.
Attach to a logger before a long operation and call ``render()``
afterwards to emit the buffered records via Rich, optionally filtered
by minimum level.
"""
def __init__(self) -> None:
super().__init__()
self._records: list[_BufferedRecord] = []
def emit(self, record: logging.LogRecord) -> None:
self._records.append(
_BufferedRecord(
level=record.levelno,
name=record.name,
message=self.format(record),
),
)
def render(
self,
console: Console,
*,
min_level: int = logging.DEBUG,
title: str = "Log Output",
) -> None:
records = [r for r in self._records if r.level >= min_level]
if not records:
return
table = Table(
title=title,
show_header=True,
header_style="bold",
show_lines=False,
box=box.SIMPLE,
)
table.add_column("Level", style="bold", width=8)
table.add_column("Logger", style="dim")
table.add_column("Message", no_wrap=False)
_level_styles: dict[int, str] = {
logging.DEBUG: "dim",
logging.INFO: "cyan",
logging.WARNING: "yellow",
logging.ERROR: "red",
logging.CRITICAL: "bold red",
}
for record in records:
style = _level_styles.get(record.level, "")
table.add_row(
Text(logging.getLevelName(record.level), style=style),
record.name,
record.message,
)
console.print(table)
def clear(self) -> None:
self._records.clear()
@dataclass(frozen=True, slots=True)
class ProcessResult(Generic[T, R]):
"""
@@ -161,6 +238,46 @@ class PaperlessCommand(RichCommand):
return super().execute(*args, **options)
@contextmanager
def buffered_logging(
self,
*logger_names: str,
level: int = logging.DEBUG,
) -> Generator[BufferingLogHandler, None, None]:
"""Context manager that captures log output from named loggers.
Installs a ``BufferingLogHandler`` on each named logger for the
duration of the block, suppressing propagation to avoid interleaving
with the Rich live display. The handler is removed on exit regardless
of whether an exception occurred.
Usage::
with self.buffered_logging("paperless", "documents") as log_buf:
# ... run progress loop ...
if options["verbose"]:
log_buf.render(self.console)
"""
handler = BufferingLogHandler()
handler.setFormatter(logging.Formatter("%(message)s"))
loggers: list[logging.Logger] = []
original_propagate: dict[str, bool] = {}
for name in logger_names:
log = logging.getLogger(name)
log.addHandler(handler)
original_propagate[name] = log.propagate
log.propagate = False
loggers.append(log)
try:
yield handler
finally:
for log in loggers:
log.removeHandler(handler)
log.propagate = original_propagate[log.name]
@staticmethod
def _progress_columns() -> tuple[Any, ...]:
"""
@@ -381,8 +498,10 @@ class PaperlessCommand(RichCommand):
total = len(items)
if self.process_count == 1:
# Sequential execution in main process - critical for testing, so we don't fork in fork, etc
yield from self._process_sequential(fn, items, description, total)
else:
# Parallel execution with ProcessPoolExecutor
yield from self._process_parallel(fn, items, description, total)
def _process_sequential(
@@ -408,14 +527,18 @@ class PaperlessCommand(RichCommand):
total: int,
) -> Generator[ProcessResult[T, R], None, None]:
"""Process items in parallel using ProcessPoolExecutor."""
# Close database connections before forking - required for PostgreSQL
db.connections.close_all()
with self._create_progress(description) as progress:
task_id = progress.add_task(description, total=total)
with ProcessPoolExecutor(max_workers=self.process_count) as executor:
# Submit all tasks and map futures back to items
future_to_item = {executor.submit(fn, item): item for item in items}
# Yield results as they complete
for future in as_completed(future_to_item):
item = future_to_item[future]
try:
@@ -260,69 +260,74 @@ class Command(PaperlessCommand):
def render_stats() -> RenderableType:
return _build_stats_table(stats, suggest=suggest)
for document in self.track_with_stats(
documents,
description="Retagging...",
stats_renderer=render_stats,
):
suggestion = DocumentSuggestion(document=document)
with self.buffered_logging(
"paperless",
"paperless.handlers",
"documents",
) as log_buf:
for document in self.track_with_stats(
documents,
description="Retagging...",
stats_renderer=render_stats,
):
suggestion = DocumentSuggestion(document=document)
if do_correspondent:
correspondent = set_correspondent(
None,
document,
classifier=classifier,
replace=overwrite,
use_first=use_first,
dry_run=suggest,
)
if correspondent is not None:
stats.correspondents += 1
suggestion.correspondent = correspondent
if do_correspondent:
correspondent = set_correspondent(
None,
document,
classifier=classifier,
replace=overwrite,
use_first=use_first,
dry_run=suggest,
)
if correspondent is not None:
stats.correspondents += 1
suggestion.correspondent = correspondent
if do_document_type:
document_type = set_document_type(
None,
document,
classifier=classifier,
replace=overwrite,
use_first=use_first,
dry_run=suggest,
)
if document_type is not None:
stats.document_types += 1
suggestion.document_type = document_type
if do_document_type:
document_type = set_document_type(
None,
document,
classifier=classifier,
replace=overwrite,
use_first=use_first,
dry_run=suggest,
)
if document_type is not None:
stats.document_types += 1
suggestion.document_type = document_type
if do_tags:
tags_to_add, tags_to_remove = set_tags(
None,
document,
classifier=classifier,
replace=overwrite,
dry_run=suggest,
)
stats.tags_added += len(tags_to_add)
stats.tags_removed += len(tags_to_remove)
suggestion.tags_to_add = frozenset(tags_to_add)
suggestion.tags_to_remove = frozenset(tags_to_remove)
if do_tags:
tags_to_add, tags_to_remove = set_tags(
None,
document,
classifier=classifier,
replace=overwrite,
dry_run=suggest,
)
stats.tags_added += len(tags_to_add)
stats.tags_removed += len(tags_to_remove)
suggestion.tags_to_add = frozenset(tags_to_add)
suggestion.tags_to_remove = frozenset(tags_to_remove)
if do_storage_path:
storage_path = set_storage_path(
None,
document,
classifier=classifier,
replace=overwrite,
use_first=use_first,
dry_run=suggest,
)
if storage_path is not None:
stats.storage_paths += 1
suggestion.storage_path = storage_path
if do_storage_path:
storage_path = set_storage_path(
None,
document,
classifier=classifier,
replace=overwrite,
use_first=use_first,
dry_run=suggest,
)
if storage_path is not None:
stats.storage_paths += 1
suggestion.storage_path = storage_path
stats.documents_processed += 1
stats.documents_processed += 1
if suggest:
suggestions.append(suggestion)
if suggest:
suggestions.append(suggestion)
# Post-loop output
if suggest:
@@ -333,3 +338,5 @@ class Command(PaperlessCommand):
self.console.print("[green]No changes suggested.[/green]")
else:
self.console.print(_build_summary_table(stats))
log_buf.render(self.console, min_level=logging.INFO, title="Retagger Log")