diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index 9252ad852..a55479acf 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -32,11 +32,11 @@ RAG_CHUNK_OVERLAP = 200 def _index_lock_path() -> Path: - """Return the path used as the file lock for FAISS index mutations. + """Return the path used as the file lock for LanceDB index mutations. The lock file lives in DATA_DIR/locks/ (not inside LLM_INDEX_DIR) so that a - rebuild — which calls shutil.rmtree(LLM_INDEX_DIR) — cannot delete the lock - while another worker still holds it. + rebuild — which calls store.drop_table() — cannot interfere with another + worker that still holds the lock. """ return settings.LLM_INDEX_LOCK @@ -99,7 +99,7 @@ def build_document_node( metadata = { "document_id": str(document.id), "title": document.title, - "tags": json.dumps([t.name for t in document.tags.all()]), + "tags": [t.name for t in document.tags.all()], "correspondent": document.correspondent.name if document.correspondent else None, @@ -190,7 +190,7 @@ def get_rag_prompt_helper( def _iter_existing_modified(store: "PaperlessLanceVectorStore") -> list[dict]: """One representative row per document_id, for modified-time comparison.""" - if LLM_INDEX_TABLE not in store.client.table_names(): + if not store.table_exists(): return [] seen: dict[str, dict] = {} for row in store.client.open_table(LLM_INDEX_TABLE).search().to_list(): @@ -334,9 +334,7 @@ def query_similar_documents( top_k: int = 5, document_ids: Iterable[int | str] | None = None, ) -> list[Document]: - """ - Runs a similarity query and returns top-k similar Document objects. - """ + """Return up to ``top_k`` Documents most similar to ``document``.""" allowed_document_ids = normalize_document_ids(document_ids) if allowed_document_ids is not None and not allowed_document_ids: return [] @@ -348,62 +346,49 @@ def query_similar_documents( ) return [] - with FileLock(_index_lock_path()): - index = load_or_build_index() + from llama_index.core.retrievers import VectorIndexRetriever + from llama_index.core.vector_stores.types import FilterOperator + from llama_index.core.vector_stores.types import MetadataFilter + from llama_index.core.vector_stores.types import MetadataFilters - # constrain only the node(s) that match the document IDs, if given - doc_node_ids = ( - [ - node.node_id - for node in index.docstore.docs.values() - if node.metadata.get("document_id") in allowed_document_ids - ] - if allowed_document_ids is not None - else None - ) - if doc_node_ids is not None and not doc_node_ids: - return [] + index = load_or_build_index() - from llama_index.core.retrievers import VectorIndexRetriever - - retriever = VectorIndexRetriever( - index=index, - similarity_top_k=top_k, - doc_ids=doc_node_ids, + filters = None + if allowed_document_ids is not None: + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="document_id", + operator=FilterOperator.IN, + value=sorted(allowed_document_ids), + ), + ], ) - config = AIConfig() - query_text = truncate_content( - (document.title or "") + "\n" + (document.content or ""), - chunk_size=config.llm_embedding_chunk_size, - context_size=config.llm_context_size, - ) - try: - results = retriever.retrieve(query_text) - except KeyError as e: - # Ghost FAISS positions remain after deletion because IndexFlatL2 is - # append-only. Treat them as absent and return no results. - logger.debug( - "Skipping LLM similarity query for document %s due to a stale " - "FAISS position with no docstore node: %s", - document.pk, - e, - ) - return [] + retriever = VectorIndexRetriever( + index=index, + similarity_top_k=top_k, + filters=filters, + ) + + config = AIConfig() + query_text = truncate_content( + (document.title or "") + "\n" + (document.content or ""), + chunk_size=config.llm_embedding_chunk_size, + context_size=config.llm_context_size, + ) + results = retriever.retrieve(query_text) retrieved_document_ids: list[int] = [] for node in results: document_id = node.metadata.get("document_id") if document_id is None: continue - normalized_document_id = str(document_id) - if ( - allowed_document_ids is not None - and normalized_document_id not in allowed_document_ids - ): + normalized = str(document_id) + if allowed_document_ids is not None and normalized not in allowed_document_ids: continue try: - retrieved_document_ids.append(int(normalized_document_id)) + retrieved_document_ids.append(int(normalized)) except ValueError: logger.warning( "Skipping LLM index result with invalid document_id %r.", diff --git a/src/paperless_ai/tests/test_ai_indexing.py b/src/paperless_ai/tests/test_ai_indexing.py index b9bf5af4c..81cd9d227 100644 --- a/src/paperless_ai/tests/test_ai_indexing.py +++ b/src/paperless_ai/tests/test_ai_indexing.py @@ -955,3 +955,21 @@ class TestLanceDbIndexing: rows = store.client.open_table(indexing.LLM_INDEX_TABLE).count_rows() assert rows < big assert rows >= 1 + + +@pytest.mark.django_db +class TestQuerySimilarDocuments: + def test_query_similar_documents_respects_allowed_ids( + self, + temp_llm_index_dir, + mock_embed_model, + ) -> None: + a = DocumentFactory.create(content="alpha shared content here") + b = DocumentFactory.create(content="beta shared content here") + c = DocumentFactory.create(content="gamma shared content here") + for doc in (a, b, c): + indexing.llm_index_add_or_update_document(doc) + + results = indexing.query_similar_documents(a, document_ids=[b.id]) + + assert all(doc.id == b.id for doc in results)