From d998d3fbaf3aebdacb3b3eadd60c16435e4eff18 Mon Sep 17 00:00:00 2001 From: Trenton Holmes <797416+stumpylog@users.noreply.github.com> Date: Fri, 3 Apr 2026 15:35:14 -0700 Subject: [PATCH] feat: delegate sorting to Tantivy and use page-only highlights in viewset Co-Authored-By: Claude Opus 4.6 --- src/documents/tests/test_api_search.py | 83 ++++++++++++++++++++ src/documents/views.py | 102 ++++++++++++++++++------- 2 files changed, 157 insertions(+), 28 deletions(-) diff --git a/src/documents/tests/test_api_search.py b/src/documents/tests/test_api_search.py index 9e0879e89..54b960719 100644 --- a/src/documents/tests/test_api_search.py +++ b/src/documents/tests/test_api_search.py @@ -1503,6 +1503,89 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): [d2.id, d1.id, d3.id], ) + def test_search_with_tantivy_native_sort(self) -> None: + """When ordering by a Tantivy-sortable field, results must be correctly sorted.""" + backend = get_backend() + for i, asn in enumerate([30, 10, 20]): + doc = Document.objects.create( + title=f"sortable doc {i}", + content="searchable content", + checksum=f"TNS{i}", + archive_serial_number=asn, + ) + backend.add_or_update(doc) + + response = self.client.get( + "/api/documents/?query=searchable&ordering=archive_serial_number", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + asns = [doc["archive_serial_number"] for doc in response.data["results"]] + self.assertEqual(asns, [10, 20, 30]) + + response = self.client.get( + "/api/documents/?query=searchable&ordering=-archive_serial_number", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + asns = [doc["archive_serial_number"] for doc in response.data["results"]] + self.assertEqual(asns, [30, 20, 10]) + + def test_search_page_2_returns_correct_slice(self) -> None: + """Page 2 must return the second slice, not overlap with page 1.""" + backend = get_backend() + for i in range(10): + doc = Document.objects.create( + title=f"doc {i}", + content="paginated content", + checksum=f"PG2{i}", + archive_serial_number=i + 1, + ) + backend.add_or_update(doc) + + response = self.client.get( + "/api/documents/?query=paginated&ordering=archive_serial_number&page=1&page_size=3", + ) + page1_ids = [r["id"] for r in response.data["results"]] + self.assertEqual(len(page1_ids), 3) + + response = self.client.get( + "/api/documents/?query=paginated&ordering=archive_serial_number&page=2&page_size=3", + ) + page2_ids = [r["id"] for r in response.data["results"]] + self.assertEqual(len(page2_ids), 3) + + # No overlap between pages + self.assertEqual(set(page1_ids) & set(page2_ids), set()) + # Page 2 ASNs are higher than page 1 + page1_asns = [ + Document.objects.get(pk=pk).archive_serial_number for pk in page1_ids + ] + page2_asns = [ + Document.objects.get(pk=pk).archive_serial_number for pk in page2_ids + ] + self.assertTrue(max(page1_asns) < min(page2_asns)) + + def test_search_all_field_contains_all_ids_when_paginated(self) -> None: + """The 'all' field must contain every matching ID, even when paginated.""" + backend = get_backend() + doc_ids = [] + for i in range(10): + doc = Document.objects.create( + title=f"all field doc {i}", + content="allfield content", + checksum=f"AF{i}", + ) + backend.add_or_update(doc) + doc_ids.append(doc.pk) + + response = self.client.get( + "/api/documents/?query=allfield&page=1&page_size=3", + headers={"Accept": "application/json; version=9"}, + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), 3) + # "all" must contain ALL 10 matching IDs + self.assertCountEqual(response.data["all"], doc_ids) + @mock.patch("documents.bulk_edit.bulk_update_documents") def test_global_search(self, m) -> None: """ diff --git a/src/documents/views.py b/src/documents/views.py index 68d2b7961..6734155d5 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -2064,7 +2064,6 @@ class UnifiedSearchViewSet(DocumentViewSet): try: backend = get_backend() - # ORM-filtered queryset: permissions + field filters + ordering (DRF backends applied) filtered_qs = self.filter_queryset(self.get_queryset()) user = None if request.user.is_superuser else request.user @@ -2079,6 +2078,39 @@ class UnifiedSearchViewSet(DocumentViewSet): }, ) + # Parse ordering param + ordering_param = request.query_params.get("ordering", "") + sort_reverse = ordering_param.startswith("-") + sort_field_name = ordering_param.lstrip("-") if ordering_param else None + + # Fields Tantivy can sort natively — only numeric/date fast fields. + # Text-based sorts (title, correspondent__name, document_type__name) + # use a tokenized fast field whose ordering may differ from the ORM, + # so they fall back to the ORM sort path. + tantivy_sortable = { + "created", + "added", + "modified", + "archive_serial_number", + "page_count", + "num_notes", + } + use_tantivy_sort = ( + sort_field_name in tantivy_sortable or sort_field_name is None + ) + + # Compute the DRF page so we can tell Tantivy which slice to highlight + try: + requested_page = int(request.query_params.get("page", 1)) + except (TypeError, ValueError): + requested_page = 1 + try: + requested_page_size = int( + request.query_params.get("page_size", self.paginator.page_size), + ) + except (TypeError, ValueError): + requested_page_size = self.paginator.page_size + if ( "text" in request.query_params or "title_search" in request.query_params @@ -2093,17 +2125,48 @@ class UnifiedSearchViewSet(DocumentViewSet): else: search_mode = SearchMode.QUERY query_str = request.query_params["query"] - results = backend.search( - query_str, - user=user, - page=1, - page_size=10000, - sort_field=None, - sort_reverse=False, - search_mode=search_mode, - ) + + if use_tantivy_sort: + # Fast path: Tantivy sorts, highlights only for DRF page + results = backend.search( + query_str, + user=user, + page=1, + page_size=10000, + sort_field=sort_field_name, + sort_reverse=sort_reverse, + search_mode=search_mode, + highlight_page=requested_page, + highlight_page_size=requested_page_size, + ) + + # Intersect with ORM-visible IDs (field filters) + orm_ids = set(filtered_qs.values_list("pk", flat=True)) + ordered_hits = [h for h in results.hits if h["id"] in orm_ids] + else: + # Slow path: custom field ordering — ORM must sort + results = backend.search( + query_str, + user=user, + page=1, + page_size=10000, + sort_field=None, + sort_reverse=False, + search_mode=search_mode, + highlight_page=requested_page, + highlight_page_size=requested_page_size, + ) + hits_by_id = {h["id"]: h for h in results.hits} + hit_ids = set(hits_by_id.keys()) + orm_ordered_ids = filtered_qs.filter(id__in=hit_ids).values_list( + "pk", + flat=True, + ) + ordered_hits = [ + hits_by_id[pk] for pk in orm_ordered_ids if pk in hits_by_id + ] else: - # more_like_id — validate permission on the seed document first + # more_like_id path try: more_like_doc_id = int(request.query_params["more_like_id"]) more_like_doc = Document.objects.select_related("owner").get( @@ -2125,25 +2188,8 @@ class UnifiedSearchViewSet(DocumentViewSet): page=1, page_size=10000, ) - - hits_by_id = {h["id"]: h for h in results.hits} - - # Determine sort order: no ordering param -> Tantivy relevance; otherwise -> ORM order - ordering_param = request.query_params.get("ordering", "").lstrip("-") - if not ordering_param: - # Preserve Tantivy relevance order; intersect with ORM-visible IDs orm_ids = set(filtered_qs.values_list("pk", flat=True)) ordered_hits = [h for h in results.hits if h["id"] in orm_ids] - else: - # Use ORM ordering (already applied by DocumentsOrderingFilter) - hit_ids = set(hits_by_id.keys()) - orm_ordered_ids = filtered_qs.filter(id__in=hit_ids).values_list( - "pk", - flat=True, - ) - ordered_hits = [ - hits_by_id[pk] for pk in orm_ordered_ids if pk in hits_by_id - ] rl = TantivyRelevanceList(ordered_hits) page = self.paginate_queryset(rl)