diff --git a/src/paperless_ai/tests/test_vector_store.py b/src/paperless_ai/tests/test_vector_store.py index 8e2f1a7a7..2633e513f 100644 --- a/src/paperless_ai/tests/test_vector_store.py +++ b/src/paperless_ai/tests/test_vector_store.py @@ -145,6 +145,7 @@ class TestPaperlessLanceVectorStoreCrud: assert store.table_exists() is False def test_build_where_or_condition(self) -> None: + from llama_index.core.vector_stores.types import FilterCondition from paperless_ai.vector_store import _build_where @@ -167,3 +168,43 @@ class TestPaperlessLanceVectorStoreCrud: ), ) assert where == "document_id = '1' OR document_id = '2'" + + +class TestPaperlessLanceVectorStoreUpsert: + @pytest.fixture + def store(self, tmp_path: Path) -> PaperlessLanceVectorStore: + s = PaperlessLanceVectorStore(uri=str(tmp_path / "idx")) + s.add( + [ + _node("1-0", "1", "old0", 0.1), + _node("1-1", "1", "old1", 0.2), + _node("1-2", "1", "old2", 0.3), + _node("2-0", "2", "keep", 0.9), + ], + ) + return s + + def test_upsert_prunes_stale_chunks_and_keeps_others( + self, + store: PaperlessLanceVectorStore, + ) -> None: + store.upsert_document( + "1", + [_node("1-0", "1", "new0", 0.1), _node("1-1", "1", "new1", 0.2)], + ) + + table = store.client.open_table("documents") + doc1 = sorted( + r["id"] for r in table.search().where("document_id = '1'").to_list() + ) + assert doc1 == ["1-0", "1-1"] # 1-2 pruned + assert table.count_rows() == 3 # 2 new doc1 + 1 doc2 + + def test_upsert_is_single_commit( + self, + store: PaperlessLanceVectorStore, + ) -> None: + table = store.client.open_table("documents") + before = table.version + store.upsert_document("1", [_node("1-0", "1", "new0", 0.1)]) + assert store.client.open_table("documents").version == before + 1 diff --git a/src/paperless_ai/vector_store.py b/src/paperless_ai/vector_store.py index 96bd31e66..67b242176 100644 --- a/src/paperless_ai/vector_store.py +++ b/src/paperless_ai/vector_store.py @@ -129,6 +129,39 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore): self._table.add(rows) return [node.node_id for node in nodes] + def upsert_document(self, document_id: str, nodes: list[BaseNode]) -> list[str]: + """Atomically replace all stored chunks of ``document_id`` with ``nodes``. + + A single ``merge_insert`` commit: matching node ids are updated, new ids + inserted, and any existing rows for this document that are not in the new + set are deleted (``when_not_matched_by_source_delete``). This prunes stale + trailing chunks when an edit reduces a document's chunk count, with no + transient empty state for concurrent lock-free readers. + """ + if not nodes: + # No indexable content: treat as a removal. + self.delete(document_id) + return [] + rows = [self._row(node) for node in nodes] + if self._table is None: + dim = len(nodes[0].get_embedding()) + self._table = self._conn.create_table( + self._table_name, + rows, + schema=self._schema(dim), + ) + return [node.node_id for node in nodes] + ( + self._table.merge_insert("id") + .when_matched_update_all() + .when_not_matched_insert_all() + .when_not_matched_by_source_delete( + f"document_id = '{_escape(document_id)}'", + ) + .execute(rows) + ) + return [node.node_id for node in nodes] + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: if self._table is not None: self._table.delete(f"doc_id = '{_escape(ref_doc_id)}'")