From 60e4715a0075c817f1dcd1f71bca93af3bd33995 Mon Sep 17 00:00:00 2001 From: Trenton Holmes <797416+stumpylog@users.noreply.github.com> Date: Sat, 13 Jun 2026 06:08:29 -0700 Subject: [PATCH] Store more profiling files --- src/other/bench_vector_store.py | 365 +++++++++++++++++++++++++++++ src/other/test_backend_profile.py | 346 +++++++++++++++++++++++++++ src/other/test_search_profiling.py | 273 +++++++++++++++++++++ 3 files changed, 984 insertions(+) create mode 100644 src/other/bench_vector_store.py create mode 100644 src/other/test_backend_profile.py create mode 100644 src/other/test_search_profiling.py diff --git a/src/other/bench_vector_store.py b/src/other/bench_vector_store.py new file mode 100644 index 000000000..1a9bd24c1 --- /dev/null +++ b/src/other/bench_vector_store.py @@ -0,0 +1,365 @@ +#!/usr/bin/env python3 +"""Head-to-head benchmark: PaperlessLanceVectorStore vs PaperlessSqliteVecVectorStore. + +Run from src/ with: + uv run python bench_vector_store.py [OPTIONS] + +Phase 1 (skipped if bench_data.pkl already exists): generate fake documents with +Faker and embed chunks via Ollama; save to disk for reuse. +Phase 2: benchmark both stores against identical data and print a comparison table. + +Requires both classes to coexist in paperless_ai.vector_store (Task 3 Phase A). +After Phase B replaces the file, the Lance import fails gracefully and only the +sqlite-vec half runs. +""" + +from __future__ import annotations + +import argparse +import pickle +import statistics +import tempfile +import time +import uuid +from pathlib import Path + +import httpx +from faker import Faker +from llama_index.core.schema import TextNode +from llama_index.core.vector_stores.types import FilterOperator +from llama_index.core.vector_stores.types import MetadataFilter +from llama_index.core.vector_stores.types import MetadataFilters +from llama_index.core.vector_stores.types import VectorStoreQuery + +try: + from paperless_ai.vector_store import PaperlessLanceVectorStore + + _LANCE_OK = True +except ImportError: + _LANCE_OK = False + +from paperless_ai.vector_store import PaperlessSqliteVecVectorStore + +DEFAULT_OLLAMA_URL = "http://192.168.1.87:11434" +DEFAULT_EMBED_MODEL = "qwen3-embedding:4b" +DEFAULT_DATA_FILE = "bench_data.pkl" +DEFAULT_N_DOCS = 2000 +DEFAULT_CHUNKS_PER_DOC = 3 +DEFAULT_QUERY_ITERS = 50 +_BATCH = 32 + + +def _embed(texts: list[str], url: str, model: str) -> list[list[float]]: + r = httpx.post( + f"{url}/api/embed", + json={"model": model, "input": texts}, + timeout=120.0, + ) + r.raise_for_status() + return r.json()["embeddings"] + + +def warm_up(url: str, model: str) -> int: + """Fire one embed call to load the model into GPU; return embedding dim.""" + print(f"Warming up {model}...", end=" ", flush=True) + dim = len(_embed(["warm"], url, model)[0]) + print(f"dim={dim}") + return dim + + +def generate_and_save( + n_docs: int, + chunks_per_doc: int, + url: str, + model: str, + out: str, +) -> list[dict]: + fake = Faker() + Faker.seed(42) + print(f"Generating {n_docs} docs ({chunks_per_doc} chunks each)...") + docs = [] + for i in range(n_docs): + body = "\n\n".join(fake.paragraph(nb_sentences=8) for _ in range(3)) + clen = max(1, len(body) // chunks_per_doc) + chunks = [] + for j in range(chunks_per_doc): + s = j * clen + e = s + clen if j < chunks_per_doc - 1 else len(body) + chunks.append( + {"node_id": str(uuid.uuid4()), "text": body[s:e], "embedding": None}, + ) + docs.append( + { + "doc_id": str(i + 1), + "title": fake.catch_phrase(), + "modified": fake.date_time_this_decade().isoformat(), + "chunks": chunks, + }, + ) + + all_texts = [c["text"] for d in docs for c in d["chunks"]] + print(f"Embedding {len(all_texts)} chunks in batches of {_BATCH}...") + embeddings: list[list[float]] = [] + for i in range(0, len(all_texts), _BATCH): + embeddings.extend(_embed(all_texts[i : i + _BATCH], url, model)) + print( + f" {min(i + _BATCH, len(all_texts))}/{len(all_texts)}", + end="\r", + flush=True, + ) + print() + + idx = 0 + for d in docs: + for c in d["chunks"]: + c["embedding"] = embeddings[idx] + idx += 1 + + with open(out, "wb") as f: + pickle.dump(docs, f) + print(f"Saved to {out}") + return docs + + +def _build_nodes(docs: list[dict]) -> list[TextNode]: + nodes = [] + for d in docs: + for c in d["chunks"]: + n = TextNode( + id_=c["node_id"], + text=c["text"], + metadata={"document_id": d["doc_id"], "modified": d["modified"]}, + ) + n.relationships = {} + n.embedding = c["embedding"] + nodes.append(n) + return nodes + + +def _in_filter(ids: list[str]) -> MetadataFilters: + return MetadataFilters( + filters=[ + MetadataFilter(key="document_id", operator=FilterOperator.IN, value=ids), + ], + ) + + +def _dir_bytes(path: str) -> int: + return sum(f.stat().st_size for f in Path(path).rglob("*") if f.is_file()) + + +def _sqlite_bytes(uri: str) -> int: + p = Path(uri) / "llmindex.db" + return p.stat().st_size if p.exists() else 0 + + +def run_bench( + store, + nodes: list[TextNode], + docs: list[dict], + q_iters: int, + is_lance: bool, +) -> dict: + doc_ids = [d["doc_id"] for d in docs] + filter_ids = doc_ids[: max(1, len(doc_ids) // 5)] + q_vecs = [nodes[i * 10 % len(nodes)].embedding for i in range(q_iters)] + by_doc: dict[str, list[TextNode]] = {} + for n in nodes: + by_doc.setdefault(n.metadata["document_id"], []).append(n) + uri = store._uri + + # insert + t0 = time.perf_counter() + store.add(list(nodes)) + r: dict = {"insert": time.perf_counter() - t0} + + # query plain + times = [] + for emb in q_vecs: + t0 = time.perf_counter() + store.query(VectorStoreQuery(query_embedding=emb, similarity_top_k=10)) + times.append(time.perf_counter() - t0) + r["qp50"] = statistics.median(times) + r["qp95"] = sorted(times)[int(len(times) * 0.95)] + + # query filtered + times = [] + flt = _in_filter(filter_ids) + for emb in q_vecs: + t0 = time.perf_counter() + store.query( + VectorStoreQuery(query_embedding=emb, similarity_top_k=10, filters=flt), + ) + times.append(time.perf_counter() - t0) + r["qfp50"] = statistics.median(times) + r["qfp95"] = sorted(times)[int(len(times) * 0.95)] + + # get_modified_times + times = [] + for _ in range(20): + t0 = time.perf_counter() + store.get_modified_times() + times.append(time.perf_counter() - t0) + r["gmt_p50"] = statistics.median(times) + + # upsert (fresh node IDs, same embeddings) + times = [] + for doc in docs[:q_iters]: + orig = by_doc.get(doc["doc_id"], []) + if not orig: + continue + fresh = [] + for o in orig: + fn = TextNode( + id_=str(uuid.uuid4()), + text=o.text, + metadata=o.metadata.copy(), + ) + fn.relationships = {} + fn.embedding = o.embedding + fresh.append(fn) + t0 = time.perf_counter() + store.upsert_document(doc["doc_id"], fresh) + times.append(time.perf_counter() - t0) + r["up50"] = statistics.median(times) if times else 0.0 + r["up95"] = sorted(times)[int(len(times) * 0.95)] if times else 0.0 + + r["size_pre"] = _dir_bytes(uri) if is_lance else _sqlite_bytes(uri) + + # compact + t0 = time.perf_counter() + if is_lance: + store.compact(retention_seconds=0) + else: + store.compact(force=True) + r["compact"] = time.perf_counter() - t0 + + r["size_post"] = _dir_bytes(uri) if is_lance else _sqlite_bytes(uri) + return r + + +def _pct(lv: float | None, sv: float) -> str: + if lv is None or lv == 0: + return "N/A" + p = (sv - lv) / lv * 100 + return f"{'+' if p > 0 else ''}{p:.0f}%" + + +def print_results( + nodes: list[TextNode], + q_iters: int, + lance: dict | None, + sq: dict, +) -> None: + W = 30 + n, dim = len(nodes), len(nodes[0].embedding) + print("\n=== Vector Store Benchmark ===") + print(f"Nodes: {n} | Dim: {dim} | Query iters: {q_iters}\n") + lh = "LanceDB" if lance else "LanceDB (N/A)" + print(f"{'Operation':<{W}} {lh:<22} {'sqlite-vec':<22} {'Delta'}") + print("-" * (W + 66)) + + def _s(v: float) -> str: + return f"{v:.3f}s" + + def _ms(v: float) -> str: + return f"{v * 1000:.1f}ms" + + def _mb(v: float) -> str: + return f"{v / 1e6:.1f} MB" + + def row(label: str, lv: float | None, sv: float, fmt) -> None: + ls = fmt(lv) if lv is not None else "N/A" + print(f"{label:<{W}} {ls:<22} {fmt(sv):<22} {_pct(lv, sv)}") + + def row2( + label: str, + lv1: float | None, + lv2: float | None, + sv1: float, + sv2: float, + ) -> None: + def ms_pair(a: float, b: float) -> str: + return f"{_ms(a)} / {_ms(b)}" + + ls = ms_pair(lv1, lv2) if lv1 is not None else "N/A" + print(f"{label:<{W}} {ls:<22} {ms_pair(sv1, sv2):<22} {_pct(lv1, sv1)}") + + L = lance + row(f"insert ({n} nodes)", L["insert"] if L else None, sq["insert"], _s) + row2( + "query plain p50/p95", + L["qp50"] if L else None, + L["qp95"] if L else None, + sq["qp50"], + sq["qp95"], + ) + row2( + "query filtered p50/p95", + L["qfp50"] if L else None, + L["qfp95"] if L else None, + sq["qfp50"], + sq["qfp95"], + ) + row("get_modified_times p50", L["gmt_p50"] if L else None, sq["gmt_p50"], _ms) + row2( + "upsert p50/p95", + L["up50"] if L else None, + L["up95"] if L else None, + sq["up50"], + sq["up95"], + ) + row("compact", L["compact"] if L else None, sq["compact"], _s) + row("file size pre-compact", L["size_pre"] if L else None, sq["size_pre"], _mb) + row("file size post-compact", L["size_post"] if L else None, sq["size_post"], _mb) + + +def main() -> None: + ap = argparse.ArgumentParser(description="Vector store head-to-head benchmark") + ap.add_argument("--n-docs", type=int, default=DEFAULT_N_DOCS) + ap.add_argument("--chunks-per-doc", type=int, default=DEFAULT_CHUNKS_PER_DOC) + ap.add_argument("--data-file", default=DEFAULT_DATA_FILE) + ap.add_argument("--regenerate", action="store_true") + ap.add_argument("--ollama-url", default=DEFAULT_OLLAMA_URL) + ap.add_argument("--embed-model", default=DEFAULT_EMBED_MODEL) + ap.add_argument("--query-iters", type=int, default=DEFAULT_QUERY_ITERS) + args = ap.parse_args() + + warm_up(args.ollama_url, args.embed_model) + + data_path = Path(args.data_file) + if args.regenerate or not data_path.exists(): + docs = generate_and_save( + args.n_docs, + args.chunks_per_doc, + args.ollama_url, + args.embed_model, + args.data_file, + ) + else: + print(f"Loading {args.data_file}...") + with open(data_path, "rb") as f: + docs = pickle.load(f) + print(f"Loaded {len(docs)} docs ({sum(len(d['chunks']) for d in docs)} nodes)") + + all_nodes = _build_nodes(docs) + + lance_r = None + if _LANCE_OK: + print("\nBenchmarking LanceDB...") + with tempfile.TemporaryDirectory() as d: + store = PaperlessLanceVectorStore(uri=d) + lance_r = run_bench(store, all_nodes, docs, args.query_iters, is_lance=True) + else: + print("Skipping LanceDB (PaperlessLanceVectorStore not importable).") + + print("\nBenchmarking sqlite-vec...") + with tempfile.TemporaryDirectory() as d: + store = PaperlessSqliteVecVectorStore(uri=d) + sqlite_r = run_bench(store, all_nodes, docs, args.query_iters, is_lance=False) + + print_results(all_nodes, args.query_iters, lance_r, sqlite_r) + + +if __name__ == "__main__": + main() diff --git a/src/other/test_backend_profile.py b/src/other/test_backend_profile.py new file mode 100644 index 000000000..91b725c5a --- /dev/null +++ b/src/other/test_backend_profile.py @@ -0,0 +1,346 @@ +# ruff: noqa: T201 +""" +cProfile-based search pipeline profiling with a 20k-document dataset. + +Run with: + uv run pytest ../test_backend_profile.py \ + -m profiling --override-ini="addopts=" -s -v + +Each scenario prints: + - Wall time for the operation + - cProfile stats sorted by cumulative time (top 25 callers) + +This is a developer tool, not a correctness test. Nothing here should +fail unless the code is broken. +""" + +from __future__ import annotations + +import random +import time +from typing import TYPE_CHECKING + +import pytest + +from documents.models import Document +from documents.search._backend import TantivyBackend +from documents.search._backend import reset_backend +from profiling import profile_cpu + +if TYPE_CHECKING: + from pathlib import Path + +# transaction=False (default): tests roll back, but the module-scoped fixture +# commits its data outside the test transaction so it remains visible throughout. +pytestmark = [pytest.mark.profiling, pytest.mark.django_db] + +# --------------------------------------------------------------------------- +# Dataset constants +# --------------------------------------------------------------------------- +NUM_DOCS = 20_000 +SEED = 42 + +# Terms and their approximate match rates across the corpus. +# "rechnung" -> ~70% of docs (~14 000) +# "mahnung" -> ~20% of docs (~4 000) +# "kontonummer" -> ~5% of docs (~1 000) +# "rarewort" -> ~1% of docs (~200) +COMMON_TERM = "rechnung" +MEDIUM_TERM = "mahnung" +RARE_TERM = "kontonummer" +VERY_RARE_TERM = "rarewort" + +PAGE_SIZE = 25 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_FILLER_WORDS = [ + "dokument", # codespell:ignore + "seite", + "datum", + "betrag", + "nummer", + "konto", + "firma", + "vertrag", + "lieferant", + "bestellung", + "steuer", + "mwst", + "leistung", + "auftrag", + "zahlung", +] + + +def _build_content(rng: random.Random) -> str: + """Return a short paragraph with terms embedded at the desired rates.""" + words = rng.choices(_FILLER_WORDS, k=15) + if rng.random() < 0.70: + words.append(COMMON_TERM) + if rng.random() < 0.20: + words.append(MEDIUM_TERM) + if rng.random() < 0.05: + words.append(RARE_TERM) + if rng.random() < 0.01: + words.append(VERY_RARE_TERM) + rng.shuffle(words) + return " ".join(words) + + +def _time(fn, *, label: str, runs: int = 3): + """Run *fn()* several times and report min/avg/max wall time (no cProfile).""" + times = [] + result = None + for _ in range(runs): + t0 = time.perf_counter() + result = fn() + times.append(time.perf_counter() - t0) + mn, avg, mx = min(times), sum(times) / len(times), max(times) + print( + f" {label}: min={mn * 1000:.1f}ms avg={avg * 1000:.1f}ms max={mx * 1000:.1f}ms (n={runs})", + ) + return result + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def module_db(django_db_setup, django_db_blocker): + """Unlock the DB for the whole module (module-scoped).""" + with django_db_blocker.unblock(): + yield + + +@pytest.fixture(scope="module") +def large_backend(tmp_path_factory, module_db) -> TantivyBackend: + """ + Build a 20 000-document DB + on-disk Tantivy index, shared across all + profiling scenarios in this module. Teardown deletes all documents. + """ + index_path: Path = tmp_path_factory.mktemp("tantivy_profile") + + # ---- 1. Bulk-create Document rows ---------------------------------------- + rng = random.Random(SEED) + docs = [ + Document( + title=f"Document {i:05d}", + content=_build_content(rng), + checksum=f"{i:064x}", + pk=i + 1, + ) + for i in range(NUM_DOCS) + ] + t0 = time.perf_counter() + Document.objects.bulk_create(docs, batch_size=1_000) + db_time = time.perf_counter() - t0 + print(f"\n[setup] bulk_create {NUM_DOCS} docs: {db_time:.2f}s") + + # ---- 2. Build Tantivy index ----------------------------------------------- + backend = TantivyBackend(path=index_path) + backend.open() + + t0 = time.perf_counter() + with backend.batch_update() as batch: + for doc in Document.objects.iterator(chunk_size=500): + batch.add_or_update(doc) + idx_time = time.perf_counter() - t0 + print(f"[setup] index {NUM_DOCS} docs: {idx_time:.2f}s") + + # ---- 3. Report corpus stats ----------------------------------------------- + for term in (COMMON_TERM, MEDIUM_TERM, RARE_TERM, VERY_RARE_TERM): + count = len(backend.search_ids(term, user=None)) + print(f"[setup] '{term}' -> {count} hits") + + yield backend + + # ---- Teardown ------------------------------------------------------------ + backend.close() + reset_backend() + Document.objects.all().delete() + + +# --------------------------------------------------------------------------- +# Profiling tests — each scenario is a separate function so pytest can run +# them individually or all together with -m profiling. +# --------------------------------------------------------------------------- + + +class TestSearchIdsProfile: + """Profile backend.search_ids() — pure Tantivy, no DB.""" + + def test_search_ids_large(self, large_backend: TantivyBackend): + """~14 000 hits: how long does Tantivy take to collect all IDs?""" + profile_cpu( + lambda: large_backend.search_ids(COMMON_TERM, user=None), + label=f"search_ids('{COMMON_TERM}') [large result set ~14k]", + ) + + def test_search_ids_medium(self, large_backend: TantivyBackend): + """~4 000 hits.""" + profile_cpu( + lambda: large_backend.search_ids(MEDIUM_TERM, user=None), + label=f"search_ids('{MEDIUM_TERM}') [medium result set ~4k]", + ) + + def test_search_ids_rare(self, large_backend: TantivyBackend): + """~1 000 hits.""" + profile_cpu( + lambda: large_backend.search_ids(RARE_TERM, user=None), + label=f"search_ids('{RARE_TERM}') [rare result set ~1k]", + ) + + +class TestIntersectAndOrderProfile: + """ + Profile the DB intersection step: filter(pk__in=search_ids). + This is the 'intersect_and_order' logic from views.py. + """ + + def test_intersect_large(self, large_backend: TantivyBackend): + """Intersect 14k Tantivy IDs with all 20k ORM-visible docs.""" + all_ids = large_backend.search_ids(COMMON_TERM, user=None) + qs = Document.objects.all() + + print(f"\n Tantivy returned {len(all_ids)} IDs") + + profile_cpu( + lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)), + label=f"filter(pk__in={len(all_ids)} ids) [large, use_tantivy_sort=True path]", + ) + + # Also time it a few times to get stable numbers + print() + _time( + lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)), + label=f"filter(pk__in={len(all_ids)}) repeated", + ) + + def test_intersect_rare(self, large_backend: TantivyBackend): + """Intersect ~1k Tantivy IDs — the happy path.""" + all_ids = large_backend.search_ids(RARE_TERM, user=None) + qs = Document.objects.all() + + print(f"\n Tantivy returned {len(all_ids)} IDs") + + profile_cpu( + lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)), + label=f"filter(pk__in={len(all_ids)} ids) [rare, use_tantivy_sort=True path]", + ) + + +class TestHighlightHitsProfile: + """Profile backend.highlight_hits() — per-doc Tantivy lookups with BM25 scoring.""" + + def test_highlight_page1(self, large_backend: TantivyBackend): + """25-doc highlight for page 1 (rank_start=1).""" + all_ids = large_backend.search_ids(COMMON_TERM, user=None) + page_ids = all_ids[:PAGE_SIZE] + + profile_cpu( + lambda: large_backend.highlight_hits( + COMMON_TERM, + page_ids, + rank_start=1, + ), + label=f"highlight_hits page 1 (ids {all_ids[0]}..{all_ids[PAGE_SIZE - 1]})", + ) + + def test_highlight_page_middle(self, large_backend: TantivyBackend): + """25-doc highlight for a mid-corpus page (rank_start=page_offset+1).""" + all_ids = large_backend.search_ids(COMMON_TERM, user=None) + mid = len(all_ids) // 2 + page_ids = all_ids[mid : mid + PAGE_SIZE] + page_offset = mid + + profile_cpu( + lambda: large_backend.highlight_hits( + COMMON_TERM, + page_ids, + rank_start=page_offset + 1, + ), + label=f"highlight_hits page ~{mid // PAGE_SIZE} (offset {page_offset})", + ) + + def test_highlight_repeated(self, large_backend: TantivyBackend): + """Multiple runs of page-1 highlight to see variance.""" + all_ids = large_backend.search_ids(COMMON_TERM, user=None) + page_ids = all_ids[:PAGE_SIZE] + + print() + _time( + lambda: large_backend.highlight_hits(COMMON_TERM, page_ids, rank_start=1), + label="highlight_hits page 1", + runs=5, + ) + + +class TestFullPipelineProfile: + """ + Profile the combined pipeline as it runs in views.py: + search_ids -> filter(pk__in) -> highlight_hits + """ + + def _run_pipeline( + self, + backend: TantivyBackend, + term: str, + page: int = 1, + ): + all_ids = backend.search_ids(term, user=None) + qs = Document.objects.all() + visible_ids = set(qs.filter(pk__in=all_ids).values_list("pk", flat=True)) + ordered_ids = [i for i in all_ids if i in visible_ids] + + page_offset = (page - 1) * PAGE_SIZE + page_ids = ordered_ids[page_offset : page_offset + PAGE_SIZE] + hits = backend.highlight_hits( + term, + page_ids, + rank_start=page_offset + 1, + ) + return ordered_ids, hits + + def test_pipeline_large_page1(self, large_backend: TantivyBackend): + """Full pipeline: large result set, page 1.""" + ordered_ids, hits = profile_cpu( + lambda: self._run_pipeline(large_backend, COMMON_TERM, page=1), + label=f"full pipeline '{COMMON_TERM}' page 1", + )[0] + print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page") + + def test_pipeline_large_page5(self, large_backend: TantivyBackend): + """Full pipeline: large result set, page 5.""" + ordered_ids, hits = profile_cpu( + lambda: self._run_pipeline(large_backend, COMMON_TERM, page=5), + label=f"full pipeline '{COMMON_TERM}' page 5", + )[0] + print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page") + + def test_pipeline_rare(self, large_backend: TantivyBackend): + """Full pipeline: rare term, page 1 (fast path).""" + ordered_ids, hits = profile_cpu( + lambda: self._run_pipeline(large_backend, RARE_TERM, page=1), + label=f"full pipeline '{RARE_TERM}' page 1", + )[0] + print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page") + + def test_pipeline_repeated(self, large_backend: TantivyBackend): + """Repeated runs to get stable timing (no cProfile overhead).""" + print() + for term, label in [ + (COMMON_TERM, f"'{COMMON_TERM}' (large)"), + (MEDIUM_TERM, f"'{MEDIUM_TERM}' (medium)"), + (RARE_TERM, f"'{RARE_TERM}' (rare)"), + ]: + _time( + lambda t=term: self._run_pipeline(large_backend, t, page=1), + label=f"full pipeline {label} page 1", + runs=3, + ) diff --git a/src/other/test_search_profiling.py b/src/other/test_search_profiling.py new file mode 100644 index 000000000..25abf8ecc --- /dev/null +++ b/src/other/test_search_profiling.py @@ -0,0 +1,273 @@ +""" +Search performance profiling tests. + +Run explicitly — excluded from the normal test suite: + + uv run pytest -m profiling -s -p no:xdist --override-ini="addopts=" -v + +The ``-s`` flag is required to see profile_block() output. +The ``-p no:xdist`` flag disables parallel execution for accurate measurements. + +Corpus: 5 000 documents generated deterministically from a fixed Faker seed, +with realistic variety: 30 correspondents, 15 document types, 50 tags, ~500 +notes spread across ~10 % of documents. +""" + +from __future__ import annotations + +import random + +import pytest +from django.contrib.auth.models import User +from faker import Faker +from rest_framework.test import APIClient + +from documents.models import Correspondent +from documents.models import Document +from documents.models import DocumentType +from documents.models import Note +from documents.models import Tag +from documents.search import get_backend +from documents.search import reset_backend +from documents.search._backend import SearchMode +from profiling import profile_block + +pytestmark = [pytest.mark.profiling, pytest.mark.search, pytest.mark.django_db] + +# --------------------------------------------------------------------------- +# Corpus parameters +# --------------------------------------------------------------------------- + +DOC_COUNT = 5_000 +SEED = 42 +NUM_CORRESPONDENTS = 30 +NUM_DOC_TYPES = 15 +NUM_TAGS = 50 +NOTE_FRACTION = 0.10 # ~500 documents get a note +PAGE_SIZE = 25 + + +def _build_corpus(rng: random.Random, fake: Faker) -> None: + """ + Insert the full corpus into the database and index it. + + Uses bulk_create for the Document rows (fast) then handles the M2M tag + relationships and notes individually. Indexes the full corpus with a + single backend.rebuild() call. + """ + import datetime + + # ---- lookup objects ------------------------------------------------- + correspondents = [ + Correspondent.objects.create(name=f"profcorp-{i}-{fake.company()}"[:128]) + for i in range(NUM_CORRESPONDENTS) + ] + doc_types = [ + DocumentType.objects.create(name=f"proftype-{i}-{fake.word()}"[:128]) + for i in range(NUM_DOC_TYPES) + ] + tags = [ + Tag.objects.create(name=f"proftag-{i}-{fake.word()}"[:100]) + for i in range(NUM_TAGS) + ] + note_user = User.objects.create_user(username="profnoteuser", password="x") + + # ---- bulk-create documents ------------------------------------------ + base_date = datetime.date(2018, 1, 1) + raw_docs = [] + for i in range(DOC_COUNT): + day_offset = rng.randint(0, 6 * 365) + created = base_date + datetime.timedelta(days=day_offset) + raw_docs.append( + Document( + title=fake.sentence(nb_words=rng.randint(3, 9)).rstrip("."), + content="\n\n".join( + fake.paragraph(nb_sentences=rng.randint(3, 7)) + for _ in range(rng.randint(2, 5)) + ), + checksum=f"PROF{i:07d}", + correspondent=rng.choice(correspondents + [None] * 8), + document_type=rng.choice(doc_types + [None] * 4), + created=created, + ), + ) + documents = Document.objects.bulk_create(raw_docs) + + # ---- tags (M2M, post-bulk) ------------------------------------------ + for doc in documents: + k = rng.randint(0, 5) + if k: + doc.tags.add(*rng.sample(tags, k)) + + # ---- notes on ~10 % of docs ----------------------------------------- + note_docs = rng.sample(documents, int(DOC_COUNT * NOTE_FRACTION)) + for doc in note_docs: + Note.objects.create( + document=doc, + note=fake.sentence(nb_words=rng.randint(6, 20)), + user=note_user, + ) + + # ---- build Tantivy index -------------------------------------------- + backend = get_backend() + qs = Document.objects.select_related( + "correspondent", + "document_type", + "storage_path", + "owner", + ).prefetch_related("tags", "notes__user", "custom_fields__field") + backend.rebuild(qs) + + +class TestSearchProfiling: + """ + Performance profiling for the Tantivy search backend and DRF API layer. + + Each test builds a fresh 5 000-document corpus, exercises one hot path, + and prints profile_block() measurements to stdout. No correctness + assertions — the goal is to surface hot spots and track regressions. + """ + + @pytest.fixture(autouse=True) + def _setup(self, tmp_path, settings): + index_dir = tmp_path / "index" + index_dir.mkdir() + settings.INDEX_DIR = index_dir + + reset_backend() + rng = random.Random(SEED) + fake = Faker() + Faker.seed(SEED) + + self.user = User.objects.create_superuser( + username="profiler", + password="admin", + ) + self.client = APIClient() + self.client.force_authenticate(user=self.user) + + _build_corpus(rng, fake) + yield + reset_backend() + + # -- 1. Backend: search_ids relevance --------------------------------- + + def test_profile_search_ids_relevance(self): + """Profile: search_ids() with relevance ordering across several queries.""" + backend = get_backend() + queries = [ + "invoice payment", + "annual report", + "bank statement", + "contract agreement", + "receipt", + ] + with profile_block(f"search_ids — relevance ({len(queries)} queries)"): + for q in queries: + backend.search_ids(q, user=None) + + # -- 2. Backend: search_ids with Tantivy-native sort ------------------ + + def test_profile_search_ids_sorted(self): + """Profile: search_ids() sorted by a Tantivy fast field (created).""" + backend = get_backend() + with profile_block("search_ids — sorted by created (asc + desc)"): + backend.search_ids( + "the", + user=None, + sort_field="created", + sort_reverse=False, + ) + backend.search_ids( + "the", + user=None, + sort_field="created", + sort_reverse=True, + ) + + # -- 3. Backend: highlight_hits for a page of 25 ---------------------- + + def test_profile_highlight_hits(self): + """Profile: highlight_hits() for a 25-document page.""" + backend = get_backend() + all_ids = backend.search_ids("report", user=None) + page_ids = all_ids[:PAGE_SIZE] + with profile_block(f"highlight_hits — {len(page_ids)} docs"): + backend.highlight_hits("report", page_ids) + + # -- 4. Backend: autocomplete ----------------------------------------- + + def test_profile_autocomplete(self): + """Profile: autocomplete() with eight common prefixes.""" + backend = get_backend() + prefixes = ["inv", "pay", "con", "rep", "sta", "acc", "doc", "fin"] + with profile_block(f"autocomplete — {len(prefixes)} prefixes"): + for prefix in prefixes: + backend.autocomplete(prefix, limit=10) + + # -- 5. Backend: simple-mode search (TEXT and TITLE) ------------------ + + def test_profile_search_ids_simple_modes(self): + """Profile: search_ids() in TEXT and TITLE simple-search modes.""" + backend = get_backend() + queries = ["invoice 2023", "annual report", "bank statement"] + with profile_block( + f"search_ids — TEXT + TITLE modes ({len(queries)} queries each)", + ): + for q in queries: + backend.search_ids(q, user=None, search_mode=SearchMode.TEXT) + backend.search_ids(q, user=None, search_mode=SearchMode.TITLE) + + # -- 6. API: full round-trip, relevance + page 1 ---------------------- + + def test_profile_api_relevance_search(self): + """Profile: full API search round-trip, relevance order, page 1.""" + with profile_block( + f"API /documents/?query=… relevance (page 1, page_size={PAGE_SIZE})", + ): + response = self.client.get( + f"/api/documents/?query=invoice+payment&page=1&page_size={PAGE_SIZE}", + ) + assert response.status_code == 200 + + # -- 7. API: full round-trip, ORM-ordered (title) --------------------- + + def test_profile_api_orm_sorted_search(self): + """Profile: full API search round-trip with ORM-delegated sort (title).""" + with profile_block("API /documents/?query=…&ordering=title"): + response = self.client.get( + f"/api/documents/?query=report&ordering=title&page=1&page_size={PAGE_SIZE}", + ) + assert response.status_code == 200 + + # -- 8. API: full round-trip, score sort ------------------------------ + + def test_profile_api_score_sort(self): + """Profile: full API search with ordering=-score (relevance, preserve order).""" + with profile_block("API /documents/?query=…&ordering=-score"): + response = self.client.get( + f"/api/documents/?query=statement&ordering=-score&page=1&page_size={PAGE_SIZE}", + ) + assert response.status_code == 200 + + # -- 9. API: full round-trip, with selection_data --------------------- + + def test_profile_api_with_selection_data(self): + """Profile: full API search including include_selection_data=true.""" + with profile_block("API /documents/?query=…&include_selection_data=true"): + response = self.client.get( + f"/api/documents/?query=contract&page=1&page_size={PAGE_SIZE}" + "&include_selection_data=true", + ) + assert response.status_code == 200 + assert "selection_data" in response.data + + # -- 10. API: paginated (page 2) -------------------------------------- + + def test_profile_api_page_2(self): + """Profile: full API search, page 2 — exercises page offset arithmetic.""" + with profile_block(f"API /documents/?query=…&page=2&page_size={PAGE_SIZE}"): + response = self.client.get( + f"/api/documents/?query=the&page=2&page_size={PAGE_SIZE}", + ) + assert response.status_code == 200