diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index 307ee840f..499990799 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -22,6 +22,7 @@ from paperless_ai.embedding import get_embedding_model if TYPE_CHECKING: from llama_index.core.schema import BaseNode + from llama_index.core.schema import NodeWithScore from paperless_ai.vector_store import PaperlessSqliteVecVectorStore @@ -449,12 +450,19 @@ def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str] return {str(document_id) for document_id in document_ids} -def query_similar_documents( +def retrieve_similar_nodes( document: Document, - top_k: int = 5, document_ids: Iterable[int | str] | None = None, -) -> list[Document]: - """Return up to ``top_k`` Documents most similar to ``document``.""" + top_k: int = 5, +) -> list["NodeWithScore"]: + """Run ANN retrieval and return the raw NodeWithScore results. + + Returns ``[]`` when the allow-list normalizes to empty, or when no index + exists yet (queuing a build in that case). The ``retrieve()`` call is a slow + embedding request, so it runs inside ``db_connection_released()`` to avoid + pinning the pooled DB connection (#12976). Both ``query_similar_documents`` + and the taxonomy-hints path go through here, so they share that behavior. + """ allowed_document_ids = normalize_document_ids(document_ids) if allowed_document_ids is not None and not allowed_document_ids: return [] @@ -494,7 +502,21 @@ def query_similar_documents( filters=filters, ) with db_connection_released(): - results = retriever.retrieve(query_text) + return retriever.retrieve(query_text) + + +def query_similar_documents( + document: Document, + top_k: int = 5, + document_ids: Iterable[int | str] | None = None, +) -> list[Document]: + """Return up to ``top_k`` Documents most similar to ``document``.""" + allowed_document_ids = normalize_document_ids(document_ids) + results = retrieve_similar_nodes( + document=document, + document_ids=allowed_document_ids, + top_k=top_k, + ) retrieved_document_ids: list[int] = [] for node in results: diff --git a/src/paperless_ai/tests/test_ai_indexing.py b/src/paperless_ai/tests/test_ai_indexing.py index 42062e82d..8deb9658f 100644 --- a/src/paperless_ai/tests/test_ai_indexing.py +++ b/src/paperless_ai/tests/test_ai_indexing.py @@ -1,4 +1,5 @@ from pathlib import Path +from types import SimpleNamespace from unittest.mock import MagicMock from unittest.mock import patch @@ -726,3 +727,58 @@ class TestQuerySimilarDocuments: results = indexing.query_similar_documents(a, document_ids=[b.id]) assert all(doc.id == b.id for doc in results) + + +class TestRetrieveSimilarNodes: + @pytest.mark.django_db + def test_returns_raw_nodes_from_retriever( + self, + temp_llm_index_dir: Path, + real_document: Document, + mocker: pytest_mock.MockerFixture, + ) -> None: + mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=True) + mocker.patch("paperless_ai.indexing.load_or_build_index") + node1 = SimpleNamespace(metadata={"document_id": "1"}) + node2 = SimpleNamespace(metadata={"document_id": "2"}) + retriever = mocker.MagicMock() + retriever.retrieve.return_value = [node1, node2] + mocker.patch( + "llama_index.core.retrievers.VectorIndexRetriever", + return_value=retriever, + ) + + result = indexing.retrieve_similar_nodes(real_document, top_k=3) + + assert result == [node1, node2] + + @pytest.mark.django_db + def test_empty_allow_list_fails_closed( + self, + real_document: Document, + mocker: pytest_mock.MockerFixture, + ) -> None: + load = mocker.patch("paperless_ai.indexing.load_or_build_index") + + result = indexing.retrieve_similar_nodes(real_document, document_ids=[]) + + assert result == [] + load.assert_not_called() + + @pytest.mark.django_db + def test_queues_update_when_index_missing( + self, + temp_llm_index_dir: Path, + real_document: Document, + mocker: pytest_mock.MockerFixture, + ) -> None: + mocker.patch("paperless_ai.indexing.llm_index_exists", return_value=False) + queue = mocker.patch("paperless_ai.indexing.queue_llm_index_update_if_needed") + + result = indexing.retrieve_similar_nodes(real_document, top_k=2) + + assert result == [] + queue.assert_called_once_with( + rebuild=False, + reason="LLM index not found for similarity query.", + )