mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-07 14:19:45 +00:00
Construct fewer AiConfig objects and instead pass around as needed
This commit is contained in:
@@ -24,9 +24,14 @@ def get_language_name(language_code: str) -> str:
|
||||
|
||||
def build_prompt_without_rag(
|
||||
document: Document,
|
||||
config: AIConfig,
|
||||
) -> str:
|
||||
filename = document.filename or ""
|
||||
content = truncate_content(document.content[:4000] or "")
|
||||
content = truncate_content(
|
||||
document.content[:4000] or "",
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
|
||||
return f"""
|
||||
You are a document classification assistant.
|
||||
@@ -49,10 +54,15 @@ def build_prompt_without_rag(
|
||||
|
||||
def build_prompt_with_rag(
|
||||
document: Document,
|
||||
config: AIConfig,
|
||||
user: User | None = None,
|
||||
) -> str:
|
||||
base_prompt = build_prompt_without_rag(document)
|
||||
context = truncate_content(get_context_for_document(document, user))
|
||||
base_prompt = build_prompt_without_rag(document, config)
|
||||
context = truncate_content(
|
||||
get_context_for_document(document, user),
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
|
||||
return f"""{base_prompt}
|
||||
|
||||
@@ -130,9 +140,9 @@ def get_ai_document_classification(
|
||||
ai_config = AIConfig()
|
||||
|
||||
prompt = (
|
||||
build_prompt_with_rag(document, user)
|
||||
build_prompt_with_rag(document, ai_config, user)
|
||||
if ai_config.llm_embedding_backend
|
||||
else build_prompt_without_rag(document)
|
||||
else build_prompt_without_rag(document, ai_config)
|
||||
)
|
||||
|
||||
client = AIClient()
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import sys
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.indexing import _document_id_filters
|
||||
from paperless_ai.indexing import get_rag_prompt_helper
|
||||
@@ -94,7 +95,8 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
from llama_index.core.response_synthesizers import get_response_synthesizer
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
|
||||
index = load_or_build_index()
|
||||
config = AIConfig()
|
||||
index = load_or_build_index(config)
|
||||
filters = _document_id_filters(str(doc.pk) for doc in documents)
|
||||
|
||||
retriever = VectorIndexRetriever(
|
||||
@@ -116,7 +118,10 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
|
||||
response_synthesizer = get_response_synthesizer(
|
||||
llm=client.llm,
|
||||
prompt_helper=get_rag_prompt_helper(),
|
||||
prompt_helper=get_rag_prompt_helper(
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
),
|
||||
text_qa_template=prompt_template,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
@@ -20,9 +20,7 @@ OCR_LEADER_REGEX = re.compile(r"[._\-\u00b7]{4,}")
|
||||
HORIZONTAL_WHITESPACE_REGEX = re.compile(r"[ \t\u00a0]+")
|
||||
|
||||
|
||||
def get_embedding_model() -> "BaseEmbedding":
|
||||
config = AIConfig()
|
||||
|
||||
def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
|
||||
match config.llm_embedding_backend:
|
||||
case LLMEmbeddingBackend.OPENAI_LIKE:
|
||||
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
||||
@@ -99,9 +97,8 @@ _DEFAULT_MODEL_NAMES = {
|
||||
}
|
||||
|
||||
|
||||
def get_configured_model_name() -> str:
|
||||
def get_configured_model_name(config: AIConfig) -> str:
|
||||
"""Return the canonical name of the currently configured embedding model."""
|
||||
config = AIConfig()
|
||||
default = _DEFAULT_MODEL_NAMES.get(
|
||||
config.llm_embedding_backend,
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
|
||||
@@ -139,12 +139,12 @@ def build_document_node(
|
||||
return parser.get_nodes_from_documents([doc])
|
||||
|
||||
|
||||
def load_or_build_index():
|
||||
def load_or_build_index(config: AIConfig):
|
||||
"""Return a VectorStoreIndex backed by the vector store."""
|
||||
import llama_index.core.settings as llama_settings
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
||||
embed_model = get_embedding_model()
|
||||
embed_model = get_embedding_model(config)
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
vector_store = get_vector_store()
|
||||
return VectorStoreIndex.from_vector_store(
|
||||
@@ -223,7 +223,16 @@ def update_llm_index(
|
||||
rebuild=False,
|
||||
) -> str:
|
||||
"""Rebuild or incrementally update the LLM index."""
|
||||
model_name = get_configured_model_name()
|
||||
documents = Document.objects.all()
|
||||
no_documents = not documents.exists()
|
||||
|
||||
# Fast exit before touching config: nothing to index and no existing index.
|
||||
if no_documents and not rebuild and not llm_index_exists():
|
||||
logger.warning("No documents found to index.")
|
||||
return "No documents found to index."
|
||||
|
||||
config = AIConfig()
|
||||
model_name = get_configured_model_name(config)
|
||||
|
||||
if (
|
||||
not rebuild
|
||||
@@ -233,14 +242,11 @@ def update_llm_index(
|
||||
logger.warning("Embedding model changed; forcing LLM index rebuild.")
|
||||
rebuild = True
|
||||
|
||||
documents = Document.objects.all()
|
||||
if not documents.exists():
|
||||
if no_documents:
|
||||
logger.warning("No documents found to index.")
|
||||
if not rebuild and not llm_index_exists():
|
||||
return "No documents found to index."
|
||||
|
||||
chunk_size = AIConfig().llm_embedding_chunk_size
|
||||
embed_model = get_embedding_model()
|
||||
chunk_size = config.llm_embedding_chunk_size
|
||||
embed_model = get_embedding_model(config)
|
||||
|
||||
with write_store(embed_model_name=model_name) as store:
|
||||
if rebuild or not store.table_exists():
|
||||
@@ -277,11 +283,15 @@ def update_llm_index(
|
||||
|
||||
def llm_index_add_or_update_document(document: Document):
|
||||
"""Add or atomically replace a document's chunks in the index."""
|
||||
new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size())
|
||||
config = AIConfig()
|
||||
new_nodes = build_document_node(
|
||||
document,
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
)
|
||||
if new_nodes:
|
||||
_embed_nodes(new_nodes, get_embedding_model())
|
||||
_embed_nodes(new_nodes, get_embedding_model(config))
|
||||
|
||||
with write_store(embed_model_name=get_configured_model_name()) as store:
|
||||
with write_store(embed_model_name=get_configured_model_name(config)) as store:
|
||||
store.upsert_document(str(document.id), new_nodes)
|
||||
store.ensure_document_id_scalar_index()
|
||||
|
||||
@@ -352,9 +362,11 @@ def query_similar_documents(
|
||||
)
|
||||
return []
|
||||
|
||||
config = AIConfig()
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
|
||||
index = load_or_build_index()
|
||||
index = load_or_build_index(config)
|
||||
|
||||
filters = (
|
||||
_document_id_filters(allowed_document_ids)
|
||||
@@ -368,7 +380,6 @@ def query_similar_documents(
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
config = AIConfig()
|
||||
query_text = truncate_content(
|
||||
(document.title or "") + "\n" + (document.content or ""),
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
from django.test import override_settings
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.ai_classifier import build_localization_prompt
|
||||
from paperless_ai.ai_classifier import build_prompt_with_rag
|
||||
from paperless_ai.ai_classifier import build_prompt_without_rag
|
||||
@@ -211,11 +212,12 @@ def test_prompt_with_without_rag(mock_document):
|
||||
"paperless_ai.ai_classifier.get_context_for_document",
|
||||
return_value="Context from similar documents",
|
||||
):
|
||||
prompt = build_prompt_without_rag(mock_document)
|
||||
config = AIConfig()
|
||||
prompt = build_prompt_without_rag(mock_document, config)
|
||||
assert "Additional context from similar documents" not in prompt
|
||||
assert "for generated" not in prompt
|
||||
|
||||
prompt = build_prompt_with_rag(mock_document)
|
||||
prompt = build_prompt_with_rag(mock_document, config)
|
||||
assert "Additional context from similar documents" in prompt
|
||||
|
||||
prompt = build_localization_prompt(
|
||||
|
||||
@@ -185,6 +185,7 @@ def test_stream_chat_empty_document_list() -> None:
|
||||
|
||||
def test_stream_chat_no_matching_nodes() -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIConfig"),
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
):
|
||||
@@ -204,6 +205,7 @@ def test_stream_chat_no_matching_nodes() -> None:
|
||||
|
||||
def test_stream_chat_unexpected_failure_returns_generic_error(caplog) -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIConfig"),
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
):
|
||||
|
||||
@@ -66,7 +66,7 @@ def test_get_embedding_model_openai(mock_ai_config):
|
||||
with patch(
|
||||
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
|
||||
) as MockOpenAIEmbedding:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(mock_ai_config.return_value)
|
||||
MockOpenAIEmbedding.assert_called_once_with(
|
||||
model_name="text-embedding-3-small",
|
||||
api_key="test_api_key",
|
||||
@@ -87,7 +87,7 @@ def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config):
|
||||
with patch(
|
||||
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
|
||||
) as MockOpenAIEmbedding:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(mock_ai_config.return_value)
|
||||
MockOpenAIEmbedding.assert_called_once_with(
|
||||
model_name="text-embedding-3-small",
|
||||
api_key="test_api_key",
|
||||
@@ -108,7 +108,7 @@ def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
|
||||
mock_ai_config.return_value.llm_allow_internal_endpoints = False
|
||||
|
||||
with pytest.raises(ValueError, match="non-public address"):
|
||||
get_embedding_model()
|
||||
get_embedding_model(mock_ai_config.return_value)
|
||||
|
||||
|
||||
def test_get_embedding_model_huggingface(mock_ai_config):
|
||||
@@ -120,7 +120,7 @@ def test_get_embedding_model_huggingface(mock_ai_config):
|
||||
with patch(
|
||||
"llama_index.embeddings.huggingface.HuggingFaceEmbedding",
|
||||
) as MockHuggingFaceEmbedding:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(mock_ai_config.return_value)
|
||||
MockHuggingFaceEmbedding.assert_called_once_with(
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||
cache_folder=str(settings.DATA_DIR / "hf_cache"),
|
||||
@@ -136,7 +136,7 @@ def test_get_embedding_model_ollama(mock_ai_config):
|
||||
with patch(
|
||||
"llama_index.embeddings.ollama.OllamaEmbedding",
|
||||
) as MockOllamaEmbedding:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(mock_ai_config.return_value)
|
||||
MockOllamaEmbedding.assert_called_once_with(
|
||||
model_name="embeddinggemma",
|
||||
base_url="http://test-url",
|
||||
@@ -154,7 +154,7 @@ def test_get_embedding_model_ollama_prefers_embedding_endpoint(mock_ai_config):
|
||||
with patch(
|
||||
"llama_index.embeddings.ollama.OllamaEmbedding",
|
||||
) as MockOllamaEmbedding:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(mock_ai_config.return_value)
|
||||
MockOllamaEmbedding.assert_called_once_with(
|
||||
model_name="embeddinggemma",
|
||||
base_url="http://embedding-url",
|
||||
@@ -172,7 +172,7 @@ def test_get_embedding_model_ollama_blocks_internal_endpoint_when_disallowed(
|
||||
mock_ai_config.return_value.llm_allow_internal_endpoints = False
|
||||
|
||||
with pytest.raises(ValueError, match="non-public address"):
|
||||
get_embedding_model()
|
||||
get_embedding_model(mock_ai_config.return_value)
|
||||
|
||||
|
||||
def test_get_embedding_model_invalid_backend(mock_ai_config):
|
||||
@@ -182,7 +182,7 @@ def test_get_embedding_model_invalid_backend(mock_ai_config):
|
||||
ValueError,
|
||||
match="Unsupported embedding backend: INVALID_BACKEND",
|
||||
):
|
||||
get_embedding_model()
|
||||
get_embedding_model(mock_ai_config.return_value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -199,18 +199,20 @@ def test_get_configured_model_name_falls_back_to_backend_default(
|
||||
expected_default,
|
||||
):
|
||||
"""When no model is explicitly configured, each backend has a distinct default."""
|
||||
mock_ai_config.return_value.llm_embedding_backend = backend
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
assert get_configured_model_name() == expected_default
|
||||
config = mock_ai_config.return_value
|
||||
config.llm_embedding_backend = backend
|
||||
config.llm_embedding_model = None
|
||||
assert get_configured_model_name(config) == expected_default
|
||||
|
||||
|
||||
def test_get_configured_model_name_explicit_overrides_default(mock_ai_config):
|
||||
"""An explicit model name overrides the backend default for all backends."""
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
|
||||
mock_ai_config.return_value.llm_embedding_model = "my-custom-model"
|
||||
config = mock_ai_config.return_value
|
||||
config.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
|
||||
config.llm_embedding_model = "my-custom-model"
|
||||
# The backend default for OPENAI_LIKE is "text-embedding-3-small", so if
|
||||
# the explicit name was ignored we'd get the wrong result.
|
||||
assert get_configured_model_name() == "my-custom-model"
|
||||
assert get_configured_model_name(config) == "my-custom-model"
|
||||
|
||||
|
||||
def test_build_llm_index_text(mock_document):
|
||||
|
||||
Reference in New Issue
Block a user