From fca565a16949e045d804d5863db466fece0891ba Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Wed, 15 Apr 2026 11:25:52 -0700 Subject: [PATCH] test: fix remaining tests broken by task system redesign MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update all tests that created PaperlessTask objects with old field names to use PaperlessTaskFactory and new field names (task_type, trigger_source, status, result_message). Use apply_async instead of delay where mocked. Drop TestCheckSanityTaskRecording — tests PaperlessTask creation that was intentionally removed from check_sanity(). Co-Authored-By: Claude Sonnet 4.6 --- src/documents/tests/test_api_app_config.py | 2 +- src/documents/tests/test_api_status.py | 56 +++++++++++----------- src/documents/tests/test_api_tasks.py | 56 ++++------------------ src/documents/tests/test_management.py | 2 +- src/documents/tests/test_sanity_check.py | 34 +------------ src/paperless_ai/tests/test_ai_indexing.py | 18 +++---- 6 files changed, 50 insertions(+), 118 deletions(-) diff --git a/src/documents/tests/test_api_app_config.py b/src/documents/tests/test_api_app_config.py index d1241b38a..ccefde1ad 100644 --- a/src/documents/tests/test_api_app_config.py +++ b/src/documents/tests/test_api_app_config.py @@ -831,7 +831,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): config.save() with ( - patch("documents.tasks.llmindex_index.delay") as mock_update, + patch("documents.tasks.llmindex_index.apply_async") as mock_update, patch("paperless_ai.indexing.vector_store_file_exists") as mock_exists, ): mock_exists.return_value = False diff --git a/src/documents/tests/test_api_status.py b/src/documents/tests/test_api_status.py index 4f4511c14..69cfe2c34 100644 --- a/src/documents/tests/test_api_status.py +++ b/src/documents/tests/test_api_status.py @@ -4,7 +4,6 @@ import tempfile from pathlib import Path from unittest import mock -from celery import states from django.contrib.auth.models import Permission from django.contrib.auth.models import User from django.test import override_settings @@ -13,6 +12,7 @@ from rest_framework.test import APITestCase from documents.models import PaperlessTask from documents.permissions import has_system_status_permission +from documents.tests.factories import PaperlessTaskFactory from paperless import version @@ -258,10 +258,10 @@ class TestSystemStatus(APITestCase): THEN: - The response contains an OK classifier status """ - PaperlessTask.objects.create( - type=PaperlessTask.TaskType.SCHEDULED_TASK, - status=states.SUCCESS, - task_name=PaperlessTask.TaskName.TRAIN_CLASSIFIER, + PaperlessTaskFactory( + task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER, + trigger_source=PaperlessTask.TriggerSource.SCHEDULED, + status=PaperlessTask.Status.SUCCESS, ) self.client.force_login(self.user) response = self.client.get(self.ENDPOINT) @@ -295,11 +295,11 @@ class TestSystemStatus(APITestCase): THEN: - The response contains an ERROR classifier status """ - PaperlessTask.objects.create( - type=PaperlessTask.TaskType.SCHEDULED_TASK, - status=states.FAILURE, - task_name=PaperlessTask.TaskName.TRAIN_CLASSIFIER, - result="Classifier training failed", + PaperlessTaskFactory( + task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER, + trigger_source=PaperlessTask.TriggerSource.SCHEDULED, + status=PaperlessTask.Status.FAILURE, + result_message="Classifier training failed", ) self.client.force_login(self.user) response = self.client.get(self.ENDPOINT) @@ -319,10 +319,10 @@ class TestSystemStatus(APITestCase): THEN: - The response contains an OK sanity check status """ - PaperlessTask.objects.create( - type=PaperlessTask.TaskType.SCHEDULED_TASK, - status=states.SUCCESS, - task_name=PaperlessTask.TaskName.CHECK_SANITY, + PaperlessTaskFactory( + task_type=PaperlessTask.TaskType.SANITY_CHECK, + trigger_source=PaperlessTask.TriggerSource.SCHEDULED, + status=PaperlessTask.Status.SUCCESS, ) self.client.force_login(self.user) response = self.client.get(self.ENDPOINT) @@ -356,11 +356,11 @@ class TestSystemStatus(APITestCase): THEN: - The response contains an ERROR sanity check status """ - PaperlessTask.objects.create( - type=PaperlessTask.TaskType.SCHEDULED_TASK, - status=states.FAILURE, - task_name=PaperlessTask.TaskName.CHECK_SANITY, - result="5 issues found.", + PaperlessTaskFactory( + task_type=PaperlessTask.TaskType.SANITY_CHECK, + trigger_source=PaperlessTask.TriggerSource.SCHEDULED, + status=PaperlessTask.Status.FAILURE, + result_message="5 issues found.", ) self.client.force_login(self.user) response = self.client.get(self.ENDPOINT) @@ -405,10 +405,10 @@ class TestSystemStatus(APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["tasks"]["llmindex_status"], "WARNING") - PaperlessTask.objects.create( - type=PaperlessTask.TaskType.SCHEDULED_TASK, - status=states.SUCCESS, - task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE, + PaperlessTaskFactory( + task_type=PaperlessTask.TaskType.LLM_INDEX, + trigger_source=PaperlessTask.TriggerSource.SCHEDULED, + status=PaperlessTask.Status.SUCCESS, ) response = self.client.get(self.ENDPOINT) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -425,11 +425,11 @@ class TestSystemStatus(APITestCase): - The response contains the correct AI status """ with override_settings(AI_ENABLED=True, LLM_EMBEDDING_BACKEND="openai"): - PaperlessTask.objects.create( - type=PaperlessTask.TaskType.SCHEDULED_TASK, - status=states.FAILURE, - task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE, - result="AI index update failed", + PaperlessTaskFactory( + task_type=PaperlessTask.TaskType.LLM_INDEX, + trigger_source=PaperlessTask.TriggerSource.SCHEDULED, + status=PaperlessTask.Status.FAILURE, + result_message="AI index update failed", ) self.client.force_login(self.user) response = self.client.get(self.ENDPOINT) diff --git a/src/documents/tests/test_api_tasks.py b/src/documents/tests/test_api_tasks.py index 3fdf32560..c485f7aec 100644 --- a/src/documents/tests/test_api_tasks.py +++ b/src/documents/tests/test_api_tasks.py @@ -29,29 +29,24 @@ ACCEPT_V9 = "application/json; version=9" @pytest.fixture() -def superuser(db) -> User: - return User.objects.create_superuser(username="admin", password="admin") +def regular_user(django_user_model: User) -> User: + return django_user_model.objects.create_user(username="regular", password="regular") @pytest.fixture() -def regular_user(db) -> User: - return User.objects.create_user(username="regular", password="regular") - - -@pytest.fixture() -def admin_client(superuser: User) -> APIClient: +def admin_client(admin_user: User) -> APIClient: """Authenticated admin client sending v10 Accept header.""" client = APIClient() - client.force_authenticate(user=superuser) + client.force_authenticate(user=admin_user) client.credentials(HTTP_ACCEPT=ACCEPT_V10) return client @pytest.fixture() -def v9_client(superuser: User) -> APIClient: +def v9_client(admin_user: User) -> APIClient: """Authenticated admin client sending v9 Accept header.""" client = APIClient() - client.force_authenticate(user=superuser) + client.force_authenticate(user=admin_user) client.credentials(HTTP_ACCEPT=ACCEPT_V9) return client @@ -65,11 +60,6 @@ def user_client(regular_user: User) -> APIClient: return client -# --------------------------------------------------------------------------- -# TestGetTasksV10 -# --------------------------------------------------------------------------- - - @pytest.mark.django_db() class TestGetTasksV10: def test_list_returns_tasks(self, admin_client: APIClient) -> None: @@ -202,7 +192,7 @@ class TestGetTasksV10: def test_list_is_owner_aware( self, - superuser: User, + admin_user: User, regular_user: User, ) -> None: """The task list only shows tasks the user owns or that are unowned.""" @@ -214,7 +204,7 @@ class TestGetTasksV10: client.force_authenticate(user=regular_user) client.credentials(HTTP_ACCEPT=ACCEPT_V10) - PaperlessTaskFactory(owner=superuser) + PaperlessTaskFactory(owner=admin_user) shared_task = PaperlessTaskFactory() own_task = PaperlessTaskFactory(owner=regular_user) @@ -227,11 +217,6 @@ class TestGetTasksV10: assert own_task.task_id in returned_task_ids -# --------------------------------------------------------------------------- -# TestGetTasksV9 -# --------------------------------------------------------------------------- - - @pytest.mark.django_db() class TestGetTasksV9: def test_task_name_equals_task_type_value(self, v9_client: APIClient) -> None: @@ -363,11 +348,6 @@ class TestGetTasksV9: assert response.data[0]["type"] == "SCHEDULED_TASK" -# --------------------------------------------------------------------------- -# TestAcknowledge -# --------------------------------------------------------------------------- - - @pytest.mark.django_db() class TestAcknowledge: def test_returns_count(self, admin_client: APIClient) -> None: @@ -430,11 +410,6 @@ class TestAcknowledge: assert response.status_code == status.HTTP_200_OK -# --------------------------------------------------------------------------- -# TestAcknowledgeAll -# --------------------------------------------------------------------------- - - @pytest.mark.django_db() class TestAcknowledgeAll: def test_marks_only_completed_tasks(self, admin_client: APIClient) -> None: @@ -474,11 +449,6 @@ class TestAcknowledgeAll: assert response.data == {"result": 1} -# --------------------------------------------------------------------------- -# TestSummary -# --------------------------------------------------------------------------- - - @pytest.mark.django_db() class TestSummary: def test_returns_per_type_totals(self, admin_client: APIClient) -> None: @@ -505,11 +475,6 @@ class TestSummary: assert by_type["train_classifier"]["total_count"] == 1 -# --------------------------------------------------------------------------- -# TestActive -# --------------------------------------------------------------------------- - - @pytest.mark.django_db() class TestActive: def test_returns_pending_and_started_only(self, admin_client: APIClient) -> None: @@ -537,11 +502,6 @@ class TestActive: assert len(response.data) == 0 -# --------------------------------------------------------------------------- -# TestRun -# --------------------------------------------------------------------------- - - @pytest.mark.django_db() class TestRun: def test_forbidden_for_regular_user(self, user_client: APIClient) -> None: diff --git a/src/documents/tests/test_management.py b/src/documents/tests/test_management.py index 72476d403..276da942d 100644 --- a/src/documents/tests/test_management.py +++ b/src/documents/tests/test_management.py @@ -211,7 +211,7 @@ class TestCreateClassifier: call_command("document_create_classifier", skip_checks=True) - m.assert_called_once_with(scheduled=False, status_callback=mocker.ANY) + m.assert_called_once_with(status_callback=mocker.ANY) assert callable(m.call_args.kwargs["status_callback"]) def test_create_classifier_callback_output(self, mocker: MockerFixture) -> None: diff --git a/src/documents/tests/test_sanity_check.py b/src/documents/tests/test_sanity_check.py index e62c17303..568e3e444 100644 --- a/src/documents/tests/test_sanity_check.py +++ b/src/documents/tests/test_sanity_check.py @@ -1,7 +1,7 @@ """Tests for the sanity checker module. Tests exercise ``check_sanity`` as a whole, verifying document validation, -orphan detection, task recording, and the iter_wrapper contract. +orphan detection, and the iter_wrapper contract. """ from __future__ import annotations @@ -12,13 +12,12 @@ from typing import TYPE_CHECKING import pytest -from documents.models import Document -from documents.models import PaperlessTask from documents.sanity_checker import check_sanity if TYPE_CHECKING: from collections.abc import Iterable + from documents.models import Document from documents.tests.conftest import PaperlessDirs @@ -229,35 +228,6 @@ class TestCheckSanityIterWrapper: assert not messages.has_error -@pytest.mark.django_db -class TestCheckSanityTaskRecording: - @pytest.mark.parametrize( - ("expected_type", "scheduled"), - [ - pytest.param(PaperlessTask.TaskType.SCHEDULED_TASK, True, id="scheduled"), - pytest.param(PaperlessTask.TaskType.MANUAL_TASK, False, id="manual"), - ], - ) - @pytest.mark.usefixtures("_media_settings") - def test_task_type(self, expected_type: str, *, scheduled: bool) -> None: - check_sanity(scheduled=scheduled) - task = PaperlessTask.objects.latest("date_created") - assert task.task_name == PaperlessTask.TaskName.CHECK_SANITY - assert task.type == expected_type - - def test_success_status(self, sample_doc: Document) -> None: - check_sanity() - task = PaperlessTask.objects.latest("date_created") - assert task.status == "SUCCESS" - - def test_failure_status(self, sample_doc: Document) -> None: - Path(sample_doc.source_path).unlink() - check_sanity() - task = PaperlessTask.objects.latest("date_created") - assert task.status == "FAILURE" - assert "Check logs for details" in task.result - - @pytest.mark.django_db class TestCheckSanityLogMessages: def test_logs_doc_issues( diff --git a/src/paperless_ai/tests/test_ai_indexing.py b/src/paperless_ai/tests/test_ai_indexing.py index c1e3b64d8..7d9f3cdd5 100644 --- a/src/paperless_ai/tests/test_ai_indexing.py +++ b/src/paperless_ai/tests/test_ai_indexing.py @@ -3,13 +3,13 @@ from unittest.mock import MagicMock from unittest.mock import patch import pytest -from celery import states from django.test import override_settings from django.utils import timezone from llama_index.core.base.embeddings.base import BaseEmbedding from documents.models import Document from documents.models import PaperlessTask +from documents.tests.factories import PaperlessTaskFactory from paperless_ai import indexing @@ -292,13 +292,15 @@ def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() - ) assert result is True - mock_task.delay.assert_called_once_with(rebuild=True, scheduled=False, auto=True) + mock_task.apply_async.assert_called_once_with( + kwargs={"rebuild": True}, + headers={"trigger_source": "system"}, + ) - PaperlessTask.objects.create( - task_id="task-1", - task_name=PaperlessTask.TaskName.LLMINDEX_UPDATE, - status=states.STARTED, - date_created=timezone.now(), + PaperlessTaskFactory( + task_type=PaperlessTask.TaskType.LLM_INDEX, + trigger_source=PaperlessTask.TriggerSource.SYSTEM, + status=PaperlessTask.Status.STARTED, ) # Existing running task @@ -309,7 +311,7 @@ def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() - ) assert result is False - mock_task.delay.assert_not_called() + mock_task.apply_async.assert_not_called() @override_settings(