Compare commits

..

14 Commits

Author SHA1 Message Date
Trenton H
0887203d45 feat(profiling): add workflow trigger matching profiling
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 14:55:25 -07:00
Trenton H
ea14c0b06f fix(profiling): use sha256 for sanity corpus checksums
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 14:50:21 -07:00
Trenton H
a8dc332abb feat(profiling): add sanity checker profiling
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 14:42:00 -07:00
Trenton H
e64b9a4cfd feat(profiling): add matching pipeline profiling
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 14:30:33 -07:00
Trenton H
6ba1acd7d3 fix(profiling): fix stale docstring and add module_db docstring in doclist test 2026-04-11 14:19:51 -07:00
Trenton H
d006b79fd1 feat(profiling): add document list API and selection_data profiling
Adds test_doclist_profile.py with 8 profiling tests covering the
/api/documents/ list path (ORM ordering, page sizes, single-doc detail,
cProfile) and _get_selection_data_for_queryset in isolation and via API.
Also registers the 'profiling' pytest marker in pyproject.toml.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 14:13:17 -07:00
Trenton H
24b754b44c fix(profiling): fix stale run paths in docstrings and consolidate profiling imports 2026-04-11 13:57:00 -07:00
Trenton H
a1a3520a8c refactor(profiling): use shared profile_cpu in backend search test
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 13:46:43 -07:00
Trenton H
23449cda17 refactor(profiling): use shared profile_cpu/measure_memory in classifier test
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 13:44:57 -07:00
Trenton H
ca3f5665ba fix(profiling): correct docstring import path and add Callable type annotation 2026-04-11 13:29:48 -07:00
Trenton H
9aa0914c3f feat(profiling): add profile_cpu and measure_memory shared helpers
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-11 13:22:43 -07:00
shamoon
fdd5e3ecb2 Update SECURITY.md 2026-04-10 12:34:47 -07:00
shamoon
df3b656352 Add tests 2026-04-10 12:06:28 -07:00
shamoon
51e721733f Enhancement: validate and sanitize uploaded logos (#12551) 2026-04-10 11:50:58 -07:00
25 changed files with 3373 additions and 945 deletions

View File

@@ -2,8 +2,83 @@
## Reporting a Vulnerability
The Paperless-ngx team and community take security bugs seriously. We appreciate your efforts to responsibly disclose your findings, and will make every effort to acknowledge your contributions.
The Paperless-ngx team and community take security issues seriously. We appreciate good-faith reports and will make every effort to review legitimate findings responsibly.
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/paperless-ngx/paperless-ngx/security/advisories/new) tab.
The team will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
After the initial reply to your report, the team may ask for additional information, reproduction steps, affected versions, configuration details, or proof-of-concept material needed to verify the issue.
## What makes a helpful report
Please include as much of the following as possible:
- A clear description of the issue and why it is a security vulnerability.
- Affected Paperless-ngx version(s).
- Required configuration, permissions, or preconditions.
- Step-by-step reproduction instructions.
- A realistic impact statement.
- Logs, screenshots, requests, or a proof of concept where appropriate.
- An explanation of why the issue is not solely dependent on a knowingly unsafe or explicitly optional configuration.
Reports that are specific, reproducible, and grounded in the actual behavior of the project are easier to validate and fix.
## AI-assisted submissions
AI tools can be useful for security research, but they can also produce plausible-sounding reports that do not reflect actual vulnerabilities in Paperless-ngx.
If you use AI tools to help draft or investigate a report, please disclose this and verify the result yourself before submission. Make sure the report:
- identifies a real code path in this project
- includes a real reproduction, not just a theoretical concern
- demonstrates realistic impact
- accounts for existing permission checks, configuration requirements, and trust boundaries
- is about Paperless-ngx itself, rather than a hypothetical deployment, downstream customization, third-party app, scanner output, reverse proxy behavior, or external service unless you can show a direct product vulnerability
Submitting unverified AI-generated claims creates unnecessary review work and slows down handling of legitimate reports.
<!-- Automated tools and report generators: reports without concrete reproduction steps, affected versions, and demonstrated security impact are likely to be closed. -->
## Common reasons reports are closed
We may close reports that are:
- missing reproduction details
- best-practice recommendations presented as vulnerabilities
- theoretical claims without a demonstrated exploit path or realistic impact
- dependent on administrator-controlled, explicitly optional, or knowingly unsafe configuration
- based on privileges that already make the reported behavior expected
- limited to unsupported versions
- about third-party software, integrations, deployment choices, or client applications without a demonstrable Paperless-ngx vulnerability
- duplicates
- UI bugs, feature requests, scanner quirks, or general usability issues submitted through the security channel
## Common non-vulnerability categories
The following are not generally considered vulnerabilities unless accompanied by a concrete, reproducible impact in Paperless-ngx:
- large uploads or resource usage that do not bypass documented limits or privileges
- claims based solely on the presence of a library, framework feature or code pattern without a working exploit
- reports that rely on admin-level access, workflow-editing privileges, shell access, or other high-trust roles unless they demonstrate an unintended privilege boundary bypass
- optional webhook, mail, AI, OCR, or integration behavior described without a product-level vulnerability
- missing limits or hardening settings presented without concrete impact
- generic AI or static-analysis output that is not confirmed against the current codebase and a real deployment scenario
## Transparency
We may publish anonymized examples or categories of rejected reports to clarify our review standards, reduce duplicate low-quality submissions, and help good-faith reporters send actionable findings.
A mistaken report made in good faith is not misconduct. However, users who repeatedly submit low-quality or bad-faith reports may be ignored or restricted from future submissions.
## Scope and expectations
Please use the security reporting channel only for security vulnerabilities in Paperless-ngx.
Please do not use the security advisory system for:
- support questions
- general bug reports
- feature requests
- browser compatibility issues
- issues in third-party mobile apps, reverse proxies, or deployment tooling unless you can demonstrate a Paperless-ngx vulnerability
The team will review reports as time permits, but submission does not guarantee that a report is valid, in scope, or will result in a fix. Reports that do not describe a reproducible product-level issue may be closed without extended back-and-forth.

150
profiling.py Normal file
View File

@@ -0,0 +1,150 @@
"""
Temporary profiling utilities for comparing implementations.
Usage in a management command or shell::
from profiling import profile_block, profile_cpu, measure_memory
with profile_block("new check_sanity"):
messages = check_sanity()
with profile_block("old check_sanity"):
messages = check_sanity_old()
Drop this file when done.
"""
from __future__ import annotations
import tracemalloc
from collections.abc import Callable # noqa: TC003
from collections.abc import Generator # noqa: TC003
from contextlib import contextmanager
from time import perf_counter
from typing import Any
from django.db import connection
from django.db import reset_queries
from django.test.utils import override_settings
@contextmanager
def profile_block(label: str = "block") -> Generator[None, None, None]:
"""Profile memory, wall time, and DB queries for a code block.
Prints a summary to stdout on exit. Requires no external packages.
Enables DEBUG temporarily to capture Django's query log.
"""
tracemalloc.start()
snapshot_before = tracemalloc.take_snapshot()
with override_settings(DEBUG=True):
reset_queries()
start = perf_counter()
yield
elapsed = perf_counter() - start
queries = list(connection.queries)
snapshot_after = tracemalloc.take_snapshot()
_, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
# Compare snapshots for top allocations
stats = snapshot_after.compare_to(snapshot_before, "lineno")
query_time = sum(float(q["time"]) for q in queries)
mem_diff = sum(s.size_diff for s in stats)
print(f"\n{'=' * 60}") # noqa: T201
print(f" Profile: {label}") # noqa: T201
print(f"{'=' * 60}") # noqa: T201
print(f" Wall time: {elapsed:.4f}s") # noqa: T201
print(f" Queries: {len(queries)} ({query_time:.4f}s)") # noqa: T201
print(f" Memory delta: {mem_diff / 1024:.1f} KiB") # noqa: T201
print(f" Peak memory: {peak / 1024:.1f} KiB") # noqa: T201
print("\n Top 5 allocations:") # noqa: T201
for stat in stats[:5]:
print(f" {stat}") # noqa: T201
print(f"{'=' * 60}\n") # noqa: T201
def profile_cpu(
fn: Callable[[], Any],
*,
label: str,
top: int = 30,
sort: str = "cumtime",
) -> tuple[Any, float]:
"""Run *fn()* under cProfile, print stats, return (result, elapsed_s).
Args:
fn: Zero-argument callable to profile.
label: Human-readable label printed in the header.
top: Number of cProfile rows to print.
sort: cProfile sort key (default: cumulative time).
Returns:
``(result, elapsed_s)`` where *result* is the return value of *fn()*.
"""
import cProfile
import io
import pstats
pr = cProfile.Profile()
t0 = perf_counter()
pr.enable()
result = fn()
pr.disable()
elapsed = perf_counter() - t0
buf = io.StringIO()
ps = pstats.Stats(pr, stream=buf).sort_stats(sort)
ps.print_stats(top)
print(f"\n{'=' * 72}") # noqa: T201
print(f" {label}") # noqa: T201
print(f" wall time: {elapsed * 1000:.1f} ms") # noqa: T201
print(f"{'=' * 72}") # noqa: T201
print(buf.getvalue()) # noqa: T201
return result, elapsed
def measure_memory(fn: Callable[[], Any], *, label: str) -> tuple[Any, float, float]:
"""Run *fn()* under tracemalloc, print allocation report.
Args:
fn: Zero-argument callable to profile.
label: Human-readable label printed in the header.
Returns:
``(result, peak_kib, delta_kib)``.
"""
tracemalloc.start()
snapshot_before = tracemalloc.take_snapshot()
t0 = perf_counter()
result = fn()
elapsed = perf_counter() - t0
snapshot_after = tracemalloc.take_snapshot()
_, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
stats = snapshot_after.compare_to(snapshot_before, "lineno")
delta_kib = sum(s.size_diff for s in stats) / 1024
print(f"\n{'=' * 72}") # noqa: T201
print(f" [memory] {label}") # noqa: T201
print(f" wall time: {elapsed * 1000:.1f} ms") # noqa: T201
print(f" memory delta: {delta_kib:+.1f} KiB") # noqa: T201
print(f" peak traced: {peak / 1024:.1f} KiB") # noqa: T201
print(f"{'=' * 72}") # noqa: T201
print(" Top allocation sites (by size_diff):") # noqa: T201
for stat in stats[:20]:
if stat.size_diff != 0:
print( # noqa: T201
f" {stat.size_diff / 1024:+8.1f} KiB {stat.traceback.format()[0]}",
)
return result, peak / 1024, delta_kib

View File

@@ -312,6 +312,7 @@ markers = [
"date_parsing: Tests which cover date parsing from content or filename",
"management: Tests which cover management commands/functionality",
"search: Tests for the Tantivy search backend",
"profiling: Performance profiling tests — print measurements, no assertions",
]
[tool.pytest_env]

View File

