Enhancement: chat message document links (#12670)

This commit is contained in:
shamoon
2026-04-28 13:00:20 -07:00
committed by GitHub
parent 8e6fd010a0
commit 354df34e47
8 changed files with 274 additions and 18 deletions

View File

@@ -8,10 +8,21 @@
<div class="chat-messages font-monospace small">
@for (message of messages; track message) {
<div class="message d-flex flex-row small" [class.justify-content-end]="message.role === 'user'">
<span class="p-2 m-2" [class.bg-dark]="message.role === 'user'">
{{ message.content }}
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
</span>
<div class="p-2 m-2" [class.bg-dark]="message.role === 'user'">
<span>
{{ message.content }}
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
</span>
@if (message.role === 'assistant' && message.references?.length) {
<div class="chat-references list-group mt-3">
@for (reference of message.references; track reference.id) {
<a class="list-group-item list-group-item-action text-primary" [routerLink]="['/documents', reference.id]">
<i-bs width="0.9em" height="0.9em" name="file-text" class="me-1"></i-bs><span>{{ reference.title }}</span>
</a>
}
</div>
}
</div>
</div>
}
<div #scrollAnchor></div>

View File

@@ -7,6 +7,10 @@
overflow-y: auto;
}
.chat-references {
font-family: var(--bs-font-sans-serif);
}
.dropdown-toggle::after {
display: none;
}

View File

