Converts all these tests to fully use fixtures, factories and compsition + dropping DB setup where possible

This commit is contained in:
stumpylog
2026-05-11 09:26:26 -07:00
parent 6cd5784bd7
commit 7a93909597
8 changed files with 594 additions and 698 deletions
+3 -1
View File
@@ -95,7 +95,9 @@ def build_llm_index_text(doc: Document) -> str:
]
for instance in doc.custom_fields.all():
lines.append(f"Custom Field - {instance.field.name}: {instance}")
lines.append(
f"Custom Field - {instance.field.name}: {instance.value_for_search}",
)
lines.append("\nContent:\n")
lines.append(doc.content or "")
+35 -1
View File
@@ -3,8 +3,42 @@ from pathlib import Path
import pytest
from pytest_django.fixtures import SettingsWrapper
from documents.models import Document
from documents.tests.factories import CorrespondentFactory
from documents.tests.factories import DocumentFactory
from documents.tests.factories import DocumentTypeFactory
@pytest.fixture
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper):
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path:
settings.LLM_INDEX_DIR = tmp_path
return tmp_path
@pytest.fixture
def document() -> Document:
return DocumentFactory.build(
title="Test Title",
filename="test_file.pdf",
correspondent=CorrespondentFactory.build(name="Test Correspondent"),
document_type=DocumentTypeFactory.build(name="Invoice"),
archive_serial_number=12345,
content="This is the document content.",
)
@pytest.fixture
def similar_documents() -> list[Document]:
return [
DocumentFactory.build(
title="Title 1",
content="Content of document 1",
filename="file1.txt",
),
DocumentFactory.build(
title="",
content="Content of document 2",
filename="file2.txt",
),
DocumentFactory.build(title="", content="", filename=None),
]
+72 -135
View File
@@ -1,9 +1,6 @@
import json
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from django.test import override_settings
from pytest_django.fixtures import SettingsWrapper
from pytest_mock import MockerFixture
from documents.models import Document
from paperless_ai.ai_classifier import build_prompt_with_rag
@@ -12,69 +9,16 @@ from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.ai_classifier import get_context_for_document
@pytest.fixture
def mock_document():
doc = MagicMock(spec=Document)
doc.title = "Test Title"
doc.filename = "test_file.pdf"
doc.created = "2023-01-01"
doc.added = "2023-01-02"
doc.modified = "2023-01-03"
tag1 = MagicMock()
tag1.name = "Tag1"
tag2 = MagicMock()
tag2.name = "Tag2"
doc.tags.all = MagicMock(return_value=[tag1, tag2])
doc.document_type = MagicMock()
doc.document_type.name = "Invoice"
doc.correspondent = MagicMock()
doc.correspondent.name = "Test Correspondent"
doc.archive_serial_number = "12345"
doc.content = "This is the document content."
cf1 = MagicMock(__str__=lambda x: "Value1")
cf1.field = MagicMock()
cf1.field.name = "Field1"
cf1.value = "Value1"
cf2 = MagicMock(__str__=lambda x: "Value2")
cf2.field = MagicMock()
cf2.field.name = "Field2"
cf2.value = "Value2"
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
return doc
@pytest.fixture
def mock_similar_documents():
doc1 = MagicMock()
doc1.content = "Content of document 1"
doc1.title = "Title 1"
doc1.filename = "file1.txt"
doc2 = MagicMock()
doc2.content = "Content of document 2"
doc2.title = None
doc2.filename = "file2.txt"
doc3 = MagicMock()
doc3.content = None
doc3.title = None
doc3.filename = None
return [doc1, doc2, doc3]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
mock_run_llm_query.return_value = {
def test_get_ai_document_classification_success(
settings: SettingsWrapper,
mocker: MockerFixture,
document: Document,
) -> None:
settings.LLM_BACKEND = "ollama"
settings.LLM_MODEL = "some_model"
mock_run = mocker.patch("paperless_ai.client.AIClient.run_llm_query")
mock_run.return_value = {
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
@@ -83,7 +27,7 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
"dates": ["2023-01-01"],
}
result = get_ai_document_classification(mock_document)
result = get_ai_document_classification(document)
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
@@ -94,93 +38,86 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
mock_run_llm_query.side_effect = Exception("LLM query failed")
# assert raises an exception
def test_get_ai_document_classification_failure(
mocker: MockerFixture,
document: Document,
) -> None:
mocker.patch(
"paperless_ai.client.AIClient.run_llm_query",
side_effect=Exception("LLM query failed"),
)
with pytest.raises(Exception):
get_ai_document_classification(mock_document)
get_ai_document_classification(document)
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_EMBEDDING_MODEL="some_model",
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_use_rag_if_configured(
mock_build_prompt_with_rag,
mock_run_llm_query,
mock_document,
):
mock_build_prompt_with_rag.return_value = "Prompt with RAG"
mock_run_llm_query.return_value.text = json.dumps({})
get_ai_document_classification(mock_document)
mock_build_prompt_with_rag.assert_called_once()
settings: SettingsWrapper,
mocker: MockerFixture,
document: Document,
) -> None:
settings.LLM_EMBEDDING_BACKEND = "huggingface"
settings.LLM_EMBEDDING_MODEL = "some_model"
settings.LLM_BACKEND = "ollama"
settings.LLM_MODEL = "some_model"
mock_build = mocker.patch("paperless_ai.ai_classifier.build_prompt_with_rag")
mock_build.return_value = "Prompt with RAG"
mocker.patch("paperless_ai.client.AIClient.run_llm_query", return_value={})
get_ai_document_classification(document)
mock_build.assert_called_once()
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_without_rag")
@patch("paperless.config.AIConfig")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_use_without_rag_if_not_configured(
mock_ai_config,
mock_build_prompt_without_rag,
mock_run_llm_query,
mock_document,
):
settings: SettingsWrapper,
mocker: MockerFixture,
document: Document,
) -> None:
settings.LLM_BACKEND = "ollama"
settings.LLM_MODEL = "some_model"
mock_ai_config = mocker.patch("paperless.config.AIConfig")
mock_build = mocker.patch("paperless_ai.ai_classifier.build_prompt_without_rag")
mocker.patch("paperless_ai.client.AIClient.run_llm_query", return_value={})
mock_ai_config.llm_embedding_backend = None
mock_build_prompt_without_rag.return_value = "Prompt without RAG"
mock_run_llm_query.return_value.text = json.dumps({})
get_ai_document_classification(mock_document)
mock_build_prompt_without_rag.assert_called_once()
mock_build.return_value = "Prompt without RAG"
get_ai_document_classification(document)
mock_build.assert_called_once()
@pytest.mark.django_db
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_prompt_with_without_rag(mock_document):
with patch(
def test_prompt_with_without_rag(mocker: MockerFixture, document: Document) -> None:
mocker.patch(
"paperless_ai.ai_classifier.get_context_for_document",
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
assert "Additional context from similar documents:" not in prompt
)
prompt = build_prompt_without_rag(document)
assert "Additional context from similar documents:" not in prompt
prompt = build_prompt_with_rag(mock_document)
assert "Additional context from similar documents:" in prompt
prompt = build_prompt_with_rag(document)
assert "Additional context from similar documents:" in prompt
@patch("paperless_ai.ai_classifier.query_similar_documents")
def test_get_context_for_document(
mock_query_similar_documents,
mock_document,
mock_similar_documents,
):
mock_query_similar_documents.return_value = mock_similar_documents
result = get_context_for_document(mock_document, max_docs=2)
expected_result = (
mocker: MockerFixture,
document: Document,
similar_documents: list[Document],
) -> None:
mocker.patch(
"paperless_ai.ai_classifier.query_similar_documents",
return_value=similar_documents,
)
result = get_context_for_document(document, max_docs=2)
assert result == (
"TITLE: Title 1\nContent of document 1\n\n"
"TITLE: file2.txt\nContent of document 2"
)
assert result == expected_result
mock_query_similar_documents.assert_called_once()
def test_get_context_for_document_no_similar_docs(mock_document):
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
result = get_context_for_document(mock_document)
assert result == ""
def test_get_context_for_document_no_similar_docs(
mocker: MockerFixture,
document: Document,
) -> None:
mocker.patch(
"paperless_ai.ai_classifier.query_similar_documents",
return_value=[],
)
assert get_context_for_document(document) == ""
+179 -227
View File
@@ -1,11 +1,12 @@
import json
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from django.test import override_settings
from django.utils import timezone
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from pytest_django.fixtures import SettingsWrapper
from pytest_mock import MockerFixture
from documents.models import Document
from documents.models import PaperlessTask
@@ -14,7 +15,7 @@ from paperless_ai import indexing
@pytest.fixture
def real_document(db):
def real_document(db) -> Document:
return Document.objects.create(
title="Test Document",
content="This is some test content.",
@@ -23,36 +24,15 @@ def real_document(db):
@pytest.fixture
def mock_embed_model():
fake = FakeEmbedding()
with (
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
patch(
"paperless_ai.embedding.get_embedding_model",
) as mock_embedding,
):
mock_index.return_value = fake
mock_embedding.return_value = fake
yield mock_index
class FakeEmbedding(BaseEmbedding):
# TODO: maybe a better way to do this?
def _aget_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_text_embedding(self, text: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def get_query_embedding_dim(self) -> int:
return 384 # Match your real FAISS config
def mock_embed_model(mocker: MockerFixture) -> MagicMock:
fake = MockEmbedding(embed_dim=384)
mock = mocker.patch("paperless_ai.indexing.get_embedding_model", return_value=fake)
mocker.patch("paperless_ai.embedding.get_embedding_model", return_value=fake)
return mock
@pytest.mark.django_db
def test_build_document_node(real_document) -> None:
def test_build_document_node(real_document: Document) -> None:
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
assert nodes[0].metadata["document_id"] == str(real_document.id)
@@ -60,37 +40,38 @@ def test_build_document_node(real_document) -> None:
@pytest.mark.django_db
def test_update_llm_index(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: MagicMock,
mocker: MockerFixture,
) -> None:
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset)
assert any(temp_llm_index_dir.glob("*.json"))
indexing.update_llm_index(rebuild=True)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_update_llm_index_removes_meta(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: MagicMock,
mocker: MockerFixture,
) -> None:
# Pre-create a meta.json with incorrect data
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "old", "dim": 1}),
)
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset)
indexing.update_llm_index(rebuild=True)
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
from paperless.config import AIConfig
@@ -106,9 +87,10 @@ def test_update_llm_index_removes_meta(
@pytest.mark.django_db
def test_update_llm_index_partial_update(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: MagicMock,
mocker: MockerFixture,
) -> None:
doc2 = Document.objects.create(
title="Test Document 2",
@@ -116,20 +98,16 @@ def test_update_llm_index_partial_update(
added=timezone.now(),
checksum="1234567890abcdef",
)
# Initial index
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document, doc2])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document, doc2])
mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset)
indexing.update_llm_index(rebuild=True)
# modify document
updated_document = real_document
updated_document.modified = timezone.now() # simulate modification
updated_document.modified = timezone.now()
# new doc
doc3 = Document.objects.create(
title="Test Document 3",
content="This is some test content 3.",
@@ -137,110 +115,101 @@ def test_update_llm_index_partial_update(
checksum="abcdef1234567890",
)
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3])
mock_all.return_value = mock_queryset
mock_queryset2 = MagicMock()
mock_queryset2.exists.return_value = True
mock_queryset2.__iter__.return_value = iter([updated_document, doc2, doc3])
mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset2)
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=False)
mock_logger.info.assert_called_once_with(
"Updating %d nodes in LLM index.",
2,
)
indexing.update_llm_index(rebuild=False)
mock_logger = mocker.patch("paperless_ai.indexing.logger")
indexing.update_llm_index(rebuild=False)
mock_logger.info.assert_called_once_with("Updating %d nodes in LLM index.", 2)
indexing.update_llm_index(rebuild=False)
assert any(temp_llm_index_dir.glob("*.json"))
def test_get_or_create_storage_context_raises_exception(
temp_llm_index_dir,
mock_embed_model,
temp_llm_index_dir: Path,
mock_embed_model: MagicMock,
) -> None:
with pytest.raises(Exception):
indexing.get_or_create_storage_context(rebuild=False)
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
)
@pytest.mark.django_db
def test_load_or_build_index_builds_when_nodes_given(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: MagicMock,
mocker: MockerFixture,
) -> None:
with (
patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
),
patch(
"llama_index.core.VectorStoreIndex",
return_value=MagicMock(),
) as mock_index_cls,
patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
) as mock_storage,
):
mock_storage.return_value.persist_dir = temp_llm_index_dir
indexing.load_or_build_index(
nodes=[indexing.build_document_node(real_document)],
)
mock_index_cls.assert_called_once()
mocker.patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
)
mock_index_cls = mocker.patch(
"llama_index.core.VectorStoreIndex",
return_value=MagicMock(),
)
mock_storage = mocker.patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
)
mock_storage.return_value.persist_dir = temp_llm_index_dir
indexing.load_or_build_index(nodes=[indexing.build_document_node(real_document)])
mock_index_cls.assert_called_once()
def test_load_or_build_index_raises_exception_when_no_nodes(
temp_llm_index_dir,
mock_embed_model,
temp_llm_index_dir: Path,
mock_embed_model: MagicMock,
mocker: MockerFixture,
) -> None:
with (
patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
),
patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
),
):
with pytest.raises(Exception):
indexing.load_or_build_index()
mocker.patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
)
mocker.patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
)
with pytest.raises(Exception):
indexing.load_or_build_index()
@pytest.mark.django_db
def test_load_or_build_index_succeeds_when_nodes_given(
temp_llm_index_dir,
mock_embed_model,
temp_llm_index_dir: Path,
mock_embed_model: MagicMock,
mocker: MockerFixture,
) -> None:
with (
patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
),
patch(
"llama_index.core.VectorStoreIndex",
return_value=MagicMock(),
) as mock_index_cls,
patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
) as mock_storage,
):
mock_storage.return_value.persist_dir = temp_llm_index_dir
indexing.load_or_build_index(
nodes=[MagicMock()],
)
mock_index_cls.assert_called_once()
mocker.patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
)
mock_index_cls = mocker.patch(
"llama_index.core.VectorStoreIndex",
return_value=MagicMock(),
)
mock_storage = mocker.patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
)
mock_storage.return_value.persist_dir = temp_llm_index_dir
indexing.load_or_build_index(nodes=[MagicMock()])
mock_index_cls.assert_called_once()
@pytest.mark.django_db
def test_add_or_update_document_updates_existing_entry(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: MagicMock,
) -> None:
indexing.update_llm_index(rebuild=True)
indexing.llm_index_add_or_update_document(real_document)
@@ -250,9 +219,9 @@ def test_add_or_update_document_updates_existing_entry(
@pytest.mark.django_db
def test_remove_document_deletes_node_from_docstore(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: MagicMock,
) -> None:
indexing.update_llm_index(rebuild=True)
index = indexing.load_or_build_index()
@@ -265,31 +234,29 @@ def test_remove_document_deletes_node_from_docstore(
@pytest.mark.django_db
def test_update_llm_index_no_documents(
temp_llm_index_dir,
mock_embed_model,
temp_llm_index_dir: Path,
mock_embed_model: MagicMock,
mocker: MockerFixture,
) -> None:
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = False
mock_queryset.__iter__.return_value = iter([])
mock_all.return_value = mock_queryset
mock_queryset = MagicMock()
mock_queryset.exists.return_value = False
mock_queryset.__iter__.return_value = iter([])
mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset)
# check log message
with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=True)
mock_logger.warning.assert_called_once_with(
"No documents found to index.",
)
mock_logger = mocker.patch("paperless_ai.indexing.logger")
indexing.update_llm_index(rebuild=True)
mock_logger.warning.assert_called_once_with("No documents found to index.")
@pytest.mark.django_db
def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -> None:
# No existing tasks
with patch("documents.tasks.llmindex_index") as mock_task:
result = indexing.queue_llm_index_update_if_needed(
rebuild=True,
reason="test enqueue",
)
def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent(
mocker: MockerFixture,
) -> None:
mock_task = mocker.patch("documents.tasks.llmindex_index")
result = indexing.queue_llm_index_update_if_needed(
rebuild=True,
reason="test enqueue",
)
assert result is True
mock_task.apply_async.assert_called_once_with(
@@ -303,86 +270,71 @@ def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -
status=PaperlessTask.Status.STARTED,
)
# Existing running task
with patch("documents.tasks.llmindex_index") as mock_task:
result = indexing.queue_llm_index_update_if_needed(
rebuild=False,
reason="should skip",
)
mock_task2 = mocker.patch("documents.tasks.llmindex_index")
result = indexing.queue_llm_index_update_if_needed(
rebuild=False,
reason="should skip",
)
assert result is False
mock_task.apply_async.assert_not_called()
mock_task2.apply_async.assert_not_called()
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_BACKEND="ollama",
)
@pytest.mark.django_db
def test_query_similar_documents(
temp_llm_index_dir,
real_document,
temp_llm_index_dir: Path,
real_document: Document,
mocker: MockerFixture,
settings: SettingsWrapper,
) -> None:
with (
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch(
"paperless_ai.indexing.vector_store_file_exists",
) as mock_vector_store_exists,
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
):
mock_storage.return_value = MagicMock()
mock_storage.return_value.persist_dir = temp_llm_index_dir
mock_vector_store_exists.return_value = True
settings.LLM_EMBEDDING_BACKEND = "huggingface"
settings.LLM_BACKEND = "ollama"
mock_index = MagicMock()
mock_load_or_build_index.return_value = mock_index
mock_storage = mocker.patch("paperless_ai.indexing.get_or_create_storage_context")
mock_storage.return_value.persist_dir = temp_llm_index_dir
mocker.patch("paperless_ai.indexing.vector_store_file_exists", return_value=True)
mock_retriever = MagicMock()
mock_retriever_cls.return_value = mock_retriever
mock_index = MagicMock()
mocker.patch("paperless_ai.indexing.load_or_build_index", return_value=mock_index)
mock_node1 = MagicMock()
mock_node1.metadata = {"document_id": 1}
mock_retriever = MagicMock()
mocker.patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=mock_retriever,
)
mock_node2 = MagicMock()
mock_node2.metadata = {"document_id": 2}
mock_node1 = MagicMock()
mock_node1.metadata = {"document_id": 1}
mock_node2 = MagicMock()
mock_node2.metadata = {"document_id": 2}
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
mock_filter = mocker.patch(
"paperless_ai.indexing.Document.objects.filter",
return_value=mock_filtered_docs,
)
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
mock_filter.return_value = mock_filtered_docs
result = indexing.query_similar_documents(real_document, top_k=3)
result = indexing.query_similar_documents(real_document, top_k=3)
mock_load_or_build_index.assert_called_once()
mock_retriever_cls.assert_called_once()
mock_retriever.retrieve.assert_called_once_with(
"Test Document\nThis is some test content.",
)
mock_filter.assert_called_once_with(pk__in=[1, 2])
assert result == mock_filtered_docs
mock_retriever.retrieve.assert_called_once_with(
"Test Document\nThis is some test content.",
)
mock_filter.assert_called_once_with(pk__in=[1, 2])
assert result == mock_filtered_docs
@pytest.mark.django_db
def test_query_similar_documents_triggers_update_when_index_missing(
temp_llm_index_dir,
real_document,
temp_llm_index_dir: Path,
real_document: Document,
mocker: MockerFixture,
) -> None:
with (
patch(
"paperless_ai.indexing.vector_store_file_exists",
return_value=False,
),
patch(
"paperless_ai.indexing.queue_llm_index_update_if_needed",
) as mock_queue,
patch("paperless_ai.indexing.load_or_build_index") as mock_load,
):
result = indexing.query_similar_documents(
real_document,
top_k=2,
)
mocker.patch("paperless_ai.indexing.vector_store_file_exists", return_value=False)
mock_queue = mocker.patch("paperless_ai.indexing.queue_llm_index_update_if_needed")
mock_load = mocker.patch("paperless_ai.indexing.load_or_build_index")
result = indexing.query_similar_documents(real_document, top_k=2)
mock_queue.assert_called_once_with(
rebuild=False,
+96 -127
View File
@@ -1,10 +1,10 @@
import json
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from pytest_mock import MockerFixture
from paperless_ai.chat import CHAT_METADATA_DELIMITER
from paperless_ai.chat import stream_chat_with_documents
@@ -15,31 +15,17 @@ def patch_embed_model():
from llama_index.core import settings as llama_settings
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
# Use a real BaseEmbedding subclass to satisfy llama-index 0.14 validation
llama_settings.Settings.embed_model = MockEmbedding(embed_dim=1536)
yield
llama_settings.Settings.embed_model = None
@pytest.fixture(autouse=True)
def patch_embed_nodes():
with patch(
"llama_index.core.indices.vector_store.base.embed_nodes",
) as mock_embed_nodes:
mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: {
node.node_id: [0.1] * 1536 for node in nodes
}
yield
@pytest.fixture
def mock_document():
doc = MagicMock()
doc.pk = 1
doc.title = "Test Document"
doc.filename = "test_file.pdf"
doc.content = "This is the document content."
return doc
def patch_embed_nodes(mocker: MockerFixture):
mock = mocker.patch("llama_index.core.indices.vector_store.base.embed_nodes")
mock.side_effect = lambda nodes, *_args, **_kwargs: {
node.node_id: [0.1] * 1536 for node in nodes
}
def assert_chat_output(
@@ -57,127 +43,110 @@ def assert_chat_output(
}
def test_stream_chat_with_one_document_full_content(mock_document) -> None:
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
def test_stream_chat_with_one_document_full_content(mocker: MockerFixture) -> None:
mock_document = MagicMock()
mock_document.pk = 1
mock_document.title = "Test Document"
mock_document.filename = "test_file.pdf"
mock_document.content = "This is the document content."
mock_node = TextNode(
text="This is node content.",
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
mock_load_index.return_value = mock_index
mock_client = MagicMock()
mocker.patch("paperless_ai.chat.AIClient", return_value=mock_client)
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
mock_node = TextNode(
text="This is node content.",
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
mocker.patch("paperless_ai.chat.load_or_build_index", return_value=mock_index)
output = list(stream_chat_with_documents("What is this?", [mock_document]))
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
mock_query_engine = MagicMock()
mock_query_engine.query.return_value = mock_response_stream
mocker.patch(
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
return_value=mock_query_engine,
)
assert_chat_output(
output,
expected_chunks=["chunk1", "chunk2"],
expected_references=[
{"id": mock_document.pk, "title": "Test Document"},
],
)
output = list(stream_chat_with_documents("What is this?", [mock_document]))
assert_chat_output(
output,
expected_chunks=["chunk1", "chunk2"],
expected_references=[{"id": mock_document.pk, "title": "Test Document"}],
)
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> None:
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
):
# Mock AIClient and LLM
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
def test_stream_chat_with_multiple_documents_retrieval(
patch_embed_nodes,
mocker: MockerFixture,
) -> None:
mock_client = MagicMock()
mocker.patch("paperless_ai.chat.AIClient", return_value=mock_client)
# Create two real TextNodes
mock_node1 = TextNode(
text="Content for doc 1.",
metadata={"document_id": "1", "title": "Document 1"},
)
mock_node2 = TextNode(
text="Content for doc 2.",
metadata={"document_id": "2", "title": "Document 2"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
mock_load_index.return_value = mock_index
mock_node1 = TextNode(
text="Content for doc 1.",
metadata={"document_id": "1", "title": "Document 1"},
)
mock_node2 = TextNode(
text="Content for doc 2.",
metadata={"document_id": "2", "title": "Document 2"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
mocker.patch("paperless_ai.chat.load_or_build_index", return_value=mock_index)
# Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
mock_retriever = MagicMock()
mock_duplicate_node = TextNode(
text="More content for doc 1.",
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
)
mock_foreign_node = TextNode(
text="Content for doc 3.",
metadata={"document_id": "3", "title": "Document 3"},
)
mock_retriever.retrieve.return_value = [
mock_node1,
mock_duplicate_node,
mock_node2,
mock_foreign_node,
]
mock_as_retriever.return_value = mock_retriever
mock_retriever = MagicMock()
mock_duplicate_node = TextNode(
text="More content for doc 1.",
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
)
mock_foreign_node = TextNode(
text="Content for doc 3.",
metadata={"document_id": "3", "title": "Document 3"},
)
mock_retriever.retrieve.return_value = [
mock_node1,
mock_duplicate_node,
mock_node2,
mock_foreign_node,
]
mocker.patch.object(VectorStoreIndex, "as_retriever", return_value=mock_retriever)
# Mock response stream
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
mock_query_engine = MagicMock()
mock_query_engine.query.return_value = mock_response_stream
mocker.patch(
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
return_value=mock_query_engine,
)
# Mock RetrieverQueryEngine
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
# Fake documents
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
assert_chat_output(
output,
expected_chunks=["chunk1", "chunk2"],
expected_references=[
{"id": 1, "title": "Document 1"},
{"id": 2, "title": "Document 2"},
],
)
assert_chat_output(
output,
expected_chunks=["chunk1", "chunk2"],
expected_references=[
{"id": 1, "title": "Document 1"},
{"id": 2, "title": "Document 2"},
],
)
def test_stream_chat_no_matching_nodes() -> None:
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
def test_stream_chat_no_matching_nodes(mocker: MockerFixture) -> None:
mock_client = MagicMock()
mocker.patch("paperless_ai.chat.AIClient", return_value=mock_client)
mock_index = MagicMock()
# No matching nodes
mock_index.docstore.docs.values.return_value = []
mock_load_index.return_value = mock_index
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = []
mocker.patch("paperless_ai.chat.load_or_build_index", return_value=mock_index)
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
assert output == ["Sorry, I couldn't find any content to answer your question."]
assert output == ["Sorry, I couldn't find any content to answer your question."]
+37 -40
View File
@@ -1,38 +1,35 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection
from pytest_mock import MockerFixture
from paperless.config import AIConfig
from paperless_ai.client import AIClient
@pytest.fixture
def mock_ai_config():
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
mock_config = MagicMock()
mock_config.llm_allow_internal_endpoints = True
MockAIConfig.return_value = mock_config
yield mock_config
def mock_ai_config(mocker: MockerFixture) -> MagicMock:
mock = mocker.patch("paperless_ai.client.AIConfig", spec=AIConfig)
mock.return_value.llm_allow_internal_endpoints = True
return mock
@pytest.fixture
def mock_ollama_llm():
with patch("llama_index.llms.ollama.Ollama") as MockOllama:
yield MockOllama
def mock_ollama_llm(mocker: MockerFixture) -> MagicMock:
return mocker.patch("llama_index.llms.ollama.Ollama")
@pytest.fixture
def mock_openai_llm():
with patch("llama_index.llms.openai_like.OpenAILike") as MockOpenAILike:
yield MockOpenAILike
def mock_openai_llm(mocker: MockerFixture) -> MagicMock:
return mocker.patch("llama_index.llms.openai_like.OpenAILike")
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_endpoint = "http://test-url"
def test_get_llm_ollama(mock_ai_config: MagicMock, mock_ollama_llm: MagicMock) -> None:
mock_ai_config.return_value.llm_backend = "ollama"
mock_ai_config.return_value.llm_model = "test_model"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
client = AIClient()
@@ -44,11 +41,11 @@ def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
assert client.llm == mock_ollama_llm.return_value
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
mock_ai_config.llm_backend = "openai-like"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_api_key = "test_api_key"
mock_ai_config.llm_endpoint = "http://test-url"
def test_get_llm_openai(mock_ai_config: MagicMock, mock_openai_llm: MagicMock) -> None:
mock_ai_config.return_value.llm_backend = "openai-like"
mock_ai_config.return_value.llm_model = "test_model"
mock_ai_config.return_value.llm_api_key = "test_api_key"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
client = AIClient()
@@ -62,31 +59,32 @@ def test_get_llm_openai(mock_ai_config, mock_openai_llm):
assert client.llm == mock_openai_llm.return_value
def test_get_llm_openai_blocks_internal_endpoint_when_disallowed(mock_ai_config):
mock_ai_config.llm_backend = "openai-like"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_api_key = "test_api_key"
mock_ai_config.llm_endpoint = "http://127.0.0.1:1234"
mock_ai_config.llm_allow_internal_endpoints = False
def test_get_llm_openai_blocks_internal_endpoint_when_disallowed(
mock_ai_config: MagicMock,
) -> None:
mock_ai_config.return_value.llm_backend = "openai-like"
mock_ai_config.return_value.llm_model = "test_model"
mock_ai_config.return_value.llm_api_key = "test_api_key"
mock_ai_config.return_value.llm_endpoint = "http://127.0.0.1:1234"
mock_ai_config.return_value.llm_allow_internal_endpoints = False
with pytest.raises(ValueError, match="non-public address"):
AIClient()
def test_get_llm_unsupported_backend(mock_ai_config):
mock_ai_config.llm_backend = "unsupported"
def test_get_llm_unsupported_backend(mock_ai_config: MagicMock) -> None:
mock_ai_config.return_value.llm_backend = "unsupported"
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
AIClient()
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_endpoint = "http://test-url"
def test_run_llm_query(mock_ai_config: MagicMock, mock_ollama_llm: MagicMock) -> None:
mock_ai_config.return_value.llm_backend = "ollama"
mock_ai_config.return_value.llm_model = "test_model"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
tool_selection = ToolSelection(
tool_id="call_test",
tool_name="DocumentClassifierSchema",
@@ -99,7 +97,6 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
"dates": ["2023-01-01"],
},
)
mock_llm_instance.chat_with_tools.return_value = MagicMock()
mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
@@ -109,10 +106,10 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
assert result["title"] == "Test Title"
def test_run_chat(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_endpoint = "http://test-url"
def test_run_chat(mock_ai_config: MagicMock, mock_ollama_llm: MagicMock) -> None:
mock_ai_config.return_value.llm_backend = "ollama"
mock_ai_config.return_value.llm_model = "test_model"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.chat.return_value = "test_chat_result"
+99 -97
View File
@@ -1,10 +1,20 @@
import datetime
import json
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from pytest_mock import MockerFixture
from documents.models import CustomField
from documents.models import CustomFieldInstance
from documents.models import Document
from documents.models import Note
from documents.tests.factories import CorrespondentFactory
from documents.tests.factories import DocumentFactory
from documents.tests.factories import DocumentTypeFactory
from documents.tests.factories import TagFactory
from paperless.config import AIConfig
from paperless.models import LLMEmbeddingBackend
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
@@ -12,97 +22,89 @@ from paperless_ai.embedding import get_embedding_model
@pytest.fixture
def mock_ai_config():
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
MockAIConfig.return_value.llm_allow_internal_endpoints = True
yield MockAIConfig
def mock_ai_config(mocker: MockerFixture) -> MagicMock:
mock = mocker.patch("paperless_ai.embedding.AIConfig", spec=AIConfig)
mock.return_value.llm_allow_internal_endpoints = True
return mock
@pytest.fixture
def mock_document():
doc = MagicMock(spec=Document)
doc.title = "Test Title"
doc.filename = "test_file.pdf"
doc.created = "2023-01-01"
doc.added = "2023-01-02"
doc.modified = "2023-01-03"
tag1 = MagicMock()
tag1.name = "Tag1"
tag2 = MagicMock()
tag2.name = "Tag2"
doc.tags.all = MagicMock(return_value=[tag1, tag2])
doc.document_type = MagicMock()
doc.document_type.name = "Invoice"
doc.correspondent = MagicMock()
doc.correspondent.name = "Test Correspondent"
doc.archive_serial_number = "12345"
doc.content = "This is the document content."
cf1 = MagicMock(__str__=lambda x: "Value1")
cf1.field = MagicMock()
cf1.field.name = "Field1"
cf1.value = "Value1"
cf2 = MagicMock(__str__=lambda x: "Value2")
cf2.field = MagicMock()
cf2.field.name = "Field2"
cf2.value = "Value2"
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
def full_document(db) -> Document:
tag1 = TagFactory(name="Tag1")
tag2 = TagFactory(name="Tag2")
doc = DocumentFactory(
title="Test Title",
filename="test_file.pdf",
created=datetime.date(2023, 1, 1),
correspondent=CorrespondentFactory(name="Test Correspondent"),
document_type=DocumentTypeFactory(name="Invoice"),
archive_serial_number=12345,
content="This is the document content.",
)
doc.tags.add(tag1, tag2)
cf1 = CustomField.objects.create(
name="Field1",
data_type=CustomField.FieldDataType.STRING,
)
cf2 = CustomField.objects.create(
name="Field2",
data_type=CustomField.FieldDataType.STRING,
)
CustomFieldInstance.objects.create(document=doc, field=cf1, value_text="Value1")
CustomFieldInstance.objects.create(document=doc, field=cf2, value_text="Value2")
Note.objects.create(document=doc, note="Note1")
Note.objects.create(document=doc, note="Note2")
return doc
def test_get_embedding_model_openai(mock_ai_config):
def test_get_embedding_model_openai(
mock_ai_config: MagicMock,
mocker: MockerFixture,
) -> None:
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
mock_ai_config.return_value.llm_api_key = "test_api_key"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
with patch(
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
) as MockOpenAIEmbedding:
model = get_embedding_model()
MockOpenAIEmbedding.assert_called_once_with(
model_name="text-embedding-3-small",
api_key="test_api_key",
api_base="http://test-url",
)
assert model == MockOpenAIEmbedding.return_value
mock_cls = mocker.patch("llama_index.embeddings.openai_like.OpenAILikeEmbedding")
model = get_embedding_model()
mock_cls.assert_called_once_with(
model_name="text-embedding-3-small",
api_key="test_api_key",
api_base="http://test-url",
)
assert model == mock_cls.return_value
def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
mock_ai_config,
):
mock_ai_config: MagicMock,
) -> None:
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
mock_ai_config.return_value.llm_api_key = "test_api_key"
mock_ai_config.return_value.llm_endpoint = "http://127.0.0.1:11434"
mock_ai_config.return_value.llm_allow_internal_endpoints = False
with pytest.raises(ValueError, match="non-public address"):
get_embedding_model()
def test_get_embedding_model_huggingface(mock_ai_config):
def test_get_embedding_model_huggingface(
mock_ai_config: MagicMock,
mocker: MockerFixture,
) -> None:
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE
mock_ai_config.return_value.llm_embedding_model = (
"sentence-transformers/all-MiniLM-L6-v2"
)
with patch(
"llama_index.embeddings.huggingface.HuggingFaceEmbedding",
) as MockHuggingFaceEmbedding:
model = get_embedding_model()
MockHuggingFaceEmbedding.assert_called_once_with(
model_name="sentence-transformers/all-MiniLM-L6-v2",
)
assert model == MockHuggingFaceEmbedding.return_value
mock_cls = mocker.patch("llama_index.embeddings.huggingface.HuggingFaceEmbedding")
model = get_embedding_model()
mock_cls.assert_called_once_with(
model_name="sentence-transformers/all-MiniLM-L6-v2",
)
assert model == mock_cls.return_value
def test_get_embedding_model_invalid_backend(mock_ai_config):
def test_get_embedding_model_invalid_backend(mock_ai_config: MagicMock) -> None:
mock_ai_config.return_value.llm_embedding_backend = "INVALID_BACKEND"
with pytest.raises(
ValueError,
match="Unsupported embedding backend: INVALID_BACKEND",
@@ -110,47 +112,53 @@ def test_get_embedding_model_invalid_backend(mock_ai_config):
get_embedding_model()
def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config):
def test_get_embedding_dim_infers_and_saves(
temp_llm_index_dir: Path,
mock_ai_config: MagicMock,
mocker: MockerFixture,
) -> None:
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
mock_ai_config.return_value.llm_embedding_model = None
class DummyEmbedding:
def get_text_embedding(self, text):
def get_text_embedding(self, text: str) -> list[float]:
return [0.0] * 7
with patch(
mock_get = mocker.patch(
"paperless_ai.embedding.get_embedding_model",
return_value=DummyEmbedding(),
) as mock_get:
dim = get_embedding_dim()
mock_get.assert_called_once()
)
dim = get_embedding_dim()
mock_get.assert_called_once()
assert dim == 7
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7}
def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config):
def test_get_embedding_dim_reads_existing_meta(
temp_llm_index_dir: Path,
mock_ai_config: MagicMock,
mocker: MockerFixture,
) -> None:
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
mock_ai_config.return_value.llm_embedding_model = None
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}),
)
with patch("paperless_ai.embedding.get_embedding_model") as mock_get:
assert get_embedding_dim() == 11
mock_get.assert_not_called()
mock_get = mocker.patch("paperless_ai.embedding.get_embedding_model")
assert get_embedding_dim() == 11
mock_get.assert_not_called()
def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config):
def test_get_embedding_dim_raises_on_model_change(
temp_llm_index_dir: Path,
mock_ai_config: MagicMock,
) -> None:
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
mock_ai_config.return_value.llm_embedding_model = None
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "old", "dim": 11}),
)
with pytest.raises(
RuntimeError,
match="Embedding model changed from old to text-embedding-3-small",
@@ -158,21 +166,15 @@ def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_co
get_embedding_dim()
def test_build_llm_index_text(mock_document):
with patch("documents.models.Note.objects.filter") as mock_notes_filter:
mock_notes_filter.return_value = [
MagicMock(note="Note1"),
MagicMock(note="Note2"),
]
result = build_llm_index_text(mock_document)
assert "Title: Test Title" in result
assert "Filename: test_file.pdf" in result
assert "Created: 2023-01-01" in result
assert "Tags: Tag1, Tag2" in result
assert "Document Type: Invoice" in result
assert "Correspondent: Test Correspondent" in result
assert "Notes: Note1,Note2" in result
assert "Content:\n\nThis is the document content." in result
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
@pytest.mark.django_db
def test_build_llm_index_text(full_document: Document) -> None:
result = build_llm_index_text(full_document)
assert "Title: Test Title" in result
assert "Filename: test_file.pdf" in result
assert "Created: 2023-01-01" in result
assert "Tags: Tag1, Tag2" in result
assert "Document Type: Invoice" in result
assert "Correspondent: Test Correspondent" in result
assert "Notes: Note1,Note2" in result
assert "Content:\n\nThis is the document content." in result
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
+73 -70
View File
@@ -1,86 +1,89 @@
from unittest.mock import patch
from pytest_mock import MockerFixture
from django.test import TestCase
from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.tests.factories import CorrespondentFactory
from documents.tests.factories import DocumentTypeFactory
from documents.tests.factories import StoragePathFactory
from documents.tests.factories import TagFactory
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
from paperless_ai.matching import match_storage_paths_by_name
from paperless_ai.matching import match_tags_by_name
_PATCH_TARGET = "paperless_ai.matching.get_objects_for_user_owner_aware"
class TestAIMatching(TestCase):
def setUp(self) -> None:
# Create test data for Tag
self.tag1 = Tag.objects.create(name="Test Tag 1")
self.tag2 = Tag.objects.create(name="Test Tag 2")
# Create test data for Correspondent
self.correspondent1 = Correspondent.objects.create(name="Test Correspondent 1")
self.correspondent2 = Correspondent.objects.create(name="Test Correspondent 2")
class TestAIMatching:
def test_match_tags_by_name(self, mocker: MockerFixture) -> None:
tags = [
TagFactory.build(name="Test Tag 1"),
TagFactory.build(name="Test Tag 2"),
]
mocker.patch(_PATCH_TARGET, return_value=tags)
result = match_tags_by_name(["Test Tag 1", "Nonexistent Tag"], user=None)
assert len(result) == 1
assert result[0].name == "Test Tag 1"
# Create test data for DocumentType
self.document_type1 = DocumentType.objects.create(name="Test Document Type 1")
self.document_type2 = DocumentType.objects.create(name="Test Document Type 2")
def test_match_correspondents_by_name(self, mocker: MockerFixture) -> None:
correspondents = [
CorrespondentFactory.build(name="Test Correspondent 1"),
CorrespondentFactory.build(name="Test Correspondent 2"),
]
mocker.patch(_PATCH_TARGET, return_value=correspondents)
result = match_correspondents_by_name(
["Test Correspondent 1", "Nonexistent Correspondent"],
user=None,
)
assert len(result) == 1
assert result[0].name == "Test Correspondent 1"
# Create test data for StoragePath
self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1")
self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2")
def test_match_document_types_by_name(self, mocker: MockerFixture) -> None:
document_types = [
DocumentTypeFactory.build(name="Test Document Type 1"),
DocumentTypeFactory.build(name="Test Document Type 2"),
]
mocker.patch(_PATCH_TARGET, return_value=document_types)
result = match_document_types_by_name(
["Test Document Type 1", "Nonexistent Document Type"],
user=None,
)
assert len(result) == 1
assert result[0].name == "Test Document Type 1"
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_by_name(self, mock_get_objects) -> None:
mock_get_objects.return_value = Tag.objects.all()
names = ["Test Tag 1", "Nonexistent Tag"]
result = match_tags_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Tag 1")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_correspondents_by_name(self, mock_get_objects) -> None:
mock_get_objects.return_value = Correspondent.objects.all()
names = ["Test Correspondent 1", "Nonexistent Correspondent"]
result = match_correspondents_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Correspondent 1")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_document_types_by_name(self, mock_get_objects) -> None:
mock_get_objects.return_value = DocumentType.objects.all()
names = ["Test Document Type 1", "Nonexistent Document Type"]
result = match_document_types_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Document Type 1")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_storage_paths_by_name(self, mock_get_objects) -> None:
mock_get_objects.return_value = StoragePath.objects.all()
names = ["Test Storage Path 1", "Nonexistent Storage Path"]
result = match_storage_paths_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Storage Path 1")
def test_match_storage_paths_by_name(self, mocker: MockerFixture) -> None:
storage_paths = [
StoragePathFactory.build(name="Test Storage Path 1"),
StoragePathFactory.build(name="Test Storage Path 2"),
]
mocker.patch(_PATCH_TARGET, return_value=storage_paths)
result = match_storage_paths_by_name(
["Test Storage Path 1", "Nonexistent Storage Path"],
user=None,
)
assert len(result) == 1
assert result[0].name == "Test Storage Path 1"
def test_extract_unmatched_names(self) -> None:
llm_names = ["Test Tag 1", "Nonexistent Tag"]
matched_objects = [self.tag1]
unmatched_names = extract_unmatched_names(llm_names, matched_objects)
self.assertEqual(unmatched_names, ["Nonexistent Tag"])
tag = TagFactory.build(name="Test Tag 1")
unmatched = extract_unmatched_names(["Test Tag 1", "Nonexistent Tag"], [tag])
assert unmatched == ["Nonexistent Tag"]
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_by_name_with_empty_names(self, mock_get_objects) -> None:
mock_get_objects.return_value = Tag.objects.all()
names = [None, "", " "]
result = match_tags_by_name(names, user=None)
self.assertEqual(result, [])
def test_match_tags_by_name_with_empty_names(self, mocker: MockerFixture) -> None:
tags = [
TagFactory.build(name="Test Tag 1"),
TagFactory.build(name="Test Tag 2"),
]
mocker.patch(_PATCH_TARGET, return_value=tags)
result = match_tags_by_name([None, "", " "], user=None)
assert result == []
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_with_fuzzy_matching(self, mock_get_objects) -> None:
mock_get_objects.return_value = Tag.objects.all()
names = ["Test Taag 1", "Teest Tag 2"]
result = match_tags_by_name(names, user=None)
self.assertEqual(len(result), 2)
self.assertEqual(result[0].name, "Test Tag 1")
self.assertEqual(result[1].name, "Test Tag 2")
def test_match_tags_with_fuzzy_matching(self, mocker: MockerFixture) -> None:
tags = [
TagFactory.build(name="Test Tag 1"),
TagFactory.build(name="Test Tag 2"),
]
mocker.patch(_PATCH_TARGET, return_value=tags)
result = match_tags_by_name(["Test Taag 1", "Teest Tag 2"], user=None)
assert len(result) == 2
assert result[0].name == "Test Tag 1"
assert result[1].name == "Test Tag 2"