diff --git a/src/documents/tests/test_api_bulk_download.py b/src/documents/tests/test_api_bulk_download.py index 865f57247..eae03a3ed 100644 --- a/src/documents/tests/test_api_bulk_download.py +++ b/src/documents/tests/test_api_bulk_download.py @@ -15,6 +15,7 @@ from documents.models import Document from documents.models import DocumentType from documents.tests.utils import DirectoriesMixin from documents.tests.utils import SampleDirMixin +from documents.tests.utils import read_streaming_response class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase): @@ -68,7 +69,7 @@ class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response["Content-Type"], "application/zip") - with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + with zipfile.ZipFile(io.BytesIO(read_streaming_response(response))) as zipf: self.assertEqual(len(zipf.filelist), 2) self.assertIn("2021-01-01 document A.pdf", zipf.namelist()) self.assertIn("2020-03-21 document B.jpg", zipf.namelist()) @@ -89,7 +90,7 @@ class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response["Content-Type"], "application/zip") - with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + with zipfile.ZipFile(io.BytesIO(read_streaming_response(response))) as zipf: self.assertEqual(len(zipf.filelist), 2) self.assertIn("2021-01-01 document A.pdf", zipf.namelist()) self.assertIn("2020-03-21 document B.pdf", zipf.namelist()) @@ -110,7 +111,7 @@ class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response["Content-Type"], "application/zip") - with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + with zipfile.ZipFile(io.BytesIO(read_streaming_response(response))) as zipf: self.assertEqual(len(zipf.filelist), 3) self.assertIn("originals/2021-01-01 document A.pdf", zipf.namelist()) self.assertIn("archive/2020-03-21 document B.pdf", zipf.namelist()) @@ -144,7 +145,7 @@ class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response["Content-Type"], "application/zip") - with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + with zipfile.ZipFile(io.BytesIO(read_streaming_response(response))) as zipf: self.assertEqual(len(zipf.filelist), 2) self.assertIn("2021-01-01 document A.pdf", zipf.namelist()) @@ -157,13 +158,14 @@ class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase): self.assertEqual(f.read(), zipf.read("2021-01-01 document A_01.pdf")) def test_compression(self) -> None: - self.client.post( + response = self.client.post( self.ENDPOINT, json.dumps( {"documents": [self.doc2.id, self.doc2b.id], "compression": "lzma"}, ), content_type="application/json", ) + response.close() @override_settings(FILENAME_FORMAT="{correspondent}/{title}") def test_formatted_download_originals(self) -> None: @@ -203,7 +205,7 @@ class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response["Content-Type"], "application/zip") - with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + with zipfile.ZipFile(io.BytesIO(read_streaming_response(response))) as zipf: self.assertEqual(len(zipf.filelist), 2) self.assertIn("a space name/Title 2 - Doc 3.jpg", zipf.namelist()) self.assertIn("test/This is Doc 2.pdf", zipf.namelist()) @@ -249,7 +251,7 @@ class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response["Content-Type"], "application/zip") - with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + with zipfile.ZipFile(io.BytesIO(read_streaming_response(response))) as zipf: self.assertEqual(len(zipf.filelist), 2) self.assertIn("somewhere/This is Doc 2.pdf", zipf.namelist()) self.assertIn("somewhere/Title 2 - Doc 3.pdf", zipf.namelist()) @@ -298,7 +300,7 @@ class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response["Content-Type"], "application/zip") - with zipfile.ZipFile(io.BytesIO(response.content)) as zipf: + with zipfile.ZipFile(io.BytesIO(read_streaming_response(response))) as zipf: self.assertEqual(len(zipf.filelist), 3) self.assertIn("originals/bill/This is Doc 2.pdf", zipf.namelist()) self.assertIn("archive/statement/Title 2 - Doc 3.pdf", zipf.namelist()) diff --git a/src/documents/tests/test_api_document_versions.py b/src/documents/tests/test_api_document_versions.py index bde13354a..848c6ec21 100644 --- a/src/documents/tests/test_api_document_versions.py +++ b/src/documents/tests/test_api_document_versions.py @@ -18,6 +18,7 @@ from documents.filters import EffectiveContentFilter from documents.filters import TitleContentFilter from documents.models import Document from documents.tests.utils import DirectoriesMixin +from documents.tests.utils import read_streaming_response if TYPE_CHECKING: from pathlib import Path @@ -449,19 +450,19 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): f"/api/documents/{root.id}/download/?version={version.id}", ) self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp.content, b"version") + self.assertEqual(read_streaming_response(resp), b"version") resp = self.client.get( f"/api/documents/{root.id}/preview/?version={version.id}", ) self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp.content, b"version") + self.assertEqual(read_streaming_response(resp), b"version") resp = self.client.get( f"/api/documents/{root.id}/thumb/?version={version.id}", ) self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp.content, b"thumb") + self.assertEqual(read_streaming_response(resp), b"thumb") def test_metadata_version_param_uses_version(self) -> None: root = Document.objects.create( diff --git a/src/documents/tests/test_api_documents.py b/src/documents/tests/test_api_documents.py index 24165c087..844dc2af5 100644 --- a/src/documents/tests/test_api_documents.py +++ b/src/documents/tests/test_api_documents.py @@ -49,6 +49,7 @@ from documents.models import WorkflowTrigger from documents.signals.handlers import run_workflows from documents.tests.utils import ConsumeTaskMixin from documents.tests.utils import DirectoriesMixin +from documents.tests.utils import read_streaming_response class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase): @@ -323,19 +324,16 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase): f.write(content_thumbnail) response = self.client.get(f"/api/documents/{doc.pk}/download/") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.content, content) + self.assertEqual(read_streaming_response(response), content) response = self.client.get(f"/api/documents/{doc.pk}/preview/") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.content, content) + self.assertEqual(read_streaming_response(response), content) response = self.client.get(f"/api/documents/{doc.pk}/thumb/") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.content, content_thumbnail) + self.assertEqual(read_streaming_response(response), content_thumbnail) def test_document_actions_with_perms(self) -> None: """ @@ -386,12 +384,15 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase): response = self.client.get(f"/api/documents/{doc.pk}/download/") self.assertEqual(response.status_code, status.HTTP_200_OK) + response.close() response = self.client.get(f"/api/documents/{doc.pk}/preview/") self.assertEqual(response.status_code, status.HTTP_200_OK) + response.close() response = self.client.get(f"/api/documents/{doc.pk}/thumb/") self.assertEqual(response.status_code, status.HTTP_200_OK) + response.close() @override_settings(FILENAME_FORMAT="") def test_download_with_archive(self) -> None: @@ -412,28 +413,24 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase): f.write(content_archive) response = self.client.get(f"/api/documents/{doc.pk}/download/") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.content, content_archive) + self.assertEqual(read_streaming_response(response), content_archive) response = self.client.get( f"/api/documents/{doc.pk}/download/?original=true", ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.content, content) + self.assertEqual(read_streaming_response(response), content) response = self.client.get(f"/api/documents/{doc.pk}/preview/") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.content, content_archive) + self.assertEqual(read_streaming_response(response), content_archive) response = self.client.get( f"/api/documents/{doc.pk}/preview/?original=true", ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.content, content) + self.assertEqual(read_streaming_response(response), content) @override_settings(FILENAME_FORMAT="") def test_download_follow_formatting(self) -> None: @@ -456,18 +453,21 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase): # Without follow_formatting, should use public filename response = self.client.get(f"/api/documents/{doc.pk}/download/") self.assertIn("none.pdf", response["Content-Disposition"]) + response.close() # With follow_formatting, should use actual filename on disk response = self.client.get( f"/api/documents/{doc.pk}/download/?follow_formatting=true", ) self.assertIn("archived.pdf", response["Content-Disposition"]) + response.close() # With follow_formatting and original, should use source filename response = self.client.get( f"/api/documents/{doc.pk}/download/?original=true&follow_formatting=true", ) self.assertIn("my_document.pdf", response["Content-Disposition"]) + response.close() def test_document_actions_not_existing_file(self) -> None: doc = Document.objects.create( diff --git a/src/documents/tests/test_api_schema.py b/src/documents/tests/test_api_schema.py index 876722be0..909d1829f 100644 --- a/src/documents/tests/test_api_schema.py +++ b/src/documents/tests/test_api_schema.py @@ -253,3 +253,47 @@ class TestShareLinkBundleRebuildSchema: else: props = resp_400.get("properties", {}) assert "detail" in props, "rebuild 400 response must have a 'detail' field" + + +class TestBulkDownloadSchema: + """bulk_download_create: POST accepts BulkDownloadSerializer, returns application/zip, documents 403.""" + + def test_bulk_download_path_exists(self, api_schema: SchemaGenerator) -> None: + assert "/api/documents/bulk_download/" in api_schema["paths"] + + def test_bulk_download_operation_id(self, api_schema: SchemaGenerator) -> None: + op = api_schema["paths"]["/api/documents/bulk_download/"]["post"] + assert op["operationId"] == "bulk_download" + + def test_bulk_download_request_body_is_json( + self, + api_schema: SchemaGenerator, + ) -> None: + op = api_schema["paths"]["/api/documents/bulk_download/"]["post"] + assert "requestBody" in op + assert "application/json" in op["requestBody"]["content"] + + def test_bulk_download_request_references_serializer( + self, + api_schema: SchemaGenerator, + ) -> None: + op = api_schema["paths"]["/api/documents/bulk_download/"]["post"] + schema_ref = ( + op["requestBody"]["content"]["application/json"] + .get("schema", {}) + .get("$ref", "") + ) + component_name = schema_ref.split("/")[-1] + assert component_name == "BulkDownloadRequest" + + def test_bulk_download_response_200_is_zip( + self, + api_schema: SchemaGenerator, + ) -> None: + op = api_schema["paths"]["/api/documents/bulk_download/"]["post"] + assert "200" in op["responses"] + assert "application/zip" in op["responses"]["200"]["content"] + + def test_bulk_download_response_403(self, api_schema: SchemaGenerator) -> None: + op = api_schema["paths"]["/api/documents/bulk_download/"]["post"] + assert "403" in op["responses"] diff --git a/src/documents/tests/test_views.py b/src/documents/tests/test_views.py index 314636045..9b1724a16 100644 --- a/src/documents/tests/test_views.py +++ b/src/documents/tests/test_views.py @@ -27,6 +27,7 @@ from documents.models import StoragePath from documents.models import Tag from documents.signals.handlers import update_llm_suggestions_cache from documents.tests.utils import DirectoriesMixin +from documents.tests.utils import read_streaming_response from paperless.models import ApplicationConfiguration @@ -157,7 +158,7 @@ class TestViews(DirectoriesMixin, TestCase): # Valid response = self.client.get(f"/share/{sl1.slug}") self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.content, content) + self.assertEqual(read_streaming_response(response), content) # Invalid response = self.client.get("/share/123notaslug", follow=True) diff --git a/src/documents/tests/utils.py b/src/documents/tests/utils.py index 530f588e8..38e7c927a 100644 --- a/src/documents/tests/utils.py +++ b/src/documents/tests/utils.py @@ -17,6 +17,7 @@ import pytest from django.apps import apps from django.db import connection from django.db.migrations.executor import MigrationExecutor +from django.http import StreamingHttpResponse from django.test import TransactionTestCase from django.test import override_settings @@ -150,6 +151,13 @@ def util_call_with_backoff( return succeeded, result +def read_streaming_response(response: StreamingHttpResponse) -> bytes: + """Consume a StreamingHttpResponse/FileResponse and close it.""" + content = b"".join(response.streaming_content) + response.close() + return content + + class DirectoriesMixin: """ Creates and overrides settings for all folders and paths, then ensures diff --git a/src/documents/views.py b/src/documents/views.py index cbf9259b3..d4c2ccd11 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -281,8 +281,7 @@ class IndexView(TemplateView): first = lang[: lang.index("-")] second = lang[lang.index("-") + 1 :] return f"{first}-{second.upper()}" - else: - return lang + return lang def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) @@ -427,9 +426,7 @@ class BulkPermissionMixin: class PermissionsAwareDocumentCountMixin(BulkPermissionMixin, PassUserMixin): - """ - Mixin to add document count to queryset, permissions-aware if needed - """ + """Mixin to add document count to queryset, permissions-aware if needed""" # Default is simple relation path, override for through-table/count specialization. document_count_through: type[Model] | None = None @@ -1231,8 +1228,7 @@ class DocumentViewSet( def get_filesize(self, filename): if Path(filename).is_file(): return Path(filename).stat().st_size - else: - return None + return None @action(methods=["get"], detail=True, filter_backends=[]) @method_decorator(cache_control(no_cache=True)) @@ -1449,7 +1445,7 @@ class DocumentViewSet( file_doc = self._get_effective_file_doc(request_doc, root_doc, request) handle = file_doc.thumbnail_file - return HttpResponse(handle, content_type="image/webp") + return FileResponse(handle, content_type="image/webp") except FileNotFoundError: raise Http404 @@ -2107,8 +2103,7 @@ class UnifiedSearchViewSet(DocumentViewSet): def get_serializer_class(self): if self._is_search_request(): return SearchResultSerializer - else: - return DocumentSerializer + return DocumentSerializer def _get_active_search_params(self, request: Request | None = None) -> list[str]: request = request or self.request @@ -3226,7 +3221,7 @@ class GlobalSearchView(PassUserMixin): query = request.query_params.get("query", None) if query is None: return HttpResponseBadRequest("Query required") - elif len(query) < 3: + if len(query) < 3: return HttpResponseBadRequest("Query must be at least 3 characters") db_only = request.query_params.get("db_only", False) @@ -3521,7 +3516,7 @@ class StatisticsView(GenericAPIView[Any]): "inbox_tag": ( inbox_tag_pks[0] if inbox_tag_pks else None ), # backwards compatibility - "inbox_tags": (inbox_tag_pks if inbox_tag_pks else None), + "inbox_tags": (inbox_tag_pks or None), "document_file_type_counts": document_file_type_counts, "character_count": character_count, "tag_count": len(tags), @@ -3533,6 +3528,16 @@ class StatisticsView(GenericAPIView[Any]): ) +@extend_schema_view( + post=extend_schema( + operation_id="bulk_download", + description="Download multiple documents as a ZIP archive.", + responses={ + (HTTPStatus.OK, "application/zip"): OpenApiTypes.BINARY, + HTTPStatus.FORBIDDEN: None, + }, + ), +) class BulkDownloadView(DocumentSelectionMixin, GenericAPIView[Any]): permission_classes = (IsAuthenticated,) serializer_class = BulkDownloadSerializer @@ -3555,13 +3560,6 @@ class BulkDownloadView(DocumentSelectionMixin, GenericAPIView[Any]): if not has_perms_owner_aware(request.user, "change_document", document): return HttpResponseForbidden("Insufficient permissions") - settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True) - temp = tempfile.NamedTemporaryFile( # noqa: SIM115 - dir=settings.SCRATCH_DIR, - suffix="-compressed-archive", - delete=False, - ) - if content == "both": strategy_class = OriginalAndArchiveStrategy elif content == "originals": @@ -3569,20 +3567,35 @@ class BulkDownloadView(DocumentSelectionMixin, GenericAPIView[Any]): else: strategy_class = ArchiveOnlyStrategy - with zipfile.ZipFile(temp.name, "w", compression) as zipf: - strategy = strategy_class(zipf, follow_formatting=follow_filename_format) - for document in documents: - strategy.add_document(document) + settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True) + fd, temp_name = tempfile.mkstemp( + dir=settings.SCRATCH_DIR, + suffix="-compressed-archive", + ) + os.close(fd) + temp_path = Path(temp_name) - # TODO(stumpylog): Investigate using FileResponse here - with Path(temp.name).open("rb") as f: - response = HttpResponse(f, content_type="application/zip") - response["Content-Disposition"] = '{}; filename="{}"'.format( - "attachment", - "documents.zip", - ) + try: + with zipfile.ZipFile(temp_path, "w", compression) as zipf: + strategy = strategy_class( + zipf, + follow_formatting=follow_filename_format, + ) + for document in documents: + strategy.add_document(document) - return response + f = temp_path.open("rb") + temp_path.unlink() + except Exception: + temp_path.unlink(missing_ok=True) + raise + + return FileResponse( + f, + as_attachment=True, + filename="documents.zip", + content_type="application/zip", + ) @extend_schema_view( @@ -4290,7 +4303,7 @@ def serve_file( use_archive: bool, disposition: str, follow_formatting: bool = False, -) -> HttpResponse: +) -> FileResponse: if use_archive: if TYPE_CHECKING: assert doc.archive_filename @@ -4313,7 +4326,7 @@ def serve_file( if mime_type in {"application/csv", "text/csv"} and disposition == "inline": mime_type = "text/plain" - response = HttpResponse(file_handle, content_type=mime_type) + response = FileResponse(file_handle, content_type=mime_type) # Firefox is not able to handle unicode characters in filename field # RFC 5987 addresses this issue # see https://datatracker.ietf.org/doc/html/rfc5987#section-4.2