Compare commits

..

1 Commits

Author SHA1 Message Date
stumpylog
4721f64e4c Converts all call sites and test asserts to use apply_async and headers 2026-04-16 16:21:59 -07:00
21 changed files with 424 additions and 390 deletions

View File

@@ -17,7 +17,9 @@ from pikepdf import Pdf
from documents.converters import convert_from_tiff_to_pdf
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides
from documents.data_models import DocumentSource
from documents.models import Document
from documents.models import PaperlessTask
from documents.models import Tag
from documents.plugins.base import ConsumeTaskPlugin
from documents.plugins.base import StopConsumeTaskError
@@ -193,23 +195,36 @@ class BarcodePlugin(ConsumeTaskPlugin):
from documents import tasks
_SOURCE_TO_TRIGGER: dict[DocumentSource, PaperlessTask.TriggerSource] = {
DocumentSource.ConsumeFolder: PaperlessTask.TriggerSource.FOLDER_CONSUME,
DocumentSource.ApiUpload: PaperlessTask.TriggerSource.API_UPLOAD,
DocumentSource.MailFetch: PaperlessTask.TriggerSource.EMAIL_CONSUME,
DocumentSource.WebUI: PaperlessTask.TriggerSource.WEB_UI,
}
trigger_source = _SOURCE_TO_TRIGGER.get(
self.input_doc.source,
PaperlessTask.TriggerSource.MANUAL,
)
# Create the split document tasks
for new_document in self.separate_pages(separator_pages):
copy_file_with_basic_stats(new_document, tmp_dir / new_document.name)
task = tasks.consume_file.delay(
ConsumableDocument(
# Same source, for templates
source=self.input_doc.source,
mailrule_id=self.input_doc.mailrule_id,
# Can't use same folder or the consume might grab it again
original_file=(tmp_dir / new_document.name).resolve(),
# Adding optional original_path for later uses in
# workflow matching
original_path=self.input_doc.original_file,
),
# All the same metadata
self.metadata,
task = tasks.consume_file.apply_async(
kwargs={
"input_doc": ConsumableDocument(
# Same source, for templates
source=self.input_doc.source,
mailrule_id=self.input_doc.mailrule_id,
# Can't use same folder or the consume might grab it again
original_file=(tmp_dir / new_document.name).resolve(),
# Adding optional original_path for later uses in
# workflow matching
original_path=self.input_doc.original_file,
),
"overrides": self.metadata,
},
headers={"trigger_source": trigger_source},
)
logger.info(f"Created new task {task.id} for {new_document.name}")

View File

@@ -22,6 +22,7 @@ from documents.models import CustomField
from documents.models import CustomFieldInstance
from documents.models import Document
from documents.models import DocumentType
from documents.models import PaperlessTask
from documents.models import StoragePath
from documents.models import Tag
from documents.permissions import set_permissions_for_object
@@ -113,7 +114,10 @@ def set_correspondent(
affected_docs = list(qs.values_list("pk", flat=True))
qs.update(correspondent=correspondent)
bulk_update_documents.delay(document_ids=affected_docs)
bulk_update_documents.apply_async(
kwargs={"document_ids": affected_docs},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
return "OK"
@@ -132,8 +136,9 @@ def set_storage_path(doc_ids: list[int], storage_path: StoragePath) -> Literal["
affected_docs = list(qs.values_list("pk", flat=True))
qs.update(storage_path=storage_path)
bulk_update_documents.delay(
document_ids=affected_docs,
bulk_update_documents.apply_async(
kwargs={"document_ids": affected_docs},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
return "OK"
@@ -151,7 +156,10 @@ def set_document_type(doc_ids: list[int], document_type: DocumentType) -> Litera
affected_docs = list(qs.values_list("pk", flat=True))
qs.update(document_type=document_type)
bulk_update_documents.delay(document_ids=affected_docs)
bulk_update_documents.apply_async(
kwargs={"document_ids": affected_docs},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
return "OK"
@@ -177,7 +185,10 @@ def add_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
DocumentTagRelationship.objects.bulk_create(to_create)
if affected_docs:
bulk_update_documents.delay(document_ids=list(affected_docs))
bulk_update_documents.apply_async(
kwargs={"document_ids": list(affected_docs)},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
return "OK"
@@ -195,7 +206,10 @@ def remove_tag(doc_ids: list[int], tag: int) -> Literal["OK"]:
qs.delete()
if affected_docs:
bulk_update_documents.delay(document_ids=affected_docs)
bulk_update_documents.apply_async(
kwargs={"document_ids": affected_docs},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
return "OK"
@@ -254,7 +268,10 @@ def modify_tags(
)
if affected_docs:
bulk_update_documents.delay(document_ids=affected_docs)
bulk_update_documents.apply_async(
kwargs={"document_ids": affected_docs},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
except Exception as e:
logger.error(f"Error modifying tags: {e}")
return "ERROR"
@@ -326,7 +343,10 @@ def modify_custom_fields(
field_id__in=remove_custom_fields,
).hard_delete()
bulk_update_documents.delay(document_ids=affected_docs)
bulk_update_documents.apply_async(
kwargs={"document_ids": affected_docs},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
return "OK"
@@ -369,8 +389,9 @@ def delete(doc_ids: list[int]) -> Literal["OK"]:
def reprocess(doc_ids: list[int]) -> Literal["OK"]:
for document_id in doc_ids:
update_document_content_maybe_archive_file.delay(
document_id=document_id,
update_document_content_maybe_archive_file.apply_async(
kwargs={"document_id": document_id},
headers={"trigger_source": PaperlessTask.TriggerSource.MANUAL},
)
return "OK"
@@ -396,7 +417,10 @@ def set_permissions(
affected_docs = list(qs.values_list("pk", flat=True))
bulk_update_documents.delay(document_ids=affected_docs)
bulk_update_documents.apply_async(
kwargs={"document_ids": affected_docs},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
return "OK"
@@ -407,6 +431,7 @@ def rotate(
*,
source_mode: SourceMode = SourceModeChoices.LATEST_VERSION,
user: User | None = None,
trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI,
) -> Literal["OK"]:
logger.info(
f"Attempting to rotate {len(doc_ids)} documents by {degrees} degrees.",
@@ -453,13 +478,16 @@ def rotate(
if user is not None:
overrides.actor_id = user.id
consume_file.delay(
ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
root_document_id=root_doc.id,
),
overrides,
consume_file.apply_async(
kwargs={
"input_doc": ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
root_document_id=root_doc.id,
),
"overrides": overrides,
},
headers={"trigger_source": trigger_source},
)
logger.info(
f"Queued new rotated version for document {root_doc.id} by {degrees} degrees",
@@ -478,6 +506,7 @@ def merge(
archive_fallback: bool = False,
source_mode: SourceMode = SourceModeChoices.LATEST_VERSION,
user: User | None = None,
trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI,
) -> Literal["OK"]:
logger.info(
f"Attempting to merge {len(doc_ids)} documents into a single document.",
@@ -556,12 +585,12 @@ def merge(
logger.info("Adding merged document to the task queue.")
consume_task = consume_file.s(
ConsumableDocument(
input_doc=ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
),
overrides,
)
overrides=overrides,
).set(headers={"trigger_source": trigger_source})
if delete_originals:
backup = release_archive_serial_numbers(affected_docs)
@@ -589,6 +618,7 @@ def split(
delete_originals: bool = False,
source_mode: SourceMode = SourceModeChoices.LATEST_VERSION,
user: User | None = None,
trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI,
) -> Literal["OK"]:
logger.info(
f"Attempting to split document {doc_ids[0]} into {len(pages)} documents",
@@ -631,12 +661,12 @@ def split(
)
consume_tasks.append(
consume_file.s(
ConsumableDocument(
input_doc=ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
),
overrides,
),
overrides=overrides,
).set(headers={"trigger_source": trigger_source}),
)
if delete_originals:
@@ -669,6 +699,7 @@ def delete_pages(
*,
source_mode: SourceMode = SourceModeChoices.LATEST_VERSION,
user: User | None = None,
trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI,
) -> Literal["OK"]:
logger.info(
f"Attempting to delete pages {pages} from {len(doc_ids)} documents",
@@ -698,13 +729,16 @@ def delete_pages(
overrides = DocumentMetadataOverrides().from_document(root_doc)
if user is not None:
overrides.actor_id = user.id
consume_file.delay(
ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
root_document_id=root_doc.id,
),
overrides,
consume_file.apply_async(
kwargs={
"input_doc": ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
root_document_id=root_doc.id,
),
"overrides": overrides,
},
headers={"trigger_source": trigger_source},
)
logger.info(
f"Queued new version for document {root_doc.id} after deleting pages {pages}",
@@ -724,6 +758,7 @@ def edit_pdf(
include_metadata: bool = True,
source_mode: SourceMode = SourceModeChoices.LATEST_VERSION,
user: User | None = None,
trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI,
) -> Literal["OK"]:
"""
Operations is a list of dictionaries describing the final PDF pages.
@@ -781,13 +816,16 @@ def edit_pdf(
if user is not None:
overrides.owner_id = user.id
overrides.actor_id = user.id
consume_file.delay(
ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
root_document_id=root_doc.id,
),
overrides,
consume_file.apply_async(
kwargs={
"input_doc": ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
root_document_id=root_doc.id,
),
"overrides": overrides,
},
headers={"trigger_source": trigger_source},
)
else:
consume_tasks = []
@@ -812,12 +850,12 @@ def edit_pdf(
pdf.save(version_filepath)
consume_tasks.append(
consume_file.s(
ConsumableDocument(
input_doc=ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=version_filepath,
),
overrides,
),
overrides=overrides,
).set(headers={"trigger_source": trigger_source}),
)
if delete_original:
@@ -853,6 +891,7 @@ def remove_password(
include_metadata: bool = True,
source_mode: SourceMode = SourceModeChoices.LATEST_VERSION,
user: User | None = None,
trigger_source: PaperlessTask.TriggerSource = PaperlessTask.TriggerSource.WEB_UI,
) -> Literal["OK"]:
"""
Remove password protection from PDF documents.
@@ -887,13 +926,16 @@ def remove_password(
if user is not None:
overrides.owner_id = user.id
overrides.actor_id = user.id
consume_file.delay(
ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
root_document_id=root_doc.id,
),
overrides,
consume_file.apply_async(
kwargs={
"input_doc": ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
root_document_id=root_doc.id,
),
"overrides": overrides,
},
headers={"trigger_source": trigger_source},
)
else:
consume_tasks = []
@@ -908,12 +950,12 @@ def remove_password(
consume_tasks.append(
consume_file.s(
ConsumableDocument(
input_doc=ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
),
overrides,
),
overrides=overrides,
).set(headers={"trigger_source": trigger_source}),
)
if delete_original:

View File

@@ -27,6 +27,7 @@ from watchfiles import watch
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides
from documents.data_models import DocumentSource
from documents.models import PaperlessTask
from documents.models import Tag
from documents.parsers import get_supported_file_extensions
from documents.tasks import consume_file
@@ -338,12 +339,15 @@ def _consume_file(
# Queue for consumption
try:
logger.info(f"Adding {filepath} to the task queue")
consume_file.delay(
ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
),
DocumentMetadataOverrides(tag_ids=tag_ids),
consume_file.apply_async(
kwargs={
"input_doc": ConsumableDocument(
source=DocumentSource.ConsumeFolder,
original_file=filepath,
),
"overrides": DocumentMetadataOverrides(tag_ids=tag_ids),
},
headers={"trigger_source": PaperlessTask.TriggerSource.FOLDER_CONSUME},
)
except Exception:
logger.exception(f"Error while queuing document {filepath}")

View File

@@ -1604,6 +1604,7 @@ class RotateDocumentsSerializer(DocumentSelectionSerializer, SourceModeValidatio
required=False,
default=bulk_edit.SourceModeChoices.LATEST_VERSION,
)
from_webui = serializers.BooleanField(required=False, default=False)
class MergeDocumentsSerializer(DocumentListSerializer, SourceModeValidationMixin):
@@ -1617,6 +1618,7 @@ class MergeDocumentsSerializer(DocumentListSerializer, SourceModeValidationMixin
required=False,
default=bulk_edit.SourceModeChoices.LATEST_VERSION,
)
from_webui = serializers.BooleanField(required=False, default=False)
class EditPdfDocumentsSerializer(DocumentListSerializer, SourceModeValidationMixin):
@@ -1628,6 +1630,7 @@ class EditPdfDocumentsSerializer(DocumentListSerializer, SourceModeValidationMix
required=False,
default=bulk_edit.SourceModeChoices.LATEST_VERSION,
)
from_webui = serializers.BooleanField(required=False, default=False)
def validate(self, attrs):
documents = attrs["documents"]
@@ -1679,6 +1682,7 @@ class RemovePasswordDocumentsSerializer(
required=False,
default=bulk_edit.SourceModeChoices.LATEST_VERSION,
)
from_webui = serializers.BooleanField(required=False, default=False)
class DeleteDocumentsSerializer(DocumentSelectionSerializer):
@@ -1726,6 +1730,7 @@ class BulkEditSerializer(
)
parameters = serializers.DictField(allow_empty=True, default={}, write_only=True)
from_webui = serializers.BooleanField(required=False, default=False)
def _validate_tag_id_list(self, tags, name="tags") -> None:
if not isinstance(tags, list):
@@ -2398,7 +2403,10 @@ class StoragePathSerializer(MatchingModelSerializer, OwnedObjectSerializer):
"""
doc_ids = [doc.id for doc in instance.documents.all()]
if doc_ids:
bulk_edit.bulk_update_documents.delay(doc_ids)
bulk_edit.bulk_update_documents.apply_async(
kwargs={"document_ids": doc_ids},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
return super().update(instance, validated_data)

View File

@@ -34,7 +34,6 @@ from documents import matching
from documents.caching import clear_document_caches
from documents.caching import invalidate_llm_suggestions_cache
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentSource
from documents.file_handling import create_source_path_directory
from documents.file_handling import delete_empty_directories
from documents.file_handling import generate_filename
@@ -1024,27 +1023,9 @@ _CELERY_STATE_TO_STATUS: dict[str, PaperlessTask.Status] = {
"REVOKED": PaperlessTask.Status.REVOKED,
}
_DOCUMENT_SOURCE_TO_TRIGGER: dict[DocumentSource, PaperlessTask.TriggerSource] = {
DocumentSource.ConsumeFolder: PaperlessTask.TriggerSource.FOLDER_CONSUME,
DocumentSource.ApiUpload: PaperlessTask.TriggerSource.API_UPLOAD,
DocumentSource.MailFetch: PaperlessTask.TriggerSource.EMAIL_CONSUME,
DocumentSource.WebUI: PaperlessTask.TriggerSource.WEB_UI,
}
def _get_consume_args(
args: tuple,
task_kwargs: dict,
) -> tuple[Any | None, Any | None]:
"""Extract (input_doc, overrides) from consume_file task arguments."""
input_doc = args[0] if args else task_kwargs.get("input_doc")
overrides = args[1] if len(args) >= 2 else task_kwargs.get("overrides")
return input_doc, overrides
def _extract_input_data(
task_type: PaperlessTask.TaskType,
args: tuple,
task_kwargs: dict,
) -> dict:
"""Build the input_data dict stored on the PaperlessTask record.
@@ -1055,8 +1036,9 @@ def _extract_input_data(
types store no input data and return {}.
"""
if task_type == PaperlessTask.TaskType.CONSUME_FILE:
input_doc, overrides = _get_consume_args(args, task_kwargs)
if input_doc is None: # pragma: no cover
input_doc = task_kwargs.get("input_doc")
overrides = task_kwargs.get("overrides")
if input_doc is None:
return {}
data: dict = {
"filename": input_doc.original_file.name,
@@ -1081,7 +1063,7 @@ def _extract_input_data(
return data
if task_type == PaperlessTask.TaskType.MAIL_FETCH:
account_ids = args[0] if args else task_kwargs.get("account_ids")
account_ids = task_kwargs.get("account_ids")
if account_ids is not None:
return {"account_ids": account_ids}
return {}
@@ -1090,46 +1072,30 @@ def _extract_input_data(
def _determine_trigger_source(
task_type: PaperlessTask.TaskType,
args: tuple,
task_kwargs: dict,
headers: dict,
) -> PaperlessTask.TriggerSource:
"""Resolve the TriggerSource for a task being published to the broker.
Priority order:
1. Explicit trigger_source header (set by beat schedule or apply_async callers).
2. For consume_file tasks, the DocumentSource on the input document.
3. MANUAL as the catch-all for all other cases.
Reads the trigger_source header set by the caller; falls back to MANUAL
when the header is absent or contains an unrecognised value.
"""
# Explicit header takes priority -- callers pass a TriggerSource DB value directly.
header_source = headers.get("trigger_source")
if header_source is not None:
try:
return PaperlessTask.TriggerSource(header_source)
except ValueError:
pass
if task_type == PaperlessTask.TaskType.CONSUME_FILE:
input_doc, _ = _get_consume_args(args, task_kwargs)
if input_doc is not None:
return _DOCUMENT_SOURCE_TO_TRIGGER.get(
input_doc.source,
PaperlessTask.TriggerSource.API_UPLOAD,
)
return PaperlessTask.TriggerSource.MANUAL
def _extract_owner_id(
task_type: PaperlessTask.TaskType,
args: tuple,
task_kwargs: dict,
) -> int | None:
"""Return the owner_id from consume_file overrides, or None for all other task types."""
if task_type != PaperlessTask.TaskType.CONSUME_FILE:
return None
_, overrides = _get_consume_args(args, task_kwargs)
overrides = task_kwargs.get("overrides")
if overrides and hasattr(overrides, "owner_id"):
return overrides.owner_id
return None # pragma: no cover
@@ -1177,17 +1143,12 @@ def before_task_publish_handler(
try:
close_old_connections()
args, task_kwargs, _ = body
_, task_kwargs, _ = body
task_id = headers["id"]
input_data = _extract_input_data(task_type, args, task_kwargs)
trigger_source = _determine_trigger_source(
task_type,
args,
task_kwargs,
headers,
)
owner_id = _extract_owner_id(task_type, args, task_kwargs)
input_data = _extract_input_data(task_type, task_kwargs)
trigger_source = _determine_trigger_source(headers)
owner_id = _extract_owner_id(task_type, task_kwargs)
PaperlessTask.objects.create(
task_id=task_id,

View File

@@ -40,6 +40,7 @@ from documents.models import Correspondent
from documents.models import CustomFieldInstance
from documents.models import Document
from documents.models import DocumentType
from documents.models import PaperlessTask
from documents.models import ShareLink
from documents.models import ShareLinkBundle
from documents.models import StoragePath
@@ -600,7 +601,10 @@ def update_document_parent_tags(tag: Tag, new_parent: Tag) -> None:
)
if affected:
bulk_update_documents.delay(document_ids=list(affected))
bulk_update_documents.apply_async(
kwargs={"document_ids": list(affected)},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
@shared_task

View File

@@ -26,7 +26,7 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
self.user = user
self.client.force_authenticate(user=user)
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay")
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.apply_async")
self.async_task = patcher.start()
self.addCleanup(patcher.stop)
self.c1 = Correspondent.objects.create(name="c1")
@@ -62,7 +62,7 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
m.return_value = return_value
m.__name__ = method_name
@mock.patch("documents.bulk_edit.bulk_update_documents.delay")
@mock.patch("documents.bulk_edit.bulk_update_documents.apply_async")
def test_api_set_correspondent(self, bulk_update_task_mock) -> None:
self.assertNotEqual(self.doc1.correspondent, self.c1)
response = self.client.post(
@@ -79,9 +79,13 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.doc1.refresh_from_db()
self.assertEqual(self.doc1.correspondent, self.c1)
bulk_update_task_mock.assert_called_once_with(document_ids=[self.doc1.pk])
bulk_update_task_mock.assert_called_once()
self.assertCountEqual(
bulk_update_task_mock.call_args.kwargs["kwargs"]["document_ids"],
[self.doc1.pk],
)
@mock.patch("documents.bulk_edit.bulk_update_documents.delay")
@mock.patch("documents.bulk_edit.bulk_update_documents.apply_async")
def test_api_unset_correspondent(self, bulk_update_task_mock) -> None:
self.doc1.correspondent = self.c1
self.doc1.save()
@@ -103,7 +107,7 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
self.doc1.refresh_from_db()
self.assertIsNone(self.doc1.correspondent)
@mock.patch("documents.bulk_edit.bulk_update_documents.delay")
@mock.patch("documents.bulk_edit.bulk_update_documents.apply_async")
def test_api_set_type(self, bulk_update_task_mock) -> None:
self.assertNotEqual(self.doc1.document_type, self.dt1)
response = self.client.post(
@@ -120,9 +124,13 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.doc1.refresh_from_db()
self.assertEqual(self.doc1.document_type, self.dt1)
bulk_update_task_mock.assert_called_once_with(document_ids=[self.doc1.pk])
bulk_update_task_mock.assert_called_once()
self.assertCountEqual(
bulk_update_task_mock.call_args.kwargs["kwargs"]["document_ids"],
[self.doc1.pk],
)
@mock.patch("documents.bulk_edit.bulk_update_documents.delay")
@mock.patch("documents.bulk_edit.bulk_update_documents.apply_async")
def test_api_unset_type(self, bulk_update_task_mock) -> None:
self.doc1.document_type = self.dt1
self.doc1.save()
@@ -141,9 +149,13 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.doc1.refresh_from_db()
self.assertIsNone(self.doc1.document_type)
bulk_update_task_mock.assert_called_once_with(document_ids=[self.doc1.pk])
bulk_update_task_mock.assert_called_once()
self.assertCountEqual(
bulk_update_task_mock.call_args.kwargs["kwargs"]["document_ids"],
[self.doc1.pk],
)
@mock.patch("documents.bulk_edit.bulk_update_documents.delay")
@mock.patch("documents.bulk_edit.bulk_update_documents.apply_async")
def test_api_add_tag(self, bulk_update_task_mock) -> None:
self.assertFalse(self.doc1.tags.filter(pk=self.t1.pk).exists())
@@ -163,9 +175,13 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
self.assertTrue(self.doc1.tags.filter(pk=self.t1.pk).exists())
bulk_update_task_mock.assert_called_once_with(document_ids=[self.doc1.pk])
bulk_update_task_mock.assert_called_once()
self.assertCountEqual(
bulk_update_task_mock.call_args.kwargs["kwargs"]["document_ids"],
[self.doc1.pk],
)
@mock.patch("documents.bulk_edit.bulk_update_documents.delay")
@mock.patch("documents.bulk_edit.bulk_update_documents.apply_async")
def test_api_remove_tag(self, bulk_update_task_mock) -> None:
self.doc1.tags.add(self.t1)

View File

@@ -537,7 +537,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
async_task.id = "task-123"
with mock.patch("documents.views.consume_file") as consume_mock:
consume_mock.delay.return_value = async_task
consume_mock.apply_async.return_value = async_task
resp = self.client.post(
f"/api/documents/{root.id}/update_version/",
{"document": upload, "version_label": " New Version "},
@@ -546,8 +546,9 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp.data, "task-123")
consume_mock.delay.assert_called_once()
input_doc, overrides = consume_mock.delay.call_args[0]
consume_mock.apply_async.assert_called_once()
task_kwargs = consume_mock.apply_async.call_args.kwargs["kwargs"]
input_doc, overrides = task_kwargs["input_doc"], task_kwargs["overrides"]
self.assertEqual(input_doc.root_document_id, root.id)
self.assertEqual(input_doc.source, DocumentSource.ApiUpload)
self.assertEqual(overrides.version_label, "New Version")
@@ -571,7 +572,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
async_task.id = "task-123"
with mock.patch("documents.views.consume_file") as consume_mock:
consume_mock.delay.return_value = async_task
consume_mock.apply_async.return_value = async_task
resp = self.client.post(
f"/api/documents/{version.id}/update_version/",
{"document": upload, "version_label": " New Version "},
@@ -580,8 +581,9 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp.data, "task-123")
consume_mock.delay.assert_called_once()
input_doc, overrides = consume_mock.delay.call_args[0]
consume_mock.apply_async.assert_called_once()
task_kwargs = consume_mock.apply_async.call_args.kwargs["kwargs"]
input_doc, overrides = task_kwargs["input_doc"], task_kwargs["overrides"]
self.assertEqual(input_doc.root_document_id, root.id)
self.assertEqual(overrides.version_label, "New Version")
self.assertEqual(overrides.actor_id, self.user.id)
@@ -595,7 +597,7 @@ class TestDocumentVersioningApi(DirectoriesMixin, APITestCase):
upload = self._make_pdf_upload()
with mock.patch("documents.views.consume_file") as consume_mock:
consume_mock.delay.side_effect = Exception("boom")
consume_mock.apply_async.side_effect = Exception("boom")
resp = self.client.post(
f"/api/documents/{root.id}/update_version/",
{"document": upload},

View File

@@ -47,11 +47,11 @@ from documents.models import Workflow
from documents.models import WorkflowAction
from documents.models import WorkflowTrigger
from documents.signals.handlers import run_workflows
from documents.tests.utils import ConsumeTaskMixin
from documents.tests.utils import DirectoriesMixin
from documents.tests.utils import DocumentConsumeDelayMixin
class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
class TestDocumentApi(DirectoriesMixin, ConsumeTaskMixin, APITestCase):
def setUp(self) -> None:
super().setUp()
@@ -1400,9 +1400,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
input_doc, overrides = self.get_last_consume_delay_call_args()
input_doc, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(input_doc.original_file.name, "simple.pdf")
self.assertTrue(
@@ -1432,9 +1430,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
input_doc, overrides = self.get_last_consume_delay_call_args()
input_doc, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(input_doc.original_file.name, "outside.pdf")
self.assertEqual(overrides.filename, "outside.pdf")
@@ -1474,9 +1470,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
input_doc, overrides = self.get_last_consume_delay_call_args()
input_doc, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(input_doc.original_file.name, "outside.pdf")
self.assertEqual(overrides.filename, "outside.pdf")
@@ -1558,9 +1552,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
input_doc, overrides = self.get_last_consume_delay_call_args()
input_doc, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(input_doc.original_file.name, "simple.pdf")
self.assertTrue(
@@ -1612,9 +1604,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
_, overrides = self.get_last_consume_delay_call_args()
_, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(overrides.title, "my custom title")
self.assertIsNone(overrides.correspondent_id)
@@ -1634,9 +1624,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
_, overrides = self.get_last_consume_delay_call_args()
_, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(overrides.correspondent_id, c.id)
self.assertIsNone(overrides.title)
@@ -1670,9 +1658,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
_, overrides = self.get_last_consume_delay_call_args()
_, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(overrides.document_type_id, dt.id)
self.assertIsNone(overrides.correspondent_id)
@@ -1706,9 +1692,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
_, overrides = self.get_last_consume_delay_call_args()
_, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(overrides.storage_path_id, sp.id)
self.assertIsNone(overrides.correspondent_id)
@@ -1743,9 +1727,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
_, overrides = self.get_last_consume_delay_call_args()
_, overrides = self.assert_queue_consumption_task_call_args()
self.assertCountEqual(overrides.tag_ids, [t1.id, t2.id])
self.assertIsNone(overrides.document_type_id)
@@ -1790,9 +1772,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
_, overrides = self.get_last_consume_delay_call_args()
_, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(overrides.created, created.date())
@@ -1809,9 +1789,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
input_doc, overrides = self.get_last_consume_delay_call_args()
input_doc, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(input_doc.original_file.name, "simple.pdf")
self.assertEqual(overrides.filename, "simple.pdf")
@@ -1841,9 +1819,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
input_doc, overrides = self.get_last_consume_delay_call_args()
input_doc, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(input_doc.original_file.name, "simple.pdf")
self.assertEqual(overrides.filename, "simple.pdf")
@@ -1898,9 +1874,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
input_doc, overrides = self.get_last_consume_delay_call_args()
input_doc, overrides = self.assert_queue_consumption_task_call_args()
new_overrides, _ = run_workflows(
trigger_type=WorkflowTrigger.WorkflowTriggerType.CONSUMPTION,
@@ -1946,9 +1920,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
input_doc, overrides = self.get_last_consume_delay_call_args()
input_doc, overrides = self.assert_queue_consumption_task_call_args()
self.assertEqual(input_doc.original_file.name, "simple.pdf")
self.assertEqual(overrides.filename, "simple.pdf")
@@ -2047,9 +2019,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.consume_file_mock.assert_called_once()
input_doc, _ = self.get_last_consume_delay_call_args()
input_doc, _ = self.assert_queue_consumption_task_call_args()
self.assertEqual(input_doc.source, WorkflowTrigger.DocumentSourceChoices.WEB_UI)

View File

@@ -291,7 +291,7 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(StoragePath.objects.count(), 2)
@mock.patch("documents.bulk_edit.bulk_update_documents.delay")
@mock.patch("documents.bulk_edit.bulk_update_documents.apply_async")
def test_api_update_storage_path(self, bulk_update_mock) -> None:
"""
GIVEN:
@@ -316,11 +316,12 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase):
bulk_update_mock.assert_called_once()
args, _ = bulk_update_mock.call_args
self.assertCountEqual(
[document.pk],
bulk_update_mock.call_args.kwargs["kwargs"]["document_ids"],
)
self.assertCountEqual([document.pk], args[0])
@mock.patch("documents.bulk_edit.bulk_update_documents.delay")
@mock.patch("documents.bulk_edit.bulk_update_documents.apply_async")
def test_api_delete_storage_path(self, bulk_update_mock) -> None:
"""
GIVEN:
@@ -347,7 +348,11 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
# only called once
bulk_update_mock.assert_called_once_with([document.pk])
bulk_update_mock.assert_called_once()
self.assertCountEqual(
[document.pk],
bulk_update_mock.call_args.kwargs["kwargs"]["document_ids"],
)
def test_test_storage_path(self) -> None:
"""

View File

@@ -17,8 +17,8 @@ from documents.data_models import DocumentSource
from documents.models import Document
from documents.models import Tag
from documents.plugins.base import StopConsumeTaskError
from documents.tests.utils import ConsumeTaskMixin
from documents.tests.utils import DirectoriesMixin
from documents.tests.utils import DocumentConsumeDelayMixin
from documents.tests.utils import DummyProgressManager
from documents.tests.utils import FileSystemAssertsMixin
from documents.tests.utils import SampleDirMixin
@@ -601,7 +601,7 @@ class TestBarcodeNewConsume(
DirectoriesMixin,
FileSystemAssertsMixin,
SampleDirMixin,
DocumentConsumeDelayMixin,
ConsumeTaskMixin,
TestCase,
):
@override_settings(CONSUMER_ENABLE_BARCODES=True)
@@ -646,7 +646,7 @@ class TestBarcodeNewConsume(
for (
new_input_doc,
new_doc_overrides,
) in self.get_all_consume_delay_call_args():
) in self.get_all_consume_task_call_args():
self.assertIsFile(new_input_doc.original_file)
self.assertEqual(new_input_doc.original_path, temp_copy)
self.assertEqual(new_input_doc.source, DocumentSource.ConsumeFolder)

View File

@@ -31,7 +31,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
self.group1 = Group.objects.create(name="group1")
self.group2 = Group.objects.create(name="group2")
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay")
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.apply_async")
self.async_task = patcher.start()
self.addCleanup(patcher.stop)
self.c1 = Correspondent.objects.create(name="c1")
@@ -74,7 +74,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
)
self.assertEqual(Document.objects.filter(correspondent=self.c2).count(), 3)
self.async_task.assert_called_once()
_, kwargs = self.async_task.call_args
kwargs = self.async_task.call_args.kwargs["kwargs"]
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id])
def test_unset_correspondent(self) -> None:
@@ -82,7 +82,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
bulk_edit.set_correspondent([self.doc1.id, self.doc2.id, self.doc3.id], None)
self.assertEqual(Document.objects.filter(correspondent=self.c2).count(), 0)
self.async_task.assert_called_once()
_, kwargs = self.async_task.call_args
kwargs = self.async_task.call_args.kwargs["kwargs"]
self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id])
def test_set_document_type(self) -> None:
@@ -93,7 +93,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
)
self.assertEqual(Document.objects.filter(document_type=self.dt2).count(), 3)
self.async_task.assert_called_once()
_, kwargs = self.async_task.call_args
kwargs = self.async_task.call_args.kwargs["kwargs"]
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id])
def test_unset_document_type(self) -> None:
@@ -101,7 +101,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
bulk_edit.set_document_type([self.doc1.id, self.doc2.id, self.doc3.id], None)
self.assertEqual(Document.objects.filter(document_type=self.dt2).count(), 0)
self.async_task.assert_called_once()
_, kwargs = self.async_task.call_args
kwargs = self.async_task.call_args.kwargs["kwargs"]
self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id])
def test_set_document_storage_path(self) -> None:
@@ -123,7 +123,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
self.assertEqual(Document.objects.filter(storage_path=None).count(), 4)
self.async_task.assert_called_once()
_, kwargs = self.async_task.call_args
kwargs = self.async_task.call_args.kwargs["kwargs"]
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id])
@@ -154,7 +154,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
self.assertEqual(Document.objects.filter(storage_path=None).count(), 5)
self.async_task.assert_called()
_, kwargs = self.async_task.call_args
kwargs = self.async_task.call_args.kwargs["kwargs"]
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id])
@@ -166,7 +166,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
)
self.assertEqual(Document.objects.filter(tags__id=self.t1.id).count(), 4)
self.async_task.assert_called_once()
_, kwargs = self.async_task.call_args
kwargs = self.async_task.call_args.kwargs["kwargs"]
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc3.id])
def test_remove_tag(self) -> None:
@@ -174,7 +174,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
bulk_edit.remove_tag([self.doc1.id, self.doc3.id, self.doc4.id], self.t1.id)
self.assertEqual(Document.objects.filter(tags__id=self.t1.id).count(), 1)
self.async_task.assert_called_once()
_, kwargs = self.async_task.call_args
kwargs = self.async_task.call_args.kwargs["kwargs"]
self.assertCountEqual(kwargs["document_ids"], [self.doc4.id])
def test_modify_tags(self) -> None:
@@ -191,7 +191,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
self.assertCountEqual(list(self.doc3.tags.all()), [self.t2, tag_unrelated])
self.async_task.assert_called_once()
_, kwargs = self.async_task.call_args
kwargs = self.async_task.call_args.kwargs["kwargs"]
# TODO: doc3 should not be affected, but the query for that is rather complicated
self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id])
@@ -248,7 +248,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
)
self.async_task.assert_called_once()
_, kwargs = self.async_task.call_args
kwargs = self.async_task.call_args.kwargs["kwargs"]
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id])
def test_modify_custom_fields_with_values(self) -> None:
@@ -325,7 +325,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
)
self.async_task.assert_called_once()
_, kwargs = self.async_task.call_args
kwargs = self.async_task.call_args.kwargs["kwargs"]
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id])
# removal of document link cf, should also remove symmetric link
@@ -428,7 +428,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
self.assertEqual(source_doc.id, version2.id)
self.assertNotEqual(source_doc.id, version1.id)
@mock.patch("documents.tasks.bulk_update_documents.delay")
@mock.patch("documents.tasks.bulk_update_documents.apply_async")
def test_set_permissions(self, m) -> None:
doc_ids = [self.doc1.id, self.doc2.id, self.doc3.id]
@@ -467,7 +467,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
)
self.assertEqual(groups_with_perms.count(), 1)
@mock.patch("documents.tasks.bulk_update_documents.delay")
@mock.patch("documents.tasks.bulk_update_documents.apply_async")
def test_set_permissions_merge(self, m) -> None:
doc_ids = [self.doc1.id, self.doc2.id, self.doc3.id]
@@ -643,20 +643,20 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
mock_consume_file.assert_called()
consume_file_args, _ = mock_consume_file.call_args
call_kwargs = mock_consume_file.call_args.kwargs
self.assertEqual(
Path(consume_file_args[0].original_file).name,
Path(call_kwargs["input_doc"].original_file).name,
expected_filename,
)
self.assertEqual(consume_file_args[1].title, None)
self.assertEqual(call_kwargs["overrides"].title, None)
# No metadata_document_id, delete_originals False, so ASN should be None
self.assertIsNone(consume_file_args[1].asn)
self.assertIsNone(call_kwargs["overrides"].asn)
# With metadata_document_id overrides
result = bulk_edit.merge(doc_ids, metadata_document_id=metadata_document_id)
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].title, "B (merged)")
self.assertEqual(consume_file_args[1].created, self.doc2.created)
call_kwargs = mock_consume_file.call_args.kwargs
self.assertEqual(call_kwargs["overrides"].title, "B (merged)")
self.assertEqual(call_kwargs["overrides"].created, self.doc2.created)
self.assertEqual(result, "OK")
@@ -720,16 +720,15 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_consume_file.assert_called()
mock_delete_documents.assert_called()
consume_sig = mock_consume_file.return_value
consume_sig.apply_async.assert_called_once()
mock_consume_file.return_value.set.return_value.apply_async.assert_called_once()
consume_file_args, _ = mock_consume_file.call_args
call_kwargs = mock_consume_file.call_args.kwargs
self.assertEqual(
Path(consume_file_args[0].original_file).name,
Path(call_kwargs["input_doc"].original_file).name,
expected_filename,
)
self.assertEqual(consume_file_args[1].title, None)
self.assertEqual(consume_file_args[1].asn, 101)
self.assertEqual(call_kwargs["overrides"].title, None)
self.assertEqual(call_kwargs["overrides"].asn, 101)
delete_documents_args, _ = mock_delete_documents.call_args
self.assertEqual(
@@ -764,7 +763,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.doc1.archive_serial_number = 111
self.doc1.save()
sig = mock.Mock()
sig.apply_async.side_effect = Exception("boom")
sig.set.return_value.apply_async.side_effect = Exception("boom")
mock_consume_file.return_value = sig
with self.assertRaises(Exception):
@@ -801,8 +800,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].asn, 202)
self.assertEqual(mock_consume_file.call_args.kwargs["overrides"].asn, 202)
def test_restore_archive_serial_numbers_task(self) -> None:
"""
@@ -843,9 +841,8 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
mock_consume_file.assert_called()
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(
Path(consume_file_args[0].original_file).name,
Path(mock_consume_file.call_args.kwargs["input_doc"].original_file).name,
expected_filename,
)
@@ -889,9 +886,11 @@ class TestPDFActions(DirectoriesMixin, TestCase):
user = User.objects.create(username="test_user")
result = bulk_edit.split(doc_ids, pages, delete_originals=False, user=user)
self.assertEqual(mock_consume_file.call_count, 2)
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].title, "B (split 2)")
self.assertIsNone(consume_file_args[1].asn)
self.assertEqual(
mock_consume_file.call_args.kwargs["overrides"].title,
"B (split 2)",
)
self.assertIsNone(mock_consume_file.call_args.kwargs["overrides"].asn)
self.assertEqual(result, "OK")
@@ -953,8 +952,10 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(result, "OK")
self.assertEqual(mock_consume_file.call_count, 2)
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].title, "B (split 2)")
self.assertEqual(
mock_consume_file.call_args.kwargs["overrides"].title,
"B (split 2)",
)
mock_delete_documents.assert_called()
mock_chord.assert_called_once()
@@ -1001,7 +1002,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.doc2.refresh_from_db()
self.assertEqual(self.doc2.archive_serial_number, 222)
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("pikepdf.Pdf.save")
def test_split_with_errors(self, mock_save_pdf, mock_consume_file) -> None:
"""
@@ -1025,7 +1026,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_consume_file.assert_not_called()
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
def test_rotate(self, mock_consume_delay):
"""
GIVEN:
@@ -1042,12 +1043,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_consume_delay.call_args_list,
doc_ids,
):
consumable, overrides = call.args
self.assertEqual(consumable.root_document_id, expected_id)
self.assertIsNotNone(overrides)
task_kwargs = call.kwargs["kwargs"]
self.assertEqual(task_kwargs["input_doc"].root_document_id, expected_id)
self.assertIsNotNone(task_kwargs["overrides"])
self.assertEqual(result, "OK")
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("pikepdf.Pdf.save")
def test_rotate_with_error(
self,
@@ -1073,7 +1074,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertIn(expected_str, error_str)
mock_consume_delay.assert_not_called()
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
def test_rotate_non_pdf(
self,
mock_consume_delay,
@@ -1091,13 +1092,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
expected_str = f"Document {self.img_doc.id} is not a PDF, skipping rotation"
self.assertTrue(any(expected_str in line for line in cm.output))
self.assertEqual(mock_consume_delay.call_count, 1)
consumable, overrides = mock_consume_delay.call_args[0]
self.assertEqual(consumable.root_document_id, self.doc2.id)
self.assertIsNotNone(overrides)
task_kwargs = mock_consume_delay.call_args.kwargs["kwargs"]
self.assertEqual(task_kwargs["input_doc"].root_document_id, self.doc2.id)
self.assertIsNotNone(task_kwargs["overrides"])
self.assertEqual(result, "OK")
@mock.patch("documents.data_models.magic.from_file", return_value="application/pdf")
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("pikepdf.open")
def test_rotate_explicit_selection_uses_root_source_when_root_selected(
self,
@@ -1124,7 +1125,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_open.assert_called_once_with(self.doc2.source_path)
mock_consume_delay.assert_called_once()
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("pikepdf.Pdf.save")
@mock.patch("documents.data_models.magic.from_file", return_value="application/pdf")
def test_delete_pages(self, mock_magic, mock_pdf_save, mock_consume_delay):
@@ -1142,14 +1143,16 @@ class TestPDFActions(DirectoriesMixin, TestCase):
result = bulk_edit.delete_pages(doc_ids, pages)
mock_pdf_save.assert_called_once()
mock_consume_delay.assert_called_once()
consumable, overrides = mock_consume_delay.call_args[0]
self.assertEqual(consumable.root_document_id, self.doc2.id)
self.assertTrue(str(consumable.original_file).endswith("_pages_deleted.pdf"))
self.assertIsNotNone(overrides)
task_kwargs = mock_consume_delay.call_args.kwargs["kwargs"]
self.assertEqual(task_kwargs["input_doc"].root_document_id, self.doc2.id)
self.assertTrue(
str(task_kwargs["input_doc"].original_file).endswith("_pages_deleted.pdf"),
)
self.assertIsNotNone(task_kwargs["overrides"])
self.assertEqual(result, "OK")
@mock.patch("documents.data_models.magic.from_file", return_value="application/pdf")
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("pikepdf.open")
def test_delete_pages_explicit_selection_uses_root_source_when_root_selected(
self,
@@ -1176,7 +1179,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_open.assert_called_once_with(self.doc2.source_path)
mock_consume_delay.assert_called_once()
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("pikepdf.Pdf.save")
def test_delete_pages_with_error(self, mock_pdf_save, mock_consume_delay):
"""
@@ -1259,8 +1262,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
result = bulk_edit.edit_pdf(doc_ids, operations, delete_original=True)
self.assertEqual(result, "OK")
mock_chord.assert_called_once()
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].asn, 250)
self.assertEqual(mock_consume_file.call_args.kwargs["overrides"].asn, 250)
self.doc2.refresh_from_db()
self.assertIsNone(self.doc2.archive_serial_number)
@@ -1297,7 +1299,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.doc2.refresh_from_db()
self.assertEqual(self.doc2.archive_serial_number, 333)
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
def test_edit_pdf_with_update_document(self, mock_consume_delay):
"""
GIVEN:
@@ -1319,13 +1321,15 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(result, "OK")
mock_consume_delay.assert_called_once()
consumable, overrides = mock_consume_delay.call_args[0]
self.assertEqual(consumable.root_document_id, self.doc2.id)
self.assertTrue(str(consumable.original_file).endswith("_edited.pdf"))
self.assertIsNotNone(overrides)
task_kwargs = mock_consume_delay.call_args.kwargs["kwargs"]
self.assertEqual(task_kwargs["input_doc"].root_document_id, self.doc2.id)
self.assertTrue(
str(task_kwargs["input_doc"].original_file).endswith("_edited.pdf"),
)
self.assertIsNotNone(task_kwargs["overrides"])
@mock.patch("documents.data_models.magic.from_file", return_value="application/pdf")
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("pikepdf.new")
@mock.patch("pikepdf.open")
def test_edit_pdf_explicit_selection_uses_root_source_when_root_selected(
@@ -1433,7 +1437,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_consume_file.assert_not_called()
@mock.patch("documents.bulk_edit.update_document_content_maybe_archive_file.delay")
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
@mock.patch("pikepdf.open")
def test_remove_password_update_document(
@@ -1468,18 +1472,18 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf.remove_unreferenced_resources.assert_called_once()
mock_update_document.assert_not_called()
mock_consume_delay.assert_called_once()
consumable, overrides = mock_consume_delay.call_args[0]
task_kwargs = mock_consume_delay.call_args.kwargs["kwargs"]
expected_path = temp_dir / f"{doc.id}_unprotected.pdf"
self.assertTrue(expected_path.exists())
self.assertEqual(
Path(consumable.original_file).resolve(),
Path(task_kwargs["input_doc"].original_file).resolve(),
expected_path.resolve(),
)
self.assertEqual(consumable.root_document_id, doc.id)
self.assertIsNotNone(overrides)
self.assertEqual(task_kwargs["input_doc"].root_document_id, doc.id)
self.assertIsNotNone(task_kwargs["overrides"])
@mock.patch("documents.data_models.magic.from_file", return_value="application/pdf")
@mock.patch("documents.tasks.consume_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("pikepdf.open")
def test_remove_password_explicit_selection_uses_root_source_when_root_selected(
self,
@@ -1548,9 +1552,9 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path, password="secret")
mock_consume_file.assert_called_once()
consume_args, _ = mock_consume_file.call_args
consumable_document = consume_args[0]
overrides = consume_args[1]
call_kwargs = mock_consume_file.call_args.kwargs
consumable_document = call_kwargs["input_doc"]
overrides = call_kwargs["overrides"]
expected_path = temp_dir / f"{doc.id}_unprotected.pdf"
self.assertTrue(expected_path.exists())
self.assertEqual(
@@ -1558,7 +1562,9 @@ class TestPDFActions(DirectoriesMixin, TestCase):
expected_path.resolve(),
)
self.assertEqual(overrides.owner_id, user.id)
mock_group.assert_called_once_with([mock_consume_file.return_value])
mock_group.assert_called_once_with(
[mock_consume_file.return_value.set.return_value],
)
mock_group.return_value.delay.assert_called_once()
mock_chord.assert_not_called()

View File

@@ -97,12 +97,10 @@ def consumer_filter() -> ConsumerFilter:
@pytest.fixture
def mock_consume_file_delay(mocker: MockerFixture) -> MagicMock:
"""Mock the consume_file.delay celery task."""
mock_task = mocker.patch(
"""Mock the consume_file task."""
return mocker.patch(
"documents.management.commands.document_consumer.consume_file",
)
mock_task.delay = mocker.MagicMock()
return mock_task
@pytest.fixture
@@ -453,9 +451,9 @@ class TestConsumeFile:
subdirs_as_tags=False,
)
mock_consume_file_delay.delay.assert_called_once()
call_args = mock_consume_file_delay.delay.call_args
consumable_doc = call_args[0][0]
mock_consume_file_delay.apply_async.assert_called_once()
call_args = mock_consume_file_delay.apply_async.call_args
consumable_doc = call_args.kwargs["kwargs"]["input_doc"]
assert isinstance(consumable_doc, ConsumableDocument)
assert consumable_doc.original_file == target
assert consumable_doc.source == DocumentSource.ConsumeFolder
@@ -471,7 +469,7 @@ class TestConsumeFile:
consumption_dir=consumption_dir,
subdirs_as_tags=False,
)
mock_consume_file_delay.delay.assert_not_called()
mock_consume_file_delay.apply_async.assert_not_called()
def test_consume_directory(
self,
@@ -487,7 +485,7 @@ class TestConsumeFile:
consumption_dir=consumption_dir,
subdirs_as_tags=False,
)
mock_consume_file_delay.delay.assert_not_called()
mock_consume_file_delay.apply_async.assert_not_called()
def test_consume_with_permission_error(
self,
@@ -506,7 +504,7 @@ class TestConsumeFile:
consumption_dir=consumption_dir,
subdirs_as_tags=False,
)
mock_consume_file_delay.delay.assert_not_called()
mock_consume_file_delay.apply_async.assert_not_called()
def test_consume_with_tags_error(
self,
@@ -529,9 +527,9 @@ class TestConsumeFile:
consumption_dir=consumption_dir,
subdirs_as_tags=True,
)
mock_consume_file_delay.delay.assert_called_once()
call_args = mock_consume_file_delay.delay.call_args
overrides = call_args[0][1]
mock_consume_file_delay.apply_async.assert_called_once()
call_args = mock_consume_file_delay.apply_async.call_args
overrides = call_args.kwargs["kwargs"]["overrides"]
assert overrides.tag_ids is None
@@ -629,7 +627,7 @@ class TestCommandOneshot:
cmd = Command()
cmd.handle(directory=str(consumption_dir), oneshot=True, testing=False)
mock_consume_file_delay.delay.assert_called_once()
mock_consume_file_delay.apply_async.assert_called_once()
def test_processes_recursive(
self,
@@ -652,7 +650,7 @@ class TestCommandOneshot:
cmd = Command()
cmd.handle(directory=str(consumption_dir), oneshot=True, testing=False)
mock_consume_file_delay.delay.assert_called_once()
mock_consume_file_delay.apply_async.assert_called_once()
def test_ignores_unsupported_extensions(
self,
@@ -671,7 +669,7 @@ class TestCommandOneshot:
cmd = Command()
cmd.handle(directory=str(consumption_dir), oneshot=True, testing=False)
mock_consume_file_delay.delay.assert_not_called()
mock_consume_file_delay.apply_async.assert_not_called()
class ConsumerThread(Thread):
@@ -795,12 +793,12 @@ class TestCommandWatch:
target = consumption_dir / "document.pdf"
shutil.copy(sample_pdf, target)
wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0)
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0)
if thread.exception:
raise thread.exception
mock_consume_file_delay.delay.assert_called()
mock_consume_file_delay.apply_async.assert_called()
def test_detects_moved_file(
self,
@@ -821,12 +819,12 @@ class TestCommandWatch:
target = consumption_dir / "document.pdf"
shutil.move(temp_location, target)
wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0)
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0)
if thread.exception:
raise thread.exception
mock_consume_file_delay.delay.assert_called()
mock_consume_file_delay.apply_async.assert_called()
def test_handles_slow_write(
self,
@@ -847,12 +845,12 @@ class TestCommandWatch:
f.flush()
sleep(0.05)
wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0)
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0)
if thread.exception:
raise thread.exception
mock_consume_file_delay.delay.assert_called()
mock_consume_file_delay.apply_async.assert_called()
def test_ignores_macos_files(
self,
@@ -868,13 +866,15 @@ class TestCommandWatch:
(consumption_dir / "._document.pdf").write_bytes(b"test")
shutil.copy(sample_pdf, consumption_dir / "valid.pdf")
wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0)
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0)
if thread.exception:
raise thread.exception
assert mock_consume_file_delay.delay.call_count == 1
call_args = mock_consume_file_delay.delay.call_args[0][0]
assert mock_consume_file_delay.apply_async.call_count == 1
call_args = mock_consume_file_delay.apply_async.call_args.kwargs["kwargs"][
"input_doc"
]
assert call_args.original_file.name == "valid.pdf"
@pytest.mark.django_db
@@ -924,12 +924,12 @@ class TestCommandWatchPolling:
# Actively wait for consumption
# Polling needs: interval (0.5s) + stability (0.1s) + next poll (0.5s) + margin
wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=5.0)
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=5.0)
if thread.exception:
raise thread.exception
mock_consume_file_delay.delay.assert_called()
mock_consume_file_delay.apply_async.assert_called()
@pytest.mark.management
@@ -953,12 +953,12 @@ class TestCommandWatchRecursive:
target = subdir / "document.pdf"
shutil.copy(sample_pdf, target)
wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0)
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0)
if thread.exception:
raise thread.exception
mock_consume_file_delay.delay.assert_called()
mock_consume_file_delay.apply_async.assert_called()
def test_subdirs_as_tags(
self,
@@ -983,15 +983,15 @@ class TestCommandWatchRecursive:
target = subdir / "document.pdf"
shutil.copy(sample_pdf, target)
wait_for_mock_call(mock_consume_file_delay.delay, timeout_s=2.0)
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=2.0)
if thread.exception:
raise thread.exception
mock_consume_file_delay.delay.assert_called()
mock_consume_file_delay.apply_async.assert_called()
mock_tags.assert_called()
call_args = mock_consume_file_delay.delay.call_args
overrides = call_args[0][1]
call_args = mock_consume_file_delay.apply_async.call_args
overrides = call_args.kwargs["kwargs"]["overrides"]
assert overrides.tag_ids is not None
assert len(overrides.tag_ids) == 2
@@ -1021,7 +1021,7 @@ class TestCommandWatchEdgeCases:
if thread.exception:
raise thread.exception
mock_consume_file_delay.delay.assert_not_called()
mock_consume_file_delay.apply_async.assert_not_called()
@pytest.mark.usefixtures("mock_supported_extensions")
def test_handles_task_exception(
@@ -1035,7 +1035,7 @@ class TestCommandWatchEdgeCases:
mock_task = mocker.patch(
"documents.management.commands.document_consumer.consume_file",
)
mock_task.delay.side_effect = Exception("Task error")
mock_task.apply_async.side_effect = Exception("Task error")
thread = ConsumerThread(consumption_dir, scratch_dir)
try:

View File

@@ -31,7 +31,7 @@ class ShareLinkBundleAPITests(DirectoriesMixin, APITestCase):
self.client.force_authenticate(self.user)
self.document = DocumentFactory.create()
@mock.patch("documents.views.build_share_link_bundle.delay")
@mock.patch("documents.views.build_share_link_bundle.apply_async")
def test_create_bundle_triggers_build_job(self, delay_mock) -> None:
payload = {
"document_ids": [self.document.pk],
@@ -45,7 +45,8 @@ class ShareLinkBundleAPITests(DirectoriesMixin, APITestCase):
bundle = ShareLinkBundle.objects.get(pk=response.data["id"])
self.assertEqual(bundle.documents.count(), 1)
self.assertEqual(bundle.status, ShareLinkBundle.Status.PENDING)
delay_mock.assert_called_once_with(bundle.pk)
delay_mock.assert_called_once()
self.assertEqual(delay_mock.call_args.kwargs["kwargs"]["bundle_id"], bundle.pk)
def test_create_bundle_rejects_missing_documents(self) -> None:
payload = {
@@ -73,7 +74,7 @@ class ShareLinkBundleAPITests(DirectoriesMixin, APITestCase):
self.assertIn("document_ids", response.data)
perms_mock.assert_called()
@mock.patch("documents.views.build_share_link_bundle.delay")
@mock.patch("documents.views.build_share_link_bundle.apply_async")
def test_rebuild_bundle_resets_state(self, delay_mock) -> None:
bundle = ShareLinkBundle.objects.create(
slug="rebuild-slug",
@@ -94,7 +95,8 @@ class ShareLinkBundleAPITests(DirectoriesMixin, APITestCase):
self.assertIsNone(bundle.last_error)
self.assertIsNone(bundle.size_bytes)
self.assertEqual(bundle.file_path, "")
delay_mock.assert_called_once_with(bundle.pk)
delay_mock.assert_called_once()
self.assertEqual(delay_mock.call_args.kwargs["kwargs"]["bundle_id"], bundle.pk)
def test_rebuild_bundle_rejects_processing_status(self) -> None:
bundle = ShareLinkBundle.objects.create(

View File

@@ -23,7 +23,7 @@ class TestTagHierarchy(DirectoriesMixin, APITestCase):
self.parent = Tag.objects.create(name="Parent")
self.child = Tag.objects.create(name="Child", tn_parent=self.parent)
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay")
patcher = mock.patch("documents.bulk_edit.bulk_update_documents.apply_async")
self.async_task = patcher.start()
self.addCleanup(patcher.stop)

View File

@@ -59,8 +59,9 @@ class TestBeforeTaskPublishHandler:
def test_creates_task_for_consume_file(self, consume_input_doc, consume_overrides):
task_id = send_publish(
"documents.tasks.consume_file",
(consume_input_doc, consume_overrides),
{},
(),
{"input_doc": consume_input_doc, "overrides": consume_overrides},
headers={"trigger_source": PaperlessTask.TriggerSource.WEB_UI},
)
task = PaperlessTask.objects.get(task_id=task_id)
assert task.task_type == PaperlessTask.TaskType.CONSUME_FILE
@@ -167,37 +168,19 @@ class TestBeforeTaskPublishHandler:
before_task_publish_handler(sender=None, headers=None, body=None)
assert PaperlessTask.objects.count() == 0
@pytest.mark.parametrize(
("document_source", "expected_trigger_source"),
[
pytest.param(
DocumentSource.ConsumeFolder,
PaperlessTask.TriggerSource.FOLDER_CONSUME,
id="folder_consume",
),
pytest.param(
DocumentSource.MailFetch,
PaperlessTask.TriggerSource.EMAIL_CONSUME,
id="email_consume",
),
],
)
def test_consume_document_source_maps_to_trigger_source(
def test_consume_file_without_trigger_source_header_defaults_to_manual(
self,
consume_input_doc,
consume_overrides,
document_source: DocumentSource,
expected_trigger_source: PaperlessTask.TriggerSource,
) -> None:
"""DocumentSource on the input doc maps to the correct TriggerSource."""
consume_input_doc.source = document_source
"""Without a trigger_source header the handler defaults to MANUAL."""
task_id = send_publish(
"documents.tasks.consume_file",
(consume_input_doc, consume_overrides),
{},
(),
{"input_doc": consume_input_doc, "overrides": consume_overrides},
)
task = PaperlessTask.objects.get(task_id=task_id)
assert task.trigger_source == expected_trigger_source
assert task.trigger_source == PaperlessTask.TriggerSource.MANUAL
@pytest.mark.django_db

View File

@@ -231,14 +231,16 @@ class ConsumerProgressMixin:
self.send_progress_patcher.stop()
class DocumentConsumeDelayMixin:
class ConsumeTaskMixin:
"""
Provides mocking of the consume_file asynchronous task and useful utilities
for decoding its arguments
"""
def setUp(self) -> None:
self.consume_file_patcher = mock.patch("documents.tasks.consume_file.delay")
self.consume_file_patcher = mock.patch(
"documents.tasks.consume_file.apply_async",
)
self.consume_file_mock = self.consume_file_patcher.start()
super().setUp()
@@ -246,48 +248,22 @@ class DocumentConsumeDelayMixin:
super().tearDown()
self.consume_file_patcher.stop()
def get_last_consume_delay_call_args(
def assert_queue_consumption_task_call_args(
self,
) -> tuple[ConsumableDocument, DocumentMetadataOverrides]:
"""
Returns the most recent arguments to the async task
"""
# Must be at least 1 call
self.consume_file_mock.assert_called()
"""Assert the task was queued exactly once and return its call args."""
self.consume_file_mock.assert_called_once()
task_kwargs = self.consume_file_mock.call_args.kwargs["kwargs"]
return (task_kwargs["input_doc"], task_kwargs["overrides"])
args, _ = self.consume_file_mock.call_args
input_doc, overrides = args
return (input_doc, overrides)
def get_all_consume_delay_call_args(
def get_all_consume_task_call_args(
self,
) -> Iterator[tuple[ConsumableDocument, DocumentMetadataOverrides]]:
"""
Iterates over all calls to the async task and returns the arguments
"""
# Must be at least 1 call
"""Iterate over all queued consume task calls and yield their call args."""
self.consume_file_mock.assert_called()
for args, kwargs in self.consume_file_mock.call_args_list:
input_doc, overrides = args
yield (input_doc, overrides)
def get_specific_consume_delay_call_args(
self,
index: int,
) -> tuple[ConsumableDocument, DocumentMetadataOverrides]:
"""
Returns the arguments of a specific call to the async task
"""
# Must be at least 1 call
self.consume_file_mock.assert_called()
args, _ = self.consume_file_mock.call_args_list[index]
input_doc, overrides = args
return (input_doc, overrides)
for call in self.consume_file_mock.call_args_list:
task_kwargs = call.kwargs["kwargs"]
yield (task_kwargs["input_doc"], task_kwargs["overrides"])
class TestMigrations(TransactionTestCase):

View File

@@ -1771,9 +1771,9 @@ class DocumentViewSet(
if request.user is not None:
overrides.actor_id = request.user.id
async_task = consume_file.delay(
input_doc,
overrides,
async_task = consume_file.apply_async(
kwargs={"input_doc": input_doc, "overrides": overrides},
headers={"trigger_source": PaperlessTask.TriggerSource.WEB_UI},
)
logger.debug(
f"Updated document {root_doc.id} with new version",
@@ -2450,6 +2450,7 @@ class DocumentOperationPermissionMixin(PassUserMixin, DocumentSelectionMixin):
"edit_pdf",
"remove_password",
}
METHOD_NAMES_REQUIRING_TRIGGER_SOURCE = METHOD_NAMES_REQUIRING_USER
def _has_document_permissions(
self,
@@ -2540,12 +2541,19 @@ class DocumentOperationPermissionMixin(PassUserMixin, DocumentSelectionMixin):
parameters = {
k: v
for k, v in validated_data.items()
if k not in {"documents", "all", "filters"}
if k not in {"documents", "all", "filters", "from_webui"}
}
user = self.request.user
from_webui = validated_data.get("from_webui", False)
if method.__name__ in self.METHOD_NAMES_REQUIRING_USER:
parameters["user"] = user
if method.__name__ in self.METHOD_NAMES_REQUIRING_TRIGGER_SOURCE:
parameters["trigger_source"] = (
PaperlessTask.TriggerSource.WEB_UI
if from_webui
else PaperlessTask.TriggerSource.API_UPLOAD
)
if not self._has_document_permissions(
user=user,
@@ -2629,12 +2637,19 @@ class BulkEditView(DocumentOperationPermissionMixin):
user = self.request.user
method = serializer.validated_data.get("method")
parameters = serializer.validated_data.get("parameters")
from_webui = serializer.validated_data.get("from_webui", False)
documents = self._resolve_document_ids(
user=user,
validated_data=serializer.validated_data,
)
if method.__name__ in self.METHOD_NAMES_REQUIRING_USER:
parameters["user"] = user
if method.__name__ in self.METHOD_NAMES_REQUIRING_TRIGGER_SOURCE:
parameters["trigger_source"] = (
PaperlessTask.TriggerSource.WEB_UI
if from_webui
else PaperlessTask.TriggerSource.API_UPLOAD
)
if not self._has_document_permissions(
user=user,
documents=documents,
@@ -2928,9 +2943,15 @@ class PostDocumentView(GenericAPIView[Any]):
custom_fields=custom_fields,
)
async_task = consume_file.delay(
input_doc,
input_doc_overrides,
async_task = consume_file.apply_async(
kwargs={"input_doc": input_doc, "overrides": input_doc_overrides},
headers={
"trigger_source": (
PaperlessTask.TriggerSource.WEB_UI
if from_webui
else PaperlessTask.TriggerSource.API_UPLOAD
),
},
)
return Response(async_task.id)
@@ -3566,7 +3587,10 @@ class StoragePathViewSet(PermissionsAwareDocumentCountMixin, ModelViewSet[Storag
response = super().destroy(request, *args, **kwargs)
if doc_ids:
bulk_edit.bulk_update_documents.delay(doc_ids)
bulk_edit.bulk_update_documents.apply_async(
kwargs={"document_ids": doc_ids},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
return response
@@ -4088,7 +4112,10 @@ class ShareLinkBundleViewSet(PassUserMixin, ModelViewSet[ShareLinkBundle]):
"file_path",
],
)
build_share_link_bundle.delay(bundle.pk)
build_share_link_bundle.apply_async(
kwargs={"bundle_id": bundle.pk},
headers={"trigger_source": PaperlessTask.TriggerSource.MANUAL},
)
bundle.document_total = len(ordered_documents)
response_serializer = self.get_serializer(bundle)
headers = self.get_success_headers(response_serializer.data)
@@ -4121,7 +4148,10 @@ class ShareLinkBundleViewSet(PassUserMixin, ModelViewSet[ShareLinkBundle]):
"file_path",
],
)
build_share_link_bundle.delay(bundle.pk)
build_share_link_bundle.apply_async(
kwargs={"bundle_id": bundle.pk},
headers={"trigger_source": PaperlessTask.TriggerSource.MANUAL},
)
bundle.document_total = (
getattr(bundle, "document_total", None) or bundle.documents.count()
)

View File

@@ -37,6 +37,7 @@ from documents.data_models import DocumentMetadataOverrides
from documents.data_models import DocumentSource
from documents.loggers import LoggingMixin
from documents.models import Correspondent
from documents.models import PaperlessTask
from documents.parsers import is_mime_type_supported
from documents.tasks import consume_file
from paperless.network import is_public_ip
@@ -893,8 +894,12 @@ class MailAccountHandler(LoggingMixin):
)
consume_task = consume_file.s(
input_doc,
doc_overrides,
input_doc=input_doc,
overrides=doc_overrides,
).set(
headers={
"trigger_source": PaperlessTask.TriggerSource.EMAIL_CONSUME,
},
)
consume_tasks.append(consume_task)
@@ -991,9 +996,9 @@ class MailAccountHandler(LoggingMixin):
)
consume_task = consume_file.s(
input_doc,
doc_overrides,
)
input_doc=input_doc,
overrides=doc_overrides,
).set(headers={"trigger_source": PaperlessTask.TriggerSource.EMAIL_CONSUME})
queue_consumption_tasks(
consume_tasks=[consume_task],

View File

@@ -359,7 +359,8 @@ class MailMocker(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
consume_tasks,
expected_signatures,
):
input_doc, overrides = consume_task.args
input_doc = consume_task.kwargs["input_doc"]
overrides = consume_task.kwargs["overrides"]
# assert the file exists
self.assertIsFile(input_doc.original_file)
@@ -2022,7 +2023,7 @@ class TestMailAccountProcess(APITestCase):
)
self.url = f"/api/mail_accounts/{self.account.pk}/process/"
@mock.patch("paperless_mail.tasks.process_mail_accounts.delay")
@mock.patch("paperless_mail.tasks.process_mail_accounts.apply_async")
def test_mail_account_process_view(self, m) -> None:
response = self.client.post(self.url)
self.assertEqual(response.status_code, status.HTTP_200_OK)

View File

@@ -23,6 +23,7 @@ from rest_framework.viewsets import ModelViewSet
from rest_framework.viewsets import ReadOnlyModelViewSet
from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
from documents.models import PaperlessTask
from documents.permissions import PaperlessObjectPermissions
from documents.permissions import has_perms_owner_aware
from documents.views import PassUserMixin
@@ -155,7 +156,10 @@ class MailAccountViewSet(PassUserMixin, ModelViewSet[MailAccount]):
@action(methods=["post"], detail=True)
def process(self, request, pk=None):
account = self.get_object()
process_mail_accounts.delay([account.pk])
process_mail_accounts.apply_async(
kwargs={"account_ids": [account.pk]},
headers={"trigger_source": PaperlessTask.TriggerSource.MANUAL},
)
return Response({"result": "OK"})