Compare commits

..

1 Commits

Author SHA1 Message Date
Trenton H
22e5ac58ba Refactors imports, mostly in the AI area, to be lazier imports 2026-03-06 21:00:46 -08:00
8 changed files with 83 additions and 48 deletions

View File

@@ -1,4 +1,4 @@
from llama_index.core.bridge.pydantic import BaseModel
from pydantic import BaseModel
class DocumentClassifierSchema(BaseModel):

View File

@@ -1,10 +1,6 @@
import logging
import sys
from llama_index.core import VectorStoreIndex
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from documents.models import Document
from paperless_ai.client import AIClient
from paperless_ai.indexing import load_or_build_index
@@ -14,15 +10,21 @@ logger = logging.getLogger("paperless_ai.chat")
MAX_SINGLE_DOC_CONTEXT_CHARS = 15000
SINGLE_DOC_SNIPPET_CHARS = 800
CHAT_PROMPT_TMPL = PromptTemplate(
template="""Context information is below.
CHAT_PROMPT_TEMPLATE = """Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {query_str}
Answer:""",
)
Answer:"""
def _get_prompt_template():
from llama_index.core.prompts import PromptTemplate
return PromptTemplate(
template=CHAT_PROMPT_TEMPLATE,
)
def stream_chat_with_documents(query_str: str, documents: list[Document]):
@@ -43,6 +45,8 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
yield "Sorry, I couldn't find any content to answer your question."
return
from llama_index.core import VectorStoreIndex
local_index = VectorStoreIndex(nodes=nodes)
retriever = local_index.as_retriever(
similarity_top_k=3 if len(documents) == 1 else 5,
@@ -85,10 +89,16 @@ def stream_chat_with_documents(query_str: str, documents: list[Document]):
for node in top_nodes
)
prompt = CHAT_PROMPT_TMPL.partial_format(
context_str=context,
query_str=query_str,
).format(llm=client.llm)
prompt = (
_get_prompt_template()
.partial_format(
context_str=context,
query_str=query_str,
)
.format(llm=client.llm)
)
from llama_index.core.query_engine import RetrieverQueryEngine
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,

View File

@@ -1,9 +1,10 @@
import logging
from typing import TYPE_CHECKING
from llama_index.core.llms import ChatMessage
from llama_index.core.program.function_program import get_function_tool
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
if TYPE_CHECKING:
from llama_index.core.llms import ChatMessage
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from paperless.config import AIConfig
from paperless_ai.base_model import DocumentClassifierSchema
@@ -20,14 +21,18 @@ class AIClient:
self.settings = AIConfig()
self.llm = self.get_llm()
def get_llm(self) -> Ollama | OpenAI:
def get_llm(self) -> "Ollama | OpenAI":
if self.settings.llm_backend == "ollama":
from llama_index.llms.ollama import Ollama
return Ollama(
model=self.settings.llm_model or "llama3.1",
base_url=self.settings.llm_endpoint or "http://localhost:11434",
request_timeout=120,
)
elif self.settings.llm_backend == "openai":
from llama_index.llms.openai import OpenAI
return OpenAI(
model=self.settings.llm_model or "gpt-3.5-turbo",
api_base=self.settings.llm_endpoint or None,
@@ -43,6 +48,9 @@ class AIClient:
self.settings.llm_model,
)
from llama_index.core.llms import ChatMessage
from llama_index.core.program.function_program import get_function_tool
user_msg = ChatMessage(role="user", content=prompt)
tool = get_function_tool(DocumentClassifierSchema)
result = self.llm.chat_with_tools(
@@ -58,7 +66,7 @@ class AIClient:
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
return parsed.model_dump()
def run_chat(self, messages: list[ChatMessage]) -> str:
def run_chat(self, messages: list["ChatMessage"]) -> str:
logger.debug(
"Running chat query against %s with model %s",
self.settings.llm_backend,

View File

@@ -5,9 +5,9 @@ if TYPE_CHECKING:
from pathlib import Path
from django.conf import settings
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
if TYPE_CHECKING:
from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document
from documents.models import Note
@@ -15,17 +15,21 @@ from paperless.config import AIConfig
from paperless.models import LLMEmbeddingBackend
def get_embedding_model() -> BaseEmbedding:
def get_embedding_model() -> "BaseEmbedding":
config = AIConfig()
match config.llm_embedding_backend:
case LLMEmbeddingBackend.OPENAI:
from llama_index.embeddings.openai import OpenAIEmbedding
return OpenAIEmbedding(
model=config.llm_embedding_model or "text-embedding-3-small",
api_key=config.llm_api_key,
api_base=config.llm_endpoint or None,
)
case LLMEmbeddingBackend.HUGGINGFACE:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
return HuggingFaceEmbedding(
model_name=config.llm_embedding_model
or "sentence-transformers/all-MiniLM-L6-v2",

View File

@@ -4,26 +4,12 @@ from collections.abc import Callable
from collections.abc import Iterable
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TypeVar
import faiss
import llama_index.core.settings as llama_settings
from celery import states
from django.conf import settings
from django.utils import timezone
from llama_index.core import Document as LlamaDocument
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.core import load_index_from_storage
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.prompts import PromptTemplate
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.core.text_splitter import TokenTextSplitter
from llama_index.vector_stores.faiss import FaissVectorStore
from documents.models import Document
from documents.models import PaperlessTask
@@ -31,6 +17,11 @@ from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_embedding_model
if TYPE_CHECKING:
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import BaseNode
_T = TypeVar("_T")
IterWrapper = Callable[[Iterable[_T]], Iterable[_T]]
@@ -65,11 +56,17 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
return True
def get_or_create_storage_context(*, rebuild=False):
def get_or_create_storage_context(*, rebuild=False) -> "StorageContext":
"""
Loads or creates the StorageContext (vector store, docstore, index store).
If rebuild=True, deletes and recreates everything.
"""
import faiss
from llama_index.core import StorageContext
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
if rebuild:
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
@@ -93,7 +90,7 @@ def get_or_create_storage_context(*, rebuild=False):
)
def build_document_node(document: Document) -> list[BaseNode]:
def build_document_node(document: Document) -> list["BaseNode"]:
"""
Given a Document, returns parsed Nodes ready for indexing.
"""
@@ -112,16 +109,23 @@ def build_document_node(document: Document) -> list[BaseNode]:
"added": document.added.isoformat() if document.added else None,
"modified": document.modified.isoformat(),
}
from llama_index.core import Document as LlamaDocument
from llama_index.core.node_parser import SimpleNodeParser
doc = LlamaDocument(text=text, metadata=metadata)
parser = SimpleNodeParser()
return parser.get_nodes_from_documents([doc])
def load_or_build_index(nodes=None):
def load_or_build_index(nodes=None) -> "VectorStoreIndex":
"""
Load an existing VectorStoreIndex if present,
or build a new one using provided nodes if storage is empty.
"""
import llama_index.core.settings as llama_settings
from llama_index.core import VectorStoreIndex
from llama_index.core import load_index_from_storage
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context()
@@ -143,7 +147,7 @@ def load_or_build_index(nodes=None):
)
def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex):
def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"):
"""
Removes existing documents from docstore for a given document from the index.
This is necessary because FAISS IndexFlatL2 is append-only.
@@ -183,6 +187,9 @@ def update_llm_index(
return msg
if rebuild or not vector_store_file_exists():
import llama_index.core.settings as llama_settings
from llama_index.core import VectorStoreIndex
# remove meta.json to force re-detection of embedding dim
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
# Rebuild index from scratch
@@ -271,6 +278,10 @@ def llm_index_remove_document(document: Document):
def truncate_content(content: str) -> str:
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.prompts import PromptTemplate
from llama_index.core.text_splitter import TokenTextSplitter
prompt_helper = PromptHelper(
context_window=8192,
num_output=512,
@@ -315,6 +326,8 @@ def query_similar_documents(
else None
)
from llama_index.core.retrievers import VectorIndexRetriever
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k,

View File

@@ -45,7 +45,7 @@ def test_stream_chat_with_one_document_full_content(mock_document) -> None:
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless_ai.chat.RetrieverQueryEngine.from_args",
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
):
mock_client = MagicMock()
@@ -76,7 +76,7 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless_ai.chat.RetrieverQueryEngine.from_args",
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
):

View File

@@ -18,13 +18,13 @@ def mock_ai_config():
@pytest.fixture
def mock_ollama_llm():
with patch("paperless_ai.client.Ollama") as MockOllama:
with patch("llama_index.llms.ollama.Ollama") as MockOllama:
yield MockOllama
@pytest.fixture
def mock_openai_llm():
with patch("paperless_ai.client.OpenAI") as MockOpenAI:
with patch("llama_index.llms.openai.OpenAI") as MockOpenAI:
yield MockOpenAI

View File

@@ -67,7 +67,7 @@ def test_get_embedding_model_openai(mock_ai_config):
mock_ai_config.return_value.llm_api_key = "test_api_key"
mock_ai_config.return_value.llm_endpoint = "http://test-url"
with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
with patch("llama_index.embeddings.openai.OpenAIEmbedding") as MockOpenAIEmbedding:
model = get_embedding_model()
MockOpenAIEmbedding.assert_called_once_with(
model="text-embedding-3-small",
@@ -84,7 +84,7 @@ def test_get_embedding_model_huggingface(mock_ai_config):
)
with patch(
"paperless_ai.embedding.HuggingFaceEmbedding",
"llama_index.embeddings.huggingface.HuggingFaceEmbedding",
) as MockHuggingFaceEmbedding:
model = get_embedding_model()
MockHuggingFaceEmbedding.assert_called_once_with(