mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-04-12 02:58:52 +00:00
Compare commits
14 Commits
feature-se
...
feature-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0887203d45 | ||
|
|
ea14c0b06f | ||
|
|
a8dc332abb | ||
|
|
e64b9a4cfd | ||
|
|
6ba1acd7d3 | ||
|
|
d006b79fd1 | ||
|
|
24b754b44c | ||
|
|
a1a3520a8c | ||
|
|
23449cda17 | ||
|
|
ca3f5665ba | ||
|
|
9aa0914c3f | ||
|
|
fdd5e3ecb2 | ||
|
|
df3b656352 | ||
|
|
51e721733f |
79
SECURITY.md
79
SECURITY.md
@@ -2,8 +2,83 @@
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
The Paperless-ngx team and community take security bugs seriously. We appreciate your efforts to responsibly disclose your findings, and will make every effort to acknowledge your contributions.
|
||||
The Paperless-ngx team and community take security issues seriously. We appreciate good-faith reports and will make every effort to review legitimate findings responsibly.
|
||||
|
||||
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/paperless-ngx/paperless-ngx/security/advisories/new) tab.
|
||||
|
||||
The team will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
|
||||
After the initial reply to your report, the team may ask for additional information, reproduction steps, affected versions, configuration details, or proof-of-concept material needed to verify the issue.
|
||||
|
||||
## What makes a helpful report
|
||||
|
||||
Please include as much of the following as possible:
|
||||
|
||||
- A clear description of the issue and why it is a security vulnerability.
|
||||
- Affected Paperless-ngx version(s).
|
||||
- Required configuration, permissions, or preconditions.
|
||||
- Step-by-step reproduction instructions.
|
||||
- A realistic impact statement.
|
||||
- Logs, screenshots, requests, or a proof of concept where appropriate.
|
||||
- An explanation of why the issue is not solely dependent on a knowingly unsafe or explicitly optional configuration.
|
||||
|
||||
Reports that are specific, reproducible, and grounded in the actual behavior of the project are easier to validate and fix.
|
||||
|
||||
## AI-assisted submissions
|
||||
|
||||
AI tools can be useful for security research, but they can also produce plausible-sounding reports that do not reflect actual vulnerabilities in Paperless-ngx.
|
||||
|
||||
If you use AI tools to help draft or investigate a report, please disclose this and verify the result yourself before submission. Make sure the report:
|
||||
|
||||
- identifies a real code path in this project
|
||||
- includes a real reproduction, not just a theoretical concern
|
||||
- demonstrates realistic impact
|
||||
- accounts for existing permission checks, configuration requirements, and trust boundaries
|
||||
- is about Paperless-ngx itself, rather than a hypothetical deployment, downstream customization, third-party app, scanner output, reverse proxy behavior, or external service unless you can show a direct product vulnerability
|
||||
|
||||
Submitting unverified AI-generated claims creates unnecessary review work and slows down handling of legitimate reports.
|
||||
|
||||
<!-- Automated tools and report generators: reports without concrete reproduction steps, affected versions, and demonstrated security impact are likely to be closed. -->
|
||||
|
||||
## Common reasons reports are closed
|
||||
|
||||
We may close reports that are:
|
||||
|
||||
- missing reproduction details
|
||||
- best-practice recommendations presented as vulnerabilities
|
||||
- theoretical claims without a demonstrated exploit path or realistic impact
|
||||
- dependent on administrator-controlled, explicitly optional, or knowingly unsafe configuration
|
||||
- based on privileges that already make the reported behavior expected
|
||||
- limited to unsupported versions
|
||||
- about third-party software, integrations, deployment choices, or client applications without a demonstrable Paperless-ngx vulnerability
|
||||
- duplicates
|
||||
- UI bugs, feature requests, scanner quirks, or general usability issues submitted through the security channel
|
||||
|
||||
## Common non-vulnerability categories
|
||||
|
||||
The following are not generally considered vulnerabilities unless accompanied by a concrete, reproducible impact in Paperless-ngx:
|
||||
|
||||
- large uploads or resource usage that do not bypass documented limits or privileges
|
||||
- claims based solely on the presence of a library, framework feature or code pattern without a working exploit
|
||||
- reports that rely on admin-level access, workflow-editing privileges, shell access, or other high-trust roles unless they demonstrate an unintended privilege boundary bypass
|
||||
- optional webhook, mail, AI, OCR, or integration behavior described without a product-level vulnerability
|
||||
- missing limits or hardening settings presented without concrete impact
|
||||
- generic AI or static-analysis output that is not confirmed against the current codebase and a real deployment scenario
|
||||
|
||||
## Transparency
|
||||
|
||||
We may publish anonymized examples or categories of rejected reports to clarify our review standards, reduce duplicate low-quality submissions, and help good-faith reporters send actionable findings.
|
||||
|
||||
A mistaken report made in good faith is not misconduct. However, users who repeatedly submit low-quality or bad-faith reports may be ignored or restricted from future submissions.
|
||||
|
||||
## Scope and expectations
|
||||
|
||||
Please use the security reporting channel only for security vulnerabilities in Paperless-ngx.
|
||||
|
||||
Please do not use the security advisory system for:
|
||||
|
||||
- support questions
|
||||
- general bug reports
|
||||
- feature requests
|
||||
- browser compatibility issues
|
||||
- issues in third-party mobile apps, reverse proxies, or deployment tooling unless you can demonstrate a Paperless-ngx vulnerability
|
||||
|
||||
The team will review reports as time permits, but submission does not guarantee that a report is valid, in scope, or will result in a fix. Reports that do not describe a reproducible product-level issue may be closed without extended back-and-forth.
|
||||
|
||||
150
profiling.py
Normal file
150
profiling.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Temporary profiling utilities for comparing implementations.
|
||||
|
||||
Usage in a management command or shell::
|
||||
|
||||
from profiling import profile_block, profile_cpu, measure_memory
|
||||
|
||||
with profile_block("new check_sanity"):
|
||||
messages = check_sanity()
|
||||
|
||||
with profile_block("old check_sanity"):
|
||||
messages = check_sanity_old()
|
||||
|
||||
Drop this file when done.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tracemalloc
|
||||
from collections.abc import Callable # noqa: TC003
|
||||
from collections.abc import Generator # noqa: TC003
|
||||
from contextlib import contextmanager
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
|
||||
from django.db import connection
|
||||
from django.db import reset_queries
|
||||
from django.test.utils import override_settings
|
||||
|
||||
|
||||
@contextmanager
|
||||
def profile_block(label: str = "block") -> Generator[None, None, None]:
|
||||
"""Profile memory, wall time, and DB queries for a code block.
|
||||
|
||||
Prints a summary to stdout on exit. Requires no external packages.
|
||||
Enables DEBUG temporarily to capture Django's query log.
|
||||
"""
|
||||
tracemalloc.start()
|
||||
snapshot_before = tracemalloc.take_snapshot()
|
||||
|
||||
with override_settings(DEBUG=True):
|
||||
reset_queries()
|
||||
start = perf_counter()
|
||||
|
||||
yield
|
||||
|
||||
elapsed = perf_counter() - start
|
||||
queries = list(connection.queries)
|
||||
|
||||
snapshot_after = tracemalloc.take_snapshot()
|
||||
_, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
|
||||
# Compare snapshots for top allocations
|
||||
stats = snapshot_after.compare_to(snapshot_before, "lineno")
|
||||
|
||||
query_time = sum(float(q["time"]) for q in queries)
|
||||
mem_diff = sum(s.size_diff for s in stats)
|
||||
|
||||
print(f"\n{'=' * 60}") # noqa: T201
|
||||
print(f" Profile: {label}") # noqa: T201
|
||||
print(f"{'=' * 60}") # noqa: T201
|
||||
print(f" Wall time: {elapsed:.4f}s") # noqa: T201
|
||||
print(f" Queries: {len(queries)} ({query_time:.4f}s)") # noqa: T201
|
||||
print(f" Memory delta: {mem_diff / 1024:.1f} KiB") # noqa: T201
|
||||
print(f" Peak memory: {peak / 1024:.1f} KiB") # noqa: T201
|
||||
print("\n Top 5 allocations:") # noqa: T201
|
||||
for stat in stats[:5]:
|
||||
print(f" {stat}") # noqa: T201
|
||||
print(f"{'=' * 60}\n") # noqa: T201
|
||||
|
||||
|
||||
def profile_cpu(
|
||||
fn: Callable[[], Any],
|
||||
*,
|
||||
label: str,
|
||||
top: int = 30,
|
||||
sort: str = "cumtime",
|
||||
) -> tuple[Any, float]:
|
||||
"""Run *fn()* under cProfile, print stats, return (result, elapsed_s).
|
||||
|
||||
Args:
|
||||
fn: Zero-argument callable to profile.
|
||||
label: Human-readable label printed in the header.
|
||||
top: Number of cProfile rows to print.
|
||||
sort: cProfile sort key (default: cumulative time).
|
||||
|
||||
Returns:
|
||||
``(result, elapsed_s)`` where *result* is the return value of *fn()*.
|
||||
"""
|
||||
import cProfile
|
||||
import io
|
||||
import pstats
|
||||
|
||||
pr = cProfile.Profile()
|
||||
t0 = perf_counter()
|
||||
pr.enable()
|
||||
result = fn()
|
||||
pr.disable()
|
||||
elapsed = perf_counter() - t0
|
||||
|
||||
buf = io.StringIO()
|
||||
ps = pstats.Stats(pr, stream=buf).sort_stats(sort)
|
||||
ps.print_stats(top)
|
||||
|
||||
print(f"\n{'=' * 72}") # noqa: T201
|
||||
print(f" {label}") # noqa: T201
|
||||
print(f" wall time: {elapsed * 1000:.1f} ms") # noqa: T201
|
||||
print(f"{'=' * 72}") # noqa: T201
|
||||
print(buf.getvalue()) # noqa: T201
|
||||
|
||||
return result, elapsed
|
||||
|
||||
|
||||
def measure_memory(fn: Callable[[], Any], *, label: str) -> tuple[Any, float, float]:
|
||||
"""Run *fn()* under tracemalloc, print allocation report.
|
||||
|
||||
Args:
|
||||
fn: Zero-argument callable to profile.
|
||||
label: Human-readable label printed in the header.
|
||||
|
||||
Returns:
|
||||
``(result, peak_kib, delta_kib)``.
|
||||
"""
|
||||
tracemalloc.start()
|
||||
snapshot_before = tracemalloc.take_snapshot()
|
||||
t0 = perf_counter()
|
||||
result = fn()
|
||||
elapsed = perf_counter() - t0
|
||||
snapshot_after = tracemalloc.take_snapshot()
|
||||
_, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
|
||||
stats = snapshot_after.compare_to(snapshot_before, "lineno")
|
||||
delta_kib = sum(s.size_diff for s in stats) / 1024
|
||||
|
||||
print(f"\n{'=' * 72}") # noqa: T201
|
||||
print(f" [memory] {label}") # noqa: T201
|
||||
print(f" wall time: {elapsed * 1000:.1f} ms") # noqa: T201
|
||||
print(f" memory delta: {delta_kib:+.1f} KiB") # noqa: T201
|
||||
print(f" peak traced: {peak / 1024:.1f} KiB") # noqa: T201
|
||||
print(f"{'=' * 72}") # noqa: T201
|
||||
print(" Top allocation sites (by size_diff):") # noqa: T201
|
||||
for stat in stats[:20]:
|
||||
if stat.size_diff != 0:
|
||||
print( # noqa: T201
|
||||
f" {stat.size_diff / 1024:+8.1f} KiB {stat.traceback.format()[0]}",
|
||||
)
|
||||
|
||||
return result, peak / 1024, delta_kib
|
||||
@@ -312,6 +312,7 @@ markers = [
|
||||
"date_parsing: Tests which cover date parsing from content or filename",
|
||||
"management: Tests which cover management commands/functionality",
|
||||
"search: Tests for the Tantivy search backend",
|
||||
"profiling: Performance profiling tests — print measurements, no assertions",
|
||||
]
|
||||
|
||||
[tool.pytest_env]
|
||||
|
||||
@@ -43,7 +43,7 @@
|
||||
</div>
|
||||
<p class="card-text">
|
||||
@if (document) {
|
||||
@if (hasSearchHighlights) {
|
||||
@if (document.__search_hit__ && document.__search_hit__.highlights) {
|
||||
<span [innerHtml]="document.__search_hit__.highlights"></span>
|
||||
}
|
||||
@for (highlight of searchNoteHighlights; track highlight) {
|
||||
@@ -52,7 +52,7 @@
|
||||
<span [innerHtml]="highlight"></span>
|
||||
</span>
|
||||
}
|
||||
@if (shouldShowContentFallback) {
|
||||
@if (!document.__search_hit__?.score) {
|
||||
<span class="result-content">{{contentTrimmed}}</span>
|
||||
}
|
||||
} @else {
|
||||
|
||||
@@ -127,19 +127,6 @@ describe('DocumentCardLargeComponent', () => {
|
||||
expect(component.searchNoteHighlights).toContain('<span>bananas</span>')
|
||||
})
|
||||
|
||||
it('should fall back to document content when a search hit has no highlights', () => {
|
||||
component.document.__search_hit__ = {
|
||||
score: 0.9,
|
||||
rank: 1,
|
||||
highlights: '',
|
||||
note_highlights: null,
|
||||
}
|
||||
fixture.detectChanges()
|
||||
|
||||
expect(fixture.nativeElement.textContent).toContain('Cupcake ipsum')
|
||||
expect(component.shouldShowContentFallback).toBe(true)
|
||||
})
|
||||
|
||||
it('should try to close the preview on mouse leave', () => {
|
||||
component.popupPreview = {
|
||||
close: jest.fn(),
|
||||
|
||||
@@ -164,17 +164,6 @@ export class DocumentCardLargeComponent
|
||||
)
|
||||
}
|
||||
|
||||
get hasSearchHighlights() {
|
||||
return Boolean(this.document?.__search_hit__?.highlights?.trim()?.length)
|
||||
}
|
||||
|
||||
get shouldShowContentFallback() {
|
||||
return (
|
||||
this.document?.__search_hit__?.score == null ||
|
||||
(!this.hasSearchHighlights && this.searchNoteHighlights.length === 0)
|
||||
)
|
||||
}
|
||||
|
||||
get notesEnabled(): boolean {
|
||||
return this.settingsService.get(SETTINGS_KEYS.NOTES_ENABLED)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from documents.search._backend import SearchHit
|
||||
from documents.search._backend import SearchIndexLockError
|
||||
from documents.search._backend import SearchMode
|
||||
from documents.search._backend import SearchResults
|
||||
from documents.search._backend import TantivyBackend
|
||||
from documents.search._backend import TantivyRelevanceList
|
||||
from documents.search._backend import WriteBatch
|
||||
@@ -10,9 +10,9 @@ from documents.search._schema import needs_rebuild
|
||||
from documents.search._schema import wipe_index
|
||||
|
||||
__all__ = [
|
||||
"SearchHit",
|
||||
"SearchIndexLockError",
|
||||
"SearchMode",
|
||||
"SearchResults",
|
||||
"TantivyBackend",
|
||||
"TantivyRelevanceList",
|
||||
"WriteBatch",
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from html import escape
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Self
|
||||
from typing import TypedDict
|
||||
@@ -55,36 +54,6 @@ class SearchMode(StrEnum):
|
||||
TITLE = "title"
|
||||
|
||||
|
||||
def _render_snippet_html(snippet: tantivy.Snippet) -> str:
|
||||
fragment = snippet.fragment()
|
||||
highlighted = sorted(snippet.highlighted(), key=lambda r: r.start)
|
||||
|
||||
if not highlighted:
|
||||
return escape(fragment)
|
||||
|
||||
parts: list[str] = []
|
||||
cursor = 0
|
||||
fragment_len = len(fragment)
|
||||
|
||||
for highlight in highlighted:
|
||||
start = max(0, min(fragment_len, highlight.start))
|
||||
end = max(start, min(fragment_len, highlight.end))
|
||||
|
||||
if end <= cursor:
|
||||
continue
|
||||
|
||||
if start > cursor:
|
||||
parts.append(escape(fragment[cursor:start]))
|
||||
|
||||
parts.append(f'<span class="match">{escape(fragment[start:end])}</span>')
|
||||
cursor = end
|
||||
|
||||
if cursor < fragment_len:
|
||||
parts.append(escape(fragment[cursor:]))
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _extract_autocomplete_words(text_sources: list[str]) -> set[str]:
|
||||
"""Extract and normalize words for autocomplete.
|
||||
|
||||
@@ -119,63 +88,45 @@ class SearchHit(TypedDict):
|
||||
highlights: dict[str, str]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SearchResults:
|
||||
"""
|
||||
Container for search results with pagination metadata.
|
||||
|
||||
Attributes:
|
||||
hits: List of search results with scores and highlights
|
||||
total: Total matching documents across all pages (for pagination)
|
||||
query: Preprocessed query string after date/syntax rewriting
|
||||
"""
|
||||
|
||||
hits: list[SearchHit]
|
||||
total: int # total matching documents (for pagination)
|
||||
query: str # preprocessed query string
|
||||
|
||||
|
||||
class TantivyRelevanceList:
|
||||
"""
|
||||
DRF-compatible list wrapper for Tantivy search results.
|
||||
DRF-compatible list wrapper for Tantivy search hits.
|
||||
|
||||
Holds a lightweight ordered list of IDs (for pagination count and
|
||||
``selection_data``) together with a small page of rich ``SearchHit``
|
||||
dicts (for serialization). DRF's ``PageNumberPagination`` calls
|
||||
``__len__`` to compute the total page count and ``__getitem__`` to
|
||||
slice the displayed page.
|
||||
Provides paginated access to search results while storing all hits in memory
|
||||
for efficient ID retrieval. Used by Django REST framework for pagination.
|
||||
|
||||
Args:
|
||||
ordered_ids: All matching document IDs in display order.
|
||||
page_hits: Rich SearchHit dicts for the requested DRF page only.
|
||||
page_offset: Index into *ordered_ids* where *page_hits* starts.
|
||||
Methods:
|
||||
__len__: Returns total hit count for pagination calculations
|
||||
__getitem__: Slices the hit list for page-specific results
|
||||
|
||||
Note: Stores ALL post-filter hits so get_all_result_ids() can return
|
||||
every matching document ID without requiring a second search query.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ordered_ids: list[int],
|
||||
page_hits: list[SearchHit],
|
||||
page_offset: int = 0,
|
||||
) -> None:
|
||||
self._ordered_ids = ordered_ids
|
||||
self._page_hits = page_hits
|
||||
self._page_offset = page_offset
|
||||
def __init__(self, hits: list[SearchHit]) -> None:
|
||||
self._hits = hits
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._ordered_ids)
|
||||
return len(self._hits)
|
||||
|
||||
def __getitem__(self, key: int | slice) -> SearchHit | list[SearchHit]:
|
||||
if isinstance(key, int):
|
||||
idx = key if key >= 0 else len(self._ordered_ids) + key
|
||||
if self._page_offset <= idx < self._page_offset + len(self._page_hits):
|
||||
return self._page_hits[idx - self._page_offset]
|
||||
return SearchHit(
|
||||
id=self._ordered_ids[key],
|
||||
score=0.0,
|
||||
rank=idx + 1,
|
||||
highlights={},
|
||||
)
|
||||
start = key.start or 0
|
||||
stop = key.stop or len(self._ordered_ids)
|
||||
# DRF slices to extract the current page. If the slice aligns
|
||||
# with our pre-fetched page_hits, return them directly.
|
||||
# We only check start — DRF always slices with stop=start+page_size,
|
||||
# which exceeds page_hits length on the last page.
|
||||
if start == self._page_offset:
|
||||
return self._page_hits[: stop - start]
|
||||
# Fallback: return stub dicts (no highlights).
|
||||
return [
|
||||
SearchHit(id=doc_id, score=0.0, rank=start + i + 1, highlights={})
|
||||
for i, doc_id in enumerate(self._ordered_ids[key])
|
||||
]
|
||||
|
||||
def get_all_ids(self) -> list[int]:
|
||||
"""Return all matching document IDs in display order."""
|
||||
return self._ordered_ids
|
||||
def __getitem__(self, key: slice) -> list[SearchHit]:
|
||||
return self._hits[key]
|
||||
|
||||
|
||||
class SearchIndexLockError(Exception):
|
||||
@@ -255,13 +206,10 @@ class WriteBatch:
|
||||
"""
|
||||
Remove a document from the batch by its primary key.
|
||||
|
||||
Uses range_query instead of term_query to work around a tantivy-py bug
|
||||
where Python integers are inferred as i64, producing Terms that never
|
||||
match u64 fields.
|
||||
|
||||
TODO: Replace with term_query("id", doc_id) once
|
||||
https://github.com/quickwit-oss/tantivy-py/pull/642 lands.
|
||||
Uses range query instead of term query to work around unsigned integer
|
||||
type detection bug in tantivy-py 0.25.
|
||||
"""
|
||||
# Use range query to work around u64 deletion bug
|
||||
self._writer.delete_documents_by_query(
|
||||
tantivy.Query.range_query(
|
||||
self._backend._schema,
|
||||
@@ -286,34 +234,6 @@ class TantivyBackend:
|
||||
the underlying index directory changes (e.g., during test isolation).
|
||||
"""
|
||||
|
||||
# Maps DRF ordering field names to Tantivy index field names.
|
||||
SORT_FIELD_MAP: dict[str, str] = {
|
||||
"title": "title_sort",
|
||||
"correspondent__name": "correspondent_sort",
|
||||
"document_type__name": "type_sort",
|
||||
"created": "created",
|
||||
"added": "added",
|
||||
"modified": "modified",
|
||||
"archive_serial_number": "asn",
|
||||
"page_count": "page_count",
|
||||
"num_notes": "num_notes",
|
||||
}
|
||||
|
||||
# Fields where Tantivy's sort order matches the ORM's sort order.
|
||||
# Text-based fields (title, correspondent__name, document_type__name)
|
||||
# are excluded because Tantivy's tokenized fast fields produce different
|
||||
# ordering than the ORM's collation-based ordering.
|
||||
SORTABLE_FIELDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"created",
|
||||
"added",
|
||||
"modified",
|
||||
"archive_serial_number",
|
||||
"page_count",
|
||||
"num_notes",
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(self, path: Path | None = None):
|
||||
# path=None → in-memory index (for tests)
|
||||
# path=some_dir → on-disk index (for production)
|
||||
@@ -352,36 +272,6 @@ class TantivyBackend:
|
||||
if self._index is None:
|
||||
self.open() # pragma: no cover
|
||||
|
||||
def _parse_query(
|
||||
self,
|
||||
query: str,
|
||||
search_mode: SearchMode,
|
||||
) -> tantivy.Query:
|
||||
"""Parse a user query string into a Tantivy Query object."""
|
||||
tz = get_current_timezone()
|
||||
if search_mode is SearchMode.TEXT:
|
||||
return parse_simple_text_query(self._index, query)
|
||||
elif search_mode is SearchMode.TITLE:
|
||||
return parse_simple_title_query(self._index, query)
|
||||
else:
|
||||
return parse_user_query(self._index, query, tz)
|
||||
|
||||
def _apply_permission_filter(
|
||||
self,
|
||||
query: tantivy.Query,
|
||||
user: AbstractBaseUser | None,
|
||||
) -> tantivy.Query:
|
||||
"""Wrap a query with a permission filter if the user is not a superuser."""
|
||||
if user is not None:
|
||||
permission_filter = build_permission_filter(self._schema, user)
|
||||
return tantivy.Query.boolean_query(
|
||||
[
|
||||
(tantivy.Occur.Must, query),
|
||||
(tantivy.Occur.Must, permission_filter),
|
||||
],
|
||||
)
|
||||
return query
|
||||
|
||||
def _build_tantivy_doc(
|
||||
self,
|
||||
document: Document,
|
||||
@@ -436,17 +326,12 @@ class TantivyBackend:
|
||||
doc.add_unsigned("tag_id", tag.pk)
|
||||
tag_names.append(tag.name)
|
||||
|
||||
# Notes — JSON for structured queries (notes.user:alice, notes.note:text).
|
||||
# notes_text is a plain-text companion for snippet/highlight generation;
|
||||
# tantivy's SnippetGenerator does not support JSON fields.
|
||||
# Notes — JSON for structured queries (notes.user:alice, notes.note:text),
|
||||
# companion text field for default full-text search.
|
||||
num_notes = 0
|
||||
note_texts: list[str] = []
|
||||
for note in document.notes.all():
|
||||
num_notes += 1
|
||||
doc.add_json("notes", {"note": note.note, "user": note.user.username})
|
||||
note_texts.append(note.note)
|
||||
if note_texts:
|
||||
doc.add_text("notes_text", " ".join(note_texts))
|
||||
|
||||
# Custom fields — JSON for structured queries (custom_fields.name:x, custom_fields.value:y),
|
||||
# companion text field for default full-text search.
|
||||
@@ -540,127 +425,155 @@ class TantivyBackend:
|
||||
with self.batch_update(lock_timeout=5.0) as batch:
|
||||
batch.remove(doc_id)
|
||||
|
||||
def highlight_hits(
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
doc_ids: list[int],
|
||||
user: AbstractBaseUser | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
sort_field: str | None,
|
||||
*,
|
||||
sort_reverse: bool,
|
||||
search_mode: SearchMode = SearchMode.QUERY,
|
||||
rank_start: int = 1,
|
||||
) -> list[SearchHit]:
|
||||
) -> SearchResults:
|
||||
"""
|
||||
Generate SearchHit dicts with highlights for specific document IDs.
|
||||
Execute a search query against the document index.
|
||||
|
||||
Unlike search(), this does not execute a ranked query — it looks up
|
||||
each document by ID and generates snippets against the provided query.
|
||||
Use this when you already know which documents to display (from
|
||||
search_ids + ORM filtering) and just need highlight data.
|
||||
Processes the user query through date rewriting, normalization, and
|
||||
permission filtering before executing against Tantivy. Supports both
|
||||
relevance-based and field-based sorting.
|
||||
|
||||
Note: Each doc_id requires an individual index lookup because tantivy-py
|
||||
does not yet expose a batch fast-field read API. This is acceptable for
|
||||
page-sized batches (typically 25 docs) but should not be called with
|
||||
thousands of IDs.
|
||||
|
||||
TODO: When https://github.com/quickwit-oss/tantivy-py/pull/641 lands,
|
||||
the per-doc range_query lookups here can be replaced with a single
|
||||
collect_u64_fast_field("id", doc_addresses) call.
|
||||
QUERY search mode supports natural date keywords, field filters, etc.
|
||||
TITLE search mode treats the query as plain text to search for in title only
|
||||
TEXT search mode treats the query as plain text to search for in title and content
|
||||
|
||||
Args:
|
||||
query: The search query (used for snippet generation)
|
||||
doc_ids: Ordered list of document IDs to generate hits for
|
||||
search_mode: Query parsing mode (for building the snippet query)
|
||||
rank_start: Starting rank value (1-based absolute position in the
|
||||
full result set; pass ``page_offset + 1`` for paginated calls)
|
||||
query: User's search query
|
||||
user: User for permission filtering (None for superuser/no filtering)
|
||||
page: Page number (1-indexed) for pagination
|
||||
page_size: Number of results per page
|
||||
sort_field: Field to sort by (None for relevance ranking)
|
||||
sort_reverse: Whether to reverse the sort order
|
||||
search_mode: "query" for advanced Tantivy syntax, "text" for
|
||||
plain-text search over title and content only, "title" for
|
||||
plain-text search over title only
|
||||
|
||||
Returns:
|
||||
List of SearchHit dicts in the same order as doc_ids
|
||||
SearchResults with hits, total count, and processed query
|
||||
"""
|
||||
if not doc_ids:
|
||||
return []
|
||||
|
||||
self._ensure_open()
|
||||
user_query = self._parse_query(query, search_mode)
|
||||
tz = get_current_timezone()
|
||||
if search_mode is SearchMode.TEXT:
|
||||
user_query = parse_simple_text_query(self._index, query)
|
||||
elif search_mode is SearchMode.TITLE:
|
||||
user_query = parse_simple_title_query(self._index, query)
|
||||
else:
|
||||
user_query = parse_user_query(self._index, query, tz)
|
||||
|
||||
# For notes_text snippet generation, we need a query that targets the
|
||||
# notes_text field directly. user_query may contain JSON-field terms
|
||||
# (e.g. notes.note:urgent) that the SnippetGenerator cannot resolve
|
||||
# against a text field. Strip field:value prefixes so bare terms like
|
||||
# "urgent" are re-parsed against notes_text, producing highlights even
|
||||
# when the original query used structured syntax.
|
||||
bare_query = re.sub(r"\w[\w.]*:", "", query).strip()
|
||||
try:
|
||||
notes_text_query = (
|
||||
self._index.parse_query(bare_query, ["notes_text"])
|
||||
if bare_query
|
||||
else user_query
|
||||
)
|
||||
except Exception:
|
||||
notes_text_query = user_query
|
||||
|
||||
searcher = self._index.searcher()
|
||||
snippet_generator = None
|
||||
notes_snippet_generator = None
|
||||
hits: list[SearchHit] = []
|
||||
|
||||
for rank, doc_id in enumerate(doc_ids, start=rank_start):
|
||||
# Look up document by ID, scoring against the user query so that
|
||||
# the returned SearchHit carries a real BM25 relevance score.
|
||||
id_query = tantivy.Query.range_query(
|
||||
self._schema,
|
||||
"id",
|
||||
tantivy.FieldType.Unsigned,
|
||||
doc_id,
|
||||
doc_id,
|
||||
)
|
||||
scored_query = tantivy.Query.boolean_query(
|
||||
# Apply permission filter if user is not None (not superuser)
|
||||
if user is not None:
|
||||
permission_filter = build_permission_filter(self._schema, user)
|
||||
final_query = tantivy.Query.boolean_query(
|
||||
[
|
||||
(tantivy.Occur.Must, user_query),
|
||||
(tantivy.Occur.Must, id_query),
|
||||
(tantivy.Occur.Must, permission_filter),
|
||||
],
|
||||
)
|
||||
results = searcher.search(scored_query, limit=1)
|
||||
else:
|
||||
final_query = user_query
|
||||
|
||||
if not results.hits:
|
||||
continue
|
||||
searcher = self._index.searcher()
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
score, doc_address = results.hits[0]
|
||||
# Map sort fields
|
||||
sort_field_map = {
|
||||
"title": "title_sort",
|
||||
"correspondent__name": "correspondent_sort",
|
||||
"document_type__name": "type_sort",
|
||||
"created": "created",
|
||||
"added": "added",
|
||||
"modified": "modified",
|
||||
"archive_serial_number": "asn",
|
||||
"page_count": "page_count",
|
||||
"num_notes": "num_notes",
|
||||
}
|
||||
|
||||
# Perform search
|
||||
if sort_field and sort_field in sort_field_map:
|
||||
mapped_field = sort_field_map[sort_field]
|
||||
results = searcher.search(
|
||||
final_query,
|
||||
limit=offset + page_size,
|
||||
order_by_field=mapped_field,
|
||||
order=tantivy.Order.Desc if sort_reverse else tantivy.Order.Asc,
|
||||
)
|
||||
# Field sorting: hits are still (score, DocAddress) tuples; score unused
|
||||
all_hits = [(hit[1], 0.0) for hit in results.hits]
|
||||
else:
|
||||
# Score-based search: hits are (score, DocAddress) tuples
|
||||
results = searcher.search(final_query, limit=offset + page_size)
|
||||
all_hits = [(hit[1], hit[0]) for hit in results.hits]
|
||||
|
||||
total = results.count
|
||||
|
||||
# Normalize scores for score-based searches
|
||||
if not sort_field and all_hits:
|
||||
max_score = max(hit[1] for hit in all_hits) or 1.0
|
||||
all_hits = [(hit[0], hit[1] / max_score) for hit in all_hits]
|
||||
|
||||
# Apply threshold filter if configured (score-based search only)
|
||||
threshold = settings.ADVANCED_FUZZY_SEARCH_THRESHOLD
|
||||
if threshold is not None and not sort_field:
|
||||
all_hits = [hit for hit in all_hits if hit[1] >= threshold]
|
||||
|
||||
# Get the page's hits
|
||||
page_hits = all_hits[offset : offset + page_size]
|
||||
|
||||
# Build result hits with highlights
|
||||
hits: list[SearchHit] = []
|
||||
snippet_generator = None
|
||||
notes_snippet_generator = None
|
||||
|
||||
for rank, (doc_address, score) in enumerate(page_hits, start=offset + 1):
|
||||
# Get the actual document from the searcher using the doc address
|
||||
actual_doc = searcher.doc(doc_address)
|
||||
doc_dict = actual_doc.to_dict()
|
||||
doc_id = doc_dict["id"][0]
|
||||
|
||||
highlights: dict[str, str] = {}
|
||||
try:
|
||||
if snippet_generator is None:
|
||||
snippet_generator = tantivy.SnippetGenerator.create(
|
||||
searcher,
|
||||
user_query,
|
||||
self._schema,
|
||||
"content",
|
||||
)
|
||||
|
||||
content_html = _render_snippet_html(
|
||||
snippet_generator.snippet_from_doc(actual_doc),
|
||||
)
|
||||
if content_html:
|
||||
highlights["content"] = content_html
|
||||
|
||||
if "notes_text" in doc_dict:
|
||||
# Use notes_text (plain text) for snippet generation — tantivy's
|
||||
# SnippetGenerator does not support JSON fields.
|
||||
if notes_snippet_generator is None:
|
||||
notes_snippet_generator = tantivy.SnippetGenerator.create(
|
||||
# Generate highlights if score > 0
|
||||
if score > 0:
|
||||
try:
|
||||
if snippet_generator is None:
|
||||
snippet_generator = tantivy.SnippetGenerator.create(
|
||||
searcher,
|
||||
notes_text_query,
|
||||
final_query,
|
||||
self._schema,
|
||||
"notes_text",
|
||||
"content",
|
||||
)
|
||||
notes_html = _render_snippet_html(
|
||||
notes_snippet_generator.snippet_from_doc(actual_doc),
|
||||
)
|
||||
if notes_html:
|
||||
highlights["notes"] = notes_html
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
logger.debug("Failed to generate highlights for doc %s", doc_id)
|
||||
content_snippet = snippet_generator.snippet_from_doc(actual_doc)
|
||||
if content_snippet:
|
||||
highlights["content"] = str(content_snippet)
|
||||
|
||||
# Try notes highlights
|
||||
if "notes" in doc_dict:
|
||||
if notes_snippet_generator is None:
|
||||
notes_snippet_generator = tantivy.SnippetGenerator.create(
|
||||
searcher,
|
||||
final_query,
|
||||
self._schema,
|
||||
"notes",
|
||||
)
|
||||
notes_snippet = notes_snippet_generator.snippet_from_doc(
|
||||
actual_doc,
|
||||
)
|
||||
if notes_snippet:
|
||||
highlights["notes"] = str(notes_snippet)
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
logger.debug("Failed to generate highlights for doc %s", doc_id)
|
||||
|
||||
hits.append(
|
||||
SearchHit(
|
||||
@@ -671,69 +584,11 @@ class TantivyBackend:
|
||||
),
|
||||
)
|
||||
|
||||
return hits
|
||||
|
||||
def search_ids(
|
||||
self,
|
||||
query: str,
|
||||
user: AbstractBaseUser | None,
|
||||
*,
|
||||
sort_field: str | None = None,
|
||||
sort_reverse: bool = False,
|
||||
search_mode: SearchMode = SearchMode.QUERY,
|
||||
limit: int | None = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Return document IDs matching a query — no highlights or scores.
|
||||
|
||||
This is the lightweight companion to search(). Use it when you need the
|
||||
full set of matching IDs (e.g. for ``selection_data``) but don't need
|
||||
scores, ranks, or highlights.
|
||||
|
||||
Args:
|
||||
query: User's search query
|
||||
user: User for permission filtering (None for superuser/no filtering)
|
||||
sort_field: Field to sort by (None for relevance ranking)
|
||||
sort_reverse: Whether to reverse the sort order
|
||||
search_mode: Query parsing mode (QUERY, TEXT, or TITLE)
|
||||
limit: Maximum number of IDs to return (None = all matching docs)
|
||||
|
||||
Returns:
|
||||
List of document IDs in the requested order
|
||||
"""
|
||||
self._ensure_open()
|
||||
user_query = self._parse_query(query, search_mode)
|
||||
final_query = self._apply_permission_filter(user_query, user)
|
||||
|
||||
searcher = self._index.searcher()
|
||||
effective_limit = limit if limit is not None else searcher.num_docs
|
||||
|
||||
if sort_field and sort_field in self.SORT_FIELD_MAP:
|
||||
mapped_field = self.SORT_FIELD_MAP[sort_field]
|
||||
results = searcher.search(
|
||||
final_query,
|
||||
limit=effective_limit,
|
||||
order_by_field=mapped_field,
|
||||
order=tantivy.Order.Desc if sort_reverse else tantivy.Order.Asc,
|
||||
)
|
||||
all_hits = [(hit[1],) for hit in results.hits]
|
||||
else:
|
||||
results = searcher.search(final_query, limit=effective_limit)
|
||||
all_hits = [(hit[1], hit[0]) for hit in results.hits]
|
||||
|
||||
# Normalize scores and apply threshold (relevance search only)
|
||||
if all_hits:
|
||||
max_score = max(hit[1] for hit in all_hits) or 1.0
|
||||
all_hits = [(hit[0], hit[1] / max_score) for hit in all_hits]
|
||||
|
||||
threshold = settings.ADVANCED_FUZZY_SEARCH_THRESHOLD
|
||||
if threshold is not None:
|
||||
all_hits = [hit for hit in all_hits if hit[1] >= threshold]
|
||||
|
||||
# TODO: Replace with searcher.collect_u64_fast_field("id", addrs) once
|
||||
# https://github.com/quickwit-oss/tantivy-py/pull/641 lands — eliminates
|
||||
# one stored-doc fetch per result (~80% reduction in search_ids latency).
|
||||
return [searcher.doc(doc_addr).to_dict()["id"][0] for doc_addr, *_ in all_hits]
|
||||
return SearchResults(
|
||||
hits=hits,
|
||||
total=total,
|
||||
query=query,
|
||||
)
|
||||
|
||||
def autocomplete(
|
||||
self,
|
||||
@@ -748,10 +603,6 @@ class TantivyBackend:
|
||||
frequency (how many documents contain each word). Optionally filters
|
||||
results to only words from documents visible to the specified user.
|
||||
|
||||
NOTE: This is the hottest search path (called per keystroke).
|
||||
A future improvement would be to cache results in Redis, keyed by
|
||||
(prefix, user_id), and invalidate on index writes.
|
||||
|
||||
Args:
|
||||
term: Prefix to match against autocomplete words
|
||||
limit: Maximum number of suggestions to return
|
||||
@@ -762,94 +613,64 @@ class TantivyBackend:
|
||||
"""
|
||||
self._ensure_open()
|
||||
normalized_term = ascii_fold(term.lower())
|
||||
if not normalized_term:
|
||||
return []
|
||||
|
||||
searcher = self._index.searcher()
|
||||
|
||||
# Build a prefix query on autocomplete_word so we only scan docs
|
||||
# containing words that start with the prefix, not the entire index.
|
||||
# tantivy regex is implicitly anchored; .+ avoids the empty-match
|
||||
# error that .* triggers. We OR with term_query to also match the
|
||||
# exact prefix as a complete word.
|
||||
escaped = re.escape(normalized_term)
|
||||
prefix_query = tantivy.Query.boolean_query(
|
||||
[
|
||||
(
|
||||
tantivy.Occur.Should,
|
||||
tantivy.Query.term_query(
|
||||
self._schema,
|
||||
"autocomplete_word",
|
||||
normalized_term,
|
||||
),
|
||||
),
|
||||
(
|
||||
tantivy.Occur.Should,
|
||||
tantivy.Query.regex_query(
|
||||
self._schema,
|
||||
"autocomplete_word",
|
||||
f"{escaped}.+",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Intersect with permission filter so autocomplete words from
|
||||
# invisible documents don't leak to other users.
|
||||
# Apply permission filter for non-superusers so autocomplete words
|
||||
# from invisible documents don't leak to other users.
|
||||
if user is not None and not user.is_superuser:
|
||||
final_query = tantivy.Query.boolean_query(
|
||||
[
|
||||
(tantivy.Occur.Must, prefix_query),
|
||||
(tantivy.Occur.Must, build_permission_filter(self._schema, user)),
|
||||
],
|
||||
)
|
||||
base_query = build_permission_filter(self._schema, user)
|
||||
else:
|
||||
final_query = prefix_query
|
||||
base_query = tantivy.Query.all_query()
|
||||
|
||||
results = searcher.search(final_query, limit=searcher.num_docs)
|
||||
results = searcher.search(base_query, limit=10000)
|
||||
|
||||
# Count how many visible documents each matching word appears in.
|
||||
# Count how many visible documents each word appears in.
|
||||
# Using Counter (not set) preserves per-word document frequency so
|
||||
# we can rank suggestions by how commonly they occur — the same
|
||||
# signal Whoosh used for Tf/Idf-based autocomplete ordering.
|
||||
word_counts: Counter[str] = Counter()
|
||||
for _score, doc_address in results.hits:
|
||||
stored_doc = searcher.doc(doc_address)
|
||||
doc_dict = stored_doc.to_dict()
|
||||
if "autocomplete_word" in doc_dict:
|
||||
for word in doc_dict["autocomplete_word"]:
|
||||
if word.startswith(normalized_term):
|
||||
word_counts[word] += 1
|
||||
word_counts.update(doc_dict["autocomplete_word"])
|
||||
|
||||
# Sort by document frequency descending; break ties alphabetically.
|
||||
# Filter to prefix matches, sort by document frequency descending;
|
||||
# break ties alphabetically for stable, deterministic output.
|
||||
matches = sorted(
|
||||
word_counts,
|
||||
(w for w in word_counts if w.startswith(normalized_term)),
|
||||
key=lambda w: (-word_counts[w], w),
|
||||
)
|
||||
|
||||
return matches[:limit]
|
||||
|
||||
def more_like_this_ids(
|
||||
def more_like_this(
|
||||
self,
|
||||
doc_id: int,
|
||||
user: AbstractBaseUser | None,
|
||||
*,
|
||||
limit: int | None = None,
|
||||
) -> list[int]:
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> SearchResults:
|
||||
"""
|
||||
Return IDs of documents similar to the given document — no highlights.
|
||||
Find documents similar to the given document using content analysis.
|
||||
|
||||
Lightweight companion to more_like_this(). The original document is
|
||||
excluded from results.
|
||||
Uses Tantivy's "more like this" query to find documents with similar
|
||||
content patterns. The original document is excluded from results.
|
||||
|
||||
Args:
|
||||
doc_id: Primary key of the reference document
|
||||
user: User for permission filtering (None for no filtering)
|
||||
limit: Maximum number of IDs to return (None = all matching docs)
|
||||
page: Page number (1-indexed) for pagination
|
||||
page_size: Number of results per page
|
||||
|
||||
Returns:
|
||||
List of similar document IDs (excluding the original)
|
||||
SearchResults with similar documents (excluding the original)
|
||||
"""
|
||||
self._ensure_open()
|
||||
searcher = self._index.searcher()
|
||||
|
||||
# First find the document address
|
||||
id_query = tantivy.Query.range_query(
|
||||
self._schema,
|
||||
"id",
|
||||
@@ -860,9 +681,13 @@ class TantivyBackend:
|
||||
results = searcher.search(id_query, limit=1)
|
||||
|
||||
if not results.hits:
|
||||
return []
|
||||
# Document not found
|
||||
return SearchResults(hits=[], total=0, query=f"more_like:{doc_id}")
|
||||
|
||||
# Extract doc_address from (score, doc_address) tuple
|
||||
doc_address = results.hits[0][1]
|
||||
|
||||
# Build more like this query
|
||||
mlt_query = tantivy.Query.more_like_this_query(
|
||||
doc_address,
|
||||
min_doc_frequency=1,
|
||||
@@ -874,21 +699,59 @@ class TantivyBackend:
|
||||
boost_factor=None,
|
||||
)
|
||||
|
||||
final_query = self._apply_permission_filter(mlt_query, user)
|
||||
# Apply permission filter
|
||||
if user is not None:
|
||||
permission_filter = build_permission_filter(self._schema, user)
|
||||
final_query = tantivy.Query.boolean_query(
|
||||
[
|
||||
(tantivy.Occur.Must, mlt_query),
|
||||
(tantivy.Occur.Must, permission_filter),
|
||||
],
|
||||
)
|
||||
else:
|
||||
final_query = mlt_query
|
||||
|
||||
effective_limit = limit if limit is not None else searcher.num_docs
|
||||
# Fetch one extra to account for excluding the original document
|
||||
results = searcher.search(final_query, limit=effective_limit + 1)
|
||||
# Search
|
||||
offset = (page - 1) * page_size
|
||||
results = searcher.search(final_query, limit=offset + page_size)
|
||||
|
||||
# TODO: Replace with collect_u64_fast_field("id", addrs) once
|
||||
# https://github.com/quickwit-oss/tantivy-py/pull/641 lands.
|
||||
ids = []
|
||||
for _score, doc_address in results.hits:
|
||||
result_doc_id = searcher.doc(doc_address).to_dict()["id"][0]
|
||||
if result_doc_id != doc_id:
|
||||
ids.append(result_doc_id)
|
||||
total = results.count
|
||||
# Convert from (score, doc_address) to (doc_address, score)
|
||||
all_hits = [(hit[1], hit[0]) for hit in results.hits]
|
||||
|
||||
return ids[:limit] if limit is not None else ids
|
||||
# Normalize scores
|
||||
if all_hits:
|
||||
max_score = max(hit[1] for hit in all_hits) or 1.0
|
||||
all_hits = [(hit[0], hit[1] / max_score) for hit in all_hits]
|
||||
|
||||
# Get page hits
|
||||
page_hits = all_hits[offset : offset + page_size]
|
||||
|
||||
# Build results
|
||||
hits: list[SearchHit] = []
|
||||
for rank, (doc_address, score) in enumerate(page_hits, start=offset + 1):
|
||||
actual_doc = searcher.doc(doc_address)
|
||||
doc_dict = actual_doc.to_dict()
|
||||
result_doc_id = doc_dict["id"][0]
|
||||
|
||||
# Skip the original document
|
||||
if result_doc_id == doc_id:
|
||||
continue
|
||||
|
||||
hits.append(
|
||||
SearchHit(
|
||||
id=result_doc_id,
|
||||
score=score,
|
||||
rank=rank,
|
||||
highlights={}, # MLT doesn't generate highlights
|
||||
),
|
||||
)
|
||||
|
||||
return SearchResults(
|
||||
hits=hits,
|
||||
total=max(0, total - 1), # Subtract 1 for the original document
|
||||
query=f"more_like:{doc_id}",
|
||||
)
|
||||
|
||||
def batch_update(self, lock_timeout: float = 30.0) -> WriteBatch:
|
||||
"""
|
||||
|
||||
@@ -396,17 +396,10 @@ def build_permission_filter(
|
||||
Tantivy query that filters results to visible documents
|
||||
|
||||
Implementation Notes:
|
||||
- Uses range_query instead of term_query for owner_id/viewer_id to work
|
||||
around a tantivy-py bug where Python ints are inferred as i64, causing
|
||||
term_query to return no hits on u64 fields.
|
||||
TODO: Replace with term_query once
|
||||
https://github.com/quickwit-oss/tantivy-py/pull/642 lands.
|
||||
|
||||
- Uses range_query(owner_id, 1, MAX_U64) as an "owner exists" check
|
||||
because exists_query is not yet available in tantivy-py 0.25.
|
||||
TODO: Replace with exists_query("owner_id") once that is exposed in
|
||||
a tantivy-py release.
|
||||
|
||||
- Uses range_query instead of term_query to work around unsigned integer
|
||||
type detection bug in tantivy-py 0.25
|
||||
- Uses boolean_query for "no owner" check since exists_query is not
|
||||
available in tantivy-py 0.25.1 (available in master)
|
||||
- Uses disjunction_max_query to combine permission clauses with OR logic
|
||||
"""
|
||||
owner_any = tantivy.Query.range_query(
|
||||
|
||||
@@ -72,9 +72,6 @@ def build_schema() -> tantivy.Schema:
|
||||
|
||||
# JSON fields — structured queries: notes.user:alice, custom_fields.name:invoice
|
||||
sb.add_json_field("notes", stored=True, tokenizer_name="paperless_text")
|
||||
# Plain-text companion for notes — tantivy's SnippetGenerator does not support
|
||||
# JSON fields, so highlights require a text field with the same content.
|
||||
sb.add_text_field("notes_text", stored=True, tokenizer_name="paperless_text")
|
||||
sb.add_json_field("custom_fields", stored=True, tokenizer_name="paperless_text")
|
||||
|
||||
for field in (
|
||||
|
||||
@@ -33,12 +33,19 @@ class TestWriteBatch:
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
ids = backend.search_ids("should survive", user=None)
|
||||
assert len(ids) == 1
|
||||
r = backend.search(
|
||||
"should survive",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
assert r.total == 1
|
||||
|
||||
|
||||
class TestSearch:
|
||||
"""Test search query parsing and matching via search_ids."""
|
||||
"""Test search functionality."""
|
||||
|
||||
def test_text_mode_limits_default_search_to_title_and_content(
|
||||
self,
|
||||
@@ -53,20 +60,27 @@ class TestSearch:
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
assert (
|
||||
len(
|
||||
backend.search_ids(
|
||||
"document_type:invoice",
|
||||
user=None,
|
||||
search_mode=SearchMode.TEXT,
|
||||
),
|
||||
)
|
||||
== 0
|
||||
metadata_only = backend.search(
|
||||
"document_type:invoice",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TEXT,
|
||||
)
|
||||
assert (
|
||||
len(backend.search_ids("monthly", user=None, search_mode=SearchMode.TEXT))
|
||||
== 1
|
||||
assert metadata_only.total == 0
|
||||
|
||||
content_match = backend.search(
|
||||
"monthly",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TEXT,
|
||||
)
|
||||
assert content_match.total == 1
|
||||
|
||||
def test_title_mode_limits_default_search_to_title_only(
|
||||
self,
|
||||
@@ -81,14 +95,27 @@ class TestSearch:
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
assert (
|
||||
len(backend.search_ids("monthly", user=None, search_mode=SearchMode.TITLE))
|
||||
== 0
|
||||
content_only = backend.search(
|
||||
"monthly",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TITLE,
|
||||
)
|
||||
assert (
|
||||
len(backend.search_ids("invoice", user=None, search_mode=SearchMode.TITLE))
|
||||
== 1
|
||||
assert content_only.total == 0
|
||||
|
||||
title_match = backend.search(
|
||||
"invoice",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TITLE,
|
||||
)
|
||||
assert title_match.total == 1
|
||||
|
||||
def test_text_mode_matches_partial_term_substrings(
|
||||
self,
|
||||
@@ -103,16 +130,38 @@ class TestSearch:
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
assert (
|
||||
len(backend.search_ids("pass", user=None, search_mode=SearchMode.TEXT)) == 1
|
||||
prefix_match = backend.search(
|
||||
"pass",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TEXT,
|
||||
)
|
||||
assert (
|
||||
len(backend.search_ids("sswo", user=None, search_mode=SearchMode.TEXT)) == 1
|
||||
assert prefix_match.total == 1
|
||||
|
||||
infix_match = backend.search(
|
||||
"sswo",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TEXT,
|
||||
)
|
||||
assert (
|
||||
len(backend.search_ids("sswo re", user=None, search_mode=SearchMode.TEXT))
|
||||
== 1
|
||||
assert infix_match.total == 1
|
||||
|
||||
phrase_match = backend.search(
|
||||
"sswo re",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TEXT,
|
||||
)
|
||||
assert phrase_match.total == 1
|
||||
|
||||
def test_text_mode_does_not_match_on_partial_term_overlap(
|
||||
self,
|
||||
@@ -127,10 +176,16 @@ class TestSearch:
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
assert (
|
||||
len(backend.search_ids("raptor", user=None, search_mode=SearchMode.TEXT))
|
||||
== 0
|
||||
non_match = backend.search(
|
||||
"raptor",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TEXT,
|
||||
)
|
||||
assert non_match.total == 0
|
||||
|
||||
def test_text_mode_anchors_later_query_tokens_to_token_starts(
|
||||
self,
|
||||
@@ -159,9 +214,16 @@ class TestSearch:
|
||||
backend.add_or_update(prefix_doc)
|
||||
backend.add_or_update(false_positive)
|
||||
|
||||
result_ids = set(
|
||||
backend.search_ids("Z-Berichte 6", user=None, search_mode=SearchMode.TEXT),
|
||||
results = backend.search(
|
||||
"Z-Berichte 6",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TEXT,
|
||||
)
|
||||
result_ids = {hit["id"] for hit in results.hits}
|
||||
|
||||
assert exact_doc.id in result_ids
|
||||
assert prefix_doc.id in result_ids
|
||||
@@ -180,9 +242,16 @@ class TestSearch:
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
assert (
|
||||
len(backend.search_ids("!!!", user=None, search_mode=SearchMode.TEXT)) == 0
|
||||
no_tokens = backend.search(
|
||||
"!!!",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TEXT,
|
||||
)
|
||||
assert no_tokens.total == 0
|
||||
|
||||
def test_title_mode_matches_partial_term_substrings(
|
||||
self,
|
||||
@@ -197,18 +266,59 @@ class TestSearch:
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
assert (
|
||||
len(backend.search_ids("pass", user=None, search_mode=SearchMode.TITLE))
|
||||
== 1
|
||||
prefix_match = backend.search(
|
||||
"pass",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TITLE,
|
||||
)
|
||||
assert (
|
||||
len(backend.search_ids("sswo", user=None, search_mode=SearchMode.TITLE))
|
||||
== 1
|
||||
assert prefix_match.total == 1
|
||||
|
||||
infix_match = backend.search(
|
||||
"sswo",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TITLE,
|
||||
)
|
||||
assert (
|
||||
len(backend.search_ids("sswo gu", user=None, search_mode=SearchMode.TITLE))
|
||||
== 1
|
||||
assert infix_match.total == 1
|
||||
|
||||
phrase_match = backend.search(
|
||||
"sswo gu",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TITLE,
|
||||
)
|
||||
assert phrase_match.total == 1
|
||||
|
||||
def test_scores_normalised_top_hit_is_one(self, backend: TantivyBackend):
|
||||
"""Search scores must be normalized so top hit has score 1.0 for UI consistency."""
|
||||
for i, title in enumerate(["bank invoice", "bank statement", "bank receipt"]):
|
||||
doc = Document.objects.create(
|
||||
title=title,
|
||||
content=title,
|
||||
checksum=f"SN{i}",
|
||||
pk=10 + i,
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
r = backend.search(
|
||||
"bank",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
assert r.hits[0]["score"] == pytest.approx(1.0)
|
||||
assert all(0.0 <= h["score"] <= 1.0 for h in r.hits)
|
||||
|
||||
def test_sort_field_ascending(self, backend: TantivyBackend):
|
||||
"""Searching with sort_reverse=False must return results in ascending ASN order."""
|
||||
@@ -221,14 +331,16 @@ class TestSearch:
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
ids = backend.search_ids(
|
||||
r = backend.search(
|
||||
"sortable",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field="archive_serial_number",
|
||||
sort_reverse=False,
|
||||
)
|
||||
assert len(ids) == 3
|
||||
asns = [Document.objects.get(pk=doc_id).archive_serial_number for doc_id in ids]
|
||||
assert r.total == 3
|
||||
asns = [Document.objects.get(pk=h["id"]).archive_serial_number for h in r.hits]
|
||||
assert asns == [10, 20, 30]
|
||||
|
||||
def test_sort_field_descending(self, backend: TantivyBackend):
|
||||
@@ -242,91 +354,79 @@ class TestSearch:
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
ids = backend.search_ids(
|
||||
r = backend.search(
|
||||
"sortable",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field="archive_serial_number",
|
||||
sort_reverse=True,
|
||||
)
|
||||
assert len(ids) == 3
|
||||
asns = [Document.objects.get(pk=doc_id).archive_serial_number for doc_id in ids]
|
||||
assert r.total == 3
|
||||
asns = [Document.objects.get(pk=h["id"]).archive_serial_number for h in r.hits]
|
||||
assert asns == [30, 20, 10]
|
||||
|
||||
|
||||
class TestSearchIds:
|
||||
"""Test lightweight ID-only search."""
|
||||
|
||||
def test_returns_matching_ids(self, backend: TantivyBackend):
|
||||
"""search_ids must return IDs of all matching documents."""
|
||||
docs = []
|
||||
for i in range(5):
|
||||
doc = Document.objects.create(
|
||||
title=f"findable doc {i}",
|
||||
content="common keyword",
|
||||
checksum=f"SI{i}",
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
docs.append(doc)
|
||||
other = Document.objects.create(
|
||||
title="unrelated",
|
||||
content="nothing here",
|
||||
checksum="SI_other",
|
||||
)
|
||||
backend.add_or_update(other)
|
||||
|
||||
ids = backend.search_ids(
|
||||
"common keyword",
|
||||
user=None,
|
||||
search_mode=SearchMode.QUERY,
|
||||
)
|
||||
assert set(ids) == {d.pk for d in docs}
|
||||
assert other.pk not in ids
|
||||
|
||||
def test_respects_permission_filter(self, backend: TantivyBackend):
|
||||
"""search_ids must respect user permission filtering."""
|
||||
owner = User.objects.create_user("ids_owner")
|
||||
other = User.objects.create_user("ids_other")
|
||||
def test_fuzzy_threshold_filters_low_score_hits(
|
||||
self,
|
||||
backend: TantivyBackend,
|
||||
settings,
|
||||
):
|
||||
"""When ADVANCED_FUZZY_SEARCH_THRESHOLD exceeds all normalized scores, hits must be filtered out."""
|
||||
doc = Document.objects.create(
|
||||
title="private doc",
|
||||
content="secret keyword",
|
||||
checksum="SIP1",
|
||||
title="Invoice document",
|
||||
content="financial report",
|
||||
checksum="FT1",
|
||||
pk=120,
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
# Threshold above 1.0 filters every hit (normalized scores top out at 1.0)
|
||||
settings.ADVANCED_FUZZY_SEARCH_THRESHOLD = 1.1
|
||||
r = backend.search(
|
||||
"invoice",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
assert r.hits == []
|
||||
|
||||
def test_owner_filter(self, backend: TantivyBackend):
|
||||
"""Document owners can search their private documents; other users cannot access them."""
|
||||
owner = User.objects.create_user("owner")
|
||||
other = User.objects.create_user("other")
|
||||
doc = Document.objects.create(
|
||||
title="Private",
|
||||
content="secret",
|
||||
checksum="PF1",
|
||||
pk=20,
|
||||
owner=owner,
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
assert backend.search_ids(
|
||||
"secret",
|
||||
user=owner,
|
||||
search_mode=SearchMode.QUERY,
|
||||
) == [doc.pk]
|
||||
assert (
|
||||
backend.search_ids("secret", user=other, search_mode=SearchMode.QUERY) == []
|
||||
backend.search(
|
||||
"secret",
|
||||
user=owner,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
).total
|
||||
== 1
|
||||
)
|
||||
|
||||
def test_respects_fuzzy_threshold(self, backend: TantivyBackend, settings):
|
||||
"""search_ids must apply the same fuzzy threshold as search()."""
|
||||
doc = Document.objects.create(
|
||||
title="threshold test",
|
||||
content="unique term",
|
||||
checksum="SIT1",
|
||||
assert (
|
||||
backend.search(
|
||||
"secret",
|
||||
user=other,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
).total
|
||||
== 0
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
settings.ADVANCED_FUZZY_SEARCH_THRESHOLD = 1.1
|
||||
ids = backend.search_ids("unique", user=None, search_mode=SearchMode.QUERY)
|
||||
assert ids == []
|
||||
|
||||
def test_returns_ids_for_text_mode(self, backend: TantivyBackend):
|
||||
"""search_ids must work with TEXT search mode."""
|
||||
doc = Document.objects.create(
|
||||
title="text mode doc",
|
||||
content="findable phrase",
|
||||
checksum="SIM1",
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
ids = backend.search_ids("findable", user=None, search_mode=SearchMode.TEXT)
|
||||
assert ids == [doc.pk]
|
||||
|
||||
|
||||
class TestRebuild:
|
||||
@@ -390,26 +490,57 @@ class TestAutocomplete:
|
||||
class TestMoreLikeThis:
|
||||
"""Test more like this functionality."""
|
||||
|
||||
def test_more_like_this_ids_excludes_original(self, backend: TantivyBackend):
|
||||
"""more_like_this_ids must return IDs of similar documents, excluding the original."""
|
||||
def test_excludes_original(self, backend: TantivyBackend):
|
||||
"""More like this queries must exclude the reference document from results."""
|
||||
doc1 = Document.objects.create(
|
||||
title="Important document",
|
||||
content="financial information report",
|
||||
checksum="MLTI1",
|
||||
pk=150,
|
||||
content="financial information",
|
||||
checksum="MLT1",
|
||||
pk=50,
|
||||
)
|
||||
doc2 = Document.objects.create(
|
||||
title="Another document",
|
||||
content="financial information report",
|
||||
checksum="MLTI2",
|
||||
pk=151,
|
||||
content="financial report",
|
||||
checksum="MLT2",
|
||||
pk=51,
|
||||
)
|
||||
backend.add_or_update(doc1)
|
||||
backend.add_or_update(doc2)
|
||||
|
||||
ids = backend.more_like_this_ids(doc_id=150, user=None)
|
||||
assert 150 not in ids
|
||||
assert 151 in ids
|
||||
results = backend.more_like_this(doc_id=50, user=None, page=1, page_size=10)
|
||||
returned_ids = [hit["id"] for hit in results.hits]
|
||||
assert 50 not in returned_ids # Original document excluded
|
||||
|
||||
def test_with_user_applies_permission_filter(self, backend: TantivyBackend):
|
||||
"""more_like_this with a user must exclude documents that user cannot see."""
|
||||
viewer = User.objects.create_user("mlt_viewer")
|
||||
other = User.objects.create_user("mlt_other")
|
||||
public_doc = Document.objects.create(
|
||||
title="Public financial document",
|
||||
content="quarterly financial analysis report figures",
|
||||
checksum="MLT3",
|
||||
pk=52,
|
||||
)
|
||||
private_doc = Document.objects.create(
|
||||
title="Private financial document",
|
||||
content="quarterly financial analysis report figures",
|
||||
checksum="MLT4",
|
||||
pk=53,
|
||||
owner=other,
|
||||
)
|
||||
backend.add_or_update(public_doc)
|
||||
backend.add_or_update(private_doc)
|
||||
|
||||
results = backend.more_like_this(doc_id=52, user=viewer, page=1, page_size=10)
|
||||
returned_ids = [hit["id"] for hit in results.hits]
|
||||
# private_doc is owned by other, so viewer cannot see it
|
||||
assert 53 not in returned_ids
|
||||
|
||||
def test_document_not_in_index_returns_empty(self, backend: TantivyBackend):
|
||||
"""more_like_this for a doc_id absent from the index must return empty results."""
|
||||
results = backend.more_like_this(doc_id=9999, user=None, page=1, page_size=10)
|
||||
assert results.hits == []
|
||||
assert results.total == 0
|
||||
|
||||
|
||||
class TestSingleton:
|
||||
@@ -462,10 +593,19 @@ class TestFieldHandling:
|
||||
# Should not raise an exception
|
||||
backend.add_or_update(doc)
|
||||
|
||||
assert len(backend.search_ids("test", user=None)) == 1
|
||||
results = backend.search(
|
||||
"test",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
assert results.total == 1
|
||||
|
||||
def test_custom_fields_include_name_and_value(self, backend: TantivyBackend):
|
||||
"""Custom fields must be indexed with both field name and value for structured queries."""
|
||||
# Create a custom field
|
||||
field = CustomField.objects.create(
|
||||
name="Invoice Number",
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
@@ -482,9 +622,18 @@ class TestFieldHandling:
|
||||
value_text="INV-2024-001",
|
||||
)
|
||||
|
||||
# Should not raise an exception during indexing
|
||||
backend.add_or_update(doc)
|
||||
|
||||
assert len(backend.search_ids("invoice", user=None)) == 1
|
||||
results = backend.search(
|
||||
"invoice",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
assert results.total == 1
|
||||
|
||||
def test_select_custom_field_indexes_label_not_id(self, backend: TantivyBackend):
|
||||
"""SELECT custom fields must index the human-readable label, not the opaque option ID."""
|
||||
@@ -511,8 +660,27 @@ class TestFieldHandling:
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
assert len(backend.search_ids("custom_fields.value:invoice", user=None)) == 1
|
||||
assert len(backend.search_ids("custom_fields.value:opt_abc", user=None)) == 0
|
||||
# Label should be findable
|
||||
results = backend.search(
|
||||
"custom_fields.value:invoice",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
assert results.total == 1
|
||||
|
||||
# Opaque ID must not appear in the index
|
||||
results = backend.search(
|
||||
"custom_fields.value:opt_abc",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
assert results.total == 0
|
||||
|
||||
def test_none_custom_field_value_not_indexed(self, backend: TantivyBackend):
|
||||
"""Custom field instances with no value set must not produce an index entry."""
|
||||
@@ -534,7 +702,16 @@ class TestFieldHandling:
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
assert len(backend.search_ids("custom_fields.value:none", user=None)) == 0
|
||||
# The string "none" must not appear as an indexed value
|
||||
results = backend.search(
|
||||
"custom_fields.value:none",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
assert results.total == 0
|
||||
|
||||
def test_notes_include_user_information(self, backend: TantivyBackend):
|
||||
"""Notes must be indexed with user information when available for structured queries."""
|
||||
@@ -547,101 +724,32 @@ class TestFieldHandling:
|
||||
)
|
||||
Note.objects.create(document=doc, note="Important note", user=user)
|
||||
|
||||
# Should not raise an exception during indexing
|
||||
backend.add_or_update(doc)
|
||||
|
||||
ids = backend.search_ids("test", user=None)
|
||||
assert len(ids) == 1, (
|
||||
f"Expected 1, got {len(ids)}. Document content should be searchable."
|
||||
# Test basic document search first
|
||||
results = backend.search(
|
||||
"test",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
assert results.total == 1, (
|
||||
f"Expected 1, got {results.total}. Document content should be searchable."
|
||||
)
|
||||
|
||||
ids = backend.search_ids("notes.note:important", user=None)
|
||||
assert len(ids) == 1, (
|
||||
f"Expected 1, got {len(ids)}. Note content should be searchable via notes.note: prefix."
|
||||
# Test notes search — must use structured JSON syntax now that note
|
||||
# is no longer in DEFAULT_SEARCH_FIELDS
|
||||
results = backend.search(
|
||||
"notes.note:important",
|
||||
user=None,
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
)
|
||||
|
||||
|
||||
class TestHighlightHits:
|
||||
"""Test highlight_hits returns proper HTML strings, not raw Snippet objects."""
|
||||
|
||||
def test_highlights_content_returns_match_span_html(
|
||||
self,
|
||||
backend: TantivyBackend,
|
||||
):
|
||||
"""highlight_hits must return frontend-ready highlight spans."""
|
||||
doc = Document.objects.create(
|
||||
title="Highlight Test",
|
||||
content="The quick brown fox jumps over the lazy dog",
|
||||
checksum="HH1",
|
||||
pk=90,
|
||||
assert results.total == 1, (
|
||||
f"Expected 1, got {results.total}. Note content should be searchable via notes.note: prefix."
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
hits = backend.highlight_hits("quick", [doc.pk])
|
||||
|
||||
assert len(hits) == 1
|
||||
highlights = hits[0]["highlights"]
|
||||
assert "content" in highlights
|
||||
content_highlight = highlights["content"]
|
||||
assert isinstance(content_highlight, str), (
|
||||
f"Expected str, got {type(content_highlight)}: {content_highlight!r}"
|
||||
)
|
||||
assert '<span class="match">' in content_highlight, (
|
||||
f"Expected HTML with match span, got: {content_highlight!r}"
|
||||
)
|
||||
|
||||
def test_highlights_notes_returns_match_span_html(
|
||||
self,
|
||||
backend: TantivyBackend,
|
||||
):
|
||||
"""Note highlights must be frontend-ready HTML via notes_text companion field.
|
||||
|
||||
The notes JSON field does not support tantivy SnippetGenerator; the
|
||||
notes_text plain-text field is used instead. We use the full-text
|
||||
query "urgent" (not notes.note:) because notes_text IS in
|
||||
DEFAULT_SEARCH_FIELDS via the normal search path… actually, we use
|
||||
notes.note: prefix so the query targets notes content directly, but
|
||||
the snippet is generated from notes_text which stores the same text.
|
||||
"""
|
||||
user = User.objects.create_user("hl_noteuser")
|
||||
doc = Document.objects.create(
|
||||
title="Doc with matching note",
|
||||
content="unrelated content",
|
||||
checksum="HH2",
|
||||
pk=91,
|
||||
)
|
||||
Note.objects.create(document=doc, note="urgent payment required", user=user)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
# Use notes.note: prefix so the document matches the query and the
|
||||
# notes_text snippet generator can produce highlights.
|
||||
hits = backend.highlight_hits("notes.note:urgent", [doc.pk])
|
||||
|
||||
assert len(hits) == 1
|
||||
highlights = hits[0]["highlights"]
|
||||
assert "notes" in highlights
|
||||
note_highlight = highlights["notes"]
|
||||
assert isinstance(note_highlight, str), (
|
||||
f"Expected str, got {type(note_highlight)}: {note_highlight!r}"
|
||||
)
|
||||
assert '<span class="match">' in note_highlight, (
|
||||
f"Expected HTML with match span, got: {note_highlight!r}"
|
||||
)
|
||||
|
||||
def test_empty_doc_list_returns_empty_hits(self, backend: TantivyBackend):
|
||||
"""highlight_hits with no doc IDs must return an empty list."""
|
||||
hits = backend.highlight_hits("anything", [])
|
||||
assert hits == []
|
||||
|
||||
def test_no_highlights_when_no_match(self, backend: TantivyBackend):
|
||||
"""Documents not matching the query should not appear in results."""
|
||||
doc = Document.objects.create(
|
||||
title="Unrelated",
|
||||
content="completely different text",
|
||||
checksum="HH3",
|
||||
pk=92,
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
hits = backend.highlight_hits("quick", [doc.pk])
|
||||
|
||||
assert len(hits) == 0
|
||||
|
||||
@@ -6,6 +6,8 @@ from unittest.mock import patch
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.test import override_settings
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
@@ -201,6 +203,156 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
)
|
||||
self.assertFalse(Path(old_logo.path).exists())
|
||||
|
||||
def test_api_strips_exif_data_from_uploaded_logo(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A JPEG logo upload containing EXIF metadata
|
||||
WHEN:
|
||||
- Uploaded via PATCH to app config
|
||||
THEN:
|
||||
- Stored logo image has EXIF metadata removed
|
||||
"""
|
||||
image = Image.new("RGB", (12, 12), "blue")
|
||||
exif = Image.Exif()
|
||||
exif[315] = "Paperless Test Author"
|
||||
|
||||
logo = BytesIO()
|
||||
image.save(logo, format="JPEG", exif=exif)
|
||||
logo.seek(0)
|
||||
|
||||
response = self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
{
|
||||
"app_logo": SimpleUploadedFile(
|
||||
name="logo-with-exif.jpg",
|
||||
content=logo.getvalue(),
|
||||
content_type="image/jpeg",
|
||||
),
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
config = ApplicationConfiguration.objects.first()
|
||||
with Image.open(config.app_logo.path) as stored_logo:
|
||||
stored_exif = stored_logo.getexif()
|
||||
|
||||
self.assertEqual(len(stored_exif), 0)
|
||||
|
||||
def test_api_strips_png_metadata_from_uploaded_logo(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A PNG logo upload containing text metadata
|
||||
WHEN:
|
||||
- Uploaded via PATCH to app config
|
||||
THEN:
|
||||
- Stored logo image has metadata removed
|
||||
"""
|
||||
image = Image.new("RGB", (12, 12), "green")
|
||||
pnginfo = PngInfo()
|
||||
pnginfo.add_text("Author", "Paperless Test Author")
|
||||
|
||||
logo = BytesIO()
|
||||
image.save(logo, format="PNG", pnginfo=pnginfo)
|
||||
logo.seek(0)
|
||||
|
||||
response = self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
{
|
||||
"app_logo": SimpleUploadedFile(
|
||||
name="logo-with-metadata.png",
|
||||
content=logo.getvalue(),
|
||||
content_type="image/png",
|
||||
),
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
config = ApplicationConfiguration.objects.first()
|
||||
with Image.open(config.app_logo.path) as stored_logo:
|
||||
stored_text = stored_logo.text
|
||||
|
||||
self.assertEqual(stored_text, {})
|
||||
|
||||
def test_api_accepts_valid_gif_logo(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A valid GIF logo upload
|
||||
WHEN:
|
||||
- Uploaded via PATCH to app config
|
||||
THEN:
|
||||
- Upload succeeds
|
||||
"""
|
||||
image = Image.new("RGB", (12, 12), "red")
|
||||
|
||||
logo = BytesIO()
|
||||
image.save(logo, format="GIF", comment=b"Paperless Test Comment")
|
||||
logo.seek(0)
|
||||
|
||||
response = self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
{
|
||||
"app_logo": SimpleUploadedFile(
|
||||
name="logo.gif",
|
||||
content=logo.getvalue(),
|
||||
content_type="image/gif",
|
||||
),
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
def test_api_rejects_invalid_raster_logo(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A file named as a JPEG but containing non-image payload data
|
||||
WHEN:
|
||||
- Uploaded via PATCH to app config
|
||||
THEN:
|
||||
- Upload is rejected with 400
|
||||
"""
|
||||
response = self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
{
|
||||
"app_logo": SimpleUploadedFile(
|
||||
name="not-an-image.jpg",
|
||||
content=b"<script>alert('xss')</script>",
|
||||
content_type="image/jpeg",
|
||||
),
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("invalid logo image", str(response.data).lower())
|
||||
|
||||
@override_settings(MAX_IMAGE_PIXELS=100)
|
||||
def test_api_rejects_logo_exceeding_max_image_pixels(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- A raster logo larger than the configured MAX_IMAGE_PIXELS limit
|
||||
WHEN:
|
||||
- Uploaded via PATCH to app config
|
||||
THEN:
|
||||
- Upload is rejected with 400
|
||||
"""
|
||||
image = Image.new("RGB", (12, 12), "purple")
|
||||
logo = BytesIO()
|
||||
image.save(logo, format="PNG")
|
||||
logo.seek(0)
|
||||
|
||||
response = self.client.patch(
|
||||
f"{self.ENDPOINT}1/",
|
||||
{
|
||||
"app_logo": SimpleUploadedFile(
|
||||
name="too-large.png",
|
||||
content=logo.getvalue(),
|
||||
content_type="image/png",
|
||||
),
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn(
|
||||
"uploaded logo exceeds the maximum allowed image size",
|
||||
str(response.data).lower(),
|
||||
)
|
||||
|
||||
def test_api_rejects_malicious_svg_logo(self) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
|
||||
@@ -18,6 +18,7 @@ from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
from django.core import mail
|
||||
from django.core.cache import cache
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.db import DataError
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
@@ -1377,6 +1378,79 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
|
||||
self.assertIsNone(overrides.document_type_id)
|
||||
self.assertIsNone(overrides.tag_ids)
|
||||
|
||||
def test_upload_with_path_traversal_filename_is_reduced_to_basename(self) -> None:
|
||||
self.consume_file_mock.return_value = celery.result.AsyncResult(
|
||||
id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
payload = SimpleUploadedFile(
|
||||
"../../outside.pdf",
|
||||
(Path(__file__).parent / "samples" / "simple.pdf").read_bytes(),
|
||||
content_type="application/pdf",
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/documents/post_document/",
|
||||
{"document": payload},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.consume_file_mock.assert_called_once()
|
||||
|
||||
input_doc, overrides = self.get_last_consume_delay_call_args()
|
||||
|
||||
self.assertEqual(input_doc.original_file.name, "outside.pdf")
|
||||
self.assertEqual(overrides.filename, "outside.pdf")
|
||||
self.assertNotIn("..", input_doc.original_file.name)
|
||||
self.assertNotIn("..", overrides.filename)
|
||||
self.assertTrue(
|
||||
input_doc.original_file.resolve(strict=False).is_relative_to(
|
||||
Path(settings.SCRATCH_DIR).resolve(strict=False),
|
||||
),
|
||||
)
|
||||
|
||||
def test_upload_with_path_traversal_content_disposition_filename_is_reduced_to_basename(
|
||||
self,
|
||||
) -> None:
|
||||
self.consume_file_mock.return_value = celery.result.AsyncResult(
|
||||
id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
pdf_bytes = (Path(__file__).parent / "samples" / "simple.pdf").read_bytes()
|
||||
boundary = "paperless-boundary"
|
||||
payload = (
|
||||
(
|
||||
f"--{boundary}\r\n"
|
||||
'Content-Disposition: form-data; name="document"; '
|
||||
'filename="../../outside.pdf"\r\n'
|
||||
"Content-Type: application/pdf\r\n\r\n"
|
||||
).encode()
|
||||
+ pdf_bytes
|
||||
+ f"\r\n--{boundary}--\r\n".encode()
|
||||
)
|
||||
|
||||
response = self.client.generic(
|
||||
"POST",
|
||||
"/api/documents/post_document/",
|
||||
payload,
|
||||
content_type=f"multipart/form-data; boundary={boundary}",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.consume_file_mock.assert_called_once()
|
||||
|
||||
input_doc, overrides = self.get_last_consume_delay_call_args()
|
||||
|
||||
self.assertEqual(input_doc.original_file.name, "outside.pdf")
|
||||
self.assertEqual(overrides.filename, "outside.pdf")
|
||||
self.assertNotIn("..", input_doc.original_file.name)
|
||||
self.assertNotIn("..", overrides.filename)
|
||||
self.assertTrue(
|
||||
input_doc.original_file.resolve(strict=False).is_relative_to(
|
||||
Path(settings.SCRATCH_DIR).resolve(strict=False),
|
||||
),
|
||||
)
|
||||
|
||||
def test_document_filters_use_latest_version_content(self) -> None:
|
||||
root = Document.objects.create(
|
||||
title="versioned root",
|
||||
|
||||
@@ -1503,126 +1503,6 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
[d2.id, d1.id, d3.id],
|
||||
)
|
||||
|
||||
def test_search_ordering_by_score(self) -> None:
|
||||
"""ordering=-score must return results in descending relevance order (best first)."""
|
||||
backend = get_backend()
|
||||
# doc_high has more occurrences of the search term → higher BM25 score
|
||||
doc_low = Document.objects.create(
|
||||
title="score sort low",
|
||||
content="apple",
|
||||
checksum="SCL1",
|
||||
)
|
||||
doc_high = Document.objects.create(
|
||||
title="score sort high",
|
||||
content="apple apple apple apple apple",
|
||||
checksum="SCH1",
|
||||
)
|
||||
backend.add_or_update(doc_low)
|
||||
backend.add_or_update(doc_high)
|
||||
|
||||
# -score = descending = best first (highest score)
|
||||
response = self.client.get("/api/documents/?query=apple&ordering=-score")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
ids = [r["id"] for r in response.data["results"]]
|
||||
self.assertEqual(
|
||||
ids[0],
|
||||
doc_high.id,
|
||||
"Most relevant doc should be first for -score",
|
||||
)
|
||||
|
||||
# score = ascending = worst first (lowest score)
|
||||
response = self.client.get("/api/documents/?query=apple&ordering=score")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
ids = [r["id"] for r in response.data["results"]]
|
||||
self.assertEqual(
|
||||
ids[0],
|
||||
doc_low.id,
|
||||
"Least relevant doc should be first for +score",
|
||||
)
|
||||
|
||||
def test_search_with_tantivy_native_sort(self) -> None:
|
||||
"""When ordering by a Tantivy-sortable field, results must be correctly sorted."""
|
||||
backend = get_backend()
|
||||
for i, asn in enumerate([30, 10, 20]):
|
||||
doc = Document.objects.create(
|
||||
title=f"sortable doc {i}",
|
||||
content="searchable content",
|
||||
checksum=f"TNS{i}",
|
||||
archive_serial_number=asn,
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
response = self.client.get(
|
||||
"/api/documents/?query=searchable&ordering=archive_serial_number",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
asns = [doc["archive_serial_number"] for doc in response.data["results"]]
|
||||
self.assertEqual(asns, [10, 20, 30])
|
||||
|
||||
response = self.client.get(
|
||||
"/api/documents/?query=searchable&ordering=-archive_serial_number",
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
asns = [doc["archive_serial_number"] for doc in response.data["results"]]
|
||||
self.assertEqual(asns, [30, 20, 10])
|
||||
|
||||
def test_search_page_2_returns_correct_slice(self) -> None:
|
||||
"""Page 2 must return the second slice, not overlap with page 1."""
|
||||
backend = get_backend()
|
||||
for i in range(10):
|
||||
doc = Document.objects.create(
|
||||
title=f"doc {i}",
|
||||
content="paginated content",
|
||||
checksum=f"PG2{i}",
|
||||
archive_serial_number=i + 1,
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
|
||||
response = self.client.get(
|
||||
"/api/documents/?query=paginated&ordering=archive_serial_number&page=1&page_size=3",
|
||||
)
|
||||
page1_ids = [r["id"] for r in response.data["results"]]
|
||||
self.assertEqual(len(page1_ids), 3)
|
||||
|
||||
response = self.client.get(
|
||||
"/api/documents/?query=paginated&ordering=archive_serial_number&page=2&page_size=3",
|
||||
)
|
||||
page2_ids = [r["id"] for r in response.data["results"]]
|
||||
self.assertEqual(len(page2_ids), 3)
|
||||
|
||||
# No overlap between pages
|
||||
self.assertEqual(set(page1_ids) & set(page2_ids), set())
|
||||
# Page 2 ASNs are higher than page 1
|
||||
page1_asns = [
|
||||
Document.objects.get(pk=pk).archive_serial_number for pk in page1_ids
|
||||
]
|
||||
page2_asns = [
|
||||
Document.objects.get(pk=pk).archive_serial_number for pk in page2_ids
|
||||
]
|
||||
self.assertTrue(max(page1_asns) < min(page2_asns))
|
||||
|
||||
def test_search_all_field_contains_all_ids_when_paginated(self) -> None:
|
||||
"""The 'all' field must contain every matching ID, even when paginated."""
|
||||
backend = get_backend()
|
||||
doc_ids = []
|
||||
for i in range(10):
|
||||
doc = Document.objects.create(
|
||||
title=f"all field doc {i}",
|
||||
content="allfield content",
|
||||
checksum=f"AF{i}",
|
||||
)
|
||||
backend.add_or_update(doc)
|
||||
doc_ids.append(doc.pk)
|
||||
|
||||
response = self.client.get(
|
||||
"/api/documents/?query=allfield&page=1&page_size=3",
|
||||
headers={"Accept": "application/json; version=9"},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(len(response.data["results"]), 3)
|
||||
# "all" must contain ALL 10 matching IDs
|
||||
self.assertCountEqual(response.data["all"], doc_ids)
|
||||
|
||||
@mock.patch("documents.bulk_edit.bulk_update_documents")
|
||||
def test_global_search(self, m) -> None:
|
||||
"""
|
||||
|
||||
@@ -38,7 +38,6 @@ from django.db.models import Model
|
||||
from django.db.models import OuterRef
|
||||
from django.db.models import Prefetch
|
||||
from django.db.models import Q
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models import Subquery
|
||||
from django.db.models import Sum
|
||||
from django.db.models import When
|
||||
@@ -249,13 +248,6 @@ if settings.AUDIT_LOG_ENABLED:
|
||||
|
||||
logger = logging.getLogger("paperless.api")
|
||||
|
||||
# Crossover point for intersect_and_order: below this count use a targeted
|
||||
# IN-clause query; at or above this count fall back to a full-table scan +
|
||||
# Python set intersection. The IN-clause is faster for small result sets but
|
||||
# degrades on SQLite with thousands of parameters. PostgreSQL handles large IN
|
||||
# clauses efficiently, so this threshold mainly protects SQLite users.
|
||||
_TANTIVY_INTERSECT_THRESHOLD = 5_000
|
||||
|
||||
|
||||
class IndexView(TemplateView):
|
||||
template_name = "index.html"
|
||||
@@ -2068,16 +2060,19 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||
if not self._is_search_request():
|
||||
return super().list(request)
|
||||
|
||||
from documents.search import SearchHit
|
||||
from documents.search import SearchMode
|
||||
from documents.search import TantivyBackend
|
||||
from documents.search import TantivyRelevanceList
|
||||
from documents.search import get_backend
|
||||
|
||||
def parse_search_params() -> tuple[str | None, bool, bool, int, int]:
|
||||
"""Extract query string, search mode, and ordering from request."""
|
||||
active = self._get_active_search_params(request)
|
||||
if len(active) > 1:
|
||||
try:
|
||||
backend = get_backend()
|
||||
# ORM-filtered queryset: permissions + field filters + ordering (DRF backends applied)
|
||||
filtered_qs = self.filter_queryset(self.get_queryset())
|
||||
|
||||
user = None if request.user.is_superuser else request.user
|
||||
active_search_params = self._get_active_search_params(request)
|
||||
|
||||
if len(active_search_params) > 1:
|
||||
raise ValidationError(
|
||||
{
|
||||
"detail": _(
|
||||
@@ -2086,161 +2081,73 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||
},
|
||||
)
|
||||
|
||||
ordering_param = request.query_params.get("ordering", "")
|
||||
sort_reverse = ordering_param.startswith("-")
|
||||
sort_field_name = ordering_param.lstrip("-") or None
|
||||
# "score" means relevance order — Tantivy handles it natively,
|
||||
# so treat it as a Tantivy sort to preserve the ranked order through
|
||||
# the ORM intersection step.
|
||||
use_tantivy_sort = (
|
||||
sort_field_name in TantivyBackend.SORTABLE_FIELDS
|
||||
or sort_field_name is None
|
||||
or sort_field_name == "score"
|
||||
)
|
||||
|
||||
try:
|
||||
page_num = int(request.query_params.get("page", 1))
|
||||
except (TypeError, ValueError):
|
||||
page_num = 1
|
||||
page_size = (
|
||||
self.paginator.get_page_size(request) or self.paginator.page_size
|
||||
)
|
||||
|
||||
return sort_field_name, sort_reverse, use_tantivy_sort, page_num, page_size
|
||||
|
||||
def intersect_and_order(
|
||||
all_ids: list[int],
|
||||
filtered_qs: QuerySet[Document],
|
||||
*,
|
||||
use_tantivy_sort: bool,
|
||||
) -> list[int]:
|
||||
"""Intersect search IDs with ORM-visible IDs, preserving order."""
|
||||
if not all_ids:
|
||||
return []
|
||||
if use_tantivy_sort:
|
||||
if len(all_ids) <= _TANTIVY_INTERSECT_THRESHOLD:
|
||||
# Small result set: targeted IN-clause avoids a full-table scan.
|
||||
visible_ids = set(
|
||||
filtered_qs.filter(pk__in=all_ids).values_list("pk", flat=True),
|
||||
)
|
||||
else:
|
||||
# Large result set: full-table scan + Python intersection is faster
|
||||
# than a large IN-clause on SQLite.
|
||||
visible_ids = set(
|
||||
filtered_qs.values_list("pk", flat=True),
|
||||
)
|
||||
return [doc_id for doc_id in all_ids if doc_id in visible_ids]
|
||||
return list(
|
||||
filtered_qs.filter(id__in=all_ids).values_list("pk", flat=True),
|
||||
)
|
||||
|
||||
def run_text_search(
|
||||
backend: TantivyBackend,
|
||||
user: User | None,
|
||||
filtered_qs: QuerySet[Document],
|
||||
) -> tuple[list[int], list[SearchHit], int]:
|
||||
"""Handle text/title/query search: IDs, ORM intersection, page highlights."""
|
||||
if "text" in request.query_params:
|
||||
search_mode = SearchMode.TEXT
|
||||
query_str = request.query_params["text"]
|
||||
elif "title_search" in request.query_params:
|
||||
search_mode = SearchMode.TITLE
|
||||
query_str = request.query_params["title_search"]
|
||||
else:
|
||||
search_mode = SearchMode.QUERY
|
||||
query_str = request.query_params["query"]
|
||||
|
||||
# "score" is not a real Tantivy sort field — it means relevance order,
|
||||
# which is Tantivy's default when no sort field is specified.
|
||||
is_score_sort = sort_field_name == "score"
|
||||
all_ids = backend.search_ids(
|
||||
query_str,
|
||||
user=user,
|
||||
sort_field=(
|
||||
None if (not use_tantivy_sort or is_score_sort) else sort_field_name
|
||||
),
|
||||
sort_reverse=sort_reverse,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
ordered_ids = intersect_and_order(
|
||||
all_ids,
|
||||
filtered_qs,
|
||||
use_tantivy_sort=use_tantivy_sort,
|
||||
)
|
||||
# Tantivy returns relevance results best-first (descending score).
|
||||
# ordering=score (ascending, worst-first) requires a reversal.
|
||||
if is_score_sort and not sort_reverse:
|
||||
ordered_ids = list(reversed(ordered_ids))
|
||||
|
||||
page_offset = (page_num - 1) * page_size
|
||||
page_ids = ordered_ids[page_offset : page_offset + page_size]
|
||||
page_hits = backend.highlight_hits(
|
||||
query_str,
|
||||
page_ids,
|
||||
search_mode=search_mode,
|
||||
rank_start=page_offset + 1,
|
||||
)
|
||||
return ordered_ids, page_hits, page_offset
|
||||
|
||||
def run_more_like_this(
|
||||
backend: TantivyBackend,
|
||||
user: User | None,
|
||||
filtered_qs: QuerySet[Document],
|
||||
) -> tuple[list[int], list[SearchHit], int]:
|
||||
"""Handle more_like_id search: permission check, IDs, stub hits."""
|
||||
try:
|
||||
more_like_doc_id = int(request.query_params["more_like_id"])
|
||||
more_like_doc = Document.objects.select_related("owner").get(
|
||||
pk=more_like_doc_id,
|
||||
)
|
||||
except (TypeError, ValueError, Document.DoesNotExist):
|
||||
raise PermissionDenied(_("Invalid more_like_id"))
|
||||
|
||||
if not has_perms_owner_aware(
|
||||
request.user,
|
||||
"view_document",
|
||||
more_like_doc,
|
||||
if (
|
||||
"text" in request.query_params
|
||||
or "title_search" in request.query_params
|
||||
or "query" in request.query_params
|
||||
):
|
||||
raise PermissionDenied(_("Insufficient permissions."))
|
||||
|
||||
all_ids = backend.more_like_this_ids(more_like_doc_id, user=user)
|
||||
ordered_ids = intersect_and_order(
|
||||
all_ids,
|
||||
filtered_qs,
|
||||
use_tantivy_sort=True,
|
||||
)
|
||||
|
||||
page_offset = (page_num - 1) * page_size
|
||||
page_ids = ordered_ids[page_offset : page_offset + page_size]
|
||||
page_hits = [
|
||||
SearchHit(id=doc_id, score=0.0, rank=rank, highlights={})
|
||||
for rank, doc_id in enumerate(page_ids, start=page_offset + 1)
|
||||
]
|
||||
return ordered_ids, page_hits, page_offset
|
||||
|
||||
try:
|
||||
sort_field_name, sort_reverse, use_tantivy_sort, page_num, page_size = (
|
||||
parse_search_params()
|
||||
)
|
||||
|
||||
backend = get_backend()
|
||||
filtered_qs = self.filter_queryset(self.get_queryset())
|
||||
user = None if request.user.is_superuser else request.user
|
||||
|
||||
if "more_like_id" in request.query_params:
|
||||
ordered_ids, page_hits, page_offset = run_more_like_this(
|
||||
backend,
|
||||
user,
|
||||
filtered_qs,
|
||||
if "text" in request.query_params:
|
||||
search_mode = SearchMode.TEXT
|
||||
query_str = request.query_params["text"]
|
||||
elif "title_search" in request.query_params:
|
||||
search_mode = SearchMode.TITLE
|
||||
query_str = request.query_params["title_search"]
|
||||
else:
|
||||
search_mode = SearchMode.QUERY
|
||||
query_str = request.query_params["query"]
|
||||
results = backend.search(
|
||||
query_str,
|
||||
user=user,
|
||||
page=1,
|
||||
page_size=10000,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
else:
|
||||
ordered_ids, page_hits, page_offset = run_text_search(
|
||||
backend,
|
||||
user,
|
||||
filtered_qs,
|
||||
# more_like_id — validate permission on the seed document first
|
||||
try:
|
||||
more_like_doc_id = int(request.query_params["more_like_id"])
|
||||
more_like_doc = Document.objects.select_related("owner").get(
|
||||
pk=more_like_doc_id,
|
||||
)
|
||||
except (TypeError, ValueError, Document.DoesNotExist):
|
||||
raise PermissionDenied(_("Invalid more_like_id"))
|
||||
|
||||
if not has_perms_owner_aware(
|
||||
request.user,
|
||||
"view_document",
|
||||
more_like_doc,
|
||||
):
|
||||
raise PermissionDenied(_("Insufficient permissions."))
|
||||
|
||||
results = backend.more_like_this(
|
||||
more_like_doc_id,
|
||||
user=user,
|
||||
page=1,
|
||||
page_size=10000,
|
||||
)
|
||||
|
||||
rl = TantivyRelevanceList(ordered_ids, page_hits, page_offset)
|
||||
hits_by_id = {h["id"]: h for h in results.hits}
|
||||
|
||||
# Determine sort order: no ordering param -> Tantivy relevance; otherwise -> ORM order
|
||||
ordering_param = request.query_params.get("ordering", "").lstrip("-")
|
||||
if not ordering_param:
|
||||
# Preserve Tantivy relevance order; intersect with ORM-visible IDs
|
||||
orm_ids = set(filtered_qs.values_list("pk", flat=True))
|
||||
ordered_hits = [h for h in results.hits if h["id"] in orm_ids]
|
||||
else:
|
||||
# Use ORM ordering (already applied by DocumentsOrderingFilter)
|
||||
hit_ids = set(hits_by_id.keys())
|
||||
orm_ordered_ids = filtered_qs.filter(id__in=hit_ids).values_list(
|
||||
"pk",
|
||||
flat=True,
|
||||
)
|
||||
ordered_hits = [
|
||||
hits_by_id[pk] for pk in orm_ordered_ids if pk in hits_by_id
|
||||
]
|
||||
|
||||
rl = TantivyRelevanceList(ordered_hits)
|
||||
page = self.paginate_queryset(rl)
|
||||
|
||||
if page is not None:
|
||||
@@ -2250,18 +2157,15 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||
if get_boolean(
|
||||
str(request.query_params.get("include_selection_data", "false")),
|
||||
):
|
||||
# NOTE: pk__in=ordered_ids generates a large SQL IN clause
|
||||
# for big result sets. Acceptable today but may need a temp
|
||||
# table or chunked approach if selection_data becomes slow
|
||||
# at scale (tens of thousands of matching documents).
|
||||
all_ids = [h["id"] for h in ordered_hits]
|
||||
response.data["selection_data"] = (
|
||||
self._get_selection_data_for_queryset(
|
||||
filtered_qs.filter(pk__in=ordered_ids),
|
||||
filtered_qs.filter(pk__in=all_ids),
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
serializer = self.get_serializer(page_hits, many=True)
|
||||
serializer = self.get_serializer(ordered_hits, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
except NotFound:
|
||||
@@ -3167,17 +3071,20 @@ class GlobalSearchView(PassUserMixin):
|
||||
docs = all_docs.filter(title__icontains=query)[:OBJECT_LIMIT]
|
||||
else:
|
||||
user = None if request.user.is_superuser else request.user
|
||||
matching_ids = get_backend().search_ids(
|
||||
fts_results = get_backend().search(
|
||||
query,
|
||||
user=user,
|
||||
page=1,
|
||||
page_size=1000,
|
||||
sort_field=None,
|
||||
sort_reverse=False,
|
||||
search_mode=SearchMode.TEXT,
|
||||
limit=OBJECT_LIMIT * 3,
|
||||
)
|
||||
docs_by_id = all_docs.in_bulk(matching_ids)
|
||||
docs_by_id = all_docs.in_bulk([hit["id"] for hit in fts_results.hits])
|
||||
docs = [
|
||||
docs_by_id[doc_id]
|
||||
for doc_id in matching_ids
|
||||
if doc_id in docs_by_id
|
||||
docs_by_id[hit["id"]]
|
||||
for hit in fts_results.hits
|
||||
if hit["id"] in docs_by_id
|
||||
][:OBJECT_LIMIT]
|
||||
saved_views = (
|
||||
get_objects_for_user_owner_aware(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import magic
|
||||
from allauth.mfa.adapter import get_adapter as get_mfa_adapter
|
||||
@@ -11,13 +12,16 @@ from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
from django.contrib.auth.password_validation import validate_password
|
||||
from django.core.files.uploadedfile import InMemoryUploadedFile
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
from PIL import Image
|
||||
from rest_framework import serializers
|
||||
from rest_framework.authtoken.serializers import AuthTokenSerializer
|
||||
|
||||
from paperless.models import ApplicationConfiguration
|
||||
from paperless.network import validate_outbound_http_url
|
||||
from paperless.validators import reject_dangerous_svg
|
||||
from paperless.validators import validate_raster_image
|
||||
from paperless_mail.serialisers import ObfuscatedPasswordField
|
||||
|
||||
logger = logging.getLogger("paperless.settings")
|
||||
@@ -233,9 +237,40 @@ class ApplicationConfigurationSerializer(serializers.ModelSerializer):
|
||||
instance.app_logo.delete()
|
||||
return super().update(instance, validated_data)
|
||||
|
||||
def _sanitize_raster_image(self, file: UploadedFile) -> UploadedFile:
|
||||
try:
|
||||
data = BytesIO()
|
||||
image = Image.open(file)
|
||||
image.save(data, format=image.format)
|
||||
data.seek(0)
|
||||
|
||||
return InMemoryUploadedFile(
|
||||
file=data,
|
||||
field_name=file.field_name,
|
||||
name=file.name,
|
||||
content_type=file.content_type,
|
||||
size=data.getbuffer().nbytes,
|
||||
charset=getattr(file, "charset", None),
|
||||
)
|
||||
finally:
|
||||
image.close()
|
||||
|
||||
def validate_app_logo(self, file: UploadedFile):
|
||||
if file and magic.from_buffer(file.read(2048), mime=True) == "image/svg+xml":
|
||||
reject_dangerous_svg(file)
|
||||
"""
|
||||
Validates and sanitizes the uploaded app logo image. Model field already restricts to
|
||||
jpg/png/gif/svg.
|
||||
"""
|
||||
if file:
|
||||
mime_type = magic.from_buffer(file.read(2048), mime=True)
|
||||
|
||||
if mime_type == "image/svg+xml":
|
||||
reject_dangerous_svg(file)
|
||||
else:
|
||||
validate_raster_image(file)
|
||||
|
||||
if mime_type in {"image/jpeg", "image/png"}:
|
||||
file = self._sanitize_raster_image(file)
|
||||
|
||||
return file
|
||||
|
||||
def validate_llm_endpoint(self, value: str | None) -> str | None:
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from io import BytesIO
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
from lxml import etree
|
||||
from PIL import Image
|
||||
|
||||
ALLOWED_SVG_TAGS: set[str] = {
|
||||
# Basic shapes
|
||||
@@ -254,3 +258,30 @@ def reject_dangerous_svg(file: UploadedFile) -> None:
|
||||
raise ValidationError(
|
||||
f"URI scheme not allowed in {attr_name}: must be #anchor, relative path, or data:image/*",
|
||||
)
|
||||
|
||||
|
||||
def validate_raster_image(file: UploadedFile) -> None:
|
||||
"""
|
||||
Validates that the uploaded file is a valid raster image (JPEG, PNG, etc.)
|
||||
and does not exceed maximum pixel limits.
|
||||
Raises ValidationError if the image is invalid or exceeds the allowed size.
|
||||
"""
|
||||
|
||||
file.seek(0)
|
||||
image_data = file.read()
|
||||
try:
|
||||
with Image.open(BytesIO(image_data)) as image:
|
||||
image.verify()
|
||||
|
||||
if (
|
||||
settings.MAX_IMAGE_PIXELS is not None
|
||||
and settings.MAX_IMAGE_PIXELS > 0
|
||||
and image.width * image.height > settings.MAX_IMAGE_PIXELS
|
||||
):
|
||||
raise ValidationError(
|
||||
"Uploaded logo exceeds the maximum allowed image size.",
|
||||
)
|
||||
if image.format is None: # pragma: no cover
|
||||
raise ValidationError("Invalid logo image.")
|
||||
except (OSError, Image.DecompressionBombError) as e:
|
||||
raise ValidationError("Invalid logo image.") from e
|
||||
|
||||
@@ -89,7 +89,7 @@ class StandardPagination(PageNumberPagination):
|
||||
|
||||
query = self.page.paginator.object_list
|
||||
if isinstance(query, TantivyRelevanceList):
|
||||
return query.get_all_ids()
|
||||
return [h["id"] for h in query._hits]
|
||||
return self.page.paginator.object_list.values_list("pk", flat=True)
|
||||
|
||||
def get_paginated_response_schema(self, schema):
|
||||
|
||||
346
test_backend_profile.py
Normal file
346
test_backend_profile.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# ruff: noqa: T201
|
||||
"""
|
||||
cProfile-based search pipeline profiling with a 20k-document dataset.
|
||||
|
||||
Run with:
|
||||
uv run pytest ../test_backend_profile.py \
|
||||
-m profiling --override-ini="addopts=" -s -v
|
||||
|
||||
Each scenario prints:
|
||||
- Wall time for the operation
|
||||
- cProfile stats sorted by cumulative time (top 25 callers)
|
||||
|
||||
This is a developer tool, not a correctness test. Nothing here should
|
||||
fail unless the code is broken.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from profiling import profile_cpu
|
||||
|
||||
from documents.models import Document
|
||||
from documents.search._backend import TantivyBackend
|
||||
from documents.search._backend import reset_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
# transaction=False (default): tests roll back, but the module-scoped fixture
|
||||
# commits its data outside the test transaction so it remains visible throughout.
|
||||
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset constants
|
||||
# ---------------------------------------------------------------------------
|
||||
NUM_DOCS = 20_000
|
||||
SEED = 42
|
||||
|
||||
# Terms and their approximate match rates across the corpus.
|
||||
# "rechnung" -> ~70% of docs (~14 000)
|
||||
# "mahnung" -> ~20% of docs (~4 000)
|
||||
# "kontonummer" -> ~5% of docs (~1 000)
|
||||
# "rarewort" -> ~1% of docs (~200)
|
||||
COMMON_TERM = "rechnung"
|
||||
MEDIUM_TERM = "mahnung"
|
||||
RARE_TERM = "kontonummer"
|
||||
VERY_RARE_TERM = "rarewort"
|
||||
|
||||
PAGE_SIZE = 25
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FILLER_WORDS = [
|
||||
"dokument", # codespell:ignore
|
||||
"seite",
|
||||
"datum",
|
||||
"betrag",
|
||||
"nummer",
|
||||
"konto",
|
||||
"firma",
|
||||
"vertrag",
|
||||
"lieferant",
|
||||
"bestellung",
|
||||
"steuer",
|
||||
"mwst",
|
||||
"leistung",
|
||||
"auftrag",
|
||||
"zahlung",
|
||||
]
|
||||
|
||||
|
||||
def _build_content(rng: random.Random) -> str:
|
||||
"""Return a short paragraph with terms embedded at the desired rates."""
|
||||
words = rng.choices(_FILLER_WORDS, k=15)
|
||||
if rng.random() < 0.70:
|
||||
words.append(COMMON_TERM)
|
||||
if rng.random() < 0.20:
|
||||
words.append(MEDIUM_TERM)
|
||||
if rng.random() < 0.05:
|
||||
words.append(RARE_TERM)
|
||||
if rng.random() < 0.01:
|
||||
words.append(VERY_RARE_TERM)
|
||||
rng.shuffle(words)
|
||||
return " ".join(words)
|
||||
|
||||
|
||||
def _time(fn, *, label: str, runs: int = 3):
|
||||
"""Run *fn()* several times and report min/avg/max wall time (no cProfile)."""
|
||||
times = []
|
||||
result = None
|
||||
for _ in range(runs):
|
||||
t0 = time.perf_counter()
|
||||
result = fn()
|
||||
times.append(time.perf_counter() - t0)
|
||||
mn, avg, mx = min(times), sum(times) / len(times), max(times)
|
||||
print(
|
||||
f" {label}: min={mn * 1000:.1f}ms avg={avg * 1000:.1f}ms max={mx * 1000:.1f}ms (n={runs})",
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def module_db(django_db_setup, django_db_blocker):
|
||||
"""Unlock the DB for the whole module (module-scoped)."""
|
||||
with django_db_blocker.unblock():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def large_backend(tmp_path_factory, module_db) -> TantivyBackend:
|
||||
"""
|
||||
Build a 20 000-document DB + on-disk Tantivy index, shared across all
|
||||
profiling scenarios in this module. Teardown deletes all documents.
|
||||
"""
|
||||
index_path: Path = tmp_path_factory.mktemp("tantivy_profile")
|
||||
|
||||
# ---- 1. Bulk-create Document rows ----------------------------------------
|
||||
rng = random.Random(SEED)
|
||||
docs = [
|
||||
Document(
|
||||
title=f"Document {i:05d}",
|
||||
content=_build_content(rng),
|
||||
checksum=f"{i:064x}",
|
||||
pk=i + 1,
|
||||
)
|
||||
for i in range(NUM_DOCS)
|
||||
]
|
||||
t0 = time.perf_counter()
|
||||
Document.objects.bulk_create(docs, batch_size=1_000)
|
||||
db_time = time.perf_counter() - t0
|
||||
print(f"\n[setup] bulk_create {NUM_DOCS} docs: {db_time:.2f}s")
|
||||
|
||||
# ---- 2. Build Tantivy index -----------------------------------------------
|
||||
backend = TantivyBackend(path=index_path)
|
||||
backend.open()
|
||||
|
||||
t0 = time.perf_counter()
|
||||
with backend.batch_update() as batch:
|
||||
for doc in Document.objects.iterator(chunk_size=500):
|
||||
batch.add_or_update(doc)
|
||||
idx_time = time.perf_counter() - t0
|
||||
print(f"[setup] index {NUM_DOCS} docs: {idx_time:.2f}s")
|
||||
|
||||
# ---- 3. Report corpus stats -----------------------------------------------
|
||||
for term in (COMMON_TERM, MEDIUM_TERM, RARE_TERM, VERY_RARE_TERM):
|
||||
count = len(backend.search_ids(term, user=None))
|
||||
print(f"[setup] '{term}' -> {count} hits")
|
||||
|
||||
yield backend
|
||||
|
||||
# ---- Teardown ------------------------------------------------------------
|
||||
backend.close()
|
||||
reset_backend()
|
||||
Document.objects.all().delete()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Profiling tests — each scenario is a separate function so pytest can run
|
||||
# them individually or all together with -m profiling.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchIdsProfile:
|
||||
"""Profile backend.search_ids() — pure Tantivy, no DB."""
|
||||
|
||||
def test_search_ids_large(self, large_backend: TantivyBackend):
|
||||
"""~14 000 hits: how long does Tantivy take to collect all IDs?"""
|
||||
profile_cpu(
|
||||
lambda: large_backend.search_ids(COMMON_TERM, user=None),
|
||||
label=f"search_ids('{COMMON_TERM}') [large result set ~14k]",
|
||||
)
|
||||
|
||||
def test_search_ids_medium(self, large_backend: TantivyBackend):
|
||||
"""~4 000 hits."""
|
||||
profile_cpu(
|
||||
lambda: large_backend.search_ids(MEDIUM_TERM, user=None),
|
||||
label=f"search_ids('{MEDIUM_TERM}') [medium result set ~4k]",
|
||||
)
|
||||
|
||||
def test_search_ids_rare(self, large_backend: TantivyBackend):
|
||||
"""~1 000 hits."""
|
||||
profile_cpu(
|
||||
lambda: large_backend.search_ids(RARE_TERM, user=None),
|
||||
label=f"search_ids('{RARE_TERM}') [rare result set ~1k]",
|
||||
)
|
||||
|
||||
|
||||
class TestIntersectAndOrderProfile:
|
||||
"""
|
||||
Profile the DB intersection step: filter(pk__in=search_ids).
|
||||
This is the 'intersect_and_order' logic from views.py.
|
||||
"""
|
||||
|
||||
def test_intersect_large(self, large_backend: TantivyBackend):
|
||||
"""Intersect 14k Tantivy IDs with all 20k ORM-visible docs."""
|
||||
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
|
||||
qs = Document.objects.all()
|
||||
|
||||
print(f"\n Tantivy returned {len(all_ids)} IDs")
|
||||
|
||||
profile_cpu(
|
||||
lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)),
|
||||
label=f"filter(pk__in={len(all_ids)} ids) [large, use_tantivy_sort=True path]",
|
||||
)
|
||||
|
||||
# Also time it a few times to get stable numbers
|
||||
print()
|
||||
_time(
|
||||
lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)),
|
||||
label=f"filter(pk__in={len(all_ids)}) repeated",
|
||||
)
|
||||
|
||||
def test_intersect_rare(self, large_backend: TantivyBackend):
|
||||
"""Intersect ~1k Tantivy IDs — the happy path."""
|
||||
all_ids = large_backend.search_ids(RARE_TERM, user=None)
|
||||
qs = Document.objects.all()
|
||||
|
||||
print(f"\n Tantivy returned {len(all_ids)} IDs")
|
||||
|
||||
profile_cpu(
|
||||
lambda: list(qs.filter(pk__in=all_ids).values_list("pk", flat=True)),
|
||||
label=f"filter(pk__in={len(all_ids)} ids) [rare, use_tantivy_sort=True path]",
|
||||
)
|
||||
|
||||
|
||||
class TestHighlightHitsProfile:
|
||||
"""Profile backend.highlight_hits() — per-doc Tantivy lookups with BM25 scoring."""
|
||||
|
||||
def test_highlight_page1(self, large_backend: TantivyBackend):
|
||||
"""25-doc highlight for page 1 (rank_start=1)."""
|
||||
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
|
||||
page_ids = all_ids[:PAGE_SIZE]
|
||||
|
||||
profile_cpu(
|
||||
lambda: large_backend.highlight_hits(
|
||||
COMMON_TERM,
|
||||
page_ids,
|
||||
rank_start=1,
|
||||
),
|
||||
label=f"highlight_hits page 1 (ids {all_ids[0]}..{all_ids[PAGE_SIZE - 1]})",
|
||||
)
|
||||
|
||||
def test_highlight_page_middle(self, large_backend: TantivyBackend):
|
||||
"""25-doc highlight for a mid-corpus page (rank_start=page_offset+1)."""
|
||||
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
|
||||
mid = len(all_ids) // 2
|
||||
page_ids = all_ids[mid : mid + PAGE_SIZE]
|
||||
page_offset = mid
|
||||
|
||||
profile_cpu(
|
||||
lambda: large_backend.highlight_hits(
|
||||
COMMON_TERM,
|
||||
page_ids,
|
||||
rank_start=page_offset + 1,
|
||||
),
|
||||
label=f"highlight_hits page ~{mid // PAGE_SIZE} (offset {page_offset})",
|
||||
)
|
||||
|
||||
def test_highlight_repeated(self, large_backend: TantivyBackend):
|
||||
"""Multiple runs of page-1 highlight to see variance."""
|
||||
all_ids = large_backend.search_ids(COMMON_TERM, user=None)
|
||||
page_ids = all_ids[:PAGE_SIZE]
|
||||
|
||||
print()
|
||||
_time(
|
||||
lambda: large_backend.highlight_hits(COMMON_TERM, page_ids, rank_start=1),
|
||||
label="highlight_hits page 1",
|
||||
runs=5,
|
||||
)
|
||||
|
||||
|
||||
class TestFullPipelineProfile:
|
||||
"""
|
||||
Profile the combined pipeline as it runs in views.py:
|
||||
search_ids -> filter(pk__in) -> highlight_hits
|
||||
"""
|
||||
|
||||
def _run_pipeline(
|
||||
self,
|
||||
backend: TantivyBackend,
|
||||
term: str,
|
||||
page: int = 1,
|
||||
):
|
||||
all_ids = backend.search_ids(term, user=None)
|
||||
qs = Document.objects.all()
|
||||
visible_ids = set(qs.filter(pk__in=all_ids).values_list("pk", flat=True))
|
||||
ordered_ids = [i for i in all_ids if i in visible_ids]
|
||||
|
||||
page_offset = (page - 1) * PAGE_SIZE
|
||||
page_ids = ordered_ids[page_offset : page_offset + PAGE_SIZE]
|
||||
hits = backend.highlight_hits(
|
||||
term,
|
||||
page_ids,
|
||||
rank_start=page_offset + 1,
|
||||
)
|
||||
return ordered_ids, hits
|
||||
|
||||
def test_pipeline_large_page1(self, large_backend: TantivyBackend):
|
||||
"""Full pipeline: large result set, page 1."""
|
||||
ordered_ids, hits = profile_cpu(
|
||||
lambda: self._run_pipeline(large_backend, COMMON_TERM, page=1),
|
||||
label=f"full pipeline '{COMMON_TERM}' page 1",
|
||||
)[0]
|
||||
print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page")
|
||||
|
||||
def test_pipeline_large_page5(self, large_backend: TantivyBackend):
|
||||
"""Full pipeline: large result set, page 5."""
|
||||
ordered_ids, hits = profile_cpu(
|
||||
lambda: self._run_pipeline(large_backend, COMMON_TERM, page=5),
|
||||
label=f"full pipeline '{COMMON_TERM}' page 5",
|
||||
)[0]
|
||||
print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page")
|
||||
|
||||
def test_pipeline_rare(self, large_backend: TantivyBackend):
|
||||
"""Full pipeline: rare term, page 1 (fast path)."""
|
||||
ordered_ids, hits = profile_cpu(
|
||||
lambda: self._run_pipeline(large_backend, RARE_TERM, page=1),
|
||||
label=f"full pipeline '{RARE_TERM}' page 1",
|
||||
)[0]
|
||||
print(f" -> {len(ordered_ids)} total results, {len(hits)} hits on page")
|
||||
|
||||
def test_pipeline_repeated(self, large_backend: TantivyBackend):
|
||||
"""Repeated runs to get stable timing (no cProfile overhead)."""
|
||||
print()
|
||||
for term, label in [
|
||||
(COMMON_TERM, f"'{COMMON_TERM}' (large)"),
|
||||
(MEDIUM_TERM, f"'{MEDIUM_TERM}' (medium)"),
|
||||
(RARE_TERM, f"'{RARE_TERM}' (rare)"),
|
||||
]:
|
||||
_time(
|
||||
lambda t=term: self._run_pipeline(large_backend, t, page=1),
|
||||
label=f"full pipeline {label} page 1",
|
||||
runs=3,
|
||||
)
|
||||
605
test_classifier_profile.py
Normal file
605
test_classifier_profile.py
Normal file
@@ -0,0 +1,605 @@
|
||||
# ruff: noqa: T201
|
||||
"""
|
||||
cProfile + tracemalloc classifier profiling test.
|
||||
|
||||
Run with:
|
||||
uv run pytest ../test_classifier_profile.py \
|
||||
-m profiling --override-ini="addopts=" -s -v
|
||||
|
||||
Corpus: 5 000 documents, 40 correspondents (25 AUTO), 25 doc types (15 AUTO),
|
||||
50 tags (30 AUTO), 20 storage paths (12 AUTO).
|
||||
|
||||
Document content is generated with Faker for realistic base text, with a
|
||||
per-label fingerprint injected so the MLP has a real learning signal.
|
||||
|
||||
Scenarios:
|
||||
- train() full corpus — memory and CPU profiles
|
||||
- second train() no-op path — shows cost of the skip check
|
||||
- save()/load() round-trip — model file size and memory cost
|
||||
- _update_data_vectorizer_hash() isolated hash overhead
|
||||
- predict_*() four independent calls per document — the 4x redundant
|
||||
vectorization path used by the signal handlers
|
||||
- _vectorize() cache-miss vs cache-hit breakdown
|
||||
|
||||
Memory: tracemalloc (delta + peak + top-20 allocation sites).
|
||||
CPU: cProfile sorted by cumulative time (top 30).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from django.test import override_settings
|
||||
from faker import Faker
|
||||
from profiling import measure_memory
|
||||
from profiling import profile_cpu
|
||||
|
||||
from documents.classifier import DocumentClassifier
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import MatchingModel
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Corpus parameters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
NUM_DOCS = 5_000
|
||||
NUM_CORRESPONDENTS = 40 # first 25 are MATCH_AUTO
|
||||
NUM_DOC_TYPES = 25 # first 15 are MATCH_AUTO
|
||||
NUM_TAGS = 50 # first 30 are MATCH_AUTO
|
||||
NUM_STORAGE_PATHS = 20 # first 12 are MATCH_AUTO
|
||||
|
||||
NUM_AUTO_CORRESPONDENTS = 25
|
||||
NUM_AUTO_DOC_TYPES = 15
|
||||
NUM_AUTO_TAGS = 30
|
||||
NUM_AUTO_STORAGE_PATHS = 12
|
||||
|
||||
SEED = 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Content generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_label_fingerprint(
|
||||
fake: Faker,
|
||||
label_seed: int,
|
||||
n_words: int = 6,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Generate a small set of unique-looking words to use as the learning
|
||||
fingerprint for a label. Each label gets its own seeded Faker so the
|
||||
fingerprints are distinct and reproducible.
|
||||
"""
|
||||
per_label_fake = Faker()
|
||||
per_label_fake.seed_instance(label_seed)
|
||||
# Mix word() and last_name() to get varied, pronounceable tokens
|
||||
words: list[str] = []
|
||||
while len(words) < n_words:
|
||||
w = per_label_fake.word().lower()
|
||||
if w not in words:
|
||||
words.append(w)
|
||||
return words
|
||||
|
||||
|
||||
def _build_fingerprints(
|
||||
num_correspondents: int,
|
||||
num_doc_types: int,
|
||||
num_tags: int,
|
||||
num_paths: int,
|
||||
) -> tuple[list[list[str]], list[list[str]], list[list[str]], list[list[str]]]:
|
||||
"""Pre-generate per-label fingerprints. Expensive once, free to reuse."""
|
||||
fake = Faker()
|
||||
# Use deterministic seeds offset by type so fingerprints don't collide
|
||||
corr_fps = [
|
||||
_make_label_fingerprint(fake, 1_000 + i) for i in range(num_correspondents)
|
||||
]
|
||||
dtype_fps = [_make_label_fingerprint(fake, 2_000 + i) for i in range(num_doc_types)]
|
||||
tag_fps = [_make_label_fingerprint(fake, 3_000 + i) for i in range(num_tags)]
|
||||
path_fps = [_make_label_fingerprint(fake, 4_000 + i) for i in range(num_paths)]
|
||||
return corr_fps, dtype_fps, tag_fps, path_fps
|
||||
|
||||
|
||||
def _build_content(
|
||||
fake: Faker,
|
||||
corr_fp: list[str] | None,
|
||||
dtype_fp: list[str] | None,
|
||||
tag_fps: list[list[str]],
|
||||
path_fp: list[str] | None,
|
||||
) -> str:
|
||||
"""
|
||||
Combine a Faker paragraph (realistic base text) with per-label
|
||||
fingerprint words so the classifier has a genuine learning signal.
|
||||
"""
|
||||
# 3-sentence paragraph provides realistic vocabulary
|
||||
base = fake.paragraph(nb_sentences=3)
|
||||
|
||||
extras: list[str] = []
|
||||
if corr_fp:
|
||||
extras.extend(corr_fp)
|
||||
if dtype_fp:
|
||||
extras.extend(dtype_fp)
|
||||
for fp in tag_fps:
|
||||
extras.extend(fp)
|
||||
if path_fp:
|
||||
extras.extend(path_fp)
|
||||
|
||||
if extras:
|
||||
return base + " " + " ".join(extras)
|
||||
return base
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-scoped corpus fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def module_db(django_db_setup, django_db_blocker):
|
||||
"""Unlock the DB for the whole module (module-scoped)."""
|
||||
with django_db_blocker.unblock():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def classifier_corpus(tmp_path_factory, module_db):
|
||||
"""
|
||||
Build the full 5 000-document corpus once for all profiling tests.
|
||||
|
||||
Label objects are created individually (small number), documents are
|
||||
bulk-inserted, and tag M2M rows go through the through-table.
|
||||
|
||||
Yields a dict with the model path and a sample content string for
|
||||
prediction tests. All rows are deleted on teardown.
|
||||
"""
|
||||
model_path: Path = tmp_path_factory.mktemp("cls_profile") / "model.pickle"
|
||||
|
||||
with override_settings(MODEL_FILE=model_path):
|
||||
fake = Faker()
|
||||
Faker.seed(SEED)
|
||||
rng = random.Random(SEED)
|
||||
|
||||
# Pre-generate fingerprints for all labels
|
||||
print("\n[setup] Generating label fingerprints...")
|
||||
corr_fps, dtype_fps, tag_fps, path_fps = _build_fingerprints(
|
||||
NUM_CORRESPONDENTS,
|
||||
NUM_DOC_TYPES,
|
||||
NUM_TAGS,
|
||||
NUM_STORAGE_PATHS,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# 1. Create label objects
|
||||
# -----------------------------------------------------------------
|
||||
print(f"[setup] Creating {NUM_CORRESPONDENTS} correspondents...")
|
||||
correspondents: list[Correspondent] = []
|
||||
for i in range(NUM_CORRESPONDENTS):
|
||||
algo = (
|
||||
MatchingModel.MATCH_AUTO
|
||||
if i < NUM_AUTO_CORRESPONDENTS
|
||||
else MatchingModel.MATCH_NONE
|
||||
)
|
||||
correspondents.append(
|
||||
Correspondent.objects.create(
|
||||
name=fake.company(),
|
||||
matching_algorithm=algo,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"[setup] Creating {NUM_DOC_TYPES} document types...")
|
||||
doc_types: list[DocumentType] = []
|
||||
for i in range(NUM_DOC_TYPES):
|
||||
algo = (
|
||||
MatchingModel.MATCH_AUTO
|
||||
if i < NUM_AUTO_DOC_TYPES
|
||||
else MatchingModel.MATCH_NONE
|
||||
)
|
||||
doc_types.append(
|
||||
DocumentType.objects.create(
|
||||
name=fake.bs()[:64],
|
||||
matching_algorithm=algo,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"[setup] Creating {NUM_TAGS} tags...")
|
||||
tags: list[Tag] = []
|
||||
for i in range(NUM_TAGS):
|
||||
algo = (
|
||||
MatchingModel.MATCH_AUTO
|
||||
if i < NUM_AUTO_TAGS
|
||||
else MatchingModel.MATCH_NONE
|
||||
)
|
||||
tags.append(
|
||||
Tag.objects.create(
|
||||
name=f"{fake.word()} {i}",
|
||||
matching_algorithm=algo,
|
||||
is_inbox_tag=False,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"[setup] Creating {NUM_STORAGE_PATHS} storage paths...")
|
||||
storage_paths: list[StoragePath] = []
|
||||
for i in range(NUM_STORAGE_PATHS):
|
||||
algo = (
|
||||
MatchingModel.MATCH_AUTO
|
||||
if i < NUM_AUTO_STORAGE_PATHS
|
||||
else MatchingModel.MATCH_NONE
|
||||
)
|
||||
storage_paths.append(
|
||||
StoragePath.objects.create(
|
||||
name=fake.word(),
|
||||
path=f"{fake.word()}/{fake.word()}/{{title}}",
|
||||
matching_algorithm=algo,
|
||||
),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# 2. Build document rows and M2M assignments
|
||||
# -----------------------------------------------------------------
|
||||
print(f"[setup] Building {NUM_DOCS} document rows...")
|
||||
doc_rows: list[Document] = []
|
||||
doc_tag_map: list[tuple[int, int]] = [] # (doc_position, tag_index)
|
||||
|
||||
for i in range(NUM_DOCS):
|
||||
corr_idx = (
|
||||
rng.randrange(NUM_CORRESPONDENTS) if rng.random() < 0.80 else None
|
||||
)
|
||||
dt_idx = rng.randrange(NUM_DOC_TYPES) if rng.random() < 0.80 else None
|
||||
sp_idx = rng.randrange(NUM_STORAGE_PATHS) if rng.random() < 0.70 else None
|
||||
|
||||
# 1-4 tags; most documents get at least one
|
||||
n_tags = rng.randint(1, 4) if rng.random() < 0.85 else 0
|
||||
assigned_tag_indices = rng.sample(range(NUM_TAGS), min(n_tags, NUM_TAGS))
|
||||
|
||||
content = _build_content(
|
||||
fake,
|
||||
corr_fp=corr_fps[corr_idx] if corr_idx is not None else None,
|
||||
dtype_fp=dtype_fps[dt_idx] if dt_idx is not None else None,
|
||||
tag_fps=[tag_fps[ti] for ti in assigned_tag_indices],
|
||||
path_fp=path_fps[sp_idx] if sp_idx is not None else None,
|
||||
)
|
||||
|
||||
doc_rows.append(
|
||||
Document(
|
||||
title=fake.sentence(nb_words=5),
|
||||
content=content,
|
||||
checksum=f"{i:064x}",
|
||||
correspondent=correspondents[corr_idx]
|
||||
if corr_idx is not None
|
||||
else None,
|
||||
document_type=doc_types[dt_idx] if dt_idx is not None else None,
|
||||
storage_path=storage_paths[sp_idx] if sp_idx is not None else None,
|
||||
),
|
||||
)
|
||||
for ti in assigned_tag_indices:
|
||||
doc_tag_map.append((i, ti))
|
||||
|
||||
t0 = time.perf_counter()
|
||||
Document.objects.bulk_create(doc_rows, batch_size=500)
|
||||
print(
|
||||
f"[setup] bulk_create {NUM_DOCS} documents: {time.perf_counter() - t0:.2f}s",
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# 3. Bulk-create M2M through-table rows
|
||||
# -----------------------------------------------------------------
|
||||
created_docs = list(Document.objects.order_by("pk"))
|
||||
through_rows = [
|
||||
Document.tags.through(
|
||||
document_id=created_docs[pos].pk,
|
||||
tag_id=tags[ti].pk,
|
||||
)
|
||||
for pos, ti in doc_tag_map
|
||||
if pos < len(created_docs)
|
||||
]
|
||||
t0 = time.perf_counter()
|
||||
Document.tags.through.objects.bulk_create(
|
||||
through_rows,
|
||||
batch_size=1_000,
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
print(
|
||||
f"[setup] bulk_create {len(through_rows)} tag M2M rows: "
|
||||
f"{time.perf_counter() - t0:.2f}s",
|
||||
)
|
||||
|
||||
# Sample content for prediction tests
|
||||
sample_content = _build_content(
|
||||
fake,
|
||||
corr_fp=corr_fps[0],
|
||||
dtype_fp=dtype_fps[0],
|
||||
tag_fps=[tag_fps[0], tag_fps[1], tag_fps[5]],
|
||||
path_fp=path_fps[0],
|
||||
)
|
||||
|
||||
yield {
|
||||
"model_path": model_path,
|
||||
"sample_content": sample_content,
|
||||
}
|
||||
|
||||
# Teardown
|
||||
print("\n[teardown] Removing corpus...")
|
||||
Document.objects.all().delete()
|
||||
Correspondent.objects.all().delete()
|
||||
DocumentType.objects.all().delete()
|
||||
Tag.objects.all().delete()
|
||||
StoragePath.objects.all().delete()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training profiles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClassifierTrainingProfile:
|
||||
"""Profile DocumentClassifier.train() on the full corpus."""
|
||||
|
||||
def test_train_memory(self, classifier_corpus, tmp_path):
|
||||
"""
|
||||
Peak memory allocated during train().
|
||||
tracemalloc reports the delta and top allocation sites.
|
||||
"""
|
||||
model_path = tmp_path / "model.pickle"
|
||||
with override_settings(MODEL_FILE=model_path):
|
||||
classifier = DocumentClassifier()
|
||||
|
||||
result, _, _ = measure_memory(
|
||||
classifier.train,
|
||||
label=(
|
||||
f"train() [{NUM_DOCS} docs | "
|
||||
f"{NUM_CORRESPONDENTS} correspondents ({NUM_AUTO_CORRESPONDENTS} AUTO) | "
|
||||
f"{NUM_DOC_TYPES} doc types ({NUM_AUTO_DOC_TYPES} AUTO) | "
|
||||
f"{NUM_TAGS} tags ({NUM_AUTO_TAGS} AUTO) | "
|
||||
f"{NUM_STORAGE_PATHS} paths ({NUM_AUTO_STORAGE_PATHS} AUTO)]"
|
||||
),
|
||||
)
|
||||
assert result is True, "train() must return True on first run"
|
||||
|
||||
print("\n Classifiers trained:")
|
||||
print(
|
||||
f" tags_classifier: {classifier.tags_classifier is not None}",
|
||||
)
|
||||
print(
|
||||
f" correspondent_classifier: {classifier.correspondent_classifier is not None}",
|
||||
)
|
||||
print(
|
||||
f" document_type_classifier: {classifier.document_type_classifier is not None}",
|
||||
)
|
||||
print(
|
||||
f" storage_path_classifier: {classifier.storage_path_classifier is not None}",
|
||||
)
|
||||
if classifier.data_vectorizer is not None:
|
||||
vocab_size = len(classifier.data_vectorizer.vocabulary_)
|
||||
print(f" vocabulary size: {vocab_size} terms")
|
||||
|
||||
def test_train_cpu(self, classifier_corpus, tmp_path):
|
||||
"""
|
||||
CPU profile of train() — shows time spent in DB queries,
|
||||
CountVectorizer.fit_transform(), and four MLPClassifier.fit() calls.
|
||||
"""
|
||||
model_path = tmp_path / "model_cpu.pickle"
|
||||
with override_settings(MODEL_FILE=model_path):
|
||||
classifier = DocumentClassifier()
|
||||
profile_cpu(
|
||||
classifier.train,
|
||||
label=f"train() [{NUM_DOCS} docs]",
|
||||
top=30,
|
||||
)
|
||||
|
||||
def test_train_second_call_noop(self, classifier_corpus, tmp_path):
|
||||
"""
|
||||
No-op path: second train() on unchanged data should return False.
|
||||
Still queries the DB to build the hash — shown here as the remaining cost.
|
||||
"""
|
||||
model_path = tmp_path / "model_noop.pickle"
|
||||
with override_settings(MODEL_FILE=model_path):
|
||||
classifier = DocumentClassifier()
|
||||
|
||||
t0 = time.perf_counter()
|
||||
classifier.train()
|
||||
first_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
result, second_elapsed = profile_cpu(
|
||||
classifier.train,
|
||||
label="train() second call (no-op — same data unchanged)",
|
||||
top=20,
|
||||
)
|
||||
assert result is False, "second train() should skip and return False"
|
||||
|
||||
print(f"\n First train: {first_ms:.1f} ms (full fit)")
|
||||
print(f" Second train: {second_elapsed * 1000:.1f} ms (skip)")
|
||||
print(f" Speedup: {first_ms / (second_elapsed * 1000):.1f}x")
|
||||
|
||||
def test_vectorizer_hash_cost(self, classifier_corpus, tmp_path):
|
||||
"""
|
||||
Isolate _update_data_vectorizer_hash() — pickles the entire
|
||||
CountVectorizer just to SHA256 it. Called at both save and load.
|
||||
"""
|
||||
import pickle
|
||||
|
||||
model_path = tmp_path / "model_hash.pickle"
|
||||
with override_settings(MODEL_FILE=model_path):
|
||||
classifier = DocumentClassifier()
|
||||
classifier.train()
|
||||
|
||||
profile_cpu(
|
||||
classifier._update_data_vectorizer_hash,
|
||||
label="_update_data_vectorizer_hash() [pickle.dumps vectorizer + sha256]",
|
||||
top=10,
|
||||
)
|
||||
|
||||
pickled_size = len(pickle.dumps(classifier.data_vectorizer))
|
||||
vocab_size = len(classifier.data_vectorizer.vocabulary_)
|
||||
print(f"\n Vocabulary size: {vocab_size} terms")
|
||||
print(f" Pickled vectorizer: {pickled_size / 1024:.1f} KiB")
|
||||
|
||||
def test_save_load_roundtrip(self, classifier_corpus, tmp_path):
|
||||
"""
|
||||
Profile save() and load() — model file size directly reflects how
|
||||
much memory the classifier occupies on disk (and roughly in RAM).
|
||||
"""
|
||||
model_path = tmp_path / "model_saveload.pickle"
|
||||
with override_settings(MODEL_FILE=model_path):
|
||||
classifier = DocumentClassifier()
|
||||
classifier.train()
|
||||
|
||||
_, save_peak, _ = measure_memory(
|
||||
classifier.save,
|
||||
label="save() [pickle.dumps + HMAC + atomic rename]",
|
||||
)
|
||||
|
||||
file_size_kib = model_path.stat().st_size / 1024
|
||||
print(f"\n Model file size: {file_size_kib:.1f} KiB")
|
||||
|
||||
classifier2 = DocumentClassifier()
|
||||
_, load_peak, _ = measure_memory(
|
||||
classifier2.load,
|
||||
label="load() [read file + verify HMAC + pickle.loads]",
|
||||
)
|
||||
|
||||
print("\n Summary:")
|
||||
print(f" Model file size: {file_size_kib:.1f} KiB")
|
||||
print(f" Save peak memory: {save_peak:.1f} KiB")
|
||||
print(f" Load peak memory: {load_peak:.1f} KiB")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prediction profiles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClassifierPredictionProfile:
|
||||
"""
|
||||
Profile the four predict_*() methods — specifically the redundant
|
||||
per-call vectorization overhead from the signal handler pattern.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def trained_classifier(self, classifier_corpus, tmp_path):
|
||||
model_path = tmp_path / "model_pred.pickle"
|
||||
self._ctx = override_settings(MODEL_FILE=model_path)
|
||||
self._ctx.enable()
|
||||
self.classifier = DocumentClassifier()
|
||||
self.classifier.train()
|
||||
self.content = classifier_corpus["sample_content"]
|
||||
yield
|
||||
self._ctx.disable()
|
||||
|
||||
def test_predict_all_four_separately_cpu(self):
|
||||
"""
|
||||
Profile all four predict_*() calls in the order the signal handlers
|
||||
fire them. Call 1 is a cache miss; calls 2-4 hit the locmem cache
|
||||
but still pay sha256 + pickle.loads each time.
|
||||
"""
|
||||
from django.core.cache import caches
|
||||
|
||||
caches["read-cache"].clear()
|
||||
|
||||
content = self.content
|
||||
print(f"\n Content length: {len(content)} chars")
|
||||
|
||||
calls = [
|
||||
("predict_correspondent", self.classifier.predict_correspondent),
|
||||
("predict_document_type", self.classifier.predict_document_type),
|
||||
("predict_tags", self.classifier.predict_tags),
|
||||
("predict_storage_path", self.classifier.predict_storage_path),
|
||||
]
|
||||
|
||||
timings: list[tuple[str, float]] = []
|
||||
for name, fn in calls:
|
||||
_, elapsed = profile_cpu(
|
||||
lambda f=fn: f(content),
|
||||
label=f"{name}() [call {len(timings) + 1}/4]",
|
||||
top=15,
|
||||
)
|
||||
timings.append((name, elapsed * 1000))
|
||||
|
||||
print("\n Per-call timings (sequential, locmem cache):")
|
||||
for name, ms in timings:
|
||||
print(f" {name:<32s} {ms:8.3f} ms")
|
||||
print(f" {'TOTAL':<32s} {sum(t for _, t in timings):8.3f} ms")
|
||||
|
||||
def test_predict_all_four_memory(self):
|
||||
"""
|
||||
Memory allocated for the full four-prediction sequence, both cold
|
||||
and warm, to show pickle serialization allocation per call.
|
||||
"""
|
||||
from django.core.cache import caches
|
||||
|
||||
content = self.content
|
||||
calls = [
|
||||
self.classifier.predict_correspondent,
|
||||
self.classifier.predict_document_type,
|
||||
self.classifier.predict_tags,
|
||||
self.classifier.predict_storage_path,
|
||||
]
|
||||
|
||||
caches["read-cache"].clear()
|
||||
measure_memory(
|
||||
lambda: [fn(content) for fn in calls],
|
||||
label="all four predict_*() [cache COLD — first call misses]",
|
||||
)
|
||||
|
||||
measure_memory(
|
||||
lambda: [fn(content) for fn in calls],
|
||||
label="all four predict_*() [cache WARM — all calls hit]",
|
||||
)
|
||||
|
||||
def test_vectorize_cache_miss_vs_hit(self):
|
||||
"""
|
||||
Isolate the cost of a cache miss (sha256 + transform + pickle.dumps)
|
||||
vs a cache hit (sha256 + pickle.loads).
|
||||
"""
|
||||
from django.core.cache import caches
|
||||
|
||||
read_cache = caches["read-cache"]
|
||||
content = self.content
|
||||
|
||||
read_cache.clear()
|
||||
_, miss_elapsed = profile_cpu(
|
||||
lambda: self.classifier._vectorize(content),
|
||||
label="_vectorize() [MISS: sha256 + transform + pickle.dumps]",
|
||||
top=15,
|
||||
)
|
||||
|
||||
_, hit_elapsed = profile_cpu(
|
||||
lambda: self.classifier._vectorize(content),
|
||||
label="_vectorize() [HIT: sha256 + pickle.loads]",
|
||||
top=15,
|
||||
)
|
||||
|
||||
print(f"\n Cache miss: {miss_elapsed * 1000:.3f} ms")
|
||||
print(f" Cache hit: {hit_elapsed * 1000:.3f} ms")
|
||||
print(f" Hit is {miss_elapsed / hit_elapsed:.1f}x faster than miss")
|
||||
|
||||
def test_content_hash_overhead(self):
|
||||
"""
|
||||
Micro-benchmark the sha256 of the content string — paid on every
|
||||
_vectorize() call regardless of cache state, including x4 per doc.
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
content = self.content
|
||||
encoded = content.encode()
|
||||
runs = 5_000
|
||||
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(runs):
|
||||
hashlib.sha256(encoded).hexdigest()
|
||||
us_per_call = (time.perf_counter() - t0) / runs * 1_000_000
|
||||
|
||||
print(f"\n Content: {len(content)} chars / {len(encoded)} bytes")
|
||||
print(f" sha256 cost per call: {us_per_call:.2f} us (avg over {runs} runs)")
|
||||
print(f" x4 calls per document: {us_per_call * 4:.2f} us total overhead")
|
||||
293
test_doclist_profile.py
Normal file
293
test_doclist_profile.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
Document list API profiling — no search, pure ORM path.
|
||||
|
||||
Run with:
|
||||
uv run pytest ../test_doclist_profile.py \
|
||||
-m profiling --override-ini="addopts=" -s -v
|
||||
|
||||
Corpus: 5 000 documents, 30 correspondents, 20 doc types, 80 tags,
|
||||
~500 notes (10 %), 10 custom fields with instances on ~50 % of docs.
|
||||
|
||||
Scenarios
|
||||
---------
|
||||
TestDocListProfile
|
||||
- test_list_default_ordering GET /api/documents/ created desc, page 1, page_size=25
|
||||
- test_list_title_ordering same with ordering=title
|
||||
- test_list_page_size_comparison page_size=10 / 25 / 100 in sequence
|
||||
- test_list_detail_fields GET /api/documents/{id}/ — single document serializer cost
|
||||
- test_list_cpu_profile cProfile of one list request
|
||||
|
||||
TestSelectionDataProfile
|
||||
- test_selection_data_unfiltered _get_selection_data_for_queryset(all docs) in isolation
|
||||
- test_selection_data_via_api GET /api/documents/?include_selection_data=true
|
||||
- test_selection_data_filtered filtered vs unfiltered COUNT query comparison
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import random
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from django.contrib.auth.models import User
|
||||
from faker import Faker
|
||||
from profiling import profile_block
|
||||
from profiling import profile_cpu
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import CustomField
|
||||
from documents.models import CustomFieldInstance
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Note
|
||||
from documents.models import Tag
|
||||
from documents.views import DocumentViewSet
|
||||
|
||||
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Corpus parameters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
NUM_DOCS = 5_000
|
||||
NUM_CORRESPONDENTS = 30
|
||||
NUM_DOC_TYPES = 20
|
||||
NUM_TAGS = 80
|
||||
NOTE_FRACTION = 0.10
|
||||
CUSTOM_FIELD_COUNT = 10
|
||||
CUSTOM_FIELD_FRACTION = 0.50
|
||||
PAGE_SIZE = 25
|
||||
SEED = 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-scoped corpus fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def module_db(django_db_setup, django_db_blocker):
|
||||
"""Unlock the DB for the whole module (module-scoped)."""
|
||||
with django_db_blocker.unblock():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def doclist_corpus(module_db):
|
||||
"""
|
||||
Build a 5 000-document corpus with tags, notes, custom fields, correspondents,
|
||||
and doc types. All objects are deleted on teardown.
|
||||
"""
|
||||
fake = Faker()
|
||||
Faker.seed(SEED)
|
||||
rng = random.Random(SEED)
|
||||
|
||||
print(f"\n[setup] Creating {NUM_CORRESPONDENTS} correspondents...") # noqa: T201
|
||||
correspondents = [
|
||||
Correspondent.objects.create(name=f"dlcorp-{i}-{fake.company()}"[:128])
|
||||
for i in range(NUM_CORRESPONDENTS)
|
||||
]
|
||||
|
||||
print(f"[setup] Creating {NUM_DOC_TYPES} doc types...") # noqa: T201
|
||||
doc_types = [
|
||||
DocumentType.objects.create(name=f"dltype-{i}-{fake.word()}"[:128])
|
||||
for i in range(NUM_DOC_TYPES)
|
||||
]
|
||||
|
||||
print(f"[setup] Creating {NUM_TAGS} tags...") # noqa: T201
|
||||
tags = [
|
||||
Tag.objects.create(name=f"dltag-{i}-{fake.word()}"[:100])
|
||||
for i in range(NUM_TAGS)
|
||||
]
|
||||
|
||||
print(f"[setup] Creating {CUSTOM_FIELD_COUNT} custom fields...") # noqa: T201
|
||||
custom_fields = [
|
||||
CustomField.objects.create(
|
||||
name=f"Field {i}",
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
for i in range(CUSTOM_FIELD_COUNT)
|
||||
]
|
||||
|
||||
note_user = User.objects.create_user(username="doclistnoteuser", password="x")
|
||||
owner = User.objects.create_superuser(username="doclistowner", password="admin")
|
||||
|
||||
print(f"[setup] Building {NUM_DOCS} document rows...") # noqa: T201
|
||||
base_date = datetime.date(2018, 1, 1)
|
||||
raw_docs = []
|
||||
for i in range(NUM_DOCS):
|
||||
day_offset = rng.randint(0, 6 * 365)
|
||||
raw_docs.append(
|
||||
Document(
|
||||
title=fake.sentence(nb_words=rng.randint(3, 8)).rstrip("."),
|
||||
content="\n\n".join(
|
||||
fake.paragraph(nb_sentences=rng.randint(2, 5))
|
||||
for _ in range(rng.randint(1, 3))
|
||||
),
|
||||
checksum=f"DL{i:07d}",
|
||||
correspondent=rng.choice(correspondents + [None] * 5),
|
||||
document_type=rng.choice(doc_types + [None] * 4),
|
||||
created=base_date + datetime.timedelta(days=day_offset),
|
||||
owner=owner if rng.random() < 0.8 else None,
|
||||
),
|
||||
)
|
||||
t0 = time.perf_counter()
|
||||
documents = Document.objects.bulk_create(raw_docs)
|
||||
print(f"[setup] bulk_create {NUM_DOCS} docs: {time.perf_counter() - t0:.2f}s") # noqa: T201
|
||||
|
||||
t0 = time.perf_counter()
|
||||
for doc in documents:
|
||||
k = rng.randint(0, 5)
|
||||
if k:
|
||||
doc.tags.add(*rng.sample(tags, k))
|
||||
print(f"[setup] tag M2M assignments: {time.perf_counter() - t0:.2f}s") # noqa: T201
|
||||
|
||||
note_docs = rng.sample(documents, int(NUM_DOCS * NOTE_FRACTION))
|
||||
Note.objects.bulk_create(
|
||||
[
|
||||
Note(
|
||||
document=doc,
|
||||
note=fake.sentence(nb_words=rng.randint(4, 15)),
|
||||
user=note_user,
|
||||
)
|
||||
for doc in note_docs
|
||||
],
|
||||
)
|
||||
|
||||
cf_docs = rng.sample(documents, int(NUM_DOCS * CUSTOM_FIELD_FRACTION))
|
||||
CustomFieldInstance.objects.bulk_create(
|
||||
[
|
||||
CustomFieldInstance(
|
||||
document=doc,
|
||||
field=rng.choice(custom_fields),
|
||||
value_text=fake.word(),
|
||||
)
|
||||
for doc in cf_docs
|
||||
],
|
||||
)
|
||||
|
||||
first_doc_pk = documents[0].pk
|
||||
|
||||
yield {"owner": owner, "first_doc_pk": first_doc_pk, "tags": tags}
|
||||
|
||||
print("\n[teardown] Removing doclist corpus...") # noqa: T201
|
||||
Document.objects.all().delete()
|
||||
Correspondent.objects.all().delete()
|
||||
DocumentType.objects.all().delete()
|
||||
Tag.objects.all().delete()
|
||||
CustomField.objects.all().delete()
|
||||
User.objects.filter(username__in=["doclistnoteuser", "doclistowner"]).delete()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestDocListProfile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDocListProfile:
|
||||
"""Profile GET /api/documents/ — pure ORM path, no Tantivy."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _client(self, doclist_corpus):
|
||||
owner = doclist_corpus["owner"]
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=owner)
|
||||
self.first_doc_pk = doclist_corpus["first_doc_pk"]
|
||||
|
||||
def test_list_default_ordering(self):
|
||||
"""GET /api/documents/ default ordering (-created), page 1, page_size=25."""
|
||||
with profile_block(
|
||||
f"GET /api/documents/ default ordering [page_size={PAGE_SIZE}]",
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?page=1&page_size={PAGE_SIZE}",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_list_title_ordering(self):
|
||||
"""GET /api/documents/ ordered by title — tests ORM sort path."""
|
||||
with profile_block(
|
||||
f"GET /api/documents/?ordering=title [page_size={PAGE_SIZE}]",
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?ordering=title&page=1&page_size={PAGE_SIZE}",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_list_page_size_comparison(self):
|
||||
"""Compare serializer cost at page_size=10, 25, 100."""
|
||||
for page_size in [10, 25, 100]:
|
||||
with profile_block(f"GET /api/documents/ [page_size={page_size}]"):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?page=1&page_size={page_size}",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_list_detail_fields(self):
|
||||
"""GET /api/documents/{id}/ — per-doc serializer cost with all relations."""
|
||||
pk = self.first_doc_pk
|
||||
with profile_block(f"GET /api/documents/{pk}/ — single doc serializer"):
|
||||
response = self.client.get(f"/api/documents/{pk}/")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_list_cpu_profile(self):
|
||||
"""cProfile of one list request — surfaces hot frames in serializer."""
|
||||
profile_cpu(
|
||||
lambda: self.client.get(
|
||||
f"/api/documents/?page=1&page_size={PAGE_SIZE}",
|
||||
),
|
||||
label=f"GET /api/documents/ cProfile [page_size={PAGE_SIZE}]",
|
||||
top=30,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSelectionDataProfile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSelectionDataProfile:
|
||||
"""Profile _get_selection_data_for_queryset — the 5+ COUNT queries per request."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self, doclist_corpus):
|
||||
owner = doclist_corpus["owner"]
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=owner)
|
||||
self.tags = doclist_corpus["tags"]
|
||||
|
||||
def test_selection_data_unfiltered(self):
|
||||
"""Call _get_selection_data_for_queryset(all docs) directly — COUNT queries in isolation."""
|
||||
viewset = DocumentViewSet()
|
||||
qs = Document.objects.all()
|
||||
|
||||
with profile_block("_get_selection_data_for_queryset(all docs) — direct call"):
|
||||
viewset._get_selection_data_for_queryset(qs)
|
||||
|
||||
def test_selection_data_via_api(self):
|
||||
"""Full API round-trip with include_selection_data=true."""
|
||||
with profile_block(
|
||||
f"GET /api/documents/?include_selection_data=true [page_size={PAGE_SIZE}]",
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?page=1&page_size={PAGE_SIZE}&include_selection_data=true",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "selection_data" in response.data
|
||||
|
||||
def test_selection_data_filtered(self):
|
||||
"""selection_data on a tag-filtered queryset — filtered COUNT vs unfiltered."""
|
||||
tag = self.tags[0]
|
||||
viewset = DocumentViewSet()
|
||||
filtered_qs = Document.objects.filter(tags=tag)
|
||||
unfiltered_qs = Document.objects.all()
|
||||
|
||||
print(f"\n Tag '{tag.name}' matches {filtered_qs.count()} docs") # noqa: T201
|
||||
|
||||
with profile_block("_get_selection_data_for_queryset(unfiltered)"):
|
||||
viewset._get_selection_data_for_queryset(unfiltered_qs)
|
||||
|
||||
with profile_block("_get_selection_data_for_queryset(filtered by tag)"):
|
||||
viewset._get_selection_data_for_queryset(filtered_qs)
|
||||
284
test_matching_profile.py
Normal file
284
test_matching_profile.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Matching pipeline profiling.
|
||||
|
||||
Run with:
|
||||
uv run pytest ../test_matching_profile.py \
|
||||
-m profiling --override-ini="addopts=" -s -v
|
||||
|
||||
Corpus: 1 document + 50 correspondents, 100 tags, 25 doc types, 20 storage
|
||||
paths. Labels are spread across all six matching algorithms
|
||||
(NONE, ANY, ALL, LITERAL, REGEX, FUZZY, AUTO).
|
||||
|
||||
Classifier is passed as None -- MATCH_AUTO models skip prediction gracefully,
|
||||
which is correct for isolating the ORM query and Python-side evaluation cost.
|
||||
|
||||
Scenarios
|
||||
---------
|
||||
TestMatchingPipelineProfile
|
||||
- test_match_correspondents 50 correspondents, algorithm mix
|
||||
- test_match_tags 100 tags
|
||||
- test_match_document_types 25 doc types
|
||||
- test_match_storage_paths 20 storage paths
|
||||
- test_full_match_sequence all four in order (cumulative consumption cost)
|
||||
- test_algorithm_breakdown each MATCH_* algorithm in isolation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from profiling import profile_block
|
||||
|
||||
from documents.matching import match_correspondents
|
||||
from documents.matching import match_document_types
|
||||
from documents.matching import match_storage_paths
|
||||
from documents.matching import match_tags
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import MatchingModel
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
|
||||
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
|
||||
|
||||
NUM_CORRESPONDENTS = 50
|
||||
NUM_TAGS = 100
|
||||
NUM_DOC_TYPES = 25
|
||||
NUM_STORAGE_PATHS = 20
|
||||
SEED = 42
|
||||
|
||||
# Algorithm distribution across labels (cycles through in order)
|
||||
_ALGORITHMS = [
|
||||
MatchingModel.MATCH_NONE,
|
||||
MatchingModel.MATCH_ANY,
|
||||
MatchingModel.MATCH_ALL,
|
||||
MatchingModel.MATCH_LITERAL,
|
||||
MatchingModel.MATCH_REGEX,
|
||||
MatchingModel.MATCH_FUZZY,
|
||||
MatchingModel.MATCH_AUTO,
|
||||
]
|
||||
|
||||
|
||||
def _algo(i: int) -> int:
|
||||
return _ALGORITHMS[i % len(_ALGORITHMS)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-scoped corpus fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def module_db(django_db_setup, django_db_blocker):
|
||||
"""Unlock the DB for the whole module (module-scoped)."""
|
||||
with django_db_blocker.unblock():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def matching_corpus(module_db):
|
||||
"""
|
||||
1 document with realistic content + dense matching model sets.
|
||||
Classifier=None so MATCH_AUTO models are simply skipped.
|
||||
"""
|
||||
fake = Faker()
|
||||
Faker.seed(SEED)
|
||||
random.seed(SEED)
|
||||
|
||||
# ---- matching models ---------------------------------------------------
|
||||
print(f"\n[setup] Creating {NUM_CORRESPONDENTS} correspondents...") # noqa: T201
|
||||
correspondents = []
|
||||
for i in range(NUM_CORRESPONDENTS):
|
||||
algo = _algo(i)
|
||||
match_text = (
|
||||
fake.word()
|
||||
if algo not in (MatchingModel.MATCH_NONE, MatchingModel.MATCH_AUTO)
|
||||
else ""
|
||||
)
|
||||
if algo == MatchingModel.MATCH_REGEX:
|
||||
match_text = r"\b" + fake.word() + r"\b"
|
||||
correspondents.append(
|
||||
Correspondent.objects.create(
|
||||
name=f"mcorp-{i}-{fake.company()}"[:128],
|
||||
matching_algorithm=algo,
|
||||
match=match_text,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"[setup] Creating {NUM_TAGS} tags...") # noqa: T201
|
||||
tags = []
|
||||
for i in range(NUM_TAGS):
|
||||
algo = _algo(i)
|
||||
match_text = (
|
||||
fake.word()
|
||||
if algo not in (MatchingModel.MATCH_NONE, MatchingModel.MATCH_AUTO)
|
||||
else ""
|
||||
)
|
||||
if algo == MatchingModel.MATCH_REGEX:
|
||||
match_text = r"\b" + fake.word() + r"\b"
|
||||
tags.append(
|
||||
Tag.objects.create(
|
||||
name=f"mtag-{i}-{fake.word()}"[:100],
|
||||
matching_algorithm=algo,
|
||||
match=match_text,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"[setup] Creating {NUM_DOC_TYPES} doc types...") # noqa: T201
|
||||
doc_types = []
|
||||
for i in range(NUM_DOC_TYPES):
|
||||
algo = _algo(i)
|
||||
match_text = (
|
||||
fake.word()
|
||||
if algo not in (MatchingModel.MATCH_NONE, MatchingModel.MATCH_AUTO)
|
||||
else ""
|
||||
)
|
||||
if algo == MatchingModel.MATCH_REGEX:
|
||||
match_text = r"\b" + fake.word() + r"\b"
|
||||
doc_types.append(
|
||||
DocumentType.objects.create(
|
||||
name=f"mtype-{i}-{fake.word()}"[:128],
|
||||
matching_algorithm=algo,
|
||||
match=match_text,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"[setup] Creating {NUM_STORAGE_PATHS} storage paths...") # noqa: T201
|
||||
storage_paths = []
|
||||
for i in range(NUM_STORAGE_PATHS):
|
||||
algo = _algo(i)
|
||||
match_text = (
|
||||
fake.word()
|
||||
if algo not in (MatchingModel.MATCH_NONE, MatchingModel.MATCH_AUTO)
|
||||
else ""
|
||||
)
|
||||
if algo == MatchingModel.MATCH_REGEX:
|
||||
match_text = r"\b" + fake.word() + r"\b"
|
||||
storage_paths.append(
|
||||
StoragePath.objects.create(
|
||||
name=f"mpath-{i}-{fake.word()}",
|
||||
path=f"{fake.word()}/{{title}}",
|
||||
matching_algorithm=algo,
|
||||
match=match_text,
|
||||
),
|
||||
)
|
||||
|
||||
# ---- document with diverse content ------------------------------------
|
||||
doc = Document.objects.create(
|
||||
title="quarterly invoice payment tax financial statement",
|
||||
content=" ".join(fake.paragraph(nb_sentences=5) for _ in range(3)),
|
||||
checksum="MATCHPROF0001",
|
||||
)
|
||||
|
||||
print(f"[setup] Document pk={doc.pk}, content length={len(doc.content)} chars") # noqa: T201
|
||||
print( # noqa: T201
|
||||
f" Correspondents: {NUM_CORRESPONDENTS} "
|
||||
f"({sum(1 for c in correspondents if c.matching_algorithm == MatchingModel.MATCH_AUTO)} AUTO)",
|
||||
)
|
||||
print( # noqa: T201
|
||||
f" Tags: {NUM_TAGS} "
|
||||
f"({sum(1 for t in tags if t.matching_algorithm == MatchingModel.MATCH_AUTO)} AUTO)",
|
||||
)
|
||||
|
||||
yield {"doc": doc}
|
||||
|
||||
# Teardown
|
||||
print("\n[teardown] Removing matching corpus...") # noqa: T201
|
||||
Document.objects.all().delete()
|
||||
Correspondent.objects.all().delete()
|
||||
Tag.objects.all().delete()
|
||||
DocumentType.objects.all().delete()
|
||||
StoragePath.objects.all().delete()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMatchingPipelineProfile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMatchingPipelineProfile:
|
||||
"""Profile the matching functions called per document during consumption."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self, matching_corpus):
|
||||
self.doc = matching_corpus["doc"]
|
||||
|
||||
def test_match_correspondents(self):
|
||||
"""50 correspondents, algorithm mix. Query count + time."""
|
||||
with profile_block(
|
||||
f"match_correspondents() [{NUM_CORRESPONDENTS} correspondents, mixed algorithms]",
|
||||
):
|
||||
result = match_correspondents(self.doc, classifier=None)
|
||||
print(f" -> {len(result)} matched") # noqa: T201
|
||||
|
||||
def test_match_tags(self):
|
||||
"""100 tags -- densest set in real installs."""
|
||||
with profile_block(f"match_tags() [{NUM_TAGS} tags, mixed algorithms]"):
|
||||
result = match_tags(self.doc, classifier=None)
|
||||
print(f" -> {len(result)} matched") # noqa: T201
|
||||
|
||||
def test_match_document_types(self):
|
||||
"""25 doc types."""
|
||||
with profile_block(
|
||||
f"match_document_types() [{NUM_DOC_TYPES} types, mixed algorithms]",
|
||||
):
|
||||
result = match_document_types(self.doc, classifier=None)
|
||||
print(f" -> {len(result)} matched") # noqa: T201
|
||||
|
||||
def test_match_storage_paths(self):
|
||||
"""20 storage paths."""
|
||||
with profile_block(
|
||||
f"match_storage_paths() [{NUM_STORAGE_PATHS} paths, mixed algorithms]",
|
||||
):
|
||||
result = match_storage_paths(self.doc, classifier=None)
|
||||
print(f" -> {len(result)} matched") # noqa: T201
|
||||
|
||||
def test_full_match_sequence(self):
|
||||
"""All four match_*() calls in order -- cumulative cost per document consumed."""
|
||||
with profile_block(
|
||||
"full match sequence: correspondents + doc_types + tags + storage_paths",
|
||||
):
|
||||
match_correspondents(self.doc, classifier=None)
|
||||
match_document_types(self.doc, classifier=None)
|
||||
match_tags(self.doc, classifier=None)
|
||||
match_storage_paths(self.doc, classifier=None)
|
||||
|
||||
def test_algorithm_breakdown(self):
|
||||
"""Create one correspondent per algorithm and time each independently."""
|
||||
import time
|
||||
|
||||
from documents.matching import matches
|
||||
|
||||
fake = Faker()
|
||||
algo_names = {
|
||||
MatchingModel.MATCH_NONE: "MATCH_NONE",
|
||||
MatchingModel.MATCH_ANY: "MATCH_ANY",
|
||||
MatchingModel.MATCH_ALL: "MATCH_ALL",
|
||||
MatchingModel.MATCH_LITERAL: "MATCH_LITERAL",
|
||||
MatchingModel.MATCH_REGEX: "MATCH_REGEX",
|
||||
MatchingModel.MATCH_FUZZY: "MATCH_FUZZY",
|
||||
}
|
||||
doc = self.doc
|
||||
print() # noqa: T201
|
||||
|
||||
for algo, name in algo_names.items():
|
||||
match_text = fake.word() if algo != MatchingModel.MATCH_NONE else ""
|
||||
if algo == MatchingModel.MATCH_REGEX:
|
||||
match_text = r"\b" + fake.word() + r"\b"
|
||||
model = Correspondent(
|
||||
name=f"algo-test-{name}",
|
||||
matching_algorithm=algo,
|
||||
match=match_text,
|
||||
)
|
||||
# Time 1000 iterations to get stable microsecond readings
|
||||
runs = 1_000
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(runs):
|
||||
matches(model, doc)
|
||||
us_per_call = (time.perf_counter() - t0) / runs * 1_000_000
|
||||
print( # noqa: T201
|
||||
f" {name:<20s} {us_per_call:8.2f} us/call (match={match_text[:20]!r})",
|
||||
)
|
||||
154
test_sanity_profile.py
Normal file
154
test_sanity_profile.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Sanity checker profiling.
|
||||
|
||||
Run with:
|
||||
uv run pytest ../test_sanity_profile.py \
|
||||
-m profiling --override-ini="addopts=" -s -v
|
||||
|
||||
Corpus: 2 000 documents with stub files (original + archive + thumbnail)
|
||||
created in a temp MEDIA_ROOT.
|
||||
|
||||
Scenarios
|
||||
---------
|
||||
TestSanityCheckerProfile
|
||||
- test_sanity_full_corpus full check_sanity() -- cProfile + tracemalloc
|
||||
- test_sanity_query_pattern profile_block summary: query count + time
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from django.test import override_settings
|
||||
from profiling import measure_memory
|
||||
from profiling import profile_block
|
||||
from profiling import profile_cpu
|
||||
|
||||
from documents.models import Document
|
||||
from documents.sanity_checker import check_sanity
|
||||
|
||||
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
|
||||
|
||||
NUM_DOCS = 2_000
|
||||
SEED = 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-scoped fixture: temp directories + corpus
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def module_db(django_db_setup, django_db_blocker):
|
||||
"""Unlock the DB for the whole module (module-scoped)."""
|
||||
with django_db_blocker.unblock():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sanity_corpus(tmp_path_factory, module_db):
|
||||
"""
|
||||
Build a 2 000-document corpus. For each document create stub files
|
||||
(1-byte placeholders) in ORIGINALS_DIR, ARCHIVE_DIR, and THUMBNAIL_DIR
|
||||
so the sanity checker's file-existence and checksum checks have real targets.
|
||||
"""
|
||||
media = tmp_path_factory.mktemp("sanity_media")
|
||||
originals_dir = media / "documents" / "originals"
|
||||
archive_dir = media / "documents" / "archive"
|
||||
thumb_dir = media / "documents" / "thumbnails"
|
||||
for d in (originals_dir, archive_dir, thumb_dir):
|
||||
d.mkdir(parents=True)
|
||||
|
||||
# Use override_settings as a context manager for the whole fixture lifetime
|
||||
settings_ctx = override_settings(
|
||||
MEDIA_ROOT=media,
|
||||
ORIGINALS_DIR=originals_dir,
|
||||
ARCHIVE_DIR=archive_dir,
|
||||
THUMBNAIL_DIR=thumb_dir,
|
||||
MEDIA_LOCK=media / "media.lock",
|
||||
)
|
||||
settings_ctx.enable()
|
||||
|
||||
print(f"\n[setup] Creating {NUM_DOCS} documents with stub files...") # noqa: T201
|
||||
t0 = time.perf_counter()
|
||||
docs = []
|
||||
for i in range(NUM_DOCS):
|
||||
content = f"document content for doc {i}"
|
||||
checksum = hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
orig_filename = f"{i:07d}.pdf"
|
||||
arch_filename = f"{i:07d}.pdf"
|
||||
|
||||
orig_path = originals_dir / orig_filename
|
||||
arch_path = archive_dir / arch_filename
|
||||
|
||||
orig_path.write_bytes(content.encode())
|
||||
arch_path.write_bytes(content.encode())
|
||||
|
||||
docs.append(
|
||||
Document(
|
||||
title=f"Document {i:05d}",
|
||||
content=content,
|
||||
checksum=checksum,
|
||||
archive_checksum=checksum,
|
||||
filename=orig_filename,
|
||||
archive_filename=arch_filename,
|
||||
mime_type="application/pdf",
|
||||
),
|
||||
)
|
||||
|
||||
created = Document.objects.bulk_create(docs, batch_size=500)
|
||||
|
||||
# Thumbnails use doc.pk, so create them after bulk_create assigns pks
|
||||
for doc in created:
|
||||
thumb_path = thumb_dir / f"{doc.pk:07d}.webp"
|
||||
thumb_path.write_bytes(b"\x00") # minimal thumbnail stub
|
||||
|
||||
print( # noqa: T201
|
||||
f"[setup] bulk_create + file creation: {time.perf_counter() - t0:.2f}s",
|
||||
)
|
||||
|
||||
yield {"media": media}
|
||||
|
||||
# Teardown
|
||||
print("\n[teardown] Removing sanity corpus...") # noqa: T201
|
||||
Document.objects.all().delete()
|
||||
settings_ctx.disable()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSanityCheckerProfile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanityCheckerProfile:
|
||||
"""Profile check_sanity() on a realistic corpus with real files."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self, sanity_corpus):
|
||||
self.media = sanity_corpus["media"]
|
||||
|
||||
def test_sanity_full_corpus(self):
|
||||
"""Full check_sanity() -- cProfile surfaces hot frames, tracemalloc shows peak."""
|
||||
_, elapsed = profile_cpu(
|
||||
lambda: check_sanity(scheduled=False),
|
||||
label=f"check_sanity() [{NUM_DOCS} docs, real files]",
|
||||
top=25,
|
||||
)
|
||||
|
||||
_, peak_kib, delta_kib = measure_memory(
|
||||
lambda: check_sanity(scheduled=False),
|
||||
label=f"check_sanity() [{NUM_DOCS} docs] -- memory",
|
||||
)
|
||||
|
||||
print("\n Summary:") # noqa: T201
|
||||
print(f" Wall time (CPU profile run): {elapsed * 1000:.1f} ms") # noqa: T201
|
||||
print(f" Peak memory (second run): {peak_kib:.1f} KiB") # noqa: T201
|
||||
print(f" Memory delta: {delta_kib:+.1f} KiB") # noqa: T201
|
||||
|
||||
def test_sanity_query_pattern(self):
|
||||
"""profile_block view: query count + query time + wall time in one summary."""
|
||||
with profile_block(f"check_sanity() [{NUM_DOCS} docs] -- query count"):
|
||||
check_sanity(scheduled=False)
|
||||
273
test_search_profiling.py
Normal file
273
test_search_profiling.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Search performance profiling tests.
|
||||
|
||||
Run explicitly — excluded from the normal test suite:
|
||||
|
||||
uv run pytest -m profiling -s -p no:xdist --override-ini="addopts=" -v
|
||||
|
||||
The ``-s`` flag is required to see profile_block() output.
|
||||
The ``-p no:xdist`` flag disables parallel execution for accurate measurements.
|
||||
|
||||
Corpus: 5 000 documents generated deterministically from a fixed Faker seed,
|
||||
with realistic variety: 30 correspondents, 15 document types, 50 tags, ~500
|
||||
notes spread across ~10 % of documents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from django.contrib.auth.models import User
|
||||
from faker import Faker
|
||||
from profiling import profile_block
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Note
|
||||
from documents.models import Tag
|
||||
from documents.search import get_backend
|
||||
from documents.search import reset_backend
|
||||
from documents.search._backend import SearchMode
|
||||
|
||||
pytestmark = [pytest.mark.profiling, pytest.mark.search, pytest.mark.django_db]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Corpus parameters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DOC_COUNT = 5_000
|
||||
SEED = 42
|
||||
NUM_CORRESPONDENTS = 30
|
||||
NUM_DOC_TYPES = 15
|
||||
NUM_TAGS = 50
|
||||
NOTE_FRACTION = 0.10 # ~500 documents get a note
|
||||
PAGE_SIZE = 25
|
||||
|
||||
|
||||
def _build_corpus(rng: random.Random, fake: Faker) -> None:
|
||||
"""
|
||||
Insert the full corpus into the database and index it.
|
||||
|
||||
Uses bulk_create for the Document rows (fast) then handles the M2M tag
|
||||
relationships and notes individually. Indexes the full corpus with a
|
||||
single backend.rebuild() call.
|
||||
"""
|
||||
import datetime
|
||||
|
||||
# ---- lookup objects -------------------------------------------------
|
||||
correspondents = [
|
||||
Correspondent.objects.create(name=f"profcorp-{i}-{fake.company()}"[:128])
|
||||
for i in range(NUM_CORRESPONDENTS)
|
||||
]
|
||||
doc_types = [
|
||||
DocumentType.objects.create(name=f"proftype-{i}-{fake.word()}"[:128])
|
||||
for i in range(NUM_DOC_TYPES)
|
||||
]
|
||||
tags = [
|
||||
Tag.objects.create(name=f"proftag-{i}-{fake.word()}"[:100])
|
||||
for i in range(NUM_TAGS)
|
||||
]
|
||||
note_user = User.objects.create_user(username="profnoteuser", password="x")
|
||||
|
||||
# ---- bulk-create documents ------------------------------------------
|
||||
base_date = datetime.date(2018, 1, 1)
|
||||
raw_docs = []
|
||||
for i in range(DOC_COUNT):
|
||||
day_offset = rng.randint(0, 6 * 365)
|
||||
created = base_date + datetime.timedelta(days=day_offset)
|
||||
raw_docs.append(
|
||||
Document(
|
||||
title=fake.sentence(nb_words=rng.randint(3, 9)).rstrip("."),
|
||||
content="\n\n".join(
|
||||
fake.paragraph(nb_sentences=rng.randint(3, 7))
|
||||
for _ in range(rng.randint(2, 5))
|
||||
),
|
||||
checksum=f"PROF{i:07d}",
|
||||
correspondent=rng.choice(correspondents + [None] * 8),
|
||||
document_type=rng.choice(doc_types + [None] * 4),
|
||||
created=created,
|
||||
),
|
||||
)
|
||||
documents = Document.objects.bulk_create(raw_docs)
|
||||
|
||||
# ---- tags (M2M, post-bulk) ------------------------------------------
|
||||
for doc in documents:
|
||||
k = rng.randint(0, 5)
|
||||
if k:
|
||||
doc.tags.add(*rng.sample(tags, k))
|
||||
|
||||
# ---- notes on ~10 % of docs -----------------------------------------
|
||||
note_docs = rng.sample(documents, int(DOC_COUNT * NOTE_FRACTION))
|
||||
for doc in note_docs:
|
||||
Note.objects.create(
|
||||
document=doc,
|
||||
note=fake.sentence(nb_words=rng.randint(6, 20)),
|
||||
user=note_user,
|
||||
)
|
||||
|
||||
# ---- build Tantivy index --------------------------------------------
|
||||
backend = get_backend()
|
||||
qs = Document.objects.select_related(
|
||||
"correspondent",
|
||||
"document_type",
|
||||
"storage_path",
|
||||
"owner",
|
||||
).prefetch_related("tags", "notes__user", "custom_fields__field")
|
||||
backend.rebuild(qs)
|
||||
|
||||
|
||||
class TestSearchProfiling:
|
||||
"""
|
||||
Performance profiling for the Tantivy search backend and DRF API layer.
|
||||
|
||||
Each test builds a fresh 5 000-document corpus, exercises one hot path,
|
||||
and prints profile_block() measurements to stdout. No correctness
|
||||
assertions — the goal is to surface hot spots and track regressions.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self, tmp_path, settings):
|
||||
index_dir = tmp_path / "index"
|
||||
index_dir.mkdir()
|
||||
settings.INDEX_DIR = index_dir
|
||||
|
||||
reset_backend()
|
||||
rng = random.Random(SEED)
|
||||
fake = Faker()
|
||||
Faker.seed(SEED)
|
||||
|
||||
self.user = User.objects.create_superuser(
|
||||
username="profiler",
|
||||
password="admin",
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
_build_corpus(rng, fake)
|
||||
yield
|
||||
reset_backend()
|
||||
|
||||
# -- 1. Backend: search_ids relevance ---------------------------------
|
||||
|
||||
def test_profile_search_ids_relevance(self):
|
||||
"""Profile: search_ids() with relevance ordering across several queries."""
|
||||
backend = get_backend()
|
||||
queries = [
|
||||
"invoice payment",
|
||||
"annual report",
|
||||
"bank statement",
|
||||
"contract agreement",
|
||||
"receipt",
|
||||
]
|
||||
with profile_block(f"search_ids — relevance ({len(queries)} queries)"):
|
||||
for q in queries:
|
||||
backend.search_ids(q, user=None)
|
||||
|
||||
# -- 2. Backend: search_ids with Tantivy-native sort ------------------
|
||||
|
||||
def test_profile_search_ids_sorted(self):
|
||||
"""Profile: search_ids() sorted by a Tantivy fast field (created)."""
|
||||
backend = get_backend()
|
||||
with profile_block("search_ids — sorted by created (asc + desc)"):
|
||||
backend.search_ids(
|
||||
"the",
|
||||
user=None,
|
||||
sort_field="created",
|
||||
sort_reverse=False,
|
||||
)
|
||||
backend.search_ids(
|
||||
"the",
|
||||
user=None,
|
||||
sort_field="created",
|
||||
sort_reverse=True,
|
||||
)
|
||||
|
||||
# -- 3. Backend: highlight_hits for a page of 25 ----------------------
|
||||
|
||||
def test_profile_highlight_hits(self):
|
||||
"""Profile: highlight_hits() for a 25-document page."""
|
||||
backend = get_backend()
|
||||
all_ids = backend.search_ids("report", user=None)
|
||||
page_ids = all_ids[:PAGE_SIZE]
|
||||
with profile_block(f"highlight_hits — {len(page_ids)} docs"):
|
||||
backend.highlight_hits("report", page_ids)
|
||||
|
||||
# -- 4. Backend: autocomplete -----------------------------------------
|
||||
|
||||
def test_profile_autocomplete(self):
|
||||
"""Profile: autocomplete() with eight common prefixes."""
|
||||
backend = get_backend()
|
||||
prefixes = ["inv", "pay", "con", "rep", "sta", "acc", "doc", "fin"]
|
||||
with profile_block(f"autocomplete — {len(prefixes)} prefixes"):
|
||||
for prefix in prefixes:
|
||||
backend.autocomplete(prefix, limit=10)
|
||||
|
||||
# -- 5. Backend: simple-mode search (TEXT and TITLE) ------------------
|
||||
|
||||
def test_profile_search_ids_simple_modes(self):
|
||||
"""Profile: search_ids() in TEXT and TITLE simple-search modes."""
|
||||
backend = get_backend()
|
||||
queries = ["invoice 2023", "annual report", "bank statement"]
|
||||
with profile_block(
|
||||
f"search_ids — TEXT + TITLE modes ({len(queries)} queries each)",
|
||||
):
|
||||
for q in queries:
|
||||
backend.search_ids(q, user=None, search_mode=SearchMode.TEXT)
|
||||
backend.search_ids(q, user=None, search_mode=SearchMode.TITLE)
|
||||
|
||||
# -- 6. API: full round-trip, relevance + page 1 ----------------------
|
||||
|
||||
def test_profile_api_relevance_search(self):
|
||||
"""Profile: full API search round-trip, relevance order, page 1."""
|
||||
with profile_block(
|
||||
f"API /documents/?query=… relevance (page 1, page_size={PAGE_SIZE})",
|
||||
):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?query=invoice+payment&page=1&page_size={PAGE_SIZE}",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# -- 7. API: full round-trip, ORM-ordered (title) ---------------------
|
||||
|
||||
def test_profile_api_orm_sorted_search(self):
|
||||
"""Profile: full API search round-trip with ORM-delegated sort (title)."""
|
||||
with profile_block("API /documents/?query=…&ordering=title"):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?query=report&ordering=title&page=1&page_size={PAGE_SIZE}",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# -- 8. API: full round-trip, score sort ------------------------------
|
||||
|
||||
def test_profile_api_score_sort(self):
|
||||
"""Profile: full API search with ordering=-score (relevance, preserve order)."""
|
||||
with profile_block("API /documents/?query=…&ordering=-score"):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?query=statement&ordering=-score&page=1&page_size={PAGE_SIZE}",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# -- 9. API: full round-trip, with selection_data ---------------------
|
||||
|
||||
def test_profile_api_with_selection_data(self):
|
||||
"""Profile: full API search including include_selection_data=true."""
|
||||
with profile_block("API /documents/?query=…&include_selection_data=true"):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?query=contract&page=1&page_size={PAGE_SIZE}"
|
||||
"&include_selection_data=true",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "selection_data" in response.data
|
||||
|
||||
# -- 10. API: paginated (page 2) --------------------------------------
|
||||
|
||||
def test_profile_api_page_2(self):
|
||||
"""Profile: full API search, page 2 — exercises page offset arithmetic."""
|
||||
with profile_block(f"API /documents/?query=…&page=2&page_size={PAGE_SIZE}"):
|
||||
response = self.client.get(
|
||||
f"/api/documents/?query=the&page=2&page_size={PAGE_SIZE}",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
231
test_workflow_profile.py
Normal file
231
test_workflow_profile.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Workflow trigger matching profiling.
|
||||
|
||||
Run with:
|
||||
uv run pytest ../test_workflow_profile.py \
|
||||
-m profiling --override-ini="addopts=" -s -v
|
||||
|
||||
Corpus: 500 documents + correspondents + tags + sets of WorkflowTrigger
|
||||
objects at 5 and 20 count to allow scaling comparisons.
|
||||
|
||||
Scenarios
|
||||
---------
|
||||
TestWorkflowMatchingProfile
|
||||
- test_existing_document_5_workflows existing_document_matches_workflow x 5 triggers
|
||||
- test_existing_document_20_workflows same x 20 triggers
|
||||
- test_workflow_prefilter prefilter_documents_by_workflowtrigger on 500 docs
|
||||
- test_trigger_type_comparison compare DOCUMENT_ADDED vs DOCUMENT_UPDATED overhead
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from profiling import profile_block
|
||||
|
||||
from documents.matching import existing_document_matches_workflow
|
||||
from documents.matching import prefilter_documents_by_workflowtrigger
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import Tag
|
||||
from documents.models import Workflow
|
||||
from documents.models import WorkflowAction
|
||||
from documents.models import WorkflowTrigger
|
||||
|
||||
pytestmark = [pytest.mark.profiling, pytest.mark.django_db]
|
||||
|
||||
NUM_DOCS = 500
|
||||
NUM_CORRESPONDENTS = 10
|
||||
NUM_TAGS = 20
|
||||
SEED = 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-scoped fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def module_db(django_db_setup, django_db_blocker):
|
||||
"""Unlock the DB for the whole module (module-scoped)."""
|
||||
with django_db_blocker.unblock():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def workflow_corpus(module_db):
|
||||
"""
|
||||
500 documents + correspondents + tags + sets of workflow triggers
|
||||
at 5 and 20 count to allow scaling comparisons.
|
||||
"""
|
||||
fake = Faker()
|
||||
Faker.seed(SEED)
|
||||
rng = random.Random(SEED)
|
||||
|
||||
# ---- lookup objects ---------------------------------------------------
|
||||
print("\n[setup] Creating lookup objects...") # noqa: T201
|
||||
correspondents = [
|
||||
Correspondent.objects.create(name=f"wfcorp-{i}-{fake.company()}"[:128])
|
||||
for i in range(NUM_CORRESPONDENTS)
|
||||
]
|
||||
tags = [
|
||||
Tag.objects.create(name=f"wftag-{i}-{fake.word()}"[:100])
|
||||
for i in range(NUM_TAGS)
|
||||
]
|
||||
|
||||
# ---- documents --------------------------------------------------------
|
||||
print(f"[setup] Building {NUM_DOCS} documents...") # noqa: T201
|
||||
raw_docs = []
|
||||
for i in range(NUM_DOCS):
|
||||
raw_docs.append(
|
||||
Document(
|
||||
title=fake.sentence(nb_words=4).rstrip("."),
|
||||
content=fake.paragraph(nb_sentences=3),
|
||||
checksum=f"WF{i:07d}",
|
||||
correspondent=rng.choice(correspondents + [None] * 3),
|
||||
),
|
||||
)
|
||||
documents = Document.objects.bulk_create(raw_docs, batch_size=500)
|
||||
for doc in documents:
|
||||
k = rng.randint(0, 3)
|
||||
if k:
|
||||
doc.tags.add(*rng.sample(tags, k))
|
||||
|
||||
sample_doc = documents[0]
|
||||
print(f"[setup] Sample doc pk={sample_doc.pk}") # noqa: T201
|
||||
|
||||
# ---- build triggers at scale 5 and 20 --------------------------------
|
||||
_wf_counter = [0]
|
||||
|
||||
def _make_triggers(n: int, trigger_type: int) -> list[WorkflowTrigger]:
|
||||
triggers = []
|
||||
for i in range(n):
|
||||
# Alternate between no filter and a correspondent filter
|
||||
corr = correspondents[i % NUM_CORRESPONDENTS] if i % 3 == 0 else None
|
||||
trigger = WorkflowTrigger.objects.create(
|
||||
type=trigger_type,
|
||||
filter_has_correspondent=corr,
|
||||
)
|
||||
action = WorkflowAction.objects.create(
|
||||
type=WorkflowAction.WorkflowActionType.ASSIGNMENT,
|
||||
)
|
||||
idx = _wf_counter[0]
|
||||
_wf_counter[0] += 1
|
||||
wf = Workflow.objects.create(name=f"wf-profile-{idx}")
|
||||
wf.triggers.add(trigger)
|
||||
wf.actions.add(action)
|
||||
triggers.append(trigger)
|
||||
return triggers
|
||||
|
||||
print("[setup] Creating workflow triggers...") # noqa: T201
|
||||
triggers_5 = _make_triggers(5, WorkflowTrigger.WorkflowTriggerType.DOCUMENT_UPDATED)
|
||||
triggers_20 = _make_triggers(
|
||||
20,
|
||||
WorkflowTrigger.WorkflowTriggerType.DOCUMENT_UPDATED,
|
||||
)
|
||||
triggers_added = _make_triggers(
|
||||
5,
|
||||
WorkflowTrigger.WorkflowTriggerType.DOCUMENT_ADDED,
|
||||
)
|
||||
|
||||
yield {
|
||||
"doc": sample_doc,
|
||||
"triggers_5": triggers_5,
|
||||
"triggers_20": triggers_20,
|
||||
"triggers_added": triggers_added,
|
||||
}
|
||||
|
||||
# Teardown
|
||||
print("\n[teardown] Removing workflow corpus...") # noqa: T201
|
||||
Workflow.objects.all().delete()
|
||||
WorkflowTrigger.objects.all().delete()
|
||||
WorkflowAction.objects.all().delete()
|
||||
Document.objects.all().delete()
|
||||
Correspondent.objects.all().delete()
|
||||
Tag.objects.all().delete()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestWorkflowMatchingProfile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWorkflowMatchingProfile:
|
||||
"""Profile workflow trigger evaluation per document save."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self, workflow_corpus):
|
||||
self.doc = workflow_corpus["doc"]
|
||||
self.triggers_5 = workflow_corpus["triggers_5"]
|
||||
self.triggers_20 = workflow_corpus["triggers_20"]
|
||||
self.triggers_added = workflow_corpus["triggers_added"]
|
||||
|
||||
def test_existing_document_5_workflows(self):
|
||||
"""existing_document_matches_workflow x 5 DOCUMENT_UPDATED triggers."""
|
||||
doc = self.doc
|
||||
triggers = self.triggers_5
|
||||
|
||||
with profile_block(
|
||||
f"existing_document_matches_workflow [{len(triggers)} triggers]",
|
||||
):
|
||||
for trigger in triggers:
|
||||
existing_document_matches_workflow(doc, trigger)
|
||||
|
||||
def test_existing_document_20_workflows(self):
|
||||
"""existing_document_matches_workflow x 20 triggers -- shows linear scaling."""
|
||||
doc = self.doc
|
||||
triggers = self.triggers_20
|
||||
|
||||
with profile_block(
|
||||
f"existing_document_matches_workflow [{len(triggers)} triggers]",
|
||||
):
|
||||
for trigger in triggers:
|
||||
existing_document_matches_workflow(doc, trigger)
|
||||
|
||||
# Also time each call individually to show per-trigger overhead
|
||||
timings = []
|
||||
for trigger in triggers:
|
||||
t0 = time.perf_counter()
|
||||
existing_document_matches_workflow(doc, trigger)
|
||||
timings.append((time.perf_counter() - t0) * 1_000_000)
|
||||
avg_us = sum(timings) / len(timings)
|
||||
print(f"\n Per-trigger avg: {avg_us:.1f} us (n={len(timings)})") # noqa: T201
|
||||
|
||||
def test_workflow_prefilter(self):
|
||||
"""prefilter_documents_by_workflowtrigger on 500 docs -- tag + correspondent filters."""
|
||||
qs = Document.objects.all()
|
||||
print(f"\n Corpus: {qs.count()} documents") # noqa: T201
|
||||
|
||||
for trigger in self.triggers_20[:3]:
|
||||
label = (
|
||||
f"prefilter_documents_by_workflowtrigger "
|
||||
f"[corr={trigger.filter_has_correspondent_id}]"
|
||||
)
|
||||
with profile_block(label):
|
||||
result = prefilter_documents_by_workflowtrigger(qs, trigger)
|
||||
# Evaluate the queryset
|
||||
count = result.count()
|
||||
print(f" -> {count} docs passed filter") # noqa: T201
|
||||
|
||||
def test_trigger_type_comparison(self):
|
||||
"""Compare per-call overhead of DOCUMENT_UPDATED vs DOCUMENT_ADDED."""
|
||||
doc = self.doc
|
||||
runs = 200
|
||||
|
||||
for label, triggers in [
|
||||
("DOCUMENT_UPDATED", self.triggers_5),
|
||||
("DOCUMENT_ADDED", self.triggers_added),
|
||||
]:
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(runs):
|
||||
for trigger in triggers:
|
||||
existing_document_matches_workflow(doc, trigger)
|
||||
total_calls = runs * len(triggers)
|
||||
us_per_call = (time.perf_counter() - t0) / total_calls * 1_000_000
|
||||
print( # noqa: T201
|
||||
f" {label:<22s} {us_per_call:.2f} us/call "
|
||||
f"({total_calls} calls, {len(triggers)} triggers)",
|
||||
)
|
||||
Reference in New Issue
Block a user