From 68b866aeee1051d41223720fd2a00d239e68652c Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Wed, 8 Apr 2026 09:30:16 -0700 Subject: [PATCH] perf: fast skip in classifier train() via auto-label-set digest MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a fast-skip gate at the top of DocumentClassifier.train() that returns False after at most 5 DB queries (1x MAX(modified) on non-inbox docs + 4x MATCH_AUTO pk lists), avoiding the O(N) per-document label scan on no-op calls. Previously the classifier always iterated every document to build the label hash before it could decide to skip — ~8 s at 5k docs, scaling linearly. Changes: - FORMAT_VERSION 10 -> 11 (new field in pickle) - New field `last_auto_label_set_digest` stored after each full train - New static method `_compute_auto_label_set_digest()` (4 queries) - Fast-skip block before the document queryset; mirrors the inbox-tag exclusion used by the training queryset for an apples-to-apples MAX(modified) comparison - Remove old embedded skip check (after the full label scan) which had a correctness gap: MATCH_AUTO labels with no document assignments were invisible to the per-doc hash, so a new unassigned AUTO label would not trigger a retrain Co-Authored-By: Claude Sonnet 4.6 --- src/documents/classifier.py | 83 ++++- .../tests/test_classifier_train_skip.py | 325 ++++++++++++++++++ 2 files changed, 389 insertions(+), 19 deletions(-) create mode 100644 src/documents/tests/test_classifier_train_skip.py diff --git a/src/documents/classifier.py b/src/documents/classifier.py index 519e1eac5..d285800ff 100644 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from django.conf import settings from django.core.cache import cache from django.core.cache import caches +from django.db.models import Max from documents.caching import CACHE_5_MINUTES from documents.caching import CACHE_50_MINUTES @@ -99,7 +100,8 @@ class DocumentClassifier: # v8 - Added storage path classifier # v9 - Changed from hashing to time/ids for re-train check # v10 - HMAC-signed model file - FORMAT_VERSION = 10 + # v11 - Added auto-label-set digest for fast skip without full document scan + FORMAT_VERSION = 11 HMAC_SIZE = 32 # SHA-256 digest length @@ -108,6 +110,8 @@ class DocumentClassifier: self.last_doc_change_time: datetime | None = None # Hash of primary keys of AUTO matching values last used in training self.last_auto_type_hash: bytes | None = None + # Digest of the set of all MATCH_AUTO label PKs (fast-skip guard) + self.last_auto_label_set_digest: bytes | None = None self.data_vectorizer = None self.data_vectorizer_hash = None @@ -140,6 +144,29 @@ class DocumentClassifier: sha256, ).digest() + @staticmethod + def _compute_auto_label_set_digest() -> bytes: + """ + Return a SHA-256 digest of all MATCH_AUTO label PKs across the four + label types. Four cheap indexed queries; stable for any fixed set of + AUTO labels regardless of document assignments. + """ + from documents.models import Correspondent + from documents.models import DocumentType + from documents.models import StoragePath + from documents.models import Tag + + hasher = sha256() + for model in (Correspondent, DocumentType, Tag, StoragePath): + pks = sorted( + model.objects.filter( + matching_algorithm=MatchingModel.MATCH_AUTO, + ).values_list("pk", flat=True), + ) + for pk in pks: + hasher.update(pk.to_bytes(4, "little", signed=False)) + return hasher.digest() + def load(self) -> None: from sklearn.exceptions import InconsistentVersionWarning @@ -161,6 +188,7 @@ class DocumentClassifier: schema_version, self.last_doc_change_time, self.last_auto_type_hash, + self.last_auto_label_set_digest, self.data_vectorizer, self.tags_binarizer, self.tags_classifier, @@ -202,6 +230,7 @@ class DocumentClassifier: self.FORMAT_VERSION, self.last_doc_change_time, self.last_auto_type_hash, + self.last_auto_label_set_digest, self.data_vectorizer, self.tags_binarizer, self.tags_classifier, @@ -224,6 +253,39 @@ class DocumentClassifier: ) -> bool: notify = status_callback if status_callback is not None else lambda _: None + # Fast skip: avoid the expensive per-document label scan when nothing + # has changed. Requires a prior training run to have populated both + # last_doc_change_time and last_auto_label_set_digest. + if ( + self.last_doc_change_time is not None + and self.last_auto_label_set_digest is not None + ): + latest_mod = Document.objects.exclude( + tags__is_inbox_tag=True, + ).aggregate(Max("modified"))["modified__max"] + if latest_mod is not None and latest_mod <= self.last_doc_change_time: + current_digest = self._compute_auto_label_set_digest() + if current_digest == self.last_auto_label_set_digest: + logger.info("No updates since last training") + cache.set( + CLASSIFIER_MODIFIED_KEY, + self.last_doc_change_time, + CACHE_50_MINUTES, + ) + cache.set( + CLASSIFIER_HASH_KEY, + self.last_auto_type_hash.hex() + if self.last_auto_type_hash + else "", + CACHE_50_MINUTES, + ) + cache.set( + CLASSIFIER_VERSION_KEY, + self.FORMAT_VERSION, + CACHE_50_MINUTES, + ) + return False + # Get non-inbox documents docs_queryset = ( Document.objects.exclude( @@ -282,25 +344,7 @@ class DocumentClassifier: num_tags = len(labels_tags_unique) - # Check if retraining is actually required. - # A document has been updated since the classifier was trained - # New auto tags, types, correspondent, storage paths exist latest_doc_change = docs_queryset.latest("modified").modified - if ( - self.last_doc_change_time is not None - and self.last_doc_change_time >= latest_doc_change - ) and self.last_auto_type_hash == hasher.digest(): - logger.info("No updates since last training") - # Set the classifier information into the cache - # Caching for 50 minutes, so slightly less than the normal retrain time - cache.set( - CLASSIFIER_MODIFIED_KEY, - self.last_doc_change_time, - CACHE_50_MINUTES, - ) - cache.set(CLASSIFIER_HASH_KEY, hasher.hexdigest(), CACHE_50_MINUTES) - cache.set(CLASSIFIER_VERSION_KEY, self.FORMAT_VERSION, CACHE_50_MINUTES) - return False # subtract 1 since -1 (null) is also part of the classes. @@ -416,6 +460,7 @@ class DocumentClassifier: self.last_doc_change_time = latest_doc_change self.last_auto_type_hash = hasher.digest() + self.last_auto_label_set_digest = self._compute_auto_label_set_digest() self._update_data_vectorizer_hash() # Set the classifier information into the cache diff --git a/src/documents/tests/test_classifier_train_skip.py b/src/documents/tests/test_classifier_train_skip.py new file mode 100644 index 000000000..e75a76174 --- /dev/null +++ b/src/documents/tests/test_classifier_train_skip.py @@ -0,0 +1,325 @@ +""" +Phase 1 — fast-skip optimisation in DocumentClassifier.train() + +The goal: when nothing has changed since the last training run, train() should +return False after at most 5 DB queries (1x MAX(modified) + 4x MATCH_AUTO pk +lists), not after a full per-document label scan. + +Correctness invariant: the skip must NOT fire when the set of AUTO-matching +labels has changed, even if no Document.modified timestamp has advanced (e.g. +a Tag's matching_algorithm was flipped to MATCH_AUTO after the last train). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from django.db import connection +from django.test.utils import CaptureQueriesContext + +from documents.classifier import DocumentClassifier +from documents.models import Correspondent +from documents.models import Document +from documents.models import DocumentType +from documents.models import MatchingModel +from documents.models import StoragePath +from documents.models import Tag + +if TYPE_CHECKING: + from pathlib import Path + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def classifier_settings(settings, tmp_path: Path): + """Point MODEL_FILE at a temp directory so tests are hermetic.""" + settings.MODEL_FILE = tmp_path / "model.pickle" + return settings + + +@pytest.fixture() +def classifier(classifier_settings): + """Fresh DocumentClassifier instance with test settings active.""" + return DocumentClassifier() + + +@pytest.fixture() +def label_corpus(classifier_settings): + """ + Minimal label + document corpus that produces a trainable classifier. + + Creates + ------- + Correspondents + c_auto — MATCH_AUTO, assigned to two docs + c_none — MATCH_NONE (control) + DocumentTypes + dt_auto — MATCH_AUTO, assigned to two docs + dt_none — MATCH_NONE (control) + Tags + t_auto — MATCH_AUTO, applied to two docs + t_none — MATCH_NONE (control, applied to one doc but never learned) + StoragePaths + sp_auto — MATCH_AUTO, assigned to two docs + sp_none — MATCH_NONE (control) + + Documents + doc_a, doc_b — assigned AUTO labels above + doc_c — control doc (MATCH_NONE labels only) + + The fixture returns a dict with all created objects for direct mutation in + individual tests. + """ + c_auto = Correspondent.objects.create( + name="Auto Corp", + matching_algorithm=MatchingModel.MATCH_AUTO, + ) + c_none = Correspondent.objects.create( + name="Manual Corp", + matching_algorithm=MatchingModel.MATCH_NONE, + ) + + dt_auto = DocumentType.objects.create( + name="Invoice", + matching_algorithm=MatchingModel.MATCH_AUTO, + ) + dt_none = DocumentType.objects.create( + name="Other", + matching_algorithm=MatchingModel.MATCH_NONE, + ) + + t_auto = Tag.objects.create( + name="finance", + matching_algorithm=MatchingModel.MATCH_AUTO, + ) + t_none = Tag.objects.create( + name="misc", + matching_algorithm=MatchingModel.MATCH_NONE, + ) + + sp_auto = StoragePath.objects.create( + name="Finance Path", + path="finance/{correspondent}", + matching_algorithm=MatchingModel.MATCH_AUTO, + ) + sp_none = StoragePath.objects.create( + name="Other Path", + path="other/{correspondent}", + matching_algorithm=MatchingModel.MATCH_NONE, + ) + + doc_a = Document.objects.create( + title="Invoice from Auto Corp Jan", + content="quarterly invoice payment tax financial statement revenue", + correspondent=c_auto, + document_type=dt_auto, + storage_path=sp_auto, + checksum="aaa", + mime_type="application/pdf", + filename="invoice_a.pdf", + ) + doc_a.tags.set([t_auto]) + + doc_b = Document.objects.create( + title="Invoice from Auto Corp Feb", + content="monthly invoice billing statement account balance due", + correspondent=c_auto, + document_type=dt_auto, + storage_path=sp_auto, + checksum="bbb", + mime_type="application/pdf", + filename="invoice_b.pdf", + ) + doc_b.tags.set([t_auto]) + + # Control document — no AUTO labels, but has enough content to vectorize + doc_c = Document.objects.create( + title="Miscellaneous Notes", + content="meeting notes agenda discussion summary action items follow", + correspondent=c_none, + document_type=dt_none, + checksum="ccc", + mime_type="application/pdf", + filename="notes_c.pdf", + ) + doc_c.tags.set([t_none]) + + return { + "c_auto": c_auto, + "c_none": c_none, + "dt_auto": dt_auto, + "dt_none": dt_none, + "t_auto": t_auto, + "t_none": t_none, + "sp_auto": sp_auto, + "sp_none": sp_none, + "doc_a": doc_a, + "doc_b": doc_b, + "doc_c": doc_c, + } + + +# --------------------------------------------------------------------------- +# Happy-path skip tests +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db() +class TestFastSkipFires: + """The no-op path: nothing changed, so the second train() is skipped.""" + + def test_first_train_returns_true(self, classifier, label_corpus): + """First train on a fresh classifier must return True (did work).""" + assert classifier.train() is True + + def test_second_train_returns_false(self, classifier, label_corpus): + """Second train with no changes must return False (skipped).""" + classifier.train() + assert classifier.train() is False + + def test_fast_skip_runs_minimal_queries(self, classifier, label_corpus): + """ + The no-op path must use at most 5 DB queries: + 1x Document.objects.aggregate(Max('modified')) + 4x MATCH_AUTO pk lists (Correspondent / DocumentType / Tag / StoragePath) + + The current implementation (before Phase 1) iterates every document + to build the label hash BEFORE it can decide to skip, which is O(N). + This test verifies the fast path is in place. + """ + classifier.train() + with CaptureQueriesContext(connection) as ctx: + result = classifier.train() + assert result is False + assert len(ctx.captured_queries) <= 5, ( + f"Fast skip used {len(ctx.captured_queries)} queries; expected ≤5.\n" + + "\n".join(q["sql"] for q in ctx.captured_queries) + ) + + def test_fast_skip_refreshes_cache_keys(self, classifier, label_corpus): + """ + Even on a skip, the cache keys must be refreshed so that the task + scheduler can detect the classifier is still current. + """ + from django.core.cache import cache + + from documents.caching import CLASSIFIER_HASH_KEY + from documents.caching import CLASSIFIER_MODIFIED_KEY + from documents.caching import CLASSIFIER_VERSION_KEY + + classifier.train() + # Evict the keys to prove skip re-populates them + cache.delete(CLASSIFIER_MODIFIED_KEY) + cache.delete(CLASSIFIER_HASH_KEY) + cache.delete(CLASSIFIER_VERSION_KEY) + + result = classifier.train() + + assert result is False + assert cache.get(CLASSIFIER_MODIFIED_KEY) is not None + assert cache.get(CLASSIFIER_HASH_KEY) is not None + assert cache.get(CLASSIFIER_VERSION_KEY) is not None + + +# --------------------------------------------------------------------------- +# Correctness tests — skip must NOT fire when the world has changed +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db() +class TestFastSkipDoesNotFire: + """The skip guard must yield to a full retrain whenever labels change.""" + + def test_document_content_modification_triggers_retrain( + self, + classifier, + label_corpus, + ): + """Updating a document's content updates modified → retrain required.""" + classifier.train() + doc_a = label_corpus["doc_a"] + doc_a.content = "completely different words here now nothing same" + doc_a.save() + assert classifier.train() is True + + def test_document_label_reassignment_triggers_retrain( + self, + classifier, + label_corpus, + ): + """ + Reassigning a document to a different AUTO correspondent (touching + doc.modified) must trigger a retrain. + """ + c_auto2 = Correspondent.objects.create( + name="Second Auto Corp", + matching_algorithm=MatchingModel.MATCH_AUTO, + ) + classifier.train() + doc_a = label_corpus["doc_a"] + doc_a.correspondent = c_auto2 + doc_a.save() + assert classifier.train() is True + + def test_matching_algorithm_change_on_assigned_tag_triggers_retrain( + self, + classifier, + label_corpus, + ): + """ + Flipping a tag's matching_algorithm to MATCH_AUTO after it is already + assigned to documents must trigger a retrain — even though no + Document.modified timestamp advances. + + This is the key correctness case for the auto-label-set digest: + the tag is already on doc_a and doc_b, so once it becomes MATCH_AUTO + the classifier needs to learn it. + """ + # t_none is applied to doc_c (a control doc) via the fixture. + # We flip it to MATCH_AUTO; the set of learnable AUTO tags grows. + classifier.train() + t_none = label_corpus["t_none"] + t_none.matching_algorithm = MatchingModel.MATCH_AUTO + t_none.save(update_fields=["matching_algorithm"]) + # Document.modified is NOT touched — this test specifically verifies + # that the auto-label-set digest catches the change. + assert classifier.train() is True + + def test_new_auto_correspondent_triggers_retrain(self, classifier, label_corpus): + """ + Adding a brand-new MATCH_AUTO correspondent (unassigned to any doc) + must trigger a retrain: the auto-label-set has grown. + """ + classifier.train() + Correspondent.objects.create( + name="New Auto Corp", + matching_algorithm=MatchingModel.MATCH_AUTO, + ) + assert classifier.train() is True + + def test_removing_auto_label_triggers_retrain(self, classifier, label_corpus): + """ + Deleting a MATCH_AUTO correspondent shrinks the auto-label-set and + must trigger a retrain. + """ + classifier.train() + label_corpus["c_auto"].delete() + assert classifier.train() is True + + def test_fresh_classifier_always_trains(self, classifier, label_corpus): + """ + A classifier that has never been trained (last_doc_change_time is None) + must always perform a full train, regardless of corpus state. + """ + assert classifier.last_doc_change_time is None + assert classifier.train() is True + + def test_no_documents_raises_value_error(self, classifier, classifier_settings): + """train() with an empty database must raise ValueError.""" + with pytest.raises(ValueError, match="No training data"): + classifier.train()