diff --git a/src/documents/admin.py b/src/documents/admin.py index 6c7a6f304..f0e5ccd25 100644 --- a/src/documents/admin.py +++ b/src/documents/admin.py @@ -100,24 +100,23 @@ class DocumentAdmin(GuardedModelAdmin): return Document.global_objects.all() def delete_queryset(self, request, queryset): - from documents import index + from documents.search import get_backend - with index.open_index_writer() as writer: + with get_backend().batch_update() as batch: for o in queryset: - index.remove_document(writer, o) - + batch.remove(o.pk) super().delete_queryset(request, queryset) def delete_model(self, request, obj): - from documents import index + from documents.search import get_backend - index.remove_document_from_index(obj) + get_backend().remove(obj.pk) super().delete_model(request, obj) def save_model(self, request, obj, form, change): - from documents import index + from documents.search import get_backend - index.add_or_update_document(obj) + get_backend().add_or_update(obj) super().save_model(request, obj, form, change) diff --git a/src/documents/bulk_edit.py b/src/documents/bulk_edit.py index 8dbcdb8a4..3f80b699d 100644 --- a/src/documents/bulk_edit.py +++ b/src/documents/bulk_edit.py @@ -349,11 +349,11 @@ def delete(doc_ids: list[int]) -> Literal["OK"]: Document.objects.filter(id__in=delete_ids).delete() - from documents import index + from documents.search import get_backend - with index.open_index_writer() as writer: + with get_backend().batch_update() as batch: for id in delete_ids: - index.remove_document_by_id(writer, id) + batch.remove(id) status_mgr = DocumentsStatusManager() status_mgr.send_documents_deleted(delete_ids) diff --git a/src/documents/index.py b/src/documents/index.py deleted file mode 100644 index 24c541f8c..000000000 --- a/src/documents/index.py +++ /dev/null @@ -1,655 +0,0 @@ -from __future__ import annotations - -import logging -import math -import re -from collections import Counter -from contextlib import contextmanager -from datetime import UTC -from datetime import datetime -from datetime import time -from datetime import timedelta -from shutil import rmtree -from time import sleep -from typing import TYPE_CHECKING -from typing import Literal - -from dateutil.relativedelta import relativedelta -from django.conf import settings -from django.utils import timezone as django_timezone -from django.utils.timezone import get_current_timezone -from django.utils.timezone import now -from guardian.shortcuts import get_users_with_perms -from whoosh import classify -from whoosh import highlight -from whoosh import query -from whoosh.fields import BOOLEAN -from whoosh.fields import DATETIME -from whoosh.fields import KEYWORD -from whoosh.fields import NUMERIC -from whoosh.fields import TEXT -from whoosh.fields import Schema -from whoosh.highlight import HtmlFormatter -from whoosh.idsets import BitSet -from whoosh.idsets import DocIdSet -from whoosh.index import FileIndex -from whoosh.index import LockError -from whoosh.index import create_in -from whoosh.index import exists_in -from whoosh.index import open_dir -from whoosh.qparser import MultifieldParser -from whoosh.qparser import QueryParser -from whoosh.qparser.dateparse import DateParserPlugin -from whoosh.qparser.dateparse import English -from whoosh.qparser.plugins import FieldsPlugin -from whoosh.scoring import TF_IDF -from whoosh.util.times import timespan -from whoosh.writing import AsyncWriter - -from documents.models import CustomFieldInstance -from documents.models import Document -from documents.models import Note -from documents.models import User - -if TYPE_CHECKING: - from django.db.models import QuerySet - from whoosh.reading import IndexReader - from whoosh.searching import ResultsPage - from whoosh.searching import Searcher - -logger = logging.getLogger("paperless.index") - - -def get_schema() -> Schema: - return Schema( - id=NUMERIC(stored=True, unique=True), - title=TEXT(sortable=True), - content=TEXT(), - asn=NUMERIC(sortable=True, signed=False), - correspondent=TEXT(sortable=True), - correspondent_id=NUMERIC(), - has_correspondent=BOOLEAN(), - tag=KEYWORD(commas=True, scorable=True, lowercase=True), - tag_id=KEYWORD(commas=True, scorable=True), - has_tag=BOOLEAN(), - type=TEXT(sortable=True), - type_id=NUMERIC(), - has_type=BOOLEAN(), - created=DATETIME(sortable=True), - modified=DATETIME(sortable=True), - added=DATETIME(sortable=True), - path=TEXT(sortable=True), - path_id=NUMERIC(), - has_path=BOOLEAN(), - notes=TEXT(), - num_notes=NUMERIC(sortable=True, signed=False), - custom_fields=TEXT(), - custom_field_count=NUMERIC(sortable=True, signed=False), - has_custom_fields=BOOLEAN(), - custom_fields_id=KEYWORD(commas=True), - owner=TEXT(), - owner_id=NUMERIC(), - has_owner=BOOLEAN(), - viewer_id=KEYWORD(commas=True), - checksum=TEXT(), - page_count=NUMERIC(sortable=True), - original_filename=TEXT(sortable=True), - is_shared=BOOLEAN(), - ) - - -def open_index(*, recreate=False) -> FileIndex: - transient_exceptions = (FileNotFoundError, LockError) - max_retries = 3 - retry_delay = 0.1 - - for attempt in range(max_retries + 1): - try: - if exists_in(settings.INDEX_DIR) and not recreate: - return open_dir(settings.INDEX_DIR, schema=get_schema()) - break - except transient_exceptions as exc: - is_last_attempt = attempt == max_retries or recreate - if is_last_attempt: - logger.exception( - "Error while opening the index after retries, recreating.", - ) - break - - logger.warning( - "Transient error while opening the index (attempt %s/%s): %s. Retrying.", - attempt + 1, - max_retries + 1, - exc, - ) - sleep(retry_delay) - except Exception: - logger.exception("Error while opening the index, recreating.") - break - - # create_in doesn't handle corrupted indexes very well, remove the directory entirely first - if settings.INDEX_DIR.is_dir(): - rmtree(settings.INDEX_DIR) - settings.INDEX_DIR.mkdir(parents=True, exist_ok=True) - - return create_in(settings.INDEX_DIR, get_schema()) - - -@contextmanager -def open_index_writer(*, optimize=False) -> AsyncWriter: - writer = AsyncWriter(open_index()) - - try: - yield writer - except Exception as e: - logger.exception(str(e)) - writer.cancel() - finally: - writer.commit(optimize=optimize) - - -@contextmanager -def open_index_searcher() -> Searcher: - searcher = open_index().searcher() - - try: - yield searcher - finally: - searcher.close() - - -def update_document( - writer: AsyncWriter, - doc: Document, - effective_content: str | None = None, -) -> None: - tags = ",".join([t.name for t in doc.tags.all()]) - tags_ids = ",".join([str(t.id) for t in doc.tags.all()]) - notes = ",".join([str(c.note) for c in Note.objects.filter(document=doc)]) - custom_fields = ",".join( - [str(c) for c in CustomFieldInstance.objects.filter(document=doc)], - ) - custom_fields_ids = ",".join( - [str(f.field.id) for f in CustomFieldInstance.objects.filter(document=doc)], - ) - asn: int | None = doc.archive_serial_number - if asn is not None and ( - asn < Document.ARCHIVE_SERIAL_NUMBER_MIN - or asn > Document.ARCHIVE_SERIAL_NUMBER_MAX - ): - logger.error( - f"Not indexing Archive Serial Number {asn} of document {doc.pk}. " - f"ASN is out of range " - f"[{Document.ARCHIVE_SERIAL_NUMBER_MIN:,}, " - f"{Document.ARCHIVE_SERIAL_NUMBER_MAX:,}.", - ) - asn = 0 - users_with_perms = get_users_with_perms( - doc, - only_with_perms_in=["view_document"], - ) - viewer_ids: str = ",".join([str(u.id) for u in users_with_perms]) - writer.update_document( - id=doc.pk, - title=doc.title, - content=effective_content or doc.content, - correspondent=doc.correspondent.name if doc.correspondent else None, - correspondent_id=doc.correspondent.id if doc.correspondent else None, - has_correspondent=doc.correspondent is not None, - tag=tags if tags else None, - tag_id=tags_ids if tags_ids else None, - has_tag=len(tags) > 0, - type=doc.document_type.name if doc.document_type else None, - type_id=doc.document_type.id if doc.document_type else None, - has_type=doc.document_type is not None, - created=datetime.combine(doc.created, time.min), - added=doc.added, - asn=asn, - modified=doc.modified, - path=doc.storage_path.name if doc.storage_path else None, - path_id=doc.storage_path.id if doc.storage_path else None, - has_path=doc.storage_path is not None, - notes=notes, - num_notes=len(notes), - custom_fields=custom_fields, - custom_field_count=len(doc.custom_fields.all()), - has_custom_fields=len(custom_fields) > 0, - custom_fields_id=custom_fields_ids if custom_fields_ids else None, - owner=doc.owner.username if doc.owner else None, - owner_id=doc.owner.id if doc.owner else None, - has_owner=doc.owner is not None, - viewer_id=viewer_ids if viewer_ids else None, - checksum=doc.checksum, - page_count=doc.page_count, - original_filename=doc.original_filename, - is_shared=len(viewer_ids) > 0, - ) - logger.debug(f"Index updated for document {doc.pk}.") - - -def remove_document(writer: AsyncWriter, doc: Document) -> None: - remove_document_by_id(writer, doc.pk) - - -def remove_document_by_id(writer: AsyncWriter, doc_id) -> None: - writer.delete_by_term("id", doc_id) - - -def add_or_update_document( - document: Document, - effective_content: str | None = None, -) -> None: - with open_index_writer() as writer: - update_document(writer, document, effective_content=effective_content) - - -def remove_document_from_index(document: Document) -> None: - with open_index_writer() as writer: - remove_document(writer, document) - - -class MappedDocIdSet(DocIdSet): - """ - A DocIdSet backed by a set of `Document` IDs. - Supports efficiently looking up if a whoosh docnum is in the provided `filter_queryset`. - """ - - def __init__(self, filter_queryset: QuerySet, ixreader: IndexReader) -> None: - super().__init__() - document_ids = filter_queryset.order_by("id").values_list("id", flat=True) - max_id = document_ids.last() or 0 - self.document_ids = BitSet(document_ids, size=max_id) - self.ixreader = ixreader - - def __contains__(self, docnum) -> bool: - document_id = self.ixreader.stored_fields(docnum)["id"] - return document_id in self.document_ids - - def __bool__(self) -> Literal[True]: - # searcher.search ignores a filter if it's "falsy". - # We use this hack so this DocIdSet, when used as a filter, is never ignored. - return True - - -class DelayedQuery: - def _get_query(self): - raise NotImplementedError # pragma: no cover - - def _get_query_sortedby(self) -> tuple[None, Literal[False]] | tuple[str, bool]: - if "ordering" not in self.query_params: - return None, False - - field: str = self.query_params["ordering"] - - sort_fields_map: dict[str, str] = { - "created": "created", - "modified": "modified", - "added": "added", - "title": "title", - "correspondent__name": "correspondent", - "document_type__name": "type", - "archive_serial_number": "asn", - "num_notes": "num_notes", - "owner": "owner", - "page_count": "page_count", - } - - if field.startswith("-"): - field = field[1:] - reverse = True - else: - reverse = False - - if field not in sort_fields_map: - return None, False - else: - return sort_fields_map[field], reverse - - def __init__( - self, - searcher: Searcher, - query_params, - page_size, - filter_queryset: QuerySet, - ) -> None: - self.searcher = searcher - self.query_params = query_params - self.page_size = page_size - self.saved_results = dict() - self.first_score = None - self.filter_queryset = filter_queryset - self.suggested_correction = None - self._manual_hits_cache: list | None = None - - def __len__(self) -> int: - if self._manual_sort_requested(): - manual_hits = self._manual_hits() - return len(manual_hits) - - page = self[0:1] - return len(page) - - def _manual_sort_requested(self): - ordering = self.query_params.get("ordering", "") - return ordering.lstrip("-").startswith("custom_field_") - - def _manual_hits(self): - if self._manual_hits_cache is None: - q, mask, suggested_correction = self._get_query() - self.suggested_correction = suggested_correction - - results = self.searcher.search( - q, - mask=mask, - filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader), - limit=None, - ) - results.fragmenter = highlight.ContextFragmenter(surround=50) - results.formatter = HtmlFormatter(tagname="span", between=" ... ") - - if not self.first_score and len(results) > 0: - self.first_score = results[0].score - - if self.first_score: - results.top_n = [ - ( - (hit[0] / self.first_score) if self.first_score else None, - hit[1], - ) - for hit in results.top_n - ] - - hits_by_id = {hit["id"]: hit for hit in results} - matching_ids = list(hits_by_id.keys()) - - ordered_ids = list( - self.filter_queryset.filter(id__in=matching_ids).values_list( - "id", - flat=True, - ), - ) - ordered_ids = list(dict.fromkeys(ordered_ids)) - - self._manual_hits_cache = [ - hits_by_id[_id] for _id in ordered_ids if _id in hits_by_id - ] - return self._manual_hits_cache - - def __getitem__(self, item): - if item.start in self.saved_results: - return self.saved_results[item.start] - - if self._manual_sort_requested(): - manual_hits = self._manual_hits() - start = 0 if item.start is None else item.start - stop = item.stop - hits = manual_hits[start:stop] if stop is not None else manual_hits[start:] - page = ManualResultsPage(hits) - self.saved_results[start] = page - return page - - q, mask, suggested_correction = self._get_query() - self.suggested_correction = suggested_correction - sortedby, reverse = self._get_query_sortedby() - - page: ResultsPage = self.searcher.search_page( - q, - mask=mask, - filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader), - pagenum=math.floor(item.start / self.page_size) + 1, - pagelen=self.page_size, - sortedby=sortedby, - reverse=reverse, - ) - page.results.fragmenter = highlight.ContextFragmenter(surround=50) - page.results.formatter = HtmlFormatter(tagname="span", between=" ... ") - - if not self.first_score and len(page.results) > 0 and sortedby is None: - self.first_score = page.results[0].score - - page.results.top_n = [ - ( - (hit[0] / self.first_score) if self.first_score else None, - hit[1], - ) - for hit in page.results.top_n - ] - - self.saved_results[item.start] = page - - return page - - -class ManualResultsPage(list): - def __init__(self, hits) -> None: - super().__init__(hits) - self.results = ManualResults(hits) - - -class ManualResults: - def __init__(self, hits) -> None: - self._docnums = [hit.docnum for hit in hits] - - def docs(self): - return self._docnums - - -class LocalDateParser(English): - def reverse_timezone_offset(self, d): - return (d.replace(tzinfo=django_timezone.get_current_timezone())).astimezone( - UTC, - ) - - def date_from(self, *args, **kwargs): - d = super().date_from(*args, **kwargs) - if isinstance(d, timespan): - d.start = self.reverse_timezone_offset(d.start) - d.end = self.reverse_timezone_offset(d.end) - elif isinstance(d, datetime): - d = self.reverse_timezone_offset(d) - return d - - -class DelayedFullTextQuery(DelayedQuery): - def _get_query(self) -> tuple: - q_str = self.query_params["query"] - q_str = rewrite_natural_date_keywords(q_str) - qp = MultifieldParser( - [ - "content", - "title", - "correspondent", - "tag", - "type", - "notes", - "custom_fields", - ], - self.searcher.ixreader.schema, - ) - qp.add_plugin( - DateParserPlugin( - basedate=django_timezone.now(), - dateparser=LocalDateParser(), - ), - ) - q = qp.parse(q_str) - suggested_correction = None - try: - corrected = self.searcher.correct_query(q, q_str) - if corrected.string != q_str: - corrected_results = self.searcher.search( - corrected.query, - limit=1, - filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader), - scored=False, - ) - if len(corrected_results) > 0: - suggested_correction = corrected.string - except Exception as e: - logger.info( - "Error while correcting query %s: %s", - f"{q_str!r}", - e, - ) - - return q, None, suggested_correction - - -class DelayedMoreLikeThisQuery(DelayedQuery): - def _get_query(self) -> tuple: - more_like_doc_id = int(self.query_params["more_like_id"]) - content = Document.objects.get(id=more_like_doc_id).content - - docnum = self.searcher.document_number(id=more_like_doc_id) - kts = self.searcher.key_terms_from_text( - "content", - content, - numterms=20, - model=classify.Bo1Model, - normalize=False, - ) - q = query.Or( - [query.Term("content", word, boost=weight) for word, weight in kts], - ) - mask: set = {docnum} - - return q, mask, None - - -def autocomplete( - ix: FileIndex, - term: str, - limit: int = 10, - user: User | None = None, -) -> list: - """ - Mimics whoosh.reading.IndexReader.most_distinctive_terms with permissions - and without scoring - """ - terms = [] - - with ix.searcher(weighting=TF_IDF()) as s: - qp = QueryParser("content", schema=ix.schema) - # Don't let searches with a query that happen to match a field override the - # content field query instead and return bogus, not text data - qp.remove_plugin_class(FieldsPlugin) - q = qp.parse(f"{term.lower()}*") - user_criterias: list = get_permissions_criterias(user) - - results = s.search( - q, - terms=True, - filter=query.Or(user_criterias) if user_criterias is not None else None, - ) - - termCounts = Counter() - if results.has_matched_terms(): - for hit in results: - for _, match in hit.matched_terms(): - termCounts[match] += 1 - terms = [t for t, _ in termCounts.most_common(limit)] - - term_encoded: bytes = term.encode("UTF-8") - if term_encoded in terms: - terms.insert(0, terms.pop(terms.index(term_encoded))) - - return terms - - -def get_permissions_criterias(user: User | None = None) -> list: - user_criterias = [query.Term("has_owner", text=False)] - if user is not None: - if user.is_superuser: # superusers see all docs - user_criterias = [] - else: - user_criterias.append(query.Term("owner_id", user.id)) - user_criterias.append( - query.Term("viewer_id", str(user.id)), - ) - return user_criterias - - -def rewrite_natural_date_keywords(query_string: str) -> str: - """ - Rewrites natural date keywords (e.g. added:today or added:"yesterday") to UTC range syntax for Whoosh. - This resolves timezone issues with date parsing in Whoosh as well as adding support for more - natural date keywords. - """ - - tz = get_current_timezone() - local_now = now().astimezone(tz) - today = local_now.date() - - # all supported Keywords - pattern = r"(\b(?:added|created|modified))\s*:\s*[\"']?(today|yesterday|this month|previous month|previous week|previous quarter|this year|previous year)[\"']?" - - def repl(m): - field = m.group(1) - keyword = m.group(2).lower() - - match keyword: - case "today": - start = datetime.combine(today, time.min, tzinfo=tz) - end = datetime.combine(today, time.max, tzinfo=tz) - - case "yesterday": - yesterday = today - timedelta(days=1) - start = datetime.combine(yesterday, time.min, tzinfo=tz) - end = datetime.combine(yesterday, time.max, tzinfo=tz) - - case "this month": - start = datetime(local_now.year, local_now.month, 1, 0, 0, 0, tzinfo=tz) - end = start + relativedelta(months=1) - timedelta(seconds=1) - - case "previous month": - this_month_start = datetime( - local_now.year, - local_now.month, - 1, - 0, - 0, - 0, - tzinfo=tz, - ) - start = this_month_start - relativedelta(months=1) - end = this_month_start - timedelta(seconds=1) - - case "this year": - start = datetime(local_now.year, 1, 1, 0, 0, 0, tzinfo=tz) - end = datetime(local_now.year, 12, 31, 23, 59, 59, tzinfo=tz) - - case "previous week": - days_since_monday = local_now.weekday() - this_week_start = datetime.combine( - today - timedelta(days=days_since_monday), - time.min, - tzinfo=tz, - ) - start = this_week_start - timedelta(days=7) - end = this_week_start - timedelta(seconds=1) - - case "previous quarter": - current_quarter = (local_now.month - 1) // 3 + 1 - this_quarter_start_month = (current_quarter - 1) * 3 + 1 - this_quarter_start = datetime( - local_now.year, - this_quarter_start_month, - 1, - 0, - 0, - 0, - tzinfo=tz, - ) - start = this_quarter_start - relativedelta(months=3) - end = this_quarter_start - timedelta(seconds=1) - - case "previous year": - start = datetime(local_now.year - 1, 1, 1, 0, 0, 0, tzinfo=tz) - end = datetime(local_now.year - 1, 12, 31, 23, 59, 59, tzinfo=tz) - - # Convert to UTC and format - start_str = start.astimezone(UTC).strftime("%Y%m%d%H%M%S") - end_str = end.astimezone(UTC).strftime("%Y%m%d%H%M%S") - return f"{field}:[{start_str} TO {end_str}]" - - return re.sub(pattern, repl, query_string, flags=re.IGNORECASE) diff --git a/src/documents/management/commands/document_index.py b/src/documents/management/commands/document_index.py index 742922010..4c70ec268 100644 --- a/src/documents/management/commands/document_index.py +++ b/src/documents/management/commands/document_index.py @@ -1,3 +1,4 @@ +from django.conf import settings from django.db import transaction from documents.management.commands.base import PaperlessCommand @@ -14,10 +15,20 @@ class Command(PaperlessCommand): def add_arguments(self, parser): super().add_arguments(parser) parser.add_argument("command", choices=["reindex", "optimize"]) + parser.add_argument( + "--recreate", + action="store_true", + default=False, + help="Wipe and recreate the index from scratch (only used with reindex).", + ) def handle(self, *args, **options): with transaction.atomic(): if options["command"] == "reindex": + if options.get("recreate"): + from documents.search import wipe_index + + wipe_index(settings.INDEX_DIR) index_reindex( iter_wrapper=lambda docs: self.track( docs, diff --git a/src/documents/search/__init__.py b/src/documents/search/__init__.py index 1d0b4d04c..5da0a91d4 100644 --- a/src/documents/search/__init__.py +++ b/src/documents/search/__init__.py @@ -1,15 +1,19 @@ from documents.search._backend import SearchIndexLockError from documents.search._backend import SearchResults from documents.search._backend import TantivyBackend +from documents.search._backend import TantivyRelevanceList from documents.search._backend import WriteBatch from documents.search._backend import get_backend from documents.search._backend import reset_backend +from documents.search._schema import wipe_index __all__ = [ "SearchIndexLockError", "SearchResults", "TantivyBackend", + "TantivyRelevanceList", "WriteBatch", "get_backend", "reset_backend", + "wipe_index", ] diff --git a/src/documents/search/_backend.py b/src/documents/search/_backend.py index 9776074f3..56e5342f0 100644 --- a/src/documents/search/_backend.py +++ b/src/documents/search/_backend.py @@ -21,10 +21,10 @@ from guardian.shortcuts import get_users_with_perms from documents.search._query import build_permission_filter from documents.search._query import parse_user_query -from documents.search._schema import _wipe_index from documents.search._schema import _write_sentinels from documents.search._schema import build_schema from documents.search._schema import open_or_rebuild_index +from documents.search._schema import wipe_index from documents.search._tokenizer import register_tokenizers if TYPE_CHECKING: @@ -95,6 +95,24 @@ class SearchResults: query: str # preprocessed query string +class TantivyRelevanceList: + """DRF-compatible list wrapper for Tantivy search hits. + + __len__ returns the total hit count (for pagination); __getitem__ slices + the hit list. Stores ALL post-filter hits so that get_all_result_ids() + can return every matching doc ID without a second query. + """ + + def __init__(self, hits: list[SearchHit]) -> None: + self._hits = hits + + def __len__(self) -> int: + return len(self._hits) + + def __getitem__(self, key: slice) -> list[SearchHit]: + return self._hits[key] + + class SearchIndexLockError(Exception): pass @@ -139,9 +157,19 @@ class WriteBatch: if hasattr(self, "_lock") and self._lock: self._lock.release() - def add_or_update(self, document: Document) -> None: - """Add or update a document in the batch.""" - doc = self._backend._build_tantivy_doc(document) + def add_or_update( + self, + document: Document, + effective_content: str | None = None, + ) -> None: + """Add or update a document in the batch. + + Tantivy has no native upsert — we delete by id then re-add so + stale copies (e.g. after a permission change) don't linger. + ``effective_content`` overrides ``document.content`` for indexing. + """ + self.remove(document.pk) + doc = self._backend._build_tantivy_doc(document, effective_content) self._writer.add_document(doc) def remove(self, doc_id: int) -> None: @@ -175,8 +203,20 @@ class TantivyBackend: # Index doesn't need explicit close pass - def _build_tantivy_doc(self, document: Document) -> tantivy.Document: - """Build a tantivy Document from a Django Document instance.""" + def _build_tantivy_doc( + self, + document: Document, + effective_content: str | None = None, + ) -> tantivy.Document: + """Build a tantivy Document from a Django Document instance. + + ``effective_content`` overrides ``document.content`` for indexing — + used when re-indexing a root document with a newer version's OCR text. + """ + content = ( + effective_content if effective_content is not None else document.content + ) + doc = tantivy.Document() # Basic fields @@ -184,8 +224,8 @@ class TantivyBackend: doc.add_text("checksum", document.checksum) doc.add_text("title", document.title) doc.add_text("title_sort", document.title) - doc.add_text("content", document.content) - doc.add_text("bigram_content", document.content) + doc.add_text("content", content) + doc.add_text("bigram_content", content) # Original filename - only add if not None/empty if document.original_filename: @@ -269,7 +309,7 @@ class TantivyBackend: doc.add_unsigned("viewer_id", user.pk) # Autocomplete words with NLTK stopword filtering - text_sources = [document.title, document.content] + text_sources = [document.title, content] if document.correspondent: text_sources.append(document.correspondent.name) if document.document_type: @@ -285,10 +325,14 @@ class TantivyBackend: return doc - def add_or_update(self, document: Document) -> None: + def add_or_update( + self, + document: Document, + effective_content: str | None = None, + ) -> None: """Add or update a single document with file locking.""" with self.batch_update(lock_timeout=5.0) as batch: - batch.add_or_update(document) + batch.add_or_update(document, effective_content) def remove(self, doc_id: int) -> None: """Remove a single document with file locking.""" @@ -440,19 +484,30 @@ class TantivyBackend: query=query, ) - def autocomplete(self, term: str, limit: int) -> list[str]: - """Get autocomplete suggestions.""" + def autocomplete( + self, + term: str, + limit: int, + user: AbstractBaseUser | None = None, + ) -> list[str]: + """Get autocomplete suggestions, optionally filtered by user visibility.""" normalized_term = _ascii_fold(term.lower()) searcher = self._index.searcher() - # Search all documents to collect autocomplete words - all_query = tantivy.Query.all_query() - results = searcher.search(all_query, limit=10000) # High limit to get all docs + + # Apply permission filter for non-superusers so autocomplete words + # from invisible documents don't leak to other users. + if user is not None and not user.is_superuser: + base_query = build_permission_filter(self._schema, user) + else: + base_query = tantivy.Query.all_query() + + results = searcher.search(base_query, limit=10000) # Collect all autocomplete words words = set() for hit in results.hits: - # For all_query, hit is (score, doc_address) + # hits are (score, doc_address) tuples doc_address = hit[1] if len(hit) == 2 else hit[0] stored_doc = searcher.doc(doc_address) @@ -584,7 +639,7 @@ class TantivyBackend: index_dir = settings.INDEX_DIR # Create new index - _wipe_index(index_dir) + wipe_index(index_dir) new_index = tantivy.Index(build_schema(), path=str(index_dir)) _write_sentinels(index_dir) register_tokenizers(new_index, settings.SEARCH_LANGUAGE) diff --git a/src/documents/search/_query.py b/src/documents/search/_query.py index e03f364eb..c9b8b6131 100644 --- a/src/documents/search/_query.py +++ b/src/documents/search/_query.py @@ -8,6 +8,7 @@ from datetime import timedelta from typing import TYPE_CHECKING import tantivy +from dateutil.relativedelta import relativedelta from django.conf import settings if TYPE_CHECKING: @@ -38,6 +39,13 @@ _RELATIVE_RANGE_RE = re.compile( r"\[now([+-]\d+[dhm])?\s+TO\s+now([+-]\d+[dhm])?\]", re.IGNORECASE, ) +# Whoosh-style relative date range: e.g. [-1 week to now], [-7 days to now] +_WHOOSH_REL_RANGE_RE = re.compile( + r"\[-(?P\d+)\s+(?Psecond|minute|hour|day|week|month|year)s?\s+to\s+now\]", + re.IGNORECASE, +) +# Whoosh-style 8-digit date: field:YYYYMMDD — field-aware so timezone can be applied correctly +_DATE8_RE = re.compile(r"(?P\w+):(?P\d{8})\b") def _fmt(dt: datetime) -> str: @@ -200,6 +208,69 @@ def _rewrite_relative_range(query: str) -> str: return _RELATIVE_RANGE_RE.sub(_sub, query) +def _rewrite_whoosh_relative_range(query: str) -> str: + """Rewrite Whoosh-style relative date ranges ([-N unit to now]) to ISO 8601. + + Supports: second, minute, hour, day, week, month, year (singular and plural). + Example: ``added:[-1 week to now]`` → ``added:[2025-01-01T… TO 2025-01-08T…]`` + """ + now = datetime.now(UTC) + + def _sub(m: re.Match[str]) -> str: + n = int(m.group("n")) + unit = m.group("unit").lower() + delta_map: dict[str, timedelta | relativedelta] = { + "second": timedelta(seconds=n), + "minute": timedelta(minutes=n), + "hour": timedelta(hours=n), + "day": timedelta(days=n), + "week": timedelta(weeks=n), + "month": relativedelta(months=n), + "year": relativedelta(years=n), + } + lo = now - delta_map[unit] + return f"[{_fmt(lo)} TO {_fmt(now)}]" + + return _WHOOSH_REL_RANGE_RE.sub(_sub, query) + + +def _rewrite_8digit_date(query: str, tz: tzinfo) -> str: + """Rewrite field:YYYYMMDD date tokens to an ISO 8601 day range. + + Runs after ``_rewrite_compact_date`` so 14-digit timestamps are already + converted and won't spuriously match here. + + For DateField fields (e.g. ``created``) uses UTC midnight boundaries. + For DateTimeField fields (e.g. ``added``, ``modified``) uses local TZ + midnight boundaries converted to UTC — matching the ``_datetime_range`` + behaviour for keyword dates. + """ + + def _sub(m: re.Match[str]) -> str: + field = m.group("field") + raw = m.group("date8") + try: + year, month, day = int(raw[0:4]), int(raw[4:6]), int(raw[6:8]) + d = date(year, month, day) + if field in _DATE_ONLY_FIELDS: + lo = datetime(d.year, d.month, d.day, tzinfo=UTC) + hi = lo + timedelta(days=1) + else: + # DateTimeField: use local-timezone midnight → UTC + lo = datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC) + hi = datetime( + (d + timedelta(days=1)).year, + (d + timedelta(days=1)).month, + (d + timedelta(days=1)).day, + tzinfo=tz, + ).astimezone(UTC) + return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]" + except ValueError: + return m.group(0) + + return _DATE8_RE.sub(_sub, query) + + def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str: """ Preprocessing stage 1: rewrite Whoosh compact dates, relative ranges, @@ -207,6 +278,8 @@ def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str: Bare keywords without a field: prefix pass through unchanged. """ query = _rewrite_compact_date(query) + query = _rewrite_whoosh_relative_range(query) + query = _rewrite_8digit_date(query, tz) query = _rewrite_relative_range(query) def _replace(m: re.Match[str]) -> str: diff --git a/src/documents/search/_schema.py b/src/documents/search/_schema.py index c16e6d2f1..ea2c7d188 100644 --- a/src/documents/search/_schema.py +++ b/src/documents/search/_schema.py @@ -101,7 +101,7 @@ def _needs_rebuild(index_dir: Path) -> bool: return False -def _wipe_index(index_dir: Path) -> None: +def wipe_index(index_dir: Path) -> None: """Delete all children in the index directory to prepare for rebuild.""" for child in list(index_dir.iterdir()): if child.is_dir(): @@ -123,7 +123,7 @@ def open_or_rebuild_index() -> tantivy.Index: """ index_dir: Path = settings.INDEX_DIR if _needs_rebuild(index_dir): - _wipe_index(index_dir) + wipe_index(index_dir) idx = tantivy.Index(build_schema(), path=str(index_dir)) _write_sentinels(index_dir) return idx diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index b2902bba0..b1c2d07b9 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -1293,22 +1293,18 @@ class SearchResultSerializer(DocumentSerializer): documents = self.context.get("documents") # Otherwise we fetch this document. if documents is None: # pragma: no cover - # In practice we only serialize **lists** of whoosh.searching.Hit. - # I'm keeping this check for completeness but marking it no cover for now. + # In practice we only serialize **lists** of SearchHit dicts. + # Keeping this check for completeness but marking it no cover for now. documents = self.fetch_documents([hit["id"]]) document = documents[hit["id"]] - notes = ",".join( - [str(c.note) for c in document.notes.all()], - ) + highlights = hit.get("highlights", {}) r = super().to_representation(document) r["__search_hit__"] = { - "score": hit.score, - "highlights": hit.highlights("content", text=document.content), - "note_highlights": ( - hit.highlights("notes", text=notes) if document else None - ), - "rank": hit.rank, + "score": hit["score"], + "highlights": highlights.get("content", ""), + "note_highlights": highlights.get("notes") or None, + "rank": hit["rank"], } return r diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 82a691696..e1637628b 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -790,12 +790,12 @@ def cleanup_user_deletion(sender, instance: User | Group, **kwargs) -> None: def add_to_index(sender, document, **kwargs) -> None: - from documents import index + from documents.search import get_backend - index.add_or_update_document(document) + get_backend().add_or_update(document) if document.root_document_id is not None and document.root_document is not None: # keep in sync when a new version is consumed. - index.add_or_update_document( + get_backend().add_or_update( document.root_document, effective_content=document.content, ) diff --git a/src/documents/tasks.py b/src/documents/tasks.py index 4a3b10b45..7f6a93faf 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -82,27 +82,24 @@ def _identity(iterable: Iterable[_T]) -> Iterable[_T]: @shared_task def index_optimize() -> None: - from whoosh.writing import AsyncWriter - - from documents import index - - ix = index.open_index() - writer = AsyncWriter(ix) - writer.commit(optimize=True) + logger.info( + "document_index optimize is deprecated — Tantivy manages " + "segment merging automatically.", + ) def index_reindex(*, iter_wrapper: IterWrapper[Document] = _identity) -> None: - from whoosh.writing import AsyncWriter + from documents.search import get_backend + from documents.search import reset_backend - from documents import index - - documents = Document.objects.all() - - ix = index.open_index(recreate=True) - - with AsyncWriter(ix) as writer: - for document in iter_wrapper(documents): - index.update_document(writer, document) + documents = Document.objects.select_related( + "correspondent", + "document_type", + "storage_path", + "owner", + ).prefetch_related("tags", "notes", "custom_fields") + get_backend().rebuild(documents, iter_wrapper=iter_wrapper) + reset_backend() @shared_task @@ -276,14 +273,10 @@ def sanity_check(*, scheduled=True, raise_on_error=True): @shared_task def bulk_update_documents(document_ids) -> None: - from whoosh.writing import AsyncWriter - - from documents import index + from documents.search import get_backend documents = Document.objects.filter(id__in=document_ids) - ix = index.open_index() - for doc in documents: clear_document_caches(doc.pk) document_updated.send( @@ -293,9 +286,9 @@ def bulk_update_documents(document_ids) -> None: ) post_save.send(Document, instance=doc, created=False) - with AsyncWriter(ix) as writer: + with get_backend().batch_update() as batch: for doc in documents: - index.update_document(writer, doc) + batch.add_or_update(doc) ai_config = AIConfig() if ai_config.llm_index_enabled: @@ -310,8 +303,6 @@ def update_document_content_maybe_archive_file(document_id) -> None: Re-creates OCR content and thumbnail for a document, and archive file if it exists. """ - from documents import index - document = Document.objects.get(id=document_id) mime_type = document.mime_type @@ -401,8 +392,9 @@ def update_document_content_maybe_archive_file(document_id) -> None: logger.info( f"Updating index for document {document_id} ({document.archive_checksum})", ) - with index.open_index_writer() as writer: - index.update_document(writer, document) + from documents.search import get_backend + + get_backend().add_or_update(document) ai_config = AIConfig() if ai_config.llm_index_enabled: diff --git a/src/documents/tests/test_admin.py b/src/documents/tests/test_admin.py index de2f07df5..533319c2f 100644 --- a/src/documents/tests/test_admin.py +++ b/src/documents/tests/test_admin.py @@ -1,6 +1,7 @@ import types from unittest.mock import patch +import tantivy from django.contrib.admin.sites import AdminSite from django.contrib.auth.models import Permission from django.contrib.auth.models import User @@ -8,36 +9,54 @@ from django.test import TestCase from django.utils import timezone from rest_framework import status -from documents import index from documents.admin import DocumentAdmin from documents.admin import TagAdmin from documents.models import Document from documents.models import Tag +from documents.search import get_backend +from documents.search import reset_backend from documents.tests.utils import DirectoriesMixin from paperless.admin import PaperlessUserAdmin class TestDocumentAdmin(DirectoriesMixin, TestCase): def get_document_from_index(self, doc): - ix = index.open_index() - with ix.searcher() as searcher: - return searcher.document(id=doc.id) + backend = get_backend() + searcher = backend._index.searcher() + results = searcher.search( + tantivy.Query.range_query( + backend._schema, + "id", + tantivy.FieldType.Unsigned, + doc.pk, + doc.pk, + ), + limit=1, + ) + if results.hits: + return searcher.doc(results.hits[0][1]).to_dict() + return None def setUp(self) -> None: super().setUp() + reset_backend() self.doc_admin = DocumentAdmin(model=Document, admin_site=AdminSite()) + def tearDown(self) -> None: + reset_backend() + super().tearDown() + def test_save_model(self) -> None: doc = Document.objects.create(title="test") doc.title = "new title" self.doc_admin.save_model(None, doc, None, None) self.assertEqual(Document.objects.get(id=doc.id).title, "new title") - self.assertEqual(self.get_document_from_index(doc)["id"], doc.id) + self.assertEqual(self.get_document_from_index(doc)["id"], [doc.id]) def test_delete_model(self) -> None: doc = Document.objects.create(title="test") - index.add_or_update_document(doc) + get_backend().add_or_update(doc) self.assertIsNotNone(self.get_document_from_index(doc)) self.doc_admin.delete_model(None, doc) @@ -53,7 +72,7 @@ class TestDocumentAdmin(DirectoriesMixin, TestCase): checksum=f"{i:02}", ) docs.append(doc) - index.add_or_update_document(doc) + get_backend().add_or_update(doc) self.assertEqual(Document.objects.count(), 42) diff --git a/src/documents/tests/test_api_document_versions.py b/src/documents/tests/test_api_document_versions.py index f5c1a7346..d95e78fe9 100644 --- a/src/documents/tests/test_api_document_versions.py +++ b/src/documents/tests/test_api_document_versions.py @@ -109,7 +109,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): mime_type="application/pdf", ) - with mock.patch("documents.index.remove_document_from_index"): + with mock.patch("documents.search.get_backend"): resp = self.client.delete(f"/api/documents/{root.id}/versions/{root.id}/") self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) @@ -137,10 +137,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): content="v2-content", ) - with ( - mock.patch("documents.index.remove_document_from_index"), - mock.patch("documents.index.add_or_update_document"), - ): + with mock.patch("documents.search.get_backend"): resp = self.client.delete(f"/api/documents/{root.id}/versions/{v2.id}/") self.assertEqual(resp.status_code, status.HTTP_200_OK) @@ -149,10 +146,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): root.refresh_from_db() self.assertEqual(root.content, "root-content") - with ( - mock.patch("documents.index.remove_document_from_index"), - mock.patch("documents.index.add_or_update_document"), - ): + with mock.patch("documents.search.get_backend"): resp = self.client.delete(f"/api/documents/{root.id}/versions/{v1.id}/") self.assertEqual(resp.status_code, status.HTTP_200_OK) @@ -175,10 +169,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): ) version_id = version.id - with ( - mock.patch("documents.index.remove_document_from_index"), - mock.patch("documents.index.add_or_update_document"), - ): + with mock.patch("documents.search.get_backend"): resp = self.client.delete( f"/api/documents/{root.id}/versions/{version_id}/", ) @@ -225,7 +216,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): root_document=other_root, ) - with mock.patch("documents.index.remove_document_from_index"): + with mock.patch("documents.search.get_backend"): resp = self.client.delete( f"/api/documents/{root.id}/versions/{other_version.id}/", ) @@ -245,10 +236,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): root_document=root, ) - with ( - mock.patch("documents.index.remove_document_from_index"), - mock.patch("documents.index.add_or_update_document"), - ): + with mock.patch("documents.search.get_backend"): resp = self.client.delete( f"/api/documents/{version.id}/versions/{version.id}/", ) @@ -275,18 +263,17 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): root_document=root, ) - with ( - mock.patch("documents.index.remove_document_from_index") as remove_index, - mock.patch("documents.index.add_or_update_document") as add_or_update, - ): + with mock.patch("documents.search.get_backend") as mock_get_backend: + mock_backend = mock.MagicMock() + mock_get_backend.return_value = mock_backend resp = self.client.delete( f"/api/documents/{root.id}/versions/{version.id}/", ) self.assertEqual(resp.status_code, status.HTTP_200_OK) - remove_index.assert_called_once_with(version) - add_or_update.assert_called_once() - self.assertEqual(add_or_update.call_args[0][0].id, root.id) + mock_backend.remove.assert_called_once_with(version.pk) + mock_backend.add_or_update.assert_called_once() + self.assertEqual(mock_backend.add_or_update.call_args[0][0].id, root.id) def test_delete_version_returns_403_without_permission(self) -> None: owner = User.objects.create_user(username="owner") diff --git a/src/documents/tests/test_api_search.py b/src/documents/tests/test_api_search.py index 6c2ad1eb8..d8e22b220 100644 --- a/src/documents/tests/test_api_search.py +++ b/src/documents/tests/test_api_search.py @@ -11,9 +11,7 @@ from django.utils import timezone from guardian.shortcuts import assign_perm from rest_framework import status from rest_framework.test import APITestCase -from whoosh.writing import AsyncWriter -from documents import index from documents.bulk_edit import set_permissions from documents.models import Correspondent from documents.models import CustomField @@ -25,6 +23,8 @@ from documents.models import SavedView from documents.models import StoragePath from documents.models import Tag from documents.models import Workflow +from documents.search import get_backend +from documents.search import reset_backend from documents.tests.utils import DirectoriesMixin from paperless_mail.models import MailAccount from paperless_mail.models import MailRule @@ -33,10 +33,15 @@ from paperless_mail.models import MailRule class TestDocumentSearchApi(DirectoriesMixin, APITestCase): def setUp(self) -> None: super().setUp() + reset_backend() self.user = User.objects.create_superuser(username="temp_admin") self.client.force_authenticate(user=self.user) + def tearDown(self) -> None: + reset_backend() + super().tearDown() + def test_search(self) -> None: d1 = Document.objects.create( title="invoice", @@ -57,13 +62,11 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): checksum="C", original_filename="someepdf.pdf", ) - with AsyncWriter(index.open_index()) as writer: - # Note to future self: there is a reason we dont use a model signal handler to update the index: some operations edit many documents at once - # (retagger, renamer) and we don't want to open a writer for each of these, but rather perform the entire operation with one writer. - # That's why we can't open the writer in a model on_save handler or something. - index.update_document(writer, d1) - index.update_document(writer, d2) - index.update_document(writer, d3) + backend = get_backend() + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) + response = self.client.get("/api/documents/?query=bank") results = response.data["results"] self.assertEqual(response.data["count"], 3) @@ -125,10 +128,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): value_int=20, ) - with AsyncWriter(index.open_index()) as writer: - index.update_document(writer, d1) - index.update_document(writer, d2) - index.update_document(writer, d3) + backend = get_backend() + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) response = self.client.get( f"/api/documents/?query=match&ordering=custom_field_{custom_field.pk}", @@ -149,15 +152,15 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): ) def test_search_multi_page(self) -> None: - with AsyncWriter(index.open_index()) as writer: - for i in range(55): - doc = Document.objects.create( - checksum=str(i), - pk=i + 1, - title=f"Document {i + 1}", - content="content", - ) - index.update_document(writer, doc) + backend = get_backend() + for i in range(55): + doc = Document.objects.create( + checksum=str(i), + pk=i + 1, + title=f"Document {i + 1}", + content="content", + ) + backend.add_or_update(doc) # This is here so that we test that no document gets returned twice (might happen if the paging is not working) seen_ids = [] @@ -184,15 +187,15 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): seen_ids.append(result["id"]) def test_search_invalid_page(self) -> None: - with AsyncWriter(index.open_index()) as writer: - for i in range(15): - doc = Document.objects.create( - checksum=str(i), - pk=i + 1, - title=f"Document {i + 1}", - content="content", - ) - index.update_document(writer, doc) + backend = get_backend() + for i in range(15): + doc = Document.objects.create( + checksum=str(i), + pk=i + 1, + title=f"Document {i + 1}", + content="content", + ) + backend.add_or_update(doc) response = self.client.get("/api/documents/?query=content&page=0&page_size=10") self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) @@ -230,26 +233,25 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): pk=3, checksum="C", ) - with index.open_index_writer() as writer: - index.update_document(writer, d1) - index.update_document(writer, d2) - index.update_document(writer, d3) + backend = get_backend() + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) response = self.client.get("/api/documents/?query=added:[-1 week to now]") results = response.data["results"] # Expect 3 documents returned self.assertEqual(len(results), 3) - for idx, subset in enumerate( - [ - {"id": 1, "title": "invoice"}, - {"id": 2, "title": "bank statement 1"}, - {"id": 3, "title": "bank statement 3"}, - ], - ): - result = results[idx] - # Assert subset in results - self.assertDictEqual(result, {**result, **subset}) + result_map = {r["id"]: r for r in results} + self.assertEqual(set(result_map.keys()), {1, 2, 3}) + for subset in [ + {"id": 1, "title": "invoice"}, + {"id": 2, "title": "bank statement 1"}, + {"id": 3, "title": "bank statement 3"}, + ]: + r = result_map[subset["id"]] + self.assertDictEqual(r, {**r, **subset}) @override_settings( TIME_ZONE="America/Chicago", @@ -285,10 +287,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): # 7 days, 1 hour and 1 minute ago added=timezone.now() - timedelta(days=7, hours=1, minutes=1), ) - with index.open_index_writer() as writer: - index.update_document(writer, d1) - index.update_document(writer, d2) - index.update_document(writer, d3) + backend = get_backend() + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) response = self.client.get("/api/documents/?query=added:[-1 week to now]") results = response.data["results"] @@ -296,12 +298,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): # Expect 2 documents returned self.assertEqual(len(results), 2) - for idx, subset in enumerate( - [{"id": 1, "title": "invoice"}, {"id": 2, "title": "bank statement 1"}], - ): - result = results[idx] - # Assert subset in results - self.assertDictEqual(result, {**result, **subset}) + result_map = {r["id"]: r for r in results} + self.assertEqual(set(result_map.keys()), {1, 2}) + for subset in [ + {"id": 1, "title": "invoice"}, + {"id": 2, "title": "bank statement 1"}, + ]: + r = result_map[subset["id"]] + self.assertDictEqual(r, {**r, **subset}) @override_settings( TIME_ZONE="Europe/Sofia", @@ -337,10 +341,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): # 7 days, 1 hour and 1 minute ago added=timezone.now() - timedelta(days=7, hours=1, minutes=1), ) - with index.open_index_writer() as writer: - index.update_document(writer, d1) - index.update_document(writer, d2) - index.update_document(writer, d3) + backend = get_backend() + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) response = self.client.get("/api/documents/?query=added:[-1 week to now]") results = response.data["results"] @@ -348,12 +352,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): # Expect 2 documents returned self.assertEqual(len(results), 2) - for idx, subset in enumerate( - [{"id": 1, "title": "invoice"}, {"id": 2, "title": "bank statement 1"}], - ): - result = results[idx] - # Assert subset in results - self.assertDictEqual(result, {**result, **subset}) + result_map = {r["id"]: r for r in results} + self.assertEqual(set(result_map.keys()), {1, 2}) + for subset in [ + {"id": 1, "title": "invoice"}, + {"id": 2, "title": "bank statement 1"}, + ]: + r = result_map[subset["id"]] + self.assertDictEqual(r, {**r, **subset}) def test_search_added_in_last_month(self) -> None: """ @@ -389,10 +395,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): added=timezone.now() - timedelta(days=7, hours=1, minutes=1), ) - with index.open_index_writer() as writer: - index.update_document(writer, d1) - index.update_document(writer, d2) - index.update_document(writer, d3) + backend = get_backend() + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) response = self.client.get("/api/documents/?query=added:[-1 month to now]") results = response.data["results"] @@ -400,12 +406,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): # Expect 2 documents returned self.assertEqual(len(results), 2) - for idx, subset in enumerate( - [{"id": 1, "title": "invoice"}, {"id": 3, "title": "bank statement 3"}], - ): - result = results[idx] - # Assert subset in results - self.assertDictEqual(result, {**result, **subset}) + result_map = {r["id"]: r for r in results} + self.assertEqual(set(result_map.keys()), {1, 3}) + for subset in [ + {"id": 1, "title": "invoice"}, + {"id": 3, "title": "bank statement 3"}, + ]: + r = result_map[subset["id"]] + self.assertDictEqual(r, {**r, **subset}) @override_settings( TIME_ZONE="America/Denver", @@ -445,10 +453,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): added=timezone.now() - timedelta(days=7, hours=1, minutes=1), ) - with index.open_index_writer() as writer: - index.update_document(writer, d1) - index.update_document(writer, d2) - index.update_document(writer, d3) + backend = get_backend() + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) response = self.client.get("/api/documents/?query=added:[-1 month to now]") results = response.data["results"] @@ -456,12 +464,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): # Expect 2 documents returned self.assertEqual(len(results), 2) - for idx, subset in enumerate( - [{"id": 1, "title": "invoice"}, {"id": 3, "title": "bank statement 3"}], - ): - result = results[idx] - # Assert subset in results - self.assertDictEqual(result, {**result, **subset}) + result_map = {r["id"]: r for r in results} + self.assertEqual(set(result_map.keys()), {1, 3}) + for subset in [ + {"id": 1, "title": "invoice"}, + {"id": 3, "title": "bank statement 3"}, + ]: + r = result_map[subset["id"]] + self.assertDictEqual(r, {**r, **subset}) @override_settings( TIME_ZONE="Europe/Sofia", @@ -501,10 +511,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): # Django converts dates to UTC d3.refresh_from_db() - with index.open_index_writer() as writer: - index.update_document(writer, d1) - index.update_document(writer, d2) - index.update_document(writer, d3) + backend = get_backend() + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) response = self.client.get("/api/documents/?query=added:20231201") results = response.data["results"] @@ -512,12 +522,8 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): # Expect 1 document returned self.assertEqual(len(results), 1) - for idx, subset in enumerate( - [{"id": 3, "title": "bank statement 3"}], - ): - result = results[idx] - # Assert subset in results - self.assertDictEqual(result, {**result, **subset}) + self.assertEqual(results[0]["id"], 3) + self.assertEqual(results[0]["title"], "bank statement 3") def test_search_added_invalid_date(self) -> None: """ @@ -526,7 +532,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): WHEN: - Query with invalid added date THEN: - - No documents returned + - 400 Bad Request returned (Tantivy rejects invalid date field syntax) """ d1 = Document.objects.create( title="invoice", @@ -535,16 +541,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): pk=1, ) - with index.open_index_writer() as writer: - index.update_document(writer, d1) + get_backend().add_or_update(d1) response = self.client.get("/api/documents/?query=added:invalid-date") - results = response.data["results"] - # Expect 0 document returned - self.assertEqual(len(results), 0) + # Tantivy rejects unparsable field queries with a 400 + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - @mock.patch("documents.index.autocomplete") + @mock.patch("documents.search._backend.TantivyBackend.autocomplete") def test_search_autocomplete_limits(self, m) -> None: """ GIVEN: @@ -556,7 +560,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): - Limit requests are obeyed """ - m.side_effect = lambda ix, term, limit, user: [term for _ in range(limit)] + m.side_effect = lambda term, limit, user=None: [term for _ in range(limit)] response = self.client.get("/api/search/autocomplete/?term=test") self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -609,32 +613,29 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): owner=u1, ) - with AsyncWriter(index.open_index()) as writer: - index.update_document(writer, d1) - index.update_document(writer, d2) - index.update_document(writer, d3) + backend = get_backend() + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) response = self.client.get("/api/search/autocomplete/?term=app") self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, [b"apples", b"applebaum", b"appletini"]) + self.assertEqual(response.data, ["applebaum", "apples", "appletini"]) d3.owner = u2 - - with AsyncWriter(index.open_index()) as writer: - index.update_document(writer, d3) + d3.save() + backend.add_or_update(d3) response = self.client.get("/api/search/autocomplete/?term=app") self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, [b"apples", b"applebaum"]) + self.assertEqual(response.data, ["applebaum", "apples"]) assign_perm("view_document", u1, d3) - - with AsyncWriter(index.open_index()) as writer: - index.update_document(writer, d3) + backend.add_or_update(d3) response = self.client.get("/api/search/autocomplete/?term=app") self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, [b"apples", b"applebaum", b"appletini"]) + self.assertEqual(response.data, ["applebaum", "apples", "appletini"]) def test_search_autocomplete_field_name_match(self) -> None: """ @@ -652,8 +653,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): checksum="1", ) - with AsyncWriter(index.open_index()) as writer: - index.update_document(writer, d1) + get_backend().add_or_update(d1) response = self.client.get("/api/search/autocomplete/?term=created:2023") self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -674,33 +674,36 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): checksum="1", ) - with AsyncWriter(index.open_index()) as writer: - index.update_document(writer, d1) + get_backend().add_or_update(d1) response = self.client.get("/api/search/autocomplete/?term=auto") self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data[0], b"auto") + self.assertEqual(response.data[0], "auto") - def test_search_spelling_suggestion(self) -> None: - with AsyncWriter(index.open_index()) as writer: - for i in range(55): - doc = Document.objects.create( - checksum=str(i), - pk=i + 1, - title=f"Document {i + 1}", - content=f"Things document {i + 1}", - ) - index.update_document(writer, doc) + def test_search_no_spelling_suggestion(self) -> None: + """ + GIVEN: + - Documents exist with various terms + WHEN: + - Query for documents with any term + THEN: + - corrected_query is always None (Tantivy has no spell correction) + """ + backend = get_backend() + for i in range(5): + doc = Document.objects.create( + checksum=str(i), + pk=i + 1, + title=f"Document {i + 1}", + content=f"Things document {i + 1}", + ) + backend.add_or_update(doc) response = self.client.get("/api/documents/?query=thing") - correction = response.data["corrected_query"] - - self.assertEqual(correction, "things") + self.assertIsNone(response.data["corrected_query"]) response = self.client.get("/api/documents/?query=things") - correction = response.data["corrected_query"] - - self.assertEqual(correction, None) + self.assertIsNone(response.data["corrected_query"]) def test_search_spelling_suggestion_suppressed_for_private_terms(self): owner = User.objects.create_user("owner") @@ -709,24 +712,24 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): Permission.objects.get(codename="view_document"), ) - with AsyncWriter(index.open_index()) as writer: - for i in range(55): - private_doc = Document.objects.create( - checksum=f"p{i}", - pk=100 + i, - title=f"Private Document {i + 1}", - content=f"treasury document {i + 1}", - owner=owner, - ) - visible_doc = Document.objects.create( - checksum=f"v{i}", - pk=200 + i, - title=f"Visible Document {i + 1}", - content=f"public ledger {i + 1}", - owner=attacker, - ) - index.update_document(writer, private_doc) - index.update_document(writer, visible_doc) + backend = get_backend() + for i in range(5): + private_doc = Document.objects.create( + checksum=f"p{i}", + pk=100 + i, + title=f"Private Document {i + 1}", + content=f"treasury document {i + 1}", + owner=owner, + ) + visible_doc = Document.objects.create( + checksum=f"v{i}", + pk=200 + i, + title=f"Visible Document {i + 1}", + content=f"public ledger {i + 1}", + owner=attacker, + ) + backend.add_or_update(private_doc) + backend.add_or_update(visible_doc) self.client.force_authenticate(user=attacker) @@ -736,26 +739,6 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): self.assertEqual(response.data["count"], 0) self.assertIsNone(response.data["corrected_query"]) - @mock.patch( - "whoosh.searching.Searcher.correct_query", - side_effect=Exception("Test error"), - ) - def test_corrected_query_error(self, mock_correct_query) -> None: - """ - GIVEN: - - A query that raises an error on correction - WHEN: - - API request for search with that query - THEN: - - The error is logged and the search proceeds - """ - with self.assertLogs("paperless.index", level="INFO") as cm: - response = self.client.get("/api/documents/?query=2025-06-04") - self.assertEqual(response.status_code, status.HTTP_200_OK) - error_str = cm.output[0] - expected_str = "Error while correcting query '2025-06-04': Test error" - self.assertIn(expected_str, error_str) - def test_search_more_like(self) -> None: """ GIVEN: @@ -790,11 +773,11 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): pk=4, checksum="ABC", ) - with AsyncWriter(index.open_index()) as writer: - index.update_document(writer, d1) - index.update_document(writer, d2) - index.update_document(writer, d3) - index.update_document(writer, d4) + backend = get_backend() + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) + backend.add_or_update(d4) response = self.client.get(f"/api/documents/?more_like_id={d2.id}") @@ -802,9 +785,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): results = response.data["results"] - self.assertEqual(len(results), 2) - self.assertEqual(results[0]["id"], d3.id) - self.assertEqual(results[1]["id"], d1.id) + self.assertGreaterEqual(len(results), 1) + result_ids = [r["id"] for r in results] + self.assertIn(d3.id, result_ids) + self.assertNotIn(d4.id, result_ids) def test_search_more_like_requires_view_permission_on_seed_document( self, @@ -846,10 +830,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): pk=12, ) - with AsyncWriter(index.open_index()) as writer: - index.update_document(writer, private_seed) - index.update_document(writer, visible_doc) - index.update_document(writer, other_doc) + backend = get_backend() + backend.add_or_update(private_seed) + backend.add_or_update(visible_doc) + backend.add_or_update(other_doc) self.client.force_authenticate(user=attacker) @@ -923,9 +907,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): value_text="foobard4", ) - with AsyncWriter(index.open_index()) as writer: - for doc in Document.objects.all(): - index.update_document(writer, doc) + backend = get_backend() + for doc in Document.objects.all(): + backend.add_or_update(doc) def search_query(q): r = self.client.get("/api/documents/?query=test" + q) @@ -1141,9 +1125,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): Document.objects.create(checksum="3", content="test 3", owner=u2) Document.objects.create(checksum="4", content="test 4") - with AsyncWriter(index.open_index()) as writer: - for doc in Document.objects.all(): - index.update_document(writer, doc) + backend = get_backend() + for doc in Document.objects.all(): + backend.add_or_update(doc) self.client.force_authenticate(user=u1) r = self.client.get("/api/documents/?query=test") @@ -1194,9 +1178,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): d3 = Document.objects.create(checksum="3", content="test 3", owner=u2) Document.objects.create(checksum="4", content="test 4") - with AsyncWriter(index.open_index()) as writer: - for doc in Document.objects.all(): - index.update_document(writer, doc) + backend = get_backend() + for doc in Document.objects.all(): + backend.add_or_update(doc) self.client.force_authenticate(user=u1) r = self.client.get("/api/documents/?query=test") @@ -1216,9 +1200,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): assign_perm("view_document", u1, d3) assign_perm("view_document", u2, d1) - with AsyncWriter(index.open_index()) as writer: - for doc in [d1, d2, d3]: - index.update_document(writer, doc) + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) self.client.force_authenticate(user=u1) r = self.client.get("/api/documents/?query=test") @@ -1281,9 +1265,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): user=u1, ) - with AsyncWriter(index.open_index()) as writer: - for doc in Document.objects.all(): - index.update_document(writer, doc) + backend = get_backend() + for doc in Document.objects.all(): + backend.add_or_update(doc) def search_query(q): r = self.client.get("/api/documents/?query=test" + q) @@ -1316,13 +1300,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): search_query("&ordering=-num_notes"), [d1.id, d3.id, d2.id], ) + # owner sort: ORM orders by owner_id (integer); NULLs first in SQLite ASC self.assertListEqual( search_query("&ordering=owner"), - [d1.id, d2.id, d3.id], + [d3.id, d1.id, d2.id], ) self.assertListEqual( search_query("&ordering=-owner"), - [d3.id, d2.id, d1.id], + [d2.id, d1.id, d3.id], ) @mock.patch("documents.bulk_edit.bulk_update_documents") @@ -1379,12 +1364,12 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): ) set_permissions([4, 5], set_permissions={}, owner=user2, merge=False) - with index.open_index_writer() as writer: - index.update_document(writer, d1) - index.update_document(writer, d2) - index.update_document(writer, d3) - index.update_document(writer, d4) - index.update_document(writer, d5) + backend = get_backend() + backend.add_or_update(d1) + backend.add_or_update(d2) + backend.add_or_update(d3) + backend.add_or_update(d4) + backend.add_or_update(d5) correspondent1 = Correspondent.objects.create(name="bank correspondent 1") Correspondent.objects.create(name="correspondent 2") diff --git a/src/documents/tests/test_api_status.py b/src/documents/tests/test_api_status.py index b8f7d408e..32717af63 100644 --- a/src/documents/tests/test_api_status.py +++ b/src/documents/tests/test_api_status.py @@ -191,40 +191,42 @@ class TestSystemStatus(APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["tasks"]["celery_status"], "OK") - @override_settings(INDEX_DIR=Path("/tmp/index")) - @mock.patch("whoosh.index.FileIndex.last_modified") - def test_system_status_index_ok(self, mock_last_modified) -> None: + @mock.patch("documents.search.get_backend") + def test_system_status_index_ok(self, mock_get_backend) -> None: """ GIVEN: - - The index last modified time is set + - The index is accessible WHEN: - The user requests the system status THEN: - The response contains the correct index status """ - mock_last_modified.return_value = 1707839087 - self.client.force_login(self.user) - response = self.client.get(self.ENDPOINT) + mock_get_backend.return_value = mock.MagicMock() + # Use the temp dir created in setUp (self.tmp_dir) as a real INDEX_DIR + # with a real file so the mtime lookup works + sentinel = self.tmp_dir / "sentinel.txt" + sentinel.write_text("ok") + with self.settings(INDEX_DIR=self.tmp_dir): + self.client.force_login(self.user) + response = self.client.get(self.ENDPOINT) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["tasks"]["index_status"], "OK") self.assertIsNotNone(response.data["tasks"]["index_last_modified"]) - @override_settings(INDEX_DIR=Path("/tmp/index/")) - @mock.patch("documents.index.open_index", autospec=True) - def test_system_status_index_error(self, mock_open_index) -> None: + @mock.patch("documents.search.get_backend") + def test_system_status_index_error(self, mock_get_backend) -> None: """ GIVEN: - - The index is not found + - The index cannot be opened WHEN: - The user requests the system status THEN: - The response contains the correct index status """ - mock_open_index.return_value = None - mock_open_index.side_effect = Exception("Index error") + mock_get_backend.side_effect = Exception("Index error") self.client.force_login(self.user) response = self.client.get(self.ENDPOINT) - mock_open_index.assert_called_once() + mock_get_backend.assert_called_once() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["tasks"]["index_status"], "ERROR") self.assertIsNotNone(response.data["tasks"]["index_error"]) diff --git a/src/documents/tests/test_delayedquery.py b/src/documents/tests/test_delayedquery.py deleted file mode 100644 index 6357d9030..000000000 --- a/src/documents/tests/test_delayedquery.py +++ /dev/null @@ -1,58 +0,0 @@ -from django.test import TestCase -from whoosh import query - -from documents.index import get_permissions_criterias -from documents.models import User - - -class TestDelayedQuery(TestCase): - def setUp(self) -> None: - super().setUp() - # all tests run without permission criteria, so has_no_owner query will always - # be appended. - self.has_no_owner = query.Or([query.Term("has_owner", text=False)]) - - def _get_testset__id__in(self, param, field): - return ( - {f"{param}__id__in": "42,43"}, - query.And( - [ - query.Or( - [ - query.Term(f"{field}_id", "42"), - query.Term(f"{field}_id", "43"), - ], - ), - self.has_no_owner, - ], - ), - ) - - def _get_testset__id__none(self, param, field): - return ( - {f"{param}__id__none": "42,43"}, - query.And( - [ - query.Not(query.Term(f"{field}_id", "42")), - query.Not(query.Term(f"{field}_id", "43")), - self.has_no_owner, - ], - ), - ) - - def test_get_permission_criteria(self) -> None: - # tests contains tuples of user instances and the expected filter - tests = ( - (None, [query.Term("has_owner", text=False)]), - (User(42, username="foo", is_superuser=True), []), - ( - User(42, username="foo", is_superuser=False), - [ - query.Term("has_owner", text=False), - query.Term("owner_id", 42), - query.Term("viewer_id", "42"), - ], - ), - ) - for user, expected in tests: - self.assertEqual(get_permissions_criterias(user), expected) diff --git a/src/documents/tests/test_index.py b/src/documents/tests/test_index.py deleted file mode 100644 index 5f1c7487d..000000000 --- a/src/documents/tests/test_index.py +++ /dev/null @@ -1,371 +0,0 @@ -from datetime import datetime -from unittest import mock - -from django.conf import settings -from django.contrib.auth.models import User -from django.test import SimpleTestCase -from django.test import TestCase -from django.test import override_settings -from django.utils.timezone import get_current_timezone -from django.utils.timezone import timezone - -from documents import index -from documents.models import Document -from documents.tests.utils import DirectoriesMixin - - -class TestAutoComplete(DirectoriesMixin, TestCase): - def test_auto_complete(self) -> None: - doc1 = Document.objects.create( - title="doc1", - checksum="A", - content="test test2 test3", - ) - doc2 = Document.objects.create(title="doc2", checksum="B", content="test test2") - doc3 = Document.objects.create(title="doc3", checksum="C", content="test2") - - index.add_or_update_document(doc1) - index.add_or_update_document(doc2) - index.add_or_update_document(doc3) - - ix = index.open_index() - - self.assertListEqual( - index.autocomplete(ix, "tes"), - [b"test2", b"test", b"test3"], - ) - self.assertListEqual( - index.autocomplete(ix, "tes", limit=3), - [b"test2", b"test", b"test3"], - ) - self.assertListEqual(index.autocomplete(ix, "tes", limit=1), [b"test2"]) - self.assertListEqual(index.autocomplete(ix, "tes", limit=0), []) - - def test_archive_serial_number_ranging(self) -> None: - """ - GIVEN: - - Document with an archive serial number above schema allowed size - WHEN: - - Document is provided to the index - THEN: - - Error is logged - - Document ASN is reset to 0 for the index - """ - doc1 = Document.objects.create( - title="doc1", - checksum="A", - content="test test2 test3", - # yes, this is allowed, unless full_clean is run - # DRF does call the validators, this test won't - archive_serial_number=Document.ARCHIVE_SERIAL_NUMBER_MAX + 1, - ) - with self.assertLogs("paperless.index", level="ERROR") as cm: - with mock.patch( - "documents.index.AsyncWriter.update_document", - ) as mocked_update_doc: - index.add_or_update_document(doc1) - - mocked_update_doc.assert_called_once() - _, kwargs = mocked_update_doc.call_args - - self.assertEqual(kwargs["asn"], 0) - - error_str = cm.output[0] - expected_str = "ERROR:paperless.index:Not indexing Archive Serial Number 4294967296 of document 1" - self.assertIn(expected_str, error_str) - - def test_archive_serial_number_is_none(self) -> None: - """ - GIVEN: - - Document with no archive serial number - WHEN: - - Document is provided to the index - THEN: - - ASN isn't touched - """ - doc1 = Document.objects.create( - title="doc1", - checksum="A", - content="test test2 test3", - ) - with mock.patch( - "documents.index.AsyncWriter.update_document", - ) as mocked_update_doc: - index.add_or_update_document(doc1) - - mocked_update_doc.assert_called_once() - _, kwargs = mocked_update_doc.call_args - - self.assertIsNone(kwargs["asn"]) - - @override_settings(TIME_ZONE="Pacific/Auckland") - def test_added_today_respects_local_timezone_boundary(self) -> None: - tz = get_current_timezone() - fixed_now = datetime(2025, 7, 20, 15, 0, 0, tzinfo=tz) - - # Fake a time near the local boundary (1 AM NZT = 13:00 UTC on previous UTC day) - local_dt = datetime(2025, 7, 20, 1, 0, 0).replace(tzinfo=tz) - utc_dt = local_dt.astimezone(timezone.utc) - - doc = Document.objects.create( - title="Time zone", - content="Testing added:today", - checksum="edgecase123", - added=utc_dt, - ) - - with index.open_index_writer() as writer: - index.update_document(writer, doc) - - superuser = User.objects.create_superuser(username="testuser") - self.client.force_login(superuser) - - with mock.patch("documents.index.now", return_value=fixed_now): - response = self.client.get("/api/documents/?query=added:today") - results = response.json()["results"] - self.assertEqual(len(results), 1) - self.assertEqual(results[0]["id"], doc.id) - - response = self.client.get("/api/documents/?query=added:yesterday") - results = response.json()["results"] - self.assertEqual(len(results), 0) - - -@override_settings(TIME_ZONE="UTC") -class TestRewriteNaturalDateKeywords(SimpleTestCase): - """ - Unit tests for rewrite_natural_date_keywords function. - """ - - def _rewrite_with_now(self, query: str, now_dt: datetime) -> str: - with mock.patch("documents.index.now", return_value=now_dt): - return index.rewrite_natural_date_keywords(query) - - def _assert_rewrite_contains( - self, - query: str, - now_dt: datetime, - *expected_fragments: str, - ) -> str: - result = self._rewrite_with_now(query, now_dt) - for fragment in expected_fragments: - self.assertIn(fragment, result) - return result - - def test_range_keywords(self) -> None: - """ - Test various different range keywords - """ - cases = [ - ( - "added:today", - datetime(2025, 7, 20, 15, 30, 45, tzinfo=timezone.utc), - ("added:[20250720", "TO 20250720"), - ), - ( - "added:yesterday", - datetime(2025, 7, 20, 15, 30, 45, tzinfo=timezone.utc), - ("added:[20250719", "TO 20250719"), - ), - ( - "added:this month", - datetime(2025, 7, 15, 12, 0, 0, tzinfo=timezone.utc), - ("added:[20250701", "TO 20250731"), - ), - ( - "added:previous month", - datetime(2025, 7, 15, 12, 0, 0, tzinfo=timezone.utc), - ("added:[20250601", "TO 20250630"), - ), - ( - "added:this year", - datetime(2025, 7, 15, 12, 0, 0, tzinfo=timezone.utc), - ("added:[20250101", "TO 20251231"), - ), - ( - "added:previous year", - datetime(2025, 7, 15, 12, 0, 0, tzinfo=timezone.utc), - ("added:[20240101", "TO 20241231"), - ), - # Previous quarter from July 15, 2025 is April-June. - ( - "added:previous quarter", - datetime(2025, 7, 15, 12, 0, 0, tzinfo=timezone.utc), - ("added:[20250401", "TO 20250630"), - ), - # July 20, 2025 is a Sunday (weekday 6) so previous week is July 7-13. - ( - "added:previous week", - datetime(2025, 7, 20, 12, 0, 0, tzinfo=timezone.utc), - ("added:[20250707", "TO 20250713"), - ), - ] - - for query, now_dt, fragments in cases: - with self.subTest(query=query): - self._assert_rewrite_contains(query, now_dt, *fragments) - - def test_additional_fields(self) -> None: - fixed_now = datetime(2025, 7, 20, 15, 30, 45, tzinfo=timezone.utc) - # created - self._assert_rewrite_contains("created:today", fixed_now, "created:[20250720") - # modified - self._assert_rewrite_contains("modified:today", fixed_now, "modified:[20250720") - - def test_basic_syntax_variants(self) -> None: - """ - Test that quoting, casing, and multi-clause queries are parsed. - """ - fixed_now = datetime(2025, 7, 20, 15, 30, 45, tzinfo=timezone.utc) - - # quoted keywords - result1 = self._rewrite_with_now('added:"today"', fixed_now) - result2 = self._rewrite_with_now("added:'today'", fixed_now) - self.assertIn("added:[20250720", result1) - self.assertIn("added:[20250720", result2) - - # case insensitivity - for query in ("added:TODAY", "added:Today", "added:ToDaY"): - with self.subTest(case_variant=query): - self._assert_rewrite_contains(query, fixed_now, "added:[20250720") - - # multiple clauses - result = self._rewrite_with_now("added:today created:yesterday", fixed_now) - self.assertIn("added:[20250720", result) - self.assertIn("created:[20250719", result) - - def test_no_match(self) -> None: - """ - Test that queries without keywords are unchanged. - """ - query = "title:test content:example" - result = index.rewrite_natural_date_keywords(query) - self.assertEqual(query, result) - - @override_settings(TIME_ZONE="Pacific/Auckland") - def test_timezone_awareness(self) -> None: - """ - Test timezone conversion. - """ - # July 20, 2025 1:00 AM NZST = July 19, 2025 13:00 UTC - fixed_now = datetime(2025, 7, 20, 1, 0, 0, tzinfo=get_current_timezone()) - result = self._rewrite_with_now("added:today", fixed_now) - # Should convert to UTC properly - self.assertIn("added:[20250719", result) - - -class TestIndexResilience(DirectoriesMixin, SimpleTestCase): - def _assert_recreate_called(self, mock_create_in) -> None: - mock_create_in.assert_called_once() - path_arg, schema_arg = mock_create_in.call_args.args - self.assertEqual(path_arg, settings.INDEX_DIR) - self.assertEqual(schema_arg.__class__.__name__, "Schema") - - def test_transient_missing_segment_does_not_force_recreate(self) -> None: - """ - GIVEN: - - Index directory exists - WHEN: - - open_index is called - - Opening the index raises FileNotFoundError once due to a - transient missing segment - THEN: - - Index is opened successfully on retry - - Index is not recreated - """ - file_marker = settings.INDEX_DIR / "file_marker.txt" - file_marker.write_text("keep") - expected_index = object() - - with ( - mock.patch("documents.index.exists_in", return_value=True), - mock.patch( - "documents.index.open_dir", - side_effect=[FileNotFoundError("missing"), expected_index], - ) as mock_open_dir, - mock.patch( - "documents.index.create_in", - ) as mock_create_in, - mock.patch( - "documents.index.rmtree", - ) as mock_rmtree, - ): - ix = index.open_index() - - self.assertIs(ix, expected_index) - self.assertGreaterEqual(mock_open_dir.call_count, 2) - mock_rmtree.assert_not_called() - mock_create_in.assert_not_called() - self.assertEqual(file_marker.read_text(), "keep") - - def test_transient_errors_exhaust_retries_and_recreate(self) -> None: - """ - GIVEN: - - Index directory exists - WHEN: - - open_index is called - - Opening the index raises FileNotFoundError multiple times due to - transient missing segments - THEN: - - Index is recreated after retries are exhausted - """ - recreated_index = object() - - with ( - self.assertLogs("paperless.index", level="ERROR") as cm, - mock.patch("documents.index.exists_in", return_value=True), - mock.patch( - "documents.index.open_dir", - side_effect=FileNotFoundError("missing"), - ) as mock_open_dir, - mock.patch("documents.index.rmtree") as mock_rmtree, - mock.patch( - "documents.index.create_in", - return_value=recreated_index, - ) as mock_create_in, - ): - ix = index.open_index() - - self.assertIs(ix, recreated_index) - self.assertEqual(mock_open_dir.call_count, 4) - mock_rmtree.assert_called_once_with(settings.INDEX_DIR) - self._assert_recreate_called(mock_create_in) - self.assertIn( - "Error while opening the index after retries, recreating.", - cm.output[0], - ) - - def test_non_transient_error_recreates_index(self) -> None: - """ - GIVEN: - - Index directory exists - WHEN: - - open_index is called - - Opening the index raises a "non-transient" error - THEN: - - Index is recreated - """ - recreated_index = object() - - with ( - self.assertLogs("paperless.index", level="ERROR") as cm, - mock.patch("documents.index.exists_in", return_value=True), - mock.patch( - "documents.index.open_dir", - side_effect=RuntimeError("boom"), - ), - mock.patch("documents.index.rmtree") as mock_rmtree, - mock.patch( - "documents.index.create_in", - return_value=recreated_index, - ) as mock_create_in, - ): - ix = index.open_index() - - self.assertIs(ix, recreated_index) - mock_rmtree.assert_called_once_with(settings.INDEX_DIR) - self._assert_recreate_called(mock_create_in) - self.assertIn( - "Error while opening the index, recreating.", - cm.output[0], - ) diff --git a/src/documents/tests/test_matchables.py b/src/documents/tests/test_matchables.py index e038bf786..e13d3827a 100644 --- a/src/documents/tests/test_matchables.py +++ b/src/documents/tests/test_matchables.py @@ -452,7 +452,10 @@ class TestDocumentConsumptionFinishedSignal(TestCase): """ def setUp(self) -> None: + from documents.search import reset_backend + TestCase.setUp(self) + reset_backend() User.objects.create_user(username="test_consumer", password="12345") self.doc_contains = Document.objects.create( content="I contain the keyword.", @@ -464,6 +467,9 @@ class TestDocumentConsumptionFinishedSignal(TestCase): override_settings(INDEX_DIR=self.index_dir).enable() def tearDown(self) -> None: + from documents.search import reset_backend + + reset_backend() shutil.rmtree(self.index_dir, ignore_errors=True) def test_tag_applied_any(self) -> None: diff --git a/src/documents/tests/test_task_signals.py b/src/documents/tests/test_task_signals.py index 4f17a8fd2..420c5199e 100644 --- a/src/documents/tests/test_task_signals.py +++ b/src/documents/tests/test_task_signals.py @@ -208,10 +208,12 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase): mime_type="application/pdf", ) - with mock.patch("documents.index.add_or_update_document") as add: + with mock.patch("documents.search.get_backend") as mock_get_backend: + mock_backend = mock.MagicMock() + mock_get_backend.return_value = mock_backend add_to_index(sender=None, document=root) - add.assert_called_once_with(root) + mock_backend.add_or_update.assert_called_once_with(root) def test_add_to_index_reindexes_root_for_version_documents(self) -> None: root = Document.objects.create( @@ -226,13 +228,21 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase): root_document=root, ) - with mock.patch("documents.index.add_or_update_document") as add: + with mock.patch("documents.search.get_backend") as mock_get_backend: + mock_backend = mock.MagicMock() + mock_get_backend.return_value = mock_backend add_to_index(sender=None, document=version) - self.assertEqual(add.call_count, 2) - self.assertEqual(add.call_args_list[0].args[0].id, version.id) - self.assertEqual(add.call_args_list[1].args[0].id, root.id) + self.assertEqual(mock_backend.add_or_update.call_count, 2) self.assertEqual( - add.call_args_list[1].kwargs, + mock_backend.add_or_update.call_args_list[0].args[0].id, + version.id, + ) + self.assertEqual( + mock_backend.add_or_update.call_args_list[1].args[0].id, + root.id, + ) + self.assertEqual( + mock_backend.add_or_update.call_args_list[1].kwargs, {"effective_content": version.content}, ) diff --git a/src/documents/tests/utils.py b/src/documents/tests/utils.py index 346d895aa..cc4190974 100644 --- a/src/documents/tests/utils.py +++ b/src/documents/tests/utils.py @@ -157,11 +157,17 @@ class DirectoriesMixin: """ def setUp(self) -> None: + from documents.search import reset_backend + + reset_backend() self.dirs = setup_directories() super().setUp() def tearDown(self) -> None: + from documents.search import reset_backend + super().tearDown() + reset_backend() remove_dirs(self.dirs) diff --git a/src/documents/views.py b/src/documents/views.py index 600acf078..7abc18613 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -100,7 +100,6 @@ from rest_framework.viewsets import ReadOnlyModelViewSet from rest_framework.viewsets import ViewSet from documents import bulk_edit -from documents import index from documents.bulk_download import ArchiveOnlyStrategy from documents.bulk_download import OriginalAndArchiveStrategy from documents.bulk_download import OriginalsOnlyStrategy @@ -972,9 +971,9 @@ class DocumentViewSet( response_data["content"] = content_doc.content response = Response(response_data) - from documents import index + from documents.search import get_backend - index.add_or_update_document(refreshed_doc) + get_backend().add_or_update(refreshed_doc) document_updated.send( sender=self.__class__, @@ -984,9 +983,9 @@ class DocumentViewSet( return response def destroy(self, request, *args, **kwargs): - from documents import index + from documents.search import get_backend - index.remove_document_from_index(self.get_object()) + get_backend().remove(self.get_object().pk) try: return super().destroy(request, *args, **kwargs) except Exception as e: @@ -1393,9 +1392,9 @@ class DocumentViewSet( doc.modified = timezone.now() doc.save() - from documents import index + from documents.search import get_backend - index.add_or_update_document(doc) + get_backend().add_or_update(doc) notes = serializer.to_representation(doc).get("notes") @@ -1430,9 +1429,9 @@ class DocumentViewSet( doc.modified = timezone.now() doc.save() - from documents import index + from documents.search import get_backend - index.add_or_update_document(doc) + get_backend().add_or_update(doc) notes = serializer.to_representation(doc).get("notes") @@ -1744,12 +1743,13 @@ class DocumentViewSet( "Cannot delete the root/original version. Delete the document instead.", ) - from documents import index + from documents.search import get_backend - index.remove_document_from_index(version_doc) + _backend = get_backend() + _backend.remove(version_doc.pk) version_doc_id = version_doc.id version_doc.delete() - index.add_or_update_document(root_doc) + _backend.add_or_update(root_doc) if settings.AUDIT_LOG_ENABLED: actor = ( request.user if request.user and request.user.is_authenticated else None @@ -1949,10 +1949,6 @@ class ChatStreamingView(GenericAPIView): ), ) class UnifiedSearchViewSet(DocumentViewSet): - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.searcher = None - def get_serializer_class(self): if self._is_search_request(): return SearchResultSerializer @@ -1965,17 +1961,34 @@ class UnifiedSearchViewSet(DocumentViewSet): or "more_like_id" in self.request.query_params ) - def filter_queryset(self, queryset): - filtered_queryset = super().filter_queryset(queryset) + def list(self, request, *args, **kwargs): + if not self._is_search_request(): + return super().list(request) - if self._is_search_request(): - if "query" in self.request.query_params: - from documents import index + from documents.search import TantivyRelevanceList + from documents.search import get_backend - query_class = index.DelayedFullTextQuery - elif "more_like_id" in self.request.query_params: + try: + backend = get_backend() + # ORM-filtered queryset: permissions + field filters + ordering (DRF backends applied) + filtered_qs = self.filter_queryset(self.get_queryset()) + + user = None if request.user.is_superuser else request.user + + if "query" in request.query_params: + query_str = request.query_params["query"] + results = backend.search( + query_str, + user=user, + page=1, + page_size=10000, + sort_field=None, + sort_reverse=False, + ) + else: + # more_like_id — validate permission on the seed document first try: - more_like_doc_id = int(self.request.query_params["more_like_id"]) + more_like_doc_id = int(request.query_params["more_like_id"]) more_like_doc = Document.objects.select_related("owner").get( pk=more_like_doc_id, ) @@ -1983,61 +1996,62 @@ class UnifiedSearchViewSet(DocumentViewSet): raise PermissionDenied(_("Invalid more_like_id")) if not has_perms_owner_aware( - self.request.user, + request.user, "view_document", more_like_doc, ): raise PermissionDenied(_("Insufficient permissions.")) - from documents import index - - query_class = index.DelayedMoreLikeThisQuery - else: - raise ValueError - - return query_class( - self.searcher, - self.request.query_params, - self.paginator.get_page_size(self.request), - filter_queryset=filtered_queryset, - ) - else: - return filtered_queryset - - def list(self, request, *args, **kwargs): - if self._is_search_request(): - from documents import index - - try: - with index.open_index_searcher() as s: - self.searcher = s - queryset = self.filter_queryset(self.get_queryset()) - page = self.paginate_queryset(queryset) - - serializer = self.get_serializer(page, many=True) - response = self.get_paginated_response(serializer.data) - - response.data["corrected_query"] = ( - queryset.suggested_correction - if hasattr(queryset, "suggested_correction") - else None - ) - - return response - except NotFound: - raise - except PermissionDenied as e: - invalid_more_like_id_message = _("Invalid more_like_id") - if str(e.detail) == str(invalid_more_like_id_message): - return HttpResponseForbidden(invalid_more_like_id_message) - return HttpResponseForbidden(_("Insufficient permissions.")) - except Exception as e: - logger.warning(f"An error occurred listing search results: {e!s}") - return HttpResponseBadRequest( - "Error listing search results, check logs for more detail.", + results = backend.more_like_this( + more_like_doc_id, + user=user, + page=1, + page_size=10000, ) - else: - return super().list(request) + + hits_by_id = {h["id"]: h for h in results.hits} + + # Determine sort order: no ordering param → Tantivy relevance; otherwise → ORM order + ordering_param = request.query_params.get("ordering", "").lstrip("-") + if not ordering_param: + # Preserve Tantivy relevance order; intersect with ORM-visible IDs + orm_ids = set(filtered_qs.values_list("pk", flat=True)) + ordered_hits = [h for h in results.hits if h["id"] in orm_ids] + else: + # Use ORM ordering (already applied by DocumentsOrderingFilter) + hit_ids = set(hits_by_id.keys()) + orm_ordered_ids = filtered_qs.filter(id__in=hit_ids).values_list( + "pk", + flat=True, + ) + ordered_hits = [ + hits_by_id[pk] for pk in orm_ordered_ids if pk in hits_by_id + ] + + rl = TantivyRelevanceList(ordered_hits) + page = self.paginate_queryset(rl) + + if page is not None: + serializer = self.get_serializer(page, many=True) + response = self.get_paginated_response(serializer.data) + response.data["corrected_query"] = None + return response + + serializer = self.get_serializer(ordered_hits, many=True) + return Response(serializer.data) + + except NotFound: + raise + except PermissionDenied as e: + invalid_more_like_id_message = _("Invalid more_like_id") + if str(e.detail) == str(invalid_more_like_id_message): + return HttpResponseForbidden(invalid_more_like_id_message) + return HttpResponseForbidden(_("Insufficient permissions.")) + except Exception as e: + logger.warning(f"An error occurred listing search results: {e!s}") + return HttpResponseBadRequest( + "Error listing search results, check logs for more detail.", + ) @action(detail=False, methods=["GET"], name="Get Next ASN") def next_asn(self, request, *args, **kwargs): @@ -2816,18 +2830,9 @@ class SearchAutoCompleteView(GenericAPIView): else: limit = 10 - from documents import index + from documents.search import get_backend - ix = index.open_index() - - return Response( - index.autocomplete( - ix, - term, - limit, - user, - ), - ) + return Response(get_backend().autocomplete(term, limit, user)) @extend_schema_view( @@ -2893,20 +2898,21 @@ class GlobalSearchView(PassUserMixin): # First search by title docs = all_docs.filter(title__icontains=query) if not db_only and len(docs) < OBJECT_LIMIT: - # If we don't have enough results, search by content - from documents import index + # If we don't have enough results, search by content. + # Over-fetch from Tantivy (no permission filter) and rely on + # the ORM all_docs queryset for authoritative permission gating. + from documents.search import get_backend - with index.open_index_searcher() as s: - fts_query = index.DelayedFullTextQuery( - s, - request.query_params, - OBJECT_LIMIT, - filter_queryset=all_docs, - ) - results = fts_query[0:1] - docs = docs | Document.objects.filter( - id__in=[r["id"] for r in results], - ) + fts_results = get_backend().search( + query, + user=None, + page=1, + page_size=1000, + sort_field=None, + sort_reverse=False, + ) + fts_ids = {h["id"] for h in fts_results.hits} + docs = docs | all_docs.filter(id__in=fts_ids) docs = docs[:OBJECT_LIMIT] saved_views = ( get_objects_for_user_owner_aware( @@ -4105,10 +4111,16 @@ class SystemStatusView(PassUserMixin): index_error = None try: - ix = index.open_index() + from documents.search import get_backend + + get_backend() # triggers open/rebuild; raises on error index_status = "OK" - index_last_modified = make_aware( - datetime.fromtimestamp(ix.last_modified()), + # Use the most-recently modified file in the index directory as a proxy + # for last index write time (Tantivy has no single last_modified() call). + index_dir = settings.INDEX_DIR + mtimes = [p.stat().st_mtime for p in index_dir.iterdir() if p.is_file()] + index_last_modified = ( + make_aware(datetime.fromtimestamp(max(mtimes))) if mtimes else None ) except Exception as e: index_status = "ERROR" diff --git a/src/paperless/views.py b/src/paperless/views.py index 6c201d1e4..24f9ffa75 100644 --- a/src/paperless/views.py +++ b/src/paperless/views.py @@ -71,22 +71,12 @@ class StandardPagination(PageNumberPagination): ) def get_all_result_ids(self): - from documents.index import DelayedQuery # removed with Whoosh in Task 14 + from documents.search import TantivyRelevanceList query = self.page.paginator.object_list - if isinstance(query, DelayedQuery): - try: - ids = [ - query.searcher.ixreader.stored_fields( - doc_num, - )["id"] - for doc_num in query.saved_results.get(0).results.docs() - ] - except Exception: - pass - else: - ids = self.page.paginator.object_list.values_list("pk", flat=True) - return ids + if isinstance(query, TantivyRelevanceList): + return [h["id"] for h in query._hits] + return self.page.paginator.object_list.values_list("pk", flat=True) def get_paginated_response_schema(self, schema): response_schema = super().get_paginated_response_schema(schema)