From 886f41103b365b5f103eba9aaf1fb9dd376d6d12 Mon Sep 17 00:00:00 2001 From: Chaluvadi <saketh.chaluvadi@intel.com> Date: Mon, 26 Feb 2024 10:15:46 -0500 Subject: [PATCH 1/2] added unit tests for the upper function --- tests/test_upper.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/test_upper.py diff --git a/tests/test_upper.py b/tests/test_upper.py new file mode 100644 index 0000000..46e5d82 --- /dev/null +++ b/tests/test_upper.py @@ -0,0 +1,48 @@ +import pytest + +import arrayfire_wrapper.dtypes as dtypes +from arrayfire_wrapper.lib.create_and_modify_array.create_array.constant import constant +from arrayfire_wrapper.lib.create_and_modify_array.create_array.diag import diag_extract +from arrayfire_wrapper.lib.create_and_modify_array.create_array.upper import upper +from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_scalar + + +@pytest.mark.parametrize( + "shape", + [ + (3, 3), + (3, 3, 3), + (3, 3, 3, 3), + ], +) +def test_diag_is_unit(shape: tuple) -> None: + """Test if when is_unit_diag in lower returns an array with a unit diagonal""" + dtype = dtypes.s64 + constant_array = constant(3, shape, dtype) + + lower_array = upper(constant_array, True) + diagonal = diag_extract(lower_array, 0) + diagonal_value = get_scalar(diagonal, dtype) + + assert diagonal_value == 1 + + +@pytest.mark.parametrize( + "shape", + [ + (3, 3), + (3, 3, 3), + (3, 3, 3, 3), + ], +) +def test_is_original(shape: tuple) -> None: + """Test if is_original keeps the diagonal the same as the original array""" + dtype = dtypes.s64 + constant_array = constant(3, shape, dtype) + original_value = get_scalar(constant_array, dtype) + + lower_array = upper(constant_array, False) + diagonal = diag_extract(lower_array, 0) + diagonal_value = get_scalar(diagonal, dtype) + + assert original_value == diagonal_value From 21f72a2c8b56163975ddeb63fbb7224c1db887c4 Mon Sep 17 00:00:00 2001 From: Chaluvadi <saketh.chaluvadi@intel.com> Date: Thu, 29 Feb 2024 14:49:14 -0500 Subject: [PATCH 2/2] fixed import formatting, black and flake8 checks --- tests/test_upper.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/test_upper.py b/tests/test_upper.py index 46e5d82..adeb41e 100644 --- a/tests/test_upper.py +++ b/tests/test_upper.py @@ -1,10 +1,7 @@ import pytest import arrayfire_wrapper.dtypes as dtypes -from arrayfire_wrapper.lib.create_and_modify_array.create_array.constant import constant -from arrayfire_wrapper.lib.create_and_modify_array.create_array.diag import diag_extract -from arrayfire_wrapper.lib.create_and_modify_array.create_array.upper import upper -from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_scalar +import arrayfire_wrapper.lib as wrapper @pytest.mark.parametrize( @@ -18,11 +15,11 @@ def test_diag_is_unit(shape: tuple) -> None: """Test if when is_unit_diag in lower returns an array with a unit diagonal""" dtype = dtypes.s64 - constant_array = constant(3, shape, dtype) + constant_array = wrapper.constant(3, shape, dtype) - lower_array = upper(constant_array, True) - diagonal = diag_extract(lower_array, 0) - diagonal_value = get_scalar(diagonal, dtype) + lower_array = wrapper.upper(constant_array, True) + diagonal = wrapper.diag_extract(lower_array, 0) + diagonal_value = wrapper.get_scalar(diagonal, dtype) assert diagonal_value == 1 @@ -38,11 +35,11 @@ def test_diag_is_unit(shape: tuple) -> None: def test_is_original(shape: tuple) -> None: """Test if is_original keeps the diagonal the same as the original array""" dtype = dtypes.s64 - constant_array = constant(3, shape, dtype) - original_value = get_scalar(constant_array, dtype) + constant_array = wrapper.constant(3, shape, dtype) + original_value = wrapper.get_scalar(constant_array, dtype) - lower_array = upper(constant_array, False) - diagonal = diag_extract(lower_array, 0) - diagonal_value = get_scalar(diagonal, dtype) + lower_array = wrapper.upper(constant_array, False) + diagonal = wrapper.diag_extract(lower_array, 0) + diagonal_value = wrapper.get_scalar(diagonal, dtype) assert original_value == diagonal_value