Merge branch 'beta' into dev

This commit is contained in:
shamoon
2026-06-23 07:32:33 -07:00
109 changed files with 6532 additions and 1774 deletions
+1
View File
@@ -58,6 +58,7 @@ repos:
rev: "v2.24.1"
hooks:
- id: pyproject-fmt
additional_dependencies: [tomli]
# Dockerfile hooks
- repo: https://github.com/AleksaC/hadolint-py
rev: v2.14.0
@@ -30,7 +30,7 @@
# documentation.
services:
broker:
image: docker.io/library/redis:8
image: docker.io/valkey/valkey:9-alpine
restart: unless-stopped
volumes:
- redisdata:/data
+1 -1
View File
@@ -26,7 +26,7 @@
# documentation.
services:
broker:
image: docker.io/library/redis:8
image: docker.io/valkey/valkey:9-alpine
restart: unless-stopped
volumes:
- redisdata:/data
+1 -1
View File
@@ -27,7 +27,7 @@
# documentation.
services:
broker:
image: docker.io/library/redis:8
image: docker.io/valkey/valkey:9-alpine
restart: unless-stopped
volumes:
- redisdata:/data
@@ -30,7 +30,7 @@
# documentation.
services:
broker:
image: docker.io/library/redis:8
image: docker.io/valkey/valkey:9-alpine
restart: unless-stopped
volumes:
- redisdata:/data
+1 -1
View File
@@ -26,7 +26,7 @@
# documentation.
services:
broker:
image: docker.io/library/redis:8
image: docker.io/valkey/valkey:9-alpine
restart: unless-stopped
volumes:
- redisdata:/data
@@ -30,7 +30,7 @@
# documentation.
services:
broker:
image: docker.io/library/redis:8
image: docker.io/valkey/valkey:9-alpine
restart: unless-stopped
volumes:
- redisdata:/data
+1 -1
View File
@@ -23,7 +23,7 @@
# documentation.
services:
broker:
image: docker.io/library/redis:8
image: docker.io/valkey/valkey:9-alpine
restart: unless-stopped
volumes:
- redisdata:/data
+32
View File
@@ -65,6 +65,11 @@ copies you created in the steps above.
Please review the [migration instructions](migration-v3.md) before upgrading Paperless-ngx to v3.0, it includes some breaking changes that require manual intervention before upgrading.
!!! note
Upgrading to v3 clears the existing task history; previously completed, failed, or
acknowledged tasks will no longer appear in the task list afterward. No action is required.
### Docker Route {#docker-updating}
If a new release of paperless-ngx is available, upgrading depends on how
@@ -500,6 +505,33 @@ task scheduler.
python3 manage.py document_index reindex --if-needed
```
### Managing the LLM (AI) index {#llm-index}
When the [AI features](advanced_usage.md#ai-features) are enabled with an embedding
backend, Paperless-ngx maintains a vector index of your documents used for
Retrieval-Augmented Generation (RAG), similar-document retrieval, and document chat. The
index is updated automatically on the schedule set by
[`PAPERLESS_LLM_INDEX_TASK_CRON`](configuration.md#PAPERLESS_LLM_INDEX_TASK_CRON), but you
can manage it manually:
```
document_llmindex {rebuild,update,compact}
```
Specify `rebuild` to build the index from scratch from all documents in the database. Use
this the first time you enable the feature, or after changing the embedding backend or
model.
Specify `update` to incrementally index new and changed documents. This is what the
scheduled task runs.
Specify `compact` to reclaim space and optimize the on-disk vector store.
!!! note
These commands have no effect unless AI is enabled and an embedding backend is
configured.
### Clearing the database read cache
If the database read cache is enabled, **you must run this command** after making any changes to the database outside the application context.
+83 -2
View File
@@ -97,6 +97,85 @@ when using this feature:
of these correspondents to ANY new document, if both are set to
automatic matching.
## AI features {#ai-features}
Paperless-ngx includes a set of optional features backed by a large language model
(LLM): AI-assisted suggestions, similar-document retrieval, and a document chat. They
are **off by default** and never replace the built-in, non-LLM
[matching and suggestions](#matching).
!!! warning
Enabling these features sends document content (and metadata) to the LLM backend you
configure. If that backend is a remote/hosted provider, your documents leave your
server and may incur usage charges. Consider the privacy implications before enabling,
and prefer a local backend (Ollama, or a self-hosted OpenAI-compatible gateway) if that
matters to you.
All AI settings can be supplied as `PAPERLESS_AI_*` environment variables (see
[configuration](configuration.md#ai)) or set in the admin under
**Settings → Application Configuration**; the database value takes precedence over the
environment.
### Enabling the AI features
At a minimum you need to enable AI and choose an LLM backend:
- [`PAPERLESS_AI_ENABLED`](configuration.md#PAPERLESS_AI_ENABLED) — master switch.
- [`PAPERLESS_AI_LLM_BACKEND`](configuration.md#PAPERLESS_AI_LLM_BACKEND) — `ollama`
(runs locally) or `openai-like` (OpenAI itself or any OpenAI-compatible API).
- [`PAPERLESS_AI_LLM_MODEL`](configuration.md#PAPERLESS_AI_LLM_MODEL), and for
`openai-like` usually [`PAPERLESS_AI_LLM_API_KEY`](configuration.md#PAPERLESS_AI_LLM_API_KEY)
and/or [`PAPERLESS_AI_LLM_ENDPOINT`](configuration.md#PAPERLESS_AI_LLM_ENDPOINT). Ollama
requires `PAPERLESS_AI_LLM_ENDPOINT` pointing at your Ollama server.
### AI-assisted suggestions
With AI enabled, Paperless-ngx can suggest a title, tags, correspondent, document type,
storage path and dates by sending the document to the LLM. This is **opt-in per request**
and surfaces through the "Suggest" control on the document detail page, alongside the
classic classifier-based suggestions — it does not disable them. Suggestion output
language can be steered with
[`PAPERLESS_AI_LLM_OUTPUT_LANGUAGE`](configuration.md#PAPERLESS_AI_LLM_OUTPUT_LANGUAGE)
(otherwise it follows the user's UI language).
### The LLM index (RAG) and similar documents
Setting an embedding backend turns on the **LLM index**, a vector index of your documents
that enables Retrieval-Augmented Generation (RAG). When enabled, suggestions are grounded
in similar existing documents, and the document chat can retrieve relevant context.
Enable it by setting
[`PAPERLESS_AI_LLM_EMBEDDING_BACKEND`](configuration.md#PAPERLESS_AI_LLM_EMBEDDING_BACKEND)
(`huggingface` for fully-local embeddings, or `ollama` / `openai-like`). The index is only
built when AI is enabled **and** an embedding backend is set.
The index is updated automatically on a schedule controlled by
[`PAPERLESS_LLM_INDEX_TASK_CRON`](configuration.md#PAPERLESS_LLM_INDEX_TASK_CRON) (daily by
default), and can be rebuilt or compacted manually — see
[Managing the LLM index](administration.md#llm-index).
!!! note
Local embeddings via `huggingface` download the embedding model on first use into the
Paperless data directory. The first run therefore needs network access and some disk
space.
### Document chat
When the LLM index is enabled, the chat control in the top app toolbar answers questions
about your documents. It operates over a single document or across multiple documents
depending on the current view, and its answers include links to the source documents it
drew from.
### AI Security notes
- Document content is passed to the LLM as **untrusted data**.
- By default Paperless-ngx allows AI endpoints that resolve to private/loopback addresses
(for local backends). Set
[`PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS`](configuration.md#PAPERLESS_AI_LLM_ALLOW_INTERNAL_ENDPOINTS)
to `false` to block them.
## Hooking into the consumption process {#consume-hooks}
Sometimes you may want to do something arbitrary whenever a document is
@@ -846,7 +925,7 @@ Paperless is able to utilize barcodes for automatically performing some tasks. B
At this time, the library utilized for detection of barcodes supports the following types:
- AN-13/UPC-A
- EAN-13/UPC-A
- UPC-E
- EAN-8
- Code 128
@@ -855,7 +934,9 @@ At this time, the library utilized for detection of barcodes supports the follow
- Codabar
- Interleaved 2 of 5
- QR Code
- SQ Code
- Data Matrix
- Aztec
- PDF417
For usage in Paperless, the type of barcode does not matter, only the contents of it.
+7
View File
@@ -227,6 +227,7 @@ Version-aware endpoints:
- `PATCH /api/documents/{id}/`: content updates target the selected version (`?version={version_id}`) or latest version by default; non-content metadata updates target the root document.
- `GET /api/documents/{id}/download/`, `GET /api/documents/{id}/preview/`, `GET /api/documents/{id}/thumb/`, `GET /api/documents/{id}/metadata/`: accept `?version={version_id}`.
- `POST /api/documents/{id}/update_version/`: uploads a new version using multipart form field `document` and optional `version_label`.
- `PATCH /api/documents/{id}/versions/{version_id}/`: updates the `version_label` of a specific version.
- `DELETE /api/documents/{root_id}/versions/{version_id}/`: deletes a non-root version.
## Permissions
@@ -445,3 +446,9 @@ Initial API version.
large lists of object IDs for operations affecting many objects.
- The legacy `title_content` document search parameter is deprecated and will be removed in a future version.
Clients should use `text` for simple title-and-content search and `title_search` for title-only search.
- The task tracking system was redesigned. The tasks list (`/api/tasks/`) is now paginated, and the
task object exposes `task_type` (formerly `task_name`) and `trigger_source` (formerly `type`). New
read-only endpoints `/api/tasks/summary/`, `/api/tasks/status_counts/`, and `/api/tasks/active/`
provide aggregate views, and `POST /api/tasks/run/` lets privileged users dispatch supported tasks.
API v9 continues to serve the unpaginated list with the legacy field names until support for v9 is
dropped.
+28 -17
View File
@@ -22,7 +22,11 @@ or applicable default will be utilized instead.
## Required services
### Redis Broker
### Message Broker
Paperless-ngx uses a Redis-compatible message broker. Any broker that
speaks the Redis protocol works here, including [Valkey](https://valkey.io/)
(the default in the bundled Docker Compose files) and Redis itself.
#### [`PAPERLESS_REDIS=<url>`](#PAPERLESS_REDIS) {#PAPERLESS_REDIS}
@@ -30,21 +34,21 @@ or applicable default will be utilized instead.
fetching, index optimization and for training the automatic document
matcher.
- If your Redis server needs login credentials PAPERLESS_REDIS =
- If your broker needs login credentials PAPERLESS_REDIS =
`redis://<username>:<password>@<host>:<port>`
- With the requirepass option PAPERLESS_REDIS =
`redis://:<password>@<host>:<port>`
- To include the redis database index PAPERLESS_REDIS =
- To include the database index PAPERLESS_REDIS =
`redis://<username>:<password>@<host>:<port>/<DBIndex>`
[More information on securing your Redis
Instance](https://redis.io/docs/latest/operate/oss_and_stack/management/security).
[More information on securing your broker
instance](https://valkey.io/topics/security/).
Defaults to `redis://localhost:6379`.
#### [`PAPERLESS_REDIS_PREFIX=<prefix>`](#PAPERLESS_REDIS_PREFIX) {#PAPERLESS_REDIS_PREFIX}
: Prefix to be used in Redis for keys and channels. Useful for sharing one Redis server among multiple Paperless instances.
: Prefix to be used in the broker for keys and channels. Useful for sharing one broker among multiple Paperless instances.
Defaults to no prefix.
@@ -58,14 +62,14 @@ and the relevant connection variables.
#### [`PAPERLESS_DBENGINE=<engine>`](#PAPERLESS_DBENGINE) {#PAPERLESS_DBENGINE}
: Specifies the database engine to use. Accepted values are `sqlite`, `postgresql`,
and `mariadb`.
Defaults to `sqlite` if not set.
and `mariadb`. PostgreSQL and MariaDB users must set this explicitly.
PostgreSQL and MariaDB both require [`PAPERLESS_DBHOST`](#PAPERLESS_DBHOST) to be
set. SQLite does not use any other connection variables; the database file is always
located at `<PAPERLESS_DATA_DIR>/db.sqlite3`.
Defaults to `sqlite`.
!!! warning
Using MariaDB comes with some caveats.
See [MySQL Caveats](advanced_usage.md#mysql-caveats).
@@ -238,7 +242,7 @@ dictionaries; for example, `pool.max_size=20` sets
#### [`PAPERLESS_DB_READ_CACHE_ENABLED=<bool>`](#PAPERLESS_DB_READ_CACHE_ENABLED) {#PAPERLESS_DB_READ_CACHE_ENABLED}
: Caches the database read query results into Redis. This can significantly improve application response times by caching database queries, at the cost of slightly increased memory usage.
: Caches the database read query results into the broker. This can significantly improve application response times by caching database queries, at the cost of slightly increased memory usage.
Defaults to `false`.
@@ -258,18 +262,18 @@ dictionaries; for example, `pool.max_size=20` sets
A high TTL increases memory usage over time. Memory may be used until end of TTL, even if the cache is invalidated with the `invalidate_cachalot` command.
In case of an out-of-memory (OOM) situation, Redis may stop accepting new data — including cache entries, scheduled tasks, and documents to consume.
If your system has limited RAM, consider configuring a dedicated Redis instance for the read cache, with a memory limit and the eviction policy set to `allkeys-lru`.
For more details, refer to the [Redis eviction policy documentation](https://redis.io/docs/latest/develop/reference/eviction/), and see the `PAPERLESS_READ_CACHE_REDIS_URL` setting to specify a separate Redis broker.
In case of an out-of-memory (OOM) situation, the broker may stop accepting new data — including cache entries, scheduled tasks, and documents to consume.
If your system has limited RAM, consider configuring a dedicated broker instance for the read cache, with a memory limit and the eviction policy set to `allkeys-lru`.
For more details, refer to the [Redis eviction policy documentation](https://redis.io/docs/latest/develop/reference/eviction/), and see the `PAPERLESS_READ_CACHE_REDIS_URL` setting to specify a separate broker.
#### [`PAPERLESS_READ_CACHE_REDIS_URL=<url>`](#PAPERLESS_READ_CACHE_REDIS_URL) {#PAPERLESS_READ_CACHE_REDIS_URL}
: Defines the Redis instance used for the read cache.
: Defines the broker instance used for the read cache.
Defaults to `None`.
!!! Note
If this value is not set, the same Redis instance used for scheduled tasks will be used for caching as well.
If this value is not set, the same broker instance used for scheduled tasks will be used for caching as well.
## Optional Services
@@ -888,7 +892,7 @@ modes are available:
The default is `auto`.
For the `skip`, `redo`, and `force` modes, read more about OCR
For the `redo` and `force` modes, read more about OCR
behaviour in the [OCRmyPDF
documentation](https://ocrmypdf.readthedocs.io/en/latest/advanced.html#when-ocr-is-skipped).
@@ -2068,6 +2072,13 @@ context by default.
Defaults to 8192.
#### [`PAPERLESS_AI_LLM_REQUEST_TIMEOUT=<int>`](#PAPERLESS_AI_LLM_REQUEST_TIMEOUT) {#PAPERLESS_AI_LLM_REQUEST_TIMEOUT}
: The timeout, in seconds, for requests to the configured AI backend. Increase this when using
local or slow inference servers that need more time to generate responses.
Defaults to 120.
#### [`PAPERLESS_AI_LLM_BACKEND=<str>`](#PAPERLESS_AI_LLM_BACKEND) {#PAPERLESS_AI_LLM_BACKEND}
: The AI backend to use. This can be either "openai-like" or "ollama". If set to "ollama", the AI
@@ -2120,7 +2131,7 @@ used with the OpenAI-compatible backend to target a custom provider or local gat
Defaults to true, which allows internal endpoints.
#### [`PAPERLESS_AI_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_AI_LLM_INDEX_TASK_CRON) {#PAPERLESS_AI_LLM_INDEX_TASK_CRON}
#### [`PAPERLESS_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_LLM_INDEX_TASK_CRON) {#PAPERLESS_LLM_INDEX_TASK_CRON}
: Configures the schedule to update the AI embeddings of text content and metadata for all documents. Only performed if
AI is enabled and the LLM embedding backend is set.
+13 -12
View File
@@ -94,16 +94,16 @@ first-time setup.
```
7. You can now either ...
- install Redis or
- install a Redis-compatible broker (e.g. Valkey or Redis) or
- use the included `scripts/start_services.sh` to use Docker to fire
up a Redis instance (and some other services such as Tika,
up a broker instance (and some other services such as Tika,
Gotenberg and a database server) or
- spin up a bare Redis container
- spin up a bare broker container
```bash
docker run -d -p 6379:6379 --restart unless-stopped redis:latest
docker run -d -p 6379:6379 --restart unless-stopped docker.io/valkey/valkey:9-alpine
```
8. Continue with either back-end or front-end development or both :-).
@@ -132,7 +132,7 @@ uv run manage.py runserver & \
```
You might need the front end to test your back end code.
This assumes that you have AngularJS installed on your system.
This assumes that you have Angular installed on your system.
Go to the [Front end development](#front-end-development) section for further details.
To build the front end once use this command:
@@ -174,7 +174,7 @@ To add a new development package `uv add --dev <package>`
## Front end development
The front end is built using AngularJS. In order to get started, you need Node.js (version 24+) and
The front end is built using Angular. In order to get started, you need Node.js (version 24+) and
`pnpm`.
!!! note
@@ -248,12 +248,12 @@ that authentication is working.
## Localization
Paperless-ngx is available in many different languages. Since Paperless-ngx
consists both of a Django application and an AngularJS front end, both
consists both of a Django application and an Angular front end, both
these parts have to be translated separately.
### Front end localization
- The AngularJS front end does localization according to the [Angular
- The Angular front end does localization according to the [Angular
documentation](https://angular.io/guide/i18n).
- The source language of the project is "en_US".
- The source strings end up in the file `src-ui/messages.xlf`.
@@ -495,7 +495,7 @@ class MyCustomParser:
self._tempdir = Path(
tempfile.mkdtemp(prefix="paperless-", dir=settings.SCRATCH_DIR)
)
self._text: str | None = None
self._text: str = ""
self._archive_path: Path | None = None
def __enter__(self) -> Self:
@@ -553,7 +553,8 @@ def parse(
**Result accessors**
```python
def get_text(self) -> str | None:
def get_text(self) -> str:
# Return the extracted text, or an empty string if none was found.
return self._text
def get_date(self) -> "datetime.datetime | None":
@@ -684,7 +685,7 @@ class XmlDocumentParser:
def __init__(self, logging_group: object = None) -> None:
settings.SCRATCH_DIR.mkdir(parents=True, exist_ok=True)
self._tempdir = Path(tempfile.mkdtemp(prefix="paperless-", dir=settings.SCRATCH_DIR))
self._text: str | None = None
self._text: str = ""
def __enter__(self) -> Self:
return self
@@ -702,7 +703,7 @@ class XmlDocumentParser:
except ET.ParseError as e:
raise ParseError(f"XML parse error: {e}") from e
def get_text(self) -> str | None:
def get_text(self) -> str:
return self._text
def get_date(self):
+29 -6
View File
@@ -70,7 +70,16 @@ elsewhere. Here are a couple notes about that.
Paperless-ngx determines the type of a file by inspecting its content
rather than its file extensions. However, files processed via the
consumption directory will be rejected if they have a file extension that
not supported by any of the available parsers.
is not supported by any of the available parsers.
## _Are duplicate documents rejected?_
**A:** Not by default. As of v3, a file whose contents match an existing document is still
consumed, and the duplicate is flagged in the UI — open the document and check the
**Duplicates** tab to review documents that share the same content. If you prefer the old
behavior of rejecting duplicates during consumption, set
[`PAPERLESS_CONSUMER_DELETE_DUPLICATES`](configuration.md#PAPERLESS_CONSUMER_DELETE_DUPLICATES)
to `true`.
## _Will paperless-ngx run on Raspberry Pi?_
@@ -118,10 +127,24 @@ able to run paperless, you're a bit on your own. If you can't run the
docker image, the documentation has instructions for bare metal
installs.
## _What about the Redis licensing change and using one of the open source forks_?
## _Does Paperless-ngx use AI, and is my data private?_
Currently (October 2024), forks of Redis such as Valkey or Redirect are not officially supported by our upstream
libraries, so using one of these to replace Redis is not officially supported.
**A:** Paperless-ngx includes optional AI features — LLM-based suggestions, document chat,
and similar-document retrieval — that are **disabled by default**. They only run when you
enable them and configure an LLM backend. The built-in tag/correspondent suggestions use a
local, non-LLM machine-learning model and do not send your data anywhere. If you enable the
LLM features, document content is sent to whichever backend you configure — this can be a
fully local backend (e.g. Ollama) or a remote provider. See
[AI features](advanced_usage.md#ai-features) for details.
However, they do claim to be compatible with the Redis protocol and will likely work, but we will
not be updating from using Redis as the broker officially just yet.
## _Which message broker should I use_?
Paperless-ngx talks to a Redis-compatible message broker, so any broker that
implements the Redis protocol will work. The bundled Docker Compose files
default to [Valkey](https://valkey.io/), the open-source fork created after
Redis' licensing change, but Redis itself and other wire-compatible brokers
(such as Microsoft's Garnet) are equally fine.
Existing installs can switch broker implementations in place: point
[`PAPERLESS_REDIS`](configuration.md#PAPERLESS_REDIS) at the new instance and
reuse the same data volume.
+2 -1
View File
@@ -35,9 +35,10 @@ physical documents into a searchable online archive so you can keep, well, _less
- _New!_ Supports remote OCR with Azure AI (opt-in).
- Documents are saved as PDF/A format which is designed for long term storage, alongside the unaltered originals.
- Uses machine-learning to automatically add tags, correspondents and document types to your documents.
- **New**: Paperless-ngx can now leverage AI (Large Language Models or LLMs) for document suggestions. This is an optional feature that can be enabled (and is disabled by default).
- **New**: Paperless-ngx can optionally leverage AI (Large Language Models or LLMs) for document suggestions, chatting with your documents, and similar-document retrieval. These features are opt-in and disabled by default.
- Supports PDF documents, images, plain text files, Office documents (Word, Excel, PowerPoint, and LibreOffice equivalents)[^1] and more.
- Paperless stores your documents plain on disk. Filenames and folders are managed by paperless and their format can be configured freely with different configurations assigned to different documents.
- Keep multiple **versions** of a document's file under a single entry, sharing one set of metadata.
- **Beautiful, modern web application** that features:
- Customizable dashboard with statistics.
- Filtering by tags, correspondents, types, and more.
+19 -12
View File
@@ -178,7 +178,7 @@ to enable polling and disable inotify. See [here](configuration.md#polling).
- `fonts-liberation` for generating thumbnails for plain text
files
- `imagemagick` >= 6 for PDF conversion
- `gnupg` for handling encrypted documents
- `gnupg` for decrypting GPG-encrypted email
- `libpq-dev` for PostgreSQL
- `libmagic-dev` for mime type detection
- `mariadb-client` for MariaDB compile time
@@ -226,7 +226,8 @@ to enable polling and disable inotify. See [here](configuration.md#polling).
build-essential python3-setuptools python3-wheel
```
2. Install `redis` >= 6.0 and configure it to start automatically.
2. Install a Redis-compatible broker (a current release of Valkey or
Redis) and configure it to start automatically.
3. Optional: Install `postgresql` and configure a database, user, and
password for Paperless-ngx. If you do not wish to use PostgreSQL,
@@ -268,10 +269,10 @@ to enable polling and disable inotify. See [here](configuration.md#polling).
6. Configure Paperless-ngx. See [configuration](configuration.md) for details.
Edit the included `paperless.conf` and adjust the settings to your
needs. Required settings for getting Paperless-ngx running are:
- [`PAPERLESS_REDIS`](configuration.md#PAPERLESS_REDIS) should point to your Redis server, such as
- [`PAPERLESS_REDIS`](configuration.md#PAPERLESS_REDIS) should point to your broker, such as
`redis://localhost:6379`.
- [`PAPERLESS_DBENGINE`](configuration.md#PAPERLESS_DBENGINE) is optional, and should be one of `postgres`,
`mariadb`, or `sqlite`
- [`PAPERLESS_DBENGINE`](configuration.md#PAPERLESS_DBENGINE) should be one of `postgresql`,
`mariadb`, or `sqlite`. PostgreSQL and MariaDB users must set this explicitly.
- [`PAPERLESS_DBHOST`](configuration.md#PAPERLESS_DBHOST) should be the hostname on which your
PostgreSQL server is running. Do not configure this to use
SQLite instead. Also configure port, database name, user and
@@ -297,7 +298,7 @@ to enable polling and disable inotify. See [here](configuration.md#polling).
!!! warning
Ensure your Redis instance [is secured](https://redis.io/docs/latest/operate/oss_and_stack/management/security/).
Ensure your broker instance [is secured](https://valkey.io/topics/security/).
7. Create the following directories if they do not already exist:
- `/opt/paperless/media`
@@ -389,9 +390,9 @@ to enable polling and disable inotify. See [here](configuration.md#polling).
`Require=paperless-webserver.socket` in the `webserver` script
and configure `granian` to listen on port 80 (set `GRANIAN_PORT`).
These services rely on Redis and optionally the database server, but
These services rely on the broker and optionally the database server, but
don't need to be started in any particular order. The example files
depend on Redis being started. If you use a database server, you
depend on the broker being started. If you use a database server, you
should add additional dependencies.
!!! note
@@ -449,6 +450,12 @@ development documentation.
You can migrate to Paperless-ngx from Paperless-ng or from the original
Paperless project.
!!! note
Upgrading an existing Paperless-ngx installation from v2 to v3 has its own
breaking changes and required steps. See the [v3 migration guide](migration-v3.md)
before upgrading.
<h3 id="migration_ng">Migrating from Paperless-ng</h3>
Paperless-ngx is meant to be a drop-in replacement for Paperless-ng, and
@@ -494,7 +501,7 @@ installation. Keep these points in mind:
for other services, you might as well use it for Paperless as well.
- The task scheduler of Paperless, which is used to execute periodic
tasks such as email checking and maintenance, requires a
[Redis](https://redis.io/) message broker instance. The
Redis-compatible message broker instance (such as Valkey or Redis). The
Docker Compose route takes care of that.
- The layout of the folder structure for your documents and data
remains the same, so you can plug your old Docker volumes into
@@ -582,16 +589,16 @@ commands as well.
1. Stop and remove the Paperless container.
2. If using an external database, stop that container.
3. Update Redis configuration.
3. Update broker configuration.
1. If `REDIS_URL` is already set, change it to [`PAPERLESS_REDIS`](configuration.md#PAPERLESS_REDIS)
and continue to step 4.
1. Otherwise, add a new Redis service in `docker-compose.yml`,
1. Otherwise, add a new broker service in `docker-compose.yml`,
following [the example compose
files](https://github.com/paperless-ngx/paperless-ngx/tree/main/docker/compose)
1. Set the environment variable [`PAPERLESS_REDIS`](configuration.md#PAPERLESS_REDIS) so it points to
the new Redis container.
the new broker container.
4. Update user mapping.
1. If set, change the environment variable `PUID` to `USERMAP_UID`.
+2 -33
View File
@@ -10,9 +10,9 @@ Check for the following issues:
`CONSUMPTION_DIR` setting. Don't adjust this setting if you're
using docker.
- Ensure that redis is up and running. Paperless does its task
- Ensure that the broker is up and running. Paperless does its task
processing asynchronously, and for documents to arrive at the task
processor, it needs redis to run.
processor, it needs the broker to run.
- Ensure that the task processor is running. Docker does this
automatically. Manually invoke the task processor by executing
@@ -149,37 +149,6 @@ operating system, if these are different from `1000`. See [Docker setup](setup.m
Also ensure that you are able to read and write to the consumption
directory on the host.
## OSError: \[Errno 19\] No such device when consuming files
If you experience errors such as:
```shell-session
File "/usr/local/lib/python3.7/site-packages/whoosh/codec/base.py", line 570, in open_compound_file
return CompoundStorage(dbfile, use_mmap=storage.supports_mmap)
File "/usr/local/lib/python3.7/site-packages/whoosh/filedb/compound.py", line 75, in __init__
self._source = mmap.mmap(fileno, 0, access=mmap.ACCESS_READ)
OSError: [Errno 19] No such device
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/site-packages/django_q/cluster.py", line 436, in worker
res = f(*task["args"], **task["kwargs"])
File "/usr/src/paperless/src/documents/tasks.py", line 73, in consume_file
override_tag_ids=override_tag_ids)
File "/usr/src/paperless/src/documents/consumer.py", line 271, in try_consume_file
raise ConsumerError(e)
```
Paperless uses a search index to provide better and faster full text
searching. This search index is stored inside the `data` folder. The
search index uses memory-mapped files (mmap). The above error indicates
that paperless was unable to create and open these files.
This happens when you're trying to store the data directory on certain
file systems (mostly network shares) that don't support memory-mapped
files.
## Web-UI stuck at "Loading\..."
This might have multiple reasons.
+21 -2
View File
@@ -292,6 +292,23 @@ Once setup, navigating to the email settings page in Paperless-ngx will allow yo
You can also submit a document using the REST API, see [POSTing documents](api.md#file-uploads)
for details.
### Duplicate documents
By default, Paperless-ngx **does not reject duplicates**. If you consume a file whose
contents exactly match an existing document (same checksum), the new copy is still
consumed and a warning is logged. The task entry for the upload also flags that a
duplicate was detected and links to the existing document(s).
To review duplicates, open a document and switch to the **Duplicates** tab on the
document detail page. It lists other documents that share the same content, including any
that are in the trash (shown with a badge), and links to each so you can decide which to
keep.
If you would rather reject duplicates at consumption time (the pre-v3 behavior), set
[`PAPERLESS_CONSUMER_DELETE_DUPLICATES`](configuration.md#PAPERLESS_CONSUMER_DELETE_DUPLICATES)
to `true`. The duplicate file is then deleted instead of consumed, and the task fails with
a "document already exists" message.
## Document Suggestions
Paperless-ngx can suggest tags, correspondents, document types and storage paths for documents based on the content of the document. This is done using a (non-LLM) machine learning model that is trained on the documents in your database. The suggestions are shown in the document detail page and can be accepted or rejected by the user.
@@ -306,7 +323,9 @@ Paperless-ngx includes several features that use AI to enhance the document mana
so consider the privacy implications of using these features, especially if using a remote
model or API provider instead of the default local model.
The AI features work by creating an embedding of the text content and metadata of documents, which is then used for various tasks such as similarity search and question answering. This uses the FAISS vector store.
The AI features work by creating an embedding of the text content and metadata of documents, which is then used for various tasks such as similarity search and question answering.
See [AI features](advanced_usage.md#ai-features) for how to enable and configure these features, including choosing an LLM backend and setting up the LLM index for RAG.
### AI-Enhanced Suggestions
@@ -1097,7 +1116,7 @@ Paperless-ngx consists of the following components:
errors (i.e., wrong email credentials, errors during consuming a
specific file, etc).
- A [redis](https://redis.io/) message broker: This is a really
- A message broker (such as Valkey or Redis): This is a really
lightweight service that is responsible for getting the tasks from
the webserver and the consumer to the task scheduler. These run in a
different process (maybe even on different machines!), and
+1 -2
View File
@@ -42,7 +42,6 @@ dependencies = [
"drf-spectacular~=0.28",
"drf-spectacular-sidecar~=2026.5.1",
"drf-writable-nested~=0.7.1",
"faiss-cpu>=1.10",
"filelock~=3.29.0",
"flower~=2.0.1",
"gotenberg-client~=0.14.0",
@@ -57,7 +56,6 @@ dependencies = [
"llama-index-embeddings-openai-like>=0.2.2",
"llama-index-llms-ollama>=0.9.1",
"llama-index-llms-openai-like>=0.7.1",
"llama-index-vector-stores-faiss>=0.5.2",
"nltk~=3.9.1",
"ocrmypdf~=17.4.2",
"openai>=2.32",
@@ -74,6 +72,7 @@ dependencies = [
"scikit-learn~=1.8.0",
"sentence-transformers>=5.4.1",
"setproctitle~=1.3.4",
"sqlite-vec==0.1.9",
"tantivy~=0.26.0",
"tika-client~=0.11.0",
"torch~=2.12.0",
+1 -1
View File
@@ -26,7 +26,7 @@ module.exports = {
'abstract-paperless-service',
],
transformIgnorePatterns: [
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|@angular/common/locales/.*\\.js$))',
'node_modules/(?!.*(\\.mjs$|tslib|lodash-es|normalize-diacritics|@angular/common/locales/.*\\.js$))',
],
moduleNameMapper: {
...esmPreset.moduleNameMapper,
+1
View File
@@ -32,6 +32,7 @@
"ngx-cookie-service": "^21.3.1",
"ngx-device-detector": "^11.0.0",
"ngx-ui-tour-ng-bootstrap": "^18.0.0",
"normalize-diacritics": "^5.0.0",
"pdfjs-dist": "^5.7.284",
"rxjs": "^7.8.2",
"tslib": "^2.8.1",
+11
View File
@@ -71,6 +71,9 @@ importers:
ngx-ui-tour-ng-bootstrap:
specifier: ^18.0.0
version: 18.0.0(4ccfccfbcf381a309618492b31e99276)
normalize-diacritics:
specifier: ^5.0.0
version: 5.0.0
pdfjs-dist:
specifier: ^5.7.284
version: 5.7.284
@@ -5565,6 +5568,10 @@ packages:
engines: {node: ^20.17.0 || >=22.9.0}
hasBin: true
normalize-diacritics@5.0.0:
resolution: {integrity: sha512-t6czCJOpbAtckN1wCC2qPWnO3GQvNANb9bcUNbiOLEqojVuP31+ELIs5KhEG8jyz0TH7iD9BWxWz8O3ic2/rMQ==}
engines: {node: '>= 14.x', npm: '>= 6.x'}
normalize-path@3.0.0:
resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==}
engines: {node: '>=0.10.0'}
@@ -12985,6 +12992,10 @@ snapshots:
dependencies:
abbrev: 4.0.0
normalize-diacritics@5.0.0:
dependencies:
tslib: 2.8.1
normalize-path@3.0.0: {}
npm-bundled@5.0.0:
@@ -11,6 +11,9 @@
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="visibleTasks.length === 0">
<i-bs name="check2-all" class="me-1"></i-bs>{{dismissButtonText}}
</button>
<button class="btn btn-sm btn-outline-primary me-2" (click)="dismissAllTasks()" *pngxIfPermissions="{ action: PermissionAction.Change, type: PermissionType.PaperlessTask }" [disabled]="totalTasks === 0">
<i-bs name="check2-all" class="me-1"></i-bs><ng-container i18n>Dismiss all</ng-container>
</button>
<div class="form-check form-switch mb-0 ms-2">
<input class="form-check-input" type="checkbox" role="switch" [(ngModel)]="autoRefreshEnabled">
<label class="form-check-label" for="autoRefreshSwitch" i18n>Auto refresh</label>
@@ -81,7 +84,7 @@
<button class="btn btn-sm btn-outline-primary" ngbDropdownToggle>{{filterTargetName}}</button>
<div class="dropdown-menu shadow" ngbDropdownMenu>
@for (t of filterTargets; track t.id) {
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="filterTargetID = t.id">{{t.name}}</button>
<button ngbDropdownItem [class.active]="filterTargetID === t.id" (click)="setFilterTarget(t.id)">{{t.name}}</button>
}
</div>
</div>
@@ -11,7 +11,7 @@ import { Router } from '@angular/router'
import { RouterTestingModule } from '@angular/router/testing'
import { NgbModal, NgbModalRef, NgbModule } from '@ng-bootstrap/ng-bootstrap'
import { allIcons, NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { throwError } from 'rxjs'
import { of, throwError } from 'rxjs'
import { routes } from 'src/app/app-routing.module'
import {
PaperlessTask,
@@ -29,7 +29,11 @@ import { ToastService } from 'src/app/services/toast.service'
import { environment } from 'src/environments/environment'
import { ConfirmDialogComponent } from '../../common/confirm-dialog/confirm-dialog.component'
import { PageHeaderComponent } from '../../common/page-header/page-header.component'
import { TasksComponent, TaskSection } from './tasks.component'
import {
TaskFilterTargetID,
TasksComponent,
TaskSection,
} from './tasks.component'
const tasks: PaperlessTask[] = [
{
@@ -154,6 +158,13 @@ const paginatedTasks: Results<PaperlessTask> = {
results: tasks,
}
const sectionCountResponse = {
all: 7,
needs_attention: 2,
in_progress: 3,
completed: 2,
}
describe('TasksComponent', () => {
let component: TasksComponent
let fixture: ComponentFixture<TasksComponent>
@@ -221,6 +232,15 @@ describe('TasksComponent', () => {
req.params.get('page') === '1'
)
.flush(paginatedTasks)
httpTestingController
.expectOne(
(req) =>
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
req.params.get('acknowledged') === 'false' &&
!req.params.has('status')
)
.flush(sectionCountResponse)
})
it('should display task sections with counts', () => {
@@ -295,6 +315,7 @@ describe('TasksComponent', () => {
const headerText = header.nativeElement.textContent
expect(headerText).toContain('Dismiss visible')
expect(headerText).toContain('Dismiss all')
expect(headerText).toContain('Auto refresh')
expect(headerText).not.toContain('All types')
expect(headerText).not.toContain('All sources')
@@ -327,6 +348,74 @@ describe('TasksComponent', () => {
expect(pagination).not.toBeNull()
})
it('should apply the selected section to the server-side task query', () => {
component.setSection(TaskSection.NeedsAttention)
const req = httpTestingController.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page') === '1' &&
request.params.get('page_size') === '25' &&
request.params.get('acknowledged') === 'false' &&
request.params.getAll('status').includes(PaperlessTaskStatus.Failure) &&
request.params.getAll('status').includes(PaperlessTaskStatus.Revoked)
)
req.flush({ count: 2, results: [tasks[0], tasks[1]] })
expect(component.totalTasks).toBe(2)
})
it('should apply task type and trigger source filters to the server-side task query', () => {
component.setTaskType(PaperlessTaskType.SanityCheck)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('task_type') === PaperlessTaskType.SanityCheck
)
.flush({ count: 1, results: [tasks[6]] })
component.setTriggerSource(PaperlessTaskTriggerSource.System)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('task_type') === PaperlessTaskType.SanityCheck &&
request.params.get('trigger_source') ===
PaperlessTaskTriggerSource.System
)
.flush({ count: 1, results: [tasks[6]] })
})
it('should apply text filters to the server-side task query', () => {
component.filterText = 'invoice'
jest.advanceTimersByTime(150)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('name') === 'invoice'
)
.flush({ count: 1, results: [tasks[0]] })
component.setFilterTarget(TaskFilterTargetID.Result)
httpTestingController
.expectOne(
(request) =>
request.url === `${environment.apiBaseUrl}tasks/` &&
request.params.get('page_size') === '25' &&
request.params.get('result') === 'invoice'
)
.flush({ count: 0, results: [] })
})
it('should load a different task page when pagination changes', () => {
component.setPage(2)
@@ -350,6 +439,27 @@ describe('TasksComponent', () => {
expect(component.pagedTasks).toEqual([tasks[0]])
})
it('should not replace section counts with current-page counts', () => {
component.setPage(2)
httpTestingController
.expectOne(
(req) =>
req.url === `${environment.apiBaseUrl}tasks/` &&
req.params.get('acknowledged') === 'false' &&
req.params.get('page_size') === '25' &&
req.params.get('page') === '2'
)
.flush({
count: 30,
results: [tasks[0]],
})
expect(component.sectionCount(TaskSection.NeedsAttention)).toBe(2)
expect(component.sectionCount(TaskSection.InProgress)).toBe(3)
expect(component.sectionCount(TaskSection.Completed)).toBe(2)
})
it('should expose stable task type options and disable empty ones', () => {
expect(component.taskTypeOptions.map((option) => option.value)).toContain(
PaperlessTaskType.TrainClassifier
@@ -495,6 +605,46 @@ describe('TasksComponent', () => {
expect(dismissSpy).toHaveBeenCalledWith(new Set([467, 466]))
})
it('should support dismiss all tasks', () => {
let modal: NgbModalRef
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
const dismissSpy = jest
.spyOn(tasksService, 'dismissAllTasks')
.mockReturnValue(of({}))
const reloadPageSpy = jest
.spyOn(component as any, 'reloadPage')
.mockImplementation(() => undefined)
component.dismissAllTasks()
expect(modal).not.toBeUndefined()
expect(modal.componentInstance.messageBold).toBe('Dismiss all 7 tasks?')
modal.componentInstance.confirmClicked.emit()
expect(dismissSpy).toHaveBeenCalled()
expect(reloadPageSpy).toHaveBeenCalledWith(false)
expect(component.selectedTasks.size).toBe(0)
})
it('should show an error and re-enable modal buttons when dismissing all tasks fails', () => {
const error = new Error('dismiss all failed')
const toastSpy = jest.spyOn(toastService, 'showError')
const dismissSpy = jest
.spyOn(tasksService, 'dismissAllTasks')
.mockReturnValue(throwError(() => error))
let modal: NgbModalRef
modalService.activeInstances.subscribe((m) => (modal = m[m.length - 1]))
component.dismissAllTasks()
expect(modal).not.toBeUndefined()
modal.componentInstance.confirmClicked.emit()
expect(dismissSpy).toHaveBeenCalled()
expect(toastSpy).toHaveBeenCalledWith('Error dismissing tasks', error)
expect(modal.componentInstance.buttonsEnabled).toBe(true)
})
it('should dismiss the currently visible scoped and filtered tasks', () => {
component.setSection(TaskSection.InProgress)
component.setTaskType(PaperlessTaskType.SanityCheck)
@@ -673,6 +823,9 @@ describe('TasksComponent', () => {
})
it('should keep clearing selection independent from resetting filters', () => {
component.resetFilter()
expect(component.filterText).toBe('')
component.setTaskType(PaperlessTaskType.ConsumeFile)
component.toggleSelected(tasks[0])
expect(component.selectedTasks.size).toBe(1)
@@ -40,7 +40,7 @@ export enum TaskSection {
Completed = 'completed',
}
enum TaskFilterTargetID {
export enum TaskFilterTargetID {
Name,
Result,
}
@@ -167,6 +167,12 @@ export class TasksComponent
public readonly pageSize = 25
public page: number = 1
public totalTasks: number = 0
public sectionCounts: Record<TaskSection, number> = {
[TaskSection.All]: 0,
[TaskSection.NeedsAttention]: 0,
[TaskSection.InProgress]: 0,
[TaskSection.Completed]: 0,
}
public pagedTasks: PaperlessTask[] = []
public selectedSection: TaskSection = TaskSection.All
public selectedTaskType: PaperlessTaskType | null = null
@@ -282,6 +288,7 @@ export class TasksComponent
.subscribe((query) => {
this._filterText = query
this.clearSelection()
this.reloadPage(true)
})
}
@@ -334,6 +341,30 @@ export class TasksComponent
}
}
dismissAllTasks() {
let modal = this.modalService.open(ConfirmDialogComponent, {
backdrop: 'static',
})
modal.componentInstance.title = $localize`Confirm Dismiss All`
modal.componentInstance.messageBold = $localize`Dismiss all ${this.totalTasks} tasks?`
modal.componentInstance.btnClass = 'btn-warning'
modal.componentInstance.btnCaption = $localize`Dismiss`
modal.componentInstance.confirmClicked.pipe(first()).subscribe(() => {
modal.componentInstance.buttonsEnabled = false
modal.close()
this.tasksService.dismissAllTasks().subscribe({
next: () => {
this.reloadPage(false)
},
error: (e) => {
this.toastService.showError($localize`Error dismissing tasks`, e)
modal.componentInstance.buttonsEnabled = true
},
})
this.clearSelection()
})
}
expandTask(task: PaperlessTask) {
this.expandedTask = this.expandedTask == task.id ? undefined : task.id
}
@@ -446,9 +477,7 @@ export class TasksComponent
}
sectionCount(section: TaskSection): number {
return this.pagedTasks.filter((task) =>
this.taskBelongsToSection(task, section)
).length
return this.sectionCounts[section]
}
sectionShowsResults(section: TaskSection): boolean {
@@ -458,16 +487,27 @@ export class TasksComponent
setSection(section: TaskSection) {
this.selectedSection = section
this.clearSelection()
this.reloadPage(true)
}
setTaskType(taskType: PaperlessTaskType | null) {
this.selectedTaskType = taskType
this.clearSelection()
this.reloadPage(true)
}
setTriggerSource(triggerSource: PaperlessTaskTriggerSource | null) {
this.selectedTriggerSource = triggerSource
this.clearSelection()
this.reloadPage(true)
}
setFilterTarget(filterTargetID: TaskFilterTargetID) {
this.filterTargetID = filterTargetID
if (this._filterText.length) {
this.clearSelection()
this.reloadPage(true)
}
}
taskTypeOptionCount(taskType: PaperlessTaskType | null): number {
@@ -505,19 +545,32 @@ export class TasksComponent
}
public resetFilter() {
if (!this._filterText.length) {
return
}
this._filterText = ''
this.clearSelection()
this.reloadPage(true)
}
public resetFilters() {
const hadFilter = this.isFiltered
this.selectedTaskType = null
this.selectedTriggerSource = null
this.resetFilter()
this._filterText = ''
this.clearSelection()
if (hadFilter) {
this.reloadPage(true)
}
}
filterInputKeyup(event: KeyboardEvent) {
if (event.key == 'Enter') {
this._filterText = (event.target as HTMLInputElement).value
this.clearSelection()
this.reloadPage(true)
} else if (event.key === 'Escape') {
this.resetFilter()
}
@@ -606,19 +659,86 @@ export class TasksComponent
)
}
private reloadSectionCounts() {
this.tasksService
.statusCounts(this.getParamsForSection(TaskSection.All))
.pipe(first(), takeUntil(this.unsubscribeNotifier))
.subscribe((counts) => {
this.sectionCounts[TaskSection.All] = counts.all
this.sectionCounts[TaskSection.NeedsAttention] = counts.needs_attention
this.sectionCounts[TaskSection.InProgress] = counts.in_progress
this.sectionCounts[TaskSection.Completed] = counts.completed
})
}
private getParamsForSection(
section: TaskSection
): Record<string, string | number | boolean | readonly string[]> {
const params: Record<
string,
string | number | boolean | readonly string[]
> = {
acknowledged: false,
}
const statuses = this.statusesForSection(section)
if (statuses.length) {
params.status = statuses
}
if (this.selectedTaskType !== null) {
params.task_type = this.selectedTaskType
}
if (this.selectedTriggerSource !== null) {
params.trigger_source = this.selectedTriggerSource
}
if (this._filterText.length) {
params[
this.filterTargetID === TaskFilterTargetID.Name ? 'name' : 'result'
] = this._filterText
}
return params
}
private statusesForSection(section: TaskSection): PaperlessTaskStatus[] {
switch (section) {
case TaskSection.NeedsAttention:
return [PaperlessTaskStatus.Failure, PaperlessTaskStatus.Revoked]
case TaskSection.InProgress:
return [PaperlessTaskStatus.Pending, PaperlessTaskStatus.Started]
case TaskSection.Completed:
return [PaperlessTaskStatus.Success]
default:
return []
}
}
private reloadPage(resetToFirstPage: boolean = false) {
if (resetToFirstPage) {
this.page = 1
}
this.reloadSectionCounts()
this.loading = true
this.tasksService
.list(this.page, this.pageSize, { acknowledged: false })
.list(
this.page,
this.pageSize,
this.getParamsForSection(this.selectedSection)
)
.pipe(first(), takeUntil(this.unsubscribeNotifier))
.subscribe({
next: (result) => {
this.pagedTasks = result.results
this.totalTasks = result.count
this.sectionCounts[TaskSection.All] = result.count
if (this.selectedSection !== TaskSection.All) {
this.sectionCounts[this.selectedSection] = result.count
}
this.loading = false
if (
this.page > 1 &&
@@ -8,7 +8,7 @@
<div class="chat-messages font-monospace small">
@for (message of messages; track message) {
<div class="message d-flex flex-row small" [class.justify-content-end]="message.role === 'user'">
<div class="p-2 m-2" [class.bg-dark]="message.role === 'user'">
<div class="p-2 m-2" [class.bg-body]="message.role === 'user'">
<span>
{{ message.content }}
@if (message.isStreaming) { <span class="blinking-cursor">|</span> }
@@ -188,4 +188,14 @@ describe('ChatComponent', () => {
component.searchInputKeyDown(event)
expect(component.sendMessage).toHaveBeenCalled()
})
it('should not send message on Enter key press while composing with IME', () => {
jest.spyOn(component, 'sendMessage')
const event = new KeyboardEvent('keydown', {
key: 'Enter',
isComposing: true,
})
component.searchInputKeyDown(event)
expect(component.sendMessage).not.toHaveBeenCalled()
})
})
@@ -155,7 +155,10 @@ export class ChatComponent implements OnInit {
}
public searchInputKeyDown(event: KeyboardEvent) {
if (event.key === 'Enter') {
if (
event.key === 'Enter' &&
!(event.isComposing || event.keyCode === 229)
) {
event.preventDefault()
this.sendMessage()
}
@@ -5,10 +5,10 @@
</div>
<div class="modal-body">
@if (messageBold) {
<p><b>{{messageBold}}</b></p>
<p class="text-break"><b>{{messageBold}}</b></p>
}
@if (message) {
<p class="mb-0" [innerHTML]="message"></p>
<p class="mb-0 text-break" [innerHTML]="message"></p>
}
</div>
<div class="modal-footer">
@@ -9,8 +9,11 @@
<label class="form-label" for="metadataDocumentID" i18n>Documents:</label>
<ul class="list-group"
cdkDropList
[cdkDropListData]="documentIDs"
(cdkDropListDropped)="onDrop($event)">
@for (document of documents; track document.id) {
@for (documentID of documentIDs; track documentID) {
@let document = getDocument(documentID);
@if (document) {
<li class="list-group-item d-flex align-items-center" cdkDrag>
<i-bs name="grip-vertical" class="me-2"></i-bs>
<div class="d-flex flex-column">
@@ -27,6 +30,7 @@
</small>
</div>
</li>
}
}
</ul>
</div>
@@ -23,6 +23,7 @@ import {
import { CustomFieldsService } from 'src/app/services/rest/custom-fields.service'
import { ToastService } from 'src/app/services/toast.service'
import { pngxPopperOptions } from 'src/app/utils/popper-options'
import { matchesSearchText } from 'src/app/utils/text-search'
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
import { CustomFieldEditDialogComponent } from '../edit-dialog/custom-field-edit-dialog/custom-field-edit-dialog.component'
@@ -69,9 +70,7 @@ export class CustomFieldsDropdownComponent extends LoadingComponentWithPermissio
public get filteredFields(): CustomField[] {
return this.unusedFields.filter(
(f) =>
!this.filterText ||
f.name.toLowerCase().includes(this.filterText.toLowerCase())
(f) => !this.filterText || matchesSearchText(f.name, this.filterText)
)
}
@@ -63,6 +63,7 @@
[(ngModel)]="atom.value"
[disabled]="disabled"
[virtualScroll]="getSelectOptionsForField(atom.field)?.length > 100"
[searchFn]="selectOptionSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
} @else if (getCustomFieldByID(atom.field)?.data_type === CustomFieldDataType.DocumentLink) {
@@ -81,6 +82,7 @@
[disabled]="disabled"
bindLabel="name"
bindValue="id"
[searchFn]="customFieldSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
<select class="w-25 form-select" [(ngModel)]="atom.operator" [disabled]="disabled">
@@ -125,6 +127,7 @@
[(ngModel)]="atom.value"
[disabled]="disabled"
[multiple]="true"
[searchFn]="selectOptionSearchFn"
(mousedown)="$event.stopImmediatePropagation()"
></ng-select>
}
@@ -36,6 +36,7 @@ import {
CustomFieldQueryExpression,
} from 'src/app/utils/custom-field-query-element'
import { pngxPopperOptions } from 'src/app/utils/popper-options'
import { matchesSearchText } from 'src/app/utils/text-search'
import { LoadingComponentWithPermissions } from '../../loading-component/loading.component'
import { ClearableBadgeComponent } from '../clearable-badge/clearable-badge.component'
import { DocumentLinkComponent } from '../input/document-link/document-link.component'
@@ -281,6 +282,14 @@ export class CustomFieldsQueryDropdownComponent extends LoadingComponentWithPerm
public readonly today: string = new Date().toLocaleDateString('en-CA')
public customFieldSearchFn = (term: string, field: CustomField): boolean =>
matchesSearchText(field?.name, term)
public selectOptionSearchFn = (
term: string,
option: { id: string; label: string }
): boolean => matchesSearchText(option?.label, term)
constructor() {
super()
this.selectionModel = new CustomFieldQueriesModel()
@@ -28,6 +28,7 @@
[notFoundText]="notFoundText"
[multiple]="multiple"
[bindLabel]="bindLabel"
[searchFn]="searchFn"
bindValue="id"
[virtualScroll]="items?.length > 100"
(change)="onChange(value)"
@@ -112,6 +112,15 @@ describe('SelectComponent', () => {
expect(createNewVal).toEqual('baz')
})
it('should search items by independent normalized terms', () => {
expect(
component.searchFn('tax 26', { id: 11, name: 'Tax\u00e9s 2026' })
).toBeTruthy()
expect(
component.searchFn('tax receipt', { id: 11, name: 'Tax\u00e9s 2026' })
).toBeFalsy()
})
it('should clear search term on blur after delay', fakeAsync(() => {
const clearSpy = jest.spyOn(component, 'clearLastSearchTerm')
component.onBlur()
@@ -13,6 +13,7 @@ import {
import { RouterModule } from '@angular/router'
import { NgSelectModule } from '@ng-select/ng-select'
import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { matchesSearchText } from 'src/app/utils/text-search'
import { AbstractInputComponent } from '../abstract-input'
@Component({
@@ -99,6 +100,9 @@ export class SelectComponent extends AbstractInputComponent<number> {
@Input()
bindLabel: string = 'name'
public searchFn = (term: string, item: any): boolean =>
matchesSearchText(item?.[this.bindLabel], term)
@Input()
showFilter: boolean = false
@@ -14,6 +14,7 @@
[clearSearchOnAdd]="true"
[hideSelected]="tags.length > 0"
[addTag]="allowCreate ? createTagRef : false"
[searchFn]="searchFn"
addTagText="Add tag"
i18n-addTagText
(add)="onAdd($event)"
@@ -171,6 +171,15 @@ describe('TagsComponent', () => {
expect(component.getTag(4)).toBeUndefined()
})
it('should search tags by independent normalized terms including parents', () => {
const parent: Tag = { id: 11, name: 'Financ\u00e9' }
const child: Tag = { id: 12, name: 'Taxes 2026', parent: parent.id }
component.tags = [parent, child]
expect(component.searchFn('finance 26', child)).toBeTruthy()
expect(component.searchFn('finance receipt', child)).toBeFalsy()
})
it('should emit filtered documents', () => {
component.value = [10]
component.tags = tags
@@ -21,6 +21,7 @@ import { NgxBootstrapIconsModule } from 'ngx-bootstrap-icons'
import { first, firstValueFrom, tap } from 'rxjs'
import { Tag } from 'src/app/data/tag'
import { TagService } from 'src/app/services/rest/tag.service'
import { matchesSearchText } from 'src/app/utils/text-search'
import { EditDialogMode } from '../../edit-dialog/edit-dialog.component'
import { TagEditDialogComponent } from '../../edit-dialog/tag-edit-dialog/tag-edit-dialog.component'
import { TagComponent } from '../../tag/tag.component'
@@ -114,6 +115,14 @@ export class TagsComponent implements OnInit, ControlValueAccessor {
public createTagRef: (name) => void
public searchFn = (term: string, tag: Tag): boolean =>
matchesSearchText(
[this.getParentChain(tag?.id).map((parent) => parent.name), tag?.name]
.flat()
.join(' '),
term
)
getTag(id: number) {
if (this.tags) {
return this.tags.find((tag) => tag.id == id)
@@ -1,5 +1,5 @@
<div class="btn-group">
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="loading || (suggestions && !aiEnabled)">
<button type="button" class="btn btn-sm btn-outline-primary" (click)="clickSuggest()" [disabled]="disabled || loading || (suggestions && !aiEnabled)">
@if (loading) {
<div class="spinner-border spinner-border-sm" role="status"></div>
} @else {
@@ -13,7 +13,7 @@
@if (aiEnabled) {
<div class="btn-group" ngbDropdown #dropdown="ngbDropdown" [popperOptions]="popperOptions">
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
<button type="button" class="btn btn-sm btn-outline-primary" ngbDropdownToggle [disabled]="disabled || loading || !suggestions" aria-expanded="false" aria-controls="suggestionsDropdown" aria-label="Suggestions dropdown">
<span class="visually-hidden" i18n>Show suggestions</span>
</button>
@@ -37,6 +37,18 @@ describe('SuggestionsDropdownComponent', () => {
expect(component.getSuggestions.emit).toHaveBeenCalled()
})
it('should not emit getSuggestions when disabled', () => {
jest.spyOn(component.getSuggestions, 'emit')
component.disabled = true
component.suggestions = null
fixture.detectChanges()
component.clickSuggest()
expect(component.getSuggestions.emit).not.toHaveBeenCalled()
expect(fixture.nativeElement.querySelector('button').disabled).toBeTruthy()
})
it('should toggle dropdown when clickSuggest is called and suggestions are not null', () => {
component.aiEnabled = true
fixture.detectChanges()
@@ -47,6 +47,14 @@ export class SuggestionsDropdownComponent {
addCorrespondent: EventEmitter<string> = new EventEmitter()
public clickSuggest(): void {
if (
this.disabled ||
this.loading ||
(this.suggestions && !this.aiEnabled)
) {
return
}
if (!this.suggestions) {
this.getSuggestions.emit(this)
} else {
@@ -131,7 +131,9 @@
@if (status.tasks.celery_status === 'OK') {
<i-bs name="check-circle-fill" class="text-primary ms-2 lh-1"></i-bs>
} @else {
<i-bs name="exclamation-triangle-fill" class="text-danger ms-2 lh-1"></i-bs>
<i-bs name="exclamation-triangle-fill" class="ms-2 lh-1"
[class.text-danger]="status.tasks.celery_status === SystemStatusItemStatus.ERROR"
[class.text-warning]="status.tasks.celery_status === SystemStatusItemStatus.WARNING"></i-bs>
}
</button>
<ng-template #celeryStatus>
+9
View File
@@ -360,6 +360,14 @@ export const PaperlessConfigOptions: ConfigOption[] = [
category: ConfigCategory.AI,
note: $localize`Language to use for generated AI suggestions. When unset, AI suggestions use the user's display language if explicitly set.`,
},
{
key: 'llm_request_timeout',
title: $localize`LLM Request Timeout`,
type: ConfigOptionType.Number,
config_key: 'PAPERLESS_AI_LLM_REQUEST_TIMEOUT',
category: ConfigCategory.AI,
note: $localize`Timeout in seconds for LLM requests.`,
},
]
export interface PaperlessConfig extends ObjectWithId {
@@ -401,4 +409,5 @@ export interface PaperlessConfig extends ObjectWithId {
llm_api_key: string
llm_endpoint: string
llm_output_language: string
llm_request_timeout: number
}
+7
View File
@@ -64,3 +64,10 @@ export interface PaperlessTaskSummary {
last_success: Date | null
last_failure: Date | null
}
export interface PaperlessTaskStatusCounts {
all: number
needs_attention: number
in_progress: number
completed: number
}
+2 -3
View File
@@ -1,5 +1,6 @@
import { Pipe, PipeTransform } from '@angular/core'
import { MatchingModel } from '../data/matching-model'
import { matchesSearchText } from '../utils/text-search'
@Pipe({
name: 'filter',
@@ -21,9 +22,7 @@ export class FilterPipe implements PipeTransform {
typeof item[key] === 'string' || typeof item[key] === 'number'
)
return keys.some((key) => {
return String(item[key])
.toLowerCase()
.includes(searchText.toLowerCase())
return matchesSearchText(item[key], searchText)
})
})
}
@@ -80,6 +80,27 @@ describe('TasksService', () => {
.flush({ count: 0, results: [] })
})
it('calls acknowledge_tasks api endpoint on dismiss all and reloads', () => {
tasksService.dismissAllTasks().subscribe()
const req = httpTestingController.expectOne(
`${environment.apiBaseUrl}tasks/acknowledge/`
)
expect(req.request.method).toEqual('POST')
expect(req.request.body).toEqual({
all: true,
})
req.flush([])
// reload is then called
httpTestingController
.expectOne(
(req: HttpRequest<unknown>) =>
req.url === `${environment.apiBaseUrl}tasks/` &&
req.params.get('acknowledged') === 'false' &&
req.params.get('page_size') === '1000'
)
.flush({ count: 0, results: [] })
})
it('groups mixed task types by status when reloading', () => {
expect(tasksService.total).toEqual(0)
const mockTasks = [
@@ -221,4 +242,34 @@ describe('TasksService', () => {
task_id: 'abc-123',
})
})
it('loads filtered task status counts', () => {
tasksService
.statusCounts({
acknowledged: false,
task_type: PaperlessTaskType.ConsumeFile,
})
.subscribe((res) => {
expect(res).toEqual({
all: 10,
needs_attention: 2,
in_progress: 3,
completed: 5,
})
})
const req = httpTestingController.expectOne(
(req: HttpRequest<unknown>) =>
req.url === `${environment.apiBaseUrl}tasks/status_counts/` &&
req.params.get('acknowledged') === 'false' &&
req.params.get('task_type') === PaperlessTaskType.ConsumeFile
)
expect(req.request.method).toEqual('GET')
req.flush({
all: 10,
needs_attention: 2,
in_progress: 3,
completed: 5,
})
})
})
+27 -1
View File
@@ -5,6 +5,7 @@ import { first, map, takeUntil, tap } from 'rxjs/operators'
import {
PaperlessTask,
PaperlessTaskStatus,
PaperlessTaskStatusCounts,
PaperlessTaskType,
} from 'src/app/data/paperless-task'
import { Results } from 'src/app/data/results'
@@ -88,7 +89,7 @@ export class TasksService {
public list(
page: number,
pageSize: number,
extraParams?: Record<string, string | number | boolean>
extraParams?: Record<string, string | number | boolean | readonly string[]>
): Observable<Results<PaperlessTask>> {
return this.http.get<Results<PaperlessTask>>(
`${this.baseUrl}${this.endpoint}/`,
@@ -102,6 +103,17 @@ export class TasksService {
)
}
public statusCounts(
extraParams?: Record<string, string | number | boolean | readonly string[]>
): Observable<PaperlessTaskStatusCounts> {
return this.http.get<PaperlessTaskStatusCounts>(
`${this.baseUrl}${this.endpoint}/status_counts/`,
{
params: extraParams,
}
)
}
public dismissTasks(task_ids: Set<number>): Observable<any> {
return this.http
.post(`${this.baseUrl}tasks/acknowledge/`, {
@@ -116,6 +128,20 @@ export class TasksService {
)
}
public dismissAllTasks(): Observable<any> {
return this.http
.post(`${this.baseUrl}tasks/acknowledge/`, {
all: true,
})
.pipe(
first(),
takeUntil(this.unsubscribeNotifer),
tap(() => {
this.reload()
})
)
}
public cancelPending(): void {
this.unsubscribeNotifer.next(true)
}
+17
View File
@@ -0,0 +1,17 @@
import { matchesSearchText } from './text-search'
describe('text search utilities', () => {
it('matches text accent-insensitively', () => {
expect(matchesSearchText('R\u00e9sum\u00e9', 'resume')).toBeTruthy()
expect(matchesSearchText('S\u00f8ren', 'soren')).toBeTruthy()
expect(matchesSearchText('\u0152uvre', 'oeuvre')).toBeTruthy()
expect(matchesSearchText('Invoice', 'receipt')).toBeFalsy()
})
it('matches all whitespace-separated search terms independently', () => {
expect(matchesSearchText('taxes 2026', 'tax 26')).toBeTruthy()
expect(matchesSearchText('2026 taxes', 'tax 26')).toBeTruthy()
expect(matchesSearchText('Tax\u00e9s 2026', 'taxe 26')).toBeTruthy()
expect(matchesSearchText('taxes 2026', 'tax receipt')).toBeFalsy()
})
})
+23
View File
@@ -0,0 +1,23 @@
import { normalizeSync } from 'normalize-diacritics'
export type SearchTextValue =
| string
| number
| boolean
| bigint
| null
| undefined
export function normalizeSearchText(value: SearchTextValue): string {
return normalizeSync(String(value ?? '')).toLocaleLowerCase()
}
export function matchesSearchText(
value: SearchTextValue,
searchText: SearchTextValue
): boolean {
const normalizedValue = normalizeSearchText(value)
const searchTerms = normalizeSearchText(searchText).trim().split(/\s+/)
return searchTerms.every((term) => normalizedValue.includes(term))
}
+13
View File
@@ -904,6 +904,19 @@ def remove_password(
doc.id,
pair.source_doc.source_path,
)
try:
with pikepdf.open(source_path) as pdf:
if not pdf.is_encrypted:
logger.info(
"Skipping password removal for document %s because the "
"source PDF is not encrypted",
pair.root_doc.id,
)
continue
except pikepdf.PasswordError:
# Password-protected PDFs need the supplied password below.
pass
with pikepdf.open(source_path, password=password) as pdf:
filepath: Path = (
Path(tempfile.mkdtemp(dir=settings.SCRATCH_DIR))
+3 -3
View File
@@ -70,13 +70,13 @@ def suggestions_last_modified(request, pk: int) -> datetime | None:
def metadata_etag(request, pk: int) -> str | None:
"""
Metadata is extracted from the original file, so use its checksum as the
ETag
Metadata responses include metadata as well as document fields, so include
the modification time with the checksum so metadata-only changes invalidate cache.
"""
doc = resolve_effective_document_by_pk(pk, request).document
if doc is None:
return None
return doc.checksum
return f"{doc.checksum}:{doc.modified.isoformat()}"
def metadata_last_modified(request, pk: int) -> datetime | None:
+63 -1
View File
@@ -28,6 +28,7 @@ from django.db.models.functions import Cast
from django.utils.translation import gettext_lazy as _
from django_filters import DateFilter
from django_filters.rest_framework import BooleanFilter
from django_filters.rest_framework import CharFilter
from django_filters.rest_framework import DateTimeFilter
from django_filters.rest_framework import Filter
from django_filters.rest_framework import FilterSet
@@ -900,6 +901,16 @@ class ShareLinkBundleFilterSet(FilterSet):
class PaperlessTaskFilterSet(FilterSet):
name = CharFilter(
method="filter_name",
label="Name",
)
result = CharFilter(
method="filter_result",
label="Result",
)
task_type = MultipleChoiceFilter(
choices=PaperlessTask.TaskType.choices,
label="Task Type",
@@ -939,7 +950,58 @@ class PaperlessTaskFilterSet(FilterSet):
class Meta:
model = PaperlessTask
fields = ["task_type", "trigger_source", "status", "acknowledged", "owner"]
fields = [
"task_type",
"trigger_source",
"status",
"acknowledged",
"owner",
"name",
"result",
]
def filter_name(self, queryset, name, value):
if not value:
return queryset
matching_task_types = [
task_type
for task_type, label in PaperlessTask.TaskType.choices
if value.lower() in str(label).lower()
]
matching_trigger_sources = [
trigger_source
for trigger_source, label in PaperlessTask.TriggerSource.choices
if value.lower() in str(label).lower()
]
return queryset.filter(
Q(input_data__filename__icontains=value)
| Q(task_type__in=matching_task_types)
| Q(trigger_source__in=matching_trigger_sources),
)
def filter_result(self, queryset, name, value):
if not value:
return queryset
query = Q(result_data__reason__icontains=value) | Q(
result_data__error_message__icontains=value,
)
try:
numeric_value = int(value)
except (TypeError, ValueError):
pass
else:
query |= Q(result_data__document_id=numeric_value) | Q(
result_data__duplicate_of=numeric_value,
)
if "duplicate" in value.lower():
query |= Q(result_data__duplicate_of__isnull=False)
return queryset.filter(query)
def filter_is_complete(self, queryset, name, value):
if value:
@@ -169,6 +169,10 @@ class FileStabilityTracker:
self._tracked.pop(path, None)
yield path
def is_tracking(self, path: Path) -> bool:
"""Check whether a path is currently being tracked for stability."""
return path.resolve() in self._tracked
def has_pending_files(self) -> bool:
"""Check if there are files waiting for stability check."""
return len(self._tracked) > 0
@@ -370,6 +374,16 @@ class Command(BaseCommand):
# Testing timeout in seconds
testing_timeout_s: Final[float] = 0.5
# How often to perform a full-glob rescan of the consume directory as a
# safety net. Each watchfiles watcher is torn down and recreated on every
# batch to reconfigure its timeout, and a fresh watcher silently adopts the
# current directory contents as its baseline. A file that appears between
# one batch and the next watcher's baseline is therefore never reported and
# would sit in the consume directory forever. This periodic rescan re-injects
# such files into the stability tracker (see GH issue #13011). Not currently
# user-configurable; instances may override for testing.
rescan_interval_s: float = 300.0
def add_arguments(self, parser) -> None:
parser.add_argument(
"directory",
@@ -425,7 +439,7 @@ class Command(BaseCommand):
)
# Process existing files
self._process_existing_files(
queued = self._process_existing_files(
directory=directory,
recursive=recursive,
subdirs_as_tags=subdirs_as_tags,
@@ -445,6 +459,7 @@ class Command(BaseCommand):
polling_interval=polling_interval,
stability_delay=stability_delay,
is_testing=is_testing,
queued=queued,
)
logger.debug("Consumer exiting")
@@ -456,11 +471,18 @@ class Command(BaseCommand):
recursive: bool,
subdirs_as_tags: bool,
consumer_filter: ConsumerFilter,
) -> None:
"""Process any existing files in the consumption directory."""
) -> set[Path]:
"""
Process any existing files in the consumption directory.
Returns the set of resolved paths that were queued, so the watch loop
can seed its in-flight set and avoid re-queuing them on the first
rescan before the consume tasks have removed them from disk.
"""
logger.info(f"Processing existing files in {directory}")
glob_pattern = "**/*" if recursive else "*"
queued: set[Path] = set()
for filepath in directory.glob(glob_pattern):
# Use filter to check if file should be processed
@@ -475,6 +497,48 @@ class Command(BaseCommand):
consumption_dir=directory,
subdirs_as_tags=subdirs_as_tags,
)
queued.add(filepath.resolve())
return queued
def _rescan_existing_files(
self,
*,
directory: Path,
recursive: bool,
consumer_filter: ConsumerFilter,
tracker: FileStabilityTracker,
queued: set[Path],
) -> None:
"""
Re-inject on-disk files the watcher never reported into the tracker.
Acts as a safety net for files stranded by the watcher-recreation gap
(see ``rescan_interval_s``). Files already being tracked or already
queued and awaiting consumption are skipped, so a file is never queued
twice. Queued paths that have since left the directory are pruned so a
later file reusing the same name is not skipped forever.
"""
# Prune in-flight paths that have left the directory
for path in list(queued):
if not path.exists():
queued.discard(path)
glob_pattern = "**/*" if recursive else "*"
for filepath in directory.glob(glob_pattern):
if not filepath.is_file():
continue
if not consumer_filter(Change.added, str(filepath)):
continue
resolved = filepath.resolve()
if tracker.is_tracking(resolved) or resolved in queued:
continue
logger.debug(f"Rescan found untracked file: {resolved}")
tracker.track(resolved, Change.added)
def _watch_directory(
self,
@@ -486,11 +550,24 @@ class Command(BaseCommand):
polling_interval: float,
stability_delay: float,
is_testing: bool,
queued: set[Path] | None = None,
) -> None:
"""Watch directory for changes and process stable files."""
use_polling = polling_interval > 0
poll_delay_ms = int(polling_interval * 1000) if use_polling else 0
# Resolved paths that have been queued and are awaiting consumption.
# Seeded from the startup scan so the first rescan does not re-queue
# files whose consume tasks have not yet removed them from disk.
queued = set() if queued is None else queued
# Full-glob safety net cadence (0 disables)
rescan_interval_s = self.rescan_interval_s
rescan_timeout_ms = (
int(rescan_interval_s * 1000) if rescan_interval_s > 0 else 0
)
last_rescan = monotonic()
if use_polling:
logger.info(
f"Watching {directory} using polling (interval: {polling_interval}s)",
@@ -505,6 +582,20 @@ class Command(BaseCommand):
stability_timeout_ms = int(stability_delay * 1000)
testing_timeout_ms = int(self.testing_timeout_s * 1000)
def cap_for_rescan(ms: int) -> int:
"""
Ensure the watch loop wakes often enough to run the rescan.
``watch()`` blocks for up to ``rust_timeout``, so the rescan can
only run that often. A timeout of 0 means "wait indefinitely",
which would never wake to rescan; cap it at the rescan interval.
"""
if rescan_timeout_ms <= 0:
return ms
if ms <= 0:
return rescan_timeout_ms
return min(ms, rescan_timeout_ms)
# Calculate appropriate timeout for watch loop
# In polling mode, rust_timeout must be significantly longer than poll_delay_ms
# to ensure poll cycles can complete before timing out
@@ -522,6 +613,8 @@ class Command(BaseCommand):
# Not testing, wait indefinitely for first event
timeout_ms = 0
timeout_ms = cap_for_rescan(timeout_ms)
self.stop_flag.clear()
while not self.stop_flag.is_set():
@@ -551,10 +644,26 @@ class Command(BaseCommand):
consumption_dir=directory,
subdirs_as_tags=subdirs_as_tags,
)
# Remember it so the rescan does not re-queue it while
# the consume task has yet to remove it from disk
queued.add(stable_path)
# Exit watch loop to reconfigure timeout
break
# Periodic full-glob safety net for files the watcher missed
if rescan_timeout_ms > 0 and (
monotonic() - last_rescan >= rescan_interval_s
):
self._rescan_existing_files(
directory=directory,
recursive=recursive,
consumer_filter=consumer_filter,
tracker=tracker,
queued=queued,
)
last_rescan = monotonic()
# Determine next timeout
if tracker.has_pending_files():
# Check pending files at stability interval
@@ -572,6 +681,8 @@ class Command(BaseCommand):
# No pending files, wait indefinitely
timeout_ms = 0
timeout_ms = cap_for_rescan(timeout_ms)
except KeyboardInterrupt: # pragma: nocover
logger.info("Received interrupt, stopping consumer")
self.stop_flag.set()
@@ -2,6 +2,7 @@ from typing import Any
from documents.management.commands.base import PaperlessCommand
from documents.tasks import llmindex_index
from paperless_ai.indexing import llm_index_compact
class Command(PaperlessCommand):
@@ -12,9 +13,12 @@ class Command(PaperlessCommand):
def add_arguments(self, parser: Any) -> None:
super().add_arguments(parser)
parser.add_argument("command", choices=["rebuild", "update"])
parser.add_argument("command", choices=["rebuild", "update", "compact"])
def handle(self, *args: Any, **options: Any) -> None:
if options["command"] == "compact":
llm_index_compact()
return
llmindex_index(
rebuild=options["command"] == "rebuild",
iter_wrapper=lambda docs: self.track(
+4
View File
@@ -8,11 +8,15 @@ from documents.search._backend import get_backend
from documents.search._backend import reset_backend
from documents.search._schema import needs_rebuild
from documents.search._schema import wipe_index
from documents.search._translate import InvalidDateQuery
from documents.search._translate import SearchQueryError
__all__ = [
"InvalidDateQuery",
"SearchHit",
"SearchIndexLockError",
"SearchMode",
"SearchQueryError",
"TantivyBackend",
"TantivyRelevanceList",
"WriteBatch",
+18 -2
View File
@@ -866,8 +866,24 @@ class TantivyBackend:
final_query = self._apply_permission_filter(mlt_query, user)
effective_limit = limit if limit is not None else searcher.num_docs
# Fetch one extra to account for excluding the original document
results = searcher.search(final_query, limit=effective_limit + 1)
try:
# Fetch one extra to account for excluding the original document
results = searcher.search(final_query, limit=effective_limit + 1)
except BaseException: # pragma: no cover
# Tantivy 0.26 panics in BM25 idf scoring when the index holds
# soft-deleted documents (doc_freq can exceed the alive doc count),
# which only surfaces for the More Like This query. The panic crosses
# the pyo3 boundary as a `pyo3_runtime.PanicException` — a
# BaseException, not an Exception — so catch BaseException and degrade
# to "no similar documents" instead of bubbling a 500 to the client.
# Fixed upstream: https://github.com/quickwit-oss/tantivy/pull/2964
# Remove once the bundled tantivy includes that fix.
logger.warning(
"More Like This scoring panicked (likely stale tantivy segment "
"stats after deletions); returning no results. A search index "
"reindex will rebuild consistent statistics.",
)
return []
addrs = [addr for _score, addr in results.hits]
all_ids = cast("list[int]", searcher.fast_field_values("id", addrs))
+163
View File
@@ -0,0 +1,163 @@
from __future__ import annotations
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Final
from dateutil.relativedelta import relativedelta
if TYPE_CHECKING:
from datetime import tzinfo
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _quarter_start(d: date) -> date:
"""Return the first day of the calendar quarter containing ``d``."""
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
def _midnight(d: date, tz: tzinfo) -> datetime:
"""Convert a calendar date at local-timezone midnight to a UTC datetime."""
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _keyword_bounds(keyword: str, tz: tzinfo) -> tuple[date, date]:
"""
Map a relative date keyword to ``(start, exclusive_end)`` calendar dates.
``tz`` only determines what "today" is; the caller decides how the returned
dates become UTC datetime boundaries (date-only vs. local-midnight offset).
"""
today = datetime.now(tz).date()
if keyword == _TODAY:
return today, today + timedelta(days=1)
if keyword == _YESTERDAY:
return today - timedelta(days=1), today
if keyword == _PREVIOUS_WEEK:
this_monday = today - timedelta(days=today.weekday())
return this_monday - timedelta(weeks=1), this_monday
if keyword == _THIS_MONTH:
first = today.replace(day=1)
return first, first + relativedelta(months=1)
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
return this_first - relativedelta(months=1), this_first
if keyword == _THIS_YEAR:
return date(today.year, 1, 1), date(today.year + 1, 1, 1)
if keyword == _PREVIOUS_YEAR:
return date(today.year - 1, 1, 1), date(today.year, 1, 1)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
return this_quarter - relativedelta(months=3), this_quarter
raise ValueError(f"Unknown keyword: {keyword}")
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic — date only.
"""
start, end = _keyword_bounds(keyword, tz)
lo = datetime(start.year, start.month, start.day, tzinfo=UTC)
hi = datetime(end.year, end.month, end.day, tzinfo=UTC)
return _iso_range(lo, hi)
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC — full offset arithmetic required.
"""
start, end = _keyword_bounds(keyword, tz)
return _iso_range(_midnight(start, tz), _midnight(end, tz))
def _precision_bounds(digits: str) -> tuple[date, date] | None:
"""
Map a 4/6/8-digit date token to (start, exclusive_end) calendar dates.
YYYY -> whole year, YYYYMM -> whole month, YYYYMMDD -> single day.
Returns None for any unparsable or out-of-range value (e.g. month 23),
so callers can emit a no-match clause instead of erroring (Whoosh parity).
"""
try:
if len(digits) == 4:
year = int(digits)
return date(year, 1, 1), date(year + 1, 1, 1)
if len(digits) == 6:
year, month = int(digits[:4]), int(digits[4:6])
start = date(year, month, 1)
end = date(year + 1, 1, 1) if month == 12 else date(year, month + 1, 1)
return start, end
if len(digits) == 8:
start = date(int(digits[:4]), int(digits[4:6]), int(digits[6:8]))
return start, start + timedelta(days=1)
except ValueError:
return None
return None
def _utc_bounds_for_field(
field: str,
start: date,
end: date,
tz: tzinfo,
) -> tuple[datetime, datetime]:
"""
Convert calendar-date bounds to UTC datetimes per the field's storage type.
For DateField (``created``) the bounds are UTC midnight (no offset). For
DateTimeField (``added``/``modified``) the bounds are local-tz midnight
converted to UTC, matching how each field is indexed.
"""
if field in _DATE_ONLY_FIELDS:
return (
datetime(start.year, start.month, start.day, tzinfo=UTC),
datetime(end.year, end.month, end.day, tzinfo=UTC),
)
return (
datetime(start.year, start.month, start.day, tzinfo=tz).astimezone(UTC),
datetime(end.year, end.month, end.day, tzinfo=tz).astimezone(UTC),
)
def _field_range_from_dates(field: str, start: date, end: date, tz: tzinfo) -> str:
"""Build a Tantivy ``field:[lo TO hi]`` ISO range from calendar-date bounds."""
lo, hi = _utc_bounds_for_field(field, start, end, tz)
return f"{field}:{_iso_range(lo, hi)}"
+27 -405
View File
@@ -1,88 +1,35 @@
from __future__ import annotations
import logging
from datetime import UTC
from datetime import date
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Final
import regex
import tantivy
from dateutil.relativedelta import relativedelta
from django.conf import settings
from documents.search._dates import (
_date_only_range, # noqa: F401 — re-exported for test imports
)
from documents.search._dates import (
_datetime_range, # noqa: F401 — re-exported for test imports
)
from documents.search._tokenizer import simple_search_tokens
from documents.search._translate import SearchQueryError
from documents.search._translate import translate_query
if TYPE_CHECKING:
from datetime import tzinfo
from django.contrib.auth.base_user import AbstractBaseUser
logger = logging.getLogger("paperless.search")
# Maximum seconds any single regex substitution may run.
# Prevents ReDoS on adversarial user-supplied query strings.
_REGEX_TIMEOUT: Final[float] = 1.0
_DATE_ONLY_FIELDS = frozenset({"created"})
_TODAY: Final[str] = "today"
_YESTERDAY: Final[str] = "yesterday"
_PREVIOUS_WEEK: Final[str] = "previous week"
_THIS_MONTH: Final[str] = "this month"
_PREVIOUS_MONTH: Final[str] = "previous month"
_THIS_YEAR: Final[str] = "this year"
_PREVIOUS_YEAR: Final[str] = "previous year"
_PREVIOUS_QUARTER: Final[str] = "previous quarter"
_DATE_KEYWORDS = frozenset(
{
_TODAY,
_YESTERDAY,
_PREVIOUS_WEEK,
_THIS_MONTH,
_PREVIOUS_MONTH,
_THIS_YEAR,
_PREVIOUS_YEAR,
_PREVIOUS_QUARTER,
},
)
_DATE_KEYWORD_PATTERN = "|".join(
sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True),
)
_FIELD_DATE_RE = regex.compile(
rf"""(?<!\w)(?P<field>created|modified|added)\s*:\s*(?:
(?P<quote>["'])(?P<quoted>{_DATE_KEYWORD_PATTERN})(?P=quote)
|
(?P<bare>{_DATE_KEYWORD_PATTERN})(?![\w-])
)""",
regex.IGNORECASE | regex.VERBOSE,
)
_COMPACT_DATE_RE = regex.compile(r"\b(\d{14})\b")
_RELATIVE_RANGE_RE = regex.compile(
r"\[now([+-]\d+[dhm])?\s+TO\s+now([+-]\d+[dhm])?\]",
regex.IGNORECASE,
)
# Whoosh-style relative date range: e.g. [-1 week to now], [-7 days to now]
_WHOOSH_REL_RANGE_RE = regex.compile(
r"\[-(?P<n>\d+)\s+(?P<unit>second|minute|hour|day|week|month|year)s?\s+to\s+now\]",
regex.IGNORECASE,
)
# Whoosh-style 8-digit date: field:YYYYMMDD — field-aware so timezone can be applied correctly.
# Scoped to date fields only; numeric fields (asn, id, page_count, ...) must not be rewritten.
_DATE8_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):(?P<date8>\d{8})\b",
)
_YEAR_RANGE_RE = regex.compile(
r"(?<!\w)(?P<field>created|modified|added):\[(?P<y1>\d{4})\s+TO\s+(?P<y2>\d{4})\]",
regex.IGNORECASE,
)
# Tantivy syntax error: " - " and " + " with spaces on both sides are invalid because
# the NOT/MUST operators require no space between the operator and the term.
# In natural-language queries (e.g., "H52.1 - Kurzsichtigkeit"), the dash is a separator.
_SPACED_OPERATOR_RE = regex.compile(r"\s+[-+]\s+")
_TRAILING_OPERATOR_RE = regex.compile(r"\s+[-+]+\s*$")
# Matches CJK/Hangul characters so queries can be routed to bigram fields.
# Uses Unicode properties to cover all blocks including Extension B+ planes.
_CJK_RE: Final = regex.compile(r"[\p{Han}\p{Hiragana}\p{Katakana}\p{Hangul}]+")
@@ -117,303 +64,12 @@ def _build_cjk_query(
return None
def _fmt(dt: datetime) -> str:
"""Format a datetime as an ISO 8601 UTC string for use in Tantivy range queries."""
return dt.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
def _iso_range(lo: datetime, hi: datetime) -> str:
"""Format a [lo TO hi] range string in ISO 8601 for Tantivy query syntax."""
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
def _date_only_range(keyword: str, tz: tzinfo) -> str:
"""
For `created` (DateField): use the local calendar date, converted to
midnight UTC boundaries. No offset arithmetic — date only.
"""
today = datetime.now(tz).date()
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
lo = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, lo + timedelta(days=1))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
lo = datetime(y.year, y.month, y.day, tzinfo=UTC)
hi = datetime(today.year, today.month, today.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
lo = datetime(last_mon.year, last_mon.month, last_mon.day, tzinfo=UTC)
hi = datetime(this_mon.year, this_mon.month, this_mon.day, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_MONTH:
lo = datetime(today.year, today.month, 1, tzinfo=UTC)
if today.month == 12:
hi = datetime(today.year + 1, 1, 1, tzinfo=UTC)
else:
hi = datetime(today.year, today.month + 1, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _PREVIOUS_MONTH:
if today.month == 1:
lo = datetime(today.year - 1, 12, 1, tzinfo=UTC)
else:
lo = datetime(today.year, today.month - 1, 1, tzinfo=UTC)
hi = datetime(today.year, today.month, 1, tzinfo=UTC)
return _iso_range(lo, hi)
if keyword == _THIS_YEAR:
lo = datetime(today.year, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year + 1, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_YEAR:
lo = datetime(today.year - 1, 1, 1, tzinfo=UTC)
return _iso_range(lo, datetime(today.year, 1, 1, tzinfo=UTC))
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
lo = datetime(
last_quarter.year,
last_quarter.month,
last_quarter.day,
tzinfo=UTC,
)
hi = datetime(
this_quarter.year,
this_quarter.month,
this_quarter.day,
tzinfo=UTC,
)
return _iso_range(lo, hi)
raise ValueError(f"Unknown keyword: {keyword}")
def _datetime_range(keyword: str, tz: tzinfo) -> str:
"""
For `added` / `modified` (DateTimeField, stored as UTC): convert local day
boundaries to UTC — full offset arithmetic required.
"""
now_local = datetime.now(tz)
today = now_local.date()
def _midnight(d: date) -> datetime:
return datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
def _quarter_start(d: date) -> date:
return date(d.year, ((d.month - 1) // 3) * 3 + 1, 1)
if keyword == _TODAY:
return _iso_range(_midnight(today), _midnight(today + timedelta(days=1)))
if keyword == _YESTERDAY:
y = today - timedelta(days=1)
return _iso_range(_midnight(y), _midnight(today))
if keyword == _PREVIOUS_WEEK:
this_mon = today - timedelta(days=today.weekday())
last_mon = this_mon - timedelta(weeks=1)
return _iso_range(_midnight(last_mon), _midnight(this_mon))
if keyword == _THIS_MONTH:
first = today.replace(day=1)
if today.month == 12:
next_first = date(today.year + 1, 1, 1)
else:
next_first = date(today.year, today.month + 1, 1)
return _iso_range(_midnight(first), _midnight(next_first))
if keyword == _PREVIOUS_MONTH:
this_first = today.replace(day=1)
if today.month == 1:
last_first = date(today.year - 1, 12, 1)
else:
last_first = date(today.year, today.month - 1, 1)
return _iso_range(_midnight(last_first), _midnight(this_first))
if keyword == _THIS_YEAR:
return _iso_range(
_midnight(date(today.year, 1, 1)),
_midnight(date(today.year + 1, 1, 1)),
)
if keyword == _PREVIOUS_YEAR:
return _iso_range(
_midnight(date(today.year - 1, 1, 1)),
_midnight(date(today.year, 1, 1)),
)
if keyword == _PREVIOUS_QUARTER:
this_quarter = _quarter_start(today)
last_quarter = this_quarter - relativedelta(months=3)
return _iso_range(_midnight(last_quarter), _midnight(this_quarter))
raise ValueError(f"Unknown keyword: {keyword}")
def _rewrite_compact_date(query: str) -> str:
"""Rewrite Whoosh compact date tokens (14-digit YYYYMMDDHHmmss) to ISO 8601."""
def _sub(m: regex.Match[str]) -> str:
raw = m.group(1)
try:
dt = datetime(
int(raw[0:4]),
int(raw[4:6]),
int(raw[6:8]),
int(raw[8:10]),
int(raw[10:12]),
int(raw[12:14]),
tzinfo=UTC,
)
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
except ValueError:
return str(m.group(0))
try:
return _COMPACT_DATE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (compact date rewrite timed out)",
)
def _rewrite_relative_range(query: str) -> str:
"""Rewrite Whoosh relative ranges ([now-7d TO now]) to concrete ISO 8601 UTC boundaries."""
def _sub(m: regex.Match[str]) -> str:
now = datetime.now(UTC)
def _offset(s: str | None) -> timedelta:
if not s:
return timedelta(0)
sign = 1 if s[0] == "+" else -1
n, unit = int(s[1:-1]), s[-1]
return (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
lo, hi = now + _offset(m.group(1)), now + _offset(m.group(2))
if lo > hi:
lo, hi = hi, lo
return f"[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _RELATIVE_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (relative range rewrite timed out)",
)
def _rewrite_whoosh_relative_range(query: str) -> str:
"""Rewrite Whoosh-style relative date ranges ([-N unit to now]) to ISO 8601.
Supports: second, minute, hour, day, week, month, year (singular and plural).
Example: ``added:[-1 week to now]`` → ``added:[2025-01-01T… TO 2025-01-08T…]``
"""
now = datetime.now(UTC)
def _sub(m: regex.Match[str]) -> str:
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
lo = now - delta_map[unit]
return f"[{_fmt(lo)} TO {_fmt(now)}]"
try:
return _WHOOSH_REL_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (Whoosh relative range rewrite timed out)",
)
def _rewrite_8digit_date(query: str, tz: tzinfo) -> str:
"""Rewrite field:YYYYMMDD date tokens to an ISO 8601 day range.
Runs after ``_rewrite_compact_date`` so 14-digit timestamps are already
converted and won't spuriously match here.
For DateField fields (e.g. ``created``) uses UTC midnight boundaries.
For DateTimeField fields (e.g. ``added``, ``modified``) uses local TZ
midnight boundaries converted to UTC — matching the ``_datetime_range``
behaviour for keyword dates.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
raw = m.group("date8")
try:
year, month, day = int(raw[0:4]), int(raw[4:6]), int(raw[6:8])
d = date(year, month, day)
if field in _DATE_ONLY_FIELDS:
lo = datetime(d.year, d.month, d.day, tzinfo=UTC)
hi = lo + timedelta(days=1)
else:
# DateTimeField: use local-timezone midnight → UTC
lo = datetime(d.year, d.month, d.day, tzinfo=tz).astimezone(UTC)
hi = datetime(
(d + timedelta(days=1)).year,
(d + timedelta(days=1)).month,
(d + timedelta(days=1)).day,
tzinfo=tz,
).astimezone(UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
except ValueError:
return m.group(0)
try:
return _DATE8_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (8-digit date rewrite timed out)",
)
def _rewrite_year_range(query: str) -> str:
"""Rewrite Whoosh-style year-only date ranges to ISO 8601 UTC boundaries.
Converts ``field:[YYYY TO YYYY]`` to a full ISO 8601 datetime range.
The upper bound is the start of the year after the end year (exclusive),
matching the Whoosh convention of treating year-only ranges as full-year spans.
"""
def _sub(m: regex.Match[str]) -> str:
field = m.group("field")
y1, y2 = int(m.group("y1")), int(m.group("y2"))
# Whoosh swaps a reversed range when both years are explicit
# (whoosh.util.times.timespan.disambiguated); match that so a backwards
# range spans the intended years instead of matching nothing.
lo_year, hi_year = min(y1, y2), max(y1, y2)
lo = datetime(lo_year, 1, 1, tzinfo=UTC)
hi = datetime(hi_year + 1, 1, 1, tzinfo=UTC)
return f"{field}:[{_fmt(lo)} TO {_fmt(hi)}]"
try:
return _YEAR_RANGE_RE.sub(_sub, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError("Query too complex to process (year range rewrite timed out)")
def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
"""
Rewrite natural date syntax to ISO 8601 format for Tantivy compatibility.
Performs the first stage of query preprocessing, converting various date
formats and keywords to ISO 8601 datetime ranges that Tantivy can parse:
- Compact 14-digit dates (YYYYMMDDHHmmss)
- Whoosh relative ranges ([-7 days to now], [now-1h TO now+2h])
- 8-digit dates with field awareness (created:20240115)
- Natural keywords (field:today, field:"previous quarter", etc.)
Delegates to ``translate_query`` which handles all date forms, comma
expansion, field aliasing, relative ranges, and operator normalization.
Args:
query: Raw user query string
@@ -425,35 +81,15 @@ def rewrite_natural_date_keywords(query: str, tz: tzinfo) -> str:
Note:
Bare keywords without field prefixes pass through unchanged.
"""
query = _rewrite_compact_date(query)
query = _rewrite_whoosh_relative_range(query)
query = _rewrite_year_range(query)
query = _rewrite_8digit_date(query, tz)
query = _rewrite_relative_range(query)
def _replace(m: regex.Match[str]) -> str:
field = m.group("field")
keyword = (m.group("quoted") or m.group("bare")).lower()
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(keyword, tz)}"
return f"{field}:{_datetime_range(keyword, tz)}"
try:
return _FIELD_DATE_RE.sub(_replace, query, timeout=_REGEX_TIMEOUT)
except TimeoutError: # pragma: no cover
raise ValueError(
"Query too complex to process (date keyword rewrite timed out)",
)
return translate_query(query, tz)
def normalize_query(query: str) -> str:
"""
Normalize query syntax for better search behavior.
Expands comma-separated field values to explicit AND clauses and
collapses excessive whitespace for cleaner parsing:
- tag:foo,bar → tag:foo AND tag:bar
- multiple spaces → single spaces
Delegates to ``translate_query`` which handles comma expansion, whitespace
collapsing, operator normalization, and field aliasing.
Args:
query: Query string after date rewriting
@@ -461,29 +97,7 @@ def normalize_query(query: str) -> str:
Returns:
Normalized query string ready for Tantivy parsing
"""
def _expand(m: regex.Match[str]) -> str:
field = m.group(1)
values = [v.strip() for v in m.group(2).split(",") if v.strip()]
return " AND ".join(f"{field}:{v}" for v in values)
try:
query = regex.sub(
r"(\w+):([^\s\[\]]+(?:,[^\s\[\]]+)+)",
_expand,
query,
timeout=_REGEX_TIMEOUT,
)
query = regex.sub(r" {2,}", " ", query, timeout=_REGEX_TIMEOUT).strip()
# Strip trailing dangling operators before Tantivy sees them.
query = _TRAILING_OPERATOR_RE.sub("", query, timeout=_REGEX_TIMEOUT).strip()
# Replace " - " / " + " with a space: Tantivy requires no space between
# the operator and its operand (-term / +term), so spaces on both sides
# means this is a natural-language separator, not a query operator.
query = _SPACED_OPERATOR_RE.sub(" ", query, timeout=_REGEX_TIMEOUT).strip()
return query
except TimeoutError: # pragma: no cover
raise ValueError("Query too complex to process (normalization timed out)")
return translate_query(query, UTC)
def build_permission_filter(
@@ -603,8 +217,16 @@ def parse_user_query(
as a post-search score filter, not during query construction.
"""
query_str = rewrite_natural_date_keywords(raw_query, tz)
query_str = normalize_query(query_str)
try:
query_str = translate_query(raw_query, tz)
except SearchQueryError:
# Intentional, user-fixable error (e.g. an unparsable date). Propagate so
# the view can return a 400 with a helpful message rather than falling
# back to the raw (still-invalid) query.
raise
except Exception: # pragma: no cover - defensive
logger.warning("Query translation failed; using raw query", exc_info=True)
query_str = raw_query
exact = index.parse_query(
query_str,
+566
View File
@@ -0,0 +1,566 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC
from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TypeAlias
import regex
from dateutil.relativedelta import relativedelta
from documents.search._dates import _DATE_KEYWORDS
from documents.search._dates import _DATE_ONLY_FIELDS
from documents.search._dates import _date_only_range
from documents.search._dates import _datetime_range
from documents.search._dates import _field_range_from_dates
from documents.search._dates import _fmt
from documents.search._dates import _precision_bounds
from documents.search._dates import _utc_bounds_for_field
# Compiled regex that matches any known multi-word (or single-word) date keyword
# at the start of a match position, longest alternatives first so "previous week"
# wins over a hypothetical shorter "previous".
_KEYWORD_VALUE_RE = regex.compile(
"|".join(sorted((regex.escape(k) for k in _DATE_KEYWORDS), key=len, reverse=True)),
regex.IGNORECASE,
)
if TYPE_CHECKING:
from datetime import tzinfo
# TODO: this module translates date queries into Tantivy *string* syntax, which
# forces a workaround for something Tantivy's string parser cannot express on
# date fields: open-ended ranges use far-past/far-future string sentinels
# (OPEN_LO/OPEN_HI). These can be replaced with a real tantivy.Query object
# (Query.range_query(..., None) for open bounds) once tantivy-py accepts Python
# datetimes in range_query/term_query on Date fields. That support exists on
# tantivy-py master (PRs #655 + #666) but postdates the pinned 0.26.0 wheel, so
# it is blocked only on a published release > 0.26.0 and a dependency bump.
# (Unparsable dates now raise InvalidDateQuery -> HTTP 400 rather than using a
# no-match string sentinel.)
# Fields that store exact, non-analyzed comma-joined tokens in the index and so
# need explicit comma->AND expansion (Whoosh KEYWORD(commas=True) set).
MULTI_VALUE_FIELDS = frozenset({"tag", "tag_id", "viewer_id"})
# Date fields whose values/ranges get rewritten to RFC3339 Tantivy ranges.
DATE_FIELDS = frozenset({"created", "modified", "added"})
# Field aliases: Whoosh (v2) field names that were renamed in the Tantivy schema.
# Preserved here so v2 queries using the old names continue to work without 400
# errors instead of silently failing. Applied by _render to non-date field tokens.
FIELD_ALIASES: dict[str, str] = {
"type": "document_type",
"type_id": "document_type_id",
"path": "storage_path",
"path_id": "storage_path_id",
}
# Known schema fields: a comma immediately followed by ``<known>:`` is a clause
# separator. Restricting to known fields prevents URL-like ``http:`` misfires.
KNOWN_FIELDS = frozenset(
{
"title",
"content",
"correspondent",
"document_type",
"type", # v2 alias -> document_type
"storage_path",
"path", # v2 alias -> storage_path
"tag",
"tag_id",
"correspondent_id",
"document_type_id",
"type_id", # v2 alias -> document_type_id
"storage_path_id",
"path_id", # v2 alias -> storage_path_id
"owner_id",
"viewer_id",
"asn",
"page_count",
"num_notes",
"created",
"modified",
"added",
"original_filename",
"checksum",
"notes",
"custom_fields",
},
)
_FIELD_RE = regex.compile(r"(?P<field>\w+):")
# Matches the TO separator inside a range bracket. Handles three forms:
# middle: "lo TO hi" (either lo or hi may be empty)
# trailing: "lo TO" (open upper bound)
# leading: "TO hi" (open lower bound)
# Bounds MAY contain internal spaces (e.g. "-7 days"), so we use .*? / .+?
# and split on the whitespace-delimited " TO " / " to " separator.
_RANGE_RE = regex.compile(
r"^\s*(?P<lo>.*?)\s+[Tt][Oo]\s+(?P<hi>.+?)\s*$"
r"|"
r"^\s*(?P<lo2>.+?)\s+[Tt][Oo]\s*$"
r"|"
r"^\s*[Tt][Oo]\s+(?P<hi2>.+?)\s*$",
)
@dataclass(frozen=True, slots=True)
class FieldValue:
field: str
value: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class FieldValueList:
field: str
values: tuple[str, ...]
@dataclass(frozen=True, slots=True)
class FieldRange:
field: str
open: str
lo: str
hi: str
close: str
# Produced by the comma-resolution pass (not by scan()).
@dataclass(frozen=True, slots=True)
class Comma:
pass
@dataclass(frozen=True, slots=True)
class Passthrough:
raw: str
Token: TypeAlias = FieldValue | FieldValueList | FieldRange | Comma | Passthrough
_CLOSE: dict[str, str] = {"[": "]", "{": "}"}
def scan(query: str) -> list[Token]:
"""
Tokenize a raw query into date/comma-aware tokens, leaving everything else
as verbatim ``Passthrough`` runs. Non-recursive: finds the first matching
close bracket/quote. Nested brackets are not valid Tantivy range syntax and
pass through verbatim on mismatch.
"""
tokens: list[Token] = []
buf: list[str] = [] # accumulates passthrough chars
i, n = 0, len(query)
while i < n:
matched = _match_field_token(query, i)
if matched is None:
buf.append(query[i])
i += 1
continue
token, i = matched
_flush(buf, tokens)
tokens.append(token)
i = _maybe_comma(query, i, tokens)
_flush(buf, tokens)
return tokens
def _flush(buf: list[str], tokens: list[Token]) -> None:
"""Emit any accumulated passthrough characters as a single token."""
if buf:
tokens.append(Passthrough("".join(buf)))
buf.clear()
def _at_word_boundary(query: str, i: int) -> bool:
"""A field token may begin only at the start or after a non-word character."""
return i == 0 or not (query[i - 1].isalnum() or query[i - 1] == "_")
def _match_field_token(query: str, i: int) -> tuple[Token, int] | None:
"""
If a known ``field:`` token starts at ``i``, consume it and return
``(token, end_index)``; otherwise return None so the caller treats the
character as passthrough. Handles both ``field:[range]`` and ``field:value``,
and returns None when the range/value cannot be consumed.
"""
m = _FIELD_RE.match(query, i)
if m is None or m.group("field") not in KNOWN_FIELDS:
return None
if not _at_word_boundary(query, i):
return None
field = m.group("field")
j = m.end()
if j < len(query) and query[j] in "[{":
return _consume_range(query, j, field)
consumed = _consume_field_value(query, field, j)
if consumed is None:
return None
value, end = consumed
return FieldValue(field, value), end
def _consume_field_value(query: str, field: str, start: int) -> tuple[str, int] | None:
"""
Consume a field value starting at ``start``: a multi-word date keyword phrase
(date fields only), or a bare/quoted value, then absorb any comma-joined
continuation that is not a clause separator. ``resolve_commas`` later splits a
multi-value field's joined value into a ``FieldValueList``; for other fields
the comma stays literal.
"""
n = len(query)
consumed = None
if field in DATE_FIELDS:
km = _KEYWORD_VALUE_RE.match(query, start)
if km is not None and (km.end() >= n or query[km.end()] in " \t),"):
consumed = (km.group(0), km.end())
if consumed is None:
consumed = _consume_value(query, start)
if consumed is None:
return None
value, k = consumed
while k < n and query[k] == ",":
if _looks_like_known_field(query, k + 1):
break # clause separator: left for _maybe_comma to emit a Comma()
more = _consume_value(query, k + 1)
if more is None:
break
value = f"{value},{more[0]}"
k = more[1]
return value, k
def _consume_range(
query: str,
start: int,
field: str,
) -> tuple[FieldRange, int] | None:
"""Consume ``[lo TO hi]`` / ``{lo TO hi}`` from ``start`` (the bracket)."""
open_br = query[start]
close_br = _CLOSE[open_br]
end = query.find(close_br, start + 1)
if end == -1:
return None
inner = query[start + 1 : end]
m = _RANGE_RE.match(inner)
if m is not None:
if m.group("lo") is not None or m.group("hi") is not None:
# Middle form: "lo TO hi" (either may be empty string)
lo = (m.group("lo") or "").strip()
hi = (m.group("hi") or "").strip()
elif m.group("lo2") is not None:
# Trailing form: "lo TO"
lo = m.group("lo2").strip()
hi = ""
else:
# Leading form: "TO hi"
lo = ""
hi = (m.group("hi2") or "").strip()
else:
lo, hi = inner.strip(), ""
return FieldRange(field, open_br, lo, hi, close_br), end + 1
def _consume_value(query: str, start: int) -> tuple[str, int] | None:
"""Consume a bare or quoted field value from ``start``, stopping at comma."""
n = len(query)
if start >= n or query[start] in " \t":
return None
if query[start] in "\"'":
quote = query[start]
end = query.find(quote, start + 1)
if end == -1:
return None
return query[start : end + 1], end + 1
j = start
while j < n and query[j] not in " \t),":
j += 1
return query[start:j], j
def _looks_like_known_field(query: str, pos: int) -> bool:
"""True if a known ``field:`` token starts at ``pos``."""
m = _FIELD_RE.match(query, pos)
return bool(m and m.group("field") in KNOWN_FIELDS)
def _maybe_comma(query: str, i: int, tokens: list) -> int:
"""If a clause-separator comma follows at ``i``, emit ``Comma()`` and advance."""
if i < len(query) and query[i] == "," and _looks_like_known_field(query, i + 1):
tokens.append(Comma())
return i + 1
return i
def resolve_commas(tokens: list) -> list:
"""
Collapse value-list commas into ``FieldValueList`` and keep clause-separator
commas as ``Comma``. (Clause-sep commas are already emitted by ``scan`` via
the value-stop logic; this pass folds value-lists.)
"""
out: list = []
for tok in tokens:
if (
isinstance(tok, FieldValue)
and tok.field in MULTI_VALUE_FIELDS
and "," in tok.value
):
values = tuple(v for v in tok.value.split(",") if v)
out.append(FieldValueList(tok.field, values))
else:
out.append(tok)
return out
class SearchQueryError(ValueError):
"""
Base for user-fixable search query errors.
Carries a message safe to surface to the user (no internal details). The view
layer catches this and returns an HTTP 400, so any future subclass (unknown
field, malformed range, wrapped parser errors) gets the same treatment.
"""
class InvalidDateQuery(SearchQueryError):
"""Raised when a date field value or range bound cannot be parsed."""
def __init__(self, field: str, value: str) -> None:
self.field = field
self.value = value
super().__init__(f"Invalid date value {value!r} for field {field!r}.")
_DIGITS_RE = regex.compile(r"^\d{4}(?:\d{2}){0,2}$")
_ISO_RE = regex.compile(r"^\d{4}(?:-\d{2}(?:-\d{2})?)?$")
def translate_scalar(field: str, value: str, tz: tzinfo) -> str:
"""Translate a bare date-field value to a Tantivy range string."""
bare = value.strip("\"'").lower()
if bare in _DATE_KEYWORDS:
if field in _DATE_ONLY_FIELDS:
return f"{field}:{_date_only_range(bare, tz)}"
return f"{field}:{_datetime_range(bare, tz)}"
digits = value.replace("-", "")
if _DIGITS_RE.match(value) or _ISO_RE.match(value):
bounds = _precision_bounds(digits)
if bounds is None:
raise InvalidDateQuery(field, value)
return _field_range_from_dates(field, bounds[0], bounds[1], tz)
if regex.fullmatch(r"\d{14}", value):
try:
dt = datetime(
int(value[0:4]),
int(value[4:6]),
int(value[6:8]),
int(value[8:10]),
int(value[10:12]),
int(value[12:14]),
tzinfo=UTC,
)
except ValueError:
raise InvalidDateQuery(field, value) from None
iso = _fmt(dt)
return f"{field}:[{iso} TO {iso}]"
# Unrecognized shape -> tell the user their date is malformed rather than
# silently matching nothing or emitting invalid Tantivy syntax.
raise InvalidDateQuery(field, value)
# Open-bound sentinels for date ranges. These far-past/far-future strings allow
# open-ended ranges to be expressed as Tantivy string queries until tantivy-py
# exposes Query.range_query(..., None) on Date fields (see module TODO).
OPEN_LO = "0001-01-01T00:00:00Z"
OPEN_HI = "9999-12-31T23:59:59Z"
# Matches compact now-offset tokens like now-7d, now+1h, now-30m.
_NOW_COMPACT_RE = regex.compile(
r"^now(?P<sign>[+-])(?P<n>\d+)(?P<unit>[dhm])$",
regex.IGNORECASE,
)
# Matches "±N <unit>" Whoosh-style offsets (e.g. -7 days, -1 week, +3 hours)
# Unit is singular or plural; sign prefix is mandatory.
_NOW_SPACED_RE = regex.compile(
r"^(?P<sign>[+-])(?P<n>\d+)\s*"
r"(?P<unit>second|minute|hour|day|week|month|year)s?$",
regex.IGNORECASE,
)
def _resolve_relative_bound(token: str) -> datetime | None:
"""
Resolve a relative bound token to an exact UTC instant, or return None.
Supported forms:
- ``now`` -> current UTC instant
- ``now+/-<n>d/h/m`` -> now +/- timedelta (d=days, h=hours, m=minutes)
- ``±N <unit>`` -> now +/- delta; month/year use relativedelta
"""
stripped = token.strip()
low = stripped.lower()
now = datetime.now(UTC)
if low == "now":
return now
m = _NOW_COMPACT_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta = (
sign
* {
"d": timedelta(days=n),
"h": timedelta(hours=n),
"m": timedelta(minutes=n),
}[unit]
)
return now + delta
m = _NOW_SPACED_RE.match(stripped)
if m:
sign = 1 if m.group("sign") == "+" else -1
n = int(m.group("n"))
unit = m.group("unit").lower()
delta_map: dict[str, timedelta | relativedelta] = {
"second": timedelta(seconds=n),
"minute": timedelta(minutes=n),
"hour": timedelta(hours=n),
"day": timedelta(days=n),
"week": timedelta(weeks=n),
"month": relativedelta(months=n),
"year": relativedelta(years=n),
}
return now - delta_map[unit] if sign == -1 else now + delta_map[unit]
return None
def _bound_datetimes(
field: str,
token: str,
tz: tzinfo,
) -> tuple[datetime, datetime] | None:
"""
Return (floor_dt, ceil_dt) UTC datetimes for a single range bound token, or
None if the token is unparsable. ``now`` and relative offsets resolve to the
current instant (floor == ceil == that instant; no day-flooring).
"""
token = token.strip()
# Try relative/now forms first (before stripping hyphens which would mangle them).
rel = _resolve_relative_bound(token)
if rel is not None:
return rel, rel
# Full ISO datetime token (contains "T"): parse directly and return an exact
# instant (floor == ceil). Python 3.11+ datetime.fromisoformat accepts trailing Z.
if "T" in token:
try:
dt = datetime.fromisoformat(token)
# Ensure timezone-aware UTC result.
dt = dt.replace(tzinfo=UTC) if dt.tzinfo is None else dt.astimezone(UTC)
return dt, dt
except ValueError:
return None
digits = token.replace("-", "")
bounds = _precision_bounds(digits)
if bounds is None:
return None
start, end = bounds
return _utc_bounds_for_field(field, start, end, tz)
def _render(tok: Token, tz: tzinfo) -> str:
"""Render a single token back to a Tantivy query string fragment."""
if isinstance(tok, Passthrough):
return tok.raw
if isinstance(tok, Comma):
return " AND "
if isinstance(tok, FieldValueList):
field = FIELD_ALIASES.get(tok.field, tok.field)
return " AND ".join(f"{field}:{v}" for v in tok.values)
if isinstance(tok, FieldValue):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_scalar(field, tok.value, tz)
return f"{field}:{tok.value}"
if isinstance(tok, FieldRange):
field = FIELD_ALIASES.get(tok.field, tok.field)
if field in DATE_FIELDS:
return translate_range(field, tok.lo, tok.hi, tz)
return f"{field}:{tok.open}{tok.lo} TO {tok.hi}{tok.close}"
return "" # pragma: no cover
# Post-render operator normalization patterns: collapse repeated whitespace and
# strip spaced/trailing Tantivy boolean operators that would otherwise be invalid.
_MULTI_SPACE_RE = regex.compile(r" {2,}")
_TRAILING_OP_RE = regex.compile(r"\s+[-+]+\s*$")
_SPACED_OP_RE = regex.compile(r"\s+[-+]\s+")
def _normalize_operators(text: str) -> str:
"""
Collapse multiple spaces, strip trailing dangling operators, and replace
spaced operators (`` - `` / `` + ``) with a single space.
Applied only to Passthrough fragments (the rendered output is scanned for
operator artifacts outside bracketed ranges) via a post-render pass on the
full rendered string. This preserves date ranges (``[... TO ...]``) verbatim
while cleaning natural-language separators in the surrounding text.
"""
text = _MULTI_SPACE_RE.sub(" ", text)
text = _TRAILING_OP_RE.sub("", text).strip()
text = _SPACED_OP_RE.sub(" ", text).strip()
return text
def translate_query(raw: str, tz: tzinfo) -> str:
"""Translate a raw Whoosh-style query into Tantivy-compatible syntax."""
tokens = resolve_commas(scan(raw))
rendered = "".join(_render(t, tz) for t in tokens)
return _normalize_operators(rendered)
def translate_range(field: str, lo: str, hi: str, tz: tzinfo) -> str:
"""Translate a date-field ``[lo TO hi]`` range to a Tantivy ISO range string.
Handles partial-date bounds (YYYY, YYYYMM, YYYYMMDD, ISO dash variants),
open bounds (empty string -> OPEN_LO/OPEN_HI), ``now``, and reversed ranges
(swaps tokens before computing floor/ceil so the span is always correct).
"""
lo_s = lo.strip()
hi_s = hi.strip()
# Parse both bounds to (floor, ceil) pairs when present.
lo_pair: tuple[datetime, datetime] | None = None
hi_pair: tuple[datetime, datetime] | None = None
if lo_s:
lo_pair = _bound_datetimes(field, lo_s, tz)
if lo_pair is None:
raise InvalidDateQuery(field, lo_s)
if hi_s:
hi_pair = _bound_datetimes(field, hi_s, tz)
if hi_pair is None:
raise InvalidDateQuery(field, hi_s)
# Detect a reversed range: only swap when BOTH bounds are present.
if lo_pair is not None and hi_pair is not None and lo_pair[0] > hi_pair[0]:
lo_pair, hi_pair = hi_pair, lo_pair
lo_iso = _fmt(lo_pair[0]) if lo_pair is not None else OPEN_LO
hi_iso = _fmt(hi_pair[1]) if hi_pair is not None else OPEN_HI
return f"{field}:[{lo_iso} TO {hi_iso}]"
+69 -6
View File
@@ -48,6 +48,7 @@ from rest_framework import serializers
from rest_framework.exceptions import PermissionDenied
from rest_framework.fields import SerializerMethodField
from rest_framework.filters import OrderingFilter
from rest_framework.utils import model_meta
if settings.AUDIT_LOG_ENABLED:
from auditlog.context import set_actor
@@ -121,6 +122,45 @@ class DynamicFieldsModelSerializer(serializers.ModelSerializer[Any]):
self.fields.pop(field_name)
class DocumentUpdateFieldsModelSerializer(DynamicFieldsModelSerializer):
stale_update_excluded_fields = frozenset({"filename", "archive_filename"})
def _get_update_fields(self, validated_data) -> list[str]:
model_fields = {
field.name
for field in self.Meta.model._meta.concrete_fields
if field.name not in self.stale_update_excluded_fields
}
update_fields = [
field_name for field_name in validated_data if field_name in model_fields
]
if "modified" in model_fields and "modified" not in update_fields:
update_fields.append("modified")
return update_fields
def update(self, instance, validated_data):
serializers.raise_errors_on_nested_writes("update", self, validated_data)
info = model_meta.get_field_info(instance)
m2m_fields = []
for attr, value in validated_data.items():
if attr in info.relations and info.relations[attr].to_many:
m2m_fields.append((attr, value))
else:
setattr(instance, attr, value)
# File names are managed by post-save file handling. Saving only the
# serializer-updated fields prevents stale in-memory path values from
# overwriting a concurrent move.
instance.save(update_fields=self._get_update_fields(validated_data))
for attr, value in m2m_fields:
field = getattr(instance, attr)
field.set(value)
return instance
class MatchingModelSerializer(serializers.ModelSerializer[Any]):
document_count = serializers.IntegerField(read_only=True)
@@ -989,7 +1029,7 @@ class DocumentVersionInfoSerializer(serializers.Serializer[_DocumentVersionInfo]
class DocumentSerializer(
OwnedObjectSerializer,
NestedUpdateMixin,
DynamicFieldsModelSerializer,
DocumentUpdateFieldsModelSerializer,
):
correspondent = CorrespondentField(allow_null=True)
tags = TagsField(many=True)
@@ -1128,10 +1168,9 @@ class DocumentSerializer(
return super().validate(attrs)
def update(self, instance: Document, validated_data):
if "created_date" in validated_data and "created" not in validated_data:
instance.created = validated_data.get("created_date")
instance.save()
if "created_date" in validated_data:
if "created" not in validated_data:
validated_data["created"] = validated_data["created_date"]
logger.warning(
"created_date is deprecated, use created instead",
)
@@ -1201,11 +1240,13 @@ class DocumentSerializer(
for tag in instance.tags.all()
if tag not in inbox_tags_not_being_added
]
if settings.AUDIT_LOG_ENABLED:
with set_actor(self.user):
super().update(instance, validated_data)
else:
super().update(instance, validated_data)
# hard delete custom field instances that were soft deleted
CustomFieldInstance.deleted_objects.filter(document=instance).delete()
return instance
@@ -2632,18 +2673,25 @@ class RunTaskSerializer(serializers.Serializer[dict[str, str]]):
class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
tasks = serializers.ListField(
required=True,
required=False,
label="Tasks",
write_only=True,
child=serializers.IntegerField(),
)
all = serializers.BooleanField(
required=False,
default=False,
label="All",
write_only=True,
)
def _validate_task_id_list(self, tasks, name="tasks") -> None:
if not isinstance(tasks, list):
raise serializers.ValidationError(f"{name} must be a list")
if not all(isinstance(i, int) for i in tasks):
raise serializers.ValidationError(f"{name} must be a list of integers")
count = PaperlessTask.objects.filter(id__in=tasks).count()
queryset = self.context.get("queryset", PaperlessTask.objects.all())
count = queryset.filter(id__in=tasks).count()
if not count == len(tasks):
raise serializers.ValidationError(
f"Some tasks in {name} don't exist or were specified twice.",
@@ -2653,6 +2701,21 @@ class AcknowledgeTasksViewSerializer(serializers.Serializer[dict[str, Any]]):
self._validate_task_id_list(tasks)
return tasks
def validate(self, attrs):
acknowledge_all = attrs.get("all", False)
task_ids = attrs.get("tasks")
if acknowledge_all and task_ids is not None:
raise serializers.ValidationError(
"Set either all or tasks, not both.",
)
if not acknowledge_all and task_ids is None:
raise serializers.ValidationError(
"Either all must be true or tasks must be provided.",
)
return attrs
class ShareLinkSerializer(OwnedObjectSerializer):
class Meta:
+17 -3
View File
@@ -1,7 +1,6 @@
from __future__ import annotations
import datetime
import hashlib
import logging
import shutil
import traceback as _tb
@@ -16,6 +15,7 @@ from celery.signals import task_postrun
from celery.signals import task_prerun
from celery.signals import task_revoked
from celery.signals import worker_process_init
from celery.signals import worker_process_shutdown
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
@@ -54,6 +54,7 @@ from documents.models import WorkflowTrigger
from documents.permissions import get_objects_for_user_owner_aware
from documents.plugins.helpers import DocumentsStatusManager
from documents.templating.utils import convert_format_str_to_template_format
from documents.utils import compute_checksum
from documents.workflows.actions import build_workflow_action_context
from documents.workflows.actions import execute_email_action
from documents.workflows.actions import execute_move_to_trash_action
@@ -410,8 +411,7 @@ def _path_matches_checksum(path: Path, checksum: str | None) -> bool:
if checksum is None or not path.is_file():
return False
with path.open("rb") as f:
return hashlib.md5(f.read()).hexdigest() == checksum
return compute_checksum(path) == checksum
def _filename_template_uses_custom_fields(doc: Document) -> bool:
@@ -1340,6 +1340,20 @@ def close_connection_pool_on_worker_init(**kwargs) -> None:
conn.close_pool()
@worker_process_shutdown.connect
def close_connection_pool_on_worker_shutdown(**kwargs) -> None: # pragma: no cover
"""
Close the DB connection pool when a Celery child process exits.
With CELERY_WORKER_MAX_TASKS_PER_CHILD=1 each child is replaced after a
single task. Without closing the pool on shutdown, its connections linger
on the server until TCP keepalive reaps them, accumulating over time.
"""
for conn in connections.all(initialized_only=True):
if conn.alias == "default" and hasattr(conn, "pool") and conn.pool:
conn.close_pool()
def add_or_update_document_in_llm_index(sender, document, **kwargs):
"""
Add or update a document in the LLM index when it is created or updated.
+24 -15
View File
@@ -1,6 +1,7 @@
import logging
import os
import re
import unicodedata
from collections.abc import Iterable
from pathlib import PurePath
@@ -36,10 +37,12 @@ class FilePathTemplate(Template):
def clean_filepath(value: str) -> str:
"""
Clean up a filepath by:
1. Removing newlines and carriage returns
2. Removing extra spaces before and after forward slashes
3. Preserving spaces in other parts of the path
1. Normalizing Unicode to NFC form to prevent byte-level mismatches
2. Removing newlines and carriage returns
3. Removing extra spaces before and after forward slashes
4. Preserving spaces in other parts of the path
"""
value = unicodedata.normalize("NFC", value)
value = value.replace("\n", "").replace("\r", "")
value = re.sub(r"\s*/\s*", "/", value)
@@ -181,17 +184,17 @@ def get_basic_metadata_context(
"""
return {
"title": pathvalidate.sanitize_filename(
document.title,
unicodedata.normalize("NFC", document.title),
replacement_text="-",
),
"correspondent": pathvalidate.sanitize_filename(
document.correspondent.name,
unicodedata.normalize("NFC", document.correspondent.name),
replacement_text="-",
)
if document.correspondent
else no_value_default,
"document_type": pathvalidate.sanitize_filename(
document.document_type.name,
unicodedata.normalize("NFC", document.document_type.name),
replacement_text="-",
)
if document.document_type
@@ -202,7 +205,10 @@ def get_basic_metadata_context(
"owner_username": document.owner.username
if document.owner
else no_value_default,
"original_name": PurePath(document.original_filename).with_suffix("").name
"original_name": unicodedata.normalize(
"NFC",
PurePath(document.original_filename).with_suffix("").name,
)
if document.original_filename
else no_value_default,
"doc_pk": f"{document.pk:07}",
@@ -269,12 +275,12 @@ def get_tags_context(tags: Iterable[Tag]) -> dict[str, str | list[str]]:
return {
"tag_list": pathvalidate.sanitize_filename(
",".join(
sorted(tag.name for tag in tags),
sorted(unicodedata.normalize("NFC", tag.name) for tag in tags),
),
replacement_text="-",
),
# Assumed to be ordered, but a template could loop through to find what they want
"tag_name_list": [x.name for x in tags],
"tag_name_list": [unicodedata.normalize("NFC", x.name) for x in tags],
}
@@ -301,7 +307,7 @@ def get_custom_fields_context(
CustomField.FieldDataType.LONG_TEXT,
}:
value = pathvalidate.sanitize_filename(
field_instance.value,
unicodedata.normalize("NFC", field_instance.value),
replacement_text="-",
)
elif (
@@ -310,10 +316,13 @@ def get_custom_fields_context(
):
options = field_instance.field.extra_data["select_options"]
value = pathvalidate.sanitize_filename(
next(
option["label"]
for option in options
if option["id"] == field_instance.value
unicodedata.normalize(
"NFC",
next(
option["label"]
for option in options
if option["id"] == field_instance.value
),
),
replacement_text="-",
)
@@ -321,7 +330,7 @@ def get_custom_fields_context(
value = field_instance.value
field_data["custom_fields"][
pathvalidate.sanitize_filename(
field_instance.field.name,
unicodedata.normalize("NFC", field_instance.field.name),
replacement_text="-",
)
] = {
@@ -0,0 +1,36 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from django.core.management import call_command
if TYPE_CHECKING:
from pytest_mock import MockerFixture
_COMPACT = "documents.management.commands.document_llmindex.llm_index_compact"
_INDEX = "documents.management.commands.document_llmindex.llmindex_index"
class TestDocumentLlmindexCommand:
def test_compact_calls_llm_index_compact(self, mocker: MockerFixture) -> None:
mock_compact = mocker.patch(_COMPACT)
call_command("document_llmindex", "compact")
mock_compact.assert_called_once_with()
def test_rebuild_calls_llmindex_index_with_rebuild_true(
self,
mocker: MockerFixture,
) -> None:
mock_index = mocker.patch(_INDEX)
call_command("document_llmindex", "rebuild")
mock_index.assert_called_once()
assert mock_index.call_args.kwargs["rebuild"] is True
def test_update_calls_llmindex_index_with_rebuild_false(
self,
mocker: MockerFixture,
) -> None:
mock_index = mocker.patch(_INDEX)
call_command("document_llmindex", "update")
mock_index.assert_called_once()
assert mock_index.call_args.kwargs["rebuild"] is False
+12
View File
@@ -1,11 +1,15 @@
from __future__ import annotations
import tempfile
from typing import TYPE_CHECKING
import pytest
import tantivy
from documents.search._backend import TantivyBackend
from documents.search._backend import reset_backend
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
if TYPE_CHECKING:
from collections.abc import Generator
@@ -31,3 +35,11 @@ def backend() -> Generator[TantivyBackend, None, None]:
finally:
b.close()
reset_backend()
@pytest.fixture(scope="module")
def index() -> tantivy.Index:
"""A real Tantivy index for parse-acceptance tests (module scope for speed)."""
idx = tantivy.Index(build_schema(), path=tempfile.mkdtemp())
register_tokenizers(idx, "english")
return idx
+88 -10
View File
@@ -13,7 +13,6 @@ import time_machine
from documents.search._query import _date_only_range
from documents.search._query import _datetime_range
from documents.search._query import _rewrite_compact_date
from documents.search._query import build_permission_filter
from documents.search._query import normalize_query
from documents.search._query import parse_simple_text_highlight_query
@@ -21,6 +20,7 @@ from documents.search._query import parse_user_query
from documents.search._query import rewrite_natural_date_keywords
from documents.search._schema import build_schema
from documents.search._tokenizer import register_tokenizers
from documents.search._translate import InvalidDateQuery
if TYPE_CHECKING:
from django.contrib.auth.base_user import AbstractBaseUser
@@ -405,12 +405,14 @@ class TestWhooshQueryRewriting:
assert lo == "2023-12-01T05:00:00Z"
assert hi == "2023-12-02T05:00:00Z"
def test_8digit_invalid_date_passes_through_unchanged(self) -> None:
assert rewrite_natural_date_keywords("added:20231340", UTC) == "added:20231340"
def test_compact_14digit_invalid_date_passes_through_unchanged(self) -> None:
# Month=13 makes datetime() raise ValueError; the token must be left as-is
assert _rewrite_compact_date("20231300120000") == "20231300120000"
def test_8digit_invalid_date_raises(self) -> None:
# The translation pipeline raises InvalidDateQuery for unparsable dates
# (e.g. month=13) so the API can surface a 400 telling the user the date
# is malformed instead of silently returning zero results.
with pytest.raises(InvalidDateQuery) as exc_info:
rewrite_natural_date_keywords("added:20231340", UTC)
assert exc_info.value.field == "added"
assert exc_info.value.value == "20231340"
class TestParseUserQuery:
@@ -463,6 +465,67 @@ class TestParseUserQuery:
) -> None:
assert isinstance(parse_user_query(query_index, raw_query, UTC), tantivy.Query)
@pytest.mark.parametrize(
"raw_query",
[
# Partial date scalar (year only)
pytest.param("created:2020", id="created_year_scalar"),
# 8-digit compact date range in brackets
pytest.param(
"created:[20200101 TO 20201231]",
id="created_8digit_bracket_range",
),
# Comma-separated field + date range (Whoosh v2 multi-clause syntax)
pytest.param(
"title:x,created:[2020 TO 2021]",
id="title_comma_created_range",
),
# Field alias: type -> document_type
pytest.param("type:invoice", id="type_alias"),
# Multi-word date keyword
pytest.param("created:previous week", id="created_previous_week"),
# Full ISO datetime range
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="created_iso_range",
),
# Comma-separated ISO ranges (Whoosh v2 syntax)
pytest.param(
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]",
id="comma_iso_ranges",
),
],
)
def test_advanced_search_queries_do_not_raise(
self,
query_index: tantivy.Index,
raw_query: str,
) -> None:
"""
End-to-end: queries that the frontend sends must parse without raising.
This tests the full pipeline: translate_query -> tantivy parse_query.
Equivalent to asserting HTTP 200 (not 400) for each query form.
"""
with time_machine.travel(datetime(2026, 6, 15, 12, 0, tzinfo=UTC), tick=False):
assert isinstance(
parse_user_query(query_index, raw_query, UTC),
tantivy.Query,
)
def test_invalid_date_propagates_not_swallowed(
self,
query_index: tantivy.Index,
) -> None:
# parse_user_query falls back to the raw query on unexpected translation
# errors, but an InvalidDateQuery is intentional and must propagate so the
# view can return a 400 instead of silently parsing the raw (invalid) date.
with pytest.raises(InvalidDateQuery) as exc_info:
parse_user_query(query_index, "created:202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
class TestYearRangeRewriting:
"""Whoosh-style year-only date ranges must be rewritten to ISO 8601."""
@@ -542,11 +605,16 @@ class TestYearRangeRewriting:
assert rewrite_natural_date_keywords(original, UTC) == original
def test_8digit_in_brackets_not_matched_as_year_range(self) -> None:
# [YYYYMMDD TO YYYYMMDD] has 8-digit values - must not be caught by year rewriter
# [YYYYMMDD TO YYYYMMDD]: the translation layer converts 8-digit bounds to
# ISO day ranges. 20200101 -> 2020-01-01T00:00:00Z (lo of that day);
# 20201231 -> the ceil of Dec 31 = 2021-01-01T00:00:00Z (exclusive end).
# This is the correct and accepted behavior: old compact form becomes a
# proper Tantivy-parseable ISO range.
original = "created:[20200101 TO 20201231]"
result = rewrite_natural_date_keywords(original, UTC)
assert "20200101" in result or "2020-01-01" in result
assert "20201231" in result or "2020-12-31" in result
lo, hi = _range(result, "created")
assert lo == "2020-01-01T00:00:00Z"
assert hi == "2021-01-01T00:00:00Z"
class TestNonDateFieldsNotRewritten:
@@ -606,6 +674,16 @@ class TestNormalizeQuery:
def test_normalize_expands_comma_separated_tags(self) -> None:
assert normalize_query("tag:foo,bar") == "tag:foo AND tag:bar"
def test_normalize_comma_between_range_expressions(self) -> None:
# Comma-separated field range expressions (Whoosh v2 syntax) must be
# converted to AND so Tantivy does not receive an invalid comma.
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert normalize_query(q) == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_normalize_expands_three_values(self) -> None:
assert normalize_query("tag:foo,bar,baz") == "tag:foo AND tag:bar AND tag:baz"
@@ -0,0 +1,742 @@
from __future__ import annotations
from datetime import UTC
from datetime import datetime
from typing import TYPE_CHECKING
from zoneinfo import ZoneInfo
import pytest
import time_machine
from documents.search._dates import _precision_bounds
if TYPE_CHECKING:
import tantivy
from documents.search._query import _FIELD_BOOSTS
from documents.search._query import DEFAULT_SEARCH_FIELDS
from documents.search._translate import OPEN_HI
from documents.search._translate import OPEN_LO
from documents.search._translate import Comma
from documents.search._translate import FieldRange
from documents.search._translate import FieldValue
from documents.search._translate import FieldValueList
from documents.search._translate import InvalidDateQuery
from documents.search._translate import Passthrough
from documents.search._translate import resolve_commas
from documents.search._translate import scan
from documents.search._translate import translate_query
from documents.search._translate import translate_range
from documents.search._translate import translate_scalar
@pytest.mark.search
class TestPrecisionBounds:
@pytest.mark.parametrize(
("digits", "expected"),
[
("2020", ((2020, 1, 1), (2021, 1, 1))),
("202003", ((2020, 3, 1), (2020, 4, 1))),
("202012", ((2020, 12, 1), (2021, 1, 1))),
("20200115", ((2020, 1, 15), (2020, 1, 16))),
("20201231", ((2020, 12, 31), (2021, 1, 1))),
],
)
def test_valid(self, digits, expected):
lo, hi = _precision_bounds(digits)
assert (lo.year, lo.month, lo.day) == expected[0]
assert (hi.year, hi.month, hi.day) == expected[1]
@pytest.mark.parametrize("digits", ["202023", "20200230", "20201301", "20", "abcd"])
def test_invalid_returns_none(self, digits):
assert _precision_bounds(digits) is None
@pytest.mark.search
class TestScan:
def test_plain_words_are_passthrough(self):
assert scan("bank statement") == [Passthrough("bank statement")]
def test_field_value(self):
assert scan("created:2020") == [FieldValue("created", "2020")]
def test_field_value_in_boolean(self):
toks = scan("created:2020 OR foo")
assert toks == [
FieldValue("created", "2020"),
Passthrough(" OR foo"),
]
def test_field_value_in_parens(self):
toks = scan("(created:2020 OR foo)")
assert toks == [
Passthrough("("),
FieldValue("created", "2020"),
Passthrough(" OR foo)"),
]
def test_quoted_value(self):
assert scan('correspondent:"A B"') == [FieldValue("correspondent", '"A B"')]
def test_field_range(self):
assert scan("created:[2020 TO 2021]") == [
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.parametrize(
("query", "expected"),
[
pytest.param(
"created:[2020 to]",
FieldRange("created", "[", "2020", "", "]"),
id="open_upper",
),
pytest.param(
"created:[to 2020]",
FieldRange("created", "[", "", "2020", "]"),
id="open_lower",
),
],
)
def test_open_range(self, query, expected):
assert scan(query) == [expected]
def test_comma_inside_range_not_split(self):
# No depth-0 comma here; the whole thing is one range token.
toks = scan("created:[2020 TO 2021]")
assert len(toks) == 1
# --- Edge-case / regression tests (scan must never raise) ---
def test_url_is_passthrough(self):
# "http" is not a known field; the whole URL must pass through verbatim.
assert scan("http://example.com") == [Passthrough("http://example.com")]
def test_unterminated_quote_is_passthrough(self):
# title is a known field but the quoted value has no closing quote;
# _consume_value returns None so the whole string falls into passthrough.
assert scan('title:"abc') == [Passthrough('title:"abc')]
def test_unterminated_bracket_is_passthrough(self):
# created is a known field but the range bracket is never closed;
# _consume_range returns None so the whole string falls into passthrough.
assert scan("created:[2020") == [Passthrough("created:[2020")]
def test_empty_value_at_end_is_passthrough(self):
# created is a known field but there is no value after the colon
# (_consume_value returns None for start >= n), so passthrough.
assert scan("created:") == [Passthrough("created:")]
def test_value_containing_colon(self):
# The bare-word value reader stops at whitespace/paren, not at colon,
# so "2020:30" is consumed as a single value token.
assert scan("created:2020:30") == [FieldValue("created", "2020:30")]
def test_comma_followed_by_unconsumable_value_stops(self):
# A comma followed by whitespace is neither a value-list continuation nor a
# clause separator: the value stops and the comma stays as passthrough.
assert scan("tag:foo, bar") == [
FieldValue("tag", "foo"),
Passthrough(", bar"),
]
def test_bracket_without_to_is_open_upper_bound(self):
# A bracketed value with no TO falls back to (value, "") -> open upper bound.
assert scan("created:[2020]") == [
FieldRange("created", "[", "2020", "", "]"),
]
def test_known_field_name_midword_is_passthrough(self):
# A known field name embedded mid-word is not a field token (the
# word-boundary guard); the whole run stays passthrough.
assert scan("xtag:foo") == [Passthrough("xtag:foo")]
@pytest.mark.search
class TestCommaResolution:
def test_value_list_multi_value_field(self):
toks = resolve_commas(scan("tag:foo,bar"))
assert toks == [FieldValueList("tag", ("foo", "bar"))]
def test_value_list_three(self):
toks = resolve_commas(scan("tag_id:1,2,3"))
assert toks == [FieldValueList("tag_id", ("1", "2", "3"))]
def test_text_field_comma_is_literal(self):
# correspondent is not multi-value: comma stays inside the value.
toks = resolve_commas(scan("correspondent:foo,bar"))
assert toks == [FieldValue("correspondent", "foo,bar")]
def test_clause_separator_before_known_field(self):
toks = resolve_commas(scan("tag:foo,type:bar"))
assert toks == [FieldValue("tag", "foo"), Comma(), FieldValue("type", "bar")]
def test_clause_separator_after_range(self):
toks = resolve_commas(scan("created:[2020 TO 2021],added:[2022 TO 2023]"))
assert toks == [
FieldRange("created", "[", "2020", "2021", "]"),
Comma(),
FieldRange("added", "[", "2022", "2023", "]"),
]
def test_clause_separator_after_quote(self):
toks = resolve_commas(scan('correspondent:"A B",created:[2020 TO 2021]'))
assert toks == [
FieldValue("correspondent", '"A B"'),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
def test_url_comma_is_literal_passthrough(self):
toks = resolve_commas(scan("http://example.com/a,b"))
assert toks == [Passthrough("http://example.com/a,b")]
def test_non_multi_value_comma_is_literal(self):
# title is not in MULTI_VALUE_FIELDS: comma stays inside the value.
toks = resolve_commas(scan("title:10,20"))
assert toks == [FieldValue("title", "10,20")]
def test_clause_separator_before_known_date_field(self):
# The comma between a bare value and a known date field acts as a
# clause separator; both sides survive as distinct tokens.
toks = resolve_commas(scan("correspondent:foo,created:[2020 TO 2021]"))
assert toks == [
FieldValue("correspondent", "foo"),
Comma(),
FieldRange("created", "[", "2020", "2021", "]"),
]
@pytest.mark.search
class TestTranslateScalar:
@pytest.mark.parametrize(
("field", "value", "expected"),
[
(
"created",
"2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"created",
"202003",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
(
"created",
"20200115",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-01-15",
"created:[2020-01-15T00:00:00Z TO 2020-01-16T00:00:00Z]",
),
(
"created",
"2020-03",
"created:[2020-03-01T00:00:00Z TO 2020-04-01T00:00:00Z]",
),
],
)
def test_partial_and_iso_dates(self, field: str, value: str, expected: str) -> None:
assert translate_scalar(field, value, UTC) == expected
def test_invalid_date_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "202023", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_keyword_delegates(self) -> None:
# keyword path produces a range; just assert it is a created range
out = translate_scalar("created", "today", UTC)
assert out.startswith("created:[") and out.endswith("]")
def test_14digit_compact_datetime(self) -> None:
out = translate_scalar("created", "20240115120000", UTC)
assert "20240115120000" not in out
assert out.startswith("created:")
assert out == "created:[2024-01-15T12:00:00Z TO 2024-01-15T12:00:00Z]"
def test_14digit_invalid_month_raises(self) -> None:
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "20231300120000", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "20231300120000"
def test_unrecognized_value_raises(self) -> None:
# A value that is not a keyword, digits, ISO date, or compact timestamp
# raises rather than producing invalid Tantivy syntax or silently matching
# nothing.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_scalar("created", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateRange:
@pytest.mark.parametrize(
("lo", "hi", "expected"),
[
("2005", "2009", "created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"),
(
"202001",
"202006",
"created:[2020-01-01T00:00:00Z TO 2020-07-01T00:00:00Z]",
),
(
"20200101",
"20201231",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
(
"2020-01-01",
"2020-12-31",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
],
)
def test_absolute_ranges(self, lo, hi, expected):
assert translate_range("created", lo, hi, UTC) == expected
def test_reversed_swaps(self):
assert translate_range("created", "2009", "2005", UTC) == (
"created:[2005-01-01T00:00:00Z TO 2010-01-01T00:00:00Z]"
)
def test_open_upper(self):
out = translate_range("created", "2020", "", UTC)
assert out == f"created:[2020-01-01T00:00:00Z TO {OPEN_HI}]"
def test_open_lower(self):
out = translate_range("created", "", "2020", UTC)
assert out == f"created:[{OPEN_LO} TO 2021-01-01T00:00:00Z]"
def test_invalid_bound_raises(self):
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "202023", "2025", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "202023"
def test_invalid_high_bound_raises(self):
# Low bound parses, high bound does not -> raise on the high bound.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range("created", "2020", "garbage", UTC)
assert exc_info.value.field == "created"
assert exc_info.value.value == "garbage"
@pytest.mark.search
class TestTranslateQuery:
@pytest.mark.parametrize(
("raw", "expected"),
[
(
"created:2020",
"created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]",
),
("tag:foo,bar", "tag:foo AND tag:bar"),
# 'type' is a user-facing alias rewritten to 'document_type' (the real schema field)
("tag:foo,type:bar", "tag:foo AND document_type:bar"),
(
"created:[2020 TO 2021],added:[2022 TO 2023]",
"created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
" AND "
"added:[2022-01-01T00:00:00Z TO 2024-01-01T00:00:00Z]",
),
# correspondent is not multi-value: comma stays literal inside the value
("correspondent:foo,bar", "correspondent:foo,bar"),
],
)
def test_golden(self, raw: str, expected: str) -> None:
assert translate_query(raw, UTC) == expected
@pytest.mark.parametrize(
"raw",
[
"created:2020",
"created:202003",
"created:[20200101 TO 20201231]",
"created:[2020-01-01 TO 2020-12-31]",
"created:[2020 to]",
"created:[to 2020]",
"title:x,created:[2020 TO 2021]",
"created:2020 OR foo",
"(created:2020 OR invoice)",
"tag:foo,type:bar",
"bank statement",
],
)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
# Must not raise:
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestFieldAliasing:
"""Whoosh->Tantivy field-name aliasing (type/path -> document_type/storage_path)."""
def test_type_alias(self) -> None:
assert translate_query("type:invoice", UTC) == "document_type:invoice"
def test_path_alias(self) -> None:
assert translate_query("path:/foo/bar", UTC) == "storage_path:/foo/bar"
def test_type_id_alias(self) -> None:
assert translate_query("type_id:5", UTC) == "document_type_id:5"
def test_path_id_alias(self) -> None:
assert translate_query("path_id:7", UTC) == "storage_path_id:7"
def test_clause_separator_plus_alias(self) -> None:
# Comma between known fields acts as AND separator; alias still applied.
assert (
translate_query("tag:foo,type:bar", UTC) == "tag:foo AND document_type:bar"
)
def test_type_range_alias(self) -> None:
# type is not a date field; range passes through verbatim with alias applied.
assert (
translate_query("type:[2020 TO 2021]", UTC)
== "document_type:[2020 TO 2021]"
)
def test_parse_acceptance_type(self, index: tantivy.Index) -> None:
# Translated output must be accepted by the real Tantivy parser.
translated = translate_query("type:invoice", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_path(self, index: tantivy.Index) -> None:
translated = translate_query("path:foo", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
# Freeze time so relative-date tests are deterministic.
_FROZEN_NOW = datetime(2026, 3, 28, 12, 0, 0, tzinfo=UTC)
@pytest.mark.search
class TestRelativeRanges:
"""Relative date-range tokens resolved against a frozen clock."""
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_7_days_to_now(self) -> None:
assert translate_query("added:[-7 days to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_week_to_now(self) -> None:
assert translate_query("added:[-1 week to now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_month_to_now(self) -> None:
assert translate_query("created:[-1 month to now]", UTC) == (
"created:[2026-02-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_1_year_to_now(self) -> None:
assert translate_query("modified:[-1 year to now]", UTC) == (
"modified:[2025-03-28T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_minus_3_hours_to_now(self) -> None:
assert translate_query("added:[-3 hours to now]", UTC) == (
"added:[2026-03-28T09:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_uppercase_units(self) -> None:
assert translate_query("added:[-1 WEEK TO NOW]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_now_minus_7d_compact(self) -> None:
assert translate_query("added:[now-7d TO now]", UTC) == (
"added:[2026-03-21T12:00:00Z TO 2026-03-28T12:00:00Z]"
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_reversed_range_swapped(self) -> None:
# now+1h TO now-1h is reversed; translate_range swaps -> lo=now-1h, hi=now+1h
assert translate_query("added:[now+1h TO now-1h]", UTC) == (
"added:[2026-03-28T11:00:00Z TO 2026-03-28T13:00:00Z]"
)
@pytest.mark.parametrize(
"raw",
[
"added:[-7 days to now]",
"added:[-1 week to now]",
"created:[-1 month to now]",
"modified:[-1 year to now]",
"added:[-3 hours to now]",
"added:[now-7d TO now]",
"added:[now+1h TO now-1h]",
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_parse_acceptance(self, index: tantivy.Index, raw: str) -> None:
translated = translate_query(raw, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestOperatorNormalization:
"""Post-render operator normalization in translate_query."""
def test_spaced_dash_removed(self) -> None:
assert (
translate_query("H52.1 - Kurzsichtigkeit", UTC) == "H52.1 Kurzsichtigkeit"
)
def test_spaced_dash_simple(self) -> None:
assert translate_query("bar - baz", UTC) == "bar baz"
def test_trailing_operator_stripped(self) -> None:
assert translate_query("foo -", UTC) == "foo"
def test_date_range_preserved(self) -> None:
out = translate_query("created:[2020 TO 2021]", UTC)
# Must not corrupt the ISO range
assert out == "created:[2020-01-01T00:00:00Z TO 2022-01-01T00:00:00Z]"
def test_date_scalar_with_or(self) -> None:
out = translate_query("created:2020 OR foo", UTC)
# The created scalar becomes a range; " OR foo" passes through verbatim.
assert out.startswith("created:[")
assert "OR foo" in out
def test_parse_acceptance_spaced_dash(self, index: tantivy.Index) -> None:
translated = translate_query("H52.1 - Kurzsichtigkeit", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_trailing_op(self, index: tantivy.Index) -> None:
translated = translate_query("foo -", UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
@pytest.mark.search
class TestMultiWordDateKeywords:
"""scan() must consume multi-word date keywords as a single value."""
def test_scan_previous_week_as_single_token(self) -> None:
# "created:previous week" must produce one FieldValue with value "previous week",
# not FieldValue("created","previous") + Passthrough(" week").
toks = scan("created:previous week")
assert toks == [FieldValue("created", "previous week")]
def test_scan_this_month_as_single_token(self) -> None:
toks = scan("added:this month")
assert toks == [FieldValue("added", "this month")]
def test_scan_previous_month_as_single_token(self) -> None:
toks = scan("created:previous month")
assert toks == [FieldValue("created", "previous month")]
def test_scan_this_year_as_single_token(self) -> None:
toks = scan("added:this year")
assert toks == [FieldValue("added", "this year")]
def test_scan_previous_year_as_single_token(self) -> None:
toks = scan("created:previous year")
assert toks == [FieldValue("created", "previous year")]
def test_scan_previous_quarter_as_single_token(self) -> None:
toks = scan("created:previous quarter")
assert toks == [FieldValue("created", "previous quarter")]
def test_quoted_multi_word_keyword_still_works(self) -> None:
# The quoted form must continue to work as before.
toks = scan('created:"previous week"')
assert toks == [FieldValue("created", '"previous week"')]
def test_non_date_field_not_affected(self) -> None:
# "previous" stops at the space for non-date fields; " week" passes through.
toks = scan("correspondent:previous week")
assert toks == [
FieldValue("correspondent", "previous"),
Passthrough(" week"),
]
@pytest.mark.search
class TestKeywordDateResolution:
"""Relative date keywords resolve to exact ISO ranges against a frozen clock.
Frozen at 2026-03-28 12:00 UTC (a Saturday in Q1) so the week, month,
quarter and year rollovers are all exercised by a single anchor.
"""
# created is a DateField: bounds are UTC midnight, no timezone offset.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"created:[2026-03-28T00:00:00Z TO 2026-03-29T00:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"created:[2026-03-27T00:00:00Z TO 2026-03-28T00:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"created:[2026-03-16T00:00:00Z TO 2026-03-23T00:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"created:[2026-03-01T00:00:00Z TO 2026-04-01T00:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"created:[2026-02-01T00:00:00Z TO 2026-03-01T00:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"created:[2026-01-01T00:00:00Z TO 2027-01-01T00:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"created:[2025-01-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"created:[2025-10-01T00:00:00Z TO 2026-01-01T00:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_date_only_field_keyword_ranges(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"created:{keyword}", UTC) == expected
# added is a DateTimeField: local-tz midnight converted to UTC. Tokyo
# (+09:00, no DST) shifts each midnight boundary back to 15:00Z the day
# before, so this also exercises the local-midnight offset path.
@pytest.mark.parametrize(
("keyword", "expected"),
[
pytest.param(
"today",
"added:[2026-03-27T15:00:00Z TO 2026-03-28T15:00:00Z]",
id="today",
),
pytest.param(
"yesterday",
"added:[2026-03-26T15:00:00Z TO 2026-03-27T15:00:00Z]",
id="yesterday",
),
pytest.param(
"previous week",
"added:[2026-03-15T15:00:00Z TO 2026-03-22T15:00:00Z]",
id="previous-week",
),
pytest.param(
"this month",
"added:[2026-02-28T15:00:00Z TO 2026-03-31T15:00:00Z]",
id="this-month",
),
pytest.param(
"previous month",
"added:[2026-01-31T15:00:00Z TO 2026-02-28T15:00:00Z]",
id="previous-month",
),
pytest.param(
"this year",
"added:[2025-12-31T15:00:00Z TO 2026-12-31T15:00:00Z]",
id="this-year",
),
pytest.param(
"previous year",
"added:[2024-12-31T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-year",
),
pytest.param(
"previous quarter",
"added:[2025-09-30T15:00:00Z TO 2025-12-31T15:00:00Z]",
id="previous-quarter",
),
],
)
@time_machine.travel(_FROZEN_NOW, tick=False)
def test_datetime_field_keyword_ranges_local_tz(
self,
keyword: str,
expected: str,
) -> None:
assert translate_query(f"added:{keyword}", ZoneInfo("Asia/Tokyo")) == expected
@pytest.mark.search
class TestISODatetimeBounds:
"""Full ISO datetime tokens in range bounds must be parsed directly."""
def test_translate_range_iso_bounds_passthrough(self) -> None:
# Already-ISO datetime bounds must pass through as-is (exact instant).
result = translate_range(
"created",
"2020-01-01T00:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert result == "created:[2020-01-01T00:00:00Z TO 2021-01-01T00:00:00Z]"
def test_translate_query_iso_range_preserved(self) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
assert translate_query(q, UTC) == q
def test_translate_query_comma_separated_iso_ranges(self) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
result = translate_query(q, UTC)
assert result == (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
" AND "
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
def test_invalid_iso_datetime_raises(self) -> None:
# A token with "T" that is not valid ISO datetime -> raise.
with pytest.raises(InvalidDateQuery) as exc_info:
translate_range(
"created",
"2020-01-01T99:00:00Z",
"2021-01-01T00:00:00Z",
UTC,
)
assert exc_info.value.field == "created"
assert exc_info.value.value == "2020-01-01T99:00:00Z"
def test_parse_acceptance_iso_bounds(self, index: tantivy.Index) -> None:
q = "created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
def test_parse_acceptance_comma_iso_ranges(self, index: tantivy.Index) -> None:
q = (
"created:[2026-01-01T00:00:00Z TO 2026-06-01T00:00:00Z],"
"added:[2026-05-01T00:00:00Z TO 2026-06-01T00:00:00Z]"
)
translated = translate_query(q, UTC)
index.parse_query(translated, DEFAULT_SEARCH_FIELDS, field_boosts=_FIELD_BOOSTS)
+5 -4
View File
@@ -82,6 +82,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
"llm_api_key": None,
"llm_endpoint": None,
"llm_output_language": None,
"llm_request_timeout": None,
},
)
@@ -844,7 +845,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.vector_store_file_exists") as mock_exists,
patch("paperless.views.llm_index_exists") as mock_exists,
):
mock_exists.return_value = False
self.client.patch(
@@ -869,7 +870,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.vector_store_file_exists") as mock_exists,
patch("paperless.views.llm_index_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
@@ -890,7 +891,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.vector_store_file_exists") as mock_exists,
patch("paperless.views.llm_index_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
@@ -928,7 +929,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
with (
patch("documents.tasks.llmindex_index.apply_async") as mock_update,
patch("paperless.views.vector_store_file_exists") as mock_exists,
patch("paperless.views.llm_index_exists") as mock_exists,
):
mock_exists.return_value = True
self.client.patch(
@@ -0,0 +1,95 @@
import unicodedata
from typing import TYPE_CHECKING
from unittest import mock
import celery.result
import pytest
from django.core.files.uploadedfile import SimpleUploadedFile
if TYPE_CHECKING:
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides
@pytest.fixture()
def consume_file_mock():
with mock.patch("documents.tasks.consume_file.apply_async") as m:
m.return_value = celery.result.AsyncResult(id="test-task-id")
yield m
@pytest.fixture()
def directories(tmp_path, settings, _media_settings):
scratch = tmp_path / "scratch"
scratch.mkdir()
settings.SCRATCH_DIR = scratch
return scratch
@pytest.mark.django_db
class TestPostDocumentNFCNormalization:
def test_nfd_filename_normalized_to_nfc(
self,
admin_client,
consume_file_mock: mock.MagicMock,
directories,
):
"""Uploaded file with NFD filename must have its name stored as NFC."""
nfd = unicodedata.normalize("NFD", "Rechnung März.pdf")
nfc = unicodedata.normalize("NFC", "Rechnung März.pdf")
# Verify our test strings actually differ at the byte level
assert nfd != nfc
uploaded = SimpleUploadedFile(
nfd,
b"%PDF-1.4 test",
content_type="application/pdf",
)
response = admin_client.post(
"/api/documents/post_document/",
{"document": uploaded},
)
assert response.status_code == 200
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
input_doc: ConsumableDocument = task_kwargs["input_doc"]
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
# The temp file on disk must have an NFC name
assert input_doc.original_file.name == nfc, (
f"Expected NFC filename {nfc!r}, got {input_doc.original_file.name!r}"
)
# The override filename stored for later use must also be NFC
assert overrides.filename == nfc, (
f"Expected NFC override filename {nfc!r}, got {overrides.filename!r}"
)
assert unicodedata.is_normalized("NFC", overrides.filename)
def test_already_nfc_filename_unchanged(
self,
admin_client,
consume_file_mock: mock.MagicMock,
directories,
):
"""Uploaded file with already-NFC filename must pass through unchanged."""
nfc = unicodedata.normalize("NFC", "Invoice_2024.pdf")
uploaded = SimpleUploadedFile(
nfc,
b"%PDF-1.4 test",
content_type="application/pdf",
)
response = admin_client.post(
"/api/documents/post_document/",
{"document": uploaded},
)
assert response.status_code == 200
task_kwargs = consume_file_mock.call_args.kwargs["kwargs"]
overrides: DocumentMetadataOverrides = task_kwargs["overrides"]
assert overrides.filename == nfc
assert unicodedata.is_normalized("NFC", overrides.filename)
+6 -3
View File
@@ -725,9 +725,11 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
GIVEN:
- One document added right now
WHEN:
- Query with invalid added date
- Query with an invalid added date
THEN:
- 400 Bad Request returned (Tantivy rejects invalid date field syntax)
- 400 Bad Request with a message naming the malformed date, so the
user knows their date is invalid rather than silently getting zero
results
"""
d1 = Document.objects.create(
title="invoice",
@@ -740,8 +742,9 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
response = self.client.get("/api/documents/?query=added:invalid-date")
# Tantivy rejects unparsable field queries with a 400
# An unparsable date is reported as a malformed query, not silently empty.
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("invalid-date", str(response.data["query"]))
@override_settings(
TIME_ZONE="UTC",
+71
View File
@@ -216,6 +216,77 @@ class TestSystemStatus(APITestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_none(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns no worker responses
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
mock_ping.return_value = None
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(
response.data["tasks"]["celery_error"],
"No celery workers responded to ping. This may be temporary.",
)
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_unexpected_responses(self, mock_ping) -> None:
"""
GIVEN:
- Celery ping returns an unexpected worker response
WHEN:
- The user requests the system status
THEN:
- The response contains a warning celery status
"""
self.client.force_login(self.user)
for ping_response in (
{"hostname": {"ok": "not-pong"}},
{"hostname": {}},
{"hostname": "pong"},
):
with self.subTest(ping_response=ping_response):
mock_ping.return_value = ping_response
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "WARNING")
self.assertEqual(response.data["tasks"]["celery_url"], "hostname")
self.assertEqual(
response.data["tasks"]["celery_error"],
"Celery worker responded unexpectedly.",
)
@mock.patch("documents.views.sleep")
@mock.patch("celery.app.control.Inspect.ping")
def test_system_status_celery_ping_retry_success(
self,
mock_ping,
mock_sleep,
) -> None:
"""
GIVEN:
- Celery ping fails once but succeeds on retry
WHEN:
- The user requests the system status
THEN:
- The response contains an OK celery status
"""
mock_ping.side_effect = [None, {"hostname": {"ok": "pong"}}]
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["celery_status"], "OK")
self.assertIsNone(response.data["tasks"]["celery_error"])
self.assertEqual(mock_ping.call_count, 2)
mock_sleep.assert_called_once_with(0.25)
@mock.patch("documents.search.get_backend")
def test_system_status_index_ok(self, mock_get_backend) -> None:
"""
+181
View File
@@ -18,6 +18,7 @@ from guardian.shortcuts import assign_perm
from rest_framework import status
from rest_framework.test import APIClient
from documents.filters import PaperlessTaskFilterSet
from documents.models import PaperlessTask
from documents.tests.factories import DocumentFactory
from documents.tests.factories import PaperlessTaskFactory
@@ -169,6 +170,165 @@ class TestGetTasksV10:
PaperlessTask.Status.STARTED,
}
def test_filter_by_task_name(self, admin_client: APIClient) -> None:
"""?name= searches task filenames, task types, and trigger sources."""
filename_task = PaperlessTaskFactory(input_data={"filename": "invoice-123.pdf"})
type_task = PaperlessTaskFactory(task_type=PaperlessTask.TaskType.SANITY_CHECK)
source_task = PaperlessTaskFactory(
trigger_source=PaperlessTask.TriggerSource.EMAIL_CONSUME,
)
PaperlessTaskFactory(input_data={"filename": "unrelated.pdf"})
response = admin_client.get(ENDPOINT, {"name": "invoice"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == filename_task.task_id
response = admin_client.get(ENDPOINT, {"name": "sanity"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == type_task.task_id
response = admin_client.get(ENDPOINT, {"name": "email"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == source_task.task_id
def test_filter_by_task_result(self, admin_client: APIClient) -> None:
"""?result= searches common structured task result messages."""
reason_task = PaperlessTaskFactory(result_data={"reason": "Manual review"})
error_task = PaperlessTaskFactory(
result_data={"error_message": "Duplicate detected"},
)
document_task = PaperlessTaskFactory(result_data={"document_id": 321})
duplicate_task = PaperlessTaskFactory(result_data={"duplicate_of": 123})
PaperlessTaskFactory(result_data={"reason": "unrelated"})
response = admin_client.get(ENDPOINT, {"result": "manual"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == reason_task.task_id
response = admin_client.get(ENDPOINT, {"result": "duplicate"})
assert response.status_code == status.HTTP_200_OK
returned_ids = {task["task_id"] for task in response.data["results"]}
assert returned_ids == {error_task.task_id, duplicate_task.task_id}
response = admin_client.get(ENDPOINT, {"result": "321"})
assert response.status_code == status.HTTP_200_OK
assert response.data["count"] == 1
assert response.data["results"][0]["task_id"] == document_task.task_id
def test_empty_task_name_and_result_filters(self) -> None:
"""Empty name/result values leave the queryset unchanged."""
PaperlessTaskFactory.create_batch(2)
queryset = PaperlessTask.objects.all()
filterset = PaperlessTaskFilterSet()
assert filterset.filter_name(queryset, "name", "").count() == 2
assert filterset.filter_result(queryset, "result", "").count() == 2
def test_status_counts_respects_filters(self, admin_client: APIClient) -> None:
"""status_counts/ returns section counts for the filtered task queryset."""
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.FAILURE,
input_data={"filename": "invoice-a.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.REVOKED,
input_data={"filename": "invoice-b.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.PENDING,
input_data={"filename": "invoice-c.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.STARTED,
input_data={"filename": "invoice-d.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "invoice-e.pdf"},
)
PaperlessTaskFactory(
acknowledged=True,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "invoice-acknowledged.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "unrelated.pdf"},
)
response = admin_client.get(
f"{ENDPOINT}status_counts/",
{"acknowledged": "false", "name": "invoice"},
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
"all": 5,
"needs_attention": 2,
"in_progress": 2,
"completed": 1,
}
def test_status_counts_ignores_section_filters(
self,
admin_client: APIClient,
) -> None:
"""status_counts/ ignores status-like filters for the sections it counts."""
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.FAILURE,
input_data={"filename": "invoice-a.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.PENDING,
input_data={"filename": "invoice-b.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.SUCCESS,
input_data={"filename": "invoice-c.pdf"},
)
PaperlessTaskFactory(
acknowledged=False,
status=PaperlessTask.Status.FAILURE,
input_data={"filename": "unrelated.pdf"},
)
response = admin_client.get(
f"{ENDPOINT}status_counts/",
{
"acknowledged": "false",
"name": "invoice",
"status": PaperlessTask.Status.FAILURE,
"is_complete": "false",
},
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
"all": 3,
"needs_attention": 1,
"in_progress": 1,
"completed": 1,
}
def test_default_ordering_is_newest_first(self, admin_client: APIClient) -> None:
"""Tasks are returned in descending date_created order (newest first)."""
base = timezone.now()
@@ -522,6 +682,27 @@ class TestAcknowledge:
assert response.status_code == status.HTTP_200_OK
assert response.data == {"result": 2}
def test_acknowledge_all_returns_count(self, admin_client: APIClient) -> None:
"""POST acknowledge/ with all=true acknowledges all unacknowledged tasks."""
unacknowledged_task1 = PaperlessTaskFactory(acknowledged=False)
unacknowledged_task2 = PaperlessTaskFactory(acknowledged=False)
acknowledged_task = PaperlessTaskFactory(acknowledged=True)
response = admin_client.post(
ENDPOINT + "acknowledge/",
{"all": True},
format="json",
)
assert response.status_code == status.HTTP_200_OK
assert response.data == {"result": 2}
unacknowledged_task1.refresh_from_db()
unacknowledged_task2.refresh_from_db()
acknowledged_task.refresh_from_db()
assert unacknowledged_task1.acknowledged
assert unacknowledged_task2.acknowledged
assert acknowledged_task.acknowledged
def test_acknowledged_tasks_excluded_from_unacked_filter(
self,
admin_client: APIClient,
+117 -9
View File
@@ -3,6 +3,7 @@ from datetime import date
from pathlib import Path
from unittest import mock
import pikepdf
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.test import TestCase
@@ -615,6 +616,18 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.img_doc.archive_filename = img_doc_archive
self.img_doc.save()
@staticmethod
def mock_password_required_pdf(
mock_open: mock.Mock,
fake_pdf: mock.Mock,
) -> None:
password_context = mock.MagicMock()
password_context.__enter__.return_value = fake_pdf
mock_open.side_effect = [
pikepdf.PasswordError("password required"),
password_context,
]
@mock.patch("documents.tasks.consume_file.s")
def test_merge(self, mock_consume_file) -> None:
"""
@@ -1466,6 +1479,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf = mock.MagicMock()
fake_pdf.pages = [mock.Mock(), mock.Mock(), mock.Mock()]
fake_pdf.is_encrypted = True
def save_side_effect(target_path):
Path(target_path).write_bytes(b"new pdf content")
@@ -1480,7 +1494,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path, password="secret")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(doc.source_path),
mock.call(doc.source_path, password="secret"),
],
)
fake_pdf.remove_unreferenced_resources.assert_called_once()
mock_update_document.assert_not_called()
mock_consume_delay.assert_called_once()
@@ -1494,6 +1514,33 @@ class TestPDFActions(DirectoriesMixin, TestCase):
self.assertEqual(task_kwargs["input_doc"].root_document_id, doc.id)
self.assertIsNotNone(task_kwargs["overrides"])
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
@mock.patch("pikepdf.open")
def test_remove_password_update_document_skips_unencrypted_pdf(
self,
mock_open,
mock_mkdtemp,
mock_consume_delay,
) -> None:
doc = self.doc1
fake_pdf = mock.MagicMock()
fake_pdf.is_encrypted = False
mock_open.return_value.__enter__.return_value = fake_pdf
result = bulk_edit.remove_password(
[doc.id],
password="secret",
update_document=True,
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path)
fake_pdf.remove_unreferenced_resources.assert_not_called()
fake_pdf.save.assert_not_called()
mock_mkdtemp.assert_not_called()
mock_consume_delay.assert_not_called()
@mock.patch("documents.bulk_edit.update_document_content_maybe_archive_file.delay")
@mock.patch("documents.tasks.consume_file.apply_async")
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
@@ -1513,12 +1560,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_mkdtemp.return_value = str(temp_dir)
fake_pdf = mock.MagicMock()
self.mock_password_required_pdf(mock_open, fake_pdf)
def save_side_effect(target_path):
Path(target_path).write_bytes(b"new pdf content")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
result = bulk_edit.remove_password(
[doc.id],
@@ -1528,7 +1575,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(source_file, password="secret")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(source_file),
mock.call(source_file, password="secret"),
],
)
mock_update_document.assert_not_called()
mock_consume_delay.assert_called_once()
@@ -1547,7 +1600,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
root_document=self.doc1,
)
fake_pdf = mock.MagicMock()
mock_open.return_value.__enter__.return_value = fake_pdf
self.mock_password_required_pdf(mock_open, fake_pdf)
result = bulk_edit.remove_password(
[self.doc1.id],
@@ -1557,7 +1610,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(self.doc1.source_path, password="secret")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(self.doc1.source_path),
mock.call(self.doc1.source_path, password="secret"),
],
)
mock_consume_delay.assert_called_once()
@mock.patch("documents.bulk_edit.chord")
@@ -1580,12 +1639,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf = mock.MagicMock()
fake_pdf.pages = [mock.Mock(), mock.Mock()]
self.mock_password_required_pdf(mock_open, fake_pdf)
def save_side_effect(target_path: Path) -> None:
target_path.write_bytes(b"password removed")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
mock_group.return_value.delay.return_value = None
user = User.objects.create(username="owner")
@@ -1600,7 +1659,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path, password="secret")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(doc.source_path),
mock.call(doc.source_path, password="secret"),
],
)
mock_consume_file.assert_called_once()
call_kwargs = mock_consume_file.call_args.kwargs
consumable_document = call_kwargs["input_doc"]
@@ -1618,6 +1683,43 @@ class TestPDFActions(DirectoriesMixin, TestCase):
mock_group.return_value.delay.assert_called_once()
mock_chord.assert_not_called()
@mock.patch("documents.bulk_edit.delete")
@mock.patch("documents.bulk_edit.chord")
@mock.patch("documents.bulk_edit.group")
@mock.patch("documents.tasks.consume_file.s")
@mock.patch("documents.bulk_edit.tempfile.mkdtemp")
@mock.patch("pikepdf.open")
def test_remove_password_skips_unencrypted_pdf_without_queueing(
self,
mock_open: mock.Mock,
mock_mkdtemp: mock.Mock,
mock_consume_file: mock.Mock,
mock_group: mock.Mock,
mock_chord: mock.Mock,
mock_delete: mock.Mock,
) -> None:
doc = self.doc2
fake_pdf = mock.MagicMock()
fake_pdf.is_encrypted = False
mock_open.return_value.__enter__.return_value = fake_pdf
result = bulk_edit.remove_password(
[doc.id],
password="secret",
update_document=False,
delete_original=True,
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path)
fake_pdf.remove_unreferenced_resources.assert_not_called()
fake_pdf.save.assert_not_called()
mock_mkdtemp.assert_not_called()
mock_consume_file.assert_not_called()
mock_group.assert_not_called()
mock_chord.assert_not_called()
mock_delete.si.assert_not_called()
@mock.patch("documents.bulk_edit.delete")
@mock.patch("documents.bulk_edit.chord")
@mock.patch("documents.bulk_edit.group")
@@ -1640,12 +1742,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
fake_pdf = mock.MagicMock()
fake_pdf.pages = [mock.Mock(), mock.Mock()]
self.mock_password_required_pdf(mock_open, fake_pdf)
def save_side_effect(target_path: Path) -> None:
target_path.write_bytes(b"password removed")
fake_pdf.save.side_effect = save_side_effect
mock_open.return_value.__enter__.return_value = fake_pdf
mock_chord.return_value.delay.return_value = None
result = bulk_edit.remove_password(
@@ -1657,7 +1759,13 @@ class TestPDFActions(DirectoriesMixin, TestCase):
)
self.assertEqual(result, "OK")
mock_open.assert_called_once_with(doc.source_path, password="secret")
self.assertEqual(
mock_open.call_args_list,
[
mock.call(doc.source_path),
mock.call(doc.source_path, password="secret"),
],
)
mock_consume_file.assert_called_once()
mock_group.assert_not_called()
mock_chord.assert_called_once()
+43 -2
View File
@@ -24,6 +24,7 @@ from documents.models import CustomFieldInstance
from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.serialisers import DocumentSerializer
from documents.tasks import empty_trash
from documents.tests.factories import DocumentFactory
from documents.tests.utils import DirectoriesMixin
@@ -221,8 +222,8 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
doc = Document.objects.create(
title="document",
mime_type="application/pdf",
checksum=hashlib.md5(original_bytes).hexdigest(),
archive_checksum=hashlib.md5(archive_bytes).hexdigest(),
checksum=hashlib.sha256(original_bytes).hexdigest(),
archive_checksum=hashlib.sha256(archive_bytes).hexdigest(),
filename="old/document.pdf",
archive_filename="old/document.pdf",
storage_path=old_storage_path,
@@ -251,6 +252,46 @@ class TestFileHandling(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
self.assertIsNotFile(settings.ORIGINALS_DIR / "old" / "document.pdf")
self.assertIsNotFile(settings.ARCHIVE_DIR / "old" / "document.pdf")
@override_settings(FILENAME_FORMAT="{title}")
def test_serializer_stale_update_does_not_clobber_filename(self) -> None:
old_path = settings.ORIGINALS_DIR / "original.pdf"
old_path.touch()
doc = Document.objects.create(
title="original",
mime_type="application/pdf",
checksum=hashlib.sha256(b"").hexdigest(),
filename="original.pdf",
)
first_instance = Document.objects.get(pk=doc.pk)
stale_instance = Document.objects.get(pk=doc.pk)
serializer = DocumentSerializer(
first_instance,
data={"title": "first"},
partial=True,
)
self.assertTrue(serializer.is_valid(), serializer.errors)
serializer.save()
doc.refresh_from_db()
self.assertEqual(doc.filename, "first.pdf")
self.assertIsFile(settings.ORIGINALS_DIR / "first.pdf")
serializer = DocumentSerializer(
stale_instance,
data={"title": "second"},
partial=True,
)
self.assertTrue(serializer.is_valid(), serializer.errors)
serializer.save()
doc.refresh_from_db()
self.assertEqual(doc.filename, "second.pdf")
self.assertIsFile(settings.ORIGINALS_DIR / "second.pdf")
self.assertIsNotFile(settings.ORIGINALS_DIR / "first.pdf")
self.assertIsNotFile(old_path)
@override_settings(FILENAME_FORMAT="{correspondent}/{correspondent}")
def test_document_delete(self) -> None:
document = Document()
+187
View File
@@ -0,0 +1,187 @@
"""
Tests for NFC Unicode normalization in generate_filename / FilePathTemplate.render().
NFC `ü` (UTF-8: c3 bc) and NFD `ü` (UTF-8: 75 cc 88) are visually identical but
produce different byte sequences. On Linux (ext4, ZFS) these are distinct filenames.
All paths produced by the templating system must be NFC-normalized.
"""
import unicodedata
import pytest
from documents.file_handling import generate_filename
from documents.models import CustomField
from documents.models import CustomFieldInstance
from documents.tests.factories import CorrespondentFactory
from documents.tests.factories import DocumentFactory
from documents.tests.factories import StoragePathFactory
from documents.tests.factories import TagFactory
@pytest.mark.django_db
class TestGenerateFilenameNFCNormalization:
@pytest.mark.parametrize(
"raw,display",
[
(unicodedata.normalize("NFD", "Gemüse"), "Gemüse"),
(unicodedata.normalize("NFD", "Café"), "Café"),
(unicodedata.normalize("NFD", "naïve"), "naïve"),
],
)
def test_nfd_title_normalized_to_nfc(self, settings, raw, display):
"""NFD title must produce NFC path bytes."""
settings.FILENAME_FORMAT = "{{ title }}"
nfc = unicodedata.normalize("NFC", display)
assert raw != nfc # confirm byte-level difference
doc = DocumentFactory(title=raw, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result) == f"{nfc}.pdf"
assert str(result).encode() == f"{nfc}.pdf".encode()
def test_nfd_correspondent_normalized_to_nfc(self, settings):
"""NFD correspondent name must produce NFC path component."""
settings.FILENAME_FORMAT = "{{ correspondent }}/{{ title }}"
nfd = unicodedata.normalize("NFD", "Müller")
nfc = unicodedata.normalize("NFC", "Müller")
correspondent = CorrespondentFactory(name=nfd)
doc = DocumentFactory(
title="invoice",
correspondent=correspondent,
mime_type="application/pdf",
)
result = generate_filename(doc)
assert str(result) == f"{nfc}/invoice.pdf"
assert str(result).encode() == f"{nfc}/invoice.pdf".encode()
def test_nfd_storage_path_normalized_to_nfc(self, settings):
"""NFD literal in StoragePath.path template must produce NFC path bytes."""
settings.FILENAME_FORMAT = None
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
# StoragePath.path is used directly as the format/template string.
# Literal NFD characters in the template must survive rendering as NFC.
sp = StoragePathFactory(path=f"{nfd}/{{{{ title }}}}")
doc = DocumentFactory(title="doc", storage_path=sp, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
def test_nfd_raw_document_title_normalized_to_nfc(self, settings):
"""NFD title accessed via document.title (unsanitized context) must also be NFC."""
settings.FILENAME_FORMAT = "{{ document.title }}"
nfd = unicodedata.normalize("NFD", "Café")
nfc = unicodedata.normalize("NFC", "Café")
doc = DocumentFactory(title=nfd, mime_type="application/pdf")
result = generate_filename(doc)
assert str(result) == f"{nfc}.pdf"
assert str(result).encode() == f"{nfc}.pdf".encode()
@pytest.mark.django_db
class TestContextBuilderNFCNormalization:
"""
Defense-in-depth: context builder functions must NFC-normalize string inputs
before passing them to sanitize_filename(). Task 1 already normalizes the
final rendered path via clean_filepath(), so these tests may already pass;
they exist as regression guards for the context-builder layer.
"""
def test_nfd_tag_name_normalized_in_tag_list(self, settings):
"""NFD tag name must appear as NFC bytes in the {{ tag_list }} shorthand."""
settings.FILENAME_FORMAT = "{{ tag_list }}/{{ title }}"
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
assert nfd != nfc # confirm they differ at byte level
tag = TagFactory(name=nfd)
doc = DocumentFactory(title="doc", mime_type="application/pdf")
doc.tags.set([tag])
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
def test_nfd_original_name_normalized_to_nfc(self, settings):
settings.FILENAME_FORMAT = "{{ original_name }}"
nfd = unicodedata.normalize("NFD", "Rechnung März")
nfc = unicodedata.normalize("NFC", "Rechnung März")
doc = DocumentFactory(
original_filename=f"{nfd}.pdf",
mime_type="application/pdf",
)
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}.pdf".encode()
def test_nfd_custom_field_string_value_normalized(self, settings):
"""NFD value in a STRING-type custom field must appear as NFC in the context."""
settings.FILENAME_FORMAT = (
"{{ custom_fields['Location']['value'] }}/{{ title }}"
)
nfd_value = unicodedata.normalize("NFD", "Düsseldorf")
nfc_value = unicodedata.normalize("NFC", "Düsseldorf")
assert nfd_value != nfc_value
doc = DocumentFactory(title="report", mime_type="application/pdf")
cf = CustomField.objects.create(
name="Location",
data_type=CustomField.FieldDataType.STRING,
)
CustomFieldInstance.objects.create(
document=doc,
field=cf,
value_text=nfd_value,
)
result = generate_filename(doc)
assert str(result).encode() == f"{nfc_value}/report.pdf".encode()
def test_nfd_custom_field_name_normalized_as_key(self, settings):
"""NFD characters in a custom field name must appear as NFC in the context dict key."""
nfd_name = unicodedata.normalize("NFD", "Größe")
nfc_name = unicodedata.normalize("NFC", "Größe")
assert nfd_name != nfc_name
settings.FILENAME_FORMAT = f"{{% if custom_fields['{nfc_name}'] %}}{{{{ custom_fields['{nfc_name}']['value'] }}}}/{{{{ title }}}}{{% else %}}{{{{ title }}}}{{% endif %}}"
doc = DocumentFactory(title="letter", mime_type="application/pdf")
cf = CustomField.objects.create(
name=nfd_name,
data_type=CustomField.FieldDataType.STRING,
)
CustomFieldInstance.objects.create(
document=doc,
field=cf,
value_text="Berlin",
)
result = generate_filename(doc)
# If field name key is NFC-normalized, the template condition succeeds
# and result is "Berlin/letter.pdf"; otherwise it falls back to "letter.pdf"
assert str(result) == "Berlin/letter.pdf"
def test_nfd_tag_name_list_normalized_to_nfc(self, settings):
"""NFD tag names in tag_name_list must appear as NFC bytes when iterated."""
settings.FILENAME_FORMAT = (
"{% for t in tag_name_list %}{{ t }}{% endfor %}/{{ title }}"
)
nfd = unicodedata.normalize("NFD", "Büro")
nfc = unicodedata.normalize("NFC", "Büro")
assert nfd != nfc # confirm byte-level difference
doc = DocumentFactory(title="doc", mime_type="application/pdf")
doc.tags.add(TagFactory(name=nfd))
result = generate_filename(doc)
assert str(result).encode() == f"{nfc}/doc.pdf".encode()
@@ -684,6 +684,7 @@ class ConsumerThread(Thread):
subdirs_as_tags: bool = False,
polling_interval: float = 0,
stability_delay: float = 0.1,
rescan_interval: float | None = None,
) -> None:
super().__init__()
self.consumption_dir = consumption_dir
@@ -693,6 +694,8 @@ class ConsumerThread(Thread):
self.polling_interval = polling_interval
self.stability_delay = stability_delay
self.cmd = Command()
if rescan_interval is not None:
self.cmd.rescan_interval_s = rescan_interval
self.cmd.stop_flag.clear()
# Non-daemon ensures finally block runs and connections are closed
self.daemon = False
@@ -1052,3 +1055,200 @@ class TestCommandWatchEdgeCases:
thread.stop_and_wait(timeout=5.0)
# Clean up any Tags created by the thread
Tag.objects.all().delete()
class TestRescanExistingFiles:
"""
Unit tests for the rescan safety net.
Each ``watch()`` recreation silently adopts the current directory contents
as its baseline, so a file appearing between one batch and the next
watcher's baseline is never reported and would sit in the consume directory
forever. ``_rescan_existing_files`` re-injects such files into the
stability tracker as a periodic safety net (see GH issue #13011).
"""
@pytest.fixture
def pdf_only_filter(self) -> ConsumerFilter:
return ConsumerFilter(
supported_extensions=frozenset({".pdf"}),
ignore_patterns=[],
)
def _rescan(
self,
directory: Path,
consumer_filter: ConsumerFilter,
tracker: FileStabilityTracker,
queued: set[Path],
*,
recursive: bool = False,
) -> None:
Command()._rescan_existing_files(
directory=directory,
recursive=recursive,
consumer_filter=consumer_filter,
tracker=tracker,
queued=queued,
)
def test_tracks_stranded_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A supported on-disk file the watcher never reported gets tracked."""
target = consumption_dir / "stranded.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.is_tracking(target) is True
assert tracker.pending_count == 1
def test_skips_already_tracked_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A file already being tracked by the watcher is not double-tracked."""
target = consumption_dir / "tracked.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
tracker.track(target, Change.added)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.pending_count == 1
def test_skips_queued_file(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""A file already queued and awaiting consumption is not re-tracked."""
target = consumption_dir / "inflight.pdf"
shutil.copy(sample_pdf, target)
tracker = FileStabilityTracker(stability_delay=0.1)
queued = {target.resolve()}
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
assert tracker.pending_count == 0
def test_prunes_vanished_queued_paths(
self,
consumption_dir: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Queued paths no longer on disk are dropped so the name can recur."""
gone = (consumption_dir / "gone.pdf").resolve()
tracker = FileStabilityTracker(stability_delay=0.1)
queued = {gone}
self._rescan(consumption_dir, pdf_only_filter, tracker, queued)
assert gone not in queued
def test_skips_unsupported_extension(
self,
consumption_dir: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Files filtered out by the consumer filter are not tracked."""
(consumption_dir / "notes.xyz").write_bytes(b"content")
tracker = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, tracker, set())
assert tracker.pending_count == 0
def test_recursive_respects_flag(
self,
consumption_dir: Path,
sample_pdf: Path,
pdf_only_filter: ConsumerFilter,
) -> None:
"""Nested files are only found when recursive scanning is enabled."""
subdir = consumption_dir / "nested"
subdir.mkdir()
target = subdir / "deep.pdf"
shutil.copy(sample_pdf, target)
shallow = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, shallow, set())
assert shallow.pending_count == 0
deep = FileStabilityTracker(stability_delay=0.1)
self._rescan(consumption_dir, pdf_only_filter, deep, set(), recursive=True)
assert deep.is_tracking(target) is True
class TestProcessExistingFilesQueued:
"""Tests that startup processing reports which paths it queued."""
@pytest.mark.usefixtures("mock_supported_extensions")
def test_returns_queued_paths(
self,
consumption_dir: Path,
sample_pdf: Path,
mock_consume_file_delay: MagicMock,
settings: SettingsWrapper,
) -> None:
"""The set returned seeds the rescan's queued set, avoiding re-queue."""
target = consumption_dir / "document.pdf"
shutil.copy(sample_pdf, target)
settings.CONSUMER_IGNORE_PATTERNS = []
queued = Command()._process_existing_files(
directory=consumption_dir,
recursive=False,
subdirs_as_tags=False,
consumer_filter=ConsumerFilter(ignore_patterns=[]),
)
assert target.resolve() in queued
@pytest.mark.management
@pytest.mark.django_db
class TestCommandRescanRecovery:
"""End-to-end test that the rescan recovers files the watcher misses."""
def test_rescan_consumes_file_the_watcher_never_reports(
self,
consumption_dir: Path,
sample_pdf: Path,
mock_consume_file_delay: MagicMock,
start_consumer: Callable[..., ConsumerThread],
) -> None:
"""
Isolate the rescan path: a long polling interval guarantees the
watcher cannot report the file within the test window, so only the
periodic rescan can consume it.
"""
# poll interval far longer than the test window -> watcher stays silent
thread = start_consumer(
polling_interval=30.0,
stability_delay=0.1,
rescan_interval=0.5,
)
# created after startup, so _process_existing_files did not see it
target = consumption_dir / "stranded.pdf"
shutil.copy(sample_pdf, target)
wait_for_mock_call(mock_consume_file_delay.apply_async, timeout_s=5.0)
if thread.exception:
raise thread.exception
mock_consume_file_delay.apply_async.assert_called()
call_args = mock_consume_file_delay.apply_async.call_args.kwargs["kwargs"][
"input_doc"
]
assert call_args.original_file.name == "stranded.pdf"
@@ -1,7 +1,9 @@
from datetime import timedelta
from types import SimpleNamespace
from unittest import mock
from django.test import TestCase
from django.utils import timezone
from documents.conditionals import metadata_etag
from documents.conditionals import preview_etag
@@ -29,10 +31,31 @@ class TestConditionals(DirectoriesMixin, TestCase):
)
request = SimpleNamespace(query_params={})
self.assertEqual(metadata_etag(request, root.id), latest.checksum)
self.assertEqual(
metadata_etag(request, root.id),
f"{latest.checksum}:{latest.modified.isoformat()}",
)
self.assertEqual(preview_etag(request, root.id), latest.archive_checksum)
self.assertEqual(thumbnail_etag(request, root.id), latest.checksum)
def test_metadata_etag_changes_when_document_modified_changes(self) -> None:
doc = Document.objects.create(
title="doc",
checksum="same-checksum",
mime_type="application/pdf",
)
request = SimpleNamespace(query_params={})
original_etag = metadata_etag(request, doc.id)
new_modified = timezone.now() + timedelta(seconds=5)
Document.objects.filter(id=doc.id).update(modified=new_modified)
self.assertNotEqual(metadata_etag(request, doc.id), original_etag)
self.assertEqual(
metadata_etag(request, doc.id),
f"{doc.checksum}:{new_modified.isoformat()}",
)
def test_resolve_effective_doc_returns_none_for_invalid_or_unrelated_version(
self,
) -> None:
+28
View File
@@ -30,6 +30,7 @@ from documents.signals.handlers import update_llm_suggestions_cache
from documents.tests.utils import DirectoriesMixin
from documents.tests.utils import read_streaming_response
from paperless.models import ApplicationConfiguration
from paperless_ai.exceptions import LLMTimeoutError
class TestViews(DirectoriesMixin, TestCase):
@@ -476,6 +477,33 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
)
@patch("documents.views.get_ai_document_classification")
@override_settings(
AI_ENABLED=True,
LLM_BACKEND="openai-like",
)
def test_ai_suggestions_with_llm_timeout(
self,
mock_get_ai_classification,
) -> None:
mock_get_ai_classification.side_effect = LLMTimeoutError()
self.client.force_login(user=self.user)
response = self.client.get(
f"/api/documents/{self.document.pk}/ai_suggestions/",
)
self.assertEqual(response.status_code, status.HTTP_503_SERVICE_UNAVAILABLE)
self.assertEqual(
response.json(),
{
"ai": ["AI backend request timed out."],
},
)
self.assertIsNone(
get_llm_suggestion_cache(self.document.pk, backend="openai-like"),
)
def test_invalidate_suggestions_cache(self) -> None:
self.client.force_login(user=self.user)
suggestions = {
+114 -12
View File
@@ -12,6 +12,7 @@ from datetime import timedelta
from http import HTTPStatus
from pathlib import Path
from time import mktime
from time import sleep
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
@@ -240,6 +241,7 @@ from paperless.serialisers import UserSerializer
from paperless.views import StandardPagination
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.chat import stream_chat_with_documents
from paperless_ai.exceptions import LLMTimeoutError
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
@@ -1400,7 +1402,7 @@ class DocumentViewSet(
)
if request.user is not None and not has_perms_owner_aware(
request.user,
"view_document",
"change_document",
doc,
):
return HttpResponseForbidden("Insufficient permissions")
@@ -1460,7 +1462,7 @@ class DocumentViewSet(
)
if request.user is not None and not has_perms_owner_aware(
request.user,
"view_document",
"change_document",
doc,
):
return HttpResponseForbidden("Insufficient permissions")
@@ -1509,6 +1511,17 @@ class DocumentViewSet(
exc_info=True,
)
raise ValidationError({"ai": [_("Invalid AI configuration.")]}) from exc
except LLMTimeoutError as exc:
logger.exception(
"AI backend timed out while generating suggestions for document %s: %s",
doc.pk,
exc,
exc_info=True,
)
return Response(
{"ai": [_("AI backend request timed out.")]},
status=status.HTTP_503_SERVICE_UNAVAILABLE,
)
matched_tags = match_tags_by_name(
llm_suggestions.get("tags", []),
@@ -2276,6 +2289,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
return super().list(request)
from documents.search import SearchHit
from documents.search import SearchQueryError
from documents.search import TantivyBackend
from documents.search import TantivyRelevanceList
from documents.search import get_backend
@@ -2468,6 +2482,11 @@ class UnifiedSearchViewSet(DocumentViewSet):
return HttpResponseForbidden(_("Insufficient permissions."))
except ValidationError:
raise
except SearchQueryError as e:
# User-fixable query error (e.g. an unparsable date): surface the
# specific message so the user can correct it, rather than a generic
# 400 or silently empty results.
raise ValidationError({"query": [str(e)]}) from e
except Exception as e:
logger.warning(f"An error occurred listing search results: {e!s}")
return HttpResponseBadRequest(
@@ -3126,6 +3145,7 @@ class PostDocumentView(GenericAPIView[Any]):
serializer.is_valid(raise_exception=True)
doc_name, doc_data = serializer.validated_data.get("document")
doc_name = normalize("NFC", doc_name)
correspondent_id = serializer.validated_data.get("correspondent")
document_type_id = serializer.validated_data.get("document_type")
storage_path_id = serializer.validated_data.get("storage_path")
@@ -4011,7 +4031,7 @@ class RemoteVersionView(GenericAPIView[Any]):
class _TasksViewSetSchema(AutoSchema):
_UNPAGINATED_ACTIONS = frozenset({"summary", "active"})
_UNPAGINATED_ACTIONS = frozenset({"summary", "active", "status_counts"})
def _get_paginator(self):
if getattr(self.view, "action", None) in self._UNPAGINATED_ACTIONS:
@@ -4033,7 +4053,7 @@ class _TasksViewSetSchema(AutoSchema):
),
acknowledge=extend_schema(
operation_id="acknowledge_tasks",
description="Acknowledge a list of tasks",
description="Acknowledge a list of tasks, or all visible unacknowledged tasks",
request=AcknowledgeTasksViewSerializer,
responses={
(200, "application/json"): inline_serializer(
@@ -4071,6 +4091,19 @@ class _TasksViewSetSchema(AutoSchema):
),
],
),
status_counts=extend_schema(
responses={
200: inline_serializer(
name="TaskStatusCounts",
fields={
"all": serializers.IntegerField(),
"needs_attention": serializers.IntegerField(),
"in_progress": serializers.IntegerField(),
"completed": serializers.IntegerField(),
},
),
},
),
active=extend_schema(
description="Currently pending and running tasks (capped at 50).",
responses={200: TaskSerializerV10(many=True)},
@@ -4124,6 +4157,7 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
PaperlessTask.TaskType.SANITY_CHECK: (sanity_check, {"raise_on_error": False}),
PaperlessTask.TaskType.LLM_INDEX: (llmindex_index, {"rebuild": False}),
}
_STATUS_COUNT_EXCLUDED_FILTERS = frozenset({"status", "is_complete"})
def get_serializer_class(self):
# v9: use backwards-compatible serializer with old field names
@@ -4164,16 +4198,38 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
queryset = queryset.filter(task_id=task_id)
return queryset
def get_status_count_queryset(self):
"""Apply task filters except the status dimensions represented by the counts."""
query_params = self.request.query_params.copy()
for param in self._STATUS_COUNT_EXCLUDED_FILTERS:
query_params.pop(param, None)
filterset = self.filterset_class(
data=query_params,
queryset=self.get_queryset(),
request=self.request,
)
if not filterset.is_valid():
raise ValidationError(filterset.errors)
return filterset.qs
@action(
methods=["post"],
detail=False,
permission_classes=[IsAuthenticated, AcknowledgeTasksPermissions],
)
def acknowledge(self, request):
serializer = AcknowledgeTasksViewSerializer(data=request.data)
queryset = self.get_queryset()
serializer = AcknowledgeTasksViewSerializer(
data=request.data,
context={"queryset": queryset},
)
serializer.is_valid(raise_exception=True)
task_ids = serializer.validated_data.get("tasks")
tasks = self.get_queryset().filter(id__in=task_ids)
if serializer.validated_data.get("all", False):
tasks = queryset.filter(acknowledged=False)
else:
task_ids = serializer.validated_data.get("tasks")
tasks = queryset.filter(id__in=task_ids)
count = tasks.update(acknowledged=True)
return Response({"result": count})
@@ -4226,6 +4282,34 @@ class TasksViewSet(ReadOnlyModelViewSet[PaperlessTask]):
serializer = TaskSummarySerializer(data, many=True)
return Response(serializer.data)
@action(methods=["get"], detail=False)
def status_counts(self, request):
"""Aggregated task counts for task UI sections."""
queryset = self.get_status_count_queryset()
counts = queryset.aggregate(
all=Count("id"),
needs_attention=Count(
"id",
filter=Q(
status__in=[
PaperlessTask.Status.FAILURE,
PaperlessTask.Status.REVOKED,
],
),
),
in_progress=Count(
"id",
filter=Q(
status__in=[
PaperlessTask.Status.PENDING,
PaperlessTask.Status.STARTED,
],
),
),
completed=Count("id", filter=Q(status=PaperlessTask.Status.SUCCESS)),
)
return Response(counts)
@action(methods=["get"], detail=False)
def active(self, request):
"""Currently pending and running tasks (capped at 50)."""
@@ -4925,11 +5009,29 @@ class SystemStatusView(PassUserMixin):
celery_error = None
celery_url = None
try:
celery_ping = celery_app.control.inspect().ping()
celery_url = next(iter(celery_ping.keys()))
first_worker_ping = celery_ping[celery_url]
if first_worker_ping["ok"] == "pong":
celery_active = "OK"
celery_ping = None
for ping_attempt in range(3):
celery_ping = celery_app.control.inspect().ping()
if celery_ping:
break
if ping_attempt < 2:
sleep(0.25)
if not celery_ping:
celery_active = "WARNING"
celery_error = (
"No celery workers responded to ping. This may be temporary."
)
else:
celery_url, first_worker_ping = next(iter(celery_ping.items()))
if (
isinstance(first_worker_ping, dict)
and first_worker_ping.get("ok") == "pong"
):
celery_active = "OK"
else:
celery_active = "WARNING"
celery_error = "Celery worker responded unexpectedly."
except Exception as e:
celery_active = "ERROR"
logger.exception(
+4
View File
@@ -197,6 +197,7 @@ class AIConfig(BaseConfig):
llm_embedding_endpoint: str = dataclasses.field(init=False)
llm_embedding_chunk_size: int = dataclasses.field(init=False)
llm_context_size: int = dataclasses.field(init=False)
llm_request_timeout: int = dataclasses.field(init=False)
llm_backend: str = dataclasses.field(init=False)
llm_model: str = dataclasses.field(init=False)
llm_api_key: str = dataclasses.field(init=False)
@@ -221,6 +222,9 @@ class AIConfig(BaseConfig):
app_config.llm_embedding_chunk_size or settings.LLM_EMBEDDING_CHUNK_SIZE
)
self.llm_context_size = app_config.llm_context_size or settings.LLM_CONTEXT_SIZE
self.llm_request_timeout = (
app_config.llm_request_timeout or settings.LLM_REQUEST_TIMEOUT
)
self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND
self.llm_model = app_config.llm_model or settings.LLM_MODEL
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
@@ -0,0 +1,23 @@
# Generated by Django 5.2.14 on 2026-06-14 14:22
import django.core.validators
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
dependencies = [
("paperless", "0012_applicationconfiguration_llm_output_language"),
]
operations = [
migrations.AddField(
model_name="applicationconfiguration",
name="llm_request_timeout",
field=models.PositiveSmallIntegerField(
null=True,
validators=[django.core.validators.MinValueValidator(1)],
verbose_name="Sets the LLM request timeout in seconds",
),
),
]
+6
View File
@@ -366,6 +366,12 @@ class ApplicationConfiguration(AbstractSingletonModel):
max_length=32,
)
llm_request_timeout = models.PositiveSmallIntegerField(
verbose_name=_("Sets the LLM timeout in seconds"),
null=True,
validators=[MinValueValidator(1)],
)
class Meta:
verbose_name = _("paperless application settings")
permissions = [
+2 -28
View File
@@ -20,6 +20,7 @@ from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from paperless.parsers.utils import read_file_handle_unicode_errors
from paperless.version import __full_version_str__
if TYPE_CHECKING:
@@ -183,7 +184,7 @@ class TextDocumentParser:
documents.parsers.ParseError
If the file cannot be read.
"""
self._text = self._read_text(document_path)
self._text = read_file_handle_unicode_errors(document_path, log=logger)
# ------------------------------------------------------------------
# Result accessors
@@ -295,30 +296,3 @@ class TextDocumentParser:
Always ``[]`` plain text files carry no structured metadata.
"""
return []
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _read_text(self, filepath: Path) -> str:
"""Read file content, replacing invalid UTF-8 bytes rather than failing.
Parameters
----------
filepath:
Path to the file to read.
Returns
-------
str
File content as a string.
"""
try:
return filepath.read_text(encoding="utf-8")
except UnicodeDecodeError as exc:
logger.warning(
"Unicode error reading %s, replacing bad bytes: %s",
filepath,
exc,
)
return filepath.read_bytes().decode("utf-8", errors="replace")
+18 -5
View File
@@ -8,6 +8,7 @@ share implementation.
from __future__ import annotations
import codecs
import logging
import re
import tempfile
@@ -114,7 +115,7 @@ def read_file_handle_unicode_errors(
filepath: Path,
log: logging.Logger | None = None,
) -> str:
"""Read a file as UTF-8 text, replacing invalid bytes rather than raising.
"""Read a file as text, detecting encoding via BOM and stripping NUL bytes.
Parameters
----------
@@ -127,15 +128,27 @@ def read_file_handle_unicode_errors(
Returns
-------
str
File content as a string, with any invalid UTF-8 sequences replaced
by the Unicode replacement character.
File content as a string, with NUL bytes removed so the result is
safe to store in PostgreSQL text fields.
"""
_log = log or logger
raw = filepath.read_bytes()
if raw.startswith((codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE)):
encoding = "utf-16"
elif raw.startswith(codecs.BOM_UTF8):
encoding = "utf-8-sig"
else:
encoding = "utf-8"
try:
return filepath.read_text(encoding="utf-8")
text = raw.decode(encoding)
except UnicodeDecodeError as e:
_log.warning("Unicode error during text reading, continuing: %s", e)
return filepath.read_bytes().decode("utf-8", errors="replace")
text = raw.decode("utf-8", errors="replace")
# PostgreSQL rejects NUL (0x00) bytes in text fields
return text.replace("\x00", "")
def get_page_count_for_pdf(
+12 -2
View File
@@ -97,8 +97,14 @@ MODEL_FILE = get_path_from_env(
DATA_DIR / "classification_model.pickle",
)
LLM_INDEX_DIR = DATA_DIR / "llm_index"
LLM_INDEX_LOCK = DATA_DIR / "locks" / "llm_index.lock"
(DATA_DIR / "locks").mkdir(parents=True, exist_ok=True)
LLM_INDEX_LOCK = LLM_INDEX_DIR / "index.lock"
# Cross-process read/write lock guarding the LLM index compaction/migration
# file swap. Readers hold it shared; the swap takes it exclusively so it never
# runs while a reader connection is open. Must be a SQLite (.db) file.
LLM_INDEX_RWLOCK = LLM_INDEX_DIR / "llmindex.rwlock.db"
# Seconds the compaction swap waits for active readers to drain before skipping
# this cycle (it is a maintenance operation; the next run retries).
LLM_INDEX_COMPACTION_LOCK_TIMEOUT = 30
LOGGING_DIR = get_path_from_env("PAPERLESS_LOGGING_DIR", DATA_DIR / "log")
@@ -644,6 +650,7 @@ LOGGING = {
"kombu": {"handlers": ["file_celery"], "level": "DEBUG"},
"_granian": {"handlers": ["file_paperless"], "level": "DEBUG"},
"granian.access": {"handlers": ["file_paperless"], "level": "DEBUG"},
"httpx": {"level": "WARNING"},
},
}
@@ -1199,6 +1206,9 @@ if LLM_EMBEDDING_CHUNK_SIZE < 1:
LLM_CONTEXT_SIZE = get_int_from_env("PAPERLESS_AI_LLM_CONTEXT_SIZE", 8192)
if LLM_CONTEXT_SIZE < 1:
raise ImproperlyConfigured("PAPERLESS_AI_LLM_CONTEXT_SIZE must be >= 1")
LLM_REQUEST_TIMEOUT = get_int_from_env("PAPERLESS_AI_LLM_REQUEST_TIMEOUT", 120)
if LLM_REQUEST_TIMEOUT < 1:
raise ImproperlyConfigured("PAPERLESS_AI_LLM_REQUEST_TIMEOUT must be >= 1")
LLM_BACKEND = get_choice_from_env(
"PAPERLESS_AI_LLM_BACKEND",
{"ollama", "openai-like"},
+3
View File
@@ -252,6 +252,9 @@ def parse_db_settings(data_dir: Path) -> dict[str, dict[str, Any]]:
"NAME": os.getenv("PAPERLESS_DBNAME", "paperless"),
"USER": os.getenv("PAPERLESS_DBUSER", "paperless"),
"PASSWORD": os.getenv("PAPERLESS_DBPASS", "paperless"),
# Validate pooled connections so a connection closed server-side
# is replaced rather than handed out as "the connection is closed".
"CONN_HEALTH_CHECKS": True,
}
base_options = {
@@ -398,6 +398,7 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "localhost",
"NAME": "paperless",
"USER": "paperless",
@@ -426,6 +427,7 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "paperless-db-host",
"PORT": 1111,
"NAME": "customdb",
@@ -455,6 +457,7 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "pghost",
"NAME": "paperless",
"USER": "paperless",
@@ -485,6 +488,7 @@ class TestParseDbSettings:
{
"default": {
"ENGINE": "django.db.backends.postgresql",
"CONN_HEALTH_CHECKS": True,
"HOST": "pghost",
"NAME": "paperless",
"USER": "paperless",
+37
View File
@@ -2,13 +2,50 @@
from __future__ import annotations
import codecs
from pathlib import Path
from paperless.parsers.utils import is_tagged_pdf
from paperless.parsers.utils import read_file_handle_unicode_errors
SAMPLES = Path(__file__).parent / "samples" / "tesseract"
class TestReadFileHandleUnicodeErrors:
def test_plain_utf8(self, tmp_path: Path) -> None:
f = tmp_path / "plain.txt"
f.write_bytes(b"hello world")
assert read_file_handle_unicode_errors(f) == "hello world"
def test_utf8_bom(self, tmp_path: Path) -> None:
f = tmp_path / "bom.txt"
f.write_bytes(codecs.BOM_UTF8 + b"hello")
assert read_file_handle_unicode_errors(f) == "hello"
def test_utf16_le(self, tmp_path: Path) -> None:
f = tmp_path / "utf16le.txt"
f.write_bytes(codecs.BOM_UTF16_LE + "hello".encode("utf-16-le"))
assert read_file_handle_unicode_errors(f) == "hello"
def test_utf16_be(self, tmp_path: Path) -> None:
f = tmp_path / "utf16be.txt"
f.write_bytes(codecs.BOM_UTF16_BE + "hello".encode("utf-16-be"))
assert read_file_handle_unicode_errors(f) == "hello"
def test_nul_bytes_stripped(self, tmp_path: Path) -> None:
f = tmp_path / "null-bytes.txt"
f.write_bytes(b"foo\x00bar")
assert read_file_handle_unicode_errors(f) == "foobar"
def test_invalid_utf8_replaced(self, tmp_path: Path) -> None:
f = tmp_path / "bad.txt"
f.write_bytes(b"ok\x80\x81bad")
result = read_file_handle_unicode_errors(f)
assert "ok" in result
assert "bad" in result
assert "\x00" not in result
class TestIsTaggedPdf:
def test_tagged_pdf_returns_true(self) -> None:
assert is_tagged_pdf(SAMPLES / "simple-digital.pdf") is True
+2 -2
View File
@@ -49,7 +49,7 @@ from paperless.serialisers import GroupSerializer
from paperless.serialisers import PaperlessAuthTokenSerializer
from paperless.serialisers import ProfileSerializer
from paperless.serialisers import UserSerializer
from paperless_ai.indexing import vector_store_file_exists
from paperless_ai.indexing import llm_index_exists
class PaperlessObtainAuthTokenView(ObtainAuthToken):
@@ -467,7 +467,7 @@ class ApplicationConfigurationViewSet(ModelViewSet[ApplicationConfiguration]):
or old_llm_context_size != new_llm_context_size
)
rebuild_needed = new_ai_index_enabled and (
not vector_store_file_exists() or embedding_config_changed
not llm_index_exists() or embedding_config_changed
)
if rebuild_needed:
+35 -21
View File
@@ -8,6 +8,7 @@ from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware
from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.db import db_connection_released
from paperless_ai.indexing import query_similar_documents
from paperless_ai.indexing import truncate_content
@@ -24,9 +25,14 @@ def get_language_name(language_code: str) -> str:
def build_prompt_without_rag(
document: Document,
config: AIConfig,
) -> str:
filename = document.filename or ""
content = truncate_content(document.content[:4000] or "")
content = truncate_content(
document.content[:4000] or "",
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
return f"""
You are a document classification assistant.
@@ -49,10 +55,15 @@ def build_prompt_without_rag(
def build_prompt_with_rag(
document: Document,
config: AIConfig,
user: User | None = None,
) -> str:
base_prompt = build_prompt_without_rag(document)
context = truncate_content(get_context_for_document(document, user))
base_prompt = build_prompt_without_rag(document, config)
context = truncate_content(
get_context_for_document(document, user),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
return f"""{base_prompt}
@@ -130,26 +141,29 @@ def get_ai_document_classification(
ai_config = AIConfig()
prompt = (
build_prompt_with_rag(document, user)
build_prompt_with_rag(document, ai_config, user)
if ai_config.llm_embedding_backend
else build_prompt_without_rag(document)
else build_prompt_without_rag(document, ai_config)
)
client = AIClient()
result = client.run_llm_query(prompt)
suggestions = parse_ai_response(result)
if output_language:
localized = client.run_llm_query(
build_localization_prompt(suggestions, output_language),
)
localized_suggestions = parse_ai_response(localized)
suggestions = {
**suggestions,
"title": localized_suggestions["title"] or suggestions["title"],
"tags": localized_suggestions["tags"] or suggestions["tags"],
"document_types": localized_suggestions["document_types"]
or suggestions["document_types"],
"storage_paths": localized_suggestions["storage_paths"]
or suggestions["storage_paths"],
}
# Hand the pooled DB connection back while the (slow) LLM query runs so it
# is not pinned for the call's duration; see paperless_ai.db and #12976.
with db_connection_released():
result = client.run_llm_query(prompt)
suggestions = parse_ai_response(result)
if output_language:
localized = client.run_llm_query(
build_localization_prompt(suggestions, output_language),
)
localized_suggestions = parse_ai_response(localized)
suggestions = {
**suggestions,
"title": localized_suggestions["title"] or suggestions["title"],
"tags": localized_suggestions["tags"] or suggestions["tags"],
"document_types": localized_suggestions["document_types"]
or suggestions["document_types"],
"storage_paths": localized_suggestions["storage_paths"]
or suggestions["storage_paths"],
}
return suggestions
+57 -123
View File
@@ -3,9 +3,13 @@ import logging
import sys
from documents.models import Document
from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.db import db_connection_released
from paperless_ai.indexing import _document_id_filters
from paperless_ai.indexing import get_rag_prompt_helper
from paperless_ai.indexing import load_or_build_index
from paperless_ai.indexing import read_store
logger = logging.getLogger("paperless_ai.chat")
@@ -75,148 +79,78 @@ def _format_chat_metadata_trailer(references: list[dict[str, int | str]]) -> str
)
def _get_document_filtered_retriever(index, doc_ids: set[str], similarity_top_k: int):
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.core.vector_stores import VectorStoreQuery
class DocumentFilteredFaissRetriever(BaseRetriever):
def __init__(self):
super().__init__()
self._cached_query_str = None
self._cached_nodes = []
def _retrieve(self, query_bundle):
if query_bundle.query_str == self._cached_query_str:
return self._cached_nodes
if query_bundle.embedding is None:
query_bundle.embedding = (
index._embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs,
)
)
faiss_index = index.vector_store._faiss_index
max_top_k = faiss_index.ntotal
if max_top_k == 0:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = []
return []
query_top_k = min(max(similarity_top_k, 1), max_top_k)
allowed_nodes: list[NodeWithScore] = []
seen_node_ids: set[str] = set()
while query_top_k <= max_top_k:
query_result = index.vector_store.query(
VectorStoreQuery(
query_embedding=query_bundle.embedding,
similarity_top_k=query_top_k,
),
)
for vector_id, score in zip(
query_result.ids or [],
query_result.similarities or [],
strict=False,
):
node_id = index.index_struct.nodes_dict.get(vector_id)
if node_id is None or node_id in seen_node_ids:
continue
node = index.docstore.docs.get(node_id)
if node is None or node.metadata.get("document_id") not in doc_ids:
continue
seen_node_ids.add(node_id)
allowed_nodes.append(NodeWithScore(node=node, score=score))
if len(allowed_nodes) >= similarity_top_k:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
if query_top_k == max_top_k:
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
query_top_k = min(query_top_k * 2, max_top_k)
self._cached_query_str = query_bundle.query_str
self._cached_nodes = allowed_nodes
return allowed_nodes
return DocumentFilteredFaissRetriever()
def stream_chat_with_documents(query_str: str, documents: list[Document]):
try:
yield from _stream_chat_with_documents(query_str, documents)
except Exception as e:
logger.exception(f"Failed to stream document chat response: {e}", exc_info=True)
logger.exception("Failed to stream document chat response: %s", e)
yield CHAT_ERROR_MESSAGE
def _stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
index = load_or_build_index()
doc_ids = [str(doc.pk) for doc in documents]
# Filter only the node(s) that match the document IDs
nodes = [
node
for node in index.docstore.docs.values()
if node.metadata.get("document_id") in doc_ids
]
if len(nodes) == 0:
logger.warning("No nodes found for the given documents.")
if not documents:
yield CHAT_NO_CONTENT_MESSAGE
return
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.retrievers import VectorIndexRetriever
retriever = _get_document_filtered_retriever(
index,
set(doc_ids),
CHAT_RETRIEVER_TOP_K,
)
config = AIConfig()
filters = _document_id_filters(str(doc.pk) for doc in documents)
top_nodes = retriever.retrieve(query_str)
if len(top_nodes) == 0:
logger.warning("Retriever returned no nodes for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
# Hold the shared read lock for the whole operation: the query engine
# retrieves from the vector store again during synthesis, so the connection
# must stay open (and the swap must not run) until the stream finishes.
with read_store() as store:
index = load_or_build_index(config, store)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=CHAT_RETRIEVER_TOP_K,
filters=filters,
)
references = _get_document_references(documents, top_nodes)
# Slow query-embedding + vector search; no Django ORM access happens
# during it, so release the pooled DB connection for its duration. See
# #12976.
with db_connection_released():
top_nodes = retriever.retrieve(query_str)
if not top_nodes:
logger.warning("No nodes found for the given documents.")
yield CHAT_NO_CONTENT_MESSAGE
return
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
response_synthesizer = get_response_synthesizer(
llm=client.llm,
prompt_helper=get_rag_prompt_helper(),
text_qa_template=prompt_template,
streaming=True,
)
client = AIClient()
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
response_synthesizer=response_synthesizer,
streaming=True,
)
references = _get_document_references(documents, top_nodes)
logger.debug("Document chat query: %s", query_str)
prompt_template = PromptTemplate(template=CHAT_PROMPT_TMPL)
response_synthesizer = get_response_synthesizer(
llm=client.llm,
prompt_helper=get_rag_prompt_helper(
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
),
text_qa_template=prompt_template,
streaming=True,
)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
response_synthesizer=response_synthesizer,
streaming=True,
)
response_stream = query_engine.query(query_str)
logger.debug("Document chat query: %s", query_str)
# Release the pooled DB connection for the slow streaming LLM response
# so it is not pinned for the whole stream; see paperless_ai.db and
# #12976.
with db_connection_released():
response_stream = query_engine.query(query_str)
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()
if references:
yield _format_chat_metadata_trailer(references)
if references:
yield _format_chat_metadata_trailer(references)
+49 -28
View File
@@ -1,11 +1,14 @@
import json
import logging
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING
import httpx
from paperless.models import LLMBackend
if TYPE_CHECKING:
from llama_index.core.llms import ChatMessage
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai_like import OpenAILike
@@ -16,6 +19,7 @@ from paperless.network import create_pinned_async_httpx_client
from paperless.network import create_pinned_httpx_client
from paperless.network import validate_outbound_http_url
from paperless_ai.base_model import DocumentClassifierSchema
from paperless_ai.exceptions import LLMTimeoutError
logger = logging.getLogger("paperless_ai.client")
@@ -61,16 +65,16 @@ class AIClient:
model=self.settings.llm_model or "llama3.1",
base_url=endpoint,
context_window=self.settings.llm_context_size,
request_timeout=120,
request_timeout=self.settings.llm_request_timeout,
system_prompt=LLM_SYSTEM_PROMPT,
client=Client(
host=endpoint,
timeout=120,
timeout=self.settings.llm_request_timeout,
transport=transport,
),
async_client=AsyncClient(
host=endpoint,
timeout=120,
timeout=self.settings.llm_request_timeout,
transport=async_transport,
),
)
@@ -84,15 +88,18 @@ class AIClient:
http_client = create_pinned_httpx_client(
endpoint,
allow_internal=self.settings.llm_allow_internal_endpoints,
timeout=self.settings.llm_request_timeout,
)
async_http_client = create_pinned_async_httpx_client(
endpoint,
allow_internal=self.settings.llm_allow_internal_endpoints,
timeout=self.settings.llm_request_timeout,
)
return OpenAILike(
model=self.settings.llm_model or "gpt-3.5-turbo",
api_base=endpoint,
api_key=self.settings.llm_api_key,
timeout=self.settings.llm_request_timeout,
is_chat_model=True,
is_function_calling_model=True,
system_prompt=LLM_SYSTEM_PROMPT,
@@ -113,11 +120,12 @@ class AIClient:
user_msg = ChatMessage(role="user", content=prompt)
if self.settings.llm_backend == LLMBackend.OLLAMA:
result = self.llm.chat(
[user_msg],
format=DocumentClassifierSchema.model_json_schema(),
think=False,
)
with self._normalize_timeouts():
result = self.llm.chat(
[user_msg],
format=DocumentClassifierSchema.model_json_schema(),
think=False,
)
logger.debug("LLM query result: %s", result)
parsed = DocumentClassifierSchema(**json.loads(result.message.content))
return parsed.model_dump()
@@ -125,26 +133,39 @@ class AIClient:
from llama_index.core.program.function_program import get_function_tool
tool = get_function_tool(DocumentClassifierSchema)
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
allow_parallel_tool_calls=True,
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_call=True,
)
with self._normalize_timeouts():
result = self.llm.chat_with_tools(
tools=[tool],
user_msg=user_msg,
chat_history=[],
allow_parallel_tool_calls=True,
tool_required=True,
)
tool_calls = self.llm.get_tool_calls_from_response(
result,
error_on_no_tool_call=True,
)
logger.debug("LLM query result: %s", tool_calls)
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
return parsed.model_dump()
def run_chat(self, messages: list["ChatMessage"]) -> str:
logger.debug(
"Running chat query against %s with model %s",
self.settings.llm_backend,
self.settings.llm_model,
)
result = self.llm.chat(messages)
logger.debug("Chat result: %s", result)
return result
@contextmanager
def _normalize_timeouts(self) -> Iterator[None]:
try:
yield
except httpx.TimeoutException as exc:
raise LLMTimeoutError from exc
except Exception as exc:
if self._is_openai_timeout(exc):
raise LLMTimeoutError from exc
raise
def _is_openai_timeout(self, exc: Exception) -> bool:
if self.settings.llm_backend != LLMBackend.OPENAI_LIKE:
return False
# Keep OpenAI imports out of module import paths and only load the SDK
# when translating an error from an OpenAI-backed request.
from openai import APITimeoutError
return isinstance(exc, APITimeoutError)
+30
View File
@@ -0,0 +1,30 @@
from __future__ import annotations
from contextlib import contextmanager
from django.db import connections
@contextmanager
def db_connection_released():
"""
Return any checked-out DB connections to the pool for the duration of the
wrapped block.
The AI endpoints run inside a synchronous web request (``ai_suggestions``)
or a streaming response (``chat``). Django keeps the request's database
connection checked out for the entire request/response, so a blocking LLM
call - which can take many seconds - pins a pooled connection the whole
time. With connection pooling enabled, enough concurrent AI requests check
out every slot and all other requests then fail with
``psycopg_pool.PoolTimeout`` (see issue #12976).
No Django ORM access happens during the LLM call, so we hand the connection
back to the pool first; Django transparently re-checks-out a connection on
the next ORM use after the block.
"""
connections.close_all()
try:
yield
finally:
connections.close_all()
+27 -50
View File
@@ -1,12 +1,9 @@
import json
import re
from typing import TYPE_CHECKING
from django.conf import settings
if TYPE_CHECKING:
from pathlib import Path
from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document
@@ -23,9 +20,7 @@ OCR_LEADER_REGEX = re.compile(r"[._\-\u00b7]{4,}")
HORIZONTAL_WHITESPACE_REGEX = re.compile(r"[ \t\u00a0]+")
def get_embedding_model() -> "BaseEmbedding":
config = AIConfig()
def get_embedding_model(config: AIConfig) -> "BaseEmbedding":
match config.llm_embedding_backend:
case LLMEmbeddingBackend.OPENAI_LIKE:
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
@@ -37,15 +32,18 @@ def get_embedding_model() -> "BaseEmbedding":
http_client = create_pinned_httpx_client(
endpoint,
allow_internal=config.llm_allow_internal_endpoints,
timeout=config.llm_request_timeout,
)
async_http_client = create_pinned_async_httpx_client(
endpoint,
allow_internal=config.llm_allow_internal_endpoints,
timeout=config.llm_request_timeout,
)
return OpenAILikeEmbedding(
model_name=config.llm_embedding_model or "text-embedding-3-small",
api_key=config.llm_api_key,
api_base=endpoint,
timeout=config.llm_request_timeout,
http_client=http_client,
async_http_client=async_http_client,
)
@@ -78,12 +76,14 @@ def get_embedding_model() -> "BaseEmbedding":
)
embedding._client = Client(
host=endpoint,
timeout=config.llm_request_timeout,
transport=PinnedHostHTTPTransport(
allow_internal=config.llm_allow_internal_endpoints,
),
)
embedding._async_client = AsyncClient(
host=endpoint,
timeout=config.llm_request_timeout,
transport=PinnedHostAsyncHTTPTransport(
allow_internal=config.llm_allow_internal_endpoints,
),
@@ -95,41 +95,24 @@ def get_embedding_model() -> "BaseEmbedding":
)
def get_embedding_dim() -> int:
"""
Loads embedding dimension from meta.json if available, otherwise infers it
from a dummy embedding and stores it for future use.
"""
config = AIConfig()
default_model = {
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
}.get(
config.llm_embedding_backend,
"sentence-transformers/all-MiniLM-L6-v2",
_DEFAULT_MODEL_NAMES = {
LLMEmbeddingBackend.OPENAI_LIKE: "text-embedding-3-small",
LLMEmbeddingBackend.HUGGINGFACE: "sentence-transformers/all-MiniLM-L6-v2",
LLMEmbeddingBackend.OLLAMA: "embeddinggemma",
}
def get_configured_model_name(config: AIConfig) -> str:
"""Return the canonical name of the currently configured embedding model."""
# dict.get(key, default) overload resolution fails for TextChoices keys in some
# type checkers; use `or` fallback to avoid the ambiguity.
default = (
_DEFAULT_MODEL_NAMES.get(
config.llm_embedding_backend,
)
or "sentence-transformers/all-MiniLM-L6-v2"
)
model = config.llm_embedding_model or default_model
meta_path: Path = settings.LLM_INDEX_DIR / "meta.json"
if meta_path.exists():
with meta_path.open() as f:
meta = json.load(f)
if meta.get("embedding_model") != model:
raise RuntimeError(
f"Embedding model changed from {meta.get('embedding_model')} to {model}. "
"You must rebuild the index.",
)
return meta["dim"]
embedding_model = get_embedding_model()
test_embed = embedding_model.get_text_embedding("test")
dim = len(test_embed)
with meta_path.open("w") as f:
json.dump({"embedding_model": model, "dim": dim}, f)
return dim
return config.llm_embedding_model or default
def _normalize_llm_index_text(text: str) -> str:
@@ -138,17 +121,11 @@ def _normalize_llm_index_text(text: str) -> str:
def build_llm_index_text(doc: Document) -> str:
# Short structured fields (filename, storage path, ASN, title, tags, ...) live
# in node.metadata: excluded from embeddings, shown to the LLM via metadata
# prepend. Notes and Custom Fields stay in the body: Notes can be long free
# text, Custom Fields are dynamic in count and best kept in the embedding.
lines = [
f"Title: {doc.title}",
f"Filename: {doc.filename}",
f"Created: {doc.created}",
f"Added: {doc.added}",
f"Modified: {doc.modified}",
f"Tags: {', '.join(tag.name for tag in doc.tags.all())}",
f"Document Type: {doc.document_type.name if doc.document_type else ''}",
f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}",
f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}",
f"Archive Serial Number: {doc.archive_serial_number or ''}",
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
]
+2
View File
@@ -0,0 +1,2 @@
class LLMTimeoutError(Exception):
pass
+280 -243
View File
@@ -1,28 +1,30 @@
import logging
import shutil
from collections import defaultdict
from collections.abc import Iterable
from contextlib import contextmanager
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
from django.conf import settings
from django.utils import timezone
from filelock import FileLock
from filelock import ReadWriteLock
from filelock import Timeout
from documents.models import Document
from documents.models import PaperlessTask
from documents.utils import IterWrapper
from documents.utils import identity
from paperless.config import AIConfig
from paperless_ai.db import db_connection_released
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_configured_model_name
from paperless_ai.embedding import get_embedding_model
if TYPE_CHECKING:
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import BaseNode
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
logger = logging.getLogger("paperless_ai.indexing")
@@ -30,21 +32,11 @@ RAG_NUM_OUTPUT = 512
RAG_CHUNK_OVERLAP = 200
def _index_lock_path() -> Path:
"""Return the path used as the file lock for FAISS index mutations.
The lock file lives in DATA_DIR/locks/ (not inside LLM_INDEX_DIR) so that a
rebuild which calls shutil.rmtree(LLM_INDEX_DIR) cannot delete the lock
while another worker still holds it.
"""
return settings.LLM_INDEX_LOCK
def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
# NOTE: The check-then-enqueue sequence below is non-atomic (TOCTOU): two
# concurrent workers can both observe no running task and both enqueue a
# full rebuild. This is wasteful but not data-corrupting — update_llm_index
# is itself protected by _index_lock_path(), so only one rebuild runs at a
# is itself protected by settings.LLM_INDEX_LOCK, so only one rebuild runs at a
# time and the second one is serialised after the first completes.
from documents.tasks import llmindex_index
@@ -71,46 +63,110 @@ def queue_llm_index_update_if_needed(*, rebuild: bool, reason: str) -> bool:
return True
def get_or_create_storage_context(*, rebuild=False):
"""
Loads or creates the StorageContext (vector store, docstore, index store).
If rebuild=True, deletes and recreates everything.
"""
if rebuild:
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
def get_vector_store() -> "PaperlessSqliteVecVectorStore":
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
if rebuild or not settings.LLM_INDEX_DIR.exists():
import faiss
from llama_index.core import StorageContext
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
embedding_dim = get_embedding_dim()
faiss_index = faiss.IndexFlatL2(embedding_dim)
vector_store = FaissVectorStore(faiss_index=faiss_index)
docstore = SimpleDocumentStore()
index_store = SimpleIndexStore()
else:
from llama_index.core import StorageContext
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
return StorageContext.from_defaults(
docstore=docstore,
index_store=index_store,
vector_store=vector_store,
persist_dir=settings.LLM_INDEX_DIR,
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
return PaperlessSqliteVecVectorStore(
uri=str(settings.LLM_INDEX_DIR),
)
# --- LLM index locking ---------------------------------------------------
#
# Two locks guard the index; they answer different questions and are NOT
# interchangeable:
#
# * settings.LLM_INDEX_LOCK (FileLock, exclusive) -- serializes WRITERS against
# each other, so only one rebuild/upsert/delete/compaction runs at a time.
# Taken by write_store(). Readers never take it, so it never blocks reads.
#
# * settings.LLM_INDEX_RWLOCK (ReadWriteLock) -- coordinates readers against the
# compaction/migration file swap. read_store() takes it SHARED (readers run
# concurrently); _exclude_readers() takes it EXCLUSIVE, only for the swap, so
# the database file is never replaced while a reader connection is open (that
# would alias the old WAL onto the new file and corrupt it).
#
# | vs another writer | vs a reader
# -----------------+-------------------+----------------------------
# normal write | LLM_INDEX_LOCK | nothing (WAL gives MVCC)
# compaction/swap | LLM_INDEX_LOCK | LLM_INDEX_RWLOCK (exclusive)
# reader | nothing (WAL) | LLM_INDEX_RWLOCK (shared)
#
# They can't be merged into one ReadWriteLock: a normal write must exclude other
# writers WITHOUT blocking readers (WAL already gives reader/writer concurrency),
# and ReadWriteLock has no "exclusive vs writers, shared vs readers" mode. Only
# the swap needs to exclude readers.
def _index_rwlock() -> ReadWriteLock:
"""Return a fresh read/write lock instance for the index swap.
``is_singleton=False`` so reads and the swap always coordinate through
SQLite (the actual cross-process case) rather than hitting the in-process
reentrant-upgrade guard; callers must ``close()`` it (the context managers
below do).
"""
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
return ReadWriteLock(str(settings.LLM_INDEX_RWLOCK), is_singleton=False)
@contextmanager
def read_store():
"""Acquire the shared read lock and yield the vector store for a read.
The shared lock is held for the whole lifetime of the connection (and
closed on exit) so the compaction/migration swap, which takes the exclusive
lock, never runs while this connection is open. Concurrent readers do not
block each other; only the swap does.
"""
lock = _index_rwlock()
try:
with lock.read_lock(), get_vector_store() as store:
yield store
finally:
lock.close()
@contextmanager
def _exclude_readers():
"""Acquire exclusive index access, blocking until readers have drained.
The exclusive counterpart to ``read_store()``: a compaction or migration
must not run while any reader connection is open. Raises
:class:`filelock.Timeout` if active readers do not drain within
``LLM_INDEX_COMPACTION_LOCK_TIMEOUT``; callers skip the operation on timeout.
"""
lock = _index_rwlock()
try:
with lock.write_lock(timeout=settings.LLM_INDEX_COMPACTION_LOCK_TIMEOUT):
yield
finally:
lock.close()
@contextmanager
def write_store(embed_model_name: str | None = None):
"""Acquire the write lock and yield the vector store.
All mutating operations (upsert, delete, rebuild, compact) must go through
this context manager to serialise concurrent Celery writers.
Read paths use ``read_store()`` so they hold the shared read lock.
Pass ``embed_model_name`` whenever the operation may create the table so
the model name is recorded in the schema metadata for future mismatch checks.
"""
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
with (
FileLock(settings.LLM_INDEX_LOCK),
PaperlessSqliteVecVectorStore(
uri=str(settings.LLM_INDEX_DIR),
embed_model_name=embed_model_name,
) as store,
):
yield store
def build_document_node(
document: Document,
*,
@@ -130,6 +186,9 @@ def build_document_node(
"document_type": document.document_type.name
if document.document_type
else None,
"filename": document.filename,
"storage_path": document.storage_path.name if document.storage_path else None,
"archive_serial_number": document.archive_serial_number,
"created": document.created.isoformat() if document.created else None,
"added": document.added.isoformat() if document.added else None,
"modified": document.modified.isoformat(),
@@ -142,9 +201,11 @@ def build_document_node(
# the token count and exceed embedding models with small context windows
# (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
doc = LlamaDocument(
id_=str(document.id),
text=text,
metadata=metadata,
excluded_embed_metadata_keys=list(metadata.keys()),
excluded_llm_metadata_keys=["document_id"],
)
chunk_size = chunk_size or get_rag_chunk_size()
parser = SimpleNodeParser(
@@ -154,76 +215,33 @@ def build_document_node(
return parser.get_nodes_from_documents([doc])
def load_or_build_index(nodes=None):
"""
Load an existing VectorStoreIndex if present,
or build a new one using provided nodes if storage is empty.
def load_or_build_index(config: AIConfig, store: "PaperlessSqliteVecVectorStore"):
"""Return a VectorStoreIndex backed by ``store``.
``store`` is supplied by the caller's ``read_store()`` context so the shared
read lock and the connection stay alive for the whole retrieval.
"""
import llama_index.core.settings as llama_settings
from llama_index.core import VectorStoreIndex
from llama_index.core import load_index_from_storage
embed_model = get_embedding_model()
embed_model = get_embedding_model(config)
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context()
try:
return load_index_from_storage(storage_context=storage_context)
except ValueError as e:
logger.warning("Failed to load index from storage: %s", e)
if not nodes:
queue_llm_index_update_if_needed(
rebuild=vector_store_file_exists(),
reason="LLM index missing or invalid while loading.",
)
logger.info("No nodes provided for index creation.")
raise
return VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
)
return VectorStoreIndex.from_vector_store(
vector_store=store,
embed_model=embed_model,
)
def remove_document_docstore_nodes(document: Document, index: "VectorStoreIndex"):
"""
Removes existing documents from docstore for a given document from the index.
This is necessary because FAISS IndexFlatL2 is append-only.
"""
all_node_ids = list(index.docstore.docs.keys())
existing_nodes = [
node.node_id
for node in index.docstore.get_nodes(all_node_ids)
if node.metadata.get("document_id") == str(document.id)
]
for node_id in existing_nodes:
# Delete from docstore, FAISS IndexFlatL2 are append-only
index.docstore.delete_document(node_id)
# Also purge the FAISS position -> UUID mapping so subsequent similarity
# queries don't raise KeyError on ghost vector positions.
stale_keys = [
k for k, v in index.index_struct.nodes_dict.items() if v == node_id
]
for key in stale_keys:
del index.index_struct.nodes_dict[key]
# Re-sync the mutated index_struct so persist() writes the updated nodes_dict.
index.storage_context.index_store.add_index_struct(index.index_struct)
def vector_store_file_exists():
"""
Check if the vector store file exists in the LLM index directory.
"""
return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists()
def llm_index_exists() -> bool:
"""True when the index table exists on disk."""
with read_store() as store:
return store.table_exists()
def get_rag_chunk_size() -> int:
return AIConfig().llm_embedding_chunk_size
def get_rag_context_size() -> int:
return AIConfig().llm_context_size
def get_rag_chunk_overlap(chunk_size: int | None = None) -> int:
chunk_size = chunk_size or get_rag_chunk_size()
return min(RAG_CHUNK_OVERLAP, chunk_size - 1)
@@ -249,123 +267,149 @@ def get_rag_prompt_helper(
)
def _embed_nodes(nodes: list["BaseNode"], embed_model) -> None:
"""Embed ``nodes`` in place using ``embed_model``."""
from llama_index.core.schema import MetadataMode
texts = [n.get_content(metadata_mode=MetadataMode.EMBED) for n in nodes]
for node, emb in zip(
nodes,
embed_model.get_text_embedding_batch(texts),
strict=True,
):
node.embedding = emb
def _document_id_filters(doc_ids):
"""Return a MetadataFilters IN filter scoped to ``doc_ids``."""
from llama_index.core.vector_stores.types import FilterOperator
from llama_index.core.vector_stores.types import MetadataFilter
from llama_index.core.vector_stores.types import MetadataFilters
return MetadataFilters(
filters=[
MetadataFilter(
key="document_id",
operator=FilterOperator.IN,
value=sorted(doc_ids),
),
],
)
def update_llm_index(
*,
iter_wrapper: IterWrapper[Document] = identity,
rebuild=False,
) -> str:
"""
Rebuild or update the LLM index.
"""
from llama_index.core import VectorStoreIndex
nodes = []
"""Rebuild or incrementally update the LLM index."""
with write_store() as store:
try:
with _exclude_readers():
needs_reembed = store.check_and_run_migrations()
except Timeout:
logger.info(
"Skipping LLM index migration check: index readers are active; "
"will retry next run.",
)
needs_reembed = False
if needs_reembed:
logger.warning(
"LLM index migration requires re-embedding; forcing rebuild.",
)
rebuild = True
documents = Document.objects.all()
if not documents.exists():
no_documents = not documents.exists()
# Fast exit before touching config: nothing to index and no existing index.
if no_documents and not rebuild and not llm_index_exists():
logger.warning("No documents found to index.")
if not rebuild and not vector_store_file_exists():
return "No documents found to index."
return "No documents found to index."
config = AIConfig()
model_name = get_configured_model_name(config)
if not rebuild and llm_index_exists():
with read_store() as store:
config_mismatch = store.config_mismatch(model_name)
if config_mismatch:
logger.warning("Embedding model changed; forcing LLM index rebuild.")
rebuild = True
if no_documents:
logger.warning("No documents found to index.")
chunk_size = config.llm_embedding_chunk_size
embed_model = get_embedding_model(config)
with FileLock(_index_lock_path()):
if rebuild or not vector_store_file_exists():
# remove meta.json to force re-detection of embedding dim
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
# Rebuild index from scratch
with write_store(embed_model_name=model_name) as store:
if rebuild or not store.table_exists():
logger.info("Rebuilding LLM index.")
import llama_index.core.settings as llama_settings
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=True)
store.drop_table()
for document in iter_wrapper(documents):
document_nodes = build_document_node(document, chunk_size=chunk_size)
nodes.extend(document_nodes)
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
show_progress=False,
)
nodes = build_document_node(document, chunk_size=chunk_size)
_embed_nodes(nodes, embed_model)
store.add(nodes)
msg = "LLM index rebuilt successfully."
else:
# Update existing index
index = load_or_build_index()
existing_nodes: defaultdict[str, list] = defaultdict(list)
for node in index.docstore.docs.values():
doc_id = node.metadata.get("document_id")
if doc_id is not None:
existing_nodes[doc_id].append(node)
existing = store.get_modified_times()
changed = 0
for document in iter_wrapper(documents):
doc_id = str(document.id)
document_modified = document.modified.isoformat()
if existing.get(doc_id) == document.modified.isoformat():
continue
nodes = build_document_node(document, chunk_size=chunk_size)
_embed_nodes(nodes, embed_model)
store.upsert_document(doc_id, nodes)
changed += 1
msg = (
"LLM index updated successfully."
if changed
else "No changes detected in LLM index."
)
if doc_id in existing_nodes:
doc_nodes = existing_nodes[doc_id]
node_modified = doc_nodes[0].metadata.get("modified")
if node_modified == document_modified:
continue
# Delete from docstore, FAISS IndexFlatL2 are append-only
for node in doc_nodes:
remove_document_docstore_nodes(document, index)
nodes.extend(build_document_node(document, chunk_size=chunk_size))
if nodes:
msg = "LLM index updated successfully."
logger.info(
"Updating %d nodes in LLM index.",
len(nodes),
)
index.insert_nodes(nodes)
else:
msg = "No changes detected in LLM index."
logger.info(msg)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
try:
with _exclude_readers():
store.compact()
except Timeout:
logger.info(
"Skipping LLM index compaction: index readers are active; "
"will retry next run.",
)
return msg
def llm_index_add_or_update_document(document: Document):
"""
Adds or updates a document in the LLM index.
If the document already exists, it will be replaced.
"""
new_nodes = build_document_node(document, chunk_size=get_rag_chunk_size())
if not new_nodes:
logger.warning(
"No indexable content for document %s; skipping LLM index update.",
document.pk,
)
return
"""Add or atomically replace a document's chunks in the index."""
config = AIConfig()
new_nodes = build_document_node(
document,
chunk_size=config.llm_embedding_chunk_size,
)
if new_nodes:
_embed_nodes(new_nodes, get_embedding_model(config))
with FileLock(_index_lock_path()):
index = load_or_build_index(nodes=new_nodes)
with write_store(embed_model_name=get_configured_model_name(config)) as store:
store.upsert_document(str(document.id), new_nodes)
remove_document_docstore_nodes(document, index)
index.insert_nodes(new_nodes)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def llm_index_compact() -> None:
"""Compact the index immediately, rebuilding the table to reclaim space."""
with write_store() as store:
try:
with _exclude_readers():
store.compact(force=True)
except Timeout:
logger.info(
"Skipping LLM index compaction: index readers are active; "
"will retry next run.",
)
def llm_index_remove_document(document: Document):
"""
Removes a document from the LLM index.
"""
with FileLock(_index_lock_path()):
index = load_or_build_index()
remove_document_docstore_nodes(document, index)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
"""Remove a document's chunks from the LLM index."""
with write_store() as store:
store.delete(str(document.id))
def truncate_content(
@@ -399,6 +443,18 @@ def truncate_content(
return " ".join(truncated_chunks)
def truncate_embedding_query(content: str, *, chunk_size: int) -> str:
from llama_index.core.text_splitter import TokenTextSplitter
splitter = TokenTextSplitter(
separator=" ",
chunk_size=chunk_size,
chunk_overlap=0,
)
content_chunks = splitter.split_text(content)
return content_chunks[0] if content_chunks else ""
def normalize_document_ids(document_ids: Iterable[int | str] | None) -> set[str] | None:
if document_ids is None:
return None
@@ -410,77 +466,58 @@ def query_similar_documents(
top_k: int = 5,
document_ids: Iterable[int | str] | None = None,
) -> list[Document]:
"""
Runs a similarity query and returns top-k similar Document objects.
"""
"""Return up to ``top_k`` Documents most similar to ``document``."""
allowed_document_ids = normalize_document_ids(document_ids)
if allowed_document_ids is not None and not allowed_document_ids:
return []
if not vector_store_file_exists():
if not llm_index_exists():
queue_llm_index_update_if_needed(
rebuild=False,
reason="LLM index not found for similarity query.",
)
return []
with FileLock(_index_lock_path()):
index = load_or_build_index()
config = AIConfig()
# constrain only the node(s) that match the document IDs, if given
doc_node_ids = (
[
node.node_id
for node in index.docstore.docs.values()
if node.metadata.get("document_id") in allowed_document_ids
]
if allowed_document_ids is not None
else None
)
if doc_node_ids is not None and not doc_node_ids:
return []
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.retrievers import VectorIndexRetriever
filters = (
_document_id_filters(allowed_document_ids)
if allowed_document_ids is not None
else None
)
query_text = truncate_embedding_query(
(document.title or "") + "\n" + (document.content or ""),
chunk_size=config.llm_embedding_chunk_size,
)
# Hold the shared read lock for the whole retrieval so the connection is
# never open across a compaction swap. The retrieve() call generates a
# query embedding (a slow external request) and searches the vector store;
# no Django ORM access happens during it, so release the pooled DB
# connection for its duration. See #12976.
with read_store() as store:
index = load_or_build_index(config, store)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k,
doc_ids=doc_node_ids,
filters=filters,
)
config = AIConfig()
query_text = truncate_content(
(document.title or "") + "\n" + (document.content or ""),
chunk_size=config.llm_embedding_chunk_size,
context_size=config.llm_context_size,
)
try:
with db_connection_released():
results = retriever.retrieve(query_text)
except KeyError as e:
# Ghost FAISS positions remain after deletion because IndexFlatL2 is
# append-only. Treat them as absent and return no results.
logger.debug(
"Skipping LLM similarity query for document %s due to a stale "
"FAISS position with no docstore node: %s",
document.pk,
e,
)
return []
retrieved_document_ids: list[int] = []
for node in results:
document_id = node.metadata.get("document_id")
if document_id is None:
continue
normalized_document_id = str(document_id)
if (
allowed_document_ids is not None
and normalized_document_id not in allowed_document_ids
):
normalized = str(document_id)
if allowed_document_ids is not None and normalized not in allowed_document_ids:
continue
try:
retrieved_document_ids.append(int(normalized_document_id))
except ValueError:
retrieved_document_ids.append(int(normalized))
except ValueError: # pragma: no cover
logger.warning(
"Skipping LLM index result with invalid document_id %r.",
document_id,
+27 -1
View File
@@ -1,10 +1,36 @@
from pathlib import Path
import pytest
import pytest_mock
from llama_index.core.base.embeddings.base import BaseEmbedding
from pytest_django.fixtures import SettingsWrapper
@pytest.fixture
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper):
def temp_llm_index_dir(tmp_path: Path, settings: SettingsWrapper) -> Path:
settings.LLM_INDEX_DIR = tmp_path
settings.LLM_INDEX_LOCK = tmp_path / "index.lock"
settings.LLM_INDEX_RWLOCK = tmp_path / "llmindex.rwlock.db"
return tmp_path
class FakeEmbedding(BaseEmbedding):
async def _aget_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_text_embedding(self, text: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def get_query_embedding_dim(self) -> int:
return 384
@pytest.fixture
def mock_embed_model(mocker: pytest_mock.MockerFixture) -> pytest_mock.MockType:
fake = FakeEmbedding()
mocker.patch("paperless_ai.indexing.get_embedding_model", return_value=fake)
mocker.patch("paperless_ai.embedding.get_embedding_model", return_value=fake)
return fake
+4 -2
View File
@@ -6,6 +6,7 @@ import pytest
from django.test import override_settings
from documents.models import Document
from paperless.config import AIConfig
from paperless_ai.ai_classifier import build_localization_prompt
from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag
@@ -211,11 +212,12 @@ def test_prompt_with_without_rag(mock_document):
"paperless_ai.ai_classifier.get_context_for_document",
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
config = AIConfig()
prompt = build_prompt_without_rag(mock_document, config)
assert "Additional context from similar documents" not in prompt
assert "for generated" not in prompt
prompt = build_prompt_with_rag(mock_document)
prompt = build_prompt_with_rag(mock_document, config)
assert "Additional context from similar documents" in prompt
prompt = build_localization_prompt(
+311 -435
View File
@@ -1,15 +1,12 @@
import json
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
import pytest_mock
from django.contrib.auth.models import User
from django.test import override_settings
from django.utils import timezone
from faker import Faker
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.schema import MetadataMode
from documents.models import Document
from documents.models import PaperlessTask
@@ -19,10 +16,12 @@ from documents.tests.factories import DocumentFactory
from documents.tests.factories import PaperlessTaskFactory
from paperless.models import ApplicationConfiguration
from paperless_ai import indexing
from paperless_ai.tests.conftest import FakeEmbedding
from paperless_ai.vector_store import PaperlessSqliteVecVectorStore
@pytest.fixture
def real_document(db):
def real_document(db: None) -> Document:
return Document.objects.create(
title="Test Document",
content="This is some test content.",
@@ -30,44 +29,39 @@ def real_document(db):
)
@pytest.fixture
def mock_embed_model():
fake = FakeEmbedding()
with (
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
patch(
"paperless_ai.embedding.get_embedding_model",
) as mock_embedding,
):
mock_index.return_value = fake
mock_embedding.return_value = fake
yield mock_index
class FakeEmbedding(BaseEmbedding):
# TODO: maybe a better way to do this?
def _aget_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_text_embedding(self, text: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def get_query_embedding_dim(self) -> int:
return 384 # Match your real FAISS config
@pytest.mark.django_db
def test_build_document_node(real_document) -> None:
def test_build_document_node(real_document: Document) -> None:
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
assert nodes[0].metadata["document_id"] == str(real_document.id)
assert nodes[0].metadata["filename"] == real_document.filename
assert nodes[0].metadata["storage_path"] == (
real_document.storage_path.name if real_document.storage_path else None
)
assert (
nodes[0].metadata["archive_serial_number"]
== real_document.archive_serial_number
)
assert "filename" in nodes[0].excluded_embed_metadata_keys
assert "filename" not in nodes[0].excluded_llm_metadata_keys
@pytest.mark.django_db
def test_build_document_node_excludes_metadata_from_embedding(real_document) -> None:
def test_build_document_node_sets_ref_doc_id(real_document: Document) -> None:
"""Every node produced by build_document_node must carry the paperless document id
as its ref_doc_id so that the vector store's delete(str(doc.id)) works correctly."""
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0, "Expected at least one node"
for node in nodes:
assert node.ref_doc_id == str(real_document.id), (
f"Expected ref_doc_id={real_document.id!r}, got {node.ref_doc_id!r}"
)
@pytest.mark.django_db
def test_build_document_node_excludes_metadata_from_embedding(
real_document: Document,
) -> None:
"""Metadata keys must not be prepended to the embedding text.
build_llm_index_text already encodes all metadata in the body text, so
@@ -75,8 +69,6 @@ def test_build_document_node_excludes_metadata_from_embedding(real_document) ->
double the token count and exceed embedding models with small context
windows (e.g. nomic-embed-text via Ollama defaults to num_ctx=2048).
"""
from llama_index.core.schema import MetadataMode
nodes = indexing.build_document_node(real_document)
for node in nodes:
embed_text = node.get_content(metadata_mode=MetadataMode.EMBED)
@@ -87,7 +79,36 @@ def test_build_document_node_excludes_metadata_from_embedding(real_document) ->
@pytest.mark.django_db
def test_build_document_node_uses_rag_chunk_settings(real_document) -> None:
def test_build_document_node_structured_fields_in_metadata(
real_document: Document,
) -> None:
"""Structured fields must be in node.metadata so the LLM receives them via metadata prepend."""
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
for node in nodes:
assert "title" in node.metadata
assert "tags" in node.metadata
assert "correspondent" in node.metadata
assert "document_type" in node.metadata
assert "created" in node.metadata
assert "added" in node.metadata
assert "modified" in node.metadata
@pytest.mark.django_db
def test_build_document_node_excludes_document_id_from_llm_context(
real_document: Document,
) -> None:
"""document_id is an internal key and must not appear in LLM context text."""
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
for node in nodes:
assert "document_id" in node.excluded_llm_metadata_keys
assert "document_id" not in node.get_content(metadata_mode=MetadataMode.LLM)
@pytest.mark.django_db
def test_build_document_node_uses_rag_chunk_settings(real_document: Document) -> None:
app_config, _ = ApplicationConfiguration.objects.get_or_create()
app_config.llm_embedding_chunk_size = 512
app_config.save()
@@ -116,11 +137,21 @@ def test_get_rag_prompt_helper_uses_context_setting() -> None:
assert prompt_helper.context_window == 4096
def test_truncate_embedding_query_returns_single_chunk() -> None:
content = " ".join(f"word{i}" for i in range(200))
result = indexing.truncate_embedding_query(content, chunk_size=32)
assert result
assert result != content
assert "word199" not in result
@pytest.mark.django_db
def test_update_llm_index(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: FakeEmbedding,
) -> None:
mock_config = MagicMock()
mock_config.llm_embedding_chunk_size = 512
@@ -138,44 +169,49 @@ def test_update_llm_index(
ai_config.assert_called_once()
build_document_node.assert_called_once_with(real_document, chunk_size=512)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_update_llm_index_removes_meta(
temp_llm_index_dir,
real_document,
mock_embed_model,
def test_update_llm_index_rebuilds_on_model_name_change(
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: FakeEmbedding,
) -> None:
# Pre-create a meta.json with incorrect data
(temp_llm_index_dir / "meta.json").write_text(
json.dumps({"embedding_model": "old", "dim": 1}),
)
# Build initial index with model "model-a".
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
with patch(
"paperless_ai.indexing.get_configured_model_name",
return_value="model-a",
):
indexing.update_llm_index(rebuild=True)
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
from paperless.config import AIConfig
# Simulate config change to "model-b"; the incremental run must force a rebuild.
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
with patch(
"paperless_ai.indexing.get_configured_model_name",
return_value="model-b",
):
indexing.update_llm_index(rebuild=False)
config = AIConfig()
expected_model = config.llm_embedding_model or (
"text-embedding-3-small"
if config.llm_embedding_backend == "openai-like"
else "sentence-transformers/all-MiniLM-L6-v2"
)
assert meta == {"embedding_model": expected_model, "dim": 384}
with indexing.get_vector_store() as store:
# Schema metadata only updates when the table is dropped and recreated, never
# on incremental writes -- so "model-b" here proves a full rebuild happened.
assert store.stored_model_name() == "model-b"
@pytest.mark.django_db
def test_update_llm_index_partial_update(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: FakeEmbedding,
) -> None:
doc2 = Document.objects.create(
title="Test Document 2",
@@ -210,131 +246,34 @@ def test_update_llm_index_partial_update(
mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3])
mock_all.return_value = mock_queryset
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=False)
mock_logger.info.assert_called_once_with(
"Updating %d nodes in LLM index.",
2,
)
indexing.update_llm_index(rebuild=False)
assert any(temp_llm_index_dir.glob("*.json"))
def test_get_or_create_storage_context_raises_exception(
temp_llm_index_dir,
mock_embed_model,
) -> None:
with pytest.raises(Exception):
indexing.get_or_create_storage_context(rebuild=False)
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
)
def test_load_or_build_index_builds_when_nodes_given(
temp_llm_index_dir,
real_document,
mock_embed_model,
) -> None:
with (
patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
),
patch(
"llama_index.core.VectorStoreIndex",
return_value=MagicMock(),
) as mock_index_cls,
patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
) as mock_storage,
):
mock_storage.return_value.persist_dir = temp_llm_index_dir
indexing.load_or_build_index(
nodes=[indexing.build_document_node(real_document)],
with indexing.get_vector_store() as store:
assert store.table_exists(), (
"Expected the vector store table to exist after incremental update"
)
mock_index_cls.assert_called_once()
def test_load_or_build_index_raises_exception_when_no_nodes(
temp_llm_index_dir,
mock_embed_model,
) -> None:
with (
patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
),
patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
),
):
with pytest.raises(Exception):
indexing.load_or_build_index()
@pytest.mark.django_db
def test_load_or_build_index_succeeds_when_nodes_given(
temp_llm_index_dir,
mock_embed_model,
) -> None:
with (
patch(
"llama_index.core.load_index_from_storage",
side_effect=ValueError("Index not found"),
),
patch(
"llama_index.core.VectorStoreIndex",
return_value=MagicMock(),
) as mock_index_cls,
patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
) as mock_storage,
):
mock_storage.return_value.persist_dir = temp_llm_index_dir
indexing.load_or_build_index(
nodes=[MagicMock()],
)
mock_index_cls.assert_called_once()
@pytest.mark.django_db
def test_add_or_update_document_updates_existing_entry(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: FakeEmbedding,
) -> None:
indexing.update_llm_index(rebuild=True)
indexing.llm_index_add_or_update_document(real_document)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_remove_document_deletes_node_from_docstore(
temp_llm_index_dir,
real_document,
mock_embed_model,
) -> None:
indexing.update_llm_index(rebuild=True)
index = indexing.load_or_build_index()
assert len(index.docstore.docs) == 1
indexing.llm_index_remove_document(real_document)
index = indexing.load_or_build_index()
assert len(index.docstore.docs) == 0
with indexing.get_vector_store() as store:
assert store.table_exists(), (
"Expected the vector store table to exist after add-or-update"
)
@pytest.mark.django_db
def test_query_after_remove_does_not_raise_key_error(
temp_llm_index_dir,
real_document,
mock_embed_model,
temp_llm_index_dir: Path,
real_document: Document,
mock_embed_model: FakeEmbedding,
) -> None:
indexing.update_llm_index(rebuild=True)
@@ -352,8 +291,8 @@ def test_query_after_remove_does_not_raise_key_error(
@pytest.mark.django_db
def test_update_llm_index_no_documents(
temp_llm_index_dir,
mock_embed_model,
temp_llm_index_dir: Path,
mock_embed_model: FakeEmbedding,
) -> None:
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
@@ -369,6 +308,22 @@ def test_update_llm_index_no_documents(
)
@pytest.mark.django_db
def test_update_no_documents_no_index_returns_early(
temp_llm_index_dir: Path,
mocker: pytest_mock.MockerFixture,
) -> None:
"""update with no documents and no existing index must return early."""
mock_qs = MagicMock()
mock_qs.exists.return_value = False
mock_qs.__iter__ = MagicMock(return_value=iter([]))
mocker.patch("paperless_ai.indexing.Document.objects.all", return_value=mock_qs)
result = indexing.update_llm_index(rebuild=False)
assert result == "No documents found to index."
@pytest.mark.django_db
def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -> None:
# No existing tasks
@@ -406,20 +361,17 @@ def test_queue_llm_index_update_if_needed_enqueues_when_idle_or_skips_recent() -
LLM_BACKEND="ollama",
)
def test_query_similar_documents(
temp_llm_index_dir,
real_document,
temp_llm_index_dir: Path,
real_document: Document,
) -> None:
with (
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch(
"paperless_ai.indexing.vector_store_file_exists",
"paperless_ai.indexing.llm_index_exists",
) as mock_vector_store_exists,
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
):
mock_storage.return_value = MagicMock()
mock_storage.return_value.persist_dir = temp_llm_index_dir
mock_vector_store_exists.return_value = True
mock_index = MagicMock()
@@ -451,14 +403,50 @@ def test_query_similar_documents(
assert result == mock_filtered_docs
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_EMBEDDING_CHUNK_SIZE=32,
LLM_BACKEND="ollama",
)
def test_query_similar_documents_truncates_query_to_embedding_chunk_size(
temp_llm_index_dir: Path,
real_document: Document,
) -> None:
real_document.content = " ".join(f"word{i}" for i in range(200))
with (
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch(
"paperless_ai.indexing.llm_index_exists",
) as mock_vector_store_exists,
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
patch("paperless_ai.indexing.truncate_content") as mock_truncate_content,
):
mock_vector_store_exists.return_value = True
mock_load_or_build_index.return_value = MagicMock()
mock_truncate_content.return_value = "wrong helper"
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = []
mock_retriever_cls.return_value = mock_retriever
mock_filter.return_value = []
indexing.query_similar_documents(real_document, top_k=3)
mock_truncate_content.assert_not_called()
query_text = mock_retriever.retrieve.call_args.args[0]
assert query_text
assert "word199" not in query_text
@pytest.mark.django_db
def test_query_similar_documents_triggers_update_when_index_missing(
temp_llm_index_dir,
real_document,
temp_llm_index_dir: Path,
real_document: Document,
) -> None:
with (
patch(
"paperless_ai.indexing.vector_store_file_exists",
"paperless_ai.indexing.llm_index_exists",
return_value=False,
),
patch(
@@ -479,120 +467,13 @@ def test_query_similar_documents_triggers_update_when_index_missing(
assert result == []
@pytest.mark.django_db
def test_query_similar_documents_normalizes_and_post_filters_allowed_ids(
real_document,
) -> None:
real_document.owner = User.objects.create_user(username="rag-owner")
real_document.save()
private_owner = User.objects.create_user(username="rag-private-owner")
private_document = Document.objects.create(
title="Private similar document",
content="Similar private content that must not reach RAG.",
owner=private_owner,
added=timezone.now(),
)
with (
patch(
"paperless_ai.indexing.vector_store_file_exists",
return_value=True,
),
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch("llama_index.core.retrievers.VectorIndexRetriever") as mock_retriever_cls,
):
allowed_node = MagicMock()
allowed_node.node_id = "allowed-node"
allowed_node.metadata = {"document_id": str(real_document.pk)}
private_node = MagicMock()
private_node.node_id = "private-node"
private_node.metadata = {"document_id": str(private_document.pk)}
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [allowed_node, private_node]
mock_load_or_build_index.return_value = mock_index
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = [private_node, allowed_node]
mock_retriever_cls.return_value = mock_retriever
result = indexing.query_similar_documents(
real_document,
top_k=2,
document_ids=[real_document.pk],
)
mock_retriever_cls.assert_called_once_with(
index=mock_index,
similarity_top_k=2,
doc_ids=["allowed-node"],
)
assert result == [real_document]
assert private_document not in result
class TestUpdateLlmIndexStaleNodes:
"""Tests that update_llm_index removes ALL nodes for a multi-chunk document."""
@pytest.mark.django_db
def test_incremental_update_removes_all_old_nodes_for_multi_chunk_document(
self,
temp_llm_index_dir,
mock_embed_model: MagicMock,
) -> None:
"""Ghost nodes from all chunks of a modified document must be removed.
When a document is split into multiple chunks (chunk_size=1024), the
incremental update path must delete every old node, not just the last
one captured by a dict comprehension keyed on document_id.
"""
# Content long enough to produce at least two chunks at chunk_size=1024.
# Generate many paragraphs so the token count comfortably exceeds 1024.
fake = Faker()
long_content = "\n\n".join(fake.paragraph(nb_sentences=20) for _ in range(20))
doc = DocumentFactory(content=long_content)
# Build the initial index (rebuild=True) so it has multiple nodes
indexing.update_llm_index(rebuild=True)
# Verify the initial index has more than one node for this document
initial_index = indexing.load_or_build_index()
initial_node_ids = [
nid
for nid, node in initial_index.docstore.docs.items()
if node.metadata.get("document_id") == str(doc.id)
]
assert len(initial_node_ids) > 1, (
f"Expected multiple chunks but got {len(initial_node_ids)}; "
"increase long_content length"
)
# Simulate a modification so the incremental path treats it as changed.
# Use queryset.update() to bypass auto_now and actually change the DB value.
new_modified = timezone.now()
Document.objects.filter(pk=doc.pk).update(modified=new_modified)
# Run incremental update (rebuild=False) with the modified document
indexing.update_llm_index(rebuild=False)
# Reload the persisted index and check that no OLD node ids remain
updated_index = indexing.load_or_build_index()
remaining_old_node_ids = [
nid for nid in initial_node_ids if nid in updated_index.docstore.docs
]
assert remaining_old_node_ids == [], (
f"Ghost nodes still present after incremental update: "
f"{remaining_old_node_ids}"
)
@pytest.mark.django_db
def test_query_similar_documents_empty_allow_list_fails_closed(
real_document,
real_document: Document,
) -> None:
with (
patch(
"paperless_ai.indexing.vector_store_file_exists",
"paperless_ai.indexing.llm_index_exists",
return_value=True,
) as mock_vector_store_exists,
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
@@ -610,27 +491,25 @@ def test_query_similar_documents_empty_allow_list_fails_closed(
class TestUpdateLlmIndexEmptyDocumentSet:
"""update_llm_index must persist an empty index when all documents are deleted.
"""update_llm_index must clear the vector store table when all documents are deleted.
Without this, the stale on-disk FAISS vectors are never cleared and
subsequent similarity searches return phantom hits for document IDs that
no longer exist in the DB.
Without this, the stale vectors are never cleared and subsequent similarity
searches return phantom hits for document IDs that no longer exist in the DB.
"""
@pytest.mark.django_db
def test_rebuild_clears_stale_index_when_no_documents_exist(
self,
temp_llm_index_dir: Path,
mock_embed_model: MagicMock,
mock_embed_model: FakeEmbedding,
) -> None:
"""After deleting all documents, rebuild=True must persist an empty index.
"""After deleting all documents, rebuild=True must produce a table with zero rows.
Steps:
1. Build an index with one document so the on-disk state is non-empty.
2. Delete all documents from the DB.
3. Call update_llm_index(rebuild=True).
4. Reload the index from disk.
5. Assert the reloaded index has zero nodes (no phantom vectors).
4. Open the LanceDB table directly and assert zero rows.
"""
# Step 1: create a document and build a non-empty index
Document.objects.create(
@@ -640,27 +519,26 @@ class TestUpdateLlmIndexEmptyDocumentSet:
)
indexing.update_llm_index(rebuild=True)
initial_index = indexing.load_or_build_index()
assert len(initial_index.docstore.docs) > 0, (
"Precondition failed: expected at least one node before deletion"
)
with indexing.get_vector_store() as store:
assert store.table_exists(), (
"Precondition failed: expected the vector store table to exist "
"before deletion"
)
# Step 2: delete all documents
Document.objects.all().delete()
assert not Document.objects.exists()
# Step 3: rebuild with no documents
# Step 3: rebuild with no documents — drop_table is called so the table
# is removed (no rows to re-insert, so it stays absent).
indexing.update_llm_index(rebuild=True)
# Step 4: reload the persisted index from disk
reloaded_index = indexing.load_or_build_index()
# Step 5: phantom vectors must be gone
assert len(reloaded_index.docstore.docs) == 0, (
f"Expected 0 nodes after clearing all documents, "
f"but found {len(reloaded_index.docstore.docs)}: "
f"{list(reloaded_index.docstore.docs.keys())}"
)
# Step 4: the table must be absent (no rows) — phantom vectors gone
with indexing.get_vector_store() as store2:
assert not store2.table_exists(), (
"Expected the vector store table to be absent after rebuilding "
"with no documents"
)
class TestDocumentUpdatedSignalTriggersLlmReindex:
@@ -709,10 +587,14 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
def test_returns_without_error_when_build_document_node_returns_empty(
self,
temp_llm_index_dir: Path,
mock_embed_model: MagicMock,
mocker: pytest_mock.MockerFixture,
) -> None:
"""When build_document_node returns [], the function must return without error
and must not call load_or_build_index at all."""
"""When build_document_node returns [], the function must return without error.
The store's upsert_document treats an empty node list as a removal (no-op
delete), so load_or_build_index must not be called.
"""
mocker.patch(
"paperless_ai.indexing.build_document_node",
return_value=[],
@@ -720,6 +602,7 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
mock_load = mocker.patch("paperless_ai.indexing.load_or_build_index")
doc = MagicMock(spec=Document)
doc.id = 42
# Must not raise
indexing.llm_index_add_or_update_document(doc)
@@ -727,172 +610,165 @@ class TestLlmIndexAddOrUpdateDocumentEmptyContent:
@pytest.mark.django_db
class TestLlmIndexLocking:
"""The FAISS index mutation functions must acquire the index lock before touching the index.
def test_llm_index_compact_uses_force(
temp_llm_index_dir: Path,
mocker: pytest_mock.MockerFixture,
) -> None:
"""compact must use force=True to rebuild the table and reclaim space immediately."""
mock_store = mocker.MagicMock()
mocker.patch(
"paperless_ai.indexing.write_store",
return_value=mocker.MagicMock(
__enter__=mocker.MagicMock(return_value=mock_store),
__exit__=mocker.MagicMock(return_value=False),
),
)
Without locking, two concurrent Celery workers can each load the same
on-disk index, make independent modifications, and the last writer silently
overwrites the first's changes.
indexing.llm_index_compact()
mock_store.compact.assert_called_once_with(force=True)
@pytest.mark.django_db
class TestLlmIndexLocking:
"""Index mutation functions must go through write_store(), which holds the lock.
Without locking, two concurrent Celery workers can open the same store,
make independent modifications, and trigger CommitConflictError.
"""
def test_add_or_update_document_acquires_lock(
def test_add_or_update_document_uses_write_store(
self,
temp_llm_index_dir: Path,
mock_embed_model: FakeEmbedding,
mocker: pytest_mock.MockerFixture,
) -> None:
"""llm_index_add_or_update_document must enter the file lock before touching the index."""
call_order: list[str] = []
mock_lock_instance = MagicMock()
mock_lock_instance.__enter__ = MagicMock(
side_effect=lambda *_: call_order.append("lock_acquired"),
)
mock_lock_instance.__exit__ = MagicMock(return_value=False)
mock_file_lock_cls = mocker.patch(
"paperless_ai.indexing.FileLock",
return_value=mock_lock_instance,
)
mock_load = mocker.patch(
"paperless_ai.indexing.load_or_build_index",
side_effect=lambda *_a, **_kw: (
call_order.append("index_loaded") or MagicMock()
mock_store = MagicMock()
mocker.patch(
"paperless_ai.indexing.write_store",
return_value=mocker.MagicMock(
__enter__=mocker.MagicMock(return_value=mock_store),
__exit__=mocker.MagicMock(return_value=False),
),
)
mock_node = MagicMock()
mock_node.get_content.return_value = "fake node text"
mocker.patch(
"paperless_ai.indexing.build_document_node",
return_value=[MagicMock()],
return_value=[mock_node],
)
mocker.patch("paperless_ai.indexing.remove_document_docstore_nodes")
doc = MagicMock(spec=Document)
doc.id = 1
indexing.llm_index_add_or_update_document(doc)
mock_file_lock_cls.assert_called_once()
mock_lock_instance.__enter__.assert_called_once()
mock_load.assert_called_once()
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
"Lock must be acquired before the index is loaded"
)
mock_store.upsert_document.assert_called_once()
def test_remove_document_acquires_lock(
def test_remove_document_uses_write_store(
self,
temp_llm_index_dir: Path,
mocker: pytest_mock.MockerFixture,
) -> None:
"""llm_index_remove_document must enter the file lock before loading the index."""
call_order: list[str] = []
mock_lock_instance = MagicMock()
mock_lock_instance.__enter__ = MagicMock(
side_effect=lambda *_: call_order.append("lock_acquired"),
)
mock_lock_instance.__exit__ = MagicMock(return_value=False)
mock_file_lock_cls = mocker.patch(
"paperless_ai.indexing.FileLock",
return_value=mock_lock_instance,
)
mock_load = mocker.patch(
"paperless_ai.indexing.load_or_build_index",
side_effect=lambda *_a, **_kw: (
call_order.append("index_loaded") or MagicMock()
mock_store = MagicMock()
mocker.patch(
"paperless_ai.indexing.write_store",
return_value=mocker.MagicMock(
__enter__=mocker.MagicMock(return_value=mock_store),
__exit__=mocker.MagicMock(return_value=False),
),
)
mocker.patch("paperless_ai.indexing.remove_document_docstore_nodes")
doc = MagicMock(spec=Document)
doc.id = 1
indexing.llm_index_remove_document(doc)
mock_file_lock_cls.assert_called_once()
mock_lock_instance.__enter__.assert_called_once()
mock_load.assert_called_once()
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
"Lock must be acquired before the index is loaded"
)
mock_store.delete.assert_called_once_with("1")
def test_update_llm_index_rebuild_acquires_lock(
def test_update_llm_index_rebuild_uses_write_store(
self,
temp_llm_index_dir: Path,
mock_embed_model: MagicMock,
mock_embed_model: FakeEmbedding,
mocker: pytest_mock.MockerFixture,
) -> None:
"""update_llm_index must enter the file lock during the rebuild/persist cycle."""
mock_lock_instance = MagicMock()
mock_lock_instance.__enter__ = MagicMock(return_value=None)
mock_lock_instance.__exit__ = MagicMock(return_value=False)
mock_file_lock_cls = mocker.patch(
"paperless_ai.indexing.FileLock",
return_value=mock_lock_instance,
mock_store = MagicMock()
mocker.patch(
"paperless_ai.indexing.write_store",
return_value=mocker.MagicMock(
__enter__=mocker.MagicMock(return_value=mock_store),
__exit__=mocker.MagicMock(return_value=False),
),
)
# exists=True so the code reaches the lock; iterate over an empty
# queryset so VectorStoreIndex is called with no nodes (still exercises
# the lock path without needing heavy FAISS fixture data)
mock_qs = MagicMock()
mock_qs.exists.return_value = True
mock_qs.__iter__ = MagicMock(return_value=iter([]))
mocker.patch("paperless_ai.indexing.Document.objects.all", return_value=mock_qs)
mocker.patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
)
indexing.update_llm_index(rebuild=True)
mock_file_lock_cls.assert_called_once()
mock_lock_instance.__enter__.assert_called_once()
mock_store.drop_table.assert_called_once()
def test_query_similar_documents_acquires_lock(
@pytest.mark.django_db
@pytest.mark.django_db
class TestVectorStoreIndexing:
def test_get_vector_store_roundtrip(
self,
temp_llm_index_dir: Path,
mocker: pytest_mock.MockerFixture,
mock_embed_model: FakeEmbedding,
) -> None:
"""query_similar_documents must enter the file lock before loading the index."""
call_order: list[str] = []
with indexing.get_vector_store() as store:
assert isinstance(store, PaperlessSqliteVecVectorStore)
mock_lock_instance = MagicMock()
mock_lock_instance.__enter__ = MagicMock(
side_effect=lambda *_: call_order.append("lock_acquired"),
)
mock_lock_instance.__exit__ = MagicMock(return_value=False)
def test_add_then_remove_document(
self,
temp_llm_index_dir: Path,
mock_embed_model: FakeEmbedding,
real_document: Document,
) -> None:
indexing.llm_index_add_or_update_document(real_document)
with indexing.get_vector_store() as store:
assert store.table_exists()
count_sql = "SELECT count(*) FROM documents"
assert store.client.execute(count_sql).fetchone()[0] >= 1
mock_file_lock_cls = mocker.patch(
"paperless_ai.indexing.FileLock",
return_value=mock_lock_instance,
)
indexing.llm_index_remove_document(real_document)
assert store.client.execute(count_sql).fetchone()[0] == 0
mocker.patch(
"paperless_ai.indexing.vector_store_file_exists",
return_value=True,
)
def test_update_shrinks_chunks_without_orphans(
self,
temp_llm_index_dir: Path,
mock_embed_model: FakeEmbedding,
real_document: Document,
) -> None:
real_document.content = "word " * 4000 # many chunks
real_document.save()
indexing.llm_index_add_or_update_document(real_document)
count_sql = "SELECT count(*) FROM documents"
with indexing.get_vector_store() as store:
big = store.client.execute(count_sql).fetchone()[0]
mock_index = MagicMock()
mock_index.docstore.docs = {}
real_document.content = "short" # one chunk
real_document.save()
indexing.llm_index_add_or_update_document(real_document)
mocker.patch(
"paperless_ai.indexing.load_or_build_index",
side_effect=lambda *_a, **_kw: (
call_order.append("index_loaded") or mock_index
),
)
rows = store.client.execute(count_sql).fetchone()[0]
assert rows < big
assert rows >= 1
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = []
mocker.patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=mock_retriever,
)
mocker.patch("paperless_ai.indexing.truncate_content", return_value="")
@pytest.mark.django_db
class TestQuerySimilarDocuments:
def test_query_similar_documents_respects_allowed_ids(
self,
temp_llm_index_dir: Path,
mock_embed_model: FakeEmbedding,
) -> None:
a = DocumentFactory.create(content="alpha shared content here")
b = DocumentFactory.create(content="beta shared content here")
c = DocumentFactory.create(content="gamma shared content here")
for doc in (a, b, c):
indexing.llm_index_add_or_update_document(doc)
indexing.query_similar_documents(MagicMock(spec=Document))
results = indexing.query_similar_documents(a, document_ids=[b.id])
mock_file_lock_cls.assert_called()
mock_lock_instance.__enter__.assert_called()
assert call_order.index("lock_acquired") < call_order.index("index_loaded"), (
"Lock must be acquired before the index is loaded"
)
assert all(doc.id == b.id for doc in results)
+110 -130
View File
@@ -3,19 +3,20 @@ from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core import settings as llama_settings
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.schema import TextNode
from documents.tests.factories import DocumentFactory
from paperless_ai import chat
from paperless_ai import indexing
from paperless_ai.chat import CHAT_ERROR_MESSAGE
from paperless_ai.chat import CHAT_METADATA_DELIMITER
from paperless_ai.chat import _get_document_filtered_retriever
from paperless_ai.chat import stream_chat_with_documents
@pytest.fixture(autouse=True)
def patch_embed_model():
from llama_index.core import settings as llama_settings
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
# Use a real BaseEmbedding subclass to satisfy llama-index 0.14 validation
llama_settings.Settings.embed_model = MockEmbedding(embed_dim=1536)
yield
@@ -58,91 +59,6 @@ def assert_chat_output(
}
def add_vector_query_results(mock_index, nodes: list[TextNode]) -> None:
mock_index.index_struct.nodes_dict = {
str(vector_id): node.node_id for vector_id, node in enumerate(nodes)
}
mock_index.docstore.docs.get.side_effect = {
node.node_id: node for node in nodes
}.get
mock_index.vector_store._faiss_index.ntotal = len(nodes)
mock_index.vector_store.query.return_value = MagicMock(
ids=list(mock_index.index_struct.nodes_dict),
similarities=[0.1] * len(nodes),
)
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
def test_document_filtered_retriever_expands_filters_and_caches() -> None:
allowed_node1 = TextNode(
text="Allowed content 1.",
metadata={"document_id": "1", "title": "Allowed 1"},
)
allowed_node2 = TextNode(
text="Allowed content 2.",
metadata={"document_id": "2", "title": "Allowed 2"},
)
foreign_node = TextNode(
text="Foreign content.",
metadata={"document_id": "3", "title": "Foreign"},
)
missing_node = TextNode(
text="Missing content.",
metadata={"document_id": "1", "title": "Missing"},
)
mock_index = MagicMock()
mock_index.index_struct.nodes_dict = {
"0": foreign_node.node_id,
"1": missing_node.node_id,
"2": allowed_node1.node_id,
"3": allowed_node2.node_id,
}
mock_index.docstore.docs.get.side_effect = {
allowed_node1.node_id: allowed_node1,
allowed_node2.node_id: allowed_node2,
foreign_node.node_id: foreign_node,
}.get
mock_index.vector_store._faiss_index.ntotal = 4
mock_index.vector_store.query.side_effect = [
MagicMock(ids=["0", "2"], similarities=[0.9, 0.8]),
MagicMock(ids=["0", "1", "3"], similarities=[0.9, 0.7, 0.6]),
]
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
retriever = _get_document_filtered_retriever(
mock_index,
{"1", "2"},
similarity_top_k=2,
)
nodes = retriever.retrieve("question")
cached_nodes = retriever.retrieve("question")
assert [node.node.node_id for node in nodes] == [
allowed_node1.node_id,
allowed_node2.node_id,
]
assert cached_nodes == nodes
assert mock_index.vector_store.query.call_count == 2
assert mock_index._embed_model.get_agg_embedding_from_queries.call_count == 1
def test_document_filtered_retriever_handles_empty_faiss_index() -> None:
mock_index = MagicMock()
mock_index.vector_store._faiss_index.ntotal = 0
mock_index._embed_model.get_agg_embedding_from_queries.return_value = [0.1] * 1536
retriever = _get_document_filtered_retriever(
mock_index,
{"1"},
similarity_top_k=2,
)
assert retriever.retrieve("question") == []
mock_index.vector_store.query.assert_not_called()
@pytest.mark.django_db
def test_stream_chat_with_one_document_retrieval(
mock_document,
@@ -164,17 +80,31 @@ def test_stream_chat_with_one_document_retrieval(
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
add_vector_query_results(mock_index, [mock_node])
# Simulate get_nodes returning nodes (content exists)
mock_index.vector_store.get_nodes.return_value = [mock_node]
mock_load_index.return_value = mock_index
mock_retriever_instance = MagicMock()
mock_retriever_instance.retrieve.return_value = [
MagicMock(
metadata={
"document_id": str(mock_document.pk),
"title": "Test Document",
},
),
]
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
output = list(stream_chat_with_documents("What is this?", [mock_document]))
with patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=mock_retriever_instance,
):
output = list(stream_chat_with_documents("What is this?", [mock_document]))
mock_query_engine.query.assert_called_once_with("What is this?")
patch_embed_nodes.assert_not_called()
@@ -196,12 +126,10 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
"llama_index.core.query_engine.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
):
# Mock AIClient and LLM
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
# Create two real TextNodes
mock_node1 = TextNode(
text="Content for doc 1.",
metadata={"document_id": "1", "title": "Document 1"},
@@ -210,41 +138,32 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
text="Content for doc 2.",
metadata={"document_id": "2", "title": "Document 2"},
)
mock_duplicate_node = TextNode(
text="More content for doc 1.",
metadata={"document_id": "1", "title": "Document 1 Duplicate"},
)
mock_foreign_node = TextNode(
text="Content for doc 3.",
metadata={"document_id": "3", "title": "Document 3"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [
mock_node1,
mock_node2,
mock_duplicate_node,
mock_foreign_node,
]
add_vector_query_results(
mock_index,
[mock_node1, mock_duplicate_node, mock_node2, mock_foreign_node],
)
# Simulate get_nodes returning nodes (content exists)
mock_index.vector_store.get_nodes.return_value = [mock_node1, mock_node2]
mock_load_index.return_value = mock_index
# Mock response stream
mock_retriever_instance = MagicMock()
mock_retriever_instance.retrieve.return_value = [
MagicMock(metadata={"document_id": "1", "title": "Document 1"}),
MagicMock(metadata={"document_id": "2", "title": "Document 2"}),
]
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
# Mock RetrieverQueryEngine
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
# Fake documents
doc1 = MagicMock(pk=1, title="Document 1", filename="doc1.pdf")
doc2 = MagicMock(pk=2, title="Document 2", filename="doc2.pdf")
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
with patch(
"llama_index.core.retrievers.VectorIndexRetriever",
return_value=mock_retriever_instance,
):
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
mock_query_engine.query.assert_called_once_with("What's up?")
patch_embed_nodes.assert_not_called()
@@ -258,8 +177,16 @@ def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes) -> Non
)
def test_stream_chat_empty_document_list() -> None:
with patch("paperless_ai.chat.load_or_build_index") as mock_load_index:
output = list(stream_chat_with_documents("Any info?", []))
mock_load_index.assert_not_called()
assert output == ["Sorry, I couldn't find any content to answer your question."]
def test_stream_chat_no_matching_nodes() -> None:
with (
patch("paperless_ai.chat.AIConfig"),
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
):
@@ -268,8 +195,8 @@ def test_stream_chat_no_matching_nodes() -> None:
mock_client.llm = MagicMock()
mock_index = MagicMock()
# No matching nodes
mock_index.docstore.docs.values.return_value = []
# No matching nodes in the store
mock_index.vector_store.get_nodes.return_value = []
mock_load_index.return_value = mock_index
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
@@ -279,30 +206,83 @@ def test_stream_chat_no_matching_nodes() -> None:
def test_stream_chat_unexpected_failure_returns_generic_error(caplog) -> None:
with (
patch("paperless_ai.chat.AIConfig"),
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless_ai.chat._get_document_filtered_retriever",
) as mock_get_retriever,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
mock_node = TextNode(
text="This is node content.",
metadata={"document_id": "1", "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
# Nodes found so we get past the pre-check
mock_index.vector_store.get_nodes.return_value = [MagicMock()]
mock_load_index.return_value = mock_index
mock_retriever = MagicMock()
mock_retriever.retrieve.side_effect = RuntimeError("private provider detail")
mock_get_retriever.return_value = mock_retriever
with patch(
"llama_index.core.retrievers.VectorIndexRetriever",
) as mock_retriever_cls:
mock_retriever = MagicMock()
mock_retriever.retrieve.side_effect = RuntimeError(
"private provider detail",
)
mock_retriever_cls.return_value = mock_retriever
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
assert output == [CHAT_ERROR_MESSAGE]
assert "Failed to stream document chat response" in caplog.text
assert "private provider detail" in caplog.text
@pytest.mark.django_db
class TestStreamChatRetrieval:
def test_no_nodes_yields_no_content_message(
self,
temp_llm_index_dir,
mock_embed_model,
) -> None:
doc = DocumentFactory.create(content="hello world")
# Nothing indexed for this document yet.
out = list(chat.stream_chat_with_documents("question?", [doc]))
assert chat.CHAT_NO_CONTENT_MESSAGE in out
def test_chat_filter_contains_only_requested_document_ids(
self,
temp_llm_index_dir,
mock_embed_model,
mocker,
) -> None:
"""The MetadataFilter passed to the retriever must be scoped to the
requested documents only content from other indexed documents must
not be surfaced.
"""
included = DocumentFactory.create(content="included document content")
excluded = DocumentFactory.create(content="excluded document content")
indexing.llm_index_add_or_update_document(included)
indexing.llm_index_add_or_update_document(excluded)
# VectorIndexRetriever is imported inside _stream_chat_with_documents;
# patch it at the llama_index source so the lazy import picks it up.
captured_filters = []
mock_retriever = mocker.MagicMock()
mock_retriever.retrieve.return_value = []
def capture_retriever(*args, **kwargs):
captured_filters.append(kwargs.get("filters"))
return mock_retriever
mocker.patch("paperless_ai.chat.AIClient")
mocker.patch(
"llama_index.core.retrievers.VectorIndexRetriever",
side_effect=capture_retriever,
)
list(chat.stream_chat_with_documents("question?", [included]))
assert captured_filters, "VectorIndexRetriever was never constructed"
filt = captured_filters[0]
assert filt is not None, "Retriever must receive a MetadataFilters"
filter_values = filt.filters[0].value
assert str(included.pk) in filter_values
assert str(excluded.pk) not in filter_values

Some files were not shown because too many files have changed in this diff Show More