diff --git a/src/paperless_ai/ai_classifier.py b/src/paperless_ai/ai_classifier.py index f9c2f1e06..f07f5eacc 100644 --- a/src/paperless_ai/ai_classifier.py +++ b/src/paperless_ai/ai_classifier.py @@ -1,5 +1,6 @@ import json import logging +from typing import TYPE_CHECKING from django.conf import settings from django.contrib.auth.models import User @@ -11,6 +12,10 @@ from paperless_ai.client import AIClient from paperless_ai.db import db_connection_released from paperless_ai.indexing import query_similar_documents from paperless_ai.indexing import truncate_content +from paperless_ai.taxonomy import format_hints_for_prompt + +if TYPE_CHECKING: + from paperless_ai.taxonomy import TaxonomyHints logger = logging.getLogger("paperless_ai.rag_classifier") @@ -26,6 +31,7 @@ def get_language_name(language_code: str) -> str: def build_prompt_without_rag( document: Document, config: AIConfig, + hints: "TaxonomyHints | None" = None, ) -> str: filename = document.filename or "" content = truncate_content( @@ -34,10 +40,16 @@ def build_prompt_without_rag( context_size=config.llm_context_size, ) + hints_block = format_hints_for_prompt(hints) if hints else "" + # Splice the block (if any) immediately before the "Analyze ..." instruction. + # When there is no block this expands to nothing, so the prompt is identical + # to the pre-hints baseline. + hints_section = f"{hints_block}\n\n " if hints_block else "" + return f""" You are a document classification assistant. - Analyze the following document and extract the following information: + {hints_section}Analyze the following document and extract the following information: - A short descriptive title - Tags that reflect the content - Names of people or organizations mentioned @@ -57,8 +69,9 @@ def build_prompt_with_rag( document: Document, config: AIConfig, user: User | None = None, + hints: "TaxonomyHints | None" = None, ) -> str: - base_prompt = build_prompt_without_rag(document, config) + base_prompt = build_prompt_without_rag(document, config, hints=hints) context = truncate_content( get_context_for_document(document, user), chunk_size=config.llm_embedding_chunk_size, @@ -137,13 +150,14 @@ def get_ai_document_classification( document: Document, user: User | None = None, output_language: str | None = None, + hints: "TaxonomyHints | None" = None, ) -> dict: ai_config = AIConfig() prompt = ( - build_prompt_with_rag(document, ai_config, user) + build_prompt_with_rag(document, ai_config, user, hints=hints) if ai_config.llm_embedding_backend - else build_prompt_without_rag(document, ai_config) + else build_prompt_without_rag(document, ai_config, hints=hints) ) client = AIClient() diff --git a/src/paperless_ai/tests/test_ai_classifier.py b/src/paperless_ai/tests/test_ai_classifier.py index 45822b14b..127c08097 100644 --- a/src/paperless_ai/tests/test_ai_classifier.py +++ b/src/paperless_ai/tests/test_ai_classifier.py @@ -1,8 +1,10 @@ import json +from types import SimpleNamespace from unittest.mock import MagicMock from unittest.mock import patch import pytest +import pytest_mock from django.test import override_settings from documents.models import Document @@ -261,3 +263,107 @@ def test_get_context_for_document_no_similar_docs(mock_document): with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]): result = get_context_for_document(mock_document) assert result == "" + + +@pytest.mark.django_db +class TestPromptHints: + @pytest.fixture + def config(self) -> AIConfig: + return AIConfig() + + def test_without_rag_includes_hints_block( + self, + mock_document: MagicMock, + config: AIConfig, + ) -> None: + hints = { + "tags": ["Bloodwork"], + "document_types": ["Invoice"], + "correspondents": [], + "storage_paths": [], + } + prompt = build_prompt_without_rag(mock_document, config, hints=hints) + assert "Available tags:" in prompt + assert "- Bloodwork" in prompt + assert "Prefer existing names from these lists verbatim" in prompt + + def test_without_rag_none_matches_baseline( + self, + mock_document: MagicMock, + config: AIConfig, + ) -> None: + baseline = build_prompt_without_rag(mock_document, config) + with_none = build_prompt_without_rag(mock_document, config, hints=None) + assert with_none == baseline + assert "Available tags:" not in with_none + + def test_with_rag_includes_context_and_hints( + self, + mock_document: MagicMock, + config: AIConfig, + mocker: pytest_mock.MockerFixture, + ) -> None: + mocker.patch( + "paperless_ai.ai_classifier.get_context_for_document", + return_value="TITLE: Neighbour\nsome context", + ) + hints = { + "tags": ["Bloodwork"], + "document_types": [], + "correspondents": [], + "storage_paths": [], + } + prompt = build_prompt_with_rag(mock_document, config, user=None, hints=hints) + assert "Additional context from similar documents" in prompt + assert "Available tags:" in prompt + + def test_classification_forwards_hints( + self, + mock_document: MagicMock, + mocker: pytest_mock.MockerFixture, + ) -> None: + mocker.patch( + "paperless_ai.ai_classifier.AIConfig", + return_value=SimpleNamespace( + llm_embedding_backend=None, + llm_embedding_chunk_size=1000, + llm_context_size=8000, + ), + ) + build = mocker.patch( + "paperless_ai.ai_classifier.build_prompt_without_rag", + return_value="PROMPT", + ) + mock_client = MagicMock() + mock_client.run_llm_query.return_value = { + "title": "t", + "tags": [], + "correspondents": [], + "document_types": [], + "storage_paths": [], + "dates": [], + } + mocker.patch("paperless_ai.ai_classifier.AIClient", return_value=mock_client) + hints = { + "tags": ["Bloodwork"], + "document_types": [], + "correspondents": [], + "storage_paths": [], + } + + result = get_ai_document_classification( + mock_document, + user=None, + hints=hints, + ) + + _, build_kwargs = build.call_args + assert build_kwargs["hints"] == hints + assert set(result.keys()) == { + "title", + "tags", + "correspondents", + "document_types", + "storage_paths", + "dates", + }