mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-03-11 03:31:23 +00:00
Typing updates for consumer and tests
This commit is contained in:
@@ -6,10 +6,13 @@ from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
|
||||
class StatusConsumer(AsyncWebsocketConsumer):
|
||||
def _authenticated(self) -> bool:
|
||||
return "user" in self.scope and self.scope["user"].is_authenticated
|
||||
user: Any = self.scope.get("user")
|
||||
return user is not None and user.is_authenticated
|
||||
|
||||
async def _can_view(self, data) -> bool:
|
||||
user = self.scope.get("user") if self.scope.get("user") else None
|
||||
async def _can_view(self, data: dict[str, Any]) -> bool:
|
||||
user: Any = self.scope.get("user")
|
||||
if user is None:
|
||||
return False
|
||||
owner_id = data.get("owner_id")
|
||||
users_can_view = data.get("users_can_view", [])
|
||||
groups_can_view = data.get("groups_can_view", [])
|
||||
@@ -30,22 +33,22 @@ class StatusConsumer(AsyncWebsocketConsumer):
|
||||
await self.channel_layer.group_add("status_updates", self.channel_name)
|
||||
await self.accept()
|
||||
|
||||
async def disconnect(self, close_code) -> None:
|
||||
async def disconnect(self, code: int) -> None:
|
||||
await self.channel_layer.group_discard("status_updates", self.channel_name)
|
||||
|
||||
async def status_update(self, event) -> None:
|
||||
async def status_update(self, event: dict[str, Any]) -> None:
|
||||
if not self._authenticated():
|
||||
await self.close()
|
||||
elif await self._can_view(event["data"]):
|
||||
await self.send(json.dumps(event))
|
||||
|
||||
async def documents_deleted(self, event) -> None:
|
||||
async def documents_deleted(self, event: dict[str, Any]) -> None:
|
||||
if not self._authenticated():
|
||||
await self.close()
|
||||
else:
|
||||
await self.send(json.dumps(event))
|
||||
|
||||
async def document_updated(self, event: Any) -> None:
|
||||
async def document_updated(self, event: dict[str, Any]) -> None:
|
||||
if not self._authenticated():
|
||||
await self.close()
|
||||
elif await self._can_view(event["data"]):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from channels.layers import get_channel_layer
|
||||
from channels.testing import WebsocketCommunicator
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from documents.plugins.helpers import DocumentsStatusManager
|
||||
from documents.plugins.helpers import ProgressManager
|
||||
@@ -10,7 +11,7 @@ from paperless.asgi import application
|
||||
|
||||
class TestWebSockets:
|
||||
@pytest.fixture(autouse=True)
|
||||
def anyio_backend(self):
|
||||
def anyio_backend(self) -> str:
|
||||
return "asyncio"
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -21,7 +22,7 @@ class TestWebSockets:
|
||||
await communicator.disconnect()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_close_on_no_auth(self, mocker) -> None:
|
||||
async def test_close_on_no_auth(self, mocker: MockerFixture) -> None:
|
||||
mock_auth = mocker.patch(
|
||||
"paperless.consumers.StatusConsumer._authenticated",
|
||||
return_value=True,
|
||||
@@ -37,6 +38,7 @@ class TestWebSockets:
|
||||
|
||||
mock_auth.return_value = False
|
||||
channel_layer = get_channel_layer()
|
||||
assert channel_layer is not None
|
||||
|
||||
await channel_layer.group_send(
|
||||
"status_updates",
|
||||
@@ -65,7 +67,7 @@ class TestWebSockets:
|
||||
mock_close.assert_awaited_once()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_auth(self, mocker) -> None:
|
||||
async def test_auth(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch(
|
||||
"paperless.consumers.StatusConsumer._authenticated",
|
||||
return_value=True,
|
||||
@@ -78,7 +80,7 @@ class TestWebSockets:
|
||||
await communicator.disconnect()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_receive_status_update(self, mocker) -> None:
|
||||
async def test_receive_status_update(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch(
|
||||
"paperless.consumers.StatusConsumer._authenticated",
|
||||
return_value=True,
|
||||
@@ -90,6 +92,7 @@ class TestWebSockets:
|
||||
|
||||
message = {"type": "status_update", "data": {"task_id": "test"}}
|
||||
channel_layer = get_channel_layer()
|
||||
assert channel_layer is not None
|
||||
await channel_layer.group_send("status_updates", message)
|
||||
|
||||
assert await communicator.receive_json_from() == message
|
||||
@@ -97,18 +100,19 @@ class TestWebSockets:
|
||||
await communicator.disconnect()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_status_update_check_perms(self, mocker) -> None:
|
||||
async def test_status_update_check_perms(self, mocker: MockerFixture) -> None:
|
||||
user = mocker.MagicMock()
|
||||
user.is_authenticated = True
|
||||
user.is_superuser = False
|
||||
user.id = 1
|
||||
|
||||
communicator = WebsocketCommunicator(application, "/ws/status/")
|
||||
communicator.scope["user"] = user
|
||||
communicator.scope["user"] = user # type: ignore[typeddict-unknown-key]
|
||||
connected, _ = await communicator.connect()
|
||||
assert connected
|
||||
|
||||
channel_layer = get_channel_layer()
|
||||
assert channel_layer is not None
|
||||
|
||||
# Message received as owner
|
||||
message = {"type": "status_update", "data": {"task_id": "test", "owner_id": 1}}
|
||||
@@ -132,7 +136,7 @@ class TestWebSockets:
|
||||
await communicator.disconnect()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_receive_documents_deleted(self, mocker) -> None:
|
||||
async def test_receive_documents_deleted(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch(
|
||||
"paperless.consumers.StatusConsumer._authenticated",
|
||||
return_value=True,
|
||||
@@ -144,6 +148,7 @@ class TestWebSockets:
|
||||
|
||||
message = {"type": "documents_deleted", "data": {"documents": [1, 2, 3]}}
|
||||
channel_layer = get_channel_layer()
|
||||
assert channel_layer is not None
|
||||
await channel_layer.group_send("status_updates", message)
|
||||
|
||||
assert await communicator.receive_json_from() == message
|
||||
@@ -151,7 +156,7 @@ class TestWebSockets:
|
||||
await communicator.disconnect()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_receive_document_updated(self, mocker) -> None:
|
||||
async def test_receive_document_updated(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch(
|
||||
"paperless.consumers.StatusConsumer._authenticated",
|
||||
return_value=True,
|
||||
@@ -176,13 +181,14 @@ class TestWebSockets:
|
||||
},
|
||||
}
|
||||
channel_layer = get_channel_layer()
|
||||
assert channel_layer is not None
|
||||
await channel_layer.group_send("status_updates", message)
|
||||
|
||||
assert await communicator.receive_json_from() == message
|
||||
|
||||
await communicator.disconnect()
|
||||
|
||||
def test_manager_send_progress(self, mocker) -> None:
|
||||
def test_manager_send_progress(self, mocker: MockerFixture) -> None:
|
||||
mock_group_send = mocker.patch(
|
||||
"channels.layers.InMemoryChannelLayer.group_send",
|
||||
)
|
||||
@@ -209,7 +215,7 @@ class TestWebSockets:
|
||||
},
|
||||
}
|
||||
|
||||
def test_manager_send_documents_deleted(self, mocker) -> None:
|
||||
def test_manager_send_documents_deleted(self, mocker: MockerFixture) -> None:
|
||||
mock_group_send = mocker.patch(
|
||||
"channels.layers.InMemoryChannelLayer.group_send",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user