Fix: use response synthesizer for RAG doc chat (#12751)

This commit is contained in:
shamoon
2026-05-08 13:01:44 -07:00
committed by GitHub
parent 8769dc894e
commit 57b91ad2cf
4 changed files with 69 additions and 55 deletions
+19 -46
View File
@@ -4,14 +4,14 @@ import sys
from documents.models import Document
from paperless_ai.client import AIClient
from paperless_ai.indexing import get_rag_prompt_helper
from paperless_ai.indexing import load_or_build_index
logger = logging.getLogger("paperless_ai.chat")
MAX_SINGLE_DOC_CONTEXT_CHARS = 15000
SINGLE_DOC_SNIPPET_CHARS = 800
CHAT_METADATA_DELIMITER = "\n\n__PAPERLESS_CHAT_METADATA__"
MAX_CHAT_REFERENCES = 3
CHAT_RETRIEVER_TOP_K = 5
CHAT_PROMPT_TMPL = """Context information is below.
---------------------
@@ -89,66 +89,39 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
from llama_index.core import VectorStoreIndex
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import get_response_synthesizer
local_index = VectorStoreIndex(nodes=nodes)
retriever = local_index.as_retriever(
similarity_top_k=3 if len(documents) == 1 else 5,
similarity_top_k=CHAT_RETRIEVER_TOP_K,
)
if len(documents) == 1:
# Just one doc — provide full content
doc = documents[0]
references = [_build_document_reference(doc)]
# TODO: include document metadata in the context
content = doc.content or ""
context_body = content
top_nodes = retriever.retrieve(query_str)
if len(top_nodes) == 0:
logger.warning("Retriever returned no nodes for the given documents.")
yield "Sorry, I couldn't find any content to answer your question."
return
if len(content) > MAX_SINGLE_DOC_CONTEXT_CHARS:
logger.info(
"Truncating single-document context from %s to %s characters",
len(content),
MAX_SINGLE_DOC_CONTEXT_CHARS,
)
context_body = content[:MAX_SINGLE_DOC_CONTEXT_CHARS]
top_nodes = retriever.retrieve(query_str)
if len(top_nodes) > 0:
snippets = "\n\n".join(
f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}"
for node in top_nodes
)
context_body = f"{context_body}\n\nTOP MATCHES:\n{snippets}"
context = f"TITLE: {doc.title or doc.filename}\n{context_body}"
else:
top_nodes = retriever.retrieve(query_str)
if len(top_nodes) == 0:
logger.warning("Retriever returned no nodes for the given documents.")
yield "Sorry, I couldn't find any content to answer your question."
return
references = _get_document_references(documents, top_nodes)
context = "\n\n".join(
f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}"
for node in top_nodes
)
references = _get_document_references(documents, top_nodes)
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
prompt = prompt_template.partial_format(
context_str=context,
query_str=query_str,
).format(llm=client.llm)
response_synthesizer = get_response_synthesizer(
llm=client.llm,
prompt_helper=get_rag_prompt_helper(),
text_qa_template=prompt_template,
streaming=True,
)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
response_synthesizer=response_synthesizer,
streaming=True,
)
logger.debug("Document chat prompt: %s", prompt)
logger.debug("Document chat query: %s", query_str)
response_stream = query_engine.query(prompt)
response_stream = query_engine.query(query_str)
for chunk in response_stream.response_gen:
yield chunk
+29 -8
View File
@@ -22,6 +22,11 @@ if TYPE_CHECKING:
logger = logging.getLogger("paperless_ai.indexing")
RAG_CONTEXT_WINDOW = 8192
RAG_NUM_OUTPUT = 512
RAG_CHUNK_SIZE = 1024
RAG_CHUNK_OVERLAP = 200
def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
from documents.tasks import llmindex_index
@@ -111,7 +116,10 @@ def build_document_node(document: Document) -> list["BaseNode"]:
from llama_index.core.node_parser import SimpleNodeParser
doc = LlamaDocument(text=text, metadata=metadata)
parser = SimpleNodeParser()
parser = SimpleNodeParser(
chunk_size=RAG_CHUNK_SIZE,
chunk_overlap=get_rag_chunk_overlap(),
)
return parser.get_nodes_from_documents([doc])
@@ -168,6 +176,21 @@ def vector_store_file_exists():
return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists()
def get_rag_chunk_overlap() -> int:
return min(RAG_CHUNK_OVERLAP, RAG_CHUNK_SIZE - 1)
def get_rag_prompt_helper():
from llama_index.core.indices.prompt_helper import PromptHelper
return PromptHelper(
context_window=RAG_CONTEXT_WINDOW,
num_output=RAG_NUM_OUTPUT,
chunk_overlap_ratio=0.1,
chunk_size_limit=RAG_CHUNK_SIZE,
)
def update_llm_index(
*,
iter_wrapper: IterWrapper[Document] = identity,
@@ -277,17 +300,15 @@ def llm_index_remove_document(document: Document):
def truncate_content(content: str) -> str:
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.prompts import PromptTemplate
from llama_index.core.text_splitter import TokenTextSplitter
prompt_helper = PromptHelper(
context_window=8192,
num_output=512,
chunk_overlap_ratio=0.1,
chunk_size_limit=None,
prompt_helper = get_rag_prompt_helper()
splitter = TokenTextSplitter(
separator=" ",
chunk_size=RAG_CHUNK_SIZE,
chunk_overlap=get_rag_chunk_overlap(),
)
splitter = TokenTextSplitter(separator=" ", chunk_size=512, chunk_overlap=50)
content_chunks = splitter.split_text(content)
truncated_chunks = prompt_helper.truncate(
prompt=PromptTemplate(template="{content}"),
@@ -58,6 +58,24 @@ def test_build_document_node(real_document) -> None:
assert nodes[0].metadata["document_id"] == str(real_document.id)
@pytest.mark.django_db
def test_build_document_node_uses_rag_chunk_settings(real_document) -> None:
with patch("llama_index.core.node_parser.SimpleNodeParser") as mock_parser:
mock_parser.return_value.get_nodes_from_documents.return_value = []
indexing.build_document_node(real_document)
mock_parser.assert_called_once_with(chunk_size=1024, chunk_overlap=200)
def test_get_rag_chunk_overlap_clamps_to_chunk_size() -> None:
with (
patch("paperless_ai.indexing.RAG_CHUNK_SIZE", 64),
patch("paperless_ai.indexing.RAG_CHUNK_OVERLAP", 128),
):
assert indexing.get_rag_chunk_overlap() == 63
@pytest.mark.django_db
def test_update_llm_index(
temp_llm_index_dir,
+3 -1
View File
@@ -57,7 +57,7 @@ def assert_chat_output(
}
def test_stream_chat_with_one_document_full_content(mock_document) -> None:
def test_stream_chat_with_one_document_retrieval(mock_document) -> None:
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
@@ -85,6 +85,7 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None:
output = list(stream_chat_with_documents("What is this?", [mock_document]))
mock_query_engine.query.assert_called_once_with("What is this?")
assert_chat_output(
output,
expected_chunks=["chunk1", "chunk2"],
@@ -154,6 +155,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
mock_query_engine.query.assert_called_once_with("What's up?")
assert_chat_output(
output,
expected_chunks=["chunk1", "chunk2"],