diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 1ff3798db..7726141be 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -2632,11 +2632,17 @@ class RunTaskSerializer(serializers.Serializer[dict[str, str]]): class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]): tasks = serializers.ListField( - required=True, + required=False, label="Tasks", write_only=True, child=serializers.IntegerField(), ) + all = serializers.BooleanField( + required=False, + default=False, + label="All", + write_only=True, + ) def _validate_task_id_list(self, tasks, name="tasks") -> None: if not isinstance(tasks, list): @@ -2653,6 +2659,21 @@ class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]): self._validate_task_id_list(tasks) return tasks + def validate(self, attrs): + acknowledge_all = attrs.get("all", False) + task_ids = attrs.get("tasks") + + if acknowledge_all and task_ids is not None: + raise serializers.ValidationError( + "Set either all or tasks, not both.", + ) + if not acknowledge_all and task_ids is None: + raise serializers.ValidationError( + "Either all must be true or tasks must be provided.", + ) + + return attrs + class ShareLinkSerializer(OwnedObjectSerializer): class Meta: diff --git a/src/documents/tests/test_api_tasks.py b/src/documents/tests/test_api_tasks.py index f9b6c4538..42ccbab5c 100644 --- a/src/documents/tests/test_api_tasks.py +++ b/src/documents/tests/test_api_tasks.py @@ -522,6 +522,27 @@ class TestAcknowledge: assert response.status_code == status.HTTP_200_OK assert response.data == {"result": 2} + def test_acknowledge_all_returns_count(self, admin_client: APIClient) -> None: + """POST acknowledge/ with all=true acknowledges all unacknowledged tasks.""" + unacknowledged_task1 = PaperlessTaskFactory(acknowledged=False) + unacknowledged_task2 = PaperlessTaskFactory(acknowledged=False) + acknowledged_task = PaperlessTaskFactory(acknowledged=True) + + response = admin_client.post( + ENDPOINT + "acknowledge/", + {"all": True}, + format="json", + ) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"result": 2} + unacknowledged_task1.refresh_from_db() + unacknowledged_task2.refresh_from_db() + acknowledged_task.refresh_from_db() + assert unacknowledged_task1.acknowledged + assert unacknowledged_task2.acknowledged + assert acknowledged_task.acknowledged + def test_acknowledged_tasks_excluded_from_unacked_filter( self, admin_client: APIClient, diff --git a/src/documents/views.py b/src/documents/views.py index 4fc0b3f51..f2e611ae6 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -4033,7 +4033,7 @@ class _TasksViewSetSchema(AutoSchema): ), acknowledge=extend_schema( operation_id="acknowledge_tasks", - description="Acknowledge a list of tasks", + description="Acknowledge a list of tasks, or all visible unacknowledged tasks", request=AcknowledgeTasksViewSerializer, responses={ (200, "application/json"): inline_serializer( @@ -4172,8 +4172,11 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]): def acknowledge(self, request): serializer = AcknowledgeTasksViewSerializer(data=request.data) serializer.is_valid(raise_exception=True) - task_ids = serializer.validated_data.get("tasks") - tasks = self.get_queryset().filter(id__in=task_ids) + if serializer.validated_data.get("all", False): + tasks = self.get_queryset().filter(acknowledged=False) + else: + task_ids = serializer.validated_data.get("tasks") + tasks = self.get_queryset().filter(id__in=task_ids) count = tasks.update(acknowledged=True) return Response({"result": count})