Skip to content

Commit 03ec5cf

Browse files
authored
Merge pull request #110 from honno/try-inspect-sig
Revamped signature tests
2 parents a2f7bd5 + b172d84 commit 03ec5cf

File tree

4 files changed

+368
-341
lines changed

4 files changed

+368
-341
lines changed

array_api_tests/dtype_helpers.py

+39-62
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import re
12
from collections.abc import Mapping
23
from functools import lru_cache
3-
from typing import Any, NamedTuple, Sequence, Tuple, Union
4+
from inspect import signature
5+
from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union
46
from warnings import warn
57

68
from . import _array_module as xp
79
from ._array_module import _UndefinedStub
10+
from .stubs import name_to_func
811
from .typing import DataType, ScalarType
912

1013
__all__ = [
@@ -242,67 +245,31 @@ def result_type(*dtypes: DataType):
242245
return result
243246

244247

245-
func_in_dtypes = {
246-
# elementwise
247-
"abs": numeric_dtypes,
248-
"acos": float_dtypes,
249-
"acosh": float_dtypes,
250-
"add": numeric_dtypes,
251-
"asin": float_dtypes,
252-
"asinh": float_dtypes,
253-
"atan": float_dtypes,
254-
"atan2": float_dtypes,
255-
"atanh": float_dtypes,
256-
"bitwise_and": bool_and_all_int_dtypes,
257-
"bitwise_invert": bool_and_all_int_dtypes,
258-
"bitwise_left_shift": all_int_dtypes,
259-
"bitwise_or": bool_and_all_int_dtypes,
260-
"bitwise_right_shift": all_int_dtypes,
261-
"bitwise_xor": bool_and_all_int_dtypes,
262-
"ceil": numeric_dtypes,
263-
"cos": float_dtypes,
264-
"cosh": float_dtypes,
265-
"divide": float_dtypes,
266-
"equal": all_dtypes,
267-
"exp": float_dtypes,
268-
"expm1": float_dtypes,
269-
"floor": numeric_dtypes,
270-
"floor_divide": numeric_dtypes,
271-
"greater": numeric_dtypes,
272-
"greater_equal": numeric_dtypes,
273-
"isfinite": numeric_dtypes,
274-
"isinf": numeric_dtypes,
275-
"isnan": numeric_dtypes,
276-
"less": numeric_dtypes,
277-
"less_equal": numeric_dtypes,
278-
"log": float_dtypes,
279-
"logaddexp": float_dtypes,
280-
"log10": float_dtypes,
281-
"log1p": float_dtypes,
282-
"log2": float_dtypes,
283-
"logical_and": (xp.bool,),
284-
"logical_not": (xp.bool,),
285-
"logical_or": (xp.bool,),
286-
"logical_xor": (xp.bool,),
287-
"multiply": numeric_dtypes,
288-
"negative": numeric_dtypes,
289-
"not_equal": all_dtypes,
290-
"positive": numeric_dtypes,
291-
"pow": numeric_dtypes,
292-
"remainder": numeric_dtypes,
293-
"round": numeric_dtypes,
294-
"sign": numeric_dtypes,
295-
"sin": float_dtypes,
296-
"sinh": float_dtypes,
297-
"sqrt": float_dtypes,
298-
"square": numeric_dtypes,
299-
"subtract": numeric_dtypes,
300-
"tan": float_dtypes,
301-
"tanh": float_dtypes,
302-
"trunc": numeric_dtypes,
303-
# searching
304-
"where": all_dtypes,
248+
r_alias = re.compile("[aA]lias")
249+
r_in_dtypes = re.compile("x1?: array\n.+have an? (.+) data type.")
250+
r_int_note = re.compile(
251+
"If one or both of the input arrays have integer data types, "
252+
"the result is implementation-dependent"
253+
)
254+
category_to_dtypes = {
255+
"boolean": (xp.bool,),
256+
"integer": all_int_dtypes,
257+
"floating-point": float_dtypes,
258+
"numeric": numeric_dtypes,
259+
"integer or boolean": bool_and_all_int_dtypes,
305260
}
261+
func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {}
262+
for name, func in name_to_func.items():
263+
if m := r_in_dtypes.search(func.__doc__):
264+
dtype_category = m.group(1)
265+
if dtype_category == "numeric" and r_int_note.search(func.__doc__):
266+
dtype_category = "floating-point"
267+
dtypes = category_to_dtypes[dtype_category]
268+
func_in_dtypes[name] = dtypes
269+
elif any("x" in name for name in signature(func).parameters.keys()):
270+
func_in_dtypes[name] = all_dtypes
271+
# See https://github.com/data-apis/array-api/pull/413
272+
func_in_dtypes["expm1"] = float_dtypes
306273

307274

308275
func_returns_bool = {
@@ -365,6 +332,8 @@ def result_type(*dtypes: DataType):
365332
"trunc": False,
366333
# searching
367334
"where": False,
335+
# linalg
336+
"matmul": False,
368337
}
369338

370339

@@ -408,7 +377,7 @@ def result_type(*dtypes: DataType):
408377
"__gt__": "greater",
409378
"__le__": "less_equal",
410379
"__lt__": "less",
411-
# '__matmul__': 'matmul', # TODO: support matmul
380+
"__matmul__": "matmul",
412381
"__mod__": "remainder",
413382
"__mul__": "multiply",
414383
"__ne__": "not_equal",
@@ -440,6 +409,14 @@ def result_type(*dtypes: DataType):
440409
func_returns_bool[iop] = func_returns_bool[op]
441410

442411

412+
func_in_dtypes["__bool__"] = (xp.bool,)
413+
func_in_dtypes["__int__"] = all_int_dtypes
414+
func_in_dtypes["__index__"] = all_int_dtypes
415+
func_in_dtypes["__float__"] = float_dtypes
416+
func_in_dtypes["from_dlpack"] = numeric_dtypes
417+
func_in_dtypes["__dlpack__"] = numeric_dtypes
418+
419+
443420
@lru_cache
444421
def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str:
445422
f_types = []
+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from inspect import Parameter, Signature, signature
2+
3+
import pytest
4+
5+
from ..test_signatures import _test_inspectable_func
6+
7+
8+
def stub(foo, /, bar=None, *, baz=None):
9+
pass
10+
11+
12+
stub_sig = signature(stub)
13+
14+
15+
@pytest.mark.parametrize(
16+
"sig",
17+
[
18+
Signature(
19+
[
20+
Parameter("foo", Parameter.POSITIONAL_ONLY),
21+
Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD),
22+
Parameter("baz", Parameter.KEYWORD_ONLY),
23+
]
24+
),
25+
Signature(
26+
[
27+
Parameter("foo", Parameter.POSITIONAL_ONLY),
28+
Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD),
29+
Parameter("baz", Parameter.POSITIONAL_OR_KEYWORD),
30+
]
31+
),
32+
Signature(
33+
[
34+
Parameter("foo", Parameter.POSITIONAL_ONLY),
35+
Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD),
36+
Parameter("qux", Parameter.KEYWORD_ONLY),
37+
Parameter("baz", Parameter.KEYWORD_ONLY),
38+
]
39+
),
40+
],
41+
)
42+
def test_good_sig_passes(sig):
43+
_test_inspectable_func(sig, stub_sig)
44+
45+
46+
@pytest.mark.parametrize(
47+
"sig",
48+
[
49+
Signature(
50+
[
51+
Parameter("foo", Parameter.POSITIONAL_ONLY),
52+
Parameter("bar", Parameter.POSITIONAL_ONLY),
53+
Parameter("baz", Parameter.KEYWORD_ONLY),
54+
]
55+
),
56+
Signature(
57+
[
58+
Parameter("foo", Parameter.POSITIONAL_ONLY),
59+
Parameter("bar", Parameter.KEYWORD_ONLY),
60+
Parameter("baz", Parameter.KEYWORD_ONLY),
61+
]
62+
),
63+
],
64+
)
65+
def test_raises_on_bad_sig(sig):
66+
with pytest.raises(AssertionError):
67+
_test_inspectable_func(sig, stub_sig)

