Skip to content

Commit 37be0aa

Browse files
authored
Merge pull request #119 from honno/special-case-stat
Test special cases in statistical functions
2 parents 03ec5cf + e933be9 commit 37be0aa

File tree

3 files changed

+125
-96
lines changed

3 files changed

+125
-96
lines changed

Diff for: array_api_tests/pytest_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def assert_0d_equals(
198198
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
199199
):
200200
msg = (
201-
f"{out_repr}={out_val}, should be {x_repr}={x_val} "
201+
f"{out_repr}={out_val}, but should be {x_repr}={x_val} "
202202
f"[{func_name}({fmt_kw(kw)})]"
203203
)
204204
if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val):

Diff for: array_api_tests/test_special_cases.py

+123-94
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
"""
2+
Tests for special cases.
3+
4+
Most test cases for special casing are built on runtime via the parametrized
5+
tests test_unary/test_binary/test_iop. Most of this file consists of utility
6+
classes and functions, all bought together to create the test cases (pytest
7+
params), to finally be run through generalised test logic.
8+
9+
TODO: test integer arrays for relevant special cases
10+
"""
111
# We use __future__ for forward reference type hints - this will work for even py3.8.0
212
# See https://stackoverflow.com/a/33533514/5193926
313
from __future__ import annotations
@@ -32,13 +42,6 @@
3242

3343
pytestmark = pytest.mark.ci
3444

35-
# The special case test casess are built on runtime via the parametrized
36-
# test_unary and test_binary functions. Most of this file consists of utility
37-
# classes and functions, all bought together to create the test cases (pytest
38-
# params), to finally be run through the general test logic of either test_unary
39-
# or test_binary.
40-
41-
4245
UnaryCheck = Callable[[float], bool]
4346
BinaryCheck = Callable[[float, float], bool]
4447

@@ -170,24 +173,6 @@ def parse_value(value_str: str) -> float:
170173
r_approx_value = re.compile(
171174
rf"an implementation-dependent approximation to {r_code.pattern}"
172175
)
173-
174-
175-
def parse_inline_code(inline_code: str) -> float:
176-
"""
177-
Parses a Sphinx code string to return a float, e.g.
178-
179-
>>> parse_value('``0``')
180-
0.
181-
>>> parse_value('``NaN``')
182-
float('nan')
183-
184-
"""
185-
if m := r_code.match(inline_code):
186-
return parse_value(m.group(1))
187-
else:
188-
raise ParseError(inline_code)
189-
190-
191176
r_not = re.compile("not (.+)")
192177
r_equal_to = re.compile(f"equal to {r_code.pattern}")
193178
r_array_element = re.compile(r"``([+-]?)x([12])_i``")
@@ -526,6 +511,10 @@ def __repr__(self) -> str:
526511
return f"{self.__class__.__name__}(<{self}>)"
527512

528513

514+
r_case_block = re.compile(r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters")
515+
r_case = re.compile(r"\s+-\s*(.*)\.")
516+
517+
529518
class UnaryCond(Protocol):
530519
def __call__(self, i: float) -> bool:
531520
...
@@ -546,12 +535,34 @@ class UnaryCase(Case):
546535

547536

548537
r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)")
538+
r_already_int_case = re.compile(
539+
"If ``x_i`` is already integer-valued, the result is ``x_i``"
540+
)
549541
r_even_round_halves_case = re.compile(
550542
"If two integers are equally close to ``x_i``, "
551543
"the result is the even integer closest to ``x_i``"
552544
)
553545

554546

