mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-20 20:34:20 +00:00
Fix (beta): truncate embedding queries for small chunk size (#13028)
This commit is contained in:
@@ -443,6 +443,18 @@ def truncate_content(
|
||||
return " ".join(truncated_chunks)
|
||||
|
||||
|
||||
def truncate_embedding_query(content: str, *, chunk_size: int) -> str:
|
||||
from llama_index.core.text_splitter import TokenTextSplitter
|
||||
|
||||
splitter = TokenTextSplitter(
|
||||
separator=" ",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=0,
|
||||
)
|
||||
content_chunks = splitter.split_text(content)
|
||||
return content_chunks[0] if content_chunks else ""
|
||||
|
||||
|
||||
def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str] | None:
|
||||
if document_ids is None:
|
||||
return None
|
||||
@@ -476,10 +488,9 @@ def query_similar_documents(
|
||||
else None
|
||||
)
|
||||
|
||||
query_text = truncate_content(
|
||||
query_text = truncate_embedding_query(
|
||||
(document.title or "") + "\n" + (document.content or ""),
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
# Hold the shared read lock for the whole retrieval so the connection is
|
||||
# never open across a compaction swap. The retrieve() call generates a
|
||||
|
||||
@@ -137,6 +137,16 @@ def test_get_rag_prompt_helper_uses_context_setting() -> None:
|
||||
assert prompt_helper.context_window == 4096
|
||||
|
||||
|
||||
def test_truncate_embedding_query_returns_single_chunk() -> None:
|
||||
content = " ".join(f"word{i}" for i in range(200))
|
||||
|
||||
result = indexing.truncate_embedding_query(content, chunk_size=32)
|
||||
|
||||
assert result
|
||||
assert result != content
|
||||
assert "word199" not in result
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index(
|
||||
temp_llm_index_dir: Path,
|
||||
@@ -393,6 +403,42 @@ def test_query_similar_documents(
|
||||
assert result == mock_filtered_docs
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_EMBEDDING_CHUNK_SIZE=32,
|
||||
LLM_BACKEND="ollama",
|
||||
)
|
||||
def test_query_similar_documents_truncates_query_to_embedding_chunk_size(
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
real_document.content = " ".join(f"word{i}" for i in range(200))
|
||||
with (
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
patch(
|
||||
"paperless_ai.indexing.llm_index_exists",
|
||||
) as mock_vector_store_exists,
|
||||
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
|
||||
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
|
||||
patch("paperless_ai.indexing.truncate_content") as mock_truncate_content,
|
||||
):
|
||||
mock_vector_store_exists.return_value = True
|
||||
mock_load_or_build_index.return_value = MagicMock()
|
||||
mock_truncate_content.return_value = "wrong helper"
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = []
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
mock_filter.return_value = []
|
||||
|
||||
indexing.query_similar_documents(real_document, top_k=3)
|
||||
|
||||
mock_truncate_content.assert_not_called()
|
||||
query_text = mock_retriever.retrieve.call_args.args[0]
|
||||
assert query_text
|
||||
assert "word199" not in query_text
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_query_similar_documents_triggers_update_when_index_missing(
|
||||
temp_llm_index_dir: Path,
|
||||
|
||||
Reference in New Issue
Block a user