array_api_tests/stubs.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,29 @@
4040
if name.endswith("_functions"):
4141
category = name.replace("_functions", "")
4242
objects = [getattr(mod, name) for name in mod.__all__]
43-
assert all(isinstance(o, FunctionType) for o in objects)
43+
assert all(isinstance(o, FunctionType) for o in objects) # sanity check
4444
category_to_funcs[category] = objects
4545

46+
all_funcs = []
47+
for funcs in [array_methods, *category_to_funcs.values()]:
48+
all_funcs.extend(funcs)
49+
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}
50+
4651
EXTENSIONS: str = ["linalg"]
4752
extension_to_funcs: Dict[str, List[FunctionType]] = {}
4853
for ext in EXTENSIONS:
4954
mod = name_to_mod[ext]
5055
objects = [getattr(mod, name) for name in mod.__all__]
51-
assert all(isinstance(o, FunctionType) for o in objects)
52-
extension_to_funcs[ext] = objects
56+
assert all(isinstance(o, FunctionType) for o in objects) # sanity check
57+
funcs = []
58+
for func in objects:
59+
if "Alias" in func.__doc__:
60+
funcs.append(name_to_func[func.__name__])
61+
else:
62+
funcs.append(func)
63+
extension_to_funcs[ext] = funcs
5364

54-
all_funcs = []
55-
for funcs in [array_methods, *category_to_funcs.values(), *extension_to_funcs.values()]:
56-
all_funcs.extend(funcs)
57-
name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}
65+
for funcs in extension_to_funcs.values():
66+
for func in funcs:
67+
if func.__name__ not in name_to_func.keys():
68+
name_to_func[func.__name__] = func

0 commit comments

Comments
 (0)