Feature: Add progress information to the classifier training for a better ux (#12331)

This commit is contained in:
Trenton H
2026-03-14 12:53:52 -07:00
committed by GitHub
parent 01abacab52
commit 9d69705e26
4 changed files with 77 additions and 14 deletions
+17 -1
View File
@@ -9,6 +9,7 @@ from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
@@ -191,7 +192,12 @@ class DocumentClassifier:
target_file_temp.rename(target_file)
def train(self) -> bool:
def train(
self,
status_callback: Callable[[str], None] | None = None,
) -> bool:
notify = status_callback if status_callback is not None else lambda _: None
# Get non-inbox documents
docs_queryset = (
Document.objects.exclude(
@@ -213,6 +219,7 @@ class DocumentClassifier:
# Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...")
notify(f"Gathering data from {docs_queryset.count()} document(s)...")
hasher = sha256()
for doc in docs_queryset:
y = -1
@@ -290,6 +297,7 @@ class DocumentClassifier:
# Step 2: vectorize data
logger.debug("Vectorizing data...")
notify("Vectorizing document content...")
def content_generator() -> Iterator[str]:
"""
@@ -316,6 +324,7 @@ class DocumentClassifier:
# Step 3: train the classifiers
if num_tags > 0:
logger.debug("Training tags classifier...")
notify(f"Training tags classifier ({num_tags} tag(s))...")
if num_tags == 1:
# Special case where only one tag has auto:
@@ -339,6 +348,9 @@ class DocumentClassifier:
if num_correspondents > 0:
logger.debug("Training correspondent classifier...")
notify(
f"Training correspondent classifier ({num_correspondents} correspondent(s))...",
)
self.correspondent_classifier = MLPClassifier(tol=0.01)
self.correspondent_classifier.fit(data_vectorized, labels_correspondent)
else:
@@ -349,6 +361,9 @@ class DocumentClassifier:
if num_document_types > 0:
logger.debug("Training document type classifier...")
notify(
f"Training document type classifier ({num_document_types} type(s))...",
)
self.document_type_classifier = MLPClassifier(tol=0.01)
self.document_type_classifier.fit(data_vectorized, labels_document_type)
else:
@@ -361,6 +376,7 @@ class DocumentClassifier:
logger.debug(
"Training storage paths classifier...",
)
notify(f"Training storage path classifier ({num_storage_paths} path(s))...")
self.storage_path_classifier = MLPClassifier(tol=0.01)
self.storage_path_classifier.fit(
data_vectorized,