Compare commits

..

19 Commits

Author SHA1 Message Date
stumpylog 3a891b38a8 Drops the search shims 2026-06-15 19:22:59 -07:00
Trenton H f4fa916579 Fix (beta): restore v2 (Whoosh) advanced-search query compatibility (#13010)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-15 15:32:44 -07:00
shamoon 75f0c4c92e Fix (beta): retry celery ping and report warning on no response (#13012) 2026-06-15 15:05:43 -07:00
Trenton H a020f64d08 Enhancement(beta): replace LanceDB vector store with sqlite-vec (#12990)
* Chore(beta): add sqlite-vec 0.1.9 dependency

Pinned exactly: the 0.1.9 wheels carry no baked SIMD flags (safe on
pre-AVX2 CPUs, the point of this migration); the 0.1.10 alphas bake
-mavx and would reintroduce the #12970 crash class.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Test(beta): port vector store tests to sqlite-vec backend

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Enhancement(beta): switch AI vector store from LanceDB to sqlite-vec

Fixes the non-AVX2 SIGILL class (#12970) at the root: lancedb is no
longer imported. sqlite-vec 0.1.9 wheels carry no baked SIMD, vec0
metadata columns give parameterized EQ/IN filtering, WAL preserves the
lock-free-reader model, and compact() rebuilds the table because vec0
DELETEs never reclaim space.

Implementation notes vs. the Task 3A draft:
- compact() uses a file-swap approach (new db file + Path.replace) rather
  than ALTER TABLE RENAME, which does not cascade to shadow tables in
  sqlite-vec 0.1.9 (upstream limitation).
- Bloat is tracked via a cumulative total_inserts counter in index_meta
  because the _rowids shadow table does not accumulate deleted rows in
  0.1.9 (contrary to the design doc assumption from #54).
- None distances from the zero-vector cosine edge case are mapped to
  similarity 0.0 rather than raising TypeError.
- Test suite updated accordingly: _bloat_ratio reads index_meta instead
  of _rowids; seed collision in force-compact test fixed (seed=100.0).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Enhancement(beta): wire indexing pipeline to the sqlite-vec store

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Enhancement(beta): move filename/storage path/ASN to node metadata

Same treatment as title/tags/correspondent in #12944: excluded from
the embedded text, visible to the LLM via metadata prepend. Changes
embedded text for every document, so it ships inside the sqlite-vec
transition, whose forced rebuild re-embeds everything anyway.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Test(beta): cover legacy LanceDB index cleanup and forced rebuild

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Chore(beta): drop lancedb dependency

Fixes #12970: the package whose wheels SIGILL on non-AVX2 CPUs is no
longer installed at all.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Chore(beta): partial pyrefly cleanup on sqlite-vec vector store

- Add MetadataFilter import and isinstance guard in _build_where()
- Add query_embedding None guard in query()
- Fix dict.get() type-checker ambiguity in get_configured_model_name()

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Chore(beta): drop automatic LanceDB index cleanup on startup

Leave legacy Lance directory removal to the user rather than deleting it
automatically on first run. Beta policy: user is expected to do a clean
re-embed anyway; no need for the system to silently delete their data.

Remove _cleanup_legacy_lance_index(), the forced-rebuild path that called
it, and the associated tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Chore(beta): ruff format pass on sqlite-vec AI files

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Removes the benchmarking file

* Try to resolve or silence some semgrep.  But we're using SQL here, not an ORM and we control the inputs, not users

* Enhancement(beta): add schema migration machinery to sqlite-vec vector store

Adds versioned schema migration support modelled after PR #12968's LanceDB
approach, adapted for sqlite-vec's file-swap compaction pattern.

- SCHEMA_VERSION = 1 written to index_meta at table creation and preserved
  through compact()
- Migration dataclass with from_version, to_version, kind ("structural" or
  "re-embed"), description, and an optional apply(src, dst, dim) callable
- MIGRATIONS registry (empty at v1 baseline); add entries and bump
  SCHEMA_VERSION when the schema changes
- check_and_run_migrations(): structural migrations run via the same
  file-swap as compact() (no re-embed); re-embed migrations return True
  so the caller forces a full rebuild
- update_llm_index() calls check_and_run_migrations() under the write lock
  before any indexing work

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Chore(beta): deduplicate vector store internals via helper methods

Extract three helpers to remove copy-paste between compact() and
_run_structural_migration():
- _meta_set_on(conn, key, value): static upsert into any connection's
  index_meta; _meta_set() now delegates to it
- _create_vec_table(conn, dim): CREATE VIRTUAL TABLE DDL (carries the
  nosemgrep annotation)
- _swap_in_compact(compact_path, db_path): close/replace/reconnect
  sequence used by both file-swap callers

Also normalises compact() error-path cleanup to unlink(missing_ok=True).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Adds equality test and no covers some defensive error handling stuff

* Ensures an embed migration stops the migration chain, just in case

* Silence one kind right but not really semgrep

* Trims dead assignment

* Fix(beta): address Copilot review on sqlite-vec vector store

Three findings from the PR review:

- compact() failure cleanup now unlinks the temporary .compact-wal and
  .compact-shm files, matching _run_structural_migration(); previously
  only the main .compact file was removed.
- _build_where() fails closed (1 = 0) when filters are requested but none
  translate, instead of emitting "()" which is invalid SQL; filters scope
  document access, so an empty translation must match no rows.
- Drop the unused table_name constructor parameter (all SQL hardcodes
  DEFAULT_TABLE_NAME) and its callers in indexing.py.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* Enhancement(beta): guard sqlite-vec compaction swap against concurrent readers

The compaction/migration file swap replaces the database via os.replace,
but the -wal/-shm files are keyed by path, not inode. A reader holding an
open connection across the swap leaves the old WAL aliased onto the new
file; a subsequent write then corrupts the database (reproduced via
PRAGMA integrity_check).

Add a cross-process read/write lock (filelock.ReadWriteLock) over the
index:

- read_store() holds it shared for the whole connection lifetime (and
  closes the connection on exit); concurrent readers do not block.
- compaction and the migration check run under an exclusive lock that
  drains readers, and skip with an info log on Timeout (maintenance op,
  retries next run).
- Normal writes are untouched: WAL gives reader/writer concurrency and
  LLM_INDEX_LOCK still serializes writers, so they never block readers.

load_or_build_index() now takes the store from the caller's read_store()
so the lock and connection span the whole retrieval; chat holds it across
the streamed response. Two new settings: LLM_INDEX_RWLOCK and
LLM_INDEX_COMPACTION_LOCK_TIMEOUT.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* Ensures the store alays cleans up SQLite connections for any operations, even on errors

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-15 13:20:41 -07:00
Yuki MIZUNO 11fb09e4f4 Fix (beta): don't send chat message on Enter while composing with IME (CJK) (#12999)
Co-authored-by: shamoon <4887959+shamoon@users.noreply.github.com>
2026-06-13 13:48:19 +00:00
Trenton H 8ed4bf2011 Fix: Apply unicode normalization to all paths and path components (#12993) 2026-06-13 12:45:54 +00:00
Trenton H 92c016ce47 Fix: Handle the UTF 16 and BOM text files better (#12994) 2026-06-13 05:35:38 -07:00
shamoon fb3816486c Fix (beta): avoid DRF update calling save on all fields (#12992) 2026-06-12 11:14:26 -07:00
Trenton H 4394403beb Fix: release pooled DB connection during AI LLM/embedding calls (#12983) 2026-06-11 13:07:31 -07:00
Trenton H f188d308eb Fix: health-check pooled DB connections and close the pool on worker shutdown (#12977) 2026-06-11 05:49:10 -07:00
shamoon a5d6ff5f15 Fix: wrap long titles in delete confirm dialog (#12973) 2026-06-10 06:56:02 -07:00
shamoon 8405f66e38 Fix (beta): fix re-ordering in merge dialog (#12967) 2026-06-09 07:03:44 -07:00
shamoon c3459d8f62 Fix (beta): move task filtering to backend fully (#12956) 2026-06-07 22:45:15 +00:00
shamoon 6f8e39c2e0 Fix: avoid unnecessary creating new PDF with pw removal workflow (#12948) 2026-06-07 20:30:08 +00:00
Trenton H eb292baa69 Enhancement (beta): Switch the AI vector store to LanceDB (#12944)
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-authored-by: shamoon <shamoon@users.noreply.github.com>
2026-06-07 11:31:26 -07:00
shamoon 3d0b8343b9 Fixhancement (beta): tasks dismiss all (#12949) 2026-06-07 03:42:06 +00:00
shamoon a7cec673bb Fix (beta): correct chat message bg color (#12955) 2026-06-06 16:00:03 -07:00
shamoon 449fd97b1f Fix (beta): respect disable state for suggest endpoint, require change perms (#12942) 2026-06-05 14:16:53 +00:00
Trenton H fa0c4368d7 Fix: Ensure checksum comparison is using SHA256 in file handling (#12939) 2026-06-05 06:46:45 -07:00
87 changed files with 5777 additions and 1943 deletions
+1 -1
View File
@@ -61,7 +61,7 @@ def replace_with_symlinks(
total_duplicates = 0
space_saved = 0
for file_list in duplicate_groups.values():
for file_hash, file_list in duplicate_groups.items():
# Keep the first file as the original, replace others with symlinks
original_file = file_list[0]
duplicates = file_list[1:]
+2 -8
View File
@@ -42,7 +42,6 @@ dependencies = [
"drf-spectacular~=0.28",
"drf-spectacular-sidecar~=2026.5.1",
"drf-writable-nested~=0.7.1",
"faiss-cpu>=1.10",
"filelock~=3.29.0",
"flower~=2.0.1",
"gotenberg-client~=0.14.0",
@@ -57,7 +56,6 @@ dependencies = [
"llama-index-embeddings-openai-like>=0.2.2",
"llama-index-llms-ollama>=0.9.1",
"llama-index-llms-openai-like>=0.7.1",
"llama-index-vector-stores-faiss>=0.5.2",
"nltk~=3.9.1",
"ocrmypdf~=17.4.2",
"openai>=2.32",
@@ -74,6 +72,7 @@ dependencies = [
"scikit-learn~=1.8.0",
"sentence-transformers>=5.4.1",
"setproctitle~=1.3.4",
"sqlite-vec==0.1.9",
"tantivy~=0.26.0",
"tika-client~=0.11.0",
"torch~=2.11.0",
@@ -185,16 +184,12 @@ line-ending = "lf"
[tool.ruff.lint]
# https://docs.astral.sh/ruff/rules/
extend-select = [
"B", # https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
"COM", # https://docs.astral.sh/ruff/rules/#flake8-commas-com
"DTZ", # https://docs.astral.sh/ruff/rules/#flake8-datetimez-dtz
"PERF", # https://docs.astral.sh/ruff/rules/#perflint-perf
"S324", # https://docs.astral.sh/ruff/rules/hashlib-insecure-hash-functions/
"DJ", # https://docs.astral.sh/ruff/rules/#flake8-django-dj
"EXE", # https://docs.astral.sh/ruff/rules/#flake8-executable-exe
"FBT", # https://docs.astral.sh/ruff/rules/#flake8-boolean-trap-fbt
"FLY", # https://docs.astral.sh/ruff/rules/#flynt-fly
"G", # https://docs.astral.sh/ruff/rules/#flake8-logging-format-g
"G201", # https://docs.astral.sh/ruff/rules/#flake8-logging-format-g
"I", # https://docs.astral.sh/ruff/rules/#isort-i
"ICN", # https://docs.astral.sh/ruff/rules/#flake8-import-conventions-icn
"INP", # https://docs.astral.sh/ruff/rules/#flake8-no-pep420-inp
@@ -215,7 +210,6 @@ extend-select = [
]
ignore = [
"DJ001",
"G004", # f-strings in logging: accepted style in this codebase
"PLC0415",
"RUF012",
"SIM105",
@@ -11,6 +11,9 @@
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="visibleTasks.length === 0">
<i-bs name="check2-all" class="me-1"></i-bs>{{dismissButtonText}}
</button>
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissAllTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="totalTasks === 0">
<i-bs name="check2-all" class="me-1"></i-bs><ng-container i18n>Dismiss all</ng-container>
</button>
<div class="form-check form-switch mb-0 ms-2">
<input class="form-check-input" type="checkbox" role="switch" [(ngModel)]="autoRefreshEnabled">
<label class="form-check-label" for="autoRefreshSwitch" i18n>Auto refresh</label>
@@ -81,7 +84,7 @@
<button class="btn btn-sm btn-outline-primary" ngbDropdownToggle>{{filterTargetName}}</button>
<div class="dropdown-menu shadow" ngbDropdownMenu>
@for (t of filterTargets; track t.id) {
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="filterTargetID = t.id">{{t.name}}</button>
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="setFilterTarget(t.id)">{{t.name}}</button>
}
</div>
</div>
@@ -11,7 +11,7 @@ import { Router } from '@angular/router'
import { RouterTestingModule } from '@angular/router/testing'
import { NgbModal, NgbModalRef, NgbModule } from '@ng-bootstrap/ng-bootstrap'
import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { throwError } from 'rxjs'
import { of, throwError } from 'rxjs'
import { routes } from 'src/app/app-routing.module'
import {
PaperlessTask,
@@ -29,7 +29,11 @@ import { ToastService } from 'src/app/services/toast.service'
import { environment } from 'src/environments/environment'
import { ConfirmDialogComponent } from '../../common/confirm-dialog/confirm-dialog.component'
import { PageHeaderComponent } from '../../common/page-header/page-header.component'
import { TasksComponent, TaskSection } from './tasks.component'
import {
TaskFilterTargetID,
TasksComponent,
TaskSection,
} from './tasks.component'
const tasks: PaperlessTask[] = [
{
@@ -154,6 +158,13 @@ const paginatedTasks: Results<PaperlessTask> = {
results: tasks,
}
const sectionCountResponse = {
all: 7,
needs_attention: 2,
in_progress: 3,
completed: 2,
}
describe('TasksComponent', () => {
let component: TasksComponent
let fixture: ComponentFixture<TasksComponent>
@@ -221,6 +232,15 @@ describe('TasksComponent', () => {
req.params.get('page') === '1'
)
.flush(paginatedTasks)
httpTestingController
.expectOne(
(req) =>
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
req.params.get('acknowledged') === 'false' &&
!req.params.has('status')
)
.flush(sectionCountResponse)
})
it('should display task sections with counts', () => {
@@ -295,6 +315,7 @@ describe('TasksComponent', () => {
const headerText = header.nativeElement.textContent
expect(headerText).toContain('Dismiss visible')
expect(headerText).toContain('Dismiss all')
expect(headerText).toContain('Auto refresh')
expect(headerText).not.toContain('All types')
expect(headerText).not.toContain('All sources')
@@ -327,6 +348,74 @@ describe('TasksComponent', () => {
expect(pagination).not.toBeNull()
})
it('should apply the selected section to the server-side task query', () => {
component.setSection(TaskSection.NeedsAttention)
const req = httpTestingController.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page') === '1' &&
request.params.get('page_size') === '25' &&
request.params.get('acknowledged') === 'false' &&
request.params.getAll('status').includes(PaperlessTaskStatus.Failure) &&
request.params.getAll('status').includes(PaperlessTaskStatus.Revoked)
)
req.flush({ count: 2, results: [tasks[0], tasks[1]] })
expect(component.totalTasks).toBe(2)
})
it('should apply task type and trigger source filters to the server-side task query', () => {
component.setTaskType(PaperlessTaskType.SanityCheck)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('task_type') === PaperlessTaskType.SanityCheck
)
.flush({ count: 1, results: [tasks[6]] })
component.setTriggerSource(PaperlessTaskTriggerSource.System)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('task_type') === PaperlessTaskType.SanityCheck &&
request.params.get('trigger_source') ===
PaperlessTaskTriggerSource.System
)
.flush({ count: 1, results: [tasks[6]] })
})
it('should apply text filters to the server-side task query', () => {
component.filterText = 'invoice'
jest.advanceTimersByTime(150)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('name') === 'invoice'
)
.flush({ count: 1, results: [tasks[0]] })
component.setFilterTarget(TaskFilterTargetID.Result)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('result') === 'invoice'
)
.flush({ count: 0, results: [] })
})
it('should load a different task page when pagination changes', () => {
component.setPage(2)
@@ -350,6 +439,27 @@ describe('TasksComponent', () => {
expect(component.pagedTasks).toEqual([tasks[0]])
})
it('should not replace section counts with current-page counts', () => {
component.setPage(2)
httpTestingController
.expectOne(
(req) =>
req.url === `${environment.apiBaseUrl}tasks/` &&
req.params.get('acknowledged') === 'false' &&
req.params.get('page_size') === '25' &&
req.params.get('page') === '2'
)
.flush({
count: 30,
results: [tasks[0]],
})
expect(component.sectionCount(TaskSection.NeedsAttention)).toBe(2)
expect(component.sectionCount(TaskSection.InProgress)).toBe(3)
expect(component.sectionCount(TaskSection.Completed)).toBe(2)
})
it('should expose stable task type options and disable empty ones', () => {
expect(component.taskTypeOptions.map((option) => option.value)).toContain(
PaperlessTaskType.TrainClassifier
@@ -495,6 +605,46 @@ describe('TasksComponent', () => {
expect(dismissSpy).toHaveBeenCalledWith(new Set([467, 466]))
})
it('should support dismiss all tasks', () => {
let modal: NgbModalRef
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
const dismissSpy = jest
.spyOn(tasksService, 'dismissAllTasks')
.mockReturnValue(of({}))
const reloadPageSpy = jest
.spyOn(component as any, 'reloadPage')
.mockImplementation(() => undefined)
component.dismissAllTasks()
expect(modal).not.toBeUndefined()
expect(modal.componentInstance.messageBold).toBe('Dismiss all 7 tasks?')
modal.componentInstance.confirmClicked.emit()
expect(dismissSpy).toHaveBeenCalled()
expect(reloadPageSpy).toHaveBeenCalledWith(false)
expect(component.selectedTasks.size).toBe(0)
})
it('should show an error and re-enable modal buttons when dismissing all tasks fails', () => {
const error = new Error('dismiss all failed')
const toastSpy = jest.spyOn(toastService, 'showError')
const dismissSpy = jest
.spyOn(tasksService, 'dismissAllTasks')
.mockReturnValue(throwError(() => error))
let modal: NgbModalRef
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
component.dismissAllTasks()
expect(modal).not.toBeUndefined()
modal.componentInstance.confirmClicked.emit()
expect(dismissSpy).toHaveBeenCalled()
expect(toastSpy).toHaveBeenCalledWith('Error dismissing tasks', error)
expect(modal.componentInstance.buttonsEnabled).toBe(true)
})
it('should dismiss the currently visible scoped and filtered tasks', () => {
component.setSection(TaskSection.InProgress)
component.setTaskType(PaperlessTaskType.SanityCheck)
@@ -673,6 +823,9 @@ describe('TasksComponent', () => {
})
it('should keep clearing selection independent from resetting filters', () => {
component.resetFilter()
expect(component.filterText).toBe('')
component.setTaskType(PaperlessTaskType.ConsumeFile)
component.toggleSelected(tasks[0])
expect(component.selectedTasks.size).toBe(1)
@@ -40,7 +40,7 @@ export enum TaskSection {
Completed = 'completed',
}
enum TaskFilterTargetID {
export enum TaskFilterTargetID {
Name,
Result,
}
@@ -167,6 +167,12 @@ export class TasksComponent
public readonly pageSize = 25
public page: number = 1
public totalTasks: number = 0
public sectionCounts: Record<TaskSection, number> = {
[TaskSection.All]: 0,
[TaskSection.NeedsAttention]: 0,
[TaskSection.InProgress]: 0,
[TaskSection.Completed]: 0,
}
public pagedTasks: PaperlessTask[] = []
public selectedSection: TaskSection = TaskSection.All
public selectedTaskType: PaperlessTaskType | null = null
@@ -282,6 +288,7 @@ export class TasksComponent
.subscribe((query) => {
this._filterText = query
this.clearSelection()
this.reloadPage(true)
})
}
@@ -334,6 +341,30 @@ export class TasksComponent
}
}
dismissAllTasks() {
let modal = this.modalService.open(ConfirmDialogComponent, {
backdrop: 'static',
})
modal.componentInstance.title = $localize`Confirm Dismiss All`
modal.componentInstance.messageBold = $localize`Dismiss all ${this.totalTasks} tasks?`
modal.componentInstance.btnClass = 'btn-warning'
modal.componentInstance.btnCaption = $localize`Dismiss`
modal.componentInstance.confirmClicked.pipe(first()).subscribe(() => {
modal.componentInstance.buttonsEnabled = false
modal.close()
this.tasksService.dismissAllTasks().subscribe({
next: () => {
this.reloadPage(false)
},
error: (e) => {
this.toastService.showError($localize`Error dismissing tasks`, e)
modal.componentInstance.buttonsEnabled = true
},
})
this.clearSelection()
})
}
expandTask(task: PaperlessTask) {
this.expandedTask = this.expandedTask == task.id ? undefined : task.id
}
@@ -446,9 +477,7 @@ export class TasksComponent
}
sectionCount(section: TaskSection): number {
return this.pagedTasks.filter((task) =>
this.taskBelongsToSection(task, section)
).length
return this.sectionCounts[section]
}
sectionShowsResults(section: TaskSection): boolean {
@@ -458,16 +487,27 @@ export class TasksComponent
setSection(section: TaskSection) {
this.selectedSection = section
this.clearSelection()
this.reloadPage(true)
}
setTaskType(taskType: PaperlessTaskType | null) {
this.selectedTaskType = taskType
this.clearSelection()
this.reloadPage(true)
}
setTriggerSource(triggerSource: PaperlessTaskTriggerSource | null) {
this.selectedTriggerSource = triggerSource
this.clearSelection()
this.reloadPage(true)
}
setFilterTarget(filterTargetID: TaskFilterTargetID) {
this.filterTargetID = filterTargetID
if (this._filterText.length) {
this.clearSelection()
this.reloadPage(true)
}
}
taskTypeOptionCount(taskType: PaperlessTaskType | null): number {
@@ -505,19 +545,32 @@ export class TasksComponent
}
public resetFilter() {
if (!this._filterText.length) {
return
}
this._filterText = ''
this.clearSelection()
this.reloadPage(true)
}
public resetFilters() {
const hadFilter = this.isFiltered
this.selectedTaskType = null
this.selectedTriggerSource = null
this.resetFilter()
this._filterText = ''
this.clearSelection()
if (hadFilter) {
this.reloadPage(true)
}
}
filterInputKeyup(event: KeyboardEvent) {
if (event.key == 'Enter') {
this._filterText = (event.target as HTMLInputElement).value
this.clearSelection()
this.reloadPage(true)
} else if (event.key === 'Escape') {
this.resetFilter()
}
@@ -606,19 +659,86 @@ export class TasksComponent
)
}
private reloadSectionCounts() {
this.tasksService
.statusCounts(this.getParamsForSection(TaskSection.All))
.pipe(first(), takeUntil(this.unsubscribeNotifier))
.subscribe((counts) => {
this.sectionCounts[TaskSection.All] = counts.all
this.sectionCounts[TaskSection.NeedsAttention] = counts.needs_attention
this.sectionCounts[TaskSection.InProgress] = counts.in_progress
this.sectionCounts[TaskSection.Completed] = counts.completed
})
}
private getParamsForSection(
section: TaskSection
): Record<string, string | number | boolean | readonly string[]> {
const params: Record<
string,
string | number | boolean | readonly string[]
> = {
acknowledged: false,
}
const statuses = this.statusesForSection(section)
if (statuses.length) {
params.status = statuses
}
if (this.selectedTaskType !== null) {
params.task_type = this.selectedTaskType
}
if (this.selectedTriggerSource !== null) {
params.trigger_source = this.selectedTriggerSource
}
if (this._filterText.length) {
params[
this.filterTargetID === TaskFilterTargetID.Name ? 'name' : 'result'
] = this._filterText
}
return params
}
private statusesForSection(section: TaskSection): PaperlessTaskStatus[] {
switch (section) {
case TaskSection.NeedsAttention:
return [PaperlessTaskStatus.Failure, PaperlessTaskStatus.Revoked]
case TaskSection.InProgress:
return [PaperlessTaskStatus.Pending, PaperlessTaskStatus.Started]
case TaskSection.Completed:
return [PaperlessTaskStatus.Success]
default:
return []
}
}
private reloadPage(resetToFirstPage: boolean = false) {
if (resetToFirstPage) {
this.page = 1
}
this.reloadSectionCounts()
this.loading = true
this.tasksService
.list(this.page, this.pageSize, { acknowledged: false })
.list(
this.page,
this.pageSize,
this.getParamsForSection(this.selectedSection)
)
.pipe(first(), takeUntil(this.unsubscribeNotifier))
.subscribe({
next: (result) => {
this.pagedTasks = result.results
this.totalTasks = result.count
this.sectionCounts[TaskSection.All] = result.count
if (this.selectedSection !== TaskSection.All) {
this.sectionCounts[this.selectedSection] = result.count
}
this.loading = false
if (
this.page > 1 &&
@@ -8,7 +8,7 @@
<div class="chat-messages font-monospace small">
@for (message of messages; track message) {
<div class="message d-flex flex-row small" [class.justify-content-end]="message.role === 'user'">
<div class="p-2 m-2" [class.bg-dark]="message.role === 'user'">
<div class="p-2 m-2" [class.bg-body]="message.role === 'user'">
<span>
{{ message.content }}
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
@@ -188,4 +188,14 @@ describe('ChatComponent', () => {
component.searchInputKeyDown(event)
expect(component.sendMessage).toHaveBeenCalled()
})
it('should not send message on Enter key press while composing with IME', () => {
jest.spyOn(component, 'sendMessage')
const event = new KeyboardEvent('keydown', {
key: 'Enter',
isComposing: true,
})
component.searchInputKeyDown(event)
expect(component.sendMessage).not.toHaveBeenCalled()
})
})
@@ -155,7 +155,10 @@ export class ChatComponent implements OnInit {
}
public searchInputKeyDown(event: KeyboardEvent) {
if (event.key === 'Enter') {
if (
event.key === 'Enter' &&
!(event.isComposing || event.keyCode === 229)
) {
event.preventDefault()
this.sendMessage()
}
@@ -5,10 +5,10 @@
</div>
<div class="modal-body">
@if (messageBold) {
<p><b>{{messageBold}}</b></p>
<p class="text-break"><b>{{messageBold}}</b></p>
}
@if (message) {
<p class="mb-0" [innerHTML]="message"></p>
<p class="mb-0 text-break" [innerHTML]="message"></p>
}
</div>
<div class="modal-footer">
@@ -9,8 +9,11 @@
<label class="form-label" for="metadataDocumentID" i18n>Documents:</label>
<ul class="list-group"
cdkDropList
[cdkDropListData]="documentIDs"
(cdkDropListDropped)="onDrop($event)">
@for (document of documents; track document.id) {
@for (documentID of documentIDs; track documentID) {
@let document = getDocument(documentID);
@if (document) {
<li class="list-group-item d-flex align-items-center" cdkDrag>
<i-bs name="grip-vertical" class="me-2"></i-bs>
<div class="d-flex flex-column">
@@ -27,6 +30,7 @@
</small>
</div>
</li>
}
}
</ul>
</div>
@@ -1,5 +1,5 @@
<div class="btn-group">
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="loading || (suggestions && !aiEnabled)">
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="disabled || loading || (suggestions && !aiEnabled)">
@if (loading) {
<div class="spinner-border spinner-border-sm" role="status"></div>
} @else {
@@ -13,7 +13,7 @@
@if (aiEnabled) {
<div class="btn-group" ngbDropdown #dropdown="ngbDropdown" [popperOptions]="popperOptions">
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="disabled || loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
<span class="visually-hidden" i18n>Show suggestions</span>
</button>
@@ -37,6 +37,18 @@ describe('SuggestionsDropdownComponent', () => {
expect(component.getSuggestions.emit).toHaveBeenCalled()
})
it('should not emit getSuggestions when disabled', () => {
jest.spyOn(component.getSuggestions, 'emit')
component.disabled = true
component.suggestions = null
fixture.detectChanges()
component.clickSuggest()
expect(component.getSuggestions.emit).not.toHaveBeenCalled()
expect(fixture.nativeElement.querySelector('button').disabled).toBeTruthy()
})
it('should toggle dropdown when clickSuggest is called and suggestions are not null', () => {
component.aiEnabled = true
fixture.detectChanges()
@@ -47,6 +47,14 @@ export class SuggestionsDropdownComponent {
addCorrespondent: EventEmitter<string> = new EventEmitter()
public clickSuggest(): void {
if (
this.disabled ||
this.loading ||
(this.suggestions && !this.aiEnabled)
) {
return
}
if (!this.suggestions) {
this.getSuggestions.emit(this)
} else {
@@ -131,7 +131,9 @@
@if (status.tasks.celery_status === 'OK') {
<i-bs name="check-circle-fill" class="text-primary ms-2 lh-1"></i-bs>
} @else {
<i-bs name="exclamation-triangle-fill" class="text-danger ms-2 lh-1"></i-bs>
<i-bs name="exclamation-triangle-fill" class="ms-2 lh-1"
[class.text-danger]="status.tasks.celery_status === SystemStatusItemStatus.ERROR"
[class.text-warning]="status.tasks.celery_status === SystemStatusItemStatus.WARNING"></i-bs>
}
</button>
<ng-template #celeryStatus>
+7
View File
@@ -64,3 +64,10 @@ export interface PaperlessTaskSummary {
last_success: Date | null
last_failure: Date | null
}
export interface PaperlessTaskStatusCounts {
all: number
needs_attention: number
in_progress: number
completed: number
}
@@ -80,6 +80,27 @@ describe('TasksService', () => {
.flush({ count: 0, results: [] })
})
it('calls acknowledge_tasks api endpoint on dismiss all and reloads', () => {
tasksService.dismissAllTasks().subscribe()
const req = httpTestingController.expectOne(
`${environment.apiBaseUrl}tasks/acknowledge/`
)
expect(req.request.method).toEqual('POST')
expect(req.request.body).toEqual({
all: true,
})
req.flush([])
// reload is then called
httpTestingController
.expectOne(
(req: HttpRequest<unknown>) =>
req.url === `${environment.apiBaseUrl}tasks/` &&
req.params.get('acknowledged') === 'false' &&
req.params.get('page_size') === '1000'
)
.flush({ count: 0, results: [] })
})
it('groups mixed task types by status when reloading', () => {
expect(tasksService.total).toEqual(0)
const mockTasks = [
@@ -221,4 +242,34 @@ describe('TasksService', () => {
task_id: 'abc-123',
})
})
it('loads filtered task status counts', () => {
tasksService
.statusCounts({
acknowledged: false,
task_type: PaperlessTaskType.ConsumeFile,
})
.subscribe((res) => {
expect(res).toEqual({
all: 10,
needs_attention: 2,
in_progress: 3,
completed: 5,
})
})
const req = httpTestingController.expectOne(
(req: HttpRequest<unknown>) =>
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
req.params.get('acknowledged') === 'false' &&
req.params.get('task_type') === PaperlessTaskType.ConsumeFile
)
expect(req.request.method).toEqual('GET')
req.flush({
all: 10,
needs_attention: 2,
in_progress: 3,
completed: 5,
})
})
})
+27 -1
View File
@@ -5,6 +5,7 @@ import { first, map, takeUntil, tap } from 'rxjs/operators'
import {
PaperlessTask,
PaperlessTaskStatus,
PaperlessTaskStatusCounts,
PaperlessTaskType,
} from 'src/app/data/paperless-task'
import { Results } from 'src/app/data/results'
@@ -88,7 +89,7 @@ export class TasksService {
public list(
page: number,
pageSize: number,
extraParams?: Record<string, string | number | boolean>
extraParams?: Record<string, string | number | boolean | readonly string[]>
): Observable<Results<PaperlessTask>> {
return this.http.get<Results<PaperlessTask>>(
`${this.baseUrl}${this.endpoint}/`,
@@ -102,6 +103,17 @@ export class TasksService {
)
}
public statusCounts(
extraParams?: Record<string, string | number | boolean | readonly string[]>
): Observable<PaperlessTaskStatusCounts> {
return this.http.get<PaperlessTaskStatusCounts>(
`${this.baseUrl}${this.endpoint}/status_counts/`,
{
params: extraParams,
}
)
}
public dismissTasks(task_ids: Set<number>): Observable<any> {
return this.http
.post(`${this.baseUrl}tasks/acknowledge/`, {
@@ -116,6 +128,20 @@ export class TasksService {
)
}
public dismissAllTasks(): Observable<any> {
return this.http
.post(`${this.baseUrl}tasks/acknowledge/`, {
all: true,
})
.pipe(
first(),
takeUntil(this.unsubscribeNotifer),
tap(() => {
this.reload()
})
)
}
public cancelPending(): void {
this.unsubscribeNotifer.next(true)
}
+13
View File
@@ -904,6 +904,19 @@ def remove_password(
doc.id,
pair.source_doc.source_path,
)
try:
with pikepdf.open(source_path) as pdf:
if not pdf.is_encrypted:
logger.info(
"Skipping password removal for document %s because the "
"source PDF is not encrypted",
pair.root_doc.id,
)
continue
except pikepdf.PasswordError:
# Password-protected PDFs need the supplied password below.
pass
with pikepdf.open(source_path, password=password) as pdf:
filepath: Path = (
Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR))
+2 -3
View File
@@ -834,9 +834,8 @@ class ConsumerPlugin(
self.log.debug(f"Creation date from parse_date: {create_date}")
else:
stats = Path(self.input_doc.original_file).stat()
create_date = datetime.datetime.fromtimestamp(
stats.st_mtime,
tz=datetime.UTC,
create_date = timezone.make_aware(
datetime.datetime.fromtimestamp(stats.st_mtime),
)
self.log.debug(f"Creation date from st_mtime: {create_date}")
+4 -4
View File
@@ -1,3 +1,4 @@
import datetime as dt
import logging
import os
import shutil
@@ -5,7 +6,6 @@ from pathlib import Path
from typing import Final
from django.conf import settings
from django.utils import timezone
from pikepdf import Pdf
from documents.consumer import ConsumerError
@@ -78,7 +78,7 @@ class CollatePlugin(NoCleanupPluginMixin, NoSetupPluginMixin, ConsumeTaskPlugin)
stats = staging.stat()
# if the file is older than the timeout, we don't consider
# it valid
if (timezone.now().timestamp() - stats.st_mtime) > TIMEOUT_SECONDS:
if (dt.datetime.now().timestamp() - stats.st_mtime) > TIMEOUT_SECONDS:
logger.warning("Outdated double sided staging file exists, deleting it")
staging.unlink()
else:
@@ -99,7 +99,7 @@ class CollatePlugin(NoCleanupPluginMixin, NoSetupPluginMixin, ConsumeTaskPlugin)
"two uploaded files don't belong to the same double-"
"sided scan. Please retry, starting with the odd "
"numbered pages again.",
) from None
)
# Merged file has the same path, but without the
# double-sided subdir. Therefore, it is also in the
# consumption dir and will be picked up for processing
@@ -134,7 +134,7 @@ class CollatePlugin(NoCleanupPluginMixin, NoSetupPluginMixin, ConsumeTaskPlugin)
shutil.move(pdf_file, staging)
# update access to modification time so we know if the file
# is outdated when another file gets uploaded
timestamp = timezone.now().timestamp()
timestamp = dt.datetime.now().timestamp()
os.utime(staging, (timestamp, timestamp))
logger.info(
"Got scan with odd numbered pages of double-sided scan, moved it to %s",
+67 -5
View File
@@ -28,6 +28,7 @@ from django.db.models.functions import Cast
from django.utils.translation import gettext_lazy as _
from django_filters import DateFilter
from django_filters.rest_framework import BooleanFilter
from django_filters.rest_framework import CharFilter
from django_filters.rest_framework import DateTimeFilter
from django_filters.rest_framework import Filter
from django_filters.rest_framework import FilterSet
@@ -350,7 +351,7 @@ def handle_validation_prefix(func: Callable):
try:
return func(*args, **kwargs)
except serializers.ValidationError as e:
raise serializers.ValidationError({validation_prefix: e.detail}) from e
raise serializers.ValidationError({validation_prefix: e.detail})
# Update the signature to include the validation_prefix argument
old_sig = inspect.signature(func)
@@ -461,7 +462,7 @@ class CustomFieldQueryParser:
except json.JSONDecodeError:
raise serializers.ValidationError(
{self._validation_prefix: [_("Value must be valid JSON.")]},
) from None
)
return (
self._parse_expr(expr, validation_prefix=self._validation_prefix),
self._annotations,
@@ -589,7 +590,7 @@ class CustomFieldQueryParser:
except CustomField.DoesNotExist:
raise serializers.ValidationError(
[_("{name!r} is not a valid custom field.").format(name=id_or_name)],
) from None
)
self._custom_fields[custom_field.id] = custom_field
self._custom_fields[custom_field.name] = custom_field
return custom_field
@@ -900,6 +901,16 @@ class ShareLinkBundleFilterSet(FilterSet):
class PaperlessTaskFilterSet(FilterSet):
name = CharFilter(
method="filter_name",
label="Name",
)
result = CharFilter(
method="filter_result",
label="Result",
)
task_type = MultipleChoiceFilter(
choices=PaperlessTask.TaskType.choices,
label="Task Type",
@@ -939,7 +950,58 @@ class PaperlessTaskFilterSet(FilterSet):
class Meta:
model = PaperlessTask
fields = ["task_type", "trigger_source", "status", "acknowledged", "owner"]
fields = [
"task_type",
"trigger_source",
"status",
"acknowledged",
"owner",
"name",
"result",
]
def filter_name(self, queryset, name, value):
if not value:
return queryset
matching_task_types = [
task_type
for task_type, label in PaperlessTask.TaskType.choices
if value.lower() in str(label).lower()
]
matching_trigger_sources = [
trigger_source
for trigger_source, label in PaperlessTask.TriggerSource.choices
if value.lower() in str(label).lower()
]
return queryset.filter(
Q(input_data__filename__icontains=value)
| Q(task_type__in=matching_task_types)
| Q(trigger_source__in=matching_trigger_sources),
)
def filter_result(self, queryset, name, value):
if not value:
return queryset
query = Q(result_data__reason__icontains=value) | Q(
result_data__error_message__icontains=value,
)
try:
numeric_value = int(value)
except (TypeError, ValueError):
pass
else:
query |= Q(result_data__document_id=numeric_value) | Q(
result_data__duplicate_of=numeric_value,
)
if "duplicate" in value.lower():
query |= Q(result_data__duplicate_of__isnull=False)
return queryset.filter(query)
def filter_is_complete(self, queryset, name, value):
if value:
@@ -988,7 +1050,7 @@ class DocumentsOrderingFilter(OrderingFilter):
except CustomField.DoesNotExist:
raise serializers.ValidationError(
{self.prefix + str(custom_field_id): [_("Custom field not found")]},
) from None
)
annotation = None
match field.data_type:
@@ -480,7 +480,7 @@ class Command(CryptMixin, PaperlessCommand):
}
# 3. Export files from each document
for _, document_dict in enumerate(
for index, document_dict in enumerate(
self.track(
document_manifest,
description="Exporting documents...",
@@ -2,6 +2,7 @@ from typing import Any
from documents.management.commands.base import PaperlessCommand
from documents.tasks import llmindex_index
from paperless_ai.indexing import llm_index_compact
class Command(PaperlessCommand):
@@ -12,9 +13,12 @@ class Command(PaperlessCommand):
def add_arguments(self, parser: Any) -> None:
super().add_arguments(parser)
parser.add_argument("command", choices=["rebuild", "update"])
parser.add_argument("command", choices=["rebuild", "update", "compact"])
def handle(self, *args: Any, **options: Any) -> None:
if options["command"] == "compact":
llm_index_compact()
return
llmindex_index(
rebuild=options["command"] == "rebuild",
iter_wrapper=lambda docs: self.track(
@@ -133,14 +133,11 @@ def _build_suggestion_table(
else:
doc_cell = Text(f"{doc} [{doc.pk}]")
tag_parts: list[str] = [
f"[green]+{tag.name}[/green]"
for tag in sorted(suggestion.tags_to_add, key=lambda t: t.name)
]
tag_parts.extend(
f"[red]-{tag.name}[/red]"
for tag in sorted(suggestion.tags_to_remove, key=lambda t: t.name)
)
tag_parts: list[str] = []
for tag in sorted(suggestion.tags_to_add, key=lambda t: t.name):
tag_parts.append(f"[green]+{tag.name}[/green]")
for tag in sorted(suggestion.tags_to_remove, key=lambda t: t.name):
tag_parts.append(f"[red]-{tag.name}[/red]")
tag_cell = Text.from_markup(", ".join(tag_parts)) if tag_parts else Text("-")
table.add_row(
+3 -3
View File
@@ -369,7 +369,7 @@ class Document(SoftDeleteModel, ModelWithOwner): # type: ignore[django-manager-
If the queryset already annotated ``effective_content``, that value is used.
"""
if hasattr(self, "effective_content"):
return self.effective_content
return getattr(self, "effective_content")
if self.root_document_id is not None or self.pk is None:
return self.content
@@ -1204,8 +1204,8 @@ class CustomFieldInstance(SoftDeleteModel):
def get_value_field_name(cls, data_type: CustomField.FieldDataType):
try:
return cls.TYPE_TO_DATA_STORE_NAME_MAP[data_type]
except KeyError as exc: # pragma: no cover
raise NotImplementedError(data_type) from exc
except KeyError: # pragma: no cover
raise NotImplementedError(data_type)
@property
def value(self):
+1 -1
View File
@@ -110,7 +110,7 @@ def run_convert(
args += ["-define", "pdf:use-cropbox=true"] if use_cropbox else []
args += [str(input_file), str(output_file)]
logger.debug("Execute: %s", " ".join(args), extra={"group": logging_group})
logger.debug("Execute: " + " ".join(args), extra={"group": logging_group})
try:
run_subprocess(args, environment, logger)
+2 -1
View File
@@ -67,7 +67,8 @@ class DateParserPluginBase(ABC):
Subclasses can override this to release resources.
"""
return None
# Default implementation does nothing.
# Returning None implies exceptions are propagated.
def _parse_string(
self,
+4
View File
@@ -8,11 +8,15 @@ from documents.search._backend import get_backend
from documents.search._backend import reset_backend
from documents.search._schema import needs_rebuild
from documents.search._schema import wipe_index
from documents.search._translate import InvalidDateQuery
from documents.search._translate import SearchQueryError
__all__ = [
"InvalidDateQuery",
"SearchHit",
"SearchIndexLockError",
"SearchMode",
"SearchQueryError",
"TantivyBackend",
"TantivyRelevanceList",
"WriteBatch",
+3 -7
View File
@@ -195,12 +195,12 @@ class WriteBatch:
try:
self._lock.acquire(timeout=self._lock_timeout)
break
except filelock.Timeout as exc:
except filelock.Timeout:
if attempt == _LOCK_RETRY_ATTEMPTS - 1:
raise SearchIndexLockError(
f"Could not acquire index lock after {_LOCK_RETRY_ATTEMPTS} "
f"attempts (timeout={self._lock_timeout}s each)",
) from exc
)
sleep_s = random.uniform(
0,
min(_LOCK_BACKOFF_CAP, _LOCK_BACKOFF_BASE * (2**attempt)),
@@ -651,11 +651,7 @@ class TantivyBackend:
result_ids = cast("list[int]", searcher.fast_field_values("id", result_addrs))
addr_by_id: dict[int, tuple[float, tantivy.DocAddress]] = {
doc_id: (score, addr)
for (score, addr), doc_id in zip(
batch_results.hits,
result_ids,
strict=False,
)
for (score, addr), doc_id in zip(batch_results.hits, result_ids)
}
snippet_generator = None
+163
View File
@@ -0,0 +1,163 @@
from __future__ import annotations
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Final
from dateutil.relativedelta import relativedelta
if TYPE_CHECKING:
from datetime import tzinfo
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _quarter_start(d: date) -> date:
"""Return the first day of the calendar quarter containing ``d``."""
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
def _midnight(d: date, tz: tzinfo) -> datetime:
"""Convert a calendar date at local-timezone midnight to a UTC datetime."""
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _keyword_bounds(keyword: str, tz: tzinfo) -> tuple[date, date]:
"""
Map a relative date keyword to ``(start, exclusive_end)`` calendar dates.
``tz`` only determines what "today" is; the caller decides how the returned
dates become UTC datetime boundaries (date-only vs. local-midnight offset).
"""
today = datetime.now(tz).date()
if keyword == _TODAY:
return today, today + timedelta(days=1)
if keyword == _YESTERDAY:
return today - timedelta(days=1), today
if keyword == _PREVIOUS_WEEK:
this_monday = today - timedelta(days=today.weekday())
return this_monday - timedelta(weeks=1), this_monday
if keyword == _THIS_MONTH:
first = today.replace(day=1)
return first, first + relativedelta(months=1)
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
return this_first - relativedelta(months=1), this_first
if keyword == _THIS_YEAR:
return date(today.year, 1, 1), date(today.year + 1, 1, 1)
if keyword == _PREVIOUS_YEAR:
return date(today.year - 1, 1, 1), date(today.year, 1, 1)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
return this_quarter - relativedelta(months=3), this_quarter
raise ValueError(f"Unknown keyword: {keyword}")
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic — date only.
"""
start, end = _keyword_bounds(keyword, tz)
lo = datetime(start.year, start.month, start.day, tzinfo=UTC)
hi = datetime(end.year, end.month, end.day, tzinfo=UTC)
return _iso_range(lo, hi)
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC — full offset arithmetic required.
"""
start, end = _keyword_bounds(keyword, tz)
return _iso_range(_midnight(start, tz), _midnight(end, tz))
def _precision_bounds(digits: str) -> tuple[date, date] | None:
"""
Map a 4/6/8-digit date token to (start, exclusive_end) calendar dates.
YYYY -> whole year, YYYYMM -> whole month, YYYYMMDD -> single day.
Returns None for any unparsable or out-of-range value (e.g. month 23),
so callers can emit a no-match clause instead of erroring (Whoosh parity).
"""
try:
if len(digits) == 4:
year = int(digits)
return date(year, 1, 1), date(year + 1, 1, 1)
if len(digits) == 6:
year, month = int(digits[:4]), int(digits[4:6])
start = date(year, month, 1)
end = date(year + 1, 1, 1) if month == 12 else date(year, month + 1, 1)
return start, end
if len(digits) == 8:
start = date(int(digits[:4]), int(digits[4:6]), int(digits[6:8]))
return start, start + timedelta(days=1)
except ValueError:
return None
return None
def _utc_bounds_for_field(
field: str,
start: date,
end: date,
tz: tzinfo,
) -> tuple[datetime, datetime]:
"""
Convert calendar-date bounds to UTC datetimes per the field's storage type.
For DateField (``created``) the bounds are UTC midnight (no offset). For
DateTimeField (``added``/``modified``) the bounds are local-tz midnight
converted to UTC, matching how each field is indexed.
"""
if field in _DATE_ONLY_FIELDS:
return (
datetime(start.year, start.month, start.day, tzinfo=UTC),
datetime(end.year, end.month, end.day, tzinfo=UTC),
)
return (
datetime(start.year, start.month, start.day, tzinfo=tz).astimezone(UTC),
datetime(end.year, end.month, end.day, tzinfo=tz).astimezone(UTC),
)
def _field_range_from_dates(field: str, start: date, end: date, tz: tzinfo) -> str:
"""Build a Tantivy ``field:[lo TO hi]`` ISO range from calendar-date bounds."""
lo, hi = _utc_bounds_for_field(field, start, end, tz)
return f"{field}:{_iso_range(lo, hi)}"
+15 -440
View File
@@ -1,88 +1,28 @@
from __future__ import annotations
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
import logging
from typing import TYPE_CHECKING
from typing import Final
import regex
import tantivy
from dateutil.relativedelta import relativedelta
from django.conf import settings
from documents.search._tokenizer import simple_search_tokens
from documents.search._translate import SearchQueryError
from documents.search._translate import translate_query
if TYPE_CHECKING:
from datetime import tzinfo
from django.contrib.auth.base_user import AbstractBaseUser
logger = logging.getLogger("paperless.search")
# Maximum seconds any single regex substitution may run.
# Prevents ReDoS on adversarial user-supplied query strings.
_REGEX_TIMEOUT: Final[float] = 1.0
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
_DATE_KEYWORD_PATTERN = "|".join(
sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True),
)
_FIELD_DATE_RE = regex.compile(
rf"""(?<!\w)(?P<field>created|modified|added)\s*:\s*(?:
(?P<quote>["'])(?P<quoted>{_DATE_KEYWORD_PATTERN})(?P=quote)
|
(?P<bare>{_DATE_KEYWORD_PATTERN})(?![\w-])
)""",
regex.IGNORECASE | regex.VERBOSE,
)
_COMPACT_DATE_RE = regex.compile(r"\b(\d{14})\b")
_RELATIVE_RANGE_RE = regex.compile(
r"\[now([+-]\d+[dhm])?\s+TO\s+now([+-]\d+[dhm])?\]",
regex.IGNORECASE,
)
# Whoosh-style relative date range: e.g. [-1 week to now], [-7 days to now]
_WHOOSH_REL_RANGE_RE = regex.compile(
r"\[-(?P<n>\d+)\s+(?P<unit>second|minute|hour|day|week|month|year)s?\s+to\s+now\]",
regex.IGNORECASE,
)
# Whoosh-style 8-digit date: field:YYYYMMDD — field-aware so timezone can be applied correctly.
# Scoped to date fields only; numeric fields (asn, id, page_count, ...) must not be rewritten.
_DATE8_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):(?P<date8>\d{8})\b",
)
_YEAR_RANGE_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):\[(?P<y1>\d{4})\s+TO\s+(?P<y2>\d{4})\]",
regex.IGNORECASE,
)
# Tantivy syntax error: " - " and " + " with spaces on both sides are invalid because
# the NOT/MUST operators require no space between the operator and the term.
# In natural-language queries (e.g., "H52.1 - Kurzsichtigkeit"), the dash is a separator.
_SPACED_OPERATOR_RE = regex.compile(r"\s+[-+]\s+")
_TRAILING_OPERATOR_RE = regex.compile(r"\s+[-+]+\s*$")
# Matches CJK/Hangul characters so queries can be routed to bigram fields.
# Uses Unicode properties to cover all blocks including Extension B+ planes.
_CJK_RE: Final = regex.compile(r"[\p{Han}\p{Hiragana}\p{Katakana}\p{Hangul}]+")
@@ -117,379 +57,6 @@ def _build_cjk_query(
return None
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic — date only.
"""
today = datetime.now(tz).date()
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
lo = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, lo + timedelta(days=1))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
lo = datetime(y.year, y.month, y.day, tzinfo=UTC)
hi = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
lo = datetime(last_mon.year, last_mon.month, last_mon.day, tzinfo=UTC)
hi = datetime(this_mon.year, this_mon.month, this_mon.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_MONTH:
lo = datetime(today.year, today.month, 1, tzinfo=UTC)
if today.month == 12:
hi = datetime(today.year + 1, 1, 1, tzinfo=UTC)
else:
hi = datetime(today.year, today.month + 1, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_MONTH:
if today.month == 1:
lo = datetime(today.year - 1, 12, 1, tzinfo=UTC)
else:
lo = datetime(today.year, today.month - 1, 1, tzinfo=UTC)
hi = datetime(today.year, today.month, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_YEAR:
lo = datetime(today.year, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year + 1, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_YEAR:
lo = datetime(today.year - 1, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
lo = datetime(
last_quarter.year,
last_quarter.month,
last_quarter.day,
tzinfo=UTC,
)
hi = datetime(
this_quarter.year,
this_quarter.month,
this_quarter.day,
tzinfo=UTC,
)
return _iso_range(lo, hi)
raise ValueError(f"Unknown keyword: {keyword}")
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC — full offset arithmetic required.
"""
now_local = datetime.now(tz)
today = now_local.date()
def _midnight(d: date) -> datetime:
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
return _iso_range(_midnight(today), _midnight(today + timedelta(days=1)))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
return _iso_range(_midnight(y), _midnight(today))
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
return _iso_range(_midnight(last_mon), _midnight(this_mon))
if keyword == _THIS_MONTH:
first = today.replace(day=1)
if today.month == 12:
next_first = date(today.year + 1, 1, 1)
else:
next_first = date(today.year, today.month + 1, 1)
return _iso_range(_midnight(first), _midnight(next_first))
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
if today.month == 1:
last_first = date(today.year - 1, 12, 1)
else:
last_first = date(today.year, today.month - 1, 1)
return _iso_range(_midnight(last_first), _midnight(this_first))
if keyword == _THIS_YEAR:
return _iso_range(
_midnight(date(today.year, 1, 1)),
_midnight(date(today.year + 1, 1, 1)),
)
if keyword == _PREVIOUS_YEAR:
return _iso_range(
_midnight(date(today.year - 1, 1, 1)),
_midnight(date(today.year, 1, 1)),
)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
return _iso_range(_midnight(last_quarter), _midnight(this_quarter))
raise ValueError(f"Unknown keyword: {keyword}")
def _rewrite_compact_date(query: str) -> str:
"""Rewrite Whoosh compact date tokens (14-digit YYYYMMDDHHmmss) to ISO 8601."""
def _sub(m: regex.Match[str]) -> str:
raw = m.group(1)
try:
dt = datetime(
int(raw[0:4]),
int(raw[4:6]),
int(raw[6:8]),
int(raw[8:10]),
int(raw[10:12]),
int(raw[12:14]),
tzinfo=UTC,
)
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
except ValueError:
return str(m.group(0))
try:
return _COMPACT_DATE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (compact date rewrite timed out)",
) from None
def _rewrite_relative_range(query: str) -> str:
"""Rewrite Whoosh relative ranges ([now-7d TO now]) to concrete ISO 8601 UTC boundaries."""
def _sub(m: regex.Match[str]) -> str:
now = datetime.now(UTC)
def _offset(s: str | None) -> timedelta:
if not s:
return timedelta(0)
sign = 1 if s[0] == "+" else -1
n, unit = int(s[1:-1]), s[-1]
return (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
lo, hi = now + _offset(m.group(1)), now + _offset(m.group(2))
if lo > hi:
lo, hi = hi, lo
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _RELATIVE_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (relative range rewrite timed out)",
) from None
def _rewrite_whoosh_relative_range(query: str) -> str:
"""Rewrite Whoosh-style relative date ranges ([-N unit to now]) to ISO 8601.
Supports: second, minute, hour, day, week, month, year (singular and plural).
Example: ``added:[-1 week to now]`` → ``added:[2025-01-01T… TO 2025-01-08T…]``
"""
now = datetime.now(UTC)
def _sub(m: regex.Match[str]) -> str:
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
lo = now - delta_map[unit]
return f"[{_fmt(lo)} TO {_fmt(now)}]"
try:
return _WHOOSH_REL_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (Whoosh relative range rewrite timed out)",
) from None
def _rewrite_8digit_date(query: str, tz: tzinfo) -> str:
"""Rewrite field:YYYYMMDD date tokens to an ISO 8601 day range.
Runs after ``_rewrite_compact_date`` so 14-digit timestamps are already
converted and won't spuriously match here.
For DateField fields (e.g. ``created``) uses UTC midnight boundaries.
For DateTimeField fields (e.g. ``added``, ``modified``) uses local TZ
midnight boundaries converted to UTC — matching the ``_datetime_range``
behaviour for keyword dates.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
raw = m.group("date8")
try:
year, month, day = int(raw[0:4]), int(raw[4:6]), int(raw[6:8])
d = date(year, month, day)
if field in _DATE_ONLY_FIELDS:
lo = datetime(d.year, d.month, d.day, tzinfo=UTC)
hi = lo + timedelta(days=1)
else:
# DateTimeField: use local-timezone midnight → UTC
lo = datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
hi = datetime(
(d + timedelta(days=1)).year,
(d + timedelta(days=1)).month,
(d + timedelta(days=1)).day,
tzinfo=tz,
).astimezone(UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
except ValueError:
return m.group(0)
try:
return _DATE8_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (8-digit date rewrite timed out)",
) from None
def _rewrite_year_range(query: str) -> str:
"""Rewrite Whoosh-style year-only date ranges to ISO 8601 UTC boundaries.
Converts ``field:[YYYY TO YYYY]`` to a full ISO 8601 datetime range.
The upper bound is the start of the year after the end year (exclusive),
matching the Whoosh convention of treating year-only ranges as full-year spans.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
y1, y2 = int(m.group("y1")), int(m.group("y2"))
# Whoosh swaps a reversed range when both years are explicit
# (whoosh.util.times.timespan.disambiguated); match that so a backwards
# range spans the intended years instead of matching nothing.
lo_year, hi_year = min(y1, y2), max(y1, y2)
lo = datetime(lo_year, 1, 1, tzinfo=UTC)
hi = datetime(hi_year + 1, 1, 1, tzinfo=UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _YEAR_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (year range rewrite timed out)",
) from None
def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
"""
Rewrite natural date syntax to ISO 8601 format for Tantivy compatibility.
Performs the first stage of query preprocessing, converting various date
formats and keywords to ISO 8601 datetime ranges that Tantivy can parse:
- Compact 14-digit dates (YYYYMMDDHHmmss)
- Whoosh relative ranges ([-7 days to now], [now-1h TO now+2h])
- 8-digit dates with field awareness (created:20240115)
- Natural keywords (field:today, field:"previous quarter", etc.)
Args:
query: Raw user query string
tz: Timezone for converting local date boundaries to UTC
Returns:
Query with date syntax rewritten to ISO 8601 ranges
Note:
Bare keywords without field prefixes pass through unchanged.
"""
query = _rewrite_compact_date(query)
query = _rewrite_whoosh_relative_range(query)
query = _rewrite_year_range(query)
query = _rewrite_8digit_date(query, tz)
query = _rewrite_relative_range(query)
def _replace(m: regex.Match[str]) -> str:
field = m.group("field")
keyword = (m.group("quoted") or m.group("bare")).lower()
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(keyword, tz)}"
return f"{field}:{_datetime_range(keyword, tz)}"
try:
return _FIELD_DATE_RE.sub(_replace, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (date keyword rewrite timed out)",
) from None
def normalize_query(query: str) -> str:
"""
Normalize query syntax for better search behavior.
Expands comma-separated field values to explicit AND clauses and
collapses excessive whitespace for cleaner parsing:
- tag:foo,bar → tag:foo AND tag:bar
- multiple spaces → single spaces
Args:
query: Query string after date rewriting
Returns:
Normalized query string ready for Tantivy parsing
"""
def _expand(m: regex.Match[str]) -> str:
field = m.group(1)
values = [v.strip() for v in m.group(2).split(",") if v.strip()]
return " AND ".join(f"{field}:{v}" for v in values)
try:
query = regex.sub(
r"(\w+):([^\s\[\]]+(?:,[^\s\[\]]+)+)",
_expand,
query,
timeout=_REGEX_TIMEOUT,
)
query = regex.sub(r" {2,}", " ", query, timeout=_REGEX_TIMEOUT).strip()
# Strip trailing dangling operators before Tantivy sees them.
query = _TRAILING_OPERATOR_RE.sub("", query, timeout=_REGEX_TIMEOUT).strip()
# Replace " - " / " + " with a space: Tantivy requires no space between
# the operator and its operand (-term / +term), so spaces on both sides
# means this is a natural-language separator, not a query operator.
query = _SPACED_OPERATOR_RE.sub(" ", query, timeout=_REGEX_TIMEOUT).strip()
return query
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (normalization timed out)",
) from None
def build_permission_filter(
schema: tantivy.Schema,
user: AbstractBaseUser,
@@ -607,8 +174,16 @@ def parse_user_query(
as a post-search score filter, not during query construction.
"""
query_str = rewrite_natural_date_keywords(raw_query, tz)
query_str = normalize_query(query_str)
try:
query_str = translate_query(raw_query, tz)
except SearchQueryError:
# Intentional, user-fixable error (e.g. an unparsable date). Propagate so
# the view can return a 400 with a helpful message rather than falling
# back to the raw (still-invalid) query.
raise
except Exception: # pragma: no cover - defensive
logger.warning("Query translation failed; using raw query", exc_info=True)
query_str = raw_query
exact = index.parse_query(
query_str,
+566
View File
@@ -0,0 +1,566 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TypeAlias
import regex
from dateutil.relativedelta import relativedelta
from documents.search._dates import _DATE_KEYWORDS
from documents.search._dates import _DATE_ONLY_FIELDS
from documents.search._dates import _date_only_range
from documents.search._dates import _datetime_range
from documents.search._dates import _field_range_from_dates
from documents.search._dates import _fmt
from documents.search._dates import _precision_bounds
from documents.search._dates import _utc_bounds_for_field
# Compiled regex that matches any known multi-word (or single-word) date keyword
# at the start of a match position, longest alternatives first so "previous week"
# wins over a hypothetical shorter "previous".
_KEYWORD_VALUE_RE = regex.compile(
"|".join(sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True)),
regex.IGNORECASE,
)
if TYPE_CHECKING:
from datetime import tzinfo
# TODO: this module translates date queries into Tantivy *string* syntax, which
# forces a workaround for something Tantivy's string parser cannot express on
# date fields: open-ended ranges use far-past/far-future string sentinels
# (OPEN_LO/OPEN_HI). These can be replaced with a real tantivy.Query object
# (Query.range_query(..., None) for open bounds) once tantivy-py accepts Python
# datetimes in range_query/term_query on Date fields. That support exists on
# tantivy-py master (PRs #655 + #666) but postdates the pinned 0.26.0 wheel, so
# it is blocked only on a published release > 0.26.0 and a dependency bump.
# (Unparsable dates now raise InvalidDateQuery -> HTTP 400 rather than using a
# no-match string sentinel.)
# Fields that store exact, non-analyzed comma-joined tokens in the index and so
# need explicit comma->AND expansion (Whoosh KEYWORD(commas=True) set).
MULTI_VALUE_FIELDS = frozenset({"tag", "tag_id", "viewer_id"})
# Date fields whose values/ranges get rewritten to RFC3339 Tantivy ranges.
DATE_FIELDS = frozenset({"created", "modified", "added"})
# Field aliases: Whoosh (v2) field names that were renamed in the Tantivy schema.
# Preserved here so v2 queries using the old names continue to work without 400
# errors instead of silently failing. Applied by _render to non-date field tokens.
FIELD_ALIASES: dict[str, str] = {
"type": "document_type",
"type_id": "document_type_id",
"path": "storage_path",
"path_id": "storage_path_id",
}
# Known schema fields: a comma immediately followed by ``<known>:`` is a clause
# separator. Restricting to known fields prevents URL-like ``http:`` misfires.
KNOWN_FIELDS = frozenset(
{
"title",
"content",
"correspondent",
"document_type",
"type", # v2 alias -> document_type
"storage_path",
"path", # v2 alias -> storage_path
"tag",
"tag_id",
"correspondent_id",
"document_type_id",
"type_id", # v2 alias -> document_type_id
"storage_path_id",
"path_id", # v2 alias -> storage_path_id
"owner_id",
"viewer_id",
"asn",
"page_count",
"num_notes",
"created",
"modified",
"added",
"original_filename",
"checksum",
"notes",
"custom_fields",
},
)
_FIELD_RE = regex.compile(r"(?P<field>\w+):")
# Matches the TO separator inside a range bracket. Handles three forms:
# middle: "lo TO hi" (either lo or hi may be empty)
# trailing: "lo TO" (open upper bound)
# leading: "TO hi" (open lower bound)
# Bounds MAY contain internal spaces (e.g. "-7 days"), so we use .*? / .+?
# and split on the whitespace-delimited " TO " / " to " separator.
_RANGE_RE = regex.compile(
r"^\s*(?P<lo>.*?)\s+[Tt][Oo]\s+(?P<hi>.+?)\s*$"
r"|"
r"^\s*(?P<lo2>.+?)\s+[Tt][Oo]\s*$"
r"|"
r"^\s*[Tt][Oo]\s+(?P<hi2>.+?)\s*$",
)
@dataclass(frozen=True, slots=True)
class FieldValue:
field: str
value: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class FieldValueList:
field: str
values: tuple[str, ...]
@dataclass(frozen=True, slots=True)
class FieldRange:
field: str
open: str
lo: str
hi: str
close: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class Comma:
pass
@dataclass(frozen=True, slots=True)
class Passthrough:
raw: str
Token: TypeAlias = FieldValue | FieldValueList | FieldRange | Comma | Passthrough
_CLOSE: dict[str, str] = {"[": "]", "{": "}"}
def scan(query: str) -> list[Token]:
"""
Tokenize a raw query into date/comma-aware tokens, leaving everything else
as verbatim ``Passthrough`` runs. Non-recursive: finds the first matching
close bracket/quote. Nested brackets are not valid Tantivy range syntax and
pass through verbatim on mismatch.
"""
tokens: list[Token] = []
buf: list[str] = [] # accumulates passthrough chars
i, n = 0, len(query)
while i < n:
matched = _match_field_token(query, i)
if matched is None:
buf.append(query[i])
i += 1
continue
token, i = matched
_flush(buf, tokens)
tokens.append(token)
i = _maybe_comma(query, i, tokens)
_flush(buf, tokens)
return tokens
def _flush(buf: list[str], tokens: list[Token]) -> None:
"""Emit any accumulated passthrough characters as a single token."""
if buf:
tokens.append(Passthrough("".join(buf)))
buf.clear()
def _at_word_boundary(query: str, i: int) -> bool:
"""A field token may begin only at the start or after a non-word character."""
return i == 0 or not (query[i - 1].isalnum() or query[i - 1] == "_")
def _match_field_token(query: str, i: int) -> tuple[Token, int] | None:
"""
If a known ``field:`` token starts at ``i``, consume it and return
``(token, end_index)``; otherwise return None so the caller treats the
character as passthrough. Handles both ``field:[range]`` and ``field:value``,
and returns None when the range/value cannot be consumed.
"""
m = _FIELD_RE.match(query, i)
if m is None or m.group("field") not in KNOWN_FIELDS:
return None
if not _at_word_boundary(query, i):
return None
field = m.group("field")
j = m.end()
if j < len(query) and query[j] in "[{":
return _consume_range(query, j, field)
consumed = _consume_field_value(query, field, j)
if consumed is None:
return None
value, end = consumed
return FieldValue(field, value), end
def _consume_field_value(query: str, field: str, start: int) -> tuple[str, int] | None:
"""
Consume a field value starting at ``start``: a multi-word date keyword phrase
(date fields only), or a bare/quoted value, then absorb any comma-joined
continuation that is not a clause separator. ``resolve_commas`` later splits a
multi-value field's joined value into a ``FieldValueList``; for other fields
the comma stays literal.
"""
n = len(query)
consumed = None
if field in DATE_FIELDS:
km = _KEYWORD_VALUE_RE.match(query, start)
if km is not None and (km.end() >= n or query[km.end()] in " \t),"):
consumed = (km.group(0), km.end())
if consumed is None:
consumed = _consume_value(query, start)
if consumed is None:
return None
value, k = consumed
while k < n and query[k] == ",":
if _looks_like_known_field(query, k + 1):
break # clause separator: left for _maybe_comma to emit a Comma()
more = _consume_value(query, k + 1)
if more is None:
break
value = f"{value},{more[0]}"
k = more[1]
return value, k
def _consume_range(
query: str,
start: int,
field: str,
) -> tuple[FieldRange, int] | None:
"""Consume ``[lo TO hi]`` / ``{lo TO hi}`` from ``start`` (the bracket)."""
open_br = query[start]
close_br = _CLOSE[open_br]
end = query.find(close_br, start + 1)
if end == -1:
return None
inner = query[start + 1 : end]
m = _RANGE_RE.match(inner)
if m is not None:
if m.group("lo") is not None or m.group("hi") is not None:
# Middle form: "lo TO hi" (either may be empty string)
lo = (m.group("lo") or "").strip()
hi = (m.group("hi") or "").strip()
elif m.group("lo2") is not None:
# Trailing form: "lo TO"
lo = m.group("lo2").strip()
hi = ""
else:
# Leading form: "TO hi"
lo = ""
hi = (m.group("hi2") or "").strip()
else:
lo, hi = inner.strip(), ""
return FieldRange(field, open_br, lo, hi, close_br), end + 1
def _consume_value(query: str, start: int) -> tuple[str, int] | None:
"""Consume a bare or quoted field value from ``start``, stopping at comma."""
n = len(query)
if start >= n or query[start] in " \t":
return None
if query[start] in "\"'":
quote = query[start]
end = query.find(quote, start + 1)
if end == -1:
return None
return query[start : end + 1], end + 1
j = start
while j < n and query[j] not in " \t),":
j += 1
return query[start:j], j
def _looks_like_known_field(query: str, pos: int) -> bool:
"""True if a known ``field:`` token starts at ``pos``."""
m = _FIELD_RE.match(query, pos)
return bool(m and m.group("field") in KNOWN_FIELDS)
def _maybe_comma(query: str, i: int, tokens: list) -> int:
"""If a clause-separator comma follows at ``i``, emit ``Comma()`` and advance."""
if i < len(query) and query[i] == "," and _looks_like_known_field(query, i + 1):
tokens.append(Comma())
return i + 1
return i
def resolve_commas(tokens: list) -> list:
"""
Collapse value-list commas into ``FieldValueList`` and keep clause-separator
commas as ``Comma``. (Clause-sep commas are already emitted by ``scan`` via
the value-stop logic; this pass folds value-lists.)
"""
out: list = []
for tok in tokens:
if (
isinstance(tok, FieldValue)
and tok.field in MULTI_VALUE_FIELDS
and "," in tok.value
):
values = tuple(v for v in tok.value.split(",") if v)
out.append(FieldValueList(tok.field, values))
else:
out.append(tok)
return out
class SearchQueryError(ValueError):
"""
Base for user-fixable search query errors.
Carries a message safe to surface to the user (no internal details). The view
layer catches this and returns an HTTP 400, so any future subclass (unknown
field, malformed range, wrapped parser errors) gets the same treatment.
"""
class InvalidDateQuery(SearchQueryError):
"""Raised when a date field value or range bound cannot be parsed."""
def __init__(self, field: str, value: str) -> None:
self.field = field
self.value = value
super().__init__(f"Invalid date value {value!r} for field {field!r}.")
_DIGITS_RE = regex.compile(r"^\d{4}(?:\d{2}){0,2}$")
_ISO_RE = regex.compile(r"^\d{4}(?:-\d{2}(?:-\d{2})?)?$")
def translate_scalar(field: str, value: str, tz: tzinfo) -> str:
"""Translate a bare date-field value to a Tantivy range string."""
bare = value.strip("\"'").lower()
if bare in _DATE_KEYWORDS:
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(bare, tz)}"
return f"{field}:{_datetime_range(bare, tz)}"
digits = value.replace("-", "")
if _DIGITS_RE.match(value) or _ISO_RE.match(value):
bounds = _precision_bounds(digits)
if bounds is None:
raise InvalidDateQuery(field, value)
return _field_range_from_dates(field, bounds[0], bounds[1], tz)
if regex.fullmatch(r"\d{14}", value):
try:
dt = datetime(
int(value[0:4]),
int(value[4:6]),
int(value[6:8]),
int(value[8:10]),
int(value[10:12]),
int(value[12:14]),
tzinfo=UTC,
)
except ValueError:
raise InvalidDateQuery(field, value) from None
iso = _fmt(dt)
return f"{field}:[{iso} TO {iso}]"
# Unrecognized shape -> tell the user their date is malformed rather than
# silently matching nothing or emitting invalid Tantivy syntax.
raise InvalidDateQuery(field, value)
# Open-bound sentinels for date ranges. These far-past/far-future strings allow
# open-ended ranges to be expressed as Tantivy string queries until tantivy-py
# exposes Query.range_query(..., None) on Date fields (see module TODO).
OPEN_LO = "0001-01-01T00:00:00Z"
OPEN_HI = "9999-12-31T23:59:59Z"
# Matches compact now-offset tokens like now-7d, now+1h, now-30m.
_NOW_COMPACT_RE = regex.compile(
r"^now(?P<sign>[+-])(?P<n>\d+)(?P<unit>[dhm])$",
regex.IGNORECASE,
)
# Matches "±N <unit>" Whoosh-style offsets (e.g. -7 days, -1 week, +3 hours)
# Unit is singular or plural; sign prefix is mandatory.
_NOW_SPACED_RE = regex.compile(
r"^(?P<sign>[+-])(?P<n>\d+)\s*"
r"(?P<unit>second|minute|hour|day|week|month|year)s?$",
regex.IGNORECASE,
)
def _resolve_relative_bound(token: str) -> datetime | None:
"""
Resolve a relative bound token to an exact UTC instant, or return None.
Supported forms:
- ``now`` -> current UTC instant
- ``now+/-<n>d/h/m`` -> now +/- timedelta (d=days, h=hours, m=minutes)
- ``±N <unit>`` -> now +/- delta; month/year use relativedelta
"""
stripped = token.strip()
low = stripped.lower()
now = datetime.now(UTC)
if low == "now":
return now
m = _NOW_COMPACT_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta = (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
return now + delta
m = _NOW_SPACED_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
return now - delta_map[unit] if sign == -1 else now + delta_map[unit]
return None
def _bound_datetimes(
field: str,
token: str,
tz: tzinfo,
) -> tuple[datetime, datetime] | None:
"""
Return (floor_dt, ceil_dt) UTC datetimes for a single range bound token, or
None if the token is unparsable. ``now`` and relative offsets resolve to the
current instant (floor == ceil == that instant; no day-flooring).
"""
token = token.strip()
# Try relative/now forms first (before stripping hyphens which would mangle them).
rel = _resolve_relative_bound(token)
if rel is not None:
return rel, rel
# Full ISO datetime token (contains "T"): parse directly and return an exact
# instant (floor == ceil). Python 3.11+ datetime.fromisoformat accepts trailing Z.
if "T" in token:
try:
dt = datetime.fromisoformat(token)
# Ensure timezone-aware UTC result.
dt = dt.replace(tzinfo=UTC) if dt.tzinfo is None else dt.astimezone(UTC)
return dt, dt
except ValueError:
return None
digits = token.replace("-", "")
bounds = _precision_bounds(digits)
if bounds is None:
return None
start, end = bounds
return _utc_bounds_for_field(field, start, end, tz)
def _render(tok: Token, tz: tzinfo) -> str:
"""Render a single token back to a Tantivy query string fragment."""
if isinstance(tok, Passthrough):
return tok.raw
if isinstance(tok, Comma):
return " AND "
if isinstance(tok, FieldValueList):
field = FIELD_ALIASES.get(tok.field, tok.field)
return " AND ".join(f"{field}:{v}" for v in tok.values)
if isinstance(tok, FieldValue):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_scalar(field, tok.value, tz)
return f"{field}:{tok.value}"
if isinstance(tok, FieldRange):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_range(field, tok.lo, tok.hi, tz)
return f"{field}:{tok.open}{tok.lo} TO {tok.hi}{tok.close}"
return "" # pragma: no cover
# Post-render operator normalization patterns: collapse repeated whitespace and
# strip spaced/trailing Tantivy boolean operators that would otherwise be invalid.
_MULTI_SPACE_RE = regex.compile(r" {2,}")
_TRAILING_OP_RE = regex.compile(r"\s+[-+]+\s*$")
_SPACED_OP_RE = regex.compile(r"\s+[-+]\s+")
def _normalize_operators(text: str) -> str:
"""
Collapse multiple spaces, strip trailing dangling operators, and replace
spaced operators (`` - `` / `` + ``) with a single space.
Applied only to Passthrough fragments (the rendered output is scanned for
operator artifacts outside bracketed ranges) via a post-render pass on the
full rendered string. This preserves date ranges (``[... TO ...]``) verbatim
while cleaning natural-language separators in the surrounding text.
"""
text = _MULTI_SPACE_RE.sub(" ", text)
text = _TRAILING_OP_RE.sub("", text).strip()
text = _SPACED_OP_RE.sub(" ", text).strip()
return text
def translate_query(raw: str, tz: tzinfo) -> str:
"""Translate a raw Whoosh-style query into Tantivy-compatible syntax."""
tokens = resolve_commas(scan(raw))
rendered = "".join(_render(t, tz) for t in tokens)
return _normalize_operators(rendered)
def translate_range(field: str, lo: str, hi: str, tz: tzinfo) -> str:
"""Translate a date-field ``[lo TO hi]`` range to a Tantivy ISO range string.
Handles partial-date bounds (YYYY, YYYYMM, YYYYMMDD, ISO dash variants),
open bounds (empty string -> OPEN_LO/OPEN_HI), ``now``, and reversed ranges
(swaps tokens before computing floor/ceil so the span is always correct).
"""
lo_s = lo.strip()
hi_s = hi.strip()
# Parse both bounds to (floor, ceil) pairs when present.
lo_pair: tuple[datetime, datetime] | None = None
hi_pair: tuple[datetime, datetime] | None = None
if lo_s:
lo_pair = _bound_datetimes(field, lo_s, tz)
if lo_pair is None:
raise InvalidDateQuery(field, lo_s)
if hi_s:
hi_pair = _bound_datetimes(field, hi_s, tz)
if hi_pair is None:
raise InvalidDateQuery(field, hi_s)
# Detect a reversed range: only swap when BOTH bounds are present.
if lo_pair is not None and hi_pair is not None and lo_pair[0] > hi_pair[0]:
lo_pair, hi_pair = hi_pair, lo_pair
lo_iso = _fmt(lo_pair[0]) if lo_pair is not None else OPEN_LO
hi_iso = _fmt(hi_pair[1]) if hi_pair is not None else OPEN_HI
return f"{field}:[{lo_iso} TO {hi_iso}]"
+87 -32
View File
@@ -48,6 +48,7 @@ from rest_framework import serializers
from rest_framework.exceptions import PermissionDenied
from rest_framework.fields import SerializerMethodField
from rest_framework.filters import OrderingFilter
from rest_framework.utils import model_meta
if settings.AUDIT_LOG_ENABLED:
from auditlog.context import set_actor
@@ -121,6 +122,45 @@ class DynamicFieldsModelSerializer(serializers.ModelSerializer[Any]):
self.fields.pop(field_name)
class DocumentUpdateFieldsModelSerializer(DynamicFieldsModelSerializer):
stale_update_excluded_fields = frozenset({"filename", "archive_filename"})
def _get_update_fields(self, validated_data) -> list[str]:
model_fields = {
field.name
for field in self.Meta.model._meta.concrete_fields
if field.name not in self.stale_update_excluded_fields
}
update_fields = [
field_name for field_name in validated_data if field_name in model_fields
]
if "modified" in model_fields and "modified" not in update_fields:
update_fields.append("modified")
return update_fields
def update(self, instance, validated_data):
serializers.raise_errors_on_nested_writes("update", self, validated_data)
info = model_meta.get_field_info(instance)
m2m_fields = []
for attr, value in validated_data.items():
if attr in info.relations and info.relations[attr].to_many:
m2m_fields.append((attr, value))
else:
setattr(instance, attr, value)
# File names are managed by post-save file handling. Saving only the
# serializer-updated fields prevents stale in-memory path values from
# overwriting a concurrent move.
instance.save(update_fields=self._get_update_fields(validated_data))
for attr, value in m2m_fields:
field = getattr(instance, attr)
field.set(value)
return instance
class MatchingModelSerializer(serializers.ModelSerializer[Any]):
document_count = serializers.IntegerField(read_only=True)
@@ -163,7 +203,7 @@ class MatchingModelSerializer(serializers.ModelSerializer[Any]):
logger.debug(f"Invalid regular expression: {e!s}")
raise serializers.ValidationError(
"Invalid regular expression, see log for details.",
) from None
)
return match
@@ -867,9 +907,7 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer[CustomFieldInsta
try:
value_int = int(data["value"])
except (TypeError, ValueError):
raise serializers.ValidationError(
"Enter a valid integer.",
) from None
raise serializers.ValidationError("Enter a valid integer.")
# Keep values within the PostgreSQL integer range
MinValueValidator(-2147483648)(value_int)
MaxValueValidator(2147483647)(value_int)
@@ -901,7 +939,7 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer[CustomFieldInsta
except Exception:
raise serializers.ValidationError(
f"Value must be an id of an element in {select_options}",
) from None
)
elif field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
if not (isinstance(data["value"], list) or data["value"] is None):
raise serializers.ValidationError(
@@ -991,7 +1029,7 @@ class DocumentVersionInfoSerializer(serializers.Serializer[_DocumentVersionInfo]
class DocumentSerializer(
OwnedObjectSerializer,
NestedUpdateMixin,
DynamicFieldsModelSerializer,
DocumentUpdateFieldsModelSerializer,
):
correspondent = CorrespondentField(allow_null=True)
tags = TagsField(many=True)
@@ -1092,7 +1130,7 @@ class DocumentSerializer(
def to_representation(self, instance):
doc = super().to_representation(instance)
if "content" in self.fields and hasattr(instance, "effective_content"):
doc["content"] = instance.effective_content or ""
doc["content"] = getattr(instance, "effective_content") or ""
if self.truncate_content and "content" in self.fields:
doc["content"] = doc.get("content")[0:550]
return doc
@@ -1130,10 +1168,9 @@ class DocumentSerializer(
return super().validate(attrs)
def update(self, instance: Document, validated_data):
if "created_date" in validated_data and "created" not in validated_data:
instance.created = validated_data.get("created_date")
instance.save()
if "created_date" in validated_data:
if "created" not in validated_data:
validated_data["created"] = validated_data["created_date"]
logger.warning(
"created_date is deprecated, use created instead",
)
@@ -1203,11 +1240,13 @@ class DocumentSerializer(
for tag in instance.tags.all()
if tag not in inbox_tags_not_being_added
]
if settings.AUDIT_LOG_ENABLED:
with set_actor(self.user):
super().update(instance, validated_data)
else:
super().update(instance, validated_data)
# hard delete custom field instances that were soft deleted
CustomFieldInstance.deleted_objects.filter(document=instance).delete()
return instance
@@ -1454,7 +1493,7 @@ class SavedViewSerializer(OwnedObjectSerializer):
)
)
except serializers.ValidationError as exc:
raise serializers.ValidationError({field_name: exc.detail}) from exc
raise serializers.ValidationError({field_name: exc.detail})
del normalized_data[field_name]
ret = super().to_internal_value(normalized_data)
@@ -1758,7 +1797,7 @@ class BulkEditSerializer(
logger.exception(f"Error validating custom fields: {e}")
raise serializers.ValidationError(
f"{name} must be a list of integers or a dict of id:value pairs, see the log for details",
) from None
)
elif not isinstance(custom_fields, list) or not all(
isinstance(i, int) for i in ids
):
@@ -1826,7 +1865,7 @@ class BulkEditSerializer(
try:
Tag.objects.get(id=tag_id)
except Tag.DoesNotExist:
raise serializers.ValidationError("Tag does not exist") from None
raise serializers.ValidationError("Tag does not exist")
else:
raise serializers.ValidationError("tag not specified")
@@ -1839,9 +1878,7 @@ class BulkEditSerializer(
try:
DocumentType.objects.get(id=document_type_id)
except DocumentType.DoesNotExist:
raise serializers.ValidationError(
"Document type does not exist",
) from None
raise serializers.ValidationError("Document type does not exist")
else:
raise serializers.ValidationError("document_type not specified")
@@ -1853,9 +1890,7 @@ class BulkEditSerializer(
try:
Correspondent.objects.get(id=correspondent_id)
except Correspondent.DoesNotExist:
raise serializers.ValidationError(
"Correspondent does not exist",
) from None
raise serializers.ValidationError("Correspondent does not exist")
else:
raise serializers.ValidationError("correspondent not specified")
@@ -1869,7 +1904,7 @@ class BulkEditSerializer(
except StoragePath.DoesNotExist:
raise serializers.ValidationError(
"Storage path does not exist",
) from None
)
else:
raise serializers.ValidationError("storage path not specified")
@@ -1924,7 +1959,7 @@ class BulkEditSerializer(
):
raise serializers.ValidationError("invalid rotation degrees")
except ValueError:
raise serializers.ValidationError("invalid rotation degrees") from None
raise serializers.ValidationError("invalid rotation degrees")
def _validate_source_mode(self, parameters) -> None:
source_mode = parameters.get(
@@ -1954,7 +1989,7 @@ class BulkEditSerializer(
pages.append([int(doc)])
parameters["pages"] = pages
except ValueError:
raise serializers.ValidationError("invalid pages specified") from None
raise serializers.ValidationError("invalid pages specified")
if "delete_originals" in parameters:
if not isinstance(parameters["delete_originals"], bool):
@@ -2224,14 +2259,14 @@ class PostDocumentSerializer(serializers.Serializer[dict[str, Any]]):
raise serializers.ValidationError(
_("Custom field id must be an integer: %(id)s")
% {"id": field_id},
) from None
)
try:
field = CustomField.objects.get(id=field_id_int)
except CustomField.DoesNotExist:
raise serializers.ValidationError(
_("Custom field with id %(id)s does not exist")
% {"id": field_id_int},
) from None
)
custom_field_serializer.validate(
{
"field": field,
@@ -2248,7 +2283,7 @@ class PostDocumentSerializer(serializers.Serializer[dict[str, Any]]):
_(
"Custom fields must be a list of integers or an object mapping ids to values.",
),
) from None
)
if CustomField.objects.filter(id__in=ids).count() != len(set(ids)):
raise serializers.ValidationError(
_("Some custom fields don't exist or were specified twice."),
@@ -2359,9 +2394,7 @@ class EmailSerializer(DocumentListSerializer):
for address in address_list:
email_validator(address)
except ValidationError:
raise serializers.ValidationError(
f"Invalid email address: {address}",
) from None
raise serializers.ValidationError(f"Invalid email address: {address}")
return ",".join(address_list)
@@ -2640,18 +2673,25 @@ class RunTaskSerializer(serializers.Serializer[dict[str, str]]):
class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
tasks = serializers.ListField(
required=True,
required=False,
label="Tasks",
write_only=True,
child=serializers.IntegerField(),
)
all = serializers.BooleanField(
required=False,
default=False,
label="All",
write_only=True,
)
def _validate_task_id_list(self, tasks, name="tasks") -> None:
if not isinstance(tasks, list):
raise serializers.ValidationError(f"{name} must be a list")
if not all(isinstance(i, int) for i in tasks):
raise serializers.ValidationError(f"{name} must be a list of integers")
count = PaperlessTask.objects.filter(id__in=tasks).count()
queryset = self.context.get("queryset", PaperlessTask.objects.all())
count = queryset.filter(id__in=tasks).count()
if not count == len(tasks):
raise serializers.ValidationError(
f"Some tasks in {name} don't exist or were specified twice.",
@@ -2661,6 +2701,21 @@ class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
self._validate_task_id_list(tasks)
return tasks
def validate(self, attrs):
acknowledge_all = attrs.get("all", False)
task_ids = attrs.get("tasks")
if acknowledge_all and task_ids is not None:
raise serializers.ValidationError(
"Set either all or tasks, not both.",
)
if not acknowledge_all and task_ids is None:
raise serializers.ValidationError(
"Either all must be true or tasks must be provided.",
)
return attrs
class ShareLinkSerializer(OwnedObjectSerializer):
class Meta:
@@ -2785,7 +2840,7 @@ class ShareLinkBundleSerializer(OwnedObjectSerializer):
return share_link_bundle
def get_document_count(self, obj: ShareLinkBundle) -> int:
return obj.document_total or obj.documents.count()
return getattr(obj, "document_total") or obj.documents.count()
class BulkEditObjectsSerializer(SerializerWithPerms, SetPermissionsMixin):
@@ -3133,7 +3188,7 @@ class WorkflowActionSerializer(serializers.ModelSerializer[WorkflowAction]):
except (ValueError, KeyError) as e:
raise serializers.ValidationError(
{"assign_title": f'Invalid f-string detected: "{e.args[0]}"'},
) from None
)
if (
"type" in attrs
+17 -3
View File
@@ -1,7 +1,6 @@
from __future__ import annotations
import datetime
import hashlib
import logging
import shutil
import traceback as _tb
@@ -16,6 +15,7 @@ from celery.signals import task_postrun
from celery.signals import task_prerun
from celery.signals import task_revoked
from celery.signals import worker_process_init
from celery.signals import worker_process_shutdown
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
@@ -54,6 +54,7 @@ from documents.models import WorkflowTrigger
from documents.permissions import get_objects_for_user_owner_aware
from documents.plugins.helpers import DocumentsStatusManager
from documents.templating.utils import convert_format_str_to_template_format
from documents.utils import compute_checksum
from documents.workflows.actions import build_workflow_action_context
from documents.workflows.actions import execute_email_action
from documents.workflows.actions import execute_move_to_trash_action
@@ -410,8 +411,7 @@ def _path_matches_checksum(path: Path, checksum: str | None) -> bool:
if checksum is None or not path.is_file():
return False
with path.open("rb") as f:
return hashlib.md5(f.read(), usedforsecurity=False).hexdigest() == checksum
return compute_checksum(path) == checksum
def _filename_template_uses_custom_fields(doc: Document) -> bool:
@@ -1340,6 +1340,20 @@ def close_connection_pool_on_worker_init(**kwargs) -> None:
conn.close_pool()
@worker_process_shutdown.connect
def close_connection_pool_on_worker_shutdown(**kwargs) -> None: # pragma: no cover
"""
Close the DB connection pool when a Celery child process exits.
With CELERY_WORKER_MAX_TASKS_PER_CHILD=1 each child is replaced after a
single task. Without closing the pool on shutdown, its connections linger
on the server until TCP keepalive reaps them, accumulating over time.
"""
for conn in connections.all(initialized_only=True):
if conn.alias == "default" and hasattr(conn, "pool") and conn.pool:
conn.close_pool()
def add_or_update_document_in_llm_index(sender, document, **kwargs):
"""
Add or update a document in the LLM index when it is created or updated.
+24 -15
View File
@@ -1,6 +1,7 @@
import logging
import os
import re
import unicodedata
from collections.abc import Iterable
from pathlib import PurePath
@@ -36,10 +37,12 @@ class FilePathTemplate(Template):
def clean_filepath(value: str) -> str:
"""
Clean up a filepath by:
1. Removing newlines and carriage returns
2. Removing extra spaces before and after forward slashes
3. Preserving spaces in other parts of the path
1. Normalizing Unicode to NFC form to prevent byte-level mismatches
2. Removing newlines and carriage returns
3. Removing extra spaces before and after forward slashes
4. Preserving spaces in other parts of the path
"""
value = unicodedata.normalize("NFC", value)
value = value.replace("\n", "").replace("\r", "")
value = re.sub(r"\s*/\s*", "/", value)
@@ -181,17 +184,17 @@ def get_basic_metadata_context(
"""
return {
"title": pathvalidate.sanitize_filename(
document.title,
unicodedata.normalize("NFC", document.title),
replacement_text="-",
),
"correspondent": pathvalidate.sanitize_filename(
document.correspondent.name,
unicodedata.normalize("NFC", document.correspondent.name),
replacement_text="-",
)
if document.correspondent
else no_value_default,
"document_type": pathvalidate.sanitize_filename(
document.document_type.name,
unicodedata.normalize("NFC", document.document_type.name),
replacement_text="-",
)
if document.document_type
@@ -202,7 +205,10 @@ def get_basic_metadata_context(
"owner_username": document.owner.username
if document.owner
else no_value_default,
"original_name": PurePath(document.original_filename).with_suffix("").name
"original_name": unicodedata.normalize(
"NFC",
PurePath(document.original_filename).with_suffix("").name,
)
if document.original_filename
else no_value_default,
"doc_pk": f"{document.pk:07}",
@@ -269,12 +275,12 @@ def get_tags_context(tags: Iterable[Tag]) -> dict[str, str | list[str]]:
return {
"tag_list": pathvalidate.sanitize_filename(
",".join(
sorted(tag.name for tag in tags),
sorted(unicodedata.normalize("NFC", tag.name) for tag in tags),
),
replacement_text="-",
),
# Assumed to be ordered, but a template could loop through to find what they want
"tag_name_list": [x.name for x in tags],
"tag_name_list": [unicodedata.normalize("NFC", x.name) for x in tags],
}
@@ -301,7 +307,7 @@ def get_custom_fields_context(
CustomField.FieldDataType.LONG_TEXT,
}:
value = pathvalidate.sanitize_filename(
field_instance.value,
unicodedata.normalize("NFC", field_instance.value),
replacement_text="-",
)
elif (
@@ -310,10 +316,13 @@ def get_custom_fields_context(
):
options = field_instance.field.extra_data["select_options"]
value = pathvalidate.sanitize_filename(
next(
option["label"]
for option in options
if option["id"] == field_instance.value
unicodedata.normalize(
"NFC",
next(
option["label"]
for option in options
if option["id"] == field_instance.value
),
),
replacement_text="-",
)
@@ -321,7 +330,7 @@ def get_custom_fields_context(
value = field_instance.value
field_data["custom_fields"][
pathvalidate.sanitize_filename(
field_instance.field.name,
unicodedata.normalize("NFC", field_instance.field.name),
replacement_text="-",
)
] = {
@@ -29,7 +29,9 @@ class SimpleCommand(PaperlessCommand):
def handle(self, *args, **options):
items = list(range(5))
results = [item * 2 for item in self.track(items, description="Processing...")]
results = []
for item in self.track(items, description="Processing..."):
results.append(item * 2)
self.stdout.write(f"Results: {results}")
@@ -55,13 +57,13 @@ class MultiprocessCommand(PaperlessCommand):
def handle(self, *args, **options):
items = list(range(5))
results = list(
self.process_parallel(
_double_value,
items,
description="Processing...",
),
)
results = []
for result in self.process_parallel(
_double_value,
items,
description="Processing...",
):
results.append(result)
successes = sum(1 for r in results if r.success)
self.stdout.write(f"Successes: {successes}")
@@ -0,0 +1,36 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from django.core.management import call_command
if TYPE_CHECKING:
from pytest_mock import MockerFixture
_COMPACT = "documents.management.commands.document_llmindex.llm_index_compact"
_INDEX = "documents.management.commands.document_llmindex.llmindex_index"
class TestDocumentLlmindexCommand:
def test_compact_calls_llm_index_compact(self, mocker: MockerFixture) -> None:
mock_compact = mocker.patch(_COMPACT)
call_command("document_llmindex", "compact")
mock_compact.assert_called_once_with()
def test_rebuild_calls_llmindex_index_with_rebuild_true(
self,
mocker: MockerFixture,
) -> None:
mock_index = mocker.patch(_INDEX)
call_command("document_llmindex", "rebuild")
mock_index.assert_called_once()
assert mock_index.call_args.kwargs["rebuild"] is True
def test_update_calls_llmindex_index_with_rebuild_false(
self,
mocker: MockerFixture,
) -> None:
mock_index = mocker.patch(_INDEX)
call_command("document_llmindex", "update")
mock_index.assert_called_once()
assert mock_index.call_args.kwargs["rebuild"] is False
+12
View File
@@ -1,11 +1,15 @@
from __future__ import annotations
import tempfile
from typing import TYPE_CHECKING
import pytest
import tantivy
from documents.search._backend import TantivyBackend
from documents.search._backend import reset_backend
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
if TYPE_CHECKING:
from collections.abc import Generator
@@ -31,3 +35,11 @@ def backend() -> Generator[TantivyBackend, None, None]:
finally:
b.close()
reset_backend()
@pytest.fixture(scope="module")
def index() -> tantivy.Index:
"""A real Tantivy index for parse-acceptance tests (module scope for speed)."""
idx = tantivy.Index(build_schema(), path=tempfile.mkdtemp())
register_tokenizers(idx, "english")
return idx
+135 -56
View File
@@ -11,16 +11,15 @@ import pytest
import tantivy
import time_machine
from documents.search._query import _date_only_range
from documents.search._query import _datetime_range
from documents.search._query import _rewrite_compact_date
from documents.search._dates import _date_only_range
from documents.search._dates import _datetime_range
from documents.search._query import build_permission_filter
from documents.search._query import normalize_query
from documents.search._query import parse_simple_text_highlight_query
from documents.search._query import parse_user_query
from documents.search._query import rewrite_natural_date_keywords
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
from documents.search._translate import InvalidDateQuery
from documents.search._translate import translate_query
if TYPE_CHECKING:
from django.contrib.auth.base_user import AbstractBaseUser
@@ -57,7 +56,7 @@ class TestCreatedDateField:
)
@time_machine.travel(datetime(2026, 3, 28, 15, 30, tzinfo=UTC), tick=False)
def test_today(self, tz: tzinfo, expected_lo: str, expected_hi: str) -> None:
lo, hi = _range(rewrite_natural_date_keywords("created:today", tz), "created")
lo, hi = _range(translate_query("created:today", tz), "created")
assert lo == expected_lo
assert hi == expected_hi
@@ -65,7 +64,7 @@ class TestCreatedDateField:
def test_today_auckland_ahead_of_utc(self) -> None:
# UTC 03:00 -> Auckland (UTC+13) = 16:00 same date; local date = 2026-03-28
lo, _ = _range(
rewrite_natural_date_keywords("created:today", AUCKLAND),
translate_query("created:today", AUCKLAND),
"created",
)
assert lo == "2026-03-28T00:00:00Z"
@@ -127,7 +126,7 @@ class TestCreatedDateField:
) -> None:
# 2026-03-28 is Saturday; Mon-Sun week calculation built into expectations
query = f"{field}:{keyword}"
lo, hi = _range(rewrite_natural_date_keywords(query, UTC), field)
lo, hi = _range(translate_query(query, UTC), field)
assert lo == expected_lo
assert hi == expected_hi
@@ -135,7 +134,7 @@ class TestCreatedDateField:
def test_this_month_december_wraps_to_next_year(self) -> None:
# December: next month must roll over to January 1 of next year
lo, hi = _range(
rewrite_natural_date_keywords("created:this month", UTC),
translate_query("created:this month", UTC),
"created",
)
assert lo == "2026-12-01T00:00:00Z"
@@ -145,7 +144,7 @@ class TestCreatedDateField:
def test_last_month_january_wraps_to_previous_year(self) -> None:
# January: last month must roll back to December 1 of previous year
lo, hi = _range(
rewrite_natural_date_keywords("created:previous month", UTC),
translate_query("created:previous month", UTC),
"created",
)
assert lo == "2025-12-01T00:00:00Z"
@@ -154,7 +153,7 @@ class TestCreatedDateField:
@time_machine.travel(datetime(2026, 7, 15, 12, 0, tzinfo=UTC), tick=False)
def test_previous_quarter(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords('created:"previous quarter"', UTC),
translate_query('created:"previous quarter"', UTC),
"created",
)
assert lo == "2026-04-01T00:00:00Z"
@@ -174,7 +173,7 @@ class TestDateTimeFields:
@time_machine.travel(datetime(2026, 3, 28, 15, 30, tzinfo=UTC), tick=False)
def test_added_today_eastern(self) -> None:
# EDT = UTC-4; local midnight 2026-03-28 00:00 EDT = 2026-03-28 04:00 UTC
lo, hi = _range(rewrite_natural_date_keywords("added:today", EASTERN), "added")
lo, hi = _range(translate_query("added:today", EASTERN), "added")
assert lo == "2026-03-28T04:00:00Z"
assert hi == "2026-03-29T04:00:00Z"
@@ -182,14 +181,14 @@ class TestDateTimeFields:
def test_added_today_auckland_midnight_crossing(self) -> None:
# UTC 02:00 on 2026-03-29 -> Auckland (UTC+13) = 2026-03-29 15:00 local
# Auckland midnight = UTC 2026-03-28 11:00
lo, hi = _range(rewrite_natural_date_keywords("added:today", AUCKLAND), "added")
lo, hi = _range(translate_query("added:today", AUCKLAND), "added")
assert lo == "2026-03-28T11:00:00Z"
assert hi == "2026-03-29T11:00:00Z"
@time_machine.travel(datetime(2026, 3, 28, 15, 0, tzinfo=UTC), tick=False)
def test_modified_today_utc(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords("modified:today", UTC),
translate_query("modified:today", UTC),
"modified",
)
assert lo == "2026-03-28T00:00:00Z"
@@ -244,14 +243,14 @@ class TestDateTimeFields:
expected_hi: str,
) -> None:
# 2026-03-28 is Saturday; weekday()==5 so Monday=2026-03-23
lo, hi = _range(rewrite_natural_date_keywords(f"added:{keyword}", UTC), "added")
lo, hi = _range(translate_query(f"added:{keyword}", UTC), "added")
assert lo == expected_lo
assert hi == expected_hi
@time_machine.travel(datetime(2026, 12, 15, 12, 0, tzinfo=UTC), tick=False)
def test_this_month_december_wraps_to_next_year(self) -> None:
# December: next month wraps to January of next year
lo, hi = _range(rewrite_natural_date_keywords("added:this month", UTC), "added")
lo, hi = _range(translate_query("added:this month", UTC), "added")
assert lo == "2026-12-01T00:00:00Z"
assert hi == "2027-01-01T00:00:00Z"
@@ -259,7 +258,7 @@ class TestDateTimeFields:
def test_last_month_january_wraps_to_previous_year(self) -> None:
# January: last month wraps back to December of previous year
lo, hi = _range(
rewrite_natural_date_keywords("added:previous month", UTC),
translate_query("added:previous month", UTC),
"added",
)
assert lo == "2025-12-01T00:00:00Z"
@@ -295,7 +294,7 @@ class TestDateTimeFields:
expected_lo: str,
expected_hi: str,
) -> None:
lo, hi = _range(rewrite_natural_date_keywords(query, UTC), "added")
lo, hi = _range(translate_query(query, UTC), "added")
assert lo == expected_lo
assert hi == expected_hi
@@ -309,20 +308,20 @@ class TestWhooshQueryRewriting:
@time_machine.travel(datetime(2026, 3, 28, 15, 0, tzinfo=UTC), tick=False)
def test_compact_date_shim_rewrites_to_iso(self) -> None:
result = rewrite_natural_date_keywords("created:20240115120000", UTC)
result = translate_query("created:20240115120000", UTC)
assert "2024-01-15" in result
assert "20240115120000" not in result
@time_machine.travel(datetime(2026, 3, 28, 15, 0, tzinfo=UTC), tick=False)
def test_relative_range_shim_removes_now(self) -> None:
result = rewrite_natural_date_keywords("added:[now-7d TO now]", UTC)
result = translate_query("added:[now-7d TO now]", UTC)
assert "now" not in result
assert "2026-03-" in result
@time_machine.travel(datetime(2026, 3, 28, 12, 0, tzinfo=UTC), tick=False)
def test_bracket_minus_7_days(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords("added:[-7 days to now]", UTC),
translate_query("added:[-7 days to now]", UTC),
"added",
)
assert lo == "2026-03-21T12:00:00Z"
@@ -331,7 +330,7 @@ class TestWhooshQueryRewriting:
@time_machine.travel(datetime(2026, 3, 28, 12, 0, tzinfo=UTC), tick=False)
def test_bracket_minus_1_week(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords("added:[-1 week to now]", UTC),
translate_query("added:[-1 week to now]", UTC),
"added",
)
assert lo == "2026-03-21T12:00:00Z"
@@ -341,7 +340,7 @@ class TestWhooshQueryRewriting:
def test_bracket_minus_1_month_uses_relativedelta(self) -> None:
# relativedelta(months=1) from 2026-03-28 = 2026-02-28 (not 29)
lo, hi = _range(
rewrite_natural_date_keywords("created:[-1 month to now]", UTC),
translate_query("created:[-1 month to now]", UTC),
"created",
)
assert lo == "2026-02-28T12:00:00Z"
@@ -350,7 +349,7 @@ class TestWhooshQueryRewriting:
@time_machine.travel(datetime(2026, 3, 28, 12, 0, tzinfo=UTC), tick=False)
def test_bracket_minus_1_year(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords("modified:[-1 year to now]", UTC),
translate_query("modified:[-1 year to now]", UTC),
"modified",
)
assert lo == "2025-03-28T12:00:00Z"
@@ -359,7 +358,7 @@ class TestWhooshQueryRewriting:
@time_machine.travel(datetime(2026, 3, 28, 12, 0, tzinfo=UTC), tick=False)
def test_bracket_plural_unit_hours(self) -> None:
lo, hi = _range(
rewrite_natural_date_keywords("added:[-3 hours to now]", UTC),
translate_query("added:[-3 hours to now]", UTC),
"added",
)
assert lo == "2026-03-28T09:00:00Z"
@@ -367,7 +366,7 @@ class TestWhooshQueryRewriting:
@time_machine.travel(datetime(2026, 3, 28, 12, 0, tzinfo=UTC), tick=False)
def test_bracket_case_insensitive(self) -> None:
result = rewrite_natural_date_keywords("added:[-1 WEEK TO NOW]", UTC)
result = translate_query("added:[-1 WEEK TO NOW]", UTC)
assert "now" not in result.lower()
lo, hi = _range(result, "added")
assert lo == "2026-03-21T12:00:00Z"
@@ -377,7 +376,7 @@ class TestWhooshQueryRewriting:
def test_relative_range_swaps_bounds_when_lo_exceeds_hi(self) -> None:
# [now+1h TO now-1h] has lo > hi before substitution; they must be swapped
lo, hi = _range(
rewrite_natural_date_keywords("added:[now+1h TO now-1h]", UTC),
translate_query("added:[now+1h TO now-1h]", UTC),
"added",
)
assert lo == "2026-03-28T11:00:00Z"
@@ -385,14 +384,14 @@ class TestWhooshQueryRewriting:
def test_8digit_created_date_field_always_uses_utc_midnight(self) -> None:
# created is a DateField: boundaries are always UTC midnight, no TZ offset
result = rewrite_natural_date_keywords("created:20231201", EASTERN)
result = translate_query("created:20231201", EASTERN)
lo, hi = _range(result, "created")
assert lo == "2023-12-01T00:00:00Z"
assert hi == "2023-12-02T00:00:00Z"
def test_8digit_added_datetime_field_converts_local_midnight_to_utc(self) -> None:
# added is DateTimeField: midnight Dec 1 Eastern (EST = UTC-5) = 05:00 UTC
result = rewrite_natural_date_keywords("added:20231201", EASTERN)
result = translate_query("added:20231201", EASTERN)
lo, hi = _range(result, "added")
assert lo == "2023-12-01T05:00:00Z"
assert hi == "2023-12-02T05:00:00Z"
@@ -400,17 +399,19 @@ class TestWhooshQueryRewriting:
def test_8digit_modified_datetime_field_converts_local_midnight_to_utc(
self,
) -> None:
result = rewrite_natural_date_keywords("modified:20231201", EASTERN)
result = translate_query("modified:20231201", EASTERN)
lo, hi = _range(result, "modified")
assert lo == "2023-12-01T05:00:00Z"
assert hi == "2023-12-02T05:00:00Z"
def test_8digit_invalid_date_passes_through_unchanged(self) -> None:
assert rewrite_natural_date_keywords("added:20231340", UTC) == "added:20231340"
def test_compact_14digit_invalid_date_passes_through_unchanged(self) -> None:
# Month=13 makes datetime() raise ValueError; the token must be left as-is
assert _rewrite_compact_date("20231300120000") == "20231300120000"
def test_8digit_invalid_date_raises(self) -> None:
# The translation pipeline raises InvalidDateQuery for unparsable dates
# (e.g. month=13) so the API can surface a 400 telling the user the date
# is malformed instead of silently returning zero results.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_query("added:20231340", UTC)
assert exc_info.value.field == "added"
assert exc_info.value.value == "20231340"
class TestParseUserQuery:
@@ -463,6 +464,67 @@ class TestParseUserQuery:
) -> None:
assert isinstance(parse_user_query(query_index, raw_query, UTC), tantivy.Query)
@pytest.mark.parametrize(
"raw_query",
[
# Partial date scalar (year only)
pytest.param("created:2020", id="created_year_scalar"),
# 8-digit compact date range in brackets
pytest.param(
"created:[20200101 TO 20201231]",
id="created_8digit_bracket_range",
),
# Comma-separated field + date range (Whoosh v2 multi-clause syntax)
pytest.param(
"title:x,created:[2020 TO 2021]",
id="title_comma_created_range",
),
# Field alias: type -> document_type
pytest.param("type:invoice", id="type_alias"),
# Multi-word date keyword
pytest.param("created:previous week", id="created_previous_week"),
# Full ISO datetime range
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="created_iso_range",
),
# Comma-separated ISO ranges (Whoosh v2 syntax)
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="comma_iso_ranges",
),
],
)
def test_advanced_search_queries_do_not_raise(
self,
query_index: tantivy.Index,
raw_query: str,
) -> None:
"""
End-to-end: queries that the frontend sends must parse without raising.
This tests the full pipeline: translate_query -> tantivy parse_query.
Equivalent to asserting HTTP 200 (not 400) for each query form.
"""
with time_machine.travel(datetime(2026, 6, 15, 12, 0, tzinfo=UTC), tick=False):
assert isinstance(
parse_user_query(query_index, raw_query, UTC),
tantivy.Query,
)
def test_invalid_date_propagates_not_swallowed(
self,
query_index: tantivy.Index,
) -> None:
# parse_user_query falls back to the raw query on unexpected translation
# errors, but an InvalidDateQuery is intentional and must propagate so the
# view can return a 400 instead of silently parsing the raw (invalid) date.
with pytest.raises(InvalidDateQuery) as exc_info:
parse_user_query(query_index, "created:202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
class TestYearRangeRewriting:
"""Whoosh-style year-only date ranges must be rewritten to ISO 8601."""
@@ -514,7 +576,7 @@ class TestYearRangeRewriting:
expected_lo: str,
expected_hi: str,
) -> None:
result = rewrite_natural_date_keywords(query, UTC)
result = translate_query(query, UTC)
lo, hi = _range(result, field)
assert lo == expected_lo
assert hi == expected_hi
@@ -522,14 +584,14 @@ class TestYearRangeRewriting:
def test_reversed_year_range_is_swapped(self) -> None:
# A reversed range must not yield lo > hi, which Tantivy treats as an
# empty range (silently zero results). The bounds are swapped instead.
result = rewrite_natural_date_keywords("created:[2025 TO 2020]", UTC)
result = translate_query("created:[2025 TO 2020]", UTC)
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2026-01-01T00:00:00Z"
def test_year_range_in_complex_boolean_query(self) -> None:
query = "tag:steuer AND (title:2020 OR (NOT title:2019 AND NOT title:2018 AND created:[2020 TO 2020]))"
result = rewrite_natural_date_keywords(query, UTC)
result = translate_query(query, UTC)
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2021-01-01T00:00:00Z"
@@ -539,14 +601,19 @@ class TestYearRangeRewriting:
def test_already_iso_date_range_passes_through_unchanged(self) -> None:
original = "created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]"
assert rewrite_natural_date_keywords(original, UTC) == original
assert translate_query(original, UTC) == original
def test_8digit_in_brackets_not_matched_as_year_range(self) -> None:
# [YYYYMMDD TO YYYYMMDD] has 8-digit values - must not be caught by year rewriter
# [YYYYMMDD TO YYYYMMDD]: the translation layer converts 8-digit bounds to
# ISO day ranges. 20200101 -> 2020-01-01T00:00:00Z (lo of that day);
# 20201231 -> the ceil of Dec 31 = 2021-01-01T00:00:00Z (exclusive end).
# This is the correct and accepted behavior: old compact form becomes a
# proper Tantivy-parseable ISO range.
original = "created:[20200101 TO 20201231]"
result = rewrite_natural_date_keywords(original, UTC)
assert "20200101" in result or "2020-01-01" in result
assert "20201231" in result or "2020-12-31" in result
result = translate_query(original, UTC)
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2021-01-01T00:00:00Z"
class TestNonDateFieldsNotRewritten:
@@ -566,7 +633,7 @@ class TestNonDateFieldsNotRewritten:
],
)
def test_8digit_on_integer_field_passes_through_unchanged(self, query: str) -> None:
assert rewrite_natural_date_keywords(query, EASTERN) == query
assert translate_query(query, EASTERN) == query
@pytest.mark.parametrize(
"query",
@@ -580,12 +647,12 @@ class TestNonDateFieldsNotRewritten:
self,
query: str,
) -> None:
assert rewrite_natural_date_keywords(query, UTC) == query
assert translate_query(query, UTC) == query
def test_unknown_field_keyword_passes_through_unchanged(self) -> None:
# foobar is not a date field: 'foobar:today' must not become a date range,
# which Tantivy would otherwise reject as an unknown/typed field.
assert rewrite_natural_date_keywords("foobar:today", UTC) == "foobar:today"
assert translate_query("foobar:today", UTC) == "foobar:today"
class TestPassthrough:
@@ -593,27 +660,39 @@ class TestPassthrough:
def test_bare_keyword_no_field_prefix_unchanged(self) -> None:
# Bare 'today' with no field: prefix passes through unchanged
result = rewrite_natural_date_keywords("bank statement today", UTC)
result = translate_query("bank statement today", UTC)
assert "today" in result
def test_unrelated_query_unchanged(self) -> None:
assert rewrite_natural_date_keywords("title:invoice", UTC) == "title:invoice"
assert translate_query("title:invoice", UTC) == "title:invoice"
class TestNormalizeQuery:
"""normalize_query expands comma-separated values and collapses whitespace."""
"""translate_query expands comma-separated values and collapses whitespace."""
def test_normalize_expands_comma_separated_tags(self) -> None:
assert normalize_query("tag:foo,bar") == "tag:foo AND tag:bar"
assert translate_query("tag:foo,bar", UTC) == "tag:foo AND tag:bar"
def test_normalize_comma_between_range_expressions(self) -> None:
# Comma-separated field range expressions (Whoosh v2 syntax) must be
# converted to AND so Tantivy does not receive an invalid comma.
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert translate_query(q, UTC) == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_normalize_expands_three_values(self) -> None:
assert normalize_query("tag:foo,bar,baz") == "tag:foo AND tag:bar AND tag:baz"
assert (
translate_query("tag:foo,bar,baz", UTC) == "tag:foo AND tag:bar AND tag:baz"
)
def test_normalize_collapses_whitespace(self) -> None:
assert normalize_query("bank statement") == "bank statement"
assert translate_query("bank statement", UTC) == "bank statement"
def test_normalize_no_commas_unchanged(self) -> None:
assert normalize_query("bank statement") == "bank statement"
assert translate_query("bank statement", UTC) == "bank statement"
@pytest.mark.parametrize(
("raw", "expected"),
@@ -656,7 +735,7 @@ class TestNormalizeQuery:
],
)
def test_normalize_strips_dangling_operators(self, raw: str, expected: str) -> None:
assert normalize_query(raw) == expected
assert translate_query(raw, UTC) == expected
@pytest.mark.parametrize(
"query",
@@ -668,7 +747,7 @@ class TestNormalizeQuery:
],
)
def test_normalize_preserves_valid_operators(self, query: str) -> None:
assert normalize_query(query) == query
assert translate_query(query, UTC) == query
class TestParseSimpleTextHighlightQuery:
@@ -0,0 +1,742 @@
from __future__ import annotations
from datetime import UTC
from datetime import datetime
from typing import TYPE_CHECKING
from zoneinfo import ZoneInfo
import pytest
import time_machine
from documents.search._dates import _precision_bounds
if TYPE_CHECKING:
import tantivy
from documents.search._query import _FIELD_BOOSTS
from documents.search._query import DEFAULT_SEARCH_FIELDS
from documents.search._translate import OPEN_HI
from documents.search._translate import OPEN_LO
from documents.search._translate import Comma
from documents.search._translate import FieldRange
from documents.search._translate import FieldValue
from documents.search._translate import FieldValueList
from documents.search._translate import InvalidDateQuery
from documents.search._translate import Passthrough
from documents.search._translate import resolve_commas
from documents.search._translate import scan
from documents.search._translate import translate_query
from documents.search._translate import translate_range
from documents.search._translate import translate_scalar
@pytest.mark.search
class TestPrecisionBounds:
@pytest.mark.parametrize(
("digits", "expected"),
[
("2020", ((2020, 1, 1), (2021, 1, 1))),
("202003", ((2020, 3, 1), (2020, 4, 1))),
("202012", ((2020, 12, 1), (2021, 1, 1))),
("20200115", ((2020, 1, 15), (2020, 1, 16))),
("20201231", ((2020, 12, 31), (2021, 1, 1))),
],
)
def test_valid(self, digits, expected):
lo, hi = _precision_bounds(digits)
assert (lo.year, lo.month, lo.day) == expected[0]
assert (hi.year, hi.month, hi.day) == expected[1]
@pytest.mark.parametrize("digits", ["202023", "20200230", "20201301", "20", "abcd"])
def test_invalid_returns_none(self, digits):
assert _precision_bounds(digits) is None
@pytest.mark.search
class TestScan:
def test_plain_words_are_passthrough(self):
assert scan("bank statement") == [Passthrough("bank statement")]
def test_field_value(self):
assert scan("created:2020") == [FieldValue("created", "2020")]
def test_field_value_in_boolean(self):
toks = scan("created:2020 OR foo")
assert toks == [
FieldValue("created", "2020"),
Passthrough(" OR foo"),
]
def test_field_value_in_parens(self):
toks = scan("(created:2020 OR foo)")
assert toks == [
Passthrough("("),
FieldValue("created", "2020"),
Passthrough(" OR foo)"),
]
def test_quoted_value(self):
assert scan('correspondent:"A B"') == [FieldValue("correspondent", '"A B"')]
def test_field_range(self):
assert scan("created:[2020 TO 2021]") == [
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.parametrize(
("query", "expected"),
[
pytest.param(
"created:[2020 to]",
FieldRange("created", "[", "2020", "", "]"),
id="open_upper",
),
pytest.param(
"created:[to 2020]",
FieldRange("created", "[", "", "2020", "]"),
id="open_lower",
),
],
)
def test_open_range(self, query, expected):
assert scan(query) == [expected]
def test_comma_inside_range_not_split(self):
# No depth-0 comma here; the whole thing is one range token.
toks = scan("created:[2020 TO 2021]")
assert len(toks) == 1
# --- Edge-case / regression tests (scan must never raise) ---
def test_url_is_passthrough(self):
# "http" is not a known field; the whole URL must pass through verbatim.
assert scan("http://example.com") == [Passthrough("http://example.com")]
def test_unterminated_quote_is_passthrough(self):
# title is a known field but the quoted value has no closing quote;
# _consume_value returns None so the whole string falls into passthrough.
assert scan('title:"abc') == [Passthrough('title:"abc')]
def test_unterminated_bracket_is_passthrough(self):
# created is a known field but the range bracket is never closed;
# _consume_range returns None so the whole string falls into passthrough.
assert scan("created:[2020") == [Passthrough("created:[2020")]
def test_empty_value_at_end_is_passthrough(self):
# created is a known field but there is no value after the colon
# (_consume_value returns None for start >= n), so passthrough.
assert scan("created:") == [Passthrough("created:")]
def test_value_containing_colon(self):
# The bare-word value reader stops at whitespace/paren, not at colon,
# so "2020:30" is consumed as a single value token.
assert scan("created:2020:30") == [FieldValue("created", "2020:30")]
def test_comma_followed_by_unconsumable_value_stops(self):
# A comma followed by whitespace is neither a value-list continuation nor a
# clause separator: the value stops and the comma stays as passthrough.
assert scan("tag:foo, bar") == [
FieldValue("tag", "foo"),
Passthrough(", bar"),
]
def test_bracket_without_to_is_open_upper_bound(self):
# A bracketed value with no TO falls back to (value, "") -> open upper bound.
assert scan("created:[2020]") == [
FieldRange("created", "[", "2020", "", "]"),
]
def test_known_field_name_midword_is_passthrough(self):
# A known field name embedded mid-word is not a field token (the
# word-boundary guard); the whole run stays passthrough.
assert scan("xtag:foo") == [Passthrough("xtag:foo")]
@pytest.mark.search
class TestCommaResolution:
def test_value_list_multi_value_field(self):
toks = resolve_commas(scan("tag:foo,bar"))
assert toks == [FieldValueList("tag", ("foo", "bar"))]
def test_value_list_three(self):
toks = resolve_commas(scan("tag_id:1,2,3"))
assert toks == [FieldValueList("tag_id", ("1", "2", "3"))]
def test_text_field_comma_is_literal(self):
# correspondent is not multi-value: comma stays inside the value.
toks = resolve_commas(scan("correspondent:foo,bar"))
assert toks == [FieldValue("correspondent", "foo,bar")]
def test_clause_separator_before_known_field(self):
toks = resolve_commas(scan("tag:foo,type:bar"))
assert toks == [FieldValue("tag", "foo"), Comma(), FieldValue("type", "bar")]
def test_clause_separator_after_range(self):
toks = resolve_commas(scan("created:[2020 TO 2021],added:[2022 TO 2023]"))
assert toks == [
FieldRange("created", "[", "2020", "2021", "]"),
Comma(),
FieldRange("added", "[", "2022", "2023", "]"),
]
def test_clause_separator_after_quote(self):
toks = resolve_commas(scan('correspondent:"A B",created:[2020 TO 2021]'))
assert toks == [
FieldValue("correspondent", '"A B"'),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
def test_url_comma_is_literal_passthrough(self):
toks = resolve_commas(scan("http://example.com/a,b"))
assert toks == [Passthrough("http://example.com/a,b")]
def test_non_multi_value_comma_is_literal(self):
# title is not in MULTI_VALUE_FIELDS: comma stays inside the value.
toks = resolve_commas(scan("title:10,20"))
assert toks == [FieldValue("title", "10,20")]
def test_clause_separator_before_known_date_field(self):
# The comma between a bare value and a known date field acts as a
# clause separator; both sides survive as distinct tokens.
toks = resolve_commas(scan("correspondent:foo,created:[2020 TO 2021]"))
assert toks == [
FieldValue("correspondent", "foo"),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.search
class TestTranslateScalar:
@pytest.mark.parametrize(
("field", "value", "expected"),
[
(
"created",
"2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"created",
"202003",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
(
"created",
"20200115",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-01-15",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-03",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
],
)
def test_partial_and_iso_dates(self, field: str, value: str, expected: str) -> None:
assert translate_scalar(field, value, UTC) == expected
def test_invalid_date_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_keyword_delegates(self) -> None:
# keyword path produces a range; just assert it is a created range
out = translate_scalar("created", "today", UTC)
assert out.startswith("created:[") and out.endswith("]")
def test_14digit_compact_datetime(self) -> None:
out = translate_scalar("created", "20240115120000", UTC)
assert "20240115120000" not in out
assert out.startswith("created:")
assert out == "created:[2024-01-15T12:00:00Z TO 2024-01-15T12:00:00Z]"
def test_14digit_invalid_month_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "20231300120000", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "20231300120000"
def test_unrecognized_value_raises(self) -> None:
# A value that is not a keyword, digits, ISO date, or compact timestamp
# raises rather than producing invalid Tantivy syntax or silently matching
# nothing.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateRange:
@pytest.mark.parametrize(
("lo", "hi", "expected"),
[
("2005", "2009", "created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"),
(
"202001",
"202006",
"created:[2020-01-01T00:00:00Z TO 2020-07-01T00:00:00Z]",
),
(
"20200101",
"20201231",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"2020-01-01",
"2020-12-31",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
],
)
def test_absolute_ranges(self, lo, hi, expected):
assert translate_range("created", lo, hi, UTC) == expected
def test_reversed_swaps(self):
assert translate_range("created", "2009", "2005", UTC) == (
"created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"
)
def test_open_upper(self):
out = translate_range("created", "2020", "", UTC)
assert out == f"created:[2020-01-01T00:00:00Z TO {OPEN_HI}]"
def test_open_lower(self):
out = translate_range("created", "", "2020", UTC)
assert out == f"created:[{OPEN_LO} TO 2021-01-01T00:00:00Z]"
def test_invalid_bound_raises(self):
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "202023", "2025", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_invalid_high_bound_raises(self):
# Low bound parses, high bound does not -> raise on the high bound.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "2020", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateQuery:
@pytest.mark.parametrize(
("raw", "expected"),
[
(
"created:2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
("tag:foo,bar", "tag:foo AND tag:bar"),
# 'type' is a user-facing alias rewritten to 'document_type' (the real schema field)
("tag:foo,type:bar", "tag:foo AND document_type:bar"),
(
"created:[2020 TO 2021],added:[2022 TO 2023]",
"created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
" AND "
"added:[2022-01-01T00:00:00Z TO 2024-01-01T00:00:00Z]",
),
# correspondent is not multi-value: comma stays literal inside the value
("correspondent:foo,bar", "correspondent:foo,bar"),
],
)
def test_golden(self, raw: str, expected: str) -> None:
assert translate_query(raw, UTC) == expected
@pytest.mark.parametrize(
"raw",
[
"created:2020",
"created:202003",
"created:[20200101 TO 20201231]",
"created:[2020-01-01 TO 2020-12-31]",
"created:[2020 to]",
"created:[to 2020]",
"title:x,created:[2020 TO 2021]",
"created:2020 OR foo",
"(created:2020 OR invoice)",
"tag:foo,type:bar",
"bank statement",
],
)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
# Must not raise:
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestFieldAliasing:
"""Whoosh->Tantivy field-name aliasing (type/path -> document_type/storage_path)."""
def test_type_alias(self) -> None:
assert translate_query("type:invoice", UTC) == "document_type:invoice"
def test_path_alias(self) -> None:
assert translate_query("path:/foo/bar", UTC) == "storage_path:/foo/bar"
def test_type_id_alias(self) -> None:
assert translate_query("type_id:5", UTC) == "document_type_id:5"
def test_path_id_alias(self) -> None:
assert translate_query("path_id:7", UTC) == "storage_path_id:7"
def test_clause_separator_plus_alias(self) -> None:
# Comma between known fields acts as AND separator; alias still applied.
assert (
translate_query("tag:foo,type:bar", UTC) == "tag:foo AND document_type:bar"
)
def test_type_range_alias(self) -> None:
# type is not a date field; range passes through verbatim with alias applied.
assert (
translate_query("type:[2020 TO 2021]", UTC)
== "document_type:[2020 TO 2021]"
)
def test_parse_acceptance_type(self, index: tantivy.Index) -> None:
# Translated output must be accepted by the real Tantivy parser.
translated = translate_query("type:invoice", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_path(self, index: tantivy.Index) -> None:
translated = translate_query("path:foo", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
# Freeze time so relative-date tests are deterministic.
_FROZEN_NOW = datetime(2026, 3, 28, 12, 0, 0, tzinfo=UTC)
@pytest.mark.search
class TestRelativeRanges:
"""Relative date-range tokens resolved against a frozen clock."""
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_7_days_to_now(self) -> None:
assert translate_query("added:[-7 days to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_week_to_now(self) -> None:
assert translate_query("added:[-1 week to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_month_to_now(self) -> None:
assert translate_query("created:[-1 month to now]", UTC) == (
"created:[2026-02-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_year_to_now(self) -> None:
assert translate_query("modified:[-1 year to now]", UTC) == (
"modified:[2025-03-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_3_hours_to_now(self) -> None:
assert translate_query("added:[-3 hours to now]", UTC) == (
"added:[2026-03-28T09:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_uppercase_units(self) -> None:
assert translate_query("added:[-1 WEEK TO NOW]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_now_minus_7d_compact(self) -> None:
assert translate_query("added:[now-7d TO now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_reversed_range_swapped(self) -> None:
# now+1h TO now-1h is reversed; translate_range swaps -> lo=now-1h, hi=now+1h
assert translate_query("added:[now+1h TO now-1h]", UTC) == (
"added:[2026-03-28T11:00:00Z TO 2026-03-28T13:00:00Z]"
)
@pytest.mark.parametrize(
"raw",
[
"added:[-7 days to now]",
"added:[-1 week to now]",
"created:[-1 month to now]",
"modified:[-1 year to now]",
"added:[-3 hours to now]",
"added:[now-7d TO now]",
"added:[now+1h TO now-1h]",
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestOperatorNormalization:
"""Post-render operator normalization in translate_query."""
def test_spaced_dash_removed(self) -> None:
assert (
translate_query("H52.1 - Kurzsichtigkeit", UTC) == "H52.1 Kurzsichtigkeit"
)
def test_spaced_dash_simple(self) -> None:
assert translate_query("bar - baz", UTC) == "bar baz"
def test_trailing_operator_stripped(self) -> None:
assert translate_query("foo -", UTC) == "foo"
def test_date_range_preserved(self) -> None:
out = translate_query("created:[2020 TO 2021]", UTC)
# Must not corrupt the ISO range
assert out == "created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
def test_date_scalar_with_or(self) -> None:
out = translate_query("created:2020 OR foo", UTC)
# The created scalar becomes a range; " OR foo" passes through verbatim.
assert out.startswith("created:[")
assert "OR foo" in out
def test_parse_acceptance_spaced_dash(self, index: tantivy.Index) -> None:
translated = translate_query("H52.1 - Kurzsichtigkeit", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_trailing_op(self, index: tantivy.Index) -> None:
translated = translate_query("foo -", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestMultiWordDateKeywords:
"""scan() must consume multi-word date keywords as a single value."""
def test_scan_previous_week_as_single_token(self) -> None:
# "created:previous week" must produce one FieldValue with value "previous week",
# not FieldValue("created","previous") + Passthrough(" week").
toks = scan("created:previous week")
assert toks == [FieldValue("created", "previous week")]
def test_scan_this_month_as_single_token(self) -> None:
toks = scan("added:this month")
assert toks == [FieldValue("added", "this month")]
def test_scan_previous_month_as_single_token(self) -> None:
toks = scan("created:previous month")
assert toks == [FieldValue("created", "previous month")]
def test_scan_this_year_as_single_token(self) -> None:
toks = scan("added:this year")
assert toks == [FieldValue("added", "this year")]
def test_scan_previous_year_as_single_token(self) -> None:
toks = scan("created:previous year")
assert toks == [FieldValue("created", "previous year")]
def test_scan_previous_quarter_as_single_token(self) -> None:
toks = scan("created:previous quarter")
assert toks == [FieldValue("created", "previous quarter")]
def test_quoted_multi_word_keyword_still_works(self) -> None:
# The quoted form must continue to work as before.
toks = scan('created:"previous week"')
assert toks == [FieldValue("created", '"previous week"')]
def test_non_date_field_not_affected(self) -> None:
# "previous" stops at the space for non-date fields; " week" passes through.
toks = scan("correspondent:previous week")
assert toks == [
FieldValue("correspondent", "previous"),
Passthrough(" week"),
]
@pytest.mark.search
class TestKeywordDateResolution:
"""Relative date keywords resolve to exact ISO ranges against a frozen clock.
Frozen at 2026-03-28 12:00 UTC (a Saturday in Q1) so the week, month,
quarter and year rollovers are all exercised by a single anchor.
"""
# created is a DateField: bounds are UTC midnight, no timezone offset.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"created:[2026-03-28T00:00:00Z TO 2026-03-29T00:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"created:[2026-03-27T00:00:00Z TO 2026-03-28T00:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"created:[2026-03-16T00:00:00Z TO 2026-03-23T00:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"created:[2026-03-01T00:00:00Z TO 2026-04-01T00:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"created:[2026-02-01T00:00:00Z TO 2026-03-01T00:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"created:[2026-01-01T00:00:00Z TO 2027-01-01T00:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"created:[2025-01-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"created:[2025-10-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_date_only_field_keyword_ranges(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"created:{keyword}", UTC) == expected
# added is a DateTimeField: local-tz midnight converted to UTC. Tokyo
# (+09:00, no DST) shifts each midnight boundary back to 15:00Z the day
# before, so this also exercises the local-midnight offset path.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"added:[2026-03-27T15:00:00Z TO 2026-03-28T15:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"added:[2026-03-26T15:00:00Z TO 2026-03-27T15:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"added:[2026-03-15T15:00:00Z TO 2026-03-22T15:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"added:[2026-02-28T15:00:00Z TO 2026-03-31T15:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"added:[2026-01-31T15:00:00Z TO 2026-02-28T15:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"added:[2025-12-31T15:00:00Z TO 2026-12-31T15:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"added:[2024-12-31T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"added:[2025-09-30T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_datetime_field_keyword_ranges_local_tz(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"added:{keyword}", ZoneInfo("Asia/Tokyo")) == expected
@pytest.mark.search
class TestISODatetimeBounds:
"""Full ISO datetime tokens in range bounds must be parsed directly."""
def test_translate_range_iso_bounds_passthrough(self) -> None:
# Already-ISO datetime bounds must pass through as-is (exact instant).
result = translate_range(
"created",
"2020-01-01T00:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert result == "created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]"
def test_translate_query_iso_range_preserved(self) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert translate_query(q, UTC) == q
def test_translate_query_comma_separated_iso_ranges(self) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
result = translate_query(q, UTC)
assert result == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_invalid_iso_datetime_raises(self) -> None:
# A token with "T" that is not valid ISO datetime -> raise.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range(
"created",
"2020-01-01T99:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert exc_info.value.field == "created"
assert exc_info.value.value == "2020-01-01T99:00:00Z"
def test_parse_acceptance_iso_bounds(self, index: tantivy.Index) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_comma_iso_ranges(self, index: tantivy.Index) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
+4 -4
View File
@@ -844,7 +844,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.vector_store_file_exists") as mock_exists,
patch("paperless.views.llm_index_exists") as mock_exists,
):
mock_exists.return_value = False
self.client.patch(
@@ -869,7 +869,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.vector_store_file_exists") as mock_exists,
patch("paperless.views.llm_index_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
@@ -890,7 +890,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.vector_store_file_exists") as mock_exists,
patch("paperless.views.llm_index_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
@@ -928,7 +928,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.vector_store_file_exists") as mock_exists,
patch("paperless.views.llm_index_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
@@ -6,6 +6,7 @@ import zipfile
from django.contrib.auth.models import User
from django.test import override_settings
from django.utils import timezone
from rest_framework import status
from rest_framework.test import APITestCase
@@ -32,21 +33,21 @@ class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase):
filename="docA.pdf",
mime_type="application/pdf",
checksum="B",
created=datetime.datetime(2021, 1, 1, tzinfo=datetime.UTC),
created=timezone.make_aware(datetime.datetime(2021, 1, 1)),
)
self.doc2b = Document.objects.create(
title="document A",
filename="docA2.pdf",
mime_type="application/pdf",
checksum="D",
created=datetime.datetime(2021, 1, 1, tzinfo=datetime.UTC),
created=timezone.make_aware(datetime.datetime(2021, 1, 1)),
)
self.doc3 = Document.objects.create(
title="document B",
filename="docB.jpg",
mime_type="image/jpeg",
checksum="C",
created=datetime.datetime(2020, 3, 21, tzinfo=datetime.UTC),
created=timezone.make_aware(datetime.datetime(2020, 3, 21)),
archive_filename="docB.pdf",
archive_checksum="D",
)
@@ -1,5 +1,5 @@
import datetime
import json
from datetime import date
from unittest import mock
from unittest.mock import ANY
@@ -456,7 +456,7 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
},
)
date_value = datetime.datetime.now(tz=datetime.UTC).date()
date_value = date.today()
resp = self.client.patch(
f"/api/documents/{doc.id}/",
@@ -618,7 +618,7 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
data_type=CustomField.FieldDataType.DATE,
)
date_value = datetime.datetime.now(tz=datetime.UTC).date()
date_value = date.today()
resp = self.client.patch(
f"/api/documents/{doc.id}/",
+1 -1
View File
@@ -265,7 +265,7 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
created=date(2023, 1, 1),
)
created_datetime = datetime.datetime(2023, 2, 1, 12, 0, 0, tzinfo=datetime.UTC)
created_datetime = datetime.datetime(2023, 2, 1, 12, 0, 0)
response = self.client.patch(
f"/api/documents/{doc.pk}/",
{"created": created_datetime},
@@ -0,0 +1,95 @@
import unicodedata
from typing import TYPE_CHECKING
from unittest import mock
import celery.result
import pytest
from django.core.files.uploadedfile import SimpleUploadedFile
if TYPE_CHECKING:
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides
@pytest.fixture()
def consume_file_mock():
with mock.patch("documents.tasks.consume_file.apply_async") as m:
m.return_value = celery.result.AsyncResult(id="test-task-id")
yield m
@pytest.fixture()
def directories(tmp_path, settings, _media_settings):
scratch = tmp_path / "scratch"
scratch.mkdir()
settings.SCRATCH_DIR = scratch
return scratch
@pytest.mark.django_db
class TestPostDocumentNFCNormalization:
def test_nfd_filename_normalized_to_nfc(
self,
admin_client,
consume_file_mock: mock.MagicMock,
directories,
):
"""Uploaded file with NFD filename must have its name stored as NFC."""
nfd = unicodedata.normalize("NFD", "Rechnung März.pdf")
nfc = unicodedata.normalize("NFC", "Rechnung März.pdf")
# Verify our test strings actually differ at the byte level
assert nfd != nfc
uploaded = SimpleUploadedFile(
nfd,
b"%PDF-1.4 test",
content_type="application/pdf",
)
response = admin_client.post(
"/api/documents/post_document/",
{"document": uploaded},
)
assert response.status_code == 200
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
input_doc: ConsumableDocument = task_kwargs["input_doc"]
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
# The temp file on disk must have an NFC name
assert input_doc.original_file.name == nfc, (
f"Expected NFC filename {nfc!r}, got {input_doc.original_file.name!r}"
)
# The override filename stored for later use must also be NFC
assert overrides.filename == nfc, (
f"Expected NFC override filename {nfc!r}, got {overrides.filename!r}"
)
assert unicodedata.is_normalized("NFC", overrides.filename)
def test_already_nfc_filename_unchanged(
self,
admin_client,
consume_file_mock: mock.MagicMock,
directories,
):
"""Uploaded file with already-NFC filename must pass through unchanged."""
nfc = unicodedata.normalize("NFC", "Invoice_2024.pdf")
uploaded = SimpleUploadedFile(
nfc,
b"%PDF-1.4 test",
content_type="application/pdf",
)
response = admin_client.post(
"/api/documents/post_document/",
{"document": uploaded},
)
assert response.status_code == 200
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
assert overrides.filename == nfc
assert unicodedata.is_normalized("NFC", overrides.filename)
+20 -33
View File
@@ -700,7 +700,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
pk=3,
checksum="C",
# specific time zone aware date
added=datetime.datetime(2023, 12, 1, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2023, 12, 1)),
)
# refresh doc instance to ensure we operate on date objects that Django uses
# Django converts dates to UTC
@@ -725,9 +725,11 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
GIVEN:
- One document added right now
WHEN:
- Query with invalid added date
- Query with an invalid added date
THEN:
- 400 Bad Request returned (Tantivy rejects invalid date field syntax)
- 400 Bad Request with a message naming the malformed date, so the
user knows their date is invalid rather than silently getting zero
results
"""
d1 = Document.objects.create(
title="invoice",
@@ -740,8 +742,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
response = self.client.get("/api/documents/?query=added:invalid-date")
# Tantivy rejects unparsable field queries with a 400
# An unparsable date is reported as a malformed query, not silently empty.
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("invalid-date", str(response.data["query"]))
@override_settings(
TIME_ZONE="UTC",
@@ -994,25 +997,25 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
title="invoice",
content="the thing i bought at a shop and paid with bank account",
created=datetime.date(2018, 1, 1),
added=datetime.datetime(2018, 1, 1, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2018, 1, 1)),
)
d2 = DocumentFactory(
title="bank statement 1",
content="things i paid for in august",
created=datetime.date(2019, 3, 4),
added=datetime.datetime(2019, 3, 4, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2019, 3, 4)),
)
d3 = DocumentFactory(
title="bank statement 3",
content="things i paid for in september",
created=datetime.date(2020, 7, 9),
added=datetime.datetime(2020, 7, 9, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2020, 7, 9)),
)
d4 = DocumentFactory(
title="Quarterly Report",
content="quarterly revenue profit margin earnings growth",
created=datetime.date(2021, 11, 30),
added=datetime.datetime(2021, 11, 30, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2021, 11, 30)),
)
backend = get_backend()
backend.add_or_update(d1)
@@ -1131,7 +1134,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d4.tags.add(t2)
d5 = Document.objects.create(
checksum="5",
added=datetime.datetime(2020, 7, 13, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2020, 7, 13)),
content="test",
original_filename="doc5.pdf",
)
@@ -1241,18 +1244,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d4.id,
search_query(
"&created__date__lt="
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
),
)
self.assertNotIn(
d4.id,
search_query(
"&created__date__gt="
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
),
)
@@ -1260,18 +1259,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d4.id,
search_query(
"&created__date__lt="
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
),
)
self.assertIn(
d4.id,
search_query(
"&created__date__gt="
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
),
)
@@ -1279,18 +1274,14 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d5.id,
search_query(
"&added__date__lt="
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
),
)
self.assertNotIn(
d5.id,
search_query(
"&added__date__gt="
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
),
)
@@ -1298,9 +1289,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d5.id,
search_query(
"&added__date__lt="
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
),
)
@@ -1308,9 +1297,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d5.id,
search_query(
"&added__date__gt="
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
),
)
+71
View File
@@ -216,6 +216,77 @@ class TestSystemStatus(APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_none(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns no worker responses
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
mock_ping.return_value = None
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(
response.data["tasks"]["celery_error"],
"No celery workers responded to ping. This may be temporary.",
)
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_unexpected_responses(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns an unexpected worker response
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
self.client.force_login(self.user)
for ping_response in (
{"hostname": {"ok": "not-pong"}},
{"hostname": {}},
{"hostname": "pong"},
):
with self.subTest(ping_response=ping_response):
mock_ping.return_value = ping_response
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(response.data["tasks"]["celery_url"], "hostname")
self.assertEqual(
response.data["tasks"]["celery_error"],
"Celery worker responded unexpectedly.",
)
@mock.patch("documents.views.sleep")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_retry_success(
self,
mock_ping,
mock_sleep,
) -> None:
"""
GIVEN:
- Celery ping fails once but succeeds on retry
WHEN:
- The user requests the system status
THEN:
- The response contains an OK celery status
"""
mock_ping.side_effect = [None, {"hostname": {"ok": "pong"}}]
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
self.assertIsNone(response.data["tasks"]["celery_error"])
self.assertEqual(mock_ping.call_count, 2)
mock_sleep.assert_called_once_with(0.25)
@mock.patch("documents.search.get_backend")
def test_system_status_index_ok(self, mock_get_backend) -> None:
"""
+181
View File
@@ -18,6 +18,7 @@ from guardian.shortcuts import assign_perm
from rest_framework import status
from rest_framework.test import APIClient
from documents.filters import PaperlessTaskFilterSet
from documents.models import PaperlessTask
from documents.tests.factories import DocumentFactory
from documents.tests.factories import PaperlessTaskFactory
@@ -169,6 +170,165 @@ class TestGetTasksV10:
PaperlessTask.Status.STARTED,
}
def test_filter_by_task_name(self, admin_client: APIClient) -> None:
"""?name= searches task filenames, task types, and trigger sources."""
filename_task = PaperlessTaskFactory(input_data={"filename": "invoice-123.pdf"})
type_task = PaperlessTaskFactory(task_type=PaperlessTask.TaskType.SANITY_CHECK)
source_task = PaperlessTaskFactory(
trigger_source=PaperlessTask.TriggerSource.EMAIL_CONSUME,
)
PaperlessTaskFactory(input_data={"filename": "unrelated.pdf"})
response = admin_client.get(ENDPOINT, {"name": "invoice"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == filename_task.task_id
response = admin_client.get(ENDPOINT, {"name": "sanity"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == type_task.task_id
response = admin_client.get(ENDPOINT, {"name": "email"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == source_task.task_id
def test_filter_by_task_result(self, admin_client: APIClient) -> None:
"""?result= searches common structured task result messages."""
reason_task = PaperlessTaskFactory(result_data={"reason": "Manual review"})
error_task = PaperlessTaskFactory(
result_data={"error_message": "Duplicate detected"},
)
document_task = PaperlessTaskFactory(result_data={"document_id": 321})
duplicate_task = PaperlessTaskFactory(result_data={"duplicate_of": 123})
PaperlessTaskFactory(result_data={"reason": "unrelated"})
response = admin_client.get(ENDPOINT, {"result": "manual"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == reason_task.task_id
response = admin_client.get(ENDPOINT, {"result": "duplicate"})
assert response.status_code == status.HTTP_200_OK
returned_ids = {task["task_id"] for task in response.data["results"]}
assert returned_ids == {error_task.task_id, duplicate_task.task_id}
response = admin_client.get(ENDPOINT, {"result": "321"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == document_task.task_id
def test_empty_task_name_and_result_filters(self) -> None:
"""Empty name/result values leave the queryset unchanged."""
PaperlessTaskFactory.create_batch(2)
queryset = PaperlessTask.objects.all()
filterset = PaperlessTaskFilterSet()
assert filterset.filter_name(queryset, "name", "").count() == 2
assert filterset.filter_result(queryset, "result", "").count() == 2
def test_status_counts_respects_filters(self, admin_client: APIClient) -> None:
"""status_counts/ returns section counts for the filtered task queryset."""
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.FAILURE,
input_data={"filename": "invoice-a.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.REVOKED,
input_data={"filename": "invoice-b.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.PENDING,
input_data={"filename": "invoice-c.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.STARTED,
input_data={"filename": "invoice-d.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "invoice-e.pdf"},
)
PaperlessTaskFactory(
acknowledged=True,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "invoice-acknowledged.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "unrelated.pdf"},
)
response = admin_client.get(
f"{ENDPOINT}status_counts/",
{"acknowledged": "false", "name": "invoice"},
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
"all": 5,
"needs_attention": 2,
"in_progress": 2,
"completed": 1,
}
def test_status_counts_ignores_section_filters(
self,
admin_client: APIClient,
) -> None:
"""status_counts/ ignores status-like filters for the sections it counts."""
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.FAILURE,
input_data={"filename": "invoice-a.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.PENDING,
input_data={"filename": "invoice-b.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "invoice-c.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.FAILURE,
input_data={"filename": "unrelated.pdf"},
)
response = admin_client.get(
f"{ENDPOINT}status_counts/",
{
"acknowledged": "false",
"name": "invoice",
"status": PaperlessTask.Status.FAILURE,
"is_complete": "false",
},
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
"all": 3,
"needs_attention": 1,
"in_progress": 1,
"completed": 1,
}
def test_default_ordering_is_newest_first(self, admin_client: APIClient) -> None:
"""Tasks are returned in descending date_created order (newest first)."""
base = timezone.now()
@@ -522,6 +682,27 @@ class TestAcknowledge:
assert response.status_code == status.HTTP_200_OK
assert response.data == {"result": 2}
def test_acknowledge_all_returns_count(self, admin_client: APIClient) -> None:
"""POST acknowledge/ with all=true acknowledges all unacknowledged tasks."""
unacknowledged_task1 = PaperlessTaskFactory(acknowledged=False)
unacknowledged_task2 = PaperlessTaskFactory(acknowledged=False)
acknowledged_task = PaperlessTaskFactory(acknowledged=True)
response = admin_client.post(
ENDPOINT + "acknowledge/",
{"all": True},
format="json",
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {"result": 2}
unacknowledged_task1.refresh_from_db()
unacknowledged_task2.refresh_from_db()
acknowledged_task.refresh_from_db()
assert unacknowledged_task1.acknowledged
assert unacknowledged_task2.acknowledged
assert acknowledged_task.acknowledged
def test_acknowledged_tasks_excluded_from_unacked_filter(
self,
admin_client: APIClient,
+120 -13
View File
@@ -3,6 +3,7 @@ from datetime import date
from pathlib import Path
from unittest import mock
import pikepdf
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.test import TestCase
@@ -615,6 +616,18 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.img_doc.archive_filename = img_doc_archive
self.img_doc.save()
@staticmethod
def mock_password_required_pdf(
mock_open: mock.Mock,
fake_pdf: mock.Mock,
) -> None:
password_context = mock.MagicMock()
password_context.__enter__.return_value = fake_pdf
mock_open.side_effect = [
pikepdf.PasswordError("password required"),
password_context,
]
@mock.patch("documents.tasks.consume_file.s")
def test_merge(self, mock_consume_file) -> None:
"""
@@ -764,7 +777,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
sig.set.return_value.apply_async.side_effect = Exception("boom")
mock_consume_file.return_value = sig
with self.assertRaisesRegex(Exception, "boom"):
with self.assertRaises(Exception):
bulk_edit.merge(doc_ids, delete_originals=True)
self.doc1.refresh_from_db()
@@ -1047,7 +1060,6 @@ class TestPDFActions(DirectoriesMixin, TestCase):
for call, expected_id in zip(
mock_consume_delay.call_args_list,
doc_ids,
strict=False,
):
task_kwargs = call.kwargs["kwargs"]
self.assertEqual(task_kwargs["input_doc"].root_document_id, expected_id)
@@ -1306,7 +1318,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
sig.apply_async.side_effect = Exception("boom")
mock_chord.return_value = sig
with self.assertRaisesRegex(Exception, "boom"):
with self.assertRaises(Exception):
bulk_edit.edit_pdf(doc_ids, operations, delete_original=True)
self.doc2.refresh_from_db()
@@ -1418,7 +1430,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
{"page": 9999}, # invalid page, forces error during PDF load
]
with self.assertLogs("paperless.bulk_edit", level="ERROR"):
with self.assertRaises(ValueError):
with self.assertRaises(Exception):
bulk_edit.edit_pdf(doc_ids, operations)
mock_group.assert_not_called()
mock_consume_file.assert_not_called()
@@ -1467,6 +1479,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf = mock.MagicMock()
fake_pdf.pages = [mock.Mock(), mock.Mock(), mock.Mock()]
fake_pdf.is_encrypted = True
def save_side_effect(target_path):
Path(target_path).write_bytes(b"new pdf content")
@@ -1481,7 +1494,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path, password="secret")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(doc.source_path),
mock.call(doc.source_path, password="secret"),
],
)
fake_pdf.remove_unreferenced_resources.assert_called_once()
mock_update_document.assert_not_called()
mock_consume_delay.assert_called_once()
@@ -1495,6 +1514,33 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(task_kwargs["input_doc"].root_document_id, doc.id)
self.assertIsNotNone(task_kwargs["overrides"])
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
@mock.patch("pikepdf.open")
def test_remove_password_update_document_skips_unencrypted_pdf(
self,
mock_open,
mock_mkdtemp,
mock_consume_delay,
) -> None:
doc = self.doc1
fake_pdf = mock.MagicMock()
fake_pdf.is_encrypted = False
mock_open.return_value.__enter__.return_value = fake_pdf
result = bulk_edit.remove_password(
[doc.id],
password="secret",
update_document=True,
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path)
fake_pdf.remove_unreferenced_resources.assert_not_called()
fake_pdf.save.assert_not_called()
mock_mkdtemp.assert_not_called()
mock_consume_delay.assert_not_called()
@mock.patch("documents.bulk_edit.update_document_content_maybe_archive_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
@@ -1514,12 +1560,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_mkdtemp.return_value = str(temp_dir)
fake_pdf = mock.MagicMock()
self.mock_password_required_pdf(mock_open, fake_pdf)
def save_side_effect(target_path):
Path(target_path).write_bytes(b"new pdf content")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
result = bulk_edit.remove_password(
[doc.id],
@@ -1529,7 +1575,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(source_file, password="secret")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(source_file),
mock.call(source_file, password="secret"),
],
)
mock_update_document.assert_not_called()
mock_consume_delay.assert_called_once()
@@ -1548,7 +1600,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
root_document=self.doc1,
)
fake_pdf = mock.MagicMock()
mock_open.return_value.__enter__.return_value = fake_pdf
self.mock_password_required_pdf(mock_open, fake_pdf)
result = bulk_edit.remove_password(
[self.doc1.id],
@@ -1558,7 +1610,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(self.doc1.source_path, password="secret")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(self.doc1.source_path),
mock.call(self.doc1.source_path, password="secret"),
],
)
mock_consume_delay.assert_called_once()
@mock.patch("documents.bulk_edit.chord")
@@ -1581,12 +1639,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf = mock.MagicMock()
fake_pdf.pages = [mock.Mock(), mock.Mock()]
self.mock_password_required_pdf(mock_open, fake_pdf)
def save_side_effect(target_path: Path) -> None:
target_path.write_bytes(b"password removed")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
mock_group.return_value.delay.return_value = None
user = User.objects.create(username="owner")
@@ -1601,7 +1659,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path, password="secret")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(doc.source_path),
mock.call(doc.source_path, password="secret"),
],
)
mock_consume_file.assert_called_once()
call_kwargs = mock_consume_file.call_args.kwargs
consumable_document = call_kwargs["input_doc"]
@@ -1619,6 +1683,43 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_group.return_value.delay.assert_called_once()
mock_chord.assert_not_called()
@mock.patch("documents.bulk_edit.delete")
@mock.patch("documents.bulk_edit.chord")
@mock.patch("documents.bulk_edit.group")
@mock.patch("documents.tasks.consume_file.s")
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
@mock.patch("pikepdf.open")
def test_remove_password_skips_unencrypted_pdf_without_queueing(
self,
mock_open: mock.Mock,
mock_mkdtemp: mock.Mock,
mock_consume_file: mock.Mock,
mock_group: mock.Mock,
mock_chord: mock.Mock,
mock_delete: mock.Mock,
) -> None:
doc = self.doc2
fake_pdf = mock.MagicMock()
fake_pdf.is_encrypted = False
mock_open.return_value.__enter__.return_value = fake_pdf
result = bulk_edit.remove_password(
[doc.id],
password="secret",
update_document=False,
delete_original=True,
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path)
fake_pdf.remove_unreferenced_resources.assert_not_called()
fake_pdf.save.assert_not_called()
mock_mkdtemp.assert_not_called()
mock_consume_file.assert_not_called()
mock_group.assert_not_called()
mock_chord.assert_not_called()
mock_delete.si.assert_not_called()
@mock.patch("documents.bulk_edit.delete")
@mock.patch("documents.bulk_edit.chord")
@mock.patch("documents.bulk_edit.group")
@@ -1641,12 +1742,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf = mock.MagicMock()
fake_pdf.pages = [mock.Mock(), mock.Mock()]
self.mock_password_required_pdf(mock_open, fake_pdf)
def save_side_effect(target_path: Path) -> None:
target_path.write_bytes(b"password removed")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
mock_chord.return_value.delay.return_value = None
result = bulk_edit.remove_password(
@@ -1658,7 +1759,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path, password="secret")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(doc.source_path),
mock.call(doc.source_path, password="secret"),
],
)
mock_consume_file.assert_called_once()
mock_group.assert_not_called()
mock_chord.assert_called_once()
+2 -2
View File
@@ -782,8 +782,8 @@ class TestClassifier(DirectoriesMixin, TestCase):
load_classifier(raise_exception=True)
Path(settings.MODEL_FILE).touch()
mock_load.side_effect = RuntimeError()
with self.assertRaises(RuntimeError):
mock_load.side_effect = Exception()
with self.assertRaises(Exception):
load_classifier(raise_exception=True)
+4 -4
View File
@@ -59,7 +59,7 @@ class TestDoubleSided(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
def create_staging_file(self, src="double-sided-odd.pdf", datetime=None) -> None:
shutil.copy(self.SAMPLE_DIR / src, self.staging_file)
if datetime is None:
datetime = dt.datetime.now(tz=dt.UTC)
datetime = dt.datetime.now()
os.utime(str(self.staging_file), (datetime.timestamp(),) * 2)
def test_odd_numbered_moved_to_staging(self) -> None:
@@ -79,8 +79,8 @@ class TestDoubleSided(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
self.assertIsFile(self.staging_file)
self.assertAlmostEqual(
dt.datetime.fromtimestamp(self.staging_file.stat().st_mtime, tz=dt.UTC),
dt.datetime.now(tz=dt.UTC),
dt.datetime.fromtimestamp(self.staging_file.stat().st_mtime),
dt.datetime.now(),
delta=dt.timedelta(seconds=5),
)
self.assertIn("Received odd numbered pages", msg["reason"])
@@ -124,7 +124,7 @@ class TestDoubleSided(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
"""
self.create_staging_file(
datetime=dt.datetime.now(tz=dt.UTC)
datetime=dt.datetime.now()
- dt.timedelta(minutes=TIMEOUT_MINUTES, seconds=1),
)
msg = self.consume_file("double-sided-odd.pdf")
+57 -25
View File
@@ -12,6 +12,7 @@ from django.contrib.auth.models import User
from django.db import DatabaseError
from django.test import TestCase
from django.test import override_settings
from django.utils import timezone
from documents.file_handling import create_source_path_directory
from documents.file_handling import delete_empty_directories
@@ -23,6 +24,7 @@ from documents.models import CustomFieldInstance
from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.serialisers import DocumentSerializer
from documents.tasks import empty_trash
from documents.tests.factories import DocumentFactory
from documents.tests.utils import DirectoriesMixin
@@ -220,11 +222,8 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
doc = Document.objects.create(
title="document",
mime_type="application/pdf",
checksum=hashlib.md5(original_bytes, usedforsecurity=False).hexdigest(),
archive_checksum=hashlib.md5(
archive_bytes,
usedforsecurity=False,
).hexdigest(),
checksum=hashlib.sha256(original_bytes).hexdigest(),
archive_checksum=hashlib.sha256(archive_bytes).hexdigest(),
filename="old/document.pdf",
archive_filename="old/document.pdf",
storage_path=old_storage_path,
@@ -253,6 +252,46 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
self.assertIsNotFile(settings.ORIGINALS_DIR / "old" / "document.pdf")
self.assertIsNotFile(settings.ARCHIVE_DIR / "old" / "document.pdf")
@override_settings(FILENAME_FORMAT="{title}")
def test_serializer_stale_update_does_not_clobber_filename(self) -> None:
old_path = settings.ORIGINALS_DIR / "original.pdf"
old_path.touch()
doc = Document.objects.create(
title="original",
mime_type="application/pdf",
checksum=hashlib.sha256(b"").hexdigest(),
filename="original.pdf",
)
first_instance = Document.objects.get(pk=doc.pk)
stale_instance = Document.objects.get(pk=doc.pk)
serializer = DocumentSerializer(
first_instance,
data={"title": "first"},
partial=True,
)
self.assertTrue(serializer.is_valid(), serializer.errors)
serializer.save()
doc.refresh_from_db()
self.assertEqual(doc.filename, "first.pdf")
self.assertIsFile(settings.ORIGINALS_DIR / "first.pdf")
serializer = DocumentSerializer(
stale_instance,
data={"title": "second"},
partial=True,
)
self.assertTrue(serializer.is_valid(), serializer.errors)
serializer.save()
doc.refresh_from_db()
self.assertEqual(doc.filename, "second.pdf")
self.assertIsFile(settings.ORIGINALS_DIR / "second.pdf")
self.assertIsNotFile(settings.ORIGINALS_DIR / "first.pdf")
self.assertIsNotFile(old_path)
@override_settings(FILENAME_FORMAT="{correspondent}/{correspondent}")
def test_document_delete(self) -> None:
document = Document()
@@ -413,7 +452,7 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
FILENAME_FORMAT="{created_year}-{created_month}-{created_day}",
)
def test_created_year_month_day(self) -> None:
d1 = datetime.datetime(2020, 3, 6, 1, 1, 1, tzinfo=datetime.UTC)
d1 = timezone.make_aware(datetime.datetime(2020, 3, 6, 1, 1, 1))
doc1 = Document.objects.create(
title="doc1",
mime_type="application/pdf",
@@ -430,7 +469,7 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
FILENAME_FORMAT="{added_year}-{added_month}-{added_day}",
)
def test_added_year_month_day(self) -> None:
d1 = datetime.datetime(1232, 1, 9, 1, 1, 1, tzinfo=datetime.UTC)
d1 = timezone.make_aware(datetime.datetime(1232, 1, 9, 1, 1, 1))
doc1 = Document.objects.create(
title="doc1",
mime_type="application/pdf",
@@ -443,7 +482,7 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
self.assertEqual(generate_filename(doc1), expected_filename)
doc1.added = datetime.datetime(2020, 11, 16, 1, 1, 1, tzinfo=datetime.UTC)
doc1.added = timezone.make_aware(datetime.datetime(2020, 11, 16, 1, 1, 1))
self.assertEqual(generate_filename(doc1), Path("2020-11-16.pdf"))
@@ -1227,7 +1266,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
def test_short_names_added(self) -> None:
doc = Document.objects.create(
title="The Title",
added=datetime.datetime(1984, 8, 21, 7, 36, 51, 153, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(1984, 8, 21, 7, 36, 51, 153)),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1466,7 +1505,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc_a = Document.objects.create(
title="Does Matter",
created=datetime.date(2020, 6, 25),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1538,7 +1577,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc = Document.objects.create(
title="scan_017562",
created=datetime.date(2025, 7, 2),
added=datetime.datetime(2026, 3, 3, 11, 53, 16, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2026, 3, 3, 11, 53, 16)),
mime_type="application/pdf",
checksum="test-checksum",
storage_path=sp,
@@ -1567,7 +1606,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc_a = Document.objects.create(
title="Does Matter",
created=datetime.date(2020, 6, 25),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1602,7 +1641,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc_a = Document.objects.create(
title="Does Matter",
created=datetime.date(2020, 6, 25),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1634,7 +1673,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc_a = Document.objects.create(
title="Some Title",
created=datetime.date(2020, 6, 25),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1739,7 +1778,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc_a = Document.objects.create(
title="Some Title",
created=datetime.date(2020, 6, 25),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1753,15 +1792,8 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
CustomFieldInstance.objects.create(
document=doc_a,
field=CustomField.objects.get(name="Invoice Date"),
value_date=datetime.datetime(
2024,
10,
1,
7,
36,
51,
153,
tzinfo=datetime.UTC,
value_date=timezone.make_aware(
datetime.datetime(2024, 10, 1, 7, 36, 51, 153),
),
)
@@ -1801,7 +1833,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc = Document.objects.create(
title="Some Title! With @ Special # Characters",
created=datetime.date(2020, 6, 25),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
mime_type="application/pdf",
pk=2,
checksum="2",
+187
View File
@@ -0,0 +1,187 @@
"""
Tests for NFC Unicode normalization in generate_filename / FilePathTemplate.render().
NFC `ü` (UTF-8: c3 bc) and NFD `ü` (UTF-8: 75 cc 88) are visually identical but
produce different byte sequences. On Linux (ext4, ZFS) these are distinct filenames.
All paths produced by the templating system must be NFC-normalized.
"""
import unicodedata
import pytest
from documents.file_handling import generate_filename
from documents.models import CustomField
from documents.models import CustomFieldInstance
from documents.tests.factories import CorrespondentFactory
from documents.tests.factories import DocumentFactory
from documents.tests.factories import StoragePathFactory
from documents.tests.factories import TagFactory
@pytest.mark.django_db
class TestGenerateFilenameNFCNormalization:
@pytest.mark.parametrize(
"raw,display",
[
(unicodedata.normalize("NFD", "Gemüse"), "Gemüse"),
(unicodedata.normalize("NFD", "Café"), "Café"),
(unicodedata.normalize("NFD", "naïve"), "naïve"),
],
)
def test_nfd_title_normalized_to_nfc(self, settings, raw, display):
"""NFD title must produce NFC path bytes."""
settings.FILENAME_FORMAT = "{{ title }}"
nfc = unicodedata.normalize("NFC", display)
assert raw != nfc # confirm byte-level difference
doc = DocumentFactory(title=raw, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result) == f"{nfc}.pdf"
assert str(result).encode() == f"{nfc}.pdf".encode()
def test_nfd_correspondent_normalized_to_nfc(self, settings):
"""NFD correspondent name must produce NFC path component."""
settings.FILENAME_FORMAT = "{{ correspondent }}/{{ title }}"
nfd = unicodedata.normalize("NFD", "Müller")
nfc = unicodedata.normalize("NFC", "Müller")
correspondent = CorrespondentFactory(name=nfd)
doc = DocumentFactory(
title="invoice",
correspondent=correspondent,
mime_type="application/pdf",
)
result = generate_filename(doc)
assert str(result) == f"{nfc}/invoice.pdf"
assert str(result).encode() == f"{nfc}/invoice.pdf".encode()
def test_nfd_storage_path_normalized_to_nfc(self, settings):
"""NFD literal in StoragePath.path template must produce NFC path bytes."""
settings.FILENAME_FORMAT = None
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
# StoragePath.path is used directly as the format/template string.
# Literal NFD characters in the template must survive rendering as NFC.
sp = StoragePathFactory(path=f"{nfd}/{{{{ title }}}}")
doc = DocumentFactory(title="doc", storage_path=sp, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
def test_nfd_raw_document_title_normalized_to_nfc(self, settings):
"""NFD title accessed via document.title (unsanitized context) must also be NFC."""
settings.FILENAME_FORMAT = "{{ document.title }}"
nfd = unicodedata.normalize("NFD", "Café")
nfc = unicodedata.normalize("NFC", "Café")
doc = DocumentFactory(title=nfd, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result) == f"{nfc}.pdf"
assert str(result).encode() == f"{nfc}.pdf".encode()
@pytest.mark.django_db
class TestContextBuilderNFCNormalization:
"""
Defense-in-depth: context builder functions must NFC-normalize string inputs
before passing them to sanitize_filename(). Task 1 already normalizes the
final rendered path via clean_filepath(), so these tests may already pass;
they exist as regression guards for the context-builder layer.
"""
def test_nfd_tag_name_normalized_in_tag_list(self, settings):
"""NFD tag name must appear as NFC bytes in the {{ tag_list }} shorthand."""
settings.FILENAME_FORMAT = "{{ tag_list }}/{{ title }}"
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
assert nfd != nfc # confirm they differ at byte level
tag = TagFactory(name=nfd)
doc = DocumentFactory(title="doc", mime_type="application/pdf")
doc.tags.set([tag])
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
def test_nfd_original_name_normalized_to_nfc(self, settings):
settings.FILENAME_FORMAT = "{{ original_name }}"
nfd = unicodedata.normalize("NFD", "Rechnung März")
nfc = unicodedata.normalize("NFC", "Rechnung März")
doc = DocumentFactory(
original_filename=f"{nfd}.pdf",
mime_type="application/pdf",
)
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}.pdf".encode()
def test_nfd_custom_field_string_value_normalized(self, settings):
"""NFD value in a STRING-type custom field must appear as NFC in the context."""
settings.FILENAME_FORMAT = (
"{{ custom_fields['Location']['value'] }}/{{ title }}"
)
nfd_value = unicodedata.normalize("NFD", "Düsseldorf")
nfc_value = unicodedata.normalize("NFC", "Düsseldorf")
assert nfd_value != nfc_value
doc = DocumentFactory(title="report", mime_type="application/pdf")
cf = CustomField.objects.create(
name="Location",
data_type=CustomField.FieldDataType.STRING,
)
CustomFieldInstance.objects.create(
document=doc,
field=cf,
value_text=nfd_value,
)
result = generate_filename(doc)
assert str(result).encode() == f"{nfc_value}/report.pdf".encode()
def test_nfd_custom_field_name_normalized_as_key(self, settings):
"""NFD characters in a custom field name must appear as NFC in the context dict key."""
nfd_name = unicodedata.normalize("NFD", "Größe")
nfc_name = unicodedata.normalize("NFC", "Größe")
assert nfd_name != nfc_name
settings.FILENAME_FORMAT = f"{{% if custom_fields['{nfc_name}'] %}}{{{{ custom_fields['{nfc_name}']['value'] }}}}/{{{{ title }}}}{{% else %}}{{{{ title }}}}{{% endif %}}"
doc = DocumentFactory(title="letter", mime_type="application/pdf")
cf = CustomField.objects.create(
name=nfd_name,
data_type=CustomField.FieldDataType.STRING,
)
CustomFieldInstance.objects.create(
document=doc,
field=cf,
value_text="Berlin",
)
result = generate_filename(doc)
# If field name key is NFC-normalized, the template condition succeeds
# and result is "Berlin/letter.pdf"; otherwise it falls back to "letter.pdf"
assert str(result) == "Berlin/letter.pdf"
def test_nfd_tag_name_list_normalized_to_nfc(self, settings):
"""NFD tag names in tag_name_list must appear as NFC bytes when iterated."""
settings.FILENAME_FORMAT = (
"{% for t in tag_name_list %}{{ t }}{% endfor %}/{{ title }}"
)
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
assert nfd != nfc # confirm byte-level difference
doc = DocumentFactory(title="doc", mime_type="application/pdf")
doc.tags.add(TagFactory(name=nfd))
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
+1 -1
View File
@@ -243,7 +243,7 @@ class TestViews(DirectoriesMixin, TestCase):
"change": {"users": [], "groups": []},
}
else:
raise AssertionError(f"Unexpected tag found: {tag['name']}")
assert False, f"Unexpected tag found: {tag['name']}"
def test_list_no_n_plus_1_queries(self) -> None:
"""
+1 -8
View File
@@ -2760,14 +2760,7 @@ class TestWorkflows(
doc = Document.objects.create(
title="test",
)
self.assertRaisesRegex(
Exception,
"not yet supported",
document_matches_workflow,
doc,
w,
99,
)
self.assertRaises(Exception, document_matches_workflow, doc, w, 99)
def test_removal_action_document_updated_workflow(self) -> None:
"""
+2 -3
View File
@@ -129,12 +129,11 @@ def util_call_with_backoff(
status_codes.append(cause_exec.response.status_code)
warnings.warn(
f"HTTP Exception for {cause_exec.request.url} - {cause_exec}",
stacklevel=2,
)
else:
warnings.warn(f"Unexpected error: {e}", stacklevel=2)
warnings.warn(f"Unexpected error: {e}")
except Exception as e: # pragma: no cover
warnings.warn(f"Unexpected error: {e}", stacklevel=2)
warnings.warn(f"Unexpected error: {e}")
retry_count = retry_count + 1
+142 -50
View File
@@ -7,11 +7,12 @@ import tempfile
import zipfile
from collections import defaultdict
from collections import deque
from datetime import UTC
from datetime import datetime
from datetime import timedelta
from http import HTTPStatus
from pathlib import Path
from time import mktime
from time import sleep
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
@@ -60,6 +61,7 @@ from django.http import StreamingHttpResponse
from django.shortcuts import get_object_or_404
from django.utils import timezone
from django.utils.decorators import method_decorator
from django.utils.timezone import make_aware
from django.utils.translation import get_language
from django.utils.translation import gettext_lazy as _
from django.views import View
@@ -284,7 +286,7 @@ def _get_more_like_id(query_params: dict[str, Any], user: User | None) -> int:
pk=more_like_doc_id,
)
except (TypeError, ValueError, Document.DoesNotExist):
raise PermissionDenied(_("Invalid more_like_id")) from None
raise PermissionDenied(_("Invalid more_like_id"))
if user and not has_perms_owner_aware(
user,
@@ -1100,7 +1102,7 @@ class DocumentViewSet(
"root_document",
).get(pk=pk)
except Document.DoesNotExist:
raise Http404 from None
raise Http404
root_doc = get_root_document(doc)
if request.user is not None and not has_perms_owner_aware(
@@ -1263,7 +1265,7 @@ class DocumentViewSet(
"root_document",
).get(id=pk)
except Document.DoesNotExist:
raise Http404 from None
raise Http404
root_doc = get_root_document(
request_doc,
@@ -1399,7 +1401,7 @@ class DocumentViewSet(
)
if request.user is not None and not has_perms_owner_aware(
request.user,
"view_document",
"change_document",
doc,
):
return HttpResponseForbidden("Insufficient permissions")
@@ -1459,7 +1461,7 @@ class DocumentViewSet(
)
if request.user is not None and not has_perms_owner_aware(
request.user,
"view_document",
"change_document",
doc,
):
return HttpResponseForbidden("Insufficient permissions")
@@ -1505,6 +1507,7 @@ class DocumentViewSet(
"document %s: %s",
doc.pk,
exc,
exc_info=True,
)
raise ValidationError({"ai": [_("Invalid AI configuration.")]}) from exc
@@ -1578,7 +1581,7 @@ class DocumentViewSet(
disposition="inline",
)
except FileNotFoundError:
raise Http404 from None
raise Http404
@action(methods=["get"], detail=True, filter_backends=[])
@method_decorator(cache_control(no_cache=True))
@@ -1603,14 +1606,14 @@ class DocumentViewSet(
return FileResponse(handle, content_type="image/webp")
except FileNotFoundError:
raise Http404 from None
raise Http404
@action(methods=["get"], detail=True)
def download(self, request, pk=None):
try:
return self.file_response(pk, request, "attachment")
except (FileNotFoundError, Document.DoesNotExist):
raise Http404 from None
raise Http404
@action(
methods=["get", "post", "delete"],
@@ -1635,7 +1638,7 @@ class DocumentViewSet(
):
return HttpResponseForbidden("Insufficient permissions to view notes")
except Document.DoesNotExist:
raise Http404 from None
raise Http404
serializer = self.get_serializer(doc)
@@ -1706,7 +1709,7 @@ class DocumentViewSet(
try:
note_id_int = int(note_id)
except ValueError:
raise ValidationError({"id": "A valid integer is required."}) from None
raise ValidationError({"id": "A valid integer is required."})
note = get_object_or_404(Note, id=note_id_int, document=doc)
if settings.AUDIT_LOG_ENABLED:
LogEntry.objects.log_create(
@@ -1750,7 +1753,7 @@ class DocumentViewSet(
"Insufficient permissions to add share link",
)
except Document.DoesNotExist:
raise Http404 from None
raise Http404
if request.method == "GET":
now = timezone.now()
@@ -1778,7 +1781,7 @@ class DocumentViewSet(
"Insufficient permissions",
)
except Document.DoesNotExist: # pragma: no cover
raise Http404 from None
raise Http404
# documents
entries = [
@@ -1799,28 +1802,28 @@ class DocumentViewSet(
]
# custom fields
entries.extend(
{
"id": entry.id,
"timestamp": entry.timestamp,
"action": entry.get_action_display(),
"changes": {
"custom_fields": {
"type": "custom_field",
"field": str(entry.object_repr).split(":")[0].strip(),
"value": str(entry.object_repr).split(":")[1].strip(),
for entry in LogEntry.objects.get_for_objects(
doc.custom_fields.all(),
).select_related("actor"):
entries.append(
{
"id": entry.id,
"timestamp": entry.timestamp,
"action": entry.get_action_display(),
"changes": {
"custom_fields": {
"type": "custom_field",
"field": str(entry.object_repr).split(":")[0].strip(),
"value": str(entry.object_repr).split(":")[1].strip(),
},
},
"actor": (
{"id": entry.actor.id, "username": entry.actor.username}
if entry.actor
else None
),
},
"actor": (
{"id": entry.actor.id, "username": entry.actor.username}
if entry.actor
else None
),
}
for entry in LogEntry.objects.get_for_objects(
doc.custom_fields.all(),
).select_related("actor")
)
)
return Response(sorted(entries, key=lambda x: x["timestamp"], reverse=True))
@@ -1928,13 +1931,13 @@ class DocumentViewSet(
):
return HttpResponseForbidden("Insufficient permissions")
except Document.DoesNotExist:
raise Http404 from None
raise Http404
try:
doc_name, doc_data = serializer.validated_data.get("document")
version_label = serializer.validated_data.get("version_label")
t = int(timezone.now().timestamp())
t = int(mktime(datetime.now().timetuple()))
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
@@ -1979,7 +1982,7 @@ class DocumentViewSet(
"root_document",
).get(pk=pk)
except Document.DoesNotExist:
raise Http404 from None
raise Http404
return get_root_document(root_doc)
def _get_version_doc_for_root(self, root_doc: Document, version_id) -> Document:
@@ -1988,7 +1991,7 @@ class DocumentViewSet(
pk=version_id,
)
except Document.DoesNotExist:
raise Http404 from None
raise Http404
if (
version_doc.id != root_doc.id
@@ -2274,6 +2277,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
return super().list(request)
from documents.search import SearchHit
from documents.search import SearchQueryError
from documents.search import TantivyBackend
from documents.search import TantivyRelevanceList
from documents.search import get_backend
@@ -2466,6 +2470,11 @@ class UnifiedSearchViewSet(DocumentViewSet):
return HttpResponseForbidden(_("Insufficient permissions."))
except ValidationError:
raise
except SearchQueryError as e:
# User-fixable query error (e.g. an unparsable date): surface the
# specific message so the user can correct it, rather than a generic
# 400 or silently empty results.
raise ValidationError({"query": [str(e)]}) from e
except Exception as e:
logger.warning(f"An error occurred listing search results: {e!s}")
return HttpResponseBadRequest(
@@ -2543,7 +2552,7 @@ class LogViewSet(ViewSet):
try:
limit = int(limit_param)
except (TypeError, ValueError):
raise ValidationError({"limit": "Must be a positive integer"}) from None
raise ValidationError({"limit": "Must be a positive integer"})
if limit < 1:
raise ValidationError({"limit": "Must be a positive integer"})
else:
@@ -3124,6 +3133,7 @@ class PostDocumentView(GenericAPIView[Any]):
serializer.is_valid(raise_exception=True)
doc_name, doc_data = serializer.validated_data.get("document")
doc_name = normalize("NFC", doc_name)
correspondent_id = serializer.validated_data.get("correspondent")
document_type_id = serializer.validated_data.get("document_type")
storage_path_id = serializer.validated_data.get("storage_path")
@@ -3134,7 +3144,7 @@ class PostDocumentView(GenericAPIView[Any]):
cf = serializer.validated_data.get("custom_fields")
from_webui = serializer.validated_data.get("from_webui")
t = int(timezone.now().timestamp())
t = int(mktime(datetime.now().timetuple()))
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
@@ -4009,7 +4019,7 @@ class RemoteVersionView(GenericAPIView[Any]):
class _TasksViewSetSchema(AutoSchema):
_UNPAGINATED_ACTIONS = frozenset({"summary", "active"})
_UNPAGINATED_ACTIONS = frozenset({"summary", "active", "status_counts"})
def _get_paginator(self):
if getattr(self.view, "action", None) in self._UNPAGINATED_ACTIONS:
@@ -4031,7 +4041,7 @@ class _TasksViewSetSchema(AutoSchema):
),
acknowledge=extend_schema(
operation_id="acknowledge_tasks",
description="Acknowledge a list of tasks",
description="Acknowledge a list of tasks, or all visible unacknowledged tasks",
request=AcknowledgeTasksViewSerializer,
responses={
(200, "application/json"): inline_serializer(
@@ -4069,6 +4079,19 @@ class _TasksViewSetSchema(AutoSchema):
),
],
),
status_counts=extend_schema(
responses={
200: inline_serializer(
name="TaskStatusCounts",
fields={
"all": serializers.IntegerField(),
"needs_attention": serializers.IntegerField(),
"in_progress": serializers.IntegerField(),
"completed": serializers.IntegerField(),
},
),
},
),
active=extend_schema(
description="Currently pending and running tasks (capped at 50).",
responses={200: TaskSerializerV10(many=True)},
@@ -4122,6 +4145,7 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
PaperlessTask.TaskType.SANITY_CHECK: (sanity_check, {"raise_on_error": False}),
PaperlessTask.TaskType.LLM_INDEX: (llmindex_index, {"rebuild": False}),
}
_STATUS_COUNT_EXCLUDED_FILTERS = frozenset({"status", "is_complete"})
def get_serializer_class(self):
# v9: use backwards-compatible serializer with old field names
@@ -4162,16 +4186,38 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
queryset = queryset.filter(task_id=task_id)
return queryset
def get_status_count_queryset(self):
"""Apply task filters except the status dimensions represented by the counts."""
query_params = self.request.query_params.copy()
for param in self._STATUS_COUNT_EXCLUDED_FILTERS:
query_params.pop(param, None)
filterset = self.filterset_class(
data=query_params,
queryset=self.get_queryset(),
request=self.request,
)
if not filterset.is_valid():
raise ValidationError(filterset.errors)
return filterset.qs
@action(
methods=["post"],
detail=False,
permission_classes=[IsAuthenticated, AcknowledgeTasksPermissions],
)
def acknowledge(self, request):
serializer = AcknowledgeTasksViewSerializer(data=request.data)
queryset = self.get_queryset()
serializer = AcknowledgeTasksViewSerializer(
data=request.data,
context={"queryset": queryset},
)
serializer.is_valid(raise_exception=True)
task_ids = serializer.validated_data.get("tasks")
tasks = self.get_queryset().filter(id__in=task_ids)
if serializer.validated_data.get("all", False):
tasks = queryset.filter(acknowledged=False)
else:
task_ids = serializer.validated_data.get("tasks")
tasks = queryset.filter(id__in=task_ids)
count = tasks.update(acknowledged=True)
return Response({"result": count})
@@ -4224,6 +4270,34 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
serializer = TaskSummarySerializer(data, many=True)
return Response(serializer.data)
@action(methods=["get"], detail=False)
def status_counts(self, request):
"""Aggregated task counts for task UI sections."""
queryset = self.get_status_count_queryset()
counts = queryset.aggregate(
all=Count("id"),
needs_attention=Count(
"id",
filter=Q(
status__in=[
PaperlessTask.Status.FAILURE,
PaperlessTask.Status.REVOKED,
],
),
),
in_progress=Count(
"id",
filter=Q(
status__in=[
PaperlessTask.Status.PENDING,
PaperlessTask.Status.STARTED,
],
),
),
completed=Count("id", filter=Q(status=PaperlessTask.Status.SUCCESS)),
)
return Response(counts)
@action(methods=["get"], detail=False)
def active(self, request):
"""Currently pending and running tasks (capped at 50)."""
@@ -4923,11 +4997,29 @@ class SystemStatusView(PassUserMixin):
celery_error = None
celery_url = None
try:
celery_ping = celery_app.control.inspect().ping()
celery_url = next(iter(celery_ping.keys()))
first_worker_ping = celery_ping[celery_url]
if first_worker_ping["ok"] == "pong":
celery_active = "OK"
celery_ping = None
for ping_attempt in range(3):
celery_ping = celery_app.control.inspect().ping()
if celery_ping:
break
if ping_attempt < 2:
sleep(0.25)
if not celery_ping:
celery_active = "WARNING"
celery_error = (
"No celery workers responded to ping. This may be temporary."
)
else:
celery_url, first_worker_ping = next(iter(celery_ping.items()))
if (
isinstance(first_worker_ping, dict)
and first_worker_ping.get("ok") == "pong"
):
celery_active = "OK"
else:
celery_active = "WARNING"
celery_error = "Celery worker responded unexpectedly."
except Exception as e:
celery_active = "ERROR"
logger.exception(
@@ -4946,7 +5038,7 @@ class SystemStatusView(PassUserMixin):
index_dir = settings.INDEX_DIR
mtimes = [p.stat().st_mtime for p in index_dir.iterdir() if p.is_file()]
index_last_modified = (
datetime.fromtimestamp(max(mtimes), tz=UTC) if mtimes else None
make_aware(datetime.fromtimestamp(max(mtimes))) if mtimes else None
)
except Exception as e:
index_status = "ERROR"
+13 -14
View File
@@ -84,11 +84,10 @@ def binaries_check(app_configs: Any, **kwargs: Any) -> list[Error]:
binaries = (settings.CONVERT_BINARY, "tesseract", "gs")
check_messages = [
Warning(error.format(binary), hint)
for binary in binaries
if shutil.which(binary) is None
]
check_messages = []
for binary in binaries:
if shutil.which(binary) is None:
check_messages.append(Warning(error.format(binary), hint))
return check_messages
@@ -384,14 +383,14 @@ def check_default_language_available(app_configs: Any, **kwargs: Any) -> list[Er
specified_langs = [x.strip() for x in settings.OCR_LANGUAGE.split("+")]
errs.extend(
Error(
f"The selected ocr language {lang} is "
f"not installed. Paperless cannot OCR your documents "
f"without it. Please fix PAPERLESS_OCR_LANGUAGE.",
)
for lang in specified_langs
if lang not in installed_langs
)
for lang in specified_langs:
if lang not in installed_langs:
errs.append(
Error(
f"The selected ocr language {lang} is "
f"not installed. Paperless cannot OCR your documents "
f"without it. Please fix PAPERLESS_OCR_LANGUAGE.",
),
)
return errs
+5 -4
View File
@@ -649,10 +649,11 @@ class MailDocumentParser:
if data["bcc"]:
data["bcc_label"] = "BCC"
att = [
f"{a.filename} ({naturalsize(a.size, binary=True, format='%.2f')})"
for a in mail.attachments
]
att = []
for a in mail.attachments:
att.append(
f"{a.filename} ({naturalsize(a.size, binary=True, format='%.2f')})",
)
data["attachments"] = clean_html(", ".join(att))
if data["attachments"]:
data["attachments_label"] = "Attachments"
+2 -28
View File
@@ -20,6 +20,7 @@ from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from paperless.parsers.utils import read_file_handle_unicode_errors
from paperless.version import __full_version_str__
if TYPE_CHECKING:
@@ -183,7 +184,7 @@ class TextDocumentParser:
documents.parsers.ParseError
If the file cannot be read.
"""
self._text = self._read_text(document_path)
self._text = read_file_handle_unicode_errors(document_path, log=logger)
# ------------------------------------------------------------------
# Result accessors
@@ -295,30 +296,3 @@ class TextDocumentParser:
Always ``[]`` plain text files carry no structured metadata.
"""
return []
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _read_text(self, filepath: Path) -> str:
"""Read file content, replacing invalid UTF-8 bytes rather than failing.
Parameters
----------
filepath:
Path to the file to read.
Returns
-------
str
File content as a string.
"""
try:
return filepath.read_text(encoding="utf-8")
except UnicodeDecodeError as exc:
logger.warning(
"Unicode error reading %s, replacing bad bytes: %s",
filepath,
exc,
)
return filepath.read_bytes().decode("utf-8", errors="replace")
+18 -5
View File
@@ -8,6 +8,7 @@ share implementation.
from __future__ import annotations
import codecs
import logging
import re
import tempfile
@@ -114,7 +115,7 @@ def read_file_handle_unicode_errors(
filepath: Path,
log: logging.Logger | None = None,
) -> str:
"""Read a file as UTF-8 text, replacing invalid bytes rather than raising.
"""Read a file as text, detecting encoding via BOM and stripping NUL bytes.
Parameters
----------
@@ -127,15 +128,27 @@ def read_file_handle_unicode_errors(
Returns
-------
str
File content as a string, with any invalid UTF-8 sequences replaced
by the Unicode replacement character.
File content as a string, with NUL bytes removed so the result is
safe to store in PostgreSQL text fields.
"""
_log = log or logger
raw = filepath.read_bytes()
if raw.startswith((codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE)):
encoding = "utf-16"
elif raw.startswith(codecs.BOM_UTF8):
encoding = "utf-8-sig"
else:
encoding = "utf-8"
try:
return filepath.read_text(encoding="utf-8")
text = raw.decode(encoding)
except UnicodeDecodeError as e:
_log.warning("Unicode error during text reading, continuing: %s", e)
return filepath.read_bytes().decode("utf-8", errors="replace")
text = raw.decode("utf-8", errors="replace")
# PostgreSQL rejects NUL (0x00) bytes in text fields
return text.replace("\x00", "")
def get_page_count_for_pdf(
+9 -2
View File
@@ -97,8 +97,14 @@ MODEL_FILE = get_path_from_env(
DATA_DIR / "classification_model.pickle",
)
LLM_INDEX_DIR = DATA_DIR / "llm_index"
LLM_INDEX_LOCK = DATA_DIR / "locks" / "llm_index.lock"
(DATA_DIR / "locks").mkdir(parents=True, exist_ok=True)
LLM_INDEX_LOCK = LLM_INDEX_DIR / "index.lock"
# Cross-process read/write lock guarding the LLM index compaction/migration
# file swap. Readers hold it shared; the swap takes it exclusively so it never
# runs while a reader connection is open. Must be a SQLite (.db) file.
LLM_INDEX_RWLOCK = LLM_INDEX_DIR / "llmindex.rwlock.db"
# Seconds the compaction swap waits for active readers to drain before skipping
# this cycle (it is a maintenance operation; the next run retries).
LLM_INDEX_COMPACTION_LOCK_TIMEOUT = 30
LOGGING_DIR = get_path_from_env("PAPERLESS_LOGGING_DIR", DATA_DIR / "log")
@@ -644,6 +650,7 @@ LOGGING = {
"kombu": {"handlers": ["file_celery"], "level": "DEBUG"},
"_granian": {"handlers": ["file_paperless"], "level": "DEBUG"},
"granian.access": {"handlers": ["file_paperless"], "level": "DEBUG"},
"httpx": {"level": "WARNING"},
},
}
+4 -1
View File
@@ -252,6 +252,9 @@ def parse_db_settings(data_dir: Path) -> dict[str, dict[str, Any]]:
"NAME": os.getenv("PAPERLESS_DBNAME", "paperless"),
"USER": os.getenv("PAPERLESS_DBUSER", "paperless"),
"PASSWORD": os.getenv("PAPERLESS_DBPASS", "paperless"),
# Validate pooled connections so a connection closed server-side
# is replaced rather than handed out as "the connection is closed".
"CONN_HEALTH_CHECKS": True,
}
base_options = {
@@ -331,7 +334,7 @@ def parse_dateparser_languages(languages: str | None) -> list[str]:
language_list = languages.split("+") if languages else []
# There is an unfixed issue in zh-Hant and zh-Hans locales in the dateparser lib.
# See: https://github.com/scrapinghub/dateparser/issues/875
for _, language in enumerate(language_list):
for index, language in enumerate(language_list):
if language.startswith("zh-") and "zh" not in language_list:
logger.warning(
f"Chinese locale detected: {language}. dateparser might fail to parse"
@@ -398,6 +398,7 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "localhost",
"NAME": "paperless",
"USER": "paperless",
@@ -426,6 +427,7 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "paperless-db-host",
"PORT": 1111,
"NAME": "customdb",
@@ -455,6 +457,7 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "pghost",
"NAME": "paperless",
"USER": "paperless",
@@ -485,6 +488,7 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "pghost",
"NAME": "paperless",
"USER": "paperless",
+37
View File
@@ -2,13 +2,50 @@
from __future__ import annotations
import codecs
from pathlib import Path
from paperless.parsers.utils import is_tagged_pdf
from paperless.parsers.utils import read_file_handle_unicode_errors
SAMPLES = Path(__file__).parent / "samples" / "tesseract"
class TestReadFileHandleUnicodeErrors:
def test_plain_utf8(self, tmp_path: Path) -> None:
f = tmp_path / "plain.txt"
f.write_bytes(b"hello world")
assert read_file_handle_unicode_errors(f) == "hello world"
def test_utf8_bom(self, tmp_path: Path) -> None:
f = tmp_path / "bom.txt"
f.write_bytes(codecs.BOM_UTF8 + b"hello")
assert read_file_handle_unicode_errors(f) == "hello"
def test_utf16_le(self, tmp_path: Path) -> None:
f = tmp_path / "utf16le.txt"
f.write_bytes(codecs.BOM_UTF16_LE + "hello".encode("utf-16-le"))
assert read_file_handle_unicode_errors(f) == "hello"
def test_utf16_be(self, tmp_path: Path) -> None:
f = tmp_path / "utf16be.txt"
f.write_bytes(codecs.BOM_UTF16_BE + "hello".encode("utf-16-be"))
assert read_file_handle_unicode_errors(f) == "hello"
def test_nul_bytes_stripped(self, tmp_path: Path) -> None:
f = tmp_path / "null-bytes.txt"
f.write_bytes(b"foo\x00bar")
assert read_file_handle_unicode_errors(f) == "foobar"
def test_invalid_utf8_replaced(self, tmp_path: Path) -> None:
f = tmp_path / "bad.txt"
f.write_bytes(b"ok\x80\x81bad")
result = read_file_handle_unicode_errors(f)
assert "ok" in result
assert "bad" in result
assert "\x00" not in result
class TestIsTaggedPdf:
def test_tagged_pdf_returns_true(self) -> None:
assert is_tagged_pdf(SAMPLES / "simple-digital.pdf") is True
+1 -1
View File
@@ -193,7 +193,7 @@ def reject_dangerous_svg(file: UploadedFile) -> None:
tree = etree.parse(file, parser)
root = tree.getroot()
except etree.XMLSyntaxError:
raise ValidationError("Invalid SVG file.") from None
raise ValidationError("Invalid SVG file.")
for element in root.iter():
tag: str = etree.QName(element.tag).localname.lower()
+2 -2
View File
@@ -49,7 +49,7 @@ from paperless.serialisers import GroupSerializer
from paperless.serialisers import PaperlessAuthTokenSerializer
from paperless.serialisers import ProfileSerializer
from paperless.serialisers import UserSerializer
from paperless_ai.indexing import vector_store_file_exists
from paperless_ai.indexing import llm_index_exists
class PaperlessObtainAuthTokenView(ObtainAuthToken):
@@ -467,7 +467,7 @@ class ApplicationConfigurationViewSet(ModelViewSet[ApplicationConfiguration]):
or old_llm_context_size != new_llm_context_size
)
rebuild_needed = new_ai_index_enabled and (
not vector_store_file_exists() or embedding_config_changed
not llm_index_exists() or embedding_config_changed
)
if rebuild_needed:
+35 -21
View File
@@ -8,6 +8,7 @@ from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware
from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.db import db_connection_released
from paperless_ai.indexing import query_similar_documents
from paperless_ai.indexing import truncate_content
@@ -24,9 +25,14 @@ def get_language_name(language_code: str) -> str:
def build_prompt_without_rag(
document: Document,
config: AIConfig,
) -> str:
filename = document.filename or ""
content = truncate_content(document.content[:4000] or "")
content = truncate_content(
document.content[:4000] or "",
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
return f"""
You are a document classification assistant.
@@ -49,10 +55,15 @@ def build_prompt_without_rag(
def build_prompt_with_rag(
document: Document,
config: AIConfig,
user: User | None = None,
) -> str:
base_prompt = build_prompt_without_rag(document)
context = truncate_content(get_context_for_document(document, user))
base_prompt = build_prompt_without_rag(document, config)
context = truncate_content(
get_context_for_document(document, user),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
return f"""{base_prompt}
@@ -130,26 +141,29 @@ def get_ai_document_classification(
ai_config = AIConfig()
prompt = (
build_prompt_with_rag(document, user)
build_prompt_with_rag(document, ai_config, user)
if ai_config.llm_embedding_backend
else build_prompt_without_rag(document)
else build_prompt_without_rag(document, ai_config)
)
client = AIClient()
result = client.run_llm_query(prompt)
suggestions = parse_ai_response(result)
if output_language:
localized = client.run_llm_query(
build_localization_prompt(suggestions, output_language),
)
localized_suggestions = parse_ai_response(localized)
suggestions = {
**suggestions,
"title": localized_suggestions["title"] or suggestions["title"],
"tags": localized_suggestions["tags"] or suggestions["tags"],
"document_types": localized_suggestions["document_types"]
or suggestions["document_types"],
"storage_paths": localized_suggestions["storage_paths"]
or suggestions["storage_paths"],
}
# Hand the pooled DB connection back while the (slow) LLM query runs so it
# is not pinned for the call's duration; see paperless_ai.db and #12976.
with db_connection_released():
result = client.run_llm_query(prompt)
suggestions = parse_ai_response(result)
if output_language:
localized = client.run_llm_query(
build_localization_prompt(suggestions, output_language),
)
localized_suggestions = parse_ai_response(localized)
suggestions = {
**suggestions,
"title": localized_suggestions["title"] or suggestions["title"],
"tags": localized_suggestions["tags"] or suggestions["tags"],
"document_types": localized_suggestions["document_types"]
or suggestions["document_types"],
"storage_paths": localized_suggestions["storage_paths"]
or suggestions["storage_paths"],
}
return suggestions
+56 -122
View File
@@ -3,9 +3,13 @@ import logging
import sys
from documents.models import Document
from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.db import db_connection_released
from paperless_ai.indexing import _document_id_filters
from paperless_ai.indexing import get_rag_prompt_helper
from paperless_ai.indexing import load_or_build_index
from paperless_ai.indexing import read_store
logger = logging.getLogger("paperless_ai.chat")
@@ -75,82 +79,6 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str
)
def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int):
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.core.vector_stores import VectorStoreQuery
class DocumentFilteredFaissRetriever(BaseRetriever):
def __init__(self):
super().__init__()
self._cached_query_str = None
self._cached_nodes = []
def _retrieve(self, query_bundle):
if query_bundle.query_str == self._cached_query_str:
return self._cached_nodes
if query_bundle.embedding is None:
query_bundle.embedding = (
index._embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs,
)
)
faiss_index = index.vector_store._faiss_index
max_top_k = faiss_index.ntotal
if max_top_k == 0:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = []
return []
query_top_k = min(max(similarity_top_k, 1), max_top_k)
allowed_nodes: list[NodeWithScore] = []
seen_node_ids: set[str] = set()
while query_top_k <= max_top_k:
query_result = index.vector_store.query(
VectorStoreQuery(
query_embedding=query_bundle.embedding,
similarity_top_k=query_top_k,
),
)
for vector_id, score in zip(
query_result.ids or [],
query_result.similarities or [],
strict=False,
):
node_id = index.index_struct.nodes_dict.get(vector_id)
if node_id is None or node_id in seen_node_ids:
continue
node = index.docstore.docs.get(node_id)
if node is None or node.metadata.get("document_id") not in doc_ids:
continue
seen_node_ids.add(node_id)
allowed_nodes.append(NodeWithScore(node=node, score=score))
if len(allowed_nodes) >= similarity_top_k:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
if query_top_k == max_top_k:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
query_top_k = min(query_top_k * 2, max_top_k)
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
return DocumentFilteredFaissRetriever()
def stream_chat_with_documents(query_str: str, documents: list[Document]):
try:
yield from _stream_chat_with_documents(query_str, documents)
@@ -160,63 +88,69 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
def _stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
index = load_or_build_index()
doc_ids = [str(doc.pk) for doc in documents]
# Filter only the node(s) that match the document IDs
nodes = [
node
for node in index.docstore.docs.values()
if node.metadata.get("document_id") in doc_ids
]
if len(nodes) == 0:
logger.warning("No nodes found for the given documents.")
if not documents:
yield CHAT_NO_CONTENT_MESSAGE
return
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.retrievers import VectorIndexRetriever
retriever = _get_document_filtered_retriever(
index,
set(doc_ids),
CHAT_RETRIEVER_TOP_K,
)
config = AIConfig()
filters = _document_id_filters(str(doc.pk) for doc in documents)
top_nodes = retriever.retrieve(query_str)
if len(top_nodes) == 0:
logger.warning("Retriever returned no nodes for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
# Hold the shared read lock for the whole operation: the query engine
# retrieves from the vector store again during synthesis, so the connection
# must stay open (and the swap must not run) until the stream finishes.
with read_store() as store:
index = load_or_build_index(config, store)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=CHAT_RETRIEVER_TOP_K,
filters=filters,
)
references = _get_document_references(documents, top_nodes)
# Slow query-embedding + vector search; no Django ORM access happens
# during it, so release the pooled DB connection for its duration. See
# #12976.
with db_connection_released():
top_nodes = retriever.retrieve(query_str)
if not top_nodes:
logger.warning("No nodes found for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
response_synthesizer = get_response_synthesizer(
llm=client.llm,
prompt_helper=get_rag_prompt_helper(),
text_qa_template=prompt_template,
streaming=True,
)
client = AIClient()
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
response_synthesizer=response_synthesizer,
streaming=True,
)
references = _get_document_references(documents, top_nodes)
logger.debug("Document chat query: %s", query_str)
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
response_synthesizer = get_response_synthesizer(
llm=client.llm,
prompt_helper=get_rag_prompt_helper(
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
),
text_qa_template=prompt_template,
streaming=True,
)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
response_synthesizer=response_synthesizer,
streaming=True,
)
response_stream = query_engine.query(query_str)
logger.debug("Document chat query: %s", query_str)
# Release the pooled DB connection for the slow streaming LLM response
# so it is not pinned for the whole stream; see paperless_ai.db and
# #12976.
with db_connection_released():
response_stream = query_engine.query(query_str)
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()
if references:
yield _format_chat_metadata_trailer(references)
if references:
yield _format_chat_metadata_trailer(references)
+30
View File
@@ -0,0 +1,30 @@
from __future__ import annotations
from contextlib import contextmanager
from django.db import connections
@contextmanager
def db_connection_released():
"""
Return any checked-out DB connections to the pool for the duration of the
wrapped block.
The AI endpoints run inside a synchronous web request (``ai_suggestions``)
or a streaming response (``chat``). Django keeps the request's database
connection checked out for the entire request/response, so a blocking LLM
call - which can take many seconds - pins a pooled connection the whole
time. With connection pooling enabled, enough concurrent AI requests check
out every slot and all other requests then fail with
``psycopg_pool.PoolTimeout`` (see issue #12976).
No Django ORM access happens during the LLM call, so we hand the connection
back to the pool first; Django transparently re-checks-out a connection on
the next ORM use after the block.
"""
connections.close_all()
try:
yield
finally:
connections.close_all()
+24 -54
View File
@@ -1,12 +1,9 @@
import json
import re
from typing import TYPE_CHECKING
from django.conf import settings
if TYPE_CHECKING:
from pathlib import Path
from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document
@@ -23,9 +20,7 @@ OCR_LEADER_REGEX = re.compile(r"[._\-\u00b7]{4,}")
HORIZONTAL_WHITESPACE_REGEX = re.compile(r"[ \t\u00a0]+")
def get_embedding_model() -> "BaseEmbedding":
config = AIConfig()
def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
match config.llm_embedding_backend:
case LLMEmbeddingBackend.OPENAI_LIKE:
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
@@ -95,41 +90,24 @@ def get_embedding_model() -> "BaseEmbedding":
)
def get_embedding_dim() -> int:
"""
Loads embedding dimension from meta.json if available, otherwise infers it
from a dummy embedding and stores it for future use.
"""
config = AIConfig()
default_model = {
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
}.get(
config.llm_embedding_backend,
"sentence-transformers/all-MiniLM-L6-v2",
_DEFAULT_MODEL_NAMES = {
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
}
def get_configured_model_name(config: AIConfig) -> str:
"""Return the canonical name of the currently configured embedding model."""
# dict.get(key, default) overload resolution fails for TextChoices keys in some
# type checkers; use `or` fallback to avoid the ambiguity.
default = (
_DEFAULT_MODEL_NAMES.get(
config.llm_embedding_backend,
)
or "sentence-transformers/all-MiniLM-L6-v2"
)
model = config.llm_embedding_model or default_model
meta_path: Path = settings.LLM_INDEX_DIR / "meta.json"
if meta_path.exists():
with meta_path.open() as f:
meta = json.load(f)
if meta.get("embedding_model") != model:
raise RuntimeError(
f"Embedding model changed from {meta.get('embedding_model')} to {model}. "
"You must rebuild the index.",
)
return meta["dim"]
embedding_model = get_embedding_model()
test_embed = embedding_model.get_text_embedding("test")
dim = len(test_embed)
with meta_path.open("w") as f:
json.dump({"embedding_model": model, "dim": dim}, f)
return dim
return config.llm_embedding_model or default
def _normalize_llm_index_text(text: str) -> str:
@@ -138,24 +116,16 @@ def _normalize_llm_index_text(text: str) -> str:
def build_llm_index_text(doc: Document) -> str:
# Short structured fields (filename, storage path, ASN, title, tags, ...) live
# in node.metadata: excluded from embeddings, shown to the LLM via metadata
# prepend. Notes and Custom Fields stay in the body: Notes can be long free
# text, Custom Fields are dynamic in count and best kept in the embedding.
lines = [
f"Title: {doc.title}",
f"Filename: {doc.filename}",
f"Created: {doc.created}",
f"Added: {doc.added}",
f"Modified: {doc.modified}",
f"Tags: {', '.join(tag.name for tag in doc.tags.all())}",
f"Document Type: {doc.document_type.name if doc.document_type else ''}",
f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}",
f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}",
f"Archive Serial Number: {doc.archive_serial_number or ''}",
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
]
lines.extend(
f"Custom Field - {instance.field.name}: {instance}"
for instance in doc.custom_fields.all()
)
for instance in doc.custom_fields.all():
lines.append(f"Custom Field - {instance.field.name}: {instance}")
lines.append("\nContent:\n")
lines.append(doc.content or "")
+269 -243
View File
@@ -1,28 +1,30 @@
import logging
import shutil
from collections import defaultdict
from collections.abc import Iterable
from contextlib import contextmanager
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
from django.conf import settings
from django.utils import timezone
from filelock import FileLock
from filelock import ReadWriteLock
from filelock import Timeout
from documents.models import Document
from documents.models import PaperlessTask
from documents.utils import IterWrapper
from documents.utils import identity
from paperless.config import AIConfig
from paperless_ai.db import db_connection_released
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_configured_model_name
from paperless_ai.embedding import get_embedding_model
if TYPE_CHECKING:
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import BaseNode
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
logger = logging.getLogger("paperless_ai.indexing")
@@ -30,21 +32,11 @@ RAG_NUM_OUTPUT = 512
RAG_CHUNK_OVERLAP = 200
def _index_lock_path() -> Path:
"""Return the path used as the file lock for FAISS index mutations.
The lock file lives in DATA_DIR/locks/ (not inside LLM_INDEX_DIR) so that a
rebuild which calls shutil.rmtree(LLM_INDEX_DIR) cannot delete the lock
while another worker still holds it.
"""
return settings.LLM_INDEX_LOCK
def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
# NOTE: The check-then-enqueue sequence below is non-atomic (TOCTOU): two
# concurrent workers can both observe no running task and both enqueue a
# full rebuild. This is wasteful but not data-corrupting — update_llm_index
# is itself protected by _index_lock_path(), so only one rebuild runs at a
# is itself protected by settings.LLM_INDEX_LOCK, so only one rebuild runs at a
# time and the second one is serialised after the first completes.
from documents.tasks import llmindex_index
@@ -71,46 +63,110 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
return True
def get_or_create_storage_context(*, rebuild=False):
"""
Loads or creates the StorageContext (vector store, docstore, index store).
If rebuild=True, deletes and recreates everything.
"""
if rebuild:
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
def get_vector_store() -> "PaperlessSqliteVecVectorStore":
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
if rebuild or not settings.LLM_INDEX_DIR.exists():
import faiss
from llama_index.core import StorageContext
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
embedding_dim = get_embedding_dim()
faiss_index = faiss.IndexFlatL2(embedding_dim)
vector_store = FaissVectorStore(faiss_index=faiss_index)
docstore = SimpleDocumentStore()
index_store = SimpleIndexStore()
else:
from llama_index.core import StorageContext
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
return StorageContext.from_defaults(
docstore=docstore,
index_store=index_store,
vector_store=vector_store,
persist_dir=settings.LLM_INDEX_DIR,
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
return PaperlessSqliteVecVectorStore(
uri=str(settings.LLM_INDEX_DIR),
)
# --- LLM index locking ---------------------------------------------------
#
# Two locks guard the index; they answer different questions and are NOT
# interchangeable:
#
# * settings.LLM_INDEX_LOCK (FileLock, exclusive) -- serializes WRITERS against
# each other, so only one rebuild/upsert/delete/compaction runs at a time.
# Taken by write_store(). Readers never take it, so it never blocks reads.
#
# * settings.LLM_INDEX_RWLOCK (ReadWriteLock) -- coordinates readers against the
# compaction/migration file swap. read_store() takes it SHARED (readers run
# concurrently); _exclude_readers() takes it EXCLUSIVE, only for the swap, so
# the database file is never replaced while a reader connection is open (that
# would alias the old WAL onto the new file and corrupt it).
#
# | vs another writer | vs a reader
# -----------------+-------------------+----------------------------
# normal write | LLM_INDEX_LOCK | nothing (WAL gives MVCC)
# compaction/swap | LLM_INDEX_LOCK | LLM_INDEX_RWLOCK (exclusive)
# reader | nothing (WAL) | LLM_INDEX_RWLOCK (shared)
#
# They can't be merged into one ReadWriteLock: a normal write must exclude other
# writers WITHOUT blocking readers (WAL already gives reader/writer concurrency),
# and ReadWriteLock has no "exclusive vs writers, shared vs readers" mode. Only
# the swap needs to exclude readers.
def _index_rwlock() -> ReadWriteLock:
"""Return a fresh read/write lock instance for the index swap.
``is_singleton=False`` so reads and the swap always coordinate through
SQLite (the actual cross-process case) rather than hitting the in-process
reentrant-upgrade guard; callers must ``close()`` it (the context managers
below do).
"""
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
return ReadWriteLock(str(settings.LLM_INDEX_RWLOCK), is_singleton=False)
@contextmanager
def read_store():
"""Acquire the shared read lock and yield the vector store for a read.
The shared lock is held for the whole lifetime of the connection (and
closed on exit) so the compaction/migration swap, which takes the exclusive
lock, never runs while this connection is open. Concurrent readers do not
block each other; only the swap does.
"""
lock = _index_rwlock()
try:
with lock.read_lock(), get_vector_store() as store:
yield store
finally:
lock.close()
@contextmanager
def _exclude_readers():
"""Acquire exclusive index access, blocking until readers have drained.
The exclusive counterpart to ``read_store()``: a compaction or migration
must not run while any reader connection is open. Raises
:class:`filelock.Timeout` if active readers do not drain within
``LLM_INDEX_COMPACTION_LOCK_TIMEOUT``; callers skip the operation on timeout.
"""
lock = _index_rwlock()
try:
with lock.write_lock(timeout=settings.LLM_INDEX_COMPACTION_LOCK_TIMEOUT):
yield
finally:
lock.close()
@contextmanager
def write_store(embed_model_name: str | None = None):
"""Acquire the write lock and yield the vector store.
All mutating operations (upsert, delete, rebuild, compact) must go through
this context manager to serialise concurrent Celery writers.
Read paths use ``read_store()`` so they hold the shared read lock.
Pass ``embed_model_name`` whenever the operation may create the table so
the model name is recorded in the schema metadata for future mismatch checks.
"""
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
with (
FileLock(settings.LLM_INDEX_LOCK),
PaperlessSqliteVecVectorStore(
uri=str(settings.LLM_INDEX_DIR),
embed_model_name=embed_model_name,
) as store,
):
yield store
def build_document_node(
document: Document,
*,
@@ -130,6 +186,9 @@ def build_document_node(
"document_type": document.document_type.name
if document.document_type
else None,
"filename": document.filename,
"storage_path": document.storage_path.name if document.storage_path else None,
"archive_serial_number": document.archive_serial_number,
"created": document.created.isoformat() if document.created else None,
"added": document.added.isoformat() if document.added else None,
"modified": document.modified.isoformat(),
@@ -142,9 +201,11 @@ def build_document_node(
# the token count and exceed embedding models with small context windows
# (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
doc = LlamaDocument(
id_=str(document.id),
text=text,
metadata=metadata,
excluded_embed_metadata_keys=list(metadata.keys()),
excluded_llm_metadata_keys=["document_id"],
)
chunk_size = chunk_size or get_rag_chunk_size()
parser = SimpleNodeParser(
@@ -154,76 +215,33 @@ def build_document_node(
return parser.get_nodes_from_documents([doc])
def load_or_build_index(nodes=None):
"""
Load an existing VectorStoreIndex if present,
or build a new one using provided nodes if storage is empty.
def load_or_build_index(config: AIConfig, store: "PaperlessSqliteVecVectorStore"):
"""Return a VectorStoreIndex backed by ``store``.
``store`` is supplied by the caller's ``read_store()`` context so the shared
read lock and the connection stay alive for the whole retrieval.
"""
import llama_index.core.settings as llama_settings
from llama_index.core import VectorStoreIndex
from llama_index.core import load_index_from_storage
embed_model = get_embedding_model()
embed_model = get_embedding_model(config)
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context()
try:
return load_index_from_storage(storage_context=storage_context)
except ValueError as e:
logger.warning("Failed to load index from storage: %s", e)
if not nodes:
queue_llm_index_update_if_needed(
rebuild=vector_store_file_exists(),
reason="LLM index missing or invalid while loading.",
)
logger.info("No nodes provided for index creation.")
raise
return VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
)
return VectorStoreIndex.from_vector_store(
vector_store=store,
embed_model=embed_model,
)
def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"):
"""
Removes existing documents from docstore for a given document from the index.
This is necessary because FAISS IndexFlatL2 is append-only.
"""
all_node_ids = list(index.docstore.docs.keys())
existing_nodes = [
node.node_id
for node in index.docstore.get_nodes(all_node_ids)
if node.metadata.get("document_id") == str(document.id)
]
for node_id in existing_nodes:
# Delete from docstore, FAISS IndexFlatL2 are append-only
index.docstore.delete_document(node_id)
# Also purge the FAISS position -> UUID mapping so subsequent similarity
# queries don't raise KeyError on ghost vector positions.
stale_keys = [
k for k, v in index.index_struct.nodes_dict.items() if v == node_id
]
for key in stale_keys:
del index.index_struct.nodes_dict[key]
# Re-sync the mutated index_struct so persist() writes the updated nodes_dict.
index.storage_context.index_store.add_index_struct(index.index_struct)
def vector_store_file_exists():
"""
Check if the vector store file exists in the LLM index directory.
"""
return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists()
def llm_index_exists() -> bool:
"""True when the index table exists on disk."""
with read_store() as store:
return store.table_exists()
def get_rag_chunk_size() -> int:
return AIConfig().llm_embedding_chunk_size
def get_rag_context_size() -> int:
return AIConfig().llm_context_size
def get_rag_chunk_overlap(chunk_size: int | None = None) -> int:
chunk_size = chunk_size or get_rag_chunk_size()
return min(RAG_CHUNK_OVERLAP, chunk_size - 1)
@@ -249,123 +267,149 @@ def get_rag_prompt_helper(
)
def _embed_nodes(nodes: list["BaseNode"], embed_model) -> None:
"""Embed ``nodes`` in place using ``embed_model``."""
from llama_index.core.schema import MetadataMode
texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes]
for node, emb in zip(
nodes,
embed_model.get_text_embedding_batch(texts),
strict=True,
):
node.embedding = emb
def _document_id_filters(doc_ids):
"""Return a MetadataFilters IN filter scoped to ``doc_ids``."""
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
return MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=sorted(doc_ids),
),
],
)
def update_llm_index(
*,
iter_wrapper: IterWrapper[Document] = identity,
rebuild=False,
) -> str:
"""
Rebuild or update the LLM index.
"""
from llama_index.core import VectorStoreIndex
nodes = []
"""Rebuild or incrementally update the LLM index."""
with write_store() as store:
try:
with _exclude_readers():
needs_reembed = store.check_and_run_migrations()
except Timeout:
logger.info(
"Skipping LLM index migration check: index readers are active; "
"will retry next run.",
)
needs_reembed = False
if needs_reembed:
logger.warning(
"LLM index migration requires re-embedding; forcing rebuild.",
)
rebuild = True
documents = Document.objects.all()
if not documents.exists():
no_documents = not documents.exists()
# Fast exit before touching config: nothing to index and no existing index.
if no_documents and not rebuild and not llm_index_exists():
logger.warning("No documents found to index.")
if not rebuild and not vector_store_file_exists():
return "No documents found to index."
return "No documents found to index."
config = AIConfig()
model_name = get_configured_model_name(config)
if not rebuild and llm_index_exists():
with read_store() as store:
config_mismatch = store.config_mismatch(model_name)
if config_mismatch:
logger.warning("Embedding model changed; forcing LLM index rebuild.")
rebuild = True
if no_documents:
logger.warning("No documents found to index.")
chunk_size = config.llm_embedding_chunk_size
embed_model = get_embedding_model(config)
with FileLock(_index_lock_path()):
if rebuild or not vector_store_file_exists():
# remove meta.json to force re-detection of embedding dim
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
# Rebuild index from scratch
with write_store(embed_model_name=model_name) as store:
if rebuild or not store.table_exists():
logger.info("Rebuilding LLM index.")
import llama_index.core.settings as llama_settings
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=True)
store.drop_table()
for document in iter_wrapper(documents):
document_nodes = build_document_node(document, chunk_size=chunk_size)
nodes.extend(document_nodes)
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
show_progress=False,
)
nodes = build_document_node(document, chunk_size=chunk_size)
_embed_nodes(nodes, embed_model)
store.add(nodes)
msg = "LLM index rebuilt successfully."
else:
# Update existing index
index = load_or_build_index()
existing_nodes: defaultdict[str, list] = defaultdict(list)
for node in index.docstore.docs.values():
doc_id = node.metadata.get("document_id")
if doc_id is not None:
existing_nodes[doc_id].append(node)
existing = store.get_modified_times()
changed = 0
for document in iter_wrapper(documents):
doc_id = str(document.id)
document_modified = document.modified.isoformat()
if existing.get(doc_id) == document.modified.isoformat():
continue
nodes = build_document_node(document, chunk_size=chunk_size)
_embed_nodes(nodes, embed_model)
store.upsert_document(doc_id, nodes)
changed += 1
msg = (
"LLM index updated successfully."
if changed
else "No changes detected in LLM index."
)
if doc_id in existing_nodes:
doc_nodes = existing_nodes[doc_id]
node_modified = doc_nodes[0].metadata.get("modified")
if node_modified == document_modified:
continue
# Delete from docstore, FAISS IndexFlatL2 are append-only
for _ in doc_nodes:
remove_document_docstore_nodes(document, index)
nodes.extend(build_document_node(document, chunk_size=chunk_size))
if nodes:
msg = "LLM index updated successfully."
logger.info(
"Updating %d nodes in LLM index.",
len(nodes),
)
index.insert_nodes(nodes)
else:
msg = "No changes detected in LLM index."
logger.info(msg)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
try:
with _exclude_readers():
store.compact()
except Timeout:
logger.info(
"Skipping LLM index compaction: index readers are active; "
"will retry next run.",
)
return msg
def llm_index_add_or_update_document(document: Document):
"""
Adds or updates a document in the LLM index.
If the document already exists, it will be replaced.
"""
new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size())
if not new_nodes:
logger.warning(
"No indexable content for document %s; skipping LLM index update.",
document.pk,
)
return
"""Add or atomically replace a document's chunks in the index."""
config = AIConfig()
new_nodes = build_document_node(
document,
chunk_size=config.llm_embedding_chunk_size,
)
if new_nodes:
_embed_nodes(new_nodes, get_embedding_model(config))
with FileLock(_index_lock_path()):
index = load_or_build_index(nodes=new_nodes)
with write_store(embed_model_name=get_configured_model_name(config)) as store:
store.upsert_document(str(document.id), new_nodes)
remove_document_docstore_nodes(document, index)
index.insert_nodes(new_nodes)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def llm_index_compact() -> None:
"""Compact the index immediately, rebuilding the table to reclaim space."""
with write_store() as store:
try:
with _exclude_readers():
store.compact(force=True)
except Timeout:
logger.info(
"Skipping LLM index compaction: index readers are active; "
"will retry next run.",
)
def llm_index_remove_document(document: Document):
"""
Removes a document from the LLM index.
"""
with FileLock(_index_lock_path()):
index = load_or_build_index()
remove_document_docstore_nodes(document, index)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
"""Remove a document's chunks from the LLM index."""
with write_store() as store:
store.delete(str(document.id))
def truncate_content(
@@ -410,77 +454,59 @@ def query_similar_documents(
top_k: int = 5,
document_ids: Iterable[int | str] | None = None,
) -> list[Document]:
"""
Runs a similarity query and returns top-k similar Document objects.
"""
"""Return up to ``top_k`` Documents most similar to ``document``."""
allowed_document_ids = normalize_document_ids(document_ids)
if allowed_document_ids is not None and not allowed_document_ids:
return []
if not vector_store_file_exists():
if not llm_index_exists():
queue_llm_index_update_if_needed(
rebuild=False,
reason="LLM index not found for similarity query.",
)
return []
with FileLock(_index_lock_path()):
index = load_or_build_index()
config = AIConfig()
# constrain only the node(s) that match the document IDs, if given
doc_node_ids = (
[
node.node_id
for node in index.docstore.docs.values()
if node.metadata.get("document_id") in allowed_document_ids
]
if allowed_document_ids is not None
else None
)
if doc_node_ids is not None and not doc_node_ids:
return []
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.retrievers import VectorIndexRetriever
filters = (
_document_id_filters(allowed_document_ids)
if allowed_document_ids is not None
else None
)
query_text = truncate_content(
(document.title or "") + "\n" + (document.content or ""),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
# Hold the shared read lock for the whole retrieval so the connection is
# never open across a compaction swap. The retrieve() call generates a
# query embedding (a slow external request) and searches the vector store;
# no Django ORM access happens during it, so release the pooled DB
# connection for its duration. See #12976.
with read_store() as store:
index = load_or_build_index(config, store)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k,
doc_ids=doc_node_ids,
filters=filters,
)
config = AIConfig()
query_text = truncate_content(
(document.title or "") + "\n" + (document.content or ""),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
try:
with db_connection_released():
results = retriever.retrieve(query_text)
except KeyError as e:
# Ghost FAISS positions remain after deletion because IndexFlatL2 is
# append-only. Treat them as absent and return no results.
logger.debug(
"Skipping LLM similarity query for document %s due to a stale "
"FAISS position with no docstore node: %s",
document.pk,
e,
)
return []
retrieved_document_ids: list[int] = []
for node in results:
document_id = node.metadata.get("document_id")
if document_id is None:
continue
normalized_document_id = str(document_id)
if (
allowed_document_ids is not None
and normalized_document_id not in allowed_document_ids
):
normalized = str(document_id)
if allowed_document_ids is not None and normalized not in allowed_document_ids:
continue
try:
retrieved_document_ids.append(int(normalized_document_id))
except ValueError:
retrieved_document_ids.append(int(normalized))
except ValueError: # pragma: no cover
logger.warning(
"Skipping LLM index result with invalid document_id %r.",
document_id,
+27 -1
View File
@@ -1,10 +1,36 @@
from pathlib import Path
import pytest
import pytest_mock
from llama_index.core.base.embeddings.base import BaseEmbedding
from pytest_django.fixtures import SettingsWrapper
@pytest.fixture
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper):
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path:
settings.LLM_INDEX_DIR = tmp_path
settings.LLM_INDEX_LOCK = tmp_path / "index.lock"
settings.LLM_INDEX_RWLOCK = tmp_path / "llmindex.rwlock.db"
return tmp_path
class FakeEmbedding(BaseEmbedding):
async def _aget_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_text_embedding(self, text: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def get_query_embedding_dim(self) -> int:
return 384
@pytest.fixture
def mock_embed_model(mocker: pytest_mock.MockerFixture) -> pytest_mock.MockType:
fake = FakeEmbedding()
mocker.patch("paperless_ai.indexing.get_embedding_model", return_value=fake)
mocker.patch("paperless_ai.embedding.get_embedding_model", return_value=fake)
return fake
+5 -3
View File
@@ -6,6 +6,7 @@ import pytest
from django.test import override_settings
from documents.models import Document
from paperless.config import AIConfig
from paperless_ai.ai_classifier import build_localization_prompt
from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag
@@ -155,7 +156,7 @@ def test_get_ai_document_classification_failure(mock_run_llm_query, mock_documen
mock_run_llm_query.side_effect = Exception("LLM query failed")
# assert raises an exception
with pytest.raises(ValueError, match="Unsupported LLM backend"):
with pytest.raises(Exception):
get_ai_document_classification(mock_document)
@@ -211,11 +212,12 @@ def test_prompt_with_without_rag(mock_document):
"paperless_ai.ai_classifier.get_context_for_document",
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
config = AIConfig()
prompt = build_prompt_without_rag(mock_document, config)
assert "Additional context from similar documents" not in prompt
assert "for generated" not in prompt
prompt = build_prompt_with_rag(mock_document)
prompt = build_prompt_with_rag(mock_document, config)
assert "Additional context from similar documents" in prompt
prompt = build_localization_prompt(
+265 -435
View File
@@ -1,15 +1,12 @@
import json
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
import pytest_mock
from django.contrib.auth.models import User
from django.test import override_settings
from django.utils import timezone
from faker import Faker
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.schema import MetadataMode
from documents.models import Document
from documents.models import PaperlessTask
@@ -19,10 +16,12 @@ from documents.tests.factories import DocumentFactory
from documents.tests.factories import PaperlessTaskFactory
from paperless.models import ApplicationConfiguration
from paperless_ai import indexing
from paperless_ai.tests.conftest import FakeEmbedding
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
@pytest.fixture
def real_document(db):
def real_document(db: None) -> Document:
return Document.objects.create(
title="Test Document",
content="This is some test content.",
@@ -30,44 +29,39 @@ def real_document(db):
)
@pytest.fixture
def mock_embed_model():
fake = FakeEmbedding()
with (
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
patch(
"paperless_ai.embedding.get_embedding_model",
) as mock_embedding,
):
mock_index.return_value = fake
mock_embedding.return_value = fake
yield mock_index
class FakeEmbedding(BaseEmbedding):
# TODO: maybe a better way to do this?
def _aget_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_text_embedding(self, text: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def get_query_embedding_dim(self) -> int:
return 384 # Match your real FAISS config
@pytest.mark.django_db
def test_build_document_node(real_document) -> None:
def test_build_document_node(real_document: Document) -> None:
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
assert nodes[0].metadata["document_id"] == str(real_document.id)
assert nodes[0].metadata["filename"] == real_document.filename
assert nodes[0].metadata["storage_path"] == (
real_document.storage_path.name if real_document.storage_path else None
)
assert (
nodes[0].metadata["archive_serial_number"]
== real_document.archive_serial_number
)
assert "filename" in nodes[0].excluded_embed_metadata_keys
assert "filename" not in nodes[0].excluded_llm_metadata_keys
@pytest.mark.django_db
def test_build_document_node_excludes_metadata_from_embedding(real_document) -> None:
def test_build_document_node_sets_ref_doc_id(real_document: Document) -> None:
"""Every node produced by build_document_node must carry the paperless document id
as its ref_doc_id so that the vector store's delete(str(doc.id)) works correctly."""
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0, "Expected at least one node"
for node in nodes:
assert node.ref_doc_id == str(real_document.id), (
f"Expected ref_doc_id={real_document.id!r}, got {node.ref_doc_id!r}"
)
@pytest.mark.django_db
def test_build_document_node_excludes_metadata_from_embedding(
real_document: Document,
) -> None:
"""Metadata keys must not be prepended to the embedding text.
build_llm_index_text already encodes all metadata in the body text, so
@@ -75,8 +69,6 @@ def test_build_document_node_excludes_metadata_from_embedding(real_document) ->
double the token count and exceed embedding models with small context
windows (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
"""
from llama_index.core.schema import MetadataMode
nodes = indexing.build_document_node(real_document)
for node in nodes:
embed_text = node.get_content(metadata_mode=MetadataMode.EMBED)
@@ -87,7 +79,36 @@ def test_build_document_node_excludes_metadata_from_embedding(real_document) ->
@pytest.mark.django_db
def test_build_document_node_uses_rag_chunk_settings(real_document) -> None:
def test_build_document_node_structured_fields_in_metadata(
real_document: Document,
) -> None:
"""Structured fields must be in node.metadata so the LLM receives them via metadata prepend."""
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
for node in nodes:
assert "title" in node.metadata
assert "tags" in node.metadata
assert "correspondent" in node.metadata
assert "document_type" in node.metadata
assert "created" in node.metadata
assert "added" in node.metadata
assert "modified" in node.metadata
@pytest.mark.django_db
def test_build_document_node_excludes_document_id_from_llm_context(
real_document: Document,
) -> None:
"""document_id is an internal key and must not appear in LLM context text."""
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
for node in nodes:
assert "document_id" in node.excluded_llm_metadata_keys
assert "document_id" not in node.get_content(metadata_mode=MetadataMode.LLM)
@pytest.mark.django_db
def test_build_document_node_uses_rag_chunk_settings(real_document: Document) -> None:
app_config, _ = ApplicationConfiguration.objects.get_or_create()
app_config.llm_embedding_chunk_size = 512
app_config.save()
@@ -118,9 +139,9 @@ def test_get_rag_prompt_helper_uses_context_setting() -> None:
@pytest.mark.django_db
def test_update_llm_index(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: FakeEmbedding,
) -> None:
mock_config = MagicMock()
mock_config.llm_embedding_chunk_size = 512
@@ -138,44 +159,49 @@ def test_update_llm_index(
ai_config.assert_called_once()
build_document_node.assert_called_once_with(real_document, chunk_size=512)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_update_llm_index_removes_meta(
temp_llm_index_dir,
real_document,
mock_embed_model,
def test_update_llm_index_rebuilds_on_model_name_change(
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: FakeEmbedding,
) -> None:
# Pre-create a meta.json with incorrect data
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "old", "dim": 1}),
)
# Build initial index with model "model-a".
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
with patch(
"paperless_ai.indexing.get_configured_model_name",
return_value="model-a",
):
indexing.update_llm_index(rebuild=True)
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
from paperless.config import AIConfig
# Simulate config change to "model-b"; the incremental run must force a rebuild.
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
with patch(
"paperless_ai.indexing.get_configured_model_name",
return_value="model-b",
):
indexing.update_llm_index(rebuild=False)
config = AIConfig()
expected_model = config.llm_embedding_model or (
"text-embedding-3-small"
if config.llm_embedding_backend == "openai-like"
else "sentence-transformers/all-MiniLM-L6-v2"
)
assert meta == {"embedding_model": expected_model, "dim": 384}
with indexing.get_vector_store() as store:
# Schema metadata only updates when the table is dropped and recreated, never
# on incremental writes -- so "model-b" here proves a full rebuild happened.
assert store.stored_model_name() == "model-b"
@pytest.mark.django_db
def test_update_llm_index_partial_update(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: FakeEmbedding,
) -> None:
doc2 = Document.objects.create(
title="Test Document 2",
@@ -210,131 +236,34 @@ def test_update_llm_index_partial_update(
mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3])
mock_all.return_value = mock_queryset
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=False)
mock_logger.info.assert_called_once_with(
"Updating %d nodes in LLM index.",
2,
)
indexing.update_llm_index(rebuild=False)
assert any(temp_llm_index_dir.glob("*.json"))
def test_get_or_create_storage_context_raises_exception(
temp_llm_index_dir,
mock_embed_model,
) -> None:
with pytest.raises(ValueError):
indexing.get_or_create_storage_context(rebuild=False)
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
)
def test_load_or_build_index_builds_when_nodes_given(
temp_llm_index_dir,
real_document,
mock_embed_model,
) -> None:
with (
patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
),
patch(
"llama_index.core.VectorStoreIndex",
return_value=MagicMock(),
) as mock_index_cls,
patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
) as mock_storage,
):
mock_storage.return_value.persist_dir = temp_llm_index_dir
indexing.load_or_build_index(
nodes=[indexing.build_document_node(real_document)],
with indexing.get_vector_store() as store:
assert store.table_exists(), (
"Expected the vector store table to exist after incremental update"
)
mock_index_cls.assert_called_once()
def test_load_or_build_index_raises_exception_when_no_nodes(
temp_llm_index_dir,
mock_embed_model,
) -> None:
with (
patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
),
patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
),
):
with pytest.raises(Exception): # noqa: B017
indexing.load_or_build_index()
@pytest.mark.django_db
def test_load_or_build_index_succeeds_when_nodes_given(
temp_llm_index_dir,
mock_embed_model,
) -> None:
with (
patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
),
patch(
"llama_index.core.VectorStoreIndex",
return_value=MagicMock(),
) as mock_index_cls,
patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
) as mock_storage,
):
mock_storage.return_value.persist_dir = temp_llm_index_dir
indexing.load_or_build_index(
nodes=[MagicMock()],
)
mock_index_cls.assert_called_once()
@pytest.mark.django_db
def test_add_or_update_document_updates_existing_entry(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: FakeEmbedding,
) -> None:
indexing.update_llm_index(rebuild=True)
indexing.llm_index_add_or_update_document(real_document)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_remove_document_deletes_node_from_docstore(
temp_llm_index_dir,
real_document,
mock_embed_model,
) -> None:
indexing.update_llm_index(rebuild=True)
index = indexing.load_or_build_index()
assert len(index.docstore.docs) == 1
indexing.llm_index_remove_document(real_document)
index = indexing.load_or_build_index()
assert len(index.docstore.docs) == 0
with indexing.get_vector_store() as store:
assert store.table_exists(), (
"Expected the vector store table to exist after add-or-update"
)
@pytest.mark.django_db
def test_query_after_remove_does_not_raise_key_error(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: FakeEmbedding,
) -> None:
indexing.update_llm_index(rebuild=True)
@@ -352,8 +281,8 @@ def test_query_after_remove_does_not_raise_key_error(
@pytest.mark.django_db
def test_update_llm_index_no_documents(
temp_llm_index_dir,
mock_embed_model,
temp_llm_index_dir: Path,
mock_embed_model: FakeEmbedding,
) -> None:
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
@@ -369,6 +298,22 @@ def test_update_llm_index_no_documents(
)
@pytest.mark.django_db
def test_update_no_documents_no_index_returns_early(
temp_llm_index_dir: Path,
mocker: pytest_mock.MockerFixture,
) -> None:
"""update with no documents and no existing index must return early."""
mock_qs = MagicMock()
mock_qs.exists.return_value = False
mock_qs.__iter__ = MagicMock(return_value=iter([]))
mocker.patch("paperless_ai.indexing.Document.objects.all", return_value=mock_qs)
result = indexing.update_llm_index(rebuild=False)
assert result == "No documents found to index."
@pytest.mark.django_db
def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -> None:
# No existing tasks
@@ -406,20 +351,17 @@ def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -
LLM_BACKEND="ollama",
)
def test_query_similar_documents(
temp_llm_index_dir,
real_document,
temp_llm_index_dir: Path,
real_document: Document,
) -> None:
with (
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch(
"paperless_ai.indexing.vector_store_file_exists",
"paperless_ai.indexing.llm_index_exists",
) as mock_vector_store_exists,
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
):
mock_storage.return_value = MagicMock()
mock_storage.return_value.persist_dir = temp_llm_index_dir
mock_vector_store_exists.return_value = True
mock_index = MagicMock()
@@ -453,12 +395,12 @@ def test_query_similar_documents(
@pytest.mark.django_db
def test_query_similar_documents_triggers_update_when_index_missing(
temp_llm_index_dir,
real_document,
temp_llm_index_dir: Path,
real_document: Document,
) -> None:
with (
patch(
"paperless_ai.indexing.vector_store_file_exists",
"paperless_ai.indexing.llm_index_exists",
return_value=False,
),
patch(
@@ -479,120 +421,13 @@ def test_query_similar_documents_triggers_update_when_index_missing(
assert result == []
@pytest.mark.django_db
def test_query_similar_documents_normalizes_and_post_filters_allowed_ids(
real_document,
) -> None:
real_document.owner = User.objects.create_user(username="rag-owner")
real_document.save()
private_owner = User.objects.create_user(username="rag-private-owner")
private_document = Document.objects.create(
title="Private similar document",
content="Similar private content that must not reach RAG.",
owner=private_owner,
added=timezone.now(),
)
with (
patch(
"paperless_ai.indexing.vector_store_file_exists",
return_value=True,
),
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
):
allowed_node = MagicMock()
allowed_node.node_id = "allowed-node"
allowed_node.metadata = {"document_id": str(real_document.pk)}
private_node = MagicMock()
private_node.node_id = "private-node"
private_node.metadata = {"document_id": str(private_document.pk)}
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [allowed_node, private_node]
mock_load_or_build_index.return_value = mock_index
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = [private_node, allowed_node]
mock_retriever_cls.return_value = mock_retriever
result = indexing.query_similar_documents(
real_document,
top_k=2,
document_ids=[real_document.pk],
)
mock_retriever_cls.assert_called_once_with(
index=mock_index,
similarity_top_k=2,
doc_ids=["allowed-node"],
)
assert result == [real_document]
assert private_document not in result
class TestUpdateLlmIndexStaleNodes:
"""Tests that update_llm_index removes ALL nodes for a multi-chunk document."""
@pytest.mark.django_db
def test_incremental_update_removes_all_old_nodes_for_multi_chunk_document(
self,
temp_llm_index_dir,
mock_embed_model: MagicMock,
) -> None:
"""Ghost nodes from all chunks of a modified document must be removed.
When a document is split into multiple chunks (chunk_size=1024), the
incremental update path must delete every old node, not just the last
one captured by a dict comprehension keyed on document_id.
"""
# Content long enough to produce at least two chunks at chunk_size=1024.
# Generate many paragraphs so the token count comfortably exceeds 1024.
fake = Faker()
long_content = "\n\n".join(fake.paragraph(nb_sentences=20) for _ in range(20))
doc = DocumentFactory(content=long_content)
# Build the initial index (rebuild=True) so it has multiple nodes
indexing.update_llm_index(rebuild=True)
# Verify the initial index has more than one node for this document
initial_index = indexing.load_or_build_index()
initial_node_ids = [
nid
for nid, node in initial_index.docstore.docs.items()
if node.metadata.get("document_id") == str(doc.id)
]
assert len(initial_node_ids) > 1, (
f"Expected multiple chunks but got {len(initial_node_ids)}; "
"increase long_content length"
)
# Simulate a modification so the incremental path treats it as changed.
# Use queryset.update() to bypass auto_now and actually change the DB value.
new_modified = timezone.now()
Document.objects.filter(pk=doc.pk).update(modified=new_modified)
# Run incremental update (rebuild=False) with the modified document
indexing.update_llm_index(rebuild=False)
# Reload the persisted index and check that no OLD node ids remain
updated_index = indexing.load_or_build_index()
remaining_old_node_ids = [
nid for nid in initial_node_ids if nid in updated_index.docstore.docs
]
assert remaining_old_node_ids == [], (
f"Ghost nodes still present after incremental update: "
f"{remaining_old_node_ids}"
)
@pytest.mark.django_db
def test_query_similar_documents_empty_allow_list_fails_closed(
real_document,
real_document: Document,
) -> None:
with (
patch(
"paperless_ai.indexing.vector_store_file_exists",
"paperless_ai.indexing.llm_index_exists",
return_value=True,
) as mock_vector_store_exists,
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
@@ -610,27 +445,25 @@ def test_query_similar_documents_empty_allow_list_fails_closed(
class TestUpdateLlmIndexEmptyDocumentSet:
"""update_llm_index must persist an empty index when all documents are deleted.
"""update_llm_index must clear the vector store table when all documents are deleted.
Without this, the stale on-disk FAISS vectors are never cleared and
subsequent similarity searches return phantom hits for document IDs that
no longer exist in the DB.
Without this, the stale vectors are never cleared and subsequent similarity
searches return phantom hits for document IDs that no longer exist in the DB.
"""
@pytest.mark.django_db
def test_rebuild_clears_stale_index_when_no_documents_exist(
self,
temp_llm_index_dir: Path,
mock_embed_model: MagicMock,
mock_embed_model: FakeEmbedding,
) -> None:
"""After deleting all documents, rebuild=True must persist an empty index.
"""After deleting all documents, rebuild=True must produce a table with zero rows.
Steps:
1. Build an index with one document so the on-disk state is non-empty.
2. Delete all documents from the DB.
3. Call update_llm_index(rebuild=True).
4. Reload the index from disk.
5. Assert the reloaded index has zero nodes (no phantom vectors).
4. Open the LanceDB table directly and assert zero rows.
"""
# Step 1: create a document and build a non-empty index
Document.objects.create(
@@ -640,27 +473,26 @@ class TestUpdateLlmIndexEmptyDocumentSet:
)
indexing.update_llm_index(rebuild=True)
initial_index = indexing.load_or_build_index()
assert len(initial_index.docstore.docs) > 0, (
"Precondition failed: expected at least one node before deletion"
)
with indexing.get_vector_store() as store:
assert store.table_exists(), (
"Precondition failed: expected the vector store table to exist "
"before deletion"
)
# Step 2: delete all documents
Document.objects.all().delete()
assert not Document.objects.exists()
# Step 3: rebuild with no documents
# Step 3: rebuild with no documents — drop_table is called so the table
# is removed (no rows to re-insert, so it stays absent).
indexing.update_llm_index(rebuild=True)
# Step 4: reload the persisted index from disk
reloaded_index = indexing.load_or_build_index()
# Step 5: phantom vectors must be gone
assert len(reloaded_index.docstore.docs) == 0, (
f"Expected 0 nodes after clearing all documents, "
f"but found {len(reloaded_index.docstore.docs)}: "
f"{list(reloaded_index.docstore.docs.keys())}"
)
# Step 4: the table must be absent (no rows) — phantom vectors gone
with indexing.get_vector_store() as store2:
assert not store2.table_exists(), (
"Expected the vector store table to be absent after rebuilding "
"with no documents"
)
class TestDocumentUpdatedSignalTriggersLlmReindex:
@@ -709,10 +541,14 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
def test_returns_without_error_when_build_document_node_returns_empty(
self,
temp_llm_index_dir: Path,
mock_embed_model: MagicMock,
mocker: pytest_mock.MockerFixture,
) -> None:
"""When build_document_node returns [], the function must return without error
and must not call load_or_build_index at all."""
"""When build_document_node returns [], the function must return without error.
The store's upsert_document treats an empty node list as a removal (no-op
delete), so load_or_build_index must not be called.
"""
mocker.patch(
"paperless_ai.indexing.build_document_node",
return_value=[],
@@ -720,6 +556,7 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
mock_load = mocker.patch("paperless_ai.indexing.load_or_build_index")
doc = MagicMock(spec=Document)
doc.id = 42
# Must not raise
indexing.llm_index_add_or_update_document(doc)
@@ -727,172 +564,165 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
@pytest.mark.django_db
class TestLlmIndexLocking:
"""The FAISS index mutation functions must acquire the index lock before touching the index.
def test_llm_index_compact_uses_force(
temp_llm_index_dir: Path,
mocker: pytest_mock.MockerFixture,
) -> None:
"""compact must use force=True to rebuild the table and reclaim space immediately."""
mock_store = mocker.MagicMock()
mocker.patch(
"paperless_ai.indexing.write_store",
return_value=mocker.MagicMock(
__enter__=mocker.MagicMock(return_value=mock_store),
__exit__=mocker.MagicMock(return_value=False),
),
)
Without locking, two concurrent Celery workers can each load the same
on-disk index, make independent modifications, and the last writer silently
overwrites the first's changes.
indexing.llm_index_compact()
mock_store.compact.assert_called_once_with(force=True)
@pytest.mark.django_db
class TestLlmIndexLocking:
"""Index mutation functions must go through write_store(), which holds the lock.
Without locking, two concurrent Celery workers can open the same store,
make independent modifications, and trigger CommitConflictError.
"""
def test_add_or_update_document_acquires_lock(
def test_add_or_update_document_uses_write_store(
self,
temp_llm_index_dir: Path,
mock_embed_model: FakeEmbedding,
mocker: pytest_mock.MockerFixture,
) -> None:
"""llm_index_add_or_update_document must enter the file lock before touching the index."""
call_order: list[str] = []
mock_lock_instance = MagicMock()
mock_lock_instance.__enter__ = MagicMock(
side_effect=lambda *_: call_order.append("lock_acquired"),
)
mock_lock_instance.__exit__ = MagicMock(return_value=False)
mock_file_lock_cls = mocker.patch(
"paperless_ai.indexing.FileLock",
return_value=mock_lock_instance,
)
mock_load = mocker.patch(
"paperless_ai.indexing.load_or_build_index",
side_effect=lambda *_a, **_kw: (
call_order.append("index_loaded") or MagicMock()
mock_store = MagicMock()
mocker.patch(
"paperless_ai.indexing.write_store",
return_value=mocker.MagicMock(
__enter__=mocker.MagicMock(return_value=mock_store),
__exit__=mocker.MagicMock(return_value=False),
),
)
mock_node = MagicMock()
mock_node.get_content.return_value = "fake node text"
mocker.patch(
"paperless_ai.indexing.build_document_node",
return_value=[MagicMock()],
return_value=[mock_node],
)
mocker.patch("paperless_ai.indexing.remove_document_docstore_nodes")
doc = MagicMock(spec=Document)
doc.id = 1
indexing.llm_index_add_or_update_document(doc)
mock_file_lock_cls.assert_called_once()
mock_lock_instance.__enter__.assert_called_once()
mock_load.assert_called_once()
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
"Lock must be acquired before the index is loaded"
)
mock_store.upsert_document.assert_called_once()
def test_remove_document_acquires_lock(
def test_remove_document_uses_write_store(
self,
temp_llm_index_dir: Path,
mocker: pytest_mock.MockerFixture,
) -> None:
"""llm_index_remove_document must enter the file lock before loading the index."""
call_order: list[str] = []
mock_lock_instance = MagicMock()
mock_lock_instance.__enter__ = MagicMock(
side_effect=lambda *_: call_order.append("lock_acquired"),
)
mock_lock_instance.__exit__ = MagicMock(return_value=False)
mock_file_lock_cls = mocker.patch(
"paperless_ai.indexing.FileLock",
return_value=mock_lock_instance,
)
mock_load = mocker.patch(
"paperless_ai.indexing.load_or_build_index",
side_effect=lambda *_a, **_kw: (
call_order.append("index_loaded") or MagicMock()
mock_store = MagicMock()
mocker.patch(
"paperless_ai.indexing.write_store",
return_value=mocker.MagicMock(
__enter__=mocker.MagicMock(return_value=mock_store),
__exit__=mocker.MagicMock(return_value=False),
),
)
mocker.patch("paperless_ai.indexing.remove_document_docstore_nodes")
doc = MagicMock(spec=Document)
doc.id = 1
indexing.llm_index_remove_document(doc)
mock_file_lock_cls.assert_called_once()
mock_lock_instance.__enter__.assert_called_once()
mock_load.assert_called_once()
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
"Lock must be acquired before the index is loaded"
)
mock_store.delete.assert_called_once_with("1")
def test_update_llm_index_rebuild_acquires_lock(
def test_update_llm_index_rebuild_uses_write_store(
self,
temp_llm_index_dir: Path,
mock_embed_model: MagicMock,
mock_embed_model: FakeEmbedding,
mocker: pytest_mock.MockerFixture,
) -> None:
"""update_llm_index must enter the file lock during the rebuild/persist cycle."""
mock_lock_instance = MagicMock()
mock_lock_instance.__enter__ = MagicMock(return_value=None)
mock_lock_instance.__exit__ = MagicMock(return_value=False)
mock_file_lock_cls = mocker.patch(
"paperless_ai.indexing.FileLock",
return_value=mock_lock_instance,
mock_store = MagicMock()
mocker.patch(
"paperless_ai.indexing.write_store",
return_value=mocker.MagicMock(
__enter__=mocker.MagicMock(return_value=mock_store),
__exit__=mocker.MagicMock(return_value=False),
),
)
# exists=True so the code reaches the lock; iterate over an empty
# queryset so VectorStoreIndex is called with no nodes (still exercises
# the lock path without needing heavy FAISS fixture data)
mock_qs = MagicMock()
mock_qs.exists.return_value = True
mock_qs.__iter__ = MagicMock(return_value=iter([]))
mocker.patch("paperless_ai.indexing.Document.objects.all", return_value=mock_qs)
mocker.patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
)
indexing.update_llm_index(rebuild=True)
mock_file_lock_cls.assert_called_once()
mock_lock_instance.__enter__.assert_called_once()
mock_store.drop_table.assert_called_once()
def test_query_similar_documents_acquires_lock(
@pytest.mark.django_db
@pytest.mark.django_db
class TestVectorStoreIndexing:
def test_get_vector_store_roundtrip(
self,
temp_llm_index_dir: Path,
mocker: pytest_mock.MockerFixture,
mock_embed_model: FakeEmbedding,
) -> None:
"""query_similar_documents must enter the file lock before loading the index."""
call_order: list[str] = []
with indexing.get_vector_store() as store:
assert isinstance(store, PaperlessSqliteVecVectorStore)
mock_lock_instance = MagicMock()
mock_lock_instance.__enter__ = MagicMock(
side_effect=lambda *_: call_order.append("lock_acquired"),
)
mock_lock_instance.__exit__ = MagicMock(return_value=False)
def test_add_then_remove_document(
self,
temp_llm_index_dir: Path,
mock_embed_model: FakeEmbedding,
real_document: Document,
) -> None:
indexing.llm_index_add_or_update_document(real_document)
with indexing.get_vector_store() as store:
assert store.table_exists()
count_sql = "SELECT count(*) FROM documents"
assert store.client.execute(count_sql).fetchone()[0] >= 1
mock_file_lock_cls = mocker.patch(
"paperless_ai.indexing.FileLock",
return_value=mock_lock_instance,
)
indexing.llm_index_remove_document(real_document)
assert store.client.execute(count_sql).fetchone()[0] == 0
mocker.patch(
"paperless_ai.indexing.vector_store_file_exists",
return_value=True,
)
def test_update_shrinks_chunks_without_orphans(
self,
temp_llm_index_dir: Path,
mock_embed_model: FakeEmbedding,
real_document: Document,
) -> None:
real_document.content = "word " * 4000 # many chunks
real_document.save()
indexing.llm_index_add_or_update_document(real_document)
count_sql = "SELECT count(*) FROM documents"
with indexing.get_vector_store() as store:
big = store.client.execute(count_sql).fetchone()[0]
mock_index = MagicMock()
mock_index.docstore.docs = {}
real_document.content = "short" # one chunk
real_document.save()
indexing.llm_index_add_or_update_document(real_document)
mocker.patch(
"paperless_ai.indexing.load_or_build_index",
side_effect=lambda *_a, **_kw: (
call_order.append("index_loaded") or mock_index
),
)
rows = store.client.execute(count_sql).fetchone()[0]
assert rows < big
assert rows >= 1
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = []
mocker.patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=mock_retriever,
)
mocker.patch("paperless_ai.indexing.truncate_content", return_value="")
@pytest.mark.django_db
class TestQuerySimilarDocuments:
def test_query_similar_documents_respects_allowed_ids(
self,
temp_llm_index_dir: Path,
mock_embed_model: FakeEmbedding,
) -> None:
a = DocumentFactory.create(content="alpha shared content here")
b = DocumentFactory.create(content="beta shared content here")
c = DocumentFactory.create(content="gamma shared content here")
for doc in (a, b, c):
indexing.llm_index_add_or_update_document(doc)
indexing.query_similar_documents(MagicMock(spec=Document))
results = indexing.query_similar_documents(a, document_ids=[b.id])
mock_file_lock_cls.assert_called()
mock_lock_instance.__enter__.assert_called()
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
"Lock must be acquired before the index is loaded"
)
assert all(doc.id == b.id for doc in results)
+110 -130
View File
@@ -3,19 +3,20 @@ from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core import settings as llama_settings
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.schema import TextNode
from documents.tests.factories import DocumentFactory
from paperless_ai import chat
from paperless_ai import indexing
from paperless_ai.chat import CHAT_ERROR_MESSAGE
from paperless_ai.chat import CHAT_METADATA_DELIMITER
from paperless_ai.chat import _get_document_filtered_retriever
from paperless_ai.chat import stream_chat_with_documents
@pytest.fixture(autouse=True)
def patch_embed_model():
from llama_index.core import settings as llama_settings
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
# Use a real BaseEmbedding subclass to satisfy llama-index 0.14 validation
llama_settings.Settings.embed_model = MockEmbedding(embed_dim=1536)
yield
@@ -58,91 +59,6 @@ def assert_chat_output(
}
def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None:
mock_index.index_struct.nodes_dict = {
str(vector_id): node.node_id for vector_id, node in enumerate(nodes)
}
mock_index.docstore.docs.get.side_effect = {
node.node_id: node for node in nodes
}.get
mock_index.vector_store._faiss_index.ntotal = len(nodes)
mock_index.vector_store.query.return_value = MagicMock(
ids=list(mock_index.index_struct.nodes_dict),
similarities=[0.1] * len(nodes),
)
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
def test_document_filtered_retriever_expands_filters_and_caches() -> None:
allowed_node1 = TextNode(
text="Allowed content 1.",
metadata={"document_id": "1", "title": "Allowed 1"},
)
allowed_node2 = TextNode(
text="Allowed content 2.",
metadata={"document_id": "2", "title": "Allowed 2"},
)
foreign_node = TextNode(
text="Foreign content.",
metadata={"document_id": "3", "title": "Foreign"},
)
missing_node = TextNode(
text="Missing content.",
metadata={"document_id": "1", "title": "Missing"},
)
mock_index = MagicMock()
mock_index.index_struct.nodes_dict = {
"0": foreign_node.node_id,
"1": missing_node.node_id,
"2": allowed_node1.node_id,
"3": allowed_node2.node_id,
}
mock_index.docstore.docs.get.side_effect = {
allowed_node1.node_id: allowed_node1,
allowed_node2.node_id: allowed_node2,
foreign_node.node_id: foreign_node,
}.get
mock_index.vector_store._faiss_index.ntotal = 4
mock_index.vector_store.query.side_effect = [
MagicMock(ids=["0", "2"], similarities=[0.9, 0.8]),
MagicMock(ids=["0", "1", "3"], similarities=[0.9, 0.7, 0.6]),
]
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
retriever = _get_document_filtered_retriever(
mock_index,
{"1", "2"},
similarity_top_k=2,
)
nodes = retriever.retrieve("question")
cached_nodes = retriever.retrieve("question")
assert [node.node.node_id for node in nodes] == [
allowed_node1.node_id,
allowed_node2.node_id,
]
assert cached_nodes == nodes
assert mock_index.vector_store.query.call_count == 2
assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1
def test_document_filtered_retriever_handles_empty_faiss_index() -> None:
mock_index = MagicMock()
mock_index.vector_store._faiss_index.ntotal = 0
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
retriever = _get_document_filtered_retriever(
mock_index,
{"1"},
similarity_top_k=2,
)
assert retriever.retrieve("question") == []
mock_index.vector_store.query.assert_not_called()
@pytest.mark.django_db
def test_stream_chat_with_one_document_retrieval(
mock_document,
@@ -164,17 +80,31 @@ def test_stream_chat_with_one_document_retrieval(
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
add_vector_query_results(mock_index, [mock_node])
# Simulate get_nodes returning nodes (content exists)
mock_index.vector_store.get_nodes.return_value = [mock_node]
mock_load_index.return_value = mock_index
mock_retriever_instance = MagicMock()
mock_retriever_instance.retrieve.return_value = [
MagicMock(
metadata={
"document_id": str(mock_document.pk),
"title": "Test Document",
},
),
]
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
output = list(stream_chat_with_documents("What is this?", [mock_document]))
with patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=mock_retriever_instance,
):
output = list(stream_chat_with_documents("What is this?", [mock_document]))
mock_query_engine.query.assert_called_once_with("What is this?")
patch_embed_nodes.assert_not_called()
@@ -196,12 +126,10 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
):
# Mock AIClient and LLM
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
# Create two real TextNodes
mock_node1 = TextNode(
text="Content for doc 1.",
metadata={"document_id": "1", "title": "Document 1"},
@@ -210,41 +138,32 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
text="Content for doc 2.",
metadata={"document_id": "2", "title": "Document 2"},
)
mock_duplicate_node = TextNode(
text="More content for doc 1.",
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
)
mock_foreign_node = TextNode(
text="Content for doc 3.",
metadata={"document_id": "3", "title": "Document 3"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [
mock_node1,
mock_node2,
mock_duplicate_node,
mock_foreign_node,
]
add_vector_query_results(
mock_index,
[mock_node1, mock_duplicate_node, mock_node2, mock_foreign_node],
)
# Simulate get_nodes returning nodes (content exists)
mock_index.vector_store.get_nodes.return_value = [mock_node1, mock_node2]
mock_load_index.return_value = mock_index
# Mock response stream
mock_retriever_instance = MagicMock()
mock_retriever_instance.retrieve.return_value = [
MagicMock(metadata={"document_id": "1", "title": "Document 1"}),
MagicMock(metadata={"document_id": "2", "title": "Document 2"}),
]
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
# Mock RetrieverQueryEngine
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
# Fake documents
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
with patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=mock_retriever_instance,
):
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
mock_query_engine.query.assert_called_once_with("What's up?")
patch_embed_nodes.assert_not_called()
@@ -258,8 +177,16 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
)
def test_stream_chat_empty_document_list() -> None:
with patch("paperless_ai.chat.load_or_build_index") as mock_load_index:
output = list(stream_chat_with_documents("Any info?", []))
mock_load_index.assert_not_called()
assert output == ["Sorry, I couldn't find any content to answer your question."]
def test_stream_chat_no_matching_nodes() -> None:
with (
patch("paperless_ai.chat.AIConfig"),
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
):
@@ -268,8 +195,8 @@ def test_stream_chat_no_matching_nodes() -> None:
mock_client.llm = MagicMock()
mock_index = MagicMock()
# No matching nodes
mock_index.docstore.docs.values.return_value = []
# No matching nodes in the store
mock_index.vector_store.get_nodes.return_value = []
mock_load_index.return_value = mock_index
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
@@ -279,30 +206,83 @@ def test_stream_chat_no_matching_nodes() -> None:
def test_stream_chat_unexpected_failure_returns_generic_error(caplog) -> None:
with (
patch("paperless_ai.chat.AIConfig"),
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless_ai.chat._get_document_filtered_retriever",
) as mock_get_retriever,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
mock_node = TextNode(
text="This is node content.",
metadata={"document_id": "1", "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
# Nodes found so we get past the pre-check
mock_index.vector_store.get_nodes.return_value = [MagicMock()]
mock_load_index.return_value = mock_index
mock_retriever = MagicMock()
mock_retriever.retrieve.side_effect = RuntimeError("private provider detail")
mock_get_retriever.return_value = mock_retriever
with patch(
"llama_index.core.retrievers.VectorIndexRetriever",
) as mock_retriever_cls:
mock_retriever = MagicMock()
mock_retriever.retrieve.side_effect = RuntimeError(
"private provider detail",
)
mock_retriever_cls.return_value = mock_retriever
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
assert output == [CHAT_ERROR_MESSAGE]
assert "Failed to stream document chat response" in caplog.text
assert "private provider detail" in caplog.text
@pytest.mark.django_db
class TestStreamChatRetrieval:
def test_no_nodes_yields_no_content_message(
self,
temp_llm_index_dir,
mock_embed_model,
) -> None:
doc = DocumentFactory.create(content="hello world")
# Nothing indexed for this document yet.
out = list(chat.stream_chat_with_documents("question?", [doc]))
assert chat.CHAT_NO_CONTENT_MESSAGE in out
def test_chat_filter_contains_only_requested_document_ids(
self,
temp_llm_index_dir,
mock_embed_model,
mocker,
) -> None:
"""The MetadataFilter passed to the retriever must be scoped to the
requested documents only content from other indexed documents must
not be surfaced.
"""
included = DocumentFactory.create(content="included document content")
excluded = DocumentFactory.create(content="excluded document content")
indexing.llm_index_add_or_update_document(included)
indexing.llm_index_add_or_update_document(excluded)
# VectorIndexRetriever is imported inside _stream_chat_with_documents;
# patch it at the llama_index source so the lazy import picks it up.
captured_filters = []
mock_retriever = mocker.MagicMock()
mock_retriever.retrieve.return_value = []
def capture_retriever(*args, **kwargs):
captured_filters.append(kwargs.get("filters"))
return mock_retriever
mocker.patch("paperless_ai.chat.AIClient")
mocker.patch(
"llama_index.core.retrievers.VectorIndexRetriever",
side_effect=capture_retriever,
)
list(chat.stream_chat_with_documents("question?", [included]))
assert captured_filters, "VectorIndexRetriever was never constructed"
filt = captured_filters[0]
assert filt is not None, "Retriever must receive a MetadataFilters"
filter_values = filt.filters[0].value
assert str(included.pk) in filter_values
assert str(excluded.pk) not in filter_values
+46 -60
View File
@@ -1,4 +1,3 @@
import json
from unittest.mock import ANY
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -10,7 +9,7 @@ from documents.models import Document
from paperless.models import LLMEmbeddingBackend
from paperless_ai.embedding import _normalize_llm_index_text
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_configured_model_name
from paperless_ai.embedding import get_embedding_model
@@ -67,7 +66,7 @@ def test_get_embedding_model_openai(mock_ai_config):
with patch(
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
) as MockOpenAIEmbedding:
model = get_embedding_model()
model = get_embedding_model(mock_ai_config.return_value)
MockOpenAIEmbedding.assert_called_once_with(
model_name="text-embedding-3-small",
api_key="test_api_key",
@@ -88,7 +87,7 @@ def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config):
with patch(
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
) as MockOpenAIEmbedding:
model = get_embedding_model()
model = get_embedding_model(mock_ai_config.return_value)
MockOpenAIEmbedding.assert_called_once_with(
model_name="text-embedding-3-small",
api_key="test_api_key",
@@ -109,7 +108,7 @@ def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
mock_ai_config.return_value.llm_allow_internal_endpoints = False
with pytest.raises(ValueError, match="non-public address"):
get_embedding_model()
get_embedding_model(mock_ai_config.return_value)
def test_get_embedding_model_huggingface(mock_ai_config):
@@ -121,7 +120,7 @@ def test_get_embedding_model_huggingface(mock_ai_config):
with patch(
"llama_index.embeddings.huggingface.HuggingFaceEmbedding",
) as MockHuggingFaceEmbedding:
model = get_embedding_model()
model = get_embedding_model(mock_ai_config.return_value)
MockHuggingFaceEmbedding.assert_called_once_with(
model_name="sentence-transformers/all-MiniLM-L6-v2",
cache_folder=str(settings.DATA_DIR / "hf_cache"),
@@ -137,7 +136,7 @@ def test_get_embedding_model_ollama(mock_ai_config):
with patch(
"llama_index.embeddings.ollama.OllamaEmbedding",
) as MockOllamaEmbedding:
model = get_embedding_model()
model = get_embedding_model(mock_ai_config.return_value)
MockOllamaEmbedding.assert_called_once_with(
model_name="embeddinggemma",
base_url="http://test-url",
@@ -155,7 +154,7 @@ def test_get_embedding_model_ollama_prefers_embedding_endpoint(mock_ai_config):
with patch(
"llama_index.embeddings.ollama.OllamaEmbedding",
) as MockOllamaEmbedding:
model = get_embedding_model()
model = get_embedding_model(mock_ai_config.return_value)
MockOllamaEmbedding.assert_called_once_with(
model_name="embeddinggemma",
base_url="http://embedding-url",
@@ -173,7 +172,7 @@ def test_get_embedding_model_ollama_blocks_internal_endpoint_when_disallowed(
mock_ai_config.return_value.llm_allow_internal_endpoints = False
with pytest.raises(ValueError, match="non-public address"):
get_embedding_model()
get_embedding_model(mock_ai_config.return_value)
def test_get_embedding_model_invalid_backend(mock_ai_config):
@@ -183,55 +182,37 @@ def test_get_embedding_model_invalid_backend(mock_ai_config):
ValueError,
match="Unsupported embedding backend: INVALID_BACKEND",
):
get_embedding_model()
get_embedding_model(mock_ai_config.return_value)
def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
mock_ai_config.return_value.llm_embedding_model = None
class DummyEmbedding:
def get_text_embedding(self, text):
return [0.0] * 7
with patch(
"paperless_ai.embedding.get_embedding_model",
return_value=DummyEmbedding(),
) as mock_get:
dim = get_embedding_dim()
mock_get.assert_called_once()
assert dim == 7
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7}
@pytest.mark.parametrize(
("backend", "expected_default"),
[
(LLMEmbeddingBackend.OPENAI_LIKE, "text-embedding-3-small"),
(LLMEmbeddingBackend.HUGGINGFACE, "sentence-transformers/all-MiniLM-L6-v2"),
(LLMEmbeddingBackend.OLLAMA, "embeddinggemma"),
],
)
def test_get_configured_model_name_falls_back_to_backend_default(
mock_ai_config,
backend,
expected_default,
):
"""When no model is explicitly configured, each backend has a distinct default."""
config = mock_ai_config.return_value
config.llm_embedding_backend = backend
config.llm_embedding_model = None
assert get_configured_model_name(config) == expected_default
def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
mock_ai_config.return_value.llm_embedding_model = None
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}),
)
with patch("paperless_ai.embedding.get_embedding_model") as mock_get:
assert get_embedding_dim() == 11
mock_get.assert_not_called()
def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
mock_ai_config.return_value.llm_embedding_model = None
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "old", "dim": 11}),
)
with pytest.raises(
RuntimeError,
match="Embedding model changed from old to text-embedding-3-small",
):
get_embedding_dim()
def test_get_configured_model_name_explicit_overrides_default(mock_ai_config):
"""An explicit model name overrides the backend default for all backends."""
config = mock_ai_config.return_value
config.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
config.llm_embedding_model = "my-custom-model"
# The backend default for OPENAI_LIKE is "text-embedding-3-small", so if
# the explicit name was ignored we'd get the wrong result.
assert get_configured_model_name(config) == "my-custom-model"
def test_build_llm_index_text(mock_document):
@@ -243,12 +224,17 @@ def test_build_llm_index_text(mock_document):
result = build_llm_index_text(mock_document)
assert "Title: Test Title" in result
assert "Filename: test_file.pdf" in result
assert "Created: 2023-01-01" in result
assert "Tags: Tag1, Tag2" in result
assert "Document Type: Invoice" in result
assert "Correspondent: Test Correspondent" in result
# Structured fields live in node.metadata for LLM context -- not body text
assert "Title: Test Title" not in result
assert "Created: 2023-01-01" not in result
assert "Tags: Tag1, Tag2" not in result
assert "Document Type: Invoice" not in result
assert "Correspondent: Test Correspondent" not in result
assert "Filename:" not in result
assert "Storage Path:" not in result
assert "Archive Serial Number:" not in result
# Fields without a metadata equivalent stay in body text
assert "Notes: Note1,Note2" in result
assert "Content:\n\nThis is the document content." in result
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
@@ -0,0 +1,134 @@
import logging
import sqlite3
import threading
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from django.conf import settings
from filelock import ReadWriteLock
from llama_index.core.schema import TextNode
from pytest_django.fixtures import SettingsWrapper
from paperless_ai import indexing
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
DIM = 8
def _node(node_id: str, document_id: str, *, seed: float = 0.0) -> TextNode:
node = TextNode(
id_=node_id,
text="chunk",
metadata={"document_id": document_id, "modified": "2026-06-01T00:00:00"},
)
node.relationships = {}
node.embedding = [seed + i / 100 for i in range(DIM)]
return node
def _seed_bloated_index(index_dir: Path) -> None:
"""Create an index whose cumulative inserts far exceed live rows."""
store = PaperlessSqliteVecVectorStore(uri=str(index_dir))
store.add([_node(f"d{j}", str(j), seed=float(j)) for j in range(20)])
for cycle in range(6):
for j in range(20):
store.upsert_document(
str(j),
[_node(f"d{j}-c{cycle}", str(j), seed=float(j))],
)
store.client.close()
def _bloat_ratio(index_dir: Path) -> float:
store = PaperlessSqliteVecVectorStore(uri=str(index_dir))
live = store.client.execute("SELECT count(*) FROM documents").fetchone()[0]
row = store.client.execute(
"SELECT value FROM index_meta WHERE key = 'total_inserts'",
).fetchone()
total = int(row["value"]) if row else live
store.client.close()
return total / max(live, 1)
def _integrity_ok(index_dir: Path) -> bool:
store = PaperlessSqliteVecVectorStore(uri=str(index_dir))
result = store.client.execute("PRAGMA integrity_check").fetchone()[0]
rows = store.client.execute("SELECT count(*) FROM documents").fetchone()[0]
store.client.close()
return result == "ok" and rows == 20
def _reader_lock() -> ReadWriteLock:
# A distinct instance simulates a reader in another process: it coordinates
# with the production lock purely through SQLite, never reentrant upgrade.
return ReadWriteLock(str(settings.LLM_INDEX_RWLOCK), is_singleton=False)
class TestCompactionLock:
def test_compaction_skips_when_a_reader_holds_the_lock(
self,
temp_llm_index_dir: Path,
settings: SettingsWrapper,
caplog: pytest.LogCaptureFixture,
) -> None:
_seed_bloated_index(temp_llm_index_dir)
settings.LLM_INDEX_COMPACTION_LOCK_TIMEOUT = 0.3
lock = _reader_lock()
with lock.read_lock(), caplog.at_level(logging.INFO):
indexing.llm_index_compact() # must not raise
lock.close()
# Swap was skipped: bloat remains, nothing corrupted, data intact.
assert _integrity_ok(temp_llm_index_dir)
assert _bloat_ratio(temp_llm_index_dir) > 2
assert "Skipping LLM index compaction" in caplog.text
def test_compaction_runs_when_no_reader_holds_the_lock(
self,
temp_llm_index_dir: Path,
) -> None:
_seed_bloated_index(temp_llm_index_dir)
assert _bloat_ratio(temp_llm_index_dir) > 2
indexing.llm_index_compact()
assert _bloat_ratio(temp_llm_index_dir) == pytest.approx(1.0)
assert _integrity_ok(temp_llm_index_dir)
def test_normal_write_is_not_gated_by_the_compaction_lock(
self,
temp_llm_index_dir: Path,
) -> None:
"""A held exclusive lock must not block ordinary writes (WAL handles them)."""
_seed_bloated_index(temp_llm_index_dir)
done = threading.Event()
def remove() -> None:
indexing.llm_index_remove_document(MagicMock(id=999))
done.set()
holder = _reader_lock()
with holder.write_lock():
t = threading.Thread(target=remove)
t.start()
finished = done.wait(timeout=5)
t.join(timeout=2)
holder.close()
assert finished, "a normal write blocked on the compaction lock"
class TestReadStore:
def test_closes_connection_on_exit(self, temp_llm_index_dir: Path) -> None:
with indexing.read_store() as store:
conn = store.client
assert conn.execute("SELECT 1").fetchone()[0] == 1
with pytest.raises(sqlite3.ProgrammingError):
conn.execute("SELECT 1")
def test_concurrent_readers_do_not_block(self, temp_llm_index_dir: Path) -> None:
_seed_bloated_index(temp_llm_index_dir)
with indexing.read_store() as a, indexing.read_store() as b:
assert a.table_exists()
assert b.table_exists()
@@ -0,0 +1,25 @@
import subprocess
import sys
from pathlib import Path
_SRC_DIR = Path(__file__).parent.parent.parent
class TestLazyAiImports:
def test_importing_tasks_does_not_load_ai_libraries(self) -> None:
code = (
"import os, django, sys\n"
"os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'paperless.settings')\n"
"django.setup()\n"
"import documents.tasks # noqa: F401\n"
"leaked = [m for m in ('lancedb', 'pyarrow', 'llama_index', 'sqlite_vec') "
"if m in sys.modules]\n"
"assert not leaked, f'AI libraries leaked into the light path: {leaked}'\n"
)
result = subprocess.run(
[sys.executable, "-c", code],
capture_output=True,
text=True,
cwd=_SRC_DIR,
)
assert result.returncode == 0, result.stdout + result.stderr
+606
View File
@@ -0,0 +1,606 @@
import sqlite3
from collections.abc import Generator
from pathlib import Path
import pytest
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
from llama_index.core.vector_stores.types import VectorStoreQuery
from paperless_ai.vector_store import DB_FILENAME
from paperless_ai.vector_store import DEFAULT_TABLE_NAME
from paperless_ai.vector_store import MIGRATIONS
from paperless_ai.vector_store import SCHEMA_VERSION
from paperless_ai.vector_store import Migration
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
from paperless_ai.vector_store import _build_where
DIM = 16
def make_node(
node_id: str,
document_id: str,
*,
modified: str = "2026-06-10T00:00:00",
seed: float = 0.0,
text: str = "some text",
) -> TextNode:
node = TextNode(
id_=node_id,
text=text,
metadata={"document_id": document_id, "modified": modified},
)
node.relationships = {}
node.embedding = [seed + i / 100 for i in range(DIM)]
return node
@pytest.fixture
def store(tmp_path: Path) -> Generator[PaperlessSqliteVecVectorStore, None, None]:
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as store:
yield store
def _query(
store: PaperlessSqliteVecVectorStore,
embedding: list[float],
top_k: int = 5,
filters=None,
):
return store.query(
VectorStoreQuery(
query_embedding=embedding,
similarity_top_k=top_k,
filters=filters,
),
)
def _eq_filter(key: str, value: str):
return MetadataFilters(
filters=[MetadataFilter(key=key, operator=FilterOperator.EQ, value=value)],
)
def _in_filter(document_ids: list[str]):
return MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=document_ids,
),
],
)
class TestCrud:
def test_add_then_query_returns_node(self, store) -> None:
node = make_node("n1", "1")
assert store.add([node]) == ["n1"]
result = _query(store, node.embedding, top_k=1)
assert result.ids == ["n1"]
assert result.nodes[0].metadata["document_id"] == "1"
# cosine distance of the identical vector is 0 -> similarity 1
assert result.similarities[0] == pytest.approx(1.0)
def test_query_empty_store_returns_empty_no_raise(self, store) -> None:
result = _query(store, [0.0] * DIM)
assert result.ids == [] and result.nodes == [] and result.similarities == []
def test_add_empty_list_is_noop(self, store) -> None:
assert store.add([]) == []
assert not store.table_exists()
def test_delete_removes_all_chunks_of_document(self, store) -> None:
store.add([make_node("a1", "1"), make_node("a2", "1"), make_node("b1", "2")])
store.delete("1")
result = _query(store, [0.0] * DIM, top_k=10)
assert result.ids == ["b1"]
def test_query_with_in_filter_scopes_results(self, store) -> None:
store.add(
[
make_node("a1", "1", seed=0.0),
make_node("b1", "2", seed=1.0),
make_node("c1", "3", seed=2.0),
],
)
result = _query(store, [0.0] * DIM, top_k=10, filters=_in_filter(["2", "3"]))
assert sorted(result.ids) == ["b1", "c1"]
def test_query_respects_top_k_with_filter(self, store) -> None:
# k semantics: global top-k even with IN filters (document_id is a
# metadata column, not a partition key -- see design doc).
store.add(
[make_node(f"n{i}", str(i % 4), seed=float(i)) for i in range(12)],
)
result = _query(
store,
[0.0] * DIM,
top_k=3,
filters=_in_filter(["0", "1", "2", "3"]),
)
assert len(result.ids) == 3
assert result.similarities == sorted(result.similarities, reverse=True)
def test_get_nodes_filter_and_empty_paths(self, store) -> None:
assert store.get_nodes(filters=_in_filter(["1"])) == [] # no table yet
store.add([make_node("a1", "1"), make_node("b1", "2")])
nodes = store.get_nodes(filters=_in_filter(["1"]))
assert [n.node_id for n in nodes] == ["a1"]
assert nodes[0].embedding is not None
assert store.get_nodes(filters=_in_filter(["999"])) == []
def test_query_with_eq_filter_scopes_results(self, store) -> None:
store.add(
[
make_node("a1", "1", seed=0.0),
make_node("b1", "2", seed=1.0),
make_node("c1", "3", seed=2.0),
],
)
result = _query(
store,
[0.0] * DIM,
top_k=10,
filters=_eq_filter("document_id", "2"),
)
assert result.ids == ["b1"]
def test_get_nodes_node_ids_not_implemented(self, store) -> None:
with pytest.raises(NotImplementedError):
store.get_nodes(node_ids=["x"])
def test_fresh_instance_sees_existing_table(self, store, tmp_path: Path) -> None:
store.add([make_node("a1", "1")])
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as reopened:
assert reopened.table_exists()
assert reopened.vector_dim() == DIM
assert _query(reopened, [0.0] * DIM, top_k=1).ids == ["a1"]
def test_table_exists_and_drop(self, store) -> None:
assert not store.table_exists()
store.add([make_node("a1", "1")])
assert store.table_exists()
store.drop_table()
assert not store.table_exists()
assert store.vector_dim() is None
class TestBuildWhere:
def test_fails_closed_when_no_filter_is_translatable(self) -> None:
# A nested MetadataFilters is not a MetadataFilter, so it is skipped.
# With no translatable clauses, the function must fail closed rather
# than emit "()" (invalid SQL) and never widen document access.
nested = MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.EQ,
value="1",
),
],
)
where, params = _build_where(MetadataFilters(filters=[nested]))
assert where == "1 = 0"
assert params == []
def test_query_with_untranslatable_filter_returns_no_rows(self, store) -> None:
store.add([make_node("a1", "1"), make_node("b1", "2")])
nested = MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.EQ,
value="1",
),
],
)
filters = MetadataFilters(filters=[nested])
# Must not raise (no "WHERE ()") and must return nothing (fail closed).
assert _query(store, [0.0] * DIM, top_k=5, filters=filters).ids == []
assert store.get_nodes(filters=filters) == []
class TestUpsert:
def test_upsert_replaces_and_prunes_stale_chunks(self, store) -> None:
store.add(
[make_node("d1c1", "1"), make_node("d1c2", "1"), make_node("d2c1", "2")],
)
store.upsert_document("1", [make_node("d1new", "1")])
result = _query(store, [0.0] * DIM, top_k=10)
assert sorted(result.ids) == ["d1new", "d2c1"]
def test_upsert_creates_table_when_missing(self, store) -> None:
store.upsert_document("1", [make_node("a1", "1")])
assert _query(store, [0.0] * DIM, top_k=1).ids == ["a1"]
def test_upsert_empty_nodes_removes_document(self, store) -> None:
store.add([make_node("a1", "1"), make_node("b1", "2")])
store.upsert_document("1", [])
assert _query(store, [0.0] * DIM, top_k=10).ids == ["b1"]
def test_upsert_is_atomic_for_concurrent_readers(
self,
store,
tmp_path: Path,
) -> None:
"""A second connection must never observe document 1 half-replaced."""
store.add([make_node("a1", "1"), make_node("a2", "1")])
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as reader:
store.upsert_document("1", [make_node("a3", "1")])
ids = [n.node_id for n in reader.get_nodes(filters=_in_filter(["1"]))]
assert ids == ["a3"]
class TestMetadataCoercion:
def test_none_metadata_values_become_empty_strings(self, store) -> None:
node = make_node("a1", "1")
node.metadata["modified"] = None
store.add([node]) # must not raise (vec0 rejects NULL metadata)
assert store.get_modified_times() == {"1": ""}
class TestModelNameTracking:
def test_stored_model_name_none_without_table(self, tmp_path: Path) -> None:
with PaperlessSqliteVecVectorStore(
uri=str(tmp_path),
embed_model_name="model-a",
) as store:
assert store.stored_model_name() is None
def test_model_name_stored_after_add_and_persists(self, tmp_path: Path) -> None:
with PaperlessSqliteVecVectorStore(
uri=str(tmp_path),
embed_model_name="model-a",
) as store:
store.add([make_node("a1", "1")])
assert store.stored_model_name() == "model-a"
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as reopened:
assert reopened.stored_model_name() == "model-a"
def test_config_mismatch_semantics(self, tmp_path: Path) -> None:
with PaperlessSqliteVecVectorStore(
uri=str(tmp_path),
embed_model_name="model-a",
) as store:
assert not store.config_mismatch("anything") # no table yet
store.add([make_node("a1", "1")])
assert not store.config_mismatch("model-a")
assert store.config_mismatch("model-b")
def test_config_mismatch_false_when_table_predates_tracking(
self,
tmp_path: Path,
) -> None:
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as store: # no model name
store.add([make_node("a1", "1")])
assert not store.config_mismatch("model-a")
class TestGetModifiedTimes:
def test_empty_store_returns_empty_dict(self, store) -> None:
assert store.get_modified_times() == {}
def test_returns_one_entry_per_document(self, store) -> None:
store.add(
[
make_node("a1", "1", modified="2026-01-01T00:00:00"),
make_node("a2", "1", modified="2026-01-01T00:00:00"),
make_node("b1", "2", modified="2026-02-02T00:00:00"),
],
)
assert store.get_modified_times() == {
"1": "2026-01-01T00:00:00",
"2": "2026-02-02T00:00:00",
}
class TestCompact:
def _bloat_ratio(self, store) -> float:
live = store.client.execute(
"SELECT count(*) FROM documents",
).fetchone()[0]
# vec0 0.1.9 does not accumulate deleted rows in the _rowids shadow
# table, so we track cumulative inserts in index_meta instead.
row = store.client.execute(
"SELECT value FROM index_meta WHERE key = 'total_inserts'",
).fetchone()
total = int(row["value"]) if row else live
return total / max(live, 1)
def _churn(self, store, cycles: int) -> None:
for i in range(cycles):
store.upsert_document(
"1",
[make_node(f"gen{i}-{j}", "1", seed=float(j)) for j in range(20)],
)
def test_compact_noop_below_threshold(self, store) -> None:
store.add([make_node("a1", "1")])
store.compact()
assert _query(store, [0.0] * DIM, top_k=1).ids == ["a1"]
def test_force_compact_preserves_rows_and_metadata(self, store) -> None:
store.add([make_node("a1", "1"), make_node("b1", "2", seed=3.0)])
self._churn(store, 5)
before = {
n.node_id: n.metadata
for n in store.get_nodes(filters=_in_filter(["1", "2"]))
}
store.compact(force=True)
after = {
n.node_id: n.metadata
for n in store.get_nodes(filters=_in_filter(["1", "2"]))
}
assert after == before
assert self._bloat_ratio(store) == pytest.approx(1.0)
# store remains fully usable after the rebuild; use a seed far from all
# existing nodes (gen4-0..gen4-19 have seeds 0..19) so cosine KNN is
# unambiguous at top_k=1.
store.upsert_document("3", [make_node("c1", "3", seed=100.0)])
assert "c1" in _query(store, [100.0] * DIM, top_k=1).ids
def test_auto_compact_triggers_on_churn(self, store) -> None:
store.add([make_node(f"s{j}", "1", seed=float(j)) for j in range(20)])
self._churn(store, 5)
assert self._bloat_ratio(store) > 2
store.compact()
assert self._bloat_ratio(store) == pytest.approx(1.0)
def test_compact_on_missing_table_is_noop(self, store) -> None:
store.compact()
store.compact(force=True)
def test_failed_compact_removes_temp_wal_and_shm(
self,
store,
tmp_path: Path,
monkeypatch,
) -> None:
"""A compact() that raises mid-rebuild must leave no .compact* files.
Normally the sole connection's close() checkpoints the temp WAL away,
but a concurrent reader keeps -wal/-shm alive, so the cleanup must
unlink them explicitly (as the structural-migration path does).
"""
store.add([make_node("a1", "1")])
compact_path = str(tmp_path / DB_FILENAME) + ".compact"
held: list[sqlite3.Connection] = []
def boom(conn: sqlite3.Connection, dim: int) -> None:
# Hold an extra connection so close() of the rebuild connection is
# not the last one -> the temp -wal/-shm survive the checkpoint.
extra = sqlite3.connect(compact_path)
extra.execute("SELECT 1").fetchall()
held.append(extra)
raise RuntimeError("boom")
monkeypatch.setattr(
PaperlessSqliteVecVectorStore,
"_create_vec_table",
staticmethod(boom),
)
try:
with pytest.raises(RuntimeError):
store.compact(force=True)
assert sorted(p.name for p in tmp_path.glob("*.compact*")) == []
finally:
for c in held:
c.close()
class TestDbFile:
def test_single_db_file_in_index_dir(self, store, tmp_path: Path) -> None:
store.add([make_node("a1", "1")])
assert (tmp_path / DB_FILENAME).exists()
def test_wal_mode_enabled(self, store) -> None:
assert (
store.client.execute("PRAGMA journal_mode").fetchone()[0].lower() == "wal"
)
class TestMigrations:
"""Tests for the schema migration machinery."""
def _schema_version(self, store: PaperlessSqliteVecVectorStore) -> int | None:
row = store.client.execute(
"SELECT value FROM index_meta WHERE key = 'schema_version'",
).fetchone()
return int(row[0]) if row else None
def test_new_table_records_schema_version(self, store) -> None:
store.add([make_node("a1", "1")])
assert self._schema_version(store) == SCHEMA_VERSION
def test_check_migrations_no_table_returns_false(self, store) -> None:
assert store.check_and_run_migrations() is False
def test_check_migrations_current_version_returns_false(self, store) -> None:
store.add([make_node("a1", "1")])
assert store.check_and_run_migrations() is False
def test_reembed_migration_returns_true(self, store, tmp_path: Path) -> None:
store.add([make_node("a1", "1")])
migration = Migration(
from_version=1,
to_version=2,
kind="re-embed",
description="test re-embed",
)
MIGRATIONS.append(migration)
try:
from paperless_ai import vector_store as vs_mod
original = vs_mod.SCHEMA_VERSION
vs_mod.SCHEMA_VERSION = 2
result = store.check_and_run_migrations()
finally:
MIGRATIONS.remove(migration)
vs_mod.SCHEMA_VERSION = original
assert result is True
def test_structural_migration_copies_rows_and_updates_version(
self,
store,
tmp_path: Path,
) -> None:
store.add([make_node("a1", "1"), make_node("b1", "2")])
def apply(
src: sqlite3.Connection,
dst: sqlite3.Connection,
dim: int,
) -> None:
dst.execute( # nosemgrep
f"CREATE VIRTUAL TABLE {DEFAULT_TABLE_NAME} USING vec0("
"id TEXT PRIMARY KEY, document_id TEXT, modified TEXT,"
f" +node_content TEXT, embedding float[{dim}] distance_metric=cosine"
")",
)
dst.execute(
"INSERT INTO index_meta (key, value) VALUES ('dim', ?) "
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(str(dim),),
)
rows = src.execute(
"SELECT id, document_id, modified, node_content, embedding "
f"FROM {DEFAULT_TABLE_NAME}",
).fetchall()
dst.execute("BEGIN IMMEDIATE")
dst.executemany(
f"INSERT INTO {DEFAULT_TABLE_NAME} "
"(id, document_id, modified, node_content, embedding) "
"VALUES (?, ?, ?, ?, ?)",
[
(
r["id"],
r["document_id"],
r["modified"],
r["node_content"],
bytes(r["embedding"]),
)
for r in rows
],
)
dst.execute(
"INSERT INTO index_meta (key, value) VALUES ('total_inserts', ?) "
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(str(len(rows)),),
)
dst.execute("COMMIT")
migration = Migration(
from_version=1,
to_version=2,
kind="structural",
description="test structural",
apply=apply,
)
MIGRATIONS.append(migration)
try:
from paperless_ai import vector_store as vs_mod
original = vs_mod.SCHEMA_VERSION
vs_mod.SCHEMA_VERSION = 2
result = store.check_and_run_migrations()
finally:
MIGRATIONS.remove(migration)
vs_mod.SCHEMA_VERSION = original
assert result is False
assert self._schema_version(store) == 2
ids = {n.node_id for n in store.get_nodes()}
assert ids == {"a1", "b1"}
def test_compact_preserves_schema_version(self, store) -> None:
store.add([make_node("a1", "1")])
assert self._schema_version(store) == SCHEMA_VERSION
store.compact(force=True)
assert self._schema_version(store) == SCHEMA_VERSION
def test_stop_at_reembed_boundary(self, store) -> None:
# Registry: structural v2, re-embed v3, structural v4.
# Only v2 should apply; the re-embed boundary must stop execution
# before v4 runs, and the stored version must stay at 2.
store.add([make_node("a1", "1"), make_node("b1", "2")])
def copy_apply(
src: sqlite3.Connection,
dst: sqlite3.Connection,
dim: int,
) -> None:
dst.execute( # nosemgrep
f"CREATE VIRTUAL TABLE {DEFAULT_TABLE_NAME} USING vec0("
"id TEXT PRIMARY KEY, document_id TEXT, modified TEXT,"
f" +node_content TEXT, embedding float[{dim}] distance_metric=cosine"
")",
)
dst.execute(
"INSERT INTO index_meta (key, value) VALUES ('dim', ?) "
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(str(dim),),
)
rows = src.execute(
"SELECT id, document_id, modified, node_content, embedding "
f"FROM {DEFAULT_TABLE_NAME}",
).fetchall()
dst.execute("BEGIN IMMEDIATE")
dst.executemany(
f"INSERT INTO {DEFAULT_TABLE_NAME} "
"(id, document_id, modified, node_content, embedding) "
"VALUES (?, ?, ?, ?, ?)",
[
(
r["id"],
r["document_id"],
r["modified"],
r["node_content"],
bytes(r["embedding"]),
)
for r in rows
],
)
dst.execute("COMMIT")
migrations = [
Migration(
from_version=1,
to_version=2,
kind="structural",
description="v2 structural",
apply=copy_apply,
),
Migration(
from_version=2,
to_version=3,
kind="re-embed",
description="v3 re-embed boundary",
),
Migration(
from_version=3,
to_version=4,
kind="structural",
description="v4 structural - must not run",
apply=copy_apply,
),
]
MIGRATIONS.extend(migrations)
try:
from paperless_ai import vector_store as vs_mod
original = vs_mod.SCHEMA_VERSION
vs_mod.SCHEMA_VERSION = 4
result = store.check_and_run_migrations()
finally:
for m in migrations:
MIGRATIONS.remove(m)
vs_mod.SCHEMA_VERSION = original
assert result is True
assert self._schema_version(store) == 2
+604
View File
@@ -0,0 +1,604 @@
import json
import logging
import sqlite3
import struct
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import field
from pathlib import Path
from types import TracebackType
from typing import Any
from typing import Literal
import sqlite_vec
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.core.vector_stores.types import FilterCondition
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
from llama_index.core.vector_stores.types import VectorStoreQuery
from llama_index.core.vector_stores.types import VectorStoreQueryResult
from llama_index.core.vector_stores.utils import metadata_dict_to_node
from llama_index.core.vector_stores.utils import node_to_metadata_dict
logger = logging.getLogger("paperless_ai.vector_store")
DB_FILENAME = "llmindex.db"
DEFAULT_TABLE_NAME = "documents"
# Current schema version. Written to index_meta at table creation and bumped
# whenever a Migration is added to MIGRATIONS. check_and_run_migrations() uses
# this to decide which migrations to run on an existing store.
SCHEMA_VERSION = 1
# compact(): rebuild when the cumulative rowid count exceeds this multiple of
# the live row count. DELETEs on vec0 tables never reclaim space (upstream
# asg017/sqlite-vec#54), so per-document re-index churn grows the file until
# a rebuild copies the live rows into a fresh table.
COMPACT_BLOAT_RATIO = 2.0
# Filterable vec0 metadata columns. _build_where() only ever receives filter
# keys we construct ourselves, but allowlisting keeps SQL identifiers safe by
# construction.
_FILTER_COLUMNS = frozenset({"document_id", "modified"})
@dataclass
class Migration:
"""A schema migration for the sqlite-vec vector store.
kind="structural": rows are copied into a new-schema file with no
re-embedding needed. Supply ``apply(src_conn, dst_conn, dim)`` which
must create the vec0 table in ``dst_conn``, copy all rows from
``src_conn``, and write ``dim`` / ``embed_model`` / ``total_inserts`` to
``dst_conn``'s ``index_meta``. ``schema_version`` is written by the
migration runner after ``apply`` returns.
kind="re-embed": the new schema requires fresh embeddings.
``check_and_run_migrations()`` returns True when it encounters one of
these so the caller can force a full rebuild (which recreates the table
at the current SCHEMA_VERSION).
"""
from_version: int
to_version: int
kind: Literal["structural", "re-embed"]
description: str
apply: Callable[[sqlite3.Connection, sqlite3.Connection, int], None] | None = field(
default=None,
repr=False,
)
# Registry of all schema migrations in order. Empty at v1 -- this is the
# baseline. Add entries here (and bump SCHEMA_VERSION) when the schema changes.
MIGRATIONS: list[Migration] = []
def _pack(embedding: Sequence[float]) -> bytes:
return struct.pack(f"{len(embedding)}f", *embedding)
def _unpack(blob: bytes) -> list[float]:
return list(struct.unpack(f"{len(blob) // 4}f", blob))
def _build_where(filters: MetadataFilters | None) -> tuple[str, list[str]]:
"""Translate the EQ / IN filters we use into a parameterized SQL clause
on vec0 metadata columns. Returns ("", []) when there is nothing to filter.
"""
if filters is None or not filters.filters:
return "", []
clauses: list[str] = []
params: list[str] = []
for f in filters.filters:
# filters.filters is Union[MetadataFilter, ExactMatchFilter, MetadataFilters];
# we only build MetadataFilter entries, so skip anything else at runtime.
if not isinstance(f, MetadataFilter):
continue
if f.key not in _FILTER_COLUMNS: # pragma: no cover - we build the keys
raise NotImplementedError(f"Unsupported filter column: {f.key}")
if f.operator == FilterOperator.IN:
values = [str(v) for v in f.value] # type: ignore[union-attr] # value is list when operator is IN
if not values: # pragma: no cover
clauses.append("1 = 0")
continue
placeholders = ",".join("?" for _ in values)
clauses.append(f"{f.key} IN ({placeholders})")
params.extend(values)
elif f.operator == FilterOperator.EQ:
clauses.append(f"{f.key} = ?")
params.append(str(f.value))
else: # pragma: no cover - we only ever build EQ/IN filters
raise NotImplementedError(f"Unsupported filter operator: {f.operator}")
if not clauses:
# Filters were requested but none could be translated. Fail closed
# rather than emit "()" (invalid SQL): filters scope document access,
# so an empty translation must match no rows, never widen the scope.
return "1 = 0", []
joiner = " OR " if filters.condition == FilterCondition.OR else " AND "
return "(" + joiner.join(clauses) + ")", params
class PaperlessSqliteVecVectorStore(BasePydanticVectorStore):
"""A llama-index vector store backed by a sqlite-vec vec0 table.
Stores one row per node: the node id (TEXT primary key), its document id
(metadata column, used for EQ/IN filtering and per-document delete), the
document's modified timestamp, the embedding (float32, cosine metric), and
the serialized node (text + metadata) as JSON in an auxiliary column.
``stores_text`` lets llama-index run off this store alone, with no
separate docstore or index store.
Everything lives in one SQLite database file (``DB_FILENAME``) inside the
directory given as ``uri`` (kept as a directory for compatibility with the
previous LanceDB layout). WAL mode allows readers in other processes to
proceed while the (FileLock-serialized) writer holds a transaction.
Implemented surface of ``BasePydanticVectorStore``
---------------------------------------------------
Only the methods actively used by this codebase are implemented.
``delete_nodes`` and the ``node_ids`` lookup path of ``get_nodes`` are
part of the llama-index interface contract and may be needed if a future
retriever or extension invokes them add them then, with tests.
"""
stores_text: bool = True
flat_metadata: bool = False
_uri: str = PrivateAttr()
_embed_model_name: str | None = PrivateAttr()
_conn: Any = PrivateAttr()
def __init__(
self,
uri: str,
embed_model_name: str | None = None,
) -> None:
super().__init__(stores_text=True, flat_metadata=False)
self._uri = uri
self._embed_model_name = embed_model_name
self._conn = self._open_connection(str(Path(uri) / DB_FILENAME))
@staticmethod
def _open_connection(db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(
db_path,
timeout=30,
isolation_level=None, # autocommit; explicit transactions below
)
conn.row_factory = sqlite3.Row
conn.enable_load_extension(True) # noqa: FBT003
sqlite_vec.load(conn)
conn.enable_load_extension(False) # noqa: FBT003
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
conn.execute(
"CREATE TABLE IF NOT EXISTS index_meta (key TEXT PRIMARY KEY, value TEXT)",
)
return conn
@property
def client(self) -> Any:
return self._conn
def close(self) -> None:
"""Close the underlying SQLite connection (idempotent)."""
self._conn.close()
def __enter__(self) -> "PaperlessSqliteVecVectorStore":
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
# Deterministically release the connection (and its WAL/SHM handles) so
# it is never left open across a compaction/migration file swap.
self.close()
@contextmanager
def _transaction(self) -> Iterator[None]:
self._conn.execute("BEGIN IMMEDIATE")
try:
yield
except BaseException: # pragma: no cover
self._conn.execute("ROLLBACK")
raise
else:
self._conn.execute("COMMIT")
def _meta_get(self, key: str) -> str | None:
row = self._conn.execute(
"SELECT value FROM index_meta WHERE key = ?",
(key,),
).fetchone()
return row["value"] if row else None
@staticmethod
def _meta_set_on(conn: sqlite3.Connection, key: str, value: str) -> None:
conn.execute(
"INSERT INTO index_meta (key, value) VALUES (?, ?) "
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(key, value),
)
def _meta_set(self, key: str, value: str) -> None:
self._meta_set_on(self._conn, key, value)
def table_exists(self) -> bool:
return (
self._conn.execute(
"SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ?",
(DEFAULT_TABLE_NAME,),
).fetchone()
is not None
)
def vector_dim(self) -> int | None:
if not self.table_exists():
return None
value = self._meta_get("dim")
return int(value) if value else None
def drop_table(self) -> None:
self._conn.execute("DROP TABLE IF EXISTS " + DEFAULT_TABLE_NAME)
self._conn.execute("DELETE FROM index_meta")
def stored_model_name(self) -> str | None:
"""Return the embedding model name recorded at table creation, or None."""
if not self.table_exists():
return None
return self._meta_get("embed_model")
def config_mismatch(self, model_name: str) -> bool:
"""True when the stored model name differs from ``model_name``.
Returns False when no table exists or when the table predates
model-name tracking conservative default avoids spurious rebuilds.
"""
stored = self.stored_model_name()
if stored is None:
return False
return stored != model_name
@staticmethod
def _create_vec_table(conn: sqlite3.Connection, dim: int) -> None:
# document_id is deliberately a metadata column, NOT a partition key:
# partition keys change KNN `k` to per-partition semantics under IN
# filters (asg017/sqlite-vec#142); metadata columns give a correct
# global top-k.
conn.execute( # nosemgrep: python.sqlalchemy.security.sqlalchemy-execute-raw-query.sqlalchemy-execute-raw-query
"CREATE VIRTUAL TABLE "
+ DEFAULT_TABLE_NAME
+ " USING vec0("
+ "id TEXT PRIMARY KEY,"
+ " document_id TEXT,"
+ " modified TEXT,"
+ " +node_content TEXT,"
+ " embedding float["
+ str(int(dim))
+ "] distance_metric=cosine"
+ ")",
)
def _create_table(self, dim: int) -> None:
self._create_vec_table(self._conn, dim)
self._meta_set("dim", str(dim))
self._meta_set("schema_version", str(SCHEMA_VERSION))
if self._embed_model_name:
self._meta_set("embed_model", self._embed_model_name)
def _ensure_table(self, dim: int) -> None:
if not self.table_exists():
self._create_table(dim)
def _row(self, node: BaseNode) -> tuple[str, str, str, str, bytes]:
meta = node_to_metadata_dict(
node,
remove_text=False,
flat_metadata=self.flat_metadata,
)
# vec0 metadata columns reject NULL (asg017/sqlite-vec#141): coerce
# every value to a string, with "" as the absent sentinel.
document_id = node.ref_doc_id or node.metadata.get("document_id")
return (
node.node_id,
str(document_id or ""),
str(node.metadata.get("modified") or ""),
json.dumps(meta),
_pack(node.get_embedding()),
)
_INSERT = (
"INSERT INTO "
+ DEFAULT_TABLE_NAME
+ " (id, document_id, modified, node_content, embedding) VALUES (?, ?, ?, ?, ?)"
)
def _increment_total_inserts(self, count: int) -> None:
"""Increment the cumulative insert counter stored in index_meta.
This counter never decreases (DELETEs do not decrement it) and is
used by compact() to estimate the bloat ratio: when total_inserts /
live_rows exceeds COMPACT_BLOAT_RATIO the table has accumulated
enough deleted-but-not-freed rows to warrant a rebuild.
"""
current = int(self._meta_get("total_inserts") or "0")
self._meta_set("total_inserts", str(current + count))
def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]:
if not nodes:
return []
rows = [self._row(node) for node in nodes]
with self._transaction():
self._ensure_table(len(nodes[0].get_embedding()))
self._conn.executemany(self._INSERT, rows)
self._increment_total_inserts(len(rows))
return [node.node_id for node in nodes]
def upsert_document(self, document_id: str, nodes: list[BaseNode]) -> list[str]:
"""Atomically replace all stored chunks of ``document_id`` with ``nodes``.
One transaction deletes the document's existing rows and inserts the
new set (vec0's INSERT OR REPLACE is broken upstream, #259, so
delete+insert it is). WAL readers in other processes see either the
old or the new chunk set, never a partial state.
"""
rows = [self._row(node) for node in nodes]
with self._transaction():
if nodes:
self._ensure_table(len(nodes[0].get_embedding()))
if self.table_exists():
self._conn.execute(
"DELETE FROM " + DEFAULT_TABLE_NAME + " WHERE document_id = ?",
(str(document_id),),
)
if rows:
self._conn.executemany(self._INSERT, rows)
self._increment_total_inserts(len(rows))
return [node.node_id for node in nodes]
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
if self.table_exists():
with self._transaction():
self._conn.execute(
"DELETE FROM " + DEFAULT_TABLE_NAME + " WHERE document_id = ?",
(str(ref_doc_id),),
)
def _rows_to_nodes(self, rows: list[sqlite3.Row]) -> list[BaseNode]:
nodes: list[BaseNode] = []
for row in rows:
node = metadata_dict_to_node(json.loads(row["node_content"]))
node.embedding = _unpack(row["embedding"])
nodes.append(node)
return nodes
def get_nodes(
self,
node_ids: list[str] | None = None,
filters: MetadataFilters | None = None,
**kwargs: Any,
) -> list[BaseNode]:
if node_ids is not None: # pragma: no cover
# node_ids lookup is not implemented; see class docstring.
raise NotImplementedError(
"PaperlessSqliteVecVectorStore does not support node_ids lookup",
)
if not self.table_exists():
return []
where, params = _build_where(filters)
sql = "SELECT node_content, embedding FROM " + DEFAULT_TABLE_NAME
if where:
sql += " WHERE " + where
return self._rows_to_nodes(self._conn.execute(sql, params).fetchall())
def query(
self,
query: VectorStoreQuery,
**kwargs: Any,
) -> VectorStoreQueryResult:
if not self.table_exists():
return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
if query.query_embedding is None: # pragma: no cover
return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
top_k = query.similarity_top_k if query.similarity_top_k is not None else 10
where, params = _build_where(query.filters)
sql = (
"SELECT id, node_content, embedding, distance FROM "
+ DEFAULT_TABLE_NAME
+ " WHERE embedding MATCH ? AND k = ?"
)
if where:
sql += " AND " + where
rows = self._conn.execute(
sql,
[_pack(query.query_embedding), top_k, *params],
).fetchall()
# vec0 returns rows distance-sorted ascending; slice defensively in
# case future schema changes alter k semantics (e.g. partition keys
# return k rows per partition).
rows = rows[:top_k]
nodes = self._rows_to_nodes(rows)
# Cosine distance in [0, 2]; map to a descending similarity.
# vec0 returns None distance when the query embedding is the zero vector
# (no meaningful cosine angle); treat that as maximum distance (1.0) so
# the row is included but ranked last.
sims = [
1.0 - float(row["distance"] if row["distance"] is not None else 1.0)
for row in rows
]
ids = [row["id"] for row in rows]
return VectorStoreQueryResult(nodes=nodes, similarities=sims, ids=ids)
def get_modified_times(self) -> dict[str, str]:
"""Return {document_id: stored_modified_isoformat} for all indexed documents.
All chunks of a document share the same ``modified`` value, so the
first row seen per document is sufficient.
"""
if not self.table_exists():
return {}
result: dict[str, str] = {}
for row in self._conn.execute(
"SELECT document_id, modified FROM " + DEFAULT_TABLE_NAME,
):
doc_id = str(row["document_id"])
if doc_id not in result:
result[doc_id] = str(row["modified"] or "")
return result
def compact(self, *, force: bool = False) -> None:
"""Rebuild the database file to reclaim space left behind by DELETEs.
vec0 DELETE only invalidates rows; the vector data stays in the file
forever (asg017/sqlite-vec#54), and per-document re-indexing is a
delete+insert. The cumulative insert counter in ``index_meta`` tracks
total rows ever written; when that exceeds ``COMPACT_BLOAT_RATIO`` x
the live row count (or when forced), live rows are copied into a fresh
database file and swapped in via ``os.replace``.
Note: ``ALTER TABLE ... RENAME TO`` on vec0 virtual tables does NOT
rename the shadow tables (sqlite-vec upstream limitation), so
an in-place rename-based rebuild is not safe. The file-swap approach
is the maintainer-endorsed workaround (asg017/sqlite-vec#205).
"""
if not self.table_exists():
return
live = self._conn.execute(
"SELECT count(*) FROM " + DEFAULT_TABLE_NAME,
).fetchone()[0]
total = int(self._meta_get("total_inserts") or str(live))
if not force and total <= max(live, 1) * COMPACT_BLOAT_RATIO:
return
dim = self.vector_dim()
if dim is None: # pragma: no cover - dim is written at creation
logger.warning("Skipping compact: no stored vector dimension")
return
logger.info(
"Compacting LLM index (%d live rows, %d cumulative inserts)",
live,
total,
)
db_path = str(Path(self._uri) / DB_FILENAME)
compact_path = db_path + ".compact"
# Copy all live rows into a fresh database file.
new_conn = self._open_connection(compact_path)
try:
self._create_vec_table(new_conn, dim)
self._meta_set_on(new_conn, "dim", str(dim))
for key in ("embed_model", "schema_version"):
value = self._meta_get(key)
if value is not None:
self._meta_set_on(new_conn, key, value)
rows = self._conn.execute(
"SELECT id, document_id, modified, node_content, embedding "
"FROM " + DEFAULT_TABLE_NAME,
).fetchall()
new_conn.execute("BEGIN IMMEDIATE")
new_conn.executemany(
self._INSERT,
[
(
r["id"],
r["document_id"],
r["modified"],
r["node_content"],
bytes(r["embedding"]),
)
for r in rows
],
)
# Reset the cumulative counter: after compact, total_inserts == live.
self._meta_set_on(new_conn, "total_inserts", str(live))
new_conn.execute("COMMIT")
except BaseException:
new_conn.close()
for p in [compact_path, compact_path + "-wal", compact_path + "-shm"]:
Path(p).unlink(missing_ok=True)
raise
new_conn.close()
self._swap_in_compact(compact_path, db_path)
def _swap_in_compact(self, compact_path: str, db_path: str) -> None:
"""Atomically replace the live database with the compacted copy."""
self._conn.close()
for suffix in ["-wal", "-shm"]:
stale = Path(compact_path + suffix)
if stale.exists(): # pragma: no cover
stale.unlink()
Path(compact_path).replace(db_path)
self._conn = self._open_connection(db_path)
def check_and_run_migrations(self) -> bool:
"""Apply any pending schema migrations to the store.
Structural migrations copy live rows into a new-schema file with no
re-embedding. Re-embed migrations cannot be applied automatically;
this method returns True when one is encountered so the caller can
force a full rebuild (which recreates the table at SCHEMA_VERSION).
Must be called under the write FileLock. No-op when the table does
not exist or is already at SCHEMA_VERSION.
"""
if not self.table_exists():
return False
raw = self._meta_get("schema_version")
current = int(raw) if raw is not None else SCHEMA_VERSION
if current >= SCHEMA_VERSION:
return False
pending = sorted(
[m for m in MIGRATIONS if current <= m.from_version < SCHEMA_VERSION],
key=lambda m: m.from_version,
)
for migration in pending:
if migration.kind == "re-embed":
logger.warning(
"LLM index schema v%d -> v%d requires re-embedding (%s); "
"forcing full rebuild.",
migration.from_version,
migration.to_version,
migration.description,
)
return True
logger.info(
"Running structural LLM index migration v%d -> v%d: %s",
migration.from_version,
migration.to_version,
migration.description,
)
self._run_structural_migration(migration)
return False
def _run_structural_migration(self, migration: Migration) -> None:
"""Execute a structural migration using the same file-swap as compact()."""
assert migration.apply is not None, "structural migration must have apply()"
dim = self.vector_dim()
if dim is None: # pragma: no cover
raise RuntimeError("Cannot migrate: no stored vector dimension")
db_path = str(Path(self._uri) / DB_FILENAME)
compact_path = db_path + ".compact"
new_conn = self._open_connection(compact_path)
try:
migration.apply(self._conn, new_conn, dim)
self._meta_set_on(new_conn, "schema_version", str(migration.to_version))
except BaseException: # pragma: no cover
new_conn.close()
for p in [compact_path, compact_path + "-wal", compact_path + "-shm"]:
Path(p).unlink(missing_ok=True)
raise
new_conn.close()
self._swap_in_compact(compact_path, db_path)
+14 -8
View File
@@ -4,6 +4,8 @@ import logging
import ssl
import tempfile
import traceback
import unicodedata
from datetime import date
from datetime import timedelta
from fnmatch import fnmatch
from pathlib import Path
@@ -384,7 +386,7 @@ def make_criterias(rule: MailRule, *, supports_gmail_labels: bool):
Returns criteria to be applied to MailBox.fetch for the given rule.
"""
maximum_age = timezone.now().date() - timedelta(days=rule.maximum_age)
maximum_age = date.today() - timedelta(days=rule.maximum_age)
criterias = {}
if rule.maximum_age > 0:
criterias["date_gte"] = maximum_age
@@ -495,10 +497,10 @@ class MailAccountHandler(LoggingMixin):
rule: MailRule,
) -> str | None:
if rule.assign_title_from == MailRule.TitleSource.FROM_SUBJECT:
return message.subject
return unicodedata.normalize("NFC", message.subject)
elif rule.assign_title_from == MailRule.TitleSource.FROM_FILENAME:
return Path(att.filename).stem
return unicodedata.normalize("NFC", Path(att.filename).stem)
elif rule.assign_title_from == MailRule.TitleSource.NONE:
return None
@@ -636,8 +638,8 @@ class MailAccountHandler(LoggingMixin):
self.log.info(f"Located folder: {folder_info.name}")
except Exception as e:
self.log.error(
"Exception during folder listing, unable to provide list folders: %s",
e,
"Exception during folder listing, unable to provide list folders: "
+ str(e),
)
raise MailError(
@@ -865,7 +867,9 @@ class MailAccountHandler(LoggingMixin):
),
)
attachment_name = pathvalidate.sanitize_filename(att.filename)
attachment_name = pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", att.filename),
)
if attachment_name:
temp_filename = temp_dir / attachment_name
else: # pragma: no cover
@@ -881,7 +885,7 @@ class MailAccountHandler(LoggingMixin):
)
doc_overrides = DocumentMetadataOverrides(
title=title,
filename=pathvalidate.sanitize_filename(att.filename),
filename=attachment_name,
correspondent_id=correspondent.id if correspondent else None,
document_type_id=doc_type.id if doc_type else None,
tag_ids=tag_ids,
@@ -987,7 +991,9 @@ class MailAccountHandler(LoggingMixin):
)
doc_overrides = DocumentMetadataOverrides(
title=message.subject,
filename=pathvalidate.sanitize_filename(f"{message.subject}.eml"),
filename=pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", f"{message.subject}.eml"),
),
correspondent_id=correspondent.id if correspondent else None,
document_type_id=doc_type.id if doc_type else None,
tag_ids=tag_ids,
+2 -4
View File
@@ -349,10 +349,9 @@ class MailMocker(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
len(expected_call_args),
)
for (_, mock_kwargs), expected_signatures in zip(
for (mock_args, mock_kwargs), expected_signatures in zip(
self._queue_consumption_tasks_mock.call_args_list,
expected_call_args,
strict=False,
):
consume_tasks = mock_kwargs["consume_tasks"]
@@ -362,7 +361,6 @@ class MailMocker(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
for consume_task, expected_signature in zip(
consume_tasks,
expected_signatures,
strict=False,
):
input_doc = consume_task.kwargs["input_doc"]
overrides = consume_task.kwargs["overrides"]
@@ -385,7 +383,7 @@ class MailMocker(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
"""
Applies pending actions to mails by inspecting calls to the queue_consumption_tasks method.
"""
for _, kwargs in self._queue_consumption_tasks_mock.call_args_list:
for args, kwargs in self._queue_consumption_tasks_mock.call_args_list:
message = kwargs["message"]
rule = kwargs["rule"]
apply_mail_action([], rule.pk, message.uid, message.subject, message.date)
+182
View File
@@ -0,0 +1,182 @@
"""
Tests that mail attachment filenames and EML subject filenames are
normalized to NFC Unicode before being stored as document overrides.
Filenames from MIME headers can arrive in NFD form (e.g. from macOS Mail),
and must be normalized to NFC so filenames are consistent regardless of the
sending client.
"""
import unicodedata
from pathlib import Path
from unittest import mock
import pytest
from documents.tests.utils import remove_dirs
from documents.tests.utils import setup_directories
from paperless_mail.models import MailRule
from paperless_mail.tests.factories import MailAccountFactory
from paperless_mail.tests.test_mail import MessageBuilder
from paperless_mail.tests.test_mail import _AttachmentDef
from paperless_mail.tests.test_mail import fake_magic_from_buffer
@pytest.fixture()
def directories(settings):
dirs = setup_directories()
yield dirs
remove_dirs(dirs)
@pytest.fixture()
def queue_consumption_tasks_mock():
with mock.patch("paperless_mail.mail.queue_consumption_tasks") as m:
yield m
@pytest.fixture()
def mail_account(db):
return MailAccountFactory()
@pytest.fixture()
def attachment_rule(mail_account):
rule = MailRule(
name="attachment rule",
account=mail_account,
assign_title_from=MailRule.TitleSource.FROM_FILENAME,
consumption_scope=MailRule.ConsumptionScope.ATTACHMENTS_ONLY,
attachment_type=MailRule.AttachmentProcessing.ATTACHMENTS_ONLY,
)
rule.save()
return rule
@pytest.fixture()
def eml_rule(mail_account):
rule = MailRule(
name="eml rule",
account=mail_account,
assign_title_from=MailRule.TitleSource.FROM_SUBJECT,
consumption_scope=MailRule.ConsumptionScope.EML_ONLY,
attachment_type=MailRule.AttachmentProcessing.ATTACHMENTS_ONLY,
)
rule.save()
return rule
@pytest.fixture()
def message_builder():
return MessageBuilder()
@pytest.mark.django_db
@mock.patch("paperless_mail.mail.magic.from_buffer", fake_magic_from_buffer)
class TestMailNFCNormalization:
"""Attachment filenames and EML subject filenames must be NFC-normalized."""
def test_attachment_nfd_filename_normalized_to_nfc(
self,
directories,
queue_consumption_tasks_mock,
attachment_rule,
mail_account_handler,
message_builder,
):
"""Attachment filename arriving as NFD must be stored as NFC in both
the overrides and the temp file written to disk.
"""
nfd_filename = unicodedata.normalize("NFD", "Rechnung März.pdf")
nfc_filename = unicodedata.normalize("NFC", "Rechnung März.pdf")
# Confirm the fixture is actually NFD (not already NFC)
assert unicodedata.is_normalized("NFD", nfd_filename)
assert not unicodedata.is_normalized("NFC", nfd_filename)
message = message_builder.create_message(
subject="Test invoice",
from_="sender@example.com",
attachments=[
_AttachmentDef(filename=nfd_filename, content=b"%PDF-1.4 test"),
],
)
result = mail_account_handler._handle_message(message, attachment_rule)
assert result == 1
queue_consumption_tasks_mock.assert_called_once()
call_kwargs = queue_consumption_tasks_mock.call_args.kwargs
consume_tasks = call_kwargs["consume_tasks"]
assert len(consume_tasks) == 1
overrides = consume_tasks[0].kwargs["overrides"]
assert overrides.filename == nfc_filename
assert unicodedata.is_normalized("NFC", overrides.filename)
assert unicodedata.is_normalized("NFC", overrides.title)
input_doc = consume_tasks[0].kwargs["input_doc"]
original_file = Path(input_doc.original_file)
assert original_file.exists()
assert original_file.name == nfc_filename
def test_eml_subject_filename_nfc(
self,
directories,
queue_consumption_tasks_mock,
eml_rule,
mail_account_handler,
message_builder,
):
"""EML filename derived from subject arriving as NFD must be stored as NFC."""
nfd_subject = unicodedata.normalize("NFD", "Rechnung März 2024")
nfc_expected_filename = unicodedata.normalize("NFC", "Rechnung März 2024.eml")
# Confirm the fixture is actually NFD
assert unicodedata.is_normalized("NFD", nfd_subject)
message = message_builder.create_message(
subject=nfd_subject,
from_="sender@example.com",
attachments=0,
)
mail_account_handler._handle_message(message, eml_rule)
queue_consumption_tasks_mock.assert_called_once()
call_kwargs = queue_consumption_tasks_mock.call_args.kwargs
consume_tasks = call_kwargs["consume_tasks"]
assert len(consume_tasks) == 1
overrides = consume_tasks[0].kwargs["overrides"]
assert overrides.filename == nfc_expected_filename
assert unicodedata.is_normalized("NFC", overrides.filename)
def test_already_nfc_attachment_filename_unchanged(
self,
directories,
queue_consumption_tasks_mock,
attachment_rule,
mail_account_handler,
message_builder,
):
"""An attachment filename already in NFC must pass through unchanged."""
nfc_filename = "Invoice_2024.pdf"
assert unicodedata.is_normalized("NFC", nfc_filename)
message = message_builder.create_message(
subject="Invoice",
from_="sender@example.com",
attachments=[
_AttachmentDef(filename=nfc_filename, content=b"%PDF-1.4 test"),
],
)
mail_account_handler._handle_message(message, attachment_rule)
call_kwargs = queue_consumption_tasks_mock.call_args.kwargs
consume_tasks = call_kwargs["consume_tasks"]
overrides = consume_tasks[0].kwargs["overrides"]
assert overrides.filename == nfc_filename
@@ -184,12 +184,7 @@ class TestMailMessageGpgDecryptor(TestMail):
EMAIL_GNUPG_HOME=empty_gpg_home,
):
message_decryptor = MailMessageDecryptor()
self.assertRaisesRegex(
Exception,
"Decryption failed",
message_decryptor.run,
encrypted_message,
)
self.assertRaises(Exception, message_decryptor.run, encrypted_message)
finally:
# Clean up the temporary GPG home used only by this test
try:
+2 -1
View File
@@ -1,3 +1,4 @@
import datetime
import logging
from datetime import timedelta
from http import HTTPStatus
@@ -85,7 +86,7 @@ class MailAccountViewSet(PassUserMixin, ModelViewSet[MailAccount]):
@action(methods=["post"], detail=False)
def test(self, request):
logger = logging.getLogger("paperless_mail")
request.data["name"] = timezone.now().isoformat()
request.data["name"] = datetime.datetime.now().isoformat()
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
existing_account = None
Generated
+13 -33
View File
@@ -1200,23 +1200,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/27/8d/2bc5f5546ff2ccb3f7de06742853483ab75bf74f36a92254702f8baecc79/factory_boy-3.3.3-py2.py3-none-any.whl", hash = "sha256:1c39e3289f7e667c4285433f305f8d506efc2fe9c73aaea4151ebd5cdea394fc", size = 37036, upload-time = "2025-02-03T09:49:01.659Z" },
]
[[package]]
name = "faiss-cpu"
version = "1.13.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/c9/671f66f6b31ec48e5825d36435f0cb91189fa8bb6b50724029dbff4ca83c/faiss_cpu-1.13.2-cp310-abi3-macosx_14_0_arm64.whl", hash = "sha256:a9064eb34f8f64438dd5b95c8f03a780b1a3f0b99c46eeacb1f0b5d15fc02dc1", size = 3452776, upload-time = "2025-12-24T10:27:01.419Z" },
{ url = "https://files.pythonhosted.org/packages/5a/4a/97150aa1582fb9c2bca95bd8fc37f27d3b470acec6f0a6833844b21e4b40/faiss_cpu-1.13.2-cp310-abi3-macosx_14_0_x86_64.whl", hash = "sha256:c8d097884521e1ecaea6467aeebbf1aa56ee4a36350b48b2ca6b39366565c317", size = 7896434, upload-time = "2025-12-24T10:27:03.592Z" },
{ url = "https://files.pythonhosted.org/packages/0b/d0/0940575f059591ca31b63a881058adb16a387020af1709dcb7669460115c/faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ee330a284042c2480f2e90450a10378fd95655d62220159b1408f59ee83ebf1", size = 11485825, upload-time = "2025-12-24T10:27:05.681Z" },
{ url = "https://files.pythonhosted.org/packages/e7/e1/a5acac02aa593809f0123539afe7b4aff61d1db149e7093239888c9053e1/faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ab88ee287c25a119213153d033f7dd64c3ccec466ace267395872f554b648cd7", size = 23845772, upload-time = "2025-12-24T10:27:08.194Z" },
{ url = "https://files.pythonhosted.org/packages/9c/7b/49dcaf354834ec457e85ca769d50bc9b5f3003fab7c94a9dcf08cf742793/faiss_cpu-1.13.2-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:85511129b34f890d19c98b82a0cd5ffb27d89d1cec2ee41d2621ee9f9ef8cf3f", size = 13477567, upload-time = "2025-12-24T10:27:10.822Z" },
{ url = "https://files.pythonhosted.org/packages/f7/6b/12bb4037921c38bb2c0b4cfc213ca7e04bbbebbfea89b0b5746248ce446e/faiss_cpu-1.13.2-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8b32eb4065bac352b52a9f5ae07223567fab0a976c7d05017c01c45a1c24264f", size = 25102239, upload-time = "2025-12-24T10:27:13.476Z" },
]
[[package]]
name = "faker"
version = "40.15.0"
@@ -2280,18 +2263,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f4/0c/fdddaee5391d915d3d568d2d8dbdb7c95647e65bb94d4ddb31d47cef5daf/llama_index_llms_openai_like-0.7.2-py3-none-any.whl", hash = "sha256:1f45a7b1cec8fb3f5997684327ffe6c19f93e789c2fff35dc5522465850faf0b", size = 6602, upload-time = "2026-04-23T23:05:31.708Z" },
]
[[package]]
name = "llama-index-vector-stores-faiss"
version = "0.6.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "llama-index-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7c/32/89a04e38fa9595b7116c61955d9a67085f0a5480738e9c14063e374724c2/llama_index_vector_stores_faiss-0.6.0.tar.gz", hash = "sha256:00bfeb6cb7571e0e856566cb4f10c89b415b6108f151d9ad48ee9c31da563f5e", size = 6045, upload-time = "2026-03-12T20:46:31.454Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5b/85/465b4f199075ae7773c181b2f98cf689f3107a8de031e7a9d4cd5e906446/llama_index_vector_stores_faiss-0.6.0-py3-none-any.whl", hash = "sha256:d4600c60ef5411d9e35ba573b4f416a5e13ea04c6f942c8e6f49f03f2feb4f3b", size = 7739, upload-time = "2026-03-12T20:46:30.736Z" },
]
[[package]]
name = "llama-index-workflows"
version = "2.20.0"
@@ -2912,7 +2883,6 @@ dependencies = [
{ name = "drf-spectacular", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "drf-spectacular-sidecar", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "drf-writable-nested", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "faiss-cpu", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "flower", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "gotenberg-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -2927,7 +2897,6 @@ dependencies = [
{ name = "llama-index-embeddings-openai-like", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "llama-index-llms-ollama", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "llama-index-llms-openai-like", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "llama-index-vector-stores-faiss", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "nltk", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "ocrmypdf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -2944,6 +2913,7 @@ dependencies = [
{ name = "scikit-learn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "sentence-transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "setproctitle", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "sqlite-vec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tantivy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tika-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "torch", version = "2.11.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" },
@@ -3062,7 +3032,6 @@ requires-dist = [
{ name = "drf-spectacular", specifier = "~=0.28" },
{ name = "drf-spectacular-sidecar", specifier = "~=2026.5.1" },
{ name = "drf-writable-nested", specifier = "~=0.7.1" },
{ name = "faiss-cpu", specifier = ">=1.10" },
{ name = "filelock", specifier = "~=3.29.0" },
{ name = "flower", specifier = "~=2.0.1" },
{ name = "gotenberg-client", specifier = "~=0.14.0" },
@@ -3078,7 +3047,6 @@ requires-dist = [
{ name = "llama-index-embeddings-openai-like", specifier = ">=0.2.2" },
{ name = "llama-index-llms-ollama", specifier = ">=0.9.1" },
{ name = "llama-index-llms-openai-like", specifier = ">=0.7.1" },
{ name = "llama-index-vector-stores-faiss", specifier = ">=0.5.2" },
{ name = "mysqlclient", marker = "extra == 'mariadb'", specifier = "~=2.2.7" },
{ name = "nltk", specifier = "~=3.9.1" },
{ name = "ocrmypdf", specifier = "~=17.4.2" },
@@ -3101,6 +3069,7 @@ requires-dist = [
{ name = "scikit-learn", specifier = "~=1.8.0" },
{ name = "sentence-transformers", specifier = ">=5.4.1" },
{ name = "setproctitle", specifier = "~=1.3.4" },
{ name = "sqlite-vec", specifier = "==0.1.9" },
{ name = "tantivy", specifier = "~=0.26.0" },
{ name = "tika-client", specifier = "~=0.11.0" },
{ name = "torch", specifier = "~=2.11.0", index = "https://download.pytorch.org/whl/cpu" },
@@ -4699,6 +4668,17 @@ asyncio = [
{ name = "greenlet", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[[package]]
name = "sqlite-vec"
version = "0.1.9"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/68/85/9fad0045d8e7c8df3e0fa5a56c630e8e15ad6e5ca2e6106fceb666aa6638/sqlite_vec-0.1.9-py3-none-macosx_10_6_x86_64.whl", hash = "sha256:1b62a7f0a060d9475575d4e599bbf94a13d85af896bc1ce86ee80d1b5b48e5fb", size = 131171, upload-time = "2026-03-31T08:02:31.717Z" },
{ url = "https://files.pythonhosted.org/packages/a4/3d/3677e0cd2f92e5ebc43cd29fbf565b75582bff1ccfa0b8327c7508e1084f/sqlite_vec-0.1.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1d52e30513bae4cc9778ddbf6145610434081be4c3afe57cd877893bad9f6b6c", size = 165434, upload-time = "2026-03-31T08:02:32.712Z" },
{ url = "https://files.pythonhosted.org/packages/00/d4/f2b936d3bdc38eadcbd2a87875815db36430fab0363182ba5d12cd8e0b51/sqlite_vec-0.1.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e921e592f24a5f9a18f590b6ddd530eb637e2d474e3b1972f9bbeb773aa3cb9", size = 160076, upload-time = "2026-03-31T08:02:33.796Z" },
{ url = "https://files.pythonhosted.org/packages/6f/ad/6afd073b0f817b3e03f9e37ad626ae341805891f23c74b5292818f49ac63/sqlite_vec-0.1.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux1_x86_64.whl", hash = "sha256:1515727990b49e79bcaf75fdee2ffc7d461f8b66905013231251f1c8938e7786", size = 163388, upload-time = "2026-03-31T08:02:34.888Z" },
]
[[package]]
name = "sqlparse"
version = "0.5.5"