fix(tasks): resolve trigger_source header via TriggerSource enum lookup

Replace two hardcoded string comparisons ("scheduled", "system") with a
single TriggerSource(header_source) lookup so the enum values are the
single source of truth. Any valid TriggerSource DB value passed in the
header is accepted; invalid values fall through to the document-source /
MANUAL logic. Update tests to pass enum values in headers rather than raw
strings, and add a test for the invalid-header fallback path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
stumpylog
2026-04-16 10:13:39 -07:00
parent 5bb8c11952
commit 4c01876a53
2 changed files with 18 additions and 7 deletions
+6 -5
View File
@@ -1097,12 +1097,13 @@ def _determine_trigger_source(
2. For consume_file tasks, the DocumentSource on the input document.
3. MANUAL as the catch-all for all other cases.
"""
# Explicit header takes priority -- covers beat ("scheduled") and system auto-runs ("system")
# Explicit header takes priority -- callers pass a TriggerSource DB value directly.
header_source = headers.get("trigger_source")
if header_source == "scheduled":
return PaperlessTask.TriggerSource.SCHEDULED
if header_source == "system":
return PaperlessTask.TriggerSource.SYSTEM
if header_source is not None:
try:
return PaperlessTask.TriggerSource(header_source)
except ValueError:
pass
if task_type == PaperlessTask.TaskType.CONSUME_FILE:
input_doc, _ = _get_consume_args(args, task_kwargs)
+12 -2
View File
@@ -94,7 +94,7 @@ class TestBeforeTaskPublishHandler:
"documents.tasks.train_classifier",
(),
{},
headers={"trigger_source": "scheduled"},
headers={"trigger_source": PaperlessTask.TriggerSource.SCHEDULED},
)
task = PaperlessTask.objects.get(task_id=task_id)
assert task.trigger_source == PaperlessTask.TriggerSource.SCHEDULED
@@ -104,11 +104,21 @@ class TestBeforeTaskPublishHandler:
"documents.tasks.llmindex_index",
(),
{"rebuild": True},
headers={"trigger_source": "system"},
headers={"trigger_source": PaperlessTask.TriggerSource.SYSTEM},
)
task = PaperlessTask.objects.get(task_id=task_id)
assert task.trigger_source == PaperlessTask.TriggerSource.SYSTEM
def test_invalid_header_falls_back_to_manual(self):
task_id = send_publish(
"documents.tasks.train_classifier",
(),
{},
headers={"trigger_source": "bogus_value"},
)
task = PaperlessTask.objects.get(task_id=task_id)
assert task.trigger_source == PaperlessTask.TriggerSource.MANUAL
def test_ignores_untracked_task(self):
send_publish("documents.tasks.some_untracked_task", (), {})
assert PaperlessTask.objects.count() == 0