From e6ffb003a6982c27e7a2ab849f05c5a83e772c18 Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Tue, 21 Apr 2026 10:19:52 -0700 Subject: [PATCH] Try to use a batched method of tag M2M validation --- src/documents/serialisers.py | 47 ++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 3cba0fafa..cdb3e0613 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -48,6 +48,7 @@ from rest_framework import serializers from rest_framework.exceptions import PermissionDenied from rest_framework.fields import SerializerMethodField from rest_framework.filters import OrderingFilter +from rest_framework.relations import MANY_RELATION_KWARGS if settings.AUDIT_LOG_ENABLED: from auditlog.context import set_actor @@ -687,10 +688,56 @@ class CorrespondentField(serializers.PrimaryKeyRelatedField[Correspondent]): return Correspondent.objects.all() +class BatchedManyRelatedField(serializers.ManyRelatedField): + """ + ManyRelatedField that validates all PKs with a single filter(pk__in=pks) + instead of one queryset.get() per PK, eliminating N+1 on write validation. + """ + + def to_internal_value(self, data): + if isinstance(data, str) or not hasattr(data, "__iter__"): + self.fail("not_a_list", input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + self.fail("empty") + if not data: + return [] + + child = self.child_relation + pk_field = getattr(child, "pk_field", None) + + pks = [] + for item in data: + if pk_field is not None: + item = pk_field.to_internal_value(item) + if isinstance(item, bool): + child.fail("incorrect_type", data_type="bool") + pks.append(item) + + try: + found = {obj.pk: obj for obj in child.get_queryset().filter(pk__in=pks)} + except (TypeError, ValueError): + child.fail("incorrect_type", data_type=type(pks[0]).__name__) + + result = [] + for pk in pks: + if pk not in found: + child.fail("does_not_exist", pk_value=pk) + result.append(found[pk]) + return result + + class TagsField(serializers.PrimaryKeyRelatedField[Tag]): def get_queryset(self): return Tag.objects.all() + @classmethod + def many_init(cls, *args, **kwargs): + list_kwargs = {"child_relation": cls(*args, **kwargs)} + for key, value in kwargs.items(): + if key in MANY_RELATION_KWARGS: + list_kwargs[key] = value + return BatchedManyRelatedField(**list_kwargs) + class DocumentTypeField(serializers.PrimaryKeyRelatedField[DocumentType]): def get_queryset(self):