mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-10 23:59:43 +00:00
feat(ai): add schema version file tracking to LanceDB vector store
Adds CURRENT_SCHEMA_VERSION constant, a schema_version.json file written alongside the LanceDB data directory on table creation, and a stored_schema_version() helper that reads it back — forming the foundation for future schema migration detection without re-embedding all documents. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -9,11 +10,17 @@ from llama_index.core.vector_stores.types import MetadataFilter
|
||||
from llama_index.core.vector_stores.types import MetadataFilters
|
||||
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||
|
||||
from paperless_ai.vector_store import CURRENT_SCHEMA_VERSION
|
||||
from paperless_ai.vector_store import PaperlessLanceVectorStore
|
||||
|
||||
DIM = 8
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path: Path) -> PaperlessLanceVectorStore:
|
||||
return PaperlessLanceVectorStore(uri=str(tmp_path / "idx"))
|
||||
|
||||
|
||||
def _node(node_id: str, document_id: str, text: str, vec: float) -> TextNode:
|
||||
node = TextNode(id_=node_id, text=text, metadata={"document_id": document_id})
|
||||
node.set_content(text)
|
||||
@@ -26,10 +33,6 @@ def _node(node_id: str, document_id: str, text: str, vec: float) -> TextNode:
|
||||
|
||||
|
||||
class TestPaperlessLanceVectorStoreCrud:
|
||||
@pytest.fixture
|
||||
def store(self, tmp_path: Path) -> PaperlessLanceVectorStore:
|
||||
return PaperlessLanceVectorStore(uri=str(tmp_path / "idx"))
|
||||
|
||||
def test_add_then_query_returns_node(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
@@ -188,9 +191,11 @@ class TestPaperlessLanceVectorStoreCrud:
|
||||
|
||||
class TestPaperlessLanceVectorStoreUpsert:
|
||||
@pytest.fixture
|
||||
def store(self, tmp_path: Path) -> PaperlessLanceVectorStore:
|
||||
s = PaperlessLanceVectorStore(uri=str(tmp_path / "idx"))
|
||||
s.add(
|
||||
def filled_store(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
) -> PaperlessLanceVectorStore:
|
||||
store.add(
|
||||
[
|
||||
_node("1-0", "1", "old0", 0.1),
|
||||
_node("1-1", "1", "old1", 0.2),
|
||||
@@ -198,18 +203,18 @@ class TestPaperlessLanceVectorStoreUpsert:
|
||||
_node("2-0", "2", "keep", 0.9),
|
||||
],
|
||||
)
|
||||
return s
|
||||
return store
|
||||
|
||||
def test_upsert_prunes_stale_chunks_and_keeps_others(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
filled_store: PaperlessLanceVectorStore,
|
||||
) -> None:
|
||||
store.upsert_document(
|
||||
filled_store.upsert_document(
|
||||
"1",
|
||||
[_node("1-0", "1", "new0", 0.1), _node("1-1", "1", "new1", 0.2)],
|
||||
)
|
||||
|
||||
table = store.client.open_table("documents")
|
||||
table = filled_store.client.open_table("documents")
|
||||
doc1 = sorted(
|
||||
r["id"] for r in table.search().where("document_id = '1'").to_list()
|
||||
)
|
||||
@@ -218,30 +223,26 @@ class TestPaperlessLanceVectorStoreUpsert:
|
||||
|
||||
def test_upsert_is_single_commit(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
filled_store: PaperlessLanceVectorStore,
|
||||
) -> None:
|
||||
table = store.client.open_table("documents")
|
||||
table = filled_store.client.open_table("documents")
|
||||
before = table.version
|
||||
store.upsert_document("1", [_node("1-0", "1", "new0", 0.1)])
|
||||
assert store.client.open_table("documents").version == before + 1
|
||||
filled_store.upsert_document("1", [_node("1-0", "1", "new0", 0.1)])
|
||||
assert filled_store.client.open_table("documents").version == before + 1
|
||||
|
||||
def test_upsert_empty_nodes_removes_document(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
filled_store: PaperlessLanceVectorStore,
|
||||
) -> None:
|
||||
store.upsert_document("1", [])
|
||||
filled_store.upsert_document("1", [])
|
||||
|
||||
table = store.client.open_table("documents")
|
||||
table = filled_store.client.open_table("documents")
|
||||
remaining = sorted(r["document_id"] for r in table.search().to_list())
|
||||
assert "1" not in remaining
|
||||
assert "2" in remaining
|
||||
|
||||
|
||||
class TestPaperlessLanceVectorStoreMaintenance:
|
||||
@pytest.fixture
|
||||
def store(self, tmp_path: Path) -> PaperlessLanceVectorStore:
|
||||
return PaperlessLanceVectorStore(uri=str(tmp_path / "idx"))
|
||||
|
||||
def test_maybe_create_ann_index_noop_below_threshold(
|
||||
self,
|
||||
store: PaperlessLanceVectorStore,
|
||||
@@ -415,3 +416,53 @@ class TestGetModifiedTimes:
|
||||
"1": "2024-01-01T00:00:00",
|
||||
"2": "2024-06-01T00:00:00",
|
||||
}
|
||||
|
||||
|
||||
class TestSchemaVersioning:
|
||||
@pytest.fixture
|
||||
def uri(self, tmp_path: Path) -> str:
|
||||
return str(tmp_path / "idx")
|
||||
|
||||
def test_version_file_written_on_table_creation(self, uri: str) -> None:
|
||||
|
||||
store = PaperlessLanceVectorStore(uri=uri)
|
||||
store.add([_node("1-0", "1", "text", 0.1)])
|
||||
|
||||
version_file = Path(uri) / "schema_version.json"
|
||||
assert version_file.exists()
|
||||
assert json.loads(version_file.read_text())["version"] == CURRENT_SCHEMA_VERSION
|
||||
|
||||
def test_stored_schema_version_returns_current_when_file_missing(
|
||||
self,
|
||||
uri: str,
|
||||
) -> None:
|
||||
|
||||
store = PaperlessLanceVectorStore(uri=uri)
|
||||
store.add([_node("1-0", "1", "text", 0.1)])
|
||||
(Path(uri) / "schema_version.json").unlink()
|
||||
|
||||
reopened = PaperlessLanceVectorStore(uri=uri)
|
||||
assert reopened.stored_schema_version() == CURRENT_SCHEMA_VERSION
|
||||
|
||||
def test_stored_schema_version_persists_after_reopen(self, uri: str) -> None:
|
||||
|
||||
PaperlessLanceVectorStore(uri=uri).add([_node("1-0", "1", "text", 0.1)])
|
||||
|
||||
reopened = PaperlessLanceVectorStore(uri=uri)
|
||||
assert reopened.stored_schema_version() == CURRENT_SCHEMA_VERSION
|
||||
|
||||
def test_drop_table_removes_version_file(self, uri: str) -> None:
|
||||
store = PaperlessLanceVectorStore(uri=uri)
|
||||
store.add([_node("1-0", "1", "text", 0.1)])
|
||||
assert (Path(uri) / "schema_version.json").exists()
|
||||
|
||||
store.drop_table()
|
||||
assert not (Path(uri) / "schema_version.json").exists()
|
||||
|
||||
def test_version_file_written_on_upsert_creation(self, uri: str) -> None:
|
||||
|
||||
store = PaperlessLanceVectorStore(uri=uri)
|
||||
store.upsert_document("1", [_node("1-0", "1", "text", 0.1)])
|
||||
|
||||
version_file = Path(uri) / "schema_version.json"
|
||||
assert json.loads(version_file.read_text())["version"] == CURRENT_SCHEMA_VERSION
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Final
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
@@ -18,7 +20,8 @@ from llama_index.core.vector_stores.utils import node_to_metadata_dict
|
||||
|
||||
logger = logging.getLogger("paperless_ai.vector_store")
|
||||
|
||||
DEFAULT_TABLE_NAME = "documents"
|
||||
DEFAULT_TABLE_NAME: Final = "documents"
|
||||
CURRENT_SCHEMA_VERSION: Final[int] = 1
|
||||
|
||||
# Below this many chunks, LanceDB's exact (brute-force) search is sufficient and
|
||||
# faster than building an ANN index (per LanceDB guidance, ~100K vectors).
|
||||
@@ -107,6 +110,7 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore):
|
||||
if self.table_exists():
|
||||
self._conn.drop_table(self._table_name)
|
||||
self._table = None
|
||||
self._schema_version_path.unlink(missing_ok=True)
|
||||
|
||||
def stored_model_name(self) -> str | None:
|
||||
"""Return the embedding model name stored in table schema metadata, or None."""
|
||||
@@ -116,6 +120,25 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore):
|
||||
value = meta.get(b"embed_model")
|
||||
return value.decode() if value else None
|
||||
|
||||
@property
|
||||
def _schema_version_path(self) -> Path:
|
||||
return Path(self._uri) / "schema_version.json"
|
||||
|
||||
def stored_schema_version(self) -> int:
|
||||
"""Return the schema version recorded on disk, or CURRENT_SCHEMA_VERSION if missing.
|
||||
|
||||
Missing means either the table predates versioning or was just created and the
|
||||
write hasn't happened yet — treat conservatively as already current.
|
||||
"""
|
||||
try:
|
||||
return int(json.loads(self._schema_version_path.read_text())["version"])
|
||||
except (FileNotFoundError, KeyError, ValueError):
|
||||
return CURRENT_SCHEMA_VERSION
|
||||
|
||||
def _write_schema_version(self, version: int) -> None:
|
||||
self._schema_version_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._schema_version_path.write_text(json.dumps({"version": version}))
|
||||
|
||||
def config_mismatch(self, model_name: str) -> bool:
|
||||
"""True when the stored model name differs from ``model_name``.
|
||||
|
||||
@@ -171,6 +194,7 @@ class PaperlessLanceVectorStore(BasePydanticVectorStore):
|
||||
rows,
|
||||
schema=self._schema(dim, self._embed_model_name),
|
||||
)
|
||||
self._write_schema_version(CURRENT_SCHEMA_VERSION)
|
||||
return True
|
||||
|
||||
def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]:
|
||||
|
||||
Reference in New Issue
Block a user