From 57565fe406a67c25c8c522df6dd493dcc2de9be2 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Tue, 10 Mar 2026 18:57:58 -0700 Subject: [PATCH] Tests for include_selection_data --- src/documents/tests/test_api_documents.py | 50 +++++++++++++++++++++++ src/documents/tests/test_api_search.py | 40 ++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/src/documents/tests/test_api_documents.py b/src/documents/tests/test_api_documents.py index 2dda91e98..538fc6dd3 100644 --- a/src/documents/tests/test_api_documents.py +++ b/src/documents/tests/test_api_documents.py @@ -1144,6 +1144,56 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.assertEqual(len(response.data["all"]), 50) self.assertCountEqual(response.data["all"], [d.id for d in docs]) + def test_list_with_include_selection_data(self) -> None: + correspondent = Correspondent.objects.create(name="c1") + doc_type = DocumentType.objects.create(name="dt1") + storage_path = StoragePath.objects.create(name="sp1") + tag = Tag.objects.create(name="tag") + + matching_doc = Document.objects.create( + checksum="A", + correspondent=correspondent, + document_type=doc_type, + storage_path=storage_path, + ) + matching_doc.tags.add(tag) + + non_matching_doc = Document.objects.create(checksum="B") + non_matching_doc.tags.add(Tag.objects.create(name="other")) + + response = self.client.get( + f"/api/documents/?tags__id__in={tag.id}&include_selection_data=true", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("selection_data", response.data) + + selected_correspondent = next( + item + for item in response.data["selection_data"]["selected_correspondents"] + if item["id"] == correspondent.id + ) + selected_tag = next( + item + for item in response.data["selection_data"]["selected_tags"] + if item["id"] == tag.id + ) + selected_type = next( + item + for item in response.data["selection_data"]["selected_document_types"] + if item["id"] == doc_type.id + ) + selected_storage_path = next( + item + for item in response.data["selection_data"]["selected_storage_paths"] + if item["id"] == storage_path.id + ) + + self.assertEqual(selected_correspondent["document_count"], 1) + self.assertEqual(selected_tag["document_count"], 1) + self.assertEqual(selected_type["document_count"], 1) + self.assertEqual(selected_storage_path["document_count"], 1) + def test_statistics(self) -> None: doc1 = Document.objects.create( title="none1", diff --git a/src/documents/tests/test_api_search.py b/src/documents/tests/test_api_search.py index 6c2ad1eb8..bd70e60c7 100644 --- a/src/documents/tests/test_api_search.py +++ b/src/documents/tests/test_api_search.py @@ -89,6 +89,46 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): self.assertEqual(len(results), 0) self.assertCountEqual(response.data["all"], []) + def test_search_with_include_selection_data(self) -> None: + correspondent = Correspondent.objects.create(name="c1") + doc_type = DocumentType.objects.create(name="dt1") + storage_path = StoragePath.objects.create(name="sp1") + tag = Tag.objects.create(name="tag") + + matching_doc = Document.objects.create( + title="bank statement", + content="bank content", + checksum="A", + correspondent=correspondent, + document_type=doc_type, + storage_path=storage_path, + ) + matching_doc.tags.add(tag) + + with AsyncWriter(index.open_index()) as writer: + index.update_document(writer, matching_doc) + + response = self.client.get( + "/api/documents/?query=bank&include_selection_data=true", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("selection_data", response.data) + + selected_correspondent = next( + item + for item in response.data["selection_data"]["selected_correspondents"] + if item["id"] == correspondent.id + ) + selected_tag = next( + item + for item in response.data["selection_data"]["selected_tags"] + if item["id"] == tag.id + ) + + self.assertEqual(selected_correspondent["document_count"], 1) + self.assertEqual(selected_tag["document_count"], 1) + def test_search_custom_field_ordering(self) -> None: custom_field = CustomField.objects.create( name="Sortable field",