mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-05-05 14:15:24 +00:00
Feature: Further reduce document importer memory usage (#12707)
* Replaces loaddata with streaming bulk_create
Replaces call_command('loaddata') with a streaming implementation that
reads manifest records one at a time via ijson, accumulates per-model
batches up to --batch-size, and flushes via bulk_create. This reduces
peak memory and no longer scales directly with the size of the import.
* fix(importer): avoid guardian lru_cache poisoning; include M2M through tables in check_constraints
clear_cache() inside the import transaction emptied Django's ContentType
manager cache while fixture PKs were live, causing downstream ContentType
lookups to repopulate guardian's separate @lru_cache(None) with
fixture-PK objects. After the TestCase transaction rolled back to
original PKs, guardian's lru_cache held stale fixture ContentType
objects, causing MixedContentTypeError in unrelated subsequent tests.
Remove clear_cache() since it was defending against a theoretical
stale-cache scenario that doesn't occur in a proper same-install restore.
Fix check_constraints() to explicitly include auto-created M2M through
tables (populated by .set() after bulk_create) alongside the model tables,
addressing the gap where join-table FK violations would have gone
undetected.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* Excludes the consumer and AnonymousUser from any models which might have a FK relation to it. This prevents orphan things like UI setting, which have a relation to no existing user
* Splits into more sub functions for Sonar
* Improvements to the typing of the new functions
* Coverage for some error cases, and removes handling for pk only models. No need to support these
* Final coverage gaps
---------
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -357,12 +357,13 @@ and the script does the rest of the work:
|
||||
document_importer source
|
||||
```
|
||||
|
||||
| Option | Required | Default | Description |
|
||||
| ------------------- | -------- | ------- | ------------------------------------------------------------------------- |
|
||||
| source | Yes | N/A | The directory containing an export |
|
||||
| `--no-progress-bar` | No | False | If provided, the progress bar will be hidden |
|
||||
| `--data-only` | No | False | If provided, only import data, do not import document files or thumbnails |
|
||||
| `--passphrase` | No | N/A | If your export was encrypted with a passphrase, must be provided |
|
||||
| Option | Required | Default | Description |
|
||||
| ------------------- | -------- | ------- | ------------------------------------------------------------------------------------------------------------ |
|
||||
| source | Yes | N/A | The directory containing an export |
|
||||
| `--no-progress-bar` | No | False | If provided, the progress bar will be hidden |
|
||||
| `--data-only` | No | False | If provided, only import data, do not import document files or thumbnails |
|
||||
| `--passphrase` | No | N/A | If your export was encrypted with a passphrase, must be provided |
|
||||
| `--batch-size` | No | 500 | Number of database records inserted per batch. Lower values reduce peak memory usage on very large installs. |
|
||||
|
||||
When you use the provided docker compose script, put the export inside
|
||||
the `export` folder in your paperless source directory. Specify
|
||||
|
||||
@@ -365,6 +365,7 @@ class Command(CryptMixin, PaperlessCommand):
|
||||
|
||||
# 2. Create manifest, containing all correspondents, types, tags, storage paths
|
||||
# note, documents and ui_settings
|
||||
_excluded_usernames = ["consumer", "AnonymousUser"]
|
||||
manifest_key_to_object_query: dict[str, QuerySet[Any]] = {
|
||||
"correspondents": Correspondent.objects.all(),
|
||||
"tags": Tag.objects.all(),
|
||||
@@ -376,12 +377,16 @@ class Command(CryptMixin, PaperlessCommand):
|
||||
"saved_view_filter_rules": SavedViewFilterRule.objects.all(),
|
||||
"groups": Group.objects.all(),
|
||||
"users": User.objects.exclude(
|
||||
username__in=["consumer", "AnonymousUser"],
|
||||
username__in=_excluded_usernames,
|
||||
).all(),
|
||||
"ui_settings": UiSettings.objects.all(),
|
||||
"ui_settings": UiSettings.objects.exclude(
|
||||
user__username__in=_excluded_usernames,
|
||||
),
|
||||
"content_types": ContentType.objects.all(),
|
||||
"permissions": Permission.objects.all(),
|
||||
"user_object_permissions": UserObjectPermission.objects.all(),
|
||||
"user_object_permissions": UserObjectPermission.objects.exclude(
|
||||
user__username__in=_excluded_usernames,
|
||||
),
|
||||
"group_object_permissions": GroupObjectPermission.objects.all(),
|
||||
"workflow_triggers": WorkflowTrigger.objects.all(),
|
||||
"workflow_actions": WorkflowAction.objects.all(),
|
||||
@@ -395,10 +400,16 @@ class Command(CryptMixin, PaperlessCommand):
|
||||
"documents": Document.global_objects.order_by("id").all(),
|
||||
"share_links": ShareLink.global_objects.all(),
|
||||
"share_link_bundles": ShareLinkBundle.objects.order_by("id").all(),
|
||||
"social_accounts": SocialAccount.objects.all(),
|
||||
"social_accounts": SocialAccount.objects.exclude(
|
||||
user__username__in=_excluded_usernames,
|
||||
),
|
||||
"social_apps": SocialApp.objects.all(),
|
||||
"social_tokens": SocialToken.objects.all(),
|
||||
"authenticators": Authenticator.objects.all(),
|
||||
"social_tokens": SocialToken.objects.exclude(
|
||||
account__user__username__in=_excluded_usernames,
|
||||
),
|
||||
"authenticators": Authenticator.objects.exclude(
|
||||
user__username__in=_excluded_usernames,
|
||||
),
|
||||
}
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
|
||||
@@ -2,13 +2,16 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import TypeAlias
|
||||
from zipfile import ZipFile
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import ijson
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
@@ -16,9 +19,14 @@ from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.core.management import call_command
|
||||
from django.core.management.base import CommandError
|
||||
from django.core.management.color import no_style
|
||||
from django.core.serializers.base import DeserializationError
|
||||
from django.db import IntegrityError
|
||||
from django.db import connection
|
||||
from django.db import models as django_models
|
||||
from django.db import transaction
|
||||
from django.db.models import GeneratedField
|
||||
from django.db.models import Model
|
||||
from django.db.models.signals import m2m_changed
|
||||
from django.db.models.signals import post_save
|
||||
from filelock import FileLock
|
||||
@@ -47,6 +55,9 @@ from paperless import version
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
from auditlog.registry import auditlog
|
||||
|
||||
# Maps M2M field names to the list of related PKs to apply after bulk_create.
|
||||
M2MData: TypeAlias = dict[str, list[int]]
|
||||
|
||||
|
||||
def iter_manifest_records(path: Path) -> Generator[dict, None, None]:
|
||||
"""Yield records one at a time from a manifest JSON array via ijson."""
|
||||
@@ -57,6 +68,107 @@ def iter_manifest_records(path: Path) -> Generator[dict, None, None]:
|
||||
raise CommandError(f"Failed to parse manifest file {path}: {e}") from e
|
||||
|
||||
|
||||
def _deserialize_record(
|
||||
record: dict,
|
||||
) -> tuple[type[Model], Model, M2MData]:
|
||||
"""
|
||||
Convert a single manifest record dict into a model instance and M2M data.
|
||||
|
||||
Returns (Model class, unsaved instance, m2m_data) where m2m_data maps
|
||||
M2M field names to lists of integer PKs to be applied after the instance
|
||||
is saved via bulk_create.
|
||||
|
||||
Raises DeserializationError for unknown models or bad field values.
|
||||
Raises FieldDoesNotExist for fields not present on the model.
|
||||
|
||||
Note: CommandError from iter_manifest_records (malformed JSON mid-stream)
|
||||
propagates through the caller unchanged, it is not caught here.
|
||||
"""
|
||||
model_label = record["model"]
|
||||
pk_value = record.get("pk")
|
||||
|
||||
try:
|
||||
Model = apps.get_model(model_label)
|
||||
except (LookupError, TypeError) as e:
|
||||
raise DeserializationError(
|
||||
f"Invalid model identifier: {model_label}",
|
||||
) from e
|
||||
|
||||
data: dict = {}
|
||||
m2m_data: M2MData = {}
|
||||
|
||||
try:
|
||||
data[Model._meta.pk.attname] = Model._meta.pk.to_python(pk_value)
|
||||
except Exception as e:
|
||||
raise DeserializationError(
|
||||
f"Could not coerce pk={pk_value} for {model_label}: {e}",
|
||||
) from e
|
||||
|
||||
for field_name, field_value in record.get("fields", {}).items():
|
||||
field = Model._meta.get_field(field_name)
|
||||
remote = field.remote_field
|
||||
|
||||
if isinstance(remote, django_models.ManyToManyRel):
|
||||
# Collect M2M PKs; .set() is called after bulk_create in flush_model.
|
||||
target_pk = field.related_model._meta.pk
|
||||
m2m_data[field.name] = [
|
||||
target_pk.to_python(pk) for pk in (field_value or [])
|
||||
]
|
||||
|
||||
elif isinstance(remote, django_models.ManyToOneRel):
|
||||
# FK: store the integer PK on field.attname (e.g. correspondent_id)
|
||||
# to avoid triggering the descriptor and avoid an extra DB lookup.
|
||||
if field_value is None:
|
||||
data[field.attname] = None
|
||||
else:
|
||||
data[field.attname] = field.related_model._meta.pk.to_python(
|
||||
field_value,
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
data[field.name] = field.to_python(field_value)
|
||||
except Exception as e:
|
||||
raise DeserializationError(
|
||||
f"Could not coerce {field_name}={field_value!r} "
|
||||
f"for {model_label}(pk={pk_value}): {e}",
|
||||
) from e
|
||||
|
||||
return Model, Model(**data), m2m_data
|
||||
|
||||
|
||||
def _iter_document_copy_records(
|
||||
manifest_paths: list[Path],
|
||||
) -> Generator[dict, None, None]:
|
||||
"""Yield one lightweight dict per Document record without buffering all records."""
|
||||
for manifest_path in manifest_paths:
|
||||
for record in iter_manifest_records(manifest_path):
|
||||
if record["model"] == "documents.document":
|
||||
yield {
|
||||
"pk": record["pk"],
|
||||
EXPORTER_FILE_NAME: record[EXPORTER_FILE_NAME],
|
||||
EXPORTER_THUMBNAIL_NAME: record.get(EXPORTER_THUMBNAIL_NAME),
|
||||
EXPORTER_ARCHIVE_NAME: record.get(EXPORTER_ARCHIVE_NAME),
|
||||
}
|
||||
|
||||
|
||||
def _iter_share_link_bundle_copy_records(
|
||||
manifest_paths: list[Path],
|
||||
) -> Generator[dict, None, None]:
|
||||
"""Yield one dict per ShareLinkBundle record that has a bundle file."""
|
||||
for manifest_path in manifest_paths:
|
||||
for record in iter_manifest_records(manifest_path):
|
||||
if record["model"] == "documents.sharelinkbundle" and record.get(
|
||||
EXPORTER_SHARE_LINK_BUNDLE_NAME,
|
||||
):
|
||||
yield {
|
||||
"pk": record["pk"],
|
||||
EXPORTER_SHARE_LINK_BUNDLE_NAME: record[
|
||||
EXPORTER_SHARE_LINK_BUNDLE_NAME
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_signal(sig, receiver, sender, *, weak: bool | None = None) -> Generator:
|
||||
try:
|
||||
@@ -92,6 +204,14 @@ class Command(CryptMixin, PaperlessCommand):
|
||||
help="If provided, is used to sensitive fields in the export",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of records to insert per batch during database load. "
|
||||
"Lower values reduce peak memory usage.",
|
||||
)
|
||||
|
||||
def pre_check(self) -> None:
|
||||
"""
|
||||
Runs some initial checks against the state of the install and source, including:
|
||||
@@ -197,36 +317,117 @@ class Command(CryptMixin, PaperlessCommand):
|
||||
),
|
||||
)
|
||||
|
||||
def _finalize_db_load(self, loaded_models: set[type[Model]]) -> None:
|
||||
"""Verify referential integrity and reset auto-increment sequences."""
|
||||
through_tables = {
|
||||
field.remote_field.through._meta.db_table
|
||||
for model in loaded_models
|
||||
for field in model._meta.many_to_many
|
||||
if field.remote_field.through is not None
|
||||
and field.remote_field.through._meta.auto_created
|
||||
}
|
||||
table_names = [m._meta.db_table for m in loaded_models] + list(through_tables)
|
||||
if table_names:
|
||||
connection.check_constraints(table_names=table_names)
|
||||
|
||||
if loaded_models:
|
||||
sequence_sql = connection.ops.sequence_reset_sql(
|
||||
no_style(),
|
||||
list(loaded_models),
|
||||
)
|
||||
with connection.cursor() as cursor:
|
||||
for sql in sequence_sql:
|
||||
cursor.execute(sql) # pragma: no cover
|
||||
|
||||
def _import_error_context_message(self) -> str:
|
||||
"""Return a diagnostic string explaining a DB import failure."""
|
||||
if ( # pragma: no cover
|
||||
self.version is not None and self.version != version.__full_version_str__
|
||||
):
|
||||
return ( # pragma: no cover
|
||||
"Version mismatch: "
|
||||
f"Currently {version.__full_version_str__},"
|
||||
f" importing {self.version}"
|
||||
)
|
||||
return "No version information present"
|
||||
|
||||
def load_data_to_database(self) -> None:
|
||||
"""
|
||||
As the name implies, loads data from the JSON file(s) into the database
|
||||
Streams records from each manifest path and loads them into the database
|
||||
using bulk_create with bounded batch sizes, avoiding holding the entire
|
||||
manifest in memory at once.
|
||||
|
||||
Memory bound: at most batch_size * (number of distinct model types
|
||||
present simultaneously in the manifest) instances at any time.
|
||||
For the standard non-split manifest, records are grouped by model, so
|
||||
in practice only one model's batch accumulates at a time.
|
||||
"""
|
||||
# Maps model class -> list of (instance, m2m_data) waiting to be flushed
|
||||
pending: defaultdict[type[Model], list[tuple[Model, M2MData]]] = defaultdict(
|
||||
list,
|
||||
)
|
||||
# All model classes inserted (needed for sequence reset after the load)
|
||||
loaded_models: set[type[Model]] = set()
|
||||
|
||||
def flush_model(model: type[Model]) -> None:
|
||||
"""bulk_create the pending batch for model, then apply M2M."""
|
||||
batch = pending.pop(model, [])
|
||||
if not batch: # pragma: no cover
|
||||
return
|
||||
instances = [inst for inst, _ in batch]
|
||||
# GeneratedField is excluded because it is generated and trying to insert it will fail
|
||||
update_fields = [
|
||||
f.attname
|
||||
for f in model._meta.concrete_fields
|
||||
if not f.primary_key and not isinstance(f, GeneratedField)
|
||||
]
|
||||
if not update_fields: # pragma: no cover
|
||||
raise DeserializationError(
|
||||
f"{model.__name__} has no updatable fields; PK-only models are not supported by the importer",
|
||||
)
|
||||
model.objects.bulk_create( # type: ignore[attr-defined]
|
||||
instances,
|
||||
update_conflicts=True,
|
||||
unique_fields=[model._meta.pk.attname],
|
||||
update_fields=update_fields,
|
||||
)
|
||||
loaded_models.add(model)
|
||||
for instance, m2m_data in batch:
|
||||
for field_name, pk_list in m2m_data.items():
|
||||
getattr(instance, field_name).set(pk_list)
|
||||
|
||||
def flush_all() -> None:
|
||||
for model in list(pending):
|
||||
flush_model(model)
|
||||
|
||||
try:
|
||||
with transaction.atomic():
|
||||
# delete these since pk can change, re-created from import
|
||||
# ContentType and Permission have auto-assigned PKs on a fresh
|
||||
# install that conflict with exported PKs. Delete and re-import.
|
||||
ContentType.objects.all().delete()
|
||||
Permission.objects.all().delete()
|
||||
for manifest_path in self.manifest_paths:
|
||||
call_command("loaddata", manifest_path, skip_checks=True)
|
||||
except (FieldDoesNotExist, DeserializationError, IntegrityError) as e:
|
||||
|
||||
# Constraint checks are disabled so FK/M2M inserts succeed
|
||||
# regardless of record order within the manifest.
|
||||
# Note: on SQLite inside a transaction this context manager is a
|
||||
# no-op; the constraint-deferral path is only exercised on
|
||||
# PostgreSQL in production.
|
||||
with connection.constraint_checks_disabled():
|
||||
for manifest_path in self.manifest_paths:
|
||||
for record in iter_manifest_records(manifest_path):
|
||||
model, instance, m2m_data = _deserialize_record(record)
|
||||
pending[model].append((instance, m2m_data))
|
||||
if len(pending[model]) >= self.batch_size:
|
||||
flush_model(model)
|
||||
|
||||
flush_all()
|
||||
|
||||
self._finalize_db_load(loaded_models)
|
||||
|
||||
except (FieldDoesNotExist, DeserializationError, IntegrityError):
|
||||
self.stdout.write(self.style.ERROR("Database import failed"))
|
||||
if (
|
||||
self.version is not None
|
||||
and self.version != version.__full_version_str__
|
||||
): # pragma: no cover
|
||||
self.stdout.write(
|
||||
self.style.ERROR(
|
||||
"Version mismatch: "
|
||||
f"Currently {version.__full_version_str__},"
|
||||
f" importing {self.version}",
|
||||
),
|
||||
)
|
||||
raise e
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.ERROR("No version information present"),
|
||||
)
|
||||
raise e
|
||||
self.stdout.write(self.style.ERROR(self._import_error_context_message()))
|
||||
raise
|
||||
|
||||
def handle(self, *args, **options) -> None:
|
||||
logging.getLogger().handlers[0].level = logging.ERROR
|
||||
@@ -234,6 +435,7 @@ class Command(CryptMixin, PaperlessCommand):
|
||||
self.source = Path(options["source"]).resolve()
|
||||
self.data_only: bool = options["data_only"]
|
||||
self.passphrase: str | None = options.get("passphrase")
|
||||
self.batch_size: int = options["batch_size"]
|
||||
self.version: str | None = None
|
||||
self.salt: str | None = None
|
||||
self.manifest_paths = []
|
||||
@@ -389,31 +591,10 @@ class Command(CryptMixin, PaperlessCommand):
|
||||
|
||||
self.stdout.write("Copy files into paperless...")
|
||||
|
||||
document_records = [
|
||||
{
|
||||
"pk": record["pk"],
|
||||
EXPORTER_FILE_NAME: record[EXPORTER_FILE_NAME],
|
||||
EXPORTER_THUMBNAIL_NAME: record.get(EXPORTER_THUMBNAIL_NAME),
|
||||
EXPORTER_ARCHIVE_NAME: record.get(EXPORTER_ARCHIVE_NAME),
|
||||
}
|
||||
for manifest_path in self.manifest_paths
|
||||
for record in iter_manifest_records(manifest_path)
|
||||
if record["model"] == "documents.document"
|
||||
]
|
||||
share_link_bundle_records = [
|
||||
{
|
||||
"pk": record["pk"],
|
||||
EXPORTER_SHARE_LINK_BUNDLE_NAME: record.get(
|
||||
EXPORTER_SHARE_LINK_BUNDLE_NAME,
|
||||
),
|
||||
}
|
||||
for manifest_path in self.manifest_paths
|
||||
for record in iter_manifest_records(manifest_path)
|
||||
if record["model"] == "documents.sharelinkbundle"
|
||||
and record.get(EXPORTER_SHARE_LINK_BUNDLE_NAME)
|
||||
]
|
||||
|
||||
for record in self.track(document_records, description="Copying files..."):
|
||||
for record in self.track(
|
||||
_iter_document_copy_records(self.manifest_paths),
|
||||
description="Copying files...",
|
||||
):
|
||||
document = Document.global_objects.get(pk=record["pk"])
|
||||
|
||||
doc_file = record[EXPORTER_FILE_NAME]
|
||||
@@ -452,10 +633,8 @@ class Command(CryptMixin, PaperlessCommand):
|
||||
# archived files
|
||||
copy_file_with_basic_stats(archive_path, document.archive_path)
|
||||
|
||||
document.save()
|
||||
|
||||
for record in self.track(
|
||||
share_link_bundle_records,
|
||||
_iter_share_link_bundle_copy_records(self.manifest_paths),
|
||||
description="Copying share link bundles...",
|
||||
):
|
||||
bundle = ShareLinkBundle.objects.get(pk=record["pk"])
|
||||
|
||||
@@ -11,6 +11,7 @@ from django.core.management.base import CommandError
|
||||
from django.test import TestCase
|
||||
|
||||
from documents.management.commands.document_importer import Command
|
||||
from documents.management.commands.document_importer import _deserialize_record
|
||||
from documents.models import Document
|
||||
from documents.settings import EXPORTER_ARCHIVE_NAME
|
||||
from documents.settings import EXPORTER_FILE_NAME
|
||||
@@ -397,3 +398,307 @@ class TestCommandImport(
|
||||
|
||||
# There should be no error or warnings. Therefore the output should be empty.
|
||||
self.assertEqual(stdout_str, "")
|
||||
|
||||
def test_batch_size_argument_accepted(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A valid source directory with an empty manifest
|
||||
WHEN:
|
||||
- Import is called with --batch-size 100
|
||||
THEN:
|
||||
- No argument parsing error is raised
|
||||
"""
|
||||
manifest_file = self.dirs.scratch_dir / "manifest.json"
|
||||
manifest_file.write_text("[]")
|
||||
|
||||
try:
|
||||
call_command(
|
||||
"document_importer",
|
||||
"--no-progress-bar",
|
||||
"--batch-size",
|
||||
"100",
|
||||
str(self.dirs.scratch_dir),
|
||||
skip_checks=True,
|
||||
)
|
||||
except CommandError:
|
||||
pass # Expected: empty manifest or missing files, not an argument error
|
||||
except SystemExit as e:
|
||||
self.fail(f"--batch-size raised SystemExit (unrecognized argument?): {e}")
|
||||
|
||||
def test_m2m_relations_restored_after_data_only_import(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A manifest with a Tag (pk=100) and a Document (pk=100) with
|
||||
tags: [100] in the fields
|
||||
WHEN:
|
||||
- Data-only import is performed
|
||||
THEN:
|
||||
- Document.objects.get(pk=100).tags.count() == 1
|
||||
- The tag's name is preserved correctly
|
||||
"""
|
||||
tag_record = {
|
||||
"model": "documents.tag",
|
||||
"pk": 100,
|
||||
"fields": {"name": "imported-tag"},
|
||||
}
|
||||
doc_record = {
|
||||
"model": "documents.document",
|
||||
"pk": 100,
|
||||
"fields": {
|
||||
"title": "Tagged Doc",
|
||||
"content": "test content",
|
||||
"checksum": "1093cf6e32adbd16b06969df09215d42c4a3a8938cc18b39455953f08d1ff2ab",
|
||||
"filename": "0001000.pdf",
|
||||
"mime_type": "application/pdf",
|
||||
"modified": "2024-01-01T00:00:00Z",
|
||||
"added": "2024-01-01T00:00:00Z",
|
||||
"tags": [100],
|
||||
"correspondent": None,
|
||||
"document_type": None,
|
||||
"storage_path": None,
|
||||
},
|
||||
}
|
||||
|
||||
manifest_file = self.dirs.scratch_dir / "manifest.json"
|
||||
manifest_file.write_text(json.dumps([tag_record, doc_record]))
|
||||
|
||||
call_command(
|
||||
"document_importer",
|
||||
"--no-progress-bar",
|
||||
"--data-only",
|
||||
str(self.dirs.scratch_dir),
|
||||
skip_checks=True,
|
||||
)
|
||||
|
||||
doc = Document.objects.get(pk=100)
|
||||
self.assertEqual(doc.tags.count(), 1)
|
||||
self.assertEqual(doc.tags.first().name, "imported-tag")
|
||||
|
||||
def test_mid_batch_flush_triggered_by_small_batch_size(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A manifest with two records (Tag + Document)
|
||||
- --batch-size 1 so each record fills a batch immediately
|
||||
WHEN:
|
||||
- Import is performed
|
||||
THEN:
|
||||
- flush_model() fires mid-loop (before flush_all) and the import
|
||||
completes correctly with the M2M relation intact
|
||||
"""
|
||||
tag_record = {
|
||||
"model": "documents.tag",
|
||||
"pk": 200,
|
||||
"fields": {"name": "batch-flush-tag"},
|
||||
}
|
||||
doc_record = {
|
||||
"model": "documents.document",
|
||||
"pk": 200,
|
||||
"fields": {
|
||||
"title": "Batch Flush Doc",
|
||||
"content": "test",
|
||||
"checksum": "2093cf6e32adbd16b06969df09215d42c4a3a8938cc18b39455953f08d1ff2ab",
|
||||
"filename": "0002000.pdf",
|
||||
"mime_type": "application/pdf",
|
||||
"modified": "2024-01-01T00:00:00Z",
|
||||
"added": "2024-01-01T00:00:00Z",
|
||||
"tags": [200],
|
||||
"correspondent": None,
|
||||
"document_type": None,
|
||||
"storage_path": None,
|
||||
},
|
||||
}
|
||||
|
||||
manifest_file = self.dirs.scratch_dir / "manifest.json"
|
||||
manifest_file.write_text(json.dumps([tag_record, doc_record]))
|
||||
|
||||
call_command(
|
||||
"document_importer",
|
||||
"--no-progress-bar",
|
||||
"--data-only",
|
||||
"--batch-size",
|
||||
"1",
|
||||
str(self.dirs.scratch_dir),
|
||||
skip_checks=True,
|
||||
)
|
||||
|
||||
doc = Document.objects.get(pk=200)
|
||||
self.assertEqual(doc.tags.count(), 1)
|
||||
self.assertEqual(doc.tags.first().name, "batch-flush-tag")
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
@pytest.mark.django_db
|
||||
class TestDeserializeRecord:
|
||||
def test_simple_model_no_relations(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A manifest record for a Correspondent (no M2M fields)
|
||||
WHEN:
|
||||
- _deserialize_record is called
|
||||
THEN:
|
||||
- Returns the correct model class, a Correspondent instance with
|
||||
correct field values, and an empty m2m_data dict
|
||||
"""
|
||||
record = {
|
||||
"model": "documents.correspondent",
|
||||
"pk": 42,
|
||||
"fields": {
|
||||
"name": "ACME Corp",
|
||||
"match": "",
|
||||
"matching_algorithm": 1,
|
||||
"is_insensitive": False,
|
||||
"owner": None,
|
||||
},
|
||||
}
|
||||
model, instance, m2m_data = _deserialize_record(record)
|
||||
assert model.__name__ == "Correspondent"
|
||||
assert instance.pk == 42
|
||||
assert instance.name == "ACME Corp"
|
||||
assert m2m_data == {}
|
||||
|
||||
def test_fk_field_stored_on_attname(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A manifest record for a Document with a FK to a Correspondent
|
||||
WHEN:
|
||||
- _deserialize_record is called
|
||||
THEN:
|
||||
- The FK integer is stored on field.attname (correspondent_id),
|
||||
not the descriptor attribute (correspondent)
|
||||
"""
|
||||
record = {
|
||||
"model": "documents.document",
|
||||
"pk": 1,
|
||||
"fields": {
|
||||
"title": "Test Doc",
|
||||
"correspondent": 42,
|
||||
"content": "",
|
||||
"checksum": "abc123abc123abc123abc123abc123ab",
|
||||
"filename": "0000001.pdf",
|
||||
"mime_type": "application/pdf",
|
||||
},
|
||||
}
|
||||
_, instance, _ = _deserialize_record(record)
|
||||
assert instance.correspondent_id == 42
|
||||
|
||||
def test_m2m_field_collected_in_m2m_data(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A manifest record for a Document with a tags M2M list
|
||||
WHEN:
|
||||
- _deserialize_record is called
|
||||
THEN:
|
||||
- M2M PKs are returned in m2m_data under the field name
|
||||
"""
|
||||
record = {
|
||||
"model": "documents.document",
|
||||
"pk": 1,
|
||||
"fields": {
|
||||
"title": "Test",
|
||||
"tags": [1, 3, 7],
|
||||
"content": "",
|
||||
"checksum": "abc123abc123abc123abc123abc123ab",
|
||||
"filename": "0000001.pdf",
|
||||
"mime_type": "application/pdf",
|
||||
},
|
||||
}
|
||||
_, _, m2m_data = _deserialize_record(record)
|
||||
assert m2m_data["tags"] == [1, 3, 7]
|
||||
|
||||
def test_null_fk_stored_as_none(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A manifest record with a nullable FK set to null
|
||||
WHEN:
|
||||
- _deserialize_record is called
|
||||
THEN:
|
||||
- The FK attname is None, not 0 or a string
|
||||
"""
|
||||
record = {
|
||||
"model": "documents.document",
|
||||
"pk": 2,
|
||||
"fields": {
|
||||
"title": "Test",
|
||||
"correspondent": None,
|
||||
"content": "",
|
||||
"checksum": "def456def456def456def456def456de",
|
||||
"filename": "0000002.pdf",
|
||||
"mime_type": "application/pdf",
|
||||
},
|
||||
}
|
||||
_, instance, _ = _deserialize_record(record)
|
||||
assert instance.correspondent_id is None
|
||||
|
||||
def test_unknown_model_raises_deserialization_error(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A manifest record with a model label that does not exist
|
||||
WHEN:
|
||||
- _deserialize_record is called
|
||||
THEN:
|
||||
- DeserializationError is raised
|
||||
"""
|
||||
from django.core.serializers.base import DeserializationError
|
||||
|
||||
record = {"model": "documents.doesnotexist", "pk": 1, "fields": {}}
|
||||
with pytest.raises(DeserializationError):
|
||||
_deserialize_record(record)
|
||||
|
||||
def test_invalid_pk_raises_deserialization_error(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A manifest record whose pk value cannot be coerced to the field type
|
||||
WHEN:
|
||||
- _deserialize_record is called
|
||||
THEN:
|
||||
- DeserializationError is raised mentioning the bad pk value
|
||||
"""
|
||||
from django.core.serializers.base import DeserializationError
|
||||
|
||||
record = {"model": "documents.correspondent", "pk": "not-an-int", "fields": {}}
|
||||
with pytest.raises(
|
||||
DeserializationError,
|
||||
match="Could not coerce pk=not-an-int",
|
||||
):
|
||||
_deserialize_record(record)
|
||||
|
||||
def test_invalid_scalar_field_value_raises_deserialization_error(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A manifest record with a scalar field whose value cannot be coerced
|
||||
WHEN:
|
||||
- _deserialize_record is called
|
||||
THEN:
|
||||
- DeserializationError is raised mentioning the field and bad value
|
||||
"""
|
||||
from django.core.serializers.base import DeserializationError
|
||||
|
||||
record = {
|
||||
"model": "documents.correspondent",
|
||||
"pk": 1,
|
||||
"fields": {"matching_algorithm": "not-an-int"},
|
||||
}
|
||||
with pytest.raises(
|
||||
DeserializationError,
|
||||
match="Could not coerce matching_algorithm=",
|
||||
):
|
||||
_deserialize_record(record)
|
||||
|
||||
def test_unknown_field_name_raises_field_does_not_exist(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A manifest record with a field name that does not exist on the model
|
||||
WHEN:
|
||||
- _deserialize_record is called
|
||||
THEN:
|
||||
- FieldDoesNotExist is raised
|
||||
"""
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
|
||||
record = {
|
||||
"model": "documents.correspondent",
|
||||
"pk": 1,
|
||||
"fields": {"no_such_field_on_correspondent": "value"},
|
||||
}
|
||||
with pytest.raises(FieldDoesNotExist):
|
||||
_deserialize_record(record)
|
||||
|
||||
Reference in New Issue
Block a user