refactor(ai): query_similar_documents via metadata filter

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
stumpylog
2026-06-03 08:35:09 -07:00
parent 6bb8212f20
commit 1f2af9087c
2 changed files with 55 additions and 52 deletions
+37 -52
View File
@@ -32,11 +32,11 @@ RAG_CHUNK_OVERLAP = 200
def _index_lock_path() -> Path:
"""Return the path used as the file lock for FAISS index mutations.
"""Return the path used as the file lock for LanceDB index mutations.
The lock file lives in DATA_DIR/locks/ (not inside LLM_INDEX_DIR) so that a
rebuild — which calls shutil.rmtree(LLM_INDEX_DIR) — cannot delete the lock
while another worker still holds it.
rebuild — which calls store.drop_table() — cannot interfere with another
worker that still holds the lock.
"""
return settings.LLM_INDEX_LOCK
@@ -99,7 +99,7 @@ def build_document_node(
metadata = {
"document_id": str(document.id),
"title": document.title,
"tags": json.dumps([t.name for t in document.tags.all()]),
"tags": [t.name for t in document.tags.all()],
"correspondent": document.correspondent.name
if document.correspondent
else None,
@@ -190,7 +190,7 @@ def get_rag_prompt_helper(
def _iter_existing_modified(store: "PaperlessLanceVectorStore") -> list[dict]:
"""One representative row per document_id, for modified-time comparison."""
if LLM_INDEX_TABLE not in store.client.table_names():
if not store.table_exists():
return []
seen: dict[str, dict] = {}
for row in store.client.open_table(LLM_INDEX_TABLE).search().to_list():
@@ -334,9 +334,7 @@ def query_similar_documents(
top_k: int = 5,
document_ids: Iterable[int | str] | None = None,
) -> list[Document]:
"""
Runs a similarity query and returns top-k similar Document objects.
"""
"""Return up to ``top_k`` Documents most similar to ``document``."""
allowed_document_ids = normalize_document_ids(document_ids)
if allowed_document_ids is not None and not allowed_document_ids:
return []
@@ -348,62 +346,49 @@ def query_similar_documents(
)
return []
with FileLock(_index_lock_path()):
index = load_or_build_index()
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
# constrain only the node(s) that match the document IDs, if given
doc_node_ids = (
[
node.node_id
for node in index.docstore.docs.values()
if node.metadata.get("document_id") in allowed_document_ids
]
if allowed_document_ids is not None
else None
)
if doc_node_ids is not None and not doc_node_ids:
return []
index = load_or_build_index()
from llama_index.core.retrievers import VectorIndexRetriever
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k,
doc_ids=doc_node_ids,
filters = None
if allowed_document_ids is not None:
filters = MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=sorted(allowed_document_ids),
),
],
)
config = AIConfig()
query_text = truncate_content(
(document.title or "") + "\n" + (document.content or ""),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
try:
results = retriever.retrieve(query_text)
except KeyError as e:
# Ghost FAISS positions remain after deletion because IndexFlatL2 is
# append-only. Treat them as absent and return no results.
logger.debug(
"Skipping LLM similarity query for document %s due to a stale "
"FAISS position with no docstore node: %s",
document.pk,
e,
)
return []
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k,
filters=filters,
)
config = AIConfig()
query_text = truncate_content(
(document.title or "") + "\n" + (document.content or ""),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
results = retriever.retrieve(query_text)
retrieved_document_ids: list[int] = []
for node in results:
document_id = node.metadata.get("document_id")
if document_id is None:
continue
normalized_document_id = str(document_id)
if (
allowed_document_ids is not None
and normalized_document_id not in allowed_document_ids
):
normalized = str(document_id)
if allowed_document_ids is not None and normalized not in allowed_document_ids:
continue
try:
retrieved_document_ids.append(int(normalized_document_id))
retrieved_document_ids.append(int(normalized))
except ValueError:
logger.warning(
"Skipping LLM index result with invalid document_id %r.",
@@ -955,3 +955,21 @@ class TestLanceDbIndexing:
rows = store.client.open_table(indexing.LLM_INDEX_TABLE).count_rows()
assert rows < big
assert rows >= 1
@pytest.mark.django_db
class TestQuerySimilarDocuments:
def test_query_similar_documents_respects_allowed_ids(
self,
temp_llm_index_dir,
mock_embed_model,
) -> None:
a = DocumentFactory.create(content="alpha shared content here")
b = DocumentFactory.create(content="beta shared content here")
c = DocumentFactory.create(content="gamma shared content here")
for doc in (a, b, c):
indexing.llm_index_add_or_update_document(doc)
results = indexing.query_similar_documents(a, document_ids=[b.id])
assert all(doc.id == b.id for doc in results)