diff --git a/src/documents/tests/test_views.py b/src/documents/tests/test_views.py index b8ab42256..90736849f 100644 --- a/src/documents/tests/test_views.py +++ b/src/documents/tests/test_views.py @@ -25,6 +25,7 @@ from documents.models import DocumentType from documents.models import ShareLink from documents.models import StoragePath from documents.models import Tag +from documents.models import UiSettings from documents.signals.handlers import update_llm_suggestions_cache from documents.tests.utils import DirectoriesMixin from documents.tests.utils import read_streaming_response @@ -319,6 +320,10 @@ class TestAISuggestions(DirectoriesMixin, TestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.json(), {"tags": ["tag1", "tag2"]}) + mock_get_cache.assert_called_once_with( + self.document.pk, + backend="mock_backend", + ) mock_refresh_cache.assert_called_once_with(self.document.pk) @patch("documents.views.get_ai_document_classification") @@ -359,6 +364,49 @@ class TestAISuggestions(DirectoriesMixin, TestCase): "dates": ["2023-01-01"], }, ) + mock_get_ai_classification.assert_called_once_with( + self.document, + self.user, + None, + ) + + @patch("documents.views.get_ai_document_classification") + @override_settings( + AI_ENABLED=True, + LLM_BACKEND="mock_backend", + ) + def test_ai_suggestions_uses_user_display_language( + self, + mock_get_ai_classification, + ) -> None: + UiSettings.objects.create(user=self.user, settings={"language": "de-de"}) + mock_get_ai_classification.return_value = { + "title": "KI Title", + "tags": [], + "correspondents": [], + "document_types": [], + "storage_paths": [], + "dates": [], + } + + self.client.force_login(user=self.user) + response = self.client.get( + f"/api/documents/{self.document.pk}/ai_suggestions/", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + mock_get_ai_classification.assert_called_once_with( + self.document, + self.user, + "de-de", + ) + self.assertEqual( + get_llm_suggestion_cache( + self.document.pk, + backend="mock_backend:de-de", + ).suggestions["title"], + "KI Title", + ) @patch("documents.views.get_ai_document_classification") @override_settings( diff --git a/src/documents/views.py b/src/documents/views.py index b2f0d0994..e92ca2d1d 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1469,9 +1469,21 @@ class DocumentViewSet( if not ai_config.ai_enabled: return HttpResponseBadRequest("AI is required for this feature") + output_language = None + if hasattr(request.user, "ui_settings") and isinstance( + request.user.ui_settings.settings, + dict, + ): + output_language = request.user.ui_settings.settings.get("language") or None + llm_cache_backend = ( + f"{ai_config.llm_backend}:{output_language}" + if output_language + else ai_config.llm_backend + ) + cached_llm_suggestions = get_llm_suggestion_cache( doc.pk, - backend=ai_config.llm_backend, + backend=llm_cache_backend, ) if cached_llm_suggestions: @@ -1479,7 +1491,11 @@ class DocumentViewSet( return Response(cached_llm_suggestions.suggestions) try: - llm_suggestions = get_ai_document_classification(doc, request.user) + llm_suggestions = get_ai_document_classification( + doc, + request.user, + output_language, + ) except ValueError as exc: logger.exception( "Invalid AI configuration while generating suggestions for " @@ -1532,7 +1548,7 @@ class DocumentViewSet( "dates": llm_suggestions.get("dates", []), } - set_llm_suggestions_cache(doc.pk, resp_data, backend=ai_config.llm_backend) + set_llm_suggestions_cache(doc.pk, resp_data, backend=llm_cache_backend) return Response(resp_data) diff --git a/src/paperless_ai/ai_classifier.py b/src/paperless_ai/ai_classifier.py index ce6ca2f7a..c3e27cd41 100644 --- a/src/paperless_ai/ai_classifier.py +++ b/src/paperless_ai/ai_classifier.py @@ -1,5 +1,7 @@ +import json import logging +from django.conf import settings from django.contrib.auth.models import User from documents.models import Document @@ -12,7 +14,17 @@ from paperless_ai.indexing import truncate_content logger = logging.getLogger("paperless_ai.rag_classifier") -def build_prompt_without_rag(document: Document) -> str: +def get_language_name(language_code: str) -> str: + normalized_language_code = language_code.lower() + for code, name in settings.LANGUAGES: + if code.lower() == normalized_language_code: + return str(name) + return language_code + + +def build_prompt_without_rag( + document: Document, +) -> str: filename = document.filename or "" content = truncate_content(document.content[:4000] or "") @@ -35,7 +47,10 @@ def build_prompt_without_rag(document: Document) -> str: """.strip() -def build_prompt_with_rag(document: Document, user: User | None = None) -> str: +def build_prompt_with_rag( + document: Document, + user: User | None = None, +) -> str: base_prompt = build_prompt_without_rag(document) context = truncate_content(get_context_for_document(document, user)) @@ -46,6 +61,25 @@ def build_prompt_with_rag(document: Document, user: User | None = None) -> str: """.strip() +def build_localization_prompt(suggestions: dict, output_language: str) -> str: + language_name = get_language_name(output_language) + return f""" + You are localizing document classification suggestions for display in Paperless-ngx. + + Rewrite only these generated fields in {language_name}: title, tags, + document_types, storage_paths. + + Do not translate correspondents or dates. + Preserve proper nouns, organization names, product names, and exact official + document names. Translate generic category words when a {language_name} + equivalent exists. + Return the same JSON schema with all fields present. + + Suggestions: + {json.dumps(suggestions)} + """.strip() + + def get_context_for_document( doc: Document, user: User | None = None, @@ -91,6 +125,7 @@ def parse_ai_response(raw: dict) -> dict: def get_ai_document_classification( document: Document, user: User | None = None, + output_language: str | None = None, ) -> dict: ai_config = AIConfig() @@ -102,4 +137,19 @@ def get_ai_document_classification( client = AIClient() result = client.run_llm_query(prompt) - return parse_ai_response(result) + suggestions = parse_ai_response(result) + if output_language: + localized = client.run_llm_query( + build_localization_prompt(suggestions, output_language), + ) + localized_suggestions = parse_ai_response(localized) + suggestions = { + **suggestions, + "title": localized_suggestions["title"] or suggestions["title"], + "tags": localized_suggestions["tags"] or suggestions["tags"], + "document_types": localized_suggestions["document_types"] + or suggestions["document_types"], + "storage_paths": localized_suggestions["storage_paths"] + or suggestions["storage_paths"], + } + return suggestions diff --git a/src/paperless_ai/tests/test_ai_classifier.py b/src/paperless_ai/tests/test_ai_classifier.py index 3a6535b57..97e18eb47 100644 --- a/src/paperless_ai/tests/test_ai_classifier.py +++ b/src/paperless_ai/tests/test_ai_classifier.py @@ -6,10 +6,12 @@ import pytest from django.test import override_settings from documents.models import Document +from paperless_ai.ai_classifier import build_localization_prompt from paperless_ai.ai_classifier import build_prompt_with_rag from paperless_ai.ai_classifier import build_prompt_without_rag from paperless_ai.ai_classifier import get_ai_document_classification from paperless_ai.ai_classifier import get_context_for_document +from paperless_ai.ai_classifier import get_language_name @pytest.fixture @@ -74,16 +76,70 @@ def mock_similar_documents(): LLM_MODEL="some_model", ) def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): - mock_run_llm_query.return_value = { - "title": "Test Title", - "tags": ["test", "document"], - "correspondents": ["John Doe"], - "document_types": ["report"], - "storage_paths": ["Reports"], - "dates": ["2023-01-01"], - } + mock_run_llm_query.side_effect = [ + { + "title": "Test Title", + "tags": ["test", "document"], + "correspondents": ["John Doe"], + "document_types": ["report"], + "storage_paths": ["Reports"], + "dates": ["2023-01-01"], + }, + { + "title": "Testtitel", + "tags": ["Test", "Document"], + "correspondents": ["Jane Doe"], + "document_types": ["Bericht"], + "storage_paths": ["Berichte"], + "dates": ["2024-01-01"], + }, + ] - result = get_ai_document_classification(mock_document) + result = get_ai_document_classification(mock_document, output_language="de-de") + + assert result["title"] == "Testtitel" + assert result["tags"] == ["Test", "Document"] + assert result["correspondents"] == ["John Doe"] + assert result["document_types"] == ["Bericht"] + assert result["storage_paths"] == ["Berichte"] + assert result["dates"] == ["2023-01-01"] + classification_prompt = mock_run_llm_query.call_args_list[0].args[0] + localization_prompt = mock_run_llm_query.call_args_list[1].args[0] + assert "Write suggested titles" not in classification_prompt + assert "Rewrite only these generated fields in German" in localization_prompt + assert "Do not translate correspondents or dates" in localization_prompt + + +@pytest.mark.django_db +@patch("paperless_ai.client.AIClient.run_llm_query") +@override_settings( + LLM_BACKEND="ollama", + LLM_MODEL="some_model", +) +def test_get_ai_document_classification_keeps_originals_when_localization_empty( + mock_run_llm_query, + mock_document, +): + mock_run_llm_query.side_effect = [ + { + "title": "Test Title", + "tags": ["test", "document"], + "correspondents": ["John Doe"], + "document_types": ["report"], + "storage_paths": ["Reports"], + "dates": ["2023-01-01"], + }, + { + "title": "", + "tags": [], + "correspondents": [], + "document_types": [], + "storage_paths": [], + "dates": [], + }, + ] + + result = get_ai_document_classification(mock_document, output_language="de-de") assert result["title"] == "Test Title" assert result["tags"] == ["test", "document"] @@ -157,10 +213,29 @@ def test_prompt_with_without_rag(mock_document): ): prompt = build_prompt_without_rag(mock_document) assert "Additional context from similar documents" not in prompt + assert "for generated" not in prompt prompt = build_prompt_with_rag(mock_document) assert "Additional context from similar documents" in prompt + prompt = build_localization_prompt( + { + "title": "Test Title", + "tags": ["test", "document"], + "correspondents": ["John Doe"], + "document_types": ["report"], + "storage_paths": ["Reports"], + "dates": ["2023-01-01"], + }, + output_language="de-de", + ) + assert "Rewrite only these generated fields in German" in prompt + assert "Do not translate correspondents or dates" in prompt + + +def test_get_language_name_falls_back_to_language_code(): + assert get_language_name("zz-zz") == "zz-zz" + @patch("paperless_ai.ai_classifier.query_similar_documents") def test_get_context_for_document(