diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 04c332393..b1c2239d8 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -1540,6 +1540,41 @@ class DocumentListSerializer(serializers.Serializer): return documents +class DocumentSelectionSerializer(DocumentListSerializer): + documents = serializers.ListField( + required=False, + label="Documents", + write_only=True, + child=serializers.IntegerField(), + ) + + all = serializers.BooleanField( + default=False, + required=False, + write_only=True, + ) + + filters = serializers.DictField( + required=False, + allow_empty=True, + write_only=True, + ) + + def validate(self, attrs): + if attrs.get("all", False): + attrs.setdefault("documents", []) + return attrs + + if "documents" not in attrs: + raise serializers.ValidationError( + "documents is required unless all is true.", + ) + + documents = attrs["documents"] + self._validate_document_id_list(documents) + return attrs + + class SourceModeValidationMixin: def validate_source_mode(self, source_mode: str) -> str: if source_mode not in bulk_edit.SourceModeChoices.__dict__.values(): @@ -1547,7 +1582,7 @@ class SourceModeValidationMixin: return source_mode -class RotateDocumentsSerializer(DocumentListSerializer, SourceModeValidationMixin): +class RotateDocumentsSerializer(DocumentSelectionSerializer, SourceModeValidationMixin): degrees = serializers.IntegerField(required=True) source_mode = serializers.CharField( required=False, @@ -1630,17 +1665,17 @@ class RemovePasswordDocumentsSerializer( ) -class DeleteDocumentsSerializer(DocumentListSerializer): +class DeleteDocumentsSerializer(DocumentSelectionSerializer): pass -class ReprocessDocumentsSerializer(DocumentListSerializer): +class ReprocessDocumentsSerializer(DocumentSelectionSerializer): pass class BulkEditSerializer( SerializerWithPerms, - DocumentListSerializer, + DocumentSelectionSerializer, SetPermissionsMixin, SourceModeValidationMixin, ): @@ -2212,7 +2247,7 @@ class DocumentVersionLabelSerializer(serializers.Serializer): return normalized or None -class BulkDownloadSerializer(DocumentListSerializer): +class BulkDownloadSerializer(DocumentSelectionSerializer): content = serializers.ChoiceField( choices=["archive", "originals", "both"], default="archive", diff --git a/src/documents/views.py b/src/documents/views.py index e26edf69c..bf15a022d 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -2215,7 +2215,36 @@ class SavedViewViewSet(BulkPermissionMixin, PassUserMixin, ModelViewSet): ordering_fields = ("name",) -class DocumentOperationPermissionMixin(PassUserMixin): +class DocumentSelectionMixin: + def _resolve_document_ids( + self, + *, + user: User, + validated_data: dict[str, Any], + permission_codename: str = "view_document", + ) -> list[int]: + if not validated_data.get("all", False): + # if all is not true, just pass through the provided document ids + return validated_data["documents"] + + # otherwise, reconstruct the document list based on the provided filters + filters = validated_data.get("filters") or {} + permitted_documents = get_objects_for_user_owner_aware( + user, + permission_codename, + Document, + ) + return list( + DocumentFilterSet( + data=filters, + queryset=permitted_documents, + ) + .qs.distinct() + .values_list("pk", flat=True), + ) + + +class DocumentOperationPermissionMixin(PassUserMixin, DocumentSelectionMixin): permission_classes = (IsAuthenticated,) parser_classes = (parsers.JSONParser,) METHOD_NAMES_REQUIRING_USER = { @@ -2309,8 +2338,15 @@ class DocumentOperationPermissionMixin(PassUserMixin): validated_data: dict[str, Any], operation_label: str, ): - documents = validated_data["documents"] - parameters = {k: v for k, v in validated_data.items() if k != "documents"} + documents = self._resolve_document_ids( + user=self.request.user, + validated_data=validated_data, + ) + parameters = { + k: v + for k, v in validated_data.items() + if k not in {"documents", "all", "filters"} + } user = self.request.user if method.__name__ in self.METHOD_NAMES_REQUIRING_USER: @@ -2398,7 +2434,10 @@ class BulkEditView(DocumentOperationPermissionMixin): user = self.request.user method = serializer.validated_data.get("method") parameters = serializer.validated_data.get("parameters") - documents = serializer.validated_data.get("documents") + documents = self._resolve_document_ids( + user=user, + validated_data=serializer.validated_data, + ) if method.__name__ in self.METHOD_NAMES_REQUIRING_USER: parameters["user"] = user if not self._has_document_permissions( @@ -3242,7 +3281,7 @@ class StatisticsView(GenericAPIView): ) -class BulkDownloadView(GenericAPIView): +class BulkDownloadView(DocumentSelectionMixin, GenericAPIView): permission_classes = (IsAuthenticated,) serializer_class = BulkDownloadSerializer parser_classes = (parsers.JSONParser,) @@ -3251,7 +3290,10 @@ class BulkDownloadView(GenericAPIView): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - ids = serializer.validated_data.get("documents") + ids = self._resolve_document_ids( + user=request.user, + validated_data=serializer.validated_data, + ) documents = Document.objects.filter(pk__in=ids) compression = serializer.validated_data.get("compression") content = serializer.validated_data.get("content")