From 3aebdcf38c957f670b2ad2080e01471bbfdb630f Mon Sep 17 00:00:00 2001 From: Trenton Holmes <797416+stumpylog@users.noreply.github.com> Date: Sat, 6 Jun 2026 16:08:57 -0700 Subject: [PATCH] Construct fewer AiConfig objects and instead pass around as needed --- src/paperless_ai/ai_classifier.py | 20 +++++++--- src/paperless_ai/chat.py | 9 ++++- src/paperless_ai/embedding.py | 7 +--- src/paperless_ai/indexing.py | 39 +++++++++++++------- src/paperless_ai/tests/test_ai_classifier.py | 6 ++- src/paperless_ai/tests/test_chat.py | 2 + src/paperless_ai/tests/test_embedding.py | 30 ++++++++------- 7 files changed, 71 insertions(+), 42 deletions(-) diff --git a/src/paperless_ai/ai_classifier.py b/src/paperless_ai/ai_classifier.py index c3e27cd41..5420812eb 100644 --- a/src/paperless_ai/ai_classifier.py +++ b/src/paperless_ai/ai_classifier.py @@ -24,9 +24,14 @@ def get_language_name(language_code: str) -> str: def build_prompt_without_rag( document: Document, + config: AIConfig, ) -> str: filename = document.filename or "" - content = truncate_content(document.content[:4000] or "") + content = truncate_content( + document.content[:4000] or "", + chunk_size=config.llm_embedding_chunk_size, + context_size=config.llm_context_size, + ) return f""" You are a document classification assistant. @@ -49,10 +54,15 @@ def build_prompt_without_rag( def build_prompt_with_rag( document: Document, + config: AIConfig, user: User | None = None, ) -> str: - base_prompt = build_prompt_without_rag(document) - context = truncate_content(get_context_for_document(document, user)) + base_prompt = build_prompt_without_rag(document, config) + context = truncate_content( + get_context_for_document(document, user), + chunk_size=config.llm_embedding_chunk_size, + context_size=config.llm_context_size, + ) return f"""{base_prompt} @@ -130,9 +140,9 @@ def get_ai_document_classification( ai_config = AIConfig() prompt = ( - build_prompt_with_rag(document, user) + build_prompt_with_rag(document, ai_config, user) if ai_config.llm_embedding_backend - else build_prompt_without_rag(document) + else build_prompt_without_rag(document, ai_config) ) client = AIClient() diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py index 7942321a6..123771c50 100644 --- a/src/paperless_ai/chat.py +++ b/src/paperless_ai/chat.py @@ -3,6 +3,7 @@ import logging import sys from documents.models import Document +from paperless.config import AIConfig from paperless_ai.client import AIClient from paperless_ai.indexing import _document_id_filters from paperless_ai.indexing import get_rag_prompt_helper @@ -94,7 +95,8 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]): from llama_index.core.response_synthesizers import get_response_synthesizer from llama_index.core.retrievers import VectorIndexRetriever - index = load_or_build_index() + config = AIConfig() + index = load_or_build_index(config) filters = _document_id_filters(str(doc.pk) for doc in documents) retriever = VectorIndexRetriever( @@ -116,7 +118,10 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]): prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL) response_synthesizer = get_response_synthesizer( llm=client.llm, - prompt_helper=get_rag_prompt_helper(), + prompt_helper=get_rag_prompt_helper( + chunk_size=config.llm_embedding_chunk_size, + context_size=config.llm_context_size, + ), text_qa_template=prompt_template, streaming=True, ) diff --git a/src/paperless_ai/embedding.py b/src/paperless_ai/embedding.py index 8480cb76d..88ea80293 100644 --- a/src/paperless_ai/embedding.py +++ b/src/paperless_ai/embedding.py @@ -20,9 +20,7 @@ OCR_LEADER_REGEX = re.compile(r"[._\-\u00b7]{4,}") HORIZONTAL_WHITESPACE_REGEX = re.compile(r"[ \t\u00a0]+") -def get_embedding_model() -> "BaseEmbedding": - config = AIConfig() - +def get_embedding_model(config: AIConfig) -> "BaseEmbedding": match config.llm_embedding_backend: case LLMEmbeddingBackend.OPENAI_LIKE: from llama_index.embeddings.openai_like import OpenAILikeEmbedding @@ -99,9 +97,8 @@ _DEFAULT_MODEL_NAMES = { } -def get_configured_model_name() -> str: +def get_configured_model_name(config: AIConfig) -> str: """Return the canonical name of the currently configured embedding model.""" - config = AIConfig() default = _DEFAULT_MODEL_NAMES.get( config.llm_embedding_backend, "sentence-transformers/all-MiniLM-L6-v2", diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index bfd4edd72..dd96106a6 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -139,12 +139,12 @@ def build_document_node( return parser.get_nodes_from_documents([doc]) -def load_or_build_index(): +def load_or_build_index(config: AIConfig): """Return a VectorStoreIndex backed by the vector store.""" import llama_index.core.settings as llama_settings from llama_index.core import VectorStoreIndex - embed_model = get_embedding_model() + embed_model = get_embedding_model(config) llama_settings.Settings.embed_model = embed_model vector_store = get_vector_store() return VectorStoreIndex.from_vector_store( @@ -223,7 +223,16 @@ def update_llm_index( rebuild=False, ) -> str: """Rebuild or incrementally update the LLM index.""" - model_name = get_configured_model_name() + documents = Document.objects.all() + no_documents = not documents.exists() + + # Fast exit before touching config: nothing to index and no existing index. + if no_documents and not rebuild and not llm_index_exists(): + logger.warning("No documents found to index.") + return "No documents found to index." + + config = AIConfig() + model_name = get_configured_model_name(config) if ( not rebuild @@ -233,14 +242,11 @@ def update_llm_index( logger.warning("Embedding model changed; forcing LLM index rebuild.") rebuild = True - documents = Document.objects.all() - if not documents.exists(): + if no_documents: logger.warning("No documents found to index.") - if not rebuild and not llm_index_exists(): - return "No documents found to index." - chunk_size = AIConfig().llm_embedding_chunk_size - embed_model = get_embedding_model() + chunk_size = config.llm_embedding_chunk_size + embed_model = get_embedding_model(config) with write_store(embed_model_name=model_name) as store: if rebuild or not store.table_exists(): @@ -277,11 +283,15 @@ def update_llm_index( def llm_index_add_or_update_document(document: Document): """Add or atomically replace a document's chunks in the index.""" - new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size()) + config = AIConfig() + new_nodes = build_document_node( + document, + chunk_size=config.llm_embedding_chunk_size, + ) if new_nodes: - _embed_nodes(new_nodes, get_embedding_model()) + _embed_nodes(new_nodes, get_embedding_model(config)) - with write_store(embed_model_name=get_configured_model_name()) as store: + with write_store(embed_model_name=get_configured_model_name(config)) as store: store.upsert_document(str(document.id), new_nodes) store.ensure_document_id_scalar_index() @@ -352,9 +362,11 @@ def query_similar_documents( ) return [] + config = AIConfig() + from llama_index.core.retrievers import VectorIndexRetriever - index = load_or_build_index() + index = load_or_build_index(config) filters = ( _document_id_filters(allowed_document_ids) @@ -368,7 +380,6 @@ def query_similar_documents( filters=filters, ) - config = AIConfig() query_text = truncate_content( (document.title or "") + "\n" + (document.content or ""), chunk_size=config.llm_embedding_chunk_size, diff --git a/src/paperless_ai/tests/test_ai_classifier.py b/src/paperless_ai/tests/test_ai_classifier.py index 97e18eb47..45822b14b 100644 --- a/src/paperless_ai/tests/test_ai_classifier.py +++ b/src/paperless_ai/tests/test_ai_classifier.py @@ -6,6 +6,7 @@ import pytest from django.test import override_settings from documents.models import Document +from paperless.config import AIConfig from paperless_ai.ai_classifier import build_localization_prompt from paperless_ai.ai_classifier import build_prompt_with_rag from paperless_ai.ai_classifier import build_prompt_without_rag @@ -211,11 +212,12 @@ def test_prompt_with_without_rag(mock_document): "paperless_ai.ai_classifier.get_context_for_document", return_value="Context from similar documents", ): - prompt = build_prompt_without_rag(mock_document) + config = AIConfig() + prompt = build_prompt_without_rag(mock_document, config) assert "Additional context from similar documents" not in prompt assert "for generated" not in prompt - prompt = build_prompt_with_rag(mock_document) + prompt = build_prompt_with_rag(mock_document, config) assert "Additional context from similar documents" in prompt prompt = build_localization_prompt( diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index bb32c62ae..af34914bb 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -185,6 +185,7 @@ def test_stream_chat_empty_document_list() -> None: def test_stream_chat_no_matching_nodes() -> None: with ( + patch("paperless_ai.chat.AIConfig"), patch("paperless_ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.load_or_build_index") as mock_load_index, ): @@ -204,6 +205,7 @@ def test_stream_chat_no_matching_nodes() -> None: def test_stream_chat_unexpected_failure_returns_generic_error(caplog) -> None: with ( + patch("paperless_ai.chat.AIConfig"), patch("paperless_ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.load_or_build_index") as mock_load_index, ): diff --git a/src/paperless_ai/tests/test_embedding.py b/src/paperless_ai/tests/test_embedding.py index d1d0754d1..251d3f90b 100644 --- a/src/paperless_ai/tests/test_embedding.py +++ b/src/paperless_ai/tests/test_embedding.py @@ -66,7 +66,7 @@ def test_get_embedding_model_openai(mock_ai_config): with patch( "llama_index.embeddings.openai_like.OpenAILikeEmbedding", ) as MockOpenAIEmbedding: - model = get_embedding_model() + model = get_embedding_model(mock_ai_config.return_value) MockOpenAIEmbedding.assert_called_once_with( model_name="text-embedding-3-small", api_key="test_api_key", @@ -87,7 +87,7 @@ def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config): with patch( "llama_index.embeddings.openai_like.OpenAILikeEmbedding", ) as MockOpenAIEmbedding: - model = get_embedding_model() + model = get_embedding_model(mock_ai_config.return_value) MockOpenAIEmbedding.assert_called_once_with( model_name="text-embedding-3-small", api_key="test_api_key", @@ -108,7 +108,7 @@ def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed( mock_ai_config.return_value.llm_allow_internal_endpoints = False with pytest.raises(ValueError, match="non-public address"): - get_embedding_model() + get_embedding_model(mock_ai_config.return_value) def test_get_embedding_model_huggingface(mock_ai_config): @@ -120,7 +120,7 @@ def test_get_embedding_model_huggingface(mock_ai_config): with patch( "llama_index.embeddings.huggingface.HuggingFaceEmbedding", ) as MockHuggingFaceEmbedding: - model = get_embedding_model() + model = get_embedding_model(mock_ai_config.return_value) MockHuggingFaceEmbedding.assert_called_once_with( model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=str(settings.DATA_DIR / "hf_cache"), @@ -136,7 +136,7 @@ def test_get_embedding_model_ollama(mock_ai_config): with patch( "llama_index.embeddings.ollama.OllamaEmbedding", ) as MockOllamaEmbedding: - model = get_embedding_model() + model = get_embedding_model(mock_ai_config.return_value) MockOllamaEmbedding.assert_called_once_with( model_name="embeddinggemma", base_url="http://test-url", @@ -154,7 +154,7 @@ def test_get_embedding_model_ollama_prefers_embedding_endpoint(mock_ai_config): with patch( "llama_index.embeddings.ollama.OllamaEmbedding", ) as MockOllamaEmbedding: - model = get_embedding_model() + model = get_embedding_model(mock_ai_config.return_value) MockOllamaEmbedding.assert_called_once_with( model_name="embeddinggemma", base_url="http://embedding-url", @@ -172,7 +172,7 @@ def test_get_embedding_model_ollama_blocks_internal_endpoint_when_disallowed( mock_ai_config.return_value.llm_allow_internal_endpoints = False with pytest.raises(ValueError, match="non-public address"): - get_embedding_model() + get_embedding_model(mock_ai_config.return_value) def test_get_embedding_model_invalid_backend(mock_ai_config): @@ -182,7 +182,7 @@ def test_get_embedding_model_invalid_backend(mock_ai_config): ValueError, match="Unsupported embedding backend: INVALID_BACKEND", ): - get_embedding_model() + get_embedding_model(mock_ai_config.return_value) @pytest.mark.parametrize( @@ -199,18 +199,20 @@ def test_get_configured_model_name_falls_back_to_backend_default( expected_default, ): """When no model is explicitly configured, each backend has a distinct default.""" - mock_ai_config.return_value.llm_embedding_backend = backend - mock_ai_config.return_value.llm_embedding_model = None - assert get_configured_model_name() == expected_default + config = mock_ai_config.return_value + config.llm_embedding_backend = backend + config.llm_embedding_model = None + assert get_configured_model_name(config) == expected_default def test_get_configured_model_name_explicit_overrides_default(mock_ai_config): """An explicit model name overrides the backend default for all backends.""" - mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE - mock_ai_config.return_value.llm_embedding_model = "my-custom-model" + config = mock_ai_config.return_value + config.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE + config.llm_embedding_model = "my-custom-model" # The backend default for OPENAI_LIKE is "text-embedding-3-small", so if # the explicit name was ignored we'd get the wrong result. - assert get_configured_model_name() == "my-custom-model" + assert get_configured_model_name(config) == "my-custom-model" def test_build_llm_index_text(mock_document):