diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 642fd7809..78da89833 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -1097,12 +1097,13 @@ def _determine_trigger_source( 2. For consume_file tasks, the DocumentSource on the input document. 3. MANUAL as the catch-all for all other cases. """ - # Explicit header takes priority -- covers beat ("scheduled") and system auto-runs ("system") + # Explicit header takes priority -- callers pass a TriggerSource DB value directly. header_source = headers.get("trigger_source") - if header_source == "scheduled": - return PaperlessTask.TriggerSource.SCHEDULED - if header_source == "system": - return PaperlessTask.TriggerSource.SYSTEM + if header_source is not None: + try: + return PaperlessTask.TriggerSource(header_source) + except ValueError: + pass if task_type == PaperlessTask.TaskType.CONSUME_FILE: input_doc, _ = _get_consume_args(args, task_kwargs) diff --git a/src/documents/tests/test_task_signals.py b/src/documents/tests/test_task_signals.py index 94f35b8b1..04554eb84 100644 --- a/src/documents/tests/test_task_signals.py +++ b/src/documents/tests/test_task_signals.py @@ -94,7 +94,7 @@ class TestBeforeTaskPublishHandler: "documents.tasks.train_classifier", (), {}, - headers={"trigger_source": "scheduled"}, + headers={"trigger_source": PaperlessTask.TriggerSource.SCHEDULED}, ) task = PaperlessTask.objects.get(task_id=task_id) assert task.trigger_source == PaperlessTask.TriggerSource.SCHEDULED @@ -104,11 +104,21 @@ class TestBeforeTaskPublishHandler: "documents.tasks.llmindex_index", (), {"rebuild": True}, - headers={"trigger_source": "system"}, + headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM}, ) task = PaperlessTask.objects.get(task_id=task_id) assert task.trigger_source == PaperlessTask.TriggerSource.SYSTEM + def test_invalid_header_falls_back_to_manual(self): + task_id = send_publish( + "documents.tasks.train_classifier", + (), + {}, + headers={"trigger_source": "bogus_value"}, + ) + task = PaperlessTask.objects.get(task_id=task_id) + assert task.trigger_source == PaperlessTask.TriggerSource.MANUAL + def test_ignores_untracked_task(self): send_publish("documents.tasks.some_untracked_task", (), {}) assert PaperlessTask.objects.count() == 0