mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-06 05:39:45 +00:00
refactor(ai): chat uses a stock filtered retriever
Delete _get_document_filtered_retriever (74-line custom FAISS retriever with expanding top_k loop) and rewrite _stream_chat_with_documents to use a stock VectorIndexRetriever with MetadataFilters(IN). The no-content pre-check now calls index.vector_store.get_nodes(filters=...) which returns [] cleanly for un-indexed documents. Move FakeEmbedding and mock_embed_model fixture to conftest.py so both test_chat.py and test_ai_indexing.py share them. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
+24
-95
@@ -75,82 +75,6 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str
|
||||
)
|
||||
|
||||
|
||||
def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int):
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.vector_stores import VectorStoreQuery
|
||||
|
||||
class DocumentFilteredFaissRetriever(BaseRetriever):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._cached_query_str = None
|
||||
self._cached_nodes = []
|
||||
|
||||
def _retrieve(self, query_bundle):
|
||||
if query_bundle.query_str == self._cached_query_str:
|
||||
return self._cached_nodes
|
||||
|
||||
if query_bundle.embedding is None:
|
||||
query_bundle.embedding = (
|
||||
index._embed_model.get_agg_embedding_from_queries(
|
||||
query_bundle.embedding_strs,
|
||||
)
|
||||
)
|
||||
|
||||
faiss_index = index.vector_store._faiss_index
|
||||
max_top_k = faiss_index.ntotal
|
||||
if max_top_k == 0:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = []
|
||||
return []
|
||||
|
||||
query_top_k = min(max(similarity_top_k, 1), max_top_k)
|
||||
allowed_nodes: list[NodeWithScore] = []
|
||||
seen_node_ids: set[str] = set()
|
||||
|
||||
while query_top_k <= max_top_k:
|
||||
query_result = index.vector_store.query(
|
||||
VectorStoreQuery(
|
||||
query_embedding=query_bundle.embedding,
|
||||
similarity_top_k=query_top_k,
|
||||
),
|
||||
)
|
||||
|
||||
for vector_id, score in zip(
|
||||
query_result.ids or [],
|
||||
query_result.similarities or [],
|
||||
strict=False,
|
||||
):
|
||||
node_id = index.index_struct.nodes_dict.get(vector_id)
|
||||
if node_id is None or node_id in seen_node_ids:
|
||||
continue
|
||||
|
||||
node = index.docstore.docs.get(node_id)
|
||||
if node is None or node.metadata.get("document_id") not in doc_ids:
|
||||
continue
|
||||
|
||||
seen_node_ids.add(node_id)
|
||||
allowed_nodes.append(NodeWithScore(node=node, score=score))
|
||||
|
||||
if len(allowed_nodes) >= similarity_top_k:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
if query_top_k == max_top_k:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
query_top_k = min(query_top_k * 2, max_top_k)
|
||||
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
return DocumentFilteredFaissRetriever()
|
||||
|
||||
|
||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
try:
|
||||
yield from _stream_chat_with_documents(query_str, documents)
|
||||
@@ -160,31 +84,39 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
|
||||
|
||||
def _stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
client = AIClient()
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import get_response_synthesizer
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.vector_stores.types import FilterOperator
|
||||
from llama_index.core.vector_stores.types import MetadataFilter
|
||||
from llama_index.core.vector_stores.types import MetadataFilters
|
||||
|
||||
index = load_or_build_index()
|
||||
|
||||
doc_ids = [str(doc.pk) for doc in documents]
|
||||
filters = MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="document_id",
|
||||
operator=FilterOperator.IN,
|
||||
value=doc_ids,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Filter only the node(s) that match the document IDs
|
||||
nodes = [
|
||||
node
|
||||
for node in index.docstore.docs.values()
|
||||
if node.metadata.get("document_id") in doc_ids
|
||||
]
|
||||
|
||||
if len(nodes) == 0:
|
||||
# No indexed content for these documents -> bail early (before touching the LLM).
|
||||
if not index.vector_store.get_nodes(filters=filters):
|
||||
logger.warning("No nodes found for the given documents.")
|
||||
yield CHAT_NO_CONTENT_MESSAGE
|
||||
return
|
||||
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import get_response_synthesizer
|
||||
client = AIClient()
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
index,
|
||||
set(doc_ids),
|
||||
CHAT_RETRIEVER_TOP_K,
|
||||
retriever = VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=CHAT_RETRIEVER_TOP_K,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
@@ -202,7 +134,6 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
text_qa_template=prompt_template,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
llm=client.llm,
|
||||
@@ -211,9 +142,7 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
)
|
||||
|
||||
logger.debug("Document chat query: %s", query_str)
|
||||
|
||||
response_stream = query_engine.query(query_str)
|
||||
|
||||
for chunk in response_stream.response_gen:
|
||||
yield chunk
|
||||
sys.stdout.flush()
|
||||
|
||||
@@ -1,10 +1,40 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from pytest_django.fixtures import SettingsWrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper):
|
||||
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path:
|
||||
settings.LLM_INDEX_DIR = tmp_path
|
||||
return tmp_path
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def get_query_embedding_dim(self) -> int:
|
||||
return 384
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model():
|
||||
fake = FakeEmbedding()
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
|
||||
patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
) as mock_embedding,
|
||||
):
|
||||
mock_index.return_value = fake
|
||||
mock_embedding.return_value = fake
|
||||
yield mock_index
|
||||
|
||||
@@ -9,7 +9,6 @@ from django.contrib.auth.models import User
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
from faker import Faker
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import PaperlessTask
|
||||
@@ -30,35 +29,6 @@ def real_document(db):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model():
|
||||
fake = FakeEmbedding()
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
|
||||
patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
) as mock_embedding,
|
||||
):
|
||||
mock_index.return_value = fake
|
||||
mock_embedding.return_value = fake
|
||||
yield mock_index
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
# TODO: maybe a better way to do this?
|
||||
def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def get_query_embedding_dim(self) -> int:
|
||||
return 384 # Match your real FAISS config
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node(real_document) -> None:
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
|
||||
@@ -5,9 +5,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
from paperless_ai import chat
|
||||
from paperless_ai.chat import CHAT_ERROR_MESSAGE
|
||||
from paperless_ai.chat import CHAT_METADATA_DELIMITER
|
||||
from paperless_ai.chat import _get_document_filtered_retriever
|
||||
from paperless_ai.chat import stream_chat_with_documents
|
||||
|
||||
|
||||
@@ -58,91 +58,6 @@ def assert_chat_output(
|
||||
}
|
||||
|
||||
|
||||
def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None:
|
||||
mock_index.index_struct.nodes_dict = {
|
||||
str(vector_id): node.node_id for vector_id, node in enumerate(nodes)
|
||||
}
|
||||
mock_index.docstore.docs.get.side_effect = {
|
||||
node.node_id: node for node in nodes
|
||||
}.get
|
||||
mock_index.vector_store._faiss_index.ntotal = len(nodes)
|
||||
mock_index.vector_store.query.return_value = MagicMock(
|
||||
ids=list(mock_index.index_struct.nodes_dict),
|
||||
similarities=[0.1] * len(nodes),
|
||||
)
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
|
||||
def test_document_filtered_retriever_expands_filters_and_caches() -> None:
|
||||
allowed_node1 = TextNode(
|
||||
text="Allowed content 1.",
|
||||
metadata={"document_id": "1", "title": "Allowed 1"},
|
||||
)
|
||||
allowed_node2 = TextNode(
|
||||
text="Allowed content 2.",
|
||||
metadata={"document_id": "2", "title": "Allowed 2"},
|
||||
)
|
||||
foreign_node = TextNode(
|
||||
text="Foreign content.",
|
||||
metadata={"document_id": "3", "title": "Foreign"},
|
||||
)
|
||||
missing_node = TextNode(
|
||||
text="Missing content.",
|
||||
metadata={"document_id": "1", "title": "Missing"},
|
||||
)
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_index.index_struct.nodes_dict = {
|
||||
"0": foreign_node.node_id,
|
||||
"1": missing_node.node_id,
|
||||
"2": allowed_node1.node_id,
|
||||
"3": allowed_node2.node_id,
|
||||
}
|
||||
mock_index.docstore.docs.get.side_effect = {
|
||||
allowed_node1.node_id: allowed_node1,
|
||||
allowed_node2.node_id: allowed_node2,
|
||||
foreign_node.node_id: foreign_node,
|
||||
}.get
|
||||
mock_index.vector_store._faiss_index.ntotal = 4
|
||||
mock_index.vector_store.query.side_effect = [
|
||||
MagicMock(ids=["0", "2"], similarities=[0.9, 0.8]),
|
||||
MagicMock(ids=["0", "1", "3"], similarities=[0.9, 0.7, 0.6]),
|
||||
]
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
mock_index,
|
||||
{"1", "2"},
|
||||
similarity_top_k=2,
|
||||
)
|
||||
|
||||
nodes = retriever.retrieve("question")
|
||||
cached_nodes = retriever.retrieve("question")
|
||||
|
||||
assert [node.node.node_id for node in nodes] == [
|
||||
allowed_node1.node_id,
|
||||
allowed_node2.node_id,
|
||||
]
|
||||
assert cached_nodes == nodes
|
||||
assert mock_index.vector_store.query.call_count == 2
|
||||
assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1
|
||||
|
||||
|
||||
def test_document_filtered_retriever_handles_empty_faiss_index() -> None:
|
||||
mock_index = MagicMock()
|
||||
mock_index.vector_store._faiss_index.ntotal = 0
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
mock_index,
|
||||
{"1"},
|
||||
similarity_top_k=2,
|
||||
)
|
||||
|
||||
assert retriever.retrieve("question") == []
|
||||
mock_index.vector_store.query.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_stream_chat_with_one_document_retrieval(
|
||||
mock_document,
|
||||
@@ -164,17 +79,31 @@ def test_stream_chat_with_one_document_retrieval(
|
||||
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||
add_vector_query_results(mock_index, [mock_node])
|
||||
# Simulate get_nodes returning nodes (content exists)
|
||||
mock_index.vector_store.get_nodes.return_value = [mock_node]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
mock_retriever_instance = MagicMock()
|
||||
mock_retriever_instance.retrieve.return_value = [
|
||||
MagicMock(
|
||||
metadata={
|
||||
"document_id": str(mock_document.pk),
|
||||
"title": "Test Document",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
with patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
return_value=mock_retriever_instance,
|
||||
):
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
|
||||
mock_query_engine.query.assert_called_once_with("What is this?")
|
||||
patch_embed_nodes.assert_not_called()
|
||||
@@ -196,12 +125,10 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
):
|
||||
# Mock AIClient and LLM
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
# Create two real TextNodes
|
||||
mock_node1 = TextNode(
|
||||
text="Content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1"},
|
||||
@@ -210,41 +137,32 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
text="Content for doc 2.",
|
||||
metadata={"document_id": "2", "title": "Document 2"},
|
||||
)
|
||||
mock_duplicate_node = TextNode(
|
||||
text="More content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
|
||||
)
|
||||
mock_foreign_node = TextNode(
|
||||
text="Content for doc 3.",
|
||||
metadata={"document_id": "3", "title": "Document 3"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [
|
||||
mock_node1,
|
||||
mock_node2,
|
||||
mock_duplicate_node,
|
||||
mock_foreign_node,
|
||||
]
|
||||
add_vector_query_results(
|
||||
mock_index,
|
||||
[mock_node1, mock_duplicate_node, mock_node2, mock_foreign_node],
|
||||
)
|
||||
# Simulate get_nodes returning nodes (content exists)
|
||||
mock_index.vector_store.get_nodes.return_value = [mock_node1, mock_node2]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
# Mock response stream
|
||||
mock_retriever_instance = MagicMock()
|
||||
mock_retriever_instance.retrieve.return_value = [
|
||||
MagicMock(metadata={"document_id": "1", "title": "Document 1"}),
|
||||
MagicMock(metadata={"document_id": "2", "title": "Document 2"}),
|
||||
]
|
||||
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
|
||||
# Mock RetrieverQueryEngine
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
|
||||
# Fake documents
|
||||
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
|
||||
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
|
||||
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
with patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
return_value=mock_retriever_instance,
|
||||
):
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
|
||||
mock_query_engine.query.assert_called_once_with("What's up?")
|
||||
patch_embed_nodes.assert_not_called()
|
||||
@@ -268,8 +186,8 @@ def test_stream_chat_no_matching_nodes() -> None:
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
mock_index = MagicMock()
|
||||
# No matching nodes
|
||||
mock_index.docstore.docs.values.return_value = []
|
||||
# No matching nodes in the store
|
||||
mock_index.vector_store.get_nodes.return_value = []
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
@@ -281,28 +199,42 @@ def test_stream_chat_unexpected_failure_returns_generic_error(caplog) -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"paperless_ai.chat._get_document_filtered_retriever",
|
||||
) as mock_get_retriever,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
mock_node = TextNode(
|
||||
text="This is node content.",
|
||||
metadata={"document_id": "1", "title": "Test Document"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||
# Nodes found so we get past the pre-check
|
||||
mock_index.vector_store.get_nodes.return_value = [MagicMock()]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.side_effect = RuntimeError("private provider detail")
|
||||
mock_get_retriever.return_value = mock_retriever
|
||||
with patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
) as mock_retriever_cls:
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.side_effect = RuntimeError(
|
||||
"private provider detail",
|
||||
)
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
|
||||
assert output == [CHAT_ERROR_MESSAGE]
|
||||
assert "Failed to stream document chat response" in caplog.text
|
||||
assert "private provider detail" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestStreamChatRetrieval:
|
||||
def test_no_nodes_yields_no_content_message(
|
||||
self,
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
from documents.tests.factories import DocumentFactory
|
||||
|
||||
doc = DocumentFactory.create(content="hello world")
|
||||
# Nothing indexed for this document yet.
|
||||
out = list(chat.stream_chat_with_documents("question?", [doc]))
|
||||
assert chat.CHAT_NO_CONTENT_MESSAGE in out
|
||||
|
||||
Reference in New Issue
Block a user