Compare commits

..

10 Commits

Author SHA1 Message Date
stumpylog 1f4a871b8f Refactor(beta): extract visible_document_ids_for_user helper
The owner-aware "resolve user to visible document pks" block was duplicated
verbatim between get_context_for_document and get_taxonomy_hints_for_document.
Extract it into indexing.visible_document_ids_for_user, next to its sibling
normalize_document_ids, and call it from both paths.

No behavior change: the helper returns None when user is None (unfiltered
retrieval) and the same pk list otherwise.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 29f9475818 Test(beta): use documents factories for taxonomy hint test fixtures
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog d06f66b618 Test(beta): use pytest-django fixtures and drop needless DB markers in taxonomy hint tests
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog f3f55e3866 Enhancement(beta): feed taxonomy hints into AI document suggestions
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 24b81c15f6 Enhancement(beta): splice taxonomy hints into the AI classifier prompt
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 5202b0880e Enhancement(beta): let name matching short-circuit on taxonomy hints
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 7ed58f9664 Enhancement(beta): gate and assemble taxonomy hints for a document
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 43eb3295ce Enhancement(beta): format taxonomy hints into prompt blocks
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog e0ba4cfada Enhancement(beta): add taxonomy hint builder from RAG node metadata
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
stumpylog 73062bd5ab Refactor(beta): extract retrieve_similar_nodes from query_similar_documents
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:07:31 -07:00
56 changed files with 1257 additions and 3297 deletions
-7
View File
@@ -2068,13 +2068,6 @@ context by default.
Defaults to 8192.
#### [`PAPERLESS_AI_LLM_REQUEST_TIMEOUT=<int>`](#PAPERLESS_AI_LLM_REQUEST_TIMEOUT) {#PAPERLESS_AI_LLM_REQUEST_TIMEOUT}
: The timeout, in seconds, for requests to the configured AI backend. Increase this when using
local or slow inference servers that need more time to generate responses.
Defaults to 120.
#### [`PAPERLESS_AI_LLM_BACKEND=<str>`](#PAPERLESS_AI_LLM_BACKEND) {#PAPERLESS_AI_LLM_BACKEND}
: The AI backend to use. This can be either "openai-like" or "ollama". If set to "ollama", the AI
+1 -1
View File
@@ -26,7 +26,7 @@ module.exports = {
'abstract-paperless-service',
],
transformIgnorePatterns: [
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|normalize-diacritics|@angular/common/locales/.*\\.js$))',
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|@angular/common/locales/.*\\.js$))',
],
moduleNameMapper: {
...esmPreset.moduleNameMapper,
-1
View File
@@ -32,7 +32,6 @@
"ngx-cookie-service": "^21.3.1",
"ngx-device-detector": "^11.0.0",
"ngx-ui-tour-ng-bootstrap": "^18.0.0",
"normalize-diacritics": "^5.0.0",
"pdfjs-dist": "^5.7.284",
"rxjs": "^7.8.2",
"tslib": "^2.8.1",
-11
View File
@@ -71,9 +71,6 @@ importers:
ngx-ui-tour-ng-bootstrap:
specifier: ^18.0.0
version: 18.0.0(f910a33494d223bd6dd07ce1bf22a35e)
normalize-diacritics:
specifier: ^5.0.0
version: 5.0.0
pdfjs-dist:
specifier: ^5.7.284
version: 5.7.284
@@ -5519,10 +5516,6 @@ packages:
engines: {node: ^20.17.0 || >=22.9.0}
hasBin: true
normalize-diacritics@5.0.0:
resolution: {integrity: sha512-t6czCJOpbAtckN1wCC2qPWnO3GQvNANb9bcUNbiOLEqojVuP31+ELIs5KhEG8jyz0TH7iD9BWxWz8O3ic2/rMQ==}
engines: {node: '>= 14.x', npm: '>= 6.x'}
normalize-path@3.0.0:
resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==}
engines: {node: '>=0.10.0'}
@@ -12938,10 +12931,6 @@ snapshots:
dependencies:
abbrev: 4.0.0
normalize-diacritics@5.0.0:
dependencies:
tslib: 2.8.1
normalize-path@3.0.0: {}
npm-bundled@5.0.0:
@@ -23,7 +23,6 @@ import {
import { CustomFieldsService } from 'src/app/services/rest/custom-fields.service'
import { ToastService } from 'src/app/services/toast.service'
import { pngxPopperOptions } from 'src/app/utils/popper-options'
import { matchesSearchText } from 'src/app/utils/text-search'
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
import { CustomFieldEditDialogComponent } from '../edit-dialog/custom-field-edit-dialog/custom-field-edit-dialog.component'
@@ -70,7 +69,9 @@ export class CustomFieldsDropdownComponent extends LoadingComponentWithPermissio
public get filteredFields(): CustomField[] {
return this.unusedFields.filter(
(f) => !this.filterText || matchesSearchText(f.name, this.filterText)
(f) =>
!this.filterText ||
f.name.toLowerCase().includes(this.filterText.toLowerCase())
)
}
@@ -63,7 +63,6 @@
[(ngModel)]="atom.value"
[disabled]="disabled"
[virtualScroll]="getSelectOptionsForField(atom.field)?.length > 100"
[searchFn]="selectOptionSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
} @else if (getCustomFieldByID(atom.field)?.data_type === CustomFieldDataType.DocumentLink) {
@@ -82,7 +81,6 @@
[disabled]="disabled"
bindLabel="name"
bindValue="id"
[searchFn]="customFieldSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
<select class="w-25 form-select" [(ngModel)]="atom.operator" [disabled]="disabled">
@@ -127,7 +125,6 @@
[(ngModel)]="atom.value"
[disabled]="disabled"
[multiple]="true"
[searchFn]="selectOptionSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
}
@@ -36,7 +36,6 @@ import {
CustomFieldQueryExpression,
} from 'src/app/utils/custom-field-query-element'
import { pngxPopperOptions } from 'src/app/utils/popper-options'
import { matchesSearchText } from 'src/app/utils/text-search'
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
import { ClearableBadgeComponent } from '../clearable-badge/clearable-badge.component'
import { DocumentLinkComponent } from '../input/document-link/document-link.component'
@@ -282,14 +281,6 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm
public readonly today: string = new Date().toLocaleDateString('en-CA')
public customFieldSearchFn = (term: string, field: CustomField): boolean =>
matchesSearchText(field?.name, term)
public selectOptionSearchFn = (
term: string,
option: { id: string; label: string }
): boolean => matchesSearchText(option?.label, term)
constructor() {
super()
this.selectionModel = new CustomFieldQueriesModel()
@@ -28,7 +28,6 @@
[notFoundText]="notFoundText"
[multiple]="multiple"
[bindLabel]="bindLabel"
[searchFn]="searchFn"
bindValue="id"
[virtualScroll]="items?.length > 100"
(change)="onChange(value)"
@@ -112,15 +112,6 @@ describe('SelectComponent', () => {
expect(createNewVal).toEqual('baz')
})
it('should search items by independent normalized terms', () => {
expect(
component.searchFn('tax 26', { id: 11, name: 'Tax\u00e9s 2026' })
).toBeTruthy()
expect(
component.searchFn('tax receipt', { id: 11, name: 'Tax\u00e9s 2026' })
).toBeFalsy()
})
it('should clear search term on blur after delay', fakeAsync(() => {
const clearSpy = jest.spyOn(component, 'clearLastSearchTerm')
component.onBlur()
@@ -13,7 +13,6 @@ import {
import { RouterModule } from '@angular/router'
import { NgSelectModule } from '@ng-select/ng-select'
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { matchesSearchText } from 'src/app/utils/text-search'
import { AbstractInputComponent } from '../abstract-input'
@Component({
@@ -100,9 +99,6 @@ export class SelectComponent extends AbstractInputComponent<number> {
@Input()
bindLabel: string = 'name'
public searchFn = (term: string, item: any): boolean =>
matchesSearchText(item?.[this.bindLabel], term)
@Input()
showFilter: boolean = false
@@ -14,7 +14,6 @@
[clearSearchOnAdd]="true"
[hideSelected]="tags.length > 0"
[addTag]="allowCreate ? createTagRef : false"
[searchFn]="searchFn"
addTagText="Add tag"
i18n-addTagText
(add)="onAdd($event)"
@@ -171,15 +171,6 @@ describe('TagsComponent', () => {
expect(component.getTag(4)).toBeUndefined()
})
it('should search tags by independent normalized terms including parents', () => {
const parent: Tag = { id: 11, name: 'Financ\u00e9' }
const child: Tag = { id: 12, name: 'Taxes 2026', parent: parent.id }
component.tags = [parent, child]
expect(component.searchFn('finance 26', child)).toBeTruthy()
expect(component.searchFn('finance receipt', child)).toBeFalsy()
})
it('should emit filtered documents', () => {
component.value = [10]
component.tags = tags
@@ -21,7 +21,6 @@ import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { first, firstValueFrom, tap } from 'rxjs'
import { Tag } from 'src/app/data/tag'
import { TagService } from 'src/app/services/rest/tag.service'
import { matchesSearchText } from 'src/app/utils/text-search'
import { EditDialogMode } from '../../edit-dialog/edit-dialog.component'
import { TagEditDialogComponent } from '../../edit-dialog/tag-edit-dialog/tag-edit-dialog.component'
import { TagComponent } from '../../tag/tag.component'
@@ -115,14 +114,6 @@ export class TagsComponent implements OnInit, ControlValueAccessor {
public createTagRef: (name) => void
public searchFn = (term: string, tag: Tag): boolean =>
matchesSearchText(
[this.getParentChain(tag?.id).map((parent) => parent.name), tag?.name]
.flat()
.join(' '),
term
)
getTag(id: number) {
if (this.tags) {
return this.tags.find((tag) => tag.id == id)
@@ -131,9 +131,7 @@
@if (status.tasks.celery_status === 'OK') {
<i-bs name="check-circle-fill" class="text-primary ms-2 lh-1"></i-bs>
} @else {
<i-bs name="exclamation-triangle-fill" class="ms-2 lh-1"
[class.text-danger]="status.tasks.celery_status === SystemStatusItemStatus.ERROR"
[class.text-warning]="status.tasks.celery_status === SystemStatusItemStatus.WARNING"></i-bs>
<i-bs name="exclamation-triangle-fill" class="text-danger ms-2 lh-1"></i-bs>
}
</button>
<ng-template #celeryStatus>
-9
View File
@@ -360,14 +360,6 @@ export const PaperlessConfigOptions: ConfigOption[] = [
category: ConfigCategory.AI,
note: $localize`Language to use for generated AI suggestions. When unset, AI suggestions use the user's display language if explicitly set.`,
},
{
key: 'llm_request_timeout',
title: $localize`LLM Request Timeout`,
type: ConfigOptionType.Number,
config_key: 'PAPERLESS_AI_LLM_REQUEST_TIMEOUT',
category: ConfigCategory.AI,
note: $localize`Timeout in seconds for LLM requests.`,
},
]
export interface PaperlessConfig extends ObjectWithId {
@@ -409,5 +401,4 @@ export interface PaperlessConfig extends ObjectWithId {
llm_api_key: string
llm_endpoint: string
llm_output_language: string
llm_request_timeout: number
}
+3 -2
View File
@@ -1,6 +1,5 @@
import { Pipe, PipeTransform } from '@angular/core'
import { MatchingModel } from '../data/matching-model'
import { matchesSearchText } from '../utils/text-search'
@Pipe({
name: 'filter',
@@ -22,7 +21,9 @@ export class FilterPipe implements PipeTransform {
typeof item[key] === 'string' || typeof item[key] === 'number'
)
return keys.some((key) => {
return matchesSearchText(item[key], searchText)
return String(item[key])
.toLowerCase()
.includes(searchText.toLowerCase())
})
})
}
-17
View File
@@ -1,17 +0,0 @@
import { matchesSearchText } from './text-search'
describe('text search utilities', () => {
it('matches text accent-insensitively', () => {
expect(matchesSearchText('R\u00e9sum\u00e9', 'resume')).toBeTruthy()
expect(matchesSearchText('S\u00f8ren', 'soren')).toBeTruthy()
expect(matchesSearchText('\u0152uvre', 'oeuvre')).toBeTruthy()
expect(matchesSearchText('Invoice', 'receipt')).toBeFalsy()
})
it('matches all whitespace-separated search terms independently', () => {
expect(matchesSearchText('taxes 2026', 'tax 26')).toBeTruthy()
expect(matchesSearchText('2026 taxes', 'tax 26')).toBeTruthy()
expect(matchesSearchText('Tax\u00e9s 2026', 'taxe 26')).toBeTruthy()
expect(matchesSearchText('taxes 2026', 'tax receipt')).toBeFalsy()
})
})
-23
View File
@@ -1,23 +0,0 @@
import { normalizeSync } from 'normalize-diacritics'
export type SearchTextValue =
| string
| number
| boolean
| bigint
| null
| undefined
export function normalizeSearchText(value: SearchTextValue): string {
return normalizeSync(String(value ?? '')).toLocaleLowerCase()
}
export function matchesSearchText(
value: SearchTextValue,
searchText: SearchTextValue
): boolean {
const normalizedValue = normalizeSearchText(value)
const searchTerms = normalizeSearchText(searchText).trim().split(/\s+/)
return searchTerms.every((term) => normalizedValue.includes(term))
}
@@ -169,10 +169,6 @@ class FileStabilityTracker:
self._tracked.pop(path, None)
yield path
def is_tracking(self, path: Path) -> bool:
"""Check whether a path is currently being tracked for stability."""
return path.resolve() in self._tracked
def has_pending_files(self) -> bool:
"""Check if there are files waiting for stability check."""
return len(self._tracked) > 0
@@ -374,16 +370,6 @@ class Command(BaseCommand):
# Testing timeout in seconds
testing_timeout_s: Final[float] = 0.5
# How often to perform a full-glob rescan of the consume directory as a
# safety net. Each watchfiles watcher is torn down and recreated on every
# batch to reconfigure its timeout, and a fresh watcher silently adopts the
# current directory contents as its baseline. A file that appears between
# one batch and the next watcher's baseline is therefore never reported and
# would sit in the consume directory forever. This periodic rescan re-injects
# such files into the stability tracker (see GH issue #13011). Not currently
# user-configurable; instances may override for testing.
rescan_interval_s: float = 300.0
def add_arguments(self, parser) -> None:
parser.add_argument(
"directory",
@@ -439,7 +425,7 @@ class Command(BaseCommand):
)
# Process existing files
queued = self._process_existing_files(
self._process_existing_files(
directory=directory,
recursive=recursive,
subdirs_as_tags=subdirs_as_tags,
@@ -459,7 +445,6 @@ class Command(BaseCommand):
polling_interval=polling_interval,
stability_delay=stability_delay,
is_testing=is_testing,
queued=queued,
)
logger.debug("Consumer exiting")
@@ -471,18 +456,11 @@ class Command(BaseCommand):
recursive: bool,
subdirs_as_tags: bool,
consumer_filter: ConsumerFilter,
) -> set[Path]:
"""
Process any existing files in the consumption directory.
Returns the set of resolved paths that were queued, so the watch loop
can seed its in-flight set and avoid re-queuing them on the first
rescan before the consume tasks have removed them from disk.
"""
) -> None:
"""Process any existing files in the consumption directory."""
logger.info(f"Processing existing files in {directory}")
glob_pattern = "**/*" if recursive else "*"
queued: set[Path] = set()
for filepath in directory.glob(glob_pattern):
# Use filter to check if file should be processed
@@ -497,48 +475,6 @@ class Command(BaseCommand):
consumption_dir=directory,
subdirs_as_tags=subdirs_as_tags,
)
queued.add(filepath.resolve())
return queued
def _rescan_existing_files(
self,
*,
directory: Path,
recursive: bool,
consumer_filter: ConsumerFilter,
tracker: FileStabilityTracker,
queued: set[Path],
) -> None:
"""
Re-inject on-disk files the watcher never reported into the tracker.
Acts as a safety net for files stranded by the watcher-recreation gap
(see ``rescan_interval_s``). Files already being tracked or already
queued and awaiting consumption are skipped, so a file is never queued
twice. Queued paths that have since left the directory are pruned so a
later file reusing the same name is not skipped forever.
"""
# Prune in-flight paths that have left the directory
for path in list(queued):
if not path.exists():
queued.discard(path)
glob_pattern = "**/*" if recursive else "*"
for filepath in directory.glob(glob_pattern):
if not filepath.is_file():
continue
if not consumer_filter(Change.added, str(filepath)):
continue
resolved = filepath.resolve()
if tracker.is_tracking(resolved) or resolved in queued:
continue
logger.debug(f"Rescan found untracked file: {resolved}")
tracker.track(resolved, Change.added)
def _watch_directory(
self,
@@ -550,24 +486,11 @@ class Command(BaseCommand):
polling_interval: float,
stability_delay: float,
is_testing: bool,
queued: set[Path] | None = None,
) -> None:
"""Watch directory for changes and process stable files."""
use_polling = polling_interval > 0
poll_delay_ms = int(polling_interval * 1000) if use_polling else 0
# Resolved paths that have been queued and are awaiting consumption.
# Seeded from the startup scan so the first rescan does not re-queue
# files whose consume tasks have not yet removed them from disk.
queued = set() if queued is None else queued
# Full-glob safety net cadence (0 disables)
rescan_interval_s = self.rescan_interval_s
rescan_timeout_ms = (
int(rescan_interval_s * 1000) if rescan_interval_s > 0 else 0
)
last_rescan = monotonic()
if use_polling:
logger.info(
f"Watching {directory} using polling (interval: {polling_interval}s)",
@@ -582,20 +505,6 @@ class Command(BaseCommand):
stability_timeout_ms = int(stability_delay * 1000)
testing_timeout_ms = int(self.testing_timeout_s * 1000)
def cap_for_rescan(ms: int) -> int:
"""
Ensure the watch loop wakes often enough to run the rescan.
``watch()`` blocks for up to ``rust_timeout``, so the rescan can
only run that often. A timeout of 0 means "wait indefinitely",
which would never wake to rescan; cap it at the rescan interval.
"""
if rescan_timeout_ms <= 0:
return ms
if ms <= 0:
return rescan_timeout_ms
return min(ms, rescan_timeout_ms)
# Calculate appropriate timeout for watch loop
# In polling mode, rust_timeout must be significantly longer than poll_delay_ms
# to ensure poll cycles can complete before timing out
@@ -613,8 +522,6 @@ class Command(BaseCommand):
# Not testing, wait indefinitely for first event
timeout_ms = 0
timeout_ms = cap_for_rescan(timeout_ms)
self.stop_flag.clear()
while not self.stop_flag.is_set():
@@ -644,26 +551,10 @@ class Command(BaseCommand):
consumption_dir=directory,
subdirs_as_tags=subdirs_as_tags,
)
# Remember it so the rescan does not re-queue it while
# the consume task has yet to remove it from disk
queued.add(stable_path)
# Exit watch loop to reconfigure timeout
break
# Periodic full-glob safety net for files the watcher missed
if rescan_timeout_ms > 0 and (
monotonic() - last_rescan >= rescan_interval_s
):
self._rescan_existing_files(
directory=directory,
recursive=recursive,
consumer_filter=consumer_filter,
tracker=tracker,
queued=queued,
)
last_rescan = monotonic()
# Determine next timeout
if tracker.has_pending_files():
# Check pending files at stability interval
@@ -681,8 +572,6 @@ class Command(BaseCommand):
# No pending files, wait indefinitely
timeout_ms = 0
timeout_ms = cap_for_rescan(timeout_ms)
except KeyboardInterrupt: # pragma: nocover
logger.info("Received interrupt, stopping consumer")
self.stop_flag.set()
@@ -1,63 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:31
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
replaces = [
("documents", "0003_remove_document_storage_type"),
("documents", "0004_workflowtrigger_filter_has_any_correspondents_and_more"),
("documents", "0005_alter_document_checksum_unique"),
]
dependencies = [
("documents", "0002_squashed"),
]
operations = [
migrations.RemoveField(
model_name="document",
name="storage_type",
),
migrations.AddField(
model_name="workflowtrigger",
name="filter_has_any_correspondents",
field=models.ManyToManyField(
blank=True,
related_name="workflowtriggers_has_any_correspondent",
to="documents.correspondent",
verbose_name="has one of these correspondents",
),
),
migrations.AddField(
model_name="workflowtrigger",
name="filter_has_any_document_types",
field=models.ManyToManyField(
blank=True,
related_name="workflowtriggers_has_any_document_type",
to="documents.documenttype",
verbose_name="has one of these document types",
),
),
migrations.AddField(
model_name="workflowtrigger",
name="filter_has_any_storage_paths",
field=models.ManyToManyField(
blank=True,
related_name="workflowtriggers_has_any_storage_path",
to="documents.storagepath",
verbose_name="has one of these storage paths",
),
),
migrations.AlterField(
model_name="document",
name="checksum",
field=models.CharField(
editable=False,
help_text="The checksum of the original document.",
max_length=32,
verbose_name="checksum",
),
),
]
@@ -1,252 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:31
import django.db.models.deletion
import django.db.models.functions.text
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
replaces = [
("documents", "0008_workflowaction_passwords_alter_workflowaction_type"),
("documents", "0009_alter_document_content_length"),
("documents", "0010_optimize_integer_field_sizes"),
("documents", "0011_alter_workflowaction_type"),
("documents", "0012_document_root_document"),
]
dependencies = [
("documents", "0007_sharelinkbundle"),
]
operations = [
migrations.AddField(
model_name="workflowaction",
name="passwords",
field=models.JSONField(
blank=True,
help_text="Passwords to try when removing PDF protection. Separate with commas or new lines.",
null=True,
verbose_name="passwords",
),
),
migrations.AlterField(
model_name="document",
name="content_length",
field=models.GeneratedField(
db_persist=True,
expression=django.db.models.functions.text.Length("content"),
help_text="Length of the content field in characters. Automatically maintained by the database for faster statistics computation.",
output_field=models.PositiveIntegerField(default=0),
serialize=False,
),
),
migrations.AlterField(
model_name="correspondent",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="documenttype",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="savedviewfilterrule",
name="rule_type",
field=models.PositiveSmallIntegerField(
choices=[
(0, "title contains"),
(1, "content contains"),
(2, "ASN is"),
(3, "correspondent is"),
(4, "document type is"),
(5, "is in inbox"),
(6, "has tag"),
(7, "has any tag"),
(8, "created before"),
(9, "created after"),
(10, "created year is"),
(11, "created month is"),
(12, "created day is"),
(13, "added before"),
(14, "added after"),
(15, "modified before"),
(16, "modified after"),
(17, "does not have tag"),
(18, "does not have ASN"),
(19, "title or content contains"),
(20, "fulltext query"),
(21, "more like this"),
(22, "has tags in"),
(23, "ASN greater than"),
(24, "ASN less than"),
(25, "storage path is"),
(26, "has correspondent in"),
(27, "does not have correspondent in"),
(28, "has document type in"),
(29, "does not have document type in"),
(30, "has storage path in"),
(31, "does not have storage path in"),
(32, "owner is"),
(33, "has owner in"),
(34, "does not have owner"),
(35, "does not have owner in"),
(36, "has custom field value"),
(37, "is shared by me"),
(38, "has custom fields"),
(39, "has custom field in"),
(40, "does not have custom field in"),
(41, "does not have custom field"),
(42, "custom fields query"),
(43, "created to"),
(44, "created from"),
(45, "added to"),
(46, "added from"),
(47, "mime type is"),
],
verbose_name="rule type",
),
),
migrations.AlterField(
model_name="storagepath",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="tag",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="workflowrun",
name="type",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Consumption Started"),
(2, "Document Added"),
(3, "Document Updated"),
(4, "Scheduled"),
],
null=True,
verbose_name="workflow trigger type",
),
),
migrations.AlterField(
model_name="workflowtrigger",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
],
default=0,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="workflowtrigger",
name="type",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Consumption Started"),
(2, "Document Added"),
(3, "Document Updated"),
(4, "Scheduled"),
],
default=1,
verbose_name="Workflow Trigger Type",
),
),
migrations.AlterField(
model_name="workflowaction",
name="type",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Assignment"),
(2, "Removal"),
(3, "Email"),
(4, "Webhook"),
(5, "Password removal"),
(6, "Move to trash"),
],
default=1,
verbose_name="Workflow Action Type",
),
),
migrations.AddField(
model_name="document",
name="root_document",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="versions",
to="documents.document",
verbose_name="root document for this version",
),
),
migrations.AddField(
model_name="document",
name="version_label",
field=models.CharField(
blank=True,
help_text="Optional short label for a document version.",
max_length=64,
null=True,
verbose_name="version label",
),
),
]
-4
View File
@@ -8,15 +8,11 @@ from documents.search._backend import get_backend
from documents.search._backend import reset_backend
from documents.search._schema import needs_rebuild
from documents.search._schema import wipe_index
from documents.search._translate import InvalidDateQuery
from documents.search._translate import SearchQueryError
__all__ = [
"InvalidDateQuery",
"SearchHit",
"SearchIndexLockError",
"SearchMode",
"SearchQueryError",
"TantivyBackend",
"TantivyRelevanceList",
"WriteBatch",
+2 -18
View File
@@ -866,24 +866,8 @@ class TantivyBackend:
final_query = self._apply_permission_filter(mlt_query, user)
effective_limit = limit if limit is not None else searcher.num_docs
try:
# Fetch one extra to account for excluding the original document
results = searcher.search(final_query, limit=effective_limit + 1)
except BaseException: # pragma: no cover
# Tantivy 0.26 panics in BM25 idf scoring when the index holds
# soft-deleted documents (doc_freq can exceed the alive doc count),
# which only surfaces for the More Like This query. The panic crosses
# the pyo3 boundary as a `pyo3_runtime.PanicException` — a
# BaseException, not an Exception — so catch BaseException and degrade
# to "no similar documents" instead of bubbling a 500 to the client.
# Fixed upstream: https://github.com/quickwit-oss/tantivy/pull/2964
# Remove once the bundled tantivy includes that fix.
logger.warning(
"More Like This scoring panicked (likely stale tantivy segment "
"stats after deletions); returning no results. A search index "
"reindex will rebuild consistent statistics.",
)
return []
# Fetch one extra to account for excluding the original document
results = searcher.search(final_query, limit=effective_limit + 1)
addrs = [addr for _score, addr in results.hits]
all_ids = cast("list[int]", searcher.fast_field_values("id", addrs))
-163
View File
@@ -1,163 +0,0 @@
from __future__ import annotations
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Final
from dateutil.relativedelta import relativedelta
if TYPE_CHECKING:
from datetime import tzinfo
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _quarter_start(d: date) -> date:
"""Return the first day of the calendar quarter containing ``d``."""
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
def _midnight(d: date, tz: tzinfo) -> datetime:
"""Convert a calendar date at local-timezone midnight to a UTC datetime."""
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _keyword_bounds(keyword: str, tz: tzinfo) -> tuple[date, date]:
"""
Map a relative date keyword to ``(start, exclusive_end)`` calendar dates.
``tz`` only determines what "today" is; the caller decides how the returned
dates become UTC datetime boundaries (date-only vs. local-midnight offset).
"""
today = datetime.now(tz).date()
if keyword == _TODAY:
return today, today + timedelta(days=1)
if keyword == _YESTERDAY:
return today - timedelta(days=1), today
if keyword == _PREVIOUS_WEEK:
this_monday = today - timedelta(days=today.weekday())
return this_monday - timedelta(weeks=1), this_monday
if keyword == _THIS_MONTH:
first = today.replace(day=1)
return first, first + relativedelta(months=1)
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
return this_first - relativedelta(months=1), this_first
if keyword == _THIS_YEAR:
return date(today.year, 1, 1), date(today.year + 1, 1, 1)
if keyword == _PREVIOUS_YEAR:
return date(today.year - 1, 1, 1), date(today.year, 1, 1)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
return this_quarter - relativedelta(months=3), this_quarter
raise ValueError(f"Unknown keyword: {keyword}")
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic — date only.
"""
start, end = _keyword_bounds(keyword, tz)
lo = datetime(start.year, start.month, start.day, tzinfo=UTC)
hi = datetime(end.year, end.month, end.day, tzinfo=UTC)
return _iso_range(lo, hi)
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC — full offset arithmetic required.
"""
start, end = _keyword_bounds(keyword, tz)
return _iso_range(_midnight(start, tz), _midnight(end, tz))
def _precision_bounds(digits: str) -> tuple[date, date] | None:
"""
Map a 4/6/8-digit date token to (start, exclusive_end) calendar dates.
YYYY -> whole year, YYYYMM -> whole month, YYYYMMDD -> single day.
Returns None for any unparsable or out-of-range value (e.g. month 23),
so callers can emit a no-match clause instead of erroring (Whoosh parity).
"""
try:
if len(digits) == 4:
year = int(digits)
return date(year, 1, 1), date(year + 1, 1, 1)
if len(digits) == 6:
year, month = int(digits[:4]), int(digits[4:6])
start = date(year, month, 1)
end = date(year + 1, 1, 1) if month == 12 else date(year, month + 1, 1)
return start, end
if len(digits) == 8:
start = date(int(digits[:4]), int(digits[4:6]), int(digits[6:8]))
return start, start + timedelta(days=1)
except ValueError:
return None
return None
def _utc_bounds_for_field(
field: str,
start: date,
end: date,
tz: tzinfo,
) -> tuple[datetime, datetime]:
"""
Convert calendar-date bounds to UTC datetimes per the field's storage type.
For DateField (``created``) the bounds are UTC midnight (no offset). For
DateTimeField (``added``/``modified``) the bounds are local-tz midnight
converted to UTC, matching how each field is indexed.
"""
if field in _DATE_ONLY_FIELDS:
return (
datetime(start.year, start.month, start.day, tzinfo=UTC),
datetime(end.year, end.month, end.day, tzinfo=UTC),
)
return (
datetime(start.year, start.month, start.day, tzinfo=tz).astimezone(UTC),
datetime(end.year, end.month, end.day, tzinfo=tz).astimezone(UTC),
)
def _field_range_from_dates(field: str, start: date, end: date, tz: tzinfo) -> str:
"""Build a Tantivy ``field:[lo TO hi]`` ISO range from calendar-date bounds."""
lo, hi = _utc_bounds_for_field(field, start, end, tz)
return f"{field}:{_iso_range(lo, hi)}"
+405 -27
View File
@@ -1,35 +1,88 @@
from __future__ import annotations
import logging
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Final
import regex
import tantivy
from dateutil.relativedelta import relativedelta
from django.conf import settings
from documents.search._dates import (
_date_only_range, # noqa: F401 — re-exported for test imports
)
from documents.search._dates import (
_datetime_range, # noqa: F401 — re-exported for test imports
)
from documents.search._tokenizer import simple_search_tokens
from documents.search._translate import SearchQueryError
from documents.search._translate import translate_query
if TYPE_CHECKING:
from datetime import tzinfo
from django.contrib.auth.base_user import AbstractBaseUser
logger = logging.getLogger("paperless.search")
# Maximum seconds any single regex substitution may run.
# Prevents ReDoS on adversarial user-supplied query strings.
_REGEX_TIMEOUT: Final[float] = 1.0
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
_DATE_KEYWORD_PATTERN = "|".join(
sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True),
)
_FIELD_DATE_RE = regex.compile(
rf"""(?<!\w)(?P<field>created|modified|added)\s*:\s*(?:
(?P<quote>["'])(?P<quoted>{_DATE_KEYWORD_PATTERN})(?P=quote)
|
(?P<bare>{_DATE_KEYWORD_PATTERN})(?![\w-])
)""",
regex.IGNORECASE | regex.VERBOSE,
)
_COMPACT_DATE_RE = regex.compile(r"\b(\d{14})\b")
_RELATIVE_RANGE_RE = regex.compile(
r"\[now([+-]\d+[dhm])?\s+TO\s+now([+-]\d+[dhm])?\]",
regex.IGNORECASE,
)
# Whoosh-style relative date range: e.g. [-1 week to now], [-7 days to now]
_WHOOSH_REL_RANGE_RE = regex.compile(
r"\[-(?P<n>\d+)\s+(?P<unit>second|minute|hour|day|week|month|year)s?\s+to\s+now\]",
regex.IGNORECASE,
)
# Whoosh-style 8-digit date: field:YYYYMMDD — field-aware so timezone can be applied correctly.
# Scoped to date fields only; numeric fields (asn, id, page_count, ...) must not be rewritten.
_DATE8_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):(?P<date8>\d{8})\b",
)
_YEAR_RANGE_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):\[(?P<y1>\d{4})\s+TO\s+(?P<y2>\d{4})\]",
regex.IGNORECASE,
)
# Tantivy syntax error: " - " and " + " with spaces on both sides are invalid because
# the NOT/MUST operators require no space between the operator and the term.
# In natural-language queries (e.g., "H52.1 - Kurzsichtigkeit"), the dash is a separator.
_SPACED_OPERATOR_RE = regex.compile(r"\s+[-+]\s+")
_TRAILING_OPERATOR_RE = regex.compile(r"\s+[-+]+\s*$")
# Matches CJK/Hangul characters so queries can be routed to bigram fields.
# Uses Unicode properties to cover all blocks including Extension B+ planes.
_CJK_RE: Final = regex.compile(r"[\p{Han}\p{Hiragana}\p{Katakana}\p{Hangul}]+")
@@ -64,12 +117,303 @@ def _build_cjk_query(
return None
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic — date only.
"""
today = datetime.now(tz).date()
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
lo = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, lo + timedelta(days=1))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
lo = datetime(y.year, y.month, y.day, tzinfo=UTC)
hi = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
lo = datetime(last_mon.year, last_mon.month, last_mon.day, tzinfo=UTC)
hi = datetime(this_mon.year, this_mon.month, this_mon.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_MONTH:
lo = datetime(today.year, today.month, 1, tzinfo=UTC)
if today.month == 12:
hi = datetime(today.year + 1, 1, 1, tzinfo=UTC)
else:
hi = datetime(today.year, today.month + 1, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_MONTH:
if today.month == 1:
lo = datetime(today.year - 1, 12, 1, tzinfo=UTC)
else:
lo = datetime(today.year, today.month - 1, 1, tzinfo=UTC)
hi = datetime(today.year, today.month, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_YEAR:
lo = datetime(today.year, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year + 1, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_YEAR:
lo = datetime(today.year - 1, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
lo = datetime(
last_quarter.year,
last_quarter.month,
last_quarter.day,
tzinfo=UTC,
)
hi = datetime(
this_quarter.year,
this_quarter.month,
this_quarter.day,
tzinfo=UTC,
)
return _iso_range(lo, hi)
raise ValueError(f"Unknown keyword: {keyword}")
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC — full offset arithmetic required.
"""
now_local = datetime.now(tz)
today = now_local.date()
def _midnight(d: date) -> datetime:
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
return _iso_range(_midnight(today), _midnight(today + timedelta(days=1)))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
return _iso_range(_midnight(y), _midnight(today))
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
return _iso_range(_midnight(last_mon), _midnight(this_mon))
if keyword == _THIS_MONTH:
first = today.replace(day=1)
if today.month == 12:
next_first = date(today.year + 1, 1, 1)
else:
next_first = date(today.year, today.month + 1, 1)
return _iso_range(_midnight(first), _midnight(next_first))
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
if today.month == 1:
last_first = date(today.year - 1, 12, 1)
else:
last_first = date(today.year, today.month - 1, 1)
return _iso_range(_midnight(last_first), _midnight(this_first))
if keyword == _THIS_YEAR:
return _iso_range(
_midnight(date(today.year, 1, 1)),
_midnight(date(today.year + 1, 1, 1)),
)
if keyword == _PREVIOUS_YEAR:
return _iso_range(
_midnight(date(today.year - 1, 1, 1)),
_midnight(date(today.year, 1, 1)),
)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
return _iso_range(_midnight(last_quarter), _midnight(this_quarter))
raise ValueError(f"Unknown keyword: {keyword}")
def _rewrite_compact_date(query: str) -> str:
"""Rewrite Whoosh compact date tokens (14-digit YYYYMMDDHHmmss) to ISO 8601."""
def _sub(m: regex.Match[str]) -> str:
raw = m.group(1)
try:
dt = datetime(
int(raw[0:4]),
int(raw[4:6]),
int(raw[6:8]),
int(raw[8:10]),
int(raw[10:12]),
int(raw[12:14]),
tzinfo=UTC,
)
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
except ValueError:
return str(m.group(0))
try:
return _COMPACT_DATE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (compact date rewrite timed out)",
)
def _rewrite_relative_range(query: str) -> str:
"""Rewrite Whoosh relative ranges ([now-7d TO now]) to concrete ISO 8601 UTC boundaries."""
def _sub(m: regex.Match[str]) -> str:
now = datetime.now(UTC)
def _offset(s: str | None) -> timedelta:
if not s:
return timedelta(0)
sign = 1 if s[0] == "+" else -1
n, unit = int(s[1:-1]), s[-1]
return (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
lo, hi = now + _offset(m.group(1)), now + _offset(m.group(2))
if lo > hi:
lo, hi = hi, lo
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _RELATIVE_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (relative range rewrite timed out)",
)
def _rewrite_whoosh_relative_range(query: str) -> str:
"""Rewrite Whoosh-style relative date ranges ([-N unit to now]) to ISO 8601.
Supports: second, minute, hour, day, week, month, year (singular and plural).
Example: ``added:[-1 week to now]`` → ``added:[2025-01-01T… TO 2025-01-08T…]``
"""
now = datetime.now(UTC)
def _sub(m: regex.Match[str]) -> str:
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
lo = now - delta_map[unit]
return f"[{_fmt(lo)} TO {_fmt(now)}]"
try:
return _WHOOSH_REL_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (Whoosh relative range rewrite timed out)",
)
def _rewrite_8digit_date(query: str, tz: tzinfo) -> str:
"""Rewrite field:YYYYMMDD date tokens to an ISO 8601 day range.
Runs after ``_rewrite_compact_date`` so 14-digit timestamps are already
converted and won't spuriously match here.
For DateField fields (e.g. ``created``) uses UTC midnight boundaries.
For DateTimeField fields (e.g. ``added``, ``modified``) uses local TZ
midnight boundaries converted to UTC — matching the ``_datetime_range``
behaviour for keyword dates.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
raw = m.group("date8")
try:
year, month, day = int(raw[0:4]), int(raw[4:6]), int(raw[6:8])
d = date(year, month, day)
if field in _DATE_ONLY_FIELDS:
lo = datetime(d.year, d.month, d.day, tzinfo=UTC)
hi = lo + timedelta(days=1)
else:
# DateTimeField: use local-timezone midnight → UTC
lo = datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
hi = datetime(
(d + timedelta(days=1)).year,
(d + timedelta(days=1)).month,
(d + timedelta(days=1)).day,
tzinfo=tz,
).astimezone(UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
except ValueError:
return m.group(0)
try:
return _DATE8_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (8-digit date rewrite timed out)",
)
def _rewrite_year_range(query: str) -> str:
"""Rewrite Whoosh-style year-only date ranges to ISO 8601 UTC boundaries.
Converts ``field:[YYYY TO YYYY]`` to a full ISO 8601 datetime range.
The upper bound is the start of the year after the end year (exclusive),
matching the Whoosh convention of treating year-only ranges as full-year spans.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
y1, y2 = int(m.group("y1")), int(m.group("y2"))
# Whoosh swaps a reversed range when both years are explicit
# (whoosh.util.times.timespan.disambiguated); match that so a backwards
# range spans the intended years instead of matching nothing.
lo_year, hi_year = min(y1, y2), max(y1, y2)
lo = datetime(lo_year, 1, 1, tzinfo=UTC)
hi = datetime(hi_year + 1, 1, 1, tzinfo=UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _YEAR_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError("Query too complex to process (year range rewrite timed out)")
def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
"""
Rewrite natural date syntax to ISO 8601 format for Tantivy compatibility.
Delegates to ``translate_query`` which handles all date forms, comma
expansion, field aliasing, relative ranges, and operator normalization.
Performs the first stage of query preprocessing, converting various date
formats and keywords to ISO 8601 datetime ranges that Tantivy can parse:
- Compact 14-digit dates (YYYYMMDDHHmmss)
- Whoosh relative ranges ([-7 days to now], [now-1h TO now+2h])
- 8-digit dates with field awareness (created:20240115)
- Natural keywords (field:today, field:"previous quarter", etc.)
Args:
query: Raw user query string
@@ -81,15 +425,35 @@ def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
Note:
Bare keywords without field prefixes pass through unchanged.
"""
return translate_query(query, tz)
query = _rewrite_compact_date(query)
query = _rewrite_whoosh_relative_range(query)
query = _rewrite_year_range(query)
query = _rewrite_8digit_date(query, tz)
query = _rewrite_relative_range(query)
def _replace(m: regex.Match[str]) -> str:
field = m.group("field")
keyword = (m.group("quoted") or m.group("bare")).lower()
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(keyword, tz)}"
return f"{field}:{_datetime_range(keyword, tz)}"
try:
return _FIELD_DATE_RE.sub(_replace, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (date keyword rewrite timed out)",
)
def normalize_query(query: str) -> str:
"""
Normalize query syntax for better search behavior.
Delegates to ``translate_query`` which handles comma expansion, whitespace
collapsing, operator normalization, and field aliasing.
Expands comma-separated field values to explicit AND clauses and
collapses excessive whitespace for cleaner parsing:
- tag:foo,bar → tag:foo AND tag:bar
- multiple spaces → single spaces
Args:
query: Query string after date rewriting
@@ -97,7 +461,29 @@ def normalize_query(query: str) -> str:
Returns:
Normalized query string ready for Tantivy parsing
"""
return translate_query(query, UTC)
def _expand(m: regex.Match[str]) -> str:
field = m.group(1)
values = [v.strip() for v in m.group(2).split(",") if v.strip()]
return " AND ".join(f"{field}:{v}" for v in values)
try:
query = regex.sub(
r"(\w+):([^\s\[\]]+(?:,[^\s\[\]]+)+)",
_expand,
query,
timeout=_REGEX_TIMEOUT,
)
query = regex.sub(r" {2,}", " ", query, timeout=_REGEX_TIMEOUT).strip()
# Strip trailing dangling operators before Tantivy sees them.
query = _TRAILING_OPERATOR_RE.sub("", query, timeout=_REGEX_TIMEOUT).strip()
# Replace " - " / " + " with a space: Tantivy requires no space between
# the operator and its operand (-term / +term), so spaces on both sides
# means this is a natural-language separator, not a query operator.
query = _SPACED_OPERATOR_RE.sub(" ", query, timeout=_REGEX_TIMEOUT).strip()
return query
except TimeoutError: # pragma: no cover
raise ValueError("Query too complex to process (normalization timed out)")
def build_permission_filter(
@@ -217,16 +603,8 @@ def parse_user_query(
as a post-search score filter, not during query construction.
"""
try:
query_str = translate_query(raw_query, tz)
except SearchQueryError:
# Intentional, user-fixable error (e.g. an unparsable date). Propagate so
# the view can return a 400 with a helpful message rather than falling
# back to the raw (still-invalid) query.
raise
except Exception: # pragma: no cover - defensive
logger.warning("Query translation failed; using raw query", exc_info=True)
query_str = raw_query
query_str = rewrite_natural_date_keywords(raw_query, tz)
query_str = normalize_query(query_str)
exact = index.parse_query(
query_str,
-566
View File
@@ -1,566 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TypeAlias
import regex
from dateutil.relativedelta import relativedelta
from documents.search._dates import _DATE_KEYWORDS
from documents.search._dates import _DATE_ONLY_FIELDS
from documents.search._dates import _date_only_range
from documents.search._dates import _datetime_range
from documents.search._dates import _field_range_from_dates
from documents.search._dates import _fmt
from documents.search._dates import _precision_bounds
from documents.search._dates import _utc_bounds_for_field
# Compiled regex that matches any known multi-word (or single-word) date keyword
# at the start of a match position, longest alternatives first so "previous week"
# wins over a hypothetical shorter "previous".
_KEYWORD_VALUE_RE = regex.compile(
"|".join(sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True)),
regex.IGNORECASE,
)
if TYPE_CHECKING:
from datetime import tzinfo
# TODO: this module translates date queries into Tantivy *string* syntax, which
# forces a workaround for something Tantivy's string parser cannot express on
# date fields: open-ended ranges use far-past/far-future string sentinels
# (OPEN_LO/OPEN_HI). These can be replaced with a real tantivy.Query object
# (Query.range_query(..., None) for open bounds) once tantivy-py accepts Python
# datetimes in range_query/term_query on Date fields. That support exists on
# tantivy-py master (PRs #655 + #666) but postdates the pinned 0.26.0 wheel, so
# it is blocked only on a published release > 0.26.0 and a dependency bump.
# (Unparsable dates now raise InvalidDateQuery -> HTTP 400 rather than using a
# no-match string sentinel.)
# Fields that store exact, non-analyzed comma-joined tokens in the index and so
# need explicit comma->AND expansion (Whoosh KEYWORD(commas=True) set).
MULTI_VALUE_FIELDS = frozenset({"tag", "tag_id", "viewer_id"})
# Date fields whose values/ranges get rewritten to RFC3339 Tantivy ranges.
DATE_FIELDS = frozenset({"created", "modified", "added"})
# Field aliases: Whoosh (v2) field names that were renamed in the Tantivy schema.
# Preserved here so v2 queries using the old names continue to work without 400
# errors instead of silently failing. Applied by _render to non-date field tokens.
FIELD_ALIASES: dict[str, str] = {
"type": "document_type",
"type_id": "document_type_id",
"path": "storage_path",
"path_id": "storage_path_id",
}
# Known schema fields: a comma immediately followed by ``<known>:`` is a clause
# separator. Restricting to known fields prevents URL-like ``http:`` misfires.
KNOWN_FIELDS = frozenset(
{
"title",
"content",
"correspondent",
"document_type",
"type", # v2 alias -> document_type
"storage_path",
"path", # v2 alias -> storage_path
"tag",
"tag_id",
"correspondent_id",
"document_type_id",
"type_id", # v2 alias -> document_type_id
"storage_path_id",
"path_id", # v2 alias -> storage_path_id
"owner_id",
"viewer_id",
"asn",
"page_count",
"num_notes",
"created",
"modified",
"added",
"original_filename",
"checksum",
"notes",
"custom_fields",
},
)
_FIELD_RE = regex.compile(r"(?P<field>\w+):")
# Matches the TO separator inside a range bracket. Handles three forms:
# middle: "lo TO hi" (either lo or hi may be empty)
# trailing: "lo TO" (open upper bound)
# leading: "TO hi" (open lower bound)
# Bounds MAY contain internal spaces (e.g. "-7 days"), so we use .*? / .+?
# and split on the whitespace-delimited " TO " / " to " separator.
_RANGE_RE = regex.compile(
r"^\s*(?P<lo>.*?)\s+[Tt][Oo]\s+(?P<hi>.+?)\s*$"
r"|"
r"^\s*(?P<lo2>.+?)\s+[Tt][Oo]\s*$"
r"|"
r"^\s*[Tt][Oo]\s+(?P<hi2>.+?)\s*$",
)
@dataclass(frozen=True, slots=True)
class FieldValue:
field: str
value: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class FieldValueList:
field: str
values: tuple[str, ...]
@dataclass(frozen=True, slots=True)
class FieldRange:
field: str
open: str
lo: str
hi: str
close: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class Comma:
pass
@dataclass(frozen=True, slots=True)
class Passthrough:
raw: str
Token: TypeAlias = FieldValue | FieldValueList | FieldRange | Comma | Passthrough
_CLOSE: dict[str, str] = {"[": "]", "{": "}"}
def scan(query: str) -> list[Token]:
"""
Tokenize a raw query into date/comma-aware tokens, leaving everything else
as verbatim ``Passthrough`` runs. Non-recursive: finds the first matching
close bracket/quote. Nested brackets are not valid Tantivy range syntax and
pass through verbatim on mismatch.
"""
tokens: list[Token] = []
buf: list[str] = [] # accumulates passthrough chars
i, n = 0, len(query)
while i < n:
matched = _match_field_token(query, i)
if matched is None:
buf.append(query[i])
i += 1
continue
token, i = matched
_flush(buf, tokens)
tokens.append(token)
i = _maybe_comma(query, i, tokens)
_flush(buf, tokens)
return tokens
def _flush(buf: list[str], tokens: list[Token]) -> None:
"""Emit any accumulated passthrough characters as a single token."""
if buf:
tokens.append(Passthrough("".join(buf)))
buf.clear()
def _at_word_boundary(query: str, i: int) -> bool:
"""A field token may begin only at the start or after a non-word character."""
return i == 0 or not (query[i - 1].isalnum() or query[i - 1] == "_")
def _match_field_token(query: str, i: int) -> tuple[Token, int] | None:
"""
If a known ``field:`` token starts at ``i``, consume it and return
``(token, end_index)``; otherwise return None so the caller treats the
character as passthrough. Handles both ``field:[range]`` and ``field:value``,
and returns None when the range/value cannot be consumed.
"""
m = _FIELD_RE.match(query, i)
if m is None or m.group("field") not in KNOWN_FIELDS:
return None
if not _at_word_boundary(query, i):
return None
field = m.group("field")
j = m.end()
if j < len(query) and query[j] in "[{":
return _consume_range(query, j, field)
consumed = _consume_field_value(query, field, j)
if consumed is None:
return None
value, end = consumed
return FieldValue(field, value), end
def _consume_field_value(query: str, field: str, start: int) -> tuple[str, int] | None:
"""
Consume a field value starting at ``start``: a multi-word date keyword phrase
(date fields only), or a bare/quoted value, then absorb any comma-joined
continuation that is not a clause separator. ``resolve_commas`` later splits a
multi-value field's joined value into a ``FieldValueList``; for other fields
the comma stays literal.
"""
n = len(query)
consumed = None
if field in DATE_FIELDS:
km = _KEYWORD_VALUE_RE.match(query, start)
if km is not None and (km.end() >= n or query[km.end()] in " \t),"):
consumed = (km.group(0), km.end())
if consumed is None:
consumed = _consume_value(query, start)
if consumed is None:
return None
value, k = consumed
while k < n and query[k] == ",":
if _looks_like_known_field(query, k + 1):
break # clause separator: left for _maybe_comma to emit a Comma()
more = _consume_value(query, k + 1)
if more is None:
break
value = f"{value},{more[0]}"
k = more[1]
return value, k
def _consume_range(
query: str,
start: int,
field: str,
) -> tuple[FieldRange, int] | None:
"""Consume ``[lo TO hi]`` / ``{lo TO hi}`` from ``start`` (the bracket)."""
open_br = query[start]
close_br = _CLOSE[open_br]
end = query.find(close_br, start + 1)
if end == -1:
return None
inner = query[start + 1 : end]
m = _RANGE_RE.match(inner)
if m is not None:
if m.group("lo") is not None or m.group("hi") is not None:
# Middle form: "lo TO hi" (either may be empty string)
lo = (m.group("lo") or "").strip()
hi = (m.group("hi") or "").strip()
elif m.group("lo2") is not None:
# Trailing form: "lo TO"
lo = m.group("lo2").strip()
hi = ""
else:
# Leading form: "TO hi"
lo = ""
hi = (m.group("hi2") or "").strip()
else:
lo, hi = inner.strip(), ""
return FieldRange(field, open_br, lo, hi, close_br), end + 1
def _consume_value(query: str, start: int) -> tuple[str, int] | None:
"""Consume a bare or quoted field value from ``start``, stopping at comma."""
n = len(query)
if start >= n or query[start] in " \t":
return None
if query[start] in "\"'":
quote = query[start]
end = query.find(quote, start + 1)
if end == -1:
return None
return query[start : end + 1], end + 1
j = start
while j < n and query[j] not in " \t),":
j += 1
return query[start:j], j
def _looks_like_known_field(query: str, pos: int) -> bool:
"""True if a known ``field:`` token starts at ``pos``."""
m = _FIELD_RE.match(query, pos)
return bool(m and m.group("field") in KNOWN_FIELDS)
def _maybe_comma(query: str, i: int, tokens: list) -> int:
"""If a clause-separator comma follows at ``i``, emit ``Comma()`` and advance."""
if i < len(query) and query[i] == "," and _looks_like_known_field(query, i + 1):
tokens.append(Comma())
return i + 1
return i
def resolve_commas(tokens: list) -> list:
"""
Collapse value-list commas into ``FieldValueList`` and keep clause-separator
commas as ``Comma``. (Clause-sep commas are already emitted by ``scan`` via
the value-stop logic; this pass folds value-lists.)
"""
out: list = []
for tok in tokens:
if (
isinstance(tok, FieldValue)
and tok.field in MULTI_VALUE_FIELDS
and "," in tok.value
):
values = tuple(v for v in tok.value.split(",") if v)
out.append(FieldValueList(tok.field, values))
else:
out.append(tok)
return out
class SearchQueryError(ValueError):
"""
Base for user-fixable search query errors.
Carries a message safe to surface to the user (no internal details). The view
layer catches this and returns an HTTP 400, so any future subclass (unknown
field, malformed range, wrapped parser errors) gets the same treatment.
"""
class InvalidDateQuery(SearchQueryError):
"""Raised when a date field value or range bound cannot be parsed."""
def __init__(self, field: str, value: str) -> None:
self.field = field
self.value = value
super().__init__(f"Invalid date value {value!r} for field {field!r}.")
_DIGITS_RE = regex.compile(r"^\d{4}(?:\d{2}){0,2}$")
_ISO_RE = regex.compile(r"^\d{4}(?:-\d{2}(?:-\d{2})?)?$")
def translate_scalar(field: str, value: str, tz: tzinfo) -> str:
"""Translate a bare date-field value to a Tantivy range string."""
bare = value.strip("\"'").lower()
if bare in _DATE_KEYWORDS:
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(bare, tz)}"
return f"{field}:{_datetime_range(bare, tz)}"
digits = value.replace("-", "")
if _DIGITS_RE.match(value) or _ISO_RE.match(value):
bounds = _precision_bounds(digits)
if bounds is None:
raise InvalidDateQuery(field, value)
return _field_range_from_dates(field, bounds[0], bounds[1], tz)
if regex.fullmatch(r"\d{14}", value):
try:
dt = datetime(
int(value[0:4]),
int(value[4:6]),
int(value[6:8]),
int(value[8:10]),
int(value[10:12]),
int(value[12:14]),
tzinfo=UTC,
)
except ValueError:
raise InvalidDateQuery(field, value) from None
iso = _fmt(dt)
return f"{field}:[{iso} TO {iso}]"
# Unrecognized shape -> tell the user their date is malformed rather than
# silently matching nothing or emitting invalid Tantivy syntax.
raise InvalidDateQuery(field, value)
# Open-bound sentinels for date ranges. These far-past/far-future strings allow
# open-ended ranges to be expressed as Tantivy string queries until tantivy-py
# exposes Query.range_query(..., None) on Date fields (see module TODO).
OPEN_LO = "0001-01-01T00:00:00Z"
OPEN_HI = "9999-12-31T23:59:59Z"
# Matches compact now-offset tokens like now-7d, now+1h, now-30m.
_NOW_COMPACT_RE = regex.compile(
r"^now(?P<sign>[+-])(?P<n>\d+)(?P<unit>[dhm])$",
regex.IGNORECASE,
)
# Matches "±N <unit>" Whoosh-style offsets (e.g. -7 days, -1 week, +3 hours)
# Unit is singular or plural; sign prefix is mandatory.
_NOW_SPACED_RE = regex.compile(
r"^(?P<sign>[+-])(?P<n>\d+)\s*"
r"(?P<unit>second|minute|hour|day|week|month|year)s?$",
regex.IGNORECASE,
)
def _resolve_relative_bound(token: str) -> datetime | None:
"""
Resolve a relative bound token to an exact UTC instant, or return None.
Supported forms:
- ``now`` -> current UTC instant
- ``now+/-<n>d/h/m`` -> now +/- timedelta (d=days, h=hours, m=minutes)
- ``±N <unit>`` -> now +/- delta; month/year use relativedelta
"""
stripped = token.strip()
low = stripped.lower()
now = datetime.now(UTC)
if low == "now":
return now
m = _NOW_COMPACT_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta = (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
return now + delta
m = _NOW_SPACED_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
return now - delta_map[unit] if sign == -1 else now + delta_map[unit]
return None
def _bound_datetimes(
field: str,
token: str,
tz: tzinfo,
) -> tuple[datetime, datetime] | None:
"""
Return (floor_dt, ceil_dt) UTC datetimes for a single range bound token, or
None if the token is unparsable. ``now`` and relative offsets resolve to the
current instant (floor == ceil == that instant; no day-flooring).
"""
token = token.strip()
# Try relative/now forms first (before stripping hyphens which would mangle them).
rel = _resolve_relative_bound(token)
if rel is not None:
return rel, rel
# Full ISO datetime token (contains "T"): parse directly and return an exact
# instant (floor == ceil). Python 3.11+ datetime.fromisoformat accepts trailing Z.
if "T" in token:
try:
dt = datetime.fromisoformat(token)
# Ensure timezone-aware UTC result.
dt = dt.replace(tzinfo=UTC) if dt.tzinfo is None else dt.astimezone(UTC)
return dt, dt
except ValueError:
return None
digits = token.replace("-", "")
bounds = _precision_bounds(digits)
if bounds is None:
return None
start, end = bounds
return _utc_bounds_for_field(field, start, end, tz)
def _render(tok: Token, tz: tzinfo) -> str:
"""Render a single token back to a Tantivy query string fragment."""
if isinstance(tok, Passthrough):
return tok.raw
if isinstance(tok, Comma):
return " AND "
if isinstance(tok, FieldValueList):
field = FIELD_ALIASES.get(tok.field, tok.field)
return " AND ".join(f"{field}:{v}" for v in tok.values)
if isinstance(tok, FieldValue):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_scalar(field, tok.value, tz)
return f"{field}:{tok.value}"
if isinstance(tok, FieldRange):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_range(field, tok.lo, tok.hi, tz)
return f"{field}:{tok.open}{tok.lo} TO {tok.hi}{tok.close}"
return "" # pragma: no cover
# Post-render operator normalization patterns: collapse repeated whitespace and
# strip spaced/trailing Tantivy boolean operators that would otherwise be invalid.
_MULTI_SPACE_RE = regex.compile(r" {2,}")
_TRAILING_OP_RE = regex.compile(r"\s+[-+]+\s*$")
_SPACED_OP_RE = regex.compile(r"\s+[-+]\s+")
def _normalize_operators(text: str) -> str:
"""
Collapse multiple spaces, strip trailing dangling operators, and replace
spaced operators (`` - `` / `` + ``) with a single space.
Applied only to Passthrough fragments (the rendered output is scanned for
operator artifacts outside bracketed ranges) via a post-render pass on the
full rendered string. This preserves date ranges (``[... TO ...]``) verbatim
while cleaning natural-language separators in the surrounding text.
"""
text = _MULTI_SPACE_RE.sub(" ", text)
text = _TRAILING_OP_RE.sub("", text).strip()
text = _SPACED_OP_RE.sub(" ", text).strip()
return text
def translate_query(raw: str, tz: tzinfo) -> str:
"""Translate a raw Whoosh-style query into Tantivy-compatible syntax."""
tokens = resolve_commas(scan(raw))
rendered = "".join(_render(t, tz) for t in tokens)
return _normalize_operators(rendered)
def translate_range(field: str, lo: str, hi: str, tz: tzinfo) -> str:
"""Translate a date-field ``[lo TO hi]`` range to a Tantivy ISO range string.
Handles partial-date bounds (YYYY, YYYYMM, YYYYMMDD, ISO dash variants),
open bounds (empty string -> OPEN_LO/OPEN_HI), ``now``, and reversed ranges
(swaps tokens before computing floor/ceil so the span is always correct).
"""
lo_s = lo.strip()
hi_s = hi.strip()
# Parse both bounds to (floor, ceil) pairs when present.
lo_pair: tuple[datetime, datetime] | None = None
hi_pair: tuple[datetime, datetime] | None = None
if lo_s:
lo_pair = _bound_datetimes(field, lo_s, tz)
if lo_pair is None:
raise InvalidDateQuery(field, lo_s)
if hi_s:
hi_pair = _bound_datetimes(field, hi_s, tz)
if hi_pair is None:
raise InvalidDateQuery(field, hi_s)
# Detect a reversed range: only swap when BOTH bounds are present.
if lo_pair is not None and hi_pair is not None and lo_pair[0] > hi_pair[0]:
lo_pair, hi_pair = hi_pair, lo_pair
lo_iso = _fmt(lo_pair[0]) if lo_pair is not None else OPEN_LO
hi_iso = _fmt(hi_pair[1]) if hi_pair is not None else OPEN_HI
return f"{field}:[{lo_iso} TO {hi_iso}]"
-12
View File
@@ -1,15 +1,11 @@
from __future__ import annotations
import tempfile
from typing import TYPE_CHECKING
import pytest
import tantivy
from documents.search._backend import TantivyBackend
from documents.search._backend import reset_backend
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
if TYPE_CHECKING:
from collections.abc import Generator
@@ -35,11 +31,3 @@ def backend() -> Generator[TantivyBackend, None, None]:
finally:
b.close()
reset_backend()
@pytest.fixture(scope="module")
def index() -> tantivy.Index:
"""A real Tantivy index for parse-acceptance tests (module scope for speed)."""
idx = tantivy.Index(build_schema(), path=tempfile.mkdtemp())
register_tokenizers(idx, "english")
return idx
+10 -88
View File
@@ -13,6 +13,7 @@ import time_machine
from documents.search._query import _date_only_range
from documents.search._query import _datetime_range
from documents.search._query import _rewrite_compact_date
from documents.search._query import build_permission_filter
from documents.search._query import normalize_query
from documents.search._query import parse_simple_text_highlight_query
@@ -20,7 +21,6 @@ from documents.search._query import parse_user_query
from documents.search._query import rewrite_natural_date_keywords
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
from documents.search._translate import InvalidDateQuery
if TYPE_CHECKING:
from django.contrib.auth.base_user import AbstractBaseUser
@@ -405,14 +405,12 @@ class TestWhooshQueryRewriting:
assert lo == "2023-12-01T05:00:00Z"
assert hi == "2023-12-02T05:00:00Z"
def test_8digit_invalid_date_raises(self) -> None:
# The translation pipeline raises InvalidDateQuery for unparsable dates
# (e.g. month=13) so the API can surface a 400 telling the user the date
# is malformed instead of silently returning zero results.
with pytest.raises(InvalidDateQuery) as exc_info:
rewrite_natural_date_keywords("added:20231340", UTC)
assert exc_info.value.field == "added"
assert exc_info.value.value == "20231340"
def test_8digit_invalid_date_passes_through_unchanged(self) -> None:
assert rewrite_natural_date_keywords("added:20231340", UTC) == "added:20231340"
def test_compact_14digit_invalid_date_passes_through_unchanged(self) -> None:
# Month=13 makes datetime() raise ValueError; the token must be left as-is
assert _rewrite_compact_date("20231300120000") == "20231300120000"
class TestParseUserQuery:
@@ -465,67 +463,6 @@ class TestParseUserQuery:
) -> None:
assert isinstance(parse_user_query(query_index, raw_query, UTC), tantivy.Query)
@pytest.mark.parametrize(
"raw_query",
[
# Partial date scalar (year only)
pytest.param("created:2020", id="created_year_scalar"),
# 8-digit compact date range in brackets
pytest.param(
"created:[20200101 TO 20201231]",
id="created_8digit_bracket_range",
),
# Comma-separated field + date range (Whoosh v2 multi-clause syntax)
pytest.param(
"title:x,created:[2020 TO 2021]",
id="title_comma_created_range",
),
# Field alias: type -> document_type
pytest.param("type:invoice", id="type_alias"),
# Multi-word date keyword
pytest.param("created:previous week", id="created_previous_week"),
# Full ISO datetime range
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="created_iso_range",
),
# Comma-separated ISO ranges (Whoosh v2 syntax)
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="comma_iso_ranges",
),
],
)
def test_advanced_search_queries_do_not_raise(
self,
query_index: tantivy.Index,
raw_query: str,
) -> None:
"""
End-to-end: queries that the frontend sends must parse without raising.
This tests the full pipeline: translate_query -> tantivy parse_query.
Equivalent to asserting HTTP 200 (not 400) for each query form.
"""
with time_machine.travel(datetime(2026, 6, 15, 12, 0, tzinfo=UTC), tick=False):
assert isinstance(
parse_user_query(query_index, raw_query, UTC),
tantivy.Query,
)
def test_invalid_date_propagates_not_swallowed(
self,
query_index: tantivy.Index,
) -> None:
# parse_user_query falls back to the raw query on unexpected translation
# errors, but an InvalidDateQuery is intentional and must propagate so the
# view can return a 400 instead of silently parsing the raw (invalid) date.
with pytest.raises(InvalidDateQuery) as exc_info:
parse_user_query(query_index, "created:202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
class TestYearRangeRewriting:
"""Whoosh-style year-only date ranges must be rewritten to ISO 8601."""
@@ -605,16 +542,11 @@ class TestYearRangeRewriting:
assert rewrite_natural_date_keywords(original, UTC) == original
def test_8digit_in_brackets_not_matched_as_year_range(self) -> None:
# [YYYYMMDD TO YYYYMMDD]: the translation layer converts 8-digit bounds to
# ISO day ranges. 20200101 -> 2020-01-01T00:00:00Z (lo of that day);
# 20201231 -> the ceil of Dec 31 = 2021-01-01T00:00:00Z (exclusive end).
# This is the correct and accepted behavior: old compact form becomes a
# proper Tantivy-parseable ISO range.
# [YYYYMMDD TO YYYYMMDD] has 8-digit values - must not be caught by year rewriter
original = "created:[20200101 TO 20201231]"
result = rewrite_natural_date_keywords(original, UTC)
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2021-01-01T00:00:00Z"
assert "20200101" in result or "2020-01-01" in result
assert "20201231" in result or "2020-12-31" in result
class TestNonDateFieldsNotRewritten:
@@ -674,16 +606,6 @@ class TestNormalizeQuery:
def test_normalize_expands_comma_separated_tags(self) -> None:
assert normalize_query("tag:foo,bar") == "tag:foo AND tag:bar"
def test_normalize_comma_between_range_expressions(self) -> None:
# Comma-separated field range expressions (Whoosh v2 syntax) must be
# converted to AND so Tantivy does not receive an invalid comma.
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert normalize_query(q) == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_normalize_expands_three_values(self) -> None:
assert normalize_query("tag:foo,bar,baz") == "tag:foo AND tag:bar AND tag:baz"
@@ -1,742 +0,0 @@
from __future__ import annotations
from datetime import UTC
from datetime import datetime
from typing import TYPE_CHECKING
from zoneinfo import ZoneInfo
import pytest
import time_machine
from documents.search._dates import _precision_bounds
if TYPE_CHECKING:
import tantivy
from documents.search._query import _FIELD_BOOSTS
from documents.search._query import DEFAULT_SEARCH_FIELDS
from documents.search._translate import OPEN_HI
from documents.search._translate import OPEN_LO
from documents.search._translate import Comma
from documents.search._translate import FieldRange
from documents.search._translate import FieldValue
from documents.search._translate import FieldValueList
from documents.search._translate import InvalidDateQuery
from documents.search._translate import Passthrough
from documents.search._translate import resolve_commas
from documents.search._translate import scan
from documents.search._translate import translate_query
from documents.search._translate import translate_range
from documents.search._translate import translate_scalar
@pytest.mark.search
class TestPrecisionBounds:
@pytest.mark.parametrize(
("digits", "expected"),
[
("2020", ((2020, 1, 1), (2021, 1, 1))),
("202003", ((2020, 3, 1), (2020, 4, 1))),
("202012", ((2020, 12, 1), (2021, 1, 1))),
("20200115", ((2020, 1, 15), (2020, 1, 16))),
("20201231", ((2020, 12, 31), (2021, 1, 1))),
],
)
def test_valid(self, digits, expected):
lo, hi = _precision_bounds(digits)
assert (lo.year, lo.month, lo.day) == expected[0]
assert (hi.year, hi.month, hi.day) == expected[1]
@pytest.mark.parametrize("digits", ["202023", "20200230", "20201301", "20", "abcd"])
def test_invalid_returns_none(self, digits):
assert _precision_bounds(digits) is None
@pytest.mark.search
class TestScan:
def test_plain_words_are_passthrough(self):
assert scan("bank statement") == [Passthrough("bank statement")]
def test_field_value(self):
assert scan("created:2020") == [FieldValue("created", "2020")]
def test_field_value_in_boolean(self):
toks = scan("created:2020 OR foo")
assert toks == [
FieldValue("created", "2020"),
Passthrough(" OR foo"),
]
def test_field_value_in_parens(self):
toks = scan("(created:2020 OR foo)")
assert toks == [
Passthrough("("),
FieldValue("created", "2020"),
Passthrough(" OR foo)"),
]
def test_quoted_value(self):
assert scan('correspondent:"A B"') == [FieldValue("correspondent", '"A B"')]
def test_field_range(self):
assert scan("created:[2020 TO 2021]") == [
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.parametrize(
("query", "expected"),
[
pytest.param(
"created:[2020 to]",
FieldRange("created", "[", "2020", "", "]"),
id="open_upper",
),
pytest.param(
"created:[to 2020]",
FieldRange("created", "[", "", "2020", "]"),
id="open_lower",
),
],
)
def test_open_range(self, query, expected):
assert scan(query) == [expected]
def test_comma_inside_range_not_split(self):
# No depth-0 comma here; the whole thing is one range token.
toks = scan("created:[2020 TO 2021]")
assert len(toks) == 1
# --- Edge-case / regression tests (scan must never raise) ---
def test_url_is_passthrough(self):
# "http" is not a known field; the whole URL must pass through verbatim.
assert scan("http://example.com") == [Passthrough("http://example.com")]
def test_unterminated_quote_is_passthrough(self):
# title is a known field but the quoted value has no closing quote;
# _consume_value returns None so the whole string falls into passthrough.
assert scan('title:"abc') == [Passthrough('title:"abc')]
def test_unterminated_bracket_is_passthrough(self):
# created is a known field but the range bracket is never closed;
# _consume_range returns None so the whole string falls into passthrough.
assert scan("created:[2020") == [Passthrough("created:[2020")]
def test_empty_value_at_end_is_passthrough(self):
# created is a known field but there is no value after the colon
# (_consume_value returns None for start >= n), so passthrough.
assert scan("created:") == [Passthrough("created:")]
def test_value_containing_colon(self):
# The bare-word value reader stops at whitespace/paren, not at colon,
# so "2020:30" is consumed as a single value token.
assert scan("created:2020:30") == [FieldValue("created", "2020:30")]
def test_comma_followed_by_unconsumable_value_stops(self):
# A comma followed by whitespace is neither a value-list continuation nor a
# clause separator: the value stops and the comma stays as passthrough.
assert scan("tag:foo, bar") == [
FieldValue("tag", "foo"),
Passthrough(", bar"),
]
def test_bracket_without_to_is_open_upper_bound(self):
# A bracketed value with no TO falls back to (value, "") -> open upper bound.
assert scan("created:[2020]") == [
FieldRange("created", "[", "2020", "", "]"),
]
def test_known_field_name_midword_is_passthrough(self):
# A known field name embedded mid-word is not a field token (the
# word-boundary guard); the whole run stays passthrough.
assert scan("xtag:foo") == [Passthrough("xtag:foo")]
@pytest.mark.search
class TestCommaResolution:
def test_value_list_multi_value_field(self):
toks = resolve_commas(scan("tag:foo,bar"))
assert toks == [FieldValueList("tag", ("foo", "bar"))]
def test_value_list_three(self):
toks = resolve_commas(scan("tag_id:1,2,3"))
assert toks == [FieldValueList("tag_id", ("1", "2", "3"))]
def test_text_field_comma_is_literal(self):
# correspondent is not multi-value: comma stays inside the value.
toks = resolve_commas(scan("correspondent:foo,bar"))
assert toks == [FieldValue("correspondent", "foo,bar")]
def test_clause_separator_before_known_field(self):
toks = resolve_commas(scan("tag:foo,type:bar"))
assert toks == [FieldValue("tag", "foo"), Comma(), FieldValue("type", "bar")]
def test_clause_separator_after_range(self):
toks = resolve_commas(scan("created:[2020 TO 2021],added:[2022 TO 2023]"))
assert toks == [
FieldRange("created", "[", "2020", "2021", "]"),
Comma(),
FieldRange("added", "[", "2022", "2023", "]"),
]
def test_clause_separator_after_quote(self):
toks = resolve_commas(scan('correspondent:"A B",created:[2020 TO 2021]'))
assert toks == [
FieldValue("correspondent", '"A B"'),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
def test_url_comma_is_literal_passthrough(self):
toks = resolve_commas(scan("http://example.com/a,b"))
assert toks == [Passthrough("http://example.com/a,b")]
def test_non_multi_value_comma_is_literal(self):
# title is not in MULTI_VALUE_FIELDS: comma stays inside the value.
toks = resolve_commas(scan("title:10,20"))
assert toks == [FieldValue("title", "10,20")]
def test_clause_separator_before_known_date_field(self):
# The comma between a bare value and a known date field acts as a
# clause separator; both sides survive as distinct tokens.
toks = resolve_commas(scan("correspondent:foo,created:[2020 TO 2021]"))
assert toks == [
FieldValue("correspondent", "foo"),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.search
class TestTranslateScalar:
@pytest.mark.parametrize(
("field", "value", "expected"),
[
(
"created",
"2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"created",
"202003",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
(
"created",
"20200115",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-01-15",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-03",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
],
)
def test_partial_and_iso_dates(self, field: str, value: str, expected: str) -> None:
assert translate_scalar(field, value, UTC) == expected
def test_invalid_date_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_keyword_delegates(self) -> None:
# keyword path produces a range; just assert it is a created range
out = translate_scalar("created", "today", UTC)
assert out.startswith("created:[") and out.endswith("]")
def test_14digit_compact_datetime(self) -> None:
out = translate_scalar("created", "20240115120000", UTC)
assert "20240115120000" not in out
assert out.startswith("created:")
assert out == "created:[2024-01-15T12:00:00Z TO 2024-01-15T12:00:00Z]"
def test_14digit_invalid_month_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "20231300120000", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "20231300120000"
def test_unrecognized_value_raises(self) -> None:
# A value that is not a keyword, digits, ISO date, or compact timestamp
# raises rather than producing invalid Tantivy syntax or silently matching
# nothing.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateRange:
@pytest.mark.parametrize(
("lo", "hi", "expected"),
[
("2005", "2009", "created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"),
(
"202001",
"202006",
"created:[2020-01-01T00:00:00Z TO 2020-07-01T00:00:00Z]",
),
(
"20200101",
"20201231",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"2020-01-01",
"2020-12-31",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
],
)
def test_absolute_ranges(self, lo, hi, expected):
assert translate_range("created", lo, hi, UTC) == expected
def test_reversed_swaps(self):
assert translate_range("created", "2009", "2005", UTC) == (
"created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"
)
def test_open_upper(self):
out = translate_range("created", "2020", "", UTC)
assert out == f"created:[2020-01-01T00:00:00Z TO {OPEN_HI}]"
def test_open_lower(self):
out = translate_range("created", "", "2020", UTC)
assert out == f"created:[{OPEN_LO} TO 2021-01-01T00:00:00Z]"
def test_invalid_bound_raises(self):
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "202023", "2025", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_invalid_high_bound_raises(self):
# Low bound parses, high bound does not -> raise on the high bound.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "2020", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateQuery:
@pytest.mark.parametrize(
("raw", "expected"),
[
(
"created:2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
("tag:foo,bar", "tag:foo AND tag:bar"),
# 'type' is a user-facing alias rewritten to 'document_type' (the real schema field)
("tag:foo,type:bar", "tag:foo AND document_type:bar"),
(
"created:[2020 TO 2021],added:[2022 TO 2023]",
"created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
" AND "
"added:[2022-01-01T00:00:00Z TO 2024-01-01T00:00:00Z]",
),
# correspondent is not multi-value: comma stays literal inside the value
("correspondent:foo,bar", "correspondent:foo,bar"),
],
)
def test_golden(self, raw: str, expected: str) -> None:
assert translate_query(raw, UTC) == expected
@pytest.mark.parametrize(
"raw",
[
"created:2020",
"created:202003",
"created:[20200101 TO 20201231]",
"created:[2020-01-01 TO 2020-12-31]",
"created:[2020 to]",
"created:[to 2020]",
"title:x,created:[2020 TO 2021]",
"created:2020 OR foo",
"(created:2020 OR invoice)",
"tag:foo,type:bar",
"bank statement",
],
)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
# Must not raise:
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestFieldAliasing:
"""Whoosh->Tantivy field-name aliasing (type/path -> document_type/storage_path)."""
def test_type_alias(self) -> None:
assert translate_query("type:invoice", UTC) == "document_type:invoice"
def test_path_alias(self) -> None:
assert translate_query("path:/foo/bar", UTC) == "storage_path:/foo/bar"
def test_type_id_alias(self) -> None:
assert translate_query("type_id:5", UTC) == "document_type_id:5"
def test_path_id_alias(self) -> None:
assert translate_query("path_id:7", UTC) == "storage_path_id:7"
def test_clause_separator_plus_alias(self) -> None:
# Comma between known fields acts as AND separator; alias still applied.
assert (
translate_query("tag:foo,type:bar", UTC) == "tag:foo AND document_type:bar"
)
def test_type_range_alias(self) -> None:
# type is not a date field; range passes through verbatim with alias applied.
assert (
translate_query("type:[2020 TO 2021]", UTC)
== "document_type:[2020 TO 2021]"
)
def test_parse_acceptance_type(self, index: tantivy.Index) -> None:
# Translated output must be accepted by the real Tantivy parser.
translated = translate_query("type:invoice", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_path(self, index: tantivy.Index) -> None:
translated = translate_query("path:foo", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
# Freeze time so relative-date tests are deterministic.
_FROZEN_NOW = datetime(2026, 3, 28, 12, 0, 0, tzinfo=UTC)
@pytest.mark.search
class TestRelativeRanges:
"""Relative date-range tokens resolved against a frozen clock."""
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_7_days_to_now(self) -> None:
assert translate_query("added:[-7 days to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_week_to_now(self) -> None:
assert translate_query("added:[-1 week to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_month_to_now(self) -> None:
assert translate_query("created:[-1 month to now]", UTC) == (
"created:[2026-02-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_year_to_now(self) -> None:
assert translate_query("modified:[-1 year to now]", UTC) == (
"modified:[2025-03-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_3_hours_to_now(self) -> None:
assert translate_query("added:[-3 hours to now]", UTC) == (
"added:[2026-03-28T09:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_uppercase_units(self) -> None:
assert translate_query("added:[-1 WEEK TO NOW]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_now_minus_7d_compact(self) -> None:
assert translate_query("added:[now-7d TO now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_reversed_range_swapped(self) -> None:
# now+1h TO now-1h is reversed; translate_range swaps -> lo=now-1h, hi=now+1h
assert translate_query("added:[now+1h TO now-1h]", UTC) == (
"added:[2026-03-28T11:00:00Z TO 2026-03-28T13:00:00Z]"
)
@pytest.mark.parametrize(
"raw",
[
"added:[-7 days to now]",
"added:[-1 week to now]",
"created:[-1 month to now]",
"modified:[-1 year to now]",
"added:[-3 hours to now]",
"added:[now-7d TO now]",
"added:[now+1h TO now-1h]",
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestOperatorNormalization:
"""Post-render operator normalization in translate_query."""
def test_spaced_dash_removed(self) -> None:
assert (
translate_query("H52.1 - Kurzsichtigkeit", UTC) == "H52.1 Kurzsichtigkeit"
)
def test_spaced_dash_simple(self) -> None:
assert translate_query("bar - baz", UTC) == "bar baz"
def test_trailing_operator_stripped(self) -> None:
assert translate_query("foo -", UTC) == "foo"
def test_date_range_preserved(self) -> None:
out = translate_query("created:[2020 TO 2021]", UTC)
# Must not corrupt the ISO range
assert out == "created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
def test_date_scalar_with_or(self) -> None:
out = translate_query("created:2020 OR foo", UTC)
# The created scalar becomes a range; " OR foo" passes through verbatim.
assert out.startswith("created:[")
assert "OR foo" in out
def test_parse_acceptance_spaced_dash(self, index: tantivy.Index) -> None:
translated = translate_query("H52.1 - Kurzsichtigkeit", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_trailing_op(self, index: tantivy.Index) -> None:
translated = translate_query("foo -", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestMultiWordDateKeywords:
"""scan() must consume multi-word date keywords as a single value."""
def test_scan_previous_week_as_single_token(self) -> None:
# "created:previous week" must produce one FieldValue with value "previous week",
# not FieldValue("created","previous") + Passthrough(" week").
toks = scan("created:previous week")
assert toks == [FieldValue("created", "previous week")]
def test_scan_this_month_as_single_token(self) -> None:
toks = scan("added:this month")
assert toks == [FieldValue("added", "this month")]
def test_scan_previous_month_as_single_token(self) -> None:
toks = scan("created:previous month")
assert toks == [FieldValue("created", "previous month")]
def test_scan_this_year_as_single_token(self) -> None:
toks = scan("added:this year")
assert toks == [FieldValue("added", "this year")]
def test_scan_previous_year_as_single_token(self) -> None:
toks = scan("created:previous year")
assert toks == [FieldValue("created", "previous year")]
def test_scan_previous_quarter_as_single_token(self) -> None:
toks = scan("created:previous quarter")
assert toks == [FieldValue("created", "previous quarter")]
def test_quoted_multi_word_keyword_still_works(self) -> None:
# The quoted form must continue to work as before.
toks = scan('created:"previous week"')
assert toks == [FieldValue("created", '"previous week"')]
def test_non_date_field_not_affected(self) -> None:
# "previous" stops at the space for non-date fields; " week" passes through.
toks = scan("correspondent:previous week")
assert toks == [
FieldValue("correspondent", "previous"),
Passthrough(" week"),
]
@pytest.mark.search
class TestKeywordDateResolution:
"""Relative date keywords resolve to exact ISO ranges against a frozen clock.
Frozen at 2026-03-28 12:00 UTC (a Saturday in Q1) so the week, month,
quarter and year rollovers are all exercised by a single anchor.
"""
# created is a DateField: bounds are UTC midnight, no timezone offset.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"created:[2026-03-28T00:00:00Z TO 2026-03-29T00:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"created:[2026-03-27T00:00:00Z TO 2026-03-28T00:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"created:[2026-03-16T00:00:00Z TO 2026-03-23T00:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"created:[2026-03-01T00:00:00Z TO 2026-04-01T00:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"created:[2026-02-01T00:00:00Z TO 2026-03-01T00:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"created:[2026-01-01T00:00:00Z TO 2027-01-01T00:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"created:[2025-01-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"created:[2025-10-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_date_only_field_keyword_ranges(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"created:{keyword}", UTC) == expected
# added is a DateTimeField: local-tz midnight converted to UTC. Tokyo
# (+09:00, no DST) shifts each midnight boundary back to 15:00Z the day
# before, so this also exercises the local-midnight offset path.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"added:[2026-03-27T15:00:00Z TO 2026-03-28T15:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"added:[2026-03-26T15:00:00Z TO 2026-03-27T15:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"added:[2026-03-15T15:00:00Z TO 2026-03-22T15:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"added:[2026-02-28T15:00:00Z TO 2026-03-31T15:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"added:[2026-01-31T15:00:00Z TO 2026-02-28T15:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"added:[2025-12-31T15:00:00Z TO 2026-12-31T15:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"added:[2024-12-31T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"added:[2025-09-30T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_datetime_field_keyword_ranges_local_tz(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"added:{keyword}", ZoneInfo("Asia/Tokyo")) == expected
@pytest.mark.search
class TestISODatetimeBounds:
"""Full ISO datetime tokens in range bounds must be parsed directly."""
def test_translate_range_iso_bounds_passthrough(self) -> None:
# Already-ISO datetime bounds must pass through as-is (exact instant).
result = translate_range(
"created",
"2020-01-01T00:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert result == "created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]"
def test_translate_query_iso_range_preserved(self) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert translate_query(q, UTC) == q
def test_translate_query_comma_separated_iso_ranges(self) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
result = translate_query(q, UTC)
assert result == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_invalid_iso_datetime_raises(self) -> None:
# A token with "T" that is not valid ISO datetime -> raise.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range(
"created",
"2020-01-01T99:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert exc_info.value.field == "created"
assert exc_info.value.value == "2020-01-01T99:00:00Z"
def test_parse_acceptance_iso_bounds(self, index: tantivy.Index) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_comma_iso_ranges(self, index: tantivy.Index) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@@ -82,7 +82,6 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
"llm_api_key": None,
"llm_endpoint": None,
"llm_output_language": None,
"llm_request_timeout": None,
},
)
+3 -6
View File
@@ -725,11 +725,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
GIVEN:
- One document added right now
WHEN:
- Query with an invalid added date
- Query with invalid added date
THEN:
- 400 Bad Request with a message naming the malformed date, so the
user knows their date is invalid rather than silently getting zero
results
- 400 Bad Request returned (Tantivy rejects invalid date field syntax)
"""
d1 = Document.objects.create(
title="invoice",
@@ -742,9 +740,8 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
response = self.client.get("/api/documents/?query=added:invalid-date")
# An unparsable date is reported as a malformed query, not silently empty.
# Tantivy rejects unparsable field queries with a 400
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("invalid-date", str(response.data["query"]))
@override_settings(
TIME_ZONE="UTC",
-71
View File
@@ -216,77 +216,6 @@ class TestSystemStatus(APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_none(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns no worker responses
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
mock_ping.return_value = None
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(
response.data["tasks"]["celery_error"],
"No celery workers responded to ping. This may be temporary.",
)
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_unexpected_responses(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns an unexpected worker response
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
self.client.force_login(self.user)
for ping_response in (
{"hostname": {"ok": "not-pong"}},
{"hostname": {}},
{"hostname": "pong"},
):
with self.subTest(ping_response=ping_response):
mock_ping.return_value = ping_response
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(response.data["tasks"]["celery_url"], "hostname")
self.assertEqual(
response.data["tasks"]["celery_error"],
"Celery worker responded unexpectedly.",
)
@mock.patch("documents.views.sleep")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_retry_success(
self,
mock_ping,
mock_sleep,
) -> None:
"""
GIVEN:
- Celery ping fails once but succeeds on retry
WHEN:
- The user requests the system status
THEN:
- The response contains an OK celery status
"""
mock_ping.side_effect = [None, {"hostname": {"ok": "pong"}}]
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
self.assertIsNone(response.data["tasks"]["celery_error"])
self.assertEqual(mock_ping.call_count, 2)
mock_sleep.assert_called_once_with(0.25)
@mock.patch("documents.search.get_backend")
def test_system_status_index_ok(self, mock_get_backend) -> None:
"""
@@ -684,7 +684,6 @@ class ConsumerThread(Thread):
subdirs_as_tags: bool = False,
polling_interval: float = 0,
stability_delay: float = 0.1,
rescan_interval: float | None = None,
) -> None:
super().__init__()
self.consumption_dir = consumption_dir
@@ -694,8 +693,6 @@ class ConsumerThread(Thread):
self.polling_interval = polling_interval
self.stability_delay = stability_delay
self.cmd = Command()
if rescan_interval is not None:
self.cmd.rescan_interval_s = rescan_interval
self.cmd.stop_flag.clear()
# Non-daemon ensures finally block runs and connections are closed
self.daemon = False
@@ -1055,200 +1052,3 @@ class TestCommandWatchEdgeCases:
thread.stop_and_wait(timeout=5.0)
# Clean up any Tags created by the thread
Tag.objects.all().delete()
class TestRescanExistingFiles:
"""
Unit tests for the rescan safety net.
Each ``watch()`` recreation silently adopts the current directory contents
as its baseline, so a file appearing between one batch and the next
watcher's baseline is never reported and would sit in the consume directory
forever. ``_rescan_existing_files`` re-injects such files into the
stability tracker as a periodic safety net (see GH issue #13011).
"""
@pytest.fixture
def pdf_only_filter(self) -> ConsumerFilter:
return ConsumerFilter(
supported_extensions=frozenset({".pdf"}),
ignore_patterns=[],
)
def _rescan(
self,
directory: Path,
consumer_filter: ConsumerFilter,
tracker: FileStabilityTracker,
queued: set[Path],
*,
recursive: bool = False,
) -> None:
Command()._rescan_existing_files(
directory=directory,
recursive=recursive,
consumer_filter=consumer_filter,
tracker=tracker,
queued=queued,
)
def test_tracks_stranded_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A supported on-disk file the watcher never reported gets tracked."""
target = consumption_dir / "stranded.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.is_tracking(target) is True
assert tracker.pending_count == 1
def test_skips_already_tracked_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A file already being tracked by the watcher is not double-tracked."""
target = consumption_dir / "tracked.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
tracker.track(target, Change.added)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.pending_count == 1
def test_skips_queued_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A file already queued and awaiting consumption is not re-tracked."""
target = consumption_dir / "inflight.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
queued = {target.resolve()}
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
assert tracker.pending_count == 0
def test_prunes_vanished_queued_paths(
self,
consumption_dir: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Queued paths no longer on disk are dropped so the name can recur."""
gone = (consumption_dir / "gone.pdf").resolve()
tracker = FileStabilityTracker(stability_delay=0.1)
queued = {gone}
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
assert gone not in queued
def test_skips_unsupported_extension(
self,
consumption_dir: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Files filtered out by the consumer filter are not tracked."""
(consumption_dir / "notes.xyz").write_bytes(b"content")
tracker = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.pending_count == 0
def test_recursive_respects_flag(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Nested files are only found when recursive scanning is enabled."""
subdir = consumption_dir / "nested"
subdir.mkdir()
target = subdir / "deep.pdf"
shutil.copy(sample_pdf, target)
shallow = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, shallow, set())
assert shallow.pending_count == 0
deep = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, deep, set(), recursive=True)
assert deep.is_tracking(target) is True
class TestProcessExistingFilesQueued:
"""Tests that startup processing reports which paths it queued."""
@pytest.mark.usefixtures("mock_supported_extensions")
def test_returns_queued_paths(
self,
consumption_dir: Path,
sample_pdf: Path,
mock_consume_file_delay: MagicMock,
settings: SettingsWrapper,
) -> None:
"""The set returned seeds the rescan's queued set, avoiding re-queue."""
target = consumption_dir / "document.pdf"
shutil.copy(sample_pdf, target)
settings.CONSUMER_IGNORE_PATTERNS = []
queued = Command()._process_existing_files(
directory=consumption_dir,
recursive=False,
subdirs_as_tags=False,
consumer_filter=ConsumerFilter(ignore_patterns=[]),
)
assert target.resolve() in queued
@pytest.mark.management
@pytest.mark.django_db
class TestCommandRescanRecovery:
"""End-to-end test that the rescan recovers files the watcher misses."""
def test_rescan_consumes_file_the_watcher_never_reports(
self,
consumption_dir: Path,
sample_pdf: Path,
mock_consume_file_delay: MagicMock,
start_consumer: Callable[..., ConsumerThread],
) -> None:
"""
Isolate the rescan path: a long polling interval guarantees the
watcher cannot report the file within the test window, so only the
periodic rescan can consume it.
"""
# poll interval far longer than the test window -> watcher stays silent
thread = start_consumer(
polling_interval=30.0,
stability_delay=0.1,
rescan_interval=0.5,
)
# created after startup, so _process_existing_files did not see it
target = consumption_dir / "stranded.pdf"
shutil.copy(sample_pdf, target)
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=5.0)
if thread.exception:
raise thread.exception
mock_consume_file_delay.apply_async.assert_called()
call_args = mock_consume_file_delay.apply_async.call_args.kwargs["kwargs"][
"input_doc"
]
assert call_args.original_file.name == "stranded.pdf"
+3 -28
View File
@@ -30,7 +30,6 @@ from documents.signals.handlers import update_llm_suggestions_cache
from documents.tests.utils import DirectoriesMixin
from documents.tests.utils import read_streaming_response
from paperless.models import ApplicationConfiguration
from paperless_ai.exceptions import LLMTimeoutError
class TestViews(DirectoriesMixin, TestCase):
@@ -369,6 +368,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
self.document,
self.user,
None,
hints=None,
)
@patch("documents.views.get_ai_document_classification")
@@ -400,6 +400,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
self.document,
self.user,
"de-de",
hints=None,
)
self.assertEqual(
get_llm_suggestion_cache(
@@ -439,6 +440,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
self.document,
self.user,
"fr-fr",
hints=None,
)
self.assertEqual(
get_llm_suggestion_cache(
@@ -477,33 +479,6 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
)
@patch("documents.views.get_ai_document_classification")
@override_settings(
AI_ENABLED=True,
LLM_BACKEND="openai-like",
)
def test_ai_suggestions_with_llm_timeout(
self,
mock_get_ai_classification,
) -> None:
mock_get_ai_classification.side_effect = LLMTimeoutError()
self.client.force_login(user=self.user)
response = self.client.get(
f"/api/documents/{self.document.pk}/ai_suggestions/",
)
self.assertEqual(response.status_code, status.HTTP_503_SERVICE_UNAVAILABLE)
self.assertEqual(
response.json(),
{
"ai": ["AI backend request timed out."],
},
)
self.assertIsNone(
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
)
def test_invalidate_suggestions_cache(self) -> None:
self.client.force_login(user=self.user)
suggestions = {
+13 -42
View File
@@ -12,7 +12,6 @@ from datetime import timedelta
from http import HTTPStatus
from pathlib import Path
from time import mktime
from time import sleep
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
@@ -241,12 +240,12 @@ from paperless.serialisers import UserSerializer
from paperless.views import StandardPagination
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.chat import stream_chat_with_documents
from paperless_ai.exceptions import LLMTimeoutError
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
from paperless_ai.matching import match_storage_paths_by_name
from paperless_ai.matching import match_tags_by_name
from paperless_ai.taxonomy import get_taxonomy_hints_for_document
from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule
from paperless_mail.oauth import PaperlessMailOAuth2Manager
@@ -1496,11 +1495,14 @@ class DocumentViewSet(
refresh_suggestions_cache(doc.pk)
return Response(cached_llm_suggestions.suggestions)
hints = get_taxonomy_hints_for_document(doc, request.user)
try:
llm_suggestions = get_ai_document_classification(
doc,
request.user,
output_language,
hints=hints,
)
except ValueError as exc:
logger.exception(
@@ -1511,33 +1513,26 @@ class DocumentViewSet(
exc_info=True,
)
raise ValidationError({"ai": [_("Invalid AI configuration.")]}) from exc
except LLMTimeoutError as exc:
logger.exception(
"AI backend timed out while generating suggestions for document %s: %s",
doc.pk,
exc,
exc_info=True,
)
return Response(
{"ai": [_("AI backend request timed out.")]},
status=status.HTTP_503_SERVICE_UNAVAILABLE,
)
matched_tags = match_tags_by_name(
llm_suggestions.get("tags", []),
request.user,
hinted_names=set(hints["tags"]) if hints else None,
)
matched_correspondents = match_correspondents_by_name(
llm_suggestions.get("correspondents", []),
request.user,
hinted_names=set(hints["correspondents"]) if hints else None,
)
matched_types = match_document_types_by_name(
llm_suggestions.get("document_types", []),
request.user,
hinted_names=set(hints["document_types"]) if hints else None,
)
matched_paths = match_storage_paths_by_name(
llm_suggestions.get("storage_paths", []),
request.user,
hinted_names=set(hints["storage_paths"]) if hints else None,
)
resp_data = {
@@ -2289,7 +2284,6 @@ class UnifiedSearchViewSet(DocumentViewSet):
return super().list(request)
from documents.search import SearchHit
from documents.search import SearchQueryError
from documents.search import TantivyBackend
from documents.search import TantivyRelevanceList
from documents.search import get_backend
@@ -2482,11 +2476,6 @@ class UnifiedSearchViewSet(DocumentViewSet):
return HttpResponseForbidden(_("Insufficient permissions."))
except ValidationError:
raise
except SearchQueryError as e:
# User-fixable query error (e.g. an unparsable date): surface the
# specific message so the user can correct it, rather than a generic
# 400 or silently empty results.
raise ValidationError({"query": [str(e)]}) from e
except Exception as e:
logger.warning(f"An error occurred listing search results: {e!s}")
return HttpResponseBadRequest(
@@ -5009,29 +4998,11 @@ class SystemStatusView(PassUserMixin):
celery_error = None
celery_url = None
try:
celery_ping = None
for ping_attempt in range(3):
celery_ping = celery_app.control.inspect().ping()
if celery_ping:
break
if ping_attempt < 2:
sleep(0.25)
if not celery_ping:
celery_active = "WARNING"
celery_error = (
"No celery workers responded to ping. This may be temporary."
)
else:
celery_url, first_worker_ping = next(iter(celery_ping.items()))
if (
isinstance(first_worker_ping, dict)
and first_worker_ping.get("ok") == "pong"
):
celery_active = "OK"
else:
celery_active = "WARNING"
celery_error = "Celery worker responded unexpectedly."
celery_ping = celery_app.control.inspect().ping()
celery_url = next(iter(celery_ping.keys()))
first_worker_ping = celery_ping[celery_url]
if first_worker_ping["ok"] == "pong":
celery_active = "OK"
except Exception as e:
celery_active = "ERROR"
logger.exception(
-4
View File
@@ -197,7 +197,6 @@ class AIConfig(BaseConfig):
llm_embedding_endpoint: str = dataclasses.field(init=False)
llm_embedding_chunk_size: int = dataclasses.field(init=False)
llm_context_size: int = dataclasses.field(init=False)
llm_request_timeout: int = dataclasses.field(init=False)
llm_backend: str = dataclasses.field(init=False)
llm_model: str = dataclasses.field(init=False)
llm_api_key: str = dataclasses.field(init=False)
@@ -222,9 +221,6 @@ class AIConfig(BaseConfig):
app_config.llm_embedding_chunk_size or settings.LLM_EMBEDDING_CHUNK_SIZE
)
self.llm_context_size = app_config.llm_context_size or settings.LLM_CONTEXT_SIZE
self.llm_request_timeout = (
app_config.llm_request_timeout or settings.LLM_REQUEST_TIMEOUT
)
self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND
self.llm_model = app_config.llm_model or settings.LLM_MODEL
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
@@ -1,365 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:30
import django.core.validators
from django.db import migrations
from django.db import models
def _create_singleton(apps, schema_editor):
settings_model = apps.get_model("paperless", "ApplicationConfiguration")
settings_model.objects.create()
class Migration(migrations.Migration):
replaces = [
("paperless", "0001_initial"),
("paperless", "0002_applicationconfiguration_app_logo_and_more"),
("paperless", "0003_alter_applicationconfiguration_max_image_pixels"),
("paperless", "0004_applicationconfiguration_barcode_asn_prefix_and_more"),
("paperless", "0005_applicationconfiguration_ai_enabled_and_more"),
("paperless", "0006_applicationconfiguration_barcode_tag_split"),
]
dependencies = []
operations = [
migrations.CreateModel(
name="ApplicationConfiguration",
fields=[
(
"id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
(
"output_type",
models.CharField(
blank=True,
choices=[
("pdf", "pdf"),
("pdfa", "pdfa"),
("pdfa-1", "pdfa-1"),
("pdfa-2", "pdfa-2"),
("pdfa-3", "pdfa-3"),
],
max_length=8,
null=True,
verbose_name="Sets the output PDF type",
),
),
(
"pages",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Do OCR from page 1 to this value",
),
),
(
"language",
models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Do OCR using these languages",
),
),
(
"mode",
models.CharField(
blank=True,
choices=[
("skip", "skip"),
("redo", "redo"),
("force", "force"),
("skip_noarchive", "skip_noarchive"),
],
max_length=16,
null=True,
verbose_name="Sets the OCR mode",
),
),
(
"skip_archive_file",
models.CharField(
blank=True,
choices=[
("never", "never"),
("with_text", "with_text"),
("always", "always"),
],
max_length=16,
null=True,
verbose_name="Controls the generation of an archive file",
),
),
(
"image_dpi",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets image DPI fallback value",
),
),
(
"unpaper_clean",
models.CharField(
blank=True,
choices=[
("clean", "clean"),
("clean-final", "clean-final"),
("none", "none"),
],
max_length=16,
null=True,
verbose_name="Controls the unpaper cleaning",
),
),
(
"deskew",
models.BooleanField(null=True, verbose_name="Enables deskew"),
),
(
"rotate_pages",
models.BooleanField(
null=True,
verbose_name="Enables page rotation",
),
),
(
"rotate_pages_threshold",
models.FloatField(
null=True,
validators=[django.core.validators.MinValueValidator(0.0)],
verbose_name="Sets the threshold for rotation of pages",
),
),
(
"max_image_pixels",
models.FloatField(
null=True,
validators=[django.core.validators.MinValueValidator(0.0)],
verbose_name="Sets the maximum image size for decompression",
),
),
(
"color_conversion_strategy",
models.CharField(
blank=True,
choices=[
("LeaveColorUnchanged", "LeaveColorUnchanged"),
("RGB", "RGB"),
("UseDeviceIndependentColor", "UseDeviceIndependentColor"),
("Gray", "Gray"),
("CMYK", "CMYK"),
],
max_length=32,
null=True,
verbose_name="Sets the Ghostscript color conversion strategy",
),
),
(
"user_args",
models.JSONField(
null=True,
verbose_name="Adds additional user arguments for OCRMyPDF",
),
),
(
"app_logo",
models.FileField(
blank=True,
null=True,
upload_to="logo/",
validators=[
django.core.validators.FileExtensionValidator(
allowed_extensions=["jpg", "png", "gif", "svg"],
),
],
verbose_name="Application logo",
),
),
(
"app_title",
models.CharField(
blank=True,
max_length=48,
null=True,
verbose_name="Application title",
),
),
(
"barcode_asn_prefix",
models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Sets the ASN barcode prefix",
),
),
(
"barcode_dpi",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the barcode DPI",
),
),
(
"barcode_enable_asn",
models.BooleanField(
null=True,
verbose_name="Enables ASN barcode",
),
),
(
"barcode_enable_tag",
models.BooleanField(
null=True,
verbose_name="Enables tag barcode",
),
),
(
"barcode_enable_tiff_support",
models.BooleanField(
null=True,
verbose_name="Enables barcode TIFF support",
),
),
(
"barcode_max_pages",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the maximum pages for barcode",
),
),
(
"barcode_retain_split_pages",
models.BooleanField(
null=True,
verbose_name="Retains split pages",
),
),
(
"barcode_string",
models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Sets the barcode string",
),
),
(
"barcode_tag_mapping",
models.JSONField(
null=True,
verbose_name="Sets the tag barcode mapping",
),
),
(
"barcode_upscale",
models.FloatField(
null=True,
validators=[django.core.validators.MinValueValidator(1.0)],
verbose_name="Sets the barcode upscale factor",
),
),
(
"barcodes_enabled",
models.BooleanField(
null=True,
verbose_name="Enables barcode scanning",
),
),
(
"ai_enabled",
models.BooleanField(
default=False,
null=True,
verbose_name="Enables AI features",
),
),
(
"llm_api_key",
models.CharField(
blank=True,
max_length=1024,
null=True,
verbose_name="Sets the LLM API key",
),
),
(
"llm_backend",
models.CharField(
blank=True,
choices=[
("openai-like", "OpenAI-compatible"),
("ollama", "Ollama"),
],
max_length=128,
null=True,
verbose_name="Sets the LLM backend",
),
),
(
"llm_embedding_backend",
models.CharField(
blank=True,
choices=[
("openai-like", "OpenAI-compatible"),
("huggingface", "Huggingface"),
],
max_length=128,
null=True,
verbose_name="Sets the LLM embedding backend",
),
),
(
"llm_embedding_model",
models.CharField(
blank=True,
max_length=128,
null=True,
verbose_name="Sets the LLM embedding model",
),
),
(
"llm_endpoint",
models.CharField(
blank=True,
max_length=256,
null=True,
verbose_name="Sets the LLM endpoint, optional",
),
),
(
"llm_model",
models.CharField(
blank=True,
max_length=128,
null=True,
verbose_name="Sets the LLM model",
),
),
(
"barcode_tag_split",
models.BooleanField(
null=True,
verbose_name="Enables splitting on tag barcodes",
),
),
],
options={
"verbose_name": "paperless application settings",
},
),
migrations.RunPython(
code=_create_singleton,
reverse_code=migrations.RunPython.noop,
),
]
@@ -1,94 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:19
import django.core.validators
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
replaces = [
("paperless", "0009_alter_applicationconfiguration_options"),
("paperless", "0010_alter_applicationconfiguration_llm_embedding_backend"),
("paperless", "0011_applicationconfiguration_llm_embedding_chunk_size"),
("paperless", "0012_applicationconfiguration_llm_output_language"),
("paperless", "0013_applicationconfiguration_llm_request_timeout"),
]
dependencies = [
("paperless", "0008_replace_skip_archive_file"),
]
operations = [
migrations.AlterModelOptions(
name="applicationconfiguration",
options={
"permissions": [
("view_global_statistics", "Can view global object counts"),
("view_system_monitoring", "Can view system status information"),
],
"verbose_name": "paperless application settings",
},
),
migrations.AlterField(
model_name="applicationconfiguration",
name="llm_embedding_backend",
field=models.CharField(
blank=True,
choices=[
("openai-like", "OpenAI-compatible"),
("huggingface", "Huggingface"),
("ollama", "Ollama"),
],
max_length=128,
null=True,
verbose_name="Sets the LLM embedding backend",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_embedding_endpoint",
field=models.CharField(
blank=True,
max_length=256,
null=True,
verbose_name="Sets the LLM embedding endpoint, optional",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_embedding_chunk_size",
field=models.PositiveSmallIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM embedding chunk size",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_context_size",
field=models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM context size",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_output_language",
field=models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Sets the LLM output language",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_request_timeout",
field=models.PositiveSmallIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM request timeout in seconds",
),
),
]
@@ -1,23 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-14 14:22
import django.core.validators
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
dependencies = [
("paperless", "0012_applicationconfiguration_llm_output_language"),
]
operations = [
migrations.AddField(
model_name="applicationconfiguration",
name="llm_request_timeout",
field=models.PositiveSmallIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM request timeout in seconds",
),
),
]
-6
View File
@@ -366,12 +366,6 @@ class ApplicationConfiguration(AbstractSingletonModel):
max_length=32,
)
llm_request_timeout = models.PositiveSmallIntegerField(
verbose_name=_("Sets the LLM timeout in seconds"),
null=True,
validators=[MinValueValidator(1)],
)
class Meta:
verbose_name = _("paperless application settings")
permissions = [
-3
View File
@@ -1206,9 +1206,6 @@ if LLM_EMBEDDING_CHUNK_SIZE < 1:
LLM_CONTEXT_SIZE = get_int_from_env("PAPERLESS_AI_LLM_CONTEXT_SIZE", 8192)
if LLM_CONTEXT_SIZE < 1:
raise ImproperlyConfigured("PAPERLESS_AI_LLM_CONTEXT_SIZE must be >= 1")
LLM_REQUEST_TIMEOUT = get_int_from_env("PAPERLESS_AI_LLM_REQUEST_TIMEOUT", 120)
if LLM_REQUEST_TIMEOUT < 1:
raise ImproperlyConfigured("PAPERLESS_AI_LLM_REQUEST_TIMEOUT must be >= 1")
LLM_BACKEND = get_choice_from_env(
"PAPERLESS_AI_LLM_BACKEND",
{"ollama", "openai-like"},
+20 -19
View File
@@ -1,16 +1,21 @@
import json
import logging
from typing import TYPE_CHECKING
from django.conf import settings
from django.contrib.auth.models import User
from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware
from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.db import db_connection_released
from paperless_ai.indexing import query_similar_documents
from paperless_ai.indexing import truncate_content
from paperless_ai.indexing import visible_document_ids_for_user
from paperless_ai.taxonomy import format_hints_for_prompt
if TYPE_CHECKING:
from paperless_ai.taxonomy import TaxonomyHints
logger = logging.getLogger("paperless_ai.rag_classifier")
@@ -26,6 +31,7 @@ def get_language_name(language_code: str) -> str:
def build_prompt_without_rag(
document: Document,
config: AIConfig,
hints: "TaxonomyHints | None" = None,
) -> str:
filename = document.filename or ""
content = truncate_content(
@@ -34,10 +40,16 @@ def build_prompt_without_rag(
context_size=config.llm_context_size,
)
hints_block = format_hints_for_prompt(hints) if hints else ""
# Splice the block (if any) immediately before the "Analyze ..." instruction.
# When there is no block this expands to nothing, so the prompt is identical
# to the pre-hints baseline.
hints_section = f"{hints_block}\n\n " if hints_block else ""
return f"""
You are a document classification assistant.
Analyze the following document and extract the following information:
{hints_section}Analyze the following document and extract the following information:
- A short descriptive title
- Tags that reflect the content
- Names of people or organizations mentioned
@@ -57,8 +69,9 @@ def build_prompt_with_rag(
document: Document,
config: AIConfig,
user: User | None = None,
hints: "TaxonomyHints | None" = None,
) -> str:
base_prompt = build_prompt_without_rag(document, config)
base_prompt = build_prompt_without_rag(document, config, hints=hints)
context = truncate_content(
get_context_for_document(document, user),
chunk_size=config.llm_embedding_chunk_size,
@@ -96,20 +109,7 @@ def get_context_for_document(
user: User | None = None,
max_docs: int = 5,
) -> str:
visible_documents = (
get_objects_for_user_owner_aware(
user,
"view_document",
Document,
)
if user
else None
)
visible_document_ids = (
list(visible_documents.values_list("pk", flat=True))
if visible_documents is not None
else None
)
visible_document_ids = visible_document_ids_for_user(user)
similar_docs = query_similar_documents(
document=doc,
document_ids=visible_document_ids,
@@ -137,13 +137,14 @@ def get_ai_document_classification(
document: Document,
user: User | None = None,
output_language: str | None = None,
hints: "TaxonomyHints | None" = None,
) -> dict:
ai_config = AIConfig()
prompt = (
build_prompt_with_rag(document, ai_config, user)
build_prompt_with_rag(document, ai_config, user, hints=hints)
if ai_config.llm_embedding_backend
else build_prompt_without_rag(document, ai_config)
else build_prompt_without_rag(document, ai_config, hints=hints)
)
client = AIClient()
+28 -49
View File
@@ -1,14 +1,11 @@
import json
import logging
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING
import httpx
from paperless.models import LLMBackend
if TYPE_CHECKING:
from llama_index.core.llms import ChatMessage
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai_like import OpenAILike
@@ -19,7 +16,6 @@ from paperless.network import create_pinned_async_httpx_client
from paperless.network import create_pinned_httpx_client
from paperless.network import validate_outbound_http_url
from paperless_ai.base_model import DocumentClassifierSchema
from paperless_ai.exceptions import LLMTimeoutError
logger = logging.getLogger("paperless_ai.client")
@@ -65,16 +61,16 @@ class AIClient:
model=self.settings.llm_model or "llama3.1",
base_url=endpoint,
context_window=self.settings.llm_context_size,
request_timeout=self.settings.llm_request_timeout,
request_timeout=120,
system_prompt=LLM_SYSTEM_PROMPT,
client=Client(
host=endpoint,
timeout=self.settings.llm_request_timeout,
timeout=120,
transport=transport,
),
async_client=AsyncClient(
host=endpoint,
timeout=self.settings.llm_request_timeout,
timeout=120,
transport=async_transport,
),
)
@@ -88,18 +84,15 @@ class AIClient:
http_client = create_pinned_httpx_client(
endpoint,
allow_internal=self.settings.llm_allow_internal_endpoints,
timeout=self.settings.llm_request_timeout,
)
async_http_client = create_pinned_async_httpx_client(
endpoint,
allow_internal=self.settings.llm_allow_internal_endpoints,
timeout=self.settings.llm_request_timeout,
)
return OpenAILike(
model=self.settings.llm_model or "gpt-3.5-turbo",
api_base=endpoint,
api_key=self.settings.llm_api_key,
timeout=self.settings.llm_request_timeout,
is_chat_model=True,
is_function_calling_model=True,
system_prompt=LLM_SYSTEM_PROMPT,
@@ -120,12 +113,11 @@ class AIClient:
user_msg = ChatMessage(role="user", content=prompt)
if self.settings.llm_backend == LLMBackend.OLLAMA:
with self._normalize_timeouts():
result = self.llm.chat(
[user_msg],
format=DocumentClassifierSchema.model_json_schema(),
think=False,
)
result = self.llm.chat(
[user_msg],
format=DocumentClassifierSchema.model_json_schema(),
think=False,
)
logger.debug("LLM query result: %s", result)
parsed = DocumentClassifierSchema(**json.loads(result.message.content))
return parsed.model_dump()
@@ -133,39 +125,26 @@ class AIClient:
from llama_index.core.program.function_program import get_function_tool
tool = get_function_tool(DocumentClassifierSchema)
with self._normalize_timeouts():
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
allow_parallel_tool_calls=True,
tool_required=True,
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_call=True,
)
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
allow_parallel_tool_calls=True,
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_call=True,
)
logger.debug("LLM query result: %s", tool_calls)
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
return parsed.model_dump()
@contextmanager
def _normalize_timeouts(self) -> Iterator[None]:
try:
yield
except httpx.TimeoutException as exc:
raise LLMTimeoutError from exc
except Exception as exc:
if self._is_openai_timeout(exc):
raise LLMTimeoutError from exc
raise
def _is_openai_timeout(self, exc: Exception) -> bool:
if self.settings.llm_backend != LLMBackend.OPENAI_LIKE:
return False
# Keep OpenAI imports out of module import paths and only load the SDK
# when translating an error from an OpenAI-backed request.
from openai import APITimeoutError
return isinstance(exc, APITimeoutError)
def run_chat(self, messages: list["ChatMessage"]) -> str:
logger.debug(
"Running chat query against %s with model %s",
self.settings.llm_backend,
self.settings.llm_model,
)
result = self.llm.chat(messages)
logger.debug("Chat result: %s", result)
return result
-5
View File
@@ -32,18 +32,15 @@ def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
http_client = create_pinned_httpx_client(
endpoint,
allow_internal=config.llm_allow_internal_endpoints,
timeout=config.llm_request_timeout,
)
async_http_client = create_pinned_async_httpx_client(
endpoint,
allow_internal=config.llm_allow_internal_endpoints,
timeout=config.llm_request_timeout,
)
return OpenAILikeEmbedding(
model_name=config.llm_embedding_model or "text-embedding-3-small",
api_key=config.llm_api_key,
api_base=endpoint,
timeout=config.llm_request_timeout,
http_client=http_client,
async_http_client=async_http_client,
)
@@ -76,14 +73,12 @@ def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
)
embedding._client = Client(
host=endpoint,
timeout=config.llm_request_timeout,
transport=PinnedHostHTTPTransport(
allow_internal=config.llm_allow_internal_endpoints,
),
)
embedding._async_client = AsyncClient(
host=endpoint,
timeout=config.llm_request_timeout,
transport=PinnedHostAsyncHTTPTransport(
allow_internal=config.llm_allow_internal_endpoints,
),
-2
View File
@@ -1,2 +0,0 @@
class LLMTimeoutError(Exception):
pass
+46 -5
View File
@@ -5,6 +5,7 @@ from datetime import timedelta
from typing import TYPE_CHECKING
from django.conf import settings
from django.contrib.auth.models import User
from django.utils import timezone
from filelock import FileLock
from filelock import ReadWriteLock
@@ -12,6 +13,7 @@ from filelock import Timeout
from documents.models import Document
from documents.models import PaperlessTask
from documents.permissions import get_objects_for_user_owner_aware
from documents.utils import IterWrapper
from documents.utils import identity
from paperless.config import AIConfig
@@ -22,6 +24,7 @@ from paperless_ai.embedding import get_embedding_model
if TYPE_CHECKING:
from llama_index.core.schema import BaseNode
from llama_index.core.schema import NodeWithScore
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
@@ -449,12 +452,36 @@ def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str]
return {str(document_id) for document_id in document_ids}
def query_similar_documents(
def visible_document_ids_for_user(user: User | None) -> list[int] | None:
"""Return the pks of documents ``user`` may view, or ``None`` for no filter.
Returns ``None`` when ``user`` is ``None`` so retrieval runs unfiltered. Used
by both the similarity-context and taxonomy-hints paths to scope RAG
neighbours to documents the requesting user is allowed to see.
"""
if user is None:
return None
visible_documents = get_objects_for_user_owner_aware(
user,
"view_document",
Document,
)
return list(visible_documents.values_list("pk", flat=True))
def retrieve_similar_nodes(
document: Document,
top_k: int = 5,
document_ids: Iterable[int | str] | None = None,
) -> list[Document]:
"""Return up to ``top_k`` Documents most similar to ``document``."""
top_k: int = 5,
) -> list["NodeWithScore"]:
"""Run ANN retrieval and return the raw NodeWithScore results.
Returns ``[]`` when the allow-list normalizes to empty, or when no index
exists yet (queuing a build in that case). The ``retrieve()`` call is a slow
embedding request, so it runs inside ``db_connection_released()`` to avoid
pinning the pooled DB connection (#12976). Both ``query_similar_documents``
and the taxonomy-hints path go through here, so they share that behavior.
"""
allowed_document_ids = normalize_document_ids(document_ids)
if allowed_document_ids is not None and not allowed_document_ids:
return []
@@ -494,7 +521,21 @@ def query_similar_documents(
filters=filters,
)
with db_connection_released():
results = retriever.retrieve(query_text)
return retriever.retrieve(query_text)
def query_similar_documents(
document: Document,
top_k: int = 5,
document_ids: Iterable[int | str] | None = None,
) -> list[Document]:
"""Return up to ``top_k`` Documents most similar to ``document``."""
allowed_document_ids = normalize_document_ids(document_ids)
results = retrieve_similar_nodes(
document=document,
document_ids=allowed_document_ids,
top_k=top_k,
)
retrieved_document_ids: list[int] = []
for node in results:
+38 -11
View File
@@ -15,40 +15,56 @@ MATCH_THRESHOLD = 0.8
logger = logging.getLogger("paperless_ai.matching")
def match_tags_by_name(names: list[str], user: User) -> list[Tag]:
def match_tags_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[Tag]:
queryset = get_objects_for_user_owner_aware(
user,
["view_tag"],
Tag,
)
return _match_names_to_queryset(names, queryset, "name")
return _match_names_to_queryset(names, queryset, "name", hinted_names)
def match_correspondents_by_name(names: list[str], user: User) -> list[Correspondent]:
def match_correspondents_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[Correspondent]:
queryset = get_objects_for_user_owner_aware(
user,
["view_correspondent"],
Correspondent,
)
return _match_names_to_queryset(names, queryset, "name")
return _match_names_to_queryset(names, queryset, "name", hinted_names)
def match_document_types_by_name(names: list[str], user: User) -> list[DocumentType]:
def match_document_types_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[DocumentType]:
queryset = get_objects_for_user_owner_aware(
user,
["view_documenttype"],
DocumentType,
)
return _match_names_to_queryset(names, queryset, "name")
return _match_names_to_queryset(names, queryset, "name", hinted_names)
def match_storage_paths_by_name(names: list[str], user: User) -> list[StoragePath]:
def match_storage_paths_by_name(
names: list[str],
user: User,
hinted_names: set[str] | None = None,
) -> list[StoragePath]:
queryset = get_objects_for_user_owner_aware(
user,
["view_storagepath"],
StoragePath,
)
return _match_names_to_queryset(names, queryset, "name")
return _match_names_to_queryset(names, queryset, "name", hinted_names)
def _normalize(s: str) -> str:
@@ -58,10 +74,18 @@ def _normalize(s: str) -> str:
return s
def _match_names_to_queryset(names: list[str], queryset, attr: str):
def _match_names_to_queryset(
names: list[str],
queryset,
attr: str,
hinted_names: set[str] | None = None,
):
results = []
objects = list(queryset)
object_names = [_normalize(getattr(obj, attr)) for obj in objects]
normalized_hints = (
{_normalize(name) for name in hinted_names} if hinted_names else set()
)
for name in names:
if not name:
@@ -76,6 +100,11 @@ def _match_names_to_queryset(names: list[str], queryset, attr: str):
results.append(matched)
continue
# A hinted name that didn't exact-match came from existing taxonomy
# verbatim; do not fuzzy-map it onto a different object.
if target in normalized_hints:
continue
# Fuzzy match fallback
matches = difflib.get_close_matches(
target,
@@ -88,8 +117,6 @@ def _match_names_to_queryset(names: list[str], queryset, attr: str):
matched = objects.pop(index)
object_names.pop(index)
results.append(matched)
else:
pass
return results
+115
View File
@@ -0,0 +1,115 @@
from typing import TYPE_CHECKING
from typing import TypedDict
from django.contrib.auth.models import User
from documents.models import Document
from paperless.config import AIConfig
from paperless_ai.indexing import retrieve_similar_nodes
from paperless_ai.indexing import visible_document_ids_for_user
if TYPE_CHECKING:
from llama_index.core.schema import NodeWithScore
class TaxonomyHints(TypedDict):
tags: list[str]
document_types: list[str]
correspondents: list[str]
storage_paths: list[str]
def build_taxonomy_hints_from_nodes(
nodes: list["NodeWithScore"],
) -> TaxonomyHints:
"""Collect the unique, sorted taxonomy names carried on retrieved nodes.
Reads ``tags`` (a list), ``document_type``, ``correspondent``, and
``storage_path`` from each node's metadata. Empty / ``None`` values and
missing keys are skipped. The result is naturally bounded by the retrieval
``top_k``, so no cap is applied.
"""
tags: set[str] = set()
document_types: set[str] = set()
correspondents: set[str] = set()
storage_paths: set[str] = set()
for node in nodes:
metadata = node.metadata or {}
for tag in metadata.get("tags") or []:
if tag:
tags.add(tag)
document_type = metadata.get("document_type")
if document_type:
document_types.add(document_type)
correspondent = metadata.get("correspondent")
if correspondent:
correspondents.add(correspondent)
storage_path = metadata.get("storage_path")
if storage_path:
storage_paths.add(storage_path)
return TaxonomyHints(
tags=sorted(tags),
document_types=sorted(document_types),
correspondents=sorted(correspondents),
storage_paths=sorted(storage_paths),
)
_HINT_INSTRUCTION = (
"Prefer existing names from these lists verbatim. Only propose a new value "
"if none of the existing names fits."
)
def format_hints_for_prompt(hints: TaxonomyHints) -> str:
"""Render non-empty hint categories as labelled blocks plus one instruction.
Returns "" when every category is empty, so callers can treat the result
the same as no hints at all.
"""
# Literal-key access keeps this TypedDict-safe for mypy; the order here is
# the order the blocks appear in the prompt.
labelled_values: list[tuple[str, list[str]]] = [
("Available tags", hints["tags"]),
("Available document types", hints["document_types"]),
("Available correspondents", hints["correspondents"]),
("Available storage paths", hints["storage_paths"]),
]
blocks: list[str] = []
for label, values in labelled_values:
if values:
listing = "\n".join(f"- {value}" for value in values)
blocks.append(f"{label}:\n{listing}")
if not blocks:
return ""
return "\n\n".join([*blocks, _HINT_INSTRUCTION])
def get_taxonomy_hints_for_document(
document: Document,
user: User | None,
) -> TaxonomyHints | None:
"""Build taxonomy hints from a document's RAG neighbours.
Returns ``None`` when no embedding backend is configured (the gate) so the
caller's prompt and matching are identical to today. Otherwise returns a
``TaxonomyHints`` -- possibly all-empty when no similar documents exist.
Applies the same owner-aware visible-document filter as
``get_context_for_document``.
"""
if not AIConfig().llm_embedding_backend:
return None
nodes = retrieve_similar_nodes(
document=document,
document_ids=visible_document_ids_for_user(user),
)
return build_taxonomy_hints_from_nodes(nodes)
@@ -1,8 +1,11 @@
import json
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
import pytest_mock
from django.test import override_settings
from documents.models import Document
@@ -261,3 +264,111 @@ def test_get_context_for_document_no_similar_docs(mock_document):
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
result = get_context_for_document(mock_document)
assert result == ""
class TestPromptHints:
@pytest.fixture
def config(self) -> AIConfig:
# build_prompt_* only read these two numeric settings off config;
# a stand-in avoids constructing a DB-backed AIConfig.
return cast(
"AIConfig",
SimpleNamespace(llm_embedding_chunk_size=1000, llm_context_size=8000),
)
def test_without_rag_includes_hints_block(
self,
mock_document: MagicMock,
config: AIConfig,
) -> None:
hints = {
"tags": ["Bloodwork"],
"document_types": ["Invoice"],
"correspondents": [],
"storage_paths": [],
}
prompt = build_prompt_without_rag(mock_document, config, hints=hints)
assert "Available tags:" in prompt
assert "- Bloodwork" in prompt
assert "Prefer existing names from these lists verbatim" in prompt
def test_without_rag_none_matches_baseline(
self,
mock_document: MagicMock,
config: AIConfig,
) -> None:
baseline = build_prompt_without_rag(mock_document, config)
with_none = build_prompt_without_rag(mock_document, config, hints=None)
assert with_none == baseline
assert "Available tags:" not in with_none
def test_with_rag_includes_context_and_hints(
self,
mock_document: MagicMock,
config: AIConfig,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.ai_classifier.get_context_for_document",
return_value="TITLE: Neighbour\nsome context",
)
hints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
prompt = build_prompt_with_rag(mock_document, config, user=None, hints=hints)
assert "Additional context from similar documents" in prompt
assert "Available tags:" in prompt
def test_classification_forwards_hints(
self,
mock_document: MagicMock,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.ai_classifier.AIConfig",
return_value=SimpleNamespace(
llm_embedding_backend=None,
llm_embedding_chunk_size=1000,
llm_context_size=8000,
),
)
build = mocker.patch(
"paperless_ai.ai_classifier.build_prompt_without_rag",
return_value="PROMPT",
)
mock_client = MagicMock()
mock_client.run_llm_query.return_value = {
"title": "t",
"tags": [],
"correspondents": [],
"document_types": [],
"storage_paths": [],
"dates": [],
}
mocker.patch("paperless_ai.ai_classifier.AIClient", return_value=mock_client)
hints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
result = get_ai_document_classification(
mock_document,
user=None,
hints=hints,
)
_, build_kwargs = build.call_args
assert build_kwargs["hints"] == hints
assert set(result.keys()) == {
"title",
"tags",
"correspondents",
"document_types",
"storage_paths",
"dates",
}
@@ -1,4 +1,5 @@
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -726,3 +727,58 @@ class TestQuerySimilarDocuments:
results = indexing.query_similar_documents(a, document_ids=[b.id])
assert all(doc.id == b.id for doc in results)
class TestRetrieveSimilarNodes:
@pytest.mark.django_db
def test_returns_raw_nodes_from_retriever(
self,
temp_llm_index_dir: Path,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=True)
mocker.patch("paperless_ai.indexing.load_or_build_index")
node1 = SimpleNamespace(metadata={"document_id": "1"})
node2 = SimpleNamespace(metadata={"document_id": "2"})
retriever = mocker.MagicMock()
retriever.retrieve.return_value = [node1, node2]
mocker.patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=retriever,
)
result = indexing.retrieve_similar_nodes(real_document, top_k=3)
assert result == [node1, node2]
@pytest.mark.django_db
def test_empty_allow_list_fails_closed(
self,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
load = mocker.patch("paperless_ai.indexing.load_or_build_index")
result = indexing.retrieve_similar_nodes(real_document, document_ids=[])
assert result == []
load.assert_not_called()
@pytest.mark.django_db
def test_queues_update_when_index_missing(
self,
temp_llm_index_dir: Path,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=False)
queue = mocker.patch("paperless_ai.indexing.queue_llm_index_update_if_needed")
result = indexing.retrieve_similar_nodes(real_document, top_k=2)
assert result == []
queue.assert_called_once_with(
rebuild=False,
reason="LLM index not found for similarity query.",
)
+7 -32
View File
@@ -3,14 +3,12 @@ from unittest.mock import ANY
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import openai
import pytest
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection
from paperless_ai.client import LLM_SYSTEM_PROMPT
from paperless_ai.client import AIClient
from paperless_ai.exceptions import LLMTimeoutError
@pytest.fixture
@@ -19,7 +17,6 @@ def mock_ai_config():
mock_config = MagicMock()
mock_config.llm_allow_internal_endpoints = True
mock_config.llm_context_size = 8192
mock_config.llm_request_timeout = 120
MockAIConfig.return_value = mock_config
yield mock_config
@@ -67,7 +64,6 @@ def test_get_llm_openai(mock_ai_config, mock_openai_llm):
model="test_model",
api_base="http://test-url",
api_key="test_api_key",
timeout=120,
is_chat_model=True,
is_function_calling_model=True,
system_prompt=LLM_SYSTEM_PROMPT,
@@ -155,38 +151,17 @@ def test_run_llm_query_openai_uses_tools(mock_ai_config, mock_openai_llm):
mock_llm_instance.chat_with_tools.assert_called_once()
def test_run_llm_query_openai_timeout_raises_local_error(
mock_ai_config,
mock_openai_llm,
):
mock_ai_config.llm_backend = "openai-like"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_api_key = "test_api_key"
mock_ai_config.llm_endpoint = "http://test-url"
request = httpx.Request("POST", "http://test-url/v1/chat/completions")
mock_openai_llm.return_value.chat_with_tools.side_effect = openai.APITimeoutError(
request,
)
client = AIClient()
with pytest.raises(LLMTimeoutError):
client.run_llm_query("test_prompt")
def test_run_llm_query_httpx_timeout_raises_local_error(
mock_ai_config,
mock_ollama_llm,
):
def test_run_chat(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_endpoint = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.chat.side_effect = httpx.ReadTimeout("timed out")
mock_llm_instance.chat.return_value = "test_chat_result"
client = AIClient()
messages = [ChatMessage(role="user", content="Hello")]
result = client.run_chat(messages)
with pytest.raises(LLMTimeoutError):
client.run_llm_query("test_prompt")
mock_llm_instance.chat.assert_called_once_with(messages)
assert result == "test_chat_result"
-3
View File
@@ -19,7 +19,6 @@ def mock_ai_config():
MockAIConfig.return_value.llm_embedding_endpoint = None
MockAIConfig.return_value.llm_allow_internal_endpoints = True
MockAIConfig.return_value.llm_context_size = 8192
MockAIConfig.return_value.llm_request_timeout = 120
yield MockAIConfig
@@ -72,7 +71,6 @@ def test_get_embedding_model_openai(mock_ai_config):
model_name="text-embedding-3-small",
api_key="test_api_key",
api_base="http://test-url",
timeout=120,
http_client=ANY,
async_http_client=ANY,
)
@@ -94,7 +92,6 @@ def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config):
model_name="text-embedding-3-small",
api_key="test_api_key",
api_base="http://embedding-url",
timeout=120,
http_client=ANY,
async_http_client=ANY,
)
+92
View File
@@ -1,12 +1,15 @@
import difflib
from unittest.mock import patch
import pytest
import pytest_mock
from django.test import TestCase
from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.tests.factories import TagFactory
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
@@ -87,6 +90,95 @@ class TestAIMatching(TestCase):
self.assertEqual(result[1].name, "Test Tag 2")
class TestHintedMatching:
def test_hinted_verbatim_skips_fuzzy(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
spy = mocker.spy(difflib, "get_close_matches")
result = match_tags_by_name(
["Bloodwork"],
user=None,
hinted_names={"Bloodwork"},
)
assert [t.name for t in result] == ["Bloodwork"]
spy.assert_not_called()
def test_unhinted_name_still_fuzzy_matches(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
# "Bloodwrok" is a typo not in hints -> fuzzy still maps it to Bloodwork.
result = match_tags_by_name(
["Bloodwrok"],
user=None,
hinted_names={"Taxes"},
)
assert [t.name for t in result] == ["Bloodwork"]
def test_hinted_name_with_whitespace_exact_matches(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
spy = mocker.spy(difflib, "get_close_matches")
result = match_tags_by_name(
["Bloodwork "],
user=None,
hinted_names={"Bloodwork"},
)
assert [t.name for t in result] == ["Bloodwork"]
spy.assert_not_called()
def test_hinted_name_absent_from_queryset_is_skipped_not_fuzzed(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
# A hint with no exact object must not fall through to fuzzy.
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Bloodwork")],
)
result = match_tags_by_name(
["Bloodwrok"],
user=None,
hinted_names={"Bloodwrok"},
)
assert result == []
def test_backward_compatible_without_kwarg(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.matching.get_objects_for_user_owner_aware",
return_value=[TagFactory.build(name="Test Tag 1")],
)
result = match_tags_by_name(["Test Tag 1", "Nonexistent"], user=None)
assert [t.name for t in result] == ["Test Tag 1"]
@pytest.mark.django_db
class TestExtractUnmatchedNamesNormalization:
def test_punctuated_name_already_matched_is_not_returned_as_unmatched(
+220
View File
@@ -0,0 +1,220 @@
from types import SimpleNamespace
import pytest_mock
from documents.tests.factories import DocumentFactory
from paperless_ai.taxonomy import TaxonomyHints
from paperless_ai.taxonomy import build_taxonomy_hints_from_nodes
from paperless_ai.taxonomy import format_hints_for_prompt
from paperless_ai.taxonomy import get_taxonomy_hints_for_document
def make_node(**metadata: object) -> SimpleNamespace:
"""A stand-in for NodeWithScore: only ``.metadata`` is accessed."""
return SimpleNamespace(metadata=metadata)
class TestBuildTaxonomyHintsFromNodes:
def test_returns_all_four_keys(self) -> None:
hints = build_taxonomy_hints_from_nodes([])
assert set(hints.keys()) == {
"tags",
"document_types",
"correspondents",
"storage_paths",
}
def test_collects_and_sorts_values(self) -> None:
nodes = [
make_node(
tags=["Taxes", "Bloodwork"],
document_type="Invoice",
correspondent="IRS",
storage_path="Financial",
),
]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["tags"] == ["Bloodwork", "Taxes"]
assert hints["document_types"] == ["Invoice"]
assert hints["correspondents"] == ["IRS"]
assert hints["storage_paths"] == ["Financial"]
def test_deduplicates_across_nodes(self) -> None:
nodes = [
make_node(tags=["Taxes"], document_type="Invoice"),
make_node(tags=["Taxes", "Medical"], document_type="Invoice"),
]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["tags"] == ["Medical", "Taxes"]
assert hints["document_types"] == ["Invoice"]
def test_none_values_skipped(self) -> None:
nodes = [
make_node(
tags=["Taxes", None, ""],
document_type=None,
correspondent=None,
storage_path=None,
),
]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["tags"] == ["Taxes"]
assert hints["document_types"] == []
assert hints["correspondents"] == []
assert hints["storage_paths"] == []
def test_missing_storage_path_key_handled(self) -> None:
# Pre-enrichment nodes have no storage_path key at all.
nodes = [make_node(tags=["Taxes"], document_type="Invoice")]
hints = build_taxonomy_hints_from_nodes(nodes)
assert hints["storage_paths"] == []
def test_empty_node_list_all_empty(self) -> None:
hints = build_taxonomy_hints_from_nodes([])
assert hints == {
"tags": [],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
def test_output_stable_across_calls(self) -> None:
nodes = [make_node(tags=["b", "a", "c"])]
assert build_taxonomy_hints_from_nodes(
nodes,
) == build_taxonomy_hints_from_nodes(nodes)
class TestFormatHintsForPrompt:
def test_all_blocks_present_when_all_categories_nonempty(self) -> None:
hints: TaxonomyHints = {
"tags": ["Bloodwork"],
"document_types": ["Invoice"],
"correspondents": ["IRS"],
"storage_paths": ["Financial"],
}
result = format_hints_for_prompt(hints)
assert "Available tags:" in result
assert "Available document types:" in result
assert "Available correspondents:" in result
assert "Available storage paths:" in result
assert "- Bloodwork" in result
def test_empty_category_produces_no_block(self) -> None:
hints: TaxonomyHints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
result = format_hints_for_prompt(hints)
assert "Available tags:" in result
assert "Available document types:" not in result
assert "Available correspondents:" not in result
assert "Available storage paths:" not in result
def test_all_empty_produces_empty_string(self) -> None:
hints: TaxonomyHints = {
"tags": [],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
assert format_hints_for_prompt(hints) == ""
def test_instruction_line_appears_once(self) -> None:
hints: TaxonomyHints = {
"tags": ["Bloodwork"],
"document_types": ["Invoice"],
"correspondents": [],
"storage_paths": [],
}
result = format_hints_for_prompt(hints)
assert result.count("Prefer existing names from these lists verbatim") == 1
class TestGetTaxonomyHintsForDocument:
def test_returns_none_when_embedding_backend_off(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend=None),
)
retrieve = mocker.patch("paperless_ai.taxonomy.retrieve_similar_nodes")
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
assert result is None
retrieve.assert_not_called()
def test_passes_owner_aware_ids_when_user_present(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
)
mocker.patch(
"paperless_ai.taxonomy.visible_document_ids_for_user",
return_value=[1, 2, 3],
)
retrieve = mocker.patch(
"paperless_ai.taxonomy.retrieve_similar_nodes",
return_value=[],
)
document = DocumentFactory.build()
user = mocker.MagicMock()
get_taxonomy_hints_for_document(document, user=user)
retrieve.assert_called_once_with(
document=document,
document_ids=[1, 2, 3],
)
def test_returns_populated_hints_when_nodes_found(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
)
mocker.patch(
"paperless_ai.taxonomy.retrieve_similar_nodes",
return_value=[make_node(tags=["Taxes"], document_type="Invoice")],
)
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
assert result == {
"tags": ["Taxes"],
"document_types": ["Invoice"],
"correspondents": [],
"storage_paths": [],
}
def test_returns_empty_hints_not_none_when_no_nodes(
self,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch(
"paperless_ai.taxonomy.AIConfig",
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
)
mocker.patch(
"paperless_ai.taxonomy.retrieve_similar_nodes",
return_value=[],
)
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
assert result == {
"tags": [],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
@@ -0,0 +1,77 @@
from types import SimpleNamespace
import pytest
import pytest_mock
from django.contrib.auth.models import User
from rest_framework.test import APIClient
from documents.models import Document
from documents.tests.factories import DocumentFactory
@pytest.mark.django_db
class TestSuggestionsHintWiring:
@pytest.fixture
def document(self) -> Document:
return DocumentFactory() # type: ignore[return-value]
@pytest.fixture
def api_client(self, admin_user: User) -> APIClient:
client = APIClient()
client.force_authenticate(user=admin_user)
return client
def test_hints_passed_to_classifier_and_matchers(
self,
api_client: APIClient,
document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
hints = {
"tags": ["Bloodwork"],
"document_types": [],
"correspondents": [],
"storage_paths": [],
}
mocker.patch(
"documents.views.get_taxonomy_hints_for_document",
return_value=hints,
)
mocker.patch(
"documents.views.AIConfig",
return_value=SimpleNamespace(
ai_enabled=True,
llm_backend="ollama",
llm_output_language=None,
),
)
# No cached suggestion -> the view reaches the classifier path.
mocker.patch(
"documents.views.get_llm_suggestion_cache",
return_value=None,
)
mocker.patch("documents.views.set_llm_suggestions_cache")
classify = mocker.patch(
"documents.views.get_ai_document_classification",
return_value={
"title": "Doc",
"tags": ["Bloodwork"],
"correspondents": [],
"document_types": [],
"storage_paths": [],
"dates": [],
},
)
match_tags = mocker.patch(
"documents.views.match_tags_by_name",
return_value=[],
)
mocker.patch("documents.views.match_correspondents_by_name", return_value=[])
mocker.patch("documents.views.match_document_types_by_name", return_value=[])
mocker.patch("documents.views.match_storage_paths_by_name", return_value=[])
response = api_client.get(f"/api/documents/{document.pk}/ai_suggestions/")
assert response.status_code == 200
assert classify.call_args.kwargs["hints"] == hints
assert match_tags.call_args.kwargs["hinted_names"] == {"Bloodwork"}
@@ -1,158 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:10
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
replaces = [
("paperless_mail", "0002_optimize_integer_field_sizes"),
("paperless_mail", "0003_mailrule_stop_processing"),
]
dependencies = [
("paperless_mail", "0001_squashed"),
]
operations = [
migrations.AlterField(
model_name="mailaccount",
name="account_type",
field=models.PositiveSmallIntegerField(
choices=[(1, "IMAP"), (2, "Gmail OAuth"), (3, "Outlook OAuth")],
default=1,
verbose_name="account type",
),
),
migrations.AlterField(
model_name="mailaccount",
name="imap_port",
field=models.PositiveIntegerField(
blank=True,
help_text="This is usually 143 for unencrypted and STARTTLS connections, and 993 for SSL connections.",
null=True,
verbose_name="IMAP port",
),
),
migrations.AlterField(
model_name="mailaccount",
name="imap_security",
field=models.PositiveSmallIntegerField(
choices=[(1, "No encryption"), (2, "Use SSL"), (3, "Use STARTTLS")],
default=2,
verbose_name="IMAP security",
),
),
migrations.AlterField(
model_name="mailrule",
name="action",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Delete"),
(2, "Move to specified folder"),
(3, "Mark as read, don't process read mails"),
(4, "Flag the mail, don't process flagged mails"),
(5, "Tag the mail with specified tag, don't process tagged mails"),
],
default=3,
verbose_name="action",
),
),
migrations.AlterField(
model_name="mailrule",
name="assign_correspondent_from",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Do not assign a correspondent"),
(2, "Use mail address"),
(3, "Use name (or mail address if not available)"),
(4, "Use correspondent selected below"),
],
default=1,
verbose_name="assign correspondent from",
),
),
migrations.AlterField(
model_name="mailrule",
name="assign_title_from",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Use subject as title"),
(2, "Use attachment filename as title"),
(3, "Do not assign title from rule"),
],
default=1,
verbose_name="assign title from",
),
),
migrations.AlterField(
model_name="mailrule",
name="attachment_type",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Only process attachments."),
(2, "Process all files, including 'inline' attachments."),
],
default=1,
help_text="Inline attachments include embedded images, so it's best to combine this option with a filename filter.",
verbose_name="attachment type",
),
),
migrations.AlterField(
model_name="mailrule",
name="consumption_scope",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Only process attachments."),
(
2,
"Process full Mail (with embedded attachments in file) as .eml",
),
(
3,
"Process full Mail (with embedded attachments in file) as .eml + process attachments as separate documents",
),
],
default=1,
verbose_name="consumption scope",
),
),
migrations.AlterField(
model_name="mailrule",
name="maximum_age",
field=models.PositiveSmallIntegerField(
default=30,
help_text="Specified in days.",
verbose_name="maximum age",
),
),
migrations.AlterField(
model_name="mailrule",
name="order",
field=models.SmallIntegerField(default=0, verbose_name="order"),
),
migrations.AlterField(
model_name="mailrule",
name="pdf_layout",
field=models.PositiveSmallIntegerField(
choices=[
(0, "System default"),
(1, "Text, then HTML"),
(2, "HTML, then text"),
(3, "HTML only"),
(4, "Text only"),
],
default=0,
verbose_name="pdf layout",
),
),
migrations.AddField(
model_name="mailrule",
name="stop_processing",
field=models.BooleanField(
default=False,
help_text="If True, no further rules will be processed after this one if any document is queued.",
verbose_name="Stop processing further rules",
),
),
]