From f6c865bf47b876dc168b18d9400f895ef512db2e Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Mon, 1 Jun 2026 10:56:21 -0700 Subject: [PATCH] Enhancement: AI LLM chunk size and context window config (#12891) --- docs/configuration.md | 16 ++++ src-ui/src/app/data/paperless-config.ts | 16 ++++ src/documents/tests/test_api_app_config.py | 89 ++++++++++++++++++- src/paperless/config.py | 6 ++ ...nconfiguration_llm_embedding_chunk_size.py | 32 +++++++ src/paperless/models.py | 12 +++ src/paperless/settings/__init__.py | 9 ++ src/paperless/views.py | 53 ++++++++--- src/paperless_ai/client.py | 1 + src/paperless_ai/embedding.py | 1 + src/paperless_ai/indexing.py | 76 ++++++++++++---- src/paperless_ai/tests/test_ai_indexing.py | 36 ++++++-- src/paperless_ai/tests/test_chat.py | 2 + src/paperless_ai/tests/test_client.py | 2 + src/paperless_ai/tests/test_embedding.py | 3 + 15 files changed, 318 insertions(+), 36 deletions(-) create mode 100644 src/paperless/migrations/0011_applicationconfiguration_llm_embedding_chunk_size.py diff --git a/docs/configuration.md b/docs/configuration.md index 43fa6b704..66470792d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2052,6 +2052,22 @@ models supported by the current embedding backend. If not supplied, defaults to Defaults to None. +#### [`PAPERLESS_AI_LLM_EMBEDDING_CHUNK_SIZE=`](#PAPERLESS_AI_LLM_EMBEDDING_CHUNK_SIZE) {#PAPERLESS_AI_LLM_EMBEDDING_CHUNK_SIZE} + +: The chunk size to use when splitting document text for RAG embeddings. Lower this value if your +embedding backend or model rejects larger inputs, or silently truncates inputs in a way that harms +retrieval quality. + + Defaults to 1024. + +#### [`PAPERLESS_AI_LLM_CONTEXT_SIZE=`](#PAPERLESS_AI_LLM_CONTEXT_SIZE) {#PAPERLESS_AI_LLM_CONTEXT_SIZE} + +: The context size to use for AI prompts and RAG retrieval. For Ollama backends, this is also sent +as `num_ctx` so models with very large native context windows are not loaded at their maximum +context by default. + + Defaults to 8192. + #### [`PAPERLESS_AI_LLM_BACKEND=`](#PAPERLESS_AI_LLM_BACKEND) {#PAPERLESS_AI_LLM_BACKEND} : The AI backend to use. This can be either "openai-like" or "ollama". If set to "ollama", the AI diff --git a/src-ui/src/app/data/paperless-config.ts b/src-ui/src/app/data/paperless-config.ts index 52061dd18..f7b654bf4 100644 --- a/src-ui/src/app/data/paperless-config.ts +++ b/src-ui/src/app/data/paperless-config.ts @@ -309,6 +309,20 @@ export const PaperlessConfigOptions: ConfigOption[] = [ config_key: 'PAPERLESS_AI_LLM_EMBEDDING_ENDPOINT', category: ConfigCategory.AI, }, + { + key: 'llm_embedding_chunk_size', + title: $localize`LLM Embedding Chunk Size`, + type: ConfigOptionType.Number, + config_key: 'PAPERLESS_AI_LLM_EMBEDDING_CHUNK_SIZE', + category: ConfigCategory.AI, + }, + { + key: 'llm_context_size', + title: $localize`LLM Context Size`, + type: ConfigOptionType.Number, + config_key: 'PAPERLESS_AI_LLM_CONTEXT_SIZE', + category: ConfigCategory.AI, + }, { key: 'llm_backend', title: $localize`LLM Backend`, @@ -372,6 +386,8 @@ export interface PaperlessConfig extends ObjectWithId { llm_embedding_backend: string llm_embedding_model: string llm_embedding_endpoint: string + llm_embedding_chunk_size: number + llm_context_size: number llm_backend: string llm_model: string llm_api_key: string diff --git a/src/documents/tests/test_api_app_config.py b/src/documents/tests/test_api_app_config.py index 3372a16eb..e0441f17c 100644 --- a/src/documents/tests/test_api_app_config.py +++ b/src/documents/tests/test_api_app_config.py @@ -75,6 +75,8 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): "llm_embedding_backend": None, "llm_embedding_model": None, "llm_embedding_endpoint": None, + "llm_embedding_chunk_size": None, + "llm_context_size": None, "llm_backend": None, "llm_model": None, "llm_api_key": None, @@ -841,7 +843,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): with ( patch("documents.tasks.llmindex_index.apply_async") as mock_update, - patch("paperless_ai.indexing.vector_store_file_exists") as mock_exists, + patch("paperless.views.vector_store_file_exists") as mock_exists, ): mock_exists.return_value = False self.client.patch( @@ -856,6 +858,91 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): ) mock_update.assert_called_once() + def test_update_llm_embedding_chunk_size_triggers_rebuild(self) -> None: + config = ApplicationConfiguration.objects.first() + assert config is not None + config.ai_enabled = True + config.llm_embedding_backend = "openai-like" + config.llm_embedding_chunk_size = 1024 + config.save() + + with ( + patch("documents.tasks.llmindex_index.apply_async") as mock_update, + patch("paperless.views.vector_store_file_exists") as mock_exists, + ): + mock_exists.return_value = True + self.client.patch( + f"{self.ENDPOINT}1/", + json.dumps({"llm_embedding_chunk_size": 512}), + content_type="application/json", + ) + mock_update.assert_called_once() + self.assertEqual(mock_update.call_args.kwargs["kwargs"], {"rebuild": True}) + + def test_update_llm_context_size_triggers_rebuild(self) -> None: + config = ApplicationConfiguration.objects.first() + assert config is not None + config.ai_enabled = True + config.llm_embedding_backend = "openai-like" + config.llm_context_size = 8192 + config.save() + + with ( + patch("documents.tasks.llmindex_index.apply_async") as mock_update, + patch("paperless.views.vector_store_file_exists") as mock_exists, + ): + mock_exists.return_value = True + self.client.patch( + f"{self.ENDPOINT}1/", + json.dumps({"llm_context_size": 4096}), + content_type="application/json", + ) + mock_update.assert_called_once() + self.assertEqual(mock_update.call_args.kwargs["kwargs"], {"rebuild": True}) + + def test_update_llm_embedding_model_triggers_rebuild(self) -> None: + config = ApplicationConfiguration.objects.first() + assert config is not None + config.ai_enabled = True + config.llm_embedding_backend = "openai-like" + config.llm_embedding_model = "text-embedding-3-small" + config.save() + + with patch("documents.tasks.llmindex_index.apply_async") as mock_update: + self.client.patch( + f"{self.ENDPOINT}1/", + json.dumps({"llm_embedding_model": "text-embedding-3-large"}), + content_type="application/json", + ) + mock_update.assert_called_once() + self.assertEqual(mock_update.call_args.kwargs["kwargs"], {"rebuild": True}) + + def test_enable_ai_index_with_config_change_triggers_rebuild(self) -> None: + config = ApplicationConfiguration.objects.first() + assert config is not None + config.ai_enabled = False + config.llm_embedding_backend = "openai-like" + config.llm_embedding_model = "text-embedding-3-small" + config.save() + + with ( + patch("documents.tasks.llmindex_index.apply_async") as mock_update, + patch("paperless.views.vector_store_file_exists") as mock_exists, + ): + mock_exists.return_value = True + self.client.patch( + f"{self.ENDPOINT}1/", + json.dumps( + { + "ai_enabled": True, + "llm_embedding_model": "text-embedding-3-large", + }, + ), + content_type="application/json", + ) + mock_update.assert_called_once() + self.assertEqual(mock_update.call_args.kwargs["kwargs"], {"rebuild": True}) + @override_settings(LLM_ALLOW_INTERNAL_ENDPOINTS=False) def test_update_llm_endpoint_blocks_internal_endpoint_when_disallowed(self) -> None: response = self.client.patch( diff --git a/src/paperless/config.py b/src/paperless/config.py index f4ade5ffe..8c9c7b3ca 100644 --- a/src/paperless/config.py +++ b/src/paperless/config.py @@ -195,6 +195,8 @@ class AIConfig(BaseConfig): llm_embedding_backend: str = dataclasses.field(init=False) llm_embedding_model: str = dataclasses.field(init=False) llm_embedding_endpoint: str = dataclasses.field(init=False) + llm_embedding_chunk_size: int = dataclasses.field(init=False) + llm_context_size: int = dataclasses.field(init=False) llm_backend: str = dataclasses.field(init=False) llm_model: str = dataclasses.field(init=False) llm_api_key: str = dataclasses.field(init=False) @@ -214,6 +216,10 @@ class AIConfig(BaseConfig): self.llm_embedding_endpoint = ( app_config.llm_embedding_endpoint or settings.LLM_EMBEDDING_ENDPOINT ) + self.llm_embedding_chunk_size = ( + app_config.llm_embedding_chunk_size or settings.LLM_EMBEDDING_CHUNK_SIZE + ) + self.llm_context_size = app_config.llm_context_size or settings.LLM_CONTEXT_SIZE self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND self.llm_model = app_config.llm_model or settings.LLM_MODEL self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY diff --git a/src/paperless/migrations/0011_applicationconfiguration_llm_embedding_chunk_size.py b/src/paperless/migrations/0011_applicationconfiguration_llm_embedding_chunk_size.py new file mode 100644 index 000000000..7d464ec51 --- /dev/null +++ b/src/paperless/migrations/0011_applicationconfiguration_llm_embedding_chunk_size.py @@ -0,0 +1,32 @@ +# Generated by Django 5.2.6 on 2026-05-31 + +from django.core.validators import MinValueValidator +from django.db import migrations +from django.db import models + + +class Migration(migrations.Migration): + dependencies = [ + ("paperless", "0010_alter_applicationconfiguration_llm_embedding_backend"), + ] + + operations = [ + migrations.AddField( + model_name="applicationconfiguration", + name="llm_embedding_chunk_size", + field=models.PositiveSmallIntegerField( + null=True, + validators=[MinValueValidator(1)], + verbose_name="Sets the LLM embedding chunk size", + ), + ), + migrations.AddField( + model_name="applicationconfiguration", + name="llm_context_size", + field=models.PositiveIntegerField( + null=True, + validators=[MinValueValidator(1)], + verbose_name="Sets the LLM context size", + ), + ), + ] diff --git a/src/paperless/models.py b/src/paperless/models.py index 95e52426e..7c562b811 100644 --- a/src/paperless/models.py +++ b/src/paperless/models.py @@ -318,6 +318,18 @@ class ApplicationConfiguration(AbstractSingletonModel): max_length=256, ) + llm_embedding_chunk_size = models.PositiveSmallIntegerField( + verbose_name=_("Sets the LLM embedding chunk size"), + null=True, + validators=[MinValueValidator(1)], + ) + + llm_context_size = models.PositiveIntegerField( + verbose_name=_("Sets the LLM context size"), + null=True, + validators=[MinValueValidator(1)], + ) + llm_backend = models.CharField( verbose_name=_("Sets the LLM backend"), blank=True, diff --git a/src/paperless/settings/__init__.py b/src/paperless/settings/__init__.py index 5d208c9f3..df3011eb2 100644 --- a/src/paperless/settings/__init__.py +++ b/src/paperless/settings/__init__.py @@ -1187,6 +1187,15 @@ LLM_EMBEDDING_BACKEND = os.getenv( ) # "huggingface", "openai-like", or "ollama" LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_AI_LLM_EMBEDDING_MODEL") LLM_EMBEDDING_ENDPOINT = os.getenv("PAPERLESS_AI_LLM_EMBEDDING_ENDPOINT") +LLM_EMBEDDING_CHUNK_SIZE = get_int_from_env( + "PAPERLESS_AI_LLM_EMBEDDING_CHUNK_SIZE", + 1024, +) +if LLM_EMBEDDING_CHUNK_SIZE < 1: + raise ImproperlyConfigured("PAPERLESS_AI_LLM_EMBEDDING_CHUNK_SIZE must be >= 1") +LLM_CONTEXT_SIZE = get_int_from_env("PAPERLESS_AI_LLM_CONTEXT_SIZE", 8192) +if LLM_CONTEXT_SIZE < 1: + raise ImproperlyConfigured("PAPERLESS_AI_LLM_CONTEXT_SIZE must be >= 1") LLM_BACKEND = os.getenv("PAPERLESS_AI_LLM_BACKEND") # "ollama" or "openai-like" LLM_MODEL = os.getenv("PAPERLESS_AI_LLM_MODEL") LLM_API_KEY = os.getenv("PAPERLESS_AI_LLM_API_KEY") diff --git a/src/paperless/views.py b/src/paperless/views.py index 022d7f217..9ed4a2a87 100644 --- a/src/paperless/views.py +++ b/src/paperless/views.py @@ -423,21 +423,54 @@ class ApplicationConfigurationViewSet(ModelViewSet[ApplicationConfiguration]): def perform_update(self, serializer): old_instance = ApplicationConfiguration.objects.all().first() - old_ai_index_enabled = ( - old_instance.ai_enabled and old_instance.llm_embedding_backend + old_llm_embedding_backend = ( + old_instance.llm_embedding_backend or settings.LLM_EMBEDDING_BACKEND + ) + old_llm_embedding_chunk_size = ( + old_instance.llm_embedding_chunk_size or settings.LLM_EMBEDDING_CHUNK_SIZE + ) + old_llm_embedding_endpoint = ( + old_instance.llm_embedding_endpoint or settings.LLM_EMBEDDING_ENDPOINT + ) + old_llm_embedding_model = ( + old_instance.llm_embedding_model or settings.LLM_EMBEDDING_MODEL + ) + old_llm_context_size = ( + old_instance.llm_context_size or settings.LLM_CONTEXT_SIZE ) new_instance: ApplicationConfiguration = serializer.save() - new_ai_index_enabled = ( - new_instance.ai_enabled and new_instance.llm_embedding_backend + new_llm_embedding_backend = ( + new_instance.llm_embedding_backend or settings.LLM_EMBEDDING_BACKEND + ) + new_ai_index_enabled = bool( + new_instance.ai_enabled and new_llm_embedding_backend, + ) + new_llm_embedding_chunk_size = ( + new_instance.llm_embedding_chunk_size or settings.LLM_EMBEDDING_CHUNK_SIZE + ) + new_llm_embedding_endpoint = ( + new_instance.llm_embedding_endpoint or settings.LLM_EMBEDDING_ENDPOINT + ) + new_llm_embedding_model = ( + new_instance.llm_embedding_model or settings.LLM_EMBEDDING_MODEL + ) + new_llm_context_size = ( + new_instance.llm_context_size or settings.LLM_CONTEXT_SIZE ) - if ( - not old_ai_index_enabled - and new_ai_index_enabled - and not vector_store_file_exists() - ): - # AI index was just enabled and vector store file does not exist + embedding_config_changed = ( + old_llm_embedding_backend != new_llm_embedding_backend + or old_llm_embedding_chunk_size != new_llm_embedding_chunk_size + or old_llm_embedding_endpoint != new_llm_embedding_endpoint + or old_llm_embedding_model != new_llm_embedding_model + or old_llm_context_size != new_llm_context_size + ) + rebuild_needed = new_ai_index_enabled and ( + not vector_store_file_exists() or embedding_config_changed + ) + + if rebuild_needed: llmindex_index.apply_async( kwargs={"rebuild": True}, headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, diff --git a/src/paperless_ai/client.py b/src/paperless_ai/client.py index 771f89f55..64d4276ac 100644 --- a/src/paperless_ai/client.py +++ b/src/paperless_ai/client.py @@ -59,6 +59,7 @@ class AIClient: return Ollama( model=self.settings.llm_model or "llama3.1", base_url=endpoint, + context_window=self.settings.llm_context_size, request_timeout=120, system_prompt=LLM_SYSTEM_PROMPT, client=Client( diff --git a/src/paperless_ai/embedding.py b/src/paperless_ai/embedding.py index 407dd4c0e..2695e9fb3 100644 --- a/src/paperless_ai/embedding.py +++ b/src/paperless_ai/embedding.py @@ -74,6 +74,7 @@ def get_embedding_model() -> "BaseEmbedding": embedding = OllamaEmbedding( model_name=config.llm_embedding_model or "embeddinggemma", base_url=endpoint, + ollama_additional_kwargs={"num_ctx": config.llm_context_size}, ) embedding._client = Client( host=endpoint, diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index 28be2c94b..5e2c5e369 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -12,6 +12,7 @@ from documents.models import Document from documents.models import PaperlessTask from documents.utils import IterWrapper from documents.utils import identity +from paperless.config import AIConfig from paperless_ai.embedding import build_llm_index_text from paperless_ai.embedding import get_embedding_dim from paperless_ai.embedding import get_embedding_model @@ -23,9 +24,7 @@ if TYPE_CHECKING: logger = logging.getLogger("paperless_ai.indexing") -RAG_CONTEXT_WINDOW = 8192 RAG_NUM_OUTPUT = 512 -RAG_CHUNK_SIZE = 1024 RAG_CHUNK_OVERLAP = 200 @@ -95,7 +94,11 @@ def get_or_create_storage_context(*, rebuild=False): ) -def build_document_node(document: Document) -> list["BaseNode"]: +def build_document_node( + document: Document, + *, + chunk_size: int | None = None, +) -> list["BaseNode"]: """ Given a Document, returns parsed Nodes ready for indexing. """ @@ -126,9 +129,10 @@ def build_document_node(document: Document) -> list["BaseNode"]: metadata=metadata, excluded_embed_metadata_keys=list(metadata.keys()), ) + chunk_size = chunk_size or get_rag_chunk_size() parser = SimpleNodeParser( - chunk_size=RAG_CHUNK_SIZE, - chunk_overlap=get_rag_chunk_overlap(), + chunk_size=chunk_size, + chunk_overlap=get_rag_chunk_overlap(chunk_size), ) return parser.get_nodes_from_documents([doc]) @@ -186,18 +190,36 @@ def vector_store_file_exists(): return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists() -def get_rag_chunk_overlap() -> int: - return min(RAG_CHUNK_OVERLAP, RAG_CHUNK_SIZE - 1) +def get_rag_chunk_size() -> int: + return AIConfig().llm_embedding_chunk_size -def get_rag_prompt_helper(): +def get_rag_context_size() -> int: + return AIConfig().llm_context_size + + +def get_rag_chunk_overlap(chunk_size: int | None = None) -> int: + chunk_size = chunk_size or get_rag_chunk_size() + return min(RAG_CHUNK_OVERLAP, chunk_size - 1) + + +def get_rag_prompt_helper( + *, + chunk_size: int | None = None, + context_size: int | None = None, +): from llama_index.core.indices.prompt_helper import PromptHelper + if chunk_size is None or context_size is None: + config = AIConfig() + chunk_size = chunk_size or config.llm_embedding_chunk_size + context_size = context_size or config.llm_context_size + return PromptHelper( - context_window=RAG_CONTEXT_WINDOW, + context_window=context_size, num_output=RAG_NUM_OUTPUT, chunk_overlap_ratio=0.1, - chunk_size_limit=RAG_CHUNK_SIZE, + chunk_size_limit=chunk_size, ) @@ -219,6 +241,9 @@ def update_llm_index( logger.warning(msg) return msg + config = AIConfig() + chunk_size = config.llm_embedding_chunk_size + if rebuild or not vector_store_file_exists(): # remove meta.json to force re-detection of embedding dim (settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True) @@ -230,7 +255,7 @@ def update_llm_index( llama_settings.Settings.embed_model = embed_model storage_context = get_or_create_storage_context(rebuild=True) for document in iter_wrapper(documents): - document_nodes = build_document_node(document) + document_nodes = build_document_node(document, chunk_size=chunk_size) nodes.extend(document_nodes) index = VectorStoreIndex( @@ -262,10 +287,10 @@ def update_llm_index( # Again, delete from docstore, FAISS IndexFlatL2 are append-only index.docstore.delete_document(node.node_id) - nodes.extend(build_document_node(document)) + nodes.extend(build_document_node(document, chunk_size=chunk_size)) else: # New document, add it - nodes.extend(build_document_node(document)) + nodes.extend(build_document_node(document, chunk_size=chunk_size)) if nodes: msg = "LLM index updated successfully." @@ -287,7 +312,7 @@ def llm_index_add_or_update_document(document: Document): Adds or updates a document in the LLM index. If the document already exists, it will be replaced. """ - new_nodes = build_document_node(document) + new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size()) index = load_or_build_index(nodes=new_nodes) @@ -309,15 +334,27 @@ def llm_index_remove_document(document: Document): index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) -def truncate_content(content: str) -> str: +def truncate_content( + content: str, + *, + chunk_size: int | None = None, + context_size: int | None = None, +) -> str: from llama_index.core.prompts import PromptTemplate from llama_index.core.text_splitter import TokenTextSplitter - prompt_helper = get_rag_prompt_helper() + if chunk_size is None or context_size is None: + config = AIConfig() + chunk_size = chunk_size or config.llm_embedding_chunk_size + context_size = context_size or config.llm_context_size + prompt_helper = get_rag_prompt_helper( + chunk_size=chunk_size, + context_size=context_size, + ) splitter = TokenTextSplitter( separator=" ", - chunk_size=RAG_CHUNK_SIZE, - chunk_overlap=get_rag_chunk_overlap(), + chunk_size=chunk_size, + chunk_overlap=get_rag_chunk_overlap(chunk_size), ) content_chunks = splitter.split_text(content) truncated_chunks = prompt_helper.truncate( @@ -376,8 +413,11 @@ def query_similar_documents( doc_ids=doc_node_ids, ) + config = AIConfig() query_text = truncate_content( (document.title or "") + "\n" + (document.content or ""), + chunk_size=config.llm_embedding_chunk_size, + context_size=config.llm_context_size, ) results = retriever.retrieve(query_text) diff --git a/src/paperless_ai/tests/test_ai_indexing.py b/src/paperless_ai/tests/test_ai_indexing.py index 356619549..f0f66fb72 100644 --- a/src/paperless_ai/tests/test_ai_indexing.py +++ b/src/paperless_ai/tests/test_ai_indexing.py @@ -11,6 +11,7 @@ from llama_index.core.base.embeddings.base import BaseEmbedding from documents.models import Document from documents.models import PaperlessTask from documents.tests.factories import PaperlessTaskFactory +from paperless.models import ApplicationConfiguration from paperless_ai import indexing @@ -81,20 +82,32 @@ def test_build_document_node_excludes_metadata_from_embedding(real_document) -> @pytest.mark.django_db def test_build_document_node_uses_rag_chunk_settings(real_document) -> None: + app_config, _ = ApplicationConfiguration.objects.get_or_create() + app_config.llm_embedding_chunk_size = 512 + app_config.save() + with patch("llama_index.core.node_parser.SimpleNodeParser") as mock_parser: mock_parser.return_value.get_nodes_from_documents.return_value = [] indexing.build_document_node(real_document) - mock_parser.assert_called_once_with(chunk_size=1024, chunk_overlap=200) + mock_parser.assert_called_once_with(chunk_size=512, chunk_overlap=200) def test_get_rag_chunk_overlap_clamps_to_chunk_size() -> None: - with ( - patch("paperless_ai.indexing.RAG_CHUNK_SIZE", 64), - patch("paperless_ai.indexing.RAG_CHUNK_OVERLAP", 128), - ): - assert indexing.get_rag_chunk_overlap() == 63 + with patch("paperless_ai.indexing.RAG_CHUNK_OVERLAP", 128): + assert indexing.get_rag_chunk_overlap(64) == 63 + + +@pytest.mark.django_db +def test_get_rag_prompt_helper_uses_context_setting() -> None: + app_config, _ = ApplicationConfiguration.objects.get_or_create() + app_config.llm_context_size = 4096 + app_config.save() + + prompt_helper = indexing.get_rag_prompt_helper() + + assert prompt_helper.context_window == 4096 @pytest.mark.django_db @@ -103,13 +116,22 @@ def test_update_llm_index( real_document, mock_embed_model, ) -> None: - with patch("documents.models.Document.objects.all") as mock_all: + mock_config = MagicMock() + mock_config.llm_embedding_chunk_size = 512 + with ( + patch("documents.models.Document.objects.all") as mock_all, + patch("paperless_ai.indexing.AIConfig", return_value=mock_config) as ai_config, + patch("paperless_ai.indexing.build_document_node") as build_document_node, + ): mock_queryset = MagicMock() mock_queryset.exists.return_value = True mock_queryset.__iter__.return_value = iter([real_document]) mock_all.return_value = mock_queryset + build_document_node.return_value = [] indexing.update_llm_index(rebuild=True) + ai_config.assert_called_once() + build_document_node.assert_called_once_with(real_document, chunk_size=512) assert any(temp_llm_index_dir.glob("*.json")) diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index 77f65db7d..d72b22f32 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -143,6 +143,7 @@ def test_document_filtered_retriever_handles_empty_faiss_index() -> None: mock_index.vector_store.query.assert_not_called() +@pytest.mark.django_db def test_stream_chat_with_one_document_retrieval( mock_document, patch_embed_nodes, @@ -186,6 +187,7 @@ def test_stream_chat_with_one_document_retrieval( ) +@pytest.mark.django_db def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> None: with ( patch("paperless_ai.chat.AIClient") as mock_client_cls, diff --git a/src/paperless_ai/tests/test_client.py b/src/paperless_ai/tests/test_client.py index 74a09d40c..b2fbf68ff 100644 --- a/src/paperless_ai/tests/test_client.py +++ b/src/paperless_ai/tests/test_client.py @@ -15,6 +15,7 @@ def mock_ai_config(): with patch("paperless_ai.client.AIConfig") as MockAIConfig: mock_config = MagicMock() mock_config.llm_allow_internal_endpoints = True + mock_config.llm_context_size = 8192 MockAIConfig.return_value = mock_config yield mock_config @@ -41,6 +42,7 @@ def test_get_llm_ollama(mock_ai_config, mock_ollama_llm): mock_ollama_llm.assert_called_once_with( model="test_model", base_url="http://test-url", + context_window=8192, request_timeout=120, system_prompt=LLM_SYSTEM_PROMPT, client=ANY, diff --git a/src/paperless_ai/tests/test_embedding.py b/src/paperless_ai/tests/test_embedding.py index d3eff080e..1dbd0ab99 100644 --- a/src/paperless_ai/tests/test_embedding.py +++ b/src/paperless_ai/tests/test_embedding.py @@ -19,6 +19,7 @@ def mock_ai_config(): with patch("paperless_ai.embedding.AIConfig") as MockAIConfig: MockAIConfig.return_value.llm_embedding_endpoint = None MockAIConfig.return_value.llm_allow_internal_endpoints = True + MockAIConfig.return_value.llm_context_size = 8192 yield MockAIConfig @@ -140,6 +141,7 @@ def test_get_embedding_model_ollama(mock_ai_config): MockOllamaEmbedding.assert_called_once_with( model_name="embeddinggemma", base_url="http://test-url", + ollama_additional_kwargs={"num_ctx": 8192}, ) assert model == MockOllamaEmbedding.return_value @@ -157,6 +159,7 @@ def test_get_embedding_model_ollama_prefers_embedding_endpoint(mock_ai_config): MockOllamaEmbedding.assert_called_once_with( model_name="embeddinggemma", base_url="http://embedding-url", + ollama_additional_kwargs={"num_ctx": 8192}, ) assert model == MockOllamaEmbedding.return_value