mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-05-01 04:19:26 +00:00
Enhancement: chat message document links (#12670)
This commit is contained in:
@@ -8,10 +8,21 @@
|
||||
<div class="chat-messages font-monospace small">
|
||||
@for (message of messages; track message) {
|
||||
<div class="message d-flex flex-row small" [class.justify-content-end]="message.role === 'user'">
|
||||
<span class="p-2 m-2" [class.bg-dark]="message.role === 'user'">
|
||||
{{ message.content }}
|
||||
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
|
||||
</span>
|
||||
<div class="p-2 m-2" [class.bg-dark]="message.role === 'user'">
|
||||
<span>
|
||||
{{ message.content }}
|
||||
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
|
||||
</span>
|
||||
@if (message.role === 'assistant' && message.references?.length) {
|
||||
<div class="chat-references list-group mt-3">
|
||||
@for (reference of message.references; track reference.id) {
|
||||
<a class="list-group-item list-group-item-action text-primary" [routerLink]="['/documents', reference.id]">
|
||||
<i-bs width="0.9em" height="0.9em" name="file-text" class="me-1"></i-bs><span>{{ reference.title }}</span>
|
||||
</a>
|
||||
}
|
||||
</div>
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
<div #scrollAnchor></div>
|
||||
|
||||
@@ -7,6 +7,10 @@
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.chat-references {
|
||||
font-family: var(--bs-font-sans-serif);
|
||||
}
|
||||
|
||||
.dropdown-toggle::after {
|
||||
display: none;
|
||||
}
|
||||
|
||||
@@ -3,9 +3,13 @@ import { provideHttpClientTesting } from '@angular/common/http/testing'
|
||||
import { ElementRef } from '@angular/core'
|
||||
import { ComponentFixture, TestBed } from '@angular/core/testing'
|
||||
import { NavigationEnd, Router } from '@angular/router'
|
||||
import { RouterTestingModule } from '@angular/router/testing'
|
||||
import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||
import { Subject } from 'rxjs'
|
||||
import { ChatService } from 'src/app/services/chat.service'
|
||||
import {
|
||||
CHAT_METADATA_DELIMITER,
|
||||
ChatService,
|
||||
} from 'src/app/services/chat.service'
|
||||
import { ChatComponent } from './chat.component'
|
||||
|
||||
describe('ChatComponent', () => {
|
||||
@@ -18,7 +22,11 @@ describe('ChatComponent', () => {
|
||||
|
||||
beforeEach(async () => {
|
||||
TestBed.configureTestingModule({
|
||||
imports: [NgxBootstrapIconsModule.pick(allIcons), ChatComponent],
|
||||
imports: [
|
||||
NgxBootstrapIconsModule.pick(allIcons),
|
||||
RouterTestingModule,
|
||||
ChatComponent,
|
||||
],
|
||||
providers: [
|
||||
provideHttpClient(withInterceptorsFromDi()),
|
||||
provideHttpClientTesting(),
|
||||
@@ -84,6 +92,57 @@ describe('ChatComponent', () => {
|
||||
expect(component.messages[1].isStreaming).toBe(false)
|
||||
})
|
||||
|
||||
it('should parse references from the metadata trailer without showing it', () => {
|
||||
component.input = 'Hello'
|
||||
component.sendMessage()
|
||||
|
||||
mockStream$.next(
|
||||
`Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}`
|
||||
)
|
||||
jest.advanceTimersByTime(1000)
|
||||
|
||||
expect(component.messages[1].content).toBe('Hi there')
|
||||
expect(component.messages[1].references).toEqual([
|
||||
{ id: 42, title: 'Bread Recipe' },
|
||||
])
|
||||
})
|
||||
|
||||
it('should render document reference links under assistant messages', () => {
|
||||
component.input = 'Hello'
|
||||
component.sendMessage()
|
||||
|
||||
mockStream$.next(
|
||||
`Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}`
|
||||
)
|
||||
jest.advanceTimersByTime(1000)
|
||||
fixture.detectChanges()
|
||||
|
||||
const link = fixture.nativeElement.querySelector('.chat-references a')
|
||||
expect(link.textContent).toContain('Bread Recipe')
|
||||
expect(link.getAttribute('href')).toContain('/documents/42')
|
||||
})
|
||||
|
||||
it('should remove delimiter fragments that were already streamed', () => {
|
||||
component.input = 'Hello'
|
||||
component.sendMessage()
|
||||
|
||||
mockStream$.next(`Hi there${CHAT_METADATA_DELIMITER.slice(0, 8)}`)
|
||||
jest.advanceTimersByTime(1000)
|
||||
expect(component.messages[1].content).toBe(
|
||||
`Hi there${CHAT_METADATA_DELIMITER.slice(0, 8)}`
|
||||
)
|
||||
|
||||
mockStream$.next(
|
||||
`Hi there${CHAT_METADATA_DELIMITER}{"references":[{"id":42,"title":"Bread Recipe"}]}`
|
||||
)
|
||||
jest.advanceTimersByTime(1000)
|
||||
|
||||
expect(component.messages[1].content).toBe('Hi there')
|
||||
expect(component.messages[1].references).toEqual([
|
||||
{ id: 42, title: 'Bread Recipe' },
|
||||
])
|
||||
})
|
||||
|
||||
it('should handle errors during streaming', () => {
|
||||
component.input = 'Hello'
|
||||
component.sendMessage()
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
import { Component, ElementRef, inject, OnInit, ViewChild } from '@angular/core'
|
||||
import { FormsModule, ReactiveFormsModule } from '@angular/forms'
|
||||
import { NavigationEnd, Router } from '@angular/router'
|
||||
import { NavigationEnd, Router, RouterModule } from '@angular/router'
|
||||
import { NgbDropdownModule } from '@ng-bootstrap/ng-bootstrap'
|
||||
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||
import { filter, map } from 'rxjs'
|
||||
import { ChatMessage, ChatService } from 'src/app/services/chat.service'
|
||||
import {
|
||||
ChatMessage,
|
||||
ChatService,
|
||||
parseChatResponse,
|
||||
} from 'src/app/services/chat.service'
|
||||
|
||||
@Component({
|
||||
selector: 'pngx-chat',
|
||||
imports: [
|
||||
FormsModule,
|
||||
ReactiveFormsModule,
|
||||
RouterModule,
|
||||
NgxBootstrapIconsModule,
|
||||
NgbDropdownModule,
|
||||
],
|
||||
@@ -70,13 +75,24 @@ export class ChatComponent implements OnInit {
|
||||
this.messages.push(assistantMessage)
|
||||
this.loading = true
|
||||
|
||||
let lastPartialLength = 0
|
||||
let lastVisibleContent = ''
|
||||
|
||||
this.chatService.streamChat(this.documentId, this.input).subscribe({
|
||||
next: (chunk) => {
|
||||
const delta = chunk.substring(lastPartialLength)
|
||||
lastPartialLength = chunk.length
|
||||
this.enqueueTypewriter(delta, assistantMessage)
|
||||
const nextResponse = parseChatResponse(chunk)
|
||||
|
||||
if (nextResponse.content.length < lastVisibleContent.length) {
|
||||
this.resetTypewriter(assistantMessage, nextResponse.content)
|
||||
lastVisibleContent = nextResponse.content
|
||||
} else {
|
||||
const visibleDelta = nextResponse.content.substring(
|
||||
lastVisibleContent.length
|
||||
)
|
||||
lastVisibleContent = nextResponse.content
|
||||
this.enqueueTypewriter(visibleDelta, assistantMessage)
|
||||
}
|
||||
|
||||
assistantMessage.references = nextResponse.references
|
||||
},
|
||||
error: () => {
|
||||
assistantMessage.content += '\n\n⚠️ Error receiving response.'
|
||||
@@ -93,6 +109,13 @@ export class ChatComponent implements OnInit {
|
||||
this.input = ''
|
||||
}
|
||||
|
||||
private resetTypewriter(message: ChatMessage, content: string): void {
|
||||
this.typewriterBuffer = []
|
||||
this.typewriterActive = false
|
||||
message.content = content
|
||||
this.scrollToBottom()
|
||||
}
|
||||
|
||||
enqueueTypewriter(chunk: string, message: ChatMessage): void {
|
||||
if (!chunk) return
|
||||
|
||||
|
||||
@@ -9,7 +9,11 @@ import {
|
||||
} from '@angular/common/http/testing'
|
||||
import { TestBed } from '@angular/core/testing'
|
||||
import { environment } from 'src/environments/environment'
|
||||
import { ChatService } from './chat.service'
|
||||
import {
|
||||
CHAT_METADATA_DELIMITER,
|
||||
ChatService,
|
||||
parseChatResponse,
|
||||
} from './chat.service'
|
||||
|
||||
describe('ChatService', () => {
|
||||
let service: ChatService
|
||||
@@ -55,4 +59,22 @@ describe('ChatService', () => {
|
||||
partialText: mockResponse,
|
||||
} as any)
|
||||
})
|
||||
|
||||
it('should parse chat references from the metadata trailer', () => {
|
||||
const parsed = parseChatResponse(
|
||||
`Answer text${CHAT_METADATA_DELIMITER}{"references":[{"id":1,"title":"Document 1"}]}`
|
||||
)
|
||||
|
||||
expect(parsed.content).toBe('Answer text')
|
||||
expect(parsed.references).toEqual([{ id: 1, title: 'Document 1' }])
|
||||
})
|
||||
|
||||
it('should hide incomplete metadata trailer from the visible content', () => {
|
||||
const parsed = parseChatResponse(
|
||||
`Answer text${CHAT_METADATA_DELIMITER}{"references"`
|
||||
)
|
||||
|
||||
expect(parsed.content).toBe('Answer text')
|
||||
expect(parsed.references).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -11,6 +11,46 @@ export interface ChatMessage {
|
||||
role: 'user' | 'assistant'
|
||||
content: string
|
||||
isStreaming?: boolean
|
||||
references?: ChatReference[]
|
||||
}
|
||||
|
||||
export interface ChatReference {
|
||||
id: number
|
||||
title: string
|
||||
}
|
||||
|
||||
export interface ParsedChatResponse {
|
||||
content: string
|
||||
references?: ChatReference[]
|
||||
}
|
||||
|
||||
export const CHAT_METADATA_DELIMITER = '\n\n__PAPERLESS_CHAT_METADATA__'
|
||||
|
||||
export function parseChatResponse(response: string): ParsedChatResponse {
|
||||
const delimiterIndex = response.indexOf(CHAT_METADATA_DELIMITER)
|
||||
|
||||
if (delimiterIndex === -1) {
|
||||
return { content: response }
|
||||
}
|
||||
|
||||
const metadataString = response.slice(
|
||||
delimiterIndex + CHAT_METADATA_DELIMITER.length
|
||||
)
|
||||
|
||||
try {
|
||||
const metadata = JSON.parse(metadataString) as {
|
||||
references?: ChatReference[]
|
||||
}
|
||||
|
||||
return {
|
||||
content: response.slice(0, delimiterIndex),
|
||||
references: metadata.references ?? [],
|
||||
}
|
||||
} catch {
|
||||
return {
|
||||
content: response.slice(0, delimiterIndex),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Injectable({
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
@@ -9,6 +10,8 @@ logger = logging.getLogger("paperless_ai.chat")
|
||||
|
||||
MAX_SINGLE_DOC_CONTEXT_CHARS = 15000
|
||||
SINGLE_DOC_SNIPPET_CHARS = 800
|
||||
CHAT_METADATA_DELIMITER = "\n\n__PAPERLESS_CHAT_METADATA__"
|
||||
MAX_CHAT_REFERENCES = 3
|
||||
|
||||
CHAT_PROMPT_TMPL = """Context information is below.
|
||||
---------------------
|
||||
@@ -19,6 +22,52 @@ CHAT_PROMPT_TMPL = """Context information is below.
|
||||
Answer:"""
|
||||
|
||||
|
||||
def _build_document_reference(
|
||||
document: Document,
|
||||
title: str | None = None,
|
||||
) -> dict[str, int | str]:
|
||||
return {
|
||||
"id": document.pk,
|
||||
"title": title or document.title or document.filename,
|
||||
}
|
||||
|
||||
|
||||
def _get_document_references(
|
||||
documents: list[Document],
|
||||
top_nodes: list,
|
||||
) -> list[dict[str, int | str]]:
|
||||
allowed_documents = {doc.pk: doc for doc in documents}
|
||||
references: list[dict[str, int | str]] = []
|
||||
seen_document_ids: set[int] = set()
|
||||
|
||||
for node in top_nodes:
|
||||
try:
|
||||
document_id = int(node.metadata["document_id"])
|
||||
except (KeyError, TypeError, ValueError): # pragma: no cover
|
||||
continue
|
||||
|
||||
if document_id in seen_document_ids or document_id not in allowed_documents:
|
||||
continue
|
||||
|
||||
seen_document_ids.add(document_id)
|
||||
document = allowed_documents[document_id]
|
||||
references.append(
|
||||
_build_document_reference(document, node.metadata.get("title")),
|
||||
)
|
||||
|
||||
if len(references) >= MAX_CHAT_REFERENCES: # pragma: no cover
|
||||
break
|
||||
|
||||
return references
|
||||
|
||||
|
||||
def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str:
|
||||
return (
|
||||
f"{CHAT_METADATA_DELIMITER}"
|
||||
f"{json.dumps({'references': references}, separators=(',', ':'))}"
|
||||
)
|
||||
|
||||
|
||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
client = AIClient()
|
||||
index = load_or_build_index()
|
||||
@@ -49,6 +98,7 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
if len(documents) == 1:
|
||||
# Just one doc — provide full content
|
||||
doc = documents[0]
|
||||
references = [_build_document_reference(doc)]
|
||||
# TODO: include document metadata in the context
|
||||
content = doc.content or ""
|
||||
context_body = content
|
||||
@@ -78,6 +128,7 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
yield "Sorry, I couldn't find any content to answer your question."
|
||||
return
|
||||
|
||||
references = _get_document_references(documents, top_nodes)
|
||||
context = "\n\n".join(
|
||||
f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}"
|
||||
for node in top_nodes
|
||||
@@ -102,3 +153,6 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
for chunk in response_stream.response_gen:
|
||||
yield chunk
|
||||
sys.stdout.flush()
|
||||
|
||||
if references:
|
||||
yield _format_chat_metadata_trailer(references)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -5,6 +6,7 @@ import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
from paperless_ai.chat import CHAT_METADATA_DELIMITER
|
||||
from paperless_ai.chat import stream_chat_with_documents
|
||||
|
||||
|
||||
@@ -40,6 +42,21 @@ def mock_document():
|
||||
return doc
|
||||
|
||||
|
||||
def assert_chat_output(
|
||||
output: list[str],
|
||||
*,
|
||||
expected_chunks: list[str],
|
||||
expected_references: list[dict[str, int | str]],
|
||||
) -> None:
|
||||
assert output[:-1] == expected_chunks
|
||||
|
||||
trailer = output[-1]
|
||||
assert trailer.startswith(CHAT_METADATA_DELIMITER)
|
||||
assert json.loads(trailer.removeprefix(CHAT_METADATA_DELIMITER)) == {
|
||||
"references": expected_references,
|
||||
}
|
||||
|
||||
|
||||
def test_stream_chat_with_one_document_full_content(mock_document) -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
@@ -68,7 +85,13 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None:
|
||||
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
|
||||
assert output == ["chunk1", "chunk2"]
|
||||
assert_chat_output(
|
||||
output,
|
||||
expected_chunks=["chunk1", "chunk2"],
|
||||
expected_references=[
|
||||
{"id": mock_document.pk, "title": "Test Document"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> None:
|
||||
@@ -100,7 +123,20 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
|
||||
# Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
||||
mock_duplicate_node = TextNode(
|
||||
text="More content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
|
||||
)
|
||||
mock_foreign_node = TextNode(
|
||||
text="Content for doc 3.",
|
||||
metadata={"document_id": "3", "title": "Document 3"},
|
||||
)
|
||||
mock_retriever.retrieve.return_value = [
|
||||
mock_node1,
|
||||
mock_duplicate_node,
|
||||
mock_node2,
|
||||
mock_foreign_node,
|
||||
]
|
||||
mock_as_retriever.return_value = mock_retriever
|
||||
|
||||
# Mock response stream
|
||||
@@ -113,12 +149,19 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
|
||||
# Fake documents
|
||||
doc1 = MagicMock(pk=1)
|
||||
doc2 = MagicMock(pk=2)
|
||||
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
|
||||
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
|
||||
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
|
||||
assert output == ["chunk1", "chunk2"]
|
||||
assert_chat_output(
|
||||
output,
|
||||
expected_chunks=["chunk1", "chunk2"],
|
||||
expected_references=[
|
||||
{"id": 1, "title": "Document 1"},
|
||||
{"id": 2, "title": "Document 2"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_no_matching_nodes() -> None:
|
||||
|
||||
Reference in New Issue
Block a user