Files
paperless-ngx/src/paperless_ai/tests/test_embedding.py
T

279 lines
10 KiB
Python

import json
from unittest.mock import ANY
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from django.conf import settings
from documents.models import Document
from paperless.models import LLMEmbeddingBackend
from paperless_ai.embedding import _normalize_llm_index_text
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_embedding_model
@pytest.fixture
def mock_ai_config():
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
MockAIConfig.return_value.llm_embedding_endpoint = None
MockAIConfig.return_value.llm_allow_internal_endpoints = True
MockAIConfig.return_value.llm_context_size = 8192
yield MockAIConfig
@pytest.fixture
def mock_document():
doc = MagicMock(spec=Document)
doc.title = "Test Title"
doc.filename = "test_file.pdf"
doc.created = "2023-01-01"
doc.added = "2023-01-02"
doc.modified = "2023-01-03"
tag1 = MagicMock()
tag1.name = "Tag1"
tag2 = MagicMock()
tag2.name = "Tag2"
doc.tags.all = MagicMock(return_value=[tag1, tag2])
doc.document_type = MagicMock()
doc.document_type.name = "Invoice"
doc.correspondent = MagicMock()
doc.correspondent.name = "Test Correspondent"
doc.archive_serial_number = "12345"
doc.content = "This is the document content."
cf1 = MagicMock(__str__=lambda x: "Value1")
cf1.field = MagicMock()
cf1.field.name = "Field1"
cf1.value = "Value1"
cf2 = MagicMock(__str__=lambda x: "Value2")
cf2.field = MagicMock()
cf2.field.name = "Field2"
cf2.value = "Value2"
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
return doc
def test_get_embedding_model_openai(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
mock_ai_config.return_value.llm_api_key = "test_api_key"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
with patch(
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
) as MockOpenAIEmbedding:
model = get_embedding_model()
MockOpenAIEmbedding.assert_called_once_with(
model_name="text-embedding-3-small",
api_key="test_api_key",
api_base="http://test-url",
http_client=ANY,
async_http_client=ANY,
)
assert model == MockOpenAIEmbedding.return_value
def test_get_embedding_model_openai_prefers_embedding_endpoint(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
mock_ai_config.return_value.llm_api_key = "test_api_key"
mock_ai_config.return_value.llm_embedding_endpoint = "http://embedding-url"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
with patch(
"llama_index.embeddings.openai_like.OpenAILikeEmbedding",
) as MockOpenAIEmbedding:
model = get_embedding_model()
MockOpenAIEmbedding.assert_called_once_with(
model_name="text-embedding-3-small",
api_key="test_api_key",
api_base="http://embedding-url",
http_client=ANY,
async_http_client=ANY,
)
assert model == MockOpenAIEmbedding.return_value
def test_get_embedding_model_openai_blocks_internal_endpoint_when_disallowed(
mock_ai_config,
):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI_LIKE
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
mock_ai_config.return_value.llm_api_key = "test_api_key"
mock_ai_config.return_value.llm_endpoint = "http://127.0.0.1:11434"
mock_ai_config.return_value.llm_allow_internal_endpoints = False
with pytest.raises(ValueError, match="non-public address"):
get_embedding_model()
def test_get_embedding_model_huggingface(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE
mock_ai_config.return_value.llm_embedding_model = (
"sentence-transformers/all-MiniLM-L6-v2"
)
with patch(
"llama_index.embeddings.huggingface.HuggingFaceEmbedding",
) as MockHuggingFaceEmbedding:
model = get_embedding_model()
MockHuggingFaceEmbedding.assert_called_once_with(
model_name="sentence-transformers/all-MiniLM-L6-v2",
cache_folder=str(settings.DATA_DIR / "hf_cache"),
)
assert model == MockHuggingFaceEmbedding.return_value
def test_get_embedding_model_ollama(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OLLAMA
mock_ai_config.return_value.llm_embedding_model = "embeddinggemma"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
with patch(
"llama_index.embeddings.ollama.OllamaEmbedding",
) as MockOllamaEmbedding:
model = get_embedding_model()
MockOllamaEmbedding.assert_called_once_with(
model_name="embeddinggemma",
base_url="http://test-url",
ollama_additional_kwargs={"num_ctx": 8192},
)
assert model == MockOllamaEmbedding.return_value
def test_get_embedding_model_ollama_prefers_embedding_endpoint(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OLLAMA
mock_ai_config.return_value.llm_embedding_model = "embeddinggemma"
mock_ai_config.return_value.llm_embedding_endpoint = "http://embedding-url"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
with patch(
"llama_index.embeddings.ollama.OllamaEmbedding",
) as MockOllamaEmbedding:
model = get_embedding_model()
MockOllamaEmbedding.assert_called_once_with(
model_name="embeddinggemma",
base_url="http://embedding-url",
ollama_additional_kwargs={"num_ctx": 8192},
)
assert model == MockOllamaEmbedding.return_value
def test_get_embedding_model_ollama_blocks_internal_endpoint_when_disallowed(
mock_ai_config,
):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OLLAMA
mock_ai_config.return_value.llm_embedding_model = "embeddinggemma"
mock_ai_config.return_value.llm_endpoint = "http://127.0.0.1:11434"
mock_ai_config.return_value.llm_allow_internal_endpoints = False
with pytest.raises(ValueError, match="non-public address"):
get_embedding_model()
def test_get_embedding_model_invalid_backend(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "INVALID_BACKEND"
with pytest.raises(
ValueError,
match="Unsupported embedding backend: INVALID_BACKEND",
):
get_embedding_model()
def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
mock_ai_config.return_value.llm_embedding_model = None
class DummyEmbedding:
def get_text_embedding(self, text):
return [0.0] * 7
with patch(
"paperless_ai.embedding.get_embedding_model",
return_value=DummyEmbedding(),
) as mock_get:
dim = get_embedding_dim()
mock_get.assert_called_once()
assert dim == 7
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7}
def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
mock_ai_config.return_value.llm_embedding_model = None
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}),
)
with patch("paperless_ai.embedding.get_embedding_model") as mock_get:
assert get_embedding_dim() == 11
mock_get.assert_not_called()
def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai-like"
mock_ai_config.return_value.llm_embedding_model = None
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "old", "dim": 11}),
)
with pytest.raises(
RuntimeError,
match="Embedding model changed from old to text-embedding-3-small",
):
get_embedding_dim()
def test_build_llm_index_text(mock_document):
with patch("documents.models.Note.objects.filter") as mock_notes_filter:
mock_notes_filter.return_value = [
MagicMock(note="Note1"),
MagicMock(note="Note2"),
]
result = build_llm_index_text(mock_document)
assert "Title: Test Title" in result
assert "Filename: test_file.pdf" in result
assert "Created: 2023-01-01" in result
assert "Tags: Tag1, Tag2" in result
assert "Document Type: Invoice" in result
assert "Correspondent: Test Correspondent" in result
assert "Notes: Note1,Note2" in result
assert "Content:\n\nThis is the document content." in result
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
def test_build_llm_index_text_normalizes_ocr_punctuation_runs(mock_document):
mock_document.content = (
"Introduction ................................................ 7\n"
"Hardware Limitation ________________________________________ 9\n"
"Keep short punctuation like INV-100 and ellipses..."
)
with patch("documents.models.Note.objects.filter", return_value=[]):
result = build_llm_index_text(mock_document)
assert "Introduction 7" in result
assert "Hardware Limitation 9" in result
assert "INV-100" in result
assert "ellipses..." in result
def test_normalize_llm_index_text_collapses_ocr_leaders_without_joining_lines():
assert _normalize_llm_index_text("A........B\nC____D----E") == "A B\nC D E"
def test_normalize_llm_index_text_collapses_non_breaking_spaces():
assert _normalize_llm_index_text("A\u00a0........\u00a0B") == "A B"