Skip to content

Commit 736e30f

Browse files
Validation of the fields argument to bulk_update (#2808)
Co-authored-by: sobolevn <[email protected]>
1 parent e7da6a9 commit 736e30f

File tree

4 files changed

+291
-2
lines changed

4 files changed

+291
-2
lines changed

mypy_django_plugin/lib/helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,9 @@ def resolve_string_attribute_value(attr_expr: Expression, django_context: "Djang
548548
if isinstance(attr_expr, StrExpr):
549549
return attr_expr.value
550550

551+
if isinstance(attr_expr, NameExpr) and isinstance(attr_expr.node, Var) and attr_expr.node.type is not None:
552+
return get_literal_str_type(attr_expr.node.type)
553+
551554
# support extracting from settings, in general case it's unresolvable yet
552555
if isinstance(attr_expr, MemberExpr):
553556
member_name = attr_expr.name

mypy_django_plugin/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ def manager_and_queryset_method_hooks(self) -> dict[str, Callable[[MethodContext
169169
querysets.extract_prefetch_related_annotations, django_context=self.django_context
170170
),
171171
"select_related": partial(querysets.validate_select_related, django_context=self.django_context),
172+
"bulk_update": partial(
173+
querysets.validate_bulk_update, django_context=self.django_context, method="bulk_update"
174+
),
175+
"abulk_update": partial(
176+
querysets.validate_bulk_update, django_context=self.django_context, method="abulk_update"
177+
),
172178
}
173179

174180
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], MypyType] | None:

mypy_django_plugin/transformers/querysets.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Sequence
2+
from typing import Literal
23

3-
from django.core.exceptions import FieldError
4+
from django.core.exceptions import FieldDoesNotExist, FieldError
45
from django.db.models.base import Model
56
from django.db.models.fields.related import RelatedField
67
from django.db.models.fields.related_descriptors import (
@@ -12,7 +13,7 @@
1213
from django.db.models.fields.reverse_related import ForeignObjectRel
1314
from mypy.checker import TypeChecker
1415
from mypy.errorcodes import NO_REDEF
15-
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, CallExpr, Expression
16+
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, CallExpr, Expression, ListExpr, SetExpr, TupleExpr
1617
from mypy.plugin import FunctionContext, MethodContext
1718
from mypy.types import AnyType, Instance, LiteralType, ProperType, TupleType, TypedDictType, TypeOfAny, get_proper_type
1819
from mypy.types import Type as MypyType
@@ -768,3 +769,58 @@ def validate_select_related(ctx: MethodContext, django_context: DjangoContext) -
768769
_validate_select_related_lookup(ctx, django_context, django_model.cls, lookup_value)
769770

770771
return ctx.default_return_type
772+
773+
774+
def _validate_bulk_update_field(
775+
ctx: MethodContext, model_cls: type[Model], field_name: str, method: Literal["bulk_update", "abulk_update"]
776+
) -> bool:
777+
opts = model_cls._meta
778+
try:
779+
field = opts.get_field(field_name)
780+
except FieldDoesNotExist as e:
781+
ctx.api.fail(str(e), ctx.context)
782+
return False
783+
784+
if not field.concrete or field.many_to_many:
785+
ctx.api.fail(f'"{method}()" can only be used with concrete fields. Got "{field_name}"', ctx.context)
786+
return False
787+
788+
all_pk_fields = set(opts.pk_fields)
789+
for parent in opts.all_parents:
790+
all_pk_fields.update(parent._meta.pk_fields)
791+
792+
if field in all_pk_fields:
793+
ctx.api.fail(f'"{method}()" cannot be used with primary key fields. Got "{field_name}"', ctx.context)
794+
return False
795+
796+
return True
797+
798+
799+
def validate_bulk_update(
800+
ctx: MethodContext, django_context: DjangoContext, method: Literal["bulk_update", "abulk_update"]
801+
) -> MypyType:
802+
"""
803+
Type check the `fields` argument passed to `QuerySet.bulk_update(...)`.
804+
805+
Extracted and adapted from `django.db.models.query.QuerySet.bulk_update`
806+
Mirrors tests from `django/tests/queries/test_bulk_update.py`
807+
"""
808+
if not (
809+
isinstance(ctx.type, Instance)
810+
and (django_model := helpers.get_model_info_from_qs_ctx(ctx, django_context)) is not None
811+
and len(ctx.args) >= 2
812+
and ctx.args[1]
813+
and isinstance((fields_args := ctx.args[1][0]), (ListExpr, TupleExpr, SetExpr))
814+
):
815+
return ctx.default_return_type
816+
817+
if len(fields_args.items) == 0:
818+
ctx.api.fail(f'Field names must be given to "{method}()"', ctx.context)
819+
return ctx.default_return_type
820+
821+
for field_arg in fields_args.items:
822+
field_name = helpers.resolve_string_attribute_value(field_arg, django_context)
823+
if field_name is not None:
824+
_validate_bulk_update_field(ctx, django_model.cls, field_name, method)
825+
826+
return ctx.default_return_type
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
- case: bulk_update_valid_fields
2+
installed_apps:
3+
- myapp
4+
main: |
5+
from myapp.models import Article, Author, Category
6+
7+
# Valid single field updates
8+
articles = Article.objects.all()
9+
Article.objects.bulk_update(articles, ["title"])
10+
Article.objects.bulk_update(articles, ["content"])
11+
Article.objects.bulk_update(articles, ["published"])
12+
13+
# Valid multiple field updates
14+
Article.objects.bulk_update(articles, ("title", "content"))
15+
Article.objects.bulk_update(articles, {"title", "content", "published"})
16+
17+
# Valid foreign key field updates (by field name and attname)
18+
Article.objects.bulk_update(articles, ["author"])
19+
Article.objects.bulk_update(articles, ["author_id"])
20+
Article.objects.bulk_update(articles, ["category"])
21+
Article.objects.bulk_update(articles, ["category_id"])
22+
23+
# Valid updates on different models
24+
authors = Author.objects.all()
25+
Author.objects.bulk_update(authors, ["name"])
26+
Author.objects.bulk_update(authors, ["email"])
27+
28+
categories = Category.objects.all()
29+
Category.objects.bulk_update(categories, ["name"])
30+
Category.objects.bulk_update(categories, ["parent"])
31+
Category.objects.bulk_update(categories, ["parent_id"])
32+
33+
# Variables containing field names
34+
field_name = "title"
35+
Article.objects.bulk_update(articles, [field_name])
36+
37+
# Dynamic field lists
38+
def get_fields() -> list[str]:
39+
return ["title"]
40+
41+
Article.objects.bulk_update(articles, get_fields())
42+
43+
async def test_async_bulk_update() -> None:
44+
# Valid single field updates
45+
articles = Article.objects.all()
46+
await Article.objects.abulk_update(articles, {"published"})
47+
48+
# Valid multiple field updates
49+
await Article.objects.abulk_update(articles, ("title", "content", "published"))
50+
51+
# Valid foreign key field updates (by field name and attname)
52+
await Article.objects.abulk_update(articles, ["category_id"])
53+
54+
# Valid updates on different models
55+
authors = Author.objects.all()
56+
await Author.objects.abulk_update(authors, ["email"])
57+
58+
categories = Category.objects.all()
59+
await Category.objects.abulk_update(categories, ["parent_id"])
60+
61+
# Variables containing field names
62+
field_name = "title"
63+
await Article.objects.abulk_update(articles, [field_name])
64+
65+
# Dynamic field lists
66+
def get_fields() -> list[str]:
67+
return ["title"]
68+
69+
await Article.objects.abulk_update(articles, get_fields())
70+
files:
71+
- path: myapp/__init__.py
72+
- path: myapp/models.py
73+
content: |
74+
from django.db import models
75+
76+
class Category(models.Model):
77+
name = models.CharField(max_length=100)
78+
parent = models.ForeignKey('self', on_delete=models.CASCADE, null=True, blank=True)
79+
80+
class Author(models.Model):
81+
name = models.CharField(max_length=100)
82+
email = models.EmailField()
83+
84+
class Tag(models.Model):
85+
name = models.CharField(max_length=50)
86+
87+
class Article(models.Model):
88+
title = models.CharField(max_length=200)
89+
content = models.TextField()
90+
published = models.BooleanField(default=False)
91+
author = models.ForeignKey(Author, on_delete=models.CASCADE)
92+
category = models.ForeignKey(Category, on_delete=models.CASCADE)
93+
tags = models.ManyToManyField(Tag)
94+
95+
96+
97+
- case: bulk_update_invalid_fields
98+
installed_apps:
99+
- myapp
100+
main: |
101+
from myapp.models import Article, Author, Category
102+
from typing import Literal
103+
104+
articles = Article.objects.all()
105+
106+
# Empty fields list
107+
Article.objects.bulk_update() # E: Missing positional arguments "objs", "fields" in call to "bulk_update" of "Manager" [call-arg]
108+
Article.objects.bulk_update(articles, []) # E: Field names must be given to "bulk_update()" [misc]
109+
110+
# Invalid field names (Django's FieldError)
111+
Article.objects.bulk_update(articles, ["nonexistent"]) # E: Article has no field named 'nonexistent' [misc]
112+
Article.objects.bulk_update(articles, ["invalid_field"]) # E: Article has no field named 'invalid_field' [misc]
113+
114+
# Cannot update primary key fields
115+
Article.objects.bulk_update(articles, ["id"]) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc]
116+
117+
# Mixed valid and invalid fields
118+
Article.objects.bulk_update(articles, {"title", "nonexistent"}) # E: Article has no field named 'nonexistent' [misc]
119+
Article.objects.bulk_update(articles, ("id", "title")) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc]
120+
121+
# Whitespace-only field names
122+
Article.objects.bulk_update(articles, [""]) # E: Article has no field named '' [misc]
123+
Article.objects.bulk_update(articles, [" "]) # E: Article has no field named ' ' [misc]
124+
125+
# ManyToMany is not a concrete updatable field
126+
Article.objects.bulk_update(articles, {"tags"}) # E: "bulk_update()" can only be used with concrete fields. Got "tags" [misc]
127+
128+
# Nested field lookups are not supported
129+
Article.objects.bulk_update(articles, ["author__name"]) # E: Article has no field named 'author__name' [misc]
130+
Article.objects.bulk_update(articles, ["category__parent__name"]) # E: Article has no field named 'category__parent__name' [misc]
131+
132+
# Multiple invalid fields
133+
Article.objects.bulk_update(articles, ["nonexistent1", "nonexistent2"]) # E: Article has no field named 'nonexistent1' [misc] # E: Article has no field named 'nonexistent2' [misc]
134+
135+
# Primary key with valid fields
136+
Article.objects.bulk_update(articles, ["title", "id", "content"]) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc]
137+
138+
# Literal type variables are validated
139+
invalid_field: Literal["nonexistent"] = "nonexistent"
140+
Article.objects.bulk_update(articles, [invalid_field]) # E: Article has no field named 'nonexistent' [misc]
141+
142+
pk_field: Literal["id"] = "id"
143+
Article.objects.bulk_update(articles, [pk_field]) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc]
144+
145+
# Test with different models
146+
authors = Author.objects.all()
147+
Author.objects.bulk_update(authors, ["id"]) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc]
148+
Author.objects.bulk_update(authors, ["invalid"]) # E: Author has no field named 'invalid' [misc]
149+
150+
categories = Category.objects.all()
151+
Category.objects.bulk_update(categories, ["id"]) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc]
152+
Category.objects.bulk_update(categories, ["invalid"]) # E: Category has no field named 'invalid' [misc]
153+
154+
# Async version
155+
async def test_async_bulk_update_invalid() -> None:
156+
articles = Article.objects.all()
157+
158+
# Empty fields list
159+
await Article.objects.abulk_update() # E: Missing positional arguments "objs", "fields" in call to "abulk_update" of "Manager" [call-arg]
160+
await Article.objects.abulk_update(articles, []) # E: Field names must be given to "abulk_update()" [misc]
161+
162+
# Invalid field names (Django's FieldError)
163+
await Article.objects.abulk_update(articles, ["invalid_field"]) # E: Article has no field named 'invalid_field' [misc]
164+
165+
# Cannot update primary key fields
166+
await Article.objects.abulk_update(articles, ["id"]) # E: "abulk_update()" cannot be used with primary key fields. Got "id" [misc]
167+
168+
# Mixed valid and invalid fields
169+
await Article.objects.abulk_update(articles, ["id", "title"]) # E: "abulk_update()" cannot be used with primary key fields. Got "id" [misc]
170+
171+
# Whitespace-only field names
172+
await Article.objects.abulk_update(articles, [" "]) # E: Article has no field named ' ' [misc]
173+
174+
# ManyToMany is not a concrete updatable field
175+
await Article.objects.abulk_update(articles, ["tags"]) # E: "abulk_update()" can only be used with concrete fields. Got "tags" [misc]
176+
177+
# Nested field lookups are not supported
178+
await Article.objects.abulk_update(articles, ["author__name"]) # E: Article has no field named 'author__name' [misc]
179+
await Article.objects.abulk_update(articles, ["category__parent__name"]) # E: Article has no field named 'category__parent__name' [misc]
180+
181+
# Multiple invalid fields
182+
await Article.objects.abulk_update(articles, ("nonexistent1", "nonexistent2")) # E: Article has no field named 'nonexistent1' [misc] # E: Article has no field named 'nonexistent2' [misc]
183+
184+
# Primary key with valid fields
185+
await Article.objects.abulk_update(articles, ["title", "id", "content"]) # E: "abulk_update()" cannot be used with primary key fields. Got "id" [misc]
186+
187+
# Literal type variables are validated
188+
invalid_field: Literal["nonexistent"] = "nonexistent"
189+
await Article.objects.abulk_update(articles, {invalid_field}) # E: Article has no field named 'nonexistent' [misc]
190+
191+
pk_field: Literal["id"] = "id"
192+
await Article.objects.abulk_update(articles, [pk_field]) # E: "abulk_update()" cannot be used with primary key fields. Got "id" [misc]
193+
194+
# Test with different models
195+
authors = Author.objects.all()
196+
await Author.objects.abulk_update(authors, ["invalid"]) # E: Author has no field named 'invalid' [misc]
197+
198+
categories = Category.objects.all()
199+
await Category.objects.abulk_update(categories, ["invalid"]) # E: Category has no field named 'invalid' [misc]
200+
201+
files:
202+
- path: myapp/__init__.py
203+
- path: myapp/models.py
204+
content: |
205+
from django.db import models
206+
207+
class Category(models.Model):
208+
name = models.CharField(max_length=100)
209+
parent = models.ForeignKey('self', on_delete=models.CASCADE, null=True, blank=True)
210+
211+
class Author(models.Model):
212+
name = models.CharField(max_length=100)
213+
email = models.EmailField()
214+
215+
class Article(models.Model):
216+
title = models.CharField(max_length=200)
217+
content = models.TextField()
218+
published = models.BooleanField(default=False)
219+
author = models.ForeignKey(Author, on_delete=models.CASCADE)
220+
category = models.ForeignKey(Category, on_delete=models.CASCADE)
221+
tags = models.ManyToManyField('myapp.Tag')
222+
223+
class Tag(models.Model):
224+
name = models.CharField(max_length=50)

0 commit comments

Comments
 (0)