mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-06 13:49:44 +00:00
Converts all these tests to fully use fixtures, factories and compsition + dropping DB setup where possible
This commit is contained in:
@@ -95,7 +95,9 @@ def build_llm_index_text(doc: Document) -> str:
|
||||
]
|
||||
|
||||
for instance in doc.custom_fields.all():
|
||||
lines.append(f"Custom Field - {instance.field.name}: {instance}")
|
||||
lines.append(
|
||||
f"Custom Field - {instance.field.name}: {instance.value_for_search}",
|
||||
)
|
||||
|
||||
lines.append("\nContent:\n")
|
||||
lines.append(doc.content or "")
|
||||
|
||||
@@ -3,8 +3,42 @@ from pathlib import Path
|
||||
import pytest
|
||||
from pytest_django.fixtures import SettingsWrapper
|
||||
|
||||
from documents.models import Document
|
||||
from documents.tests.factories import CorrespondentFactory
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from documents.tests.factories import DocumentTypeFactory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper):
|
||||
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path:
|
||||
settings.LLM_INDEX_DIR = tmp_path
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document() -> Document:
|
||||
return DocumentFactory.build(
|
||||
title="Test Title",
|
||||
filename="test_file.pdf",
|
||||
correspondent=CorrespondentFactory.build(name="Test Correspondent"),
|
||||
document_type=DocumentTypeFactory.build(name="Invoice"),
|
||||
archive_serial_number=12345,
|
||||
content="This is the document content.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def similar_documents() -> list[Document]:
|
||||
return [
|
||||
DocumentFactory.build(
|
||||
title="Title 1",
|
||||
content="Content of document 1",
|
||||
filename="file1.txt",
|
||||
),
|
||||
DocumentFactory.build(
|
||||
title="",
|
||||
content="Content of document 2",
|
||||
filename="file2.txt",
|
||||
),
|
||||
DocumentFactory.build(title="", content="", filename=None),
|
||||
]
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.test import override_settings
|
||||
from pytest_django.fixtures import SettingsWrapper
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai.ai_classifier import build_prompt_with_rag
|
||||
@@ -12,69 +9,16 @@ from paperless_ai.ai_classifier import get_ai_document_classification
|
||||
from paperless_ai.ai_classifier import get_context_for_document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.title = "Test Title"
|
||||
doc.filename = "test_file.pdf"
|
||||
doc.created = "2023-01-01"
|
||||
doc.added = "2023-01-02"
|
||||
doc.modified = "2023-01-03"
|
||||
|
||||
tag1 = MagicMock()
|
||||
tag1.name = "Tag1"
|
||||
tag2 = MagicMock()
|
||||
tag2.name = "Tag2"
|
||||
doc.tags.all = MagicMock(return_value=[tag1, tag2])
|
||||
|
||||
doc.document_type = MagicMock()
|
||||
doc.document_type.name = "Invoice"
|
||||
doc.correspondent = MagicMock()
|
||||
doc.correspondent.name = "Test Correspondent"
|
||||
doc.archive_serial_number = "12345"
|
||||
doc.content = "This is the document content."
|
||||
|
||||
cf1 = MagicMock(__str__=lambda x: "Value1")
|
||||
cf1.field = MagicMock()
|
||||
cf1.field.name = "Field1"
|
||||
cf1.value = "Value1"
|
||||
cf2 = MagicMock(__str__=lambda x: "Value2")
|
||||
cf2.field = MagicMock()
|
||||
cf2.field.name = "Field2"
|
||||
cf2.value = "Value2"
|
||||
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
|
||||
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_similar_documents():
|
||||
doc1 = MagicMock()
|
||||
doc1.content = "Content of document 1"
|
||||
doc1.title = "Title 1"
|
||||
doc1.filename = "file1.txt"
|
||||
|
||||
doc2 = MagicMock()
|
||||
doc2.content = "Content of document 2"
|
||||
doc2.title = None
|
||||
doc2.filename = "file2.txt"
|
||||
|
||||
doc3 = MagicMock()
|
||||
doc3.content = None
|
||||
doc3.title = None
|
||||
doc3.filename = None
|
||||
|
||||
return [doc1, doc2, doc3]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
||||
mock_run_llm_query.return_value = {
|
||||
def test_get_ai_document_classification_success(
|
||||
settings: SettingsWrapper,
|
||||
mocker: MockerFixture,
|
||||
document: Document,
|
||||
) -> None:
|
||||
settings.LLM_BACKEND = "ollama"
|
||||
settings.LLM_MODEL = "some_model"
|
||||
mock_run = mocker.patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
mock_run.return_value = {
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
@@ -83,7 +27,7 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
|
||||
"dates": ["2023-01-01"],
|
||||
}
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
result = get_ai_document_classification(document)
|
||||
|
||||
assert result["title"] == "Test Title"
|
||||
assert result["tags"] == ["test", "document"]
|
||||
@@ -94,93 +38,86 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
||||
mock_run_llm_query.side_effect = Exception("LLM query failed")
|
||||
|
||||
# assert raises an exception
|
||||
def test_get_ai_document_classification_failure(
|
||||
mocker: MockerFixture,
|
||||
document: Document,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.client.AIClient.run_llm_query",
|
||||
side_effect=Exception("LLM query failed"),
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
get_ai_document_classification(mock_document)
|
||||
get_ai_document_classification(document)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_EMBEDDING_MODEL="some_model",
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_use_rag_if_configured(
|
||||
mock_build_prompt_with_rag,
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_build_prompt_with_rag.return_value = "Prompt with RAG"
|
||||
mock_run_llm_query.return_value.text = json.dumps({})
|
||||
get_ai_document_classification(mock_document)
|
||||
mock_build_prompt_with_rag.assert_called_once()
|
||||
settings: SettingsWrapper,
|
||||
mocker: MockerFixture,
|
||||
document: Document,
|
||||
) -> None:
|
||||
settings.LLM_EMBEDDING_BACKEND = "huggingface"
|
||||
settings.LLM_EMBEDDING_MODEL = "some_model"
|
||||
settings.LLM_BACKEND = "ollama"
|
||||
settings.LLM_MODEL = "some_model"
|
||||
mock_build = mocker.patch("paperless_ai.ai_classifier.build_prompt_with_rag")
|
||||
mock_build.return_value = "Prompt with RAG"
|
||||
mocker.patch("paperless_ai.client.AIClient.run_llm_query", return_value={})
|
||||
get_ai_document_classification(document)
|
||||
mock_build.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@patch("paperless_ai.ai_classifier.build_prompt_without_rag")
|
||||
@patch("paperless.config.AIConfig")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_use_without_rag_if_not_configured(
|
||||
mock_ai_config,
|
||||
mock_build_prompt_without_rag,
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
settings: SettingsWrapper,
|
||||
mocker: MockerFixture,
|
||||
document: Document,
|
||||
) -> None:
|
||||
settings.LLM_BACKEND = "ollama"
|
||||
settings.LLM_MODEL = "some_model"
|
||||
mock_ai_config = mocker.patch("paperless.config.AIConfig")
|
||||
mock_build = mocker.patch("paperless_ai.ai_classifier.build_prompt_without_rag")
|
||||
mocker.patch("paperless_ai.client.AIClient.run_llm_query", return_value={})
|
||||
mock_ai_config.llm_embedding_backend = None
|
||||
mock_build_prompt_without_rag.return_value = "Prompt without RAG"
|
||||
mock_run_llm_query.return_value.text = json.dumps({})
|
||||
get_ai_document_classification(mock_document)
|
||||
mock_build_prompt_without_rag.assert_called_once()
|
||||
mock_build.return_value = "Prompt without RAG"
|
||||
get_ai_document_classification(document)
|
||||
mock_build.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_prompt_with_without_rag(mock_document):
|
||||
with patch(
|
||||
def test_prompt_with_without_rag(mocker: MockerFixture, document: Document) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.ai_classifier.get_context_for_document",
|
||||
return_value="Context from similar documents",
|
||||
):
|
||||
prompt = build_prompt_without_rag(mock_document)
|
||||
assert "Additional context from similar documents:" not in prompt
|
||||
)
|
||||
prompt = build_prompt_without_rag(document)
|
||||
assert "Additional context from similar documents:" not in prompt
|
||||
|
||||
prompt = build_prompt_with_rag(mock_document)
|
||||
assert "Additional context from similar documents:" in prompt
|
||||
prompt = build_prompt_with_rag(document)
|
||||
assert "Additional context from similar documents:" in prompt
|
||||
|
||||
|
||||
@patch("paperless_ai.ai_classifier.query_similar_documents")
|
||||
def test_get_context_for_document(
|
||||
mock_query_similar_documents,
|
||||
mock_document,
|
||||
mock_similar_documents,
|
||||
):
|
||||
mock_query_similar_documents.return_value = mock_similar_documents
|
||||
|
||||
result = get_context_for_document(mock_document, max_docs=2)
|
||||
|
||||
expected_result = (
|
||||
mocker: MockerFixture,
|
||||
document: Document,
|
||||
similar_documents: list[Document],
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.ai_classifier.query_similar_documents",
|
||||
return_value=similar_documents,
|
||||
)
|
||||
result = get_context_for_document(document, max_docs=2)
|
||||
assert result == (
|
||||
"TITLE: Title 1\nContent of document 1\n\n"
|
||||
"TITLE: file2.txt\nContent of document 2"
|
||||
)
|
||||
assert result == expected_result
|
||||
mock_query_similar_documents.assert_called_once()
|
||||
|
||||
|
||||
def test_get_context_for_document_no_similar_docs(mock_document):
|
||||
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
|
||||
result = get_context_for_document(mock_document)
|
||||
assert result == ""
|
||||
def test_get_context_for_document_no_similar_docs(
|
||||
mocker: MockerFixture,
|
||||
document: Document,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"paperless_ai.ai_classifier.query_similar_documents",
|
||||
return_value=[],
|
||||
)
|
||||
assert get_context_for_document(document) == ""
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
from pytest_django.fixtures import SettingsWrapper
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import PaperlessTask
|
||||
@@ -14,7 +15,7 @@ from paperless_ai import indexing
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_document(db):
|
||||
def real_document(db) -> Document:
|
||||
return Document.objects.create(
|
||||
title="Test Document",
|
||||
content="This is some test content.",
|
||||
@@ -23,36 +24,15 @@ def real_document(db):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model():
|
||||
fake = FakeEmbedding()
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
|
||||
patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
) as mock_embedding,
|
||||
):
|
||||
mock_index.return_value = fake
|
||||
mock_embedding.return_value = fake
|
||||
yield mock_index
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
# TODO: maybe a better way to do this?
|
||||
def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def get_query_embedding_dim(self) -> int:
|
||||
return 384 # Match your real FAISS config
|
||||
def mock_embed_model(mocker: MockerFixture) -> MagicMock:
|
||||
fake = MockEmbedding(embed_dim=384)
|
||||
mock = mocker.patch("paperless_ai.indexing.get_embedding_model", return_value=fake)
|
||||
mocker.patch("paperless_ai.embedding.get_embedding_model", return_value=fake)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node(real_document) -> None:
|
||||
def test_build_document_node(real_document: Document) -> None:
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0
|
||||
assert nodes[0].metadata["document_id"] == str(real_document.id)
|
||||
@@ -60,37 +40,38 @@ def test_build_document_node(real_document) -> None:
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mock_all.return_value = mock_queryset
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_removes_meta(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Pre-create a meta.json with incorrect data
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "old", "dim": 1}),
|
||||
)
|
||||
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mock_all.return_value = mock_queryset
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset)
|
||||
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||
from paperless.config import AIConfig
|
||||
@@ -106,9 +87,10 @@ def test_update_llm_index_removes_meta(
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_partial_update(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
doc2 = Document.objects.create(
|
||||
title="Test Document 2",
|
||||
@@ -116,20 +98,16 @@ def test_update_llm_index_partial_update(
|
||||
added=timezone.now(),
|
||||
checksum="1234567890abcdef",
|
||||
)
|
||||
# Initial index
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document, doc2])
|
||||
mock_all.return_value = mock_queryset
|
||||
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document, doc2])
|
||||
mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset)
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
# modify document
|
||||
updated_document = real_document
|
||||
updated_document.modified = timezone.now() # simulate modification
|
||||
updated_document.modified = timezone.now()
|
||||
|
||||
# new doc
|
||||
doc3 = Document.objects.create(
|
||||
title="Test Document 3",
|
||||
content="This is some test content 3.",
|
||||
@@ -137,110 +115,101 @@ def test_update_llm_index_partial_update(
|
||||
checksum="abcdef1234567890",
|
||||
)
|
||||
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3])
|
||||
mock_all.return_value = mock_queryset
|
||||
mock_queryset2 = MagicMock()
|
||||
mock_queryset2.exists.return_value = True
|
||||
mock_queryset2.__iter__.return_value = iter([updated_document, doc2, doc3])
|
||||
mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset2)
|
||||
|
||||
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
|
||||
with patch("paperless_ai.indexing.logger") as mock_logger:
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
mock_logger.info.assert_called_once_with(
|
||||
"Updating %d nodes in LLM index.",
|
||||
2,
|
||||
)
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
mock_logger = mocker.patch("paperless_ai.indexing.logger")
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
mock_logger.info.assert_called_once_with("Updating %d nodes in LLM index.", 2)
|
||||
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
def test_get_or_create_storage_context_raises_exception(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: MagicMock,
|
||||
) -> None:
|
||||
with pytest.raises(Exception):
|
||||
indexing.get_or_create_storage_context(rebuild=False)
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
)
|
||||
@pytest.mark.django_db
|
||||
def test_load_or_build_index_builds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"llama_index.core.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
) as mock_index_cls,
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
) as mock_storage,
|
||||
):
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
indexing.load_or_build_index(
|
||||
nodes=[indexing.build_document_node(real_document)],
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
mocker.patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
)
|
||||
mock_index_cls = mocker.patch(
|
||||
"llama_index.core.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mock_storage = mocker.patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
|
||||
indexing.load_or_build_index(nodes=[indexing.build_document_node(real_document)])
|
||||
|
||||
mock_index_cls.assert_called_once()
|
||||
|
||||
|
||||
def test_load_or_build_index_raises_exception_when_no_nodes(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
indexing.load_or_build_index()
|
||||
mocker.patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
)
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
indexing.load_or_build_index()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_load_or_build_index_succeeds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"llama_index.core.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
) as mock_index_cls,
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
) as mock_storage,
|
||||
):
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
indexing.load_or_build_index(
|
||||
nodes=[MagicMock()],
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
mocker.patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
)
|
||||
mock_index_cls = mocker.patch(
|
||||
"llama_index.core.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mock_storage = mocker.patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
|
||||
indexing.load_or_build_index(nodes=[MagicMock()])
|
||||
|
||||
mock_index_cls.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_add_or_update_document_updates_existing_entry(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: MagicMock,
|
||||
) -> None:
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
@@ -250,9 +219,9 @@ def test_add_or_update_document_updates_existing_entry(
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_remove_document_deletes_node_from_docstore(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: MagicMock,
|
||||
) -> None:
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
index = indexing.load_or_build_index()
|
||||
@@ -265,31 +234,29 @@ def test_remove_document_deletes_node_from_docstore(
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_no_documents(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = False
|
||||
mock_queryset.__iter__.return_value = iter([])
|
||||
mock_all.return_value = mock_queryset
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = False
|
||||
mock_queryset.__iter__.return_value = iter([])
|
||||
mocker.patch("documents.models.Document.objects.all", return_value=mock_queryset)
|
||||
|
||||
# check log message
|
||||
with patch("paperless_ai.indexing.logger") as mock_logger:
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"No documents found to index.",
|
||||
)
|
||||
mock_logger = mocker.patch("paperless_ai.indexing.logger")
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
mock_logger.warning.assert_called_once_with("No documents found to index.")
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -> None:
|
||||
# No existing tasks
|
||||
with patch("documents.tasks.llmindex_index") as mock_task:
|
||||
result = indexing.queue_llm_index_update_if_needed(
|
||||
rebuild=True,
|
||||
reason="test enqueue",
|
||||
)
|
||||
def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mock_task = mocker.patch("documents.tasks.llmindex_index")
|
||||
result = indexing.queue_llm_index_update_if_needed(
|
||||
rebuild=True,
|
||||
reason="test enqueue",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_task.apply_async.assert_called_once_with(
|
||||
@@ -303,86 +270,71 @@ def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -
|
||||
status=PaperlessTask.Status.STARTED,
|
||||
)
|
||||
|
||||
# Existing running task
|
||||
with patch("documents.tasks.llmindex_index") as mock_task:
|
||||
result = indexing.queue_llm_index_update_if_needed(
|
||||
rebuild=False,
|
||||
reason="should skip",
|
||||
)
|
||||
mock_task2 = mocker.patch("documents.tasks.llmindex_index")
|
||||
result = indexing.queue_llm_index_update_if_needed(
|
||||
rebuild=False,
|
||||
reason="should skip",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
mock_task.apply_async.assert_not_called()
|
||||
mock_task2.apply_async.assert_not_called()
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_BACKEND="ollama",
|
||||
)
|
||||
@pytest.mark.django_db
|
||||
def test_query_similar_documents(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mocker: MockerFixture,
|
||||
settings: SettingsWrapper,
|
||||
) -> None:
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
) as mock_vector_store_exists,
|
||||
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
|
||||
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
|
||||
):
|
||||
mock_storage.return_value = MagicMock()
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
mock_vector_store_exists.return_value = True
|
||||
settings.LLM_EMBEDDING_BACKEND = "huggingface"
|
||||
settings.LLM_BACKEND = "ollama"
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_load_or_build_index.return_value = mock_index
|
||||
mock_storage = mocker.patch("paperless_ai.indexing.get_or_create_storage_context")
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
mocker.patch("paperless_ai.indexing.vector_store_file_exists", return_value=True)
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
mock_index = MagicMock()
|
||||
mocker.patch("paperless_ai.indexing.load_or_build_index", return_value=mock_index)
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node1.metadata = {"document_id": 1}
|
||||
mock_retriever = MagicMock()
|
||||
mocker.patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
return_value=mock_retriever,
|
||||
)
|
||||
|
||||
mock_node2 = MagicMock()
|
||||
mock_node2.metadata = {"document_id": 2}
|
||||
mock_node1 = MagicMock()
|
||||
mock_node1.metadata = {"document_id": 1}
|
||||
mock_node2 = MagicMock()
|
||||
mock_node2.metadata = {"document_id": 2}
|
||||
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
||||
|
||||
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
||||
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
|
||||
mock_filter = mocker.patch(
|
||||
"paperless_ai.indexing.Document.objects.filter",
|
||||
return_value=mock_filtered_docs,
|
||||
)
|
||||
|
||||
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
|
||||
mock_filter.return_value = mock_filtered_docs
|
||||
result = indexing.query_similar_documents(real_document, top_k=3)
|
||||
|
||||
result = indexing.query_similar_documents(real_document, top_k=3)
|
||||
|
||||
mock_load_or_build_index.assert_called_once()
|
||||
mock_retriever_cls.assert_called_once()
|
||||
mock_retriever.retrieve.assert_called_once_with(
|
||||
"Test Document\nThis is some test content.",
|
||||
)
|
||||
mock_filter.assert_called_once_with(pk__in=[1, 2])
|
||||
|
||||
assert result == mock_filtered_docs
|
||||
mock_retriever.retrieve.assert_called_once_with(
|
||||
"Test Document\nThis is some test content.",
|
||||
)
|
||||
mock_filter.assert_called_once_with(pk__in=[1, 2])
|
||||
assert result == mock_filtered_docs
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_query_similar_documents_triggers_update_when_index_missing(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"paperless_ai.indexing.queue_llm_index_update_if_needed",
|
||||
) as mock_queue,
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load,
|
||||
):
|
||||
result = indexing.query_similar_documents(
|
||||
real_document,
|
||||
top_k=2,
|
||||
)
|
||||
mocker.patch("paperless_ai.indexing.vector_store_file_exists", return_value=False)
|
||||
mock_queue = mocker.patch("paperless_ai.indexing.queue_llm_index_update_if_needed")
|
||||
mock_load = mocker.patch("paperless_ai.indexing.load_or_build_index")
|
||||
|
||||
result = indexing.query_similar_documents(real_document, top_k=2)
|
||||
|
||||
mock_queue.assert_called_once_with(
|
||||
rebuild=False,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from paperless_ai.chat import CHAT_METADATA_DELIMITER
|
||||
from paperless_ai.chat import stream_chat_with_documents
|
||||
@@ -15,31 +15,17 @@ def patch_embed_model():
|
||||
from llama_index.core import settings as llama_settings
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
|
||||
# Use a real BaseEmbedding subclass to satisfy llama-index 0.14 validation
|
||||
llama_settings.Settings.embed_model = MockEmbedding(embed_dim=1536)
|
||||
yield
|
||||
llama_settings.Settings.embed_model = None
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_embed_nodes():
|
||||
with patch(
|
||||
"llama_index.core.indices.vector_store.base.embed_nodes",
|
||||
) as mock_embed_nodes:
|
||||
mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: {
|
||||
node.node_id: [0.1] * 1536 for node in nodes
|
||||
}
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock()
|
||||
doc.pk = 1
|
||||
doc.title = "Test Document"
|
||||
doc.filename = "test_file.pdf"
|
||||
doc.content = "This is the document content."
|
||||
return doc
|
||||
def patch_embed_nodes(mocker: MockerFixture):
|
||||
mock = mocker.patch("llama_index.core.indices.vector_store.base.embed_nodes")
|
||||
mock.side_effect = lambda nodes, *_args, **_kwargs: {
|
||||
node.node_id: [0.1] * 1536 for node in nodes
|
||||
}
|
||||
|
||||
|
||||
def assert_chat_output(
|
||||
@@ -57,127 +43,110 @@ def assert_chat_output(
|
||||
}
|
||||
|
||||
|
||||
def test_stream_chat_with_one_document_full_content(mock_document) -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
def test_stream_chat_with_one_document_full_content(mocker: MockerFixture) -> None:
|
||||
mock_document = MagicMock()
|
||||
mock_document.pk = 1
|
||||
mock_document.title = "Test Document"
|
||||
mock_document.filename = "test_file.pdf"
|
||||
mock_document.content = "This is the document content."
|
||||
|
||||
mock_node = TextNode(
|
||||
text="This is node content.",
|
||||
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||
mock_load_index.return_value = mock_index
|
||||
mock_client = MagicMock()
|
||||
mocker.patch("paperless_ai.chat.AIClient", return_value=mock_client)
|
||||
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
mock_node = TextNode(
|
||||
text="This is node content.",
|
||||
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||
mocker.patch("paperless_ai.chat.load_or_build_index", return_value=mock_index)
|
||||
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
mocker.patch(
|
||||
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
|
||||
return_value=mock_query_engine,
|
||||
)
|
||||
|
||||
assert_chat_output(
|
||||
output,
|
||||
expected_chunks=["chunk1", "chunk2"],
|
||||
expected_references=[
|
||||
{"id": mock_document.pk, "title": "Test Document"},
|
||||
],
|
||||
)
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
|
||||
assert_chat_output(
|
||||
output,
|
||||
expected_chunks=["chunk1", "chunk2"],
|
||||
expected_references=[{"id": mock_document.pk, "title": "Test Document"}],
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
|
||||
):
|
||||
# Mock AIClient and LLM
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
def test_stream_chat_with_multiple_documents_retrieval(
|
||||
patch_embed_nodes,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mock_client = MagicMock()
|
||||
mocker.patch("paperless_ai.chat.AIClient", return_value=mock_client)
|
||||
|
||||
# Create two real TextNodes
|
||||
mock_node1 = TextNode(
|
||||
text="Content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1"},
|
||||
)
|
||||
mock_node2 = TextNode(
|
||||
text="Content for doc 2.",
|
||||
metadata={"document_id": "2", "title": "Document 2"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
|
||||
mock_load_index.return_value = mock_index
|
||||
mock_node1 = TextNode(
|
||||
text="Content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1"},
|
||||
)
|
||||
mock_node2 = TextNode(
|
||||
text="Content for doc 2.",
|
||||
metadata={"document_id": "2", "title": "Document 2"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
|
||||
mocker.patch("paperless_ai.chat.load_or_build_index", return_value=mock_index)
|
||||
|
||||
# Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
|
||||
mock_retriever = MagicMock()
|
||||
mock_duplicate_node = TextNode(
|
||||
text="More content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
|
||||
)
|
||||
mock_foreign_node = TextNode(
|
||||
text="Content for doc 3.",
|
||||
metadata={"document_id": "3", "title": "Document 3"},
|
||||
)
|
||||
mock_retriever.retrieve.return_value = [
|
||||
mock_node1,
|
||||
mock_duplicate_node,
|
||||
mock_node2,
|
||||
mock_foreign_node,
|
||||
]
|
||||
mock_as_retriever.return_value = mock_retriever
|
||||
mock_retriever = MagicMock()
|
||||
mock_duplicate_node = TextNode(
|
||||
text="More content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
|
||||
)
|
||||
mock_foreign_node = TextNode(
|
||||
text="Content for doc 3.",
|
||||
metadata={"document_id": "3", "title": "Document 3"},
|
||||
)
|
||||
mock_retriever.retrieve.return_value = [
|
||||
mock_node1,
|
||||
mock_duplicate_node,
|
||||
mock_node2,
|
||||
mock_foreign_node,
|
||||
]
|
||||
mocker.patch.object(VectorStoreIndex, "as_retriever", return_value=mock_retriever)
|
||||
|
||||
# Mock response stream
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
mocker.patch(
|
||||
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
|
||||
return_value=mock_query_engine,
|
||||
)
|
||||
|
||||
# Mock RetrieverQueryEngine
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
|
||||
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
|
||||
|
||||
# Fake documents
|
||||
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
|
||||
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
|
||||
assert_chat_output(
|
||||
output,
|
||||
expected_chunks=["chunk1", "chunk2"],
|
||||
expected_references=[
|
||||
{"id": 1, "title": "Document 1"},
|
||||
{"id": 2, "title": "Document 2"},
|
||||
],
|
||||
)
|
||||
assert_chat_output(
|
||||
output,
|
||||
expected_chunks=["chunk1", "chunk2"],
|
||||
expected_references=[
|
||||
{"id": 1, "title": "Document 1"},
|
||||
{"id": 2, "title": "Document 2"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_no_matching_nodes() -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
def test_stream_chat_no_matching_nodes(mocker: MockerFixture) -> None:
|
||||
mock_client = MagicMock()
|
||||
mocker.patch("paperless_ai.chat.AIClient", return_value=mock_client)
|
||||
|
||||
mock_index = MagicMock()
|
||||
# No matching nodes
|
||||
mock_index.docstore.docs.values.return_value = []
|
||||
mock_load_index.return_value = mock_index
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = []
|
||||
mocker.patch("paperless_ai.chat.load_or_build_index", return_value=mock_index)
|
||||
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
|
||||
assert output == ["Sorry, I couldn't find any content to answer your question."]
|
||||
assert output == ["Sorry, I couldn't find any content to answer your question."]
|
||||
|
||||
@@ -1,38 +1,35 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.core.llms.llm import ToolSelection
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.client import AIClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_config():
|
||||
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
|
||||
mock_config = MagicMock()
|
||||
mock_config.llm_allow_internal_endpoints = True
|
||||
MockAIConfig.return_value = mock_config
|
||||
yield mock_config
|
||||
def mock_ai_config(mocker: MockerFixture) -> MagicMock:
|
||||
mock = mocker.patch("paperless_ai.client.AIConfig", spec=AIConfig)
|
||||
mock.return_value.llm_allow_internal_endpoints = True
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ollama_llm():
|
||||
with patch("llama_index.llms.ollama.Ollama") as MockOllama:
|
||||
yield MockOllama
|
||||
def mock_ollama_llm(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("llama_index.llms.ollama.Ollama")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_llm():
|
||||
with patch("llama_index.llms.openai_like.OpenAILike") as MockOpenAILike:
|
||||
yield MockOpenAILike
|
||||
def mock_openai_llm(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("llama_index.llms.openai_like.OpenAILike")
|
||||
|
||||
|
||||
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_endpoint = "http://test-url"
|
||||
def test_get_llm_ollama(mock_ai_config: MagicMock, mock_ollama_llm: MagicMock) -> None:
|
||||
mock_ai_config.return_value.llm_backend = "ollama"
|
||||
mock_ai_config.return_value.llm_model = "test_model"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://test-url"
|
||||
|
||||
client = AIClient()
|
||||
|
||||
@@ -44,11 +41,11 @@ def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
|
||||
assert client.llm == mock_ollama_llm.return_value
|
||||
|
||||
|
||||
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
|
||||
mock_ai_config.llm_backend = "openai-like"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_api_key = "test_api_key"
|
||||
mock_ai_config.llm_endpoint = "http://test-url"
|
||||
def test_get_llm_openai(mock_ai_config: MagicMock, mock_openai_llm: MagicMock) -> None:
|
||||
mock_ai_config.return_value.llm_backend = "openai-like"
|
||||
mock_ai_config.return_value.llm_model = "test_model"
|
||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://test-url"
|
||||
|
||||
client = AIClient()
|
||||
|
||||
@@ -62,31 +59,32 @@ def test_get_llm_openai(mock_ai_config, mock_openai_llm):
|
||||
assert client.llm == mock_openai_llm.return_value
|
||||
|
||||
|
||||
def test_get_llm_openai_blocks_internal_endpoint_when_disallowed(mock_ai_config):
|
||||
mock_ai_config.llm_backend = "openai-like"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_api_key = "test_api_key"
|
||||
mock_ai_config.llm_endpoint = "http://127.0.0.1:1234"
|
||||
mock_ai_config.llm_allow_internal_endpoints = False
|
||||
def test_get_llm_openai_blocks_internal_endpoint_when_disallowed(
|
||||
mock_ai_config: MagicMock,
|
||||
) -> None:
|
||||
mock_ai_config.return_value.llm_backend = "openai-like"
|
||||
mock_ai_config.return_value.llm_model = "test_model"
|
||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://127.0.0.1:1234"
|
||||
mock_ai_config.return_value.llm_allow_internal_endpoints = False
|
||||
|
||||
with pytest.raises(ValueError, match="non-public address"):
|
||||
AIClient()
|
||||
|
||||
|
||||
def test_get_llm_unsupported_backend(mock_ai_config):
|
||||
mock_ai_config.llm_backend = "unsupported"
|
||||
def test_get_llm_unsupported_backend(mock_ai_config: MagicMock) -> None:
|
||||
mock_ai_config.return_value.llm_backend = "unsupported"
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
|
||||
AIClient()
|
||||
|
||||
|
||||
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_endpoint = "http://test-url"
|
||||
def test_run_llm_query(mock_ai_config: MagicMock, mock_ollama_llm: MagicMock) -> None:
|
||||
mock_ai_config.return_value.llm_backend = "ollama"
|
||||
mock_ai_config.return_value.llm_model = "test_model"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
|
||||
tool_selection = ToolSelection(
|
||||
tool_id="call_test",
|
||||
tool_name="DocumentClassifierSchema",
|
||||
@@ -99,7 +97,6 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
)
|
||||
|
||||
mock_llm_instance.chat_with_tools.return_value = MagicMock()
|
||||
mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
|
||||
|
||||
@@ -109,10 +106,10 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
||||
assert result["title"] == "Test Title"
|
||||
|
||||
|
||||
def test_run_chat(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_endpoint = "http://test-url"
|
||||
def test_run_chat(mock_ai_config: MagicMock, mock_ollama_llm: MagicMock) -> None:
|
||||
mock_ai_config.return_value.llm_backend = "ollama"
|
||||
mock_ai_config.return_value.llm_model = "test_model"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
mock_llm_instance.chat.return_value = "test_chat_result"
|
||||
|
||||
@@ -1,10 +1,20 @@
|
||||
import datetime
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from documents.models import CustomField
|
||||
from documents.models import CustomFieldInstance
|
||||
from documents.models import Document
|
||||
from documents.models import Note
|
||||
from documents.tests.factories import CorrespondentFactory
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from documents.tests.factories import DocumentTypeFactory
|
||||
from documents.tests.factories import TagFactory
|
||||
from paperless.config import AIConfig
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
from paperless_ai.embedding import build_llm_index_text
|
||||
from paperless_ai.embedding import get_embedding_dim
|
||||
@@ -12,97 +22,89 @@ from paperless_ai.embedding import get_embedding_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_config():
|
||||
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
|
||||
MockAIConfig.return_value.llm_allow_internal_endpoints = True
|
||||
yield MockAIConfig
|
||||
def mock_ai_config(mocker: MockerFixture) -> MagicMock:
|
||||
mock = mocker.patch("paperless_ai.embedding.AIConfig", spec=AIConfig)
|
||||
mock.return_value.llm_allow_internal_endpoints = True
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.title = "Test Title"
|
||||
doc.filename = "test_file.pdf"
|
||||
doc.created = "2023-01-01"
|
||||
doc.added = "2023-01-02"
|
||||
doc.modified = "2023-01-03"
|
||||
|
||||
tag1 = MagicMock()
|
||||
tag1.name = "Tag1"
|
||||
tag2 = MagicMock()
|
||||
tag2.name = "Tag2"
|
||||
doc.tags.all = MagicMock(return_value=[tag1, tag2])
|
||||
|
||||
doc.document_type = MagicMock()
|
||||
doc.document_type.name = "Invoice"
|
||||
doc.correspondent = MagicMock()
|
||||
doc.correspondent.name = "Test Correspondent"
|
||||
doc.archive_serial_number = "12345"
|
||||
doc.content = "This is the document content."
|
||||
|
||||
cf1 = MagicMock(__str__=lambda x: "Value1")
|
||||
cf1.field = MagicMock()
|
||||
cf1.field.name = "Field1"
|
||||
cf1.value = "Value1"
|
||||
cf2 = MagicMock(__str__=lambda x: "Value2")
|
||||
cf2.field = MagicMock()
|
||||
cf2.field.name = "Field2"
|
||||
cf2.value = "Value2"
|
||||
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
|
||||
|
||||
def full_document(db) -> Document:
|
||||
tag1 = TagFactory(name="Tag1")
|
||||
tag2 = TagFactory(name="Tag2")
|
||||
doc = DocumentFactory(
|
||||
title="Test Title",
|
||||
filename="test_file.pdf",
|
||||
created=datetime.date(2023, 1, 1),
|
||||
correspondent=CorrespondentFactory(name="Test Correspondent"),
|
||||
document_type=DocumentTypeFactory(name="Invoice"),
|
||||
archive_serial_number=12345,
|
||||
content="This is the document content.",
|
||||
)
|
||||
doc.tags.add(tag1, tag2)
|
||||
cf1 = CustomField.objects.create(
|
||||
name="Field1",
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
cf2 = CustomField.objects.create(
|
||||
name="Field2",
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
CustomFieldInstance.objects.create(document=doc, field=cf1, value_text="Value1")
|
||||
CustomFieldInstance.objects.create(document=doc, field=cf2, value_text="Value2")
|
||||
Note.objects.create(document=doc, note="Note1")
|
||||
Note.objects.create(document=doc, note="Note2")
|
||||
return doc
|
||||
|
||||
|
||||
def test_get_embedding_model_openai(mock_ai_config):
|
||||
def test_get_embedding_model_openai(
|
||||
mock_ai_config: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
|
||||
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://test-url"
|
||||
|
||||
with patch(
|
||||
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
|
||||
) as MockOpenAIEmbedding:
|
||||
model = get_embedding_model()
|
||||
MockOpenAIEmbedding.assert_called_once_with(
|
||||
model_name="text-embedding-3-small",
|
||||
api_key="test_api_key",
|
||||
api_base="http://test-url",
|
||||
)
|
||||
assert model == MockOpenAIEmbedding.return_value
|
||||
mock_cls = mocker.patch("llama_index.embeddings.openai_like.OpenAILikeEmbedding")
|
||||
model = get_embedding_model()
|
||||
mock_cls.assert_called_once_with(
|
||||
model_name="text-embedding-3-small",
|
||||
api_key="test_api_key",
|
||||
api_base="http://test-url",
|
||||
)
|
||||
assert model == mock_cls.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
|
||||
mock_ai_config,
|
||||
):
|
||||
mock_ai_config: MagicMock,
|
||||
) -> None:
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
|
||||
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://127.0.0.1:11434"
|
||||
mock_ai_config.return_value.llm_allow_internal_endpoints = False
|
||||
|
||||
with pytest.raises(ValueError, match="non-public address"):
|
||||
get_embedding_model()
|
||||
|
||||
|
||||
def test_get_embedding_model_huggingface(mock_ai_config):
|
||||
def test_get_embedding_model_huggingface(
|
||||
mock_ai_config: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE
|
||||
mock_ai_config.return_value.llm_embedding_model = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"llama_index.embeddings.huggingface.HuggingFaceEmbedding",
|
||||
) as MockHuggingFaceEmbedding:
|
||||
model = get_embedding_model()
|
||||
MockHuggingFaceEmbedding.assert_called_once_with(
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
assert model == MockHuggingFaceEmbedding.return_value
|
||||
mock_cls = mocker.patch("llama_index.embeddings.huggingface.HuggingFaceEmbedding")
|
||||
model = get_embedding_model()
|
||||
mock_cls.assert_called_once_with(
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
assert model == mock_cls.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_invalid_backend(mock_ai_config):
|
||||
def test_get_embedding_model_invalid_backend(mock_ai_config: MagicMock) -> None:
|
||||
mock_ai_config.return_value.llm_embedding_backend = "INVALID_BACKEND"
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Unsupported embedding backend: INVALID_BACKEND",
|
||||
@@ -110,47 +112,53 @@ def test_get_embedding_model_invalid_backend(mock_ai_config):
|
||||
get_embedding_model()
|
||||
|
||||
|
||||
def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config):
|
||||
def test_get_embedding_dim_infers_and_saves(
|
||||
temp_llm_index_dir: Path,
|
||||
mock_ai_config: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
class DummyEmbedding:
|
||||
def get_text_embedding(self, text):
|
||||
def get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.0] * 7
|
||||
|
||||
with patch(
|
||||
mock_get = mocker.patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
return_value=DummyEmbedding(),
|
||||
) as mock_get:
|
||||
dim = get_embedding_dim()
|
||||
mock_get.assert_called_once()
|
||||
|
||||
)
|
||||
dim = get_embedding_dim()
|
||||
mock_get.assert_called_once()
|
||||
assert dim == 7
|
||||
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||
assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7}
|
||||
|
||||
|
||||
def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config):
|
||||
def test_get_embedding_dim_reads_existing_meta(
|
||||
temp_llm_index_dir: Path,
|
||||
mock_ai_config: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}),
|
||||
)
|
||||
|
||||
with patch("paperless_ai.embedding.get_embedding_model") as mock_get:
|
||||
assert get_embedding_dim() == 11
|
||||
mock_get.assert_not_called()
|
||||
mock_get = mocker.patch("paperless_ai.embedding.get_embedding_model")
|
||||
assert get_embedding_dim() == 11
|
||||
mock_get.assert_not_called()
|
||||
|
||||
|
||||
def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config):
|
||||
def test_get_embedding_dim_raises_on_model_change(
|
||||
temp_llm_index_dir: Path,
|
||||
mock_ai_config: MagicMock,
|
||||
) -> None:
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "old", "dim": 11}),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Embedding model changed from old to text-embedding-3-small",
|
||||
@@ -158,21 +166,15 @@ def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_co
|
||||
get_embedding_dim()
|
||||
|
||||
|
||||
def test_build_llm_index_text(mock_document):
|
||||
with patch("documents.models.Note.objects.filter") as mock_notes_filter:
|
||||
mock_notes_filter.return_value = [
|
||||
MagicMock(note="Note1"),
|
||||
MagicMock(note="Note2"),
|
||||
]
|
||||
|
||||
result = build_llm_index_text(mock_document)
|
||||
|
||||
assert "Title: Test Title" in result
|
||||
assert "Filename: test_file.pdf" in result
|
||||
assert "Created: 2023-01-01" in result
|
||||
assert "Tags: Tag1, Tag2" in result
|
||||
assert "Document Type: Invoice" in result
|
||||
assert "Correspondent: Test Correspondent" in result
|
||||
assert "Notes: Note1,Note2" in result
|
||||
assert "Content:\n\nThis is the document content." in result
|
||||
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
|
||||
@pytest.mark.django_db
|
||||
def test_build_llm_index_text(full_document: Document) -> None:
|
||||
result = build_llm_index_text(full_document)
|
||||
assert "Title: Test Title" in result
|
||||
assert "Filename: test_file.pdf" in result
|
||||
assert "Created: 2023-01-01" in result
|
||||
assert "Tags: Tag1, Tag2" in result
|
||||
assert "Document Type: Invoice" in result
|
||||
assert "Correspondent: Test Correspondent" in result
|
||||
assert "Notes: Note1,Note2" in result
|
||||
assert "Content:\n\nThis is the document content." in result
|
||||
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
|
||||
|
||||
@@ -1,86 +1,89 @@
|
||||
from unittest.mock import patch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.tests.factories import CorrespondentFactory
|
||||
from documents.tests.factories import DocumentTypeFactory
|
||||
from documents.tests.factories import StoragePathFactory
|
||||
from documents.tests.factories import TagFactory
|
||||
from paperless_ai.matching import extract_unmatched_names
|
||||
from paperless_ai.matching import match_correspondents_by_name
|
||||
from paperless_ai.matching import match_document_types_by_name
|
||||
from paperless_ai.matching import match_storage_paths_by_name
|
||||
from paperless_ai.matching import match_tags_by_name
|
||||
|
||||
_PATCH_TARGET = "paperless_ai.matching.get_objects_for_user_owner_aware"
|
||||
|
||||
class TestAIMatching(TestCase):
|
||||
def setUp(self) -> None:
|
||||
# Create test data for Tag
|
||||
self.tag1 = Tag.objects.create(name="Test Tag 1")
|
||||
self.tag2 = Tag.objects.create(name="Test Tag 2")
|
||||
|
||||
# Create test data for Correspondent
|
||||
self.correspondent1 = Correspondent.objects.create(name="Test Correspondent 1")
|
||||
self.correspondent2 = Correspondent.objects.create(name="Test Correspondent 2")
|
||||
class TestAIMatching:
|
||||
def test_match_tags_by_name(self, mocker: MockerFixture) -> None:
|
||||
tags = [
|
||||
TagFactory.build(name="Test Tag 1"),
|
||||
TagFactory.build(name="Test Tag 2"),
|
||||
]
|
||||
mocker.patch(_PATCH_TARGET, return_value=tags)
|
||||
result = match_tags_by_name(["Test Tag 1", "Nonexistent Tag"], user=None)
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "Test Tag 1"
|
||||
|
||||
# Create test data for DocumentType
|
||||
self.document_type1 = DocumentType.objects.create(name="Test Document Type 1")
|
||||
self.document_type2 = DocumentType.objects.create(name="Test Document Type 2")
|
||||
def test_match_correspondents_by_name(self, mocker: MockerFixture) -> None:
|
||||
correspondents = [
|
||||
CorrespondentFactory.build(name="Test Correspondent 1"),
|
||||
CorrespondentFactory.build(name="Test Correspondent 2"),
|
||||
]
|
||||
mocker.patch(_PATCH_TARGET, return_value=correspondents)
|
||||
result = match_correspondents_by_name(
|
||||
["Test Correspondent 1", "Nonexistent Correspondent"],
|
||||
user=None,
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "Test Correspondent 1"
|
||||
|
||||
# Create test data for StoragePath
|
||||
self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1")
|
||||
self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2")
|
||||
def test_match_document_types_by_name(self, mocker: MockerFixture) -> None:
|
||||
document_types = [
|
||||
DocumentTypeFactory.build(name="Test Document Type 1"),
|
||||
DocumentTypeFactory.build(name="Test Document Type 2"),
|
||||
]
|
||||
mocker.patch(_PATCH_TARGET, return_value=document_types)
|
||||
result = match_document_types_by_name(
|
||||
["Test Document Type 1", "Nonexistent Document Type"],
|
||||
user=None,
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "Test Document Type 1"
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_tags_by_name(self, mock_get_objects) -> None:
|
||||
mock_get_objects.return_value = Tag.objects.all()
|
||||
names = ["Test Tag 1", "Nonexistent Tag"]
|
||||
result = match_tags_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Tag 1")
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_correspondents_by_name(self, mock_get_objects) -> None:
|
||||
mock_get_objects.return_value = Correspondent.objects.all()
|
||||
names = ["Test Correspondent 1", "Nonexistent Correspondent"]
|
||||
result = match_correspondents_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Correspondent 1")
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_document_types_by_name(self, mock_get_objects) -> None:
|
||||
mock_get_objects.return_value = DocumentType.objects.all()
|
||||
names = ["Test Document Type 1", "Nonexistent Document Type"]
|
||||
result = match_document_types_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Document Type 1")
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_storage_paths_by_name(self, mock_get_objects) -> None:
|
||||
mock_get_objects.return_value = StoragePath.objects.all()
|
||||
names = ["Test Storage Path 1", "Nonexistent Storage Path"]
|
||||
result = match_storage_paths_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Storage Path 1")
|
||||
def test_match_storage_paths_by_name(self, mocker: MockerFixture) -> None:
|
||||
storage_paths = [
|
||||
StoragePathFactory.build(name="Test Storage Path 1"),
|
||||
StoragePathFactory.build(name="Test Storage Path 2"),
|
||||
]
|
||||
mocker.patch(_PATCH_TARGET, return_value=storage_paths)
|
||||
result = match_storage_paths_by_name(
|
||||
["Test Storage Path 1", "Nonexistent Storage Path"],
|
||||
user=None,
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "Test Storage Path 1"
|
||||
|
||||
def test_extract_unmatched_names(self) -> None:
|
||||
llm_names = ["Test Tag 1", "Nonexistent Tag"]
|
||||
matched_objects = [self.tag1]
|
||||
unmatched_names = extract_unmatched_names(llm_names, matched_objects)
|
||||
self.assertEqual(unmatched_names, ["Nonexistent Tag"])
|
||||
tag = TagFactory.build(name="Test Tag 1")
|
||||
unmatched = extract_unmatched_names(["Test Tag 1", "Nonexistent Tag"], [tag])
|
||||
assert unmatched == ["Nonexistent Tag"]
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_tags_by_name_with_empty_names(self, mock_get_objects) -> None:
|
||||
mock_get_objects.return_value = Tag.objects.all()
|
||||
names = [None, "", " "]
|
||||
result = match_tags_by_name(names, user=None)
|
||||
self.assertEqual(result, [])
|
||||
def test_match_tags_by_name_with_empty_names(self, mocker: MockerFixture) -> None:
|
||||
tags = [
|
||||
TagFactory.build(name="Test Tag 1"),
|
||||
TagFactory.build(name="Test Tag 2"),
|
||||
]
|
||||
mocker.patch(_PATCH_TARGET, return_value=tags)
|
||||
result = match_tags_by_name([None, "", " "], user=None)
|
||||
assert result == []
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_tags_with_fuzzy_matching(self, mock_get_objects) -> None:
|
||||
mock_get_objects.return_value = Tag.objects.all()
|
||||
names = ["Test Taag 1", "Teest Tag 2"]
|
||||
result = match_tags_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].name, "Test Tag 1")
|
||||
self.assertEqual(result[1].name, "Test Tag 2")
|
||||
def test_match_tags_with_fuzzy_matching(self, mocker: MockerFixture) -> None:
|
||||
tags = [
|
||||
TagFactory.build(name="Test Tag 1"),
|
||||
TagFactory.build(name="Test Tag 2"),
|
||||
]
|
||||
mocker.patch(_PATCH_TARGET, return_value=tags)
|
||||
result = match_tags_by_name(["Test Taag 1", "Teest Tag 2"], user=None)
|
||||
assert len(result) == 2
|
||||
assert result[0].name == "Test Tag 1"
|
||||
assert result[1].name == "Test Tag 2"
|
||||
|
||||
Reference in New Issue
Block a user