mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-03-15 05:31:23 +00:00
Compare commits
2 Commits
dev
...
fix-ai-end
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df09980e44 | ||
|
|
39e5400f68 |
@@ -1947,6 +1947,12 @@ current backend. If not supplied, defaults to "gpt-3.5-turbo" for OpenAI and "ll
|
|||||||
|
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
|
#### [`PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS=<bool>`](#PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS) {#PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS}
|
||||||
|
|
||||||
|
: If set to false, Paperless blocks AI endpoint URLs that resolve to non-public addresses (e.g., localhost, etc).
|
||||||
|
|
||||||
|
Defaults to true, which allows internal endpoints.
|
||||||
|
|
||||||
#### [`PAPERLESS_AI_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_AI_LLM_INDEX_TASK_CRON) {#PAPERLESS_AI_LLM_INDEX_TASK_CRON}
|
#### [`PAPERLESS_AI_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_AI_LLM_INDEX_TASK_CRON) {#PAPERLESS_AI_LLM_INDEX_TASK_CRON}
|
||||||
|
|
||||||
: Configures the schedule to update the AI embeddings of text content and metadata for all documents. Only performed if
|
: Configures the schedule to update the AI embeddings of text content and metadata for all documents. Only performed if
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||||
|
from django.test import override_settings
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
@@ -693,3 +694,17 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
|||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
)
|
)
|
||||||
mock_update.assert_called_once()
|
mock_update.assert_called_once()
|
||||||
|
|
||||||
|
@override_settings(LLM_ALLOW_INTERNAL_ENDPOINTS=False)
|
||||||
|
def test_update_llm_endpoint_blocks_internal_endpoint_when_disallowed(self) -> None:
|
||||||
|
response = self.client.patch(
|
||||||
|
f"{self.ENDPOINT}1/",
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"llm_endpoint": "http://127.0.0.1:11434",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
|
self.assertIn("non-public address", str(response.data).lower())
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
import ipaddress
|
|
||||||
import logging
|
import logging
|
||||||
import socket
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
|
||||||
|
from paperless.network import format_host_for_url
|
||||||
|
from paperless.network import is_public_ip
|
||||||
|
from paperless.network import resolve_hostname_ips
|
||||||
|
from paperless.network import validate_outbound_http_url
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.workflows.webhooks")
|
logger = logging.getLogger("paperless.workflows.webhooks")
|
||||||
|
|
||||||
|
|
||||||
@@ -34,23 +36,19 @@ class WebhookTransport(httpx.HTTPTransport):
|
|||||||
raise httpx.ConnectError("No hostname in request URL")
|
raise httpx.ConnectError("No hostname in request URL")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
addr_info = socket.getaddrinfo(hostname, None)
|
ips = resolve_hostname_ips(hostname)
|
||||||
except socket.gaierror as e:
|
except ValueError as e:
|
||||||
raise httpx.ConnectError(f"Could not resolve hostname: {hostname}") from e
|
raise httpx.ConnectError(str(e)) from e
|
||||||
|
|
||||||
ips = [info[4][0] for info in addr_info if info and info[4]]
|
|
||||||
if not ips:
|
|
||||||
raise httpx.ConnectError(f"Could not resolve hostname: {hostname}")
|
|
||||||
|
|
||||||
if not self.allow_internal:
|
if not self.allow_internal:
|
||||||
for ip_str in ips:
|
for ip_str in ips:
|
||||||
if not WebhookTransport.is_public_ip(ip_str):
|
if not is_public_ip(ip_str):
|
||||||
raise httpx.ConnectError(
|
raise httpx.ConnectError(
|
||||||
f"Connection blocked: {hostname} resolves to a non-public address",
|
f"Connection blocked: {hostname} resolves to a non-public address",
|
||||||
)
|
)
|
||||||
|
|
||||||
ip_str = ips[0]
|
ip_str = ips[0]
|
||||||
formatted_ip = self._format_ip_for_url(ip_str)
|
formatted_ip = format_host_for_url(ip_str)
|
||||||
|
|
||||||
new_headers = httpx.Headers(request.headers)
|
new_headers = httpx.Headers(request.headers)
|
||||||
if "host" in new_headers:
|
if "host" in new_headers:
|
||||||
@@ -69,40 +67,6 @@ class WebhookTransport(httpx.HTTPTransport):
|
|||||||
|
|
||||||
return super().handle_request(request)
|
return super().handle_request(request)
|
||||||
|
|
||||||
def _format_ip_for_url(self, ip: str) -> str:
|
|
||||||
"""
|
|
||||||
Format IP address for use in URL (wrap IPv6 in brackets)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
ip_obj = ipaddress.ip_address(ip)
|
|
||||||
if ip_obj.version == 6:
|
|
||||||
return f"[{ip}]"
|
|
||||||
return ip
|
|
||||||
except ValueError:
|
|
||||||
return ip
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_public_ip(ip: str | int) -> bool:
|
|
||||||
try:
|
|
||||||
obj = ipaddress.ip_address(ip)
|
|
||||||
return not (
|
|
||||||
obj.is_private
|
|
||||||
or obj.is_loopback
|
|
||||||
or obj.is_link_local
|
|
||||||
or obj.is_multicast
|
|
||||||
or obj.is_unspecified
|
|
||||||
)
|
|
||||||
except ValueError: # pragma: no cover
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def resolve_first_ip(host: str) -> str | None:
|
|
||||||
try:
|
|
||||||
info = socket.getaddrinfo(host, None)
|
|
||||||
return info[0][4][0] if info else None
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@shared_task(
|
@shared_task(
|
||||||
retry_backoff=True,
|
retry_backoff=True,
|
||||||
@@ -118,21 +82,24 @@ def send_webhook(
|
|||||||
*,
|
*,
|
||||||
as_json: bool = False,
|
as_json: bool = False,
|
||||||
):
|
):
|
||||||
p = urlparse(url)
|
try:
|
||||||
if p.scheme.lower() not in settings.WEBHOOKS_ALLOWED_SCHEMES or not p.hostname:
|
parsed = validate_outbound_http_url(
|
||||||
logger.warning("Webhook blocked: invalid scheme/hostname")
|
url,
|
||||||
|
allowed_schemes=settings.WEBHOOKS_ALLOWED_SCHEMES,
|
||||||
|
allowed_ports=settings.WEBHOOKS_ALLOWED_PORTS,
|
||||||
|
# Internal-address checks happen in transport to preserve ConnectError behavior.
|
||||||
|
allow_internal=True,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning("Webhook blocked: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
hostname = parsed.hostname
|
||||||
|
if hostname is None: # pragma: no cover
|
||||||
raise ValueError("Invalid URL scheme or hostname.")
|
raise ValueError("Invalid URL scheme or hostname.")
|
||||||
|
|
||||||
port = p.port or (443 if p.scheme == "https" else 80)
|
|
||||||
if (
|
|
||||||
len(settings.WEBHOOKS_ALLOWED_PORTS) > 0
|
|
||||||
and port not in settings.WEBHOOKS_ALLOWED_PORTS
|
|
||||||
):
|
|
||||||
logger.warning("Webhook blocked: port not permitted")
|
|
||||||
raise ValueError("Destination port not permitted.")
|
|
||||||
|
|
||||||
transport = WebhookTransport(
|
transport = WebhookTransport(
|
||||||
hostname=p.hostname,
|
hostname=hostname,
|
||||||
allow_internal=settings.WEBHOOKS_ALLOW_INTERNAL_REQUESTS,
|
allow_internal=settings.WEBHOOKS_ALLOW_INTERNAL_REQUESTS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -188,6 +188,7 @@ class AIConfig(BaseConfig):
|
|||||||
llm_model: str = dataclasses.field(init=False)
|
llm_model: str = dataclasses.field(init=False)
|
||||||
llm_api_key: str = dataclasses.field(init=False)
|
llm_api_key: str = dataclasses.field(init=False)
|
||||||
llm_endpoint: str = dataclasses.field(init=False)
|
llm_endpoint: str = dataclasses.field(init=False)
|
||||||
|
llm_allow_internal_endpoints: bool = dataclasses.field(init=False)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
app_config = self._get_config_instance()
|
app_config = self._get_config_instance()
|
||||||
@@ -203,6 +204,7 @@ class AIConfig(BaseConfig):
|
|||||||
self.llm_model = app_config.llm_model or settings.LLM_MODEL
|
self.llm_model = app_config.llm_model or settings.LLM_MODEL
|
||||||
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
|
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
|
||||||
self.llm_endpoint = app_config.llm_endpoint or settings.LLM_ENDPOINT
|
self.llm_endpoint = app_config.llm_endpoint or settings.LLM_ENDPOINT
|
||||||
|
self.llm_allow_internal_endpoints = settings.LLM_ALLOW_INTERNAL_ENDPOINTS
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def llm_index_enabled(self) -> bool:
|
def llm_index_enabled(self) -> bool:
|
||||||
|
|||||||
76
src/paperless/network.py
Normal file
76
src/paperless/network.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import ipaddress
|
||||||
|
import socket
|
||||||
|
from collections.abc import Collection
|
||||||
|
from urllib.parse import ParseResult
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
def is_public_ip(ip: str | int) -> bool:
|
||||||
|
try:
|
||||||
|
obj = ipaddress.ip_address(ip)
|
||||||
|
return not (
|
||||||
|
obj.is_private
|
||||||
|
or obj.is_loopback
|
||||||
|
or obj.is_link_local
|
||||||
|
or obj.is_multicast
|
||||||
|
or obj.is_unspecified
|
||||||
|
)
|
||||||
|
except ValueError: # pragma: no cover
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_hostname_ips(hostname: str) -> list[str]:
|
||||||
|
try:
|
||||||
|
addr_info = socket.getaddrinfo(hostname, None)
|
||||||
|
except socket.gaierror as e:
|
||||||
|
raise ValueError(f"Could not resolve hostname: {hostname}") from e
|
||||||
|
|
||||||
|
ips = [info[4][0] for info in addr_info if info and info[4]]
|
||||||
|
if not ips:
|
||||||
|
raise ValueError(f"Could not resolve hostname: {hostname}")
|
||||||
|
return ips
|
||||||
|
|
||||||
|
|
||||||
|
def format_host_for_url(host: str) -> str:
|
||||||
|
"""
|
||||||
|
Format IP address for URL use (wrap IPv6 in brackets).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ip_obj = ipaddress.ip_address(host)
|
||||||
|
if ip_obj.version == 6:
|
||||||
|
return f"[{host}]"
|
||||||
|
return host
|
||||||
|
except ValueError:
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
def validate_outbound_http_url(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
allowed_schemes: Collection[str] = ("http", "https"),
|
||||||
|
allowed_ports: Collection[int] | None = None,
|
||||||
|
allow_internal: bool = False,
|
||||||
|
) -> ParseResult:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
scheme = parsed.scheme.lower()
|
||||||
|
|
||||||
|
if scheme not in allowed_schemes or not parsed.hostname:
|
||||||
|
raise ValueError("Invalid URL scheme or hostname.")
|
||||||
|
|
||||||
|
default_port = 443 if scheme == "https" else 80
|
||||||
|
try:
|
||||||
|
port = parsed.port or default_port
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError("Invalid URL scheme or hostname.") from e
|
||||||
|
|
||||||
|
if allowed_ports and port not in allowed_ports:
|
||||||
|
raise ValueError("Destination port not permitted.")
|
||||||
|
|
||||||
|
if not allow_internal:
|
||||||
|
for ip_str in resolve_hostname_ips(parsed.hostname):
|
||||||
|
if not is_public_ip(ip_str):
|
||||||
|
raise ValueError(
|
||||||
|
f"Connection blocked: {parsed.hostname} resolves to a non-public address",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parsed
|
||||||
@@ -6,6 +6,7 @@ from allauth.mfa.models import Authenticator
|
|||||||
from allauth.mfa.totp.internal.auth import TOTP
|
from allauth.mfa.totp.internal.auth import TOTP
|
||||||
from allauth.socialaccount.models import SocialAccount
|
from allauth.socialaccount.models import SocialAccount
|
||||||
from allauth.socialaccount.models import SocialApp
|
from allauth.socialaccount.models import SocialApp
|
||||||
|
from django.conf import settings
|
||||||
from django.contrib.auth.models import Group
|
from django.contrib.auth.models import Group
|
||||||
from django.contrib.auth.models import Permission
|
from django.contrib.auth.models import Permission
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
@@ -15,6 +16,7 @@ from rest_framework import serializers
|
|||||||
from rest_framework.authtoken.serializers import AuthTokenSerializer
|
from rest_framework.authtoken.serializers import AuthTokenSerializer
|
||||||
|
|
||||||
from paperless.models import ApplicationConfiguration
|
from paperless.models import ApplicationConfiguration
|
||||||
|
from paperless.network import validate_outbound_http_url
|
||||||
from paperless.validators import reject_dangerous_svg
|
from paperless.validators import reject_dangerous_svg
|
||||||
from paperless_mail.serialisers import ObfuscatedPasswordField
|
from paperless_mail.serialisers import ObfuscatedPasswordField
|
||||||
|
|
||||||
@@ -236,6 +238,20 @@ class ApplicationConfigurationSerializer(serializers.ModelSerializer):
|
|||||||
reject_dangerous_svg(file)
|
reject_dangerous_svg(file)
|
||||||
return file
|
return file
|
||||||
|
|
||||||
|
def validate_llm_endpoint(self, value: str | None) -> str | None:
|
||||||
|
if not value:
|
||||||
|
return value
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_outbound_http_url(
|
||||||
|
value,
|
||||||
|
allow_internal=settings.LLM_ALLOW_INTERNAL_ENDPOINTS,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
raise serializers.ValidationError(str(e)) from e
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = ApplicationConfiguration
|
model = ApplicationConfiguration
|
||||||
fields = "__all__"
|
fields = "__all__"
|
||||||
|
|||||||
@@ -1112,3 +1112,7 @@ LLM_BACKEND = os.getenv("PAPERLESS_AI_LLM_BACKEND") # "ollama" or "openai"
|
|||||||
LLM_MODEL = os.getenv("PAPERLESS_AI_LLM_MODEL")
|
LLM_MODEL = os.getenv("PAPERLESS_AI_LLM_MODEL")
|
||||||
LLM_API_KEY = os.getenv("PAPERLESS_AI_LLM_API_KEY")
|
LLM_API_KEY = os.getenv("PAPERLESS_AI_LLM_API_KEY")
|
||||||
LLM_ENDPOINT = os.getenv("PAPERLESS_AI_LLM_ENDPOINT")
|
LLM_ENDPOINT = os.getenv("PAPERLESS_AI_LLM_ENDPOINT")
|
||||||
|
LLM_ALLOW_INTERNAL_ENDPOINTS = get_bool_from_env(
|
||||||
|
"PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS",
|
||||||
|
"true",
|
||||||
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ if TYPE_CHECKING:
|
|||||||
from llama_index.llms.openai import OpenAI
|
from llama_index.llms.openai import OpenAI
|
||||||
|
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
|
from paperless.network import validate_outbound_http_url
|
||||||
from paperless_ai.base_model import DocumentClassifierSchema
|
from paperless_ai.base_model import DocumentClassifierSchema
|
||||||
|
|
||||||
logger = logging.getLogger("paperless_ai.client")
|
logger = logging.getLogger("paperless_ai.client")
|
||||||
@@ -25,17 +26,28 @@ class AIClient:
|
|||||||
if self.settings.llm_backend == "ollama":
|
if self.settings.llm_backend == "ollama":
|
||||||
from llama_index.llms.ollama import Ollama
|
from llama_index.llms.ollama import Ollama
|
||||||
|
|
||||||
|
endpoint = self.settings.llm_endpoint or "http://localhost:11434"
|
||||||
|
validate_outbound_http_url(
|
||||||
|
endpoint,
|
||||||
|
allow_internal=self.settings.llm_allow_internal_endpoints,
|
||||||
|
)
|
||||||
return Ollama(
|
return Ollama(
|
||||||
model=self.settings.llm_model or "llama3.1",
|
model=self.settings.llm_model or "llama3.1",
|
||||||
base_url=self.settings.llm_endpoint or "http://localhost:11434",
|
base_url=endpoint,
|
||||||
request_timeout=120,
|
request_timeout=120,
|
||||||
)
|
)
|
||||||
elif self.settings.llm_backend == "openai":
|
elif self.settings.llm_backend == "openai":
|
||||||
from llama_index.llms.openai import OpenAI
|
from llama_index.llms.openai import OpenAI
|
||||||
|
|
||||||
|
endpoint = self.settings.llm_endpoint or None
|
||||||
|
if endpoint:
|
||||||
|
validate_outbound_http_url(
|
||||||
|
endpoint,
|
||||||
|
allow_internal=self.settings.llm_allow_internal_endpoints,
|
||||||
|
)
|
||||||
return OpenAI(
|
return OpenAI(
|
||||||
model=self.settings.llm_model or "gpt-3.5-turbo",
|
model=self.settings.llm_model or "gpt-3.5-turbo",
|
||||||
api_base=self.settings.llm_endpoint or None,
|
api_base=endpoint,
|
||||||
api_key=self.settings.llm_api_key,
|
api_key=self.settings.llm_api_key,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from documents.models import Document
|
|||||||
from documents.models import Note
|
from documents.models import Note
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
from paperless.models import LLMEmbeddingBackend
|
from paperless.models import LLMEmbeddingBackend
|
||||||
|
from paperless.network import validate_outbound_http_url
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model() -> "BaseEmbedding":
|
def get_embedding_model() -> "BaseEmbedding":
|
||||||
@@ -21,10 +22,16 @@ def get_embedding_model() -> "BaseEmbedding":
|
|||||||
case LLMEmbeddingBackend.OPENAI:
|
case LLMEmbeddingBackend.OPENAI:
|
||||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||||
|
|
||||||
|
endpoint = config.llm_endpoint or None
|
||||||
|
if endpoint:
|
||||||
|
validate_outbound_http_url(
|
||||||
|
endpoint,
|
||||||
|
allow_internal=config.llm_allow_internal_endpoints,
|
||||||
|
)
|
||||||
return OpenAIEmbedding(
|
return OpenAIEmbedding(
|
||||||
model=config.llm_embedding_model or "text-embedding-3-small",
|
model=config.llm_embedding_model or "text-embedding-3-small",
|
||||||
api_key=config.llm_api_key,
|
api_key=config.llm_api_key,
|
||||||
api_base=config.llm_endpoint or None,
|
api_base=endpoint,
|
||||||
)
|
)
|
||||||
case LLMEmbeddingBackend.HUGGINGFACE:
|
case LLMEmbeddingBackend.HUGGINGFACE:
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from paperless_ai.client import AIClient
|
|||||||
def mock_ai_config():
|
def mock_ai_config():
|
||||||
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
|
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
|
||||||
mock_config = MagicMock()
|
mock_config = MagicMock()
|
||||||
|
mock_config.llm_allow_internal_endpoints = True
|
||||||
MockAIConfig.return_value = mock_config
|
MockAIConfig.return_value = mock_config
|
||||||
yield mock_config
|
yield mock_config
|
||||||
|
|
||||||
@@ -59,6 +60,17 @@ def test_get_llm_openai(mock_ai_config, mock_openai_llm):
|
|||||||
assert client.llm == mock_openai_llm.return_value
|
assert client.llm == mock_openai_llm.return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_openai_blocks_internal_endpoint_when_disallowed(mock_ai_config):
|
||||||
|
mock_ai_config.llm_backend = "openai"
|
||||||
|
mock_ai_config.llm_model = "test_model"
|
||||||
|
mock_ai_config.llm_api_key = "test_api_key"
|
||||||
|
mock_ai_config.llm_endpoint = "http://127.0.0.1:1234"
|
||||||
|
mock_ai_config.llm_allow_internal_endpoints = False
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="non-public address"):
|
||||||
|
AIClient()
|
||||||
|
|
||||||
|
|
||||||
def test_get_llm_unsupported_backend(mock_ai_config):
|
def test_get_llm_unsupported_backend(mock_ai_config):
|
||||||
mock_ai_config.llm_backend = "unsupported"
|
mock_ai_config.llm_backend = "unsupported"
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from paperless_ai.embedding import get_embedding_model
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_ai_config():
|
def mock_ai_config():
|
||||||
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
|
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
|
||||||
|
MockAIConfig.return_value.llm_allow_internal_endpoints = True
|
||||||
yield MockAIConfig
|
yield MockAIConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -77,6 +78,19 @@ def test_get_embedding_model_openai(mock_ai_config):
|
|||||||
assert model == MockOpenAIEmbedding.return_value
|
assert model == MockOpenAIEmbedding.return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
|
||||||
|
mock_ai_config,
|
||||||
|
):
|
||||||
|
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
|
||||||
|
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
||||||
|
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||||
|
mock_ai_config.return_value.llm_endpoint = "http://127.0.0.1:11434"
|
||||||
|
mock_ai_config.return_value.llm_allow_internal_endpoints = False
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="non-public address"):
|
||||||
|
get_embedding_model()
|
||||||
|
|
||||||
|
|
||||||
def test_get_embedding_model_huggingface(mock_ai_config):
|
def test_get_embedding_model_huggingface(mock_ai_config):
|
||||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE
|
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE
|
||||||
mock_ai_config.return_value.llm_embedding_model = (
|
mock_ai_config.return_value.llm_embedding_model = (
|
||||||
|
|||||||
Reference in New Issue
Block a user