Files
paperless-ngx/src/paperless/network.py
T
2026-05-26 16:46:23 +00:00

197 lines
5.4 KiB
Python

import ipaddress
import socket
from collections.abc import Collection
from urllib.parse import ParseResult
from urllib.parse import urlparse
import httpx
def is_public_ip(ip: str | int) -> bool:
try:
obj = ipaddress.ip_address(ip)
return not (
obj.is_private
or obj.is_loopback
or obj.is_link_local
or obj.is_multicast
or obj.is_unspecified
)
except ValueError: # pragma: no cover
return False
def resolve_hostname_ips(hostname: str) -> list[str]:
try:
addr_info = socket.getaddrinfo(hostname, None)
except socket.gaierror as e:
raise ValueError(f"Could not resolve hostname: {hostname}") from e
ips = [info[4][0] for info in addr_info if info and info[4]]
if not ips:
raise ValueError(f"Could not resolve hostname: {hostname}")
return ips
def format_host_for_url(host: str) -> str:
"""
Format IP address for URL use (wrap IPv6 in brackets).
"""
try:
ip_obj = ipaddress.ip_address(host)
if ip_obj.version == 6:
return f"[{host}]"
return host
except ValueError:
return host
def validate_outbound_http_url(
url: str,
*,
allowed_schemes: Collection[str] = ("http", "https"),
allowed_ports: Collection[int] | None = None,
allow_internal: bool = False,
) -> ParseResult:
parsed = urlparse(url)
scheme = parsed.scheme.lower()
if scheme not in allowed_schemes or not parsed.hostname:
raise ValueError("Invalid URL scheme or hostname.")
default_port = 443 if scheme == "https" else 80
try:
port = parsed.port or default_port
except ValueError as e:
raise ValueError("Invalid URL scheme or hostname.") from e
if allowed_ports and port not in allowed_ports:
raise ValueError("Destination port not permitted.")
if not allow_internal:
for ip_str in resolve_hostname_ips(parsed.hostname):
if not is_public_ip(ip_str):
raise ValueError(
f"Connection blocked: {parsed.hostname} resolves to a non-public address",
)
return parsed
def _rewrite_request_to_pinned_ip(
request: httpx.Request,
*,
allow_internal: bool,
) -> httpx.Request:
hostname = request.url.host
if not hostname:
raise httpx.ConnectError("No hostname in request URL")
try:
ips = resolve_hostname_ips(hostname)
except ValueError as e:
raise httpx.ConnectError(str(e)) from e
if not allow_internal:
for ip_str in ips:
if not is_public_ip(ip_str):
raise httpx.ConnectError(
f"Connection blocked: {hostname} resolves to a non-public address",
)
ip_str = ips[0]
formatted_ip = format_host_for_url(ip_str)
new_headers = httpx.Headers(request.headers)
if "host" in new_headers:
del new_headers["host"]
host_header = format_host_for_url(hostname)
default_port = 443 if request.url.scheme == "https" else 80
if request.url.port and request.url.port != default_port:
host_header = f"{host_header}:{request.url.port}"
new_headers["Host"] = host_header
new_url = request.url.copy_with(host=formatted_ip)
rewritten_request = httpx.Request(
method=request.method,
url=new_url,
headers=new_headers,
content=request.stream,
extensions=request.extensions,
)
rewritten_request.extensions["sni_hostname"] = hostname
return rewritten_request
class PinnedHostHTTPTransport(httpx.HTTPTransport):
"""
HTTP transport that resolves/validates hostnames per request and connects to
a vetted IP while preserving the original Host header and TLS SNI hostname.
"""
def __init__(
self,
*args,
allow_internal: bool = False,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.allow_internal = allow_internal
def handle_request(self, request: httpx.Request) -> httpx.Response:
request = _rewrite_request_to_pinned_ip(
request,
allow_internal=self.allow_internal,
)
return super().handle_request(request)
class PinnedHostAsyncHTTPTransport(httpx.AsyncHTTPTransport):
"""
Async variant of PinnedHostHTTPTransport.
"""
def __init__(
self,
*args,
allow_internal: bool = False,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.allow_internal = allow_internal
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
request = _rewrite_request_to_pinned_ip(
request,
allow_internal=self.allow_internal,
)
return await super().handle_async_request(request)
def create_pinned_httpx_client(
url: str,
*,
allow_internal: bool = False,
**kwargs,
) -> httpx.Client:
validate_outbound_http_url(url, allow_internal=allow_internal)
return httpx.Client(
transport=PinnedHostHTTPTransport(allow_internal=allow_internal),
**kwargs,
)
def create_pinned_async_httpx_client(
url: str,
*,
allow_internal: bool = False,
**kwargs,
) -> httpx.AsyncClient:
validate_outbound_http_url(url, allow_internal=allow_internal)
return httpx.AsyncClient(
transport=PinnedHostAsyncHTTPTransport(allow_internal=allow_internal),
**kwargs,
)