diff --git a/src/paperless_ai/base_model.py b/src/paperless_ai/base_model.py index 2924f2c8c..44fe75f5b 100644 --- a/src/paperless_ai/base_model.py +++ b/src/paperless_ai/base_model.py @@ -1,4 +1,4 @@ -from llama_index.core.bridge.pydantic import BaseModel +from pydantic import BaseModel class DocumentClassifierSchema(BaseModel): diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py index f662a7bee..e48094f4f 100644 --- a/src/paperless_ai/chat.py +++ b/src/paperless_ai/chat.py @@ -1,10 +1,6 @@ import logging import sys -from llama_index.core import VectorStoreIndex -from llama_index.core.prompts import PromptTemplate -from llama_index.core.query_engine import RetrieverQueryEngine - from documents.models import Document from paperless_ai.client import AIClient from paperless_ai.indexing import load_or_build_index @@ -14,15 +10,21 @@ logger = logging.getLogger("paperless_ai.chat") MAX_SINGLE_DOC_CONTEXT_CHARS = 15000 SINGLE_DOC_SNIPPET_CHARS = 800 -CHAT_PROMPT_TMPL = PromptTemplate( - template="""Context information is below. +CHAT_PROMPT_TEMPLATE = """Context information is below. --------------------- {context_str} --------------------- Given the context information and not prior knowledge, answer the query. Query: {query_str} - Answer:""", -) + Answer:""" + + +def _get_prompt_template(): + from llama_index.core.prompts import PromptTemplate + + return PromptTemplate( + template=CHAT_PROMPT_TEMPLATE, + ) def stream_chat_with_documents(query_str: str, documents: list[Document]): @@ -43,6 +45,8 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): yield "Sorry, I couldn't find any content to answer your question." return + from llama_index.core import VectorStoreIndex + local_index = VectorStoreIndex(nodes=nodes) retriever = local_index.as_retriever( similarity_top_k=3 if len(documents) == 1 else 5, @@ -85,10 +89,16 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): for node in top_nodes ) - prompt = CHAT_PROMPT_TMPL.partial_format( - context_str=context, - query_str=query_str, - ).format(llm=client.llm) + prompt = ( + _get_prompt_template() + .partial_format( + context_str=context, + query_str=query_str, + ) + .format(llm=client.llm) + ) + + from llama_index.core.query_engine import RetrieverQueryEngine query_engine = RetrieverQueryEngine.from_args( retriever=retriever, diff --git a/src/paperless_ai/client.py b/src/paperless_ai/client.py index 62ff5f5c8..d70fe57b5 100644 --- a/src/paperless_ai/client.py +++ b/src/paperless_ai/client.py @@ -1,9 +1,10 @@ import logging +from typing import TYPE_CHECKING -from llama_index.core.llms import ChatMessage -from llama_index.core.program.function_program import get_function_tool -from llama_index.llms.ollama import Ollama -from llama_index.llms.openai import OpenAI +if TYPE_CHECKING: + from llama_index.core.llms import ChatMessage + from llama_index.llms.ollama import Ollama + from llama_index.llms.openai import OpenAI from paperless.config import AIConfig from paperless_ai.base_model import DocumentClassifierSchema @@ -20,14 +21,18 @@ class AIClient: self.settings = AIConfig() self.llm = self.get_llm() - def get_llm(self) -> Ollama | OpenAI: + def get_llm(self) -> "Ollama | OpenAI": if self.settings.llm_backend == "ollama": + from llama_index.llms.ollama import Ollama + return Ollama( model=self.settings.llm_model or "llama3.1", base_url=self.settings.llm_endpoint or "http://localhost:11434", request_timeout=120, ) elif self.settings.llm_backend == "openai": + from llama_index.llms.openai import OpenAI + return OpenAI( model=self.settings.llm_model or "gpt-3.5-turbo", api_base=self.settings.llm_endpoint or None, @@ -43,6 +48,9 @@ class AIClient: self.settings.llm_model, ) + from llama_index.core.llms import ChatMessage + from llama_index.core.program.function_program import get_function_tool + user_msg = ChatMessage(role="user", content=prompt) tool = get_function_tool(DocumentClassifierSchema) result = self.llm.chat_with_tools( @@ -58,7 +66,7 @@ class AIClient: parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs) return parsed.model_dump() - def run_chat(self, messages: list[ChatMessage]) -> str: + def run_chat(self, messages: list["ChatMessage"]) -> str: logger.debug( "Running chat query against %s with model %s", self.settings.llm_backend, diff --git a/src/paperless_ai/embedding.py b/src/paperless_ai/embedding.py index 686f73341..3966ce872 100644 --- a/src/paperless_ai/embedding.py +++ b/src/paperless_ai/embedding.py @@ -5,9 +5,9 @@ if TYPE_CHECKING: from pathlib import Path from django.conf import settings -from llama_index.core.base.embeddings.base import BaseEmbedding -from llama_index.embeddings.huggingface import HuggingFaceEmbedding -from llama_index.embeddings.openai import OpenAIEmbedding + +if TYPE_CHECKING: + from llama_index.core.base.embeddings.base import BaseEmbedding from documents.models import Document from documents.models import Note @@ -15,17 +15,21 @@ from paperless.config import AIConfig from paperless.models import LLMEmbeddingBackend -def get_embedding_model() -> BaseEmbedding: +def get_embedding_model() -> "BaseEmbedding": config = AIConfig() match config.llm_embedding_backend: case LLMEmbeddingBackend.OPENAI: + from llama_index.embeddings.openai import OpenAIEmbedding + return OpenAIEmbedding( model=config.llm_embedding_model or "text-embedding-3-small", api_key=config.llm_api_key, api_base=config.llm_endpoint or None, ) case LLMEmbeddingBackend.HUGGINGFACE: + from llama_index.embeddings.huggingface import HuggingFaceEmbedding + return HuggingFaceEmbedding( model_name=config.llm_embedding_model or "sentence-transformers/all-MiniLM-L6-v2", diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index 53e4a9796..20a5f946d 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -4,26 +4,12 @@ from collections.abc import Callable from collections.abc import Iterable from datetime import timedelta from pathlib import Path +from typing import TYPE_CHECKING from typing import TypeVar -import faiss -import llama_index.core.settings as llama_settings from celery import states from django.conf import settings from django.utils import timezone -from llama_index.core import Document as LlamaDocument -from llama_index.core import StorageContext -from llama_index.core import VectorStoreIndex -from llama_index.core import load_index_from_storage -from llama_index.core.indices.prompt_helper import PromptHelper -from llama_index.core.node_parser import SimpleNodeParser -from llama_index.core.prompts import PromptTemplate -from llama_index.core.retrievers import VectorIndexRetriever -from llama_index.core.schema import BaseNode -from llama_index.core.storage.docstore import SimpleDocumentStore -from llama_index.core.storage.index_store import SimpleIndexStore -from llama_index.core.text_splitter import TokenTextSplitter -from llama_index.vector_stores.faiss import FaissVectorStore from documents.models import Document from documents.models import PaperlessTask @@ -31,6 +17,11 @@ from paperless_ai.embedding import build_llm_index_text from paperless_ai.embedding import get_embedding_dim from paperless_ai.embedding import get_embedding_model +if TYPE_CHECKING: + from llama_index.core import StorageContext + from llama_index.core import VectorStoreIndex + from llama_index.core.schema import BaseNode + _T = TypeVar("_T") IterWrapper = Callable[[Iterable[_T]], Iterable[_T]] @@ -65,11 +56,17 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool: return True -def get_or_create_storage_context(*, rebuild=False): +def get_or_create_storage_context(*, rebuild=False) -> "StorageContext": """ Loads or creates the StorageContext (vector store, docstore, index store). If rebuild=True, deletes and recreates everything. """ + import faiss + from llama_index.core import StorageContext + from llama_index.core.storage.docstore import SimpleDocumentStore + from llama_index.core.storage.index_store import SimpleIndexStore + from llama_index.vector_stores.faiss import FaissVectorStore + if rebuild: shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True) settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True) @@ -93,7 +90,7 @@ def get_or_create_storage_context(*, rebuild=False): ) -def build_document_node(document: Document) -> list[BaseNode]: +def build_document_node(document: Document) -> list["BaseNode"]: """ Given a Document, returns parsed Nodes ready for indexing. """ @@ -112,16 +109,23 @@ def build_document_node(document: Document) -> list[BaseNode]: "added": document.added.isoformat() if document.added else None, "modified": document.modified.isoformat(), } + from llama_index.core import Document as LlamaDocument + from llama_index.core.node_parser import SimpleNodeParser + doc = LlamaDocument(text=text, metadata=metadata) parser = SimpleNodeParser() return parser.get_nodes_from_documents([doc]) -def load_or_build_index(nodes=None): +def load_or_build_index(nodes=None) -> "VectorStoreIndex": """ Load an existing VectorStoreIndex if present, or build a new one using provided nodes if storage is empty. """ + import llama_index.core.settings as llama_settings + from llama_index.core import VectorStoreIndex + from llama_index.core import load_index_from_storage + embed_model = get_embedding_model() llama_settings.Settings.embed_model = embed_model storage_context = get_or_create_storage_context() @@ -143,7 +147,7 @@ def load_or_build_index(nodes=None): ) -def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex): +def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"): """ Removes existing documents from docstore for a given document from the index. This is necessary because FAISS IndexFlatL2 is append-only. @@ -183,6 +187,9 @@ def update_llm_index( return msg if rebuild or not vector_store_file_exists(): + import llama_index.core.settings as llama_settings + from llama_index.core import VectorStoreIndex + # remove meta.json to force re-detection of embedding dim (settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True) # Rebuild index from scratch @@ -271,6 +278,10 @@ def llm_index_remove_document(document: Document): def truncate_content(content: str) -> str: + from llama_index.core.indices.prompt_helper import PromptHelper + from llama_index.core.prompts import PromptTemplate + from llama_index.core.text_splitter import TokenTextSplitter + prompt_helper = PromptHelper( context_window=8192, num_output=512, @@ -315,6 +326,8 @@ def query_similar_documents( else None ) + from llama_index.core.retrievers import VectorIndexRetriever + retriever = VectorIndexRetriever( index=index, similarity_top_k=top_k, diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index 0a14425cf..902c907f2 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -45,7 +45,7 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None: patch("paperless_ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.load_or_build_index") as mock_load_index, patch( - "paperless_ai.chat.RetrieverQueryEngine.from_args", + "llama_index.core.query_engine.RetrieverQueryEngine.from_args", ) as mock_query_engine_cls, ): mock_client = MagicMock() @@ -76,7 +76,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non patch("paperless_ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.load_or_build_index") as mock_load_index, patch( - "paperless_ai.chat.RetrieverQueryEngine.from_args", + "llama_index.core.query_engine.RetrieverQueryEngine.from_args", ) as mock_query_engine_cls, patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever, ): diff --git a/src/paperless_ai/tests/test_client.py b/src/paperless_ai/tests/test_client.py index 47053ab20..08582cc71 100644 --- a/src/paperless_ai/tests/test_client.py +++ b/src/paperless_ai/tests/test_client.py @@ -18,13 +18,13 @@ def mock_ai_config(): @pytest.fixture def mock_ollama_llm(): - with patch("paperless_ai.client.Ollama") as MockOllama: + with patch("llama_index.llms.ollama.Ollama") as MockOllama: yield MockOllama @pytest.fixture def mock_openai_llm(): - with patch("paperless_ai.client.OpenAI") as MockOpenAI: + with patch("llama_index.llms.openai.OpenAI") as MockOpenAI: yield MockOpenAI diff --git a/src/paperless_ai/tests/test_embedding.py b/src/paperless_ai/tests/test_embedding.py index 1fb69ee06..98da6e410 100644 --- a/src/paperless_ai/tests/test_embedding.py +++ b/src/paperless_ai/tests/test_embedding.py @@ -67,7 +67,7 @@ def test_get_embedding_model_openai(mock_ai_config): mock_ai_config.return_value.llm_api_key = "test_api_key" mock_ai_config.return_value.llm_endpoint = "http://test-url" - with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding: + with patch("llama_index.embeddings.openai.OpenAIEmbedding") as MockOpenAIEmbedding: model = get_embedding_model() MockOpenAIEmbedding.assert_called_once_with( model="text-embedding-3-small", @@ -84,7 +84,7 @@ def test_get_embedding_model_huggingface(mock_ai_config): ) with patch( - "paperless_ai.embedding.HuggingFaceEmbedding", + "llama_index.embeddings.huggingface.HuggingFaceEmbedding", ) as MockHuggingFaceEmbedding: model = get_embedding_model() MockHuggingFaceEmbedding.assert_called_once_with(