From e86c37f717284507e6d35adbf0d45df463495751 Mon Sep 17 00:00:00 2001 From: stumpylog <797416+stumpylog@users.noreply.github.com> Date: Wed, 15 Apr 2026 13:59:02 -0700 Subject: [PATCH] feat(tasks): extend and harden the task system redesign MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TaskType: add EMPTY_TRASH, CHECK_WORKFLOWS, CLEANUP_SHARE_LINKS; remove INDEX_REBUILD (no backing task — beat schedule uses index_optimize) - TRACKED_TASKS: wire up all nine task types including the three new ones and llmindex_index / process_mail_accounts - Add task_revoked_handler so cancelled/expired tasks are marked REVOKED - Fix double-write: task_postrun_handler no longer overwrites result_data when status is already FAILURE (task_failure_handler owns that write) - v9 serialiser: map EMAIL_CONSUME and FOLDER_CONSUME to AUTO_TASK - views: scope task list to owner for regular users, admins see all; validate ?days= query param and return 400 on bad input - tests: add test_list_admin_sees_all_tasks; rename/fix test_parses_duplicate_string (duplicates produce SUCCESS, not FAILURE); use PaperlessTaskFactory in modified tests Co-Authored-By: Claude Sonnet 4.6 --- .../migrations/0019_task_system_redesign.py | 4 +- src/documents/models.py | 4 +- src/documents/serialisers.py | 3 + src/documents/signals/handlers.py | 75 +++++++++++++++---- src/documents/tests/test_api_tasks.py | 30 ++++++-- src/documents/tests/test_task_signals.py | 5 +- src/documents/views.py | 16 +++- 7 files changed, 106 insertions(+), 31 deletions(-) diff --git a/src/documents/migrations/0019_task_system_redesign.py b/src/documents/migrations/0019_task_system_redesign.py index 790c26650..1651a60fc 100644 --- a/src/documents/migrations/0019_task_system_redesign.py +++ b/src/documents/migrations/0019_task_system_redesign.py @@ -60,9 +60,11 @@ class Migration(migrations.Migration): ("train_classifier", "Train Classifier"), ("sanity_check", "Sanity Check"), ("index_optimize", "Index Optimize"), - ("index_rebuild", "Index Rebuild"), ("mail_fetch", "Mail Fetch"), ("llm_index", "LLM Index"), + ("empty_trash", "Empty Trash"), + ("check_workflows", "Check Workflows"), + ("cleanup_share_links", "Cleanup Share Links"), ], db_index=True, help_text="The kind of work being performed", diff --git a/src/documents/models.py b/src/documents/models.py index 8a082c76e..63bac5cff 100644 --- a/src/documents/models.py +++ b/src/documents/models.py @@ -683,9 +683,11 @@ class PaperlessTask(ModelWithOwner): TRAIN_CLASSIFIER = "train_classifier", _("Train Classifier") SANITY_CHECK = "sanity_check", _("Sanity Check") INDEX_OPTIMIZE = "index_optimize", _("Index Optimize") - INDEX_REBUILD = "index_rebuild", _("Index Rebuild") MAIL_FETCH = "mail_fetch", _("Mail Fetch") LLM_INDEX = "llm_index", _("LLM Index") + EMPTY_TRASH = "empty_trash", _("Empty Trash") + CHECK_WORKFLOWS = "check_workflows", _("Check Workflows") + CLEANUP_SHARE_LINKS = "cleanup_share_links", _("Cleanup Share Links") COMPLETE_STATUSES = ( Status.SUCCESS, diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 353fb9700..81a5925b0 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -2532,6 +2532,9 @@ class TaskSerializerV9(serializers.ModelSerializer): _TRIGGER_SOURCE_TO_V9_TYPE = { PaperlessTask.TriggerSource.SCHEDULED: "SCHEDULED_TASK", PaperlessTask.TriggerSource.SYSTEM: "AUTO_TASK", + # Email and folder-consumer documents are system-initiated, not manually triggered + PaperlessTask.TriggerSource.EMAIL_CONSUME: "AUTO_TASK", + PaperlessTask.TriggerSource.FOLDER_CONSUME: "AUTO_TASK", } def get_type(self, obj: PaperlessTask) -> str: diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 9fd79cb45..0a5e96755 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -14,6 +14,7 @@ from celery.signals import before_task_publish from celery.signals import task_failure from celery.signals import task_postrun from celery.signals import task_prerun +from celery.signals import task_revoked from celery.signals import worker_process_init from django.conf import settings from django.contrib.auth.models import Group @@ -1008,6 +1009,9 @@ TRACKED_TASKS: dict[str, PaperlessTask.TaskType] = { "documents.tasks.sanity_check": PaperlessTask.TaskType.SANITY_CHECK, "documents.tasks.index_optimize": PaperlessTask.TaskType.INDEX_OPTIMIZE, "documents.tasks.llmindex_index": PaperlessTask.TaskType.LLM_INDEX, + "documents.tasks.empty_trash": PaperlessTask.TaskType.EMPTY_TRASH, + "documents.tasks.check_scheduled_workflows": PaperlessTask.TaskType.CHECK_WORKFLOWS, + "documents.tasks.cleanup_expired_share_link_bundles": PaperlessTask.TaskType.CLEANUP_SHARE_LINKS, "paperless_mail.tasks.process_mail_accounts": PaperlessTask.TaskType.MAIL_FETCH, } @@ -1218,6 +1222,11 @@ def task_postrun_handler( """ Records task completion and result data. + task_failure also fires when a task raises an exception, and it writes + richer structured error data. To avoid a race where this handler + overwrites that data, result_data and result_message are left untouched + when the final state is FAILURE — task_failure_handler owns those fields. + https://docs.celeryq.dev/en/stable/userguide/signals.html#task-postrun """ if task_id is None: @@ -1227,14 +1236,6 @@ def task_postrun_handler( new_status = _CELERY_STATE_TO_STATUS.get(state, PaperlessTask.Status.FAILURE) - result_data: dict | None = None - result_message: str | None = None - if isinstance(retval, dict): - result_data = retval - elif isinstance(retval, str): - result_message = retval - result_data = _parse_consume_result(retval) - now = timezone.now() task_instance = PaperlessTask.objects.filter(task_id=task_id).first() if task_instance is None: @@ -1249,14 +1250,23 @@ def task_postrun_handler( task_instance.date_started - task_instance.date_created ).total_seconds() - PaperlessTask.objects.filter(task_id=task_id).update( - status=new_status, - result_data=result_data, - result_message=result_message, - date_done=now, - duration_seconds=duration_seconds, - wait_time_seconds=wait_time_seconds, - ) + update_fields: dict = { + "status": new_status, + "date_done": now, + "duration_seconds": duration_seconds, + "wait_time_seconds": wait_time_seconds, + } + + # Only write result data for non-failure outcomes; task_failure_handler + # owns result_data/result_message for FAILURE states. + if new_status != PaperlessTask.Status.FAILURE: + if isinstance(retval, dict): + update_fields["result_data"] = retval + elif isinstance(retval, str): + update_fields["result_message"] = retval + update_fields["result_data"] = _parse_consume_result(retval) + + PaperlessTask.objects.filter(task_id=task_id).update(**update_fields) except Exception: logger.exception("Updating PaperlessTask failed") @@ -1298,6 +1308,39 @@ def task_failure_handler( logger.exception("Updating PaperlessTask on failure failed") +@task_revoked.connect +def task_revoked_handler( + sender=None, + request=None, + *, + terminated: bool = False, + signum=None, + expired: bool = False, + **kwargs, +) -> None: + """ + Marks the task REVOKED when it is cancelled before or during execution. + + This fires for tasks revoked while still queued (before task_prerun) as + well as for tasks terminated mid-run. task_postrun does NOT fire for + pre-start revocations, so this handler is the only way to move those + records out of PENDING. + + https://docs.celeryq.dev/en/stable/userguide/signals.html#task-revoked + """ + task_id = request.id if request else None + if task_id is None: + return + try: + close_old_connections() + PaperlessTask.objects.filter(task_id=task_id).update( + status=PaperlessTask.Status.REVOKED, + date_done=timezone.now(), + ) + except Exception: + logger.exception("Updating PaperlessTask on revocation failed") + + @worker_process_init.connect def close_connection_pool_on_worker_init(**kwargs) -> None: """ diff --git a/src/documents/tests/test_api_tasks.py b/src/documents/tests/test_api_tasks.py index b7fa7d06a..cef09a5c2 100644 --- a/src/documents/tests/test_api_tasks.py +++ b/src/documents/tests/test_api_tasks.py @@ -163,12 +163,12 @@ class TestGetTasksV10: ids = [t["task_id"] for t in response.data] assert ids == [t3.task_id, t2.task_id, t1.task_id] - def test_list_is_owner_aware( + def test_list_scoped_to_own_tasks_for_regular_user( self, admin_user: User, regular_user: User, ) -> None: - """The task list only shows tasks the user owns or that are unowned.""" + """Regular users see only tasks they own; tasks owned by others or unowned are hidden.""" regular_user.user_permissions.add( Permission.objects.get(codename="view_paperlesstask"), ) @@ -177,17 +177,31 @@ class TestGetTasksV10: client.force_authenticate(user=regular_user) client.credentials(HTTP_ACCEPT=ACCEPT_V10) - PaperlessTaskFactory(owner=admin_user) - shared_task = PaperlessTaskFactory() + PaperlessTaskFactory(owner=admin_user) # other user — not visible + PaperlessTaskFactory() # unowned (system task) — not visible own_task = PaperlessTaskFactory(owner=regular_user) response = client.get(ENDPOINT) assert response.status_code == status.HTTP_200_OK - assert len(response.data) == 2 - returned_task_ids = {t["task_id"] for t in response.data} - assert shared_task.task_id in returned_task_ids - assert own_task.task_id in returned_task_ids + assert len(response.data) == 1 + assert response.data[0]["task_id"] == own_task.task_id + + def test_list_admin_sees_all_tasks( + self, + admin_client: APIClient, + admin_user: User, + regular_user: User, + ) -> None: + """Admin users see all tasks regardless of owner.""" + PaperlessTaskFactory(owner=admin_user) + PaperlessTaskFactory() # unowned system task + PaperlessTaskFactory(owner=regular_user) + + response = admin_client.get(ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 3 @pytest.mark.django_db() diff --git a/src/documents/tests/test_task_signals.py b/src/documents/tests/test_task_signals.py index 4030a96ce..62321dfa6 100644 --- a/src/documents/tests/test_task_signals.py +++ b/src/documents/tests/test_task_signals.py @@ -212,14 +212,15 @@ class TestTaskPostrunHandler: assert task.result_data["document_id"] == 42 assert task.result_message == "New document id 42 created" - def test_parses_legacy_duplicate_string(self): + def test_parses_duplicate_string(self): + """Duplicate detection returns a string with SUCCESS state (StopConsumeTaskError is caught and returned, not raised).""" task = self._started_task() from documents.signals.handlers import task_postrun_handler task_postrun_handler( task_id=task.task_id, retval="It is a duplicate of some document (#99).", - state="FAILURE", + state="SUCCESS", ) task.refresh_from_db() assert task.result_data["duplicate_of"] == 99 diff --git a/src/documents/views.py b/src/documents/views.py index 61c95a4dd..fc1b55acf 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -3775,7 +3775,6 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]): filter_backends = ( DjangoFilterBackend, OrderingFilter, - ObjectOwnedOrGrantedPermissionsFilter, ) filterset_class = PaperlessTaskFilterSet ordering_fields = [ @@ -3809,7 +3808,12 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]): return TaskSerializerV10 def get_queryset(self): - queryset = PaperlessTask.objects.all() + # Staff see all tasks; regular users see only tasks they own. + # Unowned tasks (system/scheduled) are admin-only. + if self.request.user.is_staff: + queryset = PaperlessTask.objects.all() + else: + queryset = PaperlessTask.objects.filter(owner=self.request.user) # v9 backwards compat: map old query params to new field names if self.request.version and int(self.request.version) < 10: task_name = self.request.query_params.get("task_name") @@ -3859,7 +3863,13 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]): @action(methods=["get"], detail=False) def summary(self, request): """Aggregated task statistics per task_type over the last N days (default 30).""" - days = int(request.query_params.get("days", 30)) + try: + days = max(1, int(request.query_params.get("days", 30))) + except (TypeError, ValueError): + return Response( + {"days": "Must be a positive integer."}, + status=status.HTTP_400_BAD_REQUEST, + ) cutoff = timezone.now() - timedelta(days=days) queryset = self.get_queryset().filter(date_created__gte=cutoff)