Compare commits

..

3 Commits

Author SHA1 Message Date
Trenton H e149c139db Encodes the string just once for compare json 2026-06-12 14:23:34 -07:00
Trenton H a51cb6e231 Fix: Add directory marker entries to zip exports
Without explicit directory entries, some zip viewers (simpler tools,
web-based viewers) don't show the folder structure when browsing the
archive. Add a _ensure_zip_dirs() helper that writes directory markers
for all parent paths of each file entry, deduplicating via a set.
Uses ZipFile.mkdir() (available since Python 3.11, the project minimum).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-12 14:23:34 -07:00
Trenton H 57e7d2f0ce Refactor: Write zip exports directly into ZipFile instead of temp dir
Replace the temp-dir + shutil.make_archive() workaround with direct
zipfile.ZipFile writes. Document files are added via zf.write() and
JSON manifests via zf.writestr()/StringIO buffering, eliminating the
double-I/O and 2x disk usage of the previous approach.

Key changes:
- Removed tempfile.TemporaryDirectory and shutil.make_archive() from handle()
- ZipFile opened on a .tmp path; renamed to final .zip atomically on success;
  .tmp cleaned up on failure
- StreamingManifestWriter: zip mode buffers manifest in io.StringIO and
  writes to zip atomically on close() (zipfile allows only one open write
  handle at a time)
- check_and_copy(): zip mode calls zf.write(source, arcname=...) directly
- check_and_write_json(): zip mode calls zf.writestr(arcname, ...) directly
- files_in_export_dir scan skipped in zip mode (always fresh write)
- --compare-checksums and --compare-json emit warnings when used with --zip
- --delete in zip mode removes pre-existing files from target dir, skipping
  the in-progress .tmp and any prior .zip
