mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-03-12 20:21:23 +00:00
Compare commits
2 Commits
dependabot
...
fix-ai-end
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df09980e44 | ||
|
|
39e5400f68 |
@@ -1947,6 +1947,12 @@ current backend. If not supplied, defaults to "gpt-3.5-turbo" for OpenAI and "ll
|
||||
|
||||
Defaults to None.
|
||||
|
||||
#### [`PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS=<bool>`](#PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS) {#PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS}
|
||||
|
||||
: If set to false, Paperless blocks AI endpoint URLs that resolve to non-public addresses (e.g., localhost, etc).
|
||||
|
||||
Defaults to true, which allows internal endpoints.
|
||||
|
||||
#### [`PAPERLESS_AI_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_AI_LLM_INDEX_TASK_CRON) {#PAPERLESS_AI_LLM_INDEX_TASK_CRON}
|
||||
|
||||
: Configures the schedule to update the AI embeddings of text content and metadata for all documents. Only performed if
|
||||
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.test import override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
@@ -693,3 +694,17 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
content_type="application/json",
|
||||
)
|
||||
mock_update.assert_called_once()
|
||||
|
||||
@override_settings(LLM_ALLOW_INTERNAL_ENDPOINTS=False)
|
||||
def test_update_llm_endpoint_blocks_internal_endpoint_when_disallowed(self) -> None:
|
||||
response = self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
json.dumps(
|
||||
{
|
||||
"llm_endpoint": "http://127.0.0.1:11434",
|
||||
},
|
||||
),
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("non-public address", str(response.data).lower())
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from celery import shared_task
|
||||
from django.conf import settings
|
||||
|
||||
from paperless.network import format_host_for_url
|
||||
from paperless.network import is_public_ip
|
||||
from paperless.network import resolve_hostname_ips
|
||||
from paperless.network import validate_outbound_http_url
|
||||
|
||||
logger = logging.getLogger("paperless.workflows.webhooks")
|
||||
|
||||
|
||||
@@ -34,23 +36,19 @@ class WebhookTransport(httpx.HTTPTransport):
|
||||
raise httpx.ConnectError("No hostname in request URL")
|
||||
|
||||
try:
|
||||
addr_info = socket.getaddrinfo(hostname, None)
|
||||
except socket.gaierror as e:
|
||||
raise httpx.ConnectError(f"Could not resolve hostname: {hostname}") from e
|
||||
|
||||
ips = [info[4][0] for info in addr_info if info and info[4]]
|
||||
if not ips:
|
||||
raise httpx.ConnectError(f"Could not resolve hostname: {hostname}")
|
||||
ips = resolve_hostname_ips(hostname)
|
||||
except ValueError as e:
|
||||
raise httpx.ConnectError(str(e)) from e
|
||||
|
||||
if not self.allow_internal:
|
||||
for ip_str in ips:
|
||||
if not WebhookTransport.is_public_ip(ip_str):
|
||||
if not is_public_ip(ip_str):
|
||||
raise httpx.ConnectError(
|
||||
f"Connection blocked: {hostname} resolves to a non-public address",
|
||||
)
|
||||
|
||||
ip_str = ips[0]
|
||||
formatted_ip = self._format_ip_for_url(ip_str)
|
||||
formatted_ip = format_host_for_url(ip_str)
|
||||
|
||||
new_headers = httpx.Headers(request.headers)
|
||||
if "host" in new_headers:
|
||||
@@ -69,40 +67,6 @@ class WebhookTransport(httpx.HTTPTransport):
|
||||
|
||||
return super().handle_request(request)
|
||||
|
||||
def _format_ip_for_url(self, ip: str) -> str:
|
||||
"""
|
||||
Format IP address for use in URL (wrap IPv6 in brackets)
|
||||
"""
|
||||
try:
|
||||
ip_obj = ipaddress.ip_address(ip)
|
||||
if ip_obj.version == 6:
|
||||
return f"[{ip}]"
|
||||
return ip
|
||||
except ValueError:
|
||||
return ip
|
||||
|
||||
@staticmethod
|
||||
def is_public_ip(ip: str | int) -> bool:
|
||||
try:
|
||||
obj = ipaddress.ip_address(ip)
|
||||
return not (
|
||||
obj.is_private
|
||||
or obj.is_loopback
|
||||
or obj.is_link_local
|
||||
or obj.is_multicast
|
||||
or obj.is_unspecified
|
||||
)
|
||||
except ValueError: # pragma: no cover
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def resolve_first_ip(host: str) -> str | None:
|
||||
try:
|
||||
info = socket.getaddrinfo(host, None)
|
||||
return info[0][4][0] if info else None
|
||||
except Exception: # pragma: no cover
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
retry_backoff=True,
|
||||
@@ -118,21 +82,24 @@ def send_webhook(
|
||||
*,
|
||||
as_json: bool = False,
|
||||
):
|
||||
p = urlparse(url)
|
||||
if p.scheme.lower() not in settings.WEBHOOKS_ALLOWED_SCHEMES or not p.hostname:
|
||||
logger.warning("Webhook blocked: invalid scheme/hostname")
|
||||
try:
|
||||
parsed = validate_outbound_http_url(
|
||||
url,
|
||||
allowed_schemes=settings.WEBHOOKS_ALLOWED_SCHEMES,
|
||||
allowed_ports=settings.WEBHOOKS_ALLOWED_PORTS,
|
||||
# Internal-address checks happen in transport to preserve ConnectError behavior.
|
||||
allow_internal=True,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning("Webhook blocked: %s", e)
|
||||
raise
|
||||
|
||||
hostname = parsed.hostname
|
||||
if hostname is None: # pragma: no cover
|
||||
raise ValueError("Invalid URL scheme or hostname.")
|
||||
|
||||
port = p.port or (443 if p.scheme == "https" else 80)
|
||||
if (
|
||||
len(settings.WEBHOOKS_ALLOWED_PORTS) > 0
|
||||
and port not in settings.WEBHOOKS_ALLOWED_PORTS
|
||||
):
|
||||
logger.warning("Webhook blocked: port not permitted")
|
||||
raise ValueError("Destination port not permitted.")
|
||||
|
||||
transport = WebhookTransport(
|
||||
hostname=p.hostname,
|
||||
hostname=hostname,
|
||||
allow_internal=settings.WEBHOOKS_ALLOW_INTERNAL_REQUESTS,
|
||||
)
|
||||
|
||||
|
||||
@@ -188,6 +188,7 @@ class AIConfig(BaseConfig):
|
||||
llm_model: str = dataclasses.field(init=False)
|
||||
llm_api_key: str = dataclasses.field(init=False)
|
||||
llm_endpoint: str = dataclasses.field(init=False)
|
||||
llm_allow_internal_endpoints: bool = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
app_config = self._get_config_instance()
|
||||
@@ -203,6 +204,7 @@ class AIConfig(BaseConfig):
|
||||
self.llm_model = app_config.llm_model or settings.LLM_MODEL
|
||||
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
|
||||
self.llm_endpoint = app_config.llm_endpoint or settings.LLM_ENDPOINT
|
||||
self.llm_allow_internal_endpoints = settings.LLM_ALLOW_INTERNAL_ENDPOINTS
|
||||
|
||||
@property
|
||||
def llm_index_enabled(self) -> bool:
|
||||
|
||||
76
src/paperless/network.py
Normal file
76
src/paperless/network.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import ipaddress
|
||||
import socket
|
||||
from collections.abc import Collection
|
||||
from urllib.parse import ParseResult
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def is_public_ip(ip: str | int) -> bool:
|
||||
try:
|
||||
obj = ipaddress.ip_address(ip)
|
||||
return not (
|
||||
obj.is_private
|
||||
or obj.is_loopback
|
||||
or obj.is_link_local
|
||||
or obj.is_multicast
|
||||
or obj.is_unspecified
|
||||
)
|
||||
except ValueError: # pragma: no cover
|
||||
return False
|
||||
|
||||
|
||||
def resolve_hostname_ips(hostname: str) -> list[str]:
|
||||
try:
|
||||
addr_info = socket.getaddrinfo(hostname, None)
|
||||
except socket.gaierror as e:
|
||||
raise ValueError(f"Could not resolve hostname: {hostname}") from e
|
||||
|
||||
ips = [info[4][0] for info in addr_info if info and info[4]]
|
||||
if not ips:
|
||||
raise ValueError(f"Could not resolve hostname: {hostname}")
|
||||
return ips
|
||||
|
||||
|
||||
def format_host_for_url(host: str) -> str:
|
||||
"""
|
||||
Format IP address for URL use (wrap IPv6 in brackets).
|
||||
"""
|
||||
try:
|
||||
ip_obj = ipaddress.ip_address(host)
|
||||
if ip_obj.version == 6:
|
||||
return f"[{host}]"
|
||||
return host
|
||||
except ValueError:
|
||||
return host
|
||||
|
||||
|
||||
def validate_outbound_http_url(
|
||||
url: str,
|
||||
*,
|
||||
allowed_schemes: Collection[str] = ("http", "https"),
|
||||
allowed_ports: Collection[int] | None = None,
|
||||
allow_internal: bool = False,
|
||||
) -> ParseResult:
|
||||
parsed = urlparse(url)
|
||||
scheme = parsed.scheme.lower()
|
||||
|
||||
if scheme not in allowed_schemes or not parsed.hostname:
|
||||
raise ValueError("Invalid URL scheme or hostname.")
|
||||
|
||||
default_port = 443 if scheme == "https" else 80
|
||||
try:
|
||||
port = parsed.port or default_port
|
||||
except ValueError as e:
|
||||
raise ValueError("Invalid URL scheme or hostname.") from e
|
||||
|
||||
if allowed_ports and port not in allowed_ports:
|
||||
raise ValueError("Destination port not permitted.")
|
||||
|
||||
if not allow_internal:
|
||||
for ip_str in resolve_hostname_ips(parsed.hostname):
|
||||
if not is_public_ip(ip_str):
|
||||
raise ValueError(
|
||||
f"Connection blocked: {parsed.hostname} resolves to a non-public address",
|
||||
)
|
||||
|
||||
return parsed
|
||||
@@ -6,6 +6,7 @@ from allauth.mfa.models import Authenticator
|
||||
from allauth.mfa.totp.internal.auth import TOTP
|
||||
from allauth.socialaccount.models import SocialAccount
|
||||
from allauth.socialaccount.models import SocialApp
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
@@ -15,6 +16,7 @@ from rest_framework import serializers
|
||||
from rest_framework.authtoken.serializers import AuthTokenSerializer
|
||||
|
||||
from paperless.models import ApplicationConfiguration
|
||||
from paperless.network import validate_outbound_http_url
|
||||
from paperless.validators import reject_dangerous_svg
|
||||
from paperless_mail.serialisers import ObfuscatedPasswordField
|
||||
|
||||
@@ -236,6 +238,20 @@ class ApplicationConfigurationSerializer(serializers.ModelSerializer):
|
||||
reject_dangerous_svg(file)
|
||||
return file
|
||||
|
||||
def validate_llm_endpoint(self, value: str | None) -> str | None:
|
||||
if not value:
|
||||
return value
|
||||
|
||||
try:
|
||||
validate_outbound_http_url(
|
||||
value,
|
||||
allow_internal=settings.LLM_ALLOW_INTERNAL_ENDPOINTS,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise serializers.ValidationError(str(e)) from e
|
||||
|
||||
return value
|
||||
|
||||
class Meta:
|
||||
model = ApplicationConfiguration
|
||||
fields = "__all__"
|
||||
|
||||
@@ -1112,3 +1112,7 @@ LLM_BACKEND = os.getenv("PAPERLESS_AI_LLM_BACKEND") # "ollama" or "openai"
|
||||
LLM_MODEL = os.getenv("PAPERLESS_AI_LLM_MODEL")
|
||||
LLM_API_KEY = os.getenv("PAPERLESS_AI_LLM_API_KEY")
|
||||
LLM_ENDPOINT = os.getenv("PAPERLESS_AI_LLM_ENDPOINT")
|
||||
LLM_ALLOW_INTERNAL_ENDPOINTS = get_bool_from_env(
|
||||
"PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS",
|
||||
"true",
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ if TYPE_CHECKING:
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from paperless.config import AIConfig
|
||||
from paperless.network import validate_outbound_http_url
|
||||
from paperless_ai.base_model import DocumentClassifierSchema
|
||||
|
||||
logger = logging.getLogger("paperless_ai.client")
|
||||
@@ -25,17 +26,28 @@ class AIClient:
|
||||
if self.settings.llm_backend == "ollama":
|
||||
from llama_index.llms.ollama import Ollama
|
||||
|
||||
endpoint = self.settings.llm_endpoint or "http://localhost:11434"
|
||||
validate_outbound_http_url(
|
||||
endpoint,
|
||||
allow_internal=self.settings.llm_allow_internal_endpoints,
|
||||
)
|
||||
return Ollama(
|
||||
model=self.settings.llm_model or "llama3.1",
|
||||
base_url=self.settings.llm_endpoint or "http://localhost:11434",
|
||||
base_url=endpoint,
|
||||
request_timeout=120,
|
||||
)
|
||||
elif self.settings.llm_backend == "openai":
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
endpoint = self.settings.llm_endpoint or None
|
||||
if endpoint:
|
||||
validate_outbound_http_url(
|
||||
endpoint,
|
||||
allow_internal=self.settings.llm_allow_internal_endpoints,
|
||||
)
|
||||
return OpenAI(
|
||||
model=self.settings.llm_model or "gpt-3.5-turbo",
|
||||
api_base=self.settings.llm_endpoint or None,
|
||||
api_base=endpoint,
|
||||
api_key=self.settings.llm_api_key,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -12,6 +12,7 @@ from documents.models import Document
|
||||
from documents.models import Note
|
||||
from paperless.config import AIConfig
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
from paperless.network import validate_outbound_http_url
|
||||
|
||||
|
||||
def get_embedding_model() -> "BaseEmbedding":
|
||||
@@ -21,10 +22,16 @@ def get_embedding_model() -> "BaseEmbedding":
|
||||
case LLMEmbeddingBackend.OPENAI:
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
endpoint = config.llm_endpoint or None
|
||||
if endpoint:
|
||||
validate_outbound_http_url(
|
||||
endpoint,
|
||||
allow_internal=config.llm_allow_internal_endpoints,
|
||||
)
|
||||
return OpenAIEmbedding(
|
||||
model=config.llm_embedding_model or "text-embedding-3-small",
|
||||
api_key=config.llm_api_key,
|
||||
api_base=config.llm_endpoint or None,
|
||||
api_base=endpoint,
|
||||
)
|
||||
case LLMEmbeddingBackend.HUGGINGFACE:
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
@@ -12,6 +12,7 @@ from paperless_ai.client import AIClient
|
||||
def mock_ai_config():
|
||||
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.llm_allow_internal_endpoints = True
|
||||
MockAIConfig.return_value = mock_config
|
||||
yield mock_config
|
||||
|
||||
@@ -59,6 +60,17 @@ def test_get_llm_openai(mock_ai_config, mock_openai_llm):
|
||||
assert client.llm == mock_openai_llm.return_value
|
||||
|
||||
|
||||
def test_get_llm_openai_blocks_internal_endpoint_when_disallowed(mock_ai_config):
|
||||
mock_ai_config.llm_backend = "openai"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_api_key = "test_api_key"
|
||||
mock_ai_config.llm_endpoint = "http://127.0.0.1:1234"
|
||||
mock_ai_config.llm_allow_internal_endpoints = False
|
||||
|
||||
with pytest.raises(ValueError, match="non-public address"):
|
||||
AIClient()
|
||||
|
||||
|
||||
def test_get_llm_unsupported_backend(mock_ai_config):
|
||||
mock_ai_config.llm_backend = "unsupported"
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from paperless_ai.embedding import get_embedding_model
|
||||
@pytest.fixture
|
||||
def mock_ai_config():
|
||||
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
|
||||
MockAIConfig.return_value.llm_allow_internal_endpoints = True
|
||||
yield MockAIConfig
|
||||
|
||||
|
||||
@@ -77,6 +78,19 @@ def test_get_embedding_model_openai(mock_ai_config):
|
||||
assert model == MockOpenAIEmbedding.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
|
||||
mock_ai_config,
|
||||
):
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
|
||||
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://127.0.0.1:11434"
|
||||
mock_ai_config.return_value.llm_allow_internal_endpoints = False
|
||||
|
||||
with pytest.raises(ValueError, match="non-public address"):
|
||||
get_embedding_model()
|
||||
|
||||
|
||||
def test_get_embedding_model_huggingface(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE
|
||||
mock_ai_config.return_value.llm_embedding_model = (
|
||||
|
||||
Reference in New Issue
Block a user