Feature: Add progress information to the classifier training for a better ux (#12331)

This commit is contained in:
Trenton H
2026-03-14 12:53:52 -07:00
committed by GitHub
parent 01abacab52
commit 9d69705e26
4 changed files with 77 additions and 14 deletions
+31 -7
View File
@@ -1,7 +1,10 @@
from __future__ import annotations
import filecmp
import shutil
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import mock
import pytest
@@ -11,6 +14,9 @@ from django.core.management import call_command
from django.test import TestCase
from django.test import override_settings
if TYPE_CHECKING:
from pytest_mock import MockerFixture
from documents.file_handling import generate_filename
from documents.models import Document
from documents.tasks import update_document_content_maybe_archive_file
@@ -135,14 +141,32 @@ class TestRenamer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
@pytest.mark.management
class TestCreateClassifier(TestCase):
@mock.patch(
"documents.management.commands.document_create_classifier.train_classifier",
)
def test_create_classifier(self, m) -> None:
call_command("document_create_classifier")
class TestCreateClassifier:
def test_create_classifier(self, mocker: MockerFixture) -> None:
m = mocker.patch(
"documents.management.commands.document_create_classifier.train_classifier",
)
m.assert_called_once()
call_command("document_create_classifier", "--skip-checks")
m.assert_called_once_with(scheduled=False, status_callback=mocker.ANY)
assert callable(m.call_args.kwargs["status_callback"])
def test_create_classifier_callback_output(self, mocker: MockerFixture) -> None:
"""Callback passed to train_classifier writes each phase message to the console."""
m = mocker.patch(
"documents.management.commands.document_create_classifier.train_classifier",
)
def invoke_callback(**kwargs):
kwargs["status_callback"]("Vectorizing document content...")
m.side_effect = invoke_callback
stdout = StringIO()
call_command("document_create_classifier", "--skip-checks", stdout=stdout)
assert "Vectorizing document content..." in stdout.getvalue()
@pytest.mark.management