Try to use a batched method of tag M2M validation

This commit is contained in:
Trenton H
2026-04-21 10:19:52 -07:00
parent ffaa2bb77a
commit e6ffb003a6
+47
View File
@@ -48,6 +48,7 @@ from rest_framework import serializers
from rest_framework.exceptions import PermissionDenied
from rest_framework.fields import SerializerMethodField
from rest_framework.filters import OrderingFilter
from rest_framework.relations import MANY_RELATION_KWARGS
if settings.AUDIT_LOG_ENABLED:
from auditlog.context import set_actor
@@ -687,10 +688,56 @@ class CorrespondentField(serializers.PrimaryKeyRelatedField[Correspondent]):
return Correspondent.objects.all()
class BatchedManyRelatedField(serializers.ManyRelatedField):
"""
ManyRelatedField that validates all PKs with a single filter(pk__in=pks)
instead of one queryset.get() per PK, eliminating N+1 on write validation.
"""
def to_internal_value(self, data):
if isinstance(data, str) or not hasattr(data, "__iter__"):
self.fail("not_a_list", input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
self.fail("empty")
if not data:
return []
child = self.child_relation
pk_field = getattr(child, "pk_field", None)
pks = []
for item in data:
if pk_field is not None:
item = pk_field.to_internal_value(item)
if isinstance(item, bool):
child.fail("incorrect_type", data_type="bool")
pks.append(item)
try:
found = {obj.pk: obj for obj in child.get_queryset().filter(pk__in=pks)}
except (TypeError, ValueError):
child.fail("incorrect_type", data_type=type(pks[0]).__name__)
result = []
for pk in pks:
if pk not in found:
child.fail("does_not_exist", pk_value=pk)
result.append(found[pk])
return result
class TagsField(serializers.PrimaryKeyRelatedField[Tag]):
def get_queryset(self):
return Tag.objects.all()
@classmethod
def many_init(cls, *args, **kwargs):
list_kwargs = {"child_relation": cls(*args, **kwargs)}
for key, value in kwargs.items():
if key in MANY_RELATION_KWARGS:
list_kwargs[key] = value
return BatchedManyRelatedField(**list_kwargs)
class DocumentTypeField(serializers.PrimaryKeyRelatedField[DocumentType]):
def get_queryset(self):