mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-04-03 06:38:51 +00:00
Compare commits
6 Commits
feature-ta
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d365f19962 | ||
|
|
2703c12f1a | ||
|
|
e7c7978d67 | ||
|
|
83501757df | ||
|
|
dda05a7c00 | ||
|
|
376af81b9c |
@@ -237,8 +237,8 @@ RUN set -eux \
|
||||
&& echo "Adjusting all permissions" \
|
||||
&& chown --from root:root --changes --recursive paperless:paperless /usr/src/paperless \
|
||||
&& echo "Collecting static files" \
|
||||
&& s6-setuidgid paperless python3 manage.py collectstatic --clear --no-input --link \
|
||||
&& s6-setuidgid paperless python3 manage.py compilemessages \
|
||||
&& PAPERLESS_SECRET_KEY=build-time-dummy s6-setuidgid paperless python3 manage.py collectstatic --clear --no-input --link \
|
||||
&& PAPERLESS_SECRET_KEY=build-time-dummy s6-setuidgid paperless python3 manage.py compilemessages \
|
||||
&& /usr/local/bin/deduplicate.py --verbose /usr/src/paperless/static/
|
||||
|
||||
VOLUME ["/usr/src/paperless/data", \
|
||||
|
||||
@@ -17,9 +17,9 @@
|
||||
# (if doing so please consider security measures such as reverse proxy)
|
||||
#PAPERLESS_URL=https://paperless.example.com
|
||||
|
||||
# Adjust this key if you plan to make paperless available publicly. It should
|
||||
# be a very long sequence of random characters. You don't need to remember it.
|
||||
#PAPERLESS_SECRET_KEY=change-me
|
||||
# Required. A unique secret key for session tokens and signing.
|
||||
# Generate with: python3 -c "import secrets; print(secrets.token_urlsafe(64))"
|
||||
PAPERLESS_SECRET_KEY=change-me
|
||||
|
||||
# Use this variable to set a timezone for the Paperless Docker containers. Defaults to UTC.
|
||||
#PAPERLESS_TIME_ZONE=America/Los_Angeles
|
||||
|
||||
@@ -402,6 +402,12 @@ Defaults to `/usr/share/nltk_data`
|
||||
|
||||
: This is where paperless will store the classification model.
|
||||
|
||||
!!! warning
|
||||
|
||||
The classification model uses Python's pickle serialization format.
|
||||
Ensure this file is only writable by the paperless user, as a
|
||||
maliciously crafted model file could execute arbitrary code when loaded.
|
||||
|
||||
Defaults to `PAPERLESS_DATA_DIR/classification_model.pickle`.
|
||||
|
||||
## Logging
|
||||
@@ -422,14 +428,20 @@ Defaults to `/usr/share/nltk_data`
|
||||
|
||||
#### [`PAPERLESS_SECRET_KEY=<key>`](#PAPERLESS_SECRET_KEY) {#PAPERLESS_SECRET_KEY}
|
||||
|
||||
: Paperless uses this to make session tokens. If you expose paperless
|
||||
on the internet, you need to change this, since the default secret
|
||||
is well known.
|
||||
: **Required.** Paperless uses this to make session tokens and sign
|
||||
sensitive data. Paperless will refuse to start if this is not set.
|
||||
|
||||
Use any sequence of characters. The more, the better. You don't
|
||||
need to remember this. Just face-roll your keyboard.
|
||||
need to remember this. You can generate a suitable key with:
|
||||
|
||||
Default is listed in the file `src/paperless/settings.py`.
|
||||
python3 -c "import secrets; print(secrets.token_urlsafe(64))"
|
||||
|
||||
!!! warning
|
||||
|
||||
This setting has no default value. You **must** set it before
|
||||
starting Paperless. Existing installations that relied on the
|
||||
previous default value should set `PAPERLESS_SECRET_KEY` to
|
||||
that value to avoid invalidating existing sessions and tokens.
|
||||
|
||||
#### [`PAPERLESS_URL=<url>`](#PAPERLESS_URL) {#PAPERLESS_URL}
|
||||
|
||||
@@ -770,6 +782,14 @@ If both the [PAPERLESS_ACCOUNT_DEFAULT_GROUPS](#PAPERLESS_ACCOUNT_DEFAULT_GROUPS
|
||||
|
||||
Defaults to 1209600 (2 weeks)
|
||||
|
||||
#### [`PAPERLESS_TOKEN_THROTTLE_RATE=<rate>`](#PAPERLESS_TOKEN_THROTTLE_RATE) {#PAPERLESS_TOKEN_THROTTLE_RATE}
|
||||
|
||||
: Rate limit for the API token authentication endpoint (`/api/token/`), used to mitigate brute-force login attempts.
|
||||
Uses Django REST Framework's [throttle rate format](https://www.django-rest-framework.org/api-guide/throttling/#setting-the-throttling-policy),
|
||||
e.g. `5/min`, `100/hour`, `1000/day`.
|
||||
|
||||
Defaults to `5/min`
|
||||
|
||||
## OCR settings {#ocr}
|
||||
|
||||
Paperless uses [OCRmyPDF](https://ocrmypdf.readthedocs.io/en/latest/)
|
||||
@@ -1420,6 +1440,14 @@ ports.
|
||||
|
||||
## Incoming Mail {#incoming_mail}
|
||||
|
||||
#### [`PAPERLESS_EMAIL_ALLOW_INTERNAL_HOSTS=<bool>`](#PAPERLESS_EMAIL_ALLOW_INTERNAL_HOSTS) {#PAPERLESS_EMAIL_ALLOW_INTERNAL_HOSTS}
|
||||
|
||||
: If set to false, incoming mail account connections are blocked when the
|
||||
configured IMAP hostname resolves to a non-public address (for example,
|
||||
localhost, link-local, or RFC1918 private ranges).
|
||||
|
||||
Defaults to true, which allows internal hosts.
|
||||
|
||||
### Email OAuth {#email_oauth}
|
||||
|
||||
#### [`PAPERLESS_OAUTH_CALLBACK_BASE_URL=<str>`](#PAPERLESS_OAUTH_CALLBACK_BASE_URL) {#PAPERLESS_OAUTH_CALLBACK_BASE_URL}
|
||||
|
||||
@@ -23,7 +23,8 @@
|
||||
|
||||
# Security and hosting
|
||||
|
||||
#PAPERLESS_SECRET_KEY=change-me
|
||||
# Required. Generate with: python3 -c "import secrets; print(secrets.token_urlsafe(64))"
|
||||
PAPERLESS_SECRET_KEY=change-me
|
||||
#PAPERLESS_URL=https://example.com
|
||||
#PAPERLESS_CSRF_TRUSTED_ORIGINS=https://example.com # can be set using PAPERLESS_URL
|
||||
#PAPERLESS_ALLOWED_HOSTS=example.com,www.example.com # can be set using PAPERLESS_URL
|
||||
|
||||
@@ -315,9 +315,12 @@ markers = [
|
||||
]
|
||||
|
||||
[tool.pytest_env]
|
||||
PAPERLESS_SECRET_KEY = "test-secret-key-do-not-use-in-production"
|
||||
PAPERLESS_DISABLE_DBHANDLER = "true"
|
||||
PAPERLESS_CACHE_BACKEND = "django.core.cache.backends.locmem.LocMemCache"
|
||||
PAPERLESS_CHANNELS_BACKEND = "channels.layers.InMemoryChannelLayer"
|
||||
# I don't think anything hits this, but just in case, basically infinite
|
||||
PAPERLESS_TOKEN_THROTTLE_RATE = "1000/min"
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_also = [
|
||||
|
||||
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import regex as regex_mod
|
||||
from django.conf import settings
|
||||
from pdf2image import convert_from_path
|
||||
from pikepdf import Page
|
||||
@@ -22,6 +23,8 @@ from documents.plugins.base import ConsumeTaskPlugin
|
||||
from documents.plugins.base import StopConsumeTaskError
|
||||
from documents.plugins.helpers import ProgressManager
|
||||
from documents.plugins.helpers import ProgressStatusOptions
|
||||
from documents.regex import safe_regex_match
|
||||
from documents.regex import safe_regex_sub
|
||||
from documents.utils import copy_basic_file_stats
|
||||
from documents.utils import copy_file_with_basic_stats
|
||||
from documents.utils import maybe_override_pixel_limit
|
||||
@@ -68,8 +71,8 @@ class Barcode:
|
||||
Note: This does NOT exclude ASN or separator barcodes - they can also be used
|
||||
as tags if they match a tag mapping pattern (e.g., {"ASN12.*": "JOHN"}).
|
||||
"""
|
||||
for regex in self.settings.barcode_tag_mapping:
|
||||
if re.match(regex, self.value, flags=re.IGNORECASE):
|
||||
for pattern in self.settings.barcode_tag_mapping:
|
||||
if safe_regex_match(pattern, self.value, flags=regex_mod.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -392,11 +395,16 @@ class BarcodePlugin(ConsumeTaskPlugin):
|
||||
for raw in tag_texts.split(","):
|
||||
try:
|
||||
tag_str: str | None = None
|
||||
for regex in self.settings.barcode_tag_mapping:
|
||||
if re.match(regex, raw, flags=re.IGNORECASE):
|
||||
sub = self.settings.barcode_tag_mapping[regex]
|
||||
for pattern in self.settings.barcode_tag_mapping:
|
||||
if safe_regex_match(pattern, raw, flags=regex_mod.IGNORECASE):
|
||||
sub = self.settings.barcode_tag_mapping[pattern]
|
||||
tag_str = (
|
||||
re.sub(regex, sub, raw, flags=re.IGNORECASE)
|
||||
safe_regex_sub(
|
||||
pattern,
|
||||
sub,
|
||||
raw,
|
||||
flags=regex_mod.IGNORECASE,
|
||||
)
|
||||
if sub
|
||||
else raw
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
import pickle
|
||||
import re
|
||||
@@ -75,7 +76,7 @@ def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | No
|
||||
"Unrecoverable error while loading document "
|
||||
"classification model, deleting model file.",
|
||||
)
|
||||
Path(settings.MODEL_FILE).unlink
|
||||
Path(settings.MODEL_FILE).unlink()
|
||||
classifier = None
|
||||
if raise_exception:
|
||||
raise e
|
||||
@@ -97,7 +98,10 @@ class DocumentClassifier:
|
||||
# v7 - Updated scikit-learn package version
|
||||
# v8 - Added storage path classifier
|
||||
# v9 - Changed from hashing to time/ids for re-train check
|
||||
FORMAT_VERSION = 9
|
||||
# v10 - HMAC-signed model file
|
||||
FORMAT_VERSION = 10
|
||||
|
||||
HMAC_SIZE = 32 # SHA-256 digest length
|
||||
|
||||
def __init__(self) -> None:
|
||||
# last time a document changed and therefore training might be required
|
||||
@@ -128,67 +132,89 @@ class DocumentClassifier:
|
||||
pickle.dumps(self.data_vectorizer),
|
||||
).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _compute_hmac(data: bytes) -> bytes:
|
||||
return hmac.new(
|
||||
settings.SECRET_KEY.encode(),
|
||||
data,
|
||||
sha256,
|
||||
).digest()
|
||||
|
||||
def load(self) -> None:
|
||||
from sklearn.exceptions import InconsistentVersionWarning
|
||||
|
||||
raw = Path(settings.MODEL_FILE).read_bytes()
|
||||
|
||||
if len(raw) <= self.HMAC_SIZE:
|
||||
raise ClassifierModelCorruptError
|
||||
|
||||
signature = raw[: self.HMAC_SIZE]
|
||||
data = raw[self.HMAC_SIZE :]
|
||||
|
||||
if not hmac.compare_digest(signature, self._compute_hmac(data)):
|
||||
raise ClassifierModelCorruptError
|
||||
|
||||
# Catch warnings for processing
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with Path(settings.MODEL_FILE).open("rb") as f:
|
||||
schema_version = pickle.load(f)
|
||||
try:
|
||||
(
|
||||
schema_version,
|
||||
self.last_doc_change_time,
|
||||
self.last_auto_type_hash,
|
||||
self.data_vectorizer,
|
||||
self.tags_binarizer,
|
||||
self.tags_classifier,
|
||||
self.correspondent_classifier,
|
||||
self.document_type_classifier,
|
||||
self.storage_path_classifier,
|
||||
) = pickle.loads(data)
|
||||
except Exception as err:
|
||||
raise ClassifierModelCorruptError from err
|
||||
|
||||
if schema_version != self.FORMAT_VERSION:
|
||||
raise IncompatibleClassifierVersionError(
|
||||
"Cannot load classifier, incompatible versions.",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.last_doc_change_time = pickle.load(f)
|
||||
self.last_auto_type_hash = pickle.load(f)
|
||||
|
||||
self.data_vectorizer = pickle.load(f)
|
||||
self._update_data_vectorizer_hash()
|
||||
self.tags_binarizer = pickle.load(f)
|
||||
|
||||
self.tags_classifier = pickle.load(f)
|
||||
self.correspondent_classifier = pickle.load(f)
|
||||
self.document_type_classifier = pickle.load(f)
|
||||
self.storage_path_classifier = pickle.load(f)
|
||||
except Exception as err:
|
||||
raise ClassifierModelCorruptError from err
|
||||
|
||||
# Check for the warning about unpickling from differing versions
|
||||
# and consider it incompatible
|
||||
sk_learn_warning_url = (
|
||||
"https://scikit-learn.org/stable/"
|
||||
"model_persistence.html"
|
||||
"#security-maintainability-limitations"
|
||||
if schema_version != self.FORMAT_VERSION:
|
||||
raise IncompatibleClassifierVersionError(
|
||||
"Cannot load classifier, incompatible versions.",
|
||||
)
|
||||
for warning in w:
|
||||
# The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet
|
||||
if issubclass(warning.category, InconsistentVersionWarning) or (
|
||||
issubclass(warning.category, UserWarning)
|
||||
and sk_learn_warning_url in str(warning.message)
|
||||
):
|
||||
raise IncompatibleClassifierVersionError("sklearn version update")
|
||||
|
||||
self._update_data_vectorizer_hash()
|
||||
|
||||
# Check for the warning about unpickling from differing versions
|
||||
# and consider it incompatible
|
||||
sk_learn_warning_url = (
|
||||
"https://scikit-learn.org/stable/"
|
||||
"model_persistence.html"
|
||||
"#security-maintainability-limitations"
|
||||
)
|
||||
for warning in w:
|
||||
# The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet
|
||||
if issubclass(warning.category, InconsistentVersionWarning) or (
|
||||
issubclass(warning.category, UserWarning)
|
||||
and sk_learn_warning_url in str(warning.message)
|
||||
):
|
||||
raise IncompatibleClassifierVersionError("sklearn version update")
|
||||
|
||||
def save(self) -> None:
|
||||
target_file: Path = settings.MODEL_FILE
|
||||
target_file_temp: Path = target_file.with_suffix(".pickle.part")
|
||||
|
||||
data = pickle.dumps(
|
||||
(
|
||||
self.FORMAT_VERSION,
|
||||
self.last_doc_change_time,
|
||||
self.last_auto_type_hash,
|
||||
self.data_vectorizer,
|
||||
self.tags_binarizer,
|
||||
self.tags_classifier,
|
||||
self.correspondent_classifier,
|
||||
self.document_type_classifier,
|
||||
self.storage_path_classifier,
|
||||
),
|
||||
)
|
||||
|
||||
signature = self._compute_hmac(data)
|
||||
|
||||
with target_file_temp.open("wb") as f:
|
||||
pickle.dump(self.FORMAT_VERSION, f)
|
||||
|
||||
pickle.dump(self.last_doc_change_time, f)
|
||||
pickle.dump(self.last_auto_type_hash, f)
|
||||
|
||||
pickle.dump(self.data_vectorizer, f)
|
||||
|
||||
pickle.dump(self.tags_binarizer, f)
|
||||
pickle.dump(self.tags_classifier, f)
|
||||
|
||||
pickle.dump(self.correspondent_classifier, f)
|
||||
pickle.dump(self.document_type_classifier, f)
|
||||
pickle.dump(self.storage_path_classifier, f)
|
||||
f.write(signature + data)
|
||||
|
||||
target_file_temp.rename(target_file)
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import datetime
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from re import Match
|
||||
|
||||
import regex
|
||||
from regex import Match
|
||||
|
||||
from documents.plugins.date_parsing.base import DateParserPluginBase
|
||||
from documents.regex import safe_regex_finditer
|
||||
|
||||
|
||||
class RegexDateParserPlugin(DateParserPluginBase):
|
||||
@@ -14,7 +16,7 @@ class RegexDateParserPlugin(DateParserPluginBase):
|
||||
passed to its constructor.
|
||||
"""
|
||||
|
||||
DATE_REGEX = re.compile(
|
||||
DATE_REGEX = regex.compile(
|
||||
r"(\b|(?!=([_-])))(\d{1,2})[\.\/-](\d{1,2})[\.\/-](\d{4}|\d{2})(\b|(?=([_-])))|"
|
||||
r"(\b|(?!=([_-])))(\d{4}|\d{2})[\.\/-](\d{1,2})[\.\/-](\d{1,2})(\b|(?=([_-])))|"
|
||||
r"(\b|(?!=([_-])))(\d{1,2}[\. ]+[a-zéûäëčžúřěáíóńźçŞğü]{3,9} \d{4}|[a-zéûäëčžúřěáíóńźçŞğü]{3,9} \d{1,2}, \d{4})(\b|(?=([_-])))|"
|
||||
@@ -22,7 +24,7 @@ class RegexDateParserPlugin(DateParserPluginBase):
|
||||
r"(\b|(?!=([_-])))([^\W\d_]{3,9} \d{4})(\b|(?=([_-])))|"
|
||||
r"(\b|(?!=([_-])))(\d{1,2}[^ 0-9]{2}[\. ]+[^ ]{3,9}[ \.\/-]\d{4})(\b|(?=([_-])))|"
|
||||
r"(\b|(?!=([_-])))(\b\d{1,2}[ \.\/-][a-zéûäëčžúřěáíóńźçŞğü]{3}[ \.\/-]\d{4})(\b|(?=([_-])))",
|
||||
re.IGNORECASE,
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
def _process_match(
|
||||
@@ -45,7 +47,7 @@ class RegexDateParserPlugin(DateParserPluginBase):
|
||||
"""
|
||||
Finds all regex matches in content and yields valid dates.
|
||||
"""
|
||||
for m in re.finditer(self.DATE_REGEX, content):
|
||||
for m in safe_regex_finditer(self.DATE_REGEX, content):
|
||||
date = self._process_match(m, date_order)
|
||||
if date is not None:
|
||||
yield date
|
||||
|
||||
@@ -48,3 +48,73 @@ def safe_regex_search(pattern: str, text: str, *, flags: int = 0):
|
||||
textwrap.shorten(pattern, width=80, placeholder="…"),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def safe_regex_match(pattern: str, text: str, *, flags: int = 0):
|
||||
"""
|
||||
Run a regex match with a timeout. Returns a match object or None.
|
||||
Validation errors and timeouts are logged and treated as no match.
|
||||
"""
|
||||
|
||||
try:
|
||||
validate_regex_pattern(pattern)
|
||||
compiled = regex.compile(pattern, flags=flags)
|
||||
except (regex.error, ValueError) as exc:
|
||||
logger.error(
|
||||
"Error while processing regular expression %s: %s",
|
||||
textwrap.shorten(pattern, width=80, placeholder="…"),
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
return compiled.match(text, timeout=REGEX_TIMEOUT_SECONDS)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Regular expression matching timed out for pattern %s",
|
||||
textwrap.shorten(pattern, width=80, placeholder="…"),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def safe_regex_sub(pattern: str, repl: str, text: str, *, flags: int = 0) -> str | None:
|
||||
"""
|
||||
Run a regex substitution with a timeout. Returns the substituted string,
|
||||
or None on error/timeout.
|
||||
"""
|
||||
|
||||
try:
|
||||
validate_regex_pattern(pattern)
|
||||
compiled = regex.compile(pattern, flags=flags)
|
||||
except (regex.error, ValueError) as exc:
|
||||
logger.error(
|
||||
"Error while processing regular expression %s: %s",
|
||||
textwrap.shorten(pattern, width=80, placeholder="…"),
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
return compiled.sub(repl, text, timeout=REGEX_TIMEOUT_SECONDS)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Regular expression substitution timed out for pattern %s",
|
||||
textwrap.shorten(pattern, width=80, placeholder="…"),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def safe_regex_finditer(compiled_pattern: regex.Pattern, text: str):
|
||||
"""
|
||||
Run regex finditer with a timeout. Yields match objects.
|
||||
Stops iteration on timeout.
|
||||
"""
|
||||
|
||||
try:
|
||||
yield from compiled_pattern.finditer(text, timeout=REGEX_TIMEOUT_SECONDS)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Regular expression finditer timed out for pattern %s",
|
||||
textwrap.shorten(compiled_pattern.pattern, width=80, placeholder="…"),
|
||||
)
|
||||
return
|
||||
|
||||
Binary file not shown.
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
@@ -366,8 +366,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
|
||||
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
||||
|
||||
@mock.patch("documents.classifier.pickle.load")
|
||||
def test_load_corrupt_file(self, patched_pickle_load: mock.MagicMock) -> None:
|
||||
def test_load_corrupt_file(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- Corrupted classifier pickle file
|
||||
@@ -378,36 +377,116 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
"""
|
||||
self.generate_train_and_save()
|
||||
|
||||
# First load is the schema version,allow it
|
||||
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
|
||||
# Write garbage data (valid HMAC length but invalid content)
|
||||
Path(settings.MODEL_FILE).write_bytes(b"\x00" * 64)
|
||||
|
||||
with self.assertRaises(ClassifierModelCorruptError):
|
||||
self.classifier.load()
|
||||
patched_pickle_load.assert_called()
|
||||
|
||||
patched_pickle_load.reset_mock()
|
||||
patched_pickle_load.side_effect = [
|
||||
DocumentClassifier.FORMAT_VERSION,
|
||||
ClassifierModelCorruptError(),
|
||||
]
|
||||
|
||||
self.assertIsNone(load_classifier())
|
||||
patched_pickle_load.assert_called()
|
||||
|
||||
def test_load_corrupt_pickle_valid_hmac(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A classifier file with valid HMAC but unparsable pickle data
|
||||
WHEN:
|
||||
- An attempt is made to load the classifier
|
||||
THEN:
|
||||
- The ClassifierModelCorruptError is raised
|
||||
"""
|
||||
garbage_data = b"this is not valid pickle data"
|
||||
signature = DocumentClassifier._compute_hmac(garbage_data)
|
||||
Path(settings.MODEL_FILE).write_bytes(signature + garbage_data)
|
||||
|
||||
with self.assertRaises(ClassifierModelCorruptError):
|
||||
self.classifier.load()
|
||||
|
||||
def test_load_tampered_file(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A classifier model file whose data has been modified
|
||||
WHEN:
|
||||
- An attempt is made to load the classifier
|
||||
THEN:
|
||||
- The ClassifierModelCorruptError is raised due to HMAC mismatch
|
||||
"""
|
||||
self.generate_train_and_save()
|
||||
|
||||
raw = Path(settings.MODEL_FILE).read_bytes()
|
||||
# Flip a byte in the data portion (after the 32-byte HMAC)
|
||||
tampered = raw[:32] + bytes([raw[32] ^ 0xFF]) + raw[33:]
|
||||
Path(settings.MODEL_FILE).write_bytes(tampered)
|
||||
|
||||
with self.assertRaises(ClassifierModelCorruptError):
|
||||
self.classifier.load()
|
||||
|
||||
def test_load_wrong_secret_key(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A classifier model file signed with a different SECRET_KEY
|
||||
WHEN:
|
||||
- An attempt is made to load the classifier
|
||||
THEN:
|
||||
- The ClassifierModelCorruptError is raised due to HMAC mismatch
|
||||
"""
|
||||
self.generate_train_and_save()
|
||||
|
||||
with override_settings(SECRET_KEY="different-secret-key"):
|
||||
with self.assertRaises(ClassifierModelCorruptError):
|
||||
self.classifier.load()
|
||||
|
||||
def test_load_truncated_file(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A classifier model file that is too short to contain an HMAC
|
||||
WHEN:
|
||||
- An attempt is made to load the classifier
|
||||
THEN:
|
||||
- The ClassifierModelCorruptError is raised
|
||||
"""
|
||||
Path(settings.MODEL_FILE).write_bytes(b"\x00" * 16)
|
||||
|
||||
with self.assertRaises(ClassifierModelCorruptError):
|
||||
self.classifier.load()
|
||||
|
||||
def test_load_new_scikit_learn_version(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- classifier pickle file created with a different scikit-learn version
|
||||
- classifier pickle file triggers an InconsistentVersionWarning
|
||||
WHEN:
|
||||
- An attempt is made to load the classifier
|
||||
THEN:
|
||||
- The classifier reports the warning was captured and processed
|
||||
- IncompatibleClassifierVersionError is raised
|
||||
"""
|
||||
# TODO: This wasn't testing the warning anymore, as the schema changed
|
||||
# but as it was implemented, it would require installing an old version
|
||||
# rebuilding the file and committing that. Not developer friendly
|
||||
# Need to rethink how to pass the load through to a file with a single
|
||||
# old model?
|
||||
from sklearn.exceptions import InconsistentVersionWarning
|
||||
|
||||
self.generate_train_and_save()
|
||||
|
||||
fake_warning = warnings.WarningMessage(
|
||||
message=InconsistentVersionWarning(
|
||||
estimator_name="MLPClassifier",
|
||||
current_sklearn_version="1.0",
|
||||
original_sklearn_version="0.9",
|
||||
),
|
||||
category=InconsistentVersionWarning,
|
||||
filename="",
|
||||
lineno=0,
|
||||
)
|
||||
|
||||
real_catch_warnings = warnings.catch_warnings
|
||||
|
||||
class PatchedCatchWarnings(real_catch_warnings):
|
||||
def __enter__(self):
|
||||
w = super().__enter__()
|
||||
w.append(fake_warning)
|
||||
return w
|
||||
|
||||
with mock.patch(
|
||||
"documents.classifier.warnings.catch_warnings",
|
||||
PatchedCatchWarnings,
|
||||
):
|
||||
with self.assertRaises(IncompatibleClassifierVersionError):
|
||||
self.classifier.load()
|
||||
|
||||
def test_one_correspondent_predict(self) -> None:
|
||||
c1 = Correspondent.objects.create(
|
||||
@@ -685,17 +764,6 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
self.assertIsNone(load_classifier())
|
||||
self.assertTrue(Path(settings.MODEL_FILE).exists())
|
||||
|
||||
def test_load_old_classifier_version(self) -> None:
|
||||
shutil.copy(
|
||||
Path(__file__).parent / "data" / "v1.17.4.model.pickle",
|
||||
self.dirs.scratch_dir,
|
||||
)
|
||||
with override_settings(
|
||||
MODEL_FILE=self.dirs.scratch_dir / "v1.17.4.model.pickle",
|
||||
):
|
||||
classifier = load_classifier()
|
||||
self.assertIsNone(classifier)
|
||||
|
||||
@mock.patch("documents.classifier.DocumentClassifier.load")
|
||||
def test_load_classifier_raise_exception(self, mock_load) -> None:
|
||||
Path(settings.MODEL_FILE).touch()
|
||||
|
||||
128
src/documents/tests/test_regex.py
Normal file
128
src/documents/tests/test_regex.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import pytest
|
||||
import regex
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from documents.regex import safe_regex_finditer
|
||||
from documents.regex import safe_regex_match
|
||||
from documents.regex import safe_regex_search
|
||||
from documents.regex import safe_regex_sub
|
||||
from documents.regex import validate_regex_pattern
|
||||
|
||||
|
||||
class TestValidateRegexPattern:
|
||||
def test_valid_pattern(self):
|
||||
validate_regex_pattern(r"\d+")
|
||||
|
||||
def test_invalid_pattern_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
validate_regex_pattern(r"[invalid")
|
||||
|
||||
|
||||
class TestSafeRegexSearchAndMatch:
|
||||
"""Tests for safe_regex_search and safe_regex_match (same contract)."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("func", "pattern", "text", "expected_group"),
|
||||
[
|
||||
pytest.param(
|
||||
safe_regex_search,
|
||||
r"\d+",
|
||||
"abc123def",
|
||||
"123",
|
||||
id="search-match-found",
|
||||
),
|
||||
pytest.param(
|
||||
safe_regex_match,
|
||||
r"\d+",
|
||||
"123abc",
|
||||
"123",
|
||||
id="match-match-found",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_match_found(self, func, pattern, text, expected_group):
|
||||
result = func(pattern, text)
|
||||
assert result is not None
|
||||
assert result.group() == expected_group
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("func", "pattern", "text"),
|
||||
[
|
||||
pytest.param(safe_regex_search, r"\d+", "abcdef", id="search-no-match"),
|
||||
pytest.param(safe_regex_match, r"\d+", "abc123", id="match-no-match"),
|
||||
],
|
||||
)
|
||||
def test_no_match(self, func, pattern, text):
|
||||
assert func(pattern, text) is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func",
|
||||
[
|
||||
pytest.param(safe_regex_search, id="search"),
|
||||
pytest.param(safe_regex_match, id="match"),
|
||||
],
|
||||
)
|
||||
def test_invalid_pattern_returns_none(self, func):
|
||||
assert func(r"[invalid", "test") is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func",
|
||||
[
|
||||
pytest.param(safe_regex_search, id="search"),
|
||||
pytest.param(safe_regex_match, id="match"),
|
||||
],
|
||||
)
|
||||
def test_flags_respected(self, func):
|
||||
assert func(r"abc", "ABC", flags=regex.IGNORECASE) is not None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("func", "method_name"),
|
||||
[
|
||||
pytest.param(safe_regex_search, "search", id="search"),
|
||||
pytest.param(safe_regex_match, "match", id="match"),
|
||||
],
|
||||
)
|
||||
def test_timeout_returns_none(self, func, method_name, mocker: MockerFixture):
|
||||
mock_compile = mocker.patch("documents.regex.regex.compile")
|
||||
getattr(mock_compile.return_value, method_name).side_effect = TimeoutError
|
||||
assert func(r"\d+", "test") is None
|
||||
|
||||
|
||||
class TestSafeRegexSub:
|
||||
@pytest.mark.parametrize(
|
||||
("pattern", "repl", "text", "expected"),
|
||||
[
|
||||
pytest.param(r"\d+", "NUM", "abc123def456", "abcNUMdefNUM", id="basic-sub"),
|
||||
pytest.param(r"\d+", "NUM", "abcdef", "abcdef", id="no-match"),
|
||||
pytest.param(r"abc", "X", "ABC", "X", id="flags"),
|
||||
],
|
||||
)
|
||||
def test_substitution(self, pattern, repl, text, expected):
|
||||
flags = regex.IGNORECASE if pattern == r"abc" else 0
|
||||
result = safe_regex_sub(pattern, repl, text, flags=flags)
|
||||
assert result == expected
|
||||
|
||||
def test_invalid_pattern_returns_none(self):
|
||||
assert safe_regex_sub(r"[invalid", "x", "test") is None
|
||||
|
||||
def test_timeout_returns_none(self, mocker: MockerFixture):
|
||||
mock_compile = mocker.patch("documents.regex.regex.compile")
|
||||
mock_compile.return_value.sub.side_effect = TimeoutError
|
||||
assert safe_regex_sub(r"\d+", "X", "test") is None
|
||||
|
||||
|
||||
class TestSafeRegexFinditer:
|
||||
def test_yields_matches(self):
|
||||
pattern = regex.compile(r"\d+")
|
||||
matches = list(safe_regex_finditer(pattern, "a1b22c333"))
|
||||
assert [m.group() for m in matches] == ["1", "22", "333"]
|
||||
|
||||
def test_no_matches(self):
|
||||
pattern = regex.compile(r"\d+")
|
||||
assert list(safe_regex_finditer(pattern, "abcdef")) == []
|
||||
|
||||
def test_timeout_stops_iteration(self, mocker: MockerFixture):
|
||||
mock_pattern = mocker.MagicMock()
|
||||
mock_pattern.finditer.side_effect = TimeoutError
|
||||
mock_pattern.pattern = r"\d+"
|
||||
assert list(safe_regex_finditer(mock_pattern, "test")) == []
|
||||
@@ -31,6 +31,11 @@ from paperless.models import ApplicationConfiguration
|
||||
|
||||
|
||||
class TestViews(DirectoriesMixin, TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls) -> None:
|
||||
super().setUpTestData()
|
||||
ApplicationConfiguration.objects.get_or_create()
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = User.objects.create_user("testuser")
|
||||
super().setUp()
|
||||
|
||||
@@ -2,7 +2,7 @@ msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: paperless-ngx\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2026-04-02 19:39+0000\n"
|
||||
"POT-Creation-Date: 2026-04-03 03:25+0000\n"
|
||||
"PO-Revision-Date: 2022-02-17 04:17\n"
|
||||
"Last-Translator: \n"
|
||||
"Language-Team: English\n"
|
||||
@@ -1866,151 +1866,151 @@ msgstr ""
|
||||
msgid "paperless application settings"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:524
|
||||
#: paperless/settings/__init__.py:532
|
||||
msgid "English (US)"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:525
|
||||
#: paperless/settings/__init__.py:533
|
||||
msgid "Arabic"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:526
|
||||
#: paperless/settings/__init__.py:534
|
||||
msgid "Afrikaans"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:527
|
||||
#: paperless/settings/__init__.py:535
|
||||
msgid "Belarusian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:528
|
||||
#: paperless/settings/__init__.py:536
|
||||
msgid "Bulgarian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:529
|
||||
#: paperless/settings/__init__.py:537
|
||||
msgid "Catalan"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:530
|
||||
#: paperless/settings/__init__.py:538
|
||||
msgid "Czech"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:531
|
||||
#: paperless/settings/__init__.py:539
|
||||
msgid "Danish"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:532
|
||||
#: paperless/settings/__init__.py:540
|
||||
msgid "German"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:533
|
||||
#: paperless/settings/__init__.py:541
|
||||
msgid "Greek"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:534
|
||||
#: paperless/settings/__init__.py:542
|
||||
msgid "English (GB)"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:535
|
||||
#: paperless/settings/__init__.py:543
|
||||
msgid "Spanish"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:536
|
||||
#: paperless/settings/__init__.py:544
|
||||
msgid "Persian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:537
|
||||
#: paperless/settings/__init__.py:545
|
||||
msgid "Finnish"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:538
|
||||
#: paperless/settings/__init__.py:546
|
||||
msgid "French"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:539
|
||||
#: paperless/settings/__init__.py:547
|
||||
msgid "Hungarian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:540
|
||||
#: paperless/settings/__init__.py:548
|
||||
msgid "Indonesian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:541
|
||||
#: paperless/settings/__init__.py:549
|
||||
msgid "Italian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:542
|
||||
#: paperless/settings/__init__.py:550
|
||||
msgid "Japanese"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:543
|
||||
#: paperless/settings/__init__.py:551
|
||||
msgid "Korean"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:544
|
||||
#: paperless/settings/__init__.py:552
|
||||
msgid "Luxembourgish"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:545
|
||||
#: paperless/settings/__init__.py:553
|
||||
msgid "Norwegian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:546
|
||||
#: paperless/settings/__init__.py:554
|
||||
msgid "Dutch"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:547
|
||||
#: paperless/settings/__init__.py:555
|
||||
msgid "Polish"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:548
|
||||
#: paperless/settings/__init__.py:556
|
||||
msgid "Portuguese (Brazil)"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:549
|
||||
#: paperless/settings/__init__.py:557
|
||||
msgid "Portuguese"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:550
|
||||
#: paperless/settings/__init__.py:558
|
||||
msgid "Romanian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:551
|
||||
#: paperless/settings/__init__.py:559
|
||||
msgid "Russian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:552
|
||||
#: paperless/settings/__init__.py:560
|
||||
msgid "Slovak"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:553
|
||||
#: paperless/settings/__init__.py:561
|
||||
msgid "Slovenian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:554
|
||||
#: paperless/settings/__init__.py:562
|
||||
msgid "Serbian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:555
|
||||
#: paperless/settings/__init__.py:563
|
||||
msgid "Swedish"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:556
|
||||
#: paperless/settings/__init__.py:564
|
||||
msgid "Turkish"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:557
|
||||
#: paperless/settings/__init__.py:565
|
||||
msgid "Ukrainian"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:558
|
||||
#: paperless/settings/__init__.py:566
|
||||
msgid "Vietnamese"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:559
|
||||
#: paperless/settings/__init__.py:567
|
||||
msgid "Chinese Simplified"
|
||||
msgstr ""
|
||||
|
||||
#: paperless/settings/__init__.py:560
|
||||
#: paperless/settings/__init__.py:568
|
||||
msgid "Chinese Traditional"
|
||||
msgstr ""
|
||||
|
||||
|
||||
@@ -1,11 +1,59 @@
|
||||
import hmac
|
||||
import os
|
||||
import pickle
|
||||
from hashlib import sha256
|
||||
|
||||
from celery import Celery
|
||||
from celery.signals import worker_process_init
|
||||
from kombu.serialization import register
|
||||
|
||||
# Set the default Django settings module for the 'celery' program.
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "paperless.settings")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Signed-pickle serializer: pickle with HMAC-SHA256 integrity verification.
|
||||
#
|
||||
# Protects against malicious pickle injection via an exposed Redis broker.
|
||||
# Messages are signed on the producer side and verified before deserialization
|
||||
# on the worker side using Django's SECRET_KEY.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
HMAC_SIZE = 32 # SHA-256 digest length
|
||||
|
||||
|
||||
def _get_signing_key() -> bytes:
|
||||
from django.conf import settings
|
||||
|
||||
return settings.SECRET_KEY.encode()
|
||||
|
||||
|
||||
def signed_pickle_dumps(obj: object) -> bytes:
|
||||
data = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
signature = hmac.new(_get_signing_key(), data, sha256).digest()
|
||||
return signature + data
|
||||
|
||||
|
||||
def signed_pickle_loads(payload: bytes) -> object:
|
||||
if len(payload) < HMAC_SIZE:
|
||||
msg = "Signed-pickle payload too short"
|
||||
raise ValueError(msg)
|
||||
signature = payload[:HMAC_SIZE]
|
||||
data = payload[HMAC_SIZE:]
|
||||
expected = hmac.new(_get_signing_key(), data, sha256).digest()
|
||||
if not hmac.compare_digest(signature, expected):
|
||||
msg = "Signed-pickle HMAC verification failed — message may have been tampered with"
|
||||
raise ValueError(msg)
|
||||
return pickle.loads(data)
|
||||
|
||||
|
||||
register(
|
||||
"signed-pickle",
|
||||
signed_pickle_dumps,
|
||||
signed_pickle_loads,
|
||||
content_type="application/x-signed-pickle",
|
||||
content_encoding="binary",
|
||||
)
|
||||
|
||||
app = Celery("paperless")
|
||||
|
||||
# Using a string here means the worker doesn't have to serialize
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Final
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from compression_middleware.middleware import CompressionMiddleware
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -161,6 +162,9 @@ REST_FRAMEWORK = {
|
||||
"ALLOWED_VERSIONS": ["9", "10"],
|
||||
# DRF Spectacular default schema
|
||||
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
|
||||
"DEFAULT_THROTTLE_RATES": {
|
||||
"login": os.getenv("PAPERLESS_TOKEN_THROTTLE_RATE", "5/min"),
|
||||
},
|
||||
}
|
||||
|
||||
if DEBUG:
|
||||
@@ -460,13 +464,13 @@ SECURE_PROXY_SSL_HEADER = (
|
||||
else None
|
||||
)
|
||||
|
||||
# The secret key has a default that should be fine so long as you're hosting
|
||||
# Paperless on a closed network. However, if you're putting this anywhere
|
||||
# public, you should change the key to something unique and verbose.
|
||||
SECRET_KEY = os.getenv(
|
||||
"PAPERLESS_SECRET_KEY",
|
||||
"e11fl1oa-*ytql8p)(06fbj4ukrlo+n7k&q5+$1md7i+mge=ee",
|
||||
)
|
||||
SECRET_KEY = os.getenv("PAPERLESS_SECRET_KEY", "")
|
||||
if not SECRET_KEY: # pragma: no cover
|
||||
raise ImproperlyConfigured(
|
||||
"PAPERLESS_SECRET_KEY is not set. "
|
||||
"A unique, secret key is required for secure operation. "
|
||||
'Generate one with: python3 -c "import secrets; print(secrets.token_urlsafe(64))"',
|
||||
)
|
||||
|
||||
AUTH_PASSWORD_VALIDATORS = [
|
||||
{
|
||||
@@ -497,6 +501,10 @@ SESSION_COOKIE_NAME = f"{COOKIE_PREFIX}sessionid"
|
||||
LANGUAGE_COOKIE_NAME = f"{COOKIE_PREFIX}django_language"
|
||||
|
||||
EMAIL_CERTIFICATE_FILE = get_path_from_env("PAPERLESS_EMAIL_CERTIFICATE_LOCATION")
|
||||
EMAIL_ALLOW_INTERNAL_HOSTS = get_bool_from_env(
|
||||
"PAPERLESS_EMAIL_ALLOW_INTERNAL_HOSTS",
|
||||
"true",
|
||||
)
|
||||
|
||||
|
||||
###############################################################################
|
||||
@@ -667,9 +675,11 @@ CELERY_RESULT_BACKEND = "django-db"
|
||||
CELERY_CACHE_BACKEND = "default"
|
||||
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#task-serializer
|
||||
CELERY_TASK_SERIALIZER = "pickle"
|
||||
# Uses HMAC-signed pickle to prevent RCE via malicious messages on an exposed Redis broker.
|
||||
# The signed-pickle serializer is registered in paperless/celery.py.
|
||||
CELERY_TASK_SERIALIZER = "signed-pickle"
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#std-setting-accept_content
|
||||
CELERY_ACCEPT_CONTENT = ["application/json", "application/x-python-serialize"]
|
||||
CELERY_ACCEPT_CONTENT = ["application/json", "application/x-signed-pickle"]
|
||||
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#beat-schedule
|
||||
CELERY_BEAT_SCHEDULE = parse_beat_schedule()
|
||||
|
||||
69
src/paperless/tests/test_celery.py
Normal file
69
src/paperless/tests/test_celery.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import hmac
|
||||
import pickle
|
||||
from hashlib import sha256
|
||||
|
||||
import pytest
|
||||
from django.test import override_settings
|
||||
|
||||
from paperless.celery import HMAC_SIZE
|
||||
from paperless.celery import signed_pickle_dumps
|
||||
from paperless.celery import signed_pickle_loads
|
||||
|
||||
|
||||
class TestSignedPickleSerializer:
|
||||
def test_roundtrip_simple_types(self):
|
||||
"""Signed pickle can round-trip basic JSON-like types."""
|
||||
for obj in [42, "hello", [1, 2, 3], {"key": "value"}, None, True]:
|
||||
assert signed_pickle_loads(signed_pickle_dumps(obj)) == obj
|
||||
|
||||
def test_roundtrip_complex_types(self):
|
||||
"""Signed pickle can round-trip types that JSON cannot."""
|
||||
from pathlib import Path
|
||||
|
||||
obj = {"path": Path("/tmp/test"), "data": {1, 2, 3}}
|
||||
result = signed_pickle_loads(signed_pickle_dumps(obj))
|
||||
assert result["path"] == Path("/tmp/test")
|
||||
assert result["data"] == {1, 2, 3}
|
||||
|
||||
def test_tampered_data_rejected(self):
|
||||
"""Flipping a byte in the data portion causes HMAC failure."""
|
||||
payload = signed_pickle_dumps({"task": "test"})
|
||||
tampered = bytearray(payload)
|
||||
tampered[-1] ^= 0xFF
|
||||
with pytest.raises(ValueError, match="HMAC verification failed"):
|
||||
signed_pickle_loads(bytes(tampered))
|
||||
|
||||
def test_tampered_signature_rejected(self):
|
||||
"""Flipping a byte in the signature portion causes HMAC failure."""
|
||||
payload = signed_pickle_dumps({"task": "test"})
|
||||
tampered = bytearray(payload)
|
||||
tampered[0] ^= 0xFF
|
||||
with pytest.raises(ValueError, match="HMAC verification failed"):
|
||||
signed_pickle_loads(bytes(tampered))
|
||||
|
||||
def test_truncated_payload_rejected(self):
|
||||
"""A payload shorter than HMAC_SIZE is rejected."""
|
||||
with pytest.raises(ValueError, match="too short"):
|
||||
signed_pickle_loads(b"\x00" * (HMAC_SIZE - 1))
|
||||
|
||||
def test_empty_payload_rejected(self):
|
||||
with pytest.raises(ValueError, match="too short"):
|
||||
signed_pickle_loads(b"")
|
||||
|
||||
@override_settings(SECRET_KEY="different-secret-key")
|
||||
def test_wrong_secret_key_rejected(self):
|
||||
"""A message signed with one key cannot be loaded with another."""
|
||||
original_key = b"test-secret-key-do-not-use-in-production"
|
||||
obj = {"task": "test"}
|
||||
data = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
signature = hmac.new(original_key, data, sha256).digest()
|
||||
payload = signature + data
|
||||
with pytest.raises(ValueError, match="HMAC verification failed"):
|
||||
signed_pickle_loads(payload)
|
||||
|
||||
def test_forged_pickle_rejected(self):
|
||||
"""A raw pickle payload (no signature) is rejected."""
|
||||
raw_pickle = pickle.dumps({"task": "test"})
|
||||
# Raw pickle won't have a valid HMAC prefix
|
||||
with pytest.raises(ValueError, match="HMAC verification failed"):
|
||||
signed_pickle_loads(b"\x00" * HMAC_SIZE + raw_pickle)
|
||||
@@ -34,6 +34,7 @@ from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.permissions import DjangoModelPermissions
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.throttling import ScopedRateThrottle
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from documents.permissions import PaperlessObjectPermissions
|
||||
@@ -51,6 +52,8 @@ from paperless_ai.indexing import vector_store_file_exists
|
||||
|
||||
class PaperlessObtainAuthTokenView(ObtainAuthToken):
|
||||
serializer_class = PaperlessAuthTokenSerializer
|
||||
throttle_classes = [ScopedRateThrottle]
|
||||
throttle_scope = "login"
|
||||
|
||||
|
||||
class StandardPagination(PageNumberPagination):
|
||||
|
||||
@@ -39,6 +39,8 @@ from documents.loggers import LoggingMixin
|
||||
from documents.models import Correspondent
|
||||
from documents.parsers import is_mime_type_supported
|
||||
from documents.tasks import consume_file
|
||||
from paperless.network import is_public_ip
|
||||
from paperless.network import resolve_hostname_ips
|
||||
from paperless_mail.models import MailAccount
|
||||
from paperless_mail.models import MailRule
|
||||
from paperless_mail.models import ProcessedMail
|
||||
@@ -412,6 +414,13 @@ def get_mailbox(server, port, security) -> MailBox:
|
||||
"""
|
||||
Returns the correct MailBox instance for the given configuration.
|
||||
"""
|
||||
if not settings.EMAIL_ALLOW_INTERNAL_HOSTS:
|
||||
for ip_str in resolve_hostname_ips(server):
|
||||
if not is_public_ip(ip_str):
|
||||
raise MailError(
|
||||
f"Connection blocked: {server} resolves to a non-public address",
|
||||
)
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
if settings.EMAIL_CERTIFICATE_FILE is not None: # pragma: no cover
|
||||
ssl_context.load_verify_locations(cafile=settings.EMAIL_CERTIFICATE_FILE)
|
||||
|
||||
@@ -13,6 +13,7 @@ from django.contrib.auth.models import User
|
||||
from django.core.management import call_command
|
||||
from django.db import DatabaseError
|
||||
from django.test import TestCase
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
from imap_tools import NOT
|
||||
from imap_tools import EmailAddress
|
||||
@@ -1846,6 +1847,25 @@ class TestMailAccountTestView(APITestCase):
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(response.content.decode(), "Unable to connect to server")
|
||||
|
||||
@override_settings(EMAIL_ALLOW_INTERNAL_HOSTS=False)
|
||||
@mock.patch("paperless_mail.mail.resolve_hostname_ips", return_value=["127.0.0.1"])
|
||||
def test_mail_account_test_view_blocks_internal_host_when_disabled(
|
||||
self,
|
||||
_mock_resolve_hostname_ips,
|
||||
) -> None:
|
||||
data = {
|
||||
"imap_server": "internal.example",
|
||||
"imap_port": 993,
|
||||
"imap_security": MailAccount.ImapSecurity.SSL,
|
||||
"username": "admin",
|
||||
"password": "secret",
|
||||
"account_type": MailAccount.MailAccountType.IMAP,
|
||||
"is_token": False,
|
||||
}
|
||||
response = self.client.post(self.url, data, format="json")
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertEqual(response.content.decode(), "Unable to connect to server")
|
||||
|
||||
@mock.patch(
|
||||
"paperless_mail.oauth.PaperlessMailOAuth2Manager.refresh_account_oauth_token",
|
||||
)
|
||||
|
||||
@@ -120,12 +120,12 @@ class MailAccountViewSet(ModelViewSet, PassUserMixin):
|
||||
serializer.validated_data["expiration"] = existing_account.expiration
|
||||
|
||||
account = MailAccount(**serializer.validated_data)
|
||||
with get_mailbox(
|
||||
account.imap_server,
|
||||
account.imap_port,
|
||||
account.imap_security,
|
||||
) as M:
|
||||
try:
|
||||
try:
|
||||
with get_mailbox(
|
||||
account.imap_server,
|
||||
account.imap_port,
|
||||
account.imap_security,
|
||||
) as M:
|
||||
if (
|
||||
existing_account is not None
|
||||
and account.is_token
|
||||
@@ -145,11 +145,11 @@ class MailAccountViewSet(ModelViewSet, PassUserMixin):
|
||||
|
||||
mailbox_login(M, account)
|
||||
return Response({"success": True})
|
||||
except MailError:
|
||||
logger.error(
|
||||
"Mail account connectivity test failed",
|
||||
)
|
||||
return HttpResponseBadRequest("Unable to connect to server")
|
||||
except MailError:
|
||||
logger.error(
|
||||
"Mail account connectivity test failed",
|
||||
)
|
||||
return HttpResponseBadRequest("Unable to connect to server")
|
||||
|
||||
@action(methods=["post"], detail=True)
|
||||
def process(self, request, pk=None):
|
||||
|
||||
Reference in New Issue
Block a user