Security: Improve overall security in a few ways (#12501)

- Make sure we're always using regex with timeouts for user controlled data
- Adds rate limiting to the token endpoint (configurable)
- Signs the classifier pickle file with the SECRET_KEY and refuse to load one which doesn't verify.
- Require the user to set a secret key, instead of falling back to our old hard coded one
This commit is contained in:
Trenton H
2026-04-02 15:30:26 -07:00
committed by GitHub
parent 376af81b9c
commit dda05a7c00
14 changed files with 443 additions and 110 deletions
+14 -6
View File
@@ -7,6 +7,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
import regex as regex_mod
from django.conf import settings
from pdf2image import convert_from_path
from pikepdf import Page
@@ -22,6 +23,8 @@ from documents.plugins.base import ConsumeTaskPlugin
from documents.plugins.base import StopConsumeTaskError
from documents.plugins.helpers import ProgressManager
from documents.plugins.helpers import ProgressStatusOptions
from documents.regex import safe_regex_match
from documents.regex import safe_regex_sub
from documents.utils import copy_basic_file_stats
from documents.utils import copy_file_with_basic_stats
from documents.utils import maybe_override_pixel_limit
@@ -68,8 +71,8 @@ class Barcode:
Note: This does NOT exclude ASN or separator barcodes - they can also be used
as tags if they match a tag mapping pattern (e.g., {"ASN12.*": "JOHN"}).
"""
for regex in self.settings.barcode_tag_mapping:
if re.match(regex, self.value, flags=re.IGNORECASE):
for pattern in self.settings.barcode_tag_mapping:
if safe_regex_match(pattern, self.value, flags=regex_mod.IGNORECASE):
return True
return False
@@ -392,11 +395,16 @@ class BarcodePlugin(ConsumeTaskPlugin):
for raw in tag_texts.split(","):
try:
tag_str: str | None = None
for regex in self.settings.barcode_tag_mapping:
if re.match(regex, raw, flags=re.IGNORECASE):
sub = self.settings.barcode_tag_mapping[regex]
for pattern in self.settings.barcode_tag_mapping:
if safe_regex_match(pattern, raw, flags=regex_mod.IGNORECASE):
sub = self.settings.barcode_tag_mapping[pattern]
tag_str = (
re.sub(regex, sub, raw, flags=re.IGNORECASE)
safe_regex_sub(
pattern,
sub,
raw,
flags=regex_mod.IGNORECASE,
)
if sub
else raw
)
+76 -50
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
import hmac
import logging
import pickle
import re
@@ -75,7 +76,7 @@ def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | No
"Unrecoverable error while loading document "
"classification model, deleting model file.",
)
Path(settings.MODEL_FILE).unlink
Path(settings.MODEL_FILE).unlink()
classifier = None
if raise_exception:
raise e
@@ -97,7 +98,10 @@ class DocumentClassifier:
# v7 - Updated scikit-learn package version
# v8 - Added storage path classifier
# v9 - Changed from hashing to time/ids for re-train check
FORMAT_VERSION = 9
# v10 - HMAC-signed model file
FORMAT_VERSION = 10
HMAC_SIZE = 32 # SHA-256 digest length
def __init__(self) -> None:
# last time a document changed and therefore training might be required
@@ -128,67 +132,89 @@ class DocumentClassifier:
pickle.dumps(self.data_vectorizer),
).hexdigest()
@staticmethod
def _compute_hmac(data: bytes) -> bytes:
return hmac.new(
settings.SECRET_KEY.encode(),
data,
sha256,
).digest()
def load(self) -> None:
from sklearn.exceptions import InconsistentVersionWarning
raw = Path(settings.MODEL_FILE).read_bytes()
if len(raw) <= self.HMAC_SIZE:
raise ClassifierModelCorruptError
signature = raw[: self.HMAC_SIZE]
data = raw[self.HMAC_SIZE :]
if not hmac.compare_digest(signature, self._compute_hmac(data)):
raise ClassifierModelCorruptError
# Catch warnings for processing
with warnings.catch_warnings(record=True) as w:
with Path(settings.MODEL_FILE).open("rb") as f:
schema_version = pickle.load(f)
try:
(
schema_version,
self.last_doc_change_time,
self.last_auto_type_hash,
self.data_vectorizer,
self.tags_binarizer,
self.tags_classifier,
self.correspondent_classifier,
self.document_type_classifier,
self.storage_path_classifier,
) = pickle.loads(data)
except Exception as err:
raise ClassifierModelCorruptError from err
if schema_version != self.FORMAT_VERSION:
raise IncompatibleClassifierVersionError(
"Cannot load classifier, incompatible versions.",
)
else:
try:
self.last_doc_change_time = pickle.load(f)
self.last_auto_type_hash = pickle.load(f)
self.data_vectorizer = pickle.load(f)
self._update_data_vectorizer_hash()
self.tags_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.document_type_classifier = pickle.load(f)
self.storage_path_classifier = pickle.load(f)
except Exception as err:
raise ClassifierModelCorruptError from err
# Check for the warning about unpickling from differing versions
# and consider it incompatible
sk_learn_warning_url = (
"https://scikit-learn.org/stable/"
"model_persistence.html"
"#security-maintainability-limitations"
if schema_version != self.FORMAT_VERSION:
raise IncompatibleClassifierVersionError(
"Cannot load classifier, incompatible versions.",
)
for warning in w:
# The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet
if issubclass(warning.category, InconsistentVersionWarning) or (
issubclass(warning.category, UserWarning)
and sk_learn_warning_url in str(warning.message)
):
raise IncompatibleClassifierVersionError("sklearn version update")
self._update_data_vectorizer_hash()
# Check for the warning about unpickling from differing versions
# and consider it incompatible
sk_learn_warning_url = (
"https://scikit-learn.org/stable/"
"model_persistence.html"
"#security-maintainability-limitations"
)
for warning in w:
# The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet
if issubclass(warning.category, InconsistentVersionWarning) or (
issubclass(warning.category, UserWarning)
and sk_learn_warning_url in str(warning.message)
):
raise IncompatibleClassifierVersionError("sklearn version update")
def save(self) -> None:
target_file: Path = settings.MODEL_FILE
target_file_temp: Path = target_file.with_suffix(".pickle.part")
data = pickle.dumps(
(
self.FORMAT_VERSION,
self.last_doc_change_time,
self.last_auto_type_hash,
self.data_vectorizer,
self.tags_binarizer,
self.tags_classifier,
self.correspondent_classifier,
self.document_type_classifier,
self.storage_path_classifier,
),
)
signature = self._compute_hmac(data)
with target_file_temp.open("wb") as f:
pickle.dump(self.FORMAT_VERSION, f)
pickle.dump(self.last_doc_change_time, f)
pickle.dump(self.last_auto_type_hash, f)
pickle.dump(self.data_vectorizer, f)
pickle.dump(self.tags_binarizer, f)
pickle.dump(self.tags_classifier, f)
pickle.dump(self.correspondent_classifier, f)
pickle.dump(self.document_type_classifier, f)
pickle.dump(self.storage_path_classifier, f)
f.write(signature + data)
target_file_temp.rename(target_file)
@@ -1,9 +1,11 @@
import datetime
import re
from collections.abc import Iterator
from re import Match
import regex
from regex import Match
from documents.plugins.date_parsing.base import DateParserPluginBase
from documents.regex import safe_regex_finditer
class RegexDateParserPlugin(DateParserPluginBase):
@@ -14,7 +16,7 @@ class RegexDateParserPlugin(DateParserPluginBase):
passed to its constructor.
"""
DATE_REGEX = re.compile(
DATE_REGEX = regex.compile(
r"(\b|(?!=([_-])))(\d{1,2})[\.\/-](\d{1,2})[\.\/-](\d{4}|\d{2})(\b|(?=([_-])))|"
r"(\b|(?!=([_-])))(\d{4}|\d{2})[\.\/-](\d{1,2})[\.\/-](\d{1,2})(\b|(?=([_-])))|"
r"(\b|(?!=([_-])))(\d{1,2}[\. ]+[a-zéûäëčžúřěáíóńźçŞğü]{3,9} \d{4}|[a-zéûäëčžúřěáíóńźçŞğü]{3,9} \d{1,2}, \d{4})(\b|(?=([_-])))|"
@@ -22,7 +24,7 @@ class RegexDateParserPlugin(DateParserPluginBase):
r"(\b|(?!=([_-])))([^\W\d_]{3,9} \d{4})(\b|(?=([_-])))|"
r"(\b|(?!=([_-])))(\d{1,2}[^ 0-9]{2}[\. ]+[^ ]{3,9}[ \.\/-]\d{4})(\b|(?=([_-])))|"
r"(\b|(?!=([_-])))(\b\d{1,2}[ \.\/-][a-zéûäëčžúřěáíóńźçŞğü]{3}[ \.\/-]\d{4})(\b|(?=([_-])))",
re.IGNORECASE,
regex.IGNORECASE,
)
def _process_match(
@@ -45,7 +47,7 @@ class RegexDateParserPlugin(DateParserPluginBase):
"""
Finds all regex matches in content and yields valid dates.
"""
for m in re.finditer(self.DATE_REGEX, content):
for m in safe_regex_finditer(self.DATE_REGEX, content):
date = self._process_match(m, date_order)
if date is not None:
yield date
+70
View File
@@ -48,3 +48,73 @@ def safe_regex_search(pattern: str, text: str, *, flags: int = 0):
textwrap.shorten(pattern, width=80, placeholder=""),
)
return None
def safe_regex_match(pattern: str, text: str, *, flags: int = 0):
"""
Run a regex match with a timeout. Returns a match object or None.
Validation errors and timeouts are logged and treated as no match.
"""
try:
validate_regex_pattern(pattern)
compiled = regex.compile(pattern, flags=flags)
except (regex.error, ValueError) as exc:
logger.error(
"Error while processing regular expression %s: %s",
textwrap.shorten(pattern, width=80, placeholder=""),
exc,
)
return None
try:
return compiled.match(text, timeout=REGEX_TIMEOUT_SECONDS)
except TimeoutError:
logger.warning(
"Regular expression matching timed out for pattern %s",
textwrap.shorten(pattern, width=80, placeholder=""),
)
return None
def safe_regex_sub(pattern: str, repl: str, text: str, *, flags: int = 0) -> str | None:
"""
Run a regex substitution with a timeout. Returns the substituted string,
or None on error/timeout.
"""
try:
validate_regex_pattern(pattern)
compiled = regex.compile(pattern, flags=flags)
except (regex.error, ValueError) as exc:
logger.error(
"Error while processing regular expression %s: %s",
textwrap.shorten(pattern, width=80, placeholder=""),
exc,
)
return None
try:
return compiled.sub(repl, text, timeout=REGEX_TIMEOUT_SECONDS)
except TimeoutError:
logger.warning(
"Regular expression substitution timed out for pattern %s",
textwrap.shorten(pattern, width=80, placeholder=""),
)
return None
def safe_regex_finditer(compiled_pattern: regex.Pattern, text: str):
"""
Run regex finditer with a timeout. Yields match objects.
Stops iteration on timeout.
"""
try:
yield from compiled_pattern.finditer(text, timeout=REGEX_TIMEOUT_SECONDS)
except TimeoutError:
logger.warning(
"Regular expression finditer timed out for pattern %s",
textwrap.shorten(compiled_pattern.pattern, width=80, placeholder=""),
)
return
Binary file not shown.
+99 -31
View File
@@ -1,5 +1,5 @@
import re
import shutil
import warnings
from pathlib import Path
from unittest import mock
@@ -366,8 +366,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
@mock.patch("documents.classifier.pickle.load")
def test_load_corrupt_file(self, patched_pickle_load: mock.MagicMock) -> None:
def test_load_corrupt_file(self) -> None:
"""
GIVEN:
- Corrupted classifier pickle file
@@ -378,36 +377,116 @@ class TestClassifier(DirectoriesMixin, TestCase):
"""
self.generate_train_and_save()
# First load is the schema version,allow it
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
# Write garbage data (valid HMAC length but invalid content)
Path(settings.MODEL_FILE).write_bytes(b"\x00" * 64)
with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
patched_pickle_load.assert_called()
patched_pickle_load.reset_mock()
patched_pickle_load.side_effect = [
DocumentClassifier.FORMAT_VERSION,
ClassifierModelCorruptError(),
]
self.assertIsNone(load_classifier())
patched_pickle_load.assert_called()
def test_load_corrupt_pickle_valid_hmac(self) -> None:
"""
GIVEN:
- A classifier file with valid HMAC but unparsable pickle data
WHEN:
- An attempt is made to load the classifier
THEN:
- The ClassifierModelCorruptError is raised
"""
garbage_data = b"this is not valid pickle data"
signature = DocumentClassifier._compute_hmac(garbage_data)
Path(settings.MODEL_FILE).write_bytes(signature + garbage_data)
with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
def test_load_tampered_file(self) -> None:
"""
GIVEN:
- A classifier model file whose data has been modified
WHEN:
- An attempt is made to load the classifier
THEN:
- The ClassifierModelCorruptError is raised due to HMAC mismatch
"""
self.generate_train_and_save()
raw = Path(settings.MODEL_FILE).read_bytes()
# Flip a byte in the data portion (after the 32-byte HMAC)
tampered = raw[:32] + bytes([raw[32] ^ 0xFF]) + raw[33:]
Path(settings.MODEL_FILE).write_bytes(tampered)
with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
def test_load_wrong_secret_key(self) -> None:
"""
GIVEN:
- A classifier model file signed with a different SECRET_KEY
WHEN:
- An attempt is made to load the classifier
THEN:
- The ClassifierModelCorruptError is raised due to HMAC mismatch
"""
self.generate_train_and_save()
with override_settings(SECRET_KEY="different-secret-key"):
with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
def test_load_truncated_file(self) -> None:
"""
GIVEN:
- A classifier model file that is too short to contain an HMAC
WHEN:
- An attempt is made to load the classifier
THEN:
- The ClassifierModelCorruptError is raised
"""
Path(settings.MODEL_FILE).write_bytes(b"\x00" * 16)
with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
def test_load_new_scikit_learn_version(self) -> None:
"""
GIVEN:
- classifier pickle file created with a different scikit-learn version
- classifier pickle file triggers an InconsistentVersionWarning
WHEN:
- An attempt is made to load the classifier
THEN:
- The classifier reports the warning was captured and processed
- IncompatibleClassifierVersionError is raised
"""
# TODO: This wasn't testing the warning anymore, as the schema changed
# but as it was implemented, it would require installing an old version
# rebuilding the file and committing that. Not developer friendly
# Need to rethink how to pass the load through to a file with a single
# old model?
from sklearn.exceptions import InconsistentVersionWarning
self.generate_train_and_save()
fake_warning = warnings.WarningMessage(
message=InconsistentVersionWarning(
estimator_name="MLPClassifier",
current_sklearn_version="1.0",
original_sklearn_version="0.9",
),
category=InconsistentVersionWarning,
filename="",
lineno=0,
)
real_catch_warnings = warnings.catch_warnings
class PatchedCatchWarnings(real_catch_warnings):
def __enter__(self):
w = super().__enter__()
w.append(fake_warning)
return w
with mock.patch(
"documents.classifier.warnings.catch_warnings",
PatchedCatchWarnings,
):
with self.assertRaises(IncompatibleClassifierVersionError):
self.classifier.load()
def test_one_correspondent_predict(self) -> None:
c1 = Correspondent.objects.create(
@@ -685,17 +764,6 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.assertIsNone(load_classifier())
self.assertTrue(Path(settings.MODEL_FILE).exists())
def test_load_old_classifier_version(self) -> None:
shutil.copy(
Path(__file__).parent / "data" / "v1.17.4.model.pickle",
self.dirs.scratch_dir,
)
with override_settings(
MODEL_FILE=self.dirs.scratch_dir / "v1.17.4.model.pickle",
):
classifier = load_classifier()
self.assertIsNone(classifier)
@mock.patch("documents.classifier.DocumentClassifier.load")
def test_load_classifier_raise_exception(self, mock_load) -> None:
Path(settings.MODEL_FILE).touch()
+128
View File
@@ -0,0 +1,128 @@
import pytest
import regex
from pytest_mock import MockerFixture
from documents.regex import safe_regex_finditer
from documents.regex import safe_regex_match
from documents.regex import safe_regex_search
from documents.regex import safe_regex_sub
from documents.regex import validate_regex_pattern
class TestValidateRegexPattern:
def test_valid_pattern(self):
validate_regex_pattern(r"\d+")
def test_invalid_pattern_raises(self):
with pytest.raises(ValueError):
validate_regex_pattern(r"[invalid")
class TestSafeRegexSearchAndMatch:
"""Tests for safe_regex_search and safe_regex_match (same contract)."""
@pytest.mark.parametrize(
("func", "pattern", "text", "expected_group"),
[
pytest.param(
safe_regex_search,
r"\d+",
"abc123def",
"123",
id="search-match-found",
),
pytest.param(
safe_regex_match,
r"\d+",
"123abc",
"123",
id="match-match-found",
),
],
)
def test_match_found(self, func, pattern, text, expected_group):
result = func(pattern, text)
assert result is not None
assert result.group() == expected_group
@pytest.mark.parametrize(
("func", "pattern", "text"),
[
pytest.param(safe_regex_search, r"\d+", "abcdef", id="search-no-match"),
pytest.param(safe_regex_match, r"\d+", "abc123", id="match-no-match"),
],
)
def test_no_match(self, func, pattern, text):
assert func(pattern, text) is None
@pytest.mark.parametrize(
"func",
[
pytest.param(safe_regex_search, id="search"),
pytest.param(safe_regex_match, id="match"),
],
)
def test_invalid_pattern_returns_none(self, func):
assert func(r"[invalid", "test") is None
@pytest.mark.parametrize(
"func",
[
pytest.param(safe_regex_search, id="search"),
pytest.param(safe_regex_match, id="match"),
],
)
def test_flags_respected(self, func):
assert func(r"abc", "ABC", flags=regex.IGNORECASE) is not None
@pytest.mark.parametrize(
("func", "method_name"),
[
pytest.param(safe_regex_search, "search", id="search"),
pytest.param(safe_regex_match, "match", id="match"),
],
)
def test_timeout_returns_none(self, func, method_name, mocker: MockerFixture):
mock_compile = mocker.patch("documents.regex.regex.compile")
getattr(mock_compile.return_value, method_name).side_effect = TimeoutError
assert func(r"\d+", "test") is None
class TestSafeRegexSub:
@pytest.mark.parametrize(
("pattern", "repl", "text", "expected"),
[
pytest.param(r"\d+", "NUM", "abc123def456", "abcNUMdefNUM", id="basic-sub"),
pytest.param(r"\d+", "NUM", "abcdef", "abcdef", id="no-match"),
pytest.param(r"abc", "X", "ABC", "X", id="flags"),
],
)
def test_substitution(self, pattern, repl, text, expected):
flags = regex.IGNORECASE if pattern == r"abc" else 0
result = safe_regex_sub(pattern, repl, text, flags=flags)
assert result == expected
def test_invalid_pattern_returns_none(self):
assert safe_regex_sub(r"[invalid", "x", "test") is None
def test_timeout_returns_none(self, mocker: MockerFixture):
mock_compile = mocker.patch("documents.regex.regex.compile")
mock_compile.return_value.sub.side_effect = TimeoutError
assert safe_regex_sub(r"\d+", "X", "test") is None
class TestSafeRegexFinditer:
def test_yields_matches(self):
pattern = regex.compile(r"\d+")
matches = list(safe_regex_finditer(pattern, "a1b22c333"))
assert [m.group() for m in matches] == ["1", "22", "333"]
def test_no_matches(self):
pattern = regex.compile(r"\d+")
assert list(safe_regex_finditer(pattern, "abcdef")) == []
def test_timeout_stops_iteration(self, mocker: MockerFixture):
mock_pattern = mocker.MagicMock()
mock_pattern.finditer.side_effect = TimeoutError
mock_pattern.pattern = r"\d+"
assert list(safe_regex_finditer(mock_pattern, "test")) == []