Compare commits

..

14 Commits

Author SHA1 Message Date
shamoon
1775846483 Dont allow these ones to get all 2026-03-13 23:51:54 -07:00
shamoon
13671b7d85 Add API all/filters support 2026-03-13 23:51:45 -07:00
shamoon
0bb7d755ab Sonar 2026-03-13 07:17:37 -07:00
shamoon
e4d43175af Fix 2026-03-13 07:17:36 -07:00
shamoon
04945ff3f7 Update api.md 2026-03-13 07:17:36 -07:00
shamoon
7b430e27c6 Frontend use all option for bulk edit objects instead of sending IDs 2026-03-13 07:17:00 -07:00
shamoon
b329581111 Support all for BulkEditObjectsView 2026-03-13 07:16:59 -07:00
shamoon
84e8caf25f Use a backend display_count to fix nested tag thing 2026-03-13 07:16:59 -07:00
shamoon
97602f79fb Not even optional 2026-03-13 07:16:58 -07:00
shamoon
568be982cf Remove this stuff now 2026-03-13 07:16:58 -07:00
shamoon
d753b698db tests 2026-03-13 07:16:58 -07:00
shamoon
eabd11546a Only fetch all IDs on demand 2026-03-13 07:16:57 -07:00
shamoon
43072b7a74 Backend tests 2026-03-13 07:16:57 -07:00
shamoon
1c65a1bb0e Backend deprecate all to only api v < 10 2026-03-13 07:16:56 -07:00
17 changed files with 542 additions and 89 deletions

View File

@@ -437,3 +437,6 @@ Initial API version.
moved from the bulk edit endpoint to their own individual endpoints. Using these methods via
the bulk edit endpoint is still supported for compatibility with versions < 10 until support
for API v9 is dropped.
- The `all` parameter of list endpoints is now deprecated and will be removed in a future version.
- The bulk edit objects endpoint now supports `all` and `filters` parameters to avoid having to send
large lists of object IDs for operations affecting many objects.

View File

