diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py index b2710c379..63e9267b0 100644 --- a/src/paperless_ai/chat.py +++ b/src/paperless_ai/chat.py @@ -75,82 +75,6 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str ) -def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int): - from llama_index.core.base.base_retriever import BaseRetriever - from llama_index.core.schema import NodeWithScore - from llama_index.core.vector_stores import VectorStoreQuery - - class DocumentFilteredFaissRetriever(BaseRetriever): - def __init__(self): - super().__init__() - self._cached_query_str = None - self._cached_nodes = [] - - def _retrieve(self, query_bundle): - if query_bundle.query_str == self._cached_query_str: - return self._cached_nodes - - if query_bundle.embedding is None: - query_bundle.embedding = ( - index._embed_model.get_agg_embedding_from_queries( - query_bundle.embedding_strs, - ) - ) - - faiss_index = index.vector_store._faiss_index - max_top_k = faiss_index.ntotal - if max_top_k == 0: - self._cached_query_str = query_bundle.query_str - self._cached_nodes = [] - return [] - - query_top_k = min(max(similarity_top_k, 1), max_top_k) - allowed_nodes: list[NodeWithScore] = [] - seen_node_ids: set[str] = set() - - while query_top_k <= max_top_k: - query_result = index.vector_store.query( - VectorStoreQuery( - query_embedding=query_bundle.embedding, - similarity_top_k=query_top_k, - ), - ) - - for vector_id, score in zip( - query_result.ids or [], - query_result.similarities or [], - strict=False, - ): - node_id = index.index_struct.nodes_dict.get(vector_id) - if node_id is None or node_id in seen_node_ids: - continue - - node = index.docstore.docs.get(node_id) - if node is None or node.metadata.get("document_id") not in doc_ids: - continue - - seen_node_ids.add(node_id) - allowed_nodes.append(NodeWithScore(node=node, score=score)) - - if len(allowed_nodes) >= similarity_top_k: - self._cached_query_str = query_bundle.query_str - self._cached_nodes = allowed_nodes - return allowed_nodes - - if query_top_k == max_top_k: - self._cached_query_str = query_bundle.query_str - self._cached_nodes = allowed_nodes - return allowed_nodes - - query_top_k = min(query_top_k * 2, max_top_k) - - self._cached_query_str = query_bundle.query_str - self._cached_nodes = allowed_nodes - return allowed_nodes - - return DocumentFilteredFaissRetriever() - - def stream_chat_with_documents(query_str: str, documents: list[Document]): try: yield from _stream_chat_with_documents(query_str, documents) @@ -160,31 +84,39 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): def _stream_chat_with_documents(query_str: str, documents: list[Document]): - client = AIClient() + from llama_index.core.prompts import PromptTemplate + from llama_index.core.query_engine import RetrieverQueryEngine + from llama_index.core.response_synthesizers import get_response_synthesizer + from llama_index.core.retrievers import VectorIndexRetriever + from llama_index.core.vector_stores.types import FilterOperator + from llama_index.core.vector_stores.types import MetadataFilter + from llama_index.core.vector_stores.types import MetadataFilters + index = load_or_build_index() doc_ids = [str(doc.pk) for doc in documents] + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="document_id", + operator=FilterOperator.IN, + value=doc_ids, + ), + ], + ) - # Filter only the node(s) that match the document IDs - nodes = [ - node - for node in index.docstore.docs.values() - if node.metadata.get("document_id") in doc_ids - ] - - if len(nodes) == 0: + # No indexed content for these documents -> bail early (before touching the LLM). + if not index.vector_store.get_nodes(filters=filters): logger.warning("No nodes found for the given documents.") yield CHAT_NO_CONTENT_MESSAGE return - from llama_index.core.prompts import PromptTemplate - from llama_index.core.query_engine import RetrieverQueryEngine - from llama_index.core.response_synthesizers import get_response_synthesizer + client = AIClient() - retriever = _get_document_filtered_retriever( - index, - set(doc_ids), - CHAT_RETRIEVER_TOP_K, + retriever = VectorIndexRetriever( + index=index, + similarity_top_k=CHAT_RETRIEVER_TOP_K, + filters=filters, ) top_nodes = retriever.retrieve(query_str) @@ -202,7 +134,6 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]): text_qa_template=prompt_template, streaming=True, ) - query_engine = RetrieverQueryEngine.from_args( retriever=retriever, llm=client.llm, @@ -211,9 +142,7 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]): ) logger.debug("Document chat query: %s", query_str) - response_stream = query_engine.query(query_str) - for chunk in response_stream.response_gen: yield chunk sys.stdout.flush() diff --git a/src/paperless_ai/tests/conftest.py b/src/paperless_ai/tests/conftest.py index 2d71476c7..67fcf0faa 100644 --- a/src/paperless_ai/tests/conftest.py +++ b/src/paperless_ai/tests/conftest.py @@ -1,10 +1,40 @@ from pathlib import Path +from unittest.mock import patch import pytest +from llama_index.core.base.embeddings.base import BaseEmbedding from pytest_django.fixtures import SettingsWrapper @pytest.fixture -def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper): +def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path: settings.LLM_INDEX_DIR = tmp_path return tmp_path + + +class FakeEmbedding(BaseEmbedding): + def _aget_query_embedding(self, query: str) -> list[float]: + return [0.1] * self.get_query_embedding_dim() + + def _get_query_embedding(self, query: str) -> list[float]: + return [0.1] * self.get_query_embedding_dim() + + def _get_text_embedding(self, text: str) -> list[float]: + return [0.1] * self.get_query_embedding_dim() + + def get_query_embedding_dim(self) -> int: + return 384 + + +@pytest.fixture +def mock_embed_model(): + fake = FakeEmbedding() + with ( + patch("paperless_ai.indexing.get_embedding_model") as mock_index, + patch( + "paperless_ai.embedding.get_embedding_model", + ) as mock_embedding, + ): + mock_index.return_value = fake + mock_embedding.return_value = fake + yield mock_index diff --git a/src/paperless_ai/tests/test_ai_indexing.py b/src/paperless_ai/tests/test_ai_indexing.py index 294d57c91..201edf052 100644 --- a/src/paperless_ai/tests/test_ai_indexing.py +++ b/src/paperless_ai/tests/test_ai_indexing.py @@ -9,7 +9,6 @@ from django.contrib.auth.models import User from django.test import override_settings from django.utils import timezone from faker import Faker -from llama_index.core.base.embeddings.base import BaseEmbedding from documents.models import Document from documents.models import PaperlessTask @@ -30,35 +29,6 @@ def real_document(db): ) -@pytest.fixture -def mock_embed_model(): - fake = FakeEmbedding() - with ( - patch("paperless_ai.indexing.get_embedding_model") as mock_index, - patch( - "paperless_ai.embedding.get_embedding_model", - ) as mock_embedding, - ): - mock_index.return_value = fake - mock_embedding.return_value = fake - yield mock_index - - -class FakeEmbedding(BaseEmbedding): - # TODO: maybe a better way to do this? - def _aget_query_embedding(self, query: str) -> list[float]: - return [0.1] * self.get_query_embedding_dim() - - def _get_query_embedding(self, query: str) -> list[float]: - return [0.1] * self.get_query_embedding_dim() - - def _get_text_embedding(self, text: str) -> list[float]: - return [0.1] * self.get_query_embedding_dim() - - def get_query_embedding_dim(self) -> int: - return 384 # Match your real FAISS config - - @pytest.mark.django_db def test_build_document_node(real_document) -> None: nodes = indexing.build_document_node(real_document) diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index d72b22f32..f7edc3fc9 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -5,9 +5,9 @@ from unittest.mock import patch import pytest from llama_index.core.schema import TextNode +from paperless_ai import chat from paperless_ai.chat import CHAT_ERROR_MESSAGE from paperless_ai.chat import CHAT_METADATA_DELIMITER -from paperless_ai.chat import _get_document_filtered_retriever from paperless_ai.chat import stream_chat_with_documents @@ -58,91 +58,6 @@ def assert_chat_output( } -def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None: - mock_index.index_struct.nodes_dict = { - str(vector_id): node.node_id for vector_id, node in enumerate(nodes) - } - mock_index.docstore.docs.get.side_effect = { - node.node_id: node for node in nodes - }.get - mock_index.vector_store._faiss_index.ntotal = len(nodes) - mock_index.vector_store.query.return_value = MagicMock( - ids=list(mock_index.index_struct.nodes_dict), - similarities=[0.1] * len(nodes), - ) - mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536 - - -def test_document_filtered_retriever_expands_filters_and_caches() -> None: - allowed_node1 = TextNode( - text="Allowed content 1.", - metadata={"document_id": "1", "title": "Allowed 1"}, - ) - allowed_node2 = TextNode( - text="Allowed content 2.", - metadata={"document_id": "2", "title": "Allowed 2"}, - ) - foreign_node = TextNode( - text="Foreign content.", - metadata={"document_id": "3", "title": "Foreign"}, - ) - missing_node = TextNode( - text="Missing content.", - metadata={"document_id": "1", "title": "Missing"}, - ) - - mock_index = MagicMock() - mock_index.index_struct.nodes_dict = { - "0": foreign_node.node_id, - "1": missing_node.node_id, - "2": allowed_node1.node_id, - "3": allowed_node2.node_id, - } - mock_index.docstore.docs.get.side_effect = { - allowed_node1.node_id: allowed_node1, - allowed_node2.node_id: allowed_node2, - foreign_node.node_id: foreign_node, - }.get - mock_index.vector_store._faiss_index.ntotal = 4 - mock_index.vector_store.query.side_effect = [ - MagicMock(ids=["0", "2"], similarities=[0.9, 0.8]), - MagicMock(ids=["0", "1", "3"], similarities=[0.9, 0.7, 0.6]), - ] - mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536 - - retriever = _get_document_filtered_retriever( - mock_index, - {"1", "2"}, - similarity_top_k=2, - ) - - nodes = retriever.retrieve("question") - cached_nodes = retriever.retrieve("question") - - assert [node.node.node_id for node in nodes] == [ - allowed_node1.node_id, - allowed_node2.node_id, - ] - assert cached_nodes == nodes - assert mock_index.vector_store.query.call_count == 2 - assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1 - - -def test_document_filtered_retriever_handles_empty_faiss_index() -> None: - mock_index = MagicMock() - mock_index.vector_store._faiss_index.ntotal = 0 - mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536 - - retriever = _get_document_filtered_retriever( - mock_index, - {"1"}, - similarity_top_k=2, - ) - - assert retriever.retrieve("question") == [] - mock_index.vector_store.query.assert_not_called() - - @pytest.mark.django_db def test_stream_chat_with_one_document_retrieval( mock_document, @@ -164,17 +79,31 @@ def test_stream_chat_with_one_document_retrieval( metadata={"document_id": str(mock_document.pk), "title": "Test Document"}, ) mock_index = MagicMock() - mock_index.docstore.docs.values.return_value = [mock_node] - add_vector_query_results(mock_index, [mock_node]) + # Simulate get_nodes returning nodes (content exists) + mock_index.vector_store.get_nodes.return_value = [mock_node] mock_load_index.return_value = mock_index + mock_retriever_instance = MagicMock() + mock_retriever_instance.retrieve.return_value = [ + MagicMock( + metadata={ + "document_id": str(mock_document.pk), + "title": "Test Document", + }, + ), + ] + mock_response_stream = MagicMock() mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) mock_query_engine = MagicMock() mock_query_engine_cls.return_value = mock_query_engine mock_query_engine.query.return_value = mock_response_stream - output = list(stream_chat_with_documents("What is this?", [mock_document])) + with patch( + "llama_index.core.retrievers.VectorIndexRetriever", + return_value=mock_retriever_instance, + ): + output = list(stream_chat_with_documents("What is this?", [mock_document])) mock_query_engine.query.assert_called_once_with("What is this?") patch_embed_nodes.assert_not_called() @@ -196,12 +125,10 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non "llama_index.core.query_engine.RetrieverQueryEngine.from_args", ) as mock_query_engine_cls, ): - # Mock AIClient and LLM mock_client = MagicMock() mock_client_cls.return_value = mock_client mock_client.llm = MagicMock() - # Create two real TextNodes mock_node1 = TextNode( text="Content for doc 1.", metadata={"document_id": "1", "title": "Document 1"}, @@ -210,41 +137,32 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non text="Content for doc 2.", metadata={"document_id": "2", "title": "Document 2"}, ) - mock_duplicate_node = TextNode( - text="More content for doc 1.", - metadata={"document_id": "1", "title": "Document 1 Duplicate"}, - ) - mock_foreign_node = TextNode( - text="Content for doc 3.", - metadata={"document_id": "3", "title": "Document 3"}, - ) mock_index = MagicMock() - mock_index.docstore.docs.values.return_value = [ - mock_node1, - mock_node2, - mock_duplicate_node, - mock_foreign_node, - ] - add_vector_query_results( - mock_index, - [mock_node1, mock_duplicate_node, mock_node2, mock_foreign_node], - ) + # Simulate get_nodes returning nodes (content exists) + mock_index.vector_store.get_nodes.return_value = [mock_node1, mock_node2] mock_load_index.return_value = mock_index - # Mock response stream + mock_retriever_instance = MagicMock() + mock_retriever_instance.retrieve.return_value = [ + MagicMock(metadata={"document_id": "1", "title": "Document 1"}), + MagicMock(metadata={"document_id": "2", "title": "Document 2"}), + ] + mock_response_stream = MagicMock() mock_response_stream.response_gen = iter(["chunk1", "chunk2"]) - # Mock RetrieverQueryEngine mock_query_engine = MagicMock() mock_query_engine_cls.return_value = mock_query_engine mock_query_engine.query.return_value = mock_response_stream - # Fake documents doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf") doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf") - output = list(stream_chat_with_documents("What's up?", [doc1, doc2])) + with patch( + "llama_index.core.retrievers.VectorIndexRetriever", + return_value=mock_retriever_instance, + ): + output = list(stream_chat_with_documents("What's up?", [doc1, doc2])) mock_query_engine.query.assert_called_once_with("What's up?") patch_embed_nodes.assert_not_called() @@ -268,8 +186,8 @@ def test_stream_chat_no_matching_nodes() -> None: mock_client.llm = MagicMock() mock_index = MagicMock() - # No matching nodes - mock_index.docstore.docs.values.return_value = [] + # No matching nodes in the store + mock_index.vector_store.get_nodes.return_value = [] mock_load_index.return_value = mock_index output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)])) @@ -281,28 +199,42 @@ def test_stream_chat_unexpected_failure_returns_generic_error(caplog) -> None: with ( patch("paperless_ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.load_or_build_index") as mock_load_index, - patch( - "paperless_ai.chat._get_document_filtered_retriever", - ) as mock_get_retriever, ): mock_client = MagicMock() mock_client_cls.return_value = mock_client mock_client.llm = MagicMock() - mock_node = TextNode( - text="This is node content.", - metadata={"document_id": "1", "title": "Test Document"}, - ) mock_index = MagicMock() - mock_index.docstore.docs.values.return_value = [mock_node] + # Nodes found so we get past the pre-check + mock_index.vector_store.get_nodes.return_value = [MagicMock()] mock_load_index.return_value = mock_index - mock_retriever = MagicMock() - mock_retriever.retrieve.side_effect = RuntimeError("private provider detail") - mock_get_retriever.return_value = mock_retriever + with patch( + "llama_index.core.retrievers.VectorIndexRetriever", + ) as mock_retriever_cls: + mock_retriever = MagicMock() + mock_retriever.retrieve.side_effect = RuntimeError( + "private provider detail", + ) + mock_retriever_cls.return_value = mock_retriever - output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)])) + output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)])) assert output == [CHAT_ERROR_MESSAGE] assert "Failed to stream document chat response" in caplog.text assert "private provider detail" in caplog.text + + +@pytest.mark.django_db +class TestStreamChatRetrieval: + def test_no_nodes_yields_no_content_message( + self, + temp_llm_index_dir, + mock_embed_model, + ) -> None: + from documents.tests.factories import DocumentFactory + + doc = DocumentFactory.create(content="hello world") + # Nothing indexed for this document yet. + out = list(chat.stream_chat_with_documents("question?", [doc])) + assert chat.CHAT_NO_CONTENT_MESSAGE in out