Compare commits

..

6 Commits

Author SHA1 Message Date
stumpylog a1e7c0614e Updates the script in docker too 2026-06-04 12:02:45 -07:00
stumpylog dac05107a7 ruff: enable S324 (hashlib insecure hash functions)
Adds usedforsecurity=False to all hashlib.md5() calls, documenting
that these are used for file checksum comparison, not security.
The production call in _path_matches_checksum will be replaced with
compute_checksum() (SHA-256) in a separate branch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-04 11:37:17 -07:00
stumpylog 89ce62d97d ruff: enable PERF (perflint)
Fixes 9 violations — loop-based append replaced with comprehensions
or extend throughout production and test code:
- PERF401: list comprehensions / extend for transformed lists
- PERF402: list() around a generator for copied lists

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-04 11:22:07 -07:00
stumpylog 50f5d5f2e9 ruff: enable DTZ (flake8-datetimez)
Fixes 44 violations — naive datetime usage replaced with tz-aware
equivalents throughout production and test code:
- datetime.now() → timezone.now() (Django) or datetime.now(tz=UTC)
- datetime.fromtimestamp() → datetime.fromtimestamp(ts, tz=UTC)
- datetime.date.today() → timezone.now().date()
- datetime.datetime(...) constructors → tzinfo=UTC in tests
- UP017 auto-converted datetime.timezone.utc → datetime.UTC (py3.11+)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-04 10:47:13 -07:00
stumpylog 92b59eebfc ruff: enable B (flake8-bugbear)
Fixes 71 violations across production and test code:
- B904 (~50): raise-from in except blocks; from None at API/view
  boundaries, from exc where the cause is the direct origin