- Added tests: atomicity on failure, no SCRATCH_DIR usage

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-12 14:23:34 -07:00
70 changed files with 1512 additions and 5307 deletions
-7
View File
@@ -2068,13 +2068,6 @@ context by default.
Defaults to 8192.
#### [`PAPERLESS_AI_LLM_REQUEST_TIMEOUT=<int>`](#PAPERLESS_AI_LLM_REQUEST_TIMEOUT) {#PAPERLESS_AI_LLM_REQUEST_TIMEOUT}
: The timeout, in seconds, for requests to the configured AI backend. Increase this when using
local or slow inference servers that need more time to generate responses.
Defaults to 120.
#### [`PAPERLESS_AI_LLM_BACKEND=<str>`](#PAPERLESS_AI_LLM_BACKEND) {#PAPERLESS_AI_LLM_BACKEND}
: The AI backend to use. This can be either "openai-like" or "ollama". If set to "ollama", the AI
+2 -1
View File
@@ -49,6 +49,7 @@ dependencies = [
"ijson>=3.2",
"imap-tools~=1.13.0",
"jinja2~=3.1.5",
"lancedb~=0.33.0",
"langdetect~=1.0.9",
"llama-index-core>=0.14.21",
"llama-index-embeddings-huggingface>=0.6.1",
@@ -61,6 +62,7 @@ dependencies = [
"openai>=2.32",
"pathvalidate~=3.3.1",
"pdf2image~=1.17.0",
"pyarrow>=16",
"python-dateutil~=2.9.0",
"python-dotenv~=1.2.1",
"python-gnupg~=0.5.4",
@@ -72,7 +74,6 @@ dependencies = [
"scikit-learn~=1.8.0",
"sentence-transformers>=5.4.1",
"setproctitle~=1.3.4",
"sqlite-vec==0.1.9",
"tantivy~=0.26.0",
"tika-client~=0.11.0",
"torch~=2.11.0",
+1 -1
View File
@@ -26,7 +26,7 @@ module.exports = {
'abstract-paperless-service',
],
transformIgnorePatterns: [
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|normalize-diacritics|@angular/common/locales/.*\\.js$))',
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|@angular/common/locales/.*\\.js$))',
],
moduleNameMapper: {
...esmPreset.moduleNameMapper,
-1
View File
@@ -32,7 +32,6 @@
"ngx-cookie-service": "^21.3.1",
"ngx-device-detector": "^11.0.0",
"ngx-ui-tour-ng-bootstrap": "^18.0.0",
"normalize-diacritics": "^5.0.0",
"pdfjs-dist": "^5.7.284",
"rxjs": "^7.8.2",
"tslib": "^2.8.1",
-11
View File
@@ -71,9 +71,6 @@ importers:
ngx-ui-tour-ng-bootstrap:
specifier: ^18.0.0
version: 18.0.0(f910a33494d223bd6dd07ce1bf22a35e)
normalize-diacritics:
specifier: ^5.0.0
version: 5.0.0
pdfjs-dist:
specifier: ^5.7.284
version: 5.7.284
@@ -5519,10 +5516,6 @@ packages:
engines: {node: ^20.17.0 || >=22.9.0}
hasBin: true
normalize-diacritics@5.0.0:
resolution: {integrity: sha512-t6czCJOpbAtckN1wCC2qPWnO3GQvNANb9bcUNbiOLEqojVuP31+ELIs5KhEG8jyz0TH7iD9BWxWz8O3ic2/rMQ==}
engines: {node: '>= 14.x', npm: '>= 6.x'}
normalize-path@3.0.0:
resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==}
engines: {node: '>=0.10.0'}
@@ -12938,10 +12931,6 @@ snapshots:
dependencies:
abbrev: 4.0.0
normalize-diacritics@5.0.0:
dependencies:
tslib: 2.8.1
normalize-path@3.0.0: {}
npm-bundled@5.0.0:
@@ -188,14 +188,4 @@ describe('ChatComponent', () => {
component.searchInputKeyDown(event)
expect(component.sendMessage).toHaveBeenCalled()
})
it('should not send message on Enter key press while composing with IME', () => {
jest.spyOn(component, 'sendMessage')
const event = new KeyboardEvent('keydown', {
key: 'Enter',
isComposing: true,
})
component.searchInputKeyDown(event)
expect(component.sendMessage).not.toHaveBeenCalled()
})
})
@@ -155,10 +155,7 @@ export class ChatComponent implements OnInit {
}
public searchInputKeyDown(event: KeyboardEvent) {
if (
event.key === 'Enter' &&
!(event.isComposing || event.keyCode === 229)
) {
if (event.key === 'Enter') {
event.preventDefault()
this.sendMessage()
}
@@ -23,7 +23,6 @@ import {
import { CustomFieldsService } from 'src/app/services/rest/custom-fields.service'
import { ToastService } from 'src/app/services/toast.service'
import { pngxPopperOptions } from 'src/app/utils/popper-options'
import { matchesSearchText } from 'src/app/utils/text-search'
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
import { CustomFieldEditDialogComponent } from '../edit-dialog/custom-field-edit-dialog/custom-field-edit-dialog.component'
@@ -70,7 +69,9 @@ export class CustomFieldsDropdownComponent extends LoadingComponentWithPermissio
public get filteredFields(): CustomField[] {
return this.unusedFields.filter(
(f) => !this.filterText || matchesSearchText(f.name, this.filterText)
(f) =>
!this.filterText ||
f.name.toLowerCase().includes(this.filterText.toLowerCase())
)
}
@@ -63,7 +63,6 @@
[(ngModel)]="atom.value"
[disabled]="disabled"
[virtualScroll]="getSelectOptionsForField(atom.field)?.length > 100"
[searchFn]="selectOptionSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
} @else if (getCustomFieldByID(atom.field)?.data_type === CustomFieldDataType.DocumentLink) {
@@ -82,7 +81,6 @@
[disabled]="disabled"
bindLabel="name"
bindValue="id"
[searchFn]="customFieldSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
<select class="w-25 form-select" [(ngModel)]="atom.operator" [disabled]="disabled">
@@ -127,7 +125,6 @@
[(ngModel)]="atom.value"
[disabled]="disabled"
[multiple]="true"
[searchFn]="selectOptionSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
}
@@ -36,7 +36,6 @@ import {
CustomFieldQueryExpression,
} from 'src/app/utils/custom-field-query-element'
import { pngxPopperOptions } from 'src/app/utils/popper-options'
import { matchesSearchText } from 'src/app/utils/text-search'
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
import { ClearableBadgeComponent } from '../clearable-badge/clearable-badge.component'
import { DocumentLinkComponent } from '../input/document-link/document-link.component'
@@ -282,14 +281,6 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm
public readonly today: string = new Date().toLocaleDateString('en-CA')
public customFieldSearchFn = (term: string, field: CustomField): boolean =>
matchesSearchText(field?.name, term)
public selectOptionSearchFn = (
term: string,
option: { id: string; label: string }
): boolean => matchesSearchText(option?.label, term)
constructor() {
super()
this.selectionModel = new CustomFieldQueriesModel()
@@ -28,7 +28,6 @@
[notFoundText]="notFoundText"
[multiple]="multiple"
[bindLabel]="bindLabel"
[searchFn]="searchFn"
bindValue="id"
[virtualScroll]="items?.length > 100"
(change)="onChange(value)"
@@ -112,15 +112,6 @@ describe('SelectComponent', () => {
expect(createNewVal).toEqual('baz')
})
it('should search items by independent normalized terms', () => {
expect(
component.searchFn('tax 26', { id: 11, name: 'Tax\u00e9s 2026' })
).toBeTruthy()
expect(
component.searchFn('tax receipt', { id: 11, name: 'Tax\u00e9s 2026' })
).toBeFalsy()
})
it('should clear search term on blur after delay', fakeAsync(() => {
const clearSpy = jest.spyOn(component, 'clearLastSearchTerm')
component.onBlur()
@@ -13,7 +13,6 @@ import {
import { RouterModule } from '@angular/router'
import { NgSelectModule } from '@ng-select/ng-select'
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { matchesSearchText } from 'src/app/utils/text-search'
import { AbstractInputComponent } from '../abstract-input'
@Component({
@@ -100,9 +99,6 @@ export class SelectComponent extends AbstractInputComponent<number> {
@Input()
bindLabel: string = 'name'
public searchFn = (term: string, item: any): boolean =>
matchesSearchText(item?.[this.bindLabel], term)
@Input()
showFilter: boolean = false
@@ -14,7 +14,6 @@
[clearSearchOnAdd]="true"
[hideSelected]="tags.length > 0"
[addTag]="allowCreate ? createTagRef : false"
[searchFn]="searchFn"
addTagText="Add tag"
i18n-addTagText
(add)="onAdd($event)"
@@ -171,15 +171,6 @@ describe('TagsComponent', () => {
expect(component.getTag(4)).toBeUndefined()
})
it('should search tags by independent normalized terms including parents', () => {
const parent: Tag = { id: 11, name: 'Financ\u00e9' }
const child: Tag = { id: 12, name: 'Taxes 2026', parent: parent.id }
component.tags = [parent, child]
expect(component.searchFn('finance 26', child)).toBeTruthy()
expect(component.searchFn('finance receipt', child)).toBeFalsy()
})
it('should emit filtered documents', () => {
component.value = [10]
component.tags = tags
@@ -21,7 +21,6 @@ import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { first, firstValueFrom, tap } from 'rxjs'
import { Tag } from 'src/app/data/tag'
import { TagService } from 'src/app/services/rest/tag.service'
import { matchesSearchText } from 'src/app/utils/text-search'
import { EditDialogMode } from '../../edit-dialog/edit-dialog.component'
import { TagEditDialogComponent } from '../../edit-dialog/tag-edit-dialog/tag-edit-dialog.component'
import { TagComponent } from '../../tag/tag.component'
@@ -115,14 +114,6 @@ export class TagsComponent implements OnInit, ControlValueAccessor {
public createTagRef: (name) => void
public searchFn = (term: string, tag: Tag): boolean =>
matchesSearchText(
[this.getParentChain(tag?.id).map((parent) => parent.name), tag?.name]
.flat()
.join(' '),
term
)
getTag(id: number) {
if (this.tags) {
return this.tags.find((tag) => tag.id == id)
@@ -131,9 +131,7 @@
@if (status.tasks.celery_status === 'OK') {
<i-bs name="check-circle-fill" class="text-primary ms-2 lh-1"></i-bs>
} @else {
<i-bs name="exclamation-triangle-fill" class="ms-2 lh-1"
[class.text-danger]="status.tasks.celery_status === SystemStatusItemStatus.ERROR"
[class.text-warning]="status.tasks.celery_status === SystemStatusItemStatus.WARNING"></i-bs>
<i-bs name="exclamation-triangle-fill" class="text-danger ms-2 lh-1"></i-bs>
}
</button>
<ng-template #celeryStatus>
-9
View File
@@ -360,14 +360,6 @@ export const PaperlessConfigOptions: ConfigOption[] = [
category: ConfigCategory.AI,
note: $localize`Language to use for generated AI suggestions. When unset, AI suggestions use the user's display language if explicitly set.`,
},
{
key: 'llm_request_timeout',
title: $localize`LLM Request Timeout`,
type: ConfigOptionType.Number,
config_key: 'PAPERLESS_AI_LLM_REQUEST_TIMEOUT',
category: ConfigCategory.AI,
note: $localize`Timeout in seconds for LLM requests.`,
},
]
export interface PaperlessConfig extends ObjectWithId {
@@ -409,5 +401,4 @@ export interface PaperlessConfig extends ObjectWithId {
llm_api_key: string
llm_endpoint: string
llm_output_language: string
llm_request_timeout: number
}
+3 -2
View File
@@ -1,6 +1,5 @@
import { Pipe, PipeTransform } from '@angular/core'
import { MatchingModel } from '../data/matching-model'
import { matchesSearchText } from '../utils/text-search'
@Pipe({
name: 'filter',
@@ -22,7 +21,9 @@ export class FilterPipe implements PipeTransform {
typeof item[key] === 'string' || typeof item[key] === 'number'
)
return keys.some((key) => {
return matchesSearchText(item[key], searchText)
return String(item[key])
.toLowerCase()
.includes(searchText.toLowerCase())
})
})
}
-17
View File
@@ -1,17 +0,0 @@
import { matchesSearchText } from './text-search'
describe('text search utilities', () => {
it('matches text accent-insensitively', () => {
expect(matchesSearchText('R\u00e9sum\u00e9', 'resume')).toBeTruthy()
expect(matchesSearchText('S\u00f8ren', 'soren')).toBeTruthy()
expect(matchesSearchText('\u0152uvre', 'oeuvre')).toBeTruthy()
expect(matchesSearchText('Invoice', 'receipt')).toBeFalsy()
})
it('matches all whitespace-separated search terms independently', () => {
expect(matchesSearchText('taxes 2026', 'tax 26')).toBeTruthy()
expect(matchesSearchText('2026 taxes', 'tax 26')).toBeTruthy()
expect(matchesSearchText('Tax\u00e9s 2026', 'taxe 26')).toBeTruthy()
expect(matchesSearchText('taxes 2026', 'tax receipt')).toBeFalsy()
})
})
-23
View File
@@ -1,23 +0,0 @@
import { normalizeSync } from 'normalize-diacritics'
export type SearchTextValue =
| string
| number
| boolean
| bigint
| null
| undefined
export function normalizeSearchText(value: SearchTextValue): string {
return normalizeSync(String(value ?? '')).toLocaleLowerCase()
}
export function matchesSearchText(
value: SearchTextValue,
searchText: SearchTextValue
): boolean {
const normalizedValue = normalizeSearchText(value)
const searchTerms = normalizeSearchText(searchText).trim().split(/\s+/)
return searchTerms.every((term) => normalizedValue.includes(term))
}
@@ -169,10 +169,6 @@ class FileStabilityTracker:
self._tracked.pop(path, None)
yield path
def is_tracking(self, path: Path) -> bool:
"""Check whether a path is currently being tracked for stability."""
return path.resolve() in self._tracked
def has_pending_files(self) -> bool:
"""Check if there are files waiting for stability check."""
return len(self._tracked) > 0
@@ -374,16 +370,6 @@ class Command(BaseCommand):
# Testing timeout in seconds
testing_timeout_s: Final[float] = 0.5
# How often to perform a full-glob rescan of the consume directory as a
# safety net. Each watchfiles watcher is torn down and recreated on every
# batch to reconfigure its timeout, and a fresh watcher silently adopts the
# current directory contents as its baseline. A file that appears between
# one batch and the next watcher's baseline is therefore never reported and
# would sit in the consume directory forever. This periodic rescan re-injects
# such files into the stability tracker (see GH issue #13011). Not currently
# user-configurable; instances may override for testing.
rescan_interval_s: float = 300.0
def add_arguments(self, parser) -> None:
parser.add_argument(
"directory",
@@ -439,7 +425,7 @@ class Command(BaseCommand):
)
# Process existing files
queued = self._process_existing_files(
self._process_existing_files(
directory=directory,
recursive=recursive,
subdirs_as_tags=subdirs_as_tags,
@@ -459,7 +445,6 @@ class Command(BaseCommand):
polling_interval=polling_interval,
stability_delay=stability_delay,
is_testing=is_testing,
queued=queued,
)
logger.debug("Consumer exiting")
@@ -471,18 +456,11 @@ class Command(BaseCommand):
recursive: bool,
subdirs_as_tags: bool,
consumer_filter: ConsumerFilter,
) -> set[Path]:
"""
Process any existing files in the consumption directory.
Returns the set of resolved paths that were queued, so the watch loop
can seed its in-flight set and avoid re-queuing them on the first
rescan before the consume tasks have removed them from disk.
"""
) -> None:
"""Process any existing files in the consumption directory."""
logger.info(f"Processing existing files in {directory}")
glob_pattern = "**/*" if recursive else "*"
queued: set[Path] = set()
for filepath in directory.glob(glob_pattern):
# Use filter to check if file should be processed
@@ -497,48 +475,6 @@ class Command(BaseCommand):
consumption_dir=directory,
subdirs_as_tags=subdirs_as_tags,
)
queued.add(filepath.resolve())
return queued
def _rescan_existing_files(
self,
*,
directory: Path,
recursive: bool,
consumer_filter: ConsumerFilter,
tracker: FileStabilityTracker,
queued: set[Path],
) -> None:
"""
Re-inject on-disk files the watcher never reported into the tracker.
Acts as a safety net for files stranded by the watcher-recreation gap
(see ``rescan_interval_s``). Files already being tracked or already
queued and awaiting consumption are skipped, so a file is never queued
twice. Queued paths that have since left the directory are pruned so a
later file reusing the same name is not skipped forever.
"""
# Prune in-flight paths that have left the directory
for path in list(queued):
if not path.exists():
queued.discard(path)
glob_pattern = "**/*" if recursive else "*"
for filepath in directory.glob(glob_pattern):
if not filepath.is_file():
continue
if not consumer_filter(Change.added, str(filepath)):
continue
resolved = filepath.resolve()
if tracker.is_tracking(resolved) or resolved in queued:
continue
logger.debug(f"Rescan found untracked file: {resolved}")
tracker.track(resolved, Change.added)
def _watch_directory(
self,
@@ -550,24 +486,11 @@ class Command(BaseCommand):
polling_interval: float,
stability_delay: float,
is_testing: bool,
queued: set[Path] | None = None,
) -> None:
"""Watch directory for changes and process stable files."""
use_polling = polling_interval > 0
poll_delay_ms = int(polling_interval * 1000) if use_polling else 0
# Resolved paths that have been queued and are awaiting consumption.
# Seeded from the startup scan so the first rescan does not re-queue
# files whose consume tasks have not yet removed them from disk.
queued = set() if queued is None else queued
# Full-glob safety net cadence (0 disables)
rescan_interval_s = self.rescan_interval_s
rescan_timeout_ms = (
int(rescan_interval_s * 1000) if rescan_interval_s > 0 else 0
)
last_rescan = monotonic()
if use_polling:
logger.info(
f"Watching {directory} using polling (interval: {polling_interval}s)",
@@ -582,20 +505,6 @@ class Command(BaseCommand):
stability_timeout_ms = int(stability_delay * 1000)
testing_timeout_ms = int(self.testing_timeout_s * 1000)
def cap_for_rescan(ms: int) -> int:
"""
Ensure the watch loop wakes often enough to run the rescan.
``watch()`` blocks for up to ``rust_timeout``, so the rescan can
only run that often. A timeout of 0 means "wait indefinitely",
which would never wake to rescan; cap it at the rescan interval.
"""
if rescan_timeout_ms <= 0:
return ms
if ms <= 0:
return rescan_timeout_ms
return min(ms, rescan_timeout_ms)
# Calculate appropriate timeout for watch loop
# In polling mode, rust_timeout must be significantly longer than poll_delay_ms
# to ensure poll cycles can complete before timing out
@@ -613,8 +522,6 @@ class Command(BaseCommand):
# Not testing, wait indefinitely for first event
timeout_ms = 0
timeout_ms = cap_for_rescan(timeout_ms)
self.stop_flag.clear()
while not self.stop_flag.is_set():
@@ -644,26 +551,10 @@ class Command(BaseCommand):
consumption_dir=directory,
subdirs_as_tags=subdirs_as_tags,
)
# Remember it so the rescan does not re-queue it while
# the consume task has yet to remove it from disk
queued.add(stable_path)
# Exit watch loop to reconfigure timeout
break
# Periodic full-glob safety net for files the watcher missed
if rescan_timeout_ms > 0 and (
monotonic() - last_rescan >= rescan_interval_s
):
self._rescan_existing_files(
directory=directory,
recursive=recursive,
consumer_filter=consumer_filter,
tracker=tracker,
queued=queued,
)
last_rescan = monotonic()
# Determine next timeout
if tracker.has_pending_files():
# Check pending files at stability interval
@@ -681,8 +572,6 @@ class Command(BaseCommand):
# No pending files, wait indefinitely
timeout_ms = 0
timeout_ms = cap_for_rescan(timeout_ms)
except KeyboardInterrupt: # pragma: nocover
logger.info("Received interrupt, stopping consumer")
self.stop_flag.set()
@@ -1,8 +1,9 @@
import hashlib
import io
import json
import os
import shutil
import tempfile
import zipfile
from itertools import islice
from pathlib import Path
from typing import TYPE_CHECKING
@@ -98,6 +99,8 @@ class StreamingManifestWriter:
*,
compare_json: bool = False,
files_in_export_dir: "set[Path] | None" = None,
zip_file: "zipfile.ZipFile | None" = None,
zip_arcname: str | None = None,
) -> None:
self._path = path.resolve()
self._tmp_path = self._path.with_suffix(self._path.suffix + ".tmp")
@@ -105,12 +108,20 @@ class StreamingManifestWriter:
self._files_in_export_dir: set[Path] = (
files_in_export_dir if files_in_export_dir is not None else set()
)
self._zip_file = zip_file
self._zip_arcname = zip_arcname
self._zip_mode = zip_file is not None
self._file = None
self._first = True
def open(self) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True)
self._file = self._tmp_path.open("w", encoding="utf-8")
if self._zip_mode:
# zipfile only allows one open write handle at a time, so buffer
# the manifest in memory and write it atomically on close()
self._file = io.StringIO()
else:
self._path.parent.mkdir(parents=True, exist_ok=True)
self._file = self._tmp_path.open("w", encoding="utf-8")
self._file.write("[")
self._first = True
@@ -131,15 +142,18 @@ class StreamingManifestWriter:
if self._file is None:
return
self._file.write("\n]")
if self._zip_mode:
self._zip_file.writestr(self._zip_arcname, self._file.getvalue())
self._file.close()
self._file = None
self._finalize()
if not self._zip_mode:
self._finalize()
def discard(self) -> None:
if self._file is not None:
self._file.close()
self._file = None
if self._tmp_path.exists():
if not self._zip_mode and self._tmp_path.exists():
self._tmp_path.unlink()
def _finalize(self) -> None:
@@ -316,18 +330,13 @@ class Command(CryptMixin, PaperlessCommand):
self.files_in_export_dir: set[Path] = set()
self.exported_files: set[str] = set()
self.zip_file: zipfile.ZipFile | None = None
self._zip_dirs: set[str] = set()
# If zipping, save the original target for later and
# get a temporary directory for the target instead
temp_dir = None
self.original_target = self.target
if self.zip_export:
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
temp_dir = tempfile.TemporaryDirectory(
dir=settings.SCRATCH_DIR,
prefix="paperless-export",
)
self.target = Path(temp_dir.name).resolve()
zip_name = options["zip_name"]
self.zip_path = (self.target / zip_name).with_suffix(".zip")
self.zip_tmp_path = self.zip_path.parent / (self.zip_path.name + ".tmp")
if not self.target.exists():
raise CommandError("That path doesn't exist")
@@ -338,30 +347,53 @@ class Command(CryptMixin, PaperlessCommand):
if not os.access(self.target, os.W_OK):
raise CommandError("That path doesn't appear to be writable")
if self.zip_export:
if self.compare_checksums:
self.stdout.write(
self.style.WARNING(
"--compare-checksums is ignored when --zip is used",
),
)
if self.compare_json:
self.stdout.write(
self.style.WARNING(
"--compare-json is ignored when --zip is used",
),
)
try:
# Prevent any ongoing changes in the documents
with FileLock(settings.MEDIA_LOCK):
self.dump()
# We've written everything to the temporary directory in this case,
# now make an archive in the original target, with all files stored
if self.zip_export and temp_dir is not None:
shutil.make_archive(
self.original_target / options["zip_name"],
format="zip",
root_dir=temp_dir.name,
if self.zip_export:
self.zip_file = zipfile.ZipFile(
self.zip_tmp_path,
"w",
compression=zipfile.ZIP_DEFLATED,
allowZip64=True,
)
self.dump()
if self.zip_file is not None:
self.zip_file.close()
self.zip_file = None
self.zip_tmp_path.rename(self.zip_path)
finally:
# Always cleanup the temporary directory, if one was created
if self.zip_export and temp_dir is not None:
temp_dir.cleanup()
# Ensure zip_file is closed and the incomplete .tmp is removed on failure
if self.zip_file is not None:
self.zip_file.close()
self.zip_file = None
if self.zip_export and self.zip_tmp_path.exists():
self.zip_tmp_path.unlink()
def dump(self) -> None:
# 1. Take a snapshot of what files exist in the current export folder
for x in self.target.glob("**/*"):
if x.is_file():
self.files_in_export_dir.add(x.resolve())
# (skipped in zip mode — always write fresh, no skip/compare logic applies)
if not self.zip_export:
for x in self.target.glob("**/*"):
if x.is_file():
self.files_in_export_dir.add(x.resolve())
# 2. Create manifest, containing all correspondents, types, tags, storage paths
# note, documents and ui_settings
@@ -433,6 +465,8 @@ class Command(CryptMixin, PaperlessCommand):
manifest_path,
compare_json=self.compare_json,
files_in_export_dir=self.files_in_export_dir,
zip_file=self.zip_file,
zip_arcname="manifest.json",
) as writer:
with transaction.atomic():
for key, qs in manifest_key_to_object_query.items():
@@ -551,8 +585,12 @@ class Command(CryptMixin, PaperlessCommand):
self.target,
)
else:
# 5. Remove anything in the original location (before moving the zip)
for item in self.original_target.glob("*"):
# 5. Remove pre-existing files/dirs from target, keeping the
# in-progress zip (.tmp) and any prior zip at the final path
skip = {self.zip_path.resolve(), self.zip_tmp_path.resolve()}
for item in self.target.glob("*"):
if item.resolve() in skip:
continue
if item.is_dir():
shutil.rmtree(item)
else:
@@ -722,9 +760,23 @@ class Command(CryptMixin, PaperlessCommand):
if self.use_folder_prefix:
manifest_name = Path("json") / manifest_name
manifest_name = (self.target / manifest_name).resolve()
manifest_name.parent.mkdir(parents=True, exist_ok=True)
if not self.zip_export:
manifest_name.parent.mkdir(parents=True, exist_ok=True)
self.check_and_write_json(content, manifest_name)
def _ensure_zip_dirs(self, arcname: str) -> None:
"""Write directory marker entries for all parent directories of arcname.
Some zip viewers only show folder structure when explicit directory
entries exist, so we add them to avoid confusing users.
"""
parts = Path(arcname).parts[:-1]
for i in range(len(parts)):
dir_arc = "/".join(parts[: i + 1]) + "/"
if dir_arc not in self._zip_dirs:
self._zip_dirs.add(dir_arc)
self.zip_file.mkdir(dir_arc)
def check_and_write_json(
self,
content: list[dict] | dict,
@@ -737,32 +789,38 @@ class Command(CryptMixin, PaperlessCommand):
This preserves the file timestamps when no changes are made.
"""
target = target.resolve()
perform_write = True
if target in self.files_in_export_dir:
self.files_in_export_dir.remove(target)
if self.compare_json:
target_checksum = hashlib.blake2b(target.read_bytes()).hexdigest()
src_str = json.dumps(
content,
cls=DjangoJSONEncoder,
indent=2,
ensure_ascii=False,
)
src_checksum = hashlib.blake2b(src_str.encode("utf-8")).hexdigest()
if src_checksum == target_checksum:
perform_write = False
if perform_write:
target.write_text(
if self.zip_export:
arcname = str(target.resolve().relative_to(self.target))
self._ensure_zip_dirs(arcname)
self.zip_file.writestr(
arcname,
json.dumps(
content,
cls=DjangoJSONEncoder,
indent=2,
ensure_ascii=False,
),
encoding="utf-8",
)
return
target = target.resolve()
json_str = json.dumps(
content,
cls=DjangoJSONEncoder,
indent=2,
ensure_ascii=False,
)
perform_write = True
if target in self.files_in_export_dir:
self.files_in_export_dir.remove(target)
if self.compare_json:
target_checksum = hashlib.blake2b(target.read_bytes()).hexdigest()
src_checksum = hashlib.blake2b(json_str.encode("utf-8")).hexdigest()
if src_checksum == target_checksum:
perform_write = False
if perform_write:
target.write_text(json_str, encoding="utf-8")
def check_and_copy(
self,
@@ -775,6 +833,12 @@ class Command(CryptMixin, PaperlessCommand):
the source attributes
"""
if self.zip_export:
arcname = str(target.resolve().relative_to(self.target))
self._ensure_zip_dirs(arcname)
self.zip_file.write(source, arcname=arcname)
return
target = target.resolve()
if target in self.files_in_export_dir:
self.files_in_export_dir.remove(target)
@@ -1,63 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:31
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
replaces = [
("documents", "0003_remove_document_storage_type"),
("documents", "0004_workflowtrigger_filter_has_any_correspondents_and_more"),
("documents", "0005_alter_document_checksum_unique"),
]
dependencies = [
("documents", "0002_squashed"),
]
operations = [
migrations.RemoveField(
model_name="document",
name="storage_type",
),
migrations.AddField(
model_name="workflowtrigger",
name="filter_has_any_correspondents",
field=models.ManyToManyField(
blank=True,
related_name="workflowtriggers_has_any_correspondent",
to="documents.correspondent",
verbose_name="has one of these correspondents",
),
),
migrations.AddField(
model_name="workflowtrigger",
name="filter_has_any_document_types",
field=models.ManyToManyField(
blank=True,
related_name="workflowtriggers_has_any_document_type",
to="documents.documenttype",
verbose_name="has one of these document types",
),
),
migrations.AddField(
model_name="workflowtrigger",
name="filter_has_any_storage_paths",
field=models.ManyToManyField(
blank=True,
related_name="workflowtriggers_has_any_storage_path",
to="documents.storagepath",
verbose_name="has one of these storage paths",
),
),
migrations.AlterField(
model_name="document",
name="checksum",
field=models.CharField(
editable=False,
help_text="The checksum of the original document.",
max_length=32,
verbose_name="checksum",
),
),
]
@@ -1,252 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:31
import django.db.models.deletion
import django.db.models.functions.text
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
replaces = [
("documents", "0008_workflowaction_passwords_alter_workflowaction_type"),
("documents", "0009_alter_document_content_length"),
("documents", "0010_optimize_integer_field_sizes"),
("documents", "0011_alter_workflowaction_type"),
("documents", "0012_document_root_document"),
]
dependencies = [
("documents", "0007_sharelinkbundle"),
]
operations = [
migrations.AddField(
model_name="workflowaction",
name="passwords",
field=models.JSONField(
blank=True,
help_text="Passwords to try when removing PDF protection. Separate with commas or new lines.",
null=True,
verbose_name="passwords",
),
),
migrations.AlterField(
model_name="document",
name="content_length",
field=models.GeneratedField(
db_persist=True,
expression=django.db.models.functions.text.Length("content"),
help_text="Length of the content field in characters. Automatically maintained by the database for faster statistics computation.",
output_field=models.PositiveIntegerField(default=0),
serialize=False,
),
),
migrations.AlterField(
model_name="correspondent",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="documenttype",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="savedviewfilterrule",
name="rule_type",
field=models.PositiveSmallIntegerField(
choices=[
(0, "title contains"),
(1, "content contains"),
(2, "ASN is"),
(3, "correspondent is"),
(4, "document type is"),
(5, "is in inbox"),
(6, "has tag"),
(7, "has any tag"),
(8, "created before"),
(9, "created after"),
(10, "created year is"),
(11, "created month is"),
(12, "created day is"),
(13, "added before"),
(14, "added after"),
(15, "modified before"),
(16, "modified after"),
(17, "does not have tag"),
(18, "does not have ASN"),
(19, "title or content contains"),
(20, "fulltext query"),
(21, "more like this"),
(22, "has tags in"),
(23, "ASN greater than"),
(24, "ASN less than"),
(25, "storage path is"),
(26, "has correspondent in"),
(27, "does not have correspondent in"),
(28, "has document type in"),
(29, "does not have document type in"),
(30, "has storage path in"),
(31, "does not have storage path in"),
(32, "owner is"),
(33, "has owner in"),
(34, "does not have owner"),
(35, "does not have owner in"),
(36, "has custom field value"),
(37, "is shared by me"),
(38, "has custom fields"),
(39, "has custom field in"),
(40, "does not have custom field in"),
(41, "does not have custom field"),
(42, "custom fields query"),
(43, "created to"),
(44, "created from"),
(45, "added to"),
(46, "added from"),
(47, "mime type is"),
],
verbose_name="rule type",
),
),
migrations.AlterField(
model_name="storagepath",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="tag",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="workflowrun",
name="type",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Consumption Started"),
(2, "Document Added"),
(3, "Document Updated"),
(4, "Scheduled"),
],
null=True,
verbose_name="workflow trigger type",
),
),
migrations.AlterField(
model_name="workflowtrigger",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
],
default=0,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="workflowtrigger",
name="type",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Consumption Started"),
(2, "Document Added"),
(3, "Document Updated"),
(4, "Scheduled"),
],
default=1,
verbose_name="Workflow Trigger Type",
),
),
migrations.AlterField(
model_name="workflowaction",
name="type",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Assignment"),
(2, "Removal"),
(3, "Email"),
(4, "Webhook"),
(5, "Password removal"),
(6, "Move to trash"),
],
default=1,
verbose_name="Workflow Action Type",
),
),
migrations.AddField(
model_name="document",
name="root_document",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="versions",
to="documents.document",
verbose_name="root document for this version",
),
),
migrations.AddField(
model_name="document",
name="version_label",
field=models.CharField(
blank=True,
help_text="Optional short label for a document version.",
max_length=64,
null=True,
verbose_name="version label",
),
),
]
-4
View File
@@ -8,15 +8,11 @@ from documents.search._backend import get_backend
from documents.search._backend import reset_backend
from documents.search._schema import needs_rebuild
from documents.search._schema import wipe_index
from documents.search._translate import InvalidDateQuery
from documents.search._translate import SearchQueryError
__all__ = [
"InvalidDateQuery",
"SearchHit",
"SearchIndexLockError",
"SearchMode",
"SearchQueryError",
"TantivyBackend",
"TantivyRelevanceList",
"WriteBatch",
+2 -18
View File
@@ -866,24 +866,8 @@ class TantivyBackend:
final_query = self._apply_permission_filter(mlt_query, user)
effective_limit = limit if limit is not None else searcher.num_docs
try:
# Fetch one extra to account for excluding the original document
results = searcher.search(final_query, limit=effective_limit + 1)
except BaseException: # pragma: no cover
# Tantivy 0.26 panics in BM25 idf scoring when the index holds
# soft-deleted documents (doc_freq can exceed the alive doc count),
# which only surfaces for the More Like This query. The panic crosses
# the pyo3 boundary as a `pyo3_runtime.PanicException` — a
# BaseException, not an Exception — so catch BaseException and degrade
# to "no similar documents" instead of bubbling a 500 to the client.
# Fixed upstream: https://github.com/quickwit-oss/tantivy/pull/2964
# Remove once the bundled tantivy includes that fix.
logger.warning(
"More Like This scoring panicked (likely stale tantivy segment "
"stats after deletions); returning no results. A search index "
"reindex will rebuild consistent statistics.",
)
return []
# Fetch one extra to account for excluding the original document
results = searcher.search(final_query, limit=effective_limit + 1)
addrs = [addr for _score, addr in results.hits]
all_ids = cast("list[int]", searcher.fast_field_values("id", addrs))
-163
View File
@@ -1,163 +0,0 @@
from __future__ import annotations
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Final
from dateutil.relativedelta import relativedelta
if TYPE_CHECKING:
from datetime import tzinfo
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _quarter_start(d: date) -> date:
"""Return the first day of the calendar quarter containing ``d``."""
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
def _midnight(d: date, tz: tzinfo) -> datetime:
"""Convert a calendar date at local-timezone midnight to a UTC datetime."""
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _keyword_bounds(keyword: str, tz: tzinfo) -> tuple[date, date]:
"""
Map a relative date keyword to ``(start, exclusive_end)`` calendar dates.
``tz`` only determines what "today" is; the caller decides how the returned
dates become UTC datetime boundaries (date-only vs. local-midnight offset).
"""
today = datetime.now(tz).date()
if keyword == _TODAY:
return today, today + timedelta(days=1)
if keyword == _YESTERDAY:
return today - timedelta(days=1), today
if keyword == _PREVIOUS_WEEK:
this_monday = today - timedelta(days=today.weekday())
return this_monday - timedelta(weeks=1), this_monday
if keyword == _THIS_MONTH:
first = today.replace(day=1)
return first, first + relativedelta(months=1)
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
return this_first - relativedelta(months=1), this_first
if keyword == _THIS_YEAR:
return date(today.year, 1, 1), date(today.year + 1, 1, 1)
if keyword == _PREVIOUS_YEAR:
return date(today.year - 1, 1, 1), date(today.year, 1, 1)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
return this_quarter - relativedelta(months=3), this_quarter
raise ValueError(f"Unknown keyword: {keyword}")
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic — date only.
"""
start, end = _keyword_bounds(keyword, tz)
lo = datetime(start.year, start.month, start.day, tzinfo=UTC)
hi = datetime(end.year, end.month, end.day, tzinfo=UTC)
return _iso_range(lo, hi)
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC — full offset arithmetic required.
"""
start, end = _keyword_bounds(keyword, tz)
return _iso_range(_midnight(start, tz), _midnight(end, tz))
def _precision_bounds(digits: str) -> tuple[date, date] | None:
"""
Map a 4/6/8-digit date token to (start, exclusive_end) calendar dates.
YYYY -> whole year, YYYYMM -> whole month, YYYYMMDD -> single day.
Returns None for any unparsable or out-of-range value (e.g. month 23),
so callers can emit a no-match clause instead of erroring (Whoosh parity).
"""
try:
if len(digits) == 4:
year = int(digits)
return date(year, 1, 1), date(year + 1, 1, 1)
if len(digits) == 6:
year, month = int(digits[:4]), int(digits[4:6])
start = date(year, month, 1)
end = date(year + 1, 1, 1) if month == 12 else date(year, month + 1, 1)
return start, end
if len(digits) == 8:
start = date(int(digits[:4]), int(digits[4:6]), int(digits[6:8]))
return start, start + timedelta(days=1)
except ValueError:
return None
return None
def _utc_bounds_for_field(
field: str,
start: date,
end: date,
tz: tzinfo,
) -> tuple[datetime, datetime]:
"""
Convert calendar-date bounds to UTC datetimes per the field's storage type.
For DateField (``created``) the bounds are UTC midnight (no offset). For
DateTimeField (``added``/``modified``) the bounds are local-tz midnight
converted to UTC, matching how each field is indexed.
"""
if field in _DATE_ONLY_FIELDS:
return (
datetime(start.year, start.month, start.day, tzinfo=UTC),
datetime(end.year, end.month, end.day, tzinfo=UTC),
)
return (
datetime(start.year, start.month, start.day, tzinfo=tz).astimezone(UTC),
datetime(end.year, end.month, end.day, tzinfo=tz).astimezone(UTC),
)
def _field_range_from_dates(field: str, start: date, end: date, tz: tzinfo) -> str:
"""Build a Tantivy ``field:[lo TO hi]`` ISO range from calendar-date bounds."""
lo, hi = _utc_bounds_for_field(field, start, end, tz)
return f"{field}:{_iso_range(lo, hi)}"
+405 -27
View File
@@ -1,35 +1,88 @@
from __future__ import annotations
import logging
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Final
import regex
import tantivy
from dateutil.relativedelta import relativedelta
from django.conf import settings
from documents.search._dates import (
_date_only_range, # noqa: F401 — re-exported for test imports
)
from documents.search._dates import (
_datetime_range, # noqa: F401 — re-exported for test imports
)
from documents.search._tokenizer import simple_search_tokens
from documents.search._translate import SearchQueryError
from documents.search._translate import translate_query
if TYPE_CHECKING:
from datetime import tzinfo
from django.contrib.auth.base_user import AbstractBaseUser
logger = logging.getLogger("paperless.search")
# Maximum seconds any single regex substitution may run.
# Prevents ReDoS on adversarial user-supplied query strings.
_REGEX_TIMEOUT: Final[float] = 1.0
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
_DATE_KEYWORD_PATTERN = "|".join(
sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True),
)
_FIELD_DATE_RE = regex.compile(
rf"""(?<!\w)(?P<field>created|modified|added)\s*:\s*(?:
(?P<quote>["'])(?P<quoted>{_DATE_KEYWORD_PATTERN})(?P=quote)
|
(?P<bare>{_DATE_KEYWORD_PATTERN})(?![\w-])
)""",
regex.IGNORECASE | regex.VERBOSE,
)
_COMPACT_DATE_RE = regex.compile(r"\b(\d{14})\b")
_RELATIVE_RANGE_RE = regex.compile(
r"\[now([+-]\d+[dhm])?\s+TO\s+now([+-]\d+[dhm])?\]",
regex.IGNORECASE,
)
# Whoosh-style relative date range: e.g. [-1 week to now], [-7 days to now]
_WHOOSH_REL_RANGE_RE = regex.compile(
r"\[-(?P<n>\d+)\s+(?P<unit>second|minute|hour|day|week|month|year)s?\s+to\s+now\]",
regex.IGNORECASE,
)
# Whoosh-style 8-digit date: field:YYYYMMDD — field-aware so timezone can be applied correctly.
# Scoped to date fields only; numeric fields (asn, id, page_count, ...) must not be rewritten.
_DATE8_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):(?P<date8>\d{8})\b",
)
_YEAR_RANGE_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):\[(?P<y1>\d{4})\s+TO\s+(?P<y2>\d{4})\]",
regex.IGNORECASE,
)
# Tantivy syntax error: " - " and " + " with spaces on both sides are invalid because
# the NOT/MUST operators require no space between the operator and the term.
# In natural-language queries (e.g., "H52.1 - Kurzsichtigkeit"), the dash is a separator.
_SPACED_OPERATOR_RE = regex.compile(r"\s+[-+]\s+")
_TRAILING_OPERATOR_RE = regex.compile(r"\s+[-+]+\s*$")
# Matches CJK/Hangul characters so queries can be routed to bigram fields.
# Uses Unicode properties to cover all blocks including Extension B+ planes.
_CJK_RE: Final = regex.compile(r"[\p{Han}\p{Hiragana}\p{Katakana}\p{Hangul}]+")
@@ -64,12 +117,303 @@ def _build_cjk_query(
return None
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic — date only.
"""
today = datetime.now(tz).date()
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
lo = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, lo + timedelta(days=1))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
lo = datetime(y.year, y.month, y.day, tzinfo=UTC)
hi = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
lo = datetime(last_mon.year, last_mon.month, last_mon.day, tzinfo=UTC)
hi = datetime(this_mon.year, this_mon.month, this_mon.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_MONTH:
lo = datetime(today.year, today.month, 1, tzinfo=UTC)
if today.month == 12:
hi = datetime(today.year + 1, 1, 1, tzinfo=UTC)
else:
hi = datetime(today.year, today.month + 1, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_MONTH:
if today.month == 1:
lo = datetime(today.year - 1, 12, 1, tzinfo=UTC)
else:
lo = datetime(today.year, today.month - 1, 1, tzinfo=UTC)
hi = datetime(today.year, today.month, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_YEAR:
lo = datetime(today.year, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year + 1, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_YEAR:
lo = datetime(today.year - 1, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
lo = datetime(
last_quarter.year,
last_quarter.month,
last_quarter.day,
tzinfo=UTC,
)
hi = datetime(
this_quarter.year,
this_quarter.month,
this_quarter.day,
tzinfo=UTC,
)
return _iso_range(lo, hi)
raise ValueError(f"Unknown keyword: {keyword}")
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC — full offset arithmetic required.
"""
now_local = datetime.now(tz)
today = now_local.date()
def _midnight(d: date) -> datetime:
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
return _iso_range(_midnight(today), _midnight(today + timedelta(days=1)))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
return _iso_range(_midnight(y), _midnight(today))
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
return _iso_range(_midnight(last_mon), _midnight(this_mon))
if keyword == _THIS_MONTH:
first = today.replace(day=1)
if today.month == 12:
next_first = date(today.year + 1, 1, 1)
else:
next_first = date(today.year, today.month + 1, 1)
return _iso_range(_midnight(first), _midnight(next_first))
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
if today.month == 1:
last_first = date(today.year - 1, 12, 1)
else:
last_first = date(today.year, today.month - 1, 1)
return _iso_range(_midnight(last_first), _midnight(this_first))
if keyword == _THIS_YEAR:
return _iso_range(
_midnight(date(today.year, 1, 1)),
_midnight(date(today.year + 1, 1, 1)),
)
if keyword == _PREVIOUS_YEAR:
return _iso_range(
_midnight(date(today.year - 1, 1, 1)),
_midnight(date(today.year, 1, 1)),
)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
return _iso_range(_midnight(last_quarter), _midnight(this_quarter))
raise ValueError(f"Unknown keyword: {keyword}")
def _rewrite_compact_date(query: str) -> str:
"""Rewrite Whoosh compact date tokens (14-digit YYYYMMDDHHmmss) to ISO 8601."""
def _sub(m: regex.Match[str]) -> str:
raw = m.group(1)
try:
dt = datetime(
int(raw[0:4]),
int(raw[4:6]),
int(raw[6:8]),
int(raw[8:10]),
int(raw[10:12]),
int(raw[12:14]),
tzinfo=UTC,
)
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
except ValueError:
return str(m.group(0))
try:
return _COMPACT_DATE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (compact date rewrite timed out)",
)
def _rewrite_relative_range(query: str) -> str:
"""Rewrite Whoosh relative ranges ([now-7d TO now]) to concrete ISO 8601 UTC boundaries."""
def _sub(m: regex.Match[str]) -> str:
now = datetime.now(UTC)
def _offset(s: str | None) -> timedelta:
if not s:
return timedelta(0)
sign = 1 if s[0] == "+" else -1
n, unit = int(s[1:-1]), s[-1]
return (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
lo, hi = now + _offset(m.group(1)), now + _offset(m.group(2))
if lo > hi:
lo, hi = hi, lo
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _RELATIVE_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (relative range rewrite timed out)",
)
def _rewrite_whoosh_relative_range(query: str) -> str:
"""Rewrite Whoosh-style relative date ranges ([-N unit to now]) to ISO 8601.
Supports: second, minute, hour, day, week, month, year (singular and plural).
Example: ``added:[-1 week to now]`` → ``added:[2025-01-01T… TO 2025-01-08T…]``
"""
now = datetime.now(UTC)
def _sub(m: regex.Match[str]) -> str:
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
lo = now - delta_map[unit]
return f"[{_fmt(lo)} TO {_fmt(now)}]"
try:
return _WHOOSH_REL_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (Whoosh relative range rewrite timed out)",
)
def _rewrite_8digit_date(query: str, tz: tzinfo) -> str:
"""Rewrite field:YYYYMMDD date tokens to an ISO 8601 day range.
Runs after ``_rewrite_compact_date`` so 14-digit timestamps are already
converted and won't spuriously match here.
For DateField fields (e.g. ``created``) uses UTC midnight boundaries.
For DateTimeField fields (e.g. ``added``, ``modified``) uses local TZ
midnight boundaries converted to UTC — matching the ``_datetime_range``
behaviour for keyword dates.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
raw = m.group("date8")
try:
year, month, day = int(raw[0:4]), int(raw[4:6]), int(raw[6:8])
d = date(year, month, day)
if field in _DATE_ONLY_FIELDS:
lo = datetime(d.year, d.month, d.day, tzinfo=UTC)
hi = lo + timedelta(days=1)
else:
# DateTimeField: use local-timezone midnight → UTC
lo = datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
hi = datetime(
(d + timedelta(days=1)).year,
(d + timedelta(days=1)).month,
(d + timedelta(days=1)).day,
tzinfo=tz,
).astimezone(UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
except ValueError:
return m.group(0)
try:
return _DATE8_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (8-digit date rewrite timed out)",
)
def _rewrite_year_range(query: str) -> str:
"""Rewrite Whoosh-style year-only date ranges to ISO 8601 UTC boundaries.
Converts ``field:[YYYY TO YYYY]`` to a full ISO 8601 datetime range.
The upper bound is the start of the year after the end year (exclusive),
matching the Whoosh convention of treating year-only ranges as full-year spans.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
y1, y2 = int(m.group("y1")), int(m.group("y2"))
# Whoosh swaps a reversed range when both years are explicit
# (whoosh.util.times.timespan.disambiguated); match that so a backwards
# range spans the intended years instead of matching nothing.
lo_year, hi_year = min(y1, y2), max(y1, y2)
lo = datetime(lo_year, 1, 1, tzinfo=UTC)
hi = datetime(hi_year + 1, 1, 1, tzinfo=UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _YEAR_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError("Query too complex to process (year range rewrite timed out)")
def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
"""
Rewrite natural date syntax to ISO 8601 format for Tantivy compatibility.
Delegates to ``translate_query`` which handles all date forms, comma
expansion, field aliasing, relative ranges, and operator normalization.
Performs the first stage of query preprocessing, converting various date
formats and keywords to ISO 8601 datetime ranges that Tantivy can parse:
- Compact 14-digit dates (YYYYMMDDHHmmss)
- Whoosh relative ranges ([-7 days to now], [now-1h TO now+2h])
- 8-digit dates with field awareness (created:20240115)
- Natural keywords (field:today, field:"previous quarter", etc.)
Args:
query: Raw user query string
@@ -81,15 +425,35 @@ def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
Note:
Bare keywords without field prefixes pass through unchanged.
"""
return translate_query(query, tz)
query = _rewrite_compact_date(query)
query = _rewrite_whoosh_relative_range(query)
query = _rewrite_year_range(query)
query = _rewrite_8digit_date(query, tz)
query = _rewrite_relative_range(query)
def _replace(m: regex.Match[str]) -> str:
field = m.group("field")
keyword = (m.group("quoted") or m.group("bare")).lower()
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(keyword, tz)}"
return f"{field}:{_datetime_range(keyword, tz)}"
try:
return _FIELD_DATE_RE.sub(_replace, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (date keyword rewrite timed out)",
)
def normalize_query(query: str) -> str:
"""
Normalize query syntax for better search behavior.
Delegates to ``translate_query`` which handles comma expansion, whitespace
collapsing, operator normalization, and field aliasing.
Expands comma-separated field values to explicit AND clauses and
collapses excessive whitespace for cleaner parsing:
- tag:foo,bar → tag:foo AND tag:bar
- multiple spaces → single spaces
Args:
query: Query string after date rewriting
@@ -97,7 +461,29 @@ def normalize_query(query: str) -> str:
Returns:
Normalized query string ready for Tantivy parsing
"""
return translate_query(query, UTC)
def _expand(m: regex.Match[str]) -> str:
field = m.group(1)
values = [v.strip() for v in m.group(2).split(",") if v.strip()]
return " AND ".join(f"{field}:{v}" for v in values)
try:
query = regex.sub(
r"(\w+):([^\s\[\]]+(?:,[^\s\[\]]+)+)",
_expand,
query,
timeout=_REGEX_TIMEOUT,
)
query = regex.sub(r" {2,}", " ", query, timeout=_REGEX_TIMEOUT).strip()
# Strip trailing dangling operators before Tantivy sees them.
query = _TRAILING_OPERATOR_RE.sub("", query, timeout=_REGEX_TIMEOUT).strip()
# Replace " - " / " + " with a space: Tantivy requires no space between
# the operator and its operand (-term / +term), so spaces on both sides
# means this is a natural-language separator, not a query operator.
query = _SPACED_OPERATOR_RE.sub(" ", query, timeout=_REGEX_TIMEOUT).strip()
return query
except TimeoutError: # pragma: no cover
raise ValueError("Query too complex to process (normalization timed out)")
def build_permission_filter(
@@ -217,16 +603,8 @@ def parse_user_query(
as a post-search score filter, not during query construction.
"""
try:
query_str = translate_query(raw_query, tz)
except SearchQueryError:
# Intentional, user-fixable error (e.g. an unparsable date). Propagate so
# the view can return a 400 with a helpful message rather than falling
# back to the raw (still-invalid) query.
raise
except Exception: # pragma: no cover - defensive
logger.warning("Query translation failed; using raw query", exc_info=True)
query_str = raw_query
query_str = rewrite_natural_date_keywords(raw_query, tz)
query_str = normalize_query(query_str)
exact = index.parse_query(
query_str,
-566
View File
@@ -1,566 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TypeAlias
import regex
from dateutil.relativedelta import relativedelta
from documents.search._dates import _DATE_KEYWORDS
from documents.search._dates import _DATE_ONLY_FIELDS
from documents.search._dates import _date_only_range
from documents.search._dates import _datetime_range
from documents.search._dates import _field_range_from_dates
from documents.search._dates import _fmt
from documents.search._dates import _precision_bounds
from documents.search._dates import _utc_bounds_for_field
# Compiled regex that matches any known multi-word (or single-word) date keyword
# at the start of a match position, longest alternatives first so "previous week"
# wins over a hypothetical shorter "previous".
_KEYWORD_VALUE_RE = regex.compile(
"|".join(sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True)),
regex.IGNORECASE,
)
if TYPE_CHECKING:
from datetime import tzinfo
# TODO: this module translates date queries into Tantivy *string* syntax, which
# forces a workaround for something Tantivy's string parser cannot express on
# date fields: open-ended ranges use far-past/far-future string sentinels
# (OPEN_LO/OPEN_HI). These can be replaced with a real tantivy.Query object
# (Query.range_query(..., None) for open bounds) once tantivy-py accepts Python
# datetimes in range_query/term_query on Date fields. That support exists on
# tantivy-py master (PRs #655 + #666) but postdates the pinned 0.26.0 wheel, so
# it is blocked only on a published release > 0.26.0 and a dependency bump.
# (Unparsable dates now raise InvalidDateQuery -> HTTP 400 rather than using a
# no-match string sentinel.)
# Fields that store exact, non-analyzed comma-joined tokens in the index and so
# need explicit comma->AND expansion (Whoosh KEYWORD(commas=True) set).
MULTI_VALUE_FIELDS = frozenset({"tag", "tag_id", "viewer_id"})
# Date fields whose values/ranges get rewritten to RFC3339 Tantivy ranges.
DATE_FIELDS = frozenset({"created", "modified", "added"})
# Field aliases: Whoosh (v2) field names that were renamed in the Tantivy schema.
# Preserved here so v2 queries using the old names continue to work without 400
# errors instead of silently failing. Applied by _render to non-date field tokens.
FIELD_ALIASES: dict[str, str] = {
"type": "document_type",
"type_id": "document_type_id",
"path": "storage_path",
"path_id": "storage_path_id",
}
# Known schema fields: a comma immediately followed by ``<known>:`` is a clause
# separator. Restricting to known fields prevents URL-like ``http:`` misfires.
KNOWN_FIELDS = frozenset(
{
"title",
"content",
"correspondent",
"document_type",
"type", # v2 alias -> document_type
"storage_path",
"path", # v2 alias -> storage_path
"tag",
"tag_id",
"correspondent_id",
"document_type_id",
"type_id", # v2 alias -> document_type_id
"storage_path_id",
"path_id", # v2 alias -> storage_path_id
"owner_id",
"viewer_id",
"asn",
"page_count",
"num_notes",
"created",
"modified",
"added",
"original_filename",
"checksum",
"notes",
"custom_fields",
},
)
_FIELD_RE = regex.compile(r"(?P<field>\w+):")
# Matches the TO separator inside a range bracket. Handles three forms:
# middle: "lo TO hi" (either lo or hi may be empty)
# trailing: "lo TO" (open upper bound)
# leading: "TO hi" (open lower bound)
# Bounds MAY contain internal spaces (e.g. "-7 days"), so we use .*? / .+?
# and split on the whitespace-delimited " TO " / " to " separator.
_RANGE_RE = regex.compile(
r"^\s*(?P<lo>.*?)\s+[Tt][Oo]\s+(?P<hi>.+?)\s*$"
r"|"
r"^\s*(?P<lo2>.+?)\s+[Tt][Oo]\s*$"
r"|"
r"^\s*[Tt][Oo]\s+(?P<hi2>.+?)\s*$",
)
@dataclass(frozen=True, slots=True)
class FieldValue:
field: str
value: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class FieldValueList:
field: str
values: tuple[str, ...]
@dataclass(frozen=True, slots=True)
class FieldRange:
field: str
open: str
lo: str
hi: str
close: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class Comma:
pass
@dataclass(frozen=True, slots=True)
class Passthrough:
raw: str
Token: TypeAlias = FieldValue | FieldValueList | FieldRange | Comma | Passthrough
_CLOSE: dict[str, str] = {"[": "]", "{": "}"}
def scan(query: str) -> list[Token]:
"""
Tokenize a raw query into date/comma-aware tokens, leaving everything else
as verbatim ``Passthrough`` runs. Non-recursive: finds the first matching
close bracket/quote. Nested brackets are not valid Tantivy range syntax and
pass through verbatim on mismatch.
"""
tokens: list[Token] = []
buf: list[str] = [] # accumulates passthrough chars
i, n = 0, len(query)
while i < n:
matched = _match_field_token(query, i)
if matched is None:
buf.append(query[i])
i += 1
continue
token, i = matched
_flush(buf, tokens)
tokens.append(token)
i = _maybe_comma(query, i, tokens)
_flush(buf, tokens)
return tokens
def _flush(buf: list[str], tokens: list[Token]) -> None:
"""Emit any accumulated passthrough characters as a single token."""
if buf:
tokens.append(Passthrough("".join(buf)))
buf.clear()
def _at_word_boundary(query: str, i: int) -> bool:
"""A field token may begin only at the start or after a non-word character."""
return i == 0 or not (query[i - 1].isalnum() or query[i - 1] == "_")
def _match_field_token(query: str, i: int) -> tuple[Token, int] | None:
"""
If a known ``field:`` token starts at ``i``, consume it and return
``(token, end_index)``; otherwise return None so the caller treats the
character as passthrough. Handles both ``field:[range]`` and ``field:value``,
and returns None when the range/value cannot be consumed.
"""
m = _FIELD_RE.match(query, i)
if m is None or m.group("field") not in KNOWN_FIELDS:
return None
if not _at_word_boundary(query, i):
return None
field = m.group("field")
j = m.end()
if j < len(query) and query[j] in "[{":
return _consume_range(query, j, field)
consumed = _consume_field_value(query, field, j)
if consumed is None:
return None
value, end = consumed
return FieldValue(field, value), end
def _consume_field_value(query: str, field: str, start: int) -> tuple[str, int] | None:
"""
Consume a field value starting at ``start``: a multi-word date keyword phrase
(date fields only), or a bare/quoted value, then absorb any comma-joined
continuation that is not a clause separator. ``resolve_commas`` later splits a
multi-value field's joined value into a ``FieldValueList``; for other fields
the comma stays literal.
"""
n = len(query)
consumed = None
if field in DATE_FIELDS:
km = _KEYWORD_VALUE_RE.match(query, start)
if km is not None and (km.end() >= n or query[km.end()] in " \t),"):
consumed = (km.group(0), km.end())
if consumed is None:
consumed = _consume_value(query, start)
if consumed is None:
return None
value, k = consumed
while k < n and query[k] == ",":
if _looks_like_known_field(query, k + 1):
break # clause separator: left for _maybe_comma to emit a Comma()
more = _consume_value(query, k + 1)
if more is None:
break
value = f"{value},{more[0]}"
k = more[1]
return value, k
def _consume_range(
query: str,
start: int,
field: str,
) -> tuple[FieldRange, int] | None:
"""Consume ``[lo TO hi]`` / ``{lo TO hi}`` from ``start`` (the bracket)."""
open_br = query[start]
close_br = _CLOSE[open_br]
end = query.find(close_br, start + 1)
if end == -1:
return None
inner = query[start + 1 : end]
m = _RANGE_RE.match(inner)
if m is not None:
if m.group("lo") is not None or m.group("hi") is not None:
# Middle form: "lo TO hi" (either may be empty string)
lo = (m.group("lo") or "").strip()
hi = (m.group("hi") or "").strip()
elif m.group("lo2") is not None:
# Trailing form: "lo TO"
lo = m.group("lo2").strip()
hi = ""
else:
# Leading form: "TO hi"
lo = ""
hi = (m.group("hi2") or "").strip()
else:
lo, hi = inner.strip(), ""
return FieldRange(field, open_br, lo, hi, close_br), end + 1
def _consume_value(query: str, start: int) -> tuple[str, int] | None:
"""Consume a bare or quoted field value from ``start``, stopping at comma."""
n = len(query)
if start >= n or query[start] in " \t":
return None
if query[start] in "\"'":
quote = query[start]
end = query.find(quote, start + 1)
if end == -1:
return None
return query[start : end + 1], end + 1
j = start
while j < n and query[j] not in " \t),":
j += 1
return query[start:j], j
def _looks_like_known_field(query: str, pos: int) -> bool:
"""True if a known ``field:`` token starts at ``pos``."""
m = _FIELD_RE.match(query, pos)
return bool(m and m.group("field") in KNOWN_FIELDS)
def _maybe_comma(query: str, i: int, tokens: list) -> int:
"""If a clause-separator comma follows at ``i``, emit ``Comma()`` and advance."""
if i < len(query) and query[i] == "," and _looks_like_known_field(query, i + 1):
tokens.append(Comma())
return i + 1
return i
def resolve_commas(tokens: list) -> list:
"""
Collapse value-list commas into ``FieldValueList`` and keep clause-separator
commas as ``Comma``. (Clause-sep commas are already emitted by ``scan`` via
the value-stop logic; this pass folds value-lists.)
"""
out: list = []
for tok in tokens:
if (
isinstance(tok, FieldValue)
and tok.field in MULTI_VALUE_FIELDS
and "," in tok.value
):
values = tuple(v for v in tok.value.split(",") if v)
out.append(FieldValueList(tok.field, values))
else:
out.append(tok)
return out
class SearchQueryError(ValueError):
"""
Base for user-fixable search query errors.
Carries a message safe to surface to the user (no internal details). The view
layer catches this and returns an HTTP 400, so any future subclass (unknown
field, malformed range, wrapped parser errors) gets the same treatment.
"""
class InvalidDateQuery(SearchQueryError):
"""Raised when a date field value or range bound cannot be parsed."""
def __init__(self, field: str, value: str) -> None:
self.field = field
self.value = value
super().__init__(f"Invalid date value {value!r} for field {field!r}.")
_DIGITS_RE = regex.compile(r"^\d{4}(?:\d{2}){0,2}$")
_ISO_RE = regex.compile(r"^\d{4}(?:-\d{2}(?:-\d{2})?)?$")
def translate_scalar(field: str, value: str, tz: tzinfo) -> str:
"""Translate a bare date-field value to a Tantivy range string."""
bare = value.strip("\"'").lower()
if bare in _DATE_KEYWORDS:
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(bare, tz)}"
return f"{field}:{_datetime_range(bare, tz)}"
digits = value.replace("-", "")
if _DIGITS_RE.match(value) or _ISO_RE.match(value):
bounds = _precision_bounds(digits)
if bounds is None:
raise InvalidDateQuery(field, value)
return _field_range_from_dates(field, bounds[0], bounds[1], tz)
if regex.fullmatch(r"\d{14}", value):
try:
dt = datetime(
int(value[0:4]),
int(value[4:6]),
int(value[6:8]),
int(value[8:10]),
int(value[10:12]),
int(value[12:14]),
tzinfo=UTC,
)
except ValueError:
raise InvalidDateQuery(field, value) from None
iso = _fmt(dt)
return f"{field}:[{iso} TO {iso}]"
# Unrecognized shape -> tell the user their date is malformed rather than
# silently matching nothing or emitting invalid Tantivy syntax.
raise InvalidDateQuery(field, value)
# Open-bound sentinels for date ranges. These far-past/far-future strings allow
# open-ended ranges to be expressed as Tantivy string queries until tantivy-py
# exposes Query.range_query(..., None) on Date fields (see module TODO).
OPEN_LO = "0001-01-01T00:00:00Z"
OPEN_HI = "9999-12-31T23:59:59Z"
# Matches compact now-offset tokens like now-7d, now+1h, now-30m.
_NOW_COMPACT_RE = regex.compile(
r"^now(?P<sign>[+-])(?P<n>\d+)(?P<unit>[dhm])$",
regex.IGNORECASE,
)
# Matches "±N <unit>" Whoosh-style offsets (e.g. -7 days, -1 week, +3 hours)
# Unit is singular or plural; sign prefix is mandatory.
_NOW_SPACED_RE = regex.compile(
r"^(?P<sign>[+-])(?P<n>\d+)\s*"
r"(?P<unit>second|minute|hour|day|week|month|year)s?$",
regex.IGNORECASE,
)
def _resolve_relative_bound(token: str) -> datetime | None:
"""
Resolve a relative bound token to an exact UTC instant, or return None.
Supported forms:
- ``now`` -> current UTC instant
- ``now+/-<n>d/h/m`` -> now +/- timedelta (d=days, h=hours, m=minutes)
- ``±N <unit>`` -> now +/- delta; month/year use relativedelta
"""
stripped = token.strip()
low = stripped.lower()
now = datetime.now(UTC)
if low == "now":
return now
m = _NOW_COMPACT_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta = (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
return now + delta
m = _NOW_SPACED_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
return now - delta_map[unit] if sign == -1 else now + delta_map[unit]
return None
def _bound_datetimes(
field: str,
token: str,
tz: tzinfo,
) -> tuple[datetime, datetime] | None:
"""
Return (floor_dt, ceil_dt) UTC datetimes for a single range bound token, or
None if the token is unparsable. ``now`` and relative offsets resolve to the
current instant (floor == ceil == that instant; no day-flooring).
"""
token = token.strip()
# Try relative/now forms first (before stripping hyphens which would mangle them).
rel = _resolve_relative_bound(token)
if rel is not None:
return rel, rel
# Full ISO datetime token (contains "T"): parse directly and return an exact
# instant (floor == ceil). Python 3.11+ datetime.fromisoformat accepts trailing Z.
if "T" in token:
try:
dt = datetime.fromisoformat(token)
# Ensure timezone-aware UTC result.
dt = dt.replace(tzinfo=UTC) if dt.tzinfo is None else dt.astimezone(UTC)
return dt, dt
except ValueError:
return None
digits = token.replace("-", "")
bounds = _precision_bounds(digits)
if bounds is None:
return None
start, end = bounds
return _utc_bounds_for_field(field, start, end, tz)
def _render(tok: Token, tz: tzinfo) -> str:
"""Render a single token back to a Tantivy query string fragment."""
if isinstance(tok, Passthrough):
return tok.raw
if isinstance(tok, Comma):
return " AND "
if isinstance(tok, FieldValueList):
field = FIELD_ALIASES.get(tok.field, tok.field)
return " AND ".join(f"{field}:{v}" for v in tok.values)
if isinstance(tok, FieldValue):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_scalar(field, tok.value, tz)
return f"{field}:{tok.value}"
if isinstance(tok, FieldRange):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_range(field, tok.lo, tok.hi, tz)
return f"{field}:{tok.open}{tok.lo} TO {tok.hi}{tok.close}"
return "" # pragma: no cover
# Post-render operator normalization patterns: collapse repeated whitespace and
# strip spaced/trailing Tantivy boolean operators that would otherwise be invalid.
_MULTI_SPACE_RE = regex.compile(r" {2,}")
_TRAILING_OP_RE = regex.compile(r"\s+[-+]+\s*$")
_SPACED_OP_RE = regex.compile(r"\s+[-+]\s+")
def _normalize_operators(text: str) -> str:
"""
Collapse multiple spaces, strip trailing dangling operators, and replace
spaced operators (`` - `` / `` + ``) with a single space.
Applied only to Passthrough fragments (the rendered output is scanned for
operator artifacts outside bracketed ranges) via a post-render pass on the
full rendered string. This preserves date ranges (``[... TO ...]``) verbatim
while cleaning natural-language separators in the surrounding text.
"""
text = _MULTI_SPACE_RE.sub(" ", text)
text = _TRAILING_OP_RE.sub("", text).strip()
text = _SPACED_OP_RE.sub(" ", text).strip()
return text
def translate_query(raw: str, tz: tzinfo) -> str:
"""Translate a raw Whoosh-style query into Tantivy-compatible syntax."""
tokens = resolve_commas(scan(raw))
rendered = "".join(_render(t, tz) for t in tokens)
return _normalize_operators(rendered)
def translate_range(field: str, lo: str, hi: str, tz: tzinfo) -> str:
"""Translate a date-field ``[lo TO hi]`` range to a Tantivy ISO range string.
Handles partial-date bounds (YYYY, YYYYMM, YYYYMMDD, ISO dash variants),
open bounds (empty string -> OPEN_LO/OPEN_HI), ``now``, and reversed ranges
(swaps tokens before computing floor/ceil so the span is always correct).
"""
lo_s = lo.strip()
hi_s = hi.strip()
# Parse both bounds to (floor, ceil) pairs when present.
lo_pair: tuple[datetime, datetime] | None = None
hi_pair: tuple[datetime, datetime] | None = None
if lo_s:
lo_pair = _bound_datetimes(field, lo_s, tz)
if lo_pair is None:
raise InvalidDateQuery(field, lo_s)
if hi_s:
hi_pair = _bound_datetimes(field, hi_s, tz)
if hi_pair is None:
raise InvalidDateQuery(field, hi_s)
# Detect a reversed range: only swap when BOTH bounds are present.
if lo_pair is not None and hi_pair is not None and lo_pair[0] > hi_pair[0]:
lo_pair, hi_pair = hi_pair, lo_pair
lo_iso = _fmt(lo_pair[0]) if lo_pair is not None else OPEN_LO
hi_iso = _fmt(hi_pair[1]) if hi_pair is not None else OPEN_HI
return f"{field}:[{lo_iso} TO {hi_iso}]"
+15 -24
View File
@@ -1,7 +1,6 @@
import logging
import os
import re
import unicodedata
from collections.abc import Iterable
from pathlib import PurePath
@@ -37,12 +36,10 @@ class FilePathTemplate(Template):
def clean_filepath(value: str) -> str:
"""
Clean up a filepath by:
1. Normalizing Unicode to NFC form to prevent byte-level mismatches
2. Removing newlines and carriage returns
3. Removing extra spaces before and after forward slashes
4. Preserving spaces in other parts of the path
1. Removing newlines and carriage returns
2. Removing extra spaces before and after forward slashes
3. Preserving spaces in other parts of the path
"""
value = unicodedata.normalize("NFC", value)
value = value.replace("\n", "").replace("\r", "")
value = re.sub(r"\s*/\s*", "/", value)
@@ -184,17 +181,17 @@ def get_basic_metadata_context(
"""
return {
"title": pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", document.title),
document.title,
replacement_text="-",
),
"correspondent": pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", document.correspondent.name),
document.correspondent.name,
replacement_text="-",
)
if document.correspondent
else no_value_default,
"document_type": pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", document.document_type.name),
document.document_type.name,
replacement_text="-",
)
if document.document_type
@@ -205,10 +202,7 @@ def get_basic_metadata_context(
"owner_username": document.owner.username
if document.owner
else no_value_default,
"original_name": unicodedata.normalize(
"NFC",
PurePath(document.original_filename).with_suffix("").name,
)
"original_name": PurePath(document.original_filename).with_suffix("").name
if document.original_filename
else no_value_default,
"doc_pk": f"{document.pk:07}",
@@ -275,12 +269,12 @@ def get_tags_context(tags: Iterable[Tag]) -> dict[str, str | list[str]]:
return {
"tag_list": pathvalidate.sanitize_filename(
",".join(
sorted(unicodedata.normalize("NFC", tag.name) for tag in tags),
sorted(tag.name for tag in tags),
),
replacement_text="-",
),
# Assumed to be ordered, but a template could loop through to find what they want
"tag_name_list": [unicodedata.normalize("NFC", x.name) for x in tags],
"tag_name_list": [x.name for x in tags],
}
@@ -307,7 +301,7 @@ def get_custom_fields_context(
CustomField.FieldDataType.LONG_TEXT,
}:
value = pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", field_instance.value),
field_instance.value,
replacement_text="-",
)
elif (
@@ -316,13 +310,10 @@ def get_custom_fields_context(
):
options = field_instance.field.extra_data["select_options"]
value = pathvalidate.sanitize_filename(
unicodedata.normalize(
"NFC",
next(
option["label"]
for option in options
if option["id"] == field_instance.value
),
next(
option["label"]
for option in options
if option["id"] == field_instance.value
),
replacement_text="-",
)
@@ -330,7 +321,7 @@ def get_custom_fields_context(
value = field_instance.value
field_data["custom_fields"][
pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", field_instance.field.name),
field_instance.field.name,
replacement_text="-",
)
] = {
-12
View File
@@ -1,15 +1,11 @@
from __future__ import annotations
import tempfile
from typing import TYPE_CHECKING
import pytest
import tantivy
from documents.search._backend import TantivyBackend
from documents.search._backend import reset_backend
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
if TYPE_CHECKING:
from collections.abc import Generator
@@ -35,11 +31,3 @@ def backend() -> Generator[TantivyBackend, None, None]:
finally:
b.close()
reset_backend()
@pytest.fixture(scope="module")
def index() -> tantivy.Index:
"""A real Tantivy index for parse-acceptance tests (module scope for speed)."""
idx = tantivy.Index(build_schema(), path=tempfile.mkdtemp())
register_tokenizers(idx, "english")
return idx
+10 -88
View File
@@ -13,6 +13,7 @@ import time_machine
from documents.search._query import _date_only_range
from documents.search._query import _datetime_range
from documents.search._query import _rewrite_compact_date
from documents.search._query import build_permission_filter
from documents.search._query import normalize_query
from documents.search._query import parse_simple_text_highlight_query
@@ -20,7 +21,6 @@ from documents.search._query import parse_user_query
from documents.search._query import rewrite_natural_date_keywords
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
from documents.search._translate import InvalidDateQuery
if TYPE_CHECKING:
from django.contrib.auth.base_user import AbstractBaseUser
@@ -405,14 +405,12 @@ class TestWhooshQueryRewriting:
assert lo == "2023-12-01T05:00:00Z"
assert hi == "2023-12-02T05:00:00Z"
def test_8digit_invalid_date_raises(self) -> None:
# The translation pipeline raises InvalidDateQuery for unparsable dates
# (e.g. month=13) so the API can surface a 400 telling the user the date
# is malformed instead of silently returning zero results.
with pytest.raises(InvalidDateQuery) as exc_info:
rewrite_natural_date_keywords("added:20231340", UTC)
assert exc_info.value.field == "added"
assert exc_info.value.value == "20231340"
def test_8digit_invalid_date_passes_through_unchanged(self) -> None:
assert rewrite_natural_date_keywords("added:20231340", UTC) == "added:20231340"
def test_compact_14digit_invalid_date_passes_through_unchanged(self) -> None:
# Month=13 makes datetime() raise ValueError; the token must be left as-is
assert _rewrite_compact_date("20231300120000") == "20231300120000"
class TestParseUserQuery:
@@ -465,67 +463,6 @@ class TestParseUserQuery:
) -> None:
assert isinstance(parse_user_query(query_index, raw_query, UTC), tantivy.Query)
@pytest.mark.parametrize(
"raw_query",
[
# Partial date scalar (year only)
pytest.param("created:2020", id="created_year_scalar"),
# 8-digit compact date range in brackets
pytest.param(
"created:[20200101 TO 20201231]",
id="created_8digit_bracket_range",
),
# Comma-separated field + date range (Whoosh v2 multi-clause syntax)
pytest.param(
"title:x,created:[2020 TO 2021]",
id="title_comma_created_range",
),
# Field alias: type -> document_type
pytest.param("type:invoice", id="type_alias"),
# Multi-word date keyword
pytest.param("created:previous week", id="created_previous_week"),
# Full ISO datetime range
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="created_iso_range",
),
# Comma-separated ISO ranges (Whoosh v2 syntax)
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="comma_iso_ranges",
),
],
)
def test_advanced_search_queries_do_not_raise(
self,
query_index: tantivy.Index,
raw_query: str,
) -> None:
"""
End-to-end: queries that the frontend sends must parse without raising.
This tests the full pipeline: translate_query -> tantivy parse_query.
Equivalent to asserting HTTP 200 (not 400) for each query form.
"""
with time_machine.travel(datetime(2026, 6, 15, 12, 0, tzinfo=UTC), tick=False):
assert isinstance(
parse_user_query(query_index, raw_query, UTC),
tantivy.Query,
)
def test_invalid_date_propagates_not_swallowed(
self,
query_index: tantivy.Index,
) -> None:
# parse_user_query falls back to the raw query on unexpected translation
# errors, but an InvalidDateQuery is intentional and must propagate so the
# view can return a 400 instead of silently parsing the raw (invalid) date.
with pytest.raises(InvalidDateQuery) as exc_info:
parse_user_query(query_index, "created:202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
class TestYearRangeRewriting:
"""Whoosh-style year-only date ranges must be rewritten to ISO 8601."""
@@ -605,16 +542,11 @@ class TestYearRangeRewriting:
assert rewrite_natural_date_keywords(original, UTC) == original
def test_8digit_in_brackets_not_matched_as_year_range(self) -> None:
# [YYYYMMDD TO YYYYMMDD]: the translation layer converts 8-digit bounds to
# ISO day ranges. 20200101 -> 2020-01-01T00:00:00Z (lo of that day);
# 20201231 -> the ceil of Dec 31 = 2021-01-01T00:00:00Z (exclusive end).
# This is the correct and accepted behavior: old compact form becomes a
# proper Tantivy-parseable ISO range.
# [YYYYMMDD TO YYYYMMDD] has 8-digit values - must not be caught by year rewriter
original = "created:[20200101 TO 20201231]"
result = rewrite_natural_date_keywords(original, UTC)
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2021-01-01T00:00:00Z"
assert "20200101" in result or "2020-01-01" in result
assert "20201231" in result or "2020-12-31" in result
class TestNonDateFieldsNotRewritten:
@@ -674,16 +606,6 @@ class TestNormalizeQuery:
def test_normalize_expands_comma_separated_tags(self) -> None:
assert normalize_query("tag:foo,bar") == "tag:foo AND tag:bar"
def test_normalize_comma_between_range_expressions(self) -> None:
# Comma-separated field range expressions (Whoosh v2 syntax) must be
# converted to AND so Tantivy does not receive an invalid comma.
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert normalize_query(q) == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_normalize_expands_three_values(self) -> None:
assert normalize_query("tag:foo,bar,baz") == "tag:foo AND tag:bar AND tag:baz"
@@ -1,742 +0,0 @@
from __future__ import annotations
from datetime import UTC
from datetime import datetime
from typing import TYPE_CHECKING
from zoneinfo import ZoneInfo
import pytest
import time_machine
from documents.search._dates import _precision_bounds
if TYPE_CHECKING:
import tantivy
from documents.search._query import _FIELD_BOOSTS
from documents.search._query import DEFAULT_SEARCH_FIELDS
from documents.search._translate import OPEN_HI
from documents.search._translate import OPEN_LO
from documents.search._translate import Comma
from documents.search._translate import FieldRange
from documents.search._translate import FieldValue
from documents.search._translate import FieldValueList
from documents.search._translate import InvalidDateQuery
from documents.search._translate import Passthrough
from documents.search._translate import resolve_commas
from documents.search._translate import scan
from documents.search._translate import translate_query
from documents.search._translate import translate_range
from documents.search._translate import translate_scalar
@pytest.mark.search
class TestPrecisionBounds:
@pytest.mark.parametrize(
("digits", "expected"),
[
("2020", ((2020, 1, 1), (2021, 1, 1))),
("202003", ((2020, 3, 1), (2020, 4, 1))),
("202012", ((2020, 12, 1), (2021, 1, 1))),
("20200115", ((2020, 1, 15), (2020, 1, 16))),
("20201231", ((2020, 12, 31), (2021, 1, 1))),
],
)
def test_valid(self, digits, expected):
lo, hi = _precision_bounds(digits)
assert (lo.year, lo.month, lo.day) == expected[0]
assert (hi.year, hi.month, hi.day) == expected[1]
@pytest.mark.parametrize("digits", ["202023", "20200230", "20201301", "20", "abcd"])
def test_invalid_returns_none(self, digits):
assert _precision_bounds(digits) is None
@pytest.mark.search
class TestScan:
def test_plain_words_are_passthrough(self):
assert scan("bank statement") == [Passthrough("bank statement")]
def test_field_value(self):
assert scan("created:2020") == [FieldValue("created", "2020")]
def test_field_value_in_boolean(self):
toks = scan("created:2020 OR foo")
assert toks == [
FieldValue("created", "2020"),
Passthrough(" OR foo"),
]
def test_field_value_in_parens(self):
toks = scan("(created:2020 OR foo)")
assert toks == [
Passthrough("("),
FieldValue("created", "2020"),
Passthrough(" OR foo)"),
]
def test_quoted_value(self):
assert scan('correspondent:"A B"') == [FieldValue("correspondent", '"A B"')]
def test_field_range(self):
assert scan("created:[2020 TO 2021]") == [
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.parametrize(
("query", "expected"),
[
pytest.param(
"created:[2020 to]",
FieldRange("created", "[", "2020", "", "]"),
id="open_upper",
),
pytest.param(
"created:[to 2020]",
FieldRange("created", "[", "", "2020", "]"),
id="open_lower",
),
],
)
def test_open_range(self, query, expected):
assert scan(query) == [expected]
def test_comma_inside_range_not_split(self):
# No depth-0 comma here; the whole thing is one range token.
toks = scan("created:[2020 TO 2021]")
assert len(toks) == 1
# --- Edge-case / regression tests (scan must never raise) ---
def test_url_is_passthrough(self):
# "http" is not a known field; the whole URL must pass through verbatim.
assert scan("http://example.com") == [Passthrough("http://example.com")]
def test_unterminated_quote_is_passthrough(self):
# title is a known field but the quoted value has no closing quote;
# _consume_value returns None so the whole string falls into passthrough.
assert scan('title:"abc') == [Passthrough('title:"abc')]
def test_unterminated_bracket_is_passthrough(self):
# created is a known field but the range bracket is never closed;
# _consume_range returns None so the whole string falls into passthrough.
assert scan("created:[2020") == [Passthrough("created:[2020")]
def test_empty_value_at_end_is_passthrough(self):
# created is a known field but there is no value after the colon
# (_consume_value returns None for start >= n), so passthrough.
assert scan("created:") == [Passthrough("created:")]
def test_value_containing_colon(self):
# The bare-word value reader stops at whitespace/paren, not at colon,
# so "2020:30" is consumed as a single value token.
assert scan("created:2020:30") == [FieldValue("created", "2020:30")]
def test_comma_followed_by_unconsumable_value_stops(self):
# A comma followed by whitespace is neither a value-list continuation nor a
# clause separator: the value stops and the comma stays as passthrough.
assert scan("tag:foo, bar") == [
FieldValue("tag", "foo"),
Passthrough(", bar"),
]
def test_bracket_without_to_is_open_upper_bound(self):
# A bracketed value with no TO falls back to (value, "") -> open upper bound.
assert scan("created:[2020]") == [
FieldRange("created", "[", "2020", "", "]"),
]
def test_known_field_name_midword_is_passthrough(self):
# A known field name embedded mid-word is not a field token (the
# word-boundary guard); the whole run stays passthrough.
assert scan("xtag:foo") == [Passthrough("xtag:foo")]
@pytest.mark.search
class TestCommaResolution:
def test_value_list_multi_value_field(self):
toks = resolve_commas(scan("tag:foo,bar"))
assert toks == [FieldValueList("tag", ("foo", "bar"))]
def test_value_list_three(self):
toks = resolve_commas(scan("tag_id:1,2,3"))
assert toks == [FieldValueList("tag_id", ("1", "2", "3"))]
def test_text_field_comma_is_literal(self):
# correspondent is not multi-value: comma stays inside the value.
toks = resolve_commas(scan("correspondent:foo,bar"))
assert toks == [FieldValue("correspondent", "foo,bar")]
def test_clause_separator_before_known_field(self):
toks = resolve_commas(scan("tag:foo,type:bar"))
assert toks == [FieldValue("tag", "foo"), Comma(), FieldValue("type", "bar")]
def test_clause_separator_after_range(self):
toks = resolve_commas(scan("created:[2020 TO 2021],added:[2022 TO 2023]"))
assert toks == [
FieldRange("created", "[", "2020", "2021", "]"),
Comma(),
FieldRange("added", "[", "2022", "2023", "]"),
]
def test_clause_separator_after_quote(self):
toks = resolve_commas(scan('correspondent:"A B",created:[2020 TO 2021]'))
assert toks == [
FieldValue("correspondent", '"A B"'),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
def test_url_comma_is_literal_passthrough(self):
toks = resolve_commas(scan("http://example.com/a,b"))
assert toks == [Passthrough("http://example.com/a,b")]
def test_non_multi_value_comma_is_literal(self):
# title is not in MULTI_VALUE_FIELDS: comma stays inside the value.
toks = resolve_commas(scan("title:10,20"))
assert toks == [FieldValue("title", "10,20")]
def test_clause_separator_before_known_date_field(self):
# The comma between a bare value and a known date field acts as a
# clause separator; both sides survive as distinct tokens.
toks = resolve_commas(scan("correspondent:foo,created:[2020 TO 2021]"))
assert toks == [
FieldValue("correspondent", "foo"),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.search
class TestTranslateScalar:
@pytest.mark.parametrize(
("field", "value", "expected"),
[
(
"created",
"2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"created",
"202003",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
(
"created",
"20200115",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-01-15",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-03",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
],
)
def test_partial_and_iso_dates(self, field: str, value: str, expected: str) -> None:
assert translate_scalar(field, value, UTC) == expected
def test_invalid_date_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_keyword_delegates(self) -> None:
# keyword path produces a range; just assert it is a created range
out = translate_scalar("created", "today", UTC)
assert out.startswith("created:[") and out.endswith("]")
def test_14digit_compact_datetime(self) -> None:
out = translate_scalar("created", "20240115120000", UTC)
assert "20240115120000" not in out
assert out.startswith("created:")
assert out == "created:[2024-01-15T12:00:00Z TO 2024-01-15T12:00:00Z]"
def test_14digit_invalid_month_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "20231300120000", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "20231300120000"
def test_unrecognized_value_raises(self) -> None:
# A value that is not a keyword, digits, ISO date, or compact timestamp
# raises rather than producing invalid Tantivy syntax or silently matching
# nothing.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateRange:
@pytest.mark.parametrize(
("lo", "hi", "expected"),
[
("2005", "2009", "created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"),
(
"202001",
"202006",
"created:[2020-01-01T00:00:00Z TO 2020-07-01T00:00:00Z]",
),
(
"20200101",
"20201231",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"2020-01-01",
"2020-12-31",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
],
)
def test_absolute_ranges(self, lo, hi, expected):
assert translate_range("created", lo, hi, UTC) == expected
def test_reversed_swaps(self):
assert translate_range("created", "2009", "2005", UTC) == (
"created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"
)
def test_open_upper(self):
out = translate_range("created", "2020", "", UTC)
assert out == f"created:[2020-01-01T00:00:00Z TO {OPEN_HI}]"
def test_open_lower(self):
out = translate_range("created", "", "2020", UTC)
assert out == f"created:[{OPEN_LO} TO 2021-01-01T00:00:00Z]"
def test_invalid_bound_raises(self):
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "202023", "2025", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_invalid_high_bound_raises(self):
# Low bound parses, high bound does not -> raise on the high bound.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "2020", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateQuery:
@pytest.mark.parametrize(
("raw", "expected"),
[
(
"created:2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
("tag:foo,bar", "tag:foo AND tag:bar"),
# 'type' is a user-facing alias rewritten to 'document_type' (the real schema field)
("tag:foo,type:bar", "tag:foo AND document_type:bar"),
(
"created:[2020 TO 2021],added:[2022 TO 2023]",
"created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
" AND "
"added:[2022-01-01T00:00:00Z TO 2024-01-01T00:00:00Z]",
),
# correspondent is not multi-value: comma stays literal inside the value
("correspondent:foo,bar", "correspondent:foo,bar"),
],
)
def test_golden(self, raw: str, expected: str) -> None:
assert translate_query(raw, UTC) == expected
@pytest.mark.parametrize(
"raw",
[
"created:2020",
"created:202003",
"created:[20200101 TO 20201231]",
"created:[2020-01-01 TO 2020-12-31]",
"created:[2020 to]",
"created:[to 2020]",
"title:x,created:[2020 TO 2021]",
"created:2020 OR foo",
"(created:2020 OR invoice)",
"tag:foo,type:bar",
"bank statement",
],
)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
# Must not raise:
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestFieldAliasing:
"""Whoosh->Tantivy field-name aliasing (type/path -> document_type/storage_path)."""
def test_type_alias(self) -> None:
assert translate_query("type:invoice", UTC) == "document_type:invoice"
def test_path_alias(self) -> None:
assert translate_query("path:/foo/bar", UTC) == "storage_path:/foo/bar"
def test_type_id_alias(self) -> None:
assert translate_query("type_id:5", UTC) == "document_type_id:5"
def test_path_id_alias(self) -> None:
assert translate_query("path_id:7", UTC) == "storage_path_id:7"
def test_clause_separator_plus_alias(self) -> None:
# Comma between known fields acts as AND separator; alias still applied.
assert (
translate_query("tag:foo,type:bar", UTC) == "tag:foo AND document_type:bar"
)
def test_type_range_alias(self) -> None:
# type is not a date field; range passes through verbatim with alias applied.
assert (
translate_query("type:[2020 TO 2021]", UTC)
== "document_type:[2020 TO 2021]"
)
def test_parse_acceptance_type(self, index: tantivy.Index) -> None:
# Translated output must be accepted by the real Tantivy parser.
translated = translate_query("type:invoice", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_path(self, index: tantivy.Index) -> None:
translated = translate_query("path:foo", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
# Freeze time so relative-date tests are deterministic.
_FROZEN_NOW = datetime(2026, 3, 28, 12, 0, 0, tzinfo=UTC)
@pytest.mark.search
class TestRelativeRanges:
"""Relative date-range tokens resolved against a frozen clock."""
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_7_days_to_now(self) -> None:
assert translate_query("added:[-7 days to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_week_to_now(self) -> None:
assert translate_query("added:[-1 week to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_month_to_now(self) -> None:
assert translate_query("created:[-1 month to now]", UTC) == (
"created:[2026-02-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_year_to_now(self) -> None:
assert translate_query("modified:[-1 year to now]", UTC) == (
"modified:[2025-03-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_3_hours_to_now(self) -> None:
assert translate_query("added:[-3 hours to now]", UTC) == (
"added:[2026-03-28T09:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_uppercase_units(self) -> None:
assert translate_query("added:[-1 WEEK TO NOW]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_now_minus_7d_compact(self) -> None:
assert translate_query("added:[now-7d TO now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_reversed_range_swapped(self) -> None:
# now+1h TO now-1h is reversed; translate_range swaps -> lo=now-1h, hi=now+1h
assert translate_query("added:[now+1h TO now-1h]", UTC) == (
"added:[2026-03-28T11:00:00Z TO 2026-03-28T13:00:00Z]"
)
@pytest.mark.parametrize(
"raw",
[
"added:[-7 days to now]",
"added:[-1 week to now]",
"created:[-1 month to now]",
"modified:[-1 year to now]",
"added:[-3 hours to now]",
"added:[now-7d TO now]",
"added:[now+1h TO now-1h]",
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestOperatorNormalization:
"""Post-render operator normalization in translate_query."""
def test_spaced_dash_removed(self) -> None:
assert (
translate_query("H52.1 - Kurzsichtigkeit", UTC) == "H52.1 Kurzsichtigkeit"
)
def test_spaced_dash_simple(self) -> None:
assert translate_query("bar - baz", UTC) == "bar baz"
def test_trailing_operator_stripped(self) -> None:
assert translate_query("foo -", UTC) == "foo"
def test_date_range_preserved(self) -> None:
out = translate_query("created:[2020 TO 2021]", UTC)
# Must not corrupt the ISO range
assert out == "created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
def test_date_scalar_with_or(self) -> None:
out = translate_query("created:2020 OR foo", UTC)
# The created scalar becomes a range; " OR foo" passes through verbatim.
assert out.startswith("created:[")
assert "OR foo" in out
def test_parse_acceptance_spaced_dash(self, index: tantivy.Index) -> None:
translated = translate_query("H52.1 - Kurzsichtigkeit", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_trailing_op(self, index: tantivy.Index) -> None:
translated = translate_query("foo -", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestMultiWordDateKeywords:
"""scan() must consume multi-word date keywords as a single value."""
def test_scan_previous_week_as_single_token(self) -> None:
# "created:previous week" must produce one FieldValue with value "previous week",
# not FieldValue("created","previous") + Passthrough(" week").
toks = scan("created:previous week")
assert toks == [FieldValue("created", "previous week")]
def test_scan_this_month_as_single_token(self) -> None:
toks = scan("added:this month")
assert toks == [FieldValue("added", "this month")]
def test_scan_previous_month_as_single_token(self) -> None:
toks = scan("created:previous month")
assert toks == [FieldValue("created", "previous month")]
def test_scan_this_year_as_single_token(self) -> None:
toks = scan("added:this year")
assert toks == [FieldValue("added", "this year")]
def test_scan_previous_year_as_single_token(self) -> None:
toks = scan("created:previous year")
assert toks == [FieldValue("created", "previous year")]
def test_scan_previous_quarter_as_single_token(self) -> None:
toks = scan("created:previous quarter")
assert toks == [FieldValue("created", "previous quarter")]
def test_quoted_multi_word_keyword_still_works(self) -> None:
# The quoted form must continue to work as before.
toks = scan('created:"previous week"')
assert toks == [FieldValue("created", '"previous week"')]
def test_non_date_field_not_affected(self) -> None:
# "previous" stops at the space for non-date fields; " week" passes through.
toks = scan("correspondent:previous week")
assert toks == [
FieldValue("correspondent", "previous"),
Passthrough(" week"),
]
@pytest.mark.search
class TestKeywordDateResolution:
"""Relative date keywords resolve to exact ISO ranges against a frozen clock.
Frozen at 2026-03-28 12:00 UTC (a Saturday in Q1) so the week, month,
quarter and year rollovers are all exercised by a single anchor.
"""
# created is a DateField: bounds are UTC midnight, no timezone offset.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"created:[2026-03-28T00:00:00Z TO 2026-03-29T00:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"created:[2026-03-27T00:00:00Z TO 2026-03-28T00:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"created:[2026-03-16T00:00:00Z TO 2026-03-23T00:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"created:[2026-03-01T00:00:00Z TO 2026-04-01T00:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"created:[2026-02-01T00:00:00Z TO 2026-03-01T00:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"created:[2026-01-01T00:00:00Z TO 2027-01-01T00:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"created:[2025-01-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"created:[2025-10-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_date_only_field_keyword_ranges(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"created:{keyword}", UTC) == expected
# added is a DateTimeField: local-tz midnight converted to UTC. Tokyo
# (+09:00, no DST) shifts each midnight boundary back to 15:00Z the day
# before, so this also exercises the local-midnight offset path.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"added:[2026-03-27T15:00:00Z TO 2026-03-28T15:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"added:[2026-03-26T15:00:00Z TO 2026-03-27T15:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"added:[2026-03-15T15:00:00Z TO 2026-03-22T15:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"added:[2026-02-28T15:00:00Z TO 2026-03-31T15:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"added:[2026-01-31T15:00:00Z TO 2026-02-28T15:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"added:[2025-12-31T15:00:00Z TO 2026-12-31T15:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"added:[2024-12-31T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"added:[2025-09-30T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_datetime_field_keyword_ranges_local_tz(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"added:{keyword}", ZoneInfo("Asia/Tokyo")) == expected
@pytest.mark.search
class TestISODatetimeBounds:
"""Full ISO datetime tokens in range bounds must be parsed directly."""
def test_translate_range_iso_bounds_passthrough(self) -> None:
# Already-ISO datetime bounds must pass through as-is (exact instant).
result = translate_range(
"created",
"2020-01-01T00:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert result == "created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]"
def test_translate_query_iso_range_preserved(self) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert translate_query(q, UTC) == q
def test_translate_query_comma_separated_iso_ranges(self) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
result = translate_query(q, UTC)
assert result == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_invalid_iso_datetime_raises(self) -> None:
# A token with "T" that is not valid ISO datetime -> raise.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range(
"created",
"2020-01-01T99:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert exc_info.value.field == "created"
assert exc_info.value.value == "2020-01-01T99:00:00Z"
def test_parse_acceptance_iso_bounds(self, index: tantivy.Index) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_comma_iso_ranges(self, index: tantivy.Index) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@@ -82,7 +82,6 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
"llm_api_key": None,
"llm_endpoint": None,
"llm_output_language": None,
"llm_request_timeout": None,
},
)
@@ -1,95 +0,0 @@
import unicodedata
from typing import TYPE_CHECKING
from unittest import mock
import celery.result
import pytest
from django.core.files.uploadedfile import SimpleUploadedFile
if TYPE_CHECKING:
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides
@pytest.fixture()
def consume_file_mock():
with mock.patch("documents.tasks.consume_file.apply_async") as m:
m.return_value = celery.result.AsyncResult(id="test-task-id")
yield m
@pytest.fixture()
def directories(tmp_path, settings, _media_settings):
scratch = tmp_path / "scratch"
scratch.mkdir()
settings.SCRATCH_DIR = scratch
return scratch
@pytest.mark.django_db
class TestPostDocumentNFCNormalization:
def test_nfd_filename_normalized_to_nfc(
self,
admin_client,
consume_file_mock: mock.MagicMock,
directories,
):
"""Uploaded file with NFD filename must have its name stored as NFC."""
nfd = unicodedata.normalize("NFD", "Rechnung März.pdf")
nfc = unicodedata.normalize("NFC", "Rechnung März.pdf")
# Verify our test strings actually differ at the byte level
assert nfd != nfc
uploaded = SimpleUploadedFile(
nfd,
b"%PDF-1.4 test",
content_type="application/pdf",
)
response = admin_client.post(
"/api/documents/post_document/",
{"document": uploaded},
)
assert response.status_code == 200
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
input_doc: ConsumableDocument = task_kwargs["input_doc"]
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
# The temp file on disk must have an NFC name
assert input_doc.original_file.name == nfc, (
f"Expected NFC filename {nfc!r}, got {input_doc.original_file.name!r}"
)
# The override filename stored for later use must also be NFC
assert overrides.filename == nfc, (
f"Expected NFC override filename {nfc!r}, got {overrides.filename!r}"
)
assert unicodedata.is_normalized("NFC", overrides.filename)
def test_already_nfc_filename_unchanged(
self,
admin_client,
consume_file_mock: mock.MagicMock,
directories,
):
"""Uploaded file with already-NFC filename must pass through unchanged."""
nfc = unicodedata.normalize("NFC", "Invoice_2024.pdf")
uploaded = SimpleUploadedFile(
nfc,
b"%PDF-1.4 test",
content_type="application/pdf",
)
response = admin_client.post(
"/api/documents/post_document/",
{"document": uploaded},
)
assert response.status_code == 200
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
assert overrides.filename == nfc
assert unicodedata.is_normalized("NFC", overrides.filename)
+3 -6
View File
@@ -725,11 +725,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
GIVEN:
- One document added right now
WHEN:
- Query with an invalid added date
- Query with invalid added date
THEN:
- 400 Bad Request with a message naming the malformed date, so the
user knows their date is invalid rather than silently getting zero
results
- 400 Bad Request returned (Tantivy rejects invalid date field syntax)
"""
d1 = Document.objects.create(
title="invoice",
@@ -742,9 +740,8 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
response = self.client.get("/api/documents/?query=added:invalid-date")
# An unparsable date is reported as a malformed query, not silently empty.
# Tantivy rejects unparsable field queries with a 400
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("invalid-date", str(response.data["query"]))
@override_settings(
TIME_ZONE="UTC",
-71
View File
@@ -216,77 +216,6 @@ class TestSystemStatus(APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_none(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns no worker responses
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
mock_ping.return_value = None
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(
response.data["tasks"]["celery_error"],
"No celery workers responded to ping. This may be temporary.",
)
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_unexpected_responses(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns an unexpected worker response
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
self.client.force_login(self.user)
for ping_response in (
{"hostname": {"ok": "not-pong"}},
{"hostname": {}},
{"hostname": "pong"},
):
with self.subTest(ping_response=ping_response):
mock_ping.return_value = ping_response
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(response.data["tasks"]["celery_url"], "hostname")
self.assertEqual(
response.data["tasks"]["celery_error"],
"Celery worker responded unexpectedly.",
)
@mock.patch("documents.views.sleep")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_retry_success(
self,
mock_ping,
mock_sleep,
) -> None:
"""
GIVEN:
- Celery ping fails once but succeeds on retry
WHEN:
- The user requests the system status
THEN:
- The response contains an OK celery status
"""
mock_ping.side_effect = [None, {"hostname": {"ok": "pong"}}]
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
self.assertIsNone(response.data["tasks"]["celery_error"])
self.assertEqual(mock_ping.call_count, 2)
mock_sleep.assert_called_once_with(0.25)
@mock.patch("documents.search.get_backend")
def test_system_status_index_ok(self, mock_get_backend) -> None:
"""
-187
View File
@@ -1,187 +0,0 @@
"""
Tests for NFC Unicode normalization in generate_filename / FilePathTemplate.render().
NFC `ü` (UTF-8: c3 bc) and NFD `ü` (UTF-8: 75 cc 88) are visually identical but
produce different byte sequences. On Linux (ext4, ZFS) these are distinct filenames.
All paths produced by the templating system must be NFC-normalized.
"""
import unicodedata
import pytest
from documents.file_handling import generate_filename
from documents.models import CustomField
from documents.models import CustomFieldInstance
from documents.tests.factories import CorrespondentFactory
from documents.tests.factories import DocumentFactory
from documents.tests.factories import StoragePathFactory
from documents.tests.factories import TagFactory
@pytest.mark.django_db
class TestGenerateFilenameNFCNormalization:
@pytest.mark.parametrize(
"raw,display",
[
(unicodedata.normalize("NFD", "Gemüse"), "Gemüse"),
(unicodedata.normalize("NFD", "Café"), "Café"),
(unicodedata.normalize("NFD", "naïve"), "naïve"),
],
)
def test_nfd_title_normalized_to_nfc(self, settings, raw, display):
"""NFD title must produce NFC path bytes."""
settings.FILENAME_FORMAT = "{{ title }}"
nfc = unicodedata.normalize("NFC", display)
assert raw != nfc # confirm byte-level difference
doc = DocumentFactory(title=raw, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result) == f"{nfc}.pdf"
assert str(result).encode() == f"{nfc}.pdf".encode()
def test_nfd_correspondent_normalized_to_nfc(self, settings):
"""NFD correspondent name must produce NFC path component."""
settings.FILENAME_FORMAT = "{{ correspondent }}/{{ title }}"
nfd = unicodedata.normalize("NFD", "Müller")
nfc = unicodedata.normalize("NFC", "Müller")
correspondent = CorrespondentFactory(name=nfd)
doc = DocumentFactory(
title="invoice",
correspondent=correspondent,
mime_type="application/pdf",
)
result = generate_filename(doc)
assert str(result) == f"{nfc}/invoice.pdf"
assert str(result).encode() == f"{nfc}/invoice.pdf".encode()
def test_nfd_storage_path_normalized_to_nfc(self, settings):
"""NFD literal in StoragePath.path template must produce NFC path bytes."""
settings.FILENAME_FORMAT = None
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
# StoragePath.path is used directly as the format/template string.
# Literal NFD characters in the template must survive rendering as NFC.
sp = StoragePathFactory(path=f"{nfd}/{{{{ title }}}}")
doc = DocumentFactory(title="doc", storage_path=sp, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
def test_nfd_raw_document_title_normalized_to_nfc(self, settings):
"""NFD title accessed via document.title (unsanitized context) must also be NFC."""
settings.FILENAME_FORMAT = "{{ document.title }}"
nfd = unicodedata.normalize("NFD", "Café")
nfc = unicodedata.normalize("NFC", "Café")
doc = DocumentFactory(title=nfd, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result) == f"{nfc}.pdf"
assert str(result).encode() == f"{nfc}.pdf".encode()
@pytest.mark.django_db
class TestContextBuilderNFCNormalization:
"""
Defense-in-depth: context builder functions must NFC-normalize string inputs
before passing them to sanitize_filename(). Task 1 already normalizes the
final rendered path via clean_filepath(), so these tests may already pass;
they exist as regression guards for the context-builder layer.
"""
def test_nfd_tag_name_normalized_in_tag_list(self, settings):
"""NFD tag name must appear as NFC bytes in the {{ tag_list }} shorthand."""
settings.FILENAME_FORMAT = "{{ tag_list }}/{{ title }}"
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
assert nfd != nfc # confirm they differ at byte level
tag = TagFactory(name=nfd)
doc = DocumentFactory(title="doc", mime_type="application/pdf")
doc.tags.set([tag])
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
def test_nfd_original_name_normalized_to_nfc(self, settings):
settings.FILENAME_FORMAT = "{{ original_name }}"
nfd = unicodedata.normalize("NFD", "Rechnung März")
nfc = unicodedata.normalize("NFC", "Rechnung März")
doc = DocumentFactory(
original_filename=f"{nfd}.pdf",
mime_type="application/pdf",
)
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}.pdf".encode()
def test_nfd_custom_field_string_value_normalized(self, settings):
"""NFD value in a STRING-type custom field must appear as NFC in the context."""
settings.FILENAME_FORMAT = (
"{{ custom_fields['Location']['value'] }}/{{ title }}"
)
nfd_value = unicodedata.normalize("NFD", "Düsseldorf")
nfc_value = unicodedata.normalize("NFC", "Düsseldorf")
assert nfd_value != nfc_value
doc = DocumentFactory(title="report", mime_type="application/pdf")
cf = CustomField.objects.create(
name="Location",
data_type=CustomField.FieldDataType.STRING,
)
CustomFieldInstance.objects.create(
document=doc,
field=cf,
value_text=nfd_value,
)
result = generate_filename(doc)
assert str(result).encode() == f"{nfc_value}/report.pdf".encode()
def test_nfd_custom_field_name_normalized_as_key(self, settings):
"""NFD characters in a custom field name must appear as NFC in the context dict key."""
nfd_name = unicodedata.normalize("NFD", "Größe")
nfc_name = unicodedata.normalize("NFC", "Größe")
assert nfd_name != nfc_name
settings.FILENAME_FORMAT = f"{{% if custom_fields['{nfc_name}'] %}}{{{{ custom_fields['{nfc_name}']['value'] }}}}/{{{{ title }}}}{{% else %}}{{{{ title }}}}{{% endif %}}"
doc = DocumentFactory(title="letter", mime_type="application/pdf")
cf = CustomField.objects.create(
name=nfd_name,
data_type=CustomField.FieldDataType.STRING,
)
CustomFieldInstance.objects.create(
document=doc,
field=cf,
value_text="Berlin",
)
result = generate_filename(doc)
# If field name key is NFC-normalized, the template condition succeeds
# and result is "Berlin/letter.pdf"; otherwise it falls back to "letter.pdf"
assert str(result) == "Berlin/letter.pdf"
def test_nfd_tag_name_list_normalized_to_nfc(self, settings):
"""NFD tag names in tag_name_list must appear as NFC bytes when iterated."""
settings.FILENAME_FORMAT = (
"{% for t in tag_name_list %}{{ t }}{% endfor %}/{{ title }}"
)
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
assert nfd != nfc # confirm byte-level difference
doc = DocumentFactory(title="doc", mime_type="application/pdf")
doc.tags.add(TagFactory(name=nfd))
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
@@ -684,7 +684,6 @@ class ConsumerThread(Thread):
subdirs_as_tags: bool = False,
polling_interval: float = 0,
stability_delay: float = 0.1,
rescan_interval: float | None = None,
) -> None:
super().__init__()
self.consumption_dir = consumption_dir
@@ -694,8 +693,6 @@ class ConsumerThread(Thread):
self.polling_interval = polling_interval
self.stability_delay = stability_delay
self.cmd = Command()
if rescan_interval is not None:
self.cmd.rescan_interval_s = rescan_interval
self.cmd.stop_flag.clear()
# Non-daemon ensures finally block runs and connections are closed
self.daemon = False
@@ -1055,200 +1052,3 @@ class TestCommandWatchEdgeCases:
thread.stop_and_wait(timeout=5.0)
# Clean up any Tags created by the thread
Tag.objects.all().delete()
class TestRescanExistingFiles:
"""
Unit tests for the rescan safety net.
Each ``watch()`` recreation silently adopts the current directory contents
as its baseline, so a file appearing between one batch and the next
watcher's baseline is never reported and would sit in the consume directory
forever. ``_rescan_existing_files`` re-injects such files into the
stability tracker as a periodic safety net (see GH issue #13011).
"""
@pytest.fixture
def pdf_only_filter(self) -> ConsumerFilter:
return ConsumerFilter(
supported_extensions=frozenset({".pdf"}),
ignore_patterns=[],
)
def _rescan(
self,
directory: Path,
consumer_filter: ConsumerFilter,
tracker: FileStabilityTracker,
queued: set[Path],
*,
recursive: bool = False,
) -> None:
Command()._rescan_existing_files(
directory=directory,
recursive=recursive,
consumer_filter=consumer_filter,
tracker=tracker,
queued=queued,
)
def test_tracks_stranded_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A supported on-disk file the watcher never reported gets tracked."""
target = consumption_dir / "stranded.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.is_tracking(target) is True
assert tracker.pending_count == 1
def test_skips_already_tracked_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A file already being tracked by the watcher is not double-tracked."""
target = consumption_dir / "tracked.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
tracker.track(target, Change.added)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.pending_count == 1
def test_skips_queued_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A file already queued and awaiting consumption is not re-tracked."""
target = consumption_dir / "inflight.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
queued = {target.resolve()}
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
assert tracker.pending_count == 0
def test_prunes_vanished_queued_paths(
self,
consumption_dir: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Queued paths no longer on disk are dropped so the name can recur."""
gone = (consumption_dir / "gone.pdf").resolve()
tracker = FileStabilityTracker(stability_delay=0.1)
queued = {gone}
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
assert gone not in queued
def test_skips_unsupported_extension(
self,
consumption_dir: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Files filtered out by the consumer filter are not tracked."""
(consumption_dir / "notes.xyz").write_bytes(b"content")
tracker = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.pending_count == 0
def test_recursive_respects_flag(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Nested files are only found when recursive scanning is enabled."""
subdir = consumption_dir / "nested"
subdir.mkdir()
target = subdir / "deep.pdf"
shutil.copy(sample_pdf, target)
shallow = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, shallow, set())
assert shallow.pending_count == 0
deep = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, deep, set(), recursive=True)
assert deep.is_tracking(target) is True
class TestProcessExistingFilesQueued:
"""Tests that startup processing reports which paths it queued."""
@pytest.mark.usefixtures("mock_supported_extensions")
def test_returns_queued_paths(
self,
consumption_dir: Path,
sample_pdf: Path,
mock_consume_file_delay: MagicMock,
settings: SettingsWrapper,
) -> None:
"""The set returned seeds the rescan's queued set, avoiding re-queue."""
target = consumption_dir / "document.pdf"
shutil.copy(sample_pdf, target)
settings.CONSUMER_IGNORE_PATTERNS = []
queued = Command()._process_existing_files(
directory=consumption_dir,
recursive=False,
subdirs_as_tags=False,
consumer_filter=ConsumerFilter(ignore_patterns=[]),
)
assert target.resolve() in queued
@pytest.mark.management
@pytest.mark.django_db
class TestCommandRescanRecovery:
"""End-to-end test that the rescan recovers files the watcher misses."""
def test_rescan_consumes_file_the_watcher_never_reports(
self,
consumption_dir: Path,
sample_pdf: Path,
mock_consume_file_delay: MagicMock,
start_consumer: Callable[..., ConsumerThread],
) -> None:
"""
Isolate the rescan path: a long polling interval guarantees the
watcher cannot report the file within the test window, so only the
periodic rescan can consume it.
"""
# poll interval far longer than the test window -> watcher stays silent
thread = start_consumer(
polling_interval=30.0,
stability_delay=0.1,
rescan_interval=0.5,
)
# created after startup, so _process_existing_files did not see it
target = consumption_dir / "stranded.pdf"
shutil.copy(sample_pdf, target)
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=5.0)
if thread.exception:
raise thread.exception
mock_consume_file_delay.apply_async.assert_called()
call_args = mock_consume_file_delay.apply_async.call_args.kwargs["kwargs"][
"input_doc"
]
assert call_args.original_file.name == "stranded.pdf"
@@ -615,7 +615,7 @@ class TestExportImport(
self.assertIsFile(expected_file)
with ZipFile(expected_file) as zip:
# Extras are from the directories, which also appear in the listing
# 11 files + 3 directory marker entries for the subdirectory structure
self.assertEqual(len(zip.namelist()), 14)
self.assertIn("manifest.json", zip.namelist())
self.assertIn("metadata.json", zip.namelist())
@@ -666,6 +666,57 @@ class TestExportImport(
self.assertIn("manifest.json", zip.namelist())
self.assertIn("metadata.json", zip.namelist())
def test_export_zip_atomic_on_failure(self) -> None:
"""
GIVEN:
- Request to export documents to zipfile
WHEN:
- Export raises an exception mid-way
THEN:
- No .zip file is written at the final path
- The .tmp file is cleaned up
"""
args = ["document_exporter", self.target, "--zip"]
with mock.patch.object(
document_exporter.Command,
"dump",
side_effect=RuntimeError("simulated failure"),
):
with self.assertRaises(RuntimeError):
call_command(*args)
expected_zip = self.target / f"export-{timezone.localdate().isoformat()}.zip"
expected_tmp = (
self.target / f"export-{timezone.localdate().isoformat()}.zip.tmp"
)
self.assertIsNotFile(expected_zip)
self.assertIsNotFile(expected_tmp)
def test_export_zip_no_scratch_dir(self) -> None:
"""
GIVEN:
- Request to export documents to zipfile
WHEN:
- Documents are exported
THEN:
- No files are written under SCRATCH_DIR during the export
(the old workaround used a temp dir there)
"""
shutil.rmtree(Path(self.dirs.media_dir) / "documents")
shutil.copytree(
Path(__file__).parent / "samples" / "documents",
Path(self.dirs.media_dir) / "documents",
)
scratch_before = set(settings.SCRATCH_DIR.glob("paperless-export*"))
args = ["document_exporter", self.target, "--zip"]
call_command(*args)
scratch_after = set(settings.SCRATCH_DIR.glob("paperless-export*"))
self.assertEqual(scratch_before, scratch_after)
def test_export_target_not_exists(self) -> None:
"""
GIVEN:
-28
View File
@@ -30,7 +30,6 @@ from documents.signals.handlers import update_llm_suggestions_cache
from documents.tests.utils import DirectoriesMixin
from documents.tests.utils import read_streaming_response
from paperless.models import ApplicationConfiguration
from paperless_ai.exceptions import LLMTimeoutError
class TestViews(DirectoriesMixin, TestCase):
@@ -477,33 +476,6 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
)
@patch("documents.views.get_ai_document_classification")
@override_settings(
AI_ENABLED=True,
LLM_BACKEND="openai-like",
)
def test_ai_suggestions_with_llm_timeout(
self,
mock_get_ai_classification,
) -> None:
mock_get_ai_classification.side_effect = LLMTimeoutError()
self.client.force_login(user=self.user)
response = self.client.get(
f"/api/documents/{self.document.pk}/ai_suggestions/",
)
self.assertEqual(response.status_code, status.HTTP_503_SERVICE_UNAVAILABLE)
self.assertEqual(
response.json(),
{
"ai": ["AI backend request timed out."],
},
)
self.assertIsNone(
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
)
def test_invalidate_suggestions_cache(self) -> None:
self.client.force_login(user=self.user)
suggestions = {
+5 -43
View File
@@ -12,7 +12,6 @@ from datetime import timedelta
from http import HTTPStatus
from pathlib import Path
from time import mktime
from time import sleep
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
@@ -241,7 +240,6 @@ from paperless.serialisers import UserSerializer
from paperless.views import StandardPagination
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.chat import stream_chat_with_documents
from paperless_ai.exceptions import LLMTimeoutError
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
@@ -1511,17 +1509,6 @@ class DocumentViewSet(
exc_info=True,
)
raise ValidationError({"ai": [_("Invalid AI configuration.")]}) from exc
except LLMTimeoutError as exc:
logger.exception(
"AI backend timed out while generating suggestions for document %s: %s",
doc.pk,
exc,
exc_info=True,
)
return Response(
{"ai": [_("AI backend request timed out.")]},
status=status.HTTP_503_SERVICE_UNAVAILABLE,
)
matched_tags = match_tags_by_name(
llm_suggestions.get("tags", []),
@@ -2289,7 +2276,6 @@ class UnifiedSearchViewSet(DocumentViewSet):
return super().list(request)
from documents.search import SearchHit
from documents.search import SearchQueryError
from documents.search import TantivyBackend
from documents.search import TantivyRelevanceList
from documents.search import get_backend
@@ -2482,11 +2468,6 @@ class UnifiedSearchViewSet(DocumentViewSet):
return HttpResponseForbidden(_("Insufficient permissions."))
except ValidationError:
raise
except SearchQueryError as e:
# User-fixable query error (e.g. an unparsable date): surface the
# specific message so the user can correct it, rather than a generic
# 400 or silently empty results.
raise ValidationError({"query": [str(e)]}) from e
except Exception as e:
logger.warning(f"An error occurred listing search results: {e!s}")
return HttpResponseBadRequest(
@@ -3145,7 +3126,6 @@ class PostDocumentView(GenericAPIView[Any]):
serializer.is_valid(raise_exception=True)
doc_name, doc_data = serializer.validated_data.get("document")
doc_name = normalize("NFC", doc_name)
correspondent_id = serializer.validated_data.get("correspondent")
document_type_id = serializer.validated_data.get("document_type")
storage_path_id = serializer.validated_data.get("storage_path")
@@ -5009,29 +4989,11 @@ class SystemStatusView(PassUserMixin):
celery_error = None
celery_url = None
try:
celery_ping = None
for ping_attempt in range(3):
celery_ping = celery_app.control.inspect().ping()
if celery_ping:
break
if ping_attempt < 2:
sleep(0.25)
if not celery_ping:
celery_active = "WARNING"
celery_error = (
"No celery workers responded to ping. This may be temporary."
)
else:
celery_url, first_worker_ping = next(iter(celery_ping.items()))
if (
isinstance(first_worker_ping, dict)
and first_worker_ping.get("ok") == "pong"
):
celery_active = "OK"
else:
celery_active = "WARNING"
celery_error = "Celery worker responded unexpectedly."
celery_ping = celery_app.control.inspect().ping()
celery_url = next(iter(celery_ping.keys()))
first_worker_ping = celery_ping[celery_url]
if first_worker_ping["ok"] == "pong":
celery_active = "OK"
except Exception as e:
celery_active = "ERROR"
logger.exception(
-4
View File
@@ -197,7 +197,6 @@ class AIConfig(BaseConfig):
llm_embedding_endpoint: str = dataclasses.field(init=False)
llm_embedding_chunk_size: int = dataclasses.field(init=False)
llm_context_size: int = dataclasses.field(init=False)
llm_request_timeout: int = dataclasses.field(init=False)
llm_backend: str = dataclasses.field(init=False)
llm_model: str = dataclasses.field(init=False)
llm_api_key: str = dataclasses.field(init=False)
@@ -222,9 +221,6 @@ class AIConfig(BaseConfig):
app_config.llm_embedding_chunk_size or settings.LLM_EMBEDDING_CHUNK_SIZE
)
self.llm_context_size = app_config.llm_context_size or settings.LLM_CONTEXT_SIZE
self.llm_request_timeout = (
app_config.llm_request_timeout or settings.LLM_REQUEST_TIMEOUT
)
self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND
self.llm_model = app_config.llm_model or settings.LLM_MODEL
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
@@ -1,365 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:30
import django.core.validators
from django.db import migrations
from django.db import models
def _create_singleton(apps, schema_editor):
settings_model = apps.get_model("paperless", "ApplicationConfiguration")
settings_model.objects.create()
class Migration(migrations.Migration):
replaces = [
("paperless", "0001_initial"),
("paperless", "0002_applicationconfiguration_app_logo_and_more"),
("paperless", "0003_alter_applicationconfiguration_max_image_pixels"),
("paperless", "0004_applicationconfiguration_barcode_asn_prefix_and_more"),
("paperless", "0005_applicationconfiguration_ai_enabled_and_more"),
("paperless", "0006_applicationconfiguration_barcode_tag_split"),
]
dependencies = []
operations = [
migrations.CreateModel(
name="ApplicationConfiguration",
fields=[
(
"id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
(
"output_type",
models.CharField(
blank=True,
choices=[
("pdf", "pdf"),
("pdfa", "pdfa"),
("pdfa-1", "pdfa-1"),
("pdfa-2", "pdfa-2"),
("pdfa-3", "pdfa-3"),
],
max_length=8,
null=True,
verbose_name="Sets the output PDF type",
),
),
(
"pages",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Do OCR from page 1 to this value",
),
),
(
"language",
models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Do OCR using these languages",
),
),
(
"mode",
models.CharField(
blank=True,
choices=[
("skip", "skip"),
("redo", "redo"),
("force", "force"),
("skip_noarchive", "skip_noarchive"),
],
max_length=16,
null=True,
verbose_name="Sets the OCR mode",
),
),
(
"skip_archive_file",
models.CharField(
blank=True,
choices=[
("never", "never"),
("with_text", "with_text"),
("always", "always"),
],
max_length=16,
null=True,
verbose_name="Controls the generation of an archive file",
),
),
(
"image_dpi",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets image DPI fallback value",
),
),
(
"unpaper_clean",
models.CharField(
blank=True,
choices=[
("clean", "clean"),
("clean-final", "clean-final"),
("none", "none"),
],
max_length=16,
null=True,
verbose_name="Controls the unpaper cleaning",
),
),
(
"deskew",
models.BooleanField(null=True, verbose_name="Enables deskew"),
),
(
"rotate_pages",
models.BooleanField(
null=True,
verbose_name="Enables page rotation",
),
),
(
"rotate_pages_threshold",
models.FloatField(
null=True,
validators=[django.core.validators.MinValueValidator(0.0)],
verbose_name="Sets the threshold for rotation of pages",
),
),
(
"max_image_pixels",
models.FloatField(
null=True,
validators=[django.core.validators.MinValueValidator(0.0)],
verbose_name="Sets the maximum image size for decompression",
),
),
(
"color_conversion_strategy",
models.CharField(
blank=True,
choices=[
("LeaveColorUnchanged", "LeaveColorUnchanged"),
("RGB", "RGB"),
("UseDeviceIndependentColor", "UseDeviceIndependentColor"),
("Gray", "Gray"),
("CMYK", "CMYK"),
],
max_length=32,
null=True,
verbose_name="Sets the Ghostscript color conversion strategy",
),
),
(
"user_args",
models.JSONField(
null=True,
verbose_name="Adds additional user arguments for OCRMyPDF",
),
),
(
"app_logo",
models.FileField(
blank=True,
null=True,
upload_to="logo/",
validators=[
django.core.validators.FileExtensionValidator(
allowed_extensions=["jpg", "png", "gif", "svg"],
),
],
verbose_name="Application logo",
),
),
(
"app_title",
models.CharField(
blank=True,
max_length=48,
null=True,
verbose_name="Application title",
),
),
(
"barcode_asn_prefix",
models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Sets the ASN barcode prefix",
),
),
(
"barcode_dpi",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the barcode DPI",
),
),
(
"barcode_enable_asn",
models.BooleanField(
null=True,
verbose_name="Enables ASN barcode",
),
),
(
"barcode_enable_tag",
models.BooleanField(
null=True,
verbose_name="Enables tag barcode",
),
),
(
"barcode_enable_tiff_support",
models.BooleanField(
null=True,
verbose_name="Enables barcode TIFF support",
),
),
(
"barcode_max_pages",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the maximum pages for barcode",
),
),
(
"barcode_retain_split_pages",
models.BooleanField(
null=True,
verbose_name="Retains split pages",
),
),
(
"barcode_string",
models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Sets the barcode string",
),
),
(
"barcode_tag_mapping",
models.JSONField(
null=True,
verbose_name="Sets the tag barcode mapping",
),
),
(
"barcode_upscale",
models.FloatField(
null=True,
validators=[django.core.validators.MinValueValidator(1.0)],
verbose_name="Sets the barcode upscale factor",
),
),
(
"barcodes_enabled",
models.BooleanField(
null=True,
verbose_name="Enables barcode scanning",
),
),
(
"ai_enabled",
models.BooleanField(
default=False,
null=True,
verbose_name="Enables AI features",
),
),
(
"llm_api_key",
models.CharField(
blank=True,
max_length=1024,
null=True,
verbose_name="Sets the LLM API key",
),
),
(
"llm_backend",
models.CharField(
blank=True,
choices=[
("openai-like", "OpenAI-compatible"),
("ollama", "Ollama"),
],
max_length=128,
null=True,
verbose_name="Sets the LLM backend",
),
),
(
"llm_embedding_backend",
models.CharField(
blank=True,
choices=[
("openai-like", "OpenAI-compatible"),
("huggingface", "Huggingface"),
],
max_length=128,
null=True,
verbose_name="Sets the LLM embedding backend",
),
),
(
"llm_embedding_model",
models.CharField(
blank=True,
max_length=128,
null=True,
verbose_name="Sets the LLM embedding model",
),
),
(
"llm_endpoint",
models.CharField(
blank=True,
max_length=256,
null=True,
verbose_name="Sets the LLM endpoint, optional",
),
),
(
"llm_model",
models.CharField(
blank=True,
max_length=128,
null=True,
verbose_name="Sets the LLM model",
),
),
(
"barcode_tag_split",
models.BooleanField(
null=True,
verbose_name="Enables splitting on tag barcodes",
),
),
],
options={
"verbose_name": "paperless application settings",
},
),
migrations.RunPython(
code=_create_singleton,
reverse_code=migrations.RunPython.noop,
),
]
@@ -1,94 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:19
import django.core.validators
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
replaces = [
("paperless", "0009_alter_applicationconfiguration_options"),
("paperless", "0010_alter_applicationconfiguration_llm_embedding_backend"),
("paperless", "0011_applicationconfiguration_llm_embedding_chunk_size"),
("paperless", "0012_applicationconfiguration_llm_output_language"),
("paperless", "0013_applicationconfiguration_llm_request_timeout"),
]
dependencies = [
("paperless", "0008_replace_skip_archive_file"),
]
operations = [
migrations.AlterModelOptions(
name="applicationconfiguration",
options={
"permissions": [
("view_global_statistics", "Can view global object counts"),
("view_system_monitoring", "Can view system status information"),
],
"verbose_name": "paperless application settings",
},
),
migrations.AlterField(
model_name="applicationconfiguration",
name="llm_embedding_backend",
field=models.CharField(
blank=True,
choices=[
("openai-like", "OpenAI-compatible"),
("huggingface", "Huggingface"),
("ollama", "Ollama"),
],
max_length=128,
null=True,
verbose_name="Sets the LLM embedding backend",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_embedding_endpoint",
field=models.CharField(
blank=True,
max_length=256,
null=True,
verbose_name="Sets the LLM embedding endpoint, optional",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_embedding_chunk_size",
field=models.PositiveSmallIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM embedding chunk size",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_context_size",
field=models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM context size",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_output_language",
field=models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Sets the LLM output language",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_request_timeout",
field=models.PositiveSmallIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM request timeout in seconds",
),
),
]
@@ -1,23 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-14 14:22
import django.core.validators
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
dependencies = [
("paperless", "0012_applicationconfiguration_llm_output_language"),
]
operations = [
migrations.AddField(
model_name="applicationconfiguration",
name="llm_request_timeout",
field=models.PositiveSmallIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM request timeout in seconds",
),
),
]
-6
View File
@@ -366,12 +366,6 @@ class ApplicationConfiguration(AbstractSingletonModel):
max_length=32,
)
llm_request_timeout = models.PositiveSmallIntegerField(
verbose_name=_("Sets the LLM timeout in seconds"),
null=True,
validators=[MinValueValidator(1)],
)
class Meta:
verbose_name = _("paperless application settings")
permissions = [
+28 -2
View File
@@ -20,7 +20,6 @@ from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from paperless.parsers.utils import read_file_handle_unicode_errors
from paperless.version import __full_version_str__
if TYPE_CHECKING:
@@ -184,7 +183,7 @@ class TextDocumentParser:
documents.parsers.ParseError
If the file cannot be read.
"""
self._text = read_file_handle_unicode_errors(document_path, log=logger)
self._text = self._read_text(document_path)
# ------------------------------------------------------------------
# Result accessors
@@ -296,3 +295,30 @@ class TextDocumentParser:
Always ``[]`` plain text files carry no structured metadata.
"""
return []
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _read_text(self, filepath: Path) -> str:
"""Read file content, replacing invalid UTF-8 bytes rather than failing.
Parameters
----------
filepath:
Path to the file to read.
Returns
-------
str
File content as a string.
"""
try:
return filepath.read_text(encoding="utf-8")
except UnicodeDecodeError as exc:
logger.warning(
"Unicode error reading %s, replacing bad bytes: %s",
filepath,
exc,
)
return filepath.read_bytes().decode("utf-8", errors="replace")
+5 -18
View File
@@ -8,7 +8,6 @@ share implementation.
from __future__ import annotations
import codecs
import logging
import re
import tempfile
@@ -115,7 +114,7 @@ def read_file_handle_unicode_errors(
filepath: Path,
log: logging.Logger | None = None,
) -> str:
"""Read a file as text, detecting encoding via BOM and stripping NUL bytes.
"""Read a file as UTF-8 text, replacing invalid bytes rather than raising.
Parameters
----------
@@ -128,27 +127,15 @@ def read_file_handle_unicode_errors(
Returns
-------
str
File content as a string, with NUL bytes removed so the result is
safe to store in PostgreSQL text fields.
File content as a string, with any invalid UTF-8 sequences replaced
by the Unicode replacement character.
"""
_log = log or logger
raw = filepath.read_bytes()
if raw.startswith((codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE)):
encoding = "utf-16"
elif raw.startswith(codecs.BOM_UTF8):
encoding = "utf-8-sig"
else:
encoding = "utf-8"
try:
text = raw.decode(encoding)
return filepath.read_text(encoding="utf-8")
except UnicodeDecodeError as e:
_log.warning("Unicode error during text reading, continuing: %s", e)
text = raw.decode("utf-8", errors="replace")
# PostgreSQL rejects NUL (0x00) bytes in text fields
return text.replace("\x00", "")
return filepath.read_bytes().decode("utf-8", errors="replace")
def get_page_count_for_pdf(
-10
View File
@@ -98,13 +98,6 @@ MODEL_FILE = get_path_from_env(
)
LLM_INDEX_DIR = DATA_DIR / "llm_index"
LLM_INDEX_LOCK = LLM_INDEX_DIR / "index.lock"
# Cross-process read/write lock guarding the LLM index compaction/migration
# file swap. Readers hold it shared; the swap takes it exclusively so it never
# runs while a reader connection is open. Must be a SQLite (.db) file.
LLM_INDEX_RWLOCK = LLM_INDEX_DIR / "llmindex.rwlock.db"
# Seconds the compaction swap waits for active readers to drain before skipping
# this cycle (it is a maintenance operation; the next run retries).
LLM_INDEX_COMPACTION_LOCK_TIMEOUT = 30
LOGGING_DIR = get_path_from_env("PAPERLESS_LOGGING_DIR", DATA_DIR / "log")
@@ -1206,9 +1199,6 @@ if LLM_EMBEDDING_CHUNK_SIZE < 1:
LLM_CONTEXT_SIZE = get_int_from_env("PAPERLESS_AI_LLM_CONTEXT_SIZE", 8192)
if LLM_CONTEXT_SIZE < 1:
raise ImproperlyConfigured("PAPERLESS_AI_LLM_CONTEXT_SIZE must be >= 1")
LLM_REQUEST_TIMEOUT = get_int_from_env("PAPERLESS_AI_LLM_REQUEST_TIMEOUT", 120)
if LLM_REQUEST_TIMEOUT < 1:
raise ImproperlyConfigured("PAPERLESS_AI_LLM_REQUEST_TIMEOUT must be >= 1")
LLM_BACKEND = get_choice_from_env(
"PAPERLESS_AI_LLM_BACKEND",
{"ollama", "openai-like"},
-37
View File
@@ -2,50 +2,13 @@
from __future__ import annotations
import codecs
from pathlib import Path
from paperless.parsers.utils import is_tagged_pdf
from paperless.parsers.utils import read_file_handle_unicode_errors
SAMPLES = Path(__file__).parent / "samples" / "tesseract"
class TestReadFileHandleUnicodeErrors:
def test_plain_utf8(self, tmp_path: Path) -> None:
f = tmp_path / "plain.txt"
f.write_bytes(b"hello world")
assert read_file_handle_unicode_errors(f) == "hello world"
def test_utf8_bom(self, tmp_path: Path) -> None:
f = tmp_path / "bom.txt"
f.write_bytes(codecs.BOM_UTF8 + b"hello")
assert read_file_handle_unicode_errors(f) == "hello"
def test_utf16_le(self, tmp_path: Path) -> None:
f = tmp_path / "utf16le.txt"
f.write_bytes(codecs.BOM_UTF16_LE + "hello".encode("utf-16-le"))
assert read_file_handle_unicode_errors(f) == "hello"
def test_utf16_be(self, tmp_path: Path) -> None:
f = tmp_path / "utf16be.txt"
f.write_bytes(codecs.BOM_UTF16_BE + "hello".encode("utf-16-be"))
assert read_file_handle_unicode_errors(f) == "hello"
def test_nul_bytes_stripped(self, tmp_path: Path) -> None:
f = tmp_path / "null-bytes.txt"
f.write_bytes(b"foo\x00bar")
assert read_file_handle_unicode_errors(f) == "foobar"
def test_invalid_utf8_replaced(self, tmp_path: Path) -> None:
f = tmp_path / "bad.txt"
f.write_bytes(b"ok\x80\x81bad")
result = read_file_handle_unicode_errors(f)
assert "ok" in result
assert "bad" in result
assert "\x00" not in result
class TestIsTaggedPdf:
def test_tagged_pdf_returns_true(self) -> None:
assert is_tagged_pdf(SAMPLES / "simple-digital.pdf") is True
+42 -49
View File
@@ -9,7 +9,6 @@ from paperless_ai.db import db_connection_released
from paperless_ai.indexing import _document_id_filters
from paperless_ai.indexing import get_rag_prompt_helper
from paperless_ai.indexing import load_or_build_index
from paperless_ai.indexing import read_store
logger = logging.getLogger("paperless_ai.chat")
@@ -98,59 +97,53 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]):
from llama_index.core.retrievers import VectorIndexRetriever
config = AIConfig()
index = load_or_build_index(config)
filters = _document_id_filters(str(doc.pk) for doc in documents)
# Hold the shared read lock for the whole operation: the query engine
# retrieves from the vector store again during synthesis, so the connection
# must stay open (and the swap must not run) until the stream finishes.
with read_store() as store:
index = load_or_build_index(config, store)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=CHAT_RETRIEVER_TOP_K,
filters=filters,
)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=CHAT_RETRIEVER_TOP_K,
filters=filters,
)
# Slow query-embedding + vector search; no Django ORM access happens
# during it, so release the pooled DB connection for its duration. See
# #12976.
with db_connection_released():
top_nodes = retriever.retrieve(query_str)
if not top_nodes:
logger.warning("No nodes found for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
# Slow query-embedding + vector search; no Django ORM access happens during
# it, so release the pooled DB connection for its duration. See #12976.
with db_connection_released():
top_nodes = retriever.retrieve(query_str)
if not top_nodes:
logger.warning("No nodes found for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
client = AIClient()
client = AIClient()
references = _get_document_references(documents, top_nodes)
references = _get_document_references(documents, top_nodes)
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
response_synthesizer = get_response_synthesizer(
llm=client.llm,
prompt_helper=get_rag_prompt_helper(
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
),
text_qa_template=prompt_template,
streaming=True,
)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
response_synthesizer=response_synthesizer,
streaming=True,
)
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
response_synthesizer = get_response_synthesizer(
llm=client.llm,
prompt_helper=get_rag_prompt_helper(
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
),
text_qa_template=prompt_template,
streaming=True,
)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
response_synthesizer=response_synthesizer,
streaming=True,
)
logger.debug("Document chat query: %s", query_str)
# Release the pooled DB connection for the slow streaming LLM response
# so it is not pinned for the whole stream; see paperless_ai.db and
# #12976.
with db_connection_released():
response_stream = query_engine.query(query_str)
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()
logger.debug("Document chat query: %s", query_str)
# Release the pooled DB connection for the slow streaming LLM response so it
# is not pinned for the whole stream; see paperless_ai.db and #12976.
with db_connection_released():
response_stream = query_engine.query(query_str)
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()
if references:
yield _format_chat_metadata_trailer(references)
if references:
yield _format_chat_metadata_trailer(references)
+28 -49
View File
@@ -1,14 +1,11 @@
import json
import logging
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING
import httpx
from paperless.models import LLMBackend
if TYPE_CHECKING:
from llama_index.core.llms import ChatMessage
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai_like import OpenAILike
@@ -19,7 +16,6 @@ from paperless.network import create_pinned_async_httpx_client
from paperless.network import create_pinned_httpx_client
from paperless.network import validate_outbound_http_url
from paperless_ai.base_model import DocumentClassifierSchema
from paperless_ai.exceptions import LLMTimeoutError
logger = logging.getLogger("paperless_ai.client")
@@ -65,16 +61,16 @@ class AIClient:
model=self.settings.llm_model or "llama3.1",
base_url=endpoint,
context_window=self.settings.llm_context_size,
request_timeout=self.settings.llm_request_timeout,
request_timeout=120,
system_prompt=LLM_SYSTEM_PROMPT,
client=Client(
host=endpoint,
timeout=self.settings.llm_request_timeout,
timeout=120,
transport=transport,
),
async_client=AsyncClient(
host=endpoint,
timeout=self.settings.llm_request_timeout,
timeout=120,
transport=async_transport,
),
)
@@ -88,18 +84,15 @@ class AIClient:
http_client = create_pinned_httpx_client(
endpoint,
allow_internal=self.settings.llm_allow_internal_endpoints,
timeout=self.settings.llm_request_timeout,
)
async_http_client = create_pinned_async_httpx_client(
endpoint,
allow_internal=self.settings.llm_allow_internal_endpoints,
timeout=self.settings.llm_request_timeout,
)
return OpenAILike(
model=self.settings.llm_model or "gpt-3.5-turbo",
api_base=endpoint,
api_key=self.settings.llm_api_key,
timeout=self.settings.llm_request_timeout,
is_chat_model=True,
is_function_calling_model=True,
system_prompt=LLM_SYSTEM_PROMPT,
@@ -120,12 +113,11 @@ class AIClient:
user_msg = ChatMessage(role="user", content=prompt)
if self.settings.llm_backend == LLMBackend.OLLAMA:
with self._normalize_timeouts():
result = self.llm.chat(
[user_msg],
format=DocumentClassifierSchema.model_json_schema(),
think=False,
)
result = self.llm.chat(
[user_msg],
format=DocumentClassifierSchema.model_json_schema(),
think=False,
)
logger.debug("LLM query result: %s", result)
parsed = DocumentClassifierSchema(**json.loads(result.message.content))
return parsed.model_dump()
@@ -133,39 +125,26 @@ class AIClient:
from llama_index.core.program.function_program import get_function_tool
tool = get_function_tool(DocumentClassifierSchema)
with self._normalize_timeouts():
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
allow_parallel_tool_calls=True,
tool_required=True,
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_call=True,
)
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
allow_parallel_tool_calls=True,
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_call=True,
)
logger.debug("LLM query result: %s", tool_calls)
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
return parsed.model_dump()
@contextmanager
def _normalize_timeouts(self) -> Iterator[None]:
try:
yield
except httpx.TimeoutException as exc:
raise LLMTimeoutError from exc
except Exception as exc:
if self._is_openai_timeout(exc):
raise LLMTimeoutError from exc
raise
def _is_openai_timeout(self, exc: Exception) -> bool:
if self.settings.llm_backend != LLMBackend.OPENAI_LIKE:
return False
# Keep OpenAI imports out of module import paths and only load the SDK
# when translating an error from an OpenAI-backed request.
from openai import APITimeoutError
return isinstance(exc, APITimeoutError)
def run_chat(self, messages: list["ChatMessage"]) -> str:
logger.debug(
"Running chat query against %s with model %s",
self.settings.llm_backend,
self.settings.llm_model,
)
result = self.llm.chat(messages)
logger.debug("Chat result: %s", result)
return result
+11 -16
View File
@@ -32,18 +32,15 @@ def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
http_client = create_pinned_httpx_client(
endpoint,
allow_internal=config.llm_allow_internal_endpoints,
timeout=config.llm_request_timeout,
)
async_http_client = create_pinned_async_httpx_client(
endpoint,
allow_internal=config.llm_allow_internal_endpoints,
timeout=config.llm_request_timeout,
)
return OpenAILikeEmbedding(
model_name=config.llm_embedding_model or "text-embedding-3-small",
api_key=config.llm_api_key,
api_base=endpoint,
timeout=config.llm_request_timeout,
http_client=http_client,
async_http_client=async_http_client,
)
@@ -76,14 +73,12 @@ def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
)
embedding._client = Client(
host=endpoint,
timeout=config.llm_request_timeout,
transport=PinnedHostHTTPTransport(
allow_internal=config.llm_allow_internal_endpoints,
),
)
embedding._async_client = AsyncClient(
host=endpoint,
timeout=config.llm_request_timeout,
transport=PinnedHostAsyncHTTPTransport(
allow_internal=config.llm_allow_internal_endpoints,
),
@@ -104,13 +99,9 @@ _DEFAULT_MODEL_NAMES = {
def get_configured_model_name(config: AIConfig) -> str:
"""Return the canonical name of the currently configured embedding model."""
# dict.get(key, default) overload resolution fails for TextChoices keys in some
# type checkers; use `or` fallback to avoid the ambiguity.
default = (
_DEFAULT_MODEL_NAMES.get(
config.llm_embedding_backend,
)
or "sentence-transformers/all-MiniLM-L6-v2"
default = _DEFAULT_MODEL_NAMES.get(
config.llm_embedding_backend,
"sentence-transformers/all-MiniLM-L6-v2",
)
return config.llm_embedding_model or default
@@ -121,11 +112,15 @@ def _normalize_llm_index_text(text: str) -> str:
def build_llm_index_text(doc: Document) -> str:
# Short structured fields (filename, storage path, ASN, title, tags, ...) live
# in node.metadata: excluded from embeddings, shown to the LLM via metadata
# prepend. Notes and Custom Fields stay in the body: Notes can be long free
# text, Custom Fields are dynamic in count and best kept in the embedding.
# TODO: Filename, Storage Path, and Archive Serial Number are short structured
# values that could move to node.metadata (excluded from embeddings, visible to
# LLM via metadata prepend) — same pattern as title/tags/correspondent. Notes
# and Custom Fields should stay here: Notes can be long free text, Custom Fields
# are dynamic in count and best kept in the embedding.
lines = [
f"Filename: {doc.filename}",
f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}",
f"Archive Serial Number: {doc.archive_serial_number or ''}",
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
]
-2
View File
@@ -1,2 +0,0 @@
class LLMTimeoutError(Exception):
pass
+45 -149
View File
@@ -7,8 +7,6 @@ from typing import TYPE_CHECKING
from django.conf import settings
from django.utils import timezone
from filelock import FileLock
from filelock import ReadWriteLock
from filelock import Timeout
from documents.models import Document
from documents.models import PaperlessTask
@@ -23,11 +21,13 @@ from paperless_ai.embedding import get_embedding_model
if TYPE_CHECKING:
from llama_index.core.schema import BaseNode
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
from paperless_ai.vector_store import PaperlessLanceVectorStore
logger = logging.getLogger("paperless_ai.indexing")
LLM_INDEX_TABLE = "documents"
RAG_NUM_OUTPUT = 512
RAG_CHUNK_OVERLAP = 200
@@ -63,108 +63,36 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
return True
def get_vector_store() -> "PaperlessSqliteVecVectorStore":
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
def get_vector_store() -> "PaperlessLanceVectorStore":
from paperless_ai.vector_store import PaperlessLanceVectorStore
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
return PaperlessSqliteVecVectorStore(
return PaperlessLanceVectorStore(
uri=str(settings.LLM_INDEX_DIR),
table_name=LLM_INDEX_TABLE,
)
# --- LLM index locking ---------------------------------------------------
#
# Two locks guard the index; they answer different questions and are NOT
# interchangeable:
#
# * settings.LLM_INDEX_LOCK (FileLock, exclusive) -- serializes WRITERS against
# each other, so only one rebuild/upsert/delete/compaction runs at a time.
# Taken by write_store(). Readers never take it, so it never blocks reads.
#
# * settings.LLM_INDEX_RWLOCK (ReadWriteLock) -- coordinates readers against the
# compaction/migration file swap. read_store() takes it SHARED (readers run
# concurrently); _exclude_readers() takes it EXCLUSIVE, only for the swap, so
# the database file is never replaced while a reader connection is open (that
# would alias the old WAL onto the new file and corrupt it).
#
# | vs another writer | vs a reader
# -----------------+-------------------+----------------------------
# normal write | LLM_INDEX_LOCK | nothing (WAL gives MVCC)
# compaction/swap | LLM_INDEX_LOCK | LLM_INDEX_RWLOCK (exclusive)
# reader | nothing (WAL) | LLM_INDEX_RWLOCK (shared)
#
# They can't be merged into one ReadWriteLock: a normal write must exclude other
# writers WITHOUT blocking readers (WAL already gives reader/writer concurrency),
# and ReadWriteLock has no "exclusive vs writers, shared vs readers" mode. Only
# the swap needs to exclude readers.
def _index_rwlock() -> ReadWriteLock:
"""Return a fresh read/write lock instance for the index swap.
``is_singleton=False`` so reads and the swap always coordinate through
SQLite (the actual cross-process case) rather than hitting the in-process
reentrant-upgrade guard; callers must ``close()`` it (the context managers
below do).
"""
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
return ReadWriteLock(str(settings.LLM_INDEX_RWLOCK), is_singleton=False)
@contextmanager
def read_store():
"""Acquire the shared read lock and yield the vector store for a read.
The shared lock is held for the whole lifetime of the connection (and
closed on exit) so the compaction/migration swap, which takes the exclusive
lock, never runs while this connection is open. Concurrent readers do not
block each other; only the swap does.
"""
lock = _index_rwlock()
try:
with lock.read_lock(), get_vector_store() as store:
yield store
finally:
lock.close()
@contextmanager
def _exclude_readers():
"""Acquire exclusive index access, blocking until readers have drained.
The exclusive counterpart to ``read_store()``: a compaction or migration
must not run while any reader connection is open. Raises
:class:`filelock.Timeout` if active readers do not drain within
``LLM_INDEX_COMPACTION_LOCK_TIMEOUT``; callers skip the operation on timeout.
"""
lock = _index_rwlock()
try:
with lock.write_lock(timeout=settings.LLM_INDEX_COMPACTION_LOCK_TIMEOUT):
yield
finally:
lock.close()
@contextmanager
def write_store(embed_model_name: str | None = None):
"""Acquire the write lock and yield the vector store.
All mutating operations (upsert, delete, rebuild, compact) must go through
this context manager to serialise concurrent Celery writers.
Read paths use ``read_store()`` so they hold the shared read lock.
Read paths use ``get_vector_store()`` directly no lock needed.
Pass ``embed_model_name`` whenever the operation may create the table so
the model name is recorded in the schema metadata for future mismatch checks.
"""
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
from paperless_ai.vector_store import PaperlessLanceVectorStore
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
with (
FileLock(settings.LLM_INDEX_LOCK),
PaperlessSqliteVecVectorStore(
with FileLock(settings.LLM_INDEX_LOCK):
yield PaperlessLanceVectorStore(
uri=str(settings.LLM_INDEX_DIR),
table_name=LLM_INDEX_TABLE,
embed_model_name=embed_model_name,
) as store,
):
yield store
)
def build_document_node(
@@ -186,9 +114,6 @@ def build_document_node(
"document_type": document.document_type.name
if document.document_type
else None,
"filename": document.filename,
"storage_path": document.storage_path.name if document.storage_path else None,
"archive_serial_number": document.archive_serial_number,
"created": document.created.isoformat() if document.created else None,
"added": document.added.isoformat() if document.added else None,
"modified": document.modified.isoformat(),
@@ -215,27 +140,23 @@ def build_document_node(
return parser.get_nodes_from_documents([doc])
def load_or_build_index(config: AIConfig, store: "PaperlessSqliteVecVectorStore"):
"""Return a VectorStoreIndex backed by ``store``.
``store`` is supplied by the caller's ``read_store()`` context so the shared
read lock and the connection stay alive for the whole retrieval.
"""
def load_or_build_index(config: AIConfig):
"""Return a VectorStoreIndex backed by the vector store."""
import llama_index.core.settings as llama_settings
from llama_index.core import VectorStoreIndex
embed_model = get_embedding_model(config)
llama_settings.Settings.embed_model = embed_model
vector_store = get_vector_store()
return VectorStoreIndex.from_vector_store(
vector_store=store,
vector_store=vector_store,
embed_model=embed_model,
)
def llm_index_exists() -> bool:
"""True when the index table exists on disk."""
with read_store() as store:
return store.table_exists()
return get_vector_store().table_exists()
def get_rag_chunk_size() -> int:
@@ -303,21 +224,6 @@ def update_llm_index(
rebuild=False,
) -> str:
"""Rebuild or incrementally update the LLM index."""
with write_store() as store:
try:
with _exclude_readers():
needs_reembed = store.check_and_run_migrations()
except Timeout:
logger.info(
"Skipping LLM index migration check: index readers are active; "
"will retry next run.",
)
needs_reembed = False
if needs_reembed:
logger.warning(
"LLM index migration requires re-embedding; forcing rebuild.",
)
rebuild = True
documents = Document.objects.all()
no_documents = not documents.exists()
@@ -329,12 +235,13 @@ def update_llm_index(
config = AIConfig()
model_name = get_configured_model_name(config)
if not rebuild and llm_index_exists():
with read_store() as store:
config_mismatch = store.config_mismatch(model_name)
if config_mismatch:
logger.warning("Embedding model changed; forcing LLM index rebuild.")
rebuild = True
if (
not rebuild
and llm_index_exists()
and get_vector_store().config_mismatch(model_name)
):
logger.warning("Embedding model changed; forcing LLM index rebuild.")
rebuild = True
if no_documents:
logger.warning("No documents found to index.")
@@ -344,6 +251,7 @@ def update_llm_index(
with write_store(embed_model_name=model_name) as store:
if rebuild or not store.table_exists():
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
logger.info("Rebuilding LLM index.")
store.drop_table()
for document in iter_wrapper(documents):
@@ -368,14 +276,9 @@ def update_llm_index(
else "No changes detected in LLM index."
)
try:
with _exclude_readers():
store.compact()
except Timeout:
logger.info(
"Skipping LLM index compaction: index readers are active; "
"will retry next run.",
)
store.ensure_document_id_scalar_index()
store.maybe_create_ann_index()
store.compact(retention_seconds=60 * 60) # 1 hour: safe for in-flight readers
return msg
@@ -391,19 +294,13 @@ def llm_index_add_or_update_document(document: Document):
with write_store(embed_model_name=get_configured_model_name(config)) as store:
store.upsert_document(str(document.id), new_nodes)
store.ensure_document_id_scalar_index()
def llm_index_compact() -> None:
"""Compact the index immediately, rebuilding the table to reclaim space."""
"""Compact the index immediately, clearing all MVCC version history."""
with write_store() as store:
try:
with _exclude_readers():
store.compact(force=True)
except Timeout:
logger.info(
"Skipping LLM index compaction: index readers are active; "
"will retry next run.",
)
store.compact(retention_seconds=0)
def llm_index_remove_document(document: Document):
@@ -470,31 +367,30 @@ def query_similar_documents(
from llama_index.core.retrievers import VectorIndexRetriever
index = load_or_build_index(config)
filters = (
_document_id_filters(allowed_document_ids)
if allowed_document_ids is not None
else None
)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k,
filters=filters,
)
query_text = truncate_content(
(document.title or "") + "\n" + (document.content or ""),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
# Hold the shared read lock for the whole retrieval so the connection is
# never open across a compaction swap. The retrieve() call generates a
# query embedding (a slow external request) and searches the vector store;
# no Django ORM access happens during it, so release the pooled DB
# connection for its duration. See #12976.
with read_store() as store:
index = load_or_build_index(config, store)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k,
filters=filters,
)
with db_connection_released():
results = retriever.retrieve(query_text)
# The retrieve() call generates a query embedding (a slow external request)
# and searches the vector store; no Django ORM access happens during it, so
# release the pooled DB connection for its duration. See #12976.
with db_connection_released():
results = retriever.retrieve(query_text)
retrieved_document_ids: list[int] = []
for node in results:
-1
View File
@@ -10,7 +10,6 @@ from pytest_django.fixtures import SettingsWrapper
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path:
settings.LLM_INDEX_DIR = tmp_path
settings.LLM_INDEX_LOCK = tmp_path / "index.lock"
settings.LLM_INDEX_RWLOCK = tmp_path / "llmindex.rwlock.db"
return tmp_path
+71 -57
View File
@@ -1,3 +1,4 @@
import json
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -6,7 +7,6 @@ import pytest
import pytest_mock
from django.test import override_settings
from django.utils import timezone
from llama_index.core.schema import MetadataMode
from documents.models import Document
from documents.models import PaperlessTask
@@ -17,7 +17,6 @@ from documents.tests.factories import PaperlessTaskFactory
from paperless.models import ApplicationConfiguration
from paperless_ai import indexing
from paperless_ai.tests.conftest import FakeEmbedding
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
@pytest.fixture
@@ -34,22 +33,12 @@ def test_build_document_node(real_document: Document) -> None:
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
assert nodes[0].metadata["document_id"] == str(real_document.id)
assert nodes[0].metadata["filename"] == real_document.filename
assert nodes[0].metadata["storage_path"] == (
real_document.storage_path.name if real_document.storage_path else None
)
assert (
nodes[0].metadata["archive_serial_number"]
== real_document.archive_serial_number
)
assert "filename" in nodes[0].excluded_embed_metadata_keys
assert "filename" not in nodes[0].excluded_llm_metadata_keys
@pytest.mark.django_db
def test_build_document_node_sets_ref_doc_id(real_document: Document) -> None:
"""Every node produced by build_document_node must carry the paperless document id
as its ref_doc_id so that the vector store's delete(str(doc.id)) works correctly."""
as its ref_doc_id so that the LanceDB adapter's delete(str(doc.id)) works correctly."""
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0, "Expected at least one node"
for node in nodes:
@@ -69,6 +58,8 @@ def test_build_document_node_excludes_metadata_from_embedding(
double the token count and exceed embedding models with small context
windows (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
"""
from llama_index.core.schema import MetadataMode
nodes = indexing.build_document_node(real_document)
for node in nodes:
embed_text = node.get_content(metadata_mode=MetadataMode.EMBED)
@@ -100,6 +91,8 @@ def test_build_document_node_excludes_document_id_from_llm_context(
real_document: Document,
) -> None:
"""document_id is an internal key and must not appear in LLM context text."""
from llama_index.core.schema import MetadataMode
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
for node in nodes:
@@ -161,6 +154,29 @@ def test_update_llm_index(
build_document_node.assert_called_once_with(real_document, chunk_size=512)
@pytest.mark.django_db
def test_update_llm_index_cleans_stale_meta_on_rebuild(
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: FakeEmbedding,
) -> None:
# A meta.json left over from the FAISS era (or written by older code) must be
# deleted on rebuild so stale artifacts don't accumulate on disk.
stale_meta = temp_llm_index_dir / "meta.json"
stale_meta.write_text(json.dumps({"embedding_model": "old", "dim": 1}))
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
assert not stale_meta.exists(), (
"update_llm_index(rebuild=True) must remove stale meta.json"
)
@pytest.mark.django_db
def test_update_llm_index_rebuilds_on_model_name_change(
temp_llm_index_dir: Path,
@@ -191,10 +207,10 @@ def test_update_llm_index_rebuilds_on_model_name_change(
):
indexing.update_llm_index(rebuild=False)
with indexing.get_vector_store() as store:
# Schema metadata only updates when the table is dropped and recreated, never
# on incremental writes -- so "model-b" here proves a full rebuild happened.
assert store.stored_model_name() == "model-b"
store = indexing.get_vector_store()
# Schema metadata only updates when the table is dropped and recreated, never on
# incremental writes -- so "model-b" here proves a full rebuild happened.
assert store.stored_model_name() == "model-b"
@pytest.mark.django_db
@@ -238,10 +254,10 @@ def test_update_llm_index_partial_update(
indexing.update_llm_index(rebuild=False)
with indexing.get_vector_store() as store:
assert store.table_exists(), (
"Expected the vector store table to exist after incremental update"
)
store = indexing.get_vector_store()
assert store.table_exists(), (
"Expected the LanceDB table to exist after incremental update"
)
@pytest.mark.django_db
@@ -253,10 +269,10 @@ def test_add_or_update_document_updates_existing_entry(
indexing.update_llm_index(rebuild=True)
indexing.llm_index_add_or_update_document(real_document)
with indexing.get_vector_store() as store:
assert store.table_exists(), (
"Expected the vector store table to exist after add-or-update"
)
store = indexing.get_vector_store()
assert store.table_exists(), (
"Expected the LanceDB table to exist after add-or-update"
)
@pytest.mark.django_db
@@ -445,7 +461,7 @@ def test_query_similar_documents_empty_allow_list_fails_closed(
class TestUpdateLlmIndexEmptyDocumentSet:
"""update_llm_index must clear the vector store table when all documents are deleted.
"""update_llm_index must clear the LanceDB table when all documents are deleted.
Without this, the stale vectors are never cleared and subsequent similarity
searches return phantom hits for document IDs that no longer exist in the DB.
@@ -473,11 +489,10 @@ class TestUpdateLlmIndexEmptyDocumentSet:
)
indexing.update_llm_index(rebuild=True)
with indexing.get_vector_store() as store:
assert store.table_exists(), (
"Precondition failed: expected the vector store table to exist "
"before deletion"
)
store = indexing.get_vector_store()
assert store.table_exists(), (
"Precondition failed: expected the LanceDB table to exist before deletion"
)
# Step 2: delete all documents
Document.objects.all().delete()
@@ -488,11 +503,10 @@ class TestUpdateLlmIndexEmptyDocumentSet:
indexing.update_llm_index(rebuild=True)
# Step 4: the table must be absent (no rows) — phantom vectors gone
with indexing.get_vector_store() as store2:
assert not store2.table_exists(), (
"Expected the vector store table to be absent after rebuilding "
"with no documents"
)
store2 = indexing.get_vector_store()
assert not store2.table_exists(), (
"Expected the LanceDB table to be absent after rebuilding with no documents"
)
class TestDocumentUpdatedSignalTriggersLlmReindex:
@@ -564,11 +578,11 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
@pytest.mark.django_db
def test_llm_index_compact_uses_force(
def test_llm_index_compact_uses_zero_retention(
temp_llm_index_dir: Path,
mocker: pytest_mock.MockerFixture,
) -> None:
"""compact must use force=True to rebuild the table and reclaim space immediately."""
"""compact must use retention_seconds=0 to clear all MVCC history immediately."""
mock_store = mocker.MagicMock()
mocker.patch(
"paperless_ai.indexing.write_store",
@@ -580,7 +594,7 @@ def test_llm_index_compact_uses_force(
indexing.llm_index_compact()
mock_store.compact.assert_called_once_with(force=True)
mock_store.compact.assert_called_once_with(retention_seconds=0)
@pytest.mark.django_db
@@ -664,14 +678,16 @@ class TestLlmIndexLocking:
@pytest.mark.django_db
@pytest.mark.django_db
class TestVectorStoreIndexing:
class TestLanceDbIndexing:
def test_get_vector_store_roundtrip(
self,
temp_llm_index_dir: Path,
mock_embed_model: FakeEmbedding,
) -> None:
with indexing.get_vector_store() as store:
assert isinstance(store, PaperlessSqliteVecVectorStore)
from paperless_ai.vector_store import PaperlessLanceVectorStore
store = indexing.get_vector_store()
assert isinstance(store, PaperlessLanceVectorStore)
def test_add_then_remove_document(
self,
@@ -680,13 +696,12 @@ class TestVectorStoreIndexing:
real_document: Document,
) -> None:
indexing.llm_index_add_or_update_document(real_document)
with indexing.get_vector_store() as store:
assert store.table_exists()
count_sql = "SELECT count(*) FROM documents"
assert store.client.execute(count_sql).fetchone()[0] >= 1
store = indexing.get_vector_store()
table = store.client.open_table(indexing.LLM_INDEX_TABLE)
assert table.count_rows() >= 1
indexing.llm_index_remove_document(real_document)
assert store.client.execute(count_sql).fetchone()[0] == 0
indexing.llm_index_remove_document(real_document)
assert store.client.open_table(indexing.LLM_INDEX_TABLE).count_rows() == 0
def test_update_shrinks_chunks_without_orphans(
self,
@@ -697,17 +712,16 @@ class TestVectorStoreIndexing:
real_document.content = "word " * 4000 # many chunks
real_document.save()
indexing.llm_index_add_or_update_document(real_document)
count_sql = "SELECT count(*) FROM documents"
with indexing.get_vector_store() as store:
big = store.client.execute(count_sql).fetchone()[0]
store = indexing.get_vector_store()
big = store.client.open_table(indexing.LLM_INDEX_TABLE).count_rows()
real_document.content = "short" # one chunk
real_document.save()
indexing.llm_index_add_or_update_document(real_document)
real_document.content = "short" # one chunk
real_document.save()
indexing.llm_index_add_or_update_document(real_document)
rows = store.client.execute(count_sql).fetchone()[0]
assert rows < big
assert rows >= 1
rows = store.client.open_table(indexing.LLM_INDEX_TABLE).count_rows()
assert rows < big
assert rows >= 1
@pytest.mark.django_db
+8 -4
View File
@@ -3,13 +3,9 @@ from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core import settings as llama_settings
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.schema import TextNode
from documents.tests.factories import DocumentFactory
from paperless_ai import chat
from paperless_ai import indexing
from paperless_ai.chat import CHAT_ERROR_MESSAGE
from paperless_ai.chat import CHAT_METADATA_DELIMITER
from paperless_ai.chat import stream_chat_with_documents
@@ -17,6 +13,9 @@ from paperless_ai.chat import stream_chat_with_documents
@pytest.fixture(autouse=True)
def patch_embed_model():
from llama_index.core import settings as llama_settings
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
# Use a real BaseEmbedding subclass to satisfy llama-index 0.14 validation
llama_settings.Settings.embed_model = MockEmbedding(embed_dim=1536)
yield
@@ -242,6 +241,8 @@ class TestStreamChatRetrieval:
temp_llm_index_dir,
mock_embed_model,
) -> None:
from documents.tests.factories import DocumentFactory
doc = DocumentFactory.create(content="hello world")
# Nothing indexed for this document yet.
out = list(chat.stream_chat_with_documents("question?", [doc]))
@@ -257,6 +258,9 @@ class TestStreamChatRetrieval:
requested documents only content from other indexed documents must
not be surfaced.
"""
from documents.tests.factories import DocumentFactory
from paperless_ai import indexing
included = DocumentFactory.create(content="included document content")
excluded = DocumentFactory.create(content="excluded document content")
indexing.llm_index_add_or_update_document(included)
+7 -32
View File
@@ -3,14 +3,12 @@ from unittest.mock import ANY
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import openai
import pytest
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection
from paperless_ai.client import LLM_SYSTEM_PROMPT
from paperless_ai.client import AIClient
from paperless_ai.exceptions import LLMTimeoutError
@pytest.fixture
@@ -19,7 +17,6 @@ def mock_ai_config():
mock_config = MagicMock()
mock_config.llm_allow_internal_endpoints = True
mock_config.llm_context_size = 8192
mock_config.llm_request_timeout = 120
MockAIConfig.return_value = mock_config
yield mock_config
@@ -67,7 +64,6 @@ def test_get_llm_openai(mock_ai_config, mock_openai_llm):
model="test_model",
api_base="http://test-url",
api_key="test_api_key",
timeout=120,
is_chat_model=True,
is_function_calling_model=True,
system_prompt=LLM_SYSTEM_PROMPT,
@@ -155,38 +151,17 @@ def test_run_llm_query_openai_uses_tools(mock_ai_config, mock_openai_llm):
mock_llm_instance.chat_with_tools.assert_called_once()
def test_run_llm_query_openai_timeout_raises_local_error(
mock_ai_config,
mock_openai_llm,
):
mock_ai_config.llm_backend = "openai-like"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_api_key = "test_api_key"
mock_ai_config.llm_endpoint = "http://test-url"
request = httpx.Request("POST", "http://test-url/v1/chat/completions")
mock_openai_llm.return_value.chat_with_tools.side_effect = openai.APITimeoutError(
request,
)
client = AIClient()
with pytest.raises(LLMTimeoutError):
client.run_llm_query("test_prompt")
def test_run_llm_query_httpx_timeout_raises_local_error(
mock_ai_config,
mock_ollama_llm,
):
def test_run_chat(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_endpoint = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.chat.side_effect = httpx.ReadTimeout("timed out")
mock_llm_instance.chat.return_value = "test_chat_result"
client = AIClient()
messages = [ChatMessage(role="user", content="Hello")]
result = client.run_chat(messages)
with pytest.raises(LLMTimeoutError):
client.run_llm_query("test_prompt")
mock_llm_instance.chat.assert_called_once_with(messages)
assert result == "test_chat_result"
+2 -7
View File
@@ -19,7 +19,6 @@ def mock_ai_config():
MockAIConfig.return_value.llm_embedding_endpoint = None
MockAIConfig.return_value.llm_allow_internal_endpoints = True
MockAIConfig.return_value.llm_context_size = 8192
MockAIConfig.return_value.llm_request_timeout = 120
yield MockAIConfig
@@ -72,7 +71,6 @@ def test_get_embedding_model_openai(mock_ai_config):
model_name="text-embedding-3-small",
api_key="test_api_key",
api_base="http://test-url",
timeout=120,
http_client=ANY,
async_http_client=ANY,
)
@@ -94,7 +92,6 @@ def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config):
model_name="text-embedding-3-small",
api_key="test_api_key",
api_base="http://embedding-url",
timeout=120,
http_client=ANY,
async_http_client=ANY,
)
@@ -227,17 +224,15 @@ def test_build_llm_index_text(mock_document):
result = build_llm_index_text(mock_document)
# Structured fields live in node.metadata for LLM context -- not body text
# Structured fields live in node.metadata for LLM context not body text
assert "Title: Test Title" not in result
assert "Created: 2023-01-01" not in result
assert "Tags: Tag1, Tag2" not in result
assert "Document Type: Invoice" not in result
assert "Correspondent: Test Correspondent" not in result
assert "Filename:" not in result
assert "Storage Path:" not in result
assert "Archive Serial Number:" not in result
# Fields without a metadata equivalent stay in body text
assert "Filename: test_file.pdf" in result
assert "Notes: Note1,Note2" in result
assert "Content:\n\nThis is the document content." in result
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
@@ -1,134 +0,0 @@
import logging
import sqlite3
import threading
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from django.conf import settings
from filelock import ReadWriteLock
from llama_index.core.schema import TextNode
from pytest_django.fixtures import SettingsWrapper
from paperless_ai import indexing
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
DIM = 8
def _node(node_id: str, document_id: str, *, seed: float = 0.0) -> TextNode:
node = TextNode(
id_=node_id,
text="chunk",
metadata={"document_id": document_id, "modified": "2026-06-01T00:00:00"},
)
node.relationships = {}
node.embedding = [seed + i / 100 for i in range(DIM)]
return node
def _seed_bloated_index(index_dir: Path) -> None:
"""Create an index whose cumulative inserts far exceed live rows."""
store = PaperlessSqliteVecVectorStore(uri=str(index_dir))
store.add([_node(f"d{j}", str(j), seed=float(j)) for j in range(20)])
for cycle in range(6):
for j in range(20):
store.upsert_document(
str(j),
[_node(f"d{j}-c{cycle}", str(j), seed=float(j))],
)
store.client.close()
def _bloat_ratio(index_dir: Path) -> float:
store = PaperlessSqliteVecVectorStore(uri=str(index_dir))
live = store.client.execute("SELECT count(*) FROM documents").fetchone()[0]
row = store.client.execute(
"SELECT value FROM index_meta WHERE key = 'total_inserts'",
).fetchone()
total = int(row["value"]) if row else live
store.client.close()
return total / max(live, 1)
def _integrity_ok(index_dir: Path) -> bool:
store = PaperlessSqliteVecVectorStore(uri=str(index_dir))
result = store.client.execute("PRAGMA integrity_check").fetchone()[0]
rows = store.client.execute("SELECT count(*) FROM documents").fetchone()[0]
store.client.close()
return result == "ok" and rows == 20
def _reader_lock() -> ReadWriteLock:
# A distinct instance simulates a reader in another process: it coordinates
# with the production lock purely through SQLite, never reentrant upgrade.
return ReadWriteLock(str(settings.LLM_INDEX_RWLOCK), is_singleton=False)
class TestCompactionLock:
def test_compaction_skips_when_a_reader_holds_the_lock(
self,
temp_llm_index_dir: Path,
settings: SettingsWrapper,
caplog: pytest.LogCaptureFixture,
) -> None:
_seed_bloated_index(temp_llm_index_dir)
settings.LLM_INDEX_COMPACTION_LOCK_TIMEOUT = 0.3
lock = _reader_lock()
with lock.read_lock(), caplog.at_level(logging.INFO):
indexing.llm_index_compact() # must not raise
lock.close()
# Swap was skipped: bloat remains, nothing corrupted, data intact.
assert _integrity_ok(temp_llm_index_dir)
assert _bloat_ratio(temp_llm_index_dir) > 2
assert "Skipping LLM index compaction" in caplog.text
def test_compaction_runs_when_no_reader_holds_the_lock(
self,
temp_llm_index_dir: Path,
) -> None:
_seed_bloated_index(temp_llm_index_dir)
assert _bloat_ratio(temp_llm_index_dir) > 2
indexing.llm_index_compact()
assert _bloat_ratio(temp_llm_index_dir) == pytest.approx(1.0)
assert _integrity_ok(temp_llm_index_dir)
def test_normal_write_is_not_gated_by_the_compaction_lock(
self,
temp_llm_index_dir: Path,
) -> None:
"""A held exclusive lock must not block ordinary writes (WAL handles them)."""
_seed_bloated_index(temp_llm_index_dir)
done = threading.Event()
def remove() -> None:
indexing.llm_index_remove_document(MagicMock(id=999))
done.set()
holder = _reader_lock()
with holder.write_lock():
t = threading.Thread(target=remove)
t.start()
finished = done.wait(timeout=5)
t.join(timeout=2)
holder.close()
assert finished, "a normal write blocked on the compaction lock"
class TestReadStore:
def test_closes_connection_on_exit(self, temp_llm_index_dir: Path) -> None:
with indexing.read_store() as store:
conn = store.client
assert conn.execute("SELECT 1").fetchone()[0] == 1
with pytest.raises(sqlite3.ProgrammingError):
conn.execute("SELECT 1")
def test_concurrent_readers_do_not_block(self, temp_llm_index_dir: Path) -> None:
_seed_bloated_index(temp_llm_index_dir)
with indexing.read_store() as a, indexing.read_store() as b:
assert a.table_exists()
assert b.table_exists()
+1 -1
View File
@@ -12,7 +12,7 @@ class TestLazyAiImports:
"os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'paperless.settings')\n"
"django.setup()\n"
"import documents.tasks # noqa: F401\n"
"leaked = [m for m in ('lancedb', 'pyarrow', 'llama_index', 'sqlite_vec') "
"leaked = [m for m in ('lancedb', 'pyarrow', 'llama_index') "
"if m in sys.modules]\n"
"assert not leaked, f'AI libraries leaked into the light path: {leaked}'\n"
)
+357 -546
View File
@@ -1,606 +1,417 @@
import sqlite3
from collections.abc import Generator
from pathlib import Path
import pytest
from llama_index.core.schema import NodeRelationship
from llama_index.core.schema import RelatedNodeInfo
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
from llama_index.core.vector_stores.types import VectorStoreQuery
from paperless_ai.vector_store import DB_FILENAME
from paperless_ai.vector_store import DEFAULT_TABLE_NAME
from paperless_ai.vector_store import MIGRATIONS
from paperless_ai.vector_store import SCHEMA_VERSION
from paperless_ai.vector_store import Migration
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
from paperless_ai.vector_store import _build_where
from paperless_ai.vector_store import PaperlessLanceVectorStore
DIM = 16
DIM = 8
def make_node(
node_id: str,
document_id: str,
*,
modified: str = "2026-06-10T00:00:00",
seed: float = 0.0,
text: str = "some text",
) -> TextNode:
node = TextNode(
id_=node_id,
text=text,
metadata={"document_id": document_id, "modified": modified},
)
node.relationships = {}
node.embedding = [seed + i / 100 for i in range(DIM)]
def _node(node_id: str, document_id: str, text: str, vec: float) -> TextNode:
node = TextNode(id_=node_id, text=text, metadata={"document_id": document_id})
node.set_content(text)
node.embedding = [vec] * DIM
# Use relationships so ref_doc_id resolves correctly (it's a read-only property)
node.relationships = {
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=document_id),
}
return node
@pytest.fixture
def store(tmp_path: Path) -> Generator[PaperlessSqliteVecVectorStore, None, None]:
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as store:
yield store
class TestPaperlessLanceVectorStoreCrud:
@pytest.fixture
def store(self, tmp_path: Path) -> PaperlessLanceVectorStore:
return PaperlessLanceVectorStore(uri=str(tmp_path / "idx"))
def test_add_then_query_returns_node(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.add([_node("1-0", "1", "alpha", 0.1), _node("2-0", "2", "beta", 0.9)])
def _query(
store: PaperlessSqliteVecVectorStore,
embedding: list[float],
top_k: int = 5,
filters=None,
):
return store.query(
VectorStoreQuery(
query_embedding=embedding,
similarity_top_k=top_k,
filters=filters,
),
)
result = store.query(
VectorStoreQuery(query_embedding=[0.1] * DIM, similarity_top_k=1),
)
def _eq_filter(key: str, value: str):
return MetadataFilters(
filters=[MetadataFilter(key=key, operator=FilterOperator.EQ, value=value)],
)
def _in_filter(document_ids: list[str]):
return MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=document_ids,
),
],
)
class TestCrud:
def test_add_then_query_returns_node(self, store) -> None:
node = make_node("n1", "1")
assert store.add([node]) == ["n1"]
result = _query(store, node.embedding, top_k=1)
assert result.ids == ["n1"]
assert len(result.nodes) == 1
assert result.nodes[0].metadata["document_id"] == "1"
# cosine distance of the identical vector is 0 -> similarity 1
assert result.similarities[0] == pytest.approx(1.0)
def test_query_empty_store_returns_empty_no_raise(self, store) -> None:
result = _query(store, [0.0] * DIM)
assert result.ids == [] and result.nodes == [] and result.similarities == []
def test_query_empty_table_returns_empty_no_raise(
self,
store: PaperlessLanceVectorStore,
) -> None:
result = store.query(
VectorStoreQuery(query_embedding=[0.1] * DIM, similarity_top_k=5),
)
assert result.nodes == []
assert result.ids == []
def test_add_empty_list_is_noop(self, store) -> None:
assert store.add([]) == []
assert not store.table_exists()
def test_delete_removes_all_chunks_of_document(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.add([_node("1-0", "1", "a", 0.1), _node("1-1", "1", "b", 0.2)])
store.add([_node("2-0", "2", "c", 0.9)])
def test_delete_removes_all_chunks_of_document(self, store) -> None:
store.add([make_node("a1", "1"), make_node("a2", "1"), make_node("b1", "2")])
store.delete("1")
result = _query(store, [0.0] * DIM, top_k=10)
assert result.ids == ["b1"]
def test_query_with_in_filter_scopes_results(self, store) -> None:
store.add(
[
make_node("a1", "1", seed=0.0),
make_node("b1", "2", seed=1.0),
make_node("c1", "3", seed=2.0),
],
assert store.client.open_table("documents").count_rows() == 1
def test_query_with_in_filter_scopes_results(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.add([_node("1-0", "1", "a", 0.1), _node("2-0", "2", "b", 0.1)])
result = store.query(
VectorStoreQuery(
query_embedding=[0.1] * DIM,
similarity_top_k=5,
filters=MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=["2"],
),
],
),
),
)
result = _query(store, [0.0] * DIM, top_k=10, filters=_in_filter(["2", "3"]))
assert sorted(result.ids) == ["b1", "c1"]
def test_query_respects_top_k_with_filter(self, store) -> None:
# k semantics: global top-k even with IN filters (document_id is a
# metadata column, not a partition key -- see design doc).
store.add(
[make_node(f"n{i}", str(i % 4), seed=float(i)) for i in range(12)],
assert [n.metadata["document_id"] for n in result.nodes] == ["2"]
def test_get_nodes_filter_returns_empty_cleanly(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.add([_node("1-0", "1", "a", 0.1)])
nodes = store.get_nodes(
filters=MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=["999"],
),
],
),
)
result = _query(
store,
[0.0] * DIM,
top_k=3,
filters=_in_filter(["0", "1", "2", "3"]),
assert nodes == []
def test_get_nodes_returns_empty_when_no_table(
self,
store: PaperlessLanceVectorStore,
) -> None:
result = store.get_nodes(
filters=MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=["1"],
),
],
),
)
assert len(result.ids) == 3
assert result.similarities == sorted(result.similarities, reverse=True)
assert result == []
def test_get_nodes_filter_and_empty_paths(self, store) -> None:
assert store.get_nodes(filters=_in_filter(["1"])) == [] # no table yet
store.add([make_node("a1", "1"), make_node("b1", "2")])
nodes = store.get_nodes(filters=_in_filter(["1"]))
assert [n.node_id for n in nodes] == ["a1"]
assert nodes[0].embedding is not None
assert store.get_nodes(filters=_in_filter(["999"])) == []
def test_query_with_eq_filter_scopes_results(self, store) -> None:
store.add(
[
make_node("a1", "1", seed=0.0),
make_node("b1", "2", seed=1.0),
make_node("c1", "3", seed=2.0),
],
def test_fresh_instance_filters_existing_table(
self,
tmp_path: Path,
) -> None:
uri = str(tmp_path / "idx")
PaperlessLanceVectorStore(uri=uri).add(
[_node("1-0", "1", "a", 0.1), _node("2-0", "2", "b", 0.1)],
)
result = _query(
store,
[0.0] * DIM,
top_k=10,
filters=_eq_filter("document_id", "2"),
reopened = PaperlessLanceVectorStore(uri=uri)
result = reopened.query(
VectorStoreQuery(
query_embedding=[0.1] * DIM,
similarity_top_k=5,
filters=MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=["1"],
),
],
),
),
)
assert result.ids == ["b1"]
assert [n.metadata["document_id"] for n in result.nodes] == ["1"]
def test_get_nodes_node_ids_not_implemented(self, store) -> None:
with pytest.raises(NotImplementedError):
store.get_nodes(node_ids=["x"])
def test_fresh_instance_sees_existing_table(self, store, tmp_path: Path) -> None:
store.add([make_node("a1", "1")])
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as reopened:
assert reopened.table_exists()
assert reopened.vector_dim() == DIM
assert _query(reopened, [0.0] * DIM, top_k=1).ids == ["a1"]
def test_table_exists_and_drop(self, store) -> None:
assert not store.table_exists()
store.add([make_node("a1", "1")])
assert store.table_exists()
def test_table_exists_and_drop(
self,
store: PaperlessLanceVectorStore,
) -> None:
assert store.table_exists() is False
store.add([_node("1-0", "1", "a", 0.1)])
assert store.table_exists() is True
assert store.vector_dim() == DIM
store.drop_table()
assert not store.table_exists()
assert store.vector_dim() is None
assert store.table_exists() is False
def test_build_where_or_condition(self) -> None:
from llama_index.core.vector_stores.types import FilterCondition
from paperless_ai.vector_store import _build_where
where = _build_where(
MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.EQ,
value="1",
),
MetadataFilter(
key="document_id",
operator=FilterOperator.EQ,
value="2",
),
],
condition=FilterCondition.OR,
),
)
assert where == "document_id = '1' OR document_id = '2'"
class TestBuildWhere:
def test_fails_closed_when_no_filter_is_translatable(self) -> None:
# A nested MetadataFilters is not a MetadataFilter, so it is skipped.
# With no translatable clauses, the function must fail closed rather
# than emit "()" (invalid SQL) and never widen document access.
nested = MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.EQ,
value="1",
),
class TestPaperlessLanceVectorStoreUpsert:
@pytest.fixture
def store(self, tmp_path: Path) -> PaperlessLanceVectorStore:
s = PaperlessLanceVectorStore(uri=str(tmp_path / "idx"))
s.add(
[
_node("1-0", "1", "old0", 0.1),
_node("1-1", "1", "old1", 0.2),
_node("1-2", "1", "old2", 0.3),
_node("2-0", "2", "keep", 0.9),
],
)
where, params = _build_where(MetadataFilters(filters=[nested]))
assert where == "1 = 0"
assert params == []
return s
def test_query_with_untranslatable_filter_returns_no_rows(self, store) -> None:
store.add([make_node("a1", "1"), make_node("b1", "2")])
nested = MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.EQ,
value="1",
),
],
def test_upsert_prunes_stale_chunks_and_keeps_others(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.upsert_document(
"1",
[_node("1-0", "1", "new0", 0.1), _node("1-1", "1", "new1", 0.2)],
)
filters = MetadataFilters(filters=[nested])
# Must not raise (no "WHERE ()") and must return nothing (fail closed).
assert _query(store, [0.0] * DIM, top_k=5, filters=filters).ids == []
assert store.get_nodes(filters=filters) == []
class TestUpsert:
def test_upsert_replaces_and_prunes_stale_chunks(self, store) -> None:
store.add(
[make_node("d1c1", "1"), make_node("d1c2", "1"), make_node("d2c1", "2")],
table = store.client.open_table("documents")
doc1 = sorted(
r["id"] for r in table.search().where("document_id = '1'").to_list()
)
store.upsert_document("1", [make_node("d1new", "1")])
result = _query(store, [0.0] * DIM, top_k=10)
assert sorted(result.ids) == ["d1new", "d2c1"]
assert doc1 == ["1-0", "1-1"] # 1-2 pruned
assert table.count_rows() == 3 # 2 new doc1 + 1 doc2
def test_upsert_creates_table_when_missing(self, store) -> None:
store.upsert_document("1", [make_node("a1", "1")])
assert _query(store, [0.0] * DIM, top_k=1).ids == ["a1"]
def test_upsert_is_single_commit(
self,
store: PaperlessLanceVectorStore,
) -> None:
table = store.client.open_table("documents")
before = table.version
store.upsert_document("1", [_node("1-0", "1", "new0", 0.1)])
assert store.client.open_table("documents").version == before + 1
def test_upsert_empty_nodes_removes_document(self, store) -> None:
store.add([make_node("a1", "1"), make_node("b1", "2")])
def test_upsert_empty_nodes_removes_document(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.upsert_document("1", [])
assert _query(store, [0.0] * DIM, top_k=10).ids == ["b1"]
def test_upsert_is_atomic_for_concurrent_readers(
table = store.client.open_table("documents")
remaining = sorted(r["document_id"] for r in table.search().to_list())
assert "1" not in remaining
assert "2" in remaining
class TestPaperlessLanceVectorStoreMaintenance:
@pytest.fixture
def store(self, tmp_path: Path) -> PaperlessLanceVectorStore:
return PaperlessLanceVectorStore(uri=str(tmp_path / "idx"))
def test_maybe_create_ann_index_noop_below_threshold(
self,
store,
tmp_path: Path,
store: PaperlessLanceVectorStore,
) -> None:
"""A second connection must never observe document 1 half-replaced."""
store.add([make_node("a1", "1"), make_node("a2", "1")])
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as reader:
store.upsert_document("1", [make_node("a3", "1")])
ids = [n.node_id for n in reader.get_nodes(filters=_in_filter(["1"]))]
assert ids == ["a3"]
store.add([_node("1-0", "1", "a", 0.1)])
# Threshold far above row count -> no index attempted, no error.
store.maybe_create_ann_index(min_rows=1000)
# Still queryable.
result = store.query(
VectorStoreQuery(query_embedding=[0.1] * DIM, similarity_top_k=1),
)
assert len(result.nodes) == 1
class TestMetadataCoercion:
def test_none_metadata_values_become_empty_strings(self, store) -> None:
node = make_node("a1", "1")
node.metadata["modified"] = None
store.add([node]) # must not raise (vec0 rejects NULL metadata)
assert store.get_modified_times() == {"1": ""}
class TestModelNameTracking:
def test_stored_model_name_none_without_table(self, tmp_path: Path) -> None:
with PaperlessSqliteVecVectorStore(
uri=str(tmp_path),
embed_model_name="model-a",
) as store:
assert store.stored_model_name() is None
def test_model_name_stored_after_add_and_persists(self, tmp_path: Path) -> None:
with PaperlessSqliteVecVectorStore(
uri=str(tmp_path),
embed_model_name="model-a",
) as store:
store.add([make_node("a1", "1")])
assert store.stored_model_name() == "model-a"
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as reopened:
assert reopened.stored_model_name() == "model-a"
def test_config_mismatch_semantics(self, tmp_path: Path) -> None:
with PaperlessSqliteVecVectorStore(
uri=str(tmp_path),
embed_model_name="model-a",
) as store:
assert not store.config_mismatch("anything") # no table yet
store.add([make_node("a1", "1")])
assert not store.config_mismatch("model-a")
assert store.config_mismatch("model-b")
def test_config_mismatch_false_when_table_predates_tracking(
def test_maybe_create_ann_index_non_divisible_dim_falls_back(
self,
tmp_path: Path,
store: PaperlessLanceVectorStore,
) -> None:
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as store: # no model name
store.add([make_node("a1", "1")])
assert not store.config_mismatch("model-a")
# DIM=8 is not divisible by the PQ default sub-vectors; must not raise
# and must leave the table queryable (IVF_FLAT fallback or skipped).
for i in range(40):
store.add([_node(f"1-{i}", "1", f"t{i}", float(i))])
store.maybe_create_ann_index(min_rows=10)
result = store.query(
VectorStoreQuery(query_embedding=[1.0] * DIM, similarity_top_k=3),
)
assert len(result.nodes) == 3
def test_compact_reduces_to_single_version(
self,
store: PaperlessLanceVectorStore,
) -> None:
for i in range(5):
store.add([_node(f"1-{i}", "1", f"t{i}", float(i))])
assert len(store.client.open_table("documents").list_versions()) > 1
store.compact(retention_seconds=0)
assert len(store.client.open_table("documents").list_versions()) == 1
def test_upsert_after_optimize_with_scalar_index(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.add(
[
_node("1-0", "1", "old0", 0.1),
_node("1-1", "1", "old1", 0.2),
_node("1-2", "1", "old2", 0.3),
_node("2-0", "2", "keep", 0.9),
],
)
store.ensure_document_id_scalar_index()
store.compact(retention_seconds=0)
store.upsert_document("1", [_node("1-0", "1", "new0", 0.1)])
table = store.client.open_table("documents")
doc1 = sorted(
r["id"] for r in table.search().where("document_id = '1'").to_list()
)
assert doc1 == ["1-0"]
assert table.count_rows() == 2
def test_ensure_scalar_index_is_idempotent(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.add([_node("1-0", "1", "text", 0.5)])
store.ensure_document_id_scalar_index()
# Second call must not raise and must not replace the existing index.
store.ensure_document_id_scalar_index()
assert store._has_index_on("document_id")
def test_ensure_scalar_index_noop_on_empty_store(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.ensure_document_id_scalar_index() # no table yet — must not raise
class TestConfigMismatch:
@pytest.fixture
def uri(self, tmp_path: Path) -> str:
return str(tmp_path / "idx")
def test_stored_model_name_returns_none_when_no_table(self, uri: str) -> None:
store = PaperlessLanceVectorStore(uri=uri)
assert store.stored_model_name() is None
def test_model_name_stored_in_schema_after_add(self, uri: str) -> None:
store = PaperlessLanceVectorStore(uri=uri, embed_model_name="all-MiniLM-L6-v2")
store.add([_node("1-0", "1", "text", 0.1)])
assert store.stored_model_name() == "all-MiniLM-L6-v2"
def test_model_name_stored_in_schema_after_upsert(self, uri: str) -> None:
store = PaperlessLanceVectorStore(uri=uri, embed_model_name="nomic-embed")
store.upsert_document("1", [_node("1-0", "1", "text", 0.1)])
assert store.stored_model_name() == "nomic-embed"
def test_model_name_persists_after_reopen(self, uri: str) -> None:
PaperlessLanceVectorStore(uri=uri, embed_model_name="all-MiniLM-L6-v2").add(
[_node("1-0", "1", "text", 0.1)],
)
reopened = PaperlessLanceVectorStore(uri=uri)
assert reopened.stored_model_name() == "all-MiniLM-L6-v2"
def test_config_mismatch_returns_false_when_no_table(self, uri: str) -> None:
store = PaperlessLanceVectorStore(uri=uri)
assert store.config_mismatch("any-model") is False
def test_config_mismatch_returns_false_when_model_matches(self, uri: str) -> None:
store = PaperlessLanceVectorStore(uri=uri, embed_model_name="all-MiniLM-L6-v2")
store.add([_node("1-0", "1", "text", 0.1)])
assert store.config_mismatch("all-MiniLM-L6-v2") is False
def test_config_mismatch_returns_true_when_model_differs(self, uri: str) -> None:
store = PaperlessLanceVectorStore(uri=uri, embed_model_name="old-model")
store.add([_node("1-0", "1", "text", 0.1)])
assert store.config_mismatch("new-model") is True
def test_config_mismatch_returns_false_when_no_metadata_stored(
self,
uri: str,
) -> None:
# Tables created before model-name tracking was added have no schema metadata.
# Conservative default: assume compatible rather than force a rebuild.
store = PaperlessLanceVectorStore(uri=uri)
store.add([_node("1-0", "1", "text", 0.1)])
assert store.config_mismatch("any-model") is False
class TestGetModifiedTimes:
def test_empty_store_returns_empty_dict(self, store) -> None:
@pytest.fixture
def store(self, tmp_path: Path) -> PaperlessLanceVectorStore:
return PaperlessLanceVectorStore(uri=str(tmp_path / "idx"))
def _node_with_modified(
self,
node_id: str,
doc_id: str,
modified: str,
) -> TextNode:
node = TextNode(
id_=node_id,
text="text",
metadata={"document_id": doc_id, "modified": modified},
)
node.embedding = [0.1] * DIM
node.relationships = {
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=doc_id),
}
return node
def test_empty_store_returns_empty_dict(
self,
store: PaperlessLanceVectorStore,
) -> None:
assert store.get_modified_times() == {}
def test_returns_one_entry_per_document(self, store) -> None:
def test_returns_one_entry_per_document(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.add(
[
make_node("a1", "1", modified="2026-01-01T00:00:00"),
make_node("a2", "1", modified="2026-01-01T00:00:00"),
make_node("b1", "2", modified="2026-02-02T00:00:00"),
self._node_with_modified("1-0", "1", "2024-01-01T00:00:00"),
self._node_with_modified("1-1", "1", "2024-01-01T00:00:00"),
self._node_with_modified("2-0", "2", "2024-06-01T00:00:00"),
],
)
assert store.get_modified_times() == {
"1": "2026-01-01T00:00:00",
"2": "2026-02-02T00:00:00",
result = store.get_modified_times()
assert result == {
"1": "2024-01-01T00:00:00",
"2": "2024-06-01T00:00:00",
}
class TestCompact:
def _bloat_ratio(self, store) -> float:
live = store.client.execute(
"SELECT count(*) FROM documents",
).fetchone()[0]
# vec0 0.1.9 does not accumulate deleted rows in the _rowids shadow
# table, so we track cumulative inserts in index_meta instead.
row = store.client.execute(
"SELECT value FROM index_meta WHERE key = 'total_inserts'",
).fetchone()
total = int(row["value"]) if row else live
return total / max(live, 1)
def _churn(self, store, cycles: int) -> None:
for i in range(cycles):
store.upsert_document(
"1",
[make_node(f"gen{i}-{j}", "1", seed=float(j)) for j in range(20)],
)
def test_compact_noop_below_threshold(self, store) -> None:
store.add([make_node("a1", "1")])
store.compact()
assert _query(store, [0.0] * DIM, top_k=1).ids == ["a1"]
def test_force_compact_preserves_rows_and_metadata(self, store) -> None:
store.add([make_node("a1", "1"), make_node("b1", "2", seed=3.0)])
self._churn(store, 5)
before = {
n.node_id: n.metadata
for n in store.get_nodes(filters=_in_filter(["1", "2"]))
}
store.compact(force=True)
after = {
n.node_id: n.metadata
for n in store.get_nodes(filters=_in_filter(["1", "2"]))
}
assert after == before
assert self._bloat_ratio(store) == pytest.approx(1.0)
# store remains fully usable after the rebuild; use a seed far from all
# existing nodes (gen4-0..gen4-19 have seeds 0..19) so cosine KNN is
# unambiguous at top_k=1.
store.upsert_document("3", [make_node("c1", "3", seed=100.0)])
assert "c1" in _query(store, [100.0] * DIM, top_k=1).ids
def test_auto_compact_triggers_on_churn(self, store) -> None:
store.add([make_node(f"s{j}", "1", seed=float(j)) for j in range(20)])
self._churn(store, 5)
assert self._bloat_ratio(store) > 2
store.compact()
assert self._bloat_ratio(store) == pytest.approx(1.0)
def test_compact_on_missing_table_is_noop(self, store) -> None:
store.compact()
store.compact(force=True)
def test_failed_compact_removes_temp_wal_and_shm(
self,
store,
tmp_path: Path,
monkeypatch,
) -> None:
"""A compact() that raises mid-rebuild must leave no .compact* files.
Normally the sole connection's close() checkpoints the temp WAL away,
but a concurrent reader keeps -wal/-shm alive, so the cleanup must
unlink them explicitly (as the structural-migration path does).
"""
store.add([make_node("a1", "1")])
compact_path = str(tmp_path / DB_FILENAME) + ".compact"
held: list[sqlite3.Connection] = []
def boom(conn: sqlite3.Connection, dim: int) -> None:
# Hold an extra connection so close() of the rebuild connection is
# not the last one -> the temp -wal/-shm survive the checkpoint.
extra = sqlite3.connect(compact_path)
extra.execute("SELECT 1").fetchall()
held.append(extra)
raise RuntimeError("boom")
monkeypatch.setattr(
PaperlessSqliteVecVectorStore,
"_create_vec_table",
staticmethod(boom),
)
try:
with pytest.raises(RuntimeError):
store.compact(force=True)
assert sorted(p.name for p in tmp_path.glob("*.compact*")) == []
finally:
for c in held:
c.close()
class TestDbFile:
def test_single_db_file_in_index_dir(self, store, tmp_path: Path) -> None:
store.add([make_node("a1", "1")])
assert (tmp_path / DB_FILENAME).exists()
def test_wal_mode_enabled(self, store) -> None:
assert (
store.client.execute("PRAGMA journal_mode").fetchone()[0].lower() == "wal"
)
class TestMigrations:
"""Tests for the schema migration machinery."""
def _schema_version(self, store: PaperlessSqliteVecVectorStore) -> int | None:
row = store.client.execute(
"SELECT value FROM index_meta WHERE key = 'schema_version'",
).fetchone()
return int(row[0]) if row else None
def test_new_table_records_schema_version(self, store) -> None:
store.add([make_node("a1", "1")])
assert self._schema_version(store) == SCHEMA_VERSION
def test_check_migrations_no_table_returns_false(self, store) -> None:
assert store.check_and_run_migrations() is False
def test_check_migrations_current_version_returns_false(self, store) -> None:
store.add([make_node("a1", "1")])
assert store.check_and_run_migrations() is False
def test_reembed_migration_returns_true(self, store, tmp_path: Path) -> None:
store.add([make_node("a1", "1")])
migration = Migration(
from_version=1,
to_version=2,
kind="re-embed",
description="test re-embed",
)
MIGRATIONS.append(migration)
try:
from paperless_ai import vector_store as vs_mod
original = vs_mod.SCHEMA_VERSION
vs_mod.SCHEMA_VERSION = 2
result = store.check_and_run_migrations()
finally:
MIGRATIONS.remove(migration)
vs_mod.SCHEMA_VERSION = original
assert result is True
def test_structural_migration_copies_rows_and_updates_version(
self,
store,
tmp_path: Path,
) -> None:
store.add([make_node("a1", "1"), make_node("b1", "2")])
def apply(
src: sqlite3.Connection,
dst: sqlite3.Connection,
dim: int,
) -> None:
dst.execute( # nosemgrep
f"CREATE VIRTUAL TABLE {DEFAULT_TABLE_NAME} USING vec0("
"id TEXT PRIMARY KEY, document_id TEXT, modified TEXT,"
f" +node_content TEXT, embedding float[{dim}] distance_metric=cosine"
")",
)
dst.execute(
"INSERT INTO index_meta (key, value) VALUES ('dim', ?) "
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(str(dim),),
)
rows = src.execute(
"SELECT id, document_id, modified, node_content, embedding "
f"FROM {DEFAULT_TABLE_NAME}",
).fetchall()
dst.execute("BEGIN IMMEDIATE")
dst.executemany(
f"INSERT INTO {DEFAULT_TABLE_NAME} "
"(id, document_id, modified, node_content, embedding) "
"VALUES (?, ?, ?, ?, ?)",
[
(
r["id"],
r["document_id"],
r["modified"],
r["node_content"],
bytes(r["embedding"]),
)
for r in rows
],
)
dst.execute(
"INSERT INTO index_meta (key, value) VALUES ('total_inserts', ?) "
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(str(len(rows)),),
)
dst.execute("COMMIT")
migration = Migration(
from_version=1,
to_version=2,
kind="structural",
description="test structural",
apply=apply,
)
MIGRATIONS.append(migration)
try:
from paperless_ai import vector_store as vs_mod
original = vs_mod.SCHEMA_VERSION
vs_mod.SCHEMA_VERSION = 2
result = store.check_and_run_migrations()
finally:
MIGRATIONS.remove(migration)
vs_mod.SCHEMA_VERSION = original
assert result is False
assert self._schema_version(store) == 2
ids = {n.node_id for n in store.get_nodes()}
assert ids == {"a1", "b1"}
def test_compact_preserves_schema_version(self, store) -> None:
store.add([make_node("a1", "1")])
assert self._schema_version(store) == SCHEMA_VERSION
store.compact(force=True)
assert self._schema_version(store) == SCHEMA_VERSION
def test_stop_at_reembed_boundary(self, store) -> None:
# Registry: structural v2, re-embed v3, structural v4.
# Only v2 should apply; the re-embed boundary must stop execution
# before v4 runs, and the stored version must stay at 2.
store.add([make_node("a1", "1"), make_node("b1", "2")])
def copy_apply(
src: sqlite3.Connection,
dst: sqlite3.Connection,
dim: int,
) -> None:
dst.execute( # nosemgrep
f"CREATE VIRTUAL TABLE {DEFAULT_TABLE_NAME} USING vec0("
"id TEXT PRIMARY KEY, document_id TEXT, modified TEXT,"
f" +node_content TEXT, embedding float[{dim}] distance_metric=cosine"
")",
)
dst.execute(
"INSERT INTO index_meta (key, value) VALUES ('dim', ?) "
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(str(dim),),
)
rows = src.execute(
"SELECT id, document_id, modified, node_content, embedding "
f"FROM {DEFAULT_TABLE_NAME}",
).fetchall()
dst.execute("BEGIN IMMEDIATE")
dst.executemany(
f"INSERT INTO {DEFAULT_TABLE_NAME} "
"(id, document_id, modified, node_content, embedding) "
"VALUES (?, ?, ?, ?, ?)",
[
(
r["id"],
r["document_id"],
r["modified"],
r["node_content"],
bytes(r["embedding"]),
)
for r in rows
],
)
dst.execute("COMMIT")
migrations = [
Migration(
from_version=1,
to_version=2,
kind="structural",
description="v2 structural",
apply=copy_apply,
),
Migration(
from_version=2,
to_version=3,
kind="re-embed",
description="v3 re-embed boundary",
),
Migration(
from_version=3,
to_version=4,
kind="structural",
description="v4 structural - must not run",
apply=copy_apply,
),
]
MIGRATIONS.extend(migrations)
try:
from paperless_ai import vector_store as vs_mod
original = vs_mod.SCHEMA_VERSION
vs_mod.SCHEMA_VERSION = 4
result = store.check_and_run_migrations()
finally:
for m in migrations:
MIGRATIONS.remove(m)
vs_mod.SCHEMA_VERSION = original
assert result is True
assert self._schema_version(store) == 2
+174 -445
View File
@@ -1,25 +1,15 @@
import json
import logging
import sqlite3
import struct
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import field
from pathlib import Path
from types import TracebackType
from typing import Any
from typing import Literal
import sqlite_vec
import lancedb
import pyarrow as pa
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.core.vector_stores.types import FilterCondition
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
from llama_index.core.vector_stores.types import VectorStoreQuery
from llama_index.core.vector_stores.types import VectorStoreQueryResult
@@ -28,118 +18,46 @@ from llama_index.core.vector_stores.utils import node_to_metadata_dict
logger = logging.getLogger("paperless_ai.vector_store")
DB_FILENAME = "llmindex.db"
DEFAULT_TABLE_NAME = "documents"
# Current schema version. Written to index_meta at table creation and bumped
# whenever a Migration is added to MIGRATIONS. check_and_run_migrations() uses
# this to decide which migrations to run on an existing store.
SCHEMA_VERSION = 1
# compact(): rebuild when the cumulative rowid count exceeds this multiple of
# the live row count. DELETEs on vec0 tables never reclaim space (upstream
# asg017/sqlite-vec#54), so per-document re-index churn grows the file until
# a rebuild copies the live rows into a fresh table.
COMPACT_BLOAT_RATIO = 2.0
# Filterable vec0 metadata columns. _build_where() only ever receives filter
# keys we construct ourselves, but allowlisting keeps SQL identifiers safe by
# construction.
_FILTER_COLUMNS = frozenset({"document_id", "modified"})
# Below this many chunks, LanceDB's exact (brute-force) search is sufficient and
# faster than building an ANN index (per LanceDB guidance, ~100K vectors).
ANN_INDEX_MIN_ROWS = 100_000
# IVF_PQ default; num_sub_vectors must evenly divide the embedding dimension.
ANN_PQ_SUB_VECTORS = 96
@dataclass
class Migration:
"""A schema migration for the sqlite-vec vector store.
kind="structural": rows are copied into a new-schema file with no
re-embedding needed. Supply ``apply(src_conn, dst_conn, dim)`` which
must create the vec0 table in ``dst_conn``, copy all rows from
``src_conn``, and write ``dim`` / ``embed_model`` / ``total_inserts`` to
``dst_conn``'s ``index_meta``. ``schema_version`` is written by the
migration runner after ``apply`` returns.
kind="re-embed": the new schema requires fresh embeddings.
``check_and_run_migrations()`` returns True when it encounters one of
these so the caller can force a full rebuild (which recreates the table
at the current SCHEMA_VERSION).
"""
from_version: int
to_version: int
kind: Literal["structural", "re-embed"]
description: str
apply: Callable[[sqlite3.Connection, sqlite3.Connection, int], None] | None = field(
default=None,
repr=False,
)
def _escape(value: str) -> str:
return str(value).replace("'", "''")
# Registry of all schema migrations in order. Empty at v1 -- this is the
# baseline. Add entries here (and bump SCHEMA_VERSION) when the schema changes.
MIGRATIONS: list[Migration] = []
def _pack(embedding: Sequence[float]) -> bytes:
return struct.pack(f"{len(embedding)}f", *embedding)
def _unpack(blob: bytes) -> list[float]:
return list(struct.unpack(f"{len(blob) // 4}f", blob))
def _build_where(filters: MetadataFilters | None) -> tuple[str, list[str]]:
"""Translate the EQ / IN filters we use into a parameterized SQL clause
on vec0 metadata columns. Returns ("", []) when there is nothing to filter.
"""
def _build_where(filters: MetadataFilters | None) -> str | None:
"""Translate the EQ / IN filters we use into a Lance SQL predicate on the
top-level ``document_id`` column."""
if filters is None or not filters.filters:
return "", []
return None
clauses: list[str] = []
params: list[str] = []
for f in filters.filters:
# filters.filters is Union[MetadataFilter, ExactMatchFilter, MetadataFilters];
# we only build MetadataFilter entries, so skip anything else at runtime.
if not isinstance(f, MetadataFilter):
continue
if f.key not in _FILTER_COLUMNS: # pragma: no cover - we build the keys
raise NotImplementedError(f"Unsupported filter column: {f.key}")
if f.operator == FilterOperator.IN:
values = [str(v) for v in f.value] # type: ignore[union-attr] # value is list when operator is IN
if not values: # pragma: no cover
clauses.append("1 = 0")
continue
placeholders = ",".join("?" for _ in values)
clauses.append(f"{f.key} IN ({placeholders})")
params.extend(values)
vals = ",".join(f"'{_escape(v)}'" for v in f.value)
clauses.append(f"{f.key} IN ({vals})")
elif f.operator == FilterOperator.EQ:
clauses.append(f"{f.key} = ?")
params.append(str(f.value))
clauses.append(f"{f.key} = '{_escape(f.value)}'")
else: # pragma: no cover - we only ever build EQ/IN filters
raise NotImplementedError(f"Unsupported filter operator: {f.operator}")
if not clauses:
# Filters were requested but none could be translated. Fail closed
# rather than emit "()" (invalid SQL): filters scope document access,
# so an empty translation must match no rows, never widen the scope.
return "1 = 0", []
joiner = " OR " if filters.condition == FilterCondition.OR else " AND "
return "(" + joiner.join(clauses) + ")", params
return joiner.join(clauses)
class PaperlessSqliteVecVectorStore(BasePydanticVectorStore):
"""A llama-index vector store backed by a sqlite-vec vec0 table.
class PaperlessLanceVectorStore(BasePydanticVectorStore):
"""A llama-index vector store backed directly by a LanceDB table.
Stores one row per node: the node id (TEXT primary key), its document id
(metadata column, used for EQ/IN filtering and per-document delete), the
document's modified timestamp, the embedding (float32, cosine metric), and
the serialized node (text + metadata) as JSON in an auxiliary column.
``stores_text`` lets llama-index run off this store alone, with no
Stores one row per node with the node id, its document id (both as the
``ref_doc_id`` delete key ``doc_id`` and a top-level filter column
``document_id``), the embedding, and the serialised node (text + metadata)
as JSON. ``stores_text`` lets llama-index run off this store alone, with no
separate docstore or index store.
Everything lives in one SQLite database file (``DB_FILENAME``) inside the
directory given as ``uri`` (kept as a directory for compatibility with the
previous LanceDB layout). WAL mode allows readers in other processes to
proceed while the (FileLock-serialized) writer holds a transaction.
Implemented surface of ``BasePydanticVectorStore``
---------------------------------------------------
Only the methods actively used by this codebase are implemented.
@@ -152,117 +70,58 @@ class PaperlessSqliteVecVectorStore(BasePydanticVectorStore):
flat_metadata: bool = False
_uri: str = PrivateAttr()
_table_name: str = PrivateAttr()
_embed_model_name: str | None = PrivateAttr()
_conn: Any = PrivateAttr()
_table: Any = PrivateAttr()
def __init__(
self,
uri: str,
table_name: str = DEFAULT_TABLE_NAME,
embed_model_name: str | None = None,
) -> None:
super().__init__(stores_text=True, flat_metadata=False)
self._uri = uri
self._table_name = table_name
self._embed_model_name = embed_model_name
self._conn = self._open_connection(str(Path(uri) / DB_FILENAME))
@staticmethod
def _open_connection(db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(
db_path,
timeout=30,
isolation_level=None, # autocommit; explicit transactions below
self._conn = lancedb.connect(uri)
existing = self._conn.list_tables().tables
self._table = (
self._conn.open_table(table_name) if table_name in existing else None
)
conn.row_factory = sqlite3.Row
conn.enable_load_extension(True) # noqa: FBT003
sqlite_vec.load(conn)
conn.enable_load_extension(False) # noqa: FBT003
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
conn.execute(
"CREATE TABLE IF NOT EXISTS index_meta (key TEXT PRIMARY KEY, value TEXT)",
)
return conn
@property
def client(self) -> Any:
return self._conn
def close(self) -> None:
"""Close the underlying SQLite connection (idempotent)."""
self._conn.close()
def __enter__(self) -> "PaperlessSqliteVecVectorStore":
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
# Deterministically release the connection (and its WAL/SHM handles) so
# it is never left open across a compaction/migration file swap.
self.close()
@contextmanager
def _transaction(self) -> Iterator[None]:
self._conn.execute("BEGIN IMMEDIATE")
try:
yield
except BaseException: # pragma: no cover
self._conn.execute("ROLLBACK")
raise
else:
self._conn.execute("COMMIT")
def _meta_get(self, key: str) -> str | None:
row = self._conn.execute(
"SELECT value FROM index_meta WHERE key = ?",
(key,),
).fetchone()
return row["value"] if row else None
@staticmethod
def _meta_set_on(conn: sqlite3.Connection, key: str, value: str) -> None:
conn.execute(
"INSERT INTO index_meta (key, value) VALUES (?, ?) "
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(key, value),
)
def _meta_set(self, key: str, value: str) -> None:
self._meta_set_on(self._conn, key, value)
def table_exists(self) -> bool:
return (
self._conn.execute(
"SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ?",
(DEFAULT_TABLE_NAME,),
).fetchone()
is not None
)
return self._table is not None
def vector_dim(self) -> int | None:
if not self.table_exists():
if self._table is None:
return None
value = self._meta_get("dim")
return int(value) if value else None
return self._table.schema.field("vector").type.list_size
def drop_table(self) -> None:
self._conn.execute("DROP TABLE IF EXISTS " + DEFAULT_TABLE_NAME)
self._conn.execute("DELETE FROM index_meta")
if self.table_exists():
self._conn.drop_table(self._table_name)
self._table = None
def stored_model_name(self) -> str | None:
"""Return the embedding model name recorded at table creation, or None."""
if not self.table_exists():
"""Return the embedding model name stored in table schema metadata, or None."""
if self._table is None:
return None
return self._meta_get("embed_model")
meta = self._table.schema.metadata or {}
value = meta.get(b"embed_model")
return value.decode() if value else None
def config_mismatch(self, model_name: str) -> bool:
"""True when the stored model name differs from ``model_name``.
Returns False when no table exists or when the table predates
model-name tracking conservative default avoids spurious rebuilds.
Returns False when no table exists or when the table predates model-name
tracking (schema has no metadata) conservative default avoids spurious
rebuilds on upgrade.
"""
stored = self.stored_model_name()
if stored is None:
@@ -270,115 +129,97 @@ class PaperlessSqliteVecVectorStore(BasePydanticVectorStore):
return stored != model_name
@staticmethod
def _create_vec_table(conn: sqlite3.Connection, dim: int) -> None:
# document_id is deliberately a metadata column, NOT a partition key:
# partition keys change KNN `k` to per-partition semantics under IN
# filters (asg017/sqlite-vec#142); metadata columns give a correct
# global top-k.
conn.execute( # nosemgrep: python.sqlalchemy.security.sqlalchemy-execute-raw-query.sqlalchemy-execute-raw-query
"CREATE VIRTUAL TABLE "
+ DEFAULT_TABLE_NAME
+ " USING vec0("
+ "id TEXT PRIMARY KEY,"
+ " document_id TEXT,"
+ " modified TEXT,"
+ " +node_content TEXT,"
+ " embedding float["
+ str(int(dim))
+ "] distance_metric=cosine"
+ ")",
def _schema(dim: int, model_name: str | None = None) -> pa.Schema:
meta = {b"embed_model": model_name.encode()} if model_name else None
return pa.schema(
[
pa.field("id", pa.string()),
pa.field("doc_id", pa.string()),
pa.field("document_id", pa.string()),
pa.field("modified", pa.string()),
pa.field("vector", pa.list_(pa.float32(), dim)),
pa.field("node_content", pa.string()),
],
metadata=meta,
)
def _create_table(self, dim: int) -> None:
self._create_vec_table(self._conn, dim)
self._meta_set("dim", str(dim))
self._meta_set("schema_version", str(SCHEMA_VERSION))
if self._embed_model_name:
self._meta_set("embed_model", self._embed_model_name)
def _ensure_table(self, dim: int) -> None:
if not self.table_exists():
self._create_table(dim)
def _row(self, node: BaseNode) -> tuple[str, str, str, str, bytes]:
def _row(self, node: BaseNode) -> dict[str, Any]:
meta = node_to_metadata_dict(
node,
remove_text=False,
flat_metadata=self.flat_metadata,
)
# vec0 metadata columns reject NULL (asg017/sqlite-vec#141): coerce
# every value to a string, with "" as the absent sentinel.
document_id = node.ref_doc_id or node.metadata.get("document_id")
return (
node.node_id,
str(document_id or ""),
str(node.metadata.get("modified") or ""),
json.dumps(meta),
_pack(node.get_embedding()),
)
return {
"id": node.node_id,
"doc_id": node.ref_doc_id,
"document_id": str(node.metadata.get("document_id")),
"modified": str(node.metadata.get("modified", "")),
"vector": node.get_embedding(),
"node_content": json.dumps(meta),
}
_INSERT = (
"INSERT INTO "
+ DEFAULT_TABLE_NAME
+ " (id, document_id, modified, node_content, embedding) VALUES (?, ?, ?, ?, ?)"
)
def _ensure_table(self, rows: list[dict[str, Any]], dim: int) -> bool:
"""Create the table from ``rows`` if it does not exist yet.
def _increment_total_inserts(self, count: int) -> None:
"""Increment the cumulative insert counter stored in index_meta.
This counter never decreases (DELETEs do not decrement it) and is
used by compact() to estimate the bloat ratio: when total_inserts /
live_rows exceeds COMPACT_BLOAT_RATIO the table has accumulated
enough deleted-but-not-freed rows to warrant a rebuild.
Returns True if the table was just created (caller can skip the
separate add/merge step), False if the table already existed.
"""
current = int(self._meta_get("total_inserts") or "0")
self._meta_set("total_inserts", str(current + count))
if self._table is not None:
return False
self._table = self._conn.create_table(
self._table_name,
rows,
schema=self._schema(dim, self._embed_model_name),
)
return True
def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]:
if not nodes:
return []
rows = [self._row(node) for node in nodes]
with self._transaction():
self._ensure_table(len(nodes[0].get_embedding()))
self._conn.executemany(self._INSERT, rows)
self._increment_total_inserts(len(rows))
dim = len(nodes[0].get_embedding())
if not self._ensure_table(rows, dim):
self._table.add(rows)
return [node.node_id for node in nodes]
def upsert_document(self, document_id: str, nodes: list[BaseNode]) -> list[str]:
"""Atomically replace all stored chunks of ``document_id`` with ``nodes``.
One transaction deletes the document's existing rows and inserts the
new set (vec0's INSERT OR REPLACE is broken upstream, #259, so
delete+insert it is). WAL readers in other processes see either the
old or the new chunk set, never a partial state.
A single ``merge_insert`` commit: matching node ids are updated, new ids
inserted, and any existing rows for this document that are not in the new
set are deleted (``when_not_matched_by_source_delete``). This prunes stale
trailing chunks when an edit reduces a document's chunk count, with no
transient empty state for concurrent lock-free readers.
"""
if not nodes:
# No indexable content: remove any existing chunks for this document.
if self._table is not None:
self._table.delete(f"document_id = '{_escape(document_id)}'")
return []
rows = [self._row(node) for node in nodes]
with self._transaction():
if nodes:
self._ensure_table(len(nodes[0].get_embedding()))
if self.table_exists():
self._conn.execute(
"DELETE FROM " + DEFAULT_TABLE_NAME + " WHERE document_id = ?",
(str(document_id),),
)
if rows:
self._conn.executemany(self._INSERT, rows)
self._increment_total_inserts(len(rows))
dim = len(nodes[0].get_embedding())
if self._ensure_table(rows, dim):
return [node.node_id for node in nodes]
(
self._table.merge_insert("id")
.when_matched_update_all()
.when_not_matched_insert_all()
.when_not_matched_by_source_delete(
f"document_id = '{_escape(document_id)}'",
)
.execute(rows)
)
return [node.node_id for node in nodes]
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
if self.table_exists():
with self._transaction():
self._conn.execute(
"DELETE FROM " + DEFAULT_TABLE_NAME + " WHERE document_id = ?",
(str(ref_doc_id),),
)
if self._table is not None:
self._table.delete(f"doc_id = '{_escape(ref_doc_id)}'")
def _rows_to_nodes(self, rows: list[sqlite3.Row]) -> list[BaseNode]:
def _rows_to_nodes(self, rows: list[dict[str, Any]]) -> list[BaseNode]:
nodes: list[BaseNode] = []
for row in rows:
node = metadata_dict_to_node(json.loads(row["node_content"]))
node.embedding = _unpack(row["embedding"])
node.embedding = list(row["vector"])
nodes.append(node)
return nodes
@@ -391,214 +232,102 @@ class PaperlessSqliteVecVectorStore(BasePydanticVectorStore):
if node_ids is not None: # pragma: no cover
# node_ids lookup is not implemented; see class docstring.
raise NotImplementedError(
"PaperlessSqliteVecVectorStore does not support node_ids lookup",
"PaperlessLanceVectorStore does not support node_ids lookup",
)
if not self.table_exists():
if self._table is None:
return []
where, params = _build_where(filters)
sql = "SELECT node_content, embedding FROM " + DEFAULT_TABLE_NAME
where = _build_where(filters)
query = self._table.search()
if where:
sql += " WHERE " + where
return self._rows_to_nodes(self._conn.execute(sql, params).fetchall())
query = query.where(where)
return self._rows_to_nodes(query.to_list())
def query(
self,
query: VectorStoreQuery,
**kwargs: Any,
) -> VectorStoreQueryResult:
if not self.table_exists():
return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
if query.query_embedding is None: # pragma: no cover
if self._table is None:
return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
top_k = query.similarity_top_k if query.similarity_top_k is not None else 10
where, params = _build_where(query.filters)
sql = (
"SELECT id, node_content, embedding, distance FROM "
+ DEFAULT_TABLE_NAME
+ " WHERE embedding MATCH ? AND k = ?"
)
search = self._table.search(query.query_embedding).limit(top_k)
where = _build_where(query.filters)
if where:
sql += " AND " + where
rows = self._conn.execute(
sql,
[_pack(query.query_embedding), top_k, *params],
).fetchall()
# vec0 returns rows distance-sorted ascending; slice defensively in
# case future schema changes alter k semantics (e.g. partition keys
# return k rows per partition).
rows = rows[:top_k]
search = search.where(where)
rows = search.to_list()
nodes = self._rows_to_nodes(rows)
# Cosine distance in [0, 2]; map to a descending similarity.
# vec0 returns None distance when the query embedding is the zero vector
# (no meaningful cosine angle); treat that as maximum distance (1.0) so
# the row is included but ranked last.
sims = [
1.0 - float(row["distance"] if row["distance"] is not None else 1.0)
for row in rows
]
# LanceDB returns an L2 distance (smaller = closer); map to a descending similarity.
sims = [1.0 / (1.0 + float(row["_distance"])) for row in rows]
ids = [row["id"] for row in rows]
return VectorStoreQueryResult(nodes=nodes, similarities=sims, ids=ids)
def _has_index_on(self, column: str) -> bool:
return any(column in idx.columns for idx in self._table.list_indices())
def maybe_create_ann_index(self, min_rows: int = ANN_INDEX_MIN_ROWS) -> None:
"""Best-effort: build an IVF index once the table is large enough.
IVF_PQ is used when ``num_sub_vectors`` divides the embedding dimension,
otherwise IVF_FLAT (no divisor constraint). Any failure is logged and
leaves the table on exact search, which is always correct.
"""
if self._table is None:
return
rows = self._table.count_rows()
if rows < min_rows or self._has_index_on("vector"):
return
num_partitions = max(1, rows // 4096)
# Embedding dim from the schema's fixed-size list column.
dim = self._table.schema.field("vector").type.list_size
try:
if dim % ANN_PQ_SUB_VECTORS == 0: # pragma: no cover
self._table.create_index(
metric="l2",
num_partitions=num_partitions,
num_sub_vectors=ANN_PQ_SUB_VECTORS,
index_type="IVF_PQ",
)
else:
self._table.create_index(
metric="l2",
num_partitions=num_partitions,
index_type="IVF_FLAT",
)
except Exception as e: # pragma: no cover - depends on data/dim
logger.warning("Skipping ANN index creation: %s", e)
def get_modified_times(self) -> dict[str, str]:
"""Return {document_id: stored_modified_isoformat} for all indexed documents.
All chunks of a document share the same ``modified`` value, so the
first row seen per document is sufficient.
One representative chunk per document is fetched; all chunks share the
same ``modified`` value so the first one seen is sufficient.
"""
if not self.table_exists():
if self._table is None:
return {}
result: dict[str, str] = {}
for row in self._conn.execute(
"SELECT document_id, modified FROM " + DEFAULT_TABLE_NAME,
):
for row in self._table.search().select(["document_id", "modified"]).to_list():
doc_id = str(row["document_id"])
if doc_id not in result:
result[doc_id] = str(row["modified"] or "")
return result
def compact(self, *, force: bool = False) -> None:
"""Rebuild the database file to reclaim space left behind by DELETEs.
vec0 DELETE only invalidates rows; the vector data stays in the file
forever (asg017/sqlite-vec#54), and per-document re-indexing is a
delete+insert. The cumulative insert counter in ``index_meta`` tracks
total rows ever written; when that exceeds ``COMPACT_BLOAT_RATIO`` x
the live row count (or when forced), live rows are copied into a fresh
database file and swapped in via ``os.replace``.
Note: ``ALTER TABLE ... RENAME TO`` on vec0 virtual tables does NOT
rename the shadow tables (sqlite-vec upstream limitation), so
an in-place rename-based rebuild is not safe. The file-swap approach
is the maintainer-endorsed workaround (asg017/sqlite-vec#205).
"""
if not self.table_exists():
def ensure_document_id_scalar_index(self) -> None:
"""Create a scalar index on the filter column (never on the merge key
``id`` see https://github.com/lancedb/lancedb/issues/3177).
No-op if the index already exists."""
if self._table is None:
return
live = self._conn.execute(
"SELECT count(*) FROM " + DEFAULT_TABLE_NAME,
).fetchone()[0]
total = int(self._meta_get("total_inserts") or str(live))
if not force and total <= max(live, 1) * COMPACT_BLOAT_RATIO:
if self._has_index_on("document_id"):
return
dim = self.vector_dim()
if dim is None: # pragma: no cover - dim is written at creation
logger.warning("Skipping compact: no stored vector dimension")
return
logger.info(
"Compacting LLM index (%d live rows, %d cumulative inserts)",
live,
total,
)
db_path = str(Path(self._uri) / DB_FILENAME)
compact_path = db_path + ".compact"
# Copy all live rows into a fresh database file.
new_conn = self._open_connection(compact_path)
try:
self._create_vec_table(new_conn, dim)
self._meta_set_on(new_conn, "dim", str(dim))
for key in ("embed_model", "schema_version"):
value = self._meta_get(key)
if value is not None:
self._meta_set_on(new_conn, key, value)
rows = self._conn.execute(
"SELECT id, document_id, modified, node_content, embedding "
"FROM " + DEFAULT_TABLE_NAME,
).fetchall()
new_conn.execute("BEGIN IMMEDIATE")
new_conn.executemany(
self._INSERT,
[
(
r["id"],
r["document_id"],
r["modified"],
r["node_content"],
bytes(r["embedding"]),
)
for r in rows
],
)
# Reset the cumulative counter: after compact, total_inserts == live.
self._meta_set_on(new_conn, "total_inserts", str(live))
new_conn.execute("COMMIT")
except BaseException:
new_conn.close()
for p in [compact_path, compact_path + "-wal", compact_path + "-shm"]:
Path(p).unlink(missing_ok=True)
raise
new_conn.close()
self._swap_in_compact(compact_path, db_path)
self._table.create_scalar_index("document_id")
except Exception as e: # pragma: no cover
logger.warning("Skipping document_id scalar index: %s", e)
def _swap_in_compact(self, compact_path: str, db_path: str) -> None:
"""Atomically replace the live database with the compacted copy."""
self._conn.close()
for suffix in ["-wal", "-shm"]:
stale = Path(compact_path + suffix)
if stale.exists(): # pragma: no cover
stale.unlink()
Path(compact_path).replace(db_path)
self._conn = self._open_connection(db_path)
def compact(self, retention_seconds: int) -> None:
"""Compact fragments and prune old MVCC versions in one call."""
if self._table is None:
return
from datetime import timedelta
def check_and_run_migrations(self) -> bool:
"""Apply any pending schema migrations to the store.
Structural migrations copy live rows into a new-schema file with no
re-embedding. Re-embed migrations cannot be applied automatically;
this method returns True when one is encountered so the caller can
force a full rebuild (which recreates the table at SCHEMA_VERSION).
Must be called under the write FileLock. No-op when the table does
not exist or is already at SCHEMA_VERSION.
"""
if not self.table_exists():
return False
raw = self._meta_get("schema_version")
current = int(raw) if raw is not None else SCHEMA_VERSION
if current >= SCHEMA_VERSION:
return False
pending = sorted(
[m for m in MIGRATIONS if current <= m.from_version < SCHEMA_VERSION],
key=lambda m: m.from_version,
)
for migration in pending:
if migration.kind == "re-embed":
logger.warning(
"LLM index schema v%d -> v%d requires re-embedding (%s); "
"forcing full rebuild.",
migration.from_version,
migration.to_version,
migration.description,
)
return True
logger.info(
"Running structural LLM index migration v%d -> v%d: %s",
migration.from_version,
migration.to_version,
migration.description,
)
self._run_structural_migration(migration)
return False
def _run_structural_migration(self, migration: Migration) -> None:
"""Execute a structural migration using the same file-swap as compact()."""
assert migration.apply is not None, "structural migration must have apply()"
dim = self.vector_dim()
if dim is None: # pragma: no cover
raise RuntimeError("Cannot migrate: no stored vector dimension")
db_path = str(Path(self._uri) / DB_FILENAME)
compact_path = db_path + ".compact"
new_conn = self._open_connection(compact_path)
try:
migration.apply(self._conn, new_conn, dim)
self._meta_set_on(new_conn, "schema_version", str(migration.to_version))
except BaseException: # pragma: no cover
new_conn.close()
for p in [compact_path, compact_path + "-wal", compact_path + "-shm"]:
Path(p).unlink(missing_ok=True)
raise
new_conn.close()
self._swap_in_compact(compact_path, db_path)
self._table.optimize(cleanup_older_than=timedelta(seconds=retention_seconds))
+5 -10
View File
@@ -4,7 +4,6 @@ import logging
import ssl
import tempfile
import traceback
import unicodedata
from datetime import date
from datetime import timedelta
from fnmatch import fnmatch
@@ -497,10 +496,10 @@ class MailAccountHandler(LoggingMixin):
rule: MailRule,
) -> str | None:
if rule.assign_title_from == MailRule.TitleSource.FROM_SUBJECT:
return unicodedata.normalize("NFC", message.subject)
return message.subject
elif rule.assign_title_from == MailRule.TitleSource.FROM_FILENAME:
return unicodedata.normalize("NFC", Path(att.filename).stem)
return Path(att.filename).stem
elif rule.assign_title_from == MailRule.TitleSource.NONE:
return None
@@ -867,9 +866,7 @@ class MailAccountHandler(LoggingMixin):
),
)
attachment_name = pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", att.filename),
)
attachment_name = pathvalidate.sanitize_filename(att.filename)
if attachment_name:
temp_filename = temp_dir / attachment_name
else: # pragma: no cover
@@ -885,7 +882,7 @@ class MailAccountHandler(LoggingMixin):
)
doc_overrides = DocumentMetadataOverrides(
title=title,
filename=attachment_name,
filename=pathvalidate.sanitize_filename(att.filename),
correspondent_id=correspondent.id if correspondent else None,
document_type_id=doc_type.id if doc_type else None,
tag_ids=tag_ids,
@@ -991,9 +988,7 @@ class MailAccountHandler(LoggingMixin):
)
doc_overrides = DocumentMetadataOverrides(
title=message.subject,
filename=pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", f"{message.subject}.eml"),
),
filename=pathvalidate.sanitize_filename(f"{message.subject}.eml"),
correspondent_id=correspondent.id if correspondent else None,
document_type_id=doc_type.id if doc_type else None,
tag_ids=tag_ids,
@@ -1,158 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:10
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
replaces = [
("paperless_mail", "0002_optimize_integer_field_sizes"),
("paperless_mail", "0003_mailrule_stop_processing"),
]
dependencies = [
("paperless_mail", "0001_squashed"),
]
operations = [
migrations.AlterField(
model_name="mailaccount",
name="account_type",
field=models.PositiveSmallIntegerField(
choices=[(1, "IMAP"), (2, "Gmail OAuth"), (3, "Outlook OAuth")],
default=1,
verbose_name="account type",
),
),
migrations.AlterField(
model_name="mailaccount",
name="imap_port",
field=models.PositiveIntegerField(
blank=True,
help_text="This is usually 143 for unencrypted and STARTTLS connections, and 993 for SSL connections.",
null=True,
verbose_name="IMAP port",
),
),
migrations.AlterField(
model_name="mailaccount",
name="imap_security",
field=models.PositiveSmallIntegerField(
choices=[(1, "No encryption"), (2, "Use SSL"), (3, "Use STARTTLS")],
default=2,
verbose_name="IMAP security",
),
),
migrations.AlterField(
model_name="mailrule",
name="action",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Delete"),
(2, "Move to specified folder"),
(3, "Mark as read, don't process read mails"),
(4, "Flag the mail, don't process flagged mails"),
(5, "Tag the mail with specified tag, don't process tagged mails"),
],
default=3,
verbose_name="action",
),
),
migrations.AlterField(
model_name="mailrule",
name="assign_correspondent_from",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Do not assign a correspondent"),
(2, "Use mail address"),
(3, "Use name (or mail address if not available)"),
(4, "Use correspondent selected below"),
],
default=1,
verbose_name="assign correspondent from",
),
),
migrations.AlterField(
model_name="mailrule",
name="assign_title_from",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Use subject as title"),
(2, "Use attachment filename as title"),
(3, "Do not assign title from rule"),
],
default=1,
verbose_name="assign title from",
),
),
migrations.AlterField(
model_name="mailrule",
name="attachment_type",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Only process attachments."),
(2, "Process all files, including 'inline' attachments."),
],
default=1,
help_text="Inline attachments include embedded images, so it's best to combine this option with a filename filter.",
verbose_name="attachment type",
),
),
migrations.AlterField(
model_name="mailrule",
name="consumption_scope",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Only process attachments."),
(
2,
"Process full Mail (with embedded attachments in file) as .eml",
),
(
3,
"Process full Mail (with embedded attachments in file) as .eml + process attachments as separate documents",
),
],
default=1,
verbose_name="consumption scope",
),
),
migrations.AlterField(
model_name="mailrule",
name="maximum_age",
field=models.PositiveSmallIntegerField(
default=30,
help_text="Specified in days.",
verbose_name="maximum age",
),
),
migrations.AlterField(
model_name="mailrule",
name="order",
field=models.SmallIntegerField(default=0, verbose_name="order"),
),
migrations.AlterField(
model_name="mailrule",
name="pdf_layout",
field=models.PositiveSmallIntegerField(
choices=[
(0, "System default"),
(1, "Text, then HTML"),
(2, "HTML, then text"),
(3, "HTML only"),
(4, "Text only"),
],
default=0,
verbose_name="pdf layout",
),
),
migrations.AddField(
model_name="mailrule",
name="stop_processing",
field=models.BooleanField(
default=False,
help_text="If True, no further rules will be processed after this one if any document is queued.",
verbose_name="Stop processing further rules",
),
),
]
-182
View File
@@ -1,182 +0,0 @@
"""
Tests that mail attachment filenames and EML subject filenames are
normalized to NFC Unicode before being stored as document overrides.
Filenames from MIME headers can arrive in NFD form (e.g. from macOS Mail),
and must be normalized to NFC so filenames are consistent regardless of the
sending client.
"""
import unicodedata
from pathlib import Path
from unittest import mock
import pytest
from documents.tests.utils import remove_dirs
from documents.tests.utils import setup_directories
from paperless_mail.models import MailRule
from paperless_mail.tests.factories import MailAccountFactory
from paperless_mail.tests.test_mail import MessageBuilder
from paperless_mail.tests.test_mail import _AttachmentDef
from paperless_mail.tests.test_mail import fake_magic_from_buffer
@pytest.fixture()
def directories(settings):
dirs = setup_directories()
yield dirs
remove_dirs(dirs)
@pytest.fixture()
def queue_consumption_tasks_mock():
with mock.patch("paperless_mail.mail.queue_consumption_tasks") as m:
yield m
@pytest.fixture()
def mail_account(db):
return MailAccountFactory()
@pytest.fixture()
def attachment_rule(mail_account):
rule = MailRule(
name="attachment rule",
account=mail_account,
assign_title_from=MailRule.TitleSource.FROM_FILENAME,
consumption_scope=MailRule.ConsumptionScope.ATTACHMENTS_ONLY,
attachment_type=MailRule.AttachmentProcessing.ATTACHMENTS_ONLY,
)
rule.save()
return rule
@pytest.fixture()
def eml_rule(mail_account):
rule = MailRule(
name="eml rule",
account=mail_account,
assign_title_from=MailRule.TitleSource.FROM_SUBJECT,
consumption_scope=MailRule.ConsumptionScope.EML_ONLY,
attachment_type=MailRule.AttachmentProcessing.ATTACHMENTS_ONLY,
)
rule.save()
return rule
@pytest.fixture()
def message_builder():
return MessageBuilder()
@pytest.mark.django_db
@mock.patch("paperless_mail.mail.magic.from_buffer", fake_magic_from_buffer)
class TestMailNFCNormalization:
"""Attachment filenames and EML subject filenames must be NFC-normalized."""
def test_attachment_nfd_filename_normalized_to_nfc(
self,
directories,
queue_consumption_tasks_mock,
attachment_rule,
mail_account_handler,
message_builder,
):
"""Attachment filename arriving as NFD must be stored as NFC in both
the overrides and the temp file written to disk.
"""
nfd_filename = unicodedata.normalize("NFD", "Rechnung März.pdf")
nfc_filename = unicodedata.normalize("NFC", "Rechnung März.pdf")
# Confirm the fixture is actually NFD (not already NFC)
assert unicodedata.is_normalized("NFD", nfd_filename)
assert not unicodedata.is_normalized("NFC", nfd_filename)
message = message_builder.create_message(
subject="Test invoice",
from_="sender@example.com",
attachments=[
_AttachmentDef(filename=nfd_filename, content=b"%PDF-1.4 test"),
],
)
result = mail_account_handler._handle_message(message, attachment_rule)
assert result == 1
queue_consumption_tasks_mock.assert_called_once()
call_kwargs = queue_consumption_tasks_mock.call_args.kwargs
consume_tasks = call_kwargs["consume_tasks"]
assert len(consume_tasks) == 1
overrides = consume_tasks[0].kwargs["overrides"]
assert overrides.filename == nfc_filename
assert unicodedata.is_normalized("NFC", overrides.filename)
assert unicodedata.is_normalized("NFC", overrides.title)
input_doc = consume_tasks[0].kwargs["input_doc"]
original_file = Path(input_doc.original_file)
assert original_file.exists()
assert original_file.name == nfc_filename
def test_eml_subject_filename_nfc(
self,
directories,
queue_consumption_tasks_mock,
eml_rule,
mail_account_handler,
message_builder,
):
"""EML filename derived from subject arriving as NFD must be stored as NFC."""
nfd_subject = unicodedata.normalize("NFD", "Rechnung März 2024")
nfc_expected_filename = unicodedata.normalize("NFC", "Rechnung März 2024.eml")
# Confirm the fixture is actually NFD
assert unicodedata.is_normalized("NFD", nfd_subject)
message = message_builder.create_message(
subject=nfd_subject,
from_="sender@example.com",
attachments=0,
)
mail_account_handler._handle_message(message, eml_rule)
queue_consumption_tasks_mock.assert_called_once()
call_kwargs = queue_consumption_tasks_mock.call_args.kwargs
consume_tasks = call_kwargs["consume_tasks"]
assert len(consume_tasks) == 1
overrides = consume_tasks[0].kwargs["overrides"]
assert overrides.filename == nfc_expected_filename
assert unicodedata.is_normalized("NFC", overrides.filename)
def test_already_nfc_attachment_filename_unchanged(
self,
directories,
queue_consumption_tasks_mock,
attachment_rule,
mail_account_handler,
message_builder,
):
"""An attachment filename already in NFC must pass through unchanged."""
nfc_filename = "Invoice_2024.pdf"
assert unicodedata.is_normalized("NFC", nfc_filename)
message = message_builder.create_message(
subject="Invoice",
from_="sender@example.com",
attachments=[
_AttachmentDef(filename=nfc_filename, content=b"%PDF-1.4 test"),
],
)
mail_account_handler._handle_message(message, attachment_rule)
call_kwargs = queue_consumption_tasks_mock.call_args.kwargs
consume_tasks = call_kwargs["consume_tasks"]
overrides = consume_tasks[0].kwargs["overrides"]
assert overrides.filename == nfc_filename
Generated
+106 -13
View File
@@ -2052,6 +2052,55 @@ redis = [
{ name = "redis", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[[package]]
name = "lance-namespace"
version = "0.8.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "lance-namespace-urllib3-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/21/80/2b6eaa08c5e25915acaa6368a70211a25b5ba9d2d6006450e68a73936164/lance_namespace-0.8.0.tar.gz", hash = "sha256:c4a79ee221a3b2315c29863ad12d85fcf219a13158e26149d63e21dc4b4673a7", size = 10756, upload-time = "2026-06-01T08:47:10.183Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4b/bd/7b40a08fb132fab39a6caebf832fdf6b9befc71be9413beb9be0a9d927d4/lance_namespace-0.8.0-py3-none-any.whl", hash = "sha256:782cf9e332f46bf06836722dd98b53ca8495ad98bb541501ff6876c89b67ec90", size = 12579, upload-time = "2026-06-01T08:47:10.91Z" },
]
[[package]]
name = "lance-namespace-urllib3-client"
version = "0.8.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "urllib3", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/8c/37/06fcd5a8969381e0ba953d51990af8d331bdccbc62458bf2eed30d064573/lance_namespace_urllib3_client-0.8.0.tar.gz", hash = "sha256:4f060f05ebf3c04aeaeb0d2022cbe77648a3df290f02cd2c305e5797d0fc1fdd", size = 203710, upload-time = "2026-06-01T08:47:13.404Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/51/43/e280727feee958f303bc58d5fa912b07734a0831f756d841654d500c2c34/lance_namespace_urllib3_client-0.8.0-py3-none-any.whl", hash = "sha256:6734e341b726e5cc96a0cd257cef27eb9d03013f2d151526ee426cef8e63e228", size = 336669, upload-time = "2026-06-01T08:47:11.88Z" },
]
[[package]]
name = "lancedb"
version = "0.33.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "deprecation", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "lance-namespace", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "overrides", marker = "(python_full_version < '3.12' and sys_platform == 'darwin') or (python_full_version < '3.12' and sys_platform == 'linux')" },
{ name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyarrow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/09/2f/d5a4b2a5bb1f800936c76a6d8a4daf127a86fcab621eeb70b574a5adc774/lancedb-0.33.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:d4eaf6fa7c2eac619208f1d396f4de635ee0f535673067118a31c1181575c48b", size = 48338115, upload-time = "2026-05-28T20:37:55.88Z" },
{ url = "https://files.pythonhosted.org/packages/07/12/31787b93a856b2c31382c7771dc22fb05575b70b87c9efe454269f4f0948/lancedb-0.33.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c6c2402ed2744245ae76c4167c0461da0a7a80f1608e0ec491c1548ea2b4302", size = 51162262, upload-time = "2026-05-28T20:37:59.101Z" },
{ url = "https://files.pythonhosted.org/packages/49/b7/081cc29f8e06bf12191b99ab3fe702aceebdb0914476b821a8c0445cacc8/lancedb-0.33.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ebf1ffad811e6254a93931a79489ba1f21f48564bdfa06abae846f5fcaaf3e8", size = 54381368, upload-time = "2026-05-28T20:38:02.2Z" },
{ url = "https://files.pythonhosted.org/packages/1c/bd/e0f4bd621f10ecf96a801b0166e87799ed7ca5a9dbabcef9a6c766a58ef3/lancedb-0.33.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:13da39f80adfea59e5831fe64e4166b2d70a2f843e6507bf644c4fe4c350087c", size = 51188986, upload-time = "2026-05-28T20:38:05.375Z" },
{ url = "https://files.pythonhosted.org/packages/d9/1a/a8647a432ac6aa59cdce1fc061a7050ea4278bcab364539b78af2ecf72d2/lancedb-0.33.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:21b712825f0a00225e8974a41352c4ea84b0899ef8c23b17f672fadc38bd8346", size = 54440958, upload-time = "2026-05-28T20:38:08.474Z" },
]
[[package]]
name = "langdetect"
version = "1.0.9"
@@ -2843,6 +2892,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/c1/d6e64ccd0536bf616556f0cad2b6d94a8125f508d25cfd814b1d2db4e2f1/openai-2.32.0-py3-none-any.whl", hash = "sha256:4dcc9badeb4bf54ad0d187453742f290226d30150890b7890711bda4f32f192f", size = 1162570, upload-time = "2026-04-15T22:28:17.714Z" },
]
[[package]]
name = "overrides"
version = "7.7.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/36/86/b585f53236dec60aba864e050778b25045f857e17f6e5ea0ae95fe80edd2/overrides-7.7.0.tar.gz", hash = "sha256:55158fa3d93b98cc75299b1e67078ad9003ca27945c76162c1c0766d6f91820a", size = 22812, upload-time = "2024-01-27T21:01:33.423Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/ab/fc8290c6a4c722e5514d80f62b2dc4c4df1a68a41d1364e625c35990fcf3/overrides-7.7.0-py3-none-any.whl", hash = "sha256:c7ed9d062f78b8e4c1a7b70bd8796b35ead4d9f510227ef9c5dc7626c60d7e49", size = 17832, upload-time = "2024-01-27T21:01:31.393Z" },
]
[[package]]
name = "packaging"
version = "26.0"
@@ -2890,6 +2948,7 @@ dependencies = [
{ name = "ijson", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "imap-tools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "lancedb", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "langdetect", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "llama-index-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "llama-index-embeddings-huggingface", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -2902,6 +2961,7 @@ dependencies = [
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pathvalidate", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pdf2image", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyarrow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "python-dotenv", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "python-gnupg", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -2913,7 +2973,6 @@ dependencies = [
{ name = "scikit-learn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "sentence-transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "setproctitle", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "sqlite-vec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tantivy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tika-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "torch", version = "2.11.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" },
@@ -3040,6 +3099,7 @@ requires-dist = [
{ name = "ijson", specifier = ">=3.2" },
{ name = "imap-tools", specifier = "~=1.13.0" },
{ name = "jinja2", specifier = "~=3.1.5" },
{ name = "lancedb", specifier = "~=0.33.0" },
{ name = "langdetect", specifier = "~=1.0.9" },
{ name = "llama-index-core", specifier = ">=0.14.21" },
{ name = "llama-index-embeddings-huggingface", specifier = ">=0.6.1" },
@@ -3058,6 +3118,7 @@ requires-dist = [
{ name = "psycopg-c", marker = "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'postgres'", url = "https://github.com/paperless-ngx/builder/releases/download/psycopg-trixie-3.3.0/psycopg_c-3.3.0-cp312-cp312-linux_x86_64.whl" },
{ name = "psycopg-c", marker = "(python_full_version != '3.12.*' and platform_machine == 'aarch64' and extra == 'postgres') or (python_full_version != '3.12.*' and platform_machine == 'x86_64' and extra == 'postgres') or (platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'postgres') or (sys_platform != 'linux' and extra == 'postgres')", specifier = "==3.3" },
{ name = "psycopg-pool", marker = "extra == 'postgres'", specifier = "==3.3" },
{ name = "pyarrow", specifier = ">=16" },
{ name = "python-dateutil", specifier = "~=2.9.0" },
{ name = "python-dotenv", specifier = "~=1.2.1" },
{ name = "python-gnupg", specifier = "~=0.5.4" },
@@ -3069,7 +3130,6 @@ requires-dist = [
{ name = "scikit-learn", specifier = "~=1.8.0" },
{ name = "sentence-transformers", specifier = ">=5.4.1" },
{ name = "setproctitle", specifier = "~=1.3.4" },
{ name = "sqlite-vec", specifier = "==0.1.9" },
{ name = "tantivy", specifier = "~=0.26.0" },
{ name = "tika-client", specifier = "~=0.11.0" },
{ name = "torch", specifier = "~=2.11.0", index = "https://download.pytorch.org/whl/cpu" },
@@ -3557,6 +3617,50 @@ version = "0.16.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/1d/c7/28220d37e041fe1df03e857fe48f768dcd30cd151480bf6f00da8713214a/py-ubjson-0.16.1.tar.gz", hash = "sha256:b9bfb8695a1c7e3632e800fb83c943bf67ed45ddd87cd0344851610c69a5a482", size = 50316, upload-time = "2020-04-18T15:05:57.698Z" }
[[package]]
name = "pyarrow"
version = "24.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/91/13/13e1069b351bdc3881266e11147ffccf687505dbb0ea74036237f5d454a5/pyarrow-24.0.0.tar.gz", hash = "sha256:85fe721a14dd823aca09127acbb06c3ca723efbd436c004f16bca601b04dcc83", size = 1180261, upload-time = "2026-04-21T10:51:25.837Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/62/c9/a47ab7ece0d86cbe6678418a0fbd1ac4bb493b9184a3891dfa0e7f287ae0/pyarrow-24.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b0e131f880cda8d04e076cee175a46fc0e8bc8b65c99c6c09dff6669335fde74", size = 35068898, upload-time = "2026-04-21T10:46:36.599Z" },
{ url = "https://files.pythonhosted.org/packages/d1/bc/8db86617a9a58008acf8913d6fed68ea2a46acb6de928db28d724c891a68/pyarrow-24.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:1b2fe7f9a5566401a0ef2571f197eb92358925c1f0c8dba305d6e43ea0871bb3", size = 36679915, upload-time = "2026-04-21T10:46:42.602Z" },
{ url = "https://files.pythonhosted.org/packages/eb/8e/fb178720400ef69db251eb4a9c3ccf4af269bc1feb5055529b8fc87170d1/pyarrow-24.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:0b3537c00fb8d384f15ac1e79b6eb6db04a16514c8c1d22e59a9b95c8ba42868", size = 45697931, upload-time = "2026-04-21T10:46:48.403Z" },
{ url = "https://files.pythonhosted.org/packages/f3/27/99c42abe8e21b44f4917f62631f3aa31404882a2c41d8a4cd5c110e13d52/pyarrow-24.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:14e31a3c9e35f1ab6356c6378f6f72830e6d2d5f1791df3774a7b097d18a6a1e", size = 48837449, upload-time = "2026-04-21T10:46:55.329Z" },
{ url = "https://files.pythonhosted.org/packages/36/b6/333749e2666e9032891125bf9c691146e92901bece62030ac1430e2e7c88/pyarrow-24.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b7d9a514e73bc42711e6a35aaccf3587c520024fe0a25d830a1a8a27c15f4f57", size = 49395949, upload-time = "2026-04-21T10:47:01.869Z" },
{ url = "https://files.pythonhosted.org/packages/17/25/c5201706a2dd374e8ba6ee3fd7a8c89fb7ffc16eed5217a91fd2bd7f7626/pyarrow-24.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b196eb3f931862af3fa84c2a253514d859c08e0d8fe020e07be12e75a5a9780c", size = 51912986, upload-time = "2026-04-21T10:47:09.872Z" },
{ url = "https://files.pythonhosted.org/packages/b4/a9/9686d9f07837f91f775e8932659192e02c74f9d8920524b480b85212cc68/pyarrow-24.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:6233c9ed9ab9d1db47de57d9753256d9dcffbf42db341576099f0fd9f6bf4810", size = 34981559, upload-time = "2026-04-21T10:47:22.17Z" },
{ url = "https://files.pythonhosted.org/packages/80/b6/0ddf0e9b6ead3474ab087ae598c76b031fc45532bf6a63f3a553440fb258/pyarrow-24.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:f7616236ec1bc2b15bfdec22a71ab38851c86f8f05ff64f379e1278cf20c634a", size = 36663654, upload-time = "2026-04-21T10:47:28.315Z" },
{ url = "https://files.pythonhosted.org/packages/7c/3b/926382efe8ce27ba729071d3566ade6dfb86bdf112f366000196b2f5780a/pyarrow-24.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:1617043b99bd33e5318ae18eb2919af09c71322ef1ca46566cdafc6e6712fb66", size = 45679394, upload-time = "2026-04-21T10:47:34.821Z" },
{ url = "https://files.pythonhosted.org/packages/b3/7a/829f7d9dfd37c207206081d6dad474d81dde29952401f07f2ba507814818/pyarrow-24.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6165461f55ef6314f026de6638d661188e3455d3ec49834556a0ebbdbace18bb", size = 48863122, upload-time = "2026-04-21T10:47:42.056Z" },
{ url = "https://files.pythonhosted.org/packages/5f/e8/f88ce625fe8babaae64e8db2d417c7653adb3019b08aae85c5ed787dc816/pyarrow-24.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3b13dedfe76a0ad2d1d859b0811b53827a4e9d93a0bcb05cf59333ab4980cc7e", size = 49376032, upload-time = "2026-04-21T10:47:48.967Z" },
{ url = "https://files.pythonhosted.org/packages/36/7a/82c363caa145fff88fb475da50d3bf52bb024f61917be5424c3392eaf878/pyarrow-24.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:25ea65d868eb04015cd18e6df2fbe98f07e5bda2abefabcb88fce39a947716f6", size = 51929490, upload-time = "2026-04-21T10:47:55.981Z" },
{ url = "https://files.pythonhosted.org/packages/6f/d3/a1abf004482026ddc17f4503db227787fa3cfe41ec5091ff20e4fea55e57/pyarrow-24.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:02b001b3ed4723caa44f6cd1af2d5c86aa2cf9971dacc2ffa55b21237713dfba", size = 34976759, upload-time = "2026-04-21T10:48:07.258Z" },
{ url = "https://files.pythonhosted.org/packages/4f/4a/34f0a36d28a2dd32225301b79daad44e243dc1a2bb77d43b60749be255c4/pyarrow-24.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:04920d6a71aabd08a0417709efce97d45ea8e6fb733d9ca9ecffb13c67839f68", size = 36658471, upload-time = "2026-04-21T10:48:13.347Z" },
{ url = "https://files.pythonhosted.org/packages/1f/78/543b94712ae8bb1a6023bcc1acf1a740fbff8286747c289cd9468fced2a5/pyarrow-24.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a964266397740257f16f7bb2e4f08a0c81454004beab8ff59dd531b73610e9f2", size = 45675981, upload-time = "2026-04-21T10:48:20.201Z" },
{ url = "https://files.pythonhosted.org/packages/84/9f/8fb7c222b100d314137fa40ec050de56cd8c6d957d1cfff685ce72f15b17/pyarrow-24.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6f066b179d68c413374294bc1735f68475457c933258df594443bb9d88ddc2a0", size = 48859172, upload-time = "2026-04-21T10:48:27.541Z" },
{ url = "https://files.pythonhosted.org/packages/a7/d3/1ea72538e6c8b3b475ed78d1049a2c518e655761ea50fe1171fc855fcab7/pyarrow-24.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1183baeb14c5f587b1ec52831e665718ce632caab84b7cd6b85fd44f96114495", size = 49385733, upload-time = "2026-04-21T10:48:34.7Z" },
{ url = "https://files.pythonhosted.org/packages/c3/be/c3d8b06a1ba35f2260f8e1f771abbee7d5e345c0937aab90675706b1690a/pyarrow-24.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:806f24b4085453c197a5078218d1ee08783ebbba271badd153d1ae22a3ee804f", size = 51934335, upload-time = "2026-04-21T10:48:42.099Z" },
{ url = "https://files.pythonhosted.org/packages/17/1a/cff3a59f80b5b1658549d46611b67163f65e0664431c076ad728bf9d5af4/pyarrow-24.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:1a4e45017efbf115032e4475ee876d525e0e36c742214fbe405332480ecd6275", size = 35238554, upload-time = "2026-04-21T10:48:48.526Z" },
{ url = "https://files.pythonhosted.org/packages/a8/99/cce0f42a327bfef2c420fb6078a3eb834826e5d6697bf3009fe11d2ad051/pyarrow-24.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:7986f1fa71cee060ad00758bcc79d3a93bab8559bf978fab9e53472a2e25a17b", size = 36782301, upload-time = "2026-04-21T10:48:55.181Z" },
{ url = "https://files.pythonhosted.org/packages/2a/66/8e560d5ff6793ca29aca213c53eec0dd482dd46cb93b2819e5aab52e4252/pyarrow-24.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:d3e0b61e8efb24ed38898e5cdc5fffa9124be480008d401a1f8071500494ae42", size = 45721929, upload-time = "2026-04-21T10:49:03.676Z" },
{ url = "https://files.pythonhosted.org/packages/27/0c/a26e25505d030716e078d9f16eb74973cbf0b33b672884e9f9da1c83b871/pyarrow-24.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:55a3bc1e3df3b5567b7d27ef551b2283f0c68a5e86f1cd56abc569da4f31335b", size = 48825365, upload-time = "2026-04-21T10:49:11.714Z" },
{ url = "https://files.pythonhosted.org/packages/5f/eb/771f9ecb0c65e73fe9dccdd1717901b9594f08c4515d000c7c62df573811/pyarrow-24.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:641f795b361874ac9da5294f8f443dfdbee355cf2bd9e3b8d97aaac2306b9b37", size = 49451819, upload-time = "2026-04-21T10:49:21.474Z" },
{ url = "https://files.pythonhosted.org/packages/48/da/61ae89a88732f5a785646f3ec6125dbb640fa98a540eb2b9889caa561403/pyarrow-24.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:8adc8e6ce5fccf5dc707046ae4914fd537def529709cc0d285d37a7f9cd442ca", size = 51909252, upload-time = "2026-04-21T10:49:31.164Z" },
{ url = "https://files.pythonhosted.org/packages/ad/80/d022a34ff05d2cbedd8ccf841fc1f532ecfa9eb5ed1711b56d0e0ea71fc9/pyarrow-24.0.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:1cc9057f0319e26333b357e17f3c2c022f1a83739b48a88b25bfd5fa2dc18838", size = 35007997, upload-time = "2026-04-21T10:49:48.796Z" },
{ url = "https://files.pythonhosted.org/packages/1a/ff/f01485fda6f4e5d441afb8dd5e7681e4db18826c1e271852f5d3957d6a80/pyarrow-24.0.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:e6f1278ee4785b6db21229374a1c9e54ec7c549de5d1efc9630b6207de7e170b", size = 36678720, upload-time = "2026-04-21T10:49:55.858Z" },
{ url = "https://files.pythonhosted.org/packages/9e/c2/2d2d5fea814237923f71b36495211f20b43a1576f9a4d6da7e751a64ec6f/pyarrow-24.0.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:adbbedc55506cbdabb830890444fb856bfb0060c46c6f8026c6c2f2cf86ae795", size = 45741852, upload-time = "2026-04-21T10:50:04.624Z" },
{ url = "https://files.pythonhosted.org/packages/8e/3a/28ba9c1c1ebdbb5f1b94dfebb46f207e52e6a554b7fe4132540fde29a3a0/pyarrow-24.0.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:ae8a1145af31d903fa9bb166824d7abe9b4681a000b0159c9fb99c11bc11ad26", size = 48889852, upload-time = "2026-04-21T10:50:12.293Z" },
{ url = "https://files.pythonhosted.org/packages/df/51/4a389acfd31dca009f8fb82d7f510bb4130f2b3a8e18cf00194d0687d8ac/pyarrow-24.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:d7027eba1df3b2069e2e8d80f644fa0918b68c46432af3d088ddd390d063ecde", size = 49445207, upload-time = "2026-04-21T10:50:20.677Z" },
{ url = "https://files.pythonhosted.org/packages/19/4b/0bab2b23d2ae901b1b9a03c0efd4b2d070256f8ce3fc43f6e58c167b2081/pyarrow-24.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e56a1ffe9bf7b727432b89104cc0849c21582949dd7bdcb34f17b2001a351a76", size = 51954117, upload-time = "2026-04-21T10:50:29.14Z" },
{ url = "https://files.pythonhosted.org/packages/79/4f/46a49a63f43526da895b1a45bbb51d5baf8e4d77159f8528fc3e5490007f/pyarrow-24.0.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:418e48ce50a45a6a6c73c454677203a9c75c966cb1e92ca3370959185f197a05", size = 35250387, upload-time = "2026-04-21T10:50:35.552Z" },
{ url = "https://files.pythonhosted.org/packages/a0/da/d5e0cd5ef00796922404806d5f00325cdadc3441ce2c13fe7115f2df9a64/pyarrow-24.0.0-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:2f16197705a230a78270cdd4ea8a1d57e86b2fdcbc34a1f6aebc72e65c986f9a", size = 36797102, upload-time = "2026-04-21T10:50:42.417Z" },
{ url = "https://files.pythonhosted.org/packages/34/c7/5904145b0a593a05236c882933d439b5720f0a145381179063722fbfc123/pyarrow-24.0.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:fb24ac194bfc5e86839d7dcd52092ee31e5fe6733fe11f5e3b06ef0812b20072", size = 45745118, upload-time = "2026-04-21T10:50:49.324Z" },
{ url = "https://files.pythonhosted.org/packages/13/d3/cca42fe166d1c6e4d5b80e530b7949104d10e17508a90ae202dac205ce2a/pyarrow-24.0.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:9700ebd9a51f5895ce75ff4ac4b3c47a7d4b42bc618be8e713e5d56bacf5f931", size = 48844765, upload-time = "2026-04-21T10:50:55.579Z" },
{ url = "https://files.pythonhosted.org/packages/b0/49/942c3b79878ba928324d1e17c274ed84581db8c0a749b24bcf4cbdf15bd3/pyarrow-24.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:d8ddd2768da81d3ee08cfea9b597f4abb4e8e1dc8ae7e204b608d23a0d3ab699", size = 49471890, upload-time = "2026-04-21T10:51:02.439Z" },
{ url = "https://files.pythonhosted.org/packages/76/97/ff71431000a75d84135a1ace5ca4ba11726a231a8007bbb320a4c54075d5/pyarrow-24.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:61a3d7eaa97a14768b542f3d284dc6400dd2470d9f080708b13cd46b6ae18136", size = 51932250, upload-time = "2026-04-21T10:51:10.576Z" },
]
[[package]]
name = "pyasn1"
version = "0.6.3"
@@ -4668,17 +4772,6 @@ asyncio = [
{ name = "greenlet", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[[package]]
name = "sqlite-vec"
version = "0.1.9"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/68/85/9fad0045d8e7c8df3e0fa5a56c630e8e15ad6e5ca2e6106fceb666aa6638/sqlite_vec-0.1.9-py3-none-macosx_10_6_x86_64.whl", hash = "sha256:1b62a7f0a060d9475575d4e599bbf94a13d85af896bc1ce86ee80d1b5b48e5fb", size = 131171, upload-time = "2026-03-31T08:02:31.717Z" },
{ url = "https://files.pythonhosted.org/packages/a4/3d/3677e0cd2f92e5ebc43cd29fbf565b75582bff1ccfa0b8327c7508e1084f/sqlite_vec-0.1.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1d52e30513bae4cc9778ddbf6145610434081be4c3afe57cd877893bad9f6b6c", size = 165434, upload-time = "2026-03-31T08:02:32.712Z" },
{ url = "https://files.pythonhosted.org/packages/00/d4/f2b936d3bdc38eadcbd2a87875815db36430fab0363182ba5d12cd8e0b51/sqlite_vec-0.1.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e921e592f24a5f9a18f590b6ddd530eb637e2d474e3b1972f9bbeb773aa3cb9", size = 160076, upload-time = "2026-03-31T08:02:33.796Z" },
{ url = "https://files.pythonhosted.org/packages/6f/ad/6afd073b0f817b3e03f9e37ad626ae341805891f23c74b5292818f49ac63/sqlite_vec-0.1.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux1_x86_64.whl", hash = "sha256:1515727990b49e79bcaf75fdee2ffc7d461f8b66905013231251f1c8938e7786", size = 163388, upload-time = "2026-03-31T08:02:34.888Z" },
]
[[package]]
name = "sqlparse"
version = "0.5.5"