mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-10 23:59:43 +00:00
AD
This commit is contained in:
@@ -291,6 +291,8 @@ class ApplicationConfigurationSerializer(
|
||||
|
||||
return value
|
||||
|
||||
validate_llm_embedding_endpoint = validate_llm_endpoint
|
||||
|
||||
class Meta:
|
||||
model = ApplicationConfiguration
|
||||
fields = "__all__"
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_embedding_model() -> "BaseEmbedding":
|
||||
case LLMEmbeddingBackend.OPENAI_LIKE:
|
||||
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
||||
|
||||
endpoint = config.llm_endpoint or None
|
||||
endpoint = config.llm_embedding_endpoint or config.llm_endpoint or None
|
||||
if endpoint:
|
||||
validate_outbound_http_url(
|
||||
endpoint,
|
||||
@@ -43,7 +43,11 @@ def get_embedding_model() -> "BaseEmbedding":
|
||||
case LLMEmbeddingBackend.OLLAMA:
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
|
||||
endpoint = config.llm_endpoint or "http://localhost:11434"
|
||||
endpoint = (
|
||||
config.llm_embedding_endpoint
|
||||
or config.llm_endpoint
|
||||
or "http://localhost:11434"
|
||||
)
|
||||
validate_outbound_http_url(
|
||||
endpoint,
|
||||
allow_internal=config.llm_allow_internal_endpoints,
|
||||
|
||||
@@ -14,6 +14,7 @@ from paperless_ai.embedding import get_embedding_model
|
||||
@pytest.fixture
|
||||
def mock_ai_config():
|
||||
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
|
||||
MockAIConfig.return_value.llm_embedding_endpoint = None
|
||||
MockAIConfig.return_value.llm_allow_internal_endpoints = True
|
||||
yield MockAIConfig
|
||||
|
||||
@@ -71,6 +72,25 @@ def test_get_embedding_model_openai(mock_ai_config):
|
||||
assert model == MockOpenAIEmbedding.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
|
||||
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||
mock_ai_config.return_value.llm_embedding_endpoint = "http://embedding-url"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://test-url"
|
||||
|
||||
with patch(
|
||||
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
|
||||
) as MockOpenAIEmbedding:
|
||||
model = get_embedding_model()
|
||||
MockOpenAIEmbedding.assert_called_once_with(
|
||||
model_name="text-embedding-3-small",
|
||||
api_key="test_api_key",
|
||||
api_base="http://embedding-url",
|
||||
)
|
||||
assert model == MockOpenAIEmbedding.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
|
||||
mock_ai_config,
|
||||
):
|
||||
@@ -116,6 +136,23 @@ def test_get_embedding_model_ollama(mock_ai_config):
|
||||
assert model == MockOllamaEmbedding.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_ollama_prefers_embedding_endpoint(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OLLAMA
|
||||
mock_ai_config.return_value.llm_embedding_model = "embeddinggemma"
|
||||
mock_ai_config.return_value.llm_embedding_endpoint = "http://embedding-url"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://test-url"
|
||||
|
||||
with patch(
|
||||
"llama_index.embeddings.ollama.OllamaEmbedding",
|
||||
) as MockOllamaEmbedding:
|
||||
model = get_embedding_model()
|
||||
MockOllamaEmbedding.assert_called_once_with(
|
||||
model_name="embeddinggemma",
|
||||
base_url="http://embedding-url",
|
||||
)
|
||||
assert model == MockOllamaEmbedding.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_ollama_blocks_internal_endpoint_when_disallowed(
|
||||
mock_ai_config,
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user