- B017 (9): pytest.raises(Exception) → specific type or match= arg
- B007 (5): unused loop vars renamed to _
- B027 (1): missing @abstractmethod on DateParserPluginBase.__exit__
- B028 (3): warnings.warn without stacklevel=2 in test utils
- B011 (1): assert False → raise AssertionError()
- B905 (3): zip() without strict=False
- B009 (3): getattr with constant string (auto-fixed)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-04 10:26:08 -07:00
stumpylog 59fd2ff9e8 ruff: enable G (logging format), ignore G004 (f-strings)
Replaces the single G201 selector with the full G group.
Fixes 2x G003 (string concat in log calls) and 2x G202 (redundant
exc_info on logger.exception). G004 (f-strings in logging) is ignored
as f-string style is accepted throughout this codebase.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-04 09:32:52 -07:00
117 changed files with 1912 additions and 7291 deletions
+1 -1
View File
@@ -61,7 +61,7 @@ def replace_with_symlinks(
total_duplicates = 0
space_saved = 0
for file_hash, file_list in duplicate_groups.items():
for file_list in duplicate_groups.values():
# Keep the first file as the original, replace others with symlinks
original_file = file_list[0]
duplicates = file_list[1:]
-7
View File
@@ -2068,13 +2068,6 @@ context by default.
Defaults to 8192.
#### [`PAPERLESS_AI_LLM_REQUEST_TIMEOUT=<int>`](#PAPERLESS_AI_LLM_REQUEST_TIMEOUT) {#PAPERLESS_AI_LLM_REQUEST_TIMEOUT}
: The timeout, in seconds, for requests to the configured AI backend. Increase this when using
local or slow inference servers that need more time to generate responses.
Defaults to 120.
#### [`PAPERLESS_AI_LLM_BACKEND=<str>`](#PAPERLESS_AI_LLM_BACKEND) {#PAPERLESS_AI_LLM_BACKEND}
: The AI backend to use. This can be either "openai-like" or "ollama". If set to "ollama", the AI
+8 -2
View File
@@ -42,6 +42,7 @@ dependencies = [
"drf-spectacular~=0.28",
"drf-spectacular-sidecar~=2026.5.1",
"drf-writable-nested~=0.7.1",
"faiss-cpu>=1.10",
"filelock~=3.29.0",
"flower~=2.0.1",
"gotenberg-client~=0.14.0",
@@ -56,6 +57,7 @@ dependencies = [
"llama-index-embeddings-openai-like>=0.2.2",
"llama-index-llms-ollama>=0.9.1",
"llama-index-llms-openai-like>=0.7.1",
"llama-index-vector-stores-faiss>=0.5.2",
"nltk~=3.9.1",
"ocrmypdf~=17.4.2",
"openai>=2.32",
@@ -72,7 +74,6 @@ dependencies = [
"scikit-learn~=1.8.0",
"sentence-transformers>=5.4.1",
"setproctitle~=1.3.4",
"sqlite-vec==0.1.9",
"tantivy~=0.26.0",
"tika-client~=0.11.0",
"torch~=2.11.0",
@@ -184,12 +185,16 @@ line-ending = "lf"
[tool.ruff.lint]
# https://docs.astral.sh/ruff/rules/
extend-select = [
"B", # https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
"COM", # https://docs.astral.sh/ruff/rules/#flake8-commas-com
"DTZ", # https://docs.astral.sh/ruff/rules/#flake8-datetimez-dtz
"PERF", # https://docs.astral.sh/ruff/rules/#perflint-perf
"S324", # https://docs.astral.sh/ruff/rules/hashlib-insecure-hash-functions/
"DJ", # https://docs.astral.sh/ruff/rules/#flake8-django-dj
"EXE", # https://docs.astral.sh/ruff/rules/#flake8-executable-exe
"FBT", # https://docs.astral.sh/ruff/rules/#flake8-boolean-trap-fbt
"FLY", # https://docs.astral.sh/ruff/rules/#flynt-fly
"G201", # https://docs.astral.sh/ruff/rules/#flake8-logging-format-g
"G", # https://docs.astral.sh/ruff/rules/#flake8-logging-format-g
"I", # https://docs.astral.sh/ruff/rules/#isort-i
"ICN", # https://docs.astral.sh/ruff/rules/#flake8-import-conventions-icn
"INP", # https://docs.astral.sh/ruff/rules/#flake8-no-pep420-inp
@@ -210,6 +215,7 @@ extend-select = [
]
ignore = [
"DJ001",
"G004", # f-strings in logging: accepted style in this codebase
"PLC0415",
"RUF012",
"SIM105",
+1 -1
View File
@@ -26,7 +26,7 @@ module.exports = {
'abstract-paperless-service',
],
transformIgnorePatterns: [
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|normalize-diacritics|@angular/common/locales/.*\\.js$))',
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|@angular/common/locales/.*\\.js$))',
],
moduleNameMapper: {
...esmPreset.moduleNameMapper,
-1
View File
@@ -32,7 +32,6 @@
"ngx-cookie-service": "^21.3.1",
"ngx-device-detector": "^11.0.0",
"ngx-ui-tour-ng-bootstrap": "^18.0.0",
"normalize-diacritics": "^5.0.0",
"pdfjs-dist": "^5.7.284",
"rxjs": "^7.8.2",
"tslib": "^2.8.1",
-11
View File
@@ -71,9 +71,6 @@ importers:
ngx-ui-tour-ng-bootstrap:
specifier: ^18.0.0
version: 18.0.0(f910a33494d223bd6dd07ce1bf22a35e)
normalize-diacritics:
specifier: ^5.0.0
version: 5.0.0
pdfjs-dist:
specifier: ^5.7.284
version: 5.7.284
@@ -5519,10 +5516,6 @@ packages:
engines: {node: ^20.17.0 || >=22.9.0}
hasBin: true
normalize-diacritics@5.0.0:
resolution: {integrity: sha512-t6czCJOpbAtckN1wCC2qPWnO3GQvNANb9bcUNbiOLEqojVuP31+ELIs5KhEG8jyz0TH7iD9BWxWz8O3ic2/rMQ==}
engines: {node: '>= 14.x', npm: '>= 6.x'}
normalize-path@3.0.0:
resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==}
engines: {node: '>=0.10.0'}
@@ -12938,10 +12931,6 @@ snapshots:
dependencies:
abbrev: 4.0.0
normalize-diacritics@5.0.0:
dependencies:
tslib: 2.8.1
normalize-path@3.0.0: {}
npm-bundled@5.0.0:
@@ -11,9 +11,6 @@
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="visibleTasks.length === 0">
<i-bs name="check2-all" class="me-1"></i-bs>{{dismissButtonText}}
</button>
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissAllTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="totalTasks === 0">
<i-bs name="check2-all" class="me-1"></i-bs><ng-container i18n>Dismiss all</ng-container>
</button>
<div class="form-check form-switch mb-0 ms-2">
<input class="form-check-input" type="checkbox" role="switch" [(ngModel)]="autoRefreshEnabled">
<label class="form-check-label" for="autoRefreshSwitch" i18n>Auto refresh</label>
@@ -84,7 +81,7 @@
<button class="btn btn-sm btn-outline-primary" ngbDropdownToggle>{{filterTargetName}}</button>
<div class="dropdown-menu shadow" ngbDropdownMenu>
@for (t of filterTargets; track t.id) {
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="setFilterTarget(t.id)">{{t.name}}</button>
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="filterTargetID = t.id">{{t.name}}</button>
}
</div>
</div>
@@ -11,7 +11,7 @@ import { Router } from '@angular/router'
import { RouterTestingModule } from '@angular/router/testing'
import { NgbModal, NgbModalRef, NgbModule } from '@ng-bootstrap/ng-bootstrap'
import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { of, throwError } from 'rxjs'
import { throwError } from 'rxjs'
import { routes } from 'src/app/app-routing.module'
import {
PaperlessTask,
@@ -29,11 +29,7 @@ import { ToastService } from 'src/app/services/toast.service'
import { environment } from 'src/environments/environment'
import { ConfirmDialogComponent } from '../../common/confirm-dialog/confirm-dialog.component'
import { PageHeaderComponent } from '../../common/page-header/page-header.component'
import {
TaskFilterTargetID,
TasksComponent,
TaskSection,
} from './tasks.component'
import { TasksComponent, TaskSection } from './tasks.component'
const tasks: PaperlessTask[] = [
{
@@ -158,13 +154,6 @@ const paginatedTasks: Results<PaperlessTask> = {
results: tasks,
}
const sectionCountResponse = {
all: 7,
needs_attention: 2,
in_progress: 3,
completed: 2,
}
describe('TasksComponent', () => {
let component: TasksComponent
let fixture: ComponentFixture<TasksComponent>
@@ -232,15 +221,6 @@ describe('TasksComponent', () => {
req.params.get('page') === '1'
)
.flush(paginatedTasks)
httpTestingController
.expectOne(
(req) =>
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
req.params.get('acknowledged') === 'false' &&
!req.params.has('status')
)
.flush(sectionCountResponse)
})
it('should display task sections with counts', () => {
@@ -315,7 +295,6 @@ describe('TasksComponent', () => {
const headerText = header.nativeElement.textContent
expect(headerText).toContain('Dismiss visible')
expect(headerText).toContain('Dismiss all')
expect(headerText).toContain('Auto refresh')
expect(headerText).not.toContain('All types')
expect(headerText).not.toContain('All sources')
@@ -348,74 +327,6 @@ describe('TasksComponent', () => {
expect(pagination).not.toBeNull()
})
it('should apply the selected section to the server-side task query', () => {
component.setSection(TaskSection.NeedsAttention)
const req = httpTestingController.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page') === '1' &&
request.params.get('page_size') === '25' &&
request.params.get('acknowledged') === 'false' &&
request.params.getAll('status').includes(PaperlessTaskStatus.Failure) &&
request.params.getAll('status').includes(PaperlessTaskStatus.Revoked)
)
req.flush({ count: 2, results: [tasks[0], tasks[1]] })
expect(component.totalTasks).toBe(2)
})
it('should apply task type and trigger source filters to the server-side task query', () => {
component.setTaskType(PaperlessTaskType.SanityCheck)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('task_type') === PaperlessTaskType.SanityCheck
)
.flush({ count: 1, results: [tasks[6]] })
component.setTriggerSource(PaperlessTaskTriggerSource.System)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('task_type') === PaperlessTaskType.SanityCheck &&
request.params.get('trigger_source') ===
PaperlessTaskTriggerSource.System
)
.flush({ count: 1, results: [tasks[6]] })
})
it('should apply text filters to the server-side task query', () => {
component.filterText = 'invoice'
jest.advanceTimersByTime(150)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('name') === 'invoice'
)
.flush({ count: 1, results: [tasks[0]] })
component.setFilterTarget(TaskFilterTargetID.Result)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('result') === 'invoice'
)
.flush({ count: 0, results: [] })
})
it('should load a different task page when pagination changes', () => {
component.setPage(2)
@@ -439,27 +350,6 @@ describe('TasksComponent', () => {
expect(component.pagedTasks).toEqual([tasks[0]])
})
it('should not replace section counts with current-page counts', () => {
component.setPage(2)
httpTestingController
.expectOne(
(req) =>
req.url === `${environment.apiBaseUrl}tasks/` &&
req.params.get('acknowledged') === 'false' &&
req.params.get('page_size') === '25' &&
req.params.get('page') === '2'
)
.flush({
count: 30,
results: [tasks[0]],
})
expect(component.sectionCount(TaskSection.NeedsAttention)).toBe(2)
expect(component.sectionCount(TaskSection.InProgress)).toBe(3)
expect(component.sectionCount(TaskSection.Completed)).toBe(2)
})
it('should expose stable task type options and disable empty ones', () => {
expect(component.taskTypeOptions.map((option) => option.value)).toContain(
PaperlessTaskType.TrainClassifier
@@ -605,46 +495,6 @@ describe('TasksComponent', () => {
expect(dismissSpy).toHaveBeenCalledWith(new Set([467, 466]))
})
it('should support dismiss all tasks', () => {
let modal: NgbModalRef
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
const dismissSpy = jest
.spyOn(tasksService, 'dismissAllTasks')
.mockReturnValue(of({}))
const reloadPageSpy = jest
.spyOn(component as any, 'reloadPage')
.mockImplementation(() => undefined)
component.dismissAllTasks()
expect(modal).not.toBeUndefined()
expect(modal.componentInstance.messageBold).toBe('Dismiss all 7 tasks?')
modal.componentInstance.confirmClicked.emit()
expect(dismissSpy).toHaveBeenCalled()
expect(reloadPageSpy).toHaveBeenCalledWith(false)
expect(component.selectedTasks.size).toBe(0)
})
it('should show an error and re-enable modal buttons when dismissing all tasks fails', () => {
const error = new Error('dismiss all failed')
const toastSpy = jest.spyOn(toastService, 'showError')
const dismissSpy = jest
.spyOn(tasksService, 'dismissAllTasks')
.mockReturnValue(throwError(() => error))
let modal: NgbModalRef
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
component.dismissAllTasks()
expect(modal).not.toBeUndefined()
modal.componentInstance.confirmClicked.emit()
expect(dismissSpy).toHaveBeenCalled()
expect(toastSpy).toHaveBeenCalledWith('Error dismissing tasks', error)
expect(modal.componentInstance.buttonsEnabled).toBe(true)
})
it('should dismiss the currently visible scoped and filtered tasks', () => {
component.setSection(TaskSection.InProgress)
component.setTaskType(PaperlessTaskType.SanityCheck)
@@ -823,9 +673,6 @@ describe('TasksComponent', () => {
})
it('should keep clearing selection independent from resetting filters', () => {
component.resetFilter()
expect(component.filterText).toBe('')
component.setTaskType(PaperlessTaskType.ConsumeFile)
component.toggleSelected(tasks[0])
expect(component.selectedTasks.size).toBe(1)
@@ -40,7 +40,7 @@ export enum TaskSection {
Completed = 'completed',
}
export enum TaskFilterTargetID {
enum TaskFilterTargetID {
Name,
Result,
}
@@ -167,12 +167,6 @@ export class TasksComponent
public readonly pageSize = 25
public page: number = 1
public totalTasks: number = 0
public sectionCounts: Record<TaskSection, number> = {
[TaskSection.All]: 0,
[TaskSection.NeedsAttention]: 0,
[TaskSection.InProgress]: 0,
[TaskSection.Completed]: 0,
}
public pagedTasks: PaperlessTask[] = []
public selectedSection: TaskSection = TaskSection.All
public selectedTaskType: PaperlessTaskType | null = null
@@ -288,7 +282,6 @@ export class TasksComponent
.subscribe((query) => {
this._filterText = query
this.clearSelection()
this.reloadPage(true)
})
}
@@ -341,30 +334,6 @@ export class TasksComponent
}
}
dismissAllTasks() {
let modal = this.modalService.open(ConfirmDialogComponent, {
backdrop: 'static',
})
modal.componentInstance.title = $localize`Confirm Dismiss All`
modal.componentInstance.messageBold = $localize`Dismiss all ${this.totalTasks} tasks?`
modal.componentInstance.btnClass = 'btn-warning'
modal.componentInstance.btnCaption = $localize`Dismiss`
modal.componentInstance.confirmClicked.pipe(first()).subscribe(() => {
modal.componentInstance.buttonsEnabled = false
modal.close()
this.tasksService.dismissAllTasks().subscribe({
next: () => {
this.reloadPage(false)
},
error: (e) => {
this.toastService.showError($localize`Error dismissing tasks`, e)
modal.componentInstance.buttonsEnabled = true
},
})
this.clearSelection()
})
}
expandTask(task: PaperlessTask) {
this.expandedTask = this.expandedTask == task.id ? undefined : task.id
}
@@ -477,7 +446,9 @@ export class TasksComponent
}
sectionCount(section: TaskSection): number {
return this.sectionCounts[section]
return this.pagedTasks.filter((task) =>
this.taskBelongsToSection(task, section)
).length
}
sectionShowsResults(section: TaskSection): boolean {
@@ -487,27 +458,16 @@ export class TasksComponent
setSection(section: TaskSection) {
this.selectedSection = section
this.clearSelection()
this.reloadPage(true)
}
setTaskType(taskType: PaperlessTaskType | null) {
this.selectedTaskType = taskType
this.clearSelection()
this.reloadPage(true)
}
setTriggerSource(triggerSource: PaperlessTaskTriggerSource | null) {
this.selectedTriggerSource = triggerSource
this.clearSelection()
this.reloadPage(true)
}
setFilterTarget(filterTargetID: TaskFilterTargetID) {
this.filterTargetID = filterTargetID
if (this._filterText.length) {
this.clearSelection()
this.reloadPage(true)
}
}
taskTypeOptionCount(taskType: PaperlessTaskType | null): number {
@@ -545,32 +505,19 @@ export class TasksComponent
}
public resetFilter() {
if (!this._filterText.length) {
return
}
this._filterText = ''
this.clearSelection()
this.reloadPage(true)
}
public resetFilters() {
const hadFilter = this.isFiltered
this.selectedTaskType = null
this.selectedTriggerSource = null
this._filterText = ''
this.resetFilter()
this.clearSelection()
if (hadFilter) {
this.reloadPage(true)
}
}
filterInputKeyup(event: KeyboardEvent) {
if (event.key == 'Enter') {
this._filterText = (event.target as HTMLInputElement).value
this.clearSelection()
this.reloadPage(true)
} else if (event.key === 'Escape') {
this.resetFilter()
}
@@ -659,86 +606,19 @@ export class TasksComponent
)
}
private reloadSectionCounts() {
this.tasksService
.statusCounts(this.getParamsForSection(TaskSection.All))
.pipe(first(), takeUntil(this.unsubscribeNotifier))
.subscribe((counts) => {
this.sectionCounts[TaskSection.All] = counts.all
this.sectionCounts[TaskSection.NeedsAttention] = counts.needs_attention
this.sectionCounts[TaskSection.InProgress] = counts.in_progress
this.sectionCounts[TaskSection.Completed] = counts.completed
})
}
private getParamsForSection(
section: TaskSection
): Record<string, string | number | boolean | readonly string[]> {
const params: Record<
string,
string | number | boolean | readonly string[]
> = {
acknowledged: false,
}
const statuses = this.statusesForSection(section)
if (statuses.length) {
params.status = statuses
}
if (this.selectedTaskType !== null) {
params.task_type = this.selectedTaskType
}
if (this.selectedTriggerSource !== null) {
params.trigger_source = this.selectedTriggerSource
}
if (this._filterText.length) {
params[
this.filterTargetID === TaskFilterTargetID.Name ? 'name' : 'result'
] = this._filterText
}
return params
}
private statusesForSection(section: TaskSection): PaperlessTaskStatus[] {
switch (section) {
case TaskSection.NeedsAttention:
return [PaperlessTaskStatus.Failure, PaperlessTaskStatus.Revoked]
case TaskSection.InProgress:
return [PaperlessTaskStatus.Pending, PaperlessTaskStatus.Started]
case TaskSection.Completed:
return [PaperlessTaskStatus.Success]
default:
return []
}
}
private reloadPage(resetToFirstPage: boolean = false) {
if (resetToFirstPage) {
this.page = 1
}
this.reloadSectionCounts()
this.loading = true
this.tasksService
.list(
this.page,
this.pageSize,
this.getParamsForSection(this.selectedSection)
)
.list(this.page, this.pageSize, { acknowledged: false })
.pipe(first(), takeUntil(this.unsubscribeNotifier))
.subscribe({
next: (result) => {
this.pagedTasks = result.results
this.totalTasks = result.count
this.sectionCounts[TaskSection.All] = result.count
if (this.selectedSection !== TaskSection.All) {
this.sectionCounts[this.selectedSection] = result.count
}
this.loading = false
if (
this.page > 1 &&
@@ -8,7 +8,7 @@
<div class="chat-messages font-monospace small">
@for (message of messages; track message) {
<div class="message d-flex flex-row small" [class.justify-content-end]="message.role === 'user'">
<div class="p-2 m-2" [class.bg-body]="message.role === 'user'">
<div class="p-2 m-2" [class.bg-dark]="message.role === 'user'">
<span>
{{ message.content }}
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
@@ -188,14 +188,4 @@ describe('ChatComponent', () => {
component.searchInputKeyDown(event)
expect(component.sendMessage).toHaveBeenCalled()
})
it('should not send message on Enter key press while composing with IME', () => {
jest.spyOn(component, 'sendMessage')
const event = new KeyboardEvent('keydown', {
key: 'Enter',
isComposing: true,
})
component.searchInputKeyDown(event)
expect(component.sendMessage).not.toHaveBeenCalled()
})
})
@@ -155,10 +155,7 @@ export class ChatComponent implements OnInit {
}
public searchInputKeyDown(event: KeyboardEvent) {
if (
event.key === 'Enter' &&
!(event.isComposing || event.keyCode === 229)
) {
if (event.key === 'Enter') {
event.preventDefault()
this.sendMessage()
}
@@ -5,10 +5,10 @@
</div>
<div class="modal-body">
@if (messageBold) {
<p class="text-break"><b>{{messageBold}}</b></p>
<p><b>{{messageBold}}</b></p>
}
@if (message) {
<p class="mb-0 text-break" [innerHTML]="message"></p>
<p class="mb-0" [innerHTML]="message"></p>
}
</div>
<div class="modal-footer">
@@ -9,11 +9,8 @@
<label class="form-label" for="metadataDocumentID" i18n>Documents:</label>
<ul class="list-group"
cdkDropList
[cdkDropListData]="documentIDs"
(cdkDropListDropped)="onDrop($event)">
@for (documentID of documentIDs; track documentID) {
@let document = getDocument(documentID);
@if (document) {
@for (document of documents; track document.id) {
<li class="list-group-item d-flex align-items-center" cdkDrag>
<i-bs name="grip-vertical" class="me-2"></i-bs>
<div class="d-flex flex-column">
@@ -30,7 +27,6 @@
</small>
</div>
</li>
}
}
</ul>
</div>
@@ -23,7 +23,6 @@ import {
import { CustomFieldsService } from 'src/app/services/rest/custom-fields.service'
import { ToastService } from 'src/app/services/toast.service'
import { pngxPopperOptions } from 'src/app/utils/popper-options'
import { matchesSearchText } from 'src/app/utils/text-search'
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
import { CustomFieldEditDialogComponent } from '../edit-dialog/custom-field-edit-dialog/custom-field-edit-dialog.component'
@@ -70,7 +69,9 @@ export class CustomFieldsDropdownComponent extends LoadingComponentWithPermissio
public get filteredFields(): CustomField[] {
return this.unusedFields.filter(
(f) => !this.filterText || matchesSearchText(f.name, this.filterText)
(f) =>
!this.filterText ||
f.name.toLowerCase().includes(this.filterText.toLowerCase())
)
}
@@ -63,7 +63,6 @@
[(ngModel)]="atom.value"
[disabled]="disabled"
[virtualScroll]="getSelectOptionsForField(atom.field)?.length > 100"
[searchFn]="selectOptionSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
} @else if (getCustomFieldByID(atom.field)?.data_type === CustomFieldDataType.DocumentLink) {
@@ -82,7 +81,6 @@
[disabled]="disabled"
bindLabel="name"
bindValue="id"
[searchFn]="customFieldSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
<select class="w-25 form-select" [(ngModel)]="atom.operator" [disabled]="disabled">
@@ -127,7 +125,6 @@
[(ngModel)]="atom.value"
[disabled]="disabled"
[multiple]="true"
[searchFn]="selectOptionSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
}
@@ -36,7 +36,6 @@ import {
CustomFieldQueryExpression,
} from 'src/app/utils/custom-field-query-element'
import { pngxPopperOptions } from 'src/app/utils/popper-options'
import { matchesSearchText } from 'src/app/utils/text-search'
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
import { ClearableBadgeComponent } from '../clearable-badge/clearable-badge.component'
import { DocumentLinkComponent } from '../input/document-link/document-link.component'
@@ -282,14 +281,6 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm
public readonly today: string = new Date().toLocaleDateString('en-CA')
public customFieldSearchFn = (term: string, field: CustomField): boolean =>
matchesSearchText(field?.name, term)
public selectOptionSearchFn = (
term: string,
option: { id: string; label: string }
): boolean => matchesSearchText(option?.label, term)
constructor() {
super()
this.selectionModel = new CustomFieldQueriesModel()
@@ -28,7 +28,6 @@
[notFoundText]="notFoundText"
[multiple]="multiple"
[bindLabel]="bindLabel"
[searchFn]="searchFn"
bindValue="id"
[virtualScroll]="items?.length > 100"
(change)="onChange(value)"
@@ -112,15 +112,6 @@ describe('SelectComponent', () => {
expect(createNewVal).toEqual('baz')
})
it('should search items by independent normalized terms', () => {
expect(
component.searchFn('tax 26', { id: 11, name: 'Tax\u00e9s 2026' })
).toBeTruthy()
expect(
component.searchFn('tax receipt', { id: 11, name: 'Tax\u00e9s 2026' })
).toBeFalsy()
})
it('should clear search term on blur after delay', fakeAsync(() => {
const clearSpy = jest.spyOn(component, 'clearLastSearchTerm')
component.onBlur()
@@ -13,7 +13,6 @@ import {
import { RouterModule } from '@angular/router'
import { NgSelectModule } from '@ng-select/ng-select'
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { matchesSearchText } from 'src/app/utils/text-search'
import { AbstractInputComponent } from '../abstract-input'
@Component({
@@ -100,9 +99,6 @@ export class SelectComponent extends AbstractInputComponent<number> {
@Input()
bindLabel: string = 'name'
public searchFn = (term: string, item: any): boolean =>
matchesSearchText(item?.[this.bindLabel], term)
@Input()
showFilter: boolean = false
@@ -14,7 +14,6 @@
[clearSearchOnAdd]="true"
[hideSelected]="tags.length > 0"
[addTag]="allowCreate ? createTagRef : false"
[searchFn]="searchFn"
addTagText="Add tag"
i18n-addTagText
(add)="onAdd($event)"
@@ -171,15 +171,6 @@ describe('TagsComponent', () => {
expect(component.getTag(4)).toBeUndefined()
})
it('should search tags by independent normalized terms including parents', () => {
const parent: Tag = { id: 11, name: 'Financ\u00e9' }
const child: Tag = { id: 12, name: 'Taxes 2026', parent: parent.id }
component.tags = [parent, child]
expect(component.searchFn('finance 26', child)).toBeTruthy()
expect(component.searchFn('finance receipt', child)).toBeFalsy()
})
it('should emit filtered documents', () => {
component.value = [10]
component.tags = tags
@@ -21,7 +21,6 @@ import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { first, firstValueFrom, tap } from 'rxjs'
import { Tag } from 'src/app/data/tag'
import { TagService } from 'src/app/services/rest/tag.service'
import { matchesSearchText } from 'src/app/utils/text-search'
import { EditDialogMode } from '../../edit-dialog/edit-dialog.component'
import { TagEditDialogComponent } from '../../edit-dialog/tag-edit-dialog/tag-edit-dialog.component'
import { TagComponent } from '../../tag/tag.component'
@@ -115,14 +114,6 @@ export class TagsComponent implements OnInit, ControlValueAccessor {
public createTagRef: (name) => void
public searchFn = (term: string, tag: Tag): boolean =>
matchesSearchText(
[this.getParentChain(tag?.id).map((parent) => parent.name), tag?.name]
.flat()
.join(' '),
term
)
getTag(id: number) {
if (this.tags) {
return this.tags.find((tag) => tag.id == id)
@@ -1,5 +1,5 @@
<div class="btn-group">
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="disabled || loading || (suggestions && !aiEnabled)">
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="loading || (suggestions && !aiEnabled)">
@if (loading) {
<div class="spinner-border spinner-border-sm" role="status"></div>
} @else {
@@ -13,7 +13,7 @@
@if (aiEnabled) {
<div class="btn-group" ngbDropdown #dropdown="ngbDropdown" [popperOptions]="popperOptions">
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="disabled || loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
<span class="visually-hidden" i18n>Show suggestions</span>
</button>
@@ -37,18 +37,6 @@ describe('SuggestionsDropdownComponent', () => {
expect(component.getSuggestions.emit).toHaveBeenCalled()
})
it('should not emit getSuggestions when disabled', () => {
jest.spyOn(component.getSuggestions, 'emit')
component.disabled = true
component.suggestions = null
fixture.detectChanges()
component.clickSuggest()
expect(component.getSuggestions.emit).not.toHaveBeenCalled()
expect(fixture.nativeElement.querySelector('button').disabled).toBeTruthy()
})
it('should toggle dropdown when clickSuggest is called and suggestions are not null', () => {
component.aiEnabled = true
fixture.detectChanges()
@@ -47,14 +47,6 @@ export class SuggestionsDropdownComponent {
addCorrespondent: EventEmitter<string> = new EventEmitter()
public clickSuggest(): void {
if (
this.disabled ||
this.loading ||
(this.suggestions && !this.aiEnabled)
) {
return
}
if (!this.suggestions) {
this.getSuggestions.emit(this)
} else {
@@ -131,9 +131,7 @@
@if (status.tasks.celery_status === 'OK') {
<i-bs name="check-circle-fill" class="text-primary ms-2 lh-1"></i-bs>
} @else {
<i-bs name="exclamation-triangle-fill" class="ms-2 lh-1"
[class.text-danger]="status.tasks.celery_status === SystemStatusItemStatus.ERROR"
[class.text-warning]="status.tasks.celery_status === SystemStatusItemStatus.WARNING"></i-bs>
<i-bs name="exclamation-triangle-fill" class="text-danger ms-2 lh-1"></i-bs>
}
</button>
<ng-template #celeryStatus>
-9
View File
@@ -360,14 +360,6 @@ export const PaperlessConfigOptions: ConfigOption[] = [
category: ConfigCategory.AI,
note: $localize`Language to use for generated AI suggestions. When unset, AI suggestions use the user's display language if explicitly set.`,
},
{
key: 'llm_request_timeout',
title: $localize`LLM Request Timeout`,
type: ConfigOptionType.Number,
config_key: 'PAPERLESS_AI_LLM_REQUEST_TIMEOUT',
category: ConfigCategory.AI,
note: $localize`Timeout in seconds for LLM requests.`,
},
]
export interface PaperlessConfig extends ObjectWithId {
@@ -409,5 +401,4 @@ export interface PaperlessConfig extends ObjectWithId {
llm_api_key: string
llm_endpoint: string
llm_output_language: string
llm_request_timeout: number
}
-7
View File
@@ -64,10 +64,3 @@ export interface PaperlessTaskSummary {
last_success: Date | null
last_failure: Date | null
}
export interface PaperlessTaskStatusCounts {
all: number
needs_attention: number
in_progress: number
completed: number
}
+3 -2
View File
@@ -1,6 +1,5 @@
import { Pipe, PipeTransform } from '@angular/core'
import { MatchingModel } from '../data/matching-model'
import { matchesSearchText } from '../utils/text-search'
@Pipe({
name: 'filter',
@@ -22,7 +21,9 @@ export class FilterPipe implements PipeTransform {
typeof item[key] === 'string' || typeof item[key] === 'number'
)
return keys.some((key) => {
return matchesSearchText(item[key], searchText)
return String(item[key])
.toLowerCase()
.includes(searchText.toLowerCase())
})
})
}
@@ -80,27 +80,6 @@ describe('TasksService', () => {
.flush({ count: 0, results: [] })
})
it('calls acknowledge_tasks api endpoint on dismiss all and reloads', () => {
tasksService.dismissAllTasks().subscribe()
const req = httpTestingController.expectOne(
`${environment.apiBaseUrl}tasks/acknowledge/`
)
expect(req.request.method).toEqual('POST')
expect(req.request.body).toEqual({
all: true,
})
req.flush([])
// reload is then called
httpTestingController
.expectOne(
(req: HttpRequest<unknown>) =>
req.url === `${environment.apiBaseUrl}tasks/` &&
req.params.get('acknowledged') === 'false' &&
req.params.get('page_size') === '1000'
)
.flush({ count: 0, results: [] })
})
it('groups mixed task types by status when reloading', () => {
expect(tasksService.total).toEqual(0)
const mockTasks = [
@@ -242,34 +221,4 @@ describe('TasksService', () => {
task_id: 'abc-123',
})
})
it('loads filtered task status counts', () => {
tasksService
.statusCounts({
acknowledged: false,
task_type: PaperlessTaskType.ConsumeFile,
})
.subscribe((res) => {
expect(res).toEqual({
all: 10,
needs_attention: 2,
in_progress: 3,
completed: 5,
})
})
const req = httpTestingController.expectOne(
(req: HttpRequest<unknown>) =>
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
req.params.get('acknowledged') === 'false' &&
req.params.get('task_type') === PaperlessTaskType.ConsumeFile
)
expect(req.request.method).toEqual('GET')
req.flush({
all: 10,
needs_attention: 2,
in_progress: 3,
completed: 5,
})
})
})
+1 -27
View File
@@ -5,7 +5,6 @@ import { first, map, takeUntil, tap } from 'rxjs/operators'
import {
PaperlessTask,
PaperlessTaskStatus,
PaperlessTaskStatusCounts,
PaperlessTaskType,
} from 'src/app/data/paperless-task'
import { Results } from 'src/app/data/results'
@@ -89,7 +88,7 @@ export class TasksService {
public list(
page: number,
pageSize: number,
extraParams?: Record<string, string | number | boolean | readonly string[]>
extraParams?: Record<string, string | number | boolean>
): Observable<Results<PaperlessTask>> {
return this.http.get<Results<PaperlessTask>>(
`${this.baseUrl}${this.endpoint}/`,
@@ -103,17 +102,6 @@ export class TasksService {
)
}
public statusCounts(
extraParams?: Record<string, string | number | boolean | readonly string[]>
): Observable<PaperlessTaskStatusCounts> {
return this.http.get<PaperlessTaskStatusCounts>(
`${this.baseUrl}${this.endpoint}/status_counts/`,
{
params: extraParams,
}
)
}
public dismissTasks(task_ids: Set<number>): Observable<any> {
return this.http
.post(`${this.baseUrl}tasks/acknowledge/`, {
@@ -128,20 +116,6 @@ export class TasksService {
)
}
public dismissAllTasks(): Observable<any> {
return this.http
.post(`${this.baseUrl}tasks/acknowledge/`, {
all: true,
})
.pipe(
first(),
takeUntil(this.unsubscribeNotifer),
tap(() => {
this.reload()
})
)
}
public cancelPending(): void {
this.unsubscribeNotifer.next(true)
}
-17
View File
@@ -1,17 +0,0 @@
import { matchesSearchText } from './text-search'
describe('text search utilities', () => {
it('matches text accent-insensitively', () => {
expect(matchesSearchText('R\u00e9sum\u00e9', 'resume')).toBeTruthy()
expect(matchesSearchText('S\u00f8ren', 'soren')).toBeTruthy()
expect(matchesSearchText('\u0152uvre', 'oeuvre')).toBeTruthy()
expect(matchesSearchText('Invoice', 'receipt')).toBeFalsy()
})
it('matches all whitespace-separated search terms independently', () => {
expect(matchesSearchText('taxes 2026', 'tax 26')).toBeTruthy()
expect(matchesSearchText('2026 taxes', 'tax 26')).toBeTruthy()
expect(matchesSearchText('Tax\u00e9s 2026', 'taxe 26')).toBeTruthy()
expect(matchesSearchText('taxes 2026', 'tax receipt')).toBeFalsy()
})
})
-23
View File
@@ -1,23 +0,0 @@
import { normalizeSync } from 'normalize-diacritics'
export type SearchTextValue =
| string
| number
| boolean
| bigint
| null
| undefined
export function normalizeSearchText(value: SearchTextValue): string {
return normalizeSync(String(value ?? '')).toLocaleLowerCase()
}
export function matchesSearchText(
value: SearchTextValue,
searchText: SearchTextValue
): boolean {
const normalizedValue = normalizeSearchText(value)
const searchTerms = normalizeSearchText(searchText).trim().split(/\s+/)
return searchTerms.every((term) => normalizedValue.includes(term))
}
-13
View File
@@ -904,19 +904,6 @@ def remove_password(
doc.id,
pair.source_doc.source_path,
)
try:
with pikepdf.open(source_path) as pdf:
if not pdf.is_encrypted:
logger.info(
"Skipping password removal for document %s because the "
"source PDF is not encrypted",
pair.root_doc.id,
)
continue
except pikepdf.PasswordError:
# Password-protected PDFs need the supplied password below.
pass
with pikepdf.open(source_path, password=password) as pdf:
filepath: Path = (
Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR))
+3 -2
View File
@@ -834,8 +834,9 @@ class ConsumerPlugin(
self.log.debug(f"Creation date from parse_date: {create_date}")
else:
stats = Path(self.input_doc.original_file).stat()
create_date = timezone.make_aware(
datetime.datetime.fromtimestamp(stats.st_mtime),
create_date = datetime.datetime.fromtimestamp(
stats.st_mtime,
tz=datetime.UTC,
)
self.log.debug(f"Creation date from st_mtime: {create_date}")
+4 -4
View File
@@ -1,4 +1,3 @@
import datetime as dt
import logging
import os
import shutil
@@ -6,6 +5,7 @@ from pathlib import Path
from typing import Final
from django.conf import settings
from django.utils import timezone
from pikepdf import Pdf
from documents.consumer import ConsumerError
@@ -78,7 +78,7 @@ class CollatePlugin(NoCleanupPluginMixin, NoSetupPluginMixin, ConsumeTaskPlugin)
stats = staging.stat()
# if the file is older than the timeout, we don't consider
# it valid
if (dt.datetime.now().timestamp() - stats.st_mtime) > TIMEOUT_SECONDS:
if (timezone.now().timestamp() - stats.st_mtime) > TIMEOUT_SECONDS:
logger.warning("Outdated double sided staging file exists, deleting it")
staging.unlink()
else:
@@ -99,7 +99,7 @@ class CollatePlugin(NoCleanupPluginMixin, NoSetupPluginMixin, ConsumeTaskPlugin)
"two uploaded files don't belong to the same double-"
"sided scan. Please retry, starting with the odd "
"numbered pages again.",
)
) from None
# Merged file has the same path, but without the
# double-sided subdir. Therefore, it is also in the
# consumption dir and will be picked up for processing
@@ -134,7 +134,7 @@ class CollatePlugin(NoCleanupPluginMixin, NoSetupPluginMixin, ConsumeTaskPlugin)
shutil.move(pdf_file, staging)
# update access to modification time so we know if the file
# is outdated when another file gets uploaded
timestamp = dt.datetime.now().timestamp()
timestamp = timezone.now().timestamp()
os.utime(staging, (timestamp, timestamp))
logger.info(
"Got scan with odd numbered pages of double-sided scan, moved it to %s",
+5 -67
View File
@@ -28,7 +28,6 @@ from django.db.models.functions import Cast
from django.utils.translation import gettext_lazy as _
from django_filters import DateFilter
from django_filters.rest_framework import BooleanFilter
from django_filters.rest_framework import CharFilter
from django_filters.rest_framework import DateTimeFilter
from django_filters.rest_framework import Filter
from django_filters.rest_framework import FilterSet
@@ -351,7 +350,7 @@ def handle_validation_prefix(func: Callable):
try:
return func(*args, **kwargs)
except serializers.ValidationError as e:
raise serializers.ValidationError({validation_prefix: e.detail})
raise serializers.ValidationError({validation_prefix: e.detail}) from e
# Update the signature to include the validation_prefix argument
old_sig = inspect.signature(func)
@@ -462,7 +461,7 @@ class CustomFieldQueryParser:
except json.JSONDecodeError:
raise serializers.ValidationError(
{self._validation_prefix: [_("Value must be valid JSON.")]},
)
) from None
return (
self._parse_expr(expr, validation_prefix=self._validation_prefix),
self._annotations,
@@ -590,7 +589,7 @@ class CustomFieldQueryParser:
except CustomField.DoesNotExist:
raise serializers.ValidationError(
[_("{name!r} is not a valid custom field.").format(name=id_or_name)],
)
) from None
self._custom_fields[custom_field.id] = custom_field
self._custom_fields[custom_field.name] = custom_field
return custom_field
@@ -901,16 +900,6 @@ class ShareLinkBundleFilterSet(FilterSet):
class PaperlessTaskFilterSet(FilterSet):
name = CharFilter(
method="filter_name",
label="Name",
)
result = CharFilter(
method="filter_result",
label="Result",
)
task_type = MultipleChoiceFilter(
choices=PaperlessTask.TaskType.choices,
label="Task Type",
@@ -950,58 +939,7 @@ class PaperlessTaskFilterSet(FilterSet):
class Meta:
model = PaperlessTask
fields = [
"task_type",
"trigger_source",
"status",
"acknowledged",
"owner",
"name",
"result",
]
def filter_name(self, queryset, name, value):
if not value:
return queryset
matching_task_types = [
task_type
for task_type, label in PaperlessTask.TaskType.choices
if value.lower() in str(label).lower()
]
matching_trigger_sources = [
trigger_source
for trigger_source, label in PaperlessTask.TriggerSource.choices
if value.lower() in str(label).lower()
]
return queryset.filter(
Q(input_data__filename__icontains=value)
| Q(task_type__in=matching_task_types)
| Q(trigger_source__in=matching_trigger_sources),
)
def filter_result(self, queryset, name, value):
if not value:
return queryset
query = Q(result_data__reason__icontains=value) | Q(
result_data__error_message__icontains=value,
)
try:
numeric_value = int(value)
except (TypeError, ValueError):
pass
else:
query |= Q(result_data__document_id=numeric_value) | Q(
result_data__duplicate_of=numeric_value,
)
if "duplicate" in value.lower():
query |= Q(result_data__duplicate_of__isnull=False)
return queryset.filter(query)
fields = ["task_type", "trigger_source", "status", "acknowledged", "owner"]
def filter_is_complete(self, queryset, name, value):
if value:
@@ -1050,7 +988,7 @@ class DocumentsOrderingFilter(OrderingFilter):
except CustomField.DoesNotExist:
raise serializers.ValidationError(
{self.prefix + str(custom_field_id): [_("Custom field not found")]},
)
) from None
annotation = None
match field.data_type:
@@ -169,10 +169,6 @@ class FileStabilityTracker:
self._tracked.pop(path, None)
yield path
def is_tracking(self, path: Path) -> bool:
"""Check whether a path is currently being tracked for stability."""
return path.resolve() in self._tracked
def has_pending_files(self) -> bool:
"""Check if there are files waiting for stability check."""
return len(self._tracked) > 0
@@ -374,16 +370,6 @@ class Command(BaseCommand):
# Testing timeout in seconds
testing_timeout_s: Final[float] = 0.5
# How often to perform a full-glob rescan of the consume directory as a
# safety net. Each watchfiles watcher is torn down and recreated on every
# batch to reconfigure its timeout, and a fresh watcher silently adopts the
# current directory contents as its baseline. A file that appears between
# one batch and the next watcher's baseline is therefore never reported and
# would sit in the consume directory forever. This periodic rescan re-injects
# such files into the stability tracker (see GH issue #13011). Not currently
# user-configurable; instances may override for testing.
rescan_interval_s: float = 300.0
def add_arguments(self, parser) -> None:
parser.add_argument(
"directory",
@@ -439,7 +425,7 @@ class Command(BaseCommand):
)
# Process existing files
queued = self._process_existing_files(
self._process_existing_files(
directory=directory,
recursive=recursive,
subdirs_as_tags=subdirs_as_tags,
@@ -459,7 +445,6 @@ class Command(BaseCommand):
polling_interval=polling_interval,
stability_delay=stability_delay,
is_testing=is_testing,
queued=queued,
)
logger.debug("Consumer exiting")
@@ -471,18 +456,11 @@ class Command(BaseCommand):
recursive: bool,
subdirs_as_tags: bool,
consumer_filter: ConsumerFilter,
) -> set[Path]:
"""
Process any existing files in the consumption directory.
Returns the set of resolved paths that were queued, so the watch loop
can seed its in-flight set and avoid re-queuing them on the first
rescan before the consume tasks have removed them from disk.
"""
) -> None:
"""Process any existing files in the consumption directory."""
logger.info(f"Processing existing files in {directory}")
glob_pattern = "**/*" if recursive else "*"
queued: set[Path] = set()
for filepath in directory.glob(glob_pattern):
# Use filter to check if file should be processed
@@ -497,48 +475,6 @@ class Command(BaseCommand):
consumption_dir=directory,
subdirs_as_tags=subdirs_as_tags,
)
queued.add(filepath.resolve())
return queued
def _rescan_existing_files(
self,
*,
directory: Path,
recursive: bool,
consumer_filter: ConsumerFilter,
tracker: FileStabilityTracker,
queued: set[Path],
) -> None:
"""
Re-inject on-disk files the watcher never reported into the tracker.
Acts as a safety net for files stranded by the watcher-recreation gap
(see ``rescan_interval_s``). Files already being tracked or already
queued and awaiting consumption are skipped, so a file is never queued
twice. Queued paths that have since left the directory are pruned so a
later file reusing the same name is not skipped forever.
"""
# Prune in-flight paths that have left the directory
for path in list(queued):
if not path.exists():
queued.discard(path)
glob_pattern = "**/*" if recursive else "*"
for filepath in directory.glob(glob_pattern):
if not filepath.is_file():
continue
if not consumer_filter(Change.added, str(filepath)):
continue
resolved = filepath.resolve()
if tracker.is_tracking(resolved) or resolved in queued:
continue
logger.debug(f"Rescan found untracked file: {resolved}")
tracker.track(resolved, Change.added)
def _watch_directory(
self,
@@ -550,24 +486,11 @@ class Command(BaseCommand):
polling_interval: float,
stability_delay: float,
is_testing: bool,
queued: set[Path] | None = None,
) -> None:
"""Watch directory for changes and process stable files."""
use_polling = polling_interval > 0
poll_delay_ms = int(polling_interval * 1000) if use_polling else 0
# Resolved paths that have been queued and are awaiting consumption.
# Seeded from the startup scan so the first rescan does not re-queue
# files whose consume tasks have not yet removed them from disk.
queued = set() if queued is None else queued
# Full-glob safety net cadence (0 disables)
rescan_interval_s = self.rescan_interval_s
rescan_timeout_ms = (
int(rescan_interval_s * 1000) if rescan_interval_s > 0 else 0
)
last_rescan = monotonic()
if use_polling:
logger.info(
f"Watching {directory} using polling (interval: {polling_interval}s)",
@@ -582,20 +505,6 @@ class Command(BaseCommand):
stability_timeout_ms = int(stability_delay * 1000)
testing_timeout_ms = int(self.testing_timeout_s * 1000)
def cap_for_rescan(ms: int) -> int:
"""
Ensure the watch loop wakes often enough to run the rescan.
``watch()`` blocks for up to ``rust_timeout``, so the rescan can
only run that often. A timeout of 0 means "wait indefinitely",
which would never wake to rescan; cap it at the rescan interval.
"""
if rescan_timeout_ms <= 0:
return ms
if ms <= 0:
return rescan_timeout_ms
return min(ms, rescan_timeout_ms)
# Calculate appropriate timeout for watch loop
# In polling mode, rust_timeout must be significantly longer than poll_delay_ms
# to ensure poll cycles can complete before timing out
@@ -613,8 +522,6 @@ class Command(BaseCommand):
# Not testing, wait indefinitely for first event
timeout_ms = 0
timeout_ms = cap_for_rescan(timeout_ms)
self.stop_flag.clear()
while not self.stop_flag.is_set():
@@ -644,26 +551,10 @@ class Command(BaseCommand):
consumption_dir=directory,
subdirs_as_tags=subdirs_as_tags,
)
# Remember it so the rescan does not re-queue it while
# the consume task has yet to remove it from disk
queued.add(stable_path)
# Exit watch loop to reconfigure timeout
break
# Periodic full-glob safety net for files the watcher missed
if rescan_timeout_ms > 0 and (
monotonic() - last_rescan >= rescan_interval_s
):
self._rescan_existing_files(
directory=directory,
recursive=recursive,
consumer_filter=consumer_filter,
tracker=tracker,
queued=queued,
)
last_rescan = monotonic()
# Determine next timeout
if tracker.has_pending_files():
# Check pending files at stability interval
@@ -681,8 +572,6 @@ class Command(BaseCommand):
# No pending files, wait indefinitely
timeout_ms = 0
timeout_ms = cap_for_rescan(timeout_ms)
except KeyboardInterrupt: # pragma: nocover
logger.info("Received interrupt, stopping consumer")
self.stop_flag.set()
@@ -480,7 +480,7 @@ class Command(CryptMixin, PaperlessCommand):
}
# 3. Export files from each document
for index, document_dict in enumerate(
for _, document_dict in enumerate(
self.track(
document_manifest,
description="Exporting documents...",
@@ -2,7 +2,6 @@ from typing import Any
from documents.management.commands.base import PaperlessCommand
from documents.tasks import llmindex_index
from paperless_ai.indexing import llm_index_compact
class Command(PaperlessCommand):
@@ -13,12 +12,9 @@ class Command(PaperlessCommand):
def add_arguments(self, parser: Any) -> None:
super().add_arguments(parser)
parser.add_argument("command", choices=["rebuild", "update", "compact"])
parser.add_argument("command", choices=["rebuild", "update"])
def handle(self, *args: Any, **options: Any) -> None:
if options["command"] == "compact":
llm_index_compact()
return
llmindex_index(
rebuild=options["command"] == "rebuild",
iter_wrapper=lambda docs: self.track(
@@ -133,11 +133,14 @@ def _build_suggestion_table(
else:
doc_cell = Text(f"{doc} [{doc.pk}]")
tag_parts: list[str] = []
for tag in sorted(suggestion.tags_to_add, key=lambda t: t.name):
tag_parts.append(f"[green]+{tag.name}[/green]")
for tag in sorted(suggestion.tags_to_remove, key=lambda t: t.name):
tag_parts.append(f"[red]-{tag.name}[/red]")
tag_parts: list[str] = [
f"[green]+{tag.name}[/green]"
for tag in sorted(suggestion.tags_to_add, key=lambda t: t.name)
]
tag_parts.extend(
f"[red]-{tag.name}[/red]"
for tag in sorted(suggestion.tags_to_remove, key=lambda t: t.name)
)
tag_cell = Text.from_markup(", ".join(tag_parts)) if tag_parts else Text("-")
table.add_row(
@@ -1,63 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:31
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
replaces = [
("documents", "0003_remove_document_storage_type"),
("documents", "0004_workflowtrigger_filter_has_any_correspondents_and_more"),
("documents", "0005_alter_document_checksum_unique"),
]
dependencies = [
("documents", "0002_squashed"),
]
operations = [
migrations.RemoveField(
model_name="document",
name="storage_type",
),
migrations.AddField(
model_name="workflowtrigger",
name="filter_has_any_correspondents",
field=models.ManyToManyField(
blank=True,
related_name="workflowtriggers_has_any_correspondent",
to="documents.correspondent",
verbose_name="has one of these correspondents",
),
),
migrations.AddField(
model_name="workflowtrigger",
name="filter_has_any_document_types",
field=models.ManyToManyField(
blank=True,
related_name="workflowtriggers_has_any_document_type",
to="documents.documenttype",
verbose_name="has one of these document types",
),
),
migrations.AddField(
model_name="workflowtrigger",
name="filter_has_any_storage_paths",
field=models.ManyToManyField(
blank=True,
related_name="workflowtriggers_has_any_storage_path",
to="documents.storagepath",
verbose_name="has one of these storage paths",
),
),
migrations.AlterField(
model_name="document",
name="checksum",
field=models.CharField(
editable=False,
help_text="The checksum of the original document.",
max_length=32,
verbose_name="checksum",
),
),
]
@@ -1,252 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:31
import django.db.models.deletion
import django.db.models.functions.text
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
replaces = [
("documents", "0008_workflowaction_passwords_alter_workflowaction_type"),
("documents", "0009_alter_document_content_length"),
("documents", "0010_optimize_integer_field_sizes"),
("documents", "0011_alter_workflowaction_type"),
("documents", "0012_document_root_document"),
]
dependencies = [
("documents", "0007_sharelinkbundle"),
]
operations = [
migrations.AddField(
model_name="workflowaction",
name="passwords",
field=models.JSONField(
blank=True,
help_text="Passwords to try when removing PDF protection. Separate with commas or new lines.",
null=True,
verbose_name="passwords",
),
),
migrations.AlterField(
model_name="document",
name="content_length",
field=models.GeneratedField(
db_persist=True,
expression=django.db.models.functions.text.Length("content"),
help_text="Length of the content field in characters. Automatically maintained by the database for faster statistics computation.",
output_field=models.PositiveIntegerField(default=0),
serialize=False,
),
),
migrations.AlterField(
model_name="correspondent",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="documenttype",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="savedviewfilterrule",
name="rule_type",
field=models.PositiveSmallIntegerField(
choices=[
(0, "title contains"),
(1, "content contains"),
(2, "ASN is"),
(3, "correspondent is"),
(4, "document type is"),
(5, "is in inbox"),
(6, "has tag"),
(7, "has any tag"),
(8, "created before"),
(9, "created after"),
(10, "created year is"),
(11, "created month is"),
(12, "created day is"),
(13, "added before"),
(14, "added after"),
(15, "modified before"),
(16, "modified after"),
(17, "does not have tag"),
(18, "does not have ASN"),
(19, "title or content contains"),
(20, "fulltext query"),
(21, "more like this"),
(22, "has tags in"),
(23, "ASN greater than"),
(24, "ASN less than"),
(25, "storage path is"),
(26, "has correspondent in"),
(27, "does not have correspondent in"),
(28, "has document type in"),
(29, "does not have document type in"),
(30, "has storage path in"),
(31, "does not have storage path in"),
(32, "owner is"),
(33, "has owner in"),
(34, "does not have owner"),
(35, "does not have owner in"),
(36, "has custom field value"),
(37, "is shared by me"),
(38, "has custom fields"),
(39, "has custom field in"),
(40, "does not have custom field in"),
(41, "does not have custom field"),
(42, "custom fields query"),
(43, "created to"),
(44, "created from"),
(45, "added to"),
(46, "added from"),
(47, "mime type is"),
],
verbose_name="rule type",
),
),
migrations.AlterField(
model_name="storagepath",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="tag",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
(6, "Automatic"),
],
default=1,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="workflowrun",
name="type",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Consumption Started"),
(2, "Document Added"),
(3, "Document Updated"),
(4, "Scheduled"),
],
null=True,
verbose_name="workflow trigger type",
),
),
migrations.AlterField(
model_name="workflowtrigger",
name="matching_algorithm",
field=models.PositiveSmallIntegerField(
choices=[
(0, "None"),
(1, "Any word"),
(2, "All words"),
(3, "Exact match"),
(4, "Regular expression"),
(5, "Fuzzy word"),
],
default=0,
verbose_name="matching algorithm",
),
),
migrations.AlterField(
model_name="workflowtrigger",
name="type",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Consumption Started"),
(2, "Document Added"),
(3, "Document Updated"),
(4, "Scheduled"),
],
default=1,
verbose_name="Workflow Trigger Type",
),
),
migrations.AlterField(
model_name="workflowaction",
name="type",
field=models.PositiveSmallIntegerField(
choices=[
(1, "Assignment"),
(2, "Removal"),
(3, "Email"),
(4, "Webhook"),
(5, "Password removal"),
(6, "Move to trash"),
],
default=1,
verbose_name="Workflow Action Type",
),
),
migrations.AddField(
model_name="document",
name="root_document",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="versions",
to="documents.document",
verbose_name="root document for this version",
),
),
migrations.AddField(
model_name="document",
name="version_label",
field=models.CharField(
blank=True,
help_text="Optional short label for a document version.",
max_length=64,
null=True,
verbose_name="version label",
),
),
]
+3 -3
View File
@@ -369,7 +369,7 @@ class Document(SoftDeleteModel, ModelWithOwner): # type: ignore[django-manager-
If the queryset already annotated ``effective_content``, that value is used.
"""
if hasattr(self, "effective_content"):
return getattr(self, "effective_content")
return self.effective_content
if self.root_document_id is not None or self.pk is None:
return self.content
@@ -1204,8 +1204,8 @@ class CustomFieldInstance(SoftDeleteModel):
def get_value_field_name(cls, data_type: CustomField.FieldDataType):
try:
return cls.TYPE_TO_DATA_STORE_NAME_MAP[data_type]
except KeyError: # pragma: no cover
raise NotImplementedError(data_type)
except KeyError as exc: # pragma: no cover
raise NotImplementedError(data_type) from exc
@property
def value(self):
+1 -1
View File
@@ -110,7 +110,7 @@ def run_convert(
args += ["-define", "pdf:use-cropbox=true"] if use_cropbox else []
args += [str(input_file), str(output_file)]
logger.debug("Execute: " + " ".join(args), extra={"group": logging_group})
logger.debug("Execute: %s", " ".join(args), extra={"group": logging_group})
try:
run_subprocess(args, environment, logger)
+1 -2
View File
@@ -67,8 +67,7 @@ class DateParserPluginBase(ABC):
Subclasses can override this to release resources.
"""
# Default implementation does nothing.
# Returning None implies exceptions are propagated.
return None
def _parse_string(
self,
-4
View File
@@ -8,15 +8,11 @@ from documents.search._backend import get_backend
from documents.search._backend import reset_backend
from documents.search._schema import needs_rebuild
from documents.search._schema import wipe_index
from documents.search._translate import InvalidDateQuery
from documents.search._translate import SearchQueryError
__all__ = [
"InvalidDateQuery",
"SearchHit",
"SearchIndexLockError",
"SearchMode",
"SearchQueryError",
"TantivyBackend",
"TantivyRelevanceList",
"WriteBatch",
+9 -21
View File
@@ -195,12 +195,12 @@ class WriteBatch:
try:
self._lock.acquire(timeout=self._lock_timeout)
break
except filelock.Timeout:
except filelock.Timeout as exc:
if attempt == _LOCK_RETRY_ATTEMPTS - 1:
raise SearchIndexLockError(
f"Could not acquire index lock after {_LOCK_RETRY_ATTEMPTS} "
f"attempts (timeout={self._lock_timeout}s each)",
)
) from exc
sleep_s = random.uniform(
0,
min(_LOCK_BACKOFF_CAP, _LOCK_BACKOFF_BASE * (2**attempt)),
@@ -651,7 +651,11 @@ class TantivyBackend:
result_ids = cast("list[int]", searcher.fast_field_values("id", result_addrs))
addr_by_id: dict[int, tuple[float, tantivy.DocAddress]] = {
doc_id: (score, addr)
for (score, addr), doc_id in zip(batch_results.hits, result_ids)
for (score, addr), doc_id in zip(
batch_results.hits,
result_ids,
strict=False,
)
}
snippet_generator = None
@@ -866,24 +870,8 @@ class TantivyBackend:
final_query = self._apply_permission_filter(mlt_query, user)
effective_limit = limit if limit is not None else searcher.num_docs
try:
# Fetch one extra to account for excluding the original document
results = searcher.search(final_query, limit=effective_limit + 1)
except BaseException: # pragma: no cover
# Tantivy 0.26 panics in BM25 idf scoring when the index holds
# soft-deleted documents (doc_freq can exceed the alive doc count),
# which only surfaces for the More Like This query. The panic crosses
# the pyo3 boundary as a `pyo3_runtime.PanicException` — a
# BaseException, not an Exception — so catch BaseException and degrade
# to "no similar documents" instead of bubbling a 500 to the client.
# Fixed upstream: https://github.com/quickwit-oss/tantivy/pull/2964
# Remove once the bundled tantivy includes that fix.
logger.warning(
"More Like This scoring panicked (likely stale tantivy segment "
"stats after deletions); returning no results. A search index "
"reindex will rebuild consistent statistics.",
)
return []
# Fetch one extra to account for excluding the original document
results = searcher.search(final_query, limit=effective_limit + 1)
addrs = [addr for _score, addr in results.hits]
all_ids = cast("list[int]", searcher.fast_field_values("id", addrs))
-163
View File
@@ -1,163 +0,0 @@
from __future__ import annotations
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Final
from dateutil.relativedelta import relativedelta
if TYPE_CHECKING:
from datetime import tzinfo
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _quarter_start(d: date) -> date:
"""Return the first day of the calendar quarter containing ``d``."""
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
def _midnight(d: date, tz: tzinfo) -> datetime:
"""Convert a calendar date at local-timezone midnight to a UTC datetime."""
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _keyword_bounds(keyword: str, tz: tzinfo) -> tuple[date, date]:
"""
Map a relative date keyword to ``(start, exclusive_end)`` calendar dates.
``tz`` only determines what "today" is; the caller decides how the returned
dates become UTC datetime boundaries (date-only vs. local-midnight offset).
"""
today = datetime.now(tz).date()
if keyword == _TODAY:
return today, today + timedelta(days=1)
if keyword == _YESTERDAY:
return today - timedelta(days=1), today
if keyword == _PREVIOUS_WEEK:
this_monday = today - timedelta(days=today.weekday())
return this_monday - timedelta(weeks=1), this_monday
if keyword == _THIS_MONTH:
first = today.replace(day=1)
return first, first + relativedelta(months=1)
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
return this_first - relativedelta(months=1), this_first
if keyword == _THIS_YEAR:
return date(today.year, 1, 1), date(today.year + 1, 1, 1)
if keyword == _PREVIOUS_YEAR:
return date(today.year - 1, 1, 1), date(today.year, 1, 1)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
return this_quarter - relativedelta(months=3), this_quarter
raise ValueError(f"Unknown keyword: {keyword}")
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic date only.
"""
start, end = _keyword_bounds(keyword, tz)
lo = datetime(start.year, start.month, start.day, tzinfo=UTC)
hi = datetime(end.year, end.month, end.day, tzinfo=UTC)
return _iso_range(lo, hi)
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC full offset arithmetic required.
"""
start, end = _keyword_bounds(keyword, tz)
return _iso_range(_midnight(start, tz), _midnight(end, tz))
def _precision_bounds(digits: str) -> tuple[date, date] | None:
"""
Map a 4/6/8-digit date token to (start, exclusive_end) calendar dates.
YYYY -> whole year, YYYYMM -> whole month, YYYYMMDD -> single day.
Returns None for any unparsable or out-of-range value (e.g. month 23),
so callers can emit a no-match clause instead of erroring (Whoosh parity).
"""
try:
if len(digits) == 4:
year = int(digits)
return date(year, 1, 1), date(year + 1, 1, 1)
if len(digits) == 6:
year, month = int(digits[:4]), int(digits[4:6])
start = date(year, month, 1)
end = date(year + 1, 1, 1) if month == 12 else date(year, month + 1, 1)
return start, end
if len(digits) == 8:
start = date(int(digits[:4]), int(digits[4:6]), int(digits[6:8]))
return start, start + timedelta(days=1)
except ValueError:
return None
return None
def _utc_bounds_for_field(
field: str,
start: date,
end: date,
tz: tzinfo,
) -> tuple[datetime, datetime]:
"""
Convert calendar-date bounds to UTC datetimes per the field's storage type.
For DateField (``created``) the bounds are UTC midnight (no offset). For
DateTimeField (``added``/``modified``) the bounds are local-tz midnight
converted to UTC, matching how each field is indexed.
"""
if field in _DATE_ONLY_FIELDS:
return (
datetime(start.year, start.month, start.day, tzinfo=UTC),
datetime(end.year, end.month, end.day, tzinfo=UTC),
)
return (
datetime(start.year, start.month, start.day, tzinfo=tz).astimezone(UTC),
datetime(end.year, end.month, end.day, tzinfo=tz).astimezone(UTC),
)
def _field_range_from_dates(field: str, start: date, end: date, tz: tzinfo) -> str:
"""Build a Tantivy ``field:[lo TO hi]`` ISO range from calendar-date bounds."""
lo, hi = _utc_bounds_for_field(field, start, end, tz)
return f"{field}:{_iso_range(lo, hi)}"
+409 -27
View File
@@ -1,35 +1,88 @@
from __future__ import annotations
import logging
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Final
import regex
import tantivy
from dateutil.relativedelta import relativedelta
from django.conf import settings
from documents.search._dates import (
_date_only_range, # noqa: F401 — re-exported for test imports
)
from documents.search._dates import (
_datetime_range, # noqa: F401 — re-exported for test imports
)
from documents.search._tokenizer import simple_search_tokens
from documents.search._translate import SearchQueryError
from documents.search._translate import translate_query
if TYPE_CHECKING:
from datetime import tzinfo
from django.contrib.auth.base_user import AbstractBaseUser
logger = logging.getLogger("paperless.search")
# Maximum seconds any single regex substitution may run.
# Prevents ReDoS on adversarial user-supplied query strings.
_REGEX_TIMEOUT: Final[float] = 1.0
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
_DATE_KEYWORD_PATTERN = "|".join(
sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True),
)
_FIELD_DATE_RE = regex.compile(
rf"""(?<!\w)(?P<field>created|modified|added)\s*:\s*(?:
(?P<quote>["'])(?P<quoted>{_DATE_KEYWORD_PATTERN})(?P=quote)
|
(?P<bare>{_DATE_KEYWORD_PATTERN})(?![\w-])
)""",
regex.IGNORECASE | regex.VERBOSE,
)
_COMPACT_DATE_RE = regex.compile(r"\b(\d{14})\b")
_RELATIVE_RANGE_RE = regex.compile(
r"\[now([+-]\d+[dhm])?\s+TO\s+now([+-]\d+[dhm])?\]",
regex.IGNORECASE,
)
# Whoosh-style relative date range: e.g. [-1 week to now], [-7 days to now]
_WHOOSH_REL_RANGE_RE = regex.compile(
r"\[-(?P<n>\d+)\s+(?P<unit>second|minute|hour|day|week|month|year)s?\s+to\s+now\]",
regex.IGNORECASE,
)
# Whoosh-style 8-digit date: field:YYYYMMDD — field-aware so timezone can be applied correctly.
# Scoped to date fields only; numeric fields (asn, id, page_count, ...) must not be rewritten.
_DATE8_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):(?P<date8>\d{8})\b",
)
_YEAR_RANGE_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):\[(?P<y1>\d{4})\s+TO\s+(?P<y2>\d{4})\]",
regex.IGNORECASE,
)
# Tantivy syntax error: " - " and " + " with spaces on both sides are invalid because
# the NOT/MUST operators require no space between the operator and the term.
# In natural-language queries (e.g., "H52.1 - Kurzsichtigkeit"), the dash is a separator.
_SPACED_OPERATOR_RE = regex.compile(r"\s+[-+]\s+")
_TRAILING_OPERATOR_RE = regex.compile(r"\s+[-+]+\s*$")
# Matches CJK/Hangul characters so queries can be routed to bigram fields.
# Uses Unicode properties to cover all blocks including Extension B+ planes.
_CJK_RE: Final = regex.compile(r"[\p{Han}\p{Hiragana}\p{Katakana}\p{Hangul}]+")
@@ -64,12 +117,305 @@ def _build_cjk_query(
return None
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic date only.
"""
today = datetime.now(tz).date()
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
lo = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, lo + timedelta(days=1))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
lo = datetime(y.year, y.month, y.day, tzinfo=UTC)
hi = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
lo = datetime(last_mon.year, last_mon.month, last_mon.day, tzinfo=UTC)
hi = datetime(this_mon.year, this_mon.month, this_mon.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_MONTH:
lo = datetime(today.year, today.month, 1, tzinfo=UTC)
if today.month == 12:
hi = datetime(today.year + 1, 1, 1, tzinfo=UTC)
else:
hi = datetime(today.year, today.month + 1, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_MONTH:
if today.month == 1:
lo = datetime(today.year - 1, 12, 1, tzinfo=UTC)
else:
lo = datetime(today.year, today.month - 1, 1, tzinfo=UTC)
hi = datetime(today.year, today.month, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_YEAR:
lo = datetime(today.year, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year + 1, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_YEAR:
lo = datetime(today.year - 1, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
lo = datetime(
last_quarter.year,
last_quarter.month,
last_quarter.day,
tzinfo=UTC,
)
hi = datetime(
this_quarter.year,
this_quarter.month,
this_quarter.day,
tzinfo=UTC,
)
return _iso_range(lo, hi)
raise ValueError(f"Unknown keyword: {keyword}")
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC full offset arithmetic required.
"""
now_local = datetime.now(tz)
today = now_local.date()
def _midnight(d: date) -> datetime:
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
return _iso_range(_midnight(today), _midnight(today + timedelta(days=1)))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
return _iso_range(_midnight(y), _midnight(today))
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
return _iso_range(_midnight(last_mon), _midnight(this_mon))
if keyword == _THIS_MONTH:
first = today.replace(day=1)
if today.month == 12:
next_first = date(today.year + 1, 1, 1)
else:
next_first = date(today.year, today.month + 1, 1)
return _iso_range(_midnight(first), _midnight(next_first))
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
if today.month == 1:
last_first = date(today.year - 1, 12, 1)
else:
last_first = date(today.year, today.month - 1, 1)
return _iso_range(_midnight(last_first), _midnight(this_first))
if keyword == _THIS_YEAR:
return _iso_range(
_midnight(date(today.year, 1, 1)),
_midnight(date(today.year + 1, 1, 1)),
)
if keyword == _PREVIOUS_YEAR:
return _iso_range(
_midnight(date(today.year - 1, 1, 1)),
_midnight(date(today.year, 1, 1)),
)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
return _iso_range(_midnight(last_quarter), _midnight(this_quarter))
raise ValueError(f"Unknown keyword: {keyword}")
def _rewrite_compact_date(query: str) -> str:
"""Rewrite Whoosh compact date tokens (14-digit YYYYMMDDHHmmss) to ISO 8601."""
def _sub(m: regex.Match[str]) -> str:
raw = m.group(1)
try:
dt = datetime(
int(raw[0:4]),
int(raw[4:6]),
int(raw[6:8]),
int(raw[8:10]),
int(raw[10:12]),
int(raw[12:14]),
tzinfo=UTC,
)
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
except ValueError:
return str(m.group(0))
try:
return _COMPACT_DATE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (compact date rewrite timed out)",
) from None
def _rewrite_relative_range(query: str) -> str:
"""Rewrite Whoosh relative ranges ([now-7d TO now]) to concrete ISO 8601 UTC boundaries."""
def _sub(m: regex.Match[str]) -> str:
now = datetime.now(UTC)
def _offset(s: str | None) -> timedelta:
if not s:
return timedelta(0)
sign = 1 if s[0] == "+" else -1
n, unit = int(s[1:-1]), s[-1]
return (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
lo, hi = now + _offset(m.group(1)), now + _offset(m.group(2))
if lo > hi:
lo, hi = hi, lo
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _RELATIVE_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (relative range rewrite timed out)",
) from None
def _rewrite_whoosh_relative_range(query: str) -> str:
"""Rewrite Whoosh-style relative date ranges ([-N unit to now]) to ISO 8601.
Supports: second, minute, hour, day, week, month, year (singular and plural).
Example: ``added:[-1 week to now]`` ``added:[2025-01-01T TO 2025-01-08T]``
"""
now = datetime.now(UTC)
def _sub(m: regex.Match[str]) -> str:
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
lo = now - delta_map[unit]
return f"[{_fmt(lo)} TO {_fmt(now)}]"
try:
return _WHOOSH_REL_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (Whoosh relative range rewrite timed out)",
) from None
def _rewrite_8digit_date(query: str, tz: tzinfo) -> str:
"""Rewrite field:YYYYMMDD date tokens to an ISO 8601 day range.
Runs after ``_rewrite_compact_date`` so 14-digit timestamps are already
converted and won't spuriously match here.
For DateField fields (e.g. ``created``) uses UTC midnight boundaries.
For DateTimeField fields (e.g. ``added``, ``modified``) uses local TZ
midnight boundaries converted to UTC matching the ``_datetime_range``
behaviour for keyword dates.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
raw = m.group("date8")
try:
year, month, day = int(raw[0:4]), int(raw[4:6]), int(raw[6:8])
d = date(year, month, day)
if field in _DATE_ONLY_FIELDS:
lo = datetime(d.year, d.month, d.day, tzinfo=UTC)
hi = lo + timedelta(days=1)
else:
# DateTimeField: use local-timezone midnight → UTC
lo = datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
hi = datetime(
(d + timedelta(days=1)).year,
(d + timedelta(days=1)).month,
(d + timedelta(days=1)).day,
tzinfo=tz,
).astimezone(UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
except ValueError:
return m.group(0)
try:
return _DATE8_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (8-digit date rewrite timed out)",
) from None
def _rewrite_year_range(query: str) -> str:
"""Rewrite Whoosh-style year-only date ranges to ISO 8601 UTC boundaries.
Converts ``field:[YYYY TO YYYY]`` to a full ISO 8601 datetime range.
The upper bound is the start of the year after the end year (exclusive),
matching the Whoosh convention of treating year-only ranges as full-year spans.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
y1, y2 = int(m.group("y1")), int(m.group("y2"))
# Whoosh swaps a reversed range when both years are explicit
# (whoosh.util.times.timespan.disambiguated); match that so a backwards
# range spans the intended years instead of matching nothing.
lo_year, hi_year = min(y1, y2), max(y1, y2)
lo = datetime(lo_year, 1, 1, tzinfo=UTC)
hi = datetime(hi_year + 1, 1, 1, tzinfo=UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _YEAR_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (year range rewrite timed out)",
) from None
def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
"""
Rewrite natural date syntax to ISO 8601 format for Tantivy compatibility.
Delegates to ``translate_query`` which handles all date forms, comma
expansion, field aliasing, relative ranges, and operator normalization.
Performs the first stage of query preprocessing, converting various date
formats and keywords to ISO 8601 datetime ranges that Tantivy can parse:
- Compact 14-digit dates (YYYYMMDDHHmmss)
- Whoosh relative ranges ([-7 days to now], [now-1h TO now+2h])
- 8-digit dates with field awareness (created:20240115)
- Natural keywords (field:today, field:"previous quarter", etc.)
Args:
query: Raw user query string
@@ -81,15 +427,35 @@ def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
Note:
Bare keywords without field prefixes pass through unchanged.
"""
return translate_query(query, tz)
query = _rewrite_compact_date(query)
query = _rewrite_whoosh_relative_range(query)
query = _rewrite_year_range(query)
query = _rewrite_8digit_date(query, tz)
query = _rewrite_relative_range(query)
def _replace(m: regex.Match[str]) -> str:
field = m.group("field")
keyword = (m.group("quoted") or m.group("bare")).lower()
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(keyword, tz)}"
return f"{field}:{_datetime_range(keyword, tz)}"
try:
return _FIELD_DATE_RE.sub(_replace, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (date keyword rewrite timed out)",
) from None
def normalize_query(query: str) -> str:
"""
Normalize query syntax for better search behavior.
Delegates to ``translate_query`` which handles comma expansion, whitespace
collapsing, operator normalization, and field aliasing.
Expands comma-separated field values to explicit AND clauses and
collapses excessive whitespace for cleaner parsing:
- tag:foo,bar tag:foo AND tag:bar
- multiple spaces single spaces
Args:
query: Query string after date rewriting
@@ -97,7 +463,31 @@ def normalize_query(query: str) -> str:
Returns:
Normalized query string ready for Tantivy parsing
"""
return translate_query(query, UTC)
def _expand(m: regex.Match[str]) -> str:
field = m.group(1)
values = [v.strip() for v in m.group(2).split(",") if v.strip()]
return " AND ".join(f"{field}:{v}" for v in values)
try:
query = regex.sub(
r"(\w+):([^\s\[\]]+(?:,[^\s\[\]]+)+)",
_expand,
query,
timeout=_REGEX_TIMEOUT,
)
query = regex.sub(r" {2,}", " ", query, timeout=_REGEX_TIMEOUT).strip()
# Strip trailing dangling operators before Tantivy sees them.
query = _TRAILING_OPERATOR_RE.sub("", query, timeout=_REGEX_TIMEOUT).strip()
# Replace " - " / " + " with a space: Tantivy requires no space between
# the operator and its operand (-term / +term), so spaces on both sides
# means this is a natural-language separator, not a query operator.
query = _SPACED_OPERATOR_RE.sub(" ", query, timeout=_REGEX_TIMEOUT).strip()
return query
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (normalization timed out)",
) from None
def build_permission_filter(
@@ -217,16 +607,8 @@ def parse_user_query(
as a post-search score filter, not during query construction.
"""
try:
query_str = translate_query(raw_query, tz)
except SearchQueryError:
# Intentional, user-fixable error (e.g. an unparsable date). Propagate so
# the view can return a 400 with a helpful message rather than falling
# back to the raw (still-invalid) query.
raise
except Exception: # pragma: no cover - defensive
logger.warning("Query translation failed; using raw query", exc_info=True)
query_str = raw_query
query_str = rewrite_natural_date_keywords(raw_query, tz)
query_str = normalize_query(query_str)
exact = index.parse_query(
query_str,
-566
View File
@@ -1,566 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TypeAlias
import regex
from dateutil.relativedelta import relativedelta
from documents.search._dates import _DATE_KEYWORDS
from documents.search._dates import _DATE_ONLY_FIELDS
from documents.search._dates import _date_only_range
from documents.search._dates import _datetime_range
from documents.search._dates import _field_range_from_dates
from documents.search._dates import _fmt
from documents.search._dates import _precision_bounds
from documents.search._dates import _utc_bounds_for_field
# Compiled regex that matches any known multi-word (or single-word) date keyword
# at the start of a match position, longest alternatives first so "previous week"
# wins over a hypothetical shorter "previous".
_KEYWORD_VALUE_RE = regex.compile(
"|".join(sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True)),
regex.IGNORECASE,
)
if TYPE_CHECKING:
from datetime import tzinfo
# TODO: this module translates date queries into Tantivy *string* syntax, which
# forces a workaround for something Tantivy's string parser cannot express on
# date fields: open-ended ranges use far-past/far-future string sentinels
# (OPEN_LO/OPEN_HI). These can be replaced with a real tantivy.Query object
# (Query.range_query(..., None) for open bounds) once tantivy-py accepts Python
# datetimes in range_query/term_query on Date fields. That support exists on
# tantivy-py master (PRs #655 + #666) but postdates the pinned 0.26.0 wheel, so
# it is blocked only on a published release > 0.26.0 and a dependency bump.
# (Unparsable dates now raise InvalidDateQuery -> HTTP 400 rather than using a
# no-match string sentinel.)
# Fields that store exact, non-analyzed comma-joined tokens in the index and so
# need explicit comma->AND expansion (Whoosh KEYWORD(commas=True) set).
MULTI_VALUE_FIELDS = frozenset({"tag", "tag_id", "viewer_id"})
# Date fields whose values/ranges get rewritten to RFC3339 Tantivy ranges.
DATE_FIELDS = frozenset({"created", "modified", "added"})
# Field aliases: Whoosh (v2) field names that were renamed in the Tantivy schema.
# Preserved here so v2 queries using the old names continue to work without 400
# errors instead of silently failing. Applied by _render to non-date field tokens.
FIELD_ALIASES: dict[str, str] = {
"type": "document_type",
"type_id": "document_type_id",
"path": "storage_path",
"path_id": "storage_path_id",
}
# Known schema fields: a comma immediately followed by ``<known>:`` is a clause
# separator. Restricting to known fields prevents URL-like ``http:`` misfires.
KNOWN_FIELDS = frozenset(
{
"title",
"content",
"correspondent",
"document_type",
"type", # v2 alias -> document_type
"storage_path",
"path", # v2 alias -> storage_path
"tag",
"tag_id",
"correspondent_id",
"document_type_id",
"type_id", # v2 alias -> document_type_id
"storage_path_id",
"path_id", # v2 alias -> storage_path_id
"owner_id",
"viewer_id",
"asn",
"page_count",
"num_notes",
"created",
"modified",
"added",
"original_filename",
"checksum",
"notes",
"custom_fields",
},
)
_FIELD_RE = regex.compile(r"(?P<field>\w+):")
# Matches the TO separator inside a range bracket. Handles three forms:
# middle: "lo TO hi" (either lo or hi may be empty)
# trailing: "lo TO" (open upper bound)
# leading: "TO hi" (open lower bound)
# Bounds MAY contain internal spaces (e.g. "-7 days"), so we use .*? / .+?
# and split on the whitespace-delimited " TO " / " to " separator.
_RANGE_RE = regex.compile(
r"^\s*(?P<lo>.*?)\s+[Tt][Oo]\s+(?P<hi>.+?)\s*$"
r"|"
r"^\s*(?P<lo2>.+?)\s+[Tt][Oo]\s*$"
r"|"
r"^\s*[Tt][Oo]\s+(?P<hi2>.+?)\s*$",
)
@dataclass(frozen=True, slots=True)
class FieldValue:
field: str
value: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class FieldValueList:
field: str
values: tuple[str, ...]
@dataclass(frozen=True, slots=True)
class FieldRange:
field: str
open: str
lo: str
hi: str
close: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class Comma:
pass
@dataclass(frozen=True, slots=True)
class Passthrough:
raw: str
Token: TypeAlias = FieldValue | FieldValueList | FieldRange | Comma | Passthrough
_CLOSE: dict[str, str] = {"[": "]", "{": "}"}
def scan(query: str) -> list[Token]:
"""
Tokenize a raw query into date/comma-aware tokens, leaving everything else
as verbatim ``Passthrough`` runs. Non-recursive: finds the first matching
close bracket/quote. Nested brackets are not valid Tantivy range syntax and
pass through verbatim on mismatch.
"""
tokens: list[Token] = []
buf: list[str] = [] # accumulates passthrough chars
i, n = 0, len(query)
while i < n:
matched = _match_field_token(query, i)
if matched is None:
buf.append(query[i])
i += 1
continue
token, i = matched
_flush(buf, tokens)
tokens.append(token)
i = _maybe_comma(query, i, tokens)
_flush(buf, tokens)
return tokens
def _flush(buf: list[str], tokens: list[Token]) -> None:
"""Emit any accumulated passthrough characters as a single token."""
if buf:
tokens.append(Passthrough("".join(buf)))
buf.clear()
def _at_word_boundary(query: str, i: int) -> bool:
"""A field token may begin only at the start or after a non-word character."""
return i == 0 or not (query[i - 1].isalnum() or query[i - 1] == "_")
def _match_field_token(query: str, i: int) -> tuple[Token, int] | None:
"""
If a known ``field:`` token starts at ``i``, consume it and return
``(token, end_index)``; otherwise return None so the caller treats the
character as passthrough. Handles both ``field:[range]`` and ``field:value``,
and returns None when the range/value cannot be consumed.
"""
m = _FIELD_RE.match(query, i)
if m is None or m.group("field") not in KNOWN_FIELDS:
return None
if not _at_word_boundary(query, i):
return None
field = m.group("field")
j = m.end()
if j < len(query) and query[j] in "[{":
return _consume_range(query, j, field)
consumed = _consume_field_value(query, field, j)
if consumed is None:
return None
value, end = consumed
return FieldValue(field, value), end
def _consume_field_value(query: str, field: str, start: int) -> tuple[str, int] | None:
"""
Consume a field value starting at ``start``: a multi-word date keyword phrase
(date fields only), or a bare/quoted value, then absorb any comma-joined
continuation that is not a clause separator. ``resolve_commas`` later splits a
multi-value field's joined value into a ``FieldValueList``; for other fields
the comma stays literal.
"""
n = len(query)
consumed = None
if field in DATE_FIELDS:
km = _KEYWORD_VALUE_RE.match(query, start)
if km is not None and (km.end() >= n or query[km.end()] in " \t),"):
consumed = (km.group(0), km.end())
if consumed is None:
consumed = _consume_value(query, start)
if consumed is None:
return None
value, k = consumed
while k < n and query[k] == ",":
if _looks_like_known_field(query, k + 1):
break # clause separator: left for _maybe_comma to emit a Comma()
more = _consume_value(query, k + 1)
if more is None:
break
value = f"{value},{more[0]}"
k = more[1]
return value, k
def _consume_range(
query: str,
start: int,
field: str,
) -> tuple[FieldRange, int] | None:
"""Consume ``[lo TO hi]`` / ``{lo TO hi}`` from ``start`` (the bracket)."""
open_br = query[start]
close_br = _CLOSE[open_br]
end = query.find(close_br, start + 1)
if end == -1:
return None
inner = query[start + 1 : end]
m = _RANGE_RE.match(inner)
if m is not None:
if m.group("lo") is not None or m.group("hi") is not None:
# Middle form: "lo TO hi" (either may be empty string)
lo = (m.group("lo") or "").strip()
hi = (m.group("hi") or "").strip()
elif m.group("lo2") is not None:
# Trailing form: "lo TO"
lo = m.group("lo2").strip()
hi = ""
else:
# Leading form: "TO hi"
lo = ""
hi = (m.group("hi2") or "").strip()
else:
lo, hi = inner.strip(), ""
return FieldRange(field, open_br, lo, hi, close_br), end + 1
def _consume_value(query: str, start: int) -> tuple[str, int] | None:
"""Consume a bare or quoted field value from ``start``, stopping at comma."""
n = len(query)
if start >= n or query[start] in " \t":
return None
if query[start] in "\"'":
quote = query[start]
end = query.find(quote, start + 1)
if end == -1:
return None
return query[start : end + 1], end + 1
j = start
while j < n and query[j] not in " \t),":
j += 1
return query[start:j], j
def _looks_like_known_field(query: str, pos: int) -> bool:
"""True if a known ``field:`` token starts at ``pos``."""
m = _FIELD_RE.match(query, pos)
return bool(m and m.group("field") in KNOWN_FIELDS)
def _maybe_comma(query: str, i: int, tokens: list) -> int:
"""If a clause-separator comma follows at ``i``, emit ``Comma()`` and advance."""
if i < len(query) and query[i] == "," and _looks_like_known_field(query, i + 1):
tokens.append(Comma())
return i + 1
return i
def resolve_commas(tokens: list) -> list:
"""
Collapse value-list commas into ``FieldValueList`` and keep clause-separator
commas as ``Comma``. (Clause-sep commas are already emitted by ``scan`` via
the value-stop logic; this pass folds value-lists.)
"""
out: list = []
for tok in tokens:
if (
isinstance(tok, FieldValue)
and tok.field in MULTI_VALUE_FIELDS
and "," in tok.value
):
values = tuple(v for v in tok.value.split(",") if v)
out.append(FieldValueList(tok.field, values))
else:
out.append(tok)
return out
class SearchQueryError(ValueError):
"""
Base for user-fixable search query errors.
Carries a message safe to surface to the user (no internal details). The view
layer catches this and returns an HTTP 400, so any future subclass (unknown
field, malformed range, wrapped parser errors) gets the same treatment.
"""
class InvalidDateQuery(SearchQueryError):
"""Raised when a date field value or range bound cannot be parsed."""
def __init__(self, field: str, value: str) -> None:
self.field = field
self.value = value
super().__init__(f"Invalid date value {value!r} for field {field!r}.")
_DIGITS_RE = regex.compile(r"^\d{4}(?:\d{2}){0,2}$")
_ISO_RE = regex.compile(r"^\d{4}(?:-\d{2}(?:-\d{2})?)?$")
def translate_scalar(field: str, value: str, tz: tzinfo) -> str:
"""Translate a bare date-field value to a Tantivy range string."""
bare = value.strip("\"'").lower()
if bare in _DATE_KEYWORDS:
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(bare, tz)}"
return f"{field}:{_datetime_range(bare, tz)}"
digits = value.replace("-", "")
if _DIGITS_RE.match(value) or _ISO_RE.match(value):
bounds = _precision_bounds(digits)
if bounds is None:
raise InvalidDateQuery(field, value)
return _field_range_from_dates(field, bounds[0], bounds[1], tz)
if regex.fullmatch(r"\d{14}", value):
try:
dt = datetime(
int(value[0:4]),
int(value[4:6]),
int(value[6:8]),
int(value[8:10]),
int(value[10:12]),
int(value[12:14]),
tzinfo=UTC,
)
except ValueError:
raise InvalidDateQuery(field, value) from None
iso = _fmt(dt)
return f"{field}:[{iso} TO {iso}]"
# Unrecognized shape -> tell the user their date is malformed rather than
# silently matching nothing or emitting invalid Tantivy syntax.
raise InvalidDateQuery(field, value)
# Open-bound sentinels for date ranges. These far-past/far-future strings allow
# open-ended ranges to be expressed as Tantivy string queries until tantivy-py
# exposes Query.range_query(..., None) on Date fields (see module TODO).
OPEN_LO = "0001-01-01T00:00:00Z"
OPEN_HI = "9999-12-31T23:59:59Z"
# Matches compact now-offset tokens like now-7d, now+1h, now-30m.
_NOW_COMPACT_RE = regex.compile(
r"^now(?P<sign>[+-])(?P<n>\d+)(?P<unit>[dhm])$",
regex.IGNORECASE,
)
# Matches "±N <unit>" Whoosh-style offsets (e.g. -7 days, -1 week, +3 hours)
# Unit is singular or plural; sign prefix is mandatory.
_NOW_SPACED_RE = regex.compile(
r"^(?P<sign>[+-])(?P<n>\d+)\s*"
r"(?P<unit>second|minute|hour|day|week|month|year)s?$",
regex.IGNORECASE,
)
def _resolve_relative_bound(token: str) -> datetime | None:
"""
Resolve a relative bound token to an exact UTC instant, or return None.
Supported forms:
- ``now`` -> current UTC instant
- ``now+/-<n>d/h/m`` -> now +/- timedelta (d=days, h=hours, m=minutes)
- ``±N <unit>`` -> now +/- delta; month/year use relativedelta
"""
stripped = token.strip()
low = stripped.lower()
now = datetime.now(UTC)
if low == "now":
return now
m = _NOW_COMPACT_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta = (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
return now + delta
m = _NOW_SPACED_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
return now - delta_map[unit] if sign == -1 else now + delta_map[unit]
return None
def _bound_datetimes(
field: str,
token: str,
tz: tzinfo,
) -> tuple[datetime, datetime] | None:
"""
Return (floor_dt, ceil_dt) UTC datetimes for a single range bound token, or
None if the token is unparsable. ``now`` and relative offsets resolve to the
current instant (floor == ceil == that instant; no day-flooring).
"""
token = token.strip()
# Try relative/now forms first (before stripping hyphens which would mangle them).
rel = _resolve_relative_bound(token)
if rel is not None:
return rel, rel
# Full ISO datetime token (contains "T"): parse directly and return an exact
# instant (floor == ceil). Python 3.11+ datetime.fromisoformat accepts trailing Z.
if "T" in token:
try:
dt = datetime.fromisoformat(token)
# Ensure timezone-aware UTC result.
dt = dt.replace(tzinfo=UTC) if dt.tzinfo is None else dt.astimezone(UTC)
return dt, dt
except ValueError:
return None
digits = token.replace("-", "")
bounds = _precision_bounds(digits)
if bounds is None:
return None
start, end = bounds
return _utc_bounds_for_field(field, start, end, tz)
def _render(tok: Token, tz: tzinfo) -> str:
"""Render a single token back to a Tantivy query string fragment."""
if isinstance(tok, Passthrough):
return tok.raw
if isinstance(tok, Comma):
return " AND "
if isinstance(tok, FieldValueList):
field = FIELD_ALIASES.get(tok.field, tok.field)
return " AND ".join(f"{field}:{v}" for v in tok.values)
if isinstance(tok, FieldValue):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_scalar(field, tok.value, tz)
return f"{field}:{tok.value}"
if isinstance(tok, FieldRange):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_range(field, tok.lo, tok.hi, tz)
return f"{field}:{tok.open}{tok.lo} TO {tok.hi}{tok.close}"
return "" # pragma: no cover
# Post-render operator normalization patterns: collapse repeated whitespace and
# strip spaced/trailing Tantivy boolean operators that would otherwise be invalid.
_MULTI_SPACE_RE = regex.compile(r" {2,}")
_TRAILING_OP_RE = regex.compile(r"\s+[-+]+\s*$")
_SPACED_OP_RE = regex.compile(r"\s+[-+]\s+")
def _normalize_operators(text: str) -> str:
"""
Collapse multiple spaces, strip trailing dangling operators, and replace
spaced operators (`` - `` / `` + ``) with a single space.
Applied only to Passthrough fragments (the rendered output is scanned for
operator artifacts outside bracketed ranges) via a post-render pass on the
full rendered string. This preserves date ranges (``[... TO ...]``) verbatim
while cleaning natural-language separators in the surrounding text.
"""
text = _MULTI_SPACE_RE.sub(" ", text)
text = _TRAILING_OP_RE.sub("", text).strip()
text = _SPACED_OP_RE.sub(" ", text).strip()
return text
def translate_query(raw: str, tz: tzinfo) -> str:
"""Translate a raw Whoosh-style query into Tantivy-compatible syntax."""
tokens = resolve_commas(scan(raw))
rendered = "".join(_render(t, tz) for t in tokens)
return _normalize_operators(rendered)
def translate_range(field: str, lo: str, hi: str, tz: tzinfo) -> str:
"""Translate a date-field ``[lo TO hi]`` range to a Tantivy ISO range string.
Handles partial-date bounds (YYYY, YYYYMM, YYYYMMDD, ISO dash variants),
open bounds (empty string -> OPEN_LO/OPEN_HI), ``now``, and reversed ranges
(swaps tokens before computing floor/ceil so the span is always correct).
"""
lo_s = lo.strip()
hi_s = hi.strip()
# Parse both bounds to (floor, ceil) pairs when present.
lo_pair: tuple[datetime, datetime] | None = None
hi_pair: tuple[datetime, datetime] | None = None
if lo_s:
lo_pair = _bound_datetimes(field, lo_s, tz)
if lo_pair is None:
raise InvalidDateQuery(field, lo_s)
if hi_s:
hi_pair = _bound_datetimes(field, hi_s, tz)
if hi_pair is None:
raise InvalidDateQuery(field, hi_s)
# Detect a reversed range: only swap when BOTH bounds are present.
if lo_pair is not None and hi_pair is not None and lo_pair[0] > hi_pair[0]:
lo_pair, hi_pair = hi_pair, lo_pair
lo_iso = _fmt(lo_pair[0]) if lo_pair is not None else OPEN_LO
hi_iso = _fmt(hi_pair[1]) if hi_pair is not None else OPEN_HI
return f"{field}:[{lo_iso} TO {hi_iso}]"
+32 -87
View File
@@ -48,7 +48,6 @@ from rest_framework import serializers
from rest_framework.exceptions import PermissionDenied
from rest_framework.fields import SerializerMethodField
from rest_framework.filters import OrderingFilter
from rest_framework.utils import model_meta
if settings.AUDIT_LOG_ENABLED:
from auditlog.context import set_actor
@@ -122,45 +121,6 @@ class DynamicFieldsModelSerializer(serializers.ModelSerializer[Any]):
self.fields.pop(field_name)
class DocumentUpdateFieldsModelSerializer(DynamicFieldsModelSerializer):
stale_update_excluded_fields = frozenset({"filename", "archive_filename"})
def _get_update_fields(self, validated_data) -> list[str]:
model_fields = {
field.name
for field in self.Meta.model._meta.concrete_fields
if field.name not in self.stale_update_excluded_fields
}
update_fields = [
field_name for field_name in validated_data if field_name in model_fields
]
if "modified" in model_fields and "modified" not in update_fields:
update_fields.append("modified")
return update_fields
def update(self, instance, validated_data):
serializers.raise_errors_on_nested_writes("update", self, validated_data)
info = model_meta.get_field_info(instance)
m2m_fields = []
for attr, value in validated_data.items():
if attr in info.relations and info.relations[attr].to_many:
m2m_fields.append((attr, value))
else:
setattr(instance, attr, value)
# File names are managed by post-save file handling. Saving only the
# serializer-updated fields prevents stale in-memory path values from
# overwriting a concurrent move.
instance.save(update_fields=self._get_update_fields(validated_data))
for attr, value in m2m_fields:
field = getattr(instance, attr)
field.set(value)
return instance
class MatchingModelSerializer(serializers.ModelSerializer[Any]):
document_count = serializers.IntegerField(read_only=True)
@@ -203,7 +163,7 @@ class MatchingModelSerializer(serializers.ModelSerializer[Any]):
logger.debug(f"Invalid regular expression: {e!s}")
raise serializers.ValidationError(
"Invalid regular expression, see log for details.",
)
) from None
return match
@@ -907,7 +867,9 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer[CustomFieldInsta
try:
value_int = int(data["value"])
except (TypeError, ValueError):
raise serializers.ValidationError("Enter a valid integer.")
raise serializers.ValidationError(
"Enter a valid integer.",
) from None
# Keep values within the PostgreSQL integer range
MinValueValidator(-2147483648)(value_int)
MaxValueValidator(2147483647)(value_int)
@@ -939,7 +901,7 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer[CustomFieldInsta
except Exception:
raise serializers.ValidationError(
f"Value must be an id of an element in {select_options}",
)
) from None
elif field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
if not (isinstance(data["value"], list) or data["value"] is None):
raise serializers.ValidationError(
@@ -1029,7 +991,7 @@ class DocumentVersionInfoSerializer(serializers.Serializer[_DocumentVersionInfo]
class DocumentSerializer(
OwnedObjectSerializer,
NestedUpdateMixin,
DocumentUpdateFieldsModelSerializer,
DynamicFieldsModelSerializer,
):
correspondent = CorrespondentField(allow_null=True)
tags = TagsField(many=True)
@@ -1130,7 +1092,7 @@ class DocumentSerializer(
def to_representation(self, instance):
doc = super().to_representation(instance)
if "content" in self.fields and hasattr(instance, "effective_content"):
doc["content"] = getattr(instance, "effective_content") or ""
doc["content"] = instance.effective_content or ""
if self.truncate_content and "content" in self.fields:
doc["content"] = doc.get("content")[0:550]
return doc
@@ -1168,9 +1130,10 @@ class DocumentSerializer(
return super().validate(attrs)
def update(self, instance: Document, validated_data):
if "created_date" in validated_data and "created" not in validated_data:
instance.created = validated_data.get("created_date")
instance.save()
if "created_date" in validated_data:
if "created" not in validated_data:
validated_data["created"] = validated_data["created_date"]
logger.warning(
"created_date is deprecated, use created instead",
)
@@ -1240,13 +1203,11 @@ class DocumentSerializer(
for tag in instance.tags.all()
if tag not in inbox_tags_not_being_added
]
if settings.AUDIT_LOG_ENABLED:
with set_actor(self.user):
super().update(instance, validated_data)
else:
super().update(instance, validated_data)
# hard delete custom field instances that were soft deleted
CustomFieldInstance.deleted_objects.filter(document=instance).delete()
return instance
@@ -1493,7 +1454,7 @@ class SavedViewSerializer(OwnedObjectSerializer):
)
)
except serializers.ValidationError as exc:
raise serializers.ValidationError({field_name: exc.detail})
raise serializers.ValidationError({field_name: exc.detail}) from exc
del normalized_data[field_name]
ret = super().to_internal_value(normalized_data)
@@ -1797,7 +1758,7 @@ class BulkEditSerializer(
logger.exception(f"Error validating custom fields: {e}")
raise serializers.ValidationError(
f"{name} must be a list of integers or a dict of id:value pairs, see the log for details",
)
) from None
elif not isinstance(custom_fields, list) or not all(
isinstance(i, int) for i in ids
):
@@ -1865,7 +1826,7 @@ class BulkEditSerializer(
try:
Tag.objects.get(id=tag_id)
except Tag.DoesNotExist:
raise serializers.ValidationError("Tag does not exist")
raise serializers.ValidationError("Tag does not exist") from None
else:
raise serializers.ValidationError("tag not specified")
@@ -1878,7 +1839,9 @@ class BulkEditSerializer(
try:
DocumentType.objects.get(id=document_type_id)
except DocumentType.DoesNotExist:
raise serializers.ValidationError("Document type does not exist")
raise serializers.ValidationError(
"Document type does not exist",
) from None
else:
raise serializers.ValidationError("document_type not specified")
@@ -1890,7 +1853,9 @@ class BulkEditSerializer(
try:
Correspondent.objects.get(id=correspondent_id)
except Correspondent.DoesNotExist:
raise serializers.ValidationError("Correspondent does not exist")
raise serializers.ValidationError(
"Correspondent does not exist",
) from None
else:
raise serializers.ValidationError("correspondent not specified")
@@ -1904,7 +1869,7 @@ class BulkEditSerializer(
except StoragePath.DoesNotExist:
raise serializers.ValidationError(
"Storage path does not exist",
)
) from None
else:
raise serializers.ValidationError("storage path not specified")
@@ -1959,7 +1924,7 @@ class BulkEditSerializer(
):
raise serializers.ValidationError("invalid rotation degrees")
except ValueError:
raise serializers.ValidationError("invalid rotation degrees")
raise serializers.ValidationError("invalid rotation degrees") from None
def _validate_source_mode(self, parameters) -> None:
source_mode = parameters.get(
@@ -1989,7 +1954,7 @@ class BulkEditSerializer(
pages.append([int(doc)])
parameters["pages"] = pages
except ValueError:
raise serializers.ValidationError("invalid pages specified")
raise serializers.ValidationError("invalid pages specified") from None
if "delete_originals" in parameters:
if not isinstance(parameters["delete_originals"], bool):
@@ -2259,14 +2224,14 @@ class PostDocumentSerializer(serializers.Serializer[dict[str, Any]]):
raise serializers.ValidationError(
_("Custom field id must be an integer: %(id)s")
% {"id": field_id},
)
) from None
try:
field = CustomField.objects.get(id=field_id_int)
except CustomField.DoesNotExist:
raise serializers.ValidationError(
_("Custom field with id %(id)s does not exist")
% {"id": field_id_int},
)
) from None
custom_field_serializer.validate(
{
"field": field,
@@ -2283,7 +2248,7 @@ class PostDocumentSerializer(serializers.Serializer[dict[str, Any]]):
_(
"Custom fields must be a list of integers or an object mapping ids to values.",
),
)
) from None
if CustomField.objects.filter(id__in=ids).count() != len(set(ids)):
raise serializers.ValidationError(
_("Some custom fields don't exist or were specified twice."),
@@ -2394,7 +2359,9 @@ class EmailSerializer(DocumentListSerializer):
for address in address_list:
email_validator(address)
except ValidationError:
raise serializers.ValidationError(f"Invalid email address: {address}")
raise serializers.ValidationError(
f"Invalid email address: {address}",
) from None
return ",".join(address_list)
@@ -2673,25 +2640,18 @@ class RunTaskSerializer(serializers.Serializer[dict[str, str]]):
class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
tasks = serializers.ListField(
required=False,
required=True,
label="Tasks",
write_only=True,
child=serializers.IntegerField(),
)
all = serializers.BooleanField(
required=False,
default=False,
label="All",
write_only=True,
)
def _validate_task_id_list(self, tasks, name="tasks") -> None:
if not isinstance(tasks, list):
raise serializers.ValidationError(f"{name} must be a list")
if not all(isinstance(i, int) for i in tasks):
raise serializers.ValidationError(f"{name} must be a list of integers")
queryset = self.context.get("queryset", PaperlessTask.objects.all())
count = queryset.filter(id__in=tasks).count()
count = PaperlessTask.objects.filter(id__in=tasks).count()
if not count == len(tasks):
raise serializers.ValidationError(
f"Some tasks in {name} don't exist or were specified twice.",
@@ -2701,21 +2661,6 @@ class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
self._validate_task_id_list(tasks)
return tasks
def validate(self, attrs):
acknowledge_all = attrs.get("all", False)
task_ids = attrs.get("tasks")
if acknowledge_all and task_ids is not None:
raise serializers.ValidationError(
"Set either all or tasks, not both.",
)
if not acknowledge_all and task_ids is None:
raise serializers.ValidationError(
"Either all must be true or tasks must be provided.",
)
return attrs
class ShareLinkSerializer(OwnedObjectSerializer):
class Meta:
@@ -2840,7 +2785,7 @@ class ShareLinkBundleSerializer(OwnedObjectSerializer):
return share_link_bundle
def get_document_count(self, obj: ShareLinkBundle) -> int:
return getattr(obj, "document_total") or obj.documents.count()
return obj.document_total or obj.documents.count()
class BulkEditObjectsSerializer(SerializerWithPerms, SetPermissionsMixin):
@@ -3188,7 +3133,7 @@ class WorkflowActionSerializer(serializers.ModelSerializer[WorkflowAction]):
except (ValueError, KeyError) as e:
raise serializers.ValidationError(
{"assign_title": f'Invalid f-string detected: "{e.args[0]}"'},
)
) from None
if (
"type" in attrs
+3 -17
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import datetime
import hashlib
import logging
import shutil
import traceback as _tb
@@ -15,7 +16,6 @@ from celery.signals import task_postrun
from celery.signals import task_prerun
from celery.signals import task_revoked
from celery.signals import worker_process_init
from celery.signals import worker_process_shutdown
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
@@ -54,7 +54,6 @@ from documents.models import WorkflowTrigger
from documents.permissions import get_objects_for_user_owner_aware
from documents.plugins.helpers import DocumentsStatusManager
from documents.templating.utils import convert_format_str_to_template_format
from documents.utils import compute_checksum
from documents.workflows.actions import build_workflow_action_context
from documents.workflows.actions import execute_email_action
from documents.workflows.actions import execute_move_to_trash_action
@@ -411,7 +410,8 @@ def _path_matches_checksum(path: Path, checksum: str | None) -> bool:
if checksum is None or not path.is_file():
return False
return compute_checksum(path) == checksum
with path.open("rb") as f:
return hashlib.md5(f.read(), usedforsecurity=False).hexdigest() == checksum
def _filename_template_uses_custom_fields(doc: Document) -> bool:
@@ -1340,20 +1340,6 @@ def close_connection_pool_on_worker_init(**kwargs) -> None:
conn.close_pool()
@worker_process_shutdown.connect
def close_connection_pool_on_worker_shutdown(**kwargs) -> None: # pragma: no cover
"""
Close the DB connection pool when a Celery child process exits.
With CELERY_WORKER_MAX_TASKS_PER_CHILD=1 each child is replaced after a
single task. Without closing the pool on shutdown, its connections linger
on the server until TCP keepalive reaps them, accumulating over time.
"""
for conn in connections.all(initialized_only=True):
if conn.alias == "default" and hasattr(conn, "pool") and conn.pool:
conn.close_pool()
def add_or_update_document_in_llm_index(sender, document, **kwargs):
"""
Add or update a document in the LLM index when it is created or updated.
+15 -24
View File
@@ -1,7 +1,6 @@
import logging
import os
import re
import unicodedata
from collections.abc import Iterable
from pathlib import PurePath
@@ -37,12 +36,10 @@ class FilePathTemplate(Template):
def clean_filepath(value: str) -> str:
"""
Clean up a filepath by:
1. Normalizing Unicode to NFC form to prevent byte-level mismatches
2. Removing newlines and carriage returns
3. Removing extra spaces before and after forward slashes
4. Preserving spaces in other parts of the path
1. Removing newlines and carriage returns
2. Removing extra spaces before and after forward slashes
3. Preserving spaces in other parts of the path
"""
value = unicodedata.normalize("NFC", value)
value = value.replace("\n", "").replace("\r", "")
value = re.sub(r"\s*/\s*", "/", value)
@@ -184,17 +181,17 @@ def get_basic_metadata_context(
"""
return {
"title": pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", document.title),
document.title,
replacement_text="-",
),
"correspondent": pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", document.correspondent.name),
document.correspondent.name,
replacement_text="-",
)
if document.correspondent
else no_value_default,
"document_type": pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", document.document_type.name),
document.document_type.name,
replacement_text="-",
)
if document.document_type
@@ -205,10 +202,7 @@ def get_basic_metadata_context(
"owner_username": document.owner.username
if document.owner
else no_value_default,
"original_name": unicodedata.normalize(
"NFC",
PurePath(document.original_filename).with_suffix("").name,
)
"original_name": PurePath(document.original_filename).with_suffix("").name
if document.original_filename
else no_value_default,
"doc_pk": f"{document.pk:07}",
@@ -275,12 +269,12 @@ def get_tags_context(tags: Iterable[Tag]) -> dict[str, str | list[str]]:
return {
"tag_list": pathvalidate.sanitize_filename(
",".join(
sorted(unicodedata.normalize("NFC", tag.name) for tag in tags),
sorted(tag.name for tag in tags),
),
replacement_text="-",
),
# Assumed to be ordered, but a template could loop through to find what they want
"tag_name_list": [unicodedata.normalize("NFC", x.name) for x in tags],
"tag_name_list": [x.name for x in tags],
}
@@ -307,7 +301,7 @@ def get_custom_fields_context(
CustomField.FieldDataType.LONG_TEXT,
}:
value = pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", field_instance.value),
field_instance.value,
replacement_text="-",
)
elif (
@@ -316,13 +310,10 @@ def get_custom_fields_context(
):
options = field_instance.field.extra_data["select_options"]
value = pathvalidate.sanitize_filename(
unicodedata.normalize(
"NFC",
next(
option["label"]
for option in options
if option["id"] == field_instance.value
),
next(
option["label"]
for option in options
if option["id"] == field_instance.value
),
replacement_text="-",
)
@@ -330,7 +321,7 @@ def get_custom_fields_context(
value = field_instance.value
field_data["custom_fields"][
pathvalidate.sanitize_filename(
unicodedata.normalize("NFC", field_instance.field.name),
field_instance.field.name,
replacement_text="-",
)
] = {
@@ -29,9 +29,7 @@ class SimpleCommand(PaperlessCommand):
def handle(self, *args, **options):
items = list(range(5))
results = []
for item in self.track(items, description="Processing..."):
results.append(item * 2)
results = [item * 2 for item in self.track(items, description="Processing...")]
self.stdout.write(f"Results: {results}")
@@ -57,13 +55,13 @@ class MultiprocessCommand(PaperlessCommand):
def handle(self, *args, **options):
items = list(range(5))
results = []
for result in self.process_parallel(
_double_value,
items,
description="Processing...",
):
results.append(result)
results = list(
self.process_parallel(
_double_value,
items,
description="Processing...",
),
)
successes = sum(1 for r in results if r.success)
self.stdout.write(f"Successes: {successes}")
@@ -1,36 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from django.core.management import call_command
if TYPE_CHECKING:
from pytest_mock import MockerFixture
_COMPACT = "documents.management.commands.document_llmindex.llm_index_compact"
_INDEX = "documents.management.commands.document_llmindex.llmindex_index"
class TestDocumentLlmindexCommand:
def test_compact_calls_llm_index_compact(self, mocker: MockerFixture) -> None:
mock_compact = mocker.patch(_COMPACT)
call_command("document_llmindex", "compact")
mock_compact.assert_called_once_with()
def test_rebuild_calls_llmindex_index_with_rebuild_true(
self,
mocker: MockerFixture,
) -> None:
mock_index = mocker.patch(_INDEX)
call_command("document_llmindex", "rebuild")
mock_index.assert_called_once()
assert mock_index.call_args.kwargs["rebuild"] is True
def test_update_calls_llmindex_index_with_rebuild_false(
self,
mocker: MockerFixture,
) -> None:
mock_index = mocker.patch(_INDEX)
call_command("document_llmindex", "update")
mock_index.assert_called_once()
assert mock_index.call_args.kwargs["rebuild"] is False
-12
View File
@@ -1,15 +1,11 @@
from __future__ import annotations
import tempfile
from typing import TYPE_CHECKING
import pytest
import tantivy
from documents.search._backend import TantivyBackend
from documents.search._backend import reset_backend
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
if TYPE_CHECKING:
from collections.abc import Generator
@@ -35,11 +31,3 @@ def backend() -> Generator[TantivyBackend, None, None]:
finally:
b.close()
reset_backend()
@pytest.fixture(scope="module")
def index() -> tantivy.Index:
"""A real Tantivy index for parse-acceptance tests (module scope for speed)."""
idx = tantivy.Index(build_schema(), path=tempfile.mkdtemp())
register_tokenizers(idx, "english")
return idx
+10 -88
View File
@@ -13,6 +13,7 @@ import time_machine
from documents.search._query import _date_only_range
from documents.search._query import _datetime_range
from documents.search._query import _rewrite_compact_date
from documents.search._query import build_permission_filter
from documents.search._query import normalize_query
from documents.search._query import parse_simple_text_highlight_query
@@ -20,7 +21,6 @@ from documents.search._query import parse_user_query
from documents.search._query import rewrite_natural_date_keywords
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
from documents.search._translate import InvalidDateQuery
if TYPE_CHECKING:
from django.contrib.auth.base_user import AbstractBaseUser
@@ -405,14 +405,12 @@ class TestWhooshQueryRewriting:
assert lo == "2023-12-01T05:00:00Z"
assert hi == "2023-12-02T05:00:00Z"
def test_8digit_invalid_date_raises(self) -> None:
# The translation pipeline raises InvalidDateQuery for unparsable dates
# (e.g. month=13) so the API can surface a 400 telling the user the date
# is malformed instead of silently returning zero results.
with pytest.raises(InvalidDateQuery) as exc_info:
rewrite_natural_date_keywords("added:20231340", UTC)
assert exc_info.value.field == "added"
assert exc_info.value.value == "20231340"
def test_8digit_invalid_date_passes_through_unchanged(self) -> None:
assert rewrite_natural_date_keywords("added:20231340", UTC) == "added:20231340"
def test_compact_14digit_invalid_date_passes_through_unchanged(self) -> None:
# Month=13 makes datetime() raise ValueError; the token must be left as-is
assert _rewrite_compact_date("20231300120000") == "20231300120000"
class TestParseUserQuery:
@@ -465,67 +463,6 @@ class TestParseUserQuery:
) -> None:
assert isinstance(parse_user_query(query_index, raw_query, UTC), tantivy.Query)
@pytest.mark.parametrize(
"raw_query",
[
# Partial date scalar (year only)
pytest.param("created:2020", id="created_year_scalar"),
# 8-digit compact date range in brackets
pytest.param(
"created:[20200101 TO 20201231]",
id="created_8digit_bracket_range",
),
# Comma-separated field + date range (Whoosh v2 multi-clause syntax)
pytest.param(
"title:x,created:[2020 TO 2021]",
id="title_comma_created_range",
),
# Field alias: type -> document_type
pytest.param("type:invoice", id="type_alias"),
# Multi-word date keyword
pytest.param("created:previous week", id="created_previous_week"),
# Full ISO datetime range
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="created_iso_range",
),
# Comma-separated ISO ranges (Whoosh v2 syntax)
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="comma_iso_ranges",
),
],
)
def test_advanced_search_queries_do_not_raise(
self,
query_index: tantivy.Index,
raw_query: str,
) -> None:
"""
End-to-end: queries that the frontend sends must parse without raising.
This tests the full pipeline: translate_query -> tantivy parse_query.
Equivalent to asserting HTTP 200 (not 400) for each query form.
"""
with time_machine.travel(datetime(2026, 6, 15, 12, 0, tzinfo=UTC), tick=False):
assert isinstance(
parse_user_query(query_index, raw_query, UTC),
tantivy.Query,
)
def test_invalid_date_propagates_not_swallowed(
self,
query_index: tantivy.Index,
) -> None:
# parse_user_query falls back to the raw query on unexpected translation
# errors, but an InvalidDateQuery is intentional and must propagate so the
# view can return a 400 instead of silently parsing the raw (invalid) date.
with pytest.raises(InvalidDateQuery) as exc_info:
parse_user_query(query_index, "created:202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
class TestYearRangeRewriting:
"""Whoosh-style year-only date ranges must be rewritten to ISO 8601."""
@@ -605,16 +542,11 @@ class TestYearRangeRewriting:
assert rewrite_natural_date_keywords(original, UTC) == original
def test_8digit_in_brackets_not_matched_as_year_range(self) -> None:
# [YYYYMMDD TO YYYYMMDD]: the translation layer converts 8-digit bounds to
# ISO day ranges. 20200101 -> 2020-01-01T00:00:00Z (lo of that day);
# 20201231 -> the ceil of Dec 31 = 2021-01-01T00:00:00Z (exclusive end).
# This is the correct and accepted behavior: old compact form becomes a
# proper Tantivy-parseable ISO range.
# [YYYYMMDD TO YYYYMMDD] has 8-digit values - must not be caught by year rewriter
original = "created:[20200101 TO 20201231]"
result = rewrite_natural_date_keywords(original, UTC)
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2021-01-01T00:00:00Z"
assert "20200101" in result or "2020-01-01" in result
assert "20201231" in result or "2020-12-31" in result
class TestNonDateFieldsNotRewritten:
@@ -674,16 +606,6 @@ class TestNormalizeQuery:
def test_normalize_expands_comma_separated_tags(self) -> None:
assert normalize_query("tag:foo,bar") == "tag:foo AND tag:bar"
def test_normalize_comma_between_range_expressions(self) -> None:
# Comma-separated field range expressions (Whoosh v2 syntax) must be
# converted to AND so Tantivy does not receive an invalid comma.
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert normalize_query(q) == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_normalize_expands_three_values(self) -> None:
assert normalize_query("tag:foo,bar,baz") == "tag:foo AND tag:bar AND tag:baz"
@@ -1,742 +0,0 @@
from __future__ import annotations
from datetime import UTC
from datetime import datetime
from typing import TYPE_CHECKING
from zoneinfo import ZoneInfo
import pytest
import time_machine
from documents.search._dates import _precision_bounds
if TYPE_CHECKING:
import tantivy
from documents.search._query import _FIELD_BOOSTS
from documents.search._query import DEFAULT_SEARCH_FIELDS
from documents.search._translate import OPEN_HI
from documents.search._translate import OPEN_LO
from documents.search._translate import Comma
from documents.search._translate import FieldRange
from documents.search._translate import FieldValue
from documents.search._translate import FieldValueList
from documents.search._translate import InvalidDateQuery
from documents.search._translate import Passthrough
from documents.search._translate import resolve_commas
from documents.search._translate import scan
from documents.search._translate import translate_query
from documents.search._translate import translate_range
from documents.search._translate import translate_scalar
@pytest.mark.search
class TestPrecisionBounds:
@pytest.mark.parametrize(
("digits", "expected"),
[
("2020", ((2020, 1, 1), (2021, 1, 1))),
("202003", ((2020, 3, 1), (2020, 4, 1))),
("202012", ((2020, 12, 1), (2021, 1, 1))),
("20200115", ((2020, 1, 15), (2020, 1, 16))),
("20201231", ((2020, 12, 31), (2021, 1, 1))),
],
)
def test_valid(self, digits, expected):
lo, hi = _precision_bounds(digits)
assert (lo.year, lo.month, lo.day) == expected[0]
assert (hi.year, hi.month, hi.day) == expected[1]
@pytest.mark.parametrize("digits", ["202023", "20200230", "20201301", "20", "abcd"])
def test_invalid_returns_none(self, digits):
assert _precision_bounds(digits) is None
@pytest.mark.search
class TestScan:
def test_plain_words_are_passthrough(self):
assert scan("bank statement") == [Passthrough("bank statement")]
def test_field_value(self):
assert scan("created:2020") == [FieldValue("created", "2020")]
def test_field_value_in_boolean(self):
toks = scan("created:2020 OR foo")
assert toks == [
FieldValue("created", "2020"),
Passthrough(" OR foo"),
]
def test_field_value_in_parens(self):
toks = scan("(created:2020 OR foo)")
assert toks == [
Passthrough("("),
FieldValue("created", "2020"),
Passthrough(" OR foo)"),
]
def test_quoted_value(self):
assert scan('correspondent:"A B"') == [FieldValue("correspondent", '"A B"')]
def test_field_range(self):
assert scan("created:[2020 TO 2021]") == [
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.parametrize(
("query", "expected"),
[
pytest.param(
"created:[2020 to]",
FieldRange("created", "[", "2020", "", "]"),
id="open_upper",
),
pytest.param(
"created:[to 2020]",
FieldRange("created", "[", "", "2020", "]"),
id="open_lower",
),
],
)
def test_open_range(self, query, expected):
assert scan(query) == [expected]
def test_comma_inside_range_not_split(self):
# No depth-0 comma here; the whole thing is one range token.
toks = scan("created:[2020 TO 2021]")
assert len(toks) == 1
# --- Edge-case / regression tests (scan must never raise) ---
def test_url_is_passthrough(self):
# "http" is not a known field; the whole URL must pass through verbatim.
assert scan("http://example.com") == [Passthrough("http://example.com")]
def test_unterminated_quote_is_passthrough(self):
# title is a known field but the quoted value has no closing quote;
# _consume_value returns None so the whole string falls into passthrough.
assert scan('title:"abc') == [Passthrough('title:"abc')]
def test_unterminated_bracket_is_passthrough(self):
# created is a known field but the range bracket is never closed;
# _consume_range returns None so the whole string falls into passthrough.
assert scan("created:[2020") == [Passthrough("created:[2020")]
def test_empty_value_at_end_is_passthrough(self):
# created is a known field but there is no value after the colon
# (_consume_value returns None for start >= n), so passthrough.
assert scan("created:") == [Passthrough("created:")]
def test_value_containing_colon(self):
# The bare-word value reader stops at whitespace/paren, not at colon,
# so "2020:30" is consumed as a single value token.
assert scan("created:2020:30") == [FieldValue("created", "2020:30")]
def test_comma_followed_by_unconsumable_value_stops(self):
# A comma followed by whitespace is neither a value-list continuation nor a
# clause separator: the value stops and the comma stays as passthrough.
assert scan("tag:foo, bar") == [
FieldValue("tag", "foo"),
Passthrough(", bar"),
]
def test_bracket_without_to_is_open_upper_bound(self):
# A bracketed value with no TO falls back to (value, "") -> open upper bound.
assert scan("created:[2020]") == [
FieldRange("created", "[", "2020", "", "]"),
]
def test_known_field_name_midword_is_passthrough(self):
# A known field name embedded mid-word is not a field token (the
# word-boundary guard); the whole run stays passthrough.
assert scan("xtag:foo") == [Passthrough("xtag:foo")]
@pytest.mark.search
class TestCommaResolution:
def test_value_list_multi_value_field(self):
toks = resolve_commas(scan("tag:foo,bar"))
assert toks == [FieldValueList("tag", ("foo", "bar"))]
def test_value_list_three(self):
toks = resolve_commas(scan("tag_id:1,2,3"))
assert toks == [FieldValueList("tag_id", ("1", "2", "3"))]
def test_text_field_comma_is_literal(self):
# correspondent is not multi-value: comma stays inside the value.
toks = resolve_commas(scan("correspondent:foo,bar"))
assert toks == [FieldValue("correspondent", "foo,bar")]
def test_clause_separator_before_known_field(self):
toks = resolve_commas(scan("tag:foo,type:bar"))
assert toks == [FieldValue("tag", "foo"), Comma(), FieldValue("type", "bar")]
def test_clause_separator_after_range(self):
toks = resolve_commas(scan("created:[2020 TO 2021],added:[2022 TO 2023]"))
assert toks == [
FieldRange("created", "[", "2020", "2021", "]"),
Comma(),
FieldRange("added", "[", "2022", "2023", "]"),
]
def test_clause_separator_after_quote(self):
toks = resolve_commas(scan('correspondent:"A B",created:[2020 TO 2021]'))
assert toks == [
FieldValue("correspondent", '"A B"'),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
def test_url_comma_is_literal_passthrough(self):
toks = resolve_commas(scan("http://example.com/a,b"))
assert toks == [Passthrough("http://example.com/a,b")]
def test_non_multi_value_comma_is_literal(self):
# title is not in MULTI_VALUE_FIELDS: comma stays inside the value.
toks = resolve_commas(scan("title:10,20"))
assert toks == [FieldValue("title", "10,20")]
def test_clause_separator_before_known_date_field(self):
# The comma between a bare value and a known date field acts as a
# clause separator; both sides survive as distinct tokens.
toks = resolve_commas(scan("correspondent:foo,created:[2020 TO 2021]"))
assert toks == [
FieldValue("correspondent", "foo"),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.search
class TestTranslateScalar:
@pytest.mark.parametrize(
("field", "value", "expected"),
[
(
"created",
"2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"created",
"202003",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
(
"created",
"20200115",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-01-15",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-03",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
],
)
def test_partial_and_iso_dates(self, field: str, value: str, expected: str) -> None:
assert translate_scalar(field, value, UTC) == expected
def test_invalid_date_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_keyword_delegates(self) -> None:
# keyword path produces a range; just assert it is a created range
out = translate_scalar("created", "today", UTC)
assert out.startswith("created:[") and out.endswith("]")
def test_14digit_compact_datetime(self) -> None:
out = translate_scalar("created", "20240115120000", UTC)
assert "20240115120000" not in out
assert out.startswith("created:")
assert out == "created:[2024-01-15T12:00:00Z TO 2024-01-15T12:00:00Z]"
def test_14digit_invalid_month_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "20231300120000", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "20231300120000"
def test_unrecognized_value_raises(self) -> None:
# A value that is not a keyword, digits, ISO date, or compact timestamp
# raises rather than producing invalid Tantivy syntax or silently matching
# nothing.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateRange:
@pytest.mark.parametrize(
("lo", "hi", "expected"),
[
("2005", "2009", "created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"),
(
"202001",
"202006",
"created:[2020-01-01T00:00:00Z TO 2020-07-01T00:00:00Z]",
),
(
"20200101",
"20201231",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"2020-01-01",
"2020-12-31",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
],
)
def test_absolute_ranges(self, lo, hi, expected):
assert translate_range("created", lo, hi, UTC) == expected
def test_reversed_swaps(self):
assert translate_range("created", "2009", "2005", UTC) == (
"created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"
)
def test_open_upper(self):
out = translate_range("created", "2020", "", UTC)
assert out == f"created:[2020-01-01T00:00:00Z TO {OPEN_HI}]"
def test_open_lower(self):
out = translate_range("created", "", "2020", UTC)
assert out == f"created:[{OPEN_LO} TO 2021-01-01T00:00:00Z]"
def test_invalid_bound_raises(self):
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "202023", "2025", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_invalid_high_bound_raises(self):
# Low bound parses, high bound does not -> raise on the high bound.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "2020", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateQuery:
@pytest.mark.parametrize(
("raw", "expected"),
[
(
"created:2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
("tag:foo,bar", "tag:foo AND tag:bar"),
# 'type' is a user-facing alias rewritten to 'document_type' (the real schema field)
("tag:foo,type:bar", "tag:foo AND document_type:bar"),
(
"created:[2020 TO 2021],added:[2022 TO 2023]",
"created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
" AND "
"added:[2022-01-01T00:00:00Z TO 2024-01-01T00:00:00Z]",
),
# correspondent is not multi-value: comma stays literal inside the value
("correspondent:foo,bar", "correspondent:foo,bar"),
],
)
def test_golden(self, raw: str, expected: str) -> None:
assert translate_query(raw, UTC) == expected
@pytest.mark.parametrize(
"raw",
[
"created:2020",
"created:202003",
"created:[20200101 TO 20201231]",
"created:[2020-01-01 TO 2020-12-31]",
"created:[2020 to]",
"created:[to 2020]",
"title:x,created:[2020 TO 2021]",
"created:2020 OR foo",
"(created:2020 OR invoice)",
"tag:foo,type:bar",
"bank statement",
],
)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
# Must not raise:
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestFieldAliasing:
"""Whoosh->Tantivy field-name aliasing (type/path -> document_type/storage_path)."""
def test_type_alias(self) -> None:
assert translate_query("type:invoice", UTC) == "document_type:invoice"
def test_path_alias(self) -> None:
assert translate_query("path:/foo/bar", UTC) == "storage_path:/foo/bar"
def test_type_id_alias(self) -> None:
assert translate_query("type_id:5", UTC) == "document_type_id:5"
def test_path_id_alias(self) -> None:
assert translate_query("path_id:7", UTC) == "storage_path_id:7"
def test_clause_separator_plus_alias(self) -> None:
# Comma between known fields acts as AND separator; alias still applied.
assert (
translate_query("tag:foo,type:bar", UTC) == "tag:foo AND document_type:bar"
)
def test_type_range_alias(self) -> None:
# type is not a date field; range passes through verbatim with alias applied.
assert (
translate_query("type:[2020 TO 2021]", UTC)
== "document_type:[2020 TO 2021]"
)
def test_parse_acceptance_type(self, index: tantivy.Index) -> None:
# Translated output must be accepted by the real Tantivy parser.
translated = translate_query("type:invoice", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_path(self, index: tantivy.Index) -> None:
translated = translate_query("path:foo", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
# Freeze time so relative-date tests are deterministic.
_FROZEN_NOW = datetime(2026, 3, 28, 12, 0, 0, tzinfo=UTC)
@pytest.mark.search
class TestRelativeRanges:
"""Relative date-range tokens resolved against a frozen clock."""
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_7_days_to_now(self) -> None:
assert translate_query("added:[-7 days to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_week_to_now(self) -> None:
assert translate_query("added:[-1 week to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_month_to_now(self) -> None:
assert translate_query("created:[-1 month to now]", UTC) == (
"created:[2026-02-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_year_to_now(self) -> None:
assert translate_query("modified:[-1 year to now]", UTC) == (
"modified:[2025-03-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_3_hours_to_now(self) -> None:
assert translate_query("added:[-3 hours to now]", UTC) == (
"added:[2026-03-28T09:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_uppercase_units(self) -> None:
assert translate_query("added:[-1 WEEK TO NOW]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_now_minus_7d_compact(self) -> None:
assert translate_query("added:[now-7d TO now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_reversed_range_swapped(self) -> None:
# now+1h TO now-1h is reversed; translate_range swaps -> lo=now-1h, hi=now+1h
assert translate_query("added:[now+1h TO now-1h]", UTC) == (
"added:[2026-03-28T11:00:00Z TO 2026-03-28T13:00:00Z]"
)
@pytest.mark.parametrize(
"raw",
[
"added:[-7 days to now]",
"added:[-1 week to now]",
"created:[-1 month to now]",
"modified:[-1 year to now]",
"added:[-3 hours to now]",
"added:[now-7d TO now]",
"added:[now+1h TO now-1h]",
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestOperatorNormalization:
"""Post-render operator normalization in translate_query."""
def test_spaced_dash_removed(self) -> None:
assert (
translate_query("H52.1 - Kurzsichtigkeit", UTC) == "H52.1 Kurzsichtigkeit"
)
def test_spaced_dash_simple(self) -> None:
assert translate_query("bar - baz", UTC) == "bar baz"
def test_trailing_operator_stripped(self) -> None:
assert translate_query("foo -", UTC) == "foo"
def test_date_range_preserved(self) -> None:
out = translate_query("created:[2020 TO 2021]", UTC)
# Must not corrupt the ISO range
assert out == "created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
def test_date_scalar_with_or(self) -> None:
out = translate_query("created:2020 OR foo", UTC)
# The created scalar becomes a range; " OR foo" passes through verbatim.
assert out.startswith("created:[")
assert "OR foo" in out
def test_parse_acceptance_spaced_dash(self, index: tantivy.Index) -> None:
translated = translate_query("H52.1 - Kurzsichtigkeit", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_trailing_op(self, index: tantivy.Index) -> None:
translated = translate_query("foo -", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestMultiWordDateKeywords:
"""scan() must consume multi-word date keywords as a single value."""
def test_scan_previous_week_as_single_token(self) -> None:
# "created:previous week" must produce one FieldValue with value "previous week",
# not FieldValue("created","previous") + Passthrough(" week").
toks = scan("created:previous week")
assert toks == [FieldValue("created", "previous week")]
def test_scan_this_month_as_single_token(self) -> None:
toks = scan("added:this month")
assert toks == [FieldValue("added", "this month")]
def test_scan_previous_month_as_single_token(self) -> None:
toks = scan("created:previous month")
assert toks == [FieldValue("created", "previous month")]
def test_scan_this_year_as_single_token(self) -> None:
toks = scan("added:this year")
assert toks == [FieldValue("added", "this year")]
def test_scan_previous_year_as_single_token(self) -> None:
toks = scan("created:previous year")
assert toks == [FieldValue("created", "previous year")]
def test_scan_previous_quarter_as_single_token(self) -> None:
toks = scan("created:previous quarter")
assert toks == [FieldValue("created", "previous quarter")]
def test_quoted_multi_word_keyword_still_works(self) -> None:
# The quoted form must continue to work as before.
toks = scan('created:"previous week"')
assert toks == [FieldValue("created", '"previous week"')]
def test_non_date_field_not_affected(self) -> None:
# "previous" stops at the space for non-date fields; " week" passes through.
toks = scan("correspondent:previous week")
assert toks == [
FieldValue("correspondent", "previous"),
Passthrough(" week"),
]
@pytest.mark.search
class TestKeywordDateResolution:
"""Relative date keywords resolve to exact ISO ranges against a frozen clock.
Frozen at 2026-03-28 12:00 UTC (a Saturday in Q1) so the week, month,
quarter and year rollovers are all exercised by a single anchor.
"""
# created is a DateField: bounds are UTC midnight, no timezone offset.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"created:[2026-03-28T00:00:00Z TO 2026-03-29T00:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"created:[2026-03-27T00:00:00Z TO 2026-03-28T00:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"created:[2026-03-16T00:00:00Z TO 2026-03-23T00:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"created:[2026-03-01T00:00:00Z TO 2026-04-01T00:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"created:[2026-02-01T00:00:00Z TO 2026-03-01T00:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"created:[2026-01-01T00:00:00Z TO 2027-01-01T00:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"created:[2025-01-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"created:[2025-10-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_date_only_field_keyword_ranges(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"created:{keyword}", UTC) == expected
# added is a DateTimeField: local-tz midnight converted to UTC. Tokyo
# (+09:00, no DST) shifts each midnight boundary back to 15:00Z the day
# before, so this also exercises the local-midnight offset path.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"added:[2026-03-27T15:00:00Z TO 2026-03-28T15:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"added:[2026-03-26T15:00:00Z TO 2026-03-27T15:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"added:[2026-03-15T15:00:00Z TO 2026-03-22T15:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"added:[2026-02-28T15:00:00Z TO 2026-03-31T15:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"added:[2026-01-31T15:00:00Z TO 2026-02-28T15:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"added:[2025-12-31T15:00:00Z TO 2026-12-31T15:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"added:[2024-12-31T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"added:[2025-09-30T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_datetime_field_keyword_ranges_local_tz(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"added:{keyword}", ZoneInfo("Asia/Tokyo")) == expected
@pytest.mark.search
class TestISODatetimeBounds:
"""Full ISO datetime tokens in range bounds must be parsed directly."""
def test_translate_range_iso_bounds_passthrough(self) -> None:
# Already-ISO datetime bounds must pass through as-is (exact instant).
result = translate_range(
"created",
"2020-01-01T00:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert result == "created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]"
def test_translate_query_iso_range_preserved(self) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert translate_query(q, UTC) == q
def test_translate_query_comma_separated_iso_ranges(self) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
result = translate_query(q, UTC)
assert result == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_invalid_iso_datetime_raises(self) -> None:
# A token with "T" that is not valid ISO datetime -> raise.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range(
"created",
"2020-01-01T99:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert exc_info.value.field == "created"
assert exc_info.value.value == "2020-01-01T99:00:00Z"
def test_parse_acceptance_iso_bounds(self, index: tantivy.Index) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_comma_iso_ranges(self, index: tantivy.Index) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
+4 -5
View File
@@ -82,7 +82,6 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
"llm_api_key": None,
"llm_endpoint": None,
"llm_output_language": None,
"llm_request_timeout": None,
},
)
@@ -845,7 +844,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.llm_index_exists") as mock_exists,
patch("paperless.views.vector_store_file_exists") as mock_exists,
):
mock_exists.return_value = False
self.client.patch(
@@ -870,7 +869,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.llm_index_exists") as mock_exists,
patch("paperless.views.vector_store_file_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
@@ -891,7 +890,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.llm_index_exists") as mock_exists,
patch("paperless.views.vector_store_file_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
@@ -929,7 +928,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.llm_index_exists") as mock_exists,
patch("paperless.views.vector_store_file_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
@@ -6,7 +6,6 @@ import zipfile
from django.contrib.auth.models import User
from django.test import override_settings
from django.utils import timezone
from rest_framework import status
from rest_framework.test import APITestCase
@@ -33,21 +32,21 @@ class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase):
filename="docA.pdf",
mime_type="application/pdf",
checksum="B",
created=timezone.make_aware(datetime.datetime(2021, 1, 1)),
created=datetime.datetime(2021, 1, 1, tzinfo=datetime.UTC),
)
self.doc2b = Document.objects.create(
title="document A",
filename="docA2.pdf",
mime_type="application/pdf",
checksum="D",
created=timezone.make_aware(datetime.datetime(2021, 1, 1)),
created=datetime.datetime(2021, 1, 1, tzinfo=datetime.UTC),
)
self.doc3 = Document.objects.create(
title="document B",
filename="docB.jpg",
mime_type="image/jpeg",
checksum="C",
created=timezone.make_aware(datetime.datetime(2020, 3, 21)),
created=datetime.datetime(2020, 3, 21, tzinfo=datetime.UTC),
archive_filename="docB.pdf",
archive_checksum="D",
)
@@ -1,5 +1,5 @@
import datetime
import json
from datetime import date
from unittest import mock
from unittest.mock import ANY
@@ -456,7 +456,7 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
},
)
date_value = date.today()
date_value = datetime.datetime.now(tz=datetime.UTC).date()
resp = self.client.patch(
f"/api/documents/{doc.id}/",
@@ -618,7 +618,7 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
data_type=CustomField.FieldDataType.DATE,
)
date_value = date.today()
date_value = datetime.datetime.now(tz=datetime.UTC).date()
resp = self.client.patch(
f"/api/documents/{doc.id}/",
+1 -1
View File
@@ -265,7 +265,7 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
created=date(2023, 1, 1),
)
created_datetime = datetime.datetime(2023, 2, 1, 12, 0, 0)
created_datetime = datetime.datetime(2023, 2, 1, 12, 0, 0, tzinfo=datetime.UTC)
response = self.client.patch(
f"/api/documents/{doc.pk}/",
{"created": created_datetime},
@@ -1,95 +0,0 @@
import unicodedata
from typing import TYPE_CHECKING
from unittest import mock
import celery.result
import pytest
from django.core.files.uploadedfile import SimpleUploadedFile
if TYPE_CHECKING:
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides
@pytest.fixture()
def consume_file_mock():
with mock.patch("documents.tasks.consume_file.apply_async") as m:
m.return_value = celery.result.AsyncResult(id="test-task-id")
yield m
@pytest.fixture()
def directories(tmp_path, settings, _media_settings):
scratch = tmp_path / "scratch"
scratch.mkdir()
settings.SCRATCH_DIR = scratch
return scratch
@pytest.mark.django_db
class TestPostDocumentNFCNormalization:
def test_nfd_filename_normalized_to_nfc(
self,
admin_client,
consume_file_mock: mock.MagicMock,
directories,
):
"""Uploaded file with NFD filename must have its name stored as NFC."""
nfd = unicodedata.normalize("NFD", "Rechnung März.pdf")
nfc = unicodedata.normalize("NFC", "Rechnung März.pdf")
# Verify our test strings actually differ at the byte level
assert nfd != nfc
uploaded = SimpleUploadedFile(
nfd,
b"%PDF-1.4 test",
content_type="application/pdf",
)
response = admin_client.post(
"/api/documents/post_document/",
{"document": uploaded},
)
assert response.status_code == 200
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
input_doc: ConsumableDocument = task_kwargs["input_doc"]
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
# The temp file on disk must have an NFC name
assert input_doc.original_file.name == nfc, (
f"Expected NFC filename {nfc!r}, got {input_doc.original_file.name!r}"
)
# The override filename stored for later use must also be NFC
assert overrides.filename == nfc, (
f"Expected NFC override filename {nfc!r}, got {overrides.filename!r}"
)
assert unicodedata.is_normalized("NFC", overrides.filename)
def test_already_nfc_filename_unchanged(
self,
admin_client,
consume_file_mock: mock.MagicMock,
directories,
):
"""Uploaded file with already-NFC filename must pass through unchanged."""
nfc = unicodedata.normalize("NFC", "Invoice_2024.pdf")
uploaded = SimpleUploadedFile(
nfc,
b"%PDF-1.4 test",
content_type="application/pdf",
)
response = admin_client.post(
"/api/documents/post_document/",
{"document": uploaded},
)
assert response.status_code == 200
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
assert overrides.filename == nfc
assert unicodedata.is_normalized("NFC", overrides.filename)
+33 -20
View File
@@ -700,7 +700,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
pk=3,
checksum="C",
# specific time zone aware date
added=timezone.make_aware(datetime.datetime(2023, 12, 1)),
added=datetime.datetime(2023, 12, 1, tzinfo=datetime.UTC),
)
# refresh doc instance to ensure we operate on date objects that Django uses
# Django converts dates to UTC
@@ -725,11 +725,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
GIVEN:
- One document added right now
WHEN:
- Query with an invalid added date
- Query with invalid added date
THEN:
- 400 Bad Request with a message naming the malformed date, so the
user knows their date is invalid rather than silently getting zero
results
- 400 Bad Request returned (Tantivy rejects invalid date field syntax)
"""
d1 = Document.objects.create(
title="invoice",
@@ -742,9 +740,8 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
response = self.client.get("/api/documents/?query=added:invalid-date")
# An unparsable date is reported as a malformed query, not silently empty.
# Tantivy rejects unparsable field queries with a 400
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("invalid-date", str(response.data["query"]))
@override_settings(
TIME_ZONE="UTC",
@@ -997,25 +994,25 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
title="invoice",
content="the thing i bought at a shop and paid with bank account",
created=datetime.date(2018, 1, 1),
added=timezone.make_aware(datetime.datetime(2018, 1, 1)),
added=datetime.datetime(2018, 1, 1, tzinfo=datetime.UTC),
)
d2 = DocumentFactory(
title="bank statement 1",
content="things i paid for in august",
created=datetime.date(2019, 3, 4),
added=timezone.make_aware(datetime.datetime(2019, 3, 4)),
added=datetime.datetime(2019, 3, 4, tzinfo=datetime.UTC),
)
d3 = DocumentFactory(
title="bank statement 3",
content="things i paid for in september",
created=datetime.date(2020, 7, 9),
added=timezone.make_aware(datetime.datetime(2020, 7, 9)),
added=datetime.datetime(2020, 7, 9, tzinfo=datetime.UTC),
)
d4 = DocumentFactory(
title="Quarterly Report",
content="quarterly revenue profit margin earnings growth",
created=datetime.date(2021, 11, 30),
added=timezone.make_aware(datetime.datetime(2021, 11, 30)),
added=datetime.datetime(2021, 11, 30, tzinfo=datetime.UTC),
)
backend = get_backend()
backend.add_or_update(d1)
@@ -1134,7 +1131,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d4.tags.add(t2)
d5 = Document.objects.create(
checksum="5",
added=timezone.make_aware(datetime.datetime(2020, 7, 13)),
added=datetime.datetime(2020, 7, 13, tzinfo=datetime.UTC),
content="test",
original_filename="doc5.pdf",
)
@@ -1244,14 +1241,18 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d4.id,
search_query(
"&created__date__lt="
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
),
)
self.assertNotIn(
d4.id,
search_query(
"&created__date__gt="
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
),
)
@@ -1259,14 +1260,18 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d4.id,
search_query(
"&created__date__lt="
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
),
)
self.assertIn(
d4.id,
search_query(
"&created__date__gt="
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
),
)
@@ -1274,14 +1279,18 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d5.id,
search_query(
"&added__date__lt="
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
),
)
self.assertNotIn(
d5.id,
search_query(
"&added__date__gt="
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
),
)
@@ -1289,7 +1298,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d5.id,
search_query(
"&added__date__lt="
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
),
)
@@ -1297,7 +1308,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
d5.id,
search_query(
"&added__date__gt="
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
"%Y-%m-%d",
),
),
)
-71
View File
@@ -216,77 +216,6 @@ class TestSystemStatus(APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_none(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns no worker responses
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
mock_ping.return_value = None
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(
response.data["tasks"]["celery_error"],
"No celery workers responded to ping. This may be temporary.",
)
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_unexpected_responses(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns an unexpected worker response
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
self.client.force_login(self.user)
for ping_response in (
{"hostname": {"ok": "not-pong"}},
{"hostname": {}},
{"hostname": "pong"},
):
with self.subTest(ping_response=ping_response):
mock_ping.return_value = ping_response
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(response.data["tasks"]["celery_url"], "hostname")
self.assertEqual(
response.data["tasks"]["celery_error"],
"Celery worker responded unexpectedly.",
)
@mock.patch("documents.views.sleep")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_retry_success(
self,
mock_ping,
mock_sleep,
) -> None:
"""
GIVEN:
- Celery ping fails once but succeeds on retry
WHEN:
- The user requests the system status
THEN:
- The response contains an OK celery status
"""
mock_ping.side_effect = [None, {"hostname": {"ok": "pong"}}]
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
self.assertIsNone(response.data["tasks"]["celery_error"])
self.assertEqual(mock_ping.call_count, 2)
mock_sleep.assert_called_once_with(0.25)
@mock.patch("documents.search.get_backend")
def test_system_status_index_ok(self, mock_get_backend) -> None:
"""
-181
View File
@@ -18,7 +18,6 @@ from guardian.shortcuts import assign_perm
from rest_framework import status
from rest_framework.test import APIClient
from documents.filters import PaperlessTaskFilterSet
from documents.models import PaperlessTask
from documents.tests.factories import DocumentFactory
from documents.tests.factories import PaperlessTaskFactory
@@ -170,165 +169,6 @@ class TestGetTasksV10:
PaperlessTask.Status.STARTED,
}
def test_filter_by_task_name(self, admin_client: APIClient) -> None:
"""?name= searches task filenames, task types, and trigger sources."""
filename_task = PaperlessTaskFactory(input_data={"filename": "invoice-123.pdf"})
type_task = PaperlessTaskFactory(task_type=PaperlessTask.TaskType.SANITY_CHECK)
source_task = PaperlessTaskFactory(
trigger_source=PaperlessTask.TriggerSource.EMAIL_CONSUME,
)
PaperlessTaskFactory(input_data={"filename": "unrelated.pdf"})
response = admin_client.get(ENDPOINT, {"name": "invoice"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == filename_task.task_id
response = admin_client.get(ENDPOINT, {"name": "sanity"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == type_task.task_id
response = admin_client.get(ENDPOINT, {"name": "email"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == source_task.task_id
def test_filter_by_task_result(self, admin_client: APIClient) -> None:
"""?result= searches common structured task result messages."""
reason_task = PaperlessTaskFactory(result_data={"reason": "Manual review"})
error_task = PaperlessTaskFactory(
result_data={"error_message": "Duplicate detected"},
)
document_task = PaperlessTaskFactory(result_data={"document_id": 321})
duplicate_task = PaperlessTaskFactory(result_data={"duplicate_of": 123})
PaperlessTaskFactory(result_data={"reason": "unrelated"})
response = admin_client.get(ENDPOINT, {"result": "manual"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == reason_task.task_id
response = admin_client.get(ENDPOINT, {"result": "duplicate"})
assert response.status_code == status.HTTP_200_OK
returned_ids = {task["task_id"] for task in response.data["results"]}
assert returned_ids == {error_task.task_id, duplicate_task.task_id}
response = admin_client.get(ENDPOINT, {"result": "321"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == document_task.task_id
def test_empty_task_name_and_result_filters(self) -> None:
"""Empty name/result values leave the queryset unchanged."""
PaperlessTaskFactory.create_batch(2)
queryset = PaperlessTask.objects.all()
filterset = PaperlessTaskFilterSet()
assert filterset.filter_name(queryset, "name", "").count() == 2
assert filterset.filter_result(queryset, "result", "").count() == 2
def test_status_counts_respects_filters(self, admin_client: APIClient) -> None:
"""status_counts/ returns section counts for the filtered task queryset."""
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.FAILURE,
input_data={"filename": "invoice-a.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.REVOKED,
input_data={"filename": "invoice-b.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.PENDING,
input_data={"filename": "invoice-c.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.STARTED,
input_data={"filename": "invoice-d.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "invoice-e.pdf"},
)
PaperlessTaskFactory(
acknowledged=True,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "invoice-acknowledged.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "unrelated.pdf"},
)
response = admin_client.get(
f"{ENDPOINT}status_counts/",
{"acknowledged": "false", "name": "invoice"},
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
"all": 5,
"needs_attention": 2,
"in_progress": 2,
"completed": 1,
}
def test_status_counts_ignores_section_filters(
self,
admin_client: APIClient,
) -> None:
"""status_counts/ ignores status-like filters for the sections it counts."""
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.FAILURE,
input_data={"filename": "invoice-a.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.PENDING,
input_data={"filename": "invoice-b.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "invoice-c.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.FAILURE,
input_data={"filename": "unrelated.pdf"},
)
response = admin_client.get(
f"{ENDPOINT}status_counts/",
{
"acknowledged": "false",
"name": "invoice",
"status": PaperlessTask.Status.FAILURE,
"is_complete": "false",
},
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
"all": 3,
"needs_attention": 1,
"in_progress": 1,
"completed": 1,
}
def test_default_ordering_is_newest_first(self, admin_client: APIClient) -> None:
"""Tasks are returned in descending date_created order (newest first)."""
base = timezone.now()
@@ -682,27 +522,6 @@ class TestAcknowledge:
assert response.status_code == status.HTTP_200_OK
assert response.data == {"result": 2}
def test_acknowledge_all_returns_count(self, admin_client: APIClient) -> None:
"""POST acknowledge/ with all=true acknowledges all unacknowledged tasks."""
unacknowledged_task1 = PaperlessTaskFactory(acknowledged=False)
unacknowledged_task2 = PaperlessTaskFactory(acknowledged=False)
acknowledged_task = PaperlessTaskFactory(acknowledged=True)
response = admin_client.post(
ENDPOINT + "acknowledge/",
{"all": True},
format="json",
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {"result": 2}
unacknowledged_task1.refresh_from_db()
unacknowledged_task2.refresh_from_db()
acknowledged_task.refresh_from_db()
assert unacknowledged_task1.acknowledged
assert unacknowledged_task2.acknowledged
assert acknowledged_task.acknowledged
def test_acknowledged_tasks_excluded_from_unacked_filter(
self,
admin_client: APIClient,
+13 -120
View File
@@ -3,7 +3,6 @@ from datetime import date
from pathlib import Path
from unittest import mock
import pikepdf
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.test import TestCase
@@ -616,18 +615,6 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.img_doc.archive_filename = img_doc_archive
self.img_doc.save()
@staticmethod
def mock_password_required_pdf(
mock_open: mock.Mock,
fake_pdf: mock.Mock,
) -> None:
password_context = mock.MagicMock()
password_context.__enter__.return_value = fake_pdf
mock_open.side_effect = [
pikepdf.PasswordError("password required"),
password_context,
]
@mock.patch("documents.tasks.consume_file.s")
def test_merge(self, mock_consume_file) -> None:
"""
@@ -777,7 +764,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
sig.set.return_value.apply_async.side_effect = Exception("boom")
mock_consume_file.return_value = sig
with self.assertRaises(Exception):
with self.assertRaisesRegex(Exception, "boom"):
bulk_edit.merge(doc_ids, delete_originals=True)
self.doc1.refresh_from_db()
@@ -1060,6 +1047,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
for call, expected_id in zip(
mock_consume_delay.call_args_list,
doc_ids,
strict=False,
):
task_kwargs = call.kwargs["kwargs"]
self.assertEqual(task_kwargs["input_doc"].root_document_id, expected_id)
@@ -1318,7 +1306,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
sig.apply_async.side_effect = Exception("boom")
mock_chord.return_value = sig
with self.assertRaises(Exception):
with self.assertRaisesRegex(Exception, "boom"):
bulk_edit.edit_pdf(doc_ids, operations, delete_original=True)
self.doc2.refresh_from_db()
@@ -1430,7 +1418,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
{"page": 9999}, # invalid page, forces error during PDF load
]
with self.assertLogs("paperless.bulk_edit", level="ERROR"):
with self.assertRaises(Exception):
with self.assertRaises(ValueError):
bulk_edit.edit_pdf(doc_ids, operations)
mock_group.assert_not_called()
mock_consume_file.assert_not_called()
@@ -1479,7 +1467,6 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf = mock.MagicMock()
fake_pdf.pages = [mock.Mock(), mock.Mock(), mock.Mock()]
fake_pdf.is_encrypted = True
def save_side_effect(target_path):
Path(target_path).write_bytes(b"new pdf content")
@@ -1494,13 +1481,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(doc.source_path),
mock.call(doc.source_path, password="secret"),
],
)
mock_open.assert_called_once_with(doc.source_path, password="secret")
fake_pdf.remove_unreferenced_resources.assert_called_once()
mock_update_document.assert_not_called()
mock_consume_delay.assert_called_once()
@@ -1514,33 +1495,6 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(task_kwargs["input_doc"].root_document_id, doc.id)
self.assertIsNotNone(task_kwargs["overrides"])
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
@mock.patch("pikepdf.open")
def test_remove_password_update_document_skips_unencrypted_pdf(
self,
mock_open,
mock_mkdtemp,
mock_consume_delay,
) -> None:
doc = self.doc1
fake_pdf = mock.MagicMock()
fake_pdf.is_encrypted = False
mock_open.return_value.__enter__.return_value = fake_pdf
result = bulk_edit.remove_password(
[doc.id],
password="secret",
update_document=True,
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path)
fake_pdf.remove_unreferenced_resources.assert_not_called()
fake_pdf.save.assert_not_called()
mock_mkdtemp.assert_not_called()
mock_consume_delay.assert_not_called()
@mock.patch("documents.bulk_edit.update_document_content_maybe_archive_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
@@ -1560,12 +1514,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_mkdtemp.return_value = str(temp_dir)
fake_pdf = mock.MagicMock()
self.mock_password_required_pdf(mock_open, fake_pdf)
def save_side_effect(target_path):
Path(target_path).write_bytes(b"new pdf content")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
result = bulk_edit.remove_password(
[doc.id],
@@ -1575,13 +1529,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(source_file),
mock.call(source_file, password="secret"),
],
)
mock_open.assert_called_once_with(source_file, password="secret")
mock_update_document.assert_not_called()
mock_consume_delay.assert_called_once()
@@ -1600,7 +1548,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
root_document=self.doc1,
)
fake_pdf = mock.MagicMock()
self.mock_password_required_pdf(mock_open, fake_pdf)
mock_open.return_value.__enter__.return_value = fake_pdf
result = bulk_edit.remove_password(
[self.doc1.id],
@@ -1610,13 +1558,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(self.doc1.source_path),
mock.call(self.doc1.source_path, password="secret"),
],
)
mock_open.assert_called_once_with(self.doc1.source_path, password="secret")
mock_consume_delay.assert_called_once()
@mock.patch("documents.bulk_edit.chord")
@@ -1639,12 +1581,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf = mock.MagicMock()
fake_pdf.pages = [mock.Mock(), mock.Mock()]
self.mock_password_required_pdf(mock_open, fake_pdf)
def save_side_effect(target_path: Path) -> None:
target_path.write_bytes(b"password removed")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
mock_group.return_value.delay.return_value = None
user = User.objects.create(username="owner")
@@ -1659,13 +1601,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(doc.source_path),
mock.call(doc.source_path, password="secret"),
],
)
mock_open.assert_called_once_with(doc.source_path, password="secret")
mock_consume_file.assert_called_once()
call_kwargs = mock_consume_file.call_args.kwargs
consumable_document = call_kwargs["input_doc"]
@@ -1683,43 +1619,6 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_group.return_value.delay.assert_called_once()
mock_chord.assert_not_called()
@mock.patch("documents.bulk_edit.delete")
@mock.patch("documents.bulk_edit.chord")
@mock.patch("documents.bulk_edit.group")
@mock.patch("documents.tasks.consume_file.s")
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
@mock.patch("pikepdf.open")
def test_remove_password_skips_unencrypted_pdf_without_queueing(
self,
mock_open: mock.Mock,
mock_mkdtemp: mock.Mock,
mock_consume_file: mock.Mock,
mock_group: mock.Mock,
mock_chord: mock.Mock,
mock_delete: mock.Mock,
) -> None:
doc = self.doc2
fake_pdf = mock.MagicMock()
fake_pdf.is_encrypted = False
mock_open.return_value.__enter__.return_value = fake_pdf
result = bulk_edit.remove_password(
[doc.id],
password="secret",
update_document=False,
delete_original=True,
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path)
fake_pdf.remove_unreferenced_resources.assert_not_called()
fake_pdf.save.assert_not_called()
mock_mkdtemp.assert_not_called()
mock_consume_file.assert_not_called()
mock_group.assert_not_called()
mock_chord.assert_not_called()
mock_delete.si.assert_not_called()
@mock.patch("documents.bulk_edit.delete")
@mock.patch("documents.bulk_edit.chord")
@mock.patch("documents.bulk_edit.group")
@@ -1742,12 +1641,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf = mock.MagicMock()
fake_pdf.pages = [mock.Mock(), mock.Mock()]
self.mock_password_required_pdf(mock_open, fake_pdf)
def save_side_effect(target_path: Path) -> None:
target_path.write_bytes(b"password removed")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
mock_chord.return_value.delay.return_value = None
result = bulk_edit.remove_password(
@@ -1759,13 +1658,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(doc.source_path),
mock.call(doc.source_path, password="secret"),
],
)
mock_open.assert_called_once_with(doc.source_path, password="secret")
mock_consume_file.assert_called_once()
mock_group.assert_not_called()
mock_chord.assert_called_once()
+2 -2
View File
@@ -782,8 +782,8 @@ class TestClassifier(DirectoriesMixin, TestCase):
load_classifier(raise_exception=True)
Path(settings.MODEL_FILE).touch()
mock_load.side_effect = Exception()
with self.assertRaises(Exception):
mock_load.side_effect = RuntimeError()
with self.assertRaises(RuntimeError):
load_classifier(raise_exception=True)
+4 -4
View File
@@ -59,7 +59,7 @@ class TestDoubleSided(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
def create_staging_file(self, src="double-sided-odd.pdf", datetime=None) -> None:
shutil.copy(self.SAMPLE_DIR / src, self.staging_file)
if datetime is None:
datetime = dt.datetime.now()
datetime = dt.datetime.now(tz=dt.UTC)
os.utime(str(self.staging_file), (datetime.timestamp(),) * 2)
def test_odd_numbered_moved_to_staging(self) -> None:
@@ -79,8 +79,8 @@ class TestDoubleSided(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
self.assertIsFile(self.staging_file)
self.assertAlmostEqual(
dt.datetime.fromtimestamp(self.staging_file.stat().st_mtime),
dt.datetime.now(),
dt.datetime.fromtimestamp(self.staging_file.stat().st_mtime, tz=dt.UTC),
dt.datetime.now(tz=dt.UTC),
delta=dt.timedelta(seconds=5),
)
self.assertIn("Received odd numbered pages", msg["reason"])
@@ -124,7 +124,7 @@ class TestDoubleSided(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
"""
self.create_staging_file(
datetime=dt.datetime.now()
datetime=dt.datetime.now(tz=dt.UTC)
- dt.timedelta(minutes=TIMEOUT_MINUTES, seconds=1),
)
msg = self.consume_file("double-sided-odd.pdf")
+25 -57
View File
@@ -12,7 +12,6 @@ from django.contrib.auth.models import User
from django.db import DatabaseError
from django.test import TestCase
from django.test import override_settings
from django.utils import timezone
from documents.file_handling import create_source_path_directory
from documents.file_handling import delete_empty_directories
@@ -24,7 +23,6 @@ from documents.models import CustomFieldInstance
from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.serialisers import DocumentSerializer
from documents.tasks import empty_trash
from documents.tests.factories import DocumentFactory
from documents.tests.utils import DirectoriesMixin
@@ -222,8 +220,11 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
doc = Document.objects.create(
title="document",
mime_type="application/pdf",
checksum=hashlib.sha256(original_bytes).hexdigest(),
archive_checksum=hashlib.sha256(archive_bytes).hexdigest(),
checksum=hashlib.md5(original_bytes, usedforsecurity=False).hexdigest(),
archive_checksum=hashlib.md5(
archive_bytes,
usedforsecurity=False,
).hexdigest(),
filename="old/document.pdf",
archive_filename="old/document.pdf",
storage_path=old_storage_path,
@@ -252,46 +253,6 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
self.assertIsNotFile(settings.ORIGINALS_DIR / "old" / "document.pdf")
self.assertIsNotFile(settings.ARCHIVE_DIR / "old" / "document.pdf")
@override_settings(FILENAME_FORMAT="{title}")
def test_serializer_stale_update_does_not_clobber_filename(self) -> None:
old_path = settings.ORIGINALS_DIR / "original.pdf"
old_path.touch()
doc = Document.objects.create(
title="original",
mime_type="application/pdf",
checksum=hashlib.sha256(b"").hexdigest(),
filename="original.pdf",
)
first_instance = Document.objects.get(pk=doc.pk)
stale_instance = Document.objects.get(pk=doc.pk)
serializer = DocumentSerializer(
first_instance,
data={"title": "first"},
partial=True,
)
self.assertTrue(serializer.is_valid(), serializer.errors)
serializer.save()
doc.refresh_from_db()
self.assertEqual(doc.filename, "first.pdf")
self.assertIsFile(settings.ORIGINALS_DIR / "first.pdf")
serializer = DocumentSerializer(
stale_instance,
data={"title": "second"},
partial=True,
)
self.assertTrue(serializer.is_valid(), serializer.errors)
serializer.save()
doc.refresh_from_db()
self.assertEqual(doc.filename, "second.pdf")
self.assertIsFile(settings.ORIGINALS_DIR / "second.pdf")
self.assertIsNotFile(settings.ORIGINALS_DIR / "first.pdf")
self.assertIsNotFile(old_path)
@override_settings(FILENAME_FORMAT="{correspondent}/{correspondent}")
def test_document_delete(self) -> None:
document = Document()
@@ -452,7 +413,7 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
FILENAME_FORMAT="{created_year}-{created_month}-{created_day}",
)
def test_created_year_month_day(self) -> None:
d1 = timezone.make_aware(datetime.datetime(2020, 3, 6, 1, 1, 1))
d1 = datetime.datetime(2020, 3, 6, 1, 1, 1, tzinfo=datetime.UTC)
doc1 = Document.objects.create(
title="doc1",
mime_type="application/pdf",
@@ -469,7 +430,7 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
FILENAME_FORMAT="{added_year}-{added_month}-{added_day}",
)
def test_added_year_month_day(self) -> None:
d1 = timezone.make_aware(datetime.datetime(1232, 1, 9, 1, 1, 1))
d1 = datetime.datetime(1232, 1, 9, 1, 1, 1, tzinfo=datetime.UTC)
doc1 = Document.objects.create(
title="doc1",
mime_type="application/pdf",
@@ -482,7 +443,7 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
self.assertEqual(generate_filename(doc1), expected_filename)
doc1.added = timezone.make_aware(datetime.datetime(2020, 11, 16, 1, 1, 1))
doc1.added = datetime.datetime(2020, 11, 16, 1, 1, 1, tzinfo=datetime.UTC)
self.assertEqual(generate_filename(doc1), Path("2020-11-16.pdf"))
@@ -1266,7 +1227,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
def test_short_names_added(self) -> None:
doc = Document.objects.create(
title="The Title",
added=timezone.make_aware(datetime.datetime(1984, 8, 21, 7, 36, 51, 153)),
added=datetime.datetime(1984, 8, 21, 7, 36, 51, 153, tzinfo=datetime.UTC),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1505,7 +1466,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc_a = Document.objects.create(
title="Does Matter",
created=datetime.date(2020, 6, 25),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1577,7 +1538,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc = Document.objects.create(
title="scan_017562",
created=datetime.date(2025, 7, 2),
added=timezone.make_aware(datetime.datetime(2026, 3, 3, 11, 53, 16)),
added=datetime.datetime(2026, 3, 3, 11, 53, 16, tzinfo=datetime.UTC),
mime_type="application/pdf",
checksum="test-checksum",
storage_path=sp,
@@ -1606,7 +1567,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc_a = Document.objects.create(
title="Does Matter",
created=datetime.date(2020, 6, 25),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1641,7 +1602,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc_a = Document.objects.create(
title="Does Matter",
created=datetime.date(2020, 6, 25),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1673,7 +1634,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc_a = Document.objects.create(
title="Some Title",
created=datetime.date(2020, 6, 25),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1778,7 +1739,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc_a = Document.objects.create(
title="Some Title",
created=datetime.date(2020, 6, 25),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
mime_type="application/pdf",
pk=2,
checksum="2",
@@ -1792,8 +1753,15 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
CustomFieldInstance.objects.create(
document=doc_a,
field=CustomField.objects.get(name="Invoice Date"),
value_date=timezone.make_aware(
datetime.datetime(2024, 10, 1, 7, 36, 51, 153),
value_date=datetime.datetime(
2024,
10,
1,
7,
36,
51,
153,
tzinfo=datetime.UTC,
),
)
@@ -1833,7 +1801,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
doc = Document.objects.create(
title="Some Title! With @ Special # Characters",
created=datetime.date(2020, 6, 25),
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
mime_type="application/pdf",
pk=2,
checksum="2",
-187
View File
@@ -1,187 +0,0 @@
"""
Tests for NFC Unicode normalization in generate_filename / FilePathTemplate.render().
NFC `ü` (UTF-8: c3 bc) and NFD `ü` (UTF-8: 75 cc 88) are visually identical but
produce different byte sequences. On Linux (ext4, ZFS) these are distinct filenames.
All paths produced by the templating system must be NFC-normalized.
"""
import unicodedata
import pytest
from documents.file_handling import generate_filename
from documents.models import CustomField
from documents.models import CustomFieldInstance
from documents.tests.factories import CorrespondentFactory
from documents.tests.factories import DocumentFactory
from documents.tests.factories import StoragePathFactory
from documents.tests.factories import TagFactory
@pytest.mark.django_db
class TestGenerateFilenameNFCNormalization:
@pytest.mark.parametrize(
"raw,display",
[
(unicodedata.normalize("NFD", "Gemüse"), "Gemüse"),
(unicodedata.normalize("NFD", "Café"), "Café"),
(unicodedata.normalize("NFD", "naïve"), "naïve"),
],
)
def test_nfd_title_normalized_to_nfc(self, settings, raw, display):
"""NFD title must produce NFC path bytes."""
settings.FILENAME_FORMAT = "{{ title }}"
nfc = unicodedata.normalize("NFC", display)
assert raw != nfc # confirm byte-level difference
doc = DocumentFactory(title=raw, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result) == f"{nfc}.pdf"
assert str(result).encode() == f"{nfc}.pdf".encode()
def test_nfd_correspondent_normalized_to_nfc(self, settings):
"""NFD correspondent name must produce NFC path component."""
settings.FILENAME_FORMAT = "{{ correspondent }}/{{ title }}"
nfd = unicodedata.normalize("NFD", "Müller")
nfc = unicodedata.normalize("NFC", "Müller")
correspondent = CorrespondentFactory(name=nfd)
doc = DocumentFactory(
title="invoice",
correspondent=correspondent,
mime_type="application/pdf",
)
result = generate_filename(doc)
assert str(result) == f"{nfc}/invoice.pdf"
assert str(result).encode() == f"{nfc}/invoice.pdf".encode()
def test_nfd_storage_path_normalized_to_nfc(self, settings):
"""NFD literal in StoragePath.path template must produce NFC path bytes."""
settings.FILENAME_FORMAT = None
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
# StoragePath.path is used directly as the format/template string.
# Literal NFD characters in the template must survive rendering as NFC.
sp = StoragePathFactory(path=f"{nfd}/{{{{ title }}}}")
doc = DocumentFactory(title="doc", storage_path=sp, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
def test_nfd_raw_document_title_normalized_to_nfc(self, settings):
"""NFD title accessed via document.title (unsanitized context) must also be NFC."""
settings.FILENAME_FORMAT = "{{ document.title }}"
nfd = unicodedata.normalize("NFD", "Café")
nfc = unicodedata.normalize("NFC", "Café")
doc = DocumentFactory(title=nfd, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result) == f"{nfc}.pdf"
assert str(result).encode() == f"{nfc}.pdf".encode()
@pytest.mark.django_db
class TestContextBuilderNFCNormalization:
"""
Defense-in-depth: context builder functions must NFC-normalize string inputs
before passing them to sanitize_filename(). Task 1 already normalizes the
final rendered path via clean_filepath(), so these tests may already pass;
they exist as regression guards for the context-builder layer.
"""
def test_nfd_tag_name_normalized_in_tag_list(self, settings):
"""NFD tag name must appear as NFC bytes in the {{ tag_list }} shorthand."""
settings.FILENAME_FORMAT = "{{ tag_list }}/{{ title }}"
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
assert nfd != nfc # confirm they differ at byte level
tag = TagFactory(name=nfd)
doc = DocumentFactory(title="doc", mime_type="application/pdf")
doc.tags.set([tag])
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
def test_nfd_original_name_normalized_to_nfc(self, settings):
settings.FILENAME_FORMAT = "{{ original_name }}"
nfd = unicodedata.normalize("NFD", "Rechnung März")
nfc = unicodedata.normalize("NFC", "Rechnung März")
doc = DocumentFactory(
original_filename=f"{nfd}.pdf",
mime_type="application/pdf",
)
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}.pdf".encode()
def test_nfd_custom_field_string_value_normalized(self, settings):
"""NFD value in a STRING-type custom field must appear as NFC in the context."""
settings.FILENAME_FORMAT = (
"{{ custom_fields['Location']['value'] }}/{{ title }}"
)
nfd_value = unicodedata.normalize("NFD", "Düsseldorf")
nfc_value = unicodedata.normalize("NFC", "Düsseldorf")
assert nfd_value != nfc_value
doc = DocumentFactory(title="report", mime_type="application/pdf")
cf = CustomField.objects.create(
name="Location",
data_type=CustomField.FieldDataType.STRING,
)
CustomFieldInstance.objects.create(
document=doc,
field=cf,
value_text=nfd_value,
)
result = generate_filename(doc)
assert str(result).encode() == f"{nfc_value}/report.pdf".encode()
def test_nfd_custom_field_name_normalized_as_key(self, settings):
"""NFD characters in a custom field name must appear as NFC in the context dict key."""
nfd_name = unicodedata.normalize("NFD", "Größe")
nfc_name = unicodedata.normalize("NFC", "Größe")
assert nfd_name != nfc_name
settings.FILENAME_FORMAT = f"{{% if custom_fields['{nfc_name}'] %}}{{{{ custom_fields['{nfc_name}']['value'] }}}}/{{{{ title }}}}{{% else %}}{{{{ title }}}}{{% endif %}}"
doc = DocumentFactory(title="letter", mime_type="application/pdf")
cf = CustomField.objects.create(
name=nfd_name,
data_type=CustomField.FieldDataType.STRING,
)
CustomFieldInstance.objects.create(
document=doc,
field=cf,
value_text="Berlin",
)
result = generate_filename(doc)
# If field name key is NFC-normalized, the template condition succeeds
# and result is "Berlin/letter.pdf"; otherwise it falls back to "letter.pdf"
assert str(result) == "Berlin/letter.pdf"
def test_nfd_tag_name_list_normalized_to_nfc(self, settings):
"""NFD tag names in tag_name_list must appear as NFC bytes when iterated."""
settings.FILENAME_FORMAT = (
"{% for t in tag_name_list %}{{ t }}{% endfor %}/{{ title }}"
)
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
assert nfd != nfc # confirm byte-level difference
doc = DocumentFactory(title="doc", mime_type="application/pdf")
doc.tags.add(TagFactory(name=nfd))
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
@@ -684,7 +684,6 @@ class ConsumerThread(Thread):
subdirs_as_tags: bool = False,
polling_interval: float = 0,
stability_delay: float = 0.1,
rescan_interval: float | None = None,
) -> None:
super().__init__()
self.consumption_dir = consumption_dir
@@ -694,8 +693,6 @@ class ConsumerThread(Thread):
self.polling_interval = polling_interval
self.stability_delay = stability_delay
self.cmd = Command()
if rescan_interval is not None:
self.cmd.rescan_interval_s = rescan_interval
self.cmd.stop_flag.clear()
# Non-daemon ensures finally block runs and connections are closed
self.daemon = False
@@ -1055,200 +1052,3 @@ class TestCommandWatchEdgeCases:
thread.stop_and_wait(timeout=5.0)
# Clean up any Tags created by the thread
Tag.objects.all().delete()
class TestRescanExistingFiles:
"""
Unit tests for the rescan safety net.
Each ``watch()`` recreation silently adopts the current directory contents
as its baseline, so a file appearing between one batch and the next
watcher's baseline is never reported and would sit in the consume directory
forever. ``_rescan_existing_files`` re-injects such files into the
stability tracker as a periodic safety net (see GH issue #13011).
"""
@pytest.fixture
def pdf_only_filter(self) -> ConsumerFilter:
return ConsumerFilter(
supported_extensions=frozenset({".pdf"}),
ignore_patterns=[],
)
def _rescan(
self,
directory: Path,
consumer_filter: ConsumerFilter,
tracker: FileStabilityTracker,
queued: set[Path],
*,
recursive: bool = False,
) -> None:
Command()._rescan_existing_files(
directory=directory,
recursive=recursive,
consumer_filter=consumer_filter,
tracker=tracker,
queued=queued,
)
def test_tracks_stranded_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A supported on-disk file the watcher never reported gets tracked."""
target = consumption_dir / "stranded.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.is_tracking(target) is True
assert tracker.pending_count == 1
def test_skips_already_tracked_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A file already being tracked by the watcher is not double-tracked."""
target = consumption_dir / "tracked.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
tracker.track(target, Change.added)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.pending_count == 1
def test_skips_queued_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A file already queued and awaiting consumption is not re-tracked."""
target = consumption_dir / "inflight.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
queued = {target.resolve()}
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
assert tracker.pending_count == 0
def test_prunes_vanished_queued_paths(
self,
consumption_dir: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Queued paths no longer on disk are dropped so the name can recur."""
gone = (consumption_dir / "gone.pdf").resolve()
tracker = FileStabilityTracker(stability_delay=0.1)
queued = {gone}
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
assert gone not in queued
def test_skips_unsupported_extension(
self,
consumption_dir: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Files filtered out by the consumer filter are not tracked."""
(consumption_dir / "notes.xyz").write_bytes(b"content")
tracker = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.pending_count == 0
def test_recursive_respects_flag(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Nested files are only found when recursive scanning is enabled."""
subdir = consumption_dir / "nested"
subdir.mkdir()
target = subdir / "deep.pdf"
shutil.copy(sample_pdf, target)
shallow = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, shallow, set())
assert shallow.pending_count == 0
deep = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, deep, set(), recursive=True)
assert deep.is_tracking(target) is True
class TestProcessExistingFilesQueued:
"""Tests that startup processing reports which paths it queued."""
@pytest.mark.usefixtures("mock_supported_extensions")
def test_returns_queued_paths(
self,
consumption_dir: Path,
sample_pdf: Path,
mock_consume_file_delay: MagicMock,
settings: SettingsWrapper,
) -> None:
"""The set returned seeds the rescan's queued set, avoiding re-queue."""
target = consumption_dir / "document.pdf"
shutil.copy(sample_pdf, target)
settings.CONSUMER_IGNORE_PATTERNS = []
queued = Command()._process_existing_files(
directory=consumption_dir,
recursive=False,
subdirs_as_tags=False,
consumer_filter=ConsumerFilter(ignore_patterns=[]),
)
assert target.resolve() in queued
@pytest.mark.management
@pytest.mark.django_db
class TestCommandRescanRecovery:
"""End-to-end test that the rescan recovers files the watcher misses."""
def test_rescan_consumes_file_the_watcher_never_reports(
self,
consumption_dir: Path,
sample_pdf: Path,
mock_consume_file_delay: MagicMock,
start_consumer: Callable[..., ConsumerThread],
) -> None:
"""
Isolate the rescan path: a long polling interval guarantees the
watcher cannot report the file within the test window, so only the
periodic rescan can consume it.
"""
# poll interval far longer than the test window -> watcher stays silent
thread = start_consumer(
polling_interval=30.0,
stability_delay=0.1,
rescan_interval=0.5,
)
# created after startup, so _process_existing_files did not see it
target = consumption_dir / "stranded.pdf"
shutil.copy(sample_pdf, target)
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=5.0)
if thread.exception:
raise thread.exception
mock_consume_file_delay.apply_async.assert_called()
call_args = mock_consume_file_delay.apply_async.call_args.kwargs["kwargs"][
"input_doc"
]
assert call_args.original_file.name == "stranded.pdf"
+1 -29
View File
@@ -30,7 +30,6 @@ from documents.signals.handlers import update_llm_suggestions_cache
from documents.tests.utils import DirectoriesMixin
from documents.tests.utils import read_streaming_response
from paperless.models import ApplicationConfiguration
from paperless_ai.exceptions import LLMTimeoutError
class TestViews(DirectoriesMixin, TestCase):
@@ -244,7 +243,7 @@ class TestViews(DirectoriesMixin, TestCase):
"change": {"users": [], "groups": []},
}
else:
assert False, f"Unexpected tag found: {tag['name']}"
raise AssertionError(f"Unexpected tag found: {tag['name']}")
def test_list_no_n_plus_1_queries(self) -> None:
"""
@@ -477,33 +476,6 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
)
@patch("documents.views.get_ai_document_classification")
@override_settings(
AI_ENABLED=True,
LLM_BACKEND="openai-like",
)
def test_ai_suggestions_with_llm_timeout(
self,
mock_get_ai_classification,
) -> None:
mock_get_ai_classification.side_effect = LLMTimeoutError()
self.client.force_login(user=self.user)
response = self.client.get(
f"/api/documents/{self.document.pk}/ai_suggestions/",
)
self.assertEqual(response.status_code, status.HTTP_503_SERVICE_UNAVAILABLE)
self.assertEqual(
response.json(),
{
"ai": ["AI backend request timed out."],
},
)
self.assertIsNone(
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
)
def test_invalidate_suggestions_cache(self) -> None:
self.client.force_login(user=self.user)
suggestions = {
+8 -1
View File
@@ -2760,7 +2760,14 @@ class TestWorkflows(
doc = Document.objects.create(
title="test",
)
self.assertRaises(Exception, document_matches_workflow, doc, w, 99)
self.assertRaisesRegex(
Exception,
"not yet supported",
document_matches_workflow,
doc,
w,
99,
)
def test_removal_action_document_updated_workflow(self) -> None:
"""
+3 -2
View File
@@ -129,11 +129,12 @@ def util_call_with_backoff(
status_codes.append(cause_exec.response.status_code)
warnings.warn(
f"HTTP Exception for {cause_exec.request.url} - {cause_exec}",
stacklevel=2,
)
else:
warnings.warn(f"Unexpected error: {e}")
warnings.warn(f"Unexpected error: {e}", stacklevel=2)
except Exception as e: # pragma: no cover
warnings.warn(f"Unexpected error: {e}")
warnings.warn(f"Unexpected error: {e}", stacklevel=2)
retry_count = retry_count + 1
+50 -154
View File
@@ -7,12 +7,11 @@ import tempfile
import zipfile
from collections import defaultdict
from collections import deque
from datetime import UTC
from datetime import datetime
from datetime import timedelta
from http import HTTPStatus
from pathlib import Path
from time import mktime
from time import sleep
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
@@ -61,7 +60,6 @@ from django.http import StreamingHttpResponse
from django.shortcuts import get_object_or_404
from django.utils import timezone
from django.utils.decorators import method_decorator
from django.utils.timezone import make_aware
from django.utils.translation import get_language
from django.utils.translation import gettext_lazy as _
from django.views import View
@@ -241,7 +239,6 @@ from paperless.serialisers import UserSerializer
from paperless.views import StandardPagination
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.chat import stream_chat_with_documents
from paperless_ai.exceptions import LLMTimeoutError
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
@@ -287,7 +284,7 @@ def _get_more_like_id(query_params: dict[str, Any], user: User | None) -> int:
pk=more_like_doc_id,
)
except (TypeError, ValueError, Document.DoesNotExist):
raise PermissionDenied(_("Invalid more_like_id"))
raise PermissionDenied(_("Invalid more_like_id")) from None
if user and not has_perms_owner_aware(
user,
@@ -1103,7 +1100,7 @@ class DocumentViewSet(
"root_document",
).get(pk=pk)
except Document.DoesNotExist:
raise Http404
raise Http404 from None
root_doc = get_root_document(doc)
if request.user is not None and not has_perms_owner_aware(
@@ -1266,7 +1263,7 @@ class DocumentViewSet(
"root_document",
).get(id=pk)
except Document.DoesNotExist:
raise Http404
raise Http404 from None
root_doc = get_root_document(
request_doc,
@@ -1402,7 +1399,7 @@ class DocumentViewSet(
)
if request.user is not None and not has_perms_owner_aware(
request.user,
"change_document",
"view_document",
doc,
):
return HttpResponseForbidden("Insufficient permissions")
@@ -1462,7 +1459,7 @@ class DocumentViewSet(
)
if request.user is not None and not has_perms_owner_aware(
request.user,
"change_document",
"view_document",
doc,
):
return HttpResponseForbidden("Insufficient permissions")
@@ -1508,20 +1505,8 @@ class DocumentViewSet(
"document %s: %s",
doc.pk,
exc,
exc_info=True,
)
raise ValidationError({"ai": [_("Invalid AI configuration.")]}) from exc
except LLMTimeoutError as exc:
logger.exception(
"AI backend timed out while generating suggestions for document %s: %s",
doc.pk,
exc,
exc_info=True,
)
return Response(
{"ai": [_("AI backend request timed out.")]},
status=status.HTTP_503_SERVICE_UNAVAILABLE,
)
matched_tags = match_tags_by_name(
llm_suggestions.get("tags", []),
@@ -1593,7 +1578,7 @@ class DocumentViewSet(
disposition="inline",
)
except FileNotFoundError:
raise Http404
raise Http404 from None
@action(methods=["get"], detail=True, filter_backends=[])
@method_decorator(cache_control(no_cache=True))
@@ -1618,14 +1603,14 @@ class DocumentViewSet(
return FileResponse(handle, content_type="image/webp")
except FileNotFoundError:
raise Http404
raise Http404 from None
@action(methods=["get"], detail=True)
def download(self, request, pk=None):
try:
return self.file_response(pk, request, "attachment")
except (FileNotFoundError, Document.DoesNotExist):
raise Http404
raise Http404 from None
@action(
methods=["get", "post", "delete"],
@@ -1650,7 +1635,7 @@ class DocumentViewSet(
):
return HttpResponseForbidden("Insufficient permissions to view notes")
except Document.DoesNotExist:
raise Http404
raise Http404 from None
serializer = self.get_serializer(doc)
@@ -1721,7 +1706,7 @@ class DocumentViewSet(
try:
note_id_int = int(note_id)
except ValueError:
raise ValidationError({"id": "A valid integer is required."})
raise ValidationError({"id": "A valid integer is required."}) from None
note = get_object_or_404(Note, id=note_id_int, document=doc)
if settings.AUDIT_LOG_ENABLED:
LogEntry.objects.log_create(
@@ -1765,7 +1750,7 @@ class DocumentViewSet(
"Insufficient permissions to add share link",
)
except Document.DoesNotExist:
raise Http404
raise Http404 from None
if request.method == "GET":
now = timezone.now()
@@ -1793,7 +1778,7 @@ class DocumentViewSet(
"Insufficient permissions",
)
except Document.DoesNotExist: # pragma: no cover
raise Http404
raise Http404 from None
# documents
entries = [
@@ -1814,28 +1799,28 @@ class DocumentViewSet(
]
# custom fields
for entry in LogEntry.objects.get_for_objects(
doc.custom_fields.all(),
).select_related("actor"):
entries.append(
{
"id": entry.id,
"timestamp": entry.timestamp,
"action": entry.get_action_display(),
"changes": {
"custom_fields": {
"type": "custom_field",
"field": str(entry.object_repr).split(":")[0].strip(),
"value": str(entry.object_repr).split(":")[1].strip(),
},
entries.extend(
{
"id": entry.id,
"timestamp": entry.timestamp,
"action": entry.get_action_display(),
"changes": {
"custom_fields": {
"type": "custom_field",
"field": str(entry.object_repr).split(":")[0].strip(),
"value": str(entry.object_repr).split(":")[1].strip(),
},
"actor": (
{"id": entry.actor.id, "username": entry.actor.username}
if entry.actor
else None
),
},
)
"actor": (
{"id": entry.actor.id, "username": entry.actor.username}
if entry.actor
else None
),
}
for entry in LogEntry.objects.get_for_objects(
doc.custom_fields.all(),
).select_related("actor")
)
return Response(sorted(entries, key=lambda x: x["timestamp"], reverse=True))
@@ -1943,13 +1928,13 @@ class DocumentViewSet(
):
return HttpResponseForbidden("Insufficient permissions")
except Document.DoesNotExist:
raise Http404
raise Http404 from None
try:
doc_name, doc_data = serializer.validated_data.get("document")
version_label = serializer.validated_data.get("version_label")
t = int(mktime(datetime.now().timetuple()))
t = int(timezone.now().timestamp())
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
@@ -1994,7 +1979,7 @@ class DocumentViewSet(
"root_document",
).get(pk=pk)
except Document.DoesNotExist:
raise Http404
raise Http404 from None
return get_root_document(root_doc)
def _get_version_doc_for_root(self, root_doc: Document, version_id) -> Document:
@@ -2003,7 +1988,7 @@ class DocumentViewSet(
pk=version_id,
)
except Document.DoesNotExist:
raise Http404
raise Http404 from None
if (
version_doc.id != root_doc.id
@@ -2289,7 +2274,6 @@ class UnifiedSearchViewSet(DocumentViewSet):
return super().list(request)
from documents.search import SearchHit
from documents.search import SearchQueryError
from documents.search import TantivyBackend
from documents.search import TantivyRelevanceList
from documents.search import get_backend
@@ -2482,11 +2466,6 @@ class UnifiedSearchViewSet(DocumentViewSet):
return HttpResponseForbidden(_("Insufficient permissions."))
except ValidationError:
raise
except SearchQueryError as e:
# User-fixable query error (e.g. an unparsable date): surface the
# specific message so the user can correct it, rather than a generic
# 400 or silently empty results.
raise ValidationError({"query": [str(e)]}) from e
except Exception as e:
logger.warning(f"An error occurred listing search results: {e!s}")
return HttpResponseBadRequest(
@@ -2564,7 +2543,7 @@ class LogViewSet(ViewSet):
try:
limit = int(limit_param)
except (TypeError, ValueError):
raise ValidationError({"limit": "Must be a positive integer"})
raise ValidationError({"limit": "Must be a positive integer"}) from None
if limit < 1:
raise ValidationError({"limit": "Must be a positive integer"})
else:
@@ -3145,7 +3124,6 @@ class PostDocumentView(GenericAPIView[Any]):
serializer.is_valid(raise_exception=True)
doc_name, doc_data = serializer.validated_data.get("document")
doc_name = normalize("NFC", doc_name)
correspondent_id = serializer.validated_data.get("correspondent")
document_type_id = serializer.validated_data.get("document_type")
storage_path_id = serializer.validated_data.get("storage_path")
@@ -3156,7 +3134,7 @@ class PostDocumentView(GenericAPIView[Any]):
cf = serializer.validated_data.get("custom_fields")
from_webui = serializer.validated_data.get("from_webui")
t = int(mktime(datetime.now().timetuple()))
t = int(timezone.now().timestamp())
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
@@ -4031,7 +4009,7 @@ class RemoteVersionView(GenericAPIView[Any]):
class _TasksViewSetSchema(AutoSchema):
_UNPAGINATED_ACTIONS = frozenset({"summary", "active", "status_counts"})
_UNPAGINATED_ACTIONS = frozenset({"summary", "active"})
def _get_paginator(self):
if getattr(self.view, "action", None) in self._UNPAGINATED_ACTIONS:
@@ -4053,7 +4031,7 @@ class _TasksViewSetSchema(AutoSchema):
),
acknowledge=extend_schema(
operation_id="acknowledge_tasks",
description="Acknowledge a list of tasks, or all visible unacknowledged tasks",
description="Acknowledge a list of tasks",
request=AcknowledgeTasksViewSerializer,
responses={
(200, "application/json"): inline_serializer(
@@ -4091,19 +4069,6 @@ class _TasksViewSetSchema(AutoSchema):
),
],
),
status_counts=extend_schema(
responses={
200: inline_serializer(
name="TaskStatusCounts",
fields={
"all": serializers.IntegerField(),
"needs_attention": serializers.IntegerField(),
"in_progress": serializers.IntegerField(),
"completed": serializers.IntegerField(),
},
),
},
),
active=extend_schema(
description="Currently pending and running tasks (capped at 50).",
responses={200: TaskSerializerV10(many=True)},
@@ -4157,7 +4122,6 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
PaperlessTask.TaskType.SANITY_CHECK: (sanity_check, {"raise_on_error": False}),
PaperlessTask.TaskType.LLM_INDEX: (llmindex_index, {"rebuild": False}),
}
_STATUS_COUNT_EXCLUDED_FILTERS = frozenset({"status", "is_complete"})
def get_serializer_class(self):
# v9: use backwards-compatible serializer with old field names
@@ -4198,38 +4162,16 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
queryset = queryset.filter(task_id=task_id)
return queryset
def get_status_count_queryset(self):
"""Apply task filters except the status dimensions represented by the counts."""
query_params = self.request.query_params.copy()
for param in self._STATUS_COUNT_EXCLUDED_FILTERS:
query_params.pop(param, None)
filterset = self.filterset_class(
data=query_params,
queryset=self.get_queryset(),
request=self.request,
)
if not filterset.is_valid():
raise ValidationError(filterset.errors)
return filterset.qs
@action(
methods=["post"],
detail=False,
permission_classes=[IsAuthenticated, AcknowledgeTasksPermissions],
)
def acknowledge(self, request):
queryset = self.get_queryset()
serializer = AcknowledgeTasksViewSerializer(
data=request.data,
context={"queryset": queryset},
)
serializer = AcknowledgeTasksViewSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
if serializer.validated_data.get("all", False):
tasks = queryset.filter(acknowledged=False)
else:
task_ids = serializer.validated_data.get("tasks")
tasks = queryset.filter(id__in=task_ids)
task_ids = serializer.validated_data.get("tasks")
tasks = self.get_queryset().filter(id__in=task_ids)
count = tasks.update(acknowledged=True)
return Response({"result": count})
@@ -4282,34 +4224,6 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
serializer = TaskSummarySerializer(data, many=True)
return Response(serializer.data)
@action(methods=["get"], detail=False)
def status_counts(self, request):
"""Aggregated task counts for task UI sections."""
queryset = self.get_status_count_queryset()
counts = queryset.aggregate(
all=Count("id"),
needs_attention=Count(
"id",
filter=Q(
status__in=[
PaperlessTask.Status.FAILURE,
PaperlessTask.Status.REVOKED,
],
),
),
in_progress=Count(
"id",
filter=Q(
status__in=[
PaperlessTask.Status.PENDING,
PaperlessTask.Status.STARTED,
],
),
),
completed=Count("id", filter=Q(status=PaperlessTask.Status.SUCCESS)),
)
return Response(counts)
@action(methods=["get"], detail=False)
def active(self, request):
"""Currently pending and running tasks (capped at 50)."""
@@ -5009,29 +4923,11 @@ class SystemStatusView(PassUserMixin):
celery_error = None
celery_url = None
try:
celery_ping = None
for ping_attempt in range(3):
celery_ping = celery_app.control.inspect().ping()
if celery_ping:
break
if ping_attempt < 2:
sleep(0.25)
if not celery_ping:
celery_active = "WARNING"
celery_error = (
"No celery workers responded to ping. This may be temporary."
)
else:
celery_url, first_worker_ping = next(iter(celery_ping.items()))
if (
isinstance(first_worker_ping, dict)
and first_worker_ping.get("ok") == "pong"
):
celery_active = "OK"
else:
celery_active = "WARNING"
celery_error = "Celery worker responded unexpectedly."
celery_ping = celery_app.control.inspect().ping()
celery_url = next(iter(celery_ping.keys()))
first_worker_ping = celery_ping[celery_url]
if first_worker_ping["ok"] == "pong":
celery_active = "OK"
except Exception as e:
celery_active = "ERROR"
logger.exception(
@@ -5050,7 +4946,7 @@ class SystemStatusView(PassUserMixin):
index_dir = settings.INDEX_DIR
mtimes = [p.stat().st_mtime for p in index_dir.iterdir() if p.is_file()]
index_last_modified = (
make_aware(datetime.fromtimestamp(max(mtimes))) if mtimes else None
datetime.fromtimestamp(max(mtimes), tz=UTC) if mtimes else None
)
except Exception as e:
index_status = "ERROR"
+14 -13
View File
@@ -84,10 +84,11 @@ def binaries_check(app_configs: Any, **kwargs: Any) -> list[Error]:
binaries = (settings.CONVERT_BINARY, "tesseract", "gs")
check_messages = []
for binary in binaries:
if shutil.which(binary) is None:
check_messages.append(Warning(error.format(binary), hint))
check_messages = [
Warning(error.format(binary), hint)
for binary in binaries
if shutil.which(binary) is None
]
return check_messages
@@ -383,14 +384,14 @@ def check_default_language_available(app_configs: Any, **kwargs: Any) -> list[Er
specified_langs = [x.strip() for x in settings.OCR_LANGUAGE.split("+")]
for lang in specified_langs:
if lang not in installed_langs:
errs.append(
Error(
f"The selected ocr language {lang} is "
f"not installed. Paperless cannot OCR your documents "
f"without it. Please fix PAPERLESS_OCR_LANGUAGE.",
),
)
errs.extend(
Error(
f"The selected ocr language {lang} is "
f"not installed. Paperless cannot OCR your documents "
f"without it. Please fix PAPERLESS_OCR_LANGUAGE.",
)
for lang in specified_langs
if lang not in installed_langs
)
return errs
-4
View File
@@ -197,7 +197,6 @@ class AIConfig(BaseConfig):
llm_embedding_endpoint: str = dataclasses.field(init=False)
llm_embedding_chunk_size: int = dataclasses.field(init=False)
llm_context_size: int = dataclasses.field(init=False)
llm_request_timeout: int = dataclasses.field(init=False)
llm_backend: str = dataclasses.field(init=False)
llm_model: str = dataclasses.field(init=False)
llm_api_key: str = dataclasses.field(init=False)
@@ -222,9 +221,6 @@ class AIConfig(BaseConfig):
app_config.llm_embedding_chunk_size or settings.LLM_EMBEDDING_CHUNK_SIZE
)
self.llm_context_size = app_config.llm_context_size or settings.LLM_CONTEXT_SIZE
self.llm_request_timeout = (
app_config.llm_request_timeout or settings.LLM_REQUEST_TIMEOUT
)
self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND
self.llm_model = app_config.llm_model or settings.LLM_MODEL
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
@@ -1,365 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:30
import django.core.validators
from django.db import migrations
from django.db import models
def _create_singleton(apps, schema_editor):
settings_model = apps.get_model("paperless", "ApplicationConfiguration")
settings_model.objects.create()
class Migration(migrations.Migration):
replaces = [
("paperless", "0001_initial"),
("paperless", "0002_applicationconfiguration_app_logo_and_more"),
("paperless", "0003_alter_applicationconfiguration_max_image_pixels"),
("paperless", "0004_applicationconfiguration_barcode_asn_prefix_and_more"),
("paperless", "0005_applicationconfiguration_ai_enabled_and_more"),
("paperless", "0006_applicationconfiguration_barcode_tag_split"),
]
dependencies = []
operations = [
migrations.CreateModel(
name="ApplicationConfiguration",
fields=[
(
"id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
(
"output_type",
models.CharField(
blank=True,
choices=[
("pdf", "pdf"),
("pdfa", "pdfa"),
("pdfa-1", "pdfa-1"),
("pdfa-2", "pdfa-2"),
("pdfa-3", "pdfa-3"),
],
max_length=8,
null=True,
verbose_name="Sets the output PDF type",
),
),
(
"pages",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Do OCR from page 1 to this value",
),
),
(
"language",
models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Do OCR using these languages",
),
),
(
"mode",
models.CharField(
blank=True,
choices=[
("skip", "skip"),
("redo", "redo"),
("force", "force"),
("skip_noarchive", "skip_noarchive"),
],
max_length=16,
null=True,
verbose_name="Sets the OCR mode",
),
),
(
"skip_archive_file",
models.CharField(
blank=True,
choices=[
("never", "never"),
("with_text", "with_text"),
("always", "always"),
],
max_length=16,
null=True,
verbose_name="Controls the generation of an archive file",
),
),
(
"image_dpi",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets image DPI fallback value",
),
),
(
"unpaper_clean",
models.CharField(
blank=True,
choices=[
("clean", "clean"),
("clean-final", "clean-final"),
("none", "none"),
],
max_length=16,
null=True,
verbose_name="Controls the unpaper cleaning",
),
),
(
"deskew",
models.BooleanField(null=True, verbose_name="Enables deskew"),
),
(
"rotate_pages",
models.BooleanField(
null=True,
verbose_name="Enables page rotation",
),
),
(
"rotate_pages_threshold",
models.FloatField(
null=True,
validators=[django.core.validators.MinValueValidator(0.0)],
verbose_name="Sets the threshold for rotation of pages",
),
),
(
"max_image_pixels",
models.FloatField(
null=True,
validators=[django.core.validators.MinValueValidator(0.0)],
verbose_name="Sets the maximum image size for decompression",
),
),
(
"color_conversion_strategy",
models.CharField(
blank=True,
choices=[
("LeaveColorUnchanged", "LeaveColorUnchanged"),
("RGB", "RGB"),
("UseDeviceIndependentColor", "UseDeviceIndependentColor"),
("Gray", "Gray"),
("CMYK", "CMYK"),
],
max_length=32,
null=True,
verbose_name="Sets the Ghostscript color conversion strategy",
),
),
(
"user_args",
models.JSONField(
null=True,
verbose_name="Adds additional user arguments for OCRMyPDF",
),
),
(
"app_logo",
models.FileField(
blank=True,
null=True,
upload_to="logo/",
validators=[
django.core.validators.FileExtensionValidator(
allowed_extensions=["jpg", "png", "gif", "svg"],
),
],
verbose_name="Application logo",
),
),
(
"app_title",
models.CharField(
blank=True,
max_length=48,
null=True,
verbose_name="Application title",
),
),
(
"barcode_asn_prefix",
models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Sets the ASN barcode prefix",
),
),
(
"barcode_dpi",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the barcode DPI",
),
),
(
"barcode_enable_asn",
models.BooleanField(
null=True,
verbose_name="Enables ASN barcode",
),
),
(
"barcode_enable_tag",
models.BooleanField(
null=True,
verbose_name="Enables tag barcode",
),
),
(
"barcode_enable_tiff_support",
models.BooleanField(
null=True,
verbose_name="Enables barcode TIFF support",
),
),
(
"barcode_max_pages",
models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the maximum pages for barcode",
),
),
(
"barcode_retain_split_pages",
models.BooleanField(
null=True,
verbose_name="Retains split pages",
),
),
(
"barcode_string",
models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Sets the barcode string",
),
),
(
"barcode_tag_mapping",
models.JSONField(
null=True,
verbose_name="Sets the tag barcode mapping",
),
),
(
"barcode_upscale",
models.FloatField(
null=True,
validators=[django.core.validators.MinValueValidator(1.0)],
verbose_name="Sets the barcode upscale factor",
),
),
(
"barcodes_enabled",
models.BooleanField(
null=True,
verbose_name="Enables barcode scanning",
),
),
(
"ai_enabled",
models.BooleanField(
default=False,
null=True,
verbose_name="Enables AI features",
),
),
(
"llm_api_key",
models.CharField(
blank=True,
max_length=1024,
null=True,
verbose_name="Sets the LLM API key",
),
),
(
"llm_backend",
models.CharField(
blank=True,
choices=[
("openai-like", "OpenAI-compatible"),
("ollama", "Ollama"),
],
max_length=128,
null=True,
verbose_name="Sets the LLM backend",
),
),
(
"llm_embedding_backend",
models.CharField(
blank=True,
choices=[
("openai-like", "OpenAI-compatible"),
("huggingface", "Huggingface"),
],
max_length=128,
null=True,
verbose_name="Sets the LLM embedding backend",
),
),
(
"llm_embedding_model",
models.CharField(
blank=True,
max_length=128,
null=True,
verbose_name="Sets the LLM embedding model",
),
),
(
"llm_endpoint",
models.CharField(
blank=True,
max_length=256,
null=True,
verbose_name="Sets the LLM endpoint, optional",
),
),
(
"llm_model",
models.CharField(
blank=True,
max_length=128,
null=True,
verbose_name="Sets the LLM model",
),
),
(
"barcode_tag_split",
models.BooleanField(
null=True,
verbose_name="Enables splitting on tag barcodes",
),
),
],
options={
"verbose_name": "paperless application settings",
},
),
migrations.RunPython(
code=_create_singleton,
reverse_code=migrations.RunPython.noop,
),
]
@@ -1,94 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-04 15:19
import django.core.validators
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
replaces = [
("paperless", "0009_alter_applicationconfiguration_options"),
("paperless", "0010_alter_applicationconfiguration_llm_embedding_backend"),
("paperless", "0011_applicationconfiguration_llm_embedding_chunk_size"),
("paperless", "0012_applicationconfiguration_llm_output_language"),
("paperless", "0013_applicationconfiguration_llm_request_timeout"),
]
dependencies = [
("paperless", "0008_replace_skip_archive_file"),
]
operations = [
migrations.AlterModelOptions(
name="applicationconfiguration",
options={
"permissions": [
("view_global_statistics", "Can view global object counts"),
("view_system_monitoring", "Can view system status information"),
],
"verbose_name": "paperless application settings",
},
),
migrations.AlterField(
model_name="applicationconfiguration",
name="llm_embedding_backend",
field=models.CharField(
blank=True,
choices=[
("openai-like", "OpenAI-compatible"),
("huggingface", "Huggingface"),
("ollama", "Ollama"),
],
max_length=128,
null=True,
verbose_name="Sets the LLM embedding backend",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_embedding_endpoint",
field=models.CharField(
blank=True,
max_length=256,
null=True,
verbose_name="Sets the LLM embedding endpoint, optional",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_embedding_chunk_size",
field=models.PositiveSmallIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM embedding chunk size",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_context_size",
field=models.PositiveIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM context size",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_output_language",
field=models.CharField(
blank=True,
max_length=32,
null=True,
verbose_name="Sets the LLM output language",
),
),
migrations.AddField(
model_name="applicationconfiguration",
name="llm_request_timeout",
field=models.PositiveSmallIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM request timeout in seconds",
),
),
]
@@ -1,23 +0,0 @@
# Generated by Django 5.2.14 on 2026-06-14 14:22
import django.core.validators
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
dependencies = [
("paperless", "0012_applicationconfiguration_llm_output_language"),
]
operations = [
migrations.AddField(
model_name="applicationconfiguration",
name="llm_request_timeout",
field=models.PositiveSmallIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM request timeout in seconds",
),
),
]
-6
View File
@@ -366,12 +366,6 @@ class ApplicationConfiguration(AbstractSingletonModel):
max_length=32,
)
llm_request_timeout = models.PositiveSmallIntegerField(
verbose_name=_("Sets the LLM timeout in seconds"),
null=True,
validators=[MinValueValidator(1)],
)
class Meta:
verbose_name = _("paperless application settings")
permissions = [
+4 -5
View File
@@ -649,11 +649,10 @@ class MailDocumentParser:
if data["bcc"]:
data["bcc_label"] = "BCC"
att = []
for a in mail.attachments:
att.append(
f"{a.filename} ({naturalsize(a.size, binary=True, format='%.2f')})",
)
att = [
f"{a.filename} ({naturalsize(a.size, binary=True, format='%.2f')})"
for a in mail.attachments
]
data["attachments"] = clean_html(", ".join(att))
if data["attachments"]:
data["attachments_label"] = "Attachments"
+28 -2
View File
@@ -20,7 +20,6 @@ from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from paperless.parsers.utils import read_file_handle_unicode_errors
from paperless.version import __full_version_str__
if TYPE_CHECKING:
@@ -184,7 +183,7 @@ class TextDocumentParser:
documents.parsers.ParseError
If the file cannot be read.
"""
self._text = read_file_handle_unicode_errors(document_path, log=logger)
self._text = self._read_text(document_path)
# ------------------------------------------------------------------
# Result accessors
@@ -296,3 +295,30 @@ class TextDocumentParser:
Always ``[]`` plain text files carry no structured metadata.
"""
return []
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _read_text(self, filepath: Path) -> str:
"""Read file content, replacing invalid UTF-8 bytes rather than failing.
Parameters
----------
filepath:
Path to the file to read.
Returns
-------
str
File content as a string.
"""
try:
return filepath.read_text(encoding="utf-8")
except UnicodeDecodeError as exc:
logger.warning(
"Unicode error reading %s, replacing bad bytes: %s",
filepath,
exc,
)
return filepath.read_bytes().decode("utf-8", errors="replace")
+5 -18
View File
@@ -8,7 +8,6 @@ share implementation.
from __future__ import annotations
import codecs
import logging
import re
import tempfile
@@ -115,7 +114,7 @@ def read_file_handle_unicode_errors(
filepath: Path,
log: logging.Logger | None = None,
) -> str:
"""Read a file as text, detecting encoding via BOM and stripping NUL bytes.
"""Read a file as UTF-8 text, replacing invalid bytes rather than raising.
Parameters
----------
@@ -128,27 +127,15 @@ def read_file_handle_unicode_errors(
Returns
-------
str
File content as a string, with NUL bytes removed so the result is
safe to store in PostgreSQL text fields.
File content as a string, with any invalid UTF-8 sequences replaced
by the Unicode replacement character.
"""
_log = log or logger
raw = filepath.read_bytes()
if raw.startswith((codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE)):
encoding = "utf-16"
elif raw.startswith(codecs.BOM_UTF8):
encoding = "utf-8-sig"
else:
encoding = "utf-8"
try:
text = raw.decode(encoding)
return filepath.read_text(encoding="utf-8")
except UnicodeDecodeError as e:
_log.warning("Unicode error during text reading, continuing: %s", e)
text = raw.decode("utf-8", errors="replace")
# PostgreSQL rejects NUL (0x00) bytes in text fields
return text.replace("\x00", "")
return filepath.read_bytes().decode("utf-8", errors="replace")
def get_page_count_for_pdf(
+2 -12
View File
@@ -97,14 +97,8 @@ MODEL_FILE = get_path_from_env(
DATA_DIR / "classification_model.pickle",
)
LLM_INDEX_DIR = DATA_DIR / "llm_index"
LLM_INDEX_LOCK = LLM_INDEX_DIR / "index.lock"
# Cross-process read/write lock guarding the LLM index compaction/migration
# file swap. Readers hold it shared; the swap takes it exclusively so it never
# runs while a reader connection is open. Must be a SQLite (.db) file.
LLM_INDEX_RWLOCK = LLM_INDEX_DIR / "llmindex.rwlock.db"
# Seconds the compaction swap waits for active readers to drain before skipping
# this cycle (it is a maintenance operation; the next run retries).
LLM_INDEX_COMPACTION_LOCK_TIMEOUT = 30
LLM_INDEX_LOCK = DATA_DIR / "locks" / "llm_index.lock"
(DATA_DIR / "locks").mkdir(parents=True, exist_ok=True)
LOGGING_DIR = get_path_from_env("PAPERLESS_LOGGING_DIR", DATA_DIR / "log")
@@ -650,7 +644,6 @@ LOGGING = {
"kombu": {"handlers": ["file_celery"], "level": "DEBUG"},
"_granian": {"handlers": ["file_paperless"], "level": "DEBUG"},
"granian.access": {"handlers": ["file_paperless"], "level": "DEBUG"},
"httpx": {"level": "WARNING"},
},
}
@@ -1206,9 +1199,6 @@ if LLM_EMBEDDING_CHUNK_SIZE < 1:
LLM_CONTEXT_SIZE = get_int_from_env("PAPERLESS_AI_LLM_CONTEXT_SIZE", 8192)
if LLM_CONTEXT_SIZE < 1:
raise ImproperlyConfigured("PAPERLESS_AI_LLM_CONTEXT_SIZE must be >= 1")
LLM_REQUEST_TIMEOUT = get_int_from_env("PAPERLESS_AI_LLM_REQUEST_TIMEOUT", 120)
if LLM_REQUEST_TIMEOUT < 1:
raise ImproperlyConfigured("PAPERLESS_AI_LLM_REQUEST_TIMEOUT must be >= 1")
LLM_BACKEND = get_choice_from_env(
"PAPERLESS_AI_LLM_BACKEND",
{"ollama", "openai-like"},
+1 -4
View File
@@ -252,9 +252,6 @@ def parse_db_settings(data_dir: Path) -> dict[str, dict[str, Any]]:
"NAME": os.getenv("PAPERLESS_DBNAME", "paperless"),
"USER": os.getenv("PAPERLESS_DBUSER", "paperless"),
"PASSWORD": os.getenv("PAPERLESS_DBPASS", "paperless"),
# Validate pooled connections so a connection closed server-side
# is replaced rather than handed out as "the connection is closed".
"CONN_HEALTH_CHECKS": True,
}
base_options = {
@@ -334,7 +331,7 @@ def parse_dateparser_languages(languages: str | None) -> list[str]:
language_list = languages.split("+") if languages else []
# There is an unfixed issue in zh-Hant and zh-Hans locales in the dateparser lib.
# See: https://github.com/scrapinghub/dateparser/issues/875
for index, language in enumerate(language_list):
for _, language in enumerate(language_list):
if language.startswith("zh-") and "zh" not in language_list:
logger.warning(
f"Chinese locale detected: {language}. dateparser might fail to parse"
@@ -398,7 +398,6 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "localhost",
"NAME": "paperless",
"USER": "paperless",
@@ -427,7 +426,6 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "paperless-db-host",
"PORT": 1111,
"NAME": "customdb",
@@ -457,7 +455,6 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "pghost",
"NAME": "paperless",
"USER": "paperless",
@@ -488,7 +485,6 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "pghost",
"NAME": "paperless",
"USER": "paperless",
-37
View File
@@ -2,50 +2,13 @@
from __future__ import annotations
import codecs
from pathlib import Path
from paperless.parsers.utils import is_tagged_pdf
from paperless.parsers.utils import read_file_handle_unicode_errors
SAMPLES = Path(__file__).parent / "samples" / "tesseract"
class TestReadFileHandleUnicodeErrors:
def test_plain_utf8(self, tmp_path: Path) -> None:
f = tmp_path / "plain.txt"
f.write_bytes(b"hello world")
assert read_file_handle_unicode_errors(f) == "hello world"
def test_utf8_bom(self, tmp_path: Path) -> None:
f = tmp_path / "bom.txt"
f.write_bytes(codecs.BOM_UTF8 + b"hello")
assert read_file_handle_unicode_errors(f) == "hello"
def test_utf16_le(self, tmp_path: Path) -> None:
f = tmp_path / "utf16le.txt"
f.write_bytes(codecs.BOM_UTF16_LE + "hello".encode("utf-16-le"))
assert read_file_handle_unicode_errors(f) == "hello"
def test_utf16_be(self, tmp_path: Path) -> None:
f = tmp_path / "utf16be.txt"
f.write_bytes(codecs.BOM_UTF16_BE + "hello".encode("utf-16-be"))
assert read_file_handle_unicode_errors(f) == "hello"
def test_nul_bytes_stripped(self, tmp_path: Path) -> None:
f = tmp_path / "null-bytes.txt"
f.write_bytes(b"foo\x00bar")
assert read_file_handle_unicode_errors(f) == "foobar"
def test_invalid_utf8_replaced(self, tmp_path: Path) -> None:
f = tmp_path / "bad.txt"
f.write_bytes(b"ok\x80\x81bad")
result = read_file_handle_unicode_errors(f)
assert "ok" in result
assert "bad" in result
assert "\x00" not in result
class TestIsTaggedPdf:
def test_tagged_pdf_returns_true(self) -> None:
assert is_tagged_pdf(SAMPLES / "simple-digital.pdf") is True
+1 -1
View File
@@ -193,7 +193,7 @@ def reject_dangerous_svg(file: UploadedFile) -> None:
tree = etree.parse(file, parser)
root = tree.getroot()
except etree.XMLSyntaxError:
raise ValidationError("Invalid SVG file.")
raise ValidationError("Invalid SVG file.") from None
for element in root.iter():
tag: str = etree.QName(element.tag).localname.lower()
+2 -2
View File
@@ -49,7 +49,7 @@ from paperless.serialisers import GroupSerializer
from paperless.serialisers import PaperlessAuthTokenSerializer
from paperless.serialisers import ProfileSerializer
from paperless.serialisers import UserSerializer
from paperless_ai.indexing import llm_index_exists
from paperless_ai.indexing import vector_store_file_exists
class PaperlessObtainAuthTokenView(ObtainAuthToken):
@@ -467,7 +467,7 @@ class ApplicationConfigurationViewSet(ModelViewSet[ApplicationConfiguration]):
or old_llm_context_size != new_llm_context_size
)
rebuild_needed = new_ai_index_enabled and (
not llm_index_exists() or embedding_config_changed
not vector_store_file_exists() or embedding_config_changed
)
if rebuild_needed:
+21 -35
View File
@@ -8,7 +8,6 @@ from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware
from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.db import db_connection_released
from paperless_ai.indexing import query_similar_documents
from paperless_ai.indexing import truncate_content
@@ -25,14 +24,9 @@ def get_language_name(language_code: str) -> str:
def build_prompt_without_rag(
document: Document,
config: AIConfig,
) -> str:
filename = document.filename or ""
content = truncate_content(
document.content[:4000] or "",
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
content = truncate_content(document.content[:4000] or "")
return f"""
You are a document classification assistant.
@@ -55,15 +49,10 @@ def build_prompt_without_rag(
def build_prompt_with_rag(
document: Document,
config: AIConfig,
user: User | None = None,
) -> str:
base_prompt = build_prompt_without_rag(document, config)
context = truncate_content(
get_context_for_document(document, user),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
base_prompt = build_prompt_without_rag(document)
context = truncate_content(get_context_for_document(document, user))
return f"""{base_prompt}
@@ -141,29 +130,26 @@ def get_ai_document_classification(
ai_config = AIConfig()
prompt = (
build_prompt_with_rag(document, ai_config, user)
build_prompt_with_rag(document, user)
if ai_config.llm_embedding_backend
else build_prompt_without_rag(document, ai_config)
else build_prompt_without_rag(document)
)
client = AIClient()
# Hand the pooled DB connection back while the (slow) LLM query runs so it
# is not pinned for the call's duration; see paperless_ai.db and #12976.
with db_connection_released():
result = client.run_llm_query(prompt)
suggestions = parse_ai_response(result)
if output_language:
localized = client.run_llm_query(
build_localization_prompt(suggestions, output_language),
)
localized_suggestions = parse_ai_response(localized)
suggestions = {
**suggestions,
"title": localized_suggestions["title"] or suggestions["title"],
"tags": localized_suggestions["tags"] or suggestions["tags"],
"document_types": localized_suggestions["document_types"]
or suggestions["document_types"],
"storage_paths": localized_suggestions["storage_paths"]
or suggestions["storage_paths"],
}
result = client.run_llm_query(prompt)
suggestions = parse_ai_response(result)
if output_language:
localized = client.run_llm_query(
build_localization_prompt(suggestions, output_language),
)
localized_suggestions = parse_ai_response(localized)
suggestions = {
**suggestions,
"title": localized_suggestions["title"] or suggestions["title"],
"tags": localized_suggestions["tags"] or suggestions["tags"],
"document_types": localized_suggestions["document_types"]
or suggestions["document_types"],
"storage_paths": localized_suggestions["storage_paths"]
or suggestions["storage_paths"],
}
return suggestions
+122 -56
View File
@@ -3,13 +3,9 @@ import logging
import sys
from documents.models import Document
from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.db import db_connection_released
from paperless_ai.indexing import _document_id_filters
from paperless_ai.indexing import get_rag_prompt_helper
from paperless_ai.indexing import load_or_build_index
from paperless_ai.indexing import read_store
logger = logging.getLogger("paperless_ai.chat")
@@ -79,6 +75,82 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str
)
def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int):
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.core.vector_stores import VectorStoreQuery
class DocumentFilteredFaissRetriever(BaseRetriever):
def __init__(self):
super().__init__()
self._cached_query_str = None
self._cached_nodes = []
def _retrieve(self, query_bundle):
if query_bundle.query_str == self._cached_query_str:
return self._cached_nodes
if query_bundle.embedding is None:
query_bundle.embedding = (
index._embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs,
)
)
faiss_index = index.vector_store._faiss_index
max_top_k = faiss_index.ntotal
if max_top_k == 0:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = []
return []
query_top_k = min(max(similarity_top_k, 1), max_top_k)
allowed_nodes: list[NodeWithScore] = []
seen_node_ids: set[str] = set()
while query_top_k <= max_top_k:
query_result = index.vector_store.query(
VectorStoreQuery(
query_embedding=query_bundle.embedding,
similarity_top_k=query_top_k,
),
)
for vector_id, score in zip(
query_result.ids or [],
query_result.similarities or [],
strict=False,
):
node_id = index.index_struct.nodes_dict.get(vector_id)
if node_id is None or node_id in seen_node_ids:
continue
node = index.docstore.docs.get(node_id)
if node is None or node.metadata.get("document_id") not in doc_ids:
continue
seen_node_ids.add(node_id)
allowed_nodes.append(NodeWithScore(node=node, score=score))
if len(allowed_nodes) >= similarity_top_k:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
if query_top_k == max_top_k:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
query_top_k = min(query_top_k * 2, max_top_k)
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
return DocumentFilteredFaissRetriever()
def stream_chat_with_documents(query_str: str, documents: list[Document]):
try:
yield from _stream_chat_with_documents(query_str, documents)
@@ -88,69 +160,63 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
def _stream_chat_with_documents(query_str: str, documents: list[Document]):
if not documents:
client = AIClient()
index = load_or_build_index()
doc_ids = [str(doc.pk) for doc in documents]
# Filter only the node(s) that match the document IDs
nodes = [
node
for node in index.docstore.docs.values()
if node.metadata.get("document_id") in doc_ids
]
if len(nodes) == 0:
logger.warning("No nodes found for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.retrievers import VectorIndexRetriever
config = AIConfig()
filters = _document_id_filters(str(doc.pk) for doc in documents)
retriever = _get_document_filtered_retriever(
index,
set(doc_ids),
CHAT_RETRIEVER_TOP_K,
)
# Hold the shared read lock for the whole operation: the query engine
# retrieves from the vector store again during synthesis, so the connection
# must stay open (and the swap must not run) until the stream finishes.
with read_store() as store:
index = load_or_build_index(config, store)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=CHAT_RETRIEVER_TOP_K,
filters=filters,
)
top_nodes = retriever.retrieve(query_str)
if len(top_nodes) == 0:
logger.warning("Retriever returned no nodes for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
# Slow query-embedding + vector search; no Django ORM access happens
# during it, so release the pooled DB connection for its duration. See
# #12976.
with db_connection_released():
top_nodes = retriever.retrieve(query_str)
if not top_nodes:
logger.warning("No nodes found for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
references = _get_document_references(documents, top_nodes)
client = AIClient()
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
response_synthesizer = get_response_synthesizer(
llm=client.llm,
prompt_helper=get_rag_prompt_helper(),
text_qa_template=prompt_template,
streaming=True,
)
references = _get_document_references(documents, top_nodes)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
response_synthesizer=response_synthesizer,
streaming=True,
)
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
response_synthesizer = get_response_synthesizer(
llm=client.llm,
prompt_helper=get_rag_prompt_helper(
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
),
text_qa_template=prompt_template,
streaming=True,
)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
response_synthesizer=response_synthesizer,
streaming=True,
)
logger.debug("Document chat query: %s", query_str)
logger.debug("Document chat query: %s", query_str)
# Release the pooled DB connection for the slow streaming LLM response
# so it is not pinned for the whole stream; see paperless_ai.db and
# #12976.
with db_connection_released():
response_stream = query_engine.query(query_str)
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()
response_stream = query_engine.query(query_str)
if references:
yield _format_chat_metadata_trailer(references)
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()
if references:
yield _format_chat_metadata_trailer(references)
+28 -49
View File
@@ -1,14 +1,11 @@
import json
import logging
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING
import httpx
from paperless.models import LLMBackend
if TYPE_CHECKING:
from llama_index.core.llms import ChatMessage
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai_like import OpenAILike
@@ -19,7 +16,6 @@ from paperless.network import create_pinned_async_httpx_client
from paperless.network import create_pinned_httpx_client
from paperless.network import validate_outbound_http_url
from paperless_ai.base_model import DocumentClassifierSchema
from paperless_ai.exceptions import LLMTimeoutError
logger = logging.getLogger("paperless_ai.client")
@@ -65,16 +61,16 @@ class AIClient:
model=self.settings.llm_model or "llama3.1",
base_url=endpoint,
context_window=self.settings.llm_context_size,
request_timeout=self.settings.llm_request_timeout,
request_timeout=120,
system_prompt=LLM_SYSTEM_PROMPT,
client=Client(
host=endpoint,
timeout=self.settings.llm_request_timeout,
timeout=120,
transport=transport,
),
async_client=AsyncClient(
host=endpoint,
timeout=self.settings.llm_request_timeout,
timeout=120,
transport=async_transport,
),
)
@@ -88,18 +84,15 @@ class AIClient:
http_client = create_pinned_httpx_client(
endpoint,
allow_internal=self.settings.llm_allow_internal_endpoints,
timeout=self.settings.llm_request_timeout,
)
async_http_client = create_pinned_async_httpx_client(
endpoint,
allow_internal=self.settings.llm_allow_internal_endpoints,
timeout=self.settings.llm_request_timeout,
)
return OpenAILike(
model=self.settings.llm_model or "gpt-3.5-turbo",
api_base=endpoint,
api_key=self.settings.llm_api_key,
timeout=self.settings.llm_request_timeout,
is_chat_model=True,
is_function_calling_model=True,
system_prompt=LLM_SYSTEM_PROMPT,
@@ -120,12 +113,11 @@ class AIClient:
user_msg = ChatMessage(role="user", content=prompt)
if self.settings.llm_backend == LLMBackend.OLLAMA:
with self._normalize_timeouts():
result = self.llm.chat(
[user_msg],
format=DocumentClassifierSchema.model_json_schema(),
think=False,
)
result = self.llm.chat(
[user_msg],
format=DocumentClassifierSchema.model_json_schema(),
think=False,
)
logger.debug("LLM query result: %s", result)
parsed = DocumentClassifierSchema(**json.loads(result.message.content))
return parsed.model_dump()
@@ -133,39 +125,26 @@ class AIClient:
from llama_index.core.program.function_program import get_function_tool
tool = get_function_tool(DocumentClassifierSchema)
with self._normalize_timeouts():
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
allow_parallel_tool_calls=True,
tool_required=True,
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_call=True,
)
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
allow_parallel_tool_calls=True,
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_call=True,
)
logger.debug("LLM query result: %s", tool_calls)
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
return parsed.model_dump()
@contextmanager
def _normalize_timeouts(self) -> Iterator[None]:
try:
yield
except httpx.TimeoutException as exc:
raise LLMTimeoutError from exc
except Exception as exc:
if self._is_openai_timeout(exc):
raise LLMTimeoutError from exc
raise
def _is_openai_timeout(self, exc: Exception) -> bool:
if self.settings.llm_backend != LLMBackend.OPENAI_LIKE:
return False
# Keep OpenAI imports out of module import paths and only load the SDK
# when translating an error from an OpenAI-backed request.
from openai import APITimeoutError
return isinstance(exc, APITimeoutError)
def run_chat(self, messages: list["ChatMessage"]) -> str:
logger.debug(
"Running chat query against %s with model %s",
self.settings.llm_backend,
self.settings.llm_model,
)
result = self.llm.chat(messages)
logger.debug("Chat result: %s", result)
return result
-30
View File
@@ -1,30 +0,0 @@
from __future__ import annotations
from contextlib import contextmanager
from django.db import connections
@contextmanager
def db_connection_released():
"""
Return any checked-out DB connections to the pool for the duration of the
wrapped block.
The AI endpoints run inside a synchronous web request (``ai_suggestions``)
or a streaming response (``chat``). Django keeps the request's database
connection checked out for the entire request/response, so a blocking LLM
call - which can take many seconds - pins a pooled connection the whole
time. With connection pooling enabled, enough concurrent AI requests check
out every slot and all other requests then fail with
``psycopg_pool.PoolTimeout`` (see issue #12976).
No Django ORM access happens during the LLM call, so we hand the connection
back to the pool first; Django transparently re-checks-out a connection on
the next ORM use after the block.
"""
connections.close_all()
try:
yield
finally:
connections.close_all()
+54 -29
View File
@@ -1,9 +1,12 @@
import json
import re
from typing import TYPE_CHECKING
from django.conf import settings
if TYPE_CHECKING:
from pathlib import Path
from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document
@@ -20,7 +23,9 @@ OCR_LEADER_REGEX = re.compile(r"[._\-\u00b7]{4,}")
HORIZONTAL_WHITESPACE_REGEX = re.compile(r"[ \t\u00a0]+")
def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
def get_embedding_model() -> "BaseEmbedding":
config = AIConfig()
match config.llm_embedding_backend:
case LLMEmbeddingBackend.OPENAI_LIKE:
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
@@ -32,18 +37,15 @@ def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
http_client = create_pinned_httpx_client(
endpoint,
allow_internal=config.llm_allow_internal_endpoints,
timeout=config.llm_request_timeout,
)
async_http_client = create_pinned_async_httpx_client(
endpoint,
allow_internal=config.llm_allow_internal_endpoints,
timeout=config.llm_request_timeout,
)
return OpenAILikeEmbedding(
model_name=config.llm_embedding_model or "text-embedding-3-small",
api_key=config.llm_api_key,
api_base=endpoint,
timeout=config.llm_request_timeout,
http_client=http_client,
async_http_client=async_http_client,
)
@@ -76,14 +78,12 @@ def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
)
embedding._client = Client(
host=endpoint,
timeout=config.llm_request_timeout,
transport=PinnedHostHTTPTransport(
allow_internal=config.llm_allow_internal_endpoints,
),
)
embedding._async_client = AsyncClient(
host=endpoint,
timeout=config.llm_request_timeout,
transport=PinnedHostAsyncHTTPTransport(
allow_internal=config.llm_allow_internal_endpoints,
),
@@ -95,24 +95,41 @@ def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
)
_DEFAULT_MODEL_NAMES = {
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
}
def get_configured_model_name(config: AIConfig) -> str:
"""Return the canonical name of the currently configured embedding model."""
# dict.get(key, default) overload resolution fails for TextChoices keys in some
# type checkers; use `or` fallback to avoid the ambiguity.
default = (
_DEFAULT_MODEL_NAMES.get(
config.llm_embedding_backend,
)
or "sentence-transformers/all-MiniLM-L6-v2"
def get_embedding_dim() -> int:
"""
Loads embedding dimension from meta.json if available, otherwise infers it
from a dummy embedding and stores it for future use.
"""
config = AIConfig()
default_model = {
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
}.get(
config.llm_embedding_backend,
"sentence-transformers/all-MiniLM-L6-v2",
)
return config.llm_embedding_model or default
model = config.llm_embedding_model or default_model
meta_path: Path = settings.LLM_INDEX_DIR / "meta.json"
if meta_path.exists():
with meta_path.open() as f:
meta = json.load(f)
if meta.get("embedding_model") != model:
raise RuntimeError(
f"Embedding model changed from {meta.get('embedding_model')} to {model}. "
"You must rebuild the index.",
)
return meta["dim"]
embedding_model = get_embedding_model()
test_embed = embedding_model.get_text_embedding("test")
dim = len(test_embed)
with meta_path.open("w") as f:
json.dump({"embedding_model": model, "dim": dim}, f)
return dim
def _normalize_llm_index_text(text: str) -> str:
@@ -121,16 +138,24 @@ def _normalize_llm_index_text(text: str) -> str:
def build_llm_index_text(doc: Document) -> str:
# Short structured fields (filename, storage path, ASN, title, tags, ...) live
# in node.metadata: excluded from embeddings, shown to the LLM via metadata
# prepend. Notes and Custom Fields stay in the body: Notes can be long free
# text, Custom Fields are dynamic in count and best kept in the embedding.
lines = [
f"Title: {doc.title}",
f"Filename: {doc.filename}",
f"Created: {doc.created}",
f"Added: {doc.added}",
f"Modified: {doc.modified}",
f"Tags: {', '.join(tag.name for tag in doc.tags.all())}",
f"Document Type: {doc.document_type.name if doc.document_type else ''}",
f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}",
f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}",
f"Archive Serial Number: {doc.archive_serial_number or ''}",
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
]
for instance in doc.custom_fields.all():
lines.append(f"Custom Field - {instance.field.name}: {instance}")
lines.extend(
f"Custom Field - {instance.field.name}: {instance}"
for instance in doc.custom_fields.all()
)
lines.append("\nContent:\n")
lines.append(doc.content or "")
-2
View File
@@ -1,2 +0,0 @@
class LLMTimeoutError(Exception):
pass
+243 -269
View File
@@ -1,30 +1,28 @@
import logging
import shutil
from collections import defaultdict
from collections.abc import Iterable
from contextlib import contextmanager
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
from django.conf import settings
from django.utils import timezone
from filelock import FileLock
from filelock import ReadWriteLock
from filelock import Timeout
from documents.models import Document
from documents.models import PaperlessTask
from documents.utils import IterWrapper
from documents.utils import identity
from paperless.config import AIConfig
from paperless_ai.db import db_connection_released
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_configured_model_name
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_embedding_model
if TYPE_CHECKING:
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import BaseNode
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
logger = logging.getLogger("paperless_ai.indexing")
@@ -32,11 +30,21 @@ RAG_NUM_OUTPUT = 512
RAG_CHUNK_OVERLAP = 200
def _index_lock_path() -> Path:
"""Return the path used as the file lock for FAISS index mutations.
The lock file lives in DATA_DIR/locks/ (not inside LLM_INDEX_DIR) so that a
rebuild which calls shutil.rmtree(LLM_INDEX_DIR) cannot delete the lock
while another worker still holds it.
"""
return settings.LLM_INDEX_LOCK
def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
# NOTE: The check-then-enqueue sequence below is non-atomic (TOCTOU): two
# concurrent workers can both observe no running task and both enqueue a
# full rebuild. This is wasteful but not data-corrupting — update_llm_index
# is itself protected by settings.LLM_INDEX_LOCK, so only one rebuild runs at a
# is itself protected by _index_lock_path(), so only one rebuild runs at a
# time and the second one is serialised after the first completes.
from documents.tasks import llmindex_index
@@ -63,110 +71,46 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
return True
def get_vector_store() -> "PaperlessSqliteVecVectorStore":
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
def get_or_create_storage_context(*, rebuild=False):
"""
Loads or creates the StorageContext (vector store, docstore, index store).
If rebuild=True, deletes and recreates everything.
"""
if rebuild:
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
return PaperlessSqliteVecVectorStore(
uri=str(settings.LLM_INDEX_DIR),
if rebuild or not settings.LLM_INDEX_DIR.exists():
import faiss
from llama_index.core import StorageContext
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
embedding_dim = get_embedding_dim()
faiss_index = faiss.IndexFlatL2(embedding_dim)
vector_store = FaissVectorStore(faiss_index=faiss_index)
docstore = SimpleDocumentStore()
index_store = SimpleIndexStore()
else:
from llama_index.core import StorageContext
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
return StorageContext.from_defaults(
docstore=docstore,
index_store=index_store,
vector_store=vector_store,
persist_dir=settings.LLM_INDEX_DIR,
)
# --- LLM index locking ---------------------------------------------------
#
# Two locks guard the index; they answer different questions and are NOT
# interchangeable:
#
# * settings.LLM_INDEX_LOCK (FileLock, exclusive) -- serializes WRITERS against
# each other, so only one rebuild/upsert/delete/compaction runs at a time.
# Taken by write_store(). Readers never take it, so it never blocks reads.
#
# * settings.LLM_INDEX_RWLOCK (ReadWriteLock) -- coordinates readers against the
# compaction/migration file swap. read_store() takes it SHARED (readers run
# concurrently); _exclude_readers() takes it EXCLUSIVE, only for the swap, so
# the database file is never replaced while a reader connection is open (that
# would alias the old WAL onto the new file and corrupt it).
#
# | vs another writer | vs a reader
# -----------------+-------------------+----------------------------
# normal write | LLM_INDEX_LOCK | nothing (WAL gives MVCC)
# compaction/swap | LLM_INDEX_LOCK | LLM_INDEX_RWLOCK (exclusive)
# reader | nothing (WAL) | LLM_INDEX_RWLOCK (shared)
#
# They can't be merged into one ReadWriteLock: a normal write must exclude other
# writers WITHOUT blocking readers (WAL already gives reader/writer concurrency),
# and ReadWriteLock has no "exclusive vs writers, shared vs readers" mode. Only
# the swap needs to exclude readers.
def _index_rwlock() -> ReadWriteLock:
"""Return a fresh read/write lock instance for the index swap.
``is_singleton=False`` so reads and the swap always coordinate through
SQLite (the actual cross-process case) rather than hitting the in-process
reentrant-upgrade guard; callers must ``close()`` it (the context managers
below do).
"""
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
return ReadWriteLock(str(settings.LLM_INDEX_RWLOCK), is_singleton=False)
@contextmanager
def read_store():
"""Acquire the shared read lock and yield the vector store for a read.
The shared lock is held for the whole lifetime of the connection (and
closed on exit) so the compaction/migration swap, which takes the exclusive
lock, never runs while this connection is open. Concurrent readers do not
block each other; only the swap does.
"""
lock = _index_rwlock()
try:
with lock.read_lock(), get_vector_store() as store:
yield store
finally:
lock.close()
@contextmanager
def _exclude_readers():
"""Acquire exclusive index access, blocking until readers have drained.
The exclusive counterpart to ``read_store()``: a compaction or migration
must not run while any reader connection is open. Raises
:class:`filelock.Timeout` if active readers do not drain within
``LLM_INDEX_COMPACTION_LOCK_TIMEOUT``; callers skip the operation on timeout.
"""
lock = _index_rwlock()
try:
with lock.write_lock(timeout=settings.LLM_INDEX_COMPACTION_LOCK_TIMEOUT):
yield
finally:
lock.close()
@contextmanager
def write_store(embed_model_name: str | None = None):
"""Acquire the write lock and yield the vector store.
All mutating operations (upsert, delete, rebuild, compact) must go through
this context manager to serialise concurrent Celery writers.
Read paths use ``read_store()`` so they hold the shared read lock.
Pass ``embed_model_name`` whenever the operation may create the table so
the model name is recorded in the schema metadata for future mismatch checks.
"""
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
with (
FileLock(settings.LLM_INDEX_LOCK),
PaperlessSqliteVecVectorStore(
uri=str(settings.LLM_INDEX_DIR),
embed_model_name=embed_model_name,
) as store,
):
yield store
def build_document_node(
document: Document,
*,
@@ -186,9 +130,6 @@ def build_document_node(
"document_type": document.document_type.name
if document.document_type
else None,
"filename": document.filename,
"storage_path": document.storage_path.name if document.storage_path else None,
"archive_serial_number": document.archive_serial_number,
"created": document.created.isoformat() if document.created else None,
"added": document.added.isoformat() if document.added else None,
"modified": document.modified.isoformat(),
@@ -201,11 +142,9 @@ def build_document_node(
# the token count and exceed embedding models with small context windows
# (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
doc = LlamaDocument(
id_=str(document.id),
text=text,
metadata=metadata,
excluded_embed_metadata_keys=list(metadata.keys()),
excluded_llm_metadata_keys=["document_id"],
)
chunk_size = chunk_size or get_rag_chunk_size()
parser = SimpleNodeParser(
@@ -215,33 +154,76 @@ def build_document_node(
return parser.get_nodes_from_documents([doc])
def load_or_build_index(config: AIConfig, store: "PaperlessSqliteVecVectorStore"):
"""Return a VectorStoreIndex backed by ``store``.
``store`` is supplied by the caller's ``read_store()`` context so the shared
read lock and the connection stay alive for the whole retrieval.
def load_or_build_index(nodes=None):
"""
Load an existing VectorStoreIndex if present,
or build a new one using provided nodes if storage is empty.
"""
import llama_index.core.settings as llama_settings
from llama_index.core import VectorStoreIndex
from llama_index.core import load_index_from_storage
embed_model = get_embedding_model(config)
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
return VectorStoreIndex.from_vector_store(
vector_store=store,
embed_model=embed_model,
)
storage_context = get_or_create_storage_context()
try:
return load_index_from_storage(storage_context=storage_context)
except ValueError as e:
logger.warning("Failed to load index from storage: %s", e)
if not nodes:
queue_llm_index_update_if_needed(
rebuild=vector_store_file_exists(),
reason="LLM index missing or invalid while loading.",
)
logger.info("No nodes provided for index creation.")
raise
return VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
)
def llm_index_exists() -> bool:
"""True when the index table exists on disk."""
with read_store() as store:
return store.table_exists()
def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"):
"""
Removes existing documents from docstore for a given document from the index.
This is necessary because FAISS IndexFlatL2 is append-only.
"""
all_node_ids = list(index.docstore.docs.keys())
existing_nodes = [
node.node_id
for node in index.docstore.get_nodes(all_node_ids)
if node.metadata.get("document_id") == str(document.id)
]
for node_id in existing_nodes:
# Delete from docstore, FAISS IndexFlatL2 are append-only
index.docstore.delete_document(node_id)
# Also purge the FAISS position -> UUID mapping so subsequent similarity
# queries don't raise KeyError on ghost vector positions.
stale_keys = [
k for k, v in index.index_struct.nodes_dict.items() if v == node_id
]
for key in stale_keys:
del index.index_struct.nodes_dict[key]
# Re-sync the mutated index_struct so persist() writes the updated nodes_dict.
index.storage_context.index_store.add_index_struct(index.index_struct)
def vector_store_file_exists():
"""
Check if the vector store file exists in the LLM index directory.
"""
return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists()
def get_rag_chunk_size() -> int:
return AIConfig().llm_embedding_chunk_size
def get_rag_context_size() -> int:
return AIConfig().llm_context_size
def get_rag_chunk_overlap(chunk_size: int | None = None) -> int:
chunk_size = chunk_size or get_rag_chunk_size()
return min(RAG_CHUNK_OVERLAP, chunk_size - 1)
@@ -267,149 +249,123 @@ def get_rag_prompt_helper(
)
def _embed_nodes(nodes: list["BaseNode"], embed_model) -> None:
"""Embed ``nodes`` in place using ``embed_model``."""
from llama_index.core.schema import MetadataMode
texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes]
for node, emb in zip(
nodes,
embed_model.get_text_embedding_batch(texts),
strict=True,
):
node.embedding = emb
def _document_id_filters(doc_ids):
"""Return a MetadataFilters IN filter scoped to ``doc_ids``."""
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
return MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=sorted(doc_ids),
),
],
)
def update_llm_index(
*,
iter_wrapper: IterWrapper[Document] = identity,
rebuild=False,
) -> str:
"""Rebuild or incrementally update the LLM index."""
with write_store() as store:
try:
with _exclude_readers():
needs_reembed = store.check_and_run_migrations()
except Timeout:
logger.info(
"Skipping LLM index migration check: index readers are active; "
"will retry next run.",
)
needs_reembed = False
if needs_reembed:
logger.warning(
"LLM index migration requires re-embedding; forcing rebuild.",
)
rebuild = True
documents = Document.objects.all()
no_documents = not documents.exists()
"""
Rebuild or update the LLM index.
"""
from llama_index.core import VectorStoreIndex
# Fast exit before touching config: nothing to index and no existing index.
if no_documents and not rebuild and not llm_index_exists():
nodes = []
documents = Document.objects.all()
if not documents.exists():
logger.warning("No documents found to index.")
return "No documents found to index."
if not rebuild and not vector_store_file_exists():
return "No documents found to index."
config = AIConfig()
model_name = get_configured_model_name(config)
if not rebuild and llm_index_exists():
with read_store() as store:
config_mismatch = store.config_mismatch(model_name)
if config_mismatch:
logger.warning("Embedding model changed; forcing LLM index rebuild.")
rebuild = True
if no_documents:
logger.warning("No documents found to index.")
chunk_size = config.llm_embedding_chunk_size
embed_model = get_embedding_model(config)
with write_store(embed_model_name=model_name) as store:
if rebuild or not store.table_exists():
with FileLock(_index_lock_path()):
if rebuild or not vector_store_file_exists():
# remove meta.json to force re-detection of embedding dim
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
# Rebuild index from scratch
logger.info("Rebuilding LLM index.")
store.drop_table()
import llama_index.core.settings as llama_settings
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=True)
for document in iter_wrapper(documents):
nodes = build_document_node(document, chunk_size=chunk_size)
_embed_nodes(nodes, embed_model)
store.add(nodes)
document_nodes = build_document_node(document, chunk_size=chunk_size)
nodes.extend(document_nodes)
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
show_progress=False,
)
msg = "LLM index rebuilt successfully."
else:
existing = store.get_modified_times()
changed = 0
# Update existing index
index = load_or_build_index()
existing_nodes: defaultdict[str, list] = defaultdict(list)
for node in index.docstore.docs.values():
doc_id = node.metadata.get("document_id")
if doc_id is not None:
existing_nodes[doc_id].append(node)
for document in iter_wrapper(documents):
doc_id = str(document.id)
if existing.get(doc_id) == document.modified.isoformat():
continue
nodes = build_document_node(document, chunk_size=chunk_size)
_embed_nodes(nodes, embed_model)
store.upsert_document(doc_id, nodes)
changed += 1
msg = (
"LLM index updated successfully."
if changed
else "No changes detected in LLM index."
)
document_modified = document.modified.isoformat()
try:
with _exclude_readers():
store.compact()
except Timeout:
logger.info(
"Skipping LLM index compaction: index readers are active; "
"will retry next run.",
)
if doc_id in existing_nodes:
doc_nodes = existing_nodes[doc_id]
node_modified = doc_nodes[0].metadata.get("modified")
if node_modified == document_modified:
continue
# Delete from docstore, FAISS IndexFlatL2 are append-only
for _ in doc_nodes:
remove_document_docstore_nodes(document, index)
nodes.extend(build_document_node(document, chunk_size=chunk_size))
if nodes:
msg = "LLM index updated successfully."
logger.info(
"Updating %d nodes in LLM index.",
len(nodes),
)
index.insert_nodes(nodes)
else:
msg = "No changes detected in LLM index."
logger.info(msg)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
return msg
def llm_index_add_or_update_document(document: Document):
"""Add or atomically replace a document's chunks in the index."""
config = AIConfig()
new_nodes = build_document_node(
document,
chunk_size=config.llm_embedding_chunk_size,
)
if new_nodes:
_embed_nodes(new_nodes, get_embedding_model(config))
"""
Adds or updates a document in the LLM index.
If the document already exists, it will be replaced.
"""
new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size())
if not new_nodes:
logger.warning(
"No indexable content for document %s; skipping LLM index update.",
document.pk,
)
return
with write_store(embed_model_name=get_configured_model_name(config)) as store:
store.upsert_document(str(document.id), new_nodes)
with FileLock(_index_lock_path()):
index = load_or_build_index(nodes=new_nodes)
remove_document_docstore_nodes(document, index)
def llm_index_compact() -> None:
"""Compact the index immediately, rebuilding the table to reclaim space."""
with write_store() as store:
try:
with _exclude_readers():
store.compact(force=True)
except Timeout:
logger.info(
"Skipping LLM index compaction: index readers are active; "
"will retry next run.",
)
index.insert_nodes(new_nodes)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def llm_index_remove_document(document: Document):
"""Remove a document's chunks from the LLM index."""
with write_store() as store:
store.delete(str(document.id))
"""
Removes a document from the LLM index.
"""
with FileLock(_index_lock_path()):
index = load_or_build_index()
remove_document_docstore_nodes(document, index)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def truncate_content(
@@ -454,59 +410,77 @@ def query_similar_documents(
top_k: int = 5,
document_ids: Iterable[int | str] | None = None,
) -> list[Document]:
"""Return up to ``top_k`` Documents most similar to ``document``."""
"""
Runs a similarity query and returns top-k similar Document objects.
"""
allowed_document_ids = normalize_document_ids(document_ids)
if allowed_document_ids is not None and not allowed_document_ids:
return []
if not llm_index_exists():
if not vector_store_file_exists():
queue_llm_index_update_if_needed(
rebuild=False,
reason="LLM index not found for similarity query.",
)
return []
config = AIConfig()
with FileLock(_index_lock_path()):
index = load_or_build_index()
from llama_index.core.retrievers import VectorIndexRetriever
# constrain only the node(s) that match the document IDs, if given
doc_node_ids = (
[
node.node_id
for node in index.docstore.docs.values()
if node.metadata.get("document_id") in allowed_document_ids
]
if allowed_document_ids is not None
else None
)
if doc_node_ids is not None and not doc_node_ids:
return []
filters = (
_document_id_filters(allowed_document_ids)
if allowed_document_ids is not None
else None
)
from llama_index.core.retrievers import VectorIndexRetriever
query_text = truncate_content(
(document.title or "") + "\n" + (document.content or ""),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
# Hold the shared read lock for the whole retrieval so the connection is
# never open across a compaction swap. The retrieve() call generates a
# query embedding (a slow external request) and searches the vector store;
# no Django ORM access happens during it, so release the pooled DB
# connection for its duration. See #12976.
with read_store() as store:
index = load_or_build_index(config, store)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k,
filters=filters,
doc_ids=doc_node_ids,
)
with db_connection_released():
config = AIConfig()
query_text = truncate_content(
(document.title or "") + "\n" + (document.content or ""),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
try:
results = retriever.retrieve(query_text)
except KeyError as e:
# Ghost FAISS positions remain after deletion because IndexFlatL2 is
# append-only. Treat them as absent and return no results.
logger.debug(
"Skipping LLM similarity query for document %s due to a stale "
"FAISS position with no docstore node: %s",
document.pk,
e,
)
return []
retrieved_document_ids: list[int] = []
for node in results:
document_id = node.metadata.get("document_id")
if document_id is None:
continue
normalized = str(document_id)
if allowed_document_ids is not None and normalized not in allowed_document_ids:
normalized_document_id = str(document_id)
if (
allowed_document_ids is not None
and normalized_document_id not in allowed_document_ids
):
continue
try:
retrieved_document_ids.append(int(normalized))
except ValueError: # pragma: no cover
retrieved_document_ids.append(int(normalized_document_id))
except ValueError:
logger.warning(
"Skipping LLM index result with invalid document_id %r.",
document_id,

Some files were not shown because too many files have changed in this diff Show More