From 354df34e47935466da67f7d72f72fcb4d2113a8c Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Tue, 28 Apr 2026 13:00:20 -0700 Subject: [PATCH] Enhancement: chat message document links (#12670) --- .../components/chat/chat/chat.component.html | 19 ++++-- .../components/chat/chat/chat.component.scss | 4 ++ .../chat/chat/chat.component.spec.ts | 63 ++++++++++++++++++- .../components/chat/chat/chat.component.ts | 35 +++++++++-- src-ui/src/app/services/chat.service.spec.ts | 24 ++++++- src-ui/src/app/services/chat.service.ts | 40 ++++++++++++ src/paperless_ai/chat.py | 54 ++++++++++++++++ src/paperless_ai/tests/test_chat.py | 53 ++++++++++++++-- 8 files changed, 274 insertions(+), 18 deletions(-) diff --git a/src-ui/src/app/components/chat/chat/chat.component.html b/src-ui/src/app/components/chat/chat/chat.component.html index f0b61b805..c5cada978 100644 --- a/src-ui/src/app/components/chat/chat/chat.component.html +++ b/src-ui/src/app/components/chat/chat/chat.component.html @@ -8,10 +8,21 @@
@for (message of messages; track message) {
- - {{ message.content }} - @if (message.isStreaming) { | } - +
+ + {{ message.content }} + @if (message.isStreaming) { | } + + @if (message.role === 'assistant' && message.references?.length) { +
+ @for (reference of message.references; track reference.id) { + + {{ reference.title }} + + } +
+ } +
}
diff --git a/src-ui/src/app/components/chat/chat/chat.component.scss b/src-ui/src/app/components/chat/chat/chat.component.scss index 4b00cce1b..ccd714f3c 100644 --- a/src-ui/src/app/components/chat/chat/chat.component.scss +++ b/src-ui/src/app/components/chat/chat/chat.component.scss @@ -7,6 +7,10 @@ overflow-y: auto; } +.chat-references { + font-family: var(--bs-font-sans-serif); +} + .dropdown-toggle::after { display: none; } diff --git a/src-ui/src/app/components/chat/chat/chat.component.spec.ts b/src-ui/src/app/components/chat/chat/chat.component.spec.ts index 0ccb04a99..a35117dc5 100644 --- a/src-ui/src/app/components/chat/chat/chat.component.spec.ts +++ b/src-ui/src/app/components/chat/chat/chat.component.spec.ts @@ -3,9 +3,13 @@ import { provideHttpClientTesting } from '@angular/common/http/testing' import { ElementRef } from '@angular/core' import { ComponentFixture, TestBed } from '@angular/core/testing' import { NavigationEnd, Router } from '@angular/router' +import { RouterTestingModule } from '@angular/router/testing' import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons' import { Subject } from 'rxjs' -import { ChatService } from 'src/app/services/chat.service' +import { + CHAT_METADATA_DELIMITER, + ChatService, +} from 'src/app/services/chat.service' import { ChatComponent } from './chat.component' describe('ChatComponent', () => { @@ -18,7 +22,11 @@ describe('ChatComponent', () => { beforeEach(async () => { TestBed.configureTestingModule({ - imports: [NgxBootstrapIconsModule.pick(allIcons), ChatComponent], + imports: [ + NgxBootstrapIconsModule.pick(allIcons), + RouterTestingModule, + ChatComponent, + ], providers: [ provideHttpClient(withInterceptorsFromDi()), provideHttpClientTesting(), @@ -84,6 +92,57 @@ describe('ChatComponent', () => { expect(component.messages[1].isStreaming).toBe(false) }) + it('should parse references from the metadata trailer without showing it', () => { + component.input = 'Hello' + component.sendMessage() + + mockStream$.next( + `Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}` + ) + jest.advanceTimersByTime(1000) + + expect(component.messages[1].content).toBe('Hi there') + expect(component.messages[1].references).toEqual([ + { id: 42, title: 'Bread Recipe' }, + ]) + }) + + it('should render document reference links under assistant messages', () => { + component.input = 'Hello' + component.sendMessage() + + mockStream$.next( + `Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}` + ) + jest.advanceTimersByTime(1000) + fixture.detectChanges() + + const link = fixture.nativeElement.querySelector('.chat-references a') + expect(link.textContent).toContain('Bread Recipe') + expect(link.getAttribute('href')).toContain('/documents/42') + }) + + it('should remove delimiter fragments that were already streamed', () => { + component.input = 'Hello' + component.sendMessage() + + mockStream$.next(`Hi there${CHAT_METADATA_DELIMITER.slice(0, 8)}`) + jest.advanceTimersByTime(1000) + expect(component.messages[1].content).toBe( + `Hi there${CHAT_METADATA_DELIMITER.slice(0, 8)}` + ) + + mockStream$.next( + `Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}` + ) + jest.advanceTimersByTime(1000) + + expect(component.messages[1].content).toBe('Hi there') + expect(component.messages[1].references).toEqual([ + { id: 42, title: 'Bread Recipe' }, + ]) + }) + it('should handle errors during streaming', () => { component.input = 'Hello' component.sendMessage() diff --git a/src-ui/src/app/components/chat/chat/chat.component.ts b/src-ui/src/app/components/chat/chat/chat.component.ts index 50d27e0b1..ca17d4825 100644 --- a/src-ui/src/app/components/chat/chat/chat.component.ts +++ b/src-ui/src/app/components/chat/chat/chat.component.ts @@ -1,16 +1,21 @@ import { Component, ElementRef, inject, OnInit, ViewChild } from '@angular/core' import { FormsModule, ReactiveFormsModule } from '@angular/forms' -import { NavigationEnd, Router } from '@angular/router' +import { NavigationEnd, Router, RouterModule } from '@angular/router' import { NgbDropdownModule } from '@ng-bootstrap/ng-bootstrap' import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons' import { filter, map } from 'rxjs' -import { ChatMessage, ChatService } from 'src/app/services/chat.service' +import { + ChatMessage, + ChatService, + parseChatResponse, +} from 'src/app/services/chat.service' @Component({ selector: 'pngx-chat', imports: [ FormsModule, ReactiveFormsModule, + RouterModule, NgxBootstrapIconsModule, NgbDropdownModule, ], @@ -70,13 +75,24 @@ export class ChatComponent implements OnInit { this.messages.push(assistantMessage) this.loading = true - let lastPartialLength = 0 + let lastVisibleContent = '' this.chatService.streamChat(this.documentId, this.input).subscribe({ next: (chunk) => { - const delta = chunk.substring(lastPartialLength) - lastPartialLength = chunk.length - this.enqueueTypewriter(delta, assistantMessage) + const nextResponse = parseChatResponse(chunk) + + if (nextResponse.content.length < lastVisibleContent.length) { + this.resetTypewriter(assistantMessage, nextResponse.content) + lastVisibleContent = nextResponse.content + } else { + const visibleDelta = nextResponse.content.substring( + lastVisibleContent.length + ) + lastVisibleContent = nextResponse.content + this.enqueueTypewriter(visibleDelta, assistantMessage) + } + + assistantMessage.references = nextResponse.references }, error: () => { assistantMessage.content += '\n\n⚠️ Error receiving response.' @@ -93,6 +109,13 @@ export class ChatComponent implements OnInit { this.input = '' } + private resetTypewriter(message: ChatMessage, content: string): void { + this.typewriterBuffer = [] + this.typewriterActive = false + message.content = content + this.scrollToBottom() + } + enqueueTypewriter(chunk: string, message: ChatMessage): void { if (!chunk) return diff --git a/src-ui/src/app/services/chat.service.spec.ts b/src-ui/src/app/services/chat.service.spec.ts index b8ca957cb..d64897afe 100644 --- a/src-ui/src/app/services/chat.service.spec.ts +++ b/src-ui/src/app/services/chat.service.spec.ts @@ -9,7 +9,11 @@ import { } from '@angular/common/http/testing' import { TestBed } from '@angular/core/testing' import { environment } from 'src/environments/environment' -import { ChatService } from './chat.service' +import { + CHAT_METADATA_DELIMITER, + ChatService, + parseChatResponse, +} from './chat.service' describe('ChatService', () => { let service: ChatService @@ -55,4 +59,22 @@ describe('ChatService', () => { partialText: mockResponse, } as any) }) + + it('should parse chat references from the metadata trailer', () => { + const parsed = parseChatResponse( + `Answer text${CHAT_METADATA_DELIMITER}{"references":[{"id":1,"title":"Document 1"}]}` + ) + + expect(parsed.content).toBe('Answer text') + expect(parsed.references).toEqual([{ id: 1, title: 'Document 1' }]) + }) + + it('should hide incomplete metadata trailer from the visible content', () => { + const parsed = parseChatResponse( + `Answer text${CHAT_METADATA_DELIMITER}{"references"` + ) + + expect(parsed.content).toBe('Answer text') + expect(parsed.references).toBeUndefined() + }) }) diff --git a/src-ui/src/app/services/chat.service.ts b/src-ui/src/app/services/chat.service.ts index 9ddfb8330..bc4acbbdf 100644 --- a/src-ui/src/app/services/chat.service.ts +++ b/src-ui/src/app/services/chat.service.ts @@ -11,6 +11,46 @@ export interface ChatMessage { role: 'user' | 'assistant' content: string isStreaming?: boolean + references?: ChatReference[] +} + +export interface ChatReference { + id: number + title: string +} + +export interface ParsedChatResponse { + content: string + references?: ChatReference[] +} + +export const CHAT_METADATA_DELIMITER = '\n\n__PAPERLESS_CHAT_METADATA__' + +export function parseChatResponse(response: string): ParsedChatResponse { + const delimiterIndex = response.indexOf(CHAT_METADATA_DELIMITER) + + if (delimiterIndex === -1) { + return { content: response } + } + + const metadataString = response.slice( + delimiterIndex + CHAT_METADATA_DELIMITER.length + ) + + try { + const metadata = JSON.parse(metadataString) as { + references?: ChatReference[] + } + + return { + content: response.slice(0, delimiterIndex), + references: metadata.references ?? [], + } + } catch { + return { + content: response.slice(0, delimiterIndex), + } + } } @Injectable({ diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py index 33603c45e..f149a5fc5 100644 --- a/src/paperless_ai/chat.py +++ b/src/paperless_ai/chat.py @@ -1,3 +1,4 @@ +import json import logging import sys @@ -9,6 +10,8 @@ logger = logging.getLogger("paperless_ai.chat") MAX_SINGLE_DOC_CONTEXT_CHARS = 15000 SINGLE_DOC_SNIPPET_CHARS = 800 +CHAT_METADATA_DELIMITER = "\n\n__PAPERLESS_CHAT_METADATA__" +MAX_CHAT_REFERENCES = 3 CHAT_PROMPT_TMPL = """Context information is below. --------------------- @@ -19,6 +22,52 @@ CHAT_PROMPT_TMPL = """Context information is below. Answer:""" +def _build_document_reference( + document: Document, + title: str | None = None, +) -> dict[str, int | str]: + return { + "id": document.pk, + "title": title or document.title or document.filename, + } + + +def _get_document_references( + documents: list[Document], + top_nodes: list, +) -> list[dict[str, int | str]]: + allowed_documents = {doc.pk: doc for doc in documents} + references: list[dict[str, int | str]] = [] + seen_document_ids: set[int] = set() + + for node in top_nodes: + try: + document_id = int(node.metadata["document_id"]) + except (KeyError, TypeError, ValueError): # pragma: no cover + continue + + if document_id in seen_document_ids or document_id not in allowed_documents: + continue + + seen_document_ids.add(document_id) + document = allowed_documents[document_id] + references.append( + _build_document_reference(document, node.metadata.get("title")), + ) + + if len(references) >= MAX_CHAT_REFERENCES: # pragma: no cover + break + + return references + + +def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str: + return ( + f"{CHAT_METADATA_DELIMITER}" + f"{json.dumps({'references': references}, separators=(',', ':'))}" + ) + + def stream_chat_with_documents(query_str: str, documents: list[Document]): client = AIClient() index = load_or_build_index() @@ -49,6 +98,7 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): if len(documents) == 1: # Just one doc — provide full content doc = documents[0] + references = [_build_document_reference(doc)] # TODO: include document metadata in the context content = doc.content or "" context_body = content @@ -78,6 +128,7 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): yield "Sorry, I couldn't find any content to answer your question." return + references = _get_document_references(documents, top_nodes) context = "\n\n".join( f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}" for node in top_nodes @@ -102,3 +153,6 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): for chunk in response_stream.response_gen: yield chunk sys.stdout.flush() + + if references: + yield _format_chat_metadata_trailer(references) diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index 902c907f2..5e26ca0af 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -1,3 +1,4 @@ +import json from unittest.mock import MagicMock from unittest.mock import patch @@ -5,6 +6,7 @@ import pytest from llama_index.core import VectorStoreIndex from llama_index.core.schema import TextNode +from paperless_ai.chat import CHAT_METADATA_DELIMITER from paperless_ai.chat import stream_chat_with_documents @@ -40,6 +42,21 @@ def mock_document(): return doc +def assert_chat_output( + output: list[str], + *, + expected_chunks: list[str], + expected_references: list[dict[str, int | str]], +) -> None: + assert output[:-1] == expected_chunks + + trailer = output[-1] + assert trailer.startswith(CHAT_METADATA_DELIMITER) + assert json.loads(trailer.removeprefix(CHAT_METADATA_DELIMITER)) == { + "references": expected_references, + } + + def test_stream_chat_with_one_document_full_content(mock_document) -> None: with ( patch("paperless_ai.chat.AIClient") as mock_client_cls, @@ -68,7 +85,13 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None: output = list(stream_chat_with_documents("What is this?", [mock_document])) - assert output == ["chunk1", "chunk2"] + assert_chat_output( + output, + expected_chunks=["chunk1", "chunk2"], + expected_references=[ + {"id": mock_document.pk, "title": "Test Document"}, + ], + ) def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> None: @@ -100,7 +123,20 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non # Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2 mock_retriever = MagicMock() - mock_retriever.retrieve.return_value = [mock_node1, mock_node2] + mock_duplicate_node = TextNode( + text="More content for doc 1.", + metadata={"document_id": "1", "title": "Document 1 Duplicate"}, + ) + mock_foreign_node = TextNode( + text="Content for doc 3.", + metadata={"document_id": "3", "title": "Document 3"}, + ) + mock_retriever.retrieve.return_value = [ + mock_node1, + mock_duplicate_node, + mock_node2, + mock_foreign_node, + ] mock_as_retriever.return_value = mock_retriever # Mock response stream @@ -113,12 +149,19 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non mock_query_engine.query.return_value = mock_response_stream # Fake documents - doc1 = MagicMock(pk=1) - doc2 = MagicMock(pk=2) + doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf") + doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf") output = list(stream_chat_with_documents("What's up?", [doc1, doc2])) - assert output == ["chunk1", "chunk2"] + assert_chat_output( + output, + expected_chunks=["chunk1", "chunk2"], + expected_references=[ + {"id": 1, "title": "Document 1"}, + {"id": 2, "title": "Document 2"}, + ], + ) def test_stream_chat_no_matching_nodes() -> None: