Compare commits

..

3 Commits

5 changed files with 80 additions and 17 deletions

View File

@@ -9,6 +9,7 @@ from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
@@ -191,7 +192,12 @@ class DocumentClassifier:
target_file_temp.rename(target_file)
def train(self) -> bool:
def train(
self,
status_callback: Callable[[str], None] | None = None,
) -> bool:
notify = status_callback if status_callback is not None else lambda _: None
# Get non-inbox documents
docs_queryset = (
Document.objects.exclude(
@@ -213,6 +219,7 @@ class DocumentClassifier:
# Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...")
notify(f"Gathering data from {docs_queryset.count()} document(s)...")
hasher = sha256()
for doc in docs_queryset:
y = -1
@@ -290,6 +297,7 @@ class DocumentClassifier:
# Step 2: vectorize data
logger.debug("Vectorizing data...")
notify("Vectorizing document content...")
def content_generator() -> Iterator[str]:
"""
@@ -316,6 +324,7 @@ class DocumentClassifier:
# Step 3: train the classifiers
if num_tags > 0:
logger.debug("Training tags classifier...")
notify(f"Training tags classifier ({num_tags} tag(s))...")
if num_tags == 1:
# Special case where only one tag has auto:
@@ -339,6 +348,9 @@ class DocumentClassifier:
if num_correspondents > 0:
logger.debug("Training correspondent classifier...")
notify(
f"Training correspondent classifier ({num_correspondents} correspondent(s))...",
)
self.correspondent_classifier = MLPClassifier(tol=0.01)
self.correspondent_classifier.fit(data_vectorized, labels_correspondent)
else:
@@ -349,6 +361,9 @@ class DocumentClassifier:
if num_document_types > 0:
logger.debug("Training document type classifier...")
notify(
f"Training document type classifier ({num_document_types} type(s))...",
)
self.document_type_classifier = MLPClassifier(tol=0.01)
self.document_type_classifier.fit(data_vectorized, labels_document_type)
else:
@@ -361,6 +376,7 @@ class DocumentClassifier:
logger.debug(
"Training storage paths classifier...",
)
notify(f"Training storage path classifier ({num_storage_paths} path(s))...")
self.storage_path_classifier = MLPClassifier(tol=0.01)
self.storage_path_classifier.fit(
data_vectorized,

View File

@@ -1,13 +1,32 @@
from django.core.management.base import BaseCommand
from __future__ import annotations
import time
from documents.management.commands.base import PaperlessCommand
from documents.tasks import train_classifier
class Command(BaseCommand):
class Command(PaperlessCommand):
help = (
"Trains the classifier on your data and saves the resulting models to a "
"file. The document consumer will then automatically use this new model."
)
supports_progress_bar = False
supports_multiprocessing = False
def handle(self, *args, **options):
train_classifier(scheduled=False)
def handle(self, *args, **options) -> None:
start = time.monotonic()
with (
self.buffered_logging("paperless.tasks"),
self.buffered_logging("paperless.classifier"),
):
train_classifier(
scheduled=False,
status_callback=lambda msg: self.console.print(f" {msg}"),
)
elapsed = time.monotonic() - start
self.console.print(
f"[green]✓[/green] Classifier training complete ({elapsed:.1f}s)",
)

View File

@@ -100,7 +100,11 @@ def index_reindex(*, iter_wrapper: IterWrapper[Document] = _identity) -> None:
@shared_task
def train_classifier(*, scheduled=True) -> None:
def train_classifier(
*,
scheduled=True,
status_callback: Callable[[str], None] | None = None,
) -> None:
task = PaperlessTask.objects.create(
type=PaperlessTask.TaskType.SCHEDULED_TASK
if scheduled
@@ -136,7 +140,7 @@ def train_classifier(*, scheduled=True) -> None:
classifier = DocumentClassifier()
try:
if classifier.train():
if classifier.train(status_callback=status_callback):
logger.info(
f"Saving updated classifier model to {settings.MODEL_FILE}...",
)

View File

@@ -1,7 +1,10 @@
from __future__ import annotations
import filecmp
import shutil
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import mock
import pytest
@@ -11,6 +14,9 @@ from django.core.management import call_command
from django.test import TestCase
from django.test import override_settings
if TYPE_CHECKING:
from pytest_mock import MockerFixture
from documents.file_handling import generate_filename
from documents.models import Document
from documents.tasks import update_document_content_maybe_archive_file
@@ -135,14 +141,32 @@ class TestRenamer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
@pytest.mark.management
class TestCreateClassifier(TestCase):
@mock.patch(
"documents.management.commands.document_create_classifier.train_classifier",
)
def test_create_classifier(self, m) -> None:
call_command("document_create_classifier")
class TestCreateClassifier:
def test_create_classifier(self, mocker: MockerFixture) -> None:
m = mocker.patch(
"documents.management.commands.document_create_classifier.train_classifier",
)
m.assert_called_once()
call_command("document_create_classifier", "--skip-checks")
m.assert_called_once_with(scheduled=False, status_callback=mocker.ANY)
assert callable(m.call_args.kwargs["status_callback"])
def test_create_classifier_callback_output(self, mocker: MockerFixture) -> None:
"""Callback passed to train_classifier writes each phase message to the console."""
m = mocker.patch(
"documents.management.commands.document_create_classifier.train_classifier",
)
def invoke_callback(**kwargs):
kwargs["status_callback"]("Vectorizing document content...")
m.side_effect = invoke_callback
stdout = StringIO()
call_command("document_create_classifier", "--skip-checks", stdout=stdout)
assert "Vectorizing document content..." in stdout.getvalue()
@pytest.mark.management

6
uv.lock generated
View File

@@ -3683,11 +3683,11 @@ wheels = [
[[package]]
name = "pyjwt"
version = "2.12.0"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a8/10/e8192be5f38f3e8e7e046716de4cae33d56fd5ae08927a823bb916be36c1/pyjwt-2.12.0.tar.gz", hash = "sha256:2f62390b667cd8257de560b850bb5a883102a388829274147f1d724453f8fb02", size = 102511, upload-time = "2026-03-12T17:15:30.831Z" }
sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/15/70/70f895f404d363d291dcf62c12c85fdd47619ad9674ac0f53364d035925a/pyjwt-2.12.0-py3-none-any.whl", hash = "sha256:9bb459d1bdd0387967d287f5656bf7ec2b9a26645d1961628cda1764e087fd6e", size = 29700, upload-time = "2026-03-12T17:15:29.257Z" },
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
]
[package.optional-dependencies]