This commit is contained in:
shamoon
2026-05-08 12:48:44 -07:00
parent 29a4cc045e
commit b643ec5f50
3 changed files with 45 additions and 2 deletions
+2
View File
@@ -291,6 +291,8 @@ class ApplicationConfigurationSerializer(
return value
validate_llm_embedding_endpoint = validate_llm_endpoint
class Meta:
model = ApplicationConfiguration
fields = "__all__"
+6 -2
View File
@@ -22,7 +22,7 @@ def get_embedding_model() -> "BaseEmbedding":
case LLMEmbeddingBackend.OPENAI_LIKE:
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
endpoint = config.llm_endpoint or None
endpoint = config.llm_embedding_endpoint or config.llm_endpoint or None
if endpoint:
validate_outbound_http_url(
endpoint,
@@ -43,7 +43,11 @@ def get_embedding_model() -> "BaseEmbedding":
case LLMEmbeddingBackend.OLLAMA:
from llama_index.embeddings.ollama import OllamaEmbedding
endpoint = config.llm_endpoint or "http://localhost:11434"
endpoint = (
config.llm_embedding_endpoint
or config.llm_endpoint
or "http://localhost:11434"
)
validate_outbound_http_url(
endpoint,
allow_internal=config.llm_allow_internal_endpoints,
+37
View File
@@ -14,6 +14,7 @@ from paperless_ai.embedding import get_embedding_model
@pytest.fixture
def mock_ai_config():
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
MockAIConfig.return_value.llm_embedding_endpoint = None
MockAIConfig.return_value.llm_allow_internal_endpoints = True
yield MockAIConfig
@@ -71,6 +72,25 @@ def test_get_embedding_model_openai(mock_ai_config):
assert model == MockOpenAIEmbedding.return_value
def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
mock_ai_config.return_value.llm_api_key = "test_api_key"
mock_ai_config.return_value.llm_embedding_endpoint = "http://embedding-url"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
with patch(
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
) as MockOpenAIEmbedding:
model = get_embedding_model()
MockOpenAIEmbedding.assert_called_once_with(
model_name="text-embedding-3-small",
api_key="test_api_key",
api_base="http://embedding-url",
)
assert model == MockOpenAIEmbedding.return_value
def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
mock_ai_config,
):
@@ -116,6 +136,23 @@ def test_get_embedding_model_ollama(mock_ai_config):
assert model == MockOllamaEmbedding.return_value
def test_get_embedding_model_ollama_prefers_embedding_endpoint(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OLLAMA
mock_ai_config.return_value.llm_embedding_model = "embeddinggemma"
mock_ai_config.return_value.llm_embedding_endpoint = "http://embedding-url"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
with patch(
"llama_index.embeddings.ollama.OllamaEmbedding",
) as MockOllamaEmbedding:
model = get_embedding_model()
MockOllamaEmbedding.assert_called_once_with(
model_name="embeddinggemma",
base_url="http://embedding-url",
)
assert model == MockOllamaEmbedding.return_value
def test_get_embedding_model_ollama_blocks_internal_endpoint_when_disallowed(
mock_ai_config,
):