Validate the AI backend settings earlier instead of crashing inside the AI module

This commit is contained in:
stumpylog
2026-06-03 10:43:25 -07:00
parent abdcdccf08
commit 6ede72cc44
4 changed files with 43 additions and 23 deletions
+7 -3
View File
@@ -1184,9 +1184,10 @@ REMOTE_OCR_ENDPOINT = os.getenv("PAPERLESS_REMOTE_OCR_ENDPOINT")
# AI Settings #
################################################################################
AI_ENABLED = get_bool_from_env("PAPERLESS_AI_ENABLED", "NO")
LLM_EMBEDDING_BACKEND = os.getenv(
LLM_EMBEDDING_BACKEND = get_choice_from_env(
"PAPERLESS_AI_LLM_EMBEDDING_BACKEND",
) # "huggingface", "openai-like", or "ollama"
{"huggingface", "openai-like", "ollama"},
)
LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_AI_LLM_EMBEDDING_MODEL")
LLM_EMBEDDING_ENDPOINT = os.getenv("PAPERLESS_AI_LLM_EMBEDDING_ENDPOINT")
LLM_EMBEDDING_CHUNK_SIZE = get_int_from_env(
@@ -1198,7 +1199,10 @@ if LLM_EMBEDDING_CHUNK_SIZE < 1:
LLM_CONTEXT_SIZE = get_int_from_env("PAPERLESS_AI_LLM_CONTEXT_SIZE", 8192)
if LLM_CONTEXT_SIZE < 1:
raise ImproperlyConfigured("PAPERLESS_AI_LLM_CONTEXT_SIZE must be >= 1")
LLM_BACKEND = os.getenv("PAPERLESS_AI_LLM_BACKEND") # "ollama" or "openai-like"
LLM_BACKEND = get_choice_from_env(
"PAPERLESS_AI_LLM_BACKEND",
{"ollama", "openai-like"},
)
LLM_MODEL = os.getenv("PAPERLESS_AI_LLM_MODEL")
LLM_API_KEY = os.getenv("PAPERLESS_AI_LLM_API_KEY")
LLM_ENDPOINT = os.getenv("PAPERLESS_AI_LLM_ENDPOINT")
+5 -6
View File
@@ -209,12 +209,11 @@ def parse_db_settings(data_dir: Path) -> dict[str, dict[str, Any]]:
Returns:
A databases dict suitable for Django DATABASES setting.
"""
try:
engine = get_choice_from_env(
"PAPERLESS_DBENGINE",
{"sqlite", "postgresql", "mariadb"},
)
except ValueError:
engine = get_choice_from_env(
"PAPERLESS_DBENGINE",
{"sqlite", "postgresql", "mariadb"},
)
if engine is None:
# MariaDB users already had to set PAPERLESS_DBENGINE, so it was picked up above
# SQLite users didn't need to set anything
engine = "postgresql" if "PAPERLESS_DBHOST" in os.environ else "sqlite"
+27 -7
View File
@@ -258,32 +258,52 @@ def get_list_from_env(
return []
@overload
def get_choice_from_env(
env_key: str,
choices: set[str] | frozenset[str],
) -> str | None: ...
@overload
def get_choice_from_env(
env_key: str,
choices: set[str] | frozenset[str],
default: None,
) -> str | None: ...
@overload
def get_choice_from_env(
env_key: str,
choices: set[str] | frozenset[str],
default: str,
) -> str: ...
def get_choice_from_env(
env_key: str,
choices: set[str] | frozenset[str],
default: str | None = None,
) -> str:
) -> str | None:
"""
Gets and validates an environment variable against a set of allowed choices.
Args:
env_key: The environment variable key to validate
choices: Set of valid choices for the environment variable
default: Optional default value if environment variable is not set
default: Default value if environment variable is not set; None means optional
Returns:
The validated environment variable value
The validated environment variable value, or None if not set and no default
Raises:
ValueError: If the environment variable value is not in choices
or if no default is provided and env var is missing
"""
value = os.environ.get(env_key, default)
if value is None:
raise ValueError(
f"Environment variable '{env_key}' is required but not set.",
)
return None
if value not in choices:
raise ValueError(
@@ -509,20 +509,17 @@ class TestGetEnvChoice:
assert result == "staging"
def test_raises_error_when_env_not_set_and_no_default(
def test_returns_none_when_env_not_set_and_no_default(
self,
mocker: MockerFixture,
valid_choices: set[str],
) -> None:
"""Test that function raises ValueError when env var is missing and no default."""
"""Test that function returns None when env var is missing and no default given."""
mocker.patch.dict("os.environ", {}, clear=True)
with pytest.raises(ValueError) as exc_info:
get_choice_from_env("TEST_ENV", valid_choices)
result = get_choice_from_env("TEST_ENV", valid_choices)
assert "Environment variable 'TEST_ENV' is required but not set" in str(
exc_info.value,
)
assert result is None
def test_raises_error_when_env_value_invalid(
self,