mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-06 13:49:44 +00:00
feat(ai): add LanceDB-backed vector store adapter
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,145 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from llama_index.core.schema import NodeRelationship
|
||||
from llama_index.core.schema import RelatedNodeInfo
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import FilterOperator
|
||||
from llama_index.core.vector_stores.types import MetadataFilter
|
||||
from llama_index.core.vector_stores.types import MetadataFilters
|
||||
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||
|
||||
from paperless_ai.vector_store import PaperlessLanceVectorStore
|
||||
|
||||
DIM = 8
|
||||
|
||||
|
||||
def _node(node_id: str, document_id: str, text: str, vec: float) -> TextNode:
|
||||
node = TextNode(id_=node_id, text=text, metadata={"document_id": document_id})
|
||||
node.set_content(text)
|
||||
node.embedding = [vec] * DIM
|
||||
# Use relationships so ref_doc_id resolves correctly (it's a read-only property)
|
||||
node.relationships = {
|
||||
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=document_id),
|
||||
}
|
||||
return node
|
||||
|
||||
|
||||
class TestPaperlessLanceVectorStoreCrud:
|
||||
@pytest.fixture
|
||||
def store(self, tmp_path: Path) -> PaperlessLanceVectorStore:
|
||||
return PaperlessLanceVectorStore(uri=str(tmp_path / "idx"))
|
||||
|
||||
def test_add_then_query_returns_node(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
) -> None:
|
||||
store.add([_node("1-0", "1", "alpha", 0.1), _node("2-0", "2", "beta", 0.9)])
|
||||
|
||||
result = store.query(
|
||||
VectorStoreQuery(query_embedding=[0.1] * DIM, similarity_top_k=1),
|
||||
)
|
||||
|
||||
assert len(result.nodes) == 1
|
||||
assert result.nodes[0].metadata["document_id"] == "1"
|
||||
|
||||
def test_query_empty_table_returns_empty_no_raise(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
) -> None:
|
||||
result = store.query(
|
||||
VectorStoreQuery(query_embedding=[0.1] * DIM, similarity_top_k=5),
|
||||
)
|
||||
assert result.nodes == []
|
||||
assert result.ids == []
|
||||
|
||||
def test_delete_removes_all_chunks_of_document(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
) -> None:
|
||||
store.add([_node("1-0", "1", "a", 0.1), _node("1-1", "1", "b", 0.2)])
|
||||
store.add([_node("2-0", "2", "c", 0.9)])
|
||||
|
||||
store.delete("1")
|
||||
|
||||
assert store.client.open_table("documents").count_rows() == 1
|
||||
|
||||
def test_query_with_in_filter_scopes_results(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
) -> None:
|
||||
store.add([_node("1-0", "1", "a", 0.1), _node("2-0", "2", "b", 0.1)])
|
||||
|
||||
result = store.query(
|
||||
VectorStoreQuery(
|
||||
query_embedding=[0.1] * DIM,
|
||||
similarity_top_k=5,
|
||||
filters=MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="document_id",
|
||||
operator=FilterOperator.IN,
|
||||
value=["2"],
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
assert [n.metadata["document_id"] for n in result.nodes] == ["2"]
|
||||
|
||||
def test_get_nodes_filter_returns_empty_cleanly(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
) -> None:
|
||||
store.add([_node("1-0", "1", "a", 0.1)])
|
||||
nodes = store.get_nodes(
|
||||
filters=MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="document_id",
|
||||
operator=FilterOperator.IN,
|
||||
value=["999"],
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
assert nodes == []
|
||||
|
||||
def test_fresh_instance_filters_existing_table(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
uri = str(tmp_path / "idx")
|
||||
PaperlessLanceVectorStore(uri=uri).add(
|
||||
[_node("1-0", "1", "a", 0.1), _node("2-0", "2", "b", 0.1)],
|
||||
)
|
||||
|
||||
reopened = PaperlessLanceVectorStore(uri=uri)
|
||||
result = reopened.query(
|
||||
VectorStoreQuery(
|
||||
query_embedding=[0.1] * DIM,
|
||||
similarity_top_k=5,
|
||||
filters=MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="document_id",
|
||||
operator=FilterOperator.IN,
|
||||
value=["1"],
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
assert [n.metadata["document_id"] for n in result.nodes] == ["1"]
|
||||
|
||||
def test_table_exists_and_drop(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
) -> None:
|
||||
assert store.table_exists() is False
|
||||
store.add([_node("1-0", "1", "a", 0.1)])
|
||||
assert store.table_exists() is True
|
||||
assert store.vector_dim() == DIM
|
||||
store.drop_table()
|
||||
assert store.table_exists() is False
|
||||
@@ -0,0 +1,193 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.core.vector_stores.types import FilterCondition
|
||||
from llama_index.core.vector_stores.types import FilterOperator
|
||||
from llama_index.core.vector_stores.types import MetadataFilters
|
||||
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryResult
|
||||
from llama_index.core.vector_stores.utils import metadata_dict_to_node
|
||||
from llama_index.core.vector_stores.utils import node_to_metadata_dict
|
||||
|
||||
logger = logging.getLogger("paperless_ai.vector_store")
|
||||
|
||||
DEFAULT_TABLE_NAME = "documents"
|
||||
|
||||
|
||||
def _escape(value: str) -> str:
|
||||
return str(value).replace("'", "''")
|
||||
|
||||
|
||||
def _build_where(filters: MetadataFilters | None) -> str | None:
|
||||
"""Translate the EQ / IN filters we use into a Lance SQL predicate on the
|
||||
top-level ``document_id`` column."""
|
||||
if filters is None or not filters.filters:
|
||||
return None
|
||||
clauses: list[str] = []
|
||||
for f in filters.filters:
|
||||
if f.operator == FilterOperator.IN:
|
||||
vals = ",".join(f"'{_escape(v)}'" for v in f.value)
|
||||
clauses.append(f"{f.key} IN ({vals})")
|
||||
elif f.operator == FilterOperator.EQ:
|
||||
clauses.append(f"{f.key} = '{_escape(f.value)}'")
|
||||
else: # pragma: no cover - we only ever build EQ/IN filters
|
||||
raise NotImplementedError(f"Unsupported filter operator: {f.operator}")
|
||||
joiner = " OR " if filters.condition == FilterCondition.OR else " AND "
|
||||
return joiner.join(clauses)
|
||||
|
||||
|
||||
class PaperlessLanceVectorStore(BasePydanticVectorStore):
|
||||
"""A llama-index vector store backed directly by a LanceDB table.
|
||||
|
||||
Stores one row per node with the node id, its document id (both as the
|
||||
``ref_doc_id`` delete key ``doc_id`` and a top-level filter column
|
||||
``document_id``), the embedding, and the serialised node (text + metadata)
|
||||
as JSON. ``stores_text`` lets llama-index run off this store alone, with no
|
||||
separate docstore or index store.
|
||||
"""
|
||||
|
||||
stores_text: bool = True
|
||||
flat_metadata: bool = True
|
||||
|
||||
_uri: str = PrivateAttr()
|
||||
_table_name: str = PrivateAttr()
|
||||
_conn: Any = PrivateAttr()
|
||||
_table: Any = PrivateAttr()
|
||||
|
||||
def __init__(self, uri: str, table_name: str = DEFAULT_TABLE_NAME) -> None:
|
||||
super().__init__()
|
||||
self._uri = uri
|
||||
self._table_name = table_name
|
||||
self._conn = lancedb.connect(uri)
|
||||
existing = list(self._conn.table_names())
|
||||
self._table = (
|
||||
self._conn.open_table(table_name) if table_name in existing else None
|
||||
)
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
return self._conn
|
||||
|
||||
def table_exists(self) -> bool:
|
||||
return self._table_name in list(self._conn.table_names())
|
||||
|
||||
def vector_dim(self) -> int | None:
|
||||
if self._table is None:
|
||||
return None
|
||||
return self._table.schema.field("vector").type.list_size
|
||||
|
||||
def drop_table(self) -> None:
|
||||
if self.table_exists():
|
||||
self._conn.drop_table(self._table_name)
|
||||
self._table = None
|
||||
|
||||
@staticmethod
|
||||
def _schema(dim: int) -> pa.Schema:
|
||||
return pa.schema(
|
||||
[
|
||||
pa.field("id", pa.string()),
|
||||
pa.field("doc_id", pa.string()),
|
||||
pa.field("document_id", pa.string()),
|
||||
pa.field("vector", pa.list_(pa.float32(), dim)),
|
||||
pa.field("node_content", pa.string()),
|
||||
],
|
||||
)
|
||||
|
||||
def _row(self, node: BaseNode) -> dict[str, Any]:
|
||||
meta = node_to_metadata_dict(
|
||||
node,
|
||||
remove_text=False,
|
||||
flat_metadata=self.flat_metadata,
|
||||
)
|
||||
return {
|
||||
"id": node.node_id,
|
||||
"doc_id": node.ref_doc_id,
|
||||
"document_id": str(node.metadata.get("document_id")),
|
||||
"vector": node.get_embedding(),
|
||||
"node_content": json.dumps(meta),
|
||||
}
|
||||
|
||||
def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]:
|
||||
if not nodes:
|
||||
return []
|
||||
rows = [self._row(node) for node in nodes]
|
||||
if self._table is None:
|
||||
dim = len(nodes[0].get_embedding())
|
||||
self._table = self._conn.create_table(
|
||||
self._table_name,
|
||||
rows,
|
||||
schema=self._schema(dim),
|
||||
)
|
||||
else:
|
||||
self._table.add(rows)
|
||||
return [node.node_id for node in nodes]
|
||||
|
||||
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
||||
if self._table is not None:
|
||||
self._table.delete(f'doc_id = "{_escape(ref_doc_id)}"')
|
||||
|
||||
def delete_nodes(
|
||||
self,
|
||||
node_ids: list[str] | None = None,
|
||||
filters: MetadataFilters | None = None,
|
||||
**delete_kwargs: Any,
|
||||
) -> None:
|
||||
if self._table is None:
|
||||
return
|
||||
if node_ids:
|
||||
ids = ",".join(f'"{_escape(n)}"' for n in node_ids)
|
||||
self._table.delete(f"id IN ({ids})")
|
||||
elif filters is not None:
|
||||
where = _build_where(filters)
|
||||
if where:
|
||||
self._table.delete(where)
|
||||
|
||||
def _rows_to_nodes(self, rows: list[dict[str, Any]]) -> list[BaseNode]:
|
||||
nodes: list[BaseNode] = []
|
||||
for row in rows:
|
||||
node = metadata_dict_to_node(json.loads(row["node_content"]))
|
||||
node.embedding = list(row["vector"])
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
def get_nodes(
|
||||
self,
|
||||
node_ids: list[str] | None = None,
|
||||
filters: MetadataFilters | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[BaseNode]:
|
||||
if self._table is None:
|
||||
return []
|
||||
query = self._table.search()
|
||||
where = _build_where(filters)
|
||||
if node_ids:
|
||||
ids = ",".join(f'"{_escape(n)}"' for n in node_ids)
|
||||
query = query.where(f"id IN ({ids})")
|
||||
elif where:
|
||||
query = query.where(where)
|
||||
return self._rows_to_nodes(query.to_list())
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: VectorStoreQuery,
|
||||
**kwargs: Any,
|
||||
) -> VectorStoreQueryResult:
|
||||
if self._table is None:
|
||||
return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
|
||||
top_k = query.similarity_top_k or 10
|
||||
search = self._table.search(query.query_embedding).limit(top_k)
|
||||
where = _build_where(query.filters)
|
||||
if where:
|
||||
search = search.where(where)
|
||||
rows = search.to_list()
|
||||
nodes = self._rows_to_nodes(rows)
|
||||
# LanceDB returns squared-L2 distance; map to a descending similarity.
|
||||
sims = [1.0 / (1.0 + float(row["_distance"])) for row in rows]
|
||||
ids = [row["id"] for row in rows]
|
||||
return VectorStoreQueryResult(nodes=nodes, similarities=sims, ids=ids)
|
||||
Reference in New Issue
Block a user