diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py index f149a5fc5..40c901db7 100644 --- a/src/paperless_ai/chat.py +++ b/src/paperless_ai/chat.py @@ -4,14 +4,14 @@ import sys from documents.models import Document from paperless_ai.client import AIClient +from paperless_ai.indexing import get_rag_prompt_helper from paperless_ai.indexing import load_or_build_index logger = logging.getLogger("paperless_ai.chat") -MAX_SINGLE_DOC_CONTEXT_CHARS = 15000 -SINGLE_DOC_SNIPPET_CHARS = 800 CHAT_METADATA_DELIMITER = "\n\n__PAPERLESS_CHAT_METADATA__" MAX_CHAT_REFERENCES = 3 +CHAT_RETRIEVER_TOP_K = 5 CHAT_PROMPT_TMPL = """Context information is below. --------------------- @@ -89,66 +89,39 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): from llama_index.core import VectorStoreIndex from llama_index.core.prompts import PromptTemplate from llama_index.core.query_engine import RetrieverQueryEngine + from llama_index.core.response_synthesizers import get_response_synthesizer local_index = VectorStoreIndex(nodes=nodes) retriever = local_index.as_retriever( - similarity_top_k=3 if len(documents) == 1 else 5, + similarity_top_k=CHAT_RETRIEVER_TOP_K, ) - if len(documents) == 1: - # Just one doc — provide full content - doc = documents[0] - references = [_build_document_reference(doc)] - # TODO: include document metadata in the context - content = doc.content or "" - context_body = content + top_nodes = retriever.retrieve(query_str) + if len(top_nodes) == 0: + logger.warning("Retriever returned no nodes for the given documents.") + yield "Sorry, I couldn't find any content to answer your question." + return - if len(content) > MAX_SINGLE_DOC_CONTEXT_CHARS: - logger.info( - "Truncating single-document context from %s to %s characters", - len(content), - MAX_SINGLE_DOC_CONTEXT_CHARS, - ) - context_body = content[:MAX_SINGLE_DOC_CONTEXT_CHARS] - - top_nodes = retriever.retrieve(query_str) - if len(top_nodes) > 0: - snippets = "\n\n".join( - f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}" - for node in top_nodes - ) - context_body = f"{context_body}\n\nTOP MATCHES:\n{snippets}" - - context = f"TITLE: {doc.title or doc.filename}\n{context_body}" - else: - top_nodes = retriever.retrieve(query_str) - - if len(top_nodes) == 0: - logger.warning("Retriever returned no nodes for the given documents.") - yield "Sorry, I couldn't find any content to answer your question." - return - - references = _get_document_references(documents, top_nodes) - context = "\n\n".join( - f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}" - for node in top_nodes - ) + references = _get_document_references(documents, top_nodes) prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL) - prompt = prompt_template.partial_format( - context_str=context, - query_str=query_str, - ).format(llm=client.llm) + response_synthesizer = get_response_synthesizer( + llm=client.llm, + prompt_helper=get_rag_prompt_helper(), + text_qa_template=prompt_template, + streaming=True, + ) query_engine = RetrieverQueryEngine.from_args( retriever=retriever, llm=client.llm, + response_synthesizer=response_synthesizer, streaming=True, ) - logger.debug("Document chat prompt: %s", prompt) + logger.debug("Document chat query: %s", query_str) - response_stream = query_engine.query(prompt) + response_stream = query_engine.query(query_str) for chunk in response_stream.response_gen: yield chunk diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index b8c865214..7a9796008 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -22,6 +22,11 @@ if TYPE_CHECKING: logger = logging.getLogger("paperless_ai.indexing") +RAG_CONTEXT_WINDOW = 8192 +RAG_NUM_OUTPUT = 512 +RAG_CHUNK_SIZE = 1024 +RAG_CHUNK_OVERLAP = 200 + def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool: from documents.tasks import llmindex_index @@ -111,7 +116,10 @@ def build_document_node(document: Document) -> list["BaseNode"]: from llama_index.core.node_parser import SimpleNodeParser doc = LlamaDocument(text=text, metadata=metadata) - parser = SimpleNodeParser() + parser = SimpleNodeParser( + chunk_size=RAG_CHUNK_SIZE, + chunk_overlap=get_rag_chunk_overlap(), + ) return parser.get_nodes_from_documents([doc]) @@ -168,6 +176,21 @@ def vector_store_file_exists(): return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists() +def get_rag_chunk_overlap() -> int: + return min(RAG_CHUNK_OVERLAP, RAG_CHUNK_SIZE - 1) + + +def get_rag_prompt_helper(): + from llama_index.core.indices.prompt_helper import PromptHelper + + return PromptHelper( + context_window=RAG_CONTEXT_WINDOW, + num_output=RAG_NUM_OUTPUT, + chunk_overlap_ratio=0.1, + chunk_size_limit=RAG_CHUNK_SIZE, + ) + + def update_llm_index( *, iter_wrapper: IterWrapper[Document] = identity, @@ -277,17 +300,15 @@ def llm_index_remove_document(document: Document): def truncate_content(content: str) -> str: - from llama_index.core.indices.prompt_helper import PromptHelper from llama_index.core.prompts import PromptTemplate from llama_index.core.text_splitter import TokenTextSplitter - prompt_helper = PromptHelper( - context_window=8192, - num_output=512, - chunk_overlap_ratio=0.1, - chunk_size_limit=None, + prompt_helper = get_rag_prompt_helper() + splitter = TokenTextSplitter( + separator=" ", + chunk_size=RAG_CHUNK_SIZE, + chunk_overlap=get_rag_chunk_overlap(), ) - splitter = TokenTextSplitter(separator=" ", chunk_size=512, chunk_overlap=50) content_chunks = splitter.split_text(content) truncated_chunks = prompt_helper.truncate( prompt=PromptTemplate(template="{content}"), diff --git a/src/paperless_ai/tests/test_ai_indexing.py b/src/paperless_ai/tests/test_ai_indexing.py index d02cf3b96..09fbc2038 100644 --- a/src/paperless_ai/tests/test_ai_indexing.py +++ b/src/paperless_ai/tests/test_ai_indexing.py @@ -58,6 +58,24 @@ def test_build_document_node(real_document) -> None: assert nodes[0].metadata["document_id"] == str(real_document.id) +@pytest.mark.django_db +def test_build_document_node_uses_rag_chunk_settings(real_document) -> None: + with patch("llama_index.core.node_parser.SimpleNodeParser") as mock_parser: + mock_parser.return_value.get_nodes_from_documents.return_value = [] + + indexing.build_document_node(real_document) + + mock_parser.assert_called_once_with(chunk_size=1024, chunk_overlap=200) + + +def test_get_rag_chunk_overlap_clamps_to_chunk_size() -> None: + with ( + patch("paperless_ai.indexing.RAG_CHUNK_SIZE", 64), + patch("paperless_ai.indexing.RAG_CHUNK_OVERLAP", 128), + ): + assert indexing.get_rag_chunk_overlap() == 63 + + @pytest.mark.django_db def test_update_llm_index( temp_llm_index_dir, diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index 5e26ca0af..c7beb50d0 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -57,7 +57,7 @@ def assert_chat_output( } -def test_stream_chat_with_one_document_full_content(mock_document) -> None: +def test_stream_chat_with_one_document_retrieval(mock_document) -> None: with ( patch("paperless_ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.load_or_build_index") as mock_load_index, @@ -85,6 +85,7 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None: output = list(stream_chat_with_documents("What is this?", [mock_document])) + mock_query_engine.query.assert_called_once_with("What is this?") assert_chat_output( output, expected_chunks=["chunk1", "chunk2"], @@ -154,6 +155,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non output = list(stream_chat_with_documents("What's up?", [doc1, doc2])) + mock_query_engine.query.assert_called_once_with("What's up?") assert_chat_output( output, expected_chunks=["chunk1", "chunk2"],