forked from data-apis/array-api-tests
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstubs.py
72 lines (62 loc) · 2.36 KB
/
stubs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import inspect
import sys
from importlib import import_module
from importlib.util import find_spec
from pathlib import Path
from types import FunctionType, ModuleType
from typing import Dict, List
__all__ = [
"name_to_func",
"array_methods",
"array_attributes",
"category_to_funcs",
"EXTENSIONS",
"extension_to_funcs",
]
spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / "API_specification"
assert spec_dir.exists(), f"{spec_dir} not found - try `git submodule update --init`"
sigs_dir = spec_dir / "signatures"
assert sigs_dir.exists()
spec_abs_path: str = str(spec_dir.resolve())
sys.path.append(spec_abs_path)
assert find_spec("signatures") is not None
name_to_mod: Dict[str, ModuleType] = {}
for path in sigs_dir.glob("*.py"):
name = path.name.replace(".py", "")
name_to_mod[name] = import_module(f"signatures.{name}")
array = name_to_mod["array_object"].array
array_methods = [
f for n, f in inspect.getmembers(array, predicate=inspect.isfunction)
if n != "__init__" # probably exists for Sphinx
]
array_attributes = [
n for n, f in inspect.getmembers(array, predicate=lambda x: isinstance(x, property))
]
category_to_funcs: Dict[str, List[FunctionType]] = {}
for name, mod in name_to_mod.items():
if name.endswith("_functions"):
category = name.replace("_functions", "")
objects = [getattr(mod, name) for name in mod.__all__]
assert all(isinstance(o, FunctionType) for o in objects) # sanity check
category_to_funcs[category] = objects
all_funcs = []
for funcs in [array_methods, *category_to_funcs.values()]:
all_funcs.extend(funcs)
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}
EXTENSIONS: List[str] = ["linalg"] # TODO: add "fft" once stubs available
extension_to_funcs: Dict[str, List[FunctionType]] = {}
for ext in EXTENSIONS:
mod = name_to_mod[ext]
objects = [getattr(mod, name) for name in mod.__all__]
assert all(isinstance(o, FunctionType) for o in objects) # sanity check
funcs = []
for func in objects:
if "Alias" in func.__doc__:
funcs.append(name_to_func[func.__name__])
else:
funcs.append(func)
extension_to_funcs[ext] = funcs
for funcs in extension_to_funcs.values():
for func in funcs:
if func.__name__ not in name_to_func.keys():
name_to_func[func.__name__] = func