Skip to content

Commit a2f7bd5

Browse files
authored
Merge pull request #116 from honno/minor-fixes
Minor fixes
2 parents 63ebadb + 4d3405a commit a2f7bd5

File tree

2 files changed

+6
-29
lines changed

2 files changed

+6
-29
lines changed

array_api_tests/test_creation_functions.py

+5-28
Original file line numberDiff line numberDiff line change
@@ -409,43 +409,20 @@ def test_full_like(x, fill_value, kw):
409409
finite_kw = {"allow_nan": False, "allow_infinity": False}
410410

411411

412-
def int_stops(
413-
start: int, num, dtype: DataType, endpoint: bool
414-
) -> st.SearchStrategy[int]:
415-
min_gap = num
416-
if endpoint:
417-
min_gap += 1
418-
m, M = dh.dtype_ranges[dtype]
419-
max_pos_gap = M - start
420-
max_neg_gap = start - m
421-
max_pos_mul = max_pos_gap // min_gap
422-
max_neg_mul = max_neg_gap // min_gap
423-
return st.one_of(
424-
st.integers(0, max_pos_mul).map(lambda n: start + min_gap * n),
425-
st.integers(0, max_neg_mul).map(lambda n: start - min_gap * n),
426-
)
427-
428-
429412
@given(
430413
num=hh.sizes,
431-
dtype=st.none() | xps.numeric_dtypes(),
414+
dtype=st.none() | xps.floating_dtypes(),
432415
endpoint=st.booleans(),
433416
data=st.data(),
434417
)
435418
def test_linspace(num, dtype, endpoint, data):
436419
_dtype = dh.default_float if dtype is None else dtype
437420

438421
start = data.draw(xps.from_dtype(_dtype, **finite_kw), label="start")
439-
if dh.is_float_dtype(_dtype):
440-
stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop")
441-
# avoid overflow errors
442-
assume(not ah.isnan(ah.asarray(stop - start, dtype=_dtype)))
443-
assume(not ah.isnan(ah.asarray(start - stop, dtype=_dtype)))
444-
else:
445-
if num == 0:
446-
stop = start
447-
else:
448-
stop = data.draw(int_stops(start, num, _dtype, endpoint), label="stop")
422+
stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop")
423+
# avoid overflow errors
424+
assume(not ah.isnan(ah.asarray(stop - start, dtype=_dtype)))
425+
assume(not ah.isnan(ah.asarray(start - stop, dtype=_dtype)))
449426

450427
kw = data.draw(
451428
hh.specified_kwargs(

array_api_tests/test_linalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def test_matmul(x1, x2):
315315
@given(
316316
x=finite_matrices(),
317317
kw=kwargs(keepdims=booleans(),
318-
ord=sampled_from([-float('inf'), -2, -2, 1, 2, float('inf'), 'fro', 'nuc']))
318+
ord=sampled_from([-float('inf'), -2, -1, 1, 2, float('inf'), 'fro', 'nuc']))
319319
)
320320
def test_matrix_norm(x, kw):
321321
res = linalg.matrix_norm(x, **kw)

0 commit comments

Comments
 (0)