mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-30 09:14:17 +00:00
Store more profiling files
This commit is contained in:
committed by
stumpylog
parent
ef8b4b453d
commit
60e4715a00
@@ -0,0 +1,365 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Head-to-head benchmark: PaperlessLanceVectorStore vs PaperlessSqliteVecVectorStore.
|
||||
|
||||
Run from src/ with:
|
||||
uv run python bench_vector_store.py [OPTIONS]
|
||||
|
||||
Phase 1 (skipped if bench_data.pkl already exists): generate fake documents with
|
||||
Faker and embed chunks via Ollama; save to disk for reuse.
|
||||
Phase 2: benchmark both stores against identical data and print a comparison table.
|
||||
|
||||
Requires both classes to coexist in paperless_ai.vector_store (Task 3 Phase A).
|
||||
After Phase B replaces the file, the Lance import fails gracefully and only the
|
||||
sqlite-vec half runs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
import statistics
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from faker import Faker
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import FilterOperator
|
||||
from llama_index.core.vector_stores.types import MetadataFilter
|
||||
from llama_index.core.vector_stores.types import MetadataFilters
|
||||
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||
|
||||
try:
|
||||
from paperless_ai.vector_store import PaperlessLanceVectorStore
|
||||
|
||||
_LANCE_OK = True
|
||||
except ImportError:
|
||||
_LANCE_OK = False
|
||||
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://192.168.1.87:11434"
|
||||
DEFAULT_EMBED_MODEL = "qwen3-embedding:4b"
|
||||
DEFAULT_DATA_FILE = "bench_data.pkl"
|
||||
DEFAULT_N_DOCS = 2000
|
||||
DEFAULT_CHUNKS_PER_DOC = 3
|
||||
DEFAULT_QUERY_ITERS = 50
|
||||
_BATCH = 32
|
||||
|
||||
|
||||
def _embed(texts: list[str], url: str, model: str) -> list[list[float]]:
|
||||
r = httpx.post(
|
||||
f"{url}/api/embed",
|
||||
json={"model": model, "input": texts},
|
||||
timeout=120.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()["embeddings"]
|
||||
|
||||
|
||||
def warm_up(url: str, model: str) -> int:
|
||||
"""Fire one embed call to load the model into GPU; return embedding dim."""
|
||||
print(f"Warming up {model}...", end=" ", flush=True)
|
||||
dim = len(_embed(["warm"], url, model)[0])
|
||||
print(f"dim={dim}")
|
||||
return dim
|
||||
|
||||
|
||||
def generate_and_save(
|
||||
n_docs: int,
|
||||
chunks_per_doc: int,
|
||||
url: str,
|
||||
model: str,
|
||||
out: str,
|
||||
) -> list[dict]:
|
||||
fake = Faker()
|
||||
Faker.seed(42)
|
||||
print(f"Generating {n_docs} docs ({chunks_per_doc} chunks each)...")
|
||||
docs = []
|
||||
for i in range(n_docs):
|
||||
body = "\n\n".join(fake.paragraph(nb_sentences=8) for _ in range(3))
|
||||
clen = max(1, len(body) // chunks_per_doc)
|
||||
chunks = []
|
||||
for j in range(chunks_per_doc):
|
||||
s = j * clen
|
||||
e = s + clen if j < chunks_per_doc - 1 else len(body)
|
||||
chunks.append(
|
||||
{"node_id": str(uuid.uuid4()), "text": body[s:e], "embedding": None},
|
||||
)
|
||||
docs.append(
|
||||
{
|
||||
"doc_id": str(i + 1),
|
||||
"title": fake.catch_phrase(),
|
||||
"modified": fake.date_time_this_decade().isoformat(),
|
||||
"chunks": chunks,
|
||||
},
|
||||
)
|
||||
|
||||
all_texts = [c["text"] for d in docs for c in d["chunks"]]
|
||||
print(f"Embedding {len(all_texts)} chunks in batches of {_BATCH}...")
|
||||
embeddings: list[list[float]] = []
|
||||
for i in range(0, len(all_texts), _BATCH):
|
||||
embeddings.extend(_embed(all_texts[i : i + _BATCH], url, model))
|
||||
print(
|
||||
f" {min(i + _BATCH, len(all_texts))}/{len(all_texts)}",
|
||||
end="\r",
|
||||
flush=True,
|
||||
)
|
||||
print()
|
||||
|
||||
idx = 0
|
||||
for d in docs:
|
||||
for c in d["chunks"]:
|
||||
c["embedding"] = embeddings[idx]
|
||||
idx += 1
|
||||
|
||||
with open(out, "wb") as f:
|
||||
pickle.dump(docs, f)
|
||||
print(f"Saved to {out}")
|
||||
return docs
|
||||
|
||||
|
||||
def _build_nodes(docs: list[dict]) -> list[TextNode]:
|
||||
nodes = []
|
||||
for d in docs:
|
||||
for c in d["chunks"]:
|
||||
n = TextNode(
|
||||
id_=c["node_id"],
|
||||
text=c["text"],
|
||||
metadata={"document_id": d["doc_id"], "modified": d["modified"]},
|
||||
)
|
||||
n.relationships = {}
|
||||
n.embedding = c["embedding"]
|
||||
nodes.append(n)
|
||||
return nodes
|
||||
|
||||
|
||||
def _in_filter(ids: list[str]) -> MetadataFilters:
|
||||
return MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(key="document_id", operator=FilterOperator.IN, value=ids),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _dir_bytes(path: str) -> int:
|
||||
return sum(f.stat().st_size for f in Path(path).rglob("*") if f.is_file())
|
||||
|
||||
|
||||
def _sqlite_bytes(uri: str) -> int:
|
||||
p = Path(uri) / "llmindex.db"
|
||||
return p.stat().st_size if p.exists() else 0
|
||||
|
||||
|
||||
def run_bench(
|
||||
store,
|
||||
nodes: list[TextNode],
|
||||
docs: list[dict],
|
||||
q_iters: int,
|
||||
is_lance: bool,
|
||||
) -> dict:
|
||||
doc_ids = [d["doc_id"] for d in docs]
|
||||
filter_ids = doc_ids[: max(1, len(doc_ids) // 5)]
|
||||
q_vecs = [nodes[i * 10 % len(nodes)].embedding for i in range(q_iters)]
|
||||
by_doc: dict[str, list[TextNode]] = {}
|
||||
for n in nodes:
|
||||
by_doc.setdefault(n.metadata["document_id"], []).append(n)
|
||||
uri = store._uri
|
||||
|
||||
# insert
|
||||
t0 = time.perf_counter()
|
||||
store.add(list(nodes))
|
||||
r: dict = {"insert": time.perf_counter() - t0}
|
||||
|
||||
# query plain
|
||||
times = []
|
||||
for emb in q_vecs:
|
||||
t0 = time.perf_counter()
|
||||
store.query(VectorStoreQuery(query_embedding=emb, similarity_top_k=10))
|
||||
times.append(time.perf_counter() - t0)
|
||||
r["qp50"] = statistics.median(times)
|
||||
r["qp95"] = sorted(times)[int(len(times) * 0.95)]
|
||||
|
||||
# query filtered
|
||||
times = []
|
||||
flt = _in_filter(filter_ids)
|
||||
for emb in q_vecs:
|
||||
t0 = time.perf_counter()
|
||||
store.query(
|
||||
VectorStoreQuery(query_embedding=emb, similarity_top_k=10, filters=flt),
|
||||
)
|
||||
times.append(time.perf_counter() - t0)
|
||||
r["qfp50"] = statistics.median(times)
|
||||
r["qfp95"] = sorted(times)[int(len(times) * 0.95)]
|
||||
|
||||
# get_modified_times
|
||||
times = []
|
||||
for _ in range(20):
|
||||
t0 = time.perf_counter()
|
||||
store.get_modified_times()
|
||||
times.append(time.perf_counter() - t0)
|
||||
r["gmt_p50"] = statistics.median(times)
|
||||
|
||||
# upsert (fresh node IDs, same embeddings)
|
||||
times = []
|
||||
for doc in docs[:q_iters]:
|
||||
orig = by_doc.get(doc["doc_id"], [])
|
||||
if not orig:
|
||||
continue
|
||||
fresh = []
|
||||
for o in orig:
|
||||
fn = TextNode(
|
||||
id_=str(uuid.uuid4()),
|
||||
text=o.text,
|
||||
metadata=o.metadata.copy(),
|
||||
)
|
||||
fn.relationships = {}
|
||||
fn.embedding = o.embedding
|
||||
fresh.append(fn)
|
||||
t0 = time.perf_counter()
|
||||
store.upsert_document(doc["doc_id"], fresh)
|
||||
times.append(time.perf_counter() - t0)
|
||||
r["up50"] = statistics.median(times) if times else 0.0
|
||||
r["up95"] = sorted(times)[int(len(times) * 0.95)] if times else 0.0
|
||||
|
||||
r["size_pre"] = _dir_bytes(uri) if is_lance else _sqlite_bytes(uri)
|
||||
|
||||
# compact
|
||||
t0 = time.perf_counter()
|
||||
if is_lance:
|
||||
store.compact(retention_seconds=0)
|
||||
else:
|
||||
store.compact(force=True)
|
||||
r["compact"] = time.perf_counter() - t0
|
||||
|
||||
r["size_post"] = _dir_bytes(uri) if is_lance else _sqlite_bytes(uri)
|
||||
return r
|
||||
|
||||
|
||||
def _pct(lv: float | None, sv: float) -> str:
|
||||
if lv is None or lv == 0:
|
||||
return "N/A"
|
||||
p = (sv - lv) / lv * 100
|
||||
return f"{'+' if p > 0 else ''}{p:.0f}%"
|
||||
|
||||
|
||||
def print_results(
|
||||
nodes: list[TextNode],
|
||||
q_iters: int,
|
||||
lance: dict | None,
|
||||
sq: dict,
|
||||
) -> None:
|
||||
W = 30
|
||||
n, dim = len(nodes), len(nodes[0].embedding)
|
||||
print("\n=== Vector Store Benchmark ===")
|
||||
print(f"Nodes: {n} | Dim: {dim} | Query iters: {q_iters}\n")
|
||||
lh = "LanceDB" if lance else "LanceDB (N/A)"
|
||||
print(f"{'Operation':<{W}} {lh:<22} {'sqlite-vec':<22} {'Delta'}")
|
||||
print("-" * (W + 66))
|
||||
|
||||
def _s(v: float) -> str:
|
||||
return f"{v:.3f}s"
|
||||
|
||||
def _ms(v: float) -> str:
|
||||
return f"{v * 1000:.1f}ms"
|
||||
|
||||
def _mb(v: float) -> str:
|
||||
return f"{v / 1e6:.1f} MB"
|
||||
|
||||
def row(label: str, lv: float | None, sv: float, fmt) -> None:
|
||||
ls = fmt(lv) if lv is not None else "N/A"
|
||||
print(f"{label:<{W}} {ls:<22} {fmt(sv):<22} {_pct(lv, sv)}")
|
||||
|
||||
def row2(
|
||||
label: str,
|
||||
lv1: float | None,
|
||||
lv2: float | None,
|
||||
sv1: float,
|
||||
sv2: float,
|
||||
) -> None:
|
||||
def ms_pair(a: float, b: float) -> str:
|
||||
return f"{_ms(a)} / {_ms(b)}"
|
||||
|
||||
ls = ms_pair(lv1, lv2) if lv1 is not None else "N/A"
|
||||
print(f"{label:<{W}} {ls:<22} {ms_pair(sv1, sv2):<22} {_pct(lv1, sv1)}")
|
||||
|
||||
L = lance
|
||||
row(f"insert ({n} nodes)", L["insert"] if L else None, sq["insert"], _s)
|
||||
row2(
|
||||
"query plain p50/p95",
|
||||
L["qp50"] if L else None,
|
||||
L["qp95"] if L else None,
|
||||
sq["qp50"],
|
||||
sq["qp95"],
|
||||
)
|
||||
row2(
|
||||
"query filtered p50/p95",
|
||||
L["qfp50"] if L else None,
|
||||
L["qfp95"] if L else None,
|
||||
sq["qfp50"],
|
||||
sq["qfp95"],
|
||||
)
|
||||
row("get_modified_times p50", L["gmt_p50"] if L else None, sq["gmt_p50"], _ms)
|
||||
row2(
|
||||
"upsert p50/p95",
|
||||
L["up50"] if L else None,
|
||||
L["up95"] if L else None,
|
||||
sq["up50"],
|
||||
sq["up95"],
|
||||
)
|
||||
row("compact", L["compact"] if L else None, sq["compact"], _s)
|
||||
row("file size pre-compact", L["size_pre"] if L else None, sq["size_pre"], _mb)
|
||||
row("file size post-compact", L["size_post"] if L else None, sq["size_post"], _mb)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description="Vector store head-to-head benchmark")
|
||||
ap.add_argument("--n-docs", type=int, default=DEFAULT_N_DOCS)
|
||||
ap.add_argument("--chunks-per-doc", type=int, default=DEFAULT_CHUNKS_PER_DOC)
|
||||
ap.add_argument("--data-file", default=DEFAULT_DATA_FILE)
|
||||
ap.add_argument("--regenerate", action="store_true")
|
||||
ap.add_argument("--ollama-url", default=DEFAULT_OLLAMA_URL)
|
||||
ap.add_argument("--embed-model", default=DEFAULT_EMBED_MODEL)
|
||||
ap.add_argument("--query-iters", type=int, default=DEFAULT_QUERY_ITERS)
|
||||
args = ap.parse_args()
|
||||
|
||||
warm_up(args.ollama_url, args.embed_model)
|
||||
|
||||
data_path = Path(args.data_file)
|
||||
if args.regenerate or not data_path.exists():
|
||||
docs = generate_and_save(
|
||||
args.n_docs,
|
||||
args.chunks_per_doc,
|
||||
args.ollama_url,
|
||||
args.embed_model,
|
||||
args.data_file,
|
||||
)
|
||||
else:
|
||||
print(f"Loading {args.data_file}...")
|
||||
with open(data_path, "rb") as f:
|
||||
docs = pickle.load(f)
|
||||
print(f"Loaded {len(docs)} docs ({sum(len(d['chunks']) for d in docs)} nodes)")
|
||||
|
||||
all_nodes = _build_nodes(docs)
|
||||
|
||||
lance_r = None
|
||||
if _LANCE_OK:
|
||||
print("\nBenchmarking LanceDB...")
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
store = PaperlessLanceVectorStore(uri=d)
|
||||
lance_r = run_bench(store, all_nodes, docs, args.query_iters, is_lance=True)
|
||||
else:
|
||||
print("Skipping LanceDB (PaperlessLanceVectorStore not importable).")
|
||||
|
||||
print("\nBenchmarking sqlite-vec...")
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
store = PaperlessSqliteVecVectorStore(uri=d)
|
||||
sqlite_r = run_bench(store, all_nodes, docs, args.query_iters, is_lance=False)
|
||||
|
||||
print_results(all_nodes, args.query_iters, lance_r, sqlite_r)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,346 @@
|
||||
# ruff: noqa: T201
|
||||
"""
|
||||
cProfile-based search pipeline profiling with a 20k-document dataset.
|
||||
|
||||
Run with:
|
||||
uv run pytest ../test_backend_profile.py \
|
||||
-m profiling --override-ini="addopts=" -s -v
|
||||
|
||||
Each scenario prints:
|
||||
- Wall time for the operation
|
||||
- cProfile stats sorted by cumulative time (top 25 callers)
|
||||
|
||||
This is a developer tool, not a correctness test. Nothing here should
|
||||
fail unless the code is broken.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from documents.models import Document
|
||||
from documents.search._backend import TantivyBackend
|
||||
from documents.search._backend import reset_backend
|
||||
from profiling import profile_cpu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
# transaction=False (default): tests roll back, but the module-scoped fixture
|
||||
# commits its data outside the test transaction so it remains visible throughout.
|
||||
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset constants
|
||||
# ---------------------------------------------------------------------------
|
||||
NUM_DOCS = 20_000
|
||||
SEED = 42
|
||||
|
||||
# Terms and their approximate match rates across the corpus.
|
||||
# "rechnung" -> ~70% of docs (~14 000)
|
||||
# "mahnung" -> ~20% of docs (~4 000)
|
||||
# "kontonummer" -> ~5% of docs (~1 000)
|
||||
# "rarewort" -> ~1% of docs (~200)
|
||||
COMMON_TERM = "rechnung"
|
||||
MEDIUM_TERM = "mahnung"
|
||||
RARE_TERM = "kontonummer"
|
||||
VERY_RARE_TERM = "rarewort"
|
||||
|
||||
PAGE_SIZE = 25
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FILLER_WORDS = [
|
||||
"dokument", # codespell:ignore
|
||||
"seite",
|
||||
"datum",
|
||||
"betrag",
|
||||
"nummer",
|
||||
"konto",
|
||||
"firma",
|
||||
"vertrag",
|
||||
"lieferant",
|
||||
"bestellung",
|
||||
"steuer",
|
||||
"mwst",
|
||||
"leistung",
|
||||
"auftrag",
|
||||
"zahlung",
|
||||
]
|
||||
|
||||
|
||||
def _build_content(rng: random.Random) -> str:
|
||||
"""Return a short paragraph with terms embedded at the desired rates."""
|
||||
words = rng.choices(_FILLER_WORDS, k=15)
|
||||
if rng.random() < 0.70:
|
||||
words.append(COMMON_TERM)
|
||||
if rng.random() < 0.20:
|
||||
words.append(MEDIUM_TERM)
|
||||
if rng.random() < 0.05:
|
||||
words.append(RARE_TERM)
|
||||
if rng.random() < 0.01:
|
||||
words.append(VERY_RARE_TERM)
|
||||
rng.shuffle(words)
|
||||
return " ".join(words)
|
||||
|
||||
|
||||
def _time(fn, *, label: str, runs: int = 3):
|
||||
"""Run *fn()* several times and report min/avg/max wall time (no cProfile)."""
|
||||
times = []
|
||||
result = None
|
||||
for _ in range(runs):
|
||||
t0 = time.perf_counter()
|
||||
result = fn()
|
||||
times.append(time.perf_counter() - t0)
|
||||
mn, avg, mx = min(times), sum(times) / len(times), max(times)
|
||||
print(
|
||||
f" {label}: min={mn * 1000:.1f}ms avg={avg * 1000:.1f}ms max={mx * 1000:.1f}ms (n={runs})",
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def module_db(django_db_setup, django_db_blocker):
|
||||
"""Unlock the DB for the whole module (module-scoped)."""
|
||||
with django_db_blocker.unblock():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def large_backend(tmp_path_factory, module_db) -> TantivyBackend:
|
||||
"""
|
||||
Build a 20 000-document DB + on-disk Tantivy index, shared across all
|
||||
profiling scenarios in this module. Teardown deletes all documents.
|
||||
"""
|
||||
index_path: Path = tmp_path_factory.mktemp("tantivy_profile")
|
||||
|
||||
# ---- 1. Bulk-create Document rows ----------------------------------------
|
||||
rng = random.Random(SEED)
|
||||
docs = [
|
||||
Document(
|
||||
title=f"Document {i:05d}",
|
||||
content=_build_content(rng),
|
||||
checksum=f"{i:064x}",
|
||||
pk=i + 1,
|
||||
)
|
||||
for i in range(NUM_DOCS)
|
||||
]
|
||||
t0 = time.perf_counter()
|
||||
Document.objects.bulk_create(docs, batch_size=1_000)
|
||||
db_time = time.perf_counter() - t0
|
||||
print(f"\n[setup] bulk_create {NUM_DOCS} docs: {db_time:.2f}s")
|
||||
|
||||
# ---- 2. Build Tantivy index -----------------------------------------------
|
||||
backend = TantivyBackend(path=index_path)
|
||||
backend.open()
|
||||
|
||||
t0 = time.perf_counter()
|
||||
with backend.batch_update() as batch:
|
||||
for doc in Document.objects.iterator(chunk_size=500):
|
||||
batch.add_or_update(doc)
|
||||
idx_time = time.perf_counter() - t0
|
||||
print(f"[setup] index {NUM_DOCS} docs: {idx_time:.2f}s")
|
||||
|
||||
# ---- 3. Report corpus stats -----------------------------------------------
|
||||
for term in (COMMON_TERM, MEDIUM_TERM, RARE_TERM, VERY_RARE_TERM):
|
||||
count = len(backend.search_ids(term, user=None))
|
||||
print(f"[setup] '{term}' -> {count} hits")
|
||||
|
||||
yield backend
|
||||
|
||||
# ---- Teardown ------------------------------------------------------------
|
||||
backend.close()
|
||||
reset_backend()
|
||||
Document.objects.all().delete()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Profiling tests — each scenario is a separate function so pytest can run
|
||||
# them individually or all together with -m profiling.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchIdsProfile:
|
||||
"""Profile backend.search_ids() — pure Tantivy, no DB."""
|
||||
|
||||
def test_search_ids_large(self, large_backend: TantivyBackend):
|
||||
"""~14 000 hits: how long does Tantivy take to collect all IDs?"""
|
||||
profile_cpu(
|
||||
lambda: large_backend.search_ids(COMMON_TERM, user=None),
|
||||
label=f"search_ids('{COMMON_TERM}') [large result set ~14k]",
|
||||
)
|
||||
|
||||
def test_search_ids_medium(self, large_backend: TantivyBackend):
|
||||
"""~4 000 hits."""
|
||||
profile_cpu(
|
||||
lambda: large_backend.search_ids(MEDIUM_TERM, user=None),
|
||||
label=f"search_ids('{MEDIUM_TERM}') [medium result set ~4k]",
|
||||
)
|
||||
|
||||
def test_search_ids_rare(self, large_backend: TantivyBackend):
|
||||
"""~1 000 hits."""
|
||||
profile_cpu(
|
||||
lambda: large_backend.search_ids(RARE_TERM, user=None),
|
||||
label=f"search_ids('{RARE_TERM}') [rare result set ~1k]",
|
||||
)
|
||||
|
||||
|
||||
class TestIntersectAndOrderProfile:
|
||||
"""
|
||||
Profile the DB intersection step: filter(pk__in=search_ids).
|
||||
This is the 'intersect_and_order' logic from views.py.
|
||||
"""
|
||||
|
||||
def test_intersect_large(self, large_backend: TantivyBackend):
|
||||
"""Intersect 14k Tantivy IDs with all 20k ORM-visible docs."""
|
||||
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
|
||||
qs = Document.objects.all()
|
||||
|
||||
print(f"\n Tantivy returned {len(all_ids)} IDs")
|
||||
|
||||
profile_cpu(
|
||||
lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)),
|
||||
label=f"filter(pk__in={len(all_ids)} ids) [large, use_tantivy_sort=True path]",
|
||||
)
|
||||
|
||||
# Also time it a few times to get stable numbers
|
||||
print()
|
||||
_time(
|
||||
lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)),
|
||||
label=f"filter(pk__in={len(all_ids)}) repeated",
|
||||
)
|
||||
|
||||
def test_intersect_rare(self, large_backend: TantivyBackend):
|
||||
"""Intersect ~1k Tantivy IDs — the happy path."""
|
||||
all_ids = large_backend.search_ids(RARE_TERM, user=None)
|
||||
qs = Document.objects.all()
|
||||
|
||||
print(f"\n Tantivy returned {len(all_ids)} IDs")
|
||||
|
||||
profile_cpu(
|
||||
lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)),
|
||||
label=f"filter(pk__in={len(all_ids)} ids) [rare, use_tantivy_sort=True path]",
|
||||
)
|
||||
|
||||
|
||||
class TestHighlightHitsProfile:
|
||||
"""Profile backend.highlight_hits() — per-doc Tantivy lookups with BM25 scoring."""
|
||||
|
||||
def test_highlight_page1(self, large_backend: TantivyBackend):
|
||||
"""25-doc highlight for page 1 (rank_start=1)."""
|
||||
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
|
||||
page_ids = all_ids[:PAGE_SIZE]
|
||||
|
||||
profile_cpu(
|
||||
lambda: large_backend.highlight_hits(
|
||||
COMMON_TERM,
|
||||
page_ids,
|
||||
rank_start=1,
|
||||
),
|
||||
label=f"highlight_hits page 1 (ids {all_ids[0]}..{all_ids[PAGE_SIZE - 1]})",
|
||||
)
|
||||
|
||||
def test_highlight_page_middle(self, large_backend: TantivyBackend):
|
||||
"""25-doc highlight for a mid-corpus page (rank_start=page_offset+1)."""
|
||||
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
|
||||
mid = len(all_ids) // 2
|
||||
page_ids = all_ids[mid : mid + PAGE_SIZE]
|
||||
page_offset = mid
|
||||
|
||||
profile_cpu(
|
||||
lambda: large_backend.highlight_hits(
|
||||
COMMON_TERM,
|
||||
page_ids,
|
||||
rank_start=page_offset + 1,
|
||||
),
|
||||
label=f"highlight_hits page ~{mid // PAGE_SIZE} (offset {page_offset})",
|
||||
)
|
||||
|
||||
def test_highlight_repeated(self, large_backend: TantivyBackend):
|
||||
"""Multiple runs of page-1 highlight to see variance."""
|
||||
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
|
||||
page_ids = all_ids[:PAGE_SIZE]
|
||||
|
||||
print()
|
||||
_time(
|
||||
lambda: large_backend.highlight_hits(COMMON_TERM, page_ids, rank_start=1),
|
||||
label="highlight_hits page 1",
|
||||
runs=5,
|
||||
)
|
||||
|
||||
|
||||
class TestFullPipelineProfile:
|
||||
"""
|
||||
Profile the combined pipeline as it runs in views.py:
|
||||
search_ids -> filter(pk__in) -> highlight_hits
|
||||
"""
|
||||
|
||||
def _run_pipeline(
|
||||
self,
|
||||
backend: TantivyBackend,
|
||||
term: str,
|
||||
page: int = 1,
|
||||
):
|
||||
all_ids = backend.search_ids(term, user=None)
|
||||
qs = Document.objects.all()
|
||||
visible_ids = set(qs.filter(pk__in=all_ids).values_list("pk", flat=True))
|
||||
ordered_ids = [i for i in all_ids if i in visible_ids]
|
||||
|
||||
page_offset = (page - 1) * PAGE_SIZE
|
||||
page_ids = ordered_ids[page_offset : page_offset + PAGE_SIZE]
|
||||
hits = backend.highlight_hits(
|
||||
term,
|
||||
page_ids,
|
||||
rank_start=page_offset + 1,
|
||||
)
|
||||
return ordered_ids, hits
|
||||
|
||||
def test_pipeline_large_page1(self, large_backend: TantivyBackend):
|
||||
"""Full pipeline: large result set, page 1."""
|
||||
ordered_ids, hits = profile_cpu(
|
||||
lambda: self._run_pipeline(large_backend, COMMON_TERM, page=1),
|
||||
label=f"full pipeline '{COMMON_TERM}' page 1",
|
||||
)[0]
|
||||
print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page")
|
||||
|
||||
def test_pipeline_large_page5(self, large_backend: TantivyBackend):
|
||||
"""Full pipeline: large result set, page 5."""
|
||||
ordered_ids, hits = profile_cpu(
|
||||
lambda: self._run_pipeline(large_backend, COMMON_TERM, page=5),
|
||||
label=f"full pipeline '{COMMON_TERM}' page 5",
|
||||
)[0]
|
||||
print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page")
|
||||
|
||||
def test_pipeline_rare(self, large_backend: TantivyBackend):
|
||||
"""Full pipeline: rare term, page 1 (fast path)."""
|
||||
ordered_ids, hits = profile_cpu(
|
||||
lambda: self._run_pipeline(large_backend, RARE_TERM, page=1),
|
||||
label=f"full pipeline '{RARE_TERM}' page 1",
|
||||
)[0]
|
||||
print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page")
|
||||
|
||||
def test_pipeline_repeated(self, large_backend: TantivyBackend):
|
||||
"""Repeated runs to get stable timing (no cProfile overhead)."""
|
||||
print()
|
||||
for term, label in [
|
||||
(COMMON_TERM, f"'{COMMON_TERM}' (large)"),
|
||||
(MEDIUM_TERM, f"'{MEDIUM_TERM}' (medium)"),
|
||||
(RARE_TERM, f"'{RARE_TERM}' (rare)"),
|
||||
]:
|
||||
_time(
|
||||
lambda t=term: self._run_pipeline(large_backend, t, page=1),
|
||||
label=f"full pipeline {label} page 1",
|
||||
runs=3,
|
||||
)
|
||||
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Search performance profiling tests.
|
||||
|
||||
Run explicitly — excluded from the normal test suite:
|
||||
|
||||
uv run pytest -m profiling -s -p no:xdist --override-ini="addopts=" -v
|
||||
|
||||
The ``-s`` flag is required to see profile_block() output.
|
||||
The ``-p no:xdist`` flag disables parallel execution for accurate measurements.
|
||||
|
||||
Corpus: 5 000 documents generated deterministically from a fixed Faker seed,
|
||||
with realistic variety: 30 correspondents, 15 document types, 50 tags, ~500
|
||||
notes spread across ~10 % of documents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from django.contrib.auth.models import User
|
||||
from faker import Faker
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Note
|
||||
from documents.models import Tag
|
||||
from documents.search import get_backend
|
||||
from documents.search import reset_backend
|
||||
from documents.search._backend import SearchMode
|
||||
from profiling import profile_block
|
||||
|
||||
pytestmark = [pytest.mark.profiling, pytest.mark.search, pytest.mark.django_db]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Corpus parameters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOC_COUNT = 5_000
|
||||
SEED = 42
|
||||
NUM_CORRESPONDENTS = 30
|
||||
NUM_DOC_TYPES = 15
|
||||
NUM_TAGS = 50
|
||||
NOTE_FRACTION = 0.10 # ~500 documents get a note
|
||||
PAGE_SIZE = 25
|
||||
|
||||
|
||||
def _build_corpus(rng: random.Random, fake: Faker) -> None:
|
||||
"""
|
||||
Insert the full corpus into the database and index it.
|
||||
|
||||
Uses bulk_create for the Document rows (fast) then handles the M2M tag
|
||||
relationships and notes individually. Indexes the full corpus with a
|
||||
single backend.rebuild() call.
|
||||
"""
|
||||
import datetime
|
||||
|
||||
# ---- lookup objects -------------------------------------------------
|
||||
correspondents = [
|
||||
Correspondent.objects.create(name=f"profcorp-{i}-{fake.company()}"[:128])
|
||||
for i in range(NUM_CORRESPONDENTS)
|
||||
]
|
||||
doc_types = [
|
||||
DocumentType.objects.create(name=f"proftype-{i}-{fake.word()}"[:128])
|
||||
for i in range(NUM_DOC_TYPES)
|
||||
]
|
||||
tags = [
|
||||
Tag.objects.create(name=f"proftag-{i}-{fake.word()}"[:100])
|
||||
for i in range(NUM_TAGS)
|
||||
]
|
||||
note_user = User.objects.create_user(username="profnoteuser", password="x")
|
||||
|
||||
# ---- bulk-create documents ------------------------------------------
|
||||
base_date = datetime.date(2018, 1, 1)
|
||||
raw_docs = []
|
||||
for i in range(DOC_COUNT):
|
||||
day_offset = rng.randint(0, 6 * 365)
|
||||
created = base_date + datetime.timedelta(days=day_offset)
|
||||
raw_docs.append(
|
||||
Document(
|
||||
title=fake.sentence(nb_words=rng.randint(3, 9)).rstrip("."),
|
||||
content="\n\n".join(
|
||||
fake.paragraph(nb_sentences=rng.randint(3, 7))
|
||||
for _ in range(rng.randint(2, 5))
|
||||
),
|
||||
checksum=f"PROF{i:07d}",
|
||||
correspondent=rng.choice(correspondents + [None] * 8),
|
||||
document_type=rng.choice(doc_types + [None] * 4),
|
||||
created=created,
|
||||
),
|
||||
)
|
||||
documents = Document.objects.bulk_create(raw_docs)
|
||||
|
||||
# ---- tags (M2M, post-bulk) ------------------------------------------
|
||||
for doc in documents:
|
||||
k = rng.randint(0, 5)
|
||||
if k:
|
||||
doc.tags.add(*rng.sample(tags, k))
|
||||
|
||||
# ---- notes on ~10 % of docs -----------------------------------------
|
||||
note_docs = rng.sample(documents, int(DOC_COUNT * NOTE_FRACTION))
|
||||
for doc in note_docs:
|
||||
Note.objects.create(
|
||||
document=doc,
|
||||
note=fake.sentence(nb_words=rng.randint(6, 20)),
|
||||
user=note_user,
|
||||
)
|
||||
|
||||
# ---- build Tantivy index --------------------------------------------
|
||||
backend = get_backend()
|
||||
qs = Document.objects.select_related(
|
||||
"correspondent",
|
||||
"document_type",
|
||||
"storage_path",
|
||||
"owner",
|
||||
).prefetch_related("tags", "notes__user", "custom_fields__field")
|
||||
backend.rebuild(qs)
|
||||
|
||||
|
||||
class TestSearchProfiling:
|
||||
"""
|
||||
Performance profiling for the Tantivy search backend and DRF API layer.
|
||||
|
||||
Each test builds a fresh 5 000-document corpus, exercises one hot path,
|
||||
and prints profile_block() measurements to stdout. No correctness
|
||||
assertions — the goal is to surface hot spots and track regressions.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self, tmp_path, settings):
|
||||
index_dir = tmp_path / "index"
|
||||
index_dir.mkdir()
|
||||
settings.INDEX_DIR = index_dir
|
||||
|
||||
reset_backend()
|
||||
rng = random.Random(SEED)
|
||||
fake = Faker()
|
||||
Faker.seed(SEED)
|
||||
|
||||
self.user = User.objects.create_superuser(
|
||||
username="profiler",
|
||||
password="admin",
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
_build_corpus(rng, fake)
|
||||
yield
|
||||
reset_backend()
|
||||
|
||||
# -- 1. Backend: search_ids relevance ---------------------------------
|
||||
|
||||
def test_profile_search_ids_relevance(self):
|
||||
"""Profile: search_ids() with relevance ordering across several queries."""
|
||||
backend = get_backend()
|
||||
queries = [
|
||||
"invoice payment",
|
||||
"annual report",
|
||||
"bank statement",
|
||||
"contract agreement",
|
||||
"receipt",
|
||||
]
|
||||
with profile_block(f"search_ids — relevance ({len(queries)} queries)"):
|
||||
for q in queries:
|
||||
backend.search_ids(q, user=None)
|
||||
|
||||
# -- 2. Backend: search_ids with Tantivy-native sort ------------------
|
||||
|
||||
def test_profile_search_ids_sorted(self):
|
||||
"""Profile: search_ids() sorted by a Tantivy fast field (created)."""
|
||||
backend = get_backend()
|
||||
with profile_block("search_ids — sorted by created (asc + desc)"):
|
||||
backend.search_ids(
|
||||
"the",
|
||||
user=None,
|
||||
sort_field="created",
|
||||
sort_reverse=False,
|
||||
)
|
||||
backend.search_ids(
|
||||
"the",
|
||||
user=None,
|
||||
sort_field="created",
|
||||
sort_reverse=True,
|
||||
)
|
||||
|
||||
# -- 3. Backend: highlight_hits for a page of 25 ----------------------
|
||||
|
||||
def test_profile_highlight_hits(self):
|
||||
"""Profile: highlight_hits() for a 25-document page."""
|
||||
backend = get_backend()
|
||||
all_ids = backend.search_ids("report", user=None)
|
||||
page_ids = all_ids[:PAGE_SIZE]
|
||||
with profile_block(f"highlight_hits — {len(page_ids)} docs"):
|
||||
backend.highlight_hits("report", page_ids)
|
||||
|
||||
# -- 4. Backend: autocomplete -----------------------------------------
|
||||
|
||||
def test_profile_autocomplete(self):
|
||||
"""Profile: autocomplete() with eight common prefixes."""
|
||||
backend = get_backend()
|
||||
prefixes = ["inv", "pay", "con", "rep", "sta", "acc", "doc", "fin"]
|
||||
with profile_block(f"autocomplete — {len(prefixes)} prefixes"):
|
||||
for prefix in prefixes:
|
||||
backend.autocomplete(prefix, limit=10)
|
||||
|
||||
# -- 5. Backend: simple-mode search (TEXT and TITLE) ------------------
|
||||
|
||||
def test_profile_search_ids_simple_modes(self):
|
||||
"""Profile: search_ids() in TEXT and TITLE simple-search modes."""
|
||||
backend = get_backend()
|
||||
queries = ["invoice 2023", "annual report", "bank statement"]
|
||||
with profile_block(
|
||||
f"search_ids — TEXT + TITLE modes ({len(queries)} queries each)",
|
||||
):
|
||||
for q in queries:
|
||||
backend.search_ids(q, user=None, search_mode=SearchMode.TEXT)
|
||||
backend.search_ids(q, user=None, search_mode=SearchMode.TITLE)
|
||||
|
||||
# -- 6. API: full round-trip, relevance + page 1 ----------------------
|
||||
|
||||
def test_profile_api_relevance_search(self):
|
||||
"""Profile: full API search round-trip, relevance order, page 1."""
|
||||
with profile_block(
|
||||
f"API /documents/?query=… relevance (page 1, page_size={PAGE_SIZE})",
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?query=invoice+payment&page=1&page_size={PAGE_SIZE}",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# -- 7. API: full round-trip, ORM-ordered (title) ---------------------
|
||||
|
||||
def test_profile_api_orm_sorted_search(self):
|
||||
"""Profile: full API search round-trip with ORM-delegated sort (title)."""
|
||||
with profile_block("API /documents/?query=…&ordering=title"):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?query=report&ordering=title&page=1&page_size={PAGE_SIZE}",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# -- 8. API: full round-trip, score sort ------------------------------
|
||||
|
||||
def test_profile_api_score_sort(self):
|
||||
"""Profile: full API search with ordering=-score (relevance, preserve order)."""
|
||||
with profile_block("API /documents/?query=…&ordering=-score"):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?query=statement&ordering=-score&page=1&page_size={PAGE_SIZE}",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# -- 9. API: full round-trip, with selection_data ---------------------
|
||||
|
||||
def test_profile_api_with_selection_data(self):
|
||||
"""Profile: full API search including include_selection_data=true."""
|
||||
with profile_block("API /documents/?query=…&include_selection_data=true"):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?query=contract&page=1&page_size={PAGE_SIZE}"
|
||||
"&include_selection_data=true",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "selection_data" in response.data
|
||||
|
||||
# -- 10. API: paginated (page 2) --------------------------------------
|
||||
|
||||
def test_profile_api_page_2(self):
|
||||
"""Profile: full API search, page 2 — exercises page offset arithmetic."""
|
||||
with profile_block(f"API /documents/?query=…&page=2&page_size={PAGE_SIZE}"):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?query=the&page=2&page_size={PAGE_SIZE}",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
Reference in New Issue
Block a user