From fbf4e32646a7c698ca0c9d17b1ecc34074617c3b Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Mon, 20 Apr 2026 11:40:04 -0700 Subject: [PATCH] Chore: Converts all call sites and test asserts to use apply_async and headers (#12591) --- src/documents/barcodes.py | 41 +++-- src/documents/bulk_edit.py | 146 +++++++++++------ .../management/commands/document_consumer.py | 16 +- src/documents/serialisers.py | 10 +- src/documents/signals/handlers.py | 67 ++------ src/documents/tasks.py | 6 +- src/documents/tests/test_api_bulk_edit.py | 38 +++-- src/documents/tests/test_api_custom_fields.py | 4 +- .../tests/test_api_document_versions.py | 16 +- src/documents/tests/test_api_documents.py | 64 ++------ src/documents/tests/test_api_objects.py | 17 +- src/documents/tests/test_barcodes.py | 6 +- src/documents/tests/test_bulk_edit.py | 152 +++++++++--------- .../tests/test_management_consumer.py | 70 ++++---- .../tests/test_share_link_bundles.py | 10 +- src/documents/tests/test_tag_hierarchy.py | 2 +- src/documents/tests/test_task_signals.py | 41 ++--- src/documents/tests/test_workflows.py | 42 ++--- src/documents/tests/utils.py | 52 ++---- src/documents/views.py | 50 ++++-- src/documents/workflows/actions.py | 14 +- src/paperless_mail/mail.py | 15 +- src/paperless_mail/tests/test_mail.py | 5 +- src/paperless_mail/views.py | 6 +- 24 files changed, 465 insertions(+), 425 deletions(-) diff --git a/src/documents/barcodes.py b/src/documents/barcodes.py index 38a28081a..2bb96f1ea 100644 --- a/src/documents/barcodes.py +++ b/src/documents/barcodes.py @@ -17,7 +17,9 @@ from pikepdf import Pdf from documents.converters import convert_from_tiff_to_pdf from documents.data_models import ConsumableDocument from documents.data_models import DocumentMetadataOverrides +from documents.data_models import DocumentSource from documents.models import Document +from documents.models import PaperlessTask from documents.models import Tag from documents.plugins.base import ConsumeTaskPlugin from documents.plugins.base import StopConsumeTaskError @@ -193,23 +195,36 @@ class BarcodePlugin(ConsumeTaskPlugin): from documents import tasks + _SOURCE_TO_TRIGGER: dict[DocumentSource, PaperlessTask.TriggerSource] = { + DocumentSource.ConsumeFolder: PaperlessTask.TriggerSource.FOLDER_CONSUME, + DocumentSource.ApiUpload: PaperlessTask.TriggerSource.API_UPLOAD, + DocumentSource.MailFetch: PaperlessTask.TriggerSource.EMAIL_CONSUME, + DocumentSource.WebUI: PaperlessTask.TriggerSource.WEB_UI, + } + trigger_source = _SOURCE_TO_TRIGGER.get( + self.input_doc.source, + PaperlessTask.TriggerSource.MANUAL, + ) + # Create the split document tasks for new_document in self.separate_pages(separator_pages): copy_file_with_basic_stats(new_document, tmp_dir / new_document.name) - task = tasks.consume_file.delay( - ConsumableDocument( - # Same source, for templates - source=self.input_doc.source, - mailrule_id=self.input_doc.mailrule_id, - # Can't use same folder or the consume might grab it again - original_file=(tmp_dir / new_document.name).resolve(), - # Adding optional original_path for later uses in - # workflow matching - original_path=self.input_doc.original_file, - ), - # All the same metadata - self.metadata, + task = tasks.consume_file.apply_async( + kwargs={ + "input_doc": ConsumableDocument( + # Same source, for templates + source=self.input_doc.source, + mailrule_id=self.input_doc.mailrule_id, + # Can't use same folder or the consume might grab it again + original_file=(tmp_dir / new_document.name).resolve(), + # Adding optional original_path for later uses in + # workflow matching + original_path=self.input_doc.original_file, + ), + "overrides": self.metadata, + }, + headers={"trigger_source": trigger_source}, ) logger.info(f"Created new task {task.id} for {new_document.name}") diff --git a/src/documents/bulk_edit.py b/src/documents/bulk_edit.py index 3f80b699d..65ce3c785 100644 --- a/src/documents/bulk_edit.py +++ b/src/documents/bulk_edit.py @@ -22,6 +22,7 @@ from documents.models import CustomField from documents.models import CustomFieldInstance from documents.models import Document from documents.models import DocumentType +from documents.models import PaperlessTask from documents.models import StoragePath from documents.models import Tag from documents.permissions import set_permissions_for_object @@ -113,7 +114,10 @@ def set_correspondent( affected_docs = list(qs.values_list("pk", flat=True)) qs.update(correspondent=correspondent) - bulk_update_documents.delay(document_ids=affected_docs) + bulk_update_documents.apply_async( + kwargs={"document_ids": affected_docs}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, + ) return "OK" @@ -132,8 +136,9 @@ def set_storage_path(doc_ids: list[int], storage_path: StoragePath) -> Literal[" affected_docs = list(qs.values_list("pk", flat=True)) qs.update(storage_path=storage_path) - bulk_update_documents.delay( - document_ids=affected_docs, + bulk_update_documents.apply_async( + kwargs={"document_ids": affected_docs}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, ) return "OK" @@ -151,7 +156,10 @@ def set_document_type(doc_ids: list[int], document_type: DocumentType) -> Litera affected_docs = list(qs.values_list("pk", flat=True)) qs.update(document_type=document_type) - bulk_update_documents.delay(document_ids=affected_docs) + bulk_update_documents.apply_async( + kwargs={"document_ids": affected_docs}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, + ) return "OK" @@ -177,7 +185,10 @@ def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]: DocumentTagRelationship.objects.bulk_create(to_create) if affected_docs: - bulk_update_documents.delay(document_ids=list(affected_docs)) + bulk_update_documents.apply_async( + kwargs={"document_ids": list(affected_docs)}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, + ) return "OK" @@ -195,7 +206,10 @@ def remove_tag(doc_ids: list[int], tag: int) -> Literal["OK"]: qs.delete() if affected_docs: - bulk_update_documents.delay(document_ids=affected_docs) + bulk_update_documents.apply_async( + kwargs={"document_ids": affected_docs}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, + ) return "OK" @@ -254,7 +268,10 @@ def modify_tags( ) if affected_docs: - bulk_update_documents.delay(document_ids=affected_docs) + bulk_update_documents.apply_async( + kwargs={"document_ids": affected_docs}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, + ) except Exception as e: logger.error(f"Error modifying tags: {e}") return "ERROR" @@ -326,7 +343,10 @@ def modify_custom_fields( field_id__in=remove_custom_fields, ).hard_delete() - bulk_update_documents.delay(document_ids=affected_docs) + bulk_update_documents.apply_async( + kwargs={"document_ids": affected_docs}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, + ) return "OK" @@ -369,8 +389,9 @@ def delete(doc_ids: list[int]) -> Literal["OK"]: def reprocess(doc_ids: list[int]) -> Literal["OK"]: for document_id in doc_ids: - update_document_content_maybe_archive_file.delay( - document_id=document_id, + update_document_content_maybe_archive_file.apply_async( + kwargs={"document_id": document_id}, + headers={"trigger_source": PaperlessTask.TriggerSource.MANUAL}, ) return "OK" @@ -396,7 +417,10 @@ def set_permissions( affected_docs = list(qs.values_list("pk", flat=True)) - bulk_update_documents.delay(document_ids=affected_docs) + bulk_update_documents.apply_async( + kwargs={"document_ids": affected_docs}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, + ) return "OK" @@ -407,6 +431,7 @@ def rotate( *, source_mode: SourceMode = SourceModeChoices.LATEST_VERSION, user: User | None = None, + trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI, ) -> Literal["OK"]: logger.info( f"Attempting to rotate {len(doc_ids)} documents by {degrees} degrees.", @@ -453,13 +478,16 @@ def rotate( if user is not None: overrides.actor_id = user.id - consume_file.delay( - ConsumableDocument( - source=DocumentSource.ConsumeFolder, - original_file=filepath, - root_document_id=root_doc.id, - ), - overrides, + consume_file.apply_async( + kwargs={ + "input_doc": ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=filepath, + root_document_id=root_doc.id, + ), + "overrides": overrides, + }, + headers={"trigger_source": trigger_source}, ) logger.info( f"Queued new rotated version for document {root_doc.id} by {degrees} degrees", @@ -478,6 +506,7 @@ def merge( archive_fallback: bool = False, source_mode: SourceMode = SourceModeChoices.LATEST_VERSION, user: User | None = None, + trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI, ) -> Literal["OK"]: logger.info( f"Attempting to merge {len(doc_ids)} documents into a single document.", @@ -556,12 +585,12 @@ def merge( logger.info("Adding merged document to the task queue.") consume_task = consume_file.s( - ConsumableDocument( + input_doc=ConsumableDocument( source=DocumentSource.ConsumeFolder, original_file=filepath, ), - overrides, - ) + overrides=overrides, + ).set(headers={"trigger_source": trigger_source}) if delete_originals: backup = release_archive_serial_numbers(affected_docs) @@ -577,7 +606,7 @@ def merge( restore_archive_serial_numbers(backup) raise else: - consume_task.delay() + consume_task.apply_async() return "OK" @@ -589,6 +618,7 @@ def split( delete_originals: bool = False, source_mode: SourceMode = SourceModeChoices.LATEST_VERSION, user: User | None = None, + trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI, ) -> Literal["OK"]: logger.info( f"Attempting to split document {doc_ids[0]} into {len(pages)} documents", @@ -631,12 +661,12 @@ def split( ) consume_tasks.append( consume_file.s( - ConsumableDocument( + input_doc=ConsumableDocument( source=DocumentSource.ConsumeFolder, original_file=filepath, ), - overrides, - ), + overrides=overrides, + ).set(headers={"trigger_source": trigger_source}), ) if delete_originals: @@ -669,6 +699,7 @@ def delete_pages( *, source_mode: SourceMode = SourceModeChoices.LATEST_VERSION, user: User | None = None, + trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI, ) -> Literal["OK"]: logger.info( f"Attempting to delete pages {pages} from {len(doc_ids)} documents", @@ -698,13 +729,16 @@ def delete_pages( overrides = DocumentMetadataOverrides().from_document(root_doc) if user is not None: overrides.actor_id = user.id - consume_file.delay( - ConsumableDocument( - source=DocumentSource.ConsumeFolder, - original_file=filepath, - root_document_id=root_doc.id, - ), - overrides, + consume_file.apply_async( + kwargs={ + "input_doc": ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=filepath, + root_document_id=root_doc.id, + ), + "overrides": overrides, + }, + headers={"trigger_source": trigger_source}, ) logger.info( f"Queued new version for document {root_doc.id} after deleting pages {pages}", @@ -724,6 +758,7 @@ def edit_pdf( include_metadata: bool = True, source_mode: SourceMode = SourceModeChoices.LATEST_VERSION, user: User | None = None, + trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI, ) -> Literal["OK"]: """ Operations is a list of dictionaries describing the final PDF pages. @@ -781,13 +816,16 @@ def edit_pdf( if user is not None: overrides.owner_id = user.id overrides.actor_id = user.id - consume_file.delay( - ConsumableDocument( - source=DocumentSource.ConsumeFolder, - original_file=filepath, - root_document_id=root_doc.id, - ), - overrides, + consume_file.apply_async( + kwargs={ + "input_doc": ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=filepath, + root_document_id=root_doc.id, + ), + "overrides": overrides, + }, + headers={"trigger_source": trigger_source}, ) else: consume_tasks = [] @@ -812,12 +850,12 @@ def edit_pdf( pdf.save(version_filepath) consume_tasks.append( consume_file.s( - ConsumableDocument( + input_doc=ConsumableDocument( source=DocumentSource.ConsumeFolder, original_file=version_filepath, ), - overrides, - ), + overrides=overrides, + ).set(headers={"trigger_source": trigger_source}), ) if delete_original: @@ -853,6 +891,7 @@ def remove_password( include_metadata: bool = True, source_mode: SourceMode = SourceModeChoices.LATEST_VERSION, user: User | None = None, + trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI, ) -> Literal["OK"]: """ Remove password protection from PDF documents. @@ -887,13 +926,16 @@ def remove_password( if user is not None: overrides.owner_id = user.id overrides.actor_id = user.id - consume_file.delay( - ConsumableDocument( - source=DocumentSource.ConsumeFolder, - original_file=filepath, - root_document_id=root_doc.id, - ), - overrides, + consume_file.apply_async( + kwargs={ + "input_doc": ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=filepath, + root_document_id=root_doc.id, + ), + "overrides": overrides, + }, + headers={"trigger_source": trigger_source}, ) else: consume_tasks = [] @@ -908,12 +950,12 @@ def remove_password( consume_tasks.append( consume_file.s( - ConsumableDocument( + input_doc=ConsumableDocument( source=DocumentSource.ConsumeFolder, original_file=filepath, ), - overrides, - ), + overrides=overrides, + ).set(headers={"trigger_source": trigger_source}), ) if delete_original: diff --git a/src/documents/management/commands/document_consumer.py b/src/documents/management/commands/document_consumer.py index 5ba8d30cd..acfbf6e3d 100644 --- a/src/documents/management/commands/document_consumer.py +++ b/src/documents/management/commands/document_consumer.py @@ -27,6 +27,7 @@ from watchfiles import watch from documents.data_models import ConsumableDocument from documents.data_models import DocumentMetadataOverrides from documents.data_models import DocumentSource +from documents.models import PaperlessTask from documents.models import Tag from documents.parsers import get_supported_file_extensions from documents.tasks import consume_file @@ -338,12 +339,15 @@ def _consume_file( # Queue for consumption try: logger.info(f"Adding {filepath} to the task queue") - consume_file.delay( - ConsumableDocument( - source=DocumentSource.ConsumeFolder, - original_file=filepath, - ), - DocumentMetadataOverrides(tag_ids=tag_ids), + consume_file.apply_async( + kwargs={ + "input_doc": ConsumableDocument( + source=DocumentSource.ConsumeFolder, + original_file=filepath, + ), + "overrides": DocumentMetadataOverrides(tag_ids=tag_ids), + }, + headers={"trigger_source": PaperlessTask.TriggerSource.FOLDER_CONSUME}, ) except Exception: logger.exception(f"Error while queuing document {filepath}") diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 986fdf720..94dca2035 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -1604,6 +1604,7 @@ class RotateDocumentsSerializer(DocumentSelectionSerializer, SourceModeValidatio required=False, default=bulk_edit.SourceModeChoices.LATEST_VERSION, ) + from_webui = serializers.BooleanField(required=False, default=False) class MergeDocumentsSerializer(DocumentListSerializer, SourceModeValidationMixin): @@ -1617,6 +1618,7 @@ class MergeDocumentsSerializer(DocumentListSerializer, SourceModeValidationMixin required=False, default=bulk_edit.SourceModeChoices.LATEST_VERSION, ) + from_webui = serializers.BooleanField(required=False, default=False) class EditPdfDocumentsSerializer(DocumentListSerializer, SourceModeValidationMixin): @@ -1628,6 +1630,7 @@ class EditPdfDocumentsSerializer(DocumentListSerializer, SourceModeValidationMix required=False, default=bulk_edit.SourceModeChoices.LATEST_VERSION, ) + from_webui = serializers.BooleanField(required=False, default=False) def validate(self, attrs): documents = attrs["documents"] @@ -1679,6 +1682,7 @@ class RemovePasswordDocumentsSerializer( required=False, default=bulk_edit.SourceModeChoices.LATEST_VERSION, ) + from_webui = serializers.BooleanField(required=False, default=False) class DeleteDocumentsSerializer(DocumentSelectionSerializer): @@ -1726,6 +1730,7 @@ class BulkEditSerializer( ) parameters = serializers.DictField(allow_empty=True, default={}, write_only=True) + from_webui = serializers.BooleanField(required=False, default=False) def _validate_tag_id_list(self, tags, name="tags") -> None: if not isinstance(tags, list): @@ -2398,7 +2403,10 @@ class StoragePathSerializer(MatchingModelSerializer, OwnedObjectSerializer): """ doc_ids = [doc.id for doc in instance.documents.all()] if doc_ids: - bulk_edit.bulk_update_documents.delay(doc_ids) + bulk_edit.bulk_update_documents.apply_async( + kwargs={"document_ids": doc_ids}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, + ) return super().update(instance, validated_data) diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 3e04bc52a..952ce4df9 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -34,7 +34,6 @@ from documents import matching from documents.caching import clear_document_caches from documents.caching import invalidate_llm_suggestions_cache from documents.data_models import ConsumableDocument -from documents.data_models import DocumentSource from documents.file_handling import create_source_path_directory from documents.file_handling import delete_empty_directories from documents.file_handling import generate_filename @@ -709,7 +708,7 @@ def check_paths_and_prune_custom_fields( and instance.fields.count() > 0 and instance.extra_data ): # Only select fields, for now - process_cf_select_update.delay(instance) + process_cf_select_update.apply_async(kwargs={"custom_field": instance}) @receiver(models.signals.post_delete, sender=CustomField) @@ -1024,27 +1023,9 @@ _CELERY_STATE_TO_STATUS: dict[str, PaperlessTask.Status] = { "REVOKED": PaperlessTask.Status.REVOKED, } -_DOCUMENT_SOURCE_TO_TRIGGER: dict[DocumentSource, PaperlessTask.TriggerSource] = { - DocumentSource.ConsumeFolder: PaperlessTask.TriggerSource.FOLDER_CONSUME, - DocumentSource.ApiUpload: PaperlessTask.TriggerSource.API_UPLOAD, - DocumentSource.MailFetch: PaperlessTask.TriggerSource.EMAIL_CONSUME, - DocumentSource.WebUI: PaperlessTask.TriggerSource.WEB_UI, -} - - -def _get_consume_args( - args: tuple, - task_kwargs: dict, -) -> tuple[Any | None, Any | None]: - """Extract (input_doc, overrides) from consume_file task arguments.""" - input_doc = args[0] if args else task_kwargs.get("input_doc") - overrides = args[1] if len(args) >= 2 else task_kwargs.get("overrides") - return input_doc, overrides - def _extract_input_data( task_type: PaperlessTask.TaskType, - args: tuple, task_kwargs: dict, ) -> dict: """Build the input_data dict stored on the PaperlessTask record. @@ -1055,8 +1036,9 @@ def _extract_input_data( types store no input data and return {}. """ if task_type == PaperlessTask.TaskType.CONSUME_FILE: - input_doc, overrides = _get_consume_args(args, task_kwargs) - if input_doc is None: # pragma: no cover + input_doc = task_kwargs.get("input_doc") + overrides = task_kwargs.get("overrides") + if input_doc is None: return {} data: dict = { "filename": input_doc.original_file.name, @@ -1081,7 +1063,7 @@ def _extract_input_data( return data if task_type == PaperlessTask.TaskType.MAIL_FETCH: - account_ids = args[0] if args else task_kwargs.get("account_ids") + account_ids = task_kwargs.get("account_ids") if account_ids is not None: return {"account_ids": account_ids} return {} @@ -1090,46 +1072,30 @@ def _extract_input_data( def _determine_trigger_source( - task_type: PaperlessTask.TaskType, - args: tuple, - task_kwargs: dict, headers: dict, ) -> PaperlessTask.TriggerSource: """Resolve the TriggerSource for a task being published to the broker. - Priority order: - 1. Explicit trigger_source header (set by beat schedule or apply_async callers). - 2. For consume_file tasks, the DocumentSource on the input document. - 3. MANUAL as the catch-all for all other cases. + Reads the trigger_source header set by the caller; falls back to MANUAL + when the header is absent or contains an unrecognised value. """ - # Explicit header takes priority -- callers pass a TriggerSource DB value directly. header_source = headers.get("trigger_source") if header_source is not None: try: return PaperlessTask.TriggerSource(header_source) except ValueError: pass - - if task_type == PaperlessTask.TaskType.CONSUME_FILE: - input_doc, _ = _get_consume_args(args, task_kwargs) - if input_doc is not None: - return _DOCUMENT_SOURCE_TO_TRIGGER.get( - input_doc.source, - PaperlessTask.TriggerSource.API_UPLOAD, - ) - return PaperlessTask.TriggerSource.MANUAL def _extract_owner_id( task_type: PaperlessTask.TaskType, - args: tuple, task_kwargs: dict, ) -> int | None: """Return the owner_id from consume_file overrides, or None for all other task types.""" if task_type != PaperlessTask.TaskType.CONSUME_FILE: return None - _, overrides = _get_consume_args(args, task_kwargs) + overrides = task_kwargs.get("overrides") if overrides and hasattr(overrides, "owner_id"): return overrides.owner_id return None # pragma: no cover @@ -1177,17 +1143,12 @@ def before_task_publish_handler( try: close_old_connections() - args, task_kwargs, _ = body + _, task_kwargs, _ = body task_id = headers["id"] - input_data = _extract_input_data(task_type, args, task_kwargs) - trigger_source = _determine_trigger_source( - task_type, - args, - task_kwargs, - headers, - ) - owner_id = _extract_owner_id(task_type, args, task_kwargs) + input_data = _extract_input_data(task_type, task_kwargs) + trigger_source = _determine_trigger_source(headers) + owner_id = _extract_owner_id(task_type, task_kwargs) PaperlessTask.objects.create( task_id=task_id, @@ -1399,7 +1360,7 @@ def add_or_update_document_in_llm_index(sender, document, **kwargs): if ai_config.llm_index_enabled: from documents.tasks import update_document_in_llm_index - update_document_in_llm_index.delay(document) + update_document_in_llm_index.apply_async(kwargs={"document": document}) @receiver(models.signals.post_delete, sender=Document) @@ -1415,4 +1376,4 @@ def delete_document_from_llm_index( if ai_config.llm_index_enabled: from documents.tasks import remove_document_from_llm_index - remove_document_from_llm_index.delay(instance) + remove_document_from_llm_index.apply_async(kwargs={"document": instance}) diff --git a/src/documents/tasks.py b/src/documents/tasks.py index 86a8047bc..3bb3ff40c 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -40,6 +40,7 @@ from documents.models import Correspondent from documents.models import CustomFieldInstance from documents.models import Document from documents.models import DocumentType +from documents.models import PaperlessTask from documents.models import ShareLink from documents.models import ShareLinkBundle from documents.models import StoragePath @@ -600,7 +601,10 @@ def update_document_parent_tags(tag: Tag, new_parent: Tag) -> None: ) if affected: - bulk_update_documents.delay(document_ids=list(affected)) + bulk_update_documents.apply_async( + kwargs={"document_ids": list(affected)}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, + ) @shared_task diff --git a/src/documents/tests/test_api_bulk_edit.py b/src/documents/tests/test_api_bulk_edit.py index ff780dccd..2eb68de5f 100644 --- a/src/documents/tests/test_api_bulk_edit.py +++ b/src/documents/tests/test_api_bulk_edit.py @@ -26,7 +26,7 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase): self.user = user self.client.force_authenticate(user=user) - patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay") + patcher = mock.patch("documents.bulk_edit.bulk_update_documents.apply_async") self.async_task = patcher.start() self.addCleanup(patcher.stop) self.c1 = Correspondent.objects.create(name="c1") @@ -62,7 +62,7 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase): m.return_value = return_value m.__name__ = method_name - @mock.patch("documents.bulk_edit.bulk_update_documents.delay") + @mock.patch("documents.bulk_edit.bulk_update_documents.apply_async") def test_api_set_correspondent(self, bulk_update_task_mock) -> None: self.assertNotEqual(self.doc1.correspondent, self.c1) response = self.client.post( @@ -79,9 +79,13 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.doc1.refresh_from_db() self.assertEqual(self.doc1.correspondent, self.c1) - bulk_update_task_mock.assert_called_once_with(document_ids=[self.doc1.pk]) + bulk_update_task_mock.assert_called_once() + self.assertCountEqual( + bulk_update_task_mock.call_args.kwargs["kwargs"]["document_ids"], + [self.doc1.pk], + ) - @mock.patch("documents.bulk_edit.bulk_update_documents.delay") + @mock.patch("documents.bulk_edit.bulk_update_documents.apply_async") def test_api_unset_correspondent(self, bulk_update_task_mock) -> None: self.doc1.correspondent = self.c1 self.doc1.save() @@ -103,7 +107,7 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase): self.doc1.refresh_from_db() self.assertIsNone(self.doc1.correspondent) - @mock.patch("documents.bulk_edit.bulk_update_documents.delay") + @mock.patch("documents.bulk_edit.bulk_update_documents.apply_async") def test_api_set_type(self, bulk_update_task_mock) -> None: self.assertNotEqual(self.doc1.document_type, self.dt1) response = self.client.post( @@ -120,9 +124,13 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.doc1.refresh_from_db() self.assertEqual(self.doc1.document_type, self.dt1) - bulk_update_task_mock.assert_called_once_with(document_ids=[self.doc1.pk]) + bulk_update_task_mock.assert_called_once() + self.assertCountEqual( + bulk_update_task_mock.call_args.kwargs["kwargs"]["document_ids"], + [self.doc1.pk], + ) - @mock.patch("documents.bulk_edit.bulk_update_documents.delay") + @mock.patch("documents.bulk_edit.bulk_update_documents.apply_async") def test_api_unset_type(self, bulk_update_task_mock) -> None: self.doc1.document_type = self.dt1 self.doc1.save() @@ -141,9 +149,13 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.doc1.refresh_from_db() self.assertIsNone(self.doc1.document_type) - bulk_update_task_mock.assert_called_once_with(document_ids=[self.doc1.pk]) + bulk_update_task_mock.assert_called_once() + self.assertCountEqual( + bulk_update_task_mock.call_args.kwargs["kwargs"]["document_ids"], + [self.doc1.pk], + ) - @mock.patch("documents.bulk_edit.bulk_update_documents.delay") + @mock.patch("documents.bulk_edit.bulk_update_documents.apply_async") def test_api_add_tag(self, bulk_update_task_mock) -> None: self.assertFalse(self.doc1.tags.filter(pk=self.t1.pk).exists()) @@ -163,9 +175,13 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase): self.assertTrue(self.doc1.tags.filter(pk=self.t1.pk).exists()) - bulk_update_task_mock.assert_called_once_with(document_ids=[self.doc1.pk]) + bulk_update_task_mock.assert_called_once() + self.assertCountEqual( + bulk_update_task_mock.call_args.kwargs["kwargs"]["document_ids"], + [self.doc1.pk], + ) - @mock.patch("documents.bulk_edit.bulk_update_documents.delay") + @mock.patch("documents.bulk_edit.bulk_update_documents.apply_async") def test_api_remove_tag(self, bulk_update_task_mock) -> None: self.doc1.tags.add(self.t1) diff --git a/src/documents/tests/test_api_custom_fields.py b/src/documents/tests/test_api_custom_fields.py index 3606102ac..1e1cebd7c 100644 --- a/src/documents/tests/test_api_custom_fields.py +++ b/src/documents/tests/test_api_custom_fields.py @@ -278,7 +278,7 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase): doc.refresh_from_db() self.assertEqual(doc.custom_fields.first().value, None) - @mock.patch("documents.signals.handlers.process_cf_select_update.delay") + @mock.patch("documents.signals.handlers.process_cf_select_update.apply_async") def test_custom_field_update_offloaded_once(self, mock_delay) -> None: """ GIVEN: @@ -322,7 +322,7 @@ class TestCustomFieldsAPI(DirectoriesMixin, APITestCase): } cf_select.save() - mock_delay.assert_called_once_with(cf_select) + mock_delay.assert_called_once_with(kwargs={"custom_field": cf_select}) def test_create_custom_field_monetary_validation(self) -> None: """ diff --git a/src/documents/tests/test_api_document_versions.py b/src/documents/tests/test_api_document_versions.py index d95e78fe9..bde13354a 100644 --- a/src/documents/tests/test_api_document_versions.py +++ b/src/documents/tests/test_api_document_versions.py @@ -537,7 +537,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): async_task.id = "task-123" with mock.patch("documents.views.consume_file") as consume_mock: - consume_mock.delay.return_value = async_task + consume_mock.apply_async.return_value = async_task resp = self.client.post( f"/api/documents/{root.id}/update_version/", {"document": upload, "version_label": " New Version "}, @@ -546,8 +546,9 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.data, "task-123") - consume_mock.delay.assert_called_once() - input_doc, overrides = consume_mock.delay.call_args[0] + consume_mock.apply_async.assert_called_once() + task_kwargs = consume_mock.apply_async.call_args.kwargs["kwargs"] + input_doc, overrides = task_kwargs["input_doc"], task_kwargs["overrides"] self.assertEqual(input_doc.root_document_id, root.id) self.assertEqual(input_doc.source, DocumentSource.ApiUpload) self.assertEqual(overrides.version_label, "New Version") @@ -571,7 +572,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): async_task.id = "task-123" with mock.patch("documents.views.consume_file") as consume_mock: - consume_mock.delay.return_value = async_task + consume_mock.apply_async.return_value = async_task resp = self.client.post( f"/api/documents/{version.id}/update_version/", {"document": upload, "version_label": " New Version "}, @@ -580,8 +581,9 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.data, "task-123") - consume_mock.delay.assert_called_once() - input_doc, overrides = consume_mock.delay.call_args[0] + consume_mock.apply_async.assert_called_once() + task_kwargs = consume_mock.apply_async.call_args.kwargs["kwargs"] + input_doc, overrides = task_kwargs["input_doc"], task_kwargs["overrides"] self.assertEqual(input_doc.root_document_id, root.id) self.assertEqual(overrides.version_label, "New Version") self.assertEqual(overrides.actor_id, self.user.id) @@ -595,7 +597,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase): upload = self._make_pdf_upload() with mock.patch("documents.views.consume_file") as consume_mock: - consume_mock.delay.side_effect = Exception("boom") + consume_mock.apply_async.side_effect = Exception("boom") resp = self.client.post( f"/api/documents/{root.id}/update_version/", {"document": upload}, diff --git a/src/documents/tests/test_api_documents.py b/src/documents/tests/test_api_documents.py index 76efc1b41..24165c087 100644 --- a/src/documents/tests/test_api_documents.py +++ b/src/documents/tests/test_api_documents.py @@ -47,11 +47,11 @@ from documents.models import Workflow from documents.models import WorkflowAction from documents.models import WorkflowTrigger from documents.signals.handlers import run_workflows +from documents.tests.utils import ConsumeTaskMixin from documents.tests.utils import DirectoriesMixin -from documents.tests.utils import DocumentConsumeDelayMixin -class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): +class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase): def setUp(self) -> None: super().setUp() @@ -1400,9 +1400,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - input_doc, overrides = self.get_last_consume_delay_call_args() + input_doc, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(input_doc.original_file.name, "simple.pdf") self.assertTrue( @@ -1432,9 +1430,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - input_doc, overrides = self.get_last_consume_delay_call_args() + input_doc, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(input_doc.original_file.name, "outside.pdf") self.assertEqual(overrides.filename, "outside.pdf") @@ -1474,9 +1470,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - input_doc, overrides = self.get_last_consume_delay_call_args() + input_doc, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(input_doc.original_file.name, "outside.pdf") self.assertEqual(overrides.filename, "outside.pdf") @@ -1558,9 +1552,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - input_doc, overrides = self.get_last_consume_delay_call_args() + input_doc, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(input_doc.original_file.name, "simple.pdf") self.assertTrue( @@ -1612,9 +1604,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - _, overrides = self.get_last_consume_delay_call_args() + _, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(overrides.title, "my custom title") self.assertIsNone(overrides.correspondent_id) @@ -1634,9 +1624,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - _, overrides = self.get_last_consume_delay_call_args() + _, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(overrides.correspondent_id, c.id) self.assertIsNone(overrides.title) @@ -1670,9 +1658,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - _, overrides = self.get_last_consume_delay_call_args() + _, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(overrides.document_type_id, dt.id) self.assertIsNone(overrides.correspondent_id) @@ -1706,9 +1692,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - _, overrides = self.get_last_consume_delay_call_args() + _, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(overrides.storage_path_id, sp.id) self.assertIsNone(overrides.correspondent_id) @@ -1743,9 +1727,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - _, overrides = self.get_last_consume_delay_call_args() + _, overrides = self.assert_queue_consumption_task_call_args() self.assertCountEqual(overrides.tag_ids, [t1.id, t2.id]) self.assertIsNone(overrides.document_type_id) @@ -1790,9 +1772,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - _, overrides = self.get_last_consume_delay_call_args() + _, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(overrides.created, created.date()) @@ -1809,9 +1789,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - input_doc, overrides = self.get_last_consume_delay_call_args() + input_doc, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(input_doc.original_file.name, "simple.pdf") self.assertEqual(overrides.filename, "simple.pdf") @@ -1841,9 +1819,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - input_doc, overrides = self.get_last_consume_delay_call_args() + input_doc, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(input_doc.original_file.name, "simple.pdf") self.assertEqual(overrides.filename, "simple.pdf") @@ -1898,9 +1874,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - input_doc, overrides = self.get_last_consume_delay_call_args() + input_doc, overrides = self.assert_queue_consumption_task_call_args() new_overrides, _ = run_workflows( trigger_type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION, @@ -1946,9 +1920,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - input_doc, overrides = self.get_last_consume_delay_call_args() + input_doc, overrides = self.assert_queue_consumption_task_call_args() self.assertEqual(input_doc.original_file.name, "simple.pdf") self.assertEqual(overrides.filename, "simple.pdf") @@ -2047,9 +2019,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) - self.consume_file_mock.assert_called_once() - - input_doc, _ = self.get_last_consume_delay_call_args() + input_doc, _ = self.assert_queue_consumption_task_call_args() self.assertEqual(input_doc.source, WorkflowTrigger.DocumentSourceChoices.WEB_UI) diff --git a/src/documents/tests/test_api_objects.py b/src/documents/tests/test_api_objects.py index bf1ac4d9c..05911febc 100644 --- a/src/documents/tests/test_api_objects.py +++ b/src/documents/tests/test_api_objects.py @@ -291,7 +291,7 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(StoragePath.objects.count(), 2) - @mock.patch("documents.bulk_edit.bulk_update_documents.delay") + @mock.patch("documents.bulk_edit.bulk_update_documents.apply_async") def test_api_update_storage_path(self, bulk_update_mock) -> None: """ GIVEN: @@ -316,11 +316,12 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase): bulk_update_mock.assert_called_once() - args, _ = bulk_update_mock.call_args + self.assertCountEqual( + [document.pk], + bulk_update_mock.call_args.kwargs["kwargs"]["document_ids"], + ) - self.assertCountEqual([document.pk], args[0]) - - @mock.patch("documents.bulk_edit.bulk_update_documents.delay") + @mock.patch("documents.bulk_edit.bulk_update_documents.apply_async") def test_api_delete_storage_path(self, bulk_update_mock) -> None: """ GIVEN: @@ -347,7 +348,11 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase): self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) # only called once - bulk_update_mock.assert_called_once_with([document.pk]) + bulk_update_mock.assert_called_once() + self.assertCountEqual( + [document.pk], + bulk_update_mock.call_args.kwargs["kwargs"]["document_ids"], + ) def test_test_storage_path(self) -> None: """ diff --git a/src/documents/tests/test_barcodes.py b/src/documents/tests/test_barcodes.py index 4d8da62a3..0a8da4eb0 100644 --- a/src/documents/tests/test_barcodes.py +++ b/src/documents/tests/test_barcodes.py @@ -17,8 +17,8 @@ from documents.data_models import DocumentSource from documents.models import Document from documents.models import Tag from documents.plugins.base import StopConsumeTaskError +from documents.tests.utils import ConsumeTaskMixin from documents.tests.utils import DirectoriesMixin -from documents.tests.utils import DocumentConsumeDelayMixin from documents.tests.utils import DummyProgressManager from documents.tests.utils import FileSystemAssertsMixin from documents.tests.utils import SampleDirMixin @@ -601,7 +601,7 @@ class TestBarcodeNewConsume( DirectoriesMixin, FileSystemAssertsMixin, SampleDirMixin, - DocumentConsumeDelayMixin, + ConsumeTaskMixin, TestCase, ): @override_settings(CONSUMER_ENABLE_BARCODES=True) @@ -646,7 +646,7 @@ class TestBarcodeNewConsume( for ( new_input_doc, new_doc_overrides, - ) in self.get_all_consume_delay_call_args(): + ) in self.get_all_consume_task_call_args(): self.assertIsFile(new_input_doc.original_file) self.assertEqual(new_input_doc.original_path, temp_copy) self.assertEqual(new_input_doc.source, DocumentSource.ConsumeFolder) diff --git a/src/documents/tests/test_bulk_edit.py b/src/documents/tests/test_bulk_edit.py index 9b6c3c468..0c44157a5 100644 --- a/src/documents/tests/test_bulk_edit.py +++ b/src/documents/tests/test_bulk_edit.py @@ -31,7 +31,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): self.group1 = Group.objects.create(name="group1") self.group2 = Group.objects.create(name="group2") - patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay") + patcher = mock.patch("documents.bulk_edit.bulk_update_documents.apply_async") self.async_task = patcher.start() self.addCleanup(patcher.stop) self.c1 = Correspondent.objects.create(name="c1") @@ -74,7 +74,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): ) self.assertEqual(Document.objects.filter(correspondent=self.c2).count(), 3) self.async_task.assert_called_once() - _, kwargs = self.async_task.call_args + kwargs = self.async_task.call_args.kwargs["kwargs"] self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id]) def test_unset_correspondent(self) -> None: @@ -82,7 +82,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): bulk_edit.set_correspondent([self.doc1.id, self.doc2.id, self.doc3.id], None) self.assertEqual(Document.objects.filter(correspondent=self.c2).count(), 0) self.async_task.assert_called_once() - _, kwargs = self.async_task.call_args + kwargs = self.async_task.call_args.kwargs["kwargs"] self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id]) def test_set_document_type(self) -> None: @@ -93,7 +93,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): ) self.assertEqual(Document.objects.filter(document_type=self.dt2).count(), 3) self.async_task.assert_called_once() - _, kwargs = self.async_task.call_args + kwargs = self.async_task.call_args.kwargs["kwargs"] self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id]) def test_unset_document_type(self) -> None: @@ -101,7 +101,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): bulk_edit.set_document_type([self.doc1.id, self.doc2.id, self.doc3.id], None) self.assertEqual(Document.objects.filter(document_type=self.dt2).count(), 0) self.async_task.assert_called_once() - _, kwargs = self.async_task.call_args + kwargs = self.async_task.call_args.kwargs["kwargs"] self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id]) def test_set_document_storage_path(self) -> None: @@ -123,7 +123,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): self.assertEqual(Document.objects.filter(storage_path=None).count(), 4) self.async_task.assert_called_once() - _, kwargs = self.async_task.call_args + kwargs = self.async_task.call_args.kwargs["kwargs"] self.assertCountEqual(kwargs["document_ids"], [self.doc1.id]) @@ -154,7 +154,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): self.assertEqual(Document.objects.filter(storage_path=None).count(), 5) self.async_task.assert_called() - _, kwargs = self.async_task.call_args + kwargs = self.async_task.call_args.kwargs["kwargs"] self.assertCountEqual(kwargs["document_ids"], [self.doc1.id]) @@ -166,7 +166,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): ) self.assertEqual(Document.objects.filter(tags__id=self.t1.id).count(), 4) self.async_task.assert_called_once() - _, kwargs = self.async_task.call_args + kwargs = self.async_task.call_args.kwargs["kwargs"] self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc3.id]) def test_remove_tag(self) -> None: @@ -174,7 +174,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): bulk_edit.remove_tag([self.doc1.id, self.doc3.id, self.doc4.id], self.t1.id) self.assertEqual(Document.objects.filter(tags__id=self.t1.id).count(), 1) self.async_task.assert_called_once() - _, kwargs = self.async_task.call_args + kwargs = self.async_task.call_args.kwargs["kwargs"] self.assertCountEqual(kwargs["document_ids"], [self.doc4.id]) def test_modify_tags(self) -> None: @@ -191,7 +191,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): self.assertCountEqual(list(self.doc3.tags.all()), [self.t2, tag_unrelated]) self.async_task.assert_called_once() - _, kwargs = self.async_task.call_args + kwargs = self.async_task.call_args.kwargs["kwargs"] # TODO: doc3 should not be affected, but the query for that is rather complicated self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id]) @@ -248,7 +248,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): ) self.async_task.assert_called_once() - _, kwargs = self.async_task.call_args + kwargs = self.async_task.call_args.kwargs["kwargs"] self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id]) def test_modify_custom_fields_with_values(self) -> None: @@ -325,7 +325,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): ) self.async_task.assert_called_once() - _, kwargs = self.async_task.call_args + kwargs = self.async_task.call_args.kwargs["kwargs"] self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id]) # removal of document link cf, should also remove symmetric link @@ -428,7 +428,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): self.assertEqual(source_doc.id, version2.id) self.assertNotEqual(source_doc.id, version1.id) - @mock.patch("documents.tasks.bulk_update_documents.delay") + @mock.patch("documents.tasks.bulk_update_documents.apply_async") def test_set_permissions(self, m) -> None: doc_ids = [self.doc1.id, self.doc2.id, self.doc3.id] @@ -467,7 +467,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase): ) self.assertEqual(groups_with_perms.count(), 1) - @mock.patch("documents.tasks.bulk_update_documents.delay") + @mock.patch("documents.tasks.bulk_update_documents.apply_async") def test_set_permissions_merge(self, m) -> None: doc_ids = [self.doc1.id, self.doc2.id, self.doc3.id] @@ -643,20 +643,20 @@ class TestPDFActions(DirectoriesMixin, TestCase): ) mock_consume_file.assert_called() - consume_file_args, _ = mock_consume_file.call_args + call_kwargs = mock_consume_file.call_args.kwargs self.assertEqual( - Path(consume_file_args[0].original_file).name, + Path(call_kwargs["input_doc"].original_file).name, expected_filename, ) - self.assertEqual(consume_file_args[1].title, None) + self.assertEqual(call_kwargs["overrides"].title, None) # No metadata_document_id, delete_originals False, so ASN should be None - self.assertIsNone(consume_file_args[1].asn) + self.assertIsNone(call_kwargs["overrides"].asn) # With metadata_document_id overrides result = bulk_edit.merge(doc_ids, metadata_document_id=metadata_document_id) - consume_file_args, _ = mock_consume_file.call_args - self.assertEqual(consume_file_args[1].title, "B (merged)") - self.assertEqual(consume_file_args[1].created, self.doc2.created) + call_kwargs = mock_consume_file.call_args.kwargs + self.assertEqual(call_kwargs["overrides"].title, "B (merged)") + self.assertEqual(call_kwargs["overrides"].created, self.doc2.created) self.assertEqual(result, "OK") @@ -720,16 +720,15 @@ class TestPDFActions(DirectoriesMixin, TestCase): mock_consume_file.assert_called() mock_delete_documents.assert_called() - consume_sig = mock_consume_file.return_value - consume_sig.apply_async.assert_called_once() + mock_consume_file.return_value.set.return_value.apply_async.assert_called_once() - consume_file_args, _ = mock_consume_file.call_args + call_kwargs = mock_consume_file.call_args.kwargs self.assertEqual( - Path(consume_file_args[0].original_file).name, + Path(call_kwargs["input_doc"].original_file).name, expected_filename, ) - self.assertEqual(consume_file_args[1].title, None) - self.assertEqual(consume_file_args[1].asn, 101) + self.assertEqual(call_kwargs["overrides"].title, None) + self.assertEqual(call_kwargs["overrides"].asn, 101) delete_documents_args, _ = mock_delete_documents.call_args self.assertEqual( @@ -764,7 +763,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): self.doc1.archive_serial_number = 111 self.doc1.save() sig = mock.Mock() - sig.apply_async.side_effect = Exception("boom") + sig.set.return_value.apply_async.side_effect = Exception("boom") mock_consume_file.return_value = sig with self.assertRaises(Exception): @@ -801,8 +800,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): ) self.assertEqual(result, "OK") - consume_file_args, _ = mock_consume_file.call_args - self.assertEqual(consume_file_args[1].asn, 202) + self.assertEqual(mock_consume_file.call_args.kwargs["overrides"].asn, 202) def test_restore_archive_serial_numbers_task(self) -> None: """ @@ -843,9 +841,8 @@ class TestPDFActions(DirectoriesMixin, TestCase): ) mock_consume_file.assert_called() - consume_file_args, _ = mock_consume_file.call_args self.assertEqual( - Path(consume_file_args[0].original_file).name, + Path(mock_consume_file.call_args.kwargs["input_doc"].original_file).name, expected_filename, ) @@ -889,9 +886,11 @@ class TestPDFActions(DirectoriesMixin, TestCase): user = User.objects.create(username="test_user") result = bulk_edit.split(doc_ids, pages, delete_originals=False, user=user) self.assertEqual(mock_consume_file.call_count, 2) - consume_file_args, _ = mock_consume_file.call_args - self.assertEqual(consume_file_args[1].title, "B (split 2)") - self.assertIsNone(consume_file_args[1].asn) + self.assertEqual( + mock_consume_file.call_args.kwargs["overrides"].title, + "B (split 2)", + ) + self.assertIsNone(mock_consume_file.call_args.kwargs["overrides"].asn) self.assertEqual(result, "OK") @@ -953,8 +952,10 @@ class TestPDFActions(DirectoriesMixin, TestCase): self.assertEqual(result, "OK") self.assertEqual(mock_consume_file.call_count, 2) - consume_file_args, _ = mock_consume_file.call_args - self.assertEqual(consume_file_args[1].title, "B (split 2)") + self.assertEqual( + mock_consume_file.call_args.kwargs["overrides"].title, + "B (split 2)", + ) mock_delete_documents.assert_called() mock_chord.assert_called_once() @@ -1001,7 +1002,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): self.doc2.refresh_from_db() self.assertEqual(self.doc2.archive_serial_number, 222) - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") @mock.patch("pikepdf.Pdf.save") def test_split_with_errors(self, mock_save_pdf, mock_consume_file) -> None: """ @@ -1025,7 +1026,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): mock_consume_file.assert_not_called() - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") def test_rotate(self, mock_consume_delay): """ GIVEN: @@ -1042,12 +1043,12 @@ class TestPDFActions(DirectoriesMixin, TestCase): mock_consume_delay.call_args_list, doc_ids, ): - consumable, overrides = call.args - self.assertEqual(consumable.root_document_id, expected_id) - self.assertIsNotNone(overrides) + task_kwargs = call.kwargs["kwargs"] + self.assertEqual(task_kwargs["input_doc"].root_document_id, expected_id) + self.assertIsNotNone(task_kwargs["overrides"]) self.assertEqual(result, "OK") - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") @mock.patch("pikepdf.Pdf.save") def test_rotate_with_error( self, @@ -1073,7 +1074,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): self.assertIn(expected_str, error_str) mock_consume_delay.assert_not_called() - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") def test_rotate_non_pdf( self, mock_consume_delay, @@ -1091,13 +1092,13 @@ class TestPDFActions(DirectoriesMixin, TestCase): expected_str = f"Document {self.img_doc.id} is not a PDF, skipping rotation" self.assertTrue(any(expected_str in line for line in cm.output)) self.assertEqual(mock_consume_delay.call_count, 1) - consumable, overrides = mock_consume_delay.call_args[0] - self.assertEqual(consumable.root_document_id, self.doc2.id) - self.assertIsNotNone(overrides) + task_kwargs = mock_consume_delay.call_args.kwargs["kwargs"] + self.assertEqual(task_kwargs["input_doc"].root_document_id, self.doc2.id) + self.assertIsNotNone(task_kwargs["overrides"]) self.assertEqual(result, "OK") @mock.patch("documents.data_models.magic.from_file", return_value="application/pdf") - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") @mock.patch("pikepdf.open") def test_rotate_explicit_selection_uses_root_source_when_root_selected( self, @@ -1124,7 +1125,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): mock_open.assert_called_once_with(self.doc2.source_path) mock_consume_delay.assert_called_once() - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") @mock.patch("pikepdf.Pdf.save") @mock.patch("documents.data_models.magic.from_file", return_value="application/pdf") def test_delete_pages(self, mock_magic, mock_pdf_save, mock_consume_delay): @@ -1142,14 +1143,16 @@ class TestPDFActions(DirectoriesMixin, TestCase): result = bulk_edit.delete_pages(doc_ids, pages) mock_pdf_save.assert_called_once() mock_consume_delay.assert_called_once() - consumable, overrides = mock_consume_delay.call_args[0] - self.assertEqual(consumable.root_document_id, self.doc2.id) - self.assertTrue(str(consumable.original_file).endswith("_pages_deleted.pdf")) - self.assertIsNotNone(overrides) + task_kwargs = mock_consume_delay.call_args.kwargs["kwargs"] + self.assertEqual(task_kwargs["input_doc"].root_document_id, self.doc2.id) + self.assertTrue( + str(task_kwargs["input_doc"].original_file).endswith("_pages_deleted.pdf"), + ) + self.assertIsNotNone(task_kwargs["overrides"]) self.assertEqual(result, "OK") @mock.patch("documents.data_models.magic.from_file", return_value="application/pdf") - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") @mock.patch("pikepdf.open") def test_delete_pages_explicit_selection_uses_root_source_when_root_selected( self, @@ -1176,7 +1179,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): mock_open.assert_called_once_with(self.doc2.source_path) mock_consume_delay.assert_called_once() - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") @mock.patch("pikepdf.Pdf.save") def test_delete_pages_with_error(self, mock_pdf_save, mock_consume_delay): """ @@ -1259,8 +1262,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): result = bulk_edit.edit_pdf(doc_ids, operations, delete_original=True) self.assertEqual(result, "OK") mock_chord.assert_called_once() - consume_file_args, _ = mock_consume_file.call_args - self.assertEqual(consume_file_args[1].asn, 250) + self.assertEqual(mock_consume_file.call_args.kwargs["overrides"].asn, 250) self.doc2.refresh_from_db() self.assertIsNone(self.doc2.archive_serial_number) @@ -1297,7 +1299,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): self.doc2.refresh_from_db() self.assertEqual(self.doc2.archive_serial_number, 333) - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") def test_edit_pdf_with_update_document(self, mock_consume_delay): """ GIVEN: @@ -1319,13 +1321,15 @@ class TestPDFActions(DirectoriesMixin, TestCase): self.assertEqual(result, "OK") mock_consume_delay.assert_called_once() - consumable, overrides = mock_consume_delay.call_args[0] - self.assertEqual(consumable.root_document_id, self.doc2.id) - self.assertTrue(str(consumable.original_file).endswith("_edited.pdf")) - self.assertIsNotNone(overrides) + task_kwargs = mock_consume_delay.call_args.kwargs["kwargs"] + self.assertEqual(task_kwargs["input_doc"].root_document_id, self.doc2.id) + self.assertTrue( + str(task_kwargs["input_doc"].original_file).endswith("_edited.pdf"), + ) + self.assertIsNotNone(task_kwargs["overrides"]) @mock.patch("documents.data_models.magic.from_file", return_value="application/pdf") - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") @mock.patch("pikepdf.new") @mock.patch("pikepdf.open") def test_edit_pdf_explicit_selection_uses_root_source_when_root_selected( @@ -1433,7 +1437,7 @@ class TestPDFActions(DirectoriesMixin, TestCase): mock_consume_file.assert_not_called() @mock.patch("documents.bulk_edit.update_document_content_maybe_archive_file.delay") - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") @mock.patch("documents.bulk_edit.tempfile.mkdtemp") @mock.patch("pikepdf.open") def test_remove_password_update_document( @@ -1468,18 +1472,18 @@ class TestPDFActions(DirectoriesMixin, TestCase): fake_pdf.remove_unreferenced_resources.assert_called_once() mock_update_document.assert_not_called() mock_consume_delay.assert_called_once() - consumable, overrides = mock_consume_delay.call_args[0] + task_kwargs = mock_consume_delay.call_args.kwargs["kwargs"] expected_path = temp_dir / f"{doc.id}_unprotected.pdf" self.assertTrue(expected_path.exists()) self.assertEqual( - Path(consumable.original_file).resolve(), + Path(task_kwargs["input_doc"].original_file).resolve(), expected_path.resolve(), ) - self.assertEqual(consumable.root_document_id, doc.id) - self.assertIsNotNone(overrides) + self.assertEqual(task_kwargs["input_doc"].root_document_id, doc.id) + self.assertIsNotNone(task_kwargs["overrides"]) @mock.patch("documents.data_models.magic.from_file", return_value="application/pdf") - @mock.patch("documents.tasks.consume_file.delay") + @mock.patch("documents.tasks.consume_file.apply_async") @mock.patch("pikepdf.open") def test_remove_password_explicit_selection_uses_root_source_when_root_selected( self, @@ -1548,9 +1552,9 @@ class TestPDFActions(DirectoriesMixin, TestCase): self.assertEqual(result, "OK") mock_open.assert_called_once_with(doc.source_path, password="secret") mock_consume_file.assert_called_once() - consume_args, _ = mock_consume_file.call_args - consumable_document = consume_args[0] - overrides = consume_args[1] + call_kwargs = mock_consume_file.call_args.kwargs + consumable_document = call_kwargs["input_doc"] + overrides = call_kwargs["overrides"] expected_path = temp_dir / f"{doc.id}_unprotected.pdf" self.assertTrue(expected_path.exists()) self.assertEqual( @@ -1558,7 +1562,9 @@ class TestPDFActions(DirectoriesMixin, TestCase): expected_path.resolve(), ) self.assertEqual(overrides.owner_id, user.id) - mock_group.assert_called_once_with([mock_consume_file.return_value]) + mock_group.assert_called_once_with( + [mock_consume_file.return_value.set.return_value], + ) mock_group.return_value.delay.assert_called_once() mock_chord.assert_not_called() diff --git a/src/documents/tests/test_management_consumer.py b/src/documents/tests/test_management_consumer.py index f4451f545..707397788 100644 --- a/src/documents/tests/test_management_consumer.py +++ b/src/documents/tests/test_management_consumer.py @@ -97,12 +97,10 @@ def consumer_filter() -> ConsumerFilter: @pytest.fixture def mock_consume_file_delay(mocker: MockerFixture) -> MagicMock: - """Mock the consume_file.delay celery task.""" - mock_task = mocker.patch( + """Mock the consume_file task.""" + return mocker.patch( "documents.management.commands.document_consumer.consume_file", ) - mock_task.delay = mocker.MagicMock() - return mock_task @pytest.fixture @@ -453,9 +451,9 @@ class TestConsumeFile: subdirs_as_tags=False, ) - mock_consume_file_delay.delay.assert_called_once() - call_args = mock_consume_file_delay.delay.call_args - consumable_doc = call_args[0][0] + mock_consume_file_delay.apply_async.assert_called_once() + call_args = mock_consume_file_delay.apply_async.call_args + consumable_doc = call_args.kwargs["kwargs"]["input_doc"] assert isinstance(consumable_doc, ConsumableDocument) assert consumable_doc.original_file == target assert consumable_doc.source == DocumentSource.ConsumeFolder @@ -471,7 +469,7 @@ class TestConsumeFile: consumption_dir=consumption_dir, subdirs_as_tags=False, ) - mock_consume_file_delay.delay.assert_not_called() + mock_consume_file_delay.apply_async.assert_not_called() def test_consume_directory( self, @@ -487,7 +485,7 @@ class TestConsumeFile: consumption_dir=consumption_dir, subdirs_as_tags=False, ) - mock_consume_file_delay.delay.assert_not_called() + mock_consume_file_delay.apply_async.assert_not_called() def test_consume_with_permission_error( self, @@ -506,7 +504,7 @@ class TestConsumeFile: consumption_dir=consumption_dir, subdirs_as_tags=False, ) - mock_consume_file_delay.delay.assert_not_called() + mock_consume_file_delay.apply_async.assert_not_called() def test_consume_with_tags_error( self, @@ -529,9 +527,9 @@ class TestConsumeFile: consumption_dir=consumption_dir, subdirs_as_tags=True, ) - mock_consume_file_delay.delay.assert_called_once() - call_args = mock_consume_file_delay.delay.call_args - overrides = call_args[0][1] + mock_consume_file_delay.apply_async.assert_called_once() + call_args = mock_consume_file_delay.apply_async.call_args + overrides = call_args.kwargs["kwargs"]["overrides"] assert overrides.tag_ids is None @@ -629,7 +627,7 @@ class TestCommandOneshot: cmd = Command() cmd.handle(directory=str(consumption_dir), oneshot=True, testing=False) - mock_consume_file_delay.delay.assert_called_once() + mock_consume_file_delay.apply_async.assert_called_once() def test_processes_recursive( self, @@ -652,7 +650,7 @@ class TestCommandOneshot: cmd = Command() cmd.handle(directory=str(consumption_dir), oneshot=True, testing=False) - mock_consume_file_delay.delay.assert_called_once() + mock_consume_file_delay.apply_async.assert_called_once() def test_ignores_unsupported_extensions( self, @@ -671,7 +669,7 @@ class TestCommandOneshot: cmd = Command() cmd.handle(directory=str(consumption_dir), oneshot=True, testing=False) - mock_consume_file_delay.delay.assert_not_called() + mock_consume_file_delay.apply_async.assert_not_called() class ConsumerThread(Thread): @@ -795,12 +793,12 @@ class TestCommandWatch: target = consumption_dir / "document.pdf" shutil.copy(sample_pdf, target) - wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0) + wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0) if thread.exception: raise thread.exception - mock_consume_file_delay.delay.assert_called() + mock_consume_file_delay.apply_async.assert_called() def test_detects_moved_file( self, @@ -821,12 +819,12 @@ class TestCommandWatch: target = consumption_dir / "document.pdf" shutil.move(temp_location, target) - wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0) + wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0) if thread.exception: raise thread.exception - mock_consume_file_delay.delay.assert_called() + mock_consume_file_delay.apply_async.assert_called() def test_handles_slow_write( self, @@ -847,12 +845,12 @@ class TestCommandWatch: f.flush() sleep(0.05) - wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0) + wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0) if thread.exception: raise thread.exception - mock_consume_file_delay.delay.assert_called() + mock_consume_file_delay.apply_async.assert_called() def test_ignores_macos_files( self, @@ -868,13 +866,15 @@ class TestCommandWatch: (consumption_dir / "._document.pdf").write_bytes(b"test") shutil.copy(sample_pdf, consumption_dir / "valid.pdf") - wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0) + wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0) if thread.exception: raise thread.exception - assert mock_consume_file_delay.delay.call_count == 1 - call_args = mock_consume_file_delay.delay.call_args[0][0] + assert mock_consume_file_delay.apply_async.call_count == 1 + call_args = mock_consume_file_delay.apply_async.call_args.kwargs["kwargs"][ + "input_doc" + ] assert call_args.original_file.name == "valid.pdf" @pytest.mark.django_db @@ -924,12 +924,12 @@ class TestCommandWatchPolling: # Actively wait for consumption # Polling needs: interval (0.5s) + stability (0.1s) + next poll (0.5s) + margin - wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=5.0) + wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=5.0) if thread.exception: raise thread.exception - mock_consume_file_delay.delay.assert_called() + mock_consume_file_delay.apply_async.assert_called() @pytest.mark.management @@ -953,12 +953,12 @@ class TestCommandWatchRecursive: target = subdir / "document.pdf" shutil.copy(sample_pdf, target) - wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0) + wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0) if thread.exception: raise thread.exception - mock_consume_file_delay.delay.assert_called() + mock_consume_file_delay.apply_async.assert_called() def test_subdirs_as_tags( self, @@ -983,15 +983,15 @@ class TestCommandWatchRecursive: target = subdir / "document.pdf" shutil.copy(sample_pdf, target) - wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0) + wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0) if thread.exception: raise thread.exception - mock_consume_file_delay.delay.assert_called() + mock_consume_file_delay.apply_async.assert_called() mock_tags.assert_called() - call_args = mock_consume_file_delay.delay.call_args - overrides = call_args[0][1] + call_args = mock_consume_file_delay.apply_async.call_args + overrides = call_args.kwargs["kwargs"]["overrides"] assert overrides.tag_ids is not None assert len(overrides.tag_ids) == 2 @@ -1021,7 +1021,7 @@ class TestCommandWatchEdgeCases: if thread.exception: raise thread.exception - mock_consume_file_delay.delay.assert_not_called() + mock_consume_file_delay.apply_async.assert_not_called() @pytest.mark.usefixtures("mock_supported_extensions") def test_handles_task_exception( @@ -1035,7 +1035,7 @@ class TestCommandWatchEdgeCases: mock_task = mocker.patch( "documents.management.commands.document_consumer.consume_file", ) - mock_task.delay.side_effect = Exception("Task error") + mock_task.apply_async.side_effect = Exception("Task error") thread = ConsumerThread(consumption_dir, scratch_dir) try: diff --git a/src/documents/tests/test_share_link_bundles.py b/src/documents/tests/test_share_link_bundles.py index c82260819..6d22d9a4e 100644 --- a/src/documents/tests/test_share_link_bundles.py +++ b/src/documents/tests/test_share_link_bundles.py @@ -31,7 +31,7 @@ class ShareLinkBundleAPITests(DirectoriesMixin, APITestCase): self.client.force_authenticate(self.user) self.document = DocumentFactory.create() - @mock.patch("documents.views.build_share_link_bundle.delay") + @mock.patch("documents.views.build_share_link_bundle.apply_async") def test_create_bundle_triggers_build_job(self, delay_mock) -> None: payload = { "document_ids": [self.document.pk], @@ -45,7 +45,8 @@ class ShareLinkBundleAPITests(DirectoriesMixin, APITestCase): bundle = ShareLinkBundle.objects.get(pk=response.data["id"]) self.assertEqual(bundle.documents.count(), 1) self.assertEqual(bundle.status, ShareLinkBundle.Status.PENDING) - delay_mock.assert_called_once_with(bundle.pk) + delay_mock.assert_called_once() + self.assertEqual(delay_mock.call_args.kwargs["kwargs"]["bundle_id"], bundle.pk) def test_create_bundle_rejects_missing_documents(self) -> None: payload = { @@ -73,7 +74,7 @@ class ShareLinkBundleAPITests(DirectoriesMixin, APITestCase): self.assertIn("document_ids", response.data) perms_mock.assert_called() - @mock.patch("documents.views.build_share_link_bundle.delay") + @mock.patch("documents.views.build_share_link_bundle.apply_async") def test_rebuild_bundle_resets_state(self, delay_mock) -> None: bundle = ShareLinkBundle.objects.create( slug="rebuild-slug", @@ -94,7 +95,8 @@ class ShareLinkBundleAPITests(DirectoriesMixin, APITestCase): self.assertIsNone(bundle.last_error) self.assertIsNone(bundle.size_bytes) self.assertEqual(bundle.file_path, "") - delay_mock.assert_called_once_with(bundle.pk) + delay_mock.assert_called_once() + self.assertEqual(delay_mock.call_args.kwargs["kwargs"]["bundle_id"], bundle.pk) def test_rebuild_bundle_rejects_processing_status(self) -> None: bundle = ShareLinkBundle.objects.create( diff --git a/src/documents/tests/test_tag_hierarchy.py b/src/documents/tests/test_tag_hierarchy.py index 57aa27e3a..0bb4c75c5 100644 --- a/src/documents/tests/test_tag_hierarchy.py +++ b/src/documents/tests/test_tag_hierarchy.py @@ -23,7 +23,7 @@ class TestTagHierarchy(DirectoriesMixin, APITestCase): self.parent = Tag.objects.create(name="Parent") self.child = Tag.objects.create(name="Child", tn_parent=self.parent) - patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay") + patcher = mock.patch("documents.bulk_edit.bulk_update_documents.apply_async") self.async_task = patcher.start() self.addCleanup(patcher.stop) diff --git a/src/documents/tests/test_task_signals.py b/src/documents/tests/test_task_signals.py index 80b5e5075..56d964822 100644 --- a/src/documents/tests/test_task_signals.py +++ b/src/documents/tests/test_task_signals.py @@ -59,8 +59,9 @@ class TestBeforeTaskPublishHandler: def test_creates_task_for_consume_file(self, consume_input_doc, consume_overrides): task_id = send_publish( "documents.tasks.consume_file", - (consume_input_doc, consume_overrides), - {}, + (), + {"input_doc": consume_input_doc, "overrides": consume_overrides}, + headers={"trigger_source": PaperlessTask.TriggerSource.WEB_UI}, ) task = PaperlessTask.objects.get(task_id=task_id) assert task.task_type == PaperlessTask.TaskType.CONSUME_FILE @@ -102,8 +103,8 @@ class TestBeforeTaskPublishHandler: task_id = send_publish( "documents.tasks.consume_file", - (consume_input_doc, overrides), - {}, + (), + {"input_doc": consume_input_doc, "overrides": overrides}, ) task = PaperlessTask.objects.get(task_id=task_id) @@ -116,8 +117,8 @@ class TestBeforeTaskPublishHandler: task_id = send_publish( "documents.tasks.consume_file", - (consume_input_doc, overrides), - {}, + (), + {"input_doc": consume_input_doc, "overrides": overrides}, ) task = PaperlessTask.objects.get(task_id=task_id) @@ -167,37 +168,19 @@ class TestBeforeTaskPublishHandler: before_task_publish_handler(sender=None, headers=None, body=None) assert PaperlessTask.objects.count() == 0 - @pytest.mark.parametrize( - ("document_source", "expected_trigger_source"), - [ - pytest.param( - DocumentSource.ConsumeFolder, - PaperlessTask.TriggerSource.FOLDER_CONSUME, - id="folder_consume", - ), - pytest.param( - DocumentSource.MailFetch, - PaperlessTask.TriggerSource.EMAIL_CONSUME, - id="email_consume", - ), - ], - ) - def test_consume_document_source_maps_to_trigger_source( + def test_consume_file_without_trigger_source_header_defaults_to_manual( self, consume_input_doc, consume_overrides, - document_source: DocumentSource, - expected_trigger_source: PaperlessTask.TriggerSource, ) -> None: - """DocumentSource on the input doc maps to the correct TriggerSource.""" - consume_input_doc.source = document_source + """Without a trigger_source header the handler defaults to MANUAL.""" task_id = send_publish( "documents.tasks.consume_file", - (consume_input_doc, consume_overrides), - {}, + (), + {"input_doc": consume_input_doc, "overrides": consume_overrides}, ) task = PaperlessTask.objects.get(task_id=task_id) - assert task.trigger_source == expected_trigger_source + assert task.trigger_source == PaperlessTask.TriggerSource.MANUAL @pytest.mark.django_db diff --git a/src/documents/tests/test_workflows.py b/src/documents/tests/test_workflows.py index cbb4781f5..87cc54779 100644 --- a/src/documents/tests/test_workflows.py +++ b/src/documents/tests/test_workflows.py @@ -3636,7 +3636,7 @@ class TestWorkflows( PAPERLESS_FORCE_SCRIPT_NAME="/paperless", BASE_URL="/paperless/", ) - @mock.patch("documents.workflows.webhooks.send_webhook.delay") + @mock.patch("documents.workflows.webhooks.send_webhook.apply_async") def test_workflow_webhook_action_body(self, mock_post) -> None: """ GIVEN: @@ -3685,20 +3685,22 @@ class TestWorkflows( run_workflows(WorkflowTrigger.WorkflowTriggerType.DOCUMENT_UPDATED, doc) mock_post.assert_called_once_with( - url="http://paperless-ngx.com", - data=( - f"Test message: http://localhost:8000/paperless/documents/{doc.id}/" - f" with id {doc.id}" - ), - headers={}, - files=None, - as_json=False, + kwargs={ + "url": "http://paperless-ngx.com", + "data": ( + f"Test message: http://localhost:8000/paperless/documents/{doc.id}/" + f" with id {doc.id}" + ), + "headers": {}, + "files": None, + "as_json": False, + }, ) @override_settings( PAPERLESS_URL="http://localhost:8000", ) - @mock.patch("documents.workflows.webhooks.send_webhook.delay") + @mock.patch("documents.workflows.webhooks.send_webhook.apply_async") def test_workflow_webhook_action_w_files(self, mock_post) -> None: """ GIVEN: @@ -3750,11 +3752,13 @@ class TestWorkflows( run_workflows(WorkflowTrigger.WorkflowTriggerType.DOCUMENT_UPDATED, doc) mock_post.assert_called_once_with( - url="http://paperless-ngx.com", - data=f"Test message: http://localhost:8000/documents/{doc.id}/", - headers={}, - files={"file": ("simple.pdf", mock.ANY, "application/pdf")}, - as_json=False, + kwargs={ + "url": "http://paperless-ngx.com", + "data": f"Test message: http://localhost:8000/documents/{doc.id}/", + "headers": {}, + "files": {"file": ("simple.pdf", mock.ANY, "application/pdf")}, + "as_json": False, + }, ) @mock.patch("documents.signals.handlers.execute_webhook_action") @@ -4036,7 +4040,7 @@ class TestWorkflows( ) self.assertIn(expected_str, cm.output[0]) - @mock.patch("documents.workflows.webhooks.send_webhook.delay") + @mock.patch("documents.workflows.webhooks.send_webhook.apply_async") def test_workflow_webhook_action_consumption(self, mock_post) -> None: """ GIVEN: @@ -4376,7 +4380,7 @@ class TestWorkflows( @override_settings( PAPERLESS_URL="http://localhost:8000", ) - @mock.patch("documents.workflows.webhooks.send_webhook.delay") + @mock.patch("documents.workflows.webhooks.send_webhook.apply_async") def test_workflow_trash_with_webhook_action(self, mock_webhook_delay): """ GIVEN: @@ -4384,7 +4388,7 @@ class TestWorkflows( WHEN: - Document matches and workflow runs THEN: - - Webhook .delay() is called with complete data including file bytes + - Webhook .apply_async() is called with complete data including file bytes - Document is moved to trash (soft deleted) - Webhook task has all necessary data and doesn't rely on document existence """ @@ -4434,7 +4438,7 @@ class TestWorkflows( run_workflows(WorkflowTrigger.WorkflowTriggerType.DOCUMENT_UPDATED, doc) mock_webhook_delay.assert_called_once() - call_kwargs = mock_webhook_delay.call_args[1] + call_kwargs = mock_webhook_delay.call_args[1]["kwargs"] self.assertEqual(call_kwargs["url"], "https://paperless-ngx.com/webhook") self.assertEqual( call_kwargs["data"], diff --git a/src/documents/tests/utils.py b/src/documents/tests/utils.py index 98c8258b8..530f588e8 100644 --- a/src/documents/tests/utils.py +++ b/src/documents/tests/utils.py @@ -231,14 +231,16 @@ class ConsumerProgressMixin: self.send_progress_patcher.stop() -class DocumentConsumeDelayMixin: +class ConsumeTaskMixin: """ Provides mocking of the consume_file asynchronous task and useful utilities for decoding its arguments """ def setUp(self) -> None: - self.consume_file_patcher = mock.patch("documents.tasks.consume_file.delay") + self.consume_file_patcher = mock.patch( + "documents.tasks.consume_file.apply_async", + ) self.consume_file_mock = self.consume_file_patcher.start() super().setUp() @@ -246,48 +248,22 @@ class DocumentConsumeDelayMixin: super().tearDown() self.consume_file_patcher.stop() - def get_last_consume_delay_call_args( + def assert_queue_consumption_task_call_args( self, ) -> tuple[ConsumableDocument, DocumentMetadataOverrides]: - """ - Returns the most recent arguments to the async task - """ - # Must be at least 1 call - self.consume_file_mock.assert_called() + """Assert the task was queued exactly once and return its call args.""" + self.consume_file_mock.assert_called_once() + task_kwargs = self.consume_file_mock.call_args.kwargs["kwargs"] + return (task_kwargs["input_doc"], task_kwargs["overrides"]) - args, _ = self.consume_file_mock.call_args - input_doc, overrides = args - - return (input_doc, overrides) - - def get_all_consume_delay_call_args( + def get_all_consume_task_call_args( self, ) -> Iterator[tuple[ConsumableDocument, DocumentMetadataOverrides]]: - """ - Iterates over all calls to the async task and returns the arguments - """ - # Must be at least 1 call + """Iterate over all queued consume task calls and yield their call args.""" self.consume_file_mock.assert_called() - - for args, kwargs in self.consume_file_mock.call_args_list: - input_doc, overrides = args - - yield (input_doc, overrides) - - def get_specific_consume_delay_call_args( - self, - index: int, - ) -> tuple[ConsumableDocument, DocumentMetadataOverrides]: - """ - Returns the arguments of a specific call to the async task - """ - # Must be at least 1 call - self.consume_file_mock.assert_called() - - args, _ = self.consume_file_mock.call_args_list[index] - input_doc, overrides = args - - return (input_doc, overrides) + for call in self.consume_file_mock.call_args_list: + task_kwargs = call.kwargs["kwargs"] + yield (task_kwargs["input_doc"], task_kwargs["overrides"]) class TestMigrations(TransactionTestCase): diff --git a/src/documents/views.py b/src/documents/views.py index 21fd9a4a8..a8fd29a84 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1802,9 +1802,9 @@ class DocumentViewSet( if request.user is not None: overrides.actor_id = request.user.id - async_task = consume_file.delay( - input_doc, - overrides, + async_task = consume_file.apply_async( + kwargs={"input_doc": input_doc, "overrides": overrides}, + headers={"trigger_source": PaperlessTask.TriggerSource.WEB_UI}, ) logger.debug( f"Updated document {root_doc.id} with new version", @@ -2481,6 +2481,7 @@ class DocumentOperationPermissionMixin(PassUserMixin, DocumentSelectionMixin): "edit_pdf", "remove_password", } + METHOD_NAMES_REQUIRING_TRIGGER_SOURCE = METHOD_NAMES_REQUIRING_USER def _has_document_permissions( self, @@ -2571,12 +2572,19 @@ class DocumentOperationPermissionMixin(PassUserMixin, DocumentSelectionMixin): parameters = { k: v for k, v in validated_data.items() - if k not in {"documents", "all", "filters"} + if k not in {"documents", "all", "filters", "from_webui"} } user = self.request.user + from_webui = validated_data.get("from_webui", False) if method.__name__ in self.METHOD_NAMES_REQUIRING_USER: parameters["user"] = user + if method.__name__ in self.METHOD_NAMES_REQUIRING_TRIGGER_SOURCE: + parameters["trigger_source"] = ( + PaperlessTask.TriggerSource.WEB_UI + if from_webui + else PaperlessTask.TriggerSource.API_UPLOAD + ) if not self._has_document_permissions( user=user, @@ -2660,12 +2668,19 @@ class BulkEditView(DocumentOperationPermissionMixin): user = self.request.user method = serializer.validated_data.get("method") parameters = serializer.validated_data.get("parameters") + from_webui = serializer.validated_data.get("from_webui", False) documents = self._resolve_document_ids( user=user, validated_data=serializer.validated_data, ) if method.__name__ in self.METHOD_NAMES_REQUIRING_USER: parameters["user"] = user + if method.__name__ in self.METHOD_NAMES_REQUIRING_TRIGGER_SOURCE: + parameters["trigger_source"] = ( + PaperlessTask.TriggerSource.WEB_UI + if from_webui + else PaperlessTask.TriggerSource.API_UPLOAD + ) if not self._has_document_permissions( user=user, documents=documents, @@ -2959,9 +2974,15 @@ class PostDocumentView(GenericAPIView[Any]): custom_fields=custom_fields, ) - async_task = consume_file.delay( - input_doc, - input_doc_overrides, + async_task = consume_file.apply_async( + kwargs={"input_doc": input_doc, "overrides": input_doc_overrides}, + headers={ + "trigger_source": ( + PaperlessTask.TriggerSource.WEB_UI + if from_webui + else PaperlessTask.TriggerSource.API_UPLOAD + ), + }, ) return Response(async_task.id) @@ -3607,7 +3628,10 @@ class StoragePathViewSet(PermissionsAwareDocumentCountMixin, ModelViewSet[Storag response = super().destroy(request, *args, **kwargs) if doc_ids: - bulk_edit.bulk_update_documents.delay(doc_ids) + bulk_edit.bulk_update_documents.apply_async( + kwargs={"document_ids": doc_ids}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, + ) return response @@ -4117,7 +4141,10 @@ class ShareLinkBundleViewSet(PassUserMixin, ModelViewSet[ShareLinkBundle]): "file_path", ], ) - build_share_link_bundle.delay(bundle.pk) + build_share_link_bundle.apply_async( + kwargs={"bundle_id": bundle.pk}, + headers={"trigger_source": PaperlessTask.TriggerSource.MANUAL}, + ) bundle.document_total = len(ordered_documents) response_serializer = self.get_serializer(bundle) headers = self.get_success_headers(response_serializer.data) @@ -4150,7 +4177,10 @@ class ShareLinkBundleViewSet(PassUserMixin, ModelViewSet[ShareLinkBundle]): "file_path", ], ) - build_share_link_bundle.delay(bundle.pk) + build_share_link_bundle.apply_async( + kwargs={"bundle_id": bundle.pk}, + headers={"trigger_source": PaperlessTask.TriggerSource.MANUAL}, + ) bundle.document_total = ( getattr(bundle, "document_total", None) or bundle.documents.count() ) diff --git a/src/documents/workflows/actions.py b/src/documents/workflows/actions.py index 9744048e5..9089c6dd8 100644 --- a/src/documents/workflows/actions.py +++ b/src/documents/workflows/actions.py @@ -253,12 +253,14 @@ def execute_webhook_action( document.mime_type, ), } - send_webhook.delay( - url=action.webhook.url, - data=data, - headers=headers, - files=files, - as_json=action.webhook.as_json, + send_webhook.apply_async( + kwargs={ + "url": action.webhook.url, + "data": data, + "headers": headers, + "files": files, + "as_json": action.webhook.as_json, + }, ) logger.debug( f"Webhook to {action.webhook.url} queued", diff --git a/src/paperless_mail/mail.py b/src/paperless_mail/mail.py index 75bb7b134..430bceb4f 100644 --- a/src/paperless_mail/mail.py +++ b/src/paperless_mail/mail.py @@ -37,6 +37,7 @@ from documents.data_models import DocumentMetadataOverrides from documents.data_models import DocumentSource from documents.loggers import LoggingMixin from documents.models import Correspondent +from documents.models import PaperlessTask from documents.parsers import is_mime_type_supported from documents.tasks import consume_file from paperless.network import is_public_ip @@ -893,8 +894,12 @@ class MailAccountHandler(LoggingMixin): ) consume_task = consume_file.s( - input_doc, - doc_overrides, + input_doc=input_doc, + overrides=doc_overrides, + ).set( + headers={ + "trigger_source": PaperlessTask.TriggerSource.EMAIL_CONSUME, + }, ) consume_tasks.append(consume_task) @@ -991,9 +996,9 @@ class MailAccountHandler(LoggingMixin): ) consume_task = consume_file.s( - input_doc, - doc_overrides, - ) + input_doc=input_doc, + overrides=doc_overrides, + ).set(headers={"trigger_source": PaperlessTask.TriggerSource.EMAIL_CONSUME}) queue_consumption_tasks( consume_tasks=[consume_task], diff --git a/src/paperless_mail/tests/test_mail.py b/src/paperless_mail/tests/test_mail.py index 7d038dd38..26ee5307e 100644 --- a/src/paperless_mail/tests/test_mail.py +++ b/src/paperless_mail/tests/test_mail.py @@ -359,7 +359,8 @@ class MailMocker(DirectoriesMixin, FileSystemAssertsMixin, TestCase): consume_tasks, expected_signatures, ): - input_doc, overrides = consume_task.args + input_doc = consume_task.kwargs["input_doc"] + overrides = consume_task.kwargs["overrides"] # assert the file exists self.assertIsFile(input_doc.original_file) @@ -2022,7 +2023,7 @@ class TestMailAccountProcess(APITestCase): ) self.url = f"/api/mail_accounts/{self.account.pk}/process/" - @mock.patch("paperless_mail.tasks.process_mail_accounts.delay") + @mock.patch("paperless_mail.tasks.process_mail_accounts.apply_async") def test_mail_account_process_view(self, m) -> None: response = self.client.post(self.url) self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/src/paperless_mail/views.py b/src/paperless_mail/views.py index 8d6a7fa03..2e36f1b03 100644 --- a/src/paperless_mail/views.py +++ b/src/paperless_mail/views.py @@ -24,6 +24,7 @@ from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ReadOnlyModelViewSet from documents.filters import ObjectOwnedOrGrantedPermissionsFilter +from documents.models import PaperlessTask from documents.permissions import PaperlessObjectPermissions from documents.permissions import has_perms_owner_aware from documents.views import PassUserMixin @@ -156,7 +157,10 @@ class MailAccountViewSet(PassUserMixin, ModelViewSet[MailAccount]): @action(methods=["post"], detail=True) def process(self, request, pk=None): account = self.get_object() - process_mail_accounts.delay([account.pk]) + process_mail_accounts.apply_async( + kwargs={"account_ids": [account.pk]}, + headers={"trigger_source": PaperlessTask.TriggerSource.MANUAL}, + ) return Response({"result": "OK"})