Skip to content

Commit 21adeb2

Browse files
bulk_update field validation
1 parent ed65629 commit 21adeb2

File tree

4 files changed

+208
-2
lines changed

4 files changed

+208
-2
lines changed

mypy_django_plugin/lib/helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,9 @@ def resolve_string_attribute_value(attr_expr: Expression, django_context: "Djang
545545
if isinstance(attr_expr, StrExpr):
546546
return attr_expr.value
547547

548+
if isinstance(attr_expr, NameExpr) and isinstance(attr_expr.node, Var) and attr_expr.node.type is not None:
549+
return get_literal_str_type(attr_expr.node.type)
550+
548551
# support extracting from settings, in general case it's unresolvable yet
549552
if isinstance(attr_expr, MemberExpr):
550553
member_name = attr_expr.name

mypy_django_plugin/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def manager_and_queryset_method_hooks(self) -> dict[str, Callable[[MethodContext
169169
querysets.extract_prefetch_related_annotations, django_context=self.django_context
170170
),
171171
"select_related": partial(querysets.validate_select_related, django_context=self.django_context),
172+
"bulk_update": partial(querysets.validate_bulk_update, django_context=self.django_context),
173+
"abulk_update": partial(querysets.validate_bulk_update, django_context=self.django_context),
172174
}
173175

174176
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], MypyType] | None:

mypy_django_plugin/transformers/querysets.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Sequence
22