547+
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
548+
"""
549+
Returns a strategy that generates float-casted integers within the bounds of dtype.
550+
"""
551+
for k in kw.keys():
552+
# sanity check
553+
assert k in ["min_value", "max_value", "exclude_min", "exclude_max"]
554+
m, M = dh.dtype_ranges[dtype]
555+
if "min_value" in kw.keys():
556+
m = kw["min_value"]
557+
if "exclude_min" in kw.keys():
558+
m += 1
559+
if "max_value" in kw.keys():
560+
M = kw["max_value"]
561+
if "exclude_max" in kw.keys():
562+
M -= 1
563+
return st.integers(math.ceil(m), math.floor(M)).map(float)
564+
565+
555566
def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
556567
"""
557568
Returns a strategy that generates floats that end with .5 and are within the
@@ -568,6 +579,13 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
568579
)
569580

570581

582+
already_int_case = UnaryCase(
583+
cond_expr="x_i.is_integer()",
584+
cond=lambda i: i.is_integer(),
585+
cond_from_dtype=integers_from_dtype,
586+
result_expr="x_i",
587+
check_result=lambda i, result: i == result,
588+
)
571589
even_round_halves_case = UnaryCase(
572590
cond_expr="modf(i)[0] == 0.5",
573591
cond=lambda i: math.modf(i)[0] == 0.5,
@@ -586,7 +604,7 @@ def check_result(i: float, result: float) -> bool:
586604
return check_result
587605

588606

589-
def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
607+
def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
590608
"""
591609
Parses a Sphinx-formatted docstring of a unary function to return a list of
592610
codified unary cases, e.g.
@@ -616,7 +634,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
616634
... an array containing the square root of each element in ``x``
617635
... '''
618636
...
619-
>>> unary_cases = parse_unary_docstring(sqrt.__doc__)
637+
>>> case_block = r_case_block.search(sqrt.__doc__).group(1)
638+
>>> unary_cases = parse_unary_case_block(case_block)
620639
>>> for case in unary_cases:
621640
... print(repr(case))
622641
UnaryCase(<x_i < 0 -> NaN>)
@@ -631,19 +650,14 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
631650
True
632651
633652
"""
634-
635-
match = r_special_cases.search(docstring)
636-
if match is None:
637-
return []
638-
lines = match.group(1).split("\n")[:-1]
639653
cases = []
640-
for line in lines:
641-
if m := r_case.match(line):
642-
case = m.group(1)
643-
else:
644-
warn(f"line not machine-readable: '{line}'")
645-
continue
646-
if m := r_unary_case.search(case):
654+
for case_m in r_case.finditer(case_block):
655+
case_str = case_m.group(1)
656+
if m := r_already_int_case.search(case_str):
657+
cases.append(already_int_case)
658+
elif m := r_even_round_halves_case.search(case_str):
659+
cases.append(even_round_halves_case)
660+
elif m := r_unary_case.search(case_str):
647661
try:
648662
cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1))
649663
_check_result, result_expr = parse_result(m.group(2))
@@ -662,11 +676,9 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
662676
check_result=check_result,
663677
)
664678
cases.append(case)
665-
elif m := r_even_round_halves_case.search(case):
666-
cases.append(even_round_halves_case)
667679
else:
668-
if not r_remaining_case.search(case):
669-
warn(f"case not machine-readable: '{case}'")
680+
if not r_remaining_case.search(case_str):
681+
warn(f"case not machine-readable: '{case_str}'")
670682
return cases
671683

672684

@@ -690,12 +702,6 @@ class BinaryCase(Case):
690702
check_result: BinaryResultCheck
691703

692704

693-
r_special_cases = re.compile(
694-
r"\*\*Special [Cc]ases\*\*(?:\n.*)+"
695-
r"For floating-point operands,\n+"
696-
r"((?:\s*-\s*.*\n)+)"
697-
)
698-
r_case = re.compile(r"\s+-\s*(.*)\.\n?")
699705
r_binary_case = re.compile("If (.+), the result (.+)")
700706
r_remaining_case = re.compile("In the remaining cases.+")
701707
r_cond_sep = re.compile(r"(?<!``x1_i``),? and |(?<!i\.e\.), ")
@@ -843,25 +849,6 @@ def check_result(i1: float, i2: float, result: float) -> bool:
843849
return check_result
844850

845851

846-
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
847-
"""
848-
Returns a strategy that generates float-casted integers within the bounds of dtype.
849-
"""
850-
for k in kw.keys():
851-
# sanity check
852-
assert k in ["min_value", "max_value", "exclude_min", "exclude_max"]
853-
m, M = dh.dtype_ranges[dtype]
854-
if "min_value" in kw.keys():
855-
m = kw["min_value"]
856-
if "exclude_min" in kw.keys():
857-
m += 1
858-
if "max_value" in kw.keys():
859-
M = kw["max_value"]
860-
if "exclude_max" in kw.keys():
861-
M -= 1
862-
return st.integers(math.ceil(m), math.floor(M)).map(float)
863-
864-
865852
def parse_binary_case(case_str: str) -> BinaryCase:
866853
"""
867854
Parses a Sphinx-formatted binary case string to return codified binary cases, e.g.
@@ -880,8 +867,7 @@ def parse_binary_case(case_str: str) -> BinaryCase:
880867
881868
"""
882869
case_m = r_binary_case.match(case_str)
883-
if case_m is None:
884-
raise ParseError(case_str)
870+
assert case_m is not None # sanity check
885871
cond_strs = r_cond_sep.split(case_m.group(1))
886872

887873
partial_conds = []
@@ -1078,7 +1064,7 @@ def cond(i1: float, i2: float) -> bool:
10781064
r_redundant_case = re.compile("result.+determined by the rule already stated above")
10791065

10801066

1081-
def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1067+
def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
10821068
"""
10831069
Parses a Sphinx-formatted docstring of a binary function to return a list of
10841070
codified binary cases, e.g.
@@ -1108,29 +1094,21 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11081094
... an array containing the results
11091095
... '''
11101096
...
1111-
>>> binary_cases = parse_binary_docstring(logaddexp.__doc__)
1097+
>>> case_block = r_case_block.search(logaddexp.__doc__).group(1)
1098+
>>> binary_cases = parse_binary_case_block(case_block)
11121099
>>> for case in binary_cases:
11131100
... print(repr(case))
11141101
BinaryCase(<x1_i == NaN or x2_i == NaN -> NaN>)
11151102
BinaryCase(<x1_i == +infinity and not x2_i == NaN -> +infinity>)
11161103
BinaryCase(<not x1_i == NaN and x2_i == +infinity -> +infinity>)
11171104
11181105
"""
1119-
1120-
match = r_special_cases.search(docstring)
1121-
if match is None:
1122-
return []
1123-
lines = match.group(1).split("\n")[:-1]
11241106
cases = []
1125-
for line in lines:
1126-
if m := r_case.match(line):
1127-
case_str = m.group(1)
1128-
else:
1129-
warn(f"line not machine-readable: '{line}'")
1130-
continue
1107+
for case_m in r_case.finditer(case_block):
1108+
case_str = case_m.group(1)
11311109
if r_redundant_case.search(case_str):
11321110
continue
1133-
if m := r_binary_case.match(case_str):
1111+
if r_binary_case.match(case_str):
11341112
try:
11351113
case = parse_binary_case(case_str)
11361114
cases.append(case)
@@ -1150,6 +1128,10 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11501128
if stub.__doc__ is None:
11511129
warn(f"{stub.__name__}() stub has no docstring")
11521130
continue
1131+
if m := r_case_block.search(stub.__doc__):
1132+
case_block = m.group(1)
1133+
else:
1134+
continue
11531135
marks = []
11541136
try:
11551137
func = getattr(xp, stub.__name__)
@@ -1164,40 +1146,44 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11641146
warn(f"{func=} has no parameters")
11651147
continue
11661148
if param_names[0] == "x":
1167-
if cases := parse_unary_docstring(stub.__doc__):
1168-
func_name_to_func = {stub.__name__: func}
1149+
if cases := parse_unary_case_block(case_block):
1150+
name_to_func = {stub.__name__: func}
11691151
if stub.__name__ in func_to_op.keys():
11701152
op_name = func_to_op[stub.__name__]
11711153
op = getattr(operator, op_name)
1172-
func_name_to_func[op_name] = op
1173-
for func_name, func in func_name_to_func.items():
1154+
name_to_func[op_name] = op
1155+
for func_name, func in name_to_func.items():
11741156
for case in cases:
11751157
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
11761158
p = pytest.param(func_name, func, case, id=id_)
11771159
unary_params.append(p)
1160+
else:
1161+
warn(f"Special cases found for {stub.__name__} but none were parsed")
11781162
continue
11791163
if len(sig.parameters) == 1:
11801164
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
11811165
continue
11821166
if param_names[0] == "x1" and param_names[1] == "x2":
1183-
if cases := parse_binary_docstring(stub.__doc__):
1184-
func_name_to_func = {stub.__name__: func}
1167+
if cases := parse_binary_case_block(case_block):
1168+
name_to_func = {stub.__name__: func}
11851169
if stub.__name__ in func_to_op.keys():
11861170
op_name = func_to_op[stub.__name__]
11871171
op = getattr(operator, op_name)
1188-
func_name_to_func[op_name] = op
1189-
# We collect inplaceoperator test cases seperately
1172+
name_to_func[op_name] = op
1173+
# We collect inplace operator test cases seperately
11901174
iop_name = "__i" + op_name[2:]
11911175
iop = getattr(operator, iop_name)
11921176
for case in cases:
11931177
id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}"
11941178
p = pytest.param(iop_name, iop, case, id=id_)
11951179
iop_params.append(p)
1196-
for func_name, func in func_name_to_func.items():
1180+
for func_name, func in name_to_func.items():
11971181
for case in cases:
11981182
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
11991183
p = pytest.param(func_name, func, case, id=id_)
12001184
binary_params.append(p)
1185+
else:
1186+
warn(f"Special cases found for {stub.__name__} but none were parsed")
12011187
continue
12021188
else:
12031189
warn(
@@ -1206,7 +1192,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
12061192
)
12071193

12081194

1209-
# test_unary and test_binary naively generate arrays, i.e. arrays that might not
1195+
# test_{unary/binary/iop} naively generate arrays, i.e. arrays that might not
12101196
# meet the condition that is being test. We then forcibly make the array meet
12111197
# the condition by picking a random index to insert an acceptable element.
12121198
#
@@ -1343,3 +1329,46 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
13431329
)
13441330
break
13451331
assume(good_example)
1332+
1333+
1334+
@pytest.mark.parametrize(
1335+
"func_name, expected",
1336+
[
1337+
("mean", float("nan")),
1338+
("prod", 1),
1339+
("std", float("nan")),
1340+
("sum", 0),
1341+
("var", float("nan")),
1342+
],
1343+
ids=["mean", "prod", "std", "sum", "var"],
1344+
)
1345+
def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get expected
1346+
func = getattr(xp, func_name)
1347+
out = func(xp.asarray([], dtype=dh.default_float))
1348+
ph.assert_shape(func_name, out.shape, ()) # sanity check
1349+
msg = f"{out=!r}, but should be {expected}"
1350+
if math.isnan(expected):
1351+
assert xp.isnan(out), msg
1352+
else:
1353+
assert out == expected, msg
1354+
1355+
1356+
@pytest.mark.parametrize(
1357+
"func_name", [f.__name__ for f in category_to_funcs["statistical"]]
1358+
)
1359+
@given(
1360+
x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)),
1361+
data=st.data(),
1362+
)
1363+
def test_nan_propagation(func_name, x, data):
1364+
func = getattr(xp, func_name)
1365+
set_idx = data.draw(
1366+
xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx"
1367+
)
1368+
x[set_idx] = float("nan")
1369+
note(f"{x=}")
1370+
1371+
out = func(x)
1372+
1373+
ph.assert_shape(func_name, out.shape, ()) # sanity check
1374+
assert xp.isnan(out), f"{out=!r}, but should be NaN"

0 commit comments

Comments
 (0)