mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-03-30 21:02:45 +00:00
Use the shared stuff with LLM endpoint
This commit is contained in:
@@ -1947,6 +1947,12 @@ current backend. If not supplied, defaults to "gpt-3.5-turbo" for OpenAI and "ll
|
||||
|
||||
Defaults to None.
|
||||
|
||||
#### [`PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS=<bool>`](#PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS) {#PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS}
|
||||
|
||||
: If set to false, Paperless blocks AI endpoint URLs that resolve to non-public addresses (e.g., localhost, etc).
|
||||
|
||||
Defaults to true, which allows internal endpoints.
|
||||
|
||||
#### [`PAPERLESS_AI_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_AI_LLM_INDEX_TASK_CRON) {#PAPERLESS_AI_LLM_INDEX_TASK_CRON}
|
||||
|
||||
: Configures the schedule to update the AI embeddings of text content and metadata for all documents. Only performed if
|
||||
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.test import override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
@@ -693,3 +694,17 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
content_type="application/json",
|
||||
)
|
||||
mock_update.assert_called_once()
|
||||
|
||||
@override_settings(LLM_ALLOW_INTERNAL_ENDPOINTS=False)
|
||||
def test_update_llm_endpoint_blocks_internal_endpoint_when_disallowed(self) -> None:
|
||||
response = self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
json.dumps(
|
||||
{
|
||||
"llm_endpoint": "http://127.0.0.1:11434",
|
||||
},
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("non-public address", str(response.data).lower())
|
||||
|
||||
@@ -1112,3 +1112,7 @@ LLM_BACKEND = os.getenv("PAPERLESS_AI_LLM_BACKEND") # "ollama" or "openai"
|
||||
LLM_MODEL = os.getenv("PAPERLESS_AI_LLM_MODEL")
|
||||
LLM_API_KEY = os.getenv("PAPERLESS_AI_LLM_API_KEY")
|
||||
LLM_ENDPOINT = os.getenv("PAPERLESS_AI_LLM_ENDPOINT")
|
||||
LLM_ALLOW_INTERNAL_ENDPOINTS = get_bool_from_env(
|
||||
"PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS",
|
||||
"true",
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ if TYPE_CHECKING:
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from paperless.config import AIConfig
|
||||
from paperless.network import validate_outbound_http_url
|
||||
from paperless_ai.base_model import DocumentClassifierSchema
|
||||
|
||||
logger = logging.getLogger("paperless_ai.client")
|
||||
@@ -25,17 +26,28 @@ class AIClient:
|
||||
if self.settings.llm_backend == "ollama":
|
||||
from llama_index.llms.ollama import Ollama
|
||||
|
||||
endpoint = self.settings.llm_endpoint or "http://localhost:11434"
|
||||
validate_outbound_http_url(
|
||||
endpoint,
|
||||
allow_internal=self.settings.llm_allow_internal_endpoints,
|
||||
)
|
||||
return Ollama(
|
||||
model=self.settings.llm_model or "llama3.1",
|
||||
base_url=self.settings.llm_endpoint or "http://localhost:11434",
|
||||
base_url=endpoint,
|
||||
request_timeout=120,
|
||||
)
|
||||
elif self.settings.llm_backend == "openai":
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
endpoint = self.settings.llm_endpoint or None
|
||||
if endpoint:
|
||||
validate_outbound_http_url(
|
||||
endpoint,
|
||||
allow_internal=self.settings.llm_allow_internal_endpoints,
|
||||
)
|
||||
return OpenAI(
|
||||
model=self.settings.llm_model or "gpt-3.5-turbo",
|
||||
api_base=self.settings.llm_endpoint or None,
|
||||
api_base=endpoint,
|
||||
api_key=self.settings.llm_api_key,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -12,6 +12,7 @@ from documents.models import Document
|
||||
from documents.models import Note
|
||||
from paperless.config import AIConfig
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
from paperless.network import validate_outbound_http_url
|
||||
|
||||
|
||||
def get_embedding_model() -> "BaseEmbedding":
|
||||
@@ -21,10 +22,16 @@ def get_embedding_model() -> "BaseEmbedding":
|
||||
case LLMEmbeddingBackend.OPENAI:
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
endpoint = config.llm_endpoint or None
|
||||
if endpoint:
|
||||
validate_outbound_http_url(
|
||||
endpoint,
|
||||
allow_internal=config.llm_allow_internal_endpoints,
|
||||
)
|
||||
return OpenAIEmbedding(
|
||||
model=config.llm_embedding_model or "text-embedding-3-small",
|
||||
api_key=config.llm_api_key,
|
||||
api_base=config.llm_endpoint or None,
|
||||
api_base=endpoint,
|
||||
)
|
||||
case LLMEmbeddingBackend.HUGGINGFACE:
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
@@ -12,6 +12,7 @@ from paperless_ai.client import AIClient
|
||||
def mock_ai_config():
|
||||
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.llm_allow_internal_endpoints = True
|
||||
MockAIConfig.return_value = mock_config
|
||||
yield mock_config
|
||||
|
||||
@@ -59,6 +60,17 @@ def test_get_llm_openai(mock_ai_config, mock_openai_llm):
|
||||
assert client.llm == mock_openai_llm.return_value
|
||||
|
||||
|
||||
def test_get_llm_openai_blocks_internal_endpoint_when_disallowed(mock_ai_config):
|
||||
mock_ai_config.llm_backend = "openai"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_api_key = "test_api_key"
|
||||
mock_ai_config.llm_endpoint = "http://127.0.0.1:1234"
|
||||
mock_ai_config.llm_allow_internal_endpoints = False
|
||||
|
||||
with pytest.raises(ValueError, match="non-public address"):
|
||||
AIClient()
|
||||
|
||||
|
||||
def test_get_llm_unsupported_backend(mock_ai_config):
|
||||
mock_ai_config.llm_backend = "unsupported"
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from paperless_ai.embedding import get_embedding_model
|
||||
@pytest.fixture
|
||||
def mock_ai_config():
|
||||
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
|
||||
MockAIConfig.return_value.llm_allow_internal_endpoints = True
|
||||
yield MockAIConfig
|
||||
|
||||
|
||||
@@ -77,6 +78,19 @@ def test_get_embedding_model_openai(mock_ai_config):
|
||||
assert model == MockOpenAIEmbedding.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
|
||||
mock_ai_config,
|
||||
):
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
|
||||
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://127.0.0.1:11434"
|
||||
mock_ai_config.return_value.llm_allow_internal_endpoints = False
|
||||
|
||||
with pytest.raises(ValueError, match="non-public address"):
|
||||
get_embedding_model()
|
||||
|
||||
|
||||
def test_get_embedding_model_huggingface(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE
|
||||
mock_ai_config.return_value.llm_embedding_model = (
|
||||
|
||||
Reference in New Issue
Block a user