mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-28 16:24:19 +00:00
Refactor(beta): extract retrieve_similar_nodes from query_similar_documents
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from paperless_ai.embedding import get_embedding_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
@@ -449,12 +450,19 @@ def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str]
|
||||
return {str(document_id) for document_id in document_ids}
|
||||
|
||||
|
||||
def query_similar_documents(
|
||||
def retrieve_similar_nodes(
|
||||
document: Document,
|
||||
top_k: int = 5,
|
||||
document_ids: Iterable[int | str] | None = None,
|
||||
) -> list[Document]:
|
||||
"""Return up to ``top_k`` Documents most similar to ``document``."""
|
||||
top_k: int = 5,
|
||||
) -> list["NodeWithScore"]:
|
||||
"""Run ANN retrieval and return the raw NodeWithScore results.
|
||||
|
||||
Returns ``[]`` when the allow-list normalizes to empty, or when no index
|
||||
exists yet (queuing a build in that case). The ``retrieve()`` call is a slow
|
||||
embedding request, so it runs inside ``db_connection_released()`` to avoid
|
||||
pinning the pooled DB connection (#12976). Both ``query_similar_documents``
|
||||
and the taxonomy-hints path go through here, so they share that behavior.
|
||||
"""
|
||||
allowed_document_ids = normalize_document_ids(document_ids)
|
||||
if allowed_document_ids is not None and not allowed_document_ids:
|
||||
return []
|
||||
@@ -494,7 +502,21 @@ def query_similar_documents(
|
||||
filters=filters,
|
||||
)
|
||||
with db_connection_released():
|
||||
results = retriever.retrieve(query_text)
|
||||
return retriever.retrieve(query_text)
|
||||
|
||||
|
||||
def query_similar_documents(
|
||||
document: Document,
|
||||
top_k: int = 5,
|
||||
document_ids: Iterable[int | str] | None = None,
|
||||
) -> list[Document]:
|
||||
"""Return up to ``top_k`` Documents most similar to ``document``."""
|
||||
allowed_document_ids = normalize_document_ids(document_ids)
|
||||
results = retrieve_similar_nodes(
|
||||
document=document,
|
||||
document_ids=allowed_document_ids,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
retrieved_document_ids: list[int] = []
|
||||
for node in results:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -726,3 +727,58 @@ class TestQuerySimilarDocuments:
|
||||
results = indexing.query_similar_documents(a, document_ids=[b.id])
|
||||
|
||||
assert all(doc.id == b.id for doc in results)
|
||||
|
||||
|
||||
class TestRetrieveSimilarNodes:
|
||||
@pytest.mark.django_db
|
||||
def test_returns_raw_nodes_from_retriever(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=True)
|
||||
mocker.patch("paperless_ai.indexing.load_or_build_index")
|
||||
node1 = SimpleNamespace(metadata={"document_id": "1"})
|
||||
node2 = SimpleNamespace(metadata={"document_id": "2"})
|
||||
retriever = mocker.MagicMock()
|
||||
retriever.retrieve.return_value = [node1, node2]
|
||||
mocker.patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
return_value=retriever,
|
||||
)
|
||||
|
||||
result = indexing.retrieve_similar_nodes(real_document, top_k=3)
|
||||
|
||||
assert result == [node1, node2]
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_empty_allow_list_fails_closed(
|
||||
self,
|
||||
real_document: Document,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
load = mocker.patch("paperless_ai.indexing.load_or_build_index")
|
||||
|
||||
result = indexing.retrieve_similar_nodes(real_document, document_ids=[])
|
||||
|
||||
assert result == []
|
||||
load.assert_not_called()
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_queues_update_when_index_missing(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=False)
|
||||
queue = mocker.patch("paperless_ai.indexing.queue_llm_index_update_if_needed")
|
||||
|
||||
result = indexing.retrieve_similar_nodes(real_document, top_k=2)
|
||||
|
||||
assert result == []
|
||||
queue.assert_called_once_with(
|
||||
rebuild=False,
|
||||
reason="LLM index not found for similarity query.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user