@@ -9,8 +9,8 @@
<div ngbDropdown class="btn-group flex-fill d-sm-none">
<button class="btn btn-sm btn-outline-primary" id="dropdownSelectMobile" ngbDropdownToggle>
<i-bs name="text-indent-left"></i-bs><div class="d-none d-sm-inline ms-1"><ng-container i18n>Select</ng-container></div>
@if (activeManagementList.selectedObjects.size > 0) {
<pngx-clearable-badge [selected]="activeManagementList.selectedObjects.size > 0" [number]="activeManagementList.selectedObjects.size" (cleared)="activeManagementList.selectNone()"></pngx-clearable-badge><span class="visually-hidden">selected</span>
@if (activeManagementList.hasSelection) {
<pngx-clearable-badge [selected]="activeManagementList.hasSelection" [number]="activeManagementList.selectedCount" (cleared)="activeManagementList.selectNone()"></pngx-clearable-badge><span class="visually-hidden">selected</span>
}
</button>
<div ngbDropdownMenu aria-labelledby="dropdownSelectMobile" class="shadow">
@@ -25,7 +25,7 @@
<span class="input-group-text border-0" i18n>Select:</span>
</div>
<div class="btn-group btn-group-sm flex-nowrap">
@if (activeManagementList.selectedObjects.size > 0) {
@if (activeManagementList.hasSelection) {
<button class="btn btn-sm btn-outline-secondary" (click)="activeManagementList.selectNone()">
<i-bs name="slash-circle" class="me-1"></i-bs><ng-container i18n>None</ng-container>
</button>
@@ -40,11 +40,11 @@
</div>
<button type="button" class="btn btn-sm btn-outline-primary" (click)="activeManagementList.setPermissions()"
[disabled]="!activeManagementList.userCanBulkEdit(PermissionAction.Change) || activeManagementList.selectedObjects.size === 0">
[disabled]="!activeManagementList.userCanBulkEdit(PermissionAction.Change) || !activeManagementList.hasSelection">
<i-bs name="person-fill-lock" class="me-1"></i-bs><ng-container i18n>Permissions</ng-container>
</button>
<button type="button" class="btn btn-sm btn-outline-danger" (click)="activeManagementList.delete()"
[disabled]="!activeManagementList.userCanBulkEdit(PermissionAction.Delete) || activeManagementList.selectedObjects.size === 0">
[disabled]="!activeManagementList.userCanBulkEdit(PermissionAction.Delete) || !activeManagementList.hasSelection">
<i-bs name="trash" class="me-1"></i-bs><ng-container i18n>Delete</ng-container>
</button>
<button type="button" class="btn btn-sm btn-outline-primary ms-md-5" (click)="activeManagementList.openCreateDialog()"

View File

@@ -65,8 +65,8 @@
@if (displayCollectionSize > 0) {
<div>
<ng-container i18n>{displayCollectionSize, plural, =1 {One {{typeName}}} other {{{displayCollectionSize || 0}} total {{typeNamePlural}}}}</ng-container>
@if (selectedObjects.size > 0) {
&nbsp;({{selectedObjects.size}} selected)
@if (hasSelection) {
&nbsp;({{selectedCount}} selected)
}
</div>
}

View File

@@ -117,7 +117,6 @@ describe('ManagementListComponent', () => {
: tags
return of({
count: results.length,
all: results.map((o) => o.id),
results,
})
}
@@ -231,11 +230,11 @@ describe('ManagementListComponent', () => {
expect(reloadSpy).toHaveBeenCalled()
})
it('should use API count for pagination and all ids for displayed total', fakeAsync(() => {
it('should use API count for pagination and nested ids for displayed total', fakeAsync(() => {
jest.spyOn(tagService, 'listFiltered').mockReturnValueOnce(
of({
count: 1,
all: [1, 2, 3],
display_count: 3,
results: tags.slice(0, 1),
})
)
@@ -315,13 +314,17 @@ describe('ManagementListComponent', () => {
expect(component.togggleAll).toBe(false)
})
it('selectAll should use all IDs when collection size exists', () => {
;(component as any).allIDs = [1, 2, 3, 4]
component.collectionSize = 4
it('selectAll should activate all-selection mode', () => {
;(tagService.listFiltered as jest.Mock).mockClear()
component.collectionSize = tags.length
component.selectAll()
expect(component.selectedObjects).toEqual(new Set([1, 2, 3, 4]))
expect(tagService.listFiltered).not.toHaveBeenCalled()
expect(component.selectedObjects).toEqual(new Set(tags.map((t) => t.id)))
expect((component as any).allSelectionActive).toBe(true)
expect(component.hasSelection).toBe(true)
expect(component.selectedCount).toBe(tags.length)
expect(component.togggleAll).toBe(true)
})
@@ -395,6 +398,33 @@ describe('ManagementListComponent', () => {
expect(successToastSpy).toHaveBeenCalled()
})
it('should support bulk edit permissions for all filtered items', () => {
const bulkEditPermsSpy = jest
.spyOn(tagService, 'bulk_edit_objects')
.mockReturnValue(of('OK'))
component.selectAll()
let modal: NgbModalRef
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
fixture.detectChanges()
component.setPermissions()
expect(modal).not.toBeUndefined()
modal.componentInstance.confirmClicked.emit({
permissions: {},
merge: true,
})
expect(bulkEditPermsSpy).toHaveBeenCalledWith(
[],
BulkEditObjectOperation.SetPermissions,
{},
true,
true,
{ is_root: true }
)
})
it('should support bulk delete objects', () => {
const bulkEditSpy = jest.spyOn(tagService, 'bulk_edit_objects')
component.toggleSelected(tags[0])
@@ -415,7 +445,11 @@ describe('ManagementListComponent', () => {
modal.componentInstance.confirmClicked.emit(null)
expect(bulkEditSpy).toHaveBeenCalledWith(
Array.from(selected),
BulkEditObjectOperation.Delete
BulkEditObjectOperation.Delete,
null,
null,
false,
null
)
expect(errorToastSpy).toHaveBeenCalled()
@@ -426,6 +460,29 @@ describe('ManagementListComponent', () => {
expect(successToastSpy).toHaveBeenCalled()
})
it('should support bulk delete for all filtered items', () => {
const bulkEditSpy = jest
.spyOn(tagService, 'bulk_edit_objects')
.mockReturnValue(of('OK'))
component.selectAll()
let modal: NgbModalRef
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
fixture.detectChanges()
component.delete()
expect(modal).not.toBeUndefined()
modal.componentInstance.confirmClicked.emit(null)
expect(bulkEditSpy).toHaveBeenCalledWith(
[],
BulkEditObjectOperation.Delete,
null,
null,
true,
{ is_root: true }
)
})
it('should disallow bulk permissions or delete objects if no global perms', () => {
jest.spyOn(permissionsService, 'currentUserCan').mockReturnValue(false)
expect(component.userCanBulkEdit(PermissionAction.Delete)).toBeFalsy()

View File

@@ -90,7 +90,8 @@ export abstract class ManagementListComponent<T extends MatchingModel>
public data: T[] = []
private unfilteredData: T[] = []
private allIDs: number[] = []
private currentExtraParams: { [key: string]: any } = null
private allSelectionActive = false
public page = 1
@@ -107,6 +108,16 @@ export abstract class ManagementListComponent<T extends MatchingModel>
public selectedObjects: Set<number> = new Set()
public togggleAll: boolean = false
public get hasSelection(): boolean {
return this.selectedObjects.size > 0 || this.allSelectionActive
}
public get selectedCount(): number {
return this.allSelectionActive
? this.displayCollectionSize
: this.selectedObjects.size
}
ngOnInit(): void {
this.reloadData()
@@ -150,11 +161,11 @@ export abstract class ManagementListComponent<T extends MatchingModel>
}
protected getCollectionSize(results: Results<T>): number {
return results.all?.length ?? results.count
return results.count
}
protected getDisplayCollectionSize(results: Results<T>): number {
return this.getCollectionSize(results)
return results.display_count ?? this.getCollectionSize(results)
}
getDocumentCount(object: MatchingModel): number {
@@ -171,6 +182,7 @@ export abstract class ManagementListComponent<T extends MatchingModel>
reloadData(extraParams: { [key: string]: any } = null) {
this.loading = true
this.currentExtraParams = extraParams
this.clearSelection()
this.service
.listFiltered(
@@ -189,7 +201,6 @@ export abstract class ManagementListComponent<T extends MatchingModel>
this.data = this.filterData(c.results)
this.collectionSize = this.getCollectionSize(c)
this.displayCollectionSize = this.getDisplayCollectionSize(c)
this.allIDs = c.all
}),
delay(100)
)
@@ -346,7 +357,16 @@ export abstract class ManagementListComponent<T extends MatchingModel>
return objects.map((o) => o.id)
}
private getBulkEditFilters(): { [key: string]: any } {
const filters = { ...this.currentExtraParams }
if (this._nameFilter?.length) {
filters['name__icontains'] = this._nameFilter
}
return filters
}
clearSelection() {
this.allSelectionActive = false
this.togggleAll = false
this.selectedObjects.clear()
}
@@ -356,6 +376,7 @@ export abstract class ManagementListComponent<T extends MatchingModel>
}
selectPage() {
this.allSelectionActive = false
this.selectedObjects = new Set(this.getSelectableIDs(this.data))
this.togggleAll = this.areAllPageItemsSelected()
}
@@ -365,11 +386,16 @@ export abstract class ManagementListComponent<T extends MatchingModel>
this.clearSelection()
return
}
this.selectedObjects = new Set(this.allIDs)
this.allSelectionActive = true
this.selectedObjects = new Set(this.getSelectableIDs(this.data))
this.togggleAll = this.areAllPageItemsSelected()
}
toggleSelected(object) {
if (this.allSelectionActive) {
this.allSelectionActive = false
}
this.selectedObjects.has(object.id)
? this.selectedObjects.delete(object.id)
: this.selectedObjects.add(object.id)
@@ -377,6 +403,9 @@ export abstract class ManagementListComponent<T extends MatchingModel>
}
protected areAllPageItemsSelected(): boolean {
if (this.allSelectionActive) {
return this.data.length > 0
}
const ids = this.getSelectableIDs(this.data)
return ids.length > 0 && ids.every((id) => this.selectedObjects.has(id))
}
@@ -390,10 +419,12 @@ export abstract class ManagementListComponent<T extends MatchingModel>
modal.componentInstance.buttonsEnabled = false
this.service
.bulk_edit_objects(
Array.from(this.selectedObjects),
this.allSelectionActive ? [] : Array.from(this.selectedObjects),
BulkEditObjectOperation.SetPermissions,
permissions,
merge
merge,
this.allSelectionActive,
this.allSelectionActive ? this.getBulkEditFilters() : null
)
.subscribe({
next: () => {
@@ -428,8 +459,12 @@ export abstract class ManagementListComponent<T extends MatchingModel>
modal.componentInstance.buttonsEnabled = false
this.service
.bulk_edit_objects(
Array.from(this.selectedObjects),
BulkEditObjectOperation.Delete
this.allSelectionActive ? [] : Array.from(this.selectedObjects),
BulkEditObjectOperation.Delete,
null,
null,
this.allSelectionActive,
this.allSelectionActive ? this.getBulkEditFilters() : null
)
.subscribe({
next: () => {

View File

@@ -41,7 +41,6 @@ describe('TagListComponent', () => {
listFilteredSpy = jest.spyOn(tagService, 'listFiltered').mockReturnValue(
of({
count: 3,
all: [1, 2, 3],
results: [
{
id: 1,

View File

@@ -9,7 +9,6 @@ import {
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { TagEditDialogComponent } from 'src/app/components/common/edit-dialog/tag-edit-dialog/tag-edit-dialog.component'
import { FILTER_HAS_TAGS_ALL } from 'src/app/data/filter-rule-type'
import { Results } from 'src/app/data/results'
import { Tag } from 'src/app/data/tag'
import { IfPermissionsDirective } from 'src/app/directives/if-permissions.directive'
import { SortableDirective } from 'src/app/directives/sortable.directive'
@@ -77,16 +76,6 @@ export class TagListComponent extends ManagementListComponent<Tag> {
return data.filter((tag) => !tag.parent || !availableIds.has(tag.parent))
}
protected override getCollectionSize(results: Results<Tag>): number {
// Tag list pages are requested with is_root=true (when unfiltered), so
// pagination must follow root count even though `all` includes descendants
return results.count
}
protected override getDisplayCollectionSize(results: Results<Tag>): number {
return super.getCollectionSize(results)
}
protected override getSelectableIDs(tags: Tag[]): number[] {
const ids: number[] = []
for (const tag of tags.filter(Boolean)) {

View File

@@ -3,9 +3,9 @@ import { Document } from './document'
export interface Results<T> {
count: number
results: T[]
display_count?: number
all: number[]
results: T[]
}
export interface SelectionDataItem {

View File

@@ -96,6 +96,30 @@ export const commonAbstractNameFilterPaperlessServiceTests = (
})
req.flush([])
})
test('should call appropriate api endpoint for bulk delete on all filtered objects', () => {
subscription = service
.bulk_edit_objects(
[],
BulkEditObjectOperation.Delete,
null,
null,
true,
{ name__icontains: 'hello' }
)
.subscribe()
const req = httpTestingController.expectOne(
`${environment.apiBaseUrl}bulk_edit_objects/`
)
expect(req.request.method).toEqual('POST')
expect(req.request.body).toEqual({
object_type: endpoint,
operation: BulkEditObjectOperation.Delete,
all: true,
filters: { name__icontains: 'hello' },
})
req.flush([])
})
})
beforeEach(() => {

View File

@@ -37,13 +37,22 @@ export abstract class AbstractNameFilterService<
objects: Array<number>,
operation: BulkEditObjectOperation,
permissions: { owner: number; set_permissions: PermissionsObject } = null,
merge: boolean = null
merge: boolean = null,
all: boolean = false,
filters: { [key: string]: any } = null
): Observable<string> {
const params = {
objects,
const params: any = {
object_type: this.resourceName,
operation,
}
if (all) {
params['all'] = true
if (filters) {
params['filters'] = filters
}
} else {
params['objects'] = objects
}
if (operation === BulkEditObjectOperation.SetPermissions) {
params['owner'] = permissions?.owner
params['permissions'] = permissions?.set_permissions

View File

@@ -375,6 +375,26 @@ class DelayedQuery:
]
return self._manual_hits_cache
def get_result_ids(self) -> list[int]:
"""
Return all matching document IDs for the current query and ordering.
"""
if self._manual_sort_requested():
return [hit["id"] for hit in self._manual_hits()]
q, mask, suggested_correction = self._get_query()
self.suggested_correction = suggested_correction
sortedby, reverse = self._get_query_sortedby()
results = self.searcher.search(
q,
mask=mask,
filter=MappedDocIdSet(self.filter_queryset, self.searcher.ixreader),
limit=None,
sortedby=sortedby,
reverse=reverse,
)
return [hit["id"] for hit in results]
def __getitem__(self, item):
if item.start in self.saved_results:
return self.saved_results[item.start]

View File

@@ -1540,6 +1540,41 @@ class DocumentListSerializer(serializers.Serializer):
return documents
class DocumentSelectionSerializer(DocumentListSerializer):
documents = serializers.ListField(
required=False,
label="Documents",
write_only=True,
child=serializers.IntegerField(),
)
all = serializers.BooleanField(
default=False,
required=False,
write_only=True,
)
filters = serializers.DictField(
required=False,
allow_empty=True,
write_only=True,
)
def validate(self, attrs):
if attrs.get("all", False):
attrs.setdefault("documents", [])
return attrs
if "documents" not in attrs:
raise serializers.ValidationError(
"documents is required unless all is true.",
)
documents = attrs["documents"]
self._validate_document_id_list(documents)
return attrs
class SourceModeValidationMixin:
def validate_source_mode(self, source_mode: str) -> str:
if source_mode not in bulk_edit.SourceModeChoices.__dict__.values():
@@ -1547,7 +1582,7 @@ class SourceModeValidationMixin:
return source_mode
class RotateDocumentsSerializer(DocumentListSerializer, SourceModeValidationMixin):
class RotateDocumentsSerializer(DocumentSelectionSerializer, SourceModeValidationMixin):
degrees = serializers.IntegerField(required=True)
source_mode = serializers.CharField(
required=False,
@@ -1630,17 +1665,17 @@ class RemovePasswordDocumentsSerializer(
)
class DeleteDocumentsSerializer(DocumentListSerializer):
class DeleteDocumentsSerializer(DocumentSelectionSerializer):
pass
class ReprocessDocumentsSerializer(DocumentListSerializer):
class ReprocessDocumentsSerializer(DocumentSelectionSerializer):
pass
class BulkEditSerializer(
SerializerWithPerms,
DocumentListSerializer,
DocumentSelectionSerializer,
SetPermissionsMixin,
SourceModeValidationMixin,
):
@@ -1955,6 +1990,19 @@ class BulkEditSerializer(
raise serializers.ValidationError("password must be a string")
def validate(self, attrs):
attrs = super().validate(attrs)
if attrs.get("all", False) and attrs["method"] in [
bulk_edit.merge,
bulk_edit.split,
bulk_edit.delete_pages,
bulk_edit.edit_pdf,
bulk_edit.remove_password,
]:
raise serializers.ValidationError(
"This method does not support all=true.",
)
method = attrs["method"]
parameters = attrs["parameters"]
@@ -2212,7 +2260,7 @@ class DocumentVersionLabelSerializer(serializers.Serializer):
return normalized or None
class BulkDownloadSerializer(DocumentListSerializer):
class BulkDownloadSerializer(DocumentSelectionSerializer):
content = serializers.ChoiceField(
choices=["archive", "originals", "both"],
default="archive",
@@ -2571,13 +2619,25 @@ class ShareLinkBundleSerializer(OwnedObjectSerializer):
class BulkEditObjectsSerializer(SerializerWithPerms, SetPermissionsMixin):
objects = serializers.ListField(
required=True,
allow_empty=False,
required=False,
allow_empty=True,
label="Objects",
write_only=True,
child=serializers.IntegerField(),
)
all = serializers.BooleanField(
default=False,
required=False,
write_only=True,
)
filters = serializers.DictField(
required=False,
allow_empty=True,
write_only=True,
)
object_type = serializers.ChoiceField(
choices=[
"tags",
@@ -2650,10 +2710,20 @@ class BulkEditObjectsSerializer(SerializerWithPerms, SetPermissionsMixin):
def validate(self, attrs):
object_type = attrs["object_type"]
objects = attrs["objects"]
objects = attrs.get("objects")
apply_to_all = attrs.get("all", False)
operation = attrs.get("operation")
self._validate_objects(objects, object_type)
if apply_to_all:
attrs.setdefault("objects", [])
else:
if objects is None:
raise serializers.ValidationError(
"objects is required unless all is true.",
)
if len(objects) == 0:
raise serializers.ValidationError("objects must not be empty")
self._validate_objects(objects, object_type)
if operation == "set_permissions":
permissions = attrs.get("permissions")

View File

@@ -1119,21 +1119,19 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
[u1_doc1.id],
)
def test_pagination_all(self) -> None:
def test_pagination_results(self) -> None:
"""
GIVEN:
- A set of 50 documents
WHEN:
- API request for document filtering
THEN:
- Results are paginated (25 items) and response["all"] returns all ids (50 items)
- Results are paginated (25 items) and count reflects all results (50 items)
"""
t = Tag.objects.create(name="tag")
docs = []
for i in range(50):
d = Document.objects.create(checksum=i, content=f"test{i}")
d.tags.add(t)
docs.append(d)
response = self.client.get(
f"/api/documents/?tags__id__in={t.id}",
@@ -1141,7 +1139,32 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
results = response.data["results"]
self.assertEqual(len(results), 25)
self.assertEqual(len(response.data["all"]), 50)
self.assertEqual(response.data["count"], 50)
self.assertNotIn("all", response.data)
def test_pagination_all_for_api_version_9(self) -> None:
"""
GIVEN:
- A set of documents matching a filter
WHEN:
- API request uses legacy version 9
THEN:
- Response includes "all" for backward compatibility
"""
t = Tag.objects.create(name="tag")
docs = []
for i in range(4):
d = Document.objects.create(checksum=i, content=f"test{i}")
d.tags.add(t)
docs.append(d)
response = self.client.get(
f"/api/documents/?tags__id__in={t.id}",
headers={"Accept": "application/json; version=9"},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("all", response.data)
self.assertCountEqual(response.data["all"], [d.id for d in docs])
def test_list_with_include_selection_data(self) -> None:

View File

@@ -145,6 +145,22 @@ class TestApiObjects(DirectoriesMixin, APITestCase):
response.data["last_correspondence"],
)
def test_paginated_objects_include_all_only_for_legacy_version(self) -> None:
response_v10 = self.client.get("/api/correspondents/")
self.assertEqual(response_v10.status_code, status.HTTP_200_OK)
self.assertNotIn("all", response_v10.data)
response_v9 = self.client.get(
"/api/correspondents/",
headers={"Accept": "application/json; version=9"},
)
self.assertEqual(response_v9.status_code, status.HTTP_200_OK)
self.assertIn("all", response_v9.data)
self.assertCountEqual(
response_v9.data["all"],
[self.c1.id, self.c2.id, self.c3.id],
)
class TestApiStoragePaths(DirectoriesMixin, APITestCase):
ENDPOINT = "/api/storage_paths/"
@@ -774,6 +790,62 @@ class TestBulkEditObjects(APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(StoragePath.objects.count(), 0)
def test_bulk_objects_delete_all_filtered(self) -> None:
"""
GIVEN:
- Existing objects that can be filtered by name
WHEN:
- bulk_edit_objects API endpoint is called with all=true and filters
THEN:
- Matching objects are deleted without passing explicit IDs
"""
Correspondent.objects.create(name="c2")
response = self.client.post(
"/api/bulk_edit_objects/",
json.dumps(
{
"all": True,
"filters": {"name__icontains": "c"},
"object_type": "correspondents",
"operation": "delete",
},
),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(Correspondent.objects.count(), 0)
def test_bulk_objects_delete_all_filtered_tags_includes_descendants(self) -> None:
"""
GIVEN:
- Root tag with descendants
WHEN:
- bulk_edit_objects API endpoint is called with all=true
THEN:
- Root tags and descendants are deleted
"""
parent = Tag.objects.create(name="parent")
child = Tag.objects.create(name="child", tn_parent=parent)
response = self.client.post(
"/api/bulk_edit_objects/",
json.dumps(
{
"all": True,
"filters": {"is_root": True},
"object_type": "tags",
"operation": "delete",
},
),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertFalse(Tag.objects.filter(id=parent.id).exists())
self.assertFalse(Tag.objects.filter(id=child.id).exists())
def test_bulk_edit_object_permissions_insufficient_global_perms(self) -> None:
"""
GIVEN:
@@ -861,3 +933,40 @@ class TestBulkEditObjects(APITestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.content, b"Insufficient permissions")
def test_bulk_edit_all_filtered_permissions_insufficient_object_perms(
self,
) -> None:
"""
GIVEN:
- Filter-matching objects include one that the user cannot edit
WHEN:
- bulk_edit_objects API endpoint is called with all=true
THEN:
- Operation applies only to editable objects
"""
self.t2.owner = User.objects.get(username="temp_admin")
self.t2.save()
self.user1.user_permissions.add(
*Permission.objects.filter(codename="delete_tag"),
)
self.user1.save()
self.client.force_authenticate(user=self.user1)
response = self.client.post(
"/api/bulk_edit_objects/",
json.dumps(
{
"all": True,
"filters": {"name__icontains": "t"},
"object_type": "tags",
"operation": "delete",
},
),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue(Tag.objects.filter(id=self.t2.id).exists())
self.assertFalse(Tag.objects.filter(id=self.t1.id).exists())

View File

@@ -68,26 +68,48 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
results = response.data["results"]
self.assertEqual(response.data["count"], 3)
self.assertEqual(len(results), 3)
self.assertCountEqual(response.data["all"], [d1.id, d2.id, d3.id])
response = self.client.get("/api/documents/?query=september")
results = response.data["results"]
self.assertEqual(response.data["count"], 1)
self.assertEqual(len(results), 1)
self.assertCountEqual(response.data["all"], [d3.id])
self.assertEqual(results[0]["original_file_name"], "someepdf.pdf")
response = self.client.get("/api/documents/?query=statement")
results = response.data["results"]
self.assertEqual(response.data["count"], 2)
self.assertEqual(len(results), 2)
self.assertCountEqual(response.data["all"], [d2.id, d3.id])
response = self.client.get("/api/documents/?query=sfegdfg")
results = response.data["results"]
self.assertEqual(response.data["count"], 0)
self.assertEqual(len(results), 0)
self.assertCountEqual(response.data["all"], [])
def test_search_returns_all_for_api_version_9(self) -> None:
d1 = Document.objects.create(
title="invoice",
content="bank payment",
checksum="A",
pk=1,
)
d2 = Document.objects.create(
title="bank statement",
content="bank transfer",
checksum="B",
pk=2,
)
with AsyncWriter(index.open_index()) as writer:
index.update_document(writer, d1)
index.update_document(writer, d2)
response = self.client.get(
"/api/documents/?query=bank",
headers={"Accept": "application/json; version=9"},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("all", response.data)
self.assertCountEqual(response.data["all"], [d1.id, d2.id])
def test_search_with_include_selection_data(self) -> None:
correspondent = Correspondent.objects.create(name="c1")

View File

@@ -555,7 +555,9 @@ class TagViewSet(PermissionsAwareDocumentCountMixin, ModelViewSet):
page = self.paginate_queryset(queryset)
serializer = self.get_serializer(page, many=True)
response = self.get_paginated_response(serializer.data)
if descendant_pks:
response.data["display_count"] = len(children_source)
api_version = int(request.version or settings.REST_FRAMEWORK["DEFAULT_VERSION"])
if descendant_pks and api_version < 10:
# Include children in the "all" field, if needed
response.data["all"] = [tag.pk for tag in children_source]
return response
@@ -2084,7 +2086,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
),
),
):
result_ids = response.data.get("all", [])
result_ids = queryset.get_result_ids()
response.data["selection_data"] = (
self._get_selection_data_for_queryset(
Document.objects.filter(pk__in=result_ids),
@@ -2213,7 +2215,36 @@ class SavedViewViewSet(BulkPermissionMixin, PassUserMixin, ModelViewSet):
ordering_fields = ("name",)
class DocumentOperationPermissionMixin(PassUserMixin):
class DocumentSelectionMixin:
def _resolve_document_ids(
self,
*,
user: User,
validated_data: dict[str, Any],
permission_codename: str = "view_document",
) -> list[int]:
if not validated_data.get("all", False):
# if all is not true, just pass through the provided document ids
return validated_data["documents"]
# otherwise, reconstruct the document list based on the provided filters
filters = validated_data.get("filters") or {}
permitted_documents = get_objects_for_user_owner_aware(
user,
permission_codename,
Document,
)
return list(
DocumentFilterSet(
data=filters,
queryset=permitted_documents,
)
.qs.distinct()
.values_list("pk", flat=True),
)
class DocumentOperationPermissionMixin(PassUserMixin, DocumentSelectionMixin):
permission_classes = (IsAuthenticated,)
parser_classes = (parsers.JSONParser,)
METHOD_NAMES_REQUIRING_USER = {
@@ -2307,8 +2338,15 @@ class DocumentOperationPermissionMixin(PassUserMixin):
validated_data: dict[str, Any],
operation_label: str,
):
documents = validated_data["documents"]
parameters = {k: v for k, v in validated_data.items() if k != "documents"}
documents = self._resolve_document_ids(
user=self.request.user,
validated_data=validated_data,
)
parameters = {
k: v
for k, v in validated_data.items()
if k not in {"documents", "all", "filters"}
}
user = self.request.user
if method.__name__ in self.METHOD_NAMES_REQUIRING_USER:
@@ -2396,7 +2434,10 @@ class BulkEditView(DocumentOperationPermissionMixin):
user = self.request.user
method = serializer.validated_data.get("method")
parameters = serializer.validated_data.get("parameters")
documents = serializer.validated_data.get("documents")
documents = self._resolve_document_ids(
user=user,
validated_data=serializer.validated_data,
)
if method.__name__ in self.METHOD_NAMES_REQUIRING_USER:
parameters["user"] = user
if not self._has_document_permissions(
@@ -3240,7 +3281,7 @@ class StatisticsView(GenericAPIView):
)
class BulkDownloadView(GenericAPIView):
class BulkDownloadView(DocumentSelectionMixin, GenericAPIView):
permission_classes = (IsAuthenticated,)
serializer_class = BulkDownloadSerializer
parser_classes = (parsers.JSONParser,)
@@ -3249,7 +3290,10 @@ class BulkDownloadView(GenericAPIView):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
ids = serializer.validated_data.get("documents")
ids = self._resolve_document_ids(
user=request.user,
validated_data=serializer.validated_data,
)
documents = Document.objects.filter(pk__in=ids)
compression = serializer.validated_data.get("compression")
content = serializer.validated_data.get("content")
@@ -3879,20 +3923,55 @@ class BulkEditObjectsView(PassUserMixin):
user = self.request.user
object_type = serializer.validated_data.get("object_type")
object_ids = serializer.validated_data.get("objects")
apply_to_all = serializer.validated_data.get("all")
object_class = serializer.get_object_class(object_type)
operation = serializer.validated_data.get("operation")
model_name = object_class._meta.model_name
perm_codename = (
f"change_{model_name}"
if operation == "set_permissions"
else f"delete_{model_name}"
)
objs = object_class.objects.select_related("owner").filter(pk__in=object_ids)
if apply_to_all:
# Support all to avoid sending large lists of ids for bulk operations, with optional filters
filters = serializer.validated_data.get("filters") or {}
filterset_class = {
"tags": TagFilterSet,
"correspondents": CorrespondentFilterSet,
"document_types": DocumentTypeFilterSet,
"storage_paths": StoragePathFilterSet,
}[object_type]
user_permitted_objects = get_objects_for_user_owner_aware(
user,
perm_codename,
object_class,
)
objs = filterset_class(
data=filters,
queryset=user_permitted_objects,
).qs
if object_type == "tags":
editable_ids = set(user_permitted_objects.values_list("pk", flat=True))
all_ids = set(objs.values_list("pk", flat=True))
for tag in objs:
all_ids.update(
descendant.pk
for descendant in tag.get_descendants()
if descendant.pk in editable_ids
)
objs = object_class.objects.filter(pk__in=all_ids)
objs = objs.select_related("owner")
object_ids = list(objs.values_list("pk", flat=True))
else:
objs = object_class.objects.select_related("owner").filter(
pk__in=object_ids,
)
if not user.is_superuser:
model_name = object_class._meta.model_name
perm = (
f"documents.change_{model_name}"
if operation == "set_permissions"
else f"documents.delete_{model_name}"
)
perm = f"documents.{perm_codename}"
has_perms = user.has_perm(perm) and all(
(obj.owner == user or obj.owner is None) for obj in objs
has_perms_owner_aware(user, perm_codename, obj) for obj in objs
)
if not has_perms:

View File

@@ -9,6 +9,7 @@ from allauth.mfa.recovery_codes.internal.flows import auto_generate_recovery_cod
from allauth.mfa.totp.internal import auth as totp_auth
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.models import SocialAccount
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.contrib.staticfiles.storage import staticfiles_storage
@@ -56,17 +57,27 @@ class StandardPagination(PageNumberPagination):
page_size_query_param = "page_size"
max_page_size = 100000
def _get_api_version(self) -> int:
request = getattr(self, "request", None)
default_version = settings.REST_FRAMEWORK["DEFAULT_VERSION"]
return int(request.version if request else default_version)
def _should_include_all(self) -> bool:
# TODO: remove legacy `all` support when API v9 is dropped.
return self._get_api_version() < 10
def get_paginated_response(self, data):
response_data = [
("count", self.page.paginator.count),
("next", self.get_next_link()),
("previous", self.get_previous_link()),
]
if self._should_include_all():
response_data.append(("all", self.get_all_result_ids()))
response_data.append(("results", data))
return Response(
OrderedDict(
[
("count", self.page.paginator.count),
("next", self.get_next_link()),
("previous", self.get_previous_link()),
("all", self.get_all_result_ids()),
("results", data),
],
),
OrderedDict(response_data),
)
def get_all_result_ids(self):
@@ -87,11 +98,14 @@ class StandardPagination(PageNumberPagination):
def get_paginated_response_schema(self, schema):
response_schema = super().get_paginated_response_schema(schema)
response_schema["properties"]["all"] = {
"type": "array",
"example": "[1, 2, 3]",
"items": {"type": "integer"},
}
if self._should_include_all():
response_schema["properties"]["all"] = {
"type": "array",
"example": "[1, 2, 3]",
"items": {"type": "integer"},
}
else:
response_schema["properties"].pop("all", None)
return response_schema