From c2f02851da5417817a21b6d237b5a215a7fcec22 Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Fri, 3 Apr 2026 14:18:01 -0700 Subject: [PATCH] Chore: Better typed status manager messages (#12509) --- src/documents/consumer.py | 14 ++-- src/documents/plugins/helpers.py | 91 +++++++++++++++++++++----- src/documents/tests/utils.py | 13 ++-- src/paperless/consumers.py | 25 +++++-- src/paperless/tests/test_websockets.py | 10 ++- 5 files changed, 114 insertions(+), 39 deletions(-) diff --git a/src/documents/consumer.py b/src/documents/consumer.py index 6ae5914b7..f68fa0685 100644 --- a/src/documents/consumer.py +++ b/src/documents/consumer.py @@ -139,14 +139,12 @@ class ConsumerPluginMixin: message, current_progress, max_progress, - extra_args={ - "document_id": document_id, - "owner_id": self.metadata.owner_id if self.metadata.owner_id else None, - "users_can_view": (self.metadata.view_users or []) - + (self.metadata.change_users or []), - "groups_can_view": (self.metadata.view_groups or []) - + (self.metadata.change_groups or []), - }, + document_id=document_id, + owner_id=self.metadata.owner_id if self.metadata.owner_id else None, + users_can_view=(self.metadata.view_users or []) + + (self.metadata.change_users or []), + groups_can_view=(self.metadata.view_groups or []) + + (self.metadata.change_groups or []), ) def _fail( diff --git a/src/documents/plugins/helpers.py b/src/documents/plugins/helpers.py index e5cfde3b8..e30591125 100644 --- a/src/documents/plugins/helpers.py +++ b/src/documents/plugins/helpers.py @@ -1,6 +1,9 @@ import enum -from collections.abc import Mapping from typing import TYPE_CHECKING +from typing import Literal +from typing import Self +from typing import TypeAlias +from typing import TypedDict from asgiref.sync import async_to_sync from channels.layers import get_channel_layer @@ -16,6 +19,59 @@ class ProgressStatusOptions(enum.StrEnum): FAILED = "FAILED" +class PermissionsData(TypedDict, total=False): + """Permission fields included in status messages for access control.""" + + owner_id: int | None + users_can_view: list[int] + groups_can_view: list[int] + + +class ProgressUpdateData(TypedDict): + filename: str | None + task_id: str | None + current_progress: int + max_progress: int + status: str + message: str + document_id: int | None + owner_id: int | None + users_can_view: list[int] + groups_can_view: list[int] + + +class StatusUpdatePayload(TypedDict): + type: Literal["status_update"] + data: ProgressUpdateData + + +class DocumentsDeletedData(TypedDict): + documents: list[int] + + +class DocumentsDeletedPayload(TypedDict): + type: Literal["documents_deleted"] + data: DocumentsDeletedData + + +class DocumentUpdatedData(TypedDict): + document_id: int + modified: str + owner_id: int | None + users_can_view: list[int] + groups_can_view: list[int] + + +class DocumentUpdatedPayload(TypedDict): + type: Literal["document_updated"] + data: DocumentUpdatedData + + +WebsocketPayload: TypeAlias = ( + StatusUpdatePayload | DocumentsDeletedPayload | DocumentUpdatedPayload +) + + class BaseStatusManager: """ Handles sending of progress information via the channel layer, with proper management @@ -25,11 +81,11 @@ class BaseStatusManager: def __init__(self) -> None: self._channel: RedisPubSubChannelLayer | None = None - def __enter__(self): + def __enter__(self) -> Self: self.open() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: self.close() def open(self) -> None: @@ -48,7 +104,7 @@ class BaseStatusManager: async_to_sync(self._channel.flush) self._channel = None - def send(self, payload: Mapping[str, object]) -> None: + def send(self, payload: WebsocketPayload) -> None: # Ensure the layer is open self.open() @@ -72,36 +128,36 @@ class ProgressManager(BaseStatusManager): message: str, current_progress: int, max_progress: int, - extra_args: dict[str, str | int | None] | None = None, + *, + document_id: int | None = None, + owner_id: int | None = None, + users_can_view: list[int] | None = None, + groups_can_view: list[int] | None = None, ) -> None: - data: dict[str, object] = { + data: ProgressUpdateData = { "filename": self.filename, "task_id": self.task_id, "current_progress": current_progress, "max_progress": max_progress, "status": status, "message": message, + "document_id": document_id, + "owner_id": owner_id, + "users_can_view": users_can_view or [], + "groups_can_view": groups_can_view or [], } - if extra_args is not None: - data.update(extra_args) - - payload: dict[str, object] = { - "type": "status_update", - "data": data, - } - + payload: StatusUpdatePayload = {"type": "status_update", "data": data} self.send(payload) class DocumentsStatusManager(BaseStatusManager): def send_documents_deleted(self, documents: list[int]) -> None: - payload: dict[str, object] = { + payload: DocumentsDeletedPayload = { "type": "documents_deleted", "data": { "documents": documents, }, } - self.send(payload) def send_document_updated( @@ -113,7 +169,7 @@ class DocumentsStatusManager(BaseStatusManager): users_can_view: list[int] | None = None, groups_can_view: list[int] | None = None, ) -> None: - payload: dict[str, object] = { + payload: DocumentUpdatedPayload = { "type": "document_updated", "data": { "document_id": document_id, @@ -123,5 +179,4 @@ class DocumentsStatusManager(BaseStatusManager): "groups_can_view": groups_can_view or [], }, } - self.send(payload) diff --git a/src/documents/tests/utils.py b/src/documents/tests/utils.py index cc4190974..98c8258b8 100644 --- a/src/documents/tests/utils.py +++ b/src/documents/tests/utils.py @@ -435,7 +435,11 @@ class DummyProgressManager: message: str, current_progress: int, max_progress: int, - extra_args: dict[str, str | int] | None = None, + *, + document_id: int | None = None, + owner_id: int | None = None, + users_can_view: list[int] | None = None, + groups_can_view: list[int] | None = None, ) -> None: # Ensure the layer is open self.open() @@ -449,9 +453,10 @@ class DummyProgressManager: "max_progress": max_progress, "status": status, "message": message, + "document_id": document_id, + "owner_id": owner_id, + "users_can_view": users_can_view or [], + "groups_can_view": groups_can_view or [], }, } - if extra_args is not None: - payload["data"].update(extra_args) - self.payloads.append(payload) diff --git a/src/paperless/consumers.py b/src/paperless/consumers.py index 9d59a1a5a..4a3cda8fe 100644 --- a/src/paperless/consumers.py +++ b/src/paperless/consumers.py @@ -1,16 +1,27 @@ +from __future__ import annotations + import json -from typing import Any +from typing import TYPE_CHECKING from channels.generic.websocket import AsyncWebsocketConsumer +if TYPE_CHECKING: + from django.contrib.auth.base_user import AbstractBaseUser + from django.contrib.auth.models import AnonymousUser + + from documents.plugins.helpers import DocumentsDeletedPayload + from documents.plugins.helpers import DocumentUpdatedPayload + from documents.plugins.helpers import PermissionsData + from documents.plugins.helpers import StatusUpdatePayload + class StatusConsumer(AsyncWebsocketConsumer): def _authenticated(self) -> bool: - user: Any = self.scope.get("user") + user: AbstractBaseUser | AnonymousUser | None = self.scope.get("user") return user is not None and user.is_authenticated - async def _can_view(self, data: dict[str, Any]) -> bool: - user: Any = self.scope.get("user") + async def _can_view(self, data: PermissionsData) -> bool: + user: AbstractBaseUser | AnonymousUser | None = self.scope.get("user") if user is None: return False owner_id = data.get("owner_id") @@ -32,19 +43,19 @@ class StatusConsumer(AsyncWebsocketConsumer): async def disconnect(self, code: int) -> None: await self.channel_layer.group_discard("status_updates", self.channel_name) - async def status_update(self, event: dict[str, Any]) -> None: + async def status_update(self, event: StatusUpdatePayload) -> None: if not self._authenticated(): await self.close() elif await self._can_view(event["data"]): await self.send(json.dumps(event)) - async def documents_deleted(self, event: dict[str, Any]) -> None: + async def documents_deleted(self, event: DocumentsDeletedPayload) -> None: if not self._authenticated(): await self.close() else: await self.send(json.dumps(event)) - async def document_updated(self, event: dict[str, Any]) -> None: + async def document_updated(self, event: DocumentUpdatedPayload) -> None: if not self._authenticated(): await self.close() elif await self._can_view(event["data"]): diff --git a/src/paperless/tests/test_websockets.py b/src/paperless/tests/test_websockets.py index bffc44f82..9f7c9a652 100644 --- a/src/paperless/tests/test_websockets.py +++ b/src/paperless/tests/test_websockets.py @@ -200,7 +200,10 @@ class TestWebSockets: "Test message", 1, 10, - extra_args={"foo": "bar"}, + document_id=42, + owner_id=1, + users_can_view=[2, 3], + groups_can_view=[4], ) assert mock_group_send.call_args[0][1] == { @@ -212,7 +215,10 @@ class TestWebSockets: "max_progress": 10, "status": ProgressStatusOptions.STARTED, "message": "Test message", - "foo": "bar", + "document_id": 42, + "owner_id": 1, + "users_can_view": [2, 3], + "groups_can_view": [4], }, }