From dda05a7c00eb91e76e044aa4007a548b4acb4446 Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:30:26 -0700 Subject: [PATCH] Security: Improve overall security in a few ways (#12501) - Make sure we're always using regex with timeouts for user controlled data - Adds rate limiting to the token endpoint (configurable) - Signs the classifier pickle file with the SECRET_KEY and refuse to load one which doesn't verify. - Require the user to set a secret key, instead of falling back to our old hard coded one --- Dockerfile | 4 +- docker/compose/docker-compose.env | 6 +- docs/configuration.md | 30 +++- paperless.conf.example | 3 +- pyproject.toml | 3 + src/documents/barcodes.py | 20 ++- src/documents/classifier.py | 126 ++++++++++------- .../plugins/date_parsing/regex_parser.py | 12 +- src/documents/regex.py | 70 ++++++++++ src/documents/tests/data/v1.17.4.model.pickle | Bin 714 -> 0 bytes src/documents/tests/test_classifier.py | 130 +++++++++++++----- src/documents/tests/test_regex.py | 128 +++++++++++++++++ src/paperless/settings/__init__.py | 18 ++- src/paperless/views.py | 3 + 14 files changed, 443 insertions(+), 110 deletions(-) delete mode 100644 src/documents/tests/data/v1.17.4.model.pickle create mode 100644 src/documents/tests/test_regex.py diff --git a/Dockerfile b/Dockerfile index ac6143162..0b8886c61 100644 --- a/Dockerfile +++ b/Dockerfile @@ -237,8 +237,8 @@ RUN set -eux \ && echo "Adjusting all permissions" \ && chown --from root:root --changes --recursive paperless:paperless /usr/src/paperless \ && echo "Collecting static files" \ - && s6-setuidgid paperless python3 manage.py collectstatic --clear --no-input --link \ - && s6-setuidgid paperless python3 manage.py compilemessages \ + && PAPERLESS_SECRET_KEY=build-time-dummy s6-setuidgid paperless python3 manage.py collectstatic --clear --no-input --link \ + && PAPERLESS_SECRET_KEY=build-time-dummy s6-setuidgid paperless python3 manage.py compilemessages \ && /usr/local/bin/deduplicate.py --verbose /usr/src/paperless/static/ VOLUME ["/usr/src/paperless/data", \ diff --git a/docker/compose/docker-compose.env b/docker/compose/docker-compose.env index 75eeeed09..af6a6e8fe 100644 --- a/docker/compose/docker-compose.env +++ b/docker/compose/docker-compose.env @@ -17,9 +17,9 @@ # (if doing so please consider security measures such as reverse proxy) #PAPERLESS_URL=https://paperless.example.com -# Adjust this key if you plan to make paperless available publicly. It should -# be a very long sequence of random characters. You don't need to remember it. -#PAPERLESS_SECRET_KEY=change-me +# Required. A unique secret key for session tokens and signing. +# Generate with: python3 -c "import secrets; print(secrets.token_urlsafe(64))" +PAPERLESS_SECRET_KEY=change-me # Use this variable to set a timezone for the Paperless Docker containers. Defaults to UTC. #PAPERLESS_TIME_ZONE=America/Los_Angeles diff --git a/docs/configuration.md b/docs/configuration.md index a22171ce9..fa0d32c51 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -402,6 +402,12 @@ Defaults to `/usr/share/nltk_data` : This is where paperless will store the classification model. + !!! warning + + The classification model uses Python's pickle serialization format. + Ensure this file is only writable by the paperless user, as a + maliciously crafted model file could execute arbitrary code when loaded. + Defaults to `PAPERLESS_DATA_DIR/classification_model.pickle`. ## Logging @@ -422,14 +428,20 @@ Defaults to `/usr/share/nltk_data` #### [`PAPERLESS_SECRET_KEY=`](#PAPERLESS_SECRET_KEY) {#PAPERLESS_SECRET_KEY} -: Paperless uses this to make session tokens. If you expose paperless -on the internet, you need to change this, since the default secret -is well known. +: **Required.** Paperless uses this to make session tokens and sign +sensitive data. Paperless will refuse to start if this is not set. Use any sequence of characters. The more, the better. You don't - need to remember this. Just face-roll your keyboard. + need to remember this. You can generate a suitable key with: - Default is listed in the file `src/paperless/settings.py`. + python3 -c "import secrets; print(secrets.token_urlsafe(64))" + + !!! warning + + This setting has no default value. You **must** set it before + starting Paperless. Existing installations that relied on the + previous default value should set `PAPERLESS_SECRET_KEY` to + that value to avoid invalidating existing sessions and tokens. #### [`PAPERLESS_URL=`](#PAPERLESS_URL) {#PAPERLESS_URL} @@ -770,6 +782,14 @@ If both the [PAPERLESS_ACCOUNT_DEFAULT_GROUPS](#PAPERLESS_ACCOUNT_DEFAULT_GROUPS Defaults to 1209600 (2 weeks) +#### [`PAPERLESS_TOKEN_THROTTLE_RATE=`](#PAPERLESS_TOKEN_THROTTLE_RATE) {#PAPERLESS_TOKEN_THROTTLE_RATE} + +: Rate limit for the API token authentication endpoint (`/api/token/`), used to mitigate brute-force login attempts. +Uses Django REST Framework's [throttle rate format](https://www.django-rest-framework.org/api-guide/throttling/#setting-the-throttling-policy), +e.g. `5/min`, `100/hour`, `1000/day`. + + Defaults to `5/min` + ## OCR settings {#ocr} Paperless uses [OCRmyPDF](https://ocrmypdf.readthedocs.io/en/latest/) diff --git a/paperless.conf.example b/paperless.conf.example index 9974aeab6..a0c406f82 100644 --- a/paperless.conf.example +++ b/paperless.conf.example @@ -23,7 +23,8 @@ # Security and hosting -#PAPERLESS_SECRET_KEY=change-me +# Required. Generate with: python3 -c "import secrets; print(secrets.token_urlsafe(64))" +PAPERLESS_SECRET_KEY=change-me #PAPERLESS_URL=https://example.com #PAPERLESS_CSRF_TRUSTED_ORIGINS=https://example.com # can be set using PAPERLESS_URL #PAPERLESS_ALLOWED_HOSTS=example.com,www.example.com # can be set using PAPERLESS_URL diff --git a/pyproject.toml b/pyproject.toml index 5af886f0c..7bb160956 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -315,9 +315,12 @@ markers = [ ] [tool.pytest_env] +PAPERLESS_SECRET_KEY = "test-secret-key-do-not-use-in-production" PAPERLESS_DISABLE_DBHANDLER = "true" PAPERLESS_CACHE_BACKEND = "django.core.cache.backends.locmem.LocMemCache" PAPERLESS_CHANNELS_BACKEND = "channels.layers.InMemoryChannelLayer" +# I don't think anything hits this, but just in case, basically infinite +PAPERLESS_TOKEN_THROTTLE_RATE = "1000/min" [tool.coverage.report] exclude_also = [ diff --git a/src/documents/barcodes.py b/src/documents/barcodes.py index 31ef052c4..38a28081a 100644 --- a/src/documents/barcodes.py +++ b/src/documents/barcodes.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING +import regex as regex_mod from django.conf import settings from pdf2image import convert_from_path from pikepdf import Page @@ -22,6 +23,8 @@ from documents.plugins.base import ConsumeTaskPlugin from documents.plugins.base import StopConsumeTaskError from documents.plugins.helpers import ProgressManager from documents.plugins.helpers import ProgressStatusOptions +from documents.regex import safe_regex_match +from documents.regex import safe_regex_sub from documents.utils import copy_basic_file_stats from documents.utils import copy_file_with_basic_stats from documents.utils import maybe_override_pixel_limit @@ -68,8 +71,8 @@ class Barcode: Note: This does NOT exclude ASN or separator barcodes - they can also be used as tags if they match a tag mapping pattern (e.g., {"ASN12.*": "JOHN"}). """ - for regex in self.settings.barcode_tag_mapping: - if re.match(regex, self.value, flags=re.IGNORECASE): + for pattern in self.settings.barcode_tag_mapping: + if safe_regex_match(pattern, self.value, flags=regex_mod.IGNORECASE): return True return False @@ -392,11 +395,16 @@ class BarcodePlugin(ConsumeTaskPlugin): for raw in tag_texts.split(","): try: tag_str: str | None = None - for regex in self.settings.barcode_tag_mapping: - if re.match(regex, raw, flags=re.IGNORECASE): - sub = self.settings.barcode_tag_mapping[regex] + for pattern in self.settings.barcode_tag_mapping: + if safe_regex_match(pattern, raw, flags=regex_mod.IGNORECASE): + sub = self.settings.barcode_tag_mapping[pattern] tag_str = ( - re.sub(regex, sub, raw, flags=re.IGNORECASE) + safe_regex_sub( + pattern, + sub, + raw, + flags=regex_mod.IGNORECASE, + ) if sub else raw ) diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 87934ab52..519e1eac5 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hmac import logging import pickle import re @@ -75,7 +76,7 @@ def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | No "Unrecoverable error while loading document " "classification model, deleting model file.", ) - Path(settings.MODEL_FILE).unlink + Path(settings.MODEL_FILE).unlink() classifier = None if raise_exception: raise e @@ -97,7 +98,10 @@ class DocumentClassifier: # v7 - Updated scikit-learn package version # v8 - Added storage path classifier # v9 - Changed from hashing to time/ids for re-train check - FORMAT_VERSION = 9 + # v10 - HMAC-signed model file + FORMAT_VERSION = 10 + + HMAC_SIZE = 32 # SHA-256 digest length def __init__(self) -> None: # last time a document changed and therefore training might be required @@ -128,67 +132,89 @@ class DocumentClassifier: pickle.dumps(self.data_vectorizer), ).hexdigest() + @staticmethod + def _compute_hmac(data: bytes) -> bytes: + return hmac.new( + settings.SECRET_KEY.encode(), + data, + sha256, + ).digest() + def load(self) -> None: from sklearn.exceptions import InconsistentVersionWarning + raw = Path(settings.MODEL_FILE).read_bytes() + + if len(raw) <= self.HMAC_SIZE: + raise ClassifierModelCorruptError + + signature = raw[: self.HMAC_SIZE] + data = raw[self.HMAC_SIZE :] + + if not hmac.compare_digest(signature, self._compute_hmac(data)): + raise ClassifierModelCorruptError + # Catch warnings for processing with warnings.catch_warnings(record=True) as w: - with Path(settings.MODEL_FILE).open("rb") as f: - schema_version = pickle.load(f) + try: + ( + schema_version, + self.last_doc_change_time, + self.last_auto_type_hash, + self.data_vectorizer, + self.tags_binarizer, + self.tags_classifier, + self.correspondent_classifier, + self.document_type_classifier, + self.storage_path_classifier, + ) = pickle.loads(data) + except Exception as err: + raise ClassifierModelCorruptError from err - if schema_version != self.FORMAT_VERSION: - raise IncompatibleClassifierVersionError( - "Cannot load classifier, incompatible versions.", - ) - else: - try: - self.last_doc_change_time = pickle.load(f) - self.last_auto_type_hash = pickle.load(f) - - self.data_vectorizer = pickle.load(f) - self._update_data_vectorizer_hash() - self.tags_binarizer = pickle.load(f) - - self.tags_classifier = pickle.load(f) - self.correspondent_classifier = pickle.load(f) - self.document_type_classifier = pickle.load(f) - self.storage_path_classifier = pickle.load(f) - except Exception as err: - raise ClassifierModelCorruptError from err - - # Check for the warning about unpickling from differing versions - # and consider it incompatible - sk_learn_warning_url = ( - "https://scikit-learn.org/stable/" - "model_persistence.html" - "#security-maintainability-limitations" + if schema_version != self.FORMAT_VERSION: + raise IncompatibleClassifierVersionError( + "Cannot load classifier, incompatible versions.", ) - for warning in w: - # The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet - if issubclass(warning.category, InconsistentVersionWarning) or ( - issubclass(warning.category, UserWarning) - and sk_learn_warning_url in str(warning.message) - ): - raise IncompatibleClassifierVersionError("sklearn version update") + + self._update_data_vectorizer_hash() + + # Check for the warning about unpickling from differing versions + # and consider it incompatible + sk_learn_warning_url = ( + "https://scikit-learn.org/stable/" + "model_persistence.html" + "#security-maintainability-limitations" + ) + for warning in w: + # The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet + if issubclass(warning.category, InconsistentVersionWarning) or ( + issubclass(warning.category, UserWarning) + and sk_learn_warning_url in str(warning.message) + ): + raise IncompatibleClassifierVersionError("sklearn version update") def save(self) -> None: target_file: Path = settings.MODEL_FILE target_file_temp: Path = target_file.with_suffix(".pickle.part") + data = pickle.dumps( + ( + self.FORMAT_VERSION, + self.last_doc_change_time, + self.last_auto_type_hash, + self.data_vectorizer, + self.tags_binarizer, + self.tags_classifier, + self.correspondent_classifier, + self.document_type_classifier, + self.storage_path_classifier, + ), + ) + + signature = self._compute_hmac(data) + with target_file_temp.open("wb") as f: - pickle.dump(self.FORMAT_VERSION, f) - - pickle.dump(self.last_doc_change_time, f) - pickle.dump(self.last_auto_type_hash, f) - - pickle.dump(self.data_vectorizer, f) - - pickle.dump(self.tags_binarizer, f) - pickle.dump(self.tags_classifier, f) - - pickle.dump(self.correspondent_classifier, f) - pickle.dump(self.document_type_classifier, f) - pickle.dump(self.storage_path_classifier, f) + f.write(signature + data) target_file_temp.rename(target_file) diff --git a/src/documents/plugins/date_parsing/regex_parser.py b/src/documents/plugins/date_parsing/regex_parser.py index 2df8f9295..07a9e24f0 100644 --- a/src/documents/plugins/date_parsing/regex_parser.py +++ b/src/documents/plugins/date_parsing/regex_parser.py @@ -1,9 +1,11 @@ import datetime -import re from collections.abc import Iterator -from re import Match + +import regex +from regex import Match from documents.plugins.date_parsing.base import DateParserPluginBase +from documents.regex import safe_regex_finditer class RegexDateParserPlugin(DateParserPluginBase): @@ -14,7 +16,7 @@ class RegexDateParserPlugin(DateParserPluginBase): passed to its constructor. """ - DATE_REGEX = re.compile( + DATE_REGEX = regex.compile( r"(\b|(?!=([_-])))(\d{1,2})[\.\/-](\d{1,2})[\.\/-](\d{4}|\d{2})(\b|(?=([_-])))|" r"(\b|(?!=([_-])))(\d{4}|\d{2})[\.\/-](\d{1,2})[\.\/-](\d{1,2})(\b|(?=([_-])))|" r"(\b|(?!=([_-])))(\d{1,2}[\. ]+[a-zéûäëčžúřěáíóńźçŞğü]{3,9} \d{4}|[a-zéûäëčžúřěáíóńźçŞğü]{3,9} \d{1,2}, \d{4})(\b|(?=([_-])))|" @@ -22,7 +24,7 @@ class RegexDateParserPlugin(DateParserPluginBase): r"(\b|(?!=([_-])))([^\W\d_]{3,9} \d{4})(\b|(?=([_-])))|" r"(\b|(?!=([_-])))(\d{1,2}[^ 0-9]{2}[\. ]+[^ ]{3,9}[ \.\/-]\d{4})(\b|(?=([_-])))|" r"(\b|(?!=([_-])))(\b\d{1,2}[ \.\/-][a-zéûäëčžúřěáíóńźçŞğü]{3}[ \.\/-]\d{4})(\b|(?=([_-])))", - re.IGNORECASE, + regex.IGNORECASE, ) def _process_match( @@ -45,7 +47,7 @@ class RegexDateParserPlugin(DateParserPluginBase): """ Finds all regex matches in content and yields valid dates. """ - for m in re.finditer(self.DATE_REGEX, content): + for m in safe_regex_finditer(self.DATE_REGEX, content): date = self._process_match(m, date_order) if date is not None: yield date diff --git a/src/documents/regex.py b/src/documents/regex.py index 35acc5af0..849d417d8 100644 --- a/src/documents/regex.py +++ b/src/documents/regex.py @@ -48,3 +48,73 @@ def safe_regex_search(pattern: str, text: str, *, flags: int = 0): textwrap.shorten(pattern, width=80, placeholder="…"), ) return None + + +def safe_regex_match(pattern: str, text: str, *, flags: int = 0): + """ + Run a regex match with a timeout. Returns a match object or None. + Validation errors and timeouts are logged and treated as no match. + """ + + try: + validate_regex_pattern(pattern) + compiled = regex.compile(pattern, flags=flags) + except (regex.error, ValueError) as exc: + logger.error( + "Error while processing regular expression %s: %s", + textwrap.shorten(pattern, width=80, placeholder="…"), + exc, + ) + return None + + try: + return compiled.match(text, timeout=REGEX_TIMEOUT_SECONDS) + except TimeoutError: + logger.warning( + "Regular expression matching timed out for pattern %s", + textwrap.shorten(pattern, width=80, placeholder="…"), + ) + return None + + +def safe_regex_sub(pattern: str, repl: str, text: str, *, flags: int = 0) -> str | None: + """ + Run a regex substitution with a timeout. Returns the substituted string, + or None on error/timeout. + """ + + try: + validate_regex_pattern(pattern) + compiled = regex.compile(pattern, flags=flags) + except (regex.error, ValueError) as exc: + logger.error( + "Error while processing regular expression %s: %s", + textwrap.shorten(pattern, width=80, placeholder="…"), + exc, + ) + return None + + try: + return compiled.sub(repl, text, timeout=REGEX_TIMEOUT_SECONDS) + except TimeoutError: + logger.warning( + "Regular expression substitution timed out for pattern %s", + textwrap.shorten(pattern, width=80, placeholder="…"), + ) + return None + + +def safe_regex_finditer(compiled_pattern: regex.Pattern, text: str): + """ + Run regex finditer with a timeout. Yields match objects. + Stops iteration on timeout. + """ + + try: + yield from compiled_pattern.finditer(text, timeout=REGEX_TIMEOUT_SECONDS) + except TimeoutError: + logger.warning( + "Regular expression finditer timed out for pattern %s", + textwrap.shorten(compiled_pattern.pattern, width=80, placeholder="…"), + ) + return diff --git a/src/documents/tests/data/v1.17.4.model.pickle b/src/documents/tests/data/v1.17.4.model.pickle deleted file mode 100644 index 4b2734607f8453ad7ecf7d972eea7eafb4deea4d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 714 zcmZWn&ubGw6i(A5O`5b;(X$aeG!$dS(t7J5NRU7dJt#}>Nxq!lz) zc_%vxLiOa?lL!9<4<5vOp2s&32&ycX`bB-hA)9Z{EyqrctQwW}e+j9eQ=6UnIyJv=rcyXsGMg#PTgpfZF#FamSl0z#bMhN0`#&-p<{{ab%yzh# zbf}fbd!KH6`&fVb`|`VkquYO8y??s#_ru}ymtVgBymPqo>q1OFmrewo54U{aN>n?* zb)I%&bPj|clxa=g45q+MWsAHaMp5)4}4_;i~!Qw>RN|;QG z=f4TQFG&q%>1e)d8q+wvw!BJD1dH)rIf None: + def test_load_corrupt_file(self) -> None: """ GIVEN: - Corrupted classifier pickle file @@ -378,36 +377,116 @@ class TestClassifier(DirectoriesMixin, TestCase): """ self.generate_train_and_save() - # First load is the schema version,allow it - patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()] + # Write garbage data (valid HMAC length but invalid content) + Path(settings.MODEL_FILE).write_bytes(b"\x00" * 64) with self.assertRaises(ClassifierModelCorruptError): self.classifier.load() - patched_pickle_load.assert_called() - - patched_pickle_load.reset_mock() - patched_pickle_load.side_effect = [ - DocumentClassifier.FORMAT_VERSION, - ClassifierModelCorruptError(), - ] self.assertIsNone(load_classifier()) - patched_pickle_load.assert_called() + + def test_load_corrupt_pickle_valid_hmac(self) -> None: + """ + GIVEN: + - A classifier file with valid HMAC but unparsable pickle data + WHEN: + - An attempt is made to load the classifier + THEN: + - The ClassifierModelCorruptError is raised + """ + garbage_data = b"this is not valid pickle data" + signature = DocumentClassifier._compute_hmac(garbage_data) + Path(settings.MODEL_FILE).write_bytes(signature + garbage_data) + + with self.assertRaises(ClassifierModelCorruptError): + self.classifier.load() + + def test_load_tampered_file(self) -> None: + """ + GIVEN: + - A classifier model file whose data has been modified + WHEN: + - An attempt is made to load the classifier + THEN: + - The ClassifierModelCorruptError is raised due to HMAC mismatch + """ + self.generate_train_and_save() + + raw = Path(settings.MODEL_FILE).read_bytes() + # Flip a byte in the data portion (after the 32-byte HMAC) + tampered = raw[:32] + bytes([raw[32] ^ 0xFF]) + raw[33:] + Path(settings.MODEL_FILE).write_bytes(tampered) + + with self.assertRaises(ClassifierModelCorruptError): + self.classifier.load() + + def test_load_wrong_secret_key(self) -> None: + """ + GIVEN: + - A classifier model file signed with a different SECRET_KEY + WHEN: + - An attempt is made to load the classifier + THEN: + - The ClassifierModelCorruptError is raised due to HMAC mismatch + """ + self.generate_train_and_save() + + with override_settings(SECRET_KEY="different-secret-key"): + with self.assertRaises(ClassifierModelCorruptError): + self.classifier.load() + + def test_load_truncated_file(self) -> None: + """ + GIVEN: + - A classifier model file that is too short to contain an HMAC + WHEN: + - An attempt is made to load the classifier + THEN: + - The ClassifierModelCorruptError is raised + """ + Path(settings.MODEL_FILE).write_bytes(b"\x00" * 16) + + with self.assertRaises(ClassifierModelCorruptError): + self.classifier.load() def test_load_new_scikit_learn_version(self) -> None: """ GIVEN: - - classifier pickle file created with a different scikit-learn version + - classifier pickle file triggers an InconsistentVersionWarning WHEN: - An attempt is made to load the classifier THEN: - - The classifier reports the warning was captured and processed + - IncompatibleClassifierVersionError is raised """ - # TODO: This wasn't testing the warning anymore, as the schema changed - # but as it was implemented, it would require installing an old version - # rebuilding the file and committing that. Not developer friendly - # Need to rethink how to pass the load through to a file with a single - # old model? + from sklearn.exceptions import InconsistentVersionWarning + + self.generate_train_and_save() + + fake_warning = warnings.WarningMessage( + message=InconsistentVersionWarning( + estimator_name="MLPClassifier", + current_sklearn_version="1.0", + original_sklearn_version="0.9", + ), + category=InconsistentVersionWarning, + filename="", + lineno=0, + ) + + real_catch_warnings = warnings.catch_warnings + + class PatchedCatchWarnings(real_catch_warnings): + def __enter__(self): + w = super().__enter__() + w.append(fake_warning) + return w + + with mock.patch( + "documents.classifier.warnings.catch_warnings", + PatchedCatchWarnings, + ): + with self.assertRaises(IncompatibleClassifierVersionError): + self.classifier.load() def test_one_correspondent_predict(self) -> None: c1 = Correspondent.objects.create( @@ -685,17 +764,6 @@ class TestClassifier(DirectoriesMixin, TestCase): self.assertIsNone(load_classifier()) self.assertTrue(Path(settings.MODEL_FILE).exists()) - def test_load_old_classifier_version(self) -> None: - shutil.copy( - Path(__file__).parent / "data" / "v1.17.4.model.pickle", - self.dirs.scratch_dir, - ) - with override_settings( - MODEL_FILE=self.dirs.scratch_dir / "v1.17.4.model.pickle", - ): - classifier = load_classifier() - self.assertIsNone(classifier) - @mock.patch("documents.classifier.DocumentClassifier.load") def test_load_classifier_raise_exception(self, mock_load) -> None: Path(settings.MODEL_FILE).touch() diff --git a/src/documents/tests/test_regex.py b/src/documents/tests/test_regex.py new file mode 100644 index 000000000..a55f29c3c --- /dev/null +++ b/src/documents/tests/test_regex.py @@ -0,0 +1,128 @@ +import pytest +import regex +from pytest_mock import MockerFixture + +from documents.regex import safe_regex_finditer +from documents.regex import safe_regex_match +from documents.regex import safe_regex_search +from documents.regex import safe_regex_sub +from documents.regex import validate_regex_pattern + + +class TestValidateRegexPattern: + def test_valid_pattern(self): + validate_regex_pattern(r"\d+") + + def test_invalid_pattern_raises(self): + with pytest.raises(ValueError): + validate_regex_pattern(r"[invalid") + + +class TestSafeRegexSearchAndMatch: + """Tests for safe_regex_search and safe_regex_match (same contract).""" + + @pytest.mark.parametrize( + ("func", "pattern", "text", "expected_group"), + [ + pytest.param( + safe_regex_search, + r"\d+", + "abc123def", + "123", + id="search-match-found", + ), + pytest.param( + safe_regex_match, + r"\d+", + "123abc", + "123", + id="match-match-found", + ), + ], + ) + def test_match_found(self, func, pattern, text, expected_group): + result = func(pattern, text) + assert result is not None + assert result.group() == expected_group + + @pytest.mark.parametrize( + ("func", "pattern", "text"), + [ + pytest.param(safe_regex_search, r"\d+", "abcdef", id="search-no-match"), + pytest.param(safe_regex_match, r"\d+", "abc123", id="match-no-match"), + ], + ) + def test_no_match(self, func, pattern, text): + assert func(pattern, text) is None + + @pytest.mark.parametrize( + "func", + [ + pytest.param(safe_regex_search, id="search"), + pytest.param(safe_regex_match, id="match"), + ], + ) + def test_invalid_pattern_returns_none(self, func): + assert func(r"[invalid", "test") is None + + @pytest.mark.parametrize( + "func", + [ + pytest.param(safe_regex_search, id="search"), + pytest.param(safe_regex_match, id="match"), + ], + ) + def test_flags_respected(self, func): + assert func(r"abc", "ABC", flags=regex.IGNORECASE) is not None + + @pytest.mark.parametrize( + ("func", "method_name"), + [ + pytest.param(safe_regex_search, "search", id="search"), + pytest.param(safe_regex_match, "match", id="match"), + ], + ) + def test_timeout_returns_none(self, func, method_name, mocker: MockerFixture): + mock_compile = mocker.patch("documents.regex.regex.compile") + getattr(mock_compile.return_value, method_name).side_effect = TimeoutError + assert func(r"\d+", "test") is None + + +class TestSafeRegexSub: + @pytest.mark.parametrize( + ("pattern", "repl", "text", "expected"), + [ + pytest.param(r"\d+", "NUM", "abc123def456", "abcNUMdefNUM", id="basic-sub"), + pytest.param(r"\d+", "NUM", "abcdef", "abcdef", id="no-match"), + pytest.param(r"abc", "X", "ABC", "X", id="flags"), + ], + ) + def test_substitution(self, pattern, repl, text, expected): + flags = regex.IGNORECASE if pattern == r"abc" else 0 + result = safe_regex_sub(pattern, repl, text, flags=flags) + assert result == expected + + def test_invalid_pattern_returns_none(self): + assert safe_regex_sub(r"[invalid", "x", "test") is None + + def test_timeout_returns_none(self, mocker: MockerFixture): + mock_compile = mocker.patch("documents.regex.regex.compile") + mock_compile.return_value.sub.side_effect = TimeoutError + assert safe_regex_sub(r"\d+", "X", "test") is None + + +class TestSafeRegexFinditer: + def test_yields_matches(self): + pattern = regex.compile(r"\d+") + matches = list(safe_regex_finditer(pattern, "a1b22c333")) + assert [m.group() for m in matches] == ["1", "22", "333"] + + def test_no_matches(self): + pattern = regex.compile(r"\d+") + assert list(safe_regex_finditer(pattern, "abcdef")) == [] + + def test_timeout_stops_iteration(self, mocker: MockerFixture): + mock_pattern = mocker.MagicMock() + mock_pattern.finditer.side_effect = TimeoutError + mock_pattern.pattern = r"\d+" + assert list(safe_regex_finditer(mock_pattern, "test")) == [] diff --git a/src/paperless/settings/__init__.py b/src/paperless/settings/__init__.py index 3522b3187..a76c6ce75 100644 --- a/src/paperless/settings/__init__.py +++ b/src/paperless/settings/__init__.py @@ -11,6 +11,7 @@ from typing import Final from urllib.parse import urlparse from compression_middleware.middleware import CompressionMiddleware +from django.core.exceptions import ImproperlyConfigured from django.utils.translation import gettext_lazy as _ from dotenv import load_dotenv @@ -161,6 +162,9 @@ REST_FRAMEWORK = { "ALLOWED_VERSIONS": ["9", "10"], # DRF Spectacular default schema "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", + "DEFAULT_THROTTLE_RATES": { + "login": os.getenv("PAPERLESS_TOKEN_THROTTLE_RATE", "5/min"), + }, } if DEBUG: @@ -460,13 +464,13 @@ SECURE_PROXY_SSL_HEADER = ( else None ) -# The secret key has a default that should be fine so long as you're hosting -# Paperless on a closed network. However, if you're putting this anywhere -# public, you should change the key to something unique and verbose. -SECRET_KEY = os.getenv( - "PAPERLESS_SECRET_KEY", - "e11fl1oa-*ytql8p)(06fbj4ukrlo+n7k&q5+$1md7i+mge=ee", -) +SECRET_KEY = os.getenv("PAPERLESS_SECRET_KEY", "") +if not SECRET_KEY: # pragma: no cover + raise ImproperlyConfigured( + "PAPERLESS_SECRET_KEY is not set. " + "A unique, secret key is required for secure operation. " + 'Generate one with: python3 -c "import secrets; print(secrets.token_urlsafe(64))"', + ) AUTH_PASSWORD_VALIDATORS = [ { diff --git a/src/paperless/views.py b/src/paperless/views.py index a3b965f3f..e4db40bb4 100644 --- a/src/paperless/views.py +++ b/src/paperless/views.py @@ -34,6 +34,7 @@ from rest_framework.pagination import PageNumberPagination from rest_framework.permissions import DjangoModelPermissions from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response +from rest_framework.throttling import ScopedRateThrottle from rest_framework.viewsets import ModelViewSet from documents.permissions import PaperlessObjectPermissions @@ -51,6 +52,8 @@ from paperless_ai.indexing import vector_store_file_exists class PaperlessObtainAuthTokenView(ObtainAuthToken): serializer_class = PaperlessAuthTokenSerializer + throttle_classes = [ScopedRateThrottle] + throttle_scope = "login" class StandardPagination(PageNumberPagination):