diff --git a/src/paperless_ai/taxonomy.py b/src/paperless_ai/taxonomy.py new file mode 100644 index 000000000..498c1b3e1 --- /dev/null +++ b/src/paperless_ai/taxonomy.py @@ -0,0 +1,57 @@ +import logging +from typing import TYPE_CHECKING +from typing import TypedDict + +if TYPE_CHECKING: + from llama_index.core.schema import NodeWithScore + +logger = logging.getLogger("paperless_ai.taxonomy") + + +class TaxonomyHints(TypedDict): + tags: list[str] + document_types: list[str] + correspondents: list[str] + storage_paths: list[str] + + +def build_taxonomy_hints_from_nodes( + nodes: list["NodeWithScore"], +) -> TaxonomyHints: + """Collect the unique, sorted taxonomy names carried on retrieved nodes. + + Reads ``tags`` (a list), ``document_type``, ``correspondent``, and + ``storage_path`` from each node's metadata. Empty / ``None`` values and + missing keys are skipped. The result is naturally bounded by the retrieval + ``top_k``, so no cap is applied. + """ + tags: set[str] = set() + document_types: set[str] = set() + correspondents: set[str] = set() + storage_paths: set[str] = set() + + for node in nodes: + metadata = node.metadata or {} + + for tag in metadata.get("tags") or []: + if tag: + tags.add(tag) + + document_type = metadata.get("document_type") + if document_type: + document_types.add(document_type) + + correspondent = metadata.get("correspondent") + if correspondent: + correspondents.add(correspondent) + + storage_path = metadata.get("storage_path") + if storage_path: + storage_paths.add(storage_path) + + return TaxonomyHints( + tags=sorted(tags), + document_types=sorted(document_types), + correspondents=sorted(correspondents), + storage_paths=sorted(storage_paths), + ) diff --git a/src/paperless_ai/tests/test_taxonomy.py b/src/paperless_ai/tests/test_taxonomy.py new file mode 100644 index 000000000..0aad2cb59 --- /dev/null +++ b/src/paperless_ai/tests/test_taxonomy.py @@ -0,0 +1,79 @@ +from types import SimpleNamespace + +from paperless_ai.taxonomy import build_taxonomy_hints_from_nodes + + +def make_node(**metadata: object) -> SimpleNamespace: + """A stand-in for NodeWithScore: only ``.metadata`` is accessed.""" + return SimpleNamespace(metadata=metadata) + + +class TestBuildTaxonomyHintsFromNodes: + def test_returns_all_four_keys(self) -> None: + hints = build_taxonomy_hints_from_nodes([]) + assert set(hints.keys()) == { + "tags", + "document_types", + "correspondents", + "storage_paths", + } + + def test_collects_and_sorts_values(self) -> None: + nodes = [ + make_node( + tags=["Taxes", "Bloodwork"], + document_type="Invoice", + correspondent="IRS", + storage_path="Financial", + ), + ] + hints = build_taxonomy_hints_from_nodes(nodes) + assert hints["tags"] == ["Bloodwork", "Taxes"] + assert hints["document_types"] == ["Invoice"] + assert hints["correspondents"] == ["IRS"] + assert hints["storage_paths"] == ["Financial"] + + def test_deduplicates_across_nodes(self) -> None: + nodes = [ + make_node(tags=["Taxes"], document_type="Invoice"), + make_node(tags=["Taxes", "Medical"], document_type="Invoice"), + ] + hints = build_taxonomy_hints_from_nodes(nodes) + assert hints["tags"] == ["Medical", "Taxes"] + assert hints["document_types"] == ["Invoice"] + + def test_none_values_skipped(self) -> None: + nodes = [ + make_node( + tags=["Taxes", None, ""], + document_type=None, + correspondent=None, + storage_path=None, + ), + ] + hints = build_taxonomy_hints_from_nodes(nodes) + assert hints["tags"] == ["Taxes"] + assert hints["document_types"] == [] + assert hints["correspondents"] == [] + assert hints["storage_paths"] == [] + + def test_missing_storage_path_key_handled(self) -> None: + # Pre-enrichment nodes have no storage_path key at all. + nodes = [make_node(tags=["Taxes"], document_type="Invoice")] + hints = build_taxonomy_hints_from_nodes(nodes) + assert hints["storage_paths"] == [] + + def test_empty_node_list_all_empty(self) -> None: + hints = build_taxonomy_hints_from_nodes([]) + assert hints == { + "tags": [], + "document_types": [], + "correspondents": [], + "storage_paths": [], + } + + def test_output_stable_across_calls(self) -> None: + nodes = [make_node(tags=["b", "a", "c"])] + assert build_taxonomy_hints_from_nodes( + nodes, + ) == build_taxonomy_hints_from_nodes(nodes)