diff --git a/src/documents/management/commands/document_create_classifier.py b/src/documents/management/commands/document_create_classifier.py index b662195a7..3fa7bdb29 100644 --- a/src/documents/management/commands/document_create_classifier.py +++ b/src/documents/management/commands/document_create_classifier.py @@ -22,7 +22,6 @@ class Command(PaperlessCommand): self.buffered_logging("paperless.classifier"), ): train_classifier( - scheduled=False, status_callback=lambda msg: self.console.print(f" {msg}"), ) diff --git a/src/documents/management/commands/document_llmindex.py b/src/documents/management/commands/document_llmindex.py index 3b9e3440b..9823b1b87 100644 --- a/src/documents/management/commands/document_llmindex.py +++ b/src/documents/management/commands/document_llmindex.py @@ -17,7 +17,6 @@ class Command(PaperlessCommand): def handle(self, *args: Any, **options: Any) -> None: llmindex_index( rebuild=options["command"] == "rebuild", - scheduled=False, iter_wrapper=lambda docs: self.track( docs, description="Indexing documents...", diff --git a/src/documents/management/commands/document_sanity_checker.py b/src/documents/management/commands/document_sanity_checker.py index 598ddf7bb..c8921d064 100644 --- a/src/documents/management/commands/document_sanity_checker.py +++ b/src/documents/management/commands/document_sanity_checker.py @@ -111,7 +111,6 @@ class Command(PaperlessCommand): def handle(self, *args: Any, **options: Any) -> None: messages = check_sanity( - scheduled=False, iter_wrapper=lambda docs: self.track( docs, description="Checking documents...", diff --git a/src/documents/sanity_checker.py b/src/documents/sanity_checker.py index ee05fe16c..b824e6683 100644 --- a/src/documents/sanity_checker.py +++ b/src/documents/sanity_checker.py @@ -10,7 +10,6 @@ is an identity function that adds no overhead. """ import logging -import uuid from collections import defaultdict from collections.abc import Iterator from pathlib import Path @@ -19,10 +18,8 @@ from typing import Final from typing import TypedDict from django.conf import settings -from django.utils import timezone from documents.models import Document -from documents.models import PaperlessTask from documents.utils import IterWrapper from documents.utils import compute_checksum from documents.utils import identity @@ -286,33 +283,17 @@ def _check_document( def check_sanity( *, - scheduled: bool = True, iter_wrapper: IterWrapper[Document] = identity, ) -> SanityCheckMessages: """Run a full sanity check on the document archive. Args: - scheduled: Whether this is a scheduled (automatic) or manual check. - Controls the task type recorded in the database. iter_wrapper: A callable that wraps the document iterable, e.g., for progress bar display. Defaults to identity (no wrapping). Returns: A SanityCheckMessages instance containing all detected issues. """ - paperless_task = PaperlessTask.objects.create( - task_id=uuid.uuid4(), - trigger_source=( - PaperlessTask.TriggerSource.SCHEDULED - if scheduled - else PaperlessTask.TriggerSource.MANUAL - ), - task_type=PaperlessTask.TaskType.SANITY_CHECK, - status=PaperlessTask.Status.STARTED, - date_created=timezone.now(), - date_started=timezone.now(), - ) - messages = SanityCheckMessages() present_files = _build_present_files() @@ -331,26 +312,4 @@ def check_sanity( for extra_file in present_files: messages.warning(None, f"Orphaned file in media dir: {extra_file}") - paperless_task.status = ( - PaperlessTask.Status.SUCCESS - if not messages.has_error - else PaperlessTask.Status.FAILURE - ) - if messages.total_issue_count == 0: - paperless_task.result_message = "No issues found." - else: - parts: list[str] = [] - if messages.document_error_count: - parts.append(f"{messages.document_error_count} document(s) with errors") - if messages.document_warning_count: - parts.append(f"{messages.document_warning_count} document(s) with warnings") - if messages.global_warning_count: - parts.append(f"{messages.global_warning_count} global warning(s)") - paperless_task.result_message = ", ".join(parts) + " found." - if messages.has_error: - paperless_task.result_message += " Check logs for details." - - paperless_task.date_done = timezone.now() - paperless_task.save(update_fields=["status", "result_message", "date_done"]) - return messages diff --git a/src/documents/tasks.py b/src/documents/tasks.py index f94596e17..3af067d11 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -40,7 +40,6 @@ from documents.models import Correspondent from documents.models import CustomFieldInstance from documents.models import Document from documents.models import DocumentType -from documents.models import PaperlessTask from documents.models import ShareLink from documents.models import ShareLinkBundle from documents.models import StoragePath @@ -83,19 +82,8 @@ def index_optimize() -> None: @shared_task def train_classifier( *, - scheduled=True, status_callback: Callable[[str], None] | None = None, -) -> None: - task = PaperlessTask.objects.create( - trigger_source=PaperlessTask.TriggerSource.SCHEDULED - if scheduled - else PaperlessTask.TriggerSource.MANUAL, - task_id=uuid.uuid4(), - task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER, - status=PaperlessTask.Status.STARTED, - date_created=timezone.now(), - date_started=timezone.now(), - ) +) -> str: if ( not Tag.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists() and not DocumentType.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists() @@ -109,37 +97,22 @@ def train_classifier( if settings.MODEL_FILE.exists(): logger.info(f"Removing {settings.MODEL_FILE} so it won't be used") settings.MODEL_FILE.unlink() - task.status = PaperlessTask.Status.SUCCESS - task.result_message = result - task.date_done = timezone.now() - task.save() - return + return result classifier = load_classifier() if not classifier: classifier = DocumentClassifier() - try: - if classifier.train(status_callback=status_callback): - logger.info( - f"Saving updated classifier model to {settings.MODEL_FILE}...", - ) - classifier.save() - task.result_message = "Training completed successfully" - else: - logger.debug("Training data unchanged.") - task.result_message = "Training data unchanged" - - task.status = PaperlessTask.Status.SUCCESS - - except Exception as e: - logger.warning("Classifier error: " + str(e)) - task.status = PaperlessTask.Status.FAILURE - task.result_message = str(e) - - task.date_done = timezone.now() - task.save(update_fields=["status", "result_message", "date_done"]) + if classifier.train(status_callback=status_callback): + logger.info( + f"Saving updated classifier model to {settings.MODEL_FILE}...", + ) + classifier.save() + return "Training completed successfully" + else: + logger.debug("Training data unchanged.") + return "Training data unchanged" @shared_task(bind=True) @@ -230,8 +203,8 @@ def consume_file( @shared_task -def sanity_check(*, scheduled=True, raise_on_error=True): - messages = sanity_checker.check_sanity(scheduled=scheduled) +def sanity_check(*, raise_on_error: bool = True) -> str: + messages = sanity_checker.check_sanity() messages.log_messages() if not messages.has_error and not messages.has_warning and not messages.has_info: @@ -634,42 +607,19 @@ def update_document_parent_tags(tag: Tag, new_parent: Tag) -> None: def llmindex_index( *, iter_wrapper: IterWrapper[Document] = identity, - rebuild=False, - scheduled=True, - auto=False, -) -> None: + rebuild: bool = False, +) -> str | None: ai_config = AIConfig() - if ai_config.llm_index_enabled: - task = PaperlessTask.objects.create( - trigger_source=PaperlessTask.TriggerSource.SCHEDULED - if scheduled - else PaperlessTask.TriggerSource.SYSTEM - if auto - else PaperlessTask.TriggerSource.MANUAL, - task_id=uuid.uuid4(), - task_type=PaperlessTask.TaskType.LLM_INDEX, - status=PaperlessTask.Status.STARTED, - date_created=timezone.now(), - date_started=timezone.now(), - ) - from paperless_ai.indexing import update_llm_index - - try: - result = update_llm_index( - iter_wrapper=iter_wrapper, - rebuild=rebuild, - ) - task.status = PaperlessTask.Status.SUCCESS - task.result_message = result - except Exception as e: - logger.error("LLM index error: " + str(e)) - task.status = PaperlessTask.Status.FAILURE - task.result_message = str(e) - - task.date_done = timezone.now() - task.save(update_fields=["status", "result_message", "date_done"]) - else: + if not ai_config.llm_index_enabled: logger.info("LLM index is disabled, skipping update.") + return None + + from paperless_ai.indexing import update_llm_index + + return update_llm_index( + iter_wrapper=iter_wrapper, + rebuild=rebuild, + ) @shared_task diff --git a/src/documents/tests/test_tasks.py b/src/documents/tests/test_tasks.py index 4502423b3..0db6a9559 100644 --- a/src/documents/tests/test_tasks.py +++ b/src/documents/tests/test_tasks.py @@ -4,7 +4,6 @@ from pathlib import Path from unittest import mock import pytest -from celery import states from django.conf import settings from django.test import TestCase from django.test import override_settings @@ -14,7 +13,6 @@ from documents import tasks from documents.models import Correspondent from documents.models import Document from documents.models import DocumentType -from documents.models import PaperlessTask from documents.models import Tag from documents.sanity_checker import SanityCheckFailedException from documents.sanity_checker import SanityCheckMessages @@ -40,7 +38,8 @@ class TestClassifier(DirectoriesMixin, FileSystemAssertsMixin, TestCase): def test_train_classifier_with_auto_tag(self, load_classifier) -> None: load_classifier.return_value = None Tag.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") - tasks.train_classifier() + with self.assertRaises(ValueError): + tasks.train_classifier() load_classifier.assert_called_once() self.assertIsNotFile(settings.MODEL_FILE) @@ -48,7 +47,8 @@ class TestClassifier(DirectoriesMixin, FileSystemAssertsMixin, TestCase): def test_train_classifier_with_auto_type(self, load_classifier) -> None: load_classifier.return_value = None DocumentType.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") - tasks.train_classifier() + with self.assertRaises(ValueError): + tasks.train_classifier() load_classifier.assert_called_once() self.assertIsNotFile(settings.MODEL_FILE) @@ -56,7 +56,8 @@ class TestClassifier(DirectoriesMixin, FileSystemAssertsMixin, TestCase): def test_train_classifier_with_auto_correspondent(self, load_classifier) -> None: load_classifier.return_value = None Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test") - tasks.train_classifier() + with self.assertRaises(ValueError): + tasks.train_classifier() load_classifier.assert_called_once() self.assertIsNotFile(settings.MODEL_FILE) @@ -298,7 +299,7 @@ class TestAIIndex(DirectoriesMixin, TestCase): WHEN: - llmindex_index task is called THEN: - - update_llm_index is called, and the task is marked as success + - update_llm_index is called and its result is returned """ Document.objects.create( title="test", @@ -308,13 +309,9 @@ class TestAIIndex(DirectoriesMixin, TestCase): # lazy-loaded so mock the actual function with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index: update_llm_index.return_value = "LLM index updated successfully." - tasks.llmindex_index() + result = tasks.llmindex_index() update_llm_index.assert_called_once() - task = PaperlessTask.objects.get( - task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE, - ) - self.assertEqual(task.status, states.SUCCESS) - self.assertEqual(task.result, "LLM index updated successfully.") + self.assertEqual(result, "LLM index updated successfully.") @override_settings( AI_ENABLED=True, @@ -325,9 +322,9 @@ class TestAIIndex(DirectoriesMixin, TestCase): GIVEN: - Document exists, AI is enabled, llm index backend is set WHEN: - - llmindex_index task is called + - llmindex_index task is called and update_llm_index raises an exception THEN: - - update_llm_index raises an exception, and the task is marked as failure + - the exception propagates to the caller """ Document.objects.create( title="test", @@ -337,13 +334,9 @@ class TestAIIndex(DirectoriesMixin, TestCase): # lazy-loaded so mock the actual function with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index: update_llm_index.side_effect = Exception("LLM index update failed.") - tasks.llmindex_index() + with self.assertRaises(Exception, msg="LLM index update failed."): + tasks.llmindex_index() update_llm_index.assert_called_once() - task = PaperlessTask.objects.get( - task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE, - ) - self.assertEqual(task.status, states.FAILURE) - self.assertIn("LLM index update failed.", task.result) def test_update_document_in_llm_index(self) -> None: """ diff --git a/src/paperless/views.py b/src/paperless/views.py index c9ded4c0d..beec77358 100644 --- a/src/paperless/views.py +++ b/src/paperless/views.py @@ -427,10 +427,9 @@ class ApplicationConfigurationViewSet(ModelViewSet[ApplicationConfiguration]): and not vector_store_file_exists() ): # AI index was just enabled and vector store file does not exist - llmindex_index.delay( - rebuild=True, - scheduled=False, - auto=True, + llmindex_index.apply_async( + kwargs={"rebuild": True}, + headers={"trigger_source": "system"}, ) diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index d9b1a7f90..81f2bff3f 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -37,7 +37,10 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool: if has_running or has_recent: return False - llmindex_index.delay(rebuild=rebuild, scheduled=False, auto=True) + llmindex_index.apply_async( + kwargs={"rebuild": rebuild}, + headers={"trigger_source": "system"}, + ) logger.warning( "Queued LLM index update%s: %s", " (rebuild)" if rebuild else "",