From 7a9390959772735d3e4a5d9e314afb799f8be03f Mon Sep 17 00:00:00 2001 From: stumpylog <797416+stumpylog@users.noreply.github.com> Date: Mon, 11 May 2026 09:26:26 -0700 Subject: [PATCH] Converts all these tests to fully use fixtures, factories and compsition + dropping DB setup where possible --- src/paperless_ai/embedding.py | 4 +- src/paperless_ai/tests/conftest.py | 36 +- src/paperless_ai/tests/test_ai_classifier.py | 207 ++++------ src/paperless_ai/tests/test_ai_indexing.py | 406 ++++++++----------- src/paperless_ai/tests/test_chat.py | 223 +++++----- src/paperless_ai/tests/test_client.py | 77 ++-- src/paperless_ai/tests/test_embedding.py | 196 ++++----- src/paperless_ai/tests/test_matching.py | 143 +++---- 8 files changed, 594 insertions(+), 698 deletions(-) diff --git a/src/paperless_ai/embedding.py b/src/paperless_ai/embedding.py index a96dd2429..1000f9df9 100644 --- a/src/paperless_ai/embedding.py +++ b/src/paperless_ai/embedding.py @@ -95,7 +95,9 @@ def build_llm_index_text(doc: Document) -> str: ] for instance in doc.custom_fields.all(): - lines.append(f"Custom Field - {instance.field.name}: {instance}") + lines.append( + f"Custom Field - {instance.field.name}: {instance.value_for_search}", + ) lines.append("\nContent:\n") lines.append(doc.content or "") diff --git a/src/paperless_ai/tests/conftest.py b/src/paperless_ai/tests/conftest.py index 2d71476c7..2dde7caed 100644 --- a/src/paperless_ai/tests/conftest.py +++ b/src/paperless_ai/tests/conftest.py @@ -3,8 +3,42 @@ from pathlib import Path import pytest from pytest_django.fixtures import SettingsWrapper +from documents.models import Document +from documents.tests.factories import CorrespondentFactory +from documents.tests.factories import DocumentFactory +from documents.tests.factories import DocumentTypeFactory + @pytest.fixture -def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper): +def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path: settings.LLM_INDEX_DIR = tmp_path return tmp_path + + +@pytest.fixture +def document() -> Document: + return DocumentFactory.build( + title="Test Title", + filename="test_file.pdf", + correspondent=CorrespondentFactory.build(name="Test Correspondent"), + document_type=DocumentTypeFactory.build(name="Invoice"), + archive_serial_number=12345, + content="This is the document content.", + ) + + +@pytest.fixture +def similar_documents() -> list[Document]: + return [ + DocumentFactory.build( + title="Title 1", + content="Content of document 1", + filename="file1.txt", + ), + DocumentFactory.build( + title="", + content="Content of document 2", + filename="file2.txt", + ), + DocumentFactory.build(title="", content="", filename=None), + ] diff --git a/src/paperless_ai/tests/test_ai_classifier.py b/src/paperless_ai/tests/test_ai_classifier.py index 115d51cd4..dde704106 100644 --- a/src/paperless_ai/tests/test_ai_classifier.py +++ b/src/paperless_ai/tests/test_ai_classifier.py @@ -1,9 +1,6 @@ -import json -from unittest.mock import MagicMock -from unittest.mock import patch - import pytest -from django.test import override_settings +from pytest_django.fixtures import SettingsWrapper +from pytest_mock import MockerFixture from documents.models import Document from paperless_ai.ai_classifier import build_prompt_with_rag @@ -12,69 +9,16 @@ from paperless_ai.ai_classifier import get_ai_document_classification from paperless_ai.ai_classifier import get_context_for_document -@pytest.fixture -def mock_document(): - doc = MagicMock(spec=Document) - doc.title = "Test Title" - doc.filename = "test_file.pdf" - doc.created = "2023-01-01" - doc.added = "2023-01-02" - doc.modified = "2023-01-03" - - tag1 = MagicMock() - tag1.name = "Tag1" - tag2 = MagicMock() - tag2.name = "Tag2" - doc.tags.all = MagicMock(return_value=[tag1, tag2]) - - doc.document_type = MagicMock() - doc.document_type.name = "Invoice" - doc.correspondent = MagicMock() - doc.correspondent.name = "Test Correspondent" - doc.archive_serial_number = "12345" - doc.content = "This is the document content." - - cf1 = MagicMock(__str__=lambda x: "Value1") - cf1.field = MagicMock() - cf1.field.name = "Field1" - cf1.value = "Value1" - cf2 = MagicMock(__str__=lambda x: "Value2") - cf2.field = MagicMock() - cf2.field.name = "Field2" - cf2.value = "Value2" - doc.custom_fields.all = MagicMock(return_value=[cf1, cf2]) - - return doc - - -@pytest.fixture -def mock_similar_documents(): - doc1 = MagicMock() - doc1.content = "Content of document 1" - doc1.title = "Title 1" - doc1.filename = "file1.txt" - - doc2 = MagicMock() - doc2.content = "Content of document 2" - doc2.title = None - doc2.filename = "file2.txt" - - doc3 = MagicMock() - doc3.content = None - doc3.title = None - doc3.filename = None - - return [doc1, doc2, doc3] - - @pytest.mark.django_db -@patch("paperless_ai.client.AIClient.run_llm_query") -@override_settings( - LLM_BACKEND="ollama", - LLM_MODEL="some_model", -) -def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): - mock_run_llm_query.return_value = { +def test_get_ai_document_classification_success( + settings: SettingsWrapper, + mocker: MockerFixture, + document: Document, +) -> None: + settings.LLM_BACKEND = "ollama" + settings.LLM_MODEL = "some_model" + mock_run = mocker.patch("paperless_ai.client.AIClient.run_llm_query") + mock_run.return_value = { "title": "Test Title", "tags": ["test", "document"], "correspondents": ["John Doe"], @@ -83,7 +27,7 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen "dates": ["2023-01-01"], } - result = get_ai_document_classification(mock_document) + result = get_ai_document_classification(document) assert result["title"] == "Test Title" assert result["tags"] == ["test", "document"] @@ -94,93 +38,86 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen @pytest.mark.django_db -@patch("paperless_ai.client.AIClient.run_llm_query") -def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): - mock_run_llm_query.side_effect = Exception("LLM query failed") - - # assert raises an exception +def test_get_ai_document_classification_failure( + mocker: MockerFixture, + document: Document, +) -> None: + mocker.patch( + "paperless_ai.client.AIClient.run_llm_query", + side_effect=Exception("LLM query failed"), + ) with pytest.raises(Exception): - get_ai_document_classification(mock_document) + get_ai_document_classification(document) @pytest.mark.django_db -@patch("paperless_ai.client.AIClient.run_llm_query") -@patch("paperless_ai.ai_classifier.build_prompt_with_rag") -@override_settings( - LLM_EMBEDDING_BACKEND="huggingface", - LLM_EMBEDDING_MODEL="some_model", - LLM_BACKEND="ollama", - LLM_MODEL="some_model", -) def test_use_rag_if_configured( - mock_build_prompt_with_rag, - mock_run_llm_query, - mock_document, -): - mock_build_prompt_with_rag.return_value = "Prompt with RAG" - mock_run_llm_query.return_value.text = json.dumps({}) - get_ai_document_classification(mock_document) - mock_build_prompt_with_rag.assert_called_once() + settings: SettingsWrapper, + mocker: MockerFixture, + document: Document, +) -> None: + settings.LLM_EMBEDDING_BACKEND = "huggingface" + settings.LLM_EMBEDDING_MODEL = "some_model" + settings.LLM_BACKEND = "ollama" + settings.LLM_MODEL = "some_model" + mock_build = mocker.patch("paperless_ai.ai_classifier.build_prompt_with_rag") + mock_build.return_value = "Prompt with RAG" + mocker.patch("paperless_ai.client.AIClient.run_llm_query", return_value={}) + get_ai_document_classification(document) + mock_build.assert_called_once() @pytest.mark.django_db -@patch("paperless_ai.client.AIClient.run_llm_query") -@patch("paperless_ai.ai_classifier.build_prompt_without_rag") -@patch("paperless.config.AIConfig") -@override_settings( - LLM_BACKEND="ollama", - LLM_MODEL="some_model", -) def test_use_without_rag_if_not_configured( - mock_ai_config, - mock_build_prompt_without_rag, - mock_run_llm_query, - mock_document, -): + settings: SettingsWrapper, + mocker: MockerFixture, + document: Document, +) -> None: + settings.LLM_BACKEND = "ollama" + settings.LLM_MODEL = "some_model" + mock_ai_config = mocker.patch("paperless.config.AIConfig") + mock_build = mocker.patch("paperless_ai.ai_classifier.build_prompt_without_rag") + mocker.patch("paperless_ai.client.AIClient.run_llm_query", return_value={}) mock_ai_config.llm_embedding_backend = None - mock_build_prompt_without_rag.return_value = "Prompt without RAG" - mock_run_llm_query.return_value.text = json.dumps({}) - get_ai_document_classification(mock_document) - mock_build_prompt_without_rag.assert_called_once() + mock_build.return_value = "Prompt without RAG" + get_ai_document_classification(document) + mock_build.assert_called_once() -@pytest.mark.django_db -@override_settings( - LLM_EMBEDDING_BACKEND="huggingface", - LLM_BACKEND="ollama", - LLM_MODEL="some_model", -) -def test_prompt_with_without_rag(mock_document): - with patch( +def test_prompt_with_without_rag(mocker: MockerFixture, document: Document) -> None: + mocker.patch( "paperless_ai.ai_classifier.get_context_for_document", return_value="Context from similar documents", - ): - prompt = build_prompt_without_rag(mock_document) - assert "Additional context from similar documents:" not in prompt + ) + prompt = build_prompt_without_rag(document) + assert "Additional context from similar documents:" not in prompt - prompt = build_prompt_with_rag(mock_document) - assert "Additional context from similar documents:" in prompt + prompt = build_prompt_with_rag(document) + assert "Additional context from similar documents:" in prompt -@patch("paperless_ai.ai_classifier.query_similar_documents") def test_get_context_for_document( - mock_query_similar_documents, - mock_document, - mock_similar_documents, -): - mock_query_similar_documents.return_value = mock_similar_documents - - result = get_context_for_document(mock_document, max_docs=2) - - expected_result = ( + mocker: MockerFixture, + document: Document, + similar_documents: list[Document], +) -> None: + mocker.patch( + "paperless_ai.ai_classifier.query_similar_documents", + return_value=similar_documents, + ) + result = get_context_for_document(document, max_docs=2) + assert result == ( "TITLE: Title 1\nContent of document 1\n\n" "TITLE: file2.txt\nContent of document 2" ) - assert result == expected_result - mock_query_similar_documents.assert_called_once() -def test_get_context_for_document_no_similar_docs(mock_document): - with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]): - result = get_context_for_document(mock_document) - assert result == "" +def test_get_context_for_document_no_similar_docs( + mocker: MockerFixture, + document: Document, +) -> None: + mocker.patch( + "paperless_ai.ai_classifier.query_similar_documents", + return_value=[], + ) + assert get_context_for_document(document) == "" diff --git a/src/paperless_ai/tests/test_ai_indexing.py b/src/paperless_ai/tests/test_ai_indexing.py index d02cf3b96..8627c700f 100644 --- a/src/paperless_ai/tests/test_ai_indexing.py +++ b/src/paperless_ai/tests/test_ai_indexing.py @@ -1,11 +1,12 @@ import json +from pathlib import Path from unittest.mock import MagicMock -from unittest.mock import patch import pytest -from django.test import override_settings from django.utils import timezone -from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.embeddings.mock_embed_model import MockEmbedding +from pytest_django.fixtures import SettingsWrapper +from pytest_mock import MockerFixture from documents.models import Document from documents.models import PaperlessTask @@ -14,7 +15,7 @@ from paperless_ai import indexing @pytest.fixture -def real_document(db): +def real_document(db) -> Document: return Document.objects.create( title="Test Document", content="This is some test content.", @@ -23,36 +24,15 @@ def real_document(db): @pytest.fixture -def mock_embed_model(): - fake = FakeEmbedding() - with ( - patch("paperless_ai.indexing.get_embedding_model") as mock_index, - patch( - "paperless_ai.embedding.get_embedding_model", - ) as mock_embedding, - ): - mock_index.return_value = fake - mock_embedding.return_value = fake - yield mock_index - - -class FakeEmbedding(BaseEmbedding): - # TODO: maybe a better way to do this? - def _aget_query_embedding(self, query: str) -> list[float]: - return [0.1] * self.get_query_embedding_dim() - - def _get_query_embedding(self, query: str) -> list[float]: - return [0.1] * self.get_query_embedding_dim() - - def _get_text_embedding(self, text: str) -> list[float]: - return [0.1] * self.get_query_embedding_dim() - - def get_query_embedding_dim(self) -> int: - return 384 # Match your real FAISS config +def mock_embed_model(mocker: MockerFixture) -> MagicMock: + fake = MockEmbedding(embed_dim=384) + mock = mocker.patch("paperless_ai.indexing.get_embedding_model", return_value=fake) + mocker.patch("paperless_ai.embedding.get_embedding_model", return_value=fake) + return mock @pytest.mark.django_db -def test_build_document_node(real_document) -> None: +def test_build_document_node(real_document: Document) -> None: nodes = indexing.build_document_node(real_document) assert len(nodes) > 0 assert nodes[0].metadata["document_id"] == str(real_document.id) @@ -60,37 +40,38 @@ def test_build_document_node(real_document) -> None: @pytest.mark.django_db def test_update_llm_index( - temp_llm_index_dir, - real_document, - mock_embed_model, + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: MagicMock, + mocker: MockerFixture, ) -> None: - with patch("documents.models.Document.objects.all") as mock_all: - mock_queryset = MagicMock() - mock_queryset.exists.return_value = True - mock_queryset.__iter__.return_value = iter([real_document]) - mock_all.return_value = mock_queryset - indexing.update_llm_index(rebuild=True) + mock_queryset = MagicMock() + mock_queryset.exists.return_value = True + mock_queryset.__iter__.return_value = iter([real_document]) + mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset) - assert any(temp_llm_index_dir.glob("*.json")) + indexing.update_llm_index(rebuild=True) + + assert any(temp_llm_index_dir.glob("*.json")) @pytest.mark.django_db def test_update_llm_index_removes_meta( - temp_llm_index_dir, - real_document, - mock_embed_model, + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: MagicMock, + mocker: MockerFixture, ) -> None: - # Pre-create a meta.json with incorrect data (temp_llm_index_dir / "meta.json").write_text( json.dumps({"embedding_model": "old", "dim": 1}), ) - with patch("documents.models.Document.objects.all") as mock_all: - mock_queryset = MagicMock() - mock_queryset.exists.return_value = True - mock_queryset.__iter__.return_value = iter([real_document]) - mock_all.return_value = mock_queryset - indexing.update_llm_index(rebuild=True) + mock_queryset = MagicMock() + mock_queryset.exists.return_value = True + mock_queryset.__iter__.return_value = iter([real_document]) + mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset) + + indexing.update_llm_index(rebuild=True) meta = json.loads((temp_llm_index_dir / "meta.json").read_text()) from paperless.config import AIConfig @@ -106,9 +87,10 @@ def test_update_llm_index_removes_meta( @pytest.mark.django_db def test_update_llm_index_partial_update( - temp_llm_index_dir, - real_document, - mock_embed_model, + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: MagicMock, + mocker: MockerFixture, ) -> None: doc2 = Document.objects.create( title="Test Document 2", @@ -116,20 +98,16 @@ def test_update_llm_index_partial_update( added=timezone.now(), checksum="1234567890abcdef", ) - # Initial index - with patch("documents.models.Document.objects.all") as mock_all: - mock_queryset = MagicMock() - mock_queryset.exists.return_value = True - mock_queryset.__iter__.return_value = iter([real_document, doc2]) - mock_all.return_value = mock_queryset - indexing.update_llm_index(rebuild=True) + mock_queryset = MagicMock() + mock_queryset.exists.return_value = True + mock_queryset.__iter__.return_value = iter([real_document, doc2]) + mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset) + indexing.update_llm_index(rebuild=True) - # modify document updated_document = real_document - updated_document.modified = timezone.now() # simulate modification + updated_document.modified = timezone.now() - # new doc doc3 = Document.objects.create( title="Test Document 3", content="This is some test content 3.", @@ -137,110 +115,101 @@ def test_update_llm_index_partial_update( checksum="abcdef1234567890", ) - with patch("documents.models.Document.objects.all") as mock_all: - mock_queryset = MagicMock() - mock_queryset.exists.return_value = True - mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3]) - mock_all.return_value = mock_queryset + mock_queryset2 = MagicMock() + mock_queryset2.exists.return_value = True + mock_queryset2.__iter__.return_value = iter([updated_document, doc2, doc3]) + mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset2) - # assert logs "Updating LLM index with %d new nodes and removing %d old nodes." - with patch("paperless_ai.indexing.logger") as mock_logger: - indexing.update_llm_index(rebuild=False) - mock_logger.info.assert_called_once_with( - "Updating %d nodes in LLM index.", - 2, - ) - indexing.update_llm_index(rebuild=False) + mock_logger = mocker.patch("paperless_ai.indexing.logger") + indexing.update_llm_index(rebuild=False) + mock_logger.info.assert_called_once_with("Updating %d nodes in LLM index.", 2) + + indexing.update_llm_index(rebuild=False) assert any(temp_llm_index_dir.glob("*.json")) def test_get_or_create_storage_context_raises_exception( - temp_llm_index_dir, - mock_embed_model, + temp_llm_index_dir: Path, + mock_embed_model: MagicMock, ) -> None: with pytest.raises(Exception): indexing.get_or_create_storage_context(rebuild=False) -@override_settings( - LLM_EMBEDDING_BACKEND="huggingface", -) +@pytest.mark.django_db def test_load_or_build_index_builds_when_nodes_given( - temp_llm_index_dir, - real_document, - mock_embed_model, + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: MagicMock, + mocker: MockerFixture, ) -> None: - with ( - patch( - "llama_index.core.load_index_from_storage", - side_effect=ValueError("Index not found"), - ), - patch( - "llama_index.core.VectorStoreIndex", - return_value=MagicMock(), - ) as mock_index_cls, - patch( - "paperless_ai.indexing.get_or_create_storage_context", - return_value=MagicMock(), - ) as mock_storage, - ): - mock_storage.return_value.persist_dir = temp_llm_index_dir - indexing.load_or_build_index( - nodes=[indexing.build_document_node(real_document)], - ) - mock_index_cls.assert_called_once() + mocker.patch( + "llama_index.core.load_index_from_storage", + side_effect=ValueError("Index not found"), + ) + mock_index_cls = mocker.patch( + "llama_index.core.VectorStoreIndex", + return_value=MagicMock(), + ) + mock_storage = mocker.patch( + "paperless_ai.indexing.get_or_create_storage_context", + return_value=MagicMock(), + ) + mock_storage.return_value.persist_dir = temp_llm_index_dir + + indexing.load_or_build_index(nodes=[indexing.build_document_node(real_document)]) + + mock_index_cls.assert_called_once() def test_load_or_build_index_raises_exception_when_no_nodes( - temp_llm_index_dir, - mock_embed_model, + temp_llm_index_dir: Path, + mock_embed_model: MagicMock, + mocker: MockerFixture, ) -> None: - with ( - patch( - "llama_index.core.load_index_from_storage", - side_effect=ValueError("Index not found"), - ), - patch( - "paperless_ai.indexing.get_or_create_storage_context", - return_value=MagicMock(), - ), - ): - with pytest.raises(Exception): - indexing.load_or_build_index() + mocker.patch( + "llama_index.core.load_index_from_storage", + side_effect=ValueError("Index not found"), + ) + mocker.patch( + "paperless_ai.indexing.get_or_create_storage_context", + return_value=MagicMock(), + ) + with pytest.raises(Exception): + indexing.load_or_build_index() @pytest.mark.django_db def test_load_or_build_index_succeeds_when_nodes_given( - temp_llm_index_dir, - mock_embed_model, + temp_llm_index_dir: Path, + mock_embed_model: MagicMock, + mocker: MockerFixture, ) -> None: - with ( - patch( - "llama_index.core.load_index_from_storage", - side_effect=ValueError("Index not found"), - ), - patch( - "llama_index.core.VectorStoreIndex", - return_value=MagicMock(), - ) as mock_index_cls, - patch( - "paperless_ai.indexing.get_or_create_storage_context", - return_value=MagicMock(), - ) as mock_storage, - ): - mock_storage.return_value.persist_dir = temp_llm_index_dir - indexing.load_or_build_index( - nodes=[MagicMock()], - ) - mock_index_cls.assert_called_once() + mocker.patch( + "llama_index.core.load_index_from_storage", + side_effect=ValueError("Index not found"), + ) + mock_index_cls = mocker.patch( + "llama_index.core.VectorStoreIndex", + return_value=MagicMock(), + ) + mock_storage = mocker.patch( + "paperless_ai.indexing.get_or_create_storage_context", + return_value=MagicMock(), + ) + mock_storage.return_value.persist_dir = temp_llm_index_dir + + indexing.load_or_build_index(nodes=[MagicMock()]) + + mock_index_cls.assert_called_once() @pytest.mark.django_db def test_add_or_update_document_updates_existing_entry( - temp_llm_index_dir, - real_document, - mock_embed_model, + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: MagicMock, ) -> None: indexing.update_llm_index(rebuild=True) indexing.llm_index_add_or_update_document(real_document) @@ -250,9 +219,9 @@ def test_add_or_update_document_updates_existing_entry( @pytest.mark.django_db def test_remove_document_deletes_node_from_docstore( - temp_llm_index_dir, - real_document, - mock_embed_model, + temp_llm_index_dir: Path, + real_document: Document, + mock_embed_model: MagicMock, ) -> None: indexing.update_llm_index(rebuild=True) index = indexing.load_or_build_index() @@ -265,31 +234,29 @@ def test_remove_document_deletes_node_from_docstore( @pytest.mark.django_db def test_update_llm_index_no_documents( - temp_llm_index_dir, - mock_embed_model, + temp_llm_index_dir: Path, + mock_embed_model: MagicMock, + mocker: MockerFixture, ) -> None: - with patch("documents.models.Document.objects.all") as mock_all: - mock_queryset = MagicMock() - mock_queryset.exists.return_value = False - mock_queryset.__iter__.return_value = iter([]) - mock_all.return_value = mock_queryset + mock_queryset = MagicMock() + mock_queryset.exists.return_value = False + mock_queryset.__iter__.return_value = iter([]) + mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset) - # check log message - with patch("paperless_ai.indexing.logger") as mock_logger: - indexing.update_llm_index(rebuild=True) - mock_logger.warning.assert_called_once_with( - "No documents found to index.", - ) + mock_logger = mocker.patch("paperless_ai.indexing.logger") + indexing.update_llm_index(rebuild=True) + mock_logger.warning.assert_called_once_with("No documents found to index.") @pytest.mark.django_db -def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -> None: - # No existing tasks - with patch("documents.tasks.llmindex_index") as mock_task: - result = indexing.queue_llm_index_update_if_needed( - rebuild=True, - reason="test enqueue", - ) +def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent( + mocker: MockerFixture, +) -> None: + mock_task = mocker.patch("documents.tasks.llmindex_index") + result = indexing.queue_llm_index_update_if_needed( + rebuild=True, + reason="test enqueue", + ) assert result is True mock_task.apply_async.assert_called_once_with( @@ -303,86 +270,71 @@ def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() - status=PaperlessTask.Status.STARTED, ) - # Existing running task - with patch("documents.tasks.llmindex_index") as mock_task: - result = indexing.queue_llm_index_update_if_needed( - rebuild=False, - reason="should skip", - ) + mock_task2 = mocker.patch("documents.tasks.llmindex_index") + result = indexing.queue_llm_index_update_if_needed( + rebuild=False, + reason="should skip", + ) assert result is False - mock_task.apply_async.assert_not_called() + mock_task2.apply_async.assert_not_called() -@override_settings( - LLM_EMBEDDING_BACKEND="huggingface", - LLM_BACKEND="ollama", -) +@pytest.mark.django_db def test_query_similar_documents( - temp_llm_index_dir, - real_document, + temp_llm_index_dir: Path, + real_document: Document, + mocker: MockerFixture, + settings: SettingsWrapper, ) -> None: - with ( - patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage, - patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index, - patch( - "paperless_ai.indexing.vector_store_file_exists", - ) as mock_vector_store_exists, - patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls, - patch("paperless_ai.indexing.Document.objects.filter") as mock_filter, - ): - mock_storage.return_value = MagicMock() - mock_storage.return_value.persist_dir = temp_llm_index_dir - mock_vector_store_exists.return_value = True + settings.LLM_EMBEDDING_BACKEND = "huggingface" + settings.LLM_BACKEND = "ollama" - mock_index = MagicMock() - mock_load_or_build_index.return_value = mock_index + mock_storage = mocker.patch("paperless_ai.indexing.get_or_create_storage_context") + mock_storage.return_value.persist_dir = temp_llm_index_dir + mocker.patch("paperless_ai.indexing.vector_store_file_exists", return_value=True) - mock_retriever = MagicMock() - mock_retriever_cls.return_value = mock_retriever + mock_index = MagicMock() + mocker.patch("paperless_ai.indexing.load_or_build_index", return_value=mock_index) - mock_node1 = MagicMock() - mock_node1.metadata = {"document_id": 1} + mock_retriever = MagicMock() + mocker.patch( + "llama_index.core.retrievers.VectorIndexRetriever", + return_value=mock_retriever, + ) - mock_node2 = MagicMock() - mock_node2.metadata = {"document_id": 2} + mock_node1 = MagicMock() + mock_node1.metadata = {"document_id": 1} + mock_node2 = MagicMock() + mock_node2.metadata = {"document_id": 2} + mock_retriever.retrieve.return_value = [mock_node1, mock_node2] - mock_retriever.retrieve.return_value = [mock_node1, mock_node2] + mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)] + mock_filter = mocker.patch( + "paperless_ai.indexing.Document.objects.filter", + return_value=mock_filtered_docs, + ) - mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)] - mock_filter.return_value = mock_filtered_docs + result = indexing.query_similar_documents(real_document, top_k=3) - result = indexing.query_similar_documents(real_document, top_k=3) - - mock_load_or_build_index.assert_called_once() - mock_retriever_cls.assert_called_once() - mock_retriever.retrieve.assert_called_once_with( - "Test Document\nThis is some test content.", - ) - mock_filter.assert_called_once_with(pk__in=[1, 2]) - - assert result == mock_filtered_docs + mock_retriever.retrieve.assert_called_once_with( + "Test Document\nThis is some test content.", + ) + mock_filter.assert_called_once_with(pk__in=[1, 2]) + assert result == mock_filtered_docs @pytest.mark.django_db def test_query_similar_documents_triggers_update_when_index_missing( - temp_llm_index_dir, - real_document, + temp_llm_index_dir: Path, + real_document: Document, + mocker: MockerFixture, ) -> None: - with ( - patch( - "paperless_ai.indexing.vector_store_file_exists", - return_value=False, - ), - patch( - "paperless_ai.indexing.queue_llm_index_update_if_needed", - ) as mock_queue, - patch("paperless_ai.indexing.load_or_build_index") as mock_load, - ): - result = indexing.query_similar_documents( - real_document, - top_k=2, - ) + mocker.patch("paperless_ai.indexing.vector_store_file_exists", return_value=False) + mock_queue = mocker.patch("paperless_ai.indexing.queue_llm_index_update_if_needed") + mock_load = mocker.patch("paperless_ai.indexing.load_or_build_index") + + result = indexing.query_similar_documents(real_document, top_k=2) mock_queue.assert_called_once_with( rebuild=False, diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index 5e26ca0af..4e158716b 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -1,10 +1,10 @@ import json from unittest.mock import MagicMock -from unittest.mock import patch import pytest from llama_index.core import VectorStoreIndex from llama_index.core.schema import TextNode +from pytest_mock import MockerFixture from paperless_ai.chat import CHAT_METADATA_DELIMITER from paperless_ai.chat import stream_chat_with_documents @@ -15,31 +15,17 @@ def patch_embed_model(): from llama_index.core import settings as llama_settings from llama_index.core.embeddings.mock_embed_model import MockEmbedding - # Use a real BaseEmbedding subclass to satisfy llama-index 0.14 validation llama_settings.Settings.embed_model = MockEmbedding(embed_dim=1536) yield llama_settings.Settings.embed_model = None @pytest.fixture(autouse=True) -def patch_embed_nodes(): - with patch( - "llama_index.core.indices.vector_store.base.embed_nodes", - ) as mock_embed_nodes: - mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: { - node.node_id: [0.1] * 1536 for node in nodes - } - yield - - -@pytest.fixture -def mock_document(): - doc = MagicMock() - doc.pk = 1 - doc.title = "Test Document" - doc.filename = "test_file.pdf" - doc.content = "This is the document content." - return doc +def patch_embed_nodes(mocker: MockerFixture): + mock = mocker.patch("llama_index.core.indices.vector_store.base.embed_nodes") + mock.side_effect = lambda nodes, *_args, **_kwargs: { + node.node_id: [0.1] * 1536 for node in nodes + } def assert_chat_output( @@ -57,127 +43,110 @@ def assert_chat_output( } -def test_stream_chat_with_one_document_full_content(mock_document) -> None: - with ( - patch("paperless_ai.chat.AIClient") as mock_client_cls, - patch("paperless_ai.chat.load_or_build_index") as mock_load_index, - patch( - "llama_index.core.query_engine.RetrieverQueryEngine.from_args", - ) as mock_query_engine_cls, - ): - mock_client = MagicMock() - mock_client_cls.return_value = mock_client - mock_client.llm = MagicMock() +def test_stream_chat_with_one_document_full_content(mocker: MockerFixture) -> None: + mock_document = MagicMock() + mock_document.pk = 1 + mock_document.title = "Test Document" + mock_document.filename = "test_file.pdf" + mock_document.content = "This is the document content." - mock_node = TextNode( - text="This is node content.", - metadata={"document_id": str(mock_document.pk), "title": "Test Document"}, - ) - mock_index = MagicMock() - mock_index.docstore.docs.values.return_value = [mock_node] - mock_load_index.return_value = mock_index + mock_client = MagicMock() + mocker.patch("paperless_ai.chat.AIClient", return_value=mock_client) - mock_response_stream = MagicMock() - mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) - mock_query_engine = MagicMock() - mock_query_engine_cls.return_value = mock_query_engine - mock_query_engine.query.return_value = mock_response_stream + mock_node = TextNode( + text="This is node content.", + metadata={"document_id": str(mock_document.pk), "title": "Test Document"}, + ) + mock_index = MagicMock() + mock_index.docstore.docs.values.return_value = [mock_node] + mocker.patch("paperless_ai.chat.load_or_build_index", return_value=mock_index) - output = list(stream_chat_with_documents("What is this?", [mock_document])) + mock_response_stream = MagicMock() + mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) + mock_query_engine = MagicMock() + mock_query_engine.query.return_value = mock_response_stream + mocker.patch( + "llama_index.core.query_engine.RetrieverQueryEngine.from_args", + return_value=mock_query_engine, + ) - assert_chat_output( - output, - expected_chunks=["chunk1", "chunk2"], - expected_references=[ - {"id": mock_document.pk, "title": "Test Document"}, - ], - ) + output = list(stream_chat_with_documents("What is this?", [mock_document])) + + assert_chat_output( + output, + expected_chunks=["chunk1", "chunk2"], + expected_references=[{"id": mock_document.pk, "title": "Test Document"}], + ) -def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> None: - with ( - patch("paperless_ai.chat.AIClient") as mock_client_cls, - patch("paperless_ai.chat.load_or_build_index") as mock_load_index, - patch( - "llama_index.core.query_engine.RetrieverQueryEngine.from_args", - ) as mock_query_engine_cls, - patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever, - ): - # Mock AIClient and LLM - mock_client = MagicMock() - mock_client_cls.return_value = mock_client - mock_client.llm = MagicMock() +def test_stream_chat_with_multiple_documents_retrieval( + patch_embed_nodes, + mocker: MockerFixture, +) -> None: + mock_client = MagicMock() + mocker.patch("paperless_ai.chat.AIClient", return_value=mock_client) - # Create two real TextNodes - mock_node1 = TextNode( - text="Content for doc 1.", - metadata={"document_id": "1", "title": "Document 1"}, - ) - mock_node2 = TextNode( - text="Content for doc 2.", - metadata={"document_id": "2", "title": "Document 2"}, - ) - mock_index = MagicMock() - mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2] - mock_load_index.return_value = mock_index + mock_node1 = TextNode( + text="Content for doc 1.", + metadata={"document_id": "1", "title": "Document 1"}, + ) + mock_node2 = TextNode( + text="Content for doc 2.", + metadata={"document_id": "2", "title": "Document 2"}, + ) + mock_index = MagicMock() + mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2] + mocker.patch("paperless_ai.chat.load_or_build_index", return_value=mock_index) - # Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2 - mock_retriever = MagicMock() - mock_duplicate_node = TextNode( - text="More content for doc 1.", - metadata={"document_id": "1", "title": "Document 1 Duplicate"}, - ) - mock_foreign_node = TextNode( - text="Content for doc 3.", - metadata={"document_id": "3", "title": "Document 3"}, - ) - mock_retriever.retrieve.return_value = [ - mock_node1, - mock_duplicate_node, - mock_node2, - mock_foreign_node, - ] - mock_as_retriever.return_value = mock_retriever + mock_retriever = MagicMock() + mock_duplicate_node = TextNode( + text="More content for doc 1.", + metadata={"document_id": "1", "title": "Document 1 Duplicate"}, + ) + mock_foreign_node = TextNode( + text="Content for doc 3.", + metadata={"document_id": "3", "title": "Document 3"}, + ) + mock_retriever.retrieve.return_value = [ + mock_node1, + mock_duplicate_node, + mock_node2, + mock_foreign_node, + ] + mocker.patch.object(VectorStoreIndex, "as_retriever", return_value=mock_retriever) - # Mock response stream - mock_response_stream = MagicMock() - mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) + mock_response_stream = MagicMock() + mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) + mock_query_engine = MagicMock() + mock_query_engine.query.return_value = mock_response_stream + mocker.patch( + "llama_index.core.query_engine.RetrieverQueryEngine.from_args", + return_value=mock_query_engine, + ) - # Mock RetrieverQueryEngine - mock_query_engine = MagicMock() - mock_query_engine_cls.return_value = mock_query_engine - mock_query_engine.query.return_value = mock_response_stream + doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf") + doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf") - # Fake documents - doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf") - doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf") + output = list(stream_chat_with_documents("What's up?", [doc1, doc2])) - output = list(stream_chat_with_documents("What's up?", [doc1, doc2])) - - assert_chat_output( - output, - expected_chunks=["chunk1", "chunk2"], - expected_references=[ - {"id": 1, "title": "Document 1"}, - {"id": 2, "title": "Document 2"}, - ], - ) + assert_chat_output( + output, + expected_chunks=["chunk1", "chunk2"], + expected_references=[ + {"id": 1, "title": "Document 1"}, + {"id": 2, "title": "Document 2"}, + ], + ) -def test_stream_chat_no_matching_nodes() -> None: - with ( - patch("paperless_ai.chat.AIClient") as mock_client_cls, - patch("paperless_ai.chat.load_or_build_index") as mock_load_index, - ): - mock_client = MagicMock() - mock_client_cls.return_value = mock_client - mock_client.llm = MagicMock() +def test_stream_chat_no_matching_nodes(mocker: MockerFixture) -> None: + mock_client = MagicMock() + mocker.patch("paperless_ai.chat.AIClient", return_value=mock_client) - mock_index = MagicMock() - # No matching nodes - mock_index.docstore.docs.values.return_value = [] - mock_load_index.return_value = mock_index + mock_index = MagicMock() + mock_index.docstore.docs.values.return_value = [] + mocker.patch("paperless_ai.chat.load_or_build_index", return_value=mock_index) - output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)])) + output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)])) - assert output == ["Sorry, I couldn't find any content to answer your question."] + assert output == ["Sorry, I couldn't find any content to answer your question."] diff --git a/src/paperless_ai/tests/test_client.py b/src/paperless_ai/tests/test_client.py index 35a881400..7c346ae06 100644 --- a/src/paperless_ai/tests/test_client.py +++ b/src/paperless_ai/tests/test_client.py @@ -1,38 +1,35 @@ from unittest.mock import MagicMock -from unittest.mock import patch import pytest from llama_index.core.llms import ChatMessage from llama_index.core.llms.llm import ToolSelection +from pytest_mock import MockerFixture +from paperless.config import AIConfig from paperless_ai.client import AIClient @pytest.fixture -def mock_ai_config(): - with patch("paperless_ai.client.AIConfig") as MockAIConfig: - mock_config = MagicMock() - mock_config.llm_allow_internal_endpoints = True - MockAIConfig.return_value = mock_config - yield mock_config +def mock_ai_config(mocker: MockerFixture) -> MagicMock: + mock = mocker.patch("paperless_ai.client.AIConfig", spec=AIConfig) + mock.return_value.llm_allow_internal_endpoints = True + return mock @pytest.fixture -def mock_ollama_llm(): - with patch("llama_index.llms.ollama.Ollama") as MockOllama: - yield MockOllama +def mock_ollama_llm(mocker: MockerFixture) -> MagicMock: + return mocker.patch("llama_index.llms.ollama.Ollama") @pytest.fixture -def mock_openai_llm(): - with patch("llama_index.llms.openai_like.OpenAILike") as MockOpenAILike: - yield MockOpenAILike +def mock_openai_llm(mocker: MockerFixture) -> MagicMock: + return mocker.patch("llama_index.llms.openai_like.OpenAILike") -def test_get_llm_ollama(mock_ai_config, mock_ollama_llm): - mock_ai_config.llm_backend = "ollama" - mock_ai_config.llm_model = "test_model" - mock_ai_config.llm_endpoint = "http://test-url" +def test_get_llm_ollama(mock_ai_config: MagicMock, mock_ollama_llm: MagicMock) -> None: + mock_ai_config.return_value.llm_backend = "ollama" + mock_ai_config.return_value.llm_model = "test_model" + mock_ai_config.return_value.llm_endpoint = "http://test-url" client = AIClient() @@ -44,11 +41,11 @@ def test_get_llm_ollama(mock_ai_config, mock_ollama_llm): assert client.llm == mock_ollama_llm.return_value -def test_get_llm_openai(mock_ai_config, mock_openai_llm): - mock_ai_config.llm_backend = "openai-like" - mock_ai_config.llm_model = "test_model" - mock_ai_config.llm_api_key = "test_api_key" - mock_ai_config.llm_endpoint = "http://test-url" +def test_get_llm_openai(mock_ai_config: MagicMock, mock_openai_llm: MagicMock) -> None: + mock_ai_config.return_value.llm_backend = "openai-like" + mock_ai_config.return_value.llm_model = "test_model" + mock_ai_config.return_value.llm_api_key = "test_api_key" + mock_ai_config.return_value.llm_endpoint = "http://test-url" client = AIClient() @@ -62,31 +59,32 @@ def test_get_llm_openai(mock_ai_config, mock_openai_llm): assert client.llm == mock_openai_llm.return_value -def test_get_llm_openai_blocks_internal_endpoint_when_disallowed(mock_ai_config): - mock_ai_config.llm_backend = "openai-like" - mock_ai_config.llm_model = "test_model" - mock_ai_config.llm_api_key = "test_api_key" - mock_ai_config.llm_endpoint = "http://127.0.0.1:1234" - mock_ai_config.llm_allow_internal_endpoints = False +def test_get_llm_openai_blocks_internal_endpoint_when_disallowed( + mock_ai_config: MagicMock, +) -> None: + mock_ai_config.return_value.llm_backend = "openai-like" + mock_ai_config.return_value.llm_model = "test_model" + mock_ai_config.return_value.llm_api_key = "test_api_key" + mock_ai_config.return_value.llm_endpoint = "http://127.0.0.1:1234" + mock_ai_config.return_value.llm_allow_internal_endpoints = False with pytest.raises(ValueError, match="non-public address"): AIClient() -def test_get_llm_unsupported_backend(mock_ai_config): - mock_ai_config.llm_backend = "unsupported" +def test_get_llm_unsupported_backend(mock_ai_config: MagicMock) -> None: + mock_ai_config.return_value.llm_backend = "unsupported" with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"): AIClient() -def test_run_llm_query(mock_ai_config, mock_ollama_llm): - mock_ai_config.llm_backend = "ollama" - mock_ai_config.llm_model = "test_model" - mock_ai_config.llm_endpoint = "http://test-url" +def test_run_llm_query(mock_ai_config: MagicMock, mock_ollama_llm: MagicMock) -> None: + mock_ai_config.return_value.llm_backend = "ollama" + mock_ai_config.return_value.llm_model = "test_model" + mock_ai_config.return_value.llm_endpoint = "http://test-url" mock_llm_instance = mock_ollama_llm.return_value - tool_selection = ToolSelection( tool_id="call_test", tool_name="DocumentClassifierSchema", @@ -99,7 +97,6 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm): "dates": ["2023-01-01"], }, ) - mock_llm_instance.chat_with_tools.return_value = MagicMock() mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection] @@ -109,10 +106,10 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm): assert result["title"] == "Test Title" -def test_run_chat(mock_ai_config, mock_ollama_llm): - mock_ai_config.llm_backend = "ollama" - mock_ai_config.llm_model = "test_model" - mock_ai_config.llm_endpoint = "http://test-url" +def test_run_chat(mock_ai_config: MagicMock, mock_ollama_llm: MagicMock) -> None: + mock_ai_config.return_value.llm_backend = "ollama" + mock_ai_config.return_value.llm_model = "test_model" + mock_ai_config.return_value.llm_endpoint = "http://test-url" mock_llm_instance = mock_ollama_llm.return_value mock_llm_instance.chat.return_value = "test_chat_result" diff --git a/src/paperless_ai/tests/test_embedding.py b/src/paperless_ai/tests/test_embedding.py index e4e80cdf1..39d3c3e6e 100644 --- a/src/paperless_ai/tests/test_embedding.py +++ b/src/paperless_ai/tests/test_embedding.py @@ -1,10 +1,20 @@ +import datetime import json +from pathlib import Path from unittest.mock import MagicMock -from unittest.mock import patch import pytest +from pytest_mock import MockerFixture +from documents.models import CustomField +from documents.models import CustomFieldInstance from documents.models import Document +from documents.models import Note +from documents.tests.factories import CorrespondentFactory +from documents.tests.factories import DocumentFactory +from documents.tests.factories import DocumentTypeFactory +from documents.tests.factories import TagFactory +from paperless.config import AIConfig from paperless.models import LLMEmbeddingBackend from paperless_ai.embedding import build_llm_index_text from paperless_ai.embedding import get_embedding_dim @@ -12,97 +22,89 @@ from paperless_ai.embedding import get_embedding_model @pytest.fixture -def mock_ai_config(): - with patch("paperless_ai.embedding.AIConfig") as MockAIConfig: - MockAIConfig.return_value.llm_allow_internal_endpoints = True - yield MockAIConfig +def mock_ai_config(mocker: MockerFixture) -> MagicMock: + mock = mocker.patch("paperless_ai.embedding.AIConfig", spec=AIConfig) + mock.return_value.llm_allow_internal_endpoints = True + return mock @pytest.fixture -def mock_document(): - doc = MagicMock(spec=Document) - doc.title = "Test Title" - doc.filename = "test_file.pdf" - doc.created = "2023-01-01" - doc.added = "2023-01-02" - doc.modified = "2023-01-03" - - tag1 = MagicMock() - tag1.name = "Tag1" - tag2 = MagicMock() - tag2.name = "Tag2" - doc.tags.all = MagicMock(return_value=[tag1, tag2]) - - doc.document_type = MagicMock() - doc.document_type.name = "Invoice" - doc.correspondent = MagicMock() - doc.correspondent.name = "Test Correspondent" - doc.archive_serial_number = "12345" - doc.content = "This is the document content." - - cf1 = MagicMock(__str__=lambda x: "Value1") - cf1.field = MagicMock() - cf1.field.name = "Field1" - cf1.value = "Value1" - cf2 = MagicMock(__str__=lambda x: "Value2") - cf2.field = MagicMock() - cf2.field.name = "Field2" - cf2.value = "Value2" - doc.custom_fields.all = MagicMock(return_value=[cf1, cf2]) - +def full_document(db) -> Document: + tag1 = TagFactory(name="Tag1") + tag2 = TagFactory(name="Tag2") + doc = DocumentFactory( + title="Test Title", + filename="test_file.pdf", + created=datetime.date(2023, 1, 1), + correspondent=CorrespondentFactory(name="Test Correspondent"), + document_type=DocumentTypeFactory(name="Invoice"), + archive_serial_number=12345, + content="This is the document content.", + ) + doc.tags.add(tag1, tag2) + cf1 = CustomField.objects.create( + name="Field1", + data_type=CustomField.FieldDataType.STRING, + ) + cf2 = CustomField.objects.create( + name="Field2", + data_type=CustomField.FieldDataType.STRING, + ) + CustomFieldInstance.objects.create(document=doc, field=cf1, value_text="Value1") + CustomFieldInstance.objects.create(document=doc, field=cf2, value_text="Value2") + Note.objects.create(document=doc, note="Note1") + Note.objects.create(document=doc, note="Note2") return doc -def test_get_embedding_model_openai(mock_ai_config): +def test_get_embedding_model_openai( + mock_ai_config: MagicMock, + mocker: MockerFixture, +) -> None: mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small" mock_ai_config.return_value.llm_api_key = "test_api_key" mock_ai_config.return_value.llm_endpoint = "http://test-url" - - with patch( - "llama_index.embeddings.openai_like.OpenAILikeEmbedding", - ) as MockOpenAIEmbedding: - model = get_embedding_model() - MockOpenAIEmbedding.assert_called_once_with( - model_name="text-embedding-3-small", - api_key="test_api_key", - api_base="http://test-url", - ) - assert model == MockOpenAIEmbedding.return_value + mock_cls = mocker.patch("llama_index.embeddings.openai_like.OpenAILikeEmbedding") + model = get_embedding_model() + mock_cls.assert_called_once_with( + model_name="text-embedding-3-small", + api_key="test_api_key", + api_base="http://test-url", + ) + assert model == mock_cls.return_value def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed( - mock_ai_config, -): + mock_ai_config: MagicMock, +) -> None: mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small" mock_ai_config.return_value.llm_api_key = "test_api_key" mock_ai_config.return_value.llm_endpoint = "http://127.0.0.1:11434" mock_ai_config.return_value.llm_allow_internal_endpoints = False - with pytest.raises(ValueError, match="non-public address"): get_embedding_model() -def test_get_embedding_model_huggingface(mock_ai_config): +def test_get_embedding_model_huggingface( + mock_ai_config: MagicMock, + mocker: MockerFixture, +) -> None: mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE mock_ai_config.return_value.llm_embedding_model = ( "sentence-transformers/all-MiniLM-L6-v2" ) - - with patch( - "llama_index.embeddings.huggingface.HuggingFaceEmbedding", - ) as MockHuggingFaceEmbedding: - model = get_embedding_model() - MockHuggingFaceEmbedding.assert_called_once_with( - model_name="sentence-transformers/all-MiniLM-L6-v2", - ) - assert model == MockHuggingFaceEmbedding.return_value + mock_cls = mocker.patch("llama_index.embeddings.huggingface.HuggingFaceEmbedding") + model = get_embedding_model() + mock_cls.assert_called_once_with( + model_name="sentence-transformers/all-MiniLM-L6-v2", + ) + assert model == mock_cls.return_value -def test_get_embedding_model_invalid_backend(mock_ai_config): +def test_get_embedding_model_invalid_backend(mock_ai_config: MagicMock) -> None: mock_ai_config.return_value.llm_embedding_backend = "INVALID_BACKEND" - with pytest.raises( ValueError, match="Unsupported embedding backend: INVALID_BACKEND", @@ -110,47 +112,53 @@ def test_get_embedding_model_invalid_backend(mock_ai_config): get_embedding_model() -def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config): +def test_get_embedding_dim_infers_and_saves( + temp_llm_index_dir: Path, + mock_ai_config: MagicMock, + mocker: MockerFixture, +) -> None: mock_ai_config.return_value.llm_embedding_backend = "openai-like" mock_ai_config.return_value.llm_embedding_model = None class DummyEmbedding: - def get_text_embedding(self, text): + def get_text_embedding(self, text: str) -> list[float]: return [0.0] * 7 - with patch( + mock_get = mocker.patch( "paperless_ai.embedding.get_embedding_model", return_value=DummyEmbedding(), - ) as mock_get: - dim = get_embedding_dim() - mock_get.assert_called_once() - + ) + dim = get_embedding_dim() + mock_get.assert_called_once() assert dim == 7 meta = json.loads((temp_llm_index_dir / "meta.json").read_text()) assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7} -def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config): +def test_get_embedding_dim_reads_existing_meta( + temp_llm_index_dir: Path, + mock_ai_config: MagicMock, + mocker: MockerFixture, +) -> None: mock_ai_config.return_value.llm_embedding_backend = "openai-like" mock_ai_config.return_value.llm_embedding_model = None - (temp_llm_index_dir / "meta.json").write_text( json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}), ) - - with patch("paperless_ai.embedding.get_embedding_model") as mock_get: - assert get_embedding_dim() == 11 - mock_get.assert_not_called() + mock_get = mocker.patch("paperless_ai.embedding.get_embedding_model") + assert get_embedding_dim() == 11 + mock_get.assert_not_called() -def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config): +def test_get_embedding_dim_raises_on_model_change( + temp_llm_index_dir: Path, + mock_ai_config: MagicMock, +) -> None: mock_ai_config.return_value.llm_embedding_backend = "openai-like" mock_ai_config.return_value.llm_embedding_model = None - (temp_llm_index_dir / "meta.json").write_text( json.dumps({"embedding_model": "old", "dim": 11}), ) - with pytest.raises( RuntimeError, match="Embedding model changed from old to text-embedding-3-small", @@ -158,21 +166,15 @@ def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_co get_embedding_dim() -def test_build_llm_index_text(mock_document): - with patch("documents.models.Note.objects.filter") as mock_notes_filter: - mock_notes_filter.return_value = [ - MagicMock(note="Note1"), - MagicMock(note="Note2"), - ] - - result = build_llm_index_text(mock_document) - - assert "Title: Test Title" in result - assert "Filename: test_file.pdf" in result - assert "Created: 2023-01-01" in result - assert "Tags: Tag1, Tag2" in result - assert "Document Type: Invoice" in result - assert "Correspondent: Test Correspondent" in result - assert "Notes: Note1,Note2" in result - assert "Content:\n\nThis is the document content." in result - assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result +@pytest.mark.django_db +def test_build_llm_index_text(full_document: Document) -> None: + result = build_llm_index_text(full_document) + assert "Title: Test Title" in result + assert "Filename: test_file.pdf" in result + assert "Created: 2023-01-01" in result + assert "Tags: Tag1, Tag2" in result + assert "Document Type: Invoice" in result + assert "Correspondent: Test Correspondent" in result + assert "Notes: Note1,Note2" in result + assert "Content:\n\nThis is the document content." in result + assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result diff --git a/src/paperless_ai/tests/test_matching.py b/src/paperless_ai/tests/test_matching.py index 83cfd8a41..f59887be2 100644 --- a/src/paperless_ai/tests/test_matching.py +++ b/src/paperless_ai/tests/test_matching.py @@ -1,86 +1,89 @@ -from unittest.mock import patch +from pytest_mock import MockerFixture -from django.test import TestCase - -from documents.models import Correspondent -from documents.models import DocumentType -from documents.models import StoragePath -from documents.models import Tag +from documents.tests.factories import CorrespondentFactory +from documents.tests.factories import DocumentTypeFactory +from documents.tests.factories import StoragePathFactory +from documents.tests.factories import TagFactory from paperless_ai.matching import extract_unmatched_names from paperless_ai.matching import match_correspondents_by_name from paperless_ai.matching import match_document_types_by_name from paperless_ai.matching import match_storage_paths_by_name from paperless_ai.matching import match_tags_by_name +_PATCH_TARGET = "paperless_ai.matching.get_objects_for_user_owner_aware" -class TestAIMatching(TestCase): - def setUp(self) -> None: - # Create test data for Tag - self.tag1 = Tag.objects.create(name="Test Tag 1") - self.tag2 = Tag.objects.create(name="Test Tag 2") - # Create test data for Correspondent - self.correspondent1 = Correspondent.objects.create(name="Test Correspondent 1") - self.correspondent2 = Correspondent.objects.create(name="Test Correspondent 2") +class TestAIMatching: + def test_match_tags_by_name(self, mocker: MockerFixture) -> None: + tags = [ + TagFactory.build(name="Test Tag 1"), + TagFactory.build(name="Test Tag 2"), + ] + mocker.patch(_PATCH_TARGET, return_value=tags) + result = match_tags_by_name(["Test Tag 1", "Nonexistent Tag"], user=None) + assert len(result) == 1 + assert result[0].name == "Test Tag 1" - # Create test data for DocumentType - self.document_type1 = DocumentType.objects.create(name="Test Document Type 1") - self.document_type2 = DocumentType.objects.create(name="Test Document Type 2") + def test_match_correspondents_by_name(self, mocker: MockerFixture) -> None: + correspondents = [ + CorrespondentFactory.build(name="Test Correspondent 1"), + CorrespondentFactory.build(name="Test Correspondent 2"), + ] + mocker.patch(_PATCH_TARGET, return_value=correspondents) + result = match_correspondents_by_name( + ["Test Correspondent 1", "Nonexistent Correspondent"], + user=None, + ) + assert len(result) == 1 + assert result[0].name == "Test Correspondent 1" - # Create test data for StoragePath - self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1") - self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2") + def test_match_document_types_by_name(self, mocker: MockerFixture) -> None: + document_types = [ + DocumentTypeFactory.build(name="Test Document Type 1"), + DocumentTypeFactory.build(name="Test Document Type 2"), + ] + mocker.patch(_PATCH_TARGET, return_value=document_types) + result = match_document_types_by_name( + ["Test Document Type 1", "Nonexistent Document Type"], + user=None, + ) + assert len(result) == 1 + assert result[0].name == "Test Document Type 1" - @patch("paperless_ai.matching.get_objects_for_user_owner_aware") - def test_match_tags_by_name(self, mock_get_objects) -> None: - mock_get_objects.return_value = Tag.objects.all() - names = ["Test Tag 1", "Nonexistent Tag"] - result = match_tags_by_name(names, user=None) - self.assertEqual(len(result), 1) - self.assertEqual(result[0].name, "Test Tag 1") - - @patch("paperless_ai.matching.get_objects_for_user_owner_aware") - def test_match_correspondents_by_name(self, mock_get_objects) -> None: - mock_get_objects.return_value = Correspondent.objects.all() - names = ["Test Correspondent 1", "Nonexistent Correspondent"] - result = match_correspondents_by_name(names, user=None) - self.assertEqual(len(result), 1) - self.assertEqual(result[0].name, "Test Correspondent 1") - - @patch("paperless_ai.matching.get_objects_for_user_owner_aware") - def test_match_document_types_by_name(self, mock_get_objects) -> None: - mock_get_objects.return_value = DocumentType.objects.all() - names = ["Test Document Type 1", "Nonexistent Document Type"] - result = match_document_types_by_name(names, user=None) - self.assertEqual(len(result), 1) - self.assertEqual(result[0].name, "Test Document Type 1") - - @patch("paperless_ai.matching.get_objects_for_user_owner_aware") - def test_match_storage_paths_by_name(self, mock_get_objects) -> None: - mock_get_objects.return_value = StoragePath.objects.all() - names = ["Test Storage Path 1", "Nonexistent Storage Path"] - result = match_storage_paths_by_name(names, user=None) - self.assertEqual(len(result), 1) - self.assertEqual(result[0].name, "Test Storage Path 1") + def test_match_storage_paths_by_name(self, mocker: MockerFixture) -> None: + storage_paths = [ + StoragePathFactory.build(name="Test Storage Path 1"), + StoragePathFactory.build(name="Test Storage Path 2"), + ] + mocker.patch(_PATCH_TARGET, return_value=storage_paths) + result = match_storage_paths_by_name( + ["Test Storage Path 1", "Nonexistent Storage Path"], + user=None, + ) + assert len(result) == 1 + assert result[0].name == "Test Storage Path 1" def test_extract_unmatched_names(self) -> None: - llm_names = ["Test Tag 1", "Nonexistent Tag"] - matched_objects = [self.tag1] - unmatched_names = extract_unmatched_names(llm_names, matched_objects) - self.assertEqual(unmatched_names, ["Nonexistent Tag"]) + tag = TagFactory.build(name="Test Tag 1") + unmatched = extract_unmatched_names(["Test Tag 1", "Nonexistent Tag"], [tag]) + assert unmatched == ["Nonexistent Tag"] - @patch("paperless_ai.matching.get_objects_for_user_owner_aware") - def test_match_tags_by_name_with_empty_names(self, mock_get_objects) -> None: - mock_get_objects.return_value = Tag.objects.all() - names = [None, "", " "] - result = match_tags_by_name(names, user=None) - self.assertEqual(result, []) + def test_match_tags_by_name_with_empty_names(self, mocker: MockerFixture) -> None: + tags = [ + TagFactory.build(name="Test Tag 1"), + TagFactory.build(name="Test Tag 2"), + ] + mocker.patch(_PATCH_TARGET, return_value=tags) + result = match_tags_by_name([None, "", " "], user=None) + assert result == [] - @patch("paperless_ai.matching.get_objects_for_user_owner_aware") - def test_match_tags_with_fuzzy_matching(self, mock_get_objects) -> None: - mock_get_objects.return_value = Tag.objects.all() - names = ["Test Taag 1", "Teest Tag 2"] - result = match_tags_by_name(names, user=None) - self.assertEqual(len(result), 2) - self.assertEqual(result[0].name, "Test Tag 1") - self.assertEqual(result[1].name, "Test Tag 2") + def test_match_tags_with_fuzzy_matching(self, mocker: MockerFixture) -> None: + tags = [ + TagFactory.build(name="Test Tag 1"), + TagFactory.build(name="Test Tag 2"), + ] + mocker.patch(_PATCH_TARGET, return_value=tags) + result = match_tags_by_name(["Test Taag 1", "Teest Tag 2"], user=None) + assert len(result) == 2 + assert result[0].name == "Test Tag 1" + assert result[1].name == "Test Tag 2"