mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-06 13:49:44 +00:00
Use lancedb for chat retriever
This commit is contained in:
+16
-13
@@ -80,7 +80,7 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k:
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.vector_stores import VectorStoreQuery
|
||||
|
||||
class DocumentFilteredFaissRetriever(BaseRetriever):
|
||||
class DocumentFilteredLanceDBRetriever(BaseRetriever):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._cached_query_str = None
|
||||
@@ -97,8 +97,7 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k:
|
||||
)
|
||||
)
|
||||
|
||||
faiss_index = index.vector_store._faiss_index
|
||||
max_top_k = faiss_index.ntotal
|
||||
max_top_k = index.vector_store.table.count_rows()
|
||||
if max_top_k == 0:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = []
|
||||
@@ -109,20 +108,24 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k:
|
||||
seen_node_ids: set[str] = set()
|
||||
|
||||
while query_top_k <= max_top_k:
|
||||
query_result = index.vector_store.query(
|
||||
VectorStoreQuery(
|
||||
query_embedding=query_bundle.embedding,
|
||||
similarity_top_k=query_top_k,
|
||||
),
|
||||
)
|
||||
try:
|
||||
query_result = index.vector_store.query(
|
||||
VectorStoreQuery(
|
||||
query_embedding=query_bundle.embedding,
|
||||
similarity_top_k=query_top_k,
|
||||
),
|
||||
)
|
||||
except Warning:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = []
|
||||
return []
|
||||
|
||||
for vector_id, score in zip(
|
||||
for node_id, score in zip(
|
||||
query_result.ids or [],
|
||||
query_result.similarities or [],
|
||||
strict=False,
|
||||
):
|
||||
node_id = index.index_struct.nodes_dict.get(vector_id)
|
||||
if node_id is None or node_id in seen_node_ids:
|
||||
if node_id in seen_node_ids:
|
||||
continue
|
||||
|
||||
node = index.docstore.docs.get(node_id)
|
||||
@@ -148,7 +151,7 @@ def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k:
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
return DocumentFilteredFaissRetriever()
|
||||
return DocumentFilteredLanceDBRetriever()
|
||||
|
||||
|
||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
|
||||
@@ -59,15 +59,12 @@ def assert_chat_output(
|
||||
|
||||
|
||||
def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None:
|
||||
mock_index.index_struct.nodes_dict = {
|
||||
str(vector_id): node.node_id for vector_id, node in enumerate(nodes)
|
||||
}
|
||||
mock_index.docstore.docs.get.side_effect = {
|
||||
node.node_id: node for node in nodes
|
||||
}.get
|
||||
mock_index.vector_store._faiss_index.ntotal = len(nodes)
|
||||
mock_index.vector_store.table.count_rows.return_value = len(nodes)
|
||||
mock_index.vector_store.query.return_value = MagicMock(
|
||||
ids=list(mock_index.index_struct.nodes_dict),
|
||||
ids=[node.node_id for node in nodes],
|
||||
similarities=[0.1] * len(nodes),
|
||||
)
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
@@ -92,21 +89,21 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None:
|
||||
)
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_index.index_struct.nodes_dict = {
|
||||
"0": foreign_node.node_id,
|
||||
"1": missing_node.node_id,
|
||||
"2": allowed_node1.node_id,
|
||||
"3": allowed_node2.node_id,
|
||||
}
|
||||
mock_index.docstore.docs.get.side_effect = {
|
||||
allowed_node1.node_id: allowed_node1,
|
||||
allowed_node2.node_id: allowed_node2,
|
||||
foreign_node.node_id: foreign_node,
|
||||
}.get
|
||||
mock_index.vector_store._faiss_index.ntotal = 4
|
||||
mock_index.vector_store.table.count_rows.return_value = 4
|
||||
mock_index.vector_store.query.side_effect = [
|
||||
MagicMock(ids=["0", "2"], similarities=[0.9, 0.8]),
|
||||
MagicMock(ids=["0", "1", "3"], similarities=[0.9, 0.7, 0.6]),
|
||||
MagicMock(
|
||||
ids=[foreign_node.node_id, allowed_node1.node_id],
|
||||
similarities=[0.9, 0.8],
|
||||
),
|
||||
MagicMock(
|
||||
ids=[foreign_node.node_id, missing_node.node_id, allowed_node2.node_id],
|
||||
similarities=[0.9, 0.7, 0.6],
|
||||
),
|
||||
]
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
@@ -128,9 +125,9 @@ def test_document_filtered_retriever_expands_filters_and_caches() -> None:
|
||||
assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1
|
||||
|
||||
|
||||
def test_document_filtered_retriever_handles_empty_faiss_index() -> None:
|
||||
def test_document_filtered_retriever_handles_empty_vector_store() -> None:
|
||||
mock_index = MagicMock()
|
||||
mock_index.vector_store._faiss_index.ntotal = 0
|
||||
mock_index.vector_store.table.count_rows.return_value = 0
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
|
||||
Reference in New Issue
Block a user