diff --git a/src/paperless_ai/base_model.py b/src/paperless_ai/base_model.py index 2924f2c8c..44fe75f5b 100644 --- a/src/paperless_ai/base_model.py +++ b/src/paperless_ai/base_model.py @@ -1,4 +1,4 @@ -from llama_index.core.bridge.pydantic import BaseModel +from pydantic import BaseModel class DocumentClassifierSchema(BaseModel): diff --git a/src/paperless_ai/chat.py b/src/paperless_ai/chat.py index f662a7bee..33603c45e 100644 --- a/src/paperless_ai/chat.py +++ b/src/paperless_ai/chat.py @@ -1,10 +1,6 @@ import logging import sys -from llama_index.core import VectorStoreIndex -from llama_index.core.prompts import PromptTemplate -from llama_index.core.query_engine import RetrieverQueryEngine - from documents.models import Document from paperless_ai.client import AIClient from paperless_ai.indexing import load_or_build_index @@ -14,15 +10,13 @@ logger = logging.getLogger("paperless_ai.chat") MAX_SINGLE_DOC_CONTEXT_CHARS = 15000 SINGLE_DOC_SNIPPET_CHARS = 800 -CHAT_PROMPT_TMPL = PromptTemplate( - template="""Context information is below. +CHAT_PROMPT_TMPL = """Context information is below. --------------------- {context_str} --------------------- Given the context information and not prior knowledge, answer the query. Query: {query_str} - Answer:""", -) + Answer:""" def stream_chat_with_documents(query_str: str, documents: list[Document]): @@ -43,6 +37,10 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): yield "Sorry, I couldn't find any content to answer your question." return + from llama_index.core import VectorStoreIndex + from llama_index.core.prompts import PromptTemplate + from llama_index.core.query_engine import RetrieverQueryEngine + local_index = VectorStoreIndex(nodes=nodes) retriever = local_index.as_retriever( similarity_top_k=3 if len(documents) == 1 else 5, @@ -85,7 +83,8 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]): for node in top_nodes ) - prompt = CHAT_PROMPT_TMPL.partial_format( + prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL) + prompt = prompt_template.partial_format( context_str=context, query_str=query_str, ).format(llm=client.llm) diff --git a/src/paperless_ai/client.py b/src/paperless_ai/client.py index 62ff5f5c8..d70fe57b5 100644 --- a/src/paperless_ai/client.py +++ b/src/paperless_ai/client.py @@ -1,9 +1,10 @@ import logging +from typing import TYPE_CHECKING -from llama_index.core.llms import ChatMessage -from llama_index.core.program.function_program import get_function_tool -from llama_index.llms.ollama import Ollama -from llama_index.llms.openai import OpenAI +if TYPE_CHECKING: + from llama_index.core.llms import ChatMessage + from llama_index.llms.ollama import Ollama + from llama_index.llms.openai import OpenAI from paperless.config import AIConfig from paperless_ai.base_model import DocumentClassifierSchema @@ -20,14 +21,18 @@ class AIClient: self.settings = AIConfig() self.llm = self.get_llm() - def get_llm(self) -> Ollama | OpenAI: + def get_llm(self) -> "Ollama | OpenAI": if self.settings.llm_backend == "ollama": + from llama_index.llms.ollama import Ollama + return Ollama( model=self.settings.llm_model or "llama3.1", base_url=self.settings.llm_endpoint or "http://localhost:11434", request_timeout=120, ) elif self.settings.llm_backend == "openai": + from llama_index.llms.openai import OpenAI + return OpenAI( model=self.settings.llm_model or "gpt-3.5-turbo", api_base=self.settings.llm_endpoint or None, @@ -43,6 +48,9 @@ class AIClient: self.settings.llm_model, ) + from llama_index.core.llms import ChatMessage + from llama_index.core.program.function_program import get_function_tool + user_msg = ChatMessage(role="user", content=prompt) tool = get_function_tool(DocumentClassifierSchema) result = self.llm.chat_with_tools( @@ -58,7 +66,7 @@ class AIClient: parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs) return parsed.model_dump() - def run_chat(self, messages: list[ChatMessage]) -> str: + def run_chat(self, messages: list["ChatMessage"]) -> str: logger.debug( "Running chat query against %s with model %s", self.settings.llm_backend, diff --git a/src/paperless_ai/embedding.py b/src/paperless_ai/embedding.py index 686f73341..f7193a8db 100644 --- a/src/paperless_ai/embedding.py +++ b/src/paperless_ai/embedding.py @@ -1,13 +1,12 @@ import json from typing import TYPE_CHECKING +from django.conf import settings + if TYPE_CHECKING: from pathlib import Path -from django.conf import settings -from llama_index.core.base.embeddings.base import BaseEmbedding -from llama_index.embeddings.huggingface import HuggingFaceEmbedding -from llama_index.embeddings.openai import OpenAIEmbedding + from llama_index.core.base.embeddings.base import BaseEmbedding from documents.models import Document from documents.models import Note @@ -15,17 +14,21 @@ from paperless.config import AIConfig from paperless.models import LLMEmbeddingBackend -def get_embedding_model() -> BaseEmbedding: +def get_embedding_model() -> "BaseEmbedding": config = AIConfig() match config.llm_embedding_backend: case LLMEmbeddingBackend.OPENAI: + from llama_index.embeddings.openai import OpenAIEmbedding + return OpenAIEmbedding( model=config.llm_embedding_model or "text-embedding-3-small", api_key=config.llm_api_key, api_base=config.llm_endpoint or None, ) case LLMEmbeddingBackend.HUGGINGFACE: + from llama_index.embeddings.huggingface import HuggingFaceEmbedding + return HuggingFaceEmbedding( model_name=config.llm_embedding_model or "sentence-transformers/all-MiniLM-L6-v2", diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index 53e4a9796..bee8f0dd9 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -4,26 +4,12 @@ from collections.abc import Callable from collections.abc import Iterable from datetime import timedelta from pathlib import Path +from typing import TYPE_CHECKING from typing import TypeVar -import faiss -import llama_index.core.settings as llama_settings from celery import states from django.conf import settings from django.utils import timezone -from llama_index.core import Document as LlamaDocument -from llama_index.core import StorageContext -from llama_index.core import VectorStoreIndex -from llama_index.core import load_index_from_storage -from llama_index.core.indices.prompt_helper import PromptHelper -from llama_index.core.node_parser import SimpleNodeParser -from llama_index.core.prompts import PromptTemplate -from llama_index.core.retrievers import VectorIndexRetriever -from llama_index.core.schema import BaseNode -from llama_index.core.storage.docstore import SimpleDocumentStore -from llama_index.core.storage.index_store import SimpleIndexStore -from llama_index.core.text_splitter import TokenTextSplitter -from llama_index.vector_stores.faiss import FaissVectorStore from documents.models import Document from documents.models import PaperlessTask @@ -34,6 +20,10 @@ from paperless_ai.embedding import get_embedding_model _T = TypeVar("_T") IterWrapper = Callable[[Iterable[_T]], Iterable[_T]] +if TYPE_CHECKING: + from llama_index.core import VectorStoreIndex + from llama_index.core.schema import BaseNode + def _identity(iterable: Iterable[_T]) -> Iterable[_T]: return iterable @@ -75,12 +65,23 @@ def get_or_create_storage_context(*, rebuild=False): settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True) if rebuild or not settings.LLM_INDEX_DIR.exists(): + import faiss + from llama_index.core import StorageContext + from llama_index.core.storage.docstore import SimpleDocumentStore + from llama_index.core.storage.index_store import SimpleIndexStore + from llama_index.vector_stores.faiss import FaissVectorStore + embedding_dim = get_embedding_dim() faiss_index = faiss.IndexFlatL2(embedding_dim) vector_store = FaissVectorStore(faiss_index=faiss_index) docstore = SimpleDocumentStore() index_store = SimpleIndexStore() else: + from llama_index.core import StorageContext + from llama_index.core.storage.docstore import SimpleDocumentStore + from llama_index.core.storage.index_store import SimpleIndexStore + from llama_index.vector_stores.faiss import FaissVectorStore + vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR) docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR) index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR) @@ -93,7 +94,7 @@ def get_or_create_storage_context(*, rebuild=False): ) -def build_document_node(document: Document) -> list[BaseNode]: +def build_document_node(document: Document) -> list["BaseNode"]: """ Given a Document, returns parsed Nodes ready for indexing. """ @@ -112,6 +113,9 @@ def build_document_node(document: Document) -> list[BaseNode]: "added": document.added.isoformat() if document.added else None, "modified": document.modified.isoformat(), } + from llama_index.core import Document as LlamaDocument + from llama_index.core.node_parser import SimpleNodeParser + doc = LlamaDocument(text=text, metadata=metadata) parser = SimpleNodeParser() return parser.get_nodes_from_documents([doc]) @@ -122,6 +126,10 @@ def load_or_build_index(nodes=None): Load an existing VectorStoreIndex if present, or build a new one using provided nodes if storage is empty. """ + import llama_index.core.settings as llama_settings + from llama_index.core import VectorStoreIndex + from llama_index.core import load_index_from_storage + embed_model = get_embedding_model() llama_settings.Settings.embed_model = embed_model storage_context = get_or_create_storage_context() @@ -143,7 +151,7 @@ def load_or_build_index(nodes=None): ) -def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex): +def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"): """ Removes existing documents from docstore for a given document from the index. This is necessary because FAISS IndexFlatL2 is append-only. @@ -174,6 +182,8 @@ def update_llm_index( """ Rebuild or update the LLM index. """ + from llama_index.core import VectorStoreIndex + nodes = [] documents = Document.objects.all() @@ -187,6 +197,8 @@ def update_llm_index( (settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True) # Rebuild index from scratch logger.info("Rebuilding LLM index.") + import llama_index.core.settings as llama_settings + embed_model = get_embedding_model() llama_settings.Settings.embed_model = embed_model storage_context = get_or_create_storage_context(rebuild=True) @@ -271,6 +283,10 @@ def llm_index_remove_document(document: Document): def truncate_content(content: str) -> str: + from llama_index.core.indices.prompt_helper import PromptHelper + from llama_index.core.prompts import PromptTemplate + from llama_index.core.text_splitter import TokenTextSplitter + prompt_helper = PromptHelper( context_window=8192, num_output=512, @@ -315,6 +331,8 @@ def query_similar_documents( else None ) + from llama_index.core.retrievers import VectorIndexRetriever + retriever = VectorIndexRetriever( index=index, similarity_top_k=top_k, diff --git a/src/paperless_ai/tests/test_ai_indexing.py b/src/paperless_ai/tests/test_ai_indexing.py index c36655f4d..724ac43e4 100644 --- a/src/paperless_ai/tests/test_ai_indexing.py +++ b/src/paperless_ai/tests/test_ai_indexing.py @@ -181,11 +181,11 @@ def test_load_or_build_index_builds_when_nodes_given( ) -> None: with ( patch( - "paperless_ai.indexing.load_index_from_storage", + "llama_index.core.load_index_from_storage", side_effect=ValueError("Index not found"), ), patch( - "paperless_ai.indexing.VectorStoreIndex", + "llama_index.core.VectorStoreIndex", return_value=MagicMock(), ) as mock_index_cls, patch( @@ -206,7 +206,7 @@ def test_load_or_build_index_raises_exception_when_no_nodes( ) -> None: with ( patch( - "paperless_ai.indexing.load_index_from_storage", + "llama_index.core.load_index_from_storage", side_effect=ValueError("Index not found"), ), patch( @@ -225,11 +225,11 @@ def test_load_or_build_index_succeeds_when_nodes_given( ) -> None: with ( patch( - "paperless_ai.indexing.load_index_from_storage", + "llama_index.core.load_index_from_storage", side_effect=ValueError("Index not found"), ), patch( - "paperless_ai.indexing.VectorStoreIndex", + "llama_index.core.VectorStoreIndex", return_value=MagicMock(), ) as mock_index_cls, patch( @@ -334,7 +334,7 @@ def test_query_similar_documents( patch( "paperless_ai.indexing.vector_store_file_exists", ) as mock_vector_store_exists, - patch("paperless_ai.indexing.VectorIndexRetriever") as mock_retriever_cls, + patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls, patch("paperless_ai.indexing.Document.objects.filter") as mock_filter, ): mock_storage.return_value = MagicMock() diff --git a/src/paperless_ai/tests/test_chat.py b/src/paperless_ai/tests/test_chat.py index 0a14425cf..902c907f2 100644 --- a/src/paperless_ai/tests/test_chat.py +++ b/src/paperless_ai/tests/test_chat.py @@ -45,7 +45,7 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None: patch("paperless_ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.load_or_build_index") as mock_load_index, patch( - "paperless_ai.chat.RetrieverQueryEngine.from_args", + "llama_index.core.query_engine.RetrieverQueryEngine.from_args", ) as mock_query_engine_cls, ): mock_client = MagicMock() @@ -76,7 +76,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non patch("paperless_ai.chat.AIClient") as mock_client_cls, patch("paperless_ai.chat.load_or_build_index") as mock_load_index, patch( - "paperless_ai.chat.RetrieverQueryEngine.from_args", + "llama_index.core.query_engine.RetrieverQueryEngine.from_args", ) as mock_query_engine_cls, patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever, ): diff --git a/src/paperless_ai/tests/test_client.py b/src/paperless_ai/tests/test_client.py index 47053ab20..08582cc71 100644 --- a/src/paperless_ai/tests/test_client.py +++ b/src/paperless_ai/tests/test_client.py @@ -18,13 +18,13 @@ def mock_ai_config(): @pytest.fixture def mock_ollama_llm(): - with patch("paperless_ai.client.Ollama") as MockOllama: + with patch("llama_index.llms.ollama.Ollama") as MockOllama: yield MockOllama @pytest.fixture def mock_openai_llm(): - with patch("paperless_ai.client.OpenAI") as MockOpenAI: + with patch("llama_index.llms.openai.OpenAI") as MockOpenAI: yield MockOpenAI diff --git a/src/paperless_ai/tests/test_embedding.py b/src/paperless_ai/tests/test_embedding.py index 1fb69ee06..98da6e410 100644 --- a/src/paperless_ai/tests/test_embedding.py +++ b/src/paperless_ai/tests/test_embedding.py @@ -67,7 +67,7 @@ def test_get_embedding_model_openai(mock_ai_config): mock_ai_config.return_value.llm_api_key = "test_api_key" mock_ai_config.return_value.llm_endpoint = "http://test-url" - with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding: + with patch("llama_index.embeddings.openai.OpenAIEmbedding") as MockOpenAIEmbedding: model = get_embedding_model() MockOpenAIEmbedding.assert_called_once_with( model="text-embedding-3-small", @@ -84,7 +84,7 @@ def test_get_embedding_model_huggingface(mock_ai_config): ) with patch( - "paperless_ai.embedding.HuggingFaceEmbedding", + "llama_index.embeddings.huggingface.HuggingFaceEmbedding", ) as MockHuggingFaceEmbedding: model = get_embedding_model() MockHuggingFaceEmbedding.assert_called_once_with(