diff --git a/src/documents/tests/test_api_tasks.py b/src/documents/tests/test_api_tasks.py index 4ff3f1c87..d12506e56 100644 --- a/src/documents/tests/test_api_tasks.py +++ b/src/documents/tests/test_api_tasks.py @@ -422,8 +422,8 @@ class TestGetTasksV9: assert len(response.data) == 1 assert response.data[0]["task_name"] == "consume_file" - def test_filter_by_type_maps_to_trigger_source(self, v9_client: APIClient) -> None: - """?type=scheduled_task filter maps to trigger_source=scheduled for v9 compatibility.""" + def test_filter_by_type_scheduled_task(self, v9_client: APIClient) -> None: + """?type=scheduled_task matches trigger_source=scheduled only.""" PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.SCHEDULED) PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.WEB_UI) @@ -433,6 +433,42 @@ class TestGetTasksV9: assert len(response.data) == 1 assert response.data[0]["type"] == "scheduled_task" + def test_filter_by_type_auto_task_includes_all_auto_sources( + self, + v9_client: APIClient, + ) -> None: + """?type=auto_task matches system, email_consume, and folder_consume tasks.""" + PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.SYSTEM) + PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.EMAIL_CONSUME) + PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.FOLDER_CONSUME) + PaperlessTaskFactory( + trigger_source=PaperlessTask.TriggerSource.MANUAL, + ) # excluded + + response = v9_client.get(ENDPOINT, {"type": "auto_task"}) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 3 + assert all(t["type"] == "auto_task" for t in response.data) + + def test_filter_by_type_manual_task_includes_all_manual_sources( + self, + v9_client: APIClient, + ) -> None: + """?type=manual_task matches manual, web_ui, and api_upload tasks.""" + PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.MANUAL) + PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.WEB_UI) + PaperlessTaskFactory(trigger_source=PaperlessTask.TriggerSource.API_UPLOAD) + PaperlessTaskFactory( + trigger_source=PaperlessTask.TriggerSource.SCHEDULED, + ) # excluded + + response = v9_client.get(ENDPOINT, {"type": "manual_task"}) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 3 + assert all(t["type"] == "manual_task" for t in response.data) + @pytest.mark.django_db() class TestAcknowledge: diff --git a/src/documents/views.py b/src/documents/views.py index d56ced452..740eed792 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -3794,11 +3794,20 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]): "llmindex_update": PaperlessTask.TaskType.LLM_INDEX, } - # v9 backwards compat: maps old "type" query param values to new TriggerSource - _V9_TYPE_TO_TRIGGER_SOURCE = { - "auto_task": PaperlessTask.TriggerSource.SYSTEM, - "scheduled_task": PaperlessTask.TriggerSource.SCHEDULED, - "manual_task": PaperlessTask.TriggerSource.MANUAL, + # v9 backwards compat: maps old "type" query param values to new TriggerSource. + # Must match the reverse of TaskSerializerV9._TRIGGER_SOURCE_TO_V9_TYPE. + _V9_TYPE_TO_TRIGGER_SOURCES = { + "auto_task": [ + PaperlessTask.TriggerSource.SYSTEM, + PaperlessTask.TriggerSource.EMAIL_CONSUME, + PaperlessTask.TriggerSource.FOLDER_CONSUME, + ], + "scheduled_task": [PaperlessTask.TriggerSource.SCHEDULED], + "manual_task": [ + PaperlessTask.TriggerSource.MANUAL, + PaperlessTask.TriggerSource.WEB_UI, + PaperlessTask.TriggerSource.API_UPLOAD, + ], } _RUNNABLE_TASKS = { @@ -3834,9 +3843,9 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]): queryset = queryset.filter(task_type=mapped) task_type_old = self.request.query_params.get("type") if task_type_old is not None: - new_source = self._V9_TYPE_TO_TRIGGER_SOURCE.get(task_type_old) - if new_source: - queryset = queryset.filter(trigger_source=new_source) + sources = self._V9_TYPE_TO_TRIGGER_SOURCES.get(task_type_old) + if sources: + queryset = queryset.filter(trigger_source__in=sources) # v10+: direct task_id param for backwards compat task_id = self.request.query_params.get("task_id") if task_id is not None: