Skip to content

Commit 9b7db30

Browse files
committed
Fix unique together validator doesn't respect condition's fields
1 parent e13688f commit 9b7db30

File tree

4 files changed

+129
-33
lines changed

4 files changed

+129
-33
lines changed

rest_framework/compat.py

+36
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
versions of Django/Python, and compatibility wrappers around optional packages.
44
"""
55
import django
6+
from django.db import models
7+
from django.db.models.constants import LOOKUP_SEP
8+
from django.db.models.sql.query import Node
69
from django.views.generic import View
710

811

@@ -157,6 +160,10 @@ def md_filter_add_syntax_highlight(md):
157160
# 1) the list of validators and 2) the error message. Starting from
158161
# Django 5.1 ip_address_validators only returns the list of validators
159162
from django.core.validators import ip_address_validators
163+
164+
def get_referenced_base_fields_from_q(q):
165+
return q.referenced_base_fields
166+
160167
else:
161168
# Django <= 5.1: create a compatibility shim for ip_address_validators
162169
from django.core.validators import \
@@ -165,6 +172,35 @@ def md_filter_add_syntax_highlight(md):
165172
def ip_address_validators(protocol, unpack_ipv4):
166173
return _ip_address_validators(protocol, unpack_ipv4)[0]
167174

175+
# Django <= 5.1: create a compatibility shim for Q.referenced_base_fields
176+
# https://github.com/django/django/blob/5.1a1/django/db/models/query_utils.py#L179
177+
def _get_paths_from_expression(expr):
178+
if isinstance(expr, models.F):
179+
yield expr.name
180+
elif hasattr(expr, "flatten"):
181+
for child in expr.flatten():
182+
if isinstance(child, models.F):
183+
yield child.name
184+
elif isinstance(child, models.Q):
185+
yield from _get_children_from_q(child)
186+
187+
def _get_children_from_q(q):
188+
for child in q.children:
189+
if isinstance(child, Node):
190+
yield from _get_children_from_q(child)
191+
elif isinstance(child, tuple):
192+
lhs, rhs = child
193+
yield lhs
194+
if hasattr(rhs, "resolve_expression"):
195+
yield from _get_paths_from_expression(rhs)
196+
elif hasattr(child, "resolve_expression"):
197+
yield from _get_paths_from_expression(child)
198+
199+
def get_referenced_base_fields_from_q(q):
200+
return {
201+
child.split(LOOKUP_SEP, 1)[0] for child in _get_children_from_q(q)
202+
}
203+
168204

169205
# `separators` argument to `json.dumps()` differs between 2.x and 3.x
170206
# See: https://bugs.python.org/issue22767

rest_framework/serializers.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from django.utils.functional import cached_property
2727
from django.utils.translation import gettext_lazy as _
2828

29-
from rest_framework.compat import postgres_fields
29+
from rest_framework.compat import (
30+
get_referenced_base_fields_from_q, postgres_fields
31+
)
3032
from rest_framework.exceptions import ErrorDetail, ValidationError
3133
from rest_framework.fields import get_error_detail
3234
from rest_framework.settings import api_settings
@@ -1430,15 +1432,15 @@ def get_unique_together_constraints(self, model):
14301432
"""
14311433
for parent_class in [model] + list(model._meta.parents):
14321434
for unique_together in parent_class._meta.unique_together:
1433-
yield unique_together, model._default_manager
1435+
yield unique_together, model._default_manager, [], None
14341436
for constraint in parent_class._meta.constraints:
14351437
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
1436-
yield (
1437-
constraint.fields,
1438-
model._default_manager
1439-
if constraint.condition is None
1440-
else model._default_manager.filter(constraint.condition)
1441-
)
1438+
queryset = model._default_manager
1439+
if constraint.condition is None:
1440+
condition_fields = []
1441+
else:
1442+
condition_fields = list(get_referenced_base_fields_from_q(constraint.condition))
1443+
yield (constraint.fields, queryset, condition_fields, constraint.condition)
14421444

14431445
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
14441446
"""
@@ -1470,9 +1472,9 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs
14701472

14711473
# Include each of the `unique_together` and `UniqueConstraint` field names,
14721474
# so long as all the field names are included on the serializer.
1473-
for unique_together_list, queryset in self.get_unique_together_constraints(model):
1474-
if set(field_names).issuperset(unique_together_list):
1475-
unique_constraint_names |= set(unique_together_list)
1475+
for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model):
1476+
if set(field_names).issuperset((*unique_together_list, *condition_fields)):
1477+
unique_constraint_names |= set((*unique_together_list, *condition_fields))
14761478

14771479
# Now we have all the field names that have uniqueness constraints
14781480
# applied, we can add the extra 'required=...' or 'default=...'
@@ -1592,12 +1594,12 @@ def get_unique_together_validators(self):
15921594
# Note that we make sure to check `unique_together` both on the
15931595
# base model class, but also on any parent classes.
15941596
validators = []
1595-
for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
1597+
for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model):
15961598
# Skip if serializer does not map to all unique together sources
1597-
if not set(source_map).issuperset(unique_together):
1599+
if not set(source_map).issuperset((*unique_together, *condition_fields)):
15981600
continue
15991601

1600-
for source in unique_together:
1602+
for source in (*unique_together, *condition_fields):
16011603
assert len(source_map[source]) == 1, (
16021604
"Unable to create `UniqueTogetherValidator` for "
16031605
"`{model}.{field}` as `{serializer}` has multiple "
@@ -1614,9 +1616,9 @@ def get_unique_together_validators(self):
16141616
)
16151617

16161618
field_names = tuple(source_map[f][0] for f in unique_together)
1619+
condition_fields = tuple(source_map[f][0] for f in condition_fields)
16171620
validator = UniqueTogetherValidator(
1618-
queryset=queryset,
1619-
fields=field_names
1621+
queryset=queryset, fields=field_names, condition_fields=condition_fields, condition=condition
16201622
)
16211623
validators.append(validator)
16221624
return validators

rest_framework/validators.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
`ModelSerializer` class and an equivalent explicit `Serializer` class.
88
"""
99
from django.db import DataError
10+
from django.db.models import Exists
1011
from django.utils.translation import gettext_lazy as _
1112

1213
from rest_framework.exceptions import ValidationError
@@ -23,6 +24,16 @@ def qs_exists(queryset):
2324
return False
2425

2526

27+
def qs_exists_with_condition(queryset, condition, against):
28+
if condition is None:
29+
return qs_exists(queryset)
30+
try:
31+
# use the same query as UniqueConstraint.validate https://github.com/django/django/blob/7ba2a0db20c37a5b1500434ca4ed48022311c171/django/db/models/constraints.py#L672
32+
return (condition & Exists(queryset.filter(condition))).check(against)
33+
except (TypeError, ValueError, DataError):
34+
return False
35+
36+
2637
def qs_filter(queryset, **kwargs):
2738
try:
2839
return queryset.filter(**kwargs)
@@ -99,10 +110,12 @@ class UniqueTogetherValidator:
99110
missing_message = _('This field is required.')
100111
requires_context = True
101112

102-
def __init__(self, queryset, fields, message=None):
113+
def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None):
103114
self.queryset = queryset
104115
self.fields = fields
105116
self.message = message or self.message
117+
self.condition_fields = [] if condition_fields is None else condition_fields
118+
self.condition = condition
106119

107120
def enforce_required_fields(self, attrs, serializer):
108121
"""
@@ -114,7 +127,7 @@ def enforce_required_fields(self, attrs, serializer):
114127

115128
missing_items = {
116129
field_name: self.missing_message
117-
for field_name in self.fields
130+
for field_name in (*self.fields, *self.condition_fields)
118131
if serializer.fields[field_name].source not in attrs
119132
}
120133
if missing_items:
@@ -172,16 +185,22 @@ def __call__(self, attrs, serializer):
172185
if field in self.fields and value != getattr(serializer.instance, field)
173186
]
174187

175-
if checked_values and None not in checked_values and qs_exists(queryset):
188+
condition_kwargs = {
189+
source: attrs[source]
190+
for source in self.condition_fields
191+
}
192+
if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs):
176193
field_names = ', '.join(self.fields)
177194
message = self.message.format(field_names=field_names)
178195
raise ValidationError(message, code='unique')
179196

180197
def __repr__(self):
181-
return '<%s(queryset=%s, fields=%s)>' % (
198+
return '<%s(%s)>' % (
182199
self.__class__.__name__,
183-
smart_repr(self.queryset),
184-
smart_repr(self.fields)
200+
', '.join(
201+
f'{attr}={smart_repr(getattr(self, attr))}'
202+
for attr in ('queryset', 'fields', 'condition')
203+
if getattr(self, attr) is not None)
185204
)
186205

187206
def __eq__(self, other):

tests/test_validators.py

+50-11
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,11 @@ class Meta:
513513
name="unique_constraint_model_together_uniq",
514514
fields=('race_name', 'position'),
515515
condition=models.Q(race_name='example'),
516+
),
517+
models.UniqueConstraint(
518+
name="unique_constraint_model_together_uniq2",
519+
fields=('race_name', 'position'),
520+
condition=models.Q(fancy_conditions__gte=10),
516521
)
517522
]
518523

@@ -553,22 +558,56 @@ def test_repr(self):
553558
position = IntegerField\(.*required=True\)
554559
global_id = IntegerField\(.*validators=\[<UniqueValidator\(queryset=UniqueConstraintModel.objects.all\(\)\)>\]\)
555560
class Meta:
556-
validators = \[<UniqueTogetherValidator\(queryset=<QuerySet \[<UniqueConstraintModel: UniqueConstraintModel object \(1\)>, <UniqueConstraintModel: UniqueConstraintModel object \(2\)>\]>, fields=\('race_name', 'position'\)\)>\]
561+
validators = \[<UniqueTogetherValidator\(queryset=UniqueConstraintModel.objects.all\(\), fields=\('race_name', 'position'\), condition=<Q: \(AND: \('race_name', 'example'\)\)>\)>\]
557562
""")
563+
print(repr(serializer))
558564
assert re.search(expected, repr(serializer)) is not None
559565

560-
def test_unique_together_field(self):
566+
def test_unique_together_condition(self):
561567
"""
562-
UniqueConstraint fields and condition attributes must be passed
563-
to UniqueTogetherValidator as fields and queryset
568+
Fields used in UniqueConstraint's condition must be included
569+
into queryset existence check
564570
"""
565-
serializer = UniqueConstraintSerializer()
566-
assert len(serializer.validators) == 1
567-
validator = serializer.validators[0]
568-
assert validator.fields == ('race_name', 'position')
569-
assert set(validator.queryset.values_list(flat=True)) == set(
570-
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
571+
UniqueConstraintModel.objects.create(
572+
race_name='condition',
573+
position=1,
574+
global_id=10,
575+
fancy_conditions=10
571576
)
577+
serializer = UniqueConstraintSerializer(data={
578+
'race_name': 'condition',
579+
'position': 1,
580+
'global_id': 11,
581+
'fancy_conditions': 9,
582+
})
583+
assert serializer.is_valid()
584+
serializer = UniqueConstraintSerializer(data={
585+
'race_name': 'condition',
586+
'position': 1,
587+
'global_id': 11,
588+
'fancy_conditions': 11,
589+
})
590+
assert not serializer.is_valid()
591+
592+
def test_unique_together_condition_fields_required(self):
593+
"""
594+
Fields used in UniqueConstraint's condition must be present in serializer
595+
"""
596+
serializer = UniqueConstraintSerializer(data={
597+
'race_name': 'condition',
598+
'position': 1,
599+
'global_id': 11,
600+
})
601+
assert not serializer.is_valid()
602+
assert serializer.errors == {'fancy_conditions': ['This field is required.']}
603+
604+
class NoFieldsSerializer(serializers.ModelSerializer):
605+
class Meta:
606+
model = UniqueConstraintModel
607+
fields = ('race_name', 'position', 'global_id')
608+
609+
serializer = NoFieldsSerializer()
610+
assert len(serializer.validators) == 1
572611

573612
def test_single_field_uniq_validators(self):
574613
"""
@@ -579,7 +618,7 @@ def test_single_field_uniq_validators(self):
579618
extra_validators_qty = 2 if django_version[0] >= 5 else 0
580619
#
581620
serializer = UniqueConstraintSerializer()
582-
assert len(serializer.validators) == 1
621+
assert len(serializer.validators) == 2
583622
validators = serializer.fields['global_id'].validators
584623
assert len(validators) == 1 + extra_validators_qty
585624
assert validators[0].queryset == UniqueConstraintModel.objects

0 commit comments

Comments
 (0)