Skip to content

Commit 2fd3b45

Browse files
committed
Fix unique together validator doesn't respect condition's fields
1 parent f30c0e2 commit 2fd3b45

File tree

4 files changed

+140
-40
lines changed

4 files changed

+140
-40
lines changed

rest_framework/compat.py

+36
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
versions of Django/Python, and compatibility wrappers around optional packages.
44
"""
55
import django
6+
from django.db import models
7+
from django.db.models.constants import LOOKUP_SEP
8+
from django.db.models.sql.query import Node
69
from django.views.generic import View
710

811

@@ -157,6 +160,10 @@ def md_filter_add_syntax_highlight(md):
157160
# 1) the list of validators and 2) the error message. Starting from
158161
# Django 5.1 ip_address_validators only returns the list of validators
159162
from django.core.validators import ip_address_validators
163+
164+
def get_referenced_base_fields_from_q(q):
165+
return q.referenced_base_fields
166+
160167
else:
161168
# Django <= 5.1: create a compatibility shim for ip_address_validators
162169
from django.core.validators import \
@@ -165,6 +172,35 @@ def md_filter_add_syntax_highlight(md):
165172
def ip_address_validators(protocol, unpack_ipv4):
166173
return _ip_address_validators(protocol, unpack_ipv4)[0]
167174

175+
# Django < 5.1: create a compatibility shim for Q.referenced_base_fields
176+
# https://github.com/django/django/blob/5.1a1/django/db/models/query_utils.py#L179
177+
def _get_paths_from_expression(expr):
178+
if isinstance(expr, models.F):
179+
yield expr.name
180+
elif hasattr(expr, 'flatten'):
181+
for child in expr.flatten():
182+
if isinstance(child, models.F):
183+
yield child.name
184+
elif isinstance(child, models.Q):
185+
yield from _get_children_from_q(child)
186+
187+
def _get_children_from_q(q):
188+
for child in q.children:
189+
if isinstance(child, Node):
190+
yield from _get_children_from_q(child)
191+
elif isinstance(child, tuple):
192+
lhs, rhs = child
193+
yield lhs
194+
if hasattr(rhs, 'resolve_expression'):
195+
yield from _get_paths_from_expression(rhs)
196+
elif hasattr(child, 'resolve_expression'):
197+
yield from _get_paths_from_expression(child)
198+
199+
def get_referenced_base_fields_from_q(q):
200+
return {
201+
child.split(LOOKUP_SEP, 1)[0] for child in _get_children_from_q(q)
202+
}
203+
168204

169205
# `separators` argument to `json.dumps()` differs between 2.x and 3.x
170206
# See: https://bugs.python.org/issue22767

rest_framework/serializers.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from django.utils.functional import cached_property
2727
from django.utils.translation import gettext_lazy as _
2828

29-
from rest_framework.compat import postgres_fields
29+
from rest_framework.compat import (
30+
get_referenced_base_fields_from_q, postgres_fields
31+
)
3032
from rest_framework.exceptions import ErrorDetail, ValidationError
3133
from rest_framework.fields import get_error_detail
3234
from rest_framework.settings import api_settings
@@ -1425,20 +1427,20 @@ def get_extra_kwargs(self):
14251427

14261428
def get_unique_together_constraints(self, model):
14271429
"""
1428-
Returns iterator of (fields, queryset), each entry describes an unique together
1429-
constraint on `fields` in `queryset`.
1430+
Returns iterator of (fields, queryset, condition_fields, condition),
1431+
each entry describes an unique together constraint on `fields` in `queryset`
1432+
with respect of constraint's `condition`.
14301433
"""
14311434
for parent_class in [model] + list(model._meta.parents):
14321435
for unique_together in parent_class._meta.unique_together:
1433-
yield unique_together, model._default_manager
1436+
yield unique_together, model._default_manager, [], None
14341437
for constraint in parent_class._meta.constraints:
14351438
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
1436-
yield (
1437-
constraint.fields,
1438-
model._default_manager
1439-
if constraint.condition is None
1440-
else model._default_manager.filter(constraint.condition)
1441-
)
1439+
if constraint.condition is None:
1440+
condition_fields = []
1441+
else:
1442+
condition_fields = list(get_referenced_base_fields_from_q(constraint.condition))
1443+
yield (constraint.fields, model._default_manager, condition_fields, constraint.condition)
14421444

14431445
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
14441446
"""
@@ -1470,9 +1472,10 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs
14701472

14711473
# Include each of the `unique_together` and `UniqueConstraint` field names,
14721474
# so long as all the field names are included on the serializer.
1473-
for unique_together_list, queryset in self.get_unique_together_constraints(model):
1474-
if set(field_names).issuperset(unique_together_list):
1475-
unique_constraint_names |= set(unique_together_list)
1475+
for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model):
1476+
unique_together_list_and_condition_fields = {*unique_together_list, *condition_fields}
1477+
if set(field_names).issuperset(unique_together_list_and_condition_fields):
1478+
unique_constraint_names |= set(unique_together_list_and_condition_fields)
14761479

14771480
# Now we have all the field names that have uniqueness constraints
14781481
# applied, we can add the extra 'required=...' or 'default=...'
@@ -1594,12 +1597,13 @@ def get_unique_together_validators(self):
15941597
# Note that we make sure to check `unique_together` both on the
15951598
# base model class, but also on any parent classes.
15961599
validators = []
1597-
for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
1600+
for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model):
15981601
# Skip if serializer does not map to all unique together sources
1599-
if not set(source_map).issuperset(unique_together):
1602+
unique_together_and_condition_fields = {*unique_together, *condition_fields}
1603+
if not set(source_map).issuperset(unique_together_and_condition_fields):
16001604
continue
16011605

1602-
for source in unique_together:
1606+
for source in unique_together_and_condition_fields:
16031607
assert len(source_map[source]) == 1, (
16041608
"Unable to create `UniqueTogetherValidator` for "
16051609
"`{model}.{field}` as `{serializer}` has multiple "
@@ -1618,7 +1622,9 @@ def get_unique_together_validators(self):
16181622
field_names = tuple(source_map[f][0] for f in unique_together)
16191623
validator = UniqueTogetherValidator(
16201624
queryset=queryset,
1621-
fields=field_names
1625+
fields=field_names,
1626+
condition_fields=tuple(source_map[f][0] for f in condition_fields),
1627+
condition=condition,
16221628
)
16231629
validators.append(validator)
16241630
return validators

rest_framework/validators.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
object creation, and makes it possible to switch between using the implicit
77
`ModelSerializer` class and an equivalent explicit `Serializer` class.
88
"""
9+
from django.core.exceptions import FieldError
910
from django.db import DataError
11+
from django.db.models import Exists
1012
from django.utils.translation import gettext_lazy as _
1113

1214
from rest_framework.exceptions import ValidationError
@@ -23,6 +25,17 @@ def qs_exists(queryset):
2325
return False
2426

2527

28+
def qs_exists_with_condition(queryset, condition, against):
29+
if condition is None:
30+
return qs_exists(queryset)
31+
try:
32+
# use the same query as UniqueConstraint.validate
33+
# https://github.com/django/django/blob/7ba2a0db20c37a5b1500434ca4ed48022311c171/django/db/models/constraints.py#L672
34+
return (condition & Exists(queryset.filter(condition))).check(against)
35+
except (TypeError, ValueError, DataError, FieldError):
36+
return False
37+
38+
2639
def qs_filter(queryset, **kwargs):
2740
try:
2841
return queryset.filter(**kwargs)
@@ -99,10 +112,12 @@ class UniqueTogetherValidator:
99112
missing_message = _('This field is required.')
100113
requires_context = True
101114

102-
def __init__(self, queryset, fields, message=None):
115+
def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None):
103116
self.queryset = queryset
104117
self.fields = fields
105118
self.message = message or self.message
119+
self.condition_fields = [] if condition_fields is None else condition_fields
120+
self.condition = condition
106121

107122
def enforce_required_fields(self, attrs, serializer):
108123
"""
@@ -114,7 +129,7 @@ def enforce_required_fields(self, attrs, serializer):
114129

115130
missing_items = {
116131
field_name: self.missing_message
117-
for field_name in self.fields
132+
for field_name in (*self.fields, *self.condition_fields)
118133
if serializer.fields[field_name].source not in attrs
119134
}
120135
if missing_items:
@@ -173,16 +188,19 @@ def __call__(self, attrs, serializer):
173188
if attrs[field_name] != getattr(serializer.instance, field_name)
174189
]
175190

176-
if checked_values and None not in checked_values and qs_exists(queryset):
191+
condition_kwargs = {source: attrs[source] for source in self.condition_fields}
192+
if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs):
177193
field_names = ', '.join(self.fields)
178194
message = self.message.format(field_names=field_names)
179195
raise ValidationError(message, code='unique')
180196

181197
def __repr__(self):
182-
return '<%s(queryset=%s, fields=%s)>' % (
198+
return '<{}({})>'.format(
183199
self.__class__.__name__,
184-
smart_repr(self.queryset),
185-
smart_repr(self.fields)
200+
', '.join(
201+
f'{attr}={smart_repr(getattr(self, attr))}'
202+
for attr in ('queryset', 'fields', 'condition')
203+
if getattr(self, attr) is not None)
186204
)
187205

188206
def __eq__(self, other):

tests/test_validators.py

+57-17
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ class UniqueConstraintModel(models.Model):
521521
race_name = models.CharField(max_length=100)
522522
position = models.IntegerField()
523523
global_id = models.IntegerField()
524-
fancy_conditions = models.IntegerField(null=True)
524+
fancy_conditions = models.IntegerField()
525525

526526
class Meta:
527527
constraints = [
@@ -543,7 +543,12 @@ class Meta:
543543
name="unique_constraint_model_together_uniq",
544544
fields=('race_name', 'position'),
545545
condition=models.Q(race_name='example'),
546-
)
546+
),
547+
models.UniqueConstraint(
548+
name='unique_constraint_model_together_uniq2',
549+
fields=('race_name', 'position'),
550+
condition=models.Q(fancy_conditions__gte=10),
551+
),
547552
]
548553

