Enhancement(beta): splice taxonomy hints into the AI classifier prompt

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
stumpylog
2026-06-13 06:20:29 -07:00
parent 5202b0880e
commit 24b81c15f6
2 changed files with 124 additions and 4 deletions
+18 -4
View File
@@ -1,5 +1,6 @@
import json
import logging
from typing import TYPE_CHECKING
from django.conf import settings
from django.contrib.auth.models import User
@@ -11,6 +12,10 @@ from paperless_ai.client import AIClient
from paperless_ai.db import db_connection_released
from paperless_ai.indexing import query_similar_documents
from paperless_ai.indexing import truncate_content
from paperless_ai.taxonomy import format_hints_for_prompt
if TYPE_CHECKING:
from paperless_ai.taxonomy import TaxonomyHints
logger = logging.getLogger("paperless_ai.rag_classifier")
@@ -26,6 +31,7 @@ def get_language_name(language_code: str) -> str:
def build_prompt_without_rag(
document: Document,
config: AIConfig,
hints: "TaxonomyHints | None" = None,
) -> str:
filename = document.filename or ""
content = truncate_content(
@@ -34,10 +40,16 @@ def build_prompt_without_rag(
context_size=config.llm_context_size,
)
hints_block = format_hints_for_prompt(hints) if hints else ""
# Splice the block (if any) immediately before the "Analyze ..." instruction.
# When there is no block this expands to nothing, so the prompt is identical
# to the pre-hints baseline.
hints_section = f"{hints_block}\n\n " if hints_block else ""
return f"""
You are a document classification assistant.
Analyze the following document and extract the following information:
{hints_section}Analyze the following document and extract the following information:
- A short descriptive title
- Tags that reflect the content
- Names of people or organizations mentioned
@@ -57,8 +69,9 @@ def build_prompt_with_rag(
document: Document,
config: AIConfig,
user: User | None = None,
hints: "TaxonomyHints | None" = None,
) -> str:
base_prompt = build_prompt_without_rag(document, config)
base_prompt = build_prompt_without_rag(document, config, hints=hints)
context = truncate_content(
get_context_for_document(document, user),
chunk_size=config.llm_embedding_chunk_size,
@@ -137,13 +150,14 @@ def get_ai_document_classification(
document: Document,
user: User | None = None,
output_language: str | None = None,
hints: "TaxonomyHints | None" = None,
) -> dict:
ai_config = AIConfig()
prompt = (
build_prompt_with_rag(document, ai_config, user)
build_prompt_with_rag(document, ai_config, user, hints=hints)
if ai_config.llm_embedding_backend
else build_prompt_without_rag(document, ai_config)
else build_prompt_without_rag(document, ai_config, hints=hints)
)
client = AIClient()
@@ -1,8 +1,10 @@
import json
from types import SimpleNamespace
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
import pytest_mock
from django.test import override_settings
from documents.models import Document
@@ -261,3 +263,107 @@ def test_get_context_for_document_no_similar_docs(mock_document):
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
result = get_context_for_document(mock_document)
assert result == ""
@pytest.mark.django_db
class TestPromptHints:
@pytest.fixture
def config(self) -> AIConfig:
return AIConfig()
def test_without_rag_includes_hints_block(
self,
mock_document: MagicMock,
config: AIConfig,
) -> None:
hints = {
"tags": ["Bloodwork"],
"document_types": ["Invoice"],
"correspondents": [],
"storage_paths": [],
}
prompt = build_prompt_without_rag(mock_document, config, hints=hints)
assert "Available tags:" in prompt
assert "- Bloodwork" in prompt
assert "Prefer existing names from these lists verbatim" in prompt
def test_without_rag_none_matches_baseline(
self,
mock_document: MagicMock,
config: AIConfig,
) -> None:
baseline = build_prompt_without_rag(mock_document, config)
with_none = build_prompt_without_rag(mock_document, config, hints=None)
assert with_none == baseline
assert "Available tags:" not in with_none
def test_with_rag_includes_context_and_hints(
self,
mock_document: MagicMock,
config: AIConfig,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.ai_classifier.get_context_for_document",
return_value="TITLE: Neighbour\nsome context",
)
hints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
prompt = build_prompt_with_rag(mock_document, config, user=None, hints=hints)
assert "Additional context from similar documents" in prompt
assert "Available tags:" in prompt
def test_classification_forwards_hints(
self,
mock_document: MagicMock,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.ai_classifier.AIConfig",
return_value=SimpleNamespace(
llm_embedding_backend=None,
llm_embedding_chunk_size=1000,
llm_context_size=8000,
),
)
build = mocker.patch(
"paperless_ai.ai_classifier.build_prompt_without_rag",
return_value="PROMPT",
)
mock_client = MagicMock()
mock_client.run_llm_query.return_value = {
"title": "t",
"tags": [],
"correspondents": [],
"document_types": [],
"storage_paths": [],
"dates": [],
}
mocker.patch("paperless_ai.ai_classifier.AIClient", return_value=mock_client)
hints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
result = get_ai_document_classification(
mock_document,
user=None,
hints=hints,
)
_, build_kwargs = build.call_args
assert build_kwargs["hints"] == hints
assert set(result.keys()) == {
"title",
"tags",
"correspondents",
"document_types",
"storage_paths",
"dates",
}