From 9e8b5ddf0858625af71accb2eb636271373df537 Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Mon, 30 Mar 2026 16:26:49 -0700 Subject: [PATCH] Refactor: consolidate IterWrapper/identity into documents.utils Move the duplicated `IterWrapper` type alias and `identity` function from tasks.py, _backend.py, sanity_checker.py, and paperless_ai/indexing.py into a single location in documents/utils.py. All four callers now import from there. Co-Authored-By: Claude Sonnet 4.6 --- src/documents/sanity_checker.py | 17 +++-------------- src/documents/search/_backend.py | 9 ++------- src/documents/tasks.py | 14 +++----------- src/documents/utils.py | 13 +++++++++++++ src/paperless_ai/indexing.py | 14 +++----------- 5 files changed, 24 insertions(+), 43 deletions(-) diff --git a/src/documents/sanity_checker.py b/src/documents/sanity_checker.py index b53ed8cfb..0b3dea368 100644 --- a/src/documents/sanity_checker.py +++ b/src/documents/sanity_checker.py @@ -9,19 +9,14 @@ to wrap the document queryset (e.g., with a progress bar). The default is an identity function that adds no overhead. """ -from __future__ import annotations - import logging import uuid from collections import defaultdict -from collections.abc import Callable -from collections.abc import Iterable from collections.abc import Iterator from pathlib import Path from typing import TYPE_CHECKING from typing import Final from typing import TypedDict -from typing import TypeVar from celery import states from django.conf import settings @@ -29,14 +24,13 @@ from django.utils import timezone from documents.models import Document from documents.models import PaperlessTask +from documents.utils import IterWrapper from documents.utils import compute_checksum +from documents.utils import identity from paperless.config import GeneralConfig logger = logging.getLogger("paperless.sanity_checker") -_T = TypeVar("_T") -IterWrapper = Callable[[Iterable[_T]], Iterable[_T]] - class MessageEntry(TypedDict): """A single sanity check message with its severity level.""" @@ -45,11 +39,6 @@ class MessageEntry(TypedDict): message: str -def _identity(iterable: Iterable[_T]) -> Iterable[_T]: - """Pass through an iterable unchanged (default iter_wrapper).""" - return iterable - - class SanityCheckMessages: """Collects sanity check messages grouped by document primary key. @@ -296,7 +285,7 @@ def _check_document( def check_sanity( *, scheduled: bool = True, - iter_wrapper: IterWrapper[Document] = _identity, + iter_wrapper: IterWrapper[Document] = identity, ) -> SanityCheckMessages: """Run a full sanity check on the document archive. diff --git a/src/documents/search/_backend.py b/src/documents/search/_backend.py index 6bee65f1e..d263c7175 100644 --- a/src/documents/search/_backend.py +++ b/src/documents/search/_backend.py @@ -26,10 +26,10 @@ from documents.search._schema import build_schema from documents.search._schema import open_or_rebuild_index from documents.search._schema import wipe_index from documents.search._tokenizer import register_tokenizers +from documents.utils import identity if TYPE_CHECKING: from collections.abc import Callable - from collections.abc import Iterable from pathlib import Path from django.contrib.auth.base_user import AbstractBaseUser @@ -45,11 +45,6 @@ _AUTOCOMPLETE_REGEX_TIMEOUT = 1.0 # seconds; guards against ReDoS on untrusted T = TypeVar("T") -def _identity(iterable: Iterable[T]) -> Iterable[T]: - """Default iter_wrapper that passes documents through unchanged for indexing.""" - return iterable - - def _ascii_fold(s: str) -> str: """ Normalize unicode to ASCII equivalent characters for search consistency. @@ -764,7 +759,7 @@ class TantivyBackend: self._ensure_open() return WriteBatch(self, lock_timeout) - def rebuild(self, documents: QuerySet, iter_wrapper: Callable = _identity) -> None: + def rebuild(self, documents: QuerySet, iter_wrapper: Callable = identity) -> None: """ Rebuild the entire search index from scratch. diff --git a/src/documents/tasks.py b/src/documents/tasks.py index f86837c95..ae65a5fbe 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -4,11 +4,9 @@ import shutil import uuid import zipfile from collections.abc import Callable -from collections.abc import Iterable from pathlib import Path from tempfile import TemporaryDirectory from tempfile import mkstemp -from typing import TypeVar from celery import Task from celery import shared_task @@ -58,7 +56,9 @@ from documents.signals import document_updated from documents.signals.handlers import cleanup_document_deletion from documents.signals.handlers import run_workflows from documents.signals.handlers import send_websocket_document_updated +from documents.utils import IterWrapper from documents.utils import compute_checksum +from documents.utils import identity from documents.workflows.utils import get_workflows_for_trigger from paperless.config import AIConfig from paperless.parsers import ParserContext @@ -67,19 +67,11 @@ from paperless_ai.indexing import llm_index_add_or_update_document from paperless_ai.indexing import llm_index_remove_document from paperless_ai.indexing import update_llm_index -_T = TypeVar("_T") -IterWrapper = Callable[[Iterable[_T]], Iterable[_T]] - - if settings.AUDIT_LOG_ENABLED: from auditlog.models import LogEntry logger = logging.getLogger("paperless.tasks") -def _identity(iterable: Iterable[_T]) -> Iterable[_T]: - return iterable - - @shared_task def index_optimize() -> None: logger.info( @@ -622,7 +614,7 @@ def update_document_parent_tags(tag: Tag, new_parent: Tag) -> None: @shared_task def llmindex_index( *, - iter_wrapper: IterWrapper[Document] = _identity, + iter_wrapper: IterWrapper[Document] = identity, rebuild=False, scheduled=True, auto=False, diff --git a/src/documents/utils.py b/src/documents/utils.py index 975185a5f..2ed6758dd 100644 --- a/src/documents/utils.py +++ b/src/documents/utils.py @@ -1,14 +1,27 @@ import hashlib import logging import shutil +from collections.abc import Callable +from collections.abc import Iterable from os import utime from pathlib import Path from subprocess import CompletedProcess from subprocess import run +from typing import TypeVar from django.conf import settings from PIL import Image +_T = TypeVar("_T") + +# A function that wraps an iterable — typically used to inject a progress bar. +IterWrapper = Callable[[Iterable[_T]], Iterable[_T]] + + +def identity(iterable: Iterable[_T]) -> Iterable[_T]: + """Return the iterable unchanged; the no-op default for IterWrapper.""" + return iterable + def _coerce_to_path( source: Path | str, diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index bee8f0dd9..a54492f1f 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -1,11 +1,8 @@ import logging import shutil -from collections.abc import Callable -from collections.abc import Iterable from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING -from typing import TypeVar from celery import states from django.conf import settings @@ -13,22 +10,17 @@ from django.utils import timezone from documents.models import Document from documents.models import PaperlessTask +from documents.utils import IterWrapper +from documents.utils import identity from paperless_ai.embedding import build_llm_index_text from paperless_ai.embedding import get_embedding_dim from paperless_ai.embedding import get_embedding_model -_T = TypeVar("_T") -IterWrapper = Callable[[Iterable[_T]], Iterable[_T]] - if TYPE_CHECKING: from llama_index.core import VectorStoreIndex from llama_index.core.schema import BaseNode -def _identity(iterable: Iterable[_T]) -> Iterable[_T]: - return iterable - - logger = logging.getLogger("paperless_ai.indexing") @@ -176,7 +168,7 @@ def vector_store_file_exists(): def update_llm_index( *, - iter_wrapper: IterWrapper[Document] = _identity, + iter_wrapper: IterWrapper[Document] = identity, rebuild=False, ) -> str: """