mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-04-15 20:48:53 +00:00
refactor(tasks): remove manual PaperlessTask creation and scheduled/auto params
All task records are now created exclusively via Celery signals (Task 2). Removed PaperlessTask creation/update from train_classifier, sanity_check, llmindex_index, and check_sanity. Removed scheduled= and auto= parameters from all 7 call sites. Updated apply_async callers to use trigger_source headers instead. Exceptions now propagate naturally from task functions. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -22,7 +22,6 @@ class Command(PaperlessCommand):
|
||||
self.buffered_logging("paperless.classifier"),
|
||||
):
|
||||
train_classifier(
|
||||
scheduled=False,
|
||||
status_callback=lambda msg: self.console.print(f" {msg}"),
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ class Command(PaperlessCommand):
|
||||
def handle(self, *args: Any, **options: Any) -> None:
|
||||
llmindex_index(
|
||||
rebuild=options["command"] == "rebuild",
|
||||
scheduled=False,
|
||||
iter_wrapper=lambda docs: self.track(
|
||||
docs,
|
||||
description="Indexing documents...",
|
||||
|
||||
@@ -111,7 +111,6 @@ class Command(PaperlessCommand):
|
||||
|
||||
def handle(self, *args: Any, **options: Any) -> None:
|
||||
messages = check_sanity(
|
||||
scheduled=False,
|
||||
iter_wrapper=lambda docs: self.track(
|
||||
docs,
|
||||
description="Checking documents...",
|
||||
|
||||
@@ -10,7 +10,6 @@ is an identity function that adds no overhead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
@@ -19,10 +18,8 @@ from typing import Final
|
||||
from typing import TypedDict
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils import timezone
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import PaperlessTask
|
||||
from documents.utils import IterWrapper
|
||||
from documents.utils import compute_checksum
|
||||
from documents.utils import identity
|
||||
@@ -286,33 +283,17 @@ def _check_document(
|
||||
|
||||
def check_sanity(
|
||||
*,
|
||||
scheduled: bool = True,
|
||||
iter_wrapper: IterWrapper[Document] = identity,
|
||||
) -> SanityCheckMessages:
|
||||
"""Run a full sanity check on the document archive.
|
||||
|
||||
Args:
|
||||
scheduled: Whether this is a scheduled (automatic) or manual check.
|
||||
Controls the task type recorded in the database.
|
||||
iter_wrapper: A callable that wraps the document iterable, e.g.,
|
||||
for progress bar display. Defaults to identity (no wrapping).
|
||||
|
||||
Returns:
|
||||
A SanityCheckMessages instance containing all detected issues.
|
||||
"""
|
||||
paperless_task = PaperlessTask.objects.create(
|
||||
task_id=uuid.uuid4(),
|
||||
trigger_source=(
|
||||
PaperlessTask.TriggerSource.SCHEDULED
|
||||
if scheduled
|
||||
else PaperlessTask.TriggerSource.MANUAL
|
||||
),
|
||||
task_type=PaperlessTask.TaskType.SANITY_CHECK,
|
||||
status=PaperlessTask.Status.STARTED,
|
||||
date_created=timezone.now(),
|
||||
date_started=timezone.now(),
|
||||
)
|
||||
|
||||
messages = SanityCheckMessages()
|
||||
present_files = _build_present_files()
|
||||
|
||||
@@ -331,26 +312,4 @@ def check_sanity(
|
||||
for extra_file in present_files:
|
||||
messages.warning(None, f"Orphaned file in media dir: {extra_file}")
|
||||
|
||||
paperless_task.status = (
|
||||
PaperlessTask.Status.SUCCESS
|
||||
if not messages.has_error
|
||||
else PaperlessTask.Status.FAILURE
|
||||
)
|
||||
if messages.total_issue_count == 0:
|
||||
paperless_task.result_message = "No issues found."
|
||||
else:
|
||||
parts: list[str] = []
|
||||
if messages.document_error_count:
|
||||
parts.append(f"{messages.document_error_count} document(s) with errors")
|
||||
if messages.document_warning_count:
|
||||
parts.append(f"{messages.document_warning_count} document(s) with warnings")
|
||||
if messages.global_warning_count:
|
||||
parts.append(f"{messages.global_warning_count} global warning(s)")
|
||||
paperless_task.result_message = ", ".join(parts) + " found."
|
||||
if messages.has_error:
|
||||
paperless_task.result_message += " Check logs for details."
|
||||
|
||||
paperless_task.date_done = timezone.now()
|
||||
paperless_task.save(update_fields=["status", "result_message", "date_done"])
|
||||
|
||||
return messages
|
||||
|
||||
@@ -40,7 +40,6 @@ from documents.models import Correspondent
|
||||
from documents.models import CustomFieldInstance
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import PaperlessTask
|
||||
from documents.models import ShareLink
|
||||
from documents.models import ShareLinkBundle
|
||||
from documents.models import StoragePath
|
||||
@@ -83,19 +82,8 @@ def index_optimize() -> None:
|
||||
@shared_task
|
||||
def train_classifier(
|
||||
*,
|
||||
scheduled=True,
|
||||
status_callback: Callable[[str], None] | None = None,
|
||||
) -> None:
|
||||
task = PaperlessTask.objects.create(
|
||||
trigger_source=PaperlessTask.TriggerSource.SCHEDULED
|
||||
if scheduled
|
||||
else PaperlessTask.TriggerSource.MANUAL,
|
||||
task_id=uuid.uuid4(),
|
||||
task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER,
|
||||
status=PaperlessTask.Status.STARTED,
|
||||
date_created=timezone.now(),
|
||||
date_started=timezone.now(),
|
||||
)
|
||||
) -> str:
|
||||
if (
|
||||
not Tag.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
|
||||
and not DocumentType.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
|
||||
@@ -109,37 +97,22 @@ def train_classifier(
|
||||
if settings.MODEL_FILE.exists():
|
||||
logger.info(f"Removing {settings.MODEL_FILE} so it won't be used")
|
||||
settings.MODEL_FILE.unlink()
|
||||
task.status = PaperlessTask.Status.SUCCESS
|
||||
task.result_message = result
|
||||
task.date_done = timezone.now()
|
||||
task.save()
|
||||
return
|
||||
return result
|
||||
|
||||
classifier = load_classifier()
|
||||
|
||||
if not classifier:
|
||||
classifier = DocumentClassifier()
|
||||
|
||||
try:
|
||||
if classifier.train(status_callback=status_callback):
|
||||
logger.info(
|
||||
f"Saving updated classifier model to {settings.MODEL_FILE}...",
|
||||
)
|
||||
classifier.save()
|
||||
task.result_message = "Training completed successfully"
|
||||
else:
|
||||
logger.debug("Training data unchanged.")
|
||||
task.result_message = "Training data unchanged"
|
||||
|
||||
task.status = PaperlessTask.Status.SUCCESS
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Classifier error: " + str(e))
|
||||
task.status = PaperlessTask.Status.FAILURE
|
||||
task.result_message = str(e)
|
||||
|
||||
task.date_done = timezone.now()
|
||||
task.save(update_fields=["status", "result_message", "date_done"])
|
||||
if classifier.train(status_callback=status_callback):
|
||||
logger.info(
|
||||
f"Saving updated classifier model to {settings.MODEL_FILE}...",
|
||||
)
|
||||
classifier.save()
|
||||
return "Training completed successfully"
|
||||
else:
|
||||
logger.debug("Training data unchanged.")
|
||||
return "Training data unchanged"
|
||||
|
||||
|
||||
@shared_task(bind=True)
|
||||
@@ -230,8 +203,8 @@ def consume_file(
|
||||
|
||||
|
||||
@shared_task
|
||||
def sanity_check(*, scheduled=True, raise_on_error=True):
|
||||
messages = sanity_checker.check_sanity(scheduled=scheduled)
|
||||
def sanity_check(*, raise_on_error: bool = True) -> str:
|
||||
messages = sanity_checker.check_sanity()
|
||||
messages.log_messages()
|
||||
|
||||
if not messages.has_error and not messages.has_warning and not messages.has_info:
|
||||
@@ -634,42 +607,19 @@ def update_document_parent_tags(tag: Tag, new_parent: Tag) -> None:
|
||||
def llmindex_index(
|
||||
*,
|
||||
iter_wrapper: IterWrapper[Document] = identity,
|
||||
rebuild=False,
|
||||
scheduled=True,
|
||||
auto=False,
|
||||
) -> None:
|
||||
rebuild: bool = False,
|
||||
) -> str | None:
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled:
|
||||
task = PaperlessTask.objects.create(
|
||||
trigger_source=PaperlessTask.TriggerSource.SCHEDULED
|
||||
if scheduled
|
||||
else PaperlessTask.TriggerSource.SYSTEM
|
||||
if auto
|
||||
else PaperlessTask.TriggerSource.MANUAL,
|
||||
task_id=uuid.uuid4(),
|
||||
task_type=PaperlessTask.TaskType.LLM_INDEX,
|
||||
status=PaperlessTask.Status.STARTED,
|
||||
date_created=timezone.now(),
|
||||
date_started=timezone.now(),
|
||||
)
|
||||
from paperless_ai.indexing import update_llm_index
|
||||
|
||||
try:
|
||||
result = update_llm_index(
|
||||
iter_wrapper=iter_wrapper,
|
||||
rebuild=rebuild,
|
||||
)
|
||||
task.status = PaperlessTask.Status.SUCCESS
|
||||
task.result_message = result
|
||||
except Exception as e:
|
||||
logger.error("LLM index error: " + str(e))
|
||||
task.status = PaperlessTask.Status.FAILURE
|
||||
task.result_message = str(e)
|
||||
|
||||
task.date_done = timezone.now()
|
||||
task.save(update_fields=["status", "result_message", "date_done"])
|
||||
else:
|
||||
if not ai_config.llm_index_enabled:
|
||||
logger.info("LLM index is disabled, skipping update.")
|
||||
return None
|
||||
|
||||
from paperless_ai.indexing import update_llm_index
|
||||
|
||||
return update_llm_index(
|
||||
iter_wrapper=iter_wrapper,
|
||||
rebuild=rebuild,
|
||||
)
|
||||
|
||||
|
||||
@shared_task
|
||||
|
||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from celery import states
|
||||
from django.conf import settings
|
||||
from django.test import TestCase
|
||||
from django.test import override_settings
|
||||
@@ -14,7 +13,6 @@ from documents import tasks
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import PaperlessTask
|
||||
from documents.models import Tag
|
||||
from documents.sanity_checker import SanityCheckFailedException
|
||||
from documents.sanity_checker import SanityCheckMessages
|
||||
@@ -40,7 +38,8 @@ class TestClassifier(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
def test_train_classifier_with_auto_tag(self, load_classifier) -> None:
|
||||
load_classifier.return_value = None
|
||||
Tag.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||
tasks.train_classifier()
|
||||
with self.assertRaises(ValueError):
|
||||
tasks.train_classifier()
|
||||
load_classifier.assert_called_once()
|
||||
self.assertIsNotFile(settings.MODEL_FILE)
|
||||
|
||||
@@ -48,7 +47,8 @@ class TestClassifier(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
def test_train_classifier_with_auto_type(self, load_classifier) -> None:
|
||||
load_classifier.return_value = None
|
||||
DocumentType.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||
tasks.train_classifier()
|
||||
with self.assertRaises(ValueError):
|
||||
tasks.train_classifier()
|
||||
load_classifier.assert_called_once()
|
||||
self.assertIsNotFile(settings.MODEL_FILE)
|
||||
|
||||
@@ -56,7 +56,8 @@ class TestClassifier(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
def test_train_classifier_with_auto_correspondent(self, load_classifier) -> None:
|
||||
load_classifier.return_value = None
|
||||
Correspondent.objects.create(matching_algorithm=Tag.MATCH_AUTO, name="test")
|
||||
tasks.train_classifier()
|
||||
with self.assertRaises(ValueError):
|
||||
tasks.train_classifier()
|
||||
load_classifier.assert_called_once()
|
||||
self.assertIsNotFile(settings.MODEL_FILE)
|
||||
|
||||
@@ -298,7 +299,7 @@ class TestAIIndex(DirectoriesMixin, TestCase):
|
||||
WHEN:
|
||||
- llmindex_index task is called
|
||||
THEN:
|
||||
- update_llm_index is called, and the task is marked as success
|
||||
- update_llm_index is called and its result is returned
|
||||
"""
|
||||
Document.objects.create(
|
||||
title="test",
|
||||
@@ -308,13 +309,9 @@ class TestAIIndex(DirectoriesMixin, TestCase):
|
||||
# lazy-loaded so mock the actual function
|
||||
with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index:
|
||||
update_llm_index.return_value = "LLM index updated successfully."
|
||||
tasks.llmindex_index()
|
||||
result = tasks.llmindex_index()
|
||||
update_llm_index.assert_called_once()
|
||||
task = PaperlessTask.objects.get(
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
)
|
||||
self.assertEqual(task.status, states.SUCCESS)
|
||||
self.assertEqual(task.result, "LLM index updated successfully.")
|
||||
self.assertEqual(result, "LLM index updated successfully.")
|
||||
|
||||
@override_settings(
|
||||
AI_ENABLED=True,
|
||||
@@ -325,9 +322,9 @@ class TestAIIndex(DirectoriesMixin, TestCase):
|
||||
GIVEN:
|
||||
- Document exists, AI is enabled, llm index backend is set
|
||||
WHEN:
|
||||
- llmindex_index task is called
|
||||
- llmindex_index task is called and update_llm_index raises an exception
|
||||
THEN:
|
||||
- update_llm_index raises an exception, and the task is marked as failure
|
||||
- the exception propagates to the caller
|
||||
"""
|
||||
Document.objects.create(
|
||||
title="test",
|
||||
@@ -337,13 +334,9 @@ class TestAIIndex(DirectoriesMixin, TestCase):
|
||||
# lazy-loaded so mock the actual function
|
||||
with mock.patch("paperless_ai.indexing.update_llm_index") as update_llm_index:
|
||||
update_llm_index.side_effect = Exception("LLM index update failed.")
|
||||
tasks.llmindex_index()
|
||||
with self.assertRaises(Exception, msg="LLM index update failed."):
|
||||
tasks.llmindex_index()
|
||||
update_llm_index.assert_called_once()
|
||||
task = PaperlessTask.objects.get(
|
||||
task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE,
|
||||
)
|
||||
self.assertEqual(task.status, states.FAILURE)
|
||||
self.assertIn("LLM index update failed.", task.result)
|
||||
|
||||
def test_update_document_in_llm_index(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -427,10 +427,9 @@ class ApplicationConfigurationViewSet(ModelViewSet[ApplicationConfiguration]):
|
||||
and not vector_store_file_exists()
|
||||
):
|
||||
# AI index was just enabled and vector store file does not exist
|
||||
llmindex_index.delay(
|
||||
rebuild=True,
|
||||
scheduled=False,
|
||||
auto=True,
|
||||
llmindex_index.apply_async(
|
||||
kwargs={"rebuild": True},
|
||||
headers={"trigger_source": "system"},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -37,7 +37,10 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
|
||||
if has_running or has_recent:
|
||||
return False
|
||||
|
||||
llmindex_index.delay(rebuild=rebuild, scheduled=False, auto=True)
|
||||
llmindex_index.apply_async(
|
||||
kwargs={"rebuild": rebuild},
|
||||
headers={"trigger_source": "system"},
|
||||
)
|
||||
logger.warning(
|
||||
"Queued LLM index update%s: %s",
|
||||
" (rebuild)" if rebuild else "",
|
||||
|
||||
Reference in New Issue
Block a user