feat(tasks): rewrite signal handlers to track all task types

Replace the old consume_file-only handler with a full rewrite that tracks
6 task types (consume_file, train_classifier, sanity_check, index_optimize,
llm_index, mail_fetch) with proper trigger source detection, input data
extraction, legacy result string parsing, duration/wait time recording,
and structured error capture on failure.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
stumpylog
2026-04-15 13:59:01 -07:00
parent ef4e3d31ef
commit ad47e96df1
2 changed files with 441 additions and 285 deletions

View File

@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING
from typing import Any
from celery import shared_task
from celery import states
from celery.signals import before_task_publish
from celery.signals import task_failure
from celery.signals import task_postrun
@@ -31,6 +30,7 @@ from documents import matching
from documents.caching import clear_document_caches
from documents.caching import invalidate_llm_suggestions_cache
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentSource
from documents.file_handling import create_source_path_directory
from documents.file_handling import delete_empty_directories
from documents.file_handling import generate_filename
@@ -996,67 +996,175 @@ def run_workflows(
return overrides, "\n".join(messages)
# ---------------------------------------------------------------------------
# Task tracking -- Celery signal handlers
# ---------------------------------------------------------------------------
TRACKED_TASKS: dict[str, PaperlessTask.TaskType] = {
"documents.tasks.consume_file": PaperlessTask.TaskType.CONSUME_FILE,
"documents.tasks.train_classifier": PaperlessTask.TaskType.TRAIN_CLASSIFIER,
"documents.tasks.sanity_check": PaperlessTask.TaskType.SANITY_CHECK,
"documents.tasks.index_optimize": PaperlessTask.TaskType.INDEX_OPTIMIZE,
"documents.tasks.llmindex_index": PaperlessTask.TaskType.LLM_INDEX,
"paperless_mail.tasks.process_mail_accounts": PaperlessTask.TaskType.MAIL_FETCH,
}
_DOCUMENT_SOURCE_TO_TRIGGER: dict[Any, PaperlessTask.TriggerSource] = {
DocumentSource.ConsumeFolder: PaperlessTask.TriggerSource.FOLDER_CONSUME,
DocumentSource.ApiUpload: PaperlessTask.TriggerSource.API_UPLOAD,
DocumentSource.MailFetch: PaperlessTask.TriggerSource.EMAIL_CONSUME,
DocumentSource.WebUI: PaperlessTask.TriggerSource.WEB_UI,
}
def _extract_input_data(
task_type: PaperlessTask.TaskType,
args: tuple,
task_kwargs: dict,
) -> dict:
if task_type == PaperlessTask.TaskType.CONSUME_FILE:
input_doc = args[0] if args else task_kwargs.get("input_doc")
overrides = args[1] if len(args) >= 2 else task_kwargs.get("overrides")
if input_doc is None:
return {}
data: dict = {
"filename": input_doc.original_file.name,
"mime_type": input_doc.mime_type,
}
if input_doc.original_path:
data["source_path"] = str(input_doc.original_path)
if input_doc.mailrule_id:
data["mailrule_id"] = input_doc.mailrule_id
if overrides:
override_dict = {
k: v
for k, v in vars(overrides).items()
if v is not None and not k.startswith("_")
}
if override_dict:
data["overrides"] = override_dict
return data
if task_type == PaperlessTask.TaskType.MAIL_FETCH:
account_ids = args[0] if args else task_kwargs.get("account_ids")
return {"account_ids": account_ids}
return {}
def _determine_trigger_source(
task_type: PaperlessTask.TaskType,
args: tuple,
task_kwargs: dict,
headers: dict,
) -> PaperlessTask.TriggerSource:
# Explicit header takes priority -- covers beat ("scheduled") and system auto-runs ("system")
header_source = headers.get("trigger_source")
if header_source == "scheduled":
return PaperlessTask.TriggerSource.SCHEDULED
if header_source == "system":
return PaperlessTask.TriggerSource.SYSTEM
if task_type == PaperlessTask.TaskType.CONSUME_FILE:
input_doc = args[0] if args else task_kwargs.get("input_doc")
if input_doc is not None:
return _DOCUMENT_SOURCE_TO_TRIGGER.get(
input_doc.source,
PaperlessTask.TriggerSource.API_UPLOAD,
)
return PaperlessTask.TriggerSource.MANUAL
def _extract_owner_id(
task_type: PaperlessTask.TaskType,
args: tuple,
task_kwargs: dict,
) -> int | None:
if task_type != PaperlessTask.TaskType.CONSUME_FILE:
return None
overrides = args[1] if len(args) >= 2 else task_kwargs.get("overrides")
if overrides and hasattr(overrides, "owner_id"):
return overrides.owner_id
return None
def _parse_legacy_result(result: str) -> dict | None:
import re as _re
if match := _re.search(r"New document id (\d+) created", result):
return {"document_id": int(match.group(1))}
if match := _re.search(r"It is a duplicate of .* \(#(\d+)\)", result):
return {
"duplicate_of": int(match.group(1)),
"duplicate_in_trash": "existing document is in the trash" in result,
}
return None
@before_task_publish.connect
def before_task_publish_handler(sender=None, headers=None, body=None, **kwargs) -> None:
def before_task_publish_handler(
sender=None,
headers=None,
body=None,
**kwargs,
) -> None:
"""
Creates the PaperlessTask object in a pending state. This is sent before
the task reaches the broker, but before it begins executing on a worker.
Creates the PaperlessTask record when the task is published to broker.
https://docs.celeryq.dev/en/stable/userguide/signals.html#before-task-publish
https://docs.celeryq.dev/en/stable/internals/protocol.html#version-2
"""
if "task" not in headers or headers["task"] != "documents.tasks.consume_file":
# Assumption: this is only ever a v2 message
if headers is None or body is None:
return
task_name = headers.get("task", "")
task_type = TRACKED_TASKS.get(task_name)
if task_type is None:
return
try:
close_old_connections()
args, task_kwargs, _ = body
task_id = headers["id"]
task_args = body[0]
input_doc, overrides = task_args
task_file_name = input_doc.original_file.name
user_id = overrides.owner_id if overrides else None
input_data = _extract_input_data(task_type, args, task_kwargs)
trigger_source = _determine_trigger_source(
task_type,
args,
task_kwargs,
headers,
)
owner_id = _extract_owner_id(task_type, args, task_kwargs)
PaperlessTask.objects.create(
trigger_source=PaperlessTask.TriggerSource.FOLDER_CONSUME,
task_id=headers["id"],
task_id=task_id,
task_type=task_type,
trigger_source=trigger_source,
status=PaperlessTask.Status.PENDING,
input_data={"filename": task_file_name},
task_type=PaperlessTask.TaskType.CONSUME_FILE,
date_created=timezone.now(),
date_started=None,
date_done=None,
owner_id=user_id,
input_data=input_data,
owner_id=owner_id,
)
except Exception: # pragma: no cover
# Don't let an exception in the signal handlers prevent
# a document from being consumed.
except Exception:
logger.exception("Creating PaperlessTask failed")
@task_prerun.connect
def task_prerun_handler(sender=None, task_id=None, task=None, **kwargs) -> None:
"""
Updates the PaperlessTask to be started. Sent before the task begins execution
on a worker.
Marks the task STARTED when execution begins on a worker.
https://docs.celeryq.dev/en/stable/userguide/signals.html#task-prerun
"""
if task_id is None:
return
try:
close_old_connections()
task_instance = PaperlessTask.objects.filter(task_id=task_id).first()
if task_instance is not None:
task_instance.status = PaperlessTask.Status.STARTED
task_instance.date_started = timezone.now()
task_instance.save()
except Exception: # pragma: no cover
# Don't let an exception in the signal handlers prevent
# a document from being consumed.
PaperlessTask.objects.filter(task_id=task_id).update(
status=PaperlessTask.Status.STARTED,
date_started=timezone.now(),
)
except Exception:
logger.exception("Setting PaperlessTask started failed")
@@ -1070,33 +1178,53 @@ def task_postrun_handler(
**kwargs,
) -> None:
"""
Updates the result of the PaperlessTask.
Records task completion and result data.
https://docs.celeryq.dev/en/stable/userguide/signals.html#task-postrun
"""
if task_id is None:
return
try:
close_old_connections()
task_instance = PaperlessTask.objects.filter(task_id=task_id).first()
if task_instance is not None:
_CELERY_STATE_MAP = {
states.SUCCESS: PaperlessTask.Status.SUCCESS,
states.FAILURE: PaperlessTask.Status.FAILURE,
states.REVOKED: PaperlessTask.Status.REVOKED,
states.STARTED: PaperlessTask.Status.STARTED,
states.PENDING: PaperlessTask.Status.PENDING,
}
task_instance.status = _CELERY_STATE_MAP.get(
state,
PaperlessTask.Status.FAILURE,
)
if isinstance(retval, str):
task_instance.result_message = retval
task_instance.date_done = timezone.now()
task_instance.save()
except Exception: # pragma: no cover
# Don't let an exception in the signal handlers prevent
# a document from being consumed.
status_map = {
"SUCCESS": PaperlessTask.Status.SUCCESS,
"FAILURE": PaperlessTask.Status.FAILURE,
"REVOKED": PaperlessTask.Status.REVOKED,
}
new_status = status_map.get(state, PaperlessTask.Status.FAILURE)
result_data: dict | None = None
result_message: str | None = None
if isinstance(retval, dict):
result_data = retval
elif isinstance(retval, str):
result_message = retval
result_data = _parse_legacy_result(retval)
now = timezone.now()
task_instance = PaperlessTask.objects.filter(task_id=task_id).first()
if task_instance is None:
return
duration_seconds: float | None = None
wait_time_seconds: float | None = None
if task_instance.date_started:
duration_seconds = (now - task_instance.date_started).total_seconds()
if task_instance.date_started and task_instance.date_created:
wait_time_seconds = (
task_instance.date_started - task_instance.date_created
).total_seconds()
PaperlessTask.objects.filter(task_id=task_id).update(
status=new_status,
result_data=result_data,
result_message=result_message,
date_done=now,
duration_seconds=duration_seconds,
wait_time_seconds=wait_time_seconds,
)
except Exception:
logger.exception("Updating PaperlessTask failed")
@@ -1110,21 +1238,33 @@ def task_failure_handler(
**kwargs,
) -> None:
"""
Updates the result of a failed PaperlessTask.
Records failure details when a task raises an exception.
https://docs.celeryq.dev/en/stable/userguide/signals.html#task-failure
"""
if task_id is None:
return
try:
close_old_connections()
task_instance = PaperlessTask.objects.filter(task_id=task_id).first()
if task_instance is not None and task_instance.result_message is None:
task_instance.status = PaperlessTask.Status.FAILURE
task_instance.result_message = str(traceback) if traceback else None
task_instance.date_done = timezone.now()
task_instance.save()
except Exception: # pragma: no cover
logger.exception("Updating PaperlessTask failed")
result_data: dict = {
"error_type": type(exception).__name__ if exception else "Unknown",
"error_message": str(exception) if exception else "Unknown error",
}
if traceback:
import traceback as _tb
tb_str = "".join(_tb.format_tb(traceback))
result_data["traceback"] = tb_str[:5000]
PaperlessTask.objects.filter(task_id=task_id).update(
status=PaperlessTask.Status.FAILURE,
result_data=result_data,
result_message=str(exception) if exception else None,
date_done=timezone.now(),
)
except Exception:
logger.exception("Updating PaperlessTask on failure failed")
@worker_process_init.connect

View File

@@ -1,250 +1,266 @@
import uuid
from unittest import mock
import celery
from django.contrib.auth import get_user_model
from django.test import TestCase
import pytest
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides
from documents.data_models import DocumentSource
from documents.models import Document
from documents.models import PaperlessTask
from documents.signals.handlers import add_to_index
from documents.signals.handlers import before_task_publish_handler
from documents.signals.handlers import task_failure_handler
from documents.signals.handlers import task_postrun_handler
from documents.signals.handlers import task_prerun_handler
from documents.tests.test_consumer import fake_magic_from_file
from documents.tests.utils import DirectoriesMixin
@mock.patch("documents.consumer.magic.from_file", fake_magic_from_file)
class TestTaskSignalHandler(DirectoriesMixin, TestCase):
@classmethod
def setUpTestData(cls) -> None:
super().setUpTestData()
cls.user = get_user_model().objects.create_user(username="testuser")
@pytest.fixture
def consume_input_doc():
doc = mock.MagicMock(spec=ConsumableDocument)
# original_file is a Path; configure the nested mock so .name works
doc.original_file = mock.MagicMock()
doc.original_file.name = "invoice.pdf"
doc.original_path = None
doc.mime_type = "application/pdf"
doc.mailrule_id = None
doc.source = DocumentSource.WebUI
return doc
def util_call_before_task_publish_handler(
@pytest.fixture
def consume_overrides(django_user_model):
user = django_user_model.objects.create_user(username="testuser")
overrides = mock.MagicMock(spec=DocumentMetadataOverrides)
overrides.owner_id = user.id
return overrides
def send_publish(
task_name: str,
args: tuple,
kwargs: dict,
headers: dict | None = None,
) -> str:
from documents.signals.handlers import before_task_publish_handler
task_id = str(uuid.uuid4())
hdrs = {"task": task_name, "id": task_id, **(headers or {})}
before_task_publish_handler(sender=task_name, headers=hdrs, body=(args, kwargs, {}))
return task_id
@pytest.mark.django_db
class TestBeforeTaskPublishHandler:
def test_creates_task_for_consume_file(self, consume_input_doc, consume_overrides):
task_id = send_publish(
"documents.tasks.consume_file",
(consume_input_doc, consume_overrides),
{},
)
task = PaperlessTask.objects.get(task_id=task_id)
assert task.task_type == PaperlessTask.TaskType.CONSUME_FILE
assert task.status == PaperlessTask.Status.PENDING
assert task.trigger_source == PaperlessTask.TriggerSource.WEB_UI
assert task.input_data["filename"] == "invoice.pdf"
assert task.owner_id == consume_overrides.owner_id
def test_creates_task_for_train_classifier(self):
task_id = send_publish("documents.tasks.train_classifier", (), {})
task = PaperlessTask.objects.get(task_id=task_id)
assert task.task_type == PaperlessTask.TaskType.TRAIN_CLASSIFIER
assert task.trigger_source == PaperlessTask.TriggerSource.MANUAL
def test_creates_task_for_sanity_check(self):
task_id = send_publish("documents.tasks.sanity_check", (), {})
task = PaperlessTask.objects.get(task_id=task_id)
assert task.task_type == PaperlessTask.TaskType.SANITY_CHECK
def test_creates_task_for_process_mail_accounts(self):
task_id = send_publish(
"paperless_mail.tasks.process_mail_accounts",
(),
{"account_ids": [1, 2]},
)
task = PaperlessTask.objects.get(task_id=task_id)
assert task.task_type == PaperlessTask.TaskType.MAIL_FETCH
assert task.input_data["account_ids"] == [1, 2]
def test_scheduled_header_sets_trigger_source(self):
task_id = send_publish(
"documents.tasks.train_classifier",
(),
{},
headers={"trigger_source": "scheduled"},
)
task = PaperlessTask.objects.get(task_id=task_id)
assert task.trigger_source == PaperlessTask.TriggerSource.SCHEDULED
def test_system_header_sets_trigger_source(self):
task_id = send_publish(
"documents.tasks.llmindex_index",
(),
{"rebuild": True},
headers={"trigger_source": "system"},
)
task = PaperlessTask.objects.get(task_id=task_id)
assert task.trigger_source == PaperlessTask.TriggerSource.SYSTEM
def test_ignores_untracked_task(self):
send_publish("documents.tasks.bulk_update_documents", ([1, 2],), {})
assert PaperlessTask.objects.count() == 0
def test_ignores_none_headers(self):
from documents.signals.handlers import before_task_publish_handler
before_task_publish_handler(sender=None, headers=None, body=None)
assert PaperlessTask.objects.count() == 0
def test_consume_folder_source_maps_correctly(
self,
headers_to_use,
body_to_use,
) -> None:
"""
Simple utility to call the pre-run handle and ensure it created a single task
instance
"""
self.assertEqual(PaperlessTask.objects.all().count(), 0)
before_task_publish_handler(headers=headers_to_use, body=body_to_use)
self.assertEqual(PaperlessTask.objects.all().count(), 1)
def test_before_task_publish_handler_consume(self) -> None:
"""
GIVEN:
- A celery task is started via the consume folder
WHEN:
- Task before publish handler is called
THEN:
- The task is created and marked as pending
"""
headers = {
"id": str(uuid.uuid4()),
"task": "documents.tasks.consume_file",
}
body = (
# args
(
ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file="/consume/hello-999.pdf",
),
DocumentMetadataOverrides(
title="Hello world",
owner_id=self.user.id,
),
),
# kwargs
consume_input_doc,
consume_overrides,
):
consume_input_doc.source = DocumentSource.ConsumeFolder
task_id = send_publish(
"documents.tasks.consume_file",
(consume_input_doc, consume_overrides),
{},
# celery stuff
{"callbacks": None, "errbacks": None, "chain": None, "chord": None},
)
self.util_call_before_task_publish_handler(
headers_to_use=headers,
body_to_use=body,
)
task = PaperlessTask.objects.get(task_id=task_id)
assert task.trigger_source == PaperlessTask.TriggerSource.FOLDER_CONSUME
task = PaperlessTask.objects.get()
self.assertIsNotNone(task)
self.assertEqual(headers["id"], task.task_id)
self.assertEqual("hello-999.pdf", task.task_file_name)
self.assertEqual(PaperlessTask.TaskName.CONSUME_FILE, task.task_name)
self.assertEqual(self.user.id, task.owner_id)
self.assertEqual(celery.states.PENDING, task.status)
def test_task_prerun_handler(self) -> None:
"""
GIVEN:
- A celery task is started via the consume folder
WHEN:
- Task starts execution
THEN:
- The task is marked as started
"""
headers = {
"id": str(uuid.uuid4()),
"task": "documents.tasks.consume_file",
}
body = (
# args
(
ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file="/consume/hello-99.pdf",
),
None,
),
# kwargs
def test_email_source_maps_correctly(self, consume_input_doc, consume_overrides):
consume_input_doc.source = DocumentSource.MailFetch
task_id = send_publish(
"documents.tasks.consume_file",
(consume_input_doc, consume_overrides),
{},
# celery stuff
{"callbacks": None, "errbacks": None, "chain": None, "chord": None},
)
task = PaperlessTask.objects.get(task_id=task_id)
assert task.trigger_source == PaperlessTask.TriggerSource.EMAIL_CONSUME
@pytest.mark.django_db
class TestTaskPrerunHandler:
def test_marks_task_started(self):
task = PaperlessTask.objects.create(
task_id=str(uuid.uuid4()),
task_type=PaperlessTask.TaskType.CONSUME_FILE,
trigger_source=PaperlessTask.TriggerSource.MANUAL,
status=PaperlessTask.Status.PENDING,
)
from documents.signals.handlers import task_prerun_handler
task_prerun_handler(task_id=task.task_id)
task.refresh_from_db()
assert task.status == PaperlessTask.Status.STARTED
assert task.date_started is not None
def test_ignores_unknown_task_id(self):
from documents.signals.handlers import task_prerun_handler
task_prerun_handler(task_id="nonexistent-id") # must not raise
def test_ignores_none_task_id(self):
from documents.signals.handlers import task_prerun_handler
task_prerun_handler(task_id=None) # must not raise
@pytest.mark.django_db
class TestTaskPostrunHandler:
def _started_task(self) -> PaperlessTask:
from django.utils import timezone
return PaperlessTask.objects.create(
task_id=str(uuid.uuid4()),
task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER,
trigger_source=PaperlessTask.TriggerSource.MANUAL,
status=PaperlessTask.Status.STARTED,
date_started=timezone.now(),
)
self.util_call_before_task_publish_handler(
headers_to_use=headers,
body_to_use=body,
)
task_prerun_handler(task_id=headers["id"])
task = PaperlessTask.objects.get()
self.assertEqual(celery.states.STARTED, task.status)
def test_task_postrun_handler(self) -> None:
"""
GIVEN:
- A celery task is started via the consume folder
WHEN:
- Task finished execution
THEN:
- The task is marked as started
"""
headers = {
"id": str(uuid.uuid4()),
"task": "documents.tasks.consume_file",
}
body = (
# args
(
ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file="/consume/hello-9.pdf",
),
None,
),
# kwargs
{},
# celery stuff
{"callbacks": None, "errbacks": None, "chain": None, "chord": None},
)
self.util_call_before_task_publish_handler(
headers_to_use=headers,
body_to_use=body,
)
def test_records_success_with_dict_result(self):
task = self._started_task()
from documents.signals.handlers import task_postrun_handler
task_postrun_handler(
task_id=headers["id"],
retval="Success. New document id 1 created",
state=celery.states.SUCCESS,
task_id=task.task_id,
retval={"document_id": 42},
state="SUCCESS",
)
task.refresh_from_db()
assert task.status == PaperlessTask.Status.SUCCESS
assert task.result_data == {"document_id": 42}
assert task.date_done is not None
assert task.duration_seconds is not None
assert task.wait_time_seconds is not None
task = PaperlessTask.objects.get()
def test_records_failure_state(self):
task = self._started_task()
from documents.signals.handlers import task_postrun_handler
self.assertEqual(celery.states.SUCCESS, task.status)
task_postrun_handler(task_id=task.task_id, retval="some error", state="FAILURE")
task.refresh_from_db()
assert task.status == PaperlessTask.Status.FAILURE
def test_task_failure_handler(self) -> None:
"""
GIVEN:
- A celery task is started via the consume folder
WHEN:
- Task failed execution
THEN:
- The task is marked as failed
"""
headers = {
"id": str(uuid.uuid4()),
"task": "documents.tasks.consume_file",
}
body = (
# args
(
ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file="/consume/hello-9.pdf",
),
None,
),
# kwargs
{},
# celery stuff
{"callbacks": None, "errbacks": None, "chain": None, "chord": None},
def test_parses_legacy_new_document_string(self):
task = self._started_task()
from documents.signals.handlers import task_postrun_handler
task_postrun_handler(
task_id=task.task_id,
retval="New document id 42 created",
state="SUCCESS",
)
self.util_call_before_task_publish_handler(
headers_to_use=headers,
body_to_use=body,
task.refresh_from_db()
assert task.result_data["document_id"] == 42
assert task.result_message == "New document id 42 created"
def test_parses_legacy_duplicate_string(self):
task = self._started_task()
from documents.signals.handlers import task_postrun_handler
task_postrun_handler(
task_id=task.task_id,
retval="It is a duplicate of some document (#99).",
state="FAILURE",
)
task.refresh_from_db()
assert task.result_data["duplicate_of"] == 99
assert task.result_data["duplicate_in_trash"] is False
def test_ignores_unknown_task_id(self):
from documents.signals.handlers import task_postrun_handler
task_postrun_handler(
task_id="nonexistent",
retval=None,
state="SUCCESS",
) # must not raise
@pytest.mark.django_db
class TestTaskFailureHandler:
def test_records_failure_with_exception(self):
from django.utils import timezone
task = PaperlessTask.objects.create(
task_id=str(uuid.uuid4()),
task_type=PaperlessTask.TaskType.CONSUME_FILE,
trigger_source=PaperlessTask.TriggerSource.WEB_UI,
status=PaperlessTask.Status.STARTED,
date_started=timezone.now(),
)
from documents.signals.handlers import task_failure_handler
task_failure_handler(
task_id=headers["id"],
exception="Example failure",
task_id=task.task_id,
exception=ValueError("PDF parse failed"),
traceback=None,
)
task.refresh_from_db()
assert task.status == PaperlessTask.Status.FAILURE
assert task.result_data["error_type"] == "ValueError"
assert task.result_data["error_message"] == "PDF parse failed"
assert task.date_done is not None
task = PaperlessTask.objects.get()
def test_ignores_none_task_id(self):
from documents.signals.handlers import task_failure_handler
self.assertEqual(celery.states.FAILURE, task.status)
def test_add_to_index_indexes_root_once_for_root_documents(self) -> None:
root = Document.objects.create(
title="root",
checksum="root",
mime_type="application/pdf",
)
with mock.patch("documents.search.get_backend") as mock_get_backend:
mock_backend = mock.MagicMock()
mock_get_backend.return_value = mock_backend
add_to_index(sender=None, document=root)
mock_backend.add_or_update.assert_called_once_with(root, effective_content="")
def test_add_to_index_reindexes_root_for_version_documents(self) -> None:
root = Document.objects.create(
title="root",
checksum="root",
mime_type="application/pdf",
)
version = Document.objects.create(
title="version",
checksum="version",
mime_type="application/pdf",
root_document=root,
)
with mock.patch("documents.search.get_backend") as mock_get_backend:
mock_backend = mock.MagicMock()
mock_get_backend.return_value = mock_backend
add_to_index(sender=None, document=version)
self.assertEqual(mock_backend.add_or_update.call_count, 1)
self.assertEqual(
mock_backend.add_or_update.call_args_list[0].args[0].id,
version.id,
)
self.assertEqual(
mock_backend.add_or_update.call_args_list[0].kwargs,
{"effective_content": version.content},
)
task_failure_handler(task_id=None, exception=ValueError("x"), traceback=None)