Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions build/py3_8/float_int_stubs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from __future__ import annotations

import ast
import os
from collections.abc import Iterable
from pathlib import Path
from typing import Final

from typing_extensions import override

KEEP_FLOAT: Final[frozenset[str]] = frozenset((
# Example:
# If we didn't want to change the type of the `priority` parameter
# of the `register` function in the `Registry` class of markdown/util.pyi
# "stubs/Markdown/markdown/util.pyi/Registry.register.priority",
# See implementation `_node_path` for details on this identifier string.
))
""" identifiers for `float` that we don't want to change to `float | int` """


def name_for_target(node: ast.AnnAssign) -> str:
return (
node.target.id
if isinstance(node.target, ast.Name)
else node.target.attr
if isinstance(node.target, ast.Attribute)
else "subscript"
)


def name_for_node(
node: ast.ClassDef
| ast.FunctionDef
| ast.AsyncFunctionDef
| ast.arg
| ast.Name
| ast.AnnAssign,
) -> str:
return (
node.name
if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef))
else node.arg
if isinstance(node, ast.arg)
else name_for_target(node)
if isinstance(node, ast.AnnAssign)
else node.id
)


def has_int_child(node: ast.BinOp) -> bool:
if isinstance(node.right, ast.Name) and node.right.id == "int":
return True
if isinstance(node.left, ast.Name) and node.left.id == "int":
return True
if isinstance(node.right, ast.BinOp):
# assuming "|" is the only BinOp in annotations
assert isinstance(node.right.op, ast.BitOr), node.right.op
if has_int_child(node.right):
return True
if isinstance(node.left, ast.BinOp):
assert isinstance(node.left.op, ast.BitOr), node.left.op
if has_int_child(node.left):
return True
return False


class AnnotationTrackingVisitor(ast.NodeVisitor):
parent_stack: list[ast.AST]
in_ann: str | None = None
floats: list[ast.Name]
module: str

def __init__(self, module: str) -> None:
self.parent_stack = []
self.floats = []
self.module = module

@override
def visit(self, node: ast.AST) -> None:
self.parent_stack.append(node)
super().visit(node)
_ = self.parent_stack.pop()

@override
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
assert len(list(ast.iter_fields(node))) == 4, list(ast.iter_fields(node))
# I don't know what the 4th field "simple" is, but it's not an AST.

self.visit(node.target)

self.in_ann = name_for_target(node)
self.visit(node.annotation)
self.in_ann = None

if node.value:
self.visit(node.value)

@override
def visit_arg(self, node: ast.arg) -> None:
# arg name str, annotation, type comment str
assert len(list(ast.iter_fields(node))) == 3, list(ast.iter_fields(node))

self.in_ann = node.arg
if node.annotation:
self.visit(node.annotation)
self.in_ann = None

# NOTE: assuming function return values are actually float if annotated as such.
# If we don't want to assume that, rename this to `visit_FunctionDef`:
# (probably would also want `visit_AsyncFunctionDef`)

# @override
def _unused(self, node: ast.FunctionDef) -> None:
# Copied from implementation of base generic_visit
# and modified for "returns"
for field, value in ast.iter_fields(node): # pyright: ignore[reportAny]
if isinstance(value, ast.AST):
if field == "returns":
self.in_ann = "returns"
self.visit(value)
if field == "returns":
self.in_ann = None
elif isinstance(value, Iterable):
for item in value:
if isinstance(item, ast.AST):
self.visit(item)

def _node_path(self) -> str:
"""a string that identifies the current node (from `self.parent_stack`)"""
strs = [
name_for_node(n)
for n in self.parent_stack
if isinstance(
n,
(
ast.ClassDef,
ast.FunctionDef,
ast.AsyncFunctionDef,
ast.arg,
ast.Name,
ast.AnnAssign,
),
)
]
if len(strs) > 0 and strs[-1] == "float":
_ = strs.pop()
return self.module + "/" + ".".join(strs)

def _with_int(self) -> bool:
"""`float | int` already"""
assert isinstance(self.parent_stack[-1], ast.Name), self.parent_stack
assert self.parent_stack[-1].id == "float", self.parent_stack[-1].id
index = len(self.parent_stack) - 2
while index >= 0:
traverse_node = self.parent_stack[index]
if not isinstance(traverse_node, ast.BinOp):
return False
# assuming "|" is the only BinOp in annotations
assert isinstance(traverse_node.op, ast.BitOr), traverse_node.op
if has_int_child(traverse_node):
return True
index -= 1
return False

def _is_final(self) -> bool:
assert isinstance(self.parent_stack[-1], ast.Name), self.parent_stack
assert self.parent_stack[-1].id == "float", self.parent_stack
if len(self.parent_stack) > 1:
parent = self.parent_stack[-2]
if (
isinstance(parent, ast.Subscript)
and isinstance(parent.value, ast.Name)
and parent.value.id == "Final"
):
assert parent.slice is self.parent_stack[-1], parent.slice
return True
return False

@override
def generic_visit(self, node: ast.AST) -> None:
if self.in_ann is not None and isinstance(node, ast.Name) and node.id == "float":
assert node is self.parent_stack[-1], self.parent_stack

if self._with_int() or self._is_final():
# There's already some already float | int in typeshed
# and assuming `Final[float]` is really float
pass
else:
node_path = self._node_path()
if node_path not in KEEP_FLOAT:
self.floats.append(node)
super().generic_visit(node)


def float_expand(stubs: Path) -> None:
"""change stubs in the given directory from `float` to `float | int`"""
for dir_path, _dir_names, file_names in os.walk(stubs):
for file_name in file_names:
if not file_name.endswith(".pyi"):
continue
file_path = Path(dir_path) / file_name
rel_path = Path(os.path.relpath(file_path, stubs)).as_posix()
file_bytes = Path(file_path).read_bytes()
file_parsed = ast.parse(file_bytes)
v = AnnotationTrackingVisitor(rel_path)
v.visit(file_parsed)

if len(v.floats) > 0:
print(file_path.as_posix())
# compute start offset of each line in file
lines = file_bytes.split(b"\n")
line_starts = [0]
for line in lines:
line_starts.append(line_starts[-1] + len(line) + 1) # +1 for newline

# process in reverse order to avoid offset changes affecting subsequent replacements
v.floats.sort(key=lambda n: (n.lineno, n.col_offset), reverse=True)
for fl in v.floats:
assert fl.end_lineno == fl.lineno, fl # always within 1 line
assert fl.end_col_offset is not None
assert fl.end_col_offset == fl.col_offset + 5 # always "float" 5 chars

# calculate offsets in file (from offsets in line)
line_start = line_starts[fl.lineno - 1]
start_offset = line_start + fl.col_offset
end_offset = line_start + fl.end_col_offset

assert file_bytes[start_offset:end_offset] == b"float", file_bytes[
start_offset:end_offset
]

file_bytes = (
file_bytes[:start_offset] + b"float | int" + file_bytes[end_offset:]
)

_ = Path(file_path).write_bytes(file_bytes)


if __name__ == "__main__":
stubs_with_docs_path = Path("docstubs")
float_expand(stubs_with_docs_path)
Loading