@@ -43,7 +43,7 @@
</div>
<p class="card-text">
@if (document) {
@if (hasSearchHighlights) {
@if (document.__search_hit__ && document.__search_hit__.highlights) {
<span [innerHtml]="document.__search_hit__.highlights"></span>
}
@for (highlight of searchNoteHighlights; track highlight) {
@@ -52,7 +52,7 @@
<span [innerHtml]="highlight"></span>
</span>
}
@if (shouldShowContentFallback) {
@if (!document.__search_hit__?.score) {
<span class="result-content">{{contentTrimmed}}</span>
}
} @else {

View File

@@ -127,19 +127,6 @@ describe('DocumentCardLargeComponent', () => {
expect(component.searchNoteHighlights).toContain('<span>bananas</span>')
})
it('should fall back to document content when a search hit has no highlights', () => {
component.document.__search_hit__ = {
score: 0.9,
rank: 1,
highlights: '',
note_highlights: null,
}
fixture.detectChanges()
expect(fixture.nativeElement.textContent).toContain('Cupcake ipsum')
expect(component.shouldShowContentFallback).toBe(true)
})
it('should try to close the preview on mouse leave', () => {
component.popupPreview = {
close: jest.fn(),

View File

@@ -164,17 +164,6 @@ export class DocumentCardLargeComponent
)
}
get hasSearchHighlights() {
return Boolean(this.document?.__search_hit__?.highlights?.trim()?.length)
}
get shouldShowContentFallback() {
return (
this.document?.__search_hit__?.score == null ||
(!this.hasSearchHighlights && this.searchNoteHighlights.length === 0)
)
}
get notesEnabled(): boolean {
return this.settingsService.get(SETTINGS_KEYS.NOTES_ENABLED)
}

View File

@@ -1,6 +1,6 @@
from documents.search._backend import SearchHit
from documents.search._backend import SearchIndexLockError
from documents.search._backend import SearchMode
from documents.search._backend import SearchResults
from documents.search._backend import TantivyBackend
from documents.search._backend import TantivyRelevanceList
from documents.search._backend import WriteBatch
@@ -10,9 +10,9 @@ from documents.search._schema import needs_rebuild
from documents.search._schema import wipe_index
__all__ = [
"SearchHit",
"SearchIndexLockError",
"SearchMode",
"SearchResults",
"TantivyBackend",
"TantivyRelevanceList",
"WriteBatch",

View File

@@ -1,13 +1,12 @@
from __future__ import annotations
import logging
import re
import threading
from collections import Counter
from dataclasses import dataclass
from datetime import UTC
from datetime import datetime
from enum import StrEnum
from html import escape
from typing import TYPE_CHECKING
from typing import Self
from typing import TypedDict
@@ -55,36 +54,6 @@ class SearchMode(StrEnum):
TITLE = "title"
def _render_snippet_html(snippet: tantivy.Snippet) -> str:
fragment = snippet.fragment()
highlighted = sorted(snippet.highlighted(), key=lambda r: r.start)
if not highlighted:
return escape(fragment)
parts: list[str] = []
cursor = 0
fragment_len = len(fragment)
for highlight in highlighted:
start = max(0, min(fragment_len, highlight.start))
end = max(start, min(fragment_len, highlight.end))
if end <= cursor:
continue
if start > cursor:
parts.append(escape(fragment[cursor:start]))
parts.append(f'<span class="match">{escape(fragment[start:end])}</span>')
cursor = end
if cursor < fragment_len:
parts.append(escape(fragment[cursor:]))
return "".join(parts)
def _extract_autocomplete_words(text_sources: list[str]) -> set[str]:
"""Extract and normalize words for autocomplete.
@@ -119,63 +88,45 @@ class SearchHit(TypedDict):
highlights: dict[str, str]
@dataclass(frozen=True, slots=True)
class SearchResults:
"""
Container for search results with pagination metadata.
Attributes:
hits: List of search results with scores and highlights
total: Total matching documents across all pages (for pagination)
query: Preprocessed query string after date/syntax rewriting
"""
hits: list[SearchHit]
total: int # total matching documents (for pagination)
query: str # preprocessed query string
class TantivyRelevanceList:
"""
DRF-compatible list wrapper for Tantivy search results.
DRF-compatible list wrapper for Tantivy search hits.
Holds a lightweight ordered list of IDs (for pagination count and
``selection_data``) together with a small page of rich ``SearchHit``
dicts (for serialization). DRF's ``PageNumberPagination`` calls
``__len__`` to compute the total page count and ``__getitem__`` to
slice the displayed page.
Provides paginated access to search results while storing all hits in memory
for efficient ID retrieval. Used by Django REST framework for pagination.
Args:
ordered_ids: All matching document IDs in display order.
page_hits: Rich SearchHit dicts for the requested DRF page only.
page_offset: Index into *ordered_ids* where *page_hits* starts.
Methods:
__len__: Returns total hit count for pagination calculations
__getitem__: Slices the hit list for page-specific results
Note: Stores ALL post-filter hits so get_all_result_ids() can return
every matching document ID without requiring a second search query.
"""
def __init__(
self,
ordered_ids: list[int],
page_hits: list[SearchHit],
page_offset: int = 0,
) -> None:
self._ordered_ids = ordered_ids
self._page_hits = page_hits
self._page_offset = page_offset
def __init__(self, hits: list[SearchHit]) -> None:
self._hits = hits
def __len__(self) -> int:
return len(self._ordered_ids)
return len(self._hits)
def __getitem__(self, key: int | slice) -> SearchHit | list[SearchHit]:
if isinstance(key, int):
idx = key if key >= 0 else len(self._ordered_ids) + key
if self._page_offset <= idx < self._page_offset + len(self._page_hits):
return self._page_hits[idx - self._page_offset]
return SearchHit(
id=self._ordered_ids[key],
score=0.0,
rank=idx + 1,
highlights={},
)
start = key.start or 0
stop = key.stop or len(self._ordered_ids)
# DRF slices to extract the current page. If the slice aligns
# with our pre-fetched page_hits, return them directly.
# We only check start — DRF always slices with stop=start+page_size,
# which exceeds page_hits length on the last page.
if start == self._page_offset:
return self._page_hits[: stop - start]
# Fallback: return stub dicts (no highlights).
return [
SearchHit(id=doc_id, score=0.0, rank=start + i + 1, highlights={})
for i, doc_id in enumerate(self._ordered_ids[key])
]
def get_all_ids(self) -> list[int]:
"""Return all matching document IDs in display order."""
return self._ordered_ids
def __getitem__(self, key: slice) -> list[SearchHit]:
return self._hits[key]
class SearchIndexLockError(Exception):
@@ -255,13 +206,10 @@ class WriteBatch:
"""
Remove a document from the batch by its primary key.
Uses range_query instead of term_query to work around a tantivy-py bug
where Python integers are inferred as i64, producing Terms that never
match u64 fields.
TODO: Replace with term_query("id", doc_id) once
https://github.com/quickwit-oss/tantivy-py/pull/642 lands.
Uses range query instead of term query to work around unsigned integer
type detection bug in tantivy-py 0.25.
"""
# Use range query to work around u64 deletion bug
self._writer.delete_documents_by_query(
tantivy.Query.range_query(
self._backend._schema,
@@ -286,34 +234,6 @@ class TantivyBackend:
the underlying index directory changes (e.g., during test isolation).
"""
# Maps DRF ordering field names to Tantivy index field names.
SORT_FIELD_MAP: dict[str, str] = {
"title": "title_sort",
"correspondent__name": "correspondent_sort",
"document_type__name": "type_sort",
"created": "created",
"added": "added",
"modified": "modified",
"archive_serial_number": "asn",
"page_count": "page_count",
"num_notes": "num_notes",
}
# Fields where Tantivy's sort order matches the ORM's sort order.
# Text-based fields (title, correspondent__name, document_type__name)
# are excluded because Tantivy's tokenized fast fields produce different
# ordering than the ORM's collation-based ordering.
SORTABLE_FIELDS: frozenset[str] = frozenset(
{
"created",
"added",
"modified",
"archive_serial_number",
"page_count",
"num_notes",
},
)
def __init__(self, path: Path | None = None):
# path=None → in-memory index (for tests)
# path=some_dir → on-disk index (for production)
@@ -352,36 +272,6 @@ class TantivyBackend:
if self._index is None:
self.open() # pragma: no cover
def _parse_query(
self,
query: str,
search_mode: SearchMode,
) -> tantivy.Query:
"""Parse a user query string into a Tantivy Query object."""
tz = get_current_timezone()
if search_mode is SearchMode.TEXT:
return parse_simple_text_query(self._index, query)
elif search_mode is SearchMode.TITLE:
return parse_simple_title_query(self._index, query)
else:
return parse_user_query(self._index, query, tz)
def _apply_permission_filter(
self,
query: tantivy.Query,
user: AbstractBaseUser | None,
) -> tantivy.Query:
"""Wrap a query with a permission filter if the user is not a superuser."""
if user is not None:
permission_filter = build_permission_filter(self._schema, user)
return tantivy.Query.boolean_query(
[
(tantivy.Occur.Must, query),
(tantivy.Occur.Must, permission_filter),
],
)
return query
def _build_tantivy_doc(
self,
document: Document,
@@ -436,17 +326,12 @@ class TantivyBackend:
doc.add_unsigned("tag_id", tag.pk)
tag_names.append(tag.name)
# Notes — JSON for structured queries (notes.user:alice, notes.note:text).
# notes_text is a plain-text companion for snippet/highlight generation;
# tantivy's SnippetGenerator does not support JSON fields.
# Notes — JSON for structured queries (notes.user:alice, notes.note:text),
# companion text field for default full-text search.
num_notes = 0
note_texts: list[str] = []
for note in document.notes.all():
num_notes += 1
doc.add_json("notes", {"note": note.note, "user": note.user.username})
note_texts.append(note.note)
if note_texts:
doc.add_text("notes_text", " ".join(note_texts))
# Custom fields — JSON for structured queries (custom_fields.name:x, custom_fields.value:y),
# companion text field for default full-text search.
@@ -540,127 +425,155 @@ class TantivyBackend:
with self.batch_update(lock_timeout=5.0) as batch:
batch.remove(doc_id)
def highlight_hits(
def search(
self,
query: str,
doc_ids: list[int],
user: AbstractBaseUser | None,
page: int,
page_size: int,
sort_field: str | None,
*,
sort_reverse: bool,
search_mode: SearchMode = SearchMode.QUERY,
rank_start: int = 1,
) -> list[SearchHit]:
) -> SearchResults:
"""
Generate SearchHit dicts with highlights for specific document IDs.
Execute a search query against the document index.
Unlike search(), this does not execute a ranked query — it looks up
each document by ID and generates snippets against the provided query.
Use this when you already know which documents to display (from
search_ids + ORM filtering) and just need highlight data.
Processes the user query through date rewriting, normalization, and
permission filtering before executing against Tantivy. Supports both
relevance-based and field-based sorting.
Note: Each doc_id requires an individual index lookup because tantivy-py
does not yet expose a batch fast-field read API. This is acceptable for
page-sized batches (typically 25 docs) but should not be called with
thousands of IDs.
TODO: When https://github.com/quickwit-oss/tantivy-py/pull/641 lands,
the per-doc range_query lookups here can be replaced with a single
collect_u64_fast_field("id", doc_addresses) call.
QUERY search mode supports natural date keywords, field filters, etc.
TITLE search mode treats the query as plain text to search for in title only
TEXT search mode treats the query as plain text to search for in title and content
Args:
query: The search query (used for snippet generation)
doc_ids: Ordered list of document IDs to generate hits for
search_mode: Query parsing mode (for building the snippet query)
rank_start: Starting rank value (1-based absolute position in the
full result set; pass ``page_offset + 1`` for paginated calls)
query: User's search query
user: User for permission filtering (None for superuser/no filtering)
page: Page number (1-indexed) for pagination
page_size: Number of results per page
sort_field: Field to sort by (None for relevance ranking)
sort_reverse: Whether to reverse the sort order
search_mode: "query" for advanced Tantivy syntax, "text" for
plain-text search over title and content only, "title" for
plain-text search over title only
Returns:
List of SearchHit dicts in the same order as doc_ids
SearchResults with hits, total count, and processed query
"""
if not doc_ids:
return []
self._ensure_open()
user_query = self._parse_query(query, search_mode)
tz = get_current_timezone()
if search_mode is SearchMode.TEXT:
user_query = parse_simple_text_query(self._index, query)
elif search_mode is SearchMode.TITLE:
user_query = parse_simple_title_query(self._index, query)
else:
user_query = parse_user_query(self._index, query, tz)
# For notes_text snippet generation, we need a query that targets the
# notes_text field directly. user_query may contain JSON-field terms
# (e.g. notes.note:urgent) that the SnippetGenerator cannot resolve
# against a text field. Strip field:value prefixes so bare terms like
# "urgent" are re-parsed against notes_text, producing highlights even
# when the original query used structured syntax.
bare_query = re.sub(r"\w[\w.]*:", "", query).strip()
try:
notes_text_query = (
self._index.parse_query(bare_query, ["notes_text"])
if bare_query
else user_query
)
except Exception:
notes_text_query = user_query
searcher = self._index.searcher()
snippet_generator = None
notes_snippet_generator = None
hits: list[SearchHit] = []
for rank, doc_id in enumerate(doc_ids, start=rank_start):
# Look up document by ID, scoring against the user query so that
# the returned SearchHit carries a real BM25 relevance score.
id_query = tantivy.Query.range_query(
self._schema,
"id",
tantivy.FieldType.Unsigned,
doc_id,
doc_id,
)
scored_query = tantivy.Query.boolean_query(
# Apply permission filter if user is not None (not superuser)
if user is not None:
permission_filter = build_permission_filter(self._schema, user)
final_query = tantivy.Query.boolean_query(
[
(tantivy.Occur.Must, user_query),
(tantivy.Occur.Must, id_query),
(tantivy.Occur.Must, permission_filter),
],
)
results = searcher.search(scored_query, limit=1)
else:
final_query = user_query
if not results.hits:
continue
searcher = self._index.searcher()
offset = (page - 1) * page_size
score, doc_address = results.hits[0]
# Map sort fields
sort_field_map = {
"title": "title_sort",
"correspondent__name": "correspondent_sort",
"document_type__name": "type_sort",
"created": "created",
"added": "added",
"modified": "modified",
"archive_serial_number": "asn",
"page_count": "page_count",
"num_notes": "num_notes",
}
# Perform search
if sort_field and sort_field in sort_field_map:
mapped_field = sort_field_map[sort_field]
results = searcher.search(
final_query,
limit=offset + page_size,
order_by_field=mapped_field,
order=tantivy.Order.Desc if sort_reverse else tantivy.Order.Asc,
)
# Field sorting: hits are still (score, DocAddress) tuples; score unused
all_hits = [(hit[1], 0.0) for hit in results.hits]
else:
# Score-based search: hits are (score, DocAddress) tuples
results = searcher.search(final_query, limit=offset + page_size)
all_hits = [(hit[1], hit[0]) for hit in results.hits]
total = results.count
# Normalize scores for score-based searches
if not sort_field and all_hits:
max_score = max(hit[1] for hit in all_hits) or 1.0
all_hits = [(hit[0], hit[1] / max_score) for hit in all_hits]
# Apply threshold filter if configured (score-based search only)
threshold = settings.ADVANCED_FUZZY_SEARCH_THRESHOLD
if threshold is not None and not sort_field:
all_hits = [hit for hit in all_hits if hit[1] >= threshold]
# Get the page's hits
page_hits = all_hits[offset : offset + page_size]
# Build result hits with highlights
hits: list[SearchHit] = []
snippet_generator = None
notes_snippet_generator = None
for rank, (doc_address, score) in enumerate(page_hits, start=offset + 1):
# Get the actual document from the searcher using the doc address
actual_doc = searcher.doc(doc_address)
doc_dict = actual_doc.to_dict()
doc_id = doc_dict["id"][0]
highlights: dict[str, str] = {}
try:
if snippet_generator is None:
snippet_generator = tantivy.SnippetGenerator.create(
searcher,
user_query,
self._schema,
"content",
)
content_html = _render_snippet_html(
snippet_generator.snippet_from_doc(actual_doc),
)
if content_html:
highlights["content"] = content_html
if "notes_text" in doc_dict:
# Use notes_text (plain text) for snippet generation — tantivy's
# SnippetGenerator does not support JSON fields.
if notes_snippet_generator is None:
notes_snippet_generator = tantivy.SnippetGenerator.create(
# Generate highlights if score > 0
if score > 0:
try:
if snippet_generator is None:
snippet_generator = tantivy.SnippetGenerator.create(
searcher,
notes_text_query,
final_query,
self._schema,
"notes_text",
"content",
)
notes_html = _render_snippet_html(
notes_snippet_generator.snippet_from_doc(actual_doc),
)
if notes_html:
highlights["notes"] = notes_html
except Exception: # pragma: no cover
logger.debug("Failed to generate highlights for doc %s", doc_id)
content_snippet = snippet_generator.snippet_from_doc(actual_doc)
if content_snippet:
highlights["content"] = str(content_snippet)
# Try notes highlights
if "notes" in doc_dict:
if notes_snippet_generator is None:
notes_snippet_generator = tantivy.SnippetGenerator.create(
searcher,
final_query,
self._schema,
"notes",
)
notes_snippet = notes_snippet_generator.snippet_from_doc(
actual_doc,
)
if notes_snippet:
highlights["notes"] = str(notes_snippet)
except Exception: # pragma: no cover
logger.debug("Failed to generate highlights for doc %s", doc_id)
hits.append(
SearchHit(
@@ -671,69 +584,11 @@ class TantivyBackend:
),
)
return hits
def search_ids(
self,
query: str,
user: AbstractBaseUser | None,
*,
sort_field: str | None = None,
sort_reverse: bool = False,
search_mode: SearchMode = SearchMode.QUERY,
limit: int | None = None,
) -> list[int]:
"""
Return document IDs matching a query — no highlights or scores.
This is the lightweight companion to search(). Use it when you need the
full set of matching IDs (e.g. for ``selection_data``) but don't need
scores, ranks, or highlights.
Args:
query: User's search query
user: User for permission filtering (None for superuser/no filtering)
sort_field: Field to sort by (None for relevance ranking)
sort_reverse: Whether to reverse the sort order
search_mode: Query parsing mode (QUERY, TEXT, or TITLE)
limit: Maximum number of IDs to return (None = all matching docs)
Returns:
List of document IDs in the requested order
"""
self._ensure_open()
user_query = self._parse_query(query, search_mode)
final_query = self._apply_permission_filter(user_query, user)
searcher = self._index.searcher()
effective_limit = limit if limit is not None else searcher.num_docs
if sort_field and sort_field in self.SORT_FIELD_MAP:
mapped_field = self.SORT_FIELD_MAP[sort_field]
results = searcher.search(
final_query,
limit=effective_limit,
order_by_field=mapped_field,
order=tantivy.Order.Desc if sort_reverse else tantivy.Order.Asc,
)
all_hits = [(hit[1],) for hit in results.hits]
else:
results = searcher.search(final_query, limit=effective_limit)
all_hits = [(hit[1], hit[0]) for hit in results.hits]
# Normalize scores and apply threshold (relevance search only)
if all_hits:
max_score = max(hit[1] for hit in all_hits) or 1.0
all_hits = [(hit[0], hit[1] / max_score) for hit in all_hits]
threshold = settings.ADVANCED_FUZZY_SEARCH_THRESHOLD
if threshold is not None:
all_hits = [hit for hit in all_hits if hit[1] >= threshold]
# TODO: Replace with searcher.collect_u64_fast_field("id", addrs) once
# https://github.com/quickwit-oss/tantivy-py/pull/641 lands — eliminates
# one stored-doc fetch per result (~80% reduction in search_ids latency).
return [searcher.doc(doc_addr).to_dict()["id"][0] for doc_addr, *_ in all_hits]
return SearchResults(
hits=hits,
total=total,
query=query,
)
def autocomplete(
self,
@@ -748,10 +603,6 @@ class TantivyBackend:
frequency (how many documents contain each word). Optionally filters
results to only words from documents visible to the specified user.
NOTE: This is the hottest search path (called per keystroke).
A future improvement would be to cache results in Redis, keyed by
(prefix, user_id), and invalidate on index writes.
Args:
term: Prefix to match against autocomplete words
limit: Maximum number of suggestions to return
@@ -762,94 +613,64 @@ class TantivyBackend:
"""
self._ensure_open()
normalized_term = ascii_fold(term.lower())
if not normalized_term:
return []
searcher = self._index.searcher()
# Build a prefix query on autocomplete_word so we only scan docs
# containing words that start with the prefix, not the entire index.
# tantivy regex is implicitly anchored; .+ avoids the empty-match
# error that .* triggers. We OR with term_query to also match the
# exact prefix as a complete word.
escaped = re.escape(normalized_term)
prefix_query = tantivy.Query.boolean_query(
[
(
tantivy.Occur.Should,
tantivy.Query.term_query(
self._schema,
"autocomplete_word",
normalized_term,
),
),
(
tantivy.Occur.Should,
tantivy.Query.regex_query(
self._schema,
"autocomplete_word",
f"{escaped}.+",
),
),
],
)
# Intersect with permission filter so autocomplete words from
# invisible documents don't leak to other users.
# Apply permission filter for non-superusers so autocomplete words
# from invisible documents don't leak to other users.
if user is not None and not user.is_superuser:
final_query = tantivy.Query.boolean_query(
[
(tantivy.Occur.Must, prefix_query),
(tantivy.Occur.Must, build_permission_filter(self._schema, user)),
],
)
base_query = build_permission_filter(self._schema, user)
else:
final_query = prefix_query
base_query = tantivy.Query.all_query()
results = searcher.search(final_query, limit=searcher.num_docs)
results = searcher.search(base_query, limit=10000)
# Count how many visible documents each matching word appears in.
# Count how many visible documents each word appears in.
# Using Counter (not set) preserves per-word document frequency so
# we can rank suggestions by how commonly they occur — the same
# signal Whoosh used for Tf/Idf-based autocomplete ordering.
word_counts: Counter[str] = Counter()
for _score, doc_address in results.hits:
stored_doc = searcher.doc(doc_address)
doc_dict = stored_doc.to_dict()
if "autocomplete_word" in doc_dict:
for word in doc_dict["autocomplete_word"]:
if word.startswith(normalized_term):
word_counts[word] += 1
word_counts.update(doc_dict["autocomplete_word"])
# Sort by document frequency descending; break ties alphabetically.
# Filter to prefix matches, sort by document frequency descending;
# break ties alphabetically for stable, deterministic output.
matches = sorted(
word_counts,
(w for w in word_counts if w.startswith(normalized_term)),
key=lambda w: (-word_counts[w], w),
)
return matches[:limit]
def more_like_this_ids(
def more_like_this(
self,
doc_id: int,
user: AbstractBaseUser | None,
*,
limit: int | None = None,
) -> list[int]:
page: int,
page_size: int,
) -> SearchResults:
"""
Return IDs of documents similar to the given document — no highlights.
Find documents similar to the given document using content analysis.
Lightweight companion to more_like_this(). The original document is
excluded from results.
Uses Tantivy's "more like this" query to find documents with similar
content patterns. The original document is excluded from results.
Args:
doc_id: Primary key of the reference document
user: User for permission filtering (None for no filtering)
limit: Maximum number of IDs to return (None = all matching docs)
page: Page number (1-indexed) for pagination
page_size: Number of results per page
Returns:
List of similar document IDs (excluding the original)
SearchResults with similar documents (excluding the original)
"""
self._ensure_open()
searcher = self._index.searcher()
# First find the document address
id_query = tantivy.Query.range_query(
self._schema,
"id",
@@ -860,9 +681,13 @@ class TantivyBackend:
results = searcher.search(id_query, limit=1)
if not results.hits:
return []
# Document not found
return SearchResults(hits=[], total=0, query=f"more_like:{doc_id}")
# Extract doc_address from (score, doc_address) tuple
doc_address = results.hits[0][1]
# Build more like this query
mlt_query = tantivy.Query.more_like_this_query(
doc_address,
min_doc_frequency=1,
@@ -874,21 +699,59 @@ class TantivyBackend:
boost_factor=None,
)
final_query = self._apply_permission_filter(mlt_query, user)
# Apply permission filter
if user is not None:
permission_filter = build_permission_filter(self._schema, user)
final_query = tantivy.Query.boolean_query(
[
(tantivy.Occur.Must, mlt_query),
(tantivy.Occur.Must, permission_filter),
],
)
else:
final_query = mlt_query
effective_limit = limit if limit is not None else searcher.num_docs
# Fetch one extra to account for excluding the original document
results = searcher.search(final_query, limit=effective_limit + 1)
# Search
offset = (page - 1) * page_size
results = searcher.search(final_query, limit=offset + page_size)
# TODO: Replace with collect_u64_fast_field("id", addrs) once
# https://github.com/quickwit-oss/tantivy-py/pull/641 lands.
ids = []
for _score, doc_address in results.hits:
result_doc_id = searcher.doc(doc_address).to_dict()["id"][0]
if result_doc_id != doc_id:
ids.append(result_doc_id)
total = results.count
# Convert from (score, doc_address) to (doc_address, score)
all_hits = [(hit[1], hit[0]) for hit in results.hits]
return ids[:limit] if limit is not None else ids
# Normalize scores
if all_hits:
max_score = max(hit[1] for hit in all_hits) or 1.0
all_hits = [(hit[0], hit[1] / max_score) for hit in all_hits]
# Get page hits
page_hits = all_hits[offset : offset + page_size]
# Build results
hits: list[SearchHit] = []
for rank, (doc_address, score) in enumerate(page_hits, start=offset + 1):
actual_doc = searcher.doc(doc_address)
doc_dict = actual_doc.to_dict()
result_doc_id = doc_dict["id"][0]
# Skip the original document
if result_doc_id == doc_id:
continue
hits.append(
SearchHit(
id=result_doc_id,
score=score,
rank=rank,
highlights={}, # MLT doesn't generate highlights
),
)
return SearchResults(
hits=hits,
total=max(0, total - 1), # Subtract 1 for the original document
query=f"more_like:{doc_id}",
)
def batch_update(self, lock_timeout: float = 30.0) -> WriteBatch:
"""

View File

@@ -396,17 +396,10 @@ def build_permission_filter(
Tantivy query that filters results to visible documents
Implementation Notes:
- Uses range_query instead of term_query for owner_id/viewer_id to work
around a tantivy-py bug where Python ints are inferred as i64, causing
term_query to return no hits on u64 fields.
TODO: Replace with term_query once
https://github.com/quickwit-oss/tantivy-py/pull/642 lands.
- Uses range_query(owner_id, 1, MAX_U64) as an "owner exists" check
because exists_query is not yet available in tantivy-py 0.25.
TODO: Replace with exists_query("owner_id") once that is exposed in
a tantivy-py release.
- Uses range_query instead of term_query to work around unsigned integer
type detection bug in tantivy-py 0.25
- Uses boolean_query for "no owner" check since exists_query is not
available in tantivy-py 0.25.1 (available in master)
- Uses disjunction_max_query to combine permission clauses with OR logic
"""
owner_any = tantivy.Query.range_query(

View File

@@ -72,9 +72,6 @@ def build_schema() -> tantivy.Schema:
# JSON fields — structured queries: notes.user:alice, custom_fields.name:invoice
sb.add_json_field("notes", stored=True, tokenizer_name="paperless_text")
# Plain-text companion for notes — tantivy's SnippetGenerator does not support
# JSON fields, so highlights require a text field with the same content.
sb.add_text_field("notes_text", stored=True, tokenizer_name="paperless_text")
sb.add_json_field("custom_fields", stored=True, tokenizer_name="paperless_text")
for field in (

View File

@@ -33,12 +33,19 @@ class TestWriteBatch:
except RuntimeError:
pass
ids = backend.search_ids("should survive", user=None)
assert len(ids) == 1
r = backend.search(
"should survive",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
)
assert r.total == 1
class TestSearch:
"""Test search query parsing and matching via search_ids."""
"""Test search functionality."""
def test_text_mode_limits_default_search_to_title_and_content(
self,
@@ -53,20 +60,27 @@ class TestSearch:
)
backend.add_or_update(doc)
assert (
len(
backend.search_ids(
"document_type:invoice",
user=None,
search_mode=SearchMode.TEXT,
),
)
== 0
metadata_only = backend.search(
"document_type:invoice",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TEXT,
)
assert (
len(backend.search_ids("monthly", user=None, search_mode=SearchMode.TEXT))
== 1
assert metadata_only.total == 0
content_match = backend.search(
"monthly",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TEXT,
)
assert content_match.total == 1
def test_title_mode_limits_default_search_to_title_only(
self,
@@ -81,14 +95,27 @@ class TestSearch:
)
backend.add_or_update(doc)
assert (
len(backend.search_ids("monthly", user=None, search_mode=SearchMode.TITLE))
== 0
content_only = backend.search(
"monthly",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TITLE,
)
assert (
len(backend.search_ids("invoice", user=None, search_mode=SearchMode.TITLE))
== 1
assert content_only.total == 0
title_match = backend.search(
"invoice",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TITLE,
)
assert title_match.total == 1
def test_text_mode_matches_partial_term_substrings(
self,
@@ -103,16 +130,38 @@ class TestSearch:
)
backend.add_or_update(doc)
assert (
len(backend.search_ids("pass", user=None, search_mode=SearchMode.TEXT)) == 1
prefix_match = backend.search(
"pass",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TEXT,
)
assert (
len(backend.search_ids("sswo", user=None, search_mode=SearchMode.TEXT)) == 1
assert prefix_match.total == 1
infix_match = backend.search(
"sswo",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TEXT,
)
assert (
len(backend.search_ids("sswo re", user=None, search_mode=SearchMode.TEXT))
== 1
assert infix_match.total == 1
phrase_match = backend.search(
"sswo re",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TEXT,
)
assert phrase_match.total == 1
def test_text_mode_does_not_match_on_partial_term_overlap(
self,
@@ -127,10 +176,16 @@ class TestSearch:
)
backend.add_or_update(doc)
assert (
len(backend.search_ids("raptor", user=None, search_mode=SearchMode.TEXT))
== 0
non_match = backend.search(
"raptor",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TEXT,
)
assert non_match.total == 0
def test_text_mode_anchors_later_query_tokens_to_token_starts(
self,
@@ -159,9 +214,16 @@ class TestSearch:
backend.add_or_update(prefix_doc)
backend.add_or_update(false_positive)
result_ids = set(
backend.search_ids("Z-Berichte 6", user=None, search_mode=SearchMode.TEXT),
results = backend.search(
"Z-Berichte 6",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TEXT,
)
result_ids = {hit["id"] for hit in results.hits}
assert exact_doc.id in result_ids
assert prefix_doc.id in result_ids
@@ -180,9 +242,16 @@ class TestSearch:
)
backend.add_or_update(doc)
assert (
len(backend.search_ids("!!!", user=None, search_mode=SearchMode.TEXT)) == 0
no_tokens = backend.search(
"!!!",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TEXT,
)
assert no_tokens.total == 0
def test_title_mode_matches_partial_term_substrings(
self,
@@ -197,18 +266,59 @@ class TestSearch:
)
backend.add_or_update(doc)
assert (
len(backend.search_ids("pass", user=None, search_mode=SearchMode.TITLE))
== 1
prefix_match = backend.search(
"pass",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TITLE,
)
assert (
len(backend.search_ids("sswo", user=None, search_mode=SearchMode.TITLE))
== 1
assert prefix_match.total == 1
infix_match = backend.search(
"sswo",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TITLE,
)
assert (
len(backend.search_ids("sswo gu", user=None, search_mode=SearchMode.TITLE))
== 1
assert infix_match.total == 1
phrase_match = backend.search(
"sswo gu",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TITLE,
)
assert phrase_match.total == 1
def test_scores_normalised_top_hit_is_one(self, backend: TantivyBackend):
"""Search scores must be normalized so top hit has score 1.0 for UI consistency."""
for i, title in enumerate(["bank invoice", "bank statement", "bank receipt"]):
doc = Document.objects.create(
title=title,
content=title,
checksum=f"SN{i}",
pk=10 + i,
)
backend.add_or_update(doc)
r = backend.search(
"bank",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
)
assert r.hits[0]["score"] == pytest.approx(1.0)
assert all(0.0 <= h["score"] <= 1.0 for h in r.hits)
def test_sort_field_ascending(self, backend: TantivyBackend):
"""Searching with sort_reverse=False must return results in ascending ASN order."""
@@ -221,14 +331,16 @@ class TestSearch:
)
backend.add_or_update(doc)
ids = backend.search_ids(
r = backend.search(
"sortable",
user=None,
page=1,
page_size=10,
sort_field="archive_serial_number",
sort_reverse=False,
)
assert len(ids) == 3
asns = [Document.objects.get(pk=doc_id).archive_serial_number for doc_id in ids]
assert r.total == 3
asns = [Document.objects.get(pk=h["id"]).archive_serial_number for h in r.hits]
assert asns == [10, 20, 30]
def test_sort_field_descending(self, backend: TantivyBackend):
@@ -242,91 +354,79 @@ class TestSearch:
)
backend.add_or_update(doc)
ids = backend.search_ids(
r = backend.search(
"sortable",
user=None,
page=1,
page_size=10,
sort_field="archive_serial_number",
sort_reverse=True,
)
assert len(ids) == 3
asns = [Document.objects.get(pk=doc_id).archive_serial_number for doc_id in ids]
assert r.total == 3
asns = [Document.objects.get(pk=h["id"]).archive_serial_number for h in r.hits]
assert asns == [30, 20, 10]
class TestSearchIds:
"""Test lightweight ID-only search."""
def test_returns_matching_ids(self, backend: TantivyBackend):
"""search_ids must return IDs of all matching documents."""
docs = []
for i in range(5):
doc = Document.objects.create(
title=f"findable doc {i}",
content="common keyword",
checksum=f"SI{i}",
)
backend.add_or_update(doc)
docs.append(doc)
other = Document.objects.create(
title="unrelated",
content="nothing here",
checksum="SI_other",
)
backend.add_or_update(other)
ids = backend.search_ids(
"common keyword",
user=None,
search_mode=SearchMode.QUERY,
)
assert set(ids) == {d.pk for d in docs}
assert other.pk not in ids
def test_respects_permission_filter(self, backend: TantivyBackend):
"""search_ids must respect user permission filtering."""
owner = User.objects.create_user("ids_owner")
other = User.objects.create_user("ids_other")
def test_fuzzy_threshold_filters_low_score_hits(
self,
backend: TantivyBackend,
settings,
):
"""When ADVANCED_FUZZY_SEARCH_THRESHOLD exceeds all normalized scores, hits must be filtered out."""
doc = Document.objects.create(
title="private doc",
content="secret keyword",
checksum="SIP1",
title="Invoice document",
content="financial report",
checksum="FT1",
pk=120,
)
backend.add_or_update(doc)
# Threshold above 1.0 filters every hit (normalized scores top out at 1.0)
settings.ADVANCED_FUZZY_SEARCH_THRESHOLD = 1.1
r = backend.search(
"invoice",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
)
assert r.hits == []
def test_owner_filter(self, backend: TantivyBackend):
"""Document owners can search their private documents; other users cannot access them."""
owner = User.objects.create_user("owner")
other = User.objects.create_user("other")
doc = Document.objects.create(
title="Private",
content="secret",
checksum="PF1",
pk=20,
owner=owner,
)
backend.add_or_update(doc)
assert backend.search_ids(
"secret",
user=owner,
search_mode=SearchMode.QUERY,
) == [doc.pk]
assert (
backend.search_ids("secret", user=other, search_mode=SearchMode.QUERY) == []
backend.search(
"secret",
user=owner,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
).total
== 1
)
def test_respects_fuzzy_threshold(self, backend: TantivyBackend, settings):
"""search_ids must apply the same fuzzy threshold as search()."""
doc = Document.objects.create(
title="threshold test",
content="unique term",
checksum="SIT1",
assert (
backend.search(
"secret",
user=other,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
).total
== 0
)
backend.add_or_update(doc)
settings.ADVANCED_FUZZY_SEARCH_THRESHOLD = 1.1
ids = backend.search_ids("unique", user=None, search_mode=SearchMode.QUERY)
assert ids == []
def test_returns_ids_for_text_mode(self, backend: TantivyBackend):
"""search_ids must work with TEXT search mode."""
doc = Document.objects.create(
title="text mode doc",
content="findable phrase",
checksum="SIM1",
)
backend.add_or_update(doc)
ids = backend.search_ids("findable", user=None, search_mode=SearchMode.TEXT)
assert ids == [doc.pk]
class TestRebuild:
@@ -390,26 +490,57 @@ class TestAutocomplete:
class TestMoreLikeThis:
"""Test more like this functionality."""
def test_more_like_this_ids_excludes_original(self, backend: TantivyBackend):
"""more_like_this_ids must return IDs of similar documents, excluding the original."""
def test_excludes_original(self, backend: TantivyBackend):
"""More like this queries must exclude the reference document from results."""
doc1 = Document.objects.create(
title="Important document",
content="financial information report",
checksum="MLTI1",
pk=150,
content="financial information",
checksum="MLT1",
pk=50,
)
doc2 = Document.objects.create(
title="Another document",
content="financial information report",
checksum="MLTI2",
pk=151,
content="financial report",
checksum="MLT2",
pk=51,
)
backend.add_or_update(doc1)
backend.add_or_update(doc2)
ids = backend.more_like_this_ids(doc_id=150, user=None)
assert 150 not in ids
assert 151 in ids
results = backend.more_like_this(doc_id=50, user=None, page=1, page_size=10)
returned_ids = [hit["id"] for hit in results.hits]
assert 50 not in returned_ids # Original document excluded
def test_with_user_applies_permission_filter(self, backend: TantivyBackend):
"""more_like_this with a user must exclude documents that user cannot see."""
viewer = User.objects.create_user("mlt_viewer")
other = User.objects.create_user("mlt_other")
public_doc = Document.objects.create(
title="Public financial document",
content="quarterly financial analysis report figures",
checksum="MLT3",
pk=52,
)
private_doc = Document.objects.create(
title="Private financial document",
content="quarterly financial analysis report figures",
checksum="MLT4",
pk=53,
owner=other,
)
backend.add_or_update(public_doc)
backend.add_or_update(private_doc)
results = backend.more_like_this(doc_id=52, user=viewer, page=1, page_size=10)
returned_ids = [hit["id"] for hit in results.hits]
# private_doc is owned by other, so viewer cannot see it
assert 53 not in returned_ids
def test_document_not_in_index_returns_empty(self, backend: TantivyBackend):
"""more_like_this for a doc_id absent from the index must return empty results."""
results = backend.more_like_this(doc_id=9999, user=None, page=1, page_size=10)
assert results.hits == []
assert results.total == 0
class TestSingleton:
@@ -462,10 +593,19 @@ class TestFieldHandling:
# Should not raise an exception
backend.add_or_update(doc)
assert len(backend.search_ids("test", user=None)) == 1
results = backend.search(
"test",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
)
assert results.total == 1
def test_custom_fields_include_name_and_value(self, backend: TantivyBackend):
"""Custom fields must be indexed with both field name and value for structured queries."""
# Create a custom field
field = CustomField.objects.create(
name="Invoice Number",
data_type=CustomField.FieldDataType.STRING,
@@ -482,9 +622,18 @@ class TestFieldHandling:
value_text="INV-2024-001",
)
# Should not raise an exception during indexing
backend.add_or_update(doc)
assert len(backend.search_ids("invoice", user=None)) == 1
results = backend.search(
"invoice",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
)
assert results.total == 1
def test_select_custom_field_indexes_label_not_id(self, backend: TantivyBackend):
"""SELECT custom fields must index the human-readable label, not the opaque option ID."""
@@ -511,8 +660,27 @@ class TestFieldHandling:
)
backend.add_or_update(doc)
assert len(backend.search_ids("custom_fields.value:invoice", user=None)) == 1
assert len(backend.search_ids("custom_fields.value:opt_abc", user=None)) == 0
# Label should be findable
results = backend.search(
"custom_fields.value:invoice",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
)
assert results.total == 1
# Opaque ID must not appear in the index
results = backend.search(
"custom_fields.value:opt_abc",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
)
assert results.total == 0
def test_none_custom_field_value_not_indexed(self, backend: TantivyBackend):
"""Custom field instances with no value set must not produce an index entry."""
@@ -534,7 +702,16 @@ class TestFieldHandling:
)
backend.add_or_update(doc)
assert len(backend.search_ids("custom_fields.value:none", user=None)) == 0
# The string "none" must not appear as an indexed value
results = backend.search(
"custom_fields.value:none",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
)
assert results.total == 0
def test_notes_include_user_information(self, backend: TantivyBackend):
"""Notes must be indexed with user information when available for structured queries."""
@@ -547,101 +724,32 @@ class TestFieldHandling:
)
Note.objects.create(document=doc, note="Important note", user=user)
# Should not raise an exception during indexing
backend.add_or_update(doc)
ids = backend.search_ids("test", user=None)
assert len(ids) == 1, (
f"Expected 1, got {len(ids)}. Document content should be searchable."
# Test basic document search first
results = backend.search(
"test",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
)
assert results.total == 1, (
f"Expected 1, got {results.total}. Document content should be searchable."
)
ids = backend.search_ids("notes.note:important", user=None)
assert len(ids) == 1, (
f"Expected 1, got {len(ids)}. Note content should be searchable via notes.note: prefix."
# Test notes search — must use structured JSON syntax now that note
# is no longer in DEFAULT_SEARCH_FIELDS
results = backend.search(
"notes.note:important",
user=None,
page=1,
page_size=10,
sort_field=None,
sort_reverse=False,
)
class TestHighlightHits:
"""Test highlight_hits returns proper HTML strings, not raw Snippet objects."""
def test_highlights_content_returns_match_span_html(
self,
backend: TantivyBackend,
):
"""highlight_hits must return frontend-ready highlight spans."""
doc = Document.objects.create(
title="Highlight Test",
content="The quick brown fox jumps over the lazy dog",
checksum="HH1",
pk=90,
assert results.total == 1, (
f"Expected 1, got {results.total}. Note content should be searchable via notes.note: prefix."
)
backend.add_or_update(doc)
hits = backend.highlight_hits("quick", [doc.pk])
assert len(hits) == 1
highlights = hits[0]["highlights"]
assert "content" in highlights
content_highlight = highlights["content"]
assert isinstance(content_highlight, str), (
f"Expected str, got {type(content_highlight)}: {content_highlight!r}"
)
assert '<span class="match">' in content_highlight, (
f"Expected HTML with match span, got: {content_highlight!r}"
)
def test_highlights_notes_returns_match_span_html(
self,
backend: TantivyBackend,
):
"""Note highlights must be frontend-ready HTML via notes_text companion field.
The notes JSON field does not support tantivy SnippetGenerator; the
notes_text plain-text field is used instead. We use the full-text
query "urgent" (not notes.note:) because notes_text IS in
DEFAULT_SEARCH_FIELDS via the normal search path… actually, we use
notes.note: prefix so the query targets notes content directly, but
the snippet is generated from notes_text which stores the same text.
"""
user = User.objects.create_user("hl_noteuser")
doc = Document.objects.create(
title="Doc with matching note",
content="unrelated content",
checksum="HH2",
pk=91,
)
Note.objects.create(document=doc, note="urgent payment required", user=user)
backend.add_or_update(doc)
# Use notes.note: prefix so the document matches the query and the
# notes_text snippet generator can produce highlights.
hits = backend.highlight_hits("notes.note:urgent", [doc.pk])
assert len(hits) == 1
highlights = hits[0]["highlights"]
assert "notes" in highlights
note_highlight = highlights["notes"]
assert isinstance(note_highlight, str), (
f"Expected str, got {type(note_highlight)}: {note_highlight!r}"
)
assert '<span class="match">' in note_highlight, (
f"Expected HTML with match span, got: {note_highlight!r}"
)
def test_empty_doc_list_returns_empty_hits(self, backend: TantivyBackend):
"""highlight_hits with no doc IDs must return an empty list."""
hits = backend.highlight_hits("anything", [])
assert hits == []
def test_no_highlights_when_no_match(self, backend: TantivyBackend):
"""Documents not matching the query should not appear in results."""
doc = Document.objects.create(
title="Unrelated",
content="completely different text",
checksum="HH3",
pk=92,
)
backend.add_or_update(doc)
hits = backend.highlight_hits("quick", [doc.pk])
assert len(hits) == 0

View File

@@ -6,6 +6,8 @@ from unittest.mock import patch
from django.contrib.auth.models import User
from django.core.files.uploadedfile import SimpleUploadedFile
from django.test import override_settings
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from rest_framework import status
from rest_framework.test import APITestCase
@@ -201,6 +203,156 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
)
self.assertFalse(Path(old_logo.path).exists())
def test_api_strips_exif_data_from_uploaded_logo(self) -> None:
"""
GIVEN:
- A JPEG logo upload containing EXIF metadata
WHEN:
- Uploaded via PATCH to app config
THEN:
- Stored logo image has EXIF metadata removed
"""
image = Image.new("RGB", (12, 12), "blue")
exif = Image.Exif()
exif[315] = "Paperless Test Author"
logo = BytesIO()
image.save(logo, format="JPEG", exif=exif)
logo.seek(0)
response = self.client.patch(
f"{self.ENDPOINT}1/",
{
"app_logo": SimpleUploadedFile(
name="logo-with-exif.jpg",
content=logo.getvalue(),
content_type="image/jpeg",
),
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
config = ApplicationConfiguration.objects.first()
with Image.open(config.app_logo.path) as stored_logo:
stored_exif = stored_logo.getexif()
self.assertEqual(len(stored_exif), 0)
def test_api_strips_png_metadata_from_uploaded_logo(self) -> None:
"""
GIVEN:
- A PNG logo upload containing text metadata
WHEN:
- Uploaded via PATCH to app config
THEN:
- Stored logo image has metadata removed
"""
image = Image.new("RGB", (12, 12), "green")
pnginfo = PngInfo()
pnginfo.add_text("Author", "Paperless Test Author")
logo = BytesIO()
image.save(logo, format="PNG", pnginfo=pnginfo)
logo.seek(0)
response = self.client.patch(
f"{self.ENDPOINT}1/",
{
"app_logo": SimpleUploadedFile(
name="logo-with-metadata.png",
content=logo.getvalue(),
content_type="image/png",
),
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
config = ApplicationConfiguration.objects.first()
with Image.open(config.app_logo.path) as stored_logo:
stored_text = stored_logo.text
self.assertEqual(stored_text, {})
def test_api_accepts_valid_gif_logo(self) -> None:
"""
GIVEN:
- A valid GIF logo upload
WHEN:
- Uploaded via PATCH to app config
THEN:
- Upload succeeds
"""
image = Image.new("RGB", (12, 12), "red")
logo = BytesIO()
image.save(logo, format="GIF", comment=b"Paperless Test Comment")
logo.seek(0)
response = self.client.patch(
f"{self.ENDPOINT}1/",
{
"app_logo": SimpleUploadedFile(
name="logo.gif",
content=logo.getvalue(),
content_type="image/gif",
),
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_api_rejects_invalid_raster_logo(self) -> None:
"""
GIVEN:
- A file named as a JPEG but containing non-image payload data
WHEN:
- Uploaded via PATCH to app config
THEN:
- Upload is rejected with 400
"""
response = self.client.patch(
f"{self.ENDPOINT}1/",
{
"app_logo": SimpleUploadedFile(
name="not-an-image.jpg",
content=b"<script>alert('xss')</script>",
content_type="image/jpeg",
),
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("invalid logo image", str(response.data).lower())
@override_settings(MAX_IMAGE_PIXELS=100)
def test_api_rejects_logo_exceeding_max_image_pixels(self) -> None:
"""
GIVEN:
- A raster logo larger than the configured MAX_IMAGE_PIXELS limit
WHEN:
- Uploaded via PATCH to app config
THEN:
- Upload is rejected with 400
"""
image = Image.new("RGB", (12, 12), "purple")
logo = BytesIO()
image.save(logo, format="PNG")
logo.seek(0)
response = self.client.patch(
f"{self.ENDPOINT}1/",
{
"app_logo": SimpleUploadedFile(
name="too-large.png",
content=logo.getvalue(),
content_type="image/png",
),
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn(
"uploaded logo exceeds the maximum allowed image size",
str(response.data).lower(),
)
def test_api_rejects_malicious_svg_logo(self) -> None:
"""
GIVEN:

View File

@@ -18,6 +18,7 @@ from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.core import mail
from django.core.cache import cache
from django.core.files.uploadedfile import SimpleUploadedFile
from django.db import DataError
from django.test import override_settings
from django.utils import timezone
@@ -1377,6 +1378,79 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
self.assertIsNone(overrides.document_type_id)
self.assertIsNone(overrides.tag_ids)
def test_upload_with_path_traversal_filename_is_reduced_to_basename(self) -> None:
self.consume_file_mock.return_value = celery.result.AsyncResult(
id=str(uuid.uuid4()),
)
payload = SimpleUploadedFile(
"../../outside.pdf",
(Path(__file__).parent / "samples" / "simple.pdf").read_bytes(),
content_type="application/pdf",
)
response = self.client.post(
"/api/documents/post_document/",
{"document": payload},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
input_doc, overrides = self.get_last_consume_delay_call_args()
self.assertEqual(input_doc.original_file.name, "outside.pdf")
self.assertEqual(overrides.filename, "outside.pdf")
self.assertNotIn("..", input_doc.original_file.name)
self.assertNotIn("..", overrides.filename)
self.assertTrue(
input_doc.original_file.resolve(strict=False).is_relative_to(
Path(settings.SCRATCH_DIR).resolve(strict=False),
),
)
def test_upload_with_path_traversal_content_disposition_filename_is_reduced_to_basename(
self,
) -> None:
self.consume_file_mock.return_value = celery.result.AsyncResult(
id=str(uuid.uuid4()),
)
pdf_bytes = (Path(__file__).parent / "samples" / "simple.pdf").read_bytes()
boundary = "paperless-boundary"
payload = (
(
f"--{boundary}\r\n"
'Content-Disposition: form-data; name="document"; '
'filename="../../outside.pdf"\r\n'
"Content-Type: application/pdf\r\n\r\n"
).encode()
+ pdf_bytes
+ f"\r\n--{boundary}--\r\n".encode()
)
response = self.client.generic(
"POST",
"/api/documents/post_document/",
payload,
content_type=f"multipart/form-data; boundary={boundary}",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
input_doc, overrides = self.get_last_consume_delay_call_args()
self.assertEqual(input_doc.original_file.name, "outside.pdf")
self.assertEqual(overrides.filename, "outside.pdf")
self.assertNotIn("..", input_doc.original_file.name)
self.assertNotIn("..", overrides.filename)
self.assertTrue(
input_doc.original_file.resolve(strict=False).is_relative_to(
Path(settings.SCRATCH_DIR).resolve(strict=False),
),
)
def test_document_filters_use_latest_version_content(self) -> None:
root = Document.objects.create(
title="versioned root",

View File

@@ -1503,126 +1503,6 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
[d2.id, d1.id, d3.id],
)
def test_search_ordering_by_score(self) -> None:
"""ordering=-score must return results in descending relevance order (best first)."""
backend = get_backend()
# doc_high has more occurrences of the search term → higher BM25 score
doc_low = Document.objects.create(
title="score sort low",
content="apple",
checksum="SCL1",
)
doc_high = Document.objects.create(
title="score sort high",
content="apple apple apple apple apple",
checksum="SCH1",
)
backend.add_or_update(doc_low)
backend.add_or_update(doc_high)
# -score = descending = best first (highest score)
response = self.client.get("/api/documents/?query=apple&ordering=-score")
self.assertEqual(response.status_code, status.HTTP_200_OK)
ids = [r["id"] for r in response.data["results"]]
self.assertEqual(
ids[0],
doc_high.id,
"Most relevant doc should be first for -score",
)
# score = ascending = worst first (lowest score)
response = self.client.get("/api/documents/?query=apple&ordering=score")
self.assertEqual(response.status_code, status.HTTP_200_OK)
ids = [r["id"] for r in response.data["results"]]
self.assertEqual(
ids[0],
doc_low.id,
"Least relevant doc should be first for +score",
)
def test_search_with_tantivy_native_sort(self) -> None:
"""When ordering by a Tantivy-sortable field, results must be correctly sorted."""
backend = get_backend()
for i, asn in enumerate([30, 10, 20]):
doc = Document.objects.create(
title=f"sortable doc {i}",
content="searchable content",
checksum=f"TNS{i}",
archive_serial_number=asn,
)
backend.add_or_update(doc)
response = self.client.get(
"/api/documents/?query=searchable&ordering=archive_serial_number",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
asns = [doc["archive_serial_number"] for doc in response.data["results"]]
self.assertEqual(asns, [10, 20, 30])
response = self.client.get(
"/api/documents/?query=searchable&ordering=-archive_serial_number",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
asns = [doc["archive_serial_number"] for doc in response.data["results"]]
self.assertEqual(asns, [30, 20, 10])
def test_search_page_2_returns_correct_slice(self) -> None:
"""Page 2 must return the second slice, not overlap with page 1."""
backend = get_backend()
for i in range(10):
doc = Document.objects.create(
title=f"doc {i}",
content="paginated content",
checksum=f"PG2{i}",
archive_serial_number=i + 1,
)
backend.add_or_update(doc)
response = self.client.get(
"/api/documents/?query=paginated&ordering=archive_serial_number&page=1&page_size=3",
)
page1_ids = [r["id"] for r in response.data["results"]]
self.assertEqual(len(page1_ids), 3)
response = self.client.get(
"/api/documents/?query=paginated&ordering=archive_serial_number&page=2&page_size=3",
)
page2_ids = [r["id"] for r in response.data["results"]]
self.assertEqual(len(page2_ids), 3)
# No overlap between pages
self.assertEqual(set(page1_ids) & set(page2_ids), set())
# Page 2 ASNs are higher than page 1
page1_asns = [
Document.objects.get(pk=pk).archive_serial_number for pk in page1_ids
]
page2_asns = [
Document.objects.get(pk=pk).archive_serial_number for pk in page2_ids
]
self.assertTrue(max(page1_asns) < min(page2_asns))
def test_search_all_field_contains_all_ids_when_paginated(self) -> None:
"""The 'all' field must contain every matching ID, even when paginated."""
backend = get_backend()
doc_ids = []
for i in range(10):
doc = Document.objects.create(
title=f"all field doc {i}",
content="allfield content",
checksum=f"AF{i}",
)
backend.add_or_update(doc)
doc_ids.append(doc.pk)
response = self.client.get(
"/api/documents/?query=allfield&page=1&page_size=3",
headers={"Accept": "application/json; version=9"},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data["results"]), 3)
# "all" must contain ALL 10 matching IDs
self.assertCountEqual(response.data["all"], doc_ids)
@mock.patch("documents.bulk_edit.bulk_update_documents")
def test_global_search(self, m) -> None:
"""

View File

@@ -38,7 +38,6 @@ from django.db.models import Model
from django.db.models import OuterRef
from django.db.models import Prefetch
from django.db.models import Q
from django.db.models import QuerySet
from django.db.models import Subquery
from django.db.models import Sum
from django.db.models import When
@@ -249,13 +248,6 @@ if settings.AUDIT_LOG_ENABLED:
logger = logging.getLogger("paperless.api")
# Crossover point for intersect_and_order: below this count use a targeted
# IN-clause query; at or above this count fall back to a full-table scan +
# Python set intersection. The IN-clause is faster for small result sets but
# degrades on SQLite with thousands of parameters. PostgreSQL handles large IN
# clauses efficiently, so this threshold mainly protects SQLite users.
_TANTIVY_INTERSECT_THRESHOLD = 5_000
class IndexView(TemplateView):
template_name = "index.html"
@@ -2068,16 +2060,19 @@ class UnifiedSearchViewSet(DocumentViewSet):
if not self._is_search_request():
return super().list(request)
from documents.search import SearchHit
from documents.search import SearchMode
from documents.search import TantivyBackend
from documents.search import TantivyRelevanceList
from documents.search import get_backend
def parse_search_params() -> tuple[str | None, bool, bool, int, int]:
"""Extract query string, search mode, and ordering from request."""
active = self._get_active_search_params(request)
if len(active) > 1:
try:
backend = get_backend()
# ORM-filtered queryset: permissions + field filters + ordering (DRF backends applied)
filtered_qs = self.filter_queryset(self.get_queryset())
user = None if request.user.is_superuser else request.user
active_search_params = self._get_active_search_params(request)
if len(active_search_params) > 1:
raise ValidationError(
{
"detail": _(
@@ -2086,161 +2081,73 @@ class UnifiedSearchViewSet(DocumentViewSet):
},
)
ordering_param = request.query_params.get("ordering", "")
sort_reverse = ordering_param.startswith("-")
sort_field_name = ordering_param.lstrip("-") or None
# "score" means relevance order — Tantivy handles it natively,
# so treat it as a Tantivy sort to preserve the ranked order through
# the ORM intersection step.
use_tantivy_sort = (
sort_field_name in TantivyBackend.SORTABLE_FIELDS
or sort_field_name is None
or sort_field_name == "score"
)
try:
page_num = int(request.query_params.get("page", 1))
except (TypeError, ValueError):
page_num = 1
page_size = (
self.paginator.get_page_size(request) or self.paginator.page_size
)
return sort_field_name, sort_reverse, use_tantivy_sort, page_num, page_size
def intersect_and_order(
all_ids: list[int],
filtered_qs: QuerySet[Document],
*,
use_tantivy_sort: bool,
) -> list[int]:
"""Intersect search IDs with ORM-visible IDs, preserving order."""
if not all_ids:
return []
if use_tantivy_sort:
if len(all_ids) <= _TANTIVY_INTERSECT_THRESHOLD:
# Small result set: targeted IN-clause avoids a full-table scan.
visible_ids = set(
filtered_qs.filter(pk__in=all_ids).values_list("pk", flat=True),
)
else:
# Large result set: full-table scan + Python intersection is faster
# than a large IN-clause on SQLite.
visible_ids = set(
filtered_qs.values_list("pk", flat=True),
)
return [doc_id for doc_id in all_ids if doc_id in visible_ids]
return list(
filtered_qs.filter(id__in=all_ids).values_list("pk", flat=True),
)
def run_text_search(
backend: TantivyBackend,
user: User | None,
filtered_qs: QuerySet[Document],
) -> tuple[list[int], list[SearchHit], int]:
"""Handle text/title/query search: IDs, ORM intersection, page highlights."""
if "text" in request.query_params:
search_mode = SearchMode.TEXT
query_str = request.query_params["text"]
elif "title_search" in request.query_params:
search_mode = SearchMode.TITLE
query_str = request.query_params["title_search"]
else:
search_mode = SearchMode.QUERY
query_str = request.query_params["query"]
# "score" is not a real Tantivy sort field — it means relevance order,
# which is Tantivy's default when no sort field is specified.
is_score_sort = sort_field_name == "score"
all_ids = backend.search_ids(
query_str,
user=user,
sort_field=(
None if (not use_tantivy_sort or is_score_sort) else sort_field_name
),
sort_reverse=sort_reverse,
search_mode=search_mode,
)
ordered_ids = intersect_and_order(
all_ids,
filtered_qs,
use_tantivy_sort=use_tantivy_sort,
)
# Tantivy returns relevance results best-first (descending score).
# ordering=score (ascending, worst-first) requires a reversal.
if is_score_sort and not sort_reverse:
ordered_ids = list(reversed(ordered_ids))
page_offset = (page_num - 1) * page_size
page_ids = ordered_ids[page_offset : page_offset + page_size]
page_hits = backend.highlight_hits(
query_str,
page_ids,
search_mode=search_mode,
rank_start=page_offset + 1,
)
return ordered_ids, page_hits, page_offset
def run_more_like_this(
backend: TantivyBackend,
user: User | None,
filtered_qs: QuerySet[Document],
) -> tuple[list[int], list[SearchHit], int]:
"""Handle more_like_id search: permission check, IDs, stub hits."""
try:
more_like_doc_id = int(request.query_params["more_like_id"])
more_like_doc = Document.objects.select_related("owner").get(
pk=more_like_doc_id,
)
except (TypeError, ValueError, Document.DoesNotExist):
raise PermissionDenied(_("Invalid more_like_id"))
if not has_perms_owner_aware(
request.user,
"view_document",
more_like_doc,
if (
"text" in request.query_params
or "title_search" in request.query_params
or "query" in request.query_params
):
raise PermissionDenied(_("Insufficient permissions."))
all_ids = backend.more_like_this_ids(more_like_doc_id, user=user)
ordered_ids = intersect_and_order(
all_ids,
filtered_qs,
use_tantivy_sort=True,
)
page_offset = (page_num - 1) * page_size
page_ids = ordered_ids[page_offset : page_offset + page_size]
page_hits = [
SearchHit(id=doc_id, score=0.0, rank=rank, highlights={})
for rank, doc_id in enumerate(page_ids, start=page_offset + 1)
]
return ordered_ids, page_hits, page_offset
try:
sort_field_name, sort_reverse, use_tantivy_sort, page_num, page_size = (
parse_search_params()
)
backend = get_backend()
filtered_qs = self.filter_queryset(self.get_queryset())
user = None if request.user.is_superuser else request.user
if "more_like_id" in request.query_params:
ordered_ids, page_hits, page_offset = run_more_like_this(
backend,
user,
filtered_qs,
if "text" in request.query_params:
search_mode = SearchMode.TEXT
query_str = request.query_params["text"]
elif "title_search" in request.query_params:
search_mode = SearchMode.TITLE
query_str = request.query_params["title_search"]
else:
search_mode = SearchMode.QUERY
query_str = request.query_params["query"]
results = backend.search(
query_str,
user=user,
page=1,
page_size=10000,
sort_field=None,
sort_reverse=False,
search_mode=search_mode,
)
else:
ordered_ids, page_hits, page_offset = run_text_search(
backend,
user,
filtered_qs,
# more_like_id — validate permission on the seed document first
try:
more_like_doc_id = int(request.query_params["more_like_id"])
more_like_doc = Document.objects.select_related("owner").get(
pk=more_like_doc_id,
)
except (TypeError, ValueError, Document.DoesNotExist):
raise PermissionDenied(_("Invalid more_like_id"))
if not has_perms_owner_aware(
request.user,
"view_document",
more_like_doc,
):
raise PermissionDenied(_("Insufficient permissions."))
results = backend.more_like_this(
more_like_doc_id,
user=user,
page=1,
page_size=10000,
)
rl = TantivyRelevanceList(ordered_ids, page_hits, page_offset)
hits_by_id = {h["id"]: h for h in results.hits}
# Determine sort order: no ordering param -> Tantivy relevance; otherwise -> ORM order
ordering_param = request.query_params.get("ordering", "").lstrip("-")
if not ordering_param:
# Preserve Tantivy relevance order; intersect with ORM-visible IDs
orm_ids = set(filtered_qs.values_list("pk", flat=True))
ordered_hits = [h for h in results.hits if h["id"] in orm_ids]
else:
# Use ORM ordering (already applied by DocumentsOrderingFilter)
hit_ids = set(hits_by_id.keys())
orm_ordered_ids = filtered_qs.filter(id__in=hit_ids).values_list(
"pk",
flat=True,
)
ordered_hits = [
hits_by_id[pk] for pk in orm_ordered_ids if pk in hits_by_id
]
rl = TantivyRelevanceList(ordered_hits)
page = self.paginate_queryset(rl)
if page is not None:
@@ -2250,18 +2157,15 @@ class UnifiedSearchViewSet(DocumentViewSet):
if get_boolean(
str(request.query_params.get("include_selection_data", "false")),
):
# NOTE: pk__in=ordered_ids generates a large SQL IN clause
# for big result sets. Acceptable today but may need a temp
# table or chunked approach if selection_data becomes slow
# at scale (tens of thousands of matching documents).
all_ids = [h["id"] for h in ordered_hits]
response.data["selection_data"] = (
self._get_selection_data_for_queryset(
filtered_qs.filter(pk__in=ordered_ids),
filtered_qs.filter(pk__in=all_ids),
)
)
return response
serializer = self.get_serializer(page_hits, many=True)
serializer = self.get_serializer(ordered_hits, many=True)
return Response(serializer.data)
except NotFound:
@@ -3167,17 +3071,20 @@ class GlobalSearchView(PassUserMixin):
docs = all_docs.filter(title__icontains=query)[:OBJECT_LIMIT]
else:
user = None if request.user.is_superuser else request.user
matching_ids = get_backend().search_ids(
fts_results = get_backend().search(
query,
user=user,
page=1,
page_size=1000,
sort_field=None,
sort_reverse=False,
search_mode=SearchMode.TEXT,
limit=OBJECT_LIMIT * 3,
)
docs_by_id = all_docs.in_bulk(matching_ids)
docs_by_id = all_docs.in_bulk([hit["id"] for hit in fts_results.hits])
docs = [
docs_by_id[doc_id]
for doc_id in matching_ids
if doc_id in docs_by_id
docs_by_id[hit["id"]]
for hit in fts_results.hits
if hit["id"] in docs_by_id
][:OBJECT_LIMIT]
saved_views = (
get_objects_for_user_owner_aware(

View File

@@ -1,4 +1,5 @@
import logging
from io import BytesIO
import magic
from allauth.mfa.adapter import get_adapter as get_mfa_adapter
@@ -11,13 +12,16 @@ from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.contrib.auth.password_validation import validate_password
from django.core.files.uploadedfile import InMemoryUploadedFile
from django.core.files.uploadedfile import UploadedFile
from PIL import Image
from rest_framework import serializers
from rest_framework.authtoken.serializers import AuthTokenSerializer
from paperless.models import ApplicationConfiguration
from paperless.network import validate_outbound_http_url
from paperless.validators import reject_dangerous_svg
from paperless.validators import validate_raster_image
from paperless_mail.serialisers import ObfuscatedPasswordField
logger = logging.getLogger("paperless.settings")
@@ -233,9 +237,40 @@ class ApplicationConfigurationSerializer(serializers.ModelSerializer):
instance.app_logo.delete()
return super().update(instance, validated_data)
def _sanitize_raster_image(self, file: UploadedFile) -> UploadedFile:
try:
data = BytesIO()
image = Image.open(file)
image.save(data, format=image.format)
data.seek(0)
return InMemoryUploadedFile(
file=data,
field_name=file.field_name,
name=file.name,
content_type=file.content_type,
size=data.getbuffer().nbytes,
charset=getattr(file, "charset", None),
)
finally:
image.close()
def validate_app_logo(self, file: UploadedFile):
if file and magic.from_buffer(file.read(2048), mime=True) == "image/svg+xml":
reject_dangerous_svg(file)
"""
Validates and sanitizes the uploaded app logo image. Model field already restricts to
jpg/png/gif/svg.
"""
if file:
mime_type = magic.from_buffer(file.read(2048), mime=True)
if mime_type == "image/svg+xml":
reject_dangerous_svg(file)
else:
validate_raster_image(file)
if mime_type in {"image/jpeg", "image/png"}:
file = self._sanitize_raster_image(file)
return file
def validate_llm_endpoint(self, value: str | None) -> str | None:

View File

@@ -1,6 +1,10 @@
from io import BytesIO
from django.conf import settings
from django.core.exceptions import ValidationError
from django.core.files.uploadedfile import UploadedFile
from lxml import etree
from PIL import Image
ALLOWED_SVG_TAGS: set[str] = {
# Basic shapes
@@ -254,3 +258,30 @@ def reject_dangerous_svg(file: UploadedFile) -> None:
raise ValidationError(
f"URI scheme not allowed in {attr_name}: must be #anchor, relative path, or data:image/*",
)
def validate_raster_image(file: UploadedFile) -> None:
"""
Validates that the uploaded file is a valid raster image (JPEG, PNG, etc.)
and does not exceed maximum pixel limits.
Raises ValidationError if the image is invalid or exceeds the allowed size.
"""
file.seek(0)
image_data = file.read()
try:
with Image.open(BytesIO(image_data)) as image:
image.verify()
if (
settings.MAX_IMAGE_PIXELS is not None
and settings.MAX_IMAGE_PIXELS > 0
and image.width * image.height > settings.MAX_IMAGE_PIXELS
):
raise ValidationError(
"Uploaded logo exceeds the maximum allowed image size.",
)
if image.format is None: # pragma: no cover
raise ValidationError("Invalid logo image.")
except (OSError, Image.DecompressionBombError) as e:
raise ValidationError("Invalid logo image.") from e

View File

@@ -89,7 +89,7 @@ class StandardPagination(PageNumberPagination):
query = self.page.paginator.object_list
if isinstance(query, TantivyRelevanceList):
return query.get_all_ids()
return [h["id"] for h in query._hits]
return self.page.paginator.object_list.values_list("pk", flat=True)
def get_paginated_response_schema(self, schema):

346
test_backend_profile.py Normal file
View File

@@ -0,0 +1,346 @@
# ruff: noqa: T201
"""
cProfile-based search pipeline profiling with a 20k-document dataset.
Run with:
uv run pytest ../test_backend_profile.py \
-m profiling --override-ini="addopts=" -s -v
Each scenario prints:
- Wall time for the operation
- cProfile stats sorted by cumulative time (top 25 callers)
This is a developer tool, not a correctness test. Nothing here should
fail unless the code is broken.
"""
from __future__ import annotations
import random
import time
from typing import TYPE_CHECKING
import pytest
from profiling import profile_cpu
from documents.models import Document
from documents.search._backend import TantivyBackend
from documents.search._backend import reset_backend
if TYPE_CHECKING:
from pathlib import Path
# transaction=False (default): tests roll back, but the module-scoped fixture
# commits its data outside the test transaction so it remains visible throughout.
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
# ---------------------------------------------------------------------------
# Dataset constants
# ---------------------------------------------------------------------------
NUM_DOCS = 20_000
SEED = 42
# Terms and their approximate match rates across the corpus.
# "rechnung" -> ~70% of docs (~14 000)
# "mahnung" -> ~20% of docs (~4 000)
# "kontonummer" -> ~5% of docs (~1 000)
# "rarewort" -> ~1% of docs (~200)
COMMON_TERM = "rechnung"
MEDIUM_TERM = "mahnung"
RARE_TERM = "kontonummer"
VERY_RARE_TERM = "rarewort"
PAGE_SIZE = 25
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_FILLER_WORDS = [
"dokument", # codespell:ignore
"seite",
"datum",
"betrag",
"nummer",
"konto",
"firma",
"vertrag",
"lieferant",
"bestellung",
"steuer",
"mwst",
"leistung",
"auftrag",
"zahlung",
]
def _build_content(rng: random.Random) -> str:
"""Return a short paragraph with terms embedded at the desired rates."""
words = rng.choices(_FILLER_WORDS, k=15)
if rng.random() < 0.70:
words.append(COMMON_TERM)
if rng.random() < 0.20:
words.append(MEDIUM_TERM)
if rng.random() < 0.05:
words.append(RARE_TERM)
if rng.random() < 0.01:
words.append(VERY_RARE_TERM)
rng.shuffle(words)
return " ".join(words)
def _time(fn, *, label: str, runs: int = 3):
"""Run *fn()* several times and report min/avg/max wall time (no cProfile)."""
times = []
result = None
for _ in range(runs):
t0 = time.perf_counter()
result = fn()
times.append(time.perf_counter() - t0)
mn, avg, mx = min(times), sum(times) / len(times), max(times)
print(
f" {label}: min={mn * 1000:.1f}ms avg={avg * 1000:.1f}ms max={mx * 1000:.1f}ms (n={runs})",
)
return result
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def module_db(django_db_setup, django_db_blocker):
"""Unlock the DB for the whole module (module-scoped)."""
with django_db_blocker.unblock():
yield
@pytest.fixture(scope="module")
def large_backend(tmp_path_factory, module_db) -> TantivyBackend:
"""
Build a 20 000-document DB + on-disk Tantivy index, shared across all
profiling scenarios in this module. Teardown deletes all documents.
"""
index_path: Path = tmp_path_factory.mktemp("tantivy_profile")
# ---- 1. Bulk-create Document rows ----------------------------------------
rng = random.Random(SEED)
docs = [
Document(
title=f"Document {i:05d}",
content=_build_content(rng),
checksum=f"{i:064x}",
pk=i + 1,
)
for i in range(NUM_DOCS)
]
t0 = time.perf_counter()
Document.objects.bulk_create(docs, batch_size=1_000)
db_time = time.perf_counter() - t0
print(f"\n[setup] bulk_create {NUM_DOCS} docs: {db_time:.2f}s")
# ---- 2. Build Tantivy index -----------------------------------------------
backend = TantivyBackend(path=index_path)
backend.open()
t0 = time.perf_counter()
with backend.batch_update() as batch:
for doc in Document.objects.iterator(chunk_size=500):
batch.add_or_update(doc)
idx_time = time.perf_counter() - t0
print(f"[setup] index {NUM_DOCS} docs: {idx_time:.2f}s")
# ---- 3. Report corpus stats -----------------------------------------------
for term in (COMMON_TERM, MEDIUM_TERM, RARE_TERM, VERY_RARE_TERM):
count = len(backend.search_ids(term, user=None))
print(f"[setup] '{term}' -> {count} hits")
yield backend
# ---- Teardown ------------------------------------------------------------
backend.close()
reset_backend()
Document.objects.all().delete()
# ---------------------------------------------------------------------------
# Profiling tests — each scenario is a separate function so pytest can run
# them individually or all together with -m profiling.
# ---------------------------------------------------------------------------
class TestSearchIdsProfile:
"""Profile backend.search_ids() — pure Tantivy, no DB."""
def test_search_ids_large(self, large_backend: TantivyBackend):
"""~14 000 hits: how long does Tantivy take to collect all IDs?"""
profile_cpu(
lambda: large_backend.search_ids(COMMON_TERM, user=None),
label=f"search_ids('{COMMON_TERM}') [large result set ~14k]",
)
def test_search_ids_medium(self, large_backend: TantivyBackend):
"""~4 000 hits."""
profile_cpu(
lambda: large_backend.search_ids(MEDIUM_TERM, user=None),
label=f"search_ids('{MEDIUM_TERM}') [medium result set ~4k]",
)
def test_search_ids_rare(self, large_backend: TantivyBackend):
"""~1 000 hits."""
profile_cpu(
lambda: large_backend.search_ids(RARE_TERM, user=None),
label=f"search_ids('{RARE_TERM}') [rare result set ~1k]",
)
class TestIntersectAndOrderProfile:
"""
Profile the DB intersection step: filter(pk__in=search_ids).
This is the 'intersect_and_order' logic from views.py.
"""
def test_intersect_large(self, large_backend: TantivyBackend):
"""Intersect 14k Tantivy IDs with all 20k ORM-visible docs."""
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
qs = Document.objects.all()
print(f"\n Tantivy returned {len(all_ids)} IDs")
profile_cpu(
lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)),
label=f"filter(pk__in={len(all_ids)} ids) [large, use_tantivy_sort=True path]",
)
# Also time it a few times to get stable numbers
print()
_time(
lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)),
label=f"filter(pk__in={len(all_ids)}) repeated",
)
def test_intersect_rare(self, large_backend: TantivyBackend):
"""Intersect ~1k Tantivy IDs — the happy path."""
all_ids = large_backend.search_ids(RARE_TERM, user=None)
qs = Document.objects.all()
print(f"\n Tantivy returned {len(all_ids)} IDs")
profile_cpu(
lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)),
label=f"filter(pk__in={len(all_ids)} ids) [rare, use_tantivy_sort=True path]",
)
class TestHighlightHitsProfile:
"""Profile backend.highlight_hits() — per-doc Tantivy lookups with BM25 scoring."""
def test_highlight_page1(self, large_backend: TantivyBackend):
"""25-doc highlight for page 1 (rank_start=1)."""
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
page_ids = all_ids[:PAGE_SIZE]
profile_cpu(
lambda: large_backend.highlight_hits(
COMMON_TERM,
page_ids,
rank_start=1,
),
label=f"highlight_hits page 1 (ids {all_ids[0]}..{all_ids[PAGE_SIZE - 1]})",
)
def test_highlight_page_middle(self, large_backend: TantivyBackend):
"""25-doc highlight for a mid-corpus page (rank_start=page_offset+1)."""
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
mid = len(all_ids) // 2
page_ids = all_ids[mid : mid + PAGE_SIZE]
page_offset = mid
profile_cpu(
lambda: large_backend.highlight_hits(
COMMON_TERM,
page_ids,
rank_start=page_offset + 1,
),
label=f"highlight_hits page ~{mid // PAGE_SIZE} (offset {page_offset})",
)
def test_highlight_repeated(self, large_backend: TantivyBackend):
"""Multiple runs of page-1 highlight to see variance."""
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
page_ids = all_ids[:PAGE_SIZE]
print()
_time(
lambda: large_backend.highlight_hits(COMMON_TERM, page_ids, rank_start=1),
label="highlight_hits page 1",
runs=5,
)
class TestFullPipelineProfile:
"""
Profile the combined pipeline as it runs in views.py:
search_ids -> filter(pk__in) -> highlight_hits
"""
def _run_pipeline(
self,
backend: TantivyBackend,
term: str,
page: int = 1,
):
all_ids = backend.search_ids(term, user=None)
qs = Document.objects.all()
visible_ids = set(qs.filter(pk__in=all_ids).values_list("pk", flat=True))
ordered_ids = [i for i in all_ids if i in visible_ids]
page_offset = (page - 1) * PAGE_SIZE
page_ids = ordered_ids[page_offset : page_offset + PAGE_SIZE]
hits = backend.highlight_hits(
term,
page_ids,
rank_start=page_offset + 1,
)
return ordered_ids, hits
def test_pipeline_large_page1(self, large_backend: TantivyBackend):
"""Full pipeline: large result set, page 1."""
ordered_ids, hits = profile_cpu(
lambda: self._run_pipeline(large_backend, COMMON_TERM, page=1),
label=f"full pipeline '{COMMON_TERM}' page 1",
)[0]
print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page")
def test_pipeline_large_page5(self, large_backend: TantivyBackend):
"""Full pipeline: large result set, page 5."""
ordered_ids, hits = profile_cpu(
lambda: self._run_pipeline(large_backend, COMMON_TERM, page=5),
label=f"full pipeline '{COMMON_TERM}' page 5",
)[0]
print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page")
def test_pipeline_rare(self, large_backend: TantivyBackend):
"""Full pipeline: rare term, page 1 (fast path)."""
ordered_ids, hits = profile_cpu(
lambda: self._run_pipeline(large_backend, RARE_TERM, page=1),
label=f"full pipeline '{RARE_TERM}' page 1",
)[0]
print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page")
def test_pipeline_repeated(self, large_backend: TantivyBackend):
"""Repeated runs to get stable timing (no cProfile overhead)."""
print()
for term, label in [
(COMMON_TERM, f"'{COMMON_TERM}' (large)"),
(MEDIUM_TERM, f"'{MEDIUM_TERM}' (medium)"),
(RARE_TERM, f"'{RARE_TERM}' (rare)"),
]:
_time(
lambda t=term: self._run_pipeline(large_backend, t, page=1),
label=f"full pipeline {label} page 1",
runs=3,
)

605
test_classifier_profile.py Normal file
View File

@@ -0,0 +1,605 @@
# ruff: noqa: T201
"""
cProfile + tracemalloc classifier profiling test.
Run with:
uv run pytest ../test_classifier_profile.py \
-m profiling --override-ini="addopts=" -s -v
Corpus: 5 000 documents, 40 correspondents (25 AUTO), 25 doc types (15 AUTO),
50 tags (30 AUTO), 20 storage paths (12 AUTO).
Document content is generated with Faker for realistic base text, with a
per-label fingerprint injected so the MLP has a real learning signal.
Scenarios:
- train() full corpus — memory and CPU profiles
- second train() no-op path — shows cost of the skip check
- save()/load() round-trip — model file size and memory cost
- _update_data_vectorizer_hash() isolated hash overhead
- predict_*() four independent calls per document — the 4x redundant
vectorization path used by the signal handlers
- _vectorize() cache-miss vs cache-hit breakdown
Memory: tracemalloc (delta + peak + top-20 allocation sites).
CPU: cProfile sorted by cumulative time (top 30).
"""
from __future__ import annotations
import random
import time
from typing import TYPE_CHECKING
import pytest
from django.test import override_settings
from faker import Faker
from profiling import measure_memory
from profiling import profile_cpu
from documents.classifier import DocumentClassifier
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import MatchingModel
from documents.models import StoragePath
from documents.models import Tag
if TYPE_CHECKING:
from pathlib import Path
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
# ---------------------------------------------------------------------------
# Corpus parameters
# ---------------------------------------------------------------------------
NUM_DOCS = 5_000
NUM_CORRESPONDENTS = 40 # first 25 are MATCH_AUTO
NUM_DOC_TYPES = 25 # first 15 are MATCH_AUTO
NUM_TAGS = 50 # first 30 are MATCH_AUTO
NUM_STORAGE_PATHS = 20 # first 12 are MATCH_AUTO
NUM_AUTO_CORRESPONDENTS = 25
NUM_AUTO_DOC_TYPES = 15
NUM_AUTO_TAGS = 30
NUM_AUTO_STORAGE_PATHS = 12
SEED = 42
# ---------------------------------------------------------------------------
# Content generation
# ---------------------------------------------------------------------------
def _make_label_fingerprint(
fake: Faker,
label_seed: int,
n_words: int = 6,
) -> list[str]:
"""
Generate a small set of unique-looking words to use as the learning
fingerprint for a label. Each label gets its own seeded Faker so the
fingerprints are distinct and reproducible.
"""
per_label_fake = Faker()
per_label_fake.seed_instance(label_seed)
# Mix word() and last_name() to get varied, pronounceable tokens
words: list[str] = []
while len(words) < n_words:
w = per_label_fake.word().lower()
if w not in words:
words.append(w)
return words
def _build_fingerprints(
num_correspondents: int,
num_doc_types: int,
num_tags: int,
num_paths: int,
) -> tuple[list[list[str]], list[list[str]], list[list[str]], list[list[str]]]:
"""Pre-generate per-label fingerprints. Expensive once, free to reuse."""
fake = Faker()
# Use deterministic seeds offset by type so fingerprints don't collide
corr_fps = [
_make_label_fingerprint(fake, 1_000 + i) for i in range(num_correspondents)
]
dtype_fps = [_make_label_fingerprint(fake, 2_000 + i) for i in range(num_doc_types)]
tag_fps = [_make_label_fingerprint(fake, 3_000 + i) for i in range(num_tags)]
path_fps = [_make_label_fingerprint(fake, 4_000 + i) for i in range(num_paths)]
return corr_fps, dtype_fps, tag_fps, path_fps
def _build_content(
fake: Faker,
corr_fp: list[str] | None,
dtype_fp: list[str] | None,
tag_fps: list[list[str]],
path_fp: list[str] | None,
) -> str:
"""
Combine a Faker paragraph (realistic base text) with per-label
fingerprint words so the classifier has a genuine learning signal.
"""
# 3-sentence paragraph provides realistic vocabulary
base = fake.paragraph(nb_sentences=3)
extras: list[str] = []
if corr_fp:
extras.extend(corr_fp)
if dtype_fp:
extras.extend(dtype_fp)
for fp in tag_fps:
extras.extend(fp)
if path_fp:
extras.extend(path_fp)
if extras:
return base + " " + " ".join(extras)
return base
# ---------------------------------------------------------------------------
# Module-scoped corpus fixture
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def module_db(django_db_setup, django_db_blocker):
"""Unlock the DB for the whole module (module-scoped)."""
with django_db_blocker.unblock():
yield
@pytest.fixture(scope="module")
def classifier_corpus(tmp_path_factory, module_db):
"""
Build the full 5 000-document corpus once for all profiling tests.
Label objects are created individually (small number), documents are
bulk-inserted, and tag M2M rows go through the through-table.
Yields a dict with the model path and a sample content string for
prediction tests. All rows are deleted on teardown.
"""
model_path: Path = tmp_path_factory.mktemp("cls_profile") / "model.pickle"
with override_settings(MODEL_FILE=model_path):
fake = Faker()
Faker.seed(SEED)
rng = random.Random(SEED)
# Pre-generate fingerprints for all labels
print("\n[setup] Generating label fingerprints...")
corr_fps, dtype_fps, tag_fps, path_fps = _build_fingerprints(
NUM_CORRESPONDENTS,
NUM_DOC_TYPES,
NUM_TAGS,
NUM_STORAGE_PATHS,
)
# -----------------------------------------------------------------
# 1. Create label objects
# -----------------------------------------------------------------
print(f"[setup] Creating {NUM_CORRESPONDENTS} correspondents...")
correspondents: list[Correspondent] = []
for i in range(NUM_CORRESPONDENTS):
algo = (
MatchingModel.MATCH_AUTO
if i < NUM_AUTO_CORRESPONDENTS
else MatchingModel.MATCH_NONE
)
correspondents.append(
Correspondent.objects.create(
name=fake.company(),
matching_algorithm=algo,
),
)
print(f"[setup] Creating {NUM_DOC_TYPES} document types...")
doc_types: list[DocumentType] = []
for i in range(NUM_DOC_TYPES):
algo = (
MatchingModel.MATCH_AUTO
if i < NUM_AUTO_DOC_TYPES
else MatchingModel.MATCH_NONE
)
doc_types.append(
DocumentType.objects.create(
name=fake.bs()[:64],
matching_algorithm=algo,
),
)
print(f"[setup] Creating {NUM_TAGS} tags...")
tags: list[Tag] = []
for i in range(NUM_TAGS):
algo = (
MatchingModel.MATCH_AUTO
if i < NUM_AUTO_TAGS
else MatchingModel.MATCH_NONE
)
tags.append(
Tag.objects.create(
name=f"{fake.word()} {i}",
matching_algorithm=algo,
is_inbox_tag=False,
),
)
print(f"[setup] Creating {NUM_STORAGE_PATHS} storage paths...")
storage_paths: list[StoragePath] = []
for i in range(NUM_STORAGE_PATHS):
algo = (
MatchingModel.MATCH_AUTO
if i < NUM_AUTO_STORAGE_PATHS
else MatchingModel.MATCH_NONE
)
storage_paths.append(
StoragePath.objects.create(
name=fake.word(),
path=f"{fake.word()}/{fake.word()}/{{title}}",
matching_algorithm=algo,
),
)
# -----------------------------------------------------------------
# 2. Build document rows and M2M assignments
# -----------------------------------------------------------------
print(f"[setup] Building {NUM_DOCS} document rows...")
doc_rows: list[Document] = []
doc_tag_map: list[tuple[int, int]] = [] # (doc_position, tag_index)
for i in range(NUM_DOCS):
corr_idx = (
rng.randrange(NUM_CORRESPONDENTS) if rng.random() < 0.80 else None
)
dt_idx = rng.randrange(NUM_DOC_TYPES) if rng.random() < 0.80 else None
sp_idx = rng.randrange(NUM_STORAGE_PATHS) if rng.random() < 0.70 else None
# 1-4 tags; most documents get at least one
n_tags = rng.randint(1, 4) if rng.random() < 0.85 else 0
assigned_tag_indices = rng.sample(range(NUM_TAGS), min(n_tags, NUM_TAGS))
content = _build_content(
fake,
corr_fp=corr_fps[corr_idx] if corr_idx is not None else None,
dtype_fp=dtype_fps[dt_idx] if dt_idx is not None else None,
tag_fps=[tag_fps[ti] for ti in assigned_tag_indices],
path_fp=path_fps[sp_idx] if sp_idx is not None else None,
)
doc_rows.append(
Document(
title=fake.sentence(nb_words=5),
content=content,
checksum=f"{i:064x}",
correspondent=correspondents[corr_idx]
if corr_idx is not None
else None,
document_type=doc_types[dt_idx] if dt_idx is not None else None,
storage_path=storage_paths[sp_idx] if sp_idx is not None else None,
),
)
for ti in assigned_tag_indices:
doc_tag_map.append((i, ti))
t0 = time.perf_counter()
Document.objects.bulk_create(doc_rows, batch_size=500)
print(
f"[setup] bulk_create {NUM_DOCS} documents: {time.perf_counter() - t0:.2f}s",
)
# -----------------------------------------------------------------
# 3. Bulk-create M2M through-table rows
# -----------------------------------------------------------------
created_docs = list(Document.objects.order_by("pk"))
through_rows = [
Document.tags.through(
document_id=created_docs[pos].pk,
tag_id=tags[ti].pk,
)
for pos, ti in doc_tag_map
if pos < len(created_docs)
]
t0 = time.perf_counter()
Document.tags.through.objects.bulk_create(
through_rows,
batch_size=1_000,
ignore_conflicts=True,
)
print(
f"[setup] bulk_create {len(through_rows)} tag M2M rows: "
f"{time.perf_counter() - t0:.2f}s",
)
# Sample content for prediction tests
sample_content = _build_content(
fake,
corr_fp=corr_fps[0],
dtype_fp=dtype_fps[0],
tag_fps=[tag_fps[0], tag_fps[1], tag_fps[5]],
path_fp=path_fps[0],
)
yield {
"model_path": model_path,
"sample_content": sample_content,
}
# Teardown
print("\n[teardown] Removing corpus...")
Document.objects.all().delete()
Correspondent.objects.all().delete()
DocumentType.objects.all().delete()
Tag.objects.all().delete()
StoragePath.objects.all().delete()
# ---------------------------------------------------------------------------
# Training profiles
# ---------------------------------------------------------------------------
class TestClassifierTrainingProfile:
"""Profile DocumentClassifier.train() on the full corpus."""
def test_train_memory(self, classifier_corpus, tmp_path):
"""
Peak memory allocated during train().
tracemalloc reports the delta and top allocation sites.
"""
model_path = tmp_path / "model.pickle"
with override_settings(MODEL_FILE=model_path):
classifier = DocumentClassifier()
result, _, _ = measure_memory(
classifier.train,
label=(
f"train() [{NUM_DOCS} docs | "
f"{NUM_CORRESPONDENTS} correspondents ({NUM_AUTO_CORRESPONDENTS} AUTO) | "
f"{NUM_DOC_TYPES} doc types ({NUM_AUTO_DOC_TYPES} AUTO) | "
f"{NUM_TAGS} tags ({NUM_AUTO_TAGS} AUTO) | "
f"{NUM_STORAGE_PATHS} paths ({NUM_AUTO_STORAGE_PATHS} AUTO)]"
),
)
assert result is True, "train() must return True on first run"
print("\n Classifiers trained:")
print(
f" tags_classifier: {classifier.tags_classifier is not None}",
)
print(
f" correspondent_classifier: {classifier.correspondent_classifier is not None}",
)
print(
f" document_type_classifier: {classifier.document_type_classifier is not None}",
)
print(
f" storage_path_classifier: {classifier.storage_path_classifier is not None}",
)
if classifier.data_vectorizer is not None:
vocab_size = len(classifier.data_vectorizer.vocabulary_)
print(f" vocabulary size: {vocab_size} terms")
def test_train_cpu(self, classifier_corpus, tmp_path):
"""
CPU profile of train() — shows time spent in DB queries,
CountVectorizer.fit_transform(), and four MLPClassifier.fit() calls.
"""
model_path = tmp_path / "model_cpu.pickle"
with override_settings(MODEL_FILE=model_path):
classifier = DocumentClassifier()
profile_cpu(
classifier.train,
label=f"train() [{NUM_DOCS} docs]",
top=30,
)
def test_train_second_call_noop(self, classifier_corpus, tmp_path):
"""
No-op path: second train() on unchanged data should return False.
Still queries the DB to build the hash — shown here as the remaining cost.
"""
model_path = tmp_path / "model_noop.pickle"
with override_settings(MODEL_FILE=model_path):
classifier = DocumentClassifier()
t0 = time.perf_counter()
classifier.train()
first_ms = (time.perf_counter() - t0) * 1000
result, second_elapsed = profile_cpu(
classifier.train,
label="train() second call (no-op — same data unchanged)",
top=20,
)
assert result is False, "second train() should skip and return False"
print(f"\n First train: {first_ms:.1f} ms (full fit)")
print(f" Second train: {second_elapsed * 1000:.1f} ms (skip)")
print(f" Speedup: {first_ms / (second_elapsed * 1000):.1f}x")
def test_vectorizer_hash_cost(self, classifier_corpus, tmp_path):
"""
Isolate _update_data_vectorizer_hash() — pickles the entire
CountVectorizer just to SHA256 it. Called at both save and load.
"""
import pickle
model_path = tmp_path / "model_hash.pickle"
with override_settings(MODEL_FILE=model_path):
classifier = DocumentClassifier()
classifier.train()
profile_cpu(
classifier._update_data_vectorizer_hash,
label="_update_data_vectorizer_hash() [pickle.dumps vectorizer + sha256]",
top=10,
)
pickled_size = len(pickle.dumps(classifier.data_vectorizer))
vocab_size = len(classifier.data_vectorizer.vocabulary_)
print(f"\n Vocabulary size: {vocab_size} terms")
print(f" Pickled vectorizer: {pickled_size / 1024:.1f} KiB")
def test_save_load_roundtrip(self, classifier_corpus, tmp_path):
"""
Profile save() and load() — model file size directly reflects how
much memory the classifier occupies on disk (and roughly in RAM).
"""
model_path = tmp_path / "model_saveload.pickle"
with override_settings(MODEL_FILE=model_path):
classifier = DocumentClassifier()
classifier.train()
_, save_peak, _ = measure_memory(
classifier.save,
label="save() [pickle.dumps + HMAC + atomic rename]",
)
file_size_kib = model_path.stat().st_size / 1024
print(f"\n Model file size: {file_size_kib:.1f} KiB")
classifier2 = DocumentClassifier()
_, load_peak, _ = measure_memory(
classifier2.load,
label="load() [read file + verify HMAC + pickle.loads]",
)
print("\n Summary:")
print(f" Model file size: {file_size_kib:.1f} KiB")
print(f" Save peak memory: {save_peak:.1f} KiB")
print(f" Load peak memory: {load_peak:.1f} KiB")
# ---------------------------------------------------------------------------
# Prediction profiles
# ---------------------------------------------------------------------------
class TestClassifierPredictionProfile:
"""
Profile the four predict_*() methods — specifically the redundant
per-call vectorization overhead from the signal handler pattern.
"""
@pytest.fixture(autouse=True)
def trained_classifier(self, classifier_corpus, tmp_path):
model_path = tmp_path / "model_pred.pickle"
self._ctx = override_settings(MODEL_FILE=model_path)
self._ctx.enable()
self.classifier = DocumentClassifier()
self.classifier.train()
self.content = classifier_corpus["sample_content"]
yield
self._ctx.disable()
def test_predict_all_four_separately_cpu(self):
"""
Profile all four predict_*() calls in the order the signal handlers
fire them. Call 1 is a cache miss; calls 2-4 hit the locmem cache
but still pay sha256 + pickle.loads each time.
"""
from django.core.cache import caches
caches["read-cache"].clear()
content = self.content
print(f"\n Content length: {len(content)} chars")
calls = [
("predict_correspondent", self.classifier.predict_correspondent),
("predict_document_type", self.classifier.predict_document_type),
("predict_tags", self.classifier.predict_tags),
("predict_storage_path", self.classifier.predict_storage_path),
]
timings: list[tuple[str, float]] = []
for name, fn in calls:
_, elapsed = profile_cpu(
lambda f=fn: f(content),
label=f"{name}() [call {len(timings) + 1}/4]",
top=15,
)
timings.append((name, elapsed * 1000))
print("\n Per-call timings (sequential, locmem cache):")
for name, ms in timings:
print(f" {name:<32s} {ms:8.3f} ms")
print(f" {'TOTAL':<32s} {sum(t for _, t in timings):8.3f} ms")
def test_predict_all_four_memory(self):
"""
Memory allocated for the full four-prediction sequence, both cold
and warm, to show pickle serialization allocation per call.
"""
from django.core.cache import caches
content = self.content
calls = [
self.classifier.predict_correspondent,
self.classifier.predict_document_type,
self.classifier.predict_tags,
self.classifier.predict_storage_path,
]
caches["read-cache"].clear()
measure_memory(
lambda: [fn(content) for fn in calls],
label="all four predict_*() [cache COLD — first call misses]",
)
measure_memory(
lambda: [fn(content) for fn in calls],
label="all four predict_*() [cache WARM — all calls hit]",
)
def test_vectorize_cache_miss_vs_hit(self):
"""
Isolate the cost of a cache miss (sha256 + transform + pickle.dumps)
vs a cache hit (sha256 + pickle.loads).
"""
from django.core.cache import caches
read_cache = caches["read-cache"]
content = self.content
read_cache.clear()
_, miss_elapsed = profile_cpu(
lambda: self.classifier._vectorize(content),
label="_vectorize() [MISS: sha256 + transform + pickle.dumps]",
top=15,
)
_, hit_elapsed = profile_cpu(
lambda: self.classifier._vectorize(content),
label="_vectorize() [HIT: sha256 + pickle.loads]",
top=15,
)
print(f"\n Cache miss: {miss_elapsed * 1000:.3f} ms")
print(f" Cache hit: {hit_elapsed * 1000:.3f} ms")
print(f" Hit is {miss_elapsed / hit_elapsed:.1f}x faster than miss")
def test_content_hash_overhead(self):
"""
Micro-benchmark the sha256 of the content string — paid on every
_vectorize() call regardless of cache state, including x4 per doc.
"""
import hashlib
content = self.content
encoded = content.encode()
runs = 5_000
t0 = time.perf_counter()
for _ in range(runs):
hashlib.sha256(encoded).hexdigest()
us_per_call = (time.perf_counter() - t0) / runs * 1_000_000
print(f"\n Content: {len(content)} chars / {len(encoded)} bytes")
print(f" sha256 cost per call: {us_per_call:.2f} us (avg over {runs} runs)")
print(f" x4 calls per document: {us_per_call * 4:.2f} us total overhead")

293
test_doclist_profile.py Normal file
View File

@@ -0,0 +1,293 @@
"""
Document list API profiling — no search, pure ORM path.
Run with:
uv run pytest ../test_doclist_profile.py \
-m profiling --override-ini="addopts=" -s -v
Corpus: 5 000 documents, 30 correspondents, 20 doc types, 80 tags,
~500 notes (10 %), 10 custom fields with instances on ~50 % of docs.
Scenarios
---------
TestDocListProfile
- test_list_default_ordering GET /api/documents/ created desc, page 1, page_size=25
- test_list_title_ordering same with ordering=title
- test_list_page_size_comparison page_size=10 / 25 / 100 in sequence
- test_list_detail_fields GET /api/documents/{id}/ — single document serializer cost
- test_list_cpu_profile cProfile of one list request
TestSelectionDataProfile
- test_selection_data_unfiltered _get_selection_data_for_queryset(all docs) in isolation
- test_selection_data_via_api GET /api/documents/?include_selection_data=true
- test_selection_data_filtered filtered vs unfiltered COUNT query comparison
"""
from __future__ import annotations
import datetime
import random
import time
import pytest
from django.contrib.auth.models import User
from faker import Faker
from profiling import profile_block
from profiling import profile_cpu
from rest_framework.test import APIClient
from documents.models import Correspondent
from documents.models import CustomField
from documents.models import CustomFieldInstance
from documents.models import Document
from documents.models import DocumentType
from documents.models import Note
from documents.models import Tag
from documents.views import DocumentViewSet
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
# ---------------------------------------------------------------------------
# Corpus parameters
# ---------------------------------------------------------------------------
NUM_DOCS = 5_000
NUM_CORRESPONDENTS = 30
NUM_DOC_TYPES = 20
NUM_TAGS = 80
NOTE_FRACTION = 0.10
CUSTOM_FIELD_COUNT = 10
CUSTOM_FIELD_FRACTION = 0.50
PAGE_SIZE = 25
SEED = 42
# ---------------------------------------------------------------------------
# Module-scoped corpus fixture
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def module_db(django_db_setup, django_db_blocker):
"""Unlock the DB for the whole module (module-scoped)."""
with django_db_blocker.unblock():
yield
@pytest.fixture(scope="module")
def doclist_corpus(module_db):
"""
Build a 5 000-document corpus with tags, notes, custom fields, correspondents,
and doc types. All objects are deleted on teardown.
"""
fake = Faker()
Faker.seed(SEED)
rng = random.Random(SEED)
print(f"\n[setup] Creating {NUM_CORRESPONDENTS} correspondents...") # noqa: T201
correspondents = [
Correspondent.objects.create(name=f"dlcorp-{i}-{fake.company()}"[:128])
for i in range(NUM_CORRESPONDENTS)
]
print(f"[setup] Creating {NUM_DOC_TYPES} doc types...") # noqa: T201
doc_types = [
DocumentType.objects.create(name=f"dltype-{i}-{fake.word()}"[:128])
for i in range(NUM_DOC_TYPES)
]
print(f"[setup] Creating {NUM_TAGS} tags...") # noqa: T201
tags = [
Tag.objects.create(name=f"dltag-{i}-{fake.word()}"[:100])
for i in range(NUM_TAGS)
]
print(f"[setup] Creating {CUSTOM_FIELD_COUNT} custom fields...") # noqa: T201
custom_fields = [
CustomField.objects.create(
name=f"Field {i}",
data_type=CustomField.FieldDataType.STRING,
)
for i in range(CUSTOM_FIELD_COUNT)
]
note_user = User.objects.create_user(username="doclistnoteuser", password="x")
owner = User.objects.create_superuser(username="doclistowner", password="admin")
print(f"[setup] Building {NUM_DOCS} document rows...") # noqa: T201
base_date = datetime.date(2018, 1, 1)
raw_docs = []
for i in range(NUM_DOCS):
day_offset = rng.randint(0, 6 * 365)
raw_docs.append(
Document(
title=fake.sentence(nb_words=rng.randint(3, 8)).rstrip("."),
content="\n\n".join(
fake.paragraph(nb_sentences=rng.randint(2, 5))
for _ in range(rng.randint(1, 3))
),
checksum=f"DL{i:07d}",
correspondent=rng.choice(correspondents + [None] * 5),
document_type=rng.choice(doc_types + [None] * 4),
created=base_date + datetime.timedelta(days=day_offset),
owner=owner if rng.random() < 0.8 else None,
),
)
t0 = time.perf_counter()
documents = Document.objects.bulk_create(raw_docs)
print(f"[setup] bulk_create {NUM_DOCS} docs: {time.perf_counter() - t0:.2f}s") # noqa: T201
t0 = time.perf_counter()
for doc in documents:
k = rng.randint(0, 5)
if k:
doc.tags.add(*rng.sample(tags, k))
print(f"[setup] tag M2M assignments: {time.perf_counter() - t0:.2f}s") # noqa: T201
note_docs = rng.sample(documents, int(NUM_DOCS * NOTE_FRACTION))
Note.objects.bulk_create(
[
Note(
document=doc,
note=fake.sentence(nb_words=rng.randint(4, 15)),
user=note_user,
)
for doc in note_docs
],
)
cf_docs = rng.sample(documents, int(NUM_DOCS * CUSTOM_FIELD_FRACTION))
CustomFieldInstance.objects.bulk_create(
[
CustomFieldInstance(
document=doc,
field=rng.choice(custom_fields),
value_text=fake.word(),
)
for doc in cf_docs
],
)
first_doc_pk = documents[0].pk
yield {"owner": owner, "first_doc_pk": first_doc_pk, "tags": tags}
print("\n[teardown] Removing doclist corpus...") # noqa: T201
Document.objects.all().delete()
Correspondent.objects.all().delete()
DocumentType.objects.all().delete()
Tag.objects.all().delete()
CustomField.objects.all().delete()
User.objects.filter(username__in=["doclistnoteuser", "doclistowner"]).delete()
# ---------------------------------------------------------------------------
# TestDocListProfile
# ---------------------------------------------------------------------------
class TestDocListProfile:
"""Profile GET /api/documents/ — pure ORM path, no Tantivy."""
@pytest.fixture(autouse=True)
def _client(self, doclist_corpus):
owner = doclist_corpus["owner"]
self.client = APIClient()
self.client.force_authenticate(user=owner)
self.first_doc_pk = doclist_corpus["first_doc_pk"]
def test_list_default_ordering(self):
"""GET /api/documents/ default ordering (-created), page 1, page_size=25."""
with profile_block(
f"GET /api/documents/ default ordering [page_size={PAGE_SIZE}]",
):
response = self.client.get(
f"/api/documents/?page=1&page_size={PAGE_SIZE}",
)
assert response.status_code == 200
def test_list_title_ordering(self):
"""GET /api/documents/ ordered by title — tests ORM sort path."""
with profile_block(
f"GET /api/documents/?ordering=title [page_size={PAGE_SIZE}]",
):
response = self.client.get(
f"/api/documents/?ordering=title&page=1&page_size={PAGE_SIZE}",
)
assert response.status_code == 200
def test_list_page_size_comparison(self):
"""Compare serializer cost at page_size=10, 25, 100."""
for page_size in [10, 25, 100]:
with profile_block(f"GET /api/documents/ [page_size={page_size}]"):
response = self.client.get(
f"/api/documents/?page=1&page_size={page_size}",
)
assert response.status_code == 200
def test_list_detail_fields(self):
"""GET /api/documents/{id}/ — per-doc serializer cost with all relations."""
pk = self.first_doc_pk
with profile_block(f"GET /api/documents/{pk}/ — single doc serializer"):
response = self.client.get(f"/api/documents/{pk}/")
assert response.status_code == 200
def test_list_cpu_profile(self):
"""cProfile of one list request — surfaces hot frames in serializer."""
profile_cpu(
lambda: self.client.get(
f"/api/documents/?page=1&page_size={PAGE_SIZE}",
),
label=f"GET /api/documents/ cProfile [page_size={PAGE_SIZE}]",
top=30,
)
# ---------------------------------------------------------------------------
# TestSelectionDataProfile
# ---------------------------------------------------------------------------
class TestSelectionDataProfile:
"""Profile _get_selection_data_for_queryset — the 5+ COUNT queries per request."""
@pytest.fixture(autouse=True)
def _setup(self, doclist_corpus):
owner = doclist_corpus["owner"]
self.client = APIClient()
self.client.force_authenticate(user=owner)
self.tags = doclist_corpus["tags"]
def test_selection_data_unfiltered(self):
"""Call _get_selection_data_for_queryset(all docs) directly — COUNT queries in isolation."""
viewset = DocumentViewSet()
qs = Document.objects.all()
with profile_block("_get_selection_data_for_queryset(all docs) — direct call"):
viewset._get_selection_data_for_queryset(qs)
def test_selection_data_via_api(self):
"""Full API round-trip with include_selection_data=true."""
with profile_block(
f"GET /api/documents/?include_selection_data=true [page_size={PAGE_SIZE}]",
):
response = self.client.get(
f"/api/documents/?page=1&page_size={PAGE_SIZE}&include_selection_data=true",
)
assert response.status_code == 200
assert "selection_data" in response.data
def test_selection_data_filtered(self):
"""selection_data on a tag-filtered queryset — filtered COUNT vs unfiltered."""
tag = self.tags[0]
viewset = DocumentViewSet()
filtered_qs = Document.objects.filter(tags=tag)
unfiltered_qs = Document.objects.all()
print(f"\n Tag '{tag.name}' matches {filtered_qs.count()} docs") # noqa: T201
with profile_block("_get_selection_data_for_queryset(unfiltered)"):
viewset._get_selection_data_for_queryset(unfiltered_qs)
with profile_block("_get_selection_data_for_queryset(filtered by tag)"):
viewset._get_selection_data_for_queryset(filtered_qs)

284
test_matching_profile.py Normal file
View File

@@ -0,0 +1,284 @@
"""
Matching pipeline profiling.
Run with:
uv run pytest ../test_matching_profile.py \
-m profiling --override-ini="addopts=" -s -v
Corpus: 1 document + 50 correspondents, 100 tags, 25 doc types, 20 storage
paths. Labels are spread across all six matching algorithms
(NONE, ANY, ALL, LITERAL, REGEX, FUZZY, AUTO).
Classifier is passed as None -- MATCH_AUTO models skip prediction gracefully,
which is correct for isolating the ORM query and Python-side evaluation cost.
Scenarios
---------
TestMatchingPipelineProfile
- test_match_correspondents 50 correspondents, algorithm mix
- test_match_tags 100 tags
- test_match_document_types 25 doc types
- test_match_storage_paths 20 storage paths
- test_full_match_sequence all four in order (cumulative consumption cost)
- test_algorithm_breakdown each MATCH_* algorithm in isolation
"""
from __future__ import annotations
import random
import pytest
from faker import Faker
from profiling import profile_block
from documents.matching import match_correspondents
from documents.matching import match_document_types
from documents.matching import match_storage_paths
from documents.matching import match_tags
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import MatchingModel
from documents.models import StoragePath
from documents.models import Tag
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
NUM_CORRESPONDENTS = 50
NUM_TAGS = 100
NUM_DOC_TYPES = 25
NUM_STORAGE_PATHS = 20
SEED = 42
# Algorithm distribution across labels (cycles through in order)
_ALGORITHMS = [
MatchingModel.MATCH_NONE,
MatchingModel.MATCH_ANY,
MatchingModel.MATCH_ALL,
MatchingModel.MATCH_LITERAL,
MatchingModel.MATCH_REGEX,
MatchingModel.MATCH_FUZZY,
MatchingModel.MATCH_AUTO,
]
def _algo(i: int) -> int:
return _ALGORITHMS[i % len(_ALGORITHMS)]
# ---------------------------------------------------------------------------
# Module-scoped corpus fixture
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def module_db(django_db_setup, django_db_blocker):
"""Unlock the DB for the whole module (module-scoped)."""
with django_db_blocker.unblock():
yield
@pytest.fixture(scope="module")
def matching_corpus(module_db):
"""
1 document with realistic content + dense matching model sets.
Classifier=None so MATCH_AUTO models are simply skipped.
"""
fake = Faker()
Faker.seed(SEED)
random.seed(SEED)
# ---- matching models ---------------------------------------------------
print(f"\n[setup] Creating {NUM_CORRESPONDENTS} correspondents...") # noqa: T201
correspondents = []
for i in range(NUM_CORRESPONDENTS):
algo = _algo(i)
match_text = (
fake.word()
if algo not in (MatchingModel.MATCH_NONE, MatchingModel.MATCH_AUTO)
else ""
)
if algo == MatchingModel.MATCH_REGEX:
match_text = r"\b" + fake.word() + r"\b"
correspondents.append(
Correspondent.objects.create(
name=f"mcorp-{i}-{fake.company()}"[:128],
matching_algorithm=algo,
match=match_text,
),
)
print(f"[setup] Creating {NUM_TAGS} tags...") # noqa: T201
tags = []
for i in range(NUM_TAGS):
algo = _algo(i)
match_text = (
fake.word()
if algo not in (MatchingModel.MATCH_NONE, MatchingModel.MATCH_AUTO)
else ""
)
if algo == MatchingModel.MATCH_REGEX:
match_text = r"\b" + fake.word() + r"\b"
tags.append(
Tag.objects.create(
name=f"mtag-{i}-{fake.word()}"[:100],
matching_algorithm=algo,
match=match_text,
),
)
print(f"[setup] Creating {NUM_DOC_TYPES} doc types...") # noqa: T201
doc_types = []
for i in range(NUM_DOC_TYPES):
algo = _algo(i)
match_text = (
fake.word()
if algo not in (MatchingModel.MATCH_NONE, MatchingModel.MATCH_AUTO)
else ""
)
if algo == MatchingModel.MATCH_REGEX:
match_text = r"\b" + fake.word() + r"\b"
doc_types.append(
DocumentType.objects.create(
name=f"mtype-{i}-{fake.word()}"[:128],
matching_algorithm=algo,
match=match_text,
),
)
print(f"[setup] Creating {NUM_STORAGE_PATHS} storage paths...") # noqa: T201
storage_paths = []
for i in range(NUM_STORAGE_PATHS):
algo = _algo(i)
match_text = (
fake.word()
if algo not in (MatchingModel.MATCH_NONE, MatchingModel.MATCH_AUTO)
else ""
)
if algo == MatchingModel.MATCH_REGEX:
match_text = r"\b" + fake.word() + r"\b"
storage_paths.append(
StoragePath.objects.create(
name=f"mpath-{i}-{fake.word()}",
path=f"{fake.word()}/{{title}}",
matching_algorithm=algo,
match=match_text,
),
)
# ---- document with diverse content ------------------------------------
doc = Document.objects.create(
title="quarterly invoice payment tax financial statement",
content=" ".join(fake.paragraph(nb_sentences=5) for _ in range(3)),
checksum="MATCHPROF0001",
)
print(f"[setup] Document pk={doc.pk}, content length={len(doc.content)} chars") # noqa: T201
print( # noqa: T201
f" Correspondents: {NUM_CORRESPONDENTS} "
f"({sum(1 for c in correspondents if c.matching_algorithm == MatchingModel.MATCH_AUTO)} AUTO)",
)
print( # noqa: T201
f" Tags: {NUM_TAGS} "
f"({sum(1 for t in tags if t.matching_algorithm == MatchingModel.MATCH_AUTO)} AUTO)",
)
yield {"doc": doc}
# Teardown
print("\n[teardown] Removing matching corpus...") # noqa: T201
Document.objects.all().delete()
Correspondent.objects.all().delete()
Tag.objects.all().delete()
DocumentType.objects.all().delete()
StoragePath.objects.all().delete()
# ---------------------------------------------------------------------------
# TestMatchingPipelineProfile
# ---------------------------------------------------------------------------
class TestMatchingPipelineProfile:
"""Profile the matching functions called per document during consumption."""
@pytest.fixture(autouse=True)
def _setup(self, matching_corpus):
self.doc = matching_corpus["doc"]
def test_match_correspondents(self):
"""50 correspondents, algorithm mix. Query count + time."""
with profile_block(
f"match_correspondents() [{NUM_CORRESPONDENTS} correspondents, mixed algorithms]",
):
result = match_correspondents(self.doc, classifier=None)
print(f" -> {len(result)} matched") # noqa: T201
def test_match_tags(self):
"""100 tags -- densest set in real installs."""
with profile_block(f"match_tags() [{NUM_TAGS} tags, mixed algorithms]"):
result = match_tags(self.doc, classifier=None)
print(f" -> {len(result)} matched") # noqa: T201
def test_match_document_types(self):
"""25 doc types."""
with profile_block(
f"match_document_types() [{NUM_DOC_TYPES} types, mixed algorithms]",
):
result = match_document_types(self.doc, classifier=None)
print(f" -> {len(result)} matched") # noqa: T201
def test_match_storage_paths(self):
"""20 storage paths."""
with profile_block(
f"match_storage_paths() [{NUM_STORAGE_PATHS} paths, mixed algorithms]",
):
result = match_storage_paths(self.doc, classifier=None)
print(f" -> {len(result)} matched") # noqa: T201
def test_full_match_sequence(self):
"""All four match_*() calls in order -- cumulative cost per document consumed."""
with profile_block(
"full match sequence: correspondents + doc_types + tags + storage_paths",
):
match_correspondents(self.doc, classifier=None)
match_document_types(self.doc, classifier=None)
match_tags(self.doc, classifier=None)
match_storage_paths(self.doc, classifier=None)
def test_algorithm_breakdown(self):
"""Create one correspondent per algorithm and time each independently."""
import time
from documents.matching import matches
fake = Faker()
algo_names = {
MatchingModel.MATCH_NONE: "MATCH_NONE",
MatchingModel.MATCH_ANY: "MATCH_ANY",
MatchingModel.MATCH_ALL: "MATCH_ALL",
MatchingModel.MATCH_LITERAL: "MATCH_LITERAL",
MatchingModel.MATCH_REGEX: "MATCH_REGEX",
MatchingModel.MATCH_FUZZY: "MATCH_FUZZY",
}
doc = self.doc
print() # noqa: T201
for algo, name in algo_names.items():
match_text = fake.word() if algo != MatchingModel.MATCH_NONE else ""
if algo == MatchingModel.MATCH_REGEX:
match_text = r"\b" + fake.word() + r"\b"
model = Correspondent(
name=f"algo-test-{name}",
matching_algorithm=algo,
match=match_text,
)
# Time 1000 iterations to get stable microsecond readings
runs = 1_000
t0 = time.perf_counter()
for _ in range(runs):
matches(model, doc)
us_per_call = (time.perf_counter() - t0) / runs * 1_000_000
print( # noqa: T201
f" {name:<20s} {us_per_call:8.2f} us/call (match={match_text[:20]!r})",
)

154
test_sanity_profile.py Normal file
View File

@@ -0,0 +1,154 @@
"""
Sanity checker profiling.
Run with:
uv run pytest ../test_sanity_profile.py \
-m profiling --override-ini="addopts=" -s -v
Corpus: 2 000 documents with stub files (original + archive + thumbnail)
created in a temp MEDIA_ROOT.
Scenarios
---------
TestSanityCheckerProfile
- test_sanity_full_corpus full check_sanity() -- cProfile + tracemalloc
- test_sanity_query_pattern profile_block summary: query count + time
"""
from __future__ import annotations
import hashlib
import time
import pytest
from django.test import override_settings
from profiling import measure_memory
from profiling import profile_block
from profiling import profile_cpu
from documents.models import Document
from documents.sanity_checker import check_sanity
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
NUM_DOCS = 2_000
SEED = 42
# ---------------------------------------------------------------------------
# Module-scoped fixture: temp directories + corpus
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def module_db(django_db_setup, django_db_blocker):
"""Unlock the DB for the whole module (module-scoped)."""
with django_db_blocker.unblock():
yield
@pytest.fixture(scope="module")
def sanity_corpus(tmp_path_factory, module_db):
"""
Build a 2 000-document corpus. For each document create stub files
(1-byte placeholders) in ORIGINALS_DIR, ARCHIVE_DIR, and THUMBNAIL_DIR
so the sanity checker's file-existence and checksum checks have real targets.
"""
media = tmp_path_factory.mktemp("sanity_media")
originals_dir = media / "documents" / "originals"
archive_dir = media / "documents" / "archive"
thumb_dir = media / "documents" / "thumbnails"
for d in (originals_dir, archive_dir, thumb_dir):
d.mkdir(parents=True)
# Use override_settings as a context manager for the whole fixture lifetime
settings_ctx = override_settings(
MEDIA_ROOT=media,
ORIGINALS_DIR=originals_dir,
ARCHIVE_DIR=archive_dir,
THUMBNAIL_DIR=thumb_dir,
MEDIA_LOCK=media / "media.lock",
)
settings_ctx.enable()
print(f"\n[setup] Creating {NUM_DOCS} documents with stub files...") # noqa: T201
t0 = time.perf_counter()
docs = []
for i in range(NUM_DOCS):
content = f"document content for doc {i}"
checksum = hashlib.sha256(content.encode()).hexdigest()
orig_filename = f"{i:07d}.pdf"
arch_filename = f"{i:07d}.pdf"
orig_path = originals_dir / orig_filename
arch_path = archive_dir / arch_filename
orig_path.write_bytes(content.encode())
arch_path.write_bytes(content.encode())
docs.append(
Document(
title=f"Document {i:05d}",
content=content,
checksum=checksum,
archive_checksum=checksum,
filename=orig_filename,
archive_filename=arch_filename,
mime_type="application/pdf",
),
)
created = Document.objects.bulk_create(docs, batch_size=500)
# Thumbnails use doc.pk, so create them after bulk_create assigns pks
for doc in created:
thumb_path = thumb_dir / f"{doc.pk:07d}.webp"
thumb_path.write_bytes(b"\x00") # minimal thumbnail stub
print( # noqa: T201
f"[setup] bulk_create + file creation: {time.perf_counter() - t0:.2f}s",
)
yield {"media": media}
# Teardown
print("\n[teardown] Removing sanity corpus...") # noqa: T201
Document.objects.all().delete()
settings_ctx.disable()
# ---------------------------------------------------------------------------
# TestSanityCheckerProfile
# ---------------------------------------------------------------------------
class TestSanityCheckerProfile:
"""Profile check_sanity() on a realistic corpus with real files."""
@pytest.fixture(autouse=True)
def _setup(self, sanity_corpus):
self.media = sanity_corpus["media"]
def test_sanity_full_corpus(self):
"""Full check_sanity() -- cProfile surfaces hot frames, tracemalloc shows peak."""
_, elapsed = profile_cpu(
lambda: check_sanity(scheduled=False),
label=f"check_sanity() [{NUM_DOCS} docs, real files]",
top=25,
)
_, peak_kib, delta_kib = measure_memory(
lambda: check_sanity(scheduled=False),
label=f"check_sanity() [{NUM_DOCS} docs] -- memory",
)
print("\n Summary:") # noqa: T201
print(f" Wall time (CPU profile run): {elapsed * 1000:.1f} ms") # noqa: T201
print(f" Peak memory (second run): {peak_kib:.1f} KiB") # noqa: T201
print(f" Memory delta: {delta_kib:+.1f} KiB") # noqa: T201
def test_sanity_query_pattern(self):
"""profile_block view: query count + query time + wall time in one summary."""
with profile_block(f"check_sanity() [{NUM_DOCS} docs] -- query count"):
check_sanity(scheduled=False)

273
test_search_profiling.py Normal file
View File

@@ -0,0 +1,273 @@
"""
Search performance profiling tests.
Run explicitly — excluded from the normal test suite:
uv run pytest -m profiling -s -p no:xdist --override-ini="addopts=" -v
The ``-s`` flag is required to see profile_block() output.
The ``-p no:xdist`` flag disables parallel execution for accurate measurements.
Corpus: 5 000 documents generated deterministically from a fixed Faker seed,
with realistic variety: 30 correspondents, 15 document types, 50 tags, ~500
notes spread across ~10 % of documents.
"""
from __future__ import annotations
import random
import pytest
from django.contrib.auth.models import User
from faker import Faker
from profiling import profile_block
from rest_framework.test import APIClient
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import Note
from documents.models import Tag
from documents.search import get_backend
from documents.search import reset_backend
from documents.search._backend import SearchMode
pytestmark = [pytest.mark.profiling, pytest.mark.search, pytest.mark.django_db]
# ---------------------------------------------------------------------------
# Corpus parameters
# ---------------------------------------------------------------------------
DOC_COUNT = 5_000
SEED = 42
NUM_CORRESPONDENTS = 30
NUM_DOC_TYPES = 15
NUM_TAGS = 50
NOTE_FRACTION = 0.10 # ~500 documents get a note
PAGE_SIZE = 25
def _build_corpus(rng: random.Random, fake: Faker) -> None:
"""
Insert the full corpus into the database and index it.
Uses bulk_create for the Document rows (fast) then handles the M2M tag
relationships and notes individually. Indexes the full corpus with a
single backend.rebuild() call.
"""
import datetime
# ---- lookup objects -------------------------------------------------
correspondents = [
Correspondent.objects.create(name=f"profcorp-{i}-{fake.company()}"[:128])
for i in range(NUM_CORRESPONDENTS)
]
doc_types = [
DocumentType.objects.create(name=f"proftype-{i}-{fake.word()}"[:128])
for i in range(NUM_DOC_TYPES)
]
tags = [
Tag.objects.create(name=f"proftag-{i}-{fake.word()}"[:100])
for i in range(NUM_TAGS)
]
note_user = User.objects.create_user(username="profnoteuser", password="x")
# ---- bulk-create documents ------------------------------------------
base_date = datetime.date(2018, 1, 1)
raw_docs = []
for i in range(DOC_COUNT):
day_offset = rng.randint(0, 6 * 365)
created = base_date + datetime.timedelta(days=day_offset)
raw_docs.append(
Document(
title=fake.sentence(nb_words=rng.randint(3, 9)).rstrip("."),
content="\n\n".join(
fake.paragraph(nb_sentences=rng.randint(3, 7))
for _ in range(rng.randint(2, 5))
),
checksum=f"PROF{i:07d}",
correspondent=rng.choice(correspondents + [None] * 8),
document_type=rng.choice(doc_types + [None] * 4),
created=created,
),
)
documents = Document.objects.bulk_create(raw_docs)
# ---- tags (M2M, post-bulk) ------------------------------------------
for doc in documents:
k = rng.randint(0, 5)
if k:
doc.tags.add(*rng.sample(tags, k))
# ---- notes on ~10 % of docs -----------------------------------------
note_docs = rng.sample(documents, int(DOC_COUNT * NOTE_FRACTION))
for doc in note_docs:
Note.objects.create(
document=doc,
note=fake.sentence(nb_words=rng.randint(6, 20)),
user=note_user,
)
# ---- build Tantivy index --------------------------------------------
backend = get_backend()
qs = Document.objects.select_related(
"correspondent",
"document_type",
"storage_path",
"owner",
).prefetch_related("tags", "notes__user", "custom_fields__field")
backend.rebuild(qs)
class TestSearchProfiling:
"""
Performance profiling for the Tantivy search backend and DRF API layer.
Each test builds a fresh 5 000-document corpus, exercises one hot path,
and prints profile_block() measurements to stdout. No correctness
assertions — the goal is to surface hot spots and track regressions.
"""
@pytest.fixture(autouse=True)
def _setup(self, tmp_path, settings):
index_dir = tmp_path / "index"
index_dir.mkdir()
settings.INDEX_DIR = index_dir
reset_backend()
rng = random.Random(SEED)
fake = Faker()
Faker.seed(SEED)
self.user = User.objects.create_superuser(
username="profiler",
password="admin",
)
self.client = APIClient()
self.client.force_authenticate(user=self.user)
_build_corpus(rng, fake)
yield
reset_backend()
# -- 1. Backend: search_ids relevance ---------------------------------
def test_profile_search_ids_relevance(self):
"""Profile: search_ids() with relevance ordering across several queries."""
backend = get_backend()
queries = [
"invoice payment",
"annual report",
"bank statement",
"contract agreement",
"receipt",
]
with profile_block(f"search_ids — relevance ({len(queries)} queries)"):
for q in queries:
backend.search_ids(q, user=None)
# -- 2. Backend: search_ids with Tantivy-native sort ------------------
def test_profile_search_ids_sorted(self):
"""Profile: search_ids() sorted by a Tantivy fast field (created)."""
backend = get_backend()
with profile_block("search_ids — sorted by created (asc + desc)"):
backend.search_ids(
"the",
user=None,
sort_field="created",
sort_reverse=False,
)
backend.search_ids(
"the",
user=None,
sort_field="created",
sort_reverse=True,
)
# -- 3. Backend: highlight_hits for a page of 25 ----------------------
def test_profile_highlight_hits(self):
"""Profile: highlight_hits() for a 25-document page."""
backend = get_backend()
all_ids = backend.search_ids("report", user=None)
page_ids = all_ids[:PAGE_SIZE]
with profile_block(f"highlight_hits — {len(page_ids)} docs"):
backend.highlight_hits("report", page_ids)
# -- 4. Backend: autocomplete -----------------------------------------
def test_profile_autocomplete(self):
"""Profile: autocomplete() with eight common prefixes."""
backend = get_backend()
prefixes = ["inv", "pay", "con", "rep", "sta", "acc", "doc", "fin"]
with profile_block(f"autocomplete — {len(prefixes)} prefixes"):
for prefix in prefixes:
backend.autocomplete(prefix, limit=10)
# -- 5. Backend: simple-mode search (TEXT and TITLE) ------------------
def test_profile_search_ids_simple_modes(self):
"""Profile: search_ids() in TEXT and TITLE simple-search modes."""
backend = get_backend()
queries = ["invoice 2023", "annual report", "bank statement"]
with profile_block(
f"search_ids — TEXT + TITLE modes ({len(queries)} queries each)",
):
for q in queries:
backend.search_ids(q, user=None, search_mode=SearchMode.TEXT)
backend.search_ids(q, user=None, search_mode=SearchMode.TITLE)
# -- 6. API: full round-trip, relevance + page 1 ----------------------
def test_profile_api_relevance_search(self):
"""Profile: full API search round-trip, relevance order, page 1."""
with profile_block(
f"API /documents/?query=… relevance (page 1, page_size={PAGE_SIZE})",
):
response = self.client.get(
f"/api/documents/?query=invoice+payment&page=1&page_size={PAGE_SIZE}",
)
assert response.status_code == 200
# -- 7. API: full round-trip, ORM-ordered (title) ---------------------
def test_profile_api_orm_sorted_search(self):
"""Profile: full API search round-trip with ORM-delegated sort (title)."""
with profile_block("API /documents/?query=…&ordering=title"):
response = self.client.get(
f"/api/documents/?query=report&ordering=title&page=1&page_size={PAGE_SIZE}",
)
assert response.status_code == 200
# -- 8. API: full round-trip, score sort ------------------------------
def test_profile_api_score_sort(self):
"""Profile: full API search with ordering=-score (relevance, preserve order)."""
with profile_block("API /documents/?query=…&ordering=-score"):
response = self.client.get(
f"/api/documents/?query=statement&ordering=-score&page=1&page_size={PAGE_SIZE}",
)
assert response.status_code == 200
# -- 9. API: full round-trip, with selection_data ---------------------
def test_profile_api_with_selection_data(self):
"""Profile: full API search including include_selection_data=true."""
with profile_block("API /documents/?query=…&include_selection_data=true"):
response = self.client.get(
f"/api/documents/?query=contract&page=1&page_size={PAGE_SIZE}"
"&include_selection_data=true",
)
assert response.status_code == 200
assert "selection_data" in response.data
# -- 10. API: paginated (page 2) --------------------------------------
def test_profile_api_page_2(self):
"""Profile: full API search, page 2 — exercises page offset arithmetic."""
with profile_block(f"API /documents/?query=…&page=2&page_size={PAGE_SIZE}"):
response = self.client.get(
f"/api/documents/?query=the&page=2&page_size={PAGE_SIZE}",
)
assert response.status_code == 200

231
test_workflow_profile.py Normal file
View File

@@ -0,0 +1,231 @@
"""
Workflow trigger matching profiling.
Run with:
uv run pytest ../test_workflow_profile.py \
-m profiling --override-ini="addopts=" -s -v
Corpus: 500 documents + correspondents + tags + sets of WorkflowTrigger
objects at 5 and 20 count to allow scaling comparisons.
Scenarios
---------
TestWorkflowMatchingProfile
- test_existing_document_5_workflows existing_document_matches_workflow x 5 triggers
- test_existing_document_20_workflows same x 20 triggers
- test_workflow_prefilter prefilter_documents_by_workflowtrigger on 500 docs
- test_trigger_type_comparison compare DOCUMENT_ADDED vs DOCUMENT_UPDATED overhead
"""
from __future__ import annotations
import random
import time
import pytest
from faker import Faker
from profiling import profile_block
from documents.matching import existing_document_matches_workflow
from documents.matching import prefilter_documents_by_workflowtrigger
from documents.models import Correspondent
from documents.models import Document
from documents.models import Tag
from documents.models import Workflow
from documents.models import WorkflowAction
from documents.models import WorkflowTrigger
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
NUM_DOCS = 500
NUM_CORRESPONDENTS = 10
NUM_TAGS = 20
SEED = 42
# ---------------------------------------------------------------------------
# Module-scoped fixture
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def module_db(django_db_setup, django_db_blocker):
"""Unlock the DB for the whole module (module-scoped)."""
with django_db_blocker.unblock():
yield
@pytest.fixture(scope="module")
def workflow_corpus(module_db):
"""
500 documents + correspondents + tags + sets of workflow triggers
at 5 and 20 count to allow scaling comparisons.
"""
fake = Faker()
Faker.seed(SEED)
rng = random.Random(SEED)
# ---- lookup objects ---------------------------------------------------
print("\n[setup] Creating lookup objects...") # noqa: T201
correspondents = [
Correspondent.objects.create(name=f"wfcorp-{i}-{fake.company()}"[:128])
for i in range(NUM_CORRESPONDENTS)
]
tags = [
Tag.objects.create(name=f"wftag-{i}-{fake.word()}"[:100])
for i in range(NUM_TAGS)
]
# ---- documents --------------------------------------------------------
print(f"[setup] Building {NUM_DOCS} documents...") # noqa: T201
raw_docs = []
for i in range(NUM_DOCS):
raw_docs.append(
Document(
title=fake.sentence(nb_words=4).rstrip("."),
content=fake.paragraph(nb_sentences=3),
checksum=f"WF{i:07d}",
correspondent=rng.choice(correspondents + [None] * 3),
),
)
documents = Document.objects.bulk_create(raw_docs, batch_size=500)
for doc in documents:
k = rng.randint(0, 3)
if k:
doc.tags.add(*rng.sample(tags, k))
sample_doc = documents[0]
print(f"[setup] Sample doc pk={sample_doc.pk}") # noqa: T201
# ---- build triggers at scale 5 and 20 --------------------------------
_wf_counter = [0]
def _make_triggers(n: int, trigger_type: int) -> list[WorkflowTrigger]:
triggers = []
for i in range(n):
# Alternate between no filter and a correspondent filter
corr = correspondents[i % NUM_CORRESPONDENTS] if i % 3 == 0 else None
trigger = WorkflowTrigger.objects.create(
type=trigger_type,
filter_has_correspondent=corr,
)
action = WorkflowAction.objects.create(
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
)
idx = _wf_counter[0]
_wf_counter[0] += 1
wf = Workflow.objects.create(name=f"wf-profile-{idx}")
wf.triggers.add(trigger)
wf.actions.add(action)
triggers.append(trigger)
return triggers
print("[setup] Creating workflow triggers...") # noqa: T201
triggers_5 = _make_triggers(5, WorkflowTrigger.WorkflowTriggerType.DOCUMENT_UPDATED)
triggers_20 = _make_triggers(
20,
WorkflowTrigger.WorkflowTriggerType.DOCUMENT_UPDATED,
)
triggers_added = _make_triggers(
5,
WorkflowTrigger.WorkflowTriggerType.DOCUMENT_ADDED,
)
yield {
"doc": sample_doc,
"triggers_5": triggers_5,
"triggers_20": triggers_20,
"triggers_added": triggers_added,
}
# Teardown
print("\n[teardown] Removing workflow corpus...") # noqa: T201
Workflow.objects.all().delete()
WorkflowTrigger.objects.all().delete()
WorkflowAction.objects.all().delete()
Document.objects.all().delete()
Correspondent.objects.all().delete()
Tag.objects.all().delete()
# ---------------------------------------------------------------------------
# TestWorkflowMatchingProfile
# ---------------------------------------------------------------------------
class TestWorkflowMatchingProfile:
"""Profile workflow trigger evaluation per document save."""
@pytest.fixture(autouse=True)
def _setup(self, workflow_corpus):
self.doc = workflow_corpus["doc"]
self.triggers_5 = workflow_corpus["triggers_5"]
self.triggers_20 = workflow_corpus["triggers_20"]
self.triggers_added = workflow_corpus["triggers_added"]
def test_existing_document_5_workflows(self):
"""existing_document_matches_workflow x 5 DOCUMENT_UPDATED triggers."""
doc = self.doc
triggers = self.triggers_5
with profile_block(
f"existing_document_matches_workflow [{len(triggers)} triggers]",
):
for trigger in triggers:
existing_document_matches_workflow(doc, trigger)
def test_existing_document_20_workflows(self):
"""existing_document_matches_workflow x 20 triggers -- shows linear scaling."""
doc = self.doc
triggers = self.triggers_20
with profile_block(
f"existing_document_matches_workflow [{len(triggers)} triggers]",
):
for trigger in triggers:
existing_document_matches_workflow(doc, trigger)
# Also time each call individually to show per-trigger overhead
timings = []
for trigger in triggers:
t0 = time.perf_counter()
existing_document_matches_workflow(doc, trigger)
timings.append((time.perf_counter() - t0) * 1_000_000)
avg_us = sum(timings) / len(timings)
print(f"\n Per-trigger avg: {avg_us:.1f} us (n={len(timings)})") # noqa: T201
def test_workflow_prefilter(self):
"""prefilter_documents_by_workflowtrigger on 500 docs -- tag + correspondent filters."""
qs = Document.objects.all()
print(f"\n Corpus: {qs.count()} documents") # noqa: T201
for trigger in self.triggers_20[:3]:
label = (
f"prefilter_documents_by_workflowtrigger "
f"[corr={trigger.filter_has_correspondent_id}]"
)
with profile_block(label):
result = prefilter_documents_by_workflowtrigger(qs, trigger)
# Evaluate the queryset
count = result.count()
print(f" -> {count} docs passed filter") # noqa: T201
def test_trigger_type_comparison(self):
"""Compare per-call overhead of DOCUMENT_UPDATED vs DOCUMENT_ADDED."""
doc = self.doc
runs = 200
for label, triggers in [
("DOCUMENT_UPDATED", self.triggers_5),
("DOCUMENT_ADDED", self.triggers_added),
]:
t0 = time.perf_counter()
for _ in range(runs):
for trigger in triggers:
existing_document_matches_workflow(doc, trigger)
total_calls = runs * len(triggers)
us_per_call = (time.perf_counter() - t0) / total_calls * 1_000_000
print( # noqa: T201
f" {label:<22s} {us_per_call:.2f} us/call "
f"({total_calls} calls, {len(triggers)} triggers)",
)