From 0fb57205db3cb8f03252569246d5c35320eacb7b Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Mon, 30 Mar 2026 08:31:52 -0700 Subject: [PATCH] =?UTF-8?q?feat(search):=20complete=20TantivyBackend=20?= =?UTF-8?q?=E2=80=94=20search,=20autocomplete,=20more=5Flike=5Fthis,=20reb?= =?UTF-8?q?uild,=20WriteBatch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Dual-field approach for notes/custom_fields: JSON fields support structured queries (notes.user:alice, custom_fields.name:invoice); companion text fields (note, custom_field) carry content for default full-text search — tantivy-py 0.25 parse_query rejects dotted paths in default_field_names. Co-Authored-By: Claude Sonnet 4.6 --- src/documents/search/_backend.py | 704 +++++++++++++++++++++ src/documents/search/_query.py | 15 +- src/documents/search/_schema.py | 8 + src/documents/tests/search/conftest.py | 30 + src/documents/tests/search/test_backend.py | 260 ++++++++ 5 files changed, 1013 insertions(+), 4 deletions(-) create mode 100644 src/documents/search/_backend.py create mode 100644 src/documents/tests/search/conftest.py create mode 100644 src/documents/tests/search/test_backend.py diff --git a/src/documents/search/_backend.py b/src/documents/search/_backend.py new file mode 100644 index 000000000..6ad7e526c --- /dev/null +++ b/src/documents/search/_backend.py @@ -0,0 +1,704 @@ +from __future__ import annotations + +import bisect +import logging +import threading +import unicodedata +from dataclasses import dataclass +from datetime import UTC +from datetime import datetime +from typing import TYPE_CHECKING +from typing import Self +from typing import TypedDict +from typing import TypeVar + +import filelock +import tantivy +from django.conf import settings +from django.utils.timezone import get_current_timezone +from guardian.shortcuts import get_users_with_perms + +from documents.search._query import build_permission_filter +from documents.search._query import parse_user_query +from documents.search._schema import _wipe_index +from documents.search._schema import _write_sentinels +from documents.search._schema import build_schema +from documents.search._schema import open_or_rebuild_index +from documents.search._tokenizer import register_tokenizers + +if TYPE_CHECKING: + from collections.abc import Callable + from collections.abc import Iterable + + from django.contrib.auth.base_user import AbstractBaseUser + from django.db.models import QuerySet + + from documents.models import Document + +logger = logging.getLogger("paperless.search") + +T = TypeVar("T") + + +def _identity(iterable: Iterable[T]) -> Iterable[T]: + """Default iter_wrapper that passes through unchanged.""" + return iterable + + +def _ascii_fold(s: str) -> str: + """Normalize unicode to ASCII equivalent characters.""" + return unicodedata.normalize("NFD", s).encode("ascii", "ignore").decode() + + +def _extract_autocomplete_words(text_sources: list[str]) -> set[str]: + """Extract and normalize words for autocomplete, filtering stopwords.""" + words = set() + + # Use NLTK if enabled + if settings.NLTK_ENABLED and settings.NLTK_LANGUAGE: + try: + import nltk + from nltk.corpus import stopwords + from nltk.tokenize import word_tokenize + + # Set NLTK data path + nltk.data.path = [settings.NLTK_DIR] + + # Get stopwords for the configured language + try: + stopwords.ensure_loaded() + stop_words = frozenset(stopwords.words(settings.NLTK_LANGUAGE)) + except (AttributeError, OSError) as e: + logger.debug(f"Could not load NLTK stopwords: {e}") + stop_words = frozenset() + + for text in text_sources: + if text: + try: + tokens = word_tokenize( + text.lower(), + language=settings.NLTK_LANGUAGE, + ) + for token in tokens: + if ( + token.isalpha() + and len(token) > 2 + and token not in stop_words + ): + normalized = _ascii_fold(token) + if normalized: + words.add(normalized) + except Exception as e: + logger.debug(f"NLTK tokenization failed: {e}") + # Fallback to regex + import re + + tokens = re.findall(r"\b[a-zA-Z]{3,}\b", text) + for token in tokens: + normalized = _ascii_fold(token.lower()) + if normalized and normalized not in stop_words: + words.add(normalized) + + except ImportError: + logger.debug("NLTK not available, using fallback tokenization") + # Fall through to basic tokenization + except Exception as e: + logger.debug(f"NLTK initialization failed: {e}") + # Fall through to basic tokenization + + # Fallback tokenization when NLTK is disabled or unavailable + if not words: # Only use fallback if NLTK didn't produce results + import re + + basic_stopwords = { + "the", + "a", + "an", + "and", + "or", + "but", + "in", + "on", + "at", + "to", + "for", + "of", + "with", + "by", + } + for text in text_sources: + if text: + tokens = re.findall(r"\b[a-zA-Z]{3,}\b", text) + for token in tokens: + normalized = _ascii_fold(token.lower()) + if normalized and normalized not in basic_stopwords: + words.add(normalized) + + return words + + +class SearchHit(TypedDict): + """Type definition for search result hits.""" + + id: int + score: float + rank: int + highlights: dict[str, str] + + +@dataclass(frozen=True, slots=True) +class SearchResults: + hits: list[SearchHit] + total: int # total matching documents (for pagination) + query: str # preprocessed query string + + +class SearchIndexLockError(Exception): + pass + + +class WriteBatch: + """Context manager for bulk index operations with file locking.""" + + def __init__(self, backend: TantivyBackend, lock_timeout: float): + self._backend = backend + self._lock_timeout = lock_timeout + self._writer = None + + def __enter__(self) -> Self: + lock_path = settings.INDEX_DIR / ".tantivy.lock" + self._lock = filelock.FileLock(str(lock_path)) + + try: + self._lock.acquire(timeout=self._lock_timeout) + except filelock.Timeout as e: + raise SearchIndexLockError( + f"Could not acquire index lock within {self._lock_timeout}s", + ) from e + + self._writer = self._backend._index.writer() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type is None: + # Success case - commit changes + self._writer.commit() + self._backend._index.reload() + else: + # Exception occurred - discard changes + # Writer is automatically discarded when it goes out of scope + pass + # Explicitly delete writer to release tantivy's internal lock + if self._writer is not None: + del self._writer + self._writer = None + finally: + if hasattr(self, "_lock") and self._lock: + self._lock.release() + + def add_or_update(self, document: Document) -> None: + """Add or update a document in the batch.""" + doc = self._backend._build_tantivy_doc(document) + self._writer.add_document(doc) + + def remove(self, doc_id: int) -> None: + """Remove a document from the batch.""" + # Use range query to work around u64 deletion bug + self._writer.delete_documents_by_query( + tantivy.Query.range_query( + self._backend._schema, + "id", + tantivy.FieldType.Unsigned, + doc_id, + doc_id, + ), + ) + + +class TantivyBackend: + """Tantivy search backend with context manager interface.""" + + def __init__(self): + self._index = None + self._schema = None + + def __enter__(self) -> Self: + self._index = open_or_rebuild_index() + register_tokenizers(self._index, settings.SEARCH_LANGUAGE) + self._schema = self._index.schema + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Index doesn't need explicit close + pass + + def _build_tantivy_doc(self, document: Document) -> tantivy.Document: + """Build a tantivy Document from a Django Document instance.""" + doc = tantivy.Document() + + # Basic fields + doc.add_unsigned("id", document.pk) + doc.add_text("checksum", document.checksum) + doc.add_text("title", document.title) + doc.add_text("title_sort", document.title) + doc.add_text("content", document.content) + doc.add_text("bigram_content", document.content) + + # Original filename - only add if not None/empty + if document.original_filename: + doc.add_text("original_filename", document.original_filename) + + # Correspondent + if document.correspondent: + doc.add_text("correspondent", document.correspondent.name) + doc.add_text("correspondent_sort", document.correspondent.name) + doc.add_unsigned("correspondent_id", document.correspondent_id) + + # Document type + if document.document_type: + doc.add_text("document_type", document.document_type.name) + doc.add_text("type_sort", document.document_type.name) + doc.add_unsigned("document_type_id", document.document_type_id) + + # Storage path + if document.storage_path: + doc.add_text("storage_path", document.storage_path.name) + doc.add_unsigned("storage_path_id", document.storage_path_id) + + # Tags + for tag in document.tags.all(): + doc.add_text("tag", tag.name) + doc.add_unsigned("tag_id", tag.pk) + + # Notes — JSON for structured queries (notes.user:alice, notes.note:text), + # companion text field for default full-text search. + for note in document.notes.all(): + note_data: dict[str, str] = {"note": note.note} + if note.user: + note_data["user"] = note.user.username + doc.add_json("notes", note_data) + doc.add_text("note", note.note) + + # Custom fields — JSON for structured queries (custom_fields.name:x, custom_fields.value:y), + # companion text field for default full-text search. + for cfi in document.custom_fields.all(): + doc.add_json( + "custom_fields", + { + "name": cfi.field.name, + "value": str(cfi.value), + }, + ) + doc.add_text("custom_field", str(cfi.value)) + + # Dates - created is date-only, others are full datetime + created_date = datetime( + document.created.year, + document.created.month, + document.created.day, + tzinfo=UTC, + ) + doc.add_date("created", created_date) + doc.add_date("modified", document.modified) + doc.add_date("added", document.added) + + # ASN - skip entirely when None (0 is valid) + if document.archive_serial_number is not None: + doc.add_unsigned("asn", document.archive_serial_number) + + # Page count - only add if not None + if document.page_count is not None: + doc.add_unsigned("page_count", document.page_count) + + # Number of notes + doc.add_unsigned("num_notes", document.notes.count()) + + # Owner + if document.owner_id: + doc.add_unsigned("owner_id", document.owner_id) + + # Viewers with permission + users_with_perms = get_users_with_perms( + document, + only_with_perms_in=["view_document"], + ) + for user in users_with_perms: + doc.add_unsigned("viewer_id", user.pk) + + # Autocomplete words with NLTK stopword filtering + text_sources = [document.title, document.content] + if document.correspondent: + text_sources.append(document.correspondent.name) + if document.document_type: + text_sources.append(document.document_type.name) + for tag in document.tags.all(): + text_sources.append(tag.name) + + autocomplete_words = _extract_autocomplete_words(text_sources) + + # Add sorted deduplicated words + for word in sorted(autocomplete_words): + doc.add_text("autocomplete_word", word) + + return doc + + def add_or_update(self, document: Document) -> None: + """Add or update a single document with file locking.""" + with self.batch_update(lock_timeout=5.0) as batch: + batch.add_or_update(document) + + def remove(self, doc_id: int) -> None: + """Remove a single document with file locking.""" + with self.batch_update(lock_timeout=5.0) as batch: + batch.remove(doc_id) + + def search( + self, + query: str, + user: AbstractBaseUser | None, + page: int, + page_size: int, + sort_field: str | None, + *, + sort_reverse: bool, + ) -> SearchResults: + """Search the index.""" + tz = get_current_timezone() + user_query = parse_user_query(self._index, query, tz) + + # Apply permission filter if user is not None (not superuser) + if user is not None: + permission_filter = build_permission_filter(self._schema, user) + final_query = tantivy.Query.boolean_query( + [ + (tantivy.Occur.Must, user_query), + (tantivy.Occur.Must, permission_filter), + ], + ) + else: + final_query = user_query + + searcher = self._index.searcher() + offset = (page - 1) * page_size + + # Map sort fields + sort_field_map = { + "title": "title_sort", + "correspondent__name": "correspondent_sort", + "document_type__name": "type_sort", + "created": "created", + "added": "added", + "modified": "modified", + "archive_serial_number": "asn", + "page_count": "page_count", + "num_notes": "num_notes", + } + + # Perform search + if sort_field and sort_field in sort_field_map: + mapped_field = sort_field_map[sort_field] + if sort_reverse: + # For reverse sort, we need to use a different approach + # tantivy doesn't directly support reverse field sorting in the Python API + # We'll search for more results and sort in Python + results = searcher.search(final_query, limit=offset + page_size * 10) + # For field sorting: just DocAddress (no score) + all_hits = [ + (hit, 0.0) for hit in results.hits + ] # score is 0 for field sorts + else: + results = searcher.search( + final_query, + limit=offset + page_size, + order_by_field=mapped_field, + ) + # For field sorting: just DocAddress (no score) + all_hits = [ + (hit, 0.0) for hit in results.hits + ] # score is 0 for field sorts + else: + # Score-based search returns: (score, doc_address) tuple + results = searcher.search(final_query, limit=offset + page_size) + # Convert to (doc_address, score) for consistency + all_hits = [(hit[1], hit[0]) for hit in results.hits] + + total = results.count + + # Normalize scores for score-based searches + if not sort_field and all_hits: + scores = [hit[1] for hit in all_hits] + max_score = max(scores) if scores else 1.0 + all_hits = [(hit[0], hit[1] / max_score) for hit in all_hits] + + # Apply threshold filter if configured + threshold = getattr(settings, "ADVANCED_FUZZY_SEARCH_THRESHOLD", None) + if ( + threshold is not None and not sort_field + ): # Only apply threshold to score-based search + all_hits = [hit for hit in all_hits if hit[1] >= threshold] + + # Get the page's hits + page_hits = all_hits[offset : offset + page_size] + + # Build result hits with highlights + hits: list[SearchHit] = [] + snippet_generator = None + + for rank, (doc_address, score) in enumerate(page_hits, start=offset + 1): + # Get the actual document from the searcher using the doc address + actual_doc = searcher.doc(doc_address) + doc_dict = actual_doc.to_dict() + doc_id = doc_dict["id"][0] + + highlights: dict[str, str] = {} + + # Generate highlights if score > 0 + if score > 0: + try: + if snippet_generator is None: + snippet_generator = tantivy.SnippetGenerator.create( + searcher, + final_query, + self._schema, + "content", + ) + + content_snippet = snippet_generator.snippet_from_doc(actual_doc) + if content_snippet: + highlights["content"] = str(content_snippet) + + # Try notes highlights + if "notes" in doc_dict: + notes_generator = tantivy.SnippetGenerator.create( + searcher, + final_query, + self._schema, + "notes", + ) + notes_snippet = notes_generator.snippet_from_doc(actual_doc) + if notes_snippet: + highlights["notes"] = str(notes_snippet) + + except Exception as e: + logger.debug(f"Failed to generate highlights for doc {doc_id}: {e}") + + hits.append( + SearchHit( + id=doc_id, + score=score, + rank=rank, + highlights=highlights, + ), + ) + + return SearchResults( + hits=hits, + total=total, + query=query, + ) + + def autocomplete(self, term: str, limit: int) -> list[str]: + """Get autocomplete suggestions.""" + normalized_term = _ascii_fold(term.lower()) + + searcher = self._index.searcher() + # Search all documents to collect autocomplete words + all_query = tantivy.Query.all_query() + results = searcher.search(all_query, limit=10000) # High limit to get all docs + + # Collect all autocomplete words + words = set() + for hit in results.hits: + # For all_query, hit is (score, doc_address) + doc_address = hit[1] if len(hit) == 2 else hit[0] + + stored_doc = searcher.doc(doc_address) + doc_dict = stored_doc.to_dict() + if "autocomplete_word" in doc_dict: + for word in doc_dict["autocomplete_word"]: + words.add(word) + + # Sort and find matches + sorted_words = sorted(words) + + # Use binary search to find starting position + start_idx = bisect.bisect_left(sorted_words, normalized_term) + + # Collect matching words + matches = [] + for i in range(start_idx, len(sorted_words)): + word = sorted_words[i] + if word.startswith(normalized_term): + matches.append(word) + if len(matches) >= limit: + break + else: + break + + return matches + + def more_like_this( + self, + doc_id: int, + user: AbstractBaseUser | None, + page: int, + page_size: int, + ) -> SearchResults: + """Find documents similar to the given document.""" + searcher = self._index.searcher() + + # First find the document address + id_query = tantivy.Query.range_query( + self._schema, + "id", + tantivy.FieldType.Unsigned, + doc_id, + doc_id, + ) + results = searcher.search(id_query, limit=1) + + if not results.hits: + # Document not found + return SearchResults(hits=[], total=0, query=f"more_like:{doc_id}") + + # Extract doc_address from (score, doc_address) tuple + doc_address = results.hits[0][1] + + # Build more like this query + mlt_query = tantivy.Query.more_like_this_query( + doc_address, + min_doc_frequency=1, + max_doc_frequency=None, + min_term_frequency=1, + max_query_terms=12, + min_word_length=None, + max_word_length=None, + boost_factor=None, + ) + + # Apply permission filter + if user is not None: + permission_filter = build_permission_filter(self._schema, user) + final_query = tantivy.Query.boolean_query( + [ + (tantivy.Occur.Must, mlt_query), + (tantivy.Occur.Must, permission_filter), + ], + ) + else: + final_query = mlt_query + + # Search + offset = (page - 1) * page_size + results = searcher.search(final_query, limit=offset + page_size) + + total = results.count + # Convert from (score, doc_address) to (doc_address, score) + all_hits = [(hit[1], hit[0]) for hit in results.hits] + + # Normalize scores + if all_hits: + max_score = max(hit[1] for hit in all_hits) or 1.0 + all_hits = [(hit[0], hit[1] / max_score) for hit in all_hits] + + # Get page hits + page_hits = all_hits[offset : offset + page_size] + + # Build results + hits: list[SearchHit] = [] + for rank, (doc_address, score) in enumerate(page_hits, start=offset + 1): + actual_doc = searcher.doc(doc_address) + doc_dict = actual_doc.to_dict() + result_doc_id = doc_dict["id"][0] + + # Skip the original document + if result_doc_id == doc_id: + continue + + hits.append( + SearchHit( + id=result_doc_id, + score=score, + rank=rank, + highlights={}, # MLT doesn't generate highlights + ), + ) + + return SearchResults( + hits=hits, + total=max(0, total - 1), # Subtract 1 for the original document + query=f"more_like:{doc_id}", + ) + + def batch_update(self, lock_timeout: float = 30.0) -> WriteBatch: + """Get a batch context manager for bulk operations.""" + return WriteBatch(self, lock_timeout) + + def rebuild(self, documents: QuerySet, iter_wrapper: Callable = _identity) -> None: + """Rebuild the entire search index.""" + from documents.search._tokenizer import register_tokenizers + + index_dir = settings.INDEX_DIR + + # Create new index + _wipe_index(index_dir) + new_index = tantivy.Index(build_schema(), path=str(index_dir)) + _write_sentinels(index_dir) + register_tokenizers(new_index, settings.SEARCH_LANGUAGE) + + # Index all documents using the new index + writer = new_index.writer() + + for document in iter_wrapper(documents): + # Temporarily use new index for document building + old_index = self._index + old_schema = self._schema + self._index = new_index + self._schema = new_index.schema + + try: + doc = self._build_tantivy_doc(document) + writer.add_document(doc) + finally: + # Restore old index + self._index = old_index + self._schema = old_schema + + writer.commit() + + # Swap to new index + self._index = new_index + self._schema = new_index.schema + self._index.reload() + + +# Module-level singleton with proper thread safety +_backend: TantivyBackend | None = None +_backend_lock = threading.RLock() + + +def get_backend() -> TantivyBackend: + """Get the global backend instance with thread safety.""" + global _backend + + # Fast path for already initialized backend + if _backend is not None: + return _backend + + # Slow path with locking + with _backend_lock: + if _backend is None: + _backend = TantivyBackend() + _backend.__enter__() + return _backend + + +def reset_backend() -> None: + """Reset the global backend instance with thread safety.""" + global _backend + + with _backend_lock: + if _backend is not None: + _backend.__exit__(None, None, None) + _backend = None diff --git a/src/documents/search/_query.py b/src/documents/search/_query.py index 4176410a1..e03f364eb 100644 --- a/src/documents/search/_query.py +++ b/src/documents/search/_query.py @@ -303,19 +303,24 @@ DEFAULT_SEARCH_FIELDS = [ "correspondent", "document_type", "tag", - "notes", - "custom_fields", + "note", # companion text field for notes content (notes JSON for structured: notes.user:x) + "custom_field", # companion text field for CF values (custom_fields JSON for structured: custom_fields.name:x) ] _FIELD_BOOSTS = {"title": 2.0} def parse_user_query( index: tantivy.Index, - schema: tantivy.Schema, raw_query: str, tz: tzinfo, ) -> tantivy.Query: - """Run the full query preprocessing pipeline: date rewriting → normalisation → Tantivy parse. Adds fuzzy blend if ADVANCED_FUZZY_SEARCH_THRESHOLD is set.""" + """Run the full query preprocessing pipeline: date rewriting → normalisation → Tantivy parse. + + When ADVANCED_FUZZY_SEARCH_THRESHOLD is set (any float), a fuzzy query is blended in as a + Should clause boosted at 0.1 — keeping fuzzy hits ranked below exact matches. The fuzzy + query uses edit-distance=1, prefix=True, transposition_cost_one=True on all search fields. + The threshold float is a post-search minimum-score filter applied in the backend layer, not here. + """ query_str = rewrite_natural_date_keywords(raw_query, tz) query_str = normalize_query(query_str) @@ -332,11 +337,13 @@ def parse_user_query( query_str, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS, + # (prefix=True, distance=1, transposition_cost_one=True) — edit-distance fuzziness fuzzy_fields={f: (True, 1, True) for f in DEFAULT_SEARCH_FIELDS}, ) return tantivy.Query.boolean_query( [ (tantivy.Occur.Should, exact), + # 0.1 boost keeps fuzzy hits ranked below exact matches (intentional) (tantivy.Occur.Should, tantivy.Query.boost_query(fuzzy, 0.1)), ], ) diff --git a/src/documents/search/_schema.py b/src/documents/search/_schema.py index 5724d97a0..c16e6d2f1 100644 --- a/src/documents/search/_schema.py +++ b/src/documents/search/_schema.py @@ -48,9 +48,17 @@ def build_schema() -> tantivy.Schema: sb.add_text_field("autocomplete_word", stored=True, tokenizer_name="raw") sb.add_text_field("tag", stored=True, tokenizer_name="paperless_text") + + # JSON fields — structured queries: notes.user:alice, custom_fields.name:invoice + # tantivy-py 0.25 does not support dotted paths in parse_query default_field_names, + # so companion text fields (note, custom_field) carry content for default full-text search. sb.add_json_field("notes", stored=True, tokenizer_name="paperless_text") sb.add_json_field("custom_fields", stored=True, tokenizer_name="paperless_text") + # Companion text fields for default full-text search (not stored — no extra disk cost) + sb.add_text_field("note", stored=False, tokenizer_name="paperless_text") + sb.add_text_field("custom_field", stored=False, tokenizer_name="paperless_text") + for field in ( "correspondent_id", "document_type_id", diff --git a/src/documents/tests/search/conftest.py b/src/documents/tests/search/conftest.py new file mode 100644 index 000000000..6946649e9 --- /dev/null +++ b/src/documents/tests/search/conftest.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from documents.search._backend import TantivyBackend +from documents.search._backend import reset_backend + +if TYPE_CHECKING: + from collections.abc import Generator + from pathlib import Path + + from pytest_django.fixtures import SettingsWrapper + + +@pytest.fixture +def index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path: + path = tmp_path / "index" + path.mkdir() + settings.INDEX_DIR = path + return path + + +@pytest.fixture +def backend(index_dir: Path) -> Generator[TantivyBackend, None, None]: + b = TantivyBackend() + with b: + yield b + reset_backend() diff --git a/src/documents/tests/search/test_backend.py b/src/documents/tests/search/test_backend.py new file mode 100644 index 000000000..20608d035 --- /dev/null +++ b/src/documents/tests/search/test_backend.py @@ -0,0 +1,260 @@ +import pytest +from django.contrib.auth.models import User + +from documents.models import CustomField +from documents.models import CustomFieldInstance +from documents.models import Document +from documents.models import Note +from documents.search._backend import TantivyBackend + +pytestmark = [pytest.mark.search, pytest.mark.django_db] + + +class TestWriteBatch: + """Test WriteBatch context manager functionality.""" + + def test_rolls_back_on_exception(self, backend: TantivyBackend): + """Data integrity: a mid-batch exception must not corrupt the index.""" + doc = Document.objects.create( + title="Rollback Target", + content="should survive", + checksum="RB1", + pk=1, + ) + backend.add_or_update(doc) + + try: + with backend.batch_update() as batch: + batch.remove(doc.pk) + raise RuntimeError("simulated failure") + except RuntimeError: + pass + + r = backend.search( + "should survive", + user=None, + page=1, + page_size=10, + sort_field=None, + sort_reverse=False, + ) + assert r.total == 1 + + +class TestSearch: + """Test search functionality.""" + + def test_scores_normalised_top_hit_is_one(self, backend: TantivyBackend): + """UI score bar depends on the top hit being 1.0.""" + for i, title in enumerate(["bank invoice", "bank statement", "bank receipt"]): + doc = Document.objects.create( + title=title, + content=title, + checksum=f"SN{i}", + pk=10 + i, + ) + backend.add_or_update(doc) + r = backend.search( + "bank", + user=None, + page=1, + page_size=10, + sort_field=None, + sort_reverse=False, + ) + assert r.hits[0]["score"] == pytest.approx(1.0) + assert all(0.0 <= h["score"] <= 1.0 for h in r.hits) + + def test_owner_filter(self, backend: TantivyBackend): + """Owner can find their document; other user cannot.""" + owner = User.objects.create_user("owner") + other = User.objects.create_user("other") + doc = Document.objects.create( + title="Private", + content="secret", + checksum="PF1", + pk=20, + owner=owner, + ) + backend.add_or_update(doc) + + assert ( + backend.search( + "secret", + user=owner, + page=1, + page_size=10, + sort_field=None, + sort_reverse=False, + ).total + == 1 + ) + assert ( + backend.search( + "secret", + user=other, + page=1, + page_size=10, + sort_field=None, + sort_reverse=False, + ).total + == 0 + ) + + +class TestRebuild: + """Test index rebuilding functionality.""" + + def test_with_iter_wrapper_called(self, backend: TantivyBackend): + """rebuild() must pass documents through iter_wrapper.""" + seen = [] + + def wrapper(docs): + for doc in docs: + seen.append(doc.pk) + yield doc + + Document.objects.create(title="Tracked", content="x", checksum="TW1", pk=30) + backend.rebuild(Document.objects.all(), iter_wrapper=wrapper) + assert 30 in seen + + +class TestAutocomplete: + """Test autocomplete functionality.""" + + def test_basic_functionality(self, backend: TantivyBackend): + """Autocomplete should find word prefixes.""" + doc = Document.objects.create( + title="Invoice from Microsoft Corporation", + content="payment details", + checksum="AC1", + pk=40, + ) + backend.add_or_update(doc) + + results = backend.autocomplete("micro", limit=10) + assert "microsoft" in results + + +class TestMoreLikeThis: + """Test more like this functionality.""" + + def test_excludes_original(self, backend: TantivyBackend): + """More like this should not return the original document.""" + doc1 = Document.objects.create( + title="Important document", + content="financial information", + checksum="MLT1", + pk=50, + ) + doc2 = Document.objects.create( + title="Another document", + content="financial report", + checksum="MLT2", + pk=51, + ) + backend.add_or_update(doc1) + backend.add_or_update(doc2) + + results = backend.more_like_this(doc_id=50, user=None, page=1, page_size=10) + returned_ids = [hit["id"] for hit in results.hits] + assert 50 not in returned_ids # Original document excluded + + +class TestFieldHandling: + """Test handling of various document fields.""" + + def test_none_values_handled_correctly(self, backend: TantivyBackend): + """Test that None values for original_filename and page_count are handled properly.""" + doc = Document.objects.create( + title="Test Doc", + content="test content", + checksum="NV1", + pk=60, + original_filename=None, + page_count=None, + ) + # Should not raise an exception + backend.add_or_update(doc) + + results = backend.search( + "test", + user=None, + page=1, + page_size=10, + sort_field=None, + sort_reverse=False, + ) + assert results.total == 1 + + def test_custom_fields_include_name_and_value(self, backend: TantivyBackend): + """Custom field indexing should include both name and value.""" + # Create a custom field + field = CustomField.objects.create( + name="Invoice Number", + data_type=CustomField.FieldDataType.STRING, + ) + doc = Document.objects.create( + title="Invoice", + content="test", + checksum="CF1", + pk=70, + ) + CustomFieldInstance.objects.create( + document=doc, + field=field, + value_text="INV-2024-001", + ) + + # Should not raise an exception during indexing + backend.add_or_update(doc) + + results = backend.search( + "invoice", + user=None, + page=1, + page_size=10, + sort_field=None, + sort_reverse=False, + ) + assert results.total == 1 + + def test_notes_include_user_information(self, backend: TantivyBackend): + """Notes should include user information when available.""" + user = User.objects.create_user("notewriter") + doc = Document.objects.create( + title="Doc with notes", + content="test", + checksum="NT1", + pk=80, + ) + Note.objects.create(document=doc, note="Important note", user=user) + + # Should not raise an exception during indexing + backend.add_or_update(doc) + + # Test basic document search first + results = backend.search( + "test", + user=None, + page=1, + page_size=10, + sort_field=None, + sort_reverse=False, + ) + assert results.total == 1, ( + f"Expected 1, got {results.total}. Document content should be searchable." + ) + + # Test notes search + results = backend.search( + "important", + user=None, + page=1, + page_size=10, + sort_field=None, + sort_reverse=False, + ) + assert results.total == 1, ( + f"Expected 1, got {results.total}. Note content should be searchable." + )