Compare commits

..

3 Commits

Author SHA1 Message Date
stumpylog 3a891b38a8 Drops the search shims 2026-06-15 19:22:59 -07:00
Trenton H f4fa916579 Fix (beta): restore v2 (Whoosh) advanced-search query compatibility (#13010)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-15 15:32:44 -07:00
shamoon 75f0c4c92e Fix (beta): retry celery ping and report warning on no response (#13012) 2026-06-15 15:05:43 -07:00
21 changed files with 1782 additions and 1287 deletions
@@ -131,7 +131,9 @@
@if (status.tasks.celery_status === 'OK') {
<i-bs name="check-circle-fill" class="text-primary ms-2 lh-1"></i-bs>
} @else {
<i-bs name="exclamation-triangle-fill" class="text-danger ms-2 lh-1"></i-bs>
<i-bs name="exclamation-triangle-fill" class="ms-2 lh-1"
[class.text-danger]="status.tasks.celery_status === SystemStatusItemStatus.ERROR"
[class.text-warning]="status.tasks.celery_status === SystemStatusItemStatus.WARNING"></i-bs>
}
</button>
<ng-template #celeryStatus>
+4
View File
@@ -8,11 +8,15 @@ from documents.search._backend import get_backend
from documents.search._backend import reset_backend
from documents.search._schema import needs_rebuild
from documents.search._schema import wipe_index
from documents.search._translate import InvalidDateQuery
from documents.search._translate import SearchQueryError
__all__ = [
"InvalidDateQuery",
"SearchHit",
"SearchIndexLockError",
"SearchMode",
"SearchQueryError",
"TantivyBackend",
"TantivyRelevanceList",
"WriteBatch",
+163
View File
@@ -0,0 +1,163 @@
from __future__ import annotations
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Final
from dateutil.relativedelta import relativedelta
if TYPE_CHECKING:
from datetime import tzinfo
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _quarter_start(d: date) -> date:
"""Return the first day of the calendar quarter containing ``d``."""
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
def _midnight(d: date, tz: tzinfo) -> datetime:
"""Convert a calendar date at local-timezone midnight to a UTC datetime."""
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _keyword_bounds(keyword: str, tz: tzinfo) -> tuple[date, date]:
"""
Map a relative date keyword to ``(start, exclusive_end)`` calendar dates.
``tz`` only determines what "today" is; the caller decides how the returned
dates become UTC datetime boundaries (date-only vs. local-midnight offset).
"""
today = datetime.now(tz).date()
if keyword == _TODAY:
return today, today + timedelta(days=1)
if keyword == _YESTERDAY:
return today - timedelta(days=1), today
if keyword == _PREVIOUS_WEEK:
this_monday = today - timedelta(days=today.weekday())
return this_monday - timedelta(weeks=1), this_monday
if keyword == _THIS_MONTH:
first = today.replace(day=1)
return first, first + relativedelta(months=1)
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
return this_first - relativedelta(months=1), this_first
if keyword == _THIS_YEAR:
return date(today.year, 1, 1), date(today.year + 1, 1, 1)
if keyword == _PREVIOUS_YEAR:
return date(today.year - 1, 1, 1), date(today.year, 1, 1)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
return this_quarter - relativedelta(months=3), this_quarter
raise ValueError(f"Unknown keyword: {keyword}")
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic — date only.
"""
start, end = _keyword_bounds(keyword, tz)
lo = datetime(start.year, start.month, start.day, tzinfo=UTC)
hi = datetime(end.year, end.month, end.day, tzinfo=UTC)
return _iso_range(lo, hi)
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC — full offset arithmetic required.
"""
start, end = _keyword_bounds(keyword, tz)
return _iso_range(_midnight(start, tz), _midnight(end, tz))
def _precision_bounds(digits: str) -> tuple[date, date] | None:
"""
Map a 4/6/8-digit date token to (start, exclusive_end) calendar dates.
YYYY -> whole year, YYYYMM -> whole month, YYYYMMDD -> single day.
Returns None for any unparsable or out-of-range value (e.g. month 23),
so callers can emit a no-match clause instead of erroring (Whoosh parity).
"""
try:
if len(digits) == 4:
year = int(digits)
return date(year, 1, 1), date(year + 1, 1, 1)
if len(digits) == 6:
year, month = int(digits[:4]), int(digits[4:6])
start = date(year, month, 1)
end = date(year + 1, 1, 1) if month == 12 else date(year, month + 1, 1)
return start, end
if len(digits) == 8:
start = date(int(digits[:4]), int(digits[4:6]), int(digits[6:8]))
return start, start + timedelta(days=1)
except ValueError:
return None
return None
def _utc_bounds_for_field(
field: str,
start: date,
end: date,
tz: tzinfo,
) -> tuple[datetime, datetime]:
"""
Convert calendar-date bounds to UTC datetimes per the field's storage type.
For DateField (``created``) the bounds are UTC midnight (no offset). For
DateTimeField (``added``/``modified``) the bounds are local-tz midnight
converted to UTC, matching how each field is indexed.
"""
if field in _DATE_ONLY_FIELDS:
return (
datetime(start.year, start.month, start.day, tzinfo=UTC),
datetime(end.year, end.month, end.day, tzinfo=UTC),
)
return (
datetime(start.year, start.month, start.day, tzinfo=tz).astimezone(UTC),
datetime(end.year, end.month, end.day, tzinfo=tz).astimezone(UTC),
)
def _field_range_from_dates(field: str, start: date, end: date, tz: tzinfo) -> str:
"""Build a Tantivy ``field:[lo TO hi]`` ISO range from calendar-date bounds."""
lo, hi = _utc_bounds_for_field(field, start, end, tz)
return f"{field}:{_iso_range(lo, hi)}"
+15 -436
View File
@@ -1,88 +1,28 @@
from __future__ import annotations
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
import logging
from typing import TYPE_CHECKING
from typing import Final
import regex
import tantivy
from dateutil.relativedelta import relativedelta
from django.conf import settings
from documents.search._tokenizer import simple_search_tokens
from documents.search._translate import SearchQueryError
from documents.search._translate import translate_query
if TYPE_CHECKING:
from datetime import tzinfo
from django.contrib.auth.base_user import AbstractBaseUser
logger = logging.getLogger("paperless.search")
# Maximum seconds any single regex substitution may run.
# Prevents ReDoS on adversarial user-supplied query strings.
_REGEX_TIMEOUT: Final[float] = 1.0
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
_DATE_KEYWORD_PATTERN = "|".join(
sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True),
)
_FIELD_DATE_RE = regex.compile(
rf"""(?<!\w)(?P<field>created|modified|added)\s*:\s*(?:
(?P<quote>["'])(?P<quoted>{_DATE_KEYWORD_PATTERN})(?P=quote)
|
(?P<bare>{_DATE_KEYWORD_PATTERN})(?![\w-])
)""",
regex.IGNORECASE | regex.VERBOSE,
)
_COMPACT_DATE_RE = regex.compile(r"\b(\d{14})\b")
_RELATIVE_RANGE_RE = regex.compile(
r"\[now([+-]\d+[dhm])?\s+TO\s+now([+-]\d+[dhm])?\]",
regex.IGNORECASE,
)
# Whoosh-style relative date range: e.g. [-1 week to now], [-7 days to now]
_WHOOSH_REL_RANGE_RE = regex.compile(
r"\[-(?P<n>\d+)\s+(?P<unit>second|minute|hour|day|week|month|year)s?\s+to\s+now\]",
regex.IGNORECASE,
)
# Whoosh-style 8-digit date: field:YYYYMMDD — field-aware so timezone can be applied correctly.
# Scoped to date fields only; numeric fields (asn, id, page_count, ...) must not be rewritten.
_DATE8_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):(?P<date8>\d{8})\b",
)
_YEAR_RANGE_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):\[(?P<y1>\d{4})\s+TO\s+(?P<y2>\d{4})\]",
regex.IGNORECASE,
)
# Tantivy syntax error: " - " and " + " with spaces on both sides are invalid because
# the NOT/MUST operators require no space between the operator and the term.
# In natural-language queries (e.g., "H52.1 - Kurzsichtigkeit"), the dash is a separator.
_SPACED_OPERATOR_RE = regex.compile(r"\s+[-+]\s+")
_TRAILING_OPERATOR_RE = regex.compile(r"\s+[-+]+\s*$")
# Matches CJK/Hangul characters so queries can be routed to bigram fields.
# Uses Unicode properties to cover all blocks including Extension B+ planes.
_CJK_RE: Final = regex.compile(r"[\p{Han}\p{Hiragana}\p{Katakana}\p{Hangul}]+")
@@ -117,375 +57,6 @@ def _build_cjk_query(
return None
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic — date only.
"""
today = datetime.now(tz).date()
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
lo = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, lo + timedelta(days=1))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
lo = datetime(y.year, y.month, y.day, tzinfo=UTC)
hi = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
lo = datetime(last_mon.year, last_mon.month, last_mon.day, tzinfo=UTC)
hi = datetime(this_mon.year, this_mon.month, this_mon.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_MONTH:
lo = datetime(today.year, today.month, 1, tzinfo=UTC)
if today.month == 12:
hi = datetime(today.year + 1, 1, 1, tzinfo=UTC)
else:
hi = datetime(today.year, today.month + 1, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_MONTH:
if today.month == 1:
lo = datetime(today.year - 1, 12, 1, tzinfo=UTC)
else:
lo = datetime(today.year, today.month - 1, 1, tzinfo=UTC)
hi = datetime(today.year, today.month, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_YEAR:
lo = datetime(today.year, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year + 1, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_YEAR:
lo = datetime(today.year - 1, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
lo = datetime(
last_quarter.year,
last_quarter.month,
last_quarter.day,
tzinfo=UTC,
)
hi = datetime(
this_quarter.year,
this_quarter.month,
this_quarter.day,
tzinfo=UTC,
)
return _iso_range(lo, hi)
raise ValueError(f"Unknown keyword: {keyword}")
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC — full offset arithmetic required.
"""
now_local = datetime.now(tz)
today = now_local.date()
def _midnight(d: date) -> datetime:
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
return _iso_range(_midnight(today), _midnight(today + timedelta(days=1)))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
return _iso_range(_midnight(y), _midnight(today))
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
return _iso_range(_midnight(last_mon), _midnight(this_mon))
if keyword == _THIS_MONTH:
first = today.replace(day=1)
if today.month == 12:
next_first = date(today.year + 1, 1, 1)
else:
next_first = date(today.year, today.month + 1, 1)
return _iso_range(_midnight(first), _midnight(next_first))
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
if today.month == 1:
last_first = date(today.year - 1, 12, 1)
else:
last_first = date(today.year, today.month - 1, 1)
return _iso_range(_midnight(last_first), _midnight(this_first))
if keyword == _THIS_YEAR:
return _iso_range(
_midnight(date(today.year, 1, 1)),
_midnight(date(today.year + 1, 1, 1)),
)
if keyword == _PREVIOUS_YEAR:
return _iso_range(
_midnight(date(today.year - 1, 1, 1)),
_midnight(date(today.year, 1, 1)),
)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
return _iso_range(_midnight(last_quarter), _midnight(this_quarter))
raise ValueError(f"Unknown keyword: {keyword}")
def _rewrite_compact_date(query: str) -> str:
"""Rewrite Whoosh compact date tokens (14-digit YYYYMMDDHHmmss) to ISO 8601."""
def _sub(m: regex.Match[str]) -> str:
raw = m.group(1)
try:
dt = datetime(
int(raw[0:4]),
int(raw[4:6]),
int(raw[6:8]),
int(raw[8:10]),
int(raw[10:12]),
int(raw[12:14]),
tzinfo=UTC,
)
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
except ValueError:
return str(m.group(0))
try:
return _COMPACT_DATE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (compact date rewrite timed out)",
)
def _rewrite_relative_range(query: str) -> str:
"""Rewrite Whoosh relative ranges ([now-7d TO now]) to concrete ISO 8601 UTC boundaries."""
def _sub(m: regex.Match[str]) -> str:
now = datetime.now(UTC)
def _offset(s: str | None) -> timedelta:
if not s:
return timedelta(0)
sign = 1 if s[0] == "+" else -1
n, unit = int(s[1:-1]), s[-1]
return (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
lo, hi = now + _offset(m.group(1)), now + _offset(m.group(2))
if lo > hi:
lo, hi = hi, lo
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _RELATIVE_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (relative range rewrite timed out)",
)
def _rewrite_whoosh_relative_range(query: str) -> str:
"""Rewrite Whoosh-style relative date ranges ([-N unit to now]) to ISO 8601.
Supports: second, minute, hour, day, week, month, year (singular and plural).
Example: ``added:[-1 week to now]`` → ``added:[2025-01-01T… TO 2025-01-08T…]``
"""
now = datetime.now(UTC)
def _sub(m: regex.Match[str]) -> str:
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
lo = now - delta_map[unit]
return f"[{_fmt(lo)} TO {_fmt(now)}]"
try:
return _WHOOSH_REL_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (Whoosh relative range rewrite timed out)",
)
def _rewrite_8digit_date(query: str, tz: tzinfo) -> str:
"""Rewrite field:YYYYMMDD date tokens to an ISO 8601 day range.
Runs after ``_rewrite_compact_date`` so 14-digit timestamps are already
converted and won't spuriously match here.
For DateField fields (e.g. ``created``) uses UTC midnight boundaries.
For DateTimeField fields (e.g. ``added``, ``modified``) uses local TZ
midnight boundaries converted to UTC — matching the ``_datetime_range``
behaviour for keyword dates.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
raw = m.group("date8")
try:
year, month, day = int(raw[0:4]), int(raw[4:6]), int(raw[6:8])
d = date(year, month, day)
if field in _DATE_ONLY_FIELDS:
lo = datetime(d.year, d.month, d.day, tzinfo=UTC)
hi = lo + timedelta(days=1)
else:
# DateTimeField: use local-timezone midnight → UTC
lo = datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
hi = datetime(
(d + timedelta(days=1)).year,
(d + timedelta(days=1)).month,
(d + timedelta(days=1)).day,
tzinfo=tz,
).astimezone(UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
except ValueError:
return m.group(0)
try:
return _DATE8_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (8-digit date rewrite timed out)",
)
def _rewrite_year_range(query: str) -> str:
"""Rewrite Whoosh-style year-only date ranges to ISO 8601 UTC boundaries.
Converts ``field:[YYYY TO YYYY]`` to a full ISO 8601 datetime range.
The upper bound is the start of the year after the end year (exclusive),
matching the Whoosh convention of treating year-only ranges as full-year spans.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
y1, y2 = int(m.group("y1")), int(m.group("y2"))
# Whoosh swaps a reversed range when both years are explicit
# (whoosh.util.times.timespan.disambiguated); match that so a backwards
# range spans the intended years instead of matching nothing.
lo_year, hi_year = min(y1, y2), max(y1, y2)
lo = datetime(lo_year, 1, 1, tzinfo=UTC)
hi = datetime(hi_year + 1, 1, 1, tzinfo=UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _YEAR_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError("Query too complex to process (year range rewrite timed out)")
def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
"""
Rewrite natural date syntax to ISO 8601 format for Tantivy compatibility.
Performs the first stage of query preprocessing, converting various date
formats and keywords to ISO 8601 datetime ranges that Tantivy can parse:
- Compact 14-digit dates (YYYYMMDDHHmmss)
- Whoosh relative ranges ([-7 days to now], [now-1h TO now+2h])
- 8-digit dates with field awareness (created:20240115)
- Natural keywords (field:today, field:"previous quarter", etc.)
Args:
query: Raw user query string
tz: Timezone for converting local date boundaries to UTC
Returns:
Query with date syntax rewritten to ISO 8601 ranges
Note:
Bare keywords without field prefixes pass through unchanged.
"""
query = _rewrite_compact_date(query)
query = _rewrite_whoosh_relative_range(query)
query = _rewrite_year_range(query)
query = _rewrite_8digit_date(query, tz)
query = _rewrite_relative_range(query)
def _replace(m: regex.Match[str]) -> str:
field = m.group("field")
keyword = (m.group("quoted") or m.group("bare")).lower()
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(keyword, tz)}"
return f"{field}:{_datetime_range(keyword, tz)}"
try:
return _FIELD_DATE_RE.sub(_replace, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (date keyword rewrite timed out)",
)
def normalize_query(query: str) -> str:
"""
Normalize query syntax for better search behavior.
Expands comma-separated field values to explicit AND clauses and
collapses excessive whitespace for cleaner parsing:
- tag:foo,bar → tag:foo AND tag:bar
- multiple spaces → single spaces
Args:
query: Query string after date rewriting
Returns:
Normalized query string ready for Tantivy parsing
"""
def _expand(m: regex.Match[str]) -> str:
field = m.group(1)
values = [v.strip() for v in m.group(2).split(",") if v.strip()]
return " AND ".join(f"{field}:{v}" for v in values)
try:
query = regex.sub(
r"(\w+):([^\s\[\]]+(?:,[^\s\[\]]+)+)",
_expand,
query,
timeout=_REGEX_TIMEOUT,
)
query = regex.sub(r" {2,}", " ", query, timeout=_REGEX_TIMEOUT).strip()
# Strip trailing dangling operators before Tantivy sees them.
query = _TRAILING_OPERATOR_RE.sub("", query, timeout=_REGEX_TIMEOUT).strip()
# Replace " - " / " + " with a space: Tantivy requires no space between
# the operator and its operand (-term / +term), so spaces on both sides
# means this is a natural-language separator, not a query operator.
query = _SPACED_OPERATOR_RE.sub(" ", query, timeout=_REGEX_TIMEOUT).strip()
return query
except TimeoutError: # pragma: no cover
raise ValueError("Query too complex to process (normalization timed out)")
def build_permission_filter(
schema: tantivy.Schema,
user: AbstractBaseUser,
@@ -603,8 +174,16 @@ def parse_user_query(
as a post-search score filter, not during query construction.
"""
query_str = rewrite_natural_date_keywords(raw_query, tz)
query_str = normalize_query(query_str)
try:
query_str = translate_query(raw_query, tz)
except SearchQueryError:
# Intentional, user-fixable error (e.g. an unparsable date). Propagate so
# the view can return a 400 with a helpful message rather than falling
# back to the raw (still-invalid) query.
raise
except Exception: # pragma: no cover - defensive
logger.warning("Query translation failed; using raw query", exc_info=True)
query_str = raw_query
exact = index.parse_query(
query_str,
+566
View File
@@ -0,0 +1,566 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TypeAlias
import regex
from dateutil.relativedelta import relativedelta
from documents.search._dates import _DATE_KEYWORDS
from documents.search._dates import _DATE_ONLY_FIELDS
from documents.search._dates import _date_only_range
from documents.search._dates import _datetime_range
from documents.search._dates import _field_range_from_dates
from documents.search._dates import _fmt
from documents.search._dates import _precision_bounds
from documents.search._dates import _utc_bounds_for_field
# Compiled regex that matches any known multi-word (or single-word) date keyword
# at the start of a match position, longest alternatives first so "previous week"
# wins over a hypothetical shorter "previous".
_KEYWORD_VALUE_RE = regex.compile(
"|".join(sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True)),
regex.IGNORECASE,
)
if TYPE_CHECKING:
from datetime import tzinfo
# TODO: this module translates date queries into Tantivy *string* syntax, which
# forces a workaround for something Tantivy's string parser cannot express on
# date fields: open-ended ranges use far-past/far-future string sentinels
# (OPEN_LO/OPEN_HI). These can be replaced with a real tantivy.Query object
# (Query.range_query(..., None) for open bounds) once tantivy-py accepts Python
# datetimes in range_query/term_query on Date fields. That support exists on
# tantivy-py master (PRs #655 + #666) but postdates the pinned 0.26.0 wheel, so
# it is blocked only on a published release > 0.26.0 and a dependency bump.
# (Unparsable dates now raise InvalidDateQuery -> HTTP 400 rather than using a
# no-match string sentinel.)
# Fields that store exact, non-analyzed comma-joined tokens in the index and so
# need explicit comma->AND expansion (Whoosh KEYWORD(commas=True) set).
MULTI_VALUE_FIELDS = frozenset({"tag", "tag_id", "viewer_id"})
# Date fields whose values/ranges get rewritten to RFC3339 Tantivy ranges.
DATE_FIELDS = frozenset({"created", "modified", "added"})
# Field aliases: Whoosh (v2) field names that were renamed in the Tantivy schema.
# Preserved here so v2 queries using the old names continue to work without 400
# errors instead of silently failing. Applied by _render to non-date field tokens.
FIELD_ALIASES: dict[str, str] = {
"type": "document_type",
"type_id": "document_type_id",
"path": "storage_path",
"path_id": "storage_path_id",
}
# Known schema fields: a comma immediately followed by ``<known>:`` is a clause
# separator. Restricting to known fields prevents URL-like ``http:`` misfires.
KNOWN_FIELDS = frozenset(
{
"title",
"content",
"correspondent",
"document_type",
"type", # v2 alias -> document_type
"storage_path",
"path", # v2 alias -> storage_path
"tag",
"tag_id",
"correspondent_id",
"document_type_id",
"type_id", # v2 alias -> document_type_id
"storage_path_id",
"path_id", # v2 alias -> storage_path_id
"owner_id",
"viewer_id",
"asn",
"page_count",
"num_notes",
"created",
"modified",
"added",
"original_filename",
"checksum",
"notes",
"custom_fields",
},
)
_FIELD_RE = regex.compile(r"(?P<field>\w+):")
# Matches the TO separator inside a range bracket. Handles three forms:
# middle: "lo TO hi" (either lo or hi may be empty)
# trailing: "lo TO" (open upper bound)
# leading: "TO hi" (open lower bound)
# Bounds MAY contain internal spaces (e.g. "-7 days"), so we use .*? / .+?
# and split on the whitespace-delimited " TO " / " to " separator.
_RANGE_RE = regex.compile(
r"^\s*(?P<lo>.*?)\s+[Tt][Oo]\s+(?P<hi>.+?)\s*$"
r"|"
r"^\s*(?P<lo2>.+?)\s+[Tt][Oo]\s*$"
r"|"
r"^\s*[Tt][Oo]\s+(?P<hi2>.+?)\s*$",
)
@dataclass(frozen=True, slots=True)
class FieldValue:
field: str
value: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class FieldValueList:
field: str
values: tuple[str, ...]
@dataclass(frozen=True, slots=True)
class FieldRange:
field: str
open: str
lo: str
hi: str
close: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class Comma:
pass
@dataclass(frozen=True, slots=True)
class Passthrough:
raw: str
Token: TypeAlias = FieldValue | FieldValueList | FieldRange | Comma | Passthrough
_CLOSE: dict[str, str] = {"[": "]", "{": "}"}
def scan(query: str) -> list[Token]:
"""
Tokenize a raw query into date/comma-aware tokens, leaving everything else
as verbatim ``Passthrough`` runs. Non-recursive: finds the first matching
close bracket/quote. Nested brackets are not valid Tantivy range syntax and
pass through verbatim on mismatch.
"""
tokens: list[Token] = []
buf: list[str] = [] # accumulates passthrough chars
i, n = 0, len(query)
while i < n:
matched = _match_field_token(query, i)
if matched is None:
buf.append(query[i])
i += 1
continue
token, i = matched
_flush(buf, tokens)
tokens.append(token)
i = _maybe_comma(query, i, tokens)
_flush(buf, tokens)
return tokens
def _flush(buf: list[str], tokens: list[Token]) -> None:
"""Emit any accumulated passthrough characters as a single token."""
if buf:
tokens.append(Passthrough("".join(buf)))
buf.clear()
def _at_word_boundary(query: str, i: int) -> bool:
"""A field token may begin only at the start or after a non-word character."""
return i == 0 or not (query[i - 1].isalnum() or query[i - 1] == "_")
def _match_field_token(query: str, i: int) -> tuple[Token, int] | None:
"""
If a known ``field:`` token starts at ``i``, consume it and return
``(token, end_index)``; otherwise return None so the caller treats the
character as passthrough. Handles both ``field:[range]`` and ``field:value``,
and returns None when the range/value cannot be consumed.
"""
m = _FIELD_RE.match(query, i)
if m is None or m.group("field") not in KNOWN_FIELDS:
return None
if not _at_word_boundary(query, i):
return None
field = m.group("field")
j = m.end()
if j < len(query) and query[j] in "[{":
return _consume_range(query, j, field)
consumed = _consume_field_value(query, field, j)
if consumed is None:
return None
value, end = consumed
return FieldValue(field, value), end
def _consume_field_value(query: str, field: str, start: int) -> tuple[str, int] | None:
"""
Consume a field value starting at ``start``: a multi-word date keyword phrase
(date fields only), or a bare/quoted value, then absorb any comma-joined
continuation that is not a clause separator. ``resolve_commas`` later splits a
multi-value field's joined value into a ``FieldValueList``; for other fields
the comma stays literal.
"""
n = len(query)
consumed = None
if field in DATE_FIELDS:
km = _KEYWORD_VALUE_RE.match(query, start)
if km is not None and (km.end() >= n or query[km.end()] in " \t),"):
consumed = (km.group(0), km.end())
if consumed is None:
consumed = _consume_value(query, start)
if consumed is None:
return None
value, k = consumed
while k < n and query[k] == ",":
if _looks_like_known_field(query, k + 1):
break # clause separator: left for _maybe_comma to emit a Comma()
more = _consume_value(query, k + 1)
if more is None:
break
value = f"{value},{more[0]}"
k = more[1]
return value, k
def _consume_range(
query: str,
start: int,
field: str,
) -> tuple[FieldRange, int] | None:
"""Consume ``[lo TO hi]`` / ``{lo TO hi}`` from ``start`` (the bracket)."""
open_br = query[start]
close_br = _CLOSE[open_br]
end = query.find(close_br, start + 1)
if end == -1:
return None
inner = query[start + 1 : end]
m = _RANGE_RE.match(inner)
if m is not None:
if m.group("lo") is not None or m.group("hi") is not None:
# Middle form: "lo TO hi" (either may be empty string)
lo = (m.group("lo") or "").strip()
hi = (m.group("hi") or "").strip()
elif m.group("lo2") is not None:
# Trailing form: "lo TO"
lo = m.group("lo2").strip()
hi = ""
else:
# Leading form: "TO hi"
lo = ""
hi = (m.group("hi2") or "").strip()
else:
lo, hi = inner.strip(), ""
return FieldRange(field, open_br, lo, hi, close_br), end + 1
def _consume_value(query: str, start: int) -> tuple[str, int] | None:
"""Consume a bare or quoted field value from ``start``, stopping at comma."""
n = len(query)
if start >= n or query[start] in " \t":
return None
if query[start] in "\"'":
quote = query[start]
end = query.find(quote, start + 1)
if end == -1:
return None
return query[start : end + 1], end + 1
j = start
while j < n and query[j] not in " \t),":
j += 1
return query[start:j], j
def _looks_like_known_field(query: str, pos: int) -> bool:
"""True if a known ``field:`` token starts at ``pos``."""
m = _FIELD_RE.match(query, pos)
return bool(m and m.group("field") in KNOWN_FIELDS)
def _maybe_comma(query: str, i: int, tokens: list) -> int:
"""If a clause-separator comma follows at ``i``, emit ``Comma()`` and advance."""
if i < len(query) and query[i] == "," and _looks_like_known_field(query, i + 1):
tokens.append(Comma())
return i + 1
return i
def resolve_commas(tokens: list) -> list:
"""
Collapse value-list commas into ``FieldValueList`` and keep clause-separator
commas as ``Comma``. (Clause-sep commas are already emitted by ``scan`` via
the value-stop logic; this pass folds value-lists.)
"""
out: list = []
for tok in tokens:
if (
isinstance(tok, FieldValue)
and tok.field in MULTI_VALUE_FIELDS
and "," in tok.value
):
values = tuple(v for v in tok.value.split(",") if v)
out.append(FieldValueList(tok.field, values))
else:
out.append(tok)
return out
class SearchQueryError(ValueError):
"""
Base for user-fixable search query errors.
Carries a message safe to surface to the user (no internal details). The view
layer catches this and returns an HTTP 400, so any future subclass (unknown
field, malformed range, wrapped parser errors) gets the same treatment.
"""
class InvalidDateQuery(SearchQueryError):
"""Raised when a date field value or range bound cannot be parsed."""
def __init__(self, field: str, value: str) -> None:
self.field = field
self.value = value
super().__init__(f"Invalid date value {value!r} for field {field!r}.")
_DIGITS_RE = regex.compile(r"^\d{4}(?:\d{2}){0,2}$")
_ISO_RE = regex.compile(r"^\d{4}(?:-\d{2}(?:-\d{2})?)?$")
def translate_scalar(field: str, value: str, tz: tzinfo) -> str:
"""Translate a bare date-field value to a Tantivy range string."""
bare = value.strip("\"'").lower()
if bare in _DATE_KEYWORDS:
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(bare, tz)}"
return f"{field}:{_datetime_range(bare, tz)}"
digits = value.replace("-", "")
if _DIGITS_RE.match(value) or _ISO_RE.match(value):
bounds = _precision_bounds(digits)
if bounds is None:
raise InvalidDateQuery(field, value)
return _field_range_from_dates(field, bounds[0], bounds[1], tz)
if regex.fullmatch(r"\d{14}", value):
try:
dt = datetime(
int(value[0:4]),
int(value[4:6]),
int(value[6:8]),
int(value[8:10]),
int(value[10:12]),
int(value[12:14]),
tzinfo=UTC,
)
except ValueError:
raise InvalidDateQuery(field, value) from None
iso = _fmt(dt)
return f"{field}:[{iso} TO {iso}]"
# Unrecognized shape -> tell the user their date is malformed rather than
# silently matching nothing or emitting invalid Tantivy syntax.
raise InvalidDateQuery(field, value)
# Open-bound sentinels for date ranges. These far-past/far-future strings allow
# open-ended ranges to be expressed as Tantivy string queries until tantivy-py
# exposes Query.range_query(..., None) on Date fields (see module TODO).
OPEN_LO = "0001-01-01T00:00:00Z"
OPEN_HI = "9999-12-31T23:59:59Z"
# Matches compact now-offset tokens like now-7d, now+1h, now-30m.
_NOW_COMPACT_RE = regex.compile(
r"^now(?P<sign>[+-])(?P<n>\d+)(?P<unit>[dhm])$",
regex.IGNORECASE,
)
# Matches "±N <unit>" Whoosh-style offsets (e.g. -7 days, -1 week, +3 hours)
# Unit is singular or plural; sign prefix is mandatory.
_NOW_SPACED_RE = regex.compile(
r"^(?P<sign>[+-])(?P<n>\d+)\s*"
r"(?P<unit>second|minute|hour|day|week|month|year)s?$",
regex.IGNORECASE,
)
def _resolve_relative_bound(token: str) -> datetime | None:
"""
Resolve a relative bound token to an exact UTC instant, or return None.
Supported forms:
- ``now`` -> current UTC instant
- ``now+/-<n>d/h/m`` -> now +/- timedelta (d=days, h=hours, m=minutes)
- ``±N <unit>`` -> now +/- delta; month/year use relativedelta
"""
stripped = token.strip()
low = stripped.lower()
now = datetime.now(UTC)
if low == "now":
return now
m = _NOW_COMPACT_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta = (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
return now + delta
m = _NOW_SPACED_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
return now - delta_map[unit] if sign == -1 else now + delta_map[unit]
return None
def _bound_datetimes(
field: str,
token: str,
tz: tzinfo,
) -> tuple[datetime, datetime] | None:
"""
Return (floor_dt, ceil_dt) UTC datetimes for a single range bound token, or
None if the token is unparsable. ``now`` and relative offsets resolve to the
current instant (floor == ceil == that instant; no day-flooring).
"""
token = token.strip()
# Try relative/now forms first (before stripping hyphens which would mangle them).
rel = _resolve_relative_bound(token)
if rel is not None:
return rel, rel
# Full ISO datetime token (contains "T"): parse directly and return an exact
# instant (floor == ceil). Python 3.11+ datetime.fromisoformat accepts trailing Z.
if "T" in token:
try:
dt = datetime.fromisoformat(token)
# Ensure timezone-aware UTC result.
dt = dt.replace(tzinfo=UTC) if dt.tzinfo is None else dt.astimezone(UTC)
return dt, dt
except ValueError:
return None
digits = token.replace("-", "")
bounds = _precision_bounds(digits)
if bounds is None:
return None
start, end = bounds
return _utc_bounds_for_field(field, start, end, tz)
def _render(tok: Token, tz: tzinfo) -> str:
"""Render a single token back to a Tantivy query string fragment."""
if isinstance(tok, Passthrough):
return tok.raw
if isinstance(tok, Comma):
return " AND "
if isinstance(tok, FieldValueList):
field = FIELD_ALIASES.get(tok.field, tok.field)
return " AND ".join(f"{field}:{v}" for v in tok.values)
if isinstance(tok, FieldValue):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_scalar(field, tok.value, tz)
return f"{field}:{tok.value}"
if isinstance(tok, FieldRange):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_range(field, tok.lo, tok.hi, tz)
return f"{field}:{tok.open}{tok.lo} TO {tok.hi}{tok.close}"
return "" # pragma: no cover
# Post-render operator normalization patterns: collapse repeated whitespace and
# strip spaced/trailing Tantivy boolean operators that would otherwise be invalid.
_MULTI_SPACE_RE = regex.compile(r" {2,}")
_TRAILING_OP_RE = regex.compile(r"\s+[-+]+\s*$")
_SPACED_OP_RE = regex.compile(r"\s+[-+]\s+")
def _normalize_operators(text: str) -> str:
"""
Collapse multiple spaces, strip trailing dangling operators, and replace
spaced operators (`` - `` / `` + ``) with a single space.
Applied only to Passthrough fragments (the rendered output is scanned for
operator artifacts outside bracketed ranges) via a post-render pass on the
full rendered string. This preserves date ranges (``[... TO ...]``) verbatim
while cleaning natural-language separators in the surrounding text.
"""
text = _MULTI_SPACE_RE.sub(" ", text)
text = _TRAILING_OP_RE.sub("", text).strip()
text = _SPACED_OP_RE.sub(" ", text).strip()
return text
def translate_query(raw: str, tz: tzinfo) -> str:
"""Translate a raw Whoosh-style query into Tantivy-compatible syntax."""
tokens = resolve_commas(scan(raw))
rendered = "".join(_render(t, tz) for t in tokens)
return _normalize_operators(rendered)
def translate_range(field: str, lo: str, hi: str, tz: tzinfo) -> str:
"""Translate a date-field ``[lo TO hi]`` range to a Tantivy ISO range string.
Handles partial-date bounds (YYYY, YYYYMM, YYYYMMDD, ISO dash variants),
open bounds (empty string -> OPEN_LO/OPEN_HI), ``now``, and reversed ranges
(swaps tokens before computing floor/ceil so the span is always correct).
"""
lo_s = lo.strip()
hi_s = hi.strip()
# Parse both bounds to (floor, ceil) pairs when present.
lo_pair: tuple[datetime, datetime] | None = None
hi_pair: tuple[datetime, datetime] | None = None
if lo_s:
lo_pair = _bound_datetimes(field, lo_s, tz)
if lo_pair is None:
raise InvalidDateQuery(field, lo_s)
if hi_s:
hi_pair = _bound_datetimes(field, hi_s, tz)
if hi_pair is None:
raise InvalidDateQuery(field, hi_s)
# Detect a reversed range: only swap when BOTH bounds are present.
if lo_pair is not None and hi_pair is not None and lo_pair[0] > hi_pair[0]:
lo_pair, hi_pair = hi_pair, lo_pair
lo_iso = _fmt(lo_pair[0]) if lo_pair is not None else OPEN_LO
hi_iso = _fmt(hi_pair[1]) if hi_pair is not None else OPEN_HI
return f"{field}:[{lo_iso} TO {hi_iso}]"
+12
View File
@@ -1,11 +1,15 @@
from __future__ import annotations
import tempfile
from typing import TYPE_CHECKING
import pytest
import tantivy
from documents.search._backend import TantivyBackend
from documents.search._backend import reset_backend
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
if TYPE_CHECKING:
from collections.abc import Generator
@@ -31,3 +35,11 @@ def backend() -> Generator[TantivyBackend, None, None]:
finally:
b.close()
reset_backend()
@pytest.fixture(scope="module")
def index() -> tantivy.Index:
"""A real Tantivy index for parse-acceptance tests (module scope for speed)."""
idx = tantivy.Index(build_schema(), path=tempfile.mkdtemp())
register_tokenizers(idx, "english")
return idx
+135 -56
View File
@@ -11,16 +11,15 @@ import pytest
import tantivy
import time_machine
from documents.search._query import _date_only_range
from documents.search._query import _datetime_range
from documents.search._query import _rewrite_compact_date
from documents.search._dates import _date_only_range
from documents.search._dates import _datetime_range
from documents.search._query import build_permission_filter
from documents.search._query import normalize_query
from documents.search._query import parse_simple_text_highlight_query
from documents.search._query import parse_user_query
from documents.search._query import rewrite_natural_date_keywords
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
from documents.search._translate import InvalidDateQuery
from documents.search._translate import translate_query
if TYPE_CHECKING:
from django.contrib.auth.base_user import AbstractBaseUser
@@ -57,7 +56,7 @@ class TestCreatedDateField:
)
@time_machine.travel(datetime(2026, 3, 28, 15, 30, tzinfo=UTC), tick=False)
def test_today(self, tz: tzinfo, expected_lo: str, expected_hi: str) -> None:
lo, hi = _range(rewrite_natural_date_keywords("created:today", tz), "created")
lo, hi = _range(translate_query("created:today", tz), "created")
assert lo == expected_lo
assert hi == expected_hi
@@ -65,7 +64,7 @@ class TestCreatedDateField:
def test_today_auckland_ahead_of_utc(self) -> None:
# UTC 03:00 -> Auckland (UTC+13) = 16:00 same date; local date = 2026-03-28
lo, _ = _range(
rewrite_natural_date_keywords("created:today", AUCKLAND),
translate_query("created:today", AUCKLAND),
"created",
)
assert lo == "2026-03-28T00:00:00Z"
@@ -127,7 +126,7 @@ class TestCreatedDateField:
) -> None:
# 2026-03-28 is Saturday; Mon-Sun week calculation built into expectations
query = f"{field}:{keyword}"
lo, hi = _range(rewrite_natural_date_keywords(query, UTC), field)
lo, hi = _range(translate_query(query, UTC), field)
assert lo == expected_lo
assert hi == expected_hi
@@ -135,7 +134,7 @@ class TestCreatedDateField:
def test_this_month_december_wraps_to_next_year(self) -> None:
# December: next month must roll over to January 1 of next year
lo, hi = _range(
rewrite_natural_date_keywords("created:this month", UTC),
translate_query("created:this month", UTC),
"created",
)
assert lo == "2026-12-01T00:00:00Z"
@@ -145,7 +144,7 @@ class TestCreatedDateField:
def test_last_month_january_wraps_to_previous_year(self) -> None:
# January: last month must roll back to December 1 of previous year
lo, hi = _range(
rewrite_natural_date_keywords("created:previous month", UTC),
translate_query("created:previous month", UTC),
"created",
)
assert lo == "2025-12-01T00:00:00Z"
@@ -154,7 +153,7 @@ class TestCreatedDateField:
@time_machine.travel(datetime(2026, 7, 15, 12, 0, tzinfo=UTC), tick=False)
def test_previous_quarter(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords('created:"previous quarter"', UTC),
translate_query('created:"previous quarter"', UTC),
"created",
)
assert lo == "2026-04-01T00:00:00Z"
@@ -174,7 +173,7 @@ class TestDateTimeFields:
@time_machine.travel(datetime(2026, 3, 28, 15, 30, tzinfo=UTC), tick=False)
def test_added_today_eastern(self) -> None:
# EDT = UTC-4; local midnight 2026-03-28 00:00 EDT = 2026-03-28 04:00 UTC
lo, hi = _range(rewrite_natural_date_keywords("added:today", EASTERN), "added")
lo, hi = _range(translate_query("added:today", EASTERN), "added")
assert lo == "2026-03-28T04:00:00Z"
assert hi == "2026-03-29T04:00:00Z"
@@ -182,14 +181,14 @@ class TestDateTimeFields:
def test_added_today_auckland_midnight_crossing(self) -> None:
# UTC 02:00 on 2026-03-29 -> Auckland (UTC+13) = 2026-03-29 15:00 local
# Auckland midnight = UTC 2026-03-28 11:00
lo, hi = _range(rewrite_natural_date_keywords("added:today", AUCKLAND), "added")
lo, hi = _range(translate_query("added:today", AUCKLAND), "added")
assert lo == "2026-03-28T11:00:00Z"
assert hi == "2026-03-29T11:00:00Z"
@time_machine.travel(datetime(2026, 3, 28, 15, 0, tzinfo=UTC), tick=False)
def test_modified_today_utc(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords("modified:today", UTC),
translate_query("modified:today", UTC),
"modified",
)
assert lo == "2026-03-28T00:00:00Z"
@@ -244,14 +243,14 @@ class TestDateTimeFields:
expected_hi: str,
) -> None:
# 2026-03-28 is Saturday; weekday()==5 so Monday=2026-03-23
lo, hi = _range(rewrite_natural_date_keywords(f"added:{keyword}", UTC), "added")
lo, hi = _range(translate_query(f"added:{keyword}", UTC), "added")
assert lo == expected_lo
assert hi == expected_hi
@time_machine.travel(datetime(2026, 12, 15, 12, 0, tzinfo=UTC), tick=False)
def test_this_month_december_wraps_to_next_year(self) -> None:
# December: next month wraps to January of next year
lo, hi = _range(rewrite_natural_date_keywords("added:this month", UTC), "added")
lo, hi = _range(translate_query("added:this month", UTC), "added")
assert lo == "2026-12-01T00:00:00Z"
assert hi == "2027-01-01T00:00:00Z"
@@ -259,7 +258,7 @@ class TestDateTimeFields:
def test_last_month_january_wraps_to_previous_year(self) -> None:
# January: last month wraps back to December of previous year
lo, hi = _range(
rewrite_natural_date_keywords("added:previous month", UTC),
translate_query("added:previous month", UTC),
"added",
)
assert lo == "2025-12-01T00:00:00Z"
@@ -295,7 +294,7 @@ class TestDateTimeFields:
expected_lo: str,
expected_hi: str,
) -> None:
lo, hi = _range(rewrite_natural_date_keywords(query, UTC), "added")
lo, hi = _range(translate_query(query, UTC), "added")
assert lo == expected_lo
assert hi == expected_hi
@@ -309,20 +308,20 @@ class TestWhooshQueryRewriting:
@time_machine.travel(datetime(2026, 3, 28, 15, 0, tzinfo=UTC), tick=False)
def test_compact_date_shim_rewrites_to_iso(self) -> None:
result = rewrite_natural_date_keywords("created:20240115120000", UTC)
result = translate_query("created:20240115120000", UTC)
assert "2024-01-15" in result
assert "20240115120000" not in result
@time_machine.travel(datetime(2026, 3, 28, 15, 0, tzinfo=UTC), tick=False)
def test_relative_range_shim_removes_now(self) -> None:
result = rewrite_natural_date_keywords("added:[now-7d TO now]", UTC)
result = translate_query("added:[now-7d TO now]", UTC)
assert "now" not in result
assert "2026-03-" in result
@time_machine.travel(datetime(2026, 3, 28, 12, 0, tzinfo=UTC), tick=False)
def test_bracket_minus_7_days(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords("added:[-7 days to now]", UTC),
translate_query("added:[-7 days to now]", UTC),
"added",
)
assert lo == "2026-03-21T12:00:00Z"
@@ -331,7 +330,7 @@ class TestWhooshQueryRewriting:
@time_machine.travel(datetime(2026, 3, 28, 12, 0, tzinfo=UTC), tick=False)
def test_bracket_minus_1_week(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords("added:[-1 week to now]", UTC),
translate_query("added:[-1 week to now]", UTC),
"added",
)
assert lo == "2026-03-21T12:00:00Z"
@@ -341,7 +340,7 @@ class TestWhooshQueryRewriting:
def test_bracket_minus_1_month_uses_relativedelta(self) -> None:
# relativedelta(months=1) from 2026-03-28 = 2026-02-28 (not 29)
lo, hi = _range(
rewrite_natural_date_keywords("created:[-1 month to now]", UTC),
translate_query("created:[-1 month to now]", UTC),
"created",
)
assert lo == "2026-02-28T12:00:00Z"
@@ -350,7 +349,7 @@ class TestWhooshQueryRewriting:
@time_machine.travel(datetime(2026, 3, 28, 12, 0, tzinfo=UTC), tick=False)
def test_bracket_minus_1_year(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords("modified:[-1 year to now]", UTC),
translate_query("modified:[-1 year to now]", UTC),
"modified",
)
assert lo == "2025-03-28T12:00:00Z"
@@ -359,7 +358,7 @@ class TestWhooshQueryRewriting:
@time_machine.travel(datetime(2026, 3, 28, 12, 0, tzinfo=UTC), tick=False)
def test_bracket_plural_unit_hours(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords("added:[-3 hours to now]", UTC),
translate_query("added:[-3 hours to now]", UTC),
"added",
)
assert lo == "2026-03-28T09:00:00Z"
@@ -367,7 +366,7 @@ class TestWhooshQueryRewriting:
@time_machine.travel(datetime(2026, 3, 28, 12, 0, tzinfo=UTC), tick=False)
def test_bracket_case_insensitive(self) -> None:
result = rewrite_natural_date_keywords("added:[-1 WEEK TO NOW]", UTC)
result = translate_query("added:[-1 WEEK TO NOW]", UTC)
assert "now" not in result.lower()
lo, hi = _range(result, "added")
assert lo == "2026-03-21T12:00:00Z"
@@ -377,7 +376,7 @@ class TestWhooshQueryRewriting:
def test_relative_range_swaps_bounds_when_lo_exceeds_hi(self) -> None:
# [now+1h TO now-1h] has lo > hi before substitution; they must be swapped
lo, hi = _range(
rewrite_natural_date_keywords("added:[now+1h TO now-1h]", UTC),
translate_query("added:[now+1h TO now-1h]", UTC),
"added",
)
assert lo == "2026-03-28T11:00:00Z"
@@ -385,14 +384,14 @@ class TestWhooshQueryRewriting:
def test_8digit_created_date_field_always_uses_utc_midnight(self) -> None:
# created is a DateField: boundaries are always UTC midnight, no TZ offset
result = rewrite_natural_date_keywords("created:20231201", EASTERN)
result = translate_query("created:20231201", EASTERN)
lo, hi = _range(result, "created")
assert lo == "2023-12-01T00:00:00Z"
assert hi == "2023-12-02T00:00:00Z"
def test_8digit_added_datetime_field_converts_local_midnight_to_utc(self) -> None:
# added is DateTimeField: midnight Dec 1 Eastern (EST = UTC-5) = 05:00 UTC
result = rewrite_natural_date_keywords("added:20231201", EASTERN)
result = translate_query("added:20231201", EASTERN)
lo, hi = _range(result, "added")
assert lo == "2023-12-01T05:00:00Z"
assert hi == "2023-12-02T05:00:00Z"
@@ -400,17 +399,19 @@ class TestWhooshQueryRewriting:
def test_8digit_modified_datetime_field_converts_local_midnight_to_utc(
self,
) -> None:
result = rewrite_natural_date_keywords("modified:20231201", EASTERN)
result = translate_query("modified:20231201", EASTERN)
lo, hi = _range(result, "modified")
assert lo == "2023-12-01T05:00:00Z"
assert hi == "2023-12-02T05:00:00Z"
def test_8digit_invalid_date_passes_through_unchanged(self) -> None:
assert rewrite_natural_date_keywords("added:20231340", UTC) == "added:20231340"
def test_compact_14digit_invalid_date_passes_through_unchanged(self) -> None:
# Month=13 makes datetime() raise ValueError; the token must be left as-is
assert _rewrite_compact_date("20231300120000") == "20231300120000"
def test_8digit_invalid_date_raises(self) -> None:
# The translation pipeline raises InvalidDateQuery for unparsable dates
# (e.g. month=13) so the API can surface a 400 telling the user the date
# is malformed instead of silently returning zero results.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_query("added:20231340", UTC)
assert exc_info.value.field == "added"
assert exc_info.value.value == "20231340"
class TestParseUserQuery:
@@ -463,6 +464,67 @@ class TestParseUserQuery:
) -> None:
assert isinstance(parse_user_query(query_index, raw_query, UTC), tantivy.Query)
@pytest.mark.parametrize(
"raw_query",
[
# Partial date scalar (year only)
pytest.param("created:2020", id="created_year_scalar"),
# 8-digit compact date range in brackets
pytest.param(
"created:[20200101 TO 20201231]",
id="created_8digit_bracket_range",
),
# Comma-separated field + date range (Whoosh v2 multi-clause syntax)
pytest.param(
"title:x,created:[2020 TO 2021]",
id="title_comma_created_range",
),
# Field alias: type -> document_type
pytest.param("type:invoice", id="type_alias"),
# Multi-word date keyword
pytest.param("created:previous week", id="created_previous_week"),
# Full ISO datetime range
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="created_iso_range",
),
# Comma-separated ISO ranges (Whoosh v2 syntax)
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="comma_iso_ranges",
),
],
)
def test_advanced_search_queries_do_not_raise(
self,
query_index: tantivy.Index,
raw_query: str,
) -> None:
"""
End-to-end: queries that the frontend sends must parse without raising.
This tests the full pipeline: translate_query -> tantivy parse_query.
Equivalent to asserting HTTP 200 (not 400) for each query form.
"""
with time_machine.travel(datetime(2026, 6, 15, 12, 0, tzinfo=UTC), tick=False):
assert isinstance(
parse_user_query(query_index, raw_query, UTC),
tantivy.Query,
)
def test_invalid_date_propagates_not_swallowed(
self,
query_index: tantivy.Index,
) -> None:
# parse_user_query falls back to the raw query on unexpected translation
# errors, but an InvalidDateQuery is intentional and must propagate so the
# view can return a 400 instead of silently parsing the raw (invalid) date.
with pytest.raises(InvalidDateQuery) as exc_info:
parse_user_query(query_index, "created:202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
class TestYearRangeRewriting:
"""Whoosh-style year-only date ranges must be rewritten to ISO 8601."""
@@ -514,7 +576,7 @@ class TestYearRangeRewriting:
expected_lo: str,
expected_hi: str,
) -> None:
result = rewrite_natural_date_keywords(query, UTC)
result = translate_query(query, UTC)
lo, hi = _range(result, field)
assert lo == expected_lo
assert hi == expected_hi
@@ -522,14 +584,14 @@ class TestYearRangeRewriting:
def test_reversed_year_range_is_swapped(self) -> None:
# A reversed range must not yield lo > hi, which Tantivy treats as an
# empty range (silently zero results). The bounds are swapped instead.
result = rewrite_natural_date_keywords("created:[2025 TO 2020]", UTC)
result = translate_query("created:[2025 TO 2020]", UTC)
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2026-01-01T00:00:00Z"
def test_year_range_in_complex_boolean_query(self) -> None:
query = "tag:steuer AND (title:2020 OR (NOT title:2019 AND NOT title:2018 AND created:[2020 TO 2020]))"
result = rewrite_natural_date_keywords(query, UTC)
result = translate_query(query, UTC)
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2021-01-01T00:00:00Z"
@@ -539,14 +601,19 @@ class TestYearRangeRewriting:
def test_already_iso_date_range_passes_through_unchanged(self) -> None:
original = "created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]"
assert rewrite_natural_date_keywords(original, UTC) == original
assert translate_query(original, UTC) == original
def test_8digit_in_brackets_not_matched_as_year_range(self) -> None:
# [YYYYMMDD TO YYYYMMDD] has 8-digit values - must not be caught by year rewriter
# [YYYYMMDD TO YYYYMMDD]: the translation layer converts 8-digit bounds to
# ISO day ranges. 20200101 -> 2020-01-01T00:00:00Z (lo of that day);
# 20201231 -> the ceil of Dec 31 = 2021-01-01T00:00:00Z (exclusive end).
# This is the correct and accepted behavior: old compact form becomes a
# proper Tantivy-parseable ISO range.
original = "created:[20200101 TO 20201231]"
result = rewrite_natural_date_keywords(original, UTC)
assert "20200101" in result or "2020-01-01" in result
assert "20201231" in result or "2020-12-31" in result
result = translate_query(original, UTC)
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2021-01-01T00:00:00Z"
class TestNonDateFieldsNotRewritten:
@@ -566,7 +633,7 @@ class TestNonDateFieldsNotRewritten:
],
)
def test_8digit_on_integer_field_passes_through_unchanged(self, query: str) -> None:
assert rewrite_natural_date_keywords(query, EASTERN) == query
assert translate_query(query, EASTERN) == query
@pytest.mark.parametrize(
"query",
@@ -580,12 +647,12 @@ class TestNonDateFieldsNotRewritten:
self,
query: str,
) -> None:
assert rewrite_natural_date_keywords(query, UTC) == query
assert translate_query(query, UTC) == query
def test_unknown_field_keyword_passes_through_unchanged(self) -> None:
# foobar is not a date field: 'foobar:today' must not become a date range,
# which Tantivy would otherwise reject as an unknown/typed field.
assert rewrite_natural_date_keywords("foobar:today", UTC) == "foobar:today"
assert translate_query("foobar:today", UTC) == "foobar:today"
class TestPassthrough:
@@ -593,27 +660,39 @@ class TestPassthrough:
def test_bare_keyword_no_field_prefix_unchanged(self) -> None:
# Bare 'today' with no field: prefix passes through unchanged
result = rewrite_natural_date_keywords("bank statement today", UTC)
result = translate_query("bank statement today", UTC)
assert "today" in result
def test_unrelated_query_unchanged(self) -> None:
assert rewrite_natural_date_keywords("title:invoice", UTC) == "title:invoice"
assert translate_query("title:invoice", UTC) == "title:invoice"
class TestNormalizeQuery:
"""normalize_query expands comma-separated values and collapses whitespace."""
"""translate_query expands comma-separated values and collapses whitespace."""
def test_normalize_expands_comma_separated_tags(self) -> None:
assert normalize_query("tag:foo,bar") == "tag:foo AND tag:bar"
assert translate_query("tag:foo,bar", UTC) == "tag:foo AND tag:bar"
def test_normalize_comma_between_range_expressions(self) -> None:
# Comma-separated field range expressions (Whoosh v2 syntax) must be
# converted to AND so Tantivy does not receive an invalid comma.
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert translate_query(q, UTC) == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_normalize_expands_three_values(self) -> None:
assert normalize_query("tag:foo,bar,baz") == "tag:foo AND tag:bar AND tag:baz"
assert (
translate_query("tag:foo,bar,baz", UTC) == "tag:foo AND tag:bar AND tag:baz"
)
def test_normalize_collapses_whitespace(self) -> None:
assert normalize_query("bank statement") == "bank statement"
assert translate_query("bank statement", UTC) == "bank statement"
def test_normalize_no_commas_unchanged(self) -> None:
assert normalize_query("bank statement") == "bank statement"
assert translate_query("bank statement", UTC) == "bank statement"
@pytest.mark.parametrize(
("raw", "expected"),
@@ -656,7 +735,7 @@ class TestNormalizeQuery:
],
)
def test_normalize_strips_dangling_operators(self, raw: str, expected: str) -> None:
assert normalize_query(raw) == expected
assert translate_query(raw, UTC) == expected
@pytest.mark.parametrize(
"query",
@@ -668,7 +747,7 @@ class TestNormalizeQuery:
],
)
def test_normalize_preserves_valid_operators(self, query: str) -> None:
assert normalize_query(query) == query
assert translate_query(query, UTC) == query
class TestParseSimpleTextHighlightQuery:
@@ -0,0 +1,742 @@
from __future__ import annotations
from datetime import UTC
from datetime import datetime
from typing import TYPE_CHECKING
from zoneinfo import ZoneInfo
import pytest
import time_machine
from documents.search._dates import _precision_bounds
if TYPE_CHECKING:
import tantivy
from documents.search._query import _FIELD_BOOSTS
from documents.search._query import DEFAULT_SEARCH_FIELDS
from documents.search._translate import OPEN_HI
from documents.search._translate import OPEN_LO
from documents.search._translate import Comma
from documents.search._translate import FieldRange
from documents.search._translate import FieldValue
from documents.search._translate import FieldValueList
from documents.search._translate import InvalidDateQuery
from documents.search._translate import Passthrough
from documents.search._translate import resolve_commas
from documents.search._translate import scan
from documents.search._translate import translate_query
from documents.search._translate import translate_range
from documents.search._translate import translate_scalar
@pytest.mark.search
class TestPrecisionBounds:
@pytest.mark.parametrize(
("digits", "expected"),
[
("2020", ((2020, 1, 1), (2021, 1, 1))),
("202003", ((2020, 3, 1), (2020, 4, 1))),
("202012", ((2020, 12, 1), (2021, 1, 1))),
("20200115", ((2020, 1, 15), (2020, 1, 16))),
("20201231", ((2020, 12, 31), (2021, 1, 1))),
],
)
def test_valid(self, digits, expected):
lo, hi = _precision_bounds(digits)
assert (lo.year, lo.month, lo.day) == expected[0]
assert (hi.year, hi.month, hi.day) == expected[1]
@pytest.mark.parametrize("digits", ["202023", "20200230", "20201301", "20", "abcd"])
def test_invalid_returns_none(self, digits):
assert _precision_bounds(digits) is None
@pytest.mark.search
class TestScan:
def test_plain_words_are_passthrough(self):
assert scan("bank statement") == [Passthrough("bank statement")]
def test_field_value(self):
assert scan("created:2020") == [FieldValue("created", "2020")]
def test_field_value_in_boolean(self):
toks = scan("created:2020 OR foo")
assert toks == [
FieldValue("created", "2020"),
Passthrough(" OR foo"),
]
def test_field_value_in_parens(self):
toks = scan("(created:2020 OR foo)")
assert toks == [
Passthrough("("),
FieldValue("created", "2020"),
Passthrough(" OR foo)"),
]
def test_quoted_value(self):
assert scan('correspondent:"A B"') == [FieldValue("correspondent", '"A B"')]
def test_field_range(self):
assert scan("created:[2020 TO 2021]") == [
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.parametrize(
("query", "expected"),
[
pytest.param(
"created:[2020 to]",
FieldRange("created", "[", "2020", "", "]"),
id="open_upper",
),
pytest.param(
"created:[to 2020]",
FieldRange("created", "[", "", "2020", "]"),
id="open_lower",
),
],
)
def test_open_range(self, query, expected):
assert scan(query) == [expected]
def test_comma_inside_range_not_split(self):
# No depth-0 comma here; the whole thing is one range token.
toks = scan("created:[2020 TO 2021]")
assert len(toks) == 1
# --- Edge-case / regression tests (scan must never raise) ---
def test_url_is_passthrough(self):
# "http" is not a known field; the whole URL must pass through verbatim.
assert scan("http://example.com") == [Passthrough("http://example.com")]
def test_unterminated_quote_is_passthrough(self):
# title is a known field but the quoted value has no closing quote;
# _consume_value returns None so the whole string falls into passthrough.
assert scan('title:"abc') == [Passthrough('title:"abc')]
def test_unterminated_bracket_is_passthrough(self):
# created is a known field but the range bracket is never closed;
# _consume_range returns None so the whole string falls into passthrough.
assert scan("created:[2020") == [Passthrough("created:[2020")]
def test_empty_value_at_end_is_passthrough(self):
# created is a known field but there is no value after the colon
# (_consume_value returns None for start >= n), so passthrough.
assert scan("created:") == [Passthrough("created:")]
def test_value_containing_colon(self):
# The bare-word value reader stops at whitespace/paren, not at colon,
# so "2020:30" is consumed as a single value token.
assert scan("created:2020:30") == [FieldValue("created", "2020:30")]
def test_comma_followed_by_unconsumable_value_stops(self):
# A comma followed by whitespace is neither a value-list continuation nor a
# clause separator: the value stops and the comma stays as passthrough.
assert scan("tag:foo, bar") == [
FieldValue("tag", "foo"),
Passthrough(", bar"),
]
def test_bracket_without_to_is_open_upper_bound(self):
# A bracketed value with no TO falls back to (value, "") -> open upper bound.
assert scan("created:[2020]") == [
FieldRange("created", "[", "2020", "", "]"),
]
def test_known_field_name_midword_is_passthrough(self):
# A known field name embedded mid-word is not a field token (the
# word-boundary guard); the whole run stays passthrough.
assert scan("xtag:foo") == [Passthrough("xtag:foo")]
@pytest.mark.search
class TestCommaResolution:
def test_value_list_multi_value_field(self):
toks = resolve_commas(scan("tag:foo,bar"))
assert toks == [FieldValueList("tag", ("foo", "bar"))]
def test_value_list_three(self):
toks = resolve_commas(scan("tag_id:1,2,3"))
assert toks == [FieldValueList("tag_id", ("1", "2", "3"))]
def test_text_field_comma_is_literal(self):
# correspondent is not multi-value: comma stays inside the value.
toks = resolve_commas(scan("correspondent:foo,bar"))
assert toks == [FieldValue("correspondent", "foo,bar")]
def test_clause_separator_before_known_field(self):
toks = resolve_commas(scan("tag:foo,type:bar"))
assert toks == [FieldValue("tag", "foo"), Comma(), FieldValue("type", "bar")]
def test_clause_separator_after_range(self):
toks = resolve_commas(scan("created:[2020 TO 2021],added:[2022 TO 2023]"))
assert toks == [
FieldRange("created", "[", "2020", "2021", "]"),
Comma(),
FieldRange("added", "[", "2022", "2023", "]"),
]
def test_clause_separator_after_quote(self):
toks = resolve_commas(scan('correspondent:"A B",created:[2020 TO 2021]'))
assert toks == [
FieldValue("correspondent", '"A B"'),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
def test_url_comma_is_literal_passthrough(self):
toks = resolve_commas(scan("http://example.com/a,b"))
assert toks == [Passthrough("http://example.com/a,b")]
def test_non_multi_value_comma_is_literal(self):
# title is not in MULTI_VALUE_FIELDS: comma stays inside the value.
toks = resolve_commas(scan("title:10,20"))
assert toks == [FieldValue("title", "10,20")]
def test_clause_separator_before_known_date_field(self):
# The comma between a bare value and a known date field acts as a
# clause separator; both sides survive as distinct tokens.
toks = resolve_commas(scan("correspondent:foo,created:[2020 TO 2021]"))
assert toks == [
FieldValue("correspondent", "foo"),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.search
class TestTranslateScalar:
@pytest.mark.parametrize(
("field", "value", "expected"),
[
(
"created",
"2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"created",
"202003",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
(
"created",
"20200115",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-01-15",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-03",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
],
)
def test_partial_and_iso_dates(self, field: str, value: str, expected: str) -> None:
assert translate_scalar(field, value, UTC) == expected
def test_invalid_date_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_keyword_delegates(self) -> None:
# keyword path produces a range; just assert it is a created range
out = translate_scalar("created", "today", UTC)
assert out.startswith("created:[") and out.endswith("]")
def test_14digit_compact_datetime(self) -> None:
out = translate_scalar("created", "20240115120000", UTC)
assert "20240115120000" not in out
assert out.startswith("created:")
assert out == "created:[2024-01-15T12:00:00Z TO 2024-01-15T12:00:00Z]"
def test_14digit_invalid_month_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "20231300120000", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "20231300120000"
def test_unrecognized_value_raises(self) -> None:
# A value that is not a keyword, digits, ISO date, or compact timestamp
# raises rather than producing invalid Tantivy syntax or silently matching
# nothing.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateRange:
@pytest.mark.parametrize(
("lo", "hi", "expected"),
[
("2005", "2009", "created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"),
(
"202001",
"202006",
"created:[2020-01-01T00:00:00Z TO 2020-07-01T00:00:00Z]",
),
(
"20200101",
"20201231",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"2020-01-01",
"2020-12-31",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
],
)
def test_absolute_ranges(self, lo, hi, expected):
assert translate_range("created", lo, hi, UTC) == expected
def test_reversed_swaps(self):
assert translate_range("created", "2009", "2005", UTC) == (
"created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"
)
def test_open_upper(self):
out = translate_range("created", "2020", "", UTC)
assert out == f"created:[2020-01-01T00:00:00Z TO {OPEN_HI}]"
def test_open_lower(self):
out = translate_range("created", "", "2020", UTC)
assert out == f"created:[{OPEN_LO} TO 2021-01-01T00:00:00Z]"
def test_invalid_bound_raises(self):
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "202023", "2025", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_invalid_high_bound_raises(self):
# Low bound parses, high bound does not -> raise on the high bound.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "2020", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateQuery:
@pytest.mark.parametrize(
("raw", "expected"),
[
(
"created:2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
("tag:foo,bar", "tag:foo AND tag:bar"),
# 'type' is a user-facing alias rewritten to 'document_type' (the real schema field)
("tag:foo,type:bar", "tag:foo AND document_type:bar"),
(
"created:[2020 TO 2021],added:[2022 TO 2023]",
"created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
" AND "
"added:[2022-01-01T00:00:00Z TO 2024-01-01T00:00:00Z]",
),
# correspondent is not multi-value: comma stays literal inside the value
("correspondent:foo,bar", "correspondent:foo,bar"),
],
)
def test_golden(self, raw: str, expected: str) -> None:
assert translate_query(raw, UTC) == expected
@pytest.mark.parametrize(
"raw",
[
"created:2020",
"created:202003",
"created:[20200101 TO 20201231]",
"created:[2020-01-01 TO 2020-12-31]",
"created:[2020 to]",
"created:[to 2020]",
"title:x,created:[2020 TO 2021]",
"created:2020 OR foo",
"(created:2020 OR invoice)",
"tag:foo,type:bar",
"bank statement",
],
)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
# Must not raise:
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestFieldAliasing:
"""Whoosh->Tantivy field-name aliasing (type/path -> document_type/storage_path)."""
def test_type_alias(self) -> None:
assert translate_query("type:invoice", UTC) == "document_type:invoice"
def test_path_alias(self) -> None:
assert translate_query("path:/foo/bar", UTC) == "storage_path:/foo/bar"
def test_type_id_alias(self) -> None:
assert translate_query("type_id:5", UTC) == "document_type_id:5"
def test_path_id_alias(self) -> None:
assert translate_query("path_id:7", UTC) == "storage_path_id:7"
def test_clause_separator_plus_alias(self) -> None:
# Comma between known fields acts as AND separator; alias still applied.
assert (
translate_query("tag:foo,type:bar", UTC) == "tag:foo AND document_type:bar"
)
def test_type_range_alias(self) -> None:
# type is not a date field; range passes through verbatim with alias applied.
assert (
translate_query("type:[2020 TO 2021]", UTC)
== "document_type:[2020 TO 2021]"
)
def test_parse_acceptance_type(self, index: tantivy.Index) -> None:
# Translated output must be accepted by the real Tantivy parser.
translated = translate_query("type:invoice", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_path(self, index: tantivy.Index) -> None:
translated = translate_query("path:foo", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
# Freeze time so relative-date tests are deterministic.
_FROZEN_NOW = datetime(2026, 3, 28, 12, 0, 0, tzinfo=UTC)
@pytest.mark.search
class TestRelativeRanges:
"""Relative date-range tokens resolved against a frozen clock."""
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_7_days_to_now(self) -> None:
assert translate_query("added:[-7 days to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_week_to_now(self) -> None:
assert translate_query("added:[-1 week to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_month_to_now(self) -> None:
assert translate_query("created:[-1 month to now]", UTC) == (
"created:[2026-02-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_year_to_now(self) -> None:
assert translate_query("modified:[-1 year to now]", UTC) == (
"modified:[2025-03-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_3_hours_to_now(self) -> None:
assert translate_query("added:[-3 hours to now]", UTC) == (
"added:[2026-03-28T09:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_uppercase_units(self) -> None:
assert translate_query("added:[-1 WEEK TO NOW]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_now_minus_7d_compact(self) -> None:
assert translate_query("added:[now-7d TO now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_reversed_range_swapped(self) -> None:
# now+1h TO now-1h is reversed; translate_range swaps -> lo=now-1h, hi=now+1h
assert translate_query("added:[now+1h TO now-1h]", UTC) == (
"added:[2026-03-28T11:00:00Z TO 2026-03-28T13:00:00Z]"
)
@pytest.mark.parametrize(
"raw",
[
"added:[-7 days to now]",
"added:[-1 week to now]",
"created:[-1 month to now]",
"modified:[-1 year to now]",
"added:[-3 hours to now]",
"added:[now-7d TO now]",
"added:[now+1h TO now-1h]",
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestOperatorNormalization:
"""Post-render operator normalization in translate_query."""
def test_spaced_dash_removed(self) -> None:
assert (
translate_query("H52.1 - Kurzsichtigkeit", UTC) == "H52.1 Kurzsichtigkeit"
)
def test_spaced_dash_simple(self) -> None:
assert translate_query("bar - baz", UTC) == "bar baz"
def test_trailing_operator_stripped(self) -> None:
assert translate_query("foo -", UTC) == "foo"
def test_date_range_preserved(self) -> None:
out = translate_query("created:[2020 TO 2021]", UTC)
# Must not corrupt the ISO range
assert out == "created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
def test_date_scalar_with_or(self) -> None:
out = translate_query("created:2020 OR foo", UTC)
# The created scalar becomes a range; " OR foo" passes through verbatim.
assert out.startswith("created:[")
assert "OR foo" in out
def test_parse_acceptance_spaced_dash(self, index: tantivy.Index) -> None:
translated = translate_query("H52.1 - Kurzsichtigkeit", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_trailing_op(self, index: tantivy.Index) -> None:
translated = translate_query("foo -", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestMultiWordDateKeywords:
"""scan() must consume multi-word date keywords as a single value."""
def test_scan_previous_week_as_single_token(self) -> None:
# "created:previous week" must produce one FieldValue with value "previous week",
# not FieldValue("created","previous") + Passthrough(" week").
toks = scan("created:previous week")
assert toks == [FieldValue("created", "previous week")]
def test_scan_this_month_as_single_token(self) -> None:
toks = scan("added:this month")
assert toks == [FieldValue("added", "this month")]
def test_scan_previous_month_as_single_token(self) -> None:
toks = scan("created:previous month")
assert toks == [FieldValue("created", "previous month")]
def test_scan_this_year_as_single_token(self) -> None:
toks = scan("added:this year")
assert toks == [FieldValue("added", "this year")]
def test_scan_previous_year_as_single_token(self) -> None:
toks = scan("created:previous year")
assert toks == [FieldValue("created", "previous year")]
def test_scan_previous_quarter_as_single_token(self) -> None:
toks = scan("created:previous quarter")
assert toks == [FieldValue("created", "previous quarter")]
def test_quoted_multi_word_keyword_still_works(self) -> None:
# The quoted form must continue to work as before.
toks = scan('created:"previous week"')
assert toks == [FieldValue("created", '"previous week"')]
def test_non_date_field_not_affected(self) -> None:
# "previous" stops at the space for non-date fields; " week" passes through.
toks = scan("correspondent:previous week")
assert toks == [
FieldValue("correspondent", "previous"),
Passthrough(" week"),
]
@pytest.mark.search
class TestKeywordDateResolution:
"""Relative date keywords resolve to exact ISO ranges against a frozen clock.
Frozen at 2026-03-28 12:00 UTC (a Saturday in Q1) so the week, month,
quarter and year rollovers are all exercised by a single anchor.
"""
# created is a DateField: bounds are UTC midnight, no timezone offset.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"created:[2026-03-28T00:00:00Z TO 2026-03-29T00:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"created:[2026-03-27T00:00:00Z TO 2026-03-28T00:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"created:[2026-03-16T00:00:00Z TO 2026-03-23T00:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"created:[2026-03-01T00:00:00Z TO 2026-04-01T00:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"created:[2026-02-01T00:00:00Z TO 2026-03-01T00:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"created:[2026-01-01T00:00:00Z TO 2027-01-01T00:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"created:[2025-01-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"created:[2025-10-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_date_only_field_keyword_ranges(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"created:{keyword}", UTC) == expected
# added is a DateTimeField: local-tz midnight converted to UTC. Tokyo
# (+09:00, no DST) shifts each midnight boundary back to 15:00Z the day
# before, so this also exercises the local-midnight offset path.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"added:[2026-03-27T15:00:00Z TO 2026-03-28T15:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"added:[2026-03-26T15:00:00Z TO 2026-03-27T15:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"added:[2026-03-15T15:00:00Z TO 2026-03-22T15:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"added:[2026-02-28T15:00:00Z TO 2026-03-31T15:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"added:[2026-01-31T15:00:00Z TO 2026-02-28T15:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"added:[2025-12-31T15:00:00Z TO 2026-12-31T15:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"added:[2024-12-31T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"added:[2025-09-30T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_datetime_field_keyword_ranges_local_tz(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"added:{keyword}", ZoneInfo("Asia/Tokyo")) == expected
@pytest.mark.search
class TestISODatetimeBounds:
"""Full ISO datetime tokens in range bounds must be parsed directly."""
def test_translate_range_iso_bounds_passthrough(self) -> None:
# Already-ISO datetime bounds must pass through as-is (exact instant).
result = translate_range(
"created",
"2020-01-01T00:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert result == "created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]"
def test_translate_query_iso_range_preserved(self) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert translate_query(q, UTC) == q
def test_translate_query_comma_separated_iso_ranges(self) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
result = translate_query(q, UTC)
assert result == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_invalid_iso_datetime_raises(self) -> None:
# A token with "T" that is not valid ISO datetime -> raise.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range(
"created",
"2020-01-01T99:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert exc_info.value.field == "created"
assert exc_info.value.value == "2020-01-01T99:00:00Z"
def test_parse_acceptance_iso_bounds(self, index: tantivy.Index) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_comma_iso_ranges(self, index: tantivy.Index) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
+6 -3
View File
@@ -725,9 +725,11 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
GIVEN:
- One document added right now
WHEN:
- Query with invalid added date
- Query with an invalid added date
THEN:
- 400 Bad Request returned (Tantivy rejects invalid date field syntax)
- 400 Bad Request with a message naming the malformed date, so the
user knows their date is invalid rather than silently getting zero
results
"""
d1 = Document.objects.create(
title="invoice",
@@ -740,8 +742,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
response = self.client.get("/api/documents/?query=added:invalid-date")
# Tantivy rejects unparsable field queries with a 400
# An unparsable date is reported as a malformed query, not silently empty.
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("invalid-date", str(response.data["query"]))
@override_settings(
TIME_ZONE="UTC",
+71
View File
@@ -216,6 +216,77 @@ class TestSystemStatus(APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_none(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns no worker responses
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
mock_ping.return_value = None
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(
response.data["tasks"]["celery_error"],
"No celery workers responded to ping. This may be temporary.",
)
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_unexpected_responses(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns an unexpected worker response
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
self.client.force_login(self.user)
for ping_response in (
{"hostname": {"ok": "not-pong"}},
{"hostname": {}},
{"hostname": "pong"},
):
with self.subTest(ping_response=ping_response):
mock_ping.return_value = ping_response
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(response.data["tasks"]["celery_url"], "hostname")
self.assertEqual(
response.data["tasks"]["celery_error"],
"Celery worker responded unexpectedly.",
)
@mock.patch("documents.views.sleep")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_retry_success(
self,
mock_ping,
mock_sleep,
) -> None:
"""
GIVEN:
- Celery ping fails once but succeeds on retry
WHEN:
- The user requests the system status
THEN:
- The response contains an OK celery status
"""
mock_ping.side_effect = [None, {"hostname": {"ok": "pong"}}]
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
self.assertIsNone(response.data["tasks"]["celery_error"])
self.assertEqual(mock_ping.call_count, 2)
mock_sleep.assert_called_once_with(0.25)
@mock.patch("documents.search.get_backend")
def test_system_status_index_ok(self, mock_get_backend) -> None:
"""
-3
View File
@@ -368,7 +368,6 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
self.document,
self.user,
None,
hints=None,
)
@patch("documents.views.get_ai_document_classification")
@@ -400,7 +399,6 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
self.document,
self.user,
"de-de",
hints=None,
)
self.assertEqual(
get_llm_suggestion_cache(
@@ -440,7 +438,6 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
self.document,
self.user,
"fr-fr",
hints=None,
)
self.assertEqual(
get_llm_suggestion_cache(
+30 -13
View File
@@ -12,6 +12,7 @@ from datetime import timedelta
from http import HTTPStatus
from pathlib import Path
from time import mktime
from time import sleep
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
@@ -245,7 +246,6 @@ from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
from paperless_ai.matching import match_storage_paths_by_name
from paperless_ai.matching import match_tags_by_name
from paperless_ai.taxonomy import get_taxonomy_hints_for_document
from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule
from paperless_mail.oauth import PaperlessMailOAuth2Manager
@@ -1495,14 +1495,11 @@ class DocumentViewSet(
refresh_suggestions_cache(doc.pk)
return Response(cached_llm_suggestions.suggestions)
hints = get_taxonomy_hints_for_document(doc, request.user)
try:
llm_suggestions = get_ai_document_classification(
doc,
request.user,
output_language,
hints=hints,
)
except ValueError as exc:
logger.exception(
@@ -1517,22 +1514,18 @@ class DocumentViewSet(
matched_tags = match_tags_by_name(
llm_suggestions.get("tags", []),
request.user,
hinted_names=set(hints["tags"]) if hints else None,
)
matched_correspondents = match_correspondents_by_name(
llm_suggestions.get("correspondents", []),
request.user,
hinted_names=set(hints["correspondents"]) if hints else None,
)
matched_types = match_document_types_by_name(
llm_suggestions.get("document_types", []),
request.user,
hinted_names=set(hints["document_types"]) if hints else None,
)
matched_paths = match_storage_paths_by_name(
llm_suggestions.get("storage_paths", []),
request.user,
hinted_names=set(hints["storage_paths"]) if hints else None,
)
resp_data = {
@@ -2284,6 +2277,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
return super().list(request)
from documents.search import SearchHit
from documents.search import SearchQueryError
from documents.search import TantivyBackend
from documents.search import TantivyRelevanceList
from documents.search import get_backend
@@ -2476,6 +2470,11 @@ class UnifiedSearchViewSet(DocumentViewSet):
return HttpResponseForbidden(_("Insufficient permissions."))
except ValidationError:
raise
except SearchQueryError as e:
# User-fixable query error (e.g. an unparsable date): surface the
# specific message so the user can correct it, rather than a generic
# 400 or silently empty results.
raise ValidationError({"query": [str(e)]}) from e
except Exception as e:
logger.warning(f"An error occurred listing search results: {e!s}")
return HttpResponseBadRequest(
@@ -4998,11 +4997,29 @@ class SystemStatusView(PassUserMixin):
celery_error = None
celery_url = None
try:
celery_ping = celery_app.control.inspect().ping()
celery_url = next(iter(celery_ping.keys()))
first_worker_ping = celery_ping[celery_url]
if first_worker_ping["ok"] == "pong":
celery_active = "OK"
celery_ping = None
for ping_attempt in range(3):
celery_ping = celery_app.control.inspect().ping()
if celery_ping:
break
if ping_attempt < 2:
sleep(0.25)
if not celery_ping:
celery_active = "WARNING"
celery_error = (
"No celery workers responded to ping. This may be temporary."
)
else:
celery_url, first_worker_ping = next(iter(celery_ping.items()))
if (
isinstance(first_worker_ping, dict)
and first_worker_ping.get("ok") == "pong"
):
celery_active = "OK"
else:
celery_active = "WARNING"
celery_error = "Celery worker responded unexpectedly."
except Exception as e:
celery_active = "ERROR"
logger.exception(
+19 -20
View File
@@ -1,21 +1,16 @@
import json
import logging
from typing import TYPE_CHECKING
from django.conf import settings
from django.contrib.auth.models import User
from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware
from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.db import db_connection_released
from paperless_ai.indexing import query_similar_documents
from paperless_ai.indexing import truncate_content
from paperless_ai.indexing import visible_document_ids_for_user
from paperless_ai.taxonomy import format_hints_for_prompt
if TYPE_CHECKING:
from paperless_ai.taxonomy import TaxonomyHints
logger = logging.getLogger("paperless_ai.rag_classifier")
@@ -31,7 +26,6 @@ def get_language_name(language_code: str) -> str:
def build_prompt_without_rag(
document: Document,
config: AIConfig,
hints: "TaxonomyHints | None" = None,
) -> str:
filename = document.filename or ""
content = truncate_content(
@@ -40,16 +34,10 @@ def build_prompt_without_rag(
context_size=config.llm_context_size,
)
hints_block = format_hints_for_prompt(hints) if hints else ""
# Splice the block (if any) immediately before the "Analyze ..." instruction.
# When there is no block this expands to nothing, so the prompt is identical
# to the pre-hints baseline.
hints_section = f"{hints_block}\n\n " if hints_block else ""
return f"""
You are a document classification assistant.
{hints_section}Analyze the following document and extract the following information:
Analyze the following document and extract the following information:
- A short descriptive title
- Tags that reflect the content
- Names of people or organizations mentioned
@@ -69,9 +57,8 @@ def build_prompt_with_rag(
document: Document,
config: AIConfig,
user: User | None = None,
hints: "TaxonomyHints | None" = None,
) -> str:
base_prompt = build_prompt_without_rag(document, config, hints=hints)
base_prompt = build_prompt_without_rag(document, config)
context = truncate_content(
get_context_for_document(document, user),
chunk_size=config.llm_embedding_chunk_size,
@@ -109,7 +96,20 @@ def get_context_for_document(
user: User | None = None,
max_docs: int = 5,
) -> str:
visible_document_ids = visible_document_ids_for_user(user)
visible_documents = (
get_objects_for_user_owner_aware(
user,
"view_document",
Document,
)
if user
else None
)
visible_document_ids = (
list(visible_documents.values_list("pk", flat=True))
if visible_documents is not None
else None
)
similar_docs = query_similar_documents(
document=doc,
document_ids=visible_document_ids,
@@ -137,14 +137,13 @@ def get_ai_document_classification(
document: Document,
user: User | None = None,
output_language: str | None = None,
hints: "TaxonomyHints | None" = None,
) -> dict:
ai_config = AIConfig()
prompt = (
build_prompt_with_rag(document, ai_config, user, hints=hints)
build_prompt_with_rag(document, ai_config, user)
if ai_config.llm_embedding_backend
else build_prompt_without_rag(document, ai_config, hints=hints)
else build_prompt_without_rag(document, ai_config)
)
client = AIClient()
+5 -46
View File
@@ -5,7 +5,6 @@ from datetime import timedelta
from typing import TYPE_CHECKING
from django.conf import settings
from django.contrib.auth.models import User
from django.utils import timezone
from filelock import FileLock
from filelock import ReadWriteLock
@@ -13,7 +12,6 @@ from filelock import Timeout
from documents.models import Document
from documents.models import PaperlessTask
from documents.permissions import get_objects_for_user_owner_aware
from documents.utils import IterWrapper
from documents.utils import identity
from paperless.config import AIConfig
@@ -24,7 +22,6 @@ from paperless_ai.embedding import get_embedding_model
if TYPE_CHECKING:
from llama_index.core.schema import BaseNode
from llama_index.core.schema import NodeWithScore
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
@@ -452,36 +449,12 @@ def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str]
return {str(document_id) for document_id in document_ids}
def visible_document_ids_for_user(user: User | None) -> list[int] | None:
"""Return the pks of documents ``user`` may view, or ``None`` for no filter.
Returns ``None`` when ``user`` is ``None`` so retrieval runs unfiltered. Used
by both the similarity-context and taxonomy-hints paths to scope RAG
neighbours to documents the requesting user is allowed to see.
"""
if user is None:
return None
visible_documents = get_objects_for_user_owner_aware(
user,
"view_document",
Document,
)
return list(visible_documents.values_list("pk", flat=True))
def retrieve_similar_nodes(
def query_similar_documents(
document: Document,
document_ids: Iterable[int | str] | None = None,
top_k: int = 5,
) -> list["NodeWithScore"]:
"""Run ANN retrieval and return the raw NodeWithScore results.
Returns ``[]`` when the allow-list normalizes to empty, or when no index
exists yet (queuing a build in that case). The ``retrieve()`` call is a slow
embedding request, so it runs inside ``db_connection_released()`` to avoid
pinning the pooled DB connection (#12976). Both ``query_similar_documents``
and the taxonomy-hints path go through here, so they share that behavior.
"""
document_ids: Iterable[int | str] | None = None,
) -> list[Document]:
"""Return up to ``top_k`` Documents most similar to ``document``."""
allowed_document_ids = normalize_document_ids(document_ids)
if allowed_document_ids is not None and not allowed_document_ids:
return []
@@ -521,21 +494,7 @@ def retrieve_similar_nodes(
filters=filters,
)
with db_connection_released():
return retriever.retrieve(query_text)
def query_similar_documents(
document: Document,
top_k: int = 5,
document_ids: Iterable[int | str] | None = None,
) -> list[Document]:
"""Return up to ``top_k`` Documents most similar to ``document``."""
allowed_document_ids = normalize_document_ids(document_ids)
results = retrieve_similar_nodes(
document=document,
document_ids=allowed_document_ids,
top_k=top_k,
)
results = retriever.retrieve(query_text)
retrieved_document_ids: list[int] = []
for node in results:
+11 -38
View File
@@ -15,56 +15,40 @@ MATCH_THRESHOLD = 0.8
logger = logging.getLogger("paperless_ai.matching")
def match_tags_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[Tag]:
def match_tags_by_name(names: list[str], user: User) -> list[Tag]:
queryset = get_objects_for_user_owner_aware(
user,
["view_tag"],
Tag,
)
return _match_names_to_queryset(names, queryset, "name", hinted_names)
return _match_names_to_queryset(names, queryset, "name")
def match_correspondents_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[Correspondent]:
def match_correspondents_by_name(names: list[str], user: User) -> list[Correspondent]:
queryset = get_objects_for_user_owner_aware(
user,
["view_correspondent"],
Correspondent,
)
return _match_names_to_queryset(names, queryset, "name", hinted_names)
return _match_names_to_queryset(names, queryset, "name")
def match_document_types_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[DocumentType]:
def match_document_types_by_name(names: list[str], user: User) -> list[DocumentType]:
queryset = get_objects_for_user_owner_aware(
user,
["view_documenttype"],
DocumentType,
)
return _match_names_to_queryset(names, queryset, "name", hinted_names)
return _match_names_to_queryset(names, queryset, "name")
def match_storage_paths_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[StoragePath]:
def match_storage_paths_by_name(names: list[str], user: User) -> list[StoragePath]:
queryset = get_objects_for_user_owner_aware(
user,
["view_storagepath"],
StoragePath,
)
return _match_names_to_queryset(names, queryset, "name", hinted_names)
return _match_names_to_queryset(names, queryset, "name")
def _normalize(s: str) -> str:
@@ -74,18 +58,10 @@ def _normalize(s: str) -> str:
return s
def _match_names_to_queryset(
names: list[str],
queryset,
attr: str,
hinted_names: set[str] | None = None,
):
def _match_names_to_queryset(names: list[str], queryset, attr: str):
results = []
objects = list(queryset)
object_names = [_normalize(getattr(obj, attr)) for obj in objects]
normalized_hints = (
{_normalize(name) for name in hinted_names} if hinted_names else set()
)
for name in names:
if not name:
@@ -100,11 +76,6 @@ def _match_names_to_queryset(
results.append(matched)
continue
# A hinted name that didn't exact-match came from existing taxonomy
# verbatim; do not fuzzy-map it onto a different object.
if target in normalized_hints:
continue
# Fuzzy match fallback
matches = difflib.get_close_matches(
target,
@@ -117,6 +88,8 @@ def _match_names_to_queryset(
matched = objects.pop(index)
object_names.pop(index)
results.append(matched)
else:
pass
return results
-115
View File
@@ -1,115 +0,0 @@
from typing import TYPE_CHECKING
from typing import TypedDict
from django.contrib.auth.models import User
from documents.models import Document
from paperless.config import AIConfig
from paperless_ai.indexing import retrieve_similar_nodes
from paperless_ai.indexing import visible_document_ids_for_user
if TYPE_CHECKING:
from llama_index.core.schema import NodeWithScore
class TaxonomyHints(TypedDict):
tags: list[str]
document_types: list[str]
correspondents: list[str]
storage_paths: list[str]
def build_taxonomy_hints_from_nodes(
nodes: list["NodeWithScore"],
) -> TaxonomyHints:
"""Collect the unique, sorted taxonomy names carried on retrieved nodes.
Reads ``tags`` (a list), ``document_type``, ``correspondent``, and
``storage_path`` from each node's metadata. Empty / ``None`` values and
missing keys are skipped. The result is naturally bounded by the retrieval
``top_k``, so no cap is applied.
"""
tags: set[str] = set()
document_types: set[str] = set()
correspondents: set[str] = set()
storage_paths: set[str] = set()
for node in nodes:
metadata = node.metadata or {}
for tag in metadata.get("tags") or []:
if tag:
tags.add(tag)
document_type = metadata.get("document_type")
if document_type:
document_types.add(document_type)
correspondent = metadata.get("correspondent")
if correspondent:
correspondents.add(correspondent)
storage_path = metadata.get("storage_path")
if storage_path:
storage_paths.add(storage_path)
return TaxonomyHints(
tags=sorted(tags),
document_types=sorted(document_types),
correspondents=sorted(correspondents),
storage_paths=sorted(storage_paths),
)
_HINT_INSTRUCTION = (
"Prefer existing names from these lists verbatim. Only propose a new value "
"if none of the existing names fits."
)
def format_hints_for_prompt(hints: TaxonomyHints) -> str:
"""Render non-empty hint categories as labelled blocks plus one instruction.
Returns "" when every category is empty, so callers can treat the result
the same as no hints at all.
"""
# Literal-key access keeps this TypedDict-safe for mypy; the order here is
# the order the blocks appear in the prompt.
labelled_values: list[tuple[str, list[str]]] = [
("Available tags", hints["tags"]),
("Available document types", hints["document_types"]),
("Available correspondents", hints["correspondents"]),
("Available storage paths", hints["storage_paths"]),
]
blocks: list[str] = []
for label, values in labelled_values:
if values:
listing = "\n".join(f"- {value}" for value in values)
blocks.append(f"{label}:\n{listing}")
if not blocks:
return ""
return "\n\n".join([*blocks, _HINT_INSTRUCTION])
def get_taxonomy_hints_for_document(
document: Document,
user: User | None,
) -> TaxonomyHints | None:
"""Build taxonomy hints from a document's RAG neighbours.
Returns ``None`` when no embedding backend is configured (the gate) so the
caller's prompt and matching are identical to today. Otherwise returns a
``TaxonomyHints`` -- possibly all-empty when no similar documents exist.
Applies the same owner-aware visible-document filter as
``get_context_for_document``.
"""
if not AIConfig().llm_embedding_backend:
return None
nodes = retrieve_similar_nodes(
document=document,
document_ids=visible_document_ids_for_user(user),
)
return build_taxonomy_hints_from_nodes(nodes)
@@ -1,11 +1,8 @@
import json
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
import pytest_mock
from django.test import override_settings
from documents.models import Document
@@ -264,111 +261,3 @@ def test_get_context_for_document_no_similar_docs(mock_document):
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
result = get_context_for_document(mock_document)
assert result == ""
class TestPromptHints:
@pytest.fixture
def config(self) -> AIConfig:
# build_prompt_* only read these two numeric settings off config;
# a stand-in avoids constructing a DB-backed AIConfig.
return cast(
"AIConfig",
SimpleNamespace(llm_embedding_chunk_size=1000, llm_context_size=8000),
)
def test_without_rag_includes_hints_block(
self,
mock_document: MagicMock,
config: AIConfig,
) -> None:
hints = {
"tags": ["Bloodwork"],
"document_types": ["Invoice"],
"correspondents": [],
"storage_paths": [],
}
prompt = build_prompt_without_rag(mock_document, config, hints=hints)
assert "Available tags:" in prompt
assert "- Bloodwork" in prompt
assert "Prefer existing names from these lists verbatim" in prompt
def test_without_rag_none_matches_baseline(
self,
mock_document: MagicMock,
config: AIConfig,
) -> None:
baseline = build_prompt_without_rag(mock_document, config)
with_none = build_prompt_without_rag(mock_document, config, hints=None)
assert with_none == baseline
assert "Available tags:" not in with_none
def test_with_rag_includes_context_and_hints(
self,
mock_document: MagicMock,
config: AIConfig,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.ai_classifier.get_context_for_document",
return_value="TITLE: Neighbour\nsome context",
)
hints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
prompt = build_prompt_with_rag(mock_document, config, user=None, hints=hints)
assert "Additional context from similar documents" in prompt
assert "Available tags:" in prompt
def test_classification_forwards_hints(
self,
mock_document: MagicMock,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.ai_classifier.AIConfig",
return_value=SimpleNamespace(
llm_embedding_backend=None,
llm_embedding_chunk_size=1000,
llm_context_size=8000,
),
)
build = mocker.patch(
"paperless_ai.ai_classifier.build_prompt_without_rag",
return_value="PROMPT",
)
mock_client = MagicMock()
mock_client.run_llm_query.return_value = {
"title": "t",
"tags": [],
"correspondents": [],
"document_types": [],
"storage_paths": [],
"dates": [],
}
mocker.patch("paperless_ai.ai_classifier.AIClient", return_value=mock_client)
hints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
result = get_ai_document_classification(
mock_document,
user=None,
hints=hints,
)
_, build_kwargs = build.call_args
assert build_kwargs["hints"] == hints
assert set(result.keys()) == {
"title",
"tags",
"correspondents",
"document_types",
"storage_paths",
"dates",
}
@@ -1,5 +1,4 @@
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -727,58 +726,3 @@ class TestQuerySimilarDocuments:
results = indexing.query_similar_documents(a, document_ids=[b.id])
assert all(doc.id == b.id for doc in results)
class TestRetrieveSimilarNodes:
@pytest.mark.django_db
def test_returns_raw_nodes_from_retriever(
self,
temp_llm_index_dir: Path,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=True)
mocker.patch("paperless_ai.indexing.load_or_build_index")
node1 = SimpleNamespace(metadata={"document_id": "1"})
node2 = SimpleNamespace(metadata={"document_id": "2"})
retriever = mocker.MagicMock()
retriever.retrieve.return_value = [node1, node2]
mocker.patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=retriever,
)
result = indexing.retrieve_similar_nodes(real_document, top_k=3)
assert result == [node1, node2]
@pytest.mark.django_db
def test_empty_allow_list_fails_closed(
self,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
load = mocker.patch("paperless_ai.indexing.load_or_build_index")
result = indexing.retrieve_similar_nodes(real_document, document_ids=[])
assert result == []
load.assert_not_called()
@pytest.mark.django_db
def test_queues_update_when_index_missing(
self,
temp_llm_index_dir: Path,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=False)
queue = mocker.patch("paperless_ai.indexing.queue_llm_index_update_if_needed")
result = indexing.retrieve_similar_nodes(real_document, top_k=2)
assert result == []
queue.assert_called_once_with(
rebuild=False,
reason="LLM index not found for similarity query.",
)
-92
View File
@@ -1,15 +1,12 @@
import difflib
from unittest.mock import patch
import pytest
import pytest_mock
from django.test import TestCase
from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.tests.factories import TagFactory
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
@@ -90,95 +87,6 @@ class TestAIMatching(TestCase):
self.assertEqual(result[1].name, "Test Tag 2")
class TestHintedMatching:
def test_hinted_verbatim_skips_fuzzy(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
spy = mocker.spy(difflib, "get_close_matches")
result = match_tags_by_name(
["Bloodwork"],
user=None,
hinted_names={"Bloodwork"},
)
assert [t.name for t in result] == ["Bloodwork"]
spy.assert_not_called()
def test_unhinted_name_still_fuzzy_matches(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
# "Bloodwrok" is a typo not in hints -> fuzzy still maps it to Bloodwork.
result = match_tags_by_name(
["Bloodwrok"],
user=None,
hinted_names={"Taxes"},
)
assert [t.name for t in result] == ["Bloodwork"]
def test_hinted_name_with_whitespace_exact_matches(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
spy = mocker.spy(difflib, "get_close_matches")
result = match_tags_by_name(
["Bloodwork "],
user=None,
hinted_names={"Bloodwork"},
)
assert [t.name for t in result] == ["Bloodwork"]
spy.assert_not_called()
def test_hinted_name_absent_from_queryset_is_skipped_not_fuzzed(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
# A hint with no exact object must not fall through to fuzzy.
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
result = match_tags_by_name(
["Bloodwrok"],
user=None,
hinted_names={"Bloodwrok"},
)
assert result == []
def test_backward_compatible_without_kwarg(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Test Tag 1")],
)
result = match_tags_by_name(["Test Tag 1", "Nonexistent"], user=None)
assert [t.name for t in result] == ["Test Tag 1"]
@pytest.mark.django_db
class TestExtractUnmatchedNamesNormalization:
def test_punctuated_name_already_matched_is_not_returned_as_unmatched(
-220
View File
@@ -1,220 +0,0 @@
from types import SimpleNamespace
import pytest_mock
from documents.tests.factories import DocumentFactory
from paperless_ai.taxonomy import TaxonomyHints
from paperless_ai.taxonomy import build_taxonomy_hints_from_nodes
from paperless_ai.taxonomy import format_hints_for_prompt
from paperless_ai.taxonomy import get_taxonomy_hints_for_document
def make_node(**metadata: object) -> SimpleNamespace:
"""A stand-in for NodeWithScore: only ``.metadata`` is accessed."""
return SimpleNamespace(metadata=metadata)
class TestBuildTaxonomyHintsFromNodes:
def test_returns_all_four_keys(self) -> None:
hints = build_taxonomy_hints_from_nodes([])
assert set(hints.keys()) == {
"tags",
"document_types",
"correspondents",
"storage_paths",
}
def test_collects_and_sorts_values(self) -> None:
nodes = [
make_node(
tags=["Taxes", "Bloodwork"],
document_type="Invoice",
correspondent="IRS",
storage_path="Financial",
),
]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["tags"] == ["Bloodwork", "Taxes"]
assert hints["document_types"] == ["Invoice"]
assert hints["correspondents"] == ["IRS"]
assert hints["storage_paths"] == ["Financial"]
def test_deduplicates_across_nodes(self) -> None:
nodes = [
make_node(tags=["Taxes"], document_type="Invoice"),
make_node(tags=["Taxes", "Medical"], document_type="Invoice"),
]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["tags"] == ["Medical", "Taxes"]
assert hints["document_types"] == ["Invoice"]
def test_none_values_skipped(self) -> None:
nodes = [
make_node(
tags=["Taxes", None, ""],
document_type=None,
correspondent=None,
storage_path=None,
),
]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["tags"] == ["Taxes"]
assert hints["document_types"] == []
assert hints["correspondents"] == []
assert hints["storage_paths"] == []
def test_missing_storage_path_key_handled(self) -> None:
# Pre-enrichment nodes have no storage_path key at all.
nodes = [make_node(tags=["Taxes"], document_type="Invoice")]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["storage_paths"] == []
def test_empty_node_list_all_empty(self) -> None:
hints = build_taxonomy_hints_from_nodes([])
assert hints == {
"tags": [],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
def test_output_stable_across_calls(self) -> None:
nodes = [make_node(tags=["b", "a", "c"])]
assert build_taxonomy_hints_from_nodes(
nodes,
) == build_taxonomy_hints_from_nodes(nodes)
class TestFormatHintsForPrompt:
def test_all_blocks_present_when_all_categories_nonempty(self) -> None:
hints: TaxonomyHints = {
"tags": ["Bloodwork"],
"document_types": ["Invoice"],
"correspondents": ["IRS"],
"storage_paths": ["Financial"],
}
result = format_hints_for_prompt(hints)
assert "Available tags:" in result
assert "Available document types:" in result
assert "Available correspondents:" in result
assert "Available storage paths:" in result
assert "- Bloodwork" in result
def test_empty_category_produces_no_block(self) -> None:
hints: TaxonomyHints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
result = format_hints_for_prompt(hints)
assert "Available tags:" in result
assert "Available document types:" not in result
assert "Available correspondents:" not in result
assert "Available storage paths:" not in result
def test_all_empty_produces_empty_string(self) -> None:
hints: TaxonomyHints = {
"tags": [],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
assert format_hints_for_prompt(hints) == ""
def test_instruction_line_appears_once(self) -> None:
hints: TaxonomyHints = {
"tags": ["Bloodwork"],
"document_types": ["Invoice"],
"correspondents": [],
"storage_paths": [],
}
result = format_hints_for_prompt(hints)
assert result.count("Prefer existing names from these lists verbatim") == 1
class TestGetTaxonomyHintsForDocument:
def test_returns_none_when_embedding_backend_off(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend=None),
)
retrieve = mocker.patch("paperless_ai.taxonomy.retrieve_similar_nodes")
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
assert result is None
retrieve.assert_not_called()
def test_passes_owner_aware_ids_when_user_present(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
)
mocker.patch(
"paperless_ai.taxonomy.visible_document_ids_for_user",
return_value=[1, 2, 3],
)
retrieve = mocker.patch(
"paperless_ai.taxonomy.retrieve_similar_nodes",
return_value=[],
)
document = DocumentFactory.build()
user = mocker.MagicMock()
get_taxonomy_hints_for_document(document, user=user)
retrieve.assert_called_once_with(
document=document,
document_ids=[1, 2, 3],
)
def test_returns_populated_hints_when_nodes_found(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
)
mocker.patch(
"paperless_ai.taxonomy.retrieve_similar_nodes",
return_value=[make_node(tags=["Taxes"], document_type="Invoice")],
)
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
assert result == {
"tags": ["Taxes"],
"document_types": ["Invoice"],
"correspondents": [],
"storage_paths": [],
}
def test_returns_empty_hints_not_none_when_no_nodes(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
)
mocker.patch(
"paperless_ai.taxonomy.retrieve_similar_nodes",
return_value=[],
)
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
assert result == {
"tags": [],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
@@ -1,77 +0,0 @@
from types import SimpleNamespace
import pytest
import pytest_mock
from django.contrib.auth.models import User
from rest_framework.test import APIClient
from documents.models import Document
from documents.tests.factories import DocumentFactory
@pytest.mark.django_db
class TestSuggestionsHintWiring:
@pytest.fixture
def document(self) -> Document:
return DocumentFactory() # type: ignore[return-value]
@pytest.fixture
def api_client(self, admin_user: User) -> APIClient:
client = APIClient()
client.force_authenticate(user=admin_user)
return client
def test_hints_passed_to_classifier_and_matchers(
self,
api_client: APIClient,
document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
hints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
mocker.patch(
"documents.views.get_taxonomy_hints_for_document",
return_value=hints,
)
mocker.patch(
"documents.views.AIConfig",
return_value=SimpleNamespace(
ai_enabled=True,
llm_backend="ollama",
llm_output_language=None,
),
)
# No cached suggestion -> the view reaches the classifier path.
mocker.patch(
"documents.views.get_llm_suggestion_cache",
return_value=None,
)
mocker.patch("documents.views.set_llm_suggestions_cache")
classify = mocker.patch(
"documents.views.get_ai_document_classification",
return_value={
"title": "Doc",
"tags": ["Bloodwork"],
"correspondents": [],
"document_types": [],
"storage_paths": [],
"dates": [],
},
)
match_tags = mocker.patch(
"documents.views.match_tags_by_name",
return_value=[],
)
mocker.patch("documents.views.match_correspondents_by_name", return_value=[])
mocker.patch("documents.views.match_document_types_by_name", return_value=[])
mocker.patch("documents.views.match_storage_paths_by_name", return_value=[])
response = api_client.get(f"/api/documents/{document.pk}/ai_suggestions/")
assert response.status_code == 200
assert classify.call_args.kwargs["hints"] == hints
assert match_tags.call_args.kwargs["hinted_names"] == {"Bloodwork"}