Enhancement: try to respect language for AI suggestions (#12894)

This commit is contained in:
shamoon
2026-06-01 12:11:46 -07:00
committed by GitHub
parent f6c865bf47
commit 27426c04b0
4 changed files with 204 additions and 15 deletions
+48
View File
@@ -25,6 +25,7 @@ from documents.models import DocumentType
from documents.models import ShareLink
from documents.models import StoragePath
from documents.models import Tag
from documents.models import UiSettings
from documents.signals.handlers import update_llm_suggestions_cache
from documents.tests.utils import DirectoriesMixin
from documents.tests.utils import read_streaming_response
@@ -319,6 +320,10 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"tags": ["tag1", "tag2"]})
mock_get_cache.assert_called_once_with(
self.document.pk,
backend="mock_backend",
)
mock_refresh_cache.assert_called_once_with(self.document.pk)
@patch("documents.views.get_ai_document_classification")
@@ -359,6 +364,49 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
"dates": ["2023-01-01"],
},
)
mock_get_ai_classification.assert_called_once_with(
self.document,
self.user,
None,
)
@patch("documents.views.get_ai_document_classification")
@override_settings(
AI_ENABLED=True,
LLM_BACKEND="mock_backend",
)
def test_ai_suggestions_uses_user_display_language(
self,
mock_get_ai_classification,
) -> None:
UiSettings.objects.create(user=self.user, settings={"language": "de-de"})
mock_get_ai_classification.return_value = {
"title": "KI Title",
"tags": [],
"correspondents": [],
"document_types": [],
"storage_paths": [],
"dates": [],
}
self.client.force_login(user=self.user)
response = self.client.get(
f"/api/documents/{self.document.pk}/ai_suggestions/",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
mock_get_ai_classification.assert_called_once_with(
self.document,
self.user,
"de-de",
)
self.assertEqual(
get_llm_suggestion_cache(
self.document.pk,
backend="mock_backend:de-de",
).suggestions["title"],
"KI Title",
)
@patch("documents.views.get_ai_document_classification")
@override_settings(
+19 -3
View File
@@ -1469,9 +1469,21 @@ class DocumentViewSet(
if not ai_config.ai_enabled:
return HttpResponseBadRequest("AI is required for this feature")
output_language = None
if hasattr(request.user, "ui_settings") and isinstance(
request.user.ui_settings.settings,
dict,
):
output_language = request.user.ui_settings.settings.get("language") or None
llm_cache_backend = (
f"{ai_config.llm_backend}:{output_language}"
if output_language
else ai_config.llm_backend
)
cached_llm_suggestions = get_llm_suggestion_cache(
doc.pk,
backend=ai_config.llm_backend,
backend=llm_cache_backend,
)
if cached_llm_suggestions:
@@ -1479,7 +1491,11 @@ class DocumentViewSet(
return Response(cached_llm_suggestions.suggestions)
try:
llm_suggestions = get_ai_document_classification(doc, request.user)
llm_suggestions = get_ai_document_classification(
doc,
request.user,
output_language,
)
except ValueError as exc:
logger.exception(
"Invalid AI configuration while generating suggestions for "
@@ -1532,7 +1548,7 @@ class DocumentViewSet(
"dates": llm_suggestions.get("dates", []),
}
set_llm_suggestions_cache(doc.pk, resp_data, backend=ai_config.llm_backend)
set_llm_suggestions_cache(doc.pk, resp_data, backend=llm_cache_backend)
return Response(resp_data)
+53 -3
View File
@@ -1,5 +1,7 @@
import json
import logging
from django.conf import settings
from django.contrib.auth.models import User
from documents.models import Document
@@ -12,7 +14,17 @@ from paperless_ai.indexing import truncate_content
logger = logging.getLogger("paperless_ai.rag_classifier")
def build_prompt_without_rag(document: Document) -> str:
def get_language_name(language_code: str) -> str:
normalized_language_code = language_code.lower()
for code, name in settings.LANGUAGES:
if code.lower() == normalized_language_code:
return str(name)
return language_code
def build_prompt_without_rag(
document: Document,
) -> str:
filename = document.filename or ""
content = truncate_content(document.content[:4000] or "")
@@ -35,7 +47,10 @@ def build_prompt_without_rag(document: Document) -> str:
""".strip()
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
def build_prompt_with_rag(
document: Document,
user: User | None = None,
) -> str:
base_prompt = build_prompt_without_rag(document)
context = truncate_content(get_context_for_document(document, user))
@@ -46,6 +61,25 @@ def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
""".strip()
def build_localization_prompt(suggestions: dict, output_language: str) -> str:
language_name = get_language_name(output_language)
return f"""
You are localizing document classification suggestions for display in Paperless-ngx.
Rewrite only these generated fields in {language_name}: title, tags,
document_types, storage_paths.
Do not translate correspondents or dates.
Preserve proper nouns, organization names, product names, and exact official
document names. Translate generic category words when a {language_name}
equivalent exists.
Return the same JSON schema with all fields present.
Suggestions:
{json.dumps(suggestions)}
""".strip()
def get_context_for_document(
doc: Document,
user: User | None = None,
@@ -91,6 +125,7 @@ def parse_ai_response(raw: dict) -> dict:
def get_ai_document_classification(
document: Document,
user: User | None = None,
output_language: str | None = None,
) -> dict:
ai_config = AIConfig()
@@ -102,4 +137,19 @@ def get_ai_document_classification(
client = AIClient()
result = client.run_llm_query(prompt)
return parse_ai_response(result)
suggestions = parse_ai_response(result)
if output_language:
localized = client.run_llm_query(
build_localization_prompt(suggestions, output_language),
)
localized_suggestions = parse_ai_response(localized)
suggestions = {
**suggestions,
"title": localized_suggestions["title"] or suggestions["title"],
"tags": localized_suggestions["tags"] or suggestions["tags"],
"document_types": localized_suggestions["document_types"]
or suggestions["document_types"],
"storage_paths": localized_suggestions["storage_paths"]
or suggestions["storage_paths"],
}
return suggestions
+84 -9
View File
@@ -6,10 +6,12 @@ import pytest
from django.test import override_settings
from documents.models import Document
from paperless_ai.ai_classifier import build_localization_prompt
from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.ai_classifier import get_context_for_document
from paperless_ai.ai_classifier import get_language_name
@pytest.fixture
@@ -74,16 +76,70 @@ def mock_similar_documents():
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
mock_run_llm_query.return_value = {
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
}
mock_run_llm_query.side_effect = [
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
{
"title": "Testtitel",
"tags": ["Test", "Document"],
"correspondents": ["Jane Doe"],
"document_types": ["Bericht"],
"storage_paths": ["Berichte"],
"dates": ["2024-01-01"],
},
]
result = get_ai_document_classification(mock_document)
result = get_ai_document_classification(mock_document, output_language="de-de")
assert result["title"] == "Testtitel"
assert result["tags"] == ["Test", "Document"]
assert result["correspondents"] == ["John Doe"]
assert result["document_types"] == ["Bericht"]
assert result["storage_paths"] == ["Berichte"]
assert result["dates"] == ["2023-01-01"]
classification_prompt = mock_run_llm_query.call_args_list[0].args[0]
localization_prompt = mock_run_llm_query.call_args_list[1].args[0]
assert "Write suggested titles" not in classification_prompt
assert "Rewrite only these generated fields in German" in localization_prompt
assert "Do not translate correspondents or dates" in localization_prompt
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_keeps_originals_when_localization_empty(
mock_run_llm_query,
mock_document,
):
mock_run_llm_query.side_effect = [
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
{
"title": "",
"tags": [],
"correspondents": [],
"document_types": [],
"storage_paths": [],
"dates": [],
},
]
result = get_ai_document_classification(mock_document, output_language="de-de")
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
@@ -157,10 +213,29 @@ def test_prompt_with_without_rag(mock_document):
):
prompt = build_prompt_without_rag(mock_document)
assert "Additional context from similar documents" not in prompt
assert "for generated" not in prompt
prompt = build_prompt_with_rag(mock_document)
assert "Additional context from similar documents" in prompt
prompt = build_localization_prompt(
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
output_language="de-de",
)
assert "Rewrite only these generated fields in German" in prompt
assert "Do not translate correspondents or dates" in prompt
def test_get_language_name_falls_back_to_language_code():
assert get_language_name("zz-zz") == "zz-zz"
@patch("paperless_ai.ai_classifier.query_similar_documents")
def test_get_context_for_document(