mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-07-01 17:54:25 +00:00
feat(search): wire Tantivy backend into all callsites; remove Whoosh
- Replace all `from documents import index` + Whoosh writer usage across admin.py, bulk_edit.py, tasks.py, views.py, signals/handlers.py with `get_backend().add_or_update/remove/batch_update` - Add `effective_content` param to `_build_tantivy_doc` / `add_or_update` (used by signal handler to re-index root doc with version's OCR text) - Add `wipe_index()` (renamed from `_wipe_index`) to public API; use from `document_index --recreate` flag - `index_optimize()` replaced with deprecation log message; Tantivy manages segment merging automatically - `index_reindex()` now calls `get_backend().rebuild()` + `reset_backend()` with select_related/prefetch_related for efficiency - `document_index` management command: add `--recreate` flag - Status view: use `get_backend()` + dir mtime scan instead of Whoosh `ix.last_modified()` - Delete `documents/index.py`, `test_index.py`, `test_delayedquery.py` - Update all tests: patch `documents.search.get_backend` (lazy imports); `DirectoriesMixin` calls `reset_backend()` in setUp/tearDown; `TestDocumentConsumptionFinishedSignal` likewise - `test_api_search.py`: fix order-independent assertions for date-range queries; fix `_rewrite_8digit_date` to be field-aware and timezone-correct for DateTimeField vs DateField Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -100,24 +100,23 @@ class DocumentAdmin(GuardedModelAdmin):
|
||||
return Document.global_objects.all()
|
||||
|
||||
def delete_queryset(self, request, queryset):
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
with index.open_index_writer() as writer:
|
||||
with get_backend().batch_update() as batch:
|
||||
for o in queryset:
|
||||
index.remove_document(writer, o)
|
||||
|
||||
batch.remove(o.pk)
|
||||
super().delete_queryset(request, queryset)
|
||||
|
||||
def delete_model(self, request, obj):
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
index.remove_document_from_index(obj)
|
||||
get_backend().remove(obj.pk)
|
||||
super().delete_model(request, obj)
|
||||
|
||||
def save_model(self, request, obj, form, change):
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
index.add_or_update_document(obj)
|
||||
get_backend().add_or_update(obj)
|
||||
super().save_model(request, obj, form, change)
|
||||
|
||||
|
||||
|
||||
@@ -349,11 +349,11 @@ def delete(doc_ids: list[int]) -> Literal["OK"]:
|
||||
|
||||
Document.objects.filter(id__in=delete_ids).delete()
|
||||
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
with index.open_index_writer() as writer:
|
||||
with get_backend().batch_update() as batch:
|
||||
for id in delete_ids:
|
||||
index.remove_document_by_id(writer, id)
|
||||
batch.remove(id)
|
||||
|
||||
status_mgr = DocumentsStatusManager()
|
||||
status_mgr.send_documents_deleted(delete_ids)
|
||||
|
||||
@@ -1,655 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections import Counter
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC
|
||||
from datetime import datetime
|
||||
from datetime import time
|
||||
from datetime import timedelta
|
||||
from shutil import rmtree
|
||||
from time import sleep
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Literal
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from django.conf import settings
|
||||
from django.utils import timezone as django_timezone
|
||||
from django.utils.timezone import get_current_timezone
|
||||
from django.utils.timezone import now
|
||||
from guardian.shortcuts import get_users_with_perms
|
||||
from whoosh import classify
|
||||
from whoosh import highlight
|
||||
from whoosh import query
|
||||
from whoosh.fields import BOOLEAN
|
||||
from whoosh.fields import DATETIME
|
||||
from whoosh.fields import KEYWORD
|
||||
from whoosh.fields import NUMERIC
|
||||
from whoosh.fields import TEXT
|
||||
from whoosh.fields import Schema
|
||||
from whoosh.highlight import HtmlFormatter
|
||||
from whoosh.idsets import BitSet
|
||||
from whoosh.idsets import DocIdSet
|
||||
from whoosh.index import FileIndex
|
||||
from whoosh.index import LockError
|
||||
from whoosh.index import create_in
|
||||
from whoosh.index import exists_in
|
||||
from whoosh.index import open_dir
|
||||
from whoosh.qparser import MultifieldParser
|
||||
from whoosh.qparser import QueryParser
|
||||
from whoosh.qparser.dateparse import DateParserPlugin
|
||||
from whoosh.qparser.dateparse import English
|
||||
from whoosh.qparser.plugins import FieldsPlugin
|
||||
from whoosh.scoring import TF_IDF
|
||||
from whoosh.util.times import timespan
|
||||
from whoosh.writing import AsyncWriter
|
||||
|
||||
from documents.models import CustomFieldInstance
|
||||
from documents.models import Document
|
||||
from documents.models import Note
|
||||
from documents.models import User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.db.models import QuerySet
|
||||
from whoosh.reading import IndexReader
|
||||
from whoosh.searching import ResultsPage
|
||||
from whoosh.searching import Searcher
|
||||
|
||||
logger = logging.getLogger("paperless.index")
|
||||
|
||||
|
||||
def get_schema() -> Schema:
|
||||
return Schema(
|
||||
id=NUMERIC(stored=True, unique=True),
|
||||
title=TEXT(sortable=True),
|
||||
content=TEXT(),
|
||||
asn=NUMERIC(sortable=True, signed=False),
|
||||
correspondent=TEXT(sortable=True),
|
||||
correspondent_id=NUMERIC(),
|
||||
has_correspondent=BOOLEAN(),
|
||||
tag=KEYWORD(commas=True, scorable=True, lowercase=True),
|
||||
tag_id=KEYWORD(commas=True, scorable=True),
|
||||
has_tag=BOOLEAN(),
|
||||
type=TEXT(sortable=True),
|
||||
type_id=NUMERIC(),
|
||||
has_type=BOOLEAN(),
|
||||
created=DATETIME(sortable=True),
|
||||
modified=DATETIME(sortable=True),
|
||||
added=DATETIME(sortable=True),
|
||||
path=TEXT(sortable=True),
|
||||
path_id=NUMERIC(),
|
||||
has_path=BOOLEAN(),
|
||||
notes=TEXT(),
|
||||
num_notes=NUMERIC(sortable=True, signed=False),
|
||||
custom_fields=TEXT(),
|
||||
custom_field_count=NUMERIC(sortable=True, signed=False),
|
||||
has_custom_fields=BOOLEAN(),
|
||||
custom_fields_id=KEYWORD(commas=True),
|
||||
owner=TEXT(),
|
||||
owner_id=NUMERIC(),
|
||||
has_owner=BOOLEAN(),
|
||||
viewer_id=KEYWORD(commas=True),
|
||||
checksum=TEXT(),
|
||||
page_count=NUMERIC(sortable=True),
|
||||
original_filename=TEXT(sortable=True),
|
||||
is_shared=BOOLEAN(),
|
||||
)
|
||||
|
||||
|
||||
def open_index(*, recreate=False) -> FileIndex:
|
||||
transient_exceptions = (FileNotFoundError, LockError)
|
||||
max_retries = 3
|
||||
retry_delay = 0.1
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
if exists_in(settings.INDEX_DIR) and not recreate:
|
||||
return open_dir(settings.INDEX_DIR, schema=get_schema())
|
||||
break
|
||||
except transient_exceptions as exc:
|
||||
is_last_attempt = attempt == max_retries or recreate
|
||||
if is_last_attempt:
|
||||
logger.exception(
|
||||
"Error while opening the index after retries, recreating.",
|
||||
)
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
"Transient error while opening the index (attempt %s/%s): %s. Retrying.",
|
||||
attempt + 1,
|
||||
max_retries + 1,
|
||||
exc,
|
||||
)
|
||||
sleep(retry_delay)
|
||||
except Exception:
|
||||
logger.exception("Error while opening the index, recreating.")
|
||||
break
|
||||
|
||||
# create_in doesn't handle corrupted indexes very well, remove the directory entirely first
|
||||
if settings.INDEX_DIR.is_dir():
|
||||
rmtree(settings.INDEX_DIR)
|
||||
settings.INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return create_in(settings.INDEX_DIR, get_schema())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def open_index_writer(*, optimize=False) -> AsyncWriter:
|
||||
writer = AsyncWriter(open_index())
|
||||
|
||||
try:
|
||||
yield writer
|
||||
except Exception as e:
|
||||
logger.exception(str(e))
|
||||
writer.cancel()
|
||||
finally:
|
||||
writer.commit(optimize=optimize)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def open_index_searcher() -> Searcher:
|
||||
searcher = open_index().searcher()
|
||||
|
||||
try:
|
||||
yield searcher
|
||||
finally:
|
||||
searcher.close()
|
||||
|
||||
|
||||
def update_document(
|
||||
writer: AsyncWriter,
|
||||
doc: Document,
|
||||
effective_content: str | None = None,
|
||||
) -> None:
|
||||
tags = ",".join([t.name for t in doc.tags.all()])
|
||||
tags_ids = ",".join([str(t.id) for t in doc.tags.all()])
|
||||
notes = ",".join([str(c.note) for c in Note.objects.filter(document=doc)])
|
||||
custom_fields = ",".join(
|
||||
[str(c) for c in CustomFieldInstance.objects.filter(document=doc)],
|
||||
)
|
||||
custom_fields_ids = ",".join(
|
||||
[str(f.field.id) for f in CustomFieldInstance.objects.filter(document=doc)],
|
||||
)
|
||||
asn: int | None = doc.archive_serial_number
|
||||
if asn is not None and (
|
||||
asn < Document.ARCHIVE_SERIAL_NUMBER_MIN
|
||||
or asn > Document.ARCHIVE_SERIAL_NUMBER_MAX
|
||||
):
|
||||
logger.error(
|
||||
f"Not indexing Archive Serial Number {asn} of document {doc.pk}. "
|
||||
f"ASN is out of range "
|
||||
f"[{Document.ARCHIVE_SERIAL_NUMBER_MIN:,}, "
|
||||
f"{Document.ARCHIVE_SERIAL_NUMBER_MAX:,}.",
|
||||
)
|
||||
asn = 0
|
||||
users_with_perms = get_users_with_perms(
|
||||
doc,
|
||||
only_with_perms_in=["view_document"],
|
||||
)
|
||||
viewer_ids: str = ",".join([str(u.id) for u in users_with_perms])
|
||||
writer.update_document(
|
||||
id=doc.pk,
|
||||
title=doc.title,
|
||||
content=effective_content or doc.content,
|
||||
correspondent=doc.correspondent.name if doc.correspondent else None,
|
||||
correspondent_id=doc.correspondent.id if doc.correspondent else None,
|
||||
has_correspondent=doc.correspondent is not None,
|
||||
tag=tags if tags else None,
|
||||
tag_id=tags_ids if tags_ids else None,
|
||||
has_tag=len(tags) > 0,
|
||||
type=doc.document_type.name if doc.document_type else None,
|
||||
type_id=doc.document_type.id if doc.document_type else None,
|
||||
has_type=doc.document_type is not None,
|
||||
created=datetime.combine(doc.created, time.min),
|
||||
added=doc.added,
|
||||
asn=asn,
|
||||
modified=doc.modified,
|
||||
path=doc.storage_path.name if doc.storage_path else None,
|
||||
path_id=doc.storage_path.id if doc.storage_path else None,
|
||||
has_path=doc.storage_path is not None,
|
||||
notes=notes,
|
||||
num_notes=len(notes),
|
||||
custom_fields=custom_fields,
|
||||
custom_field_count=len(doc.custom_fields.all()),
|
||||
has_custom_fields=len(custom_fields) > 0,
|
||||
custom_fields_id=custom_fields_ids if custom_fields_ids else None,
|
||||
owner=doc.owner.username if doc.owner else None,
|
||||
owner_id=doc.owner.id if doc.owner else None,
|
||||
has_owner=doc.owner is not None,
|
||||
viewer_id=viewer_ids if viewer_ids else None,
|
||||
checksum=doc.checksum,
|
||||
page_count=doc.page_count,
|
||||
original_filename=doc.original_filename,
|
||||
is_shared=len(viewer_ids) > 0,
|
||||
)
|
||||
logger.debug(f"Index updated for document {doc.pk}.")
|
||||
|
||||
|
||||
def remove_document(writer: AsyncWriter, doc: Document) -> None:
|
||||
remove_document_by_id(writer, doc.pk)
|
||||
|
||||
|
||||
def remove_document_by_id(writer: AsyncWriter, doc_id) -> None:
|
||||
writer.delete_by_term("id", doc_id)
|
||||
|
||||
|
||||
def add_or_update_document(
|
||||
document: Document,
|
||||
effective_content: str | None = None,
|
||||
) -> None:
|
||||
with open_index_writer() as writer:
|
||||
update_document(writer, document, effective_content=effective_content)
|
||||
|
||||
|
||||
def remove_document_from_index(document: Document) -> None:
|
||||
with open_index_writer() as writer:
|
||||
remove_document(writer, document)
|
||||
|
||||
|
||||
class MappedDocIdSet(DocIdSet):
|
||||
"""
|
||||
A DocIdSet backed by a set of `Document` IDs.
|
||||
Supports efficiently looking up if a whoosh docnum is in the provided `filter_queryset`.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_queryset: QuerySet, ixreader: IndexReader) -> None:
|
||||
super().__init__()
|
||||
document_ids = filter_queryset.order_by("id").values_list("id", flat=True)
|
||||
max_id = document_ids.last() or 0
|
||||
self.document_ids = BitSet(document_ids, size=max_id)
|
||||
self.ixreader = ixreader
|
||||
|
||||
def __contains__(self, docnum) -> bool:
|
||||
document_id = self.ixreader.stored_fields(docnum)["id"]
|
||||
return document_id in self.document_ids
|
||||
|
||||
def __bool__(self) -> Literal[True]:
|
||||
# searcher.search ignores a filter if it's "falsy".
|
||||
# We use this hack so this DocIdSet, when used as a filter, is never ignored.
|
||||
return True
|
||||
|
||||
|
||||
class DelayedQuery:
|
||||
def _get_query(self):
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def _get_query_sortedby(self) -> tuple[None, Literal[False]] | tuple[str, bool]:
|
||||
if "ordering" not in self.query_params:
|
||||
return None, False
|
||||
|
||||
field: str = self.query_params["ordering"]
|
||||
|
||||
sort_fields_map: dict[str, str] = {
|
||||
"created": "created",
|
||||
"modified": "modified",
|
||||
"added": "added",
|
||||
"title": "title",
|
||||
"correspondent__name": "correspondent",
|
||||
"document_type__name": "type",
|
||||
"archive_serial_number": "asn",
|
||||
"num_notes": "num_notes",
|
||||
"owner": "owner",
|
||||
"page_count": "page_count",
|
||||
}
|
||||
|
||||
if field.startswith("-"):
|
||||
field = field[1:]
|
||||
reverse = True
|
||||
else:
|
||||
reverse = False
|
||||
|
||||
if field not in sort_fields_map:
|
||||
return None, False
|
||||
else:
|
||||
return sort_fields_map[field], reverse
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
searcher: Searcher,
|
||||
query_params,
|
||||
page_size,
|
||||
filter_queryset: QuerySet,
|
||||
) -> None:
|
||||
self.searcher = searcher
|
||||
self.query_params = query_params
|
||||
self.page_size = page_size
|
||||
self.saved_results = dict()
|
||||
self.first_score = None
|
||||
self.filter_queryset = filter_queryset
|
||||
self.suggested_correction = None
|
||||
self._manual_hits_cache: list | None = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self._manual_sort_requested():
|
||||
manual_hits = self._manual_hits()
|
||||
return len(manual_hits)
|
||||
|
||||
page = self[0:1]
|
||||
return len(page)
|
||||
|
||||
def _manual_sort_requested(self):
|
||||
ordering = self.query_params.get("ordering", "")
|
||||
return ordering.lstrip("-").startswith("custom_field_")
|
||||
|
||||
def _manual_hits(self):
|
||||
if self._manual_hits_cache is None:
|
||||
q, mask, suggested_correction = self._get_query()
|
||||
self.suggested_correction = suggested_correction
|
||||
|
||||
results = self.searcher.search(
|
||||
q,
|
||||
mask=mask,
|
||||
filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader),
|
||||
limit=None,
|
||||
)
|
||||
results.fragmenter = highlight.ContextFragmenter(surround=50)
|
||||
results.formatter = HtmlFormatter(tagname="span", between=" ... ")
|
||||
|
||||
if not self.first_score and len(results) > 0:
|
||||
self.first_score = results[0].score
|
||||
|
||||
if self.first_score:
|
||||
results.top_n = [
|
||||
(
|
||||
(hit[0] / self.first_score) if self.first_score else None,
|
||||
hit[1],
|
||||
)
|
||||
for hit in results.top_n
|
||||
]
|
||||
|
||||
hits_by_id = {hit["id"]: hit for hit in results}
|
||||
matching_ids = list(hits_by_id.keys())
|
||||
|
||||
ordered_ids = list(
|
||||
self.filter_queryset.filter(id__in=matching_ids).values_list(
|
||||
"id",
|
||||
flat=True,
|
||||
),
|
||||
)
|
||||
ordered_ids = list(dict.fromkeys(ordered_ids))
|
||||
|
||||
self._manual_hits_cache = [
|
||||
hits_by_id[_id] for _id in ordered_ids if _id in hits_by_id
|
||||
]
|
||||
return self._manual_hits_cache
|
||||
|
||||
def __getitem__(self, item):
|
||||
if item.start in self.saved_results:
|
||||
return self.saved_results[item.start]
|
||||
|
||||
if self._manual_sort_requested():
|
||||
manual_hits = self._manual_hits()
|
||||
start = 0 if item.start is None else item.start
|
||||
stop = item.stop
|
||||
hits = manual_hits[start:stop] if stop is not None else manual_hits[start:]
|
||||
page = ManualResultsPage(hits)
|
||||
self.saved_results[start] = page
|
||||
return page
|
||||
|
||||
q, mask, suggested_correction = self._get_query()
|
||||
self.suggested_correction = suggested_correction
|
||||
sortedby, reverse = self._get_query_sortedby()
|
||||
|
||||
page: ResultsPage = self.searcher.search_page(
|
||||
q,
|
||||
mask=mask,
|
||||
filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader),
|
||||
pagenum=math.floor(item.start / self.page_size) + 1,
|
||||
pagelen=self.page_size,
|
||||
sortedby=sortedby,
|
||||
reverse=reverse,
|
||||
)
|
||||
page.results.fragmenter = highlight.ContextFragmenter(surround=50)
|
||||
page.results.formatter = HtmlFormatter(tagname="span", between=" ... ")
|
||||
|
||||
if not self.first_score and len(page.results) > 0 and sortedby is None:
|
||||
self.first_score = page.results[0].score
|
||||
|
||||
page.results.top_n = [
|
||||
(
|
||||
(hit[0] / self.first_score) if self.first_score else None,
|
||||
hit[1],
|
||||
)
|
||||
for hit in page.results.top_n
|
||||
]
|
||||
|
||||
self.saved_results[item.start] = page
|
||||
|
||||
return page
|
||||
|
||||
|
||||
class ManualResultsPage(list):
|
||||
def __init__(self, hits) -> None:
|
||||
super().__init__(hits)
|
||||
self.results = ManualResults(hits)
|
||||
|
||||
|
||||
class ManualResults:
|
||||
def __init__(self, hits) -> None:
|
||||
self._docnums = [hit.docnum for hit in hits]
|
||||
|
||||
def docs(self):
|
||||
return self._docnums
|
||||
|
||||
|
||||
class LocalDateParser(English):
|
||||
def reverse_timezone_offset(self, d):
|
||||
return (d.replace(tzinfo=django_timezone.get_current_timezone())).astimezone(
|
||||
UTC,
|
||||
)
|
||||
|
||||
def date_from(self, *args, **kwargs):
|
||||
d = super().date_from(*args, **kwargs)
|
||||
if isinstance(d, timespan):
|
||||
d.start = self.reverse_timezone_offset(d.start)
|
||||
d.end = self.reverse_timezone_offset(d.end)
|
||||
elif isinstance(d, datetime):
|
||||
d = self.reverse_timezone_offset(d)
|
||||
return d
|
||||
|
||||
|
||||
class DelayedFullTextQuery(DelayedQuery):
|
||||
def _get_query(self) -> tuple:
|
||||
q_str = self.query_params["query"]
|
||||
q_str = rewrite_natural_date_keywords(q_str)
|
||||
qp = MultifieldParser(
|
||||
[
|
||||
"content",
|
||||
"title",
|
||||
"correspondent",
|
||||
"tag",
|
||||
"type",
|
||||
"notes",
|
||||
"custom_fields",
|
||||
],
|
||||
self.searcher.ixreader.schema,
|
||||
)
|
||||
qp.add_plugin(
|
||||
DateParserPlugin(
|
||||
basedate=django_timezone.now(),
|
||||
dateparser=LocalDateParser(),
|
||||
),
|
||||
)
|
||||
q = qp.parse(q_str)
|
||||
suggested_correction = None
|
||||
try:
|
||||
corrected = self.searcher.correct_query(q, q_str)
|
||||
if corrected.string != q_str:
|
||||
corrected_results = self.searcher.search(
|
||||
corrected.query,
|
||||
limit=1,
|
||||
filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader),
|
||||
scored=False,
|
||||
)
|
||||
if len(corrected_results) > 0:
|
||||
suggested_correction = corrected.string
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Error while correcting query %s: %s",
|
||||
f"{q_str!r}",
|
||||
e,
|
||||
)
|
||||
|
||||
return q, None, suggested_correction
|
||||
|
||||
|
||||
class DelayedMoreLikeThisQuery(DelayedQuery):
|
||||
def _get_query(self) -> tuple:
|
||||
more_like_doc_id = int(self.query_params["more_like_id"])
|
||||
content = Document.objects.get(id=more_like_doc_id).content
|
||||
|
||||
docnum = self.searcher.document_number(id=more_like_doc_id)
|
||||
kts = self.searcher.key_terms_from_text(
|
||||
"content",
|
||||
content,
|
||||
numterms=20,
|
||||
model=classify.Bo1Model,
|
||||
normalize=False,
|
||||
)
|
||||
q = query.Or(
|
||||
[query.Term("content", word, boost=weight) for word, weight in kts],
|
||||
)
|
||||
mask: set = {docnum}
|
||||
|
||||
return q, mask, None
|
||||
|
||||
|
||||
def autocomplete(
|
||||
ix: FileIndex,
|
||||
term: str,
|
||||
limit: int = 10,
|
||||
user: User | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Mimics whoosh.reading.IndexReader.most_distinctive_terms with permissions
|
||||
and without scoring
|
||||
"""
|
||||
terms = []
|
||||
|
||||
with ix.searcher(weighting=TF_IDF()) as s:
|
||||
qp = QueryParser("content", schema=ix.schema)
|
||||
# Don't let searches with a query that happen to match a field override the
|
||||
# content field query instead and return bogus, not text data
|
||||
qp.remove_plugin_class(FieldsPlugin)
|
||||
q = qp.parse(f"{term.lower()}*")
|
||||
user_criterias: list = get_permissions_criterias(user)
|
||||
|
||||
results = s.search(
|
||||
q,
|
||||
terms=True,
|
||||
filter=query.Or(user_criterias) if user_criterias is not None else None,
|
||||
)
|
||||
|
||||
termCounts = Counter()
|
||||
if results.has_matched_terms():
|
||||
for hit in results:
|
||||
for _, match in hit.matched_terms():
|
||||
termCounts[match] += 1
|
||||
terms = [t for t, _ in termCounts.most_common(limit)]
|
||||
|
||||
term_encoded: bytes = term.encode("UTF-8")
|
||||
if term_encoded in terms:
|
||||
terms.insert(0, terms.pop(terms.index(term_encoded)))
|
||||
|
||||
return terms
|
||||
|
||||
|
||||
def get_permissions_criterias(user: User | None = None) -> list:
|
||||
user_criterias = [query.Term("has_owner", text=False)]
|
||||
if user is not None:
|
||||
if user.is_superuser: # superusers see all docs
|
||||
user_criterias = []
|
||||
else:
|
||||
user_criterias.append(query.Term("owner_id", user.id))
|
||||
user_criterias.append(
|
||||
query.Term("viewer_id", str(user.id)),
|
||||
)
|
||||
return user_criterias
|
||||
|
||||
|
||||
def rewrite_natural_date_keywords(query_string: str) -> str:
|
||||
"""
|
||||
Rewrites natural date keywords (e.g. added:today or added:"yesterday") to UTC range syntax for Whoosh.
|
||||
This resolves timezone issues with date parsing in Whoosh as well as adding support for more
|
||||
natural date keywords.
|
||||
"""
|
||||
|
||||
tz = get_current_timezone()
|
||||
local_now = now().astimezone(tz)
|
||||
today = local_now.date()
|
||||
|
||||
# all supported Keywords
|
||||
pattern = r"(\b(?:added|created|modified))\s*:\s*[\"']?(today|yesterday|this month|previous month|previous week|previous quarter|this year|previous year)[\"']?"
|
||||
|
||||
def repl(m):
|
||||
field = m.group(1)
|
||||
keyword = m.group(2).lower()
|
||||
|
||||
match keyword:
|
||||
case "today":
|
||||
start = datetime.combine(today, time.min, tzinfo=tz)
|
||||
end = datetime.combine(today, time.max, tzinfo=tz)
|
||||
|
||||
case "yesterday":
|
||||
yesterday = today - timedelta(days=1)
|
||||
start = datetime.combine(yesterday, time.min, tzinfo=tz)
|
||||
end = datetime.combine(yesterday, time.max, tzinfo=tz)
|
||||
|
||||
case "this month":
|
||||
start = datetime(local_now.year, local_now.month, 1, 0, 0, 0, tzinfo=tz)
|
||||
end = start + relativedelta(months=1) - timedelta(seconds=1)
|
||||
|
||||
case "previous month":
|
||||
this_month_start = datetime(
|
||||
local_now.year,
|
||||
local_now.month,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
tzinfo=tz,
|
||||
)
|
||||
start = this_month_start - relativedelta(months=1)
|
||||
end = this_month_start - timedelta(seconds=1)
|
||||
|
||||
case "this year":
|
||||
start = datetime(local_now.year, 1, 1, 0, 0, 0, tzinfo=tz)
|
||||
end = datetime(local_now.year, 12, 31, 23, 59, 59, tzinfo=tz)
|
||||
|
||||
case "previous week":
|
||||
days_since_monday = local_now.weekday()
|
||||
this_week_start = datetime.combine(
|
||||
today - timedelta(days=days_since_monday),
|
||||
time.min,
|
||||
tzinfo=tz,
|
||||
)
|
||||
start = this_week_start - timedelta(days=7)
|
||||
end = this_week_start - timedelta(seconds=1)
|
||||
|
||||
case "previous quarter":
|
||||
current_quarter = (local_now.month - 1) // 3 + 1
|
||||
this_quarter_start_month = (current_quarter - 1) * 3 + 1
|
||||
this_quarter_start = datetime(
|
||||
local_now.year,
|
||||
this_quarter_start_month,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
tzinfo=tz,
|
||||
)
|
||||
start = this_quarter_start - relativedelta(months=3)
|
||||
end = this_quarter_start - timedelta(seconds=1)
|
||||
|
||||
case "previous year":
|
||||
start = datetime(local_now.year - 1, 1, 1, 0, 0, 0, tzinfo=tz)
|
||||
end = datetime(local_now.year - 1, 12, 31, 23, 59, 59, tzinfo=tz)
|
||||
|
||||
# Convert to UTC and format
|
||||
start_str = start.astimezone(UTC).strftime("%Y%m%d%H%M%S")
|
||||
end_str = end.astimezone(UTC).strftime("%Y%m%d%H%M%S")
|
||||
return f"{field}:[{start_str} TO {end_str}]"
|
||||
|
||||
return re.sub(pattern, repl, query_string, flags=re.IGNORECASE)
|
||||
@@ -1,3 +1,4 @@
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
|
||||
from documents.management.commands.base import PaperlessCommand
|
||||
@@ -14,10 +15,20 @@ class Command(PaperlessCommand):
|
||||
def add_arguments(self, parser):
|
||||
super().add_arguments(parser)
|
||||
parser.add_argument("command", choices=["reindex", "optimize"])
|
||||
parser.add_argument(
|
||||
"--recreate",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Wipe and recreate the index from scratch (only used with reindex).",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
with transaction.atomic():
|
||||
if options["command"] == "reindex":
|
||||
if options.get("recreate"):
|
||||
from documents.search import wipe_index
|
||||
|
||||
wipe_index(settings.INDEX_DIR)
|
||||
index_reindex(
|
||||
iter_wrapper=lambda docs: self.track(
|
||||
docs,
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
from documents.search._backend import SearchIndexLockError
|
||||
from documents.search._backend import SearchResults
|
||||
from documents.search._backend import TantivyBackend
|
||||
from documents.search._backend import TantivyRelevanceList
|
||||
from documents.search._backend import WriteBatch
|
||||
from documents.search._backend import get_backend
|
||||
from documents.search._backend import reset_backend
|
||||
from documents.search._schema import wipe_index
|
||||
|
||||
__all__ = [
|
||||
"SearchIndexLockError",
|
||||
"SearchResults",
|
||||
"TantivyBackend",
|
||||
"TantivyRelevanceList",
|
||||
"WriteBatch",
|
||||
"get_backend",
|
||||
"reset_backend",
|
||||
"wipe_index",
|
||||
]
|
||||
|
||||
@@ -21,10 +21,10 @@ from guardian.shortcuts import get_users_with_perms
|
||||
|
||||
from documents.search._query import build_permission_filter
|
||||
from documents.search._query import parse_user_query
|
||||
from documents.search._schema import _wipe_index
|
||||
from documents.search._schema import _write_sentinels
|
||||
from documents.search._schema import build_schema
|
||||
from documents.search._schema import open_or_rebuild_index
|
||||
from documents.search._schema import wipe_index
|
||||
from documents.search._tokenizer import register_tokenizers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -95,6 +95,24 @@ class SearchResults:
|
||||
query: str # preprocessed query string
|
||||
|
||||
|
||||
class TantivyRelevanceList:
|
||||
"""DRF-compatible list wrapper for Tantivy search hits.
|
||||
|
||||
__len__ returns the total hit count (for pagination); __getitem__ slices
|
||||
the hit list. Stores ALL post-filter hits so that get_all_result_ids()
|
||||
can return every matching doc ID without a second query.
|
||||
"""
|
||||
|
||||
def __init__(self, hits: list[SearchHit]) -> None:
|
||||
self._hits = hits
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._hits)
|
||||
|
||||
def __getitem__(self, key: slice) -> list[SearchHit]:
|
||||
return self._hits[key]
|
||||
|
||||
|
||||
class SearchIndexLockError(Exception):
|
||||
pass
|
||||
|
||||
@@ -139,9 +157,19 @@ class WriteBatch:
|
||||
if hasattr(self, "_lock") and self._lock:
|
||||
self._lock.release()
|
||||
|
||||
def add_or_update(self, document: Document) -> None:
|
||||
"""Add or update a document in the batch."""
|
||||
doc = self._backend._build_tantivy_doc(document)
|
||||
def add_or_update(
|
||||
self,
|
||||
document: Document,
|
||||
effective_content: str | None = None,
|
||||
) -> None:
|
||||
"""Add or update a document in the batch.
|
||||
|
||||
Tantivy has no native upsert — we delete by id then re-add so
|
||||
stale copies (e.g. after a permission change) don't linger.
|
||||
``effective_content`` overrides ``document.content`` for indexing.
|
||||
"""
|
||||
self.remove(document.pk)
|
||||
doc = self._backend._build_tantivy_doc(document, effective_content)
|
||||
self._writer.add_document(doc)
|
||||
|
||||
def remove(self, doc_id: int) -> None:
|
||||
@@ -175,8 +203,20 @@ class TantivyBackend:
|
||||
# Index doesn't need explicit close
|
||||
pass
|
||||
|
||||
def _build_tantivy_doc(self, document: Document) -> tantivy.Document:
|
||||
"""Build a tantivy Document from a Django Document instance."""
|
||||
def _build_tantivy_doc(
|
||||
self,
|
||||
document: Document,
|
||||
effective_content: str | None = None,
|
||||
) -> tantivy.Document:
|
||||
"""Build a tantivy Document from a Django Document instance.
|
||||
|
||||
``effective_content`` overrides ``document.content`` for indexing —
|
||||
used when re-indexing a root document with a newer version's OCR text.
|
||||
"""
|
||||
content = (
|
||||
effective_content if effective_content is not None else document.content
|
||||
)
|
||||
|
||||
doc = tantivy.Document()
|
||||
|
||||
# Basic fields
|
||||
@@ -184,8 +224,8 @@ class TantivyBackend:
|
||||
doc.add_text("checksum", document.checksum)
|
||||
doc.add_text("title", document.title)
|
||||
doc.add_text("title_sort", document.title)
|
||||
doc.add_text("content", document.content)
|
||||
doc.add_text("bigram_content", document.content)
|
||||
doc.add_text("content", content)
|
||||
doc.add_text("bigram_content", content)
|
||||
|
||||
# Original filename - only add if not None/empty
|
||||
if document.original_filename:
|
||||
@@ -269,7 +309,7 @@ class TantivyBackend:
|
||||
doc.add_unsigned("viewer_id", user.pk)
|
||||
|
||||
# Autocomplete words with NLTK stopword filtering
|
||||
text_sources = [document.title, document.content]
|
||||
text_sources = [document.title, content]
|
||||
if document.correspondent:
|
||||
text_sources.append(document.correspondent.name)
|
||||
if document.document_type:
|
||||
@@ -285,10 +325,14 @@ class TantivyBackend:
|
||||
|
||||
return doc
|
||||
|
||||
def add_or_update(self, document: Document) -> None:
|
||||
def add_or_update(
|
||||
self,
|
||||
document: Document,
|
||||
effective_content: str | None = None,
|
||||
) -> None:
|
||||
"""Add or update a single document with file locking."""
|
||||
with self.batch_update(lock_timeout=5.0) as batch:
|
||||
batch.add_or_update(document)
|
||||
batch.add_or_update(document, effective_content)
|
||||
|
||||
def remove(self, doc_id: int) -> None:
|
||||
"""Remove a single document with file locking."""
|
||||
@@ -440,19 +484,30 @@ class TantivyBackend:
|
||||
query=query,
|
||||
)
|
||||
|
||||
def autocomplete(self, term: str, limit: int) -> list[str]:
|
||||
"""Get autocomplete suggestions."""
|
||||
def autocomplete(
|
||||
self,
|
||||
term: str,
|
||||
limit: int,
|
||||
user: AbstractBaseUser | None = None,
|
||||
) -> list[str]:
|
||||
"""Get autocomplete suggestions, optionally filtered by user visibility."""
|
||||
normalized_term = _ascii_fold(term.lower())
|
||||
|
||||
searcher = self._index.searcher()
|
||||
# Search all documents to collect autocomplete words
|
||||
all_query = tantivy.Query.all_query()
|
||||
results = searcher.search(all_query, limit=10000) # High limit to get all docs
|
||||
|
||||
# Apply permission filter for non-superusers so autocomplete words
|
||||
# from invisible documents don't leak to other users.
|
||||
if user is not None and not user.is_superuser:
|
||||
base_query = build_permission_filter(self._schema, user)
|
||||
else:
|
||||
base_query = tantivy.Query.all_query()
|
||||
|
||||
results = searcher.search(base_query, limit=10000)
|
||||
|
||||
# Collect all autocomplete words
|
||||
words = set()
|
||||
for hit in results.hits:
|
||||
# For all_query, hit is (score, doc_address)
|
||||
# hits are (score, doc_address) tuples
|
||||
doc_address = hit[1] if len(hit) == 2 else hit[0]
|
||||
|
||||
stored_doc = searcher.doc(doc_address)
|
||||
@@ -584,7 +639,7 @@ class TantivyBackend:
|
||||
index_dir = settings.INDEX_DIR
|
||||
|
||||
# Create new index
|
||||
_wipe_index(index_dir)
|
||||
wipe_index(index_dir)
|
||||
new_index = tantivy.Index(build_schema(), path=str(index_dir))
|
||||
_write_sentinels(index_dir)
|
||||
register_tokenizers(new_index, settings.SEARCH_LANGUAGE)
|
||||
|
||||
@@ -8,6 +8,7 @@ from datetime import timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import tantivy
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from django.conf import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -38,6 +39,13 @@ _RELATIVE_RANGE_RE = re.compile(
|
||||
r"\[now([+-]\d+[dhm])?\s+TO\s+now([+-]\d+[dhm])?\]",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# Whoosh-style relative date range: e.g. [-1 week to now], [-7 days to now]
|
||||
_WHOOSH_REL_RANGE_RE = re.compile(
|
||||
r"\[-(?P<n>\d+)\s+(?P<unit>second|minute|hour|day|week|month|year)s?\s+to\s+now\]",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# Whoosh-style 8-digit date: field:YYYYMMDD — field-aware so timezone can be applied correctly
|
||||
_DATE8_RE = re.compile(r"(?P<field>\w+):(?P<date8>\d{8})\b")
|
||||
|
||||
|
||||
def _fmt(dt: datetime) -> str:
|
||||
@@ -200,6 +208,69 @@ def _rewrite_relative_range(query: str) -> str:
|
||||
return _RELATIVE_RANGE_RE.sub(_sub, query)
|
||||
|
||||
|
||||
def _rewrite_whoosh_relative_range(query: str) -> str:
|
||||
"""Rewrite Whoosh-style relative date ranges ([-N unit to now]) to ISO 8601.
|
||||
|
||||
Supports: second, minute, hour, day, week, month, year (singular and plural).
|
||||
Example: ``added:[-1 week to now]`` → ``added:[2025-01-01T… TO 2025-01-08T…]``
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
def _sub(m: re.Match[str]) -> str:
|
||||
n = int(m.group("n"))
|
||||
unit = m.group("unit").lower()
|
||||
delta_map: dict[str, timedelta | relativedelta] = {
|
||||
"second": timedelta(seconds=n),
|
||||
"minute": timedelta(minutes=n),
|
||||
"hour": timedelta(hours=n),
|
||||
"day": timedelta(days=n),
|
||||
"week": timedelta(weeks=n),
|
||||
"month": relativedelta(months=n),
|
||||
"year": relativedelta(years=n),
|
||||
}
|
||||
lo = now - delta_map[unit]
|
||||
return f"[{_fmt(lo)} TO {_fmt(now)}]"
|
||||
|
||||
return _WHOOSH_REL_RANGE_RE.sub(_sub, query)
|
||||
|
||||
|
||||
def _rewrite_8digit_date(query: str, tz: tzinfo) -> str:
|
||||
"""Rewrite field:YYYYMMDD date tokens to an ISO 8601 day range.
|
||||
|
||||
Runs after ``_rewrite_compact_date`` so 14-digit timestamps are already
|
||||
converted and won't spuriously match here.
|
||||
|
||||
For DateField fields (e.g. ``created``) uses UTC midnight boundaries.
|
||||
For DateTimeField fields (e.g. ``added``, ``modified``) uses local TZ
|
||||
midnight boundaries converted to UTC — matching the ``_datetime_range``
|
||||
behaviour for keyword dates.
|
||||
"""
|
||||
|
||||
def _sub(m: re.Match[str]) -> str:
|
||||
field = m.group("field")
|
||||
raw = m.group("date8")
|
||||
try:
|
||||
year, month, day = int(raw[0:4]), int(raw[4:6]), int(raw[6:8])
|
||||
d = date(year, month, day)
|
||||
if field in _DATE_ONLY_FIELDS:
|
||||
lo = datetime(d.year, d.month, d.day, tzinfo=UTC)
|
||||
hi = lo + timedelta(days=1)
|
||||
else:
|
||||
# DateTimeField: use local-timezone midnight → UTC
|
||||
lo = datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
|
||||
hi = datetime(
|
||||
(d + timedelta(days=1)).year,
|
||||
(d + timedelta(days=1)).month,
|
||||
(d + timedelta(days=1)).day,
|
||||
tzinfo=tz,
|
||||
).astimezone(UTC)
|
||||
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
|
||||
except ValueError:
|
||||
return m.group(0)
|
||||
|
||||
return _DATE8_RE.sub(_sub, query)
|
||||
|
||||
|
||||
def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
|
||||
"""
|
||||
Preprocessing stage 1: rewrite Whoosh compact dates, relative ranges,
|
||||
@@ -207,6 +278,8 @@ def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
|
||||
Bare keywords without a field: prefix pass through unchanged.
|
||||
"""
|
||||
query = _rewrite_compact_date(query)
|
||||
query = _rewrite_whoosh_relative_range(query)
|
||||
query = _rewrite_8digit_date(query, tz)
|
||||
query = _rewrite_relative_range(query)
|
||||
|
||||
def _replace(m: re.Match[str]) -> str:
|
||||
|
||||
@@ -101,7 +101,7 @@ def _needs_rebuild(index_dir: Path) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _wipe_index(index_dir: Path) -> None:
|
||||
def wipe_index(index_dir: Path) -> None:
|
||||
"""Delete all children in the index directory to prepare for rebuild."""
|
||||
for child in list(index_dir.iterdir()):
|
||||
if child.is_dir():
|
||||
@@ -123,7 +123,7 @@ def open_or_rebuild_index() -> tantivy.Index:
|
||||
"""
|
||||
index_dir: Path = settings.INDEX_DIR
|
||||
if _needs_rebuild(index_dir):
|
||||
_wipe_index(index_dir)
|
||||
wipe_index(index_dir)
|
||||
idx = tantivy.Index(build_schema(), path=str(index_dir))
|
||||
_write_sentinels(index_dir)
|
||||
return idx
|
||||
|
||||
@@ -1293,22 +1293,18 @@ class SearchResultSerializer(DocumentSerializer):
|
||||
documents = self.context.get("documents")
|
||||
# Otherwise we fetch this document.
|
||||
if documents is None: # pragma: no cover
|
||||
# In practice we only serialize **lists** of whoosh.searching.Hit.
|
||||
# I'm keeping this check for completeness but marking it no cover for now.
|
||||
# In practice we only serialize **lists** of SearchHit dicts.
|
||||
# Keeping this check for completeness but marking it no cover for now.
|
||||
documents = self.fetch_documents([hit["id"]])
|
||||
document = documents[hit["id"]]
|
||||
|
||||
notes = ",".join(
|
||||
[str(c.note) for c in document.notes.all()],
|
||||
)
|
||||
highlights = hit.get("highlights", {})
|
||||
r = super().to_representation(document)
|
||||
r["__search_hit__"] = {
|
||||
"score": hit.score,
|
||||
"highlights": hit.highlights("content", text=document.content),
|
||||
"note_highlights": (
|
||||
hit.highlights("notes", text=notes) if document else None
|
||||
),
|
||||
"rank": hit.rank,
|
||||
"score": hit["score"],
|
||||
"highlights": highlights.get("content", ""),
|
||||
"note_highlights": highlights.get("notes") or None,
|
||||
"rank": hit["rank"],
|
||||
}
|
||||
|
||||
return r
|
||||
|
||||
@@ -790,12 +790,12 @@ def cleanup_user_deletion(sender, instance: User | Group, **kwargs) -> None:
|
||||
|
||||
|
||||
def add_to_index(sender, document, **kwargs) -> None:
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
index.add_or_update_document(document)
|
||||
get_backend().add_or_update(document)
|
||||
if document.root_document_id is not None and document.root_document is not None:
|
||||
# keep in sync when a new version is consumed.
|
||||
index.add_or_update_document(
|
||||
get_backend().add_or_update(
|
||||
document.root_document,
|
||||
effective_content=document.content,
|
||||
)
|
||||
|
||||
+20
-28
@@ -82,27 +82,24 @@ def _identity(iterable: Iterable[_T]) -> Iterable[_T]:
|
||||
|
||||
@shared_task
|
||||
def index_optimize() -> None:
|
||||
from whoosh.writing import AsyncWriter
|
||||
|
||||
from documents import index
|
||||
|
||||
ix = index.open_index()
|
||||
writer = AsyncWriter(ix)
|
||||
writer.commit(optimize=True)
|
||||
logger.info(
|
||||
"document_index optimize is deprecated — Tantivy manages "
|
||||
"segment merging automatically.",
|
||||
)
|
||||
|
||||
|
||||
def index_reindex(*, iter_wrapper: IterWrapper[Document] = _identity) -> None:
|
||||
from whoosh.writing import AsyncWriter
|
||||
from documents.search import get_backend
|
||||
from documents.search import reset_backend
|
||||
|
||||
from documents import index
|
||||
|
||||
documents = Document.objects.all()
|
||||
|
||||
ix = index.open_index(recreate=True)
|
||||
|
||||
with AsyncWriter(ix) as writer:
|
||||
for document in iter_wrapper(documents):
|
||||
index.update_document(writer, document)
|
||||
documents = Document.objects.select_related(
|
||||
"correspondent",
|
||||
"document_type",
|
||||
"storage_path",
|
||||
"owner",
|
||||
).prefetch_related("tags", "notes", "custom_fields")
|
||||
get_backend().rebuild(documents, iter_wrapper=iter_wrapper)
|
||||
reset_backend()
|
||||
|
||||
|
||||
@shared_task
|
||||
@@ -276,14 +273,10 @@ def sanity_check(*, scheduled=True, raise_on_error=True):
|
||||
|
||||
@shared_task
|
||||
def bulk_update_documents(document_ids) -> None:
|
||||
from whoosh.writing import AsyncWriter
|
||||
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
documents = Document.objects.filter(id__in=document_ids)
|
||||
|
||||
ix = index.open_index()
|
||||
|
||||
for doc in documents:
|
||||
clear_document_caches(doc.pk)
|
||||
document_updated.send(
|
||||
@@ -293,9 +286,9 @@ def bulk_update_documents(document_ids) -> None:
|
||||
)
|
||||
post_save.send(Document, instance=doc, created=False)
|
||||
|
||||
with AsyncWriter(ix) as writer:
|
||||
with get_backend().batch_update() as batch:
|
||||
for doc in documents:
|
||||
index.update_document(writer, doc)
|
||||
batch.add_or_update(doc)
|
||||
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled:
|
||||
@@ -310,8 +303,6 @@ def update_document_content_maybe_archive_file(document_id) -> None:
|
||||
Re-creates OCR content and thumbnail for a document, and archive file if
|
||||
it exists.
|
||||
"""
|
||||
from documents import index
|
||||
|
||||
document = Document.objects.get(id=document_id)
|
||||
|
||||
mime_type = document.mime_type
|
||||
@@ -401,8 +392,9 @@ def update_document_content_maybe_archive_file(document_id) -> None:
|
||||
logger.info(
|
||||
f"Updating index for document {document_id} ({document.archive_checksum})",
|
||||
)
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, document)
|
||||
from documents.search import get_backend
|
||||
|
||||
get_backend().add_or_update(document)
|
||||
|
||||
ai_config = AIConfig()
|
||||
if ai_config.llm_index_enabled:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import types
|
||||
from unittest.mock import patch
|
||||
|
||||
import tantivy
|
||||
from django.contrib.admin.sites import AdminSite
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
@@ -8,36 +9,54 @@ from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
from rest_framework import status
|
||||
|
||||
from documents import index
|
||||
from documents.admin import DocumentAdmin
|
||||
from documents.admin import TagAdmin
|
||||
from documents.models import Document
|
||||
from documents.models import Tag
|
||||
from documents.search import get_backend
|
||||
from documents.search import reset_backend
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
from paperless.admin import PaperlessUserAdmin
|
||||
|
||||
|
||||
class TestDocumentAdmin(DirectoriesMixin, TestCase):
|
||||
def get_document_from_index(self, doc):
|
||||
ix = index.open_index()
|
||||
with ix.searcher() as searcher:
|
||||
return searcher.document(id=doc.id)
|
||||
backend = get_backend()
|
||||
searcher = backend._index.searcher()
|
||||
results = searcher.search(
|
||||
tantivy.Query.range_query(
|
||||
backend._schema,
|
||||
"id",
|
||||
tantivy.FieldType.Unsigned,
|
||||
doc.pk,
|
||||
doc.pk,
|
||||
),
|
||||
limit=1,
|
||||
)
|
||||
if results.hits:
|
||||
return searcher.doc(results.hits[0][1]).to_dict()
|
||||
return None
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
reset_backend()
|
||||
self.doc_admin = DocumentAdmin(model=Document, admin_site=AdminSite())
|
||||
|
||||
def tearDown(self) -> None:
|
||||
reset_backend()
|
||||
super().tearDown()
|
||||
|
||||
def test_save_model(self) -> None:
|
||||
doc = Document.objects.create(title="test")
|
||||
|
||||
doc.title = "new title"
|
||||
self.doc_admin.save_model(None, doc, None, None)
|
||||
self.assertEqual(Document.objects.get(id=doc.id).title, "new title")
|
||||
self.assertEqual(self.get_document_from_index(doc)["id"], doc.id)
|
||||
self.assertEqual(self.get_document_from_index(doc)["id"], [doc.id])
|
||||
|
||||
def test_delete_model(self) -> None:
|
||||
doc = Document.objects.create(title="test")
|
||||
index.add_or_update_document(doc)
|
||||
get_backend().add_or_update(doc)
|
||||
self.assertIsNotNone(self.get_document_from_index(doc))
|
||||
|
||||
self.doc_admin.delete_model(None, doc)
|
||||
@@ -53,7 +72,7 @@ class TestDocumentAdmin(DirectoriesMixin, TestCase):
|
||||
checksum=f"{i:02}",
|
||||
)
|
||||
docs.append(doc)
|
||||
index.add_or_update_document(doc)
|
||||
get_backend().add_or_update(doc)
|
||||
|
||||
self.assertEqual(Document.objects.count(), 42)
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
with mock.patch("documents.index.remove_document_from_index"):
|
||||
with mock.patch("documents.search.get_backend"):
|
||||
resp = self.client.delete(f"/api/documents/{root.id}/versions/{root.id}/")
|
||||
|
||||
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
@@ -137,10 +137,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
|
||||
content="v2-content",
|
||||
)
|
||||
|
||||
with (
|
||||
mock.patch("documents.index.remove_document_from_index"),
|
||||
mock.patch("documents.index.add_or_update_document"),
|
||||
):
|
||||
with mock.patch("documents.search.get_backend"):
|
||||
resp = self.client.delete(f"/api/documents/{root.id}/versions/{v2.id}/")
|
||||
|
||||
self.assertEqual(resp.status_code, status.HTTP_200_OK)
|
||||
@@ -149,10 +146,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
|
||||
root.refresh_from_db()
|
||||
self.assertEqual(root.content, "root-content")
|
||||
|
||||
with (
|
||||
mock.patch("documents.index.remove_document_from_index"),
|
||||
mock.patch("documents.index.add_or_update_document"),
|
||||
):
|
||||
with mock.patch("documents.search.get_backend"):
|
||||
resp = self.client.delete(f"/api/documents/{root.id}/versions/{v1.id}/")
|
||||
|
||||
self.assertEqual(resp.status_code, status.HTTP_200_OK)
|
||||
@@ -175,10 +169,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
|
||||
)
|
||||
version_id = version.id
|
||||
|
||||
with (
|
||||
mock.patch("documents.index.remove_document_from_index"),
|
||||
mock.patch("documents.index.add_or_update_document"),
|
||||
):
|
||||
with mock.patch("documents.search.get_backend"):
|
||||
resp = self.client.delete(
|
||||
f"/api/documents/{root.id}/versions/{version_id}/",
|
||||
)
|
||||
@@ -225,7 +216,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
|
||||
root_document=other_root,
|
||||
)
|
||||
|
||||
with mock.patch("documents.index.remove_document_from_index"):
|
||||
with mock.patch("documents.search.get_backend"):
|
||||
resp = self.client.delete(
|
||||
f"/api/documents/{root.id}/versions/{other_version.id}/",
|
||||
)
|
||||
@@ -245,10 +236,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
|
||||
root_document=root,
|
||||
)
|
||||
|
||||
with (
|
||||
mock.patch("documents.index.remove_document_from_index"),
|
||||
mock.patch("documents.index.add_or_update_document"),
|
||||
):
|
||||
with mock.patch("documents.search.get_backend"):
|
||||
resp = self.client.delete(
|
||||
f"/api/documents/{version.id}/versions/{version.id}/",
|
||||
)
|
||||
@@ -275,18 +263,17 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
|
||||
root_document=root,
|
||||
)
|
||||
|
||||
with (
|
||||
mock.patch("documents.index.remove_document_from_index") as remove_index,
|
||||
mock.patch("documents.index.add_or_update_document") as add_or_update,
|
||||
):
|
||||
with mock.patch("documents.search.get_backend") as mock_get_backend:
|
||||
mock_backend = mock.MagicMock()
|
||||
mock_get_backend.return_value = mock_backend
|
||||
resp = self.client.delete(
|
||||
f"/api/documents/{root.id}/versions/{version.id}/",
|
||||
)
|
||||
|
||||
self.assertEqual(resp.status_code, status.HTTP_200_OK)
|
||||
remove_index.assert_called_once_with(version)
|
||||
add_or_update.assert_called_once()
|
||||
self.assertEqual(add_or_update.call_args[0][0].id, root.id)
|
||||
mock_backend.remove.assert_called_once_with(version.pk)
|
||||
mock_backend.add_or_update.assert_called_once()
|
||||
self.assertEqual(mock_backend.add_or_update.call_args[0][0].id, root.id)
|
||||
|
||||
def test_delete_version_returns_403_without_permission(self) -> None:
|
||||
owner = User.objects.create_user(username="owner")
|
||||
|
||||
@@ -11,9 +11,7 @@ from django.utils import timezone
|
||||
from guardian.shortcuts import assign_perm
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
from whoosh.writing import AsyncWriter
|
||||
|
||||
from documents import index
|
||||
from documents.bulk_edit import set_permissions
|
||||
from documents.models import Correspondent
|
||||
from documents.models import CustomField
|
||||
@@ -25,6 +23,8 @@ from documents.models import SavedView
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.models import Workflow
|
||||
from documents.search import get_backend
|
||||
from documents.search import reset_backend
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
from paperless_mail.models import MailAccount
|
||||
from paperless_mail.models import MailRule
|
||||
@@ -33,10 +33,15 @@ from paperless_mail.models import MailRule
|
||||
class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
reset_backend()
|
||||
|
||||
self.user = User.objects.create_superuser(username="temp_admin")
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
reset_backend()
|
||||
super().tearDown()
|
||||
|
||||
def test_search(self) -> None:
|
||||
d1 = Document.objects.create(
|
||||
title="invoice",
|
||||
@@ -57,13 +62,11 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
checksum="C",
|
||||
original_filename="someepdf.pdf",
|
||||
)
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
# Note to future self: there is a reason we dont use a model signal handler to update the index: some operations edit many documents at once
|
||||
# (retagger, renamer) and we don't want to open a writer for each of these, but rather perform the entire operation with one writer.
|
||||
# That's why we can't open the writer in a model on_save handler or something.
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
|
||||
response = self.client.get("/api/documents/?query=bank")
|
||||
results = response.data["results"]
|
||||
self.assertEqual(response.data["count"], 3)
|
||||
@@ -125,10 +128,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
value_int=20,
|
||||
)
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
|
||||
response = self.client.get(
|
||||
f"/api/documents/?query=match&ordering=custom_field_{custom_field.pk}",
|
||||
@@ -149,15 +152,15 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
)
|
||||
|
||||
def test_search_multi_page(self) -> None:
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for i in range(55):
|
||||
doc = Document.objects.create(
|
||||
checksum=str(i),
|
||||
pk=i + 1,
|
||||
title=f"Document {i + 1}",
|
||||
content="content",
|
||||
)
|
||||
index.update_document(writer, doc)
|
||||
backend = get_backend()
|
||||
for i in range(55):
|
||||
doc = Document.objects.create(
|
||||
checksum=str(i),
|
||||
pk=i + 1,
|
||||
title=f"Document {i + 1}",
|
||||
content="content",
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
# This is here so that we test that no document gets returned twice (might happen if the paging is not working)
|
||||
seen_ids = []
|
||||
@@ -184,15 +187,15 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
seen_ids.append(result["id"])
|
||||
|
||||
def test_search_invalid_page(self) -> None:
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for i in range(15):
|
||||
doc = Document.objects.create(
|
||||
checksum=str(i),
|
||||
pk=i + 1,
|
||||
title=f"Document {i + 1}",
|
||||
content="content",
|
||||
)
|
||||
index.update_document(writer, doc)
|
||||
backend = get_backend()
|
||||
for i in range(15):
|
||||
doc = Document.objects.create(
|
||||
checksum=str(i),
|
||||
pk=i + 1,
|
||||
title=f"Document {i + 1}",
|
||||
content="content",
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
response = self.client.get("/api/documents/?query=content&page=0&page_size=10")
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
@@ -230,26 +233,25 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
pk=3,
|
||||
checksum="C",
|
||||
)
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
|
||||
response = self.client.get("/api/documents/?query=added:[-1 week to now]")
|
||||
results = response.data["results"]
|
||||
# Expect 3 documents returned
|
||||
self.assertEqual(len(results), 3)
|
||||
|
||||
for idx, subset in enumerate(
|
||||
[
|
||||
{"id": 1, "title": "invoice"},
|
||||
{"id": 2, "title": "bank statement 1"},
|
||||
{"id": 3, "title": "bank statement 3"},
|
||||
],
|
||||
):
|
||||
result = results[idx]
|
||||
# Assert subset in results
|
||||
self.assertDictEqual(result, {**result, **subset})
|
||||
result_map = {r["id"]: r for r in results}
|
||||
self.assertEqual(set(result_map.keys()), {1, 2, 3})
|
||||
for subset in [
|
||||
{"id": 1, "title": "invoice"},
|
||||
{"id": 2, "title": "bank statement 1"},
|
||||
{"id": 3, "title": "bank statement 3"},
|
||||
]:
|
||||
r = result_map[subset["id"]]
|
||||
self.assertDictEqual(r, {**r, **subset})
|
||||
|
||||
@override_settings(
|
||||
TIME_ZONE="America/Chicago",
|
||||
@@ -285,10 +287,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
# 7 days, 1 hour and 1 minute ago
|
||||
added=timezone.now() - timedelta(days=7, hours=1, minutes=1),
|
||||
)
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
|
||||
response = self.client.get("/api/documents/?query=added:[-1 week to now]")
|
||||
results = response.data["results"]
|
||||
@@ -296,12 +298,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
# Expect 2 documents returned
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
for idx, subset in enumerate(
|
||||
[{"id": 1, "title": "invoice"}, {"id": 2, "title": "bank statement 1"}],
|
||||
):
|
||||
result = results[idx]
|
||||
# Assert subset in results
|
||||
self.assertDictEqual(result, {**result, **subset})
|
||||
result_map = {r["id"]: r for r in results}
|
||||
self.assertEqual(set(result_map.keys()), {1, 2})
|
||||
for subset in [
|
||||
{"id": 1, "title": "invoice"},
|
||||
{"id": 2, "title": "bank statement 1"},
|
||||
]:
|
||||
r = result_map[subset["id"]]
|
||||
self.assertDictEqual(r, {**r, **subset})
|
||||
|
||||
@override_settings(
|
||||
TIME_ZONE="Europe/Sofia",
|
||||
@@ -337,10 +341,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
# 7 days, 1 hour and 1 minute ago
|
||||
added=timezone.now() - timedelta(days=7, hours=1, minutes=1),
|
||||
)
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
|
||||
response = self.client.get("/api/documents/?query=added:[-1 week to now]")
|
||||
results = response.data["results"]
|
||||
@@ -348,12 +352,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
# Expect 2 documents returned
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
for idx, subset in enumerate(
|
||||
[{"id": 1, "title": "invoice"}, {"id": 2, "title": "bank statement 1"}],
|
||||
):
|
||||
result = results[idx]
|
||||
# Assert subset in results
|
||||
self.assertDictEqual(result, {**result, **subset})
|
||||
result_map = {r["id"]: r for r in results}
|
||||
self.assertEqual(set(result_map.keys()), {1, 2})
|
||||
for subset in [
|
||||
{"id": 1, "title": "invoice"},
|
||||
{"id": 2, "title": "bank statement 1"},
|
||||
]:
|
||||
r = result_map[subset["id"]]
|
||||
self.assertDictEqual(r, {**r, **subset})
|
||||
|
||||
def test_search_added_in_last_month(self) -> None:
|
||||
"""
|
||||
@@ -389,10 +395,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
added=timezone.now() - timedelta(days=7, hours=1, minutes=1),
|
||||
)
|
||||
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
|
||||
response = self.client.get("/api/documents/?query=added:[-1 month to now]")
|
||||
results = response.data["results"]
|
||||
@@ -400,12 +406,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
# Expect 2 documents returned
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
for idx, subset in enumerate(
|
||||
[{"id": 1, "title": "invoice"}, {"id": 3, "title": "bank statement 3"}],
|
||||
):
|
||||
result = results[idx]
|
||||
# Assert subset in results
|
||||
self.assertDictEqual(result, {**result, **subset})
|
||||
result_map = {r["id"]: r for r in results}
|
||||
self.assertEqual(set(result_map.keys()), {1, 3})
|
||||
for subset in [
|
||||
{"id": 1, "title": "invoice"},
|
||||
{"id": 3, "title": "bank statement 3"},
|
||||
]:
|
||||
r = result_map[subset["id"]]
|
||||
self.assertDictEqual(r, {**r, **subset})
|
||||
|
||||
@override_settings(
|
||||
TIME_ZONE="America/Denver",
|
||||
@@ -445,10 +453,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
added=timezone.now() - timedelta(days=7, hours=1, minutes=1),
|
||||
)
|
||||
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
|
||||
response = self.client.get("/api/documents/?query=added:[-1 month to now]")
|
||||
results = response.data["results"]
|
||||
@@ -456,12 +464,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
# Expect 2 documents returned
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
for idx, subset in enumerate(
|
||||
[{"id": 1, "title": "invoice"}, {"id": 3, "title": "bank statement 3"}],
|
||||
):
|
||||
result = results[idx]
|
||||
# Assert subset in results
|
||||
self.assertDictEqual(result, {**result, **subset})
|
||||
result_map = {r["id"]: r for r in results}
|
||||
self.assertEqual(set(result_map.keys()), {1, 3})
|
||||
for subset in [
|
||||
{"id": 1, "title": "invoice"},
|
||||
{"id": 3, "title": "bank statement 3"},
|
||||
]:
|
||||
r = result_map[subset["id"]]
|
||||
self.assertDictEqual(r, {**r, **subset})
|
||||
|
||||
@override_settings(
|
||||
TIME_ZONE="Europe/Sofia",
|
||||
@@ -501,10 +511,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
# Django converts dates to UTC
|
||||
d3.refresh_from_db()
|
||||
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
|
||||
response = self.client.get("/api/documents/?query=added:20231201")
|
||||
results = response.data["results"]
|
||||
@@ -512,12 +522,8 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
# Expect 1 document returned
|
||||
self.assertEqual(len(results), 1)
|
||||
|
||||
for idx, subset in enumerate(
|
||||
[{"id": 3, "title": "bank statement 3"}],
|
||||
):
|
||||
result = results[idx]
|
||||
# Assert subset in results
|
||||
self.assertDictEqual(result, {**result, **subset})
|
||||
self.assertEqual(results[0]["id"], 3)
|
||||
self.assertEqual(results[0]["title"], "bank statement 3")
|
||||
|
||||
def test_search_added_invalid_date(self) -> None:
|
||||
"""
|
||||
@@ -526,7 +532,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
WHEN:
|
||||
- Query with invalid added date
|
||||
THEN:
|
||||
- No documents returned
|
||||
- 400 Bad Request returned (Tantivy rejects invalid date field syntax)
|
||||
"""
|
||||
d1 = Document.objects.create(
|
||||
title="invoice",
|
||||
@@ -535,16 +541,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
pk=1,
|
||||
)
|
||||
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, d1)
|
||||
get_backend().add_or_update(d1)
|
||||
|
||||
response = self.client.get("/api/documents/?query=added:invalid-date")
|
||||
results = response.data["results"]
|
||||
|
||||
# Expect 0 document returned
|
||||
self.assertEqual(len(results), 0)
|
||||
# Tantivy rejects unparsable field queries with a 400
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
@mock.patch("documents.index.autocomplete")
|
||||
@mock.patch("documents.search._backend.TantivyBackend.autocomplete")
|
||||
def test_search_autocomplete_limits(self, m) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
@@ -556,7 +560,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
- Limit requests are obeyed
|
||||
"""
|
||||
|
||||
m.side_effect = lambda ix, term, limit, user: [term for _ in range(limit)]
|
||||
m.side_effect = lambda term, limit, user=None: [term for _ in range(limit)]
|
||||
|
||||
response = self.client.get("/api/search/autocomplete/?term=test")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
@@ -609,32 +613,29 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
owner=u1,
|
||||
)
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
|
||||
response = self.client.get("/api/search/autocomplete/?term=app")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data, [b"apples", b"applebaum", b"appletini"])
|
||||
self.assertEqual(response.data, ["applebaum", "apples", "appletini"])
|
||||
|
||||
d3.owner = u2
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
index.update_document(writer, d3)
|
||||
d3.save()
|
||||
backend.add_or_update(d3)
|
||||
|
||||
response = self.client.get("/api/search/autocomplete/?term=app")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data, [b"apples", b"applebaum"])
|
||||
self.assertEqual(response.data, ["applebaum", "apples"])
|
||||
|
||||
assign_perm("view_document", u1, d3)
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
index.update_document(writer, d3)
|
||||
backend.add_or_update(d3)
|
||||
|
||||
response = self.client.get("/api/search/autocomplete/?term=app")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data, [b"apples", b"applebaum", b"appletini"])
|
||||
self.assertEqual(response.data, ["applebaum", "apples", "appletini"])
|
||||
|
||||
def test_search_autocomplete_field_name_match(self) -> None:
|
||||
"""
|
||||
@@ -652,8 +653,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
checksum="1",
|
||||
)
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
index.update_document(writer, d1)
|
||||
get_backend().add_or_update(d1)
|
||||
|
||||
response = self.client.get("/api/search/autocomplete/?term=created:2023")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
@@ -674,33 +674,36 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
checksum="1",
|
||||
)
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
index.update_document(writer, d1)
|
||||
get_backend().add_or_update(d1)
|
||||
|
||||
response = self.client.get("/api/search/autocomplete/?term=auto")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data[0], b"auto")
|
||||
self.assertEqual(response.data[0], "auto")
|
||||
|
||||
def test_search_spelling_suggestion(self) -> None:
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for i in range(55):
|
||||
doc = Document.objects.create(
|
||||
checksum=str(i),
|
||||
pk=i + 1,
|
||||
title=f"Document {i + 1}",
|
||||
content=f"Things document {i + 1}",
|
||||
)
|
||||
index.update_document(writer, doc)
|
||||
def test_search_no_spelling_suggestion(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- Documents exist with various terms
|
||||
WHEN:
|
||||
- Query for documents with any term
|
||||
THEN:
|
||||
- corrected_query is always None (Tantivy has no spell correction)
|
||||
"""
|
||||
backend = get_backend()
|
||||
for i in range(5):
|
||||
doc = Document.objects.create(
|
||||
checksum=str(i),
|
||||
pk=i + 1,
|
||||
title=f"Document {i + 1}",
|
||||
content=f"Things document {i + 1}",
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
response = self.client.get("/api/documents/?query=thing")
|
||||
correction = response.data["corrected_query"]
|
||||
|
||||
self.assertEqual(correction, "things")
|
||||
self.assertIsNone(response.data["corrected_query"])
|
||||
|
||||
response = self.client.get("/api/documents/?query=things")
|
||||
correction = response.data["corrected_query"]
|
||||
|
||||
self.assertEqual(correction, None)
|
||||
self.assertIsNone(response.data["corrected_query"])
|
||||
|
||||
def test_search_spelling_suggestion_suppressed_for_private_terms(self):
|
||||
owner = User.objects.create_user("owner")
|
||||
@@ -709,24 +712,24 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
Permission.objects.get(codename="view_document"),
|
||||
)
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for i in range(55):
|
||||
private_doc = Document.objects.create(
|
||||
checksum=f"p{i}",
|
||||
pk=100 + i,
|
||||
title=f"Private Document {i + 1}",
|
||||
content=f"treasury document {i + 1}",
|
||||
owner=owner,
|
||||
)
|
||||
visible_doc = Document.objects.create(
|
||||
checksum=f"v{i}",
|
||||
pk=200 + i,
|
||||
title=f"Visible Document {i + 1}",
|
||||
content=f"public ledger {i + 1}",
|
||||
owner=attacker,
|
||||
)
|
||||
index.update_document(writer, private_doc)
|
||||
index.update_document(writer, visible_doc)
|
||||
backend = get_backend()
|
||||
for i in range(5):
|
||||
private_doc = Document.objects.create(
|
||||
checksum=f"p{i}",
|
||||
pk=100 + i,
|
||||
title=f"Private Document {i + 1}",
|
||||
content=f"treasury document {i + 1}",
|
||||
owner=owner,
|
||||
)
|
||||
visible_doc = Document.objects.create(
|
||||
checksum=f"v{i}",
|
||||
pk=200 + i,
|
||||
title=f"Visible Document {i + 1}",
|
||||
content=f"public ledger {i + 1}",
|
||||
owner=attacker,
|
||||
)
|
||||
backend.add_or_update(private_doc)
|
||||
backend.add_or_update(visible_doc)
|
||||
|
||||
self.client.force_authenticate(user=attacker)
|
||||
|
||||
@@ -736,26 +739,6 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
self.assertEqual(response.data["count"], 0)
|
||||
self.assertIsNone(response.data["corrected_query"])
|
||||
|
||||
@mock.patch(
|
||||
"whoosh.searching.Searcher.correct_query",
|
||||
side_effect=Exception("Test error"),
|
||||
)
|
||||
def test_corrected_query_error(self, mock_correct_query) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A query that raises an error on correction
|
||||
WHEN:
|
||||
- API request for search with that query
|
||||
THEN:
|
||||
- The error is logged and the search proceeds
|
||||
"""
|
||||
with self.assertLogs("paperless.index", level="INFO") as cm:
|
||||
response = self.client.get("/api/documents/?query=2025-06-04")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
error_str = cm.output[0]
|
||||
expected_str = "Error while correcting query '2025-06-04': Test error"
|
||||
self.assertIn(expected_str, error_str)
|
||||
|
||||
def test_search_more_like(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
@@ -790,11 +773,11 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
pk=4,
|
||||
checksum="ABC",
|
||||
)
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
index.update_document(writer, d4)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
backend.add_or_update(d4)
|
||||
|
||||
response = self.client.get(f"/api/documents/?more_like_id={d2.id}")
|
||||
|
||||
@@ -802,9 +785,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
|
||||
results = response.data["results"]
|
||||
|
||||
self.assertEqual(len(results), 2)
|
||||
self.assertEqual(results[0]["id"], d3.id)
|
||||
self.assertEqual(results[1]["id"], d1.id)
|
||||
self.assertGreaterEqual(len(results), 1)
|
||||
result_ids = [r["id"] for r in results]
|
||||
self.assertIn(d3.id, result_ids)
|
||||
self.assertNotIn(d4.id, result_ids)
|
||||
|
||||
def test_search_more_like_requires_view_permission_on_seed_document(
|
||||
self,
|
||||
@@ -846,10 +830,10 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
pk=12,
|
||||
)
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
index.update_document(writer, private_seed)
|
||||
index.update_document(writer, visible_doc)
|
||||
index.update_document(writer, other_doc)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(private_seed)
|
||||
backend.add_or_update(visible_doc)
|
||||
backend.add_or_update(other_doc)
|
||||
|
||||
self.client.force_authenticate(user=attacker)
|
||||
|
||||
@@ -923,9 +907,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
value_text="foobard4",
|
||||
)
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for doc in Document.objects.all():
|
||||
index.update_document(writer, doc)
|
||||
backend = get_backend()
|
||||
for doc in Document.objects.all():
|
||||
backend.add_or_update(doc)
|
||||
|
||||
def search_query(q):
|
||||
r = self.client.get("/api/documents/?query=test" + q)
|
||||
@@ -1141,9 +1125,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
Document.objects.create(checksum="3", content="test 3", owner=u2)
|
||||
Document.objects.create(checksum="4", content="test 4")
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for doc in Document.objects.all():
|
||||
index.update_document(writer, doc)
|
||||
backend = get_backend()
|
||||
for doc in Document.objects.all():
|
||||
backend.add_or_update(doc)
|
||||
|
||||
self.client.force_authenticate(user=u1)
|
||||
r = self.client.get("/api/documents/?query=test")
|
||||
@@ -1194,9 +1178,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
d3 = Document.objects.create(checksum="3", content="test 3", owner=u2)
|
||||
Document.objects.create(checksum="4", content="test 4")
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for doc in Document.objects.all():
|
||||
index.update_document(writer, doc)
|
||||
backend = get_backend()
|
||||
for doc in Document.objects.all():
|
||||
backend.add_or_update(doc)
|
||||
|
||||
self.client.force_authenticate(user=u1)
|
||||
r = self.client.get("/api/documents/?query=test")
|
||||
@@ -1216,9 +1200,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
assign_perm("view_document", u1, d3)
|
||||
assign_perm("view_document", u2, d1)
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for doc in [d1, d2, d3]:
|
||||
index.update_document(writer, doc)
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
|
||||
self.client.force_authenticate(user=u1)
|
||||
r = self.client.get("/api/documents/?query=test")
|
||||
@@ -1281,9 +1265,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
user=u1,
|
||||
)
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for doc in Document.objects.all():
|
||||
index.update_document(writer, doc)
|
||||
backend = get_backend()
|
||||
for doc in Document.objects.all():
|
||||
backend.add_or_update(doc)
|
||||
|
||||
def search_query(q):
|
||||
r = self.client.get("/api/documents/?query=test" + q)
|
||||
@@ -1316,13 +1300,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
search_query("&ordering=-num_notes"),
|
||||
[d1.id, d3.id, d2.id],
|
||||
)
|
||||
# owner sort: ORM orders by owner_id (integer); NULLs first in SQLite ASC
|
||||
self.assertListEqual(
|
||||
search_query("&ordering=owner"),
|
||||
[d1.id, d2.id, d3.id],
|
||||
[d3.id, d1.id, d2.id],
|
||||
)
|
||||
self.assertListEqual(
|
||||
search_query("&ordering=-owner"),
|
||||
[d3.id, d2.id, d1.id],
|
||||
[d2.id, d1.id, d3.id],
|
||||
)
|
||||
|
||||
@mock.patch("documents.bulk_edit.bulk_update_documents")
|
||||
@@ -1379,12 +1364,12 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
)
|
||||
set_permissions([4, 5], set_permissions={}, owner=user2, merge=False)
|
||||
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
index.update_document(writer, d4)
|
||||
index.update_document(writer, d5)
|
||||
backend = get_backend()
|
||||
backend.add_or_update(d1)
|
||||
backend.add_or_update(d2)
|
||||
backend.add_or_update(d3)
|
||||
backend.add_or_update(d4)
|
||||
backend.add_or_update(d5)
|
||||
|
||||
correspondent1 = Correspondent.objects.create(name="bank correspondent 1")
|
||||
Correspondent.objects.create(name="correspondent 2")
|
||||
|
||||
@@ -191,40 +191,42 @@ class TestSystemStatus(APITestCase):
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
|
||||
|
||||
@override_settings(INDEX_DIR=Path("/tmp/index"))
|
||||
@mock.patch("whoosh.index.FileIndex.last_modified")
|
||||
def test_system_status_index_ok(self, mock_last_modified) -> None:
|
||||
@mock.patch("documents.search.get_backend")
|
||||
def test_system_status_index_ok(self, mock_get_backend) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- The index last modified time is set
|
||||
- The index is accessible
|
||||
WHEN:
|
||||
- The user requests the system status
|
||||
THEN:
|
||||
- The response contains the correct index status
|
||||
"""
|
||||
mock_last_modified.return_value = 1707839087
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
mock_get_backend.return_value = mock.MagicMock()
|
||||
# Use the temp dir created in setUp (self.tmp_dir) as a real INDEX_DIR
|
||||
# with a real file so the mtime lookup works
|
||||
sentinel = self.tmp_dir / "sentinel.txt"
|
||||
sentinel.write_text("ok")
|
||||
with self.settings(INDEX_DIR=self.tmp_dir):
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["index_status"], "OK")
|
||||
self.assertIsNotNone(response.data["tasks"]["index_last_modified"])
|
||||
|
||||
@override_settings(INDEX_DIR=Path("/tmp/index/"))
|
||||
@mock.patch("documents.index.open_index", autospec=True)
|
||||
def test_system_status_index_error(self, mock_open_index) -> None:
|
||||
@mock.patch("documents.search.get_backend")
|
||||
def test_system_status_index_error(self, mock_get_backend) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- The index is not found
|
||||
- The index cannot be opened
|
||||
WHEN:
|
||||
- The user requests the system status
|
||||
THEN:
|
||||
- The response contains the correct index status
|
||||
"""
|
||||
mock_open_index.return_value = None
|
||||
mock_open_index.side_effect = Exception("Index error")
|
||||
mock_get_backend.side_effect = Exception("Index error")
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
mock_open_index.assert_called_once()
|
||||
mock_get_backend.assert_called_once()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["index_status"], "ERROR")
|
||||
self.assertIsNotNone(response.data["tasks"]["index_error"])
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
from django.test import TestCase
|
||||
from whoosh import query
|
||||
|
||||
from documents.index import get_permissions_criterias
|
||||
from documents.models import User
|
||||
|
||||
|
||||
class TestDelayedQuery(TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
# all tests run without permission criteria, so has_no_owner query will always
|
||||
# be appended.
|
||||
self.has_no_owner = query.Or([query.Term("has_owner", text=False)])
|
||||
|
||||
def _get_testset__id__in(self, param, field):
|
||||
return (
|
||||
{f"{param}__id__in": "42,43"},
|
||||
query.And(
|
||||
[
|
||||
query.Or(
|
||||
[
|
||||
query.Term(f"{field}_id", "42"),
|
||||
query.Term(f"{field}_id", "43"),
|
||||
],
|
||||
),
|
||||
self.has_no_owner,
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
def _get_testset__id__none(self, param, field):
|
||||
return (
|
||||
{f"{param}__id__none": "42,43"},
|
||||
query.And(
|
||||
[
|
||||
query.Not(query.Term(f"{field}_id", "42")),
|
||||
query.Not(query.Term(f"{field}_id", "43")),
|
||||
self.has_no_owner,
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
def test_get_permission_criteria(self) -> None:
|
||||
# tests contains tuples of user instances and the expected filter
|
||||
tests = (
|
||||
(None, [query.Term("has_owner", text=False)]),
|
||||
(User(42, username="foo", is_superuser=True), []),
|
||||
(
|
||||
User(42, username="foo", is_superuser=False),
|
||||
[
|
||||
query.Term("has_owner", text=False),
|
||||
query.Term("owner_id", 42),
|
||||
query.Term("viewer_id", "42"),
|
||||
],
|
||||
),
|
||||
)
|
||||
for user, expected in tests:
|
||||
self.assertEqual(get_permissions_criterias(user), expected)
|
||||
@@ -1,371 +0,0 @@
|
||||
from datetime import datetime
|
||||
from unittest import mock
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import SimpleTestCase
|
||||
from django.test import TestCase
|
||||
from django.test import override_settings
|
||||
from django.utils.timezone import get_current_timezone
|
||||
from django.utils.timezone import timezone
|
||||
|
||||
from documents import index
|
||||
from documents.models import Document
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
class TestAutoComplete(DirectoriesMixin, TestCase):
|
||||
def test_auto_complete(self) -> None:
|
||||
doc1 = Document.objects.create(
|
||||
title="doc1",
|
||||
checksum="A",
|
||||
content="test test2 test3",
|
||||
)
|
||||
doc2 = Document.objects.create(title="doc2", checksum="B", content="test test2")
|
||||
doc3 = Document.objects.create(title="doc3", checksum="C", content="test2")
|
||||
|
||||
index.add_or_update_document(doc1)
|
||||
index.add_or_update_document(doc2)
|
||||
index.add_or_update_document(doc3)
|
||||
|
||||
ix = index.open_index()
|
||||
|
||||
self.assertListEqual(
|
||||
index.autocomplete(ix, "tes"),
|
||||
[b"test2", b"test", b"test3"],
|
||||
)
|
||||
self.assertListEqual(
|
||||
index.autocomplete(ix, "tes", limit=3),
|
||||
[b"test2", b"test", b"test3"],
|
||||
)
|
||||
self.assertListEqual(index.autocomplete(ix, "tes", limit=1), [b"test2"])
|
||||
self.assertListEqual(index.autocomplete(ix, "tes", limit=0), [])
|
||||
|
||||
def test_archive_serial_number_ranging(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- Document with an archive serial number above schema allowed size
|
||||
WHEN:
|
||||
- Document is provided to the index
|
||||
THEN:
|
||||
- Error is logged
|
||||
- Document ASN is reset to 0 for the index
|
||||
"""
|
||||
doc1 = Document.objects.create(
|
||||
title="doc1",
|
||||
checksum="A",
|
||||
content="test test2 test3",
|
||||
# yes, this is allowed, unless full_clean is run
|
||||
# DRF does call the validators, this test won't
|
||||
archive_serial_number=Document.ARCHIVE_SERIAL_NUMBER_MAX + 1,
|
||||
)
|
||||
with self.assertLogs("paperless.index", level="ERROR") as cm:
|
||||
with mock.patch(
|
||||
"documents.index.AsyncWriter.update_document",
|
||||
) as mocked_update_doc:
|
||||
index.add_or_update_document(doc1)
|
||||
|
||||
mocked_update_doc.assert_called_once()
|
||||
_, kwargs = mocked_update_doc.call_args
|
||||
|
||||
self.assertEqual(kwargs["asn"], 0)
|
||||
|
||||
error_str = cm.output[0]
|
||||
expected_str = "ERROR:paperless.index:Not indexing Archive Serial Number 4294967296 of document 1"
|
||||
self.assertIn(expected_str, error_str)
|
||||
|
||||
def test_archive_serial_number_is_none(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- Document with no archive serial number
|
||||
WHEN:
|
||||
- Document is provided to the index
|
||||
THEN:
|
||||
- ASN isn't touched
|
||||
"""
|
||||
doc1 = Document.objects.create(
|
||||
title="doc1",
|
||||
checksum="A",
|
||||
content="test test2 test3",
|
||||
)
|
||||
with mock.patch(
|
||||
"documents.index.AsyncWriter.update_document",
|
||||
) as mocked_update_doc:
|
||||
index.add_or_update_document(doc1)
|
||||
|
||||
mocked_update_doc.assert_called_once()
|
||||
_, kwargs = mocked_update_doc.call_args
|
||||
|
||||
self.assertIsNone(kwargs["asn"])
|
||||
|
||||
@override_settings(TIME_ZONE="Pacific/Auckland")
|
||||
def test_added_today_respects_local_timezone_boundary(self) -> None:
|
||||
tz = get_current_timezone()
|
||||
fixed_now = datetime(2025, 7, 20, 15, 0, 0, tzinfo=tz)
|
||||
|
||||
# Fake a time near the local boundary (1 AM NZT = 13:00 UTC on previous UTC day)
|
||||
local_dt = datetime(2025, 7, 20, 1, 0, 0).replace(tzinfo=tz)
|
||||
utc_dt = local_dt.astimezone(timezone.utc)
|
||||
|
||||
doc = Document.objects.create(
|
||||
title="Time zone",
|
||||
content="Testing added:today",
|
||||
checksum="edgecase123",
|
||||
added=utc_dt,
|
||||
)
|
||||
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, doc)
|
||||
|
||||
superuser = User.objects.create_superuser(username="testuser")
|
||||
self.client.force_login(superuser)
|
||||
|
||||
with mock.patch("documents.index.now", return_value=fixed_now):
|
||||
response = self.client.get("/api/documents/?query=added:today")
|
||||
results = response.json()["results"]
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]["id"], doc.id)
|
||||
|
||||
response = self.client.get("/api/documents/?query=added:yesterday")
|
||||
results = response.json()["results"]
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
|
||||
@override_settings(TIME_ZONE="UTC")
|
||||
class TestRewriteNaturalDateKeywords(SimpleTestCase):
|
||||
"""
|
||||
Unit tests for rewrite_natural_date_keywords function.
|
||||
"""
|
||||
|
||||
def _rewrite_with_now(self, query: str, now_dt: datetime) -> str:
|
||||
with mock.patch("documents.index.now", return_value=now_dt):
|
||||
return index.rewrite_natural_date_keywords(query)
|
||||
|
||||
def _assert_rewrite_contains(
|
||||
self,
|
||||
query: str,
|
||||
now_dt: datetime,
|
||||
*expected_fragments: str,
|
||||
) -> str:
|
||||
result = self._rewrite_with_now(query, now_dt)
|
||||
for fragment in expected_fragments:
|
||||
self.assertIn(fragment, result)
|
||||
return result
|
||||
|
||||
def test_range_keywords(self) -> None:
|
||||
"""
|
||||
Test various different range keywords
|
||||
"""
|
||||
cases = [
|
||||
(
|
||||
"added:today",
|
||||
datetime(2025, 7, 20, 15, 30, 45, tzinfo=timezone.utc),
|
||||
("added:[20250720", "TO 20250720"),
|
||||
),
|
||||
(
|
||||
"added:yesterday",
|
||||
datetime(2025, 7, 20, 15, 30, 45, tzinfo=timezone.utc),
|
||||
("added:[20250719", "TO 20250719"),
|
||||
),
|
||||
(
|
||||
"added:this month",
|
||||
datetime(2025, 7, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
("added:[20250701", "TO 20250731"),
|
||||
),
|
||||
(
|
||||
"added:previous month",
|
||||
datetime(2025, 7, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
("added:[20250601", "TO 20250630"),
|
||||
),
|
||||
(
|
||||
"added:this year",
|
||||
datetime(2025, 7, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
("added:[20250101", "TO 20251231"),
|
||||
),
|
||||
(
|
||||
"added:previous year",
|
||||
datetime(2025, 7, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
("added:[20240101", "TO 20241231"),
|
||||
),
|
||||
# Previous quarter from July 15, 2025 is April-June.
|
||||
(
|
||||
"added:previous quarter",
|
||||
datetime(2025, 7, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
("added:[20250401", "TO 20250630"),
|
||||
),
|
||||
# July 20, 2025 is a Sunday (weekday 6) so previous week is July 7-13.
|
||||
(
|
||||
"added:previous week",
|
||||
datetime(2025, 7, 20, 12, 0, 0, tzinfo=timezone.utc),
|
||||
("added:[20250707", "TO 20250713"),
|
||||
),
|
||||
]
|
||||
|
||||
for query, now_dt, fragments in cases:
|
||||
with self.subTest(query=query):
|
||||
self._assert_rewrite_contains(query, now_dt, *fragments)
|
||||
|
||||
def test_additional_fields(self) -> None:
|
||||
fixed_now = datetime(2025, 7, 20, 15, 30, 45, tzinfo=timezone.utc)
|
||||
# created
|
||||
self._assert_rewrite_contains("created:today", fixed_now, "created:[20250720")
|
||||
# modified
|
||||
self._assert_rewrite_contains("modified:today", fixed_now, "modified:[20250720")
|
||||
|
||||
def test_basic_syntax_variants(self) -> None:
|
||||
"""
|
||||
Test that quoting, casing, and multi-clause queries are parsed.
|
||||
"""
|
||||
fixed_now = datetime(2025, 7, 20, 15, 30, 45, tzinfo=timezone.utc)
|
||||
|
||||
# quoted keywords
|
||||
result1 = self._rewrite_with_now('added:"today"', fixed_now)
|
||||
result2 = self._rewrite_with_now("added:'today'", fixed_now)
|
||||
self.assertIn("added:[20250720", result1)
|
||||
self.assertIn("added:[20250720", result2)
|
||||
|
||||
# case insensitivity
|
||||
for query in ("added:TODAY", "added:Today", "added:ToDaY"):
|
||||
with self.subTest(case_variant=query):
|
||||
self._assert_rewrite_contains(query, fixed_now, "added:[20250720")
|
||||
|
||||
# multiple clauses
|
||||
result = self._rewrite_with_now("added:today created:yesterday", fixed_now)
|
||||
self.assertIn("added:[20250720", result)
|
||||
self.assertIn("created:[20250719", result)
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
"""
|
||||
Test that queries without keywords are unchanged.
|
||||
"""
|
||||
query = "title:test content:example"
|
||||
result = index.rewrite_natural_date_keywords(query)
|
||||
self.assertEqual(query, result)
|
||||
|
||||
@override_settings(TIME_ZONE="Pacific/Auckland")
|
||||
def test_timezone_awareness(self) -> None:
|
||||
"""
|
||||
Test timezone conversion.
|
||||
"""
|
||||
# July 20, 2025 1:00 AM NZST = July 19, 2025 13:00 UTC
|
||||
fixed_now = datetime(2025, 7, 20, 1, 0, 0, tzinfo=get_current_timezone())
|
||||
result = self._rewrite_with_now("added:today", fixed_now)
|
||||
# Should convert to UTC properly
|
||||
self.assertIn("added:[20250719", result)
|
||||
|
||||
|
||||
class TestIndexResilience(DirectoriesMixin, SimpleTestCase):
|
||||
def _assert_recreate_called(self, mock_create_in) -> None:
|
||||
mock_create_in.assert_called_once()
|
||||
path_arg, schema_arg = mock_create_in.call_args.args
|
||||
self.assertEqual(path_arg, settings.INDEX_DIR)
|
||||
self.assertEqual(schema_arg.__class__.__name__, "Schema")
|
||||
|
||||
def test_transient_missing_segment_does_not_force_recreate(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- Index directory exists
|
||||
WHEN:
|
||||
- open_index is called
|
||||
- Opening the index raises FileNotFoundError once due to a
|
||||
transient missing segment
|
||||
THEN:
|
||||
- Index is opened successfully on retry
|
||||
- Index is not recreated
|
||||
"""
|
||||
file_marker = settings.INDEX_DIR / "file_marker.txt"
|
||||
file_marker.write_text("keep")
|
||||
expected_index = object()
|
||||
|
||||
with (
|
||||
mock.patch("documents.index.exists_in", return_value=True),
|
||||
mock.patch(
|
||||
"documents.index.open_dir",
|
||||
side_effect=[FileNotFoundError("missing"), expected_index],
|
||||
) as mock_open_dir,
|
||||
mock.patch(
|
||||
"documents.index.create_in",
|
||||
) as mock_create_in,
|
||||
mock.patch(
|
||||
"documents.index.rmtree",
|
||||
) as mock_rmtree,
|
||||
):
|
||||
ix = index.open_index()
|
||||
|
||||
self.assertIs(ix, expected_index)
|
||||
self.assertGreaterEqual(mock_open_dir.call_count, 2)
|
||||
mock_rmtree.assert_not_called()
|
||||
mock_create_in.assert_not_called()
|
||||
self.assertEqual(file_marker.read_text(), "keep")
|
||||
|
||||
def test_transient_errors_exhaust_retries_and_recreate(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- Index directory exists
|
||||
WHEN:
|
||||
- open_index is called
|
||||
- Opening the index raises FileNotFoundError multiple times due to
|
||||
transient missing segments
|
||||
THEN:
|
||||
- Index is recreated after retries are exhausted
|
||||
"""
|
||||
recreated_index = object()
|
||||
|
||||
with (
|
||||
self.assertLogs("paperless.index", level="ERROR") as cm,
|
||||
mock.patch("documents.index.exists_in", return_value=True),
|
||||
mock.patch(
|
||||
"documents.index.open_dir",
|
||||
side_effect=FileNotFoundError("missing"),
|
||||
) as mock_open_dir,
|
||||
mock.patch("documents.index.rmtree") as mock_rmtree,
|
||||
mock.patch(
|
||||
"documents.index.create_in",
|
||||
return_value=recreated_index,
|
||||
) as mock_create_in,
|
||||
):
|
||||
ix = index.open_index()
|
||||
|
||||
self.assertIs(ix, recreated_index)
|
||||
self.assertEqual(mock_open_dir.call_count, 4)
|
||||
mock_rmtree.assert_called_once_with(settings.INDEX_DIR)
|
||||
self._assert_recreate_called(mock_create_in)
|
||||
self.assertIn(
|
||||
"Error while opening the index after retries, recreating.",
|
||||
cm.output[0],
|
||||
)
|
||||
|
||||
def test_non_transient_error_recreates_index(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- Index directory exists
|
||||
WHEN:
|
||||
- open_index is called
|
||||
- Opening the index raises a "non-transient" error
|
||||
THEN:
|
||||
- Index is recreated
|
||||
"""
|
||||
recreated_index = object()
|
||||
|
||||
with (
|
||||
self.assertLogs("paperless.index", level="ERROR") as cm,
|
||||
mock.patch("documents.index.exists_in", return_value=True),
|
||||
mock.patch(
|
||||
"documents.index.open_dir",
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
mock.patch("documents.index.rmtree") as mock_rmtree,
|
||||
mock.patch(
|
||||
"documents.index.create_in",
|
||||
return_value=recreated_index,
|
||||
) as mock_create_in,
|
||||
):
|
||||
ix = index.open_index()
|
||||
|
||||
self.assertIs(ix, recreated_index)
|
||||
mock_rmtree.assert_called_once_with(settings.INDEX_DIR)
|
||||
self._assert_recreate_called(mock_create_in)
|
||||
self.assertIn(
|
||||
"Error while opening the index, recreating.",
|
||||
cm.output[0],
|
||||
)
|
||||
@@ -452,7 +452,10 @@ class TestDocumentConsumptionFinishedSignal(TestCase):
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
from documents.search import reset_backend
|
||||
|
||||
TestCase.setUp(self)
|
||||
reset_backend()
|
||||
User.objects.create_user(username="test_consumer", password="12345")
|
||||
self.doc_contains = Document.objects.create(
|
||||
content="I contain the keyword.",
|
||||
@@ -464,6 +467,9 @@ class TestDocumentConsumptionFinishedSignal(TestCase):
|
||||
override_settings(INDEX_DIR=self.index_dir).enable()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
from documents.search import reset_backend
|
||||
|
||||
reset_backend()
|
||||
shutil.rmtree(self.index_dir, ignore_errors=True)
|
||||
|
||||
def test_tag_applied_any(self) -> None:
|
||||
|
||||
@@ -208,10 +208,12 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase):
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
|
||||
with mock.patch("documents.index.add_or_update_document") as add:
|
||||
with mock.patch("documents.search.get_backend") as mock_get_backend:
|
||||
mock_backend = mock.MagicMock()
|
||||
mock_get_backend.return_value = mock_backend
|
||||
add_to_index(sender=None, document=root)
|
||||
|
||||
add.assert_called_once_with(root)
|
||||
mock_backend.add_or_update.assert_called_once_with(root)
|
||||
|
||||
def test_add_to_index_reindexes_root_for_version_documents(self) -> None:
|
||||
root = Document.objects.create(
|
||||
@@ -226,13 +228,21 @@ class TestTaskSignalHandler(DirectoriesMixin, TestCase):
|
||||
root_document=root,
|
||||
)
|
||||
|
||||
with mock.patch("documents.index.add_or_update_document") as add:
|
||||
with mock.patch("documents.search.get_backend") as mock_get_backend:
|
||||
mock_backend = mock.MagicMock()
|
||||
mock_get_backend.return_value = mock_backend
|
||||
add_to_index(sender=None, document=version)
|
||||
|
||||
self.assertEqual(add.call_count, 2)
|
||||
self.assertEqual(add.call_args_list[0].args[0].id, version.id)
|
||||
self.assertEqual(add.call_args_list[1].args[0].id, root.id)
|
||||
self.assertEqual(mock_backend.add_or_update.call_count, 2)
|
||||
self.assertEqual(
|
||||
add.call_args_list[1].kwargs,
|
||||
mock_backend.add_or_update.call_args_list[0].args[0].id,
|
||||
version.id,
|
||||
)
|
||||
self.assertEqual(
|
||||
mock_backend.add_or_update.call_args_list[1].args[0].id,
|
||||
root.id,
|
||||
)
|
||||
self.assertEqual(
|
||||
mock_backend.add_or_update.call_args_list[1].kwargs,
|
||||
{"effective_content": version.content},
|
||||
)
|
||||
|
||||
@@ -157,11 +157,17 @@ class DirectoriesMixin:
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
from documents.search import reset_backend
|
||||
|
||||
reset_backend()
|
||||
self.dirs = setup_directories()
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
from documents.search import reset_backend
|
||||
|
||||
super().tearDown()
|
||||
reset_backend()
|
||||
remove_dirs(self.dirs)
|
||||
|
||||
|
||||
|
||||
+112
-100
@@ -100,7 +100,6 @@ from rest_framework.viewsets import ReadOnlyModelViewSet
|
||||
from rest_framework.viewsets import ViewSet
|
||||
|
||||
from documents import bulk_edit
|
||||
from documents import index
|
||||
from documents.bulk_download import ArchiveOnlyStrategy
|
||||
from documents.bulk_download import OriginalAndArchiveStrategy
|
||||
from documents.bulk_download import OriginalsOnlyStrategy
|
||||
@@ -972,9 +971,9 @@ class DocumentViewSet(
|
||||
response_data["content"] = content_doc.content
|
||||
response = Response(response_data)
|
||||
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
index.add_or_update_document(refreshed_doc)
|
||||
get_backend().add_or_update(refreshed_doc)
|
||||
|
||||
document_updated.send(
|
||||
sender=self.__class__,
|
||||
@@ -984,9 +983,9 @@ class DocumentViewSet(
|
||||
return response
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
index.remove_document_from_index(self.get_object())
|
||||
get_backend().remove(self.get_object().pk)
|
||||
try:
|
||||
return super().destroy(request, *args, **kwargs)
|
||||
except Exception as e:
|
||||
@@ -1393,9 +1392,9 @@ class DocumentViewSet(
|
||||
doc.modified = timezone.now()
|
||||
doc.save()
|
||||
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
index.add_or_update_document(doc)
|
||||
get_backend().add_or_update(doc)
|
||||
|
||||
notes = serializer.to_representation(doc).get("notes")
|
||||
|
||||
@@ -1430,9 +1429,9 @@ class DocumentViewSet(
|
||||
doc.modified = timezone.now()
|
||||
doc.save()
|
||||
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
index.add_or_update_document(doc)
|
||||
get_backend().add_or_update(doc)
|
||||
|
||||
notes = serializer.to_representation(doc).get("notes")
|
||||
|
||||
@@ -1744,12 +1743,13 @@ class DocumentViewSet(
|
||||
"Cannot delete the root/original version. Delete the document instead.",
|
||||
)
|
||||
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
index.remove_document_from_index(version_doc)
|
||||
_backend = get_backend()
|
||||
_backend.remove(version_doc.pk)
|
||||
version_doc_id = version_doc.id
|
||||
version_doc.delete()
|
||||
index.add_or_update_document(root_doc)
|
||||
_backend.add_or_update(root_doc)
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
actor = (
|
||||
request.user if request.user and request.user.is_authenticated else None
|
||||
@@ -1949,10 +1949,6 @@ class ChatStreamingView(GenericAPIView):
|
||||
),
|
||||
)
|
||||
class UnifiedSearchViewSet(DocumentViewSet):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.searcher = None
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self._is_search_request():
|
||||
return SearchResultSerializer
|
||||
@@ -1965,17 +1961,34 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||
or "more_like_id" in self.request.query_params
|
||||
)
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
filtered_queryset = super().filter_queryset(queryset)
|
||||
def list(self, request, *args, **kwargs):
|
||||
if not self._is_search_request():
|
||||
return super().list(request)
|
||||
|
||||
if self._is_search_request():
|
||||
if "query" in self.request.query_params:
|
||||
from documents import index
|
||||
from documents.search import TantivyRelevanceList
|
||||
from documents.search import get_backend
|
||||
|
||||
query_class = index.DelayedFullTextQuery
|
||||
elif "more_like_id" in self.request.query_params:
|
||||
try:
|
||||
backend = get_backend()
|
||||
# ORM-filtered queryset: permissions + field filters + ordering (DRF backends applied)
|
||||
filtered_qs = self.filter_queryset(self.get_queryset())
|
||||
|
||||
user = None if request.user.is_superuser else request.user
|
||||
|
||||
if "query" in request.query_params:
|
||||
query_str = request.query_params["query"]
|
||||
results = backend.search(
|
||||
query_str,
|
||||
user=user,
|
||||
page=1,
|
||||
page_size=10000,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
else:
|
||||
# more_like_id — validate permission on the seed document first
|
||||
try:
|
||||
more_like_doc_id = int(self.request.query_params["more_like_id"])
|
||||
more_like_doc_id = int(request.query_params["more_like_id"])
|
||||
more_like_doc = Document.objects.select_related("owner").get(
|
||||
pk=more_like_doc_id,
|
||||
)
|
||||
@@ -1983,61 +1996,62 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||
raise PermissionDenied(_("Invalid more_like_id"))
|
||||
|
||||
if not has_perms_owner_aware(
|
||||
self.request.user,
|
||||
request.user,
|
||||
"view_document",
|
||||
more_like_doc,
|
||||
):
|
||||
raise PermissionDenied(_("Insufficient permissions."))
|
||||
|
||||
from documents import index
|
||||
|
||||
query_class = index.DelayedMoreLikeThisQuery
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return query_class(
|
||||
self.searcher,
|
||||
self.request.query_params,
|
||||
self.paginator.get_page_size(self.request),
|
||||
filter_queryset=filtered_queryset,
|
||||
)
|
||||
else:
|
||||
return filtered_queryset
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
if self._is_search_request():
|
||||
from documents import index
|
||||
|
||||
try:
|
||||
with index.open_index_searcher() as s:
|
||||
self.searcher = s
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
page = self.paginate_queryset(queryset)
|
||||
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
response = self.get_paginated_response(serializer.data)
|
||||
|
||||
response.data["corrected_query"] = (
|
||||
queryset.suggested_correction
|
||||
if hasattr(queryset, "suggested_correction")
|
||||
else None
|
||||
)
|
||||
|
||||
return response
|
||||
except NotFound:
|
||||
raise
|
||||
except PermissionDenied as e:
|
||||
invalid_more_like_id_message = _("Invalid more_like_id")
|
||||
if str(e.detail) == str(invalid_more_like_id_message):
|
||||
return HttpResponseForbidden(invalid_more_like_id_message)
|
||||
return HttpResponseForbidden(_("Insufficient permissions."))
|
||||
except Exception as e:
|
||||
logger.warning(f"An error occurred listing search results: {e!s}")
|
||||
return HttpResponseBadRequest(
|
||||
"Error listing search results, check logs for more detail.",
|
||||
results = backend.more_like_this(
|
||||
more_like_doc_id,
|
||||
user=user,
|
||||
page=1,
|
||||
page_size=10000,
|
||||
)
|
||||
else:
|
||||
return super().list(request)
|
||||
|
||||
hits_by_id = {h["id"]: h for h in results.hits}
|
||||
|
||||
# Determine sort order: no ordering param → Tantivy relevance; otherwise → ORM order
|
||||
ordering_param = request.query_params.get("ordering", "").lstrip("-")
|
||||
if not ordering_param:
|
||||
# Preserve Tantivy relevance order; intersect with ORM-visible IDs
|
||||
orm_ids = set(filtered_qs.values_list("pk", flat=True))
|
||||
ordered_hits = [h for h in results.hits if h["id"] in orm_ids]
|
||||
else:
|
||||
# Use ORM ordering (already applied by DocumentsOrderingFilter)
|
||||
hit_ids = set(hits_by_id.keys())
|
||||
orm_ordered_ids = filtered_qs.filter(id__in=hit_ids).values_list(
|
||||
"pk",
|
||||
flat=True,
|
||||
)
|
||||
ordered_hits = [
|
||||
hits_by_id[pk] for pk in orm_ordered_ids if pk in hits_by_id
|
||||
]
|
||||
|
||||
rl = TantivyRelevanceList(ordered_hits)
|
||||
page = self.paginate_queryset(rl)
|
||||
|
||||
if page is not None:
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
response = self.get_paginated_response(serializer.data)
|
||||
response.data["corrected_query"] = None
|
||||
return response
|
||||
|
||||
serializer = self.get_serializer(ordered_hits, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
except NotFound:
|
||||
raise
|
||||
except PermissionDenied as e:
|
||||
invalid_more_like_id_message = _("Invalid more_like_id")
|
||||
if str(e.detail) == str(invalid_more_like_id_message):
|
||||
return HttpResponseForbidden(invalid_more_like_id_message)
|
||||
return HttpResponseForbidden(_("Insufficient permissions."))
|
||||
except Exception as e:
|
||||
logger.warning(f"An error occurred listing search results: {e!s}")
|
||||
return HttpResponseBadRequest(
|
||||
"Error listing search results, check logs for more detail.",
|
||||
)
|
||||
|
||||
@action(detail=False, methods=["GET"], name="Get Next ASN")
|
||||
def next_asn(self, request, *args, **kwargs):
|
||||
@@ -2816,18 +2830,9 @@ class SearchAutoCompleteView(GenericAPIView):
|
||||
else:
|
||||
limit = 10
|
||||
|
||||
from documents import index
|
||||
from documents.search import get_backend
|
||||
|
||||
ix = index.open_index()
|
||||
|
||||
return Response(
|
||||
index.autocomplete(
|
||||
ix,
|
||||
term,
|
||||
limit,
|
||||
user,
|
||||
),
|
||||
)
|
||||
return Response(get_backend().autocomplete(term, limit, user))
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
@@ -2893,20 +2898,21 @@ class GlobalSearchView(PassUserMixin):
|
||||
# First search by title
|
||||
docs = all_docs.filter(title__icontains=query)
|
||||
if not db_only and len(docs) < OBJECT_LIMIT:
|
||||
# If we don't have enough results, search by content
|
||||
from documents import index
|
||||
# If we don't have enough results, search by content.
|
||||
# Over-fetch from Tantivy (no permission filter) and rely on
|
||||
# the ORM all_docs queryset for authoritative permission gating.
|
||||
from documents.search import get_backend
|
||||
|
||||
with index.open_index_searcher() as s:
|
||||
fts_query = index.DelayedFullTextQuery(
|
||||
s,
|
||||
request.query_params,
|
||||
OBJECT_LIMIT,
|
||||
filter_queryset=all_docs,
|
||||
)
|
||||
results = fts_query[0:1]
|
||||
docs = docs | Document.objects.filter(
|
||||
id__in=[r["id"] for r in results],
|
||||
)
|
||||
fts_results = get_backend().search(
|
||||
query,
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=1000,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
fts_ids = {h["id"] for h in fts_results.hits}
|
||||
docs = docs | all_docs.filter(id__in=fts_ids)
|
||||
docs = docs[:OBJECT_LIMIT]
|
||||
saved_views = (
|
||||
get_objects_for_user_owner_aware(
|
||||
@@ -4105,10 +4111,16 @@ class SystemStatusView(PassUserMixin):
|
||||
|
||||
index_error = None
|
||||
try:
|
||||
ix = index.open_index()
|
||||
from documents.search import get_backend
|
||||
|
||||
get_backend() # triggers open/rebuild; raises on error
|
||||
index_status = "OK"
|
||||
index_last_modified = make_aware(
|
||||
datetime.fromtimestamp(ix.last_modified()),
|
||||
# Use the most-recently modified file in the index directory as a proxy
|
||||
# for last index write time (Tantivy has no single last_modified() call).
|
||||
index_dir = settings.INDEX_DIR
|
||||
mtimes = [p.stat().st_mtime for p in index_dir.iterdir() if p.is_file()]
|
||||
index_last_modified = (
|
||||
make_aware(datetime.fromtimestamp(max(mtimes))) if mtimes else None
|
||||
)
|
||||
except Exception as e:
|
||||
index_status = "ERROR"
|
||||
|
||||
+4
-14
@@ -71,22 +71,12 @@ class StandardPagination(PageNumberPagination):
|
||||
)
|
||||
|
||||
def get_all_result_ids(self):
|
||||
from documents.index import DelayedQuery # removed with Whoosh in Task 14
|
||||
from documents.search import TantivyRelevanceList
|
||||
|
||||
query = self.page.paginator.object_list
|
||||
if isinstance(query, DelayedQuery):
|
||||
try:
|
||||
ids = [
|
||||
query.searcher.ixreader.stored_fields(
|
||||
doc_num,
|
||||
)["id"]
|
||||
for doc_num in query.saved_results.get(0).results.docs()
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
ids = self.page.paginator.object_list.values_list("pk", flat=True)
|
||||
return ids
|
||||
if isinstance(query, TantivyRelevanceList):
|
||||
return [h["id"] for h in query._hits]
|
||||
return self.page.paginator.object_list.values_list("pk", flat=True)
|
||||
|
||||
def get_paginated_response_schema(self, schema):
|
||||
response_schema = super().get_paginated_response_schema(schema)
|
||||
|
||||
Reference in New Issue
Block a user