Refactor(beta): extract retrieve_similar_nodes from query_similar_documents

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
stumpylog
2026-06-12 14:51:44 -07:00
parent a020f64d08
commit 73062bd5ab
2 changed files with 83 additions and 5 deletions
+27 -5
View File
@@ -22,6 +22,7 @@ from paperless_ai.embedding import get_embedding_model
if TYPE_CHECKING:
from llama_index.core.schema import BaseNode
from llama_index.core.schema import NodeWithScore
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
@@ -449,12 +450,19 @@ def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str]
return {str(document_id) for document_id in document_ids}
def query_similar_documents(
def retrieve_similar_nodes(
document: Document,
top_k: int = 5,
document_ids: Iterable[int | str] | None = None,
) -> list[Document]:
"""Return up to ``top_k`` Documents most similar to ``document``."""
top_k: int = 5,
) -> list["NodeWithScore"]:
"""Run ANN retrieval and return the raw NodeWithScore results.
Returns ``[]`` when the allow-list normalizes to empty, or when no index
exists yet (queuing a build in that case). The ``retrieve()`` call is a slow
embedding request, so it runs inside ``db_connection_released()`` to avoid
pinning the pooled DB connection (#12976). Both ``query_similar_documents``
and the taxonomy-hints path go through here, so they share that behavior.
"""
allowed_document_ids = normalize_document_ids(document_ids)
if allowed_document_ids is not None and not allowed_document_ids:
return []
@@ -494,7 +502,21 @@ def query_similar_documents(
filters=filters,
)
with db_connection_released():
results = retriever.retrieve(query_text)
return retriever.retrieve(query_text)
def query_similar_documents(
document: Document,
top_k: int = 5,
document_ids: Iterable[int | str] | None = None,
) -> list[Document]:
"""Return up to ``top_k`` Documents most similar to ``document``."""
allowed_document_ids = normalize_document_ids(document_ids)
results = retrieve_similar_nodes(
document=document,
document_ids=allowed_document_ids,
top_k=top_k,
)
retrieved_document_ids: list[int] = []
for node in results:
@@ -1,4 +1,5 @@
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -726,3 +727,58 @@ class TestQuerySimilarDocuments:
results = indexing.query_similar_documents(a, document_ids=[b.id])
assert all(doc.id == b.id for doc in results)
class TestRetrieveSimilarNodes:
@pytest.mark.django_db
def test_returns_raw_nodes_from_retriever(
self,
temp_llm_index_dir: Path,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=True)
mocker.patch("paperless_ai.indexing.load_or_build_index")
node1 = SimpleNamespace(metadata={"document_id": "1"})
node2 = SimpleNamespace(metadata={"document_id": "2"})
retriever = mocker.MagicMock()
retriever.retrieve.return_value = [node1, node2]
mocker.patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=retriever,
)
result = indexing.retrieve_similar_nodes(real_document, top_k=3)
assert result == [node1, node2]
@pytest.mark.django_db
def test_empty_allow_list_fails_closed(
self,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
load = mocker.patch("paperless_ai.indexing.load_or_build_index")
result = indexing.retrieve_similar_nodes(real_document, document_ids=[])
assert result == []
load.assert_not_called()
@pytest.mark.django_db
def test_queues_update_when_index_missing(
self,
temp_llm_index_dir: Path,
real_document: Document,
mocker: pytest_mock.MockerFixture,
) -> None:
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=False)
queue = mocker.patch("paperless_ai.indexing.queue_llm_index_update_if_needed")
result = indexing.retrieve_similar_nodes(real_document, top_k=2)
assert result == []
queue.assert_called_once_with(
rebuild=False,
reason="LLM index not found for similarity query.",
)