Use metadata filter in document retriever

This commit is contained in:
shamoon
2026-06-03 09:34:31 -07:00
parent a02946f7c1
commit eab0466b9e
4 changed files with 87 additions and 48 deletions
+38 -36
View File
@@ -78,6 +78,9 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str
def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int):
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.core.vector_stores import FilterOperator
from llama_index.core.vector_stores import MetadataFilter
from llama_index.core.vector_stores import MetadataFilters
from llama_index.core.vector_stores import VectorStoreQuery
class DocumentFilteredLanceDBRetriever(BaseRetriever):
@@ -103,49 +106,48 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k:
self._cached_nodes = []
return []
query_top_k = min(max(similarity_top_k, 1), max_top_k)
allowed_nodes: list[NodeWithScore] = []
seen_node_ids: set[str] = set()
filters = MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=sorted(doc_ids),
),
],
)
while query_top_k <= max_top_k:
try:
query_result = index.vector_store.query(
VectorStoreQuery(
query_embedding=query_bundle.embedding,
similarity_top_k=query_top_k,
),
)
except Warning:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = []
return []
try:
query_result = index.vector_store.query(
VectorStoreQuery(
query_embedding=query_bundle.embedding,
similarity_top_k=max(similarity_top_k, 1),
filters=filters,
),
)
except Warning:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = []
return []
for node_id, score in zip(
query_result.ids or [],
query_result.similarities or [],
strict=False,
):
if node_id in seen_node_ids:
continue
for node_id, score in zip(
query_result.ids or [],
query_result.similarities or [],
strict=False,
):
if node_id in seen_node_ids:
continue
node = index.docstore.docs.get(node_id)
if node is None or node.metadata.get("document_id") not in doc_ids:
continue
node = index.docstore.docs.get(node_id)
if node is None or node.metadata.get("document_id") not in doc_ids:
continue
seen_node_ids.add(node_id)
allowed_nodes.append(NodeWithScore(node=node, score=score))
seen_node_ids.add(node_id)
allowed_nodes.append(NodeWithScore(node=node, score=score))
if len(allowed_nodes) >= similarity_top_k:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
if query_top_k == max_top_k:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
query_top_k = min(query_top_k * 2, max_top_k)
if len(allowed_nodes) >= similarity_top_k:
break
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
+1
View File
@@ -140,6 +140,7 @@ def build_document_node(
# the token count and exceed embedding models with small context windows
# (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
doc = LlamaDocument(
id_=str(document.id),
text=text,
metadata=metadata,
excluded_embed_metadata_keys=list(metadata.keys()),
@@ -19,6 +19,7 @@ from documents.tests.factories import DocumentFactory
from documents.tests.factories import PaperlessTaskFactory
from paperless.models import ApplicationConfiguration
from paperless_ai import indexing
from paperless_ai.chat import _get_document_filtered_retriever
@pytest.fixture
@@ -612,6 +613,36 @@ def test_query_similar_documents_empty_allow_list_fails_closed(
mock_retriever_cls.assert_not_called()
@pytest.mark.django_db
def test_document_filtered_retriever_applies_lancedb_metadata_filter(
temp_llm_index_dir,
mock_embed_model: MagicMock,
) -> None:
allowed_document = DocumentFactory(
title="Allowed",
content="Allowed document content.",
)
DocumentFactory(
title="Filtered",
content="Filtered document content.",
)
indexing.update_llm_index(rebuild=True)
index = indexing.load_or_build_index()
retriever = _get_document_filtered_retriever(
index,
{str(allowed_document.pk)},
similarity_top_k=5,
)
nodes = retriever.retrieve("document content")
assert nodes
assert {node.node.metadata["document_id"] for node in nodes} == {
str(allowed_document.pk),
}
class TestUpdateLlmIndexEmptyDocumentSet:
"""update_llm_index must persist an empty index when all documents are deleted.
+17 -12
View File
@@ -70,7 +70,7 @@ def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None:
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
def test_document_filtered_retriever_expands_filters_and_caches() -> None:
def test_document_filtered_retriever_applies_metadata_filter_and_caches() -> None:
allowed_node1 = TextNode(
text="Allowed content 1.",
metadata={"document_id": "1", "title": "Allowed 1"},
@@ -95,16 +95,15 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None:
foreign_node.node_id: foreign_node,
}.get
mock_index.vector_store.table.count_rows.return_value = 4
mock_index.vector_store.query.side_effect = [
MagicMock(
ids=[foreign_node.node_id, allowed_node1.node_id],
similarities=[0.9, 0.8],
),
MagicMock(
ids=[foreign_node.node_id, missing_node.node_id, allowed_node2.node_id],
similarities=[0.9, 0.7, 0.6],
),
]
mock_index.vector_store.query.return_value = MagicMock(
ids=[
foreign_node.node_id,
missing_node.node_id,
allowed_node1.node_id,
allowed_node2.node_id,
],
similarities=[0.9, 0.7, 0.6, 0.5],
)
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
retriever = _get_document_filtered_retriever(
@@ -121,7 +120,13 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None:
allowed_node2.node_id,
]
assert cached_nodes == nodes
assert mock_index.vector_store.query.call_count == 2
assert mock_index.vector_store.query.call_count == 1
query = mock_index.vector_store.query.call_args.args[0]
assert query.similarity_top_k == 2
assert len(query.filters.filters) == 1
metadata_filter = query.filters.filters[0]
assert metadata_filter.key == "document_id"
assert metadata_filter.value == ["1", "2"]
assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1