mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-07-02 18:24:17 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a1e7c0614e | |||
| dac05107a7 | |||
| 89ce62d97d | |||
| 50f5d5f2e9 | |||
| 92b59eebfc | |||
| 59fd2ff9e8 |
@@ -61,7 +61,7 @@ def replace_with_symlinks(
|
|||||||
total_duplicates = 0
|
total_duplicates = 0
|
||||||
space_saved = 0
|
space_saved = 0
|
||||||
|
|
||||||
for file_hash, file_list in duplicate_groups.items():
|
for file_list in duplicate_groups.values():
|
||||||
# Keep the first file as the original, replace others with symlinks
|
# Keep the first file as the original, replace others with symlinks
|
||||||
original_file = file_list[0]
|
original_file = file_list[0]
|
||||||
duplicates = file_list[1:]
|
duplicates = file_list[1:]
|
||||||
|
|||||||
@@ -2068,13 +2068,6 @@ context by default.
|
|||||||
|
|
||||||
Defaults to 8192.
|
Defaults to 8192.
|
||||||
|
|
||||||
#### [`PAPERLESS_AI_LLM_REQUEST_TIMEOUT=<int>`](#PAPERLESS_AI_LLM_REQUEST_TIMEOUT) {#PAPERLESS_AI_LLM_REQUEST_TIMEOUT}
|
|
||||||
|
|
||||||
: The timeout, in seconds, for requests to the configured AI backend. Increase this when using
|
|
||||||
local or slow inference servers that need more time to generate responses.
|
|
||||||
|
|
||||||
Defaults to 120.
|
|
||||||
|
|
||||||
#### [`PAPERLESS_AI_LLM_BACKEND=<str>`](#PAPERLESS_AI_LLM_BACKEND) {#PAPERLESS_AI_LLM_BACKEND}
|
#### [`PAPERLESS_AI_LLM_BACKEND=<str>`](#PAPERLESS_AI_LLM_BACKEND) {#PAPERLESS_AI_LLM_BACKEND}
|
||||||
|
|
||||||
: The AI backend to use. This can be either "openai-like" or "ollama". If set to "ollama", the AI
|
: The AI backend to use. This can be either "openai-like" or "ollama". If set to "ollama", the AI
|
||||||
|
|||||||
+8
-2
@@ -42,6 +42,7 @@ dependencies = [
|
|||||||
"drf-spectacular~=0.28",
|
"drf-spectacular~=0.28",
|
||||||
"drf-spectacular-sidecar~=2026.5.1",
|
"drf-spectacular-sidecar~=2026.5.1",
|
||||||
"drf-writable-nested~=0.7.1",
|
"drf-writable-nested~=0.7.1",
|
||||||
|
"faiss-cpu>=1.10",
|
||||||
"filelock~=3.29.0",
|
"filelock~=3.29.0",
|
||||||
"flower~=2.0.1",
|
"flower~=2.0.1",
|
||||||
"gotenberg-client~=0.14.0",
|
"gotenberg-client~=0.14.0",
|
||||||
@@ -56,6 +57,7 @@ dependencies = [
|
|||||||
"llama-index-embeddings-openai-like>=0.2.2",
|
"llama-index-embeddings-openai-like>=0.2.2",
|
||||||
"llama-index-llms-ollama>=0.9.1",
|
"llama-index-llms-ollama>=0.9.1",
|
||||||
"llama-index-llms-openai-like>=0.7.1",
|
"llama-index-llms-openai-like>=0.7.1",
|
||||||
|
"llama-index-vector-stores-faiss>=0.5.2",
|
||||||
"nltk~=3.9.1",
|
"nltk~=3.9.1",
|
||||||
"ocrmypdf~=17.4.2",
|
"ocrmypdf~=17.4.2",
|
||||||
"openai>=2.32",
|
"openai>=2.32",
|
||||||
@@ -72,7 +74,6 @@ dependencies = [
|
|||||||
"scikit-learn~=1.8.0",
|
"scikit-learn~=1.8.0",
|
||||||
"sentence-transformers>=5.4.1",
|
"sentence-transformers>=5.4.1",
|
||||||
"setproctitle~=1.3.4",
|
"setproctitle~=1.3.4",
|
||||||
"sqlite-vec==0.1.9",
|
|
||||||
"tantivy~=0.26.0",
|
"tantivy~=0.26.0",
|
||||||
"tika-client~=0.11.0",
|
"tika-client~=0.11.0",
|
||||||
"torch~=2.11.0",
|
"torch~=2.11.0",
|
||||||
@@ -184,12 +185,16 @@ line-ending = "lf"
|
|||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
# https://docs.astral.sh/ruff/rules/
|
# https://docs.astral.sh/ruff/rules/
|
||||||
extend-select = [
|
extend-select = [
|
||||||
|
"B", # https://docs.astral.sh/ruff/rules/#flake8-bugbear-b
|
||||||
"COM", # https://docs.astral.sh/ruff/rules/#flake8-commas-com
|
"COM", # https://docs.astral.sh/ruff/rules/#flake8-commas-com
|
||||||
|
"DTZ", # https://docs.astral.sh/ruff/rules/#flake8-datetimez-dtz
|
||||||
|
"PERF", # https://docs.astral.sh/ruff/rules/#perflint-perf
|
||||||
|
"S324", # https://docs.astral.sh/ruff/rules/hashlib-insecure-hash-functions/
|
||||||
"DJ", # https://docs.astral.sh/ruff/rules/#flake8-django-dj
|
"DJ", # https://docs.astral.sh/ruff/rules/#flake8-django-dj
|
||||||
"EXE", # https://docs.astral.sh/ruff/rules/#flake8-executable-exe
|
"EXE", # https://docs.astral.sh/ruff/rules/#flake8-executable-exe
|
||||||
"FBT", # https://docs.astral.sh/ruff/rules/#flake8-boolean-trap-fbt
|
"FBT", # https://docs.astral.sh/ruff/rules/#flake8-boolean-trap-fbt
|
||||||
"FLY", # https://docs.astral.sh/ruff/rules/#flynt-fly
|
"FLY", # https://docs.astral.sh/ruff/rules/#flynt-fly
|
||||||
"G201", # https://docs.astral.sh/ruff/rules/#flake8-logging-format-g
|
"G", # https://docs.astral.sh/ruff/rules/#flake8-logging-format-g
|
||||||
"I", # https://docs.astral.sh/ruff/rules/#isort-i
|
"I", # https://docs.astral.sh/ruff/rules/#isort-i
|
||||||
"ICN", # https://docs.astral.sh/ruff/rules/#flake8-import-conventions-icn
|
"ICN", # https://docs.astral.sh/ruff/rules/#flake8-import-conventions-icn
|
||||||
"INP", # https://docs.astral.sh/ruff/rules/#flake8-no-pep420-inp
|
"INP", # https://docs.astral.sh/ruff/rules/#flake8-no-pep420-inp
|
||||||
@@ -210,6 +215,7 @@ extend-select = [
|
|||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
"DJ001",
|
"DJ001",
|
||||||
|
"G004", # f-strings in logging: accepted style in this codebase
|
||||||
"PLC0415",
|
"PLC0415",
|
||||||
"RUF012",
|
"RUF012",
|
||||||
"SIM105",
|
"SIM105",
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ module.exports = {
|
|||||||
'abstract-paperless-service',
|
'abstract-paperless-service',
|
||||||
],
|
],
|
||||||
transformIgnorePatterns: [
|
transformIgnorePatterns: [
|
||||||
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|normalize-diacritics|@angular/common/locales/.*\\.js$))',
|
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|@angular/common/locales/.*\\.js$))',
|
||||||
],
|
],
|
||||||
moduleNameMapper: {
|
moduleNameMapper: {
|
||||||
...esmPreset.moduleNameMapper,
|
...esmPreset.moduleNameMapper,
|
||||||
|
|||||||
@@ -32,7 +32,6 @@
|
|||||||
"ngx-cookie-service": "^21.3.1",
|
"ngx-cookie-service": "^21.3.1",
|
||||||
"ngx-device-detector": "^11.0.0",
|
"ngx-device-detector": "^11.0.0",
|
||||||
"ngx-ui-tour-ng-bootstrap": "^18.0.0",
|
"ngx-ui-tour-ng-bootstrap": "^18.0.0",
|
||||||
"normalize-diacritics": "^5.0.0",
|
|
||||||
"pdfjs-dist": "^5.7.284",
|
"pdfjs-dist": "^5.7.284",
|
||||||
"rxjs": "^7.8.2",
|
"rxjs": "^7.8.2",
|
||||||
"tslib": "^2.8.1",
|
"tslib": "^2.8.1",
|
||||||
|
|||||||
Generated
-11
@@ -71,9 +71,6 @@ importers:
|
|||||||
ngx-ui-tour-ng-bootstrap:
|
ngx-ui-tour-ng-bootstrap:
|
||||||
specifier: ^18.0.0
|
specifier: ^18.0.0
|
||||||
version: 18.0.0(f910a33494d223bd6dd07ce1bf22a35e)
|
version: 18.0.0(f910a33494d223bd6dd07ce1bf22a35e)
|
||||||
normalize-diacritics:
|
|
||||||
specifier: ^5.0.0
|
|
||||||
version: 5.0.0
|
|
||||||
pdfjs-dist:
|
pdfjs-dist:
|
||||||
specifier: ^5.7.284
|
specifier: ^5.7.284
|
||||||
version: 5.7.284
|
version: 5.7.284
|
||||||
@@ -5519,10 +5516,6 @@ packages:
|
|||||||
engines: {node: ^20.17.0 || >=22.9.0}
|
engines: {node: ^20.17.0 || >=22.9.0}
|
||||||
hasBin: true
|
hasBin: true
|
||||||
|
|
||||||
normalize-diacritics@5.0.0:
|
|
||||||
resolution: {integrity: sha512-t6czCJOpbAtckN1wCC2qPWnO3GQvNANb9bcUNbiOLEqojVuP31+ELIs5KhEG8jyz0TH7iD9BWxWz8O3ic2/rMQ==}
|
|
||||||
engines: {node: '>= 14.x', npm: '>= 6.x'}
|
|
||||||
|
|
||||||
normalize-path@3.0.0:
|
normalize-path@3.0.0:
|
||||||
resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==}
|
resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==}
|
||||||
engines: {node: '>=0.10.0'}
|
engines: {node: '>=0.10.0'}
|
||||||
@@ -12938,10 +12931,6 @@ snapshots:
|
|||||||
dependencies:
|
dependencies:
|
||||||
abbrev: 4.0.0
|
abbrev: 4.0.0
|
||||||
|
|
||||||
normalize-diacritics@5.0.0:
|
|
||||||
dependencies:
|
|
||||||
tslib: 2.8.1
|
|
||||||
|
|
||||||
normalize-path@3.0.0: {}
|
normalize-path@3.0.0: {}
|
||||||
|
|
||||||
npm-bundled@5.0.0:
|
npm-bundled@5.0.0:
|
||||||
|
|||||||
@@ -11,9 +11,6 @@
|
|||||||
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="visibleTasks.length === 0">
|
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="visibleTasks.length === 0">
|
||||||
<i-bs name="check2-all" class="me-1"></i-bs>{{dismissButtonText}}
|
<i-bs name="check2-all" class="me-1"></i-bs>{{dismissButtonText}}
|
||||||
</button>
|
</button>
|
||||||
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissAllTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="totalTasks === 0">
|
|
||||||
<i-bs name="check2-all" class="me-1"></i-bs><ng-container i18n>Dismiss all</ng-container>
|
|
||||||
</button>
|
|
||||||
<div class="form-check form-switch mb-0 ms-2">
|
<div class="form-check form-switch mb-0 ms-2">
|
||||||
<input class="form-check-input" type="checkbox" role="switch" [(ngModel)]="autoRefreshEnabled">
|
<input class="form-check-input" type="checkbox" role="switch" [(ngModel)]="autoRefreshEnabled">
|
||||||
<label class="form-check-label" for="autoRefreshSwitch" i18n>Auto refresh</label>
|
<label class="form-check-label" for="autoRefreshSwitch" i18n>Auto refresh</label>
|
||||||
@@ -84,7 +81,7 @@
|
|||||||
<button class="btn btn-sm btn-outline-primary" ngbDropdownToggle>{{filterTargetName}}</button>
|
<button class="btn btn-sm btn-outline-primary" ngbDropdownToggle>{{filterTargetName}}</button>
|
||||||
<div class="dropdown-menu shadow" ngbDropdownMenu>
|
<div class="dropdown-menu shadow" ngbDropdownMenu>
|
||||||
@for (t of filterTargets; track t.id) {
|
@for (t of filterTargets; track t.id) {
|
||||||
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="setFilterTarget(t.id)">{{t.name}}</button>
|
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="filterTargetID = t.id">{{t.name}}</button>
|
||||||
}
|
}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import { Router } from '@angular/router'
|
|||||||
import { RouterTestingModule } from '@angular/router/testing'
|
import { RouterTestingModule } from '@angular/router/testing'
|
||||||
import { NgbModal, NgbModalRef, NgbModule } from '@ng-bootstrap/ng-bootstrap'
|
import { NgbModal, NgbModalRef, NgbModule } from '@ng-bootstrap/ng-bootstrap'
|
||||||
import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||||
import { of, throwError } from 'rxjs'
|
import { throwError } from 'rxjs'
|
||||||
import { routes } from 'src/app/app-routing.module'
|
import { routes } from 'src/app/app-routing.module'
|
||||||
import {
|
import {
|
||||||
PaperlessTask,
|
PaperlessTask,
|
||||||
@@ -29,11 +29,7 @@ import { ToastService } from 'src/app/services/toast.service'
|
|||||||
import { environment } from 'src/environments/environment'
|
import { environment } from 'src/environments/environment'
|
||||||
import { ConfirmDialogComponent } from '../../common/confirm-dialog/confirm-dialog.component'
|
import { ConfirmDialogComponent } from '../../common/confirm-dialog/confirm-dialog.component'
|
||||||
import { PageHeaderComponent } from '../../common/page-header/page-header.component'
|
import { PageHeaderComponent } from '../../common/page-header/page-header.component'
|
||||||
import {
|
import { TasksComponent, TaskSection } from './tasks.component'
|
||||||
TaskFilterTargetID,
|
|
||||||
TasksComponent,
|
|
||||||
TaskSection,
|
|
||||||
} from './tasks.component'
|
|
||||||
|
|
||||||
const tasks: PaperlessTask[] = [
|
const tasks: PaperlessTask[] = [
|
||||||
{
|
{
|
||||||
@@ -158,13 +154,6 @@ const paginatedTasks: Results<PaperlessTask> = {
|
|||||||
results: tasks,
|
results: tasks,
|
||||||
}
|
}
|
||||||
|
|
||||||
const sectionCountResponse = {
|
|
||||||
all: 7,
|
|
||||||
needs_attention: 2,
|
|
||||||
in_progress: 3,
|
|
||||||
completed: 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
describe('TasksComponent', () => {
|
describe('TasksComponent', () => {
|
||||||
let component: TasksComponent
|
let component: TasksComponent
|
||||||
let fixture: ComponentFixture<TasksComponent>
|
let fixture: ComponentFixture<TasksComponent>
|
||||||
@@ -232,15 +221,6 @@ describe('TasksComponent', () => {
|
|||||||
req.params.get('page') === '1'
|
req.params.get('page') === '1'
|
||||||
)
|
)
|
||||||
.flush(paginatedTasks)
|
.flush(paginatedTasks)
|
||||||
|
|
||||||
httpTestingController
|
|
||||||
.expectOne(
|
|
||||||
(req) =>
|
|
||||||
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
|
|
||||||
req.params.get('acknowledged') === 'false' &&
|
|
||||||
!req.params.has('status')
|
|
||||||
)
|
|
||||||
.flush(sectionCountResponse)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should display task sections with counts', () => {
|
it('should display task sections with counts', () => {
|
||||||
@@ -315,7 +295,6 @@ describe('TasksComponent', () => {
|
|||||||
const headerText = header.nativeElement.textContent
|
const headerText = header.nativeElement.textContent
|
||||||
|
|
||||||
expect(headerText).toContain('Dismiss visible')
|
expect(headerText).toContain('Dismiss visible')
|
||||||
expect(headerText).toContain('Dismiss all')
|
|
||||||
expect(headerText).toContain('Auto refresh')
|
expect(headerText).toContain('Auto refresh')
|
||||||
expect(headerText).not.toContain('All types')
|
expect(headerText).not.toContain('All types')
|
||||||
expect(headerText).not.toContain('All sources')
|
expect(headerText).not.toContain('All sources')
|
||||||
@@ -348,74 +327,6 @@ describe('TasksComponent', () => {
|
|||||||
expect(pagination).not.toBeNull()
|
expect(pagination).not.toBeNull()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should apply the selected section to the server-side task query', () => {
|
|
||||||
component.setSection(TaskSection.NeedsAttention)
|
|
||||||
|
|
||||||
const req = httpTestingController.expectOne(
|
|
||||||
(request) =>
|
|
||||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
|
||||||
request.params.get('page') === '1' &&
|
|
||||||
request.params.get('page_size') === '25' &&
|
|
||||||
request.params.get('acknowledged') === 'false' &&
|
|
||||||
request.params.getAll('status').includes(PaperlessTaskStatus.Failure) &&
|
|
||||||
request.params.getAll('status').includes(PaperlessTaskStatus.Revoked)
|
|
||||||
)
|
|
||||||
|
|
||||||
req.flush({ count: 2, results: [tasks[0], tasks[1]] })
|
|
||||||
expect(component.totalTasks).toBe(2)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should apply task type and trigger source filters to the server-side task query', () => {
|
|
||||||
component.setTaskType(PaperlessTaskType.SanityCheck)
|
|
||||||
|
|
||||||
httpTestingController
|
|
||||||
.expectOne(
|
|
||||||
(request) =>
|
|
||||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
|
||||||
request.params.get('page_size') === '25' &&
|
|
||||||
request.params.get('task_type') === PaperlessTaskType.SanityCheck
|
|
||||||
)
|
|
||||||
.flush({ count: 1, results: [tasks[6]] })
|
|
||||||
|
|
||||||
component.setTriggerSource(PaperlessTaskTriggerSource.System)
|
|
||||||
|
|
||||||
httpTestingController
|
|
||||||
.expectOne(
|
|
||||||
(request) =>
|
|
||||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
|
||||||
request.params.get('page_size') === '25' &&
|
|
||||||
request.params.get('task_type') === PaperlessTaskType.SanityCheck &&
|
|
||||||
request.params.get('trigger_source') ===
|
|
||||||
PaperlessTaskTriggerSource.System
|
|
||||||
)
|
|
||||||
.flush({ count: 1, results: [tasks[6]] })
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should apply text filters to the server-side task query', () => {
|
|
||||||
component.filterText = 'invoice'
|
|
||||||
jest.advanceTimersByTime(150)
|
|
||||||
|
|
||||||
httpTestingController
|
|
||||||
.expectOne(
|
|
||||||
(request) =>
|
|
||||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
|
||||||
request.params.get('page_size') === '25' &&
|
|
||||||
request.params.get('name') === 'invoice'
|
|
||||||
)
|
|
||||||
.flush({ count: 1, results: [tasks[0]] })
|
|
||||||
|
|
||||||
component.setFilterTarget(TaskFilterTargetID.Result)
|
|
||||||
|
|
||||||
httpTestingController
|
|
||||||
.expectOne(
|
|
||||||
(request) =>
|
|
||||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
|
||||||
request.params.get('page_size') === '25' &&
|
|
||||||
request.params.get('result') === 'invoice'
|
|
||||||
)
|
|
||||||
.flush({ count: 0, results: [] })
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should load a different task page when pagination changes', () => {
|
it('should load a different task page when pagination changes', () => {
|
||||||
component.setPage(2)
|
component.setPage(2)
|
||||||
|
|
||||||
@@ -439,27 +350,6 @@ describe('TasksComponent', () => {
|
|||||||
expect(component.pagedTasks).toEqual([tasks[0]])
|
expect(component.pagedTasks).toEqual([tasks[0]])
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not replace section counts with current-page counts', () => {
|
|
||||||
component.setPage(2)
|
|
||||||
|
|
||||||
httpTestingController
|
|
||||||
.expectOne(
|
|
||||||
(req) =>
|
|
||||||
req.url === `${environment.apiBaseUrl}tasks/` &&
|
|
||||||
req.params.get('acknowledged') === 'false' &&
|
|
||||||
req.params.get('page_size') === '25' &&
|
|
||||||
req.params.get('page') === '2'
|
|
||||||
)
|
|
||||||
.flush({
|
|
||||||
count: 30,
|
|
||||||
results: [tasks[0]],
|
|
||||||
})
|
|
||||||
|
|
||||||
expect(component.sectionCount(TaskSection.NeedsAttention)).toBe(2)
|
|
||||||
expect(component.sectionCount(TaskSection.InProgress)).toBe(3)
|
|
||||||
expect(component.sectionCount(TaskSection.Completed)).toBe(2)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should expose stable task type options and disable empty ones', () => {
|
it('should expose stable task type options and disable empty ones', () => {
|
||||||
expect(component.taskTypeOptions.map((option) => option.value)).toContain(
|
expect(component.taskTypeOptions.map((option) => option.value)).toContain(
|
||||||
PaperlessTaskType.TrainClassifier
|
PaperlessTaskType.TrainClassifier
|
||||||
@@ -605,46 +495,6 @@ describe('TasksComponent', () => {
|
|||||||
expect(dismissSpy).toHaveBeenCalledWith(new Set([467, 466]))
|
expect(dismissSpy).toHaveBeenCalledWith(new Set([467, 466]))
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should support dismiss all tasks', () => {
|
|
||||||
let modal: NgbModalRef
|
|
||||||
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
|
|
||||||
const dismissSpy = jest
|
|
||||||
.spyOn(tasksService, 'dismissAllTasks')
|
|
||||||
.mockReturnValue(of({}))
|
|
||||||
const reloadPageSpy = jest
|
|
||||||
.spyOn(component as any, 'reloadPage')
|
|
||||||
.mockImplementation(() => undefined)
|
|
||||||
|
|
||||||
component.dismissAllTasks()
|
|
||||||
|
|
||||||
expect(modal).not.toBeUndefined()
|
|
||||||
expect(modal.componentInstance.messageBold).toBe('Dismiss all 7 tasks?')
|
|
||||||
modal.componentInstance.confirmClicked.emit()
|
|
||||||
expect(dismissSpy).toHaveBeenCalled()
|
|
||||||
expect(reloadPageSpy).toHaveBeenCalledWith(false)
|
|
||||||
expect(component.selectedTasks.size).toBe(0)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should show an error and re-enable modal buttons when dismissing all tasks fails', () => {
|
|
||||||
const error = new Error('dismiss all failed')
|
|
||||||
const toastSpy = jest.spyOn(toastService, 'showError')
|
|
||||||
const dismissSpy = jest
|
|
||||||
.spyOn(tasksService, 'dismissAllTasks')
|
|
||||||
.mockReturnValue(throwError(() => error))
|
|
||||||
|
|
||||||
let modal: NgbModalRef
|
|
||||||
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
|
|
||||||
|
|
||||||
component.dismissAllTasks()
|
|
||||||
expect(modal).not.toBeUndefined()
|
|
||||||
|
|
||||||
modal.componentInstance.confirmClicked.emit()
|
|
||||||
|
|
||||||
expect(dismissSpy).toHaveBeenCalled()
|
|
||||||
expect(toastSpy).toHaveBeenCalledWith('Error dismissing tasks', error)
|
|
||||||
expect(modal.componentInstance.buttonsEnabled).toBe(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should dismiss the currently visible scoped and filtered tasks', () => {
|
it('should dismiss the currently visible scoped and filtered tasks', () => {
|
||||||
component.setSection(TaskSection.InProgress)
|
component.setSection(TaskSection.InProgress)
|
||||||
component.setTaskType(PaperlessTaskType.SanityCheck)
|
component.setTaskType(PaperlessTaskType.SanityCheck)
|
||||||
@@ -823,9 +673,6 @@ describe('TasksComponent', () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
it('should keep clearing selection independent from resetting filters', () => {
|
it('should keep clearing selection independent from resetting filters', () => {
|
||||||
component.resetFilter()
|
|
||||||
expect(component.filterText).toBe('')
|
|
||||||
|
|
||||||
component.setTaskType(PaperlessTaskType.ConsumeFile)
|
component.setTaskType(PaperlessTaskType.ConsumeFile)
|
||||||
component.toggleSelected(tasks[0])
|
component.toggleSelected(tasks[0])
|
||||||
expect(component.selectedTasks.size).toBe(1)
|
expect(component.selectedTasks.size).toBe(1)
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ export enum TaskSection {
|
|||||||
Completed = 'completed',
|
Completed = 'completed',
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum TaskFilterTargetID {
|
enum TaskFilterTargetID {
|
||||||
Name,
|
Name,
|
||||||
Result,
|
Result,
|
||||||
}
|
}
|
||||||
@@ -167,12 +167,6 @@ export class TasksComponent
|
|||||||
public readonly pageSize = 25
|
public readonly pageSize = 25
|
||||||
public page: number = 1
|
public page: number = 1
|
||||||
public totalTasks: number = 0
|
public totalTasks: number = 0
|
||||||
public sectionCounts: Record<TaskSection, number> = {
|
|
||||||
[TaskSection.All]: 0,
|
|
||||||
[TaskSection.NeedsAttention]: 0,
|
|
||||||
[TaskSection.InProgress]: 0,
|
|
||||||
[TaskSection.Completed]: 0,
|
|
||||||
}
|
|
||||||
public pagedTasks: PaperlessTask[] = []
|
public pagedTasks: PaperlessTask[] = []
|
||||||
public selectedSection: TaskSection = TaskSection.All
|
public selectedSection: TaskSection = TaskSection.All
|
||||||
public selectedTaskType: PaperlessTaskType | null = null
|
public selectedTaskType: PaperlessTaskType | null = null
|
||||||
@@ -288,7 +282,6 @@ export class TasksComponent
|
|||||||
.subscribe((query) => {
|
.subscribe((query) => {
|
||||||
this._filterText = query
|
this._filterText = query
|
||||||
this.clearSelection()
|
this.clearSelection()
|
||||||
this.reloadPage(true)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -341,30 +334,6 @@ export class TasksComponent
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dismissAllTasks() {
|
|
||||||
let modal = this.modalService.open(ConfirmDialogComponent, {
|
|
||||||
backdrop: 'static',
|
|
||||||
})
|
|
||||||
modal.componentInstance.title = $localize`Confirm Dismiss All`
|
|
||||||
modal.componentInstance.messageBold = $localize`Dismiss all ${this.totalTasks} tasks?`
|
|
||||||
modal.componentInstance.btnClass = 'btn-warning'
|
|
||||||
modal.componentInstance.btnCaption = $localize`Dismiss`
|
|
||||||
modal.componentInstance.confirmClicked.pipe(first()).subscribe(() => {
|
|
||||||
modal.componentInstance.buttonsEnabled = false
|
|
||||||
modal.close()
|
|
||||||
this.tasksService.dismissAllTasks().subscribe({
|
|
||||||
next: () => {
|
|
||||||
this.reloadPage(false)
|
|
||||||
},
|
|
||||||
error: (e) => {
|
|
||||||
this.toastService.showError($localize`Error dismissing tasks`, e)
|
|
||||||
modal.componentInstance.buttonsEnabled = true
|
|
||||||
},
|
|
||||||
})
|
|
||||||
this.clearSelection()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
expandTask(task: PaperlessTask) {
|
expandTask(task: PaperlessTask) {
|
||||||
this.expandedTask = this.expandedTask == task.id ? undefined : task.id
|
this.expandedTask = this.expandedTask == task.id ? undefined : task.id
|
||||||
}
|
}
|
||||||
@@ -477,7 +446,9 @@ export class TasksComponent
|
|||||||
}
|
}
|
||||||
|
|
||||||
sectionCount(section: TaskSection): number {
|
sectionCount(section: TaskSection): number {
|
||||||
return this.sectionCounts[section]
|
return this.pagedTasks.filter((task) =>
|
||||||
|
this.taskBelongsToSection(task, section)
|
||||||
|
).length
|
||||||
}
|
}
|
||||||
|
|
||||||
sectionShowsResults(section: TaskSection): boolean {
|
sectionShowsResults(section: TaskSection): boolean {
|
||||||
@@ -487,27 +458,16 @@ export class TasksComponent
|
|||||||
setSection(section: TaskSection) {
|
setSection(section: TaskSection) {
|
||||||
this.selectedSection = section
|
this.selectedSection = section
|
||||||
this.clearSelection()
|
this.clearSelection()
|
||||||
this.reloadPage(true)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
setTaskType(taskType: PaperlessTaskType | null) {
|
setTaskType(taskType: PaperlessTaskType | null) {
|
||||||
this.selectedTaskType = taskType
|
this.selectedTaskType = taskType
|
||||||
this.clearSelection()
|
this.clearSelection()
|
||||||
this.reloadPage(true)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
setTriggerSource(triggerSource: PaperlessTaskTriggerSource | null) {
|
setTriggerSource(triggerSource: PaperlessTaskTriggerSource | null) {
|
||||||
this.selectedTriggerSource = triggerSource
|
this.selectedTriggerSource = triggerSource
|
||||||
this.clearSelection()
|
this.clearSelection()
|
||||||
this.reloadPage(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
setFilterTarget(filterTargetID: TaskFilterTargetID) {
|
|
||||||
this.filterTargetID = filterTargetID
|
|
||||||
if (this._filterText.length) {
|
|
||||||
this.clearSelection()
|
|
||||||
this.reloadPage(true)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
taskTypeOptionCount(taskType: PaperlessTaskType | null): number {
|
taskTypeOptionCount(taskType: PaperlessTaskType | null): number {
|
||||||
@@ -545,32 +505,19 @@ export class TasksComponent
|
|||||||
}
|
}
|
||||||
|
|
||||||
public resetFilter() {
|
public resetFilter() {
|
||||||
if (!this._filterText.length) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
this._filterText = ''
|
this._filterText = ''
|
||||||
this.clearSelection()
|
|
||||||
this.reloadPage(true)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public resetFilters() {
|
public resetFilters() {
|
||||||
const hadFilter = this.isFiltered
|
|
||||||
this.selectedTaskType = null
|
this.selectedTaskType = null
|
||||||
this.selectedTriggerSource = null
|
this.selectedTriggerSource = null
|
||||||
this._filterText = ''
|
this.resetFilter()
|
||||||
this.clearSelection()
|
this.clearSelection()
|
||||||
|
|
||||||
if (hadFilter) {
|
|
||||||
this.reloadPage(true)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filterInputKeyup(event: KeyboardEvent) {
|
filterInputKeyup(event: KeyboardEvent) {
|
||||||
if (event.key == 'Enter') {
|
if (event.key == 'Enter') {
|
||||||
this._filterText = (event.target as HTMLInputElement).value
|
this._filterText = (event.target as HTMLInputElement).value
|
||||||
this.clearSelection()
|
|
||||||
this.reloadPage(true)
|
|
||||||
} else if (event.key === 'Escape') {
|
} else if (event.key === 'Escape') {
|
||||||
this.resetFilter()
|
this.resetFilter()
|
||||||
}
|
}
|
||||||
@@ -659,86 +606,19 @@ export class TasksComponent
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
private reloadSectionCounts() {
|
|
||||||
this.tasksService
|
|
||||||
.statusCounts(this.getParamsForSection(TaskSection.All))
|
|
||||||
.pipe(first(), takeUntil(this.unsubscribeNotifier))
|
|
||||||
.subscribe((counts) => {
|
|
||||||
this.sectionCounts[TaskSection.All] = counts.all
|
|
||||||
this.sectionCounts[TaskSection.NeedsAttention] = counts.needs_attention
|
|
||||||
this.sectionCounts[TaskSection.InProgress] = counts.in_progress
|
|
||||||
this.sectionCounts[TaskSection.Completed] = counts.completed
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
private getParamsForSection(
|
|
||||||
section: TaskSection
|
|
||||||
): Record<string, string | number | boolean | readonly string[]> {
|
|
||||||
const params: Record<
|
|
||||||
string,
|
|
||||||
string | number | boolean | readonly string[]
|
|
||||||
> = {
|
|
||||||
acknowledged: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
const statuses = this.statusesForSection(section)
|
|
||||||
if (statuses.length) {
|
|
||||||
params.status = statuses
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.selectedTaskType !== null) {
|
|
||||||
params.task_type = this.selectedTaskType
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.selectedTriggerSource !== null) {
|
|
||||||
params.trigger_source = this.selectedTriggerSource
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this._filterText.length) {
|
|
||||||
params[
|
|
||||||
this.filterTargetID === TaskFilterTargetID.Name ? 'name' : 'result'
|
|
||||||
] = this._filterText
|
|
||||||
}
|
|
||||||
|
|
||||||
return params
|
|
||||||
}
|
|
||||||
|
|
||||||
private statusesForSection(section: TaskSection): PaperlessTaskStatus[] {
|
|
||||||
switch (section) {
|
|
||||||
case TaskSection.NeedsAttention:
|
|
||||||
return [PaperlessTaskStatus.Failure, PaperlessTaskStatus.Revoked]
|
|
||||||
case TaskSection.InProgress:
|
|
||||||
return [PaperlessTaskStatus.Pending, PaperlessTaskStatus.Started]
|
|
||||||
case TaskSection.Completed:
|
|
||||||
return [PaperlessTaskStatus.Success]
|
|
||||||
default:
|
|
||||||
return []
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private reloadPage(resetToFirstPage: boolean = false) {
|
private reloadPage(resetToFirstPage: boolean = false) {
|
||||||
if (resetToFirstPage) {
|
if (resetToFirstPage) {
|
||||||
this.page = 1
|
this.page = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
this.reloadSectionCounts()
|
|
||||||
|
|
||||||
this.loading = true
|
this.loading = true
|
||||||
this.tasksService
|
this.tasksService
|
||||||
.list(
|
.list(this.page, this.pageSize, { acknowledged: false })
|
||||||
this.page,
|
|
||||||
this.pageSize,
|
|
||||||
this.getParamsForSection(this.selectedSection)
|
|
||||||
)
|
|
||||||
.pipe(first(), takeUntil(this.unsubscribeNotifier))
|
.pipe(first(), takeUntil(this.unsubscribeNotifier))
|
||||||
.subscribe({
|
.subscribe({
|
||||||
next: (result) => {
|
next: (result) => {
|
||||||
this.pagedTasks = result.results
|
this.pagedTasks = result.results
|
||||||
this.totalTasks = result.count
|
this.totalTasks = result.count
|
||||||
this.sectionCounts[TaskSection.All] = result.count
|
|
||||||
if (this.selectedSection !== TaskSection.All) {
|
|
||||||
this.sectionCounts[this.selectedSection] = result.count
|
|
||||||
}
|
|
||||||
this.loading = false
|
this.loading = false
|
||||||
if (
|
if (
|
||||||
this.page > 1 &&
|
this.page > 1 &&
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
<div class="chat-messages font-monospace small">
|
<div class="chat-messages font-monospace small">
|
||||||
@for (message of messages; track message) {
|
@for (message of messages; track message) {
|
||||||
<div class="message d-flex flex-row small" [class.justify-content-end]="message.role === 'user'">
|
<div class="message d-flex flex-row small" [class.justify-content-end]="message.role === 'user'">
|
||||||
<div class="p-2 m-2" [class.bg-body]="message.role === 'user'">
|
<div class="p-2 m-2" [class.bg-dark]="message.role === 'user'">
|
||||||
<span>
|
<span>
|
||||||
{{ message.content }}
|
{{ message.content }}
|
||||||
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
|
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
|
||||||
|
|||||||
@@ -188,14 +188,4 @@ describe('ChatComponent', () => {
|
|||||||
component.searchInputKeyDown(event)
|
component.searchInputKeyDown(event)
|
||||||
expect(component.sendMessage).toHaveBeenCalled()
|
expect(component.sendMessage).toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not send message on Enter key press while composing with IME', () => {
|
|
||||||
jest.spyOn(component, 'sendMessage')
|
|
||||||
const event = new KeyboardEvent('keydown', {
|
|
||||||
key: 'Enter',
|
|
||||||
isComposing: true,
|
|
||||||
})
|
|
||||||
component.searchInputKeyDown(event)
|
|
||||||
expect(component.sendMessage).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -155,10 +155,7 @@ export class ChatComponent implements OnInit {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public searchInputKeyDown(event: KeyboardEvent) {
|
public searchInputKeyDown(event: KeyboardEvent) {
|
||||||
if (
|
if (event.key === 'Enter') {
|
||||||
event.key === 'Enter' &&
|
|
||||||
!(event.isComposing || event.keyCode === 229)
|
|
||||||
) {
|
|
||||||
event.preventDefault()
|
event.preventDefault()
|
||||||
this.sendMessage()
|
this.sendMessage()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,10 +5,10 @@
|
|||||||
</div>
|
</div>
|
||||||
<div class="modal-body">
|
<div class="modal-body">
|
||||||
@if (messageBold) {
|
@if (messageBold) {
|
||||||
<p class="text-break"><b>{{messageBold}}</b></p>
|
<p><b>{{messageBold}}</b></p>
|
||||||
}
|
}
|
||||||
@if (message) {
|
@if (message) {
|
||||||
<p class="mb-0 text-break" [innerHTML]="message"></p>
|
<p class="mb-0" [innerHTML]="message"></p>
|
||||||
}
|
}
|
||||||
</div>
|
</div>
|
||||||
<div class="modal-footer">
|
<div class="modal-footer">
|
||||||
|
|||||||
+1
-5
@@ -9,11 +9,8 @@
|
|||||||
<label class="form-label" for="metadataDocumentID" i18n>Documents:</label>
|
<label class="form-label" for="metadataDocumentID" i18n>Documents:</label>
|
||||||
<ul class="list-group"
|
<ul class="list-group"
|
||||||
cdkDropList
|
cdkDropList
|
||||||
[cdkDropListData]="documentIDs"
|
|
||||||
(cdkDropListDropped)="onDrop($event)">
|
(cdkDropListDropped)="onDrop($event)">
|
||||||
@for (documentID of documentIDs; track documentID) {
|
@for (document of documents; track document.id) {
|
||||||
@let document = getDocument(documentID);
|
|
||||||
@if (document) {
|
|
||||||
<li class="list-group-item d-flex align-items-center" cdkDrag>
|
<li class="list-group-item d-flex align-items-center" cdkDrag>
|
||||||
<i-bs name="grip-vertical" class="me-2"></i-bs>
|
<i-bs name="grip-vertical" class="me-2"></i-bs>
|
||||||
<div class="d-flex flex-column">
|
<div class="d-flex flex-column">
|
||||||
@@ -30,7 +27,6 @@
|
|||||||
</small>
|
</small>
|
||||||
</div>
|
</div>
|
||||||
</li>
|
</li>
|
||||||
}
|
|
||||||
}
|
}
|
||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
+3
-2
@@ -23,7 +23,6 @@ import {
|
|||||||
import { CustomFieldsService } from 'src/app/services/rest/custom-fields.service'
|
import { CustomFieldsService } from 'src/app/services/rest/custom-fields.service'
|
||||||
import { ToastService } from 'src/app/services/toast.service'
|
import { ToastService } from 'src/app/services/toast.service'
|
||||||
import { pngxPopperOptions } from 'src/app/utils/popper-options'
|
import { pngxPopperOptions } from 'src/app/utils/popper-options'
|
||||||
import { matchesSearchText } from 'src/app/utils/text-search'
|
|
||||||
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
|
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
|
||||||
import { CustomFieldEditDialogComponent } from '../edit-dialog/custom-field-edit-dialog/custom-field-edit-dialog.component'
|
import { CustomFieldEditDialogComponent } from '../edit-dialog/custom-field-edit-dialog/custom-field-edit-dialog.component'
|
||||||
|
|
||||||
@@ -70,7 +69,9 @@ export class CustomFieldsDropdownComponent extends LoadingComponentWithPermissio
|
|||||||
|
|
||||||
public get filteredFields(): CustomField[] {
|
public get filteredFields(): CustomField[] {
|
||||||
return this.unusedFields.filter(
|
return this.unusedFields.filter(
|
||||||
(f) => !this.filterText || matchesSearchText(f.name, this.filterText)
|
(f) =>
|
||||||
|
!this.filterText ||
|
||||||
|
f.name.toLowerCase().includes(this.filterText.toLowerCase())
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
-3
@@ -63,7 +63,6 @@
|
|||||||
[(ngModel)]="atom.value"
|
[(ngModel)]="atom.value"
|
||||||
[disabled]="disabled"
|
[disabled]="disabled"
|
||||||
[virtualScroll]="getSelectOptionsForField(atom.field)?.length > 100"
|
[virtualScroll]="getSelectOptionsForField(atom.field)?.length > 100"
|
||||||
[searchFn]="selectOptionSearchFn"
|
|
||||||
(mousedown)="$event.stopImmediatePropagation()"
|
(mousedown)="$event.stopImmediatePropagation()"
|
||||||
></ng-select>
|
></ng-select>
|
||||||
} @else if (getCustomFieldByID(atom.field)?.data_type === CustomFieldDataType.DocumentLink) {
|
} @else if (getCustomFieldByID(atom.field)?.data_type === CustomFieldDataType.DocumentLink) {
|
||||||
@@ -82,7 +81,6 @@
|
|||||||
[disabled]="disabled"
|
[disabled]="disabled"
|
||||||
bindLabel="name"
|
bindLabel="name"
|
||||||
bindValue="id"
|
bindValue="id"
|
||||||
[searchFn]="customFieldSearchFn"
|
|
||||||
(mousedown)="$event.stopImmediatePropagation()"
|
(mousedown)="$event.stopImmediatePropagation()"
|
||||||
></ng-select>
|
></ng-select>
|
||||||
<select class="w-25 form-select" [(ngModel)]="atom.operator" [disabled]="disabled">
|
<select class="w-25 form-select" [(ngModel)]="atom.operator" [disabled]="disabled">
|
||||||
@@ -127,7 +125,6 @@
|
|||||||
[(ngModel)]="atom.value"
|
[(ngModel)]="atom.value"
|
||||||
[disabled]="disabled"
|
[disabled]="disabled"
|
||||||
[multiple]="true"
|
[multiple]="true"
|
||||||
[searchFn]="selectOptionSearchFn"
|
|
||||||
(mousedown)="$event.stopImmediatePropagation()"
|
(mousedown)="$event.stopImmediatePropagation()"
|
||||||
></ng-select>
|
></ng-select>
|
||||||
}
|
}
|
||||||
|
|||||||
-9
@@ -36,7 +36,6 @@ import {
|
|||||||
CustomFieldQueryExpression,
|
CustomFieldQueryExpression,
|
||||||
} from 'src/app/utils/custom-field-query-element'
|
} from 'src/app/utils/custom-field-query-element'
|
||||||
import { pngxPopperOptions } from 'src/app/utils/popper-options'
|
import { pngxPopperOptions } from 'src/app/utils/popper-options'
|
||||||
import { matchesSearchText } from 'src/app/utils/text-search'
|
|
||||||
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
|
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
|
||||||
import { ClearableBadgeComponent } from '../clearable-badge/clearable-badge.component'
|
import { ClearableBadgeComponent } from '../clearable-badge/clearable-badge.component'
|
||||||
import { DocumentLinkComponent } from '../input/document-link/document-link.component'
|
import { DocumentLinkComponent } from '../input/document-link/document-link.component'
|
||||||
@@ -282,14 +281,6 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm
|
|||||||
|
|
||||||
public readonly today: string = new Date().toLocaleDateString('en-CA')
|
public readonly today: string = new Date().toLocaleDateString('en-CA')
|
||||||
|
|
||||||
public customFieldSearchFn = (term: string, field: CustomField): boolean =>
|
|
||||||
matchesSearchText(field?.name, term)
|
|
||||||
|
|
||||||
public selectOptionSearchFn = (
|
|
||||||
term: string,
|
|
||||||
option: { id: string; label: string }
|
|
||||||
): boolean => matchesSearchText(option?.label, term)
|
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
super()
|
super()
|
||||||
this.selectionModel = new CustomFieldQueriesModel()
|
this.selectionModel = new CustomFieldQueriesModel()
|
||||||
|
|||||||
@@ -28,7 +28,6 @@
|
|||||||
[notFoundText]="notFoundText"
|
[notFoundText]="notFoundText"
|
||||||
[multiple]="multiple"
|
[multiple]="multiple"
|
||||||
[bindLabel]="bindLabel"
|
[bindLabel]="bindLabel"
|
||||||
[searchFn]="searchFn"
|
|
||||||
bindValue="id"
|
bindValue="id"
|
||||||
[virtualScroll]="items?.length > 100"
|
[virtualScroll]="items?.length > 100"
|
||||||
(change)="onChange(value)"
|
(change)="onChange(value)"
|
||||||
|
|||||||
@@ -112,15 +112,6 @@ describe('SelectComponent', () => {
|
|||||||
expect(createNewVal).toEqual('baz')
|
expect(createNewVal).toEqual('baz')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should search items by independent normalized terms', () => {
|
|
||||||
expect(
|
|
||||||
component.searchFn('tax 26', { id: 11, name: 'Tax\u00e9s 2026' })
|
|
||||||
).toBeTruthy()
|
|
||||||
expect(
|
|
||||||
component.searchFn('tax receipt', { id: 11, name: 'Tax\u00e9s 2026' })
|
|
||||||
).toBeFalsy()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should clear search term on blur after delay', fakeAsync(() => {
|
it('should clear search term on blur after delay', fakeAsync(() => {
|
||||||
const clearSpy = jest.spyOn(component, 'clearLastSearchTerm')
|
const clearSpy = jest.spyOn(component, 'clearLastSearchTerm')
|
||||||
component.onBlur()
|
component.onBlur()
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import {
|
|||||||
import { RouterModule } from '@angular/router'
|
import { RouterModule } from '@angular/router'
|
||||||
import { NgSelectModule } from '@ng-select/ng-select'
|
import { NgSelectModule } from '@ng-select/ng-select'
|
||||||
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||||
import { matchesSearchText } from 'src/app/utils/text-search'
|
|
||||||
import { AbstractInputComponent } from '../abstract-input'
|
import { AbstractInputComponent } from '../abstract-input'
|
||||||
|
|
||||||
@Component({
|
@Component({
|
||||||
@@ -100,9 +99,6 @@ export class SelectComponent extends AbstractInputComponent<number> {
|
|||||||
@Input()
|
@Input()
|
||||||
bindLabel: string = 'name'
|
bindLabel: string = 'name'
|
||||||
|
|
||||||
public searchFn = (term: string, item: any): boolean =>
|
|
||||||
matchesSearchText(item?.[this.bindLabel], term)
|
|
||||||
|
|
||||||
@Input()
|
@Input()
|
||||||
showFilter: boolean = false
|
showFilter: boolean = false
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
[clearSearchOnAdd]="true"
|
[clearSearchOnAdd]="true"
|
||||||
[hideSelected]="tags.length > 0"
|
[hideSelected]="tags.length > 0"
|
||||||
[addTag]="allowCreate ? createTagRef : false"
|
[addTag]="allowCreate ? createTagRef : false"
|
||||||
[searchFn]="searchFn"
|
|
||||||
addTagText="Add tag"
|
addTagText="Add tag"
|
||||||
i18n-addTagText
|
i18n-addTagText
|
||||||
(add)="onAdd($event)"
|
(add)="onAdd($event)"
|
||||||
|
|||||||
@@ -171,15 +171,6 @@ describe('TagsComponent', () => {
|
|||||||
expect(component.getTag(4)).toBeUndefined()
|
expect(component.getTag(4)).toBeUndefined()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should search tags by independent normalized terms including parents', () => {
|
|
||||||
const parent: Tag = { id: 11, name: 'Financ\u00e9' }
|
|
||||||
const child: Tag = { id: 12, name: 'Taxes 2026', parent: parent.id }
|
|
||||||
component.tags = [parent, child]
|
|
||||||
|
|
||||||
expect(component.searchFn('finance 26', child)).toBeTruthy()
|
|
||||||
expect(component.searchFn('finance receipt', child)).toBeFalsy()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should emit filtered documents', () => {
|
it('should emit filtered documents', () => {
|
||||||
component.value = [10]
|
component.value = [10]
|
||||||
component.tags = tags
|
component.tags = tags
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
|||||||
import { first, firstValueFrom, tap } from 'rxjs'
|
import { first, firstValueFrom, tap } from 'rxjs'
|
||||||
import { Tag } from 'src/app/data/tag'
|
import { Tag } from 'src/app/data/tag'
|
||||||
import { TagService } from 'src/app/services/rest/tag.service'
|
import { TagService } from 'src/app/services/rest/tag.service'
|
||||||
import { matchesSearchText } from 'src/app/utils/text-search'
|
|
||||||
import { EditDialogMode } from '../../edit-dialog/edit-dialog.component'
|
import { EditDialogMode } from '../../edit-dialog/edit-dialog.component'
|
||||||
import { TagEditDialogComponent } from '../../edit-dialog/tag-edit-dialog/tag-edit-dialog.component'
|
import { TagEditDialogComponent } from '../../edit-dialog/tag-edit-dialog/tag-edit-dialog.component'
|
||||||
import { TagComponent } from '../../tag/tag.component'
|
import { TagComponent } from '../../tag/tag.component'
|
||||||
@@ -115,14 +114,6 @@ export class TagsComponent implements OnInit, ControlValueAccessor {
|
|||||||
|
|
||||||
public createTagRef: (name) => void
|
public createTagRef: (name) => void
|
||||||
|
|
||||||
public searchFn = (term: string, tag: Tag): boolean =>
|
|
||||||
matchesSearchText(
|
|
||||||
[this.getParentChain(tag?.id).map((parent) => parent.name), tag?.name]
|
|
||||||
.flat()
|
|
||||||
.join(' '),
|
|
||||||
term
|
|
||||||
)
|
|
||||||
|
|
||||||
getTag(id: number) {
|
getTag(id: number) {
|
||||||
if (this.tags) {
|
if (this.tags) {
|
||||||
return this.tags.find((tag) => tag.id == id)
|
return this.tags.find((tag) => tag.id == id)
|
||||||
|
|||||||
+2
-2
@@ -1,5 +1,5 @@
|
|||||||
<div class="btn-group">
|
<div class="btn-group">
|
||||||
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="disabled || loading || (suggestions && !aiEnabled)">
|
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="loading || (suggestions && !aiEnabled)">
|
||||||
@if (loading) {
|
@if (loading) {
|
||||||
<div class="spinner-border spinner-border-sm" role="status"></div>
|
<div class="spinner-border spinner-border-sm" role="status"></div>
|
||||||
} @else {
|
} @else {
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
|
|
||||||
@if (aiEnabled) {
|
@if (aiEnabled) {
|
||||||
<div class="btn-group" ngbDropdown #dropdown="ngbDropdown" [popperOptions]="popperOptions">
|
<div class="btn-group" ngbDropdown #dropdown="ngbDropdown" [popperOptions]="popperOptions">
|
||||||
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="disabled || loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
|
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
|
||||||
<span class="visually-hidden" i18n>Show suggestions</span>
|
<span class="visually-hidden" i18n>Show suggestions</span>
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
|
|||||||
-12
@@ -37,18 +37,6 @@ describe('SuggestionsDropdownComponent', () => {
|
|||||||
expect(component.getSuggestions.emit).toHaveBeenCalled()
|
expect(component.getSuggestions.emit).toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should not emit getSuggestions when disabled', () => {
|
|
||||||
jest.spyOn(component.getSuggestions, 'emit')
|
|
||||||
component.disabled = true
|
|
||||||
component.suggestions = null
|
|
||||||
fixture.detectChanges()
|
|
||||||
|
|
||||||
component.clickSuggest()
|
|
||||||
|
|
||||||
expect(component.getSuggestions.emit).not.toHaveBeenCalled()
|
|
||||||
expect(fixture.nativeElement.querySelector('button').disabled).toBeTruthy()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should toggle dropdown when clickSuggest is called and suggestions are not null', () => {
|
it('should toggle dropdown when clickSuggest is called and suggestions are not null', () => {
|
||||||
component.aiEnabled = true
|
component.aiEnabled = true
|
||||||
fixture.detectChanges()
|
fixture.detectChanges()
|
||||||
|
|||||||
-8
@@ -47,14 +47,6 @@ export class SuggestionsDropdownComponent {
|
|||||||
addCorrespondent: EventEmitter<string> = new EventEmitter()
|
addCorrespondent: EventEmitter<string> = new EventEmitter()
|
||||||
|
|
||||||
public clickSuggest(): void {
|
public clickSuggest(): void {
|
||||||
if (
|
|
||||||
this.disabled ||
|
|
||||||
this.loading ||
|
|
||||||
(this.suggestions && !this.aiEnabled)
|
|
||||||
) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!this.suggestions) {
|
if (!this.suggestions) {
|
||||||
this.getSuggestions.emit(this)
|
this.getSuggestions.emit(this)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
+1
-3
@@ -131,9 +131,7 @@
|
|||||||
@if (status.tasks.celery_status === 'OK') {
|
@if (status.tasks.celery_status === 'OK') {
|
||||||
<i-bs name="check-circle-fill" class="text-primary ms-2 lh-1"></i-bs>
|
<i-bs name="check-circle-fill" class="text-primary ms-2 lh-1"></i-bs>
|
||||||
} @else {
|
} @else {
|
||||||
<i-bs name="exclamation-triangle-fill" class="ms-2 lh-1"
|
<i-bs name="exclamation-triangle-fill" class="text-danger ms-2 lh-1"></i-bs>
|
||||||
[class.text-danger]="status.tasks.celery_status === SystemStatusItemStatus.ERROR"
|
|
||||||
[class.text-warning]="status.tasks.celery_status === SystemStatusItemStatus.WARNING"></i-bs>
|
|
||||||
}
|
}
|
||||||
</button>
|
</button>
|
||||||
<ng-template #celeryStatus>
|
<ng-template #celeryStatus>
|
||||||
|
|||||||
@@ -360,14 +360,6 @@ export const PaperlessConfigOptions: ConfigOption[] = [
|
|||||||
category: ConfigCategory.AI,
|
category: ConfigCategory.AI,
|
||||||
note: $localize`Language to use for generated AI suggestions. When unset, AI suggestions use the user's display language if explicitly set.`,
|
note: $localize`Language to use for generated AI suggestions. When unset, AI suggestions use the user's display language if explicitly set.`,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
key: 'llm_request_timeout',
|
|
||||||
title: $localize`LLM Request Timeout`,
|
|
||||||
type: ConfigOptionType.Number,
|
|
||||||
config_key: 'PAPERLESS_AI_LLM_REQUEST_TIMEOUT',
|
|
||||||
category: ConfigCategory.AI,
|
|
||||||
note: $localize`Timeout in seconds for LLM requests.`,
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
export interface PaperlessConfig extends ObjectWithId {
|
export interface PaperlessConfig extends ObjectWithId {
|
||||||
@@ -409,5 +401,4 @@ export interface PaperlessConfig extends ObjectWithId {
|
|||||||
llm_api_key: string
|
llm_api_key: string
|
||||||
llm_endpoint: string
|
llm_endpoint: string
|
||||||
llm_output_language: string
|
llm_output_language: string
|
||||||
llm_request_timeout: number
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,10 +64,3 @@ export interface PaperlessTaskSummary {
|
|||||||
last_success: Date | null
|
last_success: Date | null
|
||||||
last_failure: Date | null
|
last_failure: Date | null
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface PaperlessTaskStatusCounts {
|
|
||||||
all: number
|
|
||||||
needs_attention: number
|
|
||||||
in_progress: number
|
|
||||||
completed: number
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import { Pipe, PipeTransform } from '@angular/core'
|
import { Pipe, PipeTransform } from '@angular/core'
|
||||||
import { MatchingModel } from '../data/matching-model'
|
import { MatchingModel } from '../data/matching-model'
|
||||||
import { matchesSearchText } from '../utils/text-search'
|
|
||||||
|
|
||||||
@Pipe({
|
@Pipe({
|
||||||
name: 'filter',
|
name: 'filter',
|
||||||
@@ -22,7 +21,9 @@ export class FilterPipe implements PipeTransform {
|
|||||||
typeof item[key] === 'string' || typeof item[key] === 'number'
|
typeof item[key] === 'string' || typeof item[key] === 'number'
|
||||||
)
|
)
|
||||||
return keys.some((key) => {
|
return keys.some((key) => {
|
||||||
return matchesSearchText(item[key], searchText)
|
return String(item[key])
|
||||||
|
.toLowerCase()
|
||||||
|
.includes(searchText.toLowerCase())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,27 +80,6 @@ describe('TasksService', () => {
|
|||||||
.flush({ count: 0, results: [] })
|
.flush({ count: 0, results: [] })
|
||||||
})
|
})
|
||||||
|
|
||||||
it('calls acknowledge_tasks api endpoint on dismiss all and reloads', () => {
|
|
||||||
tasksService.dismissAllTasks().subscribe()
|
|
||||||
const req = httpTestingController.expectOne(
|
|
||||||
`${environment.apiBaseUrl}tasks/acknowledge/`
|
|
||||||
)
|
|
||||||
expect(req.request.method).toEqual('POST')
|
|
||||||
expect(req.request.body).toEqual({
|
|
||||||
all: true,
|
|
||||||
})
|
|
||||||
req.flush([])
|
|
||||||
// reload is then called
|
|
||||||
httpTestingController
|
|
||||||
.expectOne(
|
|
||||||
(req: HttpRequest<unknown>) =>
|
|
||||||
req.url === `${environment.apiBaseUrl}tasks/` &&
|
|
||||||
req.params.get('acknowledged') === 'false' &&
|
|
||||||
req.params.get('page_size') === '1000'
|
|
||||||
)
|
|
||||||
.flush({ count: 0, results: [] })
|
|
||||||
})
|
|
||||||
|
|
||||||
it('groups mixed task types by status when reloading', () => {
|
it('groups mixed task types by status when reloading', () => {
|
||||||
expect(tasksService.total).toEqual(0)
|
expect(tasksService.total).toEqual(0)
|
||||||
const mockTasks = [
|
const mockTasks = [
|
||||||
@@ -242,34 +221,4 @@ describe('TasksService', () => {
|
|||||||
task_id: 'abc-123',
|
task_id: 'abc-123',
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('loads filtered task status counts', () => {
|
|
||||||
tasksService
|
|
||||||
.statusCounts({
|
|
||||||
acknowledged: false,
|
|
||||||
task_type: PaperlessTaskType.ConsumeFile,
|
|
||||||
})
|
|
||||||
.subscribe((res) => {
|
|
||||||
expect(res).toEqual({
|
|
||||||
all: 10,
|
|
||||||
needs_attention: 2,
|
|
||||||
in_progress: 3,
|
|
||||||
completed: 5,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
const req = httpTestingController.expectOne(
|
|
||||||
(req: HttpRequest<unknown>) =>
|
|
||||||
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
|
|
||||||
req.params.get('acknowledged') === 'false' &&
|
|
||||||
req.params.get('task_type') === PaperlessTaskType.ConsumeFile
|
|
||||||
)
|
|
||||||
expect(req.request.method).toEqual('GET')
|
|
||||||
req.flush({
|
|
||||||
all: 10,
|
|
||||||
needs_attention: 2,
|
|
||||||
in_progress: 3,
|
|
||||||
completed: 5,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import { first, map, takeUntil, tap } from 'rxjs/operators'
|
|||||||
import {
|
import {
|
||||||
PaperlessTask,
|
PaperlessTask,
|
||||||
PaperlessTaskStatus,
|
PaperlessTaskStatus,
|
||||||
PaperlessTaskStatusCounts,
|
|
||||||
PaperlessTaskType,
|
PaperlessTaskType,
|
||||||
} from 'src/app/data/paperless-task'
|
} from 'src/app/data/paperless-task'
|
||||||
import { Results } from 'src/app/data/results'
|
import { Results } from 'src/app/data/results'
|
||||||
@@ -89,7 +88,7 @@ export class TasksService {
|
|||||||
public list(
|
public list(
|
||||||
page: number,
|
page: number,
|
||||||
pageSize: number,
|
pageSize: number,
|
||||||
extraParams?: Record<string, string | number | boolean | readonly string[]>
|
extraParams?: Record<string, string | number | boolean>
|
||||||
): Observable<Results<PaperlessTask>> {
|
): Observable<Results<PaperlessTask>> {
|
||||||
return this.http.get<Results<PaperlessTask>>(
|
return this.http.get<Results<PaperlessTask>>(
|
||||||
`${this.baseUrl}${this.endpoint}/`,
|
`${this.baseUrl}${this.endpoint}/`,
|
||||||
@@ -103,17 +102,6 @@ export class TasksService {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
public statusCounts(
|
|
||||||
extraParams?: Record<string, string | number | boolean | readonly string[]>
|
|
||||||
): Observable<PaperlessTaskStatusCounts> {
|
|
||||||
return this.http.get<PaperlessTaskStatusCounts>(
|
|
||||||
`${this.baseUrl}${this.endpoint}/status_counts/`,
|
|
||||||
{
|
|
||||||
params: extraParams,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
public dismissTasks(task_ids: Set<number>): Observable<any> {
|
public dismissTasks(task_ids: Set<number>): Observable<any> {
|
||||||
return this.http
|
return this.http
|
||||||
.post(`${this.baseUrl}tasks/acknowledge/`, {
|
.post(`${this.baseUrl}tasks/acknowledge/`, {
|
||||||
@@ -128,20 +116,6 @@ export class TasksService {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
public dismissAllTasks(): Observable<any> {
|
|
||||||
return this.http
|
|
||||||
.post(`${this.baseUrl}tasks/acknowledge/`, {
|
|
||||||
all: true,
|
|
||||||
})
|
|
||||||
.pipe(
|
|
||||||
first(),
|
|
||||||
takeUntil(this.unsubscribeNotifer),
|
|
||||||
tap(() => {
|
|
||||||
this.reload()
|
|
||||||
})
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
public cancelPending(): void {
|
public cancelPending(): void {
|
||||||
this.unsubscribeNotifer.next(true)
|
this.unsubscribeNotifer.next(true)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
import { matchesSearchText } from './text-search'
|
|
||||||
|
|
||||||
describe('text search utilities', () => {
|
|
||||||
it('matches text accent-insensitively', () => {
|
|
||||||
expect(matchesSearchText('R\u00e9sum\u00e9', 'resume')).toBeTruthy()
|
|
||||||
expect(matchesSearchText('S\u00f8ren', 'soren')).toBeTruthy()
|
|
||||||
expect(matchesSearchText('\u0152uvre', 'oeuvre')).toBeTruthy()
|
|
||||||
expect(matchesSearchText('Invoice', 'receipt')).toBeFalsy()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('matches all whitespace-separated search terms independently', () => {
|
|
||||||
expect(matchesSearchText('taxes 2026', 'tax 26')).toBeTruthy()
|
|
||||||
expect(matchesSearchText('2026 taxes', 'tax 26')).toBeTruthy()
|
|
||||||
expect(matchesSearchText('Tax\u00e9s 2026', 'taxe 26')).toBeTruthy()
|
|
||||||
expect(matchesSearchText('taxes 2026', 'tax receipt')).toBeFalsy()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
import { normalizeSync } from 'normalize-diacritics'
|
|
||||||
|
|
||||||
export type SearchTextValue =
|
|
||||||
| string
|
|
||||||
| number
|
|
||||||
| boolean
|
|
||||||
| bigint
|
|
||||||
| null
|
|
||||||
| undefined
|
|
||||||
|
|
||||||
export function normalizeSearchText(value: SearchTextValue): string {
|
|
||||||
return normalizeSync(String(value ?? '')).toLocaleLowerCase()
|
|
||||||
}
|
|
||||||
|
|
||||||
export function matchesSearchText(
|
|
||||||
value: SearchTextValue,
|
|
||||||
searchText: SearchTextValue
|
|
||||||
): boolean {
|
|
||||||
const normalizedValue = normalizeSearchText(value)
|
|
||||||
const searchTerms = normalizeSearchText(searchText).trim().split(/\s+/)
|
|
||||||
|
|
||||||
return searchTerms.every((term) => normalizedValue.includes(term))
|
|
||||||
}
|
|
||||||
@@ -904,19 +904,6 @@ def remove_password(
|
|||||||
doc.id,
|
doc.id,
|
||||||
pair.source_doc.source_path,
|
pair.source_doc.source_path,
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
with pikepdf.open(source_path) as pdf:
|
|
||||||
if not pdf.is_encrypted:
|
|
||||||
logger.info(
|
|
||||||
"Skipping password removal for document %s because the "
|
|
||||||
"source PDF is not encrypted",
|
|
||||||
pair.root_doc.id,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
except pikepdf.PasswordError:
|
|
||||||
# Password-protected PDFs need the supplied password below.
|
|
||||||
pass
|
|
||||||
|
|
||||||
with pikepdf.open(source_path, password=password) as pdf:
|
with pikepdf.open(source_path, password=password) as pdf:
|
||||||
filepath: Path = (
|
filepath: Path = (
|
||||||
Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR))
|
Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR))
|
||||||
|
|||||||
@@ -834,8 +834,9 @@ class ConsumerPlugin(
|
|||||||
self.log.debug(f"Creation date from parse_date: {create_date}")
|
self.log.debug(f"Creation date from parse_date: {create_date}")
|
||||||
else:
|
else:
|
||||||
stats = Path(self.input_doc.original_file).stat()
|
stats = Path(self.input_doc.original_file).stat()
|
||||||
create_date = timezone.make_aware(
|
create_date = datetime.datetime.fromtimestamp(
|
||||||
datetime.datetime.fromtimestamp(stats.st_mtime),
|
stats.st_mtime,
|
||||||
|
tz=datetime.UTC,
|
||||||
)
|
)
|
||||||
self.log.debug(f"Creation date from st_mtime: {create_date}")
|
self.log.debug(f"Creation date from st_mtime: {create_date}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import datetime as dt
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@@ -6,6 +5,7 @@ from pathlib import Path
|
|||||||
from typing import Final
|
from typing import Final
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.utils import timezone
|
||||||
from pikepdf import Pdf
|
from pikepdf import Pdf
|
||||||
|
|
||||||
from documents.consumer import ConsumerError
|
from documents.consumer import ConsumerError
|
||||||
@@ -78,7 +78,7 @@ class CollatePlugin(NoCleanupPluginMixin, NoSetupPluginMixin, ConsumeTaskPlugin)
|
|||||||
stats = staging.stat()
|
stats = staging.stat()
|
||||||
# if the file is older than the timeout, we don't consider
|
# if the file is older than the timeout, we don't consider
|
||||||
# it valid
|
# it valid
|
||||||
if (dt.datetime.now().timestamp() - stats.st_mtime) > TIMEOUT_SECONDS:
|
if (timezone.now().timestamp() - stats.st_mtime) > TIMEOUT_SECONDS:
|
||||||
logger.warning("Outdated double sided staging file exists, deleting it")
|
logger.warning("Outdated double sided staging file exists, deleting it")
|
||||||
staging.unlink()
|
staging.unlink()
|
||||||
else:
|
else:
|
||||||
@@ -99,7 +99,7 @@ class CollatePlugin(NoCleanupPluginMixin, NoSetupPluginMixin, ConsumeTaskPlugin)
|
|||||||
"two uploaded files don't belong to the same double-"
|
"two uploaded files don't belong to the same double-"
|
||||||
"sided scan. Please retry, starting with the odd "
|
"sided scan. Please retry, starting with the odd "
|
||||||
"numbered pages again.",
|
"numbered pages again.",
|
||||||
)
|
) from None
|
||||||
# Merged file has the same path, but without the
|
# Merged file has the same path, but without the
|
||||||
# double-sided subdir. Therefore, it is also in the
|
# double-sided subdir. Therefore, it is also in the
|
||||||
# consumption dir and will be picked up for processing
|
# consumption dir and will be picked up for processing
|
||||||
@@ -134,7 +134,7 @@ class CollatePlugin(NoCleanupPluginMixin, NoSetupPluginMixin, ConsumeTaskPlugin)
|
|||||||
shutil.move(pdf_file, staging)
|
shutil.move(pdf_file, staging)
|
||||||
# update access to modification time so we know if the file
|
# update access to modification time so we know if the file
|
||||||
# is outdated when another file gets uploaded
|
# is outdated when another file gets uploaded
|
||||||
timestamp = dt.datetime.now().timestamp()
|
timestamp = timezone.now().timestamp()
|
||||||
os.utime(staging, (timestamp, timestamp))
|
os.utime(staging, (timestamp, timestamp))
|
||||||
logger.info(
|
logger.info(
|
||||||
"Got scan with odd numbered pages of double-sided scan, moved it to %s",
|
"Got scan with odd numbered pages of double-sided scan, moved it to %s",
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ from django.db.models.functions import Cast
|
|||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from django_filters import DateFilter
|
from django_filters import DateFilter
|
||||||
from django_filters.rest_framework import BooleanFilter
|
from django_filters.rest_framework import BooleanFilter
|
||||||
from django_filters.rest_framework import CharFilter
|
|
||||||
from django_filters.rest_framework import DateTimeFilter
|
from django_filters.rest_framework import DateTimeFilter
|
||||||
from django_filters.rest_framework import Filter
|
from django_filters.rest_framework import Filter
|
||||||
from django_filters.rest_framework import FilterSet
|
from django_filters.rest_framework import FilterSet
|
||||||
@@ -351,7 +350,7 @@ def handle_validation_prefix(func: Callable):
|
|||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
except serializers.ValidationError as e:
|
except serializers.ValidationError as e:
|
||||||
raise serializers.ValidationError({validation_prefix: e.detail})
|
raise serializers.ValidationError({validation_prefix: e.detail}) from e
|
||||||
|
|
||||||
# Update the signature to include the validation_prefix argument
|
# Update the signature to include the validation_prefix argument
|
||||||
old_sig = inspect.signature(func)
|
old_sig = inspect.signature(func)
|
||||||
@@ -462,7 +461,7 @@ class CustomFieldQueryParser:
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
{self._validation_prefix: [_("Value must be valid JSON.")]},
|
{self._validation_prefix: [_("Value must be valid JSON.")]},
|
||||||
)
|
) from None
|
||||||
return (
|
return (
|
||||||
self._parse_expr(expr, validation_prefix=self._validation_prefix),
|
self._parse_expr(expr, validation_prefix=self._validation_prefix),
|
||||||
self._annotations,
|
self._annotations,
|
||||||
@@ -590,7 +589,7 @@ class CustomFieldQueryParser:
|
|||||||
except CustomField.DoesNotExist:
|
except CustomField.DoesNotExist:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
[_("{name!r} is not a valid custom field.").format(name=id_or_name)],
|
[_("{name!r} is not a valid custom field.").format(name=id_or_name)],
|
||||||
)
|
) from None
|
||||||
self._custom_fields[custom_field.id] = custom_field
|
self._custom_fields[custom_field.id] = custom_field
|
||||||
self._custom_fields[custom_field.name] = custom_field
|
self._custom_fields[custom_field.name] = custom_field
|
||||||
return custom_field
|
return custom_field
|
||||||
@@ -901,16 +900,6 @@ class ShareLinkBundleFilterSet(FilterSet):
|
|||||||
|
|
||||||
|
|
||||||
class PaperlessTaskFilterSet(FilterSet):
|
class PaperlessTaskFilterSet(FilterSet):
|
||||||
name = CharFilter(
|
|
||||||
method="filter_name",
|
|
||||||
label="Name",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = CharFilter(
|
|
||||||
method="filter_result",
|
|
||||||
label="Result",
|
|
||||||
)
|
|
||||||
|
|
||||||
task_type = MultipleChoiceFilter(
|
task_type = MultipleChoiceFilter(
|
||||||
choices=PaperlessTask.TaskType.choices,
|
choices=PaperlessTask.TaskType.choices,
|
||||||
label="Task Type",
|
label="Task Type",
|
||||||
@@ -950,58 +939,7 @@ class PaperlessTaskFilterSet(FilterSet):
|
|||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = PaperlessTask
|
model = PaperlessTask
|
||||||
fields = [
|
fields = ["task_type", "trigger_source", "status", "acknowledged", "owner"]
|
||||||
"task_type",
|
|
||||||
"trigger_source",
|
|
||||||
"status",
|
|
||||||
"acknowledged",
|
|
||||||
"owner",
|
|
||||||
"name",
|
|
||||||
"result",
|
|
||||||
]
|
|
||||||
|
|
||||||
def filter_name(self, queryset, name, value):
|
|
||||||
if not value:
|
|
||||||
return queryset
|
|
||||||
|
|
||||||
matching_task_types = [
|
|
||||||
task_type
|
|
||||||
for task_type, label in PaperlessTask.TaskType.choices
|
|
||||||
if value.lower() in str(label).lower()
|
|
||||||
]
|
|
||||||
matching_trigger_sources = [
|
|
||||||
trigger_source
|
|
||||||
for trigger_source, label in PaperlessTask.TriggerSource.choices
|
|
||||||
if value.lower() in str(label).lower()
|
|
||||||
]
|
|
||||||
|
|
||||||
return queryset.filter(
|
|
||||||
Q(input_data__filename__icontains=value)
|
|
||||||
| Q(task_type__in=matching_task_types)
|
|
||||||
| Q(trigger_source__in=matching_trigger_sources),
|
|
||||||
)
|
|
||||||
|
|
||||||
def filter_result(self, queryset, name, value):
|
|
||||||
if not value:
|
|
||||||
return queryset
|
|
||||||
|
|
||||||
query = Q(result_data__reason__icontains=value) | Q(
|
|
||||||
result_data__error_message__icontains=value,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
numeric_value = int(value)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
query |= Q(result_data__document_id=numeric_value) | Q(
|
|
||||||
result_data__duplicate_of=numeric_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
if "duplicate" in value.lower():
|
|
||||||
query |= Q(result_data__duplicate_of__isnull=False)
|
|
||||||
|
|
||||||
return queryset.filter(query)
|
|
||||||
|
|
||||||
def filter_is_complete(self, queryset, name, value):
|
def filter_is_complete(self, queryset, name, value):
|
||||||
if value:
|
if value:
|
||||||
@@ -1050,7 +988,7 @@ class DocumentsOrderingFilter(OrderingFilter):
|
|||||||
except CustomField.DoesNotExist:
|
except CustomField.DoesNotExist:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
{self.prefix + str(custom_field_id): [_("Custom field not found")]},
|
{self.prefix + str(custom_field_id): [_("Custom field not found")]},
|
||||||
)
|
) from None
|
||||||
|
|
||||||
annotation = None
|
annotation = None
|
||||||
match field.data_type:
|
match field.data_type:
|
||||||
|
|||||||
@@ -169,10 +169,6 @@ class FileStabilityTracker:
|
|||||||
self._tracked.pop(path, None)
|
self._tracked.pop(path, None)
|
||||||
yield path
|
yield path
|
||||||
|
|
||||||
def is_tracking(self, path: Path) -> bool:
|
|
||||||
"""Check whether a path is currently being tracked for stability."""
|
|
||||||
return path.resolve() in self._tracked
|
|
||||||
|
|
||||||
def has_pending_files(self) -> bool:
|
def has_pending_files(self) -> bool:
|
||||||
"""Check if there are files waiting for stability check."""
|
"""Check if there are files waiting for stability check."""
|
||||||
return len(self._tracked) > 0
|
return len(self._tracked) > 0
|
||||||
@@ -374,16 +370,6 @@ class Command(BaseCommand):
|
|||||||
# Testing timeout in seconds
|
# Testing timeout in seconds
|
||||||
testing_timeout_s: Final[float] = 0.5
|
testing_timeout_s: Final[float] = 0.5
|
||||||
|
|
||||||
# How often to perform a full-glob rescan of the consume directory as a
|
|
||||||
# safety net. Each watchfiles watcher is torn down and recreated on every
|
|
||||||
# batch to reconfigure its timeout, and a fresh watcher silently adopts the
|
|
||||||
# current directory contents as its baseline. A file that appears between
|
|
||||||
# one batch and the next watcher's baseline is therefore never reported and
|
|
||||||
# would sit in the consume directory forever. This periodic rescan re-injects
|
|
||||||
# such files into the stability tracker (see GH issue #13011). Not currently
|
|
||||||
# user-configurable; instances may override for testing.
|
|
||||||
rescan_interval_s: float = 300.0
|
|
||||||
|
|
||||||
def add_arguments(self, parser) -> None:
|
def add_arguments(self, parser) -> None:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"directory",
|
"directory",
|
||||||
@@ -439,7 +425,7 @@ class Command(BaseCommand):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Process existing files
|
# Process existing files
|
||||||
queued = self._process_existing_files(
|
self._process_existing_files(
|
||||||
directory=directory,
|
directory=directory,
|
||||||
recursive=recursive,
|
recursive=recursive,
|
||||||
subdirs_as_tags=subdirs_as_tags,
|
subdirs_as_tags=subdirs_as_tags,
|
||||||
@@ -459,7 +445,6 @@ class Command(BaseCommand):
|
|||||||
polling_interval=polling_interval,
|
polling_interval=polling_interval,
|
||||||
stability_delay=stability_delay,
|
stability_delay=stability_delay,
|
||||||
is_testing=is_testing,
|
is_testing=is_testing,
|
||||||
queued=queued,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Consumer exiting")
|
logger.debug("Consumer exiting")
|
||||||
@@ -471,18 +456,11 @@ class Command(BaseCommand):
|
|||||||
recursive: bool,
|
recursive: bool,
|
||||||
subdirs_as_tags: bool,
|
subdirs_as_tags: bool,
|
||||||
consumer_filter: ConsumerFilter,
|
consumer_filter: ConsumerFilter,
|
||||||
) -> set[Path]:
|
) -> None:
|
||||||
"""
|
"""Process any existing files in the consumption directory."""
|
||||||
Process any existing files in the consumption directory.
|
|
||||||
|
|
||||||
Returns the set of resolved paths that were queued, so the watch loop
|
|
||||||
can seed its in-flight set and avoid re-queuing them on the first
|
|
||||||
rescan before the consume tasks have removed them from disk.
|
|
||||||
"""
|
|
||||||
logger.info(f"Processing existing files in {directory}")
|
logger.info(f"Processing existing files in {directory}")
|
||||||
|
|
||||||
glob_pattern = "**/*" if recursive else "*"
|
glob_pattern = "**/*" if recursive else "*"
|
||||||
queued: set[Path] = set()
|
|
||||||
|
|
||||||
for filepath in directory.glob(glob_pattern):
|
for filepath in directory.glob(glob_pattern):
|
||||||
# Use filter to check if file should be processed
|
# Use filter to check if file should be processed
|
||||||
@@ -497,48 +475,6 @@ class Command(BaseCommand):
|
|||||||
consumption_dir=directory,
|
consumption_dir=directory,
|
||||||
subdirs_as_tags=subdirs_as_tags,
|
subdirs_as_tags=subdirs_as_tags,
|
||||||
)
|
)
|
||||||
queued.add(filepath.resolve())
|
|
||||||
|
|
||||||
return queued
|
|
||||||
|
|
||||||
def _rescan_existing_files(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
directory: Path,
|
|
||||||
recursive: bool,
|
|
||||||
consumer_filter: ConsumerFilter,
|
|
||||||
tracker: FileStabilityTracker,
|
|
||||||
queued: set[Path],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Re-inject on-disk files the watcher never reported into the tracker.
|
|
||||||
|
|
||||||
Acts as a safety net for files stranded by the watcher-recreation gap
|
|
||||||
(see ``rescan_interval_s``). Files already being tracked or already
|
|
||||||
queued and awaiting consumption are skipped, so a file is never queued
|
|
||||||
twice. Queued paths that have since left the directory are pruned so a
|
|
||||||
later file reusing the same name is not skipped forever.
|
|
||||||
"""
|
|
||||||
# Prune in-flight paths that have left the directory
|
|
||||||
for path in list(queued):
|
|
||||||
if not path.exists():
|
|
||||||
queued.discard(path)
|
|
||||||
|
|
||||||
glob_pattern = "**/*" if recursive else "*"
|
|
||||||
|
|
||||||
for filepath in directory.glob(glob_pattern):
|
|
||||||
if not filepath.is_file():
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not consumer_filter(Change.added, str(filepath)):
|
|
||||||
continue
|
|
||||||
|
|
||||||
resolved = filepath.resolve()
|
|
||||||
if tracker.is_tracking(resolved) or resolved in queued:
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.debug(f"Rescan found untracked file: {resolved}")
|
|
||||||
tracker.track(resolved, Change.added)
|
|
||||||
|
|
||||||
def _watch_directory(
|
def _watch_directory(
|
||||||
self,
|
self,
|
||||||
@@ -550,24 +486,11 @@ class Command(BaseCommand):
|
|||||||
polling_interval: float,
|
polling_interval: float,
|
||||||
stability_delay: float,
|
stability_delay: float,
|
||||||
is_testing: bool,
|
is_testing: bool,
|
||||||
queued: set[Path] | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Watch directory for changes and process stable files."""
|
"""Watch directory for changes and process stable files."""
|
||||||
use_polling = polling_interval > 0
|
use_polling = polling_interval > 0
|
||||||
poll_delay_ms = int(polling_interval * 1000) if use_polling else 0
|
poll_delay_ms = int(polling_interval * 1000) if use_polling else 0
|
||||||
|
|
||||||
# Resolved paths that have been queued and are awaiting consumption.
|
|
||||||
# Seeded from the startup scan so the first rescan does not re-queue
|
|
||||||
# files whose consume tasks have not yet removed them from disk.
|
|
||||||
queued = set() if queued is None else queued
|
|
||||||
|
|
||||||
# Full-glob safety net cadence (0 disables)
|
|
||||||
rescan_interval_s = self.rescan_interval_s
|
|
||||||
rescan_timeout_ms = (
|
|
||||||
int(rescan_interval_s * 1000) if rescan_interval_s > 0 else 0
|
|
||||||
)
|
|
||||||
last_rescan = monotonic()
|
|
||||||
|
|
||||||
if use_polling:
|
if use_polling:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Watching {directory} using polling (interval: {polling_interval}s)",
|
f"Watching {directory} using polling (interval: {polling_interval}s)",
|
||||||
@@ -582,20 +505,6 @@ class Command(BaseCommand):
|
|||||||
stability_timeout_ms = int(stability_delay * 1000)
|
stability_timeout_ms = int(stability_delay * 1000)
|
||||||
testing_timeout_ms = int(self.testing_timeout_s * 1000)
|
testing_timeout_ms = int(self.testing_timeout_s * 1000)
|
||||||
|
|
||||||
def cap_for_rescan(ms: int) -> int:
|
|
||||||
"""
|
|
||||||
Ensure the watch loop wakes often enough to run the rescan.
|
|
||||||
|
|
||||||
``watch()`` blocks for up to ``rust_timeout``, so the rescan can
|
|
||||||
only run that often. A timeout of 0 means "wait indefinitely",
|
|
||||||
which would never wake to rescan; cap it at the rescan interval.
|
|
||||||
"""
|
|
||||||
if rescan_timeout_ms <= 0:
|
|
||||||
return ms
|
|
||||||
if ms <= 0:
|
|
||||||
return rescan_timeout_ms
|
|
||||||
return min(ms, rescan_timeout_ms)
|
|
||||||
|
|
||||||
# Calculate appropriate timeout for watch loop
|
# Calculate appropriate timeout for watch loop
|
||||||
# In polling mode, rust_timeout must be significantly longer than poll_delay_ms
|
# In polling mode, rust_timeout must be significantly longer than poll_delay_ms
|
||||||
# to ensure poll cycles can complete before timing out
|
# to ensure poll cycles can complete before timing out
|
||||||
@@ -613,8 +522,6 @@ class Command(BaseCommand):
|
|||||||
# Not testing, wait indefinitely for first event
|
# Not testing, wait indefinitely for first event
|
||||||
timeout_ms = 0
|
timeout_ms = 0
|
||||||
|
|
||||||
timeout_ms = cap_for_rescan(timeout_ms)
|
|
||||||
|
|
||||||
self.stop_flag.clear()
|
self.stop_flag.clear()
|
||||||
|
|
||||||
while not self.stop_flag.is_set():
|
while not self.stop_flag.is_set():
|
||||||
@@ -644,26 +551,10 @@ class Command(BaseCommand):
|
|||||||
consumption_dir=directory,
|
consumption_dir=directory,
|
||||||
subdirs_as_tags=subdirs_as_tags,
|
subdirs_as_tags=subdirs_as_tags,
|
||||||
)
|
)
|
||||||
# Remember it so the rescan does not re-queue it while
|
|
||||||
# the consume task has yet to remove it from disk
|
|
||||||
queued.add(stable_path)
|
|
||||||
|
|
||||||
# Exit watch loop to reconfigure timeout
|
# Exit watch loop to reconfigure timeout
|
||||||
break
|
break
|
||||||
|
|
||||||
# Periodic full-glob safety net for files the watcher missed
|
|
||||||
if rescan_timeout_ms > 0 and (
|
|
||||||
monotonic() - last_rescan >= rescan_interval_s
|
|
||||||
):
|
|
||||||
self._rescan_existing_files(
|
|
||||||
directory=directory,
|
|
||||||
recursive=recursive,
|
|
||||||
consumer_filter=consumer_filter,
|
|
||||||
tracker=tracker,
|
|
||||||
queued=queued,
|
|
||||||
)
|
|
||||||
last_rescan = monotonic()
|
|
||||||
|
|
||||||
# Determine next timeout
|
# Determine next timeout
|
||||||
if tracker.has_pending_files():
|
if tracker.has_pending_files():
|
||||||
# Check pending files at stability interval
|
# Check pending files at stability interval
|
||||||
@@ -681,8 +572,6 @@ class Command(BaseCommand):
|
|||||||
# No pending files, wait indefinitely
|
# No pending files, wait indefinitely
|
||||||
timeout_ms = 0
|
timeout_ms = 0
|
||||||
|
|
||||||
timeout_ms = cap_for_rescan(timeout_ms)
|
|
||||||
|
|
||||||
except KeyboardInterrupt: # pragma: nocover
|
except KeyboardInterrupt: # pragma: nocover
|
||||||
logger.info("Received interrupt, stopping consumer")
|
logger.info("Received interrupt, stopping consumer")
|
||||||
self.stop_flag.set()
|
self.stop_flag.set()
|
||||||
|
|||||||
@@ -480,7 +480,7 @@ class Command(CryptMixin, PaperlessCommand):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 3. Export files from each document
|
# 3. Export files from each document
|
||||||
for index, document_dict in enumerate(
|
for _, document_dict in enumerate(
|
||||||
self.track(
|
self.track(
|
||||||
document_manifest,
|
document_manifest,
|
||||||
description="Exporting documents...",
|
description="Exporting documents...",
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ from typing import Any
|
|||||||
|
|
||||||
from documents.management.commands.base import PaperlessCommand
|
from documents.management.commands.base import PaperlessCommand
|
||||||
from documents.tasks import llmindex_index
|
from documents.tasks import llmindex_index
|
||||||
from paperless_ai.indexing import llm_index_compact
|
|
||||||
|
|
||||||
|
|
||||||
class Command(PaperlessCommand):
|
class Command(PaperlessCommand):
|
||||||
@@ -13,12 +12,9 @@ class Command(PaperlessCommand):
|
|||||||
|
|
||||||
def add_arguments(self, parser: Any) -> None:
|
def add_arguments(self, parser: Any) -> None:
|
||||||
super().add_arguments(parser)
|
super().add_arguments(parser)
|
||||||
parser.add_argument("command", choices=["rebuild", "update", "compact"])
|
parser.add_argument("command", choices=["rebuild", "update"])
|
||||||
|
|
||||||
def handle(self, *args: Any, **options: Any) -> None:
|
def handle(self, *args: Any, **options: Any) -> None:
|
||||||
if options["command"] == "compact":
|
|
||||||
llm_index_compact()
|
|
||||||
return
|
|
||||||
llmindex_index(
|
llmindex_index(
|
||||||
rebuild=options["command"] == "rebuild",
|
rebuild=options["command"] == "rebuild",
|
||||||
iter_wrapper=lambda docs: self.track(
|
iter_wrapper=lambda docs: self.track(
|
||||||
|
|||||||
@@ -133,11 +133,14 @@ def _build_suggestion_table(
|
|||||||
else:
|
else:
|
||||||
doc_cell = Text(f"{doc} [{doc.pk}]")
|
doc_cell = Text(f"{doc} [{doc.pk}]")
|
||||||
|
|
||||||
tag_parts: list[str] = []
|
tag_parts: list[str] = [
|
||||||
for tag in sorted(suggestion.tags_to_add, key=lambda t: t.name):
|
f"[green]+{tag.name}[/green]"
|
||||||
tag_parts.append(f"[green]+{tag.name}[/green]")
|
for tag in sorted(suggestion.tags_to_add, key=lambda t: t.name)
|
||||||
for tag in sorted(suggestion.tags_to_remove, key=lambda t: t.name):
|
]
|
||||||
tag_parts.append(f"[red]-{tag.name}[/red]")
|
tag_parts.extend(
|
||||||
|
f"[red]-{tag.name}[/red]"
|
||||||
|
for tag in sorted(suggestion.tags_to_remove, key=lambda t: t.name)
|
||||||
|
)
|
||||||
tag_cell = Text.from_markup(", ".join(tag_parts)) if tag_parts else Text("-")
|
tag_cell = Text.from_markup(", ".join(tag_parts)) if tag_parts else Text("-")
|
||||||
|
|
||||||
table.add_row(
|
table.add_row(
|
||||||
|
|||||||
-63
@@ -1,63 +0,0 @@
|
|||||||
# Generated by Django 5.2.14 on 2026-06-04 15:31
|
|
||||||
|
|
||||||
from django.db import migrations
|
|
||||||
from django.db import models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
replaces = [
|
|
||||||
("documents", "0003_remove_document_storage_type"),
|
|
||||||
("documents", "0004_workflowtrigger_filter_has_any_correspondents_and_more"),
|
|
||||||
("documents", "0005_alter_document_checksum_unique"),
|
|
||||||
]
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
("documents", "0002_squashed"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.RemoveField(
|
|
||||||
model_name="document",
|
|
||||||
name="storage_type",
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="workflowtrigger",
|
|
||||||
name="filter_has_any_correspondents",
|
|
||||||
field=models.ManyToManyField(
|
|
||||||
blank=True,
|
|
||||||
related_name="workflowtriggers_has_any_correspondent",
|
|
||||||
to="documents.correspondent",
|
|
||||||
verbose_name="has one of these correspondents",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="workflowtrigger",
|
|
||||||
name="filter_has_any_document_types",
|
|
||||||
field=models.ManyToManyField(
|
|
||||||
blank=True,
|
|
||||||
related_name="workflowtriggers_has_any_document_type",
|
|
||||||
to="documents.documenttype",
|
|
||||||
verbose_name="has one of these document types",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="workflowtrigger",
|
|
||||||
name="filter_has_any_storage_paths",
|
|
||||||
field=models.ManyToManyField(
|
|
||||||
blank=True,
|
|
||||||
related_name="workflowtriggers_has_any_storage_path",
|
|
||||||
to="documents.storagepath",
|
|
||||||
verbose_name="has one of these storage paths",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="document",
|
|
||||||
name="checksum",
|
|
||||||
field=models.CharField(
|
|
||||||
editable=False,
|
|
||||||
help_text="The checksum of the original document.",
|
|
||||||
max_length=32,
|
|
||||||
verbose_name="checksum",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
-252
@@ -1,252 +0,0 @@
|
|||||||
# Generated by Django 5.2.14 on 2026-06-04 15:31
|
|
||||||
|
|
||||||
import django.db.models.deletion
|
|
||||||
import django.db.models.functions.text
|
|
||||||
from django.db import migrations
|
|
||||||
from django.db import models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
replaces = [
|
|
||||||
("documents", "0008_workflowaction_passwords_alter_workflowaction_type"),
|
|
||||||
("documents", "0009_alter_document_content_length"),
|
|
||||||
("documents", "0010_optimize_integer_field_sizes"),
|
|
||||||
("documents", "0011_alter_workflowaction_type"),
|
|
||||||
("documents", "0012_document_root_document"),
|
|
||||||
]
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
("documents", "0007_sharelinkbundle"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="workflowaction",
|
|
||||||
name="passwords",
|
|
||||||
field=models.JSONField(
|
|
||||||
blank=True,
|
|
||||||
help_text="Passwords to try when removing PDF protection. Separate with commas or new lines.",
|
|
||||||
null=True,
|
|
||||||
verbose_name="passwords",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="document",
|
|
||||||
name="content_length",
|
|
||||||
field=models.GeneratedField(
|
|
||||||
db_persist=True,
|
|
||||||
expression=django.db.models.functions.text.Length("content"),
|
|
||||||
help_text="Length of the content field in characters. Automatically maintained by the database for faster statistics computation.",
|
|
||||||
output_field=models.PositiveIntegerField(default=0),
|
|
||||||
serialize=False,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="correspondent",
|
|
||||||
name="matching_algorithm",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
choices=[
|
|
||||||
(0, "None"),
|
|
||||||
(1, "Any word"),
|
|
||||||
(2, "All words"),
|
|
||||||
(3, "Exact match"),
|
|
||||||
(4, "Regular expression"),
|
|
||||||
(5, "Fuzzy word"),
|
|
||||||
(6, "Automatic"),
|
|
||||||
],
|
|
||||||
default=1,
|
|
||||||
verbose_name="matching algorithm",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="documenttype",
|
|
||||||
name="matching_algorithm",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
choices=[
|
|
||||||
(0, "None"),
|
|
||||||
(1, "Any word"),
|
|
||||||
(2, "All words"),
|
|
||||||
(3, "Exact match"),
|
|
||||||
(4, "Regular expression"),
|
|
||||||
(5, "Fuzzy word"),
|
|
||||||
(6, "Automatic"),
|
|
||||||
],
|
|
||||||
default=1,
|
|
||||||
verbose_name="matching algorithm",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="savedviewfilterrule",
|
|
||||||
name="rule_type",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
choices=[
|
|
||||||
(0, "title contains"),
|
|
||||||
(1, "content contains"),
|
|
||||||
(2, "ASN is"),
|
|
||||||
(3, "correspondent is"),
|
|
||||||
(4, "document type is"),
|
|
||||||
(5, "is in inbox"),
|
|
||||||
(6, "has tag"),
|
|
||||||
(7, "has any tag"),
|
|
||||||
(8, "created before"),
|
|
||||||
(9, "created after"),
|
|
||||||
(10, "created year is"),
|
|
||||||
(11, "created month is"),
|
|
||||||
(12, "created day is"),
|
|
||||||
(13, "added before"),
|
|
||||||
(14, "added after"),
|
|
||||||
(15, "modified before"),
|
|
||||||
(16, "modified after"),
|
|
||||||
(17, "does not have tag"),
|
|
||||||
(18, "does not have ASN"),
|
|
||||||
(19, "title or content contains"),
|
|
||||||
(20, "fulltext query"),
|
|
||||||
(21, "more like this"),
|
|
||||||
(22, "has tags in"),
|
|
||||||
(23, "ASN greater than"),
|
|
||||||
(24, "ASN less than"),
|
|
||||||
(25, "storage path is"),
|
|
||||||
(26, "has correspondent in"),
|
|
||||||
(27, "does not have correspondent in"),
|
|
||||||
(28, "has document type in"),
|
|
||||||
(29, "does not have document type in"),
|
|
||||||
(30, "has storage path in"),
|
|
||||||
(31, "does not have storage path in"),
|
|
||||||
(32, "owner is"),
|
|
||||||
(33, "has owner in"),
|
|
||||||
(34, "does not have owner"),
|
|
||||||
(35, "does not have owner in"),
|
|
||||||
(36, "has custom field value"),
|
|
||||||
(37, "is shared by me"),
|
|
||||||
(38, "has custom fields"),
|
|
||||||
(39, "has custom field in"),
|
|
||||||
(40, "does not have custom field in"),
|
|
||||||
(41, "does not have custom field"),
|
|
||||||
(42, "custom fields query"),
|
|
||||||
(43, "created to"),
|
|
||||||
(44, "created from"),
|
|
||||||
(45, "added to"),
|
|
||||||
(46, "added from"),
|
|
||||||
(47, "mime type is"),
|
|
||||||
],
|
|
||||||
verbose_name="rule type",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="storagepath",
|
|
||||||
name="matching_algorithm",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
choices=[
|
|
||||||
(0, "None"),
|
|
||||||
(1, "Any word"),
|
|
||||||
(2, "All words"),
|
|
||||||
(3, "Exact match"),
|
|
||||||
(4, "Regular expression"),
|
|
||||||
(5, "Fuzzy word"),
|
|
||||||
(6, "Automatic"),
|
|
||||||
],
|
|
||||||
default=1,
|
|
||||||
verbose_name="matching algorithm",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="tag",
|
|
||||||
name="matching_algorithm",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
choices=[
|
|
||||||
(0, "None"),
|
|
||||||
(1, "Any word"),
|
|
||||||
(2, "All words"),
|
|
||||||
(3, "Exact match"),
|
|
||||||
(4, "Regular expression"),
|
|
||||||
(5, "Fuzzy word"),
|
|
||||||
(6, "Automatic"),
|
|
||||||
],
|
|
||||||
default=1,
|
|
||||||
verbose_name="matching algorithm",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="workflowrun",
|
|
||||||
name="type",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
choices=[
|
|
||||||
(1, "Consumption Started"),
|
|
||||||
(2, "Document Added"),
|
|
||||||
(3, "Document Updated"),
|
|
||||||
(4, "Scheduled"),
|
|
||||||
],
|
|
||||||
null=True,
|
|
||||||
verbose_name="workflow trigger type",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="workflowtrigger",
|
|
||||||
name="matching_algorithm",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
choices=[
|
|
||||||
(0, "None"),
|
|
||||||
(1, "Any word"),
|
|
||||||
(2, "All words"),
|
|
||||||
(3, "Exact match"),
|
|
||||||
(4, "Regular expression"),
|
|
||||||
(5, "Fuzzy word"),
|
|
||||||
],
|
|
||||||
default=0,
|
|
||||||
verbose_name="matching algorithm",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="workflowtrigger",
|
|
||||||
name="type",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
choices=[
|
|
||||||
(1, "Consumption Started"),
|
|
||||||
(2, "Document Added"),
|
|
||||||
(3, "Document Updated"),
|
|
||||||
(4, "Scheduled"),
|
|
||||||
],
|
|
||||||
default=1,
|
|
||||||
verbose_name="Workflow Trigger Type",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="workflowaction",
|
|
||||||
name="type",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
choices=[
|
|
||||||
(1, "Assignment"),
|
|
||||||
(2, "Removal"),
|
|
||||||
(3, "Email"),
|
|
||||||
(4, "Webhook"),
|
|
||||||
(5, "Password removal"),
|
|
||||||
(6, "Move to trash"),
|
|
||||||
],
|
|
||||||
default=1,
|
|
||||||
verbose_name="Workflow Action Type",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="document",
|
|
||||||
name="root_document",
|
|
||||||
field=models.ForeignKey(
|
|
||||||
blank=True,
|
|
||||||
null=True,
|
|
||||||
on_delete=django.db.models.deletion.CASCADE,
|
|
||||||
related_name="versions",
|
|
||||||
to="documents.document",
|
|
||||||
verbose_name="root document for this version",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="document",
|
|
||||||
name="version_label",
|
|
||||||
field=models.CharField(
|
|
||||||
blank=True,
|
|
||||||
help_text="Optional short label for a document version.",
|
|
||||||
max_length=64,
|
|
||||||
null=True,
|
|
||||||
verbose_name="version label",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
@@ -369,7 +369,7 @@ class Document(SoftDeleteModel, ModelWithOwner): # type: ignore[django-manager-
|
|||||||
If the queryset already annotated ``effective_content``, that value is used.
|
If the queryset already annotated ``effective_content``, that value is used.
|
||||||
"""
|
"""
|
||||||
if hasattr(self, "effective_content"):
|
if hasattr(self, "effective_content"):
|
||||||
return getattr(self, "effective_content")
|
return self.effective_content
|
||||||
|
|
||||||
if self.root_document_id is not None or self.pk is None:
|
if self.root_document_id is not None or self.pk is None:
|
||||||
return self.content
|
return self.content
|
||||||
@@ -1204,8 +1204,8 @@ class CustomFieldInstance(SoftDeleteModel):
|
|||||||
def get_value_field_name(cls, data_type: CustomField.FieldDataType):
|
def get_value_field_name(cls, data_type: CustomField.FieldDataType):
|
||||||
try:
|
try:
|
||||||
return cls.TYPE_TO_DATA_STORE_NAME_MAP[data_type]
|
return cls.TYPE_TO_DATA_STORE_NAME_MAP[data_type]
|
||||||
except KeyError: # pragma: no cover
|
except KeyError as exc: # pragma: no cover
|
||||||
raise NotImplementedError(data_type)
|
raise NotImplementedError(data_type) from exc
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value(self):
|
def value(self):
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ def run_convert(
|
|||||||
args += ["-define", "pdf:use-cropbox=true"] if use_cropbox else []
|
args += ["-define", "pdf:use-cropbox=true"] if use_cropbox else []
|
||||||
args += [str(input_file), str(output_file)]
|
args += [str(input_file), str(output_file)]
|
||||||
|
|
||||||
logger.debug("Execute: " + " ".join(args), extra={"group": logging_group})
|
logger.debug("Execute: %s", " ".join(args), extra={"group": logging_group})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
run_subprocess(args, environment, logger)
|
run_subprocess(args, environment, logger)
|
||||||
|
|||||||
@@ -67,8 +67,7 @@ class DateParserPluginBase(ABC):
|
|||||||
|
|
||||||
Subclasses can override this to release resources.
|
Subclasses can override this to release resources.
|
||||||
"""
|
"""
|
||||||
# Default implementation does nothing.
|
return None
|
||||||
# Returning None implies exceptions are propagated.
|
|
||||||
|
|
||||||
def _parse_string(
|
def _parse_string(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -8,15 +8,11 @@ from documents.search._backend import get_backend
|
|||||||
from documents.search._backend import reset_backend
|
from documents.search._backend import reset_backend
|
||||||
from documents.search._schema import needs_rebuild
|
from documents.search._schema import needs_rebuild
|
||||||
from documents.search._schema import wipe_index
|
from documents.search._schema import wipe_index
|
||||||
from documents.search._translate import InvalidDateQuery
|
|
||||||
from documents.search._translate import SearchQueryError
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InvalidDateQuery",
|
|
||||||
"SearchHit",
|
"SearchHit",
|
||||||
"SearchIndexLockError",
|
"SearchIndexLockError",
|
||||||
"SearchMode",
|
"SearchMode",
|
||||||
"SearchQueryError",
|
|
||||||
"TantivyBackend",
|
"TantivyBackend",
|
||||||
"TantivyRelevanceList",
|
"TantivyRelevanceList",
|
||||||
"WriteBatch",
|
"WriteBatch",
|
||||||
|
|||||||
@@ -195,12 +195,12 @@ class WriteBatch:
|
|||||||
try:
|
try:
|
||||||
self._lock.acquire(timeout=self._lock_timeout)
|
self._lock.acquire(timeout=self._lock_timeout)
|
||||||
break
|
break
|
||||||
except filelock.Timeout:
|
except filelock.Timeout as exc:
|
||||||
if attempt == _LOCK_RETRY_ATTEMPTS - 1:
|
if attempt == _LOCK_RETRY_ATTEMPTS - 1:
|
||||||
raise SearchIndexLockError(
|
raise SearchIndexLockError(
|
||||||
f"Could not acquire index lock after {_LOCK_RETRY_ATTEMPTS} "
|
f"Could not acquire index lock after {_LOCK_RETRY_ATTEMPTS} "
|
||||||
f"attempts (timeout={self._lock_timeout}s each)",
|
f"attempts (timeout={self._lock_timeout}s each)",
|
||||||
)
|
) from exc
|
||||||
sleep_s = random.uniform(
|
sleep_s = random.uniform(
|
||||||
0,
|
0,
|
||||||
min(_LOCK_BACKOFF_CAP, _LOCK_BACKOFF_BASE * (2**attempt)),
|
min(_LOCK_BACKOFF_CAP, _LOCK_BACKOFF_BASE * (2**attempt)),
|
||||||
@@ -651,7 +651,11 @@ class TantivyBackend:
|
|||||||
result_ids = cast("list[int]", searcher.fast_field_values("id", result_addrs))
|
result_ids = cast("list[int]", searcher.fast_field_values("id", result_addrs))
|
||||||
addr_by_id: dict[int, tuple[float, tantivy.DocAddress]] = {
|
addr_by_id: dict[int, tuple[float, tantivy.DocAddress]] = {
|
||||||
doc_id: (score, addr)
|
doc_id: (score, addr)
|
||||||
for (score, addr), doc_id in zip(batch_results.hits, result_ids)
|
for (score, addr), doc_id in zip(
|
||||||
|
batch_results.hits,
|
||||||
|
result_ids,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
snippet_generator = None
|
snippet_generator = None
|
||||||
@@ -866,24 +870,8 @@ class TantivyBackend:
|
|||||||
final_query = self._apply_permission_filter(mlt_query, user)
|
final_query = self._apply_permission_filter(mlt_query, user)
|
||||||
|
|
||||||
effective_limit = limit if limit is not None else searcher.num_docs
|
effective_limit = limit if limit is not None else searcher.num_docs
|
||||||
try:
|
# Fetch one extra to account for excluding the original document
|
||||||
# Fetch one extra to account for excluding the original document
|
results = searcher.search(final_query, limit=effective_limit + 1)
|
||||||
results = searcher.search(final_query, limit=effective_limit + 1)
|
|
||||||
except BaseException: # pragma: no cover
|
|
||||||
# Tantivy 0.26 panics in BM25 idf scoring when the index holds
|
|
||||||
# soft-deleted documents (doc_freq can exceed the alive doc count),
|
|
||||||
# which only surfaces for the More Like This query. The panic crosses
|
|
||||||
# the pyo3 boundary as a `pyo3_runtime.PanicException` — a
|
|
||||||
# BaseException, not an Exception — so catch BaseException and degrade
|
|
||||||
# to "no similar documents" instead of bubbling a 500 to the client.
|
|
||||||
# Fixed upstream: https://github.com/quickwit-oss/tantivy/pull/2964
|
|
||||||
# Remove once the bundled tantivy includes that fix.
|
|
||||||
logger.warning(
|
|
||||||
"More Like This scoring panicked (likely stale tantivy segment "
|
|
||||||
"stats after deletions); returning no results. A search index "
|
|
||||||
"reindex will rebuild consistent statistics.",
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
|
|
||||||
addrs = [addr for _score, addr in results.hits]
|
addrs = [addr for _score, addr in results.hits]
|
||||||
all_ids = cast("list[int]", searcher.fast_field_values("id", addrs))
|
all_ids = cast("list[int]", searcher.fast_field_values("id", addrs))
|
||||||
|
|||||||
@@ -1,163 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import UTC
|
|
||||||
from datetime import date
|
|
||||||
from datetime import datetime
|
|
||||||
from datetime import timedelta
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from typing import Final
|
|
||||||
|
|
||||||
from dateutil.relativedelta import relativedelta
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from datetime import tzinfo
|
|
||||||
|
|
||||||
_DATE_ONLY_FIELDS = frozenset({"created"})
|
|
||||||
|
|
||||||
_TODAY: Final[str] = "today"
|
|
||||||
_YESTERDAY: Final[str] = "yesterday"
|
|
||||||
_PREVIOUS_WEEK: Final[str] = "previous week"
|
|
||||||
_THIS_MONTH: Final[str] = "this month"
|
|
||||||
_PREVIOUS_MONTH: Final[str] = "previous month"
|
|
||||||
_THIS_YEAR: Final[str] = "this year"
|
|
||||||
_PREVIOUS_YEAR: Final[str] = "previous year"
|
|
||||||
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
|
|
||||||
|
|
||||||
_DATE_KEYWORDS = frozenset(
|
|
||||||
{
|
|
||||||
_TODAY,
|
|
||||||
_YESTERDAY,
|
|
||||||
_PREVIOUS_WEEK,
|
|
||||||
_THIS_MONTH,
|
|
||||||
_PREVIOUS_MONTH,
|
|
||||||
_THIS_YEAR,
|
|
||||||
_PREVIOUS_YEAR,
|
|
||||||
_PREVIOUS_QUARTER,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _fmt(dt: datetime) -> str:
|
|
||||||
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
|
|
||||||
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
||||||
|
|
||||||
|
|
||||||
def _iso_range(lo: datetime, hi: datetime) -> str:
|
|
||||||
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
|
|
||||||
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
|
|
||||||
|
|
||||||
|
|
||||||
def _quarter_start(d: date) -> date:
|
|
||||||
"""Return the first day of the calendar quarter containing ``d``."""
|
|
||||||
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def _midnight(d: date, tz: tzinfo) -> datetime:
|
|
||||||
"""Convert a calendar date at local-timezone midnight to a UTC datetime."""
|
|
||||||
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
|
|
||||||
|
|
||||||
|
|
||||||
def _keyword_bounds(keyword: str, tz: tzinfo) -> tuple[date, date]:
|
|
||||||
"""
|
|
||||||
Map a relative date keyword to ``(start, exclusive_end)`` calendar dates.
|
|
||||||
|
|
||||||
``tz`` only determines what "today" is; the caller decides how the returned
|
|
||||||
dates become UTC datetime boundaries (date-only vs. local-midnight offset).
|
|
||||||
"""
|
|
||||||
today = datetime.now(tz).date()
|
|
||||||
if keyword == _TODAY:
|
|
||||||
return today, today + timedelta(days=1)
|
|
||||||
if keyword == _YESTERDAY:
|
|
||||||
return today - timedelta(days=1), today
|
|
||||||
if keyword == _PREVIOUS_WEEK:
|
|
||||||
this_monday = today - timedelta(days=today.weekday())
|
|
||||||
return this_monday - timedelta(weeks=1), this_monday
|
|
||||||
if keyword == _THIS_MONTH:
|
|
||||||
first = today.replace(day=1)
|
|
||||||
return first, first + relativedelta(months=1)
|
|
||||||
if keyword == _PREVIOUS_MONTH:
|
|
||||||
this_first = today.replace(day=1)
|
|
||||||
return this_first - relativedelta(months=1), this_first
|
|
||||||
if keyword == _THIS_YEAR:
|
|
||||||
return date(today.year, 1, 1), date(today.year + 1, 1, 1)
|
|
||||||
if keyword == _PREVIOUS_YEAR:
|
|
||||||
return date(today.year - 1, 1, 1), date(today.year, 1, 1)
|
|
||||||
if keyword == _PREVIOUS_QUARTER:
|
|
||||||
this_quarter = _quarter_start(today)
|
|
||||||
return this_quarter - relativedelta(months=3), this_quarter
|
|
||||||
raise ValueError(f"Unknown keyword: {keyword}")
|
|
||||||
|
|
||||||
|
|
||||||
def _date_only_range(keyword: str, tz: tzinfo) -> str:
|
|
||||||
"""
|
|
||||||
For `created` (DateField): use the local calendar date, converted to
|
|
||||||
midnight UTC boundaries. No offset arithmetic — date only.
|
|
||||||
"""
|
|
||||||
start, end = _keyword_bounds(keyword, tz)
|
|
||||||
lo = datetime(start.year, start.month, start.day, tzinfo=UTC)
|
|
||||||
hi = datetime(end.year, end.month, end.day, tzinfo=UTC)
|
|
||||||
return _iso_range(lo, hi)
|
|
||||||
|
|
||||||
|
|
||||||
def _datetime_range(keyword: str, tz: tzinfo) -> str:
|
|
||||||
"""
|
|
||||||
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
|
|
||||||
boundaries to UTC — full offset arithmetic required.
|
|
||||||
"""
|
|
||||||
start, end = _keyword_bounds(keyword, tz)
|
|
||||||
return _iso_range(_midnight(start, tz), _midnight(end, tz))
|
|
||||||
|
|
||||||
|
|
||||||
def _precision_bounds(digits: str) -> tuple[date, date] | None:
|
|
||||||
"""
|
|
||||||
Map a 4/6/8-digit date token to (start, exclusive_end) calendar dates.
|
|
||||||
|
|
||||||
YYYY -> whole year, YYYYMM -> whole month, YYYYMMDD -> single day.
|
|
||||||
Returns None for any unparsable or out-of-range value (e.g. month 23),
|
|
||||||
so callers can emit a no-match clause instead of erroring (Whoosh parity).
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if len(digits) == 4:
|
|
||||||
year = int(digits)
|
|
||||||
return date(year, 1, 1), date(year + 1, 1, 1)
|
|
||||||
if len(digits) == 6:
|
|
||||||
year, month = int(digits[:4]), int(digits[4:6])
|
|
||||||
start = date(year, month, 1)
|
|
||||||
end = date(year + 1, 1, 1) if month == 12 else date(year, month + 1, 1)
|
|
||||||
return start, end
|
|
||||||
if len(digits) == 8:
|
|
||||||
start = date(int(digits[:4]), int(digits[4:6]), int(digits[6:8]))
|
|
||||||
return start, start + timedelta(days=1)
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _utc_bounds_for_field(
|
|
||||||
field: str,
|
|
||||||
start: date,
|
|
||||||
end: date,
|
|
||||||
tz: tzinfo,
|
|
||||||
) -> tuple[datetime, datetime]:
|
|
||||||
"""
|
|
||||||
Convert calendar-date bounds to UTC datetimes per the field's storage type.
|
|
||||||
|
|
||||||
For DateField (``created``) the bounds are UTC midnight (no offset). For
|
|
||||||
DateTimeField (``added``/``modified``) the bounds are local-tz midnight
|
|
||||||
converted to UTC, matching how each field is indexed.
|
|
||||||
"""
|
|
||||||
if field in _DATE_ONLY_FIELDS:
|
|
||||||
return (
|
|
||||||
datetime(start.year, start.month, start.day, tzinfo=UTC),
|
|
||||||
datetime(end.year, end.month, end.day, tzinfo=UTC),
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
datetime(start.year, start.month, start.day, tzinfo=tz).astimezone(UTC),
|
|
||||||
datetime(end.year, end.month, end.day, tzinfo=tz).astimezone(UTC),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _field_range_from_dates(field: str, start: date, end: date, tz: tzinfo) -> str:
|
|
||||||
"""Build a Tantivy ``field:[lo TO hi]`` ISO range from calendar-date bounds."""
|
|
||||||
lo, hi = _utc_bounds_for_field(field, start, end, tz)
|
|
||||||
return f"{field}:{_iso_range(lo, hi)}"
|
|
||||||
+409
-27
@@ -1,35 +1,88 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import UTC
|
from datetime import UTC
|
||||||
|
from datetime import date
|
||||||
|
from datetime import datetime
|
||||||
|
from datetime import timedelta
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from typing import Final
|
from typing import Final
|
||||||
|
|
||||||
import regex
|
import regex
|
||||||
import tantivy
|
import tantivy
|
||||||
|
from dateutil.relativedelta import relativedelta
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
|
||||||
from documents.search._dates import (
|
|
||||||
_date_only_range, # noqa: F401 — re-exported for test imports
|
|
||||||
)
|
|
||||||
from documents.search._dates import (
|
|
||||||
_datetime_range, # noqa: F401 — re-exported for test imports
|
|
||||||
)
|
|
||||||
from documents.search._tokenizer import simple_search_tokens
|
from documents.search._tokenizer import simple_search_tokens
|
||||||
from documents.search._translate import SearchQueryError
|
|
||||||
from documents.search._translate import translate_query
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datetime import tzinfo
|
from datetime import tzinfo
|
||||||
|
|
||||||
from django.contrib.auth.base_user import AbstractBaseUser
|
from django.contrib.auth.base_user import AbstractBaseUser
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.search")
|
|
||||||
|
|
||||||
# Maximum seconds any single regex substitution may run.
|
# Maximum seconds any single regex substitution may run.
|
||||||
# Prevents ReDoS on adversarial user-supplied query strings.
|
# Prevents ReDoS on adversarial user-supplied query strings.
|
||||||
_REGEX_TIMEOUT: Final[float] = 1.0
|
_REGEX_TIMEOUT: Final[float] = 1.0
|
||||||
|
|
||||||
|
_DATE_ONLY_FIELDS = frozenset({"created"})
|
||||||
|
|
||||||
|
_TODAY: Final[str] = "today"
|
||||||
|
_YESTERDAY: Final[str] = "yesterday"
|
||||||
|
_PREVIOUS_WEEK: Final[str] = "previous week"
|
||||||
|
_THIS_MONTH: Final[str] = "this month"
|
||||||
|
_PREVIOUS_MONTH: Final[str] = "previous month"
|
||||||
|
_THIS_YEAR: Final[str] = "this year"
|
||||||
|
_PREVIOUS_YEAR: Final[str] = "previous year"
|
||||||
|
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
|
||||||
|
|
||||||
|
_DATE_KEYWORDS = frozenset(
|
||||||
|
{
|
||||||
|
_TODAY,
|
||||||
|
_YESTERDAY,
|
||||||
|
_PREVIOUS_WEEK,
|
||||||
|
_THIS_MONTH,
|
||||||
|
_PREVIOUS_MONTH,
|
||||||
|
_THIS_YEAR,
|
||||||
|
_PREVIOUS_YEAR,
|
||||||
|
_PREVIOUS_QUARTER,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
_DATE_KEYWORD_PATTERN = "|".join(
|
||||||
|
sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
_FIELD_DATE_RE = regex.compile(
|
||||||
|
rf"""(?<!\w)(?P<field>created|modified|added)\s*:\s*(?:
|
||||||
|
(?P<quote>["'])(?P<quoted>{_DATE_KEYWORD_PATTERN})(?P=quote)
|
||||||
|
|
|
||||||
|
(?P<bare>{_DATE_KEYWORD_PATTERN})(?![\w-])
|
||||||
|
)""",
|
||||||
|
regex.IGNORECASE | regex.VERBOSE,
|
||||||
|
)
|
||||||
|
_COMPACT_DATE_RE = regex.compile(r"\b(\d{14})\b")
|
||||||
|
_RELATIVE_RANGE_RE = regex.compile(
|
||||||
|
r"\[now([+-]\d+[dhm])?\s+TO\s+now([+-]\d+[dhm])?\]",
|
||||||
|
regex.IGNORECASE,
|
||||||
|
)
|
||||||
|
# Whoosh-style relative date range: e.g. [-1 week to now], [-7 days to now]
|
||||||
|
_WHOOSH_REL_RANGE_RE = regex.compile(
|
||||||
|
r"\[-(?P<n>\d+)\s+(?P<unit>second|minute|hour|day|week|month|year)s?\s+to\s+now\]",
|
||||||
|
regex.IGNORECASE,
|
||||||
|
)
|
||||||
|
# Whoosh-style 8-digit date: field:YYYYMMDD — field-aware so timezone can be applied correctly.
|
||||||
|
# Scoped to date fields only; numeric fields (asn, id, page_count, ...) must not be rewritten.
|
||||||
|
_DATE8_RE = regex.compile(
|
||||||
|
r"(?<!\w)(?P<field>created|modified|added):(?P<date8>\d{8})\b",
|
||||||
|
)
|
||||||
|
_YEAR_RANGE_RE = regex.compile(
|
||||||
|
r"(?<!\w)(?P<field>created|modified|added):\[(?P<y1>\d{4})\s+TO\s+(?P<y2>\d{4})\]",
|
||||||
|
regex.IGNORECASE,
|
||||||
|
)
|
||||||
|
# Tantivy syntax error: " - " and " + " with spaces on both sides are invalid because
|
||||||
|
# the NOT/MUST operators require no space between the operator and the term.
|
||||||
|
# In natural-language queries (e.g., "H52.1 - Kurzsichtigkeit"), the dash is a separator.
|
||||||
|
_SPACED_OPERATOR_RE = regex.compile(r"\s+[-+]\s+")
|
||||||
|
_TRAILING_OPERATOR_RE = regex.compile(r"\s+[-+]+\s*$")
|
||||||
# Matches CJK/Hangul characters so queries can be routed to bigram fields.
|
# Matches CJK/Hangul characters so queries can be routed to bigram fields.
|
||||||
# Uses Unicode properties to cover all blocks including Extension B+ planes.
|
# Uses Unicode properties to cover all blocks including Extension B+ planes.
|
||||||
_CJK_RE: Final = regex.compile(r"[\p{Han}\p{Hiragana}\p{Katakana}\p{Hangul}]+")
|
_CJK_RE: Final = regex.compile(r"[\p{Han}\p{Hiragana}\p{Katakana}\p{Hangul}]+")
|
||||||
@@ -64,12 +117,305 @@ def _build_cjk_query(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt(dt: datetime) -> str:
|
||||||
|
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
|
||||||
|
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
|
||||||
|
def _iso_range(lo: datetime, hi: datetime) -> str:
|
||||||
|
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
|
||||||
|
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
|
||||||
|
|
||||||
|
|
||||||
|
def _date_only_range(keyword: str, tz: tzinfo) -> str:
|
||||||
|
"""
|
||||||
|
For `created` (DateField): use the local calendar date, converted to
|
||||||
|
midnight UTC boundaries. No offset arithmetic — date only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
today = datetime.now(tz).date()
|
||||||
|
|
||||||
|
def _quarter_start(d: date) -> date:
|
||||||
|
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
|
||||||
|
|
||||||
|
if keyword == _TODAY:
|
||||||
|
lo = datetime(today.year, today.month, today.day, tzinfo=UTC)
|
||||||
|
return _iso_range(lo, lo + timedelta(days=1))
|
||||||
|
if keyword == _YESTERDAY:
|
||||||
|
y = today - timedelta(days=1)
|
||||||
|
lo = datetime(y.year, y.month, y.day, tzinfo=UTC)
|
||||||
|
hi = datetime(today.year, today.month, today.day, tzinfo=UTC)
|
||||||
|
return _iso_range(lo, hi)
|
||||||
|
if keyword == _PREVIOUS_WEEK:
|
||||||
|
this_mon = today - timedelta(days=today.weekday())
|
||||||
|
last_mon = this_mon - timedelta(weeks=1)
|
||||||
|
lo = datetime(last_mon.year, last_mon.month, last_mon.day, tzinfo=UTC)
|
||||||
|
hi = datetime(this_mon.year, this_mon.month, this_mon.day, tzinfo=UTC)
|
||||||
|
return _iso_range(lo, hi)
|
||||||
|
if keyword == _THIS_MONTH:
|
||||||
|
lo = datetime(today.year, today.month, 1, tzinfo=UTC)
|
||||||
|
if today.month == 12:
|
||||||
|
hi = datetime(today.year + 1, 1, 1, tzinfo=UTC)
|
||||||
|
else:
|
||||||
|
hi = datetime(today.year, today.month + 1, 1, tzinfo=UTC)
|
||||||
|
return _iso_range(lo, hi)
|
||||||
|
if keyword == _PREVIOUS_MONTH:
|
||||||
|
if today.month == 1:
|
||||||
|
lo = datetime(today.year - 1, 12, 1, tzinfo=UTC)
|
||||||
|
else:
|
||||||
|
lo = datetime(today.year, today.month - 1, 1, tzinfo=UTC)
|
||||||
|
hi = datetime(today.year, today.month, 1, tzinfo=UTC)
|
||||||
|
return _iso_range(lo, hi)
|
||||||
|
if keyword == _THIS_YEAR:
|
||||||
|
lo = datetime(today.year, 1, 1, tzinfo=UTC)
|
||||||
|
return _iso_range(lo, datetime(today.year + 1, 1, 1, tzinfo=UTC))
|
||||||
|
if keyword == _PREVIOUS_YEAR:
|
||||||
|
lo = datetime(today.year - 1, 1, 1, tzinfo=UTC)
|
||||||
|
return _iso_range(lo, datetime(today.year, 1, 1, tzinfo=UTC))
|
||||||
|
if keyword == _PREVIOUS_QUARTER:
|
||||||
|
this_quarter = _quarter_start(today)
|
||||||
|
last_quarter = this_quarter - relativedelta(months=3)
|
||||||
|
lo = datetime(
|
||||||
|
last_quarter.year,
|
||||||
|
last_quarter.month,
|
||||||
|
last_quarter.day,
|
||||||
|
tzinfo=UTC,
|
||||||
|
)
|
||||||
|
hi = datetime(
|
||||||
|
this_quarter.year,
|
||||||
|
this_quarter.month,
|
||||||
|
this_quarter.day,
|
||||||
|
tzinfo=UTC,
|
||||||
|
)
|
||||||
|
return _iso_range(lo, hi)
|
||||||
|
raise ValueError(f"Unknown keyword: {keyword}")
|
||||||
|
|
||||||
|
|
||||||
|
def _datetime_range(keyword: str, tz: tzinfo) -> str:
|
||||||
|
"""
|
||||||
|
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
|
||||||
|
boundaries to UTC — full offset arithmetic required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
now_local = datetime.now(tz)
|
||||||
|
today = now_local.date()
|
||||||
|
|
||||||
|
def _midnight(d: date) -> datetime:
|
||||||
|
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
|
||||||
|
|
||||||
|
def _quarter_start(d: date) -> date:
|
||||||
|
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
|
||||||
|
|
||||||
|
if keyword == _TODAY:
|
||||||
|
return _iso_range(_midnight(today), _midnight(today + timedelta(days=1)))
|
||||||
|
if keyword == _YESTERDAY:
|
||||||
|
y = today - timedelta(days=1)
|
||||||
|
return _iso_range(_midnight(y), _midnight(today))
|
||||||
|
if keyword == _PREVIOUS_WEEK:
|
||||||
|
this_mon = today - timedelta(days=today.weekday())
|
||||||
|
last_mon = this_mon - timedelta(weeks=1)
|
||||||
|
return _iso_range(_midnight(last_mon), _midnight(this_mon))
|
||||||
|
if keyword == _THIS_MONTH:
|
||||||
|
first = today.replace(day=1)
|
||||||
|
if today.month == 12:
|
||||||
|
next_first = date(today.year + 1, 1, 1)
|
||||||
|
else:
|
||||||
|
next_first = date(today.year, today.month + 1, 1)
|
||||||
|
return _iso_range(_midnight(first), _midnight(next_first))
|
||||||
|
if keyword == _PREVIOUS_MONTH:
|
||||||
|
this_first = today.replace(day=1)
|
||||||
|
if today.month == 1:
|
||||||
|
last_first = date(today.year - 1, 12, 1)
|
||||||
|
else:
|
||||||
|
last_first = date(today.year, today.month - 1, 1)
|
||||||
|
return _iso_range(_midnight(last_first), _midnight(this_first))
|
||||||
|
if keyword == _THIS_YEAR:
|
||||||
|
return _iso_range(
|
||||||
|
_midnight(date(today.year, 1, 1)),
|
||||||
|
_midnight(date(today.year + 1, 1, 1)),
|
||||||
|
)
|
||||||
|
if keyword == _PREVIOUS_YEAR:
|
||||||
|
return _iso_range(
|
||||||
|
_midnight(date(today.year - 1, 1, 1)),
|
||||||
|
_midnight(date(today.year, 1, 1)),
|
||||||
|
)
|
||||||
|
if keyword == _PREVIOUS_QUARTER:
|
||||||
|
this_quarter = _quarter_start(today)
|
||||||
|
last_quarter = this_quarter - relativedelta(months=3)
|
||||||
|
return _iso_range(_midnight(last_quarter), _midnight(this_quarter))
|
||||||
|
raise ValueError(f"Unknown keyword: {keyword}")
|
||||||
|
|
||||||
|
|
||||||
|
def _rewrite_compact_date(query: str) -> str:
|
||||||
|
"""Rewrite Whoosh compact date tokens (14-digit YYYYMMDDHHmmss) to ISO 8601."""
|
||||||
|
|
||||||
|
def _sub(m: regex.Match[str]) -> str:
|
||||||
|
raw = m.group(1)
|
||||||
|
try:
|
||||||
|
dt = datetime(
|
||||||
|
int(raw[0:4]),
|
||||||
|
int(raw[4:6]),
|
||||||
|
int(raw[6:8]),
|
||||||
|
int(raw[8:10]),
|
||||||
|
int(raw[10:12]),
|
||||||
|
int(raw[12:14]),
|
||||||
|
tzinfo=UTC,
|
||||||
|
)
|
||||||
|
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
except ValueError:
|
||||||
|
return str(m.group(0))
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _COMPACT_DATE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
|
||||||
|
except TimeoutError: # pragma: no cover
|
||||||
|
raise ValueError(
|
||||||
|
"Query too complex to process (compact date rewrite timed out)",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
def _rewrite_relative_range(query: str) -> str:
|
||||||
|
"""Rewrite Whoosh relative ranges ([now-7d TO now]) to concrete ISO 8601 UTC boundaries."""
|
||||||
|
|
||||||
|
def _sub(m: regex.Match[str]) -> str:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
def _offset(s: str | None) -> timedelta:
|
||||||
|
if not s:
|
||||||
|
return timedelta(0)
|
||||||
|
sign = 1 if s[0] == "+" else -1
|
||||||
|
n, unit = int(s[1:-1]), s[-1]
|
||||||
|
return (
|
||||||
|
sign
|
||||||
|
* {
|
||||||
|
"d": timedelta(days=n),
|
||||||
|
"h": timedelta(hours=n),
|
||||||
|
"m": timedelta(minutes=n),
|
||||||
|
}[unit]
|
||||||
|
)
|
||||||
|
|
||||||
|
lo, hi = now + _offset(m.group(1)), now + _offset(m.group(2))
|
||||||
|
if lo > hi:
|
||||||
|
lo, hi = hi, lo
|
||||||
|
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _RELATIVE_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
|
||||||
|
except TimeoutError: # pragma: no cover
|
||||||
|
raise ValueError(
|
||||||
|
"Query too complex to process (relative range rewrite timed out)",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
def _rewrite_whoosh_relative_range(query: str) -> str:
|
||||||
|
"""Rewrite Whoosh-style relative date ranges ([-N unit to now]) to ISO 8601.
|
||||||
|
|
||||||
|
Supports: second, minute, hour, day, week, month, year (singular and plural).
|
||||||
|
Example: ``added:[-1 week to now]`` → ``added:[2025-01-01T… TO 2025-01-08T…]``
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
def _sub(m: regex.Match[str]) -> str:
|
||||||
|
n = int(m.group("n"))
|
||||||
|
unit = m.group("unit").lower()
|
||||||
|
delta_map: dict[str, timedelta | relativedelta] = {
|
||||||
|
"second": timedelta(seconds=n),
|
||||||
|
"minute": timedelta(minutes=n),
|
||||||
|
"hour": timedelta(hours=n),
|
||||||
|
"day": timedelta(days=n),
|
||||||
|
"week": timedelta(weeks=n),
|
||||||
|
"month": relativedelta(months=n),
|
||||||
|
"year": relativedelta(years=n),
|
||||||
|
}
|
||||||
|
lo = now - delta_map[unit]
|
||||||
|
return f"[{_fmt(lo)} TO {_fmt(now)}]"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _WHOOSH_REL_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
|
||||||
|
except TimeoutError: # pragma: no cover
|
||||||
|
raise ValueError(
|
||||||
|
"Query too complex to process (Whoosh relative range rewrite timed out)",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
def _rewrite_8digit_date(query: str, tz: tzinfo) -> str:
|
||||||
|
"""Rewrite field:YYYYMMDD date tokens to an ISO 8601 day range.
|
||||||
|
|
||||||
|
Runs after ``_rewrite_compact_date`` so 14-digit timestamps are already
|
||||||
|
converted and won't spuriously match here.
|
||||||
|
|
||||||
|
For DateField fields (e.g. ``created``) uses UTC midnight boundaries.
|
||||||
|
For DateTimeField fields (e.g. ``added``, ``modified``) uses local TZ
|
||||||
|
midnight boundaries converted to UTC — matching the ``_datetime_range``
|
||||||
|
behaviour for keyword dates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _sub(m: regex.Match[str]) -> str:
|
||||||
|
field = m.group("field")
|
||||||
|
raw = m.group("date8")
|
||||||
|
try:
|
||||||
|
year, month, day = int(raw[0:4]), int(raw[4:6]), int(raw[6:8])
|
||||||
|
d = date(year, month, day)
|
||||||
|
if field in _DATE_ONLY_FIELDS:
|
||||||
|
lo = datetime(d.year, d.month, d.day, tzinfo=UTC)
|
||||||
|
hi = lo + timedelta(days=1)
|
||||||
|
else:
|
||||||
|
# DateTimeField: use local-timezone midnight → UTC
|
||||||
|
lo = datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
|
||||||
|
hi = datetime(
|
||||||
|
(d + timedelta(days=1)).year,
|
||||||
|
(d + timedelta(days=1)).month,
|
||||||
|
(d + timedelta(days=1)).day,
|
||||||
|
tzinfo=tz,
|
||||||
|
).astimezone(UTC)
|
||||||
|
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
|
||||||
|
except ValueError:
|
||||||
|
return m.group(0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _DATE8_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
|
||||||
|
except TimeoutError: # pragma: no cover
|
||||||
|
raise ValueError(
|
||||||
|
"Query too complex to process (8-digit date rewrite timed out)",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
def _rewrite_year_range(query: str) -> str:
|
||||||
|
"""Rewrite Whoosh-style year-only date ranges to ISO 8601 UTC boundaries.
|
||||||
|
|
||||||
|
Converts ``field:[YYYY TO YYYY]`` to a full ISO 8601 datetime range.
|
||||||
|
The upper bound is the start of the year after the end year (exclusive),
|
||||||
|
matching the Whoosh convention of treating year-only ranges as full-year spans.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _sub(m: regex.Match[str]) -> str:
|
||||||
|
field = m.group("field")
|
||||||
|
y1, y2 = int(m.group("y1")), int(m.group("y2"))
|
||||||
|
# Whoosh swaps a reversed range when both years are explicit
|
||||||
|
# (whoosh.util.times.timespan.disambiguated); match that so a backwards
|
||||||
|
# range spans the intended years instead of matching nothing.
|
||||||
|
lo_year, hi_year = min(y1, y2), max(y1, y2)
|
||||||
|
lo = datetime(lo_year, 1, 1, tzinfo=UTC)
|
||||||
|
hi = datetime(hi_year + 1, 1, 1, tzinfo=UTC)
|
||||||
|
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _YEAR_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
|
||||||
|
except TimeoutError: # pragma: no cover
|
||||||
|
raise ValueError(
|
||||||
|
"Query too complex to process (year range rewrite timed out)",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
|
def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
|
||||||
"""
|
"""
|
||||||
Rewrite natural date syntax to ISO 8601 format for Tantivy compatibility.
|
Rewrite natural date syntax to ISO 8601 format for Tantivy compatibility.
|
||||||
|
|
||||||
Delegates to ``translate_query`` which handles all date forms, comma
|
Performs the first stage of query preprocessing, converting various date
|
||||||
expansion, field aliasing, relative ranges, and operator normalization.
|
formats and keywords to ISO 8601 datetime ranges that Tantivy can parse:
|
||||||
|
- Compact 14-digit dates (YYYYMMDDHHmmss)
|
||||||
|
- Whoosh relative ranges ([-7 days to now], [now-1h TO now+2h])
|
||||||
|
- 8-digit dates with field awareness (created:20240115)
|
||||||
|
- Natural keywords (field:today, field:"previous quarter", etc.)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Raw user query string
|
query: Raw user query string
|
||||||
@@ -81,15 +427,35 @@ def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
|
|||||||
Note:
|
Note:
|
||||||
Bare keywords without field prefixes pass through unchanged.
|
Bare keywords without field prefixes pass through unchanged.
|
||||||
"""
|
"""
|
||||||
return translate_query(query, tz)
|
query = _rewrite_compact_date(query)
|
||||||
|
query = _rewrite_whoosh_relative_range(query)
|
||||||
|
query = _rewrite_year_range(query)
|
||||||
|
query = _rewrite_8digit_date(query, tz)
|
||||||
|
query = _rewrite_relative_range(query)
|
||||||
|
|
||||||
|
def _replace(m: regex.Match[str]) -> str:
|
||||||
|
field = m.group("field")
|
||||||
|
keyword = (m.group("quoted") or m.group("bare")).lower()
|
||||||
|
if field in _DATE_ONLY_FIELDS:
|
||||||
|
return f"{field}:{_date_only_range(keyword, tz)}"
|
||||||
|
return f"{field}:{_datetime_range(keyword, tz)}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _FIELD_DATE_RE.sub(_replace, query, timeout=_REGEX_TIMEOUT)
|
||||||
|
except TimeoutError: # pragma: no cover
|
||||||
|
raise ValueError(
|
||||||
|
"Query too complex to process (date keyword rewrite timed out)",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
def normalize_query(query: str) -> str:
|
def normalize_query(query: str) -> str:
|
||||||
"""
|
"""
|
||||||
Normalize query syntax for better search behavior.
|
Normalize query syntax for better search behavior.
|
||||||
|
|
||||||
Delegates to ``translate_query`` which handles comma expansion, whitespace
|
Expands comma-separated field values to explicit AND clauses and
|
||||||
collapsing, operator normalization, and field aliasing.
|
collapses excessive whitespace for cleaner parsing:
|
||||||
|
- tag:foo,bar → tag:foo AND tag:bar
|
||||||
|
- multiple spaces → single spaces
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Query string after date rewriting
|
query: Query string after date rewriting
|
||||||
@@ -97,7 +463,31 @@ def normalize_query(query: str) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
Normalized query string ready for Tantivy parsing
|
Normalized query string ready for Tantivy parsing
|
||||||
"""
|
"""
|
||||||
return translate_query(query, UTC)
|
|
||||||
|
def _expand(m: regex.Match[str]) -> str:
|
||||||
|
field = m.group(1)
|
||||||
|
values = [v.strip() for v in m.group(2).split(",") if v.strip()]
|
||||||
|
return " AND ".join(f"{field}:{v}" for v in values)
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = regex.sub(
|
||||||
|
r"(\w+):([^\s\[\]]+(?:,[^\s\[\]]+)+)",
|
||||||
|
_expand,
|
||||||
|
query,
|
||||||
|
timeout=_REGEX_TIMEOUT,
|
||||||
|
)
|
||||||
|
query = regex.sub(r" {2,}", " ", query, timeout=_REGEX_TIMEOUT).strip()
|
||||||
|
# Strip trailing dangling operators before Tantivy sees them.
|
||||||
|
query = _TRAILING_OPERATOR_RE.sub("", query, timeout=_REGEX_TIMEOUT).strip()
|
||||||
|
# Replace " - " / " + " with a space: Tantivy requires no space between
|
||||||
|
# the operator and its operand (-term / +term), so spaces on both sides
|
||||||
|
# means this is a natural-language separator, not a query operator.
|
||||||
|
query = _SPACED_OPERATOR_RE.sub(" ", query, timeout=_REGEX_TIMEOUT).strip()
|
||||||
|
return query
|
||||||
|
except TimeoutError: # pragma: no cover
|
||||||
|
raise ValueError(
|
||||||
|
"Query too complex to process (normalization timed out)",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
def build_permission_filter(
|
def build_permission_filter(
|
||||||
@@ -217,16 +607,8 @@ def parse_user_query(
|
|||||||
as a post-search score filter, not during query construction.
|
as a post-search score filter, not during query construction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
query_str = rewrite_natural_date_keywords(raw_query, tz)
|
||||||
query_str = translate_query(raw_query, tz)
|
query_str = normalize_query(query_str)
|
||||||
except SearchQueryError:
|
|
||||||
# Intentional, user-fixable error (e.g. an unparsable date). Propagate so
|
|
||||||
# the view can return a 400 with a helpful message rather than falling
|
|
||||||
# back to the raw (still-invalid) query.
|
|
||||||
raise
|
|
||||||
except Exception: # pragma: no cover - defensive
|
|
||||||
logger.warning("Query translation failed; using raw query", exc_info=True)
|
|
||||||
query_str = raw_query
|
|
||||||
|
|
||||||
exact = index.parse_query(
|
exact = index.parse_query(
|
||||||
query_str,
|
query_str,
|
||||||
|
|||||||
@@ -1,566 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import UTC
|
|
||||||
from datetime import datetime
|
|
||||||
from datetime import timedelta
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
import regex
|
|
||||||
from dateutil.relativedelta import relativedelta
|
|
||||||
|
|
||||||
from documents.search._dates import _DATE_KEYWORDS
|
|
||||||
from documents.search._dates import _DATE_ONLY_FIELDS
|
|
||||||
from documents.search._dates import _date_only_range
|
|
||||||
from documents.search._dates import _datetime_range
|
|
||||||
from documents.search._dates import _field_range_from_dates
|
|
||||||
from documents.search._dates import _fmt
|
|
||||||
from documents.search._dates import _precision_bounds
|
|
||||||
from documents.search._dates import _utc_bounds_for_field
|
|
||||||
|
|
||||||
# Compiled regex that matches any known multi-word (or single-word) date keyword
|
|
||||||
# at the start of a match position, longest alternatives first so "previous week"
|
|
||||||
# wins over a hypothetical shorter "previous".
|
|
||||||
_KEYWORD_VALUE_RE = regex.compile(
|
|
||||||
"|".join(sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True)),
|
|
||||||
regex.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from datetime import tzinfo
|
|
||||||
|
|
||||||
# TODO: this module translates date queries into Tantivy *string* syntax, which
|
|
||||||
# forces a workaround for something Tantivy's string parser cannot express on
|
|
||||||
# date fields: open-ended ranges use far-past/far-future string sentinels
|
|
||||||
# (OPEN_LO/OPEN_HI). These can be replaced with a real tantivy.Query object
|
|
||||||
# (Query.range_query(..., None) for open bounds) once tantivy-py accepts Python
|
|
||||||
# datetimes in range_query/term_query on Date fields. That support exists on
|
|
||||||
# tantivy-py master (PRs #655 + #666) but postdates the pinned 0.26.0 wheel, so
|
|
||||||
# it is blocked only on a published release > 0.26.0 and a dependency bump.
|
|
||||||
# (Unparsable dates now raise InvalidDateQuery -> HTTP 400 rather than using a
|
|
||||||
# no-match string sentinel.)
|
|
||||||
|
|
||||||
# Fields that store exact, non-analyzed comma-joined tokens in the index and so
|
|
||||||
# need explicit comma->AND expansion (Whoosh KEYWORD(commas=True) set).
|
|
||||||
MULTI_VALUE_FIELDS = frozenset({"tag", "tag_id", "viewer_id"})
|
|
||||||
|
|
||||||
# Date fields whose values/ranges get rewritten to RFC3339 Tantivy ranges.
|
|
||||||
DATE_FIELDS = frozenset({"created", "modified", "added"})
|
|
||||||
|
|
||||||
# Field aliases: Whoosh (v2) field names that were renamed in the Tantivy schema.
|
|
||||||
# Preserved here so v2 queries using the old names continue to work without 400
|
|
||||||
# errors instead of silently failing. Applied by _render to non-date field tokens.
|
|
||||||
FIELD_ALIASES: dict[str, str] = {
|
|
||||||
"type": "document_type",
|
|
||||||
"type_id": "document_type_id",
|
|
||||||
"path": "storage_path",
|
|
||||||
"path_id": "storage_path_id",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Known schema fields: a comma immediately followed by ``<known>:`` is a clause
|
|
||||||
# separator. Restricting to known fields prevents URL-like ``http:`` misfires.
|
|
||||||
KNOWN_FIELDS = frozenset(
|
|
||||||
{
|
|
||||||
"title",
|
|
||||||
"content",
|
|
||||||
"correspondent",
|
|
||||||
"document_type",
|
|
||||||
"type", # v2 alias -> document_type
|
|
||||||
"storage_path",
|
|
||||||
"path", # v2 alias -> storage_path
|
|
||||||
"tag",
|
|
||||||
"tag_id",
|
|
||||||
"correspondent_id",
|
|
||||||
"document_type_id",
|
|
||||||
"type_id", # v2 alias -> document_type_id
|
|
||||||
"storage_path_id",
|
|
||||||
"path_id", # v2 alias -> storage_path_id
|
|
||||||
"owner_id",
|
|
||||||
"viewer_id",
|
|
||||||
"asn",
|
|
||||||
"page_count",
|
|
||||||
"num_notes",
|
|
||||||
"created",
|
|
||||||
"modified",
|
|
||||||
"added",
|
|
||||||
"original_filename",
|
|
||||||
"checksum",
|
|
||||||
"notes",
|
|
||||||
"custom_fields",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
_FIELD_RE = regex.compile(r"(?P<field>\w+):")
|
|
||||||
|
|
||||||
# Matches the TO separator inside a range bracket. Handles three forms:
|
|
||||||
# middle: "lo TO hi" (either lo or hi may be empty)
|
|
||||||
# trailing: "lo TO" (open upper bound)
|
|
||||||
# leading: "TO hi" (open lower bound)
|
|
||||||
# Bounds MAY contain internal spaces (e.g. "-7 days"), so we use .*? / .+?
|
|
||||||
# and split on the whitespace-delimited " TO " / " to " separator.
|
|
||||||
_RANGE_RE = regex.compile(
|
|
||||||
r"^\s*(?P<lo>.*?)\s+[Tt][Oo]\s+(?P<hi>.+?)\s*$"
|
|
||||||
r"|"
|
|
||||||
r"^\s*(?P<lo2>.+?)\s+[Tt][Oo]\s*$"
|
|
||||||
r"|"
|
|
||||||
r"^\s*[Tt][Oo]\s+(?P<hi2>.+?)\s*$",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
|
||||||
class FieldValue:
|
|
||||||
field: str
|
|
||||||
value: str
|
|
||||||
|
|
||||||
|
|
||||||
# Produced by the comma-resolution pass (not by scan()).
|
|
||||||
@dataclass(frozen=True, slots=True)
|
|
||||||
class FieldValueList:
|
|
||||||
field: str
|
|
||||||
values: tuple[str, ...]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
|
||||||
class FieldRange:
|
|
||||||
field: str
|
|
||||||
open: str
|
|
||||||
lo: str
|
|
||||||
hi: str
|
|
||||||
close: str
|
|
||||||
|
|
||||||
|
|
||||||
# Produced by the comma-resolution pass (not by scan()).
|
|
||||||
@dataclass(frozen=True, slots=True)
|
|
||||||
class Comma:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
|
||||||
class Passthrough:
|
|
||||||
raw: str
|
|
||||||
|
|
||||||
|
|
||||||
Token: TypeAlias = FieldValue | FieldValueList | FieldRange | Comma | Passthrough
|
|
||||||
|
|
||||||
_CLOSE: dict[str, str] = {"[": "]", "{": "}"}
|
|
||||||
|
|
||||||
|
|
||||||
def scan(query: str) -> list[Token]:
|
|
||||||
"""
|
|
||||||
Tokenize a raw query into date/comma-aware tokens, leaving everything else
|
|
||||||
as verbatim ``Passthrough`` runs. Non-recursive: finds the first matching
|
|
||||||
close bracket/quote. Nested brackets are not valid Tantivy range syntax and
|
|
||||||
pass through verbatim on mismatch.
|
|
||||||
"""
|
|
||||||
tokens: list[Token] = []
|
|
||||||
buf: list[str] = [] # accumulates passthrough chars
|
|
||||||
i, n = 0, len(query)
|
|
||||||
while i < n:
|
|
||||||
matched = _match_field_token(query, i)
|
|
||||||
if matched is None:
|
|
||||||
buf.append(query[i])
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
token, i = matched
|
|
||||||
_flush(buf, tokens)
|
|
||||||
tokens.append(token)
|
|
||||||
i = _maybe_comma(query, i, tokens)
|
|
||||||
_flush(buf, tokens)
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
def _flush(buf: list[str], tokens: list[Token]) -> None:
|
|
||||||
"""Emit any accumulated passthrough characters as a single token."""
|
|
||||||
if buf:
|
|
||||||
tokens.append(Passthrough("".join(buf)))
|
|
||||||
buf.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def _at_word_boundary(query: str, i: int) -> bool:
|
|
||||||
"""A field token may begin only at the start or after a non-word character."""
|
|
||||||
return i == 0 or not (query[i - 1].isalnum() or query[i - 1] == "_")
|
|
||||||
|
|
||||||
|
|
||||||
def _match_field_token(query: str, i: int) -> tuple[Token, int] | None:
|
|
||||||
"""
|
|
||||||
If a known ``field:`` token starts at ``i``, consume it and return
|
|
||||||
``(token, end_index)``; otherwise return None so the caller treats the
|
|
||||||
character as passthrough. Handles both ``field:[range]`` and ``field:value``,
|
|
||||||
and returns None when the range/value cannot be consumed.
|
|
||||||
"""
|
|
||||||
m = _FIELD_RE.match(query, i)
|
|
||||||
if m is None or m.group("field") not in KNOWN_FIELDS:
|
|
||||||
return None
|
|
||||||
if not _at_word_boundary(query, i):
|
|
||||||
return None
|
|
||||||
field = m.group("field")
|
|
||||||
j = m.end()
|
|
||||||
if j < len(query) and query[j] in "[{":
|
|
||||||
return _consume_range(query, j, field)
|
|
||||||
consumed = _consume_field_value(query, field, j)
|
|
||||||
if consumed is None:
|
|
||||||
return None
|
|
||||||
value, end = consumed
|
|
||||||
return FieldValue(field, value), end
|
|
||||||
|
|
||||||
|
|
||||||
def _consume_field_value(query: str, field: str, start: int) -> tuple[str, int] | None:
|
|
||||||
"""
|
|
||||||
Consume a field value starting at ``start``: a multi-word date keyword phrase
|
|
||||||
(date fields only), or a bare/quoted value, then absorb any comma-joined
|
|
||||||
continuation that is not a clause separator. ``resolve_commas`` later splits a
|
|
||||||
multi-value field's joined value into a ``FieldValueList``; for other fields
|
|
||||||
the comma stays literal.
|
|
||||||
"""
|
|
||||||
n = len(query)
|
|
||||||
consumed = None
|
|
||||||
if field in DATE_FIELDS:
|
|
||||||
km = _KEYWORD_VALUE_RE.match(query, start)
|
|
||||||
if km is not None and (km.end() >= n or query[km.end()] in " \t),"):
|
|
||||||
consumed = (km.group(0), km.end())
|
|
||||||
if consumed is None:
|
|
||||||
consumed = _consume_value(query, start)
|
|
||||||
if consumed is None:
|
|
||||||
return None
|
|
||||||
value, k = consumed
|
|
||||||
while k < n and query[k] == ",":
|
|
||||||
if _looks_like_known_field(query, k + 1):
|
|
||||||
break # clause separator: left for _maybe_comma to emit a Comma()
|
|
||||||
more = _consume_value(query, k + 1)
|
|
||||||
if more is None:
|
|
||||||
break
|
|
||||||
value = f"{value},{more[0]}"
|
|
||||||
k = more[1]
|
|
||||||
return value, k
|
|
||||||
|
|
||||||
|
|
||||||
def _consume_range(
|
|
||||||
query: str,
|
|
||||||
start: int,
|
|
||||||
field: str,
|
|
||||||
) -> tuple[FieldRange, int] | None:
|
|
||||||
"""Consume ``[lo TO hi]`` / ``{lo TO hi}`` from ``start`` (the bracket)."""
|
|
||||||
open_br = query[start]
|
|
||||||
close_br = _CLOSE[open_br]
|
|
||||||
end = query.find(close_br, start + 1)
|
|
||||||
if end == -1:
|
|
||||||
return None
|
|
||||||
inner = query[start + 1 : end]
|
|
||||||
m = _RANGE_RE.match(inner)
|
|
||||||
if m is not None:
|
|
||||||
if m.group("lo") is not None or m.group("hi") is not None:
|
|
||||||
# Middle form: "lo TO hi" (either may be empty string)
|
|
||||||
lo = (m.group("lo") or "").strip()
|
|
||||||
hi = (m.group("hi") or "").strip()
|
|
||||||
elif m.group("lo2") is not None:
|
|
||||||
# Trailing form: "lo TO"
|
|
||||||
lo = m.group("lo2").strip()
|
|
||||||
hi = ""
|
|
||||||
else:
|
|
||||||
# Leading form: "TO hi"
|
|
||||||
lo = ""
|
|
||||||
hi = (m.group("hi2") or "").strip()
|
|
||||||
else:
|
|
||||||
lo, hi = inner.strip(), ""
|
|
||||||
return FieldRange(field, open_br, lo, hi, close_br), end + 1
|
|
||||||
|
|
||||||
|
|
||||||
def _consume_value(query: str, start: int) -> tuple[str, int] | None:
|
|
||||||
"""Consume a bare or quoted field value from ``start``, stopping at comma."""
|
|
||||||
n = len(query)
|
|
||||||
if start >= n or query[start] in " \t":
|
|
||||||
return None
|
|
||||||
if query[start] in "\"'":
|
|
||||||
quote = query[start]
|
|
||||||
end = query.find(quote, start + 1)
|
|
||||||
if end == -1:
|
|
||||||
return None
|
|
||||||
return query[start : end + 1], end + 1
|
|
||||||
j = start
|
|
||||||
while j < n and query[j] not in " \t),":
|
|
||||||
j += 1
|
|
||||||
return query[start:j], j
|
|
||||||
|
|
||||||
|
|
||||||
def _looks_like_known_field(query: str, pos: int) -> bool:
|
|
||||||
"""True if a known ``field:`` token starts at ``pos``."""
|
|
||||||
m = _FIELD_RE.match(query, pos)
|
|
||||||
return bool(m and m.group("field") in KNOWN_FIELDS)
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_comma(query: str, i: int, tokens: list) -> int:
|
|
||||||
"""If a clause-separator comma follows at ``i``, emit ``Comma()`` and advance."""
|
|
||||||
if i < len(query) and query[i] == "," and _looks_like_known_field(query, i + 1):
|
|
||||||
tokens.append(Comma())
|
|
||||||
return i + 1
|
|
||||||
return i
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_commas(tokens: list) -> list:
|
|
||||||
"""
|
|
||||||
Collapse value-list commas into ``FieldValueList`` and keep clause-separator
|
|
||||||
commas as ``Comma``. (Clause-sep commas are already emitted by ``scan`` via
|
|
||||||
the value-stop logic; this pass folds value-lists.)
|
|
||||||
"""
|
|
||||||
out: list = []
|
|
||||||
for tok in tokens:
|
|
||||||
if (
|
|
||||||
isinstance(tok, FieldValue)
|
|
||||||
and tok.field in MULTI_VALUE_FIELDS
|
|
||||||
and "," in tok.value
|
|
||||||
):
|
|
||||||
values = tuple(v for v in tok.value.split(",") if v)
|
|
||||||
out.append(FieldValueList(tok.field, values))
|
|
||||||
else:
|
|
||||||
out.append(tok)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class SearchQueryError(ValueError):
|
|
||||||
"""
|
|
||||||
Base for user-fixable search query errors.
|
|
||||||
|
|
||||||
Carries a message safe to surface to the user (no internal details). The view
|
|
||||||
layer catches this and returns an HTTP 400, so any future subclass (unknown
|
|
||||||
field, malformed range, wrapped parser errors) gets the same treatment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidDateQuery(SearchQueryError):
|
|
||||||
"""Raised when a date field value or range bound cannot be parsed."""
|
|
||||||
|
|
||||||
def __init__(self, field: str, value: str) -> None:
|
|
||||||
self.field = field
|
|
||||||
self.value = value
|
|
||||||
super().__init__(f"Invalid date value {value!r} for field {field!r}.")
|
|
||||||
|
|
||||||
|
|
||||||
_DIGITS_RE = regex.compile(r"^\d{4}(?:\d{2}){0,2}$")
|
|
||||||
_ISO_RE = regex.compile(r"^\d{4}(?:-\d{2}(?:-\d{2})?)?$")
|
|
||||||
|
|
||||||
|
|
||||||
def translate_scalar(field: str, value: str, tz: tzinfo) -> str:
|
|
||||||
"""Translate a bare date-field value to a Tantivy range string."""
|
|
||||||
bare = value.strip("\"'").lower()
|
|
||||||
if bare in _DATE_KEYWORDS:
|
|
||||||
if field in _DATE_ONLY_FIELDS:
|
|
||||||
return f"{field}:{_date_only_range(bare, tz)}"
|
|
||||||
return f"{field}:{_datetime_range(bare, tz)}"
|
|
||||||
digits = value.replace("-", "")
|
|
||||||
if _DIGITS_RE.match(value) or _ISO_RE.match(value):
|
|
||||||
bounds = _precision_bounds(digits)
|
|
||||||
if bounds is None:
|
|
||||||
raise InvalidDateQuery(field, value)
|
|
||||||
return _field_range_from_dates(field, bounds[0], bounds[1], tz)
|
|
||||||
if regex.fullmatch(r"\d{14}", value):
|
|
||||||
try:
|
|
||||||
dt = datetime(
|
|
||||||
int(value[0:4]),
|
|
||||||
int(value[4:6]),
|
|
||||||
int(value[6:8]),
|
|
||||||
int(value[8:10]),
|
|
||||||
int(value[10:12]),
|
|
||||||
int(value[12:14]),
|
|
||||||
tzinfo=UTC,
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
raise InvalidDateQuery(field, value) from None
|
|
||||||
iso = _fmt(dt)
|
|
||||||
return f"{field}:[{iso} TO {iso}]"
|
|
||||||
# Unrecognized shape -> tell the user their date is malformed rather than
|
|
||||||
# silently matching nothing or emitting invalid Tantivy syntax.
|
|
||||||
raise InvalidDateQuery(field, value)
|
|
||||||
|
|
||||||
|
|
||||||
# Open-bound sentinels for date ranges. These far-past/far-future strings allow
|
|
||||||
# open-ended ranges to be expressed as Tantivy string queries until tantivy-py
|
|
||||||
# exposes Query.range_query(..., None) on Date fields (see module TODO).
|
|
||||||
OPEN_LO = "0001-01-01T00:00:00Z"
|
|
||||||
OPEN_HI = "9999-12-31T23:59:59Z"
|
|
||||||
|
|
||||||
|
|
||||||
# Matches compact now-offset tokens like now-7d, now+1h, now-30m.
|
|
||||||
_NOW_COMPACT_RE = regex.compile(
|
|
||||||
r"^now(?P<sign>[+-])(?P<n>\d+)(?P<unit>[dhm])$",
|
|
||||||
regex.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Matches "±N <unit>" Whoosh-style offsets (e.g. -7 days, -1 week, +3 hours)
|
|
||||||
# Unit is singular or plural; sign prefix is mandatory.
|
|
||||||
_NOW_SPACED_RE = regex.compile(
|
|
||||||
r"^(?P<sign>[+-])(?P<n>\d+)\s*"
|
|
||||||
r"(?P<unit>second|minute|hour|day|week|month|year)s?$",
|
|
||||||
regex.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_relative_bound(token: str) -> datetime | None:
|
|
||||||
"""
|
|
||||||
Resolve a relative bound token to an exact UTC instant, or return None.
|
|
||||||
|
|
||||||
Supported forms:
|
|
||||||
- ``now`` -> current UTC instant
|
|
||||||
- ``now+/-<n>d/h/m`` -> now +/- timedelta (d=days, h=hours, m=minutes)
|
|
||||||
- ``±N <unit>`` -> now +/- delta; month/year use relativedelta
|
|
||||||
"""
|
|
||||||
stripped = token.strip()
|
|
||||||
low = stripped.lower()
|
|
||||||
now = datetime.now(UTC)
|
|
||||||
|
|
||||||
if low == "now":
|
|
||||||
return now
|
|
||||||
|
|
||||||
m = _NOW_COMPACT_RE.match(stripped)
|
|
||||||
if m:
|
|
||||||
sign = 1 if m.group("sign") == "+" else -1
|
|
||||||
n = int(m.group("n"))
|
|
||||||
unit = m.group("unit").lower()
|
|
||||||
delta = (
|
|
||||||
sign
|
|
||||||
* {
|
|
||||||
"d": timedelta(days=n),
|
|
||||||
"h": timedelta(hours=n),
|
|
||||||
"m": timedelta(minutes=n),
|
|
||||||
}[unit]
|
|
||||||
)
|
|
||||||
return now + delta
|
|
||||||
|
|
||||||
m = _NOW_SPACED_RE.match(stripped)
|
|
||||||
if m:
|
|
||||||
sign = 1 if m.group("sign") == "+" else -1
|
|
||||||
n = int(m.group("n"))
|
|
||||||
unit = m.group("unit").lower()
|
|
||||||
delta_map: dict[str, timedelta | relativedelta] = {
|
|
||||||
"second": timedelta(seconds=n),
|
|
||||||
"minute": timedelta(minutes=n),
|
|
||||||
"hour": timedelta(hours=n),
|
|
||||||
"day": timedelta(days=n),
|
|
||||||
"week": timedelta(weeks=n),
|
|
||||||
"month": relativedelta(months=n),
|
|
||||||
"year": relativedelta(years=n),
|
|
||||||
}
|
|
||||||
return now - delta_map[unit] if sign == -1 else now + delta_map[unit]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _bound_datetimes(
|
|
||||||
field: str,
|
|
||||||
token: str,
|
|
||||||
tz: tzinfo,
|
|
||||||
) -> tuple[datetime, datetime] | None:
|
|
||||||
"""
|
|
||||||
Return (floor_dt, ceil_dt) UTC datetimes for a single range bound token, or
|
|
||||||
None if the token is unparsable. ``now`` and relative offsets resolve to the
|
|
||||||
current instant (floor == ceil == that instant; no day-flooring).
|
|
||||||
"""
|
|
||||||
token = token.strip()
|
|
||||||
|
|
||||||
# Try relative/now forms first (before stripping hyphens which would mangle them).
|
|
||||||
rel = _resolve_relative_bound(token)
|
|
||||||
if rel is not None:
|
|
||||||
return rel, rel
|
|
||||||
|
|
||||||
# Full ISO datetime token (contains "T"): parse directly and return an exact
|
|
||||||
# instant (floor == ceil). Python 3.11+ datetime.fromisoformat accepts trailing Z.
|
|
||||||
if "T" in token:
|
|
||||||
try:
|
|
||||||
dt = datetime.fromisoformat(token)
|
|
||||||
# Ensure timezone-aware UTC result.
|
|
||||||
dt = dt.replace(tzinfo=UTC) if dt.tzinfo is None else dt.astimezone(UTC)
|
|
||||||
return dt, dt
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
digits = token.replace("-", "")
|
|
||||||
bounds = _precision_bounds(digits)
|
|
||||||
if bounds is None:
|
|
||||||
return None
|
|
||||||
start, end = bounds
|
|
||||||
return _utc_bounds_for_field(field, start, end, tz)
|
|
||||||
|
|
||||||
|
|
||||||
def _render(tok: Token, tz: tzinfo) -> str:
|
|
||||||
"""Render a single token back to a Tantivy query string fragment."""
|
|
||||||
if isinstance(tok, Passthrough):
|
|
||||||
return tok.raw
|
|
||||||
if isinstance(tok, Comma):
|
|
||||||
return " AND "
|
|
||||||
if isinstance(tok, FieldValueList):
|
|
||||||
field = FIELD_ALIASES.get(tok.field, tok.field)
|
|
||||||
return " AND ".join(f"{field}:{v}" for v in tok.values)
|
|
||||||
if isinstance(tok, FieldValue):
|
|
||||||
field = FIELD_ALIASES.get(tok.field, tok.field)
|
|
||||||
if field in DATE_FIELDS:
|
|
||||||
return translate_scalar(field, tok.value, tz)
|
|
||||||
return f"{field}:{tok.value}"
|
|
||||||
if isinstance(tok, FieldRange):
|
|
||||||
field = FIELD_ALIASES.get(tok.field, tok.field)
|
|
||||||
if field in DATE_FIELDS:
|
|
||||||
return translate_range(field, tok.lo, tok.hi, tz)
|
|
||||||
return f"{field}:{tok.open}{tok.lo} TO {tok.hi}{tok.close}"
|
|
||||||
return "" # pragma: no cover
|
|
||||||
|
|
||||||
|
|
||||||
# Post-render operator normalization patterns: collapse repeated whitespace and
|
|
||||||
# strip spaced/trailing Tantivy boolean operators that would otherwise be invalid.
|
|
||||||
_MULTI_SPACE_RE = regex.compile(r" {2,}")
|
|
||||||
_TRAILING_OP_RE = regex.compile(r"\s+[-+]+\s*$")
|
|
||||||
_SPACED_OP_RE = regex.compile(r"\s+[-+]\s+")
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_operators(text: str) -> str:
|
|
||||||
"""
|
|
||||||
Collapse multiple spaces, strip trailing dangling operators, and replace
|
|
||||||
spaced operators (`` - `` / `` + ``) with a single space.
|
|
||||||
|
|
||||||
Applied only to Passthrough fragments (the rendered output is scanned for
|
|
||||||
operator artifacts outside bracketed ranges) via a post-render pass on the
|
|
||||||
full rendered string. This preserves date ranges (``[... TO ...]``) verbatim
|
|
||||||
while cleaning natural-language separators in the surrounding text.
|
|
||||||
"""
|
|
||||||
text = _MULTI_SPACE_RE.sub(" ", text)
|
|
||||||
text = _TRAILING_OP_RE.sub("", text).strip()
|
|
||||||
text = _SPACED_OP_RE.sub(" ", text).strip()
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def translate_query(raw: str, tz: tzinfo) -> str:
|
|
||||||
"""Translate a raw Whoosh-style query into Tantivy-compatible syntax."""
|
|
||||||
tokens = resolve_commas(scan(raw))
|
|
||||||
rendered = "".join(_render(t, tz) for t in tokens)
|
|
||||||
return _normalize_operators(rendered)
|
|
||||||
|
|
||||||
|
|
||||||
def translate_range(field: str, lo: str, hi: str, tz: tzinfo) -> str:
|
|
||||||
"""Translate a date-field ``[lo TO hi]`` range to a Tantivy ISO range string.
|
|
||||||
|
|
||||||
Handles partial-date bounds (YYYY, YYYYMM, YYYYMMDD, ISO dash variants),
|
|
||||||
open bounds (empty string -> OPEN_LO/OPEN_HI), ``now``, and reversed ranges
|
|
||||||
(swaps tokens before computing floor/ceil so the span is always correct).
|
|
||||||
"""
|
|
||||||
lo_s = lo.strip()
|
|
||||||
hi_s = hi.strip()
|
|
||||||
|
|
||||||
# Parse both bounds to (floor, ceil) pairs when present.
|
|
||||||
lo_pair: tuple[datetime, datetime] | None = None
|
|
||||||
hi_pair: tuple[datetime, datetime] | None = None
|
|
||||||
|
|
||||||
if lo_s:
|
|
||||||
lo_pair = _bound_datetimes(field, lo_s, tz)
|
|
||||||
if lo_pair is None:
|
|
||||||
raise InvalidDateQuery(field, lo_s)
|
|
||||||
if hi_s:
|
|
||||||
hi_pair = _bound_datetimes(field, hi_s, tz)
|
|
||||||
if hi_pair is None:
|
|
||||||
raise InvalidDateQuery(field, hi_s)
|
|
||||||
|
|
||||||
# Detect a reversed range: only swap when BOTH bounds are present.
|
|
||||||
if lo_pair is not None and hi_pair is not None and lo_pair[0] > hi_pair[0]:
|
|
||||||
lo_pair, hi_pair = hi_pair, lo_pair
|
|
||||||
|
|
||||||
lo_iso = _fmt(lo_pair[0]) if lo_pair is not None else OPEN_LO
|
|
||||||
hi_iso = _fmt(hi_pair[1]) if hi_pair is not None else OPEN_HI
|
|
||||||
|
|
||||||
return f"{field}:[{lo_iso} TO {hi_iso}]"
|
|
||||||
@@ -48,7 +48,6 @@ from rest_framework import serializers
|
|||||||
from rest_framework.exceptions import PermissionDenied
|
from rest_framework.exceptions import PermissionDenied
|
||||||
from rest_framework.fields import SerializerMethodField
|
from rest_framework.fields import SerializerMethodField
|
||||||
from rest_framework.filters import OrderingFilter
|
from rest_framework.filters import OrderingFilter
|
||||||
from rest_framework.utils import model_meta
|
|
||||||
|
|
||||||
if settings.AUDIT_LOG_ENABLED:
|
if settings.AUDIT_LOG_ENABLED:
|
||||||
from auditlog.context import set_actor
|
from auditlog.context import set_actor
|
||||||
@@ -122,45 +121,6 @@ class DynamicFieldsModelSerializer(serializers.ModelSerializer[Any]):
|
|||||||
self.fields.pop(field_name)
|
self.fields.pop(field_name)
|
||||||
|
|
||||||
|
|
||||||
class DocumentUpdateFieldsModelSerializer(DynamicFieldsModelSerializer):
|
|
||||||
stale_update_excluded_fields = frozenset({"filename", "archive_filename"})
|
|
||||||
|
|
||||||
def _get_update_fields(self, validated_data) -> list[str]:
|
|
||||||
model_fields = {
|
|
||||||
field.name
|
|
||||||
for field in self.Meta.model._meta.concrete_fields
|
|
||||||
if field.name not in self.stale_update_excluded_fields
|
|
||||||
}
|
|
||||||
update_fields = [
|
|
||||||
field_name for field_name in validated_data if field_name in model_fields
|
|
||||||
]
|
|
||||||
if "modified" in model_fields and "modified" not in update_fields:
|
|
||||||
update_fields.append("modified")
|
|
||||||
return update_fields
|
|
||||||
|
|
||||||
def update(self, instance, validated_data):
|
|
||||||
serializers.raise_errors_on_nested_writes("update", self, validated_data)
|
|
||||||
info = model_meta.get_field_info(instance)
|
|
||||||
|
|
||||||
m2m_fields = []
|
|
||||||
for attr, value in validated_data.items():
|
|
||||||
if attr in info.relations and info.relations[attr].to_many:
|
|
||||||
m2m_fields.append((attr, value))
|
|
||||||
else:
|
|
||||||
setattr(instance, attr, value)
|
|
||||||
|
|
||||||
# File names are managed by post-save file handling. Saving only the
|
|
||||||
# serializer-updated fields prevents stale in-memory path values from
|
|
||||||
# overwriting a concurrent move.
|
|
||||||
instance.save(update_fields=self._get_update_fields(validated_data))
|
|
||||||
|
|
||||||
for attr, value in m2m_fields:
|
|
||||||
field = getattr(instance, attr)
|
|
||||||
field.set(value)
|
|
||||||
|
|
||||||
return instance
|
|
||||||
|
|
||||||
|
|
||||||
class MatchingModelSerializer(serializers.ModelSerializer[Any]):
|
class MatchingModelSerializer(serializers.ModelSerializer[Any]):
|
||||||
document_count = serializers.IntegerField(read_only=True)
|
document_count = serializers.IntegerField(read_only=True)
|
||||||
|
|
||||||
@@ -203,7 +163,7 @@ class MatchingModelSerializer(serializers.ModelSerializer[Any]):
|
|||||||
logger.debug(f"Invalid regular expression: {e!s}")
|
logger.debug(f"Invalid regular expression: {e!s}")
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
"Invalid regular expression, see log for details.",
|
"Invalid regular expression, see log for details.",
|
||||||
)
|
) from None
|
||||||
return match
|
return match
|
||||||
|
|
||||||
|
|
||||||
@@ -907,7 +867,9 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer[CustomFieldInsta
|
|||||||
try:
|
try:
|
||||||
value_int = int(data["value"])
|
value_int = int(data["value"])
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
raise serializers.ValidationError("Enter a valid integer.")
|
raise serializers.ValidationError(
|
||||||
|
"Enter a valid integer.",
|
||||||
|
) from None
|
||||||
# Keep values within the PostgreSQL integer range
|
# Keep values within the PostgreSQL integer range
|
||||||
MinValueValidator(-2147483648)(value_int)
|
MinValueValidator(-2147483648)(value_int)
|
||||||
MaxValueValidator(2147483647)(value_int)
|
MaxValueValidator(2147483647)(value_int)
|
||||||
@@ -939,7 +901,7 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer[CustomFieldInsta
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
f"Value must be an id of an element in {select_options}",
|
f"Value must be an id of an element in {select_options}",
|
||||||
)
|
) from None
|
||||||
elif field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
|
elif field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
|
||||||
if not (isinstance(data["value"], list) or data["value"] is None):
|
if not (isinstance(data["value"], list) or data["value"] is None):
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
@@ -1029,7 +991,7 @@ class DocumentVersionInfoSerializer(serializers.Serializer[_DocumentVersionInfo]
|
|||||||
class DocumentSerializer(
|
class DocumentSerializer(
|
||||||
OwnedObjectSerializer,
|
OwnedObjectSerializer,
|
||||||
NestedUpdateMixin,
|
NestedUpdateMixin,
|
||||||
DocumentUpdateFieldsModelSerializer,
|
DynamicFieldsModelSerializer,
|
||||||
):
|
):
|
||||||
correspondent = CorrespondentField(allow_null=True)
|
correspondent = CorrespondentField(allow_null=True)
|
||||||
tags = TagsField(many=True)
|
tags = TagsField(many=True)
|
||||||
@@ -1130,7 +1092,7 @@ class DocumentSerializer(
|
|||||||
def to_representation(self, instance):
|
def to_representation(self, instance):
|
||||||
doc = super().to_representation(instance)
|
doc = super().to_representation(instance)
|
||||||
if "content" in self.fields and hasattr(instance, "effective_content"):
|
if "content" in self.fields and hasattr(instance, "effective_content"):
|
||||||
doc["content"] = getattr(instance, "effective_content") or ""
|
doc["content"] = instance.effective_content or ""
|
||||||
if self.truncate_content and "content" in self.fields:
|
if self.truncate_content and "content" in self.fields:
|
||||||
doc["content"] = doc.get("content")[0:550]
|
doc["content"] = doc.get("content")[0:550]
|
||||||
return doc
|
return doc
|
||||||
@@ -1168,9 +1130,10 @@ class DocumentSerializer(
|
|||||||
return super().validate(attrs)
|
return super().validate(attrs)
|
||||||
|
|
||||||
def update(self, instance: Document, validated_data):
|
def update(self, instance: Document, validated_data):
|
||||||
|
if "created_date" in validated_data and "created" not in validated_data:
|
||||||
|
instance.created = validated_data.get("created_date")
|
||||||
|
instance.save()
|
||||||
if "created_date" in validated_data:
|
if "created_date" in validated_data:
|
||||||
if "created" not in validated_data:
|
|
||||||
validated_data["created"] = validated_data["created_date"]
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"created_date is deprecated, use created instead",
|
"created_date is deprecated, use created instead",
|
||||||
)
|
)
|
||||||
@@ -1240,13 +1203,11 @@ class DocumentSerializer(
|
|||||||
for tag in instance.tags.all()
|
for tag in instance.tags.all()
|
||||||
if tag not in inbox_tags_not_being_added
|
if tag not in inbox_tags_not_being_added
|
||||||
]
|
]
|
||||||
|
|
||||||
if settings.AUDIT_LOG_ENABLED:
|
if settings.AUDIT_LOG_ENABLED:
|
||||||
with set_actor(self.user):
|
with set_actor(self.user):
|
||||||
super().update(instance, validated_data)
|
super().update(instance, validated_data)
|
||||||
else:
|
else:
|
||||||
super().update(instance, validated_data)
|
super().update(instance, validated_data)
|
||||||
|
|
||||||
# hard delete custom field instances that were soft deleted
|
# hard delete custom field instances that were soft deleted
|
||||||
CustomFieldInstance.deleted_objects.filter(document=instance).delete()
|
CustomFieldInstance.deleted_objects.filter(document=instance).delete()
|
||||||
return instance
|
return instance
|
||||||
@@ -1493,7 +1454,7 @@ class SavedViewSerializer(OwnedObjectSerializer):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
except serializers.ValidationError as exc:
|
except serializers.ValidationError as exc:
|
||||||
raise serializers.ValidationError({field_name: exc.detail})
|
raise serializers.ValidationError({field_name: exc.detail}) from exc
|
||||||
del normalized_data[field_name]
|
del normalized_data[field_name]
|
||||||
|
|
||||||
ret = super().to_internal_value(normalized_data)
|
ret = super().to_internal_value(normalized_data)
|
||||||
@@ -1797,7 +1758,7 @@ class BulkEditSerializer(
|
|||||||
logger.exception(f"Error validating custom fields: {e}")
|
logger.exception(f"Error validating custom fields: {e}")
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
f"{name} must be a list of integers or a dict of id:value pairs, see the log for details",
|
f"{name} must be a list of integers or a dict of id:value pairs, see the log for details",
|
||||||
)
|
) from None
|
||||||
elif not isinstance(custom_fields, list) or not all(
|
elif not isinstance(custom_fields, list) or not all(
|
||||||
isinstance(i, int) for i in ids
|
isinstance(i, int) for i in ids
|
||||||
):
|
):
|
||||||
@@ -1865,7 +1826,7 @@ class BulkEditSerializer(
|
|||||||
try:
|
try:
|
||||||
Tag.objects.get(id=tag_id)
|
Tag.objects.get(id=tag_id)
|
||||||
except Tag.DoesNotExist:
|
except Tag.DoesNotExist:
|
||||||
raise serializers.ValidationError("Tag does not exist")
|
raise serializers.ValidationError("Tag does not exist") from None
|
||||||
else:
|
else:
|
||||||
raise serializers.ValidationError("tag not specified")
|
raise serializers.ValidationError("tag not specified")
|
||||||
|
|
||||||
@@ -1878,7 +1839,9 @@ class BulkEditSerializer(
|
|||||||
try:
|
try:
|
||||||
DocumentType.objects.get(id=document_type_id)
|
DocumentType.objects.get(id=document_type_id)
|
||||||
except DocumentType.DoesNotExist:
|
except DocumentType.DoesNotExist:
|
||||||
raise serializers.ValidationError("Document type does not exist")
|
raise serializers.ValidationError(
|
||||||
|
"Document type does not exist",
|
||||||
|
) from None
|
||||||
else:
|
else:
|
||||||
raise serializers.ValidationError("document_type not specified")
|
raise serializers.ValidationError("document_type not specified")
|
||||||
|
|
||||||
@@ -1890,7 +1853,9 @@ class BulkEditSerializer(
|
|||||||
try:
|
try:
|
||||||
Correspondent.objects.get(id=correspondent_id)
|
Correspondent.objects.get(id=correspondent_id)
|
||||||
except Correspondent.DoesNotExist:
|
except Correspondent.DoesNotExist:
|
||||||
raise serializers.ValidationError("Correspondent does not exist")
|
raise serializers.ValidationError(
|
||||||
|
"Correspondent does not exist",
|
||||||
|
) from None
|
||||||
else:
|
else:
|
||||||
raise serializers.ValidationError("correspondent not specified")
|
raise serializers.ValidationError("correspondent not specified")
|
||||||
|
|
||||||
@@ -1904,7 +1869,7 @@ class BulkEditSerializer(
|
|||||||
except StoragePath.DoesNotExist:
|
except StoragePath.DoesNotExist:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
"Storage path does not exist",
|
"Storage path does not exist",
|
||||||
)
|
) from None
|
||||||
else:
|
else:
|
||||||
raise serializers.ValidationError("storage path not specified")
|
raise serializers.ValidationError("storage path not specified")
|
||||||
|
|
||||||
@@ -1959,7 +1924,7 @@ class BulkEditSerializer(
|
|||||||
):
|
):
|
||||||
raise serializers.ValidationError("invalid rotation degrees")
|
raise serializers.ValidationError("invalid rotation degrees")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise serializers.ValidationError("invalid rotation degrees")
|
raise serializers.ValidationError("invalid rotation degrees") from None
|
||||||
|
|
||||||
def _validate_source_mode(self, parameters) -> None:
|
def _validate_source_mode(self, parameters) -> None:
|
||||||
source_mode = parameters.get(
|
source_mode = parameters.get(
|
||||||
@@ -1989,7 +1954,7 @@ class BulkEditSerializer(
|
|||||||
pages.append([int(doc)])
|
pages.append([int(doc)])
|
||||||
parameters["pages"] = pages
|
parameters["pages"] = pages
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise serializers.ValidationError("invalid pages specified")
|
raise serializers.ValidationError("invalid pages specified") from None
|
||||||
|
|
||||||
if "delete_originals" in parameters:
|
if "delete_originals" in parameters:
|
||||||
if not isinstance(parameters["delete_originals"], bool):
|
if not isinstance(parameters["delete_originals"], bool):
|
||||||
@@ -2259,14 +2224,14 @@ class PostDocumentSerializer(serializers.Serializer[dict[str, Any]]):
|
|||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
_("Custom field id must be an integer: %(id)s")
|
_("Custom field id must be an integer: %(id)s")
|
||||||
% {"id": field_id},
|
% {"id": field_id},
|
||||||
)
|
) from None
|
||||||
try:
|
try:
|
||||||
field = CustomField.objects.get(id=field_id_int)
|
field = CustomField.objects.get(id=field_id_int)
|
||||||
except CustomField.DoesNotExist:
|
except CustomField.DoesNotExist:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
_("Custom field with id %(id)s does not exist")
|
_("Custom field with id %(id)s does not exist")
|
||||||
% {"id": field_id_int},
|
% {"id": field_id_int},
|
||||||
)
|
) from None
|
||||||
custom_field_serializer.validate(
|
custom_field_serializer.validate(
|
||||||
{
|
{
|
||||||
"field": field,
|
"field": field,
|
||||||
@@ -2283,7 +2248,7 @@ class PostDocumentSerializer(serializers.Serializer[dict[str, Any]]):
|
|||||||
_(
|
_(
|
||||||
"Custom fields must be a list of integers or an object mapping ids to values.",
|
"Custom fields must be a list of integers or an object mapping ids to values.",
|
||||||
),
|
),
|
||||||
)
|
) from None
|
||||||
if CustomField.objects.filter(id__in=ids).count() != len(set(ids)):
|
if CustomField.objects.filter(id__in=ids).count() != len(set(ids)):
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
_("Some custom fields don't exist or were specified twice."),
|
_("Some custom fields don't exist or were specified twice."),
|
||||||
@@ -2394,7 +2359,9 @@ class EmailSerializer(DocumentListSerializer):
|
|||||||
for address in address_list:
|
for address in address_list:
|
||||||
email_validator(address)
|
email_validator(address)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
raise serializers.ValidationError(f"Invalid email address: {address}")
|
raise serializers.ValidationError(
|
||||||
|
f"Invalid email address: {address}",
|
||||||
|
) from None
|
||||||
|
|
||||||
return ",".join(address_list)
|
return ",".join(address_list)
|
||||||
|
|
||||||
@@ -2673,25 +2640,18 @@ class RunTaskSerializer(serializers.Serializer[dict[str, str]]):
|
|||||||
|
|
||||||
class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
|
class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
|
||||||
tasks = serializers.ListField(
|
tasks = serializers.ListField(
|
||||||
required=False,
|
required=True,
|
||||||
label="Tasks",
|
label="Tasks",
|
||||||
write_only=True,
|
write_only=True,
|
||||||
child=serializers.IntegerField(),
|
child=serializers.IntegerField(),
|
||||||
)
|
)
|
||||||
all = serializers.BooleanField(
|
|
||||||
required=False,
|
|
||||||
default=False,
|
|
||||||
label="All",
|
|
||||||
write_only=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _validate_task_id_list(self, tasks, name="tasks") -> None:
|
def _validate_task_id_list(self, tasks, name="tasks") -> None:
|
||||||
if not isinstance(tasks, list):
|
if not isinstance(tasks, list):
|
||||||
raise serializers.ValidationError(f"{name} must be a list")
|
raise serializers.ValidationError(f"{name} must be a list")
|
||||||
if not all(isinstance(i, int) for i in tasks):
|
if not all(isinstance(i, int) for i in tasks):
|
||||||
raise serializers.ValidationError(f"{name} must be a list of integers")
|
raise serializers.ValidationError(f"{name} must be a list of integers")
|
||||||
queryset = self.context.get("queryset", PaperlessTask.objects.all())
|
count = PaperlessTask.objects.filter(id__in=tasks).count()
|
||||||
count = queryset.filter(id__in=tasks).count()
|
|
||||||
if not count == len(tasks):
|
if not count == len(tasks):
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
f"Some tasks in {name} don't exist or were specified twice.",
|
f"Some tasks in {name} don't exist or were specified twice.",
|
||||||
@@ -2701,21 +2661,6 @@ class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
|
|||||||
self._validate_task_id_list(tasks)
|
self._validate_task_id_list(tasks)
|
||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
def validate(self, attrs):
|
|
||||||
acknowledge_all = attrs.get("all", False)
|
|
||||||
task_ids = attrs.get("tasks")
|
|
||||||
|
|
||||||
if acknowledge_all and task_ids is not None:
|
|
||||||
raise serializers.ValidationError(
|
|
||||||
"Set either all or tasks, not both.",
|
|
||||||
)
|
|
||||||
if not acknowledge_all and task_ids is None:
|
|
||||||
raise serializers.ValidationError(
|
|
||||||
"Either all must be true or tasks must be provided.",
|
|
||||||
)
|
|
||||||
|
|
||||||
return attrs
|
|
||||||
|
|
||||||
|
|
||||||
class ShareLinkSerializer(OwnedObjectSerializer):
|
class ShareLinkSerializer(OwnedObjectSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
@@ -2840,7 +2785,7 @@ class ShareLinkBundleSerializer(OwnedObjectSerializer):
|
|||||||
return share_link_bundle
|
return share_link_bundle
|
||||||
|
|
||||||
def get_document_count(self, obj: ShareLinkBundle) -> int:
|
def get_document_count(self, obj: ShareLinkBundle) -> int:
|
||||||
return getattr(obj, "document_total") or obj.documents.count()
|
return obj.document_total or obj.documents.count()
|
||||||
|
|
||||||
|
|
||||||
class BulkEditObjectsSerializer(SerializerWithPerms, SetPermissionsMixin):
|
class BulkEditObjectsSerializer(SerializerWithPerms, SetPermissionsMixin):
|
||||||
@@ -3188,7 +3133,7 @@ class WorkflowActionSerializer(serializers.ModelSerializer[WorkflowAction]):
|
|||||||
except (ValueError, KeyError) as e:
|
except (ValueError, KeyError) as e:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
{"assign_title": f'Invalid f-string detected: "{e.args[0]}"'},
|
{"assign_title": f'Invalid f-string detected: "{e.args[0]}"'},
|
||||||
)
|
) from None
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"type" in attrs
|
"type" in attrs
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import traceback as _tb
|
import traceback as _tb
|
||||||
@@ -15,7 +16,6 @@ from celery.signals import task_postrun
|
|||||||
from celery.signals import task_prerun
|
from celery.signals import task_prerun
|
||||||
from celery.signals import task_revoked
|
from celery.signals import task_revoked
|
||||||
from celery.signals import worker_process_init
|
from celery.signals import worker_process_init
|
||||||
from celery.signals import worker_process_shutdown
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth.models import Group
|
from django.contrib.auth.models import Group
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
@@ -54,7 +54,6 @@ from documents.models import WorkflowTrigger
|
|||||||
from documents.permissions import get_objects_for_user_owner_aware
|
from documents.permissions import get_objects_for_user_owner_aware
|
||||||
from documents.plugins.helpers import DocumentsStatusManager
|
from documents.plugins.helpers import DocumentsStatusManager
|
||||||
from documents.templating.utils import convert_format_str_to_template_format
|
from documents.templating.utils import convert_format_str_to_template_format
|
||||||
from documents.utils import compute_checksum
|
|
||||||
from documents.workflows.actions import build_workflow_action_context
|
from documents.workflows.actions import build_workflow_action_context
|
||||||
from documents.workflows.actions import execute_email_action
|
from documents.workflows.actions import execute_email_action
|
||||||
from documents.workflows.actions import execute_move_to_trash_action
|
from documents.workflows.actions import execute_move_to_trash_action
|
||||||
@@ -411,7 +410,8 @@ def _path_matches_checksum(path: Path, checksum: str | None) -> bool:
|
|||||||
if checksum is None or not path.is_file():
|
if checksum is None or not path.is_file():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return compute_checksum(path) == checksum
|
with path.open("rb") as f:
|
||||||
|
return hashlib.md5(f.read(), usedforsecurity=False).hexdigest() == checksum
|
||||||
|
|
||||||
|
|
||||||
def _filename_template_uses_custom_fields(doc: Document) -> bool:
|
def _filename_template_uses_custom_fields(doc: Document) -> bool:
|
||||||
@@ -1340,20 +1340,6 @@ def close_connection_pool_on_worker_init(**kwargs) -> None:
|
|||||||
conn.close_pool()
|
conn.close_pool()
|
||||||
|
|
||||||
|
|
||||||
@worker_process_shutdown.connect
|
|
||||||
def close_connection_pool_on_worker_shutdown(**kwargs) -> None: # pragma: no cover
|
|
||||||
"""
|
|
||||||
Close the DB connection pool when a Celery child process exits.
|
|
||||||
|
|
||||||
With CELERY_WORKER_MAX_TASKS_PER_CHILD=1 each child is replaced after a
|
|
||||||
single task. Without closing the pool on shutdown, its connections linger
|
|
||||||
on the server until TCP keepalive reaps them, accumulating over time.
|
|
||||||
"""
|
|
||||||
for conn in connections.all(initialized_only=True):
|
|
||||||
if conn.alias == "default" and hasattr(conn, "pool") and conn.pool:
|
|
||||||
conn.close_pool()
|
|
||||||
|
|
||||||
|
|
||||||
def add_or_update_document_in_llm_index(sender, document, **kwargs):
|
def add_or_update_document_in_llm_index(sender, document, **kwargs):
|
||||||
"""
|
"""
|
||||||
Add or update a document in the LLM index when it is created or updated.
|
Add or update a document in the LLM index when it is created or updated.
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import unicodedata
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from pathlib import PurePath
|
from pathlib import PurePath
|
||||||
|
|
||||||
@@ -37,12 +36,10 @@ class FilePathTemplate(Template):
|
|||||||
def clean_filepath(value: str) -> str:
|
def clean_filepath(value: str) -> str:
|
||||||
"""
|
"""
|
||||||
Clean up a filepath by:
|
Clean up a filepath by:
|
||||||
1. Normalizing Unicode to NFC form to prevent byte-level mismatches
|
1. Removing newlines and carriage returns
|
||||||
2. Removing newlines and carriage returns
|
2. Removing extra spaces before and after forward slashes
|
||||||
3. Removing extra spaces before and after forward slashes
|
3. Preserving spaces in other parts of the path
|
||||||
4. Preserving spaces in other parts of the path
|
|
||||||
"""
|
"""
|
||||||
value = unicodedata.normalize("NFC", value)
|
|
||||||
value = value.replace("\n", "").replace("\r", "")
|
value = value.replace("\n", "").replace("\r", "")
|
||||||
value = re.sub(r"\s*/\s*", "/", value)
|
value = re.sub(r"\s*/\s*", "/", value)
|
||||||
|
|
||||||
@@ -184,17 +181,17 @@ def get_basic_metadata_context(
|
|||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"title": pathvalidate.sanitize_filename(
|
"title": pathvalidate.sanitize_filename(
|
||||||
unicodedata.normalize("NFC", document.title),
|
document.title,
|
||||||
replacement_text="-",
|
replacement_text="-",
|
||||||
),
|
),
|
||||||
"correspondent": pathvalidate.sanitize_filename(
|
"correspondent": pathvalidate.sanitize_filename(
|
||||||
unicodedata.normalize("NFC", document.correspondent.name),
|
document.correspondent.name,
|
||||||
replacement_text="-",
|
replacement_text="-",
|
||||||
)
|
)
|
||||||
if document.correspondent
|
if document.correspondent
|
||||||
else no_value_default,
|
else no_value_default,
|
||||||
"document_type": pathvalidate.sanitize_filename(
|
"document_type": pathvalidate.sanitize_filename(
|
||||||
unicodedata.normalize("NFC", document.document_type.name),
|
document.document_type.name,
|
||||||
replacement_text="-",
|
replacement_text="-",
|
||||||
)
|
)
|
||||||
if document.document_type
|
if document.document_type
|
||||||
@@ -205,10 +202,7 @@ def get_basic_metadata_context(
|
|||||||
"owner_username": document.owner.username
|
"owner_username": document.owner.username
|
||||||
if document.owner
|
if document.owner
|
||||||
else no_value_default,
|
else no_value_default,
|
||||||
"original_name": unicodedata.normalize(
|
"original_name": PurePath(document.original_filename).with_suffix("").name
|
||||||
"NFC",
|
|
||||||
PurePath(document.original_filename).with_suffix("").name,
|
|
||||||
)
|
|
||||||
if document.original_filename
|
if document.original_filename
|
||||||
else no_value_default,
|
else no_value_default,
|
||||||
"doc_pk": f"{document.pk:07}",
|
"doc_pk": f"{document.pk:07}",
|
||||||
@@ -275,12 +269,12 @@ def get_tags_context(tags: Iterable[Tag]) -> dict[str, str | list[str]]:
|
|||||||
return {
|
return {
|
||||||
"tag_list": pathvalidate.sanitize_filename(
|
"tag_list": pathvalidate.sanitize_filename(
|
||||||
",".join(
|
",".join(
|
||||||
sorted(unicodedata.normalize("NFC", tag.name) for tag in tags),
|
sorted(tag.name for tag in tags),
|
||||||
),
|
),
|
||||||
replacement_text="-",
|
replacement_text="-",
|
||||||
),
|
),
|
||||||
# Assumed to be ordered, but a template could loop through to find what they want
|
# Assumed to be ordered, but a template could loop through to find what they want
|
||||||
"tag_name_list": [unicodedata.normalize("NFC", x.name) for x in tags],
|
"tag_name_list": [x.name for x in tags],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -307,7 +301,7 @@ def get_custom_fields_context(
|
|||||||
CustomField.FieldDataType.LONG_TEXT,
|
CustomField.FieldDataType.LONG_TEXT,
|
||||||
}:
|
}:
|
||||||
value = pathvalidate.sanitize_filename(
|
value = pathvalidate.sanitize_filename(
|
||||||
unicodedata.normalize("NFC", field_instance.value),
|
field_instance.value,
|
||||||
replacement_text="-",
|
replacement_text="-",
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
@@ -316,13 +310,10 @@ def get_custom_fields_context(
|
|||||||
):
|
):
|
||||||
options = field_instance.field.extra_data["select_options"]
|
options = field_instance.field.extra_data["select_options"]
|
||||||
value = pathvalidate.sanitize_filename(
|
value = pathvalidate.sanitize_filename(
|
||||||
unicodedata.normalize(
|
next(
|
||||||
"NFC",
|
option["label"]
|
||||||
next(
|
for option in options
|
||||||
option["label"]
|
if option["id"] == field_instance.value
|
||||||
for option in options
|
|
||||||
if option["id"] == field_instance.value
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
replacement_text="-",
|
replacement_text="-",
|
||||||
)
|
)
|
||||||
@@ -330,7 +321,7 @@ def get_custom_fields_context(
|
|||||||
value = field_instance.value
|
value = field_instance.value
|
||||||
field_data["custom_fields"][
|
field_data["custom_fields"][
|
||||||
pathvalidate.sanitize_filename(
|
pathvalidate.sanitize_filename(
|
||||||
unicodedata.normalize("NFC", field_instance.field.name),
|
field_instance.field.name,
|
||||||
replacement_text="-",
|
replacement_text="-",
|
||||||
)
|
)
|
||||||
] = {
|
] = {
|
||||||
|
|||||||
@@ -29,9 +29,7 @@ class SimpleCommand(PaperlessCommand):
|
|||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
items = list(range(5))
|
items = list(range(5))
|
||||||
results = []
|
results = [item * 2 for item in self.track(items, description="Processing...")]
|
||||||
for item in self.track(items, description="Processing..."):
|
|
||||||
results.append(item * 2)
|
|
||||||
self.stdout.write(f"Results: {results}")
|
self.stdout.write(f"Results: {results}")
|
||||||
|
|
||||||
|
|
||||||
@@ -57,13 +55,13 @@ class MultiprocessCommand(PaperlessCommand):
|
|||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
items = list(range(5))
|
items = list(range(5))
|
||||||
results = []
|
results = list(
|
||||||
for result in self.process_parallel(
|
self.process_parallel(
|
||||||
_double_value,
|
_double_value,
|
||||||
items,
|
items,
|
||||||
description="Processing...",
|
description="Processing...",
|
||||||
):
|
),
|
||||||
results.append(result)
|
)
|
||||||
successes = sum(1 for r in results if r.success)
|
successes = sum(1 for r in results if r.success)
|
||||||
self.stdout.write(f"Successes: {successes}")
|
self.stdout.write(f"Successes: {successes}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from django.core.management import call_command
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from pytest_mock import MockerFixture
|
|
||||||
|
|
||||||
_COMPACT = "documents.management.commands.document_llmindex.llm_index_compact"
|
|
||||||
_INDEX = "documents.management.commands.document_llmindex.llmindex_index"
|
|
||||||
|
|
||||||
|
|
||||||
class TestDocumentLlmindexCommand:
|
|
||||||
def test_compact_calls_llm_index_compact(self, mocker: MockerFixture) -> None:
|
|
||||||
mock_compact = mocker.patch(_COMPACT)
|
|
||||||
call_command("document_llmindex", "compact")
|
|
||||||
mock_compact.assert_called_once_with()
|
|
||||||
|
|
||||||
def test_rebuild_calls_llmindex_index_with_rebuild_true(
|
|
||||||
self,
|
|
||||||
mocker: MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
mock_index = mocker.patch(_INDEX)
|
|
||||||
call_command("document_llmindex", "rebuild")
|
|
||||||
mock_index.assert_called_once()
|
|
||||||
assert mock_index.call_args.kwargs["rebuild"] is True
|
|
||||||
|
|
||||||
def test_update_calls_llmindex_index_with_rebuild_false(
|
|
||||||
self,
|
|
||||||
mocker: MockerFixture,
|
|
||||||
) -> None:
|
|
||||||
mock_index = mocker.patch(_INDEX)
|
|
||||||
call_command("document_llmindex", "update")
|
|
||||||
mock_index.assert_called_once()
|
|
||||||
assert mock_index.call_args.kwargs["rebuild"] is False
|
|
||||||
@@ -1,15 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import tempfile
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import tantivy
|
|
||||||
|
|
||||||
from documents.search._backend import TantivyBackend
|
from documents.search._backend import TantivyBackend
|
||||||
from documents.search._backend import reset_backend
|
from documents.search._backend import reset_backend
|
||||||
from documents.search._schema import build_schema
|
|
||||||
from documents.search._tokenizer import register_tokenizers
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
@@ -35,11 +31,3 @@ def backend() -> Generator[TantivyBackend, None, None]:
|
|||||||
finally:
|
finally:
|
||||||
b.close()
|
b.close()
|
||||||
reset_backend()
|
reset_backend()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def index() -> tantivy.Index:
|
|
||||||
"""A real Tantivy index for parse-acceptance tests (module scope for speed)."""
|
|
||||||
idx = tantivy.Index(build_schema(), path=tempfile.mkdtemp())
|
|
||||||
register_tokenizers(idx, "english")
|
|
||||||
return idx
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import time_machine
|
|||||||
|
|
||||||
from documents.search._query import _date_only_range
|
from documents.search._query import _date_only_range
|
||||||
from documents.search._query import _datetime_range
|
from documents.search._query import _datetime_range
|
||||||
|
from documents.search._query import _rewrite_compact_date
|
||||||
from documents.search._query import build_permission_filter
|
from documents.search._query import build_permission_filter
|
||||||
from documents.search._query import normalize_query
|
from documents.search._query import normalize_query
|
||||||
from documents.search._query import parse_simple_text_highlight_query
|
from documents.search._query import parse_simple_text_highlight_query
|
||||||
@@ -20,7 +21,6 @@ from documents.search._query import parse_user_query
|
|||||||
from documents.search._query import rewrite_natural_date_keywords
|
from documents.search._query import rewrite_natural_date_keywords
|
||||||
from documents.search._schema import build_schema
|
from documents.search._schema import build_schema
|
||||||
from documents.search._tokenizer import register_tokenizers
|
from documents.search._tokenizer import register_tokenizers
|
||||||
from documents.search._translate import InvalidDateQuery
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from django.contrib.auth.base_user import AbstractBaseUser
|
from django.contrib.auth.base_user import AbstractBaseUser
|
||||||
@@ -405,14 +405,12 @@ class TestWhooshQueryRewriting:
|
|||||||
assert lo == "2023-12-01T05:00:00Z"
|
assert lo == "2023-12-01T05:00:00Z"
|
||||||
assert hi == "2023-12-02T05:00:00Z"
|
assert hi == "2023-12-02T05:00:00Z"
|
||||||
|
|
||||||
def test_8digit_invalid_date_raises(self) -> None:
|
def test_8digit_invalid_date_passes_through_unchanged(self) -> None:
|
||||||
# The translation pipeline raises InvalidDateQuery for unparsable dates
|
assert rewrite_natural_date_keywords("added:20231340", UTC) == "added:20231340"
|
||||||
# (e.g. month=13) so the API can surface a 400 telling the user the date
|
|
||||||
# is malformed instead of silently returning zero results.
|
def test_compact_14digit_invalid_date_passes_through_unchanged(self) -> None:
|
||||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
# Month=13 makes datetime() raise ValueError; the token must be left as-is
|
||||||
rewrite_natural_date_keywords("added:20231340", UTC)
|
assert _rewrite_compact_date("20231300120000") == "20231300120000"
|
||||||
assert exc_info.value.field == "added"
|
|
||||||
assert exc_info.value.value == "20231340"
|
|
||||||
|
|
||||||
|
|
||||||
class TestParseUserQuery:
|
class TestParseUserQuery:
|
||||||
@@ -465,67 +463,6 @@ class TestParseUserQuery:
|
|||||||
) -> None:
|
) -> None:
|
||||||
assert isinstance(parse_user_query(query_index, raw_query, UTC), tantivy.Query)
|
assert isinstance(parse_user_query(query_index, raw_query, UTC), tantivy.Query)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"raw_query",
|
|
||||||
[
|
|
||||||
# Partial date scalar (year only)
|
|
||||||
pytest.param("created:2020", id="created_year_scalar"),
|
|
||||||
# 8-digit compact date range in brackets
|
|
||||||
pytest.param(
|
|
||||||
"created:[20200101 TO 20201231]",
|
|
||||||
id="created_8digit_bracket_range",
|
|
||||||
),
|
|
||||||
# Comma-separated field + date range (Whoosh v2 multi-clause syntax)
|
|
||||||
pytest.param(
|
|
||||||
"title:x,created:[2020 TO 2021]",
|
|
||||||
id="title_comma_created_range",
|
|
||||||
),
|
|
||||||
# Field alias: type -> document_type
|
|
||||||
pytest.param("type:invoice", id="type_alias"),
|
|
||||||
# Multi-word date keyword
|
|
||||||
pytest.param("created:previous week", id="created_previous_week"),
|
|
||||||
# Full ISO datetime range
|
|
||||||
pytest.param(
|
|
||||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
|
|
||||||
id="created_iso_range",
|
|
||||||
),
|
|
||||||
# Comma-separated ISO ranges (Whoosh v2 syntax)
|
|
||||||
pytest.param(
|
|
||||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
|
|
||||||
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
|
|
||||||
id="comma_iso_ranges",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_advanced_search_queries_do_not_raise(
|
|
||||||
self,
|
|
||||||
query_index: tantivy.Index,
|
|
||||||
raw_query: str,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
End-to-end: queries that the frontend sends must parse without raising.
|
|
||||||
|
|
||||||
This tests the full pipeline: translate_query -> tantivy parse_query.
|
|
||||||
Equivalent to asserting HTTP 200 (not 400) for each query form.
|
|
||||||
"""
|
|
||||||
with time_machine.travel(datetime(2026, 6, 15, 12, 0, tzinfo=UTC), tick=False):
|
|
||||||
assert isinstance(
|
|
||||||
parse_user_query(query_index, raw_query, UTC),
|
|
||||||
tantivy.Query,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_invalid_date_propagates_not_swallowed(
|
|
||||||
self,
|
|
||||||
query_index: tantivy.Index,
|
|
||||||
) -> None:
|
|
||||||
# parse_user_query falls back to the raw query on unexpected translation
|
|
||||||
# errors, but an InvalidDateQuery is intentional and must propagate so the
|
|
||||||
# view can return a 400 instead of silently parsing the raw (invalid) date.
|
|
||||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
|
||||||
parse_user_query(query_index, "created:202023", UTC)
|
|
||||||
assert exc_info.value.field == "created"
|
|
||||||
assert exc_info.value.value == "202023"
|
|
||||||
|
|
||||||
|
|
||||||
class TestYearRangeRewriting:
|
class TestYearRangeRewriting:
|
||||||
"""Whoosh-style year-only date ranges must be rewritten to ISO 8601."""
|
"""Whoosh-style year-only date ranges must be rewritten to ISO 8601."""
|
||||||
@@ -605,16 +542,11 @@ class TestYearRangeRewriting:
|
|||||||
assert rewrite_natural_date_keywords(original, UTC) == original
|
assert rewrite_natural_date_keywords(original, UTC) == original
|
||||||
|
|
||||||
def test_8digit_in_brackets_not_matched_as_year_range(self) -> None:
|
def test_8digit_in_brackets_not_matched_as_year_range(self) -> None:
|
||||||
# [YYYYMMDD TO YYYYMMDD]: the translation layer converts 8-digit bounds to
|
# [YYYYMMDD TO YYYYMMDD] has 8-digit values - must not be caught by year rewriter
|
||||||
# ISO day ranges. 20200101 -> 2020-01-01T00:00:00Z (lo of that day);
|
|
||||||
# 20201231 -> the ceil of Dec 31 = 2021-01-01T00:00:00Z (exclusive end).
|
|
||||||
# This is the correct and accepted behavior: old compact form becomes a
|
|
||||||
# proper Tantivy-parseable ISO range.
|
|
||||||
original = "created:[20200101 TO 20201231]"
|
original = "created:[20200101 TO 20201231]"
|
||||||
result = rewrite_natural_date_keywords(original, UTC)
|
result = rewrite_natural_date_keywords(original, UTC)
|
||||||
lo, hi = _range(result, "created")
|
assert "20200101" in result or "2020-01-01" in result
|
||||||
assert lo == "2020-01-01T00:00:00Z"
|
assert "20201231" in result or "2020-12-31" in result
|
||||||
assert hi == "2021-01-01T00:00:00Z"
|
|
||||||
|
|
||||||
|
|
||||||
class TestNonDateFieldsNotRewritten:
|
class TestNonDateFieldsNotRewritten:
|
||||||
@@ -674,16 +606,6 @@ class TestNormalizeQuery:
|
|||||||
def test_normalize_expands_comma_separated_tags(self) -> None:
|
def test_normalize_expands_comma_separated_tags(self) -> None:
|
||||||
assert normalize_query("tag:foo,bar") == "tag:foo AND tag:bar"
|
assert normalize_query("tag:foo,bar") == "tag:foo AND tag:bar"
|
||||||
|
|
||||||
def test_normalize_comma_between_range_expressions(self) -> None:
|
|
||||||
# Comma-separated field range expressions (Whoosh v2 syntax) must be
|
|
||||||
# converted to AND so Tantivy does not receive an invalid comma.
|
|
||||||
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
|
||||||
assert normalize_query(q) == (
|
|
||||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
|
||||||
" AND "
|
|
||||||
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_normalize_expands_three_values(self) -> None:
|
def test_normalize_expands_three_values(self) -> None:
|
||||||
assert normalize_query("tag:foo,bar,baz") == "tag:foo AND tag:bar AND tag:baz"
|
assert normalize_query("tag:foo,bar,baz") == "tag:foo AND tag:bar AND tag:baz"
|
||||||
|
|
||||||
|
|||||||
@@ -1,742 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import UTC
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from zoneinfo import ZoneInfo
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import time_machine
|
|
||||||
|
|
||||||
from documents.search._dates import _precision_bounds
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import tantivy
|
|
||||||
from documents.search._query import _FIELD_BOOSTS
|
|
||||||
from documents.search._query import DEFAULT_SEARCH_FIELDS
|
|
||||||
from documents.search._translate import OPEN_HI
|
|
||||||
from documents.search._translate import OPEN_LO
|
|
||||||
from documents.search._translate import Comma
|
|
||||||
from documents.search._translate import FieldRange
|
|
||||||
from documents.search._translate import FieldValue
|
|
||||||
from documents.search._translate import FieldValueList
|
|
||||||
from documents.search._translate import InvalidDateQuery
|
|
||||||
from documents.search._translate import Passthrough
|
|
||||||
from documents.search._translate import resolve_commas
|
|
||||||
from documents.search._translate import scan
|
|
||||||
from documents.search._translate import translate_query
|
|
||||||
from documents.search._translate import translate_range
|
|
||||||
from documents.search._translate import translate_scalar
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestPrecisionBounds:
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("digits", "expected"),
|
|
||||||
[
|
|
||||||
("2020", ((2020, 1, 1), (2021, 1, 1))),
|
|
||||||
("202003", ((2020, 3, 1), (2020, 4, 1))),
|
|
||||||
("202012", ((2020, 12, 1), (2021, 1, 1))),
|
|
||||||
("20200115", ((2020, 1, 15), (2020, 1, 16))),
|
|
||||||
("20201231", ((2020, 12, 31), (2021, 1, 1))),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_valid(self, digits, expected):
|
|
||||||
lo, hi = _precision_bounds(digits)
|
|
||||||
assert (lo.year, lo.month, lo.day) == expected[0]
|
|
||||||
assert (hi.year, hi.month, hi.day) == expected[1]
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("digits", ["202023", "20200230", "20201301", "20", "abcd"])
|
|
||||||
def test_invalid_returns_none(self, digits):
|
|
||||||
assert _precision_bounds(digits) is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestScan:
|
|
||||||
def test_plain_words_are_passthrough(self):
|
|
||||||
assert scan("bank statement") == [Passthrough("bank statement")]
|
|
||||||
|
|
||||||
def test_field_value(self):
|
|
||||||
assert scan("created:2020") == [FieldValue("created", "2020")]
|
|
||||||
|
|
||||||
def test_field_value_in_boolean(self):
|
|
||||||
toks = scan("created:2020 OR foo")
|
|
||||||
assert toks == [
|
|
||||||
FieldValue("created", "2020"),
|
|
||||||
Passthrough(" OR foo"),
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_field_value_in_parens(self):
|
|
||||||
toks = scan("(created:2020 OR foo)")
|
|
||||||
assert toks == [
|
|
||||||
Passthrough("("),
|
|
||||||
FieldValue("created", "2020"),
|
|
||||||
Passthrough(" OR foo)"),
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_quoted_value(self):
|
|
||||||
assert scan('correspondent:"A B"') == [FieldValue("correspondent", '"A B"')]
|
|
||||||
|
|
||||||
def test_field_range(self):
|
|
||||||
assert scan("created:[2020 TO 2021]") == [
|
|
||||||
FieldRange("created", "[", "2020", "2021", "]"),
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("query", "expected"),
|
|
||||||
[
|
|
||||||
pytest.param(
|
|
||||||
"created:[2020 to]",
|
|
||||||
FieldRange("created", "[", "2020", "", "]"),
|
|
||||||
id="open_upper",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"created:[to 2020]",
|
|
||||||
FieldRange("created", "[", "", "2020", "]"),
|
|
||||||
id="open_lower",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_open_range(self, query, expected):
|
|
||||||
assert scan(query) == [expected]
|
|
||||||
|
|
||||||
def test_comma_inside_range_not_split(self):
|
|
||||||
# No depth-0 comma here; the whole thing is one range token.
|
|
||||||
toks = scan("created:[2020 TO 2021]")
|
|
||||||
assert len(toks) == 1
|
|
||||||
|
|
||||||
# --- Edge-case / regression tests (scan must never raise) ---
|
|
||||||
|
|
||||||
def test_url_is_passthrough(self):
|
|
||||||
# "http" is not a known field; the whole URL must pass through verbatim.
|
|
||||||
assert scan("http://example.com") == [Passthrough("http://example.com")]
|
|
||||||
|
|
||||||
def test_unterminated_quote_is_passthrough(self):
|
|
||||||
# title is a known field but the quoted value has no closing quote;
|
|
||||||
# _consume_value returns None so the whole string falls into passthrough.
|
|
||||||
assert scan('title:"abc') == [Passthrough('title:"abc')]
|
|
||||||
|
|
||||||
def test_unterminated_bracket_is_passthrough(self):
|
|
||||||
# created is a known field but the range bracket is never closed;
|
|
||||||
# _consume_range returns None so the whole string falls into passthrough.
|
|
||||||
assert scan("created:[2020") == [Passthrough("created:[2020")]
|
|
||||||
|
|
||||||
def test_empty_value_at_end_is_passthrough(self):
|
|
||||||
# created is a known field but there is no value after the colon
|
|
||||||
# (_consume_value returns None for start >= n), so passthrough.
|
|
||||||
assert scan("created:") == [Passthrough("created:")]
|
|
||||||
|
|
||||||
def test_value_containing_colon(self):
|
|
||||||
# The bare-word value reader stops at whitespace/paren, not at colon,
|
|
||||||
# so "2020:30" is consumed as a single value token.
|
|
||||||
assert scan("created:2020:30") == [FieldValue("created", "2020:30")]
|
|
||||||
|
|
||||||
def test_comma_followed_by_unconsumable_value_stops(self):
|
|
||||||
# A comma followed by whitespace is neither a value-list continuation nor a
|
|
||||||
# clause separator: the value stops and the comma stays as passthrough.
|
|
||||||
assert scan("tag:foo, bar") == [
|
|
||||||
FieldValue("tag", "foo"),
|
|
||||||
Passthrough(", bar"),
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_bracket_without_to_is_open_upper_bound(self):
|
|
||||||
# A bracketed value with no TO falls back to (value, "") -> open upper bound.
|
|
||||||
assert scan("created:[2020]") == [
|
|
||||||
FieldRange("created", "[", "2020", "", "]"),
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_known_field_name_midword_is_passthrough(self):
|
|
||||||
# A known field name embedded mid-word is not a field token (the
|
|
||||||
# word-boundary guard); the whole run stays passthrough.
|
|
||||||
assert scan("xtag:foo") == [Passthrough("xtag:foo")]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestCommaResolution:
|
|
||||||
def test_value_list_multi_value_field(self):
|
|
||||||
toks = resolve_commas(scan("tag:foo,bar"))
|
|
||||||
assert toks == [FieldValueList("tag", ("foo", "bar"))]
|
|
||||||
|
|
||||||
def test_value_list_three(self):
|
|
||||||
toks = resolve_commas(scan("tag_id:1,2,3"))
|
|
||||||
assert toks == [FieldValueList("tag_id", ("1", "2", "3"))]
|
|
||||||
|
|
||||||
def test_text_field_comma_is_literal(self):
|
|
||||||
# correspondent is not multi-value: comma stays inside the value.
|
|
||||||
toks = resolve_commas(scan("correspondent:foo,bar"))
|
|
||||||
assert toks == [FieldValue("correspondent", "foo,bar")]
|
|
||||||
|
|
||||||
def test_clause_separator_before_known_field(self):
|
|
||||||
toks = resolve_commas(scan("tag:foo,type:bar"))
|
|
||||||
assert toks == [FieldValue("tag", "foo"), Comma(), FieldValue("type", "bar")]
|
|
||||||
|
|
||||||
def test_clause_separator_after_range(self):
|
|
||||||
toks = resolve_commas(scan("created:[2020 TO 2021],added:[2022 TO 2023]"))
|
|
||||||
assert toks == [
|
|
||||||
FieldRange("created", "[", "2020", "2021", "]"),
|
|
||||||
Comma(),
|
|
||||||
FieldRange("added", "[", "2022", "2023", "]"),
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_clause_separator_after_quote(self):
|
|
||||||
toks = resolve_commas(scan('correspondent:"A B",created:[2020 TO 2021]'))
|
|
||||||
assert toks == [
|
|
||||||
FieldValue("correspondent", '"A B"'),
|
|
||||||
Comma(),
|
|
||||||
FieldRange("created", "[", "2020", "2021", "]"),
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_url_comma_is_literal_passthrough(self):
|
|
||||||
toks = resolve_commas(scan("http://example.com/a,b"))
|
|
||||||
assert toks == [Passthrough("http://example.com/a,b")]
|
|
||||||
|
|
||||||
def test_non_multi_value_comma_is_literal(self):
|
|
||||||
# title is not in MULTI_VALUE_FIELDS: comma stays inside the value.
|
|
||||||
toks = resolve_commas(scan("title:10,20"))
|
|
||||||
assert toks == [FieldValue("title", "10,20")]
|
|
||||||
|
|
||||||
def test_clause_separator_before_known_date_field(self):
|
|
||||||
# The comma between a bare value and a known date field acts as a
|
|
||||||
# clause separator; both sides survive as distinct tokens.
|
|
||||||
toks = resolve_commas(scan("correspondent:foo,created:[2020 TO 2021]"))
|
|
||||||
assert toks == [
|
|
||||||
FieldValue("correspondent", "foo"),
|
|
||||||
Comma(),
|
|
||||||
FieldRange("created", "[", "2020", "2021", "]"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestTranslateScalar:
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("field", "value", "expected"),
|
|
||||||
[
|
|
||||||
(
|
|
||||||
"created",
|
|
||||||
"2020",
|
|
||||||
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"created",
|
|
||||||
"202003",
|
|
||||||
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"created",
|
|
||||||
"20200115",
|
|
||||||
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"created",
|
|
||||||
"2020-01-15",
|
|
||||||
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"created",
|
|
||||||
"2020-03",
|
|
||||||
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_partial_and_iso_dates(self, field: str, value: str, expected: str) -> None:
|
|
||||||
assert translate_scalar(field, value, UTC) == expected
|
|
||||||
|
|
||||||
def test_invalid_date_raises(self) -> None:
|
|
||||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
|
||||||
translate_scalar("created", "202023", UTC)
|
|
||||||
assert exc_info.value.field == "created"
|
|
||||||
assert exc_info.value.value == "202023"
|
|
||||||
|
|
||||||
def test_keyword_delegates(self) -> None:
|
|
||||||
# keyword path produces a range; just assert it is a created range
|
|
||||||
out = translate_scalar("created", "today", UTC)
|
|
||||||
assert out.startswith("created:[") and out.endswith("]")
|
|
||||||
|
|
||||||
def test_14digit_compact_datetime(self) -> None:
|
|
||||||
out = translate_scalar("created", "20240115120000", UTC)
|
|
||||||
assert "20240115120000" not in out
|
|
||||||
assert out.startswith("created:")
|
|
||||||
assert out == "created:[2024-01-15T12:00:00Z TO 2024-01-15T12:00:00Z]"
|
|
||||||
|
|
||||||
def test_14digit_invalid_month_raises(self) -> None:
|
|
||||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
|
||||||
translate_scalar("created", "20231300120000", UTC)
|
|
||||||
assert exc_info.value.field == "created"
|
|
||||||
assert exc_info.value.value == "20231300120000"
|
|
||||||
|
|
||||||
def test_unrecognized_value_raises(self) -> None:
|
|
||||||
# A value that is not a keyword, digits, ISO date, or compact timestamp
|
|
||||||
# raises rather than producing invalid Tantivy syntax or silently matching
|
|
||||||
# nothing.
|
|
||||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
|
||||||
translate_scalar("created", "garbage", UTC)
|
|
||||||
assert exc_info.value.field == "created"
|
|
||||||
assert exc_info.value.value == "garbage"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestTranslateRange:
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("lo", "hi", "expected"),
|
|
||||||
[
|
|
||||||
("2005", "2009", "created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"),
|
|
||||||
(
|
|
||||||
"202001",
|
|
||||||
"202006",
|
|
||||||
"created:[2020-01-01T00:00:00Z TO 2020-07-01T00:00:00Z]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"20200101",
|
|
||||||
"20201231",
|
|
||||||
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"2020-01-01",
|
|
||||||
"2020-12-31",
|
|
||||||
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_absolute_ranges(self, lo, hi, expected):
|
|
||||||
assert translate_range("created", lo, hi, UTC) == expected
|
|
||||||
|
|
||||||
def test_reversed_swaps(self):
|
|
||||||
assert translate_range("created", "2009", "2005", UTC) == (
|
|
||||||
"created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_open_upper(self):
|
|
||||||
out = translate_range("created", "2020", "", UTC)
|
|
||||||
assert out == f"created:[2020-01-01T00:00:00Z TO {OPEN_HI}]"
|
|
||||||
|
|
||||||
def test_open_lower(self):
|
|
||||||
out = translate_range("created", "", "2020", UTC)
|
|
||||||
assert out == f"created:[{OPEN_LO} TO 2021-01-01T00:00:00Z]"
|
|
||||||
|
|
||||||
def test_invalid_bound_raises(self):
|
|
||||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
|
||||||
translate_range("created", "202023", "2025", UTC)
|
|
||||||
assert exc_info.value.field == "created"
|
|
||||||
assert exc_info.value.value == "202023"
|
|
||||||
|
|
||||||
def test_invalid_high_bound_raises(self):
|
|
||||||
# Low bound parses, high bound does not -> raise on the high bound.
|
|
||||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
|
||||||
translate_range("created", "2020", "garbage", UTC)
|
|
||||||
assert exc_info.value.field == "created"
|
|
||||||
assert exc_info.value.value == "garbage"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestTranslateQuery:
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("raw", "expected"),
|
|
||||||
[
|
|
||||||
(
|
|
||||||
"created:2020",
|
|
||||||
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
|
|
||||||
),
|
|
||||||
("tag:foo,bar", "tag:foo AND tag:bar"),
|
|
||||||
# 'type' is a user-facing alias rewritten to 'document_type' (the real schema field)
|
|
||||||
("tag:foo,type:bar", "tag:foo AND document_type:bar"),
|
|
||||||
(
|
|
||||||
"created:[2020 TO 2021],added:[2022 TO 2023]",
|
|
||||||
"created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
|
|
||||||
" AND "
|
|
||||||
"added:[2022-01-01T00:00:00Z TO 2024-01-01T00:00:00Z]",
|
|
||||||
),
|
|
||||||
# correspondent is not multi-value: comma stays literal inside the value
|
|
||||||
("correspondent:foo,bar", "correspondent:foo,bar"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_golden(self, raw: str, expected: str) -> None:
|
|
||||||
assert translate_query(raw, UTC) == expected
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"raw",
|
|
||||||
[
|
|
||||||
"created:2020",
|
|
||||||
"created:202003",
|
|
||||||
"created:[20200101 TO 20201231]",
|
|
||||||
"created:[2020-01-01 TO 2020-12-31]",
|
|
||||||
"created:[2020 to]",
|
|
||||||
"created:[to 2020]",
|
|
||||||
"title:x,created:[2020 TO 2021]",
|
|
||||||
"created:2020 OR foo",
|
|
||||||
"(created:2020 OR invoice)",
|
|
||||||
"tag:foo,type:bar",
|
|
||||||
"bank statement",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
|
|
||||||
translated = translate_query(raw, UTC)
|
|
||||||
# Must not raise:
|
|
||||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestFieldAliasing:
|
|
||||||
"""Whoosh->Tantivy field-name aliasing (type/path -> document_type/storage_path)."""
|
|
||||||
|
|
||||||
def test_type_alias(self) -> None:
|
|
||||||
assert translate_query("type:invoice", UTC) == "document_type:invoice"
|
|
||||||
|
|
||||||
def test_path_alias(self) -> None:
|
|
||||||
assert translate_query("path:/foo/bar", UTC) == "storage_path:/foo/bar"
|
|
||||||
|
|
||||||
def test_type_id_alias(self) -> None:
|
|
||||||
assert translate_query("type_id:5", UTC) == "document_type_id:5"
|
|
||||||
|
|
||||||
def test_path_id_alias(self) -> None:
|
|
||||||
assert translate_query("path_id:7", UTC) == "storage_path_id:7"
|
|
||||||
|
|
||||||
def test_clause_separator_plus_alias(self) -> None:
|
|
||||||
# Comma between known fields acts as AND separator; alias still applied.
|
|
||||||
assert (
|
|
||||||
translate_query("tag:foo,type:bar", UTC) == "tag:foo AND document_type:bar"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_type_range_alias(self) -> None:
|
|
||||||
# type is not a date field; range passes through verbatim with alias applied.
|
|
||||||
assert (
|
|
||||||
translate_query("type:[2020 TO 2021]", UTC)
|
|
||||||
== "document_type:[2020 TO 2021]"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_parse_acceptance_type(self, index: tantivy.Index) -> None:
|
|
||||||
# Translated output must be accepted by the real Tantivy parser.
|
|
||||||
translated = translate_query("type:invoice", UTC)
|
|
||||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
|
||||||
|
|
||||||
def test_parse_acceptance_path(self, index: tantivy.Index) -> None:
|
|
||||||
translated = translate_query("path:foo", UTC)
|
|
||||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
|
||||||
|
|
||||||
|
|
||||||
# Freeze time so relative-date tests are deterministic.
|
|
||||||
_FROZEN_NOW = datetime(2026, 3, 28, 12, 0, 0, tzinfo=UTC)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestRelativeRanges:
|
|
||||||
"""Relative date-range tokens resolved against a frozen clock."""
|
|
||||||
|
|
||||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
|
||||||
def test_minus_7_days_to_now(self) -> None:
|
|
||||||
assert translate_query("added:[-7 days to now]", UTC) == (
|
|
||||||
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
|
||||||
)
|
|
||||||
|
|
||||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
|
||||||
def test_minus_1_week_to_now(self) -> None:
|
|
||||||
assert translate_query("added:[-1 week to now]", UTC) == (
|
|
||||||
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
|
||||||
)
|
|
||||||
|
|
||||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
|
||||||
def test_minus_1_month_to_now(self) -> None:
|
|
||||||
assert translate_query("created:[-1 month to now]", UTC) == (
|
|
||||||
"created:[2026-02-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
|
||||||
)
|
|
||||||
|
|
||||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
|
||||||
def test_minus_1_year_to_now(self) -> None:
|
|
||||||
assert translate_query("modified:[-1 year to now]", UTC) == (
|
|
||||||
"modified:[2025-03-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
|
||||||
)
|
|
||||||
|
|
||||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
|
||||||
def test_minus_3_hours_to_now(self) -> None:
|
|
||||||
assert translate_query("added:[-3 hours to now]", UTC) == (
|
|
||||||
"added:[2026-03-28T09:00:00Z TO 2026-03-28T12:00:00Z]"
|
|
||||||
)
|
|
||||||
|
|
||||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
|
||||||
def test_uppercase_units(self) -> None:
|
|
||||||
assert translate_query("added:[-1 WEEK TO NOW]", UTC) == (
|
|
||||||
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
|
||||||
)
|
|
||||||
|
|
||||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
|
||||||
def test_now_minus_7d_compact(self) -> None:
|
|
||||||
assert translate_query("added:[now-7d TO now]", UTC) == (
|
|
||||||
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
|
||||||
)
|
|
||||||
|
|
||||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
|
||||||
def test_reversed_range_swapped(self) -> None:
|
|
||||||
# now+1h TO now-1h is reversed; translate_range swaps -> lo=now-1h, hi=now+1h
|
|
||||||
assert translate_query("added:[now+1h TO now-1h]", UTC) == (
|
|
||||||
"added:[2026-03-28T11:00:00Z TO 2026-03-28T13:00:00Z]"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"raw",
|
|
||||||
[
|
|
||||||
"added:[-7 days to now]",
|
|
||||||
"added:[-1 week to now]",
|
|
||||||
"created:[-1 month to now]",
|
|
||||||
"modified:[-1 year to now]",
|
|
||||||
"added:[-3 hours to now]",
|
|
||||||
"added:[now-7d TO now]",
|
|
||||||
"added:[now+1h TO now-1h]",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
|
||||||
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
|
|
||||||
translated = translate_query(raw, UTC)
|
|
||||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestOperatorNormalization:
|
|
||||||
"""Post-render operator normalization in translate_query."""
|
|
||||||
|
|
||||||
def test_spaced_dash_removed(self) -> None:
|
|
||||||
assert (
|
|
||||||
translate_query("H52.1 - Kurzsichtigkeit", UTC) == "H52.1 Kurzsichtigkeit"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_spaced_dash_simple(self) -> None:
|
|
||||||
assert translate_query("bar - baz", UTC) == "bar baz"
|
|
||||||
|
|
||||||
def test_trailing_operator_stripped(self) -> None:
|
|
||||||
assert translate_query("foo -", UTC) == "foo"
|
|
||||||
|
|
||||||
def test_date_range_preserved(self) -> None:
|
|
||||||
out = translate_query("created:[2020 TO 2021]", UTC)
|
|
||||||
# Must not corrupt the ISO range
|
|
||||||
assert out == "created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
|
|
||||||
|
|
||||||
def test_date_scalar_with_or(self) -> None:
|
|
||||||
out = translate_query("created:2020 OR foo", UTC)
|
|
||||||
# The created scalar becomes a range; " OR foo" passes through verbatim.
|
|
||||||
assert out.startswith("created:[")
|
|
||||||
assert "OR foo" in out
|
|
||||||
|
|
||||||
def test_parse_acceptance_spaced_dash(self, index: tantivy.Index) -> None:
|
|
||||||
translated = translate_query("H52.1 - Kurzsichtigkeit", UTC)
|
|
||||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
|
||||||
|
|
||||||
def test_parse_acceptance_trailing_op(self, index: tantivy.Index) -> None:
|
|
||||||
translated = translate_query("foo -", UTC)
|
|
||||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestMultiWordDateKeywords:
|
|
||||||
"""scan() must consume multi-word date keywords as a single value."""
|
|
||||||
|
|
||||||
def test_scan_previous_week_as_single_token(self) -> None:
|
|
||||||
# "created:previous week" must produce one FieldValue with value "previous week",
|
|
||||||
# not FieldValue("created","previous") + Passthrough(" week").
|
|
||||||
toks = scan("created:previous week")
|
|
||||||
assert toks == [FieldValue("created", "previous week")]
|
|
||||||
|
|
||||||
def test_scan_this_month_as_single_token(self) -> None:
|
|
||||||
toks = scan("added:this month")
|
|
||||||
assert toks == [FieldValue("added", "this month")]
|
|
||||||
|
|
||||||
def test_scan_previous_month_as_single_token(self) -> None:
|
|
||||||
toks = scan("created:previous month")
|
|
||||||
assert toks == [FieldValue("created", "previous month")]
|
|
||||||
|
|
||||||
def test_scan_this_year_as_single_token(self) -> None:
|
|
||||||
toks = scan("added:this year")
|
|
||||||
assert toks == [FieldValue("added", "this year")]
|
|
||||||
|
|
||||||
def test_scan_previous_year_as_single_token(self) -> None:
|
|
||||||
toks = scan("created:previous year")
|
|
||||||
assert toks == [FieldValue("created", "previous year")]
|
|
||||||
|
|
||||||
def test_scan_previous_quarter_as_single_token(self) -> None:
|
|
||||||
toks = scan("created:previous quarter")
|
|
||||||
assert toks == [FieldValue("created", "previous quarter")]
|
|
||||||
|
|
||||||
def test_quoted_multi_word_keyword_still_works(self) -> None:
|
|
||||||
# The quoted form must continue to work as before.
|
|
||||||
toks = scan('created:"previous week"')
|
|
||||||
assert toks == [FieldValue("created", '"previous week"')]
|
|
||||||
|
|
||||||
def test_non_date_field_not_affected(self) -> None:
|
|
||||||
# "previous" stops at the space for non-date fields; " week" passes through.
|
|
||||||
toks = scan("correspondent:previous week")
|
|
||||||
assert toks == [
|
|
||||||
FieldValue("correspondent", "previous"),
|
|
||||||
Passthrough(" week"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestKeywordDateResolution:
|
|
||||||
"""Relative date keywords resolve to exact ISO ranges against a frozen clock.
|
|
||||||
|
|
||||||
Frozen at 2026-03-28 12:00 UTC (a Saturday in Q1) so the week, month,
|
|
||||||
quarter and year rollovers are all exercised by a single anchor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# created is a DateField: bounds are UTC midnight, no timezone offset.
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("keyword", "expected"),
|
|
||||||
[
|
|
||||||
pytest.param(
|
|
||||||
"today",
|
|
||||||
"created:[2026-03-28T00:00:00Z TO 2026-03-29T00:00:00Z]",
|
|
||||||
id="today",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"yesterday",
|
|
||||||
"created:[2026-03-27T00:00:00Z TO 2026-03-28T00:00:00Z]",
|
|
||||||
id="yesterday",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"previous week",
|
|
||||||
"created:[2026-03-16T00:00:00Z TO 2026-03-23T00:00:00Z]",
|
|
||||||
id="previous-week",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"this month",
|
|
||||||
"created:[2026-03-01T00:00:00Z TO 2026-04-01T00:00:00Z]",
|
|
||||||
id="this-month",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"previous month",
|
|
||||||
"created:[2026-02-01T00:00:00Z TO 2026-03-01T00:00:00Z]",
|
|
||||||
id="previous-month",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"this year",
|
|
||||||
"created:[2026-01-01T00:00:00Z TO 2027-01-01T00:00:00Z]",
|
|
||||||
id="this-year",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"previous year",
|
|
||||||
"created:[2025-01-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
|
|
||||||
id="previous-year",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"previous quarter",
|
|
||||||
"created:[2025-10-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
|
|
||||||
id="previous-quarter",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
|
||||||
def test_date_only_field_keyword_ranges(
|
|
||||||
self,
|
|
||||||
keyword: str,
|
|
||||||
expected: str,
|
|
||||||
) -> None:
|
|
||||||
assert translate_query(f"created:{keyword}", UTC) == expected
|
|
||||||
|
|
||||||
# added is a DateTimeField: local-tz midnight converted to UTC. Tokyo
|
|
||||||
# (+09:00, no DST) shifts each midnight boundary back to 15:00Z the day
|
|
||||||
# before, so this also exercises the local-midnight offset path.
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("keyword", "expected"),
|
|
||||||
[
|
|
||||||
pytest.param(
|
|
||||||
"today",
|
|
||||||
"added:[2026-03-27T15:00:00Z TO 2026-03-28T15:00:00Z]",
|
|
||||||
id="today",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"yesterday",
|
|
||||||
"added:[2026-03-26T15:00:00Z TO 2026-03-27T15:00:00Z]",
|
|
||||||
id="yesterday",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"previous week",
|
|
||||||
"added:[2026-03-15T15:00:00Z TO 2026-03-22T15:00:00Z]",
|
|
||||||
id="previous-week",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"this month",
|
|
||||||
"added:[2026-02-28T15:00:00Z TO 2026-03-31T15:00:00Z]",
|
|
||||||
id="this-month",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"previous month",
|
|
||||||
"added:[2026-01-31T15:00:00Z TO 2026-02-28T15:00:00Z]",
|
|
||||||
id="previous-month",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"this year",
|
|
||||||
"added:[2025-12-31T15:00:00Z TO 2026-12-31T15:00:00Z]",
|
|
||||||
id="this-year",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"previous year",
|
|
||||||
"added:[2024-12-31T15:00:00Z TO 2025-12-31T15:00:00Z]",
|
|
||||||
id="previous-year",
|
|
||||||
),
|
|
||||||
pytest.param(
|
|
||||||
"previous quarter",
|
|
||||||
"added:[2025-09-30T15:00:00Z TO 2025-12-31T15:00:00Z]",
|
|
||||||
id="previous-quarter",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
|
||||||
def test_datetime_field_keyword_ranges_local_tz(
|
|
||||||
self,
|
|
||||||
keyword: str,
|
|
||||||
expected: str,
|
|
||||||
) -> None:
|
|
||||||
assert translate_query(f"added:{keyword}", ZoneInfo("Asia/Tokyo")) == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.search
|
|
||||||
class TestISODatetimeBounds:
|
|
||||||
"""Full ISO datetime tokens in range bounds must be parsed directly."""
|
|
||||||
|
|
||||||
def test_translate_range_iso_bounds_passthrough(self) -> None:
|
|
||||||
# Already-ISO datetime bounds must pass through as-is (exact instant).
|
|
||||||
result = translate_range(
|
|
||||||
"created",
|
|
||||||
"2020-01-01T00:00:00Z",
|
|
||||||
"2021-01-01T00:00:00Z",
|
|
||||||
UTC,
|
|
||||||
)
|
|
||||||
assert result == "created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]"
|
|
||||||
|
|
||||||
def test_translate_query_iso_range_preserved(self) -> None:
|
|
||||||
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
|
||||||
assert translate_query(q, UTC) == q
|
|
||||||
|
|
||||||
def test_translate_query_comma_separated_iso_ranges(self) -> None:
|
|
||||||
q = (
|
|
||||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
|
|
||||||
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
|
||||||
)
|
|
||||||
result = translate_query(q, UTC)
|
|
||||||
assert result == (
|
|
||||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
|
||||||
" AND "
|
|
||||||
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_invalid_iso_datetime_raises(self) -> None:
|
|
||||||
# A token with "T" that is not valid ISO datetime -> raise.
|
|
||||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
|
||||||
translate_range(
|
|
||||||
"created",
|
|
||||||
"2020-01-01T99:00:00Z",
|
|
||||||
"2021-01-01T00:00:00Z",
|
|
||||||
UTC,
|
|
||||||
)
|
|
||||||
assert exc_info.value.field == "created"
|
|
||||||
assert exc_info.value.value == "2020-01-01T99:00:00Z"
|
|
||||||
|
|
||||||
def test_parse_acceptance_iso_bounds(self, index: tantivy.Index) -> None:
|
|
||||||
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
|
||||||
translated = translate_query(q, UTC)
|
|
||||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
|
||||||
|
|
||||||
def test_parse_acceptance_comma_iso_ranges(self, index: tantivy.Index) -> None:
|
|
||||||
q = (
|
|
||||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
|
|
||||||
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
|
||||||
)
|
|
||||||
translated = translate_query(q, UTC)
|
|
||||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
|
||||||
@@ -82,7 +82,6 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
|||||||
"llm_api_key": None,
|
"llm_api_key": None,
|
||||||
"llm_endpoint": None,
|
"llm_endpoint": None,
|
||||||
"llm_output_language": None,
|
"llm_output_language": None,
|
||||||
"llm_request_timeout": None,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -845,7 +844,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||||
):
|
):
|
||||||
mock_exists.return_value = False
|
mock_exists.return_value = False
|
||||||
self.client.patch(
|
self.client.patch(
|
||||||
@@ -870,7 +869,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||||
):
|
):
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
self.client.patch(
|
self.client.patch(
|
||||||
@@ -891,7 +890,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||||
):
|
):
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
self.client.patch(
|
self.client.patch(
|
||||||
@@ -929,7 +928,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||||
):
|
):
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
self.client.patch(
|
self.client.patch(
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import zipfile
|
|||||||
|
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
from django.utils import timezone
|
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
@@ -33,21 +32,21 @@ class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase):
|
|||||||
filename="docA.pdf",
|
filename="docA.pdf",
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
checksum="B",
|
checksum="B",
|
||||||
created=timezone.make_aware(datetime.datetime(2021, 1, 1)),
|
created=datetime.datetime(2021, 1, 1, tzinfo=datetime.UTC),
|
||||||
)
|
)
|
||||||
self.doc2b = Document.objects.create(
|
self.doc2b = Document.objects.create(
|
||||||
title="document A",
|
title="document A",
|
||||||
filename="docA2.pdf",
|
filename="docA2.pdf",
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
checksum="D",
|
checksum="D",
|
||||||
created=timezone.make_aware(datetime.datetime(2021, 1, 1)),
|
created=datetime.datetime(2021, 1, 1, tzinfo=datetime.UTC),
|
||||||
)
|
)
|
||||||
self.doc3 = Document.objects.create(
|
self.doc3 = Document.objects.create(
|
||||||
title="document B",
|
title="document B",
|
||||||
filename="docB.jpg",
|
filename="docB.jpg",
|
||||||
mime_type="image/jpeg",
|
mime_type="image/jpeg",
|
||||||
checksum="C",
|
checksum="C",
|
||||||
created=timezone.make_aware(datetime.datetime(2020, 3, 21)),
|
created=datetime.datetime(2020, 3, 21, tzinfo=datetime.UTC),
|
||||||
archive_filename="docB.pdf",
|
archive_filename="docB.pdf",
|
||||||
archive_checksum="D",
|
archive_checksum="D",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
import datetime
|
||||||
import json
|
import json
|
||||||
from datetime import date
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import ANY
|
from unittest.mock import ANY
|
||||||
|
|
||||||
@@ -456,7 +456,7 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
date_value = date.today()
|
date_value = datetime.datetime.now(tz=datetime.UTC).date()
|
||||||
|
|
||||||
resp = self.client.patch(
|
resp = self.client.patch(
|
||||||
f"/api/documents/{doc.id}/",
|
f"/api/documents/{doc.id}/",
|
||||||
@@ -618,7 +618,7 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase):
|
|||||||
data_type=CustomField.FieldDataType.DATE,
|
data_type=CustomField.FieldDataType.DATE,
|
||||||
)
|
)
|
||||||
|
|
||||||
date_value = date.today()
|
date_value = datetime.datetime.now(tz=datetime.UTC).date()
|
||||||
|
|
||||||
resp = self.client.patch(
|
resp = self.client.patch(
|
||||||
f"/api/documents/{doc.id}/",
|
f"/api/documents/{doc.id}/",
|
||||||
|
|||||||
@@ -265,7 +265,7 @@ class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
|
|||||||
created=date(2023, 1, 1),
|
created=date(2023, 1, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
created_datetime = datetime.datetime(2023, 2, 1, 12, 0, 0)
|
created_datetime = datetime.datetime(2023, 2, 1, 12, 0, 0, tzinfo=datetime.UTC)
|
||||||
response = self.client.patch(
|
response = self.client.patch(
|
||||||
f"/api/documents/{doc.pk}/",
|
f"/api/documents/{doc.pk}/",
|
||||||
{"created": created_datetime},
|
{"created": created_datetime},
|
||||||
|
|||||||
@@ -1,95 +0,0 @@
|
|||||||
import unicodedata
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
import celery.result
|
|
||||||
import pytest
|
|
||||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from documents.data_models import ConsumableDocument
|
|
||||||
from documents.data_models import DocumentMetadataOverrides
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def consume_file_mock():
|
|
||||||
with mock.patch("documents.tasks.consume_file.apply_async") as m:
|
|
||||||
m.return_value = celery.result.AsyncResult(id="test-task-id")
|
|
||||||
yield m
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def directories(tmp_path, settings, _media_settings):
|
|
||||||
scratch = tmp_path / "scratch"
|
|
||||||
scratch.mkdir()
|
|
||||||
settings.SCRATCH_DIR = scratch
|
|
||||||
return scratch
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
|
||||||
class TestPostDocumentNFCNormalization:
|
|
||||||
def test_nfd_filename_normalized_to_nfc(
|
|
||||||
self,
|
|
||||||
admin_client,
|
|
||||||
consume_file_mock: mock.MagicMock,
|
|
||||||
directories,
|
|
||||||
):
|
|
||||||
"""Uploaded file with NFD filename must have its name stored as NFC."""
|
|
||||||
nfd = unicodedata.normalize("NFD", "Rechnung März.pdf")
|
|
||||||
nfc = unicodedata.normalize("NFC", "Rechnung März.pdf")
|
|
||||||
|
|
||||||
# Verify our test strings actually differ at the byte level
|
|
||||||
assert nfd != nfc
|
|
||||||
|
|
||||||
uploaded = SimpleUploadedFile(
|
|
||||||
nfd,
|
|
||||||
b"%PDF-1.4 test",
|
|
||||||
content_type="application/pdf",
|
|
||||||
)
|
|
||||||
response = admin_client.post(
|
|
||||||
"/api/documents/post_document/",
|
|
||||||
{"document": uploaded},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
|
|
||||||
input_doc: ConsumableDocument = task_kwargs["input_doc"]
|
|
||||||
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
|
|
||||||
|
|
||||||
# The temp file on disk must have an NFC name
|
|
||||||
assert input_doc.original_file.name == nfc, (
|
|
||||||
f"Expected NFC filename {nfc!r}, got {input_doc.original_file.name!r}"
|
|
||||||
)
|
|
||||||
# The override filename stored for later use must also be NFC
|
|
||||||
assert overrides.filename == nfc, (
|
|
||||||
f"Expected NFC override filename {nfc!r}, got {overrides.filename!r}"
|
|
||||||
)
|
|
||||||
assert unicodedata.is_normalized("NFC", overrides.filename)
|
|
||||||
|
|
||||||
def test_already_nfc_filename_unchanged(
|
|
||||||
self,
|
|
||||||
admin_client,
|
|
||||||
consume_file_mock: mock.MagicMock,
|
|
||||||
directories,
|
|
||||||
):
|
|
||||||
"""Uploaded file with already-NFC filename must pass through unchanged."""
|
|
||||||
nfc = unicodedata.normalize("NFC", "Invoice_2024.pdf")
|
|
||||||
|
|
||||||
uploaded = SimpleUploadedFile(
|
|
||||||
nfc,
|
|
||||||
b"%PDF-1.4 test",
|
|
||||||
content_type="application/pdf",
|
|
||||||
)
|
|
||||||
response = admin_client.post(
|
|
||||||
"/api/documents/post_document/",
|
|
||||||
{"document": uploaded},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
|
|
||||||
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
|
|
||||||
|
|
||||||
assert overrides.filename == nfc
|
|
||||||
assert unicodedata.is_normalized("NFC", overrides.filename)
|
|
||||||
@@ -700,7 +700,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
|||||||
pk=3,
|
pk=3,
|
||||||
checksum="C",
|
checksum="C",
|
||||||
# specific time zone aware date
|
# specific time zone aware date
|
||||||
added=timezone.make_aware(datetime.datetime(2023, 12, 1)),
|
added=datetime.datetime(2023, 12, 1, tzinfo=datetime.UTC),
|
||||||
)
|
)
|
||||||
# refresh doc instance to ensure we operate on date objects that Django uses
|
# refresh doc instance to ensure we operate on date objects that Django uses
|
||||||
# Django converts dates to UTC
|
# Django converts dates to UTC
|
||||||
@@ -725,11 +725,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
|||||||
GIVEN:
|
GIVEN:
|
||||||
- One document added right now
|
- One document added right now
|
||||||
WHEN:
|
WHEN:
|
||||||
- Query with an invalid added date
|
- Query with invalid added date
|
||||||
THEN:
|
THEN:
|
||||||
- 400 Bad Request with a message naming the malformed date, so the
|
- 400 Bad Request returned (Tantivy rejects invalid date field syntax)
|
||||||
user knows their date is invalid rather than silently getting zero
|
|
||||||
results
|
|
||||||
"""
|
"""
|
||||||
d1 = Document.objects.create(
|
d1 = Document.objects.create(
|
||||||
title="invoice",
|
title="invoice",
|
||||||
@@ -742,9 +740,8 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
|||||||
|
|
||||||
response = self.client.get("/api/documents/?query=added:invalid-date")
|
response = self.client.get("/api/documents/?query=added:invalid-date")
|
||||||
|
|
||||||
# An unparsable date is reported as a malformed query, not silently empty.
|
# Tantivy rejects unparsable field queries with a 400
|
||||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
self.assertIn("invalid-date", str(response.data["query"]))
|
|
||||||
|
|
||||||
@override_settings(
|
@override_settings(
|
||||||
TIME_ZONE="UTC",
|
TIME_ZONE="UTC",
|
||||||
@@ -997,25 +994,25 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
|||||||
title="invoice",
|
title="invoice",
|
||||||
content="the thing i bought at a shop and paid with bank account",
|
content="the thing i bought at a shop and paid with bank account",
|
||||||
created=datetime.date(2018, 1, 1),
|
created=datetime.date(2018, 1, 1),
|
||||||
added=timezone.make_aware(datetime.datetime(2018, 1, 1)),
|
added=datetime.datetime(2018, 1, 1, tzinfo=datetime.UTC),
|
||||||
)
|
)
|
||||||
d2 = DocumentFactory(
|
d2 = DocumentFactory(
|
||||||
title="bank statement 1",
|
title="bank statement 1",
|
||||||
content="things i paid for in august",
|
content="things i paid for in august",
|
||||||
created=datetime.date(2019, 3, 4),
|
created=datetime.date(2019, 3, 4),
|
||||||
added=timezone.make_aware(datetime.datetime(2019, 3, 4)),
|
added=datetime.datetime(2019, 3, 4, tzinfo=datetime.UTC),
|
||||||
)
|
)
|
||||||
d3 = DocumentFactory(
|
d3 = DocumentFactory(
|
||||||
title="bank statement 3",
|
title="bank statement 3",
|
||||||
content="things i paid for in september",
|
content="things i paid for in september",
|
||||||
created=datetime.date(2020, 7, 9),
|
created=datetime.date(2020, 7, 9),
|
||||||
added=timezone.make_aware(datetime.datetime(2020, 7, 9)),
|
added=datetime.datetime(2020, 7, 9, tzinfo=datetime.UTC),
|
||||||
)
|
)
|
||||||
d4 = DocumentFactory(
|
d4 = DocumentFactory(
|
||||||
title="Quarterly Report",
|
title="Quarterly Report",
|
||||||
content="quarterly revenue profit margin earnings growth",
|
content="quarterly revenue profit margin earnings growth",
|
||||||
created=datetime.date(2021, 11, 30),
|
created=datetime.date(2021, 11, 30),
|
||||||
added=timezone.make_aware(datetime.datetime(2021, 11, 30)),
|
added=datetime.datetime(2021, 11, 30, tzinfo=datetime.UTC),
|
||||||
)
|
)
|
||||||
backend = get_backend()
|
backend = get_backend()
|
||||||
backend.add_or_update(d1)
|
backend.add_or_update(d1)
|
||||||
@@ -1134,7 +1131,7 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
|||||||
d4.tags.add(t2)
|
d4.tags.add(t2)
|
||||||
d5 = Document.objects.create(
|
d5 = Document.objects.create(
|
||||||
checksum="5",
|
checksum="5",
|
||||||
added=timezone.make_aware(datetime.datetime(2020, 7, 13)),
|
added=datetime.datetime(2020, 7, 13, tzinfo=datetime.UTC),
|
||||||
content="test",
|
content="test",
|
||||||
original_filename="doc5.pdf",
|
original_filename="doc5.pdf",
|
||||||
)
|
)
|
||||||
@@ -1244,14 +1241,18 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
|||||||
d4.id,
|
d4.id,
|
||||||
search_query(
|
search_query(
|
||||||
"&created__date__lt="
|
"&created__date__lt="
|
||||||
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
|
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
|
||||||
|
"%Y-%m-%d",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertNotIn(
|
self.assertNotIn(
|
||||||
d4.id,
|
d4.id,
|
||||||
search_query(
|
search_query(
|
||||||
"&created__date__gt="
|
"&created__date__gt="
|
||||||
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
|
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
|
||||||
|
"%Y-%m-%d",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1259,14 +1260,18 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
|||||||
d4.id,
|
d4.id,
|
||||||
search_query(
|
search_query(
|
||||||
"&created__date__lt="
|
"&created__date__lt="
|
||||||
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
|
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
|
||||||
|
"%Y-%m-%d",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
d4.id,
|
d4.id,
|
||||||
search_query(
|
search_query(
|
||||||
"&created__date__gt="
|
"&created__date__gt="
|
||||||
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
|
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
|
||||||
|
"%Y-%m-%d",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1274,14 +1279,18 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
|||||||
d5.id,
|
d5.id,
|
||||||
search_query(
|
search_query(
|
||||||
"&added__date__lt="
|
"&added__date__lt="
|
||||||
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
|
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
|
||||||
|
"%Y-%m-%d",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.assertNotIn(
|
self.assertNotIn(
|
||||||
d5.id,
|
d5.id,
|
||||||
search_query(
|
search_query(
|
||||||
"&added__date__gt="
|
"&added__date__gt="
|
||||||
+ datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d"),
|
+ datetime.datetime(2020, 9, 2, tzinfo=datetime.UTC).strftime(
|
||||||
|
"%Y-%m-%d",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1289,7 +1298,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
|||||||
d5.id,
|
d5.id,
|
||||||
search_query(
|
search_query(
|
||||||
"&added__date__lt="
|
"&added__date__lt="
|
||||||
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
|
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
|
||||||
|
"%Y-%m-%d",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1297,7 +1308,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
|||||||
d5.id,
|
d5.id,
|
||||||
search_query(
|
search_query(
|
||||||
"&added__date__gt="
|
"&added__date__gt="
|
||||||
+ datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d"),
|
+ datetime.datetime(2020, 1, 2, tzinfo=datetime.UTC).strftime(
|
||||||
|
"%Y-%m-%d",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -216,77 +216,6 @@ class TestSystemStatus(APITestCase):
|
|||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
|
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
|
||||||
|
|
||||||
@mock.patch("celery.app.control.Inspect.ping")
|
|
||||||
def test_system_status_celery_ping_none(self, mock_ping) -> None:
|
|
||||||
"""
|
|
||||||
GIVEN:
|
|
||||||
- Celery ping returns no worker responses
|
|
||||||
WHEN:
|
|
||||||
- The user requests the system status
|
|
||||||
THEN:
|
|
||||||
- The response contains a warning celery status
|
|
||||||
"""
|
|
||||||
mock_ping.return_value = None
|
|
||||||
self.client.force_login(self.user)
|
|
||||||
response = self.client.get(self.ENDPOINT)
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
||||||
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
|
|
||||||
self.assertEqual(
|
|
||||||
response.data["tasks"]["celery_error"],
|
|
||||||
"No celery workers responded to ping. This may be temporary.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@mock.patch("celery.app.control.Inspect.ping")
|
|
||||||
def test_system_status_celery_ping_unexpected_responses(self, mock_ping) -> None:
|
|
||||||
"""
|
|
||||||
GIVEN:
|
|
||||||
- Celery ping returns an unexpected worker response
|
|
||||||
WHEN:
|
|
||||||
- The user requests the system status
|
|
||||||
THEN:
|
|
||||||
- The response contains a warning celery status
|
|
||||||
"""
|
|
||||||
self.client.force_login(self.user)
|
|
||||||
for ping_response in (
|
|
||||||
{"hostname": {"ok": "not-pong"}},
|
|
||||||
{"hostname": {}},
|
|
||||||
{"hostname": "pong"},
|
|
||||||
):
|
|
||||||
with self.subTest(ping_response=ping_response):
|
|
||||||
mock_ping.return_value = ping_response
|
|
||||||
response = self.client.get(self.ENDPOINT)
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
||||||
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
|
|
||||||
self.assertEqual(response.data["tasks"]["celery_url"], "hostname")
|
|
||||||
self.assertEqual(
|
|
||||||
response.data["tasks"]["celery_error"],
|
|
||||||
"Celery worker responded unexpectedly.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@mock.patch("documents.views.sleep")
|
|
||||||
@mock.patch("celery.app.control.Inspect.ping")
|
|
||||||
def test_system_status_celery_ping_retry_success(
|
|
||||||
self,
|
|
||||||
mock_ping,
|
|
||||||
mock_sleep,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
GIVEN:
|
|
||||||
- Celery ping fails once but succeeds on retry
|
|
||||||
WHEN:
|
|
||||||
- The user requests the system status
|
|
||||||
THEN:
|
|
||||||
- The response contains an OK celery status
|
|
||||||
"""
|
|
||||||
mock_ping.side_effect = [None, {"hostname": {"ok": "pong"}}]
|
|
||||||
self.client.force_login(self.user)
|
|
||||||
response = self.client.get(self.ENDPOINT)
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
||||||
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
|
|
||||||
self.assertIsNone(response.data["tasks"]["celery_error"])
|
|
||||||
self.assertEqual(mock_ping.call_count, 2)
|
|
||||||
mock_sleep.assert_called_once_with(0.25)
|
|
||||||
|
|
||||||
@mock.patch("documents.search.get_backend")
|
@mock.patch("documents.search.get_backend")
|
||||||
def test_system_status_index_ok(self, mock_get_backend) -> None:
|
def test_system_status_index_ok(self, mock_get_backend) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from guardian.shortcuts import assign_perm
|
|||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.test import APIClient
|
from rest_framework.test import APIClient
|
||||||
|
|
||||||
from documents.filters import PaperlessTaskFilterSet
|
|
||||||
from documents.models import PaperlessTask
|
from documents.models import PaperlessTask
|
||||||
from documents.tests.factories import DocumentFactory
|
from documents.tests.factories import DocumentFactory
|
||||||
from documents.tests.factories import PaperlessTaskFactory
|
from documents.tests.factories import PaperlessTaskFactory
|
||||||
@@ -170,165 +169,6 @@ class TestGetTasksV10:
|
|||||||
PaperlessTask.Status.STARTED,
|
PaperlessTask.Status.STARTED,
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_filter_by_task_name(self, admin_client: APIClient) -> None:
|
|
||||||
"""?name= searches task filenames, task types, and trigger sources."""
|
|
||||||
filename_task = PaperlessTaskFactory(input_data={"filename": "invoice-123.pdf"})
|
|
||||||
type_task = PaperlessTaskFactory(task_type=PaperlessTask.TaskType.SANITY_CHECK)
|
|
||||||
source_task = PaperlessTaskFactory(
|
|
||||||
trigger_source=PaperlessTask.TriggerSource.EMAIL_CONSUME,
|
|
||||||
)
|
|
||||||
PaperlessTaskFactory(input_data={"filename": "unrelated.pdf"})
|
|
||||||
|
|
||||||
response = admin_client.get(ENDPOINT, {"name": "invoice"})
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
assert response.data["count"] == 1
|
|
||||||
assert response.data["results"][0]["task_id"] == filename_task.task_id
|
|
||||||
|
|
||||||
response = admin_client.get(ENDPOINT, {"name": "sanity"})
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
assert response.data["count"] == 1
|
|
||||||
assert response.data["results"][0]["task_id"] == type_task.task_id
|
|
||||||
|
|
||||||
response = admin_client.get(ENDPOINT, {"name": "email"})
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
assert response.data["count"] == 1
|
|
||||||
assert response.data["results"][0]["task_id"] == source_task.task_id
|
|
||||||
|
|
||||||
def test_filter_by_task_result(self, admin_client: APIClient) -> None:
|
|
||||||
"""?result= searches common structured task result messages."""
|
|
||||||
reason_task = PaperlessTaskFactory(result_data={"reason": "Manual review"})
|
|
||||||
error_task = PaperlessTaskFactory(
|
|
||||||
result_data={"error_message": "Duplicate detected"},
|
|
||||||
)
|
|
||||||
document_task = PaperlessTaskFactory(result_data={"document_id": 321})
|
|
||||||
duplicate_task = PaperlessTaskFactory(result_data={"duplicate_of": 123})
|
|
||||||
PaperlessTaskFactory(result_data={"reason": "unrelated"})
|
|
||||||
|
|
||||||
response = admin_client.get(ENDPOINT, {"result": "manual"})
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
assert response.data["count"] == 1
|
|
||||||
assert response.data["results"][0]["task_id"] == reason_task.task_id
|
|
||||||
|
|
||||||
response = admin_client.get(ENDPOINT, {"result": "duplicate"})
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
returned_ids = {task["task_id"] for task in response.data["results"]}
|
|
||||||
assert returned_ids == {error_task.task_id, duplicate_task.task_id}
|
|
||||||
|
|
||||||
response = admin_client.get(ENDPOINT, {"result": "321"})
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
assert response.data["count"] == 1
|
|
||||||
assert response.data["results"][0]["task_id"] == document_task.task_id
|
|
||||||
|
|
||||||
def test_empty_task_name_and_result_filters(self) -> None:
|
|
||||||
"""Empty name/result values leave the queryset unchanged."""
|
|
||||||
PaperlessTaskFactory.create_batch(2)
|
|
||||||
queryset = PaperlessTask.objects.all()
|
|
||||||
filterset = PaperlessTaskFilterSet()
|
|
||||||
|
|
||||||
assert filterset.filter_name(queryset, "name", "").count() == 2
|
|
||||||
assert filterset.filter_result(queryset, "result", "").count() == 2
|
|
||||||
|
|
||||||
def test_status_counts_respects_filters(self, admin_client: APIClient) -> None:
|
|
||||||
"""status_counts/ returns section counts for the filtered task queryset."""
|
|
||||||
PaperlessTaskFactory(
|
|
||||||
acknowledged=False,
|
|
||||||
status=PaperlessTask.Status.FAILURE,
|
|
||||||
input_data={"filename": "invoice-a.pdf"},
|
|
||||||
)
|
|
||||||
PaperlessTaskFactory(
|
|
||||||
acknowledged=False,
|
|
||||||
status=PaperlessTask.Status.REVOKED,
|
|
||||||
input_data={"filename": "invoice-b.pdf"},
|
|
||||||
)
|
|
||||||
PaperlessTaskFactory(
|
|
||||||
acknowledged=False,
|
|
||||||
status=PaperlessTask.Status.PENDING,
|
|
||||||
input_data={"filename": "invoice-c.pdf"},
|
|
||||||
)
|
|
||||||
PaperlessTaskFactory(
|
|
||||||
acknowledged=False,
|
|
||||||
status=PaperlessTask.Status.STARTED,
|
|
||||||
input_data={"filename": "invoice-d.pdf"},
|
|
||||||
)
|
|
||||||
PaperlessTaskFactory(
|
|
||||||
acknowledged=False,
|
|
||||||
status=PaperlessTask.Status.SUCCESS,
|
|
||||||
input_data={"filename": "invoice-e.pdf"},
|
|
||||||
)
|
|
||||||
PaperlessTaskFactory(
|
|
||||||
acknowledged=True,
|
|
||||||
status=PaperlessTask.Status.SUCCESS,
|
|
||||||
input_data={"filename": "invoice-acknowledged.pdf"},
|
|
||||||
)
|
|
||||||
PaperlessTaskFactory(
|
|
||||||
acknowledged=False,
|
|
||||||
status=PaperlessTask.Status.SUCCESS,
|
|
||||||
input_data={"filename": "unrelated.pdf"},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = admin_client.get(
|
|
||||||
f"{ENDPOINT}status_counts/",
|
|
||||||
{"acknowledged": "false", "name": "invoice"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
assert response.data == {
|
|
||||||
"all": 5,
|
|
||||||
"needs_attention": 2,
|
|
||||||
"in_progress": 2,
|
|
||||||
"completed": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_status_counts_ignores_section_filters(
|
|
||||||
self,
|
|
||||||
admin_client: APIClient,
|
|
||||||
) -> None:
|
|
||||||
"""status_counts/ ignores status-like filters for the sections it counts."""
|
|
||||||
PaperlessTaskFactory(
|
|
||||||
acknowledged=False,
|
|
||||||
status=PaperlessTask.Status.FAILURE,
|
|
||||||
input_data={"filename": "invoice-a.pdf"},
|
|
||||||
)
|
|
||||||
PaperlessTaskFactory(
|
|
||||||
acknowledged=False,
|
|
||||||
status=PaperlessTask.Status.PENDING,
|
|
||||||
input_data={"filename": "invoice-b.pdf"},
|
|
||||||
)
|
|
||||||
PaperlessTaskFactory(
|
|
||||||
acknowledged=False,
|
|
||||||
status=PaperlessTask.Status.SUCCESS,
|
|
||||||
input_data={"filename": "invoice-c.pdf"},
|
|
||||||
)
|
|
||||||
PaperlessTaskFactory(
|
|
||||||
acknowledged=False,
|
|
||||||
status=PaperlessTask.Status.FAILURE,
|
|
||||||
input_data={"filename": "unrelated.pdf"},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = admin_client.get(
|
|
||||||
f"{ENDPOINT}status_counts/",
|
|
||||||
{
|
|
||||||
"acknowledged": "false",
|
|
||||||
"name": "invoice",
|
|
||||||
"status": PaperlessTask.Status.FAILURE,
|
|
||||||
"is_complete": "false",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
assert response.data == {
|
|
||||||
"all": 3,
|
|
||||||
"needs_attention": 1,
|
|
||||||
"in_progress": 1,
|
|
||||||
"completed": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_default_ordering_is_newest_first(self, admin_client: APIClient) -> None:
|
def test_default_ordering_is_newest_first(self, admin_client: APIClient) -> None:
|
||||||
"""Tasks are returned in descending date_created order (newest first)."""
|
"""Tasks are returned in descending date_created order (newest first)."""
|
||||||
base = timezone.now()
|
base = timezone.now()
|
||||||
@@ -682,27 +522,6 @@ class TestAcknowledge:
|
|||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.data == {"result": 2}
|
assert response.data == {"result": 2}
|
||||||
|
|
||||||
def test_acknowledge_all_returns_count(self, admin_client: APIClient) -> None:
|
|
||||||
"""POST acknowledge/ with all=true acknowledges all unacknowledged tasks."""
|
|
||||||
unacknowledged_task1 = PaperlessTaskFactory(acknowledged=False)
|
|
||||||
unacknowledged_task2 = PaperlessTaskFactory(acknowledged=False)
|
|
||||||
acknowledged_task = PaperlessTaskFactory(acknowledged=True)
|
|
||||||
|
|
||||||
response = admin_client.post(
|
|
||||||
ENDPOINT + "acknowledge/",
|
|
||||||
{"all": True},
|
|
||||||
format="json",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
assert response.data == {"result": 2}
|
|
||||||
unacknowledged_task1.refresh_from_db()
|
|
||||||
unacknowledged_task2.refresh_from_db()
|
|
||||||
acknowledged_task.refresh_from_db()
|
|
||||||
assert unacknowledged_task1.acknowledged
|
|
||||||
assert unacknowledged_task2.acknowledged
|
|
||||||
assert acknowledged_task.acknowledged
|
|
||||||
|
|
||||||
def test_acknowledged_tasks_excluded_from_unacked_filter(
|
def test_acknowledged_tasks_excluded_from_unacked_filter(
|
||||||
self,
|
self,
|
||||||
admin_client: APIClient,
|
admin_client: APIClient,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from datetime import date
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pikepdf
|
|
||||||
from django.contrib.auth.models import Group
|
from django.contrib.auth.models import Group
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
@@ -616,18 +615,6 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
self.img_doc.archive_filename = img_doc_archive
|
self.img_doc.archive_filename = img_doc_archive
|
||||||
self.img_doc.save()
|
self.img_doc.save()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mock_password_required_pdf(
|
|
||||||
mock_open: mock.Mock,
|
|
||||||
fake_pdf: mock.Mock,
|
|
||||||
) -> None:
|
|
||||||
password_context = mock.MagicMock()
|
|
||||||
password_context.__enter__.return_value = fake_pdf
|
|
||||||
mock_open.side_effect = [
|
|
||||||
pikepdf.PasswordError("password required"),
|
|
||||||
password_context,
|
|
||||||
]
|
|
||||||
|
|
||||||
@mock.patch("documents.tasks.consume_file.s")
|
@mock.patch("documents.tasks.consume_file.s")
|
||||||
def test_merge(self, mock_consume_file) -> None:
|
def test_merge(self, mock_consume_file) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -777,7 +764,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
sig.set.return_value.apply_async.side_effect = Exception("boom")
|
sig.set.return_value.apply_async.side_effect = Exception("boom")
|
||||||
mock_consume_file.return_value = sig
|
mock_consume_file.return_value = sig
|
||||||
|
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaisesRegex(Exception, "boom"):
|
||||||
bulk_edit.merge(doc_ids, delete_originals=True)
|
bulk_edit.merge(doc_ids, delete_originals=True)
|
||||||
|
|
||||||
self.doc1.refresh_from_db()
|
self.doc1.refresh_from_db()
|
||||||
@@ -1060,6 +1047,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
for call, expected_id in zip(
|
for call, expected_id in zip(
|
||||||
mock_consume_delay.call_args_list,
|
mock_consume_delay.call_args_list,
|
||||||
doc_ids,
|
doc_ids,
|
||||||
|
strict=False,
|
||||||
):
|
):
|
||||||
task_kwargs = call.kwargs["kwargs"]
|
task_kwargs = call.kwargs["kwargs"]
|
||||||
self.assertEqual(task_kwargs["input_doc"].root_document_id, expected_id)
|
self.assertEqual(task_kwargs["input_doc"].root_document_id, expected_id)
|
||||||
@@ -1318,7 +1306,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
sig.apply_async.side_effect = Exception("boom")
|
sig.apply_async.side_effect = Exception("boom")
|
||||||
mock_chord.return_value = sig
|
mock_chord.return_value = sig
|
||||||
|
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaisesRegex(Exception, "boom"):
|
||||||
bulk_edit.edit_pdf(doc_ids, operations, delete_original=True)
|
bulk_edit.edit_pdf(doc_ids, operations, delete_original=True)
|
||||||
|
|
||||||
self.doc2.refresh_from_db()
|
self.doc2.refresh_from_db()
|
||||||
@@ -1430,7 +1418,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
{"page": 9999}, # invalid page, forces error during PDF load
|
{"page": 9999}, # invalid page, forces error during PDF load
|
||||||
]
|
]
|
||||||
with self.assertLogs("paperless.bulk_edit", level="ERROR"):
|
with self.assertLogs("paperless.bulk_edit", level="ERROR"):
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(ValueError):
|
||||||
bulk_edit.edit_pdf(doc_ids, operations)
|
bulk_edit.edit_pdf(doc_ids, operations)
|
||||||
mock_group.assert_not_called()
|
mock_group.assert_not_called()
|
||||||
mock_consume_file.assert_not_called()
|
mock_consume_file.assert_not_called()
|
||||||
@@ -1479,7 +1467,6 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
|
|
||||||
fake_pdf = mock.MagicMock()
|
fake_pdf = mock.MagicMock()
|
||||||
fake_pdf.pages = [mock.Mock(), mock.Mock(), mock.Mock()]
|
fake_pdf.pages = [mock.Mock(), mock.Mock(), mock.Mock()]
|
||||||
fake_pdf.is_encrypted = True
|
|
||||||
|
|
||||||
def save_side_effect(target_path):
|
def save_side_effect(target_path):
|
||||||
Path(target_path).write_bytes(b"new pdf content")
|
Path(target_path).write_bytes(b"new pdf content")
|
||||||
@@ -1494,13 +1481,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(result, "OK")
|
self.assertEqual(result, "OK")
|
||||||
self.assertEqual(
|
mock_open.assert_called_once_with(doc.source_path, password="secret")
|
||||||
mock_open.call_args_list,
|
|
||||||
[
|
|
||||||
mock.call(doc.source_path),
|
|
||||||
mock.call(doc.source_path, password="secret"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
fake_pdf.remove_unreferenced_resources.assert_called_once()
|
fake_pdf.remove_unreferenced_resources.assert_called_once()
|
||||||
mock_update_document.assert_not_called()
|
mock_update_document.assert_not_called()
|
||||||
mock_consume_delay.assert_called_once()
|
mock_consume_delay.assert_called_once()
|
||||||
@@ -1514,33 +1495,6 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
self.assertEqual(task_kwargs["input_doc"].root_document_id, doc.id)
|
self.assertEqual(task_kwargs["input_doc"].root_document_id, doc.id)
|
||||||
self.assertIsNotNone(task_kwargs["overrides"])
|
self.assertIsNotNone(task_kwargs["overrides"])
|
||||||
|
|
||||||
@mock.patch("documents.tasks.consume_file.apply_async")
|
|
||||||
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
|
|
||||||
@mock.patch("pikepdf.open")
|
|
||||||
def test_remove_password_update_document_skips_unencrypted_pdf(
|
|
||||||
self,
|
|
||||||
mock_open,
|
|
||||||
mock_mkdtemp,
|
|
||||||
mock_consume_delay,
|
|
||||||
) -> None:
|
|
||||||
doc = self.doc1
|
|
||||||
fake_pdf = mock.MagicMock()
|
|
||||||
fake_pdf.is_encrypted = False
|
|
||||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
|
||||||
|
|
||||||
result = bulk_edit.remove_password(
|
|
||||||
[doc.id],
|
|
||||||
password="secret",
|
|
||||||
update_document=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(result, "OK")
|
|
||||||
mock_open.assert_called_once_with(doc.source_path)
|
|
||||||
fake_pdf.remove_unreferenced_resources.assert_not_called()
|
|
||||||
fake_pdf.save.assert_not_called()
|
|
||||||
mock_mkdtemp.assert_not_called()
|
|
||||||
mock_consume_delay.assert_not_called()
|
|
||||||
|
|
||||||
@mock.patch("documents.bulk_edit.update_document_content_maybe_archive_file.delay")
|
@mock.patch("documents.bulk_edit.update_document_content_maybe_archive_file.delay")
|
||||||
@mock.patch("documents.tasks.consume_file.apply_async")
|
@mock.patch("documents.tasks.consume_file.apply_async")
|
||||||
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
|
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
|
||||||
@@ -1560,12 +1514,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
mock_mkdtemp.return_value = str(temp_dir)
|
mock_mkdtemp.return_value = str(temp_dir)
|
||||||
|
|
||||||
fake_pdf = mock.MagicMock()
|
fake_pdf = mock.MagicMock()
|
||||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
|
||||||
|
|
||||||
def save_side_effect(target_path):
|
def save_side_effect(target_path):
|
||||||
Path(target_path).write_bytes(b"new pdf content")
|
Path(target_path).write_bytes(b"new pdf content")
|
||||||
|
|
||||||
fake_pdf.save.side_effect = save_side_effect
|
fake_pdf.save.side_effect = save_side_effect
|
||||||
|
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||||
|
|
||||||
result = bulk_edit.remove_password(
|
result = bulk_edit.remove_password(
|
||||||
[doc.id],
|
[doc.id],
|
||||||
@@ -1575,13 +1529,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(result, "OK")
|
self.assertEqual(result, "OK")
|
||||||
self.assertEqual(
|
mock_open.assert_called_once_with(source_file, password="secret")
|
||||||
mock_open.call_args_list,
|
|
||||||
[
|
|
||||||
mock.call(source_file),
|
|
||||||
mock.call(source_file, password="secret"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
mock_update_document.assert_not_called()
|
mock_update_document.assert_not_called()
|
||||||
mock_consume_delay.assert_called_once()
|
mock_consume_delay.assert_called_once()
|
||||||
|
|
||||||
@@ -1600,7 +1548,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
root_document=self.doc1,
|
root_document=self.doc1,
|
||||||
)
|
)
|
||||||
fake_pdf = mock.MagicMock()
|
fake_pdf = mock.MagicMock()
|
||||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||||
|
|
||||||
result = bulk_edit.remove_password(
|
result = bulk_edit.remove_password(
|
||||||
[self.doc1.id],
|
[self.doc1.id],
|
||||||
@@ -1610,13 +1558,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(result, "OK")
|
self.assertEqual(result, "OK")
|
||||||
self.assertEqual(
|
mock_open.assert_called_once_with(self.doc1.source_path, password="secret")
|
||||||
mock_open.call_args_list,
|
|
||||||
[
|
|
||||||
mock.call(self.doc1.source_path),
|
|
||||||
mock.call(self.doc1.source_path, password="secret"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
mock_consume_delay.assert_called_once()
|
mock_consume_delay.assert_called_once()
|
||||||
|
|
||||||
@mock.patch("documents.bulk_edit.chord")
|
@mock.patch("documents.bulk_edit.chord")
|
||||||
@@ -1639,12 +1581,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
|
|
||||||
fake_pdf = mock.MagicMock()
|
fake_pdf = mock.MagicMock()
|
||||||
fake_pdf.pages = [mock.Mock(), mock.Mock()]
|
fake_pdf.pages = [mock.Mock(), mock.Mock()]
|
||||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
|
||||||
|
|
||||||
def save_side_effect(target_path: Path) -> None:
|
def save_side_effect(target_path: Path) -> None:
|
||||||
target_path.write_bytes(b"password removed")
|
target_path.write_bytes(b"password removed")
|
||||||
|
|
||||||
fake_pdf.save.side_effect = save_side_effect
|
fake_pdf.save.side_effect = save_side_effect
|
||||||
|
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||||
mock_group.return_value.delay.return_value = None
|
mock_group.return_value.delay.return_value = None
|
||||||
|
|
||||||
user = User.objects.create(username="owner")
|
user = User.objects.create(username="owner")
|
||||||
@@ -1659,13 +1601,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(result, "OK")
|
self.assertEqual(result, "OK")
|
||||||
self.assertEqual(
|
mock_open.assert_called_once_with(doc.source_path, password="secret")
|
||||||
mock_open.call_args_list,
|
|
||||||
[
|
|
||||||
mock.call(doc.source_path),
|
|
||||||
mock.call(doc.source_path, password="secret"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
mock_consume_file.assert_called_once()
|
mock_consume_file.assert_called_once()
|
||||||
call_kwargs = mock_consume_file.call_args.kwargs
|
call_kwargs = mock_consume_file.call_args.kwargs
|
||||||
consumable_document = call_kwargs["input_doc"]
|
consumable_document = call_kwargs["input_doc"]
|
||||||
@@ -1683,43 +1619,6 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
mock_group.return_value.delay.assert_called_once()
|
mock_group.return_value.delay.assert_called_once()
|
||||||
mock_chord.assert_not_called()
|
mock_chord.assert_not_called()
|
||||||
|
|
||||||
@mock.patch("documents.bulk_edit.delete")
|
|
||||||
@mock.patch("documents.bulk_edit.chord")
|
|
||||||
@mock.patch("documents.bulk_edit.group")
|
|
||||||
@mock.patch("documents.tasks.consume_file.s")
|
|
||||||
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
|
|
||||||
@mock.patch("pikepdf.open")
|
|
||||||
def test_remove_password_skips_unencrypted_pdf_without_queueing(
|
|
||||||
self,
|
|
||||||
mock_open: mock.Mock,
|
|
||||||
mock_mkdtemp: mock.Mock,
|
|
||||||
mock_consume_file: mock.Mock,
|
|
||||||
mock_group: mock.Mock,
|
|
||||||
mock_chord: mock.Mock,
|
|
||||||
mock_delete: mock.Mock,
|
|
||||||
) -> None:
|
|
||||||
doc = self.doc2
|
|
||||||
fake_pdf = mock.MagicMock()
|
|
||||||
fake_pdf.is_encrypted = False
|
|
||||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
|
||||||
|
|
||||||
result = bulk_edit.remove_password(
|
|
||||||
[doc.id],
|
|
||||||
password="secret",
|
|
||||||
update_document=False,
|
|
||||||
delete_original=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(result, "OK")
|
|
||||||
mock_open.assert_called_once_with(doc.source_path)
|
|
||||||
fake_pdf.remove_unreferenced_resources.assert_not_called()
|
|
||||||
fake_pdf.save.assert_not_called()
|
|
||||||
mock_mkdtemp.assert_not_called()
|
|
||||||
mock_consume_file.assert_not_called()
|
|
||||||
mock_group.assert_not_called()
|
|
||||||
mock_chord.assert_not_called()
|
|
||||||
mock_delete.si.assert_not_called()
|
|
||||||
|
|
||||||
@mock.patch("documents.bulk_edit.delete")
|
@mock.patch("documents.bulk_edit.delete")
|
||||||
@mock.patch("documents.bulk_edit.chord")
|
@mock.patch("documents.bulk_edit.chord")
|
||||||
@mock.patch("documents.bulk_edit.group")
|
@mock.patch("documents.bulk_edit.group")
|
||||||
@@ -1742,12 +1641,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
|
|
||||||
fake_pdf = mock.MagicMock()
|
fake_pdf = mock.MagicMock()
|
||||||
fake_pdf.pages = [mock.Mock(), mock.Mock()]
|
fake_pdf.pages = [mock.Mock(), mock.Mock()]
|
||||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
|
||||||
|
|
||||||
def save_side_effect(target_path: Path) -> None:
|
def save_side_effect(target_path: Path) -> None:
|
||||||
target_path.write_bytes(b"password removed")
|
target_path.write_bytes(b"password removed")
|
||||||
|
|
||||||
fake_pdf.save.side_effect = save_side_effect
|
fake_pdf.save.side_effect = save_side_effect
|
||||||
|
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||||
mock_chord.return_value.delay.return_value = None
|
mock_chord.return_value.delay.return_value = None
|
||||||
|
|
||||||
result = bulk_edit.remove_password(
|
result = bulk_edit.remove_password(
|
||||||
@@ -1759,13 +1658,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(result, "OK")
|
self.assertEqual(result, "OK")
|
||||||
self.assertEqual(
|
mock_open.assert_called_once_with(doc.source_path, password="secret")
|
||||||
mock_open.call_args_list,
|
|
||||||
[
|
|
||||||
mock.call(doc.source_path),
|
|
||||||
mock.call(doc.source_path, password="secret"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
mock_consume_file.assert_called_once()
|
mock_consume_file.assert_called_once()
|
||||||
mock_group.assert_not_called()
|
mock_group.assert_not_called()
|
||||||
mock_chord.assert_called_once()
|
mock_chord.assert_called_once()
|
||||||
|
|||||||
@@ -782,8 +782,8 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
load_classifier(raise_exception=True)
|
load_classifier(raise_exception=True)
|
||||||
|
|
||||||
Path(settings.MODEL_FILE).touch()
|
Path(settings.MODEL_FILE).touch()
|
||||||
mock_load.side_effect = Exception()
|
mock_load.side_effect = RuntimeError()
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(RuntimeError):
|
||||||
load_classifier(raise_exception=True)
|
load_classifier(raise_exception=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class TestDoubleSided(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
|||||||
def create_staging_file(self, src="double-sided-odd.pdf", datetime=None) -> None:
|
def create_staging_file(self, src="double-sided-odd.pdf", datetime=None) -> None:
|
||||||
shutil.copy(self.SAMPLE_DIR / src, self.staging_file)
|
shutil.copy(self.SAMPLE_DIR / src, self.staging_file)
|
||||||
if datetime is None:
|
if datetime is None:
|
||||||
datetime = dt.datetime.now()
|
datetime = dt.datetime.now(tz=dt.UTC)
|
||||||
os.utime(str(self.staging_file), (datetime.timestamp(),) * 2)
|
os.utime(str(self.staging_file), (datetime.timestamp(),) * 2)
|
||||||
|
|
||||||
def test_odd_numbered_moved_to_staging(self) -> None:
|
def test_odd_numbered_moved_to_staging(self) -> None:
|
||||||
@@ -79,8 +79,8 @@ class TestDoubleSided(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
|||||||
|
|
||||||
self.assertIsFile(self.staging_file)
|
self.assertIsFile(self.staging_file)
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
dt.datetime.fromtimestamp(self.staging_file.stat().st_mtime),
|
dt.datetime.fromtimestamp(self.staging_file.stat().st_mtime, tz=dt.UTC),
|
||||||
dt.datetime.now(),
|
dt.datetime.now(tz=dt.UTC),
|
||||||
delta=dt.timedelta(seconds=5),
|
delta=dt.timedelta(seconds=5),
|
||||||
)
|
)
|
||||||
self.assertIn("Received odd numbered pages", msg["reason"])
|
self.assertIn("Received odd numbered pages", msg["reason"])
|
||||||
@@ -124,7 +124,7 @@ class TestDoubleSided(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self.create_staging_file(
|
self.create_staging_file(
|
||||||
datetime=dt.datetime.now()
|
datetime=dt.datetime.now(tz=dt.UTC)
|
||||||
- dt.timedelta(minutes=TIMEOUT_MINUTES, seconds=1),
|
- dt.timedelta(minutes=TIMEOUT_MINUTES, seconds=1),
|
||||||
)
|
)
|
||||||
msg = self.consume_file("double-sided-odd.pdf")
|
msg = self.consume_file("double-sided-odd.pdf")
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from django.contrib.auth.models import User
|
|||||||
from django.db import DatabaseError
|
from django.db import DatabaseError
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
from django.utils import timezone
|
|
||||||
|
|
||||||
from documents.file_handling import create_source_path_directory
|
from documents.file_handling import create_source_path_directory
|
||||||
from documents.file_handling import delete_empty_directories
|
from documents.file_handling import delete_empty_directories
|
||||||
@@ -24,7 +23,6 @@ from documents.models import CustomFieldInstance
|
|||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.models import DocumentType
|
from documents.models import DocumentType
|
||||||
from documents.models import StoragePath
|
from documents.models import StoragePath
|
||||||
from documents.serialisers import DocumentSerializer
|
|
||||||
from documents.tasks import empty_trash
|
from documents.tasks import empty_trash
|
||||||
from documents.tests.factories import DocumentFactory
|
from documents.tests.factories import DocumentFactory
|
||||||
from documents.tests.utils import DirectoriesMixin
|
from documents.tests.utils import DirectoriesMixin
|
||||||
@@ -222,8 +220,11 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
|||||||
doc = Document.objects.create(
|
doc = Document.objects.create(
|
||||||
title="document",
|
title="document",
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
checksum=hashlib.sha256(original_bytes).hexdigest(),
|
checksum=hashlib.md5(original_bytes, usedforsecurity=False).hexdigest(),
|
||||||
archive_checksum=hashlib.sha256(archive_bytes).hexdigest(),
|
archive_checksum=hashlib.md5(
|
||||||
|
archive_bytes,
|
||||||
|
usedforsecurity=False,
|
||||||
|
).hexdigest(),
|
||||||
filename="old/document.pdf",
|
filename="old/document.pdf",
|
||||||
archive_filename="old/document.pdf",
|
archive_filename="old/document.pdf",
|
||||||
storage_path=old_storage_path,
|
storage_path=old_storage_path,
|
||||||
@@ -252,46 +253,6 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
|||||||
self.assertIsNotFile(settings.ORIGINALS_DIR / "old" / "document.pdf")
|
self.assertIsNotFile(settings.ORIGINALS_DIR / "old" / "document.pdf")
|
||||||
self.assertIsNotFile(settings.ARCHIVE_DIR / "old" / "document.pdf")
|
self.assertIsNotFile(settings.ARCHIVE_DIR / "old" / "document.pdf")
|
||||||
|
|
||||||
@override_settings(FILENAME_FORMAT="{title}")
|
|
||||||
def test_serializer_stale_update_does_not_clobber_filename(self) -> None:
|
|
||||||
old_path = settings.ORIGINALS_DIR / "original.pdf"
|
|
||||||
old_path.touch()
|
|
||||||
doc = Document.objects.create(
|
|
||||||
title="original",
|
|
||||||
mime_type="application/pdf",
|
|
||||||
checksum=hashlib.sha256(b"").hexdigest(),
|
|
||||||
filename="original.pdf",
|
|
||||||
)
|
|
||||||
|
|
||||||
first_instance = Document.objects.get(pk=doc.pk)
|
|
||||||
stale_instance = Document.objects.get(pk=doc.pk)
|
|
||||||
|
|
||||||
serializer = DocumentSerializer(
|
|
||||||
first_instance,
|
|
||||||
data={"title": "first"},
|
|
||||||
partial=True,
|
|
||||||
)
|
|
||||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
|
||||||
serializer.save()
|
|
||||||
|
|
||||||
doc.refresh_from_db()
|
|
||||||
self.assertEqual(doc.filename, "first.pdf")
|
|
||||||
self.assertIsFile(settings.ORIGINALS_DIR / "first.pdf")
|
|
||||||
|
|
||||||
serializer = DocumentSerializer(
|
|
||||||
stale_instance,
|
|
||||||
data={"title": "second"},
|
|
||||||
partial=True,
|
|
||||||
)
|
|
||||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
|
||||||
serializer.save()
|
|
||||||
|
|
||||||
doc.refresh_from_db()
|
|
||||||
self.assertEqual(doc.filename, "second.pdf")
|
|
||||||
self.assertIsFile(settings.ORIGINALS_DIR / "second.pdf")
|
|
||||||
self.assertIsNotFile(settings.ORIGINALS_DIR / "first.pdf")
|
|
||||||
self.assertIsNotFile(old_path)
|
|
||||||
|
|
||||||
@override_settings(FILENAME_FORMAT="{correspondent}/{correspondent}")
|
@override_settings(FILENAME_FORMAT="{correspondent}/{correspondent}")
|
||||||
def test_document_delete(self) -> None:
|
def test_document_delete(self) -> None:
|
||||||
document = Document()
|
document = Document()
|
||||||
@@ -452,7 +413,7 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
|||||||
FILENAME_FORMAT="{created_year}-{created_month}-{created_day}",
|
FILENAME_FORMAT="{created_year}-{created_month}-{created_day}",
|
||||||
)
|
)
|
||||||
def test_created_year_month_day(self) -> None:
|
def test_created_year_month_day(self) -> None:
|
||||||
d1 = timezone.make_aware(datetime.datetime(2020, 3, 6, 1, 1, 1))
|
d1 = datetime.datetime(2020, 3, 6, 1, 1, 1, tzinfo=datetime.UTC)
|
||||||
doc1 = Document.objects.create(
|
doc1 = Document.objects.create(
|
||||||
title="doc1",
|
title="doc1",
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
@@ -469,7 +430,7 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
|||||||
FILENAME_FORMAT="{added_year}-{added_month}-{added_day}",
|
FILENAME_FORMAT="{added_year}-{added_month}-{added_day}",
|
||||||
)
|
)
|
||||||
def test_added_year_month_day(self) -> None:
|
def test_added_year_month_day(self) -> None:
|
||||||
d1 = timezone.make_aware(datetime.datetime(1232, 1, 9, 1, 1, 1))
|
d1 = datetime.datetime(1232, 1, 9, 1, 1, 1, tzinfo=datetime.UTC)
|
||||||
doc1 = Document.objects.create(
|
doc1 = Document.objects.create(
|
||||||
title="doc1",
|
title="doc1",
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
@@ -482,7 +443,7 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
|||||||
|
|
||||||
self.assertEqual(generate_filename(doc1), expected_filename)
|
self.assertEqual(generate_filename(doc1), expected_filename)
|
||||||
|
|
||||||
doc1.added = timezone.make_aware(datetime.datetime(2020, 11, 16, 1, 1, 1))
|
doc1.added = datetime.datetime(2020, 11, 16, 1, 1, 1, tzinfo=datetime.UTC)
|
||||||
|
|
||||||
self.assertEqual(generate_filename(doc1), Path("2020-11-16.pdf"))
|
self.assertEqual(generate_filename(doc1), Path("2020-11-16.pdf"))
|
||||||
|
|
||||||
@@ -1266,7 +1227,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
|
|||||||
def test_short_names_added(self) -> None:
|
def test_short_names_added(self) -> None:
|
||||||
doc = Document.objects.create(
|
doc = Document.objects.create(
|
||||||
title="The Title",
|
title="The Title",
|
||||||
added=timezone.make_aware(datetime.datetime(1984, 8, 21, 7, 36, 51, 153)),
|
added=datetime.datetime(1984, 8, 21, 7, 36, 51, 153, tzinfo=datetime.UTC),
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
pk=2,
|
pk=2,
|
||||||
checksum="2",
|
checksum="2",
|
||||||
@@ -1505,7 +1466,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
|
|||||||
doc_a = Document.objects.create(
|
doc_a = Document.objects.create(
|
||||||
title="Does Matter",
|
title="Does Matter",
|
||||||
created=datetime.date(2020, 6, 25),
|
created=datetime.date(2020, 6, 25),
|
||||||
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
|
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
pk=2,
|
pk=2,
|
||||||
checksum="2",
|
checksum="2",
|
||||||
@@ -1577,7 +1538,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
|
|||||||
doc = Document.objects.create(
|
doc = Document.objects.create(
|
||||||
title="scan_017562",
|
title="scan_017562",
|
||||||
created=datetime.date(2025, 7, 2),
|
created=datetime.date(2025, 7, 2),
|
||||||
added=timezone.make_aware(datetime.datetime(2026, 3, 3, 11, 53, 16)),
|
added=datetime.datetime(2026, 3, 3, 11, 53, 16, tzinfo=datetime.UTC),
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
checksum="test-checksum",
|
checksum="test-checksum",
|
||||||
storage_path=sp,
|
storage_path=sp,
|
||||||
@@ -1606,7 +1567,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
|
|||||||
doc_a = Document.objects.create(
|
doc_a = Document.objects.create(
|
||||||
title="Does Matter",
|
title="Does Matter",
|
||||||
created=datetime.date(2020, 6, 25),
|
created=datetime.date(2020, 6, 25),
|
||||||
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
|
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
pk=2,
|
pk=2,
|
||||||
checksum="2",
|
checksum="2",
|
||||||
@@ -1641,7 +1602,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
|
|||||||
doc_a = Document.objects.create(
|
doc_a = Document.objects.create(
|
||||||
title="Does Matter",
|
title="Does Matter",
|
||||||
created=datetime.date(2020, 6, 25),
|
created=datetime.date(2020, 6, 25),
|
||||||
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
|
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
pk=2,
|
pk=2,
|
||||||
checksum="2",
|
checksum="2",
|
||||||
@@ -1673,7 +1634,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
|
|||||||
doc_a = Document.objects.create(
|
doc_a = Document.objects.create(
|
||||||
title="Some Title",
|
title="Some Title",
|
||||||
created=datetime.date(2020, 6, 25),
|
created=datetime.date(2020, 6, 25),
|
||||||
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
|
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
pk=2,
|
pk=2,
|
||||||
checksum="2",
|
checksum="2",
|
||||||
@@ -1778,7 +1739,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
|
|||||||
doc_a = Document.objects.create(
|
doc_a = Document.objects.create(
|
||||||
title="Some Title",
|
title="Some Title",
|
||||||
created=datetime.date(2020, 6, 25),
|
created=datetime.date(2020, 6, 25),
|
||||||
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
|
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
pk=2,
|
pk=2,
|
||||||
checksum="2",
|
checksum="2",
|
||||||
@@ -1792,8 +1753,15 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
|
|||||||
CustomFieldInstance.objects.create(
|
CustomFieldInstance.objects.create(
|
||||||
document=doc_a,
|
document=doc_a,
|
||||||
field=CustomField.objects.get(name="Invoice Date"),
|
field=CustomField.objects.get(name="Invoice Date"),
|
||||||
value_date=timezone.make_aware(
|
value_date=datetime.datetime(
|
||||||
datetime.datetime(2024, 10, 1, 7, 36, 51, 153),
|
2024,
|
||||||
|
10,
|
||||||
|
1,
|
||||||
|
7,
|
||||||
|
36,
|
||||||
|
51,
|
||||||
|
153,
|
||||||
|
tzinfo=datetime.UTC,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1833,7 +1801,7 @@ class TestFilenameGeneration(DirectoriesMixin, TestCase):
|
|||||||
doc = Document.objects.create(
|
doc = Document.objects.create(
|
||||||
title="Some Title! With @ Special # Characters",
|
title="Some Title! With @ Special # Characters",
|
||||||
created=datetime.date(2020, 6, 25),
|
created=datetime.date(2020, 6, 25),
|
||||||
added=timezone.make_aware(datetime.datetime(2024, 10, 1, 7, 36, 51, 153)),
|
added=datetime.datetime(2024, 10, 1, 7, 36, 51, 153, tzinfo=datetime.UTC),
|
||||||
mime_type="application/pdf",
|
mime_type="application/pdf",
|
||||||
pk=2,
|
pk=2,
|
||||||
checksum="2",
|
checksum="2",
|
||||||
|
|||||||
@@ -1,187 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for NFC Unicode normalization in generate_filename / FilePathTemplate.render().
|
|
||||||
|
|
||||||
NFC `ü` (UTF-8: c3 bc) and NFD `ü` (UTF-8: 75 cc 88) are visually identical but
|
|
||||||
produce different byte sequences. On Linux (ext4, ZFS) these are distinct filenames.
|
|
||||||
All paths produced by the templating system must be NFC-normalized.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import unicodedata
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from documents.file_handling import generate_filename
|
|
||||||
from documents.models import CustomField
|
|
||||||
from documents.models import CustomFieldInstance
|
|
||||||
from documents.tests.factories import CorrespondentFactory
|
|
||||||
from documents.tests.factories import DocumentFactory
|
|
||||||
from documents.tests.factories import StoragePathFactory
|
|
||||||
from documents.tests.factories import TagFactory
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
|
||||||
class TestGenerateFilenameNFCNormalization:
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"raw,display",
|
|
||||||
[
|
|
||||||
(unicodedata.normalize("NFD", "Gemüse"), "Gemüse"),
|
|
||||||
(unicodedata.normalize("NFD", "Café"), "Café"),
|
|
||||||
(unicodedata.normalize("NFD", "naïve"), "naïve"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_nfd_title_normalized_to_nfc(self, settings, raw, display):
|
|
||||||
"""NFD title must produce NFC path bytes."""
|
|
||||||
settings.FILENAME_FORMAT = "{{ title }}"
|
|
||||||
nfc = unicodedata.normalize("NFC", display)
|
|
||||||
assert raw != nfc # confirm byte-level difference
|
|
||||||
|
|
||||||
doc = DocumentFactory(title=raw, mime_type="application/pdf")
|
|
||||||
result = generate_filename(doc)
|
|
||||||
|
|
||||||
assert str(result) == f"{nfc}.pdf"
|
|
||||||
assert str(result).encode() == f"{nfc}.pdf".encode()
|
|
||||||
|
|
||||||
def test_nfd_correspondent_normalized_to_nfc(self, settings):
|
|
||||||
"""NFD correspondent name must produce NFC path component."""
|
|
||||||
settings.FILENAME_FORMAT = "{{ correspondent }}/{{ title }}"
|
|
||||||
nfd = unicodedata.normalize("NFD", "Müller")
|
|
||||||
nfc = unicodedata.normalize("NFC", "Müller")
|
|
||||||
|
|
||||||
correspondent = CorrespondentFactory(name=nfd)
|
|
||||||
doc = DocumentFactory(
|
|
||||||
title="invoice",
|
|
||||||
correspondent=correspondent,
|
|
||||||
mime_type="application/pdf",
|
|
||||||
)
|
|
||||||
result = generate_filename(doc)
|
|
||||||
|
|
||||||
assert str(result) == f"{nfc}/invoice.pdf"
|
|
||||||
assert str(result).encode() == f"{nfc}/invoice.pdf".encode()
|
|
||||||
|
|
||||||
def test_nfd_storage_path_normalized_to_nfc(self, settings):
|
|
||||||
"""NFD literal in StoragePath.path template must produce NFC path bytes."""
|
|
||||||
settings.FILENAME_FORMAT = None
|
|
||||||
nfd = unicodedata.normalize("NFD", "Büro")
|
|
||||||
nfc = unicodedata.normalize("NFC", "Büro")
|
|
||||||
|
|
||||||
# StoragePath.path is used directly as the format/template string.
|
|
||||||
# Literal NFD characters in the template must survive rendering as NFC.
|
|
||||||
sp = StoragePathFactory(path=f"{nfd}/{{{{ title }}}}")
|
|
||||||
doc = DocumentFactory(title="doc", storage_path=sp, mime_type="application/pdf")
|
|
||||||
result = generate_filename(doc)
|
|
||||||
|
|
||||||
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
|
|
||||||
|
|
||||||
def test_nfd_raw_document_title_normalized_to_nfc(self, settings):
|
|
||||||
"""NFD title accessed via document.title (unsanitized context) must also be NFC."""
|
|
||||||
settings.FILENAME_FORMAT = "{{ document.title }}"
|
|
||||||
nfd = unicodedata.normalize("NFD", "Café")
|
|
||||||
nfc = unicodedata.normalize("NFC", "Café")
|
|
||||||
|
|
||||||
doc = DocumentFactory(title=nfd, mime_type="application/pdf")
|
|
||||||
result = generate_filename(doc)
|
|
||||||
|
|
||||||
assert str(result) == f"{nfc}.pdf"
|
|
||||||
assert str(result).encode() == f"{nfc}.pdf".encode()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
|
||||||
class TestContextBuilderNFCNormalization:
|
|
||||||
"""
|
|
||||||
Defense-in-depth: context builder functions must NFC-normalize string inputs
|
|
||||||
before passing them to sanitize_filename(). Task 1 already normalizes the
|
|
||||||
final rendered path via clean_filepath(), so these tests may already pass;
|
|
||||||
they exist as regression guards for the context-builder layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_nfd_tag_name_normalized_in_tag_list(self, settings):
|
|
||||||
"""NFD tag name must appear as NFC bytes in the {{ tag_list }} shorthand."""
|
|
||||||
settings.FILENAME_FORMAT = "{{ tag_list }}/{{ title }}"
|
|
||||||
nfd = unicodedata.normalize("NFD", "Büro")
|
|
||||||
nfc = unicodedata.normalize("NFC", "Büro")
|
|
||||||
assert nfd != nfc # confirm they differ at byte level
|
|
||||||
|
|
||||||
tag = TagFactory(name=nfd)
|
|
||||||
doc = DocumentFactory(title="doc", mime_type="application/pdf")
|
|
||||||
doc.tags.set([tag])
|
|
||||||
|
|
||||||
result = generate_filename(doc)
|
|
||||||
|
|
||||||
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
|
|
||||||
|
|
||||||
def test_nfd_original_name_normalized_to_nfc(self, settings):
|
|
||||||
settings.FILENAME_FORMAT = "{{ original_name }}"
|
|
||||||
nfd = unicodedata.normalize("NFD", "Rechnung März")
|
|
||||||
nfc = unicodedata.normalize("NFC", "Rechnung März")
|
|
||||||
|
|
||||||
doc = DocumentFactory(
|
|
||||||
original_filename=f"{nfd}.pdf",
|
|
||||||
mime_type="application/pdf",
|
|
||||||
)
|
|
||||||
result = generate_filename(doc)
|
|
||||||
|
|
||||||
assert str(result).encode() == f"{nfc}.pdf".encode()
|
|
||||||
|
|
||||||
def test_nfd_custom_field_string_value_normalized(self, settings):
|
|
||||||
"""NFD value in a STRING-type custom field must appear as NFC in the context."""
|
|
||||||
settings.FILENAME_FORMAT = (
|
|
||||||
"{{ custom_fields['Location']['value'] }}/{{ title }}"
|
|
||||||
)
|
|
||||||
nfd_value = unicodedata.normalize("NFD", "Düsseldorf")
|
|
||||||
nfc_value = unicodedata.normalize("NFC", "Düsseldorf")
|
|
||||||
assert nfd_value != nfc_value
|
|
||||||
|
|
||||||
doc = DocumentFactory(title="report", mime_type="application/pdf")
|
|
||||||
cf = CustomField.objects.create(
|
|
||||||
name="Location",
|
|
||||||
data_type=CustomField.FieldDataType.STRING,
|
|
||||||
)
|
|
||||||
CustomFieldInstance.objects.create(
|
|
||||||
document=doc,
|
|
||||||
field=cf,
|
|
||||||
value_text=nfd_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = generate_filename(doc)
|
|
||||||
|
|
||||||
assert str(result).encode() == f"{nfc_value}/report.pdf".encode()
|
|
||||||
|
|
||||||
def test_nfd_custom_field_name_normalized_as_key(self, settings):
|
|
||||||
"""NFD characters in a custom field name must appear as NFC in the context dict key."""
|
|
||||||
nfd_name = unicodedata.normalize("NFD", "Größe")
|
|
||||||
nfc_name = unicodedata.normalize("NFC", "Größe")
|
|
||||||
assert nfd_name != nfc_name
|
|
||||||
|
|
||||||
settings.FILENAME_FORMAT = f"{{% if custom_fields['{nfc_name}'] %}}{{{{ custom_fields['{nfc_name}']['value'] }}}}/{{{{ title }}}}{{% else %}}{{{{ title }}}}{{% endif %}}"
|
|
||||||
|
|
||||||
doc = DocumentFactory(title="letter", mime_type="application/pdf")
|
|
||||||
cf = CustomField.objects.create(
|
|
||||||
name=nfd_name,
|
|
||||||
data_type=CustomField.FieldDataType.STRING,
|
|
||||||
)
|
|
||||||
CustomFieldInstance.objects.create(
|
|
||||||
document=doc,
|
|
||||||
field=cf,
|
|
||||||
value_text="Berlin",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = generate_filename(doc)
|
|
||||||
|
|
||||||
# If field name key is NFC-normalized, the template condition succeeds
|
|
||||||
# and result is "Berlin/letter.pdf"; otherwise it falls back to "letter.pdf"
|
|
||||||
assert str(result) == "Berlin/letter.pdf"
|
|
||||||
|
|
||||||
def test_nfd_tag_name_list_normalized_to_nfc(self, settings):
|
|
||||||
"""NFD tag names in tag_name_list must appear as NFC bytes when iterated."""
|
|
||||||
settings.FILENAME_FORMAT = (
|
|
||||||
"{% for t in tag_name_list %}{{ t }}{% endfor %}/{{ title }}"
|
|
||||||
)
|
|
||||||
nfd = unicodedata.normalize("NFD", "Büro")
|
|
||||||
nfc = unicodedata.normalize("NFC", "Büro")
|
|
||||||
assert nfd != nfc # confirm byte-level difference
|
|
||||||
|
|
||||||
doc = DocumentFactory(title="doc", mime_type="application/pdf")
|
|
||||||
doc.tags.add(TagFactory(name=nfd))
|
|
||||||
result = generate_filename(doc)
|
|
||||||
|
|
||||||
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
|
|
||||||
@@ -684,7 +684,6 @@ class ConsumerThread(Thread):
|
|||||||
subdirs_as_tags: bool = False,
|
subdirs_as_tags: bool = False,
|
||||||
polling_interval: float = 0,
|
polling_interval: float = 0,
|
||||||
stability_delay: float = 0.1,
|
stability_delay: float = 0.1,
|
||||||
rescan_interval: float | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.consumption_dir = consumption_dir
|
self.consumption_dir = consumption_dir
|
||||||
@@ -694,8 +693,6 @@ class ConsumerThread(Thread):
|
|||||||
self.polling_interval = polling_interval
|
self.polling_interval = polling_interval
|
||||||
self.stability_delay = stability_delay
|
self.stability_delay = stability_delay
|
||||||
self.cmd = Command()
|
self.cmd = Command()
|
||||||
if rescan_interval is not None:
|
|
||||||
self.cmd.rescan_interval_s = rescan_interval
|
|
||||||
self.cmd.stop_flag.clear()
|
self.cmd.stop_flag.clear()
|
||||||
# Non-daemon ensures finally block runs and connections are closed
|
# Non-daemon ensures finally block runs and connections are closed
|
||||||
self.daemon = False
|
self.daemon = False
|
||||||
@@ -1055,200 +1052,3 @@ class TestCommandWatchEdgeCases:
|
|||||||
thread.stop_and_wait(timeout=5.0)
|
thread.stop_and_wait(timeout=5.0)
|
||||||
# Clean up any Tags created by the thread
|
# Clean up any Tags created by the thread
|
||||||
Tag.objects.all().delete()
|
Tag.objects.all().delete()
|
||||||
|
|
||||||
|
|
||||||
class TestRescanExistingFiles:
|
|
||||||
"""
|
|
||||||
Unit tests for the rescan safety net.
|
|
||||||
|
|
||||||
Each ``watch()`` recreation silently adopts the current directory contents
|
|
||||||
as its baseline, so a file appearing between one batch and the next
|
|
||||||
watcher's baseline is never reported and would sit in the consume directory
|
|
||||||
forever. ``_rescan_existing_files`` re-injects such files into the
|
|
||||||
stability tracker as a periodic safety net (see GH issue #13011).
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def pdf_only_filter(self) -> ConsumerFilter:
|
|
||||||
return ConsumerFilter(
|
|
||||||
supported_extensions=frozenset({".pdf"}),
|
|
||||||
ignore_patterns=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
def _rescan(
|
|
||||||
self,
|
|
||||||
directory: Path,
|
|
||||||
consumer_filter: ConsumerFilter,
|
|
||||||
tracker: FileStabilityTracker,
|
|
||||||
queued: set[Path],
|
|
||||||
*,
|
|
||||||
recursive: bool = False,
|
|
||||||
) -> None:
|
|
||||||
Command()._rescan_existing_files(
|
|
||||||
directory=directory,
|
|
||||||
recursive=recursive,
|
|
||||||
consumer_filter=consumer_filter,
|
|
||||||
tracker=tracker,
|
|
||||||
queued=queued,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_tracks_stranded_file(
|
|
||||||
self,
|
|
||||||
consumption_dir: Path,
|
|
||||||
sample_pdf: Path,
|
|
||||||
pdf_only_filter: ConsumerFilter,
|
|
||||||
) -> None:
|
|
||||||
"""A supported on-disk file the watcher never reported gets tracked."""
|
|
||||||
target = consumption_dir / "stranded.pdf"
|
|
||||||
shutil.copy(sample_pdf, target)
|
|
||||||
tracker = FileStabilityTracker(stability_delay=0.1)
|
|
||||||
|
|
||||||
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
|
|
||||||
|
|
||||||
assert tracker.is_tracking(target) is True
|
|
||||||
assert tracker.pending_count == 1
|
|
||||||
|
|
||||||
def test_skips_already_tracked_file(
|
|
||||||
self,
|
|
||||||
consumption_dir: Path,
|
|
||||||
sample_pdf: Path,
|
|
||||||
pdf_only_filter: ConsumerFilter,
|
|
||||||
) -> None:
|
|
||||||
"""A file already being tracked by the watcher is not double-tracked."""
|
|
||||||
target = consumption_dir / "tracked.pdf"
|
|
||||||
shutil.copy(sample_pdf, target)
|
|
||||||
tracker = FileStabilityTracker(stability_delay=0.1)
|
|
||||||
tracker.track(target, Change.added)
|
|
||||||
|
|
||||||
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
|
|
||||||
|
|
||||||
assert tracker.pending_count == 1
|
|
||||||
|
|
||||||
def test_skips_queued_file(
|
|
||||||
self,
|
|
||||||
consumption_dir: Path,
|
|
||||||
sample_pdf: Path,
|
|
||||||
pdf_only_filter: ConsumerFilter,
|
|
||||||
) -> None:
|
|
||||||
"""A file already queued and awaiting consumption is not re-tracked."""
|
|
||||||
target = consumption_dir / "inflight.pdf"
|
|
||||||
shutil.copy(sample_pdf, target)
|
|
||||||
tracker = FileStabilityTracker(stability_delay=0.1)
|
|
||||||
queued = {target.resolve()}
|
|
||||||
|
|
||||||
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
|
|
||||||
|
|
||||||
assert tracker.pending_count == 0
|
|
||||||
|
|
||||||
def test_prunes_vanished_queued_paths(
|
|
||||||
self,
|
|
||||||
consumption_dir: Path,
|
|
||||||
pdf_only_filter: ConsumerFilter,
|
|
||||||
) -> None:
|
|
||||||
"""Queued paths no longer on disk are dropped so the name can recur."""
|
|
||||||
gone = (consumption_dir / "gone.pdf").resolve()
|
|
||||||
tracker = FileStabilityTracker(stability_delay=0.1)
|
|
||||||
queued = {gone}
|
|
||||||
|
|
||||||
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
|
|
||||||
|
|
||||||
assert gone not in queued
|
|
||||||
|
|
||||||
def test_skips_unsupported_extension(
|
|
||||||
self,
|
|
||||||
consumption_dir: Path,
|
|
||||||
pdf_only_filter: ConsumerFilter,
|
|
||||||
) -> None:
|
|
||||||
"""Files filtered out by the consumer filter are not tracked."""
|
|
||||||
(consumption_dir / "notes.xyz").write_bytes(b"content")
|
|
||||||
tracker = FileStabilityTracker(stability_delay=0.1)
|
|
||||||
|
|
||||||
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
|
|
||||||
|
|
||||||
assert tracker.pending_count == 0
|
|
||||||
|
|
||||||
def test_recursive_respects_flag(
|
|
||||||
self,
|
|
||||||
consumption_dir: Path,
|
|
||||||
sample_pdf: Path,
|
|
||||||
pdf_only_filter: ConsumerFilter,
|
|
||||||
) -> None:
|
|
||||||
"""Nested files are only found when recursive scanning is enabled."""
|
|
||||||
subdir = consumption_dir / "nested"
|
|
||||||
subdir.mkdir()
|
|
||||||
target = subdir / "deep.pdf"
|
|
||||||
shutil.copy(sample_pdf, target)
|
|
||||||
|
|
||||||
shallow = FileStabilityTracker(stability_delay=0.1)
|
|
||||||
self._rescan(consumption_dir, pdf_only_filter, shallow, set())
|
|
||||||
assert shallow.pending_count == 0
|
|
||||||
|
|
||||||
deep = FileStabilityTracker(stability_delay=0.1)
|
|
||||||
self._rescan(consumption_dir, pdf_only_filter, deep, set(), recursive=True)
|
|
||||||
assert deep.is_tracking(target) is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestProcessExistingFilesQueued:
|
|
||||||
"""Tests that startup processing reports which paths it queued."""
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("mock_supported_extensions")
|
|
||||||
def test_returns_queued_paths(
|
|
||||||
self,
|
|
||||||
consumption_dir: Path,
|
|
||||||
sample_pdf: Path,
|
|
||||||
mock_consume_file_delay: MagicMock,
|
|
||||||
settings: SettingsWrapper,
|
|
||||||
) -> None:
|
|
||||||
"""The set returned seeds the rescan's queued set, avoiding re-queue."""
|
|
||||||
target = consumption_dir / "document.pdf"
|
|
||||||
shutil.copy(sample_pdf, target)
|
|
||||||
settings.CONSUMER_IGNORE_PATTERNS = []
|
|
||||||
|
|
||||||
queued = Command()._process_existing_files(
|
|
||||||
directory=consumption_dir,
|
|
||||||
recursive=False,
|
|
||||||
subdirs_as_tags=False,
|
|
||||||
consumer_filter=ConsumerFilter(ignore_patterns=[]),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert target.resolve() in queued
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.management
|
|
||||||
@pytest.mark.django_db
|
|
||||||
class TestCommandRescanRecovery:
|
|
||||||
"""End-to-end test that the rescan recovers files the watcher misses."""
|
|
||||||
|
|
||||||
def test_rescan_consumes_file_the_watcher_never_reports(
|
|
||||||
self,
|
|
||||||
consumption_dir: Path,
|
|
||||||
sample_pdf: Path,
|
|
||||||
mock_consume_file_delay: MagicMock,
|
|
||||||
start_consumer: Callable[..., ConsumerThread],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Isolate the rescan path: a long polling interval guarantees the
|
|
||||||
watcher cannot report the file within the test window, so only the
|
|
||||||
periodic rescan can consume it.
|
|
||||||
"""
|
|
||||||
# poll interval far longer than the test window -> watcher stays silent
|
|
||||||
thread = start_consumer(
|
|
||||||
polling_interval=30.0,
|
|
||||||
stability_delay=0.1,
|
|
||||||
rescan_interval=0.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
# created after startup, so _process_existing_files did not see it
|
|
||||||
target = consumption_dir / "stranded.pdf"
|
|
||||||
shutil.copy(sample_pdf, target)
|
|
||||||
|
|
||||||
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=5.0)
|
|
||||||
|
|
||||||
if thread.exception:
|
|
||||||
raise thread.exception
|
|
||||||
|
|
||||||
mock_consume_file_delay.apply_async.assert_called()
|
|
||||||
call_args = mock_consume_file_delay.apply_async.call_args.kwargs["kwargs"][
|
|
||||||
"input_doc"
|
|
||||||
]
|
|
||||||
assert call_args.original_file.name == "stranded.pdf"
|
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from documents.signals.handlers import update_llm_suggestions_cache
|
|||||||
from documents.tests.utils import DirectoriesMixin
|
from documents.tests.utils import DirectoriesMixin
|
||||||
from documents.tests.utils import read_streaming_response
|
from documents.tests.utils import read_streaming_response
|
||||||
from paperless.models import ApplicationConfiguration
|
from paperless.models import ApplicationConfiguration
|
||||||
from paperless_ai.exceptions import LLMTimeoutError
|
|
||||||
|
|
||||||
|
|
||||||
class TestViews(DirectoriesMixin, TestCase):
|
class TestViews(DirectoriesMixin, TestCase):
|
||||||
@@ -244,7 +243,7 @@ class TestViews(DirectoriesMixin, TestCase):
|
|||||||
"change": {"users": [], "groups": []},
|
"change": {"users": [], "groups": []},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
assert False, f"Unexpected tag found: {tag['name']}"
|
raise AssertionError(f"Unexpected tag found: {tag['name']}")
|
||||||
|
|
||||||
def test_list_no_n_plus_1_queries(self) -> None:
|
def test_list_no_n_plus_1_queries(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -477,33 +476,6 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
|
|||||||
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
|
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("documents.views.get_ai_document_classification")
|
|
||||||
@override_settings(
|
|
||||||
AI_ENABLED=True,
|
|
||||||
LLM_BACKEND="openai-like",
|
|
||||||
)
|
|
||||||
def test_ai_suggestions_with_llm_timeout(
|
|
||||||
self,
|
|
||||||
mock_get_ai_classification,
|
|
||||||
) -> None:
|
|
||||||
mock_get_ai_classification.side_effect = LLMTimeoutError()
|
|
||||||
|
|
||||||
self.client.force_login(user=self.user)
|
|
||||||
response = self.client.get(
|
|
||||||
f"/api/documents/{self.document.pk}/ai_suggestions/",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_503_SERVICE_UNAVAILABLE)
|
|
||||||
self.assertEqual(
|
|
||||||
response.json(),
|
|
||||||
{
|
|
||||||
"ai": ["AI backend request timed out."],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertIsNone(
|
|
||||||
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_invalidate_suggestions_cache(self) -> None:
|
def test_invalidate_suggestions_cache(self) -> None:
|
||||||
self.client.force_login(user=self.user)
|
self.client.force_login(user=self.user)
|
||||||
suggestions = {
|
suggestions = {
|
||||||
|
|||||||
@@ -2760,7 +2760,14 @@ class TestWorkflows(
|
|||||||
doc = Document.objects.create(
|
doc = Document.objects.create(
|
||||||
title="test",
|
title="test",
|
||||||
)
|
)
|
||||||
self.assertRaises(Exception, document_matches_workflow, doc, w, 99)
|
self.assertRaisesRegex(
|
||||||
|
Exception,
|
||||||
|
"not yet supported",
|
||||||
|
document_matches_workflow,
|
||||||
|
doc,
|
||||||
|
w,
|
||||||
|
99,
|
||||||
|
)
|
||||||
|
|
||||||
def test_removal_action_document_updated_workflow(self) -> None:
|
def test_removal_action_document_updated_workflow(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -129,11 +129,12 @@ def util_call_with_backoff(
|
|||||||
status_codes.append(cause_exec.response.status_code)
|
status_codes.append(cause_exec.response.status_code)
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"HTTP Exception for {cause_exec.request.url} - {cause_exec}",
|
f"HTTP Exception for {cause_exec.request.url} - {cause_exec}",
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
warnings.warn(f"Unexpected error: {e}")
|
warnings.warn(f"Unexpected error: {e}", stacklevel=2)
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e: # pragma: no cover
|
||||||
warnings.warn(f"Unexpected error: {e}")
|
warnings.warn(f"Unexpected error: {e}", stacklevel=2)
|
||||||
|
|
||||||
retry_count = retry_count + 1
|
retry_count = retry_count + 1
|
||||||
|
|
||||||
|
|||||||
+50
-154
@@ -7,12 +7,11 @@ import tempfile
|
|||||||
import zipfile
|
import zipfile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from datetime import UTC
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from time import mktime
|
|
||||||
from time import sleep
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
@@ -61,7 +60,6 @@ from django.http import StreamingHttpResponse
|
|||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.decorators import method_decorator
|
from django.utils.decorators import method_decorator
|
||||||
from django.utils.timezone import make_aware
|
|
||||||
from django.utils.translation import get_language
|
from django.utils.translation import get_language
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from django.views import View
|
from django.views import View
|
||||||
@@ -241,7 +239,6 @@ from paperless.serialisers import UserSerializer
|
|||||||
from paperless.views import StandardPagination
|
from paperless.views import StandardPagination
|
||||||
from paperless_ai.ai_classifier import get_ai_document_classification
|
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||||
from paperless_ai.chat import stream_chat_with_documents
|
from paperless_ai.chat import stream_chat_with_documents
|
||||||
from paperless_ai.exceptions import LLMTimeoutError
|
|
||||||
from paperless_ai.matching import extract_unmatched_names
|
from paperless_ai.matching import extract_unmatched_names
|
||||||
from paperless_ai.matching import match_correspondents_by_name
|
from paperless_ai.matching import match_correspondents_by_name
|
||||||
from paperless_ai.matching import match_document_types_by_name
|
from paperless_ai.matching import match_document_types_by_name
|
||||||
@@ -287,7 +284,7 @@ def _get_more_like_id(query_params: dict[str, Any], user: User | None) -> int:
|
|||||||
pk=more_like_doc_id,
|
pk=more_like_doc_id,
|
||||||
)
|
)
|
||||||
except (TypeError, ValueError, Document.DoesNotExist):
|
except (TypeError, ValueError, Document.DoesNotExist):
|
||||||
raise PermissionDenied(_("Invalid more_like_id"))
|
raise PermissionDenied(_("Invalid more_like_id")) from None
|
||||||
|
|
||||||
if user and not has_perms_owner_aware(
|
if user and not has_perms_owner_aware(
|
||||||
user,
|
user,
|
||||||
@@ -1103,7 +1100,7 @@ class DocumentViewSet(
|
|||||||
"root_document",
|
"root_document",
|
||||||
).get(pk=pk)
|
).get(pk=pk)
|
||||||
except Document.DoesNotExist:
|
except Document.DoesNotExist:
|
||||||
raise Http404
|
raise Http404 from None
|
||||||
|
|
||||||
root_doc = get_root_document(doc)
|
root_doc = get_root_document(doc)
|
||||||
if request.user is not None and not has_perms_owner_aware(
|
if request.user is not None and not has_perms_owner_aware(
|
||||||
@@ -1266,7 +1263,7 @@ class DocumentViewSet(
|
|||||||
"root_document",
|
"root_document",
|
||||||
).get(id=pk)
|
).get(id=pk)
|
||||||
except Document.DoesNotExist:
|
except Document.DoesNotExist:
|
||||||
raise Http404
|
raise Http404 from None
|
||||||
|
|
||||||
root_doc = get_root_document(
|
root_doc = get_root_document(
|
||||||
request_doc,
|
request_doc,
|
||||||
@@ -1402,7 +1399,7 @@ class DocumentViewSet(
|
|||||||
)
|
)
|
||||||
if request.user is not None and not has_perms_owner_aware(
|
if request.user is not None and not has_perms_owner_aware(
|
||||||
request.user,
|
request.user,
|
||||||
"change_document",
|
"view_document",
|
||||||
doc,
|
doc,
|
||||||
):
|
):
|
||||||
return HttpResponseForbidden("Insufficient permissions")
|
return HttpResponseForbidden("Insufficient permissions")
|
||||||
@@ -1462,7 +1459,7 @@ class DocumentViewSet(
|
|||||||
)
|
)
|
||||||
if request.user is not None and not has_perms_owner_aware(
|
if request.user is not None and not has_perms_owner_aware(
|
||||||
request.user,
|
request.user,
|
||||||
"change_document",
|
"view_document",
|
||||||
doc,
|
doc,
|
||||||
):
|
):
|
||||||
return HttpResponseForbidden("Insufficient permissions")
|
return HttpResponseForbidden("Insufficient permissions")
|
||||||
@@ -1508,20 +1505,8 @@ class DocumentViewSet(
|
|||||||
"document %s: %s",
|
"document %s: %s",
|
||||||
doc.pk,
|
doc.pk,
|
||||||
exc,
|
exc,
|
||||||
exc_info=True,
|
|
||||||
)
|
)
|
||||||
raise ValidationError({"ai": [_("Invalid AI configuration.")]}) from exc
|
raise ValidationError({"ai": [_("Invalid AI configuration.")]}) from exc
|
||||||
except LLMTimeoutError as exc:
|
|
||||||
logger.exception(
|
|
||||||
"AI backend timed out while generating suggestions for document %s: %s",
|
|
||||||
doc.pk,
|
|
||||||
exc,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return Response(
|
|
||||||
{"ai": [_("AI backend request timed out.")]},
|
|
||||||
status=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
)
|
|
||||||
|
|
||||||
matched_tags = match_tags_by_name(
|
matched_tags = match_tags_by_name(
|
||||||
llm_suggestions.get("tags", []),
|
llm_suggestions.get("tags", []),
|
||||||
@@ -1593,7 +1578,7 @@ class DocumentViewSet(
|
|||||||
disposition="inline",
|
disposition="inline",
|
||||||
)
|
)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise Http404
|
raise Http404 from None
|
||||||
|
|
||||||
@action(methods=["get"], detail=True, filter_backends=[])
|
@action(methods=["get"], detail=True, filter_backends=[])
|
||||||
@method_decorator(cache_control(no_cache=True))
|
@method_decorator(cache_control(no_cache=True))
|
||||||
@@ -1618,14 +1603,14 @@ class DocumentViewSet(
|
|||||||
|
|
||||||
return FileResponse(handle, content_type="image/webp")
|
return FileResponse(handle, content_type="image/webp")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise Http404
|
raise Http404 from None
|
||||||
|
|
||||||
@action(methods=["get"], detail=True)
|
@action(methods=["get"], detail=True)
|
||||||
def download(self, request, pk=None):
|
def download(self, request, pk=None):
|
||||||
try:
|
try:
|
||||||
return self.file_response(pk, request, "attachment")
|
return self.file_response(pk, request, "attachment")
|
||||||
except (FileNotFoundError, Document.DoesNotExist):
|
except (FileNotFoundError, Document.DoesNotExist):
|
||||||
raise Http404
|
raise Http404 from None
|
||||||
|
|
||||||
@action(
|
@action(
|
||||||
methods=["get", "post", "delete"],
|
methods=["get", "post", "delete"],
|
||||||
@@ -1650,7 +1635,7 @@ class DocumentViewSet(
|
|||||||
):
|
):
|
||||||
return HttpResponseForbidden("Insufficient permissions to view notes")
|
return HttpResponseForbidden("Insufficient permissions to view notes")
|
||||||
except Document.DoesNotExist:
|
except Document.DoesNotExist:
|
||||||
raise Http404
|
raise Http404 from None
|
||||||
|
|
||||||
serializer = self.get_serializer(doc)
|
serializer = self.get_serializer(doc)
|
||||||
|
|
||||||
@@ -1721,7 +1706,7 @@ class DocumentViewSet(
|
|||||||
try:
|
try:
|
||||||
note_id_int = int(note_id)
|
note_id_int = int(note_id)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValidationError({"id": "A valid integer is required."})
|
raise ValidationError({"id": "A valid integer is required."}) from None
|
||||||
note = get_object_or_404(Note, id=note_id_int, document=doc)
|
note = get_object_or_404(Note, id=note_id_int, document=doc)
|
||||||
if settings.AUDIT_LOG_ENABLED:
|
if settings.AUDIT_LOG_ENABLED:
|
||||||
LogEntry.objects.log_create(
|
LogEntry.objects.log_create(
|
||||||
@@ -1765,7 +1750,7 @@ class DocumentViewSet(
|
|||||||
"Insufficient permissions to add share link",
|
"Insufficient permissions to add share link",
|
||||||
)
|
)
|
||||||
except Document.DoesNotExist:
|
except Document.DoesNotExist:
|
||||||
raise Http404
|
raise Http404 from None
|
||||||
|
|
||||||
if request.method == "GET":
|
if request.method == "GET":
|
||||||
now = timezone.now()
|
now = timezone.now()
|
||||||
@@ -1793,7 +1778,7 @@ class DocumentViewSet(
|
|||||||
"Insufficient permissions",
|
"Insufficient permissions",
|
||||||
)
|
)
|
||||||
except Document.DoesNotExist: # pragma: no cover
|
except Document.DoesNotExist: # pragma: no cover
|
||||||
raise Http404
|
raise Http404 from None
|
||||||
|
|
||||||
# documents
|
# documents
|
||||||
entries = [
|
entries = [
|
||||||
@@ -1814,28 +1799,28 @@ class DocumentViewSet(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# custom fields
|
# custom fields
|
||||||
for entry in LogEntry.objects.get_for_objects(
|
entries.extend(
|
||||||
doc.custom_fields.all(),
|
{
|
||||||
).select_related("actor"):
|
"id": entry.id,
|
||||||
entries.append(
|
"timestamp": entry.timestamp,
|
||||||
{
|
"action": entry.get_action_display(),
|
||||||
"id": entry.id,
|
"changes": {
|
||||||
"timestamp": entry.timestamp,
|
"custom_fields": {
|
||||||
"action": entry.get_action_display(),
|
"type": "custom_field",
|
||||||
"changes": {
|
"field": str(entry.object_repr).split(":")[0].strip(),
|
||||||
"custom_fields": {
|
"value": str(entry.object_repr).split(":")[1].strip(),
|
||||||
"type": "custom_field",
|
|
||||||
"field": str(entry.object_repr).split(":")[0].strip(),
|
|
||||||
"value": str(entry.object_repr).split(":")[1].strip(),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"actor": (
|
|
||||||
{"id": entry.actor.id, "username": entry.actor.username}
|
|
||||||
if entry.actor
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
)
|
"actor": (
|
||||||
|
{"id": entry.actor.id, "username": entry.actor.username}
|
||||||
|
if entry.actor
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for entry in LogEntry.objects.get_for_objects(
|
||||||
|
doc.custom_fields.all(),
|
||||||
|
).select_related("actor")
|
||||||
|
)
|
||||||
|
|
||||||
return Response(sorted(entries, key=lambda x: x["timestamp"], reverse=True))
|
return Response(sorted(entries, key=lambda x: x["timestamp"], reverse=True))
|
||||||
|
|
||||||
@@ -1943,13 +1928,13 @@ class DocumentViewSet(
|
|||||||
):
|
):
|
||||||
return HttpResponseForbidden("Insufficient permissions")
|
return HttpResponseForbidden("Insufficient permissions")
|
||||||
except Document.DoesNotExist:
|
except Document.DoesNotExist:
|
||||||
raise Http404
|
raise Http404 from None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
doc_name, doc_data = serializer.validated_data.get("document")
|
doc_name, doc_data = serializer.validated_data.get("document")
|
||||||
version_label = serializer.validated_data.get("version_label")
|
version_label = serializer.validated_data.get("version_label")
|
||||||
|
|
||||||
t = int(mktime(datetime.now().timetuple()))
|
t = int(timezone.now().timestamp())
|
||||||
|
|
||||||
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
|
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@@ -1994,7 +1979,7 @@ class DocumentViewSet(
|
|||||||
"root_document",
|
"root_document",
|
||||||
).get(pk=pk)
|
).get(pk=pk)
|
||||||
except Document.DoesNotExist:
|
except Document.DoesNotExist:
|
||||||
raise Http404
|
raise Http404 from None
|
||||||
return get_root_document(root_doc)
|
return get_root_document(root_doc)
|
||||||
|
|
||||||
def _get_version_doc_for_root(self, root_doc: Document, version_id) -> Document:
|
def _get_version_doc_for_root(self, root_doc: Document, version_id) -> Document:
|
||||||
@@ -2003,7 +1988,7 @@ class DocumentViewSet(
|
|||||||
pk=version_id,
|
pk=version_id,
|
||||||
)
|
)
|
||||||
except Document.DoesNotExist:
|
except Document.DoesNotExist:
|
||||||
raise Http404
|
raise Http404 from None
|
||||||
|
|
||||||
if (
|
if (
|
||||||
version_doc.id != root_doc.id
|
version_doc.id != root_doc.id
|
||||||
@@ -2289,7 +2274,6 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||||||
return super().list(request)
|
return super().list(request)
|
||||||
|
|
||||||
from documents.search import SearchHit
|
from documents.search import SearchHit
|
||||||
from documents.search import SearchQueryError
|
|
||||||
from documents.search import TantivyBackend
|
from documents.search import TantivyBackend
|
||||||
from documents.search import TantivyRelevanceList
|
from documents.search import TantivyRelevanceList
|
||||||
from documents.search import get_backend
|
from documents.search import get_backend
|
||||||
@@ -2482,11 +2466,6 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
|||||||
return HttpResponseForbidden(_("Insufficient permissions."))
|
return HttpResponseForbidden(_("Insufficient permissions."))
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
raise
|
raise
|
||||||
except SearchQueryError as e:
|
|
||||||
# User-fixable query error (e.g. an unparsable date): surface the
|
|
||||||
# specific message so the user can correct it, rather than a generic
|
|
||||||
# 400 or silently empty results.
|
|
||||||
raise ValidationError({"query": [str(e)]}) from e
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"An error occurred listing search results: {e!s}")
|
logger.warning(f"An error occurred listing search results: {e!s}")
|
||||||
return HttpResponseBadRequest(
|
return HttpResponseBadRequest(
|
||||||
@@ -2564,7 +2543,7 @@ class LogViewSet(ViewSet):
|
|||||||
try:
|
try:
|
||||||
limit = int(limit_param)
|
limit = int(limit_param)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
raise ValidationError({"limit": "Must be a positive integer"})
|
raise ValidationError({"limit": "Must be a positive integer"}) from None
|
||||||
if limit < 1:
|
if limit < 1:
|
||||||
raise ValidationError({"limit": "Must be a positive integer"})
|
raise ValidationError({"limit": "Must be a positive integer"})
|
||||||
else:
|
else:
|
||||||
@@ -3145,7 +3124,6 @@ class PostDocumentView(GenericAPIView[Any]):
|
|||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
|
|
||||||
doc_name, doc_data = serializer.validated_data.get("document")
|
doc_name, doc_data = serializer.validated_data.get("document")
|
||||||
doc_name = normalize("NFC", doc_name)
|
|
||||||
correspondent_id = serializer.validated_data.get("correspondent")
|
correspondent_id = serializer.validated_data.get("correspondent")
|
||||||
document_type_id = serializer.validated_data.get("document_type")
|
document_type_id = serializer.validated_data.get("document_type")
|
||||||
storage_path_id = serializer.validated_data.get("storage_path")
|
storage_path_id = serializer.validated_data.get("storage_path")
|
||||||
@@ -3156,7 +3134,7 @@ class PostDocumentView(GenericAPIView[Any]):
|
|||||||
cf = serializer.validated_data.get("custom_fields")
|
cf = serializer.validated_data.get("custom_fields")
|
||||||
from_webui = serializer.validated_data.get("from_webui")
|
from_webui = serializer.validated_data.get("from_webui")
|
||||||
|
|
||||||
t = int(mktime(datetime.now().timetuple()))
|
t = int(timezone.now().timestamp())
|
||||||
|
|
||||||
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
|
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@@ -4031,7 +4009,7 @@ class RemoteVersionView(GenericAPIView[Any]):
|
|||||||
|
|
||||||
|
|
||||||
class _TasksViewSetSchema(AutoSchema):
|
class _TasksViewSetSchema(AutoSchema):
|
||||||
_UNPAGINATED_ACTIONS = frozenset({"summary", "active", "status_counts"})
|
_UNPAGINATED_ACTIONS = frozenset({"summary", "active"})
|
||||||
|
|
||||||
def _get_paginator(self):
|
def _get_paginator(self):
|
||||||
if getattr(self.view, "action", None) in self._UNPAGINATED_ACTIONS:
|
if getattr(self.view, "action", None) in self._UNPAGINATED_ACTIONS:
|
||||||
@@ -4053,7 +4031,7 @@ class _TasksViewSetSchema(AutoSchema):
|
|||||||
),
|
),
|
||||||
acknowledge=extend_schema(
|
acknowledge=extend_schema(
|
||||||
operation_id="acknowledge_tasks",
|
operation_id="acknowledge_tasks",
|
||||||
description="Acknowledge a list of tasks, or all visible unacknowledged tasks",
|
description="Acknowledge a list of tasks",
|
||||||
request=AcknowledgeTasksViewSerializer,
|
request=AcknowledgeTasksViewSerializer,
|
||||||
responses={
|
responses={
|
||||||
(200, "application/json"): inline_serializer(
|
(200, "application/json"): inline_serializer(
|
||||||
@@ -4091,19 +4069,6 @@ class _TasksViewSetSchema(AutoSchema):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
status_counts=extend_schema(
|
|
||||||
responses={
|
|
||||||
200: inline_serializer(
|
|
||||||
name="TaskStatusCounts",
|
|
||||||
fields={
|
|
||||||
"all": serializers.IntegerField(),
|
|
||||||
"needs_attention": serializers.IntegerField(),
|
|
||||||
"in_progress": serializers.IntegerField(),
|
|
||||||
"completed": serializers.IntegerField(),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
active=extend_schema(
|
active=extend_schema(
|
||||||
description="Currently pending and running tasks (capped at 50).",
|
description="Currently pending and running tasks (capped at 50).",
|
||||||
responses={200: TaskSerializerV10(many=True)},
|
responses={200: TaskSerializerV10(many=True)},
|
||||||
@@ -4157,7 +4122,6 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
|
|||||||
PaperlessTask.TaskType.SANITY_CHECK: (sanity_check, {"raise_on_error": False}),
|
PaperlessTask.TaskType.SANITY_CHECK: (sanity_check, {"raise_on_error": False}),
|
||||||
PaperlessTask.TaskType.LLM_INDEX: (llmindex_index, {"rebuild": False}),
|
PaperlessTask.TaskType.LLM_INDEX: (llmindex_index, {"rebuild": False}),
|
||||||
}
|
}
|
||||||
_STATUS_COUNT_EXCLUDED_FILTERS = frozenset({"status", "is_complete"})
|
|
||||||
|
|
||||||
def get_serializer_class(self):
|
def get_serializer_class(self):
|
||||||
# v9: use backwards-compatible serializer with old field names
|
# v9: use backwards-compatible serializer with old field names
|
||||||
@@ -4198,38 +4162,16 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
|
|||||||
queryset = queryset.filter(task_id=task_id)
|
queryset = queryset.filter(task_id=task_id)
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
def get_status_count_queryset(self):
|
|
||||||
"""Apply task filters except the status dimensions represented by the counts."""
|
|
||||||
query_params = self.request.query_params.copy()
|
|
||||||
for param in self._STATUS_COUNT_EXCLUDED_FILTERS:
|
|
||||||
query_params.pop(param, None)
|
|
||||||
|
|
||||||
filterset = self.filterset_class(
|
|
||||||
data=query_params,
|
|
||||||
queryset=self.get_queryset(),
|
|
||||||
request=self.request,
|
|
||||||
)
|
|
||||||
if not filterset.is_valid():
|
|
||||||
raise ValidationError(filterset.errors)
|
|
||||||
return filterset.qs
|
|
||||||
|
|
||||||
@action(
|
@action(
|
||||||
methods=["post"],
|
methods=["post"],
|
||||||
detail=False,
|
detail=False,
|
||||||
permission_classes=[IsAuthenticated, AcknowledgeTasksPermissions],
|
permission_classes=[IsAuthenticated, AcknowledgeTasksPermissions],
|
||||||
)
|
)
|
||||||
def acknowledge(self, request):
|
def acknowledge(self, request):
|
||||||
queryset = self.get_queryset()
|
serializer = AcknowledgeTasksViewSerializer(data=request.data)
|
||||||
serializer = AcknowledgeTasksViewSerializer(
|
|
||||||
data=request.data,
|
|
||||||
context={"queryset": queryset},
|
|
||||||
)
|
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
if serializer.validated_data.get("all", False):
|
task_ids = serializer.validated_data.get("tasks")
|
||||||
tasks = queryset.filter(acknowledged=False)
|
tasks = self.get_queryset().filter(id__in=task_ids)
|
||||||
else:
|
|
||||||
task_ids = serializer.validated_data.get("tasks")
|
|
||||||
tasks = queryset.filter(id__in=task_ids)
|
|
||||||
count = tasks.update(acknowledged=True)
|
count = tasks.update(acknowledged=True)
|
||||||
return Response({"result": count})
|
return Response({"result": count})
|
||||||
|
|
||||||
@@ -4282,34 +4224,6 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
|
|||||||
serializer = TaskSummarySerializer(data, many=True)
|
serializer = TaskSummarySerializer(data, many=True)
|
||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
||||||
@action(methods=["get"], detail=False)
|
|
||||||
def status_counts(self, request):
|
|
||||||
"""Aggregated task counts for task UI sections."""
|
|
||||||
queryset = self.get_status_count_queryset()
|
|
||||||
counts = queryset.aggregate(
|
|
||||||
all=Count("id"),
|
|
||||||
needs_attention=Count(
|
|
||||||
"id",
|
|
||||||
filter=Q(
|
|
||||||
status__in=[
|
|
||||||
PaperlessTask.Status.FAILURE,
|
|
||||||
PaperlessTask.Status.REVOKED,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
in_progress=Count(
|
|
||||||
"id",
|
|
||||||
filter=Q(
|
|
||||||
status__in=[
|
|
||||||
PaperlessTask.Status.PENDING,
|
|
||||||
PaperlessTask.Status.STARTED,
|
|
||||||
],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
completed=Count("id", filter=Q(status=PaperlessTask.Status.SUCCESS)),
|
|
||||||
)
|
|
||||||
return Response(counts)
|
|
||||||
|
|
||||||
@action(methods=["get"], detail=False)
|
@action(methods=["get"], detail=False)
|
||||||
def active(self, request):
|
def active(self, request):
|
||||||
"""Currently pending and running tasks (capped at 50)."""
|
"""Currently pending and running tasks (capped at 50)."""
|
||||||
@@ -5009,29 +4923,11 @@ class SystemStatusView(PassUserMixin):
|
|||||||
celery_error = None
|
celery_error = None
|
||||||
celery_url = None
|
celery_url = None
|
||||||
try:
|
try:
|
||||||
celery_ping = None
|
celery_ping = celery_app.control.inspect().ping()
|
||||||
for ping_attempt in range(3):
|
celery_url = next(iter(celery_ping.keys()))
|
||||||
celery_ping = celery_app.control.inspect().ping()
|
first_worker_ping = celery_ping[celery_url]
|
||||||
if celery_ping:
|
if first_worker_ping["ok"] == "pong":
|
||||||
break
|
celery_active = "OK"
|
||||||
if ping_attempt < 2:
|
|
||||||
sleep(0.25)
|
|
||||||
|
|
||||||
if not celery_ping:
|
|
||||||
celery_active = "WARNING"
|
|
||||||
celery_error = (
|
|
||||||
"No celery workers responded to ping. This may be temporary."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
celery_url, first_worker_ping = next(iter(celery_ping.items()))
|
|
||||||
if (
|
|
||||||
isinstance(first_worker_ping, dict)
|
|
||||||
and first_worker_ping.get("ok") == "pong"
|
|
||||||
):
|
|
||||||
celery_active = "OK"
|
|
||||||
else:
|
|
||||||
celery_active = "WARNING"
|
|
||||||
celery_error = "Celery worker responded unexpectedly."
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
celery_active = "ERROR"
|
celery_active = "ERROR"
|
||||||
logger.exception(
|
logger.exception(
|
||||||
@@ -5050,7 +4946,7 @@ class SystemStatusView(PassUserMixin):
|
|||||||
index_dir = settings.INDEX_DIR
|
index_dir = settings.INDEX_DIR
|
||||||
mtimes = [p.stat().st_mtime for p in index_dir.iterdir() if p.is_file()]
|
mtimes = [p.stat().st_mtime for p in index_dir.iterdir() if p.is_file()]
|
||||||
index_last_modified = (
|
index_last_modified = (
|
||||||
make_aware(datetime.fromtimestamp(max(mtimes))) if mtimes else None
|
datetime.fromtimestamp(max(mtimes), tz=UTC) if mtimes else None
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
index_status = "ERROR"
|
index_status = "ERROR"
|
||||||
|
|||||||
+14
-13
@@ -84,10 +84,11 @@ def binaries_check(app_configs: Any, **kwargs: Any) -> list[Error]:
|
|||||||
|
|
||||||
binaries = (settings.CONVERT_BINARY, "tesseract", "gs")
|
binaries = (settings.CONVERT_BINARY, "tesseract", "gs")
|
||||||
|
|
||||||
check_messages = []
|
check_messages = [
|
||||||
for binary in binaries:
|
Warning(error.format(binary), hint)
|
||||||
if shutil.which(binary) is None:
|
for binary in binaries
|
||||||
check_messages.append(Warning(error.format(binary), hint))
|
if shutil.which(binary) is None
|
||||||
|
]
|
||||||
|
|
||||||
return check_messages
|
return check_messages
|
||||||
|
|
||||||
@@ -383,14 +384,14 @@ def check_default_language_available(app_configs: Any, **kwargs: Any) -> list[Er
|
|||||||
|
|
||||||
specified_langs = [x.strip() for x in settings.OCR_LANGUAGE.split("+")]
|
specified_langs = [x.strip() for x in settings.OCR_LANGUAGE.split("+")]
|
||||||
|
|
||||||
for lang in specified_langs:
|
errs.extend(
|
||||||
if lang not in installed_langs:
|
Error(
|
||||||
errs.append(
|
f"The selected ocr language {lang} is "
|
||||||
Error(
|
f"not installed. Paperless cannot OCR your documents "
|
||||||
f"The selected ocr language {lang} is "
|
f"without it. Please fix PAPERLESS_OCR_LANGUAGE.",
|
||||||
f"not installed. Paperless cannot OCR your documents "
|
)
|
||||||
f"without it. Please fix PAPERLESS_OCR_LANGUAGE.",
|
for lang in specified_langs
|
||||||
),
|
if lang not in installed_langs
|
||||||
)
|
)
|
||||||
|
|
||||||
return errs
|
return errs
|
||||||
|
|||||||
@@ -197,7 +197,6 @@ class AIConfig(BaseConfig):
|
|||||||
llm_embedding_endpoint: str = dataclasses.field(init=False)
|
llm_embedding_endpoint: str = dataclasses.field(init=False)
|
||||||
llm_embedding_chunk_size: int = dataclasses.field(init=False)
|
llm_embedding_chunk_size: int = dataclasses.field(init=False)
|
||||||
llm_context_size: int = dataclasses.field(init=False)
|
llm_context_size: int = dataclasses.field(init=False)
|
||||||
llm_request_timeout: int = dataclasses.field(init=False)
|
|
||||||
llm_backend: str = dataclasses.field(init=False)
|
llm_backend: str = dataclasses.field(init=False)
|
||||||
llm_model: str = dataclasses.field(init=False)
|
llm_model: str = dataclasses.field(init=False)
|
||||||
llm_api_key: str = dataclasses.field(init=False)
|
llm_api_key: str = dataclasses.field(init=False)
|
||||||
@@ -222,9 +221,6 @@ class AIConfig(BaseConfig):
|
|||||||
app_config.llm_embedding_chunk_size or settings.LLM_EMBEDDING_CHUNK_SIZE
|
app_config.llm_embedding_chunk_size or settings.LLM_EMBEDDING_CHUNK_SIZE
|
||||||
)
|
)
|
||||||
self.llm_context_size = app_config.llm_context_size or settings.LLM_CONTEXT_SIZE
|
self.llm_context_size = app_config.llm_context_size or settings.LLM_CONTEXT_SIZE
|
||||||
self.llm_request_timeout = (
|
|
||||||
app_config.llm_request_timeout or settings.LLM_REQUEST_TIMEOUT
|
|
||||||
)
|
|
||||||
self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND
|
self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND
|
||||||
self.llm_model = app_config.llm_model or settings.LLM_MODEL
|
self.llm_model = app_config.llm_model or settings.LLM_MODEL
|
||||||
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
|
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
|
||||||
|
|||||||
-365
@@ -1,365 +0,0 @@
|
|||||||
# Generated by Django 5.2.14 on 2026-06-04 15:30
|
|
||||||
|
|
||||||
import django.core.validators
|
|
||||||
from django.db import migrations
|
|
||||||
from django.db import models
|
|
||||||
|
|
||||||
|
|
||||||
def _create_singleton(apps, schema_editor):
|
|
||||||
settings_model = apps.get_model("paperless", "ApplicationConfiguration")
|
|
||||||
settings_model.objects.create()
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
replaces = [
|
|
||||||
("paperless", "0001_initial"),
|
|
||||||
("paperless", "0002_applicationconfiguration_app_logo_and_more"),
|
|
||||||
("paperless", "0003_alter_applicationconfiguration_max_image_pixels"),
|
|
||||||
("paperless", "0004_applicationconfiguration_barcode_asn_prefix_and_more"),
|
|
||||||
("paperless", "0005_applicationconfiguration_ai_enabled_and_more"),
|
|
||||||
("paperless", "0006_applicationconfiguration_barcode_tag_split"),
|
|
||||||
]
|
|
||||||
|
|
||||||
dependencies = []
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.CreateModel(
|
|
||||||
name="ApplicationConfiguration",
|
|
||||||
fields=[
|
|
||||||
(
|
|
||||||
"id",
|
|
||||||
models.AutoField(
|
|
||||||
auto_created=True,
|
|
||||||
primary_key=True,
|
|
||||||
serialize=False,
|
|
||||||
verbose_name="ID",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"output_type",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
choices=[
|
|
||||||
("pdf", "pdf"),
|
|
||||||
("pdfa", "pdfa"),
|
|
||||||
("pdfa-1", "pdfa-1"),
|
|
||||||
("pdfa-2", "pdfa-2"),
|
|
||||||
("pdfa-3", "pdfa-3"),
|
|
||||||
],
|
|
||||||
max_length=8,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the output PDF type",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"pages",
|
|
||||||
models.PositiveIntegerField(
|
|
||||||
null=True,
|
|
||||||
validators=[django.core.validators.MinValueValidator(1)],
|
|
||||||
verbose_name="Do OCR from page 1 to this value",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"language",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
max_length=32,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Do OCR using these languages",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"mode",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
choices=[
|
|
||||||
("skip", "skip"),
|
|
||||||
("redo", "redo"),
|
|
||||||
("force", "force"),
|
|
||||||
("skip_noarchive", "skip_noarchive"),
|
|
||||||
],
|
|
||||||
max_length=16,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the OCR mode",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"skip_archive_file",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
choices=[
|
|
||||||
("never", "never"),
|
|
||||||
("with_text", "with_text"),
|
|
||||||
("always", "always"),
|
|
||||||
],
|
|
||||||
max_length=16,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Controls the generation of an archive file",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"image_dpi",
|
|
||||||
models.PositiveIntegerField(
|
|
||||||
null=True,
|
|
||||||
validators=[django.core.validators.MinValueValidator(1)],
|
|
||||||
verbose_name="Sets image DPI fallback value",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"unpaper_clean",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
choices=[
|
|
||||||
("clean", "clean"),
|
|
||||||
("clean-final", "clean-final"),
|
|
||||||
("none", "none"),
|
|
||||||
],
|
|
||||||
max_length=16,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Controls the unpaper cleaning",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"deskew",
|
|
||||||
models.BooleanField(null=True, verbose_name="Enables deskew"),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"rotate_pages",
|
|
||||||
models.BooleanField(
|
|
||||||
null=True,
|
|
||||||
verbose_name="Enables page rotation",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"rotate_pages_threshold",
|
|
||||||
models.FloatField(
|
|
||||||
null=True,
|
|
||||||
validators=[django.core.validators.MinValueValidator(0.0)],
|
|
||||||
verbose_name="Sets the threshold for rotation of pages",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"max_image_pixels",
|
|
||||||
models.FloatField(
|
|
||||||
null=True,
|
|
||||||
validators=[django.core.validators.MinValueValidator(0.0)],
|
|
||||||
verbose_name="Sets the maximum image size for decompression",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"color_conversion_strategy",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
choices=[
|
|
||||||
("LeaveColorUnchanged", "LeaveColorUnchanged"),
|
|
||||||
("RGB", "RGB"),
|
|
||||||
("UseDeviceIndependentColor", "UseDeviceIndependentColor"),
|
|
||||||
("Gray", "Gray"),
|
|
||||||
("CMYK", "CMYK"),
|
|
||||||
],
|
|
||||||
max_length=32,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the Ghostscript color conversion strategy",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"user_args",
|
|
||||||
models.JSONField(
|
|
||||||
null=True,
|
|
||||||
verbose_name="Adds additional user arguments for OCRMyPDF",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"app_logo",
|
|
||||||
models.FileField(
|
|
||||||
blank=True,
|
|
||||||
null=True,
|
|
||||||
upload_to="logo/",
|
|
||||||
validators=[
|
|
||||||
django.core.validators.FileExtensionValidator(
|
|
||||||
allowed_extensions=["jpg", "png", "gif", "svg"],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
verbose_name="Application logo",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"app_title",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
max_length=48,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Application title",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcode_asn_prefix",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
max_length=32,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the ASN barcode prefix",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcode_dpi",
|
|
||||||
models.PositiveIntegerField(
|
|
||||||
null=True,
|
|
||||||
validators=[django.core.validators.MinValueValidator(1)],
|
|
||||||
verbose_name="Sets the barcode DPI",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcode_enable_asn",
|
|
||||||
models.BooleanField(
|
|
||||||
null=True,
|
|
||||||
verbose_name="Enables ASN barcode",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcode_enable_tag",
|
|
||||||
models.BooleanField(
|
|
||||||
null=True,
|
|
||||||
verbose_name="Enables tag barcode",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcode_enable_tiff_support",
|
|
||||||
models.BooleanField(
|
|
||||||
null=True,
|
|
||||||
verbose_name="Enables barcode TIFF support",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcode_max_pages",
|
|
||||||
models.PositiveIntegerField(
|
|
||||||
null=True,
|
|
||||||
validators=[django.core.validators.MinValueValidator(1)],
|
|
||||||
verbose_name="Sets the maximum pages for barcode",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcode_retain_split_pages",
|
|
||||||
models.BooleanField(
|
|
||||||
null=True,
|
|
||||||
verbose_name="Retains split pages",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcode_string",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
max_length=32,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the barcode string",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcode_tag_mapping",
|
|
||||||
models.JSONField(
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the tag barcode mapping",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcode_upscale",
|
|
||||||
models.FloatField(
|
|
||||||
null=True,
|
|
||||||
validators=[django.core.validators.MinValueValidator(1.0)],
|
|
||||||
verbose_name="Sets the barcode upscale factor",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcodes_enabled",
|
|
||||||
models.BooleanField(
|
|
||||||
null=True,
|
|
||||||
verbose_name="Enables barcode scanning",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"ai_enabled",
|
|
||||||
models.BooleanField(
|
|
||||||
default=False,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Enables AI features",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"llm_api_key",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
max_length=1024,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the LLM API key",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"llm_backend",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
choices=[
|
|
||||||
("openai-like", "OpenAI-compatible"),
|
|
||||||
("ollama", "Ollama"),
|
|
||||||
],
|
|
||||||
max_length=128,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the LLM backend",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"llm_embedding_backend",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
choices=[
|
|
||||||
("openai-like", "OpenAI-compatible"),
|
|
||||||
("huggingface", "Huggingface"),
|
|
||||||
],
|
|
||||||
max_length=128,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the LLM embedding backend",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"llm_embedding_model",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
max_length=128,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the LLM embedding model",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"llm_endpoint",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
max_length=256,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the LLM endpoint, optional",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"llm_model",
|
|
||||||
models.CharField(
|
|
||||||
blank=True,
|
|
||||||
max_length=128,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the LLM model",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"barcode_tag_split",
|
|
||||||
models.BooleanField(
|
|
||||||
null=True,
|
|
||||||
verbose_name="Enables splitting on tag barcodes",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
options={
|
|
||||||
"verbose_name": "paperless application settings",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
migrations.RunPython(
|
|
||||||
code=_create_singleton,
|
|
||||||
reverse_code=migrations.RunPython.noop,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
-94
@@ -1,94 +0,0 @@
|
|||||||
# Generated by Django 5.2.14 on 2026-06-04 15:19
|
|
||||||
|
|
||||||
import django.core.validators
|
|
||||||
from django.db import migrations
|
|
||||||
from django.db import models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
replaces = [
|
|
||||||
("paperless", "0009_alter_applicationconfiguration_options"),
|
|
||||||
("paperless", "0010_alter_applicationconfiguration_llm_embedding_backend"),
|
|
||||||
("paperless", "0011_applicationconfiguration_llm_embedding_chunk_size"),
|
|
||||||
("paperless", "0012_applicationconfiguration_llm_output_language"),
|
|
||||||
("paperless", "0013_applicationconfiguration_llm_request_timeout"),
|
|
||||||
]
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
("paperless", "0008_replace_skip_archive_file"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AlterModelOptions(
|
|
||||||
name="applicationconfiguration",
|
|
||||||
options={
|
|
||||||
"permissions": [
|
|
||||||
("view_global_statistics", "Can view global object counts"),
|
|
||||||
("view_system_monitoring", "Can view system status information"),
|
|
||||||
],
|
|
||||||
"verbose_name": "paperless application settings",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name="applicationconfiguration",
|
|
||||||
name="llm_embedding_backend",
|
|
||||||
field=models.CharField(
|
|
||||||
blank=True,
|
|
||||||
choices=[
|
|
||||||
("openai-like", "OpenAI-compatible"),
|
|
||||||
("huggingface", "Huggingface"),
|
|
||||||
("ollama", "Ollama"),
|
|
||||||
],
|
|
||||||
max_length=128,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the LLM embedding backend",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="applicationconfiguration",
|
|
||||||
name="llm_embedding_endpoint",
|
|
||||||
field=models.CharField(
|
|
||||||
blank=True,
|
|
||||||
max_length=256,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the LLM embedding endpoint, optional",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="applicationconfiguration",
|
|
||||||
name="llm_embedding_chunk_size",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
null=True,
|
|
||||||
validators=[django.core.validators.MinValueValidator(1)],
|
|
||||||
verbose_name="Sets the LLM embedding chunk size",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="applicationconfiguration",
|
|
||||||
name="llm_context_size",
|
|
||||||
field=models.PositiveIntegerField(
|
|
||||||
null=True,
|
|
||||||
validators=[django.core.validators.MinValueValidator(1)],
|
|
||||||
verbose_name="Sets the LLM context size",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="applicationconfiguration",
|
|
||||||
name="llm_output_language",
|
|
||||||
field=models.CharField(
|
|
||||||
blank=True,
|
|
||||||
max_length=32,
|
|
||||||
null=True,
|
|
||||||
verbose_name="Sets the LLM output language",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="applicationconfiguration",
|
|
||||||
name="llm_request_timeout",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
null=True,
|
|
||||||
validators=[django.core.validators.MinValueValidator(1)],
|
|
||||||
verbose_name="Sets the LLM request timeout in seconds",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
# Generated by Django 5.2.14 on 2026-06-14 14:22
|
|
||||||
|
|
||||||
import django.core.validators
|
|
||||||
from django.db import migrations
|
|
||||||
from django.db import models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
dependencies = [
|
|
||||||
("paperless", "0012_applicationconfiguration_llm_output_language"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AddField(
|
|
||||||
model_name="applicationconfiguration",
|
|
||||||
name="llm_request_timeout",
|
|
||||||
field=models.PositiveSmallIntegerField(
|
|
||||||
null=True,
|
|
||||||
validators=[django.core.validators.MinValueValidator(1)],
|
|
||||||
verbose_name="Sets the LLM request timeout in seconds",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
@@ -366,12 +366,6 @@ class ApplicationConfiguration(AbstractSingletonModel):
|
|||||||
max_length=32,
|
max_length=32,
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_request_timeout = models.PositiveSmallIntegerField(
|
|
||||||
verbose_name=_("Sets the LLM timeout in seconds"),
|
|
||||||
null=True,
|
|
||||||
validators=[MinValueValidator(1)],
|
|
||||||
)
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
verbose_name = _("paperless application settings")
|
verbose_name = _("paperless application settings")
|
||||||
permissions = [
|
permissions = [
|
||||||
|
|||||||
@@ -649,11 +649,10 @@ class MailDocumentParser:
|
|||||||
if data["bcc"]:
|
if data["bcc"]:
|
||||||
data["bcc_label"] = "BCC"
|
data["bcc_label"] = "BCC"
|
||||||
|
|
||||||
att = []
|
att = [
|
||||||
for a in mail.attachments:
|
f"{a.filename} ({naturalsize(a.size, binary=True, format='%.2f')})"
|
||||||
att.append(
|
for a in mail.attachments
|
||||||
f"{a.filename} ({naturalsize(a.size, binary=True, format='%.2f')})",
|
]
|
||||||
)
|
|
||||||
data["attachments"] = clean_html(", ".join(att))
|
data["attachments"] = clean_html(", ".join(att))
|
||||||
if data["attachments"]:
|
if data["attachments"]:
|
||||||
data["attachments_label"] = "Attachments"
|
data["attachments_label"] = "Attachments"
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from PIL import Image
|
|||||||
from PIL import ImageDraw
|
from PIL import ImageDraw
|
||||||
from PIL import ImageFont
|
from PIL import ImageFont
|
||||||
|
|
||||||
from paperless.parsers.utils import read_file_handle_unicode_errors
|
|
||||||
from paperless.version import __full_version_str__
|
from paperless.version import __full_version_str__
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -184,7 +183,7 @@ class TextDocumentParser:
|
|||||||
documents.parsers.ParseError
|
documents.parsers.ParseError
|
||||||
If the file cannot be read.
|
If the file cannot be read.
|
||||||
"""
|
"""
|
||||||
self._text = read_file_handle_unicode_errors(document_path, log=logger)
|
self._text = self._read_text(document_path)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Result accessors
|
# Result accessors
|
||||||
@@ -296,3 +295,30 @@ class TextDocumentParser:
|
|||||||
Always ``[]`` — plain text files carry no structured metadata.
|
Always ``[]`` — plain text files carry no structured metadata.
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Private helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _read_text(self, filepath: Path) -> str:
|
||||||
|
"""Read file content, replacing invalid UTF-8 bytes rather than failing.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filepath:
|
||||||
|
Path to the file to read.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
File content as a string.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return filepath.read_text(encoding="utf-8")
|
||||||
|
except UnicodeDecodeError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Unicode error reading %s, replacing bad bytes: %s",
|
||||||
|
filepath,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return filepath.read_bytes().decode("utf-8", errors="replace")
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ share implementation.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import codecs
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -115,7 +114,7 @@ def read_file_handle_unicode_errors(
|
|||||||
filepath: Path,
|
filepath: Path,
|
||||||
log: logging.Logger | None = None,
|
log: logging.Logger | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Read a file as text, detecting encoding via BOM and stripping NUL bytes.
|
"""Read a file as UTF-8 text, replacing invalid bytes rather than raising.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -128,27 +127,15 @@ def read_file_handle_unicode_errors(
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
str
|
str
|
||||||
File content as a string, with NUL bytes removed so the result is
|
File content as a string, with any invalid UTF-8 sequences replaced
|
||||||
safe to store in PostgreSQL text fields.
|
by the Unicode replacement character.
|
||||||
"""
|
"""
|
||||||
_log = log or logger
|
_log = log or logger
|
||||||
raw = filepath.read_bytes()
|
|
||||||
|
|
||||||
if raw.startswith((codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE)):
|
|
||||||
encoding = "utf-16"
|
|
||||||
elif raw.startswith(codecs.BOM_UTF8):
|
|
||||||
encoding = "utf-8-sig"
|
|
||||||
else:
|
|
||||||
encoding = "utf-8"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
text = raw.decode(encoding)
|
return filepath.read_text(encoding="utf-8")
|
||||||
except UnicodeDecodeError as e:
|
except UnicodeDecodeError as e:
|
||||||
_log.warning("Unicode error during text reading, continuing: %s", e)
|
_log.warning("Unicode error during text reading, continuing: %s", e)
|
||||||
text = raw.decode("utf-8", errors="replace")
|
return filepath.read_bytes().decode("utf-8", errors="replace")
|
||||||
|
|
||||||
# PostgreSQL rejects NUL (0x00) bytes in text fields
|
|
||||||
return text.replace("\x00", "")
|
|
||||||
|
|
||||||
|
|
||||||
def get_page_count_for_pdf(
|
def get_page_count_for_pdf(
|
||||||
|
|||||||
@@ -97,14 +97,8 @@ MODEL_FILE = get_path_from_env(
|
|||||||
DATA_DIR / "classification_model.pickle",
|
DATA_DIR / "classification_model.pickle",
|
||||||
)
|
)
|
||||||
LLM_INDEX_DIR = DATA_DIR / "llm_index"
|
LLM_INDEX_DIR = DATA_DIR / "llm_index"
|
||||||
LLM_INDEX_LOCK = LLM_INDEX_DIR / "index.lock"
|
LLM_INDEX_LOCK = DATA_DIR / "locks" / "llm_index.lock"
|
||||||
# Cross-process read/write lock guarding the LLM index compaction/migration
|
(DATA_DIR / "locks").mkdir(parents=True, exist_ok=True)
|
||||||
# file swap. Readers hold it shared; the swap takes it exclusively so it never
|
|
||||||
# runs while a reader connection is open. Must be a SQLite (.db) file.
|
|
||||||
LLM_INDEX_RWLOCK = LLM_INDEX_DIR / "llmindex.rwlock.db"
|
|
||||||
# Seconds the compaction swap waits for active readers to drain before skipping
|
|
||||||
# this cycle (it is a maintenance operation; the next run retries).
|
|
||||||
LLM_INDEX_COMPACTION_LOCK_TIMEOUT = 30
|
|
||||||
|
|
||||||
LOGGING_DIR = get_path_from_env("PAPERLESS_LOGGING_DIR", DATA_DIR / "log")
|
LOGGING_DIR = get_path_from_env("PAPERLESS_LOGGING_DIR", DATA_DIR / "log")
|
||||||
|
|
||||||
@@ -650,7 +644,6 @@ LOGGING = {
|
|||||||
"kombu": {"handlers": ["file_celery"], "level": "DEBUG"},
|
"kombu": {"handlers": ["file_celery"], "level": "DEBUG"},
|
||||||
"_granian": {"handlers": ["file_paperless"], "level": "DEBUG"},
|
"_granian": {"handlers": ["file_paperless"], "level": "DEBUG"},
|
||||||
"granian.access": {"handlers": ["file_paperless"], "level": "DEBUG"},
|
"granian.access": {"handlers": ["file_paperless"], "level": "DEBUG"},
|
||||||
"httpx": {"level": "WARNING"},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1206,9 +1199,6 @@ if LLM_EMBEDDING_CHUNK_SIZE < 1:
|
|||||||
LLM_CONTEXT_SIZE = get_int_from_env("PAPERLESS_AI_LLM_CONTEXT_SIZE", 8192)
|
LLM_CONTEXT_SIZE = get_int_from_env("PAPERLESS_AI_LLM_CONTEXT_SIZE", 8192)
|
||||||
if LLM_CONTEXT_SIZE < 1:
|
if LLM_CONTEXT_SIZE < 1:
|
||||||
raise ImproperlyConfigured("PAPERLESS_AI_LLM_CONTEXT_SIZE must be >= 1")
|
raise ImproperlyConfigured("PAPERLESS_AI_LLM_CONTEXT_SIZE must be >= 1")
|
||||||
LLM_REQUEST_TIMEOUT = get_int_from_env("PAPERLESS_AI_LLM_REQUEST_TIMEOUT", 120)
|
|
||||||
if LLM_REQUEST_TIMEOUT < 1:
|
|
||||||
raise ImproperlyConfigured("PAPERLESS_AI_LLM_REQUEST_TIMEOUT must be >= 1")
|
|
||||||
LLM_BACKEND = get_choice_from_env(
|
LLM_BACKEND = get_choice_from_env(
|
||||||
"PAPERLESS_AI_LLM_BACKEND",
|
"PAPERLESS_AI_LLM_BACKEND",
|
||||||
{"ollama", "openai-like"},
|
{"ollama", "openai-like"},
|
||||||
|
|||||||
@@ -252,9 +252,6 @@ def parse_db_settings(data_dir: Path) -> dict[str, dict[str, Any]]:
|
|||||||
"NAME": os.getenv("PAPERLESS_DBNAME", "paperless"),
|
"NAME": os.getenv("PAPERLESS_DBNAME", "paperless"),
|
||||||
"USER": os.getenv("PAPERLESS_DBUSER", "paperless"),
|
"USER": os.getenv("PAPERLESS_DBUSER", "paperless"),
|
||||||
"PASSWORD": os.getenv("PAPERLESS_DBPASS", "paperless"),
|
"PASSWORD": os.getenv("PAPERLESS_DBPASS", "paperless"),
|
||||||
# Validate pooled connections so a connection closed server-side
|
|
||||||
# is replaced rather than handed out as "the connection is closed".
|
|
||||||
"CONN_HEALTH_CHECKS": True,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
base_options = {
|
base_options = {
|
||||||
@@ -334,7 +331,7 @@ def parse_dateparser_languages(languages: str | None) -> list[str]:
|
|||||||
language_list = languages.split("+") if languages else []
|
language_list = languages.split("+") if languages else []
|
||||||
# There is an unfixed issue in zh-Hant and zh-Hans locales in the dateparser lib.
|
# There is an unfixed issue in zh-Hant and zh-Hans locales in the dateparser lib.
|
||||||
# See: https://github.com/scrapinghub/dateparser/issues/875
|
# See: https://github.com/scrapinghub/dateparser/issues/875
|
||||||
for index, language in enumerate(language_list):
|
for _, language in enumerate(language_list):
|
||||||
if language.startswith("zh-") and "zh" not in language_list:
|
if language.startswith("zh-") and "zh" not in language_list:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Chinese locale detected: {language}. dateparser might fail to parse"
|
f"Chinese locale detected: {language}. dateparser might fail to parse"
|
||||||
|
|||||||
@@ -398,7 +398,6 @@ class TestParseDbSettings:
|
|||||||
{
|
{
|
||||||
"default": {
|
"default": {
|
||||||
"ENGINE": "django.db.backends.postgresql",
|
"ENGINE": "django.db.backends.postgresql",
|
||||||
"CONN_HEALTH_CHECKS": True,
|
|
||||||
"HOST": "localhost",
|
"HOST": "localhost",
|
||||||
"NAME": "paperless",
|
"NAME": "paperless",
|
||||||
"USER": "paperless",
|
"USER": "paperless",
|
||||||
@@ -427,7 +426,6 @@ class TestParseDbSettings:
|
|||||||
{
|
{
|
||||||
"default": {
|
"default": {
|
||||||
"ENGINE": "django.db.backends.postgresql",
|
"ENGINE": "django.db.backends.postgresql",
|
||||||
"CONN_HEALTH_CHECKS": True,
|
|
||||||
"HOST": "paperless-db-host",
|
"HOST": "paperless-db-host",
|
||||||
"PORT": 1111,
|
"PORT": 1111,
|
||||||
"NAME": "customdb",
|
"NAME": "customdb",
|
||||||
@@ -457,7 +455,6 @@ class TestParseDbSettings:
|
|||||||
{
|
{
|
||||||
"default": {
|
"default": {
|
||||||
"ENGINE": "django.db.backends.postgresql",
|
"ENGINE": "django.db.backends.postgresql",
|
||||||
"CONN_HEALTH_CHECKS": True,
|
|
||||||
"HOST": "pghost",
|
"HOST": "pghost",
|
||||||
"NAME": "paperless",
|
"NAME": "paperless",
|
||||||
"USER": "paperless",
|
"USER": "paperless",
|
||||||
@@ -488,7 +485,6 @@ class TestParseDbSettings:
|
|||||||
{
|
{
|
||||||
"default": {
|
"default": {
|
||||||
"ENGINE": "django.db.backends.postgresql",
|
"ENGINE": "django.db.backends.postgresql",
|
||||||
"CONN_HEALTH_CHECKS": True,
|
|
||||||
"HOST": "pghost",
|
"HOST": "pghost",
|
||||||
"NAME": "paperless",
|
"NAME": "paperless",
|
||||||
"USER": "paperless",
|
"USER": "paperless",
|
||||||
|
|||||||
@@ -2,50 +2,13 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import codecs
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from paperless.parsers.utils import is_tagged_pdf
|
from paperless.parsers.utils import is_tagged_pdf
|
||||||
from paperless.parsers.utils import read_file_handle_unicode_errors
|
|
||||||
|
|
||||||
SAMPLES = Path(__file__).parent / "samples" / "tesseract"
|
SAMPLES = Path(__file__).parent / "samples" / "tesseract"
|
||||||
|
|
||||||
|
|
||||||
class TestReadFileHandleUnicodeErrors:
|
|
||||||
def test_plain_utf8(self, tmp_path: Path) -> None:
|
|
||||||
f = tmp_path / "plain.txt"
|
|
||||||
f.write_bytes(b"hello world")
|
|
||||||
assert read_file_handle_unicode_errors(f) == "hello world"
|
|
||||||
|
|
||||||
def test_utf8_bom(self, tmp_path: Path) -> None:
|
|
||||||
f = tmp_path / "bom.txt"
|
|
||||||
f.write_bytes(codecs.BOM_UTF8 + b"hello")
|
|
||||||
assert read_file_handle_unicode_errors(f) == "hello"
|
|
||||||
|
|
||||||
def test_utf16_le(self, tmp_path: Path) -> None:
|
|
||||||
f = tmp_path / "utf16le.txt"
|
|
||||||
f.write_bytes(codecs.BOM_UTF16_LE + "hello".encode("utf-16-le"))
|
|
||||||
assert read_file_handle_unicode_errors(f) == "hello"
|
|
||||||
|
|
||||||
def test_utf16_be(self, tmp_path: Path) -> None:
|
|
||||||
f = tmp_path / "utf16be.txt"
|
|
||||||
f.write_bytes(codecs.BOM_UTF16_BE + "hello".encode("utf-16-be"))
|
|
||||||
assert read_file_handle_unicode_errors(f) == "hello"
|
|
||||||
|
|
||||||
def test_nul_bytes_stripped(self, tmp_path: Path) -> None:
|
|
||||||
f = tmp_path / "null-bytes.txt"
|
|
||||||
f.write_bytes(b"foo\x00bar")
|
|
||||||
assert read_file_handle_unicode_errors(f) == "foobar"
|
|
||||||
|
|
||||||
def test_invalid_utf8_replaced(self, tmp_path: Path) -> None:
|
|
||||||
f = tmp_path / "bad.txt"
|
|
||||||
f.write_bytes(b"ok\x80\x81bad")
|
|
||||||
result = read_file_handle_unicode_errors(f)
|
|
||||||
assert "ok" in result
|
|
||||||
assert "bad" in result
|
|
||||||
assert "\x00" not in result
|
|
||||||
|
|
||||||
|
|
||||||
class TestIsTaggedPdf:
|
class TestIsTaggedPdf:
|
||||||
def test_tagged_pdf_returns_true(self) -> None:
|
def test_tagged_pdf_returns_true(self) -> None:
|
||||||
assert is_tagged_pdf(SAMPLES / "simple-digital.pdf") is True
|
assert is_tagged_pdf(SAMPLES / "simple-digital.pdf") is True
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ def reject_dangerous_svg(file: UploadedFile) -> None:
|
|||||||
tree = etree.parse(file, parser)
|
tree = etree.parse(file, parser)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
except etree.XMLSyntaxError:
|
except etree.XMLSyntaxError:
|
||||||
raise ValidationError("Invalid SVG file.")
|
raise ValidationError("Invalid SVG file.") from None
|
||||||
|
|
||||||
for element in root.iter():
|
for element in root.iter():
|
||||||
tag: str = etree.QName(element.tag).localname.lower()
|
tag: str = etree.QName(element.tag).localname.lower()
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ from paperless.serialisers import GroupSerializer
|
|||||||
from paperless.serialisers import PaperlessAuthTokenSerializer
|
from paperless.serialisers import PaperlessAuthTokenSerializer
|
||||||
from paperless.serialisers import ProfileSerializer
|
from paperless.serialisers import ProfileSerializer
|
||||||
from paperless.serialisers import UserSerializer
|
from paperless.serialisers import UserSerializer
|
||||||
from paperless_ai.indexing import llm_index_exists
|
from paperless_ai.indexing import vector_store_file_exists
|
||||||
|
|
||||||
|
|
||||||
class PaperlessObtainAuthTokenView(ObtainAuthToken):
|
class PaperlessObtainAuthTokenView(ObtainAuthToken):
|
||||||
@@ -467,7 +467,7 @@ class ApplicationConfigurationViewSet(ModelViewSet[ApplicationConfiguration]):
|
|||||||
or old_llm_context_size != new_llm_context_size
|
or old_llm_context_size != new_llm_context_size
|
||||||
)
|
)
|
||||||
rebuild_needed = new_ai_index_enabled and (
|
rebuild_needed = new_ai_index_enabled and (
|
||||||
not llm_index_exists() or embedding_config_changed
|
not vector_store_file_exists() or embedding_config_changed
|
||||||
)
|
)
|
||||||
|
|
||||||
if rebuild_needed:
|
if rebuild_needed:
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from documents.models import Document
|
|||||||
from documents.permissions import get_objects_for_user_owner_aware
|
from documents.permissions import get_objects_for_user_owner_aware
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
from paperless_ai.client import AIClient
|
from paperless_ai.client import AIClient
|
||||||
from paperless_ai.db import db_connection_released
|
|
||||||
from paperless_ai.indexing import query_similar_documents
|
from paperless_ai.indexing import query_similar_documents
|
||||||
from paperless_ai.indexing import truncate_content
|
from paperless_ai.indexing import truncate_content
|
||||||
|
|
||||||
@@ -25,14 +24,9 @@ def get_language_name(language_code: str) -> str:
|
|||||||
|
|
||||||
def build_prompt_without_rag(
|
def build_prompt_without_rag(
|
||||||
document: Document,
|
document: Document,
|
||||||
config: AIConfig,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
filename = document.filename or ""
|
filename = document.filename or ""
|
||||||
content = truncate_content(
|
content = truncate_content(document.content[:4000] or "")
|
||||||
document.content[:4000] or "",
|
|
||||||
chunk_size=config.llm_embedding_chunk_size,
|
|
||||||
context_size=config.llm_context_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"""
|
return f"""
|
||||||
You are a document classification assistant.
|
You are a document classification assistant.
|
||||||
@@ -55,15 +49,10 @@ def build_prompt_without_rag(
|
|||||||
|
|
||||||
def build_prompt_with_rag(
|
def build_prompt_with_rag(
|
||||||
document: Document,
|
document: Document,
|
||||||
config: AIConfig,
|
|
||||||
user: User | None = None,
|
user: User | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
base_prompt = build_prompt_without_rag(document, config)
|
base_prompt = build_prompt_without_rag(document)
|
||||||
context = truncate_content(
|
context = truncate_content(get_context_for_document(document, user))
|
||||||
get_context_for_document(document, user),
|
|
||||||
chunk_size=config.llm_embedding_chunk_size,
|
|
||||||
context_size=config.llm_context_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"""{base_prompt}
|
return f"""{base_prompt}
|
||||||
|
|
||||||
@@ -141,29 +130,26 @@ def get_ai_document_classification(
|
|||||||
ai_config = AIConfig()
|
ai_config = AIConfig()
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
build_prompt_with_rag(document, ai_config, user)
|
build_prompt_with_rag(document, user)
|
||||||
if ai_config.llm_embedding_backend
|
if ai_config.llm_embedding_backend
|
||||||
else build_prompt_without_rag(document, ai_config)
|
else build_prompt_without_rag(document)
|
||||||
)
|
)
|
||||||
|
|
||||||
client = AIClient()
|
client = AIClient()
|
||||||
# Hand the pooled DB connection back while the (slow) LLM query runs so it
|
result = client.run_llm_query(prompt)
|
||||||
# is not pinned for the call's duration; see paperless_ai.db and #12976.
|
suggestions = parse_ai_response(result)
|
||||||
with db_connection_released():
|
if output_language:
|
||||||
result = client.run_llm_query(prompt)
|
localized = client.run_llm_query(
|
||||||
suggestions = parse_ai_response(result)
|
build_localization_prompt(suggestions, output_language),
|
||||||
if output_language:
|
)
|
||||||
localized = client.run_llm_query(
|
localized_suggestions = parse_ai_response(localized)
|
||||||
build_localization_prompt(suggestions, output_language),
|
suggestions = {
|
||||||
)
|
**suggestions,
|
||||||
localized_suggestions = parse_ai_response(localized)
|
"title": localized_suggestions["title"] or suggestions["title"],
|
||||||
suggestions = {
|
"tags": localized_suggestions["tags"] or suggestions["tags"],
|
||||||
**suggestions,
|
"document_types": localized_suggestions["document_types"]
|
||||||
"title": localized_suggestions["title"] or suggestions["title"],
|
or suggestions["document_types"],
|
||||||
"tags": localized_suggestions["tags"] or suggestions["tags"],
|
"storage_paths": localized_suggestions["storage_paths"]
|
||||||
"document_types": localized_suggestions["document_types"]
|
or suggestions["storage_paths"],
|
||||||
or suggestions["document_types"],
|
}
|
||||||
"storage_paths": localized_suggestions["storage_paths"]
|
|
||||||
or suggestions["storage_paths"],
|
|
||||||
}
|
|
||||||
return suggestions
|
return suggestions
|
||||||
|
|||||||
+122
-56
@@ -3,13 +3,9 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.config import AIConfig
|
|
||||||
from paperless_ai.client import AIClient
|
from paperless_ai.client import AIClient
|
||||||
from paperless_ai.db import db_connection_released
|
|
||||||
from paperless_ai.indexing import _document_id_filters
|
|
||||||
from paperless_ai.indexing import get_rag_prompt_helper
|
from paperless_ai.indexing import get_rag_prompt_helper
|
||||||
from paperless_ai.indexing import load_or_build_index
|
from paperless_ai.indexing import load_or_build_index
|
||||||
from paperless_ai.indexing import read_store
|
|
||||||
|
|
||||||
logger = logging.getLogger("paperless_ai.chat")
|
logger = logging.getLogger("paperless_ai.chat")
|
||||||
|
|
||||||
@@ -79,6 +75,82 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int):
|
||||||
|
from llama_index.core.base.base_retriever import BaseRetriever
|
||||||
|
from llama_index.core.schema import NodeWithScore
|
||||||
|
from llama_index.core.vector_stores import VectorStoreQuery
|
||||||
|
|
||||||
|
class DocumentFilteredFaissRetriever(BaseRetriever):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._cached_query_str = None
|
||||||
|
self._cached_nodes = []
|
||||||
|
|
||||||
|
def _retrieve(self, query_bundle):
|
||||||
|
if query_bundle.query_str == self._cached_query_str:
|
||||||
|
return self._cached_nodes
|
||||||
|
|
||||||
|
if query_bundle.embedding is None:
|
||||||
|
query_bundle.embedding = (
|
||||||
|
index._embed_model.get_agg_embedding_from_queries(
|
||||||
|
query_bundle.embedding_strs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
faiss_index = index.vector_store._faiss_index
|
||||||
|
max_top_k = faiss_index.ntotal
|
||||||
|
if max_top_k == 0:
|
||||||
|
self._cached_query_str = query_bundle.query_str
|
||||||
|
self._cached_nodes = []
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_top_k = min(max(similarity_top_k, 1), max_top_k)
|
||||||
|
allowed_nodes: list[NodeWithScore] = []
|
||||||
|
seen_node_ids: set[str] = set()
|
||||||
|
|
||||||
|
while query_top_k <= max_top_k:
|
||||||
|
query_result = index.vector_store.query(
|
||||||
|
VectorStoreQuery(
|
||||||
|
query_embedding=query_bundle.embedding,
|
||||||
|
similarity_top_k=query_top_k,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
for vector_id, score in zip(
|
||||||
|
query_result.ids or [],
|
||||||
|
query_result.similarities or [],
|
||||||
|
strict=False,
|
||||||
|
):
|
||||||
|
node_id = index.index_struct.nodes_dict.get(vector_id)
|
||||||
|
if node_id is None or node_id in seen_node_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
node = index.docstore.docs.get(node_id)
|
||||||
|
if node is None or node.metadata.get("document_id") not in doc_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
seen_node_ids.add(node_id)
|
||||||
|
allowed_nodes.append(NodeWithScore(node=node, score=score))
|
||||||
|
|
||||||
|
if len(allowed_nodes) >= similarity_top_k:
|
||||||
|
self._cached_query_str = query_bundle.query_str
|
||||||
|
self._cached_nodes = allowed_nodes
|
||||||
|
return allowed_nodes
|
||||||
|
|
||||||
|
if query_top_k == max_top_k:
|
||||||
|
self._cached_query_str = query_bundle.query_str
|
||||||
|
self._cached_nodes = allowed_nodes
|
||||||
|
return allowed_nodes
|
||||||
|
|
||||||
|
query_top_k = min(query_top_k * 2, max_top_k)
|
||||||
|
|
||||||
|
self._cached_query_str = query_bundle.query_str
|
||||||
|
self._cached_nodes = allowed_nodes
|
||||||
|
return allowed_nodes
|
||||||
|
|
||||||
|
return DocumentFilteredFaissRetriever()
|
||||||
|
|
||||||
|
|
||||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||||
try:
|
try:
|
||||||
yield from _stream_chat_with_documents(query_str, documents)
|
yield from _stream_chat_with_documents(query_str, documents)
|
||||||
@@ -88,69 +160,63 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
|||||||
|
|
||||||
|
|
||||||
def _stream_chat_with_documents(query_str: str, documents: list[Document]):
|
def _stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||||
if not documents:
|
client = AIClient()
|
||||||
|
index = load_or_build_index()
|
||||||
|
|
||||||
|
doc_ids = [str(doc.pk) for doc in documents]
|
||||||
|
|
||||||
|
# Filter only the node(s) that match the document IDs
|
||||||
|
nodes = [
|
||||||
|
node
|
||||||
|
for node in index.docstore.docs.values()
|
||||||
|
if node.metadata.get("document_id") in doc_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(nodes) == 0:
|
||||||
|
logger.warning("No nodes found for the given documents.")
|
||||||
yield CHAT_NO_CONTENT_MESSAGE
|
yield CHAT_NO_CONTENT_MESSAGE
|
||||||
return
|
return
|
||||||
|
|
||||||
from llama_index.core.prompts import PromptTemplate
|
from llama_index.core.prompts import PromptTemplate
|
||||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||||
from llama_index.core.response_synthesizers import get_response_synthesizer
|
from llama_index.core.response_synthesizers import get_response_synthesizer
|
||||||
from llama_index.core.retrievers import VectorIndexRetriever
|
|
||||||
|
|
||||||
config = AIConfig()
|
retriever = _get_document_filtered_retriever(
|
||||||
filters = _document_id_filters(str(doc.pk) for doc in documents)
|
index,
|
||||||
|
set(doc_ids),
|
||||||
|
CHAT_RETRIEVER_TOP_K,
|
||||||
|
)
|
||||||
|
|
||||||
# Hold the shared read lock for the whole operation: the query engine
|
top_nodes = retriever.retrieve(query_str)
|
||||||
# retrieves from the vector store again during synthesis, so the connection
|
if len(top_nodes) == 0:
|
||||||
# must stay open (and the swap must not run) until the stream finishes.
|
logger.warning("Retriever returned no nodes for the given documents.")
|
||||||
with read_store() as store:
|
yield CHAT_NO_CONTENT_MESSAGE
|
||||||
index = load_or_build_index(config, store)
|
return
|
||||||
retriever = VectorIndexRetriever(
|
|
||||||
index=index,
|
|
||||||
similarity_top_k=CHAT_RETRIEVER_TOP_K,
|
|
||||||
filters=filters,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Slow query-embedding + vector search; no Django ORM access happens
|
references = _get_document_references(documents, top_nodes)
|
||||||
# during it, so release the pooled DB connection for its duration. See
|
|
||||||
# #12976.
|
|
||||||
with db_connection_released():
|
|
||||||
top_nodes = retriever.retrieve(query_str)
|
|
||||||
if not top_nodes:
|
|
||||||
logger.warning("No nodes found for the given documents.")
|
|
||||||
yield CHAT_NO_CONTENT_MESSAGE
|
|
||||||
return
|
|
||||||
|
|
||||||
client = AIClient()
|
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
|
||||||
|
response_synthesizer = get_response_synthesizer(
|
||||||
|
llm=client.llm,
|
||||||
|
prompt_helper=get_rag_prompt_helper(),
|
||||||
|
text_qa_template=prompt_template,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
references = _get_document_references(documents, top_nodes)
|
query_engine = RetrieverQueryEngine.from_args(
|
||||||
|
retriever=retriever,
|
||||||
|
llm=client.llm,
|
||||||
|
response_synthesizer=response_synthesizer,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
|
logger.debug("Document chat query: %s", query_str)
|
||||||
response_synthesizer = get_response_synthesizer(
|
|
||||||
llm=client.llm,
|
|
||||||
prompt_helper=get_rag_prompt_helper(
|
|
||||||
chunk_size=config.llm_embedding_chunk_size,
|
|
||||||
context_size=config.llm_context_size,
|
|
||||||
),
|
|
||||||
text_qa_template=prompt_template,
|
|
||||||
streaming=True,
|
|
||||||
)
|
|
||||||
query_engine = RetrieverQueryEngine.from_args(
|
|
||||||
retriever=retriever,
|
|
||||||
llm=client.llm,
|
|
||||||
response_synthesizer=response_synthesizer,
|
|
||||||
streaming=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Document chat query: %s", query_str)
|
response_stream = query_engine.query(query_str)
|
||||||
# Release the pooled DB connection for the slow streaming LLM response
|
|
||||||
# so it is not pinned for the whole stream; see paperless_ai.db and
|
|
||||||
# #12976.
|
|
||||||
with db_connection_released():
|
|
||||||
response_stream = query_engine.query(query_str)
|
|
||||||
for chunk in response_stream.response_gen:
|
|
||||||
yield chunk
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
if references:
|
for chunk in response_stream.response_gen:
|
||||||
yield _format_chat_metadata_trailer(references)
|
yield chunk
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
if references:
|
||||||
|
yield _format_chat_metadata_trailer(references)
|
||||||
|
|||||||
+28
-49
@@ -1,14 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Iterator
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from paperless.models import LLMBackend
|
from paperless.models import LLMBackend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from llama_index.core.llms import ChatMessage
|
||||||
from llama_index.llms.ollama import Ollama
|
from llama_index.llms.ollama import Ollama
|
||||||
from llama_index.llms.openai_like import OpenAILike
|
from llama_index.llms.openai_like import OpenAILike
|
||||||
|
|
||||||
@@ -19,7 +16,6 @@ from paperless.network import create_pinned_async_httpx_client
|
|||||||
from paperless.network import create_pinned_httpx_client
|
from paperless.network import create_pinned_httpx_client
|
||||||
from paperless.network import validate_outbound_http_url
|
from paperless.network import validate_outbound_http_url
|
||||||
from paperless_ai.base_model import DocumentClassifierSchema
|
from paperless_ai.base_model import DocumentClassifierSchema
|
||||||
from paperless_ai.exceptions import LLMTimeoutError
|
|
||||||
|
|
||||||
logger = logging.getLogger("paperless_ai.client")
|
logger = logging.getLogger("paperless_ai.client")
|
||||||
|
|
||||||
@@ -65,16 +61,16 @@ class AIClient:
|
|||||||
model=self.settings.llm_model or "llama3.1",
|
model=self.settings.llm_model or "llama3.1",
|
||||||
base_url=endpoint,
|
base_url=endpoint,
|
||||||
context_window=self.settings.llm_context_size,
|
context_window=self.settings.llm_context_size,
|
||||||
request_timeout=self.settings.llm_request_timeout,
|
request_timeout=120,
|
||||||
system_prompt=LLM_SYSTEM_PROMPT,
|
system_prompt=LLM_SYSTEM_PROMPT,
|
||||||
client=Client(
|
client=Client(
|
||||||
host=endpoint,
|
host=endpoint,
|
||||||
timeout=self.settings.llm_request_timeout,
|
timeout=120,
|
||||||
transport=transport,
|
transport=transport,
|
||||||
),
|
),
|
||||||
async_client=AsyncClient(
|
async_client=AsyncClient(
|
||||||
host=endpoint,
|
host=endpoint,
|
||||||
timeout=self.settings.llm_request_timeout,
|
timeout=120,
|
||||||
transport=async_transport,
|
transport=async_transport,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -88,18 +84,15 @@ class AIClient:
|
|||||||
http_client = create_pinned_httpx_client(
|
http_client = create_pinned_httpx_client(
|
||||||
endpoint,
|
endpoint,
|
||||||
allow_internal=self.settings.llm_allow_internal_endpoints,
|
allow_internal=self.settings.llm_allow_internal_endpoints,
|
||||||
timeout=self.settings.llm_request_timeout,
|
|
||||||
)
|
)
|
||||||
async_http_client = create_pinned_async_httpx_client(
|
async_http_client = create_pinned_async_httpx_client(
|
||||||
endpoint,
|
endpoint,
|
||||||
allow_internal=self.settings.llm_allow_internal_endpoints,
|
allow_internal=self.settings.llm_allow_internal_endpoints,
|
||||||
timeout=self.settings.llm_request_timeout,
|
|
||||||
)
|
)
|
||||||
return OpenAILike(
|
return OpenAILike(
|
||||||
model=self.settings.llm_model or "gpt-3.5-turbo",
|
model=self.settings.llm_model or "gpt-3.5-turbo",
|
||||||
api_base=endpoint,
|
api_base=endpoint,
|
||||||
api_key=self.settings.llm_api_key,
|
api_key=self.settings.llm_api_key,
|
||||||
timeout=self.settings.llm_request_timeout,
|
|
||||||
is_chat_model=True,
|
is_chat_model=True,
|
||||||
is_function_calling_model=True,
|
is_function_calling_model=True,
|
||||||
system_prompt=LLM_SYSTEM_PROMPT,
|
system_prompt=LLM_SYSTEM_PROMPT,
|
||||||
@@ -120,12 +113,11 @@ class AIClient:
|
|||||||
|
|
||||||
user_msg = ChatMessage(role="user", content=prompt)
|
user_msg = ChatMessage(role="user", content=prompt)
|
||||||
if self.settings.llm_backend == LLMBackend.OLLAMA:
|
if self.settings.llm_backend == LLMBackend.OLLAMA:
|
||||||
with self._normalize_timeouts():
|
result = self.llm.chat(
|
||||||
result = self.llm.chat(
|
[user_msg],
|
||||||
[user_msg],
|
format=DocumentClassifierSchema.model_json_schema(),
|
||||||
format=DocumentClassifierSchema.model_json_schema(),
|
think=False,
|
||||||
think=False,
|
)
|
||||||
)
|
|
||||||
logger.debug("LLM query result: %s", result)
|
logger.debug("LLM query result: %s", result)
|
||||||
parsed = DocumentClassifierSchema(**json.loads(result.message.content))
|
parsed = DocumentClassifierSchema(**json.loads(result.message.content))
|
||||||
return parsed.model_dump()
|
return parsed.model_dump()
|
||||||
@@ -133,39 +125,26 @@ class AIClient:
|
|||||||
from llama_index.core.program.function_program import get_function_tool
|
from llama_index.core.program.function_program import get_function_tool
|
||||||
|
|
||||||
tool = get_function_tool(DocumentClassifierSchema)
|
tool = get_function_tool(DocumentClassifierSchema)
|
||||||
with self._normalize_timeouts():
|
result = self.llm.chat_with_tools(
|
||||||
result = self.llm.chat_with_tools(
|
tools=[tool],
|
||||||
tools=[tool],
|
user_msg=user_msg,
|
||||||
user_msg=user_msg,
|
chat_history=[],
|
||||||
chat_history=[],
|
allow_parallel_tool_calls=True,
|
||||||
allow_parallel_tool_calls=True,
|
)
|
||||||
tool_required=True,
|
tool_calls = self.llm.get_tool_calls_from_response(
|
||||||
)
|
result,
|
||||||
tool_calls = self.llm.get_tool_calls_from_response(
|
error_on_no_tool_call=True,
|
||||||
result,
|
)
|
||||||
error_on_no_tool_call=True,
|
|
||||||
)
|
|
||||||
logger.debug("LLM query result: %s", tool_calls)
|
logger.debug("LLM query result: %s", tool_calls)
|
||||||
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
|
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
|
||||||
return parsed.model_dump()
|
return parsed.model_dump()
|
||||||
|
|
||||||
@contextmanager
|
def run_chat(self, messages: list["ChatMessage"]) -> str:
|
||||||
def _normalize_timeouts(self) -> Iterator[None]:
|
logger.debug(
|
||||||
try:
|
"Running chat query against %s with model %s",
|
||||||
yield
|
self.settings.llm_backend,
|
||||||
except httpx.TimeoutException as exc:
|
self.settings.llm_model,
|
||||||
raise LLMTimeoutError from exc
|
)
|
||||||
except Exception as exc:
|
result = self.llm.chat(messages)
|
||||||
if self._is_openai_timeout(exc):
|
logger.debug("Chat result: %s", result)
|
||||||
raise LLMTimeoutError from exc
|
return result
|
||||||
raise
|
|
||||||
|
|
||||||
def _is_openai_timeout(self, exc: Exception) -> bool:
|
|
||||||
if self.settings.llm_backend != LLMBackend.OPENAI_LIKE:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Keep OpenAI imports out of module import paths and only load the SDK
|
|
||||||
# when translating an error from an OpenAI-backed request.
|
|
||||||
from openai import APITimeoutError
|
|
||||||
|
|
||||||
return isinstance(exc, APITimeoutError)
|
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
from django.db import connections
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def db_connection_released():
|
|
||||||
"""
|
|
||||||
Return any checked-out DB connections to the pool for the duration of the
|
|
||||||
wrapped block.
|
|
||||||
|
|
||||||
The AI endpoints run inside a synchronous web request (``ai_suggestions``)
|
|
||||||
or a streaming response (``chat``). Django keeps the request's database
|
|
||||||
connection checked out for the entire request/response, so a blocking LLM
|
|
||||||
call - which can take many seconds - pins a pooled connection the whole
|
|
||||||
time. With connection pooling enabled, enough concurrent AI requests check
|
|
||||||
out every slot and all other requests then fail with
|
|
||||||
``psycopg_pool.PoolTimeout`` (see issue #12976).
|
|
||||||
|
|
||||||
No Django ORM access happens during the LLM call, so we hand the connection
|
|
||||||
back to the pool first; Django transparently re-checks-out a connection on
|
|
||||||
the next ORM use after the block.
|
|
||||||
"""
|
|
||||||
connections.close_all()
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
connections.close_all()
|
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
@@ -20,7 +23,9 @@ OCR_LEADER_REGEX = re.compile(r"[._\-\u00b7]{4,}")
|
|||||||
HORIZONTAL_WHITESPACE_REGEX = re.compile(r"[ \t\u00a0]+")
|
HORIZONTAL_WHITESPACE_REGEX = re.compile(r"[ \t\u00a0]+")
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
|
def get_embedding_model() -> "BaseEmbedding":
|
||||||
|
config = AIConfig()
|
||||||
|
|
||||||
match config.llm_embedding_backend:
|
match config.llm_embedding_backend:
|
||||||
case LLMEmbeddingBackend.OPENAI_LIKE:
|
case LLMEmbeddingBackend.OPENAI_LIKE:
|
||||||
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
||||||
@@ -32,18 +37,15 @@ def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
|
|||||||
http_client = create_pinned_httpx_client(
|
http_client = create_pinned_httpx_client(
|
||||||
endpoint,
|
endpoint,
|
||||||
allow_internal=config.llm_allow_internal_endpoints,
|
allow_internal=config.llm_allow_internal_endpoints,
|
||||||
timeout=config.llm_request_timeout,
|
|
||||||
)
|
)
|
||||||
async_http_client = create_pinned_async_httpx_client(
|
async_http_client = create_pinned_async_httpx_client(
|
||||||
endpoint,
|
endpoint,
|
||||||
allow_internal=config.llm_allow_internal_endpoints,
|
allow_internal=config.llm_allow_internal_endpoints,
|
||||||
timeout=config.llm_request_timeout,
|
|
||||||
)
|
)
|
||||||
return OpenAILikeEmbedding(
|
return OpenAILikeEmbedding(
|
||||||
model_name=config.llm_embedding_model or "text-embedding-3-small",
|
model_name=config.llm_embedding_model or "text-embedding-3-small",
|
||||||
api_key=config.llm_api_key,
|
api_key=config.llm_api_key,
|
||||||
api_base=endpoint,
|
api_base=endpoint,
|
||||||
timeout=config.llm_request_timeout,
|
|
||||||
http_client=http_client,
|
http_client=http_client,
|
||||||
async_http_client=async_http_client,
|
async_http_client=async_http_client,
|
||||||
)
|
)
|
||||||
@@ -76,14 +78,12 @@ def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
|
|||||||
)
|
)
|
||||||
embedding._client = Client(
|
embedding._client = Client(
|
||||||
host=endpoint,
|
host=endpoint,
|
||||||
timeout=config.llm_request_timeout,
|
|
||||||
transport=PinnedHostHTTPTransport(
|
transport=PinnedHostHTTPTransport(
|
||||||
allow_internal=config.llm_allow_internal_endpoints,
|
allow_internal=config.llm_allow_internal_endpoints,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
embedding._async_client = AsyncClient(
|
embedding._async_client = AsyncClient(
|
||||||
host=endpoint,
|
host=endpoint,
|
||||||
timeout=config.llm_request_timeout,
|
|
||||||
transport=PinnedHostAsyncHTTPTransport(
|
transport=PinnedHostAsyncHTTPTransport(
|
||||||
allow_internal=config.llm_allow_internal_endpoints,
|
allow_internal=config.llm_allow_internal_endpoints,
|
||||||
),
|
),
|
||||||
@@ -95,24 +95,41 @@ def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_MODEL_NAMES = {
|
def get_embedding_dim() -> int:
|
||||||
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
|
"""
|
||||||
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
|
Loads embedding dimension from meta.json if available, otherwise infers it
|
||||||
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
|
from a dummy embedding and stores it for future use.
|
||||||
}
|
"""
|
||||||
|
config = AIConfig()
|
||||||
|
default_model = {
|
||||||
def get_configured_model_name(config: AIConfig) -> str:
|
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
|
||||||
"""Return the canonical name of the currently configured embedding model."""
|
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
# dict.get(key, default) overload resolution fails for TextChoices keys in some
|
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
|
||||||
# type checkers; use `or` fallback to avoid the ambiguity.
|
}.get(
|
||||||
default = (
|
config.llm_embedding_backend,
|
||||||
_DEFAULT_MODEL_NAMES.get(
|
"sentence-transformers/all-MiniLM-L6-v2",
|
||||||
config.llm_embedding_backend,
|
|
||||||
)
|
|
||||||
or "sentence-transformers/all-MiniLM-L6-v2"
|
|
||||||
)
|
)
|
||||||
return config.llm_embedding_model or default
|
model = config.llm_embedding_model or default_model
|
||||||
|
|
||||||
|
meta_path: Path = settings.LLM_INDEX_DIR / "meta.json"
|
||||||
|
if meta_path.exists():
|
||||||
|
with meta_path.open() as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
if meta.get("embedding_model") != model:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Embedding model changed from {meta.get('embedding_model')} to {model}. "
|
||||||
|
"You must rebuild the index.",
|
||||||
|
)
|
||||||
|
return meta["dim"]
|
||||||
|
|
||||||
|
embedding_model = get_embedding_model()
|
||||||
|
test_embed = embedding_model.get_text_embedding("test")
|
||||||
|
dim = len(test_embed)
|
||||||
|
|
||||||
|
with meta_path.open("w") as f:
|
||||||
|
json.dump({"embedding_model": model, "dim": dim}, f)
|
||||||
|
|
||||||
|
return dim
|
||||||
|
|
||||||
|
|
||||||
def _normalize_llm_index_text(text: str) -> str:
|
def _normalize_llm_index_text(text: str) -> str:
|
||||||
@@ -121,16 +138,24 @@ def _normalize_llm_index_text(text: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def build_llm_index_text(doc: Document) -> str:
|
def build_llm_index_text(doc: Document) -> str:
|
||||||
# Short structured fields (filename, storage path, ASN, title, tags, ...) live
|
|
||||||
# in node.metadata: excluded from embeddings, shown to the LLM via metadata
|
|
||||||
# prepend. Notes and Custom Fields stay in the body: Notes can be long free
|
|
||||||
# text, Custom Fields are dynamic in count and best kept in the embedding.
|
|
||||||
lines = [
|
lines = [
|
||||||
|
f"Title: {doc.title}",
|
||||||
|
f"Filename: {doc.filename}",
|
||||||
|
f"Created: {doc.created}",
|
||||||
|
f"Added: {doc.added}",
|
||||||
|
f"Modified: {doc.modified}",
|
||||||
|
f"Tags: {', '.join(tag.name for tag in doc.tags.all())}",
|
||||||
|
f"Document Type: {doc.document_type.name if doc.document_type else ''}",
|
||||||
|
f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}",
|
||||||
|
f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}",
|
||||||
|
f"Archive Serial Number: {doc.archive_serial_number or ''}",
|
||||||
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
|
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
|
||||||
]
|
]
|
||||||
|
|
||||||
for instance in doc.custom_fields.all():
|
lines.extend(
|
||||||
lines.append(f"Custom Field - {instance.field.name}: {instance}")
|
f"Custom Field - {instance.field.name}: {instance}"
|
||||||
|
for instance in doc.custom_fields.all()
|
||||||
|
)
|
||||||
|
|
||||||
lines.append("\nContent:\n")
|
lines.append("\nContent:\n")
|
||||||
lines.append(doc.content or "")
|
lines.append(doc.content or "")
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
class LLMTimeoutError(Exception):
|
|
||||||
pass
|
|
||||||
+243
-269
@@ -1,30 +1,28 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import shutil
|
||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from contextlib import contextmanager
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from filelock import ReadWriteLock
|
|
||||||
from filelock import Timeout
|
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.models import PaperlessTask
|
from documents.models import PaperlessTask
|
||||||
from documents.utils import IterWrapper
|
from documents.utils import IterWrapper
|
||||||
from documents.utils import identity
|
from documents.utils import identity
|
||||||
from paperless.config import AIConfig
|
from paperless.config import AIConfig
|
||||||
from paperless_ai.db import db_connection_released
|
|
||||||
from paperless_ai.embedding import build_llm_index_text
|
from paperless_ai.embedding import build_llm_index_text
|
||||||
from paperless_ai.embedding import get_configured_model_name
|
from paperless_ai.embedding import get_embedding_dim
|
||||||
from paperless_ai.embedding import get_embedding_model
|
from paperless_ai.embedding import get_embedding_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
from llama_index.core.schema import BaseNode
|
from llama_index.core.schema import BaseNode
|
||||||
|
|
||||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("paperless_ai.indexing")
|
logger = logging.getLogger("paperless_ai.indexing")
|
||||||
|
|
||||||
@@ -32,11 +30,21 @@ RAG_NUM_OUTPUT = 512
|
|||||||
RAG_CHUNK_OVERLAP = 200
|
RAG_CHUNK_OVERLAP = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _index_lock_path() -> Path:
|
||||||
|
"""Return the path used as the file lock for FAISS index mutations.
|
||||||
|
|
||||||
|
The lock file lives in DATA_DIR/locks/ (not inside LLM_INDEX_DIR) so that a
|
||||||
|
rebuild — which calls shutil.rmtree(LLM_INDEX_DIR) — cannot delete the lock
|
||||||
|
while another worker still holds it.
|
||||||
|
"""
|
||||||
|
return settings.LLM_INDEX_LOCK
|
||||||
|
|
||||||
|
|
||||||
def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
|
def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
|
||||||
# NOTE: The check-then-enqueue sequence below is non-atomic (TOCTOU): two
|
# NOTE: The check-then-enqueue sequence below is non-atomic (TOCTOU): two
|
||||||
# concurrent workers can both observe no running task and both enqueue a
|
# concurrent workers can both observe no running task and both enqueue a
|
||||||
# full rebuild. This is wasteful but not data-corrupting — update_llm_index
|
# full rebuild. This is wasteful but not data-corrupting — update_llm_index
|
||||||
# is itself protected by settings.LLM_INDEX_LOCK, so only one rebuild runs at a
|
# is itself protected by _index_lock_path(), so only one rebuild runs at a
|
||||||
# time and the second one is serialised after the first completes.
|
# time and the second one is serialised after the first completes.
|
||||||
from documents.tasks import llmindex_index
|
from documents.tasks import llmindex_index
|
||||||
|
|
||||||
@@ -63,110 +71,46 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_vector_store() -> "PaperlessSqliteVecVectorStore":
|
def get_or_create_storage_context(*, rebuild=False):
|
||||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
"""
|
||||||
|
Loads or creates the StorageContext (vector store, docstore, index store).
|
||||||
|
If rebuild=True, deletes and recreates everything.
|
||||||
|
"""
|
||||||
|
if rebuild:
|
||||||
|
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
||||||
|
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
||||||
return PaperlessSqliteVecVectorStore(
|
import faiss
|
||||||
uri=str(settings.LLM_INDEX_DIR),
|
from llama_index.core import StorageContext
|
||||||
|
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||||
|
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
|
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
embedding_dim = get_embedding_dim()
|
||||||
|
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
||||||
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
|
docstore = SimpleDocumentStore()
|
||||||
|
index_store = SimpleIndexStore()
|
||||||
|
else:
|
||||||
|
from llama_index.core import StorageContext
|
||||||
|
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||||
|
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
|
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
|
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
|
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||||
|
|
||||||
|
return StorageContext.from_defaults(
|
||||||
|
docstore=docstore,
|
||||||
|
index_store=index_store,
|
||||||
|
vector_store=vector_store,
|
||||||
|
persist_dir=settings.LLM_INDEX_DIR,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# --- LLM index locking ---------------------------------------------------
|
|
||||||
#
|
|
||||||
# Two locks guard the index; they answer different questions and are NOT
|
|
||||||
# interchangeable:
|
|
||||||
#
|
|
||||||
# * settings.LLM_INDEX_LOCK (FileLock, exclusive) -- serializes WRITERS against
|
|
||||||
# each other, so only one rebuild/upsert/delete/compaction runs at a time.
|
|
||||||
# Taken by write_store(). Readers never take it, so it never blocks reads.
|
|
||||||
#
|
|
||||||
# * settings.LLM_INDEX_RWLOCK (ReadWriteLock) -- coordinates readers against the
|
|
||||||
# compaction/migration file swap. read_store() takes it SHARED (readers run
|
|
||||||
# concurrently); _exclude_readers() takes it EXCLUSIVE, only for the swap, so
|
|
||||||
# the database file is never replaced while a reader connection is open (that
|
|
||||||
# would alias the old WAL onto the new file and corrupt it).
|
|
||||||
#
|
|
||||||
# | vs another writer | vs a reader
|
|
||||||
# -----------------+-------------------+----------------------------
|
|
||||||
# normal write | LLM_INDEX_LOCK | nothing (WAL gives MVCC)
|
|
||||||
# compaction/swap | LLM_INDEX_LOCK | LLM_INDEX_RWLOCK (exclusive)
|
|
||||||
# reader | nothing (WAL) | LLM_INDEX_RWLOCK (shared)
|
|
||||||
#
|
|
||||||
# They can't be merged into one ReadWriteLock: a normal write must exclude other
|
|
||||||
# writers WITHOUT blocking readers (WAL already gives reader/writer concurrency),
|
|
||||||
# and ReadWriteLock has no "exclusive vs writers, shared vs readers" mode. Only
|
|
||||||
# the swap needs to exclude readers.
|
|
||||||
def _index_rwlock() -> ReadWriteLock:
|
|
||||||
"""Return a fresh read/write lock instance for the index swap.
|
|
||||||
|
|
||||||
``is_singleton=False`` so reads and the swap always coordinate through
|
|
||||||
SQLite (the actual cross-process case) rather than hitting the in-process
|
|
||||||
reentrant-upgrade guard; callers must ``close()`` it (the context managers
|
|
||||||
below do).
|
|
||||||
"""
|
|
||||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
return ReadWriteLock(str(settings.LLM_INDEX_RWLOCK), is_singleton=False)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def read_store():
|
|
||||||
"""Acquire the shared read lock and yield the vector store for a read.
|
|
||||||
|
|
||||||
The shared lock is held for the whole lifetime of the connection (and
|
|
||||||
closed on exit) so the compaction/migration swap, which takes the exclusive
|
|
||||||
lock, never runs while this connection is open. Concurrent readers do not
|
|
||||||
block each other; only the swap does.
|
|
||||||
"""
|
|
||||||
lock = _index_rwlock()
|
|
||||||
try:
|
|
||||||
with lock.read_lock(), get_vector_store() as store:
|
|
||||||
yield store
|
|
||||||
finally:
|
|
||||||
lock.close()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _exclude_readers():
|
|
||||||
"""Acquire exclusive index access, blocking until readers have drained.
|
|
||||||
|
|
||||||
The exclusive counterpart to ``read_store()``: a compaction or migration
|
|
||||||
must not run while any reader connection is open. Raises
|
|
||||||
:class:`filelock.Timeout` if active readers do not drain within
|
|
||||||
``LLM_INDEX_COMPACTION_LOCK_TIMEOUT``; callers skip the operation on timeout.
|
|
||||||
"""
|
|
||||||
lock = _index_rwlock()
|
|
||||||
try:
|
|
||||||
with lock.write_lock(timeout=settings.LLM_INDEX_COMPACTION_LOCK_TIMEOUT):
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
lock.close()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def write_store(embed_model_name: str | None = None):
|
|
||||||
"""Acquire the write lock and yield the vector store.
|
|
||||||
|
|
||||||
All mutating operations (upsert, delete, rebuild, compact) must go through
|
|
||||||
this context manager to serialise concurrent Celery writers.
|
|
||||||
Read paths use ``read_store()`` so they hold the shared read lock.
|
|
||||||
|
|
||||||
Pass ``embed_model_name`` whenever the operation may create the table so
|
|
||||||
the model name is recorded in the schema metadata for future mismatch checks.
|
|
||||||
"""
|
|
||||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
|
||||||
|
|
||||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
with (
|
|
||||||
FileLock(settings.LLM_INDEX_LOCK),
|
|
||||||
PaperlessSqliteVecVectorStore(
|
|
||||||
uri=str(settings.LLM_INDEX_DIR),
|
|
||||||
embed_model_name=embed_model_name,
|
|
||||||
) as store,
|
|
||||||
):
|
|
||||||
yield store
|
|
||||||
|
|
||||||
|
|
||||||
def build_document_node(
|
def build_document_node(
|
||||||
document: Document,
|
document: Document,
|
||||||
*,
|
*,
|
||||||
@@ -186,9 +130,6 @@ def build_document_node(
|
|||||||
"document_type": document.document_type.name
|
"document_type": document.document_type.name
|
||||||
if document.document_type
|
if document.document_type
|
||||||
else None,
|
else None,
|
||||||
"filename": document.filename,
|
|
||||||
"storage_path": document.storage_path.name if document.storage_path else None,
|
|
||||||
"archive_serial_number": document.archive_serial_number,
|
|
||||||
"created": document.created.isoformat() if document.created else None,
|
"created": document.created.isoformat() if document.created else None,
|
||||||
"added": document.added.isoformat() if document.added else None,
|
"added": document.added.isoformat() if document.added else None,
|
||||||
"modified": document.modified.isoformat(),
|
"modified": document.modified.isoformat(),
|
||||||
@@ -201,11 +142,9 @@ def build_document_node(
|
|||||||
# the token count and exceed embedding models with small context windows
|
# the token count and exceed embedding models with small context windows
|
||||||
# (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
|
# (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
|
||||||
doc = LlamaDocument(
|
doc = LlamaDocument(
|
||||||
id_=str(document.id),
|
|
||||||
text=text,
|
text=text,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
excluded_embed_metadata_keys=list(metadata.keys()),
|
excluded_embed_metadata_keys=list(metadata.keys()),
|
||||||
excluded_llm_metadata_keys=["document_id"],
|
|
||||||
)
|
)
|
||||||
chunk_size = chunk_size or get_rag_chunk_size()
|
chunk_size = chunk_size or get_rag_chunk_size()
|
||||||
parser = SimpleNodeParser(
|
parser = SimpleNodeParser(
|
||||||
@@ -215,33 +154,76 @@ def build_document_node(
|
|||||||
return parser.get_nodes_from_documents([doc])
|
return parser.get_nodes_from_documents([doc])
|
||||||
|
|
||||||
|
|
||||||
def load_or_build_index(config: AIConfig, store: "PaperlessSqliteVecVectorStore"):
|
def load_or_build_index(nodes=None):
|
||||||
"""Return a VectorStoreIndex backed by ``store``.
|
"""
|
||||||
|
Load an existing VectorStoreIndex if present,
|
||||||
``store`` is supplied by the caller's ``read_store()`` context so the shared
|
or build a new one using provided nodes if storage is empty.
|
||||||
read lock and the connection stay alive for the whole retrieval.
|
|
||||||
"""
|
"""
|
||||||
import llama_index.core.settings as llama_settings
|
import llama_index.core.settings as llama_settings
|
||||||
from llama_index.core import VectorStoreIndex
|
from llama_index.core import VectorStoreIndex
|
||||||
|
from llama_index.core import load_index_from_storage
|
||||||
|
|
||||||
embed_model = get_embedding_model(config)
|
embed_model = get_embedding_model()
|
||||||
llama_settings.Settings.embed_model = embed_model
|
llama_settings.Settings.embed_model = embed_model
|
||||||
return VectorStoreIndex.from_vector_store(
|
storage_context = get_or_create_storage_context()
|
||||||
vector_store=store,
|
try:
|
||||||
embed_model=embed_model,
|
return load_index_from_storage(storage_context=storage_context)
|
||||||
)
|
except ValueError as e:
|
||||||
|
logger.warning("Failed to load index from storage: %s", e)
|
||||||
|
if not nodes:
|
||||||
|
queue_llm_index_update_if_needed(
|
||||||
|
rebuild=vector_store_file_exists(),
|
||||||
|
reason="LLM index missing or invalid while loading.",
|
||||||
|
)
|
||||||
|
logger.info("No nodes provided for index creation.")
|
||||||
|
raise
|
||||||
|
return VectorStoreIndex(
|
||||||
|
nodes=nodes,
|
||||||
|
storage_context=storage_context,
|
||||||
|
embed_model=embed_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def llm_index_exists() -> bool:
|
def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"):
|
||||||
"""True when the index table exists on disk."""
|
"""
|
||||||
with read_store() as store:
|
Removes existing documents from docstore for a given document from the index.
|
||||||
return store.table_exists()
|
This is necessary because FAISS IndexFlatL2 is append-only.
|
||||||
|
"""
|
||||||
|
all_node_ids = list(index.docstore.docs.keys())
|
||||||
|
existing_nodes = [
|
||||||
|
node.node_id
|
||||||
|
for node in index.docstore.get_nodes(all_node_ids)
|
||||||
|
if node.metadata.get("document_id") == str(document.id)
|
||||||
|
]
|
||||||
|
for node_id in existing_nodes:
|
||||||
|
# Delete from docstore, FAISS IndexFlatL2 are append-only
|
||||||
|
index.docstore.delete_document(node_id)
|
||||||
|
# Also purge the FAISS position -> UUID mapping so subsequent similarity
|
||||||
|
# queries don't raise KeyError on ghost vector positions.
|
||||||
|
stale_keys = [
|
||||||
|
k for k, v in index.index_struct.nodes_dict.items() if v == node_id
|
||||||
|
]
|
||||||
|
for key in stale_keys:
|
||||||
|
del index.index_struct.nodes_dict[key]
|
||||||
|
# Re-sync the mutated index_struct so persist() writes the updated nodes_dict.
|
||||||
|
index.storage_context.index_store.add_index_struct(index.index_struct)
|
||||||
|
|
||||||
|
|
||||||
|
def vector_store_file_exists():
|
||||||
|
"""
|
||||||
|
Check if the vector store file exists in the LLM index directory.
|
||||||
|
"""
|
||||||
|
return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists()
|
||||||
|
|
||||||
|
|
||||||
def get_rag_chunk_size() -> int:
|
def get_rag_chunk_size() -> int:
|
||||||
return AIConfig().llm_embedding_chunk_size
|
return AIConfig().llm_embedding_chunk_size
|
||||||
|
|
||||||
|
|
||||||
|
def get_rag_context_size() -> int:
|
||||||
|
return AIConfig().llm_context_size
|
||||||
|
|
||||||
|
|
||||||
def get_rag_chunk_overlap(chunk_size: int | None = None) -> int:
|
def get_rag_chunk_overlap(chunk_size: int | None = None) -> int:
|
||||||
chunk_size = chunk_size or get_rag_chunk_size()
|
chunk_size = chunk_size or get_rag_chunk_size()
|
||||||
return min(RAG_CHUNK_OVERLAP, chunk_size - 1)
|
return min(RAG_CHUNK_OVERLAP, chunk_size - 1)
|
||||||
@@ -267,149 +249,123 @@ def get_rag_prompt_helper(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _embed_nodes(nodes: list["BaseNode"], embed_model) -> None:
|
|
||||||
"""Embed ``nodes`` in place using ``embed_model``."""
|
|
||||||
from llama_index.core.schema import MetadataMode
|
|
||||||
|
|
||||||
texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes]
|
|
||||||
for node, emb in zip(
|
|
||||||
nodes,
|
|
||||||
embed_model.get_text_embedding_batch(texts),
|
|
||||||
strict=True,
|
|
||||||
):
|
|
||||||
node.embedding = emb
|
|
||||||
|
|
||||||
|
|
||||||
def _document_id_filters(doc_ids):
|
|
||||||
"""Return a MetadataFilters IN filter scoped to ``doc_ids``."""
|
|
||||||
from llama_index.core.vector_stores.types import FilterOperator
|
|
||||||
from llama_index.core.vector_stores.types import MetadataFilter
|
|
||||||
from llama_index.core.vector_stores.types import MetadataFilters
|
|
||||||
|
|
||||||
return MetadataFilters(
|
|
||||||
filters=[
|
|
||||||
MetadataFilter(
|
|
||||||
key="document_id",
|
|
||||||
operator=FilterOperator.IN,
|
|
||||||
value=sorted(doc_ids),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def update_llm_index(
|
def update_llm_index(
|
||||||
*,
|
*,
|
||||||
iter_wrapper: IterWrapper[Document] = identity,
|
iter_wrapper: IterWrapper[Document] = identity,
|
||||||
rebuild=False,
|
rebuild=False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Rebuild or incrementally update the LLM index."""
|
"""
|
||||||
with write_store() as store:
|
Rebuild or update the LLM index.
|
||||||
try:
|
"""
|
||||||
with _exclude_readers():
|
from llama_index.core import VectorStoreIndex
|
||||||
needs_reembed = store.check_and_run_migrations()
|
|
||||||
except Timeout:
|
|
||||||
logger.info(
|
|
||||||
"Skipping LLM index migration check: index readers are active; "
|
|
||||||
"will retry next run.",
|
|
||||||
)
|
|
||||||
needs_reembed = False
|
|
||||||
if needs_reembed:
|
|
||||||
logger.warning(
|
|
||||||
"LLM index migration requires re-embedding; forcing rebuild.",
|
|
||||||
)
|
|
||||||
rebuild = True
|
|
||||||
documents = Document.objects.all()
|
|
||||||
no_documents = not documents.exists()
|
|
||||||
|
|
||||||
# Fast exit before touching config: nothing to index and no existing index.
|
nodes = []
|
||||||
if no_documents and not rebuild and not llm_index_exists():
|
|
||||||
|
documents = Document.objects.all()
|
||||||
|
if not documents.exists():
|
||||||
logger.warning("No documents found to index.")
|
logger.warning("No documents found to index.")
|
||||||
return "No documents found to index."
|
if not rebuild and not vector_store_file_exists():
|
||||||
|
return "No documents found to index."
|
||||||
|
|
||||||
config = AIConfig()
|
config = AIConfig()
|
||||||
model_name = get_configured_model_name(config)
|
|
||||||
|
|
||||||
if not rebuild and llm_index_exists():
|
|
||||||
with read_store() as store:
|
|
||||||
config_mismatch = store.config_mismatch(model_name)
|
|
||||||
if config_mismatch:
|
|
||||||
logger.warning("Embedding model changed; forcing LLM index rebuild.")
|
|
||||||
rebuild = True
|
|
||||||
|
|
||||||
if no_documents:
|
|
||||||
logger.warning("No documents found to index.")
|
|
||||||
|
|
||||||
chunk_size = config.llm_embedding_chunk_size
|
chunk_size = config.llm_embedding_chunk_size
|
||||||
embed_model = get_embedding_model(config)
|
|
||||||
|
|
||||||
with write_store(embed_model_name=model_name) as store:
|
with FileLock(_index_lock_path()):
|
||||||
if rebuild or not store.table_exists():
|
if rebuild or not vector_store_file_exists():
|
||||||
|
# remove meta.json to force re-detection of embedding dim
|
||||||
|
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
|
||||||
|
# Rebuild index from scratch
|
||||||
logger.info("Rebuilding LLM index.")
|
logger.info("Rebuilding LLM index.")
|
||||||
store.drop_table()
|
import llama_index.core.settings as llama_settings
|
||||||
|
|
||||||
|
embed_model = get_embedding_model()
|
||||||
|
llama_settings.Settings.embed_model = embed_model
|
||||||
|
storage_context = get_or_create_storage_context(rebuild=True)
|
||||||
for document in iter_wrapper(documents):
|
for document in iter_wrapper(documents):
|
||||||
nodes = build_document_node(document, chunk_size=chunk_size)
|
document_nodes = build_document_node(document, chunk_size=chunk_size)
|
||||||
_embed_nodes(nodes, embed_model)
|
nodes.extend(document_nodes)
|
||||||
store.add(nodes)
|
|
||||||
|
index = VectorStoreIndex(
|
||||||
|
nodes=nodes,
|
||||||
|
storage_context=storage_context,
|
||||||
|
embed_model=embed_model,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
msg = "LLM index rebuilt successfully."
|
msg = "LLM index rebuilt successfully."
|
||||||
else:
|
else:
|
||||||
existing = store.get_modified_times()
|
# Update existing index
|
||||||
changed = 0
|
index = load_or_build_index()
|
||||||
|
existing_nodes: defaultdict[str, list] = defaultdict(list)
|
||||||
|
for node in index.docstore.docs.values():
|
||||||
|
doc_id = node.metadata.get("document_id")
|
||||||
|
if doc_id is not None:
|
||||||
|
existing_nodes[doc_id].append(node)
|
||||||
|
|
||||||
for document in iter_wrapper(documents):
|
for document in iter_wrapper(documents):
|
||||||
doc_id = str(document.id)
|
doc_id = str(document.id)
|
||||||
if existing.get(doc_id) == document.modified.isoformat():
|
document_modified = document.modified.isoformat()
|
||||||
continue
|
|
||||||
nodes = build_document_node(document, chunk_size=chunk_size)
|
|
||||||
_embed_nodes(nodes, embed_model)
|
|
||||||
store.upsert_document(doc_id, nodes)
|
|
||||||
changed += 1
|
|
||||||
msg = (
|
|
||||||
"LLM index updated successfully."
|
|
||||||
if changed
|
|
||||||
else "No changes detected in LLM index."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
if doc_id in existing_nodes:
|
||||||
with _exclude_readers():
|
doc_nodes = existing_nodes[doc_id]
|
||||||
store.compact()
|
node_modified = doc_nodes[0].metadata.get("modified")
|
||||||
except Timeout:
|
|
||||||
logger.info(
|
if node_modified == document_modified:
|
||||||
"Skipping LLM index compaction: index readers are active; "
|
continue
|
||||||
"will retry next run.",
|
|
||||||
)
|
# Delete from docstore, FAISS IndexFlatL2 are append-only
|
||||||
|
for _ in doc_nodes:
|
||||||
|
remove_document_docstore_nodes(document, index)
|
||||||
|
|
||||||
|
nodes.extend(build_document_node(document, chunk_size=chunk_size))
|
||||||
|
|
||||||
|
if nodes:
|
||||||
|
msg = "LLM index updated successfully."
|
||||||
|
logger.info(
|
||||||
|
"Updating %d nodes in LLM index.",
|
||||||
|
len(nodes),
|
||||||
|
)
|
||||||
|
index.insert_nodes(nodes)
|
||||||
|
else:
|
||||||
|
msg = "No changes detected in LLM index."
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def llm_index_add_or_update_document(document: Document):
|
def llm_index_add_or_update_document(document: Document):
|
||||||
"""Add or atomically replace a document's chunks in the index."""
|
"""
|
||||||
config = AIConfig()
|
Adds or updates a document in the LLM index.
|
||||||
new_nodes = build_document_node(
|
If the document already exists, it will be replaced.
|
||||||
document,
|
"""
|
||||||
chunk_size=config.llm_embedding_chunk_size,
|
new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size())
|
||||||
)
|
if not new_nodes:
|
||||||
if new_nodes:
|
logger.warning(
|
||||||
_embed_nodes(new_nodes, get_embedding_model(config))
|
"No indexable content for document %s; skipping LLM index update.",
|
||||||
|
document.pk,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
with write_store(embed_model_name=get_configured_model_name(config)) as store:
|
with FileLock(_index_lock_path()):
|
||||||
store.upsert_document(str(document.id), new_nodes)
|
index = load_or_build_index(nodes=new_nodes)
|
||||||
|
|
||||||
|
remove_document_docstore_nodes(document, index)
|
||||||
|
|
||||||
def llm_index_compact() -> None:
|
index.insert_nodes(new_nodes)
|
||||||
"""Compact the index immediately, rebuilding the table to reclaim space."""
|
|
||||||
with write_store() as store:
|
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||||
try:
|
|
||||||
with _exclude_readers():
|
|
||||||
store.compact(force=True)
|
|
||||||
except Timeout:
|
|
||||||
logger.info(
|
|
||||||
"Skipping LLM index compaction: index readers are active; "
|
|
||||||
"will retry next run.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def llm_index_remove_document(document: Document):
|
def llm_index_remove_document(document: Document):
|
||||||
"""Remove a document's chunks from the LLM index."""
|
"""
|
||||||
with write_store() as store:
|
Removes a document from the LLM index.
|
||||||
store.delete(str(document.id))
|
"""
|
||||||
|
with FileLock(_index_lock_path()):
|
||||||
|
index = load_or_build_index()
|
||||||
|
|
||||||
|
remove_document_docstore_nodes(document, index)
|
||||||
|
|
||||||
|
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||||
|
|
||||||
|
|
||||||
def truncate_content(
|
def truncate_content(
|
||||||
@@ -454,59 +410,77 @@ def query_similar_documents(
|
|||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
document_ids: Iterable[int | str] | None = None,
|
document_ids: Iterable[int | str] | None = None,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
"""Return up to ``top_k`` Documents most similar to ``document``."""
|
"""
|
||||||
|
Runs a similarity query and returns top-k similar Document objects.
|
||||||
|
"""
|
||||||
allowed_document_ids = normalize_document_ids(document_ids)
|
allowed_document_ids = normalize_document_ids(document_ids)
|
||||||
if allowed_document_ids is not None and not allowed_document_ids:
|
if allowed_document_ids is not None and not allowed_document_ids:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if not llm_index_exists():
|
if not vector_store_file_exists():
|
||||||
queue_llm_index_update_if_needed(
|
queue_llm_index_update_if_needed(
|
||||||
rebuild=False,
|
rebuild=False,
|
||||||
reason="LLM index not found for similarity query.",
|
reason="LLM index not found for similarity query.",
|
||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
config = AIConfig()
|
with FileLock(_index_lock_path()):
|
||||||
|
index = load_or_build_index()
|
||||||
|
|
||||||
from llama_index.core.retrievers import VectorIndexRetriever
|
# constrain only the node(s) that match the document IDs, if given
|
||||||
|
doc_node_ids = (
|
||||||
|
[
|
||||||
|
node.node_id
|
||||||
|
for node in index.docstore.docs.values()
|
||||||
|
if node.metadata.get("document_id") in allowed_document_ids
|
||||||
|
]
|
||||||
|
if allowed_document_ids is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if doc_node_ids is not None and not doc_node_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
filters = (
|
from llama_index.core.retrievers import VectorIndexRetriever
|
||||||
_document_id_filters(allowed_document_ids)
|
|
||||||
if allowed_document_ids is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
query_text = truncate_content(
|
|
||||||
(document.title or "") + "\n" + (document.content or ""),
|
|
||||||
chunk_size=config.llm_embedding_chunk_size,
|
|
||||||
context_size=config.llm_context_size,
|
|
||||||
)
|
|
||||||
# Hold the shared read lock for the whole retrieval so the connection is
|
|
||||||
# never open across a compaction swap. The retrieve() call generates a
|
|
||||||
# query embedding (a slow external request) and searches the vector store;
|
|
||||||
# no Django ORM access happens during it, so release the pooled DB
|
|
||||||
# connection for its duration. See #12976.
|
|
||||||
with read_store() as store:
|
|
||||||
index = load_or_build_index(config, store)
|
|
||||||
retriever = VectorIndexRetriever(
|
retriever = VectorIndexRetriever(
|
||||||
index=index,
|
index=index,
|
||||||
similarity_top_k=top_k,
|
similarity_top_k=top_k,
|
||||||
filters=filters,
|
doc_ids=doc_node_ids,
|
||||||
)
|
)
|
||||||
with db_connection_released():
|
|
||||||
|
config = AIConfig()
|
||||||
|
query_text = truncate_content(
|
||||||
|
(document.title or "") + "\n" + (document.content or ""),
|
||||||
|
chunk_size=config.llm_embedding_chunk_size,
|
||||||
|
context_size=config.llm_context_size,
|
||||||
|
)
|
||||||
|
try:
|
||||||
results = retriever.retrieve(query_text)
|
results = retriever.retrieve(query_text)
|
||||||
|
except KeyError as e:
|
||||||
|
# Ghost FAISS positions remain after deletion because IndexFlatL2 is
|
||||||
|
# append-only. Treat them as absent and return no results.
|
||||||
|
logger.debug(
|
||||||
|
"Skipping LLM similarity query for document %s due to a stale "
|
||||||
|
"FAISS position with no docstore node: %s",
|
||||||
|
document.pk,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
retrieved_document_ids: list[int] = []
|
retrieved_document_ids: list[int] = []
|
||||||
for node in results:
|
for node in results:
|
||||||
document_id = node.metadata.get("document_id")
|
document_id = node.metadata.get("document_id")
|
||||||
if document_id is None:
|
if document_id is None:
|
||||||
continue
|
continue
|
||||||
normalized = str(document_id)
|
normalized_document_id = str(document_id)
|
||||||
if allowed_document_ids is not None and normalized not in allowed_document_ids:
|
if (
|
||||||
|
allowed_document_ids is not None
|
||||||
|
and normalized_document_id not in allowed_document_ids
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
retrieved_document_ids.append(int(normalized))
|
retrieved_document_ids.append(int(normalized_document_id))
|
||||||
except ValueError: # pragma: no cover
|
except ValueError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Skipping LLM index result with invalid document_id %r.",
|
"Skipping LLM index result with invalid document_id %r.",
|
||||||
document_id,
|
document_id,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user