From 79784ac407f9753072602ef40e2b4dfef07eedbb Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Thu, 2 Apr 2026 13:27:02 -0700 Subject: [PATCH] Signs the classifier so we have additional protections against tampering + pickle --- docs/configuration.md | 6 ++ src/documents/classifier.py | 126 +++++++++++++++---------- src/documents/tests/test_classifier.py | 105 +++++++++++++++------ 3 files changed, 156 insertions(+), 81 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 3ab1903fc..33f41c993 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -402,6 +402,12 @@ Defaults to `/usr/share/nltk_data` : This is where paperless will store the classification model. + !!! warning + + The classification model uses Python's pickle serialization format. + Ensure this file is only writable by the paperless user, as a + maliciously crafted model file could execute arbitrary code when loaded. + Defaults to `PAPERLESS_DATA_DIR/classification_model.pickle`. ## Logging diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 87934ab52..519e1eac5 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hmac import logging import pickle import re @@ -75,7 +76,7 @@ def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | No "Unrecoverable error while loading document " "classification model, deleting model file.", ) - Path(settings.MODEL_FILE).unlink + Path(settings.MODEL_FILE).unlink() classifier = None if raise_exception: raise e @@ -97,7 +98,10 @@ class DocumentClassifier: # v7 - Updated scikit-learn package version # v8 - Added storage path classifier # v9 - Changed from hashing to time/ids for re-train check - FORMAT_VERSION = 9 + # v10 - HMAC-signed model file + FORMAT_VERSION = 10 + + HMAC_SIZE = 32 # SHA-256 digest length def __init__(self) -> None: # last time a document changed and therefore training might be required @@ -128,67 +132,89 @@ class DocumentClassifier: pickle.dumps(self.data_vectorizer), ).hexdigest() + @staticmethod + def _compute_hmac(data: bytes) -> bytes: + return hmac.new( + settings.SECRET_KEY.encode(), + data, + sha256, + ).digest() + def load(self) -> None: from sklearn.exceptions import InconsistentVersionWarning + raw = Path(settings.MODEL_FILE).read_bytes() + + if len(raw) <= self.HMAC_SIZE: + raise ClassifierModelCorruptError + + signature = raw[: self.HMAC_SIZE] + data = raw[self.HMAC_SIZE :] + + if not hmac.compare_digest(signature, self._compute_hmac(data)): + raise ClassifierModelCorruptError + # Catch warnings for processing with warnings.catch_warnings(record=True) as w: - with Path(settings.MODEL_FILE).open("rb") as f: - schema_version = pickle.load(f) + try: + ( + schema_version, + self.last_doc_change_time, + self.last_auto_type_hash, + self.data_vectorizer, + self.tags_binarizer, + self.tags_classifier, + self.correspondent_classifier, + self.document_type_classifier, + self.storage_path_classifier, + ) = pickle.loads(data) + except Exception as err: + raise ClassifierModelCorruptError from err - if schema_version != self.FORMAT_VERSION: - raise IncompatibleClassifierVersionError( - "Cannot load classifier, incompatible versions.", - ) - else: - try: - self.last_doc_change_time = pickle.load(f) - self.last_auto_type_hash = pickle.load(f) - - self.data_vectorizer = pickle.load(f) - self._update_data_vectorizer_hash() - self.tags_binarizer = pickle.load(f) - - self.tags_classifier = pickle.load(f) - self.correspondent_classifier = pickle.load(f) - self.document_type_classifier = pickle.load(f) - self.storage_path_classifier = pickle.load(f) - except Exception as err: - raise ClassifierModelCorruptError from err - - # Check for the warning about unpickling from differing versions - # and consider it incompatible - sk_learn_warning_url = ( - "https://scikit-learn.org/stable/" - "model_persistence.html" - "#security-maintainability-limitations" + if schema_version != self.FORMAT_VERSION: + raise IncompatibleClassifierVersionError( + "Cannot load classifier, incompatible versions.", ) - for warning in w: - # The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet - if issubclass(warning.category, InconsistentVersionWarning) or ( - issubclass(warning.category, UserWarning) - and sk_learn_warning_url in str(warning.message) - ): - raise IncompatibleClassifierVersionError("sklearn version update") + + self._update_data_vectorizer_hash() + + # Check for the warning about unpickling from differing versions + # and consider it incompatible + sk_learn_warning_url = ( + "https://scikit-learn.org/stable/" + "model_persistence.html" + "#security-maintainability-limitations" + ) + for warning in w: + # The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet + if issubclass(warning.category, InconsistentVersionWarning) or ( + issubclass(warning.category, UserWarning) + and sk_learn_warning_url in str(warning.message) + ): + raise IncompatibleClassifierVersionError("sklearn version update") def save(self) -> None: target_file: Path = settings.MODEL_FILE target_file_temp: Path = target_file.with_suffix(".pickle.part") + data = pickle.dumps( + ( + self.FORMAT_VERSION, + self.last_doc_change_time, + self.last_auto_type_hash, + self.data_vectorizer, + self.tags_binarizer, + self.tags_classifier, + self.correspondent_classifier, + self.document_type_classifier, + self.storage_path_classifier, + ), + ) + + signature = self._compute_hmac(data) + with target_file_temp.open("wb") as f: - pickle.dump(self.FORMAT_VERSION, f) - - pickle.dump(self.last_doc_change_time, f) - pickle.dump(self.last_auto_type_hash, f) - - pickle.dump(self.data_vectorizer, f) - - pickle.dump(self.tags_binarizer, f) - pickle.dump(self.tags_classifier, f) - - pickle.dump(self.correspondent_classifier, f) - pickle.dump(self.document_type_classifier, f) - pickle.dump(self.storage_path_classifier, f) + f.write(signature + data) target_file_temp.rename(target_file) diff --git a/src/documents/tests/test_classifier.py b/src/documents/tests/test_classifier.py index f04152ae0..a1bb9baa5 100644 --- a/src/documents/tests/test_classifier.py +++ b/src/documents/tests/test_classifier.py @@ -1,5 +1,6 @@ +import pickle import re -import shutil +import warnings from pathlib import Path from unittest import mock @@ -366,8 +367,7 @@ class TestClassifier(DirectoriesMixin, TestCase): self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12]) - @mock.patch("documents.classifier.pickle.load") - def test_load_corrupt_file(self, patched_pickle_load: mock.MagicMock) -> None: + def test_load_corrupt_file(self) -> None: """ GIVEN: - Corrupted classifier pickle file @@ -378,36 +378,90 @@ class TestClassifier(DirectoriesMixin, TestCase): """ self.generate_train_and_save() - # First load is the schema version,allow it - patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()] + # Write garbage data (valid HMAC length but invalid content) + Path(settings.MODEL_FILE).write_bytes(b"\x00" * 64) with self.assertRaises(ClassifierModelCorruptError): self.classifier.load() - patched_pickle_load.assert_called() - - patched_pickle_load.reset_mock() - patched_pickle_load.side_effect = [ - DocumentClassifier.FORMAT_VERSION, - ClassifierModelCorruptError(), - ] self.assertIsNone(load_classifier()) - patched_pickle_load.assert_called() + + def test_load_tampered_file(self) -> None: + """ + GIVEN: + - A classifier model file whose data has been modified + WHEN: + - An attempt is made to load the classifier + THEN: + - The ClassifierModelCorruptError is raised due to HMAC mismatch + """ + self.generate_train_and_save() + + raw = Path(settings.MODEL_FILE).read_bytes() + # Flip a byte in the data portion (after the 32-byte HMAC) + tampered = raw[:32] + bytes([raw[32] ^ 0xFF]) + raw[33:] + Path(settings.MODEL_FILE).write_bytes(tampered) + + with self.assertRaises(ClassifierModelCorruptError): + self.classifier.load() + + def test_load_wrong_secret_key(self) -> None: + """ + GIVEN: + - A classifier model file signed with a different SECRET_KEY + WHEN: + - An attempt is made to load the classifier + THEN: + - The ClassifierModelCorruptError is raised due to HMAC mismatch + """ + self.generate_train_and_save() + + with override_settings(SECRET_KEY="different-secret-key"): + with self.assertRaises(ClassifierModelCorruptError): + self.classifier.load() + + def test_load_truncated_file(self) -> None: + """ + GIVEN: + - A classifier model file that is too short to contain an HMAC + WHEN: + - An attempt is made to load the classifier + THEN: + - The ClassifierModelCorruptError is raised + """ + Path(settings.MODEL_FILE).write_bytes(b"\x00" * 16) + + with self.assertRaises(ClassifierModelCorruptError): + self.classifier.load() def test_load_new_scikit_learn_version(self) -> None: """ GIVEN: - - classifier pickle file created with a different scikit-learn version + - classifier pickle file triggers an InconsistentVersionWarning WHEN: - An attempt is made to load the classifier THEN: - - The classifier reports the warning was captured and processed + - IncompatibleClassifierVersionError is raised """ - # TODO: This wasn't testing the warning anymore, as the schema changed - # but as it was implemented, it would require installing an old version - # rebuilding the file and committing that. Not developer friendly - # Need to rethink how to pass the load through to a file with a single - # old model? + from sklearn.exceptions import InconsistentVersionWarning + + self.generate_train_and_save() + + real_loads = pickle.loads + + def loads_with_warning(data): + warnings.warn( + "Trying to unpickle estimator from version 0.0 when using version 1.0.", + InconsistentVersionWarning, + ) + return real_loads(data) + + with mock.patch( + "documents.classifier.pickle.loads", + side_effect=loads_with_warning, + ): + with self.assertRaises(IncompatibleClassifierVersionError): + self.classifier.load() def test_one_correspondent_predict(self) -> None: c1 = Correspondent.objects.create( @@ -685,17 +739,6 @@ class TestClassifier(DirectoriesMixin, TestCase): self.assertIsNone(load_classifier()) self.assertTrue(Path(settings.MODEL_FILE).exists()) - def test_load_old_classifier_version(self) -> None: - shutil.copy( - Path(__file__).parent / "data" / "v1.17.4.model.pickle", - self.dirs.scratch_dir, - ) - with override_settings( - MODEL_FILE=self.dirs.scratch_dir / "v1.17.4.model.pickle", - ): - classifier = load_classifier() - self.assertIsNone(classifier) - @mock.patch("documents.classifier.DocumentClassifier.load") def test_load_classifier_raise_exception(self, mock_load) -> None: Path(settings.MODEL_FILE).touch()