From b643ec5f50a6926ce7e3b4098160dd2e640771ab Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Fri, 8 May 2026 12:48:44 -0700 Subject: [PATCH] AD --- src/paperless/serialisers.py | 2 ++ src/paperless_ai/embedding.py | 8 +++-- src/paperless_ai/tests/test_embedding.py | 37 ++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/paperless/serialisers.py b/src/paperless/serialisers.py index 92676df4e..d1597ab13 100644 --- a/src/paperless/serialisers.py +++ b/src/paperless/serialisers.py @@ -291,6 +291,8 @@ class ApplicationConfigurationSerializer( return value + validate_llm_embedding_endpoint = validate_llm_endpoint + class Meta: model = ApplicationConfiguration fields = "__all__" diff --git a/src/paperless_ai/embedding.py b/src/paperless_ai/embedding.py index cf7626984..2a32b9dba 100644 --- a/src/paperless_ai/embedding.py +++ b/src/paperless_ai/embedding.py @@ -22,7 +22,7 @@ def get_embedding_model() -> "BaseEmbedding": case LLMEmbeddingBackend.OPENAI_LIKE: from llama_index.embeddings.openai_like import OpenAILikeEmbedding - endpoint = config.llm_endpoint or None + endpoint = config.llm_embedding_endpoint or config.llm_endpoint or None if endpoint: validate_outbound_http_url( endpoint, @@ -43,7 +43,11 @@ def get_embedding_model() -> "BaseEmbedding": case LLMEmbeddingBackend.OLLAMA: from llama_index.embeddings.ollama import OllamaEmbedding - endpoint = config.llm_endpoint or "http://localhost:11434" + endpoint = ( + config.llm_embedding_endpoint + or config.llm_endpoint + or "http://localhost:11434" + ) validate_outbound_http_url( endpoint, allow_internal=config.llm_allow_internal_endpoints, diff --git a/src/paperless_ai/tests/test_embedding.py b/src/paperless_ai/tests/test_embedding.py index b595b68cd..3fb5c39ce 100644 --- a/src/paperless_ai/tests/test_embedding.py +++ b/src/paperless_ai/tests/test_embedding.py @@ -14,6 +14,7 @@ from paperless_ai.embedding import get_embedding_model @pytest.fixture def mock_ai_config(): with patch("paperless_ai.embedding.AIConfig") as MockAIConfig: + MockAIConfig.return_value.llm_embedding_endpoint = None MockAIConfig.return_value.llm_allow_internal_endpoints = True yield MockAIConfig @@ -71,6 +72,25 @@ def test_get_embedding_model_openai(mock_ai_config): assert model == MockOpenAIEmbedding.return_value +def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config): + mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE + mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small" + mock_ai_config.return_value.llm_api_key = "test_api_key" + mock_ai_config.return_value.llm_embedding_endpoint = "http://embedding-url" + mock_ai_config.return_value.llm_endpoint = "http://test-url" + + with patch( + "llama_index.embeddings.openai_like.OpenAILikeEmbedding", + ) as MockOpenAIEmbedding: + model = get_embedding_model() + MockOpenAIEmbedding.assert_called_once_with( + model_name="text-embedding-3-small", + api_key="test_api_key", + api_base="http://embedding-url", + ) + assert model == MockOpenAIEmbedding.return_value + + def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed( mock_ai_config, ): @@ -116,6 +136,23 @@ def test_get_embedding_model_ollama(mock_ai_config): assert model == MockOllamaEmbedding.return_value +def test_get_embedding_model_ollama_prefers_embedding_endpoint(mock_ai_config): + mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OLLAMA + mock_ai_config.return_value.llm_embedding_model = "embeddinggemma" + mock_ai_config.return_value.llm_embedding_endpoint = "http://embedding-url" + mock_ai_config.return_value.llm_endpoint = "http://test-url" + + with patch( + "llama_index.embeddings.ollama.OllamaEmbedding", + ) as MockOllamaEmbedding: + model = get_embedding_model() + MockOllamaEmbedding.assert_called_once_with( + model_name="embeddinggemma", + base_url="http://embedding-url", + ) + assert model == MockOllamaEmbedding.return_value + + def test_get_embedding_model_ollama_blocks_internal_endpoint_when_disallowed( mock_ai_config, ):