mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-29 00:34:17 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1f4a871b8f | |||
| 29f9475818 | |||
| d06f66b618 | |||
| f3f55e3866 | |||
| 24b81c15f6 | |||
| 5202b0880e | |||
| 7ed58f9664 | |||
| 43eb3295ce | |||
| e0ba4cfada | |||
| 73062bd5ab | |||
| a020f64d08 | |||
| 11fb09e4f4 | |||
| 8ed4bf2011 | |||
| 92c016ce47 | |||
| fb3816486c | |||
| 4394403beb | |||
| f188d308eb | |||
| a5d6ff5f15 | |||
| 8405f66e38 | |||
| c3459d8f62 | |||
| 6f8e39c2e0 | |||
| eb292baa69 | |||
| 3d0b8343b9 | |||
| a7cec673bb | |||
| 449fd97b1f | |||
| fa0c4368d7 |
+1
-2
@@ -42,7 +42,6 @@ dependencies = [
|
||||
"drf-spectacular~=0.28",
|
||||
"drf-spectacular-sidecar~=2026.5.1",
|
||||
"drf-writable-nested~=0.7.1",
|
||||
"faiss-cpu>=1.10",
|
||||
"filelock~=3.29.0",
|
||||
"flower~=2.0.1",
|
||||
"gotenberg-client~=0.14.0",
|
||||
@@ -57,7 +56,6 @@ dependencies = [
|
||||
"llama-index-embeddings-openai-like>=0.2.2",
|
||||
"llama-index-llms-ollama>=0.9.1",
|
||||
"llama-index-llms-openai-like>=0.7.1",
|
||||
"llama-index-vector-stores-faiss>=0.5.2",
|
||||
"nltk~=3.9.1",
|
||||
"ocrmypdf~=17.4.2",
|
||||
"openai>=2.32",
|
||||
@@ -74,6 +72,7 @@ dependencies = [
|
||||
"scikit-learn~=1.8.0",
|
||||
"sentence-transformers>=5.4.1",
|
||||
"setproctitle~=1.3.4",
|
||||
"sqlite-vec==0.1.9",
|
||||
"tantivy~=0.26.0",
|
||||
"tika-client~=0.11.0",
|
||||
"torch~=2.11.0",
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="visibleTasks.length === 0">
|
||||
<i-bs name="check2-all" class="me-1"></i-bs>{{dismissButtonText}}
|
||||
</button>
|
||||
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissAllTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="totalTasks === 0">
|
||||
<i-bs name="check2-all" class="me-1"></i-bs><ng-container i18n>Dismiss all</ng-container>
|
||||
</button>
|
||||
<div class="form-check form-switch mb-0 ms-2">
|
||||
<input class="form-check-input" type="checkbox" role="switch" [(ngModel)]="autoRefreshEnabled">
|
||||
<label class="form-check-label" for="autoRefreshSwitch" i18n>Auto refresh</label>
|
||||
@@ -81,7 +84,7 @@
|
||||
<button class="btn btn-sm btn-outline-primary" ngbDropdownToggle>{{filterTargetName}}</button>
|
||||
<div class="dropdown-menu shadow" ngbDropdownMenu>
|
||||
@for (t of filterTargets; track t.id) {
|
||||
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="filterTargetID = t.id">{{t.name}}</button>
|
||||
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="setFilterTarget(t.id)">{{t.name}}</button>
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -11,7 +11,7 @@ import { Router } from '@angular/router'
|
||||
import { RouterTestingModule } from '@angular/router/testing'
|
||||
import { NgbModal, NgbModalRef, NgbModule } from '@ng-bootstrap/ng-bootstrap'
|
||||
import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||
import { throwError } from 'rxjs'
|
||||
import { of, throwError } from 'rxjs'
|
||||
import { routes } from 'src/app/app-routing.module'
|
||||
import {
|
||||
PaperlessTask,
|
||||
@@ -29,7 +29,11 @@ import { ToastService } from 'src/app/services/toast.service'
|
||||
import { environment } from 'src/environments/environment'
|
||||
import { ConfirmDialogComponent } from '../../common/confirm-dialog/confirm-dialog.component'
|
||||
import { PageHeaderComponent } from '../../common/page-header/page-header.component'
|
||||
import { TasksComponent, TaskSection } from './tasks.component'
|
||||
import {
|
||||
TaskFilterTargetID,
|
||||
TasksComponent,
|
||||
TaskSection,
|
||||
} from './tasks.component'
|
||||
|
||||
const tasks: PaperlessTask[] = [
|
||||
{
|
||||
@@ -154,6 +158,13 @@ const paginatedTasks: Results<PaperlessTask> = {
|
||||
results: tasks,
|
||||
}
|
||||
|
||||
const sectionCountResponse = {
|
||||
all: 7,
|
||||
needs_attention: 2,
|
||||
in_progress: 3,
|
||||
completed: 2,
|
||||
}
|
||||
|
||||
describe('TasksComponent', () => {
|
||||
let component: TasksComponent
|
||||
let fixture: ComponentFixture<TasksComponent>
|
||||
@@ -221,6 +232,15 @@ describe('TasksComponent', () => {
|
||||
req.params.get('page') === '1'
|
||||
)
|
||||
.flush(paginatedTasks)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(req) =>
|
||||
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
|
||||
req.params.get('acknowledged') === 'false' &&
|
||||
!req.params.has('status')
|
||||
)
|
||||
.flush(sectionCountResponse)
|
||||
})
|
||||
|
||||
it('should display task sections with counts', () => {
|
||||
@@ -295,6 +315,7 @@ describe('TasksComponent', () => {
|
||||
const headerText = header.nativeElement.textContent
|
||||
|
||||
expect(headerText).toContain('Dismiss visible')
|
||||
expect(headerText).toContain('Dismiss all')
|
||||
expect(headerText).toContain('Auto refresh')
|
||||
expect(headerText).not.toContain('All types')
|
||||
expect(headerText).not.toContain('All sources')
|
||||
@@ -327,6 +348,74 @@ describe('TasksComponent', () => {
|
||||
expect(pagination).not.toBeNull()
|
||||
})
|
||||
|
||||
it('should apply the selected section to the server-side task query', () => {
|
||||
component.setSection(TaskSection.NeedsAttention)
|
||||
|
||||
const req = httpTestingController.expectOne(
|
||||
(request) =>
|
||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
request.params.get('page') === '1' &&
|
||||
request.params.get('page_size') === '25' &&
|
||||
request.params.get('acknowledged') === 'false' &&
|
||||
request.params.getAll('status').includes(PaperlessTaskStatus.Failure) &&
|
||||
request.params.getAll('status').includes(PaperlessTaskStatus.Revoked)
|
||||
)
|
||||
|
||||
req.flush({ count: 2, results: [tasks[0], tasks[1]] })
|
||||
expect(component.totalTasks).toBe(2)
|
||||
})
|
||||
|
||||
it('should apply task type and trigger source filters to the server-side task query', () => {
|
||||
component.setTaskType(PaperlessTaskType.SanityCheck)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(request) =>
|
||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
request.params.get('page_size') === '25' &&
|
||||
request.params.get('task_type') === PaperlessTaskType.SanityCheck
|
||||
)
|
||||
.flush({ count: 1, results: [tasks[6]] })
|
||||
|
||||
component.setTriggerSource(PaperlessTaskTriggerSource.System)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(request) =>
|
||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
request.params.get('page_size') === '25' &&
|
||||
request.params.get('task_type') === PaperlessTaskType.SanityCheck &&
|
||||
request.params.get('trigger_source') ===
|
||||
PaperlessTaskTriggerSource.System
|
||||
)
|
||||
.flush({ count: 1, results: [tasks[6]] })
|
||||
})
|
||||
|
||||
it('should apply text filters to the server-side task query', () => {
|
||||
component.filterText = 'invoice'
|
||||
jest.advanceTimersByTime(150)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(request) =>
|
||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
request.params.get('page_size') === '25' &&
|
||||
request.params.get('name') === 'invoice'
|
||||
)
|
||||
.flush({ count: 1, results: [tasks[0]] })
|
||||
|
||||
component.setFilterTarget(TaskFilterTargetID.Result)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(request) =>
|
||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
request.params.get('page_size') === '25' &&
|
||||
request.params.get('result') === 'invoice'
|
||||
)
|
||||
.flush({ count: 0, results: [] })
|
||||
})
|
||||
|
||||
it('should load a different task page when pagination changes', () => {
|
||||
component.setPage(2)
|
||||
|
||||
@@ -350,6 +439,27 @@ describe('TasksComponent', () => {
|
||||
expect(component.pagedTasks).toEqual([tasks[0]])
|
||||
})
|
||||
|
||||
it('should not replace section counts with current-page counts', () => {
|
||||
component.setPage(2)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(req) =>
|
||||
req.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
req.params.get('acknowledged') === 'false' &&
|
||||
req.params.get('page_size') === '25' &&
|
||||
req.params.get('page') === '2'
|
||||
)
|
||||
.flush({
|
||||
count: 30,
|
||||
results: [tasks[0]],
|
||||
})
|
||||
|
||||
expect(component.sectionCount(TaskSection.NeedsAttention)).toBe(2)
|
||||
expect(component.sectionCount(TaskSection.InProgress)).toBe(3)
|
||||
expect(component.sectionCount(TaskSection.Completed)).toBe(2)
|
||||
})
|
||||
|
||||
it('should expose stable task type options and disable empty ones', () => {
|
||||
expect(component.taskTypeOptions.map((option) => option.value)).toContain(
|
||||
PaperlessTaskType.TrainClassifier
|
||||
@@ -495,6 +605,46 @@ describe('TasksComponent', () => {
|
||||
expect(dismissSpy).toHaveBeenCalledWith(new Set([467, 466]))
|
||||
})
|
||||
|
||||
it('should support dismiss all tasks', () => {
|
||||
let modal: NgbModalRef
|
||||
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
|
||||
const dismissSpy = jest
|
||||
.spyOn(tasksService, 'dismissAllTasks')
|
||||
.mockReturnValue(of({}))
|
||||
const reloadPageSpy = jest
|
||||
.spyOn(component as any, 'reloadPage')
|
||||
.mockImplementation(() => undefined)
|
||||
|
||||
component.dismissAllTasks()
|
||||
|
||||
expect(modal).not.toBeUndefined()
|
||||
expect(modal.componentInstance.messageBold).toBe('Dismiss all 7 tasks?')
|
||||
modal.componentInstance.confirmClicked.emit()
|
||||
expect(dismissSpy).toHaveBeenCalled()
|
||||
expect(reloadPageSpy).toHaveBeenCalledWith(false)
|
||||
expect(component.selectedTasks.size).toBe(0)
|
||||
})
|
||||
|
||||
it('should show an error and re-enable modal buttons when dismissing all tasks fails', () => {
|
||||
const error = new Error('dismiss all failed')
|
||||
const toastSpy = jest.spyOn(toastService, 'showError')
|
||||
const dismissSpy = jest
|
||||
.spyOn(tasksService, 'dismissAllTasks')
|
||||
.mockReturnValue(throwError(() => error))
|
||||
|
||||
let modal: NgbModalRef
|
||||
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
|
||||
|
||||
component.dismissAllTasks()
|
||||
expect(modal).not.toBeUndefined()
|
||||
|
||||
modal.componentInstance.confirmClicked.emit()
|
||||
|
||||
expect(dismissSpy).toHaveBeenCalled()
|
||||
expect(toastSpy).toHaveBeenCalledWith('Error dismissing tasks', error)
|
||||
expect(modal.componentInstance.buttonsEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('should dismiss the currently visible scoped and filtered tasks', () => {
|
||||
component.setSection(TaskSection.InProgress)
|
||||
component.setTaskType(PaperlessTaskType.SanityCheck)
|
||||
@@ -673,6 +823,9 @@ describe('TasksComponent', () => {
|
||||
})
|
||||
|
||||
it('should keep clearing selection independent from resetting filters', () => {
|
||||
component.resetFilter()
|
||||
expect(component.filterText).toBe('')
|
||||
|
||||
component.setTaskType(PaperlessTaskType.ConsumeFile)
|
||||
component.toggleSelected(tasks[0])
|
||||
expect(component.selectedTasks.size).toBe(1)
|
||||
|
||||
@@ -40,7 +40,7 @@ export enum TaskSection {
|
||||
Completed = 'completed',
|
||||
}
|
||||
|
||||
enum TaskFilterTargetID {
|
||||
export enum TaskFilterTargetID {
|
||||
Name,
|
||||
Result,
|
||||
}
|
||||
@@ -167,6 +167,12 @@ export class TasksComponent
|
||||
public readonly pageSize = 25
|
||||
public page: number = 1
|
||||
public totalTasks: number = 0
|
||||
public sectionCounts: Record<TaskSection, number> = {
|
||||
[TaskSection.All]: 0,
|
||||
[TaskSection.NeedsAttention]: 0,
|
||||
[TaskSection.InProgress]: 0,
|
||||
[TaskSection.Completed]: 0,
|
||||
}
|
||||
public pagedTasks: PaperlessTask[] = []
|
||||
public selectedSection: TaskSection = TaskSection.All
|
||||
public selectedTaskType: PaperlessTaskType | null = null
|
||||
@@ -282,6 +288,7 @@ export class TasksComponent
|
||||
.subscribe((query) => {
|
||||
this._filterText = query
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -334,6 +341,30 @@ export class TasksComponent
|
||||
}
|
||||
}
|
||||
|
||||
dismissAllTasks() {
|
||||
let modal = this.modalService.open(ConfirmDialogComponent, {
|
||||
backdrop: 'static',
|
||||
})
|
||||
modal.componentInstance.title = $localize`Confirm Dismiss All`
|
||||
modal.componentInstance.messageBold = $localize`Dismiss all ${this.totalTasks} tasks?`
|
||||
modal.componentInstance.btnClass = 'btn-warning'
|
||||
modal.componentInstance.btnCaption = $localize`Dismiss`
|
||||
modal.componentInstance.confirmClicked.pipe(first()).subscribe(() => {
|
||||
modal.componentInstance.buttonsEnabled = false
|
||||
modal.close()
|
||||
this.tasksService.dismissAllTasks().subscribe({
|
||||
next: () => {
|
||||
this.reloadPage(false)
|
||||
},
|
||||
error: (e) => {
|
||||
this.toastService.showError($localize`Error dismissing tasks`, e)
|
||||
modal.componentInstance.buttonsEnabled = true
|
||||
},
|
||||
})
|
||||
this.clearSelection()
|
||||
})
|
||||
}
|
||||
|
||||
expandTask(task: PaperlessTask) {
|
||||
this.expandedTask = this.expandedTask == task.id ? undefined : task.id
|
||||
}
|
||||
@@ -446,9 +477,7 @@ export class TasksComponent
|
||||
}
|
||||
|
||||
sectionCount(section: TaskSection): number {
|
||||
return this.pagedTasks.filter((task) =>
|
||||
this.taskBelongsToSection(task, section)
|
||||
).length
|
||||
return this.sectionCounts[section]
|
||||
}
|
||||
|
||||
sectionShowsResults(section: TaskSection): boolean {
|
||||
@@ -458,16 +487,27 @@ export class TasksComponent
|
||||
setSection(section: TaskSection) {
|
||||
this.selectedSection = section
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
}
|
||||
|
||||
setTaskType(taskType: PaperlessTaskType | null) {
|
||||
this.selectedTaskType = taskType
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
}
|
||||
|
||||
setTriggerSource(triggerSource: PaperlessTaskTriggerSource | null) {
|
||||
this.selectedTriggerSource = triggerSource
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
}
|
||||
|
||||
setFilterTarget(filterTargetID: TaskFilterTargetID) {
|
||||
this.filterTargetID = filterTargetID
|
||||
if (this._filterText.length) {
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
}
|
||||
}
|
||||
|
||||
taskTypeOptionCount(taskType: PaperlessTaskType | null): number {
|
||||
@@ -505,19 +545,32 @@ export class TasksComponent
|
||||
}
|
||||
|
||||
public resetFilter() {
|
||||
if (!this._filterText.length) {
|
||||
return
|
||||
}
|
||||
|
||||
this._filterText = ''
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
}
|
||||
|
||||
public resetFilters() {
|
||||
const hadFilter = this.isFiltered
|
||||
this.selectedTaskType = null
|
||||
this.selectedTriggerSource = null
|
||||
this.resetFilter()
|
||||
this._filterText = ''
|
||||
this.clearSelection()
|
||||
|
||||
if (hadFilter) {
|
||||
this.reloadPage(true)
|
||||
}
|
||||
}
|
||||
|
||||
filterInputKeyup(event: KeyboardEvent) {
|
||||
if (event.key == 'Enter') {
|
||||
this._filterText = (event.target as HTMLInputElement).value
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
} else if (event.key === 'Escape') {
|
||||
this.resetFilter()
|
||||
}
|
||||
@@ -606,19 +659,86 @@ export class TasksComponent
|
||||
)
|
||||
}
|
||||
|
||||
private reloadSectionCounts() {
|
||||
this.tasksService
|
||||
.statusCounts(this.getParamsForSection(TaskSection.All))
|
||||
.pipe(first(), takeUntil(this.unsubscribeNotifier))
|
||||
.subscribe((counts) => {
|
||||
this.sectionCounts[TaskSection.All] = counts.all
|
||||
this.sectionCounts[TaskSection.NeedsAttention] = counts.needs_attention
|
||||
this.sectionCounts[TaskSection.InProgress] = counts.in_progress
|
||||
this.sectionCounts[TaskSection.Completed] = counts.completed
|
||||
})
|
||||
}
|
||||
|
||||
private getParamsForSection(
|
||||
section: TaskSection
|
||||
): Record<string, string | number | boolean | readonly string[]> {
|
||||
const params: Record<
|
||||
string,
|
||||
string | number | boolean | readonly string[]
|
||||
> = {
|
||||
acknowledged: false,
|
||||
}
|
||||
|
||||
const statuses = this.statusesForSection(section)
|
||||
if (statuses.length) {
|
||||
params.status = statuses
|
||||
}
|
||||
|
||||
if (this.selectedTaskType !== null) {
|
||||
params.task_type = this.selectedTaskType
|
||||
}
|
||||
|
||||
if (this.selectedTriggerSource !== null) {
|
||||
params.trigger_source = this.selectedTriggerSource
|
||||
}
|
||||
|
||||
if (this._filterText.length) {
|
||||
params[
|
||||
this.filterTargetID === TaskFilterTargetID.Name ? 'name' : 'result'
|
||||
] = this._filterText
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
private statusesForSection(section: TaskSection): PaperlessTaskStatus[] {
|
||||
switch (section) {
|
||||
case TaskSection.NeedsAttention:
|
||||
return [PaperlessTaskStatus.Failure, PaperlessTaskStatus.Revoked]
|
||||
case TaskSection.InProgress:
|
||||
return [PaperlessTaskStatus.Pending, PaperlessTaskStatus.Started]
|
||||
case TaskSection.Completed:
|
||||
return [PaperlessTaskStatus.Success]
|
||||
default:
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
private reloadPage(resetToFirstPage: boolean = false) {
|
||||
if (resetToFirstPage) {
|
||||
this.page = 1
|
||||
}
|
||||
|
||||
this.reloadSectionCounts()
|
||||
|
||||
this.loading = true
|
||||
this.tasksService
|
||||
.list(this.page, this.pageSize, { acknowledged: false })
|
||||
.list(
|
||||
this.page,
|
||||
this.pageSize,
|
||||
this.getParamsForSection(this.selectedSection)
|
||||
)
|
||||
.pipe(first(), takeUntil(this.unsubscribeNotifier))
|
||||
.subscribe({
|
||||
next: (result) => {
|
||||
this.pagedTasks = result.results
|
||||
this.totalTasks = result.count
|
||||
this.sectionCounts[TaskSection.All] = result.count
|
||||
if (this.selectedSection !== TaskSection.All) {
|
||||
this.sectionCounts[this.selectedSection] = result.count
|
||||
}
|
||||
this.loading = false
|
||||
if (
|
||||
this.page > 1 &&
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<div class="chat-messages font-monospace small">
|
||||
@for (message of messages; track message) {
|
||||
<div class="message d-flex flex-row small" [class.justify-content-end]="message.role === 'user'">
|
||||
<div class="p-2 m-2" [class.bg-dark]="message.role === 'user'">
|
||||
<div class="p-2 m-2" [class.bg-body]="message.role === 'user'">
|
||||
<span>
|
||||
{{ message.content }}
|
||||
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
|
||||
|
||||
@@ -188,4 +188,14 @@ describe('ChatComponent', () => {
|
||||
component.searchInputKeyDown(event)
|
||||
expect(component.sendMessage).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not send message on Enter key press while composing with IME', () => {
|
||||
jest.spyOn(component, 'sendMessage')
|
||||
const event = new KeyboardEvent('keydown', {
|
||||
key: 'Enter',
|
||||
isComposing: true,
|
||||
})
|
||||
component.searchInputKeyDown(event)
|
||||
expect(component.sendMessage).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -155,7 +155,10 @@ export class ChatComponent implements OnInit {
|
||||
}
|
||||
|
||||
public searchInputKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === 'Enter') {
|
||||
if (
|
||||
event.key === 'Enter' &&
|
||||
!(event.isComposing || event.keyCode === 229)
|
||||
) {
|
||||
event.preventDefault()
|
||||
this.sendMessage()
|
||||
}
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
@if (messageBold) {
|
||||
<p><b>{{messageBold}}</b></p>
|
||||
<p class="text-break"><b>{{messageBold}}</b></p>
|
||||
}
|
||||
@if (message) {
|
||||
<p class="mb-0" [innerHTML]="message"></p>
|
||||
<p class="mb-0 text-break" [innerHTML]="message"></p>
|
||||
}
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
|
||||
+5
-1
@@ -9,8 +9,11 @@
|
||||
<label class="form-label" for="metadataDocumentID" i18n>Documents:</label>
|
||||
<ul class="list-group"
|
||||
cdkDropList
|
||||
[cdkDropListData]="documentIDs"
|
||||
(cdkDropListDropped)="onDrop($event)">
|
||||
@for (document of documents; track document.id) {
|
||||
@for (documentID of documentIDs; track documentID) {
|
||||
@let document = getDocument(documentID);
|
||||
@if (document) {
|
||||
<li class="list-group-item d-flex align-items-center" cdkDrag>
|
||||
<i-bs name="grip-vertical" class="me-2"></i-bs>
|
||||
<div class="d-flex flex-column">
|
||||
@@ -27,6 +30,7 @@
|
||||
</small>
|
||||
</div>
|
||||
</li>
|
||||
}
|
||||
}
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
+2
-2
@@ -1,5 +1,5 @@
|
||||
<div class="btn-group">
|
||||
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="loading || (suggestions && !aiEnabled)">
|
||||
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="disabled || loading || (suggestions && !aiEnabled)">
|
||||
@if (loading) {
|
||||
<div class="spinner-border spinner-border-sm" role="status"></div>
|
||||
} @else {
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
@if (aiEnabled) {
|
||||
<div class="btn-group" ngbDropdown #dropdown="ngbDropdown" [popperOptions]="popperOptions">
|
||||
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
|
||||
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="disabled || loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
|
||||
<span class="visually-hidden" i18n>Show suggestions</span>
|
||||
</button>
|
||||
|
||||
|
||||
+12
@@ -37,6 +37,18 @@ describe('SuggestionsDropdownComponent', () => {
|
||||
expect(component.getSuggestions.emit).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not emit getSuggestions when disabled', () => {
|
||||
jest.spyOn(component.getSuggestions, 'emit')
|
||||
component.disabled = true
|
||||
component.suggestions = null
|
||||
fixture.detectChanges()
|
||||
|
||||
component.clickSuggest()
|
||||
|
||||
expect(component.getSuggestions.emit).not.toHaveBeenCalled()
|
||||
expect(fixture.nativeElement.querySelector('button').disabled).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should toggle dropdown when clickSuggest is called and suggestions are not null', () => {
|
||||
component.aiEnabled = true
|
||||
fixture.detectChanges()
|
||||
|
||||
+8
@@ -47,6 +47,14 @@ export class SuggestionsDropdownComponent {
|
||||
addCorrespondent: EventEmitter<string> = new EventEmitter()
|
||||
|
||||
public clickSuggest(): void {
|
||||
if (
|
||||
this.disabled ||
|
||||
this.loading ||
|
||||
(this.suggestions && !this.aiEnabled)
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!this.suggestions) {
|
||||
this.getSuggestions.emit(this)
|
||||
} else {
|
||||
|
||||
@@ -64,3 +64,10 @@ export interface PaperlessTaskSummary {
|
||||
last_success: Date | null
|
||||
last_failure: Date | null
|
||||
}
|
||||
|
||||
export interface PaperlessTaskStatusCounts {
|
||||
all: number
|
||||
needs_attention: number
|
||||
in_progress: number
|
||||
completed: number
|
||||
}
|
||||
|
||||
@@ -80,6 +80,27 @@ describe('TasksService', () => {
|
||||
.flush({ count: 0, results: [] })
|
||||
})
|
||||
|
||||
it('calls acknowledge_tasks api endpoint on dismiss all and reloads', () => {
|
||||
tasksService.dismissAllTasks().subscribe()
|
||||
const req = httpTestingController.expectOne(
|
||||
`${environment.apiBaseUrl}tasks/acknowledge/`
|
||||
)
|
||||
expect(req.request.method).toEqual('POST')
|
||||
expect(req.request.body).toEqual({
|
||||
all: true,
|
||||
})
|
||||
req.flush([])
|
||||
// reload is then called
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(req: HttpRequest<unknown>) =>
|
||||
req.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
req.params.get('acknowledged') === 'false' &&
|
||||
req.params.get('page_size') === '1000'
|
||||
)
|
||||
.flush({ count: 0, results: [] })
|
||||
})
|
||||
|
||||
it('groups mixed task types by status when reloading', () => {
|
||||
expect(tasksService.total).toEqual(0)
|
||||
const mockTasks = [
|
||||
@@ -221,4 +242,34 @@ describe('TasksService', () => {
|
||||
task_id: 'abc-123',
|
||||
})
|
||||
})
|
||||
|
||||
it('loads filtered task status counts', () => {
|
||||
tasksService
|
||||
.statusCounts({
|
||||
acknowledged: false,
|
||||
task_type: PaperlessTaskType.ConsumeFile,
|
||||
})
|
||||
.subscribe((res) => {
|
||||
expect(res).toEqual({
|
||||
all: 10,
|
||||
needs_attention: 2,
|
||||
in_progress: 3,
|
||||
completed: 5,
|
||||
})
|
||||
})
|
||||
|
||||
const req = httpTestingController.expectOne(
|
||||
(req: HttpRequest<unknown>) =>
|
||||
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
|
||||
req.params.get('acknowledged') === 'false' &&
|
||||
req.params.get('task_type') === PaperlessTaskType.ConsumeFile
|
||||
)
|
||||
expect(req.request.method).toEqual('GET')
|
||||
req.flush({
|
||||
all: 10,
|
||||
needs_attention: 2,
|
||||
in_progress: 3,
|
||||
completed: 5,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,6 +5,7 @@ import { first, map, takeUntil, tap } from 'rxjs/operators'
|
||||
import {
|
||||
PaperlessTask,
|
||||
PaperlessTaskStatus,
|
||||
PaperlessTaskStatusCounts,
|
||||
PaperlessTaskType,
|
||||
} from 'src/app/data/paperless-task'
|
||||
import { Results } from 'src/app/data/results'
|
||||
@@ -88,7 +89,7 @@ export class TasksService {
|
||||
public list(
|
||||
page: number,
|
||||
pageSize: number,
|
||||
extraParams?: Record<string, string | number | boolean>
|
||||
extraParams?: Record<string, string | number | boolean | readonly string[]>
|
||||
): Observable<Results<PaperlessTask>> {
|
||||
return this.http.get<Results<PaperlessTask>>(
|
||||
`${this.baseUrl}${this.endpoint}/`,
|
||||
@@ -102,6 +103,17 @@ export class TasksService {
|
||||
)
|
||||
}
|
||||
|
||||
public statusCounts(
|
||||
extraParams?: Record<string, string | number | boolean | readonly string[]>
|
||||
): Observable<PaperlessTaskStatusCounts> {
|
||||
return this.http.get<PaperlessTaskStatusCounts>(
|
||||
`${this.baseUrl}${this.endpoint}/status_counts/`,
|
||||
{
|
||||
params: extraParams,
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
public dismissTasks(task_ids: Set<number>): Observable<any> {
|
||||
return this.http
|
||||
.post(`${this.baseUrl}tasks/acknowledge/`, {
|
||||
@@ -116,6 +128,20 @@ export class TasksService {
|
||||
)
|
||||
}
|
||||
|
||||
public dismissAllTasks(): Observable<any> {
|
||||
return this.http
|
||||
.post(`${this.baseUrl}tasks/acknowledge/`, {
|
||||
all: true,
|
||||
})
|
||||
.pipe(
|
||||
first(),
|
||||
takeUntil(this.unsubscribeNotifer),
|
||||
tap(() => {
|
||||
this.reload()
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
public cancelPending(): void {
|
||||
this.unsubscribeNotifer.next(true)
|
||||
}
|
||||
|
||||
@@ -904,6 +904,19 @@ def remove_password(
|
||||
doc.id,
|
||||
pair.source_doc.source_path,
|
||||
)
|
||||
try:
|
||||
with pikepdf.open(source_path) as pdf:
|
||||
if not pdf.is_encrypted:
|
||||
logger.info(
|
||||
"Skipping password removal for document %s because the "
|
||||
"source PDF is not encrypted",
|
||||
pair.root_doc.id,
|
||||
)
|
||||
continue
|
||||
except pikepdf.PasswordError:
|
||||
# Password-protected PDFs need the supplied password below.
|
||||
pass
|
||||
|
||||
with pikepdf.open(source_path, password=password) as pdf:
|
||||
filepath: Path = (
|
||||
Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR))
|
||||
|
||||
@@ -28,6 +28,7 @@ from django.db.models.functions import Cast
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_filters import DateFilter
|
||||
from django_filters.rest_framework import BooleanFilter
|
||||
from django_filters.rest_framework import CharFilter
|
||||
from django_filters.rest_framework import DateTimeFilter
|
||||
from django_filters.rest_framework import Filter
|
||||
from django_filters.rest_framework import FilterSet
|
||||
@@ -900,6 +901,16 @@ class ShareLinkBundleFilterSet(FilterSet):
|
||||
|
||||
|
||||
class PaperlessTaskFilterSet(FilterSet):
|
||||
name = CharFilter(
|
||||
method="filter_name",
|
||||
label="Name",
|
||||
)
|
||||
|
||||
result = CharFilter(
|
||||
method="filter_result",
|
||||
label="Result",
|
||||
)
|
||||
|
||||
task_type = MultipleChoiceFilter(
|
||||
choices=PaperlessTask.TaskType.choices,
|
||||
label="Task Type",
|
||||
@@ -939,7 +950,58 @@ class PaperlessTaskFilterSet(FilterSet):
|
||||
|
||||
class Meta:
|
||||
model = PaperlessTask
|
||||
fields = ["task_type", "trigger_source", "status", "acknowledged", "owner"]
|
||||
fields = [
|
||||
"task_type",
|
||||
"trigger_source",
|
||||
"status",
|
||||
"acknowledged",
|
||||
"owner",
|
||||
"name",
|
||||
"result",
|
||||
]
|
||||
|
||||
def filter_name(self, queryset, name, value):
|
||||
if not value:
|
||||
return queryset
|
||||
|
||||
matching_task_types = [
|
||||
task_type
|
||||
for task_type, label in PaperlessTask.TaskType.choices
|
||||
if value.lower() in str(label).lower()
|
||||
]
|
||||
matching_trigger_sources = [
|
||||
trigger_source
|
||||
for trigger_source, label in PaperlessTask.TriggerSource.choices
|
||||
if value.lower() in str(label).lower()
|
||||
]
|
||||
|
||||
return queryset.filter(
|
||||
Q(input_data__filename__icontains=value)
|
||||
| Q(task_type__in=matching_task_types)
|
||||
| Q(trigger_source__in=matching_trigger_sources),
|
||||
)
|
||||
|
||||
def filter_result(self, queryset, name, value):
|
||||
if not value:
|
||||
return queryset
|
||||
|
||||
query = Q(result_data__reason__icontains=value) | Q(
|
||||
result_data__error_message__icontains=value,
|
||||
)
|
||||
|
||||
try:
|
||||
numeric_value = int(value)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
query |= Q(result_data__document_id=numeric_value) | Q(
|
||||
result_data__duplicate_of=numeric_value,
|
||||
)
|
||||
|
||||
if "duplicate" in value.lower():
|
||||
query |= Q(result_data__duplicate_of__isnull=False)
|
||||
|
||||
return queryset.filter(query)
|
||||
|
||||
def filter_is_complete(self, queryset, name, value):
|
||||
if value:
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any
|
||||
|
||||
from documents.management.commands.base import PaperlessCommand
|
||||
from documents.tasks import llmindex_index
|
||||
from paperless_ai.indexing import llm_index_compact
|
||||
|
||||
|
||||
class Command(PaperlessCommand):
|
||||
@@ -12,9 +13,12 @@ class Command(PaperlessCommand):
|
||||
|
||||
def add_arguments(self, parser: Any) -> None:
|
||||
super().add_arguments(parser)
|
||||
parser.add_argument("command", choices=["rebuild", "update"])
|
||||
parser.add_argument("command", choices=["rebuild", "update", "compact"])
|
||||
|
||||
def handle(self, *args: Any, **options: Any) -> None:
|
||||
if options["command"] == "compact":
|
||||
llm_index_compact()
|
||||
return
|
||||
llmindex_index(
|
||||
rebuild=options["command"] == "rebuild",
|
||||
iter_wrapper=lambda docs: self.track(
|
||||
|
||||
@@ -48,6 +48,7 @@ from rest_framework import serializers
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
from rest_framework.fields import SerializerMethodField
|
||||
from rest_framework.filters import OrderingFilter
|
||||
from rest_framework.utils import model_meta
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
from auditlog.context import set_actor
|
||||
@@ -121,6 +122,45 @@ class DynamicFieldsModelSerializer(serializers.ModelSerializer[Any]):
|
||||
self.fields.pop(field_name)
|
||||
|
||||
|
||||
class DocumentUpdateFieldsModelSerializer(DynamicFieldsModelSerializer):
|
||||
stale_update_excluded_fields = frozenset({"filename", "archive_filename"})
|
||||
|
||||
def _get_update_fields(self, validated_data) -> list[str]:
|
||||
model_fields = {
|
||||
field.name
|
||||
for field in self.Meta.model._meta.concrete_fields
|
||||
if field.name not in self.stale_update_excluded_fields
|
||||
}
|
||||
update_fields = [
|
||||
field_name for field_name in validated_data if field_name in model_fields
|
||||
]
|
||||
if "modified" in model_fields and "modified" not in update_fields:
|
||||
update_fields.append("modified")
|
||||
return update_fields
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
serializers.raise_errors_on_nested_writes("update", self, validated_data)
|
||||
info = model_meta.get_field_info(instance)
|
||||
|
||||
m2m_fields = []
|
||||
for attr, value in validated_data.items():
|
||||
if attr in info.relations and info.relations[attr].to_many:
|
||||
m2m_fields.append((attr, value))
|
||||
else:
|
||||
setattr(instance, attr, value)
|
||||
|
||||
# File names are managed by post-save file handling. Saving only the
|
||||
# serializer-updated fields prevents stale in-memory path values from
|
||||
# overwriting a concurrent move.
|
||||
instance.save(update_fields=self._get_update_fields(validated_data))
|
||||
|
||||
for attr, value in m2m_fields:
|
||||
field = getattr(instance, attr)
|
||||
field.set(value)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
class MatchingModelSerializer(serializers.ModelSerializer[Any]):
|
||||
document_count = serializers.IntegerField(read_only=True)
|
||||
|
||||
@@ -989,7 +1029,7 @@ class DocumentVersionInfoSerializer(serializers.Serializer[_DocumentVersionInfo]
|
||||
class DocumentSerializer(
|
||||
OwnedObjectSerializer,
|
||||
NestedUpdateMixin,
|
||||
DynamicFieldsModelSerializer,
|
||||
DocumentUpdateFieldsModelSerializer,
|
||||
):
|
||||
correspondent = CorrespondentField(allow_null=True)
|
||||
tags = TagsField(many=True)
|
||||
@@ -1128,10 +1168,9 @@ class DocumentSerializer(
|
||||
return super().validate(attrs)
|
||||
|
||||
def update(self, instance: Document, validated_data):
|
||||
if "created_date" in validated_data and "created" not in validated_data:
|
||||
instance.created = validated_data.get("created_date")
|
||||
instance.save()
|
||||
if "created_date" in validated_data:
|
||||
if "created" not in validated_data:
|
||||
validated_data["created"] = validated_data["created_date"]
|
||||
logger.warning(
|
||||
"created_date is deprecated, use created instead",
|
||||
)
|
||||
@@ -1201,11 +1240,13 @@ class DocumentSerializer(
|
||||
for tag in instance.tags.all()
|
||||
if tag not in inbox_tags_not_being_added
|
||||
]
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
with set_actor(self.user):
|
||||
super().update(instance, validated_data)
|
||||
else:
|
||||
super().update(instance, validated_data)
|
||||
|
||||
# hard delete custom field instances that were soft deleted
|
||||
CustomFieldInstance.deleted_objects.filter(document=instance).delete()
|
||||
return instance
|
||||
@@ -2632,18 +2673,25 @@ class RunTaskSerializer(serializers.Serializer[dict[str, str]]):
|
||||
|
||||
class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
|
||||
tasks = serializers.ListField(
|
||||
required=True,
|
||||
required=False,
|
||||
label="Tasks",
|
||||
write_only=True,
|
||||
child=serializers.IntegerField(),
|
||||
)
|
||||
all = serializers.BooleanField(
|
||||
required=False,
|
||||
default=False,
|
||||
label="All",
|
||||
write_only=True,
|
||||
)
|
||||
|
||||
def _validate_task_id_list(self, tasks, name="tasks") -> None:
|
||||
if not isinstance(tasks, list):
|
||||
raise serializers.ValidationError(f"{name} must be a list")
|
||||
if not all(isinstance(i, int) for i in tasks):
|
||||
raise serializers.ValidationError(f"{name} must be a list of integers")
|
||||
count = PaperlessTask.objects.filter(id__in=tasks).count()
|
||||
queryset = self.context.get("queryset", PaperlessTask.objects.all())
|
||||
count = queryset.filter(id__in=tasks).count()
|
||||
if not count == len(tasks):
|
||||
raise serializers.ValidationError(
|
||||
f"Some tasks in {name} don't exist or were specified twice.",
|
||||
@@ -2653,6 +2701,21 @@ class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
|
||||
self._validate_task_id_list(tasks)
|
||||
return tasks
|
||||
|
||||
def validate(self, attrs):
|
||||
acknowledge_all = attrs.get("all", False)
|
||||
task_ids = attrs.get("tasks")
|
||||
|
||||
if acknowledge_all and task_ids is not None:
|
||||
raise serializers.ValidationError(
|
||||
"Set either all or tasks, not both.",
|
||||
)
|
||||
if not acknowledge_all and task_ids is None:
|
||||
raise serializers.ValidationError(
|
||||
"Either all must be true or tasks must be provided.",
|
||||
)
|
||||
|
||||
return attrs
|
||||
|
||||
|
||||
class ShareLinkSerializer(OwnedObjectSerializer):
|
||||
class Meta:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import hashlib
|
||||
import logging
|
||||
import shutil
|
||||
import traceback as _tb
|
||||
@@ -16,6 +15,7 @@ from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.signals import task_revoked
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_process_shutdown
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import User
|
||||
@@ -54,6 +54,7 @@ from documents.models import WorkflowTrigger
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
from documents.plugins.helpers import DocumentsStatusManager
|
||||
from documents.templating.utils import convert_format_str_to_template_format
|
||||
from documents.utils import compute_checksum
|
||||
from documents.workflows.actions import build_workflow_action_context
|
||||
from documents.workflows.actions import execute_email_action
|
||||
from documents.workflows.actions import execute_move_to_trash_action
|
||||
@@ -410,8 +411,7 @@ def _path_matches_checksum(path: Path, checksum: str | None) -> bool:
|
||||
if checksum is None or not path.is_file():
|
||||
return False
|
||||
|
||||
with path.open("rb") as f:
|
||||
return hashlib.md5(f.read()).hexdigest() == checksum
|
||||
return compute_checksum(path) == checksum
|
||||
|
||||
|
||||
def _filename_template_uses_custom_fields(doc: Document) -> bool:
|
||||
@@ -1340,6 +1340,20 @@ def close_connection_pool_on_worker_init(**kwargs) -> None:
|
||||
conn.close_pool()
|
||||
|
||||
|
||||
@worker_process_shutdown.connect
|
||||
def close_connection_pool_on_worker_shutdown(**kwargs) -> None: # pragma: no cover
|
||||
"""
|
||||
Close the DB connection pool when a Celery child process exits.
|
||||
|
||||
With CELERY_WORKER_MAX_TASKS_PER_CHILD=1 each child is replaced after a
|
||||
single task. Without closing the pool on shutdown, its connections linger
|
||||
on the server until TCP keepalive reaps them, accumulating over time.
|
||||
"""
|
||||
for conn in connections.all(initialized_only=True):
|
||||
if conn.alias == "default" and hasattr(conn, "pool") and conn.pool:
|
||||
conn.close_pool()
|
||||
|
||||
|
||||
def add_or_update_document_in_llm_index(sender, document, **kwargs):
|
||||
"""
|
||||
Add or update a document in the LLM index when it is created or updated.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
from collections.abc import Iterable
|
||||
from pathlib import PurePath
|
||||
|
||||
@@ -36,10 +37,12 @@ class FilePathTemplate(Template):
|
||||
def clean_filepath(value: str) -> str:
|
||||
"""
|
||||
Clean up a filepath by:
|
||||
1. Removing newlines and carriage returns
|
||||
2. Removing extra spaces before and after forward slashes
|
||||
3. Preserving spaces in other parts of the path
|
||||
1. Normalizing Unicode to NFC form to prevent byte-level mismatches
|
||||
2. Removing newlines and carriage returns
|
||||
3. Removing extra spaces before and after forward slashes
|
||||
4. Preserving spaces in other parts of the path
|
||||
"""
|
||||
value = unicodedata.normalize("NFC", value)
|
||||
value = value.replace("\n", "").replace("\r", "")
|
||||
value = re.sub(r"\s*/\s*", "/", value)
|
||||
|
||||
@@ -181,17 +184,17 @@ def get_basic_metadata_context(
|
||||
"""
|
||||
return {
|
||||
"title": pathvalidate.sanitize_filename(
|
||||
document.title,
|
||||
unicodedata.normalize("NFC", document.title),
|
||||
replacement_text="-",
|
||||
),
|
||||
"correspondent": pathvalidate.sanitize_filename(
|
||||
document.correspondent.name,
|
||||
unicodedata.normalize("NFC", document.correspondent.name),
|
||||
replacement_text="-",
|
||||
)
|
||||
if document.correspondent
|
||||
else no_value_default,
|
||||
"document_type": pathvalidate.sanitize_filename(
|
||||
document.document_type.name,
|
||||
unicodedata.normalize("NFC", document.document_type.name),
|
||||
replacement_text="-",
|
||||
)
|
||||
if document.document_type
|
||||
@@ -202,7 +205,10 @@ def get_basic_metadata_context(
|
||||
"owner_username": document.owner.username
|
||||
if document.owner
|
||||
else no_value_default,
|
||||
"original_name": PurePath(document.original_filename).with_suffix("").name
|
||||
"original_name": unicodedata.normalize(
|
||||
"NFC",
|
||||
PurePath(document.original_filename).with_suffix("").name,
|
||||
)
|
||||
if document.original_filename
|
||||
else no_value_default,
|
||||
"doc_pk": f"{document.pk:07}",
|
||||
@@ -269,12 +275,12 @@ def get_tags_context(tags: Iterable[Tag]) -> dict[str, str | list[str]]:
|
||||
return {
|
||||
"tag_list": pathvalidate.sanitize_filename(
|
||||
",".join(
|
||||
sorted(tag.name for tag in tags),
|
||||
sorted(unicodedata.normalize("NFC", tag.name) for tag in tags),
|
||||
),
|
||||
replacement_text="-",
|
||||
),
|
||||
# Assumed to be ordered, but a template could loop through to find what they want
|
||||
"tag_name_list": [x.name for x in tags],
|
||||
"tag_name_list": [unicodedata.normalize("NFC", x.name) for x in tags],
|
||||
}
|
||||
|
||||
|
||||
@@ -301,7 +307,7 @@ def get_custom_fields_context(
|
||||
CustomField.FieldDataType.LONG_TEXT,
|
||||
}:
|
||||
value = pathvalidate.sanitize_filename(
|
||||
field_instance.value,
|
||||
unicodedata.normalize("NFC", field_instance.value),
|
||||
replacement_text="-",
|
||||
)
|
||||
elif (
|
||||
@@ -310,10 +316,13 @@ def get_custom_fields_context(
|
||||
):
|
||||
options = field_instance.field.extra_data["select_options"]
|
||||
value = pathvalidate.sanitize_filename(
|
||||
next(
|
||||
option["label"]
|
||||
for option in options
|
||||
if option["id"] == field_instance.value
|
||||
unicodedata.normalize(
|
||||
"NFC",
|
||||
next(
|
||||
option["label"]
|
||||
for option in options
|
||||
if option["id"] == field_instance.value
|
||||
),
|
||||
),
|
||||
replacement_text="-",
|
||||
)
|
||||
@@ -321,7 +330,7 @@ def get_custom_fields_context(
|
||||
value = field_instance.value
|
||||
field_data["custom_fields"][
|
||||
pathvalidate.sanitize_filename(
|
||||
field_instance.field.name,
|
||||
unicodedata.normalize("NFC", field_instance.field.name),
|
||||
replacement_text="-",
|
||||
)
|
||||
] = {
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.core.management import call_command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
_COMPACT = "documents.management.commands.document_llmindex.llm_index_compact"
|
||||
_INDEX = "documents.management.commands.document_llmindex.llmindex_index"
|
||||
|
||||
|
||||
class TestDocumentLlmindexCommand:
|
||||
def test_compact_calls_llm_index_compact(self, mocker: MockerFixture) -> None:
|
||||
mock_compact = mocker.patch(_COMPACT)
|
||||
call_command("document_llmindex", "compact")
|
||||
mock_compact.assert_called_once_with()
|
||||
|
||||
def test_rebuild_calls_llmindex_index_with_rebuild_true(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mock_index = mocker.patch(_INDEX)
|
||||
call_command("document_llmindex", "rebuild")
|
||||
mock_index.assert_called_once()
|
||||
assert mock_index.call_args.kwargs["rebuild"] is True
|
||||
|
||||
def test_update_calls_llmindex_index_with_rebuild_false(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mock_index = mocker.patch(_INDEX)
|
||||
call_command("document_llmindex", "update")
|
||||
mock_index.assert_called_once()
|
||||
assert mock_index.call_args.kwargs["rebuild"] is False
|
||||
@@ -844,7 +844,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
|
||||
with (
|
||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
||||
):
|
||||
mock_exists.return_value = False
|
||||
self.client.patch(
|
||||
@@ -869,7 +869,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
|
||||
with (
|
||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
||||
):
|
||||
mock_exists.return_value = True
|
||||
self.client.patch(
|
||||
@@ -890,7 +890,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
|
||||
with (
|
||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
||||
):
|
||||
mock_exists.return_value = True
|
||||
self.client.patch(
|
||||
@@ -928,7 +928,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
|
||||
with (
|
||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
||||
):
|
||||
mock_exists.return_value = True
|
||||
self.client.patch(
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
import unicodedata
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest import mock
|
||||
|
||||
import celery.result
|
||||
import pytest
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from documents.data_models import ConsumableDocument
|
||||
from documents.data_models import DocumentMetadataOverrides
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def consume_file_mock():
|
||||
with mock.patch("documents.tasks.consume_file.apply_async") as m:
|
||||
m.return_value = celery.result.AsyncResult(id="test-task-id")
|
||||
yield m
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def directories(tmp_path, settings, _media_settings):
|
||||
scratch = tmp_path / "scratch"
|
||||
scratch.mkdir()
|
||||
settings.SCRATCH_DIR = scratch
|
||||
return scratch
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestPostDocumentNFCNormalization:
|
||||
def test_nfd_filename_normalized_to_nfc(
|
||||
self,
|
||||
admin_client,
|
||||
consume_file_mock: mock.MagicMock,
|
||||
directories,
|
||||
):
|
||||
"""Uploaded file with NFD filename must have its name stored as NFC."""
|
||||
nfd = unicodedata.normalize("NFD", "Rechnung März.pdf")
|
||||
nfc = unicodedata.normalize("NFC", "Rechnung März.pdf")
|
||||
|
||||
# Verify our test strings actually differ at the byte level
|
||||
assert nfd != nfc
|
||||
|
||||
uploaded = SimpleUploadedFile(
|
||||
nfd,
|
||||
b"%PDF-1.4 test",
|
||||
content_type="application/pdf",
|
||||
)
|
||||
response = admin_client.post(
|
||||
"/api/documents/post_document/",
|
||||
{"document": uploaded},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
|
||||
input_doc: ConsumableDocument = task_kwargs["input_doc"]
|
||||
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
|
||||
|
||||
# The temp file on disk must have an NFC name
|
||||
assert input_doc.original_file.name == nfc, (
|
||||
f"Expected NFC filename {nfc!r}, got {input_doc.original_file.name!r}"
|
||||
)
|
||||
# The override filename stored for later use must also be NFC
|
||||
assert overrides.filename == nfc, (
|
||||
f"Expected NFC override filename {nfc!r}, got {overrides.filename!r}"
|
||||
)
|
||||
assert unicodedata.is_normalized("NFC", overrides.filename)
|
||||
|
||||
def test_already_nfc_filename_unchanged(
|
||||
self,
|
||||
admin_client,
|
||||
consume_file_mock: mock.MagicMock,
|
||||
directories,
|
||||
):
|
||||
"""Uploaded file with already-NFC filename must pass through unchanged."""
|
||||
nfc = unicodedata.normalize("NFC", "Invoice_2024.pdf")
|
||||
|
||||
uploaded = SimpleUploadedFile(
|
||||
nfc,
|
||||
b"%PDF-1.4 test",
|
||||
content_type="application/pdf",
|
||||
)
|
||||
response = admin_client.post(
|
||||
"/api/documents/post_document/",
|
||||
{"document": uploaded},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
|
||||
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
|
||||
|
||||
assert overrides.filename == nfc
|
||||
assert unicodedata.is_normalized("NFC", overrides.filename)
|
||||
@@ -18,6 +18,7 @@ from guardian.shortcuts import assign_perm
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from documents.filters import PaperlessTaskFilterSet
|
||||
from documents.models import PaperlessTask
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from documents.tests.factories import PaperlessTaskFactory
|
||||
@@ -169,6 +170,165 @@ class TestGetTasksV10:
|
||||
PaperlessTask.Status.STARTED,
|
||||
}
|
||||
|
||||
def test_filter_by_task_name(self, admin_client: APIClient) -> None:
|
||||
"""?name= searches task filenames, task types, and trigger sources."""
|
||||
filename_task = PaperlessTaskFactory(input_data={"filename": "invoice-123.pdf"})
|
||||
type_task = PaperlessTaskFactory(task_type=PaperlessTask.TaskType.SANITY_CHECK)
|
||||
source_task = PaperlessTaskFactory(
|
||||
trigger_source=PaperlessTask.TriggerSource.EMAIL_CONSUME,
|
||||
)
|
||||
PaperlessTaskFactory(input_data={"filename": "unrelated.pdf"})
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"name": "invoice"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["count"] == 1
|
||||
assert response.data["results"][0]["task_id"] == filename_task.task_id
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"name": "sanity"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["count"] == 1
|
||||
assert response.data["results"][0]["task_id"] == type_task.task_id
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"name": "email"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["count"] == 1
|
||||
assert response.data["results"][0]["task_id"] == source_task.task_id
|
||||
|
||||
def test_filter_by_task_result(self, admin_client: APIClient) -> None:
|
||||
"""?result= searches common structured task result messages."""
|
||||
reason_task = PaperlessTaskFactory(result_data={"reason": "Manual review"})
|
||||
error_task = PaperlessTaskFactory(
|
||||
result_data={"error_message": "Duplicate detected"},
|
||||
)
|
||||
document_task = PaperlessTaskFactory(result_data={"document_id": 321})
|
||||
duplicate_task = PaperlessTaskFactory(result_data={"duplicate_of": 123})
|
||||
PaperlessTaskFactory(result_data={"reason": "unrelated"})
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"result": "manual"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["count"] == 1
|
||||
assert response.data["results"][0]["task_id"] == reason_task.task_id
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"result": "duplicate"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
returned_ids = {task["task_id"] for task in response.data["results"]}
|
||||
assert returned_ids == {error_task.task_id, duplicate_task.task_id}
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"result": "321"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["count"] == 1
|
||||
assert response.data["results"][0]["task_id"] == document_task.task_id
|
||||
|
||||
def test_empty_task_name_and_result_filters(self) -> None:
|
||||
"""Empty name/result values leave the queryset unchanged."""
|
||||
PaperlessTaskFactory.create_batch(2)
|
||||
queryset = PaperlessTask.objects.all()
|
||||
filterset = PaperlessTaskFilterSet()
|
||||
|
||||
assert filterset.filter_name(queryset, "name", "").count() == 2
|
||||
assert filterset.filter_result(queryset, "result", "").count() == 2
|
||||
|
||||
def test_status_counts_respects_filters(self, admin_client: APIClient) -> None:
|
||||
"""status_counts/ returns section counts for the filtered task queryset."""
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.FAILURE,
|
||||
input_data={"filename": "invoice-a.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.REVOKED,
|
||||
input_data={"filename": "invoice-b.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.PENDING,
|
||||
input_data={"filename": "invoice-c.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.STARTED,
|
||||
input_data={"filename": "invoice-d.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.SUCCESS,
|
||||
input_data={"filename": "invoice-e.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=True,
|
||||
status=PaperlessTask.Status.SUCCESS,
|
||||
input_data={"filename": "invoice-acknowledged.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.SUCCESS,
|
||||
input_data={"filename": "unrelated.pdf"},
|
||||
)
|
||||
|
||||
response = admin_client.get(
|
||||
f"{ENDPOINT}status_counts/",
|
||||
{"acknowledged": "false", "name": "invoice"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == {
|
||||
"all": 5,
|
||||
"needs_attention": 2,
|
||||
"in_progress": 2,
|
||||
"completed": 1,
|
||||
}
|
||||
|
||||
def test_status_counts_ignores_section_filters(
|
||||
self,
|
||||
admin_client: APIClient,
|
||||
) -> None:
|
||||
"""status_counts/ ignores status-like filters for the sections it counts."""
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.FAILURE,
|
||||
input_data={"filename": "invoice-a.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.PENDING,
|
||||
input_data={"filename": "invoice-b.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.SUCCESS,
|
||||
input_data={"filename": "invoice-c.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.FAILURE,
|
||||
input_data={"filename": "unrelated.pdf"},
|
||||
)
|
||||
|
||||
response = admin_client.get(
|
||||
f"{ENDPOINT}status_counts/",
|
||||
{
|
||||
"acknowledged": "false",
|
||||
"name": "invoice",
|
||||
"status": PaperlessTask.Status.FAILURE,
|
||||
"is_complete": "false",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == {
|
||||
"all": 3,
|
||||
"needs_attention": 1,
|
||||
"in_progress": 1,
|
||||
"completed": 1,
|
||||
}
|
||||
|
||||
def test_default_ordering_is_newest_first(self, admin_client: APIClient) -> None:
|
||||
"""Tasks are returned in descending date_created order (newest first)."""
|
||||
base = timezone.now()
|
||||
@@ -522,6 +682,27 @@ class TestAcknowledge:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == {"result": 2}
|
||||
|
||||
def test_acknowledge_all_returns_count(self, admin_client: APIClient) -> None:
|
||||
"""POST acknowledge/ with all=true acknowledges all unacknowledged tasks."""
|
||||
unacknowledged_task1 = PaperlessTaskFactory(acknowledged=False)
|
||||
unacknowledged_task2 = PaperlessTaskFactory(acknowledged=False)
|
||||
acknowledged_task = PaperlessTaskFactory(acknowledged=True)
|
||||
|
||||
response = admin_client.post(
|
||||
ENDPOINT + "acknowledge/",
|
||||
{"all": True},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == {"result": 2}
|
||||
unacknowledged_task1.refresh_from_db()
|
||||
unacknowledged_task2.refresh_from_db()
|
||||
acknowledged_task.refresh_from_db()
|
||||
assert unacknowledged_task1.acknowledged
|
||||
assert unacknowledged_task2.acknowledged
|
||||
assert acknowledged_task.acknowledged
|
||||
|
||||
def test_acknowledged_tasks_excluded_from_unacked_filter(
|
||||
self,
|
||||
admin_client: APIClient,
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import date
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pikepdf
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import TestCase
|
||||
@@ -615,6 +616,18 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
self.img_doc.archive_filename = img_doc_archive
|
||||
self.img_doc.save()
|
||||
|
||||
@staticmethod
|
||||
def mock_password_required_pdf(
|
||||
mock_open: mock.Mock,
|
||||
fake_pdf: mock.Mock,
|
||||
) -> None:
|
||||
password_context = mock.MagicMock()
|
||||
password_context.__enter__.return_value = fake_pdf
|
||||
mock_open.side_effect = [
|
||||
pikepdf.PasswordError("password required"),
|
||||
password_context,
|
||||
]
|
||||
|
||||
@mock.patch("documents.tasks.consume_file.s")
|
||||
def test_merge(self, mock_consume_file) -> None:
|
||||
"""
|
||||
@@ -1466,6 +1479,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
|
||||
fake_pdf = mock.MagicMock()
|
||||
fake_pdf.pages = [mock.Mock(), mock.Mock(), mock.Mock()]
|
||||
fake_pdf.is_encrypted = True
|
||||
|
||||
def save_side_effect(target_path):
|
||||
Path(target_path).write_bytes(b"new pdf content")
|
||||
@@ -1480,7 +1494,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(doc.source_path, password="secret")
|
||||
self.assertEqual(
|
||||
mock_open.call_args_list,
|
||||
[
|
||||
mock.call(doc.source_path),
|
||||
mock.call(doc.source_path, password="secret"),
|
||||
],
|
||||
)
|
||||
fake_pdf.remove_unreferenced_resources.assert_called_once()
|
||||
mock_update_document.assert_not_called()
|
||||
mock_consume_delay.assert_called_once()
|
||||
@@ -1494,6 +1514,33 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
self.assertEqual(task_kwargs["input_doc"].root_document_id, doc.id)
|
||||
self.assertIsNotNone(task_kwargs["overrides"])
|
||||
|
||||
@mock.patch("documents.tasks.consume_file.apply_async")
|
||||
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
|
||||
@mock.patch("pikepdf.open")
|
||||
def test_remove_password_update_document_skips_unencrypted_pdf(
|
||||
self,
|
||||
mock_open,
|
||||
mock_mkdtemp,
|
||||
mock_consume_delay,
|
||||
) -> None:
|
||||
doc = self.doc1
|
||||
fake_pdf = mock.MagicMock()
|
||||
fake_pdf.is_encrypted = False
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
|
||||
result = bulk_edit.remove_password(
|
||||
[doc.id],
|
||||
password="secret",
|
||||
update_document=True,
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(doc.source_path)
|
||||
fake_pdf.remove_unreferenced_resources.assert_not_called()
|
||||
fake_pdf.save.assert_not_called()
|
||||
mock_mkdtemp.assert_not_called()
|
||||
mock_consume_delay.assert_not_called()
|
||||
|
||||
@mock.patch("documents.bulk_edit.update_document_content_maybe_archive_file.delay")
|
||||
@mock.patch("documents.tasks.consume_file.apply_async")
|
||||
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
|
||||
@@ -1513,12 +1560,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
mock_mkdtemp.return_value = str(temp_dir)
|
||||
|
||||
fake_pdf = mock.MagicMock()
|
||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
||||
|
||||
def save_side_effect(target_path):
|
||||
Path(target_path).write_bytes(b"new pdf content")
|
||||
|
||||
fake_pdf.save.side_effect = save_side_effect
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
|
||||
result = bulk_edit.remove_password(
|
||||
[doc.id],
|
||||
@@ -1528,7 +1575,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(source_file, password="secret")
|
||||
self.assertEqual(
|
||||
mock_open.call_args_list,
|
||||
[
|
||||
mock.call(source_file),
|
||||
mock.call(source_file, password="secret"),
|
||||
],
|
||||
)
|
||||
mock_update_document.assert_not_called()
|
||||
mock_consume_delay.assert_called_once()
|
||||
|
||||
@@ -1547,7 +1600,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
root_document=self.doc1,
|
||||
)
|
||||
fake_pdf = mock.MagicMock()
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
||||
|
||||
result = bulk_edit.remove_password(
|
||||
[self.doc1.id],
|
||||
@@ -1557,7 +1610,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(self.doc1.source_path, password="secret")
|
||||
self.assertEqual(
|
||||
mock_open.call_args_list,
|
||||
[
|
||||
mock.call(self.doc1.source_path),
|
||||
mock.call(self.doc1.source_path, password="secret"),
|
||||
],
|
||||
)
|
||||
mock_consume_delay.assert_called_once()
|
||||
|
||||
@mock.patch("documents.bulk_edit.chord")
|
||||
@@ -1580,12 +1639,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
|
||||
fake_pdf = mock.MagicMock()
|
||||
fake_pdf.pages = [mock.Mock(), mock.Mock()]
|
||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
||||
|
||||
def save_side_effect(target_path: Path) -> None:
|
||||
target_path.write_bytes(b"password removed")
|
||||
|
||||
fake_pdf.save.side_effect = save_side_effect
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
mock_group.return_value.delay.return_value = None
|
||||
|
||||
user = User.objects.create(username="owner")
|
||||
@@ -1600,7 +1659,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(doc.source_path, password="secret")
|
||||
self.assertEqual(
|
||||
mock_open.call_args_list,
|
||||
[
|
||||
mock.call(doc.source_path),
|
||||
mock.call(doc.source_path, password="secret"),
|
||||
],
|
||||
)
|
||||
mock_consume_file.assert_called_once()
|
||||
call_kwargs = mock_consume_file.call_args.kwargs
|
||||
consumable_document = call_kwargs["input_doc"]
|
||||
@@ -1618,6 +1683,43 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
mock_group.return_value.delay.assert_called_once()
|
||||
mock_chord.assert_not_called()
|
||||
|
||||
@mock.patch("documents.bulk_edit.delete")
|
||||
@mock.patch("documents.bulk_edit.chord")
|
||||
@mock.patch("documents.bulk_edit.group")
|
||||
@mock.patch("documents.tasks.consume_file.s")
|
||||
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
|
||||
@mock.patch("pikepdf.open")
|
||||
def test_remove_password_skips_unencrypted_pdf_without_queueing(
|
||||
self,
|
||||
mock_open: mock.Mock,
|
||||
mock_mkdtemp: mock.Mock,
|
||||
mock_consume_file: mock.Mock,
|
||||
mock_group: mock.Mock,
|
||||
mock_chord: mock.Mock,
|
||||
mock_delete: mock.Mock,
|
||||
) -> None:
|
||||
doc = self.doc2
|
||||
fake_pdf = mock.MagicMock()
|
||||
fake_pdf.is_encrypted = False
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
|
||||
result = bulk_edit.remove_password(
|
||||
[doc.id],
|
||||
password="secret",
|
||||
update_document=False,
|
||||
delete_original=True,
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(doc.source_path)
|
||||
fake_pdf.remove_unreferenced_resources.assert_not_called()
|
||||
fake_pdf.save.assert_not_called()
|
||||
mock_mkdtemp.assert_not_called()
|
||||
mock_consume_file.assert_not_called()
|
||||
mock_group.assert_not_called()
|
||||
mock_chord.assert_not_called()
|
||||
mock_delete.si.assert_not_called()
|
||||
|
||||
@mock.patch("documents.bulk_edit.delete")
|
||||
@mock.patch("documents.bulk_edit.chord")
|
||||
@mock.patch("documents.bulk_edit.group")
|
||||
@@ -1640,12 +1742,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
|
||||
fake_pdf = mock.MagicMock()
|
||||
fake_pdf.pages = [mock.Mock(), mock.Mock()]
|
||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
||||
|
||||
def save_side_effect(target_path: Path) -> None:
|
||||
target_path.write_bytes(b"password removed")
|
||||
|
||||
fake_pdf.save.side_effect = save_side_effect
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
mock_chord.return_value.delay.return_value = None
|
||||
|
||||
result = bulk_edit.remove_password(
|
||||
@@ -1657,7 +1759,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(doc.source_path, password="secret")
|
||||
self.assertEqual(
|
||||
mock_open.call_args_list,
|
||||
[
|
||||
mock.call(doc.source_path),
|
||||
mock.call(doc.source_path, password="secret"),
|
||||
],
|
||||
)
|
||||
mock_consume_file.assert_called_once()
|
||||
mock_group.assert_not_called()
|
||||
mock_chord.assert_called_once()
|
||||
|
||||
@@ -24,6 +24,7 @@ from documents.models import CustomFieldInstance
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.serialisers import DocumentSerializer
|
||||
from documents.tasks import empty_trash
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
@@ -221,8 +222,8 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
doc = Document.objects.create(
|
||||
title="document",
|
||||
mime_type="application/pdf",
|
||||
checksum=hashlib.md5(original_bytes).hexdigest(),
|
||||
archive_checksum=hashlib.md5(archive_bytes).hexdigest(),
|
||||
checksum=hashlib.sha256(original_bytes).hexdigest(),
|
||||
archive_checksum=hashlib.sha256(archive_bytes).hexdigest(),
|
||||
filename="old/document.pdf",
|
||||
archive_filename="old/document.pdf",
|
||||
storage_path=old_storage_path,
|
||||
@@ -251,6 +252,46 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
self.assertIsNotFile(settings.ORIGINALS_DIR / "old" / "document.pdf")
|
||||
self.assertIsNotFile(settings.ARCHIVE_DIR / "old" / "document.pdf")
|
||||
|
||||
@override_settings(FILENAME_FORMAT="{title}")
|
||||
def test_serializer_stale_update_does_not_clobber_filename(self) -> None:
|
||||
old_path = settings.ORIGINALS_DIR / "original.pdf"
|
||||
old_path.touch()
|
||||
doc = Document.objects.create(
|
||||
title="original",
|
||||
mime_type="application/pdf",
|
||||
checksum=hashlib.sha256(b"").hexdigest(),
|
||||
filename="original.pdf",
|
||||
)
|
||||
|
||||
first_instance = Document.objects.get(pk=doc.pk)
|
||||
stale_instance = Document.objects.get(pk=doc.pk)
|
||||
|
||||
serializer = DocumentSerializer(
|
||||
first_instance,
|
||||
data={"title": "first"},
|
||||
partial=True,
|
||||
)
|
||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
||||
serializer.save()
|
||||
|
||||
doc.refresh_from_db()
|
||||
self.assertEqual(doc.filename, "first.pdf")
|
||||
self.assertIsFile(settings.ORIGINALS_DIR / "first.pdf")
|
||||
|
||||
serializer = DocumentSerializer(
|
||||
stale_instance,
|
||||
data={"title": "second"},
|
||||
partial=True,
|
||||
)
|
||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
||||
serializer.save()
|
||||
|
||||
doc.refresh_from_db()
|
||||
self.assertEqual(doc.filename, "second.pdf")
|
||||
self.assertIsFile(settings.ORIGINALS_DIR / "second.pdf")
|
||||
self.assertIsNotFile(settings.ORIGINALS_DIR / "first.pdf")
|
||||
self.assertIsNotFile(old_path)
|
||||
|
||||
@override_settings(FILENAME_FORMAT="{correspondent}/{correspondent}")
|
||||
def test_document_delete(self) -> None:
|
||||
document = Document()
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Tests for NFC Unicode normalization in generate_filename / FilePathTemplate.render().
|
||||
|
||||
NFC `ü` (UTF-8: c3 bc) and NFD `ü` (UTF-8: 75 cc 88) are visually identical but
|
||||
produce different byte sequences. On Linux (ext4, ZFS) these are distinct filenames.
|
||||
All paths produced by the templating system must be NFC-normalized.
|
||||
"""
|
||||
|
||||
import unicodedata
|
||||
|
||||
import pytest
|
||||
|
||||
from documents.file_handling import generate_filename
|
||||
from documents.models import CustomField
|
||||
from documents.models import CustomFieldInstance
|
||||
from documents.tests.factories import CorrespondentFactory
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from documents.tests.factories import StoragePathFactory
|
||||
from documents.tests.factories import TagFactory
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGenerateFilenameNFCNormalization:
|
||||
@pytest.mark.parametrize(
|
||||
"raw,display",
|
||||
[
|
||||
(unicodedata.normalize("NFD", "Gemüse"), "Gemüse"),
|
||||
(unicodedata.normalize("NFD", "Café"), "Café"),
|
||||
(unicodedata.normalize("NFD", "naïve"), "naïve"),
|
||||
],
|
||||
)
|
||||
def test_nfd_title_normalized_to_nfc(self, settings, raw, display):
|
||||
"""NFD title must produce NFC path bytes."""
|
||||
settings.FILENAME_FORMAT = "{{ title }}"
|
||||
nfc = unicodedata.normalize("NFC", display)
|
||||
assert raw != nfc # confirm byte-level difference
|
||||
|
||||
doc = DocumentFactory(title=raw, mime_type="application/pdf")
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result) == f"{nfc}.pdf"
|
||||
assert str(result).encode() == f"{nfc}.pdf".encode()
|
||||
|
||||
def test_nfd_correspondent_normalized_to_nfc(self, settings):
|
||||
"""NFD correspondent name must produce NFC path component."""
|
||||
settings.FILENAME_FORMAT = "{{ correspondent }}/{{ title }}"
|
||||
nfd = unicodedata.normalize("NFD", "Müller")
|
||||
nfc = unicodedata.normalize("NFC", "Müller")
|
||||
|
||||
correspondent = CorrespondentFactory(name=nfd)
|
||||
doc = DocumentFactory(
|
||||
title="invoice",
|
||||
correspondent=correspondent,
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result) == f"{nfc}/invoice.pdf"
|
||||
assert str(result).encode() == f"{nfc}/invoice.pdf".encode()
|
||||
|
||||
def test_nfd_storage_path_normalized_to_nfc(self, settings):
|
||||
"""NFD literal in StoragePath.path template must produce NFC path bytes."""
|
||||
settings.FILENAME_FORMAT = None
|
||||
nfd = unicodedata.normalize("NFD", "Büro")
|
||||
nfc = unicodedata.normalize("NFC", "Büro")
|
||||
|
||||
# StoragePath.path is used directly as the format/template string.
|
||||
# Literal NFD characters in the template must survive rendering as NFC.
|
||||
sp = StoragePathFactory(path=f"{nfd}/{{{{ title }}}}")
|
||||
doc = DocumentFactory(title="doc", storage_path=sp, mime_type="application/pdf")
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
|
||||
|
||||
def test_nfd_raw_document_title_normalized_to_nfc(self, settings):
|
||||
"""NFD title accessed via document.title (unsanitized context) must also be NFC."""
|
||||
settings.FILENAME_FORMAT = "{{ document.title }}"
|
||||
nfd = unicodedata.normalize("NFD", "Café")
|
||||
nfc = unicodedata.normalize("NFC", "Café")
|
||||
|
||||
doc = DocumentFactory(title=nfd, mime_type="application/pdf")
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result) == f"{nfc}.pdf"
|
||||
assert str(result).encode() == f"{nfc}.pdf".encode()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestContextBuilderNFCNormalization:
|
||||
"""
|
||||
Defense-in-depth: context builder functions must NFC-normalize string inputs
|
||||
before passing them to sanitize_filename(). Task 1 already normalizes the
|
||||
final rendered path via clean_filepath(), so these tests may already pass;
|
||||
they exist as regression guards for the context-builder layer.
|
||||
"""
|
||||
|
||||
def test_nfd_tag_name_normalized_in_tag_list(self, settings):
|
||||
"""NFD tag name must appear as NFC bytes in the {{ tag_list }} shorthand."""
|
||||
settings.FILENAME_FORMAT = "{{ tag_list }}/{{ title }}"
|
||||
nfd = unicodedata.normalize("NFD", "Büro")
|
||||
nfc = unicodedata.normalize("NFC", "Büro")
|
||||
assert nfd != nfc # confirm they differ at byte level
|
||||
|
||||
tag = TagFactory(name=nfd)
|
||||
doc = DocumentFactory(title="doc", mime_type="application/pdf")
|
||||
doc.tags.set([tag])
|
||||
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
|
||||
|
||||
def test_nfd_original_name_normalized_to_nfc(self, settings):
|
||||
settings.FILENAME_FORMAT = "{{ original_name }}"
|
||||
nfd = unicodedata.normalize("NFD", "Rechnung März")
|
||||
nfc = unicodedata.normalize("NFC", "Rechnung März")
|
||||
|
||||
doc = DocumentFactory(
|
||||
original_filename=f"{nfd}.pdf",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result).encode() == f"{nfc}.pdf".encode()
|
||||
|
||||
def test_nfd_custom_field_string_value_normalized(self, settings):
|
||||
"""NFD value in a STRING-type custom field must appear as NFC in the context."""
|
||||
settings.FILENAME_FORMAT = (
|
||||
"{{ custom_fields['Location']['value'] }}/{{ title }}"
|
||||
)
|
||||
nfd_value = unicodedata.normalize("NFD", "Düsseldorf")
|
||||
nfc_value = unicodedata.normalize("NFC", "Düsseldorf")
|
||||
assert nfd_value != nfc_value
|
||||
|
||||
doc = DocumentFactory(title="report", mime_type="application/pdf")
|
||||
cf = CustomField.objects.create(
|
||||
name="Location",
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
CustomFieldInstance.objects.create(
|
||||
document=doc,
|
||||
field=cf,
|
||||
value_text=nfd_value,
|
||||
)
|
||||
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result).encode() == f"{nfc_value}/report.pdf".encode()
|
||||
|
||||
def test_nfd_custom_field_name_normalized_as_key(self, settings):
|
||||
"""NFD characters in a custom field name must appear as NFC in the context dict key."""
|
||||
nfd_name = unicodedata.normalize("NFD", "Größe")
|
||||
nfc_name = unicodedata.normalize("NFC", "Größe")
|
||||
assert nfd_name != nfc_name
|
||||
|
||||
settings.FILENAME_FORMAT = f"{{% if custom_fields['{nfc_name}'] %}}{{{{ custom_fields['{nfc_name}']['value'] }}}}/{{{{ title }}}}{{% else %}}{{{{ title }}}}{{% endif %}}"
|
||||
|
||||
doc = DocumentFactory(title="letter", mime_type="application/pdf")
|
||||
cf = CustomField.objects.create(
|
||||
name=nfd_name,
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
CustomFieldInstance.objects.create(
|
||||
document=doc,
|
||||
field=cf,
|
||||
value_text="Berlin",
|
||||
)
|
||||
|
||||
result = generate_filename(doc)
|
||||
|
||||
# If field name key is NFC-normalized, the template condition succeeds
|
||||
# and result is "Berlin/letter.pdf"; otherwise it falls back to "letter.pdf"
|
||||
assert str(result) == "Berlin/letter.pdf"
|
||||
|
||||
def test_nfd_tag_name_list_normalized_to_nfc(self, settings):
|
||||
"""NFD tag names in tag_name_list must appear as NFC bytes when iterated."""
|
||||
settings.FILENAME_FORMAT = (
|
||||
"{% for t in tag_name_list %}{{ t }}{% endfor %}/{{ title }}"
|
||||
)
|
||||
nfd = unicodedata.normalize("NFD", "Büro")
|
||||
nfc = unicodedata.normalize("NFC", "Büro")
|
||||
assert nfd != nfc # confirm byte-level difference
|
||||
|
||||
doc = DocumentFactory(title="doc", mime_type="application/pdf")
|
||||
doc.tags.add(TagFactory(name=nfd))
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
|
||||
@@ -368,6 +368,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
self.document,
|
||||
self.user,
|
||||
None,
|
||||
hints=None,
|
||||
)
|
||||
|
||||
@patch("documents.views.get_ai_document_classification")
|
||||
@@ -399,6 +400,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
self.document,
|
||||
self.user,
|
||||
"de-de",
|
||||
hints=None,
|
||||
)
|
||||
self.assertEqual(
|
||||
get_llm_suggestion_cache(
|
||||
@@ -438,6 +440,7 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
self.document,
|
||||
self.user,
|
||||
"fr-fr",
|
||||
hints=None,
|
||||
)
|
||||
self.assertEqual(
|
||||
get_llm_suggestion_cache(
|
||||
|
||||
+80
-7
@@ -245,6 +245,7 @@ from paperless_ai.matching import match_correspondents_by_name
|
||||
from paperless_ai.matching import match_document_types_by_name
|
||||
from paperless_ai.matching import match_storage_paths_by_name
|
||||
from paperless_ai.matching import match_tags_by_name
|
||||
from paperless_ai.taxonomy import get_taxonomy_hints_for_document
|
||||
from paperless_mail.models import MailAccount
|
||||
from paperless_mail.models import MailRule
|
||||
from paperless_mail.oauth import PaperlessMailOAuth2Manager
|
||||
@@ -1400,7 +1401,7 @@ class DocumentViewSet(
|
||||
)
|
||||
if request.user is not None and not has_perms_owner_aware(
|
||||
request.user,
|
||||
"view_document",
|
||||
"change_document",
|
||||
doc,
|
||||
):
|
||||
return HttpResponseForbidden("Insufficient permissions")
|
||||
@@ -1460,7 +1461,7 @@ class DocumentViewSet(
|
||||
)
|
||||
if request.user is not None and not has_perms_owner_aware(
|
||||
request.user,
|
||||
"view_document",
|
||||
"change_document",
|
||||
doc,
|
||||
):
|
||||
return HttpResponseForbidden("Insufficient permissions")
|
||||
@@ -1494,11 +1495,14 @@ class DocumentViewSet(
|
||||
refresh_suggestions_cache(doc.pk)
|
||||
return Response(cached_llm_suggestions.suggestions)
|
||||
|
||||
hints = get_taxonomy_hints_for_document(doc, request.user)
|
||||
|
||||
try:
|
||||
llm_suggestions = get_ai_document_classification(
|
||||
doc,
|
||||
request.user,
|
||||
output_language,
|
||||
hints=hints,
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.exception(
|
||||
@@ -1513,18 +1517,22 @@ class DocumentViewSet(
|
||||
matched_tags = match_tags_by_name(
|
||||
llm_suggestions.get("tags", []),
|
||||
request.user,
|
||||
hinted_names=set(hints["tags"]) if hints else None,
|
||||
)
|
||||
matched_correspondents = match_correspondents_by_name(
|
||||
llm_suggestions.get("correspondents", []),
|
||||
request.user,
|
||||
hinted_names=set(hints["correspondents"]) if hints else None,
|
||||
)
|
||||
matched_types = match_document_types_by_name(
|
||||
llm_suggestions.get("document_types", []),
|
||||
request.user,
|
||||
hinted_names=set(hints["document_types"]) if hints else None,
|
||||
)
|
||||
matched_paths = match_storage_paths_by_name(
|
||||
llm_suggestions.get("storage_paths", []),
|
||||
request.user,
|
||||
hinted_names=set(hints["storage_paths"]) if hints else None,
|
||||
)
|
||||
|
||||
resp_data = {
|
||||
@@ -3126,6 +3134,7 @@ class PostDocumentView(GenericAPIView[Any]):
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
doc_name, doc_data = serializer.validated_data.get("document")
|
||||
doc_name = normalize("NFC", doc_name)
|
||||
correspondent_id = serializer.validated_data.get("correspondent")
|
||||
document_type_id = serializer.validated_data.get("document_type")
|
||||
storage_path_id = serializer.validated_data.get("storage_path")
|
||||
@@ -4011,7 +4020,7 @@ class RemoteVersionView(GenericAPIView[Any]):
|
||||
|
||||
|
||||
class _TasksViewSetSchema(AutoSchema):
|
||||
_UNPAGINATED_ACTIONS = frozenset({"summary", "active"})
|
||||
_UNPAGINATED_ACTIONS = frozenset({"summary", "active", "status_counts"})
|
||||
|
||||
def _get_paginator(self):
|
||||
if getattr(self.view, "action", None) in self._UNPAGINATED_ACTIONS:
|
||||
@@ -4033,7 +4042,7 @@ class _TasksViewSetSchema(AutoSchema):
|
||||
),
|
||||
acknowledge=extend_schema(
|
||||
operation_id="acknowledge_tasks",
|
||||
description="Acknowledge a list of tasks",
|
||||
description="Acknowledge a list of tasks, or all visible unacknowledged tasks",
|
||||
request=AcknowledgeTasksViewSerializer,
|
||||
responses={
|
||||
(200, "application/json"): inline_serializer(
|
||||
@@ -4071,6 +4080,19 @@ class _TasksViewSetSchema(AutoSchema):
|
||||
),
|
||||
],
|
||||
),
|
||||
status_counts=extend_schema(
|
||||
responses={
|
||||
200: inline_serializer(
|
||||
name="TaskStatusCounts",
|
||||
fields={
|
||||
"all": serializers.IntegerField(),
|
||||
"needs_attention": serializers.IntegerField(),
|
||||
"in_progress": serializers.IntegerField(),
|
||||
"completed": serializers.IntegerField(),
|
||||
},
|
||||
),
|
||||
},
|
||||
),
|
||||
active=extend_schema(
|
||||
description="Currently pending and running tasks (capped at 50).",
|
||||
responses={200: TaskSerializerV10(many=True)},
|
||||
@@ -4124,6 +4146,7 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
|
||||
PaperlessTask.TaskType.SANITY_CHECK: (sanity_check, {"raise_on_error": False}),
|
||||
PaperlessTask.TaskType.LLM_INDEX: (llmindex_index, {"rebuild": False}),
|
||||
}
|
||||
_STATUS_COUNT_EXCLUDED_FILTERS = frozenset({"status", "is_complete"})
|
||||
|
||||
def get_serializer_class(self):
|
||||
# v9: use backwards-compatible serializer with old field names
|
||||
@@ -4164,16 +4187,38 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
|
||||
queryset = queryset.filter(task_id=task_id)
|
||||
return queryset
|
||||
|
||||
def get_status_count_queryset(self):
|
||||
"""Apply task filters except the status dimensions represented by the counts."""
|
||||
query_params = self.request.query_params.copy()
|
||||
for param in self._STATUS_COUNT_EXCLUDED_FILTERS:
|
||||
query_params.pop(param, None)
|
||||
|
||||
filterset = self.filterset_class(
|
||||
data=query_params,
|
||||
queryset=self.get_queryset(),
|
||||
request=self.request,
|
||||
)
|
||||
if not filterset.is_valid():
|
||||
raise ValidationError(filterset.errors)
|
||||
return filterset.qs
|
||||
|
||||
@action(
|
||||
methods=["post"],
|
||||
detail=False,
|
||||
permission_classes=[IsAuthenticated, AcknowledgeTasksPermissions],
|
||||
)
|
||||
def acknowledge(self, request):
|
||||
serializer = AcknowledgeTasksViewSerializer(data=request.data)
|
||||
queryset = self.get_queryset()
|
||||
serializer = AcknowledgeTasksViewSerializer(
|
||||
data=request.data,
|
||||
context={"queryset": queryset},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
task_ids = serializer.validated_data.get("tasks")
|
||||
tasks = self.get_queryset().filter(id__in=task_ids)
|
||||
if serializer.validated_data.get("all", False):
|
||||
tasks = queryset.filter(acknowledged=False)
|
||||
else:
|
||||
task_ids = serializer.validated_data.get("tasks")
|
||||
tasks = queryset.filter(id__in=task_ids)
|
||||
count = tasks.update(acknowledged=True)
|
||||
return Response({"result": count})
|
||||
|
||||
@@ -4226,6 +4271,34 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
|
||||
serializer = TaskSummarySerializer(data, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
@action(methods=["get"], detail=False)
|
||||
def status_counts(self, request):
|
||||
"""Aggregated task counts for task UI sections."""
|
||||
queryset = self.get_status_count_queryset()
|
||||
counts = queryset.aggregate(
|
||||
all=Count("id"),
|
||||
needs_attention=Count(
|
||||
"id",
|
||||
filter=Q(
|
||||
status__in=[
|
||||
PaperlessTask.Status.FAILURE,
|
||||
PaperlessTask.Status.REVOKED,
|
||||
],
|
||||
),
|
||||
),
|
||||
in_progress=Count(
|
||||
"id",
|
||||
filter=Q(
|
||||
status__in=[
|
||||
PaperlessTask.Status.PENDING,
|
||||
PaperlessTask.Status.STARTED,
|
||||
],
|
||||
),
|
||||
),
|
||||
completed=Count("id", filter=Q(status=PaperlessTask.Status.SUCCESS)),
|
||||
)
|
||||
return Response(counts)
|
||||
|
||||
@action(methods=["get"], detail=False)
|
||||
def active(self, request):
|
||||
"""Currently pending and running tasks (capped at 50)."""
|
||||
|
||||
@@ -20,6 +20,7 @@ from PIL import Image
|
||||
from PIL import ImageDraw
|
||||
from PIL import ImageFont
|
||||
|
||||
from paperless.parsers.utils import read_file_handle_unicode_errors
|
||||
from paperless.version import __full_version_str__
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -183,7 +184,7 @@ class TextDocumentParser:
|
||||
documents.parsers.ParseError
|
||||
If the file cannot be read.
|
||||
"""
|
||||
self._text = self._read_text(document_path)
|
||||
self._text = read_file_handle_unicode_errors(document_path, log=logger)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Result accessors
|
||||
@@ -295,30 +296,3 @@ class TextDocumentParser:
|
||||
Always ``[]`` — plain text files carry no structured metadata.
|
||||
"""
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _read_text(self, filepath: Path) -> str:
|
||||
"""Read file content, replacing invalid UTF-8 bytes rather than failing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filepath:
|
||||
Path to the file to read.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
File content as a string.
|
||||
"""
|
||||
try:
|
||||
return filepath.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
logger.warning(
|
||||
"Unicode error reading %s, replacing bad bytes: %s",
|
||||
filepath,
|
||||
exc,
|
||||
)
|
||||
return filepath.read_bytes().decode("utf-8", errors="replace")
|
||||
|
||||
@@ -8,6 +8,7 @@ share implementation.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import logging
|
||||
import re
|
||||
import tempfile
|
||||
@@ -114,7 +115,7 @@ def read_file_handle_unicode_errors(
|
||||
filepath: Path,
|
||||
log: logging.Logger | None = None,
|
||||
) -> str:
|
||||
"""Read a file as UTF-8 text, replacing invalid bytes rather than raising.
|
||||
"""Read a file as text, detecting encoding via BOM and stripping NUL bytes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -127,15 +128,27 @@ def read_file_handle_unicode_errors(
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
File content as a string, with any invalid UTF-8 sequences replaced
|
||||
by the Unicode replacement character.
|
||||
File content as a string, with NUL bytes removed so the result is
|
||||
safe to store in PostgreSQL text fields.
|
||||
"""
|
||||
_log = log or logger
|
||||
raw = filepath.read_bytes()
|
||||
|
||||
if raw.startswith((codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE)):
|
||||
encoding = "utf-16"
|
||||
elif raw.startswith(codecs.BOM_UTF8):
|
||||
encoding = "utf-8-sig"
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
|
||||
try:
|
||||
return filepath.read_text(encoding="utf-8")
|
||||
text = raw.decode(encoding)
|
||||
except UnicodeDecodeError as e:
|
||||
_log.warning("Unicode error during text reading, continuing: %s", e)
|
||||
return filepath.read_bytes().decode("utf-8", errors="replace")
|
||||
text = raw.decode("utf-8", errors="replace")
|
||||
|
||||
# PostgreSQL rejects NUL (0x00) bytes in text fields
|
||||
return text.replace("\x00", "")
|
||||
|
||||
|
||||
def get_page_count_for_pdf(
|
||||
|
||||
@@ -97,8 +97,14 @@ MODEL_FILE = get_path_from_env(
|
||||
DATA_DIR / "classification_model.pickle",
|
||||
)
|
||||
LLM_INDEX_DIR = DATA_DIR / "llm_index"
|
||||
LLM_INDEX_LOCK = DATA_DIR / "locks" / "llm_index.lock"
|
||||
(DATA_DIR / "locks").mkdir(parents=True, exist_ok=True)
|
||||
LLM_INDEX_LOCK = LLM_INDEX_DIR / "index.lock"
|
||||
# Cross-process read/write lock guarding the LLM index compaction/migration
|
||||
# file swap. Readers hold it shared; the swap takes it exclusively so it never
|
||||
# runs while a reader connection is open. Must be a SQLite (.db) file.
|
||||
LLM_INDEX_RWLOCK = LLM_INDEX_DIR / "llmindex.rwlock.db"
|
||||
# Seconds the compaction swap waits for active readers to drain before skipping
|
||||
# this cycle (it is a maintenance operation; the next run retries).
|
||||
LLM_INDEX_COMPACTION_LOCK_TIMEOUT = 30
|
||||
|
||||
LOGGING_DIR = get_path_from_env("PAPERLESS_LOGGING_DIR", DATA_DIR / "log")
|
||||
|
||||
@@ -644,6 +650,7 @@ LOGGING = {
|
||||
"kombu": {"handlers": ["file_celery"], "level": "DEBUG"},
|
||||
"_granian": {"handlers": ["file_paperless"], "level": "DEBUG"},
|
||||
"granian.access": {"handlers": ["file_paperless"], "level": "DEBUG"},
|
||||
"httpx": {"level": "WARNING"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -252,6 +252,9 @@ def parse_db_settings(data_dir: Path) -> dict[str, dict[str, Any]]:
|
||||
"NAME": os.getenv("PAPERLESS_DBNAME", "paperless"),
|
||||
"USER": os.getenv("PAPERLESS_DBUSER", "paperless"),
|
||||
"PASSWORD": os.getenv("PAPERLESS_DBPASS", "paperless"),
|
||||
# Validate pooled connections so a connection closed server-side
|
||||
# is replaced rather than handed out as "the connection is closed".
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
}
|
||||
|
||||
base_options = {
|
||||
|
||||
@@ -398,6 +398,7 @@ class TestParseDbSettings:
|
||||
{
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
"HOST": "localhost",
|
||||
"NAME": "paperless",
|
||||
"USER": "paperless",
|
||||
@@ -426,6 +427,7 @@ class TestParseDbSettings:
|
||||
{
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
"HOST": "paperless-db-host",
|
||||
"PORT": 1111,
|
||||
"NAME": "customdb",
|
||||
@@ -455,6 +457,7 @@ class TestParseDbSettings:
|
||||
{
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
"HOST": "pghost",
|
||||
"NAME": "paperless",
|
||||
"USER": "paperless",
|
||||
@@ -485,6 +488,7 @@ class TestParseDbSettings:
|
||||
{
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
"HOST": "pghost",
|
||||
"NAME": "paperless",
|
||||
"USER": "paperless",
|
||||
|
||||
@@ -2,13 +2,50 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
from pathlib import Path
|
||||
|
||||
from paperless.parsers.utils import is_tagged_pdf
|
||||
from paperless.parsers.utils import read_file_handle_unicode_errors
|
||||
|
||||
SAMPLES = Path(__file__).parent / "samples" / "tesseract"
|
||||
|
||||
|
||||
class TestReadFileHandleUnicodeErrors:
|
||||
def test_plain_utf8(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "plain.txt"
|
||||
f.write_bytes(b"hello world")
|
||||
assert read_file_handle_unicode_errors(f) == "hello world"
|
||||
|
||||
def test_utf8_bom(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "bom.txt"
|
||||
f.write_bytes(codecs.BOM_UTF8 + b"hello")
|
||||
assert read_file_handle_unicode_errors(f) == "hello"
|
||||
|
||||
def test_utf16_le(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "utf16le.txt"
|
||||
f.write_bytes(codecs.BOM_UTF16_LE + "hello".encode("utf-16-le"))
|
||||
assert read_file_handle_unicode_errors(f) == "hello"
|
||||
|
||||
def test_utf16_be(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "utf16be.txt"
|
||||
f.write_bytes(codecs.BOM_UTF16_BE + "hello".encode("utf-16-be"))
|
||||
assert read_file_handle_unicode_errors(f) == "hello"
|
||||
|
||||
def test_nul_bytes_stripped(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "null-bytes.txt"
|
||||
f.write_bytes(b"foo\x00bar")
|
||||
assert read_file_handle_unicode_errors(f) == "foobar"
|
||||
|
||||
def test_invalid_utf8_replaced(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "bad.txt"
|
||||
f.write_bytes(b"ok\x80\x81bad")
|
||||
result = read_file_handle_unicode_errors(f)
|
||||
assert "ok" in result
|
||||
assert "bad" in result
|
||||
assert "\x00" not in result
|
||||
|
||||
|
||||
class TestIsTaggedPdf:
|
||||
def test_tagged_pdf_returns_true(self) -> None:
|
||||
assert is_tagged_pdf(SAMPLES / "simple-digital.pdf") is True
|
||||
|
||||
@@ -49,7 +49,7 @@ from paperless.serialisers import GroupSerializer
|
||||
from paperless.serialisers import PaperlessAuthTokenSerializer
|
||||
from paperless.serialisers import ProfileSerializer
|
||||
from paperless.serialisers import UserSerializer
|
||||
from paperless_ai.indexing import vector_store_file_exists
|
||||
from paperless_ai.indexing import llm_index_exists
|
||||
|
||||
|
||||
class PaperlessObtainAuthTokenView(ObtainAuthToken):
|
||||
@@ -467,7 +467,7 @@ class ApplicationConfigurationViewSet(ModelViewSet[ApplicationConfiguration]):
|
||||
or old_llm_context_size != new_llm_context_size
|
||||
)
|
||||
rebuild_needed = new_ai_index_enabled and (
|
||||
not vector_store_file_exists() or embedding_config_changed
|
||||
not llm_index_exists() or embedding_config_changed
|
||||
)
|
||||
|
||||
if rebuild_needed:
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from documents.models import Document
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.db import db_connection_released
|
||||
from paperless_ai.indexing import query_similar_documents
|
||||
from paperless_ai.indexing import truncate_content
|
||||
from paperless_ai.indexing import visible_document_ids_for_user
|
||||
from paperless_ai.taxonomy import format_hints_for_prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paperless_ai.taxonomy import TaxonomyHints
|
||||
|
||||
logger = logging.getLogger("paperless_ai.rag_classifier")
|
||||
|
||||
@@ -24,14 +30,26 @@ def get_language_name(language_code: str) -> str:
|
||||
|
||||
def build_prompt_without_rag(
|
||||
document: Document,
|
||||
config: AIConfig,
|
||||
hints: "TaxonomyHints | None" = None,
|
||||
) -> str:
|
||||
filename = document.filename or ""
|
||||
content = truncate_content(document.content[:4000] or "")
|
||||
content = truncate_content(
|
||||
document.content[:4000] or "",
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
|
||||
hints_block = format_hints_for_prompt(hints) if hints else ""
|
||||
# Splice the block (if any) immediately before the "Analyze ..." instruction.
|
||||
# When there is no block this expands to nothing, so the prompt is identical
|
||||
# to the pre-hints baseline.
|
||||
hints_section = f"{hints_block}\n\n " if hints_block else ""
|
||||
|
||||
return f"""
|
||||
You are a document classification assistant.
|
||||
|
||||
Analyze the following document and extract the following information:
|
||||
{hints_section}Analyze the following document and extract the following information:
|
||||
- A short descriptive title
|
||||
- Tags that reflect the content
|
||||
- Names of people or organizations mentioned
|
||||
@@ -49,10 +67,16 @@ def build_prompt_without_rag(
|
||||
|
||||
def build_prompt_with_rag(
|
||||
document: Document,
|
||||
config: AIConfig,
|
||||
user: User | None = None,
|
||||
hints: "TaxonomyHints | None" = None,
|
||||
) -> str:
|
||||
base_prompt = build_prompt_without_rag(document)
|
||||
context = truncate_content(get_context_for_document(document, user))
|
||||
base_prompt = build_prompt_without_rag(document, config, hints=hints)
|
||||
context = truncate_content(
|
||||
get_context_for_document(document, user),
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
|
||||
return f"""{base_prompt}
|
||||
|
||||
@@ -85,20 +109,7 @@ def get_context_for_document(
|
||||
user: User | None = None,
|
||||
max_docs: int = 5,
|
||||
) -> str:
|
||||
visible_documents = (
|
||||
get_objects_for_user_owner_aware(
|
||||
user,
|
||||
"view_document",
|
||||
Document,
|
||||
)
|
||||
if user
|
||||
else None
|
||||
)
|
||||
visible_document_ids = (
|
||||
list(visible_documents.values_list("pk", flat=True))
|
||||
if visible_documents is not None
|
||||
else None
|
||||
)
|
||||
visible_document_ids = visible_document_ids_for_user(user)
|
||||
similar_docs = query_similar_documents(
|
||||
document=doc,
|
||||
document_ids=visible_document_ids,
|
||||
@@ -126,30 +137,34 @@ def get_ai_document_classification(
|
||||
document: Document,
|
||||
user: User | None = None,
|
||||
output_language: str | None = None,
|
||||
hints: "TaxonomyHints | None" = None,
|
||||
) -> dict:
|
||||
ai_config = AIConfig()
|
||||
|
||||
prompt = (
|
||||
build_prompt_with_rag(document, user)
|
||||
build_prompt_with_rag(document, ai_config, user, hints=hints)
|
||||
if ai_config.llm_embedding_backend
|
||||
else build_prompt_without_rag(document)
|
||||
else build_prompt_without_rag(document, ai_config, hints=hints)
|
||||
)
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query(prompt)
|
||||
suggestions = parse_ai_response(result)
|
||||
if output_language:
|
||||
localized = client.run_llm_query(
|
||||
build_localization_prompt(suggestions, output_language),
|
||||
)
|
||||
localized_suggestions = parse_ai_response(localized)
|
||||
suggestions = {
|
||||
**suggestions,
|
||||
"title": localized_suggestions["title"] or suggestions["title"],
|
||||
"tags": localized_suggestions["tags"] or suggestions["tags"],
|
||||
"document_types": localized_suggestions["document_types"]
|
||||
or suggestions["document_types"],
|
||||
"storage_paths": localized_suggestions["storage_paths"]
|
||||
or suggestions["storage_paths"],
|
||||
}
|
||||
# Hand the pooled DB connection back while the (slow) LLM query runs so it
|
||||
# is not pinned for the call's duration; see paperless_ai.db and #12976.
|
||||
with db_connection_released():
|
||||
result = client.run_llm_query(prompt)
|
||||
suggestions = parse_ai_response(result)
|
||||
if output_language:
|
||||
localized = client.run_llm_query(
|
||||
build_localization_prompt(suggestions, output_language),
|
||||
)
|
||||
localized_suggestions = parse_ai_response(localized)
|
||||
suggestions = {
|
||||
**suggestions,
|
||||
"title": localized_suggestions["title"] or suggestions["title"],
|
||||
"tags": localized_suggestions["tags"] or suggestions["tags"],
|
||||
"document_types": localized_suggestions["document_types"]
|
||||
or suggestions["document_types"],
|
||||
"storage_paths": localized_suggestions["storage_paths"]
|
||||
or suggestions["storage_paths"],
|
||||
}
|
||||
return suggestions
|
||||
|
||||
+57
-123
@@ -3,9 +3,13 @@ import logging
|
||||
import sys
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.db import db_connection_released
|
||||
from paperless_ai.indexing import _document_id_filters
|
||||
from paperless_ai.indexing import get_rag_prompt_helper
|
||||
from paperless_ai.indexing import load_or_build_index
|
||||
from paperless_ai.indexing import read_store
|
||||
|
||||
logger = logging.getLogger("paperless_ai.chat")
|
||||
|
||||
@@ -75,148 +79,78 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str
|
||||
)
|
||||
|
||||
|
||||
def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int):
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.vector_stores import VectorStoreQuery
|
||||
|
||||
class DocumentFilteredFaissRetriever(BaseRetriever):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._cached_query_str = None
|
||||
self._cached_nodes = []
|
||||
|
||||
def _retrieve(self, query_bundle):
|
||||
if query_bundle.query_str == self._cached_query_str:
|
||||
return self._cached_nodes
|
||||
|
||||
if query_bundle.embedding is None:
|
||||
query_bundle.embedding = (
|
||||
index._embed_model.get_agg_embedding_from_queries(
|
||||
query_bundle.embedding_strs,
|
||||
)
|
||||
)
|
||||
|
||||
faiss_index = index.vector_store._faiss_index
|
||||
max_top_k = faiss_index.ntotal
|
||||
if max_top_k == 0:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = []
|
||||
return []
|
||||
|
||||
query_top_k = min(max(similarity_top_k, 1), max_top_k)
|
||||
allowed_nodes: list[NodeWithScore] = []
|
||||
seen_node_ids: set[str] = set()
|
||||
|
||||
while query_top_k <= max_top_k:
|
||||
query_result = index.vector_store.query(
|
||||
VectorStoreQuery(
|
||||
query_embedding=query_bundle.embedding,
|
||||
similarity_top_k=query_top_k,
|
||||
),
|
||||
)
|
||||
|
||||
for vector_id, score in zip(
|
||||
query_result.ids or [],
|
||||
query_result.similarities or [],
|
||||
strict=False,
|
||||
):
|
||||
node_id = index.index_struct.nodes_dict.get(vector_id)
|
||||
if node_id is None or node_id in seen_node_ids:
|
||||
continue
|
||||
|
||||
node = index.docstore.docs.get(node_id)
|
||||
if node is None or node.metadata.get("document_id") not in doc_ids:
|
||||
continue
|
||||
|
||||
seen_node_ids.add(node_id)
|
||||
allowed_nodes.append(NodeWithScore(node=node, score=score))
|
||||
|
||||
if len(allowed_nodes) >= similarity_top_k:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
if query_top_k == max_top_k:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
query_top_k = min(query_top_k * 2, max_top_k)
|
||||
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
return DocumentFilteredFaissRetriever()
|
||||
|
||||
|
||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
try:
|
||||
yield from _stream_chat_with_documents(query_str, documents)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to stream document chat response: {e}", exc_info=True)
|
||||
logger.exception("Failed to stream document chat response: %s", e)
|
||||
yield CHAT_ERROR_MESSAGE
|
||||
|
||||
|
||||
def _stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
client = AIClient()
|
||||
index = load_or_build_index()
|
||||
|
||||
doc_ids = [str(doc.pk) for doc in documents]
|
||||
|
||||
# Filter only the node(s) that match the document IDs
|
||||
nodes = [
|
||||
node
|
||||
for node in index.docstore.docs.values()
|
||||
if node.metadata.get("document_id") in doc_ids
|
||||
]
|
||||
|
||||
if len(nodes) == 0:
|
||||
logger.warning("No nodes found for the given documents.")
|
||||
if not documents:
|
||||
yield CHAT_NO_CONTENT_MESSAGE
|
||||
return
|
||||
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import get_response_synthesizer
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
index,
|
||||
set(doc_ids),
|
||||
CHAT_RETRIEVER_TOP_K,
|
||||
)
|
||||
config = AIConfig()
|
||||
filters = _document_id_filters(str(doc.pk) for doc in documents)
|
||||
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
if len(top_nodes) == 0:
|
||||
logger.warning("Retriever returned no nodes for the given documents.")
|
||||
yield CHAT_NO_CONTENT_MESSAGE
|
||||
return
|
||||
# Hold the shared read lock for the whole operation: the query engine
|
||||
# retrieves from the vector store again during synthesis, so the connection
|
||||
# must stay open (and the swap must not run) until the stream finishes.
|
||||
with read_store() as store:
|
||||
index = load_or_build_index(config, store)
|
||||
retriever = VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=CHAT_RETRIEVER_TOP_K,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
references = _get_document_references(documents, top_nodes)
|
||||
# Slow query-embedding + vector search; no Django ORM access happens
|
||||
# during it, so release the pooled DB connection for its duration. See
|
||||
# #12976.
|
||||
with db_connection_released():
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
if not top_nodes:
|
||||
logger.warning("No nodes found for the given documents.")
|
||||
yield CHAT_NO_CONTENT_MESSAGE
|
||||
return
|
||||
|
||||
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
|
||||
response_synthesizer = get_response_synthesizer(
|
||||
llm=client.llm,
|
||||
prompt_helper=get_rag_prompt_helper(),
|
||||
text_qa_template=prompt_template,
|
||||
streaming=True,
|
||||
)
|
||||
client = AIClient()
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
llm=client.llm,
|
||||
response_synthesizer=response_synthesizer,
|
||||
streaming=True,
|
||||
)
|
||||
references = _get_document_references(documents, top_nodes)
|
||||
|
||||
logger.debug("Document chat query: %s", query_str)
|
||||
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
|
||||
response_synthesizer = get_response_synthesizer(
|
||||
llm=client.llm,
|
||||
prompt_helper=get_rag_prompt_helper(
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
),
|
||||
text_qa_template=prompt_template,
|
||||
streaming=True,
|
||||
)
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
llm=client.llm,
|
||||
response_synthesizer=response_synthesizer,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
response_stream = query_engine.query(query_str)
|
||||
logger.debug("Document chat query: %s", query_str)
|
||||
# Release the pooled DB connection for the slow streaming LLM response
|
||||
# so it is not pinned for the whole stream; see paperless_ai.db and
|
||||
# #12976.
|
||||
with db_connection_released():
|
||||
response_stream = query_engine.query(query_str)
|
||||
for chunk in response_stream.response_gen:
|
||||
yield chunk
|
||||
sys.stdout.flush()
|
||||
|
||||
for chunk in response_stream.response_gen:
|
||||
yield chunk
|
||||
sys.stdout.flush()
|
||||
|
||||
if references:
|
||||
yield _format_chat_metadata_trailer(references)
|
||||
if references:
|
||||
yield _format_chat_metadata_trailer(references)
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.db import connections
|
||||
|
||||
|
||||
@contextmanager
|
||||
def db_connection_released():
|
||||
"""
|
||||
Return any checked-out DB connections to the pool for the duration of the
|
||||
wrapped block.
|
||||
|
||||
The AI endpoints run inside a synchronous web request (``ai_suggestions``)
|
||||
or a streaming response (``chat``). Django keeps the request's database
|
||||
connection checked out for the entire request/response, so a blocking LLM
|
||||
call - which can take many seconds - pins a pooled connection the whole
|
||||
time. With connection pooling enabled, enough concurrent AI requests check
|
||||
out every slot and all other requests then fail with
|
||||
``psycopg_pool.PoolTimeout`` (see issue #12976).
|
||||
|
||||
No Django ORM access happens during the LLM call, so we hand the connection
|
||||
back to the pool first; Django transparently re-checks-out a connection on
|
||||
the next ORM use after the block.
|
||||
"""
|
||||
connections.close_all()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
connections.close_all()
|
||||
@@ -1,12 +1,9 @@
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
|
||||
from documents.models import Document
|
||||
@@ -23,9 +20,7 @@ OCR_LEADER_REGEX = re.compile(r"[._\-\u00b7]{4,}")
|
||||
HORIZONTAL_WHITESPACE_REGEX = re.compile(r"[ \t\u00a0]+")
|
||||
|
||||
|
||||
def get_embedding_model() -> "BaseEmbedding":
|
||||
config = AIConfig()
|
||||
|
||||
def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
|
||||
match config.llm_embedding_backend:
|
||||
case LLMEmbeddingBackend.OPENAI_LIKE:
|
||||
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
||||
@@ -95,41 +90,24 @@ def get_embedding_model() -> "BaseEmbedding":
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_dim() -> int:
|
||||
"""
|
||||
Loads embedding dimension from meta.json if available, otherwise infers it
|
||||
from a dummy embedding and stores it for future use.
|
||||
"""
|
||||
config = AIConfig()
|
||||
default_model = {
|
||||
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
|
||||
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
|
||||
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
|
||||
}.get(
|
||||
config.llm_embedding_backend,
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
_DEFAULT_MODEL_NAMES = {
|
||||
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
|
||||
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
|
||||
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
|
||||
}
|
||||
|
||||
|
||||
def get_configured_model_name(config: AIConfig) -> str:
|
||||
"""Return the canonical name of the currently configured embedding model."""
|
||||
# dict.get(key, default) overload resolution fails for TextChoices keys in some
|
||||
# type checkers; use `or` fallback to avoid the ambiguity.
|
||||
default = (
|
||||
_DEFAULT_MODEL_NAMES.get(
|
||||
config.llm_embedding_backend,
|
||||
)
|
||||
or "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
model = config.llm_embedding_model or default_model
|
||||
|
||||
meta_path: Path = settings.LLM_INDEX_DIR / "meta.json"
|
||||
if meta_path.exists():
|
||||
with meta_path.open() as f:
|
||||
meta = json.load(f)
|
||||
if meta.get("embedding_model") != model:
|
||||
raise RuntimeError(
|
||||
f"Embedding model changed from {meta.get('embedding_model')} to {model}. "
|
||||
"You must rebuild the index.",
|
||||
)
|
||||
return meta["dim"]
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
test_embed = embedding_model.get_text_embedding("test")
|
||||
dim = len(test_embed)
|
||||
|
||||
with meta_path.open("w") as f:
|
||||
json.dump({"embedding_model": model, "dim": dim}, f)
|
||||
|
||||
return dim
|
||||
return config.llm_embedding_model or default
|
||||
|
||||
|
||||
def _normalize_llm_index_text(text: str) -> str:
|
||||
@@ -138,17 +116,11 @@ def _normalize_llm_index_text(text: str) -> str:
|
||||
|
||||
|
||||
def build_llm_index_text(doc: Document) -> str:
|
||||
# Short structured fields (filename, storage path, ASN, title, tags, ...) live
|
||||
# in node.metadata: excluded from embeddings, shown to the LLM via metadata
|
||||
# prepend. Notes and Custom Fields stay in the body: Notes can be long free
|
||||
# text, Custom Fields are dynamic in count and best kept in the embedding.
|
||||
lines = [
|
||||
f"Title: {doc.title}",
|
||||
f"Filename: {doc.filename}",
|
||||
f"Created: {doc.created}",
|
||||
f"Added: {doc.added}",
|
||||
f"Modified: {doc.modified}",
|
||||
f"Tags: {', '.join(tag.name for tag in doc.tags.all())}",
|
||||
f"Document Type: {doc.document_type.name if doc.document_type else ''}",
|
||||
f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}",
|
||||
f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}",
|
||||
f"Archive Serial Number: {doc.archive_serial_number or ''}",
|
||||
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
|
||||
]
|
||||
|
||||
|
||||
+313
-246
@@ -1,27 +1,32 @@
|
||||
import logging
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
from django.utils import timezone
|
||||
from filelock import FileLock
|
||||
from filelock import ReadWriteLock
|
||||
from filelock import Timeout
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import PaperlessTask
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
from documents.utils import IterWrapper
|
||||
from documents.utils import identity
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.db import db_connection_released
|
||||
from paperless_ai.embedding import build_llm_index_text
|
||||
from paperless_ai.embedding import get_embedding_dim
|
||||
from paperless_ai.embedding import get_configured_model_name
|
||||
from paperless_ai.embedding import get_embedding_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
|
||||
logger = logging.getLogger("paperless_ai.indexing")
|
||||
@@ -30,21 +35,11 @@ RAG_NUM_OUTPUT = 512
|
||||
RAG_CHUNK_OVERLAP = 200
|
||||
|
||||
|
||||
def _index_lock_path() -> Path:
|
||||
"""Return the path used as the file lock for FAISS index mutations.
|
||||
|
||||
The lock file lives in DATA_DIR/locks/ (not inside LLM_INDEX_DIR) so that a
|
||||
rebuild — which calls shutil.rmtree(LLM_INDEX_DIR) — cannot delete the lock
|
||||
while another worker still holds it.
|
||||
"""
|
||||
return settings.LLM_INDEX_LOCK
|
||||
|
||||
|
||||
def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
|
||||
# NOTE: The check-then-enqueue sequence below is non-atomic (TOCTOU): two
|
||||
# concurrent workers can both observe no running task and both enqueue a
|
||||
# full rebuild. This is wasteful but not data-corrupting — update_llm_index
|
||||
# is itself protected by _index_lock_path(), so only one rebuild runs at a
|
||||
# is itself protected by settings.LLM_INDEX_LOCK, so only one rebuild runs at a
|
||||
# time and the second one is serialised after the first completes.
|
||||
from documents.tasks import llmindex_index
|
||||
|
||||
@@ -71,46 +66,110 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_or_create_storage_context(*, rebuild=False):
|
||||
"""
|
||||
Loads or creates the StorageContext (vector store, docstore, index store).
|
||||
If rebuild=True, deletes and recreates everything.
|
||||
"""
|
||||
if rebuild:
|
||||
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
def get_vector_store() -> "PaperlessSqliteVecVectorStore":
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
||||
import faiss
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
embedding_dim = get_embedding_dim()
|
||||
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
docstore = SimpleDocumentStore()
|
||||
index_store = SimpleIndexStore()
|
||||
else:
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
|
||||
return StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
index_store=index_store,
|
||||
vector_store=vector_store,
|
||||
persist_dir=settings.LLM_INDEX_DIR,
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return PaperlessSqliteVecVectorStore(
|
||||
uri=str(settings.LLM_INDEX_DIR),
|
||||
)
|
||||
|
||||
|
||||
# --- LLM index locking ---------------------------------------------------
|
||||
#
|
||||
# Two locks guard the index; they answer different questions and are NOT
|
||||
# interchangeable:
|
||||
#
|
||||
# * settings.LLM_INDEX_LOCK (FileLock, exclusive) -- serializes WRITERS against
|
||||
# each other, so only one rebuild/upsert/delete/compaction runs at a time.
|
||||
# Taken by write_store(). Readers never take it, so it never blocks reads.
|
||||
#
|
||||
# * settings.LLM_INDEX_RWLOCK (ReadWriteLock) -- coordinates readers against the
|
||||
# compaction/migration file swap. read_store() takes it SHARED (readers run
|
||||
# concurrently); _exclude_readers() takes it EXCLUSIVE, only for the swap, so
|
||||
# the database file is never replaced while a reader connection is open (that
|
||||
# would alias the old WAL onto the new file and corrupt it).
|
||||
#
|
||||
# | vs another writer | vs a reader
|
||||
# -----------------+-------------------+----------------------------
|
||||
# normal write | LLM_INDEX_LOCK | nothing (WAL gives MVCC)
|
||||
# compaction/swap | LLM_INDEX_LOCK | LLM_INDEX_RWLOCK (exclusive)
|
||||
# reader | nothing (WAL) | LLM_INDEX_RWLOCK (shared)
|
||||
#
|
||||
# They can't be merged into one ReadWriteLock: a normal write must exclude other
|
||||
# writers WITHOUT blocking readers (WAL already gives reader/writer concurrency),
|
||||
# and ReadWriteLock has no "exclusive vs writers, shared vs readers" mode. Only
|
||||
# the swap needs to exclude readers.
|
||||
def _index_rwlock() -> ReadWriteLock:
|
||||
"""Return a fresh read/write lock instance for the index swap.
|
||||
|
||||
``is_singleton=False`` so reads and the swap always coordinate through
|
||||
SQLite (the actual cross-process case) rather than hitting the in-process
|
||||
reentrant-upgrade guard; callers must ``close()`` it (the context managers
|
||||
below do).
|
||||
"""
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return ReadWriteLock(str(settings.LLM_INDEX_RWLOCK), is_singleton=False)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def read_store():
|
||||
"""Acquire the shared read lock and yield the vector store for a read.
|
||||
|
||||
The shared lock is held for the whole lifetime of the connection (and
|
||||
closed on exit) so the compaction/migration swap, which takes the exclusive
|
||||
lock, never runs while this connection is open. Concurrent readers do not
|
||||
block each other; only the swap does.
|
||||
"""
|
||||
lock = _index_rwlock()
|
||||
try:
|
||||
with lock.read_lock(), get_vector_store() as store:
|
||||
yield store
|
||||
finally:
|
||||
lock.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _exclude_readers():
|
||||
"""Acquire exclusive index access, blocking until readers have drained.
|
||||
|
||||
The exclusive counterpart to ``read_store()``: a compaction or migration
|
||||
must not run while any reader connection is open. Raises
|
||||
:class:`filelock.Timeout` if active readers do not drain within
|
||||
``LLM_INDEX_COMPACTION_LOCK_TIMEOUT``; callers skip the operation on timeout.
|
||||
"""
|
||||
lock = _index_rwlock()
|
||||
try:
|
||||
with lock.write_lock(timeout=settings.LLM_INDEX_COMPACTION_LOCK_TIMEOUT):
|
||||
yield
|
||||
finally:
|
||||
lock.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def write_store(embed_model_name: str | None = None):
|
||||
"""Acquire the write lock and yield the vector store.
|
||||
|
||||
All mutating operations (upsert, delete, rebuild, compact) must go through
|
||||
this context manager to serialise concurrent Celery writers.
|
||||
Read paths use ``read_store()`` so they hold the shared read lock.
|
||||
|
||||
Pass ``embed_model_name`` whenever the operation may create the table so
|
||||
the model name is recorded in the schema metadata for future mismatch checks.
|
||||
"""
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with (
|
||||
FileLock(settings.LLM_INDEX_LOCK),
|
||||
PaperlessSqliteVecVectorStore(
|
||||
uri=str(settings.LLM_INDEX_DIR),
|
||||
embed_model_name=embed_model_name,
|
||||
) as store,
|
||||
):
|
||||
yield store
|
||||
|
||||
|
||||
def build_document_node(
|
||||
document: Document,
|
||||
*,
|
||||
@@ -130,6 +189,9 @@ def build_document_node(
|
||||
"document_type": document.document_type.name
|
||||
if document.document_type
|
||||
else None,
|
||||
"filename": document.filename,
|
||||
"storage_path": document.storage_path.name if document.storage_path else None,
|
||||
"archive_serial_number": document.archive_serial_number,
|
||||
"created": document.created.isoformat() if document.created else None,
|
||||
"added": document.added.isoformat() if document.added else None,
|
||||
"modified": document.modified.isoformat(),
|
||||
@@ -142,9 +204,11 @@ def build_document_node(
|
||||
# the token count and exceed embedding models with small context windows
|
||||
# (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
|
||||
doc = LlamaDocument(
|
||||
id_=str(document.id),
|
||||
text=text,
|
||||
metadata=metadata,
|
||||
excluded_embed_metadata_keys=list(metadata.keys()),
|
||||
excluded_llm_metadata_keys=["document_id"],
|
||||
)
|
||||
chunk_size = chunk_size or get_rag_chunk_size()
|
||||
parser = SimpleNodeParser(
|
||||
@@ -154,76 +218,33 @@ def build_document_node(
|
||||
return parser.get_nodes_from_documents([doc])
|
||||
|
||||
|
||||
def load_or_build_index(nodes=None):
|
||||
"""
|
||||
Load an existing VectorStoreIndex if present,
|
||||
or build a new one using provided nodes if storage is empty.
|
||||
def load_or_build_index(config: AIConfig, store: "PaperlessSqliteVecVectorStore"):
|
||||
"""Return a VectorStoreIndex backed by ``store``.
|
||||
|
||||
``store`` is supplied by the caller's ``read_store()`` context so the shared
|
||||
read lock and the connection stay alive for the whole retrieval.
|
||||
"""
|
||||
import llama_index.core.settings as llama_settings
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core import load_index_from_storage
|
||||
|
||||
embed_model = get_embedding_model()
|
||||
embed_model = get_embedding_model(config)
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context()
|
||||
try:
|
||||
return load_index_from_storage(storage_context=storage_context)
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to load index from storage: %s", e)
|
||||
if not nodes:
|
||||
queue_llm_index_update_if_needed(
|
||||
rebuild=vector_store_file_exists(),
|
||||
reason="LLM index missing or invalid while loading.",
|
||||
)
|
||||
logger.info("No nodes provided for index creation.")
|
||||
raise
|
||||
return VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
return VectorStoreIndex.from_vector_store(
|
||||
vector_store=store,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
|
||||
|
||||
def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"):
|
||||
"""
|
||||
Removes existing documents from docstore for a given document from the index.
|
||||
This is necessary because FAISS IndexFlatL2 is append-only.
|
||||
"""
|
||||
all_node_ids = list(index.docstore.docs.keys())
|
||||
existing_nodes = [
|
||||
node.node_id
|
||||
for node in index.docstore.get_nodes(all_node_ids)
|
||||
if node.metadata.get("document_id") == str(document.id)
|
||||
]
|
||||
for node_id in existing_nodes:
|
||||
# Delete from docstore, FAISS IndexFlatL2 are append-only
|
||||
index.docstore.delete_document(node_id)
|
||||
# Also purge the FAISS position -> UUID mapping so subsequent similarity
|
||||
# queries don't raise KeyError on ghost vector positions.
|
||||
stale_keys = [
|
||||
k for k, v in index.index_struct.nodes_dict.items() if v == node_id
|
||||
]
|
||||
for key in stale_keys:
|
||||
del index.index_struct.nodes_dict[key]
|
||||
# Re-sync the mutated index_struct so persist() writes the updated nodes_dict.
|
||||
index.storage_context.index_store.add_index_struct(index.index_struct)
|
||||
|
||||
|
||||
def vector_store_file_exists():
|
||||
"""
|
||||
Check if the vector store file exists in the LLM index directory.
|
||||
"""
|
||||
return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists()
|
||||
def llm_index_exists() -> bool:
|
||||
"""True when the index table exists on disk."""
|
||||
with read_store() as store:
|
||||
return store.table_exists()
|
||||
|
||||
|
||||
def get_rag_chunk_size() -> int:
|
||||
return AIConfig().llm_embedding_chunk_size
|
||||
|
||||
|
||||
def get_rag_context_size() -> int:
|
||||
return AIConfig().llm_context_size
|
||||
|
||||
|
||||
def get_rag_chunk_overlap(chunk_size: int | None = None) -> int:
|
||||
chunk_size = chunk_size or get_rag_chunk_size()
|
||||
return min(RAG_CHUNK_OVERLAP, chunk_size - 1)
|
||||
@@ -249,123 +270,149 @@ def get_rag_prompt_helper(
|
||||
)
|
||||
|
||||
|
||||
def _embed_nodes(nodes: list["BaseNode"], embed_model) -> None:
|
||||
"""Embed ``nodes`` in place using ``embed_model``."""
|
||||
from llama_index.core.schema import MetadataMode
|
||||
|
||||
texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes]
|
||||
for node, emb in zip(
|
||||
nodes,
|
||||
embed_model.get_text_embedding_batch(texts),
|
||||
strict=True,
|
||||
):
|
||||
node.embedding = emb
|
||||
|
||||
|
||||
def _document_id_filters(doc_ids):
|
||||
"""Return a MetadataFilters IN filter scoped to ``doc_ids``."""
|
||||
from llama_index.core.vector_stores.types import FilterOperator
|
||||
from llama_index.core.vector_stores.types import MetadataFilter
|
||||
from llama_index.core.vector_stores.types import MetadataFilters
|
||||
|
||||
return MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="document_id",
|
||||
operator=FilterOperator.IN,
|
||||
value=sorted(doc_ids),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def update_llm_index(
|
||||
*,
|
||||
iter_wrapper: IterWrapper[Document] = identity,
|
||||
rebuild=False,
|
||||
) -> str:
|
||||
"""
|
||||
Rebuild or update the LLM index.
|
||||
"""
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
||||
nodes = []
|
||||
|
||||
"""Rebuild or incrementally update the LLM index."""
|
||||
with write_store() as store:
|
||||
try:
|
||||
with _exclude_readers():
|
||||
needs_reembed = store.check_and_run_migrations()
|
||||
except Timeout:
|
||||
logger.info(
|
||||
"Skipping LLM index migration check: index readers are active; "
|
||||
"will retry next run.",
|
||||
)
|
||||
needs_reembed = False
|
||||
if needs_reembed:
|
||||
logger.warning(
|
||||
"LLM index migration requires re-embedding; forcing rebuild.",
|
||||
)
|
||||
rebuild = True
|
||||
documents = Document.objects.all()
|
||||
if not documents.exists():
|
||||
no_documents = not documents.exists()
|
||||
|
||||
# Fast exit before touching config: nothing to index and no existing index.
|
||||
if no_documents and not rebuild and not llm_index_exists():
|
||||
logger.warning("No documents found to index.")
|
||||
if not rebuild and not vector_store_file_exists():
|
||||
return "No documents found to index."
|
||||
return "No documents found to index."
|
||||
|
||||
config = AIConfig()
|
||||
model_name = get_configured_model_name(config)
|
||||
|
||||
if not rebuild and llm_index_exists():
|
||||
with read_store() as store:
|
||||
config_mismatch = store.config_mismatch(model_name)
|
||||
if config_mismatch:
|
||||
logger.warning("Embedding model changed; forcing LLM index rebuild.")
|
||||
rebuild = True
|
||||
|
||||
if no_documents:
|
||||
logger.warning("No documents found to index.")
|
||||
|
||||
chunk_size = config.llm_embedding_chunk_size
|
||||
embed_model = get_embedding_model(config)
|
||||
|
||||
with FileLock(_index_lock_path()):
|
||||
if rebuild or not vector_store_file_exists():
|
||||
# remove meta.json to force re-detection of embedding dim
|
||||
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
|
||||
# Rebuild index from scratch
|
||||
with write_store(embed_model_name=model_name) as store:
|
||||
if rebuild or not store.table_exists():
|
||||
logger.info("Rebuilding LLM index.")
|
||||
import llama_index.core.settings as llama_settings
|
||||
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context(rebuild=True)
|
||||
store.drop_table()
|
||||
for document in iter_wrapper(documents):
|
||||
document_nodes = build_document_node(document, chunk_size=chunk_size)
|
||||
nodes.extend(document_nodes)
|
||||
|
||||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
show_progress=False,
|
||||
)
|
||||
nodes = build_document_node(document, chunk_size=chunk_size)
|
||||
_embed_nodes(nodes, embed_model)
|
||||
store.add(nodes)
|
||||
msg = "LLM index rebuilt successfully."
|
||||
else:
|
||||
# Update existing index
|
||||
index = load_or_build_index()
|
||||
existing_nodes: defaultdict[str, list] = defaultdict(list)
|
||||
for node in index.docstore.docs.values():
|
||||
doc_id = node.metadata.get("document_id")
|
||||
if doc_id is not None:
|
||||
existing_nodes[doc_id].append(node)
|
||||
|
||||
existing = store.get_modified_times()
|
||||
changed = 0
|
||||
for document in iter_wrapper(documents):
|
||||
doc_id = str(document.id)
|
||||
document_modified = document.modified.isoformat()
|
||||
if existing.get(doc_id) == document.modified.isoformat():
|
||||
continue
|
||||
nodes = build_document_node(document, chunk_size=chunk_size)
|
||||
_embed_nodes(nodes, embed_model)
|
||||
store.upsert_document(doc_id, nodes)
|
||||
changed += 1
|
||||
msg = (
|
||||
"LLM index updated successfully."
|
||||
if changed
|
||||
else "No changes detected in LLM index."
|
||||
)
|
||||
|
||||
if doc_id in existing_nodes:
|
||||
doc_nodes = existing_nodes[doc_id]
|
||||
node_modified = doc_nodes[0].metadata.get("modified")
|
||||
|
||||
if node_modified == document_modified:
|
||||
continue
|
||||
|
||||
# Delete from docstore, FAISS IndexFlatL2 are append-only
|
||||
for node in doc_nodes:
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
nodes.extend(build_document_node(document, chunk_size=chunk_size))
|
||||
|
||||
if nodes:
|
||||
msg = "LLM index updated successfully."
|
||||
logger.info(
|
||||
"Updating %d nodes in LLM index.",
|
||||
len(nodes),
|
||||
)
|
||||
index.insert_nodes(nodes)
|
||||
else:
|
||||
msg = "No changes detected in LLM index."
|
||||
logger.info(msg)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
try:
|
||||
with _exclude_readers():
|
||||
store.compact()
|
||||
except Timeout:
|
||||
logger.info(
|
||||
"Skipping LLM index compaction: index readers are active; "
|
||||
"will retry next run.",
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def llm_index_add_or_update_document(document: Document):
|
||||
"""
|
||||
Adds or updates a document in the LLM index.
|
||||
If the document already exists, it will be replaced.
|
||||
"""
|
||||
new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size())
|
||||
if not new_nodes:
|
||||
logger.warning(
|
||||
"No indexable content for document %s; skipping LLM index update.",
|
||||
document.pk,
|
||||
)
|
||||
return
|
||||
"""Add or atomically replace a document's chunks in the index."""
|
||||
config = AIConfig()
|
||||
new_nodes = build_document_node(
|
||||
document,
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
)
|
||||
if new_nodes:
|
||||
_embed_nodes(new_nodes, get_embedding_model(config))
|
||||
|
||||
with FileLock(_index_lock_path()):
|
||||
index = load_or_build_index(nodes=new_nodes)
|
||||
with write_store(embed_model_name=get_configured_model_name(config)) as store:
|
||||
store.upsert_document(str(document.id), new_nodes)
|
||||
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
index.insert_nodes(new_nodes)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
def llm_index_compact() -> None:
|
||||
"""Compact the index immediately, rebuilding the table to reclaim space."""
|
||||
with write_store() as store:
|
||||
try:
|
||||
with _exclude_readers():
|
||||
store.compact(force=True)
|
||||
except Timeout:
|
||||
logger.info(
|
||||
"Skipping LLM index compaction: index readers are active; "
|
||||
"will retry next run.",
|
||||
)
|
||||
|
||||
|
||||
def llm_index_remove_document(document: Document):
|
||||
"""
|
||||
Removes a document from the LLM index.
|
||||
"""
|
||||
with FileLock(_index_lock_path()):
|
||||
index = load_or_build_index()
|
||||
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
"""Remove a document's chunks from the LLM index."""
|
||||
with write_store() as store:
|
||||
store.delete(str(document.id))
|
||||
|
||||
|
||||
def truncate_content(
|
||||
@@ -405,82 +452,102 @@ def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str]
|
||||
return {str(document_id) for document_id in document_ids}
|
||||
|
||||
|
||||
def query_similar_documents(
|
||||
document: Document,
|
||||
top_k: int = 5,
|
||||
document_ids: Iterable[int | str] | None = None,
|
||||
) -> list[Document]:
|
||||
def visible_document_ids_for_user(user: User | None) -> list[int] | None:
|
||||
"""Return the pks of documents ``user`` may view, or ``None`` for no filter.
|
||||
|
||||
Returns ``None`` when ``user`` is ``None`` so retrieval runs unfiltered. Used
|
||||
by both the similarity-context and taxonomy-hints paths to scope RAG
|
||||
neighbours to documents the requesting user is allowed to see.
|
||||
"""
|
||||
Runs a similarity query and returns top-k similar Document objects.
|
||||
if user is None:
|
||||
return None
|
||||
visible_documents = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
"view_document",
|
||||
Document,
|
||||
)
|
||||
return list(visible_documents.values_list("pk", flat=True))
|
||||
|
||||
|
||||
def retrieve_similar_nodes(
|
||||
document: Document,
|
||||
document_ids: Iterable[int | str] | None = None,
|
||||
top_k: int = 5,
|
||||
) -> list["NodeWithScore"]:
|
||||
"""Run ANN retrieval and return the raw NodeWithScore results.
|
||||
|
||||
Returns ``[]`` when the allow-list normalizes to empty, or when no index
|
||||
exists yet (queuing a build in that case). The ``retrieve()`` call is a slow
|
||||
embedding request, so it runs inside ``db_connection_released()`` to avoid
|
||||
pinning the pooled DB connection (#12976). Both ``query_similar_documents``
|
||||
and the taxonomy-hints path go through here, so they share that behavior.
|
||||
"""
|
||||
allowed_document_ids = normalize_document_ids(document_ids)
|
||||
if allowed_document_ids is not None and not allowed_document_ids:
|
||||
return []
|
||||
|
||||
if not vector_store_file_exists():
|
||||
if not llm_index_exists():
|
||||
queue_llm_index_update_if_needed(
|
||||
rebuild=False,
|
||||
reason="LLM index not found for similarity query.",
|
||||
)
|
||||
return []
|
||||
|
||||
with FileLock(_index_lock_path()):
|
||||
index = load_or_build_index()
|
||||
config = AIConfig()
|
||||
|
||||
# constrain only the node(s) that match the document IDs, if given
|
||||
doc_node_ids = (
|
||||
[
|
||||
node.node_id
|
||||
for node in index.docstore.docs.values()
|
||||
if node.metadata.get("document_id") in allowed_document_ids
|
||||
]
|
||||
if allowed_document_ids is not None
|
||||
else None
|
||||
)
|
||||
if doc_node_ids is not None and not doc_node_ids:
|
||||
return []
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
filters = (
|
||||
_document_id_filters(allowed_document_ids)
|
||||
if allowed_document_ids is not None
|
||||
else None
|
||||
)
|
||||
|
||||
query_text = truncate_content(
|
||||
(document.title or "") + "\n" + (document.content or ""),
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
# Hold the shared read lock for the whole retrieval so the connection is
|
||||
# never open across a compaction swap. The retrieve() call generates a
|
||||
# query embedding (a slow external request) and searches the vector store;
|
||||
# no Django ORM access happens during it, so release the pooled DB
|
||||
# connection for its duration. See #12976.
|
||||
with read_store() as store:
|
||||
index = load_or_build_index(config, store)
|
||||
retriever = VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=top_k,
|
||||
doc_ids=doc_node_ids,
|
||||
filters=filters,
|
||||
)
|
||||
with db_connection_released():
|
||||
return retriever.retrieve(query_text)
|
||||
|
||||
config = AIConfig()
|
||||
query_text = truncate_content(
|
||||
(document.title or "") + "\n" + (document.content or ""),
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
try:
|
||||
results = retriever.retrieve(query_text)
|
||||
except KeyError as e:
|
||||
# Ghost FAISS positions remain after deletion because IndexFlatL2 is
|
||||
# append-only. Treat them as absent and return no results.
|
||||
logger.debug(
|
||||
"Skipping LLM similarity query for document %s due to a stale "
|
||||
"FAISS position with no docstore node: %s",
|
||||
document.pk,
|
||||
e,
|
||||
)
|
||||
return []
|
||||
|
||||
def query_similar_documents(
|
||||
document: Document,
|
||||
top_k: int = 5,
|
||||
document_ids: Iterable[int | str] | None = None,
|
||||
) -> list[Document]:
|
||||
"""Return up to ``top_k`` Documents most similar to ``document``."""
|
||||
allowed_document_ids = normalize_document_ids(document_ids)
|
||||
results = retrieve_similar_nodes(
|
||||
document=document,
|
||||
document_ids=allowed_document_ids,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
retrieved_document_ids: list[int] = []
|
||||
for node in results:
|
||||
document_id = node.metadata.get("document_id")
|
||||
if document_id is None:
|
||||
continue
|
||||
normalized_document_id = str(document_id)
|
||||
if (
|
||||
allowed_document_ids is not None
|
||||
and normalized_document_id not in allowed_document_ids
|
||||
):
|
||||
normalized = str(document_id)
|
||||
if allowed_document_ids is not None and normalized not in allowed_document_ids:
|
||||
continue
|
||||
try:
|
||||
retrieved_document_ids.append(int(normalized_document_id))
|
||||
except ValueError:
|
||||
retrieved_document_ids.append(int(normalized))
|
||||
except ValueError: # pragma: no cover
|
||||
logger.warning(
|
||||
"Skipping LLM index result with invalid document_id %r.",
|
||||
document_id,
|
||||
|
||||
@@ -15,40 +15,56 @@ MATCH_THRESHOLD = 0.8
|
||||
logger = logging.getLogger("paperless_ai.matching")
|
||||
|
||||
|
||||
def match_tags_by_name(names: list[str], user: User) -> list[Tag]:
|
||||
def match_tags_by_name(
|
||||
names: list[str],
|
||||
user: User,
|
||||
hinted_names: set[str] | None = None,
|
||||
) -> list[Tag]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_tag"],
|
||||
Tag,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
return _match_names_to_queryset(names, queryset, "name", hinted_names)
|
||||
|
||||
|
||||
def match_correspondents_by_name(names: list[str], user: User) -> list[Correspondent]:
|
||||
def match_correspondents_by_name(
|
||||
names: list[str],
|
||||
user: User,
|
||||
hinted_names: set[str] | None = None,
|
||||
) -> list[Correspondent]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_correspondent"],
|
||||
Correspondent,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
return _match_names_to_queryset(names, queryset, "name", hinted_names)
|
||||
|
||||
|
||||
def match_document_types_by_name(names: list[str], user: User) -> list[DocumentType]:
|
||||
def match_document_types_by_name(
|
||||
names: list[str],
|
||||
user: User,
|
||||
hinted_names: set[str] | None = None,
|
||||
) -> list[DocumentType]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_documenttype"],
|
||||
DocumentType,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
return _match_names_to_queryset(names, queryset, "name", hinted_names)
|
||||
|
||||
|
||||
def match_storage_paths_by_name(names: list[str], user: User) -> list[StoragePath]:
|
||||
def match_storage_paths_by_name(
|
||||
names: list[str],
|
||||
user: User,
|
||||
hinted_names: set[str] | None = None,
|
||||
) -> list[StoragePath]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_storagepath"],
|
||||
StoragePath,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
return _match_names_to_queryset(names, queryset, "name", hinted_names)
|
||||
|
||||
|
||||
def _normalize(s: str) -> str:
|
||||
@@ -58,10 +74,18 @@ def _normalize(s: str) -> str:
|
||||
return s
|
||||
|
||||
|
||||
def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||
def _match_names_to_queryset(
|
||||
names: list[str],
|
||||
queryset,
|
||||
attr: str,
|
||||
hinted_names: set[str] | None = None,
|
||||
):
|
||||
results = []
|
||||
objects = list(queryset)
|
||||
object_names = [_normalize(getattr(obj, attr)) for obj in objects]
|
||||
normalized_hints = (
|
||||
{_normalize(name) for name in hinted_names} if hinted_names else set()
|
||||
)
|
||||
|
||||
for name in names:
|
||||
if not name:
|
||||
@@ -76,6 +100,11 @@ def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||
results.append(matched)
|
||||
continue
|
||||
|
||||
# A hinted name that didn't exact-match came from existing taxonomy
|
||||
# verbatim; do not fuzzy-map it onto a different object.
|
||||
if target in normalized_hints:
|
||||
continue
|
||||
|
||||
# Fuzzy match fallback
|
||||
matches = difflib.get_close_matches(
|
||||
target,
|
||||
@@ -88,8 +117,6 @@ def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||
matched = objects.pop(index)
|
||||
object_names.pop(index)
|
||||
results.append(matched)
|
||||
else:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypedDict
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.indexing import retrieve_similar_nodes
|
||||
from paperless_ai.indexing import visible_document_ids_for_user
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
|
||||
class TaxonomyHints(TypedDict):
|
||||
tags: list[str]
|
||||
document_types: list[str]
|
||||
correspondents: list[str]
|
||||
storage_paths: list[str]
|
||||
|
||||
|
||||
def build_taxonomy_hints_from_nodes(
|
||||
nodes: list["NodeWithScore"],
|
||||
) -> TaxonomyHints:
|
||||
"""Collect the unique, sorted taxonomy names carried on retrieved nodes.
|
||||
|
||||
Reads ``tags`` (a list), ``document_type``, ``correspondent``, and
|
||||
``storage_path`` from each node's metadata. Empty / ``None`` values and
|
||||
missing keys are skipped. The result is naturally bounded by the retrieval
|
||||
``top_k``, so no cap is applied.
|
||||
"""
|
||||
tags: set[str] = set()
|
||||
document_types: set[str] = set()
|
||||
correspondents: set[str] = set()
|
||||
storage_paths: set[str] = set()
|
||||
|
||||
for node in nodes:
|
||||
metadata = node.metadata or {}
|
||||
|
||||
for tag in metadata.get("tags") or []:
|
||||
if tag:
|
||||
tags.add(tag)
|
||||
|
||||
document_type = metadata.get("document_type")
|
||||
if document_type:
|
||||
document_types.add(document_type)
|
||||
|
||||
correspondent = metadata.get("correspondent")
|
||||
if correspondent:
|
||||
correspondents.add(correspondent)
|
||||
|
||||
storage_path = metadata.get("storage_path")
|
||||
if storage_path:
|
||||
storage_paths.add(storage_path)
|
||||
|
||||
return TaxonomyHints(
|
||||
tags=sorted(tags),
|
||||
document_types=sorted(document_types),
|
||||
correspondents=sorted(correspondents),
|
||||
storage_paths=sorted(storage_paths),
|
||||
)
|
||||
|
||||
|
||||
_HINT_INSTRUCTION = (
|
||||
"Prefer existing names from these lists verbatim. Only propose a new value "
|
||||
"if none of the existing names fits."
|
||||
)
|
||||
|
||||
|
||||
def format_hints_for_prompt(hints: TaxonomyHints) -> str:
|
||||
"""Render non-empty hint categories as labelled blocks plus one instruction.
|
||||
|
||||
Returns "" when every category is empty, so callers can treat the result
|
||||
the same as no hints at all.
|
||||
"""
|
||||
# Literal-key access keeps this TypedDict-safe for mypy; the order here is
|
||||
# the order the blocks appear in the prompt.
|
||||
labelled_values: list[tuple[str, list[str]]] = [
|
||||
("Available tags", hints["tags"]),
|
||||
("Available document types", hints["document_types"]),
|
||||
("Available correspondents", hints["correspondents"]),
|
||||
("Available storage paths", hints["storage_paths"]),
|
||||
]
|
||||
blocks: list[str] = []
|
||||
for label, values in labelled_values:
|
||||
if values:
|
||||
listing = "\n".join(f"- {value}" for value in values)
|
||||
blocks.append(f"{label}:\n{listing}")
|
||||
|
||||
if not blocks:
|
||||
return ""
|
||||
|
||||
return "\n\n".join([*blocks, _HINT_INSTRUCTION])
|
||||
|
||||
|
||||
def get_taxonomy_hints_for_document(
|
||||
document: Document,
|
||||
user: User | None,
|
||||
) -> TaxonomyHints | None:
|
||||
"""Build taxonomy hints from a document's RAG neighbours.
|
||||
|
||||
Returns ``None`` when no embedding backend is configured (the gate) so the
|
||||
caller's prompt and matching are identical to today. Otherwise returns a
|
||||
``TaxonomyHints`` -- possibly all-empty when no similar documents exist.
|
||||
Applies the same owner-aware visible-document filter as
|
||||
``get_context_for_document``.
|
||||
"""
|
||||
if not AIConfig().llm_embedding_backend:
|
||||
return None
|
||||
|
||||
nodes = retrieve_similar_nodes(
|
||||
document=document,
|
||||
document_ids=visible_document_ids_for_user(user),
|
||||
)
|
||||
return build_taxonomy_hints_from_nodes(nodes)
|
||||
@@ -1,10 +1,36 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from pytest_django.fixtures import SettingsWrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper):
|
||||
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path:
|
||||
settings.LLM_INDEX_DIR = tmp_path
|
||||
settings.LLM_INDEX_LOCK = tmp_path / "index.lock"
|
||||
settings.LLM_INDEX_RWLOCK = tmp_path / "llmindex.rwlock.db"
|
||||
return tmp_path
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
async def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def get_query_embedding_dim(self) -> int:
|
||||
return 384
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model(mocker: pytest_mock.MockerFixture) -> pytest_mock.MockType:
|
||||
fake = FakeEmbedding()
|
||||
mocker.patch("paperless_ai.indexing.get_embedding_model", return_value=fake)
|
||||
mocker.patch("paperless_ai.embedding.get_embedding_model", return_value=fake)
|
||||
return fake
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from django.test import override_settings
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.ai_classifier import build_localization_prompt
|
||||
from paperless_ai.ai_classifier import build_prompt_with_rag
|
||||
from paperless_ai.ai_classifier import build_prompt_without_rag
|
||||
@@ -211,11 +215,12 @@ def test_prompt_with_without_rag(mock_document):
|
||||
"paperless_ai.ai_classifier.get_context_for_document",
|
||||
return_value="Context from similar documents",
|
||||
):
|
||||
prompt = build_prompt_without_rag(mock_document)
|
||||
config = AIConfig()
|
||||
prompt = build_prompt_without_rag(mock_document, config)
|
||||
assert "Additional context from similar documents" not in prompt
|
||||
assert "for generated" not in prompt
|
||||
|
||||
prompt = build_prompt_with_rag(mock_document)
|
||||
prompt = build_prompt_with_rag(mock_document, config)
|
||||
assert "Additional context from similar documents" in prompt
|
||||
|
||||
prompt = build_localization_prompt(
|
||||
@@ -259,3 +264,111 @@ def test_get_context_for_document_no_similar_docs(mock_document):
|
||||
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
|
||||
result = get_context_for_document(mock_document)
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestPromptHints:
|
||||
@pytest.fixture
|
||||
def config(self) -> AIConfig:
|
||||
# build_prompt_* only read these two numeric settings off config;
|
||||
# a stand-in avoids constructing a DB-backed AIConfig.
|
||||
return cast(
|
||||
"AIConfig",
|
||||
SimpleNamespace(llm_embedding_chunk_size=1000, llm_context_size=8000),
|
||||
)
|
||||
|
||||
def test_without_rag_includes_hints_block(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
config: AIConfig,
|
||||
) -> None:
|
||||
hints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": ["Invoice"],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
prompt = build_prompt_without_rag(mock_document, config, hints=hints)
|
||||
assert "Available tags:" in prompt
|
||||
assert "- Bloodwork" in prompt
|
||||
assert "Prefer existing names from these lists verbatim" in prompt
|
||||
|
||||
def test_without_rag_none_matches_baseline(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
config: AIConfig,
|
||||
) -> None:
|
||||
baseline = build_prompt_without_rag(mock_document, config)
|
||||
with_none = build_prompt_without_rag(mock_document, config, hints=None)
|
||||
assert with_none == baseline
|
||||
assert "Available tags:" not in with_none
|
||||
|
||||
def test_with_rag_includes_context_and_hints(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
config: AIConfig,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.ai_classifier.get_context_for_document",
|
||||
return_value="TITLE: Neighbour\nsome context",
|
||||
)
|
||||
hints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
prompt = build_prompt_with_rag(mock_document, config, user=None, hints=hints)
|
||||
assert "Additional context from similar documents" in prompt
|
||||
assert "Available tags:" in prompt
|
||||
|
||||
def test_classification_forwards_hints(
|
||||
self,
|
||||
mock_document: MagicMock,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.ai_classifier.AIConfig",
|
||||
return_value=SimpleNamespace(
|
||||
llm_embedding_backend=None,
|
||||
llm_embedding_chunk_size=1000,
|
||||
llm_context_size=8000,
|
||||
),
|
||||
)
|
||||
build = mocker.patch(
|
||||
"paperless_ai.ai_classifier.build_prompt_without_rag",
|
||||
return_value="PROMPT",
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.run_llm_query.return_value = {
|
||||
"title": "t",
|
||||
"tags": [],
|
||||
"correspondents": [],
|
||||
"document_types": [],
|
||||
"storage_paths": [],
|
||||
"dates": [],
|
||||
}
|
||||
mocker.patch("paperless_ai.ai_classifier.AIClient", return_value=mock_client)
|
||||
hints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
|
||||
result = get_ai_document_classification(
|
||||
mock_document,
|
||||
user=None,
|
||||
hints=hints,
|
||||
)
|
||||
|
||||
_, build_kwargs = build.call_args
|
||||
assert build_kwargs["hints"] == hints
|
||||
assert set(result.keys()) == {
|
||||
"title",
|
||||
"tags",
|
||||
"correspondents",
|
||||
"document_types",
|
||||
"storage_paths",
|
||||
"dates",
|
||||
}
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
from faker import Faker
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.schema import MetadataMode
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import PaperlessTask
|
||||
@@ -19,10 +17,12 @@ from documents.tests.factories import DocumentFactory
|
||||
from documents.tests.factories import PaperlessTaskFactory
|
||||
from paperless.models import ApplicationConfiguration
|
||||
from paperless_ai import indexing
|
||||
from paperless_ai.tests.conftest import FakeEmbedding
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_document(db):
|
||||
def real_document(db: None) -> Document:
|
||||
return Document.objects.create(
|
||||
title="Test Document",
|
||||
content="This is some test content.",
|
||||
@@ -30,44 +30,39 @@ def real_document(db):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model():
|
||||
fake = FakeEmbedding()
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
|
||||
patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
) as mock_embedding,
|
||||
):
|
||||
mock_index.return_value = fake
|
||||
mock_embedding.return_value = fake
|
||||
yield mock_index
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
# TODO: maybe a better way to do this?
|
||||
def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def get_query_embedding_dim(self) -> int:
|
||||
return 384 # Match your real FAISS config
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node(real_document) -> None:
|
||||
def test_build_document_node(real_document: Document) -> None:
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0
|
||||
assert nodes[0].metadata["document_id"] == str(real_document.id)
|
||||
assert nodes[0].metadata["filename"] == real_document.filename
|
||||
assert nodes[0].metadata["storage_path"] == (
|
||||
real_document.storage_path.name if real_document.storage_path else None
|
||||
)
|
||||
assert (
|
||||
nodes[0].metadata["archive_serial_number"]
|
||||
== real_document.archive_serial_number
|
||||
)
|
||||
assert "filename" in nodes[0].excluded_embed_metadata_keys
|
||||
assert "filename" not in nodes[0].excluded_llm_metadata_keys
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node_excludes_metadata_from_embedding(real_document) -> None:
|
||||
def test_build_document_node_sets_ref_doc_id(real_document: Document) -> None:
|
||||
"""Every node produced by build_document_node must carry the paperless document id
|
||||
as its ref_doc_id so that the vector store's delete(str(doc.id)) works correctly."""
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0, "Expected at least one node"
|
||||
for node in nodes:
|
||||
assert node.ref_doc_id == str(real_document.id), (
|
||||
f"Expected ref_doc_id={real_document.id!r}, got {node.ref_doc_id!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node_excludes_metadata_from_embedding(
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
"""Metadata keys must not be prepended to the embedding text.
|
||||
|
||||
build_llm_index_text already encodes all metadata in the body text, so
|
||||
@@ -75,8 +70,6 @@ def test_build_document_node_excludes_metadata_from_embedding(real_document) ->
|
||||
double the token count and exceed embedding models with small context
|
||||
windows (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
|
||||
"""
|
||||
from llama_index.core.schema import MetadataMode
|
||||
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
for node in nodes:
|
||||
embed_text = node.get_content(metadata_mode=MetadataMode.EMBED)
|
||||
@@ -87,7 +80,36 @@ def test_build_document_node_excludes_metadata_from_embedding(real_document) ->
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node_uses_rag_chunk_settings(real_document) -> None:
|
||||
def test_build_document_node_structured_fields_in_metadata(
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
"""Structured fields must be in node.metadata so the LLM receives them via metadata prepend."""
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0
|
||||
for node in nodes:
|
||||
assert "title" in node.metadata
|
||||
assert "tags" in node.metadata
|
||||
assert "correspondent" in node.metadata
|
||||
assert "document_type" in node.metadata
|
||||
assert "created" in node.metadata
|
||||
assert "added" in node.metadata
|
||||
assert "modified" in node.metadata
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node_excludes_document_id_from_llm_context(
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
"""document_id is an internal key and must not appear in LLM context text."""
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0
|
||||
for node in nodes:
|
||||
assert "document_id" in node.excluded_llm_metadata_keys
|
||||
assert "document_id" not in node.get_content(metadata_mode=MetadataMode.LLM)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node_uses_rag_chunk_settings(real_document: Document) -> None:
|
||||
app_config, _ = ApplicationConfiguration.objects.get_or_create()
|
||||
app_config.llm_embedding_chunk_size = 512
|
||||
app_config.save()
|
||||
@@ -118,9 +140,9 @@ def test_get_rag_prompt_helper_uses_context_setting() -> None:
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
mock_config = MagicMock()
|
||||
mock_config.llm_embedding_chunk_size = 512
|
||||
@@ -138,44 +160,49 @@ def test_update_llm_index(
|
||||
|
||||
ai_config.assert_called_once()
|
||||
build_document_node.assert_called_once_with(real_document, chunk_size=512)
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_removes_meta(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
def test_update_llm_index_rebuilds_on_model_name_change(
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
# Pre-create a meta.json with incorrect data
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "old", "dim": 1}),
|
||||
)
|
||||
|
||||
# Build initial index with model "model-a".
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mock_all.return_value = mock_queryset
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
with patch(
|
||||
"paperless_ai.indexing.get_configured_model_name",
|
||||
return_value="model-a",
|
||||
):
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||
from paperless.config import AIConfig
|
||||
# Simulate config change to "model-b"; the incremental run must force a rebuild.
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mock_all.return_value = mock_queryset
|
||||
with patch(
|
||||
"paperless_ai.indexing.get_configured_model_name",
|
||||
return_value="model-b",
|
||||
):
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
|
||||
config = AIConfig()
|
||||
expected_model = config.llm_embedding_model or (
|
||||
"text-embedding-3-small"
|
||||
if config.llm_embedding_backend == "openai-like"
|
||||
else "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
assert meta == {"embedding_model": expected_model, "dim": 384}
|
||||
with indexing.get_vector_store() as store:
|
||||
# Schema metadata only updates when the table is dropped and recreated, never
|
||||
# on incremental writes -- so "model-b" here proves a full rebuild happened.
|
||||
assert store.stored_model_name() == "model-b"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_partial_update(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
doc2 = Document.objects.create(
|
||||
title="Test Document 2",
|
||||
@@ -210,131 +237,34 @@ def test_update_llm_index_partial_update(
|
||||
mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3])
|
||||
mock_all.return_value = mock_queryset
|
||||
|
||||
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
|
||||
with patch("paperless_ai.indexing.logger") as mock_logger:
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
mock_logger.info.assert_called_once_with(
|
||||
"Updating %d nodes in LLM index.",
|
||||
2,
|
||||
)
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
def test_get_or_create_storage_context_raises_exception(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
with pytest.raises(Exception):
|
||||
indexing.get_or_create_storage_context(rebuild=False)
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
)
|
||||
def test_load_or_build_index_builds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"llama_index.core.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
) as mock_index_cls,
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
) as mock_storage,
|
||||
):
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
indexing.load_or_build_index(
|
||||
nodes=[indexing.build_document_node(real_document)],
|
||||
with indexing.get_vector_store() as store:
|
||||
assert store.table_exists(), (
|
||||
"Expected the vector store table to exist after incremental update"
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
|
||||
|
||||
def test_load_or_build_index_raises_exception_when_no_nodes(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
indexing.load_or_build_index()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_load_or_build_index_succeeds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"llama_index.core.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
) as mock_index_cls,
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
) as mock_storage,
|
||||
):
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
indexing.load_or_build_index(
|
||||
nodes=[MagicMock()],
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_add_or_update_document_updates_existing_entry(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_remove_document_deletes_node_from_docstore(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
index = indexing.load_or_build_index()
|
||||
assert len(index.docstore.docs) == 1
|
||||
|
||||
indexing.llm_index_remove_document(real_document)
|
||||
index = indexing.load_or_build_index()
|
||||
assert len(index.docstore.docs) == 0
|
||||
with indexing.get_vector_store() as store:
|
||||
assert store.table_exists(), (
|
||||
"Expected the vector store table to exist after add-or-update"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_query_after_remove_does_not_raise_key_error(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
@@ -352,8 +282,8 @@ def test_query_after_remove_does_not_raise_key_error(
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_no_documents(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
@@ -369,6 +299,22 @@ def test_update_llm_index_no_documents(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_no_documents_no_index_returns_early(
|
||||
temp_llm_index_dir: Path,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""update with no documents and no existing index must return early."""
|
||||
mock_qs = MagicMock()
|
||||
mock_qs.exists.return_value = False
|
||||
mock_qs.__iter__ = MagicMock(return_value=iter([]))
|
||||
mocker.patch("paperless_ai.indexing.Document.objects.all", return_value=mock_qs)
|
||||
|
||||
result = indexing.update_llm_index(rebuild=False)
|
||||
|
||||
assert result == "No documents found to index."
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -> None:
|
||||
# No existing tasks
|
||||
@@ -406,20 +352,17 @@ def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -
|
||||
LLM_BACKEND="ollama",
|
||||
)
|
||||
def test_query_similar_documents(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
"paperless_ai.indexing.llm_index_exists",
|
||||
) as mock_vector_store_exists,
|
||||
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
|
||||
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
|
||||
):
|
||||
mock_storage.return_value = MagicMock()
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
mock_vector_store_exists.return_value = True
|
||||
|
||||
mock_index = MagicMock()
|
||||
@@ -453,12 +396,12 @@ def test_query_similar_documents(
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_query_similar_documents_triggers_update_when_index_missing(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
"paperless_ai.indexing.llm_index_exists",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
@@ -479,120 +422,13 @@ def test_query_similar_documents_triggers_update_when_index_missing(
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_query_similar_documents_normalizes_and_post_filters_allowed_ids(
|
||||
real_document,
|
||||
) -> None:
|
||||
real_document.owner = User.objects.create_user(username="rag-owner")
|
||||
real_document.save()
|
||||
private_owner = User.objects.create_user(username="rag-private-owner")
|
||||
private_document = Document.objects.create(
|
||||
title="Private similar document",
|
||||
content="Similar private content that must not reach RAG.",
|
||||
owner=private_owner,
|
||||
added=timezone.now(),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
return_value=True,
|
||||
),
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
|
||||
):
|
||||
allowed_node = MagicMock()
|
||||
allowed_node.node_id = "allowed-node"
|
||||
allowed_node.metadata = {"document_id": str(real_document.pk)}
|
||||
private_node = MagicMock()
|
||||
private_node.node_id = "private-node"
|
||||
private_node.metadata = {"document_id": str(private_document.pk)}
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [allowed_node, private_node]
|
||||
mock_load_or_build_index.return_value = mock_index
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = [private_node, allowed_node]
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
|
||||
result = indexing.query_similar_documents(
|
||||
real_document,
|
||||
top_k=2,
|
||||
document_ids=[real_document.pk],
|
||||
)
|
||||
|
||||
mock_retriever_cls.assert_called_once_with(
|
||||
index=mock_index,
|
||||
similarity_top_k=2,
|
||||
doc_ids=["allowed-node"],
|
||||
)
|
||||
assert result == [real_document]
|
||||
assert private_document not in result
|
||||
|
||||
|
||||
class TestUpdateLlmIndexStaleNodes:
|
||||
"""Tests that update_llm_index removes ALL nodes for a multi-chunk document."""
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_incremental_update_removes_all_old_nodes_for_multi_chunk_document(
|
||||
self,
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model: MagicMock,
|
||||
) -> None:
|
||||
"""Ghost nodes from all chunks of a modified document must be removed.
|
||||
|
||||
When a document is split into multiple chunks (chunk_size=1024), the
|
||||
incremental update path must delete every old node, not just the last
|
||||
one captured by a dict comprehension keyed on document_id.
|
||||
"""
|
||||
# Content long enough to produce at least two chunks at chunk_size=1024.
|
||||
# Generate many paragraphs so the token count comfortably exceeds 1024.
|
||||
fake = Faker()
|
||||
long_content = "\n\n".join(fake.paragraph(nb_sentences=20) for _ in range(20))
|
||||
doc = DocumentFactory(content=long_content)
|
||||
|
||||
# Build the initial index (rebuild=True) so it has multiple nodes
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
# Verify the initial index has more than one node for this document
|
||||
initial_index = indexing.load_or_build_index()
|
||||
initial_node_ids = [
|
||||
nid
|
||||
for nid, node in initial_index.docstore.docs.items()
|
||||
if node.metadata.get("document_id") == str(doc.id)
|
||||
]
|
||||
assert len(initial_node_ids) > 1, (
|
||||
f"Expected multiple chunks but got {len(initial_node_ids)}; "
|
||||
"increase long_content length"
|
||||
)
|
||||
|
||||
# Simulate a modification so the incremental path treats it as changed.
|
||||
# Use queryset.update() to bypass auto_now and actually change the DB value.
|
||||
new_modified = timezone.now()
|
||||
Document.objects.filter(pk=doc.pk).update(modified=new_modified)
|
||||
|
||||
# Run incremental update (rebuild=False) with the modified document
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
|
||||
# Reload the persisted index and check that no OLD node ids remain
|
||||
updated_index = indexing.load_or_build_index()
|
||||
remaining_old_node_ids = [
|
||||
nid for nid in initial_node_ids if nid in updated_index.docstore.docs
|
||||
]
|
||||
assert remaining_old_node_ids == [], (
|
||||
f"Ghost nodes still present after incremental update: "
|
||||
f"{remaining_old_node_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_query_similar_documents_empty_allow_list_fails_closed(
|
||||
real_document,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
"paperless_ai.indexing.llm_index_exists",
|
||||
return_value=True,
|
||||
) as mock_vector_store_exists,
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
@@ -610,27 +446,25 @@ def test_query_similar_documents_empty_allow_list_fails_closed(
|
||||
|
||||
|
||||
class TestUpdateLlmIndexEmptyDocumentSet:
|
||||
"""update_llm_index must persist an empty index when all documents are deleted.
|
||||
"""update_llm_index must clear the vector store table when all documents are deleted.
|
||||
|
||||
Without this, the stale on-disk FAISS vectors are never cleared and
|
||||
subsequent similarity searches return phantom hits for document IDs that
|
||||
no longer exist in the DB.
|
||||
Without this, the stale vectors are never cleared and subsequent similarity
|
||||
searches return phantom hits for document IDs that no longer exist in the DB.
|
||||
"""
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_rebuild_clears_stale_index_when_no_documents_exist(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: MagicMock,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
"""After deleting all documents, rebuild=True must persist an empty index.
|
||||
"""After deleting all documents, rebuild=True must produce a table with zero rows.
|
||||
|
||||
Steps:
|
||||
1. Build an index with one document so the on-disk state is non-empty.
|
||||
2. Delete all documents from the DB.
|
||||
3. Call update_llm_index(rebuild=True).
|
||||
4. Reload the index from disk.
|
||||
5. Assert the reloaded index has zero nodes (no phantom vectors).
|
||||
4. Open the LanceDB table directly and assert zero rows.
|
||||
"""
|
||||
# Step 1: create a document and build a non-empty index
|
||||
Document.objects.create(
|
||||
@@ -640,27 +474,26 @@ class TestUpdateLlmIndexEmptyDocumentSet:
|
||||
)
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
initial_index = indexing.load_or_build_index()
|
||||
assert len(initial_index.docstore.docs) > 0, (
|
||||
"Precondition failed: expected at least one node before deletion"
|
||||
)
|
||||
with indexing.get_vector_store() as store:
|
||||
assert store.table_exists(), (
|
||||
"Precondition failed: expected the vector store table to exist "
|
||||
"before deletion"
|
||||
)
|
||||
|
||||
# Step 2: delete all documents
|
||||
Document.objects.all().delete()
|
||||
assert not Document.objects.exists()
|
||||
|
||||
# Step 3: rebuild with no documents
|
||||
# Step 3: rebuild with no documents — drop_table is called so the table
|
||||
# is removed (no rows to re-insert, so it stays absent).
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
# Step 4: reload the persisted index from disk
|
||||
reloaded_index = indexing.load_or_build_index()
|
||||
|
||||
# Step 5: phantom vectors must be gone
|
||||
assert len(reloaded_index.docstore.docs) == 0, (
|
||||
f"Expected 0 nodes after clearing all documents, "
|
||||
f"but found {len(reloaded_index.docstore.docs)}: "
|
||||
f"{list(reloaded_index.docstore.docs.keys())}"
|
||||
)
|
||||
# Step 4: the table must be absent (no rows) — phantom vectors gone
|
||||
with indexing.get_vector_store() as store2:
|
||||
assert not store2.table_exists(), (
|
||||
"Expected the vector store table to be absent after rebuilding "
|
||||
"with no documents"
|
||||
)
|
||||
|
||||
|
||||
class TestDocumentUpdatedSignalTriggersLlmReindex:
|
||||
@@ -709,10 +542,14 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
|
||||
def test_returns_without_error_when_build_document_node_returns_empty(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: MagicMock,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""When build_document_node returns [], the function must return without error
|
||||
and must not call load_or_build_index at all."""
|
||||
"""When build_document_node returns [], the function must return without error.
|
||||
|
||||
The store's upsert_document treats an empty node list as a removal (no-op
|
||||
delete), so load_or_build_index must not be called.
|
||||
"""
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.build_document_node",
|
||||
return_value=[],
|
||||
@@ -720,6 +557,7 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
|
||||
mock_load = mocker.patch("paperless_ai.indexing.load_or_build_index")
|
||||
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.id = 42
|
||||
# Must not raise
|
||||
indexing.llm_index_add_or_update_document(doc)
|
||||
|
||||
@@ -727,172 +565,220 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestLlmIndexLocking:
|
||||
"""The FAISS index mutation functions must acquire the index lock before touching the index.
|
||||
def test_llm_index_compact_uses_force(
|
||||
temp_llm_index_dir: Path,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""compact must use force=True to rebuild the table and reclaim space immediately."""
|
||||
mock_store = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.write_store",
|
||||
return_value=mocker.MagicMock(
|
||||
__enter__=mocker.MagicMock(return_value=mock_store),
|
||||
__exit__=mocker.MagicMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
Without locking, two concurrent Celery workers can each load the same
|
||||
on-disk index, make independent modifications, and the last writer silently
|
||||
overwrites the first's changes.
|
||||
indexing.llm_index_compact()
|
||||
|
||||
mock_store.compact.assert_called_once_with(force=True)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestLlmIndexLocking:
|
||||
"""Index mutation functions must go through write_store(), which holds the lock.
|
||||
|
||||
Without locking, two concurrent Celery workers can open the same store,
|
||||
make independent modifications, and trigger CommitConflictError.
|
||||
"""
|
||||
|
||||
def test_add_or_update_document_acquires_lock(
|
||||
def test_add_or_update_document_uses_write_store(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""llm_index_add_or_update_document must enter the file lock before touching the index."""
|
||||
call_order: list[str] = []
|
||||
|
||||
mock_lock_instance = MagicMock()
|
||||
mock_lock_instance.__enter__ = MagicMock(
|
||||
side_effect=lambda *_: call_order.append("lock_acquired"),
|
||||
)
|
||||
mock_lock_instance.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_file_lock_cls = mocker.patch(
|
||||
"paperless_ai.indexing.FileLock",
|
||||
return_value=mock_lock_instance,
|
||||
)
|
||||
|
||||
mock_load = mocker.patch(
|
||||
"paperless_ai.indexing.load_or_build_index",
|
||||
side_effect=lambda *_a, **_kw: (
|
||||
call_order.append("index_loaded") or MagicMock()
|
||||
mock_store = MagicMock()
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.write_store",
|
||||
return_value=mocker.MagicMock(
|
||||
__enter__=mocker.MagicMock(return_value=mock_store),
|
||||
__exit__=mocker.MagicMock(return_value=False),
|
||||
),
|
||||
)
|
||||
mock_node = MagicMock()
|
||||
mock_node.get_content.return_value = "fake node text"
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.build_document_node",
|
||||
return_value=[MagicMock()],
|
||||
return_value=[mock_node],
|
||||
)
|
||||
mocker.patch("paperless_ai.indexing.remove_document_docstore_nodes")
|
||||
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.id = 1
|
||||
indexing.llm_index_add_or_update_document(doc)
|
||||
|
||||
mock_file_lock_cls.assert_called_once()
|
||||
mock_lock_instance.__enter__.assert_called_once()
|
||||
mock_load.assert_called_once()
|
||||
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
|
||||
"Lock must be acquired before the index is loaded"
|
||||
)
|
||||
mock_store.upsert_document.assert_called_once()
|
||||
|
||||
def test_remove_document_acquires_lock(
|
||||
def test_remove_document_uses_write_store(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""llm_index_remove_document must enter the file lock before loading the index."""
|
||||
call_order: list[str] = []
|
||||
|
||||
mock_lock_instance = MagicMock()
|
||||
mock_lock_instance.__enter__ = MagicMock(
|
||||
side_effect=lambda *_: call_order.append("lock_acquired"),
|
||||
)
|
||||
mock_lock_instance.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_file_lock_cls = mocker.patch(
|
||||
"paperless_ai.indexing.FileLock",
|
||||
return_value=mock_lock_instance,
|
||||
)
|
||||
|
||||
mock_load = mocker.patch(
|
||||
"paperless_ai.indexing.load_or_build_index",
|
||||
side_effect=lambda *_a, **_kw: (
|
||||
call_order.append("index_loaded") or MagicMock()
|
||||
mock_store = MagicMock()
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.write_store",
|
||||
return_value=mocker.MagicMock(
|
||||
__enter__=mocker.MagicMock(return_value=mock_store),
|
||||
__exit__=mocker.MagicMock(return_value=False),
|
||||
),
|
||||
)
|
||||
mocker.patch("paperless_ai.indexing.remove_document_docstore_nodes")
|
||||
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.id = 1
|
||||
indexing.llm_index_remove_document(doc)
|
||||
|
||||
mock_file_lock_cls.assert_called_once()
|
||||
mock_lock_instance.__enter__.assert_called_once()
|
||||
mock_load.assert_called_once()
|
||||
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
|
||||
"Lock must be acquired before the index is loaded"
|
||||
)
|
||||
mock_store.delete.assert_called_once_with("1")
|
||||
|
||||
def test_update_llm_index_rebuild_acquires_lock(
|
||||
def test_update_llm_index_rebuild_uses_write_store(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: MagicMock,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""update_llm_index must enter the file lock during the rebuild/persist cycle."""
|
||||
mock_lock_instance = MagicMock()
|
||||
mock_lock_instance.__enter__ = MagicMock(return_value=None)
|
||||
mock_lock_instance.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_file_lock_cls = mocker.patch(
|
||||
"paperless_ai.indexing.FileLock",
|
||||
return_value=mock_lock_instance,
|
||||
mock_store = MagicMock()
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.write_store",
|
||||
return_value=mocker.MagicMock(
|
||||
__enter__=mocker.MagicMock(return_value=mock_store),
|
||||
__exit__=mocker.MagicMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
# exists=True so the code reaches the lock; iterate over an empty
|
||||
# queryset so VectorStoreIndex is called with no nodes (still exercises
|
||||
# the lock path without needing heavy FAISS fixture data)
|
||||
mock_qs = MagicMock()
|
||||
mock_qs.exists.return_value = True
|
||||
mock_qs.__iter__ = MagicMock(return_value=iter([]))
|
||||
mocker.patch("paperless_ai.indexing.Document.objects.all", return_value=mock_qs)
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
mock_file_lock_cls.assert_called_once()
|
||||
mock_lock_instance.__enter__.assert_called_once()
|
||||
mock_store.drop_table.assert_called_once()
|
||||
|
||||
def test_query_similar_documents_acquires_lock(
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.django_db
|
||||
class TestVectorStoreIndexing:
|
||||
def test_get_vector_store_roundtrip(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
with indexing.get_vector_store() as store:
|
||||
assert isinstance(store, PaperlessSqliteVecVectorStore)
|
||||
|
||||
def test_add_then_remove_document(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
with indexing.get_vector_store() as store:
|
||||
assert store.table_exists()
|
||||
count_sql = "SELECT count(*) FROM documents"
|
||||
assert store.client.execute(count_sql).fetchone()[0] >= 1
|
||||
|
||||
indexing.llm_index_remove_document(real_document)
|
||||
assert store.client.execute(count_sql).fetchone()[0] == 0
|
||||
|
||||
def test_update_shrinks_chunks_without_orphans(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
real_document.content = "word " * 4000 # many chunks
|
||||
real_document.save()
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
count_sql = "SELECT count(*) FROM documents"
|
||||
with indexing.get_vector_store() as store:
|
||||
big = store.client.execute(count_sql).fetchone()[0]
|
||||
|
||||
real_document.content = "short" # one chunk
|
||||
real_document.save()
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
|
||||
rows = store.client.execute(count_sql).fetchone()[0]
|
||||
assert rows < big
|
||||
assert rows >= 1
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestQuerySimilarDocuments:
|
||||
def test_query_similar_documents_respects_allowed_ids(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
a = DocumentFactory.create(content="alpha shared content here")
|
||||
b = DocumentFactory.create(content="beta shared content here")
|
||||
c = DocumentFactory.create(content="gamma shared content here")
|
||||
for doc in (a, b, c):
|
||||
indexing.llm_index_add_or_update_document(doc)
|
||||
|
||||
results = indexing.query_similar_documents(a, document_ids=[b.id])
|
||||
|
||||
assert all(doc.id == b.id for doc in results)
|
||||
|
||||
|
||||
class TestRetrieveSimilarNodes:
|
||||
@pytest.mark.django_db
|
||||
def test_returns_raw_nodes_from_retriever(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""query_similar_documents must enter the file lock before loading the index."""
|
||||
call_order: list[str] = []
|
||||
|
||||
mock_lock_instance = MagicMock()
|
||||
mock_lock_instance.__enter__ = MagicMock(
|
||||
side_effect=lambda *_: call_order.append("lock_acquired"),
|
||||
)
|
||||
mock_lock_instance.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_file_lock_cls = mocker.patch(
|
||||
"paperless_ai.indexing.FileLock",
|
||||
return_value=mock_lock_instance,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs = {}
|
||||
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.load_or_build_index",
|
||||
side_effect=lambda *_a, **_kw: (
|
||||
call_order.append("index_loaded") or mock_index
|
||||
),
|
||||
)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = []
|
||||
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=True)
|
||||
mocker.patch("paperless_ai.indexing.load_or_build_index")
|
||||
node1 = SimpleNamespace(metadata={"document_id": "1"})
|
||||
node2 = SimpleNamespace(metadata={"document_id": "2"})
|
||||
retriever = mocker.MagicMock()
|
||||
retriever.retrieve.return_value = [node1, node2]
|
||||
mocker.patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
return_value=mock_retriever,
|
||||
return_value=retriever,
|
||||
)
|
||||
|
||||
mocker.patch("paperless_ai.indexing.truncate_content", return_value="")
|
||||
result = indexing.retrieve_similar_nodes(real_document, top_k=3)
|
||||
|
||||
indexing.query_similar_documents(MagicMock(spec=Document))
|
||||
assert result == [node1, node2]
|
||||
|
||||
mock_file_lock_cls.assert_called()
|
||||
mock_lock_instance.__enter__.assert_called()
|
||||
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
|
||||
"Lock must be acquired before the index is loaded"
|
||||
@pytest.mark.django_db
|
||||
def test_empty_allow_list_fails_closed(
|
||||
self,
|
||||
real_document: Document,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
load = mocker.patch("paperless_ai.indexing.load_or_build_index")
|
||||
|
||||
result = indexing.retrieve_similar_nodes(real_document, document_ids=[])
|
||||
|
||||
assert result == []
|
||||
load.assert_not_called()
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_queues_update_when_index_missing(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=False)
|
||||
queue = mocker.patch("paperless_ai.indexing.queue_llm_index_update_if_needed")
|
||||
|
||||
result = indexing.retrieve_similar_nodes(real_document, top_k=2)
|
||||
|
||||
assert result == []
|
||||
queue.assert_called_once_with(
|
||||
rebuild=False,
|
||||
reason="LLM index not found for similarity query.",
|
||||
)
|
||||
|
||||
+110
-130
@@ -3,19 +3,20 @@ from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core import settings as llama_settings
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from paperless_ai import chat
|
||||
from paperless_ai import indexing
|
||||
from paperless_ai.chat import CHAT_ERROR_MESSAGE
|
||||
from paperless_ai.chat import CHAT_METADATA_DELIMITER
|
||||
from paperless_ai.chat import _get_document_filtered_retriever
|
||||
from paperless_ai.chat import stream_chat_with_documents
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_embed_model():
|
||||
from llama_index.core import settings as llama_settings
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
|
||||
# Use a real BaseEmbedding subclass to satisfy llama-index 0.14 validation
|
||||
llama_settings.Settings.embed_model = MockEmbedding(embed_dim=1536)
|
||||
yield
|
||||
@@ -58,91 +59,6 @@ def assert_chat_output(
|
||||
}
|
||||
|
||||
|
||||
def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None:
|
||||
mock_index.index_struct.nodes_dict = {
|
||||
str(vector_id): node.node_id for vector_id, node in enumerate(nodes)
|
||||
}
|
||||
mock_index.docstore.docs.get.side_effect = {
|
||||
node.node_id: node for node in nodes
|
||||
}.get
|
||||
mock_index.vector_store._faiss_index.ntotal = len(nodes)
|
||||
mock_index.vector_store.query.return_value = MagicMock(
|
||||
ids=list(mock_index.index_struct.nodes_dict),
|
||||
similarities=[0.1] * len(nodes),
|
||||
)
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
|
||||
def test_document_filtered_retriever_expands_filters_and_caches() -> None:
|
||||
allowed_node1 = TextNode(
|
||||
text="Allowed content 1.",
|
||||
metadata={"document_id": "1", "title": "Allowed 1"},
|
||||
)
|
||||
allowed_node2 = TextNode(
|
||||
text="Allowed content 2.",
|
||||
metadata={"document_id": "2", "title": "Allowed 2"},
|
||||
)
|
||||
foreign_node = TextNode(
|
||||
text="Foreign content.",
|
||||
metadata={"document_id": "3", "title": "Foreign"},
|
||||
)
|
||||
missing_node = TextNode(
|
||||
text="Missing content.",
|
||||
metadata={"document_id": "1", "title": "Missing"},
|
||||
)
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_index.index_struct.nodes_dict = {
|
||||
"0": foreign_node.node_id,
|
||||
"1": missing_node.node_id,
|
||||
"2": allowed_node1.node_id,
|
||||
"3": allowed_node2.node_id,
|
||||
}
|
||||
mock_index.docstore.docs.get.side_effect = {
|
||||
allowed_node1.node_id: allowed_node1,
|
||||
allowed_node2.node_id: allowed_node2,
|
||||
foreign_node.node_id: foreign_node,
|
||||
}.get
|
||||
mock_index.vector_store._faiss_index.ntotal = 4
|
||||
mock_index.vector_store.query.side_effect = [
|
||||
MagicMock(ids=["0", "2"], similarities=[0.9, 0.8]),
|
||||
MagicMock(ids=["0", "1", "3"], similarities=[0.9, 0.7, 0.6]),
|
||||
]
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
mock_index,
|
||||
{"1", "2"},
|
||||
similarity_top_k=2,
|
||||
)
|
||||
|
||||
nodes = retriever.retrieve("question")
|
||||
cached_nodes = retriever.retrieve("question")
|
||||
|
||||
assert [node.node.node_id for node in nodes] == [
|
||||
allowed_node1.node_id,
|
||||
allowed_node2.node_id,
|
||||
]
|
||||
assert cached_nodes == nodes
|
||||
assert mock_index.vector_store.query.call_count == 2
|
||||
assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1
|
||||
|
||||
|
||||
def test_document_filtered_retriever_handles_empty_faiss_index() -> None:
|
||||
mock_index = MagicMock()
|
||||
mock_index.vector_store._faiss_index.ntotal = 0
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
mock_index,
|
||||
{"1"},
|
||||
similarity_top_k=2,
|
||||
)
|
||||
|
||||
assert retriever.retrieve("question") == []
|
||||
mock_index.vector_store.query.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_stream_chat_with_one_document_retrieval(
|
||||
mock_document,
|
||||
@@ -164,17 +80,31 @@ def test_stream_chat_with_one_document_retrieval(
|
||||
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||
add_vector_query_results(mock_index, [mock_node])
|
||||
# Simulate get_nodes returning nodes (content exists)
|
||||
mock_index.vector_store.get_nodes.return_value = [mock_node]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
mock_retriever_instance = MagicMock()
|
||||
mock_retriever_instance.retrieve.return_value = [
|
||||
MagicMock(
|
||||
metadata={
|
||||
"document_id": str(mock_document.pk),
|
||||
"title": "Test Document",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
with patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
return_value=mock_retriever_instance,
|
||||
):
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
|
||||
mock_query_engine.query.assert_called_once_with("What is this?")
|
||||
patch_embed_nodes.assert_not_called()
|
||||
@@ -196,12 +126,10 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
):
|
||||
# Mock AIClient and LLM
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
# Create two real TextNodes
|
||||
mock_node1 = TextNode(
|
||||
text="Content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1"},
|
||||
@@ -210,41 +138,32 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
text="Content for doc 2.",
|
||||
metadata={"document_id": "2", "title": "Document 2"},
|
||||
)
|
||||
mock_duplicate_node = TextNode(
|
||||
text="More content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
|
||||
)
|
||||
mock_foreign_node = TextNode(
|
||||
text="Content for doc 3.",
|
||||
metadata={"document_id": "3", "title": "Document 3"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [
|
||||
mock_node1,
|
||||
mock_node2,
|
||||
mock_duplicate_node,
|
||||
mock_foreign_node,
|
||||
]
|
||||
add_vector_query_results(
|
||||
mock_index,
|
||||
[mock_node1, mock_duplicate_node, mock_node2, mock_foreign_node],
|
||||
)
|
||||
# Simulate get_nodes returning nodes (content exists)
|
||||
mock_index.vector_store.get_nodes.return_value = [mock_node1, mock_node2]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
# Mock response stream
|
||||
mock_retriever_instance = MagicMock()
|
||||
mock_retriever_instance.retrieve.return_value = [
|
||||
MagicMock(metadata={"document_id": "1", "title": "Document 1"}),
|
||||
MagicMock(metadata={"document_id": "2", "title": "Document 2"}),
|
||||
]
|
||||
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
|
||||
# Mock RetrieverQueryEngine
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
|
||||
# Fake documents
|
||||
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
|
||||
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
|
||||
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
with patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
return_value=mock_retriever_instance,
|
||||
):
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
|
||||
mock_query_engine.query.assert_called_once_with("What's up?")
|
||||
patch_embed_nodes.assert_not_called()
|
||||
@@ -258,8 +177,16 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_empty_document_list() -> None:
|
||||
with patch("paperless_ai.chat.load_or_build_index") as mock_load_index:
|
||||
output = list(stream_chat_with_documents("Any info?", []))
|
||||
mock_load_index.assert_not_called()
|
||||
assert output == ["Sorry, I couldn't find any content to answer your question."]
|
||||
|
||||
|
||||
def test_stream_chat_no_matching_nodes() -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIConfig"),
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
):
|
||||
@@ -268,8 +195,8 @@ def test_stream_chat_no_matching_nodes() -> None:
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
mock_index = MagicMock()
|
||||
# No matching nodes
|
||||
mock_index.docstore.docs.values.return_value = []
|
||||
# No matching nodes in the store
|
||||
mock_index.vector_store.get_nodes.return_value = []
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
@@ -279,30 +206,83 @@ def test_stream_chat_no_matching_nodes() -> None:
|
||||
|
||||
def test_stream_chat_unexpected_failure_returns_generic_error(caplog) -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIConfig"),
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"paperless_ai.chat._get_document_filtered_retriever",
|
||||
) as mock_get_retriever,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
mock_node = TextNode(
|
||||
text="This is node content.",
|
||||
metadata={"document_id": "1", "title": "Test Document"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||
# Nodes found so we get past the pre-check
|
||||
mock_index.vector_store.get_nodes.return_value = [MagicMock()]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.side_effect = RuntimeError("private provider detail")
|
||||
mock_get_retriever.return_value = mock_retriever
|
||||
with patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
) as mock_retriever_cls:
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.side_effect = RuntimeError(
|
||||
"private provider detail",
|
||||
)
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
|
||||
assert output == [CHAT_ERROR_MESSAGE]
|
||||
assert "Failed to stream document chat response" in caplog.text
|
||||
assert "private provider detail" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestStreamChatRetrieval:
|
||||
def test_no_nodes_yields_no_content_message(
|
||||
self,
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
doc = DocumentFactory.create(content="hello world")
|
||||
# Nothing indexed for this document yet.
|
||||
out = list(chat.stream_chat_with_documents("question?", [doc]))
|
||||
assert chat.CHAT_NO_CONTENT_MESSAGE in out
|
||||
|
||||
def test_chat_filter_contains_only_requested_document_ids(
|
||||
self,
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
mocker,
|
||||
) -> None:
|
||||
"""The MetadataFilter passed to the retriever must be scoped to the
|
||||
requested documents only — content from other indexed documents must
|
||||
not be surfaced.
|
||||
"""
|
||||
included = DocumentFactory.create(content="included document content")
|
||||
excluded = DocumentFactory.create(content="excluded document content")
|
||||
indexing.llm_index_add_or_update_document(included)
|
||||
indexing.llm_index_add_or_update_document(excluded)
|
||||
|
||||
# VectorIndexRetriever is imported inside _stream_chat_with_documents;
|
||||
# patch it at the llama_index source so the lazy import picks it up.
|
||||
captured_filters = []
|
||||
mock_retriever = mocker.MagicMock()
|
||||
mock_retriever.retrieve.return_value = []
|
||||
|
||||
def capture_retriever(*args, **kwargs):
|
||||
captured_filters.append(kwargs.get("filters"))
|
||||
return mock_retriever
|
||||
|
||||
mocker.patch("paperless_ai.chat.AIClient")
|
||||
mocker.patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
side_effect=capture_retriever,
|
||||
)
|
||||
|
||||
list(chat.stream_chat_with_documents("question?", [included]))
|
||||
|
||||
assert captured_filters, "VectorIndexRetriever was never constructed"
|
||||
filt = captured_filters[0]
|
||||
assert filt is not None, "Retriever must receive a MetadataFilters"
|
||||
filter_values = filt.filters[0].value
|
||||
assert str(included.pk) in filter_values
|
||||
assert str(excluded.pk) not in filter_values
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from unittest.mock import ANY
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
@@ -10,7 +9,7 @@ from documents.models import Document
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
from paperless_ai.embedding import _normalize_llm_index_text
|
||||
from paperless_ai.embedding import build_llm_index_text
|
||||
from paperless_ai.embedding import get_embedding_dim
|
||||
from paperless_ai.embedding import get_configured_model_name
|
||||
from paperless_ai.embedding import get_embedding_model
|
||||
|
||||
|
||||
@@ -67,7 +66,7 @@ def test_get_embedding_model_openai(mock_ai_config):
|
||||
with patch(
|
||||
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
|
||||
) as MockOpenAIEmbedding:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(mock_ai_config.return_value)
|
||||
MockOpenAIEmbedding.assert_called_once_with(
|
||||
model_name="text-embedding-3-small",
|
||||
api_key="test_api_key",
|
||||
@@ -88,7 +87,7 @@ def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config):
|
||||
with patch(
|
||||
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
|
||||
) as MockOpenAIEmbedding:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(mock_ai_config.return_value)
|
||||
MockOpenAIEmbedding.assert_called_once_with(
|
||||
model_name="text-embedding-3-small",
|
||||
api_key="test_api_key",
|
||||
@@ -109,7 +108,7 @@ def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
|
||||
mock_ai_config.return_value.llm_allow_internal_endpoints = False
|
||||
|
||||
with pytest.raises(ValueError, match="non-public address"):
|
||||
get_embedding_model()
|
||||
get_embedding_model(mock_ai_config.return_value)
|
||||
|
||||
|
||||
def test_get_embedding_model_huggingface(mock_ai_config):
|
||||
@@ -121,7 +120,7 @@ def test_get_embedding_model_huggingface(mock_ai_config):
|
||||
with patch(
|
||||
"llama_index.embeddings.huggingface.HuggingFaceEmbedding",
|
||||
) as MockHuggingFaceEmbedding:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(mock_ai_config.return_value)
|
||||
MockHuggingFaceEmbedding.assert_called_once_with(
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||
cache_folder=str(settings.DATA_DIR / "hf_cache"),
|
||||
@@ -137,7 +136,7 @@ def test_get_embedding_model_ollama(mock_ai_config):
|
||||
with patch(
|
||||
"llama_index.embeddings.ollama.OllamaEmbedding",
|
||||
) as MockOllamaEmbedding:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(mock_ai_config.return_value)
|
||||
MockOllamaEmbedding.assert_called_once_with(
|
||||
model_name="embeddinggemma",
|
||||
base_url="http://test-url",
|
||||
@@ -155,7 +154,7 @@ def test_get_embedding_model_ollama_prefers_embedding_endpoint(mock_ai_config):
|
||||
with patch(
|
||||
"llama_index.embeddings.ollama.OllamaEmbedding",
|
||||
) as MockOllamaEmbedding:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(mock_ai_config.return_value)
|
||||
MockOllamaEmbedding.assert_called_once_with(
|
||||
model_name="embeddinggemma",
|
||||
base_url="http://embedding-url",
|
||||
@@ -173,7 +172,7 @@ def test_get_embedding_model_ollama_blocks_internal_endpoint_when_disallowed(
|
||||
mock_ai_config.return_value.llm_allow_internal_endpoints = False
|
||||
|
||||
with pytest.raises(ValueError, match="non-public address"):
|
||||
get_embedding_model()
|
||||
get_embedding_model(mock_ai_config.return_value)
|
||||
|
||||
|
||||
def test_get_embedding_model_invalid_backend(mock_ai_config):
|
||||
@@ -183,55 +182,37 @@ def test_get_embedding_model_invalid_backend(mock_ai_config):
|
||||
ValueError,
|
||||
match="Unsupported embedding backend: INVALID_BACKEND",
|
||||
):
|
||||
get_embedding_model()
|
||||
get_embedding_model(mock_ai_config.return_value)
|
||||
|
||||
|
||||
def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
class DummyEmbedding:
|
||||
def get_text_embedding(self, text):
|
||||
return [0.0] * 7
|
||||
|
||||
with patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
return_value=DummyEmbedding(),
|
||||
) as mock_get:
|
||||
dim = get_embedding_dim()
|
||||
mock_get.assert_called_once()
|
||||
|
||||
assert dim == 7
|
||||
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||
assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7}
|
||||
@pytest.mark.parametrize(
|
||||
("backend", "expected_default"),
|
||||
[
|
||||
(LLMEmbeddingBackend.OPENAI_LIKE, "text-embedding-3-small"),
|
||||
(LLMEmbeddingBackend.HUGGINGFACE, "sentence-transformers/all-MiniLM-L6-v2"),
|
||||
(LLMEmbeddingBackend.OLLAMA, "embeddinggemma"),
|
||||
],
|
||||
)
|
||||
def test_get_configured_model_name_falls_back_to_backend_default(
|
||||
mock_ai_config,
|
||||
backend,
|
||||
expected_default,
|
||||
):
|
||||
"""When no model is explicitly configured, each backend has a distinct default."""
|
||||
config = mock_ai_config.return_value
|
||||
config.llm_embedding_backend = backend
|
||||
config.llm_embedding_model = None
|
||||
assert get_configured_model_name(config) == expected_default
|
||||
|
||||
|
||||
def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}),
|
||||
)
|
||||
|
||||
with patch("paperless_ai.embedding.get_embedding_model") as mock_get:
|
||||
assert get_embedding_dim() == 11
|
||||
mock_get.assert_not_called()
|
||||
|
||||
|
||||
def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "old", "dim": 11}),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Embedding model changed from old to text-embedding-3-small",
|
||||
):
|
||||
get_embedding_dim()
|
||||
def test_get_configured_model_name_explicit_overrides_default(mock_ai_config):
|
||||
"""An explicit model name overrides the backend default for all backends."""
|
||||
config = mock_ai_config.return_value
|
||||
config.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
|
||||
config.llm_embedding_model = "my-custom-model"
|
||||
# The backend default for OPENAI_LIKE is "text-embedding-3-small", so if
|
||||
# the explicit name was ignored we'd get the wrong result.
|
||||
assert get_configured_model_name(config) == "my-custom-model"
|
||||
|
||||
|
||||
def test_build_llm_index_text(mock_document):
|
||||
@@ -243,12 +224,17 @@ def test_build_llm_index_text(mock_document):
|
||||
|
||||
result = build_llm_index_text(mock_document)
|
||||
|
||||
assert "Title: Test Title" in result
|
||||
assert "Filename: test_file.pdf" in result
|
||||
assert "Created: 2023-01-01" in result
|
||||
assert "Tags: Tag1, Tag2" in result
|
||||
assert "Document Type: Invoice" in result
|
||||
assert "Correspondent: Test Correspondent" in result
|
||||
# Structured fields live in node.metadata for LLM context -- not body text
|
||||
assert "Title: Test Title" not in result
|
||||
assert "Created: 2023-01-01" not in result
|
||||
assert "Tags: Tag1, Tag2" not in result
|
||||
assert "Document Type: Invoice" not in result
|
||||
assert "Correspondent: Test Correspondent" not in result
|
||||
assert "Filename:" not in result
|
||||
assert "Storage Path:" not in result
|
||||
assert "Archive Serial Number:" not in result
|
||||
|
||||
# Fields without a metadata equivalent stay in body text
|
||||
assert "Notes: Note1,Note2" in result
|
||||
assert "Content:\n\nThis is the document content." in result
|
||||
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
from filelock import ReadWriteLock
|
||||
from llama_index.core.schema import TextNode
|
||||
from pytest_django.fixtures import SettingsWrapper
|
||||
|
||||
from paperless_ai import indexing
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
DIM = 8
|
||||
|
||||
|
||||
def _node(node_id: str, document_id: str, *, seed: float = 0.0) -> TextNode:
|
||||
node = TextNode(
|
||||
id_=node_id,
|
||||
text="chunk",
|
||||
metadata={"document_id": document_id, "modified": "2026-06-01T00:00:00"},
|
||||
)
|
||||
node.relationships = {}
|
||||
node.embedding = [seed + i / 100 for i in range(DIM)]
|
||||
return node
|
||||
|
||||
|
||||
def _seed_bloated_index(index_dir: Path) -> None:
|
||||
"""Create an index whose cumulative inserts far exceed live rows."""
|
||||
store = PaperlessSqliteVecVectorStore(uri=str(index_dir))
|
||||
store.add([_node(f"d{j}", str(j), seed=float(j)) for j in range(20)])
|
||||
for cycle in range(6):
|
||||
for j in range(20):
|
||||
store.upsert_document(
|
||||
str(j),
|
||||
[_node(f"d{j}-c{cycle}", str(j), seed=float(j))],
|
||||
)
|
||||
store.client.close()
|
||||
|
||||
|
||||
def _bloat_ratio(index_dir: Path) -> float:
|
||||
store = PaperlessSqliteVecVectorStore(uri=str(index_dir))
|
||||
live = store.client.execute("SELECT count(*) FROM documents").fetchone()[0]
|
||||
row = store.client.execute(
|
||||
"SELECT value FROM index_meta WHERE key = 'total_inserts'",
|
||||
).fetchone()
|
||||
total = int(row["value"]) if row else live
|
||||
store.client.close()
|
||||
return total / max(live, 1)
|
||||
|
||||
|
||||
def _integrity_ok(index_dir: Path) -> bool:
|
||||
store = PaperlessSqliteVecVectorStore(uri=str(index_dir))
|
||||
result = store.client.execute("PRAGMA integrity_check").fetchone()[0]
|
||||
rows = store.client.execute("SELECT count(*) FROM documents").fetchone()[0]
|
||||
store.client.close()
|
||||
return result == "ok" and rows == 20
|
||||
|
||||
|
||||
def _reader_lock() -> ReadWriteLock:
|
||||
# A distinct instance simulates a reader in another process: it coordinates
|
||||
# with the production lock purely through SQLite, never reentrant upgrade.
|
||||
return ReadWriteLock(str(settings.LLM_INDEX_RWLOCK), is_singleton=False)
|
||||
|
||||
|
||||
class TestCompactionLock:
|
||||
def test_compaction_skips_when_a_reader_holds_the_lock(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
settings: SettingsWrapper,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
_seed_bloated_index(temp_llm_index_dir)
|
||||
settings.LLM_INDEX_COMPACTION_LOCK_TIMEOUT = 0.3
|
||||
|
||||
lock = _reader_lock()
|
||||
with lock.read_lock(), caplog.at_level(logging.INFO):
|
||||
indexing.llm_index_compact() # must not raise
|
||||
lock.close()
|
||||
|
||||
# Swap was skipped: bloat remains, nothing corrupted, data intact.
|
||||
assert _integrity_ok(temp_llm_index_dir)
|
||||
assert _bloat_ratio(temp_llm_index_dir) > 2
|
||||
assert "Skipping LLM index compaction" in caplog.text
|
||||
|
||||
def test_compaction_runs_when_no_reader_holds_the_lock(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
) -> None:
|
||||
_seed_bloated_index(temp_llm_index_dir)
|
||||
assert _bloat_ratio(temp_llm_index_dir) > 2
|
||||
|
||||
indexing.llm_index_compact()
|
||||
|
||||
assert _bloat_ratio(temp_llm_index_dir) == pytest.approx(1.0)
|
||||
assert _integrity_ok(temp_llm_index_dir)
|
||||
|
||||
def test_normal_write_is_not_gated_by_the_compaction_lock(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
) -> None:
|
||||
"""A held exclusive lock must not block ordinary writes (WAL handles them)."""
|
||||
_seed_bloated_index(temp_llm_index_dir)
|
||||
done = threading.Event()
|
||||
|
||||
def remove() -> None:
|
||||
indexing.llm_index_remove_document(MagicMock(id=999))
|
||||
done.set()
|
||||
|
||||
holder = _reader_lock()
|
||||
with holder.write_lock():
|
||||
t = threading.Thread(target=remove)
|
||||
t.start()
|
||||
finished = done.wait(timeout=5)
|
||||
t.join(timeout=2)
|
||||
holder.close()
|
||||
assert finished, "a normal write blocked on the compaction lock"
|
||||
|
||||
|
||||
class TestReadStore:
|
||||
def test_closes_connection_on_exit(self, temp_llm_index_dir: Path) -> None:
|
||||
with indexing.read_store() as store:
|
||||
conn = store.client
|
||||
assert conn.execute("SELECT 1").fetchone()[0] == 1
|
||||
with pytest.raises(sqlite3.ProgrammingError):
|
||||
conn.execute("SELECT 1")
|
||||
|
||||
def test_concurrent_readers_do_not_block(self, temp_llm_index_dir: Path) -> None:
|
||||
_seed_bloated_index(temp_llm_index_dir)
|
||||
with indexing.read_store() as a, indexing.read_store() as b:
|
||||
assert a.table_exists()
|
||||
assert b.table_exists()
|
||||
@@ -0,0 +1,25 @@
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_SRC_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class TestLazyAiImports:
|
||||
def test_importing_tasks_does_not_load_ai_libraries(self) -> None:
|
||||
code = (
|
||||
"import os, django, sys\n"
|
||||
"os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'paperless.settings')\n"
|
||||
"django.setup()\n"
|
||||
"import documents.tasks # noqa: F401\n"
|
||||
"leaked = [m for m in ('lancedb', 'pyarrow', 'llama_index', 'sqlite_vec') "
|
||||
"if m in sys.modules]\n"
|
||||
"assert not leaked, f'AI libraries leaked into the light path: {leaked}'\n"
|
||||
)
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-c", code],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=_SRC_DIR,
|
||||
)
|
||||
assert result.returncode == 0, result.stdout + result.stderr
|
||||
@@ -1,12 +1,15 @@
|
||||
import difflib
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from django.test import TestCase
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.tests.factories import TagFactory
|
||||
from paperless_ai.matching import extract_unmatched_names
|
||||
from paperless_ai.matching import match_correspondents_by_name
|
||||
from paperless_ai.matching import match_document_types_by_name
|
||||
@@ -87,6 +90,95 @@ class TestAIMatching(TestCase):
|
||||
self.assertEqual(result[1].name, "Test Tag 2")
|
||||
|
||||
|
||||
class TestHintedMatching:
|
||||
def test_hinted_verbatim_skips_fuzzy(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.matching.get_objects_for_user_owner_aware",
|
||||
return_value=[TagFactory.build(name="Bloodwork")],
|
||||
)
|
||||
spy = mocker.spy(difflib, "get_close_matches")
|
||||
|
||||
result = match_tags_by_name(
|
||||
["Bloodwork"],
|
||||
user=None,
|
||||
hinted_names={"Bloodwork"},
|
||||
)
|
||||
|
||||
assert [t.name for t in result] == ["Bloodwork"]
|
||||
spy.assert_not_called()
|
||||
|
||||
def test_unhinted_name_still_fuzzy_matches(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.matching.get_objects_for_user_owner_aware",
|
||||
return_value=[TagFactory.build(name="Bloodwork")],
|
||||
)
|
||||
|
||||
# "Bloodwrok" is a typo not in hints -> fuzzy still maps it to Bloodwork.
|
||||
result = match_tags_by_name(
|
||||
["Bloodwrok"],
|
||||
user=None,
|
||||
hinted_names={"Taxes"},
|
||||
)
|
||||
|
||||
assert [t.name for t in result] == ["Bloodwork"]
|
||||
|
||||
def test_hinted_name_with_whitespace_exact_matches(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.matching.get_objects_for_user_owner_aware",
|
||||
return_value=[TagFactory.build(name="Bloodwork")],
|
||||
)
|
||||
spy = mocker.spy(difflib, "get_close_matches")
|
||||
|
||||
result = match_tags_by_name(
|
||||
["Bloodwork "],
|
||||
user=None,
|
||||
hinted_names={"Bloodwork"},
|
||||
)
|
||||
|
||||
assert [t.name for t in result] == ["Bloodwork"]
|
||||
spy.assert_not_called()
|
||||
|
||||
def test_hinted_name_absent_from_queryset_is_skipped_not_fuzzed(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
# A hint with no exact object must not fall through to fuzzy.
|
||||
mocker.patch(
|
||||
"paperless_ai.matching.get_objects_for_user_owner_aware",
|
||||
return_value=[TagFactory.build(name="Bloodwork")],
|
||||
)
|
||||
|
||||
result = match_tags_by_name(
|
||||
["Bloodwrok"],
|
||||
user=None,
|
||||
hinted_names={"Bloodwrok"},
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_backward_compatible_without_kwarg(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.matching.get_objects_for_user_owner_aware",
|
||||
return_value=[TagFactory.build(name="Test Tag 1")],
|
||||
)
|
||||
|
||||
result = match_tags_by_name(["Test Tag 1", "Nonexistent"], user=None)
|
||||
|
||||
assert [t.name for t in result] == ["Test Tag 1"]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestExtractUnmatchedNamesNormalization:
|
||||
def test_punctuated_name_already_matched_is_not_returned_as_unmatched(
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest_mock
|
||||
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from paperless_ai.taxonomy import TaxonomyHints
|
||||
from paperless_ai.taxonomy import build_taxonomy_hints_from_nodes
|
||||
from paperless_ai.taxonomy import format_hints_for_prompt
|
||||
from paperless_ai.taxonomy import get_taxonomy_hints_for_document
|
||||
|
||||
|
||||
def make_node(**metadata: object) -> SimpleNamespace:
|
||||
"""A stand-in for NodeWithScore: only ``.metadata`` is accessed."""
|
||||
return SimpleNamespace(metadata=metadata)
|
||||
|
||||
|
||||
class TestBuildTaxonomyHintsFromNodes:
|
||||
def test_returns_all_four_keys(self) -> None:
|
||||
hints = build_taxonomy_hints_from_nodes([])
|
||||
assert set(hints.keys()) == {
|
||||
"tags",
|
||||
"document_types",
|
||||
"correspondents",
|
||||
"storage_paths",
|
||||
}
|
||||
|
||||
def test_collects_and_sorts_values(self) -> None:
|
||||
nodes = [
|
||||
make_node(
|
||||
tags=["Taxes", "Bloodwork"],
|
||||
document_type="Invoice",
|
||||
correspondent="IRS",
|
||||
storage_path="Financial",
|
||||
),
|
||||
]
|
||||
hints = build_taxonomy_hints_from_nodes(nodes)
|
||||
assert hints["tags"] == ["Bloodwork", "Taxes"]
|
||||
assert hints["document_types"] == ["Invoice"]
|
||||
assert hints["correspondents"] == ["IRS"]
|
||||
assert hints["storage_paths"] == ["Financial"]
|
||||
|
||||
def test_deduplicates_across_nodes(self) -> None:
|
||||
nodes = [
|
||||
make_node(tags=["Taxes"], document_type="Invoice"),
|
||||
make_node(tags=["Taxes", "Medical"], document_type="Invoice"),
|
||||
]
|
||||
hints = build_taxonomy_hints_from_nodes(nodes)
|
||||
assert hints["tags"] == ["Medical", "Taxes"]
|
||||
assert hints["document_types"] == ["Invoice"]
|
||||
|
||||
def test_none_values_skipped(self) -> None:
|
||||
nodes = [
|
||||
make_node(
|
||||
tags=["Taxes", None, ""],
|
||||
document_type=None,
|
||||
correspondent=None,
|
||||
storage_path=None,
|
||||
),
|
||||
]
|
||||
hints = build_taxonomy_hints_from_nodes(nodes)
|
||||
assert hints["tags"] == ["Taxes"]
|
||||
assert hints["document_types"] == []
|
||||
assert hints["correspondents"] == []
|
||||
assert hints["storage_paths"] == []
|
||||
|
||||
def test_missing_storage_path_key_handled(self) -> None:
|
||||
# Pre-enrichment nodes have no storage_path key at all.
|
||||
nodes = [make_node(tags=["Taxes"], document_type="Invoice")]
|
||||
hints = build_taxonomy_hints_from_nodes(nodes)
|
||||
assert hints["storage_paths"] == []
|
||||
|
||||
def test_empty_node_list_all_empty(self) -> None:
|
||||
hints = build_taxonomy_hints_from_nodes([])
|
||||
assert hints == {
|
||||
"tags": [],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
|
||||
def test_output_stable_across_calls(self) -> None:
|
||||
nodes = [make_node(tags=["b", "a", "c"])]
|
||||
assert build_taxonomy_hints_from_nodes(
|
||||
nodes,
|
||||
) == build_taxonomy_hints_from_nodes(nodes)
|
||||
|
||||
|
||||
class TestFormatHintsForPrompt:
|
||||
def test_all_blocks_present_when_all_categories_nonempty(self) -> None:
|
||||
hints: TaxonomyHints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": ["Invoice"],
|
||||
"correspondents": ["IRS"],
|
||||
"storage_paths": ["Financial"],
|
||||
}
|
||||
result = format_hints_for_prompt(hints)
|
||||
assert "Available tags:" in result
|
||||
assert "Available document types:" in result
|
||||
assert "Available correspondents:" in result
|
||||
assert "Available storage paths:" in result
|
||||
assert "- Bloodwork" in result
|
||||
|
||||
def test_empty_category_produces_no_block(self) -> None:
|
||||
hints: TaxonomyHints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
result = format_hints_for_prompt(hints)
|
||||
assert "Available tags:" in result
|
||||
assert "Available document types:" not in result
|
||||
assert "Available correspondents:" not in result
|
||||
assert "Available storage paths:" not in result
|
||||
|
||||
def test_all_empty_produces_empty_string(self) -> None:
|
||||
hints: TaxonomyHints = {
|
||||
"tags": [],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
assert format_hints_for_prompt(hints) == ""
|
||||
|
||||
def test_instruction_line_appears_once(self) -> None:
|
||||
hints: TaxonomyHints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": ["Invoice"],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
result = format_hints_for_prompt(hints)
|
||||
assert result.count("Prefer existing names from these lists verbatim") == 1
|
||||
|
||||
|
||||
class TestGetTaxonomyHintsForDocument:
|
||||
def test_returns_none_when_embedding_backend_off(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.AIConfig",
|
||||
return_value=SimpleNamespace(llm_embedding_backend=None),
|
||||
)
|
||||
retrieve = mocker.patch("paperless_ai.taxonomy.retrieve_similar_nodes")
|
||||
|
||||
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
|
||||
|
||||
assert result is None
|
||||
retrieve.assert_not_called()
|
||||
|
||||
def test_passes_owner_aware_ids_when_user_present(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.AIConfig",
|
||||
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
|
||||
)
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.visible_document_ids_for_user",
|
||||
return_value=[1, 2, 3],
|
||||
)
|
||||
retrieve = mocker.patch(
|
||||
"paperless_ai.taxonomy.retrieve_similar_nodes",
|
||||
return_value=[],
|
||||
)
|
||||
document = DocumentFactory.build()
|
||||
user = mocker.MagicMock()
|
||||
|
||||
get_taxonomy_hints_for_document(document, user=user)
|
||||
|
||||
retrieve.assert_called_once_with(
|
||||
document=document,
|
||||
document_ids=[1, 2, 3],
|
||||
)
|
||||
|
||||
def test_returns_populated_hints_when_nodes_found(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.AIConfig",
|
||||
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
|
||||
)
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.retrieve_similar_nodes",
|
||||
return_value=[make_node(tags=["Taxes"], document_type="Invoice")],
|
||||
)
|
||||
|
||||
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
|
||||
|
||||
assert result == {
|
||||
"tags": ["Taxes"],
|
||||
"document_types": ["Invoice"],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
|
||||
def test_returns_empty_hints_not_none_when_no_nodes(
|
||||
self,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.AIConfig",
|
||||
return_value=SimpleNamespace(llm_embedding_backend="huggingface"),
|
||||
)
|
||||
mocker.patch(
|
||||
"paperless_ai.taxonomy.retrieve_similar_nodes",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
result = get_taxonomy_hints_for_document(DocumentFactory.build(), user=None)
|
||||
|
||||
assert result == {
|
||||
"tags": [],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
@@ -0,0 +1,606 @@
|
||||
import sqlite3
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import FilterOperator
|
||||
from llama_index.core.vector_stores.types import MetadataFilter
|
||||
from llama_index.core.vector_stores.types import MetadataFilters
|
||||
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||
|
||||
from paperless_ai.vector_store import DB_FILENAME
|
||||
from paperless_ai.vector_store import DEFAULT_TABLE_NAME
|
||||
from paperless_ai.vector_store import MIGRATIONS
|
||||
from paperless_ai.vector_store import SCHEMA_VERSION
|
||||
from paperless_ai.vector_store import Migration
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
from paperless_ai.vector_store import _build_where
|
||||
|
||||
DIM = 16
|
||||
|
||||
|
||||
def make_node(
|
||||
node_id: str,
|
||||
document_id: str,
|
||||
*,
|
||||
modified: str = "2026-06-10T00:00:00",
|
||||
seed: float = 0.0,
|
||||
text: str = "some text",
|
||||
) -> TextNode:
|
||||
node = TextNode(
|
||||
id_=node_id,
|
||||
text=text,
|
||||
metadata={"document_id": document_id, "modified": modified},
|
||||
)
|
||||
node.relationships = {}
|
||||
node.embedding = [seed + i / 100 for i in range(DIM)]
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path: Path) -> Generator[PaperlessSqliteVecVectorStore, None, None]:
|
||||
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as store:
|
||||
yield store
|
||||
|
||||
|
||||
def _query(
|
||||
store: PaperlessSqliteVecVectorStore,
|
||||
embedding: list[float],
|
||||
top_k: int = 5,
|
||||
filters=None,
|
||||
):
|
||||
return store.query(
|
||||
VectorStoreQuery(
|
||||
query_embedding=embedding,
|
||||
similarity_top_k=top_k,
|
||||
filters=filters,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _eq_filter(key: str, value: str):
|
||||
return MetadataFilters(
|
||||
filters=[MetadataFilter(key=key, operator=FilterOperator.EQ, value=value)],
|
||||
)
|
||||
|
||||
|
||||
def _in_filter(document_ids: list[str]):
|
||||
return MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="document_id",
|
||||
operator=FilterOperator.IN,
|
||||
value=document_ids,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestCrud:
|
||||
def test_add_then_query_returns_node(self, store) -> None:
|
||||
node = make_node("n1", "1")
|
||||
assert store.add([node]) == ["n1"]
|
||||
result = _query(store, node.embedding, top_k=1)
|
||||
assert result.ids == ["n1"]
|
||||
assert result.nodes[0].metadata["document_id"] == "1"
|
||||
# cosine distance of the identical vector is 0 -> similarity 1
|
||||
assert result.similarities[0] == pytest.approx(1.0)
|
||||
|
||||
def test_query_empty_store_returns_empty_no_raise(self, store) -> None:
|
||||
result = _query(store, [0.0] * DIM)
|
||||
assert result.ids == [] and result.nodes == [] and result.similarities == []
|
||||
|
||||
def test_add_empty_list_is_noop(self, store) -> None:
|
||||
assert store.add([]) == []
|
||||
assert not store.table_exists()
|
||||
|
||||
def test_delete_removes_all_chunks_of_document(self, store) -> None:
|
||||
store.add([make_node("a1", "1"), make_node("a2", "1"), make_node("b1", "2")])
|
||||
store.delete("1")
|
||||
result = _query(store, [0.0] * DIM, top_k=10)
|
||||
assert result.ids == ["b1"]
|
||||
|
||||
def test_query_with_in_filter_scopes_results(self, store) -> None:
|
||||
store.add(
|
||||
[
|
||||
make_node("a1", "1", seed=0.0),
|
||||
make_node("b1", "2", seed=1.0),
|
||||
make_node("c1", "3", seed=2.0),
|
||||
],
|
||||
)
|
||||
result = _query(store, [0.0] * DIM, top_k=10, filters=_in_filter(["2", "3"]))
|
||||
assert sorted(result.ids) == ["b1", "c1"]
|
||||
|
||||
def test_query_respects_top_k_with_filter(self, store) -> None:
|
||||
# k semantics: global top-k even with IN filters (document_id is a
|
||||
# metadata column, not a partition key -- see design doc).
|
||||
store.add(
|
||||
[make_node(f"n{i}", str(i % 4), seed=float(i)) for i in range(12)],
|
||||
)
|
||||
result = _query(
|
||||
store,
|
||||
[0.0] * DIM,
|
||||
top_k=3,
|
||||
filters=_in_filter(["0", "1", "2", "3"]),
|
||||
)
|
||||
assert len(result.ids) == 3
|
||||
assert result.similarities == sorted(result.similarities, reverse=True)
|
||||
|
||||
def test_get_nodes_filter_and_empty_paths(self, store) -> None:
|
||||
assert store.get_nodes(filters=_in_filter(["1"])) == [] # no table yet
|
||||
store.add([make_node("a1", "1"), make_node("b1", "2")])
|
||||
nodes = store.get_nodes(filters=_in_filter(["1"]))
|
||||
assert [n.node_id for n in nodes] == ["a1"]
|
||||
assert nodes[0].embedding is not None
|
||||
assert store.get_nodes(filters=_in_filter(["999"])) == []
|
||||
|
||||
def test_query_with_eq_filter_scopes_results(self, store) -> None:
|
||||
store.add(
|
||||
[
|
||||
make_node("a1", "1", seed=0.0),
|
||||
make_node("b1", "2", seed=1.0),
|
||||
make_node("c1", "3", seed=2.0),
|
||||
],
|
||||
)
|
||||
result = _query(
|
||||
store,
|
||||
[0.0] * DIM,
|
||||
top_k=10,
|
||||
filters=_eq_filter("document_id", "2"),
|
||||
)
|
||||
assert result.ids == ["b1"]
|
||||
|
||||
def test_get_nodes_node_ids_not_implemented(self, store) -> None:
|
||||
with pytest.raises(NotImplementedError):
|
||||
store.get_nodes(node_ids=["x"])
|
||||
|
||||
def test_fresh_instance_sees_existing_table(self, store, tmp_path: Path) -> None:
|
||||
store.add([make_node("a1", "1")])
|
||||
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as reopened:
|
||||
assert reopened.table_exists()
|
||||
assert reopened.vector_dim() == DIM
|
||||
assert _query(reopened, [0.0] * DIM, top_k=1).ids == ["a1"]
|
||||
|
||||
def test_table_exists_and_drop(self, store) -> None:
|
||||
assert not store.table_exists()
|
||||
store.add([make_node("a1", "1")])
|
||||
assert store.table_exists()
|
||||
store.drop_table()
|
||||
assert not store.table_exists()
|
||||
assert store.vector_dim() is None
|
||||
|
||||
|
||||
class TestBuildWhere:
|
||||
def test_fails_closed_when_no_filter_is_translatable(self) -> None:
|
||||
# A nested MetadataFilters is not a MetadataFilter, so it is skipped.
|
||||
# With no translatable clauses, the function must fail closed rather
|
||||
# than emit "()" (invalid SQL) and never widen document access.
|
||||
nested = MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="document_id",
|
||||
operator=FilterOperator.EQ,
|
||||
value="1",
|
||||
),
|
||||
],
|
||||
)
|
||||
where, params = _build_where(MetadataFilters(filters=[nested]))
|
||||
assert where == "1 = 0"
|
||||
assert params == []
|
||||
|
||||
def test_query_with_untranslatable_filter_returns_no_rows(self, store) -> None:
|
||||
store.add([make_node("a1", "1"), make_node("b1", "2")])
|
||||
nested = MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="document_id",
|
||||
operator=FilterOperator.EQ,
|
||||
value="1",
|
||||
),
|
||||
],
|
||||
)
|
||||
filters = MetadataFilters(filters=[nested])
|
||||
# Must not raise (no "WHERE ()") and must return nothing (fail closed).
|
||||
assert _query(store, [0.0] * DIM, top_k=5, filters=filters).ids == []
|
||||
assert store.get_nodes(filters=filters) == []
|
||||
|
||||
|
||||
class TestUpsert:
|
||||
def test_upsert_replaces_and_prunes_stale_chunks(self, store) -> None:
|
||||
store.add(
|
||||
[make_node("d1c1", "1"), make_node("d1c2", "1"), make_node("d2c1", "2")],
|
||||
)
|
||||
store.upsert_document("1", [make_node("d1new", "1")])
|
||||
result = _query(store, [0.0] * DIM, top_k=10)
|
||||
assert sorted(result.ids) == ["d1new", "d2c1"]
|
||||
|
||||
def test_upsert_creates_table_when_missing(self, store) -> None:
|
||||
store.upsert_document("1", [make_node("a1", "1")])
|
||||
assert _query(store, [0.0] * DIM, top_k=1).ids == ["a1"]
|
||||
|
||||
def test_upsert_empty_nodes_removes_document(self, store) -> None:
|
||||
store.add([make_node("a1", "1"), make_node("b1", "2")])
|
||||
store.upsert_document("1", [])
|
||||
assert _query(store, [0.0] * DIM, top_k=10).ids == ["b1"]
|
||||
|
||||
def test_upsert_is_atomic_for_concurrent_readers(
|
||||
self,
|
||||
store,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""A second connection must never observe document 1 half-replaced."""
|
||||
store.add([make_node("a1", "1"), make_node("a2", "1")])
|
||||
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as reader:
|
||||
store.upsert_document("1", [make_node("a3", "1")])
|
||||
ids = [n.node_id for n in reader.get_nodes(filters=_in_filter(["1"]))]
|
||||
assert ids == ["a3"]
|
||||
|
||||
|
||||
class TestMetadataCoercion:
|
||||
def test_none_metadata_values_become_empty_strings(self, store) -> None:
|
||||
node = make_node("a1", "1")
|
||||
node.metadata["modified"] = None
|
||||
store.add([node]) # must not raise (vec0 rejects NULL metadata)
|
||||
assert store.get_modified_times() == {"1": ""}
|
||||
|
||||
|
||||
class TestModelNameTracking:
|
||||
def test_stored_model_name_none_without_table(self, tmp_path: Path) -> None:
|
||||
with PaperlessSqliteVecVectorStore(
|
||||
uri=str(tmp_path),
|
||||
embed_model_name="model-a",
|
||||
) as store:
|
||||
assert store.stored_model_name() is None
|
||||
|
||||
def test_model_name_stored_after_add_and_persists(self, tmp_path: Path) -> None:
|
||||
with PaperlessSqliteVecVectorStore(
|
||||
uri=str(tmp_path),
|
||||
embed_model_name="model-a",
|
||||
) as store:
|
||||
store.add([make_node("a1", "1")])
|
||||
assert store.stored_model_name() == "model-a"
|
||||
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as reopened:
|
||||
assert reopened.stored_model_name() == "model-a"
|
||||
|
||||
def test_config_mismatch_semantics(self, tmp_path: Path) -> None:
|
||||
with PaperlessSqliteVecVectorStore(
|
||||
uri=str(tmp_path),
|
||||
embed_model_name="model-a",
|
||||
) as store:
|
||||
assert not store.config_mismatch("anything") # no table yet
|
||||
store.add([make_node("a1", "1")])
|
||||
assert not store.config_mismatch("model-a")
|
||||
assert store.config_mismatch("model-b")
|
||||
|
||||
def test_config_mismatch_false_when_table_predates_tracking(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
with PaperlessSqliteVecVectorStore(uri=str(tmp_path)) as store: # no model name
|
||||
store.add([make_node("a1", "1")])
|
||||
assert not store.config_mismatch("model-a")
|
||||
|
||||
|
||||
class TestGetModifiedTimes:
|
||||
def test_empty_store_returns_empty_dict(self, store) -> None:
|
||||
assert store.get_modified_times() == {}
|
||||
|
||||
def test_returns_one_entry_per_document(self, store) -> None:
|
||||
store.add(
|
||||
[
|
||||
make_node("a1", "1", modified="2026-01-01T00:00:00"),
|
||||
make_node("a2", "1", modified="2026-01-01T00:00:00"),
|
||||
make_node("b1", "2", modified="2026-02-02T00:00:00"),
|
||||
],
|
||||
)
|
||||
assert store.get_modified_times() == {
|
||||
"1": "2026-01-01T00:00:00",
|
||||
"2": "2026-02-02T00:00:00",
|
||||
}
|
||||
|
||||
|
||||
class TestCompact:
|
||||
def _bloat_ratio(self, store) -> float:
|
||||
live = store.client.execute(
|
||||
"SELECT count(*) FROM documents",
|
||||
).fetchone()[0]
|
||||
# vec0 0.1.9 does not accumulate deleted rows in the _rowids shadow
|
||||
# table, so we track cumulative inserts in index_meta instead.
|
||||
row = store.client.execute(
|
||||
"SELECT value FROM index_meta WHERE key = 'total_inserts'",
|
||||
).fetchone()
|
||||
total = int(row["value"]) if row else live
|
||||
return total / max(live, 1)
|
||||
|
||||
def _churn(self, store, cycles: int) -> None:
|
||||
for i in range(cycles):
|
||||
store.upsert_document(
|
||||
"1",
|
||||
[make_node(f"gen{i}-{j}", "1", seed=float(j)) for j in range(20)],
|
||||
)
|
||||
|
||||
def test_compact_noop_below_threshold(self, store) -> None:
|
||||
store.add([make_node("a1", "1")])
|
||||
store.compact()
|
||||
assert _query(store, [0.0] * DIM, top_k=1).ids == ["a1"]
|
||||
|
||||
def test_force_compact_preserves_rows_and_metadata(self, store) -> None:
|
||||
store.add([make_node("a1", "1"), make_node("b1", "2", seed=3.0)])
|
||||
self._churn(store, 5)
|
||||
before = {
|
||||
n.node_id: n.metadata
|
||||
for n in store.get_nodes(filters=_in_filter(["1", "2"]))
|
||||
}
|
||||
store.compact(force=True)
|
||||
after = {
|
||||
n.node_id: n.metadata
|
||||
for n in store.get_nodes(filters=_in_filter(["1", "2"]))
|
||||
}
|
||||
assert after == before
|
||||
assert self._bloat_ratio(store) == pytest.approx(1.0)
|
||||
# store remains fully usable after the rebuild; use a seed far from all
|
||||
# existing nodes (gen4-0..gen4-19 have seeds 0..19) so cosine KNN is
|
||||
# unambiguous at top_k=1.
|
||||
store.upsert_document("3", [make_node("c1", "3", seed=100.0)])
|
||||
assert "c1" in _query(store, [100.0] * DIM, top_k=1).ids
|
||||
|
||||
def test_auto_compact_triggers_on_churn(self, store) -> None:
|
||||
store.add([make_node(f"s{j}", "1", seed=float(j)) for j in range(20)])
|
||||
self._churn(store, 5)
|
||||
assert self._bloat_ratio(store) > 2
|
||||
store.compact()
|
||||
assert self._bloat_ratio(store) == pytest.approx(1.0)
|
||||
|
||||
def test_compact_on_missing_table_is_noop(self, store) -> None:
|
||||
store.compact()
|
||||
store.compact(force=True)
|
||||
|
||||
def test_failed_compact_removes_temp_wal_and_shm(
|
||||
self,
|
||||
store,
|
||||
tmp_path: Path,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""A compact() that raises mid-rebuild must leave no .compact* files.
|
||||
|
||||
Normally the sole connection's close() checkpoints the temp WAL away,
|
||||
but a concurrent reader keeps -wal/-shm alive, so the cleanup must
|
||||
unlink them explicitly (as the structural-migration path does).
|
||||
"""
|
||||
store.add([make_node("a1", "1")])
|
||||
compact_path = str(tmp_path / DB_FILENAME) + ".compact"
|
||||
held: list[sqlite3.Connection] = []
|
||||
|
||||
def boom(conn: sqlite3.Connection, dim: int) -> None:
|
||||
# Hold an extra connection so close() of the rebuild connection is
|
||||
# not the last one -> the temp -wal/-shm survive the checkpoint.
|
||||
extra = sqlite3.connect(compact_path)
|
||||
extra.execute("SELECT 1").fetchall()
|
||||
held.append(extra)
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(
|
||||
PaperlessSqliteVecVectorStore,
|
||||
"_create_vec_table",
|
||||
staticmethod(boom),
|
||||
)
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
store.compact(force=True)
|
||||
assert sorted(p.name for p in tmp_path.glob("*.compact*")) == []
|
||||
finally:
|
||||
for c in held:
|
||||
c.close()
|
||||
|
||||
|
||||
class TestDbFile:
|
||||
def test_single_db_file_in_index_dir(self, store, tmp_path: Path) -> None:
|
||||
store.add([make_node("a1", "1")])
|
||||
assert (tmp_path / DB_FILENAME).exists()
|
||||
|
||||
def test_wal_mode_enabled(self, store) -> None:
|
||||
assert (
|
||||
store.client.execute("PRAGMA journal_mode").fetchone()[0].lower() == "wal"
|
||||
)
|
||||
|
||||
|
||||
class TestMigrations:
|
||||
"""Tests for the schema migration machinery."""
|
||||
|
||||
def _schema_version(self, store: PaperlessSqliteVecVectorStore) -> int | None:
|
||||
row = store.client.execute(
|
||||
"SELECT value FROM index_meta WHERE key = 'schema_version'",
|
||||
).fetchone()
|
||||
return int(row[0]) if row else None
|
||||
|
||||
def test_new_table_records_schema_version(self, store) -> None:
|
||||
store.add([make_node("a1", "1")])
|
||||
assert self._schema_version(store) == SCHEMA_VERSION
|
||||
|
||||
def test_check_migrations_no_table_returns_false(self, store) -> None:
|
||||
assert store.check_and_run_migrations() is False
|
||||
|
||||
def test_check_migrations_current_version_returns_false(self, store) -> None:
|
||||
store.add([make_node("a1", "1")])
|
||||
assert store.check_and_run_migrations() is False
|
||||
|
||||
def test_reembed_migration_returns_true(self, store, tmp_path: Path) -> None:
|
||||
store.add([make_node("a1", "1")])
|
||||
migration = Migration(
|
||||
from_version=1,
|
||||
to_version=2,
|
||||
kind="re-embed",
|
||||
description="test re-embed",
|
||||
)
|
||||
MIGRATIONS.append(migration)
|
||||
try:
|
||||
from paperless_ai import vector_store as vs_mod
|
||||
|
||||
original = vs_mod.SCHEMA_VERSION
|
||||
vs_mod.SCHEMA_VERSION = 2
|
||||
result = store.check_and_run_migrations()
|
||||
finally:
|
||||
MIGRATIONS.remove(migration)
|
||||
vs_mod.SCHEMA_VERSION = original
|
||||
assert result is True
|
||||
|
||||
def test_structural_migration_copies_rows_and_updates_version(
|
||||
self,
|
||||
store,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
store.add([make_node("a1", "1"), make_node("b1", "2")])
|
||||
|
||||
def apply(
|
||||
src: sqlite3.Connection,
|
||||
dst: sqlite3.Connection,
|
||||
dim: int,
|
||||
) -> None:
|
||||
dst.execute( # nosemgrep
|
||||
f"CREATE VIRTUAL TABLE {DEFAULT_TABLE_NAME} USING vec0("
|
||||
"id TEXT PRIMARY KEY, document_id TEXT, modified TEXT,"
|
||||
f" +node_content TEXT, embedding float[{dim}] distance_metric=cosine"
|
||||
")",
|
||||
)
|
||||
dst.execute(
|
||||
"INSERT INTO index_meta (key, value) VALUES ('dim', ?) "
|
||||
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
|
||||
(str(dim),),
|
||||
)
|
||||
rows = src.execute(
|
||||
"SELECT id, document_id, modified, node_content, embedding "
|
||||
f"FROM {DEFAULT_TABLE_NAME}",
|
||||
).fetchall()
|
||||
dst.execute("BEGIN IMMEDIATE")
|
||||
dst.executemany(
|
||||
f"INSERT INTO {DEFAULT_TABLE_NAME} "
|
||||
"(id, document_id, modified, node_content, embedding) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
(
|
||||
r["id"],
|
||||
r["document_id"],
|
||||
r["modified"],
|
||||
r["node_content"],
|
||||
bytes(r["embedding"]),
|
||||
)
|
||||
for r in rows
|
||||
],
|
||||
)
|
||||
dst.execute(
|
||||
"INSERT INTO index_meta (key, value) VALUES ('total_inserts', ?) "
|
||||
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
|
||||
(str(len(rows)),),
|
||||
)
|
||||
dst.execute("COMMIT")
|
||||
|
||||
migration = Migration(
|
||||
from_version=1,
|
||||
to_version=2,
|
||||
kind="structural",
|
||||
description="test structural",
|
||||
apply=apply,
|
||||
)
|
||||
MIGRATIONS.append(migration)
|
||||
try:
|
||||
from paperless_ai import vector_store as vs_mod
|
||||
|
||||
original = vs_mod.SCHEMA_VERSION
|
||||
vs_mod.SCHEMA_VERSION = 2
|
||||
result = store.check_and_run_migrations()
|
||||
finally:
|
||||
MIGRATIONS.remove(migration)
|
||||
vs_mod.SCHEMA_VERSION = original
|
||||
|
||||
assert result is False
|
||||
assert self._schema_version(store) == 2
|
||||
ids = {n.node_id for n in store.get_nodes()}
|
||||
assert ids == {"a1", "b1"}
|
||||
|
||||
def test_compact_preserves_schema_version(self, store) -> None:
|
||||
store.add([make_node("a1", "1")])
|
||||
assert self._schema_version(store) == SCHEMA_VERSION
|
||||
store.compact(force=True)
|
||||
assert self._schema_version(store) == SCHEMA_VERSION
|
||||
|
||||
def test_stop_at_reembed_boundary(self, store) -> None:
|
||||
# Registry: structural v2, re-embed v3, structural v4.
|
||||
# Only v2 should apply; the re-embed boundary must stop execution
|
||||
# before v4 runs, and the stored version must stay at 2.
|
||||
store.add([make_node("a1", "1"), make_node("b1", "2")])
|
||||
|
||||
def copy_apply(
|
||||
src: sqlite3.Connection,
|
||||
dst: sqlite3.Connection,
|
||||
dim: int,
|
||||
) -> None:
|
||||
dst.execute( # nosemgrep
|
||||
f"CREATE VIRTUAL TABLE {DEFAULT_TABLE_NAME} USING vec0("
|
||||
"id TEXT PRIMARY KEY, document_id TEXT, modified TEXT,"
|
||||
f" +node_content TEXT, embedding float[{dim}] distance_metric=cosine"
|
||||
")",
|
||||
)
|
||||
dst.execute(
|
||||
"INSERT INTO index_meta (key, value) VALUES ('dim', ?) "
|
||||
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
|
||||
(str(dim),),
|
||||
)
|
||||
rows = src.execute(
|
||||
"SELECT id, document_id, modified, node_content, embedding "
|
||||
f"FROM {DEFAULT_TABLE_NAME}",
|
||||
).fetchall()
|
||||
dst.execute("BEGIN IMMEDIATE")
|
||||
dst.executemany(
|
||||
f"INSERT INTO {DEFAULT_TABLE_NAME} "
|
||||
"(id, document_id, modified, node_content, embedding) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
(
|
||||
r["id"],
|
||||
r["document_id"],
|
||||
r["modified"],
|
||||
r["node_content"],
|
||||
bytes(r["embedding"]),
|
||||
)
|
||||
for r in rows
|
||||
],
|
||||
)
|
||||
dst.execute("COMMIT")
|
||||
|
||||
migrations = [
|
||||
Migration(
|
||||
from_version=1,
|
||||
to_version=2,
|
||||
kind="structural",
|
||||
description="v2 structural",
|
||||
apply=copy_apply,
|
||||
),
|
||||
Migration(
|
||||
from_version=2,
|
||||
to_version=3,
|
||||
kind="re-embed",
|
||||
description="v3 re-embed boundary",
|
||||
),
|
||||
Migration(
|
||||
from_version=3,
|
||||
to_version=4,
|
||||
kind="structural",
|
||||
description="v4 structural - must not run",
|
||||
apply=copy_apply,
|
||||
),
|
||||
]
|
||||
MIGRATIONS.extend(migrations)
|
||||
try:
|
||||
from paperless_ai import vector_store as vs_mod
|
||||
|
||||
original = vs_mod.SCHEMA_VERSION
|
||||
vs_mod.SCHEMA_VERSION = 4
|
||||
result = store.check_and_run_migrations()
|
||||
finally:
|
||||
for m in migrations:
|
||||
MIGRATIONS.remove(m)
|
||||
vs_mod.SCHEMA_VERSION = original
|
||||
|
||||
assert result is True
|
||||
assert self._schema_version(store) == 2
|
||||
@@ -0,0 +1,77 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from django.contrib.auth.models import User
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from documents.models import Document
|
||||
from documents.tests.factories import DocumentFactory
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestSuggestionsHintWiring:
|
||||
@pytest.fixture
|
||||
def document(self) -> Document:
|
||||
return DocumentFactory() # type: ignore[return-value]
|
||||
|
||||
@pytest.fixture
|
||||
def api_client(self, admin_user: User) -> APIClient:
|
||||
client = APIClient()
|
||||
client.force_authenticate(user=admin_user)
|
||||
return client
|
||||
|
||||
def test_hints_passed_to_classifier_and_matchers(
|
||||
self,
|
||||
api_client: APIClient,
|
||||
document: Document,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
hints = {
|
||||
"tags": ["Bloodwork"],
|
||||
"document_types": [],
|
||||
"correspondents": [],
|
||||
"storage_paths": [],
|
||||
}
|
||||
mocker.patch(
|
||||
"documents.views.get_taxonomy_hints_for_document",
|
||||
return_value=hints,
|
||||
)
|
||||
mocker.patch(
|
||||
"documents.views.AIConfig",
|
||||
return_value=SimpleNamespace(
|
||||
ai_enabled=True,
|
||||
llm_backend="ollama",
|
||||
llm_output_language=None,
|
||||
),
|
||||
)
|
||||
# No cached suggestion -> the view reaches the classifier path.
|
||||
mocker.patch(
|
||||
"documents.views.get_llm_suggestion_cache",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch("documents.views.set_llm_suggestions_cache")
|
||||
classify = mocker.patch(
|
||||
"documents.views.get_ai_document_classification",
|
||||
return_value={
|
||||
"title": "Doc",
|
||||
"tags": ["Bloodwork"],
|
||||
"correspondents": [],
|
||||
"document_types": [],
|
||||
"storage_paths": [],
|
||||
"dates": [],
|
||||
},
|
||||
)
|
||||
match_tags = mocker.patch(
|
||||
"documents.views.match_tags_by_name",
|
||||
return_value=[],
|
||||
)
|
||||
mocker.patch("documents.views.match_correspondents_by_name", return_value=[])
|
||||
mocker.patch("documents.views.match_document_types_by_name", return_value=[])
|
||||
mocker.patch("documents.views.match_storage_paths_by_name", return_value=[])
|
||||
|
||||
response = api_client.get(f"/api/documents/{document.pk}/ai_suggestions/")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert classify.call_args.kwargs["hints"] == hints
|
||||
assert match_tags.call_args.kwargs["hinted_names"] == {"Bloodwork"}
|
||||
@@ -0,0 +1,604 @@
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import struct
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
import sqlite_vec
|
||||
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.core.vector_stores.types import FilterCondition
|
||||
from llama_index.core.vector_stores.types import FilterOperator
|
||||
from llama_index.core.vector_stores.types import MetadataFilter
|
||||
from llama_index.core.vector_stores.types import MetadataFilters
|
||||
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryResult
|
||||
from llama_index.core.vector_stores.utils import metadata_dict_to_node
|
||||
from llama_index.core.vector_stores.utils import node_to_metadata_dict
|
||||
|
||||
logger = logging.getLogger("paperless_ai.vector_store")
|
||||
|
||||
DB_FILENAME = "llmindex.db"
|
||||
DEFAULT_TABLE_NAME = "documents"
|
||||
|
||||
# Current schema version. Written to index_meta at table creation and bumped
|
||||
# whenever a Migration is added to MIGRATIONS. check_and_run_migrations() uses
|
||||
# this to decide which migrations to run on an existing store.
|
||||
SCHEMA_VERSION = 1
|
||||
|
||||
# compact(): rebuild when the cumulative rowid count exceeds this multiple of
|
||||
# the live row count. DELETEs on vec0 tables never reclaim space (upstream
|
||||
# asg017/sqlite-vec#54), so per-document re-index churn grows the file until
|
||||
# a rebuild copies the live rows into a fresh table.
|
||||
COMPACT_BLOAT_RATIO = 2.0
|
||||
|
||||
# Filterable vec0 metadata columns. _build_where() only ever receives filter
|
||||
# keys we construct ourselves, but allowlisting keeps SQL identifiers safe by
|
||||
# construction.
|
||||
_FILTER_COLUMNS = frozenset({"document_id", "modified"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class Migration:
|
||||
"""A schema migration for the sqlite-vec vector store.
|
||||
|
||||
kind="structural": rows are copied into a new-schema file with no
|
||||
re-embedding needed. Supply ``apply(src_conn, dst_conn, dim)`` which
|
||||
must create the vec0 table in ``dst_conn``, copy all rows from
|
||||
``src_conn``, and write ``dim`` / ``embed_model`` / ``total_inserts`` to
|
||||
``dst_conn``'s ``index_meta``. ``schema_version`` is written by the
|
||||
migration runner after ``apply`` returns.
|
||||
|
||||
kind="re-embed": the new schema requires fresh embeddings.
|
||||
``check_and_run_migrations()`` returns True when it encounters one of
|
||||
these so the caller can force a full rebuild (which recreates the table
|
||||
at the current SCHEMA_VERSION).
|
||||
"""
|
||||
|
||||
from_version: int
|
||||
to_version: int
|
||||
kind: Literal["structural", "re-embed"]
|
||||
description: str
|
||||
apply: Callable[[sqlite3.Connection, sqlite3.Connection, int], None] | None = field(
|
||||
default=None,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
|
||||
# Registry of all schema migrations in order. Empty at v1 -- this is the
|
||||
# baseline. Add entries here (and bump SCHEMA_VERSION) when the schema changes.
|
||||
MIGRATIONS: list[Migration] = []
|
||||
|
||||
|
||||
def _pack(embedding: Sequence[float]) -> bytes:
|
||||
return struct.pack(f"{len(embedding)}f", *embedding)
|
||||
|
||||
|
||||
def _unpack(blob: bytes) -> list[float]:
|
||||
return list(struct.unpack(f"{len(blob) // 4}f", blob))
|
||||
|
||||
|
||||
def _build_where(filters: MetadataFilters | None) -> tuple[str, list[str]]:
|
||||
"""Translate the EQ / IN filters we use into a parameterized SQL clause
|
||||
on vec0 metadata columns. Returns ("", []) when there is nothing to filter.
|
||||
"""
|
||||
if filters is None or not filters.filters:
|
||||
return "", []
|
||||
clauses: list[str] = []
|
||||
params: list[str] = []
|
||||
for f in filters.filters:
|
||||
# filters.filters is Union[MetadataFilter, ExactMatchFilter, MetadataFilters];
|
||||
# we only build MetadataFilter entries, so skip anything else at runtime.
|
||||
if not isinstance(f, MetadataFilter):
|
||||
continue
|
||||
if f.key not in _FILTER_COLUMNS: # pragma: no cover - we build the keys
|
||||
raise NotImplementedError(f"Unsupported filter column: {f.key}")
|
||||
if f.operator == FilterOperator.IN:
|
||||
values = [str(v) for v in f.value] # type: ignore[union-attr] # value is list when operator is IN
|
||||
if not values: # pragma: no cover
|
||||
clauses.append("1 = 0")
|
||||
continue
|
||||
placeholders = ",".join("?" for _ in values)
|
||||
clauses.append(f"{f.key} IN ({placeholders})")
|
||||
params.extend(values)
|
||||
elif f.operator == FilterOperator.EQ:
|
||||
clauses.append(f"{f.key} = ?")
|
||||
params.append(str(f.value))
|
||||
else: # pragma: no cover - we only ever build EQ/IN filters
|
||||
raise NotImplementedError(f"Unsupported filter operator: {f.operator}")
|
||||
if not clauses:
|
||||
# Filters were requested but none could be translated. Fail closed
|
||||
# rather than emit "()" (invalid SQL): filters scope document access,
|
||||
# so an empty translation must match no rows, never widen the scope.
|
||||
return "1 = 0", []
|
||||
joiner = " OR " if filters.condition == FilterCondition.OR else " AND "
|
||||
return "(" + joiner.join(clauses) + ")", params
|
||||
|
||||
|
||||
class PaperlessSqliteVecVectorStore(BasePydanticVectorStore):
|
||||
"""A llama-index vector store backed by a sqlite-vec vec0 table.
|
||||
|
||||
Stores one row per node: the node id (TEXT primary key), its document id
|
||||
(metadata column, used for EQ/IN filtering and per-document delete), the
|
||||
document's modified timestamp, the embedding (float32, cosine metric), and
|
||||
the serialized node (text + metadata) as JSON in an auxiliary column.
|
||||
``stores_text`` lets llama-index run off this store alone, with no
|
||||
separate docstore or index store.
|
||||
|
||||
Everything lives in one SQLite database file (``DB_FILENAME``) inside the
|
||||
directory given as ``uri`` (kept as a directory for compatibility with the
|
||||
previous LanceDB layout). WAL mode allows readers in other processes to
|
||||
proceed while the (FileLock-serialized) writer holds a transaction.
|
||||
|
||||
Implemented surface of ``BasePydanticVectorStore``
|
||||
---------------------------------------------------
|
||||
Only the methods actively used by this codebase are implemented.
|
||||
``delete_nodes`` and the ``node_ids`` lookup path of ``get_nodes`` are
|
||||
part of the llama-index interface contract and may be needed if a future
|
||||
retriever or extension invokes them — add them then, with tests.
|
||||
"""
|
||||
|
||||
stores_text: bool = True
|
||||
flat_metadata: bool = False
|
||||
|
||||
_uri: str = PrivateAttr()
|
||||
_embed_model_name: str | None = PrivateAttr()
|
||||
_conn: Any = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
embed_model_name: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(stores_text=True, flat_metadata=False)
|
||||
self._uri = uri
|
||||
self._embed_model_name = embed_model_name
|
||||
self._conn = self._open_connection(str(Path(uri) / DB_FILENAME))
|
||||
|
||||
@staticmethod
|
||||
def _open_connection(db_path: str) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(
|
||||
db_path,
|
||||
timeout=30,
|
||||
isolation_level=None, # autocommit; explicit transactions below
|
||||
)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.enable_load_extension(True) # noqa: FBT003
|
||||
sqlite_vec.load(conn)
|
||||
conn.enable_load_extension(False) # noqa: FBT003
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS index_meta (key TEXT PRIMARY KEY, value TEXT)",
|
||||
)
|
||||
return conn
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
return self._conn
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the underlying SQLite connection (idempotent)."""
|
||||
self._conn.close()
|
||||
|
||||
def __enter__(self) -> "PaperlessSqliteVecVectorStore":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
# Deterministically release the connection (and its WAL/SHM handles) so
|
||||
# it is never left open across a compaction/migration file swap.
|
||||
self.close()
|
||||
|
||||
@contextmanager
|
||||
def _transaction(self) -> Iterator[None]:
|
||||
self._conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
yield
|
||||
except BaseException: # pragma: no cover
|
||||
self._conn.execute("ROLLBACK")
|
||||
raise
|
||||
else:
|
||||
self._conn.execute("COMMIT")
|
||||
|
||||
def _meta_get(self, key: str) -> str | None:
|
||||
row = self._conn.execute(
|
||||
"SELECT value FROM index_meta WHERE key = ?",
|
||||
(key,),
|
||||
).fetchone()
|
||||
return row["value"] if row else None
|
||||
|
||||
@staticmethod
|
||||
def _meta_set_on(conn: sqlite3.Connection, key: str, value: str) -> None:
|
||||
conn.execute(
|
||||
"INSERT INTO index_meta (key, value) VALUES (?, ?) "
|
||||
"ON CONFLICT(key) DO UPDATE SET value = excluded.value",
|
||||
(key, value),
|
||||
)
|
||||
|
||||
def _meta_set(self, key: str, value: str) -> None:
|
||||
self._meta_set_on(self._conn, key, value)
|
||||
|
||||
def table_exists(self) -> bool:
|
||||
return (
|
||||
self._conn.execute(
|
||||
"SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ?",
|
||||
(DEFAULT_TABLE_NAME,),
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
def vector_dim(self) -> int | None:
|
||||
if not self.table_exists():
|
||||
return None
|
||||
value = self._meta_get("dim")
|
||||
return int(value) if value else None
|
||||
|
||||
def drop_table(self) -> None:
|
||||
self._conn.execute("DROP TABLE IF EXISTS " + DEFAULT_TABLE_NAME)
|
||||
self._conn.execute("DELETE FROM index_meta")
|
||||
|
||||
def stored_model_name(self) -> str | None:
|
||||
"""Return the embedding model name recorded at table creation, or None."""
|
||||
if not self.table_exists():
|
||||
return None
|
||||
return self._meta_get("embed_model")
|
||||
|
||||
def config_mismatch(self, model_name: str) -> bool:
|
||||
"""True when the stored model name differs from ``model_name``.
|
||||
|
||||
Returns False when no table exists or when the table predates
|
||||
model-name tracking — conservative default avoids spurious rebuilds.
|
||||
"""
|
||||
stored = self.stored_model_name()
|
||||
if stored is None:
|
||||
return False
|
||||
return stored != model_name
|
||||
|
||||
@staticmethod
|
||||
def _create_vec_table(conn: sqlite3.Connection, dim: int) -> None:
|
||||
# document_id is deliberately a metadata column, NOT a partition key:
|
||||
# partition keys change KNN `k` to per-partition semantics under IN
|
||||
# filters (asg017/sqlite-vec#142); metadata columns give a correct
|
||||
# global top-k.
|
||||
conn.execute( # nosemgrep: python.sqlalchemy.security.sqlalchemy-execute-raw-query.sqlalchemy-execute-raw-query
|
||||
"CREATE VIRTUAL TABLE "
|
||||
+ DEFAULT_TABLE_NAME
|
||||
+ " USING vec0("
|
||||
+ "id TEXT PRIMARY KEY,"
|
||||
+ " document_id TEXT,"
|
||||
+ " modified TEXT,"
|
||||
+ " +node_content TEXT,"
|
||||
+ " embedding float["
|
||||
+ str(int(dim))
|
||||
+ "] distance_metric=cosine"
|
||||
+ ")",
|
||||
)
|
||||
|
||||
def _create_table(self, dim: int) -> None:
|
||||
self._create_vec_table(self._conn, dim)
|
||||
self._meta_set("dim", str(dim))
|
||||
self._meta_set("schema_version", str(SCHEMA_VERSION))
|
||||
if self._embed_model_name:
|
||||
self._meta_set("embed_model", self._embed_model_name)
|
||||
|
||||
def _ensure_table(self, dim: int) -> None:
|
||||
if not self.table_exists():
|
||||
self._create_table(dim)
|
||||
|
||||
def _row(self, node: BaseNode) -> tuple[str, str, str, str, bytes]:
|
||||
meta = node_to_metadata_dict(
|
||||
node,
|
||||
remove_text=False,
|
||||
flat_metadata=self.flat_metadata,
|
||||
)
|
||||
# vec0 metadata columns reject NULL (asg017/sqlite-vec#141): coerce
|
||||
# every value to a string, with "" as the absent sentinel.
|
||||
document_id = node.ref_doc_id or node.metadata.get("document_id")
|
||||
return (
|
||||
node.node_id,
|
||||
str(document_id or ""),
|
||||
str(node.metadata.get("modified") or ""),
|
||||
json.dumps(meta),
|
||||
_pack(node.get_embedding()),
|
||||
)
|
||||
|
||||
_INSERT = (
|
||||
"INSERT INTO "
|
||||
+ DEFAULT_TABLE_NAME
|
||||
+ " (id, document_id, modified, node_content, embedding) VALUES (?, ?, ?, ?, ?)"
|
||||
)
|
||||
|
||||
def _increment_total_inserts(self, count: int) -> None:
|
||||
"""Increment the cumulative insert counter stored in index_meta.
|
||||
|
||||
This counter never decreases (DELETEs do not decrement it) and is
|
||||
used by compact() to estimate the bloat ratio: when total_inserts /
|
||||
live_rows exceeds COMPACT_BLOAT_RATIO the table has accumulated
|
||||
enough deleted-but-not-freed rows to warrant a rebuild.
|
||||
"""
|
||||
current = int(self._meta_get("total_inserts") or "0")
|
||||
self._meta_set("total_inserts", str(current + count))
|
||||
|
||||
def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]:
|
||||
if not nodes:
|
||||
return []
|
||||
rows = [self._row(node) for node in nodes]
|
||||
with self._transaction():
|
||||
self._ensure_table(len(nodes[0].get_embedding()))
|
||||
self._conn.executemany(self._INSERT, rows)
|
||||
self._increment_total_inserts(len(rows))
|
||||
return [node.node_id for node in nodes]
|
||||
|
||||
def upsert_document(self, document_id: str, nodes: list[BaseNode]) -> list[str]:
|
||||
"""Atomically replace all stored chunks of ``document_id`` with ``nodes``.
|
||||
|
||||
One transaction deletes the document's existing rows and inserts the
|
||||
new set (vec0's INSERT OR REPLACE is broken upstream, #259, so
|
||||
delete+insert it is). WAL readers in other processes see either the
|
||||
old or the new chunk set, never a partial state.
|
||||
"""
|
||||
rows = [self._row(node) for node in nodes]
|
||||
with self._transaction():
|
||||
if nodes:
|
||||
self._ensure_table(len(nodes[0].get_embedding()))
|
||||
if self.table_exists():
|
||||
self._conn.execute(
|
||||
"DELETE FROM " + DEFAULT_TABLE_NAME + " WHERE document_id = ?",
|
||||
(str(document_id),),
|
||||
)
|
||||
if rows:
|
||||
self._conn.executemany(self._INSERT, rows)
|
||||
self._increment_total_inserts(len(rows))
|
||||
return [node.node_id for node in nodes]
|
||||
|
||||
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
||||
if self.table_exists():
|
||||
with self._transaction():
|
||||
self._conn.execute(
|
||||
"DELETE FROM " + DEFAULT_TABLE_NAME + " WHERE document_id = ?",
|
||||
(str(ref_doc_id),),
|
||||
)
|
||||
|
||||
def _rows_to_nodes(self, rows: list[sqlite3.Row]) -> list[BaseNode]:
|
||||
nodes: list[BaseNode] = []
|
||||
for row in rows:
|
||||
node = metadata_dict_to_node(json.loads(row["node_content"]))
|
||||
node.embedding = _unpack(row["embedding"])
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
def get_nodes(
|
||||
self,
|
||||
node_ids: list[str] | None = None,
|
||||
filters: MetadataFilters | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[BaseNode]:
|
||||
if node_ids is not None: # pragma: no cover
|
||||
# node_ids lookup is not implemented; see class docstring.
|
||||
raise NotImplementedError(
|
||||
"PaperlessSqliteVecVectorStore does not support node_ids lookup",
|
||||
)
|
||||
if not self.table_exists():
|
||||
return []
|
||||
where, params = _build_where(filters)
|
||||
sql = "SELECT node_content, embedding FROM " + DEFAULT_TABLE_NAME
|
||||
if where:
|
||||
sql += " WHERE " + where
|
||||
return self._rows_to_nodes(self._conn.execute(sql, params).fetchall())
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: VectorStoreQuery,
|
||||
**kwargs: Any,
|
||||
) -> VectorStoreQueryResult:
|
||||
if not self.table_exists():
|
||||
return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
|
||||
if query.query_embedding is None: # pragma: no cover
|
||||
return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
|
||||
top_k = query.similarity_top_k if query.similarity_top_k is not None else 10
|
||||
where, params = _build_where(query.filters)
|
||||
sql = (
|
||||
"SELECT id, node_content, embedding, distance FROM "
|
||||
+ DEFAULT_TABLE_NAME
|
||||
+ " WHERE embedding MATCH ? AND k = ?"
|
||||
)
|
||||
if where:
|
||||
sql += " AND " + where
|
||||
rows = self._conn.execute(
|
||||
sql,
|
||||
[_pack(query.query_embedding), top_k, *params],
|
||||
).fetchall()
|
||||
# vec0 returns rows distance-sorted ascending; slice defensively in
|
||||
# case future schema changes alter k semantics (e.g. partition keys
|
||||
# return k rows per partition).
|
||||
rows = rows[:top_k]
|
||||
nodes = self._rows_to_nodes(rows)
|
||||
# Cosine distance in [0, 2]; map to a descending similarity.
|
||||
# vec0 returns None distance when the query embedding is the zero vector
|
||||
# (no meaningful cosine angle); treat that as maximum distance (1.0) so
|
||||
# the row is included but ranked last.
|
||||
sims = [
|
||||
1.0 - float(row["distance"] if row["distance"] is not None else 1.0)
|
||||
for row in rows
|
||||
]
|
||||
ids = [row["id"] for row in rows]
|
||||
return VectorStoreQueryResult(nodes=nodes, similarities=sims, ids=ids)
|
||||
|
||||
def get_modified_times(self) -> dict[str, str]:
|
||||
"""Return {document_id: stored_modified_isoformat} for all indexed documents.
|
||||
|
||||
All chunks of a document share the same ``modified`` value, so the
|
||||
first row seen per document is sufficient.
|
||||
"""
|
||||
if not self.table_exists():
|
||||
return {}
|
||||
result: dict[str, str] = {}
|
||||
for row in self._conn.execute(
|
||||
"SELECT document_id, modified FROM " + DEFAULT_TABLE_NAME,
|
||||
):
|
||||
doc_id = str(row["document_id"])
|
||||
if doc_id not in result:
|
||||
result[doc_id] = str(row["modified"] or "")
|
||||
return result
|
||||
|
||||
def compact(self, *, force: bool = False) -> None:
|
||||
"""Rebuild the database file to reclaim space left behind by DELETEs.
|
||||
|
||||
vec0 DELETE only invalidates rows; the vector data stays in the file
|
||||
forever (asg017/sqlite-vec#54), and per-document re-indexing is a
|
||||
delete+insert. The cumulative insert counter in ``index_meta`` tracks
|
||||
total rows ever written; when that exceeds ``COMPACT_BLOAT_RATIO`` x
|
||||
the live row count (or when forced), live rows are copied into a fresh
|
||||
database file and swapped in via ``os.replace``.
|
||||
|
||||
Note: ``ALTER TABLE ... RENAME TO`` on vec0 virtual tables does NOT
|
||||
rename the shadow tables (sqlite-vec upstream limitation), so
|
||||
an in-place rename-based rebuild is not safe. The file-swap approach
|
||||
is the maintainer-endorsed workaround (asg017/sqlite-vec#205).
|
||||
"""
|
||||
if not self.table_exists():
|
||||
return
|
||||
live = self._conn.execute(
|
||||
"SELECT count(*) FROM " + DEFAULT_TABLE_NAME,
|
||||
).fetchone()[0]
|
||||
total = int(self._meta_get("total_inserts") or str(live))
|
||||
if not force and total <= max(live, 1) * COMPACT_BLOAT_RATIO:
|
||||
return
|
||||
dim = self.vector_dim()
|
||||
if dim is None: # pragma: no cover - dim is written at creation
|
||||
logger.warning("Skipping compact: no stored vector dimension")
|
||||
return
|
||||
logger.info(
|
||||
"Compacting LLM index (%d live rows, %d cumulative inserts)",
|
||||
live,
|
||||
total,
|
||||
)
|
||||
db_path = str(Path(self._uri) / DB_FILENAME)
|
||||
compact_path = db_path + ".compact"
|
||||
|
||||
# Copy all live rows into a fresh database file.
|
||||
new_conn = self._open_connection(compact_path)
|
||||
try:
|
||||
self._create_vec_table(new_conn, dim)
|
||||
self._meta_set_on(new_conn, "dim", str(dim))
|
||||
for key in ("embed_model", "schema_version"):
|
||||
value = self._meta_get(key)
|
||||
if value is not None:
|
||||
self._meta_set_on(new_conn, key, value)
|
||||
rows = self._conn.execute(
|
||||
"SELECT id, document_id, modified, node_content, embedding "
|
||||
"FROM " + DEFAULT_TABLE_NAME,
|
||||
).fetchall()
|
||||
new_conn.execute("BEGIN IMMEDIATE")
|
||||
new_conn.executemany(
|
||||
self._INSERT,
|
||||
[
|
||||
(
|
||||
r["id"],
|
||||
r["document_id"],
|
||||
r["modified"],
|
||||
r["node_content"],
|
||||
bytes(r["embedding"]),
|
||||
)
|
||||
for r in rows
|
||||
],
|
||||
)
|
||||
# Reset the cumulative counter: after compact, total_inserts == live.
|
||||
self._meta_set_on(new_conn, "total_inserts", str(live))
|
||||
new_conn.execute("COMMIT")
|
||||
except BaseException:
|
||||
new_conn.close()
|
||||
for p in [compact_path, compact_path + "-wal", compact_path + "-shm"]:
|
||||
Path(p).unlink(missing_ok=True)
|
||||
raise
|
||||
new_conn.close()
|
||||
self._swap_in_compact(compact_path, db_path)
|
||||
|
||||
def _swap_in_compact(self, compact_path: str, db_path: str) -> None:
|
||||
"""Atomically replace the live database with the compacted copy."""
|
||||
self._conn.close()
|
||||
for suffix in ["-wal", "-shm"]:
|
||||
stale = Path(compact_path + suffix)
|
||||
if stale.exists(): # pragma: no cover
|
||||
stale.unlink()
|
||||
Path(compact_path).replace(db_path)
|
||||
self._conn = self._open_connection(db_path)
|
||||
|
||||
def check_and_run_migrations(self) -> bool:
|
||||
"""Apply any pending schema migrations to the store.
|
||||
|
||||
Structural migrations copy live rows into a new-schema file with no
|
||||
re-embedding. Re-embed migrations cannot be applied automatically;
|
||||
this method returns True when one is encountered so the caller can
|
||||
force a full rebuild (which recreates the table at SCHEMA_VERSION).
|
||||
|
||||
Must be called under the write FileLock. No-op when the table does
|
||||
not exist or is already at SCHEMA_VERSION.
|
||||
"""
|
||||
if not self.table_exists():
|
||||
return False
|
||||
|
||||
raw = self._meta_get("schema_version")
|
||||
current = int(raw) if raw is not None else SCHEMA_VERSION
|
||||
if current >= SCHEMA_VERSION:
|
||||
return False
|
||||
|
||||
pending = sorted(
|
||||
[m for m in MIGRATIONS if current <= m.from_version < SCHEMA_VERSION],
|
||||
key=lambda m: m.from_version,
|
||||
)
|
||||
|
||||
for migration in pending:
|
||||
if migration.kind == "re-embed":
|
||||
logger.warning(
|
||||
"LLM index schema v%d -> v%d requires re-embedding (%s); "
|
||||
"forcing full rebuild.",
|
||||
migration.from_version,
|
||||
migration.to_version,
|
||||
migration.description,
|
||||
)
|
||||
return True
|
||||
logger.info(
|
||||
"Running structural LLM index migration v%d -> v%d: %s",
|
||||
migration.from_version,
|
||||
migration.to_version,
|
||||
migration.description,
|
||||
)
|
||||
self._run_structural_migration(migration)
|
||||
|
||||
return False
|
||||
|
||||
def _run_structural_migration(self, migration: Migration) -> None:
|
||||
"""Execute a structural migration using the same file-swap as compact()."""
|
||||
assert migration.apply is not None, "structural migration must have apply()"
|
||||
dim = self.vector_dim()
|
||||
if dim is None: # pragma: no cover
|
||||
raise RuntimeError("Cannot migrate: no stored vector dimension")
|
||||
db_path = str(Path(self._uri) / DB_FILENAME)
|
||||
compact_path = db_path + ".compact"
|
||||
new_conn = self._open_connection(compact_path)
|
||||
try:
|
||||
migration.apply(self._conn, new_conn, dim)
|
||||
self._meta_set_on(new_conn, "schema_version", str(migration.to_version))
|
||||
except BaseException: # pragma: no cover
|
||||
new_conn.close()
|
||||
for p in [compact_path, compact_path + "-wal", compact_path + "-shm"]:
|
||||
Path(p).unlink(missing_ok=True)
|
||||
raise
|
||||
new_conn.close()
|
||||
self._swap_in_compact(compact_path, db_path)
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
import ssl
|
||||
import tempfile
|
||||
import traceback
|
||||
import unicodedata
|
||||
from datetime import date
|
||||
from datetime import timedelta
|
||||
from fnmatch import fnmatch
|
||||
@@ -496,10 +497,10 @@ class MailAccountHandler(LoggingMixin):
|
||||
rule: MailRule,
|
||||
) -> str | None:
|
||||
if rule.assign_title_from == MailRule.TitleSource.FROM_SUBJECT:
|
||||
return message.subject
|
||||
return unicodedata.normalize("NFC", message.subject)
|
||||
|
||||
elif rule.assign_title_from == MailRule.TitleSource.FROM_FILENAME:
|
||||
return Path(att.filename).stem
|
||||
return unicodedata.normalize("NFC", Path(att.filename).stem)
|
||||
|
||||
elif rule.assign_title_from == MailRule.TitleSource.NONE:
|
||||
return None
|
||||
@@ -866,7 +867,9 @@ class MailAccountHandler(LoggingMixin):
|
||||
),
|
||||
)
|
||||
|
||||
attachment_name = pathvalidate.sanitize_filename(att.filename)
|
||||
attachment_name = pathvalidate.sanitize_filename(
|
||||
unicodedata.normalize("NFC", att.filename),
|
||||
)
|
||||
if attachment_name:
|
||||
temp_filename = temp_dir / attachment_name
|
||||
else: # pragma: no cover
|
||||
@@ -882,7 +885,7 @@ class MailAccountHandler(LoggingMixin):
|
||||
)
|
||||
doc_overrides = DocumentMetadataOverrides(
|
||||
title=title,
|
||||
filename=pathvalidate.sanitize_filename(att.filename),
|
||||
filename=attachment_name,
|
||||
correspondent_id=correspondent.id if correspondent else None,
|
||||
document_type_id=doc_type.id if doc_type else None,
|
||||
tag_ids=tag_ids,
|
||||
@@ -988,7 +991,9 @@ class MailAccountHandler(LoggingMixin):
|
||||
)
|
||||
doc_overrides = DocumentMetadataOverrides(
|
||||
title=message.subject,
|
||||
filename=pathvalidate.sanitize_filename(f"{message.subject}.eml"),
|
||||
filename=pathvalidate.sanitize_filename(
|
||||
unicodedata.normalize("NFC", f"{message.subject}.eml"),
|
||||
),
|
||||
correspondent_id=correspondent.id if correspondent else None,
|
||||
document_type_id=doc_type.id if doc_type else None,
|
||||
tag_ids=tag_ids,
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Tests that mail attachment filenames and EML subject filenames are
|
||||
normalized to NFC Unicode before being stored as document overrides.
|
||||
|
||||
Filenames from MIME headers can arrive in NFD form (e.g. from macOS Mail),
|
||||
and must be normalized to NFC so filenames are consistent regardless of the
|
||||
sending client.
|
||||
"""
|
||||
|
||||
import unicodedata
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from documents.tests.utils import remove_dirs
|
||||
from documents.tests.utils import setup_directories
|
||||
from paperless_mail.models import MailRule
|
||||
from paperless_mail.tests.factories import MailAccountFactory
|
||||
from paperless_mail.tests.test_mail import MessageBuilder
|
||||
from paperless_mail.tests.test_mail import _AttachmentDef
|
||||
from paperless_mail.tests.test_mail import fake_magic_from_buffer
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def directories(settings):
|
||||
dirs = setup_directories()
|
||||
yield dirs
|
||||
remove_dirs(dirs)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def queue_consumption_tasks_mock():
|
||||
with mock.patch("paperless_mail.mail.queue_consumption_tasks") as m:
|
||||
yield m
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mail_account(db):
|
||||
return MailAccountFactory()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def attachment_rule(mail_account):
|
||||
rule = MailRule(
|
||||
name="attachment rule",
|
||||
account=mail_account,
|
||||
assign_title_from=MailRule.TitleSource.FROM_FILENAME,
|
||||
consumption_scope=MailRule.ConsumptionScope.ATTACHMENTS_ONLY,
|
||||
attachment_type=MailRule.AttachmentProcessing.ATTACHMENTS_ONLY,
|
||||
)
|
||||
rule.save()
|
||||
return rule
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def eml_rule(mail_account):
|
||||
rule = MailRule(
|
||||
name="eml rule",
|
||||
account=mail_account,
|
||||
assign_title_from=MailRule.TitleSource.FROM_SUBJECT,
|
||||
consumption_scope=MailRule.ConsumptionScope.EML_ONLY,
|
||||
attachment_type=MailRule.AttachmentProcessing.ATTACHMENTS_ONLY,
|
||||
)
|
||||
rule.save()
|
||||
return rule
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def message_builder():
|
||||
return MessageBuilder()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@mock.patch("paperless_mail.mail.magic.from_buffer", fake_magic_from_buffer)
|
||||
class TestMailNFCNormalization:
|
||||
"""Attachment filenames and EML subject filenames must be NFC-normalized."""
|
||||
|
||||
def test_attachment_nfd_filename_normalized_to_nfc(
|
||||
self,
|
||||
directories,
|
||||
queue_consumption_tasks_mock,
|
||||
attachment_rule,
|
||||
mail_account_handler,
|
||||
message_builder,
|
||||
):
|
||||
"""Attachment filename arriving as NFD must be stored as NFC in both
|
||||
the overrides and the temp file written to disk.
|
||||
"""
|
||||
nfd_filename = unicodedata.normalize("NFD", "Rechnung März.pdf")
|
||||
nfc_filename = unicodedata.normalize("NFC", "Rechnung März.pdf")
|
||||
|
||||
# Confirm the fixture is actually NFD (not already NFC)
|
||||
assert unicodedata.is_normalized("NFD", nfd_filename)
|
||||
assert not unicodedata.is_normalized("NFC", nfd_filename)
|
||||
|
||||
message = message_builder.create_message(
|
||||
subject="Test invoice",
|
||||
from_="sender@example.com",
|
||||
attachments=[
|
||||
_AttachmentDef(filename=nfd_filename, content=b"%PDF-1.4 test"),
|
||||
],
|
||||
)
|
||||
|
||||
result = mail_account_handler._handle_message(message, attachment_rule)
|
||||
|
||||
assert result == 1
|
||||
queue_consumption_tasks_mock.assert_called_once()
|
||||
|
||||
call_kwargs = queue_consumption_tasks_mock.call_args.kwargs
|
||||
consume_tasks = call_kwargs["consume_tasks"]
|
||||
assert len(consume_tasks) == 1
|
||||
|
||||
overrides = consume_tasks[0].kwargs["overrides"]
|
||||
assert overrides.filename == nfc_filename
|
||||
assert unicodedata.is_normalized("NFC", overrides.filename)
|
||||
assert unicodedata.is_normalized("NFC", overrides.title)
|
||||
|
||||
input_doc = consume_tasks[0].kwargs["input_doc"]
|
||||
original_file = Path(input_doc.original_file)
|
||||
assert original_file.exists()
|
||||
assert original_file.name == nfc_filename
|
||||
|
||||
def test_eml_subject_filename_nfc(
|
||||
self,
|
||||
directories,
|
||||
queue_consumption_tasks_mock,
|
||||
eml_rule,
|
||||
mail_account_handler,
|
||||
message_builder,
|
||||
):
|
||||
"""EML filename derived from subject arriving as NFD must be stored as NFC."""
|
||||
nfd_subject = unicodedata.normalize("NFD", "Rechnung März 2024")
|
||||
nfc_expected_filename = unicodedata.normalize("NFC", "Rechnung März 2024.eml")
|
||||
|
||||
# Confirm the fixture is actually NFD
|
||||
assert unicodedata.is_normalized("NFD", nfd_subject)
|
||||
|
||||
message = message_builder.create_message(
|
||||
subject=nfd_subject,
|
||||
from_="sender@example.com",
|
||||
attachments=0,
|
||||
)
|
||||
|
||||
mail_account_handler._handle_message(message, eml_rule)
|
||||
|
||||
queue_consumption_tasks_mock.assert_called_once()
|
||||
|
||||
call_kwargs = queue_consumption_tasks_mock.call_args.kwargs
|
||||
consume_tasks = call_kwargs["consume_tasks"]
|
||||
assert len(consume_tasks) == 1
|
||||
|
||||
overrides = consume_tasks[0].kwargs["overrides"]
|
||||
assert overrides.filename == nfc_expected_filename
|
||||
assert unicodedata.is_normalized("NFC", overrides.filename)
|
||||
|
||||
def test_already_nfc_attachment_filename_unchanged(
|
||||
self,
|
||||
directories,
|
||||
queue_consumption_tasks_mock,
|
||||
attachment_rule,
|
||||
mail_account_handler,
|
||||
message_builder,
|
||||
):
|
||||
"""An attachment filename already in NFC must pass through unchanged."""
|
||||
nfc_filename = "Invoice_2024.pdf"
|
||||
assert unicodedata.is_normalized("NFC", nfc_filename)
|
||||
|
||||
message = message_builder.create_message(
|
||||
subject="Invoice",
|
||||
from_="sender@example.com",
|
||||
attachments=[
|
||||
_AttachmentDef(filename=nfc_filename, content=b"%PDF-1.4 test"),
|
||||
],
|
||||
)
|
||||
|
||||
mail_account_handler._handle_message(message, attachment_rule)
|
||||
|
||||
call_kwargs = queue_consumption_tasks_mock.call_args.kwargs
|
||||
consume_tasks = call_kwargs["consume_tasks"]
|
||||
overrides = consume_tasks[0].kwargs["overrides"]
|
||||
assert overrides.filename == nfc_filename
|
||||
@@ -1200,23 +1200,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/27/8d/2bc5f5546ff2ccb3f7de06742853483ab75bf74f36a92254702f8baecc79/factory_boy-3.3.3-py2.py3-none-any.whl", hash = "sha256:1c39e3289f7e667c4285433f305f8d506efc2fe9c73aaea4151ebd5cdea394fc", size = 37036, upload-time = "2025-02-03T09:49:01.659Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "faiss-cpu"
|
||||
version = "1.13.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/07/c9/671f66f6b31ec48e5825d36435f0cb91189fa8bb6b50724029dbff4ca83c/faiss_cpu-1.13.2-cp310-abi3-macosx_14_0_arm64.whl", hash = "sha256:a9064eb34f8f64438dd5b95c8f03a780b1a3f0b99c46eeacb1f0b5d15fc02dc1", size = 3452776, upload-time = "2025-12-24T10:27:01.419Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/4a/97150aa1582fb9c2bca95bd8fc37f27d3b470acec6f0a6833844b21e4b40/faiss_cpu-1.13.2-cp310-abi3-macosx_14_0_x86_64.whl", hash = "sha256:c8d097884521e1ecaea6467aeebbf1aa56ee4a36350b48b2ca6b39366565c317", size = 7896434, upload-time = "2025-12-24T10:27:03.592Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0b/d0/0940575f059591ca31b63a881058adb16a387020af1709dcb7669460115c/faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ee330a284042c2480f2e90450a10378fd95655d62220159b1408f59ee83ebf1", size = 11485825, upload-time = "2025-12-24T10:27:05.681Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/e1/a5acac02aa593809f0123539afe7b4aff61d1db149e7093239888c9053e1/faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ab88ee287c25a119213153d033f7dd64c3ccec466ace267395872f554b648cd7", size = 23845772, upload-time = "2025-12-24T10:27:08.194Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/7b/49dcaf354834ec457e85ca769d50bc9b5f3003fab7c94a9dcf08cf742793/faiss_cpu-1.13.2-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:85511129b34f890d19c98b82a0cd5ffb27d89d1cec2ee41d2621ee9f9ef8cf3f", size = 13477567, upload-time = "2025-12-24T10:27:10.822Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/6b/12bb4037921c38bb2c0b4cfc213ca7e04bbbebbfea89b0b5746248ce446e/faiss_cpu-1.13.2-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8b32eb4065bac352b52a9f5ae07223567fab0a976c7d05017c01c45a1c24264f", size = 25102239, upload-time = "2025-12-24T10:27:13.476Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "faker"
|
||||
version = "40.15.0"
|
||||
@@ -2280,18 +2263,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/0c/fdddaee5391d915d3d568d2d8dbdb7c95647e65bb94d4ddb31d47cef5daf/llama_index_llms_openai_like-0.7.2-py3-none-any.whl", hash = "sha256:1f45a7b1cec8fb3f5997684327ffe6c19f93e789c2fff35dc5522465850faf0b", size = 6602, upload-time = "2026-04-23T23:05:31.708Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "llama-index-vector-stores-faiss"
|
||||
version = "0.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "llama-index-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7c/32/89a04e38fa9595b7116c61955d9a67085f0a5480738e9c14063e374724c2/llama_index_vector_stores_faiss-0.6.0.tar.gz", hash = "sha256:00bfeb6cb7571e0e856566cb4f10c89b415b6108f151d9ad48ee9c31da563f5e", size = 6045, upload-time = "2026-03-12T20:46:31.454Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/85/465b4f199075ae7773c181b2f98cf689f3107a8de031e7a9d4cd5e906446/llama_index_vector_stores_faiss-0.6.0-py3-none-any.whl", hash = "sha256:d4600c60ef5411d9e35ba573b4f416a5e13ea04c6f942c8e6f49f03f2feb4f3b", size = 7739, upload-time = "2026-03-12T20:46:30.736Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "llama-index-workflows"
|
||||
version = "2.20.0"
|
||||
@@ -2912,7 +2883,6 @@ dependencies = [
|
||||
{ name = "drf-spectacular", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "drf-spectacular-sidecar", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "drf-writable-nested", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "faiss-cpu", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "flower", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "gotenberg-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -2927,7 +2897,6 @@ dependencies = [
|
||||
{ name = "llama-index-embeddings-openai-like", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "llama-index-llms-ollama", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "llama-index-llms-openai-like", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "llama-index-vector-stores-faiss", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "nltk", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "ocrmypdf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -2944,6 +2913,7 @@ dependencies = [
|
||||
{ name = "scikit-learn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "sentence-transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "setproctitle", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "sqlite-vec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tantivy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tika-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "torch", version = "2.11.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" },
|
||||
@@ -3062,7 +3032,6 @@ requires-dist = [
|
||||
{ name = "drf-spectacular", specifier = "~=0.28" },
|
||||
{ name = "drf-spectacular-sidecar", specifier = "~=2026.5.1" },
|
||||
{ name = "drf-writable-nested", specifier = "~=0.7.1" },
|
||||
{ name = "faiss-cpu", specifier = ">=1.10" },
|
||||
{ name = "filelock", specifier = "~=3.29.0" },
|
||||
{ name = "flower", specifier = "~=2.0.1" },
|
||||
{ name = "gotenberg-client", specifier = "~=0.14.0" },
|
||||
@@ -3078,7 +3047,6 @@ requires-dist = [
|
||||
{ name = "llama-index-embeddings-openai-like", specifier = ">=0.2.2" },
|
||||
{ name = "llama-index-llms-ollama", specifier = ">=0.9.1" },
|
||||
{ name = "llama-index-llms-openai-like", specifier = ">=0.7.1" },
|
||||
{ name = "llama-index-vector-stores-faiss", specifier = ">=0.5.2" },
|
||||
{ name = "mysqlclient", marker = "extra == 'mariadb'", specifier = "~=2.2.7" },
|
||||
{ name = "nltk", specifier = "~=3.9.1" },
|
||||
{ name = "ocrmypdf", specifier = "~=17.4.2" },
|
||||
@@ -3101,6 +3069,7 @@ requires-dist = [
|
||||
{ name = "scikit-learn", specifier = "~=1.8.0" },
|
||||
{ name = "sentence-transformers", specifier = ">=5.4.1" },
|
||||
{ name = "setproctitle", specifier = "~=1.3.4" },
|
||||
{ name = "sqlite-vec", specifier = "==0.1.9" },
|
||||
{ name = "tantivy", specifier = "~=0.26.0" },
|
||||
{ name = "tika-client", specifier = "~=0.11.0" },
|
||||
{ name = "torch", specifier = "~=2.11.0", index = "https://download.pytorch.org/whl/cpu" },
|
||||
@@ -4699,6 +4668,17 @@ asyncio = [
|
||||
{ name = "greenlet", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlite-vec"
|
||||
version = "0.1.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/68/85/9fad0045d8e7c8df3e0fa5a56c630e8e15ad6e5ca2e6106fceb666aa6638/sqlite_vec-0.1.9-py3-none-macosx_10_6_x86_64.whl", hash = "sha256:1b62a7f0a060d9475575d4e599bbf94a13d85af896bc1ce86ee80d1b5b48e5fb", size = 131171, upload-time = "2026-03-31T08:02:31.717Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/3d/3677e0cd2f92e5ebc43cd29fbf565b75582bff1ccfa0b8327c7508e1084f/sqlite_vec-0.1.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1d52e30513bae4cc9778ddbf6145610434081be4c3afe57cd877893bad9f6b6c", size = 165434, upload-time = "2026-03-31T08:02:32.712Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/00/d4/f2b936d3bdc38eadcbd2a87875815db36430fab0363182ba5d12cd8e0b51/sqlite_vec-0.1.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e921e592f24a5f9a18f590b6ddd530eb637e2d474e3b1972f9bbeb773aa3cb9", size = 160076, upload-time = "2026-03-31T08:02:33.796Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/ad/6afd073b0f817b3e03f9e37ad626ae341805891f23c74b5292818f49ac63/sqlite_vec-0.1.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux1_x86_64.whl", hash = "sha256:1515727990b49e79bcaf75fdee2ffc7d461f8b66905013231251f1c8938e7786", size = 163388, upload-time = "2026-03-31T08:02:34.888Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlparse"
|
||||
version = "0.5.5"
|
||||
|
||||
Reference in New Issue
Block a user