From bedb965b849ec2a6f2011eb1b7b387b4ca314c1d Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Wed, 15 Apr 2026 10:07:38 -0700 Subject: [PATCH] test(tasks): rewrite API task tests for redesigned model and v9 compat Replaces the old Django TestCase-based tests with pytest-style classes using PaperlessTaskFactory. Covers v10 field names, v9 backwards-compat field mapping, filtering, ordering, acknowledge, acknowledge_all, summary, active, and run endpoints. Also adds PaperlessTaskFactory to factories.py and fixes a redundant source= kwarg in TaskSerializerV10.related_document_ids. Co-Authored-By: Claude Sonnet 4.6 --- src/documents/serialisers.py | 1 - src/documents/tests/factories.py | 15 + src/documents/tests/test_api_tasks.py | 943 ++++++++++++++++---------- 3 files changed, 604 insertions(+), 355 deletions(-) diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 480a6e9c4..278f3e936 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -2434,7 +2434,6 @@ class TaskSerializerV10(OwnedObjectSerializer): related_document_ids = serializers.ListField( child=serializers.IntegerField(), read_only=True, - source="related_document_ids", ) task_type_display = serializers.CharField( source="get_task_type_display", diff --git a/src/documents/tests/factories.py b/src/documents/tests/factories.py index b0fd68428..0d59f6571 100644 --- a/src/documents/tests/factories.py +++ b/src/documents/tests/factories.py @@ -11,6 +11,7 @@ from documents.models import Correspondent from documents.models import Document from documents.models import DocumentType from documents.models import MatchingModel +from documents.models import PaperlessTask from documents.models import StoragePath from documents.models import Tag @@ -65,3 +66,17 @@ class DocumentFactory(DjangoModelFactory): correspondent = None document_type = None storage_path = None + + +class PaperlessTaskFactory(DjangoModelFactory): + class Meta: + model = PaperlessTask + + task_id = factory.LazyFunction(lambda: str(__import__("uuid").uuid4())) + task_type = PaperlessTask.TaskType.CONSUME_FILE + trigger_source = PaperlessTask.TriggerSource.WEB_UI + status = PaperlessTask.Status.PENDING + input_data = factory.LazyFunction(dict) + result_data = None + result_message = None + acknowledged = False diff --git a/src/documents/tests/test_api_tasks.py b/src/documents/tests/test_api_tasks.py index 5dd003565..862044888 100644 --- a/src/documents/tests/test_api_tasks.py +++ b/src/documents/tests/test_api_tasks.py @@ -1,425 +1,660 @@ +"""Tests for the /api/tasks/ endpoint. + +Covers: +- v10 serializer (new field names) +- v9 serializer (backwards-compatible field names) +- Filtering, ordering, acknowledge, acknowledge_all, summary, active, run +""" + import uuid from unittest import mock -import celery +import pytest from django.contrib.auth.models import Permission from django.contrib.auth.models import User from rest_framework import status -from rest_framework.test import APITestCase +from rest_framework.test import APIClient -from documents.models import Document from documents.models import PaperlessTask -from documents.tests.utils import DirectoriesMixin -from documents.views import TasksViewSet +from documents.tests.factories import PaperlessTaskFactory + +ENDPOINT = "/api/tasks/" +ACCEPT_V10 = "application/json; version=10" +ACCEPT_V9 = "application/json; version=9" -class TestTasks(DirectoriesMixin, APITestCase): - ENDPOINT = "/api/tasks/" +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- - def setUp(self) -> None: - super().setUp() - self.user = User.objects.create_superuser(username="temp_admin") - self.client.force_authenticate(user=self.user) +@pytest.fixture() +def superuser(db) -> User: + return User.objects.create_superuser(username="admin", password="admin") - def test_get_tasks(self) -> None: - """ - GIVEN: - - Attempted celery tasks - WHEN: - - API call is made to get tasks - THEN: - - Attempting and pending tasks are serialized and provided - """ - task1 = PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_one.pdf", +@pytest.fixture() +def regular_user(db) -> User: + return User.objects.create_user(username="regular", password="regular") + + +@pytest.fixture() +def admin_client(superuser: User) -> APIClient: + """Authenticated admin client sending v10 Accept header.""" + client = APIClient() + client.force_authenticate(user=superuser) + client.credentials(HTTP_ACCEPT=ACCEPT_V10) + return client + + +@pytest.fixture() +def v9_client(superuser: User) -> APIClient: + """Authenticated admin client sending v9 Accept header.""" + client = APIClient() + client.force_authenticate(user=superuser) + client.credentials(HTTP_ACCEPT=ACCEPT_V9) + return client + + +@pytest.fixture() +def user_client(regular_user: User) -> APIClient: + """Authenticated regular-user client sending v10 Accept header.""" + client = APIClient() + client.force_authenticate(user=regular_user) + client.credentials(HTTP_ACCEPT=ACCEPT_V10) + return client + + +# --------------------------------------------------------------------------- +# TestGetTasksV10 +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db() +class TestGetTasksV10: + def test_list_returns_tasks(self, admin_client: APIClient) -> None: + PaperlessTaskFactory.create_batch(2) + + response = admin_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 2 + + def test_response_has_v10_fields(self, admin_client: APIClient) -> None: + PaperlessTaskFactory( + input_data={"filename": "doc.pdf"}, + result_data={"document_id": 42}, + result_message="Done", ) - task2 = PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_two.pdf", + response = admin_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + task_data = response.data[0] + assert "task_type" in task_data + assert "trigger_source" in task_data + assert "input_data" in task_data + assert "result_data" in task_data + assert "result_message" in task_data + assert "related_document_ids" in task_data + + def test_related_document_ids_populated_from_result_data( + self, + admin_client: APIClient, + ) -> None: + PaperlessTaskFactory( + status=PaperlessTask.Status.SUCCESS, + result_data={"document_id": 7}, ) - response = self.client.get(self.ENDPOINT) + response = admin_client.get(ENDPOINT) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 2) - returned_task1 = response.data[1] - returned_task2 = response.data[0] + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["related_document_ids"] == [7] - self.assertEqual(returned_task1["task_id"], task1.task_id) - self.assertEqual(returned_task1["status"], celery.states.PENDING) - self.assertEqual(returned_task1["task_file_name"], task1.task_file_name) + def test_filter_by_task_type(self, admin_client: APIClient) -> None: + PaperlessTaskFactory(task_type=PaperlessTask.TaskType.CONSUME_FILE) + PaperlessTaskFactory(task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER) - self.assertEqual(returned_task2["task_id"], task2.task_id) - self.assertEqual(returned_task2["status"], celery.states.PENDING) - self.assertEqual(returned_task2["task_file_name"], task2.task_file_name) - - def test_get_single_task_status(self) -> None: - """ - GIVEN - - Query parameter for a valid task ID - WHEN: - - API call is made to get task status - THEN: - - Single task data is returned - """ - - id1 = str(uuid.uuid4()) - task1 = PaperlessTask.objects.create( - task_id=id1, - task_file_name="task_one.pdf", + response = admin_client.get( + ENDPOINT, + {"task_type": PaperlessTask.TaskType.TRAIN_CLASSIFIER}, ) - _ = PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_two.pdf", + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 1 + assert response.data[0]["task_type"] == PaperlessTask.TaskType.TRAIN_CLASSIFIER + + def test_filter_by_status(self, admin_client: APIClient) -> None: + PaperlessTaskFactory(status=PaperlessTask.Status.PENDING) + PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS) + + response = admin_client.get( + ENDPOINT, + {"status": PaperlessTask.Status.SUCCESS}, ) - response = self.client.get(self.ENDPOINT + f"?task_id={id1}") + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 1 + assert response.data[0]["status"] == PaperlessTask.Status.SUCCESS - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 1) - returned_task1 = response.data[0] + def test_filter_by_task_id(self, admin_client: APIClient) -> None: + task = PaperlessTaskFactory() + PaperlessTaskFactory() # another task that should not appear - self.assertEqual(returned_task1["task_id"], task1.task_id) + response = admin_client.get(ENDPOINT, {"task_id": task.task_id}) - def test_get_single_task_status_not_valid(self) -> None: - """ - GIVEN - - Query parameter for a non-existent task ID - WHEN: - - API call is made to get task status - THEN: - - No task data is returned - """ - PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_one.pdf", + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 1 + assert response.data[0]["task_id"] == task.task_id + + def test_filter_by_acknowledged(self, admin_client: APIClient) -> None: + PaperlessTaskFactory(acknowledged=False) + PaperlessTaskFactory(acknowledged=True) + + response = admin_client.get(ENDPOINT, {"acknowledged": "false"}) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 1 + assert response.data[0]["acknowledged"] is False + + def test_filter_is_complete_true(self, admin_client: APIClient) -> None: + PaperlessTaskFactory(status=PaperlessTask.Status.PENDING) + PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS) + PaperlessTaskFactory(status=PaperlessTask.Status.FAILURE) + + response = admin_client.get(ENDPOINT, {"is_complete": "true"}) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 2 + returned_statuses = {t["status"] for t in response.data} + assert returned_statuses == { + PaperlessTask.Status.SUCCESS, + PaperlessTask.Status.FAILURE, + } + + def test_default_ordering_is_newest_first(self, admin_client: APIClient) -> None: + t1 = PaperlessTaskFactory() + PaperlessTaskFactory() # middle task -- not checked directly + t3 = PaperlessTaskFactory() + + response = admin_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + ids = [t["task_id"] for t in response.data] + assert ids[0] == t3.task_id + assert ids[-1] == t1.task_id + + def test_no_v9_only_fields_present(self, admin_client: APIClient) -> None: + PaperlessTaskFactory() + + response = admin_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + task_data = response.data[0] + assert "task_name" not in task_data + assert "task_file_name" not in task_data + + +# --------------------------------------------------------------------------- +# TestGetTasksV9 +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db() +class TestGetTasksV9: + def test_response_has_v9_fields(self, v9_client: APIClient) -> None: + PaperlessTaskFactory(input_data={"filename": "invoice.pdf"}) + + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + task_data = response.data[0] + assert "task_name" in task_data + assert "task_file_name" in task_data + assert "type" in task_data + assert "result" in task_data + assert "related_document" in task_data + assert "duplicate_documents" in task_data + + def test_no_v10_only_fields_present(self, v9_client: APIClient) -> None: + PaperlessTaskFactory() + + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + task_data = response.data[0] + assert "task_type" not in task_data + assert "trigger_source" not in task_data + assert "input_data" not in task_data + + def test_task_name_equals_task_type_value(self, v9_client: APIClient) -> None: + PaperlessTaskFactory(task_type=PaperlessTask.TaskType.CONSUME_FILE) + + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["task_name"] == "consume_file" + + def test_task_file_name_from_input_data(self, v9_client: APIClient) -> None: + PaperlessTaskFactory(input_data={"filename": "report.pdf"}) + + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["task_file_name"] == "report.pdf" + + def test_task_file_name_none_when_no_filename_key( + self, + v9_client: APIClient, + ) -> None: + PaperlessTaskFactory(input_data={}) + + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["task_file_name"] is None + + def test_type_scheduled_maps_to_scheduled_task(self, v9_client: APIClient) -> None: + PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.SCHEDULED) + + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["type"] == "SCHEDULED_TASK" + + def test_type_system_maps_to_auto_task(self, v9_client: APIClient) -> None: + PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.SYSTEM) + + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["type"] == "AUTO_TASK" + + def test_type_web_ui_maps_to_manual_task(self, v9_client: APIClient) -> None: + PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.WEB_UI) + + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["type"] == "MANUAL_TASK" + + def test_type_manual_maps_to_manual_task(self, v9_client: APIClient) -> None: + PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.MANUAL) + + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["type"] == "MANUAL_TASK" + + def test_related_document_from_result_data_document_id( + self, + v9_client: APIClient, + ) -> None: + PaperlessTaskFactory( + status=PaperlessTask.Status.SUCCESS, + result_data={"document_id": 99}, ) - _ = PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_two.pdf", + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["related_document"] == 99 + + def test_related_document_none_when_no_result_data( + self, + v9_client: APIClient, + ) -> None: + PaperlessTaskFactory(result_data=None) + + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["related_document"] is None + + def test_duplicate_documents_from_result_data(self, v9_client: APIClient) -> None: + PaperlessTaskFactory( + status=PaperlessTask.Status.SUCCESS, + result_data={"duplicate_of": 55}, ) - response = self.client.get(self.ENDPOINT + "?task_id=bad-task-id") + response = v9_client.get(ENDPOINT) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 0) + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["duplicate_documents"] == [55] - def test_acknowledge_tasks(self) -> None: - """ - GIVEN: - - Attempted celery tasks - WHEN: - - API call is made to get mark task as acknowledged - THEN: - - Task is marked as acknowledged - """ - task = PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_one.pdf", + def test_duplicate_documents_empty_when_no_result_data( + self, + v9_client: APIClient, + ) -> None: + PaperlessTaskFactory(result_data=None) + + response = v9_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["duplicate_documents"] == [] + + def test_filter_by_task_name_maps_to_task_type(self, v9_client: APIClient) -> None: + """v9 ?task_name=consume_file filter maps to the task_type field.""" + PaperlessTaskFactory(task_type=PaperlessTask.TaskType.CONSUME_FILE) + PaperlessTaskFactory(task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER) + + response = v9_client.get(ENDPOINT, {"task_name": "consume_file"}) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 1 + assert response.data[0]["task_name"] == "consume_file" + + +# --------------------------------------------------------------------------- +# TestAcknowledge +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db() +class TestAcknowledge: + def test_returns_count(self, admin_client: APIClient) -> None: + task1 = PaperlessTaskFactory() + task2 = PaperlessTaskFactory() + + response = admin_client.post( + ENDPOINT + "acknowledge/", + {"tasks": [task1.id, task2.id]}, + format="json", ) - response = self.client.get(self.ENDPOINT) - self.assertEqual(len(response.data), 1) + assert response.status_code == status.HTTP_200_OK + assert response.data == {"result": 2} - response = self.client.post( - self.ENDPOINT + "acknowledge/", + def test_acknowledged_tasks_excluded_from_unacked_filter( + self, + admin_client: APIClient, + ) -> None: + task = PaperlessTaskFactory() + admin_client.post( + ENDPOINT + "acknowledge/", {"tasks": [task.id]}, - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - response = self.client.get(self.ENDPOINT + "?acknowledged=false") - self.assertEqual(len(response.data), 0) - - def test_acknowledge_tasks_requires_change_permission(self) -> None: - """ - GIVEN: - - A regular user initially without change permissions - - A regular user with change permissions - WHEN: - - API call is made to acknowledge tasks - THEN: - - The first user is forbidden from acknowledging tasks - - The second user is allowed to acknowledge tasks - """ - regular_user = User.objects.create_user(username="test") - self.client.force_authenticate(user=regular_user) - - task = PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_one.pdf", + format="json", ) - response = self.client.post( - self.ENDPOINT + "acknowledge/", + response = admin_client.get(ENDPOINT, {"acknowledged": "false"}) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 0 + + def test_requires_change_permission(self, user_client: APIClient) -> None: + task = PaperlessTaskFactory() + + response = user_client.post( + ENDPOINT + "acknowledge/", {"tasks": [task.id]}, + format="json", ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - regular_user2 = User.objects.create_user(username="test2") - regular_user2.user_permissions.add( + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_succeeds_with_change_permission(self, regular_user: User) -> None: + regular_user.user_permissions.add( Permission.objects.get(codename="change_paperlesstask"), ) - regular_user2.save() - self.client.force_authenticate(user=regular_user2) + regular_user.save() - response = self.client.post( - self.ENDPOINT + "acknowledge/", + client = APIClient() + client.force_authenticate(user=regular_user) + client.credentials(HTTP_ACCEPT=ACCEPT_V10) + + task = PaperlessTaskFactory() + response = client.post( + ENDPOINT + "acknowledge/", {"tasks": [task.id]}, - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - def test_tasks_owner_aware(self) -> None: - """ - GIVEN: - - Existing PaperlessTasks with owner and with no owner - WHEN: - - API call is made to get tasks - THEN: - - Only tasks with no owner or request user are returned - """ - - regular_user = User.objects.create_user(username="test") - regular_user.user_permissions.add(*Permission.objects.all()) - self.client.logout() - self.client.force_authenticate(user=regular_user) - - task1 = PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_one.pdf", - owner=self.user, + format="json", ) - task2 = PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_two.pdf", + assert response.status_code == status.HTTP_200_OK + + def test_list_is_owner_aware( + self, + superuser: User, + regular_user: User, + ) -> None: + """The task list only shows tasks the user owns or that are unowned.""" + regular_user.user_permissions.add( + Permission.objects.get(codename="view_paperlesstask"), ) - task3 = PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_three.pdf", - owner=regular_user, + client = APIClient() + client.force_authenticate(user=regular_user) + client.credentials(HTTP_ACCEPT=ACCEPT_V10) + + PaperlessTaskFactory(owner=superuser) + shared_task = PaperlessTaskFactory() + own_task = PaperlessTaskFactory(owner=regular_user) + + response = client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 2 + returned_task_ids = {t["task_id"] for t in response.data} + assert shared_task.task_id in returned_task_ids + assert own_task.task_id in returned_task_ids + + +# --------------------------------------------------------------------------- +# TestAcknowledgeAll +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db() +class TestAcknowledgeAll: + def test_marks_only_completed_tasks(self, admin_client: APIClient) -> None: + PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS, acknowledged=False) + PaperlessTaskFactory(status=PaperlessTask.Status.FAILURE, acknowledged=False) + PaperlessTaskFactory(status=PaperlessTask.Status.PENDING, acknowledged=False) + + response = admin_client.post(ENDPOINT + "acknowledge_all/") + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"result": 2} + + def test_skips_already_acknowledged(self, admin_client: APIClient) -> None: + PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS, acknowledged=True) + PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS, acknowledged=False) + + response = admin_client.post(ENDPOINT + "acknowledge_all/") + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"result": 1} + + def test_skips_pending_and_started(self, admin_client: APIClient) -> None: + PaperlessTaskFactory(status=PaperlessTask.Status.PENDING) + PaperlessTaskFactory(status=PaperlessTask.Status.STARTED) + + response = admin_client.post(ENDPOINT + "acknowledge_all/") + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"result": 0} + + def test_includes_revoked(self, admin_client: APIClient) -> None: + PaperlessTaskFactory(status=PaperlessTask.Status.REVOKED, acknowledged=False) + + response = admin_client.post(ENDPOINT + "acknowledge_all/") + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"result": 1} + + +# --------------------------------------------------------------------------- +# TestSummary +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db() +class TestSummary: + def test_returns_per_type_totals(self, admin_client: APIClient) -> None: + PaperlessTaskFactory( + task_type=PaperlessTask.TaskType.CONSUME_FILE, + status=PaperlessTask.Status.SUCCESS, + ) + PaperlessTaskFactory( + task_type=PaperlessTask.TaskType.CONSUME_FILE, + status=PaperlessTask.Status.FAILURE, + ) + PaperlessTaskFactory( + task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER, + status=PaperlessTask.Status.SUCCESS, ) - response = self.client.get(self.ENDPOINT) + response = admin_client.get(ENDPOINT + "summary/") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 2) - self.assertEqual(response.data[0]["task_id"], task3.task_id) - self.assertEqual(response.data[1]["task_id"], task2.task_id) + assert response.status_code == status.HTTP_200_OK + by_type = {item["task_type"]: item for item in response.data} + assert by_type["consume_file"]["total_count"] == 2 + assert by_type["consume_file"]["success_count"] == 1 + assert by_type["consume_file"]["failure_count"] == 1 + assert by_type["train_classifier"]["total_count"] == 1 - acknowledge_response = self.client.post( - self.ENDPOINT + "acknowledge/", - {"tasks": [task1.id, task2.id, task3.id]}, - ) - self.assertEqual(acknowledge_response.status_code, status.HTTP_200_OK) - self.assertEqual(acknowledge_response.data, {"result": 2}) + def test_contains_expected_fields(self, admin_client: APIClient) -> None: + PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS) - def test_task_result_no_error(self) -> None: - """ - GIVEN: - - A celery task completed without error - WHEN: - - API call is made to get tasks - THEN: - - The returned data includes the task result - """ - PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_one.pdf", - status=celery.states.SUCCESS, - result="Success. New document id 1 created", - ) + response = admin_client.get(ENDPOINT + "summary/") - response = self.client.get(self.ENDPOINT) + assert response.status_code == status.HTTP_200_OK + assert len(response.data) >= 1 + item = response.data[0] + for field in ( + "task_type", + "total_count", + "pending_count", + "success_count", + "failure_count", + "avg_duration_seconds", + "avg_wait_time_seconds", + "last_run", + "last_success", + "last_failure", + ): + assert field in item, f"Missing field: {field}" - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 1) - returned_data = response.data[0] +# --------------------------------------------------------------------------- +# TestActive +# --------------------------------------------------------------------------- - self.assertEqual(returned_data["result"], "Success. New document id 1 created") - self.assertEqual(returned_data["related_document"], "1") - def test_task_result_with_error(self) -> None: - """ - GIVEN: - - A celery task completed with an exception - WHEN: - - API call is made to get tasks - THEN: - - The returned result is the exception info - """ - PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_one.pdf", - status=celery.states.FAILURE, - result="test.pdf: Unexpected error during ingestion.", - ) +@pytest.mark.django_db() +class TestActive: + def test_returns_pending_and_started_only(self, admin_client: APIClient) -> None: + PaperlessTaskFactory(status=PaperlessTask.Status.PENDING) + PaperlessTaskFactory(status=PaperlessTask.Status.STARTED) + PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS) + PaperlessTaskFactory(status=PaperlessTask.Status.FAILURE) - response = self.client.get(self.ENDPOINT) + response = admin_client.get(ENDPOINT + "active/") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 1) - - returned_data = response.data[0] - - self.assertEqual( - returned_data["result"], - "test.pdf: Unexpected error during ingestion.", - ) - - def test_task_name_webui(self) -> None: - """ - GIVEN: - - Attempted celery task - - Task was created through the webui - WHEN: - - API call is made to get tasks - THEN: - - Returned data include the filename - """ - PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="test.pdf", - task_name=PaperlessTask.TaskName.CONSUME_FILE, - status=celery.states.SUCCESS, - ) - - response = self.client.get(self.ENDPOINT) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 1) - - returned_data = response.data[0] - - self.assertEqual(returned_data["task_file_name"], "test.pdf") - - def test_task_name_consume_folder(self) -> None: - """ - GIVEN: - - Attempted celery task - - Task was created through the consume folder - WHEN: - - API call is made to get tasks - THEN: - - Returned data include the filename - """ - PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="anothertest.pdf", - task_name=PaperlessTask.TaskName.CONSUME_FILE, - status=celery.states.SUCCESS, - ) - - response = self.client.get(self.ENDPOINT) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 1) - - returned_data = response.data[0] - - self.assertEqual(returned_data["task_file_name"], "anothertest.pdf") - - def test_task_result_duplicate_warning_includes_count(self) -> None: - """ - GIVEN: - - A celery task succeeds, but a duplicate exists - WHEN: - - API call is made to get tasks - THEN: - - The returned data includes duplicate warning metadata - """ - checksum = "duplicate-checksum" - Document.objects.create( - title="Existing", - content="", - mime_type="application/pdf", - checksum=checksum, - ) - created_doc = Document.objects.create( - title="Created", - content="", - mime_type="application/pdf", - checksum=checksum, - archive_checksum="another-checksum", - ) - PaperlessTask.objects.create( - task_id=str(uuid.uuid4()), - task_file_name="task_one.pdf", - status=celery.states.SUCCESS, - result=f"Success. New document id {created_doc.pk} created", - ) - - response = self.client.get(self.ENDPOINT) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data), 1) - - returned_data = response.data[0] - - self.assertEqual(returned_data["related_document"], str(created_doc.pk)) - - def test_run_train_classifier_task(self) -> None: - """ - GIVEN: - - A superuser - WHEN: - - API call is made to run the train classifier task - THEN: - - The task is run - """ - mock_train_classifier = mock.Mock(return_value="Task started") - TasksViewSet.TASK_AND_ARGS_BY_NAME = { - PaperlessTask.TaskName.TRAIN_CLASSIFIER: ( - mock_train_classifier, - {"scheduled": False}, - ), + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 2 + active_statuses = {t["status"] for t in response.data} + assert active_statuses == { + PaperlessTask.Status.PENDING, + PaperlessTask.Status.STARTED, } - response = self.client.post( - self.ENDPOINT + "run/", - {"task_name": PaperlessTask.TaskName.TRAIN_CLASSIFIER}, + + def test_excludes_completed_tasks(self, admin_client: APIClient) -> None: + PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS) + PaperlessTaskFactory(status=PaperlessTask.Status.FAILURE) + PaperlessTaskFactory(status=PaperlessTask.Status.REVOKED) + + response = admin_client.get(ENDPOINT + "active/") + + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 0 + + +# --------------------------------------------------------------------------- +# TestRun +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db() +class TestRun: + def test_forbidden_for_regular_user(self, user_client: APIClient) -> None: + response = user_client.post( + ENDPOINT + "run/", + {"task_type": PaperlessTask.TaskType.TRAIN_CLASSIFIER}, + format="json", ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, {"result": "Task started"}) - mock_train_classifier.assert_called_once_with(scheduled=False) + assert response.status_code == status.HTTP_403_FORBIDDEN - # mock error - mock_train_classifier.reset_mock() - mock_train_classifier.side_effect = Exception("Error") - response = self.client.post( - self.ENDPOINT + "run/", - {"task_name": PaperlessTask.TaskName.TRAIN_CLASSIFIER}, + def test_dispatches_via_apply_async_with_manual_trigger_header( + self, + admin_client: APIClient, + ) -> None: + fake_task_id = str(uuid.uuid4()) + mock_async_result = mock.Mock() + mock_async_result.id = fake_task_id + + mock_apply_async = mock.Mock(return_value=mock_async_result) + + with mock.patch( + "documents.views.train_classifier.apply_async", + mock_apply_async, + ): + response = admin_client.post( + ENDPOINT + "run/", + {"task_type": PaperlessTask.TaskType.TRAIN_CLASSIFIER}, + format="json", + ) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"task_id": fake_task_id} + mock_apply_async.assert_called_once_with( + kwargs={}, + headers={"trigger_source": "manual"}, ) - self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - mock_train_classifier.assert_called_once_with(scheduled=False) - - @mock.patch("documents.tasks.sanity_check") - def test_run_task_requires_superuser(self, mock_check_sanity) -> None: - """ - GIVEN: - - A regular user - WHEN: - - API call is made to run a task - THEN: - - The task is not run - """ - regular_user = User.objects.create_user(username="test") - regular_user.user_permissions.add(*Permission.objects.all()) - self.client.logout() - self.client.force_authenticate(user=regular_user) - - response = self.client.post( - self.ENDPOINT + "run/", - {"task_name": PaperlessTask.TaskName.CHECK_SANITY}, + def test_returns_400_for_consume_file(self, admin_client: APIClient) -> None: + """consume_file cannot be manually triggered via the run endpoint.""" + response = admin_client.post( + ENDPOINT + "run/", + {"task_type": PaperlessTask.TaskType.CONSUME_FILE}, + format="json", ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - mock_check_sanity.assert_not_called() + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_returns_400_for_invalid_task_type(self, admin_client: APIClient) -> None: + response = admin_client.post( + ENDPOINT + "run/", + {"task_type": "not_a_real_type"}, + format="json", + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_sanity_check_dispatched_with_correct_kwargs( + self, + admin_client: APIClient, + ) -> None: + fake_task_id = str(uuid.uuid4()) + mock_async_result = mock.Mock() + mock_async_result.id = fake_task_id + + mock_apply_async = mock.Mock(return_value=mock_async_result) + + with mock.patch( + "documents.views.sanity_check.apply_async", + mock_apply_async, + ): + response = admin_client.post( + ENDPOINT + "run/", + {"task_type": PaperlessTask.TaskType.SANITY_CHECK}, + format="json", + ) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"task_id": fake_task_id} + mock_apply_async.assert_called_once_with( + kwargs={"raise_on_error": False}, + headers={"trigger_source": "manual"}, + )