mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-07 06:09:43 +00:00
Use metadata filter in document retriever
This commit is contained in:
+38
-36
@@ -78,6 +78,9 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str
|
||||
def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int):
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.vector_stores import FilterOperator
|
||||
from llama_index.core.vector_stores import MetadataFilter
|
||||
from llama_index.core.vector_stores import MetadataFilters
|
||||
from llama_index.core.vector_stores import VectorStoreQuery
|
||||
|
||||
class DocumentFilteredLanceDBRetriever(BaseRetriever):
|
||||
@@ -103,49 +106,48 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k:
|
||||
self._cached_nodes = []
|
||||
return []
|
||||
|
||||
query_top_k = min(max(similarity_top_k, 1), max_top_k)
|
||||
allowed_nodes: list[NodeWithScore] = []
|
||||
seen_node_ids: set[str] = set()
|
||||
filters = MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="document_id",
|
||||
operator=FilterOperator.IN,
|
||||
value=sorted(doc_ids),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
while query_top_k <= max_top_k:
|
||||
try:
|
||||
query_result = index.vector_store.query(
|
||||
VectorStoreQuery(
|
||||
query_embedding=query_bundle.embedding,
|
||||
similarity_top_k=query_top_k,
|
||||
),
|
||||
)
|
||||
except Warning:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = []
|
||||
return []
|
||||
try:
|
||||
query_result = index.vector_store.query(
|
||||
VectorStoreQuery(
|
||||
query_embedding=query_bundle.embedding,
|
||||
similarity_top_k=max(similarity_top_k, 1),
|
||||
filters=filters,
|
||||
),
|
||||
)
|
||||
except Warning:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = []
|
||||
return []
|
||||
|
||||
for node_id, score in zip(
|
||||
query_result.ids or [],
|
||||
query_result.similarities or [],
|
||||
strict=False,
|
||||
):
|
||||
if node_id in seen_node_ids:
|
||||
continue
|
||||
for node_id, score in zip(
|
||||
query_result.ids or [],
|
||||
query_result.similarities or [],
|
||||
strict=False,
|
||||
):
|
||||
if node_id in seen_node_ids:
|
||||
continue
|
||||
|
||||
node = index.docstore.docs.get(node_id)
|
||||
if node is None or node.metadata.get("document_id") not in doc_ids:
|
||||
continue
|
||||
node = index.docstore.docs.get(node_id)
|
||||
if node is None or node.metadata.get("document_id") not in doc_ids:
|
||||
continue
|
||||
|
||||
seen_node_ids.add(node_id)
|
||||
allowed_nodes.append(NodeWithScore(node=node, score=score))
|
||||
seen_node_ids.add(node_id)
|
||||
allowed_nodes.append(NodeWithScore(node=node, score=score))
|
||||
|
||||
if len(allowed_nodes) >= similarity_top_k:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
if query_top_k == max_top_k:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
query_top_k = min(query_top_k * 2, max_top_k)
|
||||
if len(allowed_nodes) >= similarity_top_k:
|
||||
break
|
||||
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
|
||||
@@ -140,6 +140,7 @@ def build_document_node(
|
||||
# the token count and exceed embedding models with small context windows
|
||||
# (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
|
||||
doc = LlamaDocument(
|
||||
id_=str(document.id),
|
||||
text=text,
|
||||
metadata=metadata,
|
||||
excluded_embed_metadata_keys=list(metadata.keys()),
|
||||
|
||||
@@ -19,6 +19,7 @@ from documents.tests.factories import DocumentFactory
|
||||
from documents.tests.factories import PaperlessTaskFactory
|
||||
from paperless.models import ApplicationConfiguration
|
||||
from paperless_ai import indexing
|
||||
from paperless_ai.chat import _get_document_filtered_retriever
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -612,6 +613,36 @@ def test_query_similar_documents_empty_allow_list_fails_closed(
|
||||
mock_retriever_cls.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_document_filtered_retriever_applies_lancedb_metadata_filter(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model: MagicMock,
|
||||
) -> None:
|
||||
allowed_document = DocumentFactory(
|
||||
title="Allowed",
|
||||
content="Allowed document content.",
|
||||
)
|
||||
DocumentFactory(
|
||||
title="Filtered",
|
||||
content="Filtered document content.",
|
||||
)
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
index = indexing.load_or_build_index()
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
index,
|
||||
{str(allowed_document.pk)},
|
||||
similarity_top_k=5,
|
||||
)
|
||||
|
||||
nodes = retriever.retrieve("document content")
|
||||
|
||||
assert nodes
|
||||
assert {node.node.metadata["document_id"] for node in nodes} == {
|
||||
str(allowed_document.pk),
|
||||
}
|
||||
|
||||
|
||||
class TestUpdateLlmIndexEmptyDocumentSet:
|
||||
"""update_llm_index must persist an empty index when all documents are deleted.
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None:
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
|
||||
def test_document_filtered_retriever_expands_filters_and_caches() -> None:
|
||||
def test_document_filtered_retriever_applies_metadata_filter_and_caches() -> None:
|
||||
allowed_node1 = TextNode(
|
||||
text="Allowed content 1.",
|
||||
metadata={"document_id": "1", "title": "Allowed 1"},
|
||||
@@ -95,16 +95,15 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None:
|
||||
foreign_node.node_id: foreign_node,
|
||||
}.get
|
||||
mock_index.vector_store.table.count_rows.return_value = 4
|
||||
mock_index.vector_store.query.side_effect = [
|
||||
MagicMock(
|
||||
ids=[foreign_node.node_id, allowed_node1.node_id],
|
||||
similarities=[0.9, 0.8],
|
||||
),
|
||||
MagicMock(
|
||||
ids=[foreign_node.node_id, missing_node.node_id, allowed_node2.node_id],
|
||||
similarities=[0.9, 0.7, 0.6],
|
||||
),
|
||||
]
|
||||
mock_index.vector_store.query.return_value = MagicMock(
|
||||
ids=[
|
||||
foreign_node.node_id,
|
||||
missing_node.node_id,
|
||||
allowed_node1.node_id,
|
||||
allowed_node2.node_id,
|
||||
],
|
||||
similarities=[0.9, 0.7, 0.6, 0.5],
|
||||
)
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
@@ -121,7 +120,13 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None:
|
||||
allowed_node2.node_id,
|
||||
]
|
||||
assert cached_nodes == nodes
|
||||
assert mock_index.vector_store.query.call_count == 2
|
||||
assert mock_index.vector_store.query.call_count == 1
|
||||
query = mock_index.vector_store.query.call_args.args[0]
|
||||
assert query.similarity_top_k == 2
|
||||
assert len(query.filters.filters) == 1
|
||||
metadata_filter = query.filters.filters[0]
|
||||
assert metadata_filter.key == "document_id"
|
||||
assert metadata_filter.value == ["1", "2"]
|
||||
assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user