Chore: Better typed status manager messages (#12509)

This commit is contained in:
Trenton H
2026-04-03 14:18:01 -07:00
committed by GitHub
parent d0f8a98a9a
commit c2f02851da
5 changed files with 114 additions and 39 deletions
+6 -8
View File
@@ -139,14 +139,12 @@ class ConsumerPluginMixin:
message,
current_progress,
max_progress,
extra_args={
"document_id": document_id,
"owner_id": self.metadata.owner_id if self.metadata.owner_id else None,
"users_can_view": (self.metadata.view_users or [])
+ (self.metadata.change_users or []),
"groups_can_view": (self.metadata.view_groups or [])
+ (self.metadata.change_groups or []),
},
document_id=document_id,
owner_id=self.metadata.owner_id if self.metadata.owner_id else None,
users_can_view=(self.metadata.view_users or [])
+ (self.metadata.change_users or []),
groups_can_view=(self.metadata.view_groups or [])
+ (self.metadata.change_groups or []),
)
def _fail(
+73 -18
View File
@@ -1,6 +1,9 @@
import enum
from collections.abc import Mapping
from typing import TYPE_CHECKING
from typing import Literal
from typing import Self
from typing import TypeAlias
from typing import TypedDict
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
@@ -16,6 +19,59 @@ class ProgressStatusOptions(enum.StrEnum):
FAILED = "FAILED"
class PermissionsData(TypedDict, total=False):
"""Permission fields included in status messages for access control."""
owner_id: int | None
users_can_view: list[int]
groups_can_view: list[int]
class ProgressUpdateData(TypedDict):
filename: str | None
task_id: str | None
current_progress: int
max_progress: int
status: str
message: str
document_id: int | None
owner_id: int | None
users_can_view: list[int]
groups_can_view: list[int]
class StatusUpdatePayload(TypedDict):
type: Literal["status_update"]
data: ProgressUpdateData
class DocumentsDeletedData(TypedDict):
documents: list[int]
class DocumentsDeletedPayload(TypedDict):
type: Literal["documents_deleted"]
data: DocumentsDeletedData
class DocumentUpdatedData(TypedDict):
document_id: int
modified: str
owner_id: int | None
users_can_view: list[int]
groups_can_view: list[int]
class DocumentUpdatedPayload(TypedDict):
type: Literal["document_updated"]
data: DocumentUpdatedData
WebsocketPayload: TypeAlias = (
StatusUpdatePayload | DocumentsDeletedPayload | DocumentUpdatedPayload
)
class BaseStatusManager:
"""
Handles sending of progress information via the channel layer, with proper management
@@ -25,11 +81,11 @@ class BaseStatusManager:
def __init__(self) -> None:
self._channel: RedisPubSubChannelLayer | None = None
def __enter__(self):
def __enter__(self) -> Self:
self.open()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
self.close()
def open(self) -> None:
@@ -48,7 +104,7 @@ class BaseStatusManager:
async_to_sync(self._channel.flush)
self._channel = None
def send(self, payload: Mapping[str, object]) -> None:
def send(self, payload: WebsocketPayload) -> None:
# Ensure the layer is open
self.open()
@@ -72,36 +128,36 @@ class ProgressManager(BaseStatusManager):
message: str,
current_progress: int,
max_progress: int,
extra_args: dict[str, str | int | None] | None = None,
*,
document_id: int | None = None,
owner_id: int | None = None,
users_can_view: list[int] | None = None,
groups_can_view: list[int] | None = None,
) -> None:
data: dict[str, object] = {
data: ProgressUpdateData = {
"filename": self.filename,
"task_id": self.task_id,
"current_progress": current_progress,
"max_progress": max_progress,
"status": status,
"message": message,
"document_id": document_id,
"owner_id": owner_id,
"users_can_view": users_can_view or [],
"groups_can_view": groups_can_view or [],
}
if extra_args is not None:
data.update(extra_args)
payload: dict[str, object] = {
"type": "status_update",
"data": data,
}
payload: StatusUpdatePayload = {"type": "status_update", "data": data}
self.send(payload)
class DocumentsStatusManager(BaseStatusManager):
def send_documents_deleted(self, documents: list[int]) -> None:
payload: dict[str, object] = {
payload: DocumentsDeletedPayload = {
"type": "documents_deleted",
"data": {
"documents": documents,
},
}
self.send(payload)
def send_document_updated(
@@ -113,7 +169,7 @@ class DocumentsStatusManager(BaseStatusManager):
users_can_view: list[int] | None = None,
groups_can_view: list[int] | None = None,
) -> None:
payload: dict[str, object] = {
payload: DocumentUpdatedPayload = {
"type": "document_updated",
"data": {
"document_id": document_id,
@@ -123,5 +179,4 @@ class DocumentsStatusManager(BaseStatusManager):
"groups_can_view": groups_can_view or [],
},
}
self.send(payload)
+9 -4
View File
@@ -435,7 +435,11 @@ class DummyProgressManager:
message: str,
current_progress: int,
max_progress: int,
extra_args: dict[str, str | int] | None = None,
*,
document_id: int | None = None,
owner_id: int | None = None,
users_can_view: list[int] | None = None,
groups_can_view: list[int] | None = None,
) -> None:
# Ensure the layer is open
self.open()
@@ -449,9 +453,10 @@ class DummyProgressManager:
"max_progress": max_progress,
"status": status,
"message": message,
"document_id": document_id,
"owner_id": owner_id,
"users_can_view": users_can_view or [],
"groups_can_view": groups_can_view or [],
},
}
if extra_args is not None:
payload["data"].update(extra_args)
self.payloads.append(payload)
+18 -7
View File
@@ -1,16 +1,27 @@
from __future__ import annotations
import json
from typing import Any
from typing import TYPE_CHECKING
from channels.generic.websocket import AsyncWebsocketConsumer
if TYPE_CHECKING:
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.models import AnonymousUser
from documents.plugins.helpers import DocumentsDeletedPayload
from documents.plugins.helpers import DocumentUpdatedPayload
from documents.plugins.helpers import PermissionsData
from documents.plugins.helpers import StatusUpdatePayload
class StatusConsumer(AsyncWebsocketConsumer):
def _authenticated(self) -> bool:
user: Any = self.scope.get("user")
user: AbstractBaseUser | AnonymousUser | None = self.scope.get("user")
return user is not None and user.is_authenticated
async def _can_view(self, data: dict[str, Any]) -> bool:
user: Any = self.scope.get("user")
async def _can_view(self, data: PermissionsData) -> bool:
user: AbstractBaseUser | AnonymousUser | None = self.scope.get("user")
if user is None:
return False
owner_id = data.get("owner_id")
@@ -32,19 +43,19 @@ class StatusConsumer(AsyncWebsocketConsumer):
async def disconnect(self, code: int) -> None:
await self.channel_layer.group_discard("status_updates", self.channel_name)
async def status_update(self, event: dict[str, Any]) -> None:
async def status_update(self, event: StatusUpdatePayload) -> None:
if not self._authenticated():
await self.close()
elif await self._can_view(event["data"]):
await self.send(json.dumps(event))
async def documents_deleted(self, event: dict[str, Any]) -> None:
async def documents_deleted(self, event: DocumentsDeletedPayload) -> None:
if not self._authenticated():
await self.close()
else:
await self.send(json.dumps(event))
async def document_updated(self, event: dict[str, Any]) -> None:
async def document_updated(self, event: DocumentUpdatedPayload) -> None:
if not self._authenticated():
await self.close()
elif await self._can_view(event["data"]):
+8 -2
View File
@@ -200,7 +200,10 @@ class TestWebSockets:
"Test message",
1,
10,
extra_args={"foo": "bar"},
document_id=42,
owner_id=1,
users_can_view=[2, 3],
groups_can_view=[4],
)
assert mock_group_send.call_args[0][1] == {
@@ -212,7 +215,10 @@ class TestWebSockets:
"max_progress": 10,
"status": ProgressStatusOptions.STARTED,
"message": "Test message",
"foo": "bar",
"document_id": 42,
"owner_id": 1,
"users_can_view": [2, 3],
"groups_can_view": [4],
},
}