diff --git a/src-ui/src/app/components/chat/chat/chat.component.html b/src-ui/src/app/components/chat/chat/chat.component.html
index f0b61b805..c5cada978 100644
--- a/src-ui/src/app/components/chat/chat/chat.component.html
+++ b/src-ui/src/app/components/chat/chat/chat.component.html
@@ -8,10 +8,21 @@
@for (message of messages; track message) {
-
- {{ message.content }}
- @if (message.isStreaming) { | }
-
+
+
+ {{ message.content }}
+ @if (message.isStreaming) { | }
+
+ @if (message.role === 'assistant' && message.references?.length) {
+
+ }
+
}
diff --git a/src-ui/src/app/components/chat/chat/chat.component.scss b/src-ui/src/app/components/chat/chat/chat.component.scss
index 4b00cce1b..ccd714f3c 100644
--- a/src-ui/src/app/components/chat/chat/chat.component.scss
+++ b/src-ui/src/app/components/chat/chat/chat.component.scss
@@ -7,6 +7,10 @@
overflow-y: auto;
}
+.chat-references {
+ font-family: var(--bs-font-sans-serif);
+}
+
.dropdown-toggle::after {
display: none;
}
diff --git a/src-ui/src/app/components/chat/chat/chat.component.spec.ts b/src-ui/src/app/components/chat/chat/chat.component.spec.ts
index 0ccb04a99..a35117dc5 100644
--- a/src-ui/src/app/components/chat/chat/chat.component.spec.ts
+++ b/src-ui/src/app/components/chat/chat/chat.component.spec.ts
@@ -3,9 +3,13 @@ import { provideHttpClientTesting } from '@angular/common/http/testing'
import { ElementRef } from '@angular/core'
import { ComponentFixture, TestBed } from '@angular/core/testing'
import { NavigationEnd, Router } from '@angular/router'
+import { RouterTestingModule } from '@angular/router/testing'
import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { Subject } from 'rxjs'
-import { ChatService } from 'src/app/services/chat.service'
+import {
+ CHAT_METADATA_DELIMITER,
+ ChatService,
+} from 'src/app/services/chat.service'
import { ChatComponent } from './chat.component'
describe('ChatComponent', () => {
@@ -18,7 +22,11 @@ describe('ChatComponent', () => {
beforeEach(async () => {
TestBed.configureTestingModule({
- imports: [NgxBootstrapIconsModule.pick(allIcons), ChatComponent],
+ imports: [
+ NgxBootstrapIconsModule.pick(allIcons),
+ RouterTestingModule,
+ ChatComponent,
+ ],
providers: [
provideHttpClient(withInterceptorsFromDi()),
provideHttpClientTesting(),
@@ -84,6 +92,57 @@ describe('ChatComponent', () => {
expect(component.messages[1].isStreaming).toBe(false)
})
+ it('should parse references from the metadata trailer without showing it', () => {
+ component.input = 'Hello'
+ component.sendMessage()
+
+ mockStream$.next(
+ `Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}`
+ )
+ jest.advanceTimersByTime(1000)
+
+ expect(component.messages[1].content).toBe('Hi there')
+ expect(component.messages[1].references).toEqual([
+ { id: 42, title: 'Bread Recipe' },
+ ])
+ })
+
+ it('should render document reference links under assistant messages', () => {
+ component.input = 'Hello'
+ component.sendMessage()
+
+ mockStream$.next(
+ `Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}`
+ )
+ jest.advanceTimersByTime(1000)
+ fixture.detectChanges()
+
+ const link = fixture.nativeElement.querySelector('.chat-references a')
+ expect(link.textContent).toContain('Bread Recipe')
+ expect(link.getAttribute('href')).toContain('/documents/42')
+ })
+
+ it('should remove delimiter fragments that were already streamed', () => {
+ component.input = 'Hello'
+ component.sendMessage()
+
+ mockStream$.next(`Hi there${CHAT_METADATA_DELIMITER.slice(0, 8)}`)
+ jest.advanceTimersByTime(1000)
+ expect(component.messages[1].content).toBe(
+ `Hi there${CHAT_METADATA_DELIMITER.slice(0, 8)}`
+ )
+
+ mockStream$.next(
+ `Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}`
+ )
+ jest.advanceTimersByTime(1000)
+
+ expect(component.messages[1].content).toBe('Hi there')
+ expect(component.messages[1].references).toEqual([
+ { id: 42, title: 'Bread Recipe' },
+ ])
+ })
+
it('should handle errors during streaming', () => {
component.input = 'Hello'
component.sendMessage()
diff --git a/src-ui/src/app/components/chat/chat/chat.component.ts b/src-ui/src/app/components/chat/chat/chat.component.ts
index 50d27e0b1..ca17d4825 100644
--- a/src-ui/src/app/components/chat/chat/chat.component.ts
+++ b/src-ui/src/app/components/chat/chat/chat.component.ts
@@ -1,16 +1,21 @@
import { Component, ElementRef, inject, OnInit, ViewChild } from '@angular/core'
import { FormsModule, ReactiveFormsModule } from '@angular/forms'
-import { NavigationEnd, Router } from '@angular/router'
+import { NavigationEnd, Router, RouterModule } from '@angular/router'
import { NgbDropdownModule } from '@ng-bootstrap/ng-bootstrap'
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { filter, map } from 'rxjs'
-import { ChatMessage, ChatService } from 'src/app/services/chat.service'
+import {
+ ChatMessage,
+ ChatService,
+ parseChatResponse,
+} from 'src/app/services/chat.service'
@Component({
selector: 'pngx-chat',
imports: [
FormsModule,
ReactiveFormsModule,
+ RouterModule,
NgxBootstrapIconsModule,
NgbDropdownModule,
],
@@ -70,13 +75,24 @@ export class ChatComponent implements OnInit {
this.messages.push(assistantMessage)
this.loading = true
- let lastPartialLength = 0
+ let lastVisibleContent = ''
this.chatService.streamChat(this.documentId, this.input).subscribe({
next: (chunk) => {
- const delta = chunk.substring(lastPartialLength)
- lastPartialLength = chunk.length
- this.enqueueTypewriter(delta, assistantMessage)
+ const nextResponse = parseChatResponse(chunk)
+
+ if (nextResponse.content.length < lastVisibleContent.length) {
+ this.resetTypewriter(assistantMessage, nextResponse.content)
+ lastVisibleContent = nextResponse.content
+ } else {
+ const visibleDelta = nextResponse.content.substring(
+ lastVisibleContent.length
+ )
+ lastVisibleContent = nextResponse.content
+ this.enqueueTypewriter(visibleDelta, assistantMessage)
+ }
+
+ assistantMessage.references = nextResponse.references
},
error: () => {
assistantMessage.content += '\n\n⚠️ Error receiving response.'
@@ -93,6 +109,13 @@ export class ChatComponent implements OnInit {
this.input = ''
}
+ private resetTypewriter(message: ChatMessage, content: string): void {
+ this.typewriterBuffer = []
+ this.typewriterActive = false
+ message.content = content
+ this.scrollToBottom()
+ }
+
enqueueTypewriter(chunk: string, message: ChatMessage): void {
if (!chunk) return
diff --git a/src-ui/src/app/services/chat.service.spec.ts b/src-ui/src/app/services/chat.service.spec.ts
index b8ca957cb..d64897afe 100644
--- a/src-ui/src/app/services/chat.service.spec.ts
+++ b/src-ui/src/app/services/chat.service.spec.ts
@@ -9,7 +9,11 @@ import {
} from '@angular/common/http/testing'
import { TestBed } from '@angular/core/testing'
import { environment } from 'src/environments/environment'
-import { ChatService } from './chat.service'
+import {
+ CHAT_METADATA_DELIMITER,
+ ChatService,
+ parseChatResponse,
+} from './chat.service'
describe('ChatService', () => {
let service: ChatService
@@ -55,4 +59,22 @@ describe('ChatService', () => {
partialText: mockResponse,
} as any)
})
+
+ it('should parse chat references from the metadata trailer', () => {
+ const parsed = parseChatResponse(
+ `Answer text${CHAT_METADATA_DELIMITER}{"references":[{"id":1,"title":"Document 1"}]}`
+ )
+
+ expect(parsed.content).toBe('Answer text')
+ expect(parsed.references).toEqual([{ id: 1, title: 'Document 1' }])
+ })
+
+ it('should hide incomplete metadata trailer from the visible content', () => {
+ const parsed = parseChatResponse(
+ `Answer text${CHAT_METADATA_DELIMITER}{"references"`
+ )
+
+ expect(parsed.content).toBe('Answer text')
+ expect(parsed.references).toBeUndefined()
+ })
})
diff --git a/src-ui/src/app/services/chat.service.ts b/src-ui/src/app/services/chat.service.ts
index 9ddfb8330..bc4acbbdf 100644
--- a/src-ui/src/app/services/chat.service.ts
+++ b/src-ui/src/app/services/chat.service.ts
@@ -11,6 +11,46 @@ export interface ChatMessage {
role: 'user' | 'assistant'
content: string
isStreaming?: boolean
+ references?: ChatReference[]
+}
+
+export interface ChatReference {
+ id: number
+ title: string
+}
+
+export interface ParsedChatResponse {
+ content: string
+ references?: ChatReference[]
+}
+
+export const CHAT_METADATA_DELIMITER = '\n\n__PAPERLESS_CHAT_METADATA__'
+
+export function parseChatResponse(response: string): ParsedChatResponse {
+ const delimiterIndex = response.indexOf(CHAT_METADATA_DELIMITER)
+
+ if (delimiterIndex === -1) {
+ return { content: response }
+ }
+
+ const metadataString = response.slice(
+ delimiterIndex + CHAT_METADATA_DELIMITER.length
+ )
+
+ try {
+ const metadata = JSON.parse(metadataString) as {
+ references?: ChatReference[]
+ }
+
+ return {
+ content: response.slice(0, delimiterIndex),
+ references: metadata.references ?? [],
+ }
+ } catch {
+ return {
+ content: response.slice(0, delimiterIndex),
+ }
+ }
}
@Injectable({
diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py
index 33603c45e..f149a5fc5 100644
--- a/src/paperless_ai/chat.py
+++ b/src/paperless_ai/chat.py
@@ -1,3 +1,4 @@
+import json
import logging
import sys
@@ -9,6 +10,8 @@ logger = logging.getLogger("paperless_ai.chat")
MAX_SINGLE_DOC_CONTEXT_CHARS = 15000
SINGLE_DOC_SNIPPET_CHARS = 800
+CHAT_METADATA_DELIMITER = "\n\n__PAPERLESS_CHAT_METADATA__"
+MAX_CHAT_REFERENCES = 3
CHAT_PROMPT_TMPL = """Context information is below.
---------------------
@@ -19,6 +22,52 @@ CHAT_PROMPT_TMPL = """Context information is below.
Answer:"""
+def _build_document_reference(
+ document: Document,
+ title: str | None = None,
+) -> dict[str, int | str]:
+ return {
+ "id": document.pk,
+ "title": title or document.title or document.filename,
+ }
+
+
+def _get_document_references(
+ documents: list[Document],
+ top_nodes: list,
+) -> list[dict[str, int | str]]:
+ allowed_documents = {doc.pk: doc for doc in documents}
+ references: list[dict[str, int | str]] = []
+ seen_document_ids: set[int] = set()
+
+ for node in top_nodes:
+ try:
+ document_id = int(node.metadata["document_id"])
+ except (KeyError, TypeError, ValueError): # pragma: no cover
+ continue
+
+ if document_id in seen_document_ids or document_id not in allowed_documents:
+ continue
+
+ seen_document_ids.add(document_id)
+ document = allowed_documents[document_id]
+ references.append(
+ _build_document_reference(document, node.metadata.get("title")),
+ )
+
+ if len(references) >= MAX_CHAT_REFERENCES: # pragma: no cover
+ break
+
+ return references
+
+
+def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str:
+ return (
+ f"{CHAT_METADATA_DELIMITER}"
+ f"{json.dumps({'references': references}, separators=(',', ':'))}"
+ )
+
+
def stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
index = load_or_build_index()
@@ -49,6 +98,7 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
if len(documents) == 1:
# Just one doc — provide full content
doc = documents[0]
+ references = [_build_document_reference(doc)]
# TODO: include document metadata in the context
content = doc.content or ""
context_body = content
@@ -78,6 +128,7 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
yield "Sorry, I couldn't find any content to answer your question."
return
+ references = _get_document_references(documents, top_nodes)
context = "\n\n".join(
f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}"
for node in top_nodes
@@ -102,3 +153,6 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()
+
+ if references:
+ yield _format_chat_metadata_trailer(references)
diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py
index 902c907f2..5e26ca0af 100644
--- a/src/paperless_ai/tests/test_chat.py
+++ b/src/paperless_ai/tests/test_chat.py
@@ -1,3 +1,4 @@
+import json
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -5,6 +6,7 @@ import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
+from paperless_ai.chat import CHAT_METADATA_DELIMITER
from paperless_ai.chat import stream_chat_with_documents
@@ -40,6 +42,21 @@ def mock_document():
return doc
+def assert_chat_output(
+ output: list[str],
+ *,
+ expected_chunks: list[str],
+ expected_references: list[dict[str, int | str]],
+) -> None:
+ assert output[:-1] == expected_chunks
+
+ trailer = output[-1]
+ assert trailer.startswith(CHAT_METADATA_DELIMITER)
+ assert json.loads(trailer.removeprefix(CHAT_METADATA_DELIMITER)) == {
+ "references": expected_references,
+ }
+
+
def test_stream_chat_with_one_document_full_content(mock_document) -> None:
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
@@ -68,7 +85,13 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None:
output = list(stream_chat_with_documents("What is this?", [mock_document]))
- assert output == ["chunk1", "chunk2"]
+ assert_chat_output(
+ output,
+ expected_chunks=["chunk1", "chunk2"],
+ expected_references=[
+ {"id": mock_document.pk, "title": "Test Document"},
+ ],
+ )
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> None:
@@ -100,7 +123,20 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
# Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
mock_retriever = MagicMock()
- mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
+ mock_duplicate_node = TextNode(
+ text="More content for doc 1.",
+ metadata={"document_id": "1", "title": "Document 1 Duplicate"},
+ )
+ mock_foreign_node = TextNode(
+ text="Content for doc 3.",
+ metadata={"document_id": "3", "title": "Document 3"},
+ )
+ mock_retriever.retrieve.return_value = [
+ mock_node1,
+ mock_duplicate_node,
+ mock_node2,
+ mock_foreign_node,
+ ]
mock_as_retriever.return_value = mock_retriever
# Mock response stream
@@ -113,12 +149,19 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
mock_query_engine.query.return_value = mock_response_stream
# Fake documents
- doc1 = MagicMock(pk=1)
- doc2 = MagicMock(pk=2)
+ doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
+ doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
- assert output == ["chunk1", "chunk2"]
+ assert_chat_output(
+ output,
+ expected_chunks=["chunk1", "chunk2"],
+ expected_references=[
+ {"id": 1, "title": "Document 1"},
+ {"id": 2, "title": "Document 2"},
+ ],
+ )
def test_stream_chat_no_matching_nodes() -> None: