mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-11 08:09:44 +00:00
Typing updates for consumer and tests
This commit is contained in:
@@ -6,10 +6,13 @@ from channels.generic.websocket import AsyncWebsocketConsumer
|
|||||||
|
|
||||||
class StatusConsumer(AsyncWebsocketConsumer):
|
class StatusConsumer(AsyncWebsocketConsumer):
|
||||||
def _authenticated(self) -> bool:
|
def _authenticated(self) -> bool:
|
||||||
return "user" in self.scope and self.scope["user"].is_authenticated
|
user: Any = self.scope.get("user")
|
||||||
|
return user is not None and user.is_authenticated
|
||||||
|
|
||||||
async def _can_view(self, data) -> bool:
|
async def _can_view(self, data: dict[str, Any]) -> bool:
|
||||||
user = self.scope.get("user") if self.scope.get("user") else None
|
user: Any = self.scope.get("user")
|
||||||
|
if user is None:
|
||||||
|
return False
|
||||||
owner_id = data.get("owner_id")
|
owner_id = data.get("owner_id")
|
||||||
users_can_view = data.get("users_can_view", [])
|
users_can_view = data.get("users_can_view", [])
|
||||||
groups_can_view = data.get("groups_can_view", [])
|
groups_can_view = data.get("groups_can_view", [])
|
||||||
@@ -30,22 +33,22 @@ class StatusConsumer(AsyncWebsocketConsumer):
|
|||||||
await self.channel_layer.group_add("status_updates", self.channel_name)
|
await self.channel_layer.group_add("status_updates", self.channel_name)
|
||||||
await self.accept()
|
await self.accept()
|
||||||
|
|
||||||
async def disconnect(self, close_code) -> None:
|
async def disconnect(self, code: int) -> None:
|
||||||
await self.channel_layer.group_discard("status_updates", self.channel_name)
|
await self.channel_layer.group_discard("status_updates", self.channel_name)
|
||||||
|
|
||||||
async def status_update(self, event) -> None:
|
async def status_update(self, event: dict[str, Any]) -> None:
|
||||||
if not self._authenticated():
|
if not self._authenticated():
|
||||||
await self.close()
|
await self.close()
|
||||||
elif await self._can_view(event["data"]):
|
elif await self._can_view(event["data"]):
|
||||||
await self.send(json.dumps(event))
|
await self.send(json.dumps(event))
|
||||||
|
|
||||||
async def documents_deleted(self, event) -> None:
|
async def documents_deleted(self, event: dict[str, Any]) -> None:
|
||||||
if not self._authenticated():
|
if not self._authenticated():
|
||||||
await self.close()
|
await self.close()
|
||||||
else:
|
else:
|
||||||
await self.send(json.dumps(event))
|
await self.send(json.dumps(event))
|
||||||
|
|
||||||
async def document_updated(self, event: Any) -> None:
|
async def document_updated(self, event: dict[str, Any]) -> None:
|
||||||
if not self._authenticated():
|
if not self._authenticated():
|
||||||
await self.close()
|
await self.close()
|
||||||
elif await self._can_view(event["data"]):
|
elif await self._can_view(event["data"]):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from channels.layers import get_channel_layer
|
from channels.layers import get_channel_layer
|
||||||
from channels.testing import WebsocketCommunicator
|
from channels.testing import WebsocketCommunicator
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from documents.plugins.helpers import DocumentsStatusManager
|
from documents.plugins.helpers import DocumentsStatusManager
|
||||||
from documents.plugins.helpers import ProgressManager
|
from documents.plugins.helpers import ProgressManager
|
||||||
@@ -10,7 +11,7 @@ from paperless.asgi import application
|
|||||||
|
|
||||||
class TestWebSockets:
|
class TestWebSockets:
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def anyio_backend(self):
|
def anyio_backend(self) -> str:
|
||||||
return "asyncio"
|
return "asyncio"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -21,7 +22,7 @@ class TestWebSockets:
|
|||||||
await communicator.disconnect()
|
await communicator.disconnect()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_close_on_no_auth(self, mocker) -> None:
|
async def test_close_on_no_auth(self, mocker: MockerFixture) -> None:
|
||||||
mock_auth = mocker.patch(
|
mock_auth = mocker.patch(
|
||||||
"paperless.consumers.StatusConsumer._authenticated",
|
"paperless.consumers.StatusConsumer._authenticated",
|
||||||
return_value=True,
|
return_value=True,
|
||||||
@@ -37,6 +38,7 @@ class TestWebSockets:
|
|||||||
|
|
||||||
mock_auth.return_value = False
|
mock_auth.return_value = False
|
||||||
channel_layer = get_channel_layer()
|
channel_layer = get_channel_layer()
|
||||||
|
assert channel_layer is not None
|
||||||
|
|
||||||
await channel_layer.group_send(
|
await channel_layer.group_send(
|
||||||
"status_updates",
|
"status_updates",
|
||||||
@@ -65,7 +67,7 @@ class TestWebSockets:
|
|||||||
mock_close.assert_awaited_once()
|
mock_close.assert_awaited_once()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_auth(self, mocker) -> None:
|
async def test_auth(self, mocker: MockerFixture) -> None:
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"paperless.consumers.StatusConsumer._authenticated",
|
"paperless.consumers.StatusConsumer._authenticated",
|
||||||
return_value=True,
|
return_value=True,
|
||||||
@@ -78,7 +80,7 @@ class TestWebSockets:
|
|||||||
await communicator.disconnect()
|
await communicator.disconnect()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_receive_status_update(self, mocker) -> None:
|
async def test_receive_status_update(self, mocker: MockerFixture) -> None:
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"paperless.consumers.StatusConsumer._authenticated",
|
"paperless.consumers.StatusConsumer._authenticated",
|
||||||
return_value=True,
|
return_value=True,
|
||||||
@@ -90,6 +92,7 @@ class TestWebSockets:
|
|||||||
|
|
||||||
message = {"type": "status_update", "data": {"task_id": "test"}}
|
message = {"type": "status_update", "data": {"task_id": "test"}}
|
||||||
channel_layer = get_channel_layer()
|
channel_layer = get_channel_layer()
|
||||||
|
assert channel_layer is not None
|
||||||
await channel_layer.group_send("status_updates", message)
|
await channel_layer.group_send("status_updates", message)
|
||||||
|
|
||||||
assert await communicator.receive_json_from() == message
|
assert await communicator.receive_json_from() == message
|
||||||
@@ -97,18 +100,19 @@ class TestWebSockets:
|
|||||||
await communicator.disconnect()
|
await communicator.disconnect()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_status_update_check_perms(self, mocker) -> None:
|
async def test_status_update_check_perms(self, mocker: MockerFixture) -> None:
|
||||||
user = mocker.MagicMock()
|
user = mocker.MagicMock()
|
||||||
user.is_authenticated = True
|
user.is_authenticated = True
|
||||||
user.is_superuser = False
|
user.is_superuser = False
|
||||||
user.id = 1
|
user.id = 1
|
||||||
|
|
||||||
communicator = WebsocketCommunicator(application, "/ws/status/")
|
communicator = WebsocketCommunicator(application, "/ws/status/")
|
||||||
communicator.scope["user"] = user
|
communicator.scope["user"] = user # type: ignore[typeddict-unknown-key]
|
||||||
connected, _ = await communicator.connect()
|
connected, _ = await communicator.connect()
|
||||||
assert connected
|
assert connected
|
||||||
|
|
||||||
channel_layer = get_channel_layer()
|
channel_layer = get_channel_layer()
|
||||||
|
assert channel_layer is not None
|
||||||
|
|
||||||
# Message received as owner
|
# Message received as owner
|
||||||
message = {"type": "status_update", "data": {"task_id": "test", "owner_id": 1}}
|
message = {"type": "status_update", "data": {"task_id": "test", "owner_id": 1}}
|
||||||
@@ -132,7 +136,7 @@ class TestWebSockets:
|
|||||||
await communicator.disconnect()
|
await communicator.disconnect()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_receive_documents_deleted(self, mocker) -> None:
|
async def test_receive_documents_deleted(self, mocker: MockerFixture) -> None:
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"paperless.consumers.StatusConsumer._authenticated",
|
"paperless.consumers.StatusConsumer._authenticated",
|
||||||
return_value=True,
|
return_value=True,
|
||||||
@@ -144,6 +148,7 @@ class TestWebSockets:
|
|||||||
|
|
||||||
message = {"type": "documents_deleted", "data": {"documents": [1, 2, 3]}}
|
message = {"type": "documents_deleted", "data": {"documents": [1, 2, 3]}}
|
||||||
channel_layer = get_channel_layer()
|
channel_layer = get_channel_layer()
|
||||||
|
assert channel_layer is not None
|
||||||
await channel_layer.group_send("status_updates", message)
|
await channel_layer.group_send("status_updates", message)
|
||||||
|
|
||||||
assert await communicator.receive_json_from() == message
|
assert await communicator.receive_json_from() == message
|
||||||
@@ -151,7 +156,7 @@ class TestWebSockets:
|
|||||||
await communicator.disconnect()
|
await communicator.disconnect()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_receive_document_updated(self, mocker) -> None:
|
async def test_receive_document_updated(self, mocker: MockerFixture) -> None:
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"paperless.consumers.StatusConsumer._authenticated",
|
"paperless.consumers.StatusConsumer._authenticated",
|
||||||
return_value=True,
|
return_value=True,
|
||||||
@@ -176,13 +181,14 @@ class TestWebSockets:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
channel_layer = get_channel_layer()
|
channel_layer = get_channel_layer()
|
||||||
|
assert channel_layer is not None
|
||||||
await channel_layer.group_send("status_updates", message)
|
await channel_layer.group_send("status_updates", message)
|
||||||
|
|
||||||
assert await communicator.receive_json_from() == message
|
assert await communicator.receive_json_from() == message
|
||||||
|
|
||||||
await communicator.disconnect()
|
await communicator.disconnect()
|
||||||
|
|
||||||
def test_manager_send_progress(self, mocker) -> None:
|
def test_manager_send_progress(self, mocker: MockerFixture) -> None:
|
||||||
mock_group_send = mocker.patch(
|
mock_group_send = mocker.patch(
|
||||||
"channels.layers.InMemoryChannelLayer.group_send",
|
"channels.layers.InMemoryChannelLayer.group_send",
|
||||||
)
|
)
|
||||||
@@ -209,7 +215,7 @@ class TestWebSockets:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_manager_send_documents_deleted(self, mocker) -> None:
|
def test_manager_send_documents_deleted(self, mocker: MockerFixture) -> None:
|
||||||
mock_group_send = mocker.patch(
|
mock_group_send = mocker.patch(
|
||||||
"channels.layers.InMemoryChannelLayer.group_send",
|
"channels.layers.InMemoryChannelLayer.group_send",
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user