1
+ """
2
+ Tests for special cases.
3
+
4
+ Most test cases for special casing are built on runtime via the parametrized
5
+ tests test_unary/test_binary/test_iop. Most of this file consists of utility
6
+ classes and functions, all bought together to create the test cases (pytest
7
+ params), to finally be run through generalised test logic.
8
+
9
+ TODO: test integer arrays for relevant special cases
10
+ """
1
11
# We use __future__ for forward reference type hints - this will work for even py3.8.0
2
12
# See https://stackoverflow.com/a/33533514/5193926
3
13
from __future__ import annotations
32
42
33
43
pytestmark = pytest .mark .ci
34
44
35
- # The special case test casess are built on runtime via the parametrized
36
- # test_unary and test_binary functions. Most of this file consists of utility
37
- # classes and functions, all bought together to create the test cases (pytest
38
- # params), to finally be run through the general test logic of either test_unary
39
- # or test_binary.
40
-
41
-
42
45
UnaryCheck = Callable [[float ], bool ]
43
46
BinaryCheck = Callable [[float , float ], bool ]
44
47
@@ -170,24 +173,6 @@ def parse_value(value_str: str) -> float:
170
173
r_approx_value = re .compile (
171
174
rf"an implementation-dependent approximation to { r_code .pattern } "
172
175
)
173
-
174
-
175
- def parse_inline_code (inline_code : str ) -> float :
176
- """
177
- Parses a Sphinx code string to return a float, e.g.
178
-
179
- >>> parse_value('``0``')
180
- 0.
181
- >>> parse_value('``NaN``')
182
- float('nan')
183
-
184
- """
185
- if m := r_code .match (inline_code ):
186
- return parse_value (m .group (1 ))
187
- else :
188
- raise ParseError (inline_code )
189
-
190
-
191
176
r_not = re .compile ("not (.+)" )
192
177
r_equal_to = re .compile (f"equal to { r_code .pattern } " )
193
178
r_array_element = re .compile (r"``([+-]?)x([12])_i``" )
@@ -526,6 +511,10 @@ def __repr__(self) -> str:
526
511
return f"{ self .__class__ .__name__ } (<{ self } >)"
527
512
528
513
514
+ r_case_block = re .compile (r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters" )
515
+ r_case = re .compile (r"\s+-\s*(.*)\." )
516
+
517
+
529
518
class UnaryCond (Protocol ):
530
519
def __call__ (self , i : float ) -> bool :
531
520
...
@@ -546,12 +535,34 @@ class UnaryCase(Case):
546
535
547
536
548
537
r_unary_case = re .compile ("If ``x_i`` is (.+), the result is (.+)" )
538
+ r_already_int_case = re .compile (
539
+ "If ``x_i`` is already integer-valued, the result is ``x_i``"
540
+ )
549
541
r_even_round_halves_case = re .compile (
550
542
"If two integers are equally close to ``x_i``, "
551
543
"the result is the even integer closest to ``x_i``"
552
544
)
553
545
554
546
547
+ def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
548
+ """
549
+ Returns a strategy that generates float-casted integers within the bounds of dtype.
550
+ """
551
+ for k in kw .keys ():
552
+ # sanity check
553
+ assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
554
+ m , M = dh .dtype_ranges [dtype ]
555
+ if "min_value" in kw .keys ():
556
+ m = kw ["min_value" ]
557
+ if "exclude_min" in kw .keys ():
558
+ m += 1
559
+ if "max_value" in kw .keys ():
560
+ M = kw ["max_value" ]
561
+ if "exclude_max" in kw .keys ():
562
+ M -= 1
563
+ return st .integers (math .ceil (m ), math .floor (M )).map (float )
564
+
565
+
555
566
def trailing_halves_from_dtype (dtype : DataType ) -> st .SearchStrategy [float ]:
556
567
"""
557
568
Returns a strategy that generates floats that end with .5 and are within the
@@ -568,6 +579,13 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
568
579
)
569
580
570
581
582
+ already_int_case = UnaryCase (
583
+ cond_expr = "x_i.is_integer()" ,
584
+ cond = lambda i : i .is_integer (),
585
+ cond_from_dtype = integers_from_dtype ,
586
+ result_expr = "x_i" ,
587
+ check_result = lambda i , result : i == result ,
588
+ )
571
589
even_round_halves_case = UnaryCase (
572
590
cond_expr = "modf(i)[0] == 0.5" ,
573
591
cond = lambda i : math .modf (i )[0 ] == 0.5 ,
@@ -586,7 +604,7 @@ def check_result(i: float, result: float) -> bool:
586
604
return check_result
587
605
588
606
589
- def parse_unary_docstring ( docstring : str ) -> List [UnaryCase ]:
607
+ def parse_unary_case_block ( case_block : str ) -> List [UnaryCase ]:
590
608
"""
591
609
Parses a Sphinx-formatted docstring of a unary function to return a list of
592
610
codified unary cases, e.g.
@@ -616,7 +634,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
616
634
... an array containing the square root of each element in ``x``
617
635
... '''
618
636
...
619
- >>> unary_cases = parse_unary_docstring(sqrt.__doc__)
637
+ >>> case_block = r_case_block.search(sqrt.__doc__).group(1)
638
+ >>> unary_cases = parse_unary_case_block(case_block)
620
639
>>> for case in unary_cases:
621
640
... print(repr(case))
622
641
UnaryCase(<x_i < 0 -> NaN>)
@@ -631,19 +650,14 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
631
650
True
632
651
633
652
"""
634
-
635
- match = r_special_cases .search (docstring )
636
- if match is None :
637
- return []
638
- lines = match .group (1 ).split ("\n " )[:- 1 ]
639
653
cases = []
640
- for line in lines :
641
- if m := r_case . match ( line ):
642
- case = m . group ( 1 )
643
- else :
644
- warn ( f"line not machine-readable: ' { line } '" )
645
- continue
646
- if m := r_unary_case .search (case ):
654
+ for case_m in r_case . finditer ( case_block ) :
655
+ case_str = case_m . group ( 1 )
656
+ if m := r_already_int_case . search ( case_str ):
657
+ cases . append ( already_int_case )
658
+ elif m := r_even_round_halves_case . search ( case_str ):
659
+ cases . append ( even_round_halves_case )
660
+ elif m := r_unary_case .search (case_str ):
647
661
try :
648
662
cond , cond_expr_template , cond_from_dtype = parse_cond (m .group (1 ))
649
663
_check_result , result_expr = parse_result (m .group (2 ))
@@ -662,11 +676,9 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
662
676
check_result = check_result ,
663
677
)
664
678
cases .append (case )
665
- elif m := r_even_round_halves_case .search (case ):
666
- cases .append (even_round_halves_case )
667
679
else :
668
- if not r_remaining_case .search (case ):
669
- warn (f"case not machine-readable: '{ case } '" )
680
+ if not r_remaining_case .search (case_str ):
681
+ warn (f"case not machine-readable: '{ case_str } '" )
670
682
return cases
671
683
672
684
@@ -690,12 +702,6 @@ class BinaryCase(Case):
690
702
check_result : BinaryResultCheck
691
703
692
704
693
- r_special_cases = re .compile (
694
- r"\*\*Special [Cc]ases\*\*(?:\n.*)+"
695
- r"For floating-point operands,\n+"
696
- r"((?:\s*-\s*.*\n)+)"
697
- )
698
- r_case = re .compile (r"\s+-\s*(.*)\.\n?" )
699
705
r_binary_case = re .compile ("If (.+), the result (.+)" )
700
706
r_remaining_case = re .compile ("In the remaining cases.+" )
701
707
r_cond_sep = re .compile (r"(?<!``x1_i``),? and |(?<!i\.e\.), " )
@@ -843,25 +849,6 @@ def check_result(i1: float, i2: float, result: float) -> bool:
843
849
return check_result
844
850
845
851
846
- def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
847
- """
848
- Returns a strategy that generates float-casted integers within the bounds of dtype.
849
- """
850
- for k in kw .keys ():
851
- # sanity check
852
- assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
853
- m , M = dh .dtype_ranges [dtype ]
854
- if "min_value" in kw .keys ():
855
- m = kw ["min_value" ]
856
- if "exclude_min" in kw .keys ():
857
- m += 1
858
- if "max_value" in kw .keys ():
859
- M = kw ["max_value" ]
860
- if "exclude_max" in kw .keys ():
861
- M -= 1
862
- return st .integers (math .ceil (m ), math .floor (M )).map (float )
863
-
864
-
865
852
def parse_binary_case (case_str : str ) -> BinaryCase :
866
853
"""
867
854
Parses a Sphinx-formatted binary case string to return codified binary cases, e.g.
@@ -880,8 +867,7 @@ def parse_binary_case(case_str: str) -> BinaryCase:
880
867
881
868
"""
882
869
case_m = r_binary_case .match (case_str )
883
- if case_m is None :
884
- raise ParseError (case_str )
870
+ assert case_m is not None # sanity check
885
871
cond_strs = r_cond_sep .split (case_m .group (1 ))
886
872
887
873
partial_conds = []
@@ -1078,7 +1064,7 @@ def cond(i1: float, i2: float) -> bool:
1078
1064
r_redundant_case = re .compile ("result.+determined by the rule already stated above" )
1079
1065
1080
1066
1081
- def parse_binary_docstring ( docstring : str ) -> List [BinaryCase ]:
1067
+ def parse_binary_case_block ( case_block : str ) -> List [BinaryCase ]:
1082
1068
"""
1083
1069
Parses a Sphinx-formatted docstring of a binary function to return a list of
1084
1070
codified binary cases, e.g.
@@ -1108,29 +1094,21 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1108
1094
... an array containing the results
1109
1095
... '''
1110
1096
...
1111
- >>> binary_cases = parse_binary_docstring(logaddexp.__doc__)
1097
+ >>> case_block = r_case_block.search(logaddexp.__doc__).group(1)
1098
+ >>> binary_cases = parse_binary_case_block(case_block)
1112
1099
>>> for case in binary_cases:
1113
1100
... print(repr(case))
1114
1101
BinaryCase(<x1_i == NaN or x2_i == NaN -> NaN>)
1115
1102
BinaryCase(<x1_i == +infinity and not x2_i == NaN -> +infinity>)
1116
1103
BinaryCase(<not x1_i == NaN and x2_i == +infinity -> +infinity>)
1117
1104
1118
1105
"""
1119
-
1120
- match = r_special_cases .search (docstring )
1121
- if match is None :
1122
- return []
1123
- lines = match .group (1 ).split ("\n " )[:- 1 ]
1124
1106
cases = []
1125
- for line in lines :
1126
- if m := r_case .match (line ):
1127
- case_str = m .group (1 )
1128
- else :
1129
- warn (f"line not machine-readable: '{ line } '" )
1130
- continue
1107
+ for case_m in r_case .finditer (case_block ):
1108
+ case_str = case_m .group (1 )
1131
1109
if r_redundant_case .search (case_str ):
1132
1110
continue
1133
- if m := r_binary_case .match (case_str ):
1111
+ if r_binary_case .match (case_str ):
1134
1112
try :
1135
1113
case = parse_binary_case (case_str )
1136
1114
cases .append (case )
@@ -1150,6 +1128,10 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1150
1128
if stub .__doc__ is None :
1151
1129
warn (f"{ stub .__name__ } () stub has no docstring" )
1152
1130
continue
1131
+ if m := r_case_block .search (stub .__doc__ ):
1132
+ case_block = m .group (1 )
1133
+ else :
1134
+ continue
1153
1135
marks = []
1154
1136
try :
1155
1137
func = getattr (xp , stub .__name__ )
@@ -1164,40 +1146,44 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1164
1146
warn (f"{ func = } has no parameters" )
1165
1147
continue
1166
1148
if param_names [0 ] == "x" :
1167
- if cases := parse_unary_docstring ( stub . __doc__ ):
1168
- func_name_to_func = {stub .__name__ : func }
1149
+ if cases := parse_unary_case_block ( case_block ):
1150
+ name_to_func = {stub .__name__ : func }
1169
1151
if stub .__name__ in func_to_op .keys ():
1170
1152
op_name = func_to_op [stub .__name__ ]
1171
1153
op = getattr (operator , op_name )
1172
- func_name_to_func [op_name ] = op
1173
- for func_name , func in func_name_to_func .items ():
1154
+ name_to_func [op_name ] = op
1155
+ for func_name , func in name_to_func .items ():
1174
1156
for case in cases :
1175
1157
id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1176
1158
p = pytest .param (func_name , func , case , id = id_ )
1177
1159
unary_params .append (p )
1160
+ else :
1161
+ warn (f"Special cases found for { stub .__name__ } but none were parsed" )
1178
1162
continue
1179
1163
if len (sig .parameters ) == 1 :
1180
1164
warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
1181
1165
continue
1182
1166
if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1183
- if cases := parse_binary_docstring ( stub . __doc__ ):
1184
- func_name_to_func = {stub .__name__ : func }
1167
+ if cases := parse_binary_case_block ( case_block ):
1168
+ name_to_func = {stub .__name__ : func }
1185
1169
if stub .__name__ in func_to_op .keys ():
1186
1170
op_name = func_to_op [stub .__name__ ]
1187
1171
op = getattr (operator , op_name )
1188
- func_name_to_func [op_name ] = op
1189
- # We collect inplaceoperator test cases seperately
1172
+ name_to_func [op_name ] = op
1173
+ # We collect inplace operator test cases seperately
1190
1174
iop_name = "__i" + op_name [2 :]
1191
1175
iop = getattr (operator , iop_name )
1192
1176
for case in cases :
1193
1177
id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
1194
1178
p = pytest .param (iop_name , iop , case , id = id_ )
1195
1179
iop_params .append (p )
1196
- for func_name , func in func_name_to_func .items ():
1180
+ for func_name , func in name_to_func .items ():
1197
1181
for case in cases :
1198
1182
id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1199
1183
p = pytest .param (func_name , func , case , id = id_ )
1200
1184
binary_params .append (p )
1185
+ else :
1186
+ warn (f"Special cases found for { stub .__name__ } but none were parsed" )
1201
1187
continue
1202
1188
else :
1203
1189
warn (
@@ -1206,7 +1192,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1206
1192
)
1207
1193
1208
1194
1209
- # test_unary and test_binary naively generate arrays, i.e. arrays that might not
1195
+ # test_{unary/binary/iop} naively generate arrays, i.e. arrays that might not
1210
1196
# meet the condition that is being test. We then forcibly make the array meet
1211
1197
# the condition by picking a random index to insert an acceptable element.
1212
1198
#
@@ -1343,3 +1329,46 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
1343
1329
)
1344
1330
break
1345
1331
assume (good_example )
1332
+
1333
+
1334
+ @pytest .mark .parametrize (
1335
+ "func_name, expected" ,
1336
+ [
1337
+ ("mean" , float ("nan" )),
1338
+ ("prod" , 1 ),
1339
+ ("std" , float ("nan" )),
1340
+ ("sum" , 0 ),
1341
+ ("var" , float ("nan" )),
1342
+ ],
1343
+ ids = ["mean" , "prod" , "std" , "sum" , "var" ],
1344
+ )
1345
+ def test_empty_arrays (func_name , expected ): # TODO: parse docstrings to get expected
1346
+ func = getattr (xp , func_name )
1347
+ out = func (xp .asarray ([], dtype = dh .default_float ))
1348
+ ph .assert_shape (func_name , out .shape , ()) # sanity check
1349
+ msg = f"{ out = !r} , but should be { expected } "
1350
+ if math .isnan (expected ):
1351
+ assert xp .isnan (out ), msg
1352
+ else :
1353
+ assert out == expected , msg
1354
+
1355
+
1356
+ @pytest .mark .parametrize (
1357
+ "func_name" , [f .__name__ for f in category_to_funcs ["statistical" ]]
1358
+ )
1359
+ @given (
1360
+ x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )),
1361
+ data = st .data (),
1362
+ )
1363
+ def test_nan_propagation (func_name , x , data ):
1364
+ func = getattr (xp , func_name )
1365
+ set_idx = data .draw (
1366
+ xps .indices (x .shape , max_dims = 0 , allow_ellipsis = False ), label = "set idx"
1367
+ )
1368
+ x [set_idx ] = float ("nan" )
1369
+ note (f"{ x = } " )
1370
+
1371
+ out = func (x )
1372
+
1373
+ ph .assert_shape (func_name , out .shape , ()) # sanity check
1374
+ assert xp .isnan (out ), f"{ out = !r} , but should be NaN"
0 commit comments