feat(tasks): extend and harden the task system redesign

- TaskType: add EMPTY_TRASH, CHECK_WORKFLOWS, CLEANUP_SHARE_LINKS;
  remove INDEX_REBUILD (no backing task — beat schedule uses index_optimize)
- TRACKED_TASKS: wire up all nine task types including the three new ones
  and llmindex_index / process_mail_accounts
- Add task_revoked_handler so cancelled/expired tasks are marked REVOKED
- Fix double-write: task_postrun_handler no longer overwrites result_data
  when status is already FAILURE (task_failure_handler owns that write)
- v9 serialiser: map EMAIL_CONSUME and FOLDER_CONSUME to AUTO_TASK
- views: scope task list to owner for regular users, admins see all;
  validate ?days= query param and return 400 on bad input
- tests: add test_list_admin_sees_all_tasks; rename/fix
  test_parses_duplicate_string (duplicates produce SUCCESS, not FAILURE);
  use PaperlessTaskFactory in modified tests

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
stumpylog
2026-04-15 13:59:02 -07:00
parent 979d8a67f0
commit e86c37f717
7 changed files with 106 additions and 31 deletions
@@ -60,9 +60,11 @@ class Migration(migrations.Migration):
("train_classifier", "Train Classifier"),
("sanity_check", "Sanity Check"),
("index_optimize", "Index Optimize"),
("index_rebuild", "Index Rebuild"),
("mail_fetch", "Mail Fetch"),
("llm_index", "LLM Index"),
("empty_trash", "Empty Trash"),
("check_workflows", "Check Workflows"),
("cleanup_share_links", "Cleanup Share Links"),
],
db_index=True,
help_text="The kind of work being performed",
+3 -1
View File
@@ -683,9 +683,11 @@ class PaperlessTask(ModelWithOwner):
TRAIN_CLASSIFIER = "train_classifier", _("Train Classifier")
SANITY_CHECK = "sanity_check", _("Sanity Check")
INDEX_OPTIMIZE = "index_optimize", _("Index Optimize")
INDEX_REBUILD = "index_rebuild", _("Index Rebuild")
MAIL_FETCH = "mail_fetch", _("Mail Fetch")
LLM_INDEX = "llm_index", _("LLM Index")
EMPTY_TRASH = "empty_trash", _("Empty Trash")
CHECK_WORKFLOWS = "check_workflows", _("Check Workflows")
CLEANUP_SHARE_LINKS = "cleanup_share_links", _("Cleanup Share Links")
COMPLETE_STATUSES = (
Status.SUCCESS,
+3
View File
@@ -2532,6 +2532,9 @@ class TaskSerializerV9(serializers.ModelSerializer):
_TRIGGER_SOURCE_TO_V9_TYPE = {
PaperlessTask.TriggerSource.SCHEDULED: "SCHEDULED_TASK",
PaperlessTask.TriggerSource.SYSTEM: "AUTO_TASK",
# Email and folder-consumer documents are system-initiated, not manually triggered
PaperlessTask.TriggerSource.EMAIL_CONSUME: "AUTO_TASK",
PaperlessTask.TriggerSource.FOLDER_CONSUME: "AUTO_TASK",
}
def get_type(self, obj: PaperlessTask) -> str:
+59 -16
View File
@@ -14,6 +14,7 @@ from celery.signals import before_task_publish
from celery.signals import task_failure
from celery.signals import task_postrun
from celery.signals import task_prerun
from celery.signals import task_revoked
from celery.signals import worker_process_init
from django.conf import settings
from django.contrib.auth.models import Group
@@ -1008,6 +1009,9 @@ TRACKED_TASKS: dict[str, PaperlessTask.TaskType] = {
"documents.tasks.sanity_check": PaperlessTask.TaskType.SANITY_CHECK,
"documents.tasks.index_optimize": PaperlessTask.TaskType.INDEX_OPTIMIZE,
"documents.tasks.llmindex_index": PaperlessTask.TaskType.LLM_INDEX,
"documents.tasks.empty_trash": PaperlessTask.TaskType.EMPTY_TRASH,
"documents.tasks.check_scheduled_workflows": PaperlessTask.TaskType.CHECK_WORKFLOWS,
"documents.tasks.cleanup_expired_share_link_bundles": PaperlessTask.TaskType.CLEANUP_SHARE_LINKS,
"paperless_mail.tasks.process_mail_accounts": PaperlessTask.TaskType.MAIL_FETCH,
}
@@ -1218,6 +1222,11 @@ def task_postrun_handler(
"""
Records task completion and result data.
task_failure also fires when a task raises an exception, and it writes
richer structured error data. To avoid a race where this handler
overwrites that data, result_data and result_message are left untouched
when the final state is FAILURE — task_failure_handler owns those fields.
https://docs.celeryq.dev/en/stable/userguide/signals.html#task-postrun
"""
if task_id is None:
@@ -1227,14 +1236,6 @@ def task_postrun_handler(
new_status = _CELERY_STATE_TO_STATUS.get(state, PaperlessTask.Status.FAILURE)
result_data: dict | None = None
result_message: str | None = None
if isinstance(retval, dict):
result_data = retval
elif isinstance(retval, str):
result_message = retval
result_data = _parse_consume_result(retval)
now = timezone.now()
task_instance = PaperlessTask.objects.filter(task_id=task_id).first()
if task_instance is None:
@@ -1249,14 +1250,23 @@ def task_postrun_handler(
task_instance.date_started - task_instance.date_created
).total_seconds()
PaperlessTask.objects.filter(task_id=task_id).update(
status=new_status,
result_data=result_data,
result_message=result_message,
date_done=now,
duration_seconds=duration_seconds,
wait_time_seconds=wait_time_seconds,
)
update_fields: dict = {
"status": new_status,
"date_done": now,
"duration_seconds": duration_seconds,
"wait_time_seconds": wait_time_seconds,
}
# Only write result data for non-failure outcomes; task_failure_handler
# owns result_data/result_message for FAILURE states.
if new_status != PaperlessTask.Status.FAILURE:
if isinstance(retval, dict):
update_fields["result_data"] = retval
elif isinstance(retval, str):
update_fields["result_message"] = retval
update_fields["result_data"] = _parse_consume_result(retval)
PaperlessTask.objects.filter(task_id=task_id).update(**update_fields)
except Exception:
logger.exception("Updating PaperlessTask failed")
@@ -1298,6 +1308,39 @@ def task_failure_handler(
logger.exception("Updating PaperlessTask on failure failed")
@task_revoked.connect
def task_revoked_handler(
sender=None,
request=None,
*,
terminated: bool = False,
signum=None,
expired: bool = False,
**kwargs,
) -> None:
"""
Marks the task REVOKED when it is cancelled before or during execution.
This fires for tasks revoked while still queued (before task_prerun) as
well as for tasks terminated mid-run. task_postrun does NOT fire for
pre-start revocations, so this handler is the only way to move those
records out of PENDING.
https://docs.celeryq.dev/en/stable/userguide/signals.html#task-revoked
"""
task_id = request.id if request else None
if task_id is None:
return
try:
close_old_connections()
PaperlessTask.objects.filter(task_id=task_id).update(
status=PaperlessTask.Status.REVOKED,
date_done=timezone.now(),
)
except Exception:
logger.exception("Updating PaperlessTask on revocation failed")
@worker_process_init.connect
def close_connection_pool_on_worker_init(**kwargs) -> None:
"""
+22 -8
View File
@@ -163,12 +163,12 @@ class TestGetTasksV10:
ids = [t["task_id"] for t in response.data]
assert ids == [t3.task_id, t2.task_id, t1.task_id]
def test_list_is_owner_aware(
def test_list_scoped_to_own_tasks_for_regular_user(
self,
admin_user: User,
regular_user: User,
) -> None:
"""The task list only shows tasks the user owns or that are unowned."""
"""Regular users see only tasks they own; tasks owned by others or unowned are hidden."""
regular_user.user_permissions.add(
Permission.objects.get(codename="view_paperlesstask"),
)
@@ -177,17 +177,31 @@ class TestGetTasksV10:
client.force_authenticate(user=regular_user)
client.credentials(HTTP_ACCEPT=ACCEPT_V10)
PaperlessTaskFactory(owner=admin_user)
shared_task = PaperlessTaskFactory()
PaperlessTaskFactory(owner=admin_user) # other user — not visible
PaperlessTaskFactory() # unowned (system task) — not visible
own_task = PaperlessTaskFactory(owner=regular_user)
response = client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 2
returned_task_ids = {t["task_id"] for t in response.data}
assert shared_task.task_id in returned_task_ids
assert own_task.task_id in returned_task_ids
assert len(response.data) == 1
assert response.data[0]["task_id"] == own_task.task_id
def test_list_admin_sees_all_tasks(
self,
admin_client: APIClient,
admin_user: User,
regular_user: User,
) -> None:
"""Admin users see all tasks regardless of owner."""
PaperlessTaskFactory(owner=admin_user)
PaperlessTaskFactory() # unowned system task
PaperlessTaskFactory(owner=regular_user)
response = admin_client.get(ENDPOINT)
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 3
@pytest.mark.django_db()
+3 -2
View File
@@ -212,14 +212,15 @@ class TestTaskPostrunHandler:
assert task.result_data["document_id"] == 42
assert task.result_message == "New document id 42 created"
def test_parses_legacy_duplicate_string(self):
def test_parses_duplicate_string(self):
"""Duplicate detection returns a string with SUCCESS state (StopConsumeTaskError is caught and returned, not raised)."""
task = self._started_task()
from documents.signals.handlers import task_postrun_handler
task_postrun_handler(
task_id=task.task_id,
retval="It is a duplicate of some document (#99).",
state="FAILURE",
state="SUCCESS",
)
task.refresh_from_db()
assert task.result_data["duplicate_of"] == 99
+13 -3
View File
@@ -3775,7 +3775,6 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
filter_backends = (
DjangoFilterBackend,
OrderingFilter,
ObjectOwnedOrGrantedPermissionsFilter,
)
filterset_class = PaperlessTaskFilterSet
ordering_fields = [
@@ -3809,7 +3808,12 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
return TaskSerializerV10
def get_queryset(self):
queryset = PaperlessTask.objects.all()
# Staff see all tasks; regular users see only tasks they own.
# Unowned tasks (system/scheduled) are admin-only.
if self.request.user.is_staff:
queryset = PaperlessTask.objects.all()
else:
queryset = PaperlessTask.objects.filter(owner=self.request.user)
# v9 backwards compat: map old query params to new field names
if self.request.version and int(self.request.version) < 10:
task_name = self.request.query_params.get("task_name")
@@ -3859,7 +3863,13 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
@action(methods=["get"], detail=False)
def summary(self, request):
"""Aggregated task statistics per task_type over the last N days (default 30)."""
days = int(request.query_params.get("days", 30))
try:
days = max(1, int(request.query_params.get("days", 30)))
except (TypeError, ValueError):
return Response(
{"days": "Must be a positive integer."},
status=status.HTTP_400_BAD_REQUEST,
)
cutoff = timezone.now() - timedelta(days=days)
queryset = self.get_queryset().filter(date_created__gte=cutoff)