diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 1e9da7ce6..87934ab52 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import Callable from collections.abc import Iterator from datetime import datetime @@ -191,7 +192,12 @@ class DocumentClassifier: target_file_temp.rename(target_file) - def train(self) -> bool: + def train( + self, + status_callback: Callable[[str], None] | None = None, + ) -> bool: + notify = status_callback if status_callback is not None else lambda _: None + # Get non-inbox documents docs_queryset = ( Document.objects.exclude( @@ -213,6 +219,7 @@ class DocumentClassifier: # Step 1: Extract and preprocess training data from the database. logger.debug("Gathering data from database...") + notify(f"Gathering data from {docs_queryset.count()} document(s)...") hasher = sha256() for doc in docs_queryset: y = -1 @@ -290,6 +297,7 @@ class DocumentClassifier: # Step 2: vectorize data logger.debug("Vectorizing data...") + notify("Vectorizing document content...") def content_generator() -> Iterator[str]: """ @@ -316,6 +324,7 @@ class DocumentClassifier: # Step 3: train the classifiers if num_tags > 0: logger.debug("Training tags classifier...") + notify(f"Training tags classifier ({num_tags} tag(s))...") if num_tags == 1: # Special case where only one tag has auto: @@ -339,6 +348,9 @@ class DocumentClassifier: if num_correspondents > 0: logger.debug("Training correspondent classifier...") + notify( + f"Training correspondent classifier ({num_correspondents} correspondent(s))...", + ) self.correspondent_classifier = MLPClassifier(tol=0.01) self.correspondent_classifier.fit(data_vectorized, labels_correspondent) else: @@ -349,6 +361,9 @@ class DocumentClassifier: if num_document_types > 0: logger.debug("Training document type classifier...") + notify( + f"Training document type classifier ({num_document_types} type(s))...", + ) self.document_type_classifier = MLPClassifier(tol=0.01) self.document_type_classifier.fit(data_vectorized, labels_document_type) else: @@ -361,6 +376,7 @@ class DocumentClassifier: logger.debug( "Training storage paths classifier...", ) + notify(f"Training storage path classifier ({num_storage_paths} path(s))...") self.storage_path_classifier = MLPClassifier(tol=0.01) self.storage_path_classifier.fit( data_vectorized, diff --git a/src/documents/management/commands/document_create_classifier.py b/src/documents/management/commands/document_create_classifier.py index f7903aac7..b662195a7 100644 --- a/src/documents/management/commands/document_create_classifier.py +++ b/src/documents/management/commands/document_create_classifier.py @@ -1,13 +1,32 @@ -from django.core.management.base import BaseCommand +from __future__ import annotations +import time + +from documents.management.commands.base import PaperlessCommand from documents.tasks import train_classifier -class Command(BaseCommand): +class Command(PaperlessCommand): help = ( "Trains the classifier on your data and saves the resulting models to a " "file. The document consumer will then automatically use this new model." ) + supports_progress_bar = False + supports_multiprocessing = False - def handle(self, *args, **options): - train_classifier(scheduled=False) + def handle(self, *args, **options) -> None: + start = time.monotonic() + + with ( + self.buffered_logging("paperless.tasks"), + self.buffered_logging("paperless.classifier"), + ): + train_classifier( + scheduled=False, + status_callback=lambda msg: self.console.print(f" {msg}"), + ) + + elapsed = time.monotonic() - start + self.console.print( + f"[green]✓[/green] Classifier training complete ({elapsed:.1f}s)", + ) diff --git a/src/documents/tasks.py b/src/documents/tasks.py index 86b6b2716..378695731 100644 --- a/src/documents/tasks.py +++ b/src/documents/tasks.py @@ -100,7 +100,11 @@ def index_reindex(*, iter_wrapper: IterWrapper[Document] = _identity) -> None: @shared_task -def train_classifier(*, scheduled=True) -> None: +def train_classifier( + *, + scheduled=True, + status_callback: Callable[[str], None] | None = None, +) -> None: task = PaperlessTask.objects.create( type=PaperlessTask.TaskType.SCHEDULED_TASK if scheduled @@ -136,7 +140,7 @@ def train_classifier(*, scheduled=True) -> None: classifier = DocumentClassifier() try: - if classifier.train(): + if classifier.train(status_callback=status_callback): logger.info( f"Saving updated classifier model to {settings.MODEL_FILE}...", ) diff --git a/src/documents/tests/test_management.py b/src/documents/tests/test_management.py index 03959a85b..2a62173b1 100644 --- a/src/documents/tests/test_management.py +++ b/src/documents/tests/test_management.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import filecmp import shutil from io import StringIO from pathlib import Path +from typing import TYPE_CHECKING from unittest import mock import pytest @@ -11,6 +14,9 @@ from django.core.management import call_command from django.test import TestCase from django.test import override_settings +if TYPE_CHECKING: + from pytest_mock import MockerFixture + from documents.file_handling import generate_filename from documents.models import Document from documents.tasks import update_document_content_maybe_archive_file @@ -135,14 +141,32 @@ class TestRenamer(DirectoriesMixin, FileSystemAssertsMixin, TestCase): @pytest.mark.management -class TestCreateClassifier(TestCase): - @mock.patch( - "documents.management.commands.document_create_classifier.train_classifier", - ) - def test_create_classifier(self, m) -> None: - call_command("document_create_classifier") +class TestCreateClassifier: + def test_create_classifier(self, mocker: MockerFixture) -> None: + m = mocker.patch( + "documents.management.commands.document_create_classifier.train_classifier", + ) - m.assert_called_once() + call_command("document_create_classifier", "--skip-checks") + + m.assert_called_once_with(scheduled=False, status_callback=mocker.ANY) + assert callable(m.call_args.kwargs["status_callback"]) + + def test_create_classifier_callback_output(self, mocker: MockerFixture) -> None: + """Callback passed to train_classifier writes each phase message to the console.""" + m = mocker.patch( + "documents.management.commands.document_create_classifier.train_classifier", + ) + + def invoke_callback(**kwargs): + kwargs["status_callback"]("Vectorizing document content...") + + m.side_effect = invoke_callback + + stdout = StringIO() + call_command("document_create_classifier", "--skip-checks", stdout=stdout) + + assert "Vectorizing document content..." in stdout.getvalue() @pytest.mark.management