mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-03-13 12:41:23 +00:00
Compare commits
1 Commits
dependabot
...
feature-cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7d4fb1f8b |
@@ -9,6 +9,7 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -191,7 +192,12 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
target_file_temp.rename(target_file)
|
target_file_temp.rename(target_file)
|
||||||
|
|
||||||
def train(self) -> bool:
|
def train(
|
||||||
|
self,
|
||||||
|
status_callback: Callable[[str], None] | None = None,
|
||||||
|
) -> bool:
|
||||||
|
notify = status_callback if status_callback is not None else lambda _: None
|
||||||
|
|
||||||
# Get non-inbox documents
|
# Get non-inbox documents
|
||||||
docs_queryset = (
|
docs_queryset = (
|
||||||
Document.objects.exclude(
|
Document.objects.exclude(
|
||||||
@@ -213,6 +219,7 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
# Step 1: Extract and preprocess training data from the database.
|
# Step 1: Extract and preprocess training data from the database.
|
||||||
logger.debug("Gathering data from database...")
|
logger.debug("Gathering data from database...")
|
||||||
|
notify(f"Gathering data from {docs_queryset.count()} document(s)...")
|
||||||
hasher = sha256()
|
hasher = sha256()
|
||||||
for doc in docs_queryset:
|
for doc in docs_queryset:
|
||||||
y = -1
|
y = -1
|
||||||
@@ -290,6 +297,7 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
# Step 2: vectorize data
|
# Step 2: vectorize data
|
||||||
logger.debug("Vectorizing data...")
|
logger.debug("Vectorizing data...")
|
||||||
|
notify("Vectorizing document content...")
|
||||||
|
|
||||||
def content_generator() -> Iterator[str]:
|
def content_generator() -> Iterator[str]:
|
||||||
"""
|
"""
|
||||||
@@ -316,6 +324,7 @@ class DocumentClassifier:
|
|||||||
# Step 3: train the classifiers
|
# Step 3: train the classifiers
|
||||||
if num_tags > 0:
|
if num_tags > 0:
|
||||||
logger.debug("Training tags classifier...")
|
logger.debug("Training tags classifier...")
|
||||||
|
notify(f"Training tags classifier ({num_tags} tag(s))...")
|
||||||
|
|
||||||
if num_tags == 1:
|
if num_tags == 1:
|
||||||
# Special case where only one tag has auto:
|
# Special case where only one tag has auto:
|
||||||
@@ -339,6 +348,9 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
if num_correspondents > 0:
|
if num_correspondents > 0:
|
||||||
logger.debug("Training correspondent classifier...")
|
logger.debug("Training correspondent classifier...")
|
||||||
|
notify(
|
||||||
|
f"Training correspondent classifier ({num_correspondents} correspondent(s))...",
|
||||||
|
)
|
||||||
self.correspondent_classifier = MLPClassifier(tol=0.01)
|
self.correspondent_classifier = MLPClassifier(tol=0.01)
|
||||||
self.correspondent_classifier.fit(data_vectorized, labels_correspondent)
|
self.correspondent_classifier.fit(data_vectorized, labels_correspondent)
|
||||||
else:
|
else:
|
||||||
@@ -349,6 +361,9 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
if num_document_types > 0:
|
if num_document_types > 0:
|
||||||
logger.debug("Training document type classifier...")
|
logger.debug("Training document type classifier...")
|
||||||
|
notify(
|
||||||
|
f"Training document type classifier ({num_document_types} type(s))...",
|
||||||
|
)
|
||||||
self.document_type_classifier = MLPClassifier(tol=0.01)
|
self.document_type_classifier = MLPClassifier(tol=0.01)
|
||||||
self.document_type_classifier.fit(data_vectorized, labels_document_type)
|
self.document_type_classifier.fit(data_vectorized, labels_document_type)
|
||||||
else:
|
else:
|
||||||
@@ -361,6 +376,7 @@ class DocumentClassifier:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"Training storage paths classifier...",
|
"Training storage paths classifier...",
|
||||||
)
|
)
|
||||||
|
notify(f"Training storage path classifier ({num_storage_paths} path(s))...")
|
||||||
self.storage_path_classifier = MLPClassifier(tol=0.01)
|
self.storage_path_classifier = MLPClassifier(tol=0.01)
|
||||||
self.storage_path_classifier.fit(
|
self.storage_path_classifier.fit(
|
||||||
data_vectorized,
|
data_vectorized,
|
||||||
|
|||||||
@@ -1,13 +1,29 @@
|
|||||||
from django.core.management.base import BaseCommand
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from documents.management.commands.base import PaperlessCommand
|
||||||
from documents.tasks import train_classifier
|
from documents.tasks import train_classifier
|
||||||
|
|
||||||
|
|
||||||
class Command(BaseCommand):
|
class Command(PaperlessCommand):
|
||||||
help = (
|
help = (
|
||||||
"Trains the classifier on your data and saves the resulting models to a "
|
"Trains the classifier on your data and saves the resulting models to a "
|
||||||
"file. The document consumer will then automatically use this new model."
|
"file. The document consumer will then automatically use this new model."
|
||||||
)
|
)
|
||||||
|
supports_progress_bar = False
|
||||||
|
supports_multiprocessing = False
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options) -> None:
|
||||||
train_classifier(scheduled=False)
|
start = time.monotonic()
|
||||||
|
|
||||||
|
with self.buffered_logging("paperless.tasks"):
|
||||||
|
train_classifier(
|
||||||
|
scheduled=False,
|
||||||
|
status_callback=lambda msg: self.console.print(f" {msg}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
self.console.print(
|
||||||
|
f"[green]✓[/green] Classifier training complete ({elapsed:.1f}s)",
|
||||||
|
)
|
||||||
|
|||||||
@@ -100,7 +100,11 @@ def index_reindex(*, iter_wrapper: IterWrapper[Document] = _identity) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
def train_classifier(*, scheduled=True) -> None:
|
def train_classifier(
|
||||||
|
*,
|
||||||
|
scheduled=True,
|
||||||
|
status_callback: Callable[[str], None] | None = None,
|
||||||
|
) -> None:
|
||||||
task = PaperlessTask.objects.create(
|
task = PaperlessTask.objects.create(
|
||||||
type=PaperlessTask.TaskType.SCHEDULED_TASK
|
type=PaperlessTask.TaskType.SCHEDULED_TASK
|
||||||
if scheduled
|
if scheduled
|
||||||
@@ -136,7 +140,7 @@ def train_classifier(*, scheduled=True) -> None:
|
|||||||
classifier = DocumentClassifier()
|
classifier = DocumentClassifier()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if classifier.train():
|
if classifier.train(status_callback=status_callback):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Saving updated classifier model to {settings.MODEL_FILE}...",
|
f"Saving updated classifier model to {settings.MODEL_FILE}...",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import filecmp
|
import filecmp
|
||||||
import shutil
|
import shutil
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -11,6 +14,9 @@ from django.core.management import call_command
|
|||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from documents.file_handling import generate_filename
|
from documents.file_handling import generate_filename
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.tasks import update_document_content_maybe_archive_file
|
from documents.tasks import update_document_content_maybe_archive_file
|
||||||
@@ -135,14 +141,32 @@ class TestRenamer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.management
|
@pytest.mark.management
|
||||||
class TestCreateClassifier(TestCase):
|
class TestCreateClassifier:
|
||||||
@mock.patch(
|
def test_create_classifier(self, mocker: MockerFixture) -> None:
|
||||||
"documents.management.commands.document_create_classifier.train_classifier",
|
m = mocker.patch(
|
||||||
)
|
"documents.management.commands.document_create_classifier.train_classifier",
|
||||||
def test_create_classifier(self, m) -> None:
|
)
|
||||||
call_command("document_create_classifier")
|
|
||||||
|
|
||||||
m.assert_called_once()
|
call_command("document_create_classifier", "--skip-checks")
|
||||||
|
|
||||||
|
m.assert_called_once_with(scheduled=False, status_callback=mocker.ANY)
|
||||||
|
assert callable(m.call_args.kwargs["status_callback"])
|
||||||
|
|
||||||
|
def test_create_classifier_callback_output(self, mocker: MockerFixture) -> None:
|
||||||
|
"""Callback passed to train_classifier writes each phase message to the console."""
|
||||||
|
m = mocker.patch(
|
||||||
|
"documents.management.commands.document_create_classifier.train_classifier",
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke_callback(**kwargs):
|
||||||
|
kwargs["status_callback"]("Vectorizing document content...")
|
||||||
|
|
||||||
|
m.side_effect = invoke_callback
|
||||||
|
|
||||||
|
stdout = StringIO()
|
||||||
|
call_command("document_create_classifier", "--skip-checks", stdout=stdout)
|
||||||
|
|
||||||
|
assert "Vectorizing document content..." in stdout.getvalue()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.management
|
@pytest.mark.management
|
||||||
|
|||||||
Reference in New Issue
Block a user