Extracts some common code into helpers instead of duplication

This commit is contained in:
stumpylog
2026-06-05 13:31:32 -07:00
parent a23888aa1b
commit 75a31ee09b
2 changed files with 46 additions and 66 deletions
+7 -25
View File
@@ -4,6 +4,7 @@ import sys
from documents.models import Document
from paperless_ai.client import AIClient
from paperless_ai.indexing import _document_id_filters
from paperless_ai.indexing import get_rag_prompt_helper
from paperless_ai.indexing import load_or_build_index
@@ -79,7 +80,7 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
try:
yield from _stream_chat_with_documents(query_str, documents)
except Exception as e:
logger.exception(f"Failed to stream document chat response: {e}", exc_info=True)
logger.exception("Failed to stream document chat response: %s", e)
yield CHAT_ERROR_MESSAGE
@@ -88,30 +89,9 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]):
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
index = load_or_build_index()
doc_ids = sorted(str(doc.pk) for doc in documents)
filters = MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=doc_ids,
),
],
)
# No indexed content for these documents -> bail early (before touching the LLM).
if not index.vector_store.has_nodes(filters=filters):
logger.warning("No nodes found for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
client = AIClient()
filters = _document_id_filters(str(doc.pk) for doc in documents)
retriever = VectorIndexRetriever(
index=index,
@@ -120,11 +100,13 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]):
)
top_nodes = retriever.retrieve(query_str)
if len(top_nodes) == 0:
logger.warning("Retriever returned no nodes for the given documents.")
if not top_nodes:
logger.warning("No nodes found for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
client = AIClient()
references = _get_document_references(documents, top_nodes)
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
+39 -41
View File
@@ -192,6 +192,36 @@ def get_rag_prompt_helper(
)
def _embed_nodes(nodes: list["BaseNode"], embed_model) -> None:
"""Embed ``nodes`` in place using ``embed_model``."""
from llama_index.core.schema import MetadataMode
texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes]
for node, emb in zip(
nodes,
embed_model.get_text_embedding_batch(texts),
strict=True,
):
node.embedding = emb
def _document_id_filters(doc_ids):
"""Return a MetadataFilters IN filter scoped to ``doc_ids``."""
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
return MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=sorted(doc_ids),
),
],
)
def get_llm_index_compaction_retention() -> int:
"""Seconds of MVCC version history to keep during compaction."""
return 60 * 60 # 1 hour: safe for in-flight readers, reclaims daily
@@ -203,8 +233,6 @@ def update_llm_index(
rebuild=False,
) -> str:
"""Rebuild or incrementally update the LLM index."""
from llama_index.core.schema import MetadataMode
if not rebuild and llm_index_exists() and embedding_dim_mismatch():
logger.warning("Embedding dimension changed; forcing LLM index rebuild.")
rebuild = True
@@ -219,19 +247,13 @@ def update_llm_index(
embed_model = get_embedding_model()
with write_store() as store:
if rebuild or not llm_index_exists():
if rebuild or not store.table_exists():
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
logger.info("Rebuilding LLM index.")
store.drop_table()
for document in iter_wrapper(documents):
nodes = build_document_node(document, chunk_size=chunk_size)
texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes]
for node, emb in zip(
nodes,
embed_model.get_text_embedding_batch(texts),
strict=True,
):
node.embedding = emb
_embed_nodes(nodes, embed_model)
store.add(nodes)
msg = "LLM index rebuilt successfully."
else:
@@ -242,13 +264,7 @@ def update_llm_index(
if existing.get(doc_id) == document.modified.isoformat():
continue
nodes = build_document_node(document, chunk_size=chunk_size)
texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes]
for node, emb in zip(
nodes,
embed_model.get_text_embedding_batch(texts),
strict=True,
):
node.embedding = emb
_embed_nodes(nodes, embed_model)
store.upsert_document(doc_id, nodes)
changed += 1
msg = (
@@ -265,18 +281,9 @@ def update_llm_index(
def llm_index_add_or_update_document(document: Document):
"""Add or atomically replace a document's chunks in the index."""
from llama_index.core.schema import MetadataMode
new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size())
if new_nodes:
embed_model = get_embedding_model()
texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in new_nodes]
for node, emb in zip(
new_nodes,
embed_model.get_text_embedding_batch(texts),
strict=True,
):
node.embedding = emb
_embed_nodes(new_nodes, get_embedding_model())
with write_store() as store:
store.upsert_document(str(document.id), new_nodes)
@@ -350,23 +357,14 @@ def query_similar_documents(
return []
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
index = load_or_build_index()
filters = None
if allowed_document_ids is not None:
filters = MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=sorted(allowed_document_ids),
),
],
)
filters = (
_document_id_filters(allowed_document_ids)
if allowed_document_ids is not None
else None
)
retriever = VectorIndexRetriever(
index=index,