mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-25 06:44:19 +00:00
Fix: use response synthesizer for RAG doc chat (#12751)
This commit is contained in:
+19
-46
@@ -4,14 +4,14 @@ import sys
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.indexing import get_rag_prompt_helper
|
||||
from paperless_ai.indexing import load_or_build_index
|
||||
|
||||
logger = logging.getLogger("paperless_ai.chat")
|
||||
|
||||
MAX_SINGLE_DOC_CONTEXT_CHARS = 15000
|
||||
SINGLE_DOC_SNIPPET_CHARS = 800
|
||||
CHAT_METADATA_DELIMITER = "\n\n__PAPERLESS_CHAT_METADATA__"
|
||||
MAX_CHAT_REFERENCES = 3
|
||||
CHAT_RETRIEVER_TOP_K = 5
|
||||
|
||||
CHAT_PROMPT_TMPL = """Context information is below.
|
||||
---------------------
|
||||
@@ -89,66 +89,39 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import get_response_synthesizer
|
||||
|
||||
local_index = VectorStoreIndex(nodes=nodes)
|
||||
retriever = local_index.as_retriever(
|
||||
similarity_top_k=3 if len(documents) == 1 else 5,
|
||||
similarity_top_k=CHAT_RETRIEVER_TOP_K,
|
||||
)
|
||||
|
||||
if len(documents) == 1:
|
||||
# Just one doc — provide full content
|
||||
doc = documents[0]
|
||||
references = [_build_document_reference(doc)]
|
||||
# TODO: include document metadata in the context
|
||||
content = doc.content or ""
|
||||
context_body = content
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
if len(top_nodes) == 0:
|
||||
logger.warning("Retriever returned no nodes for the given documents.")
|
||||
yield "Sorry, I couldn't find any content to answer your question."
|
||||
return
|
||||
|
||||
if len(content) > MAX_SINGLE_DOC_CONTEXT_CHARS:
|
||||
logger.info(
|
||||
"Truncating single-document context from %s to %s characters",
|
||||
len(content),
|
||||
MAX_SINGLE_DOC_CONTEXT_CHARS,
|
||||
)
|
||||
context_body = content[:MAX_SINGLE_DOC_CONTEXT_CHARS]
|
||||
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
if len(top_nodes) > 0:
|
||||
snippets = "\n\n".join(
|
||||
f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}"
|
||||
for node in top_nodes
|
||||
)
|
||||
context_body = f"{context_body}\n\nTOP MATCHES:\n{snippets}"
|
||||
|
||||
context = f"TITLE: {doc.title or doc.filename}\n{context_body}"
|
||||
else:
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
|
||||
if len(top_nodes) == 0:
|
||||
logger.warning("Retriever returned no nodes for the given documents.")
|
||||
yield "Sorry, I couldn't find any content to answer your question."
|
||||
return
|
||||
|
||||
references = _get_document_references(documents, top_nodes)
|
||||
context = "\n\n".join(
|
||||
f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}"
|
||||
for node in top_nodes
|
||||
)
|
||||
references = _get_document_references(documents, top_nodes)
|
||||
|
||||
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
|
||||
prompt = prompt_template.partial_format(
|
||||
context_str=context,
|
||||
query_str=query_str,
|
||||
).format(llm=client.llm)
|
||||
response_synthesizer = get_response_synthesizer(
|
||||
llm=client.llm,
|
||||
prompt_helper=get_rag_prompt_helper(),
|
||||
text_qa_template=prompt_template,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
llm=client.llm,
|
||||
response_synthesizer=response_synthesizer,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
logger.debug("Document chat prompt: %s", prompt)
|
||||
logger.debug("Document chat query: %s", query_str)
|
||||
|
||||
response_stream = query_engine.query(prompt)
|
||||
response_stream = query_engine.query(query_str)
|
||||
|
||||
for chunk in response_stream.response_gen:
|
||||
yield chunk
|
||||
|
||||
@@ -22,6 +22,11 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger("paperless_ai.indexing")
|
||||
|
||||
RAG_CONTEXT_WINDOW = 8192
|
||||
RAG_NUM_OUTPUT = 512
|
||||
RAG_CHUNK_SIZE = 1024
|
||||
RAG_CHUNK_OVERLAP = 200
|
||||
|
||||
|
||||
def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
|
||||
from documents.tasks import llmindex_index
|
||||
@@ -111,7 +116,10 @@ def build_document_node(document: Document) -> list["BaseNode"]:
|
||||
from llama_index.core.node_parser import SimpleNodeParser
|
||||
|
||||
doc = LlamaDocument(text=text, metadata=metadata)
|
||||
parser = SimpleNodeParser()
|
||||
parser = SimpleNodeParser(
|
||||
chunk_size=RAG_CHUNK_SIZE,
|
||||
chunk_overlap=get_rag_chunk_overlap(),
|
||||
)
|
||||
return parser.get_nodes_from_documents([doc])
|
||||
|
||||
|
||||
@@ -168,6 +176,21 @@ def vector_store_file_exists():
|
||||
return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists()
|
||||
|
||||
|
||||
def get_rag_chunk_overlap() -> int:
|
||||
return min(RAG_CHUNK_OVERLAP, RAG_CHUNK_SIZE - 1)
|
||||
|
||||
|
||||
def get_rag_prompt_helper():
|
||||
from llama_index.core.indices.prompt_helper import PromptHelper
|
||||
|
||||
return PromptHelper(
|
||||
context_window=RAG_CONTEXT_WINDOW,
|
||||
num_output=RAG_NUM_OUTPUT,
|
||||
chunk_overlap_ratio=0.1,
|
||||
chunk_size_limit=RAG_CHUNK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
def update_llm_index(
|
||||
*,
|
||||
iter_wrapper: IterWrapper[Document] = identity,
|
||||
@@ -277,17 +300,15 @@ def llm_index_remove_document(document: Document):
|
||||
|
||||
|
||||
def truncate_content(content: str) -> str:
|
||||
from llama_index.core.indices.prompt_helper import PromptHelper
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.text_splitter import TokenTextSplitter
|
||||
|
||||
prompt_helper = PromptHelper(
|
||||
context_window=8192,
|
||||
num_output=512,
|
||||
chunk_overlap_ratio=0.1,
|
||||
chunk_size_limit=None,
|
||||
prompt_helper = get_rag_prompt_helper()
|
||||
splitter = TokenTextSplitter(
|
||||
separator=" ",
|
||||
chunk_size=RAG_CHUNK_SIZE,
|
||||
chunk_overlap=get_rag_chunk_overlap(),
|
||||
)
|
||||
splitter = TokenTextSplitter(separator=" ", chunk_size=512, chunk_overlap=50)
|
||||
content_chunks = splitter.split_text(content)
|
||||
truncated_chunks = prompt_helper.truncate(
|
||||
prompt=PromptTemplate(template="{content}"),
|
||||
|
||||
@@ -58,6 +58,24 @@ def test_build_document_node(real_document) -> None:
|
||||
assert nodes[0].metadata["document_id"] == str(real_document.id)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node_uses_rag_chunk_settings(real_document) -> None:
|
||||
with patch("llama_index.core.node_parser.SimpleNodeParser") as mock_parser:
|
||||
mock_parser.return_value.get_nodes_from_documents.return_value = []
|
||||
|
||||
indexing.build_document_node(real_document)
|
||||
|
||||
mock_parser.assert_called_once_with(chunk_size=1024, chunk_overlap=200)
|
||||
|
||||
|
||||
def test_get_rag_chunk_overlap_clamps_to_chunk_size() -> None:
|
||||
with (
|
||||
patch("paperless_ai.indexing.RAG_CHUNK_SIZE", 64),
|
||||
patch("paperless_ai.indexing.RAG_CHUNK_OVERLAP", 128),
|
||||
):
|
||||
assert indexing.get_rag_chunk_overlap() == 63
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index(
|
||||
temp_llm_index_dir,
|
||||
|
||||
@@ -57,7 +57,7 @@ def assert_chat_output(
|
||||
}
|
||||
|
||||
|
||||
def test_stream_chat_with_one_document_full_content(mock_document) -> None:
|
||||
def test_stream_chat_with_one_document_retrieval(mock_document) -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
@@ -85,6 +85,7 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None:
|
||||
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
|
||||
mock_query_engine.query.assert_called_once_with("What is this?")
|
||||
assert_chat_output(
|
||||
output,
|
||||
expected_chunks=["chunk1", "chunk2"],
|
||||
@@ -154,6 +155,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
|
||||
mock_query_engine.query.assert_called_once_with("What's up?")
|
||||
assert_chat_output(
|
||||
output,
|
||||
expected_chunks=["chunk1", "chunk2"],
|
||||
|
||||
Reference in New Issue
Block a user