mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-03 12:19:45 +00:00
Enhancement: try to respect language for AI suggestions (#12894)
This commit is contained in:
@@ -25,6 +25,7 @@ from documents.models import DocumentType
|
||||
from documents.models import ShareLink
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.models import UiSettings
|
||||
from documents.signals.handlers import update_llm_suggestions_cache
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
from documents.tests.utils import read_streaming_response
|
||||
@@ -319,6 +320,10 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.json(), {"tags": ["tag1", "tag2"]})
|
||||
mock_get_cache.assert_called_once_with(
|
||||
self.document.pk,
|
||||
backend="mock_backend",
|
||||
)
|
||||
mock_refresh_cache.assert_called_once_with(self.document.pk)
|
||||
|
||||
@patch("documents.views.get_ai_document_classification")
|
||||
@@ -359,6 +364,49 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
)
|
||||
mock_get_ai_classification.assert_called_once_with(
|
||||
self.document,
|
||||
self.user,
|
||||
None,
|
||||
)
|
||||
|
||||
@patch("documents.views.get_ai_document_classification")
|
||||
@override_settings(
|
||||
AI_ENABLED=True,
|
||||
LLM_BACKEND="mock_backend",
|
||||
)
|
||||
def test_ai_suggestions_uses_user_display_language(
|
||||
self,
|
||||
mock_get_ai_classification,
|
||||
) -> None:
|
||||
UiSettings.objects.create(user=self.user, settings={"language": "de-de"})
|
||||
mock_get_ai_classification.return_value = {
|
||||
"title": "KI Title",
|
||||
"tags": [],
|
||||
"correspondents": [],
|
||||
"document_types": [],
|
||||
"storage_paths": [],
|
||||
"dates": [],
|
||||
}
|
||||
|
||||
self.client.force_login(user=self.user)
|
||||
response = self.client.get(
|
||||
f"/api/documents/{self.document.pk}/ai_suggestions/",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
mock_get_ai_classification.assert_called_once_with(
|
||||
self.document,
|
||||
self.user,
|
||||
"de-de",
|
||||
)
|
||||
self.assertEqual(
|
||||
get_llm_suggestion_cache(
|
||||
self.document.pk,
|
||||
backend="mock_backend:de-de",
|
||||
).suggestions["title"],
|
||||
"KI Title",
|
||||
)
|
||||
|
||||
@patch("documents.views.get_ai_document_classification")
|
||||
@override_settings(
|
||||
|
||||
+19
-3
@@ -1469,9 +1469,21 @@ class DocumentViewSet(
|
||||
if not ai_config.ai_enabled:
|
||||
return HttpResponseBadRequest("AI is required for this feature")
|
||||
|
||||
output_language = None
|
||||
if hasattr(request.user, "ui_settings") and isinstance(
|
||||
request.user.ui_settings.settings,
|
||||
dict,
|
||||
):
|
||||
output_language = request.user.ui_settings.settings.get("language") or None
|
||||
llm_cache_backend = (
|
||||
f"{ai_config.llm_backend}:{output_language}"
|
||||
if output_language
|
||||
else ai_config.llm_backend
|
||||
)
|
||||
|
||||
cached_llm_suggestions = get_llm_suggestion_cache(
|
||||
doc.pk,
|
||||
backend=ai_config.llm_backend,
|
||||
backend=llm_cache_backend,
|
||||
)
|
||||
|
||||
if cached_llm_suggestions:
|
||||
@@ -1479,7 +1491,11 @@ class DocumentViewSet(
|
||||
return Response(cached_llm_suggestions.suggestions)
|
||||
|
||||
try:
|
||||
llm_suggestions = get_ai_document_classification(doc, request.user)
|
||||
llm_suggestions = get_ai_document_classification(
|
||||
doc,
|
||||
request.user,
|
||||
output_language,
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.exception(
|
||||
"Invalid AI configuration while generating suggestions for "
|
||||
@@ -1532,7 +1548,7 @@ class DocumentViewSet(
|
||||
"dates": llm_suggestions.get("dates", []),
|
||||
}
|
||||
|
||||
set_llm_suggestions_cache(doc.pk, resp_data, backend=ai_config.llm_backend)
|
||||
set_llm_suggestions_cache(doc.pk, resp_data, backend=llm_cache_backend)
|
||||
|
||||
return Response(resp_data)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from documents.models import Document
|
||||
@@ -12,7 +14,17 @@ from paperless_ai.indexing import truncate_content
|
||||
logger = logging.getLogger("paperless_ai.rag_classifier")
|
||||
|
||||
|
||||
def build_prompt_without_rag(document: Document) -> str:
|
||||
def get_language_name(language_code: str) -> str:
|
||||
normalized_language_code = language_code.lower()
|
||||
for code, name in settings.LANGUAGES:
|
||||
if code.lower() == normalized_language_code:
|
||||
return str(name)
|
||||
return language_code
|
||||
|
||||
|
||||
def build_prompt_without_rag(
|
||||
document: Document,
|
||||
) -> str:
|
||||
filename = document.filename or ""
|
||||
content = truncate_content(document.content[:4000] or "")
|
||||
|
||||
@@ -35,7 +47,10 @@ def build_prompt_without_rag(document: Document) -> str:
|
||||
""".strip()
|
||||
|
||||
|
||||
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
|
||||
def build_prompt_with_rag(
|
||||
document: Document,
|
||||
user: User | None = None,
|
||||
) -> str:
|
||||
base_prompt = build_prompt_without_rag(document)
|
||||
context = truncate_content(get_context_for_document(document, user))
|
||||
|
||||
@@ -46,6 +61,25 @@ def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
|
||||
""".strip()
|
||||
|
||||
|
||||
def build_localization_prompt(suggestions: dict, output_language: str) -> str:
|
||||
language_name = get_language_name(output_language)
|
||||
return f"""
|
||||
You are localizing document classification suggestions for display in Paperless-ngx.
|
||||
|
||||
Rewrite only these generated fields in {language_name}: title, tags,
|
||||
document_types, storage_paths.
|
||||
|
||||
Do not translate correspondents or dates.
|
||||
Preserve proper nouns, organization names, product names, and exact official
|
||||
document names. Translate generic category words when a {language_name}
|
||||
equivalent exists.
|
||||
Return the same JSON schema with all fields present.
|
||||
|
||||
Suggestions:
|
||||
{json.dumps(suggestions)}
|
||||
""".strip()
|
||||
|
||||
|
||||
def get_context_for_document(
|
||||
doc: Document,
|
||||
user: User | None = None,
|
||||
@@ -91,6 +125,7 @@ def parse_ai_response(raw: dict) -> dict:
|
||||
def get_ai_document_classification(
|
||||
document: Document,
|
||||
user: User | None = None,
|
||||
output_language: str | None = None,
|
||||
) -> dict:
|
||||
ai_config = AIConfig()
|
||||
|
||||
@@ -102,4 +137,19 @@ def get_ai_document_classification(
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query(prompt)
|
||||
return parse_ai_response(result)
|
||||
suggestions = parse_ai_response(result)
|
||||
if output_language:
|
||||
localized = client.run_llm_query(
|
||||
build_localization_prompt(suggestions, output_language),
|
||||
)
|
||||
localized_suggestions = parse_ai_response(localized)
|
||||
suggestions = {
|
||||
**suggestions,
|
||||
"title": localized_suggestions["title"] or suggestions["title"],
|
||||
"tags": localized_suggestions["tags"] or suggestions["tags"],
|
||||
"document_types": localized_suggestions["document_types"]
|
||||
or suggestions["document_types"],
|
||||
"storage_paths": localized_suggestions["storage_paths"]
|
||||
or suggestions["storage_paths"],
|
||||
}
|
||||
return suggestions
|
||||
|
||||
@@ -6,10 +6,12 @@ import pytest
|
||||
from django.test import override_settings
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai.ai_classifier import build_localization_prompt
|
||||
from paperless_ai.ai_classifier import build_prompt_with_rag
|
||||
from paperless_ai.ai_classifier import build_prompt_without_rag
|
||||
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||
from paperless_ai.ai_classifier import get_context_for_document
|
||||
from paperless_ai.ai_classifier import get_language_name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -74,16 +76,70 @@ def mock_similar_documents():
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
||||
mock_run_llm_query.return_value = {
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"],
|
||||
}
|
||||
mock_run_llm_query.side_effect = [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
{
|
||||
"title": "Testtitel",
|
||||
"tags": ["Test", "Document"],
|
||||
"correspondents": ["Jane Doe"],
|
||||
"document_types": ["Bericht"],
|
||||
"storage_paths": ["Berichte"],
|
||||
"dates": ["2024-01-01"],
|
||||
},
|
||||
]
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
result = get_ai_document_classification(mock_document, output_language="de-de")
|
||||
|
||||
assert result["title"] == "Testtitel"
|
||||
assert result["tags"] == ["Test", "Document"]
|
||||
assert result["correspondents"] == ["John Doe"]
|
||||
assert result["document_types"] == ["Bericht"]
|
||||
assert result["storage_paths"] == ["Berichte"]
|
||||
assert result["dates"] == ["2023-01-01"]
|
||||
classification_prompt = mock_run_llm_query.call_args_list[0].args[0]
|
||||
localization_prompt = mock_run_llm_query.call_args_list[1].args[0]
|
||||
assert "Write suggested titles" not in classification_prompt
|
||||
assert "Rewrite only these generated fields in German" in localization_prompt
|
||||
assert "Do not translate correspondents or dates" in localization_prompt
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_keeps_originals_when_localization_empty(
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_run_llm_query.side_effect = [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
{
|
||||
"title": "",
|
||||
"tags": [],
|
||||
"correspondents": [],
|
||||
"document_types": [],
|
||||
"storage_paths": [],
|
||||
"dates": [],
|
||||
},
|
||||
]
|
||||
|
||||
result = get_ai_document_classification(mock_document, output_language="de-de")
|
||||
|
||||
assert result["title"] == "Test Title"
|
||||
assert result["tags"] == ["test", "document"]
|
||||
@@ -157,10 +213,29 @@ def test_prompt_with_without_rag(mock_document):
|
||||
):
|
||||
prompt = build_prompt_without_rag(mock_document)
|
||||
assert "Additional context from similar documents" not in prompt
|
||||
assert "for generated" not in prompt
|
||||
|
||||
prompt = build_prompt_with_rag(mock_document)
|
||||
assert "Additional context from similar documents" in prompt
|
||||
|
||||
prompt = build_localization_prompt(
|
||||
{
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
output_language="de-de",
|
||||
)
|
||||
assert "Rewrite only these generated fields in German" in prompt
|
||||
assert "Do not translate correspondents or dates" in prompt
|
||||
|
||||
|
||||
def test_get_language_name_falls_back_to_language_code():
|
||||
assert get_language_name("zz-zz") == "zz-zz"
|
||||
|
||||
|
||||
@patch("paperless_ai.ai_classifier.query_similar_documents")
|
||||
def test_get_context_for_document(
|
||||
|
||||
Reference in New Issue
Block a user