From c751d7a75782c676b052d5d229162f6b91c477e5 Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Tue, 10 Mar 2026 13:20:24 -0700 Subject: [PATCH] Typing updates for consumer and tests --- src/paperless/consumers.py | 17 ++++++++++------- src/paperless/tests/test_websockets.py | 26 ++++++++++++++++---------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/paperless/consumers.py b/src/paperless/consumers.py index d06d584c1..90f385135 100644 --- a/src/paperless/consumers.py +++ b/src/paperless/consumers.py @@ -6,10 +6,13 @@ from channels.generic.websocket import AsyncWebsocketConsumer class StatusConsumer(AsyncWebsocketConsumer): def _authenticated(self) -> bool: - return "user" in self.scope and self.scope["user"].is_authenticated + user: Any = self.scope.get("user") + return user is not None and user.is_authenticated - async def _can_view(self, data) -> bool: - user = self.scope.get("user") if self.scope.get("user") else None + async def _can_view(self, data: dict[str, Any]) -> bool: + user: Any = self.scope.get("user") + if user is None: + return False owner_id = data.get("owner_id") users_can_view = data.get("users_can_view", []) groups_can_view = data.get("groups_can_view", []) @@ -30,22 +33,22 @@ class StatusConsumer(AsyncWebsocketConsumer): await self.channel_layer.group_add("status_updates", self.channel_name) await self.accept() - async def disconnect(self, close_code) -> None: + async def disconnect(self, code: int) -> None: await self.channel_layer.group_discard("status_updates", self.channel_name) - async def status_update(self, event) -> None: + async def status_update(self, event: dict[str, Any]) -> None: if not self._authenticated(): await self.close() elif await self._can_view(event["data"]): await self.send(json.dumps(event)) - async def documents_deleted(self, event) -> None: + async def documents_deleted(self, event: dict[str, Any]) -> None: if not self._authenticated(): await self.close() else: await self.send(json.dumps(event)) - async def document_updated(self, event: Any) -> None: + async def document_updated(self, event: dict[str, Any]) -> None: if not self._authenticated(): await self.close() elif await self._can_view(event["data"]): diff --git a/src/paperless/tests/test_websockets.py b/src/paperless/tests/test_websockets.py index f9e87a407..8c08ee887 100644 --- a/src/paperless/tests/test_websockets.py +++ b/src/paperless/tests/test_websockets.py @@ -1,6 +1,7 @@ import pytest from channels.layers import get_channel_layer from channels.testing import WebsocketCommunicator +from pytest_mock import MockerFixture from documents.plugins.helpers import DocumentsStatusManager from documents.plugins.helpers import ProgressManager @@ -10,7 +11,7 @@ from paperless.asgi import application class TestWebSockets: @pytest.fixture(autouse=True) - def anyio_backend(self): + def anyio_backend(self) -> str: return "asyncio" @pytest.mark.anyio @@ -21,7 +22,7 @@ class TestWebSockets: await communicator.disconnect() @pytest.mark.anyio - async def test_close_on_no_auth(self, mocker) -> None: + async def test_close_on_no_auth(self, mocker: MockerFixture) -> None: mock_auth = mocker.patch( "paperless.consumers.StatusConsumer._authenticated", return_value=True, @@ -37,6 +38,7 @@ class TestWebSockets: mock_auth.return_value = False channel_layer = get_channel_layer() + assert channel_layer is not None await channel_layer.group_send( "status_updates", @@ -65,7 +67,7 @@ class TestWebSockets: mock_close.assert_awaited_once() @pytest.mark.anyio - async def test_auth(self, mocker) -> None: + async def test_auth(self, mocker: MockerFixture) -> None: mocker.patch( "paperless.consumers.StatusConsumer._authenticated", return_value=True, @@ -78,7 +80,7 @@ class TestWebSockets: await communicator.disconnect() @pytest.mark.anyio - async def test_receive_status_update(self, mocker) -> None: + async def test_receive_status_update(self, mocker: MockerFixture) -> None: mocker.patch( "paperless.consumers.StatusConsumer._authenticated", return_value=True, @@ -90,6 +92,7 @@ class TestWebSockets: message = {"type": "status_update", "data": {"task_id": "test"}} channel_layer = get_channel_layer() + assert channel_layer is not None await channel_layer.group_send("status_updates", message) assert await communicator.receive_json_from() == message @@ -97,18 +100,19 @@ class TestWebSockets: await communicator.disconnect() @pytest.mark.anyio - async def test_status_update_check_perms(self, mocker) -> None: + async def test_status_update_check_perms(self, mocker: MockerFixture) -> None: user = mocker.MagicMock() user.is_authenticated = True user.is_superuser = False user.id = 1 communicator = WebsocketCommunicator(application, "/ws/status/") - communicator.scope["user"] = user + communicator.scope["user"] = user # type: ignore[typeddict-unknown-key] connected, _ = await communicator.connect() assert connected channel_layer = get_channel_layer() + assert channel_layer is not None # Message received as owner message = {"type": "status_update", "data": {"task_id": "test", "owner_id": 1}} @@ -132,7 +136,7 @@ class TestWebSockets: await communicator.disconnect() @pytest.mark.anyio - async def test_receive_documents_deleted(self, mocker) -> None: + async def test_receive_documents_deleted(self, mocker: MockerFixture) -> None: mocker.patch( "paperless.consumers.StatusConsumer._authenticated", return_value=True, @@ -144,6 +148,7 @@ class TestWebSockets: message = {"type": "documents_deleted", "data": {"documents": [1, 2, 3]}} channel_layer = get_channel_layer() + assert channel_layer is not None await channel_layer.group_send("status_updates", message) assert await communicator.receive_json_from() == message @@ -151,7 +156,7 @@ class TestWebSockets: await communicator.disconnect() @pytest.mark.anyio - async def test_receive_document_updated(self, mocker) -> None: + async def test_receive_document_updated(self, mocker: MockerFixture) -> None: mocker.patch( "paperless.consumers.StatusConsumer._authenticated", return_value=True, @@ -176,13 +181,14 @@ class TestWebSockets: }, } channel_layer = get_channel_layer() + assert channel_layer is not None await channel_layer.group_send("status_updates", message) assert await communicator.receive_json_from() == message await communicator.disconnect() - def test_manager_send_progress(self, mocker) -> None: + def test_manager_send_progress(self, mocker: MockerFixture) -> None: mock_group_send = mocker.patch( "channels.layers.InMemoryChannelLayer.group_send", ) @@ -209,7 +215,7 @@ class TestWebSockets: }, } - def test_manager_send_documents_deleted(self, mocker) -> None: + def test_manager_send_documents_deleted(self, mocker: MockerFixture) -> None: mock_group_send = mocker.patch( "channels.layers.InMemoryChannelLayer.group_send", )