From eab0466b9ed0e27b23b1ca7965a03759678e612e Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Wed, 3 Jun 2026 09:34:31 -0700 Subject: [PATCH] Use metadata filter in document retriever --- src/paperless_ai/chat.py | 74 +++++++++++----------- src/paperless_ai/indexing.py | 1 + src/paperless_ai/tests/test_ai_indexing.py | 31 +++++++++ src/paperless_ai/tests/test_chat.py | 29 +++++---- 4 files changed, 87 insertions(+), 48 deletions(-) diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py index fbaea325c..c2bcff2a1 100644 --- a/src/paperless_ai/chat.py +++ b/src/paperless_ai/chat.py @@ -78,6 +78,9 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int): from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.schema import NodeWithScore + from llama_index.core.vector_stores import FilterOperator + from llama_index.core.vector_stores import MetadataFilter + from llama_index.core.vector_stores import MetadataFilters from llama_index.core.vector_stores import VectorStoreQuery class DocumentFilteredLanceDBRetriever(BaseRetriever): @@ -103,49 +106,48 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: self._cached_nodes = [] return [] - query_top_k = min(max(similarity_top_k, 1), max_top_k) allowed_nodes: list[NodeWithScore] = [] seen_node_ids: set[str] = set() + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="document_id", + operator=FilterOperator.IN, + value=sorted(doc_ids), + ), + ], + ) - while query_top_k <= max_top_k: - try: - query_result = index.vector_store.query( - VectorStoreQuery( - query_embedding=query_bundle.embedding, - similarity_top_k=query_top_k, - ), - ) - except Warning: - self._cached_query_str = query_bundle.query_str - self._cached_nodes = [] - return [] + try: + query_result = index.vector_store.query( + VectorStoreQuery( + query_embedding=query_bundle.embedding, + similarity_top_k=max(similarity_top_k, 1), + filters=filters, + ), + ) + except Warning: + self._cached_query_str = query_bundle.query_str + self._cached_nodes = [] + return [] - for node_id, score in zip( - query_result.ids or [], - query_result.similarities or [], - strict=False, - ): - if node_id in seen_node_ids: - continue + for node_id, score in zip( + query_result.ids or [], + query_result.similarities or [], + strict=False, + ): + if node_id in seen_node_ids: + continue - node = index.docstore.docs.get(node_id) - if node is None or node.metadata.get("document_id") not in doc_ids: - continue + node = index.docstore.docs.get(node_id) + if node is None or node.metadata.get("document_id") not in doc_ids: + continue - seen_node_ids.add(node_id) - allowed_nodes.append(NodeWithScore(node=node, score=score)) + seen_node_ids.add(node_id) + allowed_nodes.append(NodeWithScore(node=node, score=score)) - if len(allowed_nodes) >= similarity_top_k: - self._cached_query_str = query_bundle.query_str - self._cached_nodes = allowed_nodes - return allowed_nodes - - if query_top_k == max_top_k: - self._cached_query_str = query_bundle.query_str - self._cached_nodes = allowed_nodes - return allowed_nodes - - query_top_k = min(query_top_k * 2, max_top_k) + if len(allowed_nodes) >= similarity_top_k: + break self._cached_query_str = query_bundle.query_str self._cached_nodes = allowed_nodes diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index 3e88a93f6..7556d7d4e 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -140,6 +140,7 @@ def build_document_node( # the token count and exceed embedding models with small context windows # (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048). doc = LlamaDocument( + id_=str(document.id), text=text, metadata=metadata, excluded_embed_metadata_keys=list(metadata.keys()), diff --git a/src/paperless_ai/tests/test_ai_indexing.py b/src/paperless_ai/tests/test_ai_indexing.py index 1d74f7786..c820b1ce3 100644 --- a/src/paperless_ai/tests/test_ai_indexing.py +++ b/src/paperless_ai/tests/test_ai_indexing.py @@ -19,6 +19,7 @@ from documents.tests.factories import DocumentFactory from documents.tests.factories import PaperlessTaskFactory from paperless.models import ApplicationConfiguration from paperless_ai import indexing +from paperless_ai.chat import _get_document_filtered_retriever @pytest.fixture @@ -612,6 +613,36 @@ def test_query_similar_documents_empty_allow_list_fails_closed( mock_retriever_cls.assert_not_called() +@pytest.mark.django_db +def test_document_filtered_retriever_applies_lancedb_metadata_filter( + temp_llm_index_dir, + mock_embed_model: MagicMock, +) -> None: + allowed_document = DocumentFactory( + title="Allowed", + content="Allowed document content.", + ) + DocumentFactory( + title="Filtered", + content="Filtered document content.", + ) + indexing.update_llm_index(rebuild=True) + index = indexing.load_or_build_index() + + retriever = _get_document_filtered_retriever( + index, + {str(allowed_document.pk)}, + similarity_top_k=5, + ) + + nodes = retriever.retrieve("document content") + + assert nodes + assert {node.node.metadata["document_id"] for node in nodes} == { + str(allowed_document.pk), + } + + class TestUpdateLlmIndexEmptyDocumentSet: """update_llm_index must persist an empty index when all documents are deleted. diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index 3fbe59458..52c4bf619 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -70,7 +70,7 @@ def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None: mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536 -def test_document_filtered_retriever_expands_filters_and_caches() -> None: +def test_document_filtered_retriever_applies_metadata_filter_and_caches() -> None: allowed_node1 = TextNode( text="Allowed content 1.", metadata={"document_id": "1", "title": "Allowed 1"}, @@ -95,16 +95,15 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None: foreign_node.node_id: foreign_node, }.get mock_index.vector_store.table.count_rows.return_value = 4 - mock_index.vector_store.query.side_effect = [ - MagicMock( - ids=[foreign_node.node_id, allowed_node1.node_id], - similarities=[0.9, 0.8], - ), - MagicMock( - ids=[foreign_node.node_id, missing_node.node_id, allowed_node2.node_id], - similarities=[0.9, 0.7, 0.6], - ), - ] + mock_index.vector_store.query.return_value = MagicMock( + ids=[ + foreign_node.node_id, + missing_node.node_id, + allowed_node1.node_id, + allowed_node2.node_id, + ], + similarities=[0.9, 0.7, 0.6, 0.5], + ) mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536 retriever = _get_document_filtered_retriever( @@ -121,7 +120,13 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None: allowed_node2.node_id, ] assert cached_nodes == nodes - assert mock_index.vector_store.query.call_count == 2 + assert mock_index.vector_store.query.call_count == 1 + query = mock_index.vector_store.query.call_args.args[0] + assert query.similarity_top_k == 2 + assert len(query.filters.filters) == 1 + metadata_filter = query.filters.filters[0] + assert metadata_filter.key == "document_id" + assert metadata_filter.value == ["1", "2"] assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1