3-
from django.core.exceptions import FieldError
3+
from django.core.exceptions import FieldDoesNotExist, FieldError
44
from django.db.models.base import Model
55
from django.db.models.fields.related import RelatedField
66
from django.db.models.fields.related_descriptors import (
@@ -12,7 +12,7 @@
1212
from django.db.models.fields.reverse_related import ForeignObjectRel
1313
from mypy.checker import TypeChecker
1414
from mypy.errorcodes import NO_REDEF
15-
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, CallExpr, Expression
15+
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, CallExpr, Expression, ListExpr
1616
from mypy.plugin import FunctionContext, MethodContext
1717
from mypy.types import AnyType, Instance, LiteralType, ProperType, TupleType, TypedDictType, TypeOfAny, get_proper_type
1818
from mypy.types import Type as MypyType
@@ -695,3 +695,58 @@ def validate_select_related(ctx: MethodContext, django_context: DjangoContext) -
695695
_validate_select_related_lookup(ctx, django_context, django_model.cls, lookup_value)
696696

697697
return ctx.default_return_type
698+
699+
700+
def _validate_bulk_update_field(
701+
ctx: MethodContext,
702+
model_cls: type[Model],
703+
field_name: str,
704+
) -> bool:
705+
opts = model_cls._meta
706+
try:
707+
field = opts.get_field(field_name)
708+
except FieldDoesNotExist as e:
709+
ctx.api.fail(str(e), ctx.context)
710+
return False
711+
712+
if not field.concrete or field.many_to_many:
713+
ctx.api.fail(f'bulk_update() can only be used with concrete fields. Got "{field_name}"', ctx.context)
714+
return False
715+
716+
all_pk_fields = set(opts.pk_fields)
717+
for parent in opts.all_parents:
718+
all_pk_fields.update(parent._meta.pk_fields)
719+
720+
if field in all_pk_fields:
721+
ctx.api.fail(f'bulk_update() cannot be used with primary key fields. Got "{field_name}"', ctx.context)
722+
return False
723+
724+
return True
725+
726+
727+
def validate_bulk_update(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
728+
"""
729+
Type check the `fields` argument passed to `QuerySet.bulk_update(...)`.
730+
731+
Extracted and adapted from `django.db.models.query.QuerySet.bulk_update`
732+
Mirrors tests from `django/tests/queries/test_bulk_update.py`
733+
"""
734+
if not (
735+
isinstance(ctx.type, Instance)
736+
and (django_model := helpers.get_model_info_from_qs_ctx(ctx, django_context)) is not None
737+
and len(ctx.args) >= 2
738+
and ctx.args[1]
739+
and isinstance((fields_args := ctx.args[1][0]), ListExpr)
740+
):
741+
return ctx.default_return_type
742+
743+
if len(fields_args.items) == 0:
744+
ctx.api.fail("Field names must be given to bulk_update()", ctx.context)
745+
return ctx.default_return_type
746+
747+
for field_arg in fields_args.items:
748+
field_name = helpers.resolve_string_attribute_value(field_arg, django_context)
749+
if field_name is not None:
750+
_validate_bulk_update_field(ctx, django_model.cls, field_name)
751+
752+
return ctx.default_return_type
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
- case: bulk_update_valid_fields
2+
installed_apps:
3+
- myapp
4+
main: |
5+
from myapp.models import Article, Author, Category
6+
7+
# Valid single field updates
8+
articles = Article.objects.all()
9+
Article.objects.bulk_update(articles, ["title"])
10+
Article.objects.bulk_update(articles, ["content"])
11+
Article.objects.bulk_update(articles, ["published"])
12+
13+
# Valid multiple field updates
14+
Article.objects.bulk_update(articles, ["title", "content"])
15+
Article.objects.bulk_update(articles, ["title", "content", "published"])
16+
17+
# Valid foreign key field updates (by field name and attname)
18+
Article.objects.bulk_update(articles, ["author"])
19+
Article.objects.bulk_update(articles, ["author_id"])
20+
Article.objects.bulk_update(articles, ["category"])
21+
Article.objects.bulk_update(articles, ["category_id"])
22+
23+
# Valid updates on different models
24+
authors = Author.objects.all()
25+
Author.objects.bulk_update(authors, ["name"])
26+
Author.objects.bulk_update(authors, ["email"])
27+
28+
categories = Category.objects.all()
29+
Category.objects.bulk_update(categories, ["name"])
30+
Category.objects.bulk_update(categories, ["parent"])
31+
Category.objects.bulk_update(categories, ["parent_id"])
32+
33+
# Variables containing field names
34+
field_name = "title"
35+
Article.objects.bulk_update(articles, [field_name])
36+
37+
# Dynamic field lists
38+
def get_fields() -> list[str]:
39+
return ["title"]
40+
41+
Article.objects.bulk_update(articles, get_fields())
42+
43+
files:
44+
- path: myapp/__init__.py
45+
- path: myapp/models.py
46+
content: |
47+
from django.db import models
48+
49+
class Category(models.Model):
50+
name = models.CharField(max_length=100)
51+
parent = models.ForeignKey('self', on_delete=models.CASCADE, null=True, blank=True)
52+
53+
class Author(models.Model):
54+
name = models.CharField(max_length=100)
55+
email = models.EmailField()
56+
57+
class Tag(models.Model):
58+
name = models.CharField(max_length=50)
59+
60+
class Article(models.Model):
61+
title = models.CharField(max_length=200)
62+
content = models.TextField()
63+
published = models.BooleanField(default=False)
64+
author = models.ForeignKey(Author, on_delete=models.CASCADE)
65+
category = models.ForeignKey(Category, on_delete=models.CASCADE)
66+
tags = models.ManyToManyField(Tag)
67+
68+
69+
70+
- case: bulk_update_invalid_fields
71+
installed_apps:
72+
- myapp
73+
main: |
74+
from myapp.models import Article, Author, Category
75+
from typing import Literal
76+
77+
articles = Article.objects.all()
78+
79+
# Empty fields list
80+
Article.objects.bulk_update() # E: Missing positional arguments "objs", "fields" in call to "bulk_update" of "Manager" [call-arg]
81+
Article.objects.bulk_update(articles, []) # E: Field names must be given to bulk_update() [misc]
82+
83+
# Invalid field names (Django's FieldError)
84+
Article.objects.bulk_update(articles, ["nonexistent"]) # E: Article has no field named 'nonexistent' [misc]
85+
Article.objects.bulk_update(articles, ["invalid_field"]) # E: Article has no field named 'invalid_field' [misc]
86+
87+
# Cannot update primary key fields
88+
Article.objects.bulk_update(articles, ["id"]) # E: bulk_update() cannot be used with primary key fields. Got "id" [misc]
89+
90+
# Mixed valid and invalid fields
91+
Article.objects.bulk_update(articles, ["title", "nonexistent"]) # E: Article has no field named 'nonexistent' [misc]
92+
Article.objects.bulk_update(articles, ["id", "title"]) # E: bulk_update() cannot be used with primary key fields. Got "id" [misc]
93+
94+
# Whitespace-only field names
95+
Article.objects.bulk_update(articles, [""]) # E: Article has no field named '' [misc]
96+
Article.objects.bulk_update(articles, [" "]) # E: Article has no field named ' ' [misc]
97+
98+
# ManyToMany is not a concrete updatable field
99+
Article.objects.bulk_update(articles, ["tags"]) # E: bulk_update() can only be used with concrete fields. Got "tags" [misc]
100+
101+
# Multiple invalid fields
102+
Article.objects.bulk_update(articles, ["nonexistent1", "nonexistent2"]) # E: Article has no field named 'nonexistent1' [misc] # E: Article has no field named 'nonexistent2' [misc]
103+
104+
# Primary key with valid fields
105+
Article.objects.bulk_update(articles, ["title", "id", "content"]) # E: bulk_update() cannot be used with primary key fields. Got "id" [misc]
106+
107+
# Literal type variables are validated
108+
invalid_field: Literal["nonexistent"] = "nonexistent"
109+
Article.objects.bulk_update(articles, [invalid_field]) # E: Article has no field named 'nonexistent' [misc]
110+
111+
pk_field: Literal["id"] = "id"
112+
Article.objects.bulk_update(articles, [pk_field]) # E: bulk_update() cannot be used with primary key fields. Got "id" [misc]
113+
114+
# Test with different models
115+
authors = Author.objects.all()
116+
Author.objects.bulk_update(authors, ["id"]) # E: bulk_update() cannot be used with primary key fields. Got "id" [misc]
117+
Author.objects.bulk_update(authors, ["invalid"]) # E: Author has no field named 'invalid' [misc]
118+
119+
categories = Category.objects.all()
120+
Category.objects.bulk_update(categories, ["id"]) # E: bulk_update() cannot be used with primary key fields. Got "id" [misc]
121+
Category.objects.bulk_update(categories, ["invalid"]) # E: Category has no field named 'invalid' [misc]
122+
123+
files:
124+
- path: myapp/__init__.py
125+
- path: myapp/models.py
126+
content: |
127+
from django.db import models
128+
129+
class Category(models.Model):
130+
name = models.CharField(max_length=100)
131+
parent = models.ForeignKey('self', on_delete=models.CASCADE, null=True, blank=True)
132+
133+
class Author(models.Model):
134+
name = models.CharField(max_length=100)
135+
email = models.EmailField()
136+
137+
class Article(models.Model):
138+
title = models.CharField(max_length=200)
139+
content = models.TextField()
140+
published = models.BooleanField(default=False)
141+
author = models.ForeignKey(Author, on_delete=models.CASCADE)
142+
category = models.ForeignKey(Category, on_delete=models.CASCADE)
143+
tags = models.ManyToManyField('myapp.Tag')
144+
145+
class Tag(models.Model):
146+
name = models.CharField(max_length=50)

0 commit comments

Comments
 (0)