mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-28 16:24:19 +00:00
Enhancement(beta): splice taxonomy hints into the AI classifier prompt
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
@@ -11,6 +12,10 @@ from paperless_ai.client import AIClient
|
||||
from paperless_ai.db import db_connection_released
|
||||
from paperless_ai.indexing import query_similar_documents
|
||||
from paperless_ai.indexing import truncate_content
|
||||
from paperless_ai.taxonomy import format_hints_for_prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paperless_ai.taxonomy import TaxonomyHints
|
||||
|
||||
logger = logging.getLogger("paperless_ai.rag_classifier")
|
||||
|
||||
@@ -26,6 +31,7 @@ def get_language_name(language_code: str) -> str:
|
||||
def build_prompt_without_rag(
|
||||
document: Document,
|
||||
config: AIConfig,
|
||||
hints: "TaxonomyHints | None" = None,
|
||||
) -> str:
|
||||
filename = document.filename or ""
|
||||
content = truncate_content(
|
||||
@@ -34,10 +40,16 @@ def build_prompt_without_rag(
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
|
||||
hints_block = format_hints_for_prompt(hints) if hints else ""
|
||||
# Splice the block (if any) immediately before the "Analyze ..." instruction.
|
||||
# When there is no block this expands to nothing, so the prompt is identical
|
||||
# to the pre-hints baseline.
|
||||
hints_section = f"{hints_block}\n\n " if hints_block else ""
|
||||
|
||||
return f"""
|
||||
You are a document classification assistant.
|
||||
|
||||
Analyze the following document and extract the following information:
|
||||
{hints_section}Analyze the following document and extract the following information:
|
||||
- A short descriptive title
|
||||
- Tags that reflect the content
|
||||
- Names of people or organizations mentioned
|
||||
@@ -57,8 +69,9 @@ def build_prompt_with_rag(
|
||||
document: Document,
|
||||
config: AIConfig,
|
||||
user: User | None = None,
|
||||
hints: "TaxonomyHints | None" = None,
|
||||
) -> str:
|
||||
base_prompt = build_prompt_without_rag(document, config)
|
||||
base_prompt = build_prompt_without_rag(document, config, hints=hints)
|
||||
context = truncate_content(
|
||||
get_context_for_document(document, user),
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
@@ -137,13 +150,14 @@ def get_ai_document_classification(
|
||||
document: Document,
|
||||
user: User | None = None,
|
||||
output_language: str | None = None,
|
||||
hints: "TaxonomyHints | None" = None,
|
||||
) -> dict:
|
||||
ai_config = AIConfig()
|
||||
|
||||
prompt = (
|
||||
build_prompt_with_rag(document, ai_config, user)
|
||||
build_prompt_with_rag(document, ai_config, user, hints=hints)
|
||||
if ai_config.llm_embedding_backend
|
||||
else build_prompt_without_rag(document, ai_config)
|
||||
else build_prompt_without_rag(document, ai_config, hints=hints)
|
||||
)
|
||||
|
||||
client = AIClient()
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from django.test import override_settings
|
||||
|
||||
from documents.models import Document
|
||||
@@ -261,3 +263,107 @@ def test_get_context_for_document_no_similar_docs(mock_document):
|
||||
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
|
||||
result = get_context_for_document(mock_document)
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestPromptHints:
|
||||
@pytest.fixture
|
||||
def config(self) -> AIConfig:
|
||||
return AIConfig()
|
||||
|
||||
def test_without_rag_includes_hints_block(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
config: AIConfig,
|
||||
) -> None:
|
||||
hints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": ["Invoice"],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
prompt = build_prompt_without_rag(mock_document, config, hints=hints)
|
||||
assert "Available tags:" in prompt
|
||||
assert "- Bloodwork" in prompt
|
||||
assert "Prefer existing names from these lists verbatim" in prompt
|
||||
|
||||
def test_without_rag_none_matches_baseline(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
config: AIConfig,
|
||||
) -> None:
|
||||
baseline = build_prompt_without_rag(mock_document, config)
|
||||
with_none = build_prompt_without_rag(mock_document, config, hints=None)
|
||||
assert with_none == baseline
|
||||
assert "Available tags:" not in with_none
|
||||
|
||||
def test_with_rag_includes_context_and_hints(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
config: AIConfig,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.ai_classifier.get_context_for_document",
|
||||
return_value="TITLE: Neighbour\nsome context",
|
||||
)
|
||||
hints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
prompt = build_prompt_with_rag(mock_document, config, user=None, hints=hints)
|
||||
assert "Additional context from similar documents" in prompt
|
||||
assert "Available tags:" in prompt
|
||||
|
||||
def test_classification_forwards_hints(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.ai_classifier.AIConfig",
|
||||
return_value=SimpleNamespace(
|
||||
llm_embedding_backend=None,
|
||||
llm_embedding_chunk_size=1000,
|
||||
llm_context_size=8000,
|
||||
),
|
||||
)
|
||||
build = mocker.patch(
|
||||
"paperless_ai.ai_classifier.build_prompt_without_rag",
|
||||
return_value="PROMPT",
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.run_llm_query.return_value = {
|
||||
"title": "t",
|
||||
"tags": [],
|
||||
"correspondents": [],
|
||||
"document_types": [],
|
||||
"storage_paths": [],
|
||||
"dates": [],
|
||||
}
|
||||
mocker.patch("paperless_ai.ai_classifier.AIClient", return_value=mock_client)
|
||||
hints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
|
||||
result = get_ai_document_classification(
|
||||
mock_document,
|
||||
user=None,
|
||||
hints=hints,
|
||||
)
|
||||
|
||||
_, build_kwargs = build.call_args
|
||||
assert build_kwargs["hints"] == hints
|
||||
assert set(result.keys()) == {
|
||||
"title",
|
||||
"tags",
|
||||
"correspondents",
|
||||
"document_types",
|
||||
"storage_paths",
|
||||
"dates",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user