diff --git a/src/paperless_ai/ai_classifier.py b/src/paperless_ai/ai_classifier.py index f07f5eacc..7de19629a 100644 --- a/src/paperless_ai/ai_classifier.py +++ b/src/paperless_ai/ai_classifier.py @@ -6,12 +6,12 @@ from django.conf import settings from django.contrib.auth.models import User from documents.models import Document -from documents.permissions import get_objects_for_user_owner_aware from paperless.config import AIConfig from paperless_ai.client import AIClient from paperless_ai.db import db_connection_released from paperless_ai.indexing import query_similar_documents from paperless_ai.indexing import truncate_content +from paperless_ai.indexing import visible_document_ids_for_user from paperless_ai.taxonomy import format_hints_for_prompt if TYPE_CHECKING: @@ -109,20 +109,7 @@ def get_context_for_document( user: User | None = None, max_docs: int = 5, ) -> str: - visible_documents = ( - get_objects_for_user_owner_aware( - user, - "view_document", - Document, - ) - if user - else None - ) - visible_document_ids = ( - list(visible_documents.values_list("pk", flat=True)) - if visible_documents is not None - else None - ) + visible_document_ids = visible_document_ids_for_user(user) similar_docs = query_similar_documents( document=doc, document_ids=visible_document_ids, diff --git a/src/paperless_ai/indexing.py b/src/paperless_ai/indexing.py index 499990799..d8f4290e3 100644 --- a/src/paperless_ai/indexing.py +++ b/src/paperless_ai/indexing.py @@ -5,6 +5,7 @@ from datetime import timedelta from typing import TYPE_CHECKING from django.conf import settings +from django.contrib.auth.models import User from django.utils import timezone from filelock import FileLock from filelock import ReadWriteLock @@ -12,6 +13,7 @@ from filelock import Timeout from documents.models import Document from documents.models import PaperlessTask +from documents.permissions import get_objects_for_user_owner_aware from documents.utils import IterWrapper from documents.utils import identity from paperless.config import AIConfig @@ -450,6 +452,23 @@ def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str] return {str(document_id) for document_id in document_ids} +def visible_document_ids_for_user(user: User | None) -> list[int] | None: + """Return the pks of documents ``user`` may view, or ``None`` for no filter. + + Returns ``None`` when ``user`` is ``None`` so retrieval runs unfiltered. Used + by both the similarity-context and taxonomy-hints paths to scope RAG + neighbours to documents the requesting user is allowed to see. + """ + if user is None: + return None + visible_documents = get_objects_for_user_owner_aware( + user, + "view_document", + Document, + ) + return list(visible_documents.values_list("pk", flat=True)) + + def retrieve_similar_nodes( document: Document, document_ids: Iterable[int | str] | None = None, diff --git a/src/paperless_ai/taxonomy.py b/src/paperless_ai/taxonomy.py index 6b9cb77d9..4a5487d5b 100644 --- a/src/paperless_ai/taxonomy.py +++ b/src/paperless_ai/taxonomy.py @@ -4,9 +4,9 @@ from typing import TypedDict from django.contrib.auth.models import User from documents.models import Document -from documents.permissions import get_objects_for_user_owner_aware from paperless.config import AIConfig from paperless_ai.indexing import retrieve_similar_nodes +from paperless_ai.indexing import visible_document_ids_for_user if TYPE_CHECKING: from llama_index.core.schema import NodeWithScore @@ -108,23 +108,8 @@ def get_taxonomy_hints_for_document( if not AIConfig().llm_embedding_backend: return None - visible_documents = ( - get_objects_for_user_owner_aware( - user, - "view_document", - Document, - ) - if user - else None - ) - visible_document_ids = ( - list(visible_documents.values_list("pk", flat=True)) - if visible_documents is not None - else None - ) - nodes = retrieve_similar_nodes( document=document, - document_ids=visible_document_ids, + document_ids=visible_document_ids_for_user(user), ) return build_taxonomy_hints_from_nodes(nodes) diff --git a/src/paperless_ai/tests/test_taxonomy.py b/src/paperless_ai/tests/test_taxonomy.py index 3249dac84..401933a91 100644 --- a/src/paperless_ai/tests/test_taxonomy.py +++ b/src/paperless_ai/tests/test_taxonomy.py @@ -157,11 +157,9 @@ class TestGetTaxonomyHintsForDocument: "paperless_ai.taxonomy.AIConfig", return_value=SimpleNamespace(llm_embedding_backend="huggingface"), ) - visible = mocker.MagicMock() - visible.values_list.return_value = [1, 2, 3] mocker.patch( - "paperless_ai.taxonomy.get_objects_for_user_owner_aware", - return_value=visible, + "paperless_ai.taxonomy.visible_document_ids_for_user", + return_value=[1, 2, 3], ) retrieve = mocker.patch( "paperless_ai.taxonomy.retrieve_similar_nodes",