Skip to content

Commit 7a06c7b

Browse files
committed
Ensure reconstructed Enums are idompotent
Previously, mypy would infer that `Foo(Foo.x)` is of type `Foo`. This is problematic as it means Enums can not be reconstructed for runtime saftey. Hence, this commit extends the enum plugin and refines the return type of `Enum.__new__`, ensuring reconstructed Enums are idempotent Fixes #19669
1 parent 8bfecd4 commit 7a06c7b

File tree

3 files changed

+482
-2
lines changed

3 files changed

+482
-2
lines changed

mypy/plugins/default.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import mypy.errorcodes as codes
77
from mypy import message_registry
8-
from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
8+
from mypy.nodes import DictExpr, IntExpr, StrExpr, TypeInfo, UnaryExpr
99
from mypy.plugin import (
1010
AttributeContext,
1111
ClassDefContext,
@@ -47,7 +47,12 @@
4747
dataclass_tag_callback,
4848
replace_function_sig_callback,
4949
)
50-
from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback
50+
from mypy.plugins.enums import (
51+
enum_member_callback,
52+
enum_name_callback,
53+
enum_new_callback,
54+
enum_value_callback,
55+
)
5156
from mypy.plugins.functools import (
5257
functools_total_ordering_maker_callback,
5358
functools_total_ordering_makers,
@@ -104,6 +109,12 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
104109
return partial_new_callback
105110
elif fullname == "enum.member":
106111
return enum_member_callback
112+
elif (
113+
(st := self.lookup_fully_qualified(fullname))
114+
and isinstance(st.node, TypeInfo)
115+
and getattr(st.node, "is_enum", False)
116+
):
117+
return enum_new_callback
107118
return None
108119

109120
def get_function_signature_hook(

mypy/plugins/enums.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import mypy.plugin # To avoid circular imports.
2020
from mypy.checker_shared import TypeCheckerSharedApi
2121
from mypy.nodes import TypeInfo, Var
22+
from mypy.plugin import FunctionContext # depending on where this runs
2223
from mypy.subtypes import is_equivalent
2324
from mypy.typeops import fixup_partial_type, make_simplified_union
2425
from mypy.types import (
@@ -28,6 +29,8 @@
2829
LiteralType,
2930
ProperType,
3031
Type,
32+
TypeVarType,
33+
UnionType,
3134
get_proper_type,
3235
is_named_instance,
3336
)
@@ -297,3 +300,60 @@ def _extract_underlying_field_name(typ: Type) -> str | None:
297300
# as a string.
298301
assert isinstance(underlying_literal.value, str)
299302
return underlying_literal.value
303+
304+
305+
def enum_new_callback(ctx: FunctionContext) -> Type:
306+
"""This plugin refines the return type of `__new__`, ensuring reconstructed
307+
Enums are idempotent.
308+
309+
By default, mypy will infer that `Foo(Foo.x)` is of type `Foo`. This plugin
310+
ensures types are not loosened, meaning with this plugin enabled
311+
`Foo(Foo.x)` is of type `Literal[Foo.x]?`.
312+
313+
This means with this plugin:
314+
```
315+
reveal_type(Foo(Foo.x)) # mypy reveals Literal[Foo.x]?
316+
```
317+
318+
This plugin works by adjusting the return type of `__new__` to be the given
319+
argument type, if and only if `__new__` comes from `enum.Enum`.
320+
321+
This plugin supports arguments that are Final, Literial, Union of Literials
322+
and generic TypeVars.
323+
"""
324+
base_ret = ctx.default_return_type
325+
enum_inst = get_proper_type(base_ret)
326+
if not isinstance(enum_inst, Instance):
327+
return base_ret
328+
329+
info: TypeInfo = enum_inst.type
330+
if not info.is_enum:
331+
return base_ret
332+
333+
if _implements_new(info):
334+
return base_ret
335+
336+
if not ctx.args or not ctx.args[0] or not ctx.arg_types or not ctx.arg_types[0]:
337+
return base_ret
338+
339+
arg0_t = get_proper_type(ctx.arg_types[0][0])
340+
341+
if isinstance(arg0_t, Instance) and arg0_t.type is info:
342+
return arg0_t
343+
elif isinstance(arg0_t, LiteralType) and arg0_t.fallback.type is info:
344+
return arg0_t
345+
elif isinstance(arg0_t, UnionType):
346+
347+
def is_memeber(given_t: Type) -> bool:
348+
return (isinstance(given_t, Instance) and given_t.type is info) or (
349+
isinstance(given_t, LiteralType) and given_t.fallback.type is info
350+
)
351+
352+
items = [get_proper_type(it) for it in arg0_t.items]
353+
if items and all(is_memeber(item) for item in items):
354+
return arg0_t
355+
elif (isinstance(arg0_t, TypeVarType)) and isinstance(arg0_t.upper_bound, Instance):
356+
if arg0_t.upper_bound.type is info:
357+
return arg0_t
358+
359+
return base_ret

0 commit comments

Comments
 (0)