549554

@@ -576,17 +581,20 @@ def setUp(self):
576581
self.instance = UniqueConstraintModel.objects.create(
577582
race_name='example',
578583
position=1,
579-
global_id=1
584+
global_id=1,
585+
fancy_conditions=1
580586
)
581587
UniqueConstraintModel.objects.create(
582588
race_name='example',
583589
position=2,
584-
global_id=2
590+
global_id=2,
591+
fancy_conditions=1
585592
)
586593
UniqueConstraintModel.objects.create(
587594
race_name='other',
588595
position=1,
589-
global_id=3
596+
global_id=3,
597+
fancy_conditions=1
590598
)
591599

592600
def test_repr(self):
@@ -601,22 +609,55 @@ def test_repr(self):
601609
position = IntegerField\(.*required=True\)
602610
global_id = IntegerField\(.*validators=\[<UniqueValidator\(queryset=UniqueConstraintModel.objects.all\(\)\)>\]\)
603611
class Meta:
604-
validators = \[<UniqueTogetherValidator\(queryset=<QuerySet \[<UniqueConstraintModel: UniqueConstraintModel object \(1\)>, <UniqueConstraintModel: UniqueConstraintModel object \(2\)>\]>, fields=\('race_name', 'position'\)\)>\]
612+
validators = \[<UniqueTogetherValidator\(queryset=UniqueConstraintModel.objects.all\(\), fields=\('race_name', 'position'\), condition=<Q: \(AND: \('race_name', 'example'\)\)>\)>\]
605613
""")
606614
assert re.search(expected, repr(serializer)) is not None
607615

608-
def test_unique_together_field(self):
616+
def test_unique_together_condition(self):
609617
"""
610-
UniqueConstraint fields and condition attributes must be passed
611-
to UniqueTogetherValidator as fields and queryset
618+
Fields used in UniqueConstraint's condition must be included
619+
into queryset existence check
612620
"""
613-
serializer = UniqueConstraintSerializer()
614-
assert len(serializer.validators) == 1
615-
validator = serializer.validators[0]
616-
assert validator.fields == ('race_name', 'position')
617-
assert set(validator.queryset.values_list(flat=True)) == set(
618-
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
621+
UniqueConstraintModel.objects.create(
622+
race_name='condition',
623+
position=1,
624+
global_id=10,
625+
fancy_conditions=10,
619626
)
627+
serializer = UniqueConstraintSerializer(data={
628+
'race_name': 'condition',
629+
'position': 1,
630+
'global_id': 11,
631+
'fancy_conditions': 9,
632+
})
633+
assert serializer.is_valid()
634+
serializer = UniqueConstraintSerializer(data={
635+
'race_name': 'condition',
636+
'position': 1,
637+
'global_id': 11,
638+
'fancy_conditions': 11,
639+
})
640+
assert not serializer.is_valid()
641+
642+
def test_unique_together_condition_fields_required(self):
643+
"""
644+
Fields used in UniqueConstraint's condition must be present in serializer
645+
"""
646+
serializer = UniqueConstraintSerializer(data={
647+
'race_name': 'condition',
648+
'position': 1,
649+
'global_id': 11,
650+
})
651+
assert not serializer.is_valid()
652+
assert serializer.errors == {'fancy_conditions': ['This field is required.']}
653+
654+
class NoFieldsSerializer(serializers.ModelSerializer):
655+
class Meta:
656+
model = UniqueConstraintModel
657+
fields = ('race_name', 'position', 'global_id')
658+
659+
serializer = NoFieldsSerializer()
660+
assert len(serializer.validators) == 1
620661

621662
def test_single_field_uniq_validators(self):
622663
"""
@@ -625,9 +666,8 @@ def test_single_field_uniq_validators(self):
625666
"""
626667
# Django 5 includes Max and Min values validators for IntergerField
627668
extra_validators_qty = 2 if django_version[0] >= 5 else 0
628-
#
629669
serializer = UniqueConstraintSerializer()
630-
assert len(serializer.validators) == 1
670+
assert len(serializer.validators) == 2
631671
validators = serializer.fields['global_id'].validators
632672
assert len(validators) == 1 + extra_validators_qty
633673
assert validators[0].queryset == UniqueConstraintModel.objects

0 commit comments

Comments
 (0)