From fa0c4368d774eed9f2c34da6f63dccb24c4038c1 Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Fri, 5 Jun 2026 06:46:45 -0700 Subject: [PATCH 01/29] Fix: Ensure checksum comparison is using SHA256 in file handling (#12939) --- src/documents/signals/handlers.py | 5 ++--- src/documents/tests/test_file_handling.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index a34a3acf9..f85763d5f 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -1,7 +1,6 @@ from __future__ import annotations import datetime -import hashlib import logging import shutil import traceback as _tb @@ -54,6 +53,7 @@ from documents.models import WorkflowTrigger from documents.permissions import get_objects_for_user_owner_aware from documents.plugins.helpers import DocumentsStatusManager from documents.templating.utils import convert_format_str_to_template_format +from documents.utils import compute_checksum from documents.workflows.actions import build_workflow_action_context from documents.workflows.actions import execute_email_action from documents.workflows.actions import execute_move_to_trash_action @@ -410,8 +410,7 @@ def _path_matches_checksum(path: Path, checksum: str | None) -> bool: if checksum is None or not path.is_file(): return False - with path.open("rb") as f: - return hashlib.md5(f.read()).hexdigest() == checksum + return compute_checksum(path) == checksum def _filename_template_uses_custom_fields(doc: Document) -> bool: diff --git a/src/documents/tests/test_file_handling.py b/src/documents/tests/test_file_handling.py index dc0fbb74c..2ae04b063 100644 --- a/src/documents/tests/test_file_handling.py +++ b/src/documents/tests/test_file_handling.py @@ -221,8 +221,8 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase): doc = Document.objects.create( title="document", mime_type="application/pdf", - checksum=hashlib.md5(original_bytes).hexdigest(), - archive_checksum=hashlib.md5(archive_bytes).hexdigest(), + checksum=hashlib.sha256(original_bytes).hexdigest(), + archive_checksum=hashlib.sha256(archive_bytes).hexdigest(), filename="old/document.pdf", archive_filename="old/document.pdf", storage_path=old_storage_path, From 449fd97b1f052c333470a7ae47896d702d59e04b Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Fri, 5 Jun 2026 07:16:53 -0700 Subject: [PATCH 02/29] Fix (beta): respect disable state for suggest endpoint, require change perms (#12942) --- .../suggestions-dropdown.component.html | 4 ++-- .../suggestions-dropdown.component.spec.ts | 12 ++++++++++++ .../suggestions-dropdown.component.ts | 8 ++++++++ src/documents/views.py | 4 ++-- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.html b/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.html index 7e1a29666..8fb900eaa 100644 --- a/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.html +++ b/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.html @@ -1,5 +1,5 @@
- diff --git a/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.spec.ts b/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.spec.ts index 801a56af3..863393ace 100644 --- a/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.spec.ts +++ b/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.spec.ts @@ -37,6 +37,18 @@ describe('SuggestionsDropdownComponent', () => { expect(component.getSuggestions.emit).toHaveBeenCalled() }) + it('should not emit getSuggestions when disabled', () => { + jest.spyOn(component.getSuggestions, 'emit') + component.disabled = true + component.suggestions = null + fixture.detectChanges() + + component.clickSuggest() + + expect(component.getSuggestions.emit).not.toHaveBeenCalled() + expect(fixture.nativeElement.querySelector('button').disabled).toBeTruthy() + }) + it('should toggle dropdown when clickSuggest is called and suggestions are not null', () => { component.aiEnabled = true fixture.detectChanges() diff --git a/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.ts b/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.ts index b165f0a5e..6cf63f683 100644 --- a/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.ts +++ b/src-ui/src/app/components/common/suggestions-dropdown/suggestions-dropdown.component.ts @@ -47,6 +47,14 @@ export class SuggestionsDropdownComponent { addCorrespondent: EventEmitter = new EventEmitter() public clickSuggest(): void { + if ( + this.disabled || + this.loading || + (this.suggestions && !this.aiEnabled) + ) { + return + } + if (!this.suggestions) { this.getSuggestions.emit(this) } else { diff --git a/src/documents/views.py b/src/documents/views.py index 511429129..4fc0b3f51 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1400,7 +1400,7 @@ class DocumentViewSet( ) if request.user is not None and not has_perms_owner_aware( request.user, - "view_document", + "change_document", doc, ): return HttpResponseForbidden("Insufficient permissions") @@ -1460,7 +1460,7 @@ class DocumentViewSet( ) if request.user is not None and not has_perms_owner_aware( request.user, - "view_document", + "change_document", doc, ): return HttpResponseForbidden("Insufficient permissions") From a7cec673bb37db8071a3171aa631a0be458c29dc Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Sat, 6 Jun 2026 16:00:03 -0700 Subject: [PATCH 03/29] Fix (beta): correct chat message bg color (#12955) --- src-ui/src/app/components/chat/chat/chat.component.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src-ui/src/app/components/chat/chat/chat.component.html b/src-ui/src/app/components/chat/chat/chat.component.html index c5cada978..78cd28ca3 100644 --- a/src-ui/src/app/components/chat/chat/chat.component.html +++ b/src-ui/src/app/components/chat/chat/chat.component.html @@ -8,7 +8,7 @@
@for (message of messages; track message) {
-
+
{{ message.content }} @if (message.isStreaming) { | } From 3d0b8343b9d315598f9675ee51b0f0bb7e9921bb Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Sat, 6 Jun 2026 20:42:06 -0700 Subject: [PATCH 04/29] Fixhancement (beta): tasks dismiss all (#12949) --- .../admin/tasks/tasks.component.html | 3 ++ .../admin/tasks/tasks.component.spec.ts | 43 ++++++++++++++++++- .../components/admin/tasks/tasks.component.ts | 24 +++++++++++ src-ui/src/app/services/tasks.service.spec.ts | 21 +++++++++ src-ui/src/app/services/tasks.service.ts | 14 ++++++ src/documents/serialisers.py | 26 ++++++++++- src/documents/tests/test_api_tasks.py | 21 +++++++++ src/documents/views.py | 15 +++++-- 8 files changed, 160 insertions(+), 7 deletions(-) diff --git a/src-ui/src/app/components/admin/tasks/tasks.component.html b/src-ui/src/app/components/admin/tasks/tasks.component.html index 116d35f89..b8e4f3ff5 100644 --- a/src-ui/src/app/components/admin/tasks/tasks.component.html +++ b/src-ui/src/app/components/admin/tasks/tasks.component.html @@ -11,6 +11,9 @@ +
diff --git a/src-ui/src/app/components/admin/tasks/tasks.component.spec.ts b/src-ui/src/app/components/admin/tasks/tasks.component.spec.ts index 962895295..a87ec49b0 100644 --- a/src-ui/src/app/components/admin/tasks/tasks.component.spec.ts +++ b/src-ui/src/app/components/admin/tasks/tasks.component.spec.ts @@ -11,7 +11,7 @@ import { Router } from '@angular/router' import { RouterTestingModule } from '@angular/router/testing' import { NgbModal, NgbModalRef, NgbModule } from '@ng-bootstrap/ng-bootstrap' import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons' -import { throwError } from 'rxjs' +import { of, throwError } from 'rxjs' import { routes } from 'src/app/app-routing.module' import { PaperlessTask, @@ -295,6 +295,7 @@ describe('TasksComponent', () => { const headerText = header.nativeElement.textContent expect(headerText).toContain('Dismiss visible') + expect(headerText).toContain('Dismiss all') expect(headerText).toContain('Auto refresh') expect(headerText).not.toContain('All types') expect(headerText).not.toContain('All sources') @@ -495,6 +496,46 @@ describe('TasksComponent', () => { expect(dismissSpy).toHaveBeenCalledWith(new Set([467, 466])) }) + it('should support dismiss all tasks', () => { + let modal: NgbModalRef + modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1])) + const dismissSpy = jest + .spyOn(tasksService, 'dismissAllTasks') + .mockReturnValue(of({})) + const reloadPageSpy = jest + .spyOn(component as any, 'reloadPage') + .mockImplementation(() => undefined) + + component.dismissAllTasks() + + expect(modal).not.toBeUndefined() + expect(modal.componentInstance.messageBold).toBe('Dismiss all 7 tasks?') + modal.componentInstance.confirmClicked.emit() + expect(dismissSpy).toHaveBeenCalled() + expect(reloadPageSpy).toHaveBeenCalledWith(false) + expect(component.selectedTasks.size).toBe(0) + }) + + it('should show an error and re-enable modal buttons when dismissing all tasks fails', () => { + const error = new Error('dismiss all failed') + const toastSpy = jest.spyOn(toastService, 'showError') + const dismissSpy = jest + .spyOn(tasksService, 'dismissAllTasks') + .mockReturnValue(throwError(() => error)) + + let modal: NgbModalRef + modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1])) + + component.dismissAllTasks() + expect(modal).not.toBeUndefined() + + modal.componentInstance.confirmClicked.emit() + + expect(dismissSpy).toHaveBeenCalled() + expect(toastSpy).toHaveBeenCalledWith('Error dismissing tasks', error) + expect(modal.componentInstance.buttonsEnabled).toBe(true) + }) + it('should dismiss the currently visible scoped and filtered tasks', () => { component.setSection(TaskSection.InProgress) component.setTaskType(PaperlessTaskType.SanityCheck) diff --git a/src-ui/src/app/components/admin/tasks/tasks.component.ts b/src-ui/src/app/components/admin/tasks/tasks.component.ts index 884ede0d6..ed72a401d 100644 --- a/src-ui/src/app/components/admin/tasks/tasks.component.ts +++ b/src-ui/src/app/components/admin/tasks/tasks.component.ts @@ -334,6 +334,30 @@ export class TasksComponent } } + dismissAllTasks() { + let modal = this.modalService.open(ConfirmDialogComponent, { + backdrop: 'static', + }) + modal.componentInstance.title = $localize`Confirm Dismiss All` + modal.componentInstance.messageBold = $localize`Dismiss all ${this.totalTasks} tasks?` + modal.componentInstance.btnClass = 'btn-warning' + modal.componentInstance.btnCaption = $localize`Dismiss` + modal.componentInstance.confirmClicked.pipe(first()).subscribe(() => { + modal.componentInstance.buttonsEnabled = false + modal.close() + this.tasksService.dismissAllTasks().subscribe({ + next: () => { + this.reloadPage(false) + }, + error: (e) => { + this.toastService.showError($localize`Error dismissing tasks`, e) + modal.componentInstance.buttonsEnabled = true + }, + }) + this.clearSelection() + }) + } + expandTask(task: PaperlessTask) { this.expandedTask = this.expandedTask == task.id ? undefined : task.id } diff --git a/src-ui/src/app/services/tasks.service.spec.ts b/src-ui/src/app/services/tasks.service.spec.ts index 3cc35232d..1ae217543 100644 --- a/src-ui/src/app/services/tasks.service.spec.ts +++ b/src-ui/src/app/services/tasks.service.spec.ts @@ -80,6 +80,27 @@ describe('TasksService', () => { .flush({ count: 0, results: [] }) }) + it('calls acknowledge_tasks api endpoint on dismiss all and reloads', () => { + tasksService.dismissAllTasks().subscribe() + const req = httpTestingController.expectOne( + `${environment.apiBaseUrl}tasks/acknowledge/` + ) + expect(req.request.method).toEqual('POST') + expect(req.request.body).toEqual({ + all: true, + }) + req.flush([]) + // reload is then called + httpTestingController + .expectOne( + (req: HttpRequest) => + req.url === `${environment.apiBaseUrl}tasks/` && + req.params.get('acknowledged') === 'false' && + req.params.get('page_size') === '1000' + ) + .flush({ count: 0, results: [] }) + }) + it('groups mixed task types by status when reloading', () => { expect(tasksService.total).toEqual(0) const mockTasks = [ diff --git a/src-ui/src/app/services/tasks.service.ts b/src-ui/src/app/services/tasks.service.ts index 1eb5e0837..a3ae283ed 100644 --- a/src-ui/src/app/services/tasks.service.ts +++ b/src-ui/src/app/services/tasks.service.ts @@ -116,6 +116,20 @@ export class TasksService { ) } + public dismissAllTasks(): Observable { + return this.http + .post(`${this.baseUrl}tasks/acknowledge/`, { + all: true, + }) + .pipe( + first(), + takeUntil(this.unsubscribeNotifer), + tap(() => { + this.reload() + }) + ) + } + public cancelPending(): void { this.unsubscribeNotifer.next(true) } diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 1ff3798db..82c18b703 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -2632,18 +2632,25 @@ class RunTaskSerializer(serializers.Serializer[dict[str, str]]): class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]): tasks = serializers.ListField( - required=True, + required=False, label="Tasks", write_only=True, child=serializers.IntegerField(), ) + all = serializers.BooleanField( + required=False, + default=False, + label="All", + write_only=True, + ) def _validate_task_id_list(self, tasks, name="tasks") -> None: if not isinstance(tasks, list): raise serializers.ValidationError(f"{name} must be a list") if not all(isinstance(i, int) for i in tasks): raise serializers.ValidationError(f"{name} must be a list of integers") - count = PaperlessTask.objects.filter(id__in=tasks).count() + queryset = self.context.get("queryset", PaperlessTask.objects.all()) + count = queryset.filter(id__in=tasks).count() if not count == len(tasks): raise serializers.ValidationError( f"Some tasks in {name} don't exist or were specified twice.", @@ -2653,6 +2660,21 @@ class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]): self._validate_task_id_list(tasks) return tasks + def validate(self, attrs): + acknowledge_all = attrs.get("all", False) + task_ids = attrs.get("tasks") + + if acknowledge_all and task_ids is not None: + raise serializers.ValidationError( + "Set either all or tasks, not both.", + ) + if not acknowledge_all and task_ids is None: + raise serializers.ValidationError( + "Either all must be true or tasks must be provided.", + ) + + return attrs + class ShareLinkSerializer(OwnedObjectSerializer): class Meta: diff --git a/src/documents/tests/test_api_tasks.py b/src/documents/tests/test_api_tasks.py index f9b6c4538..42ccbab5c 100644 --- a/src/documents/tests/test_api_tasks.py +++ b/src/documents/tests/test_api_tasks.py @@ -522,6 +522,27 @@ class TestAcknowledge: assert response.status_code == status.HTTP_200_OK assert response.data == {"result": 2} + def test_acknowledge_all_returns_count(self, admin_client: APIClient) -> None: + """POST acknowledge/ with all=true acknowledges all unacknowledged tasks.""" + unacknowledged_task1 = PaperlessTaskFactory(acknowledged=False) + unacknowledged_task2 = PaperlessTaskFactory(acknowledged=False) + acknowledged_task = PaperlessTaskFactory(acknowledged=True) + + response = admin_client.post( + ENDPOINT + "acknowledge/", + {"all": True}, + format="json", + ) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"result": 2} + unacknowledged_task1.refresh_from_db() + unacknowledged_task2.refresh_from_db() + acknowledged_task.refresh_from_db() + assert unacknowledged_task1.acknowledged + assert unacknowledged_task2.acknowledged + assert acknowledged_task.acknowledged + def test_acknowledged_tasks_excluded_from_unacked_filter( self, admin_client: APIClient, diff --git a/src/documents/views.py b/src/documents/views.py index 4fc0b3f51..ba4faa622 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -4033,7 +4033,7 @@ class _TasksViewSetSchema(AutoSchema): ), acknowledge=extend_schema( operation_id="acknowledge_tasks", - description="Acknowledge a list of tasks", + description="Acknowledge a list of tasks, or all visible unacknowledged tasks", request=AcknowledgeTasksViewSerializer, responses={ (200, "application/json"): inline_serializer( @@ -4170,10 +4170,17 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]): permission_classes=[IsAuthenticated, AcknowledgeTasksPermissions], ) def acknowledge(self, request): - serializer = AcknowledgeTasksViewSerializer(data=request.data) + queryset = self.get_queryset() + serializer = AcknowledgeTasksViewSerializer( + data=request.data, + context={"queryset": queryset}, + ) serializer.is_valid(raise_exception=True) - task_ids = serializer.validated_data.get("tasks") - tasks = self.get_queryset().filter(id__in=task_ids) + if serializer.validated_data.get("all", False): + tasks = queryset.filter(acknowledged=False) + else: + task_ids = serializer.validated_data.get("tasks") + tasks = queryset.filter(id__in=task_ids) count = tasks.update(acknowledged=True) return Response({"result": count}) From eb292baa6953c3ba49db1b776c415d95276dcb37 Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Sun, 7 Jun 2026 11:31:26 -0700 Subject: [PATCH 05/29] Enhancement (beta): Switch the AI vector store to LanceDB (#12944) Co-authored-by: Claude Opus 4.8 (1M context) Co-authored-by: shamoon --- pyproject.toml | 4 +- .../management/commands/document_llmindex.py | 6 +- .../test_management_document_llmindex.py | 36 + src/documents/tests/test_api_app_config.py | 8 +- src/paperless/settings/__init__.py | 4 +- src/paperless/views.py | 4 +- src/paperless_ai/ai_classifier.py | 20 +- src/paperless_ai/chat.py | 122 +-- src/paperless_ai/embedding.py | 62 +- src/paperless_ai/indexing.py | 411 ++++------ src/paperless_ai/tests/conftest.py | 27 +- src/paperless_ai/tests/test_ai_classifier.py | 6 +- src/paperless_ai/tests/test_ai_indexing.py | 702 +++++++----------- src/paperless_ai/tests/test_chat.py | 238 +++--- src/paperless_ai/tests/test_embedding.py | 102 ++- src/paperless_ai/tests/test_lazy_imports.py | 25 + src/paperless_ai/tests/test_vector_store.py | 417 +++++++++++ src/paperless_ai/vector_store.py | 333 +++++++++ uv.lock | 139 +++- 19 files changed, 1606 insertions(+), 1060 deletions(-) create mode 100644 src/documents/tests/management/test_management_document_llmindex.py create mode 100644 src/paperless_ai/tests/test_lazy_imports.py create mode 100644 src/paperless_ai/tests/test_vector_store.py create mode 100644 src/paperless_ai/vector_store.py diff --git a/pyproject.toml b/pyproject.toml index e78457e15..68e19cb9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,6 @@ dependencies = [ "drf-spectacular~=0.28", "drf-spectacular-sidecar~=2026.5.1", "drf-writable-nested~=0.7.1", - "faiss-cpu>=1.10", "filelock~=3.29.0", "flower~=2.0.1", "gotenberg-client~=0.14.0", @@ -50,6 +49,7 @@ dependencies = [ "ijson>=3.2", "imap-tools~=1.13.0", "jinja2~=3.1.5", + "lancedb~=0.33.0", "langdetect~=1.0.9", "llama-index-core>=0.14.21", "llama-index-embeddings-huggingface>=0.6.1", @@ -57,12 +57,12 @@ dependencies = [ "llama-index-embeddings-openai-like>=0.2.2", "llama-index-llms-ollama>=0.9.1", "llama-index-llms-openai-like>=0.7.1", - "llama-index-vector-stores-faiss>=0.5.2", "nltk~=3.9.1", "ocrmypdf~=17.4.2", "openai>=2.32", "pathvalidate~=3.3.1", "pdf2image~=1.17.0", + "pyarrow>=16", "python-dateutil~=2.9.0", "python-dotenv~=1.2.1", "python-gnupg~=0.5.4", diff --git a/src/documents/management/commands/document_llmindex.py b/src/documents/management/commands/document_llmindex.py index 9823b1b87..7b34ca9a8 100644 --- a/src/documents/management/commands/document_llmindex.py +++ b/src/documents/management/commands/document_llmindex.py @@ -2,6 +2,7 @@ from typing import Any from documents.management.commands.base import PaperlessCommand from documents.tasks import llmindex_index +from paperless_ai.indexing import llm_index_compact class Command(PaperlessCommand): @@ -12,9 +13,12 @@ class Command(PaperlessCommand): def add_arguments(self, parser: Any) -> None: super().add_arguments(parser) - parser.add_argument("command", choices=["rebuild", "update"]) + parser.add_argument("command", choices=["rebuild", "update", "compact"]) def handle(self, *args: Any, **options: Any) -> None: + if options["command"] == "compact": + llm_index_compact() + return llmindex_index( rebuild=options["command"] == "rebuild", iter_wrapper=lambda docs: self.track( diff --git a/src/documents/tests/management/test_management_document_llmindex.py b/src/documents/tests/management/test_management_document_llmindex.py new file mode 100644 index 000000000..b8a05dd85 --- /dev/null +++ b/src/documents/tests/management/test_management_document_llmindex.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from django.core.management import call_command + +if TYPE_CHECKING: + from pytest_mock import MockerFixture + +_COMPACT = "documents.management.commands.document_llmindex.llm_index_compact" +_INDEX = "documents.management.commands.document_llmindex.llmindex_index" + + +class TestDocumentLlmindexCommand: + def test_compact_calls_llm_index_compact(self, mocker: MockerFixture) -> None: + mock_compact = mocker.patch(_COMPACT) + call_command("document_llmindex", "compact") + mock_compact.assert_called_once_with() + + def test_rebuild_calls_llmindex_index_with_rebuild_true( + self, + mocker: MockerFixture, + ) -> None: + mock_index = mocker.patch(_INDEX) + call_command("document_llmindex", "rebuild") + mock_index.assert_called_once() + assert mock_index.call_args.kwargs["rebuild"] is True + + def test_update_calls_llmindex_index_with_rebuild_false( + self, + mocker: MockerFixture, + ) -> None: + mock_index = mocker.patch(_INDEX) + call_command("document_llmindex", "update") + mock_index.assert_called_once() + assert mock_index.call_args.kwargs["rebuild"] is False diff --git a/src/documents/tests/test_api_app_config.py b/src/documents/tests/test_api_app_config.py index 2418236bd..9b94fff17 100644 --- a/src/documents/tests/test_api_app_config.py +++ b/src/documents/tests/test_api_app_config.py @@ -844,7 +844,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): with ( patch("documents.tasks.llmindex_index.apply_async") as mock_update, - patch("paperless.views.vector_store_file_exists") as mock_exists, + patch("paperless.views.llm_index_exists") as mock_exists, ): mock_exists.return_value = False self.client.patch( @@ -869,7 +869,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): with ( patch("documents.tasks.llmindex_index.apply_async") as mock_update, - patch("paperless.views.vector_store_file_exists") as mock_exists, + patch("paperless.views.llm_index_exists") as mock_exists, ): mock_exists.return_value = True self.client.patch( @@ -890,7 +890,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): with ( patch("documents.tasks.llmindex_index.apply_async") as mock_update, - patch("paperless.views.vector_store_file_exists") as mock_exists, + patch("paperless.views.llm_index_exists") as mock_exists, ): mock_exists.return_value = True self.client.patch( @@ -928,7 +928,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): with ( patch("documents.tasks.llmindex_index.apply_async") as mock_update, - patch("paperless.views.vector_store_file_exists") as mock_exists, + patch("paperless.views.llm_index_exists") as mock_exists, ): mock_exists.return_value = True self.client.patch( diff --git a/src/paperless/settings/__init__.py b/src/paperless/settings/__init__.py index 96b2279a7..1352388f7 100644 --- a/src/paperless/settings/__init__.py +++ b/src/paperless/settings/__init__.py @@ -97,8 +97,7 @@ MODEL_FILE = get_path_from_env( DATA_DIR / "classification_model.pickle", ) LLM_INDEX_DIR = DATA_DIR / "llm_index" -LLM_INDEX_LOCK = DATA_DIR / "locks" / "llm_index.lock" -(DATA_DIR / "locks").mkdir(parents=True, exist_ok=True) +LLM_INDEX_LOCK = LLM_INDEX_DIR / "index.lock" LOGGING_DIR = get_path_from_env("PAPERLESS_LOGGING_DIR", DATA_DIR / "log") @@ -644,6 +643,7 @@ LOGGING = { "kombu": {"handlers": ["file_celery"], "level": "DEBUG"}, "_granian": {"handlers": ["file_paperless"], "level": "DEBUG"}, "granian.access": {"handlers": ["file_paperless"], "level": "DEBUG"}, + "httpx": {"level": "WARNING"}, }, } diff --git a/src/paperless/views.py b/src/paperless/views.py index 9ed4a2a87..2e3fb4d82 100644 --- a/src/paperless/views.py +++ b/src/paperless/views.py @@ -49,7 +49,7 @@ from paperless.serialisers import GroupSerializer from paperless.serialisers import PaperlessAuthTokenSerializer from paperless.serialisers import ProfileSerializer from paperless.serialisers import UserSerializer -from paperless_ai.indexing import vector_store_file_exists +from paperless_ai.indexing import llm_index_exists class PaperlessObtainAuthTokenView(ObtainAuthToken): @@ -467,7 +467,7 @@ class ApplicationConfigurationViewSet(ModelViewSet[ApplicationConfiguration]): or old_llm_context_size != new_llm_context_size ) rebuild_needed = new_ai_index_enabled and ( - not vector_store_file_exists() or embedding_config_changed + not llm_index_exists() or embedding_config_changed ) if rebuild_needed: diff --git a/src/paperless_ai/ai_classifier.py b/src/paperless_ai/ai_classifier.py index c3e27cd41..5420812eb 100644 --- a/src/paperless_ai/ai_classifier.py +++ b/src/paperless_ai/ai_classifier.py @@ -24,9 +24,14 @@ def get_language_name(language_code: str) -> str: def build_prompt_without_rag( document: Document, + config: AIConfig, ) -> str: filename = document.filename or "" - content = truncate_content(document.content[:4000] or "") + content = truncate_content( + document.content[:4000] or "", + chunk_size=config.llm_embedding_chunk_size, + context_size=config.llm_context_size, + ) return f""" You are a document classification assistant. @@ -49,10 +54,15 @@ def build_prompt_without_rag( def build_prompt_with_rag( document: Document, + config: AIConfig, user: User | None = None, ) -> str: - base_prompt = build_prompt_without_rag(document) - context = truncate_content(get_context_for_document(document, user)) + base_prompt = build_prompt_without_rag(document, config) + context = truncate_content( + get_context_for_document(document, user), + chunk_size=config.llm_embedding_chunk_size, + context_size=config.llm_context_size, + ) return f"""{base_prompt} @@ -130,9 +140,9 @@ def get_ai_document_classification( ai_config = AIConfig() prompt = ( - build_prompt_with_rag(document, user) + build_prompt_with_rag(document, ai_config, user) if ai_config.llm_embedding_backend - else build_prompt_without_rag(document) + else build_prompt_without_rag(document, ai_config) ) client = AIClient() diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py index b2710c379..123771c50 100644 --- a/src/paperless_ai/chat.py +++ b/src/paperless_ai/chat.py @@ -3,7 +3,9 @@ import logging import sys from documents.models import Document +from paperless.config import AIConfig from paperless_ai.client import AIClient +from paperless_ai.indexing import _document_id_filters from paperless_ai.indexing import get_rag_prompt_helper from paperless_ai.indexing import load_or_build_index @@ -75,134 +77,54 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str ) -def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int): - from llama_index.core.base.base_retriever import BaseRetriever - from llama_index.core.schema import NodeWithScore - from llama_index.core.vector_stores import VectorStoreQuery - - class DocumentFilteredFaissRetriever(BaseRetriever): - def __init__(self): - super().__init__() - self._cached_query_str = None - self._cached_nodes = [] - - def _retrieve(self, query_bundle): - if query_bundle.query_str == self._cached_query_str: - return self._cached_nodes - - if query_bundle.embedding is None: - query_bundle.embedding = ( - index._embed_model.get_agg_embedding_from_queries( - query_bundle.embedding_strs, - ) - ) - - faiss_index = index.vector_store._faiss_index - max_top_k = faiss_index.ntotal - if max_top_k == 0: - self._cached_query_str = query_bundle.query_str - self._cached_nodes = [] - return [] - - query_top_k = min(max(similarity_top_k, 1), max_top_k) - allowed_nodes: list[NodeWithScore] = [] - seen_node_ids: set[str] = set() - - while query_top_k <= max_top_k: - query_result = index.vector_store.query( - VectorStoreQuery( - query_embedding=query_bundle.embedding, - similarity_top_k=query_top_k, - ), - ) - - for vector_id, score in zip( - query_result.ids or [], - query_result.similarities or [], - strict=False, - ): - node_id = index.index_struct.nodes_dict.get(vector_id) - if node_id is None or node_id in seen_node_ids: - continue - - node = index.docstore.docs.get(node_id) - if node is None or node.metadata.get("document_id") not in doc_ids: - continue - - seen_node_ids.add(node_id) - allowed_nodes.append(NodeWithScore(node=node, score=score)) - - if len(allowed_nodes) >= similarity_top_k: - self._cached_query_str = query_bundle.query_str - self._cached_nodes = allowed_nodes - return allowed_nodes - - if query_top_k == max_top_k: - self._cached_query_str = query_bundle.query_str - self._cached_nodes = allowed_nodes - return allowed_nodes - - query_top_k = min(query_top_k * 2, max_top_k) - - self._cached_query_str = query_bundle.query_str - self._cached_nodes = allowed_nodes - return allowed_nodes - - return DocumentFilteredFaissRetriever() - - def stream_chat_with_documents(query_str: str, documents: list[Document]): try: yield from _stream_chat_with_documents(query_str, documents) except Exception as e: - logger.exception(f"Failed to stream document chat response: {e}", exc_info=True) + logger.exception("Failed to stream document chat response: %s", e) yield CHAT_ERROR_MESSAGE def _stream_chat_with_documents(query_str: str, documents: list[Document]): - client = AIClient() - index = load_or_build_index() - - doc_ids = [str(doc.pk) for doc in documents] - - # Filter only the node(s) that match the document IDs - nodes = [ - node - for node in index.docstore.docs.values() - if node.metadata.get("document_id") in doc_ids - ] - - if len(nodes) == 0: - logger.warning("No nodes found for the given documents.") + if not documents: yield CHAT_NO_CONTENT_MESSAGE return from llama_index.core.prompts import PromptTemplate from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.response_synthesizers import get_response_synthesizer + from llama_index.core.retrievers import VectorIndexRetriever - retriever = _get_document_filtered_retriever( - index, - set(doc_ids), - CHAT_RETRIEVER_TOP_K, + config = AIConfig() + index = load_or_build_index(config) + filters = _document_id_filters(str(doc.pk) for doc in documents) + + retriever = VectorIndexRetriever( + index=index, + similarity_top_k=CHAT_RETRIEVER_TOP_K, + filters=filters, ) top_nodes = retriever.retrieve(query_str) - if len(top_nodes) == 0: - logger.warning("Retriever returned no nodes for the given documents.") + if not top_nodes: + logger.warning("No nodes found for the given documents.") yield CHAT_NO_CONTENT_MESSAGE return + client = AIClient() + references = _get_document_references(documents, top_nodes) prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL) response_synthesizer = get_response_synthesizer( llm=client.llm, - prompt_helper=get_rag_prompt_helper(), + prompt_helper=get_rag_prompt_helper( + chunk_size=config.llm_embedding_chunk_size, + context_size=config.llm_context_size, + ), text_qa_template=prompt_template, streaming=True, ) - query_engine = RetrieverQueryEngine.from_args( retriever=retriever, llm=client.llm, @@ -211,9 +133,7 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]): ) logger.debug("Document chat query: %s", query_str) - response_stream = query_engine.query(query_str) - for chunk in response_stream.response_gen: yield chunk sys.stdout.flush() diff --git a/src/paperless_ai/embedding.py b/src/paperless_ai/embedding.py index 2695e9fb3..88ea80293 100644 --- a/src/paperless_ai/embedding.py +++ b/src/paperless_ai/embedding.py @@ -1,12 +1,9 @@ -import json import re from typing import TYPE_CHECKING from django.conf import settings if TYPE_CHECKING: - from pathlib import Path - from llama_index.core.base.embeddings.base import BaseEmbedding from documents.models import Document @@ -23,9 +20,7 @@ OCR_LEADER_REGEX = re.compile(r"[._\-\u00b7]{4,}") HORIZONTAL_WHITESPACE_REGEX = re.compile(r"[ \t\u00a0]+") -def get_embedding_model() -> "BaseEmbedding": - config = AIConfig() - +def get_embedding_model(config: AIConfig) -> "BaseEmbedding": match config.llm_embedding_backend: case LLMEmbeddingBackend.OPENAI_LIKE: from llama_index.embeddings.openai_like import OpenAILikeEmbedding @@ -95,41 +90,20 @@ def get_embedding_model() -> "BaseEmbedding": ) -def get_embedding_dim() -> int: - """ - Loads embedding dimension from meta.json if available, otherwise infers it - from a dummy embedding and stores it for future use. - """ - config = AIConfig() - default_model = { - LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small", - LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2", - LLMEmbeddingBackend.OLLAMA: "embeddinggemma", - }.get( +_DEFAULT_MODEL_NAMES = { + LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small", + LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2", + LLMEmbeddingBackend.OLLAMA: "embeddinggemma", +} + + +def get_configured_model_name(config: AIConfig) -> str: + """Return the canonical name of the currently configured embedding model.""" + default = _DEFAULT_MODEL_NAMES.get( config.llm_embedding_backend, "sentence-transformers/all-MiniLM-L6-v2", ) - model = config.llm_embedding_model or default_model - - meta_path: Path = settings.LLM_INDEX_DIR / "meta.json" - if meta_path.exists(): - with meta_path.open() as f: - meta = json.load(f) - if meta.get("embedding_model") != model: - raise RuntimeError( - f"Embedding model changed from {meta.get('embedding_model')} to {model}. " - "You must rebuild the index.", - ) - return meta["dim"] - - embedding_model = get_embedding_model() - test_embed = embedding_model.get_text_embedding("test") - dim = len(test_embed) - - with meta_path.open("w") as f: - json.dump({"embedding_model": model, "dim": dim}, f) - - return dim + return config.llm_embedding_model or default def _normalize_llm_index_text(text: str) -> str: @@ -138,15 +112,13 @@ def _normalize_llm_index_text(text: str) -> str: def build_llm_index_text(doc: Document) -> str: + # TODO: Filename, Storage Path, and Archive Serial Number are short structured + # values that could move to node.metadata (excluded from embeddings, visible to + # LLM via metadata prepend) — same pattern as title/tags/correspondent. Notes + # and Custom Fields should stay here: Notes can be long free text, Custom Fields + # are dynamic in count and best kept in the embedding. lines = [ - f"Title: {doc.title}", f"Filename: {doc.filename}", - f"Created: {doc.created}", - f"Added: {doc.added}", - f"Modified: {doc.modified}", - f"Tags: {', '.join(tag.name for tag in doc.tags.all())}", - f"Document Type: {doc.document_type.name if doc.document_type else ''}", - f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}", f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}", f"Archive Serial Number: {doc.archive_serial_number or ''}", f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}", diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index 7ec1fdba3..dd96106a6 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -1,9 +1,7 @@ import logging -import shutil -from collections import defaultdict from collections.abc import Iterable +from contextlib import contextmanager from datetime import timedelta -from pathlib import Path from typing import TYPE_CHECKING from django.conf import settings @@ -16,35 +14,28 @@ from documents.utils import IterWrapper from documents.utils import identity from paperless.config import AIConfig from paperless_ai.embedding import build_llm_index_text -from paperless_ai.embedding import get_embedding_dim +from paperless_ai.embedding import get_configured_model_name from paperless_ai.embedding import get_embedding_model if TYPE_CHECKING: - from llama_index.core import VectorStoreIndex from llama_index.core.schema import BaseNode + from paperless_ai.vector_store import PaperlessLanceVectorStore + logger = logging.getLogger("paperless_ai.indexing") +LLM_INDEX_TABLE = "documents" + RAG_NUM_OUTPUT = 512 RAG_CHUNK_OVERLAP = 200 -def _index_lock_path() -> Path: - """Return the path used as the file lock for FAISS index mutations. - - The lock file lives in DATA_DIR/locks/ (not inside LLM_INDEX_DIR) so that a - rebuild — which calls shutil.rmtree(LLM_INDEX_DIR) — cannot delete the lock - while another worker still holds it. - """ - return settings.LLM_INDEX_LOCK - - def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool: # NOTE: The check-then-enqueue sequence below is non-atomic (TOCTOU): two # concurrent workers can both observe no running task and both enqueue a # full rebuild. This is wasteful but not data-corrupting — update_llm_index - # is itself protected by _index_lock_path(), so only one rebuild runs at a + # is itself protected by settings.LLM_INDEX_LOCK, so only one rebuild runs at a # time and the second one is serialised after the first completes. from documents.tasks import llmindex_index @@ -71,46 +62,38 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool: return True -def get_or_create_storage_context(*, rebuild=False): - """ - Loads or creates the StorageContext (vector store, docstore, index store). - If rebuild=True, deletes and recreates everything. - """ - if rebuild: - shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True) - settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True) +def get_vector_store() -> "PaperlessLanceVectorStore": + from paperless_ai.vector_store import PaperlessLanceVectorStore - if rebuild or not settings.LLM_INDEX_DIR.exists(): - import faiss - from llama_index.core import StorageContext - from llama_index.core.storage.docstore import SimpleDocumentStore - from llama_index.core.storage.index_store import SimpleIndexStore - from llama_index.vector_stores.faiss import FaissVectorStore - - settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True) - embedding_dim = get_embedding_dim() - faiss_index = faiss.IndexFlatL2(embedding_dim) - vector_store = FaissVectorStore(faiss_index=faiss_index) - docstore = SimpleDocumentStore() - index_store = SimpleIndexStore() - else: - from llama_index.core import StorageContext - from llama_index.core.storage.docstore import SimpleDocumentStore - from llama_index.core.storage.index_store import SimpleIndexStore - from llama_index.vector_stores.faiss import FaissVectorStore - - vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR) - docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR) - index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR) - - return StorageContext.from_defaults( - docstore=docstore, - index_store=index_store, - vector_store=vector_store, - persist_dir=settings.LLM_INDEX_DIR, + settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True) + return PaperlessLanceVectorStore( + uri=str(settings.LLM_INDEX_DIR), + table_name=LLM_INDEX_TABLE, ) +@contextmanager +def write_store(embed_model_name: str | None = None): + """Acquire the write lock and yield the vector store. + + All mutating operations (upsert, delete, rebuild, compact) must go through + this context manager to serialise concurrent Celery writers. + Read paths use ``get_vector_store()`` directly — no lock needed. + + Pass ``embed_model_name`` whenever the operation may create the table so + the model name is recorded in the schema metadata for future mismatch checks. + """ + from paperless_ai.vector_store import PaperlessLanceVectorStore + + settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True) + with FileLock(settings.LLM_INDEX_LOCK): + yield PaperlessLanceVectorStore( + uri=str(settings.LLM_INDEX_DIR), + table_name=LLM_INDEX_TABLE, + embed_model_name=embed_model_name, + ) + + def build_document_node( document: Document, *, @@ -142,9 +125,11 @@ def build_document_node( # the token count and exceed embedding models with small context windows # (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048). doc = LlamaDocument( + id_=str(document.id), text=text, metadata=metadata, excluded_embed_metadata_keys=list(metadata.keys()), + excluded_llm_metadata_keys=["document_id"], ) chunk_size = chunk_size or get_rag_chunk_size() parser = SimpleNodeParser( @@ -154,76 +139,29 @@ def build_document_node( return parser.get_nodes_from_documents([doc]) -def load_or_build_index(nodes=None): - """ - Load an existing VectorStoreIndex if present, - or build a new one using provided nodes if storage is empty. - """ +def load_or_build_index(config: AIConfig): + """Return a VectorStoreIndex backed by the vector store.""" import llama_index.core.settings as llama_settings from llama_index.core import VectorStoreIndex - from llama_index.core import load_index_from_storage - embed_model = get_embedding_model() + embed_model = get_embedding_model(config) llama_settings.Settings.embed_model = embed_model - storage_context = get_or_create_storage_context() - try: - return load_index_from_storage(storage_context=storage_context) - except ValueError as e: - logger.warning("Failed to load index from storage: %s", e) - if not nodes: - queue_llm_index_update_if_needed( - rebuild=vector_store_file_exists(), - reason="LLM index missing or invalid while loading.", - ) - logger.info("No nodes provided for index creation.") - raise - return VectorStoreIndex( - nodes=nodes, - storage_context=storage_context, - embed_model=embed_model, - ) + vector_store = get_vector_store() + return VectorStoreIndex.from_vector_store( + vector_store=vector_store, + embed_model=embed_model, + ) -def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"): - """ - Removes existing documents from docstore for a given document from the index. - This is necessary because FAISS IndexFlatL2 is append-only. - """ - all_node_ids = list(index.docstore.docs.keys()) - existing_nodes = [ - node.node_id - for node in index.docstore.get_nodes(all_node_ids) - if node.metadata.get("document_id") == str(document.id) - ] - for node_id in existing_nodes: - # Delete from docstore, FAISS IndexFlatL2 are append-only - index.docstore.delete_document(node_id) - # Also purge the FAISS position -> UUID mapping so subsequent similarity - # queries don't raise KeyError on ghost vector positions. - stale_keys = [ - k for k, v in index.index_struct.nodes_dict.items() if v == node_id - ] - for key in stale_keys: - del index.index_struct.nodes_dict[key] - # Re-sync the mutated index_struct so persist() writes the updated nodes_dict. - index.storage_context.index_store.add_index_struct(index.index_struct) - - -def vector_store_file_exists(): - """ - Check if the vector store file exists in the LLM index directory. - """ - return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists() +def llm_index_exists() -> bool: + """True when the index table exists on disk.""" + return get_vector_store().table_exists() def get_rag_chunk_size() -> int: return AIConfig().llm_embedding_chunk_size -def get_rag_context_size() -> int: - return AIConfig().llm_context_size - - def get_rag_chunk_overlap(chunk_size: int | None = None) -> int: chunk_size = chunk_size or get_rag_chunk_size() return min(RAG_CHUNK_OVERLAP, chunk_size - 1) @@ -249,123 +187,125 @@ def get_rag_prompt_helper( ) +def _embed_nodes(nodes: list["BaseNode"], embed_model) -> None: + """Embed ``nodes`` in place using ``embed_model``.""" + from llama_index.core.schema import MetadataMode + + texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes] + for node, emb in zip( + nodes, + embed_model.get_text_embedding_batch(texts), + strict=True, + ): + node.embedding = emb + + +def _document_id_filters(doc_ids): + """Return a MetadataFilters IN filter scoped to ``doc_ids``.""" + from llama_index.core.vector_stores.types import FilterOperator + from llama_index.core.vector_stores.types import MetadataFilter + from llama_index.core.vector_stores.types import MetadataFilters + + return MetadataFilters( + filters=[ + MetadataFilter( + key="document_id", + operator=FilterOperator.IN, + value=sorted(doc_ids), + ), + ], + ) + + def update_llm_index( *, iter_wrapper: IterWrapper[Document] = identity, rebuild=False, ) -> str: - """ - Rebuild or update the LLM index. - """ - from llama_index.core import VectorStoreIndex - - nodes = [] - + """Rebuild or incrementally update the LLM index.""" documents = Document.objects.all() - if not documents.exists(): + no_documents = not documents.exists() + + # Fast exit before touching config: nothing to index and no existing index. + if no_documents and not rebuild and not llm_index_exists(): logger.warning("No documents found to index.") - if not rebuild and not vector_store_file_exists(): - return "No documents found to index." + return "No documents found to index." config = AIConfig() + model_name = get_configured_model_name(config) + + if ( + not rebuild + and llm_index_exists() + and get_vector_store().config_mismatch(model_name) + ): + logger.warning("Embedding model changed; forcing LLM index rebuild.") + rebuild = True + + if no_documents: + logger.warning("No documents found to index.") + chunk_size = config.llm_embedding_chunk_size + embed_model = get_embedding_model(config) - with FileLock(_index_lock_path()): - if rebuild or not vector_store_file_exists(): - # remove meta.json to force re-detection of embedding dim + with write_store(embed_model_name=model_name) as store: + if rebuild or not store.table_exists(): (settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True) - # Rebuild index from scratch logger.info("Rebuilding LLM index.") - import llama_index.core.settings as llama_settings - - embed_model = get_embedding_model() - llama_settings.Settings.embed_model = embed_model - storage_context = get_or_create_storage_context(rebuild=True) + store.drop_table() for document in iter_wrapper(documents): - document_nodes = build_document_node(document, chunk_size=chunk_size) - nodes.extend(document_nodes) - - index = VectorStoreIndex( - nodes=nodes, - storage_context=storage_context, - embed_model=embed_model, - show_progress=False, - ) + nodes = build_document_node(document, chunk_size=chunk_size) + _embed_nodes(nodes, embed_model) + store.add(nodes) msg = "LLM index rebuilt successfully." else: - # Update existing index - index = load_or_build_index() - existing_nodes: defaultdict[str, list] = defaultdict(list) - for node in index.docstore.docs.values(): - doc_id = node.metadata.get("document_id") - if doc_id is not None: - existing_nodes[doc_id].append(node) - + existing = store.get_modified_times() + changed = 0 for document in iter_wrapper(documents): doc_id = str(document.id) - document_modified = document.modified.isoformat() + if existing.get(doc_id) == document.modified.isoformat(): + continue + nodes = build_document_node(document, chunk_size=chunk_size) + _embed_nodes(nodes, embed_model) + store.upsert_document(doc_id, nodes) + changed += 1 + msg = ( + "LLM index updated successfully." + if changed + else "No changes detected in LLM index." + ) - if doc_id in existing_nodes: - doc_nodes = existing_nodes[doc_id] - node_modified = doc_nodes[0].metadata.get("modified") - - if node_modified == document_modified: - continue - - # Delete from docstore, FAISS IndexFlatL2 are append-only - for node in doc_nodes: - remove_document_docstore_nodes(document, index) - - nodes.extend(build_document_node(document, chunk_size=chunk_size)) - - if nodes: - msg = "LLM index updated successfully." - logger.info( - "Updating %d nodes in LLM index.", - len(nodes), - ) - index.insert_nodes(nodes) - else: - msg = "No changes detected in LLM index." - logger.info(msg) - - index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) + store.ensure_document_id_scalar_index() + store.maybe_create_ann_index() + store.compact(retention_seconds=60 * 60) # 1 hour: safe for in-flight readers return msg def llm_index_add_or_update_document(document: Document): - """ - Adds or updates a document in the LLM index. - If the document already exists, it will be replaced. - """ - new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size()) - if not new_nodes: - logger.warning( - "No indexable content for document %s; skipping LLM index update.", - document.pk, - ) - return + """Add or atomically replace a document's chunks in the index.""" + config = AIConfig() + new_nodes = build_document_node( + document, + chunk_size=config.llm_embedding_chunk_size, + ) + if new_nodes: + _embed_nodes(new_nodes, get_embedding_model(config)) - with FileLock(_index_lock_path()): - index = load_or_build_index(nodes=new_nodes) + with write_store(embed_model_name=get_configured_model_name(config)) as store: + store.upsert_document(str(document.id), new_nodes) + store.ensure_document_id_scalar_index() - remove_document_docstore_nodes(document, index) - index.insert_nodes(new_nodes) - - index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) +def llm_index_compact() -> None: + """Compact the index immediately, clearing all MVCC version history.""" + with write_store() as store: + store.compact(retention_seconds=0) def llm_index_remove_document(document: Document): - """ - Removes a document from the LLM index. - """ - with FileLock(_index_lock_path()): - index = load_or_build_index() - - remove_document_docstore_nodes(document, index) - - index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR) + """Remove a document's chunks from the LLM index.""" + with write_store() as store: + store.delete(str(document.id)) def truncate_content( @@ -410,77 +350,54 @@ def query_similar_documents( top_k: int = 5, document_ids: Iterable[int | str] | None = None, ) -> list[Document]: - """ - Runs a similarity query and returns top-k similar Document objects. - """ + """Return up to ``top_k`` Documents most similar to ``document``.""" allowed_document_ids = normalize_document_ids(document_ids) if allowed_document_ids is not None and not allowed_document_ids: return [] - if not vector_store_file_exists(): + if not llm_index_exists(): queue_llm_index_update_if_needed( rebuild=False, reason="LLM index not found for similarity query.", ) return [] - with FileLock(_index_lock_path()): - index = load_or_build_index() + config = AIConfig() - # constrain only the node(s) that match the document IDs, if given - doc_node_ids = ( - [ - node.node_id - for node in index.docstore.docs.values() - if node.metadata.get("document_id") in allowed_document_ids - ] - if allowed_document_ids is not None - else None - ) - if doc_node_ids is not None and not doc_node_ids: - return [] + from llama_index.core.retrievers import VectorIndexRetriever - from llama_index.core.retrievers import VectorIndexRetriever + index = load_or_build_index(config) - retriever = VectorIndexRetriever( - index=index, - similarity_top_k=top_k, - doc_ids=doc_node_ids, - ) + filters = ( + _document_id_filters(allowed_document_ids) + if allowed_document_ids is not None + else None + ) - config = AIConfig() - query_text = truncate_content( - (document.title or "") + "\n" + (document.content or ""), - chunk_size=config.llm_embedding_chunk_size, - context_size=config.llm_context_size, - ) - try: - results = retriever.retrieve(query_text) - except KeyError as e: - # Ghost FAISS positions remain after deletion because IndexFlatL2 is - # append-only. Treat them as absent and return no results. - logger.debug( - "Skipping LLM similarity query for document %s due to a stale " - "FAISS position with no docstore node: %s", - document.pk, - e, - ) - return [] + retriever = VectorIndexRetriever( + index=index, + similarity_top_k=top_k, + filters=filters, + ) + + query_text = truncate_content( + (document.title or "") + "\n" + (document.content or ""), + chunk_size=config.llm_embedding_chunk_size, + context_size=config.llm_context_size, + ) + results = retriever.retrieve(query_text) retrieved_document_ids: list[int] = [] for node in results: document_id = node.metadata.get("document_id") if document_id is None: continue - normalized_document_id = str(document_id) - if ( - allowed_document_ids is not None - and normalized_document_id not in allowed_document_ids - ): + normalized = str(document_id) + if allowed_document_ids is not None and normalized not in allowed_document_ids: continue try: - retrieved_document_ids.append(int(normalized_document_id)) - except ValueError: + retrieved_document_ids.append(int(normalized)) + except ValueError: # pragma: no cover logger.warning( "Skipping LLM index result with invalid document_id %r.", document_id, diff --git a/src/paperless_ai/tests/conftest.py b/src/paperless_ai/tests/conftest.py index 2d71476c7..6a7abf7ec 100644 --- a/src/paperless_ai/tests/conftest.py +++ b/src/paperless_ai/tests/conftest.py @@ -1,10 +1,35 @@ from pathlib import Path import pytest +import pytest_mock +from llama_index.core.base.embeddings.base import BaseEmbedding from pytest_django.fixtures import SettingsWrapper @pytest.fixture -def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper): +def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path: settings.LLM_INDEX_DIR = tmp_path + settings.LLM_INDEX_LOCK = tmp_path / "index.lock" return tmp_path + + +class FakeEmbedding(BaseEmbedding): + async def _aget_query_embedding(self, query: str) -> list[float]: + return [0.1] * self.get_query_embedding_dim() + + def _get_query_embedding(self, query: str) -> list[float]: + return [0.1] * self.get_query_embedding_dim() + + def _get_text_embedding(self, text: str) -> list[float]: + return [0.1] * self.get_query_embedding_dim() + + def get_query_embedding_dim(self) -> int: + return 384 + + +@pytest.fixture +def mock_embed_model(mocker: pytest_mock.MockerFixture) -> pytest_mock.MockType: + fake = FakeEmbedding() + mocker.patch("paperless_ai.indexing.get_embedding_model", return_value=fake) + mocker.patch("paperless_ai.embedding.get_embedding_model", return_value=fake) + return fake diff --git a/src/paperless_ai/tests/test_ai_classifier.py b/src/paperless_ai/tests/test_ai_classifier.py index 97e18eb47..45822b14b 100644 --- a/src/paperless_ai/tests/test_ai_classifier.py +++ b/src/paperless_ai/tests/test_ai_classifier.py @@ -6,6 +6,7 @@ import pytest from django.test import override_settings from documents.models import Document +from paperless.config import AIConfig from paperless_ai.ai_classifier import build_localization_prompt from paperless_ai.ai_classifier import build_prompt_with_rag from paperless_ai.ai_classifier import build_prompt_without_rag @@ -211,11 +212,12 @@ def test_prompt_with_without_rag(mock_document): "paperless_ai.ai_classifier.get_context_for_document", return_value="Context from similar documents", ): - prompt = build_prompt_without_rag(mock_document) + config = AIConfig() + prompt = build_prompt_without_rag(mock_document, config) assert "Additional context from similar documents" not in prompt assert "for generated" not in prompt - prompt = build_prompt_with_rag(mock_document) + prompt = build_prompt_with_rag(mock_document, config) assert "Additional context from similar documents" in prompt prompt = build_localization_prompt( diff --git a/src/paperless_ai/tests/test_ai_indexing.py b/src/paperless_ai/tests/test_ai_indexing.py index 339d75ead..31e1f6bc8 100644 --- a/src/paperless_ai/tests/test_ai_indexing.py +++ b/src/paperless_ai/tests/test_ai_indexing.py @@ -5,11 +5,8 @@ from unittest.mock import patch import pytest import pytest_mock -from django.contrib.auth.models import User from django.test import override_settings from django.utils import timezone -from faker import Faker -from llama_index.core.base.embeddings.base import BaseEmbedding from documents.models import Document from documents.models import PaperlessTask @@ -19,10 +16,11 @@ from documents.tests.factories import DocumentFactory from documents.tests.factories import PaperlessTaskFactory from paperless.models import ApplicationConfiguration from paperless_ai import indexing +from paperless_ai.tests.conftest import FakeEmbedding @pytest.fixture -def real_document(db): +def real_document(db: None) -> Document: return Document.objects.create( title="Test Document", content="This is some test content.", @@ -30,44 +28,29 @@ def real_document(db): ) -@pytest.fixture -def mock_embed_model(): - fake = FakeEmbedding() - with ( - patch("paperless_ai.indexing.get_embedding_model") as mock_index, - patch( - "paperless_ai.embedding.get_embedding_model", - ) as mock_embedding, - ): - mock_index.return_value = fake - mock_embedding.return_value = fake - yield mock_index - - -class FakeEmbedding(BaseEmbedding): - # TODO: maybe a better way to do this? - def _aget_query_embedding(self, query: str) -> list[float]: - return [0.1] * self.get_query_embedding_dim() - - def _get_query_embedding(self, query: str) -> list[float]: - return [0.1] * self.get_query_embedding_dim() - - def _get_text_embedding(self, text: str) -> list[float]: - return [0.1] * self.get_query_embedding_dim() - - def get_query_embedding_dim(self) -> int: - return 384 # Match your real FAISS config - - @pytest.mark.django_db -def test_build_document_node(real_document) -> None: +def test_build_document_node(real_document: Document) -> None: nodes = indexing.build_document_node(real_document) assert len(nodes) > 0 assert nodes[0].metadata["document_id"] == str(real_document.id) @pytest.mark.django_db -def test_build_document_node_excludes_metadata_from_embedding(real_document) -> None: +def test_build_document_node_sets_ref_doc_id(real_document: Document) -> None: + """Every node produced by build_document_node must carry the paperless document id + as its ref_doc_id so that the LanceDB adapter's delete(str(doc.id)) works correctly.""" + nodes = indexing.build_document_node(real_document) + assert len(nodes) > 0, "Expected at least one node" + for node in nodes: + assert node.ref_doc_id == str(real_document.id), ( + f"Expected ref_doc_id={real_document.id!r}, got {node.ref_doc_id!r}" + ) + + +@pytest.mark.django_db +def test_build_document_node_excludes_metadata_from_embedding( + real_document: Document, +) -> None: """Metadata keys must not be prepended to the embedding text. build_llm_index_text already encodes all metadata in the body text, so @@ -87,7 +70,38 @@ def test_build_document_node_excludes_metadata_from_embedding(real_document) -> @pytest.mark.django_db -def test_build_document_node_uses_rag_chunk_settings(real_document) -> None: +def test_build_document_node_structured_fields_in_metadata( + real_document: Document, +) -> None: + """Structured fields must be in node.metadata so the LLM receives them via metadata prepend.""" + nodes = indexing.build_document_node(real_document) + assert len(nodes) > 0 + for node in nodes: + assert "title" in node.metadata + assert "tags" in node.metadata + assert "correspondent" in node.metadata + assert "document_type" in node.metadata + assert "created" in node.metadata + assert "added" in node.metadata + assert "modified" in node.metadata + + +@pytest.mark.django_db +def test_build_document_node_excludes_document_id_from_llm_context( + real_document: Document, +) -> None: + """document_id is an internal key and must not appear in LLM context text.""" + from llama_index.core.schema import MetadataMode + + nodes = indexing.build_document_node(real_document) + assert len(nodes) > 0 + for node in nodes: + assert "document_id" in node.excluded_llm_metadata_keys + assert "document_id" not in node.get_content(metadata_mode=MetadataMode.LLM) + + +@pytest.mark.django_db +def test_build_document_node_uses_rag_chunk_settings(real_document: Document) -> None: app_config, _ = ApplicationConfiguration.objects.get_or_create() app_config.llm_embedding_chunk_size = 512 app_config.save() @@ -118,9 +132,9 @@ def test_get_rag_prompt_helper_uses_context_setting() -> None: @pytest.mark.django_db def test_update_llm_index( - temp_llm_index_dir, - real_document, - mock_embed_model, + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: FakeEmbedding, ) -> None: mock_config = MagicMock() mock_config.llm_embedding_chunk_size = 512 @@ -138,19 +152,18 @@ def test_update_llm_index( ai_config.assert_called_once() build_document_node.assert_called_once_with(real_document, chunk_size=512) - assert any(temp_llm_index_dir.glob("*.json")) @pytest.mark.django_db -def test_update_llm_index_removes_meta( - temp_llm_index_dir, - real_document, - mock_embed_model, +def test_update_llm_index_cleans_stale_meta_on_rebuild( + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: FakeEmbedding, ) -> None: - # Pre-create a meta.json with incorrect data - (temp_llm_index_dir / "meta.json").write_text( - json.dumps({"embedding_model": "old", "dim": 1}), - ) + # A meta.json left over from the FAISS era (or written by older code) must be + # deleted on rebuild so stale artifacts don't accumulate on disk. + stale_meta = temp_llm_index_dir / "meta.json" + stale_meta.write_text(json.dumps({"embedding_model": "old", "dim": 1})) with patch("documents.models.Document.objects.all") as mock_all: mock_queryset = MagicMock() @@ -159,23 +172,52 @@ def test_update_llm_index_removes_meta( mock_all.return_value = mock_queryset indexing.update_llm_index(rebuild=True) - meta = json.loads((temp_llm_index_dir / "meta.json").read_text()) - from paperless.config import AIConfig - - config = AIConfig() - expected_model = config.llm_embedding_model or ( - "text-embedding-3-small" - if config.llm_embedding_backend == "openai-like" - else "sentence-transformers/all-MiniLM-L6-v2" + assert not stale_meta.exists(), ( + "update_llm_index(rebuild=True) must remove stale meta.json" ) - assert meta == {"embedding_model": expected_model, "dim": 384} + + +@pytest.mark.django_db +def test_update_llm_index_rebuilds_on_model_name_change( + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: FakeEmbedding, +) -> None: + # Build initial index with model "model-a". + with patch("documents.models.Document.objects.all") as mock_all: + mock_queryset = MagicMock() + mock_queryset.exists.return_value = True + mock_queryset.__iter__.return_value = iter([real_document]) + mock_all.return_value = mock_queryset + with patch( + "paperless_ai.indexing.get_configured_model_name", + return_value="model-a", + ): + indexing.update_llm_index(rebuild=True) + + # Simulate config change to "model-b"; the incremental run must force a rebuild. + with patch("documents.models.Document.objects.all") as mock_all: + mock_queryset = MagicMock() + mock_queryset.exists.return_value = True + mock_queryset.__iter__.return_value = iter([real_document]) + mock_all.return_value = mock_queryset + with patch( + "paperless_ai.indexing.get_configured_model_name", + return_value="model-b", + ): + indexing.update_llm_index(rebuild=False) + + store = indexing.get_vector_store() + # Schema metadata only updates when the table is dropped and recreated, never on + # incremental writes -- so "model-b" here proves a full rebuild happened. + assert store.stored_model_name() == "model-b" @pytest.mark.django_db def test_update_llm_index_partial_update( - temp_llm_index_dir, - real_document, - mock_embed_model, + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: FakeEmbedding, ) -> None: doc2 = Document.objects.create( title="Test Document 2", @@ -210,131 +252,34 @@ def test_update_llm_index_partial_update( mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3]) mock_all.return_value = mock_queryset - # assert logs "Updating LLM index with %d new nodes and removing %d old nodes." - with patch("paperless_ai.indexing.logger") as mock_logger: - indexing.update_llm_index(rebuild=False) - mock_logger.info.assert_called_once_with( - "Updating %d nodes in LLM index.", - 2, - ) indexing.update_llm_index(rebuild=False) - assert any(temp_llm_index_dir.glob("*.json")) - - -def test_get_or_create_storage_context_raises_exception( - temp_llm_index_dir, - mock_embed_model, -) -> None: - with pytest.raises(Exception): - indexing.get_or_create_storage_context(rebuild=False) - - -@override_settings( - LLM_EMBEDDING_BACKEND="huggingface", -) -def test_load_or_build_index_builds_when_nodes_given( - temp_llm_index_dir, - real_document, - mock_embed_model, -) -> None: - with ( - patch( - "llama_index.core.load_index_from_storage", - side_effect=ValueError("Index not found"), - ), - patch( - "llama_index.core.VectorStoreIndex", - return_value=MagicMock(), - ) as mock_index_cls, - patch( - "paperless_ai.indexing.get_or_create_storage_context", - return_value=MagicMock(), - ) as mock_storage, - ): - mock_storage.return_value.persist_dir = temp_llm_index_dir - indexing.load_or_build_index( - nodes=[indexing.build_document_node(real_document)], - ) - mock_index_cls.assert_called_once() - - -def test_load_or_build_index_raises_exception_when_no_nodes( - temp_llm_index_dir, - mock_embed_model, -) -> None: - with ( - patch( - "llama_index.core.load_index_from_storage", - side_effect=ValueError("Index not found"), - ), - patch( - "paperless_ai.indexing.get_or_create_storage_context", - return_value=MagicMock(), - ), - ): - with pytest.raises(Exception): - indexing.load_or_build_index() - - -@pytest.mark.django_db -def test_load_or_build_index_succeeds_when_nodes_given( - temp_llm_index_dir, - mock_embed_model, -) -> None: - with ( - patch( - "llama_index.core.load_index_from_storage", - side_effect=ValueError("Index not found"), - ), - patch( - "llama_index.core.VectorStoreIndex", - return_value=MagicMock(), - ) as mock_index_cls, - patch( - "paperless_ai.indexing.get_or_create_storage_context", - return_value=MagicMock(), - ) as mock_storage, - ): - mock_storage.return_value.persist_dir = temp_llm_index_dir - indexing.load_or_build_index( - nodes=[MagicMock()], - ) - mock_index_cls.assert_called_once() + store = indexing.get_vector_store() + assert store.table_exists(), ( + "Expected the LanceDB table to exist after incremental update" + ) @pytest.mark.django_db def test_add_or_update_document_updates_existing_entry( - temp_llm_index_dir, - real_document, - mock_embed_model, + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: FakeEmbedding, ) -> None: indexing.update_llm_index(rebuild=True) indexing.llm_index_add_or_update_document(real_document) - assert any(temp_llm_index_dir.glob("*.json")) - - -@pytest.mark.django_db -def test_remove_document_deletes_node_from_docstore( - temp_llm_index_dir, - real_document, - mock_embed_model, -) -> None: - indexing.update_llm_index(rebuild=True) - index = indexing.load_or_build_index() - assert len(index.docstore.docs) == 1 - - indexing.llm_index_remove_document(real_document) - index = indexing.load_or_build_index() - assert len(index.docstore.docs) == 0 + store = indexing.get_vector_store() + assert store.table_exists(), ( + "Expected the LanceDB table to exist after add-or-update" + ) @pytest.mark.django_db def test_query_after_remove_does_not_raise_key_error( - temp_llm_index_dir, - real_document, - mock_embed_model, + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: FakeEmbedding, ) -> None: indexing.update_llm_index(rebuild=True) @@ -352,8 +297,8 @@ def test_query_after_remove_does_not_raise_key_error( @pytest.mark.django_db def test_update_llm_index_no_documents( - temp_llm_index_dir, - mock_embed_model, + temp_llm_index_dir: Path, + mock_embed_model: FakeEmbedding, ) -> None: with patch("documents.models.Document.objects.all") as mock_all: mock_queryset = MagicMock() @@ -369,6 +314,22 @@ def test_update_llm_index_no_documents( ) +@pytest.mark.django_db +def test_update_no_documents_no_index_returns_early( + temp_llm_index_dir: Path, + mocker: pytest_mock.MockerFixture, +) -> None: + """update with no documents and no existing index must return early.""" + mock_qs = MagicMock() + mock_qs.exists.return_value = False + mock_qs.__iter__ = MagicMock(return_value=iter([])) + mocker.patch("paperless_ai.indexing.Document.objects.all", return_value=mock_qs) + + result = indexing.update_llm_index(rebuild=False) + + assert result == "No documents found to index." + + @pytest.mark.django_db def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -> None: # No existing tasks @@ -406,20 +367,17 @@ def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() - LLM_BACKEND="ollama", ) def test_query_similar_documents( - temp_llm_index_dir, - real_document, + temp_llm_index_dir: Path, + real_document: Document, ) -> None: with ( - patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage, patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index, patch( - "paperless_ai.indexing.vector_store_file_exists", + "paperless_ai.indexing.llm_index_exists", ) as mock_vector_store_exists, patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls, patch("paperless_ai.indexing.Document.objects.filter") as mock_filter, ): - mock_storage.return_value = MagicMock() - mock_storage.return_value.persist_dir = temp_llm_index_dir mock_vector_store_exists.return_value = True mock_index = MagicMock() @@ -453,12 +411,12 @@ def test_query_similar_documents( @pytest.mark.django_db def test_query_similar_documents_triggers_update_when_index_missing( - temp_llm_index_dir, - real_document, + temp_llm_index_dir: Path, + real_document: Document, ) -> None: with ( patch( - "paperless_ai.indexing.vector_store_file_exists", + "paperless_ai.indexing.llm_index_exists", return_value=False, ), patch( @@ -479,120 +437,13 @@ def test_query_similar_documents_triggers_update_when_index_missing( assert result == [] -@pytest.mark.django_db -def test_query_similar_documents_normalizes_and_post_filters_allowed_ids( - real_document, -) -> None: - real_document.owner = User.objects.create_user(username="rag-owner") - real_document.save() - private_owner = User.objects.create_user(username="rag-private-owner") - private_document = Document.objects.create( - title="Private similar document", - content="Similar private content that must not reach RAG.", - owner=private_owner, - added=timezone.now(), - ) - - with ( - patch( - "paperless_ai.indexing.vector_store_file_exists", - return_value=True, - ), - patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index, - patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls, - ): - allowed_node = MagicMock() - allowed_node.node_id = "allowed-node" - allowed_node.metadata = {"document_id": str(real_document.pk)} - private_node = MagicMock() - private_node.node_id = "private-node" - private_node.metadata = {"document_id": str(private_document.pk)} - - mock_index = MagicMock() - mock_index.docstore.docs.values.return_value = [allowed_node, private_node] - mock_load_or_build_index.return_value = mock_index - - mock_retriever = MagicMock() - mock_retriever.retrieve.return_value = [private_node, allowed_node] - mock_retriever_cls.return_value = mock_retriever - - result = indexing.query_similar_documents( - real_document, - top_k=2, - document_ids=[real_document.pk], - ) - - mock_retriever_cls.assert_called_once_with( - index=mock_index, - similarity_top_k=2, - doc_ids=["allowed-node"], - ) - assert result == [real_document] - assert private_document not in result - - -class TestUpdateLlmIndexStaleNodes: - """Tests that update_llm_index removes ALL nodes for a multi-chunk document.""" - - @pytest.mark.django_db - def test_incremental_update_removes_all_old_nodes_for_multi_chunk_document( - self, - temp_llm_index_dir, - mock_embed_model: MagicMock, - ) -> None: - """Ghost nodes from all chunks of a modified document must be removed. - - When a document is split into multiple chunks (chunk_size=1024), the - incremental update path must delete every old node, not just the last - one captured by a dict comprehension keyed on document_id. - """ - # Content long enough to produce at least two chunks at chunk_size=1024. - # Generate many paragraphs so the token count comfortably exceeds 1024. - fake = Faker() - long_content = "\n\n".join(fake.paragraph(nb_sentences=20) for _ in range(20)) - doc = DocumentFactory(content=long_content) - - # Build the initial index (rebuild=True) so it has multiple nodes - indexing.update_llm_index(rebuild=True) - - # Verify the initial index has more than one node for this document - initial_index = indexing.load_or_build_index() - initial_node_ids = [ - nid - for nid, node in initial_index.docstore.docs.items() - if node.metadata.get("document_id") == str(doc.id) - ] - assert len(initial_node_ids) > 1, ( - f"Expected multiple chunks but got {len(initial_node_ids)}; " - "increase long_content length" - ) - - # Simulate a modification so the incremental path treats it as changed. - # Use queryset.update() to bypass auto_now and actually change the DB value. - new_modified = timezone.now() - Document.objects.filter(pk=doc.pk).update(modified=new_modified) - - # Run incremental update (rebuild=False) with the modified document - indexing.update_llm_index(rebuild=False) - - # Reload the persisted index and check that no OLD node ids remain - updated_index = indexing.load_or_build_index() - remaining_old_node_ids = [ - nid for nid in initial_node_ids if nid in updated_index.docstore.docs - ] - assert remaining_old_node_ids == [], ( - f"Ghost nodes still present after incremental update: " - f"{remaining_old_node_ids}" - ) - - @pytest.mark.django_db def test_query_similar_documents_empty_allow_list_fails_closed( - real_document, + real_document: Document, ) -> None: with ( patch( - "paperless_ai.indexing.vector_store_file_exists", + "paperless_ai.indexing.llm_index_exists", return_value=True, ) as mock_vector_store_exists, patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index, @@ -610,27 +461,25 @@ def test_query_similar_documents_empty_allow_list_fails_closed( class TestUpdateLlmIndexEmptyDocumentSet: - """update_llm_index must persist an empty index when all documents are deleted. + """update_llm_index must clear the LanceDB table when all documents are deleted. - Without this, the stale on-disk FAISS vectors are never cleared and - subsequent similarity searches return phantom hits for document IDs that - no longer exist in the DB. + Without this, the stale vectors are never cleared and subsequent similarity + searches return phantom hits for document IDs that no longer exist in the DB. """ @pytest.mark.django_db def test_rebuild_clears_stale_index_when_no_documents_exist( self, temp_llm_index_dir: Path, - mock_embed_model: MagicMock, + mock_embed_model: FakeEmbedding, ) -> None: - """After deleting all documents, rebuild=True must persist an empty index. + """After deleting all documents, rebuild=True must produce a table with zero rows. Steps: 1. Build an index with one document so the on-disk state is non-empty. 2. Delete all documents from the DB. 3. Call update_llm_index(rebuild=True). - 4. Reload the index from disk. - 5. Assert the reloaded index has zero nodes (no phantom vectors). + 4. Open the LanceDB table directly and assert zero rows. """ # Step 1: create a document and build a non-empty index Document.objects.create( @@ -640,26 +489,23 @@ class TestUpdateLlmIndexEmptyDocumentSet: ) indexing.update_llm_index(rebuild=True) - initial_index = indexing.load_or_build_index() - assert len(initial_index.docstore.docs) > 0, ( - "Precondition failed: expected at least one node before deletion" + store = indexing.get_vector_store() + assert store.table_exists(), ( + "Precondition failed: expected the LanceDB table to exist before deletion" ) # Step 2: delete all documents Document.objects.all().delete() assert not Document.objects.exists() - # Step 3: rebuild with no documents + # Step 3: rebuild with no documents — drop_table is called so the table + # is removed (no rows to re-insert, so it stays absent). indexing.update_llm_index(rebuild=True) - # Step 4: reload the persisted index from disk - reloaded_index = indexing.load_or_build_index() - - # Step 5: phantom vectors must be gone - assert len(reloaded_index.docstore.docs) == 0, ( - f"Expected 0 nodes after clearing all documents, " - f"but found {len(reloaded_index.docstore.docs)}: " - f"{list(reloaded_index.docstore.docs.keys())}" + # Step 4: the table must be absent (no rows) — phantom vectors gone + store2 = indexing.get_vector_store() + assert not store2.table_exists(), ( + "Expected the LanceDB table to be absent after rebuilding with no documents" ) @@ -709,10 +555,14 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent: def test_returns_without_error_when_build_document_node_returns_empty( self, temp_llm_index_dir: Path, + mock_embed_model: MagicMock, mocker: pytest_mock.MockerFixture, ) -> None: - """When build_document_node returns [], the function must return without error - and must not call load_or_build_index at all.""" + """When build_document_node returns [], the function must return without error. + + The store's upsert_document treats an empty node list as a removal (no-op + delete), so load_or_build_index must not be called. + """ mocker.patch( "paperless_ai.indexing.build_document_node", return_value=[], @@ -720,6 +570,7 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent: mock_load = mocker.patch("paperless_ai.indexing.load_or_build_index") doc = MagicMock(spec=Document) + doc.id = 42 # Must not raise indexing.llm_index_add_or_update_document(doc) @@ -727,172 +578,165 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent: @pytest.mark.django_db -class TestLlmIndexLocking: - """The FAISS index mutation functions must acquire the index lock before touching the index. +def test_llm_index_compact_uses_zero_retention( + temp_llm_index_dir: Path, + mocker: pytest_mock.MockerFixture, +) -> None: + """compact must use retention_seconds=0 to clear all MVCC history immediately.""" + mock_store = mocker.MagicMock() + mocker.patch( + "paperless_ai.indexing.write_store", + return_value=mocker.MagicMock( + __enter__=mocker.MagicMock(return_value=mock_store), + __exit__=mocker.MagicMock(return_value=False), + ), + ) - Without locking, two concurrent Celery workers can each load the same - on-disk index, make independent modifications, and the last writer silently - overwrites the first's changes. + indexing.llm_index_compact() + + mock_store.compact.assert_called_once_with(retention_seconds=0) + + +@pytest.mark.django_db +class TestLlmIndexLocking: + """Index mutation functions must go through write_store(), which holds the lock. + + Without locking, two concurrent Celery workers can open the same store, + make independent modifications, and trigger CommitConflictError. """ - def test_add_or_update_document_acquires_lock( + def test_add_or_update_document_uses_write_store( self, temp_llm_index_dir: Path, + mock_embed_model: FakeEmbedding, mocker: pytest_mock.MockerFixture, ) -> None: - """llm_index_add_or_update_document must enter the file lock before touching the index.""" - call_order: list[str] = [] - - mock_lock_instance = MagicMock() - mock_lock_instance.__enter__ = MagicMock( - side_effect=lambda *_: call_order.append("lock_acquired"), - ) - mock_lock_instance.__exit__ = MagicMock(return_value=False) - - mock_file_lock_cls = mocker.patch( - "paperless_ai.indexing.FileLock", - return_value=mock_lock_instance, - ) - - mock_load = mocker.patch( - "paperless_ai.indexing.load_or_build_index", - side_effect=lambda *_a, **_kw: ( - call_order.append("index_loaded") or MagicMock() + mock_store = MagicMock() + mocker.patch( + "paperless_ai.indexing.write_store", + return_value=mocker.MagicMock( + __enter__=mocker.MagicMock(return_value=mock_store), + __exit__=mocker.MagicMock(return_value=False), ), ) + mock_node = MagicMock() + mock_node.get_content.return_value = "fake node text" mocker.patch( "paperless_ai.indexing.build_document_node", - return_value=[MagicMock()], + return_value=[mock_node], ) - mocker.patch("paperless_ai.indexing.remove_document_docstore_nodes") doc = MagicMock(spec=Document) + doc.id = 1 indexing.llm_index_add_or_update_document(doc) - mock_file_lock_cls.assert_called_once() - mock_lock_instance.__enter__.assert_called_once() - mock_load.assert_called_once() - assert call_order.index("lock_acquired") < call_order.index("index_loaded"), ( - "Lock must be acquired before the index is loaded" - ) + mock_store.upsert_document.assert_called_once() - def test_remove_document_acquires_lock( + def test_remove_document_uses_write_store( self, temp_llm_index_dir: Path, mocker: pytest_mock.MockerFixture, ) -> None: - """llm_index_remove_document must enter the file lock before loading the index.""" - call_order: list[str] = [] - - mock_lock_instance = MagicMock() - mock_lock_instance.__enter__ = MagicMock( - side_effect=lambda *_: call_order.append("lock_acquired"), - ) - mock_lock_instance.__exit__ = MagicMock(return_value=False) - - mock_file_lock_cls = mocker.patch( - "paperless_ai.indexing.FileLock", - return_value=mock_lock_instance, - ) - - mock_load = mocker.patch( - "paperless_ai.indexing.load_or_build_index", - side_effect=lambda *_a, **_kw: ( - call_order.append("index_loaded") or MagicMock() + mock_store = MagicMock() + mocker.patch( + "paperless_ai.indexing.write_store", + return_value=mocker.MagicMock( + __enter__=mocker.MagicMock(return_value=mock_store), + __exit__=mocker.MagicMock(return_value=False), ), ) - mocker.patch("paperless_ai.indexing.remove_document_docstore_nodes") doc = MagicMock(spec=Document) + doc.id = 1 indexing.llm_index_remove_document(doc) - mock_file_lock_cls.assert_called_once() - mock_lock_instance.__enter__.assert_called_once() - mock_load.assert_called_once() - assert call_order.index("lock_acquired") < call_order.index("index_loaded"), ( - "Lock must be acquired before the index is loaded" - ) + mock_store.delete.assert_called_once_with("1") - def test_update_llm_index_rebuild_acquires_lock( + def test_update_llm_index_rebuild_uses_write_store( self, temp_llm_index_dir: Path, - mock_embed_model: MagicMock, + mock_embed_model: FakeEmbedding, mocker: pytest_mock.MockerFixture, ) -> None: - """update_llm_index must enter the file lock during the rebuild/persist cycle.""" - mock_lock_instance = MagicMock() - mock_lock_instance.__enter__ = MagicMock(return_value=None) - mock_lock_instance.__exit__ = MagicMock(return_value=False) - - mock_file_lock_cls = mocker.patch( - "paperless_ai.indexing.FileLock", - return_value=mock_lock_instance, + mock_store = MagicMock() + mocker.patch( + "paperless_ai.indexing.write_store", + return_value=mocker.MagicMock( + __enter__=mocker.MagicMock(return_value=mock_store), + __exit__=mocker.MagicMock(return_value=False), + ), ) - - # exists=True so the code reaches the lock; iterate over an empty - # queryset so VectorStoreIndex is called with no nodes (still exercises - # the lock path without needing heavy FAISS fixture data) mock_qs = MagicMock() mock_qs.exists.return_value = True mock_qs.__iter__ = MagicMock(return_value=iter([])) mocker.patch("paperless_ai.indexing.Document.objects.all", return_value=mock_qs) - mocker.patch( - "paperless_ai.indexing.get_or_create_storage_context", - return_value=MagicMock(), - ) indexing.update_llm_index(rebuild=True) - mock_file_lock_cls.assert_called_once() - mock_lock_instance.__enter__.assert_called_once() + mock_store.drop_table.assert_called_once() - def test_query_similar_documents_acquires_lock( + +@pytest.mark.django_db +@pytest.mark.django_db +class TestLanceDbIndexing: + def test_get_vector_store_roundtrip( self, temp_llm_index_dir: Path, - mocker: pytest_mock.MockerFixture, + mock_embed_model: FakeEmbedding, ) -> None: - """query_similar_documents must enter the file lock before loading the index.""" - call_order: list[str] = [] + from paperless_ai.vector_store import PaperlessLanceVectorStore - mock_lock_instance = MagicMock() - mock_lock_instance.__enter__ = MagicMock( - side_effect=lambda *_: call_order.append("lock_acquired"), - ) - mock_lock_instance.__exit__ = MagicMock(return_value=False) + store = indexing.get_vector_store() + assert isinstance(store, PaperlessLanceVectorStore) - mock_file_lock_cls = mocker.patch( - "paperless_ai.indexing.FileLock", - return_value=mock_lock_instance, - ) + def test_add_then_remove_document( + self, + temp_llm_index_dir: Path, + mock_embed_model: FakeEmbedding, + real_document: Document, + ) -> None: + indexing.llm_index_add_or_update_document(real_document) + store = indexing.get_vector_store() + table = store.client.open_table(indexing.LLM_INDEX_TABLE) + assert table.count_rows() >= 1 - mocker.patch( - "paperless_ai.indexing.vector_store_file_exists", - return_value=True, - ) + indexing.llm_index_remove_document(real_document) + assert store.client.open_table(indexing.LLM_INDEX_TABLE).count_rows() == 0 - mock_index = MagicMock() - mock_index.docstore.docs = {} + def test_update_shrinks_chunks_without_orphans( + self, + temp_llm_index_dir: Path, + mock_embed_model: FakeEmbedding, + real_document: Document, + ) -> None: + real_document.content = "word " * 4000 # many chunks + real_document.save() + indexing.llm_index_add_or_update_document(real_document) + store = indexing.get_vector_store() + big = store.client.open_table(indexing.LLM_INDEX_TABLE).count_rows() - mocker.patch( - "paperless_ai.indexing.load_or_build_index", - side_effect=lambda *_a, **_kw: ( - call_order.append("index_loaded") or mock_index - ), - ) + real_document.content = "short" # one chunk + real_document.save() + indexing.llm_index_add_or_update_document(real_document) - mock_retriever = MagicMock() - mock_retriever.retrieve.return_value = [] - mocker.patch( - "llama_index.core.retrievers.VectorIndexRetriever", - return_value=mock_retriever, - ) + rows = store.client.open_table(indexing.LLM_INDEX_TABLE).count_rows() + assert rows < big + assert rows >= 1 - mocker.patch("paperless_ai.indexing.truncate_content", return_value="") - indexing.query_similar_documents(MagicMock(spec=Document)) +@pytest.mark.django_db +class TestQuerySimilarDocuments: + def test_query_similar_documents_respects_allowed_ids( + self, + temp_llm_index_dir: Path, + mock_embed_model: FakeEmbedding, + ) -> None: + a = DocumentFactory.create(content="alpha shared content here") + b = DocumentFactory.create(content="beta shared content here") + c = DocumentFactory.create(content="gamma shared content here") + for doc in (a, b, c): + indexing.llm_index_add_or_update_document(doc) - mock_file_lock_cls.assert_called() - mock_lock_instance.__enter__.assert_called() - assert call_order.index("lock_acquired") < call_order.index("index_loaded"), ( - "Lock must be acquired before the index is loaded" - ) + results = indexing.query_similar_documents(a, document_ids=[b.id]) + + assert all(doc.id == b.id for doc in results) diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index d72b22f32..af34914bb 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -5,9 +5,9 @@ from unittest.mock import patch import pytest from llama_index.core.schema import TextNode +from paperless_ai import chat from paperless_ai.chat import CHAT_ERROR_MESSAGE from paperless_ai.chat import CHAT_METADATA_DELIMITER -from paperless_ai.chat import _get_document_filtered_retriever from paperless_ai.chat import stream_chat_with_documents @@ -58,91 +58,6 @@ def assert_chat_output( } -def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None: - mock_index.index_struct.nodes_dict = { - str(vector_id): node.node_id for vector_id, node in enumerate(nodes) - } - mock_index.docstore.docs.get.side_effect = { - node.node_id: node for node in nodes - }.get - mock_index.vector_store._faiss_index.ntotal = len(nodes) - mock_index.vector_store.query.return_value = MagicMock( - ids=list(mock_index.index_struct.nodes_dict), - similarities=[0.1] * len(nodes), - ) - mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536 - - -def test_document_filtered_retriever_expands_filters_and_caches() -> None: - allowed_node1 = TextNode( - text="Allowed content 1.", - metadata={"document_id": "1", "title": "Allowed 1"}, - ) - allowed_node2 = TextNode( - text="Allowed content 2.", - metadata={"document_id": "2", "title": "Allowed 2"}, - ) - foreign_node = TextNode( - text="Foreign content.", - metadata={"document_id": "3", "title": "Foreign"}, - ) - missing_node = TextNode( - text="Missing content.", - metadata={"document_id": "1", "title": "Missing"}, - ) - - mock_index = MagicMock() - mock_index.index_struct.nodes_dict = { - "0": foreign_node.node_id, - "1": missing_node.node_id, - "2": allowed_node1.node_id, - "3": allowed_node2.node_id, - } - mock_index.docstore.docs.get.side_effect = { - allowed_node1.node_id: allowed_node1, - allowed_node2.node_id: allowed_node2, - foreign_node.node_id: foreign_node, - }.get - mock_index.vector_store._faiss_index.ntotal = 4 - mock_index.vector_store.query.side_effect = [ - MagicMock(ids=["0", "2"], similarities=[0.9, 0.8]), - MagicMock(ids=["0", "1", "3"], similarities=[0.9, 0.7, 0.6]), - ] - mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536 - - retriever = _get_document_filtered_retriever( - mock_index, - {"1", "2"}, - similarity_top_k=2, - ) - - nodes = retriever.retrieve("question") - cached_nodes = retriever.retrieve("question") - - assert [node.node.node_id for node in nodes] == [ - allowed_node1.node_id, - allowed_node2.node_id, - ] - assert cached_nodes == nodes - assert mock_index.vector_store.query.call_count == 2 - assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1 - - -def test_document_filtered_retriever_handles_empty_faiss_index() -> None: - mock_index = MagicMock() - mock_index.vector_store._faiss_index.ntotal = 0 - mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536 - - retriever = _get_document_filtered_retriever( - mock_index, - {"1"}, - similarity_top_k=2, - ) - - assert retriever.retrieve("question") == [] - mock_index.vector_store.query.assert_not_called() - - @pytest.mark.django_db def test_stream_chat_with_one_document_retrieval( mock_document, @@ -164,17 +79,31 @@ def test_stream_chat_with_one_document_retrieval( metadata={"document_id": str(mock_document.pk), "title": "Test Document"}, ) mock_index = MagicMock() - mock_index.docstore.docs.values.return_value = [mock_node] - add_vector_query_results(mock_index, [mock_node]) + # Simulate get_nodes returning nodes (content exists) + mock_index.vector_store.get_nodes.return_value = [mock_node] mock_load_index.return_value = mock_index + mock_retriever_instance = MagicMock() + mock_retriever_instance.retrieve.return_value = [ + MagicMock( + metadata={ + "document_id": str(mock_document.pk), + "title": "Test Document", + }, + ), + ] + mock_response_stream = MagicMock() mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) mock_query_engine = MagicMock() mock_query_engine_cls.return_value = mock_query_engine mock_query_engine.query.return_value = mock_response_stream - output = list(stream_chat_with_documents("What is this?", [mock_document])) + with patch( + "llama_index.core.retrievers.VectorIndexRetriever", + return_value=mock_retriever_instance, + ): + output = list(stream_chat_with_documents("What is this?", [mock_document])) mock_query_engine.query.assert_called_once_with("What is this?") patch_embed_nodes.assert_not_called() @@ -196,12 +125,10 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non "llama_index.core.query_engine.RetrieverQueryEngine.from_args", ) as mock_query_engine_cls, ): - # Mock AIClient and LLM mock_client = MagicMock() mock_client_cls.return_value = mock_client mock_client.llm = MagicMock() - # Create two real TextNodes mock_node1 = TextNode( text="Content for doc 1.", metadata={"document_id": "1", "title": "Document 1"}, @@ -210,41 +137,32 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non text="Content for doc 2.", metadata={"document_id": "2", "title": "Document 2"}, ) - mock_duplicate_node = TextNode( - text="More content for doc 1.", - metadata={"document_id": "1", "title": "Document 1 Duplicate"}, - ) - mock_foreign_node = TextNode( - text="Content for doc 3.", - metadata={"document_id": "3", "title": "Document 3"}, - ) mock_index = MagicMock() - mock_index.docstore.docs.values.return_value = [ - mock_node1, - mock_node2, - mock_duplicate_node, - mock_foreign_node, - ] - add_vector_query_results( - mock_index, - [mock_node1, mock_duplicate_node, mock_node2, mock_foreign_node], - ) + # Simulate get_nodes returning nodes (content exists) + mock_index.vector_store.get_nodes.return_value = [mock_node1, mock_node2] mock_load_index.return_value = mock_index - # Mock response stream + mock_retriever_instance = MagicMock() + mock_retriever_instance.retrieve.return_value = [ + MagicMock(metadata={"document_id": "1", "title": "Document 1"}), + MagicMock(metadata={"document_id": "2", "title": "Document 2"}), + ] + mock_response_stream = MagicMock() mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) - # Mock RetrieverQueryEngine mock_query_engine = MagicMock() mock_query_engine_cls.return_value = mock_query_engine mock_query_engine.query.return_value = mock_response_stream - # Fake documents doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf") doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf") - output = list(stream_chat_with_documents("What's up?", [doc1, doc2])) + with patch( + "llama_index.core.retrievers.VectorIndexRetriever", + return_value=mock_retriever_instance, + ): + output = list(stream_chat_with_documents("What's up?", [doc1, doc2])) mock_query_engine.query.assert_called_once_with("What's up?") patch_embed_nodes.assert_not_called() @@ -258,8 +176,16 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non ) +def test_stream_chat_empty_document_list() -> None: + with patch("paperless_ai.chat.load_or_build_index") as mock_load_index: + output = list(stream_chat_with_documents("Any info?", [])) + mock_load_index.assert_not_called() + assert output == ["Sorry, I couldn't find any content to answer your question."] + + def test_stream_chat_no_matching_nodes() -> None: with ( + patch("paperless_ai.chat.AIConfig"), patch("paperless_ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.load_or_build_index") as mock_load_index, ): @@ -268,8 +194,8 @@ def test_stream_chat_no_matching_nodes() -> None: mock_client.llm = MagicMock() mock_index = MagicMock() - # No matching nodes - mock_index.docstore.docs.values.return_value = [] + # No matching nodes in the store + mock_index.vector_store.get_nodes.return_value = [] mock_load_index.return_value = mock_index output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)])) @@ -279,30 +205,88 @@ def test_stream_chat_no_matching_nodes() -> None: def test_stream_chat_unexpected_failure_returns_generic_error(caplog) -> None: with ( + patch("paperless_ai.chat.AIConfig"), patch("paperless_ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.load_or_build_index") as mock_load_index, - patch( - "paperless_ai.chat._get_document_filtered_retriever", - ) as mock_get_retriever, ): mock_client = MagicMock() mock_client_cls.return_value = mock_client mock_client.llm = MagicMock() - mock_node = TextNode( - text="This is node content.", - metadata={"document_id": "1", "title": "Test Document"}, - ) mock_index = MagicMock() - mock_index.docstore.docs.values.return_value = [mock_node] + # Nodes found so we get past the pre-check + mock_index.vector_store.get_nodes.return_value = [MagicMock()] mock_load_index.return_value = mock_index - mock_retriever = MagicMock() - mock_retriever.retrieve.side_effect = RuntimeError("private provider detail") - mock_get_retriever.return_value = mock_retriever + with patch( + "llama_index.core.retrievers.VectorIndexRetriever", + ) as mock_retriever_cls: + mock_retriever = MagicMock() + mock_retriever.retrieve.side_effect = RuntimeError( + "private provider detail", + ) + mock_retriever_cls.return_value = mock_retriever - output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)])) + output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)])) assert output == [CHAT_ERROR_MESSAGE] assert "Failed to stream document chat response" in caplog.text assert "private provider detail" in caplog.text + + +@pytest.mark.django_db +class TestStreamChatRetrieval: + def test_no_nodes_yields_no_content_message( + self, + temp_llm_index_dir, + mock_embed_model, + ) -> None: + from documents.tests.factories import DocumentFactory + + doc = DocumentFactory.create(content="hello world") + # Nothing indexed for this document yet. + out = list(chat.stream_chat_with_documents("question?", [doc])) + assert chat.CHAT_NO_CONTENT_MESSAGE in out + + def test_chat_filter_contains_only_requested_document_ids( + self, + temp_llm_index_dir, + mock_embed_model, + mocker, + ) -> None: + """The MetadataFilter passed to the retriever must be scoped to the + requested documents only — content from other indexed documents must + not be surfaced. + """ + from documents.tests.factories import DocumentFactory + from paperless_ai import indexing + + included = DocumentFactory.create(content="included document content") + excluded = DocumentFactory.create(content="excluded document content") + indexing.llm_index_add_or_update_document(included) + indexing.llm_index_add_or_update_document(excluded) + + # VectorIndexRetriever is imported inside _stream_chat_with_documents; + # patch it at the llama_index source so the lazy import picks it up. + captured_filters = [] + mock_retriever = mocker.MagicMock() + mock_retriever.retrieve.return_value = [] + + def capture_retriever(*args, **kwargs): + captured_filters.append(kwargs.get("filters")) + return mock_retriever + + mocker.patch("paperless_ai.chat.AIClient") + mocker.patch( + "llama_index.core.retrievers.VectorIndexRetriever", + side_effect=capture_retriever, + ) + + list(chat.stream_chat_with_documents("question?", [included])) + + assert captured_filters, "VectorIndexRetriever was never constructed" + filt = captured_filters[0] + assert filt is not None, "Retriever must receive a MetadataFilters" + filter_values = filt.filters[0].value + assert str(included.pk) in filter_values + assert str(excluded.pk) not in filter_values diff --git a/src/paperless_ai/tests/test_embedding.py b/src/paperless_ai/tests/test_embedding.py index 1dbd0ab99..251d3f90b 100644 --- a/src/paperless_ai/tests/test_embedding.py +++ b/src/paperless_ai/tests/test_embedding.py @@ -1,4 +1,3 @@ -import json from unittest.mock import ANY from unittest.mock import MagicMock from unittest.mock import patch @@ -10,7 +9,7 @@ from documents.models import Document from paperless.models import LLMEmbeddingBackend from paperless_ai.embedding import _normalize_llm_index_text from paperless_ai.embedding import build_llm_index_text -from paperless_ai.embedding import get_embedding_dim +from paperless_ai.embedding import get_configured_model_name from paperless_ai.embedding import get_embedding_model @@ -67,7 +66,7 @@ def test_get_embedding_model_openai(mock_ai_config): with patch( "llama_index.embeddings.openai_like.OpenAILikeEmbedding", ) as MockOpenAIEmbedding: - model = get_embedding_model() + model = get_embedding_model(mock_ai_config.return_value) MockOpenAIEmbedding.assert_called_once_with( model_name="text-embedding-3-small", api_key="test_api_key", @@ -88,7 +87,7 @@ def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config): with patch( "llama_index.embeddings.openai_like.OpenAILikeEmbedding", ) as MockOpenAIEmbedding: - model = get_embedding_model() + model = get_embedding_model(mock_ai_config.return_value) MockOpenAIEmbedding.assert_called_once_with( model_name="text-embedding-3-small", api_key="test_api_key", @@ -109,7 +108,7 @@ def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed( mock_ai_config.return_value.llm_allow_internal_endpoints = False with pytest.raises(ValueError, match="non-public address"): - get_embedding_model() + get_embedding_model(mock_ai_config.return_value) def test_get_embedding_model_huggingface(mock_ai_config): @@ -121,7 +120,7 @@ def test_get_embedding_model_huggingface(mock_ai_config): with patch( "llama_index.embeddings.huggingface.HuggingFaceEmbedding", ) as MockHuggingFaceEmbedding: - model = get_embedding_model() + model = get_embedding_model(mock_ai_config.return_value) MockHuggingFaceEmbedding.assert_called_once_with( model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=str(settings.DATA_DIR / "hf_cache"), @@ -137,7 +136,7 @@ def test_get_embedding_model_ollama(mock_ai_config): with patch( "llama_index.embeddings.ollama.OllamaEmbedding", ) as MockOllamaEmbedding: - model = get_embedding_model() + model = get_embedding_model(mock_ai_config.return_value) MockOllamaEmbedding.assert_called_once_with( model_name="embeddinggemma", base_url="http://test-url", @@ -155,7 +154,7 @@ def test_get_embedding_model_ollama_prefers_embedding_endpoint(mock_ai_config): with patch( "llama_index.embeddings.ollama.OllamaEmbedding", ) as MockOllamaEmbedding: - model = get_embedding_model() + model = get_embedding_model(mock_ai_config.return_value) MockOllamaEmbedding.assert_called_once_with( model_name="embeddinggemma", base_url="http://embedding-url", @@ -173,7 +172,7 @@ def test_get_embedding_model_ollama_blocks_internal_endpoint_when_disallowed( mock_ai_config.return_value.llm_allow_internal_endpoints = False with pytest.raises(ValueError, match="non-public address"): - get_embedding_model() + get_embedding_model(mock_ai_config.return_value) def test_get_embedding_model_invalid_backend(mock_ai_config): @@ -183,55 +182,37 @@ def test_get_embedding_model_invalid_backend(mock_ai_config): ValueError, match="Unsupported embedding backend: INVALID_BACKEND", ): - get_embedding_model() + get_embedding_model(mock_ai_config.return_value) -def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config): - mock_ai_config.return_value.llm_embedding_backend = "openai-like" - mock_ai_config.return_value.llm_embedding_model = None - - class DummyEmbedding: - def get_text_embedding(self, text): - return [0.0] * 7 - - with patch( - "paperless_ai.embedding.get_embedding_model", - return_value=DummyEmbedding(), - ) as mock_get: - dim = get_embedding_dim() - mock_get.assert_called_once() - - assert dim == 7 - meta = json.loads((temp_llm_index_dir / "meta.json").read_text()) - assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7} +@pytest.mark.parametrize( + ("backend", "expected_default"), + [ + (LLMEmbeddingBackend.OPENAI_LIKE, "text-embedding-3-small"), + (LLMEmbeddingBackend.HUGGINGFACE, "sentence-transformers/all-MiniLM-L6-v2"), + (LLMEmbeddingBackend.OLLAMA, "embeddinggemma"), + ], +) +def test_get_configured_model_name_falls_back_to_backend_default( + mock_ai_config, + backend, + expected_default, +): + """When no model is explicitly configured, each backend has a distinct default.""" + config = mock_ai_config.return_value + config.llm_embedding_backend = backend + config.llm_embedding_model = None + assert get_configured_model_name(config) == expected_default -def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config): - mock_ai_config.return_value.llm_embedding_backend = "openai-like" - mock_ai_config.return_value.llm_embedding_model = None - - (temp_llm_index_dir / "meta.json").write_text( - json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}), - ) - - with patch("paperless_ai.embedding.get_embedding_model") as mock_get: - assert get_embedding_dim() == 11 - mock_get.assert_not_called() - - -def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config): - mock_ai_config.return_value.llm_embedding_backend = "openai-like" - mock_ai_config.return_value.llm_embedding_model = None - - (temp_llm_index_dir / "meta.json").write_text( - json.dumps({"embedding_model": "old", "dim": 11}), - ) - - with pytest.raises( - RuntimeError, - match="Embedding model changed from old to text-embedding-3-small", - ): - get_embedding_dim() +def test_get_configured_model_name_explicit_overrides_default(mock_ai_config): + """An explicit model name overrides the backend default for all backends.""" + config = mock_ai_config.return_value + config.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE + config.llm_embedding_model = "my-custom-model" + # The backend default for OPENAI_LIKE is "text-embedding-3-small", so if + # the explicit name was ignored we'd get the wrong result. + assert get_configured_model_name(config) == "my-custom-model" def test_build_llm_index_text(mock_document): @@ -243,12 +224,15 @@ def test_build_llm_index_text(mock_document): result = build_llm_index_text(mock_document) - assert "Title: Test Title" in result + # Structured fields live in node.metadata for LLM context — not body text + assert "Title: Test Title" not in result + assert "Created: 2023-01-01" not in result + assert "Tags: Tag1, Tag2" not in result + assert "Document Type: Invoice" not in result + assert "Correspondent: Test Correspondent" not in result + + # Fields without a metadata equivalent stay in body text assert "Filename: test_file.pdf" in result - assert "Created: 2023-01-01" in result - assert "Tags: Tag1, Tag2" in result - assert "Document Type: Invoice" in result - assert "Correspondent: Test Correspondent" in result assert "Notes: Note1,Note2" in result assert "Content:\n\nThis is the document content." in result assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result diff --git a/src/paperless_ai/tests/test_lazy_imports.py b/src/paperless_ai/tests/test_lazy_imports.py new file mode 100644 index 000000000..7418d2ef0 --- /dev/null +++ b/src/paperless_ai/tests/test_lazy_imports.py @@ -0,0 +1,25 @@ +import subprocess +import sys +from pathlib import Path + +_SRC_DIR = Path(__file__).parent.parent.parent + + +class TestLazyAiImports: + def test_importing_tasks_does_not_load_ai_libraries(self) -> None: + code = ( + "import os, django, sys\n" + "os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'paperless.settings')\n" + "django.setup()\n" + "import documents.tasks # noqa: F401\n" + "leaked = [m for m in ('lancedb', 'pyarrow', 'llama_index') " + "if m in sys.modules]\n" + "assert not leaked, f'AI libraries leaked into the light path: {leaked}'\n" + ) + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True, + cwd=_SRC_DIR, + ) + assert result.returncode == 0, result.stdout + result.stderr diff --git a/src/paperless_ai/tests/test_vector_store.py b/src/paperless_ai/tests/test_vector_store.py new file mode 100644 index 000000000..b409ed1c9 --- /dev/null +++ b/src/paperless_ai/tests/test_vector_store.py @@ -0,0 +1,417 @@ +from pathlib import Path + +import pytest +from llama_index.core.schema import NodeRelationship +from llama_index.core.schema import RelatedNodeInfo +from llama_index.core.schema import TextNode +from llama_index.core.vector_stores.types import FilterOperator +from llama_index.core.vector_stores.types import MetadataFilter +from llama_index.core.vector_stores.types import MetadataFilters +from llama_index.core.vector_stores.types import VectorStoreQuery + +from paperless_ai.vector_store import PaperlessLanceVectorStore + +DIM = 8 + + +def _node(node_id: str, document_id: str, text: str, vec: float) -> TextNode: + node = TextNode(id_=node_id, text=text, metadata={"document_id": document_id}) + node.set_content(text) + node.embedding = [vec] * DIM + # Use relationships so ref_doc_id resolves correctly (it's a read-only property) + node.relationships = { + NodeRelationship.SOURCE: RelatedNodeInfo(node_id=document_id), + } + return node + + +class TestPaperlessLanceVectorStoreCrud: + @pytest.fixture + def store(self, tmp_path: Path) -> PaperlessLanceVectorStore: + return PaperlessLanceVectorStore(uri=str(tmp_path / "idx")) + + def test_add_then_query_returns_node( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.add([_node("1-0", "1", "alpha", 0.1), _node("2-0", "2", "beta", 0.9)]) + + result = store.query( + VectorStoreQuery(query_embedding=[0.1] * DIM, similarity_top_k=1), + ) + + assert len(result.nodes) == 1 + assert result.nodes[0].metadata["document_id"] == "1" + + def test_query_empty_table_returns_empty_no_raise( + self, + store: PaperlessLanceVectorStore, + ) -> None: + result = store.query( + VectorStoreQuery(query_embedding=[0.1] * DIM, similarity_top_k=5), + ) + assert result.nodes == [] + assert result.ids == [] + + def test_delete_removes_all_chunks_of_document( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.add([_node("1-0", "1", "a", 0.1), _node("1-1", "1", "b", 0.2)]) + store.add([_node("2-0", "2", "c", 0.9)]) + + store.delete("1") + + assert store.client.open_table("documents").count_rows() == 1 + + def test_query_with_in_filter_scopes_results( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.add([_node("1-0", "1", "a", 0.1), _node("2-0", "2", "b", 0.1)]) + + result = store.query( + VectorStoreQuery( + query_embedding=[0.1] * DIM, + similarity_top_k=5, + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="document_id", + operator=FilterOperator.IN, + value=["2"], + ), + ], + ), + ), + ) + + assert [n.metadata["document_id"] for n in result.nodes] == ["2"] + + def test_get_nodes_filter_returns_empty_cleanly( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.add([_node("1-0", "1", "a", 0.1)]) + nodes = store.get_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="document_id", + operator=FilterOperator.IN, + value=["999"], + ), + ], + ), + ) + assert nodes == [] + + def test_get_nodes_returns_empty_when_no_table( + self, + store: PaperlessLanceVectorStore, + ) -> None: + result = store.get_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="document_id", + operator=FilterOperator.IN, + value=["1"], + ), + ], + ), + ) + assert result == [] + + def test_fresh_instance_filters_existing_table( + self, + tmp_path: Path, + ) -> None: + uri = str(tmp_path / "idx") + PaperlessLanceVectorStore(uri=uri).add( + [_node("1-0", "1", "a", 0.1), _node("2-0", "2", "b", 0.1)], + ) + + reopened = PaperlessLanceVectorStore(uri=uri) + result = reopened.query( + VectorStoreQuery( + query_embedding=[0.1] * DIM, + similarity_top_k=5, + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="document_id", + operator=FilterOperator.IN, + value=["1"], + ), + ], + ), + ), + ) + assert [n.metadata["document_id"] for n in result.nodes] == ["1"] + + def test_table_exists_and_drop( + self, + store: PaperlessLanceVectorStore, + ) -> None: + assert store.table_exists() is False + store.add([_node("1-0", "1", "a", 0.1)]) + assert store.table_exists() is True + assert store.vector_dim() == DIM + store.drop_table() + assert store.table_exists() is False + + def test_build_where_or_condition(self) -> None: + from llama_index.core.vector_stores.types import FilterCondition + + from paperless_ai.vector_store import _build_where + + where = _build_where( + MetadataFilters( + filters=[ + MetadataFilter( + key="document_id", + operator=FilterOperator.EQ, + value="1", + ), + MetadataFilter( + key="document_id", + operator=FilterOperator.EQ, + value="2", + ), + ], + condition=FilterCondition.OR, + ), + ) + assert where == "document_id = '1' OR document_id = '2'" + + +class TestPaperlessLanceVectorStoreUpsert: + @pytest.fixture + def store(self, tmp_path: Path) -> PaperlessLanceVectorStore: + s = PaperlessLanceVectorStore(uri=str(tmp_path / "idx")) + s.add( + [ + _node("1-0", "1", "old0", 0.1), + _node("1-1", "1", "old1", 0.2), + _node("1-2", "1", "old2", 0.3), + _node("2-0", "2", "keep", 0.9), + ], + ) + return s + + def test_upsert_prunes_stale_chunks_and_keeps_others( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.upsert_document( + "1", + [_node("1-0", "1", "new0", 0.1), _node("1-1", "1", "new1", 0.2)], + ) + + table = store.client.open_table("documents") + doc1 = sorted( + r["id"] for r in table.search().where("document_id = '1'").to_list() + ) + assert doc1 == ["1-0", "1-1"] # 1-2 pruned + assert table.count_rows() == 3 # 2 new doc1 + 1 doc2 + + def test_upsert_is_single_commit( + self, + store: PaperlessLanceVectorStore, + ) -> None: + table = store.client.open_table("documents") + before = table.version + store.upsert_document("1", [_node("1-0", "1", "new0", 0.1)]) + assert store.client.open_table("documents").version == before + 1 + + def test_upsert_empty_nodes_removes_document( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.upsert_document("1", []) + + table = store.client.open_table("documents") + remaining = sorted(r["document_id"] for r in table.search().to_list()) + assert "1" not in remaining + assert "2" in remaining + + +class TestPaperlessLanceVectorStoreMaintenance: + @pytest.fixture + def store(self, tmp_path: Path) -> PaperlessLanceVectorStore: + return PaperlessLanceVectorStore(uri=str(tmp_path / "idx")) + + def test_maybe_create_ann_index_noop_below_threshold( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.add([_node("1-0", "1", "a", 0.1)]) + # Threshold far above row count -> no index attempted, no error. + store.maybe_create_ann_index(min_rows=1000) + # Still queryable. + result = store.query( + VectorStoreQuery(query_embedding=[0.1] * DIM, similarity_top_k=1), + ) + assert len(result.nodes) == 1 + + def test_maybe_create_ann_index_non_divisible_dim_falls_back( + self, + store: PaperlessLanceVectorStore, + ) -> None: + # DIM=8 is not divisible by the PQ default sub-vectors; must not raise + # and must leave the table queryable (IVF_FLAT fallback or skipped). + for i in range(40): + store.add([_node(f"1-{i}", "1", f"t{i}", float(i))]) + store.maybe_create_ann_index(min_rows=10) + result = store.query( + VectorStoreQuery(query_embedding=[1.0] * DIM, similarity_top_k=3), + ) + assert len(result.nodes) == 3 + + def test_compact_reduces_to_single_version( + self, + store: PaperlessLanceVectorStore, + ) -> None: + for i in range(5): + store.add([_node(f"1-{i}", "1", f"t{i}", float(i))]) + assert len(store.client.open_table("documents").list_versions()) > 1 + store.compact(retention_seconds=0) + assert len(store.client.open_table("documents").list_versions()) == 1 + + def test_upsert_after_optimize_with_scalar_index( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.add( + [ + _node("1-0", "1", "old0", 0.1), + _node("1-1", "1", "old1", 0.2), + _node("1-2", "1", "old2", 0.3), + _node("2-0", "2", "keep", 0.9), + ], + ) + store.ensure_document_id_scalar_index() + store.compact(retention_seconds=0) + + store.upsert_document("1", [_node("1-0", "1", "new0", 0.1)]) + + table = store.client.open_table("documents") + doc1 = sorted( + r["id"] for r in table.search().where("document_id = '1'").to_list() + ) + assert doc1 == ["1-0"] + assert table.count_rows() == 2 + + def test_ensure_scalar_index_is_idempotent( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.add([_node("1-0", "1", "text", 0.5)]) + store.ensure_document_id_scalar_index() + # Second call must not raise and must not replace the existing index. + store.ensure_document_id_scalar_index() + assert store._has_index_on("document_id") + + def test_ensure_scalar_index_noop_on_empty_store( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.ensure_document_id_scalar_index() # no table yet — must not raise + + +class TestConfigMismatch: + @pytest.fixture + def uri(self, tmp_path: Path) -> str: + return str(tmp_path / "idx") + + def test_stored_model_name_returns_none_when_no_table(self, uri: str) -> None: + store = PaperlessLanceVectorStore(uri=uri) + assert store.stored_model_name() is None + + def test_model_name_stored_in_schema_after_add(self, uri: str) -> None: + store = PaperlessLanceVectorStore(uri=uri, embed_model_name="all-MiniLM-L6-v2") + store.add([_node("1-0", "1", "text", 0.1)]) + assert store.stored_model_name() == "all-MiniLM-L6-v2" + + def test_model_name_stored_in_schema_after_upsert(self, uri: str) -> None: + store = PaperlessLanceVectorStore(uri=uri, embed_model_name="nomic-embed") + store.upsert_document("1", [_node("1-0", "1", "text", 0.1)]) + assert store.stored_model_name() == "nomic-embed" + + def test_model_name_persists_after_reopen(self, uri: str) -> None: + PaperlessLanceVectorStore(uri=uri, embed_model_name="all-MiniLM-L6-v2").add( + [_node("1-0", "1", "text", 0.1)], + ) + reopened = PaperlessLanceVectorStore(uri=uri) + assert reopened.stored_model_name() == "all-MiniLM-L6-v2" + + def test_config_mismatch_returns_false_when_no_table(self, uri: str) -> None: + store = PaperlessLanceVectorStore(uri=uri) + assert store.config_mismatch("any-model") is False + + def test_config_mismatch_returns_false_when_model_matches(self, uri: str) -> None: + store = PaperlessLanceVectorStore(uri=uri, embed_model_name="all-MiniLM-L6-v2") + store.add([_node("1-0", "1", "text", 0.1)]) + assert store.config_mismatch("all-MiniLM-L6-v2") is False + + def test_config_mismatch_returns_true_when_model_differs(self, uri: str) -> None: + store = PaperlessLanceVectorStore(uri=uri, embed_model_name="old-model") + store.add([_node("1-0", "1", "text", 0.1)]) + assert store.config_mismatch("new-model") is True + + def test_config_mismatch_returns_false_when_no_metadata_stored( + self, + uri: str, + ) -> None: + # Tables created before model-name tracking was added have no schema metadata. + # Conservative default: assume compatible rather than force a rebuild. + store = PaperlessLanceVectorStore(uri=uri) + store.add([_node("1-0", "1", "text", 0.1)]) + assert store.config_mismatch("any-model") is False + + +class TestGetModifiedTimes: + @pytest.fixture + def store(self, tmp_path: Path) -> PaperlessLanceVectorStore: + return PaperlessLanceVectorStore(uri=str(tmp_path / "idx")) + + def _node_with_modified( + self, + node_id: str, + doc_id: str, + modified: str, + ) -> TextNode: + node = TextNode( + id_=node_id, + text="text", + metadata={"document_id": doc_id, "modified": modified}, + ) + node.embedding = [0.1] * DIM + node.relationships = { + NodeRelationship.SOURCE: RelatedNodeInfo(node_id=doc_id), + } + return node + + def test_empty_store_returns_empty_dict( + self, + store: PaperlessLanceVectorStore, + ) -> None: + assert store.get_modified_times() == {} + + def test_returns_one_entry_per_document( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.add( + [ + self._node_with_modified("1-0", "1", "2024-01-01T00:00:00"), + self._node_with_modified("1-1", "1", "2024-01-01T00:00:00"), + self._node_with_modified("2-0", "2", "2024-06-01T00:00:00"), + ], + ) + result = store.get_modified_times() + assert result == { + "1": "2024-01-01T00:00:00", + "2": "2024-06-01T00:00:00", + } diff --git a/src/paperless_ai/vector_store.py b/src/paperless_ai/vector_store.py new file mode 100644 index 000000000..0e731e5c9 --- /dev/null +++ b/src/paperless_ai/vector_store.py @@ -0,0 +1,333 @@ +import json +import logging +from collections.abc import Sequence +from typing import Any + +import lancedb +import pyarrow as pa +from llama_index.core.bridge.pydantic import PrivateAttr +from llama_index.core.schema import BaseNode +from llama_index.core.vector_stores.types import BasePydanticVectorStore +from llama_index.core.vector_stores.types import FilterCondition +from llama_index.core.vector_stores.types import FilterOperator +from llama_index.core.vector_stores.types import MetadataFilters +from llama_index.core.vector_stores.types import VectorStoreQuery +from llama_index.core.vector_stores.types import VectorStoreQueryResult +from llama_index.core.vector_stores.utils import metadata_dict_to_node +from llama_index.core.vector_stores.utils import node_to_metadata_dict + +logger = logging.getLogger("paperless_ai.vector_store") + +DEFAULT_TABLE_NAME = "documents" + +# Below this many chunks, LanceDB's exact (brute-force) search is sufficient and +# faster than building an ANN index (per LanceDB guidance, ~100K vectors). +ANN_INDEX_MIN_ROWS = 100_000 +# IVF_PQ default; num_sub_vectors must evenly divide the embedding dimension. +ANN_PQ_SUB_VECTORS = 96 + + +def _escape(value: str) -> str: + return str(value).replace("'", "''") + + +def _build_where(filters: MetadataFilters | None) -> str | None: + """Translate the EQ / IN filters we use into a Lance SQL predicate on the + top-level ``document_id`` column.""" + if filters is None or not filters.filters: + return None + clauses: list[str] = [] + for f in filters.filters: + if f.operator == FilterOperator.IN: + vals = ",".join(f"'{_escape(v)}'" for v in f.value) + clauses.append(f"{f.key} IN ({vals})") + elif f.operator == FilterOperator.EQ: + clauses.append(f"{f.key} = '{_escape(f.value)}'") + else: # pragma: no cover - we only ever build EQ/IN filters + raise NotImplementedError(f"Unsupported filter operator: {f.operator}") + joiner = " OR " if filters.condition == FilterCondition.OR else " AND " + return joiner.join(clauses) + + +class PaperlessLanceVectorStore(BasePydanticVectorStore): + """A llama-index vector store backed directly by a LanceDB table. + + Stores one row per node with the node id, its document id (both as the + ``ref_doc_id`` delete key ``doc_id`` and a top-level filter column + ``document_id``), the embedding, and the serialised node (text + metadata) + as JSON. ``stores_text`` lets llama-index run off this store alone, with no + separate docstore or index store. + + Implemented surface of ``BasePydanticVectorStore`` + --------------------------------------------------- + Only the methods actively used by this codebase are implemented. + ``delete_nodes`` and the ``node_ids`` lookup path of ``get_nodes`` are + part of the llama-index interface contract and may be needed if a future + retriever or extension invokes them — add them then, with tests. + """ + + stores_text: bool = True + flat_metadata: bool = False + + _uri: str = PrivateAttr() + _table_name: str = PrivateAttr() + _embed_model_name: str | None = PrivateAttr() + _conn: Any = PrivateAttr() + _table: Any = PrivateAttr() + + def __init__( + self, + uri: str, + table_name: str = DEFAULT_TABLE_NAME, + embed_model_name: str | None = None, + ) -> None: + super().__init__(stores_text=True, flat_metadata=False) + self._uri = uri + self._table_name = table_name + self._embed_model_name = embed_model_name + self._conn = lancedb.connect(uri) + existing = self._conn.list_tables().tables + self._table = ( + self._conn.open_table(table_name) if table_name in existing else None + ) + + @property + def client(self) -> Any: + return self._conn + + def table_exists(self) -> bool: + return self._table is not None + + def vector_dim(self) -> int | None: + if self._table is None: + return None + return self._table.schema.field("vector").type.list_size + + def drop_table(self) -> None: + if self.table_exists(): + self._conn.drop_table(self._table_name) + self._table = None + + def stored_model_name(self) -> str | None: + """Return the embedding model name stored in table schema metadata, or None.""" + if self._table is None: + return None + meta = self._table.schema.metadata or {} + value = meta.get(b"embed_model") + return value.decode() if value else None + + def config_mismatch(self, model_name: str) -> bool: + """True when the stored model name differs from ``model_name``. + + Returns False when no table exists or when the table predates model-name + tracking (schema has no metadata) — conservative default avoids spurious + rebuilds on upgrade. + """ + stored = self.stored_model_name() + if stored is None: + return False + return stored != model_name + + @staticmethod + def _schema(dim: int, model_name: str | None = None) -> pa.Schema: + meta = {b"embed_model": model_name.encode()} if model_name else None + return pa.schema( + [ + pa.field("id", pa.string()), + pa.field("doc_id", pa.string()), + pa.field("document_id", pa.string()), + pa.field("modified", pa.string()), + pa.field("vector", pa.list_(pa.float32(), dim)), + pa.field("node_content", pa.string()), + ], + metadata=meta, + ) + + def _row(self, node: BaseNode) -> dict[str, Any]: + meta = node_to_metadata_dict( + node, + remove_text=False, + flat_metadata=self.flat_metadata, + ) + return { + "id": node.node_id, + "doc_id": node.ref_doc_id, + "document_id": str(node.metadata.get("document_id")), + "modified": str(node.metadata.get("modified", "")), + "vector": node.get_embedding(), + "node_content": json.dumps(meta), + } + + def _ensure_table(self, rows: list[dict[str, Any]], dim: int) -> bool: + """Create the table from ``rows`` if it does not exist yet. + + Returns True if the table was just created (caller can skip the + separate add/merge step), False if the table already existed. + """ + if self._table is not None: + return False + self._table = self._conn.create_table( + self._table_name, + rows, + schema=self._schema(dim, self._embed_model_name), + ) + return True + + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: + if not nodes: + return [] + rows = [self._row(node) for node in nodes] + dim = len(nodes[0].get_embedding()) + if not self._ensure_table(rows, dim): + self._table.add(rows) + return [node.node_id for node in nodes] + + def upsert_document(self, document_id: str, nodes: list[BaseNode]) -> list[str]: + """Atomically replace all stored chunks of ``document_id`` with ``nodes``. + + A single ``merge_insert`` commit: matching node ids are updated, new ids + inserted, and any existing rows for this document that are not in the new + set are deleted (``when_not_matched_by_source_delete``). This prunes stale + trailing chunks when an edit reduces a document's chunk count, with no + transient empty state for concurrent lock-free readers. + """ + if not nodes: + # No indexable content: remove any existing chunks for this document. + if self._table is not None: + self._table.delete(f"document_id = '{_escape(document_id)}'") + return [] + rows = [self._row(node) for node in nodes] + dim = len(nodes[0].get_embedding()) + if self._ensure_table(rows, dim): + return [node.node_id for node in nodes] + ( + self._table.merge_insert("id") + .when_matched_update_all() + .when_not_matched_insert_all() + .when_not_matched_by_source_delete( + f"document_id = '{_escape(document_id)}'", + ) + .execute(rows) + ) + return [node.node_id for node in nodes] + + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + if self._table is not None: + self._table.delete(f"doc_id = '{_escape(ref_doc_id)}'") + + def _rows_to_nodes(self, rows: list[dict[str, Any]]) -> list[BaseNode]: + nodes: list[BaseNode] = [] + for row in rows: + node = metadata_dict_to_node(json.loads(row["node_content"])) + node.embedding = list(row["vector"]) + nodes.append(node) + return nodes + + def get_nodes( + self, + node_ids: list[str] | None = None, + filters: MetadataFilters | None = None, + **kwargs: Any, + ) -> list[BaseNode]: + if node_ids is not None: # pragma: no cover + # node_ids lookup is not implemented; see class docstring. + raise NotImplementedError( + "PaperlessLanceVectorStore does not support node_ids lookup", + ) + if self._table is None: + return [] + where = _build_where(filters) + query = self._table.search() + if where: + query = query.where(where) + return self._rows_to_nodes(query.to_list()) + + def query( + self, + query: VectorStoreQuery, + **kwargs: Any, + ) -> VectorStoreQueryResult: + if self._table is None: + return VectorStoreQueryResult(nodes=[], similarities=[], ids=[]) + top_k = query.similarity_top_k if query.similarity_top_k is not None else 10 + search = self._table.search(query.query_embedding).limit(top_k) + where = _build_where(query.filters) + if where: + search = search.where(where) + rows = search.to_list() + nodes = self._rows_to_nodes(rows) + # LanceDB returns an L2 distance (smaller = closer); map to a descending similarity. + sims = [1.0 / (1.0 + float(row["_distance"])) for row in rows] + ids = [row["id"] for row in rows] + return VectorStoreQueryResult(nodes=nodes, similarities=sims, ids=ids) + + def _has_index_on(self, column: str) -> bool: + return any(column in idx.columns for idx in self._table.list_indices()) + + def maybe_create_ann_index(self, min_rows: int = ANN_INDEX_MIN_ROWS) -> None: + """Best-effort: build an IVF index once the table is large enough. + + IVF_PQ is used when ``num_sub_vectors`` divides the embedding dimension, + otherwise IVF_FLAT (no divisor constraint). Any failure is logged and + leaves the table on exact search, which is always correct. + """ + if self._table is None: + return + rows = self._table.count_rows() + if rows < min_rows or self._has_index_on("vector"): + return + num_partitions = max(1, rows // 4096) + # Embedding dim from the schema's fixed-size list column. + dim = self._table.schema.field("vector").type.list_size + try: + if dim % ANN_PQ_SUB_VECTORS == 0: # pragma: no cover + self._table.create_index( + metric="l2", + num_partitions=num_partitions, + num_sub_vectors=ANN_PQ_SUB_VECTORS, + index_type="IVF_PQ", + ) + else: + self._table.create_index( + metric="l2", + num_partitions=num_partitions, + index_type="IVF_FLAT", + ) + except Exception as e: # pragma: no cover - depends on data/dim + logger.warning("Skipping ANN index creation: %s", e) + + def get_modified_times(self) -> dict[str, str]: + """Return {document_id: stored_modified_isoformat} for all indexed documents. + + One representative chunk per document is fetched; all chunks share the + same ``modified`` value so the first one seen is sufficient. + """ + if self._table is None: + return {} + result: dict[str, str] = {} + for row in self._table.search().select(["document_id", "modified"]).to_list(): + doc_id = str(row["document_id"]) + if doc_id not in result: + result[doc_id] = str(row["modified"] or "") + return result + + def ensure_document_id_scalar_index(self) -> None: + """Create a scalar index on the filter column (never on the merge key + ``id`` — see https://github.com/lancedb/lancedb/issues/3177). + No-op if the index already exists.""" + if self._table is None: + return + if self._has_index_on("document_id"): + return + try: + self._table.create_scalar_index("document_id") + except Exception as e: # pragma: no cover + logger.warning("Skipping document_id scalar index: %s", e) + + def compact(self, retention_seconds: int) -> None: + """Compact fragments and prune old MVCC versions in one call.""" + if self._table is None: + return + from datetime import timedelta + + self._table.optimize(cleanup_older_than=timedelta(seconds=retention_seconds)) diff --git a/uv.lock b/uv.lock index 691e54a9c..e6cad3a35 100644 --- a/uv.lock +++ b/uv.lock @@ -1200,23 +1200,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/8d/2bc5f5546ff2ccb3f7de06742853483ab75bf74f36a92254702f8baecc79/factory_boy-3.3.3-py2.py3-none-any.whl", hash = "sha256:1c39e3289f7e667c4285433f305f8d506efc2fe9c73aaea4151ebd5cdea394fc", size = 37036, upload-time = "2025-02-03T09:49:01.659Z" }, ] -[[package]] -name = "faiss-cpu" -version = "1.13.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/07/c9/671f66f6b31ec48e5825d36435f0cb91189fa8bb6b50724029dbff4ca83c/faiss_cpu-1.13.2-cp310-abi3-macosx_14_0_arm64.whl", hash = "sha256:a9064eb34f8f64438dd5b95c8f03a780b1a3f0b99c46eeacb1f0b5d15fc02dc1", size = 3452776, upload-time = "2025-12-24T10:27:01.419Z" }, - { url = "https://files.pythonhosted.org/packages/5a/4a/97150aa1582fb9c2bca95bd8fc37f27d3b470acec6f0a6833844b21e4b40/faiss_cpu-1.13.2-cp310-abi3-macosx_14_0_x86_64.whl", hash = "sha256:c8d097884521e1ecaea6467aeebbf1aa56ee4a36350b48b2ca6b39366565c317", size = 7896434, upload-time = "2025-12-24T10:27:03.592Z" }, - { url = "https://files.pythonhosted.org/packages/0b/d0/0940575f059591ca31b63a881058adb16a387020af1709dcb7669460115c/faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ee330a284042c2480f2e90450a10378fd95655d62220159b1408f59ee83ebf1", size = 11485825, upload-time = "2025-12-24T10:27:05.681Z" }, - { url = "https://files.pythonhosted.org/packages/e7/e1/a5acac02aa593809f0123539afe7b4aff61d1db149e7093239888c9053e1/faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ab88ee287c25a119213153d033f7dd64c3ccec466ace267395872f554b648cd7", size = 23845772, upload-time = "2025-12-24T10:27:08.194Z" }, - { url = "https://files.pythonhosted.org/packages/9c/7b/49dcaf354834ec457e85ca769d50bc9b5f3003fab7c94a9dcf08cf742793/faiss_cpu-1.13.2-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:85511129b34f890d19c98b82a0cd5ffb27d89d1cec2ee41d2621ee9f9ef8cf3f", size = 13477567, upload-time = "2025-12-24T10:27:10.822Z" }, - { url = "https://files.pythonhosted.org/packages/f7/6b/12bb4037921c38bb2c0b4cfc213ca7e04bbbebbfea89b0b5746248ce446e/faiss_cpu-1.13.2-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8b32eb4065bac352b52a9f5ae07223567fab0a976c7d05017c01c45a1c24264f", size = 25102239, upload-time = "2025-12-24T10:27:13.476Z" }, -] - [[package]] name = "faker" version = "40.15.0" @@ -2069,6 +2052,55 @@ redis = [ { name = "redis", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] +[[package]] +name = "lance-namespace" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lance-namespace-urllib3-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/21/80/2b6eaa08c5e25915acaa6368a70211a25b5ba9d2d6006450e68a73936164/lance_namespace-0.8.0.tar.gz", hash = "sha256:c4a79ee221a3b2315c29863ad12d85fcf219a13158e26149d63e21dc4b4673a7", size = 10756, upload-time = "2026-06-01T08:47:10.183Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/bd/7b40a08fb132fab39a6caebf832fdf6b9befc71be9413beb9be0a9d927d4/lance_namespace-0.8.0-py3-none-any.whl", hash = "sha256:782cf9e332f46bf06836722dd98b53ca8495ad98bb541501ff6876c89b67ec90", size = 12579, upload-time = "2026-06-01T08:47:10.91Z" }, +] + +[[package]] +name = "lance-namespace-urllib3-client" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "urllib3", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8c/37/06fcd5a8969381e0ba953d51990af8d331bdccbc62458bf2eed30d064573/lance_namespace_urllib3_client-0.8.0.tar.gz", hash = "sha256:4f060f05ebf3c04aeaeb0d2022cbe77648a3df290f02cd2c305e5797d0fc1fdd", size = 203710, upload-time = "2026-06-01T08:47:13.404Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/43/e280727feee958f303bc58d5fa912b07734a0831f756d841654d500c2c34/lance_namespace_urllib3_client-0.8.0-py3-none-any.whl", hash = "sha256:6734e341b726e5cc96a0cd257cef27eb9d03013f2d151526ee426cef8e63e228", size = 336669, upload-time = "2026-06-01T08:47:11.88Z" }, +] + +[[package]] +name = "lancedb" +version = "0.33.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "deprecation", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "lance-namespace", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "overrides", marker = "(python_full_version < '3.12' and sys_platform == 'darwin') or (python_full_version < '3.12' and sys_platform == 'linux')" }, + { name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pyarrow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/09/2f/d5a4b2a5bb1f800936c76a6d8a4daf127a86fcab621eeb70b574a5adc774/lancedb-0.33.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:d4eaf6fa7c2eac619208f1d396f4de635ee0f535673067118a31c1181575c48b", size = 48338115, upload-time = "2026-05-28T20:37:55.88Z" }, + { url = "https://files.pythonhosted.org/packages/07/12/31787b93a856b2c31382c7771dc22fb05575b70b87c9efe454269f4f0948/lancedb-0.33.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c6c2402ed2744245ae76c4167c0461da0a7a80f1608e0ec491c1548ea2b4302", size = 51162262, upload-time = "2026-05-28T20:37:59.101Z" }, + { url = "https://files.pythonhosted.org/packages/49/b7/081cc29f8e06bf12191b99ab3fe702aceebdb0914476b821a8c0445cacc8/lancedb-0.33.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ebf1ffad811e6254a93931a79489ba1f21f48564bdfa06abae846f5fcaaf3e8", size = 54381368, upload-time = "2026-05-28T20:38:02.2Z" }, + { url = "https://files.pythonhosted.org/packages/1c/bd/e0f4bd621f10ecf96a801b0166e87799ed7ca5a9dbabcef9a6c766a58ef3/lancedb-0.33.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:13da39f80adfea59e5831fe64e4166b2d70a2f843e6507bf644c4fe4c350087c", size = 51188986, upload-time = "2026-05-28T20:38:05.375Z" }, + { url = "https://files.pythonhosted.org/packages/d9/1a/a8647a432ac6aa59cdce1fc061a7050ea4278bcab364539b78af2ecf72d2/lancedb-0.33.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:21b712825f0a00225e8974a41352c4ea84b0899ef8c23b17f672fadc38bd8346", size = 54440958, upload-time = "2026-05-28T20:38:08.474Z" }, +] + [[package]] name = "langdetect" version = "1.0.9" @@ -2280,18 +2312,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f4/0c/fdddaee5391d915d3d568d2d8dbdb7c95647e65bb94d4ddb31d47cef5daf/llama_index_llms_openai_like-0.7.2-py3-none-any.whl", hash = "sha256:1f45a7b1cec8fb3f5997684327ffe6c19f93e789c2fff35dc5522465850faf0b", size = 6602, upload-time = "2026-04-23T23:05:31.708Z" }, ] -[[package]] -name = "llama-index-vector-stores-faiss" -version = "0.6.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "llama-index-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/7c/32/89a04e38fa9595b7116c61955d9a67085f0a5480738e9c14063e374724c2/llama_index_vector_stores_faiss-0.6.0.tar.gz", hash = "sha256:00bfeb6cb7571e0e856566cb4f10c89b415b6108f151d9ad48ee9c31da563f5e", size = 6045, upload-time = "2026-03-12T20:46:31.454Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5b/85/465b4f199075ae7773c181b2f98cf689f3107a8de031e7a9d4cd5e906446/llama_index_vector_stores_faiss-0.6.0-py3-none-any.whl", hash = "sha256:d4600c60ef5411d9e35ba573b4f416a5e13ea04c6f942c8e6f49f03f2feb4f3b", size = 7739, upload-time = "2026-03-12T20:46:30.736Z" }, -] - [[package]] name = "llama-index-workflows" version = "2.20.0" @@ -2872,6 +2892,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/c1/d6e64ccd0536bf616556f0cad2b6d94a8125f508d25cfd814b1d2db4e2f1/openai-2.32.0-py3-none-any.whl", hash = "sha256:4dcc9badeb4bf54ad0d187453742f290226d30150890b7890711bda4f32f192f", size = 1162570, upload-time = "2026-04-15T22:28:17.714Z" }, ] +[[package]] +name = "overrides" +version = "7.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/36/86/b585f53236dec60aba864e050778b25045f857e17f6e5ea0ae95fe80edd2/overrides-7.7.0.tar.gz", hash = "sha256:55158fa3d93b98cc75299b1e67078ad9003ca27945c76162c1c0766d6f91820a", size = 22812, upload-time = "2024-01-27T21:01:33.423Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/ab/fc8290c6a4c722e5514d80f62b2dc4c4df1a68a41d1364e625c35990fcf3/overrides-7.7.0-py3-none-any.whl", hash = "sha256:c7ed9d062f78b8e4c1a7b70bd8796b35ead4d9f510227ef9c5dc7626c60d7e49", size = 17832, upload-time = "2024-01-27T21:01:31.393Z" }, +] + [[package]] name = "packaging" version = "26.0" @@ -2912,7 +2941,6 @@ dependencies = [ { name = "drf-spectacular", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "drf-spectacular-sidecar", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "drf-writable-nested", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "faiss-cpu", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "flower", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "gotenberg-client", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -2920,6 +2948,7 @@ dependencies = [ { name = "ijson", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "imap-tools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "lancedb", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "langdetect", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "llama-index-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "llama-index-embeddings-huggingface", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -2927,12 +2956,12 @@ dependencies = [ { name = "llama-index-embeddings-openai-like", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "llama-index-llms-ollama", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "llama-index-llms-openai-like", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "llama-index-vector-stores-faiss", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "nltk", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "ocrmypdf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pathvalidate", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pdf2image", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pyarrow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "python-dotenv", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "python-gnupg", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -3062,7 +3091,6 @@ requires-dist = [ { name = "drf-spectacular", specifier = "~=0.28" }, { name = "drf-spectacular-sidecar", specifier = "~=2026.5.1" }, { name = "drf-writable-nested", specifier = "~=0.7.1" }, - { name = "faiss-cpu", specifier = ">=1.10" }, { name = "filelock", specifier = "~=3.29.0" }, { name = "flower", specifier = "~=2.0.1" }, { name = "gotenberg-client", specifier = "~=0.14.0" }, @@ -3071,6 +3099,7 @@ requires-dist = [ { name = "ijson", specifier = ">=3.2" }, { name = "imap-tools", specifier = "~=1.13.0" }, { name = "jinja2", specifier = "~=3.1.5" }, + { name = "lancedb", specifier = "~=0.33.0" }, { name = "langdetect", specifier = "~=1.0.9" }, { name = "llama-index-core", specifier = ">=0.14.21" }, { name = "llama-index-embeddings-huggingface", specifier = ">=0.6.1" }, @@ -3078,7 +3107,6 @@ requires-dist = [ { name = "llama-index-embeddings-openai-like", specifier = ">=0.2.2" }, { name = "llama-index-llms-ollama", specifier = ">=0.9.1" }, { name = "llama-index-llms-openai-like", specifier = ">=0.7.1" }, - { name = "llama-index-vector-stores-faiss", specifier = ">=0.5.2" }, { name = "mysqlclient", marker = "extra == 'mariadb'", specifier = "~=2.2.7" }, { name = "nltk", specifier = "~=3.9.1" }, { name = "ocrmypdf", specifier = "~=17.4.2" }, @@ -3090,6 +3118,7 @@ requires-dist = [ { name = "psycopg-c", marker = "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'postgres'", url = "https://github.com/paperless-ngx/builder/releases/download/psycopg-trixie-3.3.0/psycopg_c-3.3.0-cp312-cp312-linux_x86_64.whl" }, { name = "psycopg-c", marker = "(python_full_version != '3.12.*' and platform_machine == 'aarch64' and extra == 'postgres') or (python_full_version != '3.12.*' and platform_machine == 'x86_64' and extra == 'postgres') or (platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'postgres') or (sys_platform != 'linux' and extra == 'postgres')", specifier = "==3.3" }, { name = "psycopg-pool", marker = "extra == 'postgres'", specifier = "==3.3" }, + { name = "pyarrow", specifier = ">=16" }, { name = "python-dateutil", specifier = "~=2.9.0" }, { name = "python-dotenv", specifier = "~=1.2.1" }, { name = "python-gnupg", specifier = "~=0.5.4" }, @@ -3588,6 +3617,50 @@ version = "0.16.1" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/1d/c7/28220d37e041fe1df03e857fe48f768dcd30cd151480bf6f00da8713214a/py-ubjson-0.16.1.tar.gz", hash = "sha256:b9bfb8695a1c7e3632e800fb83c943bf67ed45ddd87cd0344851610c69a5a482", size = 50316, upload-time = "2020-04-18T15:05:57.698Z" } +[[package]] +name = "pyarrow" +version = "24.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/91/13/13e1069b351bdc3881266e11147ffccf687505dbb0ea74036237f5d454a5/pyarrow-24.0.0.tar.gz", hash = "sha256:85fe721a14dd823aca09127acbb06c3ca723efbd436c004f16bca601b04dcc83", size = 1180261, upload-time = "2026-04-21T10:51:25.837Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/c9/a47ab7ece0d86cbe6678418a0fbd1ac4bb493b9184a3891dfa0e7f287ae0/pyarrow-24.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b0e131f880cda8d04e076cee175a46fc0e8bc8b65c99c6c09dff6669335fde74", size = 35068898, upload-time = "2026-04-21T10:46:36.599Z" }, + { url = "https://files.pythonhosted.org/packages/d1/bc/8db86617a9a58008acf8913d6fed68ea2a46acb6de928db28d724c891a68/pyarrow-24.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:1b2fe7f9a5566401a0ef2571f197eb92358925c1f0c8dba305d6e43ea0871bb3", size = 36679915, upload-time = "2026-04-21T10:46:42.602Z" }, + { url = "https://files.pythonhosted.org/packages/eb/8e/fb178720400ef69db251eb4a9c3ccf4af269bc1feb5055529b8fc87170d1/pyarrow-24.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:0b3537c00fb8d384f15ac1e79b6eb6db04a16514c8c1d22e59a9b95c8ba42868", size = 45697931, upload-time = "2026-04-21T10:46:48.403Z" }, + { url = "https://files.pythonhosted.org/packages/f3/27/99c42abe8e21b44f4917f62631f3aa31404882a2c41d8a4cd5c110e13d52/pyarrow-24.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:14e31a3c9e35f1ab6356c6378f6f72830e6d2d5f1791df3774a7b097d18a6a1e", size = 48837449, upload-time = "2026-04-21T10:46:55.329Z" }, + { url = "https://files.pythonhosted.org/packages/36/b6/333749e2666e9032891125bf9c691146e92901bece62030ac1430e2e7c88/pyarrow-24.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b7d9a514e73bc42711e6a35aaccf3587c520024fe0a25d830a1a8a27c15f4f57", size = 49395949, upload-time = "2026-04-21T10:47:01.869Z" }, + { url = "https://files.pythonhosted.org/packages/17/25/c5201706a2dd374e8ba6ee3fd7a8c89fb7ffc16eed5217a91fd2bd7f7626/pyarrow-24.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b196eb3f931862af3fa84c2a253514d859c08e0d8fe020e07be12e75a5a9780c", size = 51912986, upload-time = "2026-04-21T10:47:09.872Z" }, + { url = "https://files.pythonhosted.org/packages/b4/a9/9686d9f07837f91f775e8932659192e02c74f9d8920524b480b85212cc68/pyarrow-24.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:6233c9ed9ab9d1db47de57d9753256d9dcffbf42db341576099f0fd9f6bf4810", size = 34981559, upload-time = "2026-04-21T10:47:22.17Z" }, + { url = "https://files.pythonhosted.org/packages/80/b6/0ddf0e9b6ead3474ab087ae598c76b031fc45532bf6a63f3a553440fb258/pyarrow-24.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:f7616236ec1bc2b15bfdec22a71ab38851c86f8f05ff64f379e1278cf20c634a", size = 36663654, upload-time = "2026-04-21T10:47:28.315Z" }, + { url = "https://files.pythonhosted.org/packages/7c/3b/926382efe8ce27ba729071d3566ade6dfb86bdf112f366000196b2f5780a/pyarrow-24.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:1617043b99bd33e5318ae18eb2919af09c71322ef1ca46566cdafc6e6712fb66", size = 45679394, upload-time = "2026-04-21T10:47:34.821Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/829f7d9dfd37c207206081d6dad474d81dde29952401f07f2ba507814818/pyarrow-24.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6165461f55ef6314f026de6638d661188e3455d3ec49834556a0ebbdbace18bb", size = 48863122, upload-time = "2026-04-21T10:47:42.056Z" }, + { url = "https://files.pythonhosted.org/packages/5f/e8/f88ce625fe8babaae64e8db2d417c7653adb3019b08aae85c5ed787dc816/pyarrow-24.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3b13dedfe76a0ad2d1d859b0811b53827a4e9d93a0bcb05cf59333ab4980cc7e", size = 49376032, upload-time = "2026-04-21T10:47:48.967Z" }, + { url = "https://files.pythonhosted.org/packages/36/7a/82c363caa145fff88fb475da50d3bf52bb024f61917be5424c3392eaf878/pyarrow-24.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:25ea65d868eb04015cd18e6df2fbe98f07e5bda2abefabcb88fce39a947716f6", size = 51929490, upload-time = "2026-04-21T10:47:55.981Z" }, + { url = "https://files.pythonhosted.org/packages/6f/d3/a1abf004482026ddc17f4503db227787fa3cfe41ec5091ff20e4fea55e57/pyarrow-24.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:02b001b3ed4723caa44f6cd1af2d5c86aa2cf9971dacc2ffa55b21237713dfba", size = 34976759, upload-time = "2026-04-21T10:48:07.258Z" }, + { url = "https://files.pythonhosted.org/packages/4f/4a/34f0a36d28a2dd32225301b79daad44e243dc1a2bb77d43b60749be255c4/pyarrow-24.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:04920d6a71aabd08a0417709efce97d45ea8e6fb733d9ca9ecffb13c67839f68", size = 36658471, upload-time = "2026-04-21T10:48:13.347Z" }, + { url = "https://files.pythonhosted.org/packages/1f/78/543b94712ae8bb1a6023bcc1acf1a740fbff8286747c289cd9468fced2a5/pyarrow-24.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a964266397740257f16f7bb2e4f08a0c81454004beab8ff59dd531b73610e9f2", size = 45675981, upload-time = "2026-04-21T10:48:20.201Z" }, + { url = "https://files.pythonhosted.org/packages/84/9f/8fb7c222b100d314137fa40ec050de56cd8c6d957d1cfff685ce72f15b17/pyarrow-24.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6f066b179d68c413374294bc1735f68475457c933258df594443bb9d88ddc2a0", size = 48859172, upload-time = "2026-04-21T10:48:27.541Z" }, + { url = "https://files.pythonhosted.org/packages/a7/d3/1ea72538e6c8b3b475ed78d1049a2c518e655761ea50fe1171fc855fcab7/pyarrow-24.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1183baeb14c5f587b1ec52831e665718ce632caab84b7cd6b85fd44f96114495", size = 49385733, upload-time = "2026-04-21T10:48:34.7Z" }, + { url = "https://files.pythonhosted.org/packages/c3/be/c3d8b06a1ba35f2260f8e1f771abbee7d5e345c0937aab90675706b1690a/pyarrow-24.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:806f24b4085453c197a5078218d1ee08783ebbba271badd153d1ae22a3ee804f", size = 51934335, upload-time = "2026-04-21T10:48:42.099Z" }, + { url = "https://files.pythonhosted.org/packages/17/1a/cff3a59f80b5b1658549d46611b67163f65e0664431c076ad728bf9d5af4/pyarrow-24.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:1a4e45017efbf115032e4475ee876d525e0e36c742214fbe405332480ecd6275", size = 35238554, upload-time = "2026-04-21T10:48:48.526Z" }, + { url = "https://files.pythonhosted.org/packages/a8/99/cce0f42a327bfef2c420fb6078a3eb834826e5d6697bf3009fe11d2ad051/pyarrow-24.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:7986f1fa71cee060ad00758bcc79d3a93bab8559bf978fab9e53472a2e25a17b", size = 36782301, upload-time = "2026-04-21T10:48:55.181Z" }, + { url = "https://files.pythonhosted.org/packages/2a/66/8e560d5ff6793ca29aca213c53eec0dd482dd46cb93b2819e5aab52e4252/pyarrow-24.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:d3e0b61e8efb24ed38898e5cdc5fffa9124be480008d401a1f8071500494ae42", size = 45721929, upload-time = "2026-04-21T10:49:03.676Z" }, + { url = "https://files.pythonhosted.org/packages/27/0c/a26e25505d030716e078d9f16eb74973cbf0b33b672884e9f9da1c83b871/pyarrow-24.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:55a3bc1e3df3b5567b7d27ef551b2283f0c68a5e86f1cd56abc569da4f31335b", size = 48825365, upload-time = "2026-04-21T10:49:11.714Z" }, + { url = "https://files.pythonhosted.org/packages/5f/eb/771f9ecb0c65e73fe9dccdd1717901b9594f08c4515d000c7c62df573811/pyarrow-24.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:641f795b361874ac9da5294f8f443dfdbee355cf2bd9e3b8d97aaac2306b9b37", size = 49451819, upload-time = "2026-04-21T10:49:21.474Z" }, + { url = "https://files.pythonhosted.org/packages/48/da/61ae89a88732f5a785646f3ec6125dbb640fa98a540eb2b9889caa561403/pyarrow-24.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:8adc8e6ce5fccf5dc707046ae4914fd537def529709cc0d285d37a7f9cd442ca", size = 51909252, upload-time = "2026-04-21T10:49:31.164Z" }, + { url = "https://files.pythonhosted.org/packages/ad/80/d022a34ff05d2cbedd8ccf841fc1f532ecfa9eb5ed1711b56d0e0ea71fc9/pyarrow-24.0.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:1cc9057f0319e26333b357e17f3c2c022f1a83739b48a88b25bfd5fa2dc18838", size = 35007997, upload-time = "2026-04-21T10:49:48.796Z" }, + { url = "https://files.pythonhosted.org/packages/1a/ff/f01485fda6f4e5d441afb8dd5e7681e4db18826c1e271852f5d3957d6a80/pyarrow-24.0.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:e6f1278ee4785b6db21229374a1c9e54ec7c549de5d1efc9630b6207de7e170b", size = 36678720, upload-time = "2026-04-21T10:49:55.858Z" }, + { url = "https://files.pythonhosted.org/packages/9e/c2/2d2d5fea814237923f71b36495211f20b43a1576f9a4d6da7e751a64ec6f/pyarrow-24.0.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:adbbedc55506cbdabb830890444fb856bfb0060c46c6f8026c6c2f2cf86ae795", size = 45741852, upload-time = "2026-04-21T10:50:04.624Z" }, + { url = "https://files.pythonhosted.org/packages/8e/3a/28ba9c1c1ebdbb5f1b94dfebb46f207e52e6a554b7fe4132540fde29a3a0/pyarrow-24.0.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:ae8a1145af31d903fa9bb166824d7abe9b4681a000b0159c9fb99c11bc11ad26", size = 48889852, upload-time = "2026-04-21T10:50:12.293Z" }, + { url = "https://files.pythonhosted.org/packages/df/51/4a389acfd31dca009f8fb82d7f510bb4130f2b3a8e18cf00194d0687d8ac/pyarrow-24.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:d7027eba1df3b2069e2e8d80f644fa0918b68c46432af3d088ddd390d063ecde", size = 49445207, upload-time = "2026-04-21T10:50:20.677Z" }, + { url = "https://files.pythonhosted.org/packages/19/4b/0bab2b23d2ae901b1b9a03c0efd4b2d070256f8ce3fc43f6e58c167b2081/pyarrow-24.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e56a1ffe9bf7b727432b89104cc0849c21582949dd7bdcb34f17b2001a351a76", size = 51954117, upload-time = "2026-04-21T10:50:29.14Z" }, + { url = "https://files.pythonhosted.org/packages/79/4f/46a49a63f43526da895b1a45bbb51d5baf8e4d77159f8528fc3e5490007f/pyarrow-24.0.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:418e48ce50a45a6a6c73c454677203a9c75c966cb1e92ca3370959185f197a05", size = 35250387, upload-time = "2026-04-21T10:50:35.552Z" }, + { url = "https://files.pythonhosted.org/packages/a0/da/d5e0cd5ef00796922404806d5f00325cdadc3441ce2c13fe7115f2df9a64/pyarrow-24.0.0-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:2f16197705a230a78270cdd4ea8a1d57e86b2fdcbc34a1f6aebc72e65c986f9a", size = 36797102, upload-time = "2026-04-21T10:50:42.417Z" }, + { url = "https://files.pythonhosted.org/packages/34/c7/5904145b0a593a05236c882933d439b5720f0a145381179063722fbfc123/pyarrow-24.0.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:fb24ac194bfc5e86839d7dcd52092ee31e5fe6733fe11f5e3b06ef0812b20072", size = 45745118, upload-time = "2026-04-21T10:50:49.324Z" }, + { url = "https://files.pythonhosted.org/packages/13/d3/cca42fe166d1c6e4d5b80e530b7949104d10e17508a90ae202dac205ce2a/pyarrow-24.0.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:9700ebd9a51f5895ce75ff4ac4b3c47a7d4b42bc618be8e713e5d56bacf5f931", size = 48844765, upload-time = "2026-04-21T10:50:55.579Z" }, + { url = "https://files.pythonhosted.org/packages/b0/49/942c3b79878ba928324d1e17c274ed84581db8c0a749b24bcf4cbdf15bd3/pyarrow-24.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:d8ddd2768da81d3ee08cfea9b597f4abb4e8e1dc8ae7e204b608d23a0d3ab699", size = 49471890, upload-time = "2026-04-21T10:51:02.439Z" }, + { url = "https://files.pythonhosted.org/packages/76/97/ff71431000a75d84135a1ace5ca4ba11726a231a8007bbb320a4c54075d5/pyarrow-24.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:61a3d7eaa97a14768b542f3d284dc6400dd2470d9f080708b13cd46b6ae18136", size = 51932250, upload-time = "2026-04-21T10:51:10.576Z" }, +] + [[package]] name = "pyasn1" version = "0.6.3" From 6f8e39c2e0d877d7088a58939ea705530fa7e985 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Sun, 7 Jun 2026 13:30:08 -0700 Subject: [PATCH 06/29] Fix: avoid unnecessary creating new PDF with pw removal workflow (#12948) --- src/documents/bulk_edit.py | 13 +++ src/documents/tests/test_bulk_edit.py | 126 ++++++++++++++++++++++++-- 2 files changed, 130 insertions(+), 9 deletions(-) diff --git a/src/documents/bulk_edit.py b/src/documents/bulk_edit.py index 848919cde..0cea9a3a1 100644 --- a/src/documents/bulk_edit.py +++ b/src/documents/bulk_edit.py @@ -904,6 +904,19 @@ def remove_password( doc.id, pair.source_doc.source_path, ) + try: + with pikepdf.open(source_path) as pdf: + if not pdf.is_encrypted: + logger.info( + "Skipping password removal for document %s because the " + "source PDF is not encrypted", + pair.root_doc.id, + ) + continue + except pikepdf.PasswordError: + # Password-protected PDFs need the supplied password below. + pass + with pikepdf.open(source_path, password=password) as pdf: filepath: Path = ( Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR)) diff --git a/src/documents/tests/test_bulk_edit.py b/src/documents/tests/test_bulk_edit.py index 8d0c893eb..010744af1 100644 --- a/src/documents/tests/test_bulk_edit.py +++ b/src/documents/tests/test_bulk_edit.py @@ -3,6 +3,7 @@ from datetime import date from pathlib import Path from unittest import mock +import pikepdf from django.contrib.auth.models import Group from django.contrib.auth.models import User from django.test import TestCase @@ -615,6 +616,18 @@ class TestPDFActions(DirectoriesMixin, TestCase): self.img_doc.archive_filename = img_doc_archive self.img_doc.save() + @staticmethod + def mock_password_required_pdf( + mock_open: mock.Mock, + fake_pdf: mock.Mock, + ) -> None: + password_context = mock.MagicMock() + password_context.__enter__.return_value = fake_pdf + mock_open.side_effect = [ + pikepdf.PasswordError("password required"), + password_context, + ] + @mock.patch("documents.tasks.consume_file.s") def test_merge(self, mock_consume_file) -> None: """ @@ -1466,6 +1479,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): fake_pdf = mock.MagicMock() fake_pdf.pages = [mock.Mock(), mock.Mock(), mock.Mock()] + fake_pdf.is_encrypted = True def save_side_effect(target_path): Path(target_path).write_bytes(b"new pdf content") @@ -1480,7 +1494,13 @@ class TestPDFActions(DirectoriesMixin, TestCase): ) self.assertEqual(result, "OK") - mock_open.assert_called_once_with(doc.source_path, password="secret") + self.assertEqual( + mock_open.call_args_list, + [ + mock.call(doc.source_path), + mock.call(doc.source_path, password="secret"), + ], + ) fake_pdf.remove_unreferenced_resources.assert_called_once() mock_update_document.assert_not_called() mock_consume_delay.assert_called_once() @@ -1494,6 +1514,33 @@ class TestPDFActions(DirectoriesMixin, TestCase): self.assertEqual(task_kwargs["input_doc"].root_document_id, doc.id) self.assertIsNotNone(task_kwargs["overrides"]) + @mock.patch("documents.tasks.consume_file.apply_async") + @mock.patch("documents.bulk_edit.tempfile.mkdtemp") + @mock.patch("pikepdf.open") + def test_remove_password_update_document_skips_unencrypted_pdf( + self, + mock_open, + mock_mkdtemp, + mock_consume_delay, + ) -> None: + doc = self.doc1 + fake_pdf = mock.MagicMock() + fake_pdf.is_encrypted = False + mock_open.return_value.__enter__.return_value = fake_pdf + + result = bulk_edit.remove_password( + [doc.id], + password="secret", + update_document=True, + ) + + self.assertEqual(result, "OK") + mock_open.assert_called_once_with(doc.source_path) + fake_pdf.remove_unreferenced_resources.assert_not_called() + fake_pdf.save.assert_not_called() + mock_mkdtemp.assert_not_called() + mock_consume_delay.assert_not_called() + @mock.patch("documents.bulk_edit.update_document_content_maybe_archive_file.delay") @mock.patch("documents.tasks.consume_file.apply_async") @mock.patch("documents.bulk_edit.tempfile.mkdtemp") @@ -1513,12 +1560,12 @@ class TestPDFActions(DirectoriesMixin, TestCase): mock_mkdtemp.return_value = str(temp_dir) fake_pdf = mock.MagicMock() + self.mock_password_required_pdf(mock_open, fake_pdf) def save_side_effect(target_path): Path(target_path).write_bytes(b"new pdf content") fake_pdf.save.side_effect = save_side_effect - mock_open.return_value.__enter__.return_value = fake_pdf result = bulk_edit.remove_password( [doc.id], @@ -1528,7 +1575,13 @@ class TestPDFActions(DirectoriesMixin, TestCase): ) self.assertEqual(result, "OK") - mock_open.assert_called_once_with(source_file, password="secret") + self.assertEqual( + mock_open.call_args_list, + [ + mock.call(source_file), + mock.call(source_file, password="secret"), + ], + ) mock_update_document.assert_not_called() mock_consume_delay.assert_called_once() @@ -1547,7 +1600,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): root_document=self.doc1, ) fake_pdf = mock.MagicMock() - mock_open.return_value.__enter__.return_value = fake_pdf + self.mock_password_required_pdf(mock_open, fake_pdf) result = bulk_edit.remove_password( [self.doc1.id], @@ -1557,7 +1610,13 @@ class TestPDFActions(DirectoriesMixin, TestCase): ) self.assertEqual(result, "OK") - mock_open.assert_called_once_with(self.doc1.source_path, password="secret") + self.assertEqual( + mock_open.call_args_list, + [ + mock.call(self.doc1.source_path), + mock.call(self.doc1.source_path, password="secret"), + ], + ) mock_consume_delay.assert_called_once() @mock.patch("documents.bulk_edit.chord") @@ -1580,12 +1639,12 @@ class TestPDFActions(DirectoriesMixin, TestCase): fake_pdf = mock.MagicMock() fake_pdf.pages = [mock.Mock(), mock.Mock()] + self.mock_password_required_pdf(mock_open, fake_pdf) def save_side_effect(target_path: Path) -> None: target_path.write_bytes(b"password removed") fake_pdf.save.side_effect = save_side_effect - mock_open.return_value.__enter__.return_value = fake_pdf mock_group.return_value.delay.return_value = None user = User.objects.create(username="owner") @@ -1600,7 +1659,13 @@ class TestPDFActions(DirectoriesMixin, TestCase): ) self.assertEqual(result, "OK") - mock_open.assert_called_once_with(doc.source_path, password="secret") + self.assertEqual( + mock_open.call_args_list, + [ + mock.call(doc.source_path), + mock.call(doc.source_path, password="secret"), + ], + ) mock_consume_file.assert_called_once() call_kwargs = mock_consume_file.call_args.kwargs consumable_document = call_kwargs["input_doc"] @@ -1618,6 +1683,43 @@ class TestPDFActions(DirectoriesMixin, TestCase): mock_group.return_value.delay.assert_called_once() mock_chord.assert_not_called() + @mock.patch("documents.bulk_edit.delete") + @mock.patch("documents.bulk_edit.chord") + @mock.patch("documents.bulk_edit.group") + @mock.patch("documents.tasks.consume_file.s") + @mock.patch("documents.bulk_edit.tempfile.mkdtemp") + @mock.patch("pikepdf.open") + def test_remove_password_skips_unencrypted_pdf_without_queueing( + self, + mock_open: mock.Mock, + mock_mkdtemp: mock.Mock, + mock_consume_file: mock.Mock, + mock_group: mock.Mock, + mock_chord: mock.Mock, + mock_delete: mock.Mock, + ) -> None: + doc = self.doc2 + fake_pdf = mock.MagicMock() + fake_pdf.is_encrypted = False + mock_open.return_value.__enter__.return_value = fake_pdf + + result = bulk_edit.remove_password( + [doc.id], + password="secret", + update_document=False, + delete_original=True, + ) + + self.assertEqual(result, "OK") + mock_open.assert_called_once_with(doc.source_path) + fake_pdf.remove_unreferenced_resources.assert_not_called() + fake_pdf.save.assert_not_called() + mock_mkdtemp.assert_not_called() + mock_consume_file.assert_not_called() + mock_group.assert_not_called() + mock_chord.assert_not_called() + mock_delete.si.assert_not_called() + @mock.patch("documents.bulk_edit.delete") @mock.patch("documents.bulk_edit.chord") @mock.patch("documents.bulk_edit.group") @@ -1640,12 +1742,12 @@ class TestPDFActions(DirectoriesMixin, TestCase): fake_pdf = mock.MagicMock() fake_pdf.pages = [mock.Mock(), mock.Mock()] + self.mock_password_required_pdf(mock_open, fake_pdf) def save_side_effect(target_path: Path) -> None: target_path.write_bytes(b"password removed") fake_pdf.save.side_effect = save_side_effect - mock_open.return_value.__enter__.return_value = fake_pdf mock_chord.return_value.delay.return_value = None result = bulk_edit.remove_password( @@ -1657,7 +1759,13 @@ class TestPDFActions(DirectoriesMixin, TestCase): ) self.assertEqual(result, "OK") - mock_open.assert_called_once_with(doc.source_path, password="secret") + self.assertEqual( + mock_open.call_args_list, + [ + mock.call(doc.source_path), + mock.call(doc.source_path, password="secret"), + ], + ) mock_consume_file.assert_called_once() mock_group.assert_not_called() mock_chord.assert_called_once() From c3459d8f6260d297bb586a2acc292c2b2ffb2078 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Sun, 7 Jun 2026 15:45:15 -0700 Subject: [PATCH 07/29] Fix (beta): move task filtering to backend fully (#12956) --- .../admin/tasks/tasks.component.html | 2 +- .../admin/tasks/tasks.component.spec.ts | 114 ++++++++++++- .../components/admin/tasks/tasks.component.ts | 108 +++++++++++- src-ui/src/app/data/paperless-task.ts | 7 + src-ui/src/app/services/tasks.service.spec.ts | 30 ++++ src-ui/src/app/services/tasks.service.ts | 14 +- src/documents/filters.py | 64 ++++++- src/documents/tests/test_api_tasks.py | 160 ++++++++++++++++++ src/documents/views.py | 59 ++++++- 9 files changed, 547 insertions(+), 11 deletions(-) diff --git a/src-ui/src/app/components/admin/tasks/tasks.component.html b/src-ui/src/app/components/admin/tasks/tasks.component.html index b8e4f3ff5..e1d6bc900 100644 --- a/src-ui/src/app/components/admin/tasks/tasks.component.html +++ b/src-ui/src/app/components/admin/tasks/tasks.component.html @@ -84,7 +84,7 @@
diff --git a/src-ui/src/app/components/admin/tasks/tasks.component.spec.ts b/src-ui/src/app/components/admin/tasks/tasks.component.spec.ts index a87ec49b0..315dc2455 100644 --- a/src-ui/src/app/components/admin/tasks/tasks.component.spec.ts +++ b/src-ui/src/app/components/admin/tasks/tasks.component.spec.ts @@ -29,7 +29,11 @@ import { ToastService } from 'src/app/services/toast.service' import { environment } from 'src/environments/environment' import { ConfirmDialogComponent } from '../../common/confirm-dialog/confirm-dialog.component' import { PageHeaderComponent } from '../../common/page-header/page-header.component' -import { TasksComponent, TaskSection } from './tasks.component' +import { + TaskFilterTargetID, + TasksComponent, + TaskSection, +} from './tasks.component' const tasks: PaperlessTask[] = [ { @@ -154,6 +158,13 @@ const paginatedTasks: Results = { results: tasks, } +const sectionCountResponse = { + all: 7, + needs_attention: 2, + in_progress: 3, + completed: 2, +} + describe('TasksComponent', () => { let component: TasksComponent let fixture: ComponentFixture @@ -221,6 +232,15 @@ describe('TasksComponent', () => { req.params.get('page') === '1' ) .flush(paginatedTasks) + + httpTestingController + .expectOne( + (req) => + req.url === `${environment.apiBaseUrl}tasks/status_counts/` && + req.params.get('acknowledged') === 'false' && + !req.params.has('status') + ) + .flush(sectionCountResponse) }) it('should display task sections with counts', () => { @@ -328,6 +348,74 @@ describe('TasksComponent', () => { expect(pagination).not.toBeNull() }) + it('should apply the selected section to the server-side task query', () => { + component.setSection(TaskSection.NeedsAttention) + + const req = httpTestingController.expectOne( + (request) => + request.url === `${environment.apiBaseUrl}tasks/` && + request.params.get('page') === '1' && + request.params.get('page_size') === '25' && + request.params.get('acknowledged') === 'false' && + request.params.getAll('status').includes(PaperlessTaskStatus.Failure) && + request.params.getAll('status').includes(PaperlessTaskStatus.Revoked) + ) + + req.flush({ count: 2, results: [tasks[0], tasks[1]] }) + expect(component.totalTasks).toBe(2) + }) + + it('should apply task type and trigger source filters to the server-side task query', () => { + component.setTaskType(PaperlessTaskType.SanityCheck) + + httpTestingController + .expectOne( + (request) => + request.url === `${environment.apiBaseUrl}tasks/` && + request.params.get('page_size') === '25' && + request.params.get('task_type') === PaperlessTaskType.SanityCheck + ) + .flush({ count: 1, results: [tasks[6]] }) + + component.setTriggerSource(PaperlessTaskTriggerSource.System) + + httpTestingController + .expectOne( + (request) => + request.url === `${environment.apiBaseUrl}tasks/` && + request.params.get('page_size') === '25' && + request.params.get('task_type') === PaperlessTaskType.SanityCheck && + request.params.get('trigger_source') === + PaperlessTaskTriggerSource.System + ) + .flush({ count: 1, results: [tasks[6]] }) + }) + + it('should apply text filters to the server-side task query', () => { + component.filterText = 'invoice' + jest.advanceTimersByTime(150) + + httpTestingController + .expectOne( + (request) => + request.url === `${environment.apiBaseUrl}tasks/` && + request.params.get('page_size') === '25' && + request.params.get('name') === 'invoice' + ) + .flush({ count: 1, results: [tasks[0]] }) + + component.setFilterTarget(TaskFilterTargetID.Result) + + httpTestingController + .expectOne( + (request) => + request.url === `${environment.apiBaseUrl}tasks/` && + request.params.get('page_size') === '25' && + request.params.get('result') === 'invoice' + ) + .flush({ count: 0, results: [] }) + }) + it('should load a different task page when pagination changes', () => { component.setPage(2) @@ -351,6 +439,27 @@ describe('TasksComponent', () => { expect(component.pagedTasks).toEqual([tasks[0]]) }) + it('should not replace section counts with current-page counts', () => { + component.setPage(2) + + httpTestingController + .expectOne( + (req) => + req.url === `${environment.apiBaseUrl}tasks/` && + req.params.get('acknowledged') === 'false' && + req.params.get('page_size') === '25' && + req.params.get('page') === '2' + ) + .flush({ + count: 30, + results: [tasks[0]], + }) + + expect(component.sectionCount(TaskSection.NeedsAttention)).toBe(2) + expect(component.sectionCount(TaskSection.InProgress)).toBe(3) + expect(component.sectionCount(TaskSection.Completed)).toBe(2) + }) + it('should expose stable task type options and disable empty ones', () => { expect(component.taskTypeOptions.map((option) => option.value)).toContain( PaperlessTaskType.TrainClassifier @@ -714,6 +823,9 @@ describe('TasksComponent', () => { }) it('should keep clearing selection independent from resetting filters', () => { + component.resetFilter() + expect(component.filterText).toBe('') + component.setTaskType(PaperlessTaskType.ConsumeFile) component.toggleSelected(tasks[0]) expect(component.selectedTasks.size).toBe(1) diff --git a/src-ui/src/app/components/admin/tasks/tasks.component.ts b/src-ui/src/app/components/admin/tasks/tasks.component.ts index ed72a401d..276dc6a8f 100644 --- a/src-ui/src/app/components/admin/tasks/tasks.component.ts +++ b/src-ui/src/app/components/admin/tasks/tasks.component.ts @@ -40,7 +40,7 @@ export enum TaskSection { Completed = 'completed', } -enum TaskFilterTargetID { +export enum TaskFilterTargetID { Name, Result, } @@ -167,6 +167,12 @@ export class TasksComponent public readonly pageSize = 25 public page: number = 1 public totalTasks: number = 0 + public sectionCounts: Record = { + [TaskSection.All]: 0, + [TaskSection.NeedsAttention]: 0, + [TaskSection.InProgress]: 0, + [TaskSection.Completed]: 0, + } public pagedTasks: PaperlessTask[] = [] public selectedSection: TaskSection = TaskSection.All public selectedTaskType: PaperlessTaskType | null = null @@ -282,6 +288,7 @@ export class TasksComponent .subscribe((query) => { this._filterText = query this.clearSelection() + this.reloadPage(true) }) } @@ -470,9 +477,7 @@ export class TasksComponent } sectionCount(section: TaskSection): number { - return this.pagedTasks.filter((task) => - this.taskBelongsToSection(task, section) - ).length + return this.sectionCounts[section] } sectionShowsResults(section: TaskSection): boolean { @@ -482,16 +487,27 @@ export class TasksComponent setSection(section: TaskSection) { this.selectedSection = section this.clearSelection() + this.reloadPage(true) } setTaskType(taskType: PaperlessTaskType | null) { this.selectedTaskType = taskType this.clearSelection() + this.reloadPage(true) } setTriggerSource(triggerSource: PaperlessTaskTriggerSource | null) { this.selectedTriggerSource = triggerSource this.clearSelection() + this.reloadPage(true) + } + + setFilterTarget(filterTargetID: TaskFilterTargetID) { + this.filterTargetID = filterTargetID + if (this._filterText.length) { + this.clearSelection() + this.reloadPage(true) + } } taskTypeOptionCount(taskType: PaperlessTaskType | null): number { @@ -529,19 +545,32 @@ export class TasksComponent } public resetFilter() { + if (!this._filterText.length) { + return + } + this._filterText = '' + this.clearSelection() + this.reloadPage(true) } public resetFilters() { + const hadFilter = this.isFiltered this.selectedTaskType = null this.selectedTriggerSource = null - this.resetFilter() + this._filterText = '' this.clearSelection() + + if (hadFilter) { + this.reloadPage(true) + } } filterInputKeyup(event: KeyboardEvent) { if (event.key == 'Enter') { this._filterText = (event.target as HTMLInputElement).value + this.clearSelection() + this.reloadPage(true) } else if (event.key === 'Escape') { this.resetFilter() } @@ -630,19 +659,86 @@ export class TasksComponent ) } + private reloadSectionCounts() { + this.tasksService + .statusCounts(this.getParamsForSection(TaskSection.All)) + .pipe(first(), takeUntil(this.unsubscribeNotifier)) + .subscribe((counts) => { + this.sectionCounts[TaskSection.All] = counts.all + this.sectionCounts[TaskSection.NeedsAttention] = counts.needs_attention + this.sectionCounts[TaskSection.InProgress] = counts.in_progress + this.sectionCounts[TaskSection.Completed] = counts.completed + }) + } + + private getParamsForSection( + section: TaskSection + ): Record { + const params: Record< + string, + string | number | boolean | readonly string[] + > = { + acknowledged: false, + } + + const statuses = this.statusesForSection(section) + if (statuses.length) { + params.status = statuses + } + + if (this.selectedTaskType !== null) { + params.task_type = this.selectedTaskType + } + + if (this.selectedTriggerSource !== null) { + params.trigger_source = this.selectedTriggerSource + } + + if (this._filterText.length) { + params[ + this.filterTargetID === TaskFilterTargetID.Name ? 'name' : 'result' + ] = this._filterText + } + + return params + } + + private statusesForSection(section: TaskSection): PaperlessTaskStatus[] { + switch (section) { + case TaskSection.NeedsAttention: + return [PaperlessTaskStatus.Failure, PaperlessTaskStatus.Revoked] + case TaskSection.InProgress: + return [PaperlessTaskStatus.Pending, PaperlessTaskStatus.Started] + case TaskSection.Completed: + return [PaperlessTaskStatus.Success] + default: + return [] + } + } + private reloadPage(resetToFirstPage: boolean = false) { if (resetToFirstPage) { this.page = 1 } + this.reloadSectionCounts() + this.loading = true this.tasksService - .list(this.page, this.pageSize, { acknowledged: false }) + .list( + this.page, + this.pageSize, + this.getParamsForSection(this.selectedSection) + ) .pipe(first(), takeUntil(this.unsubscribeNotifier)) .subscribe({ next: (result) => { this.pagedTasks = result.results this.totalTasks = result.count + this.sectionCounts[TaskSection.All] = result.count + if (this.selectedSection !== TaskSection.All) { + this.sectionCounts[this.selectedSection] = result.count + } this.loading = false if ( this.page > 1 && diff --git a/src-ui/src/app/data/paperless-task.ts b/src-ui/src/app/data/paperless-task.ts index 53aba0edd..ca64918c4 100644 --- a/src-ui/src/app/data/paperless-task.ts +++ b/src-ui/src/app/data/paperless-task.ts @@ -64,3 +64,10 @@ export interface PaperlessTaskSummary { last_success: Date | null last_failure: Date | null } + +export interface PaperlessTaskStatusCounts { + all: number + needs_attention: number + in_progress: number + completed: number +} diff --git a/src-ui/src/app/services/tasks.service.spec.ts b/src-ui/src/app/services/tasks.service.spec.ts index 1ae217543..3412ae2ce 100644 --- a/src-ui/src/app/services/tasks.service.spec.ts +++ b/src-ui/src/app/services/tasks.service.spec.ts @@ -242,4 +242,34 @@ describe('TasksService', () => { task_id: 'abc-123', }) }) + + it('loads filtered task status counts', () => { + tasksService + .statusCounts({ + acknowledged: false, + task_type: PaperlessTaskType.ConsumeFile, + }) + .subscribe((res) => { + expect(res).toEqual({ + all: 10, + needs_attention: 2, + in_progress: 3, + completed: 5, + }) + }) + + const req = httpTestingController.expectOne( + (req: HttpRequest) => + req.url === `${environment.apiBaseUrl}tasks/status_counts/` && + req.params.get('acknowledged') === 'false' && + req.params.get('task_type') === PaperlessTaskType.ConsumeFile + ) + expect(req.request.method).toEqual('GET') + req.flush({ + all: 10, + needs_attention: 2, + in_progress: 3, + completed: 5, + }) + }) }) diff --git a/src-ui/src/app/services/tasks.service.ts b/src-ui/src/app/services/tasks.service.ts index a3ae283ed..404db589a 100644 --- a/src-ui/src/app/services/tasks.service.ts +++ b/src-ui/src/app/services/tasks.service.ts @@ -5,6 +5,7 @@ import { first, map, takeUntil, tap } from 'rxjs/operators' import { PaperlessTask, PaperlessTaskStatus, + PaperlessTaskStatusCounts, PaperlessTaskType, } from 'src/app/data/paperless-task' import { Results } from 'src/app/data/results' @@ -88,7 +89,7 @@ export class TasksService { public list( page: number, pageSize: number, - extraParams?: Record + extraParams?: Record ): Observable> { return this.http.get>( `${this.baseUrl}${this.endpoint}/`, @@ -102,6 +103,17 @@ export class TasksService { ) } + public statusCounts( + extraParams?: Record + ): Observable { + return this.http.get( + `${this.baseUrl}${this.endpoint}/status_counts/`, + { + params: extraParams, + } + ) + } + public dismissTasks(task_ids: Set): Observable { return this.http .post(`${this.baseUrl}tasks/acknowledge/`, { diff --git a/src/documents/filters.py b/src/documents/filters.py index ddc784204..39c6eb467 100644 --- a/src/documents/filters.py +++ b/src/documents/filters.py @@ -28,6 +28,7 @@ from django.db.models.functions import Cast from django.utils.translation import gettext_lazy as _ from django_filters import DateFilter from django_filters.rest_framework import BooleanFilter +from django_filters.rest_framework import CharFilter from django_filters.rest_framework import DateTimeFilter from django_filters.rest_framework import Filter from django_filters.rest_framework import FilterSet @@ -900,6 +901,16 @@ class ShareLinkBundleFilterSet(FilterSet): class PaperlessTaskFilterSet(FilterSet): + name = CharFilter( + method="filter_name", + label="Name", + ) + + result = CharFilter( + method="filter_result", + label="Result", + ) + task_type = MultipleChoiceFilter( choices=PaperlessTask.TaskType.choices, label="Task Type", @@ -939,7 +950,58 @@ class PaperlessTaskFilterSet(FilterSet): class Meta: model = PaperlessTask - fields = ["task_type", "trigger_source", "status", "acknowledged", "owner"] + fields = [ + "task_type", + "trigger_source", + "status", + "acknowledged", + "owner", + "name", + "result", + ] + + def filter_name(self, queryset, name, value): + if not value: + return queryset + + matching_task_types = [ + task_type + for task_type, label in PaperlessTask.TaskType.choices + if value.lower() in str(label).lower() + ] + matching_trigger_sources = [ + trigger_source + for trigger_source, label in PaperlessTask.TriggerSource.choices + if value.lower() in str(label).lower() + ] + + return queryset.filter( + Q(input_data__filename__icontains=value) + | Q(task_type__in=matching_task_types) + | Q(trigger_source__in=matching_trigger_sources), + ) + + def filter_result(self, queryset, name, value): + if not value: + return queryset + + query = Q(result_data__reason__icontains=value) | Q( + result_data__error_message__icontains=value, + ) + + try: + numeric_value = int(value) + except (TypeError, ValueError): + pass + else: + query |= Q(result_data__document_id=numeric_value) | Q( + result_data__duplicate_of=numeric_value, + ) + + if "duplicate" in value.lower(): + query |= Q(result_data__duplicate_of__isnull=False) + + return queryset.filter(query) def filter_is_complete(self, queryset, name, value): if value: diff --git a/src/documents/tests/test_api_tasks.py b/src/documents/tests/test_api_tasks.py index 42ccbab5c..59767f9af 100644 --- a/src/documents/tests/test_api_tasks.py +++ b/src/documents/tests/test_api_tasks.py @@ -18,6 +18,7 @@ from guardian.shortcuts import assign_perm from rest_framework import status from rest_framework.test import APIClient +from documents.filters import PaperlessTaskFilterSet from documents.models import PaperlessTask from documents.tests.factories import DocumentFactory from documents.tests.factories import PaperlessTaskFactory @@ -169,6 +170,165 @@ class TestGetTasksV10: PaperlessTask.Status.STARTED, } + def test_filter_by_task_name(self, admin_client: APIClient) -> None: + """?name= searches task filenames, task types, and trigger sources.""" + filename_task = PaperlessTaskFactory(input_data={"filename": "invoice-123.pdf"}) + type_task = PaperlessTaskFactory(task_type=PaperlessTask.TaskType.SANITY_CHECK) + source_task = PaperlessTaskFactory( + trigger_source=PaperlessTask.TriggerSource.EMAIL_CONSUME, + ) + PaperlessTaskFactory(input_data={"filename": "unrelated.pdf"}) + + response = admin_client.get(ENDPOINT, {"name": "invoice"}) + + assert response.status_code == status.HTTP_200_OK + assert response.data["count"] == 1 + assert response.data["results"][0]["task_id"] == filename_task.task_id + + response = admin_client.get(ENDPOINT, {"name": "sanity"}) + + assert response.status_code == status.HTTP_200_OK + assert response.data["count"] == 1 + assert response.data["results"][0]["task_id"] == type_task.task_id + + response = admin_client.get(ENDPOINT, {"name": "email"}) + + assert response.status_code == status.HTTP_200_OK + assert response.data["count"] == 1 + assert response.data["results"][0]["task_id"] == source_task.task_id + + def test_filter_by_task_result(self, admin_client: APIClient) -> None: + """?result= searches common structured task result messages.""" + reason_task = PaperlessTaskFactory(result_data={"reason": "Manual review"}) + error_task = PaperlessTaskFactory( + result_data={"error_message": "Duplicate detected"}, + ) + document_task = PaperlessTaskFactory(result_data={"document_id": 321}) + duplicate_task = PaperlessTaskFactory(result_data={"duplicate_of": 123}) + PaperlessTaskFactory(result_data={"reason": "unrelated"}) + + response = admin_client.get(ENDPOINT, {"result": "manual"}) + + assert response.status_code == status.HTTP_200_OK + assert response.data["count"] == 1 + assert response.data["results"][0]["task_id"] == reason_task.task_id + + response = admin_client.get(ENDPOINT, {"result": "duplicate"}) + + assert response.status_code == status.HTTP_200_OK + returned_ids = {task["task_id"] for task in response.data["results"]} + assert returned_ids == {error_task.task_id, duplicate_task.task_id} + + response = admin_client.get(ENDPOINT, {"result": "321"}) + + assert response.status_code == status.HTTP_200_OK + assert response.data["count"] == 1 + assert response.data["results"][0]["task_id"] == document_task.task_id + + def test_empty_task_name_and_result_filters(self) -> None: + """Empty name/result values leave the queryset unchanged.""" + PaperlessTaskFactory.create_batch(2) + queryset = PaperlessTask.objects.all() + filterset = PaperlessTaskFilterSet() + + assert filterset.filter_name(queryset, "name", "").count() == 2 + assert filterset.filter_result(queryset, "result", "").count() == 2 + + def test_status_counts_respects_filters(self, admin_client: APIClient) -> None: + """status_counts/ returns section counts for the filtered task queryset.""" + PaperlessTaskFactory( + acknowledged=False, + status=PaperlessTask.Status.FAILURE, + input_data={"filename": "invoice-a.pdf"}, + ) + PaperlessTaskFactory( + acknowledged=False, + status=PaperlessTask.Status.REVOKED, + input_data={"filename": "invoice-b.pdf"}, + ) + PaperlessTaskFactory( + acknowledged=False, + status=PaperlessTask.Status.PENDING, + input_data={"filename": "invoice-c.pdf"}, + ) + PaperlessTaskFactory( + acknowledged=False, + status=PaperlessTask.Status.STARTED, + input_data={"filename": "invoice-d.pdf"}, + ) + PaperlessTaskFactory( + acknowledged=False, + status=PaperlessTask.Status.SUCCESS, + input_data={"filename": "invoice-e.pdf"}, + ) + PaperlessTaskFactory( + acknowledged=True, + status=PaperlessTask.Status.SUCCESS, + input_data={"filename": "invoice-acknowledged.pdf"}, + ) + PaperlessTaskFactory( + acknowledged=False, + status=PaperlessTask.Status.SUCCESS, + input_data={"filename": "unrelated.pdf"}, + ) + + response = admin_client.get( + f"{ENDPOINT}status_counts/", + {"acknowledged": "false", "name": "invoice"}, + ) + + assert response.status_code == status.HTTP_200_OK + assert response.data == { + "all": 5, + "needs_attention": 2, + "in_progress": 2, + "completed": 1, + } + + def test_status_counts_ignores_section_filters( + self, + admin_client: APIClient, + ) -> None: + """status_counts/ ignores status-like filters for the sections it counts.""" + PaperlessTaskFactory( + acknowledged=False, + status=PaperlessTask.Status.FAILURE, + input_data={"filename": "invoice-a.pdf"}, + ) + PaperlessTaskFactory( + acknowledged=False, + status=PaperlessTask.Status.PENDING, + input_data={"filename": "invoice-b.pdf"}, + ) + PaperlessTaskFactory( + acknowledged=False, + status=PaperlessTask.Status.SUCCESS, + input_data={"filename": "invoice-c.pdf"}, + ) + PaperlessTaskFactory( + acknowledged=False, + status=PaperlessTask.Status.FAILURE, + input_data={"filename": "unrelated.pdf"}, + ) + + response = admin_client.get( + f"{ENDPOINT}status_counts/", + { + "acknowledged": "false", + "name": "invoice", + "status": PaperlessTask.Status.FAILURE, + "is_complete": "false", + }, + ) + + assert response.status_code == status.HTTP_200_OK + assert response.data == { + "all": 3, + "needs_attention": 1, + "in_progress": 1, + "completed": 1, + } + def test_default_ordering_is_newest_first(self, admin_client: APIClient) -> None: """Tasks are returned in descending date_created order (newest first).""" base = timezone.now() diff --git a/src/documents/views.py b/src/documents/views.py index ba4faa622..cbc4560d8 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -4011,7 +4011,7 @@ class RemoteVersionView(GenericAPIView[Any]): class _TasksViewSetSchema(AutoSchema): - _UNPAGINATED_ACTIONS = frozenset({"summary", "active"}) + _UNPAGINATED_ACTIONS = frozenset({"summary", "active", "status_counts"}) def _get_paginator(self): if getattr(self.view, "action", None) in self._UNPAGINATED_ACTIONS: @@ -4071,6 +4071,19 @@ class _TasksViewSetSchema(AutoSchema): ), ], ), + status_counts=extend_schema( + responses={ + 200: inline_serializer( + name="TaskStatusCounts", + fields={ + "all": serializers.IntegerField(), + "needs_attention": serializers.IntegerField(), + "in_progress": serializers.IntegerField(), + "completed": serializers.IntegerField(), + }, + ), + }, + ), active=extend_schema( description="Currently pending and running tasks (capped at 50).", responses={200: TaskSerializerV10(many=True)}, @@ -4124,6 +4137,7 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]): PaperlessTask.TaskType.SANITY_CHECK: (sanity_check, {"raise_on_error": False}), PaperlessTask.TaskType.LLM_INDEX: (llmindex_index, {"rebuild": False}), } + _STATUS_COUNT_EXCLUDED_FILTERS = frozenset({"status", "is_complete"}) def get_serializer_class(self): # v9: use backwards-compatible serializer with old field names @@ -4164,6 +4178,21 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]): queryset = queryset.filter(task_id=task_id) return queryset + def get_status_count_queryset(self): + """Apply task filters except the status dimensions represented by the counts.""" + query_params = self.request.query_params.copy() + for param in self._STATUS_COUNT_EXCLUDED_FILTERS: + query_params.pop(param, None) + + filterset = self.filterset_class( + data=query_params, + queryset=self.get_queryset(), + request=self.request, + ) + if not filterset.is_valid(): + raise ValidationError(filterset.errors) + return filterset.qs + @action( methods=["post"], detail=False, @@ -4233,6 +4262,34 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]): serializer = TaskSummarySerializer(data, many=True) return Response(serializer.data) + @action(methods=["get"], detail=False) + def status_counts(self, request): + """Aggregated task counts for task UI sections.""" + queryset = self.get_status_count_queryset() + counts = queryset.aggregate( + all=Count("id"), + needs_attention=Count( + "id", + filter=Q( + status__in=[ + PaperlessTask.Status.FAILURE, + PaperlessTask.Status.REVOKED, + ], + ), + ), + in_progress=Count( + "id", + filter=Q( + status__in=[ + PaperlessTask.Status.PENDING, + PaperlessTask.Status.STARTED, + ], + ), + ), + completed=Count("id", filter=Q(status=PaperlessTask.Status.SUCCESS)), + ) + return Response(counts) + @action(methods=["get"], detail=False) def active(self, request): """Currently pending and running tasks (capped at 50).""" From 8405f66e386559eb26a1affe5ed31575ea0a9f0c Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Tue, 9 Jun 2026 07:03:44 -0700 Subject: [PATCH 08/29] Fix (beta): fix re-ordering in merge dialog (#12967) --- .../merge-confirm-dialog.component.html | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src-ui/src/app/components/common/confirm-dialog/merge-confirm-dialog/merge-confirm-dialog.component.html b/src-ui/src/app/components/common/confirm-dialog/merge-confirm-dialog/merge-confirm-dialog.component.html index 90f555890..ab0951d54 100644 --- a/src-ui/src/app/components/common/confirm-dialog/merge-confirm-dialog/merge-confirm-dialog.component.html +++ b/src-ui/src/app/components/common/confirm-dialog/merge-confirm-dialog/merge-confirm-dialog.component.html @@ -9,8 +9,11 @@
    - @for (document of documents; track document.id) { + @for (documentID of documentIDs; track documentID) { + @let document = getDocument(documentID); + @if (document) {
  • @@ -27,6 +30,7 @@
  • + } }
From a5d6ff5f156c1f5d082127af7ee9fc7e39d5e173 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Wed, 10 Jun 2026 06:56:02 -0700 Subject: [PATCH 09/29] Fix: wrap long titles in delete confirm dialog (#12973) --- .../common/confirm-dialog/confirm-dialog.component.html | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src-ui/src/app/components/common/confirm-dialog/confirm-dialog.component.html b/src-ui/src/app/components/common/confirm-dialog/confirm-dialog.component.html index deee54402..437e7af94 100644 --- a/src-ui/src/app/components/common/confirm-dialog/confirm-dialog.component.html +++ b/src-ui/src/app/components/common/confirm-dialog/confirm-dialog.component.html @@ -5,10 +5,10 @@