mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-07-02 02:04:19 +00:00
Performance: support passing selection data with filtered document requests (#12300)
This commit is contained in:
@@ -1144,6 +1144,56 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
|
||||
self.assertEqual(len(response.data["all"]), 50)
|
||||
self.assertCountEqual(response.data["all"], [d.id for d in docs])
|
||||
|
||||
def test_list_with_include_selection_data(self) -> None:
|
||||
correspondent = Correspondent.objects.create(name="c1")
|
||||
doc_type = DocumentType.objects.create(name="dt1")
|
||||
storage_path = StoragePath.objects.create(name="sp1")
|
||||
tag = Tag.objects.create(name="tag")
|
||||
|
||||
matching_doc = Document.objects.create(
|
||||
checksum="A",
|
||||
correspondent=correspondent,
|
||||
document_type=doc_type,
|
||||
storage_path=storage_path,
|
||||
)
|
||||
matching_doc.tags.add(tag)
|
||||
|
||||
non_matching_doc = Document.objects.create(checksum="B")
|
||||
non_matching_doc.tags.add(Tag.objects.create(name="other"))
|
||||
|
||||
response = self.client.get(
|
||||
f"/api/documents/?tags__id__in={tag.id}&include_selection_data=true",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("selection_data", response.data)
|
||||
|
||||
selected_correspondent = next(
|
||||
item
|
||||
for item in response.data["selection_data"]["selected_correspondents"]
|
||||
if item["id"] == correspondent.id
|
||||
)
|
||||
selected_tag = next(
|
||||
item
|
||||
for item in response.data["selection_data"]["selected_tags"]
|
||||
if item["id"] == tag.id
|
||||
)
|
||||
selected_type = next(
|
||||
item
|
||||
for item in response.data["selection_data"]["selected_document_types"]
|
||||
if item["id"] == doc_type.id
|
||||
)
|
||||
selected_storage_path = next(
|
||||
item
|
||||
for item in response.data["selection_data"]["selected_storage_paths"]
|
||||
if item["id"] == storage_path.id
|
||||
)
|
||||
|
||||
self.assertEqual(selected_correspondent["document_count"], 1)
|
||||
self.assertEqual(selected_tag["document_count"], 1)
|
||||
self.assertEqual(selected_type["document_count"], 1)
|
||||
self.assertEqual(selected_storage_path["document_count"], 1)
|
||||
|
||||
def test_statistics(self) -> None:
|
||||
doc1 = Document.objects.create(
|
||||
title="none1",
|
||||
|
||||
@@ -89,6 +89,46 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
self.assertEqual(len(results), 0)
|
||||
self.assertCountEqual(response.data["all"], [])
|
||||
|
||||
def test_search_with_include_selection_data(self) -> None:
|
||||
correspondent = Correspondent.objects.create(name="c1")
|
||||
doc_type = DocumentType.objects.create(name="dt1")
|
||||
storage_path = StoragePath.objects.create(name="sp1")
|
||||
tag = Tag.objects.create(name="tag")
|
||||
|
||||
matching_doc = Document.objects.create(
|
||||
title="bank statement",
|
||||
content="bank content",
|
||||
checksum="A",
|
||||
correspondent=correspondent,
|
||||
document_type=doc_type,
|
||||
storage_path=storage_path,
|
||||
)
|
||||
matching_doc.tags.add(tag)
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
index.update_document(writer, matching_doc)
|
||||
|
||||
response = self.client.get(
|
||||
"/api/documents/?query=bank&include_selection_data=true",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("selection_data", response.data)
|
||||
|
||||
selected_correspondent = next(
|
||||
item
|
||||
for item in response.data["selection_data"]["selected_correspondents"]
|
||||
if item["id"] == correspondent.id
|
||||
)
|
||||
selected_tag = next(
|
||||
item
|
||||
for item in response.data["selection_data"]["selected_tags"]
|
||||
if item["id"] == tag.id
|
||||
)
|
||||
|
||||
self.assertEqual(selected_correspondent["document_count"], 1)
|
||||
self.assertEqual(selected_tag["document_count"], 1)
|
||||
|
||||
def test_search_custom_field_ordering(self) -> None:
|
||||
custom_field = CustomField.objects.create(
|
||||
name="Sortable field",
|
||||
|
||||
@@ -836,6 +836,61 @@ class DocumentViewSet(
|
||||
"custom_field_",
|
||||
)
|
||||
|
||||
def _get_selection_data_for_queryset(self, queryset):
|
||||
correspondents = Correspondent.objects.annotate(
|
||||
document_count=Count(
|
||||
"documents",
|
||||
filter=Q(documents__in=queryset),
|
||||
distinct=True,
|
||||
),
|
||||
)
|
||||
tags = Tag.objects.annotate(
|
||||
document_count=Count(
|
||||
"documents",
|
||||
filter=Q(documents__in=queryset),
|
||||
distinct=True,
|
||||
),
|
||||
)
|
||||
document_types = DocumentType.objects.annotate(
|
||||
document_count=Count(
|
||||
"documents",
|
||||
filter=Q(documents__in=queryset),
|
||||
distinct=True,
|
||||
),
|
||||
)
|
||||
storage_paths = StoragePath.objects.annotate(
|
||||
document_count=Count(
|
||||
"documents",
|
||||
filter=Q(documents__in=queryset),
|
||||
distinct=True,
|
||||
),
|
||||
)
|
||||
custom_fields = CustomField.objects.annotate(
|
||||
document_count=Count(
|
||||
"fields__document",
|
||||
filter=Q(fields__document__in=queryset),
|
||||
distinct=True,
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"selected_correspondents": [
|
||||
{"id": t.id, "document_count": t.document_count} for t in correspondents
|
||||
],
|
||||
"selected_tags": [
|
||||
{"id": t.id, "document_count": t.document_count} for t in tags
|
||||
],
|
||||
"selected_document_types": [
|
||||
{"id": t.id, "document_count": t.document_count} for t in document_types
|
||||
],
|
||||
"selected_storage_paths": [
|
||||
{"id": t.id, "document_count": t.document_count} for t in storage_paths
|
||||
],
|
||||
"selected_custom_fields": [
|
||||
{"id": t.id, "document_count": t.document_count} for t in custom_fields
|
||||
],
|
||||
}
|
||||
|
||||
def get_queryset(self):
|
||||
latest_version_content = Subquery(
|
||||
Document.objects.filter(root_document=OuterRef("pk"))
|
||||
@@ -983,6 +1038,25 @@ class DocumentViewSet(
|
||||
|
||||
return response
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
if not get_boolean(
|
||||
str(request.query_params.get("include_selection_data", "false")),
|
||||
):
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
selection_data = self._get_selection_data_for_queryset(queryset)
|
||||
|
||||
page = self.paginate_queryset(queryset)
|
||||
if page is not None:
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
response = self.get_paginated_response(serializer.data)
|
||||
response.data["selection_data"] = selection_data
|
||||
return response
|
||||
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return Response({"results": serializer.data, "selection_data": selection_data})
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
from documents import index
|
||||
|
||||
@@ -2023,6 +2097,21 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||
else None
|
||||
)
|
||||
|
||||
if get_boolean(
|
||||
str(
|
||||
request.query_params.get(
|
||||
"include_selection_data",
|
||||
"false",
|
||||
),
|
||||
),
|
||||
):
|
||||
result_ids = response.data.get("all", [])
|
||||
response.data["selection_data"] = (
|
||||
self._get_selection_data_for_queryset(
|
||||
Document.objects.filter(pk__in=result_ids),
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except NotFound:
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user