diff --git a/src/documents/tests/test_views.py b/src/documents/tests/test_views.py index a67590b81..7c8774506 100644 --- a/src/documents/tests/test_views.py +++ b/src/documents/tests/test_views.py @@ -368,6 +368,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase): self.document, self.user, None, + hints=None, ) @patch("documents.views.get_ai_document_classification") @@ -399,6 +400,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase): self.document, self.user, "de-de", + hints=None, ) self.assertEqual( get_llm_suggestion_cache( @@ -438,6 +440,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase): self.document, self.user, "fr-fr", + hints=None, ) self.assertEqual( get_llm_suggestion_cache( diff --git a/src/documents/views.py b/src/documents/views.py index 5ed6fdaf5..51cbd63ba 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -245,6 +245,7 @@ from paperless_ai.matching import match_correspondents_by_name from paperless_ai.matching import match_document_types_by_name from paperless_ai.matching import match_storage_paths_by_name from paperless_ai.matching import match_tags_by_name +from paperless_ai.taxonomy import get_taxonomy_hints_for_document from paperless_mail.models import MailAccount from paperless_mail.models import MailRule from paperless_mail.oauth import PaperlessMailOAuth2Manager @@ -1494,11 +1495,14 @@ class DocumentViewSet( refresh_suggestions_cache(doc.pk) return Response(cached_llm_suggestions.suggestions) + hints = get_taxonomy_hints_for_document(doc, request.user) + try: llm_suggestions = get_ai_document_classification( doc, request.user, output_language, + hints=hints, ) except ValueError as exc: logger.exception( @@ -1513,18 +1517,22 @@ class DocumentViewSet( matched_tags = match_tags_by_name( llm_suggestions.get("tags", []), request.user, + hinted_names=set(hints["tags"]) if hints else None, ) matched_correspondents = match_correspondents_by_name( llm_suggestions.get("correspondents", []), request.user, + hinted_names=set(hints["correspondents"]) if hints else None, ) matched_types = match_document_types_by_name( llm_suggestions.get("document_types", []), request.user, + hinted_names=set(hints["document_types"]) if hints else None, ) matched_paths = match_storage_paths_by_name( llm_suggestions.get("storage_paths", []), request.user, + hinted_names=set(hints["storage_paths"]) if hints else None, ) resp_data = { diff --git a/src/paperless_ai/tests/test_views_suggestions.py b/src/paperless_ai/tests/test_views_suggestions.py new file mode 100644 index 000000000..e13c6cba1 --- /dev/null +++ b/src/paperless_ai/tests/test_views_suggestions.py @@ -0,0 +1,85 @@ +from types import SimpleNamespace + +import pytest +import pytest_mock +from django.contrib.auth.models import User +from rest_framework.test import APIClient + +from documents.models import Document + + +@pytest.mark.django_db +class TestSuggestionsHintWiring: + @pytest.fixture + def user(self) -> User: + return User.objects.create_superuser(username="admin", password="pw") + + @pytest.fixture + def document(self, user: User) -> Document: + return Document.objects.create( + title="Doc", + content="content", + checksum="abc123", + mime_type="application/pdf", + ) + + @pytest.fixture + def api_client(self, user: User) -> APIClient: + client = APIClient() + client.force_authenticate(user=user) + return client + + def test_hints_passed_to_classifier_and_matchers( + self, + api_client: APIClient, + document: Document, + mocker: pytest_mock.MockerFixture, + ) -> None: + hints = { + "tags": ["Bloodwork"], + "document_types": [], + "correspondents": [], + "storage_paths": [], + } + mocker.patch( + "documents.views.get_taxonomy_hints_for_document", + return_value=hints, + ) + mocker.patch( + "documents.views.AIConfig", + return_value=SimpleNamespace( + ai_enabled=True, + llm_backend="ollama", + llm_output_language=None, + ), + ) + # No cached suggestion -> the view reaches the classifier path. + mocker.patch( + "documents.views.get_llm_suggestion_cache", + return_value=None, + ) + mocker.patch("documents.views.set_llm_suggestions_cache") + classify = mocker.patch( + "documents.views.get_ai_document_classification", + return_value={ + "title": "Doc", + "tags": ["Bloodwork"], + "correspondents": [], + "document_types": [], + "storage_paths": [], + "dates": [], + }, + ) + match_tags = mocker.patch( + "documents.views.match_tags_by_name", + return_value=[], + ) + mocker.patch("documents.views.match_correspondents_by_name", return_value=[]) + mocker.patch("documents.views.match_document_types_by_name", return_value=[]) + mocker.patch("documents.views.match_storage_paths_by_name", return_value=[]) + + response = api_client.get(f"/api/documents/{document.pk}/ai_suggestions/") + + assert response.status_code == 200 + assert classify.call_args.kwargs["hints"] == hints + assert match_tags.call_args.kwargs["hinted_names"] == {"Bloodwork"}