mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-04-16 04:58:53 +00:00
feat(tasks): rewrite signal handlers to track all task types
Replace the old consume_file-only handler with a full rewrite that tracks 6 task types (consume_file, train_classifier, sanity_check, index_optimize, llm_index, mail_fetch) with proper trigger source detection, input data extraction, legacy result string parsing, duration/wait time recording, and structured error capture on failure. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery import states
|
||||
from celery.signals import before_task_publish
|
||||
from celery.signals import task_failure
|
||||
from celery.signals import task_postrun
|
||||
@@ -31,6 +30,7 @@ from documents import matching
|
||||
from documents.caching import clear_document_caches
|
||||
from documents.caching import invalidate_llm_suggestions_cache
|
||||
from documents.data_models import ConsumableDocument
|
||||
from documents.data_models import DocumentSource
|
||||
from documents.file_handling import create_source_path_directory
|
||||
from documents.file_handling import delete_empty_directories
|
||||
from documents.file_handling import generate_filename
|
||||
@@ -996,67 +996,175 @@ def run_workflows(
|
||||
return overrides, "\n".join(messages)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task tracking -- Celery signal handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TRACKED_TASKS: dict[str, PaperlessTask.TaskType] = {
|
||||
"documents.tasks.consume_file": PaperlessTask.TaskType.CONSUME_FILE,
|
||||
"documents.tasks.train_classifier": PaperlessTask.TaskType.TRAIN_CLASSIFIER,
|
||||
"documents.tasks.sanity_check": PaperlessTask.TaskType.SANITY_CHECK,
|
||||
"documents.tasks.index_optimize": PaperlessTask.TaskType.INDEX_OPTIMIZE,
|
||||
"documents.tasks.llmindex_index": PaperlessTask.TaskType.LLM_INDEX,
|
||||
"paperless_mail.tasks.process_mail_accounts": PaperlessTask.TaskType.MAIL_FETCH,
|
||||
}
|
||||
|
||||
_DOCUMENT_SOURCE_TO_TRIGGER: dict[Any, PaperlessTask.TriggerSource] = {
|
||||
DocumentSource.ConsumeFolder: PaperlessTask.TriggerSource.FOLDER_CONSUME,
|
||||
DocumentSource.ApiUpload: PaperlessTask.TriggerSource.API_UPLOAD,
|
||||
DocumentSource.MailFetch: PaperlessTask.TriggerSource.EMAIL_CONSUME,
|
||||
DocumentSource.WebUI: PaperlessTask.TriggerSource.WEB_UI,
|
||||
}
|
||||
|
||||
|
||||
def _extract_input_data(
|
||||
task_type: PaperlessTask.TaskType,
|
||||
args: tuple,
|
||||
task_kwargs: dict,
|
||||
) -> dict:
|
||||
if task_type == PaperlessTask.TaskType.CONSUME_FILE:
|
||||
input_doc = args[0] if args else task_kwargs.get("input_doc")
|
||||
overrides = args[1] if len(args) >= 2 else task_kwargs.get("overrides")
|
||||
if input_doc is None:
|
||||
return {}
|
||||
data: dict = {
|
||||
"filename": input_doc.original_file.name,
|
||||
"mime_type": input_doc.mime_type,
|
||||
}
|
||||
if input_doc.original_path:
|
||||
data["source_path"] = str(input_doc.original_path)
|
||||
if input_doc.mailrule_id:
|
||||
data["mailrule_id"] = input_doc.mailrule_id
|
||||
if overrides:
|
||||
override_dict = {
|
||||
k: v
|
||||
for k, v in vars(overrides).items()
|
||||
if v is not None and not k.startswith("_")
|
||||
}
|
||||
if override_dict:
|
||||
data["overrides"] = override_dict
|
||||
return data
|
||||
|
||||
if task_type == PaperlessTask.TaskType.MAIL_FETCH:
|
||||
account_ids = args[0] if args else task_kwargs.get("account_ids")
|
||||
return {"account_ids": account_ids}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _determine_trigger_source(
|
||||
task_type: PaperlessTask.TaskType,
|
||||
args: tuple,
|
||||
task_kwargs: dict,
|
||||
headers: dict,
|
||||
) -> PaperlessTask.TriggerSource:
|
||||
# Explicit header takes priority -- covers beat ("scheduled") and system auto-runs ("system")
|
||||
header_source = headers.get("trigger_source")
|
||||
if header_source == "scheduled":
|
||||
return PaperlessTask.TriggerSource.SCHEDULED
|
||||
if header_source == "system":
|
||||
return PaperlessTask.TriggerSource.SYSTEM
|
||||
|
||||
if task_type == PaperlessTask.TaskType.CONSUME_FILE:
|
||||
input_doc = args[0] if args else task_kwargs.get("input_doc")
|
||||
if input_doc is not None:
|
||||
return _DOCUMENT_SOURCE_TO_TRIGGER.get(
|
||||
input_doc.source,
|
||||
PaperlessTask.TriggerSource.API_UPLOAD,
|
||||
)
|
||||
|
||||
return PaperlessTask.TriggerSource.MANUAL
|
||||
|
||||
|
||||
def _extract_owner_id(
|
||||
task_type: PaperlessTask.TaskType,
|
||||
args: tuple,
|
||||
task_kwargs: dict,
|
||||
) -> int | None:
|
||||
if task_type != PaperlessTask.TaskType.CONSUME_FILE:
|
||||
return None
|
||||
overrides = args[1] if len(args) >= 2 else task_kwargs.get("overrides")
|
||||
if overrides and hasattr(overrides, "owner_id"):
|
||||
return overrides.owner_id
|
||||
return None
|
||||
|
||||
|
||||
def _parse_legacy_result(result: str) -> dict | None:
|
||||
import re as _re
|
||||
|
||||
if match := _re.search(r"New document id (\d+) created", result):
|
||||
return {"document_id": int(match.group(1))}
|
||||
if match := _re.search(r"It is a duplicate of .* \(#(\d+)\)", result):
|
||||
return {
|
||||
"duplicate_of": int(match.group(1)),
|
||||
"duplicate_in_trash": "existing document is in the trash" in result,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
@before_task_publish.connect
|
||||
def before_task_publish_handler(sender=None, headers=None, body=None, **kwargs) -> None:
|
||||
def before_task_publish_handler(
|
||||
sender=None,
|
||||
headers=None,
|
||||
body=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Creates the PaperlessTask object in a pending state. This is sent before
|
||||
the task reaches the broker, but before it begins executing on a worker.
|
||||
Creates the PaperlessTask record when the task is published to broker.
|
||||
|
||||
https://docs.celeryq.dev/en/stable/userguide/signals.html#before-task-publish
|
||||
|
||||
https://docs.celeryq.dev/en/stable/internals/protocol.html#version-2
|
||||
|
||||
"""
|
||||
if "task" not in headers or headers["task"] != "documents.tasks.consume_file":
|
||||
# Assumption: this is only ever a v2 message
|
||||
if headers is None or body is None:
|
||||
return
|
||||
|
||||
task_name = headers.get("task", "")
|
||||
task_type = TRACKED_TASKS.get(task_name)
|
||||
if task_type is None:
|
||||
return
|
||||
|
||||
try:
|
||||
close_old_connections()
|
||||
args, task_kwargs, _ = body
|
||||
task_id = headers["id"]
|
||||
|
||||
task_args = body[0]
|
||||
input_doc, overrides = task_args
|
||||
|
||||
task_file_name = input_doc.original_file.name
|
||||
user_id = overrides.owner_id if overrides else None
|
||||
input_data = _extract_input_data(task_type, args, task_kwargs)
|
||||
trigger_source = _determine_trigger_source(
|
||||
task_type,
|
||||
args,
|
||||
task_kwargs,
|
||||
headers,
|
||||
)
|
||||
owner_id = _extract_owner_id(task_type, args, task_kwargs)
|
||||
|
||||
PaperlessTask.objects.create(
|
||||
trigger_source=PaperlessTask.TriggerSource.FOLDER_CONSUME,
|
||||
task_id=headers["id"],
|
||||
task_id=task_id,
|
||||
task_type=task_type,
|
||||
trigger_source=trigger_source,
|
||||
status=PaperlessTask.Status.PENDING,
|
||||
input_data={"filename": task_file_name},
|
||||
task_type=PaperlessTask.TaskType.CONSUME_FILE,
|
||||
date_created=timezone.now(),
|
||||
date_started=None,
|
||||
date_done=None,
|
||||
owner_id=user_id,
|
||||
input_data=input_data,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
# Don't let an exception in the signal handlers prevent
|
||||
# a document from being consumed.
|
||||
except Exception:
|
||||
logger.exception("Creating PaperlessTask failed")
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def task_prerun_handler(sender=None, task_id=None, task=None, **kwargs) -> None:
|
||||
"""
|
||||
|
||||
Updates the PaperlessTask to be started. Sent before the task begins execution
|
||||
on a worker.
|
||||
Marks the task STARTED when execution begins on a worker.
|
||||
|
||||
https://docs.celeryq.dev/en/stable/userguide/signals.html#task-prerun
|
||||
"""
|
||||
if task_id is None:
|
||||
return
|
||||
try:
|
||||
close_old_connections()
|
||||
task_instance = PaperlessTask.objects.filter(task_id=task_id).first()
|
||||
|
||||
if task_instance is not None:
|
||||
task_instance.status = PaperlessTask.Status.STARTED
|
||||
task_instance.date_started = timezone.now()
|
||||
task_instance.save()
|
||||
except Exception: # pragma: no cover
|
||||
# Don't let an exception in the signal handlers prevent
|
||||
# a document from being consumed.
|
||||
PaperlessTask.objects.filter(task_id=task_id).update(
|
||||
status=PaperlessTask.Status.STARTED,
|
||||
date_started=timezone.now(),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Setting PaperlessTask started failed")
|
||||
|
||||
|
||||
@@ -1070,33 +1178,53 @@ def task_postrun_handler(
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Updates the result of the PaperlessTask.
|
||||
Records task completion and result data.
|
||||
|
||||
https://docs.celeryq.dev/en/stable/userguide/signals.html#task-postrun
|
||||
"""
|
||||
if task_id is None:
|
||||
return
|
||||
try:
|
||||
close_old_connections()
|
||||
task_instance = PaperlessTask.objects.filter(task_id=task_id).first()
|
||||
|
||||
if task_instance is not None:
|
||||
_CELERY_STATE_MAP = {
|
||||
states.SUCCESS: PaperlessTask.Status.SUCCESS,
|
||||
states.FAILURE: PaperlessTask.Status.FAILURE,
|
||||
states.REVOKED: PaperlessTask.Status.REVOKED,
|
||||
states.STARTED: PaperlessTask.Status.STARTED,
|
||||
states.PENDING: PaperlessTask.Status.PENDING,
|
||||
}
|
||||
task_instance.status = _CELERY_STATE_MAP.get(
|
||||
state,
|
||||
PaperlessTask.Status.FAILURE,
|
||||
)
|
||||
if isinstance(retval, str):
|
||||
task_instance.result_message = retval
|
||||
task_instance.date_done = timezone.now()
|
||||
task_instance.save()
|
||||
except Exception: # pragma: no cover
|
||||
# Don't let an exception in the signal handlers prevent
|
||||
# a document from being consumed.
|
||||
status_map = {
|
||||
"SUCCESS": PaperlessTask.Status.SUCCESS,
|
||||
"FAILURE": PaperlessTask.Status.FAILURE,
|
||||
"REVOKED": PaperlessTask.Status.REVOKED,
|
||||
}
|
||||
new_status = status_map.get(state, PaperlessTask.Status.FAILURE)
|
||||
|
||||
result_data: dict | None = None
|
||||
result_message: str | None = None
|
||||
if isinstance(retval, dict):
|
||||
result_data = retval
|
||||
elif isinstance(retval, str):
|
||||
result_message = retval
|
||||
result_data = _parse_legacy_result(retval)
|
||||
|
||||
now = timezone.now()
|
||||
task_instance = PaperlessTask.objects.filter(task_id=task_id).first()
|
||||
if task_instance is None:
|
||||
return
|
||||
|
||||
duration_seconds: float | None = None
|
||||
wait_time_seconds: float | None = None
|
||||
if task_instance.date_started:
|
||||
duration_seconds = (now - task_instance.date_started).total_seconds()
|
||||
if task_instance.date_started and task_instance.date_created:
|
||||
wait_time_seconds = (
|
||||
task_instance.date_started - task_instance.date_created
|
||||
).total_seconds()
|
||||
|
||||
PaperlessTask.objects.filter(task_id=task_id).update(
|
||||
status=new_status,
|
||||
result_data=result_data,
|
||||
result_message=result_message,
|
||||
date_done=now,
|
||||
duration_seconds=duration_seconds,
|
||||
wait_time_seconds=wait_time_seconds,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Updating PaperlessTask failed")
|
||||
|
||||
|
||||
@@ -1110,21 +1238,33 @@ def task_failure_handler(
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Updates the result of a failed PaperlessTask.
|
||||
Records failure details when a task raises an exception.
|
||||
|
||||
https://docs.celeryq.dev/en/stable/userguide/signals.html#task-failure
|
||||
"""
|
||||
if task_id is None:
|
||||
return
|
||||
try:
|
||||
close_old_connections()
|
||||
task_instance = PaperlessTask.objects.filter(task_id=task_id).first()
|
||||
|
||||
if task_instance is not None and task_instance.result_message is None:
|
||||
task_instance.status = PaperlessTask.Status.FAILURE
|
||||
task_instance.result_message = str(traceback) if traceback else None
|
||||
task_instance.date_done = timezone.now()
|
||||
task_instance.save()
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Updating PaperlessTask failed")
|
||||
result_data: dict = {
|
||||
"error_type": type(exception).__name__ if exception else "Unknown",
|
||||
"error_message": str(exception) if exception else "Unknown error",
|
||||
}
|
||||
if traceback:
|
||||
import traceback as _tb
|
||||
|
||||
tb_str = "".join(_tb.format_tb(traceback))
|
||||
result_data["traceback"] = tb_str[:5000]
|
||||
|
||||
PaperlessTask.objects.filter(task_id=task_id).update(
|
||||
status=PaperlessTask.Status.FAILURE,
|
||||
result_data=result_data,
|
||||
result_message=str(exception) if exception else None,
|
||||
date_done=timezone.now(),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Updating PaperlessTask on failure failed")
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
|
||||
@@ -1,250 +1,266 @@
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
import celery
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase
|
||||
import pytest
|
||||
|
||||
from documents.data_models import ConsumableDocument
|
||||
from documents.data_models import DocumentMetadataOverrides
|
||||
from documents.data_models import DocumentSource
|
||||
from documents.models import Document
|
||||
from documents.models import PaperlessTask
|
||||
from documents.signals.handlers import add_to_index
|
||||
from documents.signals.handlers import before_task_publish_handler
|
||||
from documents.signals.handlers import task_failure_handler
|
||||
from documents.signals.handlers import task_postrun_handler
|
||||
from documents.signals.handlers import task_prerun_handler
|
||||
from documents.tests.test_consumer import fake_magic_from_file
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
@mock.patch("documents.consumer.magic.from_file", fake_magic_from_file)
|
||||
class TestTaskSignalHandler(DirectoriesMixin, TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls) -> None:
|
||||
super().setUpTestData()
|
||||
cls.user = get_user_model().objects.create_user(username="testuser")
|
||||
@pytest.fixture
|
||||
def consume_input_doc():
|
||||
doc = mock.MagicMock(spec=ConsumableDocument)
|
||||
# original_file is a Path; configure the nested mock so .name works
|
||||
doc.original_file = mock.MagicMock()
|
||||
doc.original_file.name = "invoice.pdf"
|
||||
doc.original_path = None
|
||||
doc.mime_type = "application/pdf"
|
||||
doc.mailrule_id = None
|
||||
doc.source = DocumentSource.WebUI
|
||||
return doc
|
||||
|
||||
def util_call_before_task_publish_handler(
|
||||
|
||||
@pytest.fixture
|
||||
def consume_overrides(django_user_model):
|
||||
user = django_user_model.objects.create_user(username="testuser")
|
||||
overrides = mock.MagicMock(spec=DocumentMetadataOverrides)
|
||||
overrides.owner_id = user.id
|
||||
return overrides
|
||||
|
||||
|
||||
def send_publish(
|
||||
task_name: str,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
headers: dict | None = None,
|
||||
) -> str:
|
||||
from documents.signals.handlers import before_task_publish_handler
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
hdrs = {"task": task_name, "id": task_id, **(headers or {})}
|
||||
before_task_publish_handler(sender=task_name, headers=hdrs, body=(args, kwargs, {}))
|
||||
return task_id
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestBeforeTaskPublishHandler:
|
||||
def test_creates_task_for_consume_file(self, consume_input_doc, consume_overrides):
|
||||
task_id = send_publish(
|
||||
"documents.tasks.consume_file",
|
||||
(consume_input_doc, consume_overrides),
|
||||
{},
|
||||
)
|
||||
task = PaperlessTask.objects.get(task_id=task_id)
|
||||
assert task.task_type == PaperlessTask.TaskType.CONSUME_FILE
|
||||
assert task.status == PaperlessTask.Status.PENDING
|
||||
assert task.trigger_source == PaperlessTask.TriggerSource.WEB_UI
|
||||
assert task.input_data["filename"] == "invoice.pdf"
|
||||
assert task.owner_id == consume_overrides.owner_id
|
||||
|
||||
def test_creates_task_for_train_classifier(self):
|
||||
task_id = send_publish("documents.tasks.train_classifier", (), {})
|
||||
task = PaperlessTask.objects.get(task_id=task_id)
|
||||
assert task.task_type == PaperlessTask.TaskType.TRAIN_CLASSIFIER
|
||||
assert task.trigger_source == PaperlessTask.TriggerSource.MANUAL
|
||||
|
||||
def test_creates_task_for_sanity_check(self):
|
||||
task_id = send_publish("documents.tasks.sanity_check", (), {})
|
||||
task = PaperlessTask.objects.get(task_id=task_id)
|
||||
assert task.task_type == PaperlessTask.TaskType.SANITY_CHECK
|
||||
|
||||
def test_creates_task_for_process_mail_accounts(self):
|
||||
task_id = send_publish(
|
||||
"paperless_mail.tasks.process_mail_accounts",
|
||||
(),
|
||||
{"account_ids": [1, 2]},
|
||||
)
|
||||
task = PaperlessTask.objects.get(task_id=task_id)
|
||||
assert task.task_type == PaperlessTask.TaskType.MAIL_FETCH
|
||||
assert task.input_data["account_ids"] == [1, 2]
|
||||
|
||||
def test_scheduled_header_sets_trigger_source(self):
|
||||
task_id = send_publish(
|
||||
"documents.tasks.train_classifier",
|
||||
(),
|
||||
{},
|
||||
headers={"trigger_source": "scheduled"},
|
||||
)
|
||||
task = PaperlessTask.objects.get(task_id=task_id)
|
||||
assert task.trigger_source == PaperlessTask.TriggerSource.SCHEDULED
|
||||
|
||||
def test_system_header_sets_trigger_source(self):
|
||||
task_id = send_publish(
|
||||
"documents.tasks.llmindex_index",
|
||||
(),
|
||||
{"rebuild": True},
|
||||
headers={"trigger_source": "system"},
|
||||
)
|
||||
task = PaperlessTask.objects.get(task_id=task_id)
|
||||
assert task.trigger_source == PaperlessTask.TriggerSource.SYSTEM
|
||||
|
||||
def test_ignores_untracked_task(self):
|
||||
send_publish("documents.tasks.bulk_update_documents", ([1, 2],), {})
|
||||
assert PaperlessTask.objects.count() == 0
|
||||
|
||||
def test_ignores_none_headers(self):
|
||||
from documents.signals.handlers import before_task_publish_handler
|
||||
|
||||
before_task_publish_handler(sender=None, headers=None, body=None)
|
||||
assert PaperlessTask.objects.count() == 0
|
||||
|
||||
def test_consume_folder_source_maps_correctly(
|
||||
self,
|
||||
headers_to_use,
|
||||
body_to_use,
|
||||
) -> None:
|
||||
"""
|
||||
Simple utility to call the pre-run handle and ensure it created a single task
|
||||
instance
|
||||
"""
|
||||
self.assertEqual(PaperlessTask.objects.all().count(), 0)
|
||||
|
||||
before_task_publish_handler(headers=headers_to_use, body=body_to_use)
|
||||
|
||||
self.assertEqual(PaperlessTask.objects.all().count(), 1)
|
||||
|
||||
def test_before_task_publish_handler_consume(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A celery task is started via the consume folder
|
||||
WHEN:
|
||||
- Task before publish handler is called
|
||||
THEN:
|
||||
- The task is created and marked as pending
|
||||
"""
|
||||
headers = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"task": "documents.tasks.consume_file",
|
||||
}
|
||||
body = (
|
||||
# args
|
||||
(
|
||||
ConsumableDocument(
|
||||
source=DocumentSource.ConsumeFolder,
|
||||
original_file="/consume/hello-999.pdf",
|
||||
),
|
||||
DocumentMetadataOverrides(
|
||||
title="Hello world",
|
||||
owner_id=self.user.id,
|
||||
),
|
||||
),
|
||||
# kwargs
|
||||
consume_input_doc,
|
||||
consume_overrides,
|
||||
):
|
||||
consume_input_doc.source = DocumentSource.ConsumeFolder
|
||||
task_id = send_publish(
|
||||
"documents.tasks.consume_file",
|
||||
(consume_input_doc, consume_overrides),
|
||||
{},
|
||||
# celery stuff
|
||||
{"callbacks": None, "errbacks": None, "chain": None, "chord": None},
|
||||
)
|
||||
self.util_call_before_task_publish_handler(
|
||||
headers_to_use=headers,
|
||||
body_to_use=body,
|
||||
)
|
||||
task = PaperlessTask.objects.get(task_id=task_id)
|
||||
assert task.trigger_source == PaperlessTask.TriggerSource.FOLDER_CONSUME
|
||||
|
||||
task = PaperlessTask.objects.get()
|
||||
self.assertIsNotNone(task)
|
||||
self.assertEqual(headers["id"], task.task_id)
|
||||
self.assertEqual("hello-999.pdf", task.task_file_name)
|
||||
self.assertEqual(PaperlessTask.TaskName.CONSUME_FILE, task.task_name)
|
||||
self.assertEqual(self.user.id, task.owner_id)
|
||||
self.assertEqual(celery.states.PENDING, task.status)
|
||||
|
||||
def test_task_prerun_handler(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A celery task is started via the consume folder
|
||||
WHEN:
|
||||
- Task starts execution
|
||||
THEN:
|
||||
- The task is marked as started
|
||||
"""
|
||||
|
||||
headers = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"task": "documents.tasks.consume_file",
|
||||
}
|
||||
body = (
|
||||
# args
|
||||
(
|
||||
ConsumableDocument(
|
||||
source=DocumentSource.ConsumeFolder,
|
||||
original_file="/consume/hello-99.pdf",
|
||||
),
|
||||
None,
|
||||
),
|
||||
# kwargs
|
||||
def test_email_source_maps_correctly(self, consume_input_doc, consume_overrides):
|
||||
consume_input_doc.source = DocumentSource.MailFetch
|
||||
task_id = send_publish(
|
||||
"documents.tasks.consume_file",
|
||||
(consume_input_doc, consume_overrides),
|
||||
{},
|
||||
# celery stuff
|
||||
{"callbacks": None, "errbacks": None, "chain": None, "chord": None},
|
||||
)
|
||||
task = PaperlessTask.objects.get(task_id=task_id)
|
||||
assert task.trigger_source == PaperlessTask.TriggerSource.EMAIL_CONSUME
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestTaskPrerunHandler:
|
||||
def test_marks_task_started(self):
|
||||
task = PaperlessTask.objects.create(
|
||||
task_id=str(uuid.uuid4()),
|
||||
task_type=PaperlessTask.TaskType.CONSUME_FILE,
|
||||
trigger_source=PaperlessTask.TriggerSource.MANUAL,
|
||||
status=PaperlessTask.Status.PENDING,
|
||||
)
|
||||
from documents.signals.handlers import task_prerun_handler
|
||||
|
||||
task_prerun_handler(task_id=task.task_id)
|
||||
task.refresh_from_db()
|
||||
assert task.status == PaperlessTask.Status.STARTED
|
||||
assert task.date_started is not None
|
||||
|
||||
def test_ignores_unknown_task_id(self):
|
||||
from documents.signals.handlers import task_prerun_handler
|
||||
|
||||
task_prerun_handler(task_id="nonexistent-id") # must not raise
|
||||
|
||||
def test_ignores_none_task_id(self):
|
||||
from documents.signals.handlers import task_prerun_handler
|
||||
|
||||
task_prerun_handler(task_id=None) # must not raise
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestTaskPostrunHandler:
|
||||
def _started_task(self) -> PaperlessTask:
|
||||
from django.utils import timezone
|
||||
|
||||
return PaperlessTask.objects.create(
|
||||
task_id=str(uuid.uuid4()),
|
||||
task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER,
|
||||
trigger_source=PaperlessTask.TriggerSource.MANUAL,
|
||||
status=PaperlessTask.Status.STARTED,
|
||||
date_started=timezone.now(),
|
||||
)
|
||||
|
||||
self.util_call_before_task_publish_handler(
|
||||
headers_to_use=headers,
|
||||
body_to_use=body,
|
||||
)
|
||||
|
||||
task_prerun_handler(task_id=headers["id"])
|
||||
|
||||
task = PaperlessTask.objects.get()
|
||||
|
||||
self.assertEqual(celery.states.STARTED, task.status)
|
||||
|
||||
def test_task_postrun_handler(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A celery task is started via the consume folder
|
||||
WHEN:
|
||||
- Task finished execution
|
||||
THEN:
|
||||
- The task is marked as started
|
||||
"""
|
||||
headers = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"task": "documents.tasks.consume_file",
|
||||
}
|
||||
body = (
|
||||
# args
|
||||
(
|
||||
ConsumableDocument(
|
||||
source=DocumentSource.ConsumeFolder,
|
||||
original_file="/consume/hello-9.pdf",
|
||||
),
|
||||
None,
|
||||
),
|
||||
# kwargs
|
||||
{},
|
||||
# celery stuff
|
||||
{"callbacks": None, "errbacks": None, "chain": None, "chord": None},
|
||||
)
|
||||
self.util_call_before_task_publish_handler(
|
||||
headers_to_use=headers,
|
||||
body_to_use=body,
|
||||
)
|
||||
def test_records_success_with_dict_result(self):
|
||||
task = self._started_task()
|
||||
from documents.signals.handlers import task_postrun_handler
|
||||
|
||||
task_postrun_handler(
|
||||
task_id=headers["id"],
|
||||
retval="Success. New document id 1 created",
|
||||
state=celery.states.SUCCESS,
|
||||
task_id=task.task_id,
|
||||
retval={"document_id": 42},
|
||||
state="SUCCESS",
|
||||
)
|
||||
task.refresh_from_db()
|
||||
assert task.status == PaperlessTask.Status.SUCCESS
|
||||
assert task.result_data == {"document_id": 42}
|
||||
assert task.date_done is not None
|
||||
assert task.duration_seconds is not None
|
||||
assert task.wait_time_seconds is not None
|
||||
|
||||
task = PaperlessTask.objects.get()
|
||||
def test_records_failure_state(self):
|
||||
task = self._started_task()
|
||||
from documents.signals.handlers import task_postrun_handler
|
||||
|
||||
self.assertEqual(celery.states.SUCCESS, task.status)
|
||||
task_postrun_handler(task_id=task.task_id, retval="some error", state="FAILURE")
|
||||
task.refresh_from_db()
|
||||
assert task.status == PaperlessTask.Status.FAILURE
|
||||
|
||||
def test_task_failure_handler(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A celery task is started via the consume folder
|
||||
WHEN:
|
||||
- Task failed execution
|
||||
THEN:
|
||||
- The task is marked as failed
|
||||
"""
|
||||
headers = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"task": "documents.tasks.consume_file",
|
||||
}
|
||||
body = (
|
||||
# args
|
||||
(
|
||||
ConsumableDocument(
|
||||
source=DocumentSource.ConsumeFolder,
|
||||
original_file="/consume/hello-9.pdf",
|
||||
),
|
||||
None,
|
||||
),
|
||||
# kwargs
|
||||
{},
|
||||
# celery stuff
|
||||
{"callbacks": None, "errbacks": None, "chain": None, "chord": None},
|
||||
def test_parses_legacy_new_document_string(self):
|
||||
task = self._started_task()
|
||||
from documents.signals.handlers import task_postrun_handler
|
||||
|
||||
task_postrun_handler(
|
||||
task_id=task.task_id,
|
||||
retval="New document id 42 created",
|
||||
state="SUCCESS",
|
||||
)
|
||||
self.util_call_before_task_publish_handler(
|
||||
headers_to_use=headers,
|
||||
body_to_use=body,
|
||||
task.refresh_from_db()
|
||||
assert task.result_data["document_id"] == 42
|
||||
assert task.result_message == "New document id 42 created"
|
||||
|
||||
def test_parses_legacy_duplicate_string(self):
|
||||
task = self._started_task()
|
||||
from documents.signals.handlers import task_postrun_handler
|
||||
|
||||
task_postrun_handler(
|
||||
task_id=task.task_id,
|
||||
retval="It is a duplicate of some document (#99).",
|
||||
state="FAILURE",
|
||||
)
|
||||
task.refresh_from_db()
|
||||
assert task.result_data["duplicate_of"] == 99
|
||||
assert task.result_data["duplicate_in_trash"] is False
|
||||
|
||||
def test_ignores_unknown_task_id(self):
|
||||
from documents.signals.handlers import task_postrun_handler
|
||||
|
||||
task_postrun_handler(
|
||||
task_id="nonexistent",
|
||||
retval=None,
|
||||
state="SUCCESS",
|
||||
) # must not raise
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestTaskFailureHandler:
|
||||
def test_records_failure_with_exception(self):
|
||||
from django.utils import timezone
|
||||
|
||||
task = PaperlessTask.objects.create(
|
||||
task_id=str(uuid.uuid4()),
|
||||
task_type=PaperlessTask.TaskType.CONSUME_FILE,
|
||||
trigger_source=PaperlessTask.TriggerSource.WEB_UI,
|
||||
status=PaperlessTask.Status.STARTED,
|
||||
date_started=timezone.now(),
|
||||
)
|
||||
from documents.signals.handlers import task_failure_handler
|
||||
|
||||
task_failure_handler(
|
||||
task_id=headers["id"],
|
||||
exception="Example failure",
|
||||
task_id=task.task_id,
|
||||
exception=ValueError("PDF parse failed"),
|
||||
traceback=None,
|
||||
)
|
||||
task.refresh_from_db()
|
||||
assert task.status == PaperlessTask.Status.FAILURE
|
||||
assert task.result_data["error_type"] == "ValueError"
|
||||
assert task.result_data["error_message"] == "PDF parse failed"
|
||||
assert task.date_done is not None
|
||||
|
||||
task = PaperlessTask.objects.get()
|
||||
def test_ignores_none_task_id(self):
|
||||
from documents.signals.handlers import task_failure_handler
|
||||
|
||||
self.assertEqual(celery.states.FAILURE, task.status)
|
||||
|
||||
def test_add_to_index_indexes_root_once_for_root_documents(self) -> None:
|
||||
root = Document.objects.create(
|
||||
title="root",
|
||||
checksum="root",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
with mock.patch("documents.search.get_backend") as mock_get_backend:
|
||||
mock_backend = mock.MagicMock()
|
||||
mock_get_backend.return_value = mock_backend
|
||||
add_to_index(sender=None, document=root)
|
||||
|
||||
mock_backend.add_or_update.assert_called_once_with(root, effective_content="")
|
||||
|
||||
def test_add_to_index_reindexes_root_for_version_documents(self) -> None:
|
||||
root = Document.objects.create(
|
||||
title="root",
|
||||
checksum="root",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
version = Document.objects.create(
|
||||
title="version",
|
||||
checksum="version",
|
||||
mime_type="application/pdf",
|
||||
root_document=root,
|
||||
)
|
||||
|
||||
with mock.patch("documents.search.get_backend") as mock_get_backend:
|
||||
mock_backend = mock.MagicMock()
|
||||
mock_get_backend.return_value = mock_backend
|
||||
add_to_index(sender=None, document=version)
|
||||
|
||||
self.assertEqual(mock_backend.add_or_update.call_count, 1)
|
||||
self.assertEqual(
|
||||
mock_backend.add_or_update.call_args_list[0].args[0].id,
|
||||
version.id,
|
||||
)
|
||||
self.assertEqual(
|
||||
mock_backend.add_or_update.call_args_list[0].kwargs,
|
||||
{"effective_content": version.content},
|
||||
)
|
||||
task_failure_handler(task_id=None, exception=ValueError("x"), traceback=None)
|
||||
|
||||
Reference in New Issue
Block a user