Skip to content

Commit 2453222

Browse files
committed
Ensure reconstructed Enums are idompotent
Previously, mypy would infer that `Foo(Foo.x)` is of type `Foo`. This is problematic as it means Enums can not be reconstructed for runtime saftey. Hence, this commit extends the enum plugin and refines the return type of `Enum.__new__`, ensuring reconstructed Enums are idempotent Fixes #19669
1 parent 8bfecd4 commit 2453222

File tree

3 files changed

+483
-2
lines changed

3 files changed

+483
-2
lines changed

mypy/plugins/default.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import mypy.errorcodes as codes
77
from mypy import message_registry
8-
from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
8+
from mypy.nodes import DictExpr, IntExpr, StrExpr, TypeInfo, UnaryExpr
99
from mypy.plugin import (
1010
AttributeContext,
1111
ClassDefContext,
@@ -47,7 +47,12 @@
4747
dataclass_tag_callback,
4848
replace_function_sig_callback,
4949
)
50-
from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback
50+
from mypy.plugins.enums import (
51+
enum_member_callback,
52+
enum_name_callback,
53+
enum_new_callback,
54+
enum_value_callback,
55+
)
5156
from mypy.plugins.functools import (
5257
functools_total_ordering_maker_callback,
5358
functools_total_ordering_makers,
@@ -104,6 +109,12 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
104109
return partial_new_callback
105110
elif fullname == "enum.member":
106111
return enum_member_callback
112+
elif (
113+
(st := self.lookup_fully_qualified(fullname))
114+
and isinstance(st.node, TypeInfo)
115+
and getattr(st.node, "is_enum", False)
116+
):
117+
return enum_new_callback
107118
return None
108119

109120
def get_function_signature_hook(

mypy/plugins/enums.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
LiteralType,
2929
ProperType,
3030
Type,
31+
TypeVarType,
32+
UnionType,
3133
get_proper_type,
3234
is_named_instance,
3335
)
@@ -297,3 +299,62 @@ def _extract_underlying_field_name(typ: Type) -> str | None:
297299
# as a string.
298300
assert isinstance(underlying_literal.value, str)
299301
return underlying_literal.value
302+
303+
304+
def enum_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
305+
"""This plugin refines the return type of `__new__`, ensuring reconstructed
306+
Enums are idempotent.
307+
308+
By default, mypy will infer that `Foo(Foo.x)` is of type `Foo`. This plugin
309+
ensures types are not loosened, meaning with this plugin enabled
310+
`Foo(Foo.x)` is of type `Literal[Foo.x]?`.
311+
312+
This means with this plugin:
313+
```
314+
reveal_type(Foo(Foo.x)) # mypy reveals Literal[Foo.x]?
315+
```
316+
317+
This plugin works by adjusting the return type of `__new__` to be the given
318+
argument type, if and only if `__new__` comes from `enum.Enum`.
319+
320+
This plugin supports arguments that are Final, Literial, Union of Literials
321+
and generic TypeVars.
322+
"""
323+
base_ret = ctx.default_return_type
324+
enum_inst = get_proper_type(base_ret)
325+
if not isinstance(enum_inst, Instance):
326+
return base_ret
327+
328+
info: TypeInfo = enum_inst.type
329+
if not info.is_enum:
330+
return base_ret
331+
332+
if _implements_new(info):
333+
return base_ret
334+
335+
if not ctx.args or not ctx.args[0] or not ctx.arg_types or not ctx.arg_types[0]:
336+
return base_ret
337+
338+
arg0_t = get_proper_type(ctx.arg_types[0][0])
339+
340+
if isinstance(arg0_t, Instance) and arg0_t.type is info:
341+
return arg0_t
342+
elif isinstance(arg0_t, LiteralType) and arg0_t.fallback.type is info:
343+
return arg0_t
344+
elif isinstance(arg0_t, UnionType):
345+
346+
def is_memeber(given_t: ProperType) -> bool:
347+
return (isinstance(given_t, Instance) and given_t.type is info) or (
348+
isinstance(given_t, LiteralType) and given_t.fallback.type is info
349+
)
350+
351+
items = [get_proper_type(it) for it in arg0_t.items]
352+
if items and all(is_memeber(item) for item in items):
353+
return arg0_t
354+
elif (isinstance(arg0_t, TypeVarType)) and isinstance(
355+
get_proper_type(arg0_t.upper_bound), Instance
356+
):
357+
if arg0_t.upper_bound.type is info:
358+
return arg0_t
359+
360+
return base_ret

0 commit comments

Comments
 (0)