Use lancedb for chat retriever

This commit is contained in:
shamoon
2026-06-02 09:45:38 -07:00
parent c5bfe008d7
commit ad052459f3
2 changed files with 29 additions and 29 deletions
+16 -13
View File
@@ -80,7 +80,7 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k:
from llama_index.core.schema import NodeWithScore
from llama_index.core.vector_stores import VectorStoreQuery
class DocumentFilteredFaissRetriever(BaseRetriever):
class DocumentFilteredLanceDBRetriever(BaseRetriever):
def __init__(self):
super().__init__()
self._cached_query_str = None
@@ -97,8 +97,7 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k:
)
)
faiss_index = index.vector_store._faiss_index
max_top_k = faiss_index.ntotal
max_top_k = index.vector_store.table.count_rows()
if max_top_k == 0:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = []
@@ -109,20 +108,24 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k:
seen_node_ids: set[str] = set()
while query_top_k <= max_top_k:
query_result = index.vector_store.query(
VectorStoreQuery(
query_embedding=query_bundle.embedding,
similarity_top_k=query_top_k,
),
)
try:
query_result = index.vector_store.query(
VectorStoreQuery(
query_embedding=query_bundle.embedding,
similarity_top_k=query_top_k,
),
)
except Warning:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = []
return []
for vector_id, score in zip(
for node_id, score in zip(
query_result.ids or [],
query_result.similarities or [],
strict=False,
):
node_id = index.index_struct.nodes_dict.get(vector_id)
if node_id is None or node_id in seen_node_ids:
if node_id in seen_node_ids:
continue
node = index.docstore.docs.get(node_id)
@@ -148,7 +151,7 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k:
self._cached_nodes = allowed_nodes
return allowed_nodes
return DocumentFilteredFaissRetriever()
return DocumentFilteredLanceDBRetriever()
def stream_chat_with_documents(query_str: str, documents: list[Document]):
+13 -16
View File
@@ -59,15 +59,12 @@ def assert_chat_output(
def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None:
mock_index.index_struct.nodes_dict = {
str(vector_id): node.node_id for vector_id, node in enumerate(nodes)
}
mock_index.docstore.docs.get.side_effect = {
node.node_id: node for node in nodes
}.get
mock_index.vector_store._faiss_index.ntotal = len(nodes)
mock_index.vector_store.table.count_rows.return_value = len(nodes)
mock_index.vector_store.query.return_value = MagicMock(
ids=list(mock_index.index_struct.nodes_dict),
ids=[node.node_id for node in nodes],
similarities=[0.1] * len(nodes),
)
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
@@ -92,21 +89,21 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None:
)
mock_index = MagicMock()
mock_index.index_struct.nodes_dict = {
"0": foreign_node.node_id,
"1": missing_node.node_id,
"2": allowed_node1.node_id,
"3": allowed_node2.node_id,
}
mock_index.docstore.docs.get.side_effect = {
allowed_node1.node_id: allowed_node1,
allowed_node2.node_id: allowed_node2,
foreign_node.node_id: foreign_node,
}.get
mock_index.vector_store._faiss_index.ntotal = 4
mock_index.vector_store.table.count_rows.return_value = 4
mock_index.vector_store.query.side_effect = [
MagicMock(ids=["0", "2"], similarities=[0.9, 0.8]),
MagicMock(ids=["0", "1", "3"], similarities=[0.9, 0.7, 0.6]),
MagicMock(
ids=[foreign_node.node_id, allowed_node1.node_id],
similarities=[0.9, 0.8],
),
MagicMock(
ids=[foreign_node.node_id, missing_node.node_id, allowed_node2.node_id],
similarities=[0.9, 0.7, 0.6],
),
]
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
@@ -128,9 +125,9 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None:
assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1
def test_document_filtered_retriever_handles_empty_faiss_index() -> None:
def test_document_filtered_retriever_handles_empty_vector_store() -> None:
mock_index = MagicMock()
mock_index.vector_store._faiss_index.ntotal = 0
mock_index.vector_store.table.count_rows.return_value = 0
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
retriever = _get_document_filtered_retriever(