Compare commits

...

1 Commits

Author SHA1 Message Date
Trenton H
c7d4fb1f8b Adds a progress information to the classifier training for a better ux 2026-03-12 13:09:21 -07:00
4 changed files with 74 additions and 14 deletions

View File

@@ -9,6 +9,7 @@ from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
from datetime import datetime from datetime import datetime
@@ -191,7 +192,12 @@ class DocumentClassifier:
target_file_temp.rename(target_file) target_file_temp.rename(target_file)
def train(self) -> bool: def train(
self,
status_callback: Callable[[str], None] | None = None,
) -> bool:
notify = status_callback if status_callback is not None else lambda _: None
# Get non-inbox documents # Get non-inbox documents
docs_queryset = ( docs_queryset = (
Document.objects.exclude( Document.objects.exclude(
@@ -213,6 +219,7 @@ class DocumentClassifier:
# Step 1: Extract and preprocess training data from the database. # Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...") logger.debug("Gathering data from database...")
notify(f"Gathering data from {docs_queryset.count()} document(s)...")
hasher = sha256() hasher = sha256()
for doc in docs_queryset: for doc in docs_queryset:
y = -1 y = -1
@@ -290,6 +297,7 @@ class DocumentClassifier:
# Step 2: vectorize data # Step 2: vectorize data
logger.debug("Vectorizing data...") logger.debug("Vectorizing data...")
notify("Vectorizing document content...")
def content_generator() -> Iterator[str]: def content_generator() -> Iterator[str]:
""" """
@@ -316,6 +324,7 @@ class DocumentClassifier:
# Step 3: train the classifiers # Step 3: train the classifiers
if num_tags > 0: if num_tags > 0:
logger.debug("Training tags classifier...") logger.debug("Training tags classifier...")
notify(f"Training tags classifier ({num_tags} tag(s))...")
if num_tags == 1: if num_tags == 1:
# Special case where only one tag has auto: # Special case where only one tag has auto:
@@ -339,6 +348,9 @@ class DocumentClassifier:
if num_correspondents > 0: if num_correspondents > 0:
logger.debug("Training correspondent classifier...") logger.debug("Training correspondent classifier...")
notify(
f"Training correspondent classifier ({num_correspondents} correspondent(s))...",
)
self.correspondent_classifier = MLPClassifier(tol=0.01) self.correspondent_classifier = MLPClassifier(tol=0.01)
self.correspondent_classifier.fit(data_vectorized, labels_correspondent) self.correspondent_classifier.fit(data_vectorized, labels_correspondent)
else: else:
@@ -349,6 +361,9 @@ class DocumentClassifier:
if num_document_types > 0: if num_document_types > 0:
logger.debug("Training document type classifier...") logger.debug("Training document type classifier...")
notify(
f"Training document type classifier ({num_document_types} type(s))...",
)
self.document_type_classifier = MLPClassifier(tol=0.01) self.document_type_classifier = MLPClassifier(tol=0.01)
self.document_type_classifier.fit(data_vectorized, labels_document_type) self.document_type_classifier.fit(data_vectorized, labels_document_type)
else: else:
@@ -361,6 +376,7 @@ class DocumentClassifier:
logger.debug( logger.debug(
"Training storage paths classifier...", "Training storage paths classifier...",
) )
notify(f"Training storage path classifier ({num_storage_paths} path(s))...")
self.storage_path_classifier = MLPClassifier(tol=0.01) self.storage_path_classifier = MLPClassifier(tol=0.01)
self.storage_path_classifier.fit( self.storage_path_classifier.fit(
data_vectorized, data_vectorized,

View File

@@ -1,13 +1,29 @@
from django.core.management.base import BaseCommand from __future__ import annotations
import time
from documents.management.commands.base import PaperlessCommand
from documents.tasks import train_classifier from documents.tasks import train_classifier
class Command(BaseCommand): class Command(PaperlessCommand):
help = ( help = (
"Trains the classifier on your data and saves the resulting models to a " "Trains the classifier on your data and saves the resulting models to a "
"file. The document consumer will then automatically use this new model." "file. The document consumer will then automatically use this new model."
) )
supports_progress_bar = False
supports_multiprocessing = False
def handle(self, *args, **options): def handle(self, *args, **options) -> None:
train_classifier(scheduled=False) start = time.monotonic()
with self.buffered_logging("paperless.tasks"):
train_classifier(
scheduled=False,
status_callback=lambda msg: self.console.print(f" {msg}"),
)
elapsed = time.monotonic() - start
self.console.print(
f"[green]✓[/green] Classifier training complete ({elapsed:.1f}s)",
)

View File

@@ -100,7 +100,11 @@ def index_reindex(*, iter_wrapper: IterWrapper[Document] = _identity) -> None:
@shared_task @shared_task
def train_classifier(*, scheduled=True) -> None: def train_classifier(
*,
scheduled=True,
status_callback: Callable[[str], None] | None = None,
) -> None:
task = PaperlessTask.objects.create( task = PaperlessTask.objects.create(
type=PaperlessTask.TaskType.SCHEDULED_TASK type=PaperlessTask.TaskType.SCHEDULED_TASK
if scheduled if scheduled
@@ -136,7 +140,7 @@ def train_classifier(*, scheduled=True) -> None:
classifier = DocumentClassifier() classifier = DocumentClassifier()
try: try:
if classifier.train(): if classifier.train(status_callback=status_callback):
logger.info( logger.info(
f"Saving updated classifier model to {settings.MODEL_FILE}...", f"Saving updated classifier model to {settings.MODEL_FILE}...",
) )

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
import filecmp import filecmp
import shutil import shutil
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
from unittest import mock from unittest import mock
import pytest import pytest
@@ -11,6 +14,9 @@ from django.core.management import call_command
from django.test import TestCase from django.test import TestCase
from django.test import override_settings from django.test import override_settings
if TYPE_CHECKING:
from pytest_mock import MockerFixture
from documents.file_handling import generate_filename from documents.file_handling import generate_filename
from documents.models import Document from documents.models import Document
from documents.tasks import update_document_content_maybe_archive_file from documents.tasks import update_document_content_maybe_archive_file
@@ -135,14 +141,32 @@ class TestRenamer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
@pytest.mark.management @pytest.mark.management
class TestCreateClassifier(TestCase): class TestCreateClassifier:
@mock.patch( def test_create_classifier(self, mocker: MockerFixture) -> None:
"documents.management.commands.document_create_classifier.train_classifier", m = mocker.patch(
) "documents.management.commands.document_create_classifier.train_classifier",
def test_create_classifier(self, m) -> None: )
call_command("document_create_classifier")
m.assert_called_once() call_command("document_create_classifier", "--skip-checks")
m.assert_called_once_with(scheduled=False, status_callback=mocker.ANY)
assert callable(m.call_args.kwargs["status_callback"])
def test_create_classifier_callback_output(self, mocker: MockerFixture) -> None:
"""Callback passed to train_classifier writes each phase message to the console."""
m = mocker.patch(
"documents.management.commands.document_create_classifier.train_classifier",
)
def invoke_callback(**kwargs):
kwargs["status_callback"]("Vectorizing document content...")
m.side_effect = invoke_callback
stdout = StringIO()
call_command("document_create_classifier", "--skip-checks", stdout=stdout)
assert "Vectorizing document content..." in stdout.getvalue()
@pytest.mark.management @pytest.mark.management