mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-06-26 07:14:18 +00:00
Merge branch 'beta' into dev
This commit is contained in:
@@ -58,6 +58,7 @@ repos:
|
||||
rev: "v2.24.1"
|
||||
hooks:
|
||||
- id: pyproject-fmt
|
||||
additional_dependencies: [tomli]
|
||||
# Dockerfile hooks
|
||||
- repo: https://github.com/AleksaC/hadolint-py
|
||||
rev: v2.14.0
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
# documentation.
|
||||
services:
|
||||
broker:
|
||||
image: docker.io/library/redis:8
|
||||
image: docker.io/valkey/valkey:9-alpine
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- redisdata:/data
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
# documentation.
|
||||
services:
|
||||
broker:
|
||||
image: docker.io/library/redis:8
|
||||
image: docker.io/valkey/valkey:9-alpine
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- redisdata:/data
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
# documentation.
|
||||
services:
|
||||
broker:
|
||||
image: docker.io/library/redis:8
|
||||
image: docker.io/valkey/valkey:9-alpine
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- redisdata:/data
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
# documentation.
|
||||
services:
|
||||
broker:
|
||||
image: docker.io/library/redis:8
|
||||
image: docker.io/valkey/valkey:9-alpine
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- redisdata:/data
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
# documentation.
|
||||
services:
|
||||
broker:
|
||||
image: docker.io/library/redis:8
|
||||
image: docker.io/valkey/valkey:9-alpine
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- redisdata:/data
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
# documentation.
|
||||
services:
|
||||
broker:
|
||||
image: docker.io/library/redis:8
|
||||
image: docker.io/valkey/valkey:9-alpine
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- redisdata:/data
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
# documentation.
|
||||
services:
|
||||
broker:
|
||||
image: docker.io/library/redis:8
|
||||
image: docker.io/valkey/valkey:9-alpine
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- redisdata:/data
|
||||
|
||||
@@ -65,6 +65,11 @@ copies you created in the steps above.
|
||||
|
||||
Please review the [migration instructions](migration-v3.md) before upgrading Paperless-ngx to v3.0, it includes some breaking changes that require manual intervention before upgrading.
|
||||
|
||||
!!! note
|
||||
|
||||
Upgrading to v3 clears the existing task history; previously completed, failed, or
|
||||
acknowledged tasks will no longer appear in the task list afterward. No action is required.
|
||||
|
||||
### Docker Route {#docker-updating}
|
||||
|
||||
If a new release of paperless-ngx is available, upgrading depends on how
|
||||
@@ -500,6 +505,33 @@ task scheduler.
|
||||
python3 manage.py document_index reindex --if-needed
|
||||
```
|
||||
|
||||
### Managing the LLM (AI) index {#llm-index}
|
||||
|
||||
When the [AI features](advanced_usage.md#ai-features) are enabled with an embedding
|
||||
backend, Paperless-ngx maintains a vector index of your documents used for
|
||||
Retrieval-Augmented Generation (RAG), similar-document retrieval, and document chat. The
|
||||
index is updated automatically on the schedule set by
|
||||
[`PAPERLESS_LLM_INDEX_TASK_CRON`](configuration.md#PAPERLESS_LLM_INDEX_TASK_CRON), but you
|
||||
can manage it manually:
|
||||
|
||||
```
|
||||
document_llmindex {rebuild,update,compact}
|
||||
```
|
||||
|
||||
Specify `rebuild` to build the index from scratch from all documents in the database. Use
|
||||
this the first time you enable the feature, or after changing the embedding backend or
|
||||
model.
|
||||
|
||||
Specify `update` to incrementally index new and changed documents. This is what the
|
||||
scheduled task runs.
|
||||
|
||||
Specify `compact` to reclaim space and optimize the on-disk vector store.
|
||||
|
||||
!!! note
|
||||
|
||||
These commands have no effect unless AI is enabled and an embedding backend is
|
||||
configured.
|
||||
|
||||
### Clearing the database read cache
|
||||
|
||||
If the database read cache is enabled, **you must run this command** after making any changes to the database outside the application context.
|
||||
|
||||
+83
-2
@@ -97,6 +97,85 @@ when using this feature:
|
||||
of these correspondents to ANY new document, if both are set to
|
||||
automatic matching.
|
||||
|
||||
## AI features {#ai-features}
|
||||
|
||||
Paperless-ngx includes a set of optional features backed by a large language model
|
||||
(LLM): AI-assisted suggestions, similar-document retrieval, and a document chat. They
|
||||
are **off by default** and never replace the built-in, non-LLM
|
||||
[matching and suggestions](#matching).
|
||||
|
||||
!!! warning
|
||||
|
||||
Enabling these features sends document content (and metadata) to the LLM backend you
|
||||
configure. If that backend is a remote/hosted provider, your documents leave your
|
||||
server and may incur usage charges. Consider the privacy implications before enabling,
|
||||
and prefer a local backend (Ollama, or a self-hosted OpenAI-compatible gateway) if that
|
||||
matters to you.
|
||||
|
||||
All AI settings can be supplied as `PAPERLESS_AI_*` environment variables (see
|
||||
[configuration](configuration.md#ai)) or set in the admin under
|
||||
**Settings → Application Configuration**; the database value takes precedence over the
|
||||
environment.
|
||||
|
||||
### Enabling the AI features
|
||||
|
||||
At a minimum you need to enable AI and choose an LLM backend:
|
||||
|
||||
- [`PAPERLESS_AI_ENABLED`](configuration.md#PAPERLESS_AI_ENABLED) — master switch.
|
||||
- [`PAPERLESS_AI_LLM_BACKEND`](configuration.md#PAPERLESS_AI_LLM_BACKEND) — `ollama`
|
||||
(runs locally) or `openai-like` (OpenAI itself or any OpenAI-compatible API).
|
||||
- [`PAPERLESS_AI_LLM_MODEL`](configuration.md#PAPERLESS_AI_LLM_MODEL), and for
|
||||
`openai-like` usually [`PAPERLESS_AI_LLM_API_KEY`](configuration.md#PAPERLESS_AI_LLM_API_KEY)
|
||||
and/or [`PAPERLESS_AI_LLM_ENDPOINT`](configuration.md#PAPERLESS_AI_LLM_ENDPOINT). Ollama
|
||||
requires `PAPERLESS_AI_LLM_ENDPOINT` pointing at your Ollama server.
|
||||
|
||||
### AI-assisted suggestions
|
||||
|
||||
With AI enabled, Paperless-ngx can suggest a title, tags, correspondent, document type,
|
||||
storage path and dates by sending the document to the LLM. This is **opt-in per request**
|
||||
and surfaces through the "Suggest" control on the document detail page, alongside the
|
||||
classic classifier-based suggestions — it does not disable them. Suggestion output
|
||||
language can be steered with
|
||||
[`PAPERLESS_AI_LLM_OUTPUT_LANGUAGE`](configuration.md#PAPERLESS_AI_LLM_OUTPUT_LANGUAGE)
|
||||
(otherwise it follows the user's UI language).
|
||||
|
||||
### The LLM index (RAG) and similar documents
|
||||
|
||||
Setting an embedding backend turns on the **LLM index**, a vector index of your documents
|
||||
that enables Retrieval-Augmented Generation (RAG). When enabled, suggestions are grounded
|
||||
in similar existing documents, and the document chat can retrieve relevant context.
|
||||
|
||||
Enable it by setting
|
||||
[`PAPERLESS_AI_LLM_EMBEDDING_BACKEND`](configuration.md#PAPERLESS_AI_LLM_EMBEDDING_BACKEND)
|
||||
(`huggingface` for fully-local embeddings, or `ollama` / `openai-like`). The index is only
|
||||
built when AI is enabled **and** an embedding backend is set.
|
||||
|
||||
The index is updated automatically on a schedule controlled by
|
||||
[`PAPERLESS_LLM_INDEX_TASK_CRON`](configuration.md#PAPERLESS_LLM_INDEX_TASK_CRON) (daily by
|
||||
default), and can be rebuilt or compacted manually — see
|
||||
[Managing the LLM index](administration.md#llm-index).
|
||||
|
||||
!!! note
|
||||
|
||||
Local embeddings via `huggingface` download the embedding model on first use into the
|
||||
Paperless data directory. The first run therefore needs network access and some disk
|
||||
space.
|
||||
|
||||
### Document chat
|
||||
|
||||
When the LLM index is enabled, the chat control in the top app toolbar answers questions
|
||||
about your documents. It operates over a single document or across multiple documents
|
||||
depending on the current view, and its answers include links to the source documents it
|
||||
drew from.
|
||||
|
||||
### AI Security notes
|
||||
|
||||
- Document content is passed to the LLM as **untrusted data**.
|
||||
- By default Paperless-ngx allows AI endpoints that resolve to private/loopback addresses
|
||||
(for local backends). Set
|
||||
[`PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS`](configuration.md#PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS)
|
||||
to `false` to block them.
|
||||
|
||||
## Hooking into the consumption process {#consume-hooks}
|
||||
|
||||
Sometimes you may want to do something arbitrary whenever a document is
|
||||
@@ -846,7 +925,7 @@ Paperless is able to utilize barcodes for automatically performing some tasks. B
|
||||
|
||||
At this time, the library utilized for detection of barcodes supports the following types:
|
||||
|
||||
- AN-13/UPC-A
|
||||
- EAN-13/UPC-A
|
||||
- UPC-E
|
||||
- EAN-8
|
||||
- Code 128
|
||||
@@ -855,7 +934,9 @@ At this time, the library utilized for detection of barcodes supports the follow
|
||||
- Codabar
|
||||
- Interleaved 2 of 5
|
||||
- QR Code
|
||||
- SQ Code
|
||||
- Data Matrix
|
||||
- Aztec
|
||||
- PDF417
|
||||
|
||||
For usage in Paperless, the type of barcode does not matter, only the contents of it.
|
||||
|
||||
|
||||
@@ -227,6 +227,7 @@ Version-aware endpoints:
|
||||
- `PATCH /api/documents/{id}/`: content updates target the selected version (`?version={version_id}`) or latest version by default; non-content metadata updates target the root document.
|
||||
- `GET /api/documents/{id}/download/`, `GET /api/documents/{id}/preview/`, `GET /api/documents/{id}/thumb/`, `GET /api/documents/{id}/metadata/`: accept `?version={version_id}`.
|
||||
- `POST /api/documents/{id}/update_version/`: uploads a new version using multipart form field `document` and optional `version_label`.
|
||||
- `PATCH /api/documents/{id}/versions/{version_id}/`: updates the `version_label` of a specific version.
|
||||
- `DELETE /api/documents/{root_id}/versions/{version_id}/`: deletes a non-root version.
|
||||
|
||||
## Permissions
|
||||
@@ -445,3 +446,9 @@ Initial API version.
|
||||
large lists of object IDs for operations affecting many objects.
|
||||
- The legacy `title_content` document search parameter is deprecated and will be removed in a future version.
|
||||
Clients should use `text` for simple title-and-content search and `title_search` for title-only search.
|
||||
- The task tracking system was redesigned. The tasks list (`/api/tasks/`) is now paginated, and the
|
||||
task object exposes `task_type` (formerly `task_name`) and `trigger_source` (formerly `type`). New
|
||||
read-only endpoints `/api/tasks/summary/`, `/api/tasks/status_counts/`, and `/api/tasks/active/`
|
||||
provide aggregate views, and `POST /api/tasks/run/` lets privileged users dispatch supported tasks.
|
||||
API v9 continues to serve the unpaginated list with the legacy field names until support for v9 is
|
||||
dropped.
|
||||
|
||||
+28
-17
@@ -22,7 +22,11 @@ or applicable default will be utilized instead.
|
||||
|
||||
## Required services
|
||||
|
||||
### Redis Broker
|
||||
### Message Broker
|
||||
|
||||
Paperless-ngx uses a Redis-compatible message broker. Any broker that
|
||||
speaks the Redis protocol works here, including [Valkey](https://valkey.io/)
|
||||
(the default in the bundled Docker Compose files) and Redis itself.
|
||||
|
||||
#### [`PAPERLESS_REDIS=<url>`](#PAPERLESS_REDIS) {#PAPERLESS_REDIS}
|
||||
|
||||
@@ -30,21 +34,21 @@ or applicable default will be utilized instead.
|
||||
fetching, index optimization and for training the automatic document
|
||||
matcher.
|
||||
|
||||
- If your Redis server needs login credentials PAPERLESS_REDIS =
|
||||
- If your broker needs login credentials PAPERLESS_REDIS =
|
||||
`redis://<username>:<password>@<host>:<port>`
|
||||
- With the requirepass option PAPERLESS_REDIS =
|
||||
`redis://:<password>@<host>:<port>`
|
||||
- To include the redis database index PAPERLESS_REDIS =
|
||||
- To include the database index PAPERLESS_REDIS =
|
||||
`redis://<username>:<password>@<host>:<port>/<DBIndex>`
|
||||
|
||||
[More information on securing your Redis
|
||||
Instance](https://redis.io/docs/latest/operate/oss_and_stack/management/security).
|
||||
[More information on securing your broker
|
||||
instance](https://valkey.io/topics/security/).
|
||||
|
||||
Defaults to `redis://localhost:6379`.
|
||||
|
||||
#### [`PAPERLESS_REDIS_PREFIX=<prefix>`](#PAPERLESS_REDIS_PREFIX) {#PAPERLESS_REDIS_PREFIX}
|
||||
|
||||
: Prefix to be used in Redis for keys and channels. Useful for sharing one Redis server among multiple Paperless instances.
|
||||
: Prefix to be used in the broker for keys and channels. Useful for sharing one broker among multiple Paperless instances.
|
||||
|
||||
Defaults to no prefix.
|
||||
|
||||
@@ -58,14 +62,14 @@ and the relevant connection variables.
|
||||
#### [`PAPERLESS_DBENGINE=<engine>`](#PAPERLESS_DBENGINE) {#PAPERLESS_DBENGINE}
|
||||
|
||||
: Specifies the database engine to use. Accepted values are `sqlite`, `postgresql`,
|
||||
and `mariadb`.
|
||||
|
||||
Defaults to `sqlite` if not set.
|
||||
and `mariadb`. PostgreSQL and MariaDB users must set this explicitly.
|
||||
|
||||
PostgreSQL and MariaDB both require [`PAPERLESS_DBHOST`](#PAPERLESS_DBHOST) to be
|
||||
set. SQLite does not use any other connection variables; the database file is always
|
||||
located at `<PAPERLESS_DATA_DIR>/db.sqlite3`.
|
||||
|
||||
Defaults to `sqlite`.
|
||||
|
||||
!!! warning
|
||||
Using MariaDB comes with some caveats.
|
||||
See [MySQL Caveats](advanced_usage.md#mysql-caveats).
|
||||
@@ -238,7 +242,7 @@ dictionaries; for example, `pool.max_size=20` sets
|
||||
|
||||
#### [`PAPERLESS_DB_READ_CACHE_ENABLED=<bool>`](#PAPERLESS_DB_READ_CACHE_ENABLED) {#PAPERLESS_DB_READ_CACHE_ENABLED}
|
||||
|
||||
: Caches the database read query results into Redis. This can significantly improve application response times by caching database queries, at the cost of slightly increased memory usage.
|
||||
: Caches the database read query results into the broker. This can significantly improve application response times by caching database queries, at the cost of slightly increased memory usage.
|
||||
|
||||
Defaults to `false`.
|
||||
|
||||
@@ -258,18 +262,18 @@ dictionaries; for example, `pool.max_size=20` sets
|
||||
|
||||
A high TTL increases memory usage over time. Memory may be used until end of TTL, even if the cache is invalidated with the `invalidate_cachalot` command.
|
||||
|
||||
In case of an out-of-memory (OOM) situation, Redis may stop accepting new data — including cache entries, scheduled tasks, and documents to consume.
|
||||
If your system has limited RAM, consider configuring a dedicated Redis instance for the read cache, with a memory limit and the eviction policy set to `allkeys-lru`.
|
||||
For more details, refer to the [Redis eviction policy documentation](https://redis.io/docs/latest/develop/reference/eviction/), and see the `PAPERLESS_READ_CACHE_REDIS_URL` setting to specify a separate Redis broker.
|
||||
In case of an out-of-memory (OOM) situation, the broker may stop accepting new data — including cache entries, scheduled tasks, and documents to consume.
|
||||
If your system has limited RAM, consider configuring a dedicated broker instance for the read cache, with a memory limit and the eviction policy set to `allkeys-lru`.
|
||||
For more details, refer to the [Redis eviction policy documentation](https://redis.io/docs/latest/develop/reference/eviction/), and see the `PAPERLESS_READ_CACHE_REDIS_URL` setting to specify a separate broker.
|
||||
|
||||
#### [`PAPERLESS_READ_CACHE_REDIS_URL=<url>`](#PAPERLESS_READ_CACHE_REDIS_URL) {#PAPERLESS_READ_CACHE_REDIS_URL}
|
||||
|
||||
: Defines the Redis instance used for the read cache.
|
||||
: Defines the broker instance used for the read cache.
|
||||
|
||||
Defaults to `None`.
|
||||
|
||||
!!! Note
|
||||
If this value is not set, the same Redis instance used for scheduled tasks will be used for caching as well.
|
||||
If this value is not set, the same broker instance used for scheduled tasks will be used for caching as well.
|
||||
|
||||
## Optional Services
|
||||
|
||||
@@ -888,7 +892,7 @@ modes are available:
|
||||
|
||||
The default is `auto`.
|
||||
|
||||
For the `skip`, `redo`, and `force` modes, read more about OCR
|
||||
For the `redo` and `force` modes, read more about OCR
|
||||
behaviour in the [OCRmyPDF
|
||||
documentation](https://ocrmypdf.readthedocs.io/en/latest/advanced.html#when-ocr-is-skipped).
|
||||
|
||||
@@ -2068,6 +2072,13 @@ context by default.
|
||||
|
||||
Defaults to 8192.
|
||||
|
||||
#### [`PAPERLESS_AI_LLM_REQUEST_TIMEOUT=<int>`](#PAPERLESS_AI_LLM_REQUEST_TIMEOUT) {#PAPERLESS_AI_LLM_REQUEST_TIMEOUT}
|
||||
|
||||
: The timeout, in seconds, for requests to the configured AI backend. Increase this when using
|
||||
local or slow inference servers that need more time to generate responses.
|
||||
|
||||
Defaults to 120.
|
||||
|
||||
#### [`PAPERLESS_AI_LLM_BACKEND=<str>`](#PAPERLESS_AI_LLM_BACKEND) {#PAPERLESS_AI_LLM_BACKEND}
|
||||
|
||||
: The AI backend to use. This can be either "openai-like" or "ollama". If set to "ollama", the AI
|
||||
@@ -2120,7 +2131,7 @@ used with the OpenAI-compatible backend to target a custom provider or local gat
|
||||
|
||||
Defaults to true, which allows internal endpoints.
|
||||
|
||||
#### [`PAPERLESS_AI_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_AI_LLM_INDEX_TASK_CRON) {#PAPERLESS_AI_LLM_INDEX_TASK_CRON}
|
||||
#### [`PAPERLESS_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_LLM_INDEX_TASK_CRON) {#PAPERLESS_LLM_INDEX_TASK_CRON}
|
||||
|
||||
: Configures the schedule to update the AI embeddings of text content and metadata for all documents. Only performed if
|
||||
AI is enabled and the LLM embedding backend is set.
|
||||
|
||||
+13
-12
@@ -94,16 +94,16 @@ first-time setup.
|
||||
```
|
||||
|
||||
7. You can now either ...
|
||||
- install Redis or
|
||||
- install a Redis-compatible broker (e.g. Valkey or Redis) or
|
||||
|
||||
- use the included `scripts/start_services.sh` to use Docker to fire
|
||||
up a Redis instance (and some other services such as Tika,
|
||||
up a broker instance (and some other services such as Tika,
|
||||
Gotenberg and a database server) or
|
||||
|
||||
- spin up a bare Redis container
|
||||
- spin up a bare broker container
|
||||
|
||||
```bash
|
||||
docker run -d -p 6379:6379 --restart unless-stopped redis:latest
|
||||
docker run -d -p 6379:6379 --restart unless-stopped docker.io/valkey/valkey:9-alpine
|
||||
```
|
||||
|
||||
8. Continue with either back-end or front-end development – or both :-).
|
||||
@@ -132,7 +132,7 @@ uv run manage.py runserver & \
|
||||
```
|
||||
|
||||
You might need the front end to test your back end code.
|
||||
This assumes that you have AngularJS installed on your system.
|
||||
This assumes that you have Angular installed on your system.
|
||||
Go to the [Front end development](#front-end-development) section for further details.
|
||||
To build the front end once use this command:
|
||||
|
||||
@@ -174,7 +174,7 @@ To add a new development package `uv add --dev <package>`
|
||||
|
||||
## Front end development
|
||||
|
||||
The front end is built using AngularJS. In order to get started, you need Node.js (version 24+) and
|
||||
The front end is built using Angular. In order to get started, you need Node.js (version 24+) and
|
||||
`pnpm`.
|
||||
|
||||
!!! note
|
||||
@@ -248,12 +248,12 @@ that authentication is working.
|
||||
## Localization
|
||||
|
||||
Paperless-ngx is available in many different languages. Since Paperless-ngx
|
||||
consists both of a Django application and an AngularJS front end, both
|
||||
consists both of a Django application and an Angular front end, both
|
||||
these parts have to be translated separately.
|
||||
|
||||
### Front end localization
|
||||
|
||||
- The AngularJS front end does localization according to the [Angular
|
||||
- The Angular front end does localization according to the [Angular
|
||||
documentation](https://angular.io/guide/i18n).
|
||||
- The source language of the project is "en_US".
|
||||
- The source strings end up in the file `src-ui/messages.xlf`.
|
||||
@@ -495,7 +495,7 @@ class MyCustomParser:
|
||||
self._tempdir = Path(
|
||||
tempfile.mkdtemp(prefix="paperless-", dir=settings.SCRATCH_DIR)
|
||||
)
|
||||
self._text: str | None = None
|
||||
self._text: str = ""
|
||||
self._archive_path: Path | None = None
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
@@ -553,7 +553,8 @@ def parse(
|
||||
**Result accessors**
|
||||
|
||||
```python
|
||||
def get_text(self) -> str | None:
|
||||
def get_text(self) -> str:
|
||||
# Return the extracted text, or an empty string if none was found.
|
||||
return self._text
|
||||
|
||||
def get_date(self) -> "datetime.datetime | None":
|
||||
@@ -684,7 +685,7 @@ class XmlDocumentParser:
|
||||
def __init__(self, logging_group: object = None) -> None:
|
||||
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
|
||||
self._tempdir = Path(tempfile.mkdtemp(prefix="paperless-", dir=settings.SCRATCH_DIR))
|
||||
self._text: str | None = None
|
||||
self._text: str = ""
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
@@ -702,7 +703,7 @@ class XmlDocumentParser:
|
||||
except ET.ParseError as e:
|
||||
raise ParseError(f"XML parse error: {e}") from e
|
||||
|
||||
def get_text(self) -> str | None:
|
||||
def get_text(self) -> str:
|
||||
return self._text
|
||||
|
||||
def get_date(self):
|
||||
|
||||
+29
-6
@@ -70,7 +70,16 @@ elsewhere. Here are a couple notes about that.
|
||||
Paperless-ngx determines the type of a file by inspecting its content
|
||||
rather than its file extensions. However, files processed via the
|
||||
consumption directory will be rejected if they have a file extension that
|
||||
not supported by any of the available parsers.
|
||||
is not supported by any of the available parsers.
|
||||
|
||||
## _Are duplicate documents rejected?_
|
||||
|
||||
**A:** Not by default. As of v3, a file whose contents match an existing document is still
|
||||
consumed, and the duplicate is flagged in the UI — open the document and check the
|
||||
**Duplicates** tab to review documents that share the same content. If you prefer the old
|
||||
behavior of rejecting duplicates during consumption, set
|
||||
[`PAPERLESS_CONSUMER_DELETE_DUPLICATES`](configuration.md#PAPERLESS_CONSUMER_DELETE_DUPLICATES)
|
||||
to `true`.
|
||||
|
||||
## _Will paperless-ngx run on Raspberry Pi?_
|
||||
|
||||
@@ -118,10 +127,24 @@ able to run paperless, you're a bit on your own. If you can't run the
|
||||
docker image, the documentation has instructions for bare metal
|
||||
installs.
|
||||
|
||||
## _What about the Redis licensing change and using one of the open source forks_?
|
||||
## _Does Paperless-ngx use AI, and is my data private?_
|
||||
|
||||
Currently (October 2024), forks of Redis such as Valkey or Redirect are not officially supported by our upstream
|
||||
libraries, so using one of these to replace Redis is not officially supported.
|
||||
**A:** Paperless-ngx includes optional AI features — LLM-based suggestions, document chat,
|
||||
and similar-document retrieval — that are **disabled by default**. They only run when you
|
||||
enable them and configure an LLM backend. The built-in tag/correspondent suggestions use a
|
||||
local, non-LLM machine-learning model and do not send your data anywhere. If you enable the
|
||||
LLM features, document content is sent to whichever backend you configure — this can be a
|
||||
fully local backend (e.g. Ollama) or a remote provider. See
|
||||
[AI features](advanced_usage.md#ai-features) for details.
|
||||
|
||||
However, they do claim to be compatible with the Redis protocol and will likely work, but we will
|
||||
not be updating from using Redis as the broker officially just yet.
|
||||
## _Which message broker should I use_?
|
||||
|
||||
Paperless-ngx talks to a Redis-compatible message broker, so any broker that
|
||||
implements the Redis protocol will work. The bundled Docker Compose files
|
||||
default to [Valkey](https://valkey.io/), the open-source fork created after
|
||||
Redis' licensing change, but Redis itself and other wire-compatible brokers
|
||||
(such as Microsoft's Garnet) are equally fine.
|
||||
|
||||
Existing installs can switch broker implementations in place: point
|
||||
[`PAPERLESS_REDIS`](configuration.md#PAPERLESS_REDIS) at the new instance and
|
||||
reuse the same data volume.
|
||||
|
||||
+2
-1
@@ -35,9 +35,10 @@ physical documents into a searchable online archive so you can keep, well, _less
|
||||
- _New!_ Supports remote OCR with Azure AI (opt-in).
|
||||
- Documents are saved as PDF/A format which is designed for long term storage, alongside the unaltered originals.
|
||||
- Uses machine-learning to automatically add tags, correspondents and document types to your documents.
|
||||
- **New**: Paperless-ngx can now leverage AI (Large Language Models or LLMs) for document suggestions. This is an optional feature that can be enabled (and is disabled by default).
|
||||
- **New**: Paperless-ngx can optionally leverage AI (Large Language Models or LLMs) for document suggestions, chatting with your documents, and similar-document retrieval. These features are opt-in and disabled by default.
|
||||
- Supports PDF documents, images, plain text files, Office documents (Word, Excel, PowerPoint, and LibreOffice equivalents)[^1] and more.
|
||||
- Paperless stores your documents plain on disk. Filenames and folders are managed by paperless and their format can be configured freely with different configurations assigned to different documents.
|
||||
- Keep multiple **versions** of a document's file under a single entry, sharing one set of metadata.
|
||||
- **Beautiful, modern web application** that features:
|
||||
- Customizable dashboard with statistics.
|
||||
- Filtering by tags, correspondents, types, and more.
|
||||
|
||||
+19
-12
@@ -178,7 +178,7 @@ to enable polling and disable inotify. See [here](configuration.md#polling).
|
||||
- `fonts-liberation` for generating thumbnails for plain text
|
||||
files
|
||||
- `imagemagick` >= 6 for PDF conversion
|
||||
- `gnupg` for handling encrypted documents
|
||||
- `gnupg` for decrypting GPG-encrypted email
|
||||
- `libpq-dev` for PostgreSQL
|
||||
- `libmagic-dev` for mime type detection
|
||||
- `mariadb-client` for MariaDB compile time
|
||||
@@ -226,7 +226,8 @@ to enable polling and disable inotify. See [here](configuration.md#polling).
|
||||
build-essential python3-setuptools python3-wheel
|
||||
```
|
||||
|
||||
2. Install `redis` >= 6.0 and configure it to start automatically.
|
||||
2. Install a Redis-compatible broker (a current release of Valkey or
|
||||
Redis) and configure it to start automatically.
|
||||
|
||||
3. Optional: Install `postgresql` and configure a database, user, and
|
||||
password for Paperless-ngx. If you do not wish to use PostgreSQL,
|
||||
@@ -268,10 +269,10 @@ to enable polling and disable inotify. See [here](configuration.md#polling).
|
||||
6. Configure Paperless-ngx. See [configuration](configuration.md) for details.
|
||||
Edit the included `paperless.conf` and adjust the settings to your
|
||||
needs. Required settings for getting Paperless-ngx running are:
|
||||
- [`PAPERLESS_REDIS`](configuration.md#PAPERLESS_REDIS) should point to your Redis server, such as
|
||||
- [`PAPERLESS_REDIS`](configuration.md#PAPERLESS_REDIS) should point to your broker, such as
|
||||
`redis://localhost:6379`.
|
||||
- [`PAPERLESS_DBENGINE`](configuration.md#PAPERLESS_DBENGINE) is optional, and should be one of `postgres`,
|
||||
`mariadb`, or `sqlite`
|
||||
- [`PAPERLESS_DBENGINE`](configuration.md#PAPERLESS_DBENGINE) should be one of `postgresql`,
|
||||
`mariadb`, or `sqlite`. PostgreSQL and MariaDB users must set this explicitly.
|
||||
- [`PAPERLESS_DBHOST`](configuration.md#PAPERLESS_DBHOST) should be the hostname on which your
|
||||
PostgreSQL server is running. Do not configure this to use
|
||||
SQLite instead. Also configure port, database name, user and
|
||||
@@ -297,7 +298,7 @@ to enable polling and disable inotify. See [here](configuration.md#polling).
|
||||
|
||||
!!! warning
|
||||
|
||||
Ensure your Redis instance [is secured](https://redis.io/docs/latest/operate/oss_and_stack/management/security/).
|
||||
Ensure your broker instance [is secured](https://valkey.io/topics/security/).
|
||||
|
||||
7. Create the following directories if they do not already exist:
|
||||
- `/opt/paperless/media`
|
||||
@@ -389,9 +390,9 @@ to enable polling and disable inotify. See [here](configuration.md#polling).
|
||||
`Require=paperless-webserver.socket` in the `webserver` script
|
||||
and configure `granian` to listen on port 80 (set `GRANIAN_PORT`).
|
||||
|
||||
These services rely on Redis and optionally the database server, but
|
||||
These services rely on the broker and optionally the database server, but
|
||||
don't need to be started in any particular order. The example files
|
||||
depend on Redis being started. If you use a database server, you
|
||||
depend on the broker being started. If you use a database server, you
|
||||
should add additional dependencies.
|
||||
|
||||
!!! note
|
||||
@@ -449,6 +450,12 @@ development documentation.
|
||||
You can migrate to Paperless-ngx from Paperless-ng or from the original
|
||||
Paperless project.
|
||||
|
||||
!!! note
|
||||
|
||||
Upgrading an existing Paperless-ngx installation from v2 to v3 has its own
|
||||
breaking changes and required steps. See the [v3 migration guide](migration-v3.md)
|
||||
before upgrading.
|
||||
|
||||
<h3 id="migration_ng">Migrating from Paperless-ng</h3>
|
||||
|
||||
Paperless-ngx is meant to be a drop-in replacement for Paperless-ng, and
|
||||
@@ -494,7 +501,7 @@ installation. Keep these points in mind:
|
||||
for other services, you might as well use it for Paperless as well.
|
||||
- The task scheduler of Paperless, which is used to execute periodic
|
||||
tasks such as email checking and maintenance, requires a
|
||||
[Redis](https://redis.io/) message broker instance. The
|
||||
Redis-compatible message broker instance (such as Valkey or Redis). The
|
||||
Docker Compose route takes care of that.
|
||||
- The layout of the folder structure for your documents and data
|
||||
remains the same, so you can plug your old Docker volumes into
|
||||
@@ -582,16 +589,16 @@ commands as well.
|
||||
|
||||
1. Stop and remove the Paperless container.
|
||||
2. If using an external database, stop that container.
|
||||
3. Update Redis configuration.
|
||||
3. Update broker configuration.
|
||||
1. If `REDIS_URL` is already set, change it to [`PAPERLESS_REDIS`](configuration.md#PAPERLESS_REDIS)
|
||||
and continue to step 4.
|
||||
|
||||
1. Otherwise, add a new Redis service in `docker-compose.yml`,
|
||||
1. Otherwise, add a new broker service in `docker-compose.yml`,
|
||||
following [the example compose
|
||||
files](https://github.com/paperless-ngx/paperless-ngx/tree/main/docker/compose)
|
||||
|
||||
1. Set the environment variable [`PAPERLESS_REDIS`](configuration.md#PAPERLESS_REDIS) so it points to
|
||||
the new Redis container.
|
||||
the new broker container.
|
||||
|
||||
4. Update user mapping.
|
||||
1. If set, change the environment variable `PUID` to `USERMAP_UID`.
|
||||
|
||||
+2
-33
@@ -10,9 +10,9 @@ Check for the following issues:
|
||||
`CONSUMPTION_DIR` setting. Don't adjust this setting if you're
|
||||
using docker.
|
||||
|
||||
- Ensure that redis is up and running. Paperless does its task
|
||||
- Ensure that the broker is up and running. Paperless does its task
|
||||
processing asynchronously, and for documents to arrive at the task
|
||||
processor, it needs redis to run.
|
||||
processor, it needs the broker to run.
|
||||
|
||||
- Ensure that the task processor is running. Docker does this
|
||||
automatically. Manually invoke the task processor by executing
|
||||
@@ -149,37 +149,6 @@ operating system, if these are different from `1000`. See [Docker setup](setup.m
|
||||
Also ensure that you are able to read and write to the consumption
|
||||
directory on the host.
|
||||
|
||||
## OSError: \[Errno 19\] No such device when consuming files
|
||||
|
||||
If you experience errors such as:
|
||||
|
||||
```shell-session
|
||||
File "/usr/local/lib/python3.7/site-packages/whoosh/codec/base.py", line 570, in open_compound_file
|
||||
return CompoundStorage(dbfile, use_mmap=storage.supports_mmap)
|
||||
File "/usr/local/lib/python3.7/site-packages/whoosh/filedb/compound.py", line 75, in __init__
|
||||
self._source = mmap.mmap(fileno, 0, access=mmap.ACCESS_READ)
|
||||
OSError: [Errno 19] No such device
|
||||
|
||||
During handling of the above exception, another exception occurred:
|
||||
|
||||
Traceback (most recent call last):
|
||||
File "/usr/local/lib/python3.7/site-packages/django_q/cluster.py", line 436, in worker
|
||||
res = f(*task["args"], **task["kwargs"])
|
||||
File "/usr/src/paperless/src/documents/tasks.py", line 73, in consume_file
|
||||
override_tag_ids=override_tag_ids)
|
||||
File "/usr/src/paperless/src/documents/consumer.py", line 271, in try_consume_file
|
||||
raise ConsumerError(e)
|
||||
```
|
||||
|
||||
Paperless uses a search index to provide better and faster full text
|
||||
searching. This search index is stored inside the `data` folder. The
|
||||
search index uses memory-mapped files (mmap). The above error indicates
|
||||
that paperless was unable to create and open these files.
|
||||
|
||||
This happens when you're trying to store the data directory on certain
|
||||
file systems (mostly network shares) that don't support memory-mapped
|
||||
files.
|
||||
|
||||
## Web-UI stuck at "Loading\..."
|
||||
|
||||
This might have multiple reasons.
|
||||
|
||||
+21
-2
@@ -292,6 +292,23 @@ Once setup, navigating to the email settings page in Paperless-ngx will allow yo
|
||||
You can also submit a document using the REST API, see [POSTing documents](api.md#file-uploads)
|
||||
for details.
|
||||
|
||||
### Duplicate documents
|
||||
|
||||
By default, Paperless-ngx **does not reject duplicates**. If you consume a file whose
|
||||
contents exactly match an existing document (same checksum), the new copy is still
|
||||
consumed and a warning is logged. The task entry for the upload also flags that a
|
||||
duplicate was detected and links to the existing document(s).
|
||||
|
||||
To review duplicates, open a document and switch to the **Duplicates** tab on the
|
||||
document detail page. It lists other documents that share the same content, including any
|
||||
that are in the trash (shown with a badge), and links to each so you can decide which to
|
||||
keep.
|
||||
|
||||
If you would rather reject duplicates at consumption time (the pre-v3 behavior), set
|
||||
[`PAPERLESS_CONSUMER_DELETE_DUPLICATES`](configuration.md#PAPERLESS_CONSUMER_DELETE_DUPLICATES)
|
||||
to `true`. The duplicate file is then deleted instead of consumed, and the task fails with
|
||||
a "document already exists" message.
|
||||
|
||||
## Document Suggestions
|
||||
|
||||
Paperless-ngx can suggest tags, correspondents, document types and storage paths for documents based on the content of the document. This is done using a (non-LLM) machine learning model that is trained on the documents in your database. The suggestions are shown in the document detail page and can be accepted or rejected by the user.
|
||||
@@ -306,7 +323,9 @@ Paperless-ngx includes several features that use AI to enhance the document mana
|
||||
so consider the privacy implications of using these features, especially if using a remote
|
||||
model or API provider instead of the default local model.
|
||||
|
||||
The AI features work by creating an embedding of the text content and metadata of documents, which is then used for various tasks such as similarity search and question answering. This uses the FAISS vector store.
|
||||
The AI features work by creating an embedding of the text content and metadata of documents, which is then used for various tasks such as similarity search and question answering.
|
||||
|
||||
See [AI features](advanced_usage.md#ai-features) for how to enable and configure these features, including choosing an LLM backend and setting up the LLM index for RAG.
|
||||
|
||||
### AI-Enhanced Suggestions
|
||||
|
||||
@@ -1097,7 +1116,7 @@ Paperless-ngx consists of the following components:
|
||||
errors (i.e., wrong email credentials, errors during consuming a
|
||||
specific file, etc).
|
||||
|
||||
- A [redis](https://redis.io/) message broker: This is a really
|
||||
- A message broker (such as Valkey or Redis): This is a really
|
||||
lightweight service that is responsible for getting the tasks from
|
||||
the webserver and the consumer to the task scheduler. These run in a
|
||||
different process (maybe even on different machines!), and
|
||||
|
||||
+1
-2
@@ -42,7 +42,6 @@ dependencies = [
|
||||
"drf-spectacular~=0.28",
|
||||
"drf-spectacular-sidecar~=2026.5.1",
|
||||
"drf-writable-nested~=0.7.1",
|
||||
"faiss-cpu>=1.10",
|
||||
"filelock~=3.29.0",
|
||||
"flower~=2.0.1",
|
||||
"gotenberg-client~=0.14.0",
|
||||
@@ -57,7 +56,6 @@ dependencies = [
|
||||
"llama-index-embeddings-openai-like>=0.2.2",
|
||||
"llama-index-llms-ollama>=0.9.1",
|
||||
"llama-index-llms-openai-like>=0.7.1",
|
||||
"llama-index-vector-stores-faiss>=0.5.2",
|
||||
"nltk~=3.9.1",
|
||||
"ocrmypdf~=17.4.2",
|
||||
"openai>=2.32",
|
||||
@@ -74,6 +72,7 @@ dependencies = [
|
||||
"scikit-learn~=1.8.0",
|
||||
"sentence-transformers>=5.4.1",
|
||||
"setproctitle~=1.3.4",
|
||||
"sqlite-vec==0.1.9",
|
||||
"tantivy~=0.26.0",
|
||||
"tika-client~=0.11.0",
|
||||
"torch~=2.12.0",
|
||||
|
||||
@@ -26,7 +26,7 @@ module.exports = {
|
||||
'abstract-paperless-service',
|
||||
],
|
||||
transformIgnorePatterns: [
|
||||
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|@angular/common/locales/.*\\.js$))',
|
||||
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|normalize-diacritics|@angular/common/locales/.*\\.js$))',
|
||||
],
|
||||
moduleNameMapper: {
|
||||
...esmPreset.moduleNameMapper,
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
"ngx-cookie-service": "^21.3.1",
|
||||
"ngx-device-detector": "^11.0.0",
|
||||
"ngx-ui-tour-ng-bootstrap": "^18.0.0",
|
||||
"normalize-diacritics": "^5.0.0",
|
||||
"pdfjs-dist": "^5.7.284",
|
||||
"rxjs": "^7.8.2",
|
||||
"tslib": "^2.8.1",
|
||||
|
||||
Generated
+11
@@ -71,6 +71,9 @@ importers:
|
||||
ngx-ui-tour-ng-bootstrap:
|
||||
specifier: ^18.0.0
|
||||
version: 18.0.0(4ccfccfbcf381a309618492b31e99276)
|
||||
normalize-diacritics:
|
||||
specifier: ^5.0.0
|
||||
version: 5.0.0
|
||||
pdfjs-dist:
|
||||
specifier: ^5.7.284
|
||||
version: 5.7.284
|
||||
@@ -5565,6 +5568,10 @@ packages:
|
||||
engines: {node: ^20.17.0 || >=22.9.0}
|
||||
hasBin: true
|
||||
|
||||
normalize-diacritics@5.0.0:
|
||||
resolution: {integrity: sha512-t6czCJOpbAtckN1wCC2qPWnO3GQvNANb9bcUNbiOLEqojVuP31+ELIs5KhEG8jyz0TH7iD9BWxWz8O3ic2/rMQ==}
|
||||
engines: {node: '>= 14.x', npm: '>= 6.x'}
|
||||
|
||||
normalize-path@3.0.0:
|
||||
resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==}
|
||||
engines: {node: '>=0.10.0'}
|
||||
@@ -12985,6 +12992,10 @@ snapshots:
|
||||
dependencies:
|
||||
abbrev: 4.0.0
|
||||
|
||||
normalize-diacritics@5.0.0:
|
||||
dependencies:
|
||||
tslib: 2.8.1
|
||||
|
||||
normalize-path@3.0.0: {}
|
||||
|
||||
npm-bundled@5.0.0:
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="visibleTasks.length === 0">
|
||||
<i-bs name="check2-all" class="me-1"></i-bs>{{dismissButtonText}}
|
||||
</button>
|
||||
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissAllTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="totalTasks === 0">
|
||||
<i-bs name="check2-all" class="me-1"></i-bs><ng-container i18n>Dismiss all</ng-container>
|
||||
</button>
|
||||
<div class="form-check form-switch mb-0 ms-2">
|
||||
<input class="form-check-input" type="checkbox" role="switch" [(ngModel)]="autoRefreshEnabled">
|
||||
<label class="form-check-label" for="autoRefreshSwitch" i18n>Auto refresh</label>
|
||||
@@ -81,7 +84,7 @@
|
||||
<button class="btn btn-sm btn-outline-primary" ngbDropdownToggle>{{filterTargetName}}</button>
|
||||
<div class="dropdown-menu shadow" ngbDropdownMenu>
|
||||
@for (t of filterTargets; track t.id) {
|
||||
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="filterTargetID = t.id">{{t.name}}</button>
|
||||
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="setFilterTarget(t.id)">{{t.name}}</button>
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -11,7 +11,7 @@ import { Router } from '@angular/router'
|
||||
import { RouterTestingModule } from '@angular/router/testing'
|
||||
import { NgbModal, NgbModalRef, NgbModule } from '@ng-bootstrap/ng-bootstrap'
|
||||
import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||
import { throwError } from 'rxjs'
|
||||
import { of, throwError } from 'rxjs'
|
||||
import { routes } from 'src/app/app-routing.module'
|
||||
import {
|
||||
PaperlessTask,
|
||||
@@ -29,7 +29,11 @@ import { ToastService } from 'src/app/services/toast.service'
|
||||
import { environment } from 'src/environments/environment'
|
||||
import { ConfirmDialogComponent } from '../../common/confirm-dialog/confirm-dialog.component'
|
||||
import { PageHeaderComponent } from '../../common/page-header/page-header.component'
|
||||
import { TasksComponent, TaskSection } from './tasks.component'
|
||||
import {
|
||||
TaskFilterTargetID,
|
||||
TasksComponent,
|
||||
TaskSection,
|
||||
} from './tasks.component'
|
||||
|
||||
const tasks: PaperlessTask[] = [
|
||||
{
|
||||
@@ -154,6 +158,13 @@ const paginatedTasks: Results<PaperlessTask> = {
|
||||
results: tasks,
|
||||
}
|
||||
|
||||
const sectionCountResponse = {
|
||||
all: 7,
|
||||
needs_attention: 2,
|
||||
in_progress: 3,
|
||||
completed: 2,
|
||||
}
|
||||
|
||||
describe('TasksComponent', () => {
|
||||
let component: TasksComponent
|
||||
let fixture: ComponentFixture<TasksComponent>
|
||||
@@ -221,6 +232,15 @@ describe('TasksComponent', () => {
|
||||
req.params.get('page') === '1'
|
||||
)
|
||||
.flush(paginatedTasks)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(req) =>
|
||||
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
|
||||
req.params.get('acknowledged') === 'false' &&
|
||||
!req.params.has('status')
|
||||
)
|
||||
.flush(sectionCountResponse)
|
||||
})
|
||||
|
||||
it('should display task sections with counts', () => {
|
||||
@@ -295,6 +315,7 @@ describe('TasksComponent', () => {
|
||||
const headerText = header.nativeElement.textContent
|
||||
|
||||
expect(headerText).toContain('Dismiss visible')
|
||||
expect(headerText).toContain('Dismiss all')
|
||||
expect(headerText).toContain('Auto refresh')
|
||||
expect(headerText).not.toContain('All types')
|
||||
expect(headerText).not.toContain('All sources')
|
||||
@@ -327,6 +348,74 @@ describe('TasksComponent', () => {
|
||||
expect(pagination).not.toBeNull()
|
||||
})
|
||||
|
||||
it('should apply the selected section to the server-side task query', () => {
|
||||
component.setSection(TaskSection.NeedsAttention)
|
||||
|
||||
const req = httpTestingController.expectOne(
|
||||
(request) =>
|
||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
request.params.get('page') === '1' &&
|
||||
request.params.get('page_size') === '25' &&
|
||||
request.params.get('acknowledged') === 'false' &&
|
||||
request.params.getAll('status').includes(PaperlessTaskStatus.Failure) &&
|
||||
request.params.getAll('status').includes(PaperlessTaskStatus.Revoked)
|
||||
)
|
||||
|
||||
req.flush({ count: 2, results: [tasks[0], tasks[1]] })
|
||||
expect(component.totalTasks).toBe(2)
|
||||
})
|
||||
|
||||
it('should apply task type and trigger source filters to the server-side task query', () => {
|
||||
component.setTaskType(PaperlessTaskType.SanityCheck)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(request) =>
|
||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
request.params.get('page_size') === '25' &&
|
||||
request.params.get('task_type') === PaperlessTaskType.SanityCheck
|
||||
)
|
||||
.flush({ count: 1, results: [tasks[6]] })
|
||||
|
||||
component.setTriggerSource(PaperlessTaskTriggerSource.System)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(request) =>
|
||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
request.params.get('page_size') === '25' &&
|
||||
request.params.get('task_type') === PaperlessTaskType.SanityCheck &&
|
||||
request.params.get('trigger_source') ===
|
||||
PaperlessTaskTriggerSource.System
|
||||
)
|
||||
.flush({ count: 1, results: [tasks[6]] })
|
||||
})
|
||||
|
||||
it('should apply text filters to the server-side task query', () => {
|
||||
component.filterText = 'invoice'
|
||||
jest.advanceTimersByTime(150)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(request) =>
|
||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
request.params.get('page_size') === '25' &&
|
||||
request.params.get('name') === 'invoice'
|
||||
)
|
||||
.flush({ count: 1, results: [tasks[0]] })
|
||||
|
||||
component.setFilterTarget(TaskFilterTargetID.Result)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(request) =>
|
||||
request.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
request.params.get('page_size') === '25' &&
|
||||
request.params.get('result') === 'invoice'
|
||||
)
|
||||
.flush({ count: 0, results: [] })
|
||||
})
|
||||
|
||||
it('should load a different task page when pagination changes', () => {
|
||||
component.setPage(2)
|
||||
|
||||
@@ -350,6 +439,27 @@ describe('TasksComponent', () => {
|
||||
expect(component.pagedTasks).toEqual([tasks[0]])
|
||||
})
|
||||
|
||||
it('should not replace section counts with current-page counts', () => {
|
||||
component.setPage(2)
|
||||
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(req) =>
|
||||
req.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
req.params.get('acknowledged') === 'false' &&
|
||||
req.params.get('page_size') === '25' &&
|
||||
req.params.get('page') === '2'
|
||||
)
|
||||
.flush({
|
||||
count: 30,
|
||||
results: [tasks[0]],
|
||||
})
|
||||
|
||||
expect(component.sectionCount(TaskSection.NeedsAttention)).toBe(2)
|
||||
expect(component.sectionCount(TaskSection.InProgress)).toBe(3)
|
||||
expect(component.sectionCount(TaskSection.Completed)).toBe(2)
|
||||
})
|
||||
|
||||
it('should expose stable task type options and disable empty ones', () => {
|
||||
expect(component.taskTypeOptions.map((option) => option.value)).toContain(
|
||||
PaperlessTaskType.TrainClassifier
|
||||
@@ -495,6 +605,46 @@ describe('TasksComponent', () => {
|
||||
expect(dismissSpy).toHaveBeenCalledWith(new Set([467, 466]))
|
||||
})
|
||||
|
||||
it('should support dismiss all tasks', () => {
|
||||
let modal: NgbModalRef
|
||||
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
|
||||
const dismissSpy = jest
|
||||
.spyOn(tasksService, 'dismissAllTasks')
|
||||
.mockReturnValue(of({}))
|
||||
const reloadPageSpy = jest
|
||||
.spyOn(component as any, 'reloadPage')
|
||||
.mockImplementation(() => undefined)
|
||||
|
||||
component.dismissAllTasks()
|
||||
|
||||
expect(modal).not.toBeUndefined()
|
||||
expect(modal.componentInstance.messageBold).toBe('Dismiss all 7 tasks?')
|
||||
modal.componentInstance.confirmClicked.emit()
|
||||
expect(dismissSpy).toHaveBeenCalled()
|
||||
expect(reloadPageSpy).toHaveBeenCalledWith(false)
|
||||
expect(component.selectedTasks.size).toBe(0)
|
||||
})
|
||||
|
||||
it('should show an error and re-enable modal buttons when dismissing all tasks fails', () => {
|
||||
const error = new Error('dismiss all failed')
|
||||
const toastSpy = jest.spyOn(toastService, 'showError')
|
||||
const dismissSpy = jest
|
||||
.spyOn(tasksService, 'dismissAllTasks')
|
||||
.mockReturnValue(throwError(() => error))
|
||||
|
||||
let modal: NgbModalRef
|
||||
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
|
||||
|
||||
component.dismissAllTasks()
|
||||
expect(modal).not.toBeUndefined()
|
||||
|
||||
modal.componentInstance.confirmClicked.emit()
|
||||
|
||||
expect(dismissSpy).toHaveBeenCalled()
|
||||
expect(toastSpy).toHaveBeenCalledWith('Error dismissing tasks', error)
|
||||
expect(modal.componentInstance.buttonsEnabled).toBe(true)
|
||||
})
|
||||
|
||||
it('should dismiss the currently visible scoped and filtered tasks', () => {
|
||||
component.setSection(TaskSection.InProgress)
|
||||
component.setTaskType(PaperlessTaskType.SanityCheck)
|
||||
@@ -673,6 +823,9 @@ describe('TasksComponent', () => {
|
||||
})
|
||||
|
||||
it('should keep clearing selection independent from resetting filters', () => {
|
||||
component.resetFilter()
|
||||
expect(component.filterText).toBe('')
|
||||
|
||||
component.setTaskType(PaperlessTaskType.ConsumeFile)
|
||||
component.toggleSelected(tasks[0])
|
||||
expect(component.selectedTasks.size).toBe(1)
|
||||
|
||||
@@ -40,7 +40,7 @@ export enum TaskSection {
|
||||
Completed = 'completed',
|
||||
}
|
||||
|
||||
enum TaskFilterTargetID {
|
||||
export enum TaskFilterTargetID {
|
||||
Name,
|
||||
Result,
|
||||
}
|
||||
@@ -167,6 +167,12 @@ export class TasksComponent
|
||||
public readonly pageSize = 25
|
||||
public page: number = 1
|
||||
public totalTasks: number = 0
|
||||
public sectionCounts: Record<TaskSection, number> = {
|
||||
[TaskSection.All]: 0,
|
||||
[TaskSection.NeedsAttention]: 0,
|
||||
[TaskSection.InProgress]: 0,
|
||||
[TaskSection.Completed]: 0,
|
||||
}
|
||||
public pagedTasks: PaperlessTask[] = []
|
||||
public selectedSection: TaskSection = TaskSection.All
|
||||
public selectedTaskType: PaperlessTaskType | null = null
|
||||
@@ -282,6 +288,7 @@ export class TasksComponent
|
||||
.subscribe((query) => {
|
||||
this._filterText = query
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -334,6 +341,30 @@ export class TasksComponent
|
||||
}
|
||||
}
|
||||
|
||||
dismissAllTasks() {
|
||||
let modal = this.modalService.open(ConfirmDialogComponent, {
|
||||
backdrop: 'static',
|
||||
})
|
||||
modal.componentInstance.title = $localize`Confirm Dismiss All`
|
||||
modal.componentInstance.messageBold = $localize`Dismiss all ${this.totalTasks} tasks?`
|
||||
modal.componentInstance.btnClass = 'btn-warning'
|
||||
modal.componentInstance.btnCaption = $localize`Dismiss`
|
||||
modal.componentInstance.confirmClicked.pipe(first()).subscribe(() => {
|
||||
modal.componentInstance.buttonsEnabled = false
|
||||
modal.close()
|
||||
this.tasksService.dismissAllTasks().subscribe({
|
||||
next: () => {
|
||||
this.reloadPage(false)
|
||||
},
|
||||
error: (e) => {
|
||||
this.toastService.showError($localize`Error dismissing tasks`, e)
|
||||
modal.componentInstance.buttonsEnabled = true
|
||||
},
|
||||
})
|
||||
this.clearSelection()
|
||||
})
|
||||
}
|
||||
|
||||
expandTask(task: PaperlessTask) {
|
||||
this.expandedTask = this.expandedTask == task.id ? undefined : task.id
|
||||
}
|
||||
@@ -446,9 +477,7 @@ export class TasksComponent
|
||||
}
|
||||
|
||||
sectionCount(section: TaskSection): number {
|
||||
return this.pagedTasks.filter((task) =>
|
||||
this.taskBelongsToSection(task, section)
|
||||
).length
|
||||
return this.sectionCounts[section]
|
||||
}
|
||||
|
||||
sectionShowsResults(section: TaskSection): boolean {
|
||||
@@ -458,16 +487,27 @@ export class TasksComponent
|
||||
setSection(section: TaskSection) {
|
||||
this.selectedSection = section
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
}
|
||||
|
||||
setTaskType(taskType: PaperlessTaskType | null) {
|
||||
this.selectedTaskType = taskType
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
}
|
||||
|
||||
setTriggerSource(triggerSource: PaperlessTaskTriggerSource | null) {
|
||||
this.selectedTriggerSource = triggerSource
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
}
|
||||
|
||||
setFilterTarget(filterTargetID: TaskFilterTargetID) {
|
||||
this.filterTargetID = filterTargetID
|
||||
if (this._filterText.length) {
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
}
|
||||
}
|
||||
|
||||
taskTypeOptionCount(taskType: PaperlessTaskType | null): number {
|
||||
@@ -505,19 +545,32 @@ export class TasksComponent
|
||||
}
|
||||
|
||||
public resetFilter() {
|
||||
if (!this._filterText.length) {
|
||||
return
|
||||
}
|
||||
|
||||
this._filterText = ''
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
}
|
||||
|
||||
public resetFilters() {
|
||||
const hadFilter = this.isFiltered
|
||||
this.selectedTaskType = null
|
||||
this.selectedTriggerSource = null
|
||||
this.resetFilter()
|
||||
this._filterText = ''
|
||||
this.clearSelection()
|
||||
|
||||
if (hadFilter) {
|
||||
this.reloadPage(true)
|
||||
}
|
||||
}
|
||||
|
||||
filterInputKeyup(event: KeyboardEvent) {
|
||||
if (event.key == 'Enter') {
|
||||
this._filterText = (event.target as HTMLInputElement).value
|
||||
this.clearSelection()
|
||||
this.reloadPage(true)
|
||||
} else if (event.key === 'Escape') {
|
||||
this.resetFilter()
|
||||
}
|
||||
@@ -606,19 +659,86 @@ export class TasksComponent
|
||||
)
|
||||
}
|
||||
|
||||
private reloadSectionCounts() {
|
||||
this.tasksService
|
||||
.statusCounts(this.getParamsForSection(TaskSection.All))
|
||||
.pipe(first(), takeUntil(this.unsubscribeNotifier))
|
||||
.subscribe((counts) => {
|
||||
this.sectionCounts[TaskSection.All] = counts.all
|
||||
this.sectionCounts[TaskSection.NeedsAttention] = counts.needs_attention
|
||||
this.sectionCounts[TaskSection.InProgress] = counts.in_progress
|
||||
this.sectionCounts[TaskSection.Completed] = counts.completed
|
||||
})
|
||||
}
|
||||
|
||||
private getParamsForSection(
|
||||
section: TaskSection
|
||||
): Record<string, string | number | boolean | readonly string[]> {
|
||||
const params: Record<
|
||||
string,
|
||||
string | number | boolean | readonly string[]
|
||||
> = {
|
||||
acknowledged: false,
|
||||
}
|
||||
|
||||
const statuses = this.statusesForSection(section)
|
||||
if (statuses.length) {
|
||||
params.status = statuses
|
||||
}
|
||||
|
||||
if (this.selectedTaskType !== null) {
|
||||
params.task_type = this.selectedTaskType
|
||||
}
|
||||
|
||||
if (this.selectedTriggerSource !== null) {
|
||||
params.trigger_source = this.selectedTriggerSource
|
||||
}
|
||||
|
||||
if (this._filterText.length) {
|
||||
params[
|
||||
this.filterTargetID === TaskFilterTargetID.Name ? 'name' : 'result'
|
||||
] = this._filterText
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
private statusesForSection(section: TaskSection): PaperlessTaskStatus[] {
|
||||
switch (section) {
|
||||
case TaskSection.NeedsAttention:
|
||||
return [PaperlessTaskStatus.Failure, PaperlessTaskStatus.Revoked]
|
||||
case TaskSection.InProgress:
|
||||
return [PaperlessTaskStatus.Pending, PaperlessTaskStatus.Started]
|
||||
case TaskSection.Completed:
|
||||
return [PaperlessTaskStatus.Success]
|
||||
default:
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
private reloadPage(resetToFirstPage: boolean = false) {
|
||||
if (resetToFirstPage) {
|
||||
this.page = 1
|
||||
}
|
||||
|
||||
this.reloadSectionCounts()
|
||||
|
||||
this.loading = true
|
||||
this.tasksService
|
||||
.list(this.page, this.pageSize, { acknowledged: false })
|
||||
.list(
|
||||
this.page,
|
||||
this.pageSize,
|
||||
this.getParamsForSection(this.selectedSection)
|
||||
)
|
||||
.pipe(first(), takeUntil(this.unsubscribeNotifier))
|
||||
.subscribe({
|
||||
next: (result) => {
|
||||
this.pagedTasks = result.results
|
||||
this.totalTasks = result.count
|
||||
this.sectionCounts[TaskSection.All] = result.count
|
||||
if (this.selectedSection !== TaskSection.All) {
|
||||
this.sectionCounts[this.selectedSection] = result.count
|
||||
}
|
||||
this.loading = false
|
||||
if (
|
||||
this.page > 1 &&
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<div class="chat-messages font-monospace small">
|
||||
@for (message of messages; track message) {
|
||||
<div class="message d-flex flex-row small" [class.justify-content-end]="message.role === 'user'">
|
||||
<div class="p-2 m-2" [class.bg-dark]="message.role === 'user'">
|
||||
<div class="p-2 m-2" [class.bg-body]="message.role === 'user'">
|
||||
<span>
|
||||
{{ message.content }}
|
||||
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
|
||||
|
||||
@@ -188,4 +188,14 @@ describe('ChatComponent', () => {
|
||||
component.searchInputKeyDown(event)
|
||||
expect(component.sendMessage).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not send message on Enter key press while composing with IME', () => {
|
||||
jest.spyOn(component, 'sendMessage')
|
||||
const event = new KeyboardEvent('keydown', {
|
||||
key: 'Enter',
|
||||
isComposing: true,
|
||||
})
|
||||
component.searchInputKeyDown(event)
|
||||
expect(component.sendMessage).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -155,7 +155,10 @@ export class ChatComponent implements OnInit {
|
||||
}
|
||||
|
||||
public searchInputKeyDown(event: KeyboardEvent) {
|
||||
if (event.key === 'Enter') {
|
||||
if (
|
||||
event.key === 'Enter' &&
|
||||
!(event.isComposing || event.keyCode === 229)
|
||||
) {
|
||||
event.preventDefault()
|
||||
this.sendMessage()
|
||||
}
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
@if (messageBold) {
|
||||
<p><b>{{messageBold}}</b></p>
|
||||
<p class="text-break"><b>{{messageBold}}</b></p>
|
||||
}
|
||||
@if (message) {
|
||||
<p class="mb-0" [innerHTML]="message"></p>
|
||||
<p class="mb-0 text-break" [innerHTML]="message"></p>
|
||||
}
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
|
||||
+5
-1
@@ -9,8 +9,11 @@
|
||||
<label class="form-label" for="metadataDocumentID" i18n>Documents:</label>
|
||||
<ul class="list-group"
|
||||
cdkDropList
|
||||
[cdkDropListData]="documentIDs"
|
||||
(cdkDropListDropped)="onDrop($event)">
|
||||
@for (document of documents; track document.id) {
|
||||
@for (documentID of documentIDs; track documentID) {
|
||||
@let document = getDocument(documentID);
|
||||
@if (document) {
|
||||
<li class="list-group-item d-flex align-items-center" cdkDrag>
|
||||
<i-bs name="grip-vertical" class="me-2"></i-bs>
|
||||
<div class="d-flex flex-column">
|
||||
@@ -27,6 +30,7 @@
|
||||
</small>
|
||||
</div>
|
||||
</li>
|
||||
}
|
||||
}
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
+2
-3
@@ -23,6 +23,7 @@ import {
|
||||
import { CustomFieldsService } from 'src/app/services/rest/custom-fields.service'
|
||||
import { ToastService } from 'src/app/services/toast.service'
|
||||
import { pngxPopperOptions } from 'src/app/utils/popper-options'
|
||||
import { matchesSearchText } from 'src/app/utils/text-search'
|
||||
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
|
||||
import { CustomFieldEditDialogComponent } from '../edit-dialog/custom-field-edit-dialog/custom-field-edit-dialog.component'
|
||||
|
||||
@@ -69,9 +70,7 @@ export class CustomFieldsDropdownComponent extends LoadingComponentWithPermissio
|
||||
|
||||
public get filteredFields(): CustomField[] {
|
||||
return this.unusedFields.filter(
|
||||
(f) =>
|
||||
!this.filterText ||
|
||||
f.name.toLowerCase().includes(this.filterText.toLowerCase())
|
||||
(f) => !this.filterText || matchesSearchText(f.name, this.filterText)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
+3
@@ -63,6 +63,7 @@
|
||||
[(ngModel)]="atom.value"
|
||||
[disabled]="disabled"
|
||||
[virtualScroll]="getSelectOptionsForField(atom.field)?.length > 100"
|
||||
[searchFn]="selectOptionSearchFn"
|
||||
(mousedown)="$event.stopImmediatePropagation()"
|
||||
></ng-select>
|
||||
} @else if (getCustomFieldByID(atom.field)?.data_type === CustomFieldDataType.DocumentLink) {
|
||||
@@ -81,6 +82,7 @@
|
||||
[disabled]="disabled"
|
||||
bindLabel="name"
|
||||
bindValue="id"
|
||||
[searchFn]="customFieldSearchFn"
|
||||
(mousedown)="$event.stopImmediatePropagation()"
|
||||
></ng-select>
|
||||
<select class="w-25 form-select" [(ngModel)]="atom.operator" [disabled]="disabled">
|
||||
@@ -125,6 +127,7 @@
|
||||
[(ngModel)]="atom.value"
|
||||
[disabled]="disabled"
|
||||
[multiple]="true"
|
||||
[searchFn]="selectOptionSearchFn"
|
||||
(mousedown)="$event.stopImmediatePropagation()"
|
||||
></ng-select>
|
||||
}
|
||||
|
||||
+9
@@ -36,6 +36,7 @@ import {
|
||||
CustomFieldQueryExpression,
|
||||
} from 'src/app/utils/custom-field-query-element'
|
||||
import { pngxPopperOptions } from 'src/app/utils/popper-options'
|
||||
import { matchesSearchText } from 'src/app/utils/text-search'
|
||||
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
|
||||
import { ClearableBadgeComponent } from '../clearable-badge/clearable-badge.component'
|
||||
import { DocumentLinkComponent } from '../input/document-link/document-link.component'
|
||||
@@ -281,6 +282,14 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm
|
||||
|
||||
public readonly today: string = new Date().toLocaleDateString('en-CA')
|
||||
|
||||
public customFieldSearchFn = (term: string, field: CustomField): boolean =>
|
||||
matchesSearchText(field?.name, term)
|
||||
|
||||
public selectOptionSearchFn = (
|
||||
term: string,
|
||||
option: { id: string; label: string }
|
||||
): boolean => matchesSearchText(option?.label, term)
|
||||
|
||||
constructor() {
|
||||
super()
|
||||
this.selectionModel = new CustomFieldQueriesModel()
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
[notFoundText]="notFoundText"
|
||||
[multiple]="multiple"
|
||||
[bindLabel]="bindLabel"
|
||||
[searchFn]="searchFn"
|
||||
bindValue="id"
|
||||
[virtualScroll]="items?.length > 100"
|
||||
(change)="onChange(value)"
|
||||
|
||||
@@ -112,6 +112,15 @@ describe('SelectComponent', () => {
|
||||
expect(createNewVal).toEqual('baz')
|
||||
})
|
||||
|
||||
it('should search items by independent normalized terms', () => {
|
||||
expect(
|
||||
component.searchFn('tax 26', { id: 11, name: 'Tax\u00e9s 2026' })
|
||||
).toBeTruthy()
|
||||
expect(
|
||||
component.searchFn('tax receipt', { id: 11, name: 'Tax\u00e9s 2026' })
|
||||
).toBeFalsy()
|
||||
})
|
||||
|
||||
it('should clear search term on blur after delay', fakeAsync(() => {
|
||||
const clearSpy = jest.spyOn(component, 'clearLastSearchTerm')
|
||||
component.onBlur()
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
import { RouterModule } from '@angular/router'
|
||||
import { NgSelectModule } from '@ng-select/ng-select'
|
||||
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||
import { matchesSearchText } from 'src/app/utils/text-search'
|
||||
import { AbstractInputComponent } from '../abstract-input'
|
||||
|
||||
@Component({
|
||||
@@ -99,6 +100,9 @@ export class SelectComponent extends AbstractInputComponent<number> {
|
||||
@Input()
|
||||
bindLabel: string = 'name'
|
||||
|
||||
public searchFn = (term: string, item: any): boolean =>
|
||||
matchesSearchText(item?.[this.bindLabel], term)
|
||||
|
||||
@Input()
|
||||
showFilter: boolean = false
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
[clearSearchOnAdd]="true"
|
||||
[hideSelected]="tags.length > 0"
|
||||
[addTag]="allowCreate ? createTagRef : false"
|
||||
[searchFn]="searchFn"
|
||||
addTagText="Add tag"
|
||||
i18n-addTagText
|
||||
(add)="onAdd($event)"
|
||||
|
||||
@@ -171,6 +171,15 @@ describe('TagsComponent', () => {
|
||||
expect(component.getTag(4)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should search tags by independent normalized terms including parents', () => {
|
||||
const parent: Tag = { id: 11, name: 'Financ\u00e9' }
|
||||
const child: Tag = { id: 12, name: 'Taxes 2026', parent: parent.id }
|
||||
component.tags = [parent, child]
|
||||
|
||||
expect(component.searchFn('finance 26', child)).toBeTruthy()
|
||||
expect(component.searchFn('finance receipt', child)).toBeFalsy()
|
||||
})
|
||||
|
||||
it('should emit filtered documents', () => {
|
||||
component.value = [10]
|
||||
component.tags = tags
|
||||
|
||||
@@ -21,6 +21,7 @@ import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
|
||||
import { first, firstValueFrom, tap } from 'rxjs'
|
||||
import { Tag } from 'src/app/data/tag'
|
||||
import { TagService } from 'src/app/services/rest/tag.service'
|
||||
import { matchesSearchText } from 'src/app/utils/text-search'
|
||||
import { EditDialogMode } from '../../edit-dialog/edit-dialog.component'
|
||||
import { TagEditDialogComponent } from '../../edit-dialog/tag-edit-dialog/tag-edit-dialog.component'
|
||||
import { TagComponent } from '../../tag/tag.component'
|
||||
@@ -114,6 +115,14 @@ export class TagsComponent implements OnInit, ControlValueAccessor {
|
||||
|
||||
public createTagRef: (name) => void
|
||||
|
||||
public searchFn = (term: string, tag: Tag): boolean =>
|
||||
matchesSearchText(
|
||||
[this.getParentChain(tag?.id).map((parent) => parent.name), tag?.name]
|
||||
.flat()
|
||||
.join(' '),
|
||||
term
|
||||
)
|
||||
|
||||
getTag(id: number) {
|
||||
if (this.tags) {
|
||||
return this.tags.find((tag) => tag.id == id)
|
||||
|
||||
+2
-2
@@ -1,5 +1,5 @@
|
||||
<div class="btn-group">
|
||||
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="loading || (suggestions && !aiEnabled)">
|
||||
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="disabled || loading || (suggestions && !aiEnabled)">
|
||||
@if (loading) {
|
||||
<div class="spinner-border spinner-border-sm" role="status"></div>
|
||||
} @else {
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
@if (aiEnabled) {
|
||||
<div class="btn-group" ngbDropdown #dropdown="ngbDropdown" [popperOptions]="popperOptions">
|
||||
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
|
||||
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="disabled || loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
|
||||
<span class="visually-hidden" i18n>Show suggestions</span>
|
||||
</button>
|
||||
|
||||
|
||||
+12
@@ -37,6 +37,18 @@ describe('SuggestionsDropdownComponent', () => {
|
||||
expect(component.getSuggestions.emit).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not emit getSuggestions when disabled', () => {
|
||||
jest.spyOn(component.getSuggestions, 'emit')
|
||||
component.disabled = true
|
||||
component.suggestions = null
|
||||
fixture.detectChanges()
|
||||
|
||||
component.clickSuggest()
|
||||
|
||||
expect(component.getSuggestions.emit).not.toHaveBeenCalled()
|
||||
expect(fixture.nativeElement.querySelector('button').disabled).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should toggle dropdown when clickSuggest is called and suggestions are not null', () => {
|
||||
component.aiEnabled = true
|
||||
fixture.detectChanges()
|
||||
|
||||
+8
@@ -47,6 +47,14 @@ export class SuggestionsDropdownComponent {
|
||||
addCorrespondent: EventEmitter<string> = new EventEmitter()
|
||||
|
||||
public clickSuggest(): void {
|
||||
if (
|
||||
this.disabled ||
|
||||
this.loading ||
|
||||
(this.suggestions && !this.aiEnabled)
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!this.suggestions) {
|
||||
this.getSuggestions.emit(this)
|
||||
} else {
|
||||
|
||||
+3
-1
@@ -131,7 +131,9 @@
|
||||
@if (status.tasks.celery_status === 'OK') {
|
||||
<i-bs name="check-circle-fill" class="text-primary ms-2 lh-1"></i-bs>
|
||||
} @else {
|
||||
<i-bs name="exclamation-triangle-fill" class="text-danger ms-2 lh-1"></i-bs>
|
||||
<i-bs name="exclamation-triangle-fill" class="ms-2 lh-1"
|
||||
[class.text-danger]="status.tasks.celery_status === SystemStatusItemStatus.ERROR"
|
||||
[class.text-warning]="status.tasks.celery_status === SystemStatusItemStatus.WARNING"></i-bs>
|
||||
}
|
||||
</button>
|
||||
<ng-template #celeryStatus>
|
||||
|
||||
@@ -360,6 +360,14 @@ export const PaperlessConfigOptions: ConfigOption[] = [
|
||||
category: ConfigCategory.AI,
|
||||
note: $localize`Language to use for generated AI suggestions. When unset, AI suggestions use the user's display language if explicitly set.`,
|
||||
},
|
||||
{
|
||||
key: 'llm_request_timeout',
|
||||
title: $localize`LLM Request Timeout`,
|
||||
type: ConfigOptionType.Number,
|
||||
config_key: 'PAPERLESS_AI_LLM_REQUEST_TIMEOUT',
|
||||
category: ConfigCategory.AI,
|
||||
note: $localize`Timeout in seconds for LLM requests.`,
|
||||
},
|
||||
]
|
||||
|
||||
export interface PaperlessConfig extends ObjectWithId {
|
||||
@@ -401,4 +409,5 @@ export interface PaperlessConfig extends ObjectWithId {
|
||||
llm_api_key: string
|
||||
llm_endpoint: string
|
||||
llm_output_language: string
|
||||
llm_request_timeout: number
|
||||
}
|
||||
|
||||
@@ -64,3 +64,10 @@ export interface PaperlessTaskSummary {
|
||||
last_success: Date | null
|
||||
last_failure: Date | null
|
||||
}
|
||||
|
||||
export interface PaperlessTaskStatusCounts {
|
||||
all: number
|
||||
needs_attention: number
|
||||
in_progress: number
|
||||
completed: number
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Pipe, PipeTransform } from '@angular/core'
|
||||
import { MatchingModel } from '../data/matching-model'
|
||||
import { matchesSearchText } from '../utils/text-search'
|
||||
|
||||
@Pipe({
|
||||
name: 'filter',
|
||||
@@ -21,9 +22,7 @@ export class FilterPipe implements PipeTransform {
|
||||
typeof item[key] === 'string' || typeof item[key] === 'number'
|
||||
)
|
||||
return keys.some((key) => {
|
||||
return String(item[key])
|
||||
.toLowerCase()
|
||||
.includes(searchText.toLowerCase())
|
||||
return matchesSearchText(item[key], searchText)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -80,6 +80,27 @@ describe('TasksService', () => {
|
||||
.flush({ count: 0, results: [] })
|
||||
})
|
||||
|
||||
it('calls acknowledge_tasks api endpoint on dismiss all and reloads', () => {
|
||||
tasksService.dismissAllTasks().subscribe()
|
||||
const req = httpTestingController.expectOne(
|
||||
`${environment.apiBaseUrl}tasks/acknowledge/`
|
||||
)
|
||||
expect(req.request.method).toEqual('POST')
|
||||
expect(req.request.body).toEqual({
|
||||
all: true,
|
||||
})
|
||||
req.flush([])
|
||||
// reload is then called
|
||||
httpTestingController
|
||||
.expectOne(
|
||||
(req: HttpRequest<unknown>) =>
|
||||
req.url === `${environment.apiBaseUrl}tasks/` &&
|
||||
req.params.get('acknowledged') === 'false' &&
|
||||
req.params.get('page_size') === '1000'
|
||||
)
|
||||
.flush({ count: 0, results: [] })
|
||||
})
|
||||
|
||||
it('groups mixed task types by status when reloading', () => {
|
||||
expect(tasksService.total).toEqual(0)
|
||||
const mockTasks = [
|
||||
@@ -221,4 +242,34 @@ describe('TasksService', () => {
|
||||
task_id: 'abc-123',
|
||||
})
|
||||
})
|
||||
|
||||
it('loads filtered task status counts', () => {
|
||||
tasksService
|
||||
.statusCounts({
|
||||
acknowledged: false,
|
||||
task_type: PaperlessTaskType.ConsumeFile,
|
||||
})
|
||||
.subscribe((res) => {
|
||||
expect(res).toEqual({
|
||||
all: 10,
|
||||
needs_attention: 2,
|
||||
in_progress: 3,
|
||||
completed: 5,
|
||||
})
|
||||
})
|
||||
|
||||
const req = httpTestingController.expectOne(
|
||||
(req: HttpRequest<unknown>) =>
|
||||
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
|
||||
req.params.get('acknowledged') === 'false' &&
|
||||
req.params.get('task_type') === PaperlessTaskType.ConsumeFile
|
||||
)
|
||||
expect(req.request.method).toEqual('GET')
|
||||
req.flush({
|
||||
all: 10,
|
||||
needs_attention: 2,
|
||||
in_progress: 3,
|
||||
completed: 5,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,6 +5,7 @@ import { first, map, takeUntil, tap } from 'rxjs/operators'
|
||||
import {
|
||||
PaperlessTask,
|
||||
PaperlessTaskStatus,
|
||||
PaperlessTaskStatusCounts,
|
||||
PaperlessTaskType,
|
||||
} from 'src/app/data/paperless-task'
|
||||
import { Results } from 'src/app/data/results'
|
||||
@@ -88,7 +89,7 @@ export class TasksService {
|
||||
public list(
|
||||
page: number,
|
||||
pageSize: number,
|
||||
extraParams?: Record<string, string | number | boolean>
|
||||
extraParams?: Record<string, string | number | boolean | readonly string[]>
|
||||
): Observable<Results<PaperlessTask>> {
|
||||
return this.http.get<Results<PaperlessTask>>(
|
||||
`${this.baseUrl}${this.endpoint}/`,
|
||||
@@ -102,6 +103,17 @@ export class TasksService {
|
||||
)
|
||||
}
|
||||
|
||||
public statusCounts(
|
||||
extraParams?: Record<string, string | number | boolean | readonly string[]>
|
||||
): Observable<PaperlessTaskStatusCounts> {
|
||||
return this.http.get<PaperlessTaskStatusCounts>(
|
||||
`${this.baseUrl}${this.endpoint}/status_counts/`,
|
||||
{
|
||||
params: extraParams,
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
public dismissTasks(task_ids: Set<number>): Observable<any> {
|
||||
return this.http
|
||||
.post(`${this.baseUrl}tasks/acknowledge/`, {
|
||||
@@ -116,6 +128,20 @@ export class TasksService {
|
||||
)
|
||||
}
|
||||
|
||||
public dismissAllTasks(): Observable<any> {
|
||||
return this.http
|
||||
.post(`${this.baseUrl}tasks/acknowledge/`, {
|
||||
all: true,
|
||||
})
|
||||
.pipe(
|
||||
first(),
|
||||
takeUntil(this.unsubscribeNotifer),
|
||||
tap(() => {
|
||||
this.reload()
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
public cancelPending(): void {
|
||||
this.unsubscribeNotifer.next(true)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
import { matchesSearchText } from './text-search'
|
||||
|
||||
describe('text search utilities', () => {
|
||||
it('matches text accent-insensitively', () => {
|
||||
expect(matchesSearchText('R\u00e9sum\u00e9', 'resume')).toBeTruthy()
|
||||
expect(matchesSearchText('S\u00f8ren', 'soren')).toBeTruthy()
|
||||
expect(matchesSearchText('\u0152uvre', 'oeuvre')).toBeTruthy()
|
||||
expect(matchesSearchText('Invoice', 'receipt')).toBeFalsy()
|
||||
})
|
||||
|
||||
it('matches all whitespace-separated search terms independently', () => {
|
||||
expect(matchesSearchText('taxes 2026', 'tax 26')).toBeTruthy()
|
||||
expect(matchesSearchText('2026 taxes', 'tax 26')).toBeTruthy()
|
||||
expect(matchesSearchText('Tax\u00e9s 2026', 'taxe 26')).toBeTruthy()
|
||||
expect(matchesSearchText('taxes 2026', 'tax receipt')).toBeFalsy()
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,23 @@
|
||||
import { normalizeSync } from 'normalize-diacritics'
|
||||
|
||||
export type SearchTextValue =
|
||||
| string
|
||||
| number
|
||||
| boolean
|
||||
| bigint
|
||||
| null
|
||||
| undefined
|
||||
|
||||
export function normalizeSearchText(value: SearchTextValue): string {
|
||||
return normalizeSync(String(value ?? '')).toLocaleLowerCase()
|
||||
}
|
||||
|
||||
export function matchesSearchText(
|
||||
value: SearchTextValue,
|
||||
searchText: SearchTextValue
|
||||
): boolean {
|
||||
const normalizedValue = normalizeSearchText(value)
|
||||
const searchTerms = normalizeSearchText(searchText).trim().split(/\s+/)
|
||||
|
||||
return searchTerms.every((term) => normalizedValue.includes(term))
|
||||
}
|
||||
@@ -904,6 +904,19 @@ def remove_password(
|
||||
doc.id,
|
||||
pair.source_doc.source_path,
|
||||
)
|
||||
try:
|
||||
with pikepdf.open(source_path) as pdf:
|
||||
if not pdf.is_encrypted:
|
||||
logger.info(
|
||||
"Skipping password removal for document %s because the "
|
||||
"source PDF is not encrypted",
|
||||
pair.root_doc.id,
|
||||
)
|
||||
continue
|
||||
except pikepdf.PasswordError:
|
||||
# Password-protected PDFs need the supplied password below.
|
||||
pass
|
||||
|
||||
with pikepdf.open(source_path, password=password) as pdf:
|
||||
filepath: Path = (
|
||||
Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR))
|
||||
|
||||
@@ -70,13 +70,13 @@ def suggestions_last_modified(request, pk: int) -> datetime | None:
|
||||
|
||||
def metadata_etag(request, pk: int) -> str | None:
|
||||
"""
|
||||
Metadata is extracted from the original file, so use its checksum as the
|
||||
ETag
|
||||
Metadata responses include metadata as well as document fields, so include
|
||||
the modification time with the checksum so metadata-only changes invalidate cache.
|
||||
"""
|
||||
doc = resolve_effective_document_by_pk(pk, request).document
|
||||
if doc is None:
|
||||
return None
|
||||
return doc.checksum
|
||||
return f"{doc.checksum}:{doc.modified.isoformat()}"
|
||||
|
||||
|
||||
def metadata_last_modified(request, pk: int) -> datetime | None:
|
||||
|
||||
@@ -28,6 +28,7 @@ from django.db.models.functions import Cast
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_filters import DateFilter
|
||||
from django_filters.rest_framework import BooleanFilter
|
||||
from django_filters.rest_framework import CharFilter
|
||||
from django_filters.rest_framework import DateTimeFilter
|
||||
from django_filters.rest_framework import Filter
|
||||
from django_filters.rest_framework import FilterSet
|
||||
@@ -900,6 +901,16 @@ class ShareLinkBundleFilterSet(FilterSet):
|
||||
|
||||
|
||||
class PaperlessTaskFilterSet(FilterSet):
|
||||
name = CharFilter(
|
||||
method="filter_name",
|
||||
label="Name",
|
||||
)
|
||||
|
||||
result = CharFilter(
|
||||
method="filter_result",
|
||||
label="Result",
|
||||
)
|
||||
|
||||
task_type = MultipleChoiceFilter(
|
||||
choices=PaperlessTask.TaskType.choices,
|
||||
label="Task Type",
|
||||
@@ -939,7 +950,58 @@ class PaperlessTaskFilterSet(FilterSet):
|
||||
|
||||
class Meta:
|
||||
model = PaperlessTask
|
||||
fields = ["task_type", "trigger_source", "status", "acknowledged", "owner"]
|
||||
fields = [
|
||||
"task_type",
|
||||
"trigger_source",
|
||||
"status",
|
||||
"acknowledged",
|
||||
"owner",
|
||||
"name",
|
||||
"result",
|
||||
]
|
||||
|
||||
def filter_name(self, queryset, name, value):
|
||||
if not value:
|
||||
return queryset
|
||||
|
||||
matching_task_types = [
|
||||
task_type
|
||||
for task_type, label in PaperlessTask.TaskType.choices
|
||||
if value.lower() in str(label).lower()
|
||||
]
|
||||
matching_trigger_sources = [
|
||||
trigger_source
|
||||
for trigger_source, label in PaperlessTask.TriggerSource.choices
|
||||
if value.lower() in str(label).lower()
|
||||
]
|
||||
|
||||
return queryset.filter(
|
||||
Q(input_data__filename__icontains=value)
|
||||
| Q(task_type__in=matching_task_types)
|
||||
| Q(trigger_source__in=matching_trigger_sources),
|
||||
)
|
||||
|
||||
def filter_result(self, queryset, name, value):
|
||||
if not value:
|
||||
return queryset
|
||||
|
||||
query = Q(result_data__reason__icontains=value) | Q(
|
||||
result_data__error_message__icontains=value,
|
||||
)
|
||||
|
||||
try:
|
||||
numeric_value = int(value)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
query |= Q(result_data__document_id=numeric_value) | Q(
|
||||
result_data__duplicate_of=numeric_value,
|
||||
)
|
||||
|
||||
if "duplicate" in value.lower():
|
||||
query |= Q(result_data__duplicate_of__isnull=False)
|
||||
|
||||
return queryset.filter(query)
|
||||
|
||||
def filter_is_complete(self, queryset, name, value):
|
||||
if value:
|
||||
|
||||
@@ -169,6 +169,10 @@ class FileStabilityTracker:
|
||||
self._tracked.pop(path, None)
|
||||
yield path
|
||||
|
||||
def is_tracking(self, path: Path) -> bool:
|
||||
"""Check whether a path is currently being tracked for stability."""
|
||||
return path.resolve() in self._tracked
|
||||
|
||||
def has_pending_files(self) -> bool:
|
||||
"""Check if there are files waiting for stability check."""
|
||||
return len(self._tracked) > 0
|
||||
@@ -370,6 +374,16 @@ class Command(BaseCommand):
|
||||
# Testing timeout in seconds
|
||||
testing_timeout_s: Final[float] = 0.5
|
||||
|
||||
# How often to perform a full-glob rescan of the consume directory as a
|
||||
# safety net. Each watchfiles watcher is torn down and recreated on every
|
||||
# batch to reconfigure its timeout, and a fresh watcher silently adopts the
|
||||
# current directory contents as its baseline. A file that appears between
|
||||
# one batch and the next watcher's baseline is therefore never reported and
|
||||
# would sit in the consume directory forever. This periodic rescan re-injects
|
||||
# such files into the stability tracker (see GH issue #13011). Not currently
|
||||
# user-configurable; instances may override for testing.
|
||||
rescan_interval_s: float = 300.0
|
||||
|
||||
def add_arguments(self, parser) -> None:
|
||||
parser.add_argument(
|
||||
"directory",
|
||||
@@ -425,7 +439,7 @@ class Command(BaseCommand):
|
||||
)
|
||||
|
||||
# Process existing files
|
||||
self._process_existing_files(
|
||||
queued = self._process_existing_files(
|
||||
directory=directory,
|
||||
recursive=recursive,
|
||||
subdirs_as_tags=subdirs_as_tags,
|
||||
@@ -445,6 +459,7 @@ class Command(BaseCommand):
|
||||
polling_interval=polling_interval,
|
||||
stability_delay=stability_delay,
|
||||
is_testing=is_testing,
|
||||
queued=queued,
|
||||
)
|
||||
|
||||
logger.debug("Consumer exiting")
|
||||
@@ -456,11 +471,18 @@ class Command(BaseCommand):
|
||||
recursive: bool,
|
||||
subdirs_as_tags: bool,
|
||||
consumer_filter: ConsumerFilter,
|
||||
) -> None:
|
||||
"""Process any existing files in the consumption directory."""
|
||||
) -> set[Path]:
|
||||
"""
|
||||
Process any existing files in the consumption directory.
|
||||
|
||||
Returns the set of resolved paths that were queued, so the watch loop
|
||||
can seed its in-flight set and avoid re-queuing them on the first
|
||||
rescan before the consume tasks have removed them from disk.
|
||||
"""
|
||||
logger.info(f"Processing existing files in {directory}")
|
||||
|
||||
glob_pattern = "**/*" if recursive else "*"
|
||||
queued: set[Path] = set()
|
||||
|
||||
for filepath in directory.glob(glob_pattern):
|
||||
# Use filter to check if file should be processed
|
||||
@@ -475,6 +497,48 @@ class Command(BaseCommand):
|
||||
consumption_dir=directory,
|
||||
subdirs_as_tags=subdirs_as_tags,
|
||||
)
|
||||
queued.add(filepath.resolve())
|
||||
|
||||
return queued
|
||||
|
||||
def _rescan_existing_files(
|
||||
self,
|
||||
*,
|
||||
directory: Path,
|
||||
recursive: bool,
|
||||
consumer_filter: ConsumerFilter,
|
||||
tracker: FileStabilityTracker,
|
||||
queued: set[Path],
|
||||
) -> None:
|
||||
"""
|
||||
Re-inject on-disk files the watcher never reported into the tracker.
|
||||
|
||||
Acts as a safety net for files stranded by the watcher-recreation gap
|
||||
(see ``rescan_interval_s``). Files already being tracked or already
|
||||
queued and awaiting consumption are skipped, so a file is never queued
|
||||
twice. Queued paths that have since left the directory are pruned so a
|
||||
later file reusing the same name is not skipped forever.
|
||||
"""
|
||||
# Prune in-flight paths that have left the directory
|
||||
for path in list(queued):
|
||||
if not path.exists():
|
||||
queued.discard(path)
|
||||
|
||||
glob_pattern = "**/*" if recursive else "*"
|
||||
|
||||
for filepath in directory.glob(glob_pattern):
|
||||
if not filepath.is_file():
|
||||
continue
|
||||
|
||||
if not consumer_filter(Change.added, str(filepath)):
|
||||
continue
|
||||
|
||||
resolved = filepath.resolve()
|
||||
if tracker.is_tracking(resolved) or resolved in queued:
|
||||
continue
|
||||
|
||||
logger.debug(f"Rescan found untracked file: {resolved}")
|
||||
tracker.track(resolved, Change.added)
|
||||
|
||||
def _watch_directory(
|
||||
self,
|
||||
@@ -486,11 +550,24 @@ class Command(BaseCommand):
|
||||
polling_interval: float,
|
||||
stability_delay: float,
|
||||
is_testing: bool,
|
||||
queued: set[Path] | None = None,
|
||||
) -> None:
|
||||
"""Watch directory for changes and process stable files."""
|
||||
use_polling = polling_interval > 0
|
||||
poll_delay_ms = int(polling_interval * 1000) if use_polling else 0
|
||||
|
||||
# Resolved paths that have been queued and are awaiting consumption.
|
||||
# Seeded from the startup scan so the first rescan does not re-queue
|
||||
# files whose consume tasks have not yet removed them from disk.
|
||||
queued = set() if queued is None else queued
|
||||
|
||||
# Full-glob safety net cadence (0 disables)
|
||||
rescan_interval_s = self.rescan_interval_s
|
||||
rescan_timeout_ms = (
|
||||
int(rescan_interval_s * 1000) if rescan_interval_s > 0 else 0
|
||||
)
|
||||
last_rescan = monotonic()
|
||||
|
||||
if use_polling:
|
||||
logger.info(
|
||||
f"Watching {directory} using polling (interval: {polling_interval}s)",
|
||||
@@ -505,6 +582,20 @@ class Command(BaseCommand):
|
||||
stability_timeout_ms = int(stability_delay * 1000)
|
||||
testing_timeout_ms = int(self.testing_timeout_s * 1000)
|
||||
|
||||
def cap_for_rescan(ms: int) -> int:
|
||||
"""
|
||||
Ensure the watch loop wakes often enough to run the rescan.
|
||||
|
||||
``watch()`` blocks for up to ``rust_timeout``, so the rescan can
|
||||
only run that often. A timeout of 0 means "wait indefinitely",
|
||||
which would never wake to rescan; cap it at the rescan interval.
|
||||
"""
|
||||
if rescan_timeout_ms <= 0:
|
||||
return ms
|
||||
if ms <= 0:
|
||||
return rescan_timeout_ms
|
||||
return min(ms, rescan_timeout_ms)
|
||||
|
||||
# Calculate appropriate timeout for watch loop
|
||||
# In polling mode, rust_timeout must be significantly longer than poll_delay_ms
|
||||
# to ensure poll cycles can complete before timing out
|
||||
@@ -522,6 +613,8 @@ class Command(BaseCommand):
|
||||
# Not testing, wait indefinitely for first event
|
||||
timeout_ms = 0
|
||||
|
||||
timeout_ms = cap_for_rescan(timeout_ms)
|
||||
|
||||
self.stop_flag.clear()
|
||||
|
||||
while not self.stop_flag.is_set():
|
||||
@@ -551,10 +644,26 @@ class Command(BaseCommand):
|
||||
consumption_dir=directory,
|
||||
subdirs_as_tags=subdirs_as_tags,
|
||||
)
|
||||
# Remember it so the rescan does not re-queue it while
|
||||
# the consume task has yet to remove it from disk
|
||||
queued.add(stable_path)
|
||||
|
||||
# Exit watch loop to reconfigure timeout
|
||||
break
|
||||
|
||||
# Periodic full-glob safety net for files the watcher missed
|
||||
if rescan_timeout_ms > 0 and (
|
||||
monotonic() - last_rescan >= rescan_interval_s
|
||||
):
|
||||
self._rescan_existing_files(
|
||||
directory=directory,
|
||||
recursive=recursive,
|
||||
consumer_filter=consumer_filter,
|
||||
tracker=tracker,
|
||||
queued=queued,
|
||||
)
|
||||
last_rescan = monotonic()
|
||||
|
||||
# Determine next timeout
|
||||
if tracker.has_pending_files():
|
||||
# Check pending files at stability interval
|
||||
@@ -572,6 +681,8 @@ class Command(BaseCommand):
|
||||
# No pending files, wait indefinitely
|
||||
timeout_ms = 0
|
||||
|
||||
timeout_ms = cap_for_rescan(timeout_ms)
|
||||
|
||||
except KeyboardInterrupt: # pragma: nocover
|
||||
logger.info("Received interrupt, stopping consumer")
|
||||
self.stop_flag.set()
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any
|
||||
|
||||
from documents.management.commands.base import PaperlessCommand
|
||||
from documents.tasks import llmindex_index
|
||||
from paperless_ai.indexing import llm_index_compact
|
||||
|
||||
|
||||
class Command(PaperlessCommand):
|
||||
@@ -12,9 +13,12 @@ class Command(PaperlessCommand):
|
||||
|
||||
def add_arguments(self, parser: Any) -> None:
|
||||
super().add_arguments(parser)
|
||||
parser.add_argument("command", choices=["rebuild", "update"])
|
||||
parser.add_argument("command", choices=["rebuild", "update", "compact"])
|
||||
|
||||
def handle(self, *args: Any, **options: Any) -> None:
|
||||
if options["command"] == "compact":
|
||||
llm_index_compact()
|
||||
return
|
||||
llmindex_index(
|
||||
rebuild=options["command"] == "rebuild",
|
||||
iter_wrapper=lambda docs: self.track(
|
||||
|
||||
@@ -8,11 +8,15 @@ from documents.search._backend import get_backend
|
||||
from documents.search._backend import reset_backend
|
||||
from documents.search._schema import needs_rebuild
|
||||
from documents.search._schema import wipe_index
|
||||
from documents.search._translate import InvalidDateQuery
|
||||
from documents.search._translate import SearchQueryError
|
||||
|
||||
__all__ = [
|
||||
"InvalidDateQuery",
|
||||
"SearchHit",
|
||||
"SearchIndexLockError",
|
||||
"SearchMode",
|
||||
"SearchQueryError",
|
||||
"TantivyBackend",
|
||||
"TantivyRelevanceList",
|
||||
"WriteBatch",
|
||||
|
||||
@@ -866,8 +866,24 @@ class TantivyBackend:
|
||||
final_query = self._apply_permission_filter(mlt_query, user)
|
||||
|
||||
effective_limit = limit if limit is not None else searcher.num_docs
|
||||
# Fetch one extra to account for excluding the original document
|
||||
results = searcher.search(final_query, limit=effective_limit + 1)
|
||||
try:
|
||||
# Fetch one extra to account for excluding the original document
|
||||
results = searcher.search(final_query, limit=effective_limit + 1)
|
||||
except BaseException: # pragma: no cover
|
||||
# Tantivy 0.26 panics in BM25 idf scoring when the index holds
|
||||
# soft-deleted documents (doc_freq can exceed the alive doc count),
|
||||
# which only surfaces for the More Like This query. The panic crosses
|
||||
# the pyo3 boundary as a `pyo3_runtime.PanicException` — a
|
||||
# BaseException, not an Exception — so catch BaseException and degrade
|
||||
# to "no similar documents" instead of bubbling a 500 to the client.
|
||||
# Fixed upstream: https://github.com/quickwit-oss/tantivy/pull/2964
|
||||
# Remove once the bundled tantivy includes that fix.
|
||||
logger.warning(
|
||||
"More Like This scoring panicked (likely stale tantivy segment "
|
||||
"stats after deletions); returning no results. A search index "
|
||||
"reindex will rebuild consistent statistics.",
|
||||
)
|
||||
return []
|
||||
|
||||
addrs = [addr for _score, addr in results.hits]
|
||||
all_ids = cast("list[int]", searcher.fast_field_values("id", addrs))
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
from datetime import date
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Final
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import tzinfo
|
||||
|
||||
_DATE_ONLY_FIELDS = frozenset({"created"})
|
||||
|
||||
_TODAY: Final[str] = "today"
|
||||
_YESTERDAY: Final[str] = "yesterday"
|
||||
_PREVIOUS_WEEK: Final[str] = "previous week"
|
||||
_THIS_MONTH: Final[str] = "this month"
|
||||
_PREVIOUS_MONTH: Final[str] = "previous month"
|
||||
_THIS_YEAR: Final[str] = "this year"
|
||||
_PREVIOUS_YEAR: Final[str] = "previous year"
|
||||
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
|
||||
|
||||
_DATE_KEYWORDS = frozenset(
|
||||
{
|
||||
_TODAY,
|
||||
_YESTERDAY,
|
||||
_PREVIOUS_WEEK,
|
||||
_THIS_MONTH,
|
||||
_PREVIOUS_MONTH,
|
||||
_THIS_YEAR,
|
||||
_PREVIOUS_YEAR,
|
||||
_PREVIOUS_QUARTER,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _fmt(dt: datetime) -> str:
|
||||
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
|
||||
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def _iso_range(lo: datetime, hi: datetime) -> str:
|
||||
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
|
||||
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
|
||||
|
||||
|
||||
def _quarter_start(d: date) -> date:
|
||||
"""Return the first day of the calendar quarter containing ``d``."""
|
||||
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
|
||||
|
||||
|
||||
def _midnight(d: date, tz: tzinfo) -> datetime:
|
||||
"""Convert a calendar date at local-timezone midnight to a UTC datetime."""
|
||||
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
|
||||
|
||||
|
||||
def _keyword_bounds(keyword: str, tz: tzinfo) -> tuple[date, date]:
|
||||
"""
|
||||
Map a relative date keyword to ``(start, exclusive_end)`` calendar dates.
|
||||
|
||||
``tz`` only determines what "today" is; the caller decides how the returned
|
||||
dates become UTC datetime boundaries (date-only vs. local-midnight offset).
|
||||
"""
|
||||
today = datetime.now(tz).date()
|
||||
if keyword == _TODAY:
|
||||
return today, today + timedelta(days=1)
|
||||
if keyword == _YESTERDAY:
|
||||
return today - timedelta(days=1), today
|
||||
if keyword == _PREVIOUS_WEEK:
|
||||
this_monday = today - timedelta(days=today.weekday())
|
||||
return this_monday - timedelta(weeks=1), this_monday
|
||||
if keyword == _THIS_MONTH:
|
||||
first = today.replace(day=1)
|
||||
return first, first + relativedelta(months=1)
|
||||
if keyword == _PREVIOUS_MONTH:
|
||||
this_first = today.replace(day=1)
|
||||
return this_first - relativedelta(months=1), this_first
|
||||
if keyword == _THIS_YEAR:
|
||||
return date(today.year, 1, 1), date(today.year + 1, 1, 1)
|
||||
if keyword == _PREVIOUS_YEAR:
|
||||
return date(today.year - 1, 1, 1), date(today.year, 1, 1)
|
||||
if keyword == _PREVIOUS_QUARTER:
|
||||
this_quarter = _quarter_start(today)
|
||||
return this_quarter - relativedelta(months=3), this_quarter
|
||||
raise ValueError(f"Unknown keyword: {keyword}")
|
||||
|
||||
|
||||
def _date_only_range(keyword: str, tz: tzinfo) -> str:
|
||||
"""
|
||||
For `created` (DateField): use the local calendar date, converted to
|
||||
midnight UTC boundaries. No offset arithmetic — date only.
|
||||
"""
|
||||
start, end = _keyword_bounds(keyword, tz)
|
||||
lo = datetime(start.year, start.month, start.day, tzinfo=UTC)
|
||||
hi = datetime(end.year, end.month, end.day, tzinfo=UTC)
|
||||
return _iso_range(lo, hi)
|
||||
|
||||
|
||||
def _datetime_range(keyword: str, tz: tzinfo) -> str:
|
||||
"""
|
||||
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
|
||||
boundaries to UTC — full offset arithmetic required.
|
||||
"""
|
||||
start, end = _keyword_bounds(keyword, tz)
|
||||
return _iso_range(_midnight(start, tz), _midnight(end, tz))
|
||||
|
||||
|
||||
def _precision_bounds(digits: str) -> tuple[date, date] | None:
|
||||
"""
|
||||
Map a 4/6/8-digit date token to (start, exclusive_end) calendar dates.
|
||||
|
||||
YYYY -> whole year, YYYYMM -> whole month, YYYYMMDD -> single day.
|
||||
Returns None for any unparsable or out-of-range value (e.g. month 23),
|
||||
so callers can emit a no-match clause instead of erroring (Whoosh parity).
|
||||
"""
|
||||
try:
|
||||
if len(digits) == 4:
|
||||
year = int(digits)
|
||||
return date(year, 1, 1), date(year + 1, 1, 1)
|
||||
if len(digits) == 6:
|
||||
year, month = int(digits[:4]), int(digits[4:6])
|
||||
start = date(year, month, 1)
|
||||
end = date(year + 1, 1, 1) if month == 12 else date(year, month + 1, 1)
|
||||
return start, end
|
||||
if len(digits) == 8:
|
||||
start = date(int(digits[:4]), int(digits[4:6]), int(digits[6:8]))
|
||||
return start, start + timedelta(days=1)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _utc_bounds_for_field(
|
||||
field: str,
|
||||
start: date,
|
||||
end: date,
|
||||
tz: tzinfo,
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""
|
||||
Convert calendar-date bounds to UTC datetimes per the field's storage type.
|
||||
|
||||
For DateField (``created``) the bounds are UTC midnight (no offset). For
|
||||
DateTimeField (``added``/``modified``) the bounds are local-tz midnight
|
||||
converted to UTC, matching how each field is indexed.
|
||||
"""
|
||||
if field in _DATE_ONLY_FIELDS:
|
||||
return (
|
||||
datetime(start.year, start.month, start.day, tzinfo=UTC),
|
||||
datetime(end.year, end.month, end.day, tzinfo=UTC),
|
||||
)
|
||||
return (
|
||||
datetime(start.year, start.month, start.day, tzinfo=tz).astimezone(UTC),
|
||||
datetime(end.year, end.month, end.day, tzinfo=tz).astimezone(UTC),
|
||||
)
|
||||
|
||||
|
||||
def _field_range_from_dates(field: str, start: date, end: date, tz: tzinfo) -> str:
|
||||
"""Build a Tantivy ``field:[lo TO hi]`` ISO range from calendar-date bounds."""
|
||||
lo, hi = _utc_bounds_for_field(field, start, end, tz)
|
||||
return f"{field}:{_iso_range(lo, hi)}"
|
||||
+27
-405
@@ -1,88 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC
|
||||
from datetime import date
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Final
|
||||
|
||||
import regex
|
||||
import tantivy
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from django.conf import settings
|
||||
|
||||
from documents.search._dates import (
|
||||
_date_only_range, # noqa: F401 — re-exported for test imports
|
||||
)
|
||||
from documents.search._dates import (
|
||||
_datetime_range, # noqa: F401 — re-exported for test imports
|
||||
)
|
||||
from documents.search._tokenizer import simple_search_tokens
|
||||
from documents.search._translate import SearchQueryError
|
||||
from documents.search._translate import translate_query
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import tzinfo
|
||||
|
||||
from django.contrib.auth.base_user import AbstractBaseUser
|
||||
|
||||
logger = logging.getLogger("paperless.search")
|
||||
|
||||
# Maximum seconds any single regex substitution may run.
|
||||
# Prevents ReDoS on adversarial user-supplied query strings.
|
||||
_REGEX_TIMEOUT: Final[float] = 1.0
|
||||
|
||||
_DATE_ONLY_FIELDS = frozenset({"created"})
|
||||
|
||||
_TODAY: Final[str] = "today"
|
||||
_YESTERDAY: Final[str] = "yesterday"
|
||||
_PREVIOUS_WEEK: Final[str] = "previous week"
|
||||
_THIS_MONTH: Final[str] = "this month"
|
||||
_PREVIOUS_MONTH: Final[str] = "previous month"
|
||||
_THIS_YEAR: Final[str] = "this year"
|
||||
_PREVIOUS_YEAR: Final[str] = "previous year"
|
||||
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
|
||||
|
||||
_DATE_KEYWORDS = frozenset(
|
||||
{
|
||||
_TODAY,
|
||||
_YESTERDAY,
|
||||
_PREVIOUS_WEEK,
|
||||
_THIS_MONTH,
|
||||
_PREVIOUS_MONTH,
|
||||
_THIS_YEAR,
|
||||
_PREVIOUS_YEAR,
|
||||
_PREVIOUS_QUARTER,
|
||||
},
|
||||
)
|
||||
|
||||
_DATE_KEYWORD_PATTERN = "|".join(
|
||||
sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True),
|
||||
)
|
||||
|
||||
_FIELD_DATE_RE = regex.compile(
|
||||
rf"""(?<!\w)(?P<field>created|modified|added)\s*:\s*(?:
|
||||
(?P<quote>["'])(?P<quoted>{_DATE_KEYWORD_PATTERN})(?P=quote)
|
||||
|
|
||||
(?P<bare>{_DATE_KEYWORD_PATTERN})(?![\w-])
|
||||
)""",
|
||||
regex.IGNORECASE | regex.VERBOSE,
|
||||
)
|
||||
_COMPACT_DATE_RE = regex.compile(r"\b(\d{14})\b")
|
||||
_RELATIVE_RANGE_RE = regex.compile(
|
||||
r"\[now([+-]\d+[dhm])?\s+TO\s+now([+-]\d+[dhm])?\]",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
# Whoosh-style relative date range: e.g. [-1 week to now], [-7 days to now]
|
||||
_WHOOSH_REL_RANGE_RE = regex.compile(
|
||||
r"\[-(?P<n>\d+)\s+(?P<unit>second|minute|hour|day|week|month|year)s?\s+to\s+now\]",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
# Whoosh-style 8-digit date: field:YYYYMMDD — field-aware so timezone can be applied correctly.
|
||||
# Scoped to date fields only; numeric fields (asn, id, page_count, ...) must not be rewritten.
|
||||
_DATE8_RE = regex.compile(
|
||||
r"(?<!\w)(?P<field>created|modified|added):(?P<date8>\d{8})\b",
|
||||
)
|
||||
_YEAR_RANGE_RE = regex.compile(
|
||||
r"(?<!\w)(?P<field>created|modified|added):\[(?P<y1>\d{4})\s+TO\s+(?P<y2>\d{4})\]",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
# Tantivy syntax error: " - " and " + " with spaces on both sides are invalid because
|
||||
# the NOT/MUST operators require no space between the operator and the term.
|
||||
# In natural-language queries (e.g., "H52.1 - Kurzsichtigkeit"), the dash is a separator.
|
||||
_SPACED_OPERATOR_RE = regex.compile(r"\s+[-+]\s+")
|
||||
_TRAILING_OPERATOR_RE = regex.compile(r"\s+[-+]+\s*$")
|
||||
# Matches CJK/Hangul characters so queries can be routed to bigram fields.
|
||||
# Uses Unicode properties to cover all blocks including Extension B+ planes.
|
||||
_CJK_RE: Final = regex.compile(r"[\p{Han}\p{Hiragana}\p{Katakana}\p{Hangul}]+")
|
||||
@@ -117,303 +64,12 @@ def _build_cjk_query(
|
||||
return None
|
||||
|
||||
|
||||
def _fmt(dt: datetime) -> str:
|
||||
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
|
||||
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def _iso_range(lo: datetime, hi: datetime) -> str:
|
||||
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
|
||||
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
|
||||
|
||||
|
||||
def _date_only_range(keyword: str, tz: tzinfo) -> str:
|
||||
"""
|
||||
For `created` (DateField): use the local calendar date, converted to
|
||||
midnight UTC boundaries. No offset arithmetic — date only.
|
||||
"""
|
||||
|
||||
today = datetime.now(tz).date()
|
||||
|
||||
def _quarter_start(d: date) -> date:
|
||||
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
|
||||
|
||||
if keyword == _TODAY:
|
||||
lo = datetime(today.year, today.month, today.day, tzinfo=UTC)
|
||||
return _iso_range(lo, lo + timedelta(days=1))
|
||||
if keyword == _YESTERDAY:
|
||||
y = today - timedelta(days=1)
|
||||
lo = datetime(y.year, y.month, y.day, tzinfo=UTC)
|
||||
hi = datetime(today.year, today.month, today.day, tzinfo=UTC)
|
||||
return _iso_range(lo, hi)
|
||||
if keyword == _PREVIOUS_WEEK:
|
||||
this_mon = today - timedelta(days=today.weekday())
|
||||
last_mon = this_mon - timedelta(weeks=1)
|
||||
lo = datetime(last_mon.year, last_mon.month, last_mon.day, tzinfo=UTC)
|
||||
hi = datetime(this_mon.year, this_mon.month, this_mon.day, tzinfo=UTC)
|
||||
return _iso_range(lo, hi)
|
||||
if keyword == _THIS_MONTH:
|
||||
lo = datetime(today.year, today.month, 1, tzinfo=UTC)
|
||||
if today.month == 12:
|
||||
hi = datetime(today.year + 1, 1, 1, tzinfo=UTC)
|
||||
else:
|
||||
hi = datetime(today.year, today.month + 1, 1, tzinfo=UTC)
|
||||
return _iso_range(lo, hi)
|
||||
if keyword == _PREVIOUS_MONTH:
|
||||
if today.month == 1:
|
||||
lo = datetime(today.year - 1, 12, 1, tzinfo=UTC)
|
||||
else:
|
||||
lo = datetime(today.year, today.month - 1, 1, tzinfo=UTC)
|
||||
hi = datetime(today.year, today.month, 1, tzinfo=UTC)
|
||||
return _iso_range(lo, hi)
|
||||
if keyword == _THIS_YEAR:
|
||||
lo = datetime(today.year, 1, 1, tzinfo=UTC)
|
||||
return _iso_range(lo, datetime(today.year + 1, 1, 1, tzinfo=UTC))
|
||||
if keyword == _PREVIOUS_YEAR:
|
||||
lo = datetime(today.year - 1, 1, 1, tzinfo=UTC)
|
||||
return _iso_range(lo, datetime(today.year, 1, 1, tzinfo=UTC))
|
||||
if keyword == _PREVIOUS_QUARTER:
|
||||
this_quarter = _quarter_start(today)
|
||||
last_quarter = this_quarter - relativedelta(months=3)
|
||||
lo = datetime(
|
||||
last_quarter.year,
|
||||
last_quarter.month,
|
||||
last_quarter.day,
|
||||
tzinfo=UTC,
|
||||
)
|
||||
hi = datetime(
|
||||
this_quarter.year,
|
||||
this_quarter.month,
|
||||
this_quarter.day,
|
||||
tzinfo=UTC,
|
||||
)
|
||||
return _iso_range(lo, hi)
|
||||
raise ValueError(f"Unknown keyword: {keyword}")
|
||||
|
||||
|
||||
def _datetime_range(keyword: str, tz: tzinfo) -> str:
|
||||
"""
|
||||
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
|
||||
boundaries to UTC — full offset arithmetic required.
|
||||
"""
|
||||
|
||||
now_local = datetime.now(tz)
|
||||
today = now_local.date()
|
||||
|
||||
def _midnight(d: date) -> datetime:
|
||||
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
|
||||
|
||||
def _quarter_start(d: date) -> date:
|
||||
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
|
||||
|
||||
if keyword == _TODAY:
|
||||
return _iso_range(_midnight(today), _midnight(today + timedelta(days=1)))
|
||||
if keyword == _YESTERDAY:
|
||||
y = today - timedelta(days=1)
|
||||
return _iso_range(_midnight(y), _midnight(today))
|
||||
if keyword == _PREVIOUS_WEEK:
|
||||
this_mon = today - timedelta(days=today.weekday())
|
||||
last_mon = this_mon - timedelta(weeks=1)
|
||||
return _iso_range(_midnight(last_mon), _midnight(this_mon))
|
||||
if keyword == _THIS_MONTH:
|
||||
first = today.replace(day=1)
|
||||
if today.month == 12:
|
||||
next_first = date(today.year + 1, 1, 1)
|
||||
else:
|
||||
next_first = date(today.year, today.month + 1, 1)
|
||||
return _iso_range(_midnight(first), _midnight(next_first))
|
||||
if keyword == _PREVIOUS_MONTH:
|
||||
this_first = today.replace(day=1)
|
||||
if today.month == 1:
|
||||
last_first = date(today.year - 1, 12, 1)
|
||||
else:
|
||||
last_first = date(today.year, today.month - 1, 1)
|
||||
return _iso_range(_midnight(last_first), _midnight(this_first))
|
||||
if keyword == _THIS_YEAR:
|
||||
return _iso_range(
|
||||
_midnight(date(today.year, 1, 1)),
|
||||
_midnight(date(today.year + 1, 1, 1)),
|
||||
)
|
||||
if keyword == _PREVIOUS_YEAR:
|
||||
return _iso_range(
|
||||
_midnight(date(today.year - 1, 1, 1)),
|
||||
_midnight(date(today.year, 1, 1)),
|
||||
)
|
||||
if keyword == _PREVIOUS_QUARTER:
|
||||
this_quarter = _quarter_start(today)
|
||||
last_quarter = this_quarter - relativedelta(months=3)
|
||||
return _iso_range(_midnight(last_quarter), _midnight(this_quarter))
|
||||
raise ValueError(f"Unknown keyword: {keyword}")
|
||||
|
||||
|
||||
def _rewrite_compact_date(query: str) -> str:
|
||||
"""Rewrite Whoosh compact date tokens (14-digit YYYYMMDDHHmmss) to ISO 8601."""
|
||||
|
||||
def _sub(m: regex.Match[str]) -> str:
|
||||
raw = m.group(1)
|
||||
try:
|
||||
dt = datetime(
|
||||
int(raw[0:4]),
|
||||
int(raw[4:6]),
|
||||
int(raw[6:8]),
|
||||
int(raw[8:10]),
|
||||
int(raw[10:12]),
|
||||
int(raw[12:14]),
|
||||
tzinfo=UTC,
|
||||
)
|
||||
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
except ValueError:
|
||||
return str(m.group(0))
|
||||
|
||||
try:
|
||||
return _COMPACT_DATE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
|
||||
except TimeoutError: # pragma: no cover
|
||||
raise ValueError(
|
||||
"Query too complex to process (compact date rewrite timed out)",
|
||||
)
|
||||
|
||||
|
||||
def _rewrite_relative_range(query: str) -> str:
|
||||
"""Rewrite Whoosh relative ranges ([now-7d TO now]) to concrete ISO 8601 UTC boundaries."""
|
||||
|
||||
def _sub(m: regex.Match[str]) -> str:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
def _offset(s: str | None) -> timedelta:
|
||||
if not s:
|
||||
return timedelta(0)
|
||||
sign = 1 if s[0] == "+" else -1
|
||||
n, unit = int(s[1:-1]), s[-1]
|
||||
return (
|
||||
sign
|
||||
* {
|
||||
"d": timedelta(days=n),
|
||||
"h": timedelta(hours=n),
|
||||
"m": timedelta(minutes=n),
|
||||
}[unit]
|
||||
)
|
||||
|
||||
lo, hi = now + _offset(m.group(1)), now + _offset(m.group(2))
|
||||
if lo > hi:
|
||||
lo, hi = hi, lo
|
||||
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
|
||||
|
||||
try:
|
||||
return _RELATIVE_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
|
||||
except TimeoutError: # pragma: no cover
|
||||
raise ValueError(
|
||||
"Query too complex to process (relative range rewrite timed out)",
|
||||
)
|
||||
|
||||
|
||||
def _rewrite_whoosh_relative_range(query: str) -> str:
|
||||
"""Rewrite Whoosh-style relative date ranges ([-N unit to now]) to ISO 8601.
|
||||
|
||||
Supports: second, minute, hour, day, week, month, year (singular and plural).
|
||||
Example: ``added:[-1 week to now]`` → ``added:[2025-01-01T… TO 2025-01-08T…]``
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
def _sub(m: regex.Match[str]) -> str:
|
||||
n = int(m.group("n"))
|
||||
unit = m.group("unit").lower()
|
||||
delta_map: dict[str, timedelta | relativedelta] = {
|
||||
"second": timedelta(seconds=n),
|
||||
"minute": timedelta(minutes=n),
|
||||
"hour": timedelta(hours=n),
|
||||
"day": timedelta(days=n),
|
||||
"week": timedelta(weeks=n),
|
||||
"month": relativedelta(months=n),
|
||||
"year": relativedelta(years=n),
|
||||
}
|
||||
lo = now - delta_map[unit]
|
||||
return f"[{_fmt(lo)} TO {_fmt(now)}]"
|
||||
|
||||
try:
|
||||
return _WHOOSH_REL_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
|
||||
except TimeoutError: # pragma: no cover
|
||||
raise ValueError(
|
||||
"Query too complex to process (Whoosh relative range rewrite timed out)",
|
||||
)
|
||||
|
||||
|
||||
def _rewrite_8digit_date(query: str, tz: tzinfo) -> str:
|
||||
"""Rewrite field:YYYYMMDD date tokens to an ISO 8601 day range.
|
||||
|
||||
Runs after ``_rewrite_compact_date`` so 14-digit timestamps are already
|
||||
converted and won't spuriously match here.
|
||||
|
||||
For DateField fields (e.g. ``created``) uses UTC midnight boundaries.
|
||||
For DateTimeField fields (e.g. ``added``, ``modified``) uses local TZ
|
||||
midnight boundaries converted to UTC — matching the ``_datetime_range``
|
||||
behaviour for keyword dates.
|
||||
"""
|
||||
|
||||
def _sub(m: regex.Match[str]) -> str:
|
||||
field = m.group("field")
|
||||
raw = m.group("date8")
|
||||
try:
|
||||
year, month, day = int(raw[0:4]), int(raw[4:6]), int(raw[6:8])
|
||||
d = date(year, month, day)
|
||||
if field in _DATE_ONLY_FIELDS:
|
||||
lo = datetime(d.year, d.month, d.day, tzinfo=UTC)
|
||||
hi = lo + timedelta(days=1)
|
||||
else:
|
||||
# DateTimeField: use local-timezone midnight → UTC
|
||||
lo = datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
|
||||
hi = datetime(
|
||||
(d + timedelta(days=1)).year,
|
||||
(d + timedelta(days=1)).month,
|
||||
(d + timedelta(days=1)).day,
|
||||
tzinfo=tz,
|
||||
).astimezone(UTC)
|
||||
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
|
||||
except ValueError:
|
||||
return m.group(0)
|
||||
|
||||
try:
|
||||
return _DATE8_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
|
||||
except TimeoutError: # pragma: no cover
|
||||
raise ValueError(
|
||||
"Query too complex to process (8-digit date rewrite timed out)",
|
||||
)
|
||||
|
||||
|
||||
def _rewrite_year_range(query: str) -> str:
|
||||
"""Rewrite Whoosh-style year-only date ranges to ISO 8601 UTC boundaries.
|
||||
|
||||
Converts ``field:[YYYY TO YYYY]`` to a full ISO 8601 datetime range.
|
||||
The upper bound is the start of the year after the end year (exclusive),
|
||||
matching the Whoosh convention of treating year-only ranges as full-year spans.
|
||||
"""
|
||||
|
||||
def _sub(m: regex.Match[str]) -> str:
|
||||
field = m.group("field")
|
||||
y1, y2 = int(m.group("y1")), int(m.group("y2"))
|
||||
# Whoosh swaps a reversed range when both years are explicit
|
||||
# (whoosh.util.times.timespan.disambiguated); match that so a backwards
|
||||
# range spans the intended years instead of matching nothing.
|
||||
lo_year, hi_year = min(y1, y2), max(y1, y2)
|
||||
lo = datetime(lo_year, 1, 1, tzinfo=UTC)
|
||||
hi = datetime(hi_year + 1, 1, 1, tzinfo=UTC)
|
||||
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
|
||||
|
||||
try:
|
||||
return _YEAR_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
|
||||
except TimeoutError: # pragma: no cover
|
||||
raise ValueError("Query too complex to process (year range rewrite timed out)")
|
||||
|
||||
|
||||
def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
|
||||
"""
|
||||
Rewrite natural date syntax to ISO 8601 format for Tantivy compatibility.
|
||||
|
||||
Performs the first stage of query preprocessing, converting various date
|
||||
formats and keywords to ISO 8601 datetime ranges that Tantivy can parse:
|
||||
- Compact 14-digit dates (YYYYMMDDHHmmss)
|
||||
- Whoosh relative ranges ([-7 days to now], [now-1h TO now+2h])
|
||||
- 8-digit dates with field awareness (created:20240115)
|
||||
- Natural keywords (field:today, field:"previous quarter", etc.)
|
||||
Delegates to ``translate_query`` which handles all date forms, comma
|
||||
expansion, field aliasing, relative ranges, and operator normalization.
|
||||
|
||||
Args:
|
||||
query: Raw user query string
|
||||
@@ -425,35 +81,15 @@ def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
|
||||
Note:
|
||||
Bare keywords without field prefixes pass through unchanged.
|
||||
"""
|
||||
query = _rewrite_compact_date(query)
|
||||
query = _rewrite_whoosh_relative_range(query)
|
||||
query = _rewrite_year_range(query)
|
||||
query = _rewrite_8digit_date(query, tz)
|
||||
query = _rewrite_relative_range(query)
|
||||
|
||||
def _replace(m: regex.Match[str]) -> str:
|
||||
field = m.group("field")
|
||||
keyword = (m.group("quoted") or m.group("bare")).lower()
|
||||
if field in _DATE_ONLY_FIELDS:
|
||||
return f"{field}:{_date_only_range(keyword, tz)}"
|
||||
return f"{field}:{_datetime_range(keyword, tz)}"
|
||||
|
||||
try:
|
||||
return _FIELD_DATE_RE.sub(_replace, query, timeout=_REGEX_TIMEOUT)
|
||||
except TimeoutError: # pragma: no cover
|
||||
raise ValueError(
|
||||
"Query too complex to process (date keyword rewrite timed out)",
|
||||
)
|
||||
return translate_query(query, tz)
|
||||
|
||||
|
||||
def normalize_query(query: str) -> str:
|
||||
"""
|
||||
Normalize query syntax for better search behavior.
|
||||
|
||||
Expands comma-separated field values to explicit AND clauses and
|
||||
collapses excessive whitespace for cleaner parsing:
|
||||
- tag:foo,bar → tag:foo AND tag:bar
|
||||
- multiple spaces → single spaces
|
||||
Delegates to ``translate_query`` which handles comma expansion, whitespace
|
||||
collapsing, operator normalization, and field aliasing.
|
||||
|
||||
Args:
|
||||
query: Query string after date rewriting
|
||||
@@ -461,29 +97,7 @@ def normalize_query(query: str) -> str:
|
||||
Returns:
|
||||
Normalized query string ready for Tantivy parsing
|
||||
"""
|
||||
|
||||
def _expand(m: regex.Match[str]) -> str:
|
||||
field = m.group(1)
|
||||
values = [v.strip() for v in m.group(2).split(",") if v.strip()]
|
||||
return " AND ".join(f"{field}:{v}" for v in values)
|
||||
|
||||
try:
|
||||
query = regex.sub(
|
||||
r"(\w+):([^\s\[\]]+(?:,[^\s\[\]]+)+)",
|
||||
_expand,
|
||||
query,
|
||||
timeout=_REGEX_TIMEOUT,
|
||||
)
|
||||
query = regex.sub(r" {2,}", " ", query, timeout=_REGEX_TIMEOUT).strip()
|
||||
# Strip trailing dangling operators before Tantivy sees them.
|
||||
query = _TRAILING_OPERATOR_RE.sub("", query, timeout=_REGEX_TIMEOUT).strip()
|
||||
# Replace " - " / " + " with a space: Tantivy requires no space between
|
||||
# the operator and its operand (-term / +term), so spaces on both sides
|
||||
# means this is a natural-language separator, not a query operator.
|
||||
query = _SPACED_OPERATOR_RE.sub(" ", query, timeout=_REGEX_TIMEOUT).strip()
|
||||
return query
|
||||
except TimeoutError: # pragma: no cover
|
||||
raise ValueError("Query too complex to process (normalization timed out)")
|
||||
return translate_query(query, UTC)
|
||||
|
||||
|
||||
def build_permission_filter(
|
||||
@@ -603,8 +217,16 @@ def parse_user_query(
|
||||
as a post-search score filter, not during query construction.
|
||||
"""
|
||||
|
||||
query_str = rewrite_natural_date_keywords(raw_query, tz)
|
||||
query_str = normalize_query(query_str)
|
||||
try:
|
||||
query_str = translate_query(raw_query, tz)
|
||||
except SearchQueryError:
|
||||
# Intentional, user-fixable error (e.g. an unparsable date). Propagate so
|
||||
# the view can return a 400 with a helpful message rather than falling
|
||||
# back to the raw (still-invalid) query.
|
||||
raise
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.warning("Query translation failed; using raw query", exc_info=True)
|
||||
query_str = raw_query
|
||||
|
||||
exact = index.parse_query(
|
||||
query_str,
|
||||
|
||||
@@ -0,0 +1,566 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeAlias
|
||||
|
||||
import regex
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
from documents.search._dates import _DATE_KEYWORDS
|
||||
from documents.search._dates import _DATE_ONLY_FIELDS
|
||||
from documents.search._dates import _date_only_range
|
||||
from documents.search._dates import _datetime_range
|
||||
from documents.search._dates import _field_range_from_dates
|
||||
from documents.search._dates import _fmt
|
||||
from documents.search._dates import _precision_bounds
|
||||
from documents.search._dates import _utc_bounds_for_field
|
||||
|
||||
# Compiled regex that matches any known multi-word (or single-word) date keyword
|
||||
# at the start of a match position, longest alternatives first so "previous week"
|
||||
# wins over a hypothetical shorter "previous".
|
||||
_KEYWORD_VALUE_RE = regex.compile(
|
||||
"|".join(sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True)),
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import tzinfo
|
||||
|
||||
# TODO: this module translates date queries into Tantivy *string* syntax, which
|
||||
# forces a workaround for something Tantivy's string parser cannot express on
|
||||
# date fields: open-ended ranges use far-past/far-future string sentinels
|
||||
# (OPEN_LO/OPEN_HI). These can be replaced with a real tantivy.Query object
|
||||
# (Query.range_query(..., None) for open bounds) once tantivy-py accepts Python
|
||||
# datetimes in range_query/term_query on Date fields. That support exists on
|
||||
# tantivy-py master (PRs #655 + #666) but postdates the pinned 0.26.0 wheel, so
|
||||
# it is blocked only on a published release > 0.26.0 and a dependency bump.
|
||||
# (Unparsable dates now raise InvalidDateQuery -> HTTP 400 rather than using a
|
||||
# no-match string sentinel.)
|
||||
|
||||
# Fields that store exact, non-analyzed comma-joined tokens in the index and so
|
||||
# need explicit comma->AND expansion (Whoosh KEYWORD(commas=True) set).
|
||||
MULTI_VALUE_FIELDS = frozenset({"tag", "tag_id", "viewer_id"})
|
||||
|
||||
# Date fields whose values/ranges get rewritten to RFC3339 Tantivy ranges.
|
||||
DATE_FIELDS = frozenset({"created", "modified", "added"})
|
||||
|
||||
# Field aliases: Whoosh (v2) field names that were renamed in the Tantivy schema.
|
||||
# Preserved here so v2 queries using the old names continue to work without 400
|
||||
# errors instead of silently failing. Applied by _render to non-date field tokens.
|
||||
FIELD_ALIASES: dict[str, str] = {
|
||||
"type": "document_type",
|
||||
"type_id": "document_type_id",
|
||||
"path": "storage_path",
|
||||
"path_id": "storage_path_id",
|
||||
}
|
||||
|
||||
# Known schema fields: a comma immediately followed by ``<known>:`` is a clause
|
||||
# separator. Restricting to known fields prevents URL-like ``http:`` misfires.
|
||||
KNOWN_FIELDS = frozenset(
|
||||
{
|
||||
"title",
|
||||
"content",
|
||||
"correspondent",
|
||||
"document_type",
|
||||
"type", # v2 alias -> document_type
|
||||
"storage_path",
|
||||
"path", # v2 alias -> storage_path
|
||||
"tag",
|
||||
"tag_id",
|
||||
"correspondent_id",
|
||||
"document_type_id",
|
||||
"type_id", # v2 alias -> document_type_id
|
||||
"storage_path_id",
|
||||
"path_id", # v2 alias -> storage_path_id
|
||||
"owner_id",
|
||||
"viewer_id",
|
||||
"asn",
|
||||
"page_count",
|
||||
"num_notes",
|
||||
"created",
|
||||
"modified",
|
||||
"added",
|
||||
"original_filename",
|
||||
"checksum",
|
||||
"notes",
|
||||
"custom_fields",
|
||||
},
|
||||
)
|
||||
|
||||
_FIELD_RE = regex.compile(r"(?P<field>\w+):")
|
||||
|
||||
# Matches the TO separator inside a range bracket. Handles three forms:
|
||||
# middle: "lo TO hi" (either lo or hi may be empty)
|
||||
# trailing: "lo TO" (open upper bound)
|
||||
# leading: "TO hi" (open lower bound)
|
||||
# Bounds MAY contain internal spaces (e.g. "-7 days"), so we use .*? / .+?
|
||||
# and split on the whitespace-delimited " TO " / " to " separator.
|
||||
_RANGE_RE = regex.compile(
|
||||
r"^\s*(?P<lo>.*?)\s+[Tt][Oo]\s+(?P<hi>.+?)\s*$"
|
||||
r"|"
|
||||
r"^\s*(?P<lo2>.+?)\s+[Tt][Oo]\s*$"
|
||||
r"|"
|
||||
r"^\s*[Tt][Oo]\s+(?P<hi2>.+?)\s*$",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FieldValue:
|
||||
field: str
|
||||
value: str
|
||||
|
||||
|
||||
# Produced by the comma-resolution pass (not by scan()).
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FieldValueList:
|
||||
field: str
|
||||
values: tuple[str, ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FieldRange:
|
||||
field: str
|
||||
open: str
|
||||
lo: str
|
||||
hi: str
|
||||
close: str
|
||||
|
||||
|
||||
# Produced by the comma-resolution pass (not by scan()).
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Comma:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Passthrough:
|
||||
raw: str
|
||||
|
||||
|
||||
Token: TypeAlias = FieldValue | FieldValueList | FieldRange | Comma | Passthrough
|
||||
|
||||
_CLOSE: dict[str, str] = {"[": "]", "{": "}"}
|
||||
|
||||
|
||||
def scan(query: str) -> list[Token]:
|
||||
"""
|
||||
Tokenize a raw query into date/comma-aware tokens, leaving everything else
|
||||
as verbatim ``Passthrough`` runs. Non-recursive: finds the first matching
|
||||
close bracket/quote. Nested brackets are not valid Tantivy range syntax and
|
||||
pass through verbatim on mismatch.
|
||||
"""
|
||||
tokens: list[Token] = []
|
||||
buf: list[str] = [] # accumulates passthrough chars
|
||||
i, n = 0, len(query)
|
||||
while i < n:
|
||||
matched = _match_field_token(query, i)
|
||||
if matched is None:
|
||||
buf.append(query[i])
|
||||
i += 1
|
||||
continue
|
||||
token, i = matched
|
||||
_flush(buf, tokens)
|
||||
tokens.append(token)
|
||||
i = _maybe_comma(query, i, tokens)
|
||||
_flush(buf, tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
def _flush(buf: list[str], tokens: list[Token]) -> None:
|
||||
"""Emit any accumulated passthrough characters as a single token."""
|
||||
if buf:
|
||||
tokens.append(Passthrough("".join(buf)))
|
||||
buf.clear()
|
||||
|
||||
|
||||
def _at_word_boundary(query: str, i: int) -> bool:
|
||||
"""A field token may begin only at the start or after a non-word character."""
|
||||
return i == 0 or not (query[i - 1].isalnum() or query[i - 1] == "_")
|
||||
|
||||
|
||||
def _match_field_token(query: str, i: int) -> tuple[Token, int] | None:
|
||||
"""
|
||||
If a known ``field:`` token starts at ``i``, consume it and return
|
||||
``(token, end_index)``; otherwise return None so the caller treats the
|
||||
character as passthrough. Handles both ``field:[range]`` and ``field:value``,
|
||||
and returns None when the range/value cannot be consumed.
|
||||
"""
|
||||
m = _FIELD_RE.match(query, i)
|
||||
if m is None or m.group("field") not in KNOWN_FIELDS:
|
||||
return None
|
||||
if not _at_word_boundary(query, i):
|
||||
return None
|
||||
field = m.group("field")
|
||||
j = m.end()
|
||||
if j < len(query) and query[j] in "[{":
|
||||
return _consume_range(query, j, field)
|
||||
consumed = _consume_field_value(query, field, j)
|
||||
if consumed is None:
|
||||
return None
|
||||
value, end = consumed
|
||||
return FieldValue(field, value), end
|
||||
|
||||
|
||||
def _consume_field_value(query: str, field: str, start: int) -> tuple[str, int] | None:
|
||||
"""
|
||||
Consume a field value starting at ``start``: a multi-word date keyword phrase
|
||||
(date fields only), or a bare/quoted value, then absorb any comma-joined
|
||||
continuation that is not a clause separator. ``resolve_commas`` later splits a
|
||||
multi-value field's joined value into a ``FieldValueList``; for other fields
|
||||
the comma stays literal.
|
||||
"""
|
||||
n = len(query)
|
||||
consumed = None
|
||||
if field in DATE_FIELDS:
|
||||
km = _KEYWORD_VALUE_RE.match(query, start)
|
||||
if km is not None and (km.end() >= n or query[km.end()] in " \t),"):
|
||||
consumed = (km.group(0), km.end())
|
||||
if consumed is None:
|
||||
consumed = _consume_value(query, start)
|
||||
if consumed is None:
|
||||
return None
|
||||
value, k = consumed
|
||||
while k < n and query[k] == ",":
|
||||
if _looks_like_known_field(query, k + 1):
|
||||
break # clause separator: left for _maybe_comma to emit a Comma()
|
||||
more = _consume_value(query, k + 1)
|
||||
if more is None:
|
||||
break
|
||||
value = f"{value},{more[0]}"
|
||||
k = more[1]
|
||||
return value, k
|
||||
|
||||
|
||||
def _consume_range(
|
||||
query: str,
|
||||
start: int,
|
||||
field: str,
|
||||
) -> tuple[FieldRange, int] | None:
|
||||
"""Consume ``[lo TO hi]`` / ``{lo TO hi}`` from ``start`` (the bracket)."""
|
||||
open_br = query[start]
|
||||
close_br = _CLOSE[open_br]
|
||||
end = query.find(close_br, start + 1)
|
||||
if end == -1:
|
||||
return None
|
||||
inner = query[start + 1 : end]
|
||||
m = _RANGE_RE.match(inner)
|
||||
if m is not None:
|
||||
if m.group("lo") is not None or m.group("hi") is not None:
|
||||
# Middle form: "lo TO hi" (either may be empty string)
|
||||
lo = (m.group("lo") or "").strip()
|
||||
hi = (m.group("hi") or "").strip()
|
||||
elif m.group("lo2") is not None:
|
||||
# Trailing form: "lo TO"
|
||||
lo = m.group("lo2").strip()
|
||||
hi = ""
|
||||
else:
|
||||
# Leading form: "TO hi"
|
||||
lo = ""
|
||||
hi = (m.group("hi2") or "").strip()
|
||||
else:
|
||||
lo, hi = inner.strip(), ""
|
||||
return FieldRange(field, open_br, lo, hi, close_br), end + 1
|
||||
|
||||
|
||||
def _consume_value(query: str, start: int) -> tuple[str, int] | None:
|
||||
"""Consume a bare or quoted field value from ``start``, stopping at comma."""
|
||||
n = len(query)
|
||||
if start >= n or query[start] in " \t":
|
||||
return None
|
||||
if query[start] in "\"'":
|
||||
quote = query[start]
|
||||
end = query.find(quote, start + 1)
|
||||
if end == -1:
|
||||
return None
|
||||
return query[start : end + 1], end + 1
|
||||
j = start
|
||||
while j < n and query[j] not in " \t),":
|
||||
j += 1
|
||||
return query[start:j], j
|
||||
|
||||
|
||||
def _looks_like_known_field(query: str, pos: int) -> bool:
|
||||
"""True if a known ``field:`` token starts at ``pos``."""
|
||||
m = _FIELD_RE.match(query, pos)
|
||||
return bool(m and m.group("field") in KNOWN_FIELDS)
|
||||
|
||||
|
||||
def _maybe_comma(query: str, i: int, tokens: list) -> int:
|
||||
"""If a clause-separator comma follows at ``i``, emit ``Comma()`` and advance."""
|
||||
if i < len(query) and query[i] == "," and _looks_like_known_field(query, i + 1):
|
||||
tokens.append(Comma())
|
||||
return i + 1
|
||||
return i
|
||||
|
||||
|
||||
def resolve_commas(tokens: list) -> list:
|
||||
"""
|
||||
Collapse value-list commas into ``FieldValueList`` and keep clause-separator
|
||||
commas as ``Comma``. (Clause-sep commas are already emitted by ``scan`` via
|
||||
the value-stop logic; this pass folds value-lists.)
|
||||
"""
|
||||
out: list = []
|
||||
for tok in tokens:
|
||||
if (
|
||||
isinstance(tok, FieldValue)
|
||||
and tok.field in MULTI_VALUE_FIELDS
|
||||
and "," in tok.value
|
||||
):
|
||||
values = tuple(v for v in tok.value.split(",") if v)
|
||||
out.append(FieldValueList(tok.field, values))
|
||||
else:
|
||||
out.append(tok)
|
||||
return out
|
||||
|
||||
|
||||
class SearchQueryError(ValueError):
|
||||
"""
|
||||
Base for user-fixable search query errors.
|
||||
|
||||
Carries a message safe to surface to the user (no internal details). The view
|
||||
layer catches this and returns an HTTP 400, so any future subclass (unknown
|
||||
field, malformed range, wrapped parser errors) gets the same treatment.
|
||||
"""
|
||||
|
||||
|
||||
class InvalidDateQuery(SearchQueryError):
|
||||
"""Raised when a date field value or range bound cannot be parsed."""
|
||||
|
||||
def __init__(self, field: str, value: str) -> None:
|
||||
self.field = field
|
||||
self.value = value
|
||||
super().__init__(f"Invalid date value {value!r} for field {field!r}.")
|
||||
|
||||
|
||||
_DIGITS_RE = regex.compile(r"^\d{4}(?:\d{2}){0,2}$")
|
||||
_ISO_RE = regex.compile(r"^\d{4}(?:-\d{2}(?:-\d{2})?)?$")
|
||||
|
||||
|
||||
def translate_scalar(field: str, value: str, tz: tzinfo) -> str:
|
||||
"""Translate a bare date-field value to a Tantivy range string."""
|
||||
bare = value.strip("\"'").lower()
|
||||
if bare in _DATE_KEYWORDS:
|
||||
if field in _DATE_ONLY_FIELDS:
|
||||
return f"{field}:{_date_only_range(bare, tz)}"
|
||||
return f"{field}:{_datetime_range(bare, tz)}"
|
||||
digits = value.replace("-", "")
|
||||
if _DIGITS_RE.match(value) or _ISO_RE.match(value):
|
||||
bounds = _precision_bounds(digits)
|
||||
if bounds is None:
|
||||
raise InvalidDateQuery(field, value)
|
||||
return _field_range_from_dates(field, bounds[0], bounds[1], tz)
|
||||
if regex.fullmatch(r"\d{14}", value):
|
||||
try:
|
||||
dt = datetime(
|
||||
int(value[0:4]),
|
||||
int(value[4:6]),
|
||||
int(value[6:8]),
|
||||
int(value[8:10]),
|
||||
int(value[10:12]),
|
||||
int(value[12:14]),
|
||||
tzinfo=UTC,
|
||||
)
|
||||
except ValueError:
|
||||
raise InvalidDateQuery(field, value) from None
|
||||
iso = _fmt(dt)
|
||||
return f"{field}:[{iso} TO {iso}]"
|
||||
# Unrecognized shape -> tell the user their date is malformed rather than
|
||||
# silently matching nothing or emitting invalid Tantivy syntax.
|
||||
raise InvalidDateQuery(field, value)
|
||||
|
||||
|
||||
# Open-bound sentinels for date ranges. These far-past/far-future strings allow
|
||||
# open-ended ranges to be expressed as Tantivy string queries until tantivy-py
|
||||
# exposes Query.range_query(..., None) on Date fields (see module TODO).
|
||||
OPEN_LO = "0001-01-01T00:00:00Z"
|
||||
OPEN_HI = "9999-12-31T23:59:59Z"
|
||||
|
||||
|
||||
# Matches compact now-offset tokens like now-7d, now+1h, now-30m.
|
||||
_NOW_COMPACT_RE = regex.compile(
|
||||
r"^now(?P<sign>[+-])(?P<n>\d+)(?P<unit>[dhm])$",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
# Matches "±N <unit>" Whoosh-style offsets (e.g. -7 days, -1 week, +3 hours)
|
||||
# Unit is singular or plural; sign prefix is mandatory.
|
||||
_NOW_SPACED_RE = regex.compile(
|
||||
r"^(?P<sign>[+-])(?P<n>\d+)\s*"
|
||||
r"(?P<unit>second|minute|hour|day|week|month|year)s?$",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_relative_bound(token: str) -> datetime | None:
|
||||
"""
|
||||
Resolve a relative bound token to an exact UTC instant, or return None.
|
||||
|
||||
Supported forms:
|
||||
- ``now`` -> current UTC instant
|
||||
- ``now+/-<n>d/h/m`` -> now +/- timedelta (d=days, h=hours, m=minutes)
|
||||
- ``±N <unit>`` -> now +/- delta; month/year use relativedelta
|
||||
"""
|
||||
stripped = token.strip()
|
||||
low = stripped.lower()
|
||||
now = datetime.now(UTC)
|
||||
|
||||
if low == "now":
|
||||
return now
|
||||
|
||||
m = _NOW_COMPACT_RE.match(stripped)
|
||||
if m:
|
||||
sign = 1 if m.group("sign") == "+" else -1
|
||||
n = int(m.group("n"))
|
||||
unit = m.group("unit").lower()
|
||||
delta = (
|
||||
sign
|
||||
* {
|
||||
"d": timedelta(days=n),
|
||||
"h": timedelta(hours=n),
|
||||
"m": timedelta(minutes=n),
|
||||
}[unit]
|
||||
)
|
||||
return now + delta
|
||||
|
||||
m = _NOW_SPACED_RE.match(stripped)
|
||||
if m:
|
||||
sign = 1 if m.group("sign") == "+" else -1
|
||||
n = int(m.group("n"))
|
||||
unit = m.group("unit").lower()
|
||||
delta_map: dict[str, timedelta | relativedelta] = {
|
||||
"second": timedelta(seconds=n),
|
||||
"minute": timedelta(minutes=n),
|
||||
"hour": timedelta(hours=n),
|
||||
"day": timedelta(days=n),
|
||||
"week": timedelta(weeks=n),
|
||||
"month": relativedelta(months=n),
|
||||
"year": relativedelta(years=n),
|
||||
}
|
||||
return now - delta_map[unit] if sign == -1 else now + delta_map[unit]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _bound_datetimes(
|
||||
field: str,
|
||||
token: str,
|
||||
tz: tzinfo,
|
||||
) -> tuple[datetime, datetime] | None:
|
||||
"""
|
||||
Return (floor_dt, ceil_dt) UTC datetimes for a single range bound token, or
|
||||
None if the token is unparsable. ``now`` and relative offsets resolve to the
|
||||
current instant (floor == ceil == that instant; no day-flooring).
|
||||
"""
|
||||
token = token.strip()
|
||||
|
||||
# Try relative/now forms first (before stripping hyphens which would mangle them).
|
||||
rel = _resolve_relative_bound(token)
|
||||
if rel is not None:
|
||||
return rel, rel
|
||||
|
||||
# Full ISO datetime token (contains "T"): parse directly and return an exact
|
||||
# instant (floor == ceil). Python 3.11+ datetime.fromisoformat accepts trailing Z.
|
||||
if "T" in token:
|
||||
try:
|
||||
dt = datetime.fromisoformat(token)
|
||||
# Ensure timezone-aware UTC result.
|
||||
dt = dt.replace(tzinfo=UTC) if dt.tzinfo is None else dt.astimezone(UTC)
|
||||
return dt, dt
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
digits = token.replace("-", "")
|
||||
bounds = _precision_bounds(digits)
|
||||
if bounds is None:
|
||||
return None
|
||||
start, end = bounds
|
||||
return _utc_bounds_for_field(field, start, end, tz)
|
||||
|
||||
|
||||
def _render(tok: Token, tz: tzinfo) -> str:
|
||||
"""Render a single token back to a Tantivy query string fragment."""
|
||||
if isinstance(tok, Passthrough):
|
||||
return tok.raw
|
||||
if isinstance(tok, Comma):
|
||||
return " AND "
|
||||
if isinstance(tok, FieldValueList):
|
||||
field = FIELD_ALIASES.get(tok.field, tok.field)
|
||||
return " AND ".join(f"{field}:{v}" for v in tok.values)
|
||||
if isinstance(tok, FieldValue):
|
||||
field = FIELD_ALIASES.get(tok.field, tok.field)
|
||||
if field in DATE_FIELDS:
|
||||
return translate_scalar(field, tok.value, tz)
|
||||
return f"{field}:{tok.value}"
|
||||
if isinstance(tok, FieldRange):
|
||||
field = FIELD_ALIASES.get(tok.field, tok.field)
|
||||
if field in DATE_FIELDS:
|
||||
return translate_range(field, tok.lo, tok.hi, tz)
|
||||
return f"{field}:{tok.open}{tok.lo} TO {tok.hi}{tok.close}"
|
||||
return "" # pragma: no cover
|
||||
|
||||
|
||||
# Post-render operator normalization patterns: collapse repeated whitespace and
|
||||
# strip spaced/trailing Tantivy boolean operators that would otherwise be invalid.
|
||||
_MULTI_SPACE_RE = regex.compile(r" {2,}")
|
||||
_TRAILING_OP_RE = regex.compile(r"\s+[-+]+\s*$")
|
||||
_SPACED_OP_RE = regex.compile(r"\s+[-+]\s+")
|
||||
|
||||
|
||||
def _normalize_operators(text: str) -> str:
|
||||
"""
|
||||
Collapse multiple spaces, strip trailing dangling operators, and replace
|
||||
spaced operators (`` - `` / `` + ``) with a single space.
|
||||
|
||||
Applied only to Passthrough fragments (the rendered output is scanned for
|
||||
operator artifacts outside bracketed ranges) via a post-render pass on the
|
||||
full rendered string. This preserves date ranges (``[... TO ...]``) verbatim
|
||||
while cleaning natural-language separators in the surrounding text.
|
||||
"""
|
||||
text = _MULTI_SPACE_RE.sub(" ", text)
|
||||
text = _TRAILING_OP_RE.sub("", text).strip()
|
||||
text = _SPACED_OP_RE.sub(" ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def translate_query(raw: str, tz: tzinfo) -> str:
|
||||
"""Translate a raw Whoosh-style query into Tantivy-compatible syntax."""
|
||||
tokens = resolve_commas(scan(raw))
|
||||
rendered = "".join(_render(t, tz) for t in tokens)
|
||||
return _normalize_operators(rendered)
|
||||
|
||||
|
||||
def translate_range(field: str, lo: str, hi: str, tz: tzinfo) -> str:
|
||||
"""Translate a date-field ``[lo TO hi]`` range to a Tantivy ISO range string.
|
||||
|
||||
Handles partial-date bounds (YYYY, YYYYMM, YYYYMMDD, ISO dash variants),
|
||||
open bounds (empty string -> OPEN_LO/OPEN_HI), ``now``, and reversed ranges
|
||||
(swaps tokens before computing floor/ceil so the span is always correct).
|
||||
"""
|
||||
lo_s = lo.strip()
|
||||
hi_s = hi.strip()
|
||||
|
||||
# Parse both bounds to (floor, ceil) pairs when present.
|
||||
lo_pair: tuple[datetime, datetime] | None = None
|
||||
hi_pair: tuple[datetime, datetime] | None = None
|
||||
|
||||
if lo_s:
|
||||
lo_pair = _bound_datetimes(field, lo_s, tz)
|
||||
if lo_pair is None:
|
||||
raise InvalidDateQuery(field, lo_s)
|
||||
if hi_s:
|
||||
hi_pair = _bound_datetimes(field, hi_s, tz)
|
||||
if hi_pair is None:
|
||||
raise InvalidDateQuery(field, hi_s)
|
||||
|
||||
# Detect a reversed range: only swap when BOTH bounds are present.
|
||||
if lo_pair is not None and hi_pair is not None and lo_pair[0] > hi_pair[0]:
|
||||
lo_pair, hi_pair = hi_pair, lo_pair
|
||||
|
||||
lo_iso = _fmt(lo_pair[0]) if lo_pair is not None else OPEN_LO
|
||||
hi_iso = _fmt(hi_pair[1]) if hi_pair is not None else OPEN_HI
|
||||
|
||||
return f"{field}:[{lo_iso} TO {hi_iso}]"
|
||||
@@ -48,6 +48,7 @@ from rest_framework import serializers
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
from rest_framework.fields import SerializerMethodField
|
||||
from rest_framework.filters import OrderingFilter
|
||||
from rest_framework.utils import model_meta
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
from auditlog.context import set_actor
|
||||
@@ -121,6 +122,45 @@ class DynamicFieldsModelSerializer(serializers.ModelSerializer[Any]):
|
||||
self.fields.pop(field_name)
|
||||
|
||||
|
||||
class DocumentUpdateFieldsModelSerializer(DynamicFieldsModelSerializer):
|
||||
stale_update_excluded_fields = frozenset({"filename", "archive_filename"})
|
||||
|
||||
def _get_update_fields(self, validated_data) -> list[str]:
|
||||
model_fields = {
|
||||
field.name
|
||||
for field in self.Meta.model._meta.concrete_fields
|
||||
if field.name not in self.stale_update_excluded_fields
|
||||
}
|
||||
update_fields = [
|
||||
field_name for field_name in validated_data if field_name in model_fields
|
||||
]
|
||||
if "modified" in model_fields and "modified" not in update_fields:
|
||||
update_fields.append("modified")
|
||||
return update_fields
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
serializers.raise_errors_on_nested_writes("update", self, validated_data)
|
||||
info = model_meta.get_field_info(instance)
|
||||
|
||||
m2m_fields = []
|
||||
for attr, value in validated_data.items():
|
||||
if attr in info.relations and info.relations[attr].to_many:
|
||||
m2m_fields.append((attr, value))
|
||||
else:
|
||||
setattr(instance, attr, value)
|
||||
|
||||
# File names are managed by post-save file handling. Saving only the
|
||||
# serializer-updated fields prevents stale in-memory path values from
|
||||
# overwriting a concurrent move.
|
||||
instance.save(update_fields=self._get_update_fields(validated_data))
|
||||
|
||||
for attr, value in m2m_fields:
|
||||
field = getattr(instance, attr)
|
||||
field.set(value)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
class MatchingModelSerializer(serializers.ModelSerializer[Any]):
|
||||
document_count = serializers.IntegerField(read_only=True)
|
||||
|
||||
@@ -989,7 +1029,7 @@ class DocumentVersionInfoSerializer(serializers.Serializer[_DocumentVersionInfo]
|
||||
class DocumentSerializer(
|
||||
OwnedObjectSerializer,
|
||||
NestedUpdateMixin,
|
||||
DynamicFieldsModelSerializer,
|
||||
DocumentUpdateFieldsModelSerializer,
|
||||
):
|
||||
correspondent = CorrespondentField(allow_null=True)
|
||||
tags = TagsField(many=True)
|
||||
@@ -1128,10 +1168,9 @@ class DocumentSerializer(
|
||||
return super().validate(attrs)
|
||||
|
||||
def update(self, instance: Document, validated_data):
|
||||
if "created_date" in validated_data and "created" not in validated_data:
|
||||
instance.created = validated_data.get("created_date")
|
||||
instance.save()
|
||||
if "created_date" in validated_data:
|
||||
if "created" not in validated_data:
|
||||
validated_data["created"] = validated_data["created_date"]
|
||||
logger.warning(
|
||||
"created_date is deprecated, use created instead",
|
||||
)
|
||||
@@ -1201,11 +1240,13 @@ class DocumentSerializer(
|
||||
for tag in instance.tags.all()
|
||||
if tag not in inbox_tags_not_being_added
|
||||
]
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
with set_actor(self.user):
|
||||
super().update(instance, validated_data)
|
||||
else:
|
||||
super().update(instance, validated_data)
|
||||
|
||||
# hard delete custom field instances that were soft deleted
|
||||
CustomFieldInstance.deleted_objects.filter(document=instance).delete()
|
||||
return instance
|
||||
@@ -2632,18 +2673,25 @@ class RunTaskSerializer(serializers.Serializer[dict[str, str]]):
|
||||
|
||||
class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
|
||||
tasks = serializers.ListField(
|
||||
required=True,
|
||||
required=False,
|
||||
label="Tasks",
|
||||
write_only=True,
|
||||
child=serializers.IntegerField(),
|
||||
)
|
||||
all = serializers.BooleanField(
|
||||
required=False,
|
||||
default=False,
|
||||
label="All",
|
||||
write_only=True,
|
||||
)
|
||||
|
||||
def _validate_task_id_list(self, tasks, name="tasks") -> None:
|
||||
if not isinstance(tasks, list):
|
||||
raise serializers.ValidationError(f"{name} must be a list")
|
||||
if not all(isinstance(i, int) for i in tasks):
|
||||
raise serializers.ValidationError(f"{name} must be a list of integers")
|
||||
count = PaperlessTask.objects.filter(id__in=tasks).count()
|
||||
queryset = self.context.get("queryset", PaperlessTask.objects.all())
|
||||
count = queryset.filter(id__in=tasks).count()
|
||||
if not count == len(tasks):
|
||||
raise serializers.ValidationError(
|
||||
f"Some tasks in {name} don't exist or were specified twice.",
|
||||
@@ -2653,6 +2701,21 @@ class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
|
||||
self._validate_task_id_list(tasks)
|
||||
return tasks
|
||||
|
||||
def validate(self, attrs):
|
||||
acknowledge_all = attrs.get("all", False)
|
||||
task_ids = attrs.get("tasks")
|
||||
|
||||
if acknowledge_all and task_ids is not None:
|
||||
raise serializers.ValidationError(
|
||||
"Set either all or tasks, not both.",
|
||||
)
|
||||
if not acknowledge_all and task_ids is None:
|
||||
raise serializers.ValidationError(
|
||||
"Either all must be true or tasks must be provided.",
|
||||
)
|
||||
|
||||
return attrs
|
||||
|
||||
|
||||
class ShareLinkSerializer(OwnedObjectSerializer):
|
||||
class Meta:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import hashlib
|
||||
import logging
|
||||
import shutil
|
||||
import traceback as _tb
|
||||
@@ -16,6 +15,7 @@ from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.signals import task_revoked
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_process_shutdown
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import User
|
||||
@@ -54,6 +54,7 @@ from documents.models import WorkflowTrigger
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
from documents.plugins.helpers import DocumentsStatusManager
|
||||
from documents.templating.utils import convert_format_str_to_template_format
|
||||
from documents.utils import compute_checksum
|
||||
from documents.workflows.actions import build_workflow_action_context
|
||||
from documents.workflows.actions import execute_email_action
|
||||
from documents.workflows.actions import execute_move_to_trash_action
|
||||
@@ -410,8 +411,7 @@ def _path_matches_checksum(path: Path, checksum: str | None) -> bool:
|
||||
if checksum is None or not path.is_file():
|
||||
return False
|
||||
|
||||
with path.open("rb") as f:
|
||||
return hashlib.md5(f.read()).hexdigest() == checksum
|
||||
return compute_checksum(path) == checksum
|
||||
|
||||
|
||||
def _filename_template_uses_custom_fields(doc: Document) -> bool:
|
||||
@@ -1340,6 +1340,20 @@ def close_connection_pool_on_worker_init(**kwargs) -> None:
|
||||
conn.close_pool()
|
||||
|
||||
|
||||
@worker_process_shutdown.connect
|
||||
def close_connection_pool_on_worker_shutdown(**kwargs) -> None: # pragma: no cover
|
||||
"""
|
||||
Close the DB connection pool when a Celery child process exits.
|
||||
|
||||
With CELERY_WORKER_MAX_TASKS_PER_CHILD=1 each child is replaced after a
|
||||
single task. Without closing the pool on shutdown, its connections linger
|
||||
on the server until TCP keepalive reaps them, accumulating over time.
|
||||
"""
|
||||
for conn in connections.all(initialized_only=True):
|
||||
if conn.alias == "default" and hasattr(conn, "pool") and conn.pool:
|
||||
conn.close_pool()
|
||||
|
||||
|
||||
def add_or_update_document_in_llm_index(sender, document, **kwargs):
|
||||
"""
|
||||
Add or update a document in the LLM index when it is created or updated.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
from collections.abc import Iterable
|
||||
from pathlib import PurePath
|
||||
|
||||
@@ -36,10 +37,12 @@ class FilePathTemplate(Template):
|
||||
def clean_filepath(value: str) -> str:
|
||||
"""
|
||||
Clean up a filepath by:
|
||||
1. Removing newlines and carriage returns
|
||||
2. Removing extra spaces before and after forward slashes
|
||||
3. Preserving spaces in other parts of the path
|
||||
1. Normalizing Unicode to NFC form to prevent byte-level mismatches
|
||||
2. Removing newlines and carriage returns
|
||||
3. Removing extra spaces before and after forward slashes
|
||||
4. Preserving spaces in other parts of the path
|
||||
"""
|
||||
value = unicodedata.normalize("NFC", value)
|
||||
value = value.replace("\n", "").replace("\r", "")
|
||||
value = re.sub(r"\s*/\s*", "/", value)
|
||||
|
||||
@@ -181,17 +184,17 @@ def get_basic_metadata_context(
|
||||
"""
|
||||
return {
|
||||
"title": pathvalidate.sanitize_filename(
|
||||
document.title,
|
||||
unicodedata.normalize("NFC", document.title),
|
||||
replacement_text="-",
|
||||
),
|
||||
"correspondent": pathvalidate.sanitize_filename(
|
||||
document.correspondent.name,
|
||||
unicodedata.normalize("NFC", document.correspondent.name),
|
||||
replacement_text="-",
|
||||
)
|
||||
if document.correspondent
|
||||
else no_value_default,
|
||||
"document_type": pathvalidate.sanitize_filename(
|
||||
document.document_type.name,
|
||||
unicodedata.normalize("NFC", document.document_type.name),
|
||||
replacement_text="-",
|
||||
)
|
||||
if document.document_type
|
||||
@@ -202,7 +205,10 @@ def get_basic_metadata_context(
|
||||
"owner_username": document.owner.username
|
||||
if document.owner
|
||||
else no_value_default,
|
||||
"original_name": PurePath(document.original_filename).with_suffix("").name
|
||||
"original_name": unicodedata.normalize(
|
||||
"NFC",
|
||||
PurePath(document.original_filename).with_suffix("").name,
|
||||
)
|
||||
if document.original_filename
|
||||
else no_value_default,
|
||||
"doc_pk": f"{document.pk:07}",
|
||||
@@ -269,12 +275,12 @@ def get_tags_context(tags: Iterable[Tag]) -> dict[str, str | list[str]]:
|
||||
return {
|
||||
"tag_list": pathvalidate.sanitize_filename(
|
||||
",".join(
|
||||
sorted(tag.name for tag in tags),
|
||||
sorted(unicodedata.normalize("NFC", tag.name) for tag in tags),
|
||||
),
|
||||
replacement_text="-",
|
||||
),
|
||||
# Assumed to be ordered, but a template could loop through to find what they want
|
||||
"tag_name_list": [x.name for x in tags],
|
||||
"tag_name_list": [unicodedata.normalize("NFC", x.name) for x in tags],
|
||||
}
|
||||
|
||||
|
||||
@@ -301,7 +307,7 @@ def get_custom_fields_context(
|
||||
CustomField.FieldDataType.LONG_TEXT,
|
||||
}:
|
||||
value = pathvalidate.sanitize_filename(
|
||||
field_instance.value,
|
||||
unicodedata.normalize("NFC", field_instance.value),
|
||||
replacement_text="-",
|
||||
)
|
||||
elif (
|
||||
@@ -310,10 +316,13 @@ def get_custom_fields_context(
|
||||
):
|
||||
options = field_instance.field.extra_data["select_options"]
|
||||
value = pathvalidate.sanitize_filename(
|
||||
next(
|
||||
option["label"]
|
||||
for option in options
|
||||
if option["id"] == field_instance.value
|
||||
unicodedata.normalize(
|
||||
"NFC",
|
||||
next(
|
||||
option["label"]
|
||||
for option in options
|
||||
if option["id"] == field_instance.value
|
||||
),
|
||||
),
|
||||
replacement_text="-",
|
||||
)
|
||||
@@ -321,7 +330,7 @@ def get_custom_fields_context(
|
||||
value = field_instance.value
|
||||
field_data["custom_fields"][
|
||||
pathvalidate.sanitize_filename(
|
||||
field_instance.field.name,
|
||||
unicodedata.normalize("NFC", field_instance.field.name),
|
||||
replacement_text="-",
|
||||
)
|
||||
] = {
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.core.management import call_command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
_COMPACT = "documents.management.commands.document_llmindex.llm_index_compact"
|
||||
_INDEX = "documents.management.commands.document_llmindex.llmindex_index"
|
||||
|
||||
|
||||
class TestDocumentLlmindexCommand:
|
||||
def test_compact_calls_llm_index_compact(self, mocker: MockerFixture) -> None:
|
||||
mock_compact = mocker.patch(_COMPACT)
|
||||
call_command("document_llmindex", "compact")
|
||||
mock_compact.assert_called_once_with()
|
||||
|
||||
def test_rebuild_calls_llmindex_index_with_rebuild_true(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mock_index = mocker.patch(_INDEX)
|
||||
call_command("document_llmindex", "rebuild")
|
||||
mock_index.assert_called_once()
|
||||
assert mock_index.call_args.kwargs["rebuild"] is True
|
||||
|
||||
def test_update_calls_llmindex_index_with_rebuild_false(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mock_index = mocker.patch(_INDEX)
|
||||
call_command("document_llmindex", "update")
|
||||
mock_index.assert_called_once()
|
||||
assert mock_index.call_args.kwargs["rebuild"] is False
|
||||
@@ -1,11 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
import tantivy
|
||||
|
||||
from documents.search._backend import TantivyBackend
|
||||
from documents.search._backend import reset_backend
|
||||
from documents.search._schema import build_schema
|
||||
from documents.search._tokenizer import register_tokenizers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
@@ -31,3 +35,11 @@ def backend() -> Generator[TantivyBackend, None, None]:
|
||||
finally:
|
||||
b.close()
|
||||
reset_backend()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def index() -> tantivy.Index:
|
||||
"""A real Tantivy index for parse-acceptance tests (module scope for speed)."""
|
||||
idx = tantivy.Index(build_schema(), path=tempfile.mkdtemp())
|
||||
register_tokenizers(idx, "english")
|
||||
return idx
|
||||
|
||||
@@ -13,7 +13,6 @@ import time_machine
|
||||
|
||||
from documents.search._query import _date_only_range
|
||||
from documents.search._query import _datetime_range
|
||||
from documents.search._query import _rewrite_compact_date
|
||||
from documents.search._query import build_permission_filter
|
||||
from documents.search._query import normalize_query
|
||||
from documents.search._query import parse_simple_text_highlight_query
|
||||
@@ -21,6 +20,7 @@ from documents.search._query import parse_user_query
|
||||
from documents.search._query import rewrite_natural_date_keywords
|
||||
from documents.search._schema import build_schema
|
||||
from documents.search._tokenizer import register_tokenizers
|
||||
from documents.search._translate import InvalidDateQuery
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.contrib.auth.base_user import AbstractBaseUser
|
||||
@@ -405,12 +405,14 @@ class TestWhooshQueryRewriting:
|
||||
assert lo == "2023-12-01T05:00:00Z"
|
||||
assert hi == "2023-12-02T05:00:00Z"
|
||||
|
||||
def test_8digit_invalid_date_passes_through_unchanged(self) -> None:
|
||||
assert rewrite_natural_date_keywords("added:20231340", UTC) == "added:20231340"
|
||||
|
||||
def test_compact_14digit_invalid_date_passes_through_unchanged(self) -> None:
|
||||
# Month=13 makes datetime() raise ValueError; the token must be left as-is
|
||||
assert _rewrite_compact_date("20231300120000") == "20231300120000"
|
||||
def test_8digit_invalid_date_raises(self) -> None:
|
||||
# The translation pipeline raises InvalidDateQuery for unparsable dates
|
||||
# (e.g. month=13) so the API can surface a 400 telling the user the date
|
||||
# is malformed instead of silently returning zero results.
|
||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
||||
rewrite_natural_date_keywords("added:20231340", UTC)
|
||||
assert exc_info.value.field == "added"
|
||||
assert exc_info.value.value == "20231340"
|
||||
|
||||
|
||||
class TestParseUserQuery:
|
||||
@@ -463,6 +465,67 @@ class TestParseUserQuery:
|
||||
) -> None:
|
||||
assert isinstance(parse_user_query(query_index, raw_query, UTC), tantivy.Query)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw_query",
|
||||
[
|
||||
# Partial date scalar (year only)
|
||||
pytest.param("created:2020", id="created_year_scalar"),
|
||||
# 8-digit compact date range in brackets
|
||||
pytest.param(
|
||||
"created:[20200101 TO 20201231]",
|
||||
id="created_8digit_bracket_range",
|
||||
),
|
||||
# Comma-separated field + date range (Whoosh v2 multi-clause syntax)
|
||||
pytest.param(
|
||||
"title:x,created:[2020 TO 2021]",
|
||||
id="title_comma_created_range",
|
||||
),
|
||||
# Field alias: type -> document_type
|
||||
pytest.param("type:invoice", id="type_alias"),
|
||||
# Multi-word date keyword
|
||||
pytest.param("created:previous week", id="created_previous_week"),
|
||||
# Full ISO datetime range
|
||||
pytest.param(
|
||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
|
||||
id="created_iso_range",
|
||||
),
|
||||
# Comma-separated ISO ranges (Whoosh v2 syntax)
|
||||
pytest.param(
|
||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
|
||||
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
|
||||
id="comma_iso_ranges",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_advanced_search_queries_do_not_raise(
|
||||
self,
|
||||
query_index: tantivy.Index,
|
||||
raw_query: str,
|
||||
) -> None:
|
||||
"""
|
||||
End-to-end: queries that the frontend sends must parse without raising.
|
||||
|
||||
This tests the full pipeline: translate_query -> tantivy parse_query.
|
||||
Equivalent to asserting HTTP 200 (not 400) for each query form.
|
||||
"""
|
||||
with time_machine.travel(datetime(2026, 6, 15, 12, 0, tzinfo=UTC), tick=False):
|
||||
assert isinstance(
|
||||
parse_user_query(query_index, raw_query, UTC),
|
||||
tantivy.Query,
|
||||
)
|
||||
|
||||
def test_invalid_date_propagates_not_swallowed(
|
||||
self,
|
||||
query_index: tantivy.Index,
|
||||
) -> None:
|
||||
# parse_user_query falls back to the raw query on unexpected translation
|
||||
# errors, but an InvalidDateQuery is intentional and must propagate so the
|
||||
# view can return a 400 instead of silently parsing the raw (invalid) date.
|
||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
||||
parse_user_query(query_index, "created:202023", UTC)
|
||||
assert exc_info.value.field == "created"
|
||||
assert exc_info.value.value == "202023"
|
||||
|
||||
|
||||
class TestYearRangeRewriting:
|
||||
"""Whoosh-style year-only date ranges must be rewritten to ISO 8601."""
|
||||
@@ -542,11 +605,16 @@ class TestYearRangeRewriting:
|
||||
assert rewrite_natural_date_keywords(original, UTC) == original
|
||||
|
||||
def test_8digit_in_brackets_not_matched_as_year_range(self) -> None:
|
||||
# [YYYYMMDD TO YYYYMMDD] has 8-digit values - must not be caught by year rewriter
|
||||
# [YYYYMMDD TO YYYYMMDD]: the translation layer converts 8-digit bounds to
|
||||
# ISO day ranges. 20200101 -> 2020-01-01T00:00:00Z (lo of that day);
|
||||
# 20201231 -> the ceil of Dec 31 = 2021-01-01T00:00:00Z (exclusive end).
|
||||
# This is the correct and accepted behavior: old compact form becomes a
|
||||
# proper Tantivy-parseable ISO range.
|
||||
original = "created:[20200101 TO 20201231]"
|
||||
result = rewrite_natural_date_keywords(original, UTC)
|
||||
assert "20200101" in result or "2020-01-01" in result
|
||||
assert "20201231" in result or "2020-12-31" in result
|
||||
lo, hi = _range(result, "created")
|
||||
assert lo == "2020-01-01T00:00:00Z"
|
||||
assert hi == "2021-01-01T00:00:00Z"
|
||||
|
||||
|
||||
class TestNonDateFieldsNotRewritten:
|
||||
@@ -606,6 +674,16 @@ class TestNormalizeQuery:
|
||||
def test_normalize_expands_comma_separated_tags(self) -> None:
|
||||
assert normalize_query("tag:foo,bar") == "tag:foo AND tag:bar"
|
||||
|
||||
def test_normalize_comma_between_range_expressions(self) -> None:
|
||||
# Comma-separated field range expressions (Whoosh v2 syntax) must be
|
||||
# converted to AND so Tantivy does not receive an invalid comma.
|
||||
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
||||
assert normalize_query(q) == (
|
||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
||||
" AND "
|
||||
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
||||
)
|
||||
|
||||
def test_normalize_expands_three_values(self) -> None:
|
||||
assert normalize_query("tag:foo,bar,baz") == "tag:foo AND tag:bar AND tag:baz"
|
||||
|
||||
|
||||
@@ -0,0 +1,742 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
import time_machine
|
||||
|
||||
from documents.search._dates import _precision_bounds
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tantivy
|
||||
from documents.search._query import _FIELD_BOOSTS
|
||||
from documents.search._query import DEFAULT_SEARCH_FIELDS
|
||||
from documents.search._translate import OPEN_HI
|
||||
from documents.search._translate import OPEN_LO
|
||||
from documents.search._translate import Comma
|
||||
from documents.search._translate import FieldRange
|
||||
from documents.search._translate import FieldValue
|
||||
from documents.search._translate import FieldValueList
|
||||
from documents.search._translate import InvalidDateQuery
|
||||
from documents.search._translate import Passthrough
|
||||
from documents.search._translate import resolve_commas
|
||||
from documents.search._translate import scan
|
||||
from documents.search._translate import translate_query
|
||||
from documents.search._translate import translate_range
|
||||
from documents.search._translate import translate_scalar
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestPrecisionBounds:
|
||||
@pytest.mark.parametrize(
|
||||
("digits", "expected"),
|
||||
[
|
||||
("2020", ((2020, 1, 1), (2021, 1, 1))),
|
||||
("202003", ((2020, 3, 1), (2020, 4, 1))),
|
||||
("202012", ((2020, 12, 1), (2021, 1, 1))),
|
||||
("20200115", ((2020, 1, 15), (2020, 1, 16))),
|
||||
("20201231", ((2020, 12, 31), (2021, 1, 1))),
|
||||
],
|
||||
)
|
||||
def test_valid(self, digits, expected):
|
||||
lo, hi = _precision_bounds(digits)
|
||||
assert (lo.year, lo.month, lo.day) == expected[0]
|
||||
assert (hi.year, hi.month, hi.day) == expected[1]
|
||||
|
||||
@pytest.mark.parametrize("digits", ["202023", "20200230", "20201301", "20", "abcd"])
|
||||
def test_invalid_returns_none(self, digits):
|
||||
assert _precision_bounds(digits) is None
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestScan:
|
||||
def test_plain_words_are_passthrough(self):
|
||||
assert scan("bank statement") == [Passthrough("bank statement")]
|
||||
|
||||
def test_field_value(self):
|
||||
assert scan("created:2020") == [FieldValue("created", "2020")]
|
||||
|
||||
def test_field_value_in_boolean(self):
|
||||
toks = scan("created:2020 OR foo")
|
||||
assert toks == [
|
||||
FieldValue("created", "2020"),
|
||||
Passthrough(" OR foo"),
|
||||
]
|
||||
|
||||
def test_field_value_in_parens(self):
|
||||
toks = scan("(created:2020 OR foo)")
|
||||
assert toks == [
|
||||
Passthrough("("),
|
||||
FieldValue("created", "2020"),
|
||||
Passthrough(" OR foo)"),
|
||||
]
|
||||
|
||||
def test_quoted_value(self):
|
||||
assert scan('correspondent:"A B"') == [FieldValue("correspondent", '"A B"')]
|
||||
|
||||
def test_field_range(self):
|
||||
assert scan("created:[2020 TO 2021]") == [
|
||||
FieldRange("created", "[", "2020", "2021", "]"),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("query", "expected"),
|
||||
[
|
||||
pytest.param(
|
||||
"created:[2020 to]",
|
||||
FieldRange("created", "[", "2020", "", "]"),
|
||||
id="open_upper",
|
||||
),
|
||||
pytest.param(
|
||||
"created:[to 2020]",
|
||||
FieldRange("created", "[", "", "2020", "]"),
|
||||
id="open_lower",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_open_range(self, query, expected):
|
||||
assert scan(query) == [expected]
|
||||
|
||||
def test_comma_inside_range_not_split(self):
|
||||
# No depth-0 comma here; the whole thing is one range token.
|
||||
toks = scan("created:[2020 TO 2021]")
|
||||
assert len(toks) == 1
|
||||
|
||||
# --- Edge-case / regression tests (scan must never raise) ---
|
||||
|
||||
def test_url_is_passthrough(self):
|
||||
# "http" is not a known field; the whole URL must pass through verbatim.
|
||||
assert scan("http://example.com") == [Passthrough("http://example.com")]
|
||||
|
||||
def test_unterminated_quote_is_passthrough(self):
|
||||
# title is a known field but the quoted value has no closing quote;
|
||||
# _consume_value returns None so the whole string falls into passthrough.
|
||||
assert scan('title:"abc') == [Passthrough('title:"abc')]
|
||||
|
||||
def test_unterminated_bracket_is_passthrough(self):
|
||||
# created is a known field but the range bracket is never closed;
|
||||
# _consume_range returns None so the whole string falls into passthrough.
|
||||
assert scan("created:[2020") == [Passthrough("created:[2020")]
|
||||
|
||||
def test_empty_value_at_end_is_passthrough(self):
|
||||
# created is a known field but there is no value after the colon
|
||||
# (_consume_value returns None for start >= n), so passthrough.
|
||||
assert scan("created:") == [Passthrough("created:")]
|
||||
|
||||
def test_value_containing_colon(self):
|
||||
# The bare-word value reader stops at whitespace/paren, not at colon,
|
||||
# so "2020:30" is consumed as a single value token.
|
||||
assert scan("created:2020:30") == [FieldValue("created", "2020:30")]
|
||||
|
||||
def test_comma_followed_by_unconsumable_value_stops(self):
|
||||
# A comma followed by whitespace is neither a value-list continuation nor a
|
||||
# clause separator: the value stops and the comma stays as passthrough.
|
||||
assert scan("tag:foo, bar") == [
|
||||
FieldValue("tag", "foo"),
|
||||
Passthrough(", bar"),
|
||||
]
|
||||
|
||||
def test_bracket_without_to_is_open_upper_bound(self):
|
||||
# A bracketed value with no TO falls back to (value, "") -> open upper bound.
|
||||
assert scan("created:[2020]") == [
|
||||
FieldRange("created", "[", "2020", "", "]"),
|
||||
]
|
||||
|
||||
def test_known_field_name_midword_is_passthrough(self):
|
||||
# A known field name embedded mid-word is not a field token (the
|
||||
# word-boundary guard); the whole run stays passthrough.
|
||||
assert scan("xtag:foo") == [Passthrough("xtag:foo")]
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestCommaResolution:
|
||||
def test_value_list_multi_value_field(self):
|
||||
toks = resolve_commas(scan("tag:foo,bar"))
|
||||
assert toks == [FieldValueList("tag", ("foo", "bar"))]
|
||||
|
||||
def test_value_list_three(self):
|
||||
toks = resolve_commas(scan("tag_id:1,2,3"))
|
||||
assert toks == [FieldValueList("tag_id", ("1", "2", "3"))]
|
||||
|
||||
def test_text_field_comma_is_literal(self):
|
||||
# correspondent is not multi-value: comma stays inside the value.
|
||||
toks = resolve_commas(scan("correspondent:foo,bar"))
|
||||
assert toks == [FieldValue("correspondent", "foo,bar")]
|
||||
|
||||
def test_clause_separator_before_known_field(self):
|
||||
toks = resolve_commas(scan("tag:foo,type:bar"))
|
||||
assert toks == [FieldValue("tag", "foo"), Comma(), FieldValue("type", "bar")]
|
||||
|
||||
def test_clause_separator_after_range(self):
|
||||
toks = resolve_commas(scan("created:[2020 TO 2021],added:[2022 TO 2023]"))
|
||||
assert toks == [
|
||||
FieldRange("created", "[", "2020", "2021", "]"),
|
||||
Comma(),
|
||||
FieldRange("added", "[", "2022", "2023", "]"),
|
||||
]
|
||||
|
||||
def test_clause_separator_after_quote(self):
|
||||
toks = resolve_commas(scan('correspondent:"A B",created:[2020 TO 2021]'))
|
||||
assert toks == [
|
||||
FieldValue("correspondent", '"A B"'),
|
||||
Comma(),
|
||||
FieldRange("created", "[", "2020", "2021", "]"),
|
||||
]
|
||||
|
||||
def test_url_comma_is_literal_passthrough(self):
|
||||
toks = resolve_commas(scan("http://example.com/a,b"))
|
||||
assert toks == [Passthrough("http://example.com/a,b")]
|
||||
|
||||
def test_non_multi_value_comma_is_literal(self):
|
||||
# title is not in MULTI_VALUE_FIELDS: comma stays inside the value.
|
||||
toks = resolve_commas(scan("title:10,20"))
|
||||
assert toks == [FieldValue("title", "10,20")]
|
||||
|
||||
def test_clause_separator_before_known_date_field(self):
|
||||
# The comma between a bare value and a known date field acts as a
|
||||
# clause separator; both sides survive as distinct tokens.
|
||||
toks = resolve_commas(scan("correspondent:foo,created:[2020 TO 2021]"))
|
||||
assert toks == [
|
||||
FieldValue("correspondent", "foo"),
|
||||
Comma(),
|
||||
FieldRange("created", "[", "2020", "2021", "]"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestTranslateScalar:
|
||||
@pytest.mark.parametrize(
|
||||
("field", "value", "expected"),
|
||||
[
|
||||
(
|
||||
"created",
|
||||
"2020",
|
||||
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
|
||||
),
|
||||
(
|
||||
"created",
|
||||
"202003",
|
||||
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
|
||||
),
|
||||
(
|
||||
"created",
|
||||
"20200115",
|
||||
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
|
||||
),
|
||||
(
|
||||
"created",
|
||||
"2020-01-15",
|
||||
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
|
||||
),
|
||||
(
|
||||
"created",
|
||||
"2020-03",
|
||||
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_partial_and_iso_dates(self, field: str, value: str, expected: str) -> None:
|
||||
assert translate_scalar(field, value, UTC) == expected
|
||||
|
||||
def test_invalid_date_raises(self) -> None:
|
||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
||||
translate_scalar("created", "202023", UTC)
|
||||
assert exc_info.value.field == "created"
|
||||
assert exc_info.value.value == "202023"
|
||||
|
||||
def test_keyword_delegates(self) -> None:
|
||||
# keyword path produces a range; just assert it is a created range
|
||||
out = translate_scalar("created", "today", UTC)
|
||||
assert out.startswith("created:[") and out.endswith("]")
|
||||
|
||||
def test_14digit_compact_datetime(self) -> None:
|
||||
out = translate_scalar("created", "20240115120000", UTC)
|
||||
assert "20240115120000" not in out
|
||||
assert out.startswith("created:")
|
||||
assert out == "created:[2024-01-15T12:00:00Z TO 2024-01-15T12:00:00Z]"
|
||||
|
||||
def test_14digit_invalid_month_raises(self) -> None:
|
||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
||||
translate_scalar("created", "20231300120000", UTC)
|
||||
assert exc_info.value.field == "created"
|
||||
assert exc_info.value.value == "20231300120000"
|
||||
|
||||
def test_unrecognized_value_raises(self) -> None:
|
||||
# A value that is not a keyword, digits, ISO date, or compact timestamp
|
||||
# raises rather than producing invalid Tantivy syntax or silently matching
|
||||
# nothing.
|
||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
||||
translate_scalar("created", "garbage", UTC)
|
||||
assert exc_info.value.field == "created"
|
||||
assert exc_info.value.value == "garbage"
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestTranslateRange:
|
||||
@pytest.mark.parametrize(
|
||||
("lo", "hi", "expected"),
|
||||
[
|
||||
("2005", "2009", "created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"),
|
||||
(
|
||||
"202001",
|
||||
"202006",
|
||||
"created:[2020-01-01T00:00:00Z TO 2020-07-01T00:00:00Z]",
|
||||
),
|
||||
(
|
||||
"20200101",
|
||||
"20201231",
|
||||
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
|
||||
),
|
||||
(
|
||||
"2020-01-01",
|
||||
"2020-12-31",
|
||||
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_absolute_ranges(self, lo, hi, expected):
|
||||
assert translate_range("created", lo, hi, UTC) == expected
|
||||
|
||||
def test_reversed_swaps(self):
|
||||
assert translate_range("created", "2009", "2005", UTC) == (
|
||||
"created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"
|
||||
)
|
||||
|
||||
def test_open_upper(self):
|
||||
out = translate_range("created", "2020", "", UTC)
|
||||
assert out == f"created:[2020-01-01T00:00:00Z TO {OPEN_HI}]"
|
||||
|
||||
def test_open_lower(self):
|
||||
out = translate_range("created", "", "2020", UTC)
|
||||
assert out == f"created:[{OPEN_LO} TO 2021-01-01T00:00:00Z]"
|
||||
|
||||
def test_invalid_bound_raises(self):
|
||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
||||
translate_range("created", "202023", "2025", UTC)
|
||||
assert exc_info.value.field == "created"
|
||||
assert exc_info.value.value == "202023"
|
||||
|
||||
def test_invalid_high_bound_raises(self):
|
||||
# Low bound parses, high bound does not -> raise on the high bound.
|
||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
||||
translate_range("created", "2020", "garbage", UTC)
|
||||
assert exc_info.value.field == "created"
|
||||
assert exc_info.value.value == "garbage"
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestTranslateQuery:
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "expected"),
|
||||
[
|
||||
(
|
||||
"created:2020",
|
||||
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
|
||||
),
|
||||
("tag:foo,bar", "tag:foo AND tag:bar"),
|
||||
# 'type' is a user-facing alias rewritten to 'document_type' (the real schema field)
|
||||
("tag:foo,type:bar", "tag:foo AND document_type:bar"),
|
||||
(
|
||||
"created:[2020 TO 2021],added:[2022 TO 2023]",
|
||||
"created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
|
||||
" AND "
|
||||
"added:[2022-01-01T00:00:00Z TO 2024-01-01T00:00:00Z]",
|
||||
),
|
||||
# correspondent is not multi-value: comma stays literal inside the value
|
||||
("correspondent:foo,bar", "correspondent:foo,bar"),
|
||||
],
|
||||
)
|
||||
def test_golden(self, raw: str, expected: str) -> None:
|
||||
assert translate_query(raw, UTC) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw",
|
||||
[
|
||||
"created:2020",
|
||||
"created:202003",
|
||||
"created:[20200101 TO 20201231]",
|
||||
"created:[2020-01-01 TO 2020-12-31]",
|
||||
"created:[2020 to]",
|
||||
"created:[to 2020]",
|
||||
"title:x,created:[2020 TO 2021]",
|
||||
"created:2020 OR foo",
|
||||
"(created:2020 OR invoice)",
|
||||
"tag:foo,type:bar",
|
||||
"bank statement",
|
||||
],
|
||||
)
|
||||
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
|
||||
translated = translate_query(raw, UTC)
|
||||
# Must not raise:
|
||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestFieldAliasing:
|
||||
"""Whoosh->Tantivy field-name aliasing (type/path -> document_type/storage_path)."""
|
||||
|
||||
def test_type_alias(self) -> None:
|
||||
assert translate_query("type:invoice", UTC) == "document_type:invoice"
|
||||
|
||||
def test_path_alias(self) -> None:
|
||||
assert translate_query("path:/foo/bar", UTC) == "storage_path:/foo/bar"
|
||||
|
||||
def test_type_id_alias(self) -> None:
|
||||
assert translate_query("type_id:5", UTC) == "document_type_id:5"
|
||||
|
||||
def test_path_id_alias(self) -> None:
|
||||
assert translate_query("path_id:7", UTC) == "storage_path_id:7"
|
||||
|
||||
def test_clause_separator_plus_alias(self) -> None:
|
||||
# Comma between known fields acts as AND separator; alias still applied.
|
||||
assert (
|
||||
translate_query("tag:foo,type:bar", UTC) == "tag:foo AND document_type:bar"
|
||||
)
|
||||
|
||||
def test_type_range_alias(self) -> None:
|
||||
# type is not a date field; range passes through verbatim with alias applied.
|
||||
assert (
|
||||
translate_query("type:[2020 TO 2021]", UTC)
|
||||
== "document_type:[2020 TO 2021]"
|
||||
)
|
||||
|
||||
def test_parse_acceptance_type(self, index: tantivy.Index) -> None:
|
||||
# Translated output must be accepted by the real Tantivy parser.
|
||||
translated = translate_query("type:invoice", UTC)
|
||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
||||
|
||||
def test_parse_acceptance_path(self, index: tantivy.Index) -> None:
|
||||
translated = translate_query("path:foo", UTC)
|
||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
||||
|
||||
|
||||
# Freeze time so relative-date tests are deterministic.
|
||||
_FROZEN_NOW = datetime(2026, 3, 28, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestRelativeRanges:
|
||||
"""Relative date-range tokens resolved against a frozen clock."""
|
||||
|
||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
||||
def test_minus_7_days_to_now(self) -> None:
|
||||
assert translate_query("added:[-7 days to now]", UTC) == (
|
||||
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
||||
)
|
||||
|
||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
||||
def test_minus_1_week_to_now(self) -> None:
|
||||
assert translate_query("added:[-1 week to now]", UTC) == (
|
||||
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
||||
)
|
||||
|
||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
||||
def test_minus_1_month_to_now(self) -> None:
|
||||
assert translate_query("created:[-1 month to now]", UTC) == (
|
||||
"created:[2026-02-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
||||
)
|
||||
|
||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
||||
def test_minus_1_year_to_now(self) -> None:
|
||||
assert translate_query("modified:[-1 year to now]", UTC) == (
|
||||
"modified:[2025-03-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
||||
)
|
||||
|
||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
||||
def test_minus_3_hours_to_now(self) -> None:
|
||||
assert translate_query("added:[-3 hours to now]", UTC) == (
|
||||
"added:[2026-03-28T09:00:00Z TO 2026-03-28T12:00:00Z]"
|
||||
)
|
||||
|
||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
||||
def test_uppercase_units(self) -> None:
|
||||
assert translate_query("added:[-1 WEEK TO NOW]", UTC) == (
|
||||
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
||||
)
|
||||
|
||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
||||
def test_now_minus_7d_compact(self) -> None:
|
||||
assert translate_query("added:[now-7d TO now]", UTC) == (
|
||||
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
|
||||
)
|
||||
|
||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
||||
def test_reversed_range_swapped(self) -> None:
|
||||
# now+1h TO now-1h is reversed; translate_range swaps -> lo=now-1h, hi=now+1h
|
||||
assert translate_query("added:[now+1h TO now-1h]", UTC) == (
|
||||
"added:[2026-03-28T11:00:00Z TO 2026-03-28T13:00:00Z]"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw",
|
||||
[
|
||||
"added:[-7 days to now]",
|
||||
"added:[-1 week to now]",
|
||||
"created:[-1 month to now]",
|
||||
"modified:[-1 year to now]",
|
||||
"added:[-3 hours to now]",
|
||||
"added:[now-7d TO now]",
|
||||
"added:[now+1h TO now-1h]",
|
||||
],
|
||||
)
|
||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
||||
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
|
||||
translated = translate_query(raw, UTC)
|
||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestOperatorNormalization:
|
||||
"""Post-render operator normalization in translate_query."""
|
||||
|
||||
def test_spaced_dash_removed(self) -> None:
|
||||
assert (
|
||||
translate_query("H52.1 - Kurzsichtigkeit", UTC) == "H52.1 Kurzsichtigkeit"
|
||||
)
|
||||
|
||||
def test_spaced_dash_simple(self) -> None:
|
||||
assert translate_query("bar - baz", UTC) == "bar baz"
|
||||
|
||||
def test_trailing_operator_stripped(self) -> None:
|
||||
assert translate_query("foo -", UTC) == "foo"
|
||||
|
||||
def test_date_range_preserved(self) -> None:
|
||||
out = translate_query("created:[2020 TO 2021]", UTC)
|
||||
# Must not corrupt the ISO range
|
||||
assert out == "created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
|
||||
|
||||
def test_date_scalar_with_or(self) -> None:
|
||||
out = translate_query("created:2020 OR foo", UTC)
|
||||
# The created scalar becomes a range; " OR foo" passes through verbatim.
|
||||
assert out.startswith("created:[")
|
||||
assert "OR foo" in out
|
||||
|
||||
def test_parse_acceptance_spaced_dash(self, index: tantivy.Index) -> None:
|
||||
translated = translate_query("H52.1 - Kurzsichtigkeit", UTC)
|
||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
||||
|
||||
def test_parse_acceptance_trailing_op(self, index: tantivy.Index) -> None:
|
||||
translated = translate_query("foo -", UTC)
|
||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestMultiWordDateKeywords:
|
||||
"""scan() must consume multi-word date keywords as a single value."""
|
||||
|
||||
def test_scan_previous_week_as_single_token(self) -> None:
|
||||
# "created:previous week" must produce one FieldValue with value "previous week",
|
||||
# not FieldValue("created","previous") + Passthrough(" week").
|
||||
toks = scan("created:previous week")
|
||||
assert toks == [FieldValue("created", "previous week")]
|
||||
|
||||
def test_scan_this_month_as_single_token(self) -> None:
|
||||
toks = scan("added:this month")
|
||||
assert toks == [FieldValue("added", "this month")]
|
||||
|
||||
def test_scan_previous_month_as_single_token(self) -> None:
|
||||
toks = scan("created:previous month")
|
||||
assert toks == [FieldValue("created", "previous month")]
|
||||
|
||||
def test_scan_this_year_as_single_token(self) -> None:
|
||||
toks = scan("added:this year")
|
||||
assert toks == [FieldValue("added", "this year")]
|
||||
|
||||
def test_scan_previous_year_as_single_token(self) -> None:
|
||||
toks = scan("created:previous year")
|
||||
assert toks == [FieldValue("created", "previous year")]
|
||||
|
||||
def test_scan_previous_quarter_as_single_token(self) -> None:
|
||||
toks = scan("created:previous quarter")
|
||||
assert toks == [FieldValue("created", "previous quarter")]
|
||||
|
||||
def test_quoted_multi_word_keyword_still_works(self) -> None:
|
||||
# The quoted form must continue to work as before.
|
||||
toks = scan('created:"previous week"')
|
||||
assert toks == [FieldValue("created", '"previous week"')]
|
||||
|
||||
def test_non_date_field_not_affected(self) -> None:
|
||||
# "previous" stops at the space for non-date fields; " week" passes through.
|
||||
toks = scan("correspondent:previous week")
|
||||
assert toks == [
|
||||
FieldValue("correspondent", "previous"),
|
||||
Passthrough(" week"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestKeywordDateResolution:
|
||||
"""Relative date keywords resolve to exact ISO ranges against a frozen clock.
|
||||
|
||||
Frozen at 2026-03-28 12:00 UTC (a Saturday in Q1) so the week, month,
|
||||
quarter and year rollovers are all exercised by a single anchor.
|
||||
"""
|
||||
|
||||
# created is a DateField: bounds are UTC midnight, no timezone offset.
|
||||
@pytest.mark.parametrize(
|
||||
("keyword", "expected"),
|
||||
[
|
||||
pytest.param(
|
||||
"today",
|
||||
"created:[2026-03-28T00:00:00Z TO 2026-03-29T00:00:00Z]",
|
||||
id="today",
|
||||
),
|
||||
pytest.param(
|
||||
"yesterday",
|
||||
"created:[2026-03-27T00:00:00Z TO 2026-03-28T00:00:00Z]",
|
||||
id="yesterday",
|
||||
),
|
||||
pytest.param(
|
||||
"previous week",
|
||||
"created:[2026-03-16T00:00:00Z TO 2026-03-23T00:00:00Z]",
|
||||
id="previous-week",
|
||||
),
|
||||
pytest.param(
|
||||
"this month",
|
||||
"created:[2026-03-01T00:00:00Z TO 2026-04-01T00:00:00Z]",
|
||||
id="this-month",
|
||||
),
|
||||
pytest.param(
|
||||
"previous month",
|
||||
"created:[2026-02-01T00:00:00Z TO 2026-03-01T00:00:00Z]",
|
||||
id="previous-month",
|
||||
),
|
||||
pytest.param(
|
||||
"this year",
|
||||
"created:[2026-01-01T00:00:00Z TO 2027-01-01T00:00:00Z]",
|
||||
id="this-year",
|
||||
),
|
||||
pytest.param(
|
||||
"previous year",
|
||||
"created:[2025-01-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
|
||||
id="previous-year",
|
||||
),
|
||||
pytest.param(
|
||||
"previous quarter",
|
||||
"created:[2025-10-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
|
||||
id="previous-quarter",
|
||||
),
|
||||
],
|
||||
)
|
||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
||||
def test_date_only_field_keyword_ranges(
|
||||
self,
|
||||
keyword: str,
|
||||
expected: str,
|
||||
) -> None:
|
||||
assert translate_query(f"created:{keyword}", UTC) == expected
|
||||
|
||||
# added is a DateTimeField: local-tz midnight converted to UTC. Tokyo
|
||||
# (+09:00, no DST) shifts each midnight boundary back to 15:00Z the day
|
||||
# before, so this also exercises the local-midnight offset path.
|
||||
@pytest.mark.parametrize(
|
||||
("keyword", "expected"),
|
||||
[
|
||||
pytest.param(
|
||||
"today",
|
||||
"added:[2026-03-27T15:00:00Z TO 2026-03-28T15:00:00Z]",
|
||||
id="today",
|
||||
),
|
||||
pytest.param(
|
||||
"yesterday",
|
||||
"added:[2026-03-26T15:00:00Z TO 2026-03-27T15:00:00Z]",
|
||||
id="yesterday",
|
||||
),
|
||||
pytest.param(
|
||||
"previous week",
|
||||
"added:[2026-03-15T15:00:00Z TO 2026-03-22T15:00:00Z]",
|
||||
id="previous-week",
|
||||
),
|
||||
pytest.param(
|
||||
"this month",
|
||||
"added:[2026-02-28T15:00:00Z TO 2026-03-31T15:00:00Z]",
|
||||
id="this-month",
|
||||
),
|
||||
pytest.param(
|
||||
"previous month",
|
||||
"added:[2026-01-31T15:00:00Z TO 2026-02-28T15:00:00Z]",
|
||||
id="previous-month",
|
||||
),
|
||||
pytest.param(
|
||||
"this year",
|
||||
"added:[2025-12-31T15:00:00Z TO 2026-12-31T15:00:00Z]",
|
||||
id="this-year",
|
||||
),
|
||||
pytest.param(
|
||||
"previous year",
|
||||
"added:[2024-12-31T15:00:00Z TO 2025-12-31T15:00:00Z]",
|
||||
id="previous-year",
|
||||
),
|
||||
pytest.param(
|
||||
"previous quarter",
|
||||
"added:[2025-09-30T15:00:00Z TO 2025-12-31T15:00:00Z]",
|
||||
id="previous-quarter",
|
||||
),
|
||||
],
|
||||
)
|
||||
@time_machine.travel(_FROZEN_NOW, tick=False)
|
||||
def test_datetime_field_keyword_ranges_local_tz(
|
||||
self,
|
||||
keyword: str,
|
||||
expected: str,
|
||||
) -> None:
|
||||
assert translate_query(f"added:{keyword}", ZoneInfo("Asia/Tokyo")) == expected
|
||||
|
||||
|
||||
@pytest.mark.search
|
||||
class TestISODatetimeBounds:
|
||||
"""Full ISO datetime tokens in range bounds must be parsed directly."""
|
||||
|
||||
def test_translate_range_iso_bounds_passthrough(self) -> None:
|
||||
# Already-ISO datetime bounds must pass through as-is (exact instant).
|
||||
result = translate_range(
|
||||
"created",
|
||||
"2020-01-01T00:00:00Z",
|
||||
"2021-01-01T00:00:00Z",
|
||||
UTC,
|
||||
)
|
||||
assert result == "created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]"
|
||||
|
||||
def test_translate_query_iso_range_preserved(self) -> None:
|
||||
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
||||
assert translate_query(q, UTC) == q
|
||||
|
||||
def test_translate_query_comma_separated_iso_ranges(self) -> None:
|
||||
q = (
|
||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
|
||||
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
||||
)
|
||||
result = translate_query(q, UTC)
|
||||
assert result == (
|
||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
||||
" AND "
|
||||
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
||||
)
|
||||
|
||||
def test_invalid_iso_datetime_raises(self) -> None:
|
||||
# A token with "T" that is not valid ISO datetime -> raise.
|
||||
with pytest.raises(InvalidDateQuery) as exc_info:
|
||||
translate_range(
|
||||
"created",
|
||||
"2020-01-01T99:00:00Z",
|
||||
"2021-01-01T00:00:00Z",
|
||||
UTC,
|
||||
)
|
||||
assert exc_info.value.field == "created"
|
||||
assert exc_info.value.value == "2020-01-01T99:00:00Z"
|
||||
|
||||
def test_parse_acceptance_iso_bounds(self, index: tantivy.Index) -> None:
|
||||
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
||||
translated = translate_query(q, UTC)
|
||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
||||
|
||||
def test_parse_acceptance_comma_iso_ranges(self, index: tantivy.Index) -> None:
|
||||
q = (
|
||||
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
|
||||
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
|
||||
)
|
||||
translated = translate_query(q, UTC)
|
||||
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
|
||||
@@ -82,6 +82,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
"llm_api_key": None,
|
||||
"llm_endpoint": None,
|
||||
"llm_output_language": None,
|
||||
"llm_request_timeout": None,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -844,7 +845,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
|
||||
with (
|
||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
||||
):
|
||||
mock_exists.return_value = False
|
||||
self.client.patch(
|
||||
@@ -869,7 +870,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
|
||||
with (
|
||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
||||
):
|
||||
mock_exists.return_value = True
|
||||
self.client.patch(
|
||||
@@ -890,7 +891,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
|
||||
with (
|
||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
||||
):
|
||||
mock_exists.return_value = True
|
||||
self.client.patch(
|
||||
@@ -928,7 +929,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
|
||||
with (
|
||||
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
|
||||
patch("paperless.views.vector_store_file_exists") as mock_exists,
|
||||
patch("paperless.views.llm_index_exists") as mock_exists,
|
||||
):
|
||||
mock_exists.return_value = True
|
||||
self.client.patch(
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
import unicodedata
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest import mock
|
||||
|
||||
import celery.result
|
||||
import pytest
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from documents.data_models import ConsumableDocument
|
||||
from documents.data_models import DocumentMetadataOverrides
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def consume_file_mock():
|
||||
with mock.patch("documents.tasks.consume_file.apply_async") as m:
|
||||
m.return_value = celery.result.AsyncResult(id="test-task-id")
|
||||
yield m
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def directories(tmp_path, settings, _media_settings):
|
||||
scratch = tmp_path / "scratch"
|
||||
scratch.mkdir()
|
||||
settings.SCRATCH_DIR = scratch
|
||||
return scratch
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestPostDocumentNFCNormalization:
|
||||
def test_nfd_filename_normalized_to_nfc(
|
||||
self,
|
||||
admin_client,
|
||||
consume_file_mock: mock.MagicMock,
|
||||
directories,
|
||||
):
|
||||
"""Uploaded file with NFD filename must have its name stored as NFC."""
|
||||
nfd = unicodedata.normalize("NFD", "Rechnung März.pdf")
|
||||
nfc = unicodedata.normalize("NFC", "Rechnung März.pdf")
|
||||
|
||||
# Verify our test strings actually differ at the byte level
|
||||
assert nfd != nfc
|
||||
|
||||
uploaded = SimpleUploadedFile(
|
||||
nfd,
|
||||
b"%PDF-1.4 test",
|
||||
content_type="application/pdf",
|
||||
)
|
||||
response = admin_client.post(
|
||||
"/api/documents/post_document/",
|
||||
{"document": uploaded},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
|
||||
input_doc: ConsumableDocument = task_kwargs["input_doc"]
|
||||
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
|
||||
|
||||
# The temp file on disk must have an NFC name
|
||||
assert input_doc.original_file.name == nfc, (
|
||||
f"Expected NFC filename {nfc!r}, got {input_doc.original_file.name!r}"
|
||||
)
|
||||
# The override filename stored for later use must also be NFC
|
||||
assert overrides.filename == nfc, (
|
||||
f"Expected NFC override filename {nfc!r}, got {overrides.filename!r}"
|
||||
)
|
||||
assert unicodedata.is_normalized("NFC", overrides.filename)
|
||||
|
||||
def test_already_nfc_filename_unchanged(
|
||||
self,
|
||||
admin_client,
|
||||
consume_file_mock: mock.MagicMock,
|
||||
directories,
|
||||
):
|
||||
"""Uploaded file with already-NFC filename must pass through unchanged."""
|
||||
nfc = unicodedata.normalize("NFC", "Invoice_2024.pdf")
|
||||
|
||||
uploaded = SimpleUploadedFile(
|
||||
nfc,
|
||||
b"%PDF-1.4 test",
|
||||
content_type="application/pdf",
|
||||
)
|
||||
response = admin_client.post(
|
||||
"/api/documents/post_document/",
|
||||
{"document": uploaded},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
|
||||
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
|
||||
|
||||
assert overrides.filename == nfc
|
||||
assert unicodedata.is_normalized("NFC", overrides.filename)
|
||||
@@ -725,9 +725,11 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
GIVEN:
|
||||
- One document added right now
|
||||
WHEN:
|
||||
- Query with invalid added date
|
||||
- Query with an invalid added date
|
||||
THEN:
|
||||
- 400 Bad Request returned (Tantivy rejects invalid date field syntax)
|
||||
- 400 Bad Request with a message naming the malformed date, so the
|
||||
user knows their date is invalid rather than silently getting zero
|
||||
results
|
||||
"""
|
||||
d1 = Document.objects.create(
|
||||
title="invoice",
|
||||
@@ -740,8 +742,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
|
||||
response = self.client.get("/api/documents/?query=added:invalid-date")
|
||||
|
||||
# Tantivy rejects unparsable field queries with a 400
|
||||
# An unparsable date is reported as a malformed query, not silently empty.
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("invalid-date", str(response.data["query"]))
|
||||
|
||||
@override_settings(
|
||||
TIME_ZONE="UTC",
|
||||
|
||||
@@ -216,6 +216,77 @@ class TestSystemStatus(APITestCase):
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
|
||||
|
||||
@mock.patch("celery.app.control.Inspect.ping")
|
||||
def test_system_status_celery_ping_none(self, mock_ping) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- Celery ping returns no worker responses
|
||||
WHEN:
|
||||
- The user requests the system status
|
||||
THEN:
|
||||
- The response contains a warning celery status
|
||||
"""
|
||||
mock_ping.return_value = None
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
|
||||
self.assertEqual(
|
||||
response.data["tasks"]["celery_error"],
|
||||
"No celery workers responded to ping. This may be temporary.",
|
||||
)
|
||||
|
||||
@mock.patch("celery.app.control.Inspect.ping")
|
||||
def test_system_status_celery_ping_unexpected_responses(self, mock_ping) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- Celery ping returns an unexpected worker response
|
||||
WHEN:
|
||||
- The user requests the system status
|
||||
THEN:
|
||||
- The response contains a warning celery status
|
||||
"""
|
||||
self.client.force_login(self.user)
|
||||
for ping_response in (
|
||||
{"hostname": {"ok": "not-pong"}},
|
||||
{"hostname": {}},
|
||||
{"hostname": "pong"},
|
||||
):
|
||||
with self.subTest(ping_response=ping_response):
|
||||
mock_ping.return_value = ping_response
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
|
||||
self.assertEqual(response.data["tasks"]["celery_url"], "hostname")
|
||||
self.assertEqual(
|
||||
response.data["tasks"]["celery_error"],
|
||||
"Celery worker responded unexpectedly.",
|
||||
)
|
||||
|
||||
@mock.patch("documents.views.sleep")
|
||||
@mock.patch("celery.app.control.Inspect.ping")
|
||||
def test_system_status_celery_ping_retry_success(
|
||||
self,
|
||||
mock_ping,
|
||||
mock_sleep,
|
||||
) -> None:
|
||||
"""
|
||||
GIVEN:
|
||||
- Celery ping fails once but succeeds on retry
|
||||
WHEN:
|
||||
- The user requests the system status
|
||||
THEN:
|
||||
- The response contains an OK celery status
|
||||
"""
|
||||
mock_ping.side_effect = [None, {"hostname": {"ok": "pong"}}]
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(self.ENDPOINT)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
|
||||
self.assertIsNone(response.data["tasks"]["celery_error"])
|
||||
self.assertEqual(mock_ping.call_count, 2)
|
||||
mock_sleep.assert_called_once_with(0.25)
|
||||
|
||||
@mock.patch("documents.search.get_backend")
|
||||
def test_system_status_index_ok(self, mock_get_backend) -> None:
|
||||
"""
|
||||
|
||||
@@ -18,6 +18,7 @@ from guardian.shortcuts import assign_perm
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from documents.filters import PaperlessTaskFilterSet
|
||||
from documents.models import PaperlessTask
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from documents.tests.factories import PaperlessTaskFactory
|
||||
@@ -169,6 +170,165 @@ class TestGetTasksV10:
|
||||
PaperlessTask.Status.STARTED,
|
||||
}
|
||||
|
||||
def test_filter_by_task_name(self, admin_client: APIClient) -> None:
|
||||
"""?name= searches task filenames, task types, and trigger sources."""
|
||||
filename_task = PaperlessTaskFactory(input_data={"filename": "invoice-123.pdf"})
|
||||
type_task = PaperlessTaskFactory(task_type=PaperlessTask.TaskType.SANITY_CHECK)
|
||||
source_task = PaperlessTaskFactory(
|
||||
trigger_source=PaperlessTask.TriggerSource.EMAIL_CONSUME,
|
||||
)
|
||||
PaperlessTaskFactory(input_data={"filename": "unrelated.pdf"})
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"name": "invoice"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["count"] == 1
|
||||
assert response.data["results"][0]["task_id"] == filename_task.task_id
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"name": "sanity"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["count"] == 1
|
||||
assert response.data["results"][0]["task_id"] == type_task.task_id
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"name": "email"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["count"] == 1
|
||||
assert response.data["results"][0]["task_id"] == source_task.task_id
|
||||
|
||||
def test_filter_by_task_result(self, admin_client: APIClient) -> None:
|
||||
"""?result= searches common structured task result messages."""
|
||||
reason_task = PaperlessTaskFactory(result_data={"reason": "Manual review"})
|
||||
error_task = PaperlessTaskFactory(
|
||||
result_data={"error_message": "Duplicate detected"},
|
||||
)
|
||||
document_task = PaperlessTaskFactory(result_data={"document_id": 321})
|
||||
duplicate_task = PaperlessTaskFactory(result_data={"duplicate_of": 123})
|
||||
PaperlessTaskFactory(result_data={"reason": "unrelated"})
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"result": "manual"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["count"] == 1
|
||||
assert response.data["results"][0]["task_id"] == reason_task.task_id
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"result": "duplicate"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
returned_ids = {task["task_id"] for task in response.data["results"]}
|
||||
assert returned_ids == {error_task.task_id, duplicate_task.task_id}
|
||||
|
||||
response = admin_client.get(ENDPOINT, {"result": "321"})
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data["count"] == 1
|
||||
assert response.data["results"][0]["task_id"] == document_task.task_id
|
||||
|
||||
def test_empty_task_name_and_result_filters(self) -> None:
|
||||
"""Empty name/result values leave the queryset unchanged."""
|
||||
PaperlessTaskFactory.create_batch(2)
|
||||
queryset = PaperlessTask.objects.all()
|
||||
filterset = PaperlessTaskFilterSet()
|
||||
|
||||
assert filterset.filter_name(queryset, "name", "").count() == 2
|
||||
assert filterset.filter_result(queryset, "result", "").count() == 2
|
||||
|
||||
def test_status_counts_respects_filters(self, admin_client: APIClient) -> None:
|
||||
"""status_counts/ returns section counts for the filtered task queryset."""
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.FAILURE,
|
||||
input_data={"filename": "invoice-a.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.REVOKED,
|
||||
input_data={"filename": "invoice-b.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.PENDING,
|
||||
input_data={"filename": "invoice-c.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.STARTED,
|
||||
input_data={"filename": "invoice-d.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.SUCCESS,
|
||||
input_data={"filename": "invoice-e.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=True,
|
||||
status=PaperlessTask.Status.SUCCESS,
|
||||
input_data={"filename": "invoice-acknowledged.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.SUCCESS,
|
||||
input_data={"filename": "unrelated.pdf"},
|
||||
)
|
||||
|
||||
response = admin_client.get(
|
||||
f"{ENDPOINT}status_counts/",
|
||||
{"acknowledged": "false", "name": "invoice"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == {
|
||||
"all": 5,
|
||||
"needs_attention": 2,
|
||||
"in_progress": 2,
|
||||
"completed": 1,
|
||||
}
|
||||
|
||||
def test_status_counts_ignores_section_filters(
|
||||
self,
|
||||
admin_client: APIClient,
|
||||
) -> None:
|
||||
"""status_counts/ ignores status-like filters for the sections it counts."""
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.FAILURE,
|
||||
input_data={"filename": "invoice-a.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.PENDING,
|
||||
input_data={"filename": "invoice-b.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.SUCCESS,
|
||||
input_data={"filename": "invoice-c.pdf"},
|
||||
)
|
||||
PaperlessTaskFactory(
|
||||
acknowledged=False,
|
||||
status=PaperlessTask.Status.FAILURE,
|
||||
input_data={"filename": "unrelated.pdf"},
|
||||
)
|
||||
|
||||
response = admin_client.get(
|
||||
f"{ENDPOINT}status_counts/",
|
||||
{
|
||||
"acknowledged": "false",
|
||||
"name": "invoice",
|
||||
"status": PaperlessTask.Status.FAILURE,
|
||||
"is_complete": "false",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == {
|
||||
"all": 3,
|
||||
"needs_attention": 1,
|
||||
"in_progress": 1,
|
||||
"completed": 1,
|
||||
}
|
||||
|
||||
def test_default_ordering_is_newest_first(self, admin_client: APIClient) -> None:
|
||||
"""Tasks are returned in descending date_created order (newest first)."""
|
||||
base = timezone.now()
|
||||
@@ -522,6 +682,27 @@ class TestAcknowledge:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == {"result": 2}
|
||||
|
||||
def test_acknowledge_all_returns_count(self, admin_client: APIClient) -> None:
|
||||
"""POST acknowledge/ with all=true acknowledges all unacknowledged tasks."""
|
||||
unacknowledged_task1 = PaperlessTaskFactory(acknowledged=False)
|
||||
unacknowledged_task2 = PaperlessTaskFactory(acknowledged=False)
|
||||
acknowledged_task = PaperlessTaskFactory(acknowledged=True)
|
||||
|
||||
response = admin_client.post(
|
||||
ENDPOINT + "acknowledge/",
|
||||
{"all": True},
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.data == {"result": 2}
|
||||
unacknowledged_task1.refresh_from_db()
|
||||
unacknowledged_task2.refresh_from_db()
|
||||
acknowledged_task.refresh_from_db()
|
||||
assert unacknowledged_task1.acknowledged
|
||||
assert unacknowledged_task2.acknowledged
|
||||
assert acknowledged_task.acknowledged
|
||||
|
||||
def test_acknowledged_tasks_excluded_from_unacked_filter(
|
||||
self,
|
||||
admin_client: APIClient,
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import date
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pikepdf
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import TestCase
|
||||
@@ -615,6 +616,18 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
self.img_doc.archive_filename = img_doc_archive
|
||||
self.img_doc.save()
|
||||
|
||||
@staticmethod
|
||||
def mock_password_required_pdf(
|
||||
mock_open: mock.Mock,
|
||||
fake_pdf: mock.Mock,
|
||||
) -> None:
|
||||
password_context = mock.MagicMock()
|
||||
password_context.__enter__.return_value = fake_pdf
|
||||
mock_open.side_effect = [
|
||||
pikepdf.PasswordError("password required"),
|
||||
password_context,
|
||||
]
|
||||
|
||||
@mock.patch("documents.tasks.consume_file.s")
|
||||
def test_merge(self, mock_consume_file) -> None:
|
||||
"""
|
||||
@@ -1466,6 +1479,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
|
||||
fake_pdf = mock.MagicMock()
|
||||
fake_pdf.pages = [mock.Mock(), mock.Mock(), mock.Mock()]
|
||||
fake_pdf.is_encrypted = True
|
||||
|
||||
def save_side_effect(target_path):
|
||||
Path(target_path).write_bytes(b"new pdf content")
|
||||
@@ -1480,7 +1494,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(doc.source_path, password="secret")
|
||||
self.assertEqual(
|
||||
mock_open.call_args_list,
|
||||
[
|
||||
mock.call(doc.source_path),
|
||||
mock.call(doc.source_path, password="secret"),
|
||||
],
|
||||
)
|
||||
fake_pdf.remove_unreferenced_resources.assert_called_once()
|
||||
mock_update_document.assert_not_called()
|
||||
mock_consume_delay.assert_called_once()
|
||||
@@ -1494,6 +1514,33 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
self.assertEqual(task_kwargs["input_doc"].root_document_id, doc.id)
|
||||
self.assertIsNotNone(task_kwargs["overrides"])
|
||||
|
||||
@mock.patch("documents.tasks.consume_file.apply_async")
|
||||
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
|
||||
@mock.patch("pikepdf.open")
|
||||
def test_remove_password_update_document_skips_unencrypted_pdf(
|
||||
self,
|
||||
mock_open,
|
||||
mock_mkdtemp,
|
||||
mock_consume_delay,
|
||||
) -> None:
|
||||
doc = self.doc1
|
||||
fake_pdf = mock.MagicMock()
|
||||
fake_pdf.is_encrypted = False
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
|
||||
result = bulk_edit.remove_password(
|
||||
[doc.id],
|
||||
password="secret",
|
||||
update_document=True,
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(doc.source_path)
|
||||
fake_pdf.remove_unreferenced_resources.assert_not_called()
|
||||
fake_pdf.save.assert_not_called()
|
||||
mock_mkdtemp.assert_not_called()
|
||||
mock_consume_delay.assert_not_called()
|
||||
|
||||
@mock.patch("documents.bulk_edit.update_document_content_maybe_archive_file.delay")
|
||||
@mock.patch("documents.tasks.consume_file.apply_async")
|
||||
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
|
||||
@@ -1513,12 +1560,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
mock_mkdtemp.return_value = str(temp_dir)
|
||||
|
||||
fake_pdf = mock.MagicMock()
|
||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
||||
|
||||
def save_side_effect(target_path):
|
||||
Path(target_path).write_bytes(b"new pdf content")
|
||||
|
||||
fake_pdf.save.side_effect = save_side_effect
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
|
||||
result = bulk_edit.remove_password(
|
||||
[doc.id],
|
||||
@@ -1528,7 +1575,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(source_file, password="secret")
|
||||
self.assertEqual(
|
||||
mock_open.call_args_list,
|
||||
[
|
||||
mock.call(source_file),
|
||||
mock.call(source_file, password="secret"),
|
||||
],
|
||||
)
|
||||
mock_update_document.assert_not_called()
|
||||
mock_consume_delay.assert_called_once()
|
||||
|
||||
@@ -1547,7 +1600,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
root_document=self.doc1,
|
||||
)
|
||||
fake_pdf = mock.MagicMock()
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
||||
|
||||
result = bulk_edit.remove_password(
|
||||
[self.doc1.id],
|
||||
@@ -1557,7 +1610,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(self.doc1.source_path, password="secret")
|
||||
self.assertEqual(
|
||||
mock_open.call_args_list,
|
||||
[
|
||||
mock.call(self.doc1.source_path),
|
||||
mock.call(self.doc1.source_path, password="secret"),
|
||||
],
|
||||
)
|
||||
mock_consume_delay.assert_called_once()
|
||||
|
||||
@mock.patch("documents.bulk_edit.chord")
|
||||
@@ -1580,12 +1639,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
|
||||
fake_pdf = mock.MagicMock()
|
||||
fake_pdf.pages = [mock.Mock(), mock.Mock()]
|
||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
||||
|
||||
def save_side_effect(target_path: Path) -> None:
|
||||
target_path.write_bytes(b"password removed")
|
||||
|
||||
fake_pdf.save.side_effect = save_side_effect
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
mock_group.return_value.delay.return_value = None
|
||||
|
||||
user = User.objects.create(username="owner")
|
||||
@@ -1600,7 +1659,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(doc.source_path, password="secret")
|
||||
self.assertEqual(
|
||||
mock_open.call_args_list,
|
||||
[
|
||||
mock.call(doc.source_path),
|
||||
mock.call(doc.source_path, password="secret"),
|
||||
],
|
||||
)
|
||||
mock_consume_file.assert_called_once()
|
||||
call_kwargs = mock_consume_file.call_args.kwargs
|
||||
consumable_document = call_kwargs["input_doc"]
|
||||
@@ -1618,6 +1683,43 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
mock_group.return_value.delay.assert_called_once()
|
||||
mock_chord.assert_not_called()
|
||||
|
||||
@mock.patch("documents.bulk_edit.delete")
|
||||
@mock.patch("documents.bulk_edit.chord")
|
||||
@mock.patch("documents.bulk_edit.group")
|
||||
@mock.patch("documents.tasks.consume_file.s")
|
||||
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
|
||||
@mock.patch("pikepdf.open")
|
||||
def test_remove_password_skips_unencrypted_pdf_without_queueing(
|
||||
self,
|
||||
mock_open: mock.Mock,
|
||||
mock_mkdtemp: mock.Mock,
|
||||
mock_consume_file: mock.Mock,
|
||||
mock_group: mock.Mock,
|
||||
mock_chord: mock.Mock,
|
||||
mock_delete: mock.Mock,
|
||||
) -> None:
|
||||
doc = self.doc2
|
||||
fake_pdf = mock.MagicMock()
|
||||
fake_pdf.is_encrypted = False
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
|
||||
result = bulk_edit.remove_password(
|
||||
[doc.id],
|
||||
password="secret",
|
||||
update_document=False,
|
||||
delete_original=True,
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(doc.source_path)
|
||||
fake_pdf.remove_unreferenced_resources.assert_not_called()
|
||||
fake_pdf.save.assert_not_called()
|
||||
mock_mkdtemp.assert_not_called()
|
||||
mock_consume_file.assert_not_called()
|
||||
mock_group.assert_not_called()
|
||||
mock_chord.assert_not_called()
|
||||
mock_delete.si.assert_not_called()
|
||||
|
||||
@mock.patch("documents.bulk_edit.delete")
|
||||
@mock.patch("documents.bulk_edit.chord")
|
||||
@mock.patch("documents.bulk_edit.group")
|
||||
@@ -1640,12 +1742,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
|
||||
fake_pdf = mock.MagicMock()
|
||||
fake_pdf.pages = [mock.Mock(), mock.Mock()]
|
||||
self.mock_password_required_pdf(mock_open, fake_pdf)
|
||||
|
||||
def save_side_effect(target_path: Path) -> None:
|
||||
target_path.write_bytes(b"password removed")
|
||||
|
||||
fake_pdf.save.side_effect = save_side_effect
|
||||
mock_open.return_value.__enter__.return_value = fake_pdf
|
||||
mock_chord.return_value.delay.return_value = None
|
||||
|
||||
result = bulk_edit.remove_password(
|
||||
@@ -1657,7 +1759,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(result, "OK")
|
||||
mock_open.assert_called_once_with(doc.source_path, password="secret")
|
||||
self.assertEqual(
|
||||
mock_open.call_args_list,
|
||||
[
|
||||
mock.call(doc.source_path),
|
||||
mock.call(doc.source_path, password="secret"),
|
||||
],
|
||||
)
|
||||
mock_consume_file.assert_called_once()
|
||||
mock_group.assert_not_called()
|
||||
mock_chord.assert_called_once()
|
||||
|
||||
@@ -24,6 +24,7 @@ from documents.models import CustomFieldInstance
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.serialisers import DocumentSerializer
|
||||
from documents.tasks import empty_trash
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
@@ -221,8 +222,8 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
doc = Document.objects.create(
|
||||
title="document",
|
||||
mime_type="application/pdf",
|
||||
checksum=hashlib.md5(original_bytes).hexdigest(),
|
||||
archive_checksum=hashlib.md5(archive_bytes).hexdigest(),
|
||||
checksum=hashlib.sha256(original_bytes).hexdigest(),
|
||||
archive_checksum=hashlib.sha256(archive_bytes).hexdigest(),
|
||||
filename="old/document.pdf",
|
||||
archive_filename="old/document.pdf",
|
||||
storage_path=old_storage_path,
|
||||
@@ -251,6 +252,46 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
|
||||
self.assertIsNotFile(settings.ORIGINALS_DIR / "old" / "document.pdf")
|
||||
self.assertIsNotFile(settings.ARCHIVE_DIR / "old" / "document.pdf")
|
||||
|
||||
@override_settings(FILENAME_FORMAT="{title}")
|
||||
def test_serializer_stale_update_does_not_clobber_filename(self) -> None:
|
||||
old_path = settings.ORIGINALS_DIR / "original.pdf"
|
||||
old_path.touch()
|
||||
doc = Document.objects.create(
|
||||
title="original",
|
||||
mime_type="application/pdf",
|
||||
checksum=hashlib.sha256(b"").hexdigest(),
|
||||
filename="original.pdf",
|
||||
)
|
||||
|
||||
first_instance = Document.objects.get(pk=doc.pk)
|
||||
stale_instance = Document.objects.get(pk=doc.pk)
|
||||
|
||||
serializer = DocumentSerializer(
|
||||
first_instance,
|
||||
data={"title": "first"},
|
||||
partial=True,
|
||||
)
|
||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
||||
serializer.save()
|
||||
|
||||
doc.refresh_from_db()
|
||||
self.assertEqual(doc.filename, "first.pdf")
|
||||
self.assertIsFile(settings.ORIGINALS_DIR / "first.pdf")
|
||||
|
||||
serializer = DocumentSerializer(
|
||||
stale_instance,
|
||||
data={"title": "second"},
|
||||
partial=True,
|
||||
)
|
||||
self.assertTrue(serializer.is_valid(), serializer.errors)
|
||||
serializer.save()
|
||||
|
||||
doc.refresh_from_db()
|
||||
self.assertEqual(doc.filename, "second.pdf")
|
||||
self.assertIsFile(settings.ORIGINALS_DIR / "second.pdf")
|
||||
self.assertIsNotFile(settings.ORIGINALS_DIR / "first.pdf")
|
||||
self.assertIsNotFile(old_path)
|
||||
|
||||
@override_settings(FILENAME_FORMAT="{correspondent}/{correspondent}")
|
||||
def test_document_delete(self) -> None:
|
||||
document = Document()
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Tests for NFC Unicode normalization in generate_filename / FilePathTemplate.render().
|
||||
|
||||
NFC `ü` (UTF-8: c3 bc) and NFD `ü` (UTF-8: 75 cc 88) are visually identical but
|
||||
produce different byte sequences. On Linux (ext4, ZFS) these are distinct filenames.
|
||||
All paths produced by the templating system must be NFC-normalized.
|
||||
"""
|
||||
|
||||
import unicodedata
|
||||
|
||||
import pytest
|
||||
|
||||
from documents.file_handling import generate_filename
|
||||
from documents.models import CustomField
|
||||
from documents.models import CustomFieldInstance
|
||||
from documents.tests.factories import CorrespondentFactory
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from documents.tests.factories import StoragePathFactory
|
||||
from documents.tests.factories import TagFactory
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestGenerateFilenameNFCNormalization:
|
||||
@pytest.mark.parametrize(
|
||||
"raw,display",
|
||||
[
|
||||
(unicodedata.normalize("NFD", "Gemüse"), "Gemüse"),
|
||||
(unicodedata.normalize("NFD", "Café"), "Café"),
|
||||
(unicodedata.normalize("NFD", "naïve"), "naïve"),
|
||||
],
|
||||
)
|
||||
def test_nfd_title_normalized_to_nfc(self, settings, raw, display):
|
||||
"""NFD title must produce NFC path bytes."""
|
||||
settings.FILENAME_FORMAT = "{{ title }}"
|
||||
nfc = unicodedata.normalize("NFC", display)
|
||||
assert raw != nfc # confirm byte-level difference
|
||||
|
||||
doc = DocumentFactory(title=raw, mime_type="application/pdf")
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result) == f"{nfc}.pdf"
|
||||
assert str(result).encode() == f"{nfc}.pdf".encode()
|
||||
|
||||
def test_nfd_correspondent_normalized_to_nfc(self, settings):
|
||||
"""NFD correspondent name must produce NFC path component."""
|
||||
settings.FILENAME_FORMAT = "{{ correspondent }}/{{ title }}"
|
||||
nfd = unicodedata.normalize("NFD", "Müller")
|
||||
nfc = unicodedata.normalize("NFC", "Müller")
|
||||
|
||||
correspondent = CorrespondentFactory(name=nfd)
|
||||
doc = DocumentFactory(
|
||||
title="invoice",
|
||||
correspondent=correspondent,
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result) == f"{nfc}/invoice.pdf"
|
||||
assert str(result).encode() == f"{nfc}/invoice.pdf".encode()
|
||||
|
||||
def test_nfd_storage_path_normalized_to_nfc(self, settings):
|
||||
"""NFD literal in StoragePath.path template must produce NFC path bytes."""
|
||||
settings.FILENAME_FORMAT = None
|
||||
nfd = unicodedata.normalize("NFD", "Büro")
|
||||
nfc = unicodedata.normalize("NFC", "Büro")
|
||||
|
||||
# StoragePath.path is used directly as the format/template string.
|
||||
# Literal NFD characters in the template must survive rendering as NFC.
|
||||
sp = StoragePathFactory(path=f"{nfd}/{{{{ title }}}}")
|
||||
doc = DocumentFactory(title="doc", storage_path=sp, mime_type="application/pdf")
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
|
||||
|
||||
def test_nfd_raw_document_title_normalized_to_nfc(self, settings):
|
||||
"""NFD title accessed via document.title (unsanitized context) must also be NFC."""
|
||||
settings.FILENAME_FORMAT = "{{ document.title }}"
|
||||
nfd = unicodedata.normalize("NFD", "Café")
|
||||
nfc = unicodedata.normalize("NFC", "Café")
|
||||
|
||||
doc = DocumentFactory(title=nfd, mime_type="application/pdf")
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result) == f"{nfc}.pdf"
|
||||
assert str(result).encode() == f"{nfc}.pdf".encode()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestContextBuilderNFCNormalization:
|
||||
"""
|
||||
Defense-in-depth: context builder functions must NFC-normalize string inputs
|
||||
before passing them to sanitize_filename(). Task 1 already normalizes the
|
||||
final rendered path via clean_filepath(), so these tests may already pass;
|
||||
they exist as regression guards for the context-builder layer.
|
||||
"""
|
||||
|
||||
def test_nfd_tag_name_normalized_in_tag_list(self, settings):
|
||||
"""NFD tag name must appear as NFC bytes in the {{ tag_list }} shorthand."""
|
||||
settings.FILENAME_FORMAT = "{{ tag_list }}/{{ title }}"
|
||||
nfd = unicodedata.normalize("NFD", "Büro")
|
||||
nfc = unicodedata.normalize("NFC", "Büro")
|
||||
assert nfd != nfc # confirm they differ at byte level
|
||||
|
||||
tag = TagFactory(name=nfd)
|
||||
doc = DocumentFactory(title="doc", mime_type="application/pdf")
|
||||
doc.tags.set([tag])
|
||||
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
|
||||
|
||||
def test_nfd_original_name_normalized_to_nfc(self, settings):
|
||||
settings.FILENAME_FORMAT = "{{ original_name }}"
|
||||
nfd = unicodedata.normalize("NFD", "Rechnung März")
|
||||
nfc = unicodedata.normalize("NFC", "Rechnung März")
|
||||
|
||||
doc = DocumentFactory(
|
||||
original_filename=f"{nfd}.pdf",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result).encode() == f"{nfc}.pdf".encode()
|
||||
|
||||
def test_nfd_custom_field_string_value_normalized(self, settings):
|
||||
"""NFD value in a STRING-type custom field must appear as NFC in the context."""
|
||||
settings.FILENAME_FORMAT = (
|
||||
"{{ custom_fields['Location']['value'] }}/{{ title }}"
|
||||
)
|
||||
nfd_value = unicodedata.normalize("NFD", "Düsseldorf")
|
||||
nfc_value = unicodedata.normalize("NFC", "Düsseldorf")
|
||||
assert nfd_value != nfc_value
|
||||
|
||||
doc = DocumentFactory(title="report", mime_type="application/pdf")
|
||||
cf = CustomField.objects.create(
|
||||
name="Location",
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
CustomFieldInstance.objects.create(
|
||||
document=doc,
|
||||
field=cf,
|
||||
value_text=nfd_value,
|
||||
)
|
||||
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result).encode() == f"{nfc_value}/report.pdf".encode()
|
||||
|
||||
def test_nfd_custom_field_name_normalized_as_key(self, settings):
|
||||
"""NFD characters in a custom field name must appear as NFC in the context dict key."""
|
||||
nfd_name = unicodedata.normalize("NFD", "Größe")
|
||||
nfc_name = unicodedata.normalize("NFC", "Größe")
|
||||
assert nfd_name != nfc_name
|
||||
|
||||
settings.FILENAME_FORMAT = f"{{% if custom_fields['{nfc_name}'] %}}{{{{ custom_fields['{nfc_name}']['value'] }}}}/{{{{ title }}}}{{% else %}}{{{{ title }}}}{{% endif %}}"
|
||||
|
||||
doc = DocumentFactory(title="letter", mime_type="application/pdf")
|
||||
cf = CustomField.objects.create(
|
||||
name=nfd_name,
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
CustomFieldInstance.objects.create(
|
||||
document=doc,
|
||||
field=cf,
|
||||
value_text="Berlin",
|
||||
)
|
||||
|
||||
result = generate_filename(doc)
|
||||
|
||||
# If field name key is NFC-normalized, the template condition succeeds
|
||||
# and result is "Berlin/letter.pdf"; otherwise it falls back to "letter.pdf"
|
||||
assert str(result) == "Berlin/letter.pdf"
|
||||
|
||||
def test_nfd_tag_name_list_normalized_to_nfc(self, settings):
|
||||
"""NFD tag names in tag_name_list must appear as NFC bytes when iterated."""
|
||||
settings.FILENAME_FORMAT = (
|
||||
"{% for t in tag_name_list %}{{ t }}{% endfor %}/{{ title }}"
|
||||
)
|
||||
nfd = unicodedata.normalize("NFD", "Büro")
|
||||
nfc = unicodedata.normalize("NFC", "Büro")
|
||||
assert nfd != nfc # confirm byte-level difference
|
||||
|
||||
doc = DocumentFactory(title="doc", mime_type="application/pdf")
|
||||
doc.tags.add(TagFactory(name=nfd))
|
||||
result = generate_filename(doc)
|
||||
|
||||
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
|
||||
@@ -684,6 +684,7 @@ class ConsumerThread(Thread):
|
||||
subdirs_as_tags: bool = False,
|
||||
polling_interval: float = 0,
|
||||
stability_delay: float = 0.1,
|
||||
rescan_interval: float | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.consumption_dir = consumption_dir
|
||||
@@ -693,6 +694,8 @@ class ConsumerThread(Thread):
|
||||
self.polling_interval = polling_interval
|
||||
self.stability_delay = stability_delay
|
||||
self.cmd = Command()
|
||||
if rescan_interval is not None:
|
||||
self.cmd.rescan_interval_s = rescan_interval
|
||||
self.cmd.stop_flag.clear()
|
||||
# Non-daemon ensures finally block runs and connections are closed
|
||||
self.daemon = False
|
||||
@@ -1052,3 +1055,200 @@ class TestCommandWatchEdgeCases:
|
||||
thread.stop_and_wait(timeout=5.0)
|
||||
# Clean up any Tags created by the thread
|
||||
Tag.objects.all().delete()
|
||||
|
||||
|
||||
class TestRescanExistingFiles:
|
||||
"""
|
||||
Unit tests for the rescan safety net.
|
||||
|
||||
Each ``watch()`` recreation silently adopts the current directory contents
|
||||
as its baseline, so a file appearing between one batch and the next
|
||||
watcher's baseline is never reported and would sit in the consume directory
|
||||
forever. ``_rescan_existing_files`` re-injects such files into the
|
||||
stability tracker as a periodic safety net (see GH issue #13011).
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def pdf_only_filter(self) -> ConsumerFilter:
|
||||
return ConsumerFilter(
|
||||
supported_extensions=frozenset({".pdf"}),
|
||||
ignore_patterns=[],
|
||||
)
|
||||
|
||||
def _rescan(
|
||||
self,
|
||||
directory: Path,
|
||||
consumer_filter: ConsumerFilter,
|
||||
tracker: FileStabilityTracker,
|
||||
queued: set[Path],
|
||||
*,
|
||||
recursive: bool = False,
|
||||
) -> None:
|
||||
Command()._rescan_existing_files(
|
||||
directory=directory,
|
||||
recursive=recursive,
|
||||
consumer_filter=consumer_filter,
|
||||
tracker=tracker,
|
||||
queued=queued,
|
||||
)
|
||||
|
||||
def test_tracks_stranded_file(
|
||||
self,
|
||||
consumption_dir: Path,
|
||||
sample_pdf: Path,
|
||||
pdf_only_filter: ConsumerFilter,
|
||||
) -> None:
|
||||
"""A supported on-disk file the watcher never reported gets tracked."""
|
||||
target = consumption_dir / "stranded.pdf"
|
||||
shutil.copy(sample_pdf, target)
|
||||
tracker = FileStabilityTracker(stability_delay=0.1)
|
||||
|
||||
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
|
||||
|
||||
assert tracker.is_tracking(target) is True
|
||||
assert tracker.pending_count == 1
|
||||
|
||||
def test_skips_already_tracked_file(
|
||||
self,
|
||||
consumption_dir: Path,
|
||||
sample_pdf: Path,
|
||||
pdf_only_filter: ConsumerFilter,
|
||||
) -> None:
|
||||
"""A file already being tracked by the watcher is not double-tracked."""
|
||||
target = consumption_dir / "tracked.pdf"
|
||||
shutil.copy(sample_pdf, target)
|
||||
tracker = FileStabilityTracker(stability_delay=0.1)
|
||||
tracker.track(target, Change.added)
|
||||
|
||||
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
|
||||
|
||||
assert tracker.pending_count == 1
|
||||
|
||||
def test_skips_queued_file(
|
||||
self,
|
||||
consumption_dir: Path,
|
||||
sample_pdf: Path,
|
||||
pdf_only_filter: ConsumerFilter,
|
||||
) -> None:
|
||||
"""A file already queued and awaiting consumption is not re-tracked."""
|
||||
target = consumption_dir / "inflight.pdf"
|
||||
shutil.copy(sample_pdf, target)
|
||||
tracker = FileStabilityTracker(stability_delay=0.1)
|
||||
queued = {target.resolve()}
|
||||
|
||||
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
|
||||
|
||||
assert tracker.pending_count == 0
|
||||
|
||||
def test_prunes_vanished_queued_paths(
|
||||
self,
|
||||
consumption_dir: Path,
|
||||
pdf_only_filter: ConsumerFilter,
|
||||
) -> None:
|
||||
"""Queued paths no longer on disk are dropped so the name can recur."""
|
||||
gone = (consumption_dir / "gone.pdf").resolve()
|
||||
tracker = FileStabilityTracker(stability_delay=0.1)
|
||||
queued = {gone}
|
||||
|
||||
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
|
||||
|
||||
assert gone not in queued
|
||||
|
||||
def test_skips_unsupported_extension(
|
||||
self,
|
||||
consumption_dir: Path,
|
||||
pdf_only_filter: ConsumerFilter,
|
||||
) -> None:
|
||||
"""Files filtered out by the consumer filter are not tracked."""
|
||||
(consumption_dir / "notes.xyz").write_bytes(b"content")
|
||||
tracker = FileStabilityTracker(stability_delay=0.1)
|
||||
|
||||
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
|
||||
|
||||
assert tracker.pending_count == 0
|
||||
|
||||
def test_recursive_respects_flag(
|
||||
self,
|
||||
consumption_dir: Path,
|
||||
sample_pdf: Path,
|
||||
pdf_only_filter: ConsumerFilter,
|
||||
) -> None:
|
||||
"""Nested files are only found when recursive scanning is enabled."""
|
||||
subdir = consumption_dir / "nested"
|
||||
subdir.mkdir()
|
||||
target = subdir / "deep.pdf"
|
||||
shutil.copy(sample_pdf, target)
|
||||
|
||||
shallow = FileStabilityTracker(stability_delay=0.1)
|
||||
self._rescan(consumption_dir, pdf_only_filter, shallow, set())
|
||||
assert shallow.pending_count == 0
|
||||
|
||||
deep = FileStabilityTracker(stability_delay=0.1)
|
||||
self._rescan(consumption_dir, pdf_only_filter, deep, set(), recursive=True)
|
||||
assert deep.is_tracking(target) is True
|
||||
|
||||
|
||||
class TestProcessExistingFilesQueued:
|
||||
"""Tests that startup processing reports which paths it queued."""
|
||||
|
||||
@pytest.mark.usefixtures("mock_supported_extensions")
|
||||
def test_returns_queued_paths(
|
||||
self,
|
||||
consumption_dir: Path,
|
||||
sample_pdf: Path,
|
||||
mock_consume_file_delay: MagicMock,
|
||||
settings: SettingsWrapper,
|
||||
) -> None:
|
||||
"""The set returned seeds the rescan's queued set, avoiding re-queue."""
|
||||
target = consumption_dir / "document.pdf"
|
||||
shutil.copy(sample_pdf, target)
|
||||
settings.CONSUMER_IGNORE_PATTERNS = []
|
||||
|
||||
queued = Command()._process_existing_files(
|
||||
directory=consumption_dir,
|
||||
recursive=False,
|
||||
subdirs_as_tags=False,
|
||||
consumer_filter=ConsumerFilter(ignore_patterns=[]),
|
||||
)
|
||||
|
||||
assert target.resolve() in queued
|
||||
|
||||
|
||||
@pytest.mark.management
|
||||
@pytest.mark.django_db
|
||||
class TestCommandRescanRecovery:
|
||||
"""End-to-end test that the rescan recovers files the watcher misses."""
|
||||
|
||||
def test_rescan_consumes_file_the_watcher_never_reports(
|
||||
self,
|
||||
consumption_dir: Path,
|
||||
sample_pdf: Path,
|
||||
mock_consume_file_delay: MagicMock,
|
||||
start_consumer: Callable[..., ConsumerThread],
|
||||
) -> None:
|
||||
"""
|
||||
Isolate the rescan path: a long polling interval guarantees the
|
||||
watcher cannot report the file within the test window, so only the
|
||||
periodic rescan can consume it.
|
||||
"""
|
||||
# poll interval far longer than the test window -> watcher stays silent
|
||||
thread = start_consumer(
|
||||
polling_interval=30.0,
|
||||
stability_delay=0.1,
|
||||
rescan_interval=0.5,
|
||||
)
|
||||
|
||||
# created after startup, so _process_existing_files did not see it
|
||||
target = consumption_dir / "stranded.pdf"
|
||||
shutil.copy(sample_pdf, target)
|
||||
|
||||
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=5.0)
|
||||
|
||||
if thread.exception:
|
||||
raise thread.exception
|
||||
|
||||
mock_consume_file_delay.apply_async.assert_called()
|
||||
call_args = mock_consume_file_delay.apply_async.call_args.kwargs["kwargs"][
|
||||
"input_doc"
|
||||
]
|
||||
assert call_args.original_file.name == "stranded.pdf"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from datetime import timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from documents.conditionals import metadata_etag
|
||||
from documents.conditionals import preview_etag
|
||||
@@ -29,10 +31,31 @@ class TestConditionals(DirectoriesMixin, TestCase):
|
||||
)
|
||||
request = SimpleNamespace(query_params={})
|
||||
|
||||
self.assertEqual(metadata_etag(request, root.id), latest.checksum)
|
||||
self.assertEqual(
|
||||
metadata_etag(request, root.id),
|
||||
f"{latest.checksum}:{latest.modified.isoformat()}",
|
||||
)
|
||||
self.assertEqual(preview_etag(request, root.id), latest.archive_checksum)
|
||||
self.assertEqual(thumbnail_etag(request, root.id), latest.checksum)
|
||||
|
||||
def test_metadata_etag_changes_when_document_modified_changes(self) -> None:
|
||||
doc = Document.objects.create(
|
||||
title="doc",
|
||||
checksum="same-checksum",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
request = SimpleNamespace(query_params={})
|
||||
|
||||
original_etag = metadata_etag(request, doc.id)
|
||||
new_modified = timezone.now() + timedelta(seconds=5)
|
||||
Document.objects.filter(id=doc.id).update(modified=new_modified)
|
||||
|
||||
self.assertNotEqual(metadata_etag(request, doc.id), original_etag)
|
||||
self.assertEqual(
|
||||
metadata_etag(request, doc.id),
|
||||
f"{doc.checksum}:{new_modified.isoformat()}",
|
||||
)
|
||||
|
||||
def test_resolve_effective_doc_returns_none_for_invalid_or_unrelated_version(
|
||||
self,
|
||||
) -> None:
|
||||
|
||||
@@ -30,6 +30,7 @@ from documents.signals.handlers import update_llm_suggestions_cache
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
from documents.tests.utils import read_streaming_response
|
||||
from paperless.models import ApplicationConfiguration
|
||||
from paperless_ai.exceptions import LLMTimeoutError
|
||||
|
||||
|
||||
class TestViews(DirectoriesMixin, TestCase):
|
||||
@@ -476,6 +477,33 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
|
||||
)
|
||||
|
||||
@patch("documents.views.get_ai_document_classification")
|
||||
@override_settings(
|
||||
AI_ENABLED=True,
|
||||
LLM_BACKEND="openai-like",
|
||||
)
|
||||
def test_ai_suggestions_with_llm_timeout(
|
||||
self,
|
||||
mock_get_ai_classification,
|
||||
) -> None:
|
||||
mock_get_ai_classification.side_effect = LLMTimeoutError()
|
||||
|
||||
self.client.force_login(user=self.user)
|
||||
response = self.client.get(
|
||||
f"/api/documents/{self.document.pk}/ai_suggestions/",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_503_SERVICE_UNAVAILABLE)
|
||||
self.assertEqual(
|
||||
response.json(),
|
||||
{
|
||||
"ai": ["AI backend request timed out."],
|
||||
},
|
||||
)
|
||||
self.assertIsNone(
|
||||
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
|
||||
)
|
||||
|
||||
def test_invalidate_suggestions_cache(self) -> None:
|
||||
self.client.force_login(user=self.user)
|
||||
suggestions = {
|
||||
|
||||
+114
-12
@@ -12,6 +12,7 @@ from datetime import timedelta
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from time import mktime
|
||||
from time import sleep
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
@@ -240,6 +241,7 @@ from paperless.serialisers import UserSerializer
|
||||
from paperless.views import StandardPagination
|
||||
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||
from paperless_ai.chat import stream_chat_with_documents
|
||||
from paperless_ai.exceptions import LLMTimeoutError
|
||||
from paperless_ai.matching import extract_unmatched_names
|
||||
from paperless_ai.matching import match_correspondents_by_name
|
||||
from paperless_ai.matching import match_document_types_by_name
|
||||
@@ -1400,7 +1402,7 @@ class DocumentViewSet(
|
||||
)
|
||||
if request.user is not None and not has_perms_owner_aware(
|
||||
request.user,
|
||||
"view_document",
|
||||
"change_document",
|
||||
doc,
|
||||
):
|
||||
return HttpResponseForbidden("Insufficient permissions")
|
||||
@@ -1460,7 +1462,7 @@ class DocumentViewSet(
|
||||
)
|
||||
if request.user is not None and not has_perms_owner_aware(
|
||||
request.user,
|
||||
"view_document",
|
||||
"change_document",
|
||||
doc,
|
||||
):
|
||||
return HttpResponseForbidden("Insufficient permissions")
|
||||
@@ -1509,6 +1511,17 @@ class DocumentViewSet(
|
||||
exc_info=True,
|
||||
)
|
||||
raise ValidationError({"ai": [_("Invalid AI configuration.")]}) from exc
|
||||
except LLMTimeoutError as exc:
|
||||
logger.exception(
|
||||
"AI backend timed out while generating suggestions for document %s: %s",
|
||||
doc.pk,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return Response(
|
||||
{"ai": [_("AI backend request timed out.")]},
|
||||
status=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
matched_tags = match_tags_by_name(
|
||||
llm_suggestions.get("tags", []),
|
||||
@@ -2276,6 +2289,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||
return super().list(request)
|
||||
|
||||
from documents.search import SearchHit
|
||||
from documents.search import SearchQueryError
|
||||
from documents.search import TantivyBackend
|
||||
from documents.search import TantivyRelevanceList
|
||||
from documents.search import get_backend
|
||||
@@ -2468,6 +2482,11 @@ class UnifiedSearchViewSet(DocumentViewSet):
|
||||
return HttpResponseForbidden(_("Insufficient permissions."))
|
||||
except ValidationError:
|
||||
raise
|
||||
except SearchQueryError as e:
|
||||
# User-fixable query error (e.g. an unparsable date): surface the
|
||||
# specific message so the user can correct it, rather than a generic
|
||||
# 400 or silently empty results.
|
||||
raise ValidationError({"query": [str(e)]}) from e
|
||||
except Exception as e:
|
||||
logger.warning(f"An error occurred listing search results: {e!s}")
|
||||
return HttpResponseBadRequest(
|
||||
@@ -3126,6 +3145,7 @@ class PostDocumentView(GenericAPIView[Any]):
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
doc_name, doc_data = serializer.validated_data.get("document")
|
||||
doc_name = normalize("NFC", doc_name)
|
||||
correspondent_id = serializer.validated_data.get("correspondent")
|
||||
document_type_id = serializer.validated_data.get("document_type")
|
||||
storage_path_id = serializer.validated_data.get("storage_path")
|
||||
@@ -4011,7 +4031,7 @@ class RemoteVersionView(GenericAPIView[Any]):
|
||||
|
||||
|
||||
class _TasksViewSetSchema(AutoSchema):
|
||||
_UNPAGINATED_ACTIONS = frozenset({"summary", "active"})
|
||||
_UNPAGINATED_ACTIONS = frozenset({"summary", "active", "status_counts"})
|
||||
|
||||
def _get_paginator(self):
|
||||
if getattr(self.view, "action", None) in self._UNPAGINATED_ACTIONS:
|
||||
@@ -4033,7 +4053,7 @@ class _TasksViewSetSchema(AutoSchema):
|
||||
),
|
||||
acknowledge=extend_schema(
|
||||
operation_id="acknowledge_tasks",
|
||||
description="Acknowledge a list of tasks",
|
||||
description="Acknowledge a list of tasks, or all visible unacknowledged tasks",
|
||||
request=AcknowledgeTasksViewSerializer,
|
||||
responses={
|
||||
(200, "application/json"): inline_serializer(
|
||||
@@ -4071,6 +4091,19 @@ class _TasksViewSetSchema(AutoSchema):
|
||||
),
|
||||
],
|
||||
),
|
||||
status_counts=extend_schema(
|
||||
responses={
|
||||
200: inline_serializer(
|
||||
name="TaskStatusCounts",
|
||||
fields={
|
||||
"all": serializers.IntegerField(),
|
||||
"needs_attention": serializers.IntegerField(),
|
||||
"in_progress": serializers.IntegerField(),
|
||||
"completed": serializers.IntegerField(),
|
||||
},
|
||||
),
|
||||
},
|
||||
),
|
||||
active=extend_schema(
|
||||
description="Currently pending and running tasks (capped at 50).",
|
||||
responses={200: TaskSerializerV10(many=True)},
|
||||
@@ -4124,6 +4157,7 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
|
||||
PaperlessTask.TaskType.SANITY_CHECK: (sanity_check, {"raise_on_error": False}),
|
||||
PaperlessTask.TaskType.LLM_INDEX: (llmindex_index, {"rebuild": False}),
|
||||
}
|
||||
_STATUS_COUNT_EXCLUDED_FILTERS = frozenset({"status", "is_complete"})
|
||||
|
||||
def get_serializer_class(self):
|
||||
# v9: use backwards-compatible serializer with old field names
|
||||
@@ -4164,16 +4198,38 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
|
||||
queryset = queryset.filter(task_id=task_id)
|
||||
return queryset
|
||||
|
||||
def get_status_count_queryset(self):
|
||||
"""Apply task filters except the status dimensions represented by the counts."""
|
||||
query_params = self.request.query_params.copy()
|
||||
for param in self._STATUS_COUNT_EXCLUDED_FILTERS:
|
||||
query_params.pop(param, None)
|
||||
|
||||
filterset = self.filterset_class(
|
||||
data=query_params,
|
||||
queryset=self.get_queryset(),
|
||||
request=self.request,
|
||||
)
|
||||
if not filterset.is_valid():
|
||||
raise ValidationError(filterset.errors)
|
||||
return filterset.qs
|
||||
|
||||
@action(
|
||||
methods=["post"],
|
||||
detail=False,
|
||||
permission_classes=[IsAuthenticated, AcknowledgeTasksPermissions],
|
||||
)
|
||||
def acknowledge(self, request):
|
||||
serializer = AcknowledgeTasksViewSerializer(data=request.data)
|
||||
queryset = self.get_queryset()
|
||||
serializer = AcknowledgeTasksViewSerializer(
|
||||
data=request.data,
|
||||
context={"queryset": queryset},
|
||||
)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
task_ids = serializer.validated_data.get("tasks")
|
||||
tasks = self.get_queryset().filter(id__in=task_ids)
|
||||
if serializer.validated_data.get("all", False):
|
||||
tasks = queryset.filter(acknowledged=False)
|
||||
else:
|
||||
task_ids = serializer.validated_data.get("tasks")
|
||||
tasks = queryset.filter(id__in=task_ids)
|
||||
count = tasks.update(acknowledged=True)
|
||||
return Response({"result": count})
|
||||
|
||||
@@ -4226,6 +4282,34 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
|
||||
serializer = TaskSummarySerializer(data, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
@action(methods=["get"], detail=False)
|
||||
def status_counts(self, request):
|
||||
"""Aggregated task counts for task UI sections."""
|
||||
queryset = self.get_status_count_queryset()
|
||||
counts = queryset.aggregate(
|
||||
all=Count("id"),
|
||||
needs_attention=Count(
|
||||
"id",
|
||||
filter=Q(
|
||||
status__in=[
|
||||
PaperlessTask.Status.FAILURE,
|
||||
PaperlessTask.Status.REVOKED,
|
||||
],
|
||||
),
|
||||
),
|
||||
in_progress=Count(
|
||||
"id",
|
||||
filter=Q(
|
||||
status__in=[
|
||||
PaperlessTask.Status.PENDING,
|
||||
PaperlessTask.Status.STARTED,
|
||||
],
|
||||
),
|
||||
),
|
||||
completed=Count("id", filter=Q(status=PaperlessTask.Status.SUCCESS)),
|
||||
)
|
||||
return Response(counts)
|
||||
|
||||
@action(methods=["get"], detail=False)
|
||||
def active(self, request):
|
||||
"""Currently pending and running tasks (capped at 50)."""
|
||||
@@ -4925,11 +5009,29 @@ class SystemStatusView(PassUserMixin):
|
||||
celery_error = None
|
||||
celery_url = None
|
||||
try:
|
||||
celery_ping = celery_app.control.inspect().ping()
|
||||
celery_url = next(iter(celery_ping.keys()))
|
||||
first_worker_ping = celery_ping[celery_url]
|
||||
if first_worker_ping["ok"] == "pong":
|
||||
celery_active = "OK"
|
||||
celery_ping = None
|
||||
for ping_attempt in range(3):
|
||||
celery_ping = celery_app.control.inspect().ping()
|
||||
if celery_ping:
|
||||
break
|
||||
if ping_attempt < 2:
|
||||
sleep(0.25)
|
||||
|
||||
if not celery_ping:
|
||||
celery_active = "WARNING"
|
||||
celery_error = (
|
||||
"No celery workers responded to ping. This may be temporary."
|
||||
)
|
||||
else:
|
||||
celery_url, first_worker_ping = next(iter(celery_ping.items()))
|
||||
if (
|
||||
isinstance(first_worker_ping, dict)
|
||||
and first_worker_ping.get("ok") == "pong"
|
||||
):
|
||||
celery_active = "OK"
|
||||
else:
|
||||
celery_active = "WARNING"
|
||||
celery_error = "Celery worker responded unexpectedly."
|
||||
except Exception as e:
|
||||
celery_active = "ERROR"
|
||||
logger.exception(
|
||||
|
||||
@@ -197,6 +197,7 @@ class AIConfig(BaseConfig):
|
||||
llm_embedding_endpoint: str = dataclasses.field(init=False)
|
||||
llm_embedding_chunk_size: int = dataclasses.field(init=False)
|
||||
llm_context_size: int = dataclasses.field(init=False)
|
||||
llm_request_timeout: int = dataclasses.field(init=False)
|
||||
llm_backend: str = dataclasses.field(init=False)
|
||||
llm_model: str = dataclasses.field(init=False)
|
||||
llm_api_key: str = dataclasses.field(init=False)
|
||||
@@ -221,6 +222,9 @@ class AIConfig(BaseConfig):
|
||||
app_config.llm_embedding_chunk_size or settings.LLM_EMBEDDING_CHUNK_SIZE
|
||||
)
|
||||
self.llm_context_size = app_config.llm_context_size or settings.LLM_CONTEXT_SIZE
|
||||
self.llm_request_timeout = (
|
||||
app_config.llm_request_timeout or settings.LLM_REQUEST_TIMEOUT
|
||||
)
|
||||
self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND
|
||||
self.llm_model = app_config.llm_model or settings.LLM_MODEL
|
||||
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
# Generated by Django 5.2.14 on 2026-06-14 14:22
|
||||
|
||||
import django.core.validators
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("paperless", "0012_applicationconfiguration_llm_output_language"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="applicationconfiguration",
|
||||
name="llm_request_timeout",
|
||||
field=models.PositiveSmallIntegerField(
|
||||
null=True,
|
||||
validators=[django.core.validators.MinValueValidator(1)],
|
||||
verbose_name="Sets the LLM request timeout in seconds",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -366,6 +366,12 @@ class ApplicationConfiguration(AbstractSingletonModel):
|
||||
max_length=32,
|
||||
)
|
||||
|
||||
llm_request_timeout = models.PositiveSmallIntegerField(
|
||||
verbose_name=_("Sets the LLM timeout in seconds"),
|
||||
null=True,
|
||||
validators=[MinValueValidator(1)],
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("paperless application settings")
|
||||
permissions = [
|
||||
|
||||
@@ -20,6 +20,7 @@ from PIL import Image
|
||||
from PIL import ImageDraw
|
||||
from PIL import ImageFont
|
||||
|
||||
from paperless.parsers.utils import read_file_handle_unicode_errors
|
||||
from paperless.version import __full_version_str__
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -183,7 +184,7 @@ class TextDocumentParser:
|
||||
documents.parsers.ParseError
|
||||
If the file cannot be read.
|
||||
"""
|
||||
self._text = self._read_text(document_path)
|
||||
self._text = read_file_handle_unicode_errors(document_path, log=logger)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Result accessors
|
||||
@@ -295,30 +296,3 @@ class TextDocumentParser:
|
||||
Always ``[]`` — plain text files carry no structured metadata.
|
||||
"""
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _read_text(self, filepath: Path) -> str:
|
||||
"""Read file content, replacing invalid UTF-8 bytes rather than failing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filepath:
|
||||
Path to the file to read.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
File content as a string.
|
||||
"""
|
||||
try:
|
||||
return filepath.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
logger.warning(
|
||||
"Unicode error reading %s, replacing bad bytes: %s",
|
||||
filepath,
|
||||
exc,
|
||||
)
|
||||
return filepath.read_bytes().decode("utf-8", errors="replace")
|
||||
|
||||
@@ -8,6 +8,7 @@ share implementation.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import logging
|
||||
import re
|
||||
import tempfile
|
||||
@@ -114,7 +115,7 @@ def read_file_handle_unicode_errors(
|
||||
filepath: Path,
|
||||
log: logging.Logger | None = None,
|
||||
) -> str:
|
||||
"""Read a file as UTF-8 text, replacing invalid bytes rather than raising.
|
||||
"""Read a file as text, detecting encoding via BOM and stripping NUL bytes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -127,15 +128,27 @@ def read_file_handle_unicode_errors(
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
File content as a string, with any invalid UTF-8 sequences replaced
|
||||
by the Unicode replacement character.
|
||||
File content as a string, with NUL bytes removed so the result is
|
||||
safe to store in PostgreSQL text fields.
|
||||
"""
|
||||
_log = log or logger
|
||||
raw = filepath.read_bytes()
|
||||
|
||||
if raw.startswith((codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE)):
|
||||
encoding = "utf-16"
|
||||
elif raw.startswith(codecs.BOM_UTF8):
|
||||
encoding = "utf-8-sig"
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
|
||||
try:
|
||||
return filepath.read_text(encoding="utf-8")
|
||||
text = raw.decode(encoding)
|
||||
except UnicodeDecodeError as e:
|
||||
_log.warning("Unicode error during text reading, continuing: %s", e)
|
||||
return filepath.read_bytes().decode("utf-8", errors="replace")
|
||||
text = raw.decode("utf-8", errors="replace")
|
||||
|
||||
# PostgreSQL rejects NUL (0x00) bytes in text fields
|
||||
return text.replace("\x00", "")
|
||||
|
||||
|
||||
def get_page_count_for_pdf(
|
||||
|
||||
@@ -97,8 +97,14 @@ MODEL_FILE = get_path_from_env(
|
||||
DATA_DIR / "classification_model.pickle",
|
||||
)
|
||||
LLM_INDEX_DIR = DATA_DIR / "llm_index"
|
||||
LLM_INDEX_LOCK = DATA_DIR / "locks" / "llm_index.lock"
|
||||
(DATA_DIR / "locks").mkdir(parents=True, exist_ok=True)
|
||||
LLM_INDEX_LOCK = LLM_INDEX_DIR / "index.lock"
|
||||
# Cross-process read/write lock guarding the LLM index compaction/migration
|
||||
# file swap. Readers hold it shared; the swap takes it exclusively so it never
|
||||
# runs while a reader connection is open. Must be a SQLite (.db) file.
|
||||
LLM_INDEX_RWLOCK = LLM_INDEX_DIR / "llmindex.rwlock.db"
|
||||
# Seconds the compaction swap waits for active readers to drain before skipping
|
||||
# this cycle (it is a maintenance operation; the next run retries).
|
||||
LLM_INDEX_COMPACTION_LOCK_TIMEOUT = 30
|
||||
|
||||
LOGGING_DIR = get_path_from_env("PAPERLESS_LOGGING_DIR", DATA_DIR / "log")
|
||||
|
||||
@@ -644,6 +650,7 @@ LOGGING = {
|
||||
"kombu": {"handlers": ["file_celery"], "level": "DEBUG"},
|
||||
"_granian": {"handlers": ["file_paperless"], "level": "DEBUG"},
|
||||
"granian.access": {"handlers": ["file_paperless"], "level": "DEBUG"},
|
||||
"httpx": {"level": "WARNING"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1199,6 +1206,9 @@ if LLM_EMBEDDING_CHUNK_SIZE < 1:
|
||||
LLM_CONTEXT_SIZE = get_int_from_env("PAPERLESS_AI_LLM_CONTEXT_SIZE", 8192)
|
||||
if LLM_CONTEXT_SIZE < 1:
|
||||
raise ImproperlyConfigured("PAPERLESS_AI_LLM_CONTEXT_SIZE must be >= 1")
|
||||
LLM_REQUEST_TIMEOUT = get_int_from_env("PAPERLESS_AI_LLM_REQUEST_TIMEOUT", 120)
|
||||
if LLM_REQUEST_TIMEOUT < 1:
|
||||
raise ImproperlyConfigured("PAPERLESS_AI_LLM_REQUEST_TIMEOUT must be >= 1")
|
||||
LLM_BACKEND = get_choice_from_env(
|
||||
"PAPERLESS_AI_LLM_BACKEND",
|
||||
{"ollama", "openai-like"},
|
||||
|
||||
@@ -252,6 +252,9 @@ def parse_db_settings(data_dir: Path) -> dict[str, dict[str, Any]]:
|
||||
"NAME": os.getenv("PAPERLESS_DBNAME", "paperless"),
|
||||
"USER": os.getenv("PAPERLESS_DBUSER", "paperless"),
|
||||
"PASSWORD": os.getenv("PAPERLESS_DBPASS", "paperless"),
|
||||
# Validate pooled connections so a connection closed server-side
|
||||
# is replaced rather than handed out as "the connection is closed".
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
}
|
||||
|
||||
base_options = {
|
||||
|
||||
@@ -398,6 +398,7 @@ class TestParseDbSettings:
|
||||
{
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
"HOST": "localhost",
|
||||
"NAME": "paperless",
|
||||
"USER": "paperless",
|
||||
@@ -426,6 +427,7 @@ class TestParseDbSettings:
|
||||
{
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
"HOST": "paperless-db-host",
|
||||
"PORT": 1111,
|
||||
"NAME": "customdb",
|
||||
@@ -455,6 +457,7 @@ class TestParseDbSettings:
|
||||
{
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
"HOST": "pghost",
|
||||
"NAME": "paperless",
|
||||
"USER": "paperless",
|
||||
@@ -485,6 +488,7 @@ class TestParseDbSettings:
|
||||
{
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
"HOST": "pghost",
|
||||
"NAME": "paperless",
|
||||
"USER": "paperless",
|
||||
|
||||
@@ -2,13 +2,50 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
from pathlib import Path
|
||||
|
||||
from paperless.parsers.utils import is_tagged_pdf
|
||||
from paperless.parsers.utils import read_file_handle_unicode_errors
|
||||
|
||||
SAMPLES = Path(__file__).parent / "samples" / "tesseract"
|
||||
|
||||
|
||||
class TestReadFileHandleUnicodeErrors:
|
||||
def test_plain_utf8(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "plain.txt"
|
||||
f.write_bytes(b"hello world")
|
||||
assert read_file_handle_unicode_errors(f) == "hello world"
|
||||
|
||||
def test_utf8_bom(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "bom.txt"
|
||||
f.write_bytes(codecs.BOM_UTF8 + b"hello")
|
||||
assert read_file_handle_unicode_errors(f) == "hello"
|
||||
|
||||
def test_utf16_le(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "utf16le.txt"
|
||||
f.write_bytes(codecs.BOM_UTF16_LE + "hello".encode("utf-16-le"))
|
||||
assert read_file_handle_unicode_errors(f) == "hello"
|
||||
|
||||
def test_utf16_be(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "utf16be.txt"
|
||||
f.write_bytes(codecs.BOM_UTF16_BE + "hello".encode("utf-16-be"))
|
||||
assert read_file_handle_unicode_errors(f) == "hello"
|
||||
|
||||
def test_nul_bytes_stripped(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "null-bytes.txt"
|
||||
f.write_bytes(b"foo\x00bar")
|
||||
assert read_file_handle_unicode_errors(f) == "foobar"
|
||||
|
||||
def test_invalid_utf8_replaced(self, tmp_path: Path) -> None:
|
||||
f = tmp_path / "bad.txt"
|
||||
f.write_bytes(b"ok\x80\x81bad")
|
||||
result = read_file_handle_unicode_errors(f)
|
||||
assert "ok" in result
|
||||
assert "bad" in result
|
||||
assert "\x00" not in result
|
||||
|
||||
|
||||
class TestIsTaggedPdf:
|
||||
def test_tagged_pdf_returns_true(self) -> None:
|
||||
assert is_tagged_pdf(SAMPLES / "simple-digital.pdf") is True
|
||||
|
||||
@@ -49,7 +49,7 @@ from paperless.serialisers import GroupSerializer
|
||||
from paperless.serialisers import PaperlessAuthTokenSerializer
|
||||
from paperless.serialisers import ProfileSerializer
|
||||
from paperless.serialisers import UserSerializer
|
||||
from paperless_ai.indexing import vector_store_file_exists
|
||||
from paperless_ai.indexing import llm_index_exists
|
||||
|
||||
|
||||
class PaperlessObtainAuthTokenView(ObtainAuthToken):
|
||||
@@ -467,7 +467,7 @@ class ApplicationConfigurationViewSet(ModelViewSet[ApplicationConfiguration]):
|
||||
or old_llm_context_size != new_llm_context_size
|
||||
)
|
||||
rebuild_needed = new_ai_index_enabled and (
|
||||
not vector_store_file_exists() or embedding_config_changed
|
||||
not llm_index_exists() or embedding_config_changed
|
||||
)
|
||||
|
||||
if rebuild_needed:
|
||||
|
||||
@@ -8,6 +8,7 @@ from documents.models import Document
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.db import db_connection_released
|
||||
from paperless_ai.indexing import query_similar_documents
|
||||
from paperless_ai.indexing import truncate_content
|
||||
|
||||
@@ -24,9 +25,14 @@ def get_language_name(language_code: str) -> str:
|
||||
|
||||
def build_prompt_without_rag(
|
||||
document: Document,
|
||||
config: AIConfig,
|
||||
) -> str:
|
||||
filename = document.filename or ""
|
||||
content = truncate_content(document.content[:4000] or "")
|
||||
content = truncate_content(
|
||||
document.content[:4000] or "",
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
|
||||
return f"""
|
||||
You are a document classification assistant.
|
||||
@@ -49,10 +55,15 @@ def build_prompt_without_rag(
|
||||
|
||||
def build_prompt_with_rag(
|
||||
document: Document,
|
||||
config: AIConfig,
|
||||
user: User | None = None,
|
||||
) -> str:
|
||||
base_prompt = build_prompt_without_rag(document)
|
||||
context = truncate_content(get_context_for_document(document, user))
|
||||
base_prompt = build_prompt_without_rag(document, config)
|
||||
context = truncate_content(
|
||||
get_context_for_document(document, user),
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
|
||||
return f"""{base_prompt}
|
||||
|
||||
@@ -130,26 +141,29 @@ def get_ai_document_classification(
|
||||
ai_config = AIConfig()
|
||||
|
||||
prompt = (
|
||||
build_prompt_with_rag(document, user)
|
||||
build_prompt_with_rag(document, ai_config, user)
|
||||
if ai_config.llm_embedding_backend
|
||||
else build_prompt_without_rag(document)
|
||||
else build_prompt_without_rag(document, ai_config)
|
||||
)
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query(prompt)
|
||||
suggestions = parse_ai_response(result)
|
||||
if output_language:
|
||||
localized = client.run_llm_query(
|
||||
build_localization_prompt(suggestions, output_language),
|
||||
)
|
||||
localized_suggestions = parse_ai_response(localized)
|
||||
suggestions = {
|
||||
**suggestions,
|
||||
"title": localized_suggestions["title"] or suggestions["title"],
|
||||
"tags": localized_suggestions["tags"] or suggestions["tags"],
|
||||
"document_types": localized_suggestions["document_types"]
|
||||
or suggestions["document_types"],
|
||||
"storage_paths": localized_suggestions["storage_paths"]
|
||||
or suggestions["storage_paths"],
|
||||
}
|
||||
# Hand the pooled DB connection back while the (slow) LLM query runs so it
|
||||
# is not pinned for the call's duration; see paperless_ai.db and #12976.
|
||||
with db_connection_released():
|
||||
result = client.run_llm_query(prompt)
|
||||
suggestions = parse_ai_response(result)
|
||||
if output_language:
|
||||
localized = client.run_llm_query(
|
||||
build_localization_prompt(suggestions, output_language),
|
||||
)
|
||||
localized_suggestions = parse_ai_response(localized)
|
||||
suggestions = {
|
||||
**suggestions,
|
||||
"title": localized_suggestions["title"] or suggestions["title"],
|
||||
"tags": localized_suggestions["tags"] or suggestions["tags"],
|
||||
"document_types": localized_suggestions["document_types"]
|
||||
or suggestions["document_types"],
|
||||
"storage_paths": localized_suggestions["storage_paths"]
|
||||
or suggestions["storage_paths"],
|
||||
}
|
||||
return suggestions
|
||||
|
||||
+57
-123
@@ -3,9 +3,13 @@ import logging
|
||||
import sys
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.db import db_connection_released
|
||||
from paperless_ai.indexing import _document_id_filters
|
||||
from paperless_ai.indexing import get_rag_prompt_helper
|
||||
from paperless_ai.indexing import load_or_build_index
|
||||
from paperless_ai.indexing import read_store
|
||||
|
||||
logger = logging.getLogger("paperless_ai.chat")
|
||||
|
||||
@@ -75,148 +79,78 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str
|
||||
)
|
||||
|
||||
|
||||
def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int):
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from llama_index.core.vector_stores import VectorStoreQuery
|
||||
|
||||
class DocumentFilteredFaissRetriever(BaseRetriever):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._cached_query_str = None
|
||||
self._cached_nodes = []
|
||||
|
||||
def _retrieve(self, query_bundle):
|
||||
if query_bundle.query_str == self._cached_query_str:
|
||||
return self._cached_nodes
|
||||
|
||||
if query_bundle.embedding is None:
|
||||
query_bundle.embedding = (
|
||||
index._embed_model.get_agg_embedding_from_queries(
|
||||
query_bundle.embedding_strs,
|
||||
)
|
||||
)
|
||||
|
||||
faiss_index = index.vector_store._faiss_index
|
||||
max_top_k = faiss_index.ntotal
|
||||
if max_top_k == 0:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = []
|
||||
return []
|
||||
|
||||
query_top_k = min(max(similarity_top_k, 1), max_top_k)
|
||||
allowed_nodes: list[NodeWithScore] = []
|
||||
seen_node_ids: set[str] = set()
|
||||
|
||||
while query_top_k <= max_top_k:
|
||||
query_result = index.vector_store.query(
|
||||
VectorStoreQuery(
|
||||
query_embedding=query_bundle.embedding,
|
||||
similarity_top_k=query_top_k,
|
||||
),
|
||||
)
|
||||
|
||||
for vector_id, score in zip(
|
||||
query_result.ids or [],
|
||||
query_result.similarities or [],
|
||||
strict=False,
|
||||
):
|
||||
node_id = index.index_struct.nodes_dict.get(vector_id)
|
||||
if node_id is None or node_id in seen_node_ids:
|
||||
continue
|
||||
|
||||
node = index.docstore.docs.get(node_id)
|
||||
if node is None or node.metadata.get("document_id") not in doc_ids:
|
||||
continue
|
||||
|
||||
seen_node_ids.add(node_id)
|
||||
allowed_nodes.append(NodeWithScore(node=node, score=score))
|
||||
|
||||
if len(allowed_nodes) >= similarity_top_k:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
if query_top_k == max_top_k:
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
query_top_k = min(query_top_k * 2, max_top_k)
|
||||
|
||||
self._cached_query_str = query_bundle.query_str
|
||||
self._cached_nodes = allowed_nodes
|
||||
return allowed_nodes
|
||||
|
||||
return DocumentFilteredFaissRetriever()
|
||||
|
||||
|
||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
try:
|
||||
yield from _stream_chat_with_documents(query_str, documents)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to stream document chat response: {e}", exc_info=True)
|
||||
logger.exception("Failed to stream document chat response: %s", e)
|
||||
yield CHAT_ERROR_MESSAGE
|
||||
|
||||
|
||||
def _stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
client = AIClient()
|
||||
index = load_or_build_index()
|
||||
|
||||
doc_ids = [str(doc.pk) for doc in documents]
|
||||
|
||||
# Filter only the node(s) that match the document IDs
|
||||
nodes = [
|
||||
node
|
||||
for node in index.docstore.docs.values()
|
||||
if node.metadata.get("document_id") in doc_ids
|
||||
]
|
||||
|
||||
if len(nodes) == 0:
|
||||
logger.warning("No nodes found for the given documents.")
|
||||
if not documents:
|
||||
yield CHAT_NO_CONTENT_MESSAGE
|
||||
return
|
||||
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import get_response_synthesizer
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
index,
|
||||
set(doc_ids),
|
||||
CHAT_RETRIEVER_TOP_K,
|
||||
)
|
||||
config = AIConfig()
|
||||
filters = _document_id_filters(str(doc.pk) for doc in documents)
|
||||
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
if len(top_nodes) == 0:
|
||||
logger.warning("Retriever returned no nodes for the given documents.")
|
||||
yield CHAT_NO_CONTENT_MESSAGE
|
||||
return
|
||||
# Hold the shared read lock for the whole operation: the query engine
|
||||
# retrieves from the vector store again during synthesis, so the connection
|
||||
# must stay open (and the swap must not run) until the stream finishes.
|
||||
with read_store() as store:
|
||||
index = load_or_build_index(config, store)
|
||||
retriever = VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=CHAT_RETRIEVER_TOP_K,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
references = _get_document_references(documents, top_nodes)
|
||||
# Slow query-embedding + vector search; no Django ORM access happens
|
||||
# during it, so release the pooled DB connection for its duration. See
|
||||
# #12976.
|
||||
with db_connection_released():
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
if not top_nodes:
|
||||
logger.warning("No nodes found for the given documents.")
|
||||
yield CHAT_NO_CONTENT_MESSAGE
|
||||
return
|
||||
|
||||
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
|
||||
response_synthesizer = get_response_synthesizer(
|
||||
llm=client.llm,
|
||||
prompt_helper=get_rag_prompt_helper(),
|
||||
text_qa_template=prompt_template,
|
||||
streaming=True,
|
||||
)
|
||||
client = AIClient()
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
llm=client.llm,
|
||||
response_synthesizer=response_synthesizer,
|
||||
streaming=True,
|
||||
)
|
||||
references = _get_document_references(documents, top_nodes)
|
||||
|
||||
logger.debug("Document chat query: %s", query_str)
|
||||
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
|
||||
response_synthesizer = get_response_synthesizer(
|
||||
llm=client.llm,
|
||||
prompt_helper=get_rag_prompt_helper(
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
),
|
||||
text_qa_template=prompt_template,
|
||||
streaming=True,
|
||||
)
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
llm=client.llm,
|
||||
response_synthesizer=response_synthesizer,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
response_stream = query_engine.query(query_str)
|
||||
logger.debug("Document chat query: %s", query_str)
|
||||
# Release the pooled DB connection for the slow streaming LLM response
|
||||
# so it is not pinned for the whole stream; see paperless_ai.db and
|
||||
# #12976.
|
||||
with db_connection_released():
|
||||
response_stream = query_engine.query(query_str)
|
||||
for chunk in response_stream.response_gen:
|
||||
yield chunk
|
||||
sys.stdout.flush()
|
||||
|
||||
for chunk in response_stream.response_gen:
|
||||
yield chunk
|
||||
sys.stdout.flush()
|
||||
|
||||
if references:
|
||||
yield _format_chat_metadata_trailer(references)
|
||||
if references:
|
||||
yield _format_chat_metadata_trailer(references)
|
||||
|
||||
+49
-28
@@ -1,11 +1,14 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from paperless.models import LLMBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
|
||||
@@ -16,6 +19,7 @@ from paperless.network import create_pinned_async_httpx_client
|
||||
from paperless.network import create_pinned_httpx_client
|
||||
from paperless.network import validate_outbound_http_url
|
||||
from paperless_ai.base_model import DocumentClassifierSchema
|
||||
from paperless_ai.exceptions import LLMTimeoutError
|
||||
|
||||
logger = logging.getLogger("paperless_ai.client")
|
||||
|
||||
@@ -61,16 +65,16 @@ class AIClient:
|
||||
model=self.settings.llm_model or "llama3.1",
|
||||
base_url=endpoint,
|
||||
context_window=self.settings.llm_context_size,
|
||||
request_timeout=120,
|
||||
request_timeout=self.settings.llm_request_timeout,
|
||||
system_prompt=LLM_SYSTEM_PROMPT,
|
||||
client=Client(
|
||||
host=endpoint,
|
||||
timeout=120,
|
||||
timeout=self.settings.llm_request_timeout,
|
||||
transport=transport,
|
||||
),
|
||||
async_client=AsyncClient(
|
||||
host=endpoint,
|
||||
timeout=120,
|
||||
timeout=self.settings.llm_request_timeout,
|
||||
transport=async_transport,
|
||||
),
|
||||
)
|
||||
@@ -84,15 +88,18 @@ class AIClient:
|
||||
http_client = create_pinned_httpx_client(
|
||||
endpoint,
|
||||
allow_internal=self.settings.llm_allow_internal_endpoints,
|
||||
timeout=self.settings.llm_request_timeout,
|
||||
)
|
||||
async_http_client = create_pinned_async_httpx_client(
|
||||
endpoint,
|
||||
allow_internal=self.settings.llm_allow_internal_endpoints,
|
||||
timeout=self.settings.llm_request_timeout,
|
||||
)
|
||||
return OpenAILike(
|
||||
model=self.settings.llm_model or "gpt-3.5-turbo",
|
||||
api_base=endpoint,
|
||||
api_key=self.settings.llm_api_key,
|
||||
timeout=self.settings.llm_request_timeout,
|
||||
is_chat_model=True,
|
||||
is_function_calling_model=True,
|
||||
system_prompt=LLM_SYSTEM_PROMPT,
|
||||
@@ -113,11 +120,12 @@ class AIClient:
|
||||
|
||||
user_msg = ChatMessage(role="user", content=prompt)
|
||||
if self.settings.llm_backend == LLMBackend.OLLAMA:
|
||||
result = self.llm.chat(
|
||||
[user_msg],
|
||||
format=DocumentClassifierSchema.model_json_schema(),
|
||||
think=False,
|
||||
)
|
||||
with self._normalize_timeouts():
|
||||
result = self.llm.chat(
|
||||
[user_msg],
|
||||
format=DocumentClassifierSchema.model_json_schema(),
|
||||
think=False,
|
||||
)
|
||||
logger.debug("LLM query result: %s", result)
|
||||
parsed = DocumentClassifierSchema(**json.loads(result.message.content))
|
||||
return parsed.model_dump()
|
||||
@@ -125,26 +133,39 @@ class AIClient:
|
||||
from llama_index.core.program.function_program import get_function_tool
|
||||
|
||||
tool = get_function_tool(DocumentClassifierSchema)
|
||||
result = self.llm.chat_with_tools(
|
||||
tools=[tool],
|
||||
user_msg=user_msg,
|
||||
chat_history=[],
|
||||
allow_parallel_tool_calls=True,
|
||||
)
|
||||
tool_calls = self.llm.get_tool_calls_from_response(
|
||||
result,
|
||||
error_on_no_tool_call=True,
|
||||
)
|
||||
with self._normalize_timeouts():
|
||||
result = self.llm.chat_with_tools(
|
||||
tools=[tool],
|
||||
user_msg=user_msg,
|
||||
chat_history=[],
|
||||
allow_parallel_tool_calls=True,
|
||||
tool_required=True,
|
||||
)
|
||||
tool_calls = self.llm.get_tool_calls_from_response(
|
||||
result,
|
||||
error_on_no_tool_call=True,
|
||||
)
|
||||
logger.debug("LLM query result: %s", tool_calls)
|
||||
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
|
||||
return parsed.model_dump()
|
||||
|
||||
def run_chat(self, messages: list["ChatMessage"]) -> str:
|
||||
logger.debug(
|
||||
"Running chat query against %s with model %s",
|
||||
self.settings.llm_backend,
|
||||
self.settings.llm_model,
|
||||
)
|
||||
result = self.llm.chat(messages)
|
||||
logger.debug("Chat result: %s", result)
|
||||
return result
|
||||
@contextmanager
|
||||
def _normalize_timeouts(self) -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except httpx.TimeoutException as exc:
|
||||
raise LLMTimeoutError from exc
|
||||
except Exception as exc:
|
||||
if self._is_openai_timeout(exc):
|
||||
raise LLMTimeoutError from exc
|
||||
raise
|
||||
|
||||
def _is_openai_timeout(self, exc: Exception) -> bool:
|
||||
if self.settings.llm_backend != LLMBackend.OPENAI_LIKE:
|
||||
return False
|
||||
|
||||
# Keep OpenAI imports out of module import paths and only load the SDK
|
||||
# when translating an error from an OpenAI-backed request.
|
||||
from openai import APITimeoutError
|
||||
|
||||
return isinstance(exc, APITimeoutError)
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.db import connections
|
||||
|
||||
|
||||
@contextmanager
|
||||
def db_connection_released():
|
||||
"""
|
||||
Return any checked-out DB connections to the pool for the duration of the
|
||||
wrapped block.
|
||||
|
||||
The AI endpoints run inside a synchronous web request (``ai_suggestions``)
|
||||
or a streaming response (``chat``). Django keeps the request's database
|
||||
connection checked out for the entire request/response, so a blocking LLM
|
||||
call - which can take many seconds - pins a pooled connection the whole
|
||||
time. With connection pooling enabled, enough concurrent AI requests check
|
||||
out every slot and all other requests then fail with
|
||||
``psycopg_pool.PoolTimeout`` (see issue #12976).
|
||||
|
||||
No Django ORM access happens during the LLM call, so we hand the connection
|
||||
back to the pool first; Django transparently re-checks-out a connection on
|
||||
the next ORM use after the block.
|
||||
"""
|
||||
connections.close_all()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
connections.close_all()
|
||||
@@ -1,12 +1,9 @@
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
|
||||
from documents.models import Document
|
||||
@@ -23,9 +20,7 @@ OCR_LEADER_REGEX = re.compile(r"[._\-\u00b7]{4,}")
|
||||
HORIZONTAL_WHITESPACE_REGEX = re.compile(r"[ \t\u00a0]+")
|
||||
|
||||
|
||||
def get_embedding_model() -> "BaseEmbedding":
|
||||
config = AIConfig()
|
||||
|
||||
def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
|
||||
match config.llm_embedding_backend:
|
||||
case LLMEmbeddingBackend.OPENAI_LIKE:
|
||||
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
||||
@@ -37,15 +32,18 @@ def get_embedding_model() -> "BaseEmbedding":
|
||||
http_client = create_pinned_httpx_client(
|
||||
endpoint,
|
||||
allow_internal=config.llm_allow_internal_endpoints,
|
||||
timeout=config.llm_request_timeout,
|
||||
)
|
||||
async_http_client = create_pinned_async_httpx_client(
|
||||
endpoint,
|
||||
allow_internal=config.llm_allow_internal_endpoints,
|
||||
timeout=config.llm_request_timeout,
|
||||
)
|
||||
return OpenAILikeEmbedding(
|
||||
model_name=config.llm_embedding_model or "text-embedding-3-small",
|
||||
api_key=config.llm_api_key,
|
||||
api_base=endpoint,
|
||||
timeout=config.llm_request_timeout,
|
||||
http_client=http_client,
|
||||
async_http_client=async_http_client,
|
||||
)
|
||||
@@ -78,12 +76,14 @@ def get_embedding_model() -> "BaseEmbedding":
|
||||
)
|
||||
embedding._client = Client(
|
||||
host=endpoint,
|
||||
timeout=config.llm_request_timeout,
|
||||
transport=PinnedHostHTTPTransport(
|
||||
allow_internal=config.llm_allow_internal_endpoints,
|
||||
),
|
||||
)
|
||||
embedding._async_client = AsyncClient(
|
||||
host=endpoint,
|
||||
timeout=config.llm_request_timeout,
|
||||
transport=PinnedHostAsyncHTTPTransport(
|
||||
allow_internal=config.llm_allow_internal_endpoints,
|
||||
),
|
||||
@@ -95,41 +95,24 @@ def get_embedding_model() -> "BaseEmbedding":
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_dim() -> int:
|
||||
"""
|
||||
Loads embedding dimension from meta.json if available, otherwise infers it
|
||||
from a dummy embedding and stores it for future use.
|
||||
"""
|
||||
config = AIConfig()
|
||||
default_model = {
|
||||
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
|
||||
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
|
||||
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
|
||||
}.get(
|
||||
config.llm_embedding_backend,
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
_DEFAULT_MODEL_NAMES = {
|
||||
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
|
||||
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
|
||||
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
|
||||
}
|
||||
|
||||
|
||||
def get_configured_model_name(config: AIConfig) -> str:
|
||||
"""Return the canonical name of the currently configured embedding model."""
|
||||
# dict.get(key, default) overload resolution fails for TextChoices keys in some
|
||||
# type checkers; use `or` fallback to avoid the ambiguity.
|
||||
default = (
|
||||
_DEFAULT_MODEL_NAMES.get(
|
||||
config.llm_embedding_backend,
|
||||
)
|
||||
or "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
model = config.llm_embedding_model or default_model
|
||||
|
||||
meta_path: Path = settings.LLM_INDEX_DIR / "meta.json"
|
||||
if meta_path.exists():
|
||||
with meta_path.open() as f:
|
||||
meta = json.load(f)
|
||||
if meta.get("embedding_model") != model:
|
||||
raise RuntimeError(
|
||||
f"Embedding model changed from {meta.get('embedding_model')} to {model}. "
|
||||
"You must rebuild the index.",
|
||||
)
|
||||
return meta["dim"]
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
test_embed = embedding_model.get_text_embedding("test")
|
||||
dim = len(test_embed)
|
||||
|
||||
with meta_path.open("w") as f:
|
||||
json.dump({"embedding_model": model, "dim": dim}, f)
|
||||
|
||||
return dim
|
||||
return config.llm_embedding_model or default
|
||||
|
||||
|
||||
def _normalize_llm_index_text(text: str) -> str:
|
||||
@@ -138,17 +121,11 @@ def _normalize_llm_index_text(text: str) -> str:
|
||||
|
||||
|
||||
def build_llm_index_text(doc: Document) -> str:
|
||||
# Short structured fields (filename, storage path, ASN, title, tags, ...) live
|
||||
# in node.metadata: excluded from embeddings, shown to the LLM via metadata
|
||||
# prepend. Notes and Custom Fields stay in the body: Notes can be long free
|
||||
# text, Custom Fields are dynamic in count and best kept in the embedding.
|
||||
lines = [
|
||||
f"Title: {doc.title}",
|
||||
f"Filename: {doc.filename}",
|
||||
f"Created: {doc.created}",
|
||||
f"Added: {doc.added}",
|
||||
f"Modified: {doc.modified}",
|
||||
f"Tags: {', '.join(tag.name for tag in doc.tags.all())}",
|
||||
f"Document Type: {doc.document_type.name if doc.document_type else ''}",
|
||||
f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}",
|
||||
f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}",
|
||||
f"Archive Serial Number: {doc.archive_serial_number or ''}",
|
||||
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
class LLMTimeoutError(Exception):
|
||||
pass
|
||||
+280
-243
@@ -1,28 +1,30 @@
|
||||
import logging
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils import timezone
|
||||
from filelock import FileLock
|
||||
from filelock import ReadWriteLock
|
||||
from filelock import Timeout
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import PaperlessTask
|
||||
from documents.utils import IterWrapper
|
||||
from documents.utils import identity
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.db import db_connection_released
|
||||
from paperless_ai.embedding import build_llm_index_text
|
||||
from paperless_ai.embedding import get_embedding_dim
|
||||
from paperless_ai.embedding import get_configured_model_name
|
||||
from paperless_ai.embedding import get_embedding_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
|
||||
logger = logging.getLogger("paperless_ai.indexing")
|
||||
|
||||
@@ -30,21 +32,11 @@ RAG_NUM_OUTPUT = 512
|
||||
RAG_CHUNK_OVERLAP = 200
|
||||
|
||||
|
||||
def _index_lock_path() -> Path:
|
||||
"""Return the path used as the file lock for FAISS index mutations.
|
||||
|
||||
The lock file lives in DATA_DIR/locks/ (not inside LLM_INDEX_DIR) so that a
|
||||
rebuild — which calls shutil.rmtree(LLM_INDEX_DIR) — cannot delete the lock
|
||||
while another worker still holds it.
|
||||
"""
|
||||
return settings.LLM_INDEX_LOCK
|
||||
|
||||
|
||||
def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
|
||||
# NOTE: The check-then-enqueue sequence below is non-atomic (TOCTOU): two
|
||||
# concurrent workers can both observe no running task and both enqueue a
|
||||
# full rebuild. This is wasteful but not data-corrupting — update_llm_index
|
||||
# is itself protected by _index_lock_path(), so only one rebuild runs at a
|
||||
# is itself protected by settings.LLM_INDEX_LOCK, so only one rebuild runs at a
|
||||
# time and the second one is serialised after the first completes.
|
||||
from documents.tasks import llmindex_index
|
||||
|
||||
@@ -71,46 +63,110 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_or_create_storage_context(*, rebuild=False):
|
||||
"""
|
||||
Loads or creates the StorageContext (vector store, docstore, index store).
|
||||
If rebuild=True, deletes and recreates everything.
|
||||
"""
|
||||
if rebuild:
|
||||
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
def get_vector_store() -> "PaperlessSqliteVecVectorStore":
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
||||
import faiss
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
embedding_dim = get_embedding_dim()
|
||||
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
docstore = SimpleDocumentStore()
|
||||
index_store = SimpleIndexStore()
|
||||
else:
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
|
||||
return StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
index_store=index_store,
|
||||
vector_store=vector_store,
|
||||
persist_dir=settings.LLM_INDEX_DIR,
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return PaperlessSqliteVecVectorStore(
|
||||
uri=str(settings.LLM_INDEX_DIR),
|
||||
)
|
||||
|
||||
|
||||
# --- LLM index locking ---------------------------------------------------
|
||||
#
|
||||
# Two locks guard the index; they answer different questions and are NOT
|
||||
# interchangeable:
|
||||
#
|
||||
# * settings.LLM_INDEX_LOCK (FileLock, exclusive) -- serializes WRITERS against
|
||||
# each other, so only one rebuild/upsert/delete/compaction runs at a time.
|
||||
# Taken by write_store(). Readers never take it, so it never blocks reads.
|
||||
#
|
||||
# * settings.LLM_INDEX_RWLOCK (ReadWriteLock) -- coordinates readers against the
|
||||
# compaction/migration file swap. read_store() takes it SHARED (readers run
|
||||
# concurrently); _exclude_readers() takes it EXCLUSIVE, only for the swap, so
|
||||
# the database file is never replaced while a reader connection is open (that
|
||||
# would alias the old WAL onto the new file and corrupt it).
|
||||
#
|
||||
# | vs another writer | vs a reader
|
||||
# -----------------+-------------------+----------------------------
|
||||
# normal write | LLM_INDEX_LOCK | nothing (WAL gives MVCC)
|
||||
# compaction/swap | LLM_INDEX_LOCK | LLM_INDEX_RWLOCK (exclusive)
|
||||
# reader | nothing (WAL) | LLM_INDEX_RWLOCK (shared)
|
||||
#
|
||||
# They can't be merged into one ReadWriteLock: a normal write must exclude other
|
||||
# writers WITHOUT blocking readers (WAL already gives reader/writer concurrency),
|
||||
# and ReadWriteLock has no "exclusive vs writers, shared vs readers" mode. Only
|
||||
# the swap needs to exclude readers.
|
||||
def _index_rwlock() -> ReadWriteLock:
|
||||
"""Return a fresh read/write lock instance for the index swap.
|
||||
|
||||
``is_singleton=False`` so reads and the swap always coordinate through
|
||||
SQLite (the actual cross-process case) rather than hitting the in-process
|
||||
reentrant-upgrade guard; callers must ``close()`` it (the context managers
|
||||
below do).
|
||||
"""
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return ReadWriteLock(str(settings.LLM_INDEX_RWLOCK), is_singleton=False)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def read_store():
|
||||
"""Acquire the shared read lock and yield the vector store for a read.
|
||||
|
||||
The shared lock is held for the whole lifetime of the connection (and
|
||||
closed on exit) so the compaction/migration swap, which takes the exclusive
|
||||
lock, never runs while this connection is open. Concurrent readers do not
|
||||
block each other; only the swap does.
|
||||
"""
|
||||
lock = _index_rwlock()
|
||||
try:
|
||||
with lock.read_lock(), get_vector_store() as store:
|
||||
yield store
|
||||
finally:
|
||||
lock.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _exclude_readers():
|
||||
"""Acquire exclusive index access, blocking until readers have drained.
|
||||
|
||||
The exclusive counterpart to ``read_store()``: a compaction or migration
|
||||
must not run while any reader connection is open. Raises
|
||||
:class:`filelock.Timeout` if active readers do not drain within
|
||||
``LLM_INDEX_COMPACTION_LOCK_TIMEOUT``; callers skip the operation on timeout.
|
||||
"""
|
||||
lock = _index_rwlock()
|
||||
try:
|
||||
with lock.write_lock(timeout=settings.LLM_INDEX_COMPACTION_LOCK_TIMEOUT):
|
||||
yield
|
||||
finally:
|
||||
lock.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def write_store(embed_model_name: str | None = None):
|
||||
"""Acquire the write lock and yield the vector store.
|
||||
|
||||
All mutating operations (upsert, delete, rebuild, compact) must go through
|
||||
this context manager to serialise concurrent Celery writers.
|
||||
Read paths use ``read_store()`` so they hold the shared read lock.
|
||||
|
||||
Pass ``embed_model_name`` whenever the operation may create the table so
|
||||
the model name is recorded in the schema metadata for future mismatch checks.
|
||||
"""
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
with (
|
||||
FileLock(settings.LLM_INDEX_LOCK),
|
||||
PaperlessSqliteVecVectorStore(
|
||||
uri=str(settings.LLM_INDEX_DIR),
|
||||
embed_model_name=embed_model_name,
|
||||
) as store,
|
||||
):
|
||||
yield store
|
||||
|
||||
|
||||
def build_document_node(
|
||||
document: Document,
|
||||
*,
|
||||
@@ -130,6 +186,9 @@ def build_document_node(
|
||||
"document_type": document.document_type.name
|
||||
if document.document_type
|
||||
else None,
|
||||
"filename": document.filename,
|
||||
"storage_path": document.storage_path.name if document.storage_path else None,
|
||||
"archive_serial_number": document.archive_serial_number,
|
||||
"created": document.created.isoformat() if document.created else None,
|
||||
"added": document.added.isoformat() if document.added else None,
|
||||
"modified": document.modified.isoformat(),
|
||||
@@ -142,9 +201,11 @@ def build_document_node(
|
||||
# the token count and exceed embedding models with small context windows
|
||||
# (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
|
||||
doc = LlamaDocument(
|
||||
id_=str(document.id),
|
||||
text=text,
|
||||
metadata=metadata,
|
||||
excluded_embed_metadata_keys=list(metadata.keys()),
|
||||
excluded_llm_metadata_keys=["document_id"],
|
||||
)
|
||||
chunk_size = chunk_size or get_rag_chunk_size()
|
||||
parser = SimpleNodeParser(
|
||||
@@ -154,76 +215,33 @@ def build_document_node(
|
||||
return parser.get_nodes_from_documents([doc])
|
||||
|
||||
|
||||
def load_or_build_index(nodes=None):
|
||||
"""
|
||||
Load an existing VectorStoreIndex if present,
|
||||
or build a new one using provided nodes if storage is empty.
|
||||
def load_or_build_index(config: AIConfig, store: "PaperlessSqliteVecVectorStore"):
|
||||
"""Return a VectorStoreIndex backed by ``store``.
|
||||
|
||||
``store`` is supplied by the caller's ``read_store()`` context so the shared
|
||||
read lock and the connection stay alive for the whole retrieval.
|
||||
"""
|
||||
import llama_index.core.settings as llama_settings
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core import load_index_from_storage
|
||||
|
||||
embed_model = get_embedding_model()
|
||||
embed_model = get_embedding_model(config)
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context()
|
||||
try:
|
||||
return load_index_from_storage(storage_context=storage_context)
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to load index from storage: %s", e)
|
||||
if not nodes:
|
||||
queue_llm_index_update_if_needed(
|
||||
rebuild=vector_store_file_exists(),
|
||||
reason="LLM index missing or invalid while loading.",
|
||||
)
|
||||
logger.info("No nodes provided for index creation.")
|
||||
raise
|
||||
return VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
return VectorStoreIndex.from_vector_store(
|
||||
vector_store=store,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
|
||||
|
||||
def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"):
|
||||
"""
|
||||
Removes existing documents from docstore for a given document from the index.
|
||||
This is necessary because FAISS IndexFlatL2 is append-only.
|
||||
"""
|
||||
all_node_ids = list(index.docstore.docs.keys())
|
||||
existing_nodes = [
|
||||
node.node_id
|
||||
for node in index.docstore.get_nodes(all_node_ids)
|
||||
if node.metadata.get("document_id") == str(document.id)
|
||||
]
|
||||
for node_id in existing_nodes:
|
||||
# Delete from docstore, FAISS IndexFlatL2 are append-only
|
||||
index.docstore.delete_document(node_id)
|
||||
# Also purge the FAISS position -> UUID mapping so subsequent similarity
|
||||
# queries don't raise KeyError on ghost vector positions.
|
||||
stale_keys = [
|
||||
k for k, v in index.index_struct.nodes_dict.items() if v == node_id
|
||||
]
|
||||
for key in stale_keys:
|
||||
del index.index_struct.nodes_dict[key]
|
||||
# Re-sync the mutated index_struct so persist() writes the updated nodes_dict.
|
||||
index.storage_context.index_store.add_index_struct(index.index_struct)
|
||||
|
||||
|
||||
def vector_store_file_exists():
|
||||
"""
|
||||
Check if the vector store file exists in the LLM index directory.
|
||||
"""
|
||||
return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists()
|
||||
def llm_index_exists() -> bool:
|
||||
"""True when the index table exists on disk."""
|
||||
with read_store() as store:
|
||||
return store.table_exists()
|
||||
|
||||
|
||||
def get_rag_chunk_size() -> int:
|
||||
return AIConfig().llm_embedding_chunk_size
|
||||
|
||||
|
||||
def get_rag_context_size() -> int:
|
||||
return AIConfig().llm_context_size
|
||||
|
||||
|
||||
def get_rag_chunk_overlap(chunk_size: int | None = None) -> int:
|
||||
chunk_size = chunk_size or get_rag_chunk_size()
|
||||
return min(RAG_CHUNK_OVERLAP, chunk_size - 1)
|
||||
@@ -249,123 +267,149 @@ def get_rag_prompt_helper(
|
||||
)
|
||||
|
||||
|
||||
def _embed_nodes(nodes: list["BaseNode"], embed_model) -> None:
|
||||
"""Embed ``nodes`` in place using ``embed_model``."""
|
||||
from llama_index.core.schema import MetadataMode
|
||||
|
||||
texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes]
|
||||
for node, emb in zip(
|
||||
nodes,
|
||||
embed_model.get_text_embedding_batch(texts),
|
||||
strict=True,
|
||||
):
|
||||
node.embedding = emb
|
||||
|
||||
|
||||
def _document_id_filters(doc_ids):
|
||||
"""Return a MetadataFilters IN filter scoped to ``doc_ids``."""
|
||||
from llama_index.core.vector_stores.types import FilterOperator
|
||||
from llama_index.core.vector_stores.types import MetadataFilter
|
||||
from llama_index.core.vector_stores.types import MetadataFilters
|
||||
|
||||
return MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="document_id",
|
||||
operator=FilterOperator.IN,
|
||||
value=sorted(doc_ids),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def update_llm_index(
|
||||
*,
|
||||
iter_wrapper: IterWrapper[Document] = identity,
|
||||
rebuild=False,
|
||||
) -> str:
|
||||
"""
|
||||
Rebuild or update the LLM index.
|
||||
"""
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
||||
nodes = []
|
||||
|
||||
"""Rebuild or incrementally update the LLM index."""
|
||||
with write_store() as store:
|
||||
try:
|
||||
with _exclude_readers():
|
||||
needs_reembed = store.check_and_run_migrations()
|
||||
except Timeout:
|
||||
logger.info(
|
||||
"Skipping LLM index migration check: index readers are active; "
|
||||
"will retry next run.",
|
||||
)
|
||||
needs_reembed = False
|
||||
if needs_reembed:
|
||||
logger.warning(
|
||||
"LLM index migration requires re-embedding; forcing rebuild.",
|
||||
)
|
||||
rebuild = True
|
||||
documents = Document.objects.all()
|
||||
if not documents.exists():
|
||||
no_documents = not documents.exists()
|
||||
|
||||
# Fast exit before touching config: nothing to index and no existing index.
|
||||
if no_documents and not rebuild and not llm_index_exists():
|
||||
logger.warning("No documents found to index.")
|
||||
if not rebuild and not vector_store_file_exists():
|
||||
return "No documents found to index."
|
||||
return "No documents found to index."
|
||||
|
||||
config = AIConfig()
|
||||
model_name = get_configured_model_name(config)
|
||||
|
||||
if not rebuild and llm_index_exists():
|
||||
with read_store() as store:
|
||||
config_mismatch = store.config_mismatch(model_name)
|
||||
if config_mismatch:
|
||||
logger.warning("Embedding model changed; forcing LLM index rebuild.")
|
||||
rebuild = True
|
||||
|
||||
if no_documents:
|
||||
logger.warning("No documents found to index.")
|
||||
|
||||
chunk_size = config.llm_embedding_chunk_size
|
||||
embed_model = get_embedding_model(config)
|
||||
|
||||
with FileLock(_index_lock_path()):
|
||||
if rebuild or not vector_store_file_exists():
|
||||
# remove meta.json to force re-detection of embedding dim
|
||||
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
|
||||
# Rebuild index from scratch
|
||||
with write_store(embed_model_name=model_name) as store:
|
||||
if rebuild or not store.table_exists():
|
||||
logger.info("Rebuilding LLM index.")
|
||||
import llama_index.core.settings as llama_settings
|
||||
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context(rebuild=True)
|
||||
store.drop_table()
|
||||
for document in iter_wrapper(documents):
|
||||
document_nodes = build_document_node(document, chunk_size=chunk_size)
|
||||
nodes.extend(document_nodes)
|
||||
|
||||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
show_progress=False,
|
||||
)
|
||||
nodes = build_document_node(document, chunk_size=chunk_size)
|
||||
_embed_nodes(nodes, embed_model)
|
||||
store.add(nodes)
|
||||
msg = "LLM index rebuilt successfully."
|
||||
else:
|
||||
# Update existing index
|
||||
index = load_or_build_index()
|
||||
existing_nodes: defaultdict[str, list] = defaultdict(list)
|
||||
for node in index.docstore.docs.values():
|
||||
doc_id = node.metadata.get("document_id")
|
||||
if doc_id is not None:
|
||||
existing_nodes[doc_id].append(node)
|
||||
|
||||
existing = store.get_modified_times()
|
||||
changed = 0
|
||||
for document in iter_wrapper(documents):
|
||||
doc_id = str(document.id)
|
||||
document_modified = document.modified.isoformat()
|
||||
if existing.get(doc_id) == document.modified.isoformat():
|
||||
continue
|
||||
nodes = build_document_node(document, chunk_size=chunk_size)
|
||||
_embed_nodes(nodes, embed_model)
|
||||
store.upsert_document(doc_id, nodes)
|
||||
changed += 1
|
||||
msg = (
|
||||
"LLM index updated successfully."
|
||||
if changed
|
||||
else "No changes detected in LLM index."
|
||||
)
|
||||
|
||||
if doc_id in existing_nodes:
|
||||
doc_nodes = existing_nodes[doc_id]
|
||||
node_modified = doc_nodes[0].metadata.get("modified")
|
||||
|
||||
if node_modified == document_modified:
|
||||
continue
|
||||
|
||||
# Delete from docstore, FAISS IndexFlatL2 are append-only
|
||||
for node in doc_nodes:
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
nodes.extend(build_document_node(document, chunk_size=chunk_size))
|
||||
|
||||
if nodes:
|
||||
msg = "LLM index updated successfully."
|
||||
logger.info(
|
||||
"Updating %d nodes in LLM index.",
|
||||
len(nodes),
|
||||
)
|
||||
index.insert_nodes(nodes)
|
||||
else:
|
||||
msg = "No changes detected in LLM index."
|
||||
logger.info(msg)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
try:
|
||||
with _exclude_readers():
|
||||
store.compact()
|
||||
except Timeout:
|
||||
logger.info(
|
||||
"Skipping LLM index compaction: index readers are active; "
|
||||
"will retry next run.",
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def llm_index_add_or_update_document(document: Document):
|
||||
"""
|
||||
Adds or updates a document in the LLM index.
|
||||
If the document already exists, it will be replaced.
|
||||
"""
|
||||
new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size())
|
||||
if not new_nodes:
|
||||
logger.warning(
|
||||
"No indexable content for document %s; skipping LLM index update.",
|
||||
document.pk,
|
||||
)
|
||||
return
|
||||
"""Add or atomically replace a document's chunks in the index."""
|
||||
config = AIConfig()
|
||||
new_nodes = build_document_node(
|
||||
document,
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
)
|
||||
if new_nodes:
|
||||
_embed_nodes(new_nodes, get_embedding_model(config))
|
||||
|
||||
with FileLock(_index_lock_path()):
|
||||
index = load_or_build_index(nodes=new_nodes)
|
||||
with write_store(embed_model_name=get_configured_model_name(config)) as store:
|
||||
store.upsert_document(str(document.id), new_nodes)
|
||||
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
index.insert_nodes(new_nodes)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
def llm_index_compact() -> None:
|
||||
"""Compact the index immediately, rebuilding the table to reclaim space."""
|
||||
with write_store() as store:
|
||||
try:
|
||||
with _exclude_readers():
|
||||
store.compact(force=True)
|
||||
except Timeout:
|
||||
logger.info(
|
||||
"Skipping LLM index compaction: index readers are active; "
|
||||
"will retry next run.",
|
||||
)
|
||||
|
||||
|
||||
def llm_index_remove_document(document: Document):
|
||||
"""
|
||||
Removes a document from the LLM index.
|
||||
"""
|
||||
with FileLock(_index_lock_path()):
|
||||
index = load_or_build_index()
|
||||
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
"""Remove a document's chunks from the LLM index."""
|
||||
with write_store() as store:
|
||||
store.delete(str(document.id))
|
||||
|
||||
|
||||
def truncate_content(
|
||||
@@ -399,6 +443,18 @@ def truncate_content(
|
||||
return " ".join(truncated_chunks)
|
||||
|
||||
|
||||
def truncate_embedding_query(content: str, *, chunk_size: int) -> str:
|
||||
from llama_index.core.text_splitter import TokenTextSplitter
|
||||
|
||||
splitter = TokenTextSplitter(
|
||||
separator=" ",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=0,
|
||||
)
|
||||
content_chunks = splitter.split_text(content)
|
||||
return content_chunks[0] if content_chunks else ""
|
||||
|
||||
|
||||
def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str] | None:
|
||||
if document_ids is None:
|
||||
return None
|
||||
@@ -410,77 +466,58 @@ def query_similar_documents(
|
||||
top_k: int = 5,
|
||||
document_ids: Iterable[int | str] | None = None,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Runs a similarity query and returns top-k similar Document objects.
|
||||
"""
|
||||
"""Return up to ``top_k`` Documents most similar to ``document``."""
|
||||
allowed_document_ids = normalize_document_ids(document_ids)
|
||||
if allowed_document_ids is not None and not allowed_document_ids:
|
||||
return []
|
||||
|
||||
if not vector_store_file_exists():
|
||||
if not llm_index_exists():
|
||||
queue_llm_index_update_if_needed(
|
||||
rebuild=False,
|
||||
reason="LLM index not found for similarity query.",
|
||||
)
|
||||
return []
|
||||
|
||||
with FileLock(_index_lock_path()):
|
||||
index = load_or_build_index()
|
||||
config = AIConfig()
|
||||
|
||||
# constrain only the node(s) that match the document IDs, if given
|
||||
doc_node_ids = (
|
||||
[
|
||||
node.node_id
|
||||
for node in index.docstore.docs.values()
|
||||
if node.metadata.get("document_id") in allowed_document_ids
|
||||
]
|
||||
if allowed_document_ids is not None
|
||||
else None
|
||||
)
|
||||
if doc_node_ids is not None and not doc_node_ids:
|
||||
return []
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
filters = (
|
||||
_document_id_filters(allowed_document_ids)
|
||||
if allowed_document_ids is not None
|
||||
else None
|
||||
)
|
||||
|
||||
query_text = truncate_embedding_query(
|
||||
(document.title or "") + "\n" + (document.content or ""),
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
)
|
||||
# Hold the shared read lock for the whole retrieval so the connection is
|
||||
# never open across a compaction swap. The retrieve() call generates a
|
||||
# query embedding (a slow external request) and searches the vector store;
|
||||
# no Django ORM access happens during it, so release the pooled DB
|
||||
# connection for its duration. See #12976.
|
||||
with read_store() as store:
|
||||
index = load_or_build_index(config, store)
|
||||
retriever = VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=top_k,
|
||||
doc_ids=doc_node_ids,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
config = AIConfig()
|
||||
query_text = truncate_content(
|
||||
(document.title or "") + "\n" + (document.content or ""),
|
||||
chunk_size=config.llm_embedding_chunk_size,
|
||||
context_size=config.llm_context_size,
|
||||
)
|
||||
try:
|
||||
with db_connection_released():
|
||||
results = retriever.retrieve(query_text)
|
||||
except KeyError as e:
|
||||
# Ghost FAISS positions remain after deletion because IndexFlatL2 is
|
||||
# append-only. Treat them as absent and return no results.
|
||||
logger.debug(
|
||||
"Skipping LLM similarity query for document %s due to a stale "
|
||||
"FAISS position with no docstore node: %s",
|
||||
document.pk,
|
||||
e,
|
||||
)
|
||||
return []
|
||||
|
||||
retrieved_document_ids: list[int] = []
|
||||
for node in results:
|
||||
document_id = node.metadata.get("document_id")
|
||||
if document_id is None:
|
||||
continue
|
||||
normalized_document_id = str(document_id)
|
||||
if (
|
||||
allowed_document_ids is not None
|
||||
and normalized_document_id not in allowed_document_ids
|
||||
):
|
||||
normalized = str(document_id)
|
||||
if allowed_document_ids is not None and normalized not in allowed_document_ids:
|
||||
continue
|
||||
try:
|
||||
retrieved_document_ids.append(int(normalized_document_id))
|
||||
except ValueError:
|
||||
retrieved_document_ids.append(int(normalized))
|
||||
except ValueError: # pragma: no cover
|
||||
logger.warning(
|
||||
"Skipping LLM index result with invalid document_id %r.",
|
||||
document_id,
|
||||
|
||||
@@ -1,10 +1,36 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from pytest_django.fixtures import SettingsWrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper):
|
||||
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path:
|
||||
settings.LLM_INDEX_DIR = tmp_path
|
||||
settings.LLM_INDEX_LOCK = tmp_path / "index.lock"
|
||||
settings.LLM_INDEX_RWLOCK = tmp_path / "llmindex.rwlock.db"
|
||||
return tmp_path
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
async def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def get_query_embedding_dim(self) -> int:
|
||||
return 384
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model(mocker: pytest_mock.MockerFixture) -> pytest_mock.MockType:
|
||||
fake = FakeEmbedding()
|
||||
mocker.patch("paperless_ai.indexing.get_embedding_model", return_value=fake)
|
||||
mocker.patch("paperless_ai.embedding.get_embedding_model", return_value=fake)
|
||||
return fake
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
from django.test import override_settings
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.ai_classifier import build_localization_prompt
|
||||
from paperless_ai.ai_classifier import build_prompt_with_rag
|
||||
from paperless_ai.ai_classifier import build_prompt_without_rag
|
||||
@@ -211,11 +212,12 @@ def test_prompt_with_without_rag(mock_document):
|
||||
"paperless_ai.ai_classifier.get_context_for_document",
|
||||
return_value="Context from similar documents",
|
||||
):
|
||||
prompt = build_prompt_without_rag(mock_document)
|
||||
config = AIConfig()
|
||||
prompt = build_prompt_without_rag(mock_document, config)
|
||||
assert "Additional context from similar documents" not in prompt
|
||||
assert "for generated" not in prompt
|
||||
|
||||
prompt = build_prompt_with_rag(mock_document)
|
||||
prompt = build_prompt_with_rag(mock_document, config)
|
||||
assert "Additional context from similar documents" in prompt
|
||||
|
||||
prompt = build_localization_prompt(
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
from faker import Faker
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.core.schema import MetadataMode
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import PaperlessTask
|
||||
@@ -19,10 +16,12 @@ from documents.tests.factories import DocumentFactory
|
||||
from documents.tests.factories import PaperlessTaskFactory
|
||||
from paperless.models import ApplicationConfiguration
|
||||
from paperless_ai import indexing
|
||||
from paperless_ai.tests.conftest import FakeEmbedding
|
||||
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_document(db):
|
||||
def real_document(db: None) -> Document:
|
||||
return Document.objects.create(
|
||||
title="Test Document",
|
||||
content="This is some test content.",
|
||||
@@ -30,44 +29,39 @@ def real_document(db):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model():
|
||||
fake = FakeEmbedding()
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
|
||||
patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
) as mock_embedding,
|
||||
):
|
||||
mock_index.return_value = fake
|
||||
mock_embedding.return_value = fake
|
||||
yield mock_index
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
# TODO: maybe a better way to do this?
|
||||
def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def get_query_embedding_dim(self) -> int:
|
||||
return 384 # Match your real FAISS config
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node(real_document) -> None:
|
||||
def test_build_document_node(real_document: Document) -> None:
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0
|
||||
assert nodes[0].metadata["document_id"] == str(real_document.id)
|
||||
assert nodes[0].metadata["filename"] == real_document.filename
|
||||
assert nodes[0].metadata["storage_path"] == (
|
||||
real_document.storage_path.name if real_document.storage_path else None
|
||||
)
|
||||
assert (
|
||||
nodes[0].metadata["archive_serial_number"]
|
||||
== real_document.archive_serial_number
|
||||
)
|
||||
assert "filename" in nodes[0].excluded_embed_metadata_keys
|
||||
assert "filename" not in nodes[0].excluded_llm_metadata_keys
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node_excludes_metadata_from_embedding(real_document) -> None:
|
||||
def test_build_document_node_sets_ref_doc_id(real_document: Document) -> None:
|
||||
"""Every node produced by build_document_node must carry the paperless document id
|
||||
as its ref_doc_id so that the vector store's delete(str(doc.id)) works correctly."""
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0, "Expected at least one node"
|
||||
for node in nodes:
|
||||
assert node.ref_doc_id == str(real_document.id), (
|
||||
f"Expected ref_doc_id={real_document.id!r}, got {node.ref_doc_id!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node_excludes_metadata_from_embedding(
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
"""Metadata keys must not be prepended to the embedding text.
|
||||
|
||||
build_llm_index_text already encodes all metadata in the body text, so
|
||||
@@ -75,8 +69,6 @@ def test_build_document_node_excludes_metadata_from_embedding(real_document) ->
|
||||
double the token count and exceed embedding models with small context
|
||||
windows (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
|
||||
"""
|
||||
from llama_index.core.schema import MetadataMode
|
||||
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
for node in nodes:
|
||||
embed_text = node.get_content(metadata_mode=MetadataMode.EMBED)
|
||||
@@ -87,7 +79,36 @@ def test_build_document_node_excludes_metadata_from_embedding(real_document) ->
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node_uses_rag_chunk_settings(real_document) -> None:
|
||||
def test_build_document_node_structured_fields_in_metadata(
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
"""Structured fields must be in node.metadata so the LLM receives them via metadata prepend."""
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0
|
||||
for node in nodes:
|
||||
assert "title" in node.metadata
|
||||
assert "tags" in node.metadata
|
||||
assert "correspondent" in node.metadata
|
||||
assert "document_type" in node.metadata
|
||||
assert "created" in node.metadata
|
||||
assert "added" in node.metadata
|
||||
assert "modified" in node.metadata
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node_excludes_document_id_from_llm_context(
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
"""document_id is an internal key and must not appear in LLM context text."""
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0
|
||||
for node in nodes:
|
||||
assert "document_id" in node.excluded_llm_metadata_keys
|
||||
assert "document_id" not in node.get_content(metadata_mode=MetadataMode.LLM)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node_uses_rag_chunk_settings(real_document: Document) -> None:
|
||||
app_config, _ = ApplicationConfiguration.objects.get_or_create()
|
||||
app_config.llm_embedding_chunk_size = 512
|
||||
app_config.save()
|
||||
@@ -116,11 +137,21 @@ def test_get_rag_prompt_helper_uses_context_setting() -> None:
|
||||
assert prompt_helper.context_window == 4096
|
||||
|
||||
|
||||
def test_truncate_embedding_query_returns_single_chunk() -> None:
|
||||
content = " ".join(f"word{i}" for i in range(200))
|
||||
|
||||
result = indexing.truncate_embedding_query(content, chunk_size=32)
|
||||
|
||||
assert result
|
||||
assert result != content
|
||||
assert "word199" not in result
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
mock_config = MagicMock()
|
||||
mock_config.llm_embedding_chunk_size = 512
|
||||
@@ -138,44 +169,49 @@ def test_update_llm_index(
|
||||
|
||||
ai_config.assert_called_once()
|
||||
build_document_node.assert_called_once_with(real_document, chunk_size=512)
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_removes_meta(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
def test_update_llm_index_rebuilds_on_model_name_change(
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
# Pre-create a meta.json with incorrect data
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "old", "dim": 1}),
|
||||
)
|
||||
|
||||
# Build initial index with model "model-a".
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mock_all.return_value = mock_queryset
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
with patch(
|
||||
"paperless_ai.indexing.get_configured_model_name",
|
||||
return_value="model-a",
|
||||
):
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||
from paperless.config import AIConfig
|
||||
# Simulate config change to "model-b"; the incremental run must force a rebuild.
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mock_all.return_value = mock_queryset
|
||||
with patch(
|
||||
"paperless_ai.indexing.get_configured_model_name",
|
||||
return_value="model-b",
|
||||
):
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
|
||||
config = AIConfig()
|
||||
expected_model = config.llm_embedding_model or (
|
||||
"text-embedding-3-small"
|
||||
if config.llm_embedding_backend == "openai-like"
|
||||
else "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
assert meta == {"embedding_model": expected_model, "dim": 384}
|
||||
with indexing.get_vector_store() as store:
|
||||
# Schema metadata only updates when the table is dropped and recreated, never
|
||||
# on incremental writes -- so "model-b" here proves a full rebuild happened.
|
||||
assert store.stored_model_name() == "model-b"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_partial_update(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
doc2 = Document.objects.create(
|
||||
title="Test Document 2",
|
||||
@@ -210,131 +246,34 @@ def test_update_llm_index_partial_update(
|
||||
mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3])
|
||||
mock_all.return_value = mock_queryset
|
||||
|
||||
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
|
||||
with patch("paperless_ai.indexing.logger") as mock_logger:
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
mock_logger.info.assert_called_once_with(
|
||||
"Updating %d nodes in LLM index.",
|
||||
2,
|
||||
)
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
def test_get_or_create_storage_context_raises_exception(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
with pytest.raises(Exception):
|
||||
indexing.get_or_create_storage_context(rebuild=False)
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
)
|
||||
def test_load_or_build_index_builds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"llama_index.core.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
) as mock_index_cls,
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
) as mock_storage,
|
||||
):
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
indexing.load_or_build_index(
|
||||
nodes=[indexing.build_document_node(real_document)],
|
||||
with indexing.get_vector_store() as store:
|
||||
assert store.table_exists(), (
|
||||
"Expected the vector store table to exist after incremental update"
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
|
||||
|
||||
def test_load_or_build_index_raises_exception_when_no_nodes(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
indexing.load_or_build_index()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_load_or_build_index_succeeds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"llama_index.core.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"llama_index.core.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
) as mock_index_cls,
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
) as mock_storage,
|
||||
):
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
indexing.load_or_build_index(
|
||||
nodes=[MagicMock()],
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_add_or_update_document_updates_existing_entry(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_remove_document_deletes_node_from_docstore(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
index = indexing.load_or_build_index()
|
||||
assert len(index.docstore.docs) == 1
|
||||
|
||||
indexing.llm_index_remove_document(real_document)
|
||||
index = indexing.load_or_build_index()
|
||||
assert len(index.docstore.docs) == 0
|
||||
with indexing.get_vector_store() as store:
|
||||
assert store.table_exists(), (
|
||||
"Expected the vector store table to exist after add-or-update"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_query_after_remove_does_not_raise_key_error(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
@@ -352,8 +291,8 @@ def test_query_after_remove_does_not_raise_key_error(
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_no_documents(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
@@ -369,6 +308,22 @@ def test_update_llm_index_no_documents(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_no_documents_no_index_returns_early(
|
||||
temp_llm_index_dir: Path,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""update with no documents and no existing index must return early."""
|
||||
mock_qs = MagicMock()
|
||||
mock_qs.exists.return_value = False
|
||||
mock_qs.__iter__ = MagicMock(return_value=iter([]))
|
||||
mocker.patch("paperless_ai.indexing.Document.objects.all", return_value=mock_qs)
|
||||
|
||||
result = indexing.update_llm_index(rebuild=False)
|
||||
|
||||
assert result == "No documents found to index."
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -> None:
|
||||
# No existing tasks
|
||||
@@ -406,20 +361,17 @@ def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -
|
||||
LLM_BACKEND="ollama",
|
||||
)
|
||||
def test_query_similar_documents(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
"paperless_ai.indexing.llm_index_exists",
|
||||
) as mock_vector_store_exists,
|
||||
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
|
||||
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
|
||||
):
|
||||
mock_storage.return_value = MagicMock()
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
mock_vector_store_exists.return_value = True
|
||||
|
||||
mock_index = MagicMock()
|
||||
@@ -451,14 +403,50 @@ def test_query_similar_documents(
|
||||
assert result == mock_filtered_docs
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_EMBEDDING_CHUNK_SIZE=32,
|
||||
LLM_BACKEND="ollama",
|
||||
)
|
||||
def test_query_similar_documents_truncates_query_to_embedding_chunk_size(
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
real_document.content = " ".join(f"word{i}" for i in range(200))
|
||||
with (
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
patch(
|
||||
"paperless_ai.indexing.llm_index_exists",
|
||||
) as mock_vector_store_exists,
|
||||
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
|
||||
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
|
||||
patch("paperless_ai.indexing.truncate_content") as mock_truncate_content,
|
||||
):
|
||||
mock_vector_store_exists.return_value = True
|
||||
mock_load_or_build_index.return_value = MagicMock()
|
||||
mock_truncate_content.return_value = "wrong helper"
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = []
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
mock_filter.return_value = []
|
||||
|
||||
indexing.query_similar_documents(real_document, top_k=3)
|
||||
|
||||
mock_truncate_content.assert_not_called()
|
||||
query_text = mock_retriever.retrieve.call_args.args[0]
|
||||
assert query_text
|
||||
assert "word199" not in query_text
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_query_similar_documents_triggers_update_when_index_missing(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
temp_llm_index_dir: Path,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
"paperless_ai.indexing.llm_index_exists",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
@@ -479,120 +467,13 @@ def test_query_similar_documents_triggers_update_when_index_missing(
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_query_similar_documents_normalizes_and_post_filters_allowed_ids(
|
||||
real_document,
|
||||
) -> None:
|
||||
real_document.owner = User.objects.create_user(username="rag-owner")
|
||||
real_document.save()
|
||||
private_owner = User.objects.create_user(username="rag-private-owner")
|
||||
private_document = Document.objects.create(
|
||||
title="Private similar document",
|
||||
content="Similar private content that must not reach RAG.",
|
||||
owner=private_owner,
|
||||
added=timezone.now(),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
return_value=True,
|
||||
),
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
|
||||
):
|
||||
allowed_node = MagicMock()
|
||||
allowed_node.node_id = "allowed-node"
|
||||
allowed_node.metadata = {"document_id": str(real_document.pk)}
|
||||
private_node = MagicMock()
|
||||
private_node.node_id = "private-node"
|
||||
private_node.metadata = {"document_id": str(private_document.pk)}
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [allowed_node, private_node]
|
||||
mock_load_or_build_index.return_value = mock_index
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = [private_node, allowed_node]
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
|
||||
result = indexing.query_similar_documents(
|
||||
real_document,
|
||||
top_k=2,
|
||||
document_ids=[real_document.pk],
|
||||
)
|
||||
|
||||
mock_retriever_cls.assert_called_once_with(
|
||||
index=mock_index,
|
||||
similarity_top_k=2,
|
||||
doc_ids=["allowed-node"],
|
||||
)
|
||||
assert result == [real_document]
|
||||
assert private_document not in result
|
||||
|
||||
|
||||
class TestUpdateLlmIndexStaleNodes:
|
||||
"""Tests that update_llm_index removes ALL nodes for a multi-chunk document."""
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_incremental_update_removes_all_old_nodes_for_multi_chunk_document(
|
||||
self,
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model: MagicMock,
|
||||
) -> None:
|
||||
"""Ghost nodes from all chunks of a modified document must be removed.
|
||||
|
||||
When a document is split into multiple chunks (chunk_size=1024), the
|
||||
incremental update path must delete every old node, not just the last
|
||||
one captured by a dict comprehension keyed on document_id.
|
||||
"""
|
||||
# Content long enough to produce at least two chunks at chunk_size=1024.
|
||||
# Generate many paragraphs so the token count comfortably exceeds 1024.
|
||||
fake = Faker()
|
||||
long_content = "\n\n".join(fake.paragraph(nb_sentences=20) for _ in range(20))
|
||||
doc = DocumentFactory(content=long_content)
|
||||
|
||||
# Build the initial index (rebuild=True) so it has multiple nodes
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
# Verify the initial index has more than one node for this document
|
||||
initial_index = indexing.load_or_build_index()
|
||||
initial_node_ids = [
|
||||
nid
|
||||
for nid, node in initial_index.docstore.docs.items()
|
||||
if node.metadata.get("document_id") == str(doc.id)
|
||||
]
|
||||
assert len(initial_node_ids) > 1, (
|
||||
f"Expected multiple chunks but got {len(initial_node_ids)}; "
|
||||
"increase long_content length"
|
||||
)
|
||||
|
||||
# Simulate a modification so the incremental path treats it as changed.
|
||||
# Use queryset.update() to bypass auto_now and actually change the DB value.
|
||||
new_modified = timezone.now()
|
||||
Document.objects.filter(pk=doc.pk).update(modified=new_modified)
|
||||
|
||||
# Run incremental update (rebuild=False) with the modified document
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
|
||||
# Reload the persisted index and check that no OLD node ids remain
|
||||
updated_index = indexing.load_or_build_index()
|
||||
remaining_old_node_ids = [
|
||||
nid for nid in initial_node_ids if nid in updated_index.docstore.docs
|
||||
]
|
||||
assert remaining_old_node_ids == [], (
|
||||
f"Ghost nodes still present after incremental update: "
|
||||
f"{remaining_old_node_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_query_similar_documents_empty_allow_list_fails_closed(
|
||||
real_document,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
"paperless_ai.indexing.llm_index_exists",
|
||||
return_value=True,
|
||||
) as mock_vector_store_exists,
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
@@ -610,27 +491,25 @@ def test_query_similar_documents_empty_allow_list_fails_closed(
|
||||
|
||||
|
||||
class TestUpdateLlmIndexEmptyDocumentSet:
|
||||
"""update_llm_index must persist an empty index when all documents are deleted.
|
||||
"""update_llm_index must clear the vector store table when all documents are deleted.
|
||||
|
||||
Without this, the stale on-disk FAISS vectors are never cleared and
|
||||
subsequent similarity searches return phantom hits for document IDs that
|
||||
no longer exist in the DB.
|
||||
Without this, the stale vectors are never cleared and subsequent similarity
|
||||
searches return phantom hits for document IDs that no longer exist in the DB.
|
||||
"""
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_rebuild_clears_stale_index_when_no_documents_exist(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: MagicMock,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
"""After deleting all documents, rebuild=True must persist an empty index.
|
||||
"""After deleting all documents, rebuild=True must produce a table with zero rows.
|
||||
|
||||
Steps:
|
||||
1. Build an index with one document so the on-disk state is non-empty.
|
||||
2. Delete all documents from the DB.
|
||||
3. Call update_llm_index(rebuild=True).
|
||||
4. Reload the index from disk.
|
||||
5. Assert the reloaded index has zero nodes (no phantom vectors).
|
||||
4. Open the LanceDB table directly and assert zero rows.
|
||||
"""
|
||||
# Step 1: create a document and build a non-empty index
|
||||
Document.objects.create(
|
||||
@@ -640,27 +519,26 @@ class TestUpdateLlmIndexEmptyDocumentSet:
|
||||
)
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
initial_index = indexing.load_or_build_index()
|
||||
assert len(initial_index.docstore.docs) > 0, (
|
||||
"Precondition failed: expected at least one node before deletion"
|
||||
)
|
||||
with indexing.get_vector_store() as store:
|
||||
assert store.table_exists(), (
|
||||
"Precondition failed: expected the vector store table to exist "
|
||||
"before deletion"
|
||||
)
|
||||
|
||||
# Step 2: delete all documents
|
||||
Document.objects.all().delete()
|
||||
assert not Document.objects.exists()
|
||||
|
||||
# Step 3: rebuild with no documents
|
||||
# Step 3: rebuild with no documents — drop_table is called so the table
|
||||
# is removed (no rows to re-insert, so it stays absent).
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
# Step 4: reload the persisted index from disk
|
||||
reloaded_index = indexing.load_or_build_index()
|
||||
|
||||
# Step 5: phantom vectors must be gone
|
||||
assert len(reloaded_index.docstore.docs) == 0, (
|
||||
f"Expected 0 nodes after clearing all documents, "
|
||||
f"but found {len(reloaded_index.docstore.docs)}: "
|
||||
f"{list(reloaded_index.docstore.docs.keys())}"
|
||||
)
|
||||
# Step 4: the table must be absent (no rows) — phantom vectors gone
|
||||
with indexing.get_vector_store() as store2:
|
||||
assert not store2.table_exists(), (
|
||||
"Expected the vector store table to be absent after rebuilding "
|
||||
"with no documents"
|
||||
)
|
||||
|
||||
|
||||
class TestDocumentUpdatedSignalTriggersLlmReindex:
|
||||
@@ -709,10 +587,14 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
|
||||
def test_returns_without_error_when_build_document_node_returns_empty(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: MagicMock,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""When build_document_node returns [], the function must return without error
|
||||
and must not call load_or_build_index at all."""
|
||||
"""When build_document_node returns [], the function must return without error.
|
||||
|
||||
The store's upsert_document treats an empty node list as a removal (no-op
|
||||
delete), so load_or_build_index must not be called.
|
||||
"""
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.build_document_node",
|
||||
return_value=[],
|
||||
@@ -720,6 +602,7 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
|
||||
mock_load = mocker.patch("paperless_ai.indexing.load_or_build_index")
|
||||
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.id = 42
|
||||
# Must not raise
|
||||
indexing.llm_index_add_or_update_document(doc)
|
||||
|
||||
@@ -727,172 +610,165 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestLlmIndexLocking:
|
||||
"""The FAISS index mutation functions must acquire the index lock before touching the index.
|
||||
def test_llm_index_compact_uses_force(
|
||||
temp_llm_index_dir: Path,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""compact must use force=True to rebuild the table and reclaim space immediately."""
|
||||
mock_store = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.write_store",
|
||||
return_value=mocker.MagicMock(
|
||||
__enter__=mocker.MagicMock(return_value=mock_store),
|
||||
__exit__=mocker.MagicMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
Without locking, two concurrent Celery workers can each load the same
|
||||
on-disk index, make independent modifications, and the last writer silently
|
||||
overwrites the first's changes.
|
||||
indexing.llm_index_compact()
|
||||
|
||||
mock_store.compact.assert_called_once_with(force=True)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestLlmIndexLocking:
|
||||
"""Index mutation functions must go through write_store(), which holds the lock.
|
||||
|
||||
Without locking, two concurrent Celery workers can open the same store,
|
||||
make independent modifications, and trigger CommitConflictError.
|
||||
"""
|
||||
|
||||
def test_add_or_update_document_acquires_lock(
|
||||
def test_add_or_update_document_uses_write_store(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""llm_index_add_or_update_document must enter the file lock before touching the index."""
|
||||
call_order: list[str] = []
|
||||
|
||||
mock_lock_instance = MagicMock()
|
||||
mock_lock_instance.__enter__ = MagicMock(
|
||||
side_effect=lambda *_: call_order.append("lock_acquired"),
|
||||
)
|
||||
mock_lock_instance.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_file_lock_cls = mocker.patch(
|
||||
"paperless_ai.indexing.FileLock",
|
||||
return_value=mock_lock_instance,
|
||||
)
|
||||
|
||||
mock_load = mocker.patch(
|
||||
"paperless_ai.indexing.load_or_build_index",
|
||||
side_effect=lambda *_a, **_kw: (
|
||||
call_order.append("index_loaded") or MagicMock()
|
||||
mock_store = MagicMock()
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.write_store",
|
||||
return_value=mocker.MagicMock(
|
||||
__enter__=mocker.MagicMock(return_value=mock_store),
|
||||
__exit__=mocker.MagicMock(return_value=False),
|
||||
),
|
||||
)
|
||||
mock_node = MagicMock()
|
||||
mock_node.get_content.return_value = "fake node text"
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.build_document_node",
|
||||
return_value=[MagicMock()],
|
||||
return_value=[mock_node],
|
||||
)
|
||||
mocker.patch("paperless_ai.indexing.remove_document_docstore_nodes")
|
||||
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.id = 1
|
||||
indexing.llm_index_add_or_update_document(doc)
|
||||
|
||||
mock_file_lock_cls.assert_called_once()
|
||||
mock_lock_instance.__enter__.assert_called_once()
|
||||
mock_load.assert_called_once()
|
||||
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
|
||||
"Lock must be acquired before the index is loaded"
|
||||
)
|
||||
mock_store.upsert_document.assert_called_once()
|
||||
|
||||
def test_remove_document_acquires_lock(
|
||||
def test_remove_document_uses_write_store(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""llm_index_remove_document must enter the file lock before loading the index."""
|
||||
call_order: list[str] = []
|
||||
|
||||
mock_lock_instance = MagicMock()
|
||||
mock_lock_instance.__enter__ = MagicMock(
|
||||
side_effect=lambda *_: call_order.append("lock_acquired"),
|
||||
)
|
||||
mock_lock_instance.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_file_lock_cls = mocker.patch(
|
||||
"paperless_ai.indexing.FileLock",
|
||||
return_value=mock_lock_instance,
|
||||
)
|
||||
|
||||
mock_load = mocker.patch(
|
||||
"paperless_ai.indexing.load_or_build_index",
|
||||
side_effect=lambda *_a, **_kw: (
|
||||
call_order.append("index_loaded") or MagicMock()
|
||||
mock_store = MagicMock()
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.write_store",
|
||||
return_value=mocker.MagicMock(
|
||||
__enter__=mocker.MagicMock(return_value=mock_store),
|
||||
__exit__=mocker.MagicMock(return_value=False),
|
||||
),
|
||||
)
|
||||
mocker.patch("paperless_ai.indexing.remove_document_docstore_nodes")
|
||||
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.id = 1
|
||||
indexing.llm_index_remove_document(doc)
|
||||
|
||||
mock_file_lock_cls.assert_called_once()
|
||||
mock_lock_instance.__enter__.assert_called_once()
|
||||
mock_load.assert_called_once()
|
||||
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
|
||||
"Lock must be acquired before the index is loaded"
|
||||
)
|
||||
mock_store.delete.assert_called_once_with("1")
|
||||
|
||||
def test_update_llm_index_rebuild_acquires_lock(
|
||||
def test_update_llm_index_rebuild_uses_write_store(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: MagicMock,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
) -> None:
|
||||
"""update_llm_index must enter the file lock during the rebuild/persist cycle."""
|
||||
mock_lock_instance = MagicMock()
|
||||
mock_lock_instance.__enter__ = MagicMock(return_value=None)
|
||||
mock_lock_instance.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_file_lock_cls = mocker.patch(
|
||||
"paperless_ai.indexing.FileLock",
|
||||
return_value=mock_lock_instance,
|
||||
mock_store = MagicMock()
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.write_store",
|
||||
return_value=mocker.MagicMock(
|
||||
__enter__=mocker.MagicMock(return_value=mock_store),
|
||||
__exit__=mocker.MagicMock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
# exists=True so the code reaches the lock; iterate over an empty
|
||||
# queryset so VectorStoreIndex is called with no nodes (still exercises
|
||||
# the lock path without needing heavy FAISS fixture data)
|
||||
mock_qs = MagicMock()
|
||||
mock_qs.exists.return_value = True
|
||||
mock_qs.__iter__ = MagicMock(return_value=iter([]))
|
||||
mocker.patch("paperless_ai.indexing.Document.objects.all", return_value=mock_qs)
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
mock_file_lock_cls.assert_called_once()
|
||||
mock_lock_instance.__enter__.assert_called_once()
|
||||
mock_store.drop_table.assert_called_once()
|
||||
|
||||
def test_query_similar_documents_acquires_lock(
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.mark.django_db
|
||||
class TestVectorStoreIndexing:
|
||||
def test_get_vector_store_roundtrip(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
"""query_similar_documents must enter the file lock before loading the index."""
|
||||
call_order: list[str] = []
|
||||
with indexing.get_vector_store() as store:
|
||||
assert isinstance(store, PaperlessSqliteVecVectorStore)
|
||||
|
||||
mock_lock_instance = MagicMock()
|
||||
mock_lock_instance.__enter__ = MagicMock(
|
||||
side_effect=lambda *_: call_order.append("lock_acquired"),
|
||||
)
|
||||
mock_lock_instance.__exit__ = MagicMock(return_value=False)
|
||||
def test_add_then_remove_document(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
with indexing.get_vector_store() as store:
|
||||
assert store.table_exists()
|
||||
count_sql = "SELECT count(*) FROM documents"
|
||||
assert store.client.execute(count_sql).fetchone()[0] >= 1
|
||||
|
||||
mock_file_lock_cls = mocker.patch(
|
||||
"paperless_ai.indexing.FileLock",
|
||||
return_value=mock_lock_instance,
|
||||
)
|
||||
indexing.llm_index_remove_document(real_document)
|
||||
assert store.client.execute(count_sql).fetchone()[0] == 0
|
||||
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.vector_store_file_exists",
|
||||
return_value=True,
|
||||
)
|
||||
def test_update_shrinks_chunks_without_orphans(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
real_document: Document,
|
||||
) -> None:
|
||||
real_document.content = "word " * 4000 # many chunks
|
||||
real_document.save()
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
count_sql = "SELECT count(*) FROM documents"
|
||||
with indexing.get_vector_store() as store:
|
||||
big = store.client.execute(count_sql).fetchone()[0]
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs = {}
|
||||
real_document.content = "short" # one chunk
|
||||
real_document.save()
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
|
||||
mocker.patch(
|
||||
"paperless_ai.indexing.load_or_build_index",
|
||||
side_effect=lambda *_a, **_kw: (
|
||||
call_order.append("index_loaded") or mock_index
|
||||
),
|
||||
)
|
||||
rows = store.client.execute(count_sql).fetchone()[0]
|
||||
assert rows < big
|
||||
assert rows >= 1
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = []
|
||||
mocker.patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
return_value=mock_retriever,
|
||||
)
|
||||
|
||||
mocker.patch("paperless_ai.indexing.truncate_content", return_value="")
|
||||
@pytest.mark.django_db
|
||||
class TestQuerySimilarDocuments:
|
||||
def test_query_similar_documents_respects_allowed_ids(
|
||||
self,
|
||||
temp_llm_index_dir: Path,
|
||||
mock_embed_model: FakeEmbedding,
|
||||
) -> None:
|
||||
a = DocumentFactory.create(content="alpha shared content here")
|
||||
b = DocumentFactory.create(content="beta shared content here")
|
||||
c = DocumentFactory.create(content="gamma shared content here")
|
||||
for doc in (a, b, c):
|
||||
indexing.llm_index_add_or_update_document(doc)
|
||||
|
||||
indexing.query_similar_documents(MagicMock(spec=Document))
|
||||
results = indexing.query_similar_documents(a, document_ids=[b.id])
|
||||
|
||||
mock_file_lock_cls.assert_called()
|
||||
mock_lock_instance.__enter__.assert_called()
|
||||
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
|
||||
"Lock must be acquired before the index is loaded"
|
||||
)
|
||||
assert all(doc.id == b.id for doc in results)
|
||||
|
||||
+110
-130
@@ -3,19 +3,20 @@ from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core import settings as llama_settings
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
from documents.tests.factories import DocumentFactory
|
||||
from paperless_ai import chat
|
||||
from paperless_ai import indexing
|
||||
from paperless_ai.chat import CHAT_ERROR_MESSAGE
|
||||
from paperless_ai.chat import CHAT_METADATA_DELIMITER
|
||||
from paperless_ai.chat import _get_document_filtered_retriever
|
||||
from paperless_ai.chat import stream_chat_with_documents
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_embed_model():
|
||||
from llama_index.core import settings as llama_settings
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
|
||||
# Use a real BaseEmbedding subclass to satisfy llama-index 0.14 validation
|
||||
llama_settings.Settings.embed_model = MockEmbedding(embed_dim=1536)
|
||||
yield
|
||||
@@ -58,91 +59,6 @@ def assert_chat_output(
|
||||
}
|
||||
|
||||
|
||||
def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None:
|
||||
mock_index.index_struct.nodes_dict = {
|
||||
str(vector_id): node.node_id for vector_id, node in enumerate(nodes)
|
||||
}
|
||||
mock_index.docstore.docs.get.side_effect = {
|
||||
node.node_id: node for node in nodes
|
||||
}.get
|
||||
mock_index.vector_store._faiss_index.ntotal = len(nodes)
|
||||
mock_index.vector_store.query.return_value = MagicMock(
|
||||
ids=list(mock_index.index_struct.nodes_dict),
|
||||
similarities=[0.1] * len(nodes),
|
||||
)
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
|
||||
def test_document_filtered_retriever_expands_filters_and_caches() -> None:
|
||||
allowed_node1 = TextNode(
|
||||
text="Allowed content 1.",
|
||||
metadata={"document_id": "1", "title": "Allowed 1"},
|
||||
)
|
||||
allowed_node2 = TextNode(
|
||||
text="Allowed content 2.",
|
||||
metadata={"document_id": "2", "title": "Allowed 2"},
|
||||
)
|
||||
foreign_node = TextNode(
|
||||
text="Foreign content.",
|
||||
metadata={"document_id": "3", "title": "Foreign"},
|
||||
)
|
||||
missing_node = TextNode(
|
||||
text="Missing content.",
|
||||
metadata={"document_id": "1", "title": "Missing"},
|
||||
)
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_index.index_struct.nodes_dict = {
|
||||
"0": foreign_node.node_id,
|
||||
"1": missing_node.node_id,
|
||||
"2": allowed_node1.node_id,
|
||||
"3": allowed_node2.node_id,
|
||||
}
|
||||
mock_index.docstore.docs.get.side_effect = {
|
||||
allowed_node1.node_id: allowed_node1,
|
||||
allowed_node2.node_id: allowed_node2,
|
||||
foreign_node.node_id: foreign_node,
|
||||
}.get
|
||||
mock_index.vector_store._faiss_index.ntotal = 4
|
||||
mock_index.vector_store.query.side_effect = [
|
||||
MagicMock(ids=["0", "2"], similarities=[0.9, 0.8]),
|
||||
MagicMock(ids=["0", "1", "3"], similarities=[0.9, 0.7, 0.6]),
|
||||
]
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
mock_index,
|
||||
{"1", "2"},
|
||||
similarity_top_k=2,
|
||||
)
|
||||
|
||||
nodes = retriever.retrieve("question")
|
||||
cached_nodes = retriever.retrieve("question")
|
||||
|
||||
assert [node.node.node_id for node in nodes] == [
|
||||
allowed_node1.node_id,
|
||||
allowed_node2.node_id,
|
||||
]
|
||||
assert cached_nodes == nodes
|
||||
assert mock_index.vector_store.query.call_count == 2
|
||||
assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1
|
||||
|
||||
|
||||
def test_document_filtered_retriever_handles_empty_faiss_index() -> None:
|
||||
mock_index = MagicMock()
|
||||
mock_index.vector_store._faiss_index.ntotal = 0
|
||||
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
|
||||
|
||||
retriever = _get_document_filtered_retriever(
|
||||
mock_index,
|
||||
{"1"},
|
||||
similarity_top_k=2,
|
||||
)
|
||||
|
||||
assert retriever.retrieve("question") == []
|
||||
mock_index.vector_store.query.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_stream_chat_with_one_document_retrieval(
|
||||
mock_document,
|
||||
@@ -164,17 +80,31 @@ def test_stream_chat_with_one_document_retrieval(
|
||||
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||
add_vector_query_results(mock_index, [mock_node])
|
||||
# Simulate get_nodes returning nodes (content exists)
|
||||
mock_index.vector_store.get_nodes.return_value = [mock_node]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
mock_retriever_instance = MagicMock()
|
||||
mock_retriever_instance.retrieve.return_value = [
|
||||
MagicMock(
|
||||
metadata={
|
||||
"document_id": str(mock_document.pk),
|
||||
"title": "Test Document",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
with patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
return_value=mock_retriever_instance,
|
||||
):
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
|
||||
mock_query_engine.query.assert_called_once_with("What is this?")
|
||||
patch_embed_nodes.assert_not_called()
|
||||
@@ -196,12 +126,10 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
):
|
||||
# Mock AIClient and LLM
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
# Create two real TextNodes
|
||||
mock_node1 = TextNode(
|
||||
text="Content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1"},
|
||||
@@ -210,41 +138,32 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
text="Content for doc 2.",
|
||||
metadata={"document_id": "2", "title": "Document 2"},
|
||||
)
|
||||
mock_duplicate_node = TextNode(
|
||||
text="More content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
|
||||
)
|
||||
mock_foreign_node = TextNode(
|
||||
text="Content for doc 3.",
|
||||
metadata={"document_id": "3", "title": "Document 3"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [
|
||||
mock_node1,
|
||||
mock_node2,
|
||||
mock_duplicate_node,
|
||||
mock_foreign_node,
|
||||
]
|
||||
add_vector_query_results(
|
||||
mock_index,
|
||||
[mock_node1, mock_duplicate_node, mock_node2, mock_foreign_node],
|
||||
)
|
||||
# Simulate get_nodes returning nodes (content exists)
|
||||
mock_index.vector_store.get_nodes.return_value = [mock_node1, mock_node2]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
# Mock response stream
|
||||
mock_retriever_instance = MagicMock()
|
||||
mock_retriever_instance.retrieve.return_value = [
|
||||
MagicMock(metadata={"document_id": "1", "title": "Document 1"}),
|
||||
MagicMock(metadata={"document_id": "2", "title": "Document 2"}),
|
||||
]
|
||||
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
|
||||
# Mock RetrieverQueryEngine
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
|
||||
# Fake documents
|
||||
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
|
||||
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
|
||||
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
with patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
return_value=mock_retriever_instance,
|
||||
):
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
|
||||
mock_query_engine.query.assert_called_once_with("What's up?")
|
||||
patch_embed_nodes.assert_not_called()
|
||||
@@ -258,8 +177,16 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat_empty_document_list() -> None:
|
||||
with patch("paperless_ai.chat.load_or_build_index") as mock_load_index:
|
||||
output = list(stream_chat_with_documents("Any info?", []))
|
||||
mock_load_index.assert_not_called()
|
||||
assert output == ["Sorry, I couldn't find any content to answer your question."]
|
||||
|
||||
|
||||
def test_stream_chat_no_matching_nodes() -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIConfig"),
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
):
|
||||
@@ -268,8 +195,8 @@ def test_stream_chat_no_matching_nodes() -> None:
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
mock_index = MagicMock()
|
||||
# No matching nodes
|
||||
mock_index.docstore.docs.values.return_value = []
|
||||
# No matching nodes in the store
|
||||
mock_index.vector_store.get_nodes.return_value = []
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
@@ -279,30 +206,83 @@ def test_stream_chat_no_matching_nodes() -> None:
|
||||
|
||||
def test_stream_chat_unexpected_failure_returns_generic_error(caplog) -> None:
|
||||
with (
|
||||
patch("paperless_ai.chat.AIConfig"),
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"paperless_ai.chat._get_document_filtered_retriever",
|
||||
) as mock_get_retriever,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
mock_node = TextNode(
|
||||
text="This is node content.",
|
||||
metadata={"document_id": "1", "title": "Test Document"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||
# Nodes found so we get past the pre-check
|
||||
mock_index.vector_store.get_nodes.return_value = [MagicMock()]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.side_effect = RuntimeError("private provider detail")
|
||||
mock_get_retriever.return_value = mock_retriever
|
||||
with patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
) as mock_retriever_cls:
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.side_effect = RuntimeError(
|
||||
"private provider detail",
|
||||
)
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
|
||||
assert output == [CHAT_ERROR_MESSAGE]
|
||||
assert "Failed to stream document chat response" in caplog.text
|
||||
assert "private provider detail" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestStreamChatRetrieval:
|
||||
def test_no_nodes_yields_no_content_message(
|
||||
self,
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
) -> None:
|
||||
doc = DocumentFactory.create(content="hello world")
|
||||
# Nothing indexed for this document yet.
|
||||
out = list(chat.stream_chat_with_documents("question?", [doc]))
|
||||
assert chat.CHAT_NO_CONTENT_MESSAGE in out
|
||||
|
||||
def test_chat_filter_contains_only_requested_document_ids(
|
||||
self,
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
mocker,
|
||||
) -> None:
|
||||
"""The MetadataFilter passed to the retriever must be scoped to the
|
||||
requested documents only — content from other indexed documents must
|
||||
not be surfaced.
|
||||
"""
|
||||
included = DocumentFactory.create(content="included document content")
|
||||
excluded = DocumentFactory.create(content="excluded document content")
|
||||
indexing.llm_index_add_or_update_document(included)
|
||||
indexing.llm_index_add_or_update_document(excluded)
|
||||
|
||||
# VectorIndexRetriever is imported inside _stream_chat_with_documents;
|
||||
# patch it at the llama_index source so the lazy import picks it up.
|
||||
captured_filters = []
|
||||
mock_retriever = mocker.MagicMock()
|
||||
mock_retriever.retrieve.return_value = []
|
||||
|
||||
def capture_retriever(*args, **kwargs):
|
||||
captured_filters.append(kwargs.get("filters"))
|
||||
return mock_retriever
|
||||
|
||||
mocker.patch("paperless_ai.chat.AIClient")
|
||||
mocker.patch(
|
||||
"llama_index.core.retrievers.VectorIndexRetriever",
|
||||
side_effect=capture_retriever,
|
||||
)
|
||||
|
||||
list(chat.stream_chat_with_documents("question?", [included]))
|
||||
|
||||
assert captured_filters, "VectorIndexRetriever was never constructed"
|
||||
filt = captured_filters[0]
|
||||
assert filt is not None, "Retriever must receive a MetadataFilters"
|
||||
filter_values = filt.filters[0].value
|
||||
assert str(included.pk) in filter_values
|
||||
assert str(excluded.pk) not in filter_values
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user