mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-03-08 02:01:22 +00:00
Compare commits
6 Commits
dependabot
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2cdb1424ef | ||
|
|
f5c0c21922 | ||
|
|
91ddda9256 | ||
|
|
9d5e618de8 | ||
|
|
50ae49c7da | ||
|
|
ba023ef332 |
12
.github/workflows/ci-docker.yml
vendored
12
.github/workflows/ci-docker.yml
vendored
@@ -149,15 +149,16 @@ jobs:
|
|||||||
mkdir -p /tmp/digests
|
mkdir -p /tmp/digests
|
||||||
digest="${{ steps.build.outputs.digest }}"
|
digest="${{ steps.build.outputs.digest }}"
|
||||||
echo "digest=${digest}"
|
echo "digest=${digest}"
|
||||||
touch "/tmp/digests/${digest#sha256:}"
|
echo "${digest}" > "/tmp/digests/digest-${{ matrix.arch }}.txt"
|
||||||
- name: Upload digest
|
- name: Upload digest
|
||||||
if: steps.check-push.outputs.should-push == 'true'
|
if: steps.check-push.outputs.should-push == 'true'
|
||||||
uses: actions/upload-artifact@v7.0.0
|
uses: actions/upload-artifact@v7.0.0
|
||||||
with:
|
with:
|
||||||
name: digests-${{ matrix.arch }}
|
name: digests-${{ matrix.arch }}
|
||||||
path: /tmp/digests/*
|
path: /tmp/digests/digest-${{ matrix.arch }}.txt
|
||||||
if-no-files-found: error
|
if-no-files-found: error
|
||||||
retention-days: 1
|
retention-days: 1
|
||||||
|
archive: false
|
||||||
merge-and-push:
|
merge-and-push:
|
||||||
name: Merge and Push Manifest
|
name: Merge and Push Manifest
|
||||||
runs-on: ubuntu-24.04
|
runs-on: ubuntu-24.04
|
||||||
@@ -171,7 +172,7 @@ jobs:
|
|||||||
uses: actions/download-artifact@v8.0.0
|
uses: actions/download-artifact@v8.0.0
|
||||||
with:
|
with:
|
||||||
path: /tmp/digests
|
path: /tmp/digests
|
||||||
pattern: digests-*
|
pattern: digest-*.txt
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
- name: List digests
|
- name: List digests
|
||||||
run: |
|
run: |
|
||||||
@@ -217,8 +218,9 @@ jobs:
|
|||||||
tags=$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "${DOCKER_METADATA_OUTPUT_JSON}")
|
tags=$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "${DOCKER_METADATA_OUTPUT_JSON}")
|
||||||
|
|
||||||
digests=""
|
digests=""
|
||||||
for digest in *; do
|
for digest_file in digest-*.txt; do
|
||||||
digests+="${{ env.REGISTRY }}/${REPOSITORY}@sha256:${digest} "
|
digest=$(cat "${digest_file}")
|
||||||
|
digests+="${{ env.REGISTRY }}/${REPOSITORY}@${digest} "
|
||||||
done
|
done
|
||||||
|
|
||||||
echo "Creating manifest with tags: ${tags}"
|
echo "Creating manifest with tags: ${tags}"
|
||||||
|
|||||||
17
.github/workflows/pr-bot.yml
vendored
17
.github/workflows/pr-bot.yml
vendored
@@ -2,13 +2,24 @@ name: PR Bot
|
|||||||
on:
|
on:
|
||||||
pull_request_target:
|
pull_request_target:
|
||||||
types: [opened]
|
types: [opened]
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
pull-requests: write
|
|
||||||
jobs:
|
jobs:
|
||||||
|
anti-slop:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
issues: read
|
||||||
|
pull-requests: write
|
||||||
|
steps:
|
||||||
|
- uses: peakoss/anti-slop@v0.2.1
|
||||||
|
with:
|
||||||
|
max-failures: 4
|
||||||
|
failure-add-pr-labels: 'ai'
|
||||||
pr-bot:
|
pr-bot:
|
||||||
name: Automated PR Bot
|
name: Automated PR Bot
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
steps:
|
steps:
|
||||||
- name: Label PR by file path or branch name
|
- name: Label PR by file path or branch name
|
||||||
# see .github/labeler.yml for the labeler config
|
# see .github/labeler.yml for the labeler config
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import json
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from itertools import chain
|
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
@@ -81,6 +80,87 @@ def serialize_queryset_batched(
|
|||||||
yield serializers.serialize("python", chunk)
|
yield serializers.serialize("python", chunk)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingManifestWriter:
|
||||||
|
"""Incrementally writes a JSON array to a file, one record at a time.
|
||||||
|
|
||||||
|
Writes to <target>.tmp first; on close(), optionally BLAKE2b-compares
|
||||||
|
with the existing file (--compare-json) and renames or discards accordingly.
|
||||||
|
On exception, discard() deletes the tmp file and leaves the original intact.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
*,
|
||||||
|
compare_json: bool = False,
|
||||||
|
files_in_export_dir: "set[Path] | None" = None,
|
||||||
|
) -> None:
|
||||||
|
self._path = path.resolve()
|
||||||
|
self._tmp_path = self._path.with_suffix(self._path.suffix + ".tmp")
|
||||||
|
self._compare_json = compare_json
|
||||||
|
self._files_in_export_dir: set[Path] = (
|
||||||
|
files_in_export_dir if files_in_export_dir is not None else set()
|
||||||
|
)
|
||||||
|
self._file = None
|
||||||
|
self._first = True
|
||||||
|
|
||||||
|
def open(self) -> None:
|
||||||
|
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._file = self._tmp_path.open("w", encoding="utf-8")
|
||||||
|
self._file.write("[")
|
||||||
|
self._first = True
|
||||||
|
|
||||||
|
def write_record(self, record: dict) -> None:
|
||||||
|
if not self._first:
|
||||||
|
self._file.write(",\n")
|
||||||
|
else:
|
||||||
|
self._first = False
|
||||||
|
self._file.write(
|
||||||
|
json.dumps(record, cls=DjangoJSONEncoder, indent=2, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def write_batch(self, records: list[dict]) -> None:
|
||||||
|
for record in records:
|
||||||
|
self.write_record(record)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
if self._file is None:
|
||||||
|
return
|
||||||
|
self._file.write("\n]")
|
||||||
|
self._file.close()
|
||||||
|
self._file = None
|
||||||
|
self._finalize()
|
||||||
|
|
||||||
|
def discard(self) -> None:
|
||||||
|
if self._file is not None:
|
||||||
|
self._file.close()
|
||||||
|
self._file = None
|
||||||
|
if self._tmp_path.exists():
|
||||||
|
self._tmp_path.unlink()
|
||||||
|
|
||||||
|
def _finalize(self) -> None:
|
||||||
|
"""Compare with existing file (if --compare-json) then rename or discard tmp."""
|
||||||
|
if self._path in self._files_in_export_dir:
|
||||||
|
self._files_in_export_dir.remove(self._path)
|
||||||
|
if self._compare_json:
|
||||||
|
existing_hash = hashlib.blake2b(self._path.read_bytes()).hexdigest()
|
||||||
|
new_hash = hashlib.blake2b(self._tmp_path.read_bytes()).hexdigest()
|
||||||
|
if existing_hash == new_hash:
|
||||||
|
self._tmp_path.unlink()
|
||||||
|
return
|
||||||
|
self._tmp_path.rename(self._path)
|
||||||
|
|
||||||
|
def __enter__(self) -> "StreamingManifestWriter":
|
||||||
|
self.open()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
if exc_type is not None:
|
||||||
|
self.discard()
|
||||||
|
else:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
class Command(CryptMixin, BaseCommand):
|
class Command(CryptMixin, BaseCommand):
|
||||||
help = (
|
help = (
|
||||||
"Decrypt and rename all files in our collection into a given target "
|
"Decrypt and rename all files in our collection into a given target "
|
||||||
@@ -322,95 +402,83 @@ class Command(CryptMixin, BaseCommand):
|
|||||||
if settings.AUDIT_LOG_ENABLED:
|
if settings.AUDIT_LOG_ENABLED:
|
||||||
manifest_key_to_object_query["log_entries"] = LogEntry.objects.all()
|
manifest_key_to_object_query["log_entries"] = LogEntry.objects.all()
|
||||||
|
|
||||||
with transaction.atomic():
|
# Crypto setup before streaming begins
|
||||||
manifest_dict = {}
|
if self.passphrase:
|
||||||
|
self.setup_crypto(passphrase=self.passphrase)
|
||||||
# Build an overall manifest
|
elif MailAccount.objects.count() > 0 or SocialToken.objects.count() > 0:
|
||||||
for key, object_query in manifest_key_to_object_query.items():
|
self.stdout.write(
|
||||||
manifest_dict[key] = list(
|
self.style.NOTICE(
|
||||||
chain.from_iterable(
|
"No passphrase was given, sensitive fields will be in plaintext",
|
||||||
serialize_queryset_batched(
|
),
|
||||||
object_query,
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.encrypt_secret_fields(manifest_dict)
|
|
||||||
|
|
||||||
# These are treated specially and included in the per-document manifest
|
|
||||||
# if that setting is enabled. Otherwise, they are just exported to the bulk
|
|
||||||
# manifest
|
|
||||||
document_map: dict[int, Document] = {
|
|
||||||
d.pk: d for d in manifest_key_to_object_query["documents"]
|
|
||||||
}
|
|
||||||
document_manifest = manifest_dict["documents"]
|
|
||||||
|
|
||||||
# 3. Export files from each document
|
|
||||||
for index, document_dict in tqdm.tqdm(
|
|
||||||
enumerate(document_manifest),
|
|
||||||
total=len(document_manifest),
|
|
||||||
disable=self.no_progress_bar,
|
|
||||||
):
|
|
||||||
document = document_map[document_dict["pk"]]
|
|
||||||
|
|
||||||
# 3.1. generate a unique filename
|
|
||||||
base_name = self.generate_base_name(document)
|
|
||||||
|
|
||||||
# 3.2. write filenames into manifest
|
|
||||||
original_target, thumbnail_target, archive_target = (
|
|
||||||
self.generate_document_targets(document, base_name, document_dict)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3.3. write files to target folder
|
document_manifest: list[dict] = []
|
||||||
if not self.data_only:
|
|
||||||
self.copy_document_files(
|
|
||||||
document,
|
|
||||||
original_target,
|
|
||||||
thumbnail_target,
|
|
||||||
archive_target,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.split_manifest:
|
|
||||||
manifest_name = base_name.with_name(f"{base_name.stem}-manifest.json")
|
|
||||||
if self.use_folder_prefix:
|
|
||||||
manifest_name = Path("json") / manifest_name
|
|
||||||
manifest_name = (self.target / manifest_name).resolve()
|
|
||||||
manifest_name.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
content = [document_manifest[index]]
|
|
||||||
content += list(
|
|
||||||
filter(
|
|
||||||
lambda d: d["fields"]["document"] == document_dict["pk"],
|
|
||||||
manifest_dict["notes"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
content += list(
|
|
||||||
filter(
|
|
||||||
lambda d: d["fields"]["document"] == document_dict["pk"],
|
|
||||||
manifest_dict["custom_field_instances"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.check_and_write_json(
|
|
||||||
content,
|
|
||||||
manifest_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
# These were exported already
|
|
||||||
if self.split_manifest:
|
|
||||||
del manifest_dict["documents"]
|
|
||||||
del manifest_dict["notes"]
|
|
||||||
del manifest_dict["custom_field_instances"]
|
|
||||||
|
|
||||||
# 4.1 write primary manifest to target folder
|
|
||||||
manifest = []
|
|
||||||
for key, item in manifest_dict.items():
|
|
||||||
manifest.extend(item)
|
|
||||||
manifest_path = (self.target / "manifest.json").resolve()
|
manifest_path = (self.target / "manifest.json").resolve()
|
||||||
self.check_and_write_json(
|
|
||||||
manifest,
|
with StreamingManifestWriter(
|
||||||
manifest_path,
|
manifest_path,
|
||||||
)
|
compare_json=self.compare_json,
|
||||||
|
files_in_export_dir=self.files_in_export_dir,
|
||||||
|
) as writer:
|
||||||
|
with transaction.atomic():
|
||||||
|
for key, qs in manifest_key_to_object_query.items():
|
||||||
|
if key == "documents":
|
||||||
|
# Accumulate for file-copy loop; written to manifest after
|
||||||
|
for batch in serialize_queryset_batched(
|
||||||
|
qs,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
):
|
||||||
|
for record in batch:
|
||||||
|
self._encrypt_record_inline(record)
|
||||||
|
document_manifest.extend(batch)
|
||||||
|
elif self.split_manifest and key in (
|
||||||
|
"notes",
|
||||||
|
"custom_field_instances",
|
||||||
|
):
|
||||||
|
# Written per-document in _write_split_manifest
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
for batch in serialize_queryset_batched(
|
||||||
|
qs,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
):
|
||||||
|
for record in batch:
|
||||||
|
self._encrypt_record_inline(record)
|
||||||
|
writer.write_batch(batch)
|
||||||
|
|
||||||
|
document_map: dict[int, Document] = {
|
||||||
|
d.pk: d for d in Document.objects.order_by("id")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. Export files from each document
|
||||||
|
for document_dict in tqdm.tqdm(
|
||||||
|
document_manifest,
|
||||||
|
total=len(document_manifest),
|
||||||
|
disable=self.no_progress_bar,
|
||||||
|
):
|
||||||
|
document = document_map[document_dict["pk"]]
|
||||||
|
|
||||||
|
# 3.1. generate a unique filename
|
||||||
|
base_name = self.generate_base_name(document)
|
||||||
|
|
||||||
|
# 3.2. write filenames into manifest
|
||||||
|
original_target, thumbnail_target, archive_target = (
|
||||||
|
self.generate_document_targets(document, base_name, document_dict)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3.3. write files to target folder
|
||||||
|
if not self.data_only:
|
||||||
|
self.copy_document_files(
|
||||||
|
document,
|
||||||
|
original_target,
|
||||||
|
thumbnail_target,
|
||||||
|
archive_target,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.split_manifest:
|
||||||
|
self._write_split_manifest(document_dict, document, base_name)
|
||||||
|
else:
|
||||||
|
writer.write_record(document_dict)
|
||||||
|
|
||||||
# 4.2 write version information to target folder
|
# 4.2 write version information to target folder
|
||||||
extra_metadata_path = (self.target / "metadata.json").resolve()
|
extra_metadata_path = (self.target / "metadata.json").resolve()
|
||||||
@@ -532,6 +600,42 @@ class Command(CryptMixin, BaseCommand):
|
|||||||
archive_target,
|
archive_target,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _encrypt_record_inline(self, record: dict) -> None:
|
||||||
|
"""Encrypt sensitive fields in a single record, if passphrase is set."""
|
||||||
|
if not self.passphrase:
|
||||||
|
return
|
||||||
|
fields = self.CRYPT_FIELDS_BY_MODEL.get(record.get("model", ""))
|
||||||
|
if fields:
|
||||||
|
for field in fields:
|
||||||
|
if record["fields"].get(field):
|
||||||
|
record["fields"][field] = self.encrypt_string(
|
||||||
|
value=record["fields"][field],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _write_split_manifest(
|
||||||
|
self,
|
||||||
|
document_dict: dict,
|
||||||
|
document: Document,
|
||||||
|
base_name: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Write per-document manifest file for --split-manifest mode."""
|
||||||
|
content = [document_dict]
|
||||||
|
content.extend(
|
||||||
|
serializers.serialize("python", Note.objects.filter(document=document)),
|
||||||
|
)
|
||||||
|
content.extend(
|
||||||
|
serializers.serialize(
|
||||||
|
"python",
|
||||||
|
CustomFieldInstance.objects.filter(document=document),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
manifest_name = base_name.with_name(f"{base_name.stem}-manifest.json")
|
||||||
|
if self.use_folder_prefix:
|
||||||
|
manifest_name = Path("json") / manifest_name
|
||||||
|
manifest_name = (self.target / manifest_name).resolve()
|
||||||
|
manifest_name.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.check_and_write_json(content, manifest_name)
|
||||||
|
|
||||||
def check_and_write_json(
|
def check_and_write_json(
|
||||||
self,
|
self,
|
||||||
content: list[dict] | dict,
|
content: list[dict] | dict,
|
||||||
@@ -549,14 +653,14 @@ class Command(CryptMixin, BaseCommand):
|
|||||||
if target in self.files_in_export_dir:
|
if target in self.files_in_export_dir:
|
||||||
self.files_in_export_dir.remove(target)
|
self.files_in_export_dir.remove(target)
|
||||||
if self.compare_json:
|
if self.compare_json:
|
||||||
target_checksum = hashlib.md5(target.read_bytes()).hexdigest()
|
target_checksum = hashlib.blake2b(target.read_bytes()).hexdigest()
|
||||||
src_str = json.dumps(
|
src_str = json.dumps(
|
||||||
content,
|
content,
|
||||||
cls=DjangoJSONEncoder,
|
cls=DjangoJSONEncoder,
|
||||||
indent=2,
|
indent=2,
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
src_checksum = hashlib.md5(src_str.encode("utf-8")).hexdigest()
|
src_checksum = hashlib.blake2b(src_str.encode("utf-8")).hexdigest()
|
||||||
if src_checksum == target_checksum:
|
if src_checksum == target_checksum:
|
||||||
perform_write = False
|
perform_write = False
|
||||||
|
|
||||||
@@ -606,28 +710,3 @@ class Command(CryptMixin, BaseCommand):
|
|||||||
if perform_copy:
|
if perform_copy:
|
||||||
target.parent.mkdir(parents=True, exist_ok=True)
|
target.parent.mkdir(parents=True, exist_ok=True)
|
||||||
copy_file_with_basic_stats(source, target)
|
copy_file_with_basic_stats(source, target)
|
||||||
|
|
||||||
def encrypt_secret_fields(self, manifest: dict) -> None:
|
|
||||||
"""
|
|
||||||
Encrypts certain fields in the export. Currently limited to the mail account password
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.passphrase:
|
|
||||||
self.setup_crypto(passphrase=self.passphrase)
|
|
||||||
|
|
||||||
for crypt_config in self.CRYPT_FIELDS:
|
|
||||||
exporter_key = crypt_config["exporter_key"]
|
|
||||||
crypt_fields = crypt_config["fields"]
|
|
||||||
for manifest_record in manifest[exporter_key]:
|
|
||||||
for field in crypt_fields:
|
|
||||||
if manifest_record["fields"][field]:
|
|
||||||
manifest_record["fields"][field] = self.encrypt_string(
|
|
||||||
value=manifest_record["fields"][field],
|
|
||||||
)
|
|
||||||
|
|
||||||
elif MailAccount.objects.count() > 0 or SocialToken.objects.count() > 0:
|
|
||||||
self.stdout.write(
|
|
||||||
self.style.NOTICE(
|
|
||||||
"No passphrase was given, sensitive fields will be in plaintext",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ class CryptMixin:
|
|||||||
key_size = 32
|
key_size = 32
|
||||||
kdf_algorithm = "pbkdf2_sha256"
|
kdf_algorithm = "pbkdf2_sha256"
|
||||||
|
|
||||||
CRYPT_FIELDS: CryptFields = [
|
CRYPT_FIELDS: list[CryptFields] = [
|
||||||
{
|
{
|
||||||
"exporter_key": "mail_accounts",
|
"exporter_key": "mail_accounts",
|
||||||
"model_name": "paperless_mail.mailaccount",
|
"model_name": "paperless_mail.mailaccount",
|
||||||
@@ -89,6 +89,10 @@ class CryptMixin:
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
# O(1) lookup for per-record encryption; derived from CRYPT_FIELDS at class definition time
|
||||||
|
CRYPT_FIELDS_BY_MODEL: dict[str, list[str]] = {
|
||||||
|
cfg["model_name"]: cfg["fields"] for cfg in CRYPT_FIELDS
|
||||||
|
}
|
||||||
|
|
||||||
def get_crypt_params(self) -> dict[str, dict[str, str | int]]:
|
def get_crypt_params(self) -> dict[str, dict[str, str | int]]:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -753,6 +753,31 @@ class TestExportImport(
|
|||||||
call_command("document_importer", "--no-progress-bar", self.target)
|
call_command("document_importer", "--no-progress-bar", self.target)
|
||||||
self.assertEqual(Document.objects.count(), 4)
|
self.assertEqual(Document.objects.count(), 4)
|
||||||
|
|
||||||
|
def test_folder_prefix_with_split(self) -> None:
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- Request to export documents to directory
|
||||||
|
WHEN:
|
||||||
|
- Option use_folder_prefix is used
|
||||||
|
- Option split manifest is used
|
||||||
|
THEN:
|
||||||
|
- Documents can be imported again
|
||||||
|
"""
|
||||||
|
shutil.rmtree(Path(self.dirs.media_dir) / "documents")
|
||||||
|
shutil.copytree(
|
||||||
|
Path(__file__).parent / "samples" / "documents",
|
||||||
|
Path(self.dirs.media_dir) / "documents",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._do_export(use_folder_prefix=True, split_manifest=True)
|
||||||
|
|
||||||
|
with paperless_environment():
|
||||||
|
self.assertEqual(Document.objects.count(), 4)
|
||||||
|
Document.objects.all().delete()
|
||||||
|
self.assertEqual(Document.objects.count(), 0)
|
||||||
|
call_command("document_importer", "--no-progress-bar", self.target)
|
||||||
|
self.assertEqual(Document.objects.count(), 4)
|
||||||
|
|
||||||
def test_import_db_transaction_failed(self) -> None:
|
def test_import_db_transaction_failed(self) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
|
|||||||
@@ -1,107 +1,100 @@
|
|||||||
from unittest import mock
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
from allauth.account.adapter import get_adapter
|
from allauth.account.adapter import get_adapter
|
||||||
from allauth.core import context
|
from allauth.core import context
|
||||||
from allauth.socialaccount.adapter import get_adapter as get_social_adapter
|
from allauth.socialaccount.adapter import get_adapter as get_social_adapter
|
||||||
from django.conf import settings
|
|
||||||
from django.contrib.auth.models import AnonymousUser
|
from django.contrib.auth.models import AnonymousUser
|
||||||
from django.contrib.auth.models import Group
|
from django.contrib.auth.models import Group
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.forms import ValidationError
|
from django.forms import ValidationError
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from django.test import TestCase
|
|
||||||
from django.test import override_settings
|
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
|
from pytest_django.fixtures import SettingsWrapper
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
from rest_framework.authtoken.models import Token
|
from rest_framework.authtoken.models import Token
|
||||||
|
|
||||||
from paperless.adapter import DrfTokenStrategy
|
from paperless.adapter import DrfTokenStrategy
|
||||||
|
|
||||||
|
|
||||||
class TestCustomAccountAdapter(TestCase):
|
@pytest.mark.django_db
|
||||||
def test_is_open_for_signup(self) -> None:
|
class TestCustomAccountAdapter:
|
||||||
|
def test_is_open_for_signup(self, settings: SettingsWrapper) -> None:
|
||||||
adapter = get_adapter()
|
adapter = get_adapter()
|
||||||
|
|
||||||
# With no accounts, signups should be allowed
|
# With no accounts, signups should be allowed
|
||||||
self.assertTrue(adapter.is_open_for_signup(None))
|
assert adapter.is_open_for_signup(None)
|
||||||
|
|
||||||
User.objects.create_user("testuser")
|
User.objects.create_user("testuser")
|
||||||
|
|
||||||
# Test when ACCOUNT_ALLOW_SIGNUPS is True
|
|
||||||
settings.ACCOUNT_ALLOW_SIGNUPS = True
|
settings.ACCOUNT_ALLOW_SIGNUPS = True
|
||||||
self.assertTrue(adapter.is_open_for_signup(None))
|
assert adapter.is_open_for_signup(None)
|
||||||
|
|
||||||
# Test when ACCOUNT_ALLOW_SIGNUPS is False
|
|
||||||
settings.ACCOUNT_ALLOW_SIGNUPS = False
|
settings.ACCOUNT_ALLOW_SIGNUPS = False
|
||||||
self.assertFalse(adapter.is_open_for_signup(None))
|
assert not adapter.is_open_for_signup(None)
|
||||||
|
|
||||||
def test_is_safe_url(self) -> None:
|
def test_is_safe_url(self, settings: SettingsWrapper) -> None:
|
||||||
request = HttpRequest()
|
request = HttpRequest()
|
||||||
request.get_host = mock.Mock(return_value="example.com")
|
request.get_host = lambda: "example.com"
|
||||||
with context.request_context(request):
|
with context.request_context(request):
|
||||||
adapter = get_adapter()
|
adapter = get_adapter()
|
||||||
with override_settings(ALLOWED_HOSTS=["*"]):
|
|
||||||
# True because request host is same
|
|
||||||
url = "https://example.com"
|
|
||||||
self.assertTrue(adapter.is_safe_url(url))
|
|
||||||
|
|
||||||
url = "https://evil.com"
|
settings.ALLOWED_HOSTS = ["*"]
|
||||||
|
# True because request host is same
|
||||||
|
assert adapter.is_safe_url("https://example.com")
|
||||||
# False despite wildcard because request host is different
|
# False despite wildcard because request host is different
|
||||||
self.assertFalse(adapter.is_safe_url(url))
|
assert not adapter.is_safe_url("https://evil.com")
|
||||||
|
|
||||||
settings.ALLOWED_HOSTS = ["example.com"]
|
settings.ALLOWED_HOSTS = ["example.com"]
|
||||||
url = "https://example.com"
|
|
||||||
# True because request host is same
|
# True because request host is same
|
||||||
self.assertTrue(adapter.is_safe_url(url))
|
assert adapter.is_safe_url("https://example.com")
|
||||||
|
|
||||||
settings.ALLOWED_HOSTS = ["*", "example.com"]
|
settings.ALLOWED_HOSTS = ["*", "example.com"]
|
||||||
url = "//evil.com"
|
|
||||||
# False because request host is not in allowed hosts
|
# False because request host is not in allowed hosts
|
||||||
self.assertFalse(adapter.is_safe_url(url))
|
assert not adapter.is_safe_url("//evil.com")
|
||||||
|
|
||||||
@mock.patch("allauth.core.internal.ratelimit.consume", return_value=True)
|
def test_pre_authenticate(
|
||||||
def test_pre_authenticate(self, mock_consume) -> None:
|
self,
|
||||||
|
settings: SettingsWrapper,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch("allauth.core.internal.ratelimit.consume", return_value=True)
|
||||||
adapter = get_adapter()
|
adapter = get_adapter()
|
||||||
request = HttpRequest()
|
request = HttpRequest()
|
||||||
request.get_host = mock.Mock(return_value="example.com")
|
request.get_host = lambda: "example.com"
|
||||||
|
|
||||||
settings.DISABLE_REGULAR_LOGIN = False
|
settings.DISABLE_REGULAR_LOGIN = False
|
||||||
adapter.pre_authenticate(request)
|
adapter.pre_authenticate(request)
|
||||||
|
|
||||||
settings.DISABLE_REGULAR_LOGIN = True
|
settings.DISABLE_REGULAR_LOGIN = True
|
||||||
with self.assertRaises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
adapter.pre_authenticate(request)
|
adapter.pre_authenticate(request)
|
||||||
|
|
||||||
def test_get_reset_password_from_key_url(self) -> None:
|
def test_get_reset_password_from_key_url(self, settings: SettingsWrapper) -> None:
|
||||||
request = HttpRequest()
|
request = HttpRequest()
|
||||||
request.get_host = mock.Mock(return_value="foo.org")
|
request.get_host = lambda: "foo.org"
|
||||||
with context.request_context(request):
|
with context.request_context(request):
|
||||||
adapter = get_adapter()
|
adapter = get_adapter()
|
||||||
|
|
||||||
# Test when PAPERLESS_URL is None
|
settings.PAPERLESS_URL = None
|
||||||
with override_settings(
|
settings.ACCOUNT_DEFAULT_HTTP_PROTOCOL = "https"
|
||||||
PAPERLESS_URL=None,
|
expected_url = f"https://foo.org{reverse('account_reset_password_from_key', kwargs={'uidb36': 'UID', 'key': 'KEY'})}"
|
||||||
ACCOUNT_DEFAULT_HTTP_PROTOCOL="https",
|
assert adapter.get_reset_password_from_key_url("UID-KEY") == expected_url
|
||||||
):
|
|
||||||
expected_url = f"https://foo.org{reverse('account_reset_password_from_key', kwargs={'uidb36': 'UID', 'key': 'KEY'})}"
|
|
||||||
self.assertEqual(
|
|
||||||
adapter.get_reset_password_from_key_url("UID-KEY"),
|
|
||||||
expected_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test when PAPERLESS_URL is not None
|
settings.PAPERLESS_URL = "https://bar.com"
|
||||||
with override_settings(PAPERLESS_URL="https://bar.com"):
|
expected_url = f"https://bar.com{reverse('account_reset_password_from_key', kwargs={'uidb36': 'UID', 'key': 'KEY'})}"
|
||||||
expected_url = f"https://bar.com{reverse('account_reset_password_from_key', kwargs={'uidb36': 'UID', 'key': 'KEY'})}"
|
assert adapter.get_reset_password_from_key_url("UID-KEY") == expected_url
|
||||||
self.assertEqual(
|
|
||||||
adapter.get_reset_password_from_key_url("UID-KEY"),
|
|
||||||
expected_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
@override_settings(ACCOUNT_DEFAULT_GROUPS=["group1", "group2"])
|
def test_save_user_adds_groups(
|
||||||
def test_save_user_adds_groups(self) -> None:
|
self,
|
||||||
|
settings: SettingsWrapper,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
settings.ACCOUNT_DEFAULT_GROUPS = ["group1", "group2"]
|
||||||
Group.objects.create(name="group1")
|
Group.objects.create(name="group1")
|
||||||
user = User.objects.create_user("testuser")
|
user = User.objects.create_user("testuser")
|
||||||
adapter = get_adapter()
|
adapter = get_adapter()
|
||||||
form = mock.Mock(
|
form = mocker.MagicMock(
|
||||||
cleaned_data={
|
cleaned_data={
|
||||||
"username": "testuser",
|
"username": "testuser",
|
||||||
"email": "user@example.com",
|
"email": "user@example.com",
|
||||||
@@ -110,88 +103,81 @@ class TestCustomAccountAdapter(TestCase):
|
|||||||
|
|
||||||
user = adapter.save_user(HttpRequest(), user, form, commit=True)
|
user = adapter.save_user(HttpRequest(), user, form, commit=True)
|
||||||
|
|
||||||
self.assertEqual(user.groups.count(), 1)
|
assert user.groups.count() == 1
|
||||||
self.assertTrue(user.groups.filter(name="group1").exists())
|
assert user.groups.filter(name="group1").exists()
|
||||||
self.assertFalse(user.groups.filter(name="group2").exists())
|
assert not user.groups.filter(name="group2").exists()
|
||||||
|
|
||||||
def test_fresh_install_save_creates_superuser(self) -> None:
|
def test_fresh_install_save_creates_superuser(self, mocker: MockerFixture) -> None:
|
||||||
adapter = get_adapter()
|
adapter = get_adapter()
|
||||||
form = mock.Mock(
|
form = mocker.MagicMock(
|
||||||
cleaned_data={
|
cleaned_data={
|
||||||
"username": "testuser",
|
"username": "testuser",
|
||||||
"email": "user@paperless-ngx.com",
|
"email": "user@paperless-ngx.com",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
user = adapter.save_user(HttpRequest(), User(), form, commit=True)
|
user = adapter.save_user(HttpRequest(), User(), form, commit=True)
|
||||||
self.assertTrue(user.is_superuser)
|
assert user.is_superuser
|
||||||
|
|
||||||
# Next time, it should not create a superuser
|
form = mocker.MagicMock(
|
||||||
form = mock.Mock(
|
|
||||||
cleaned_data={
|
cleaned_data={
|
||||||
"username": "testuser2",
|
"username": "testuser2",
|
||||||
"email": "user2@paperless-ngx.com",
|
"email": "user2@paperless-ngx.com",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
user2 = adapter.save_user(HttpRequest(), User(), form, commit=True)
|
user2 = adapter.save_user(HttpRequest(), User(), form, commit=True)
|
||||||
self.assertFalse(user2.is_superuser)
|
assert not user2.is_superuser
|
||||||
|
|
||||||
|
|
||||||
class TestCustomSocialAccountAdapter(TestCase):
|
class TestCustomSocialAccountAdapter:
|
||||||
def test_is_open_for_signup(self) -> None:
|
@pytest.mark.django_db
|
||||||
|
def test_is_open_for_signup(self, settings: SettingsWrapper) -> None:
|
||||||
adapter = get_social_adapter()
|
adapter = get_social_adapter()
|
||||||
|
|
||||||
# Test when SOCIALACCOUNT_ALLOW_SIGNUPS is True
|
|
||||||
settings.SOCIALACCOUNT_ALLOW_SIGNUPS = True
|
settings.SOCIALACCOUNT_ALLOW_SIGNUPS = True
|
||||||
self.assertTrue(adapter.is_open_for_signup(None, None))
|
assert adapter.is_open_for_signup(None, None)
|
||||||
|
|
||||||
# Test when SOCIALACCOUNT_ALLOW_SIGNUPS is False
|
|
||||||
settings.SOCIALACCOUNT_ALLOW_SIGNUPS = False
|
settings.SOCIALACCOUNT_ALLOW_SIGNUPS = False
|
||||||
self.assertFalse(adapter.is_open_for_signup(None, None))
|
assert not adapter.is_open_for_signup(None, None)
|
||||||
|
|
||||||
def test_get_connect_redirect_url(self) -> None:
|
def test_get_connect_redirect_url(self) -> None:
|
||||||
adapter = get_social_adapter()
|
adapter = get_social_adapter()
|
||||||
request = None
|
assert adapter.get_connect_redirect_url(None, None) == reverse("base")
|
||||||
socialaccount = None
|
|
||||||
|
|
||||||
# Test the default URL
|
@pytest.mark.django_db
|
||||||
expected_url = reverse("base")
|
def test_save_user_adds_groups(
|
||||||
self.assertEqual(
|
self,
|
||||||
adapter.get_connect_redirect_url(request, socialaccount),
|
settings: SettingsWrapper,
|
||||||
expected_url,
|
mocker: MockerFixture,
|
||||||
)
|
) -> None:
|
||||||
|
settings.SOCIAL_ACCOUNT_DEFAULT_GROUPS = ["group1", "group2"]
|
||||||
@override_settings(SOCIAL_ACCOUNT_DEFAULT_GROUPS=["group1", "group2"])
|
|
||||||
def test_save_user_adds_groups(self) -> None:
|
|
||||||
Group.objects.create(name="group1")
|
Group.objects.create(name="group1")
|
||||||
adapter = get_social_adapter()
|
adapter = get_social_adapter()
|
||||||
request = HttpRequest()
|
|
||||||
user = User.objects.create_user("testuser")
|
user = User.objects.create_user("testuser")
|
||||||
sociallogin = mock.Mock(
|
sociallogin = mocker.MagicMock(user=user)
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
user = adapter.save_user(request, sociallogin, None)
|
user = adapter.save_user(HttpRequest(), sociallogin, None)
|
||||||
|
|
||||||
self.assertEqual(user.groups.count(), 1)
|
assert user.groups.count() == 1
|
||||||
self.assertTrue(user.groups.filter(name="group1").exists())
|
assert user.groups.filter(name="group1").exists()
|
||||||
self.assertFalse(user.groups.filter(name="group2").exists())
|
assert not user.groups.filter(name="group2").exists()
|
||||||
|
|
||||||
def test_error_logged_on_authentication_error(self) -> None:
|
def test_error_logged_on_authentication_error(
|
||||||
|
self,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
adapter = get_social_adapter()
|
adapter = get_social_adapter()
|
||||||
request = HttpRequest()
|
with caplog.at_level(logging.INFO, logger="paperless.auth"):
|
||||||
with self.assertLogs("paperless.auth", level="INFO") as log_cm:
|
|
||||||
adapter.on_authentication_error(
|
adapter.on_authentication_error(
|
||||||
request,
|
HttpRequest(),
|
||||||
provider="test-provider",
|
provider="test-provider",
|
||||||
error="Error",
|
error="Error",
|
||||||
exception="Test authentication error",
|
exception="Test authentication error",
|
||||||
)
|
)
|
||||||
self.assertTrue(
|
assert any("Test authentication error" in msg for msg in caplog.messages)
|
||||||
any("Test authentication error" in message for message in log_cm.output),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDrfTokenStrategy(TestCase):
|
@pytest.mark.django_db
|
||||||
|
class TestDrfTokenStrategy:
|
||||||
def test_create_access_token_creates_new_token(self) -> None:
|
def test_create_access_token_creates_new_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
@@ -201,7 +187,6 @@ class TestDrfTokenStrategy(TestCase):
|
|||||||
THEN:
|
THEN:
|
||||||
- A new token is created and its key is returned
|
- A new token is created and its key is returned
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user = User.objects.create_user("testuser")
|
user = User.objects.create_user("testuser")
|
||||||
request = HttpRequest()
|
request = HttpRequest()
|
||||||
request.user = user
|
request.user = user
|
||||||
@@ -209,13 +194,9 @@ class TestDrfTokenStrategy(TestCase):
|
|||||||
strategy = DrfTokenStrategy()
|
strategy = DrfTokenStrategy()
|
||||||
token_key = strategy.create_access_token(request)
|
token_key = strategy.create_access_token(request)
|
||||||
|
|
||||||
# Verify a token was created
|
assert token_key is not None
|
||||||
self.assertIsNotNone(token_key)
|
assert Token.objects.filter(user=user).exists()
|
||||||
self.assertTrue(Token.objects.filter(user=user).exists())
|
assert token_key == Token.objects.get(user=user).key
|
||||||
|
|
||||||
# Verify the returned key matches the created token
|
|
||||||
token = Token.objects.get(user=user)
|
|
||||||
self.assertEqual(token_key, token.key)
|
|
||||||
|
|
||||||
def test_create_access_token_returns_existing_token(self) -> None:
|
def test_create_access_token_returns_existing_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -226,7 +207,6 @@ class TestDrfTokenStrategy(TestCase):
|
|||||||
THEN:
|
THEN:
|
||||||
- The same token key is returned (no new token created)
|
- The same token key is returned (no new token created)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user = User.objects.create_user("testuser")
|
user = User.objects.create_user("testuser")
|
||||||
existing_token = Token.objects.create(user=user)
|
existing_token = Token.objects.create(user=user)
|
||||||
|
|
||||||
@@ -236,11 +216,8 @@ class TestDrfTokenStrategy(TestCase):
|
|||||||
strategy = DrfTokenStrategy()
|
strategy = DrfTokenStrategy()
|
||||||
token_key = strategy.create_access_token(request)
|
token_key = strategy.create_access_token(request)
|
||||||
|
|
||||||
# Verify the existing token key is returned
|
assert token_key == existing_token.key
|
||||||
self.assertEqual(token_key, existing_token.key)
|
assert Token.objects.filter(user=user).count() == 1
|
||||||
|
|
||||||
# Verify only one token exists (no duplicate created)
|
|
||||||
self.assertEqual(Token.objects.filter(user=user).count(), 1)
|
|
||||||
|
|
||||||
def test_create_access_token_returns_none_for_unauthenticated_user(self) -> None:
|
def test_create_access_token_returns_none_for_unauthenticated_user(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -251,12 +228,11 @@ class TestDrfTokenStrategy(TestCase):
|
|||||||
THEN:
|
THEN:
|
||||||
- None is returned and no token is created
|
- None is returned and no token is created
|
||||||
"""
|
"""
|
||||||
|
|
||||||
request = HttpRequest()
|
request = HttpRequest()
|
||||||
request.user = AnonymousUser()
|
request.user = AnonymousUser()
|
||||||
|
|
||||||
strategy = DrfTokenStrategy()
|
strategy = DrfTokenStrategy()
|
||||||
token_key = strategy.create_access_token(request)
|
token_key = strategy.create_access_token(request)
|
||||||
|
|
||||||
self.assertIsNone(token_key)
|
assert token_key is None
|
||||||
self.assertEqual(Token.objects.count(), 0)
|
assert Token.objects.count() == 0
|
||||||
|
|||||||
@@ -1,16 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from django.core.checks import Error
|
from django.core.checks import Error
|
||||||
from django.core.checks import Warning
|
from django.core.checks import Warning
|
||||||
from django.test import TestCase
|
from pytest_django.fixtures import SettingsWrapper
|
||||||
from django.test import override_settings
|
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from documents.tests.utils import DirectoriesMixin
|
|
||||||
from documents.tests.utils import FileSystemAssertsMixin
|
|
||||||
from paperless.checks import audit_log_check
|
from paperless.checks import audit_log_check
|
||||||
from paperless.checks import binaries_check
|
from paperless.checks import binaries_check
|
||||||
from paperless.checks import check_deprecated_db_settings
|
from paperless.checks import check_deprecated_db_settings
|
||||||
@@ -20,54 +19,84 @@ from paperless.checks import paths_check
|
|||||||
from paperless.checks import settings_values_check
|
from paperless.checks import settings_values_check
|
||||||
|
|
||||||
|
|
||||||
class TestChecks(DirectoriesMixin, TestCase):
|
@dataclass(frozen=True, slots=True)
|
||||||
def test_binaries(self) -> None:
|
class PaperlessTestDirs:
|
||||||
self.assertEqual(binaries_check(None), [])
|
data_dir: Path
|
||||||
|
media_dir: Path
|
||||||
|
consumption_dir: Path
|
||||||
|
|
||||||
@override_settings(CONVERT_BINARY="uuuhh")
|
|
||||||
def test_binaries_fail(self) -> None:
|
|
||||||
self.assertEqual(len(binaries_check(None)), 1)
|
|
||||||
|
|
||||||
def test_paths_check(self) -> None:
|
# TODO: consolidate with documents/tests/conftest.py PaperlessDirs/paperless_dirs
|
||||||
self.assertEqual(paths_check(None), [])
|
# once the paperless and documents test suites are ready to share fixtures.
|
||||||
|
@pytest.fixture()
|
||||||
|
def directories(tmp_path: Path, settings: SettingsWrapper) -> PaperlessTestDirs:
|
||||||
|
data_dir = tmp_path / "data"
|
||||||
|
media_dir = tmp_path / "media"
|
||||||
|
consumption_dir = tmp_path / "consumption"
|
||||||
|
|
||||||
@override_settings(
|
for d in (data_dir, media_dir, consumption_dir):
|
||||||
MEDIA_ROOT=Path("uuh"),
|
d.mkdir()
|
||||||
DATA_DIR=Path("whatever"),
|
|
||||||
CONSUMPTION_DIR=Path("idontcare"),
|
settings.DATA_DIR = data_dir
|
||||||
|
settings.MEDIA_ROOT = media_dir
|
||||||
|
settings.CONSUMPTION_DIR = consumption_dir
|
||||||
|
|
||||||
|
return PaperlessTestDirs(
|
||||||
|
data_dir=data_dir,
|
||||||
|
media_dir=media_dir,
|
||||||
|
consumption_dir=consumption_dir,
|
||||||
)
|
)
|
||||||
def test_paths_check_dont_exist(self) -> None:
|
|
||||||
msgs = paths_check(None)
|
|
||||||
self.assertEqual(len(msgs), 3, str(msgs))
|
|
||||||
|
|
||||||
for msg in msgs:
|
|
||||||
self.assertTrue(msg.msg.endswith("is set but doesn't exist."))
|
|
||||||
|
|
||||||
def test_paths_check_no_access(self) -> None:
|
class TestChecks:
|
||||||
Path(self.dirs.data_dir).chmod(0o000)
|
def test_binaries(self) -> None:
|
||||||
Path(self.dirs.media_dir).chmod(0o000)
|
assert binaries_check(None) == []
|
||||||
Path(self.dirs.consumption_dir).chmod(0o000)
|
|
||||||
|
|
||||||
self.addCleanup(os.chmod, self.dirs.data_dir, 0o777)
|
def test_binaries_fail(self, settings: SettingsWrapper) -> None:
|
||||||
self.addCleanup(os.chmod, self.dirs.media_dir, 0o777)
|
settings.CONVERT_BINARY = "uuuhh"
|
||||||
self.addCleanup(os.chmod, self.dirs.consumption_dir, 0o777)
|
assert len(binaries_check(None)) == 1
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("directories")
|
||||||
|
def test_paths_check(self) -> None:
|
||||||
|
assert paths_check(None) == []
|
||||||
|
|
||||||
|
def test_paths_check_dont_exist(self, settings: SettingsWrapper) -> None:
|
||||||
|
settings.MEDIA_ROOT = Path("uuh")
|
||||||
|
settings.DATA_DIR = Path("whatever")
|
||||||
|
settings.CONSUMPTION_DIR = Path("idontcare")
|
||||||
|
|
||||||
msgs = paths_check(None)
|
msgs = paths_check(None)
|
||||||
self.assertEqual(len(msgs), 3)
|
|
||||||
|
|
||||||
|
assert len(msgs) == 3, str(msgs)
|
||||||
for msg in msgs:
|
for msg in msgs:
|
||||||
self.assertTrue(msg.msg.endswith("is not writeable"))
|
assert msg.msg.endswith("is set but doesn't exist.")
|
||||||
|
|
||||||
@override_settings(DEBUG=False)
|
def test_paths_check_no_access(self, directories: PaperlessTestDirs) -> None:
|
||||||
def test_debug_disabled(self) -> None:
|
directories.data_dir.chmod(0o000)
|
||||||
self.assertEqual(debug_mode_check(None), [])
|
directories.media_dir.chmod(0o000)
|
||||||
|
directories.consumption_dir.chmod(0o000)
|
||||||
|
|
||||||
@override_settings(DEBUG=True)
|
try:
|
||||||
def test_debug_enabled(self) -> None:
|
msgs = paths_check(None)
|
||||||
self.assertEqual(len(debug_mode_check(None)), 1)
|
finally:
|
||||||
|
directories.data_dir.chmod(0o777)
|
||||||
|
directories.media_dir.chmod(0o777)
|
||||||
|
directories.consumption_dir.chmod(0o777)
|
||||||
|
|
||||||
|
assert len(msgs) == 3
|
||||||
|
for msg in msgs:
|
||||||
|
assert msg.msg.endswith("is not writeable")
|
||||||
|
|
||||||
|
def test_debug_disabled(self, settings: SettingsWrapper) -> None:
|
||||||
|
settings.DEBUG = False
|
||||||
|
assert debug_mode_check(None) == []
|
||||||
|
|
||||||
|
def test_debug_enabled(self, settings: SettingsWrapper) -> None:
|
||||||
|
settings.DEBUG = True
|
||||||
|
assert len(debug_mode_check(None)) == 1
|
||||||
|
|
||||||
|
|
||||||
class TestSettingsChecksAgainstDefaults(DirectoriesMixin, TestCase):
|
class TestSettingsChecksAgainstDefaults:
|
||||||
def test_all_valid(self) -> None:
|
def test_all_valid(self) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
@@ -78,104 +107,71 @@ class TestSettingsChecksAgainstDefaults(DirectoriesMixin, TestCase):
|
|||||||
- No system check errors reported
|
- No system check errors reported
|
||||||
"""
|
"""
|
||||||
msgs = settings_values_check(None)
|
msgs = settings_values_check(None)
|
||||||
self.assertEqual(len(msgs), 0)
|
assert len(msgs) == 0
|
||||||
|
|
||||||
|
|
||||||
class TestOcrSettingsChecks(DirectoriesMixin, TestCase):
|
class TestOcrSettingsChecks:
|
||||||
@override_settings(OCR_OUTPUT_TYPE="notapdf")
|
@pytest.mark.parametrize(
|
||||||
def test_invalid_output_type(self) -> None:
|
("setting", "value", "expected_msg"),
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
"OCR_OUTPUT_TYPE",
|
||||||
|
"notapdf",
|
||||||
|
'OCR output type "notapdf"',
|
||||||
|
id="invalid-output-type",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"OCR_MODE",
|
||||||
|
"makeitso",
|
||||||
|
'OCR output mode "makeitso"',
|
||||||
|
id="invalid-mode",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"OCR_MODE",
|
||||||
|
"skip_noarchive",
|
||||||
|
"deprecated",
|
||||||
|
id="deprecated-mode",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"OCR_SKIP_ARCHIVE_FILE",
|
||||||
|
"invalid",
|
||||||
|
'OCR_SKIP_ARCHIVE_FILE setting "invalid"',
|
||||||
|
id="invalid-skip-archive-file",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"OCR_CLEAN",
|
||||||
|
"cleanme",
|
||||||
|
'OCR clean mode "cleanme"',
|
||||||
|
id="invalid-clean",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_invalid_setting_produces_one_error(
|
||||||
|
self,
|
||||||
|
settings: SettingsWrapper,
|
||||||
|
setting: str,
|
||||||
|
value: str,
|
||||||
|
expected_msg: str,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
- Default settings
|
- Default settings
|
||||||
- OCR output type is invalid
|
- One OCR setting is set to an invalid value
|
||||||
WHEN:
|
WHEN:
|
||||||
- Settings are validated
|
- Settings are validated
|
||||||
THEN:
|
THEN:
|
||||||
- system check error reported for OCR output type
|
- Exactly one system check error is reported containing the expected message
|
||||||
"""
|
"""
|
||||||
|
setattr(settings, setting, value)
|
||||||
|
|
||||||
msgs = settings_values_check(None)
|
msgs = settings_values_check(None)
|
||||||
self.assertEqual(len(msgs), 1)
|
|
||||||
|
|
||||||
msg = msgs[0]
|
assert len(msgs) == 1
|
||||||
|
assert expected_msg in msgs[0].msg
|
||||||
self.assertIn('OCR output type "notapdf"', msg.msg)
|
|
||||||
|
|
||||||
@override_settings(OCR_MODE="makeitso")
|
|
||||||
def test_invalid_ocr_type(self) -> None:
|
|
||||||
"""
|
|
||||||
GIVEN:
|
|
||||||
- Default settings
|
|
||||||
- OCR type is invalid
|
|
||||||
WHEN:
|
|
||||||
- Settings are validated
|
|
||||||
THEN:
|
|
||||||
- system check error reported for OCR type
|
|
||||||
"""
|
|
||||||
msgs = settings_values_check(None)
|
|
||||||
self.assertEqual(len(msgs), 1)
|
|
||||||
|
|
||||||
msg = msgs[0]
|
|
||||||
|
|
||||||
self.assertIn('OCR output mode "makeitso"', msg.msg)
|
|
||||||
|
|
||||||
@override_settings(OCR_MODE="skip_noarchive")
|
|
||||||
def test_deprecated_ocr_type(self) -> None:
|
|
||||||
"""
|
|
||||||
GIVEN:
|
|
||||||
- Default settings
|
|
||||||
- OCR type is deprecated
|
|
||||||
WHEN:
|
|
||||||
- Settings are validated
|
|
||||||
THEN:
|
|
||||||
- deprecation warning reported for OCR type
|
|
||||||
"""
|
|
||||||
msgs = settings_values_check(None)
|
|
||||||
self.assertEqual(len(msgs), 1)
|
|
||||||
|
|
||||||
msg = msgs[0]
|
|
||||||
|
|
||||||
self.assertIn("deprecated", msg.msg)
|
|
||||||
|
|
||||||
@override_settings(OCR_SKIP_ARCHIVE_FILE="invalid")
|
|
||||||
def test_invalid_ocr_skip_archive_file(self) -> None:
|
|
||||||
"""
|
|
||||||
GIVEN:
|
|
||||||
- Default settings
|
|
||||||
- OCR_SKIP_ARCHIVE_FILE is invalid
|
|
||||||
WHEN:
|
|
||||||
- Settings are validated
|
|
||||||
THEN:
|
|
||||||
- system check error reported for OCR_SKIP_ARCHIVE_FILE
|
|
||||||
"""
|
|
||||||
msgs = settings_values_check(None)
|
|
||||||
self.assertEqual(len(msgs), 1)
|
|
||||||
|
|
||||||
msg = msgs[0]
|
|
||||||
|
|
||||||
self.assertIn('OCR_SKIP_ARCHIVE_FILE setting "invalid"', msg.msg)
|
|
||||||
|
|
||||||
@override_settings(OCR_CLEAN="cleanme")
|
|
||||||
def test_invalid_ocr_clean(self) -> None:
|
|
||||||
"""
|
|
||||||
GIVEN:
|
|
||||||
- Default settings
|
|
||||||
- OCR cleaning type is invalid
|
|
||||||
WHEN:
|
|
||||||
- Settings are validated
|
|
||||||
THEN:
|
|
||||||
- system check error reported for OCR cleaning type
|
|
||||||
"""
|
|
||||||
msgs = settings_values_check(None)
|
|
||||||
self.assertEqual(len(msgs), 1)
|
|
||||||
|
|
||||||
msg = msgs[0]
|
|
||||||
|
|
||||||
self.assertIn('OCR clean mode "cleanme"', msg.msg)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTimezoneSettingsChecks(DirectoriesMixin, TestCase):
|
class TestTimezoneSettingsChecks:
|
||||||
@override_settings(TIME_ZONE="TheMoon\\MyCrater")
|
def test_invalid_timezone(self, settings: SettingsWrapper) -> None:
|
||||||
def test_invalid_timezone(self) -> None:
|
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
- Default settings
|
- Default settings
|
||||||
@@ -185,17 +181,16 @@ class TestTimezoneSettingsChecks(DirectoriesMixin, TestCase):
|
|||||||
THEN:
|
THEN:
|
||||||
- system check error reported for timezone
|
- system check error reported for timezone
|
||||||
"""
|
"""
|
||||||
|
settings.TIME_ZONE = "TheMoon\\MyCrater"
|
||||||
|
|
||||||
msgs = settings_values_check(None)
|
msgs = settings_values_check(None)
|
||||||
self.assertEqual(len(msgs), 1)
|
|
||||||
|
|
||||||
msg = msgs[0]
|
assert len(msgs) == 1
|
||||||
|
assert 'Timezone "TheMoon\\MyCrater"' in msgs[0].msg
|
||||||
self.assertIn('Timezone "TheMoon\\MyCrater"', msg.msg)
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmailCertSettingsChecks(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
class TestEmailCertSettingsChecks:
|
||||||
@override_settings(EMAIL_CERTIFICATE_FILE=Path("/tmp/not_actually_here.pem"))
|
def test_not_valid_file(self, settings: SettingsWrapper) -> None:
|
||||||
def test_not_valid_file(self) -> None:
|
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
- Default settings
|
- Default settings
|
||||||
@@ -205,19 +200,22 @@ class TestEmailCertSettingsChecks(DirectoriesMixin, FileSystemAssertsMixin, Test
|
|||||||
THEN:
|
THEN:
|
||||||
- system check error reported for email certificate
|
- system check error reported for email certificate
|
||||||
"""
|
"""
|
||||||
self.assertIsNotFile("/tmp/not_actually_here.pem")
|
cert_path = Path("/tmp/not_actually_here.pem")
|
||||||
|
assert not cert_path.is_file()
|
||||||
|
settings.EMAIL_CERTIFICATE_FILE = cert_path
|
||||||
|
|
||||||
msgs = settings_values_check(None)
|
msgs = settings_values_check(None)
|
||||||
|
|
||||||
self.assertEqual(len(msgs), 1)
|
assert len(msgs) == 1
|
||||||
|
assert "Email cert /tmp/not_actually_here.pem is not a file" in msgs[0].msg
|
||||||
msg = msgs[0]
|
|
||||||
|
|
||||||
self.assertIn("Email cert /tmp/not_actually_here.pem is not a file", msg.msg)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuditLogChecks(TestCase):
|
class TestAuditLogChecks:
|
||||||
def test_was_enabled_once(self) -> None:
|
def test_was_enabled_once(
|
||||||
|
self,
|
||||||
|
settings: SettingsWrapper,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
- Audit log is not enabled
|
- Audit log is not enabled
|
||||||
@@ -226,23 +224,18 @@ class TestAuditLogChecks(TestCase):
|
|||||||
THEN:
|
THEN:
|
||||||
- system check error reported for disabling audit log
|
- system check error reported for disabling audit log
|
||||||
"""
|
"""
|
||||||
introspect_mock = mock.MagicMock()
|
settings.AUDIT_LOG_ENABLED = False
|
||||||
|
introspect_mock = mocker.MagicMock()
|
||||||
introspect_mock.introspection.table_names.return_value = ["auditlog_logentry"]
|
introspect_mock.introspection.table_names.return_value = ["auditlog_logentry"]
|
||||||
with override_settings(AUDIT_LOG_ENABLED=False):
|
mocker.patch.dict(
|
||||||
with mock.patch.dict(
|
"paperless.checks.connections",
|
||||||
"paperless.checks.connections",
|
{"default": introspect_mock},
|
||||||
{"default": introspect_mock},
|
)
|
||||||
):
|
|
||||||
msgs = audit_log_check(None)
|
|
||||||
|
|
||||||
self.assertEqual(len(msgs), 1)
|
msgs = audit_log_check(None)
|
||||||
|
|
||||||
msg = msgs[0]
|
assert len(msgs) == 1
|
||||||
|
assert "auditlog table was found but audit log is disabled." in msgs[0].msg
|
||||||
self.assertIn(
|
|
||||||
("auditlog table was found but audit log is disabled."),
|
|
||||||
msg.msg,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DEPRECATED_VARS: dict[str, str] = {
|
DEPRECATED_VARS: dict[str, str] = {
|
||||||
@@ -271,20 +264,16 @@ class TestDeprecatedDbSettings:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("env_var", "db_option_key"),
|
("env_var", "db_option_key"),
|
||||||
[
|
[
|
||||||
("PAPERLESS_DB_TIMEOUT", "timeout"),
|
pytest.param("PAPERLESS_DB_TIMEOUT", "timeout", id="db-timeout"),
|
||||||
("PAPERLESS_DB_POOLSIZE", "pool.min_size / pool.max_size"),
|
pytest.param(
|
||||||
("PAPERLESS_DBSSLMODE", "sslmode"),
|
"PAPERLESS_DB_POOLSIZE",
|
||||||
("PAPERLESS_DBSSLROOTCERT", "sslrootcert"),
|
"pool.min_size / pool.max_size",
|
||||||
("PAPERLESS_DBSSLCERT", "sslcert"),
|
id="db-poolsize",
|
||||||
("PAPERLESS_DBSSLKEY", "sslkey"),
|
),
|
||||||
],
|
pytest.param("PAPERLESS_DBSSLMODE", "sslmode", id="ssl-mode"),
|
||||||
ids=[
|
pytest.param("PAPERLESS_DBSSLROOTCERT", "sslrootcert", id="ssl-rootcert"),
|
||||||
"db-timeout",
|
pytest.param("PAPERLESS_DBSSLCERT", "sslcert", id="ssl-cert"),
|
||||||
"db-poolsize",
|
pytest.param("PAPERLESS_DBSSLKEY", "sslkey", id="ssl-key"),
|
||||||
"ssl-mode",
|
|
||||||
"ssl-rootcert",
|
|
||||||
"ssl-cert",
|
|
||||||
"ssl-key",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_single_deprecated_var_produces_one_warning(
|
def test_single_deprecated_var_produces_one_warning(
|
||||||
@@ -403,7 +392,10 @@ class TestV3MinimumUpgradeVersionCheck:
|
|||||||
"""Test suite for check_v3_minimum_upgrade_version system check."""
|
"""Test suite for check_v3_minimum_upgrade_version system check."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def build_conn_mock(self, mocker: MockerFixture):
|
def build_conn_mock(
|
||||||
|
self,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> Callable[[list[str], list[str]], mock.MagicMock]:
|
||||||
"""Factory fixture that builds a connections['default'] mock.
|
"""Factory fixture that builds a connections['default'] mock.
|
||||||
|
|
||||||
Usage::
|
Usage::
|
||||||
@@ -423,7 +415,7 @@ class TestV3MinimumUpgradeVersionCheck:
|
|||||||
def test_no_migrations_table_fresh_install(
|
def test_no_migrations_table_fresh_install(
|
||||||
self,
|
self,
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
build_conn_mock,
|
build_conn_mock: Callable[[list[str], list[str]], mock.MagicMock],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
@@ -442,7 +434,7 @@ class TestV3MinimumUpgradeVersionCheck:
|
|||||||
def test_no_documents_migrations_fresh_install(
|
def test_no_documents_migrations_fresh_install(
|
||||||
self,
|
self,
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
build_conn_mock,
|
build_conn_mock: Callable[[list[str], list[str]], mock.MagicMock],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
@@ -461,7 +453,7 @@ class TestV3MinimumUpgradeVersionCheck:
|
|||||||
def test_v3_state_with_0001_squashed(
|
def test_v3_state_with_0001_squashed(
|
||||||
self,
|
self,
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
build_conn_mock,
|
build_conn_mock: Callable[[list[str], list[str]], mock.MagicMock],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
@@ -485,7 +477,7 @@ class TestV3MinimumUpgradeVersionCheck:
|
|||||||
def test_v3_state_with_0002_squashed_only(
|
def test_v3_state_with_0002_squashed_only(
|
||||||
self,
|
self,
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
build_conn_mock,
|
build_conn_mock: Callable[[list[str], list[str]], mock.MagicMock],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
@@ -504,7 +496,7 @@ class TestV3MinimumUpgradeVersionCheck:
|
|||||||
def test_v2_20_9_state_ready_to_upgrade(
|
def test_v2_20_9_state_ready_to_upgrade(
|
||||||
self,
|
self,
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
build_conn_mock,
|
build_conn_mock: Callable[[list[str], list[str]], mock.MagicMock],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
@@ -531,7 +523,7 @@ class TestV3MinimumUpgradeVersionCheck:
|
|||||||
def test_v2_20_8_raises_error(
|
def test_v2_20_8_raises_error(
|
||||||
self,
|
self,
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
build_conn_mock,
|
build_conn_mock: Callable[[list[str], list[str]], mock.MagicMock],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
@@ -558,7 +550,7 @@ class TestV3MinimumUpgradeVersionCheck:
|
|||||||
def test_very_old_version_raises_error(
|
def test_very_old_version_raises_error(
|
||||||
self,
|
self,
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
build_conn_mock,
|
build_conn_mock: Callable[[list[str], list[str]], mock.MagicMock],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
@@ -585,7 +577,7 @@ class TestV3MinimumUpgradeVersionCheck:
|
|||||||
def test_error_hint_mentions_v2_20_9(
|
def test_error_hint_mentions_v2_20_9(
|
||||||
self,
|
self,
|
||||||
mocker: MockerFixture,
|
mocker: MockerFixture,
|
||||||
build_conn_mock,
|
build_conn_mock: Callable[[list[str], list[str]], mock.MagicMock],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
|
|||||||
@@ -9,35 +9,50 @@ from paperless.utils import ocr_to_dateparser_languages
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("ocr_language", "expected"),
|
("ocr_language", "expected"),
|
||||||
[
|
[
|
||||||
# One language
|
pytest.param("eng", ["en"], id="single-language"),
|
||||||
("eng", ["en"]),
|
pytest.param("fra+ita+lao", ["fr", "it", "lo"], id="multiple-languages"),
|
||||||
# Multiple languages
|
pytest.param("fil", ["fil"], id="no-two-letter-equivalent"),
|
||||||
("fra+ita+lao", ["fr", "it", "lo"]),
|
pytest.param(
|
||||||
# Languages that don't have a two-letter equivalent
|
"aze_cyrl+srp_latn",
|
||||||
("fil", ["fil"]),
|
["az-Cyrl", "sr-Latn"],
|
||||||
# Languages with a script part supported by dateparser
|
id="script-supported-by-dateparser",
|
||||||
("aze_cyrl+srp_latn", ["az-Cyrl", "sr-Latn"]),
|
),
|
||||||
# Languages with a script part not supported by dateparser
|
pytest.param(
|
||||||
# In this case, default to the language without script
|
"deu_frak",
|
||||||
("deu_frak", ["de"]),
|
["de"],
|
||||||
# Traditional and simplified chinese don't have the same name in dateparser,
|
id="script-not-supported-falls-back-to-language",
|
||||||
# so they're converted to the general chinese language
|
),
|
||||||
("chi_tra+chi_sim", ["zh"]),
|
pytest.param(
|
||||||
# If a language is not supported by dateparser, fallback to the supported ones
|
"chi_tra+chi_sim",
|
||||||
("eng+unsupported_language+por", ["en", "pt"]),
|
["zh"],
|
||||||
# If no language is supported, fallback to default
|
id="chinese-variants-collapse-to-general",
|
||||||
("unsupported1+unsupported2", []),
|
),
|
||||||
# Duplicate languages, should not duplicate in result
|
pytest.param(
|
||||||
("eng+eng", ["en"]),
|
"eng+unsupported_language+por",
|
||||||
# Language with script, but script is not mapped
|
["en", "pt"],
|
||||||
("ita_unknownscript", ["it"]),
|
id="unsupported-language-skipped",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"unsupported1+unsupported2",
|
||||||
|
[],
|
||||||
|
id="all-unsupported-returns-empty",
|
||||||
|
),
|
||||||
|
pytest.param("eng+eng", ["en"], id="duplicates-deduplicated"),
|
||||||
|
pytest.param(
|
||||||
|
"ita_unknownscript",
|
||||||
|
["it"],
|
||||||
|
id="unknown-script-falls-back-to-language",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_ocr_to_dateparser_languages(ocr_language, expected):
|
def test_ocr_to_dateparser_languages(ocr_language: str, expected: list[str]) -> None:
|
||||||
assert sorted(ocr_to_dateparser_languages(ocr_language)) == sorted(expected)
|
assert sorted(ocr_to_dateparser_languages(ocr_language)) == sorted(expected)
|
||||||
|
|
||||||
|
|
||||||
def test_ocr_to_dateparser_languages_exception(monkeypatch, caplog):
|
def test_ocr_to_dateparser_languages_exception(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
# Patch LocaleDataLoader.get_locale_map to raise an exception
|
# Patch LocaleDataLoader.get_locale_map to raise an exception
|
||||||
class DummyLoader:
|
class DummyLoader:
|
||||||
def get_locale_map(self, locales=None):
|
def get_locale_map(self, locales=None):
|
||||||
|
|||||||
@@ -1,24 +1,31 @@
|
|||||||
import tempfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from django.test import override_settings
|
from django.test import Client
|
||||||
|
from pytest_django.fixtures import SettingsWrapper
|
||||||
|
|
||||||
|
|
||||||
def test_favicon_view(client):
|
def test_favicon_view(
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
client: Client,
|
||||||
static_dir = Path(tmpdir)
|
tmp_path: Path,
|
||||||
favicon_path = static_dir / "paperless" / "img" / "favicon.ico"
|
settings: SettingsWrapper,
|
||||||
favicon_path.parent.mkdir(parents=True, exist_ok=True)
|
) -> None:
|
||||||
favicon_path.write_bytes(b"FAKE ICON DATA")
|
favicon_path = tmp_path / "paperless" / "img" / "favicon.ico"
|
||||||
|
favicon_path.parent.mkdir(parents=True)
|
||||||
|
favicon_path.write_bytes(b"FAKE ICON DATA")
|
||||||
|
|
||||||
with override_settings(STATIC_ROOT=static_dir):
|
settings.STATIC_ROOT = tmp_path
|
||||||
response = client.get("/favicon.ico")
|
|
||||||
assert response.status_code == 200
|
response = client.get("/favicon.ico")
|
||||||
assert response["Content-Type"] == "image/x-icon"
|
assert response.status_code == 200
|
||||||
assert b"".join(response.streaming_content) == b"FAKE ICON DATA"
|
assert response["Content-Type"] == "image/x-icon"
|
||||||
|
assert b"".join(response.streaming_content) == b"FAKE ICON DATA"
|
||||||
|
|
||||||
|
|
||||||
def test_favicon_view_missing_file(client):
|
def test_favicon_view_missing_file(
|
||||||
with override_settings(STATIC_ROOT=Path(tempfile.mkdtemp())):
|
client: Client,
|
||||||
response = client.get("/favicon.ico")
|
tmp_path: Path,
|
||||||
assert response.status_code == 404
|
settings: SettingsWrapper,
|
||||||
|
) -> None:
|
||||||
|
settings.STATIC_ROOT = tmp_path
|
||||||
|
response = client.get("/favicon.ico")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from llama_index.core.bridge.pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class DocumentClassifierSchema(BaseModel):
|
class DocumentClassifierSchema(BaseModel):
|
||||||
|
|||||||
@@ -1,10 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from llama_index.core import VectorStoreIndex
|
|
||||||
from llama_index.core.prompts import PromptTemplate
|
|
||||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless_ai.client import AIClient
|
from paperless_ai.client import AIClient
|
||||||
from paperless_ai.indexing import load_or_build_index
|
from paperless_ai.indexing import load_or_build_index
|
||||||
@@ -14,15 +10,13 @@ logger = logging.getLogger("paperless_ai.chat")
|
|||||||
MAX_SINGLE_DOC_CONTEXT_CHARS = 15000
|
MAX_SINGLE_DOC_CONTEXT_CHARS = 15000
|
||||||
SINGLE_DOC_SNIPPET_CHARS = 800
|
SINGLE_DOC_SNIPPET_CHARS = 800
|
||||||
|
|
||||||
CHAT_PROMPT_TMPL = PromptTemplate(
|
CHAT_PROMPT_TMPL = """Context information is below.
|
||||||
template="""Context information is below.
|
|
||||||
---------------------
|
---------------------
|
||||||
{context_str}
|
{context_str}
|
||||||
---------------------
|
---------------------
|
||||||
Given the context information and not prior knowledge, answer the query.
|
Given the context information and not prior knowledge, answer the query.
|
||||||
Query: {query_str}
|
Query: {query_str}
|
||||||
Answer:""",
|
Answer:"""
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||||
@@ -43,6 +37,10 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
|||||||
yield "Sorry, I couldn't find any content to answer your question."
|
yield "Sorry, I couldn't find any content to answer your question."
|
||||||
return
|
return
|
||||||
|
|
||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
|
from llama_index.core.prompts import PromptTemplate
|
||||||
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||||
|
|
||||||
local_index = VectorStoreIndex(nodes=nodes)
|
local_index = VectorStoreIndex(nodes=nodes)
|
||||||
retriever = local_index.as_retriever(
|
retriever = local_index.as_retriever(
|
||||||
similarity_top_k=3 if len(documents) == 1 else 5,
|
similarity_top_k=3 if len(documents) == 1 else 5,
|
||||||
@@ -85,7 +83,8 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
|||||||
for node in top_nodes
|
for node in top_nodes
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = CHAT_PROMPT_TMPL.partial_format(
|
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
|
||||||
|
prompt = prompt_template.partial_format(
|
||||||
context_str=context,
|
context_str=context,
|
||||||
query_str=query_str,
|
query_str=query_str,
|
||||||
).format(llm=client.llm)
|
).format(llm=client.llm)
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from llama_index.core.llms import ChatMessage
|
if TYPE_CHECKING:
|
||||||
from llama_index.core.program.function_program import get_function_tool
|
from llama_index.core.llms import ChatMessage
|
||||||
from llama_index.llms.ollama import Ollama
|
from llama_index.llms.ollama import Ollama
|
||||||
from llama_index.llms.openai import OpenAI
|
from llama_index.llms.openai import OpenAI
|
||||||
|
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
from paperless_ai.base_model import DocumentClassifierSchema
|
from paperless_ai.base_model import DocumentClassifierSchema
|
||||||
@@ -20,14 +21,18 @@ class AIClient:
|
|||||||
self.settings = AIConfig()
|
self.settings = AIConfig()
|
||||||
self.llm = self.get_llm()
|
self.llm = self.get_llm()
|
||||||
|
|
||||||
def get_llm(self) -> Ollama | OpenAI:
|
def get_llm(self) -> "Ollama | OpenAI":
|
||||||
if self.settings.llm_backend == "ollama":
|
if self.settings.llm_backend == "ollama":
|
||||||
|
from llama_index.llms.ollama import Ollama
|
||||||
|
|
||||||
return Ollama(
|
return Ollama(
|
||||||
model=self.settings.llm_model or "llama3.1",
|
model=self.settings.llm_model or "llama3.1",
|
||||||
base_url=self.settings.llm_endpoint or "http://localhost:11434",
|
base_url=self.settings.llm_endpoint or "http://localhost:11434",
|
||||||
request_timeout=120,
|
request_timeout=120,
|
||||||
)
|
)
|
||||||
elif self.settings.llm_backend == "openai":
|
elif self.settings.llm_backend == "openai":
|
||||||
|
from llama_index.llms.openai import OpenAI
|
||||||
|
|
||||||
return OpenAI(
|
return OpenAI(
|
||||||
model=self.settings.llm_model or "gpt-3.5-turbo",
|
model=self.settings.llm_model or "gpt-3.5-turbo",
|
||||||
api_base=self.settings.llm_endpoint or None,
|
api_base=self.settings.llm_endpoint or None,
|
||||||
@@ -43,6 +48,9 @@ class AIClient:
|
|||||||
self.settings.llm_model,
|
self.settings.llm_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_index.core.llms import ChatMessage
|
||||||
|
from llama_index.core.program.function_program import get_function_tool
|
||||||
|
|
||||||
user_msg = ChatMessage(role="user", content=prompt)
|
user_msg = ChatMessage(role="user", content=prompt)
|
||||||
tool = get_function_tool(DocumentClassifierSchema)
|
tool = get_function_tool(DocumentClassifierSchema)
|
||||||
result = self.llm.chat_with_tools(
|
result = self.llm.chat_with_tools(
|
||||||
@@ -58,7 +66,7 @@ class AIClient:
|
|||||||
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
|
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
|
||||||
return parsed.model_dump()
|
return parsed.model_dump()
|
||||||
|
|
||||||
def run_chat(self, messages: list[ChatMessage]) -> str:
|
def run_chat(self, messages: list["ChatMessage"]) -> str:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Running chat query against %s with model %s",
|
"Running chat query against %s with model %s",
|
||||||
self.settings.llm_backend,
|
self.settings.llm_backend,
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from django.conf import settings
|
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
||||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.models import Note
|
from documents.models import Note
|
||||||
@@ -15,17 +14,21 @@ from paperless.config import AIConfig
|
|||||||
from paperless.models import LLMEmbeddingBackend
|
from paperless.models import LLMEmbeddingBackend
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model() -> BaseEmbedding:
|
def get_embedding_model() -> "BaseEmbedding":
|
||||||
config = AIConfig()
|
config = AIConfig()
|
||||||
|
|
||||||
match config.llm_embedding_backend:
|
match config.llm_embedding_backend:
|
||||||
case LLMEmbeddingBackend.OPENAI:
|
case LLMEmbeddingBackend.OPENAI:
|
||||||
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||||
|
|
||||||
return OpenAIEmbedding(
|
return OpenAIEmbedding(
|
||||||
model=config.llm_embedding_model or "text-embedding-3-small",
|
model=config.llm_embedding_model or "text-embedding-3-small",
|
||||||
api_key=config.llm_api_key,
|
api_key=config.llm_api_key,
|
||||||
api_base=config.llm_endpoint or None,
|
api_base=config.llm_endpoint or None,
|
||||||
)
|
)
|
||||||
case LLMEmbeddingBackend.HUGGINGFACE:
|
case LLMEmbeddingBackend.HUGGINGFACE:
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
|
||||||
return HuggingFaceEmbedding(
|
return HuggingFaceEmbedding(
|
||||||
model_name=config.llm_embedding_model
|
model_name=config.llm_embedding_model
|
||||||
or "sentence-transformers/all-MiniLM-L6-v2",
|
or "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
|||||||
@@ -4,26 +4,12 @@ from collections.abc import Callable
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
import faiss
|
|
||||||
import llama_index.core.settings as llama_settings
|
|
||||||
from celery import states
|
from celery import states
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from llama_index.core import Document as LlamaDocument
|
|
||||||
from llama_index.core import StorageContext
|
|
||||||
from llama_index.core import VectorStoreIndex
|
|
||||||
from llama_index.core import load_index_from_storage
|
|
||||||
from llama_index.core.indices.prompt_helper import PromptHelper
|
|
||||||
from llama_index.core.node_parser import SimpleNodeParser
|
|
||||||
from llama_index.core.prompts import PromptTemplate
|
|
||||||
from llama_index.core.retrievers import VectorIndexRetriever
|
|
||||||
from llama_index.core.schema import BaseNode
|
|
||||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
|
||||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
|
||||||
from llama_index.core.text_splitter import TokenTextSplitter
|
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.models import PaperlessTask
|
from documents.models import PaperlessTask
|
||||||
@@ -34,6 +20,10 @@ from paperless_ai.embedding import get_embedding_model
|
|||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
IterWrapper = Callable[[Iterable[_T]], Iterable[_T]]
|
IterWrapper = Callable[[Iterable[_T]], Iterable[_T]]
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
|
from llama_index.core.schema import BaseNode
|
||||||
|
|
||||||
|
|
||||||
def _identity(iterable: Iterable[_T]) -> Iterable[_T]:
|
def _identity(iterable: Iterable[_T]) -> Iterable[_T]:
|
||||||
return iterable
|
return iterable
|
||||||
@@ -75,12 +65,23 @@ def get_or_create_storage_context(*, rebuild=False):
|
|||||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
||||||
|
import faiss
|
||||||
|
from llama_index.core import StorageContext
|
||||||
|
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||||
|
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
embedding_dim = get_embedding_dim()
|
embedding_dim = get_embedding_dim()
|
||||||
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
docstore = SimpleDocumentStore()
|
docstore = SimpleDocumentStore()
|
||||||
index_store = SimpleIndexStore()
|
index_store = SimpleIndexStore()
|
||||||
else:
|
else:
|
||||||
|
from llama_index.core import StorageContext
|
||||||
|
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||||
|
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
@@ -93,7 +94,7 @@ def get_or_create_storage_context(*, rebuild=False):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_document_node(document: Document) -> list[BaseNode]:
|
def build_document_node(document: Document) -> list["BaseNode"]:
|
||||||
"""
|
"""
|
||||||
Given a Document, returns parsed Nodes ready for indexing.
|
Given a Document, returns parsed Nodes ready for indexing.
|
||||||
"""
|
"""
|
||||||
@@ -112,6 +113,9 @@ def build_document_node(document: Document) -> list[BaseNode]:
|
|||||||
"added": document.added.isoformat() if document.added else None,
|
"added": document.added.isoformat() if document.added else None,
|
||||||
"modified": document.modified.isoformat(),
|
"modified": document.modified.isoformat(),
|
||||||
}
|
}
|
||||||
|
from llama_index.core import Document as LlamaDocument
|
||||||
|
from llama_index.core.node_parser import SimpleNodeParser
|
||||||
|
|
||||||
doc = LlamaDocument(text=text, metadata=metadata)
|
doc = LlamaDocument(text=text, metadata=metadata)
|
||||||
parser = SimpleNodeParser()
|
parser = SimpleNodeParser()
|
||||||
return parser.get_nodes_from_documents([doc])
|
return parser.get_nodes_from_documents([doc])
|
||||||
@@ -122,6 +126,10 @@ def load_or_build_index(nodes=None):
|
|||||||
Load an existing VectorStoreIndex if present,
|
Load an existing VectorStoreIndex if present,
|
||||||
or build a new one using provided nodes if storage is empty.
|
or build a new one using provided nodes if storage is empty.
|
||||||
"""
|
"""
|
||||||
|
import llama_index.core.settings as llama_settings
|
||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
|
from llama_index.core import load_index_from_storage
|
||||||
|
|
||||||
embed_model = get_embedding_model()
|
embed_model = get_embedding_model()
|
||||||
llama_settings.Settings.embed_model = embed_model
|
llama_settings.Settings.embed_model = embed_model
|
||||||
storage_context = get_or_create_storage_context()
|
storage_context = get_or_create_storage_context()
|
||||||
@@ -143,7 +151,7 @@ def load_or_build_index(nodes=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex):
|
def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"):
|
||||||
"""
|
"""
|
||||||
Removes existing documents from docstore for a given document from the index.
|
Removes existing documents from docstore for a given document from the index.
|
||||||
This is necessary because FAISS IndexFlatL2 is append-only.
|
This is necessary because FAISS IndexFlatL2 is append-only.
|
||||||
@@ -174,6 +182,8 @@ def update_llm_index(
|
|||||||
"""
|
"""
|
||||||
Rebuild or update the LLM index.
|
Rebuild or update the LLM index.
|
||||||
"""
|
"""
|
||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
|
|
||||||
nodes = []
|
nodes = []
|
||||||
|
|
||||||
documents = Document.objects.all()
|
documents = Document.objects.all()
|
||||||
@@ -187,6 +197,8 @@ def update_llm_index(
|
|||||||
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
|
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
|
||||||
# Rebuild index from scratch
|
# Rebuild index from scratch
|
||||||
logger.info("Rebuilding LLM index.")
|
logger.info("Rebuilding LLM index.")
|
||||||
|
import llama_index.core.settings as llama_settings
|
||||||
|
|
||||||
embed_model = get_embedding_model()
|
embed_model = get_embedding_model()
|
||||||
llama_settings.Settings.embed_model = embed_model
|
llama_settings.Settings.embed_model = embed_model
|
||||||
storage_context = get_or_create_storage_context(rebuild=True)
|
storage_context = get_or_create_storage_context(rebuild=True)
|
||||||
@@ -271,6 +283,10 @@ def llm_index_remove_document(document: Document):
|
|||||||
|
|
||||||
|
|
||||||
def truncate_content(content: str) -> str:
|
def truncate_content(content: str) -> str:
|
||||||
|
from llama_index.core.indices.prompt_helper import PromptHelper
|
||||||
|
from llama_index.core.prompts import PromptTemplate
|
||||||
|
from llama_index.core.text_splitter import TokenTextSplitter
|
||||||
|
|
||||||
prompt_helper = PromptHelper(
|
prompt_helper = PromptHelper(
|
||||||
context_window=8192,
|
context_window=8192,
|
||||||
num_output=512,
|
num_output=512,
|
||||||
@@ -315,6 +331,8 @@ def query_similar_documents(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_index.core.retrievers import VectorIndexRetriever
|
||||||
|
|
||||||
retriever = VectorIndexRetriever(
|
retriever = VectorIndexRetriever(
|
||||||
index=index,
|
index=index,
|
||||||
similarity_top_k=top_k,
|
similarity_top_k=top_k,
|
||||||
|
|||||||
@@ -181,11 +181,11 @@ def test_load_or_build_index_builds_when_nodes_given(
|
|||||||
) -> None:
|
) -> None:
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"paperless_ai.indexing.load_index_from_storage",
|
"llama_index.core.load_index_from_storage",
|
||||||
side_effect=ValueError("Index not found"),
|
side_effect=ValueError("Index not found"),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"paperless_ai.indexing.VectorStoreIndex",
|
"llama_index.core.VectorStoreIndex",
|
||||||
return_value=MagicMock(),
|
return_value=MagicMock(),
|
||||||
) as mock_index_cls,
|
) as mock_index_cls,
|
||||||
patch(
|
patch(
|
||||||
@@ -206,7 +206,7 @@ def test_load_or_build_index_raises_exception_when_no_nodes(
|
|||||||
) -> None:
|
) -> None:
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"paperless_ai.indexing.load_index_from_storage",
|
"llama_index.core.load_index_from_storage",
|
||||||
side_effect=ValueError("Index not found"),
|
side_effect=ValueError("Index not found"),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
@@ -225,11 +225,11 @@ def test_load_or_build_index_succeeds_when_nodes_given(
|
|||||||
) -> None:
|
) -> None:
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"paperless_ai.indexing.load_index_from_storage",
|
"llama_index.core.load_index_from_storage",
|
||||||
side_effect=ValueError("Index not found"),
|
side_effect=ValueError("Index not found"),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"paperless_ai.indexing.VectorStoreIndex",
|
"llama_index.core.VectorStoreIndex",
|
||||||
return_value=MagicMock(),
|
return_value=MagicMock(),
|
||||||
) as mock_index_cls,
|
) as mock_index_cls,
|
||||||
patch(
|
patch(
|
||||||
@@ -334,7 +334,7 @@ def test_query_similar_documents(
|
|||||||
patch(
|
patch(
|
||||||
"paperless_ai.indexing.vector_store_file_exists",
|
"paperless_ai.indexing.vector_store_file_exists",
|
||||||
) as mock_vector_store_exists,
|
) as mock_vector_store_exists,
|
||||||
patch("paperless_ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
|
||||||
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
|
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
|
||||||
):
|
):
|
||||||
mock_storage.return_value = MagicMock()
|
mock_storage.return_value = MagicMock()
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None:
|
|||||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||||
patch(
|
patch(
|
||||||
"paperless_ai.chat.RetrieverQueryEngine.from_args",
|
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
|
||||||
) as mock_query_engine_cls,
|
) as mock_query_engine_cls,
|
||||||
):
|
):
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
@@ -76,7 +76,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
|||||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||||
patch(
|
patch(
|
||||||
"paperless_ai.chat.RetrieverQueryEngine.from_args",
|
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
|
||||||
) as mock_query_engine_cls,
|
) as mock_query_engine_cls,
|
||||||
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
|
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -18,13 +18,13 @@ def mock_ai_config():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_ollama_llm():
|
def mock_ollama_llm():
|
||||||
with patch("paperless_ai.client.Ollama") as MockOllama:
|
with patch("llama_index.llms.ollama.Ollama") as MockOllama:
|
||||||
yield MockOllama
|
yield MockOllama
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_openai_llm():
|
def mock_openai_llm():
|
||||||
with patch("paperless_ai.client.OpenAI") as MockOpenAI:
|
with patch("llama_index.llms.openai.OpenAI") as MockOpenAI:
|
||||||
yield MockOpenAI
|
yield MockOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ def test_get_embedding_model_openai(mock_ai_config):
|
|||||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||||
mock_ai_config.return_value.llm_endpoint = "http://test-url"
|
mock_ai_config.return_value.llm_endpoint = "http://test-url"
|
||||||
|
|
||||||
with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
|
with patch("llama_index.embeddings.openai.OpenAIEmbedding") as MockOpenAIEmbedding:
|
||||||
model = get_embedding_model()
|
model = get_embedding_model()
|
||||||
MockOpenAIEmbedding.assert_called_once_with(
|
MockOpenAIEmbedding.assert_called_once_with(
|
||||||
model="text-embedding-3-small",
|
model="text-embedding-3-small",
|
||||||
@@ -84,7 +84,7 @@ def test_get_embedding_model_huggingface(mock_ai_config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"paperless_ai.embedding.HuggingFaceEmbedding",
|
"llama_index.embeddings.huggingface.HuggingFaceEmbedding",
|
||||||
) as MockHuggingFaceEmbedding:
|
) as MockHuggingFaceEmbedding:
|
||||||
model = get_embedding_model()
|
model = get_embedding_model()
|
||||||
MockHuggingFaceEmbedding.assert_called_once_with(
|
MockHuggingFaceEmbedding.assert_called_once_with(
|
||||||
|
|||||||
Reference in New Issue
Block a user