Compare commits

...

3 Commits

Author SHA1 Message Date
Trenton H
a1cf82245e We need to super this 2026-02-27 13:57:43 -08:00
Trenton H
e7480ca3eb Missed this old call here 2026-02-27 13:24:54 -08:00
Trenton H
151e7d5abd Switches the 2 indexing methods to also use Rich now 2026-02-27 13:22:18 -08:00
4 changed files with 57 additions and 29 deletions

View File

@@ -1,22 +1,26 @@
from django.core.management import BaseCommand
from django.db import transaction
from documents.management.commands.mixins import ProgressBarMixin
from documents.management.commands.base import PaperlessCommand
from documents.tasks import index_optimize
from documents.tasks import index_reindex
class Command(ProgressBarMixin, BaseCommand):
class Command(PaperlessCommand):
help = "Manages the document index."
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument("command", choices=["reindex", "optimize"])
self.add_argument_progress_bar_mixin(parser)
def handle(self, *args, **options):
self.handle_progress_bar_mixin(**options)
with transaction.atomic():
if options["command"] == "reindex":
index_reindex(progress_bar_disable=self.no_progress_bar)
index_reindex(
iter_wrapper=lambda docs: self.track(
docs,
description="Indexing documents...",
),
)
elif options["command"] == "optimize":
index_optimize()

View File

@@ -1,22 +1,22 @@
from django.core.management import BaseCommand
from django.db import transaction
from typing import Any
from documents.management.commands.mixins import ProgressBarMixin
from documents.management.commands.base import PaperlessCommand
from documents.tasks import llmindex_index
class Command(ProgressBarMixin, BaseCommand):
class Command(PaperlessCommand):
help = "Manages the LLM-based vector index for Paperless."
def add_arguments(self, parser):
def add_arguments(self, parser: Any) -> None:
super().add_arguments(parser)
parser.add_argument("command", choices=["rebuild", "update"])
self.add_argument_progress_bar_mixin(parser)
def handle(self, *args, **options):
self.handle_progress_bar_mixin(**options)
with transaction.atomic():
llmindex_index(
progress_bar_disable=self.no_progress_bar,
rebuild=options["command"] == "rebuild",
scheduled=False,
)
def handle(self, *args: Any, **options: Any) -> None:
llmindex_index(
rebuild=options["command"] == "rebuild",
scheduled=False,
iter_wrapper=lambda docs: self.track(
docs,
description="Indexing documents...",
),
)

View File

@@ -4,11 +4,13 @@ import logging
import shutil
import uuid
import zipfile
from collections.abc import Callable
from collections.abc import Iterable
from pathlib import Path
from tempfile import TemporaryDirectory
from tempfile import mkstemp
from typing import TypeVar
import tqdm
from celery import Task
from celery import shared_task
from celery import states
@@ -66,11 +68,19 @@ from paperless_ai.indexing import llm_index_add_or_update_document
from paperless_ai.indexing import llm_index_remove_document
from paperless_ai.indexing import update_llm_index
_T = TypeVar("_T")
IterWrapper = Callable[[Iterable[_T]], Iterable[_T]]
if settings.AUDIT_LOG_ENABLED:
from auditlog.models import LogEntry
logger = logging.getLogger("paperless.tasks")
def _identity(iterable: Iterable[_T]) -> Iterable[_T]:
return iterable
@shared_task
def index_optimize() -> None:
ix = index.open_index()
@@ -78,13 +88,13 @@ def index_optimize() -> None:
writer.commit(optimize=True)
def index_reindex(*, progress_bar_disable=False) -> None:
def index_reindex(*, iter_wrapper: IterWrapper[Document] = _identity) -> None:
documents = Document.objects.all()
ix = index.open_index(recreate=True)
with AsyncWriter(ix) as writer:
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
for document in iter_wrapper(documents):
index.update_document(writer, document)
@@ -594,7 +604,7 @@ def update_document_parent_tags(tag: Tag, new_parent: Tag) -> None:
@shared_task
def llmindex_index(
*,
progress_bar_disable=True,
iter_wrapper: IterWrapper[Document] = _identity,
rebuild=False,
scheduled=True,
auto=False,
@@ -617,7 +627,7 @@ def llmindex_index(
try:
result = update_llm_index(
progress_bar_disable=progress_bar_disable,
iter_wrapper=iter_wrapper,
rebuild=rebuild,
)
task.status = states.SUCCESS

View File

@@ -1,11 +1,13 @@
import logging
import shutil
from collections.abc import Callable
from collections.abc import Iterable
from datetime import timedelta
from pathlib import Path
from typing import TypeVar
import faiss
import llama_index.core.settings as llama_settings
import tqdm
from celery import states
from django.conf import settings
from django.utils import timezone
@@ -29,6 +31,14 @@ from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_embedding_model
_T = TypeVar("_T")
IterWrapper = Callable[[Iterable[_T]], Iterable[_T]]
def _identity(iterable: Iterable[_T]) -> Iterable[_T]:
return iterable
logger = logging.getLogger("paperless_ai.indexing")
@@ -156,7 +166,11 @@ def vector_store_file_exists():
return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists()
def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str:
def update_llm_index(
*,
iter_wrapper: IterWrapper[Document] = _identity,
rebuild=False,
) -> str:
"""
Rebuild or update the LLM index.
"""
@@ -176,7 +190,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str:
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=True)
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
for document in iter_wrapper(documents):
document_nodes = build_document_node(document)
nodes.extend(document_nodes)
@@ -184,7 +198,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str:
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
show_progress=not progress_bar_disable,
show_progress=False,
)
msg = "LLM index rebuilt successfully."
else:
@@ -196,7 +210,7 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str:
for node in index.docstore.get_nodes(all_node_ids)
}
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
for document in iter_wrapper(documents):
doc_id = str(document.id)
document_modified = document.modified.isoformat()