@@ -3,9 +3,13 @@ import { provideHttpClientTesting } from '@angular/common/http/testing'
import { ElementRef } from '@angular/core'
import { ComponentFixture, TestBed } from '@angular/core/testing'
import { NavigationEnd, Router } from '@angular/router'
import { RouterTestingModule } from '@angular/router/testing'
import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { Subject } from 'rxjs'
import { ChatService } from 'src/app/services/chat.service'
import {
CHAT_METADATA_DELIMITER,
ChatService,
} from 'src/app/services/chat.service'
import { ChatComponent } from './chat.component'
describe('ChatComponent', () => {
@@ -18,7 +22,11 @@ describe('ChatComponent', () => {
beforeEach(async () => {
TestBed.configureTestingModule({
imports: [NgxBootstrapIconsModule.pick(allIcons), ChatComponent],
imports: [
NgxBootstrapIconsModule.pick(allIcons),
RouterTestingModule,
ChatComponent,
],
providers: [
provideHttpClient(withInterceptorsFromDi()),
provideHttpClientTesting(),
@@ -84,6 +92,57 @@ describe('ChatComponent', () => {
expect(component.messages[1].isStreaming).toBe(false)
})
it('should parse references from the metadata trailer without showing it', () => {
component.input = 'Hello'
component.sendMessage()
mockStream$.next(
`Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}`
)
jest.advanceTimersByTime(1000)
expect(component.messages[1].content).toBe('Hi there')
expect(component.messages[1].references).toEqual([
{ id: 42, title: 'Bread Recipe' },
])
})
it('should render document reference links under assistant messages', () => {
component.input = 'Hello'
component.sendMessage()
mockStream$.next(
`Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}`
)
jest.advanceTimersByTime(1000)
fixture.detectChanges()
const link = fixture.nativeElement.querySelector('.chat-references a')
expect(link.textContent).toContain('Bread Recipe')
expect(link.getAttribute('href')).toContain('/documents/42')
})
it('should remove delimiter fragments that were already streamed', () => {
component.input = 'Hello'
component.sendMessage()
mockStream$.next(`Hi there${CHAT_METADATA_DELIMITER.slice(0, 8)}`)
jest.advanceTimersByTime(1000)
expect(component.messages[1].content).toBe(
`Hi there${CHAT_METADATA_DELIMITER.slice(0, 8)}`
)
mockStream$.next(
`Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}`
)
jest.advanceTimersByTime(1000)
expect(component.messages[1].content).toBe('Hi there')
expect(component.messages[1].references).toEqual([
{ id: 42, title: 'Bread Recipe' },
])
})
it('should handle errors during streaming', () => {
component.input = 'Hello'
component.sendMessage()

View File

@@ -1,16 +1,21 @@
import { Component, ElementRef, inject, OnInit, ViewChild } from '@angular/core'
import { FormsModule, ReactiveFormsModule } from '@angular/forms'
import { NavigationEnd, Router } from '@angular/router'
import { NavigationEnd, Router, RouterModule } from '@angular/router'
import { NgbDropdownModule } from '@ng-bootstrap/ng-bootstrap'
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { filter, map } from 'rxjs'
import { ChatMessage, ChatService } from 'src/app/services/chat.service'
import {
ChatMessage,
ChatService,
parseChatResponse,
} from 'src/app/services/chat.service'
@Component({
selector: 'pngx-chat',
imports: [
FormsModule,
ReactiveFormsModule,
RouterModule,
NgxBootstrapIconsModule,
NgbDropdownModule,
],
@@ -70,13 +75,24 @@ export class ChatComponent implements OnInit {
this.messages.push(assistantMessage)
this.loading = true
let lastPartialLength = 0
let lastVisibleContent = ''
this.chatService.streamChat(this.documentId, this.input).subscribe({
next: (chunk) => {
const delta = chunk.substring(lastPartialLength)
lastPartialLength = chunk.length
this.enqueueTypewriter(delta, assistantMessage)
const nextResponse = parseChatResponse(chunk)
if (nextResponse.content.length < lastVisibleContent.length) {
this.resetTypewriter(assistantMessage, nextResponse.content)
lastVisibleContent = nextResponse.content
} else {
const visibleDelta = nextResponse.content.substring(
lastVisibleContent.length
)
lastVisibleContent = nextResponse.content
this.enqueueTypewriter(visibleDelta, assistantMessage)
}
assistantMessage.references = nextResponse.references
},
error: () => {
assistantMessage.content += '\n\n⚠ Error receiving response.'
@@ -93,6 +109,13 @@ export class ChatComponent implements OnInit {
this.input = ''
}
private resetTypewriter(message: ChatMessage, content: string): void {
this.typewriterBuffer = []
this.typewriterActive = false
message.content = content
this.scrollToBottom()
}
enqueueTypewriter(chunk: string, message: ChatMessage): void {
if (!chunk) return

View File

@@ -9,7 +9,11 @@ import {
} from '@angular/common/http/testing'
import { TestBed } from '@angular/core/testing'
import { environment } from 'src/environments/environment'
import { ChatService } from './chat.service'
import {
CHAT_METADATA_DELIMITER,
ChatService,
parseChatResponse,
} from './chat.service'
describe('ChatService', () => {
let service: ChatService
@@ -55,4 +59,22 @@ describe('ChatService', () => {
partialText: mockResponse,
} as any)
})
it('should parse chat references from the metadata trailer', () => {
const parsed = parseChatResponse(
`Answer text${CHAT_METADATA_DELIMITER}{"references":[{"id":1,"title":"Document 1"}]}`
)
expect(parsed.content).toBe('Answer text')
expect(parsed.references).toEqual([{ id: 1, title: 'Document 1' }])
})
it('should hide incomplete metadata trailer from the visible content', () => {
const parsed = parseChatResponse(
`Answer text${CHAT_METADATA_DELIMITER}{"references"`
)
expect(parsed.content).toBe('Answer text')
expect(parsed.references).toBeUndefined()
})
})

View File

@@ -11,6 +11,46 @@ export interface ChatMessage {
role: 'user' | 'assistant'
content: string
isStreaming?: boolean
references?: ChatReference[]
}
export interface ChatReference {
id: number
title: string
}
export interface ParsedChatResponse {
content: string
references?: ChatReference[]
}
export const CHAT_METADATA_DELIMITER = '\n\n__PAPERLESS_CHAT_METADATA__'
export function parseChatResponse(response: string): ParsedChatResponse {
const delimiterIndex = response.indexOf(CHAT_METADATA_DELIMITER)
if (delimiterIndex === -1) {
return { content: response }
}
const metadataString = response.slice(
delimiterIndex + CHAT_METADATA_DELIMITER.length
)
try {
const metadata = JSON.parse(metadataString) as {
references?: ChatReference[]
}
return {
content: response.slice(0, delimiterIndex),
references: metadata.references ?? [],
}
} catch {
return {
content: response.slice(0, delimiterIndex),
}
}
}
@Injectable({

View File

@@ -1,3 +1,4 @@
import json
import logging
import sys
@@ -9,6 +10,8 @@ logger = logging.getLogger("paperless_ai.chat")
MAX_SINGLE_DOC_CONTEXT_CHARS = 15000
SINGLE_DOC_SNIPPET_CHARS = 800
CHAT_METADATA_DELIMITER = "\n\n__PAPERLESS_CHAT_METADATA__"
MAX_CHAT_REFERENCES = 3
CHAT_PROMPT_TMPL = """Context information is below.
---------------------
@@ -19,6 +22,52 @@ CHAT_PROMPT_TMPL = """Context information is below.
Answer:"""
def _build_document_reference(
document: Document,
title: str | None = None,
) -> dict[str, int | str]:
return {
"id": document.pk,
"title": title or document.title or document.filename,
}
def _get_document_references(
documents: list[Document],
top_nodes: list,
) -> list[dict[str, int | str]]:
allowed_documents = {doc.pk: doc for doc in documents}
references: list[dict[str, int | str]] = []
seen_document_ids: set[int] = set()
for node in top_nodes:
try:
document_id = int(node.metadata["document_id"])
except (KeyError, TypeError, ValueError): # pragma: no cover
continue
if document_id in seen_document_ids or document_id not in allowed_documents:
continue
seen_document_ids.add(document_id)
document = allowed_documents[document_id]
references.append(
_build_document_reference(document, node.metadata.get("title")),
)
if len(references) >= MAX_CHAT_REFERENCES: # pragma: no cover
break
return references
def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str:
return (
f"{CHAT_METADATA_DELIMITER}"
f"{json.dumps({'references': references}, separators=(',', ':'))}"
)
def stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
index = load_or_build_index()
@@ -49,6 +98,7 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
if len(documents) == 1:
# Just one doc — provide full content
doc = documents[0]
references = [_build_document_reference(doc)]
# TODO: include document metadata in the context
content = doc.content or ""
context_body = content
@@ -78,6 +128,7 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
yield "Sorry, I couldn't find any content to answer your question."
return
references = _get_document_references(documents, top_nodes)
context = "\n\n".join(
f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}"
for node in top_nodes
@@ -102,3 +153,6 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()
if references:
yield _format_chat_metadata_trailer(references)

View File

@@ -1,3 +1,4 @@
import json
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -5,6 +6,7 @@ import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from paperless_ai.chat import CHAT_METADATA_DELIMITER
from paperless_ai.chat import stream_chat_with_documents
@@ -40,6 +42,21 @@ def mock_document():
return doc
def assert_chat_output(
output: list[str],
*,
expected_chunks: list[str],
expected_references: list[dict[str, int | str]],
) -> None:
assert output[:-1] == expected_chunks
trailer = output[-1]
assert trailer.startswith(CHAT_METADATA_DELIMITER)
assert json.loads(trailer.removeprefix(CHAT_METADATA_DELIMITER)) == {
"references": expected_references,
}
def test_stream_chat_with_one_document_full_content(mock_document) -> None:
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
@@ -68,7 +85,13 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None:
output = list(stream_chat_with_documents("What is this?", [mock_document]))
assert output == ["chunk1", "chunk2"]
assert_chat_output(
output,
expected_chunks=["chunk1", "chunk2"],
expected_references=[
{"id": mock_document.pk, "title": "Test Document"},
],
)
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> None:
@@ -100,7 +123,20 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
# Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
mock_duplicate_node = TextNode(
text="More content for doc 1.",
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
)
mock_foreign_node = TextNode(
text="Content for doc 3.",
metadata={"document_id": "3", "title": "Document 3"},
)
mock_retriever.retrieve.return_value = [
mock_node1,
mock_duplicate_node,
mock_node2,
mock_foreign_node,
]
mock_as_retriever.return_value = mock_retriever
# Mock response stream
@@ -113,12 +149,19 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
mock_query_engine.query.return_value = mock_response_stream
# Fake documents
doc1 = MagicMock(pk=1)
doc2 = MagicMock(pk=2)
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
assert output == ["chunk1", "chunk2"]
assert_chat_output(
output,
expected_chunks=["chunk1", "chunk2"],
expected_references=[
{"id": 1, "title": "Document 1"},
{"id": 2, "title": "Document 2"},
],
)
def test_stream_chat_no_matching_nodes() -> None: