diff --git a/src/documents/management/commands/document_index.py b/src/documents/management/commands/document_index.py index 1fa4f5a70..b46d681b8 100644 --- a/src/documents/management/commands/document_index.py +++ b/src/documents/management/commands/document_index.py @@ -1,12 +1,11 @@ -from django.core.management import BaseCommand from django.db import transaction -from documents.management.commands.mixins import ProgressBarMixin +from documents.management.commands.base import PaperlessCommand from documents.tasks import index_optimize from documents.tasks import index_reindex -class Command(ProgressBarMixin, BaseCommand): +class Command(PaperlessCommand): help = "Manages the document index." def add_arguments(self, parser): @@ -17,6 +16,11 @@ class Command(ProgressBarMixin, BaseCommand): self.handle_progress_bar_mixin(**options) with transaction.atomic(): if options["command"] == "reindex": - index_reindex(progress_bar_disable=self.no_progress_bar) + index_reindex( + iter_wrapper=lambda docs: self.track( + docs, + description="Indexing documents...", + ), + ) elif options["command"] == "optimize": index_optimize() diff --git a/src/documents/management/commands/document_llmindex.py b/src/documents/management/commands/document_llmindex.py index d2df02ed9..6af1c7c9f 100644 --- a/src/documents/management/commands/document_llmindex.py +++ b/src/documents/management/commands/document_llmindex.py @@ -1,22 +1,22 @@ -from django.core.management import BaseCommand -from django.db import transaction +from typing import Any -from documents.management.commands.mixins import ProgressBarMixin +from documents.management.commands.base import PaperlessCommand from documents.tasks import llmindex_index -class Command(ProgressBarMixin, BaseCommand): +class Command(PaperlessCommand): help = "Manages the LLM-based vector index for Paperless." - def add_arguments(self, parser): + def add_arguments(self, parser: Any) -> None: + super().add_arguments(parser) parser.add_argument("command", choices=["rebuild", "update"]) - self.add_argument_progress_bar_mixin(parser) - def handle(self, *args, **options): - self.handle_progress_bar_mixin(**options) - with transaction.atomic(): - llmindex_index( - progress_bar_disable=self.no_progress_bar, - rebuild=options["command"] == "rebuild", - scheduled=False, - ) + def handle(self, *args: Any, **options: Any) -> None: + llmindex_index( + rebuild=options["command"] == "rebuild", + scheduled=False, + iter_wrapper=lambda docs: self.track( + docs, + description="Indexing documents...", + ), + ) diff --git a/src/documents/tasks.py b/src/documents/tasks.py index cee038072..e600d997b 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -4,11 +4,13 @@ import logging import shutil import uuid import zipfile +from collections.abc import Callable +from collections.abc import Iterable from pathlib import Path from tempfile import TemporaryDirectory from tempfile import mkstemp +from typing import TypeVar -import tqdm from celery import Task from celery import shared_task from celery import states @@ -66,11 +68,19 @@ from paperless_ai.indexing import llm_index_add_or_update_document from paperless_ai.indexing import llm_index_remove_document from paperless_ai.indexing import update_llm_index +_T = TypeVar("_T") +IterWrapper = Callable[[Iterable[_T]], Iterable[_T]] + + if settings.AUDIT_LOG_ENABLED: from auditlog.models import LogEntry logger = logging.getLogger("paperless.tasks") +def _identity(iterable: Iterable[_T]) -> Iterable[_T]: + return iterable + + @shared_task def index_optimize() -> None: ix = index.open_index() @@ -78,13 +88,13 @@ def index_optimize() -> None: writer.commit(optimize=True) -def index_reindex(*, progress_bar_disable=False) -> None: +def index_reindex(*, iter_wrapper: IterWrapper[Document] = _identity) -> None: documents = Document.objects.all() ix = index.open_index(recreate=True) with AsyncWriter(ix) as writer: - for document in tqdm.tqdm(documents, disable=progress_bar_disable): + for document in iter_wrapper(documents): index.update_document(writer, document) @@ -594,7 +604,7 @@ def update_document_parent_tags(tag: Tag, new_parent: Tag) -> None: @shared_task def llmindex_index( *, - progress_bar_disable=True, + iter_wrapper: IterWrapper[Document] = _identity, rebuild=False, scheduled=True, auto=False, @@ -617,7 +627,7 @@ def llmindex_index( try: result = update_llm_index( - progress_bar_disable=progress_bar_disable, + iter_wrapper=iter_wrapper, rebuild=rebuild, ) task.status = states.SUCCESS diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index 654c56f3b..53e4a9796 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -1,11 +1,13 @@ import logging import shutil +from collections.abc import Callable +from collections.abc import Iterable from datetime import timedelta from pathlib import Path +from typing import TypeVar import faiss import llama_index.core.settings as llama_settings -import tqdm from celery import states from django.conf import settings from django.utils import timezone @@ -29,6 +31,14 @@ from paperless_ai.embedding import build_llm_index_text from paperless_ai.embedding import get_embedding_dim from paperless_ai.embedding import get_embedding_model +_T = TypeVar("_T") +IterWrapper = Callable[[Iterable[_T]], Iterable[_T]] + + +def _identity(iterable: Iterable[_T]) -> Iterable[_T]: + return iterable + + logger = logging.getLogger("paperless_ai.indexing") @@ -156,7 +166,11 @@ def vector_store_file_exists(): return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists() -def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str: +def update_llm_index( + *, + iter_wrapper: IterWrapper[Document] = _identity, + rebuild=False, +) -> str: """ Rebuild or update the LLM index. """ @@ -176,7 +190,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str: embed_model = get_embedding_model() llama_settings.Settings.embed_model = embed_model storage_context = get_or_create_storage_context(rebuild=True) - for document in tqdm.tqdm(documents, disable=progress_bar_disable): + for document in iter_wrapper(documents): document_nodes = build_document_node(document) nodes.extend(document_nodes) @@ -184,7 +198,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str: nodes=nodes, storage_context=storage_context, embed_model=embed_model, - show_progress=not progress_bar_disable, + show_progress=False, ) msg = "LLM index rebuilt successfully." else: @@ -196,7 +210,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str: for node in index.docstore.get_nodes(all_node_ids) } - for document in tqdm.tqdm(documents, disable=progress_bar_disable): + for document in iter_wrapper(documents): doc_id = str(document.id) document_modified = document.modified.isoformat()