From 75a31ee09b331544d4c9b112b8b98fe32bb3bf04 Mon Sep 17 00:00:00 2001 From: stumpylog <797416+stumpylog@users.noreply.github.com> Date: Fri, 5 Jun 2026 13:31:32 -0700 Subject: [PATCH] Extracts some common code into helpers instead of duplication --- src/paperless_ai/chat.py | 32 ++++----------- src/paperless_ai/indexing.py | 80 ++++++++++++++++++------------------ 2 files changed, 46 insertions(+), 66 deletions(-) diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py index 6465cec9e..10d2b96d8 100644 --- a/src/paperless_ai/chat.py +++ b/src/paperless_ai/chat.py @@ -4,6 +4,7 @@ import sys from documents.models import Document from paperless_ai.client import AIClient +from paperless_ai.indexing import _document_id_filters from paperless_ai.indexing import get_rag_prompt_helper from paperless_ai.indexing import load_or_build_index @@ -79,7 +80,7 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): try: yield from _stream_chat_with_documents(query_str, documents) except Exception as e: - logger.exception(f"Failed to stream document chat response: {e}", exc_info=True) + logger.exception("Failed to stream document chat response: %s", e) yield CHAT_ERROR_MESSAGE @@ -88,30 +89,9 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]): from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.response_synthesizers import get_response_synthesizer from llama_index.core.retrievers import VectorIndexRetriever - from llama_index.core.vector_stores.types import FilterOperator - from llama_index.core.vector_stores.types import MetadataFilter - from llama_index.core.vector_stores.types import MetadataFilters index = load_or_build_index() - - doc_ids = sorted(str(doc.pk) for doc in documents) - filters = MetadataFilters( - filters=[ - MetadataFilter( - key="document_id", - operator=FilterOperator.IN, - value=doc_ids, - ), - ], - ) - - # No indexed content for these documents -> bail early (before touching the LLM). - if not index.vector_store.has_nodes(filters=filters): - logger.warning("No nodes found for the given documents.") - yield CHAT_NO_CONTENT_MESSAGE - return - - client = AIClient() + filters = _document_id_filters(str(doc.pk) for doc in documents) retriever = VectorIndexRetriever( index=index, @@ -120,11 +100,13 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]): ) top_nodes = retriever.retrieve(query_str) - if len(top_nodes) == 0: - logger.warning("Retriever returned no nodes for the given documents.") + if not top_nodes: + logger.warning("No nodes found for the given documents.") yield CHAT_NO_CONTENT_MESSAGE return + client = AIClient() + references = _get_document_references(documents, top_nodes) prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL) diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index d8c14c1ac..3e6cff3d3 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -192,6 +192,36 @@ def get_rag_prompt_helper( ) +def _embed_nodes(nodes: list["BaseNode"], embed_model) -> None: + """Embed ``nodes`` in place using ``embed_model``.""" + from llama_index.core.schema import MetadataMode + + texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes] + for node, emb in zip( + nodes, + embed_model.get_text_embedding_batch(texts), + strict=True, + ): + node.embedding = emb + + +def _document_id_filters(doc_ids): + """Return a MetadataFilters IN filter scoped to ``doc_ids``.""" + from llama_index.core.vector_stores.types import FilterOperator + from llama_index.core.vector_stores.types import MetadataFilter + from llama_index.core.vector_stores.types import MetadataFilters + + return MetadataFilters( + filters=[ + MetadataFilter( + key="document_id", + operator=FilterOperator.IN, + value=sorted(doc_ids), + ), + ], + ) + + def get_llm_index_compaction_retention() -> int: """Seconds of MVCC version history to keep during compaction.""" return 60 * 60 # 1 hour: safe for in-flight readers, reclaims daily @@ -203,8 +233,6 @@ def update_llm_index( rebuild=False, ) -> str: """Rebuild or incrementally update the LLM index.""" - from llama_index.core.schema import MetadataMode - if not rebuild and llm_index_exists() and embedding_dim_mismatch(): logger.warning("Embedding dimension changed; forcing LLM index rebuild.") rebuild = True @@ -219,19 +247,13 @@ def update_llm_index( embed_model = get_embedding_model() with write_store() as store: - if rebuild or not llm_index_exists(): + if rebuild or not store.table_exists(): (settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True) logger.info("Rebuilding LLM index.") store.drop_table() for document in iter_wrapper(documents): nodes = build_document_node(document, chunk_size=chunk_size) - texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes] - for node, emb in zip( - nodes, - embed_model.get_text_embedding_batch(texts), - strict=True, - ): - node.embedding = emb + _embed_nodes(nodes, embed_model) store.add(nodes) msg = "LLM index rebuilt successfully." else: @@ -242,13 +264,7 @@ def update_llm_index( if existing.get(doc_id) == document.modified.isoformat(): continue nodes = build_document_node(document, chunk_size=chunk_size) - texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes] - for node, emb in zip( - nodes, - embed_model.get_text_embedding_batch(texts), - strict=True, - ): - node.embedding = emb + _embed_nodes(nodes, embed_model) store.upsert_document(doc_id, nodes) changed += 1 msg = ( @@ -265,18 +281,9 @@ def update_llm_index( def llm_index_add_or_update_document(document: Document): """Add or atomically replace a document's chunks in the index.""" - from llama_index.core.schema import MetadataMode - new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size()) if new_nodes: - embed_model = get_embedding_model() - texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in new_nodes] - for node, emb in zip( - new_nodes, - embed_model.get_text_embedding_batch(texts), - strict=True, - ): - node.embedding = emb + _embed_nodes(new_nodes, get_embedding_model()) with write_store() as store: store.upsert_document(str(document.id), new_nodes) @@ -350,23 +357,14 @@ def query_similar_documents( return [] from llama_index.core.retrievers import VectorIndexRetriever - from llama_index.core.vector_stores.types import FilterOperator - from llama_index.core.vector_stores.types import MetadataFilter - from llama_index.core.vector_stores.types import MetadataFilters index = load_or_build_index() - filters = None - if allowed_document_ids is not None: - filters = MetadataFilters( - filters=[ - MetadataFilter( - key="document_id", - operator=FilterOperator.IN, - value=sorted(allowed_document_ids), - ), - ], - ) + filters = ( + _document_id_filters(allowed_document_ids) + if allowed_document_ids is not None + else None + ) retriever = VectorIndexRetriever( index=index,