Files
paperless-ngx/src/documents/tests/test_api_tasks.py
T
stumpylog bc63ca5806 test(tasks): rewrite API task tests for redesigned model and v9 compat
Replaces the old Django TestCase-based tests with pytest-style classes using
PaperlessTaskFactory. Covers v10 field names, v9 backwards-compat field
mapping, filtering, ordering, acknowledge, acknowledge_all, summary, active,
and run endpoints. Also adds PaperlessTaskFactory to factories.py and fixes
a redundant source= kwarg in TaskSerializerV10.related_document_ids.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-15 18:34:34 -07:00

661 lines
23 KiB
Python

"""Tests for the /api/tasks/ endpoint.
Covers:
- v10 serializer (new field names)
- v9 serializer (backwards-compatible field names)
- Filtering, ordering, acknowledge, acknowledge_all, summary, active, run
"""
import uuid
from unittest import mock
import pytest
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from rest_framework import status
from rest_framework.test import APIClient
from documents.models import PaperlessTask
from documents.tests.factories import PaperlessTaskFactory
ENDPOINT = "/api/tasks/"
ACCEPT_V10 = "application/json; version=10"
ACCEPT_V9 = "application/json; version=9"
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def superuser(db) -> User:
return User.objects.create_superuser(username="admin", password="admin")
@pytest.fixture()
def regular_user(db) -> User:
return User.objects.create_user(username="regular", password="regular")
@pytest.fixture()
def admin_client(superuser: User) -> APIClient:
"""Authenticated admin client sending v10 Accept header."""
client = APIClient()
client.force_authenticate(user=superuser)
client.credentials(HTTP_ACCEPT=ACCEPT_V10)
return client
@pytest.fixture()
def v9_client(superuser: User) -> APIClient:
"""Authenticated admin client sending v9 Accept header."""
client = APIClient()
client.force_authenticate(user=superuser)
client.credentials(HTTP_ACCEPT=ACCEPT_V9)
return client
@pytest.fixture()
def user_client(regular_user: User) -> APIClient:
"""Authenticated regular-user client sending v10 Accept header."""
client = APIClient()
client.force_authenticate(user=regular_user)
client.credentials(HTTP_ACCEPT=ACCEPT_V10)
return client
# ---------------------------------------------------------------------------
# TestGetTasksV10
# ---------------------------------------------------------------------------
@pytest.mark.django_db()
class TestGetTasksV10:
def test_list_returns_tasks(self, admin_client: APIClient) -> None:
PaperlessTaskFactory.create_batch(2)
response = admin_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 2
def test_response_has_v10_fields(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(
input_data={"filename": "doc.pdf"},
result_data={"document_id": 42},
result_message="Done",
)
response = admin_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
task_data = response.data[0]
assert "task_type" in task_data
assert "trigger_source" in task_data
assert "input_data" in task_data
assert "result_data" in task_data
assert "result_message" in task_data
assert "related_document_ids" in task_data
def test_related_document_ids_populated_from_result_data(
self,
admin_client: APIClient,
) -> None:
PaperlessTaskFactory(
status=PaperlessTask.Status.SUCCESS,
result_data={"document_id": 7},
)
response = admin_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["related_document_ids"] == [7]
def test_filter_by_task_type(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(task_type=PaperlessTask.TaskType.CONSUME_FILE)
PaperlessTaskFactory(task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER)
response = admin_client.get(
ENDPOINT,
{"task_type": PaperlessTask.TaskType.TRAIN_CLASSIFIER},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data[0]["task_type"] == PaperlessTask.TaskType.TRAIN_CLASSIFIER
def test_filter_by_status(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(status=PaperlessTask.Status.PENDING)
PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS)
response = admin_client.get(
ENDPOINT,
{"status": PaperlessTask.Status.SUCCESS},
)
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data[0]["status"] == PaperlessTask.Status.SUCCESS
def test_filter_by_task_id(self, admin_client: APIClient) -> None:
task = PaperlessTaskFactory()
PaperlessTaskFactory() # another task that should not appear
response = admin_client.get(ENDPOINT, {"task_id": task.task_id})
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data[0]["task_id"] == task.task_id
def test_filter_by_acknowledged(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(acknowledged=False)
PaperlessTaskFactory(acknowledged=True)
response = admin_client.get(ENDPOINT, {"acknowledged": "false"})
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data[0]["acknowledged"] is False
def test_filter_is_complete_true(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(status=PaperlessTask.Status.PENDING)
PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS)
PaperlessTaskFactory(status=PaperlessTask.Status.FAILURE)
response = admin_client.get(ENDPOINT, {"is_complete": "true"})
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 2
returned_statuses = {t["status"] for t in response.data}
assert returned_statuses == {
PaperlessTask.Status.SUCCESS,
PaperlessTask.Status.FAILURE,
}
def test_default_ordering_is_newest_first(self, admin_client: APIClient) -> None:
t1 = PaperlessTaskFactory()
PaperlessTaskFactory() # middle task -- not checked directly
t3 = PaperlessTaskFactory()
response = admin_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
ids = [t["task_id"] for t in response.data]
assert ids[0] == t3.task_id
assert ids[-1] == t1.task_id
def test_no_v9_only_fields_present(self, admin_client: APIClient) -> None:
PaperlessTaskFactory()
response = admin_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
task_data = response.data[0]
assert "task_name" not in task_data
assert "task_file_name" not in task_data
# ---------------------------------------------------------------------------
# TestGetTasksV9
# ---------------------------------------------------------------------------
@pytest.mark.django_db()
class TestGetTasksV9:
def test_response_has_v9_fields(self, v9_client: APIClient) -> None:
PaperlessTaskFactory(input_data={"filename": "invoice.pdf"})
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
task_data = response.data[0]
assert "task_name" in task_data
assert "task_file_name" in task_data
assert "type" in task_data
assert "result" in task_data
assert "related_document" in task_data
assert "duplicate_documents" in task_data
def test_no_v10_only_fields_present(self, v9_client: APIClient) -> None:
PaperlessTaskFactory()
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
task_data = response.data[0]
assert "task_type" not in task_data
assert "trigger_source" not in task_data
assert "input_data" not in task_data
def test_task_name_equals_task_type_value(self, v9_client: APIClient) -> None:
PaperlessTaskFactory(task_type=PaperlessTask.TaskType.CONSUME_FILE)
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["task_name"] == "consume_file"
def test_task_file_name_from_input_data(self, v9_client: APIClient) -> None:
PaperlessTaskFactory(input_data={"filename": "report.pdf"})
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["task_file_name"] == "report.pdf"
def test_task_file_name_none_when_no_filename_key(
self,
v9_client: APIClient,
) -> None:
PaperlessTaskFactory(input_data={})
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["task_file_name"] is None
def test_type_scheduled_maps_to_scheduled_task(self, v9_client: APIClient) -> None:
PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.SCHEDULED)
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["type"] == "SCHEDULED_TASK"
def test_type_system_maps_to_auto_task(self, v9_client: APIClient) -> None:
PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.SYSTEM)
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["type"] == "AUTO_TASK"
def test_type_web_ui_maps_to_manual_task(self, v9_client: APIClient) -> None:
PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.WEB_UI)
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["type"] == "MANUAL_TASK"
def test_type_manual_maps_to_manual_task(self, v9_client: APIClient) -> None:
PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.MANUAL)
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["type"] == "MANUAL_TASK"
def test_related_document_from_result_data_document_id(
self,
v9_client: APIClient,
) -> None:
PaperlessTaskFactory(
status=PaperlessTask.Status.SUCCESS,
result_data={"document_id": 99},
)
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["related_document"] == 99
def test_related_document_none_when_no_result_data(
self,
v9_client: APIClient,
) -> None:
PaperlessTaskFactory(result_data=None)
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["related_document"] is None
def test_duplicate_documents_from_result_data(self, v9_client: APIClient) -> None:
PaperlessTaskFactory(
status=PaperlessTask.Status.SUCCESS,
result_data={"duplicate_of": 55},
)
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["duplicate_documents"] == [55]
def test_duplicate_documents_empty_when_no_result_data(
self,
v9_client: APIClient,
) -> None:
PaperlessTaskFactory(result_data=None)
response = v9_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert response.data[0]["duplicate_documents"] == []
def test_filter_by_task_name_maps_to_task_type(self, v9_client: APIClient) -> None:
"""v9 ?task_name=consume_file filter maps to the task_type field."""
PaperlessTaskFactory(task_type=PaperlessTask.TaskType.CONSUME_FILE)
PaperlessTaskFactory(task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER)
response = v9_client.get(ENDPOINT, {"task_name": "consume_file"})
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data[0]["task_name"] == "consume_file"
# ---------------------------------------------------------------------------
# TestAcknowledge
# ---------------------------------------------------------------------------
@pytest.mark.django_db()
class TestAcknowledge:
def test_returns_count(self, admin_client: APIClient) -> None:
task1 = PaperlessTaskFactory()
task2 = PaperlessTaskFactory()
response = admin_client.post(
ENDPOINT + "acknowledge/",
{"tasks": [task1.id, task2.id]},
format="json",
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {"result": 2}
def test_acknowledged_tasks_excluded_from_unacked_filter(
self,
admin_client: APIClient,
) -> None:
task = PaperlessTaskFactory()
admin_client.post(
ENDPOINT + "acknowledge/",
{"tasks": [task.id]},
format="json",
)
response = admin_client.get(ENDPOINT, {"acknowledged": "false"})
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 0
def test_requires_change_permission(self, user_client: APIClient) -> None:
task = PaperlessTaskFactory()
response = user_client.post(
ENDPOINT + "acknowledge/",
{"tasks": [task.id]},
format="json",
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_succeeds_with_change_permission(self, regular_user: User) -> None:
regular_user.user_permissions.add(
Permission.objects.get(codename="change_paperlesstask"),
)
regular_user.save()
client = APIClient()
client.force_authenticate(user=regular_user)
client.credentials(HTTP_ACCEPT=ACCEPT_V10)
task = PaperlessTaskFactory()
response = client.post(
ENDPOINT + "acknowledge/",
{"tasks": [task.id]},
format="json",
)
assert response.status_code == status.HTTP_200_OK
def test_list_is_owner_aware(
self,
superuser: User,
regular_user: User,
) -> None:
"""The task list only shows tasks the user owns or that are unowned."""
regular_user.user_permissions.add(
Permission.objects.get(codename="view_paperlesstask"),
)
client = APIClient()
client.force_authenticate(user=regular_user)
client.credentials(HTTP_ACCEPT=ACCEPT_V10)
PaperlessTaskFactory(owner=superuser)
shared_task = PaperlessTaskFactory()
own_task = PaperlessTaskFactory(owner=regular_user)
response = client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 2
returned_task_ids = {t["task_id"] for t in response.data}
assert shared_task.task_id in returned_task_ids
assert own_task.task_id in returned_task_ids
# ---------------------------------------------------------------------------
# TestAcknowledgeAll
# ---------------------------------------------------------------------------
@pytest.mark.django_db()
class TestAcknowledgeAll:
def test_marks_only_completed_tasks(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS, acknowledged=False)
PaperlessTaskFactory(status=PaperlessTask.Status.FAILURE, acknowledged=False)
PaperlessTaskFactory(status=PaperlessTask.Status.PENDING, acknowledged=False)
response = admin_client.post(ENDPOINT + "acknowledge_all/")
assert response.status_code == status.HTTP_200_OK
assert response.data == {"result": 2}
def test_skips_already_acknowledged(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS, acknowledged=True)
PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS, acknowledged=False)
response = admin_client.post(ENDPOINT + "acknowledge_all/")
assert response.status_code == status.HTTP_200_OK
assert response.data == {"result": 1}
def test_skips_pending_and_started(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(status=PaperlessTask.Status.PENDING)
PaperlessTaskFactory(status=PaperlessTask.Status.STARTED)
response = admin_client.post(ENDPOINT + "acknowledge_all/")
assert response.status_code == status.HTTP_200_OK
assert response.data == {"result": 0}
def test_includes_revoked(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(status=PaperlessTask.Status.REVOKED, acknowledged=False)
response = admin_client.post(ENDPOINT + "acknowledge_all/")
assert response.status_code == status.HTTP_200_OK
assert response.data == {"result": 1}
# ---------------------------------------------------------------------------
# TestSummary
# ---------------------------------------------------------------------------
@pytest.mark.django_db()
class TestSummary:
def test_returns_per_type_totals(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(
task_type=PaperlessTask.TaskType.CONSUME_FILE,
status=PaperlessTask.Status.SUCCESS,
)
PaperlessTaskFactory(
task_type=PaperlessTask.TaskType.CONSUME_FILE,
status=PaperlessTask.Status.FAILURE,
)
PaperlessTaskFactory(
task_type=PaperlessTask.TaskType.TRAIN_CLASSIFIER,
status=PaperlessTask.Status.SUCCESS,
)
response = admin_client.get(ENDPOINT + "summary/")
assert response.status_code == status.HTTP_200_OK
by_type = {item["task_type"]: item for item in response.data}
assert by_type["consume_file"]["total_count"] == 2
assert by_type["consume_file"]["success_count"] == 1
assert by_type["consume_file"]["failure_count"] == 1
assert by_type["train_classifier"]["total_count"] == 1
def test_contains_expected_fields(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS)
response = admin_client.get(ENDPOINT + "summary/")
assert response.status_code == status.HTTP_200_OK
assert len(response.data) >= 1
item = response.data[0]
for field in (
"task_type",
"total_count",
"pending_count",
"success_count",
"failure_count",
"avg_duration_seconds",
"avg_wait_time_seconds",
"last_run",
"last_success",
"last_failure",
):
assert field in item, f"Missing field: {field}"
# ---------------------------------------------------------------------------
# TestActive
# ---------------------------------------------------------------------------
@pytest.mark.django_db()
class TestActive:
def test_returns_pending_and_started_only(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(status=PaperlessTask.Status.PENDING)
PaperlessTaskFactory(status=PaperlessTask.Status.STARTED)
PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS)
PaperlessTaskFactory(status=PaperlessTask.Status.FAILURE)
response = admin_client.get(ENDPOINT + "active/")
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 2
active_statuses = {t["status"] for t in response.data}
assert active_statuses == {
PaperlessTask.Status.PENDING,
PaperlessTask.Status.STARTED,
}
def test_excludes_completed_tasks(self, admin_client: APIClient) -> None:
PaperlessTaskFactory(status=PaperlessTask.Status.SUCCESS)
PaperlessTaskFactory(status=PaperlessTask.Status.FAILURE)
PaperlessTaskFactory(status=PaperlessTask.Status.REVOKED)
response = admin_client.get(ENDPOINT + "active/")
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 0
# ---------------------------------------------------------------------------
# TestRun
# ---------------------------------------------------------------------------
@pytest.mark.django_db()
class TestRun:
def test_forbidden_for_regular_user(self, user_client: APIClient) -> None:
response = user_client.post(
ENDPOINT + "run/",
{"task_type": PaperlessTask.TaskType.TRAIN_CLASSIFIER},
format="json",
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_dispatches_via_apply_async_with_manual_trigger_header(
self,
admin_client: APIClient,
) -> None:
fake_task_id = str(uuid.uuid4())
mock_async_result = mock.Mock()
mock_async_result.id = fake_task_id
mock_apply_async = mock.Mock(return_value=mock_async_result)
with mock.patch(
"documents.views.train_classifier.apply_async",
mock_apply_async,
):
response = admin_client.post(
ENDPOINT + "run/",
{"task_type": PaperlessTask.TaskType.TRAIN_CLASSIFIER},
format="json",
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {"task_id": fake_task_id}
mock_apply_async.assert_called_once_with(
kwargs={},
headers={"trigger_source": "manual"},
)
def test_returns_400_for_consume_file(self, admin_client: APIClient) -> None:
"""consume_file cannot be manually triggered via the run endpoint."""
response = admin_client.post(
ENDPOINT + "run/",
{"task_type": PaperlessTask.TaskType.CONSUME_FILE},
format="json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_returns_400_for_invalid_task_type(self, admin_client: APIClient) -> None:
response = admin_client.post(
ENDPOINT + "run/",
{"task_type": "not_a_real_type"},
format="json",
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_sanity_check_dispatched_with_correct_kwargs(
self,
admin_client: APIClient,
) -> None:
fake_task_id = str(uuid.uuid4())
mock_async_result = mock.Mock()
mock_async_result.id = fake_task_id
mock_apply_async = mock.Mock(return_value=mock_async_result)
with mock.patch(
"documents.views.sanity_check.apply_async",
mock_apply_async,
):
response = admin_client.post(
ENDPOINT + "run/",
{"task_type": PaperlessTask.TaskType.SANITY_CHECK},
format="json",
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {"task_id": fake_task_id}
mock_apply_async.assert_called_once_with(
kwargs={"raise_on_error": False},
headers={"trigger_source": "manual"},
)