diff --git a/pyproject.toml b/pyproject.toml index 3d00f4e67..36a02528d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -305,6 +305,7 @@ markers = [ "greenmail: Tests requiring Greenmail service", "date_parsing: Tests which cover date parsing from content or filename", "management: Tests which cover management commands/functionality", + "profiling: Benchmarks that profile and compare implementation performance", ] [tool.pytest_env] diff --git a/src/documents/management/commands/document_exporter.py b/src/documents/management/commands/document_exporter.py index bd962efc4..52cadb4fe 100644 --- a/src/documents/management/commands/document_exporter.py +++ b/src/documents/management/commands/document_exporter.py @@ -3,6 +3,8 @@ import json import os import shutil import tempfile +from itertools import chain +from itertools import islice from pathlib import Path from typing import TYPE_CHECKING @@ -19,6 +21,7 @@ from django.contrib.contenttypes.models import ContentType from django.core import serializers from django.core.management.base import BaseCommand from django.core.management.base import CommandError +from django.core.serializers.json import DjangoJSONEncoder from django.db import transaction from django.utils import timezone from filelock import FileLock @@ -26,6 +29,8 @@ from guardian.models import GroupObjectPermission from guardian.models import UserObjectPermission if TYPE_CHECKING: + from collections.abc import Generator + from django.db.models import QuerySet if settings.AUDIT_LOG_ENABLED: @@ -60,6 +65,22 @@ from paperless_mail.models import MailAccount from paperless_mail.models import MailRule +def serialize_queryset_batched( + queryset: "QuerySet", + *, + batch_size: int = 500, +) -> "Generator[list[dict], None, None]": + """Yield batches of serialized records from a QuerySet. + + Each batch is a list of dicts in Django's Python serialization format. + Uses QuerySet.iterator() to avoid loading the full queryset into memory, + and islice to collect chunk-sized batches serialized in a single call. + """ + iterator = queryset.iterator(chunk_size=batch_size) + while chunk := list(islice(iterator, batch_size)): + yield serializers.serialize("python", chunk) + + class Command(CryptMixin, BaseCommand): help = ( "Decrypt and rename all files in our collection into a given target " @@ -186,6 +207,17 @@ class Command(CryptMixin, BaseCommand): help="If provided, is used to encrypt sensitive data in the export", ) + parser.add_argument( + "--batch-size", + type=int, + default=500, + help=( + "Number of records to process per batch during serialization. " + "Lower values reduce peak memory usage; higher values improve " + "throughput. Default: 500." + ), + ) + def handle(self, *args, **options) -> None: self.target = Path(options["target"]).resolve() self.split_manifest: bool = options["split_manifest"] @@ -200,6 +232,7 @@ class Command(CryptMixin, BaseCommand): self.data_only: bool = options["data_only"] self.no_progress_bar: bool = options["no_progress_bar"] self.passphrase: str | None = options.get("passphrase") + self.batch_size: int = options["batch_size"] self.files_in_export_dir: set[Path] = set() self.exported_files: set[str] = set() @@ -294,8 +327,13 @@ class Command(CryptMixin, BaseCommand): # Build an overall manifest for key, object_query in manifest_key_to_object_query.items(): - manifest_dict[key] = json.loads( - serializers.serialize("json", object_query), + manifest_dict[key] = list( + chain.from_iterable( + serialize_queryset_batched( + object_query, + batch_size=self.batch_size, + ), + ), ) self.encrypt_secret_fields(manifest_dict) @@ -512,14 +550,24 @@ class Command(CryptMixin, BaseCommand): self.files_in_export_dir.remove(target) if self.compare_json: target_checksum = hashlib.md5(target.read_bytes()).hexdigest() - src_str = json.dumps(content, indent=2, ensure_ascii=False) + src_str = json.dumps( + content, + cls=DjangoJSONEncoder, + indent=2, + ensure_ascii=False, + ) src_checksum = hashlib.md5(src_str.encode("utf-8")).hexdigest() if src_checksum == target_checksum: perform_write = False if perform_write: target.write_text( - json.dumps(content, indent=2, ensure_ascii=False), + json.dumps( + content, + cls=DjangoJSONEncoder, + indent=2, + ensure_ascii=False, + ), encoding="utf-8", ) diff --git a/src/documents/profiling.py b/src/documents/profiling.py new file mode 100644 index 000000000..aca0913e4 --- /dev/null +++ b/src/documents/profiling.py @@ -0,0 +1,71 @@ +""" +Temporary profiling utilities for comparing implementations. + +Usage in a management command or shell:: + + from documents.profiling import profile_block + + with profile_block("new check_sanity"): + messages = check_sanity() + + with profile_block("old check_sanity"): + messages = check_sanity_old() + +Drop this file when done. +""" + +from __future__ import annotations + +import tracemalloc +from contextlib import contextmanager +from time import perf_counter +from typing import TYPE_CHECKING + +from django.db import connection +from django.db import reset_queries +from django.test.utils import override_settings + +if TYPE_CHECKING: + from collections.abc import Generator + + +@contextmanager +def profile_block(label: str = "block") -> Generator[None, None, None]: + """Profile memory, wall time, and DB queries for a code block. + + Prints a summary to stdout on exit. Requires no external packages. + Enables DEBUG temporarily to capture Django's query log. + """ + tracemalloc.start() + snapshot_before = tracemalloc.take_snapshot() + + with override_settings(DEBUG=True): + reset_queries() + start = perf_counter() + + yield + + elapsed = perf_counter() - start + queries = list(connection.queries) + + snapshot_after = tracemalloc.take_snapshot() + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + # Compare snapshots for top allocations + stats = snapshot_after.compare_to(snapshot_before, "lineno") + + query_time = sum(float(q["time"]) for q in queries) + mem_diff = sum(s.size_diff for s in stats) + + print(f"\n{'=' * 60}") # noqa: T201 + print(f" Profile: {label}") # noqa: T201 + print(f"{'=' * 60}") # noqa: T201 + print(f" Wall time: {elapsed:.4f}s") # noqa: T201 + print(f" Queries: {len(queries)} ({query_time:.4f}s)") # noqa: T201 + print(f" Memory delta: {mem_diff / 1024:.1f} KiB") # noqa: T201 + print(f" Peak memory: {peak / 1024:.1f} KiB") # noqa: T201 + print("\n Top 5 allocations:") # noqa: T201 + for stat in stats[:5]: + print(f" {stat}") # noqa: T201 + print(f"{'=' * 60}\n") # noqa: T201