Typing updates for consumer and tests

This commit is contained in:
Trenton H
2026-03-10 13:20:24 -07:00
parent 24f81edccf
commit c751d7a757
2 changed files with 26 additions and 17 deletions

View File

@@ -6,10 +6,13 @@ from channels.generic.websocket import AsyncWebsocketConsumer
class StatusConsumer(AsyncWebsocketConsumer):
def _authenticated(self) -> bool:
return "user" in self.scope and self.scope["user"].is_authenticated
user: Any = self.scope.get("user")
return user is not None and user.is_authenticated
async def _can_view(self, data) -> bool:
user = self.scope.get("user") if self.scope.get("user") else None
async def _can_view(self, data: dict[str, Any]) -> bool:
user: Any = self.scope.get("user")
if user is None:
return False
owner_id = data.get("owner_id")
users_can_view = data.get("users_can_view", [])
groups_can_view = data.get("groups_can_view", [])
@@ -30,22 +33,22 @@ class StatusConsumer(AsyncWebsocketConsumer):
await self.channel_layer.group_add("status_updates", self.channel_name)
await self.accept()
async def disconnect(self, close_code) -> None:
async def disconnect(self, code: int) -> None:
await self.channel_layer.group_discard("status_updates", self.channel_name)
async def status_update(self, event) -> None:
async def status_update(self, event: dict[str, Any]) -> None:
if not self._authenticated():
await self.close()
elif await self._can_view(event["data"]):
await self.send(json.dumps(event))
async def documents_deleted(self, event) -> None:
async def documents_deleted(self, event: dict[str, Any]) -> None:
if not self._authenticated():
await self.close()
else:
await self.send(json.dumps(event))
async def document_updated(self, event: Any) -> None:
async def document_updated(self, event: dict[str, Any]) -> None:
if not self._authenticated():
await self.close()
elif await self._can_view(event["data"]):

View File

@@ -1,6 +1,7 @@
import pytest
from channels.layers import get_channel_layer
from channels.testing import WebsocketCommunicator
from pytest_mock import MockerFixture
from documents.plugins.helpers import DocumentsStatusManager
from documents.plugins.helpers import ProgressManager
@@ -10,7 +11,7 @@ from paperless.asgi import application
class TestWebSockets:
@pytest.fixture(autouse=True)
def anyio_backend(self):
def anyio_backend(self) -> str:
return "asyncio"
@pytest.mark.anyio
@@ -21,7 +22,7 @@ class TestWebSockets:
await communicator.disconnect()
@pytest.mark.anyio
async def test_close_on_no_auth(self, mocker) -> None:
async def test_close_on_no_auth(self, mocker: MockerFixture) -> None:
mock_auth = mocker.patch(
"paperless.consumers.StatusConsumer._authenticated",
return_value=True,
@@ -37,6 +38,7 @@ class TestWebSockets:
mock_auth.return_value = False
channel_layer = get_channel_layer()
assert channel_layer is not None
await channel_layer.group_send(
"status_updates",
@@ -65,7 +67,7 @@ class TestWebSockets:
mock_close.assert_awaited_once()
@pytest.mark.anyio
async def test_auth(self, mocker) -> None:
async def test_auth(self, mocker: MockerFixture) -> None:
mocker.patch(
"paperless.consumers.StatusConsumer._authenticated",
return_value=True,
@@ -78,7 +80,7 @@ class TestWebSockets:
await communicator.disconnect()
@pytest.mark.anyio
async def test_receive_status_update(self, mocker) -> None:
async def test_receive_status_update(self, mocker: MockerFixture) -> None:
mocker.patch(
"paperless.consumers.StatusConsumer._authenticated",
return_value=True,
@@ -90,6 +92,7 @@ class TestWebSockets:
message = {"type": "status_update", "data": {"task_id": "test"}}
channel_layer = get_channel_layer()
assert channel_layer is not None
await channel_layer.group_send("status_updates", message)
assert await communicator.receive_json_from() == message
@@ -97,18 +100,19 @@ class TestWebSockets:
await communicator.disconnect()
@pytest.mark.anyio
async def test_status_update_check_perms(self, mocker) -> None:
async def test_status_update_check_perms(self, mocker: MockerFixture) -> None:
user = mocker.MagicMock()
user.is_authenticated = True
user.is_superuser = False
user.id = 1
communicator = WebsocketCommunicator(application, "/ws/status/")
communicator.scope["user"] = user
communicator.scope["user"] = user # type: ignore[typeddict-unknown-key]
connected, _ = await communicator.connect()
assert connected
channel_layer = get_channel_layer()
assert channel_layer is not None
# Message received as owner
message = {"type": "status_update", "data": {"task_id": "test", "owner_id": 1}}
@@ -132,7 +136,7 @@ class TestWebSockets:
await communicator.disconnect()
@pytest.mark.anyio
async def test_receive_documents_deleted(self, mocker) -> None:
async def test_receive_documents_deleted(self, mocker: MockerFixture) -> None:
mocker.patch(
"paperless.consumers.StatusConsumer._authenticated",
return_value=True,
@@ -144,6 +148,7 @@ class TestWebSockets:
message = {"type": "documents_deleted", "data": {"documents": [1, 2, 3]}}
channel_layer = get_channel_layer()
assert channel_layer is not None
await channel_layer.group_send("status_updates", message)
assert await communicator.receive_json_from() == message
@@ -151,7 +156,7 @@ class TestWebSockets:
await communicator.disconnect()
@pytest.mark.anyio
async def test_receive_document_updated(self, mocker) -> None:
async def test_receive_document_updated(self, mocker: MockerFixture) -> None:
mocker.patch(
"paperless.consumers.StatusConsumer._authenticated",
return_value=True,
@@ -176,13 +181,14 @@ class TestWebSockets:
},
}
channel_layer = get_channel_layer()
assert channel_layer is not None
await channel_layer.group_send("status_updates", message)
assert await communicator.receive_json_from() == message
await communicator.disconnect()
def test_manager_send_progress(self, mocker) -> None:
def test_manager_send_progress(self, mocker: MockerFixture) -> None:
mock_group_send = mocker.patch(
"channels.layers.InMemoryChannelLayer.group_send",
)
@@ -209,7 +215,7 @@ class TestWebSockets:
},
}
def test_manager_send_documents_deleted(self, mocker) -> None:
def test_manager_send_documents_deleted(self, mocker: MockerFixture) -> None:
mock_group_send = mocker.patch(
"channels.layers.InMemoryChannelLayer.group_send",
)