diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 7cb83ccc8..3494439d6 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING from typing import Any from celery import shared_task -from celery import states from celery.signals import before_task_publish from celery.signals import task_failure from celery.signals import task_postrun @@ -31,6 +30,7 @@ from documents import matching from documents.caching import clear_document_caches from documents.caching import invalidate_llm_suggestions_cache from documents.data_models import ConsumableDocument +from documents.data_models import DocumentSource from documents.file_handling import create_source_path_directory from documents.file_handling import delete_empty_directories from documents.file_handling import generate_filename @@ -996,67 +996,175 @@ def run_workflows( return overrides, "\n".join(messages) +# --------------------------------------------------------------------------- +# Task tracking -- Celery signal handlers +# --------------------------------------------------------------------------- + +TRACKED_TASKS: dict[str, PaperlessTask.TaskType] = { + "documents.tasks.consume_file": PaperlessTask.TaskType.CONSUME_FILE, + "documents.tasks.train_classifier": PaperlessTask.TaskType.TRAIN_CLASSIFIER, + "documents.tasks.sanity_check": PaperlessTask.TaskType.SANITY_CHECK, + "documents.tasks.index_optimize": PaperlessTask.TaskType.INDEX_OPTIMIZE, + "documents.tasks.llmindex_index": PaperlessTask.TaskType.LLM_INDEX, + "paperless_mail.tasks.process_mail_accounts": PaperlessTask.TaskType.MAIL_FETCH, +} + +_DOCUMENT_SOURCE_TO_TRIGGER: dict[Any, PaperlessTask.TriggerSource] = { + DocumentSource.ConsumeFolder: PaperlessTask.TriggerSource.FOLDER_CONSUME, + DocumentSource.ApiUpload: PaperlessTask.TriggerSource.API_UPLOAD, + DocumentSource.MailFetch: PaperlessTask.TriggerSource.EMAIL_CONSUME, + DocumentSource.WebUI: PaperlessTask.TriggerSource.WEB_UI, +} + + +def _extract_input_data( + task_type: PaperlessTask.TaskType, + args: tuple, + task_kwargs: dict, +) -> dict: + if task_type == PaperlessTask.TaskType.CONSUME_FILE: + input_doc = args[0] if args else task_kwargs.get("input_doc") + overrides = args[1] if len(args) >= 2 else task_kwargs.get("overrides") + if input_doc is None: + return {} + data: dict = { + "filename": input_doc.original_file.name, + "mime_type": input_doc.mime_type, + } + if input_doc.original_path: + data["source_path"] = str(input_doc.original_path) + if input_doc.mailrule_id: + data["mailrule_id"] = input_doc.mailrule_id + if overrides: + override_dict = { + k: v + for k, v in vars(overrides).items() + if v is not None and not k.startswith("_") + } + if override_dict: + data["overrides"] = override_dict + return data + + if task_type == PaperlessTask.TaskType.MAIL_FETCH: + account_ids = args[0] if args else task_kwargs.get("account_ids") + return {"account_ids": account_ids} + + return {} + + +def _determine_trigger_source( + task_type: PaperlessTask.TaskType, + args: tuple, + task_kwargs: dict, + headers: dict, +) -> PaperlessTask.TriggerSource: + # Explicit header takes priority -- covers beat ("scheduled") and system auto-runs ("system") + header_source = headers.get("trigger_source") + if header_source == "scheduled": + return PaperlessTask.TriggerSource.SCHEDULED + if header_source == "system": + return PaperlessTask.TriggerSource.SYSTEM + + if task_type == PaperlessTask.TaskType.CONSUME_FILE: + input_doc = args[0] if args else task_kwargs.get("input_doc") + if input_doc is not None: + return _DOCUMENT_SOURCE_TO_TRIGGER.get( + input_doc.source, + PaperlessTask.TriggerSource.API_UPLOAD, + ) + + return PaperlessTask.TriggerSource.MANUAL + + +def _extract_owner_id( + task_type: PaperlessTask.TaskType, + args: tuple, + task_kwargs: dict, +) -> int | None: + if task_type != PaperlessTask.TaskType.CONSUME_FILE: + return None + overrides = args[1] if len(args) >= 2 else task_kwargs.get("overrides") + if overrides and hasattr(overrides, "owner_id"): + return overrides.owner_id + return None + + +def _parse_legacy_result(result: str) -> dict | None: + import re as _re + + if match := _re.search(r"New document id (\d+) created", result): + return {"document_id": int(match.group(1))} + if match := _re.search(r"It is a duplicate of .* \(#(\d+)\)", result): + return { + "duplicate_of": int(match.group(1)), + "duplicate_in_trash": "existing document is in the trash" in result, + } + return None + + @before_task_publish.connect -def before_task_publish_handler(sender=None, headers=None, body=None, **kwargs) -> None: +def before_task_publish_handler( + sender=None, + headers=None, + body=None, + **kwargs, +) -> None: """ - Creates the PaperlessTask object in a pending state. This is sent before - the task reaches the broker, but before it begins executing on a worker. + Creates the PaperlessTask record when the task is published to broker. https://docs.celeryq.dev/en/stable/userguide/signals.html#before-task-publish - https://docs.celeryq.dev/en/stable/internals/protocol.html#version-2 - """ - if "task" not in headers or headers["task"] != "documents.tasks.consume_file": - # Assumption: this is only ever a v2 message + if headers is None or body is None: + return + + task_name = headers.get("task", "") + task_type = TRACKED_TASKS.get(task_name) + if task_type is None: return try: close_old_connections() + args, task_kwargs, _ = body + task_id = headers["id"] - task_args = body[0] - input_doc, overrides = task_args - - task_file_name = input_doc.original_file.name - user_id = overrides.owner_id if overrides else None + input_data = _extract_input_data(task_type, args, task_kwargs) + trigger_source = _determine_trigger_source( + task_type, + args, + task_kwargs, + headers, + ) + owner_id = _extract_owner_id(task_type, args, task_kwargs) PaperlessTask.objects.create( - trigger_source=PaperlessTask.TriggerSource.FOLDER_CONSUME, - task_id=headers["id"], + task_id=task_id, + task_type=task_type, + trigger_source=trigger_source, status=PaperlessTask.Status.PENDING, - input_data={"filename": task_file_name}, - task_type=PaperlessTask.TaskType.CONSUME_FILE, - date_created=timezone.now(), - date_started=None, - date_done=None, - owner_id=user_id, + input_data=input_data, + owner_id=owner_id, ) - except Exception: # pragma: no cover - # Don't let an exception in the signal handlers prevent - # a document from being consumed. + except Exception: logger.exception("Creating PaperlessTask failed") @task_prerun.connect def task_prerun_handler(sender=None, task_id=None, task=None, **kwargs) -> None: """ - - Updates the PaperlessTask to be started. Sent before the task begins execution - on a worker. + Marks the task STARTED when execution begins on a worker. https://docs.celeryq.dev/en/stable/userguide/signals.html#task-prerun """ + if task_id is None: + return try: close_old_connections() - task_instance = PaperlessTask.objects.filter(task_id=task_id).first() - - if task_instance is not None: - task_instance.status = PaperlessTask.Status.STARTED - task_instance.date_started = timezone.now() - task_instance.save() - except Exception: # pragma: no cover - # Don't let an exception in the signal handlers prevent - # a document from being consumed. + PaperlessTask.objects.filter(task_id=task_id).update( + status=PaperlessTask.Status.STARTED, + date_started=timezone.now(), + ) + except Exception: logger.exception("Setting PaperlessTask started failed") @@ -1070,33 +1178,53 @@ def task_postrun_handler( **kwargs, ) -> None: """ - Updates the result of the PaperlessTask. + Records task completion and result data. https://docs.celeryq.dev/en/stable/userguide/signals.html#task-postrun """ + if task_id is None: + return try: close_old_connections() - task_instance = PaperlessTask.objects.filter(task_id=task_id).first() - if task_instance is not None: - _CELERY_STATE_MAP = { - states.SUCCESS: PaperlessTask.Status.SUCCESS, - states.FAILURE: PaperlessTask.Status.FAILURE, - states.REVOKED: PaperlessTask.Status.REVOKED, - states.STARTED: PaperlessTask.Status.STARTED, - states.PENDING: PaperlessTask.Status.PENDING, - } - task_instance.status = _CELERY_STATE_MAP.get( - state, - PaperlessTask.Status.FAILURE, - ) - if isinstance(retval, str): - task_instance.result_message = retval - task_instance.date_done = timezone.now() - task_instance.save() - except Exception: # pragma: no cover - # Don't let an exception in the signal handlers prevent - # a document from being consumed. + status_map = { + "SUCCESS": PaperlessTask.Status.SUCCESS, + "FAILURE": PaperlessTask.Status.FAILURE, + "REVOKED": PaperlessTask.Status.REVOKED, + } + new_status = status_map.get(state, PaperlessTask.Status.FAILURE) + + result_data: dict | None = None + result_message: str | None = None + if isinstance(retval, dict): + result_data = retval + elif isinstance(retval, str): + result_message = retval + result_data = _parse_legacy_result(retval) + + now = timezone.now() + task_instance = PaperlessTask.objects.filter(task_id=task_id).first() + if task_instance is None: + return + + duration_seconds: float | None = None + wait_time_seconds: float | None = None + if task_instance.date_started: + duration_seconds = (now - task_instance.date_started).total_seconds() + if task_instance.date_started and task_instance.date_created: + wait_time_seconds = ( + task_instance.date_started - task_instance.date_created + ).total_seconds() + + PaperlessTask.objects.filter(task_id=task_id).update( + status=new_status, + result_data=result_data, + result_message=result_message, + date_done=now, + duration_seconds=duration_seconds, + wait_time_seconds=wait_time_seconds, + ) + except Exception: logger.exception("Updating PaperlessTask failed") @@ -1110,21 +1238,33 @@ def task_failure_handler( **kwargs, ) -> None: """ - Updates the result of a failed PaperlessTask. + Records failure details when a task raises an exception. https://docs.celeryq.dev/en/stable/userguide/signals.html#task-failure """ + if task_id is None: + return try: close_old_connections() - task_instance = PaperlessTask.objects.filter(task_id=task_id).first() - if task_instance is not None and task_instance.result_message is None: - task_instance.status = PaperlessTask.Status.FAILURE - task_instance.result_message = str(traceback) if traceback else None - task_instance.date_done = timezone.now() - task_instance.save() - except Exception: # pragma: no cover - logger.exception("Updating PaperlessTask failed") + result_data: dict = { + "error_type": type(exception).__name__ if exception else "Unknown", + "error_message": str(exception) if exception else "Unknown error", + } + if traceback: + import traceback as _tb + + tb_str = "".join(_tb.format_tb(traceback)) + result_data["traceback"] = tb_str[:5000] + + PaperlessTask.objects.filter(task_id=task_id).update( + status=PaperlessTask.Status.FAILURE, + result_data=result_data, + result_message=str(exception) if exception else None, + date_done=timezone.now(), + ) + except Exception: + logger.exception("Updating PaperlessTask on failure failed") @worker_process_init.connect diff --git a/src/documents/tests/test_task_signals.py b/src/documents/tests/test_task_signals.py index 3dcbbeaff..e2335af16 100644 --- a/src/documents/tests/test_task_signals.py +++ b/src/documents/tests/test_task_signals.py @@ -1,250 +1,266 @@ import uuid from unittest import mock -import celery -from django.contrib.auth import get_user_model -from django.test import TestCase +import pytest from documents.data_models import ConsumableDocument from documents.data_models import DocumentMetadataOverrides from documents.data_models import DocumentSource -from documents.models import Document from documents.models import PaperlessTask -from documents.signals.handlers import add_to_index -from documents.signals.handlers import before_task_publish_handler -from documents.signals.handlers import task_failure_handler -from documents.signals.handlers import task_postrun_handler -from documents.signals.handlers import task_prerun_handler -from documents.tests.test_consumer import fake_magic_from_file -from documents.tests.utils import DirectoriesMixin -@mock.patch("documents.consumer.magic.from_file", fake_magic_from_file) -class TestTaskSignalHandler(DirectoriesMixin, TestCase): - @classmethod - def setUpTestData(cls) -> None: - super().setUpTestData() - cls.user = get_user_model().objects.create_user(username="testuser") +@pytest.fixture +def consume_input_doc(): + doc = mock.MagicMock(spec=ConsumableDocument) + # original_file is a Path; configure the nested mock so .name works + doc.original_file = mock.MagicMock() + doc.original_file.name = "invoice.pdf" + doc.original_path = None + doc.mime_type = "application/pdf" + doc.mailrule_id = None + doc.source = DocumentSource.WebUI + return doc - def util_call_before_task_publish_handler( + +@pytest.fixture +def consume_overrides(django_user_model): + user = django_user_model.objects.create_user(username="testuser") + overrides = mock.MagicMock(spec=DocumentMetadataOverrides) + overrides.owner_id = user.id + return overrides + + +def send_publish( + task_name: str, + args: tuple, + kwargs: dict, + headers: dict | None = None, +) -> str: + from documents.signals.handlers import before_task_publish_handler + + task_id = str(uuid.uuid4()) + hdrs = {"task": task_name, "id": task_id, **(headers or {})} + before_task_publish_handler(sender=task_name, headers=hdrs, body=(args, kwargs, {})) + return task_id + + +@pytest.mark.django_db +class TestBeforeTaskPublishHandler: + def test_creates_task_for_consume_file(self, consume_input_doc, consume_overrides): + task_id = send_publish( + "documents.tasks.consume_file", + (consume_input_doc, consume_overrides), + {}, + ) + task = PaperlessTask.objects.get(task_id=task_id) + assert task.task_type == PaperlessTask.TaskType.CONSUME_FILE + assert task.status == PaperlessTask.Status.PENDING + assert task.trigger_source == PaperlessTask.TriggerSource.WEB_UI + assert task.input_data["filename"] == "invoice.pdf" + assert task.owner_id == consume_overrides.owner_id + + def test_creates_task_for_train_classifier(self): + task_id = send_publish("documents.tasks.train_classifier", (), {}) + task = PaperlessTask.objects.get(task_id=task_id) + assert task.task_type == PaperlessTask.TaskType.TRAIN_CLASSIFIER + assert task.trigger_source == PaperlessTask.TriggerSource.MANUAL + + def test_creates_task_for_sanity_check(self): + task_id = send_publish("documents.tasks.sanity_check", (), {}) + task = PaperlessTask.objects.get(task_id=task_id) + assert task.task_type == PaperlessTask.TaskType.SANITY_CHECK + + def test_creates_task_for_process_mail_accounts(self): + task_id = send_publish( + "paperless_mail.tasks.process_mail_accounts", + (), + {"account_ids": [1, 2]}, + ) + task = PaperlessTask.objects.get(task_id=task_id) + assert task.task_type == PaperlessTask.TaskType.MAIL_FETCH + assert task.input_data["account_ids"] == [1, 2] + + def test_scheduled_header_sets_trigger_source(self): + task_id = send_publish( + "documents.tasks.train_classifier", + (), + {}, + headers={"trigger_source": "scheduled"}, + ) + task = PaperlessTask.objects.get(task_id=task_id) + assert task.trigger_source == PaperlessTask.TriggerSource.SCHEDULED + + def test_system_header_sets_trigger_source(self): + task_id = send_publish( + "documents.tasks.llmindex_index", + (), + {"rebuild": True}, + headers={"trigger_source": "system"}, + ) + task = PaperlessTask.objects.get(task_id=task_id) + assert task.trigger_source == PaperlessTask.TriggerSource.SYSTEM + + def test_ignores_untracked_task(self): + send_publish("documents.tasks.bulk_update_documents", ([1, 2],), {}) + assert PaperlessTask.objects.count() == 0 + + def test_ignores_none_headers(self): + from documents.signals.handlers import before_task_publish_handler + + before_task_publish_handler(sender=None, headers=None, body=None) + assert PaperlessTask.objects.count() == 0 + + def test_consume_folder_source_maps_correctly( self, - headers_to_use, - body_to_use, - ) -> None: - """ - Simple utility to call the pre-run handle and ensure it created a single task - instance - """ - self.assertEqual(PaperlessTask.objects.all().count(), 0) - - before_task_publish_handler(headers=headers_to_use, body=body_to_use) - - self.assertEqual(PaperlessTask.objects.all().count(), 1) - - def test_before_task_publish_handler_consume(self) -> None: - """ - GIVEN: - - A celery task is started via the consume folder - WHEN: - - Task before publish handler is called - THEN: - - The task is created and marked as pending - """ - headers = { - "id": str(uuid.uuid4()), - "task": "documents.tasks.consume_file", - } - body = ( - # args - ( - ConsumableDocument( - source=DocumentSource.ConsumeFolder, - original_file="/consume/hello-999.pdf", - ), - DocumentMetadataOverrides( - title="Hello world", - owner_id=self.user.id, - ), - ), - # kwargs + consume_input_doc, + consume_overrides, + ): + consume_input_doc.source = DocumentSource.ConsumeFolder + task_id = send_publish( + "documents.tasks.consume_file", + (consume_input_doc, consume_overrides), {}, - # celery stuff - {"callbacks": None, "errbacks": None, "chain": None, "chord": None}, - ) - self.util_call_before_task_publish_handler( - headers_to_use=headers, - body_to_use=body, ) + task = PaperlessTask.objects.get(task_id=task_id) + assert task.trigger_source == PaperlessTask.TriggerSource.FOLDER_CONSUME - task = PaperlessTask.objects.get() - self.assertIsNotNone(task) - self.assertEqual(headers["id"], task.task_id) - self.assertEqual("hello-999.pdf", task.task_file_name) - self.assertEqual(PaperlessTask.TaskName.CONSUME_FILE, task.task_name) - self.assertEqual(self.user.id, task.owner_id) - self.assertEqual(celery.states.PENDING, task.status) - - def test_task_prerun_handler(self) -> None: - """ - GIVEN: - - A celery task is started via the consume folder - WHEN: - - Task starts execution - THEN: - - The task is marked as started - """ - - headers = { - "id": str(uuid.uuid4()), - "task": "documents.tasks.consume_file", - } - body = ( - # args - ( - ConsumableDocument( - source=DocumentSource.ConsumeFolder, - original_file="/consume/hello-99.pdf", - ), - None, - ), - # kwargs + def test_email_source_maps_correctly(self, consume_input_doc, consume_overrides): + consume_input_doc.source = DocumentSource.MailFetch + task_id = send_publish( + "documents.tasks.consume_file", + (consume_input_doc, consume_overrides), {}, - # celery stuff - {"callbacks": None, "errbacks": None, "chain": None, "chord": None}, + ) + task = PaperlessTask.objects.get(task_id=task_id) + assert task.trigger_source == PaperlessTask.TriggerSource.EMAIL_CONSUME + + +@pytest.mark.django_db +class TestTaskPrerunHandler: + def test_marks_task_started(self): + task = PaperlessTask.objects.create( + task_id=str(uuid.uuid4()), + task_type=PaperlessTask.TaskType.CONSUME_FILE, + trigger_source=PaperlessTask.TriggerSource.MANUAL, + status=PaperlessTask.Status.PENDING, + ) + from documents.signals.handlers import task_prerun_handler + + task_prerun_handler(task_id=task.task_id) + task.refresh_from_db() + assert task.status == PaperlessTask.Status.STARTED + assert task.date_started is not None + + def test_ignores_unknown_task_id(self): + from documents.signals.handlers import task_prerun_handler + + task_prerun_handler(task_id="nonexistent-id") # must not raise + + def test_ignores_none_task_id(self): + from documents.signals.handlers import task_prerun_handler + + task_prerun_handler(task_id=None) # must not raise + + +@pytest.mark.django_db +class TestTaskPostrunHandler: + def _started_task(self) -> PaperlessTask: + from django.utils import timezone + + return PaperlessTask.objects.create( + task_id=str(uuid.uuid4()), + task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER, + trigger_source=PaperlessTask.TriggerSource.MANUAL, + status=PaperlessTask.Status.STARTED, + date_started=timezone.now(), ) - self.util_call_before_task_publish_handler( - headers_to_use=headers, - body_to_use=body, - ) - - task_prerun_handler(task_id=headers["id"]) - - task = PaperlessTask.objects.get() - - self.assertEqual(celery.states.STARTED, task.status) - - def test_task_postrun_handler(self) -> None: - """ - GIVEN: - - A celery task is started via the consume folder - WHEN: - - Task finished execution - THEN: - - The task is marked as started - """ - headers = { - "id": str(uuid.uuid4()), - "task": "documents.tasks.consume_file", - } - body = ( - # args - ( - ConsumableDocument( - source=DocumentSource.ConsumeFolder, - original_file="/consume/hello-9.pdf", - ), - None, - ), - # kwargs - {}, - # celery stuff - {"callbacks": None, "errbacks": None, "chain": None, "chord": None}, - ) - self.util_call_before_task_publish_handler( - headers_to_use=headers, - body_to_use=body, - ) + def test_records_success_with_dict_result(self): + task = self._started_task() + from documents.signals.handlers import task_postrun_handler task_postrun_handler( - task_id=headers["id"], - retval="Success. New document id 1 created", - state=celery.states.SUCCESS, + task_id=task.task_id, + retval={"document_id": 42}, + state="SUCCESS", ) + task.refresh_from_db() + assert task.status == PaperlessTask.Status.SUCCESS + assert task.result_data == {"document_id": 42} + assert task.date_done is not None + assert task.duration_seconds is not None + assert task.wait_time_seconds is not None - task = PaperlessTask.objects.get() + def test_records_failure_state(self): + task = self._started_task() + from documents.signals.handlers import task_postrun_handler - self.assertEqual(celery.states.SUCCESS, task.status) + task_postrun_handler(task_id=task.task_id, retval="some error", state="FAILURE") + task.refresh_from_db() + assert task.status == PaperlessTask.Status.FAILURE - def test_task_failure_handler(self) -> None: - """ - GIVEN: - - A celery task is started via the consume folder - WHEN: - - Task failed execution - THEN: - - The task is marked as failed - """ - headers = { - "id": str(uuid.uuid4()), - "task": "documents.tasks.consume_file", - } - body = ( - # args - ( - ConsumableDocument( - source=DocumentSource.ConsumeFolder, - original_file="/consume/hello-9.pdf", - ), - None, - ), - # kwargs - {}, - # celery stuff - {"callbacks": None, "errbacks": None, "chain": None, "chord": None}, + def test_parses_legacy_new_document_string(self): + task = self._started_task() + from documents.signals.handlers import task_postrun_handler + + task_postrun_handler( + task_id=task.task_id, + retval="New document id 42 created", + state="SUCCESS", ) - self.util_call_before_task_publish_handler( - headers_to_use=headers, - body_to_use=body, + task.refresh_from_db() + assert task.result_data["document_id"] == 42 + assert task.result_message == "New document id 42 created" + + def test_parses_legacy_duplicate_string(self): + task = self._started_task() + from documents.signals.handlers import task_postrun_handler + + task_postrun_handler( + task_id=task.task_id, + retval="It is a duplicate of some document (#99).", + state="FAILURE", ) + task.refresh_from_db() + assert task.result_data["duplicate_of"] == 99 + assert task.result_data["duplicate_in_trash"] is False + + def test_ignores_unknown_task_id(self): + from documents.signals.handlers import task_postrun_handler + + task_postrun_handler( + task_id="nonexistent", + retval=None, + state="SUCCESS", + ) # must not raise + + +@pytest.mark.django_db +class TestTaskFailureHandler: + def test_records_failure_with_exception(self): + from django.utils import timezone + + task = PaperlessTask.objects.create( + task_id=str(uuid.uuid4()), + task_type=PaperlessTask.TaskType.CONSUME_FILE, + trigger_source=PaperlessTask.TriggerSource.WEB_UI, + status=PaperlessTask.Status.STARTED, + date_started=timezone.now(), + ) + from documents.signals.handlers import task_failure_handler task_failure_handler( - task_id=headers["id"], - exception="Example failure", + task_id=task.task_id, + exception=ValueError("PDF parse failed"), + traceback=None, ) + task.refresh_from_db() + assert task.status == PaperlessTask.Status.FAILURE + assert task.result_data["error_type"] == "ValueError" + assert task.result_data["error_message"] == "PDF parse failed" + assert task.date_done is not None - task = PaperlessTask.objects.get() + def test_ignores_none_task_id(self): + from documents.signals.handlers import task_failure_handler - self.assertEqual(celery.states.FAILURE, task.status) - - def test_add_to_index_indexes_root_once_for_root_documents(self) -> None: - root = Document.objects.create( - title="root", - checksum="root", - mime_type="application/pdf", - ) - - with mock.patch("documents.search.get_backend") as mock_get_backend: - mock_backend = mock.MagicMock() - mock_get_backend.return_value = mock_backend - add_to_index(sender=None, document=root) - - mock_backend.add_or_update.assert_called_once_with(root, effective_content="") - - def test_add_to_index_reindexes_root_for_version_documents(self) -> None: - root = Document.objects.create( - title="root", - checksum="root", - mime_type="application/pdf", - ) - version = Document.objects.create( - title="version", - checksum="version", - mime_type="application/pdf", - root_document=root, - ) - - with mock.patch("documents.search.get_backend") as mock_get_backend: - mock_backend = mock.MagicMock() - mock_get_backend.return_value = mock_backend - add_to_index(sender=None, document=version) - - self.assertEqual(mock_backend.add_or_update.call_count, 1) - self.assertEqual( - mock_backend.add_or_update.call_args_list[0].args[0].id, - version.id, - ) - self.assertEqual( - mock_backend.add_or_update.call_args_list[0].kwargs, - {"effective_content": version.content}, - ) + task_failure_handler(task_id=None, exception=ValueError("x"), traceback=None)