Construct fewer AiConfig objects and instead pass around as needed

This commit is contained in:
Trenton Holmes
2026-06-06 16:08:57 -07:00
parent 7f5053cbe3
commit 3aebdcf38c
7 changed files with 71 additions and 42 deletions
+15 -5
View File
@@ -24,9 +24,14 @@ def get_language_name(language_code: str) -> str:
def build_prompt_without_rag(
document: Document,
config: AIConfig,
) -> str:
filename = document.filename or ""
content = truncate_content(document.content[:4000] or "")
content = truncate_content(
document.content[:4000] or "",
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
return f"""
You are a document classification assistant.
@@ -49,10 +54,15 @@ def build_prompt_without_rag(
def build_prompt_with_rag(
document: Document,
config: AIConfig,
user: User | None = None,
) -> str:
base_prompt = build_prompt_without_rag(document)
context = truncate_content(get_context_for_document(document, user))
base_prompt = build_prompt_without_rag(document, config)
context = truncate_content(
get_context_for_document(document, user),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
return f"""{base_prompt}
@@ -130,9 +140,9 @@ def get_ai_document_classification(
ai_config = AIConfig()
prompt = (
build_prompt_with_rag(document, user)
build_prompt_with_rag(document, ai_config, user)
if ai_config.llm_embedding_backend
else build_prompt_without_rag(document)
else build_prompt_without_rag(document, ai_config)
)
client = AIClient()
+7 -2
View File
@@ -3,6 +3,7 @@ import logging
import sys
from documents.models import Document
from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.indexing import _document_id_filters
from paperless_ai.indexing import get_rag_prompt_helper
@@ -94,7 +95,8 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]):
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.retrievers import VectorIndexRetriever
index = load_or_build_index()
config = AIConfig()
index = load_or_build_index(config)
filters = _document_id_filters(str(doc.pk) for doc in documents)
retriever = VectorIndexRetriever(
@@ -116,7 +118,10 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]):
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
response_synthesizer = get_response_synthesizer(
llm=client.llm,
prompt_helper=get_rag_prompt_helper(),
prompt_helper=get_rag_prompt_helper(
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
),
text_qa_template=prompt_template,
streaming=True,
)
+2 -5
View File
@@ -20,9 +20,7 @@ OCR_LEADER_REGEX = re.compile(r"[._\-\u00b7]{4,}")
HORIZONTAL_WHITESPACE_REGEX = re.compile(r"[ \t\u00a0]+")
def get_embedding_model() -> "BaseEmbedding":
config = AIConfig()
def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
match config.llm_embedding_backend:
case LLMEmbeddingBackend.OPENAI_LIKE:
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
@@ -99,9 +97,8 @@ _DEFAULT_MODEL_NAMES = {
}
def get_configured_model_name() -> str:
def get_configured_model_name(config: AIConfig) -> str:
"""Return the canonical name of the currently configured embedding model."""
config = AIConfig()
default = _DEFAULT_MODEL_NAMES.get(
config.llm_embedding_backend,
"sentence-transformers/all-MiniLM-L6-v2",
+25 -14
View File
@@ -139,12 +139,12 @@ def build_document_node(
return parser.get_nodes_from_documents([doc])
def load_or_build_index():
def load_or_build_index(config: AIConfig):
"""Return a VectorStoreIndex backed by the vector store."""
import llama_index.core.settings as llama_settings
from llama_index.core import VectorStoreIndex
embed_model = get_embedding_model()
embed_model = get_embedding_model(config)
llama_settings.Settings.embed_model = embed_model
vector_store = get_vector_store()
return VectorStoreIndex.from_vector_store(
@@ -223,7 +223,16 @@ def update_llm_index(
rebuild=False,
) -> str:
"""Rebuild or incrementally update the LLM index."""
model_name = get_configured_model_name()
documents = Document.objects.all()
no_documents = not documents.exists()
# Fast exit before touching config: nothing to index and no existing index.
if no_documents and not rebuild and not llm_index_exists():
logger.warning("No documents found to index.")
return "No documents found to index."
config = AIConfig()
model_name = get_configured_model_name(config)
if (
not rebuild
@@ -233,14 +242,11 @@ def update_llm_index(
logger.warning("Embedding model changed; forcing LLM index rebuild.")
rebuild = True
documents = Document.objects.all()
if not documents.exists():
if no_documents:
logger.warning("No documents found to index.")
if not rebuild and not llm_index_exists():
return "No documents found to index."
chunk_size = AIConfig().llm_embedding_chunk_size
embed_model = get_embedding_model()
chunk_size = config.llm_embedding_chunk_size
embed_model = get_embedding_model(config)
with write_store(embed_model_name=model_name) as store:
if rebuild or not store.table_exists():
@@ -277,11 +283,15 @@ def update_llm_index(
def llm_index_add_or_update_document(document: Document):
"""Add or atomically replace a document's chunks in the index."""
new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size())
config = AIConfig()
new_nodes = build_document_node(
document,
chunk_size=config.llm_embedding_chunk_size,
)
if new_nodes:
_embed_nodes(new_nodes, get_embedding_model())
_embed_nodes(new_nodes, get_embedding_model(config))
with write_store(embed_model_name=get_configured_model_name()) as store:
with write_store(embed_model_name=get_configured_model_name(config)) as store:
store.upsert_document(str(document.id), new_nodes)
store.ensure_document_id_scalar_index()
@@ -352,9 +362,11 @@ def query_similar_documents(
)
return []
config = AIConfig()
from llama_index.core.retrievers import VectorIndexRetriever
index = load_or_build_index()
index = load_or_build_index(config)
filters = (
_document_id_filters(allowed_document_ids)
@@ -368,7 +380,6 @@ def query_similar_documents(
filters=filters,
)
config = AIConfig()
query_text = truncate_content(
(document.title or "") + "\n" + (document.content or ""),
chunk_size=config.llm_embedding_chunk_size,
+4 -2
View File
@@ -6,6 +6,7 @@ import pytest
from django.test import override_settings
from documents.models import Document
from paperless.config import AIConfig
from paperless_ai.ai_classifier import build_localization_prompt
from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag
@@ -211,11 +212,12 @@ def test_prompt_with_without_rag(mock_document):
"paperless_ai.ai_classifier.get_context_for_document",
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
config = AIConfig()
prompt = build_prompt_without_rag(mock_document, config)
assert "Additional context from similar documents" not in prompt
assert "for generated" not in prompt
prompt = build_prompt_with_rag(mock_document)
prompt = build_prompt_with_rag(mock_document, config)
assert "Additional context from similar documents" in prompt
prompt = build_localization_prompt(
+2
View File
@@ -185,6 +185,7 @@ def test_stream_chat_empty_document_list() -> None:
def test_stream_chat_no_matching_nodes() -> None:
with (
patch("paperless_ai.chat.AIConfig"),
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
):
@@ -204,6 +205,7 @@ def test_stream_chat_no_matching_nodes() -> None:
def test_stream_chat_unexpected_failure_returns_generic_error(caplog) -> None:
with (
patch("paperless_ai.chat.AIConfig"),
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
):
+16 -14
View File
@@ -66,7 +66,7 @@ def test_get_embedding_model_openai(mock_ai_config):
with patch(
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
) as MockOpenAIEmbedding:
model = get_embedding_model()
model = get_embedding_model(mock_ai_config.return_value)
MockOpenAIEmbedding.assert_called_once_with(
model_name="text-embedding-3-small",
api_key="test_api_key",
@@ -87,7 +87,7 @@ def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config):
with patch(
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
) as MockOpenAIEmbedding:
model = get_embedding_model()
model = get_embedding_model(mock_ai_config.return_value)
MockOpenAIEmbedding.assert_called_once_with(
model_name="text-embedding-3-small",
api_key="test_api_key",
@@ -108,7 +108,7 @@ def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
mock_ai_config.return_value.llm_allow_internal_endpoints = False
with pytest.raises(ValueError, match="non-public address"):
get_embedding_model()
get_embedding_model(mock_ai_config.return_value)
def test_get_embedding_model_huggingface(mock_ai_config):
@@ -120,7 +120,7 @@ def test_get_embedding_model_huggingface(mock_ai_config):
with patch(
"llama_index.embeddings.huggingface.HuggingFaceEmbedding",
) as MockHuggingFaceEmbedding:
model = get_embedding_model()
model = get_embedding_model(mock_ai_config.return_value)
MockHuggingFaceEmbedding.assert_called_once_with(
model_name="sentence-transformers/all-MiniLM-L6-v2",
cache_folder=str(settings.DATA_DIR / "hf_cache"),
@@ -136,7 +136,7 @@ def test_get_embedding_model_ollama(mock_ai_config):
with patch(
"llama_index.embeddings.ollama.OllamaEmbedding",
) as MockOllamaEmbedding:
model = get_embedding_model()
model = get_embedding_model(mock_ai_config.return_value)
MockOllamaEmbedding.assert_called_once_with(
model_name="embeddinggemma",
base_url="http://test-url",
@@ -154,7 +154,7 @@ def test_get_embedding_model_ollama_prefers_embedding_endpoint(mock_ai_config):
with patch(
"llama_index.embeddings.ollama.OllamaEmbedding",
) as MockOllamaEmbedding:
model = get_embedding_model()
model = get_embedding_model(mock_ai_config.return_value)
MockOllamaEmbedding.assert_called_once_with(
model_name="embeddinggemma",
base_url="http://embedding-url",
@@ -172,7 +172,7 @@ def test_get_embedding_model_ollama_blocks_internal_endpoint_when_disallowed(
mock_ai_config.return_value.llm_allow_internal_endpoints = False
with pytest.raises(ValueError, match="non-public address"):
get_embedding_model()
get_embedding_model(mock_ai_config.return_value)
def test_get_embedding_model_invalid_backend(mock_ai_config):
@@ -182,7 +182,7 @@ def test_get_embedding_model_invalid_backend(mock_ai_config):
ValueError,
match="Unsupported embedding backend: INVALID_BACKEND",
):
get_embedding_model()
get_embedding_model(mock_ai_config.return_value)
@pytest.mark.parametrize(
@@ -199,18 +199,20 @@ def test_get_configured_model_name_falls_back_to_backend_default(
expected_default,
):
"""When no model is explicitly configured, each backend has a distinct default."""
mock_ai_config.return_value.llm_embedding_backend = backend
mock_ai_config.return_value.llm_embedding_model = None
assert get_configured_model_name() == expected_default
config = mock_ai_config.return_value
config.llm_embedding_backend = backend
config.llm_embedding_model = None
assert get_configured_model_name(config) == expected_default
def test_get_configured_model_name_explicit_overrides_default(mock_ai_config):
"""An explicit model name overrides the backend default for all backends."""
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
mock_ai_config.return_value.llm_embedding_model = "my-custom-model"
config = mock_ai_config.return_value
config.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
config.llm_embedding_model = "my-custom-model"
# The backend default for OPENAI_LIKE is "text-embedding-3-small", so if
# the explicit name was ignored we'd get the wrong result.
assert get_configured_model_name() == "my-custom-model"
assert get_configured_model_name(config) == "my-custom-model"
def test_build_llm_index_text(mock_document):