mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-03-07 09:41:22 +00:00
Compare commits
1 Commits
dev
...
chore/lazy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22e5ac58ba |
@@ -1,4 +1,4 @@
|
||||
from llama_index.core.bridge.pydantic import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DocumentClassifierSchema(BaseModel):
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.indexing import load_or_build_index
|
||||
@@ -14,15 +10,21 @@ logger = logging.getLogger("paperless_ai.chat")
|
||||
MAX_SINGLE_DOC_CONTEXT_CHARS = 15000
|
||||
SINGLE_DOC_SNIPPET_CHARS = 800
|
||||
|
||||
CHAT_PROMPT_TMPL = PromptTemplate(
|
||||
template="""Context information is below.
|
||||
CHAT_PROMPT_TEMPLATE = """Context information is below.
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
Given the context information and not prior knowledge, answer the query.
|
||||
Query: {query_str}
|
||||
Answer:""",
|
||||
)
|
||||
Answer:"""
|
||||
|
||||
|
||||
def _get_prompt_template():
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
|
||||
return PromptTemplate(
|
||||
template=CHAT_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
@@ -43,6 +45,8 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
yield "Sorry, I couldn't find any content to answer your question."
|
||||
return
|
||||
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
||||
local_index = VectorStoreIndex(nodes=nodes)
|
||||
retriever = local_index.as_retriever(
|
||||
similarity_top_k=3 if len(documents) == 1 else 5,
|
||||
@@ -85,10 +89,16 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
for node in top_nodes
|
||||
)
|
||||
|
||||
prompt = CHAT_PROMPT_TMPL.partial_format(
|
||||
context_str=context,
|
||||
query_str=query_str,
|
||||
).format(llm=client.llm)
|
||||
prompt = (
|
||||
_get_prompt_template()
|
||||
.partial_format(
|
||||
context_str=context,
|
||||
query_str=query_str,
|
||||
)
|
||||
.format(llm=client.llm)
|
||||
)
|
||||
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.core.program.function_program import get_function_tool
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.base_model import DocumentClassifierSchema
|
||||
@@ -20,14 +21,18 @@ class AIClient:
|
||||
self.settings = AIConfig()
|
||||
self.llm = self.get_llm()
|
||||
|
||||
def get_llm(self) -> Ollama | OpenAI:
|
||||
def get_llm(self) -> "Ollama | OpenAI":
|
||||
if self.settings.llm_backend == "ollama":
|
||||
from llama_index.llms.ollama import Ollama
|
||||
|
||||
return Ollama(
|
||||
model=self.settings.llm_model or "llama3.1",
|
||||
base_url=self.settings.llm_endpoint or "http://localhost:11434",
|
||||
request_timeout=120,
|
||||
)
|
||||
elif self.settings.llm_backend == "openai":
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
return OpenAI(
|
||||
model=self.settings.llm_model or "gpt-3.5-turbo",
|
||||
api_base=self.settings.llm_endpoint or None,
|
||||
@@ -43,6 +48,9 @@ class AIClient:
|
||||
self.settings.llm_model,
|
||||
)
|
||||
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.core.program.function_program import get_function_tool
|
||||
|
||||
user_msg = ChatMessage(role="user", content=prompt)
|
||||
tool = get_function_tool(DocumentClassifierSchema)
|
||||
result = self.llm.chat_with_tools(
|
||||
@@ -58,7 +66,7 @@ class AIClient:
|
||||
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
|
||||
return parsed.model_dump()
|
||||
|
||||
def run_chat(self, messages: list[ChatMessage]) -> str:
|
||||
def run_chat(self, messages: list["ChatMessage"]) -> str:
|
||||
logger.debug(
|
||||
"Running chat query against %s with model %s",
|
||||
self.settings.llm_backend,
|
||||
|
||||
@@ -5,9 +5,9 @@ if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import Note
|
||||
@@ -15,17 +15,21 @@ from paperless.config import AIConfig
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
|
||||
|
||||
def get_embedding_model() -> BaseEmbedding:
|
||||
def get_embedding_model() -> "BaseEmbedding":
|
||||
config = AIConfig()
|
||||
|
||||
match config.llm_embedding_backend:
|
||||
case LLMEmbeddingBackend.OPENAI:
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
return OpenAIEmbedding(
|
||||
model=config.llm_embedding_model or "text-embedding-3-small",
|
||||
api_key=config.llm_api_key,
|
||||
api_base=config.llm_endpoint or None,
|
||||
)
|
||||
case LLMEmbeddingBackend.HUGGINGFACE:
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
return HuggingFaceEmbedding(
|
||||
model_name=config.llm_embedding_model
|
||||
or "sentence-transformers/all-MiniLM-L6-v2",
|
||||
|
||||
@@ -4,26 +4,12 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
import faiss
|
||||
import llama_index.core.settings as llama_settings
|
||||
from celery import states
|
||||
from django.conf import settings
|
||||
from django.utils import timezone
|
||||
from llama_index.core import Document as LlamaDocument
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core import load_index_from_storage
|
||||
from llama_index.core.indices.prompt_helper import PromptHelper
|
||||
from llama_index.core.node_parser import SimpleNodeParser
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||
from llama_index.core.text_splitter import TokenTextSplitter
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import PaperlessTask
|
||||
@@ -31,6 +17,11 @@ from paperless_ai.embedding import build_llm_index_text
|
||||
from paperless_ai.embedding import get_embedding_dim
|
||||
from paperless_ai.embedding import get_embedding_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
_T = TypeVar("_T")
|
||||
IterWrapper = Callable[[Iterable[_T]], Iterable[_T]]
|
||||
|
||||
@@ -65,11 +56,17 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_or_create_storage_context(*, rebuild=False):
|
||||
def get_or_create_storage_context(*, rebuild=False) -> "StorageContext":
|
||||
"""
|
||||
Loads or creates the StorageContext (vector store, docstore, index store).
|
||||
If rebuild=True, deletes and recreates everything.
|
||||
"""
|
||||
import faiss
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
if rebuild:
|
||||
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
@@ -93,7 +90,7 @@ def get_or_create_storage_context(*, rebuild=False):
|
||||
)
|
||||
|
||||
|
||||
def build_document_node(document: Document) -> list[BaseNode]:
|
||||
def build_document_node(document: Document) -> list["BaseNode"]:
|
||||
"""
|
||||
Given a Document, returns parsed Nodes ready for indexing.
|
||||
"""
|
||||
@@ -112,16 +109,23 @@ def build_document_node(document: Document) -> list[BaseNode]:
|
||||
"added": document.added.isoformat() if document.added else None,
|
||||
"modified": document.modified.isoformat(),
|
||||
}
|
||||
from llama_index.core import Document as LlamaDocument
|
||||
from llama_index.core.node_parser import SimpleNodeParser
|
||||
|
||||
doc = LlamaDocument(text=text, metadata=metadata)
|
||||
parser = SimpleNodeParser()
|
||||
return parser.get_nodes_from_documents([doc])
|
||||
|
||||
|
||||
def load_or_build_index(nodes=None):
|
||||
def load_or_build_index(nodes=None) -> "VectorStoreIndex":
|
||||
"""
|
||||
Load an existing VectorStoreIndex if present,
|
||||
or build a new one using provided nodes if storage is empty.
|
||||
"""
|
||||
import llama_index.core.settings as llama_settings
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core import load_index_from_storage
|
||||
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context()
|
||||
@@ -143,7 +147,7 @@ def load_or_build_index(nodes=None):
|
||||
)
|
||||
|
||||
|
||||
def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex):
|
||||
def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"):
|
||||
"""
|
||||
Removes existing documents from docstore for a given document from the index.
|
||||
This is necessary because FAISS IndexFlatL2 is append-only.
|
||||
@@ -183,6 +187,9 @@ def update_llm_index(
|
||||
return msg
|
||||
|
||||
if rebuild or not vector_store_file_exists():
|
||||
import llama_index.core.settings as llama_settings
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
||||
# remove meta.json to force re-detection of embedding dim
|
||||
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
|
||||
# Rebuild index from scratch
|
||||
@@ -271,6 +278,10 @@ def llm_index_remove_document(document: Document):
|
||||
|
||||
|
||||
def truncate_content(content: str) -> str:
|
||||
from llama_index.core.indices.prompt_helper import PromptHelper
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.text_splitter import TokenTextSplitter
|
||||
|
||||
prompt_helper = PromptHelper(
|
||||
context_window=8192,
|
||||
num_output=512,
|
||||
@@ -315,6 +326,8 @@ def query_similar_documents(
|
||||
else None
|
||||
)
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
|
||||
retriever = VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=top_k,
|
||||
|
||||
@@ -45,7 +45,7 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None:
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"paperless_ai.chat.RetrieverQueryEngine.from_args",
|
||||
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
@@ -76,7 +76,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"paperless_ai.chat.RetrieverQueryEngine.from_args",
|
||||
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
|
||||
):
|
||||
|
||||
@@ -18,13 +18,13 @@ def mock_ai_config():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ollama_llm():
|
||||
with patch("paperless_ai.client.Ollama") as MockOllama:
|
||||
with patch("llama_index.llms.ollama.Ollama") as MockOllama:
|
||||
yield MockOllama
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_llm():
|
||||
with patch("paperless_ai.client.OpenAI") as MockOpenAI:
|
||||
with patch("llama_index.llms.openai.OpenAI") as MockOpenAI:
|
||||
yield MockOpenAI
|
||||
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ def test_get_embedding_model_openai(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||
mock_ai_config.return_value.llm_endpoint = "http://test-url"
|
||||
|
||||
with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
|
||||
with patch("llama_index.embeddings.openai.OpenAIEmbedding") as MockOpenAIEmbedding:
|
||||
model = get_embedding_model()
|
||||
MockOpenAIEmbedding.assert_called_once_with(
|
||||
model="text-embedding-3-small",
|
||||
@@ -84,7 +84,7 @@ def test_get_embedding_model_huggingface(mock_ai_config):
|
||||
)
|
||||
|
||||
with patch(
|
||||
"paperless_ai.embedding.HuggingFaceEmbedding",
|
||||
"llama_index.embeddings.huggingface.HuggingFaceEmbedding",
|
||||
) as MockHuggingFaceEmbedding:
|
||||
model = get_embedding_model()
|
||||
MockHuggingFaceEmbedding.assert_called_once_with(
|
||||
|
||||
Reference in New Issue
Block a user