Skip to content

Commit 46ebaca

Browse files
authored
stubgen: Replace obsolete typing aliases with builtin containers (python#16780)
Addresses part of python#16737 This only replaces typing symbols that have equivalents in the `builtins` module. Replacing other symbols, like those from the `collections.abc` module, are a bit more complicated so I suggest we handle them separately. I also changed the default `TypedDict` module from `typing_extensions` to `typing` as typeshed dropped support for Python 3.7.
1 parent eb84794 commit 46ebaca

File tree

7 files changed

+148
-50
lines changed

7 files changed

+148
-50
lines changed

mypy/stubgen.py

+33-18
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
import os.path
4848
import sys
4949
import traceback
50-
from typing import Final, Iterable
50+
from typing import Final, Iterable, Iterator
5151

5252
import mypy.build
5353
import mypy.mixedtraverser
@@ -114,6 +114,7 @@
114114
from mypy.stubdoc import ArgSig, FunctionSig
115115
from mypy.stubgenc import InspectionStubGenerator, generate_stub_for_c_module
116116
from mypy.stubutil import (
117+
TYPING_BUILTIN_REPLACEMENTS,
117118
BaseStubGenerator,
118119
CantImport,
119120
ClassInfo,
@@ -289,20 +290,19 @@ def visit_call_expr(self, node: CallExpr) -> str:
289290
raise ValueError(f"Unknown argument kind {kind} in call")
290291
return f"{callee}({', '.join(args)})"
291292

293+
def _visit_ref_expr(self, node: NameExpr | MemberExpr) -> str:
294+
fullname = self.stubgen.get_fullname(node)
295+
if fullname in TYPING_BUILTIN_REPLACEMENTS:
296+
return self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=False)
297+
qualname = get_qualified_name(node)
298+
self.stubgen.import_tracker.require_name(qualname)
299+
return qualname
300+
292301
def visit_name_expr(self, node: NameExpr) -> str:
293-
self.stubgen.import_tracker.require_name(node.name)
294-
return node.name
302+
return self._visit_ref_expr(node)
295303

296304
def visit_member_expr(self, o: MemberExpr) -> str:
297-
node: Expression = o
298-
trailer = ""
299-
while isinstance(node, MemberExpr):
300-
trailer = "." + node.name + trailer
301-
node = node.expr
302-
if not isinstance(node, NameExpr):
303-
return ERROR_MARKER
304-
self.stubgen.import_tracker.require_name(node.name)
305-
return node.name + trailer
305+
return self._visit_ref_expr(o)
306306

307307
def visit_str_expr(self, node: StrExpr) -> str:
308308
return repr(node.value)
@@ -351,11 +351,17 @@ def find_defined_names(file: MypyFile) -> set[str]:
351351
return finder.names
352352

353353

354+
def get_assigned_names(lvalues: Iterable[Expression]) -> Iterator[str]:
355+
for lvalue in lvalues:
356+
if isinstance(lvalue, NameExpr):
357+
yield lvalue.name
358+
elif isinstance(lvalue, TupleExpr):
359+
yield from get_assigned_names(lvalue.items)
360+
361+
354362
class DefinitionFinder(mypy.traverser.TraverserVisitor):
355363
"""Find names of things defined at the top level of a module."""
356364

357-
# TODO: Assignment statements etc.
358-
359365
def __init__(self) -> None:
360366
# Short names of things defined at the top level.
361367
self.names: set[str] = set()
@@ -368,6 +374,10 @@ def visit_func_def(self, o: FuncDef) -> None:
368374
# Don't recurse, as we only keep track of top-level definitions.
369375
self.names.add(o.name)
370376

377+
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
378+
for name in get_assigned_names(o.lvalues):
379+
self.names.add(name)
380+
371381

372382
def find_referenced_names(file: MypyFile) -> set[str]:
373383
finder = ReferenceFinder()
@@ -1023,10 +1033,15 @@ def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool:
10231033
and isinstance(expr.node, (FuncDef, Decorator, MypyFile))
10241034
or isinstance(expr.node, TypeInfo)
10251035
) and not self.is_private_member(expr.node.fullname)
1026-
elif (
1027-
isinstance(expr, IndexExpr)
1028-
and isinstance(expr.base, NameExpr)
1029-
and not self.is_private_name(expr.base.name)
1036+
elif isinstance(expr, IndexExpr) and (
1037+
(isinstance(expr.base, NameExpr) and not self.is_private_name(expr.base.name))
1038+
or ( # Also some known aliases that could be member expression
1039+
isinstance(expr.base, MemberExpr)
1040+
and not self.is_private_member(get_qualified_name(expr.base))
1041+
and self.get_fullname(expr.base).startswith(
1042+
("builtins.", "typing.", "typing_extensions.", "collections.abc.")
1043+
)
1044+
)
10301045
):
10311046
if isinstance(expr.index, TupleExpr):
10321047
indices = expr.index.items

mypy/stubutil.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,26 @@
2222
# Modules that may fail when imported, or that may have side effects (fully qualified).
2323
NOT_IMPORTABLE_MODULES = ()
2424

25+
# Typing constructs to be replaced by their builtin equivalents.
26+
TYPING_BUILTIN_REPLACEMENTS: Final = {
27+
# From typing
28+
"typing.Text": "builtins.str",
29+
"typing.Tuple": "builtins.tuple",
30+
"typing.List": "builtins.list",
31+
"typing.Dict": "builtins.dict",
32+
"typing.Set": "builtins.set",
33+
"typing.FrozenSet": "builtins.frozenset",
34+
"typing.Type": "builtins.type",
35+
# From typing_extensions
36+
"typing_extensions.Text": "builtins.str",
37+
"typing_extensions.Tuple": "builtins.tuple",
38+
"typing_extensions.List": "builtins.list",
39+
"typing_extensions.Dict": "builtins.dict",
40+
"typing_extensions.Set": "builtins.set",
41+
"typing_extensions.FrozenSet": "builtins.frozenset",
42+
"typing_extensions.Type": "builtins.type",
43+
}
44+
2545

2646
class CantImport(Exception):
2747
def __init__(self, module: str, message: str) -> None:
@@ -229,6 +249,8 @@ def visit_unbound_type(self, t: UnboundType) -> str:
229249
return " | ".join([item.accept(self) for item in t.args])
230250
if fullname == "typing.Optional":
231251
return f"{t.args[0].accept(self)} | None"
252+
if fullname in TYPING_BUILTIN_REPLACEMENTS:
253+
s = self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=True)
232254
if self.known_modules is not None and "." in s:
233255
# see if this object is from any of the modules that we're currently processing.
234256
# reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
@@ -476,7 +498,7 @@ def reexport(self, name: str) -> None:
476498
def import_lines(self) -> list[str]:
477499
"""The list of required import lines (as strings with python code).
478500
479-
In order for a module be included in this output, an indentifier must be both
501+
In order for a module be included in this output, an identifier must be both
480502
'required' via require_name() and 'imported' via add_import_from()
481503
or add_import()
482504
"""
@@ -585,9 +607,9 @@ def __init__(
585607
# a corresponding import statement.
586608
self.known_imports = {
587609
"_typeshed": ["Incomplete"],
588-
"typing": ["Any", "TypeVar", "NamedTuple"],
610+
"typing": ["Any", "TypeVar", "NamedTuple", "TypedDict"],
589611
"collections.abc": ["Generator"],
590-
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
612+
"typing_extensions": ["ParamSpec", "TypeVarTuple"],
591613
}
592614

593615
def get_sig_generators(self) -> list[SignatureGenerator]:
@@ -613,7 +635,10 @@ def add_name(self, fullname: str, require: bool = True) -> str:
613635
"""
614636
module, name = fullname.rsplit(".", 1)
615637
alias = "_" + name if name in self.defined_names else None
616-
self.import_tracker.add_import_from(module, [(name, alias)], require=require)
638+
while alias in self.defined_names:
639+
alias = "_" + alias
640+
if module != "builtins" or alias: # don't import from builtins unless needed
641+
self.import_tracker.add_import_from(module, [(name, alias)], require=require)
617642
return alias or name
618643

619644
def add_import_line(self, line: str) -> None:

test-data/pybind11_fixtures/expected_stubs_no_docs/pybind11_fixtures/__init__.pyi

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from . import demo as demo
3-
from typing import List, Tuple, overload
3+
from typing import overload
44

55
class StaticMethods:
66
def __init__(self, *args, **kwargs) -> None: ...
@@ -22,6 +22,6 @@ class TestStruct:
2222

2323
def func_incomplete_signature(*args, **kwargs): ...
2424
def func_returning_optional() -> int | None: ...
25-
def func_returning_pair() -> Tuple[int, float]: ...
25+
def func_returning_pair() -> tuple[int, float]: ...
2626
def func_returning_path() -> os.PathLike: ...
27-
def func_returning_vector() -> List[float]: ...
27+
def func_returning_vector() -> list[float]: ...

test-data/pybind11_fixtures/expected_stubs_no_docs/pybind11_fixtures/demo.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import ClassVar, List, overload
1+
from typing import ClassVar, overload
22

33
PI: float
44
__version__: str
@@ -47,7 +47,7 @@ class Point:
4747
def __init__(self) -> None: ...
4848
@overload
4949
def __init__(self, x: float, y: float) -> None: ...
50-
def as_list(self) -> List[float]: ...
50+
def as_list(self) -> list[float]: ...
5151
@overload
5252
def distance_to(self, x: float, y: float) -> float: ...
5353
@overload

test-data/pybind11_fixtures/expected_stubs_with_docs/pybind11_fixtures/__init__.pyi

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from . import demo as demo
3-
from typing import List, Tuple, overload
3+
from typing import overload
44

55
class StaticMethods:
66
def __init__(self, *args, **kwargs) -> None:
@@ -44,9 +44,9 @@ def func_incomplete_signature(*args, **kwargs):
4444
"""func_incomplete_signature() -> dummy_sub_namespace::HasNoBinding"""
4545
def func_returning_optional() -> int | None:
4646
"""func_returning_optional() -> Optional[int]"""
47-
def func_returning_pair() -> Tuple[int, float]:
47+
def func_returning_pair() -> tuple[int, float]:
4848
"""func_returning_pair() -> Tuple[int, float]"""
4949
def func_returning_path() -> os.PathLike:
5050
"""func_returning_path() -> os.PathLike"""
51-
def func_returning_vector() -> List[float]:
51+
def func_returning_vector() -> list[float]:
5252
"""func_returning_vector() -> List[float]"""

test-data/pybind11_fixtures/expected_stubs_with_docs/pybind11_fixtures/demo.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import ClassVar, List, overload
1+
from typing import ClassVar, overload
22

33
PI: float
44
__version__: str
@@ -73,7 +73,7 @@ class Point:
7373
7474
2. __init__(self: pybind11_fixtures.demo.Point, x: float, y: float) -> None
7575
"""
76-
def as_list(self) -> List[float]:
76+
def as_list(self) -> list[float]:
7777
"""as_list(self: pybind11_fixtures.demo.Point) -> List[float]"""
7878
@overload
7979
def distance_to(self, x: float, y: float) -> float:

0 commit comments

Comments
 (0)