Compare commits

...

3 Commits

Author SHA1 Message Date
Trenton H
1a26514a96 perf: replace MLPClassifier with LinearSVC for multi-tag classification
For the common case (num_tags > 1), switch from MLPClassifier to
OneVsRestClassifier(LinearSVC()) for the tags classifier.

MLPClassifier with thousands of output neurons (e.g. 3,085 AUTO tags)
requires a dense num_docs x num_tags label matrix and runs full
gradient descent with Adam optimiser for up to 200 epochs -- the
primary cause of >10 GB RAM and multi-hour training in extreme cases.

LinearSVC trains one binary linear SVM per class via OneVsRestClassifier.
Each model is a single weight vector; training is parallelisable and
orders of magnitude faster for large class counts.

The num_tags == 1 binary path is unchanged (MLP is kept there because
LinearSVC requires at least 2 distinct classes in training data, which
is not guaranteed when all documents share the single AUTO tag).

Adds test_classifier_tags_correctness.py, which verifies:
- Multi-cluster docs are predicted correctly (single and multi-tag)
- Single-tag (binary) path is predicted correctly
- Test passes with MLP (baseline) and LinearSVC (after swap)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 10:18:16 -07:00
Trenton H
1fefd506b7 perf: eliminate second document queryset scan in classifier train()
Capture doc.content during the label extraction loop so the document
queryset is iterated exactly once per training run.

Previously CountVectorizer.fit_transform() consumed a content_generator()
that re-evaluated the same docs_queryset, causing a second full table
scan. At 5k docs this wasted ~2.4 s and doubled DB I/O on every train.

Remove content_generator(); replace with a generator expression over
the in-memory doc_contents list collected during Step 1.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 09:38:53 -07:00
Trenton H
68b866aeee perf: fast skip in classifier train() via auto-label-set digest
Add a fast-skip gate at the top of DocumentClassifier.train() that
returns False after at most 5 DB queries (1x MAX(modified) on
non-inbox docs + 4x MATCH_AUTO pk lists), avoiding the O(N)
per-document label scan on no-op calls.

Previously the classifier always iterated every document to build the
label hash before it could decide to skip — ~8 s at 5k docs, scaling
linearly.

Changes:
- FORMAT_VERSION 10 -> 11 (new field in pickle)
- New field `last_auto_label_set_digest` stored after each full train
- New static method `_compute_auto_label_set_digest()` (4 queries)
- Fast-skip block before the document queryset; mirrors the inbox-tag
  exclusion used by the training queryset for an apples-to-apples
  MAX(modified) comparison
- Remove old embedded skip check (after the full label scan) which had
  a correctness gap: MATCH_AUTO labels with no document assignments
  were invisible to the per-doc hash, so a new unassigned AUTO label
  would not trigger a retrain

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 09:30:16 -07:00
4 changed files with 840 additions and 32 deletions

View File

