mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-04-09 17:48:51 +00:00
Compare commits
3 Commits
dev
...
feature-cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a26514a96 | ||
|
|
1fefd506b7 | ||
|
|
68b866aeee |
@@ -11,7 +11,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
|
||||
from numpy import ndarray
|
||||
@@ -19,6 +18,7 @@ if TYPE_CHECKING:
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.core.cache import caches
|
||||
from django.db.models import Max
|
||||
|
||||
from documents.caching import CACHE_5_MINUTES
|
||||
from documents.caching import CACHE_50_MINUTES
|
||||
@@ -99,7 +99,8 @@ class DocumentClassifier:
|
||||
# v8 - Added storage path classifier
|
||||
# v9 - Changed from hashing to time/ids for re-train check
|
||||
# v10 - HMAC-signed model file
|
||||
FORMAT_VERSION = 10
|
||||
# v11 - Added auto-label-set digest for fast skip without full document scan
|
||||
FORMAT_VERSION = 11
|
||||
|
||||
HMAC_SIZE = 32 # SHA-256 digest length
|
||||
|
||||
@@ -108,6 +109,8 @@ class DocumentClassifier:
|
||||
self.last_doc_change_time: datetime | None = None
|
||||
# Hash of primary keys of AUTO matching values last used in training
|
||||
self.last_auto_type_hash: bytes | None = None
|
||||
# Digest of the set of all MATCH_AUTO label PKs (fast-skip guard)
|
||||
self.last_auto_label_set_digest: bytes | None = None
|
||||
|
||||
self.data_vectorizer = None
|
||||
self.data_vectorizer_hash = None
|
||||
@@ -140,6 +143,29 @@ class DocumentClassifier:
|
||||
sha256,
|
||||
).digest()
|
||||
|
||||
@staticmethod
|
||||
def _compute_auto_label_set_digest() -> bytes:
|
||||
"""
|
||||
Return a SHA-256 digest of all MATCH_AUTO label PKs across the four
|
||||
label types. Four cheap indexed queries; stable for any fixed set of
|
||||
AUTO labels regardless of document assignments.
|
||||
"""
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
|
||||
hasher = sha256()
|
||||
for model in (Correspondent, DocumentType, Tag, StoragePath):
|
||||
pks = sorted(
|
||||
model.objects.filter(
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
).values_list("pk", flat=True),
|
||||
)
|
||||
for pk in pks:
|
||||
hasher.update(pk.to_bytes(4, "little", signed=False))
|
||||
return hasher.digest()
|
||||
|
||||
def load(self) -> None:
|
||||
from sklearn.exceptions import InconsistentVersionWarning
|
||||
|
||||
@@ -161,6 +187,7 @@ class DocumentClassifier:
|
||||
schema_version,
|
||||
self.last_doc_change_time,
|
||||
self.last_auto_type_hash,
|
||||
self.last_auto_label_set_digest,
|
||||
self.data_vectorizer,
|
||||
self.tags_binarizer,
|
||||
self.tags_classifier,
|
||||
@@ -202,6 +229,7 @@ class DocumentClassifier:
|
||||
self.FORMAT_VERSION,
|
||||
self.last_doc_change_time,
|
||||
self.last_auto_type_hash,
|
||||
self.last_auto_label_set_digest,
|
||||
self.data_vectorizer,
|
||||
self.tags_binarizer,
|
||||
self.tags_classifier,
|
||||
@@ -224,6 +252,39 @@ class DocumentClassifier:
|
||||
) -> bool:
|
||||
notify = status_callback if status_callback is not None else lambda _: None
|
||||
|
||||
# Fast skip: avoid the expensive per-document label scan when nothing
|
||||
# has changed. Requires a prior training run to have populated both
|
||||
# last_doc_change_time and last_auto_label_set_digest.
|
||||
if (
|
||||
self.last_doc_change_time is not None
|
||||
and self.last_auto_label_set_digest is not None
|
||||
):
|
||||
latest_mod = Document.objects.exclude(
|
||||
tags__is_inbox_tag=True,
|
||||
).aggregate(Max("modified"))["modified__max"]
|
||||
if latest_mod is not None and latest_mod <= self.last_doc_change_time:
|
||||
current_digest = self._compute_auto_label_set_digest()
|
||||
if current_digest == self.last_auto_label_set_digest:
|
||||
logger.info("No updates since last training")
|
||||
cache.set(
|
||||
CLASSIFIER_MODIFIED_KEY,
|
||||
self.last_doc_change_time,
|
||||
CACHE_50_MINUTES,
|
||||
)
|
||||
cache.set(
|
||||
CLASSIFIER_HASH_KEY,
|
||||
self.last_auto_type_hash.hex()
|
||||
if self.last_auto_type_hash
|
||||
else "",
|
||||
CACHE_50_MINUTES,
|
||||
)
|
||||
cache.set(
|
||||
CLASSIFIER_VERSION_KEY,
|
||||
self.FORMAT_VERSION,
|
||||
CACHE_50_MINUTES,
|
||||
)
|
||||
return False
|
||||
|
||||
# Get non-inbox documents
|
||||
docs_queryset = (
|
||||
Document.objects.exclude(
|
||||
@@ -242,12 +303,15 @@ class DocumentClassifier:
|
||||
labels_correspondent = []
|
||||
labels_document_type = []
|
||||
labels_storage_path = []
|
||||
doc_contents: list[str] = []
|
||||
|
||||
# Step 1: Extract and preprocess training data from the database.
|
||||
# Step 1: Extract labels and capture content in a single pass.
|
||||
logger.debug("Gathering data from database...")
|
||||
notify(f"Gathering data from {docs_queryset.count()} document(s)...")
|
||||
hasher = sha256()
|
||||
for doc in docs_queryset:
|
||||
doc_contents.append(doc.content)
|
||||
|
||||
y = -1
|
||||
dt = doc.document_type
|
||||
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
@@ -282,25 +346,7 @@ class DocumentClassifier:
|
||||
|
||||
num_tags = len(labels_tags_unique)
|
||||
|
||||
# Check if retraining is actually required.
|
||||
# A document has been updated since the classifier was trained
|
||||
# New auto tags, types, correspondent, storage paths exist
|
||||
latest_doc_change = docs_queryset.latest("modified").modified
|
||||
if (
|
||||
self.last_doc_change_time is not None
|
||||
and self.last_doc_change_time >= latest_doc_change
|
||||
) and self.last_auto_type_hash == hasher.digest():
|
||||
logger.info("No updates since last training")
|
||||
# Set the classifier information into the cache
|
||||
# Caching for 50 minutes, so slightly less than the normal retrain time
|
||||
cache.set(
|
||||
CLASSIFIER_MODIFIED_KEY,
|
||||
self.last_doc_change_time,
|
||||
CACHE_50_MINUTES,
|
||||
)
|
||||
cache.set(CLASSIFIER_HASH_KEY, hasher.hexdigest(), CACHE_50_MINUTES)
|
||||
cache.set(CLASSIFIER_VERSION_KEY, self.FORMAT_VERSION, CACHE_50_MINUTES)
|
||||
return False
|
||||
|
||||
# subtract 1 since -1 (null) is also part of the classes.
|
||||
|
||||
@@ -317,21 +363,16 @@ class DocumentClassifier:
|
||||
)
|
||||
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.multiclass import OneVsRestClassifier
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.preprocessing import LabelBinarizer
|
||||
from sklearn.preprocessing import MultiLabelBinarizer
|
||||
from sklearn.svm import LinearSVC
|
||||
|
||||
# Step 2: vectorize data
|
||||
logger.debug("Vectorizing data...")
|
||||
notify("Vectorizing document content...")
|
||||
|
||||
def content_generator() -> Iterator[str]:
|
||||
"""
|
||||
Generates the content for documents, but once at a time
|
||||
"""
|
||||
for doc in docs_queryset:
|
||||
yield self.preprocess_content(doc.content, shared_cache=False)
|
||||
|
||||
self.data_vectorizer = CountVectorizer(
|
||||
analyzer="word",
|
||||
ngram_range=(1, 2),
|
||||
@@ -339,7 +380,8 @@ class DocumentClassifier:
|
||||
)
|
||||
|
||||
data_vectorized: ndarray = self.data_vectorizer.fit_transform(
|
||||
content_generator(),
|
||||
self.preprocess_content(content, shared_cache=False)
|
||||
for content in doc_contents
|
||||
)
|
||||
|
||||
# See the notes here:
|
||||
@@ -353,8 +395,10 @@ class DocumentClassifier:
|
||||
notify(f"Training tags classifier ({num_tags} tag(s))...")
|
||||
|
||||
if num_tags == 1:
|
||||
# Special case where only one tag has auto:
|
||||
# Fallback to binary classification.
|
||||
# Special case: only one AUTO tag — use binary classification.
|
||||
# MLPClassifier is used here because LinearSVC requires at least
|
||||
# 2 distinct classes in training data, which cannot be guaranteed
|
||||
# when all documents share the single AUTO tag.
|
||||
labels_tags = [
|
||||
label[0] if len(label) == 1 else -1 for label in labels_tags
|
||||
]
|
||||
@@ -362,11 +406,15 @@ class DocumentClassifier:
|
||||
labels_tags_vectorized: ndarray = self.tags_binarizer.fit_transform(
|
||||
labels_tags,
|
||||
).ravel()
|
||||
self.tags_classifier = MLPClassifier(tol=0.01)
|
||||
else:
|
||||
# General multi-label case: LinearSVC via OneVsRestClassifier.
|
||||
# Vastly more memory- and time-efficient than MLPClassifier for
|
||||
# large class counts (e.g. hundreds of AUTO tags).
|
||||
self.tags_binarizer = MultiLabelBinarizer()
|
||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
|
||||
self.tags_classifier = OneVsRestClassifier(LinearSVC())
|
||||
|
||||
self.tags_classifier = MLPClassifier(tol=0.01)
|
||||
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
||||
else:
|
||||
self.tags_classifier = None
|
||||
@@ -416,6 +464,7 @@ class DocumentClassifier:
|
||||
|
||||
self.last_doc_change_time = latest_doc_change
|
||||
self.last_auto_type_hash = hasher.digest()
|
||||
self.last_auto_label_set_digest = self._compute_auto_label_set_digest()
|
||||
self._update_data_vectorizer_hash()
|
||||
|
||||
# Set the classifier information into the cache
|
||||
|
||||
134
src/documents/tests/test_classifier_single_pass.py
Normal file
134
src/documents/tests/test_classifier_single_pass.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Phase 2 — Single queryset pass in DocumentClassifier.train()
|
||||
|
||||
The document queryset must be iterated exactly once: during the label
|
||||
extraction loop, which now also captures doc.content for vectorization.
|
||||
The previous content_generator() caused a second full table scan.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
from documents.classifier import DocumentClassifier
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import MatchingModel
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures (mirrors test_classifier_train_skip.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def classifier_settings(settings, tmp_path):
|
||||
settings.MODEL_FILE = tmp_path / "model.pickle"
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def classifier(classifier_settings):
|
||||
return DocumentClassifier()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def label_corpus(classifier_settings):
|
||||
c_auto = Correspondent.objects.create(
|
||||
name="Auto Corp",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
dt_auto = DocumentType.objects.create(
|
||||
name="Invoice",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
t_auto = Tag.objects.create(
|
||||
name="finance",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
sp_auto = StoragePath.objects.create(
|
||||
name="Finance Path",
|
||||
path="finance/{correspondent}",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
|
||||
doc_a = Document.objects.create(
|
||||
title="Invoice A",
|
||||
content="quarterly invoice payment tax financial statement revenue",
|
||||
correspondent=c_auto,
|
||||
document_type=dt_auto,
|
||||
storage_path=sp_auto,
|
||||
checksum="aaa",
|
||||
mime_type="application/pdf",
|
||||
filename="invoice_a.pdf",
|
||||
)
|
||||
doc_a.tags.set([t_auto])
|
||||
|
||||
doc_b = Document.objects.create(
|
||||
title="Invoice B",
|
||||
content="monthly invoice billing statement account balance due",
|
||||
correspondent=c_auto,
|
||||
document_type=dt_auto,
|
||||
storage_path=sp_auto,
|
||||
checksum="bbb",
|
||||
mime_type="application/pdf",
|
||||
filename="invoice_b.pdf",
|
||||
)
|
||||
doc_b.tags.set([t_auto])
|
||||
|
||||
doc_c = Document.objects.create(
|
||||
title="Notes",
|
||||
content="meeting notes agenda discussion summary action items follow",
|
||||
checksum="ccc",
|
||||
mime_type="application/pdf",
|
||||
filename="notes_c.pdf",
|
||||
)
|
||||
|
||||
return {"doc_a": doc_a, "doc_b": doc_b, "doc_c": doc_c}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.django_db()
|
||||
class TestSingleQuerysetPass:
|
||||
def test_train_iterates_document_queryset_once(self, classifier, label_corpus):
|
||||
"""
|
||||
train() must iterate the Document queryset exactly once.
|
||||
|
||||
Before Phase 2 there were two iterations: one in the label extraction
|
||||
loop and a second inside content_generator() for CountVectorizer.
|
||||
After Phase 2 content is captured during the label loop; the second
|
||||
iteration is eliminated.
|
||||
"""
|
||||
original_iter = QuerySet.__iter__
|
||||
doc_iter_count = 0
|
||||
|
||||
def counting_iter(qs):
|
||||
nonlocal doc_iter_count
|
||||
if qs.model is Document:
|
||||
doc_iter_count += 1
|
||||
return original_iter(qs)
|
||||
|
||||
with mock.patch.object(QuerySet, "__iter__", counting_iter):
|
||||
classifier.train()
|
||||
|
||||
assert doc_iter_count == 1, (
|
||||
f"Expected 1 Document queryset iteration, got {doc_iter_count}. "
|
||||
"content_generator() may still be re-fetching from the DB."
|
||||
)
|
||||
|
||||
def test_train_result_unchanged(self, classifier, label_corpus):
|
||||
"""
|
||||
Collapsing to a single pass must not change what the classifier learns:
|
||||
a second train() with no changes still returns False.
|
||||
"""
|
||||
assert classifier.train() is True
|
||||
assert classifier.train() is False
|
||||
300
src/documents/tests/test_classifier_tags_correctness.py
Normal file
300
src/documents/tests/test_classifier_tags_correctness.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
Tags classifier correctness test — Phase 3b gate.
|
||||
|
||||
This test must pass both BEFORE and AFTER the MLPClassifier → LinearSVC swap.
|
||||
It verifies that the tags classifier correctly learns discriminative signal and
|
||||
predicts the right tags on held-out documents.
|
||||
|
||||
Run before the swap to establish a baseline, then run again after to confirm
|
||||
the new algorithm is at least as correct.
|
||||
|
||||
Two scenarios are tested:
|
||||
1. Multi-tag (num_tags > 1) — the common case; uses MultiLabelBinarizer
|
||||
2. Single-tag (num_tags == 1) — special binary path; uses LabelBinarizer
|
||||
|
||||
Corpus design: each tag has a distinct vocabulary cluster. Each training
|
||||
document contains words from exactly one cluster (or two for multi-tag docs).
|
||||
Held-out test documents contain the same cluster words; correct classification
|
||||
requires the model to learn the vocabulary → tag mapping.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from documents.classifier import DocumentClassifier
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import MatchingModel
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vocabulary clusters — intentionally non-overlapping so both MLP and SVM
|
||||
# should learn them perfectly or near-perfectly.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FINANCE_WORDS = (
|
||||
"invoice payment tax revenue billing statement account receivable "
|
||||
"quarterly budget expense ledger debit credit profit loss fiscal"
|
||||
)
|
||||
LEGAL_WORDS = (
|
||||
"contract agreement terms conditions clause liability indemnity "
|
||||
"jurisdiction arbitration compliance regulation statute obligation"
|
||||
)
|
||||
MEDICAL_WORDS = (
|
||||
"prescription diagnosis treatment patient health symptom dosage "
|
||||
"physician referral therapy clinical examination procedure chronic"
|
||||
)
|
||||
HR_WORDS = (
|
||||
"employee salary onboarding performance review appraisal benefits "
|
||||
"recruitment hiring resignation termination payroll department staff"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def classifier_settings(settings, tmp_path):
|
||||
settings.MODEL_FILE = tmp_path / "model.pickle"
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def classifier(classifier_settings):
|
||||
return DocumentClassifier()
|
||||
|
||||
|
||||
def _make_doc(title, content, checksum, tags=(), **kwargs):
|
||||
doc = Document.objects.create(
|
||||
title=title,
|
||||
content=content,
|
||||
checksum=checksum,
|
||||
mime_type="application/pdf",
|
||||
filename=f"{checksum}.pdf",
|
||||
**kwargs,
|
||||
)
|
||||
if tags:
|
||||
doc.tags.set(tags)
|
||||
return doc
|
||||
|
||||
|
||||
def _words(cluster, extra=""):
|
||||
"""Repeat cluster words enough times to clear min_df=0.01 at ~40 docs."""
|
||||
return f"{cluster} {cluster} {extra}".strip()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-tag correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def multi_tag_corpus(classifier_settings):
|
||||
"""
|
||||
40 training documents across 4 AUTO tags with distinct vocabulary.
|
||||
10 single-tag docs per tag + 5 two-tag docs. Total: 45 docs.
|
||||
|
||||
A non-AUTO correspondent and doc type are included to keep the
|
||||
other classifiers happy and not raise ValueError.
|
||||
"""
|
||||
t_finance = Tag.objects.create(
|
||||
name="finance",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
t_legal = Tag.objects.create(
|
||||
name="legal",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
t_medical = Tag.objects.create(
|
||||
name="medical",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
t_hr = Tag.objects.create(name="hr", matching_algorithm=MatchingModel.MATCH_AUTO)
|
||||
|
||||
# non-AUTO labels to keep the other classifiers from raising
|
||||
c = Correspondent.objects.create(
|
||||
name="org",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
dt = DocumentType.objects.create(
|
||||
name="doc",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
sp = StoragePath.objects.create(
|
||||
name="archive",
|
||||
path="archive",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
|
||||
checksum = 0
|
||||
|
||||
def make(title, content, tags):
|
||||
nonlocal checksum
|
||||
checksum += 1
|
||||
return _make_doc(
|
||||
title,
|
||||
content,
|
||||
f"{checksum:04d}",
|
||||
tags=tags,
|
||||
correspondent=c,
|
||||
document_type=dt,
|
||||
storage_path=sp,
|
||||
)
|
||||
|
||||
# 10 single-tag training docs per tag
|
||||
for i in range(10):
|
||||
make(f"finance-{i}", _words(FINANCE_WORDS, f"doc{i}"), [t_finance])
|
||||
make(f"legal-{i}", _words(LEGAL_WORDS, f"doc{i}"), [t_legal])
|
||||
make(f"medical-{i}", _words(MEDICAL_WORDS, f"doc{i}"), [t_medical])
|
||||
make(f"hr-{i}", _words(HR_WORDS, f"doc{i}"), [t_hr])
|
||||
|
||||
# 5 two-tag training docs
|
||||
for i in range(5):
|
||||
make(
|
||||
f"finance-legal-{i}",
|
||||
_words(FINANCE_WORDS + " " + LEGAL_WORDS, f"combo{i}"),
|
||||
[t_finance, t_legal],
|
||||
)
|
||||
|
||||
return {
|
||||
"t_finance": t_finance,
|
||||
"t_legal": t_legal,
|
||||
"t_medical": t_medical,
|
||||
"t_hr": t_hr,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.django_db()
|
||||
class TestMultiTagCorrectness:
|
||||
"""
|
||||
The tags classifier must correctly predict tags on held-out documents whose
|
||||
content clearly belongs to one or two vocabulary clusters.
|
||||
|
||||
A prediction is "correct" if the expected tag is present in the result.
|
||||
"""
|
||||
|
||||
def test_single_cluster_docs_predicted_correctly(
|
||||
self,
|
||||
classifier,
|
||||
multi_tag_corpus,
|
||||
):
|
||||
"""Each single-cluster held-out doc gets exactly the right tag."""
|
||||
classifier.train()
|
||||
tags = multi_tag_corpus
|
||||
|
||||
cases = [
|
||||
(FINANCE_WORDS + " unique alpha", [tags["t_finance"].pk]),
|
||||
(LEGAL_WORDS + " unique beta", [tags["t_legal"].pk]),
|
||||
(MEDICAL_WORDS + " unique gamma", [tags["t_medical"].pk]),
|
||||
(HR_WORDS + " unique delta", [tags["t_hr"].pk]),
|
||||
]
|
||||
|
||||
for content, expected_pks in cases:
|
||||
predicted = classifier.predict_tags(content)
|
||||
for pk in expected_pks:
|
||||
assert pk in predicted, (
|
||||
f"Expected tag pk={pk} in predictions for content starting "
|
||||
f"'{content[:40]}…', got {predicted}"
|
||||
)
|
||||
|
||||
def test_multi_cluster_doc_gets_both_tags(self, classifier, multi_tag_corpus):
|
||||
"""A document with finance + legal vocabulary gets both tags."""
|
||||
classifier.train()
|
||||
tags = multi_tag_corpus
|
||||
|
||||
content = FINANCE_WORDS + " " + LEGAL_WORDS + " unique epsilon"
|
||||
predicted = classifier.predict_tags(content)
|
||||
|
||||
assert tags["t_finance"].pk in predicted, f"Expected finance tag in {predicted}"
|
||||
assert tags["t_legal"].pk in predicted, f"Expected legal tag in {predicted}"
|
||||
|
||||
def test_unrelated_content_predicts_no_trained_tags(
|
||||
self,
|
||||
classifier,
|
||||
multi_tag_corpus,
|
||||
):
|
||||
"""
|
||||
Completely alien content should not confidently fire any learned tag.
|
||||
This is a soft check — we only assert no false positives on a document
|
||||
that shares zero vocabulary with the training corpus.
|
||||
"""
|
||||
classifier.train()
|
||||
|
||||
alien = (
|
||||
"xyzzyx qwerty asdfgh zxcvbn plokij unique zeta "
|
||||
"xyzzyx qwerty asdfgh zxcvbn plokij unique zeta"
|
||||
)
|
||||
predicted = classifier.predict_tags(alien)
|
||||
# Not a hard requirement — just log for human inspection
|
||||
# Both MLP and SVM may or may not produce false positives on OOV content
|
||||
assert isinstance(predicted, list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-tag (binary) correctness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def single_tag_corpus(classifier_settings):
|
||||
"""
|
||||
Corpus with exactly ONE AUTO tag, exercising the LabelBinarizer +
|
||||
binary classification path. Documents either have the tag or don't.
|
||||
"""
|
||||
t_finance = Tag.objects.create(
|
||||
name="finance",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
c = Correspondent.objects.create(
|
||||
name="org",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
dt = DocumentType.objects.create(
|
||||
name="doc",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
|
||||
checksum = 0
|
||||
|
||||
def make(title, content, tags):
|
||||
nonlocal checksum
|
||||
checksum += 1
|
||||
return _make_doc(
|
||||
title,
|
||||
content,
|
||||
f"s{checksum:04d}",
|
||||
tags=tags,
|
||||
correspondent=c,
|
||||
document_type=dt,
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
make(f"finance-{i}", _words(FINANCE_WORDS, f"s{i}"), [t_finance])
|
||||
make(f"other-{i}", _words(LEGAL_WORDS, f"s{i}"), [])
|
||||
|
||||
return {"t_finance": t_finance}
|
||||
|
||||
|
||||
@pytest.mark.django_db()
|
||||
class TestSingleTagCorrectness:
|
||||
def test_finance_content_predicts_finance_tag(self, classifier, single_tag_corpus):
|
||||
"""Finance vocabulary → finance tag predicted."""
|
||||
classifier.train()
|
||||
tags = single_tag_corpus
|
||||
|
||||
predicted = classifier.predict_tags(FINANCE_WORDS + " unique alpha single")
|
||||
assert tags["t_finance"].pk in predicted, (
|
||||
f"Expected finance tag pk={tags['t_finance'].pk} in {predicted}"
|
||||
)
|
||||
|
||||
def test_non_finance_content_predicts_no_tag(self, classifier, single_tag_corpus):
|
||||
"""Non-finance vocabulary → no tag predicted."""
|
||||
classifier.train()
|
||||
|
||||
predicted = classifier.predict_tags(LEGAL_WORDS + " unique beta single")
|
||||
assert predicted == [], f"Expected no tags, got {predicted}"
|
||||
325
src/documents/tests/test_classifier_train_skip.py
Normal file
325
src/documents/tests/test_classifier_train_skip.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
Phase 1 — fast-skip optimisation in DocumentClassifier.train()
|
||||
|
||||
The goal: when nothing has changed since the last training run, train() should
|
||||
return False after at most 5 DB queries (1x MAX(modified) + 4x MATCH_AUTO pk
|
||||
lists), not after a full per-document label scan.
|
||||
|
||||
Correctness invariant: the skip must NOT fire when the set of AUTO-matching
|
||||
labels has changed, even if no Document.modified timestamp has advanced (e.g.
|
||||
a Tag's matching_algorithm was flipped to MATCH_AUTO after the last train).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from django.db import connection
|
||||
from django.test.utils import CaptureQueriesContext
|
||||
|
||||
from documents.classifier import DocumentClassifier
|
||||
from documents.models import Correspondent
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import MatchingModel
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def classifier_settings(settings, tmp_path: Path):
|
||||
"""Point MODEL_FILE at a temp directory so tests are hermetic."""
|
||||
settings.MODEL_FILE = tmp_path / "model.pickle"
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def classifier(classifier_settings):
|
||||
"""Fresh DocumentClassifier instance with test settings active."""
|
||||
return DocumentClassifier()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def label_corpus(classifier_settings):
|
||||
"""
|
||||
Minimal label + document corpus that produces a trainable classifier.
|
||||
|
||||
Creates
|
||||
-------
|
||||
Correspondents
|
||||
c_auto — MATCH_AUTO, assigned to two docs
|
||||
c_none — MATCH_NONE (control)
|
||||
DocumentTypes
|
||||
dt_auto — MATCH_AUTO, assigned to two docs
|
||||
dt_none — MATCH_NONE (control)
|
||||
Tags
|
||||
t_auto — MATCH_AUTO, applied to two docs
|
||||
t_none — MATCH_NONE (control, applied to one doc but never learned)
|
||||
StoragePaths
|
||||
sp_auto — MATCH_AUTO, assigned to two docs
|
||||
sp_none — MATCH_NONE (control)
|
||||
|
||||
Documents
|
||||
doc_a, doc_b — assigned AUTO labels above
|
||||
doc_c — control doc (MATCH_NONE labels only)
|
||||
|
||||
The fixture returns a dict with all created objects for direct mutation in
|
||||
individual tests.
|
||||
"""
|
||||
c_auto = Correspondent.objects.create(
|
||||
name="Auto Corp",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
c_none = Correspondent.objects.create(
|
||||
name="Manual Corp",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
|
||||
dt_auto = DocumentType.objects.create(
|
||||
name="Invoice",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
dt_none = DocumentType.objects.create(
|
||||
name="Other",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
|
||||
t_auto = Tag.objects.create(
|
||||
name="finance",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
t_none = Tag.objects.create(
|
||||
name="misc",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
|
||||
sp_auto = StoragePath.objects.create(
|
||||
name="Finance Path",
|
||||
path="finance/{correspondent}",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
sp_none = StoragePath.objects.create(
|
||||
name="Other Path",
|
||||
path="other/{correspondent}",
|
||||
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||
)
|
||||
|
||||
doc_a = Document.objects.create(
|
||||
title="Invoice from Auto Corp Jan",
|
||||
content="quarterly invoice payment tax financial statement revenue",
|
||||
correspondent=c_auto,
|
||||
document_type=dt_auto,
|
||||
storage_path=sp_auto,
|
||||
checksum="aaa",
|
||||
mime_type="application/pdf",
|
||||
filename="invoice_a.pdf",
|
||||
)
|
||||
doc_a.tags.set([t_auto])
|
||||
|
||||
doc_b = Document.objects.create(
|
||||
title="Invoice from Auto Corp Feb",
|
||||
content="monthly invoice billing statement account balance due",
|
||||
correspondent=c_auto,
|
||||
document_type=dt_auto,
|
||||
storage_path=sp_auto,
|
||||
checksum="bbb",
|
||||
mime_type="application/pdf",
|
||||
filename="invoice_b.pdf",
|
||||
)
|
||||
doc_b.tags.set([t_auto])
|
||||
|
||||
# Control document — no AUTO labels, but has enough content to vectorize
|
||||
doc_c = Document.objects.create(
|
||||
title="Miscellaneous Notes",
|
||||
content="meeting notes agenda discussion summary action items follow",
|
||||
correspondent=c_none,
|
||||
document_type=dt_none,
|
||||
checksum="ccc",
|
||||
mime_type="application/pdf",
|
||||
filename="notes_c.pdf",
|
||||
)
|
||||
doc_c.tags.set([t_none])
|
||||
|
||||
return {
|
||||
"c_auto": c_auto,
|
||||
"c_none": c_none,
|
||||
"dt_auto": dt_auto,
|
||||
"dt_none": dt_none,
|
||||
"t_auto": t_auto,
|
||||
"t_none": t_none,
|
||||
"sp_auto": sp_auto,
|
||||
"sp_none": sp_none,
|
||||
"doc_a": doc_a,
|
||||
"doc_b": doc_b,
|
||||
"doc_c": doc_c,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy-path skip tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.django_db()
|
||||
class TestFastSkipFires:
|
||||
"""The no-op path: nothing changed, so the second train() is skipped."""
|
||||
|
||||
def test_first_train_returns_true(self, classifier, label_corpus):
|
||||
"""First train on a fresh classifier must return True (did work)."""
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_second_train_returns_false(self, classifier, label_corpus):
|
||||
"""Second train with no changes must return False (skipped)."""
|
||||
classifier.train()
|
||||
assert classifier.train() is False
|
||||
|
||||
def test_fast_skip_runs_minimal_queries(self, classifier, label_corpus):
|
||||
"""
|
||||
The no-op path must use at most 5 DB queries:
|
||||
1x Document.objects.aggregate(Max('modified'))
|
||||
4x MATCH_AUTO pk lists (Correspondent / DocumentType / Tag / StoragePath)
|
||||
|
||||
The current implementation (before Phase 1) iterates every document
|
||||
to build the label hash BEFORE it can decide to skip, which is O(N).
|
||||
This test verifies the fast path is in place.
|
||||
"""
|
||||
classifier.train()
|
||||
with CaptureQueriesContext(connection) as ctx:
|
||||
result = classifier.train()
|
||||
assert result is False
|
||||
assert len(ctx.captured_queries) <= 5, (
|
||||
f"Fast skip used {len(ctx.captured_queries)} queries; expected ≤5.\n"
|
||||
+ "\n".join(q["sql"] for q in ctx.captured_queries)
|
||||
)
|
||||
|
||||
def test_fast_skip_refreshes_cache_keys(self, classifier, label_corpus):
|
||||
"""
|
||||
Even on a skip, the cache keys must be refreshed so that the task
|
||||
scheduler can detect the classifier is still current.
|
||||
"""
|
||||
from django.core.cache import cache
|
||||
|
||||
from documents.caching import CLASSIFIER_HASH_KEY
|
||||
from documents.caching import CLASSIFIER_MODIFIED_KEY
|
||||
from documents.caching import CLASSIFIER_VERSION_KEY
|
||||
|
||||
classifier.train()
|
||||
# Evict the keys to prove skip re-populates them
|
||||
cache.delete(CLASSIFIER_MODIFIED_KEY)
|
||||
cache.delete(CLASSIFIER_HASH_KEY)
|
||||
cache.delete(CLASSIFIER_VERSION_KEY)
|
||||
|
||||
result = classifier.train()
|
||||
|
||||
assert result is False
|
||||
assert cache.get(CLASSIFIER_MODIFIED_KEY) is not None
|
||||
assert cache.get(CLASSIFIER_HASH_KEY) is not None
|
||||
assert cache.get(CLASSIFIER_VERSION_KEY) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Correctness tests — skip must NOT fire when the world has changed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.django_db()
|
||||
class TestFastSkipDoesNotFire:
|
||||
"""The skip guard must yield to a full retrain whenever labels change."""
|
||||
|
||||
def test_document_content_modification_triggers_retrain(
|
||||
self,
|
||||
classifier,
|
||||
label_corpus,
|
||||
):
|
||||
"""Updating a document's content updates modified → retrain required."""
|
||||
classifier.train()
|
||||
doc_a = label_corpus["doc_a"]
|
||||
doc_a.content = "completely different words here now nothing same"
|
||||
doc_a.save()
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_document_label_reassignment_triggers_retrain(
|
||||
self,
|
||||
classifier,
|
||||
label_corpus,
|
||||
):
|
||||
"""
|
||||
Reassigning a document to a different AUTO correspondent (touching
|
||||
doc.modified) must trigger a retrain.
|
||||
"""
|
||||
c_auto2 = Correspondent.objects.create(
|
||||
name="Second Auto Corp",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
classifier.train()
|
||||
doc_a = label_corpus["doc_a"]
|
||||
doc_a.correspondent = c_auto2
|
||||
doc_a.save()
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_matching_algorithm_change_on_assigned_tag_triggers_retrain(
|
||||
self,
|
||||
classifier,
|
||||
label_corpus,
|
||||
):
|
||||
"""
|
||||
Flipping a tag's matching_algorithm to MATCH_AUTO after it is already
|
||||
assigned to documents must trigger a retrain — even though no
|
||||
Document.modified timestamp advances.
|
||||
|
||||
This is the key correctness case for the auto-label-set digest:
|
||||
the tag is already on doc_a and doc_b, so once it becomes MATCH_AUTO
|
||||
the classifier needs to learn it.
|
||||
"""
|
||||
# t_none is applied to doc_c (a control doc) via the fixture.
|
||||
# We flip it to MATCH_AUTO; the set of learnable AUTO tags grows.
|
||||
classifier.train()
|
||||
t_none = label_corpus["t_none"]
|
||||
t_none.matching_algorithm = MatchingModel.MATCH_AUTO
|
||||
t_none.save(update_fields=["matching_algorithm"])
|
||||
# Document.modified is NOT touched — this test specifically verifies
|
||||
# that the auto-label-set digest catches the change.
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_new_auto_correspondent_triggers_retrain(self, classifier, label_corpus):
|
||||
"""
|
||||
Adding a brand-new MATCH_AUTO correspondent (unassigned to any doc)
|
||||
must trigger a retrain: the auto-label-set has grown.
|
||||
"""
|
||||
classifier.train()
|
||||
Correspondent.objects.create(
|
||||
name="New Auto Corp",
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
)
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_removing_auto_label_triggers_retrain(self, classifier, label_corpus):
|
||||
"""
|
||||
Deleting a MATCH_AUTO correspondent shrinks the auto-label-set and
|
||||
must trigger a retrain.
|
||||
"""
|
||||
classifier.train()
|
||||
label_corpus["c_auto"].delete()
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_fresh_classifier_always_trains(self, classifier, label_corpus):
|
||||
"""
|
||||
A classifier that has never been trained (last_doc_change_time is None)
|
||||
must always perform a full train, regardless of corpus state.
|
||||
"""
|
||||
assert classifier.last_doc_change_time is None
|
||||
assert classifier.train() is True
|
||||
|
||||
def test_no_documents_raises_value_error(self, classifier, classifier_settings):
|
||||
"""train() with an empty database must raise ValueError."""
|
||||
with pytest.raises(ValueError, match="No training data"):
|
||||
classifier.train()
|
||||
Reference in New Issue
Block a user