From ad052459f36b03e9b54326ad56d6853c3fcbc07a Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:45:38 -0700 Subject: [PATCH] Use lancedb for chat retriever --- src/paperless_ai/chat.py | 29 ++++++++++++++++------------- src/paperless_ai/tests/test_chat.py | 29 +++++++++++++---------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py index b2710c379..fbaea325c 100644 --- a/src/paperless_ai/chat.py +++ b/src/paperless_ai/chat.py @@ -80,7 +80,7 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: from llama_index.core.schema import NodeWithScore from llama_index.core.vector_stores import VectorStoreQuery - class DocumentFilteredFaissRetriever(BaseRetriever): + class DocumentFilteredLanceDBRetriever(BaseRetriever): def __init__(self): super().__init__() self._cached_query_str = None @@ -97,8 +97,7 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: ) ) - faiss_index = index.vector_store._faiss_index - max_top_k = faiss_index.ntotal + max_top_k = index.vector_store.table.count_rows() if max_top_k == 0: self._cached_query_str = query_bundle.query_str self._cached_nodes = [] @@ -109,20 +108,24 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: seen_node_ids: set[str] = set() while query_top_k <= max_top_k: - query_result = index.vector_store.query( - VectorStoreQuery( - query_embedding=query_bundle.embedding, - similarity_top_k=query_top_k, - ), - ) + try: + query_result = index.vector_store.query( + VectorStoreQuery( + query_embedding=query_bundle.embedding, + similarity_top_k=query_top_k, + ), + ) + except Warning: + self._cached_query_str = query_bundle.query_str + self._cached_nodes = [] + return [] - for vector_id, score in zip( + for node_id, score in zip( query_result.ids or [], query_result.similarities or [], strict=False, ): - node_id = index.index_struct.nodes_dict.get(vector_id) - if node_id is None or node_id in seen_node_ids: + if node_id in seen_node_ids: continue node = index.docstore.docs.get(node_id) @@ -148,7 +151,7 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: self._cached_nodes = allowed_nodes return allowed_nodes - return DocumentFilteredFaissRetriever() + return DocumentFilteredLanceDBRetriever() def stream_chat_with_documents(query_str: str, documents: list[Document]): diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index d72b22f32..3fbe59458 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -59,15 +59,12 @@ def assert_chat_output( def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None: - mock_index.index_struct.nodes_dict = { - str(vector_id): node.node_id for vector_id, node in enumerate(nodes) - } mock_index.docstore.docs.get.side_effect = { node.node_id: node for node in nodes }.get - mock_index.vector_store._faiss_index.ntotal = len(nodes) + mock_index.vector_store.table.count_rows.return_value = len(nodes) mock_index.vector_store.query.return_value = MagicMock( - ids=list(mock_index.index_struct.nodes_dict), + ids=[node.node_id for node in nodes], similarities=[0.1] * len(nodes), ) mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536 @@ -92,21 +89,21 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None: ) mock_index = MagicMock() - mock_index.index_struct.nodes_dict = { - "0": foreign_node.node_id, - "1": missing_node.node_id, - "2": allowed_node1.node_id, - "3": allowed_node2.node_id, - } mock_index.docstore.docs.get.side_effect = { allowed_node1.node_id: allowed_node1, allowed_node2.node_id: allowed_node2, foreign_node.node_id: foreign_node, }.get - mock_index.vector_store._faiss_index.ntotal = 4 + mock_index.vector_store.table.count_rows.return_value = 4 mock_index.vector_store.query.side_effect = [ - MagicMock(ids=["0", "2"], similarities=[0.9, 0.8]), - MagicMock(ids=["0", "1", "3"], similarities=[0.9, 0.7, 0.6]), + MagicMock( + ids=[foreign_node.node_id, allowed_node1.node_id], + similarities=[0.9, 0.8], + ), + MagicMock( + ids=[foreign_node.node_id, missing_node.node_id, allowed_node2.node_id], + similarities=[0.9, 0.7, 0.6], + ), ] mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536 @@ -128,9 +125,9 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None: assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1 -def test_document_filtered_retriever_handles_empty_faiss_index() -> None: +def test_document_filtered_retriever_handles_empty_vector_store() -> None: mock_index = MagicMock() - mock_index.vector_store._faiss_index.ntotal = 0 + mock_index.vector_store.table.count_rows.return_value = 0 mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536 retriever = _get_document_filtered_retriever(