@@ -11,7 +11,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from numpy import ndarray
@@ -19,6 +18,7 @@ if TYPE_CHECKING:
from django.conf import settings
from django.core.cache import cache
from django.core.cache import caches
from django.db.models import Max
from documents.caching import CACHE_5_MINUTES
from documents.caching import CACHE_50_MINUTES
@@ -99,7 +99,8 @@ class DocumentClassifier:
# v8 - Added storage path classifier
# v9 - Changed from hashing to time/ids for re-train check
# v10 - HMAC-signed model file
FORMAT_VERSION = 10
# v11 - Added auto-label-set digest for fast skip without full document scan
FORMAT_VERSION = 11
HMAC_SIZE = 32 # SHA-256 digest length
@@ -108,6 +109,8 @@ class DocumentClassifier:
self.last_doc_change_time: datetime | None = None
# Hash of primary keys of AUTO matching values last used in training
self.last_auto_type_hash: bytes | None = None
# Digest of the set of all MATCH_AUTO label PKs (fast-skip guard)
self.last_auto_label_set_digest: bytes | None = None
self.data_vectorizer = None
self.data_vectorizer_hash = None
@@ -140,6 +143,29 @@ class DocumentClassifier:
sha256,
).digest()
@staticmethod
def _compute_auto_label_set_digest() -> bytes:
"""
Return a SHA-256 digest of all MATCH_AUTO label PKs across the four
label types. Four cheap indexed queries; stable for any fixed set of
AUTO labels regardless of document assignments.
"""
from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
hasher = sha256()
for model in (Correspondent, DocumentType, Tag, StoragePath):
pks = sorted(
model.objects.filter(
matching_algorithm=MatchingModel.MATCH_AUTO,
).values_list("pk", flat=True),
)
for pk in pks:
hasher.update(pk.to_bytes(4, "little", signed=False))
return hasher.digest()
def load(self) -> None:
from sklearn.exceptions import InconsistentVersionWarning
@@ -161,6 +187,7 @@ class DocumentClassifier:
schema_version,
self.last_doc_change_time,
self.last_auto_type_hash,
self.last_auto_label_set_digest,
self.data_vectorizer,
self.tags_binarizer,
self.tags_classifier,
@@ -202,6 +229,7 @@ class DocumentClassifier:
self.FORMAT_VERSION,
self.last_doc_change_time,
self.last_auto_type_hash,
self.last_auto_label_set_digest,
self.data_vectorizer,
self.tags_binarizer,
self.tags_classifier,
@@ -224,6 +252,39 @@ class DocumentClassifier:
) -> bool:
notify = status_callback if status_callback is not None else lambda _: None
# Fast skip: avoid the expensive per-document label scan when nothing
# has changed. Requires a prior training run to have populated both
# last_doc_change_time and last_auto_label_set_digest.
if (
self.last_doc_change_time is not None
and self.last_auto_label_set_digest is not None
):
latest_mod = Document.objects.exclude(
tags__is_inbox_tag=True,
).aggregate(Max("modified"))["modified__max"]
if latest_mod is not None and latest_mod <= self.last_doc_change_time:
current_digest = self._compute_auto_label_set_digest()
if current_digest == self.last_auto_label_set_digest:
logger.info("No updates since last training")
cache.set(
CLASSIFIER_MODIFIED_KEY,
self.last_doc_change_time,
CACHE_50_MINUTES,
)
cache.set(
CLASSIFIER_HASH_KEY,
self.last_auto_type_hash.hex()
if self.last_auto_type_hash
else "",
CACHE_50_MINUTES,
)
cache.set(
CLASSIFIER_VERSION_KEY,
self.FORMAT_VERSION,
CACHE_50_MINUTES,
)
return False
# Get non-inbox documents
docs_queryset = (
Document.objects.exclude(
@@ -242,12 +303,15 @@ class DocumentClassifier:
labels_correspondent = []
labels_document_type = []
labels_storage_path = []
doc_contents: list[str] = []
# Step 1: Extract and preprocess training data from the database.
# Step 1: Extract labels and capture content in a single pass.
logger.debug("Gathering data from database...")
notify(f"Gathering data from {docs_queryset.count()} document(s)...")
hasher = sha256()
for doc in docs_queryset:
doc_contents.append(doc.content)
y = -1
dt = doc.document_type
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
@@ -282,25 +346,7 @@ class DocumentClassifier:
num_tags = len(labels_tags_unique)
# Check if retraining is actually required.
# A document has been updated since the classifier was trained
# New auto tags, types, correspondent, storage paths exist
latest_doc_change = docs_queryset.latest("modified").modified
if (
self.last_doc_change_time is not None
and self.last_doc_change_time >= latest_doc_change
) and self.last_auto_type_hash == hasher.digest():
logger.info("No updates since last training")
# Set the classifier information into the cache
# Caching for 50 minutes, so slightly less than the normal retrain time
cache.set(
CLASSIFIER_MODIFIED_KEY,
self.last_doc_change_time,
CACHE_50_MINUTES,
)
cache.set(CLASSIFIER_HASH_KEY, hasher.hexdigest(), CACHE_50_MINUTES)
cache.set(CLASSIFIER_VERSION_KEY, self.FORMAT_VERSION, CACHE_50_MINUTES)
return False
# subtract 1 since -1 (null) is also part of the classes.
@@ -317,21 +363,16 @@ class DocumentClassifier:
)
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.svm import LinearSVC
# Step 2: vectorize data
logger.debug("Vectorizing data...")
notify("Vectorizing document content...")
def content_generator() -> Iterator[str]:
"""
Generates the content for documents, but once at a time
"""
for doc in docs_queryset:
yield self.preprocess_content(doc.content, shared_cache=False)
self.data_vectorizer = CountVectorizer(
analyzer="word",
ngram_range=(1, 2),
@@ -339,7 +380,8 @@ class DocumentClassifier:
)
data_vectorized: ndarray = self.data_vectorizer.fit_transform(
content_generator(),
self.preprocess_content(content, shared_cache=False)
for content in doc_contents
)
# See the notes here:
@@ -353,8 +395,10 @@ class DocumentClassifier:
notify(f"Training tags classifier ({num_tags} tag(s))...")
if num_tags == 1:
# Special case where only one tag has auto:
# Fallback to binary classification.
# Special case: only one AUTO tag — use binary classification.
# MLPClassifier is used here because LinearSVC requires at least
# 2 distinct classes in training data, which cannot be guaranteed
# when all documents share the single AUTO tag.
labels_tags = [
label[0] if len(label) == 1 else -1 for label in labels_tags
]
@@ -362,11 +406,15 @@ class DocumentClassifier:
labels_tags_vectorized: ndarray = self.tags_binarizer.fit_transform(
labels_tags,
).ravel()
self.tags_classifier = MLPClassifier(tol=0.01)
else:
# General multi-label case: LinearSVC via OneVsRestClassifier.
# Vastly more memory- and time-efficient than MLPClassifier for
# large class counts (e.g. hundreds of AUTO tags).
self.tags_binarizer = MultiLabelBinarizer()
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
self.tags_classifier = OneVsRestClassifier(LinearSVC())
self.tags_classifier = MLPClassifier(tol=0.01)
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
else:
self.tags_classifier = None
@@ -416,6 +464,7 @@ class DocumentClassifier:
self.last_doc_change_time = latest_doc_change
self.last_auto_type_hash = hasher.digest()
self.last_auto_label_set_digest = self._compute_auto_label_set_digest()
self._update_data_vectorizer_hash()
# Set the classifier information into the cache

View File

@@ -0,0 +1,134 @@
"""
Phase 2 — Single queryset pass in DocumentClassifier.train()
The document queryset must be iterated exactly once: during the label
extraction loop, which now also captures doc.content for vectorization.
The previous content_generator() caused a second full table scan.
"""
from __future__ import annotations
from unittest import mock
import pytest
from django.db.models.query import QuerySet
from documents.classifier import DocumentClassifier
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import MatchingModel
from documents.models import StoragePath
from documents.models import Tag
# ---------------------------------------------------------------------------
# Fixtures (mirrors test_classifier_train_skip.py)
# ---------------------------------------------------------------------------
@pytest.fixture()
def classifier_settings(settings, tmp_path):
settings.MODEL_FILE = tmp_path / "model.pickle"
return settings
@pytest.fixture()
def classifier(classifier_settings):
return DocumentClassifier()
@pytest.fixture()
def label_corpus(classifier_settings):
c_auto = Correspondent.objects.create(
name="Auto Corp",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
dt_auto = DocumentType.objects.create(
name="Invoice",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
t_auto = Tag.objects.create(
name="finance",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
sp_auto = StoragePath.objects.create(
name="Finance Path",
path="finance/{correspondent}",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
doc_a = Document.objects.create(
title="Invoice A",
content="quarterly invoice payment tax financial statement revenue",
correspondent=c_auto,
document_type=dt_auto,
storage_path=sp_auto,
checksum="aaa",
mime_type="application/pdf",
filename="invoice_a.pdf",
)
doc_a.tags.set([t_auto])
doc_b = Document.objects.create(
title="Invoice B",
content="monthly invoice billing statement account balance due",
correspondent=c_auto,
document_type=dt_auto,
storage_path=sp_auto,
checksum="bbb",
mime_type="application/pdf",
filename="invoice_b.pdf",
)
doc_b.tags.set([t_auto])
doc_c = Document.objects.create(
title="Notes",
content="meeting notes agenda discussion summary action items follow",
checksum="ccc",
mime_type="application/pdf",
filename="notes_c.pdf",
)
return {"doc_a": doc_a, "doc_b": doc_b, "doc_c": doc_c}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.django_db()
class TestSingleQuerysetPass:
def test_train_iterates_document_queryset_once(self, classifier, label_corpus):
"""
train() must iterate the Document queryset exactly once.
Before Phase 2 there were two iterations: one in the label extraction
loop and a second inside content_generator() for CountVectorizer.
After Phase 2 content is captured during the label loop; the second
iteration is eliminated.
"""
original_iter = QuerySet.__iter__
doc_iter_count = 0
def counting_iter(qs):
nonlocal doc_iter_count
if qs.model is Document:
doc_iter_count += 1
return original_iter(qs)
with mock.patch.object(QuerySet, "__iter__", counting_iter):
classifier.train()
assert doc_iter_count == 1, (
f"Expected 1 Document queryset iteration, got {doc_iter_count}. "
"content_generator() may still be re-fetching from the DB."
)
def test_train_result_unchanged(self, classifier, label_corpus):
"""
Collapsing to a single pass must not change what the classifier learns:
a second train() with no changes still returns False.
"""
assert classifier.train() is True
assert classifier.train() is False

View File

@@ -0,0 +1,300 @@
"""
Tags classifier correctness test — Phase 3b gate.
This test must pass both BEFORE and AFTER the MLPClassifier → LinearSVC swap.
It verifies that the tags classifier correctly learns discriminative signal and
predicts the right tags on held-out documents.
Run before the swap to establish a baseline, then run again after to confirm
the new algorithm is at least as correct.
Two scenarios are tested:
1. Multi-tag (num_tags > 1) — the common case; uses MultiLabelBinarizer
2. Single-tag (num_tags == 1) — special binary path; uses LabelBinarizer
Corpus design: each tag has a distinct vocabulary cluster. Each training
document contains words from exactly one cluster (or two for multi-tag docs).
Held-out test documents contain the same cluster words; correct classification
requires the model to learn the vocabulary → tag mapping.
"""
from __future__ import annotations
import pytest
from documents.classifier import DocumentClassifier
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import MatchingModel
from documents.models import StoragePath
from documents.models import Tag
# ---------------------------------------------------------------------------
# Vocabulary clusters — intentionally non-overlapping so both MLP and SVM
# should learn them perfectly or near-perfectly.
# ---------------------------------------------------------------------------
FINANCE_WORDS = (
"invoice payment tax revenue billing statement account receivable "
"quarterly budget expense ledger debit credit profit loss fiscal"
)
LEGAL_WORDS = (
"contract agreement terms conditions clause liability indemnity "
"jurisdiction arbitration compliance regulation statute obligation"
)
MEDICAL_WORDS = (
"prescription diagnosis treatment patient health symptom dosage "
"physician referral therapy clinical examination procedure chronic"
)
HR_WORDS = (
"employee salary onboarding performance review appraisal benefits "
"recruitment hiring resignation termination payroll department staff"
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def classifier_settings(settings, tmp_path):
settings.MODEL_FILE = tmp_path / "model.pickle"
return settings
@pytest.fixture()
def classifier(classifier_settings):
return DocumentClassifier()
def _make_doc(title, content, checksum, tags=(), **kwargs):
doc = Document.objects.create(
title=title,
content=content,
checksum=checksum,
mime_type="application/pdf",
filename=f"{checksum}.pdf",
**kwargs,
)
if tags:
doc.tags.set(tags)
return doc
def _words(cluster, extra=""):
"""Repeat cluster words enough times to clear min_df=0.01 at ~40 docs."""
return f"{cluster} {cluster} {extra}".strip()
# ---------------------------------------------------------------------------
# Multi-tag correctness
# ---------------------------------------------------------------------------
@pytest.fixture()
def multi_tag_corpus(classifier_settings):
"""
40 training documents across 4 AUTO tags with distinct vocabulary.
10 single-tag docs per tag + 5 two-tag docs. Total: 45 docs.
A non-AUTO correspondent and doc type are included to keep the
other classifiers happy and not raise ValueError.
"""
t_finance = Tag.objects.create(
name="finance",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
t_legal = Tag.objects.create(
name="legal",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
t_medical = Tag.objects.create(
name="medical",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
t_hr = Tag.objects.create(name="hr", matching_algorithm=MatchingModel.MATCH_AUTO)
# non-AUTO labels to keep the other classifiers from raising
c = Correspondent.objects.create(
name="org",
matching_algorithm=MatchingModel.MATCH_NONE,
)
dt = DocumentType.objects.create(
name="doc",
matching_algorithm=MatchingModel.MATCH_NONE,
)
sp = StoragePath.objects.create(
name="archive",
path="archive",
matching_algorithm=MatchingModel.MATCH_NONE,
)
checksum = 0
def make(title, content, tags):
nonlocal checksum
checksum += 1
return _make_doc(
title,
content,
f"{checksum:04d}",
tags=tags,
correspondent=c,
document_type=dt,
storage_path=sp,
)
# 10 single-tag training docs per tag
for i in range(10):
make(f"finance-{i}", _words(FINANCE_WORDS, f"doc{i}"), [t_finance])
make(f"legal-{i}", _words(LEGAL_WORDS, f"doc{i}"), [t_legal])
make(f"medical-{i}", _words(MEDICAL_WORDS, f"doc{i}"), [t_medical])
make(f"hr-{i}", _words(HR_WORDS, f"doc{i}"), [t_hr])
# 5 two-tag training docs
for i in range(5):
make(
f"finance-legal-{i}",
_words(FINANCE_WORDS + " " + LEGAL_WORDS, f"combo{i}"),
[t_finance, t_legal],
)
return {
"t_finance": t_finance,
"t_legal": t_legal,
"t_medical": t_medical,
"t_hr": t_hr,
}
@pytest.mark.django_db()
class TestMultiTagCorrectness:
"""
The tags classifier must correctly predict tags on held-out documents whose
content clearly belongs to one or two vocabulary clusters.
A prediction is "correct" if the expected tag is present in the result.
"""
def test_single_cluster_docs_predicted_correctly(
self,
classifier,
multi_tag_corpus,
):
"""Each single-cluster held-out doc gets exactly the right tag."""
classifier.train()
tags = multi_tag_corpus
cases = [
(FINANCE_WORDS + " unique alpha", [tags["t_finance"].pk]),
(LEGAL_WORDS + " unique beta", [tags["t_legal"].pk]),
(MEDICAL_WORDS + " unique gamma", [tags["t_medical"].pk]),
(HR_WORDS + " unique delta", [tags["t_hr"].pk]),
]
for content, expected_pks in cases:
predicted = classifier.predict_tags(content)
for pk in expected_pks:
assert pk in predicted, (
f"Expected tag pk={pk} in predictions for content starting "
f"'{content[:40]}', got {predicted}"
)
def test_multi_cluster_doc_gets_both_tags(self, classifier, multi_tag_corpus):
"""A document with finance + legal vocabulary gets both tags."""
classifier.train()
tags = multi_tag_corpus
content = FINANCE_WORDS + " " + LEGAL_WORDS + " unique epsilon"
predicted = classifier.predict_tags(content)
assert tags["t_finance"].pk in predicted, f"Expected finance tag in {predicted}"
assert tags["t_legal"].pk in predicted, f"Expected legal tag in {predicted}"
def test_unrelated_content_predicts_no_trained_tags(
self,
classifier,
multi_tag_corpus,
):
"""
Completely alien content should not confidently fire any learned tag.
This is a soft check — we only assert no false positives on a document
that shares zero vocabulary with the training corpus.
"""
classifier.train()
alien = (
"xyzzyx qwerty asdfgh zxcvbn plokij unique zeta "
"xyzzyx qwerty asdfgh zxcvbn plokij unique zeta"
)
predicted = classifier.predict_tags(alien)
# Not a hard requirement — just log for human inspection
# Both MLP and SVM may or may not produce false positives on OOV content
assert isinstance(predicted, list)
# ---------------------------------------------------------------------------
# Single-tag (binary) correctness
# ---------------------------------------------------------------------------
@pytest.fixture()
def single_tag_corpus(classifier_settings):
"""
Corpus with exactly ONE AUTO tag, exercising the LabelBinarizer +
binary classification path. Documents either have the tag or don't.
"""
t_finance = Tag.objects.create(
name="finance",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
c = Correspondent.objects.create(
name="org",
matching_algorithm=MatchingModel.MATCH_NONE,
)
dt = DocumentType.objects.create(
name="doc",
matching_algorithm=MatchingModel.MATCH_NONE,
)
checksum = 0
def make(title, content, tags):
nonlocal checksum
checksum += 1
return _make_doc(
title,
content,
f"s{checksum:04d}",
tags=tags,
correspondent=c,
document_type=dt,
)
for i in range(10):
make(f"finance-{i}", _words(FINANCE_WORDS, f"s{i}"), [t_finance])
make(f"other-{i}", _words(LEGAL_WORDS, f"s{i}"), [])
return {"t_finance": t_finance}
@pytest.mark.django_db()
class TestSingleTagCorrectness:
def test_finance_content_predicts_finance_tag(self, classifier, single_tag_corpus):
"""Finance vocabulary → finance tag predicted."""
classifier.train()
tags = single_tag_corpus
predicted = classifier.predict_tags(FINANCE_WORDS + " unique alpha single")
assert tags["t_finance"].pk in predicted, (
f"Expected finance tag pk={tags['t_finance'].pk} in {predicted}"
)
def test_non_finance_content_predicts_no_tag(self, classifier, single_tag_corpus):
"""Non-finance vocabulary → no tag predicted."""
classifier.train()
predicted = classifier.predict_tags(LEGAL_WORDS + " unique beta single")
assert predicted == [], f"Expected no tags, got {predicted}"

View File

@@ -0,0 +1,325 @@
"""
Phase 1 — fast-skip optimisation in DocumentClassifier.train()
The goal: when nothing has changed since the last training run, train() should
return False after at most 5 DB queries (1x MAX(modified) + 4x MATCH_AUTO pk
lists), not after a full per-document label scan.
Correctness invariant: the skip must NOT fire when the set of AUTO-matching
labels has changed, even if no Document.modified timestamp has advanced (e.g.
a Tag's matching_algorithm was flipped to MATCH_AUTO after the last train).
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from django.db import connection
from django.test.utils import CaptureQueriesContext
from documents.classifier import DocumentClassifier
from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import MatchingModel
from documents.models import StoragePath
from documents.models import Tag
if TYPE_CHECKING:
from pathlib import Path
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def classifier_settings(settings, tmp_path: Path):
"""Point MODEL_FILE at a temp directory so tests are hermetic."""
settings.MODEL_FILE = tmp_path / "model.pickle"
return settings
@pytest.fixture()
def classifier(classifier_settings):
"""Fresh DocumentClassifier instance with test settings active."""
return DocumentClassifier()
@pytest.fixture()
def label_corpus(classifier_settings):
"""
Minimal label + document corpus that produces a trainable classifier.
Creates
-------
Correspondents
c_auto — MATCH_AUTO, assigned to two docs
c_none — MATCH_NONE (control)
DocumentTypes
dt_auto — MATCH_AUTO, assigned to two docs
dt_none — MATCH_NONE (control)
Tags
t_auto — MATCH_AUTO, applied to two docs
t_none — MATCH_NONE (control, applied to one doc but never learned)
StoragePaths
sp_auto — MATCH_AUTO, assigned to two docs
sp_none — MATCH_NONE (control)
Documents
doc_a, doc_b — assigned AUTO labels above
doc_c — control doc (MATCH_NONE labels only)
The fixture returns a dict with all created objects for direct mutation in
individual tests.
"""
c_auto = Correspondent.objects.create(
name="Auto Corp",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
c_none = Correspondent.objects.create(
name="Manual Corp",
matching_algorithm=MatchingModel.MATCH_NONE,
)
dt_auto = DocumentType.objects.create(
name="Invoice",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
dt_none = DocumentType.objects.create(
name="Other",
matching_algorithm=MatchingModel.MATCH_NONE,
)
t_auto = Tag.objects.create(
name="finance",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
t_none = Tag.objects.create(
name="misc",
matching_algorithm=MatchingModel.MATCH_NONE,
)
sp_auto = StoragePath.objects.create(
name="Finance Path",
path="finance/{correspondent}",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
sp_none = StoragePath.objects.create(
name="Other Path",
path="other/{correspondent}",
matching_algorithm=MatchingModel.MATCH_NONE,
)
doc_a = Document.objects.create(
title="Invoice from Auto Corp Jan",
content="quarterly invoice payment tax financial statement revenue",
correspondent=c_auto,
document_type=dt_auto,
storage_path=sp_auto,
checksum="aaa",
mime_type="application/pdf",
filename="invoice_a.pdf",
)
doc_a.tags.set([t_auto])
doc_b = Document.objects.create(
title="Invoice from Auto Corp Feb",
content="monthly invoice billing statement account balance due",
correspondent=c_auto,
document_type=dt_auto,
storage_path=sp_auto,
checksum="bbb",
mime_type="application/pdf",
filename="invoice_b.pdf",
)
doc_b.tags.set([t_auto])
# Control document — no AUTO labels, but has enough content to vectorize
doc_c = Document.objects.create(
title="Miscellaneous Notes",
content="meeting notes agenda discussion summary action items follow",
correspondent=c_none,
document_type=dt_none,
checksum="ccc",
mime_type="application/pdf",
filename="notes_c.pdf",
)
doc_c.tags.set([t_none])
return {
"c_auto": c_auto,
"c_none": c_none,
"dt_auto": dt_auto,
"dt_none": dt_none,
"t_auto": t_auto,
"t_none": t_none,
"sp_auto": sp_auto,
"sp_none": sp_none,
"doc_a": doc_a,
"doc_b": doc_b,
"doc_c": doc_c,
}
# ---------------------------------------------------------------------------
# Happy-path skip tests
# ---------------------------------------------------------------------------
@pytest.mark.django_db()
class TestFastSkipFires:
"""The no-op path: nothing changed, so the second train() is skipped."""
def test_first_train_returns_true(self, classifier, label_corpus):
"""First train on a fresh classifier must return True (did work)."""
assert classifier.train() is True
def test_second_train_returns_false(self, classifier, label_corpus):
"""Second train with no changes must return False (skipped)."""
classifier.train()
assert classifier.train() is False
def test_fast_skip_runs_minimal_queries(self, classifier, label_corpus):
"""
The no-op path must use at most 5 DB queries:
1x Document.objects.aggregate(Max('modified'))
4x MATCH_AUTO pk lists (Correspondent / DocumentType / Tag / StoragePath)
The current implementation (before Phase 1) iterates every document
to build the label hash BEFORE it can decide to skip, which is O(N).
This test verifies the fast path is in place.
"""
classifier.train()
with CaptureQueriesContext(connection) as ctx:
result = classifier.train()
assert result is False
assert len(ctx.captured_queries) <= 5, (
f"Fast skip used {len(ctx.captured_queries)} queries; expected ≤5.\n"
+ "\n".join(q["sql"] for q in ctx.captured_queries)
)
def test_fast_skip_refreshes_cache_keys(self, classifier, label_corpus):
"""
Even on a skip, the cache keys must be refreshed so that the task
scheduler can detect the classifier is still current.
"""
from django.core.cache import cache
from documents.caching import CLASSIFIER_HASH_KEY
from documents.caching import CLASSIFIER_MODIFIED_KEY
from documents.caching import CLASSIFIER_VERSION_KEY
classifier.train()
# Evict the keys to prove skip re-populates them
cache.delete(CLASSIFIER_MODIFIED_KEY)
cache.delete(CLASSIFIER_HASH_KEY)
cache.delete(CLASSIFIER_VERSION_KEY)
result = classifier.train()
assert result is False
assert cache.get(CLASSIFIER_MODIFIED_KEY) is not None
assert cache.get(CLASSIFIER_HASH_KEY) is not None
assert cache.get(CLASSIFIER_VERSION_KEY) is not None
# ---------------------------------------------------------------------------
# Correctness tests — skip must NOT fire when the world has changed
# ---------------------------------------------------------------------------
@pytest.mark.django_db()
class TestFastSkipDoesNotFire:
"""The skip guard must yield to a full retrain whenever labels change."""
def test_document_content_modification_triggers_retrain(
self,
classifier,
label_corpus,
):
"""Updating a document's content updates modified → retrain required."""
classifier.train()
doc_a = label_corpus["doc_a"]
doc_a.content = "completely different words here now nothing same"
doc_a.save()
assert classifier.train() is True
def test_document_label_reassignment_triggers_retrain(
self,
classifier,
label_corpus,
):
"""
Reassigning a document to a different AUTO correspondent (touching
doc.modified) must trigger a retrain.
"""
c_auto2 = Correspondent.objects.create(
name="Second Auto Corp",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
classifier.train()
doc_a = label_corpus["doc_a"]
doc_a.correspondent = c_auto2
doc_a.save()
assert classifier.train() is True
def test_matching_algorithm_change_on_assigned_tag_triggers_retrain(
self,
classifier,
label_corpus,
):
"""
Flipping a tag's matching_algorithm to MATCH_AUTO after it is already
assigned to documents must trigger a retrain — even though no
Document.modified timestamp advances.
This is the key correctness case for the auto-label-set digest:
the tag is already on doc_a and doc_b, so once it becomes MATCH_AUTO
the classifier needs to learn it.
"""
# t_none is applied to doc_c (a control doc) via the fixture.
# We flip it to MATCH_AUTO; the set of learnable AUTO tags grows.
classifier.train()
t_none = label_corpus["t_none"]
t_none.matching_algorithm = MatchingModel.MATCH_AUTO
t_none.save(update_fields=["matching_algorithm"])
# Document.modified is NOT touched — this test specifically verifies
# that the auto-label-set digest catches the change.
assert classifier.train() is True
def test_new_auto_correspondent_triggers_retrain(self, classifier, label_corpus):
"""
Adding a brand-new MATCH_AUTO correspondent (unassigned to any doc)
must trigger a retrain: the auto-label-set has grown.
"""
classifier.train()
Correspondent.objects.create(
name="New Auto Corp",
matching_algorithm=MatchingModel.MATCH_AUTO,
)
assert classifier.train() is True
def test_removing_auto_label_triggers_retrain(self, classifier, label_corpus):
"""
Deleting a MATCH_AUTO correspondent shrinks the auto-label-set and
must trigger a retrain.
"""
classifier.train()
label_corpus["c_auto"].delete()
assert classifier.train() is True
def test_fresh_classifier_always_trains(self, classifier, label_corpus):
"""
A classifier that has never been trained (last_doc_change_time is None)
must always perform a full train, regardless of corpus state.
"""
assert classifier.last_doc_change_time is None
assert classifier.train() is True
def test_no_documents_raises_value_error(self, classifier, classifier_settings):
"""train() with an empty database must raise ValueError."""
with pytest.raises(ValueError, match="No training data"):
classifier.train()