diff --git a/src/paperless_ai/tests/test_vector_store.py b/src/paperless_ai/tests/test_vector_store.py index b409ed1c9..933961126 100644 --- a/src/paperless_ai/tests/test_vector_store.py +++ b/src/paperless_ai/tests/test_vector_store.py @@ -1,3 +1,4 @@ +import json from pathlib import Path import pytest @@ -9,11 +10,17 @@ from llama_index.core.vector_stores.types import MetadataFilter from llama_index.core.vector_stores.types import MetadataFilters from llama_index.core.vector_stores.types import VectorStoreQuery +from paperless_ai.vector_store import CURRENT_SCHEMA_VERSION from paperless_ai.vector_store import PaperlessLanceVectorStore DIM = 8 +@pytest.fixture +def store(tmp_path: Path) -> PaperlessLanceVectorStore: + return PaperlessLanceVectorStore(uri=str(tmp_path / "idx")) + + def _node(node_id: str, document_id: str, text: str, vec: float) -> TextNode: node = TextNode(id_=node_id, text=text, metadata={"document_id": document_id}) node.set_content(text) @@ -26,10 +33,6 @@ def _node(node_id: str, document_id: str, text: str, vec: float) -> TextNode: class TestPaperlessLanceVectorStoreCrud: - @pytest.fixture - def store(self, tmp_path: Path) -> PaperlessLanceVectorStore: - return PaperlessLanceVectorStore(uri=str(tmp_path / "idx")) - def test_add_then_query_returns_node( self, store: PaperlessLanceVectorStore, @@ -188,9 +191,11 @@ class TestPaperlessLanceVectorStoreCrud: class TestPaperlessLanceVectorStoreUpsert: @pytest.fixture - def store(self, tmp_path: Path) -> PaperlessLanceVectorStore: - s = PaperlessLanceVectorStore(uri=str(tmp_path / "idx")) - s.add( + def filled_store( + self, + store: PaperlessLanceVectorStore, + ) -> PaperlessLanceVectorStore: + store.add( [ _node("1-0", "1", "old0", 0.1), _node("1-1", "1", "old1", 0.2), @@ -198,18 +203,18 @@ class TestPaperlessLanceVectorStoreUpsert: _node("2-0", "2", "keep", 0.9), ], ) - return s + return store def test_upsert_prunes_stale_chunks_and_keeps_others( self, - store: PaperlessLanceVectorStore, + filled_store: PaperlessLanceVectorStore, ) -> None: - store.upsert_document( + filled_store.upsert_document( "1", [_node("1-0", "1", "new0", 0.1), _node("1-1", "1", "new1", 0.2)], ) - table = store.client.open_table("documents") + table = filled_store.client.open_table("documents") doc1 = sorted( r["id"] for r in table.search().where("document_id = '1'").to_list() ) @@ -218,30 +223,26 @@ class TestPaperlessLanceVectorStoreUpsert: def test_upsert_is_single_commit( self, - store: PaperlessLanceVectorStore, + filled_store: PaperlessLanceVectorStore, ) -> None: - table = store.client.open_table("documents") + table = filled_store.client.open_table("documents") before = table.version - store.upsert_document("1", [_node("1-0", "1", "new0", 0.1)]) - assert store.client.open_table("documents").version == before + 1 + filled_store.upsert_document("1", [_node("1-0", "1", "new0", 0.1)]) + assert filled_store.client.open_table("documents").version == before + 1 def test_upsert_empty_nodes_removes_document( self, - store: PaperlessLanceVectorStore, + filled_store: PaperlessLanceVectorStore, ) -> None: - store.upsert_document("1", []) + filled_store.upsert_document("1", []) - table = store.client.open_table("documents") + table = filled_store.client.open_table("documents") remaining = sorted(r["document_id"] for r in table.search().to_list()) assert "1" not in remaining assert "2" in remaining class TestPaperlessLanceVectorStoreMaintenance: - @pytest.fixture - def store(self, tmp_path: Path) -> PaperlessLanceVectorStore: - return PaperlessLanceVectorStore(uri=str(tmp_path / "idx")) - def test_maybe_create_ann_index_noop_below_threshold( self, store: PaperlessLanceVectorStore, @@ -415,3 +416,53 @@ class TestGetModifiedTimes: "1": "2024-01-01T00:00:00", "2": "2024-06-01T00:00:00", } + + +class TestSchemaVersioning: + @pytest.fixture + def uri(self, tmp_path: Path) -> str: + return str(tmp_path / "idx") + + def test_version_file_written_on_table_creation(self, uri: str) -> None: + + store = PaperlessLanceVectorStore(uri=uri) + store.add([_node("1-0", "1", "text", 0.1)]) + + version_file = Path(uri) / "schema_version.json" + assert version_file.exists() + assert json.loads(version_file.read_text())["version"] == CURRENT_SCHEMA_VERSION + + def test_stored_schema_version_returns_current_when_file_missing( + self, + uri: str, + ) -> None: + + store = PaperlessLanceVectorStore(uri=uri) + store.add([_node("1-0", "1", "text", 0.1)]) + (Path(uri) / "schema_version.json").unlink() + + reopened = PaperlessLanceVectorStore(uri=uri) + assert reopened.stored_schema_version() == CURRENT_SCHEMA_VERSION + + def test_stored_schema_version_persists_after_reopen(self, uri: str) -> None: + + PaperlessLanceVectorStore(uri=uri).add([_node("1-0", "1", "text", 0.1)]) + + reopened = PaperlessLanceVectorStore(uri=uri) + assert reopened.stored_schema_version() == CURRENT_SCHEMA_VERSION + + def test_drop_table_removes_version_file(self, uri: str) -> None: + store = PaperlessLanceVectorStore(uri=uri) + store.add([_node("1-0", "1", "text", 0.1)]) + assert (Path(uri) / "schema_version.json").exists() + + store.drop_table() + assert not (Path(uri) / "schema_version.json").exists() + + def test_version_file_written_on_upsert_creation(self, uri: str) -> None: + + store = PaperlessLanceVectorStore(uri=uri) + store.upsert_document("1", [_node("1-0", "1", "text", 0.1)]) + + version_file = Path(uri) / "schema_version.json" + assert json.loads(version_file.read_text())["version"] == CURRENT_SCHEMA_VERSION diff --git a/src/paperless_ai/vector_store.py b/src/paperless_ai/vector_store.py index 0e731e5c9..afbc11603 100644 --- a/src/paperless_ai/vector_store.py +++ b/src/paperless_ai/vector_store.py @@ -1,7 +1,9 @@ import json import logging from collections.abc import Sequence +from pathlib import Path from typing import Any +from typing import Final import lancedb import pyarrow as pa @@ -18,7 +20,8 @@ from llama_index.core.vector_stores.utils import node_to_metadata_dict logger = logging.getLogger("paperless_ai.vector_store") -DEFAULT_TABLE_NAME = "documents" +DEFAULT_TABLE_NAME: Final = "documents" +CURRENT_SCHEMA_VERSION: Final[int] = 1 # Below this many chunks, LanceDB's exact (brute-force) search is sufficient and # faster than building an ANN index (per LanceDB guidance, ~100K vectors). @@ -107,6 +110,7 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore): if self.table_exists(): self._conn.drop_table(self._table_name) self._table = None + self._schema_version_path.unlink(missing_ok=True) def stored_model_name(self) -> str | None: """Return the embedding model name stored in table schema metadata, or None.""" @@ -116,6 +120,25 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore): value = meta.get(b"embed_model") return value.decode() if value else None + @property + def _schema_version_path(self) -> Path: + return Path(self._uri) / "schema_version.json" + + def stored_schema_version(self) -> int: + """Return the schema version recorded on disk, or CURRENT_SCHEMA_VERSION if missing. + + Missing means either the table predates versioning or was just created and the + write hasn't happened yet — treat conservatively as already current. + """ + try: + return int(json.loads(self._schema_version_path.read_text())["version"]) + except (FileNotFoundError, KeyError, ValueError): + return CURRENT_SCHEMA_VERSION + + def _write_schema_version(self, version: int) -> None: + self._schema_version_path.parent.mkdir(parents=True, exist_ok=True) + self._schema_version_path.write_text(json.dumps({"version": version})) + def config_mismatch(self, model_name: str) -> bool: """True when the stored model name differs from ``model_name``. @@ -171,6 +194,7 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore): rows, schema=self._schema(dim, self._embed_model_name), ) + self._write_schema_version(CURRENT_SCHEMA_VERSION) return True def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: