feat(ai): add LanceDB-backed vector store adapter

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
stumpylog
2026-06-02 13:24:04 -07:00
parent df4607a492
commit b758cd1bdb
2 changed files with 338 additions and 0 deletions
+145
View File
@@ -0,0 +1,145 @@
from pathlib import Path
import pytest
from llama_index.core.schema import NodeRelationship
from llama_index.core.schema import RelatedNodeInfo
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
from llama_index.core.vector_stores.types import VectorStoreQuery
from paperless_ai.vector_store import PaperlessLanceVectorStore
DIM = 8
def _node(node_id: str, document_id: str, text: str, vec: float) -> TextNode:
node = TextNode(id_=node_id, text=text, metadata={"document_id": document_id})
node.set_content(text)
node.embedding = [vec] * DIM
# Use relationships so ref_doc_id resolves correctly (it's a read-only property)
node.relationships = {
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=document_id),
}
return node
class TestPaperlessLanceVectorStoreCrud:
@pytest.fixture
def store(self, tmp_path: Path) -> PaperlessLanceVectorStore:
return PaperlessLanceVectorStore(uri=str(tmp_path / "idx"))
def test_add_then_query_returns_node(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.add([_node("1-0", "1", "alpha", 0.1), _node("2-0", "2", "beta", 0.9)])
result = store.query(
VectorStoreQuery(query_embedding=[0.1] * DIM, similarity_top_k=1),
)
assert len(result.nodes) == 1
assert result.nodes[0].metadata["document_id"] == "1"
def test_query_empty_table_returns_empty_no_raise(
self,
store: PaperlessLanceVectorStore,
) -> None:
result = store.query(
VectorStoreQuery(query_embedding=[0.1] * DIM, similarity_top_k=5),
)
assert result.nodes == []
assert result.ids == []
def test_delete_removes_all_chunks_of_document(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.add([_node("1-0", "1", "a", 0.1), _node("1-1", "1", "b", 0.2)])
store.add([_node("2-0", "2", "c", 0.9)])
store.delete("1")
assert store.client.open_table("documents").count_rows() == 1
def test_query_with_in_filter_scopes_results(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.add([_node("1-0", "1", "a", 0.1), _node("2-0", "2", "b", 0.1)])
result = store.query(
VectorStoreQuery(
query_embedding=[0.1] * DIM,
similarity_top_k=5,
filters=MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=["2"],
),
],
),
),
)
assert [n.metadata["document_id"] for n in result.nodes] == ["2"]
def test_get_nodes_filter_returns_empty_cleanly(
self,
store: PaperlessLanceVectorStore,
) -> None:
store.add([_node("1-0", "1", "a", 0.1)])
nodes = store.get_nodes(
filters=MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=["999"],
),
],
),
)
assert nodes == []
def test_fresh_instance_filters_existing_table(
self,
tmp_path: Path,
) -> None:
uri = str(tmp_path / "idx")
PaperlessLanceVectorStore(uri=uri).add(
[_node("1-0", "1", "a", 0.1), _node("2-0", "2", "b", 0.1)],
)
reopened = PaperlessLanceVectorStore(uri=uri)
result = reopened.query(
VectorStoreQuery(
query_embedding=[0.1] * DIM,
similarity_top_k=5,
filters=MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=["1"],
),
],
),
),
)
assert [n.metadata["document_id"] for n in result.nodes] == ["1"]
def test_table_exists_and_drop(
self,
store: PaperlessLanceVectorStore,
) -> None:
assert store.table_exists() is False
store.add([_node("1-0", "1", "a", 0.1)])
assert store.table_exists() is True
assert store.vector_dim() == DIM
store.drop_table()
assert store.table_exists() is False
+193
View File
@@ -0,0 +1,193 @@
import json
import logging
from typing import Any
import lancedb
import pyarrow as pa
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.core.vector_stores.types import FilterCondition
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilters
from llama_index.core.vector_stores.types import VectorStoreQuery
from llama_index.core.vector_stores.types import VectorStoreQueryResult
from llama_index.core.vector_stores.utils import metadata_dict_to_node
from llama_index.core.vector_stores.utils import node_to_metadata_dict
logger = logging.getLogger("paperless_ai.vector_store")
DEFAULT_TABLE_NAME = "documents"
def _escape(value: str) -> str:
return str(value).replace("'", "''")
def _build_where(filters: MetadataFilters | None) -> str | None:
"""Translate the EQ / IN filters we use into a Lance SQL predicate on the
top-level ``document_id`` column."""
if filters is None or not filters.filters:
return None
clauses: list[str] = []
for f in filters.filters:
if f.operator == FilterOperator.IN:
vals = ",".join(f"'{_escape(v)}'" for v in f.value)
clauses.append(f"{f.key} IN ({vals})")
elif f.operator == FilterOperator.EQ:
clauses.append(f"{f.key} = '{_escape(f.value)}'")
else: # pragma: no cover - we only ever build EQ/IN filters
raise NotImplementedError(f"Unsupported filter operator: {f.operator}")
joiner = " OR " if filters.condition == FilterCondition.OR else " AND "
return joiner.join(clauses)
class PaperlessLanceVectorStore(BasePydanticVectorStore):
"""A llama-index vector store backed directly by a LanceDB table.
Stores one row per node with the node id, its document id (both as the
``ref_doc_id`` delete key ``doc_id`` and a top-level filter column
``document_id``), the embedding, and the serialised node (text + metadata)
as JSON. ``stores_text`` lets llama-index run off this store alone, with no
separate docstore or index store.
"""
stores_text: bool = True
flat_metadata: bool = True
_uri: str = PrivateAttr()
_table_name: str = PrivateAttr()
_conn: Any = PrivateAttr()
_table: Any = PrivateAttr()
def __init__(self, uri: str, table_name: str = DEFAULT_TABLE_NAME) -> None:
super().__init__()
self._uri = uri
self._table_name = table_name
self._conn = lancedb.connect(uri)
existing = list(self._conn.table_names())
self._table = (
self._conn.open_table(table_name) if table_name in existing else None
)
@property
def client(self) -> Any:
return self._conn
def table_exists(self) -> bool:
return self._table_name in list(self._conn.table_names())
def vector_dim(self) -> int | None:
if self._table is None:
return None
return self._table.schema.field("vector").type.list_size
def drop_table(self) -> None:
if self.table_exists():
self._conn.drop_table(self._table_name)
self._table = None
@staticmethod
def _schema(dim: int) -> pa.Schema:
return pa.schema(
[
pa.field("id", pa.string()),
pa.field("doc_id", pa.string()),
pa.field("document_id", pa.string()),
pa.field("vector", pa.list_(pa.float32(), dim)),
pa.field("node_content", pa.string()),
],
)
def _row(self, node: BaseNode) -> dict[str, Any]:
meta = node_to_metadata_dict(
node,
remove_text=False,
flat_metadata=self.flat_metadata,
)
return {
"id": node.node_id,
"doc_id": node.ref_doc_id,
"document_id": str(node.metadata.get("document_id")),
"vector": node.get_embedding(),
"node_content": json.dumps(meta),
}
def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]:
if not nodes:
return []
rows = [self._row(node) for node in nodes]
if self._table is None:
dim = len(nodes[0].get_embedding())
self._table = self._conn.create_table(
self._table_name,
rows,
schema=self._schema(dim),
)
else:
self._table.add(rows)
return [node.node_id for node in nodes]
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
if self._table is not None:
self._table.delete(f'doc_id = "{_escape(ref_doc_id)}"')
def delete_nodes(
self,
node_ids: list[str] | None = None,
filters: MetadataFilters | None = None,
**delete_kwargs: Any,
) -> None:
if self._table is None:
return
if node_ids:
ids = ",".join(f'"{_escape(n)}"' for n in node_ids)
self._table.delete(f"id IN ({ids})")
elif filters is not None:
where = _build_where(filters)
if where:
self._table.delete(where)
def _rows_to_nodes(self, rows: list[dict[str, Any]]) -> list[BaseNode]:
nodes: list[BaseNode] = []
for row in rows:
node = metadata_dict_to_node(json.loads(row["node_content"]))
node.embedding = list(row["vector"])
nodes.append(node)
return nodes
def get_nodes(
self,
node_ids: list[str] | None = None,
filters: MetadataFilters | None = None,
**kwargs: Any,
) -> list[BaseNode]:
if self._table is None:
return []
query = self._table.search()
where = _build_where(filters)
if node_ids:
ids = ",".join(f'"{_escape(n)}"' for n in node_ids)
query = query.where(f"id IN ({ids})")
elif where:
query = query.where(where)
return self._rows_to_nodes(query.to_list())
def query(
self,
query: VectorStoreQuery,
**kwargs: Any,
) -> VectorStoreQueryResult:
if self._table is None:
return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
top_k = query.similarity_top_k or 10
search = self._table.search(query.query_embedding).limit(top_k)
where = _build_where(query.filters)
if where:
search = search.where(where)
rows = search.to_list()
nodes = self._rows_to_nodes(rows)
# LanceDB returns squared-L2 distance; map to a descending similarity.
sims = [1.0 / (1.0 + float(row["_distance"])) for row in rows]
ids = [row["id"] for row in rows]
return VectorStoreQueryResult(nodes=nodes, similarities=sims, ids=ids)