Merge branch 'beta' into dev

This commit is contained in:
shamoon
2026-06-02 08:32:54 -07:00
81 changed files with 3722 additions and 634 deletions
+15
View File
@@ -8,6 +8,8 @@ from typing import TYPE_CHECKING
import filelock
import pytest
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.models import ContentType
from guardian.shortcuts import clear_ct_cache
from pytest_django.fixtures import SettingsWrapper
from rest_framework.test import APIClient
@@ -158,6 +160,19 @@ def user_client(rest_api_client: APIClient, regular_user: UserModelT) -> APIClie
return rest_api_client
@pytest.fixture(autouse=True)
def _clear_content_type_caches() -> None:
"""Clear Django's ContentType cache and guardian's lru_cache before each test.
Tests that delete and reinsert ContentType/Permission rows (e.g. the
importer) corrupt both caches. Without this fixture a subsequent test on
the same xdist worker sees stale ContentType objects and guardian raises
MixedContentTypeError.
"""
ContentType.objects.clear_cache()
clear_ct_cache()
@pytest.fixture(scope="session", autouse=True)
def faker_session_locale():
"""Set Faker locale for reproducibility."""
+206
View File
@@ -1,5 +1,6 @@
import pytest
from django.contrib.auth.models import User
from pytest_mock import MockerFixture
from documents.models import CustomField
from documents.models import CustomFieldInstance
@@ -7,8 +8,13 @@ from documents.models import Document
from documents.models import Note
from documents.search._backend import SearchMode
from documents.search._backend import TantivyBackend
from documents.search._backend import WriteBatch
from documents.search._backend import get_backend
from documents.search._backend import reset_backend
from documents.tests.factories import CorrespondentFactory
from documents.tests.factories import DocumentFactory
from documents.tests.factories import DocumentTypeFactory
from documents.tests.factories import TagFactory
pytestmark = [pytest.mark.search, pytest.mark.django_db]
@@ -36,6 +42,47 @@ class TestWriteBatch:
ids = backend.search_ids("should survive", user=None)
assert len(ids) == 1
def test_writer_released_when_commit_fails(
self,
backend: TantivyBackend,
mocker: MockerFixture,
) -> None:
"""A commit failure must still dispose the writer (released in finally).
Otherwise the Tantivy IndexWriter lingers holding its internal lock and
the next batch fails with LockBusy. The real writer is created in
__enter__; here commit() is forced to raise via a mocked _writer.
"""
doc = Document.objects.create(
title="Commit Fail",
content="indexable text",
checksum="WBCF1",
pk=42,
)
failing = mocker.MagicMock()
failing.commit.side_effect = RuntimeError("simulated commit failure")
mocker.patch.object(
WriteBatch,
"_writer",
new_callable=mocker.PropertyMock,
return_value=failing,
)
batch = backend.batch_update()
with pytest.raises(RuntimeError, match="simulated commit failure"):
with batch as b:
b.add_or_update(doc)
# Writer disposed despite the commit failure.
assert batch._raw_writer is None
# Drop the patch so a real writer can be created; a fresh batch must
# succeed (would raise LockBusy if the previous writer had leaked).
mocker.stopall()
backend.add_or_update(doc)
assert len(backend.search_ids("indexable", user=None)) == 1
class TestSearch:
"""Test search query parsing and matching via search_ids."""
@@ -214,6 +261,153 @@ class TestSearch:
== 1
)
@pytest.mark.parametrize(
("mode", "title", "content", "hits", "misses"),
[
pytest.param(
SearchMode.QUERY,
"CJK document",
"東京都の人口は約1400万人です",
["東京", "人口"],
["大阪"],
id="query_mode_cjk_content",
),
pytest.param(
SearchMode.TEXT,
"CJK document",
"東京都の人口は約1400万人です",
["東京"],
["大阪"],
id="text_mode_cjk_content",
),
pytest.param(
SearchMode.TITLE,
"東京都の報告書",
"This document is about Tokyo.",
["東京", "報告"],
["大阪"],
id="title_mode_cjk_title",
),
],
)
def test_cjk_search_finds_matching_documents(
self,
backend: TantivyBackend,
mode: SearchMode,
title: str,
content: str,
hits: list[str],
misses: list[str],
) -> None:
"""CJK queries must match documents via bigram fields in all three search modes."""
doc = DocumentFactory(title=title, content=content)
backend.add_or_update(doc)
for query in hits:
assert len(backend.search_ids(query, user=None, search_mode=mode)) == 1, (
f"Expected {query!r} to match in {mode} mode"
)
for query in misses:
assert len(backend.search_ids(query, user=None, search_mode=mode)) == 0, (
f"Expected {query!r} not to match in {mode} mode"
)
def test_title_mode_cjk_does_not_match_content_only(
self,
backend: TantivyBackend,
) -> None:
"""Title-only CJK search must not return docs where CJK appears only in content."""
doc = DocumentFactory(
title="Tokyo report",
content="東京都の人口は約1400万人です",
)
backend.add_or_update(doc)
assert (
len(backend.search_ids("東京", user=None, search_mode=SearchMode.TITLE))
== 0
)
@pytest.mark.parametrize(
("field", "query", "miss"),
[
pytest.param("correspondent", "東京", "大阪", id="cjk_correspondent"),
pytest.param("document_type", "請求書", "領収書", id="cjk_document_type"),
pytest.param("tag", "重要", "普通", id="cjk_tag"),
],
)
def test_cjk_metadata_search_via_query_mode(
self,
backend: TantivyBackend,
field: str,
query: str,
miss: str,
) -> None:
"""CJK in correspondent/document_type/tag names must be searchable via global search."""
if field == "correspondent":
doc = DocumentFactory(correspondent=CorrespondentFactory(name=query))
elif field == "document_type":
doc = DocumentFactory(document_type=DocumentTypeFactory(name=query))
else:
tag = TagFactory(name=query)
doc = DocumentFactory()
doc.tags.add(tag)
backend.add_or_update(doc)
assert (
len(backend.search_ids(query, user=None, search_mode=SearchMode.QUERY)) == 1
), f"Expected CJK {field} name {query!r} to match"
assert (
len(backend.search_ids(miss, user=None, search_mode=SearchMode.QUERY)) == 0
), f"Expected {miss!r} not to match"
def test_cjk_text_mode_does_not_leak_field_query_semantics(
self,
backend: TantivyBackend,
) -> None:
"""TEXT mode is plain-text over content: a 'field:CJK' input must not be
parsed as a structured query against that field. A doc tagged 重要 with
no 重要 in its content must NOT match the TEXT-mode query 'tag:重要'."""
tag = TagFactory(name="重要")
doc = DocumentFactory(title="report", content="just english content")
doc.tags.add(tag)
backend.add_or_update(doc)
assert (
len(backend.search_ids("tag:重要", user=None, search_mode=SearchMode.TEXT))
== 0
)
# Sanity: the CJK run still matches when it is actually in the content.
doc2 = DocumentFactory(title="report2", content="本文に重要な情報")
backend.add_or_update(doc2)
assert (
len(backend.search_ids("tag:重要", user=None, search_mode=SearchMode.TEXT))
== 1
)
@pytest.mark.parametrize(
"query",
[
pytest.param("Straße", id="eszett"),
pytest.param("Ærøskøbing", id="ae_and_oslash"),
pytest.param("strasse", id="ascii_fold_form"),
],
)
def test_simple_search_folds_special_letters_like_index(
self,
backend: TantivyBackend,
query: str,
) -> None:
"""Query-side folding must match index-side folding for non-decomposable
letters (ß→ss, ø→o, ...). Searching the accented form must find the doc.
A naive NFD fold deletes these letters and silently fails to match."""
doc = DocumentFactory(title="report", content="Straße Ærøskøbing")
backend.add_or_update(doc)
assert (
len(backend.search_ids(query, user=None, search_mode=SearchMode.TEXT)) == 1
)
def test_sort_field_ascending(self, backend: TantivyBackend) -> None:
"""Searching with sort_reverse=False must return results in ascending ASN order."""
for asn in [30, 10, 20]:
@@ -393,6 +587,18 @@ class TestAutocomplete:
results = backend.autocomplete("pay", limit=10)
assert results.index("payment") < results.index("payslip")
def test_folds_special_letters_consistently(
self,
backend: TantivyBackend,
) -> None:
"""Autocomplete words must fold the same way as content (ß→ss), so a
prefix of the folded form finds them. A naive NFD fold would store the
word as 'strae' and the prefix 'stras' would never match it."""
doc = DocumentFactory(title="Straße", content="details")
backend.add_or_update(doc)
assert "strasse" in backend.autocomplete("stras", limit=10)
class TestMoreLikeThis:
"""Test more like this functionality."""
@@ -0,0 +1,248 @@
"""Tests for search index lock backoff, retry logic, and self-healing deferred tasks."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import filelock
import pytest
from documents.search._backend import _LOCK_BACKOFF_CAP
from documents.search._backend import _LOCK_RETRY_ATTEMPTS
from documents.search._backend import _LOCK_TIMEOUT_SECONDS
from documents.search._backend import SearchIndexLockError
from documents.search._backend import TantivyBackend
from documents.tasks import index_document
from documents.tasks import remove_document_from_index
from documents.tests.factories import DocumentFactory
if TYPE_CHECKING:
from collections.abc import Generator
from pathlib import Path
from pytest_mock import MockerFixture
pytestmark = pytest.mark.search
@pytest.fixture
def disk_backend(tmp_path: Path) -> Generator[TantivyBackend, None, None]:
"""On-disk TantivyBackend so the file-lock code path is exercised."""
b = TantivyBackend(path=tmp_path)
b.open()
try:
yield b
finally:
b.close()
class TestWriteBatchLockRetry:
"""Test WriteBatch retry loop with backoff + full jitter."""
@pytest.mark.django_db
def test_lock_retries_then_succeeds(
self,
disk_backend: TantivyBackend,
mocker: MockerFixture,
) -> None:
"""Timeout on first 3 attempts then success on 4th — document must be indexed."""
doc = DocumentFactory()
acquire_calls = 0
def flaky_acquire(timeout: float) -> None:
nonlocal acquire_calls
acquire_calls += 1
# Raise Timeout for first _LOCK_RETRY_ATTEMPTS - 1 calls, succeed on last
if acquire_calls < _LOCK_RETRY_ATTEMPTS:
raise filelock.Timeout("")
sleep_values: list[float] = []
mocker.patch(
"documents.search._backend.filelock.FileLock.acquire",
side_effect=flaky_acquire,
)
mock_sleep = mocker.patch(
"documents.search._backend.time.sleep",
side_effect=lambda s: sleep_values.append(s),
)
# Should not raise — 4th attempt succeeds
with disk_backend.batch_update(lock_timeout=_LOCK_TIMEOUT_SECONDS) as batch:
batch.add_or_update(doc)
# sleep called exactly _LOCK_RETRY_ATTEMPTS - 1 times (once per failed attempt)
assert mock_sleep.call_count == _LOCK_RETRY_ATTEMPTS - 1
# All sleep values must be in [0, _LOCK_BACKOFF_CAP]
for s in sleep_values:
assert 0 <= s <= _LOCK_BACKOFF_CAP, (
f"Sleep value {s} outside [0, {_LOCK_BACKOFF_CAP}]"
)
def test_lock_exhaustion_raises_search_index_lock_error(
self,
disk_backend: TantivyBackend,
mocker: MockerFixture,
) -> None:
"""All acquire attempts raise Timeout — WriteBatch must raise SearchIndexLockError."""
mocker.patch(
"documents.search._backend.filelock.FileLock.acquire",
side_effect=filelock.Timeout(""),
)
mocker.patch("documents.search._backend.time.sleep")
with pytest.raises(SearchIndexLockError):
with disk_backend.batch_update(lock_timeout=_LOCK_TIMEOUT_SECONDS):
pass
def test_jitter_values_in_range(
self,
disk_backend: TantivyBackend,
mocker: MockerFixture,
) -> None:
"""Sleep values must always lie in [0, _LOCK_BACKOFF_CAP] across many samples."""
mocker.patch(
"documents.search._backend.filelock.FileLock.acquire",
side_effect=filelock.Timeout(""),
)
sleep_values: list[float] = []
mocker.patch(
"documents.search._backend.time.sleep",
side_effect=lambda s: sleep_values.append(s),
)
for _ in range(50):
sleep_values.clear()
with pytest.raises(SearchIndexLockError):
with disk_backend.batch_update(lock_timeout=_LOCK_TIMEOUT_SECONDS):
pass
for s in sleep_values:
assert 0 <= s <= _LOCK_BACKOFF_CAP, (
f"Jitter {s} exceeds cap {_LOCK_BACKOFF_CAP}"
)
class TestAddOrUpdateDeferredScheduling:
"""Test that add_or_update() and remove() defer to Celery on lock exhaustion."""
@pytest.mark.django_db
def test_lock_exhaustion_schedules_deferred_task(
self,
disk_backend: TantivyBackend,
mocker: MockerFixture,
) -> None:
"""Lock exhaustion in add_or_update must schedule index_document task, not raise."""
doc = DocumentFactory()
mocker.patch(
"documents.search._backend.filelock.FileLock.acquire",
side_effect=filelock.Timeout(""),
)
mocker.patch("documents.search._backend.time.sleep")
mock_apply = mocker.patch("documents.tasks.index_document.apply_async")
# Must NOT raise
disk_backend.add_or_update(doc)
mock_apply.assert_called_once_with(args=[doc.pk], countdown=60)
def test_remove_exhaustion_schedules_deferred_task(
self,
disk_backend: TantivyBackend,
mocker: MockerFixture,
) -> None:
"""Lock exhaustion in remove() must schedule remove_document_from_index task, not raise."""
doc_id = 503
mocker.patch(
"documents.search._backend.filelock.FileLock.acquire",
side_effect=filelock.Timeout(""),
)
mocker.patch("documents.search._backend.time.sleep")
mock_apply = mocker.patch(
"documents.tasks.remove_document_from_index.apply_async",
)
# Must NOT raise
disk_backend.remove(doc_id)
mock_apply.assert_called_once_with(args=[doc_id], countdown=60)
@pytest.mark.django_db
class TestIndexDocumentTask:
"""Test the deferred index_document and remove_document_from_index Celery tasks."""
def test_index_document_task_skips_deleted_document(
self,
caplog: pytest.LogCaptureFixture,
) -> None:
"""index_document with a non-existent doc_id must return cleanly and log INFO."""
nonexistent_id = 999999
with caplog.at_level(logging.INFO, logger="paperless.tasks"):
index_document(nonexistent_id)
assert any("no longer exists" in record.message for record in caplog.records), (
"Expected INFO log about missing document"
)
def test_index_document_task_indexes_existing_document(
self,
backend: TantivyBackend,
mocker: MockerFixture,
) -> None:
"""index_document task must add the document to the index via batch_update."""
doc = DocumentFactory(content="via deferred task")
# get_backend is imported lazily inside the task: `from documents.search import get_backend`
mocker.patch(
"documents.search.get_backend",
return_value=backend,
)
index_document(doc.pk)
ids = backend.search_ids("deferred task", user=None)
assert doc.pk in ids
def test_remove_document_from_index_task_removes_existing_document(
self,
backend: TantivyBackend,
mocker: MockerFixture,
) -> None:
"""remove_document_from_index task must remove the document from the index."""
doc = DocumentFactory(content="will be removed by deferred task")
backend.add_or_update(doc)
assert doc.pk in backend.search_ids("removed", user=None)
mocker.patch("documents.search.get_backend", return_value=backend)
remove_document_from_index(doc.pk)
assert doc.pk not in backend.search_ids("removed", user=None)
def test_task_does_not_swallow_lock_error(
self,
mocker: MockerFixture,
) -> None:
"""Verifies the task body propagates SearchIndexLockError so Celery's
autoretry_for can catch it (rather than the task swallowing the error
and silently succeeding)."""
doc = DocumentFactory()
mock_batch = mocker.MagicMock()
mock_batch.__enter__ = mocker.MagicMock(
side_effect=SearchIndexLockError("exhausted"),
)
mock_batch.__exit__ = mocker.MagicMock(return_value=False)
mock_backend = mocker.MagicMock()
mock_backend.batch_update.return_value = mock_batch
# get_backend is imported lazily inside the task: `from documents.search import get_backend`
mocker.patch("documents.search.get_backend", return_value=mock_backend)
with pytest.raises(SearchIndexLockError):
index_document(doc.pk)
+243 -1
View File
@@ -16,6 +16,7 @@ from documents.search._query import _datetime_range
from documents.search._query import _rewrite_compact_date
from documents.search._query import build_permission_filter
from documents.search._query import normalize_query
from documents.search._query import parse_simple_text_highlight_query
from documents.search._query import parse_user_query
from documents.search._query import rewrite_natural_date_keywords
from documents.search._schema import build_schema
@@ -443,6 +444,149 @@ class TestParseUserQuery:
q = parse_user_query(query_index, "created:today", UTC)
assert isinstance(q, tantivy.Query)
@pytest.mark.parametrize(
"raw_query",
[
pytest.param("h52.1 - kurzsichtigkeit", id="icd_code_dash_description"),
pytest.param("H52.1 - asd", id="icd_code_uppercase"),
pytest.param("h52.1 -", id="trailing_minus"),
pytest.param(". -", id="dot_trailing_minus"),
pytest.param("h52. -", id="partial_code_trailing_minus"),
pytest.param(".12 -", id="dot_number_trailing_minus"),
pytest.param("h52.1 - ku", id="partial_word_after_dash"),
],
)
def test_spaced_dash_queries_do_not_raise(
self,
query_index: tantivy.Index,
raw_query: str,
) -> None:
assert isinstance(parse_user_query(query_index, raw_query, UTC), tantivy.Query)
class TestYearRangeRewriting:
"""Whoosh-style year-only date ranges must be rewritten to ISO 8601."""
@pytest.mark.parametrize(
("query", "field", "expected_lo", "expected_hi"),
[
pytest.param(
"created:[2020 TO 2020]",
"created",
"2020-01-01T00:00:00Z",
"2021-01-01T00:00:00Z",
id="single_year_created",
),
pytest.param(
"created:[2018 TO 2021]",
"created",
"2018-01-01T00:00:00Z",
"2022-01-01T00:00:00Z",
id="multi_year_range_created",
),
pytest.param(
"added:[2022 TO 2023]",
"added",
"2022-01-01T00:00:00Z",
"2024-01-01T00:00:00Z",
id="added_field",
),
pytest.param(
"modified:[2021 TO 2021]",
"modified",
"2021-01-01T00:00:00Z",
"2022-01-01T00:00:00Z",
id="modified_field",
),
pytest.param(
"created:[2020 to 2020]",
"created",
"2020-01-01T00:00:00Z",
"2021-01-01T00:00:00Z",
id="lowercase_to_keyword",
),
],
)
def test_year_range_rewritten(
self,
query: str,
field: str,
expected_lo: str,
expected_hi: str,
) -> None:
result = rewrite_natural_date_keywords(query, UTC)
lo, hi = _range(result, field)
assert lo == expected_lo
assert hi == expected_hi
def test_reversed_year_range_is_swapped(self) -> None:
# A reversed range must not yield lo > hi, which Tantivy treats as an
# empty range (silently zero results). The bounds are swapped instead.
result = rewrite_natural_date_keywords("created:[2025 TO 2020]", UTC)
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2026-01-01T00:00:00Z"
def test_year_range_in_complex_boolean_query(self) -> None:
query = "tag:steuer AND (title:2020 OR (NOT title:2019 AND NOT title:2018 AND created:[2020 TO 2020]))"
result = rewrite_natural_date_keywords(query, UTC)
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2021-01-01T00:00:00Z"
assert "title:2020" in result
assert "title:2019" in result
assert "title:2018" in result
def test_already_iso_date_range_passes_through_unchanged(self) -> None:
original = "created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]"
assert rewrite_natural_date_keywords(original, UTC) == original
def test_8digit_in_brackets_not_matched_as_year_range(self) -> None:
# [YYYYMMDD TO YYYYMMDD] has 8-digit values - must not be caught by year rewriter
original = "created:[20200101 TO 20201231]"
result = rewrite_natural_date_keywords(original, UTC)
assert "20200101" in result or "2020-01-01" in result
assert "20201231" in result or "2020-12-31" in result
class TestNonDateFieldsNotRewritten:
"""Date rewriters must only fire on the date fields (created/modified/added).
Integer fields like asn/id/page_count and unknown fields would otherwise be
rewritten into date ranges and rejected by Tantivy as type mismatches.
"""
@pytest.mark.parametrize(
"query",
[
pytest.param("asn:20240101", id="asn_8digit"),
pytest.param("id:20240101", id="id_8digit"),
pytest.param("page_count:12345678", id="page_count_8digit"),
pytest.param("num_notes:20231201", id="num_notes_8digit"),
],
)
def test_8digit_on_integer_field_passes_through_unchanged(self, query: str) -> None:
assert rewrite_natural_date_keywords(query, EASTERN) == query
@pytest.mark.parametrize(
"query",
[
pytest.param("asn:[2000 TO 2024]", id="asn_year_range"),
pytest.param("id:[2000 TO 2024]", id="id_year_range"),
pytest.param("page_count:[2000 TO 2024]", id="page_count_year_range"),
],
)
def test_year_range_on_integer_field_passes_through_unchanged(
self,
query: str,
) -> None:
assert rewrite_natural_date_keywords(query, UTC) == query
def test_unknown_field_keyword_passes_through_unchanged(self) -> None:
# foobar is not a date field: 'foobar:today' must not become a date range,
# which Tantivy would otherwise reject as an unknown/typed field.
assert rewrite_natural_date_keywords("foobar:today", UTC) == "foobar:today"
class TestPassthrough:
"""Queries without field prefixes or unrelated content pass through unchanged."""
@@ -471,10 +615,108 @@ class TestNormalizeQuery:
def test_normalize_no_commas_unchanged(self) -> None:
assert normalize_query("bank statement") == "bank statement"
@pytest.mark.parametrize(
("raw", "expected"),
[
pytest.param(
"h52.1 - kurzsichtigkeit",
"h52.1 kurzsichtigkeit",
id="icd_code_dash_description",
),
pytest.param(
"H52.1 - asd",
"H52.1 asd",
id="icd_code_uppercase_dash",
),
pytest.param(
"h52.1 -",
"h52.1",
id="trailing_minus",
),
pytest.param(
". -",
".",
id="dot_trailing_minus",
),
pytest.param(
"h52. -",
"h52.",
id="partial_code_trailing_minus",
),
pytest.param(
"foo - bar - baz",
"foo bar baz",
id="multiple_dashes",
),
pytest.param(
"foo + bar",
"foo bar",
id="spaced_plus_operator",
),
],
)
def test_normalize_strips_dangling_operators(self, raw: str, expected: str) -> None:
assert normalize_query(raw) == expected
@pytest.mark.parametrize(
"query",
[
pytest.param("term -other", id="adjacent_not_operator"),
pytest.param("-term", id="leading_not_operator"),
pytest.param("+term", id="leading_must_operator"),
pytest.param("foo -bar +baz", id="mixed_adjacent_operators"),
],
)
def test_normalize_preserves_valid_operators(self, query: str) -> None:
assert normalize_query(query) == query
class TestParseSimpleTextHighlightQuery:
"""parse_simple_text_highlight_query must not raise on natural-language queries."""
@pytest.fixture
def query_index(self) -> tantivy.Index:
schema = build_schema()
idx = tantivy.Index(schema, path=None)
register_tokenizers(idx, "")
return idx
@pytest.mark.parametrize(
"raw_query",
[
pytest.param("h52.1 - kurzsichtigkeit", id="icd_code_dash_description"),
pytest.param("H52.1 - asd", id="icd_code_uppercase"),
pytest.param("h52.1 -", id="trailing_minus"),
pytest.param(". -", id="dot_trailing_minus"),
pytest.param(".12 -", id="dot_number_trailing_minus"),
pytest.param("f84.0 - v.a. autismusspektrumstorung", id="complex_icd_dash"),
],
)
def test_spaced_dash_queries_do_not_raise(
self,
query_index: tantivy.Index,
raw_query: str,
) -> None:
assert isinstance(
parse_simple_text_highlight_query(query_index, raw_query),
tantivy.Query,
)
def test_empty_query_returns_empty_query(self, query_index: tantivy.Index) -> None:
result = parse_simple_text_highlight_query(query_index, "")
assert isinstance(result, tantivy.Query)
def test_all_operators_returns_empty_query(
self,
query_index: tantivy.Index,
) -> None:
result = parse_simple_text_highlight_query(query_index, "- +")
assert isinstance(result, tantivy.Query)
class TestPermissionFilter:
"""
build_permission_filter tests use an in-memory index no DB access needed.
build_permission_filter tests use an in-memory index - no DB access needed.
Users are constructed as unsaved model instances (django_user_model(pk=N))
so no database round-trip occurs; only .pk is read by build_permission_filter.
+105 -1
View File
@@ -74,6 +74,9 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
"ai_enabled": False,
"llm_embedding_backend": None,
"llm_embedding_model": None,
"llm_embedding_endpoint": None,
"llm_embedding_chunk_size": None,
"llm_context_size": None,
"llm_backend": None,
"llm_model": None,
"llm_api_key": None,
@@ -840,7 +843,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless_ai.indexing.vector_store_file_exists") as mock_exists,
patch("paperless.views.vector_store_file_exists") as mock_exists,
):
mock_exists.return_value = False
self.client.patch(
@@ -855,6 +858,91 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
)
mock_update.assert_called_once()
def test_update_llm_embedding_chunk_size_triggers_rebuild(self) -> None:
config = ApplicationConfiguration.objects.first()
assert config is not None
config.ai_enabled = True
config.llm_embedding_backend = "openai-like"
config.llm_embedding_chunk_size = 1024
config.save()
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.vector_store_file_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
f"{self.ENDPOINT}1/",
json.dumps({"llm_embedding_chunk_size": 512}),
content_type="application/json",
)
mock_update.assert_called_once()
self.assertEqual(mock_update.call_args.kwargs["kwargs"], {"rebuild": True})
def test_update_llm_context_size_triggers_rebuild(self) -> None:
config = ApplicationConfiguration.objects.first()
assert config is not None
config.ai_enabled = True
config.llm_embedding_backend = "openai-like"
config.llm_context_size = 8192
config.save()
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.vector_store_file_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
f"{self.ENDPOINT}1/",
json.dumps({"llm_context_size": 4096}),
content_type="application/json",
)
mock_update.assert_called_once()
self.assertEqual(mock_update.call_args.kwargs["kwargs"], {"rebuild": True})
def test_update_llm_embedding_model_triggers_rebuild(self) -> None:
config = ApplicationConfiguration.objects.first()
assert config is not None
config.ai_enabled = True
config.llm_embedding_backend = "openai-like"
config.llm_embedding_model = "text-embedding-3-small"
config.save()
with patch("documents.tasks.llmindex_index.apply_async") as mock_update:
self.client.patch(
f"{self.ENDPOINT}1/",
json.dumps({"llm_embedding_model": "text-embedding-3-large"}),
content_type="application/json",
)
mock_update.assert_called_once()
self.assertEqual(mock_update.call_args.kwargs["kwargs"], {"rebuild": True})
def test_enable_ai_index_with_config_change_triggers_rebuild(self) -> None:
config = ApplicationConfiguration.objects.first()
assert config is not None
config.ai_enabled = False
config.llm_embedding_backend = "openai-like"
config.llm_embedding_model = "text-embedding-3-small"
config.save()
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.vector_store_file_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
f"{self.ENDPOINT}1/",
json.dumps(
{
"ai_enabled": True,
"llm_embedding_model": "text-embedding-3-large",
},
),
content_type="application/json",
)
mock_update.assert_called_once()
self.assertEqual(mock_update.call_args.kwargs["kwargs"], {"rebuild": True})
@override_settings(LLM_ALLOW_INTERNAL_ENDPOINTS=False)
def test_update_llm_endpoint_blocks_internal_endpoint_when_disallowed(self) -> None:
response = self.client.patch(
@@ -868,3 +956,19 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("non-public address", str(response.data).lower())
@override_settings(LLM_ALLOW_INTERNAL_ENDPOINTS=False)
def test_update_llm_embedding_endpoint_blocks_internal_endpoint_when_disallowed(
self,
) -> None:
response = self.client.patch(
f"{self.ENDPOINT}1/",
json.dumps(
{
"llm_embedding_endpoint": "http://127.0.0.1:11434",
},
),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("non-public address", str(response.data).lower())
+44
View File
@@ -0,0 +1,44 @@
from __future__ import annotations
from unittest import mock
from django.contrib.auth.models import User
from rest_framework import status
from rest_framework.test import APITestCase
class TestChatStreamingViewInputValidation(APITestCase):
def setUp(self) -> None:
super().setUp()
self.user = User.objects.create_superuser(username="temp_admin")
self.client.force_authenticate(user=self.user)
def _mock_ai_enabled(self) -> mock.MagicMock:
"""Return a mock AIConfig instance with ai_enabled=True."""
m = mock.MagicMock()
m.ai_enabled = True
return m
def test_oversized_question_is_rejected(self) -> None:
with mock.patch(
"documents.views.AIConfig",
return_value=self._mock_ai_enabled(),
):
resp = self.client.post(
"/api/documents/chat/",
{"q": "x" * 4001},
format="json",
)
assert resp.status_code == status.HTTP_400_BAD_REQUEST
def test_missing_question_is_rejected(self) -> None:
with mock.patch(
"documents.views.AIConfig",
return_value=self._mock_ai_enabled(),
):
resp = self.client.post(
"/api/documents/chat/",
{},
format="json",
)
assert resp.status_code == status.HTTP_400_BAD_REQUEST
@@ -464,6 +464,40 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(read_streaming_response(resp), b"thumb")
def test_thumb_etag_changes_when_latest_version_is_deleted(self) -> None:
root = self._create_pdf(title="root", checksum="root")
v1 = self._create_pdf(
title="v1",
checksum="v1",
root_document=root,
)
v2 = self._create_pdf(
title="v2",
checksum="v2",
root_document=root,
)
self._write_file(v1.thumbnail_path, b"thumb-v1")
self._write_file(v2.thumbnail_path, b"thumb-v2")
resp = self.client.get(f"/api/documents/{root.id}/thumb/")
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(read_streaming_response(resp), b"thumb-v2")
self.assertEqual(resp.headers["ETag"], '"v2"')
with mock.patch("documents.search.get_backend"):
delete_resp = self.client.delete(
f"/api/documents/{root.id}/versions/{v2.id}/",
)
self.assertEqual(delete_resp.status_code, status.HTTP_200_OK)
resp = self.client.get(
f"/api/documents/{root.id}/thumb/",
HTTP_IF_NONE_MATCH='"v2"',
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp.headers["ETag"], '"v1"')
self.assertEqual(read_streaming_response(resp), b"thumb-v1")
def test_metadata_version_param_uses_version(self) -> None:
root = Document.objects.create(
title="root",
+105
View File
@@ -485,6 +485,42 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
response = self.client.get(f"/api/documents/{doc.pk}/thumb/")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_document_actions_trashed_document(self) -> None:
"""
GIVEN:
- Document with files exists
WHEN:
- Document is soft-deleted (moved to trash)
- Preview and thumb endpoints are requested
THEN:
- HTTP 200 OK for both (trashed documents remain previewable)
"""
_, filename = tempfile.mkstemp(dir=self.dirs.originals_dir)
content = b"This is a test"
content_thumbnail = b"thumbnail content"
with Path(filename).open("wb") as f:
f.write(content)
doc = Document.objects.create(
title="none",
filename=Path(filename).name,
mime_type="application/pdf",
)
with (self.dirs.thumbnail_dir / f"{doc.pk:07d}.webp").open("wb") as f:
f.write(content_thumbnail)
doc.delete()
response = self.client.get(f"/api/documents/{doc.pk}/preview/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(read_streaming_response(response), content)
response = self.client.get(f"/api/documents/{doc.pk}/thumb/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(read_streaming_response(response), content_thumbnail)
def test_document_history_action(self) -> None:
"""
GIVEN:
@@ -1305,6 +1341,35 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
self.assertEqual(response.data["document_type_count"], 1)
self.assertEqual(response.data["storage_path_count"], 2)
def test_statistics_excludes_document_versions(self) -> None:
root = Document.objects.create(
title="root",
checksum="A",
mime_type="application/pdf",
content="root",
)
version = Document.objects.create(
title="version",
checksum="B",
mime_type="application/pdf",
content="version",
root_document=root,
version_index=1,
)
tag_inbox = Tag.objects.create(name="t1", is_inbox_tag=True)
version.tags.add(tag_inbox)
response = self.client.get("/api/statistics/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["documents_total"], 1)
self.assertEqual(response.data["documents_inbox"], 0)
self.assertEqual(response.data["character_count"], 4)
self.assertEqual(
response.data["document_file_type_counts"][0]["mime_type_count"],
1,
)
def test_statistics_no_inbox_tag(self) -> None:
Document.objects.create(title="none1", checksum="A")
@@ -3047,6 +3112,46 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
# modified was updated to today
self.assertEqual(doc.modified.day, timezone.now().day)
def test_create_note_only_saves_document_modified_field(self) -> None:
"""
GIVEN:
- Existing document with a created date
WHEN:
- API request is made to add a note
THEN:
- Only the document modified field is persisted by the note endpoint
- Other document fields are not rewritten by the note endpoint
"""
doc = Document.objects.create(
title="test",
mime_type="application/pdf",
content="this is a document which will have notes added",
created=datetime.date(2026, 3, 31),
)
original_save = Document.save
with mock.patch.object(
Document,
"save",
autospec=True,
side_effect=original_save,
) as save_mock:
resp = self.client.post(
f"/api/documents/{doc.pk}/notes/",
data={"note": "this is a posted note"},
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
doc.refresh_from_db()
self.assertEqual(doc.created, datetime.date(2026, 3, 31))
self.assertTrue(
any(
call.kwargs.get("update_fields") == ["modified"]
for call in save_mock.call_args_list
if call.args and call.args[0].pk == doc.pk
),
)
def test_notes_permissions_aware(self) -> None:
"""
GIVEN:
+15 -12
View File
@@ -987,29 +987,32 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
THEN:
- The similar documents are returned from the API request
"""
d1 = Document.objects.create(
# Distinct created/added dates: documents created at the same instant
# share a timestamp term, and more_like_this (which cannot be scoped to
# content fields) would then match on it, surfacing unrelated documents.
d1 = DocumentFactory(
title="invoice",
content="the thing i bought at a shop and paid with bank account",
checksum="A",
pk=1,
created=datetime.date(2018, 1, 1),
added=timezone.make_aware(datetime.datetime(2018, 1, 1)),
)
d2 = Document.objects.create(
d2 = DocumentFactory(
title="bank statement 1",
content="things i paid for in august",
pk=2,
checksum="B",
created=datetime.date(2019, 3, 4),
added=timezone.make_aware(datetime.datetime(2019, 3, 4)),
)
d3 = Document.objects.create(
d3 = DocumentFactory(
title="bank statement 3",
content="things i paid for in september",
pk=3,
checksum="C",
created=datetime.date(2020, 7, 9),
added=timezone.make_aware(datetime.datetime(2020, 7, 9)),
)
d4 = Document.objects.create(
d4 = DocumentFactory(
title="Quarterly Report",
content="quarterly revenue profit margin earnings growth",
pk=4,
checksum="ABC",
created=datetime.date(2021, 11, 30),
added=timezone.make_aware(datetime.datetime(2021, 11, 30)),
)
backend = get_backend()
backend.add_or_update(d1)
+52
View File
@@ -945,6 +945,10 @@ class TestPDFActions(DirectoriesMixin, TestCase):
pages = [[1, 2], [3]]
self.doc2.archive_serial_number = 200
self.doc2.save()
errback = bulk_edit.restore_archive_serial_numbers_task.s(
{self.doc2.id: 200},
)
mock_chord.return_value.on_error.return_value = mock_chord.return_value
result = bulk_edit.split(doc_ids, pages, delete_originals=True)
self.assertEqual(result, "OK")
@@ -957,6 +961,8 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_delete_documents.assert_called()
mock_chord.assert_called_once()
mock_chord.return_value.on_error.assert_called_once_with(errback)
mock_chord.return_value.apply_async.assert_called_once_with()
delete_documents_args, _ = mock_delete_documents.call_args
self.assertEqual(
@@ -991,6 +997,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.doc2.save()
sig = mock.Mock()
sig.on_error.return_value = sig
sig.apply_async.side_effect = Exception("boom")
mock_chord.return_value = sig
@@ -1256,10 +1263,16 @@ class TestPDFActions(DirectoriesMixin, TestCase):
operations = [{"page": 1}, {"page": 2}]
self.doc2.archive_serial_number = 250
self.doc2.save()
errback = bulk_edit.restore_archive_serial_numbers_task.s(
{self.doc2.id: 250},
)
mock_chord.return_value.on_error.return_value = mock_chord.return_value
result = bulk_edit.edit_pdf(doc_ids, operations, delete_original=True)
self.assertEqual(result, "OK")
mock_chord.assert_called_once()
mock_chord.return_value.on_error.assert_called_once_with(errback)
mock_chord.return_value.apply_async.assert_called_once_with()
self.assertEqual(mock_consume_file.call_args.kwargs["overrides"].asn, 250)
self.doc2.refresh_from_db()
self.assertIsNone(self.doc2.archive_serial_number)
@@ -1288,6 +1301,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.doc2.save()
sig = mock.Mock()
sig.on_error.return_value = sig
sig.apply_async.side_effect = Exception("boom")
mock_chord.return_value = sig
@@ -1480,6 +1494,44 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(task_kwargs["input_doc"].root_document_id, doc.id)
self.assertIsNotNone(task_kwargs["overrides"])
@mock.patch("documents.bulk_edit.update_document_content_maybe_archive_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
@mock.patch("pikepdf.open")
def test_remove_password_update_document_uses_source_paths(
self,
mock_open,
mock_mkdtemp,
mock_consume_delay,
mock_update_document,
) -> None:
doc = self.doc1
source_file = self.dirs.scratch_dir / "consumption-source.pdf"
source_file.write_bytes(b"protected pdf content")
temp_dir = self.dirs.scratch_dir / "remove-password-source-file"
temp_dir.mkdir(parents=True, exist_ok=True)
mock_mkdtemp.return_value = str(temp_dir)
fake_pdf = mock.MagicMock()
def save_side_effect(target_path):
Path(target_path).write_bytes(b"new pdf content")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
result = bulk_edit.remove_password(
[doc.id],
password="secret",
update_document=True,
source_paths_by_id={doc.id: source_file},
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(source_file, password="secret")
mock_update_document.assert_not_called()
mock_consume_delay.assert_called_once()
@mock.patch("documents.data_models.magic.from_file", return_value="application/pdf")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("pikepdf.open")
+17 -9
View File
@@ -1120,12 +1120,14 @@ class TestConsumer(
self.assertEqual(command[1], "--replace-input")
@mock.patch("paperless_mail.models.MailRule.objects.get")
@mock.patch("paperless.parsers.mail.MailDocumentParser.get_thumbnail")
@mock.patch("paperless.parsers.mail.MailDocumentParser.parse")
@mock.patch("documents.consumer.get_parser_registry")
def test_mail_parser_receives_mailrule(
self,
mock_get_parser_registry: mock.Mock,
mock_mail_parser_parse: mock.Mock,
mock_get_thumbnail: mock.Mock,
mock_mailrule_get: mock.Mock,
) -> None:
"""
@@ -1136,6 +1138,7 @@ class TestConsumer(
THEN:
- The mail parser should receive the mail rule
"""
from documents.parsers import ParseError
from paperless.parsers.mail import MailDocumentParser
mock_get_parser_registry.return_value.get_parser_for_file.return_value = (
@@ -1144,19 +1147,24 @@ class TestConsumer(
mock_mailrule_get.return_value = mock.Mock(
pdf_layout=MailRule.PdfLayout.HTML_ONLY,
)
mock_get_thumbnail.side_effect = ParseError("no thumbnail")
src = (
Path(__file__).parent.parent.parent
/ Path("paperless")
/ Path("tests")
/ Path("samples")
/ Path("mail")
/ "html.eml"
)
dst = self.dirs.scratch_dir / "html.eml"
shutil.copy(src, dst)
with self.get_consumer(
filepath=(
Path(__file__).parent.parent.parent
/ Path("paperless")
/ Path("tests")
/ Path("samples")
/ Path("mail")
).resolve()
/ "html.eml",
filepath=dst,
source=DocumentSource.MailFetch,
mailrule_id=1,
) as consumer:
# fails because no gotenberg
with self.assertRaises(
ConsumerError,
):
+23 -10
View File
@@ -124,7 +124,7 @@ class ShareLinkBundleAPITests(DirectoriesMixin, APITestCase):
self.assertIn("document_ids", response.data)
def test_download_ready_bundle_streams_file(self) -> None:
bundle_file = Path(self.dirs.media_dir) / "bundles" / "ready.zip"
bundle_file = settings.SHARE_LINK_BUNDLE_DIR / "bundles" / "ready.zip"
bundle_file.parent.mkdir(parents=True, exist_ok=True)
bundle_file.write_bytes(b"binary-zip-content")
@@ -132,7 +132,7 @@ class ShareLinkBundleAPITests(DirectoriesMixin, APITestCase):
slug="readyslug",
file_version=ShareLink.FileVersion.ARCHIVE,
status=ShareLinkBundle.Status.READY,
file_path=str(bundle_file),
file_path=str(bundle_file.relative_to(settings.SHARE_LINK_BUNDLE_DIR)),
)
bundle.documents.set([self.document])
@@ -199,11 +199,11 @@ class ShareLinkBundleTaskTests(DirectoriesMixin, APITestCase):
self.document = DocumentFactory.create()
def test_cleanup_expired_share_link_bundles(self) -> None:
expired_path = Path(self.dirs.media_dir) / "expired.zip"
expired_path = settings.SHARE_LINK_BUNDLE_DIR / "expired.zip"
expired_path.parent.mkdir(parents=True, exist_ok=True)
expired_path.write_bytes(b"expired")
active_path = Path(self.dirs.media_dir) / "active.zip"
active_path = settings.SHARE_LINK_BUNDLE_DIR / "active.zip"
active_path.write_bytes(b"active")
expired_bundle = ShareLinkBundle.objects.create(
@@ -211,7 +211,7 @@ class ShareLinkBundleTaskTests(DirectoriesMixin, APITestCase):
file_version=ShareLink.FileVersion.ARCHIVE,
status=ShareLinkBundle.Status.READY,
expiration=timezone.now() - timedelta(days=1),
file_path=str(expired_path),
file_path=expired_path.name,
)
expired_bundle.documents.set([self.document])
@@ -220,7 +220,7 @@ class ShareLinkBundleTaskTests(DirectoriesMixin, APITestCase):
file_version=ShareLink.FileVersion.ARCHIVE,
status=ShareLinkBundle.Status.READY,
expiration=timezone.now() + timedelta(days=1),
file_path=str(active_path),
file_path=active_path.name,
)
active_bundle.documents.set([self.document])
@@ -424,7 +424,7 @@ class ShareLinkBundleFilterSetTests(DirectoriesMixin, APITestCase):
class ShareLinkBundleModelTests(DirectoriesMixin, APITestCase):
def test_absolute_file_path_handles_relative_and_absolute(self) -> None:
def test_absolute_file_path_handles_relative_path(self) -> None:
relative_path = Path("relative.zip")
bundle = ShareLinkBundle.objects.create(
slug="relative-bundle",
@@ -437,10 +437,23 @@ class ShareLinkBundleModelTests(DirectoriesMixin, APITestCase):
(settings.SHARE_LINK_BUNDLE_DIR / relative_path).resolve(),
)
absolute_path = Path(self.dirs.media_dir) / "absolute.zip"
bundle.file_path = str(absolute_path)
def test_absolute_file_path_rejects_absolute_path(self) -> None:
bundle = ShareLinkBundle.objects.create(
slug="absolute-bundle",
file_version=ShareLink.FileVersion.ORIGINAL,
file_path=str(Path(self.dirs.media_dir) / "absolute.zip"),
)
self.assertEqual(bundle.absolute_file_path.resolve(), absolute_path.resolve())
self.assertIsNone(bundle.absolute_file_path)
def test_absolute_file_path_rejects_traversal_outside_bundle_dir(self) -> None:
bundle = ShareLinkBundle.objects.create(
slug="traversal-bundle",
file_version=ShareLink.FileVersion.ORIGINAL,
file_path="../escaped.zip",
)
self.assertIsNone(bundle.absolute_file_path)
def test_str_returns_translated_slug(self) -> None:
bundle = ShareLinkBundle.objects.create(
@@ -5,6 +5,7 @@ from django.test import TestCase
from documents.conditionals import metadata_etag
from documents.conditionals import preview_etag
from documents.conditionals import thumbnail_etag
from documents.conditionals import thumbnail_last_modified
from documents.models import Document
from documents.tests.utils import DirectoriesMixin
@@ -30,6 +31,7 @@ class TestConditionals(DirectoriesMixin, TestCase):
self.assertEqual(metadata_etag(request, root.id), latest.checksum)
self.assertEqual(preview_etag(request, root.id), latest.archive_checksum)
self.assertEqual(thumbnail_etag(request, root.id), latest.checksum)
def test_resolve_effective_doc_returns_none_for_invalid_or_unrelated_version(
self,
+86
View File
@@ -25,6 +25,7 @@ from documents.models import DocumentType
from documents.models import ShareLink
from documents.models import StoragePath
from documents.models import Tag
from documents.models import UiSettings
from documents.signals.handlers import update_llm_suggestions_cache
from documents.tests.utils import DirectoriesMixin
from documents.tests.utils import read_streaming_response
@@ -319,6 +320,10 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"tags": ["tag1", "tag2"]})
mock_get_cache.assert_called_once_with(
self.document.pk,
backend="mock_backend",
)
mock_refresh_cache.assert_called_once_with(self.document.pk)
@patch("documents.views.get_ai_document_classification")
@@ -359,6 +364,49 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
"dates": ["2023-01-01"],
},
)
mock_get_ai_classification.assert_called_once_with(
self.document,
self.user,
None,
)
@patch("documents.views.get_ai_document_classification")
@override_settings(
AI_ENABLED=True,
LLM_BACKEND="mock_backend",
)
def test_ai_suggestions_uses_user_display_language(
self,
mock_get_ai_classification,
) -> None:
UiSettings.objects.create(user=self.user, settings={"language": "de-de"})
mock_get_ai_classification.return_value = {
"title": "KI Title",
"tags": [],
"correspondents": [],
"document_types": [],
"storage_paths": [],
"dates": [],
}
self.client.force_login(user=self.user)
response = self.client.get(
f"/api/documents/{self.document.pk}/ai_suggestions/",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
mock_get_ai_classification.assert_called_once_with(
self.document,
self.user,
"de-de",
)
self.assertEqual(
get_llm_suggestion_cache(
self.document.pk,
backend="mock_backend:de-de",
).suggestions["title"],
"KI Title",
)
@patch("documents.views.get_ai_document_classification")
@override_settings(
@@ -437,8 +485,14 @@ class TestAIChatStreamingView(DirectoriesMixin, TestCase):
)
super().setUp()
def grant_view_document_permission(self) -> None:
self.user.user_permissions.add(
*Permission.objects.filter(codename="view_document"),
)
@override_settings(AI_ENABLED=False)
def test_post_ai_disabled(self) -> None:
self.grant_view_document_permission()
response = self.client.post(
self.ENDPOINT,
data='{"q": "question"}',
@@ -451,6 +505,7 @@ class TestAIChatStreamingView(DirectoriesMixin, TestCase):
@patch("documents.views.get_objects_for_user_owner_aware")
@override_settings(AI_ENABLED=True)
def test_post_no_document_id(self, mock_get_objects, mock_stream_chat) -> None:
self.grant_view_document_permission()
mock_get_objects.return_value = [self.document]
mock_stream_chat.return_value = iter([b"data"])
response = self.client.post(
@@ -464,6 +519,7 @@ class TestAIChatStreamingView(DirectoriesMixin, TestCase):
@patch("documents.views.stream_chat_with_documents")
@override_settings(AI_ENABLED=True)
def test_post_with_document_id(self, mock_stream_chat) -> None:
self.grant_view_document_permission()
mock_stream_chat.return_value = iter([b"data"])
response = self.client.post(
self.ENDPOINT,
@@ -475,6 +531,7 @@ class TestAIChatStreamingView(DirectoriesMixin, TestCase):
@override_settings(AI_ENABLED=True)
def test_post_with_invalid_document_id(self) -> None:
self.grant_view_document_permission()
response = self.client.post(
self.ENDPOINT,
data='{"q": "question", "document_id": 999999}',
@@ -486,6 +543,7 @@ class TestAIChatStreamingView(DirectoriesMixin, TestCase):
@patch("documents.views.has_perms_owner_aware")
@override_settings(AI_ENABLED=True)
def test_post_with_document_id_no_permission(self, mock_has_perms) -> None:
self.grant_view_document_permission()
mock_has_perms.return_value = False
response = self.client.post(
self.ENDPOINT,
@@ -494,3 +552,31 @@ class TestAIChatStreamingView(DirectoriesMixin, TestCase):
)
self.assertEqual(response.status_code, 403)
self.assertIn(b"Insufficient permissions", response.content)
@patch("documents.views.stream_chat_with_documents")
@override_settings(AI_ENABLED=True)
def test_post_no_document_id_requires_view_document_permission(
self,
mock_stream_chat,
) -> None:
response = self.client.post(
self.ENDPOINT,
data='{"q": "question"}',
content_type="application/json",
)
self.assertEqual(response.status_code, 403)
mock_stream_chat.assert_not_called()
@patch("documents.views.stream_chat_with_documents")
@override_settings(AI_ENABLED=True)
def test_post_with_document_id_requires_view_document_permission(
self,
mock_stream_chat,
) -> None:
response = self.client.post(
self.ENDPOINT,
data=f'{{"q": "question", "document_id": {self.document.pk}}}',
content_type="application/json",
)
self.assertEqual(response.status_code, 403)
mock_stream_chat.assert_not_called()
+55 -3
View File
@@ -4164,7 +4164,7 @@ class TestWorkflows(
)
action = WorkflowAction.objects.create(
type=WorkflowAction.WorkflowActionType.PASSWORD_REMOVAL,
passwords="wrong, right\n extra ",
passwords=["wrong", "right", "extra"],
)
workflow = Workflow.objects.create(name="Password workflow")
workflow.triggers.add(trigger)
@@ -4185,12 +4185,14 @@ class TestWorkflows(
password="wrong",
update_document=True,
user=doc.owner,
source_paths_by_id=None,
),
mock.call(
[doc.id],
password="right",
update_document=True,
user=doc.owner,
source_paths_by_id=None,
),
],
)
@@ -4218,7 +4220,7 @@ class TestWorkflows(
)
action = WorkflowAction.objects.create(
type=WorkflowAction.WorkflowActionType.PASSWORD_REMOVAL,
passwords=" \n , ",
passwords=[" ", " "],
)
workflow = Workflow.objects.create(name="Password workflow missing passwords")
workflow.triggers.add(trigger)
@@ -4276,7 +4278,7 @@ class TestWorkflows(
"""
action = WorkflowAction.objects.create(
type=WorkflowAction.WorkflowActionType.PASSWORD_REMOVAL,
passwords="first, second",
passwords=["first", "second"],
)
temp_dir = Path(tempfile.mkdtemp())
@@ -4304,6 +4306,7 @@ class TestWorkflows(
document_consumption_finished.send(
sender=self.__class__,
document=doc,
original_file=original_file,
)
assert mock_remove_password.call_count == 2
@@ -4314,12 +4317,14 @@ class TestWorkflows(
password="first",
update_document=True,
user=doc.owner,
source_paths_by_id={doc.id: original_file},
),
mock.call(
[doc.id],
password="second",
update_document=True,
user=doc.owner,
source_paths_by_id={doc.id: original_file},
),
],
)
@@ -4331,6 +4336,53 @@ class TestWorkflows(
)
assert mock_remove_password.call_count == 2
@mock.patch("documents.bulk_edit.remove_password")
def test_password_removal_document_added_uses_original_file(
self,
mock_remove_password,
) -> None:
"""
GIVEN:
- Workflow password removal action on a DOCUMENT_ADDED trigger
- run_workflows called with an explicit original_file (staged file
from the consumer, before the source path is populated)
WHEN:
- The workflow runs
THEN:
- remove_password is called with source_paths_by_id pointing at the
staged file rather than the not-yet-existing source_path
"""
doc = Document.objects.create(
title="Protected",
checksum="pw-checksum-added",
)
trigger = WorkflowTrigger.objects.create(
type=WorkflowTrigger.WorkflowTriggerType.DOCUMENT_ADDED,
)
action = WorkflowAction.objects.create(
type=WorkflowAction.WorkflowActionType.PASSWORD_REMOVAL,
passwords=["secret"],
)
workflow = Workflow.objects.create(name="Password workflow added")
workflow.triggers.add(trigger)
workflow.actions.add(action)
mock_remove_password.return_value = "OK"
temp_dir = Path(tempfile.mkdtemp())
original_file = temp_dir / "staged.pdf"
original_file.write_bytes(b"pdf content")
run_workflows(trigger.type, doc, original_file=original_file)
mock_remove_password.assert_called_once_with(
[doc.id],
password="secret",
update_document=True,
user=doc.owner,
source_paths_by_id={doc.id: original_file},
)
def test_workflow_trash_action_soft_delete(self) -> None:
"""
GIVEN: