mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-04-10 01:58:53 +00:00
perf: fast skip in classifier train() via auto-label-set digest
Add a fast-skip gate at the top of DocumentClassifier.train() that returns False after at most 5 DB queries (1x MAX(modified) on non-inbox docs + 4x MATCH_AUTO pk lists), avoiding the O(N) per-document label scan on no-op calls. Previously the classifier always iterated every document to build the label hash before it could decide to skip — ~8 s at 5k docs, scaling linearly. Changes: - FORMAT_VERSION 10 -> 11 (new field in pickle) - New field `last_auto_label_set_digest` stored after each full train - New static method `_compute_auto_label_set_digest()` (4 queries) - Fast-skip block before the document queryset; mirrors the inbox-tag exclusion used by the training queryset for an apples-to-apples MAX(modified) comparison - Remove old embedded skip check (after the full label scan) which had a correctness gap: MATCH_AUTO labels with no document assignments were invisible to the per-doc hash, so a new unassigned AUTO label would not trigger a retrain Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.core.cache import caches
|
||||
from django.db.models import Max
|
||||
|
||||
from documents.caching import CACHE_5_MINUTES
|
||||
from documents.caching import CACHE_50_MINUTES
|
||||
@@ -99,7 +100,8 @@ class DocumentClassifier:
|
||||
# v8 - Added storage path classifier
|
||||
# v9 - Changed from hashing to time/ids for re-train check
|
||||
# v10 - HMAC-signed model file
|
||||
FORMAT_VERSION = 10
|
||||
# v11 - Added auto-label-set digest for fast skip without full document scan
|
||||
FORMAT_VERSION = 11
|
||||
|
||||
HMAC_SIZE = 32 # SHA-256 digest length
|
||||
|
||||
@@ -108,6 +110,8 @@ class DocumentClassifier:
|
||||
self.last_doc_change_time: datetime | None = None
|
||||
# Hash of primary keys of AUTO matching values last used in training
|
||||
self.last_auto_type_hash: bytes | None = None
|
||||
# Digest of the set of all MATCH_AUTO label PKs (fast-skip guard)
|
||||
self.last_auto_label_set_digest: bytes | None = None
|
||||
|
||||
self.data_vectorizer = None
|
||||
self.data_vectorizer_hash = None
|
||||
@@ -140,6 +144,29 @@ class DocumentClassifier:
|
||||
sha256,
|
||||
).digest()
|
||||
|
||||
@staticmethod
|
||||
def _compute_auto_label_set_digest() -> bytes:
|
||||
"""
|
||||
Return a SHA-256 digest of all MATCH_AUTO label PKs across the four
|
||||
label types. Four cheap indexed queries; stable for any fixed set of
|
||||
AUTO labels regardless of document assignments.
|
||||
"""
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
|
||||
hasher = sha256()
|
||||
for model in (Correspondent, DocumentType, Tag, StoragePath):
|
||||
pks = sorted(
|
||||
model.objects.filter(
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
).values_list("pk", flat=True),
|
||||
)
|
||||
for pk in pks:
|
||||
hasher.update(pk.to_bytes(4, "little", signed=False))
|
||||
return hasher.digest()
|
||||
|
||||
def load(self) -> None:
|
||||
from sklearn.exceptions import InconsistentVersionWarning
|
||||
|
||||
@@ -161,6 +188,7 @@ class DocumentClassifier:
|
||||
schema_version,
|
||||
self.last_doc_change_time,
|
||||
self.last_auto_type_hash,
|
||||
self.last_auto_label_set_digest,
|
||||
self.data_vectorizer,
|
||||
self.tags_binarizer,
|
||||
self.tags_classifier,
|
||||
@@ -202,6 +230,7 @@ class DocumentClassifier:
|
||||
self.FORMAT_VERSION,
|
||||
self.last_doc_change_time,
|
||||
self.last_auto_type_hash,
|
||||
self.last_auto_label_set_digest,
|
||||
self.data_vectorizer,
|
||||
self.tags_binarizer,
|
||||
self.tags_classifier,
|
||||
@@ -224,6 +253,39 @@ class DocumentClassifier:
|
||||
) -> bool:
|
||||
notify = status_callback if status_callback is not None else lambda _: None
|
||||
|
||||
# Fast skip: avoid the expensive per-document label scan when nothing
|
||||
# has changed. Requires a prior training run to have populated both
|
||||
# last_doc_change_time and last_auto_label_set_digest.
|
||||
if (
|
||||
self.last_doc_change_time is not None
|
||||
and self.last_auto_label_set_digest is not None
|
||||
):
|
||||
latest_mod = Document.objects.exclude(
|
||||
tags__is_inbox_tag=True,
|
||||
).aggregate(Max("modified"))["modified__max"]
|
||||
if latest_mod is not None and latest_mod <= self.last_doc_change_time:
|
||||
current_digest = self._compute_auto_label_set_digest()
|
||||
if current_digest == self.last_auto_label_set_digest:
|
||||
logger.info("No updates since last training")
|
||||
cache.set(
|
||||
CLASSIFIER_MODIFIED_KEY,
|
||||
self.last_doc_change_time,
|
||||
CACHE_50_MINUTES,
|
||||
)
|
||||
cache.set(
|
||||
CLASSIFIER_HASH_KEY,
|
||||
self.last_auto_type_hash.hex()
|
||||
if self.last_auto_type_hash
|
||||
else "",
|
||||
CACHE_50_MINUTES,
|
||||
)
|
||||
cache.set(
|
||||
CLASSIFIER_VERSION_KEY,
|
||||
self.FORMAT_VERSION,
|
||||
CACHE_50_MINUTES,
|
||||
)
|
||||
return False
|
||||
|
||||
# Get non-inbox documents
|
||||
docs_queryset = (
|
||||
Document.objects.exclude(
|
||||
@@ -282,25 +344,7 @@ class DocumentClassifier:
|
||||
|
||||
num_tags = len(labels_tags_unique)
|
||||
|
||||
# Check if retraining is actually required.
|
||||
# A document has been updated since the classifier was trained
|
||||
# New auto tags, types, correspondent, storage paths exist
|
||||
latest_doc_change = docs_queryset.latest("modified").modified
|
||||
if (
|
||||
self.last_doc_change_time is not None
|
||||
and self.last_doc_change_time >= latest_doc_change
|
||||
) and self.last_auto_type_hash == hasher.digest():
|
||||
logger.info("No updates since last training")
|
||||
# Set the classifier information into the cache
|
||||
# Caching for 50 minutes, so slightly less than the normal retrain time
|
||||
cache.set(
|
||||
CLASSIFIER_MODIFIED_KEY,
|
||||
self.last_doc_change_time,
|
||||
CACHE_50_MINUTES,
|
||||
)
|
||||
cache.set(CLASSIFIER_HASH_KEY, hasher.hexdigest(), CACHE_50_MINUTES)
|
||||
cache.set(CLASSIFIER_VERSION_KEY, self.FORMAT_VERSION, CACHE_50_MINUTES)
|
||||
return False
|
||||
|
||||
# subtract 1 since -1 (null) is also part of the classes.
|
||||
|
||||
@@ -416,6 +460,7 @@ class DocumentClassifier:
|
||||
|
||||
self.last_doc_change_time = latest_doc_change
|
||||
self.last_auto_type_hash = hasher.digest()
|
||||
self.last_auto_label_set_digest = self._compute_auto_label_set_digest()
|
||||
self._update_data_vectorizer_hash()
|
||||
|
||||
# Set the classifier information into the cache
|
||||
|
||||
325
src/documents/tests/test_classifier_train_skip.py
Normal file
325
src/documents/tests/test_classifier_train_skip.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
Phase 1 — fast-skip optimisation in DocumentClassifier.train()
|
||||
|
||||
The goal: when nothing has changed since the last training run, train() should
|
||||
return False after at most 5 DB queries (1x MAX(modified) + 4x MATCH_AUTO pk
|
||||
lists), not after a full per-document label scan.
|
||||
|
||||
Correctness invariant: the skip must NOT fire when the set of AUTO-matching
|
||||
labels has changed, even if no Document.modified timestamp has advanced (e.g.
|
||||
a Tag's matching_algorithm was flipped to MATCH_AUTO after the last train).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from django.db import connection
|
||||
from django.test.utils import CaptureQueriesContext
|
||||
|
||||
from documents.classifier import DocumentClassifier
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import MatchingModel
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def classifier_settings(settings, tmp_path: Path):
|
||||
"""Point MODEL_FILE at a temp directory so tests are hermetic."""
|
||||
settings.MODEL_FILE = tmp_path / "model.pickle"
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def classifier(classifier_settings):
|
||||
"""Fresh DocumentClassifier instance with test settings active."""
|
||||
return DocumentClassifier()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def label_corpus(classifier_settings):
|
||||
"""
|
||||
Minimal label + document corpus that produces a trainable classifier.
|
||||
|
||||
Creates
|
||||
-------
|
||||
Correspondents
|
||||
c_auto — MATCH_AUTO, assigned to two docs
|
||||
c_none — MATCH_NONE (control)
|
||||
DocumentTypes
|
||||
dt_auto — MATCH_AUTO, assigned to two docs
|
||||
dt_none — MATCH_NONE (control)
|
||||
Tags
|
||||
t_auto — MATCH_AUTO, applied to two docs
|
||||
t_none — MATCH_NONE (control, applied to one doc but never learned)
|
||||
StoragePaths
|
||||
sp_auto — MATCH_AUTO, assigned to two docs
|
||||
sp_none — MATCH_NONE (control)
|
||||
|
||||
Documents
|
||||
doc_a, doc_b — assigned AUTO labels above
|
||||
doc_c — control doc (MATCH_NONE labels only)
|
||||
|
||||
The fixture returns a dict with all created objects for direct mutation in
|
||||
individual tests.
|
||||
"""
|
||||
c_auto = Correspondent.objects.create(
|
||||
name="Auto Corp",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
c_none = Correspondent.objects.create(
|
||||
name="Manual Corp",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
|
||||
dt_auto = DocumentType.objects.create(
|
||||
name="Invoice",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
dt_none = DocumentType.objects.create(
|
||||
name="Other",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
|
||||
t_auto = Tag.objects.create(
|
||||
name="finance",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
t_none = Tag.objects.create(
|
||||
name="misc",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
|
||||
sp_auto = StoragePath.objects.create(
|
||||
name="Finance Path",
|
||||
path="finance/{correspondent}",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
sp_none = StoragePath.objects.create(
|
||||
name="Other Path",
|
||||
path="other/{correspondent}",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
|
||||
doc_a = Document.objects.create(
|
||||
title="Invoice from Auto Corp Jan",
|
||||
content="quarterly invoice payment tax financial statement revenue",
|
||||
correspondent=c_auto,
|
||||
document_type=dt_auto,
|
||||
storage_path=sp_auto,
|
||||
checksum="aaa",
|
||||
mime_type="application/pdf",
|
||||
filename="invoice_a.pdf",
|
||||
)
|
||||
doc_a.tags.set([t_auto])
|
||||
|
||||
doc_b = Document.objects.create(
|
||||
title="Invoice from Auto Corp Feb",
|
||||
content="monthly invoice billing statement account balance due",
|
||||
correspondent=c_auto,
|
||||
document_type=dt_auto,
|
||||
storage_path=sp_auto,
|
||||
checksum="bbb",
|
||||
mime_type="application/pdf",
|
||||
filename="invoice_b.pdf",
|
||||
)
|
||||
doc_b.tags.set([t_auto])
|
||||
|
||||
# Control document — no AUTO labels, but has enough content to vectorize
|
||||
doc_c = Document.objects.create(
|
||||
title="Miscellaneous Notes",
|
||||
content="meeting notes agenda discussion summary action items follow",
|
||||
correspondent=c_none,
|
||||
document_type=dt_none,
|
||||
checksum="ccc",
|
||||
mime_type="application/pdf",
|
||||
filename="notes_c.pdf",
|
||||
)
|
||||
doc_c.tags.set([t_none])
|
||||
|
||||
return {
|
||||
"c_auto": c_auto,
|
||||
"c_none": c_none,
|
||||
"dt_auto": dt_auto,
|
||||
"dt_none": dt_none,
|
||||
"t_auto": t_auto,
|
||||
"t_none": t_none,
|
||||
"sp_auto": sp_auto,
|
||||
"sp_none": sp_none,
|
||||
"doc_a": doc_a,
|
||||
"doc_b": doc_b,
|
||||
"doc_c": doc_c,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy-path skip tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.django_db()
|
||||
class TestFastSkipFires:
|
||||
"""The no-op path: nothing changed, so the second train() is skipped."""
|
||||
|
||||
def test_first_train_returns_true(self, classifier, label_corpus):
|
||||
"""First train on a fresh classifier must return True (did work)."""
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_second_train_returns_false(self, classifier, label_corpus):
|
||||
"""Second train with no changes must return False (skipped)."""
|
||||
classifier.train()
|
||||
assert classifier.train() is False
|
||||
|
||||
def test_fast_skip_runs_minimal_queries(self, classifier, label_corpus):
|
||||
"""
|
||||
The no-op path must use at most 5 DB queries:
|
||||
1x Document.objects.aggregate(Max('modified'))
|
||||
4x MATCH_AUTO pk lists (Correspondent / DocumentType / Tag / StoragePath)
|
||||
|
||||
The current implementation (before Phase 1) iterates every document
|
||||
to build the label hash BEFORE it can decide to skip, which is O(N).
|
||||
This test verifies the fast path is in place.
|
||||
"""
|
||||
classifier.train()
|
||||
with CaptureQueriesContext(connection) as ctx:
|
||||
result = classifier.train()
|
||||
assert result is False
|
||||
assert len(ctx.captured_queries) <= 5, (
|
||||
f"Fast skip used {len(ctx.captured_queries)} queries; expected ≤5.\n"
|
||||
+ "\n".join(q["sql"] for q in ctx.captured_queries)
|
||||
)
|
||||
|
||||
def test_fast_skip_refreshes_cache_keys(self, classifier, label_corpus):
|
||||
"""
|
||||
Even on a skip, the cache keys must be refreshed so that the task
|
||||
scheduler can detect the classifier is still current.
|
||||
"""
|
||||
from django.core.cache import cache
|
||||
|
||||
from documents.caching import CLASSIFIER_HASH_KEY
|
||||
from documents.caching import CLASSIFIER_MODIFIED_KEY
|
||||
from documents.caching import CLASSIFIER_VERSION_KEY
|
||||
|
||||
classifier.train()
|
||||
# Evict the keys to prove skip re-populates them
|
||||
cache.delete(CLASSIFIER_MODIFIED_KEY)
|
||||
cache.delete(CLASSIFIER_HASH_KEY)
|
||||
cache.delete(CLASSIFIER_VERSION_KEY)
|
||||
|
||||
result = classifier.train()
|
||||
|
||||
assert result is False
|
||||
assert cache.get(CLASSIFIER_MODIFIED_KEY) is not None
|
||||
assert cache.get(CLASSIFIER_HASH_KEY) is not None
|
||||
assert cache.get(CLASSIFIER_VERSION_KEY) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Correctness tests — skip must NOT fire when the world has changed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.django_db()
|
||||
class TestFastSkipDoesNotFire:
|
||||
"""The skip guard must yield to a full retrain whenever labels change."""
|
||||
|
||||
def test_document_content_modification_triggers_retrain(
|
||||
self,
|
||||
classifier,
|
||||
label_corpus,
|
||||
):
|
||||
"""Updating a document's content updates modified → retrain required."""
|
||||
classifier.train()
|
||||
doc_a = label_corpus["doc_a"]
|
||||
doc_a.content = "completely different words here now nothing same"
|
||||
doc_a.save()
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_document_label_reassignment_triggers_retrain(
|
||||
self,
|
||||
classifier,
|
||||
label_corpus,
|
||||
):
|
||||
"""
|
||||
Reassigning a document to a different AUTO correspondent (touching
|
||||
doc.modified) must trigger a retrain.
|
||||
"""
|
||||
c_auto2 = Correspondent.objects.create(
|
||||
name="Second Auto Corp",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
classifier.train()
|
||||
doc_a = label_corpus["doc_a"]
|
||||
doc_a.correspondent = c_auto2
|
||||
doc_a.save()
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_matching_algorithm_change_on_assigned_tag_triggers_retrain(
|
||||
self,
|
||||
classifier,
|
||||
label_corpus,
|
||||
):
|
||||
"""
|
||||
Flipping a tag's matching_algorithm to MATCH_AUTO after it is already
|
||||
assigned to documents must trigger a retrain — even though no
|
||||
Document.modified timestamp advances.
|
||||
|
||||
This is the key correctness case for the auto-label-set digest:
|
||||
the tag is already on doc_a and doc_b, so once it becomes MATCH_AUTO
|
||||
the classifier needs to learn it.
|
||||
"""
|
||||
# t_none is applied to doc_c (a control doc) via the fixture.
|
||||
# We flip it to MATCH_AUTO; the set of learnable AUTO tags grows.
|
||||
classifier.train()
|
||||
t_none = label_corpus["t_none"]
|
||||
t_none.matching_algorithm = MatchingModel.MATCH_AUTO
|
||||
t_none.save(update_fields=["matching_algorithm"])
|
||||
# Document.modified is NOT touched — this test specifically verifies
|
||||
# that the auto-label-set digest catches the change.
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_new_auto_correspondent_triggers_retrain(self, classifier, label_corpus):
|
||||
"""
|
||||
Adding a brand-new MATCH_AUTO correspondent (unassigned to any doc)
|
||||
must trigger a retrain: the auto-label-set has grown.
|
||||
"""
|
||||
classifier.train()
|
||||
Correspondent.objects.create(
|
||||
name="New Auto Corp",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_removing_auto_label_triggers_retrain(self, classifier, label_corpus):
|
||||
"""
|
||||
Deleting a MATCH_AUTO correspondent shrinks the auto-label-set and
|
||||
must trigger a retrain.
|
||||
"""
|
||||
classifier.train()
|
||||
label_corpus["c_auto"].delete()
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_fresh_classifier_always_trains(self, classifier, label_corpus):
|
||||
"""
|
||||
A classifier that has never been trained (last_doc_change_time is None)
|
||||
must always perform a full train, regardless of corpus state.
|
||||
"""
|
||||
assert classifier.last_doc_change_time is None
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_no_documents_raises_value_error(self, classifier, classifier_settings):
|
||||
"""train() with an empty database must raise ValueError."""
|
||||
with pytest.raises(ValueError, match="No training data"):
|
||||
classifier.train()
|
||||
Reference in New Issue
Block a user