refactor(ai): chat uses a stock filtered retriever

Delete _get_document_filtered_retriever (74-line custom FAISS retriever
with expanding top_k loop) and rewrite _stream_chat_with_documents to use
a stock VectorIndexRetriever with MetadataFilters(IN). The no-content
pre-check now calls index.vector_store.get_nodes(filters=...) which
returns [] cleanly for un-indexed documents. Move FakeEmbedding and
mock_embed_model fixture to conftest.py so both test_chat.py and
test_ai_indexing.py share them.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
stumpylog
2026-06-03 09:04:26 -07:00
parent d0a7c47f92
commit 788ae5d4e5
4 changed files with 114 additions and 253 deletions
+24 -95
View File
@@ -75,82 +75,6 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str
)
def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int):
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.core.vector_stores import VectorStoreQuery
class DocumentFilteredFaissRetriever(BaseRetriever):
def __init__(self):
super().__init__()
self._cached_query_str = None
self._cached_nodes = []
def _retrieve(self, query_bundle):
if query_bundle.query_str == self._cached_query_str:
return self._cached_nodes
if query_bundle.embedding is None:
query_bundle.embedding = (
index._embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs,
)
)
faiss_index = index.vector_store._faiss_index
max_top_k = faiss_index.ntotal
if max_top_k == 0:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = []
return []
query_top_k = min(max(similarity_top_k, 1), max_top_k)
allowed_nodes: list[NodeWithScore] = []
seen_node_ids: set[str] = set()
while query_top_k <= max_top_k:
query_result = index.vector_store.query(
VectorStoreQuery(
query_embedding=query_bundle.embedding,
similarity_top_k=query_top_k,
),
)
for vector_id, score in zip(
query_result.ids or [],
query_result.similarities or [],
strict=False,
):
node_id = index.index_struct.nodes_dict.get(vector_id)
if node_id is None or node_id in seen_node_ids:
continue
node = index.docstore.docs.get(node_id)
if node is None or node.metadata.get("document_id") not in doc_ids:
continue
seen_node_ids.add(node_id)
allowed_nodes.append(NodeWithScore(node=node, score=score))
if len(allowed_nodes) >= similarity_top_k:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
if query_top_k == max_top_k:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
query_top_k = min(query_top_k * 2, max_top_k)
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
return DocumentFilteredFaissRetriever()
def stream_chat_with_documents(query_str: str, documents: list[Document]):
try:
yield from _stream_chat_with_documents(query_str, documents)
@@ -160,31 +84,39 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
def _stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
index = load_or_build_index()
doc_ids = [str(doc.pk) for doc in documents]
filters = MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=doc_ids,
),
],
)
# Filter only the node(s) that match the document IDs
nodes = [
node
for node in index.docstore.docs.values()
if node.metadata.get("document_id") in doc_ids
]
if len(nodes) == 0:
# No indexed content for these documents -> bail early (before touching the LLM).
if not index.vector_store.get_nodes(filters=filters):
logger.warning("No nodes found for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import get_response_synthesizer
client = AIClient()
retriever = _get_document_filtered_retriever(
index,
set(doc_ids),
CHAT_RETRIEVER_TOP_K,
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=CHAT_RETRIEVER_TOP_K,
filters=filters,
)
top_nodes = retriever.retrieve(query_str)
@@ -202,7 +134,6 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]):
text_qa_template=prompt_template,
streaming=True,
)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
@@ -211,9 +142,7 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]):
)
logger.debug("Document chat query: %s", query_str)
response_stream = query_engine.query(query_str)
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()
+31 -1
View File
@@ -1,10 +1,40 @@
from pathlib import Path
from unittest.mock import patch
import pytest
from llama_index.core.base.embeddings.base import BaseEmbedding
from pytest_django.fixtures import SettingsWrapper
@pytest.fixture
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper):
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path:
settings.LLM_INDEX_DIR = tmp_path
return tmp_path
class FakeEmbedding(BaseEmbedding):
def _aget_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_text_embedding(self, text: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def get_query_embedding_dim(self) -> int:
return 384
@pytest.fixture
def mock_embed_model():
fake = FakeEmbedding()
with (
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
patch(
"paperless_ai.embedding.get_embedding_model",
) as mock_embedding,
):
mock_index.return_value = fake
mock_embedding.return_value = fake
yield mock_index
@@ -9,7 +9,6 @@ from django.contrib.auth.models import User
from django.test import override_settings
from django.utils import timezone
from faker import Faker
from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document
from documents.models import PaperlessTask
@@ -30,35 +29,6 @@ def real_document(db):
)
@pytest.fixture
def mock_embed_model():
fake = FakeEmbedding()
with (
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
patch(
"paperless_ai.embedding.get_embedding_model",
) as mock_embedding,
):
mock_index.return_value = fake
mock_embedding.return_value = fake
yield mock_index
class FakeEmbedding(BaseEmbedding):
# TODO: maybe a better way to do this?
def _aget_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_text_embedding(self, text: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def get_query_embedding_dim(self) -> int:
return 384 # Match your real FAISS config
@pytest.mark.django_db
def test_build_document_node(real_document) -> None:
nodes = indexing.build_document_node(real_document)
+59 -127
View File
@@ -5,9 +5,9 @@ from unittest.mock import patch
import pytest
from llama_index.core.schema import TextNode
from paperless_ai import chat
from paperless_ai.chat import CHAT_ERROR_MESSAGE
from paperless_ai.chat import CHAT_METADATA_DELIMITER
from paperless_ai.chat import _get_document_filtered_retriever
from paperless_ai.chat import stream_chat_with_documents
@@ -58,91 +58,6 @@ def assert_chat_output(
}
def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None:
mock_index.index_struct.nodes_dict = {
str(vector_id): node.node_id for vector_id, node in enumerate(nodes)
}
mock_index.docstore.docs.get.side_effect = {
node.node_id: node for node in nodes
}.get
mock_index.vector_store._faiss_index.ntotal = len(nodes)
mock_index.vector_store.query.return_value = MagicMock(
ids=list(mock_index.index_struct.nodes_dict),
similarities=[0.1] * len(nodes),
)
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
def test_document_filtered_retriever_expands_filters_and_caches() -> None:
allowed_node1 = TextNode(
text="Allowed content 1.",
metadata={"document_id": "1", "title": "Allowed 1"},
)
allowed_node2 = TextNode(
text="Allowed content 2.",
metadata={"document_id": "2", "title": "Allowed 2"},
)
foreign_node = TextNode(
text="Foreign content.",
metadata={"document_id": "3", "title": "Foreign"},
)
missing_node = TextNode(
text="Missing content.",
metadata={"document_id": "1", "title": "Missing"},
)
mock_index = MagicMock()
mock_index.index_struct.nodes_dict = {
"0": foreign_node.node_id,
"1": missing_node.node_id,
"2": allowed_node1.node_id,
"3": allowed_node2.node_id,
}
mock_index.docstore.docs.get.side_effect = {
allowed_node1.node_id: allowed_node1,
allowed_node2.node_id: allowed_node2,
foreign_node.node_id: foreign_node,
}.get
mock_index.vector_store._faiss_index.ntotal = 4
mock_index.vector_store.query.side_effect = [
MagicMock(ids=["0", "2"], similarities=[0.9, 0.8]),
MagicMock(ids=["0", "1", "3"], similarities=[0.9, 0.7, 0.6]),
]
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
retriever = _get_document_filtered_retriever(
mock_index,
{"1", "2"},
similarity_top_k=2,
)
nodes = retriever.retrieve("question")
cached_nodes = retriever.retrieve("question")
assert [node.node.node_id for node in nodes] == [
allowed_node1.node_id,
allowed_node2.node_id,
]
assert cached_nodes == nodes
assert mock_index.vector_store.query.call_count == 2
assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1
def test_document_filtered_retriever_handles_empty_faiss_index() -> None:
mock_index = MagicMock()
mock_index.vector_store._faiss_index.ntotal = 0
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
retriever = _get_document_filtered_retriever(
mock_index,
{"1"},
similarity_top_k=2,
)
assert retriever.retrieve("question") == []
mock_index.vector_store.query.assert_not_called()
@pytest.mark.django_db
def test_stream_chat_with_one_document_retrieval(
mock_document,
@@ -164,17 +79,31 @@ def test_stream_chat_with_one_document_retrieval(
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
add_vector_query_results(mock_index, [mock_node])
# Simulate get_nodes returning nodes (content exists)
mock_index.vector_store.get_nodes.return_value = [mock_node]
mock_load_index.return_value = mock_index
mock_retriever_instance = MagicMock()
mock_retriever_instance.retrieve.return_value = [
MagicMock(
metadata={
"document_id": str(mock_document.pk),
"title": "Test Document",
},
),
]
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
output = list(stream_chat_with_documents("What is this?", [mock_document]))
with patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=mock_retriever_instance,
):
output = list(stream_chat_with_documents("What is this?", [mock_document]))
mock_query_engine.query.assert_called_once_with("What is this?")
patch_embed_nodes.assert_not_called()
@@ -196,12 +125,10 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
):
# Mock AIClient and LLM
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
# Create two real TextNodes
mock_node1 = TextNode(
text="Content for doc 1.",
metadata={"document_id": "1", "title": "Document 1"},
@@ -210,41 +137,32 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
text="Content for doc 2.",
metadata={"document_id": "2", "title": "Document 2"},
)
mock_duplicate_node = TextNode(
text="More content for doc 1.",
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
)
mock_foreign_node = TextNode(
text="Content for doc 3.",
metadata={"document_id": "3", "title": "Document 3"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [
mock_node1,
mock_node2,
mock_duplicate_node,
mock_foreign_node,
]
add_vector_query_results(
mock_index,
[mock_node1, mock_duplicate_node, mock_node2, mock_foreign_node],
)
# Simulate get_nodes returning nodes (content exists)
mock_index.vector_store.get_nodes.return_value = [mock_node1, mock_node2]
mock_load_index.return_value = mock_index
# Mock response stream
mock_retriever_instance = MagicMock()
mock_retriever_instance.retrieve.return_value = [
MagicMock(metadata={"document_id": "1", "title": "Document 1"}),
MagicMock(metadata={"document_id": "2", "title": "Document 2"}),
]
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
# Mock RetrieverQueryEngine
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
# Fake documents
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
with patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=mock_retriever_instance,
):
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
mock_query_engine.query.assert_called_once_with("What's up?")
patch_embed_nodes.assert_not_called()
@@ -268,8 +186,8 @@ def test_stream_chat_no_matching_nodes() -> None:
mock_client.llm = MagicMock()
mock_index = MagicMock()
# No matching nodes
mock_index.docstore.docs.values.return_value = []
# No matching nodes in the store
mock_index.vector_store.get_nodes.return_value = []
mock_load_index.return_value = mock_index
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
@@ -281,28 +199,42 @@ def test_stream_chat_unexpected_failure_returns_generic_error(caplog) -> None:
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless_ai.chat._get_document_filtered_retriever",
) as mock_get_retriever,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
mock_node = TextNode(
text="This is node content.",
metadata={"document_id": "1", "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
# Nodes found so we get past the pre-check
mock_index.vector_store.get_nodes.return_value = [MagicMock()]
mock_load_index.return_value = mock_index
mock_retriever = MagicMock()
mock_retriever.retrieve.side_effect = RuntimeError("private provider detail")
mock_get_retriever.return_value = mock_retriever
with patch(
"llama_index.core.retrievers.VectorIndexRetriever",
) as mock_retriever_cls:
mock_retriever = MagicMock()
mock_retriever.retrieve.side_effect = RuntimeError(
"private provider detail",
)
mock_retriever_cls.return_value = mock_retriever
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
assert output == [CHAT_ERROR_MESSAGE]
assert "Failed to stream document chat response" in caplog.text
assert "private provider detail" in caplog.text
@pytest.mark.django_db
class TestStreamChatRetrieval:
def test_no_nodes_yields_no_content_message(
self,
temp_llm_index_dir,
mock_embed_model,
) -> None:
from documents.tests.factories import DocumentFactory
doc = DocumentFactory.create(content="hello world")
# Nothing indexed for this document yet.
out = list(chat.stream_chat_with_documents("question?", [doc]))
assert chat.CHAT_NO_CONTENT_MESSAGE in out