Skip to content

Commit 43fedaf

Browse files
author
AzeezIsh
committed
Resolved hypotenuse test issue, missing param.
Was missing batch param, and out pointer, appended the same for clamp.
1 parent 903e14d commit 43fedaf

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

arrayfire_wrapper/lib/mathematical_functions/numeric_functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@ def floor(arr: AFArray, /) -> AFArray:
4848
return unary_op(floor.__name__, arr)
4949

5050

51-
def hypot(lhs: AFArray, rhs: AFArray, /) -> AFArray:
51+
def hypot(lhs: AFArray, rhs: AFArray, batch: bool, /) -> AFArray:
5252
"""
5353
source:
5454
"""
5555
out = AFArray.create_null_pointer()
56-
call_from_clib(hypot.__name__, lhs, rhs)
56+
call_from_clib(hypot.__name__, ctypes.pointer(out), lhs, rhs, ctypes.c_bool(batch))
5757
return out
5858

5959

tests/test_numeric.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -385,15 +385,27 @@ def test_trunc_shapes_invalid(invdtypes: dtype.Dtype) -> None:
385385
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
386386
],
387387
)
388-
@pytest.mark.parametrize("dtype_name", util.get_all_types())
389-
def test_hypot_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
388+
def test_hypot_shape_dtypes(shape: tuple) -> None:
390389
"""Test hypotenuse operation between two arrays of the same shape"""
391-
util.check_type_supported(dtype_name)
392-
lhs = wrapper.randu(shape, dtype_name)
393-
rhs = wrapper.randu(shape, dtype_name)
390+
lhs = wrapper.randu(shape, dtype.f32)
391+
rhs = wrapper.randu(shape, dtype.f32)
394392

395-
result = wrapper.hypot(lhs, rhs)
393+
result = wrapper.hypot(lhs, rhs, True)
396394

397395
assert (
398396
wrapper.get_dims(result)[0 : len(shape)] == shape # noqa
399-
), f"failed for shape: {shape} and dtype {dtype_name}"
397+
), f"failed for shape: {shape} and dtype {dtype.f32}"
398+
@pytest.mark.parametrize(
399+
"invdtypes",
400+
[
401+
dtype.int32,
402+
dtype.uint32,
403+
],
404+
)
405+
def test_hypot_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
406+
"""Test division operation for unsupported data types."""
407+
with pytest.raises(RuntimeError):
408+
shape = (5, 5)
409+
lhs = wrapper.randu(shape, invdtypes)
410+
rhs = wrapper.randu(shape, invdtypes)
411+
wrapper.hypot(rhs, lhs, True)

0 commit comments

Comments
 (0)