mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-04-10 01:58:53 +00:00
Compare commits
3 Commits
dev
...
feature-cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a26514a96 | ||
|
|
1fefd506b7 | ||
|
|
68b866aeee |
@@ -11,7 +11,6 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from collections.abc import Iterator
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
@@ -19,6 +18,7 @@ if TYPE_CHECKING:
|
|||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.core.cache import caches
|
from django.core.cache import caches
|
||||||
|
from django.db.models import Max
|
||||||
|
|
||||||
from documents.caching import CACHE_5_MINUTES
|
from documents.caching import CACHE_5_MINUTES
|
||||||
from documents.caching import CACHE_50_MINUTES
|
from documents.caching import CACHE_50_MINUTES
|
||||||
@@ -99,7 +99,8 @@ class DocumentClassifier:
|
|||||||
# v8 - Added storage path classifier
|
# v8 - Added storage path classifier
|
||||||
# v9 - Changed from hashing to time/ids for re-train check
|
# v9 - Changed from hashing to time/ids for re-train check
|
||||||
# v10 - HMAC-signed model file
|
# v10 - HMAC-signed model file
|
||||||
FORMAT_VERSION = 10
|
# v11 - Added auto-label-set digest for fast skip without full document scan
|
||||||
|
FORMAT_VERSION = 11
|
||||||
|
|
||||||
HMAC_SIZE = 32 # SHA-256 digest length
|
HMAC_SIZE = 32 # SHA-256 digest length
|
||||||
|
|
||||||
@@ -108,6 +109,8 @@ class DocumentClassifier:
|
|||||||
self.last_doc_change_time: datetime | None = None
|
self.last_doc_change_time: datetime | None = None
|
||||||
# Hash of primary keys of AUTO matching values last used in training
|
# Hash of primary keys of AUTO matching values last used in training
|
||||||
self.last_auto_type_hash: bytes | None = None
|
self.last_auto_type_hash: bytes | None = None
|
||||||
|
# Digest of the set of all MATCH_AUTO label PKs (fast-skip guard)
|
||||||
|
self.last_auto_label_set_digest: bytes | None = None
|
||||||
|
|
||||||
self.data_vectorizer = None
|
self.data_vectorizer = None
|
||||||
self.data_vectorizer_hash = None
|
self.data_vectorizer_hash = None
|
||||||
@@ -140,6 +143,29 @@ class DocumentClassifier:
|
|||||||
sha256,
|
sha256,
|
||||||
).digest()
|
).digest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_auto_label_set_digest() -> bytes:
|
||||||
|
"""
|
||||||
|
Return a SHA-256 digest of all MATCH_AUTO label PKs across the four
|
||||||
|
label types. Four cheap indexed queries; stable for any fixed set of
|
||||||
|
AUTO labels regardless of document assignments.
|
||||||
|
"""
|
||||||
|
from documents.models import Correspondent
|
||||||
|
from documents.models import DocumentType
|
||||||
|
from documents.models import StoragePath
|
||||||
|
from documents.models import Tag
|
||||||
|
|
||||||
|
hasher = sha256()
|
||||||
|
for model in (Correspondent, DocumentType, Tag, StoragePath):
|
||||||
|
pks = sorted(
|
||||||
|
model.objects.filter(
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
).values_list("pk", flat=True),
|
||||||
|
)
|
||||||
|
for pk in pks:
|
||||||
|
hasher.update(pk.to_bytes(4, "little", signed=False))
|
||||||
|
return hasher.digest()
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
from sklearn.exceptions import InconsistentVersionWarning
|
from sklearn.exceptions import InconsistentVersionWarning
|
||||||
|
|
||||||
@@ -161,6 +187,7 @@ class DocumentClassifier:
|
|||||||
schema_version,
|
schema_version,
|
||||||
self.last_doc_change_time,
|
self.last_doc_change_time,
|
||||||
self.last_auto_type_hash,
|
self.last_auto_type_hash,
|
||||||
|
self.last_auto_label_set_digest,
|
||||||
self.data_vectorizer,
|
self.data_vectorizer,
|
||||||
self.tags_binarizer,
|
self.tags_binarizer,
|
||||||
self.tags_classifier,
|
self.tags_classifier,
|
||||||
@@ -202,6 +229,7 @@ class DocumentClassifier:
|
|||||||
self.FORMAT_VERSION,
|
self.FORMAT_VERSION,
|
||||||
self.last_doc_change_time,
|
self.last_doc_change_time,
|
||||||
self.last_auto_type_hash,
|
self.last_auto_type_hash,
|
||||||
|
self.last_auto_label_set_digest,
|
||||||
self.data_vectorizer,
|
self.data_vectorizer,
|
||||||
self.tags_binarizer,
|
self.tags_binarizer,
|
||||||
self.tags_classifier,
|
self.tags_classifier,
|
||||||
@@ -224,6 +252,39 @@ class DocumentClassifier:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
notify = status_callback if status_callback is not None else lambda _: None
|
notify = status_callback if status_callback is not None else lambda _: None
|
||||||
|
|
||||||
|
# Fast skip: avoid the expensive per-document label scan when nothing
|
||||||
|
# has changed. Requires a prior training run to have populated both
|
||||||
|
# last_doc_change_time and last_auto_label_set_digest.
|
||||||
|
if (
|
||||||
|
self.last_doc_change_time is not None
|
||||||
|
and self.last_auto_label_set_digest is not None
|
||||||
|
):
|
||||||
|
latest_mod = Document.objects.exclude(
|
||||||
|
tags__is_inbox_tag=True,
|
||||||
|
).aggregate(Max("modified"))["modified__max"]
|
||||||
|
if latest_mod is not None and latest_mod <= self.last_doc_change_time:
|
||||||
|
current_digest = self._compute_auto_label_set_digest()
|
||||||
|
if current_digest == self.last_auto_label_set_digest:
|
||||||
|
logger.info("No updates since last training")
|
||||||
|
cache.set(
|
||||||
|
CLASSIFIER_MODIFIED_KEY,
|
||||||
|
self.last_doc_change_time,
|
||||||
|
CACHE_50_MINUTES,
|
||||||
|
)
|
||||||
|
cache.set(
|
||||||
|
CLASSIFIER_HASH_KEY,
|
||||||
|
self.last_auto_type_hash.hex()
|
||||||
|
if self.last_auto_type_hash
|
||||||
|
else "",
|
||||||
|
CACHE_50_MINUTES,
|
||||||
|
)
|
||||||
|
cache.set(
|
||||||
|
CLASSIFIER_VERSION_KEY,
|
||||||
|
self.FORMAT_VERSION,
|
||||||
|
CACHE_50_MINUTES,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
# Get non-inbox documents
|
# Get non-inbox documents
|
||||||
docs_queryset = (
|
docs_queryset = (
|
||||||
Document.objects.exclude(
|
Document.objects.exclude(
|
||||||
@@ -242,12 +303,15 @@ class DocumentClassifier:
|
|||||||
labels_correspondent = []
|
labels_correspondent = []
|
||||||
labels_document_type = []
|
labels_document_type = []
|
||||||
labels_storage_path = []
|
labels_storage_path = []
|
||||||
|
doc_contents: list[str] = []
|
||||||
|
|
||||||
# Step 1: Extract and preprocess training data from the database.
|
# Step 1: Extract labels and capture content in a single pass.
|
||||||
logger.debug("Gathering data from database...")
|
logger.debug("Gathering data from database...")
|
||||||
notify(f"Gathering data from {docs_queryset.count()} document(s)...")
|
notify(f"Gathering data from {docs_queryset.count()} document(s)...")
|
||||||
hasher = sha256()
|
hasher = sha256()
|
||||||
for doc in docs_queryset:
|
for doc in docs_queryset:
|
||||||
|
doc_contents.append(doc.content)
|
||||||
|
|
||||||
y = -1
|
y = -1
|
||||||
dt = doc.document_type
|
dt = doc.document_type
|
||||||
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
|
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||||
@@ -282,25 +346,7 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
num_tags = len(labels_tags_unique)
|
num_tags = len(labels_tags_unique)
|
||||||
|
|
||||||
# Check if retraining is actually required.
|
|
||||||
# A document has been updated since the classifier was trained
|
|
||||||
# New auto tags, types, correspondent, storage paths exist
|
|
||||||
latest_doc_change = docs_queryset.latest("modified").modified
|
latest_doc_change = docs_queryset.latest("modified").modified
|
||||||
if (
|
|
||||||
self.last_doc_change_time is not None
|
|
||||||
and self.last_doc_change_time >= latest_doc_change
|
|
||||||
) and self.last_auto_type_hash == hasher.digest():
|
|
||||||
logger.info("No updates since last training")
|
|
||||||
# Set the classifier information into the cache
|
|
||||||
# Caching for 50 minutes, so slightly less than the normal retrain time
|
|
||||||
cache.set(
|
|
||||||
CLASSIFIER_MODIFIED_KEY,
|
|
||||||
self.last_doc_change_time,
|
|
||||||
CACHE_50_MINUTES,
|
|
||||||
)
|
|
||||||
cache.set(CLASSIFIER_HASH_KEY, hasher.hexdigest(), CACHE_50_MINUTES)
|
|
||||||
cache.set(CLASSIFIER_VERSION_KEY, self.FORMAT_VERSION, CACHE_50_MINUTES)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# subtract 1 since -1 (null) is also part of the classes.
|
# subtract 1 since -1 (null) is also part of the classes.
|
||||||
|
|
||||||
@@ -317,21 +363,16 @@ class DocumentClassifier:
|
|||||||
)
|
)
|
||||||
|
|
||||||
from sklearn.feature_extraction.text import CountVectorizer
|
from sklearn.feature_extraction.text import CountVectorizer
|
||||||
|
from sklearn.multiclass import OneVsRestClassifier
|
||||||
from sklearn.neural_network import MLPClassifier
|
from sklearn.neural_network import MLPClassifier
|
||||||
from sklearn.preprocessing import LabelBinarizer
|
from sklearn.preprocessing import LabelBinarizer
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer
|
from sklearn.preprocessing import MultiLabelBinarizer
|
||||||
|
from sklearn.svm import LinearSVC
|
||||||
|
|
||||||
# Step 2: vectorize data
|
# Step 2: vectorize data
|
||||||
logger.debug("Vectorizing data...")
|
logger.debug("Vectorizing data...")
|
||||||
notify("Vectorizing document content...")
|
notify("Vectorizing document content...")
|
||||||
|
|
||||||
def content_generator() -> Iterator[str]:
|
|
||||||
"""
|
|
||||||
Generates the content for documents, but once at a time
|
|
||||||
"""
|
|
||||||
for doc in docs_queryset:
|
|
||||||
yield self.preprocess_content(doc.content, shared_cache=False)
|
|
||||||
|
|
||||||
self.data_vectorizer = CountVectorizer(
|
self.data_vectorizer = CountVectorizer(
|
||||||
analyzer="word",
|
analyzer="word",
|
||||||
ngram_range=(1, 2),
|
ngram_range=(1, 2),
|
||||||
@@ -339,7 +380,8 @@ class DocumentClassifier:
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_vectorized: ndarray = self.data_vectorizer.fit_transform(
|
data_vectorized: ndarray = self.data_vectorizer.fit_transform(
|
||||||
content_generator(),
|
self.preprocess_content(content, shared_cache=False)
|
||||||
|
for content in doc_contents
|
||||||
)
|
)
|
||||||
|
|
||||||
# See the notes here:
|
# See the notes here:
|
||||||
@@ -353,8 +395,10 @@ class DocumentClassifier:
|
|||||||
notify(f"Training tags classifier ({num_tags} tag(s))...")
|
notify(f"Training tags classifier ({num_tags} tag(s))...")
|
||||||
|
|
||||||
if num_tags == 1:
|
if num_tags == 1:
|
||||||
# Special case where only one tag has auto:
|
# Special case: only one AUTO tag — use binary classification.
|
||||||
# Fallback to binary classification.
|
# MLPClassifier is used here because LinearSVC requires at least
|
||||||
|
# 2 distinct classes in training data, which cannot be guaranteed
|
||||||
|
# when all documents share the single AUTO tag.
|
||||||
labels_tags = [
|
labels_tags = [
|
||||||
label[0] if len(label) == 1 else -1 for label in labels_tags
|
label[0] if len(label) == 1 else -1 for label in labels_tags
|
||||||
]
|
]
|
||||||
@@ -362,11 +406,15 @@ class DocumentClassifier:
|
|||||||
labels_tags_vectorized: ndarray = self.tags_binarizer.fit_transform(
|
labels_tags_vectorized: ndarray = self.tags_binarizer.fit_transform(
|
||||||
labels_tags,
|
labels_tags,
|
||||||
).ravel()
|
).ravel()
|
||||||
|
self.tags_classifier = MLPClassifier(tol=0.01)
|
||||||
else:
|
else:
|
||||||
|
# General multi-label case: LinearSVC via OneVsRestClassifier.
|
||||||
|
# Vastly more memory- and time-efficient than MLPClassifier for
|
||||||
|
# large class counts (e.g. hundreds of AUTO tags).
|
||||||
self.tags_binarizer = MultiLabelBinarizer()
|
self.tags_binarizer = MultiLabelBinarizer()
|
||||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
|
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
|
||||||
|
self.tags_classifier = OneVsRestClassifier(LinearSVC())
|
||||||
|
|
||||||
self.tags_classifier = MLPClassifier(tol=0.01)
|
|
||||||
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
||||||
else:
|
else:
|
||||||
self.tags_classifier = None
|
self.tags_classifier = None
|
||||||
@@ -416,6 +464,7 @@ class DocumentClassifier:
|
|||||||
|
|
||||||
self.last_doc_change_time = latest_doc_change
|
self.last_doc_change_time = latest_doc_change
|
||||||
self.last_auto_type_hash = hasher.digest()
|
self.last_auto_type_hash = hasher.digest()
|
||||||
|
self.last_auto_label_set_digest = self._compute_auto_label_set_digest()
|
||||||
self._update_data_vectorizer_hash()
|
self._update_data_vectorizer_hash()
|
||||||
|
|
||||||
# Set the classifier information into the cache
|
# Set the classifier information into the cache
|
||||||
|
|||||||
134
src/documents/tests/test_classifier_single_pass.py
Normal file
134
src/documents/tests/test_classifier_single_pass.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""
|
||||||
|
Phase 2 — Single queryset pass in DocumentClassifier.train()
|
||||||
|
|
||||||
|
The document queryset must be iterated exactly once: during the label
|
||||||
|
extraction loop, which now also captures doc.content for vectorization.
|
||||||
|
The previous content_generator() caused a second full table scan.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from django.db.models.query import QuerySet
|
||||||
|
|
||||||
|
from documents.classifier import DocumentClassifier
|
||||||
|
from documents.models import Correspondent
|
||||||
|
from documents.models import Document
|
||||||
|
from documents.models import DocumentType
|
||||||
|
from documents.models import MatchingModel
|
||||||
|
from documents.models import StoragePath
|
||||||
|
from documents.models import Tag
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures (mirrors test_classifier_train_skip.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def classifier_settings(settings, tmp_path):
|
||||||
|
settings.MODEL_FILE = tmp_path / "model.pickle"
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def classifier(classifier_settings):
|
||||||
|
return DocumentClassifier()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def label_corpus(classifier_settings):
|
||||||
|
c_auto = Correspondent.objects.create(
|
||||||
|
name="Auto Corp",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
dt_auto = DocumentType.objects.create(
|
||||||
|
name="Invoice",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
t_auto = Tag.objects.create(
|
||||||
|
name="finance",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
sp_auto = StoragePath.objects.create(
|
||||||
|
name="Finance Path",
|
||||||
|
path="finance/{correspondent}",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_a = Document.objects.create(
|
||||||
|
title="Invoice A",
|
||||||
|
content="quarterly invoice payment tax financial statement revenue",
|
||||||
|
correspondent=c_auto,
|
||||||
|
document_type=dt_auto,
|
||||||
|
storage_path=sp_auto,
|
||||||
|
checksum="aaa",
|
||||||
|
mime_type="application/pdf",
|
||||||
|
filename="invoice_a.pdf",
|
||||||
|
)
|
||||||
|
doc_a.tags.set([t_auto])
|
||||||
|
|
||||||
|
doc_b = Document.objects.create(
|
||||||
|
title="Invoice B",
|
||||||
|
content="monthly invoice billing statement account balance due",
|
||||||
|
correspondent=c_auto,
|
||||||
|
document_type=dt_auto,
|
||||||
|
storage_path=sp_auto,
|
||||||
|
checksum="bbb",
|
||||||
|
mime_type="application/pdf",
|
||||||
|
filename="invoice_b.pdf",
|
||||||
|
)
|
||||||
|
doc_b.tags.set([t_auto])
|
||||||
|
|
||||||
|
doc_c = Document.objects.create(
|
||||||
|
title="Notes",
|
||||||
|
content="meeting notes agenda discussion summary action items follow",
|
||||||
|
checksum="ccc",
|
||||||
|
mime_type="application/pdf",
|
||||||
|
filename="notes_c.pdf",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"doc_a": doc_a, "doc_b": doc_b, "doc_c": doc_c}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db()
|
||||||
|
class TestSingleQuerysetPass:
|
||||||
|
def test_train_iterates_document_queryset_once(self, classifier, label_corpus):
|
||||||
|
"""
|
||||||
|
train() must iterate the Document queryset exactly once.
|
||||||
|
|
||||||
|
Before Phase 2 there were two iterations: one in the label extraction
|
||||||
|
loop and a second inside content_generator() for CountVectorizer.
|
||||||
|
After Phase 2 content is captured during the label loop; the second
|
||||||
|
iteration is eliminated.
|
||||||
|
"""
|
||||||
|
original_iter = QuerySet.__iter__
|
||||||
|
doc_iter_count = 0
|
||||||
|
|
||||||
|
def counting_iter(qs):
|
||||||
|
nonlocal doc_iter_count
|
||||||
|
if qs.model is Document:
|
||||||
|
doc_iter_count += 1
|
||||||
|
return original_iter(qs)
|
||||||
|
|
||||||
|
with mock.patch.object(QuerySet, "__iter__", counting_iter):
|
||||||
|
classifier.train()
|
||||||
|
|
||||||
|
assert doc_iter_count == 1, (
|
||||||
|
f"Expected 1 Document queryset iteration, got {doc_iter_count}. "
|
||||||
|
"content_generator() may still be re-fetching from the DB."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_train_result_unchanged(self, classifier, label_corpus):
|
||||||
|
"""
|
||||||
|
Collapsing to a single pass must not change what the classifier learns:
|
||||||
|
a second train() with no changes still returns False.
|
||||||
|
"""
|
||||||
|
assert classifier.train() is True
|
||||||
|
assert classifier.train() is False
|
||||||
300
src/documents/tests/test_classifier_tags_correctness.py
Normal file
300
src/documents/tests/test_classifier_tags_correctness.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
"""
|
||||||
|
Tags classifier correctness test — Phase 3b gate.
|
||||||
|
|
||||||
|
This test must pass both BEFORE and AFTER the MLPClassifier → LinearSVC swap.
|
||||||
|
It verifies that the tags classifier correctly learns discriminative signal and
|
||||||
|
predicts the right tags on held-out documents.
|
||||||
|
|
||||||
|
Run before the swap to establish a baseline, then run again after to confirm
|
||||||
|
the new algorithm is at least as correct.
|
||||||
|
|
||||||
|
Two scenarios are tested:
|
||||||
|
1. Multi-tag (num_tags > 1) — the common case; uses MultiLabelBinarizer
|
||||||
|
2. Single-tag (num_tags == 1) — special binary path; uses LabelBinarizer
|
||||||
|
|
||||||
|
Corpus design: each tag has a distinct vocabulary cluster. Each training
|
||||||
|
document contains words from exactly one cluster (or two for multi-tag docs).
|
||||||
|
Held-out test documents contain the same cluster words; correct classification
|
||||||
|
requires the model to learn the vocabulary → tag mapping.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from documents.classifier import DocumentClassifier
|
||||||
|
from documents.models import Correspondent
|
||||||
|
from documents.models import Document
|
||||||
|
from documents.models import DocumentType
|
||||||
|
from documents.models import MatchingModel
|
||||||
|
from documents.models import StoragePath
|
||||||
|
from documents.models import Tag
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Vocabulary clusters — intentionally non-overlapping so both MLP and SVM
|
||||||
|
# should learn them perfectly or near-perfectly.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
FINANCE_WORDS = (
|
||||||
|
"invoice payment tax revenue billing statement account receivable "
|
||||||
|
"quarterly budget expense ledger debit credit profit loss fiscal"
|
||||||
|
)
|
||||||
|
LEGAL_WORDS = (
|
||||||
|
"contract agreement terms conditions clause liability indemnity "
|
||||||
|
"jurisdiction arbitration compliance regulation statute obligation"
|
||||||
|
)
|
||||||
|
MEDICAL_WORDS = (
|
||||||
|
"prescription diagnosis treatment patient health symptom dosage "
|
||||||
|
"physician referral therapy clinical examination procedure chronic"
|
||||||
|
)
|
||||||
|
HR_WORDS = (
|
||||||
|
"employee salary onboarding performance review appraisal benefits "
|
||||||
|
"recruitment hiring resignation termination payroll department staff"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def classifier_settings(settings, tmp_path):
|
||||||
|
settings.MODEL_FILE = tmp_path / "model.pickle"
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def classifier(classifier_settings):
|
||||||
|
return DocumentClassifier()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_doc(title, content, checksum, tags=(), **kwargs):
|
||||||
|
doc = Document.objects.create(
|
||||||
|
title=title,
|
||||||
|
content=content,
|
||||||
|
checksum=checksum,
|
||||||
|
mime_type="application/pdf",
|
||||||
|
filename=f"{checksum}.pdf",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if tags:
|
||||||
|
doc.tags.set(tags)
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
def _words(cluster, extra=""):
|
||||||
|
"""Repeat cluster words enough times to clear min_df=0.01 at ~40 docs."""
|
||||||
|
return f"{cluster} {cluster} {extra}".strip()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Multi-tag correctness
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def multi_tag_corpus(classifier_settings):
|
||||||
|
"""
|
||||||
|
40 training documents across 4 AUTO tags with distinct vocabulary.
|
||||||
|
10 single-tag docs per tag + 5 two-tag docs. Total: 45 docs.
|
||||||
|
|
||||||
|
A non-AUTO correspondent and doc type are included to keep the
|
||||||
|
other classifiers happy and not raise ValueError.
|
||||||
|
"""
|
||||||
|
t_finance = Tag.objects.create(
|
||||||
|
name="finance",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
t_legal = Tag.objects.create(
|
||||||
|
name="legal",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
t_medical = Tag.objects.create(
|
||||||
|
name="medical",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
t_hr = Tag.objects.create(name="hr", matching_algorithm=MatchingModel.MATCH_AUTO)
|
||||||
|
|
||||||
|
# non-AUTO labels to keep the other classifiers from raising
|
||||||
|
c = Correspondent.objects.create(
|
||||||
|
name="org",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||||
|
)
|
||||||
|
dt = DocumentType.objects.create(
|
||||||
|
name="doc",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||||
|
)
|
||||||
|
sp = StoragePath.objects.create(
|
||||||
|
name="archive",
|
||||||
|
path="archive",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
checksum = 0
|
||||||
|
|
||||||
|
def make(title, content, tags):
|
||||||
|
nonlocal checksum
|
||||||
|
checksum += 1
|
||||||
|
return _make_doc(
|
||||||
|
title,
|
||||||
|
content,
|
||||||
|
f"{checksum:04d}",
|
||||||
|
tags=tags,
|
||||||
|
correspondent=c,
|
||||||
|
document_type=dt,
|
||||||
|
storage_path=sp,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 10 single-tag training docs per tag
|
||||||
|
for i in range(10):
|
||||||
|
make(f"finance-{i}", _words(FINANCE_WORDS, f"doc{i}"), [t_finance])
|
||||||
|
make(f"legal-{i}", _words(LEGAL_WORDS, f"doc{i}"), [t_legal])
|
||||||
|
make(f"medical-{i}", _words(MEDICAL_WORDS, f"doc{i}"), [t_medical])
|
||||||
|
make(f"hr-{i}", _words(HR_WORDS, f"doc{i}"), [t_hr])
|
||||||
|
|
||||||
|
# 5 two-tag training docs
|
||||||
|
for i in range(5):
|
||||||
|
make(
|
||||||
|
f"finance-legal-{i}",
|
||||||
|
_words(FINANCE_WORDS + " " + LEGAL_WORDS, f"combo{i}"),
|
||||||
|
[t_finance, t_legal],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"t_finance": t_finance,
|
||||||
|
"t_legal": t_legal,
|
||||||
|
"t_medical": t_medical,
|
||||||
|
"t_hr": t_hr,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db()
|
||||||
|
class TestMultiTagCorrectness:
|
||||||
|
"""
|
||||||
|
The tags classifier must correctly predict tags on held-out documents whose
|
||||||
|
content clearly belongs to one or two vocabulary clusters.
|
||||||
|
|
||||||
|
A prediction is "correct" if the expected tag is present in the result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_single_cluster_docs_predicted_correctly(
|
||||||
|
self,
|
||||||
|
classifier,
|
||||||
|
multi_tag_corpus,
|
||||||
|
):
|
||||||
|
"""Each single-cluster held-out doc gets exactly the right tag."""
|
||||||
|
classifier.train()
|
||||||
|
tags = multi_tag_corpus
|
||||||
|
|
||||||
|
cases = [
|
||||||
|
(FINANCE_WORDS + " unique alpha", [tags["t_finance"].pk]),
|
||||||
|
(LEGAL_WORDS + " unique beta", [tags["t_legal"].pk]),
|
||||||
|
(MEDICAL_WORDS + " unique gamma", [tags["t_medical"].pk]),
|
||||||
|
(HR_WORDS + " unique delta", [tags["t_hr"].pk]),
|
||||||
|
]
|
||||||
|
|
||||||
|
for content, expected_pks in cases:
|
||||||
|
predicted = classifier.predict_tags(content)
|
||||||
|
for pk in expected_pks:
|
||||||
|
assert pk in predicted, (
|
||||||
|
f"Expected tag pk={pk} in predictions for content starting "
|
||||||
|
f"'{content[:40]}…', got {predicted}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multi_cluster_doc_gets_both_tags(self, classifier, multi_tag_corpus):
|
||||||
|
"""A document with finance + legal vocabulary gets both tags."""
|
||||||
|
classifier.train()
|
||||||
|
tags = multi_tag_corpus
|
||||||
|
|
||||||
|
content = FINANCE_WORDS + " " + LEGAL_WORDS + " unique epsilon"
|
||||||
|
predicted = classifier.predict_tags(content)
|
||||||
|
|
||||||
|
assert tags["t_finance"].pk in predicted, f"Expected finance tag in {predicted}"
|
||||||
|
assert tags["t_legal"].pk in predicted, f"Expected legal tag in {predicted}"
|
||||||
|
|
||||||
|
def test_unrelated_content_predicts_no_trained_tags(
|
||||||
|
self,
|
||||||
|
classifier,
|
||||||
|
multi_tag_corpus,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Completely alien content should not confidently fire any learned tag.
|
||||||
|
This is a soft check — we only assert no false positives on a document
|
||||||
|
that shares zero vocabulary with the training corpus.
|
||||||
|
"""
|
||||||
|
classifier.train()
|
||||||
|
|
||||||
|
alien = (
|
||||||
|
"xyzzyx qwerty asdfgh zxcvbn plokij unique zeta "
|
||||||
|
"xyzzyx qwerty asdfgh zxcvbn plokij unique zeta"
|
||||||
|
)
|
||||||
|
predicted = classifier.predict_tags(alien)
|
||||||
|
# Not a hard requirement — just log for human inspection
|
||||||
|
# Both MLP and SVM may or may not produce false positives on OOV content
|
||||||
|
assert isinstance(predicted, list)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Single-tag (binary) correctness
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def single_tag_corpus(classifier_settings):
|
||||||
|
"""
|
||||||
|
Corpus with exactly ONE AUTO tag, exercising the LabelBinarizer +
|
||||||
|
binary classification path. Documents either have the tag or don't.
|
||||||
|
"""
|
||||||
|
t_finance = Tag.objects.create(
|
||||||
|
name="finance",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
c = Correspondent.objects.create(
|
||||||
|
name="org",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||||
|
)
|
||||||
|
dt = DocumentType.objects.create(
|
||||||
|
name="doc",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
checksum = 0
|
||||||
|
|
||||||
|
def make(title, content, tags):
|
||||||
|
nonlocal checksum
|
||||||
|
checksum += 1
|
||||||
|
return _make_doc(
|
||||||
|
title,
|
||||||
|
content,
|
||||||
|
f"s{checksum:04d}",
|
||||||
|
tags=tags,
|
||||||
|
correspondent=c,
|
||||||
|
document_type=dt,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
make(f"finance-{i}", _words(FINANCE_WORDS, f"s{i}"), [t_finance])
|
||||||
|
make(f"other-{i}", _words(LEGAL_WORDS, f"s{i}"), [])
|
||||||
|
|
||||||
|
return {"t_finance": t_finance}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db()
|
||||||
|
class TestSingleTagCorrectness:
|
||||||
|
def test_finance_content_predicts_finance_tag(self, classifier, single_tag_corpus):
|
||||||
|
"""Finance vocabulary → finance tag predicted."""
|
||||||
|
classifier.train()
|
||||||
|
tags = single_tag_corpus
|
||||||
|
|
||||||
|
predicted = classifier.predict_tags(FINANCE_WORDS + " unique alpha single")
|
||||||
|
assert tags["t_finance"].pk in predicted, (
|
||||||
|
f"Expected finance tag pk={tags['t_finance'].pk} in {predicted}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_non_finance_content_predicts_no_tag(self, classifier, single_tag_corpus):
|
||||||
|
"""Non-finance vocabulary → no tag predicted."""
|
||||||
|
classifier.train()
|
||||||
|
|
||||||
|
predicted = classifier.predict_tags(LEGAL_WORDS + " unique beta single")
|
||||||
|
assert predicted == [], f"Expected no tags, got {predicted}"
|
||||||
325
src/documents/tests/test_classifier_train_skip.py
Normal file
325
src/documents/tests/test_classifier_train_skip.py
Normal file
@@ -0,0 +1,325 @@
|
|||||||
|
"""
|
||||||
|
Phase 1 — fast-skip optimisation in DocumentClassifier.train()
|
||||||
|
|
||||||
|
The goal: when nothing has changed since the last training run, train() should
|
||||||
|
return False after at most 5 DB queries (1x MAX(modified) + 4x MATCH_AUTO pk
|
||||||
|
lists), not after a full per-document label scan.
|
||||||
|
|
||||||
|
Correctness invariant: the skip must NOT fire when the set of AUTO-matching
|
||||||
|
labels has changed, even if no Document.modified timestamp has advanced (e.g.
|
||||||
|
a Tag's matching_algorithm was flipped to MATCH_AUTO after the last train).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from django.db import connection
|
||||||
|
from django.test.utils import CaptureQueriesContext
|
||||||
|
|
||||||
|
from documents.classifier import DocumentClassifier
|
||||||
|
from documents.models import Correspondent
|
||||||
|
from documents.models import Document
|
||||||
|
from documents.models import DocumentType
|
||||||
|
from documents.models import MatchingModel
|
||||||
|
from documents.models import StoragePath
|
||||||
|
from documents.models import Tag
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def classifier_settings(settings, tmp_path: Path):
|
||||||
|
"""Point MODEL_FILE at a temp directory so tests are hermetic."""
|
||||||
|
settings.MODEL_FILE = tmp_path / "model.pickle"
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def classifier(classifier_settings):
|
||||||
|
"""Fresh DocumentClassifier instance with test settings active."""
|
||||||
|
return DocumentClassifier()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def label_corpus(classifier_settings):
|
||||||
|
"""
|
||||||
|
Minimal label + document corpus that produces a trainable classifier.
|
||||||
|
|
||||||
|
Creates
|
||||||
|
-------
|
||||||
|
Correspondents
|
||||||
|
c_auto — MATCH_AUTO, assigned to two docs
|
||||||
|
c_none — MATCH_NONE (control)
|
||||||
|
DocumentTypes
|
||||||
|
dt_auto — MATCH_AUTO, assigned to two docs
|
||||||
|
dt_none — MATCH_NONE (control)
|
||||||
|
Tags
|
||||||
|
t_auto — MATCH_AUTO, applied to two docs
|
||||||
|
t_none — MATCH_NONE (control, applied to one doc but never learned)
|
||||||
|
StoragePaths
|
||||||
|
sp_auto — MATCH_AUTO, assigned to two docs
|
||||||
|
sp_none — MATCH_NONE (control)
|
||||||
|
|
||||||
|
Documents
|
||||||
|
doc_a, doc_b — assigned AUTO labels above
|
||||||
|
doc_c — control doc (MATCH_NONE labels only)
|
||||||
|
|
||||||
|
The fixture returns a dict with all created objects for direct mutation in
|
||||||
|
individual tests.
|
||||||
|
"""
|
||||||
|
c_auto = Correspondent.objects.create(
|
||||||
|
name="Auto Corp",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
c_none = Correspondent.objects.create(
|
||||||
|
name="Manual Corp",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
dt_auto = DocumentType.objects.create(
|
||||||
|
name="Invoice",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
dt_none = DocumentType.objects.create(
|
||||||
|
name="Other",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
t_auto = Tag.objects.create(
|
||||||
|
name="finance",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
t_none = Tag.objects.create(
|
||||||
|
name="misc",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
sp_auto = StoragePath.objects.create(
|
||||||
|
name="Finance Path",
|
||||||
|
path="finance/{correspondent}",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
sp_none = StoragePath.objects.create(
|
||||||
|
name="Other Path",
|
||||||
|
path="other/{correspondent}",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_a = Document.objects.create(
|
||||||
|
title="Invoice from Auto Corp Jan",
|
||||||
|
content="quarterly invoice payment tax financial statement revenue",
|
||||||
|
correspondent=c_auto,
|
||||||
|
document_type=dt_auto,
|
||||||
|
storage_path=sp_auto,
|
||||||
|
checksum="aaa",
|
||||||
|
mime_type="application/pdf",
|
||||||
|
filename="invoice_a.pdf",
|
||||||
|
)
|
||||||
|
doc_a.tags.set([t_auto])
|
||||||
|
|
||||||
|
doc_b = Document.objects.create(
|
||||||
|
title="Invoice from Auto Corp Feb",
|
||||||
|
content="monthly invoice billing statement account balance due",
|
||||||
|
correspondent=c_auto,
|
||||||
|
document_type=dt_auto,
|
||||||
|
storage_path=sp_auto,
|
||||||
|
checksum="bbb",
|
||||||
|
mime_type="application/pdf",
|
||||||
|
filename="invoice_b.pdf",
|
||||||
|
)
|
||||||
|
doc_b.tags.set([t_auto])
|
||||||
|
|
||||||
|
# Control document — no AUTO labels, but has enough content to vectorize
|
||||||
|
doc_c = Document.objects.create(
|
||||||
|
title="Miscellaneous Notes",
|
||||||
|
content="meeting notes agenda discussion summary action items follow",
|
||||||
|
correspondent=c_none,
|
||||||
|
document_type=dt_none,
|
||||||
|
checksum="ccc",
|
||||||
|
mime_type="application/pdf",
|
||||||
|
filename="notes_c.pdf",
|
||||||
|
)
|
||||||
|
doc_c.tags.set([t_none])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"c_auto": c_auto,
|
||||||
|
"c_none": c_none,
|
||||||
|
"dt_auto": dt_auto,
|
||||||
|
"dt_none": dt_none,
|
||||||
|
"t_auto": t_auto,
|
||||||
|
"t_none": t_none,
|
||||||
|
"sp_auto": sp_auto,
|
||||||
|
"sp_none": sp_none,
|
||||||
|
"doc_a": doc_a,
|
||||||
|
"doc_b": doc_b,
|
||||||
|
"doc_c": doc_c,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Happy-path skip tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db()
|
||||||
|
class TestFastSkipFires:
|
||||||
|
"""The no-op path: nothing changed, so the second train() is skipped."""
|
||||||
|
|
||||||
|
def test_first_train_returns_true(self, classifier, label_corpus):
|
||||||
|
"""First train on a fresh classifier must return True (did work)."""
|
||||||
|
assert classifier.train() is True
|
||||||
|
|
||||||
|
def test_second_train_returns_false(self, classifier, label_corpus):
|
||||||
|
"""Second train with no changes must return False (skipped)."""
|
||||||
|
classifier.train()
|
||||||
|
assert classifier.train() is False
|
||||||
|
|
||||||
|
def test_fast_skip_runs_minimal_queries(self, classifier, label_corpus):
|
||||||
|
"""
|
||||||
|
The no-op path must use at most 5 DB queries:
|
||||||
|
1x Document.objects.aggregate(Max('modified'))
|
||||||
|
4x MATCH_AUTO pk lists (Correspondent / DocumentType / Tag / StoragePath)
|
||||||
|
|
||||||
|
The current implementation (before Phase 1) iterates every document
|
||||||
|
to build the label hash BEFORE it can decide to skip, which is O(N).
|
||||||
|
This test verifies the fast path is in place.
|
||||||
|
"""
|
||||||
|
classifier.train()
|
||||||
|
with CaptureQueriesContext(connection) as ctx:
|
||||||
|
result = classifier.train()
|
||||||
|
assert result is False
|
||||||
|
assert len(ctx.captured_queries) <= 5, (
|
||||||
|
f"Fast skip used {len(ctx.captured_queries)} queries; expected ≤5.\n"
|
||||||
|
+ "\n".join(q["sql"] for q in ctx.captured_queries)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fast_skip_refreshes_cache_keys(self, classifier, label_corpus):
|
||||||
|
"""
|
||||||
|
Even on a skip, the cache keys must be refreshed so that the task
|
||||||
|
scheduler can detect the classifier is still current.
|
||||||
|
"""
|
||||||
|
from django.core.cache import cache
|
||||||
|
|
||||||
|
from documents.caching import CLASSIFIER_HASH_KEY
|
||||||
|
from documents.caching import CLASSIFIER_MODIFIED_KEY
|
||||||
|
from documents.caching import CLASSIFIER_VERSION_KEY
|
||||||
|
|
||||||
|
classifier.train()
|
||||||
|
# Evict the keys to prove skip re-populates them
|
||||||
|
cache.delete(CLASSIFIER_MODIFIED_KEY)
|
||||||
|
cache.delete(CLASSIFIER_HASH_KEY)
|
||||||
|
cache.delete(CLASSIFIER_VERSION_KEY)
|
||||||
|
|
||||||
|
result = classifier.train()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert cache.get(CLASSIFIER_MODIFIED_KEY) is not None
|
||||||
|
assert cache.get(CLASSIFIER_HASH_KEY) is not None
|
||||||
|
assert cache.get(CLASSIFIER_VERSION_KEY) is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Correctness tests — skip must NOT fire when the world has changed
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db()
|
||||||
|
class TestFastSkipDoesNotFire:
|
||||||
|
"""The skip guard must yield to a full retrain whenever labels change."""
|
||||||
|
|
||||||
|
def test_document_content_modification_triggers_retrain(
|
||||||
|
self,
|
||||||
|
classifier,
|
||||||
|
label_corpus,
|
||||||
|
):
|
||||||
|
"""Updating a document's content updates modified → retrain required."""
|
||||||
|
classifier.train()
|
||||||
|
doc_a = label_corpus["doc_a"]
|
||||||
|
doc_a.content = "completely different words here now nothing same"
|
||||||
|
doc_a.save()
|
||||||
|
assert classifier.train() is True
|
||||||
|
|
||||||
|
def test_document_label_reassignment_triggers_retrain(
|
||||||
|
self,
|
||||||
|
classifier,
|
||||||
|
label_corpus,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Reassigning a document to a different AUTO correspondent (touching
|
||||||
|
doc.modified) must trigger a retrain.
|
||||||
|
"""
|
||||||
|
c_auto2 = Correspondent.objects.create(
|
||||||
|
name="Second Auto Corp",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
classifier.train()
|
||||||
|
doc_a = label_corpus["doc_a"]
|
||||||
|
doc_a.correspondent = c_auto2
|
||||||
|
doc_a.save()
|
||||||
|
assert classifier.train() is True
|
||||||
|
|
||||||
|
def test_matching_algorithm_change_on_assigned_tag_triggers_retrain(
|
||||||
|
self,
|
||||||
|
classifier,
|
||||||
|
label_corpus,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Flipping a tag's matching_algorithm to MATCH_AUTO after it is already
|
||||||
|
assigned to documents must trigger a retrain — even though no
|
||||||
|
Document.modified timestamp advances.
|
||||||
|
|
||||||
|
This is the key correctness case for the auto-label-set digest:
|
||||||
|
the tag is already on doc_a and doc_b, so once it becomes MATCH_AUTO
|
||||||
|
the classifier needs to learn it.
|
||||||
|
"""
|
||||||
|
# t_none is applied to doc_c (a control doc) via the fixture.
|
||||||
|
# We flip it to MATCH_AUTO; the set of learnable AUTO tags grows.
|
||||||
|
classifier.train()
|
||||||
|
t_none = label_corpus["t_none"]
|
||||||
|
t_none.matching_algorithm = MatchingModel.MATCH_AUTO
|
||||||
|
t_none.save(update_fields=["matching_algorithm"])
|
||||||
|
# Document.modified is NOT touched — this test specifically verifies
|
||||||
|
# that the auto-label-set digest catches the change.
|
||||||
|
assert classifier.train() is True
|
||||||
|
|
||||||
|
def test_new_auto_correspondent_triggers_retrain(self, classifier, label_corpus):
|
||||||
|
"""
|
||||||
|
Adding a brand-new MATCH_AUTO correspondent (unassigned to any doc)
|
||||||
|
must trigger a retrain: the auto-label-set has grown.
|
||||||
|
"""
|
||||||
|
classifier.train()
|
||||||
|
Correspondent.objects.create(
|
||||||
|
name="New Auto Corp",
|
||||||
|
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||||
|
)
|
||||||
|
assert classifier.train() is True
|
||||||
|
|
||||||
|
def test_removing_auto_label_triggers_retrain(self, classifier, label_corpus):
|
||||||
|
"""
|
||||||
|
Deleting a MATCH_AUTO correspondent shrinks the auto-label-set and
|
||||||
|
must trigger a retrain.
|
||||||
|
"""
|
||||||
|
classifier.train()
|
||||||
|
label_corpus["c_auto"].delete()
|
||||||
|
assert classifier.train() is True
|
||||||
|
|
||||||
|
def test_fresh_classifier_always_trains(self, classifier, label_corpus):
|
||||||
|
"""
|
||||||
|
A classifier that has never been trained (last_doc_change_time is None)
|
||||||
|
must always perform a full train, regardless of corpus state.
|
||||||
|
"""
|
||||||
|
assert classifier.last_doc_change_time is None
|
||||||
|
assert classifier.train() is True
|
||||||
|
|
||||||
|
def test_no_documents_raises_value_error(self, classifier, classifier_settings):
|
||||||
|
"""train() with an empty database must raise ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="No training data"):
|
||||||
|
classifier.train()
|
||||||
Reference in New Issue
Block a user