Compare commits

...

10 Commits

Author SHA1 Message Date
stumpylog 1f4a871b8f Refactor(beta): extract visible_document_ids_for_user helper
The owner-aware "resolve user to visible document pks" block was duplicated
verbatim between get_context_for_document and get_taxonomy_hints_for_document.
Extract it into indexing.visible_document_ids_for_user, next to its sibling
normalize_document_ids, and call it from both paths.

No behavior change: the helper returns None when user is None (unfiltered
retrieval) and the same pk list otherwise.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 29f9475818 Test(beta): use documents factories for taxonomy hint test fixtures
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog d06f66b618 Test(beta): use pytest-django fixtures and drop needless DB markers in taxonomy hint tests
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog f3f55e3866 Enhancement(beta): feed taxonomy hints into AI document suggestions
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 24b81c15f6 Enhancement(beta): splice taxonomy hints into the AI classifier prompt
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 5202b0880e Enhancement(beta): let name matching short-circuit on taxonomy hints
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 7ed58f9664 Enhancement(beta): gate and assemble taxonomy hints for a document
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 43eb3295ce Enhancement(beta): format taxonomy hints into prompt blocks
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog e0ba4cfada Enhancement(beta): add taxonomy hint builder from RAG node metadata
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 73062bd5ab Refactor(beta): extract retrieve_similar_nodes from query_similar_documents
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
11 changed files with 786 additions and 35 deletions
+3
View File
@@ -368,6 +368,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
self.document, self.document,
self.user, self.user,
None, None,
hints=None,
) )
@patch("documents.views.get_ai_document_classification") @patch("documents.views.get_ai_document_classification")
@@ -399,6 +400,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
self.document, self.document,
self.user, self.user,
"de-de", "de-de",
hints=None,
) )
self.assertEqual( self.assertEqual(
get_llm_suggestion_cache( get_llm_suggestion_cache(
@@ -438,6 +440,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
self.document, self.document,
self.user, self.user,
"fr-fr", "fr-fr",
hints=None,
) )
self.assertEqual( self.assertEqual(
get_llm_suggestion_cache( get_llm_suggestion_cache(
+8
View File
@@ -245,6 +245,7 @@ from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name from paperless_ai.matching import match_document_types_by_name
from paperless_ai.matching import match_storage_paths_by_name from paperless_ai.matching import match_storage_paths_by_name
from paperless_ai.matching import match_tags_by_name from paperless_ai.matching import match_tags_by_name
from paperless_ai.taxonomy import get_taxonomy_hints_for_document
from paperless_mail.models import MailAccount from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule from paperless_mail.models import MailRule
from paperless_mail.oauth import PaperlessMailOAuth2Manager from paperless_mail.oauth import PaperlessMailOAuth2Manager
@@ -1494,11 +1495,14 @@ class DocumentViewSet(
refresh_suggestions_cache(doc.pk) refresh_suggestions_cache(doc.pk)
return Response(cached_llm_suggestions.suggestions) return Response(cached_llm_suggestions.suggestions)
hints = get_taxonomy_hints_for_document(doc, request.user)
try: try:
llm_suggestions = get_ai_document_classification( llm_suggestions = get_ai_document_classification(
doc, doc,
request.user, request.user,
output_language, output_language,
hints=hints,
) )
except ValueError as exc: except ValueError as exc:
logger.exception( logger.exception(
@@ -1513,18 +1517,22 @@ class DocumentViewSet(
matched_tags = match_tags_by_name( matched_tags = match_tags_by_name(
llm_suggestions.get("tags", []), llm_suggestions.get("tags", []),
request.user, request.user,
hinted_names=set(hints["tags"]) if hints else None,
) )
matched_correspondents = match_correspondents_by_name( matched_correspondents = match_correspondents_by_name(
llm_suggestions.get("correspondents", []), llm_suggestions.get("correspondents", []),
request.user, request.user,
hinted_names=set(hints["correspondents"]) if hints else None,
) )
matched_types = match_document_types_by_name( matched_types = match_document_types_by_name(
llm_suggestions.get("document_types", []), llm_suggestions.get("document_types", []),
request.user, request.user,
hinted_names=set(hints["document_types"]) if hints else None,
) )
matched_paths = match_storage_paths_by_name( matched_paths = match_storage_paths_by_name(
llm_suggestions.get("storage_paths", []), llm_suggestions.get("storage_paths", []),
request.user, request.user,
hinted_names=set(hints["storage_paths"]) if hints else None,
) )
resp_data = { resp_data = {
+20 -19
View File
@@ -1,16 +1,21 @@
import json import json
import logging import logging
from typing import TYPE_CHECKING
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import User from django.contrib.auth.models import User
from documents.models import Document from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware
from paperless.config import AIConfig from paperless.config import AIConfig
from paperless_ai.client import AIClient from paperless_ai.client import AIClient
from paperless_ai.db import db_connection_released from paperless_ai.db import db_connection_released
from paperless_ai.indexing import query_similar_documents from paperless_ai.indexing import query_similar_documents
from paperless_ai.indexing import truncate_content from paperless_ai.indexing import truncate_content
from paperless_ai.indexing import visible_document_ids_for_user
from paperless_ai.taxonomy import format_hints_for_prompt
if TYPE_CHECKING:
from paperless_ai.taxonomy import TaxonomyHints
logger = logging.getLogger("paperless_ai.rag_classifier") logger = logging.getLogger("paperless_ai.rag_classifier")
@@ -26,6 +31,7 @@ def get_language_name(language_code: str) -> str:
def build_prompt_without_rag( def build_prompt_without_rag(
document: Document, document: Document,
config: AIConfig, config: AIConfig,
hints: "TaxonomyHints | None" = None,
) -> str: ) -> str:
filename = document.filename or "" filename = document.filename or ""
content = truncate_content( content = truncate_content(
@@ -34,10 +40,16 @@ def build_prompt_without_rag(
context_size=config.llm_context_size, context_size=config.llm_context_size,
) )
hints_block = format_hints_for_prompt(hints) if hints else ""
# Splice the block (if any) immediately before the "Analyze ..." instruction.
# When there is no block this expands to nothing, so the prompt is identical
# to the pre-hints baseline.
hints_section = f"{hints_block}\n\n " if hints_block else ""
return f""" return f"""
You are a document classification assistant. You are a document classification assistant.
Analyze the following document and extract the following information: {hints_section}Analyze the following document and extract the following information:
- A short descriptive title - A short descriptive title
- Tags that reflect the content - Tags that reflect the content
- Names of people or organizations mentioned - Names of people or organizations mentioned
@@ -57,8 +69,9 @@ def build_prompt_with_rag(
document: Document, document: Document,
config: AIConfig, config: AIConfig,
user: User | None = None, user: User | None = None,
hints: "TaxonomyHints | None" = None,
) -> str: ) -> str:
base_prompt = build_prompt_without_rag(document, config) base_prompt = build_prompt_without_rag(document, config, hints=hints)
context = truncate_content( context = truncate_content(
get_context_for_document(document, user), get_context_for_document(document, user),
chunk_size=config.llm_embedding_chunk_size, chunk_size=config.llm_embedding_chunk_size,
@@ -96,20 +109,7 @@ def get_context_for_document(
user: User | None = None, user: User | None = None,
max_docs: int = 5, max_docs: int = 5,
) -> str: ) -> str:
visible_documents = ( visible_document_ids = visible_document_ids_for_user(user)
get_objects_for_user_owner_aware(
user,
"view_document",
Document,
)
if user
else None
)
visible_document_ids = (
list(visible_documents.values_list("pk", flat=True))
if visible_documents is not None
else None
)
similar_docs = query_similar_documents( similar_docs = query_similar_documents(
document=doc, document=doc,
document_ids=visible_document_ids, document_ids=visible_document_ids,
@@ -137,13 +137,14 @@ def get_ai_document_classification(
document: Document, document: Document,
user: User | None = None, user: User | None = None,
output_language: str | None = None, output_language: str | None = None,
hints: "TaxonomyHints | None" = None,
) -> dict: ) -> dict:
ai_config = AIConfig() ai_config = AIConfig()
prompt = ( prompt = (
build_prompt_with_rag(document, ai_config, user) build_prompt_with_rag(document, ai_config, user, hints=hints)
if ai_config.llm_embedding_backend if ai_config.llm_embedding_backend
else build_prompt_without_rag(document, ai_config) else build_prompt_without_rag(document, ai_config, hints=hints)
) )
client = AIClient() client = AIClient()
+46 -5
View File
@@ -5,6 +5,7 @@ from datetime import timedelta
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import User
from django.utils import timezone from django.utils import timezone
from filelock import FileLock from filelock import FileLock
from filelock import ReadWriteLock from filelock import ReadWriteLock
@@ -12,6 +13,7 @@ from filelock import Timeout
from documents.models import Document from documents.models import Document
from documents.models import PaperlessTask from documents.models import PaperlessTask
from documents.permissions import get_objects_for_user_owner_aware
from documents.utils import IterWrapper from documents.utils import IterWrapper
from documents.utils import identity from documents.utils import identity
from paperless.config import AIConfig from paperless.config import AIConfig
@@ -22,6 +24,7 @@ from paperless_ai.embedding import get_embedding_model
if TYPE_CHECKING: if TYPE_CHECKING:
from llama_index.core.schema import BaseNode from llama_index.core.schema import BaseNode
from llama_index.core.schema import NodeWithScore
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
@@ -449,12 +452,36 @@ def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str]
return {str(document_id) for document_id in document_ids} return {str(document_id) for document_id in document_ids}
def query_similar_documents( def visible_document_ids_for_user(user: User | None) -> list[int] | None:
"""Return the pks of documents ``user`` may view, or ``None`` for no filter.
Returns ``None`` when ``user`` is ``None`` so retrieval runs unfiltered. Used
by both the similarity-context and taxonomy-hints paths to scope RAG
neighbours to documents the requesting user is allowed to see.
"""
if user is None:
return None
visible_documents = get_objects_for_user_owner_aware(
user,
"view_document",
Document,
)
return list(visible_documents.values_list("pk", flat=True))
def retrieve_similar_nodes(
document: Document, document: Document,
top_k: int = 5,
document_ids: Iterable[int | str] | None = None, document_ids: Iterable[int | str] | None = None,
) -> list[Document]: top_k: int = 5,
"""Return up to ``top_k`` Documents most similar to ``document``.""" ) -> list["NodeWithScore"]:
"""Run ANN retrieval and return the raw NodeWithScore results.
Returns ``[]`` when the allow-list normalizes to empty, or when no index
exists yet (queuing a build in that case). The ``retrieve()`` call is a slow
embedding request, so it runs inside ``db_connection_released()`` to avoid
pinning the pooled DB connection (#12976). Both ``query_similar_documents``
and the taxonomy-hints path go through here, so they share that behavior.
"""
allowed_document_ids = normalize_document_ids(document_ids) allowed_document_ids = normalize_document_ids(document_ids)
if allowed_document_ids is not None and not allowed_document_ids: if allowed_document_ids is not None and not allowed_document_ids:
return [] return []
@@ -494,7 +521,21 @@ def query_similar_documents(
filters=filters, filters=filters,
) )
with db_connection_released(): with db_connection_released():
results = retriever.retrieve(query_text) return retriever.retrieve(query_text)
def query_similar_documents(
document: Document,
top_k: int = 5,
document_ids: Iterable[int | str] | None = None,
) -> list[Document]:
"""Return up to ``top_k`` Documents most similar to ``document``."""
allowed_document_ids = normalize_document_ids(document_ids)
results = retrieve_similar_nodes(
document=document,
document_ids=allowed_document_ids,
top_k=top_k,
)
retrieved_document_ids: list[int] = [] retrieved_document_ids: list[int] = []
for node in results: for node in results:
+38 -11
View File
@@ -15,40 +15,56 @@ MATCH_THRESHOLD = 0.8
logger = logging.getLogger("paperless_ai.matching") logger = logging.getLogger("paperless_ai.matching")
def match_tags_by_name(names: list[str], user: User) -> list[Tag]: def match_tags_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[Tag]:
queryset = get_objects_for_user_owner_aware( queryset = get_objects_for_user_owner_aware(
user, user,
["view_tag"], ["view_tag"],
Tag, Tag,
) )
return _match_names_to_queryset(names, queryset, "name") return _match_names_to_queryset(names, queryset, "name", hinted_names)
def match_correspondents_by_name(names: list[str], user: User) -> list[Correspondent]: def match_correspondents_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[Correspondent]:
queryset = get_objects_for_user_owner_aware( queryset = get_objects_for_user_owner_aware(
user, user,
["view_correspondent"], ["view_correspondent"],
Correspondent, Correspondent,
) )
return _match_names_to_queryset(names, queryset, "name") return _match_names_to_queryset(names, queryset, "name", hinted_names)
def match_document_types_by_name(names: list[str], user: User) -> list[DocumentType]: def match_document_types_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[DocumentType]:
queryset = get_objects_for_user_owner_aware( queryset = get_objects_for_user_owner_aware(
user, user,
["view_documenttype"], ["view_documenttype"],
DocumentType, DocumentType,
) )
return _match_names_to_queryset(names, queryset, "name") return _match_names_to_queryset(names, queryset, "name", hinted_names)
def match_storage_paths_by_name(names: list[str], user: User) -> list[StoragePath]: def match_storage_paths_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[StoragePath]:
queryset = get_objects_for_user_owner_aware( queryset = get_objects_for_user_owner_aware(
user, user,
["view_storagepath"], ["view_storagepath"],
StoragePath, StoragePath,
) )
return _match_names_to_queryset(names, queryset, "name") return _match_names_to_queryset(names, queryset, "name", hinted_names)
def _normalize(s: str) -> str: def _normalize(s: str) -> str:
@@ -58,10 +74,18 @@ def _normalize(s: str) -> str:
return s return s
def _match_names_to_queryset(names: list[str], queryset, attr: str): def _match_names_to_queryset(
names: list[str],
queryset,
attr: str,
hinted_names: set[str] | None = None,
):
results = [] results = []
objects = list(queryset) objects = list(queryset)
object_names = [_normalize(getattr(obj, attr)) for obj in objects] object_names = [_normalize(getattr(obj, attr)) for obj in objects]
normalized_hints = (
{_normalize(name) for name in hinted_names} if hinted_names else set()
)
for name in names: for name in names:
if not name: if not name:
@@ -76,6 +100,11 @@ def _match_names_to_queryset(names: list[str], queryset, attr: str):
results.append(matched) results.append(matched)
continue continue
# A hinted name that didn't exact-match came from existing taxonomy
# verbatim; do not fuzzy-map it onto a different object.
if target in normalized_hints:
continue
# Fuzzy match fallback # Fuzzy match fallback
matches = difflib.get_close_matches( matches = difflib.get_close_matches(
target, target,
@@ -88,8 +117,6 @@ def _match_names_to_queryset(names: list[str], queryset, attr: str):
matched = objects.pop(index) matched = objects.pop(index)
object_names.pop(index) object_names.pop(index)
results.append(matched) results.append(matched)
else:
pass
return results return results
+115
View File
@@ -0,0 +1,115 @@
from typing import TYPE_CHECKING
from typing import TypedDict
from django.contrib.auth.models import User
from documents.models import Document
from paperless.config import AIConfig
from paperless_ai.indexing import retrieve_similar_nodes
from paperless_ai.indexing import visible_document_ids_for_user
if TYPE_CHECKING:
from llama_index.core.schema import NodeWithScore
class TaxonomyHints(TypedDict):
tags: list[str]
document_types: list[str]
correspondents: list[str]
storage_paths: list[str]
def build_taxonomy_hints_from_nodes(
nodes: list["NodeWithScore"],
) -> TaxonomyHints:
"""Collect the unique, sorted taxonomy names carried on retrieved nodes.
Reads ``tags`` (a list), ``document_type``, ``correspondent``, and
``storage_path`` from each node's metadata. Empty / ``None`` values and
missing keys are skipped. The result is naturally bounded by the retrieval
``top_k``, so no cap is applied.
"""
tags: set[str] = set()
document_types: set[str] = set()
correspondents: set[str] = set()
storage_paths: set[str] = set()
for node in nodes:
metadata = node.metadata or {}
for tag in metadata.get("tags") or []:
if tag:
tags.add(tag)
document_type = metadata.get("document_type")
if document_type:
document_types.add(document_type)
correspondent = metadata.get("correspondent")
if correspondent:
correspondents.add(correspondent)
storage_path = metadata.get("storage_path")
if storage_path:
storage_paths.add(storage_path)
return TaxonomyHints(
tags=sorted(tags),
document_types=sorted(document_types),
correspondents=sorted(correspondents),
storage_paths=sorted(storage_paths),
)
_HINT_INSTRUCTION = (
"Prefer existing names from these lists verbatim. Only propose a new value "
"if none of the existing names fits."
)
def format_hints_for_prompt(hints: TaxonomyHints) -> str:
"""Render non-empty hint categories as labelled blocks plus one instruction.
Returns "" when every category is empty, so callers can treat the result
the same as no hints at all.
"""
# Literal-key access keeps this TypedDict-safe for mypy; the order here is
# the order the blocks appear in the prompt.
labelled_values: list[tuple[str, list[str]]] = [
("Available tags", hints["tags"]),
("Available document types", hints["document_types"]),
("Available correspondents", hints["correspondents"]),
("Available storage paths", hints["storage_paths"]),
]
blocks: list[str] = []
for label, values in labelled_values:
if values:
listing = "\n".join(f"- {value}" for value in values)
blocks.append(f"{label}:\n{listing}")
if not blocks:
return ""
return "\n\n".join([*blocks, _HINT_INSTRUCTION])
def get_taxonomy_hints_for_document(
document: Document,
user: User | None,
) -> TaxonomyHints | None:
"""Build taxonomy hints from a document's RAG neighbours.
Returns ``None`` when no embedding backend is configured (the gate) so the
caller's prompt and matching are identical to today. Otherwise returns a
``TaxonomyHints`` -- possibly all-empty when no similar documents exist.
Applies the same owner-aware visible-document filter as
``get_context_for_document``.
"""
if not AIConfig().llm_embedding_backend:
return None
nodes = retrieve_similar_nodes(
document=document,
document_ids=visible_document_ids_for_user(user),
)
return build_taxonomy_hints_from_nodes(nodes)
@@ -1,8 +1,11 @@
import json import json
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock from unittest.mock import MagicMock
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import pytest_mock
from django.test import override_settings from django.test import override_settings
from documents.models import Document from documents.models import Document
@@ -261,3 +264,111 @@ def test_get_context_for_document_no_similar_docs(mock_document):
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]): with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
result = get_context_for_document(mock_document) result = get_context_for_document(mock_document)
assert result == "" assert result == ""
class TestPromptHints:
@pytest.fixture
def config(self) -> AIConfig:
# build_prompt_* only read these two numeric settings off config;
# a stand-in avoids constructing a DB-backed AIConfig.
return cast(
"AIConfig",
SimpleNamespace(llm_embedding_chunk_size=1000, llm_context_size=8000),
)
def test_without_rag_includes_hints_block(
self,
mock_document: MagicMock,
config: AIConfig,
) -> None:
hints = {
"tags": ["Bloodwork"],
"document_types": ["Invoice"],
"correspondents": [],
"storage_paths": [],
}
prompt = build_prompt_without_rag(mock_document, config, hints=hints)
assert "Available tags:" in prompt
assert "- Bloodwork" in prompt
assert "Prefer existing names from these lists verbatim" in prompt
def test_without_rag_none_matches_baseline(
self,
mock_document: MagicMock,
config: AIConfig,
) -> None:
baseline = build_prompt_without_rag(mock_document, config)
with_none = build_prompt_without_rag(mock_document, config, hints=None)
assert with_none == baseline
assert "Available tags:" not in with_none
def test_with_rag_includes_context_and_hints(
self,
mock_document: MagicMock,
config: AIConfig,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.ai_classifier.get_context_for_document",
return_value="TITLE: Neighbour\nsome context",
)
hints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
prompt = build_prompt_with_rag(mock_document, config, user=None, hints=hints)
assert "Additional context from similar documents" in prompt
assert "Available tags:" in prompt
def test_classification_forwards_hints(
self,
mock_document: MagicMock,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.ai_classifier.AIConfig",
return_value=SimpleNamespace(
llm_embedding_backend=None,
llm_embedding_chunk_size=1000,
llm_context_size=8000,
),
)
build = mocker.patch(
"paperless_ai.ai_classifier.build_prompt_without_rag",
return_value="PROMPT",
)
mock_client = MagicMock()
mock_client.run_llm_query.return_value = {
"title": "t",
"tags": [],
"correspondents": [],
"document_types": [],
"storage_paths": [],
"dates": [],
}
mocker.patch("paperless_ai.ai_classifier.AIClient", return_value=mock_client)
hints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
result = get_ai_document_classification(
mock_document,
user=None,
hints=hints,
)
_, build_kwargs = build.call_args
assert build_kwargs["hints"] == hints
assert set(result.keys()) == {
"title",
"tags",
"correspondents",
"document_types",
"storage_paths",
"dates",
}
@@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
from unittest.mock import patch from unittest.mock import patch
@@ -726,3 +727,58 @@ class TestQuerySimilarDocuments:
results = indexing.query_similar_documents(a, document_ids=[b.id]) results = indexing.query_similar_documents(a, document_ids=[b.id])
assert all(doc.id == b.id for doc in results) assert all(doc.id == b.id for doc in results)
class TestRetrieveSimilarNodes:
@pytest.mark.django_db
def test_returns_raw_nodes_from_retriever(
self,
temp_llm_index_dir: Path,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=True)
mocker.patch("paperless_ai.indexing.load_or_build_index")
node1 = SimpleNamespace(metadata={"document_id": "1"})
node2 = SimpleNamespace(metadata={"document_id": "2"})
retriever = mocker.MagicMock()
retriever.retrieve.return_value = [node1, node2]
mocker.patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=retriever,
)
result = indexing.retrieve_similar_nodes(real_document, top_k=3)
assert result == [node1, node2]
@pytest.mark.django_db
def test_empty_allow_list_fails_closed(
self,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
load = mocker.patch("paperless_ai.indexing.load_or_build_index")
result = indexing.retrieve_similar_nodes(real_document, document_ids=[])
assert result == []
load.assert_not_called()
@pytest.mark.django_db
def test_queues_update_when_index_missing(
self,
temp_llm_index_dir: Path,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=False)
queue = mocker.patch("paperless_ai.indexing.queue_llm_index_update_if_needed")
result = indexing.retrieve_similar_nodes(real_document, top_k=2)
assert result == []
queue.assert_called_once_with(
rebuild=False,
reason="LLM index not found for similarity query.",
)
+92
View File
@@ -1,12 +1,15 @@
import difflib
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import pytest_mock
from django.test import TestCase from django.test import TestCase
from documents.models import Correspondent from documents.models import Correspondent
from documents.models import DocumentType from documents.models import DocumentType
from documents.models import StoragePath from documents.models import StoragePath
from documents.models import Tag from documents.models import Tag
from documents.tests.factories import TagFactory
from paperless_ai.matching import extract_unmatched_names from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name from paperless_ai.matching import match_document_types_by_name
@@ -87,6 +90,95 @@ class TestAIMatching(TestCase):
self.assertEqual(result[1].name, "Test Tag 2") self.assertEqual(result[1].name, "Test Tag 2")
class TestHintedMatching:
def test_hinted_verbatim_skips_fuzzy(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
spy = mocker.spy(difflib, "get_close_matches")
result = match_tags_by_name(
["Bloodwork"],
user=None,
hinted_names={"Bloodwork"},
)
assert [t.name for t in result] == ["Bloodwork"]
spy.assert_not_called()
def test_unhinted_name_still_fuzzy_matches(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
# "Bloodwrok" is a typo not in hints -> fuzzy still maps it to Bloodwork.
result = match_tags_by_name(
["Bloodwrok"],
user=None,
hinted_names={"Taxes"},
)
assert [t.name for t in result] == ["Bloodwork"]
def test_hinted_name_with_whitespace_exact_matches(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
spy = mocker.spy(difflib, "get_close_matches")
result = match_tags_by_name(
["Bloodwork "],
user=None,
hinted_names={"Bloodwork"},
)
assert [t.name for t in result] == ["Bloodwork"]
spy.assert_not_called()
def test_hinted_name_absent_from_queryset_is_skipped_not_fuzzed(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
# A hint with no exact object must not fall through to fuzzy.
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
result = match_tags_by_name(
["Bloodwrok"],
user=None,
hinted_names={"Bloodwrok"},
)
assert result == []
def test_backward_compatible_without_kwarg(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Test Tag 1")],
)
result = match_tags_by_name(["Test Tag 1", "Nonexistent"], user=None)
assert [t.name for t in result] == ["Test Tag 1"]
@pytest.mark.django_db @pytest.mark.django_db
class TestExtractUnmatchedNamesNormalization: class TestExtractUnmatchedNamesNormalization:
def test_punctuated_name_already_matched_is_not_returned_as_unmatched( def test_punctuated_name_already_matched_is_not_returned_as_unmatched(
+220
View File
@@ -0,0 +1,220 @@
from types import SimpleNamespace
import pytest_mock
from documents.tests.factories import DocumentFactory
from paperless_ai.taxonomy import TaxonomyHints
from paperless_ai.taxonomy import build_taxonomy_hints_from_nodes
from paperless_ai.taxonomy import format_hints_for_prompt
from paperless_ai.taxonomy import get_taxonomy_hints_for_document
def make_node(**metadata: object) -> SimpleNamespace:
"""A stand-in for NodeWithScore: only ``.metadata`` is accessed."""
return SimpleNamespace(metadata=metadata)
class TestBuildTaxonomyHintsFromNodes:
def test_returns_all_four_keys(self) -> None:
hints = build_taxonomy_hints_from_nodes([])
assert set(hints.keys()) == {
"tags",
"document_types",
"correspondents",
"storage_paths",
}
def test_collects_and_sorts_values(self) -> None:
nodes = [
make_node(
tags=["Taxes", "Bloodwork"],
document_type="Invoice",
correspondent="IRS",
storage_path="Financial",
),
]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["tags"] == ["Bloodwork", "Taxes"]
assert hints["document_types"] == ["Invoice"]
assert hints["correspondents"] == ["IRS"]
assert hints["storage_paths"] == ["Financial"]
def test_deduplicates_across_nodes(self) -> None:
nodes = [
make_node(tags=["Taxes"], document_type="Invoice"),
make_node(tags=["Taxes", "Medical"], document_type="Invoice"),
]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["tags"] == ["Medical", "Taxes"]
assert hints["document_types"] == ["Invoice"]
def test_none_values_skipped(self) -> None:
nodes = [
make_node(
tags=["Taxes", None, ""],
document_type=None,
correspondent=None,
storage_path=None,
),
]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["tags"] == ["Taxes"]
assert hints["document_types"] == []
assert hints["correspondents"] == []
assert hints["storage_paths"] == []
def test_missing_storage_path_key_handled(self) -> None:
# Pre-enrichment nodes have no storage_path key at all.
nodes = [make_node(tags=["Taxes"], document_type="Invoice")]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["storage_paths"] == []
def test_empty_node_list_all_empty(self) -> None:
hints = build_taxonomy_hints_from_nodes([])
assert hints == {
"tags": [],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
def test_output_stable_across_calls(self) -> None:
nodes = [make_node(tags=["b", "a", "c"])]
assert build_taxonomy_hints_from_nodes(
nodes,
) == build_taxonomy_hints_from_nodes(nodes)
class TestFormatHintsForPrompt:
def test_all_blocks_present_when_all_categories_nonempty(self) -> None:
hints: TaxonomyHints = {
"tags": ["Bloodwork"],
"document_types": ["Invoice"],
"correspondents": ["IRS"],
"storage_paths": ["Financial"],
}
result = format_hints_for_prompt(hints)
assert "Available tags:" in result
assert "Available document types:" in result
assert "Available correspondents:" in result
assert "Available storage paths:" in result
assert "- Bloodwork" in result
def test_empty_category_produces_no_block(self) -> None:
hints: TaxonomyHints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
result = format_hints_for_prompt(hints)
assert "Available tags:" in result
assert "Available document types:" not in result
assert "Available correspondents:" not in result
assert "Available storage paths:" not in result
def test_all_empty_produces_empty_string(self) -> None:
hints: TaxonomyHints = {
"tags": [],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
assert format_hints_for_prompt(hints) == ""
def test_instruction_line_appears_once(self) -> None:
hints: TaxonomyHints = {
"tags": ["Bloodwork"],
"document_types": ["Invoice"],
"correspondents": [],
"storage_paths": [],
}
result = format_hints_for_prompt(hints)
assert result.count("Prefer existing names from these lists verbatim") == 1
class TestGetTaxonomyHintsForDocument:
def test_returns_none_when_embedding_backend_off(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend=None),
)
retrieve = mocker.patch("paperless_ai.taxonomy.retrieve_similar_nodes")
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
assert result is None
retrieve.assert_not_called()
def test_passes_owner_aware_ids_when_user_present(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
)
mocker.patch(
"paperless_ai.taxonomy.visible_document_ids_for_user",
return_value=[1, 2, 3],
)
retrieve = mocker.patch(
"paperless_ai.taxonomy.retrieve_similar_nodes",
return_value=[],
)
document = DocumentFactory.build()
user = mocker.MagicMock()
get_taxonomy_hints_for_document(document, user=user)
retrieve.assert_called_once_with(
document=document,
document_ids=[1, 2, 3],
)
def test_returns_populated_hints_when_nodes_found(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
)
mocker.patch(
"paperless_ai.taxonomy.retrieve_similar_nodes",
return_value=[make_node(tags=["Taxes"], document_type="Invoice")],
)
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
assert result == {
"tags": ["Taxes"],
"document_types": ["Invoice"],
"correspondents": [],
"storage_paths": [],
}
def test_returns_empty_hints_not_none_when_no_nodes(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
)
mocker.patch(
"paperless_ai.taxonomy.retrieve_similar_nodes",
return_value=[],
)
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
assert result == {
"tags": [],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
@@ -0,0 +1,77 @@
from types import SimpleNamespace
import pytest
import pytest_mock
from django.contrib.auth.models import User
from rest_framework.test import APIClient
from documents.models import Document
from documents.tests.factories import DocumentFactory
@pytest.mark.django_db
class TestSuggestionsHintWiring:
@pytest.fixture
def document(self) -> Document:
return DocumentFactory() # type: ignore[return-value]
@pytest.fixture
def api_client(self, admin_user: User) -> APIClient:
client = APIClient()
client.force_authenticate(user=admin_user)
return client
def test_hints_passed_to_classifier_and_matchers(
self,
api_client: APIClient,
document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
hints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
mocker.patch(
"documents.views.get_taxonomy_hints_for_document",
return_value=hints,
)
mocker.patch(
"documents.views.AIConfig",
return_value=SimpleNamespace(
ai_enabled=True,
llm_backend="ollama",
llm_output_language=None,
),
)
# No cached suggestion -> the view reaches the classifier path.
mocker.patch(
"documents.views.get_llm_suggestion_cache",
return_value=None,
)
mocker.patch("documents.views.set_llm_suggestions_cache")
classify = mocker.patch(
"documents.views.get_ai_document_classification",
return_value={
"title": "Doc",
"tags": ["Bloodwork"],
"correspondents": [],
"document_types": [],
"storage_paths": [],
"dates": [],
},
)
match_tags = mocker.patch(
"documents.views.match_tags_by_name",
return_value=[],
)
mocker.patch("documents.views.match_correspondents_by_name", return_value=[])
mocker.patch("documents.views.match_document_types_by_name", return_value=[])
mocker.patch("documents.views.match_storage_paths_by_name", return_value=[])
response = api_client.get(f"/api/documents/{document.pk}/ai_suggestions/")
assert response.status_code == 200
assert classify.call_args.kwargs["hints"] == hints
assert match_tags.call_args.kwargs["hinted_names"] == {"Bloodwork"}