diff --git a/src/paperless_ai/tests/test_vector_store.py b/src/paperless_ai/tests/test_vector_store.py index 33df3b4ac..8e2f1a7a7 100644 --- a/src/paperless_ai/tests/test_vector_store.py +++ b/src/paperless_ai/tests/test_vector_store.py @@ -143,3 +143,27 @@ class TestPaperlessLanceVectorStoreCrud: assert store.vector_dim() == DIM store.drop_table() assert store.table_exists() is False + + def test_build_where_or_condition(self) -> None: + from llama_index.core.vector_stores.types import FilterCondition + + from paperless_ai.vector_store import _build_where + + where = _build_where( + MetadataFilters( + filters=[ + MetadataFilter( + key="document_id", + operator=FilterOperator.EQ, + value="1", + ), + MetadataFilter( + key="document_id", + operator=FilterOperator.EQ, + value="2", + ), + ], + condition=FilterCondition.OR, + ), + ) + assert where == "document_id = '1' OR document_id = '2'" diff --git a/src/paperless_ai/vector_store.py b/src/paperless_ai/vector_store.py index 2509cf3b4..96bd31e66 100644 --- a/src/paperless_ai/vector_store.py +++ b/src/paperless_ai/vector_store.py @@ -1,5 +1,6 @@ import json import logging +from collections.abc import Sequence from typing import Any import lancedb @@ -65,7 +66,7 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore): self._uri = uri self._table_name = table_name self._conn = lancedb.connect(uri) - existing = list(self._conn.table_names()) + existing = self._conn.list_tables().tables self._table = ( self._conn.open_table(table_name) if table_name in existing else None ) @@ -75,7 +76,7 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore): return self._conn def table_exists(self) -> bool: - return self._table_name in list(self._conn.table_names()) + return self._table_name in self._conn.list_tables().tables def vector_dim(self) -> int | None: if self._table is None: @@ -113,7 +114,7 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore): "node_content": json.dumps(meta), } - def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: if not nodes: return [] rows = [self._row(node) for node in nodes] @@ -130,7 +131,7 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore): def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: if self._table is not None: - self._table.delete(f'doc_id = "{_escape(ref_doc_id)}"') + self._table.delete(f"doc_id = '{_escape(ref_doc_id)}'") def delete_nodes( self, @@ -180,14 +181,14 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore): ) -> VectorStoreQueryResult: if self._table is None: return VectorStoreQueryResult(nodes=[], similarities=[], ids=[]) - top_k = query.similarity_top_k or 10 + top_k = query.similarity_top_k if query.similarity_top_k is not None else 10 search = self._table.search(query.query_embedding).limit(top_k) where = _build_where(query.filters) if where: search = search.where(where) rows = search.to_list() nodes = self._rows_to_nodes(rows) - # LanceDB returns squared-L2 distance; map to a descending similarity. + # LanceDB returns an L2 distance (smaller = closer); map to a descending similarity. sims = [1.0 / (1.0 + float(row["_distance"])) for row in rows] ids = [row["id"] for row in rows] return VectorStoreQueryResult(nodes=nodes, similarities=sims, ids=ids)