mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-28 16:24:19 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1f4a871b8f | |||
| 29f9475818 | |||
| d06f66b618 | |||
| f3f55e3866 | |||
| 24b81c15f6 | |||
| 5202b0880e | |||
| 7ed58f9664 | |||
| 43eb3295ce | |||
| e0ba4cfada | |||
| 73062bd5ab |
@@ -368,6 +368,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
self.document,
|
||||
self.user,
|
||||
None,
|
||||
hints=None,
|
||||
)
|
||||
|
||||
@patch("documents.views.get_ai_document_classification")
|
||||
@@ -399,6 +400,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
self.document,
|
||||
self.user,
|
||||
"de-de",
|
||||
hints=None,
|
||||
)
|
||||
self.assertEqual(
|
||||
get_llm_suggestion_cache(
|
||||
@@ -438,6 +440,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
self.document,
|
||||
self.user,
|
||||
"fr-fr",
|
||||
hints=None,
|
||||
)
|
||||
self.assertEqual(
|
||||
get_llm_suggestion_cache(
|
||||
|
||||
@@ -245,6 +245,7 @@ from paperless_ai.matching import match_correspondents_by_name
|
||||
from paperless_ai.matching import match_document_types_by_name
|
||||
from paperless_ai.matching import match_storage_paths_by_name
|
||||
from paperless_ai.matching import match_tags_by_name
|
||||
from paperless_ai.taxonomy import get_taxonomy_hints_for_document
|
||||
from paperless_mail.models import MailAccount
|
||||
from paperless_mail.models import MailRule
|
||||
from paperless_mail.oauth import PaperlessMailOAuth2Manager
|
||||
@@ -1494,11 +1495,14 @@ class DocumentViewSet(
|
||||
refresh_suggestions_cache(doc.pk)
|
||||
return Response(cached_llm_suggestions.suggestions)
|
||||
|
||||
hints = get_taxonomy_hints_for_document(doc, request.user)
|
||||
|
||||
try:
|
||||
llm_suggestions = get_ai_document_classification(
|
||||
doc,
|
||||
request.user,
|
||||
output_language,
|
||||
hints=hints,
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.exception(
|
||||
@@ -1513,18 +1517,22 @@ class DocumentViewSet(
|
||||
matched_tags = match_tags_by_name(
|
||||
llm_suggestions.get("tags", []),
|
||||
request.user,
|
||||
hinted_names=set(hints["tags"]) if hints else None,
|
||||
)
|
||||
matched_correspondents = match_correspondents_by_name(
|
||||
llm_suggestions.get("correspondents", []),
|
||||
request.user,
|
||||
hinted_names=set(hints["correspondents"]) if hints else None,
|
||||
)
|
||||
matched_types = match_document_types_by_name(
|
||||
llm_suggestions.get("document_types", []),
|
||||
request.user,
|
||||
hinted_names=set(hints["document_types"]) if hints else None,
|
||||
)
|
||||
matched_paths = match_storage_paths_by_name(
|
||||
llm_suggestions.get("storage_paths", []),
|
||||
request.user,
|
||||
hinted_names=set(hints["storage_paths"]) if hints else None,
|
||||
)
|
||||
|
||||
resp_data = {
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from documents.models import Document
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.db import db_connection_released
|
||||
from paperless_ai.indexing import query_similar_documents
|
||||
from paperless_ai.indexing import truncate_content
|
||||
from paperless_ai.indexing import visible_document_ids_for_user
|
||||
from paperless_ai.taxonomy import format_hints_for_prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paperless_ai.taxonomy import TaxonomyHints
|
||||
|
||||
logger = logging.getLogger("paperless_ai.rag_classifier")
|
||||
|
||||
@@ -26,6 +31,7 @@ def get_language_name(language_code: str) -> str:
|
||||
def build_prompt_without_rag(
|
||||
document: Document,
|
||||
config: AIConfig,
|
||||
hints: "TaxonomyHints | None" = None,
|
||||
) -> str:
|
||||
filename = document.filename or ""
|
||||
content = truncate_content(
|
||||
@@ -34,10 +40,16 @@ def build_prompt_without_rag(
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
|
||||
hints_block = format_hints_for_prompt(hints) if hints else ""
|
||||
# Splice the block (if any) immediately before the "Analyze ..." instruction.
|
||||
# When there is no block this expands to nothing, so the prompt is identical
|
||||
# to the pre-hints baseline.
|
||||
hints_section = f"{hints_block}\n\n " if hints_block else ""
|
||||
|
||||
return f"""
|
||||
You are a document classification assistant.
|
||||
|
||||
Analyze the following document and extract the following information:
|
||||
{hints_section}Analyze the following document and extract the following information:
|
||||
- A short descriptive title
|
||||
- Tags that reflect the content
|
||||
- Names of people or organizations mentioned
|
||||
@@ -57,8 +69,9 @@ def build_prompt_with_rag(
|
||||
document: Document,
|
||||
config: AIConfig,
|
||||
user: User | None = None,
|
||||
hints: "TaxonomyHints | None" = None,
|
||||
) -> str:
|
||||
base_prompt = build_prompt_without_rag(document, config)
|
||||
base_prompt = build_prompt_without_rag(document, config, hints=hints)
|
||||
context = truncate_content(
|
||||
get_context_for_document(document, user),
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
@@ -96,20 +109,7 @@ def get_context_for_document(
|
||||
user: User | None = None,
|
||||
max_docs: int = 5,
|
||||
) -> str:
|
||||
visible_documents = (
|
||||
get_objects_for_user_owner_aware(
|
||||
user,
|
||||
"view_document",
|
||||
Document,
|
||||
)
|
||||
if user
|
||||
else None
|
||||
)
|
||||
visible_document_ids = (
|
||||
list(visible_documents.values_list("pk", flat=True))
|
||||
if visible_documents is not None
|
||||
else None
|
||||
)
|
||||
visible_document_ids = visible_document_ids_for_user(user)
|
||||
similar_docs = query_similar_documents(
|
||||
document=doc,
|
||||
document_ids=visible_document_ids,
|
||||
@@ -137,13 +137,14 @@ def get_ai_document_classification(
|
||||
document: Document,
|
||||
user: User | None = None,
|
||||
output_language: str | None = None,
|
||||
hints: "TaxonomyHints | None" = None,
|
||||
) -> dict:
|
||||
ai_config = AIConfig()
|
||||
|
||||
prompt = (
|
||||
build_prompt_with_rag(document, ai_config, user)
|
||||
build_prompt_with_rag(document, ai_config, user, hints=hints)
|
||||
if ai_config.llm_embedding_backend
|
||||
else build_prompt_without_rag(document, ai_config)
|
||||
else build_prompt_without_rag(document, ai_config, hints=hints)
|
||||
)
|
||||
|
||||
client = AIClient()
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
from django.utils import timezone
|
||||
from filelock import FileLock
|
||||
from filelock import ReadWriteLock
|
||||
@@ -12,6 +13,7 @@ from filelock import Timeout
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import PaperlessTask
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
from documents.utils import IterWrapper
|
||||
from documents.utils import identity
|
||||
from paperless.config import AIConfig
|
||||
@@ -22,6 +24,7 @@ from paperless_ai.embedding import get_embedding_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
@@ -449,12 +452,36 @@ def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str]
|
||||
return {str(document_id) for document_id in document_ids}
|
||||
|
||||
|
||||
def query_similar_documents(
|
||||
def visible_document_ids_for_user(user: User | None) -> list[int] | None:
|
||||
"""Return the pks of documents ``user`` may view, or ``None`` for no filter.
|
||||
|
||||
Returns ``None`` when ``user`` is ``None`` so retrieval runs unfiltered. Used
|
||||
by both the similarity-context and taxonomy-hints paths to scope RAG
|
||||
neighbours to documents the requesting user is allowed to see.
|
||||
"""
|
||||
if user is None:
|
||||
return None
|
||||
visible_documents = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
"view_document",
|
||||
Document,
|
||||
)
|
||||
return list(visible_documents.values_list("pk", flat=True))
|
||||
|
||||
|
||||
def retrieve_similar_nodes(
|
||||
document: Document,
|
||||
top_k: int = 5,
|
||||
document_ids: Iterable[int | str] | None = None,
|
||||
) -> list[Document]:
|
||||
"""Return up to ``top_k`` Documents most similar to ``document``."""
|
||||
top_k: int = 5,
|
||||
) -> list["NodeWithScore"]:
|
||||
"""Run ANN retrieval and return the raw NodeWithScore results.
|
||||
|
||||
Returns ``[]`` when the allow-list normalizes to empty, or when no index
|
||||
exists yet (queuing a build in that case). The ``retrieve()`` call is a slow
|
||||
embedding request, so it runs inside ``db_connection_released()`` to avoid
|
||||
pinning the pooled DB connection (#12976). Both ``query_similar_documents``
|
||||
and the taxonomy-hints path go through here, so they share that behavior.
|
||||
"""
|
||||
allowed_document_ids = normalize_document_ids(document_ids)
|
||||
if allowed_document_ids is not None and not allowed_document_ids:
|
||||
return []
|
||||
@@ -494,7 +521,21 @@ def query_similar_documents(
|
||||
filters=filters,
|
||||
)
|
||||
with db_connection_released():
|
||||
results = retriever.retrieve(query_text)
|
||||
return retriever.retrieve(query_text)
|
||||
|
||||
|
||||
def query_similar_documents(
|
||||
document: Document,
|
||||
top_k: int = 5,
|
||||
document_ids: Iterable[int | str] | None = None,
|
||||
) -> list[Document]:
|
||||
"""Return up to ``top_k`` Documents most similar to ``document``."""
|
||||
allowed_document_ids = normalize_document_ids(document_ids)
|
||||
results = retrieve_similar_nodes(
|
||||
document=document,
|
||||
document_ids=allowed_document_ids,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
retrieved_document_ids: list[int] = []
|
||||
for node in results:
|
||||
|
||||
@@ -15,40 +15,56 @@ MATCH_THRESHOLD = 0.8
|
||||
logger = logging.getLogger("paperless_ai.matching")
|
||||
|
||||
|
||||
def match_tags_by_name(names: list[str], user: User) -> list[Tag]:
|
||||
def match_tags_by_name(
|
||||
names: list[str],
|
||||
user: User,
|
||||
hinted_names: set[str] | None = None,
|
||||
) -> list[Tag]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_tag"],
|
||||
Tag,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
return _match_names_to_queryset(names, queryset, "name", hinted_names)
|
||||
|
||||
|
||||
def match_correspondents_by_name(names: list[str], user: User) -> list[Correspondent]:
|
||||
def match_correspondents_by_name(
|
||||
names: list[str],
|
||||
user: User,
|
||||
hinted_names: set[str] | None = None,
|
||||
) -> list[Correspondent]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_correspondent"],
|
||||
Correspondent,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
return _match_names_to_queryset(names, queryset, "name", hinted_names)
|
||||
|
||||
|
||||
def match_document_types_by_name(names: list[str], user: User) -> list[DocumentType]:
|
||||
def match_document_types_by_name(
|
||||
names: list[str],
|
||||
user: User,
|
||||
hinted_names: set[str] | None = None,
|
||||
) -> list[DocumentType]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_documenttype"],
|
||||
DocumentType,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
return _match_names_to_queryset(names, queryset, "name", hinted_names)
|
||||
|
||||
|
||||
def match_storage_paths_by_name(names: list[str], user: User) -> list[StoragePath]:
|
||||
def match_storage_paths_by_name(
|
||||
names: list[str],
|
||||
user: User,
|
||||
hinted_names: set[str] | None = None,
|
||||
) -> list[StoragePath]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_storagepath"],
|
||||
StoragePath,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
return _match_names_to_queryset(names, queryset, "name", hinted_names)
|
||||
|
||||
|
||||
def _normalize(s: str) -> str:
|
||||
@@ -58,10 +74,18 @@ def _normalize(s: str) -> str:
|
||||
return s
|
||||
|
||||
|
||||
def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||
def _match_names_to_queryset(
|
||||
names: list[str],
|
||||
queryset,
|
||||
attr: str,
|
||||
hinted_names: set[str] | None = None,
|
||||
):
|
||||
results = []
|
||||
objects = list(queryset)
|
||||
object_names = [_normalize(getattr(obj, attr)) for obj in objects]
|
||||
normalized_hints = (
|
||||
{_normalize(name) for name in hinted_names} if hinted_names else set()
|
||||
)
|
||||
|
||||
for name in names:
|
||||
if not name:
|
||||
@@ -76,6 +100,11 @@ def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||
results.append(matched)
|
||||
continue
|
||||
|
||||
# A hinted name that didn't exact-match came from existing taxonomy
|
||||
# verbatim; do not fuzzy-map it onto a different object.
|
||||
if target in normalized_hints:
|
||||
continue
|
||||
|
||||
# Fuzzy match fallback
|
||||
matches = difflib.get_close_matches(
|
||||
target,
|
||||
@@ -88,8 +117,6 @@ def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||
matched = objects.pop(index)
|
||||
object_names.pop(index)
|
||||
results.append(matched)
|
||||
else:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypedDict
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.indexing import retrieve_similar_nodes
|
||||
from paperless_ai.indexing import visible_document_ids_for_user
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
|
||||
class TaxonomyHints(TypedDict):
|
||||
tags: list[str]
|
||||
document_types: list[str]
|
||||
correspondents: list[str]
|
||||
storage_paths: list[str]
|
||||
|
||||
|
||||
def build_taxonomy_hints_from_nodes(
|
||||
nodes: list["NodeWithScore"],
|
||||
) -> TaxonomyHints:
|
||||
"""Collect the unique, sorted taxonomy names carried on retrieved nodes.
|
||||
|
||||
Reads ``tags`` (a list), ``document_type``, ``correspondent``, and
|
||||
``storage_path`` from each node's metadata. Empty / ``None`` values and
|
||||
missing keys are skipped. The result is naturally bounded by the retrieval
|
||||
``top_k``, so no cap is applied.
|
||||
"""
|
||||
tags: set[str] = set()
|
||||
document_types: set[str] = set()
|
||||
correspondents: set[str] = set()
|
||||
storage_paths: set[str] = set()
|
||||
|
||||
for node in nodes:
|
||||
metadata = node.metadata or {}
|
||||
|
||||
for tag in metadata.get("tags") or []:
|
||||
if tag:
|
||||
tags.add(tag)
|
||||
|
||||
document_type = metadata.get("document_type")
|
||||
if document_type:
|
||||
document_types.add(document_type)
|
||||
|
||||
correspondent = metadata.get("correspondent")
|
||||
if correspondent:
|
||||
correspondents.add(correspondent)
|
||||
|
||||
storage_path = metadata.get("storage_path")
|
||||
if storage_path:
|
||||
storage_paths.add(storage_path)
|
||||
|
||||
return TaxonomyHints(
|
||||
tags=sorted(tags),
|
||||
document_types=sorted(document_types),
|
||||
correspondents=sorted(correspondents),
|
||||
storage_paths=sorted(storage_paths),
|
||||
)
|
||||
|
||||
|
||||
_HINT_INSTRUCTION = (
|
||||
"Prefer existing names from these lists verbatim. Only propose a new value "
|
||||
"if none of the existing names fits."
|
||||
)
|
||||
|
||||
|
||||
def format_hints_for_prompt(hints: TaxonomyHints) -> str:
|
||||
"""Render non-empty hint categories as labelled blocks plus one instruction.
|
||||
|
||||
Returns "" when every category is empty, so callers can treat the result
|
||||
the same as no hints at all.
|
||||
"""
|
||||
# Literal-key access keeps this TypedDict-safe for mypy; the order here is
|
||||
# the order the blocks appear in the prompt.
|
||||
labelled_values: list[tuple[str, list[str]]] = [
|
||||
("Available tags", hints["tags"]),
|
||||
("Available document types", hints["document_types"]),
|
||||
("Available correspondents", hints["correspondents"]),
|
||||
("Available storage paths", hints["storage_paths"]),
|
||||
]
|
||||
blocks: list[str] = []
|
||||
for label, values in labelled_values:
|
||||
if values:
|
||||
listing = "\n".join(f"- {value}" for value in values)
|
||||
blocks.append(f"{label}:\n{listing}")
|
||||
|
||||
if not blocks:
|
||||
return ""
|
||||
|
||||
return "\n\n".join([*blocks, _HINT_INSTRUCTION])
|
||||
|
||||
|
||||
def get_taxonomy_hints_for_document(
|
||||
document: Document,
|
||||
user: User | None,
|
||||
) -> TaxonomyHints | None:
|
||||
"""Build taxonomy hints from a document's RAG neighbours.
|
||||
|
||||
Returns ``None`` when no embedding backend is configured (the gate) so the
|
||||
caller's prompt and matching are identical to today. Otherwise returns a
|
||||
``TaxonomyHints`` -- possibly all-empty when no similar documents exist.
|
||||
Applies the same owner-aware visible-document filter as
|
||||
``get_context_for_document``.
|
||||
"""
|
||||
if not AIConfig().llm_embedding_backend:
|
||||
return None
|
||||
|
||||
nodes = retrieve_similar_nodes(
|
||||
document=document,
|
||||
document_ids=visible_document_ids_for_user(user),
|
||||
)
|
||||
return build_taxonomy_hints_from_nodes(nodes)
|
||||
@@ -1,8 +1,11 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from django.test import override_settings
|
||||
|
||||
from documents.models import Document
|
||||
@@ -261,3 +264,111 @@ def test_get_context_for_document_no_similar_docs(mock_document):
|
||||
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
|
||||
result = get_context_for_document(mock_document)
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestPromptHints:
|
||||
@pytest.fixture
|
||||
def config(self) -> AIConfig:
|
||||
# build_prompt_* only read these two numeric settings off config;
|
||||
# a stand-in avoids constructing a DB-backed AIConfig.
|
||||
return cast(
|
||||
"AIConfig",
|
||||
SimpleNamespace(llm_embedding_chunk_size=1000, llm_context_size=8000),
|
||||
)
|
||||
|
||||
def test_without_rag_includes_hints_block(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
config: AIConfig,
|
||||
) -> None:
|
||||
hints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": ["Invoice"],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
prompt = build_prompt_without_rag(mock_document, config, hints=hints)
|
||||
assert "Available tags:" in prompt
|
||||
assert "- Bloodwork" in prompt
|
||||
assert "Prefer existing names from these lists verbatim" in prompt
|
||||
|
||||
def test_without_rag_none_matches_baseline(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
config: AIConfig,
|
||||
) -> None:
|
||||
baseline = build_prompt_without_rag(mock_document, config)
|
||||
with_none = build_prompt_without_rag(mock_document, config, hints=None)
|
||||
assert with_none == baseline
|
||||
assert "Available tags:" not in with_none
|
||||
|
||||
def test_with_rag_includes_context_and_hints(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
config: AIConfig,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.ai_classifier.get_context_for_document",
|
||||
return_value="TITLE: Neighbour\nsome context",
|
||||
)
|
||||
hints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
prompt = build_prompt_with_rag(mock_document, config, user=None, hints=hints)
|
||||
assert "Additional context from similar documents" in prompt
|
||||
assert "Available tags:" in prompt
|
||||
|
||||
def test_classification_forwards_hints(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.ai_classifier.AIConfig",
|
||||
return_value=SimpleNamespace(
|
||||
llm_embedding_backend=None,
|
||||
llm_embedding_chunk_size=1000,
|
||||
llm_context_size=8000,
|
||||
),
|
||||
)
|
||||
build = mocker.patch(
|
||||
"paperless_ai.ai_classifier.build_prompt_without_rag",
|
||||
return_value="PROMPT",
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.run_llm_query.return_value = {
|
||||
"title": "t",
|
||||
"tags": [],
|
||||
"correspondents": [],
|
||||
"document_types": [],
|
||||
"storage_paths": [],
|
||||
"dates": [],
|
||||
}
|
||||
mocker.patch("paperless_ai.ai_classifier.AIClient", return_value=mock_client)
|
||||
hints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
|
||||
result = get_ai_document_classification(
|
||||
mock_document,
|
||||
user=None,
|
||||
hints=hints,
|
||||
)
|
||||
|
||||
_, build_kwargs = build.call_args
|
||||
assert build_kwargs["hints"] == hints
|
||||
assert set(result.keys()) == {
|
||||
"title",
|
||||
"tags",
|
||||
"correspondents",
|
||||
"document_types",
|
||||
"storage_paths",
|
||||
"dates",
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -726,3 +727,58 @@ class TestQuerySimilarDocuments:
|
||||
results = indexing.query_similar_documents(a, document_ids=[b.id])
|
||||
|
||||
assert all(doc.id == b.id for doc in results)
|
||||
|
||||
|
||||
class TestRetrieveSimilarNodes:
|
||||
@pytest.mark.django_db
|
||||
def test_returns_raw_nodes_from_retriever(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=True)
|
||||
mocker.patch("paperless_ai.indexing.load_or_build_index")
|
||||
node1 = SimpleNamespace(metadata={"document_id": "1"})
|
||||
node2 = SimpleNamespace(metadata={"document_id": "2"})
|
||||
retriever = mocker.MagicMock()
|
||||
retriever.retrieve.return_value = [node1, node2]
|
||||
mocker.patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
return_value=retriever,
|
||||
)
|
||||
|
||||
result = indexing.retrieve_similar_nodes(real_document, top_k=3)
|
||||
|
||||
assert result == [node1, node2]
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_empty_allow_list_fails_closed(
|
||||
self,
|
||||
real_document: Document,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
load = mocker.patch("paperless_ai.indexing.load_or_build_index")
|
||||
|
||||
result = indexing.retrieve_similar_nodes(real_document, document_ids=[])
|
||||
|
||||
assert result == []
|
||||
load.assert_not_called()
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_queues_update_when_index_missing(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=False)
|
||||
queue = mocker.patch("paperless_ai.indexing.queue_llm_index_update_if_needed")
|
||||
|
||||
result = indexing.retrieve_similar_nodes(real_document, top_k=2)
|
||||
|
||||
assert result == []
|
||||
queue.assert_called_once_with(
|
||||
rebuild=False,
|
||||
reason="LLM index not found for similarity query.",
|
||||
)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import difflib
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from django.test import TestCase
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.tests.factories import TagFactory
|
||||
from paperless_ai.matching import extract_unmatched_names
|
||||
from paperless_ai.matching import match_correspondents_by_name
|
||||
from paperless_ai.matching import match_document_types_by_name
|
||||
@@ -87,6 +90,95 @@ class TestAIMatching(TestCase):
|
||||
self.assertEqual(result[1].name, "Test Tag 2")
|
||||
|
||||
|
||||
class TestHintedMatching:
|
||||
def test_hinted_verbatim_skips_fuzzy(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.matching.get_objects_for_user_owner_aware",
|
||||
return_value=[TagFactory.build(name="Bloodwork")],
|
||||
)
|
||||
spy = mocker.spy(difflib, "get_close_matches")
|
||||
|
||||
result = match_tags_by_name(
|
||||
["Bloodwork"],
|
||||
user=None,
|
||||
hinted_names={"Bloodwork"},
|
||||
)
|
||||
|
||||
assert [t.name for t in result] == ["Bloodwork"]
|
||||
spy.assert_not_called()
|
||||
|
||||
def test_unhinted_name_still_fuzzy_matches(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.matching.get_objects_for_user_owner_aware",
|
||||
return_value=[TagFactory.build(name="Bloodwork")],
|
||||
)
|
||||
|
||||
# "Bloodwrok" is a typo not in hints -> fuzzy still maps it to Bloodwork.
|
||||
result = match_tags_by_name(
|
||||
["Bloodwrok"],
|
||||
user=None,
|
||||
hinted_names={"Taxes"},
|
||||
)
|
||||
|
||||
assert [t.name for t in result] == ["Bloodwork"]
|
||||
|
||||
def test_hinted_name_with_whitespace_exact_matches(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.matching.get_objects_for_user_owner_aware",
|
||||
return_value=[TagFactory.build(name="Bloodwork")],
|
||||
)
|
||||
spy = mocker.spy(difflib, "get_close_matches")
|
||||
|
||||
result = match_tags_by_name(
|
||||
["Bloodwork "],
|
||||
user=None,
|
||||
hinted_names={"Bloodwork"},
|
||||
)
|
||||
|
||||
assert [t.name for t in result] == ["Bloodwork"]
|
||||
spy.assert_not_called()
|
||||
|
||||
def test_hinted_name_absent_from_queryset_is_skipped_not_fuzzed(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
# A hint with no exact object must not fall through to fuzzy.
|
||||
mocker.patch(
|
||||
"paperless_ai.matching.get_objects_for_user_owner_aware",
|
||||
return_value=[TagFactory.build(name="Bloodwork")],
|
||||
)
|
||||
|
||||
result = match_tags_by_name(
|
||||
["Bloodwrok"],
|
||||
user=None,
|
||||
hinted_names={"Bloodwrok"},
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_backward_compatible_without_kwarg(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.matching.get_objects_for_user_owner_aware",
|
||||
return_value=[TagFactory.build(name="Test Tag 1")],
|
||||
)
|
||||
|
||||
result = match_tags_by_name(["Test Tag 1", "Nonexistent"], user=None)
|
||||
|
||||
assert [t.name for t in result] == ["Test Tag 1"]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestExtractUnmatchedNamesNormalization:
|
||||
def test_punctuated_name_already_matched_is_not_returned_as_unmatched(
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest_mock
|
||||
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from paperless_ai.taxonomy import TaxonomyHints
|
||||
from paperless_ai.taxonomy import build_taxonomy_hints_from_nodes
|
||||
from paperless_ai.taxonomy import format_hints_for_prompt
|
||||
from paperless_ai.taxonomy import get_taxonomy_hints_for_document
|
||||
|
||||
|
||||
def make_node(**metadata: object) -> SimpleNamespace:
|
||||
"""A stand-in for NodeWithScore: only ``.metadata`` is accessed."""
|
||||
return SimpleNamespace(metadata=metadata)
|
||||
|
||||
|
||||
class TestBuildTaxonomyHintsFromNodes:
|
||||
def test_returns_all_four_keys(self) -> None:
|
||||
hints = build_taxonomy_hints_from_nodes([])
|
||||
assert set(hints.keys()) == {
|
||||
"tags",
|
||||
"document_types",
|
||||
"correspondents",
|
||||
"storage_paths",
|
||||
}
|
||||
|
||||
def test_collects_and_sorts_values(self) -> None:
|
||||
nodes = [
|
||||
make_node(
|
||||
tags=["Taxes", "Bloodwork"],
|
||||
document_type="Invoice",
|
||||
correspondent="IRS",
|
||||
storage_path="Financial",
|
||||
),
|
||||
]
|
||||
hints = build_taxonomy_hints_from_nodes(nodes)
|
||||
assert hints["tags"] == ["Bloodwork", "Taxes"]
|
||||
assert hints["document_types"] == ["Invoice"]
|
||||
assert hints["correspondents"] == ["IRS"]
|
||||
assert hints["storage_paths"] == ["Financial"]
|
||||
|
||||
def test_deduplicates_across_nodes(self) -> None:
|
||||
nodes = [
|
||||
make_node(tags=["Taxes"], document_type="Invoice"),
|
||||
make_node(tags=["Taxes", "Medical"], document_type="Invoice"),
|
||||
]
|
||||
hints = build_taxonomy_hints_from_nodes(nodes)
|
||||
assert hints["tags"] == ["Medical", "Taxes"]
|
||||
assert hints["document_types"] == ["Invoice"]
|
||||
|
||||
def test_none_values_skipped(self) -> None:
|
||||
nodes = [
|
||||
make_node(
|
||||
tags=["Taxes", None, ""],
|
||||
document_type=None,
|
||||
correspondent=None,
|
||||
storage_path=None,
|
||||
),
|
||||
]
|
||||
hints = build_taxonomy_hints_from_nodes(nodes)
|
||||
assert hints["tags"] == ["Taxes"]
|
||||
assert hints["document_types"] == []
|
||||
assert hints["correspondents"] == []
|
||||
assert hints["storage_paths"] == []
|
||||
|
||||
def test_missing_storage_path_key_handled(self) -> None:
|
||||
# Pre-enrichment nodes have no storage_path key at all.
|
||||
nodes = [make_node(tags=["Taxes"], document_type="Invoice")]
|
||||
hints = build_taxonomy_hints_from_nodes(nodes)
|
||||
assert hints["storage_paths"] == []
|
||||
|
||||
def test_empty_node_list_all_empty(self) -> None:
|
||||
hints = build_taxonomy_hints_from_nodes([])
|
||||
assert hints == {
|
||||
"tags": [],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
|
||||
def test_output_stable_across_calls(self) -> None:
|
||||
nodes = [make_node(tags=["b", "a", "c"])]
|
||||
assert build_taxonomy_hints_from_nodes(
|
||||
nodes,
|
||||
) == build_taxonomy_hints_from_nodes(nodes)
|
||||
|
||||
|
||||
class TestFormatHintsForPrompt:
|
||||
def test_all_blocks_present_when_all_categories_nonempty(self) -> None:
|
||||
hints: TaxonomyHints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": ["Invoice"],
|
||||
"correspondents": ["IRS"],
|
||||
"storage_paths": ["Financial"],
|
||||
}
|
||||
result = format_hints_for_prompt(hints)
|
||||
assert "Available tags:" in result
|
||||
assert "Available document types:" in result
|
||||
assert "Available correspondents:" in result
|
||||
assert "Available storage paths:" in result
|
||||
assert "- Bloodwork" in result
|
||||
|
||||
def test_empty_category_produces_no_block(self) -> None:
|
||||
hints: TaxonomyHints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
result = format_hints_for_prompt(hints)
|
||||
assert "Available tags:" in result
|
||||
assert "Available document types:" not in result
|
||||
assert "Available correspondents:" not in result
|
||||
assert "Available storage paths:" not in result
|
||||
|
||||
def test_all_empty_produces_empty_string(self) -> None:
|
||||
hints: TaxonomyHints = {
|
||||
"tags": [],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
assert format_hints_for_prompt(hints) == ""
|
||||
|
||||
def test_instruction_line_appears_once(self) -> None:
|
||||
hints: TaxonomyHints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": ["Invoice"],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
result = format_hints_for_prompt(hints)
|
||||
assert result.count("Prefer existing names from these lists verbatim") == 1
|
||||
|
||||
|
||||
class TestGetTaxonomyHintsForDocument:
|
||||
def test_returns_none_when_embedding_backend_off(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.AIConfig",
|
||||
return_value=SimpleNamespace(llm_embedding_backend=None),
|
||||
)
|
||||
retrieve = mocker.patch("paperless_ai.taxonomy.retrieve_similar_nodes")
|
||||
|
||||
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
|
||||
|
||||
assert result is None
|
||||
retrieve.assert_not_called()
|
||||
|
||||
def test_passes_owner_aware_ids_when_user_present(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.AIConfig",
|
||||
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
|
||||
)
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.visible_document_ids_for_user",
|
||||
return_value=[1, 2, 3],
|
||||
)
|
||||
retrieve = mocker.patch(
|
||||
"paperless_ai.taxonomy.retrieve_similar_nodes",
|
||||
return_value=[],
|
||||
)
|
||||
document = DocumentFactory.build()
|
||||
user = mocker.MagicMock()
|
||||
|
||||
get_taxonomy_hints_for_document(document, user=user)
|
||||
|
||||
retrieve.assert_called_once_with(
|
||||
document=document,
|
||||
document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
def test_returns_populated_hints_when_nodes_found(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.AIConfig",
|
||||
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
|
||||
)
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.retrieve_similar_nodes",
|
||||
return_value=[make_node(tags=["Taxes"], document_type="Invoice")],
|
||||
)
|
||||
|
||||
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
|
||||
|
||||
assert result == {
|
||||
"tags": ["Taxes"],
|
||||
"document_types": ["Invoice"],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
|
||||
def test_returns_empty_hints_not_none_when_no_nodes(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.AIConfig",
|
||||
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
|
||||
)
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.retrieve_similar_nodes",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
|
||||
|
||||
assert result == {
|
||||
"tags": [],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from django.contrib.auth.models import User
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from documents.models import Document
|
||||
from documents.tests.factories import DocumentFactory
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestSuggestionsHintWiring:
|
||||
@pytest.fixture
|
||||
def document(self) -> Document:
|
||||
return DocumentFactory() # type: ignore[return-value]
|
||||
|
||||
@pytest.fixture
|
||||
def api_client(self, admin_user: User) -> APIClient:
|
||||
client = APIClient()
|
||||
client.force_authenticate(user=admin_user)
|
||||
return client
|
||||
|
||||
def test_hints_passed_to_classifier_and_matchers(
|
||||
self,
|
||||
api_client: APIClient,
|
||||
document: Document,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
hints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
mocker.patch(
|
||||
"documents.views.get_taxonomy_hints_for_document",
|
||||
return_value=hints,
|
||||
)
|
||||
mocker.patch(
|
||||
"documents.views.AIConfig",
|
||||
return_value=SimpleNamespace(
|
||||
ai_enabled=True,
|
||||
llm_backend="ollama",
|
||||
llm_output_language=None,
|
||||
),
|
||||
)
|
||||
# No cached suggestion -> the view reaches the classifier path.
|
||||
mocker.patch(
|
||||
"documents.views.get_llm_suggestion_cache",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch("documents.views.set_llm_suggestions_cache")
|
||||
classify = mocker.patch(
|
||||
"documents.views.get_ai_document_classification",
|
||||
return_value={
|
||||
"title": "Doc",
|
||||
"tags": ["Bloodwork"],
|
||||
"correspondents": [],
|
||||
"document_types": [],
|
||||
"storage_paths": [],
|
||||
"dates": [],
|
||||
},
|
||||
)
|
||||
match_tags = mocker.patch(
|
||||
"documents.views.match_tags_by_name",
|
||||
return_value=[],
|
||||
)
|
||||
mocker.patch("documents.views.match_correspondents_by_name", return_value=[])
|
||||
mocker.patch("documents.views.match_document_types_by_name", return_value=[])
|
||||
mocker.patch("documents.views.match_storage_paths_by_name", return_value=[])
|
||||
|
||||
response = api_client.get(f"/api/documents/{document.pk}/ai_suggestions/")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert classify.call_args.kwargs["hints"] == hints
|
||||
assert match_tags.call_args.kwargs["hinted_names"] == {"Bloodwork"}
|
||||
Reference in New Issue
Block a user