diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py index be1da80f6..6465cec9e 100644 --- a/src/paperless_ai/chat.py +++ b/src/paperless_ai/chat.py @@ -94,7 +94,7 @@ def _stream_chat_with_documents(query_str: str, documents: list[Document]): index = load_or_build_index() - doc_ids = [str(doc.pk) for doc in documents] + doc_ids = sorted(str(doc.pk) for doc in documents) filters = MetadataFilters( filters=[ MetadataFilter( diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index f7edc3fc9..52e36b15a 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -238,3 +238,46 @@ class TestStreamChatRetrieval: # Nothing indexed for this document yet. out = list(chat.stream_chat_with_documents("question?", [doc])) assert chat.CHAT_NO_CONTENT_MESSAGE in out + + def test_chat_filter_contains_only_requested_document_ids( + self, + temp_llm_index_dir, + mock_embed_model, + mocker, + ) -> None: + """The MetadataFilter passed to the retriever must be scoped to the + requested documents only — content from other indexed documents must + not be surfaced. + """ + from documents.tests.factories import DocumentFactory + from paperless_ai import indexing + + included = DocumentFactory.create(content="included document content") + excluded = DocumentFactory.create(content="excluded document content") + indexing.llm_index_add_or_update_document(included) + indexing.llm_index_add_or_update_document(excluded) + + # VectorIndexRetriever is imported inside _stream_chat_with_documents; + # patch it at the llama_index source so the lazy import picks it up. + captured_filters = [] + mock_retriever = mocker.MagicMock() + mock_retriever.retrieve.return_value = [] + + def capture_retriever(*args, **kwargs): + captured_filters.append(kwargs.get("filters")) + return mock_retriever + + mocker.patch("paperless_ai.chat.AIClient") + mocker.patch( + "llama_index.core.retrievers.VectorIndexRetriever", + side_effect=capture_retriever, + ) + + list(chat.stream_chat_with_documents("question?", [included])) + + assert captured_filters, "VectorIndexRetriever was never constructed" + filt = captured_filters[0] + assert filt is not None, "Retriever must receive a MetadataFilters" + filter_values = filt.filters[0].value + assert str(included.pk) in filter_values + assert str(excluded.pk) not in filter_values