From 886f41103b365b5f103eba9aaf1fb9dd376d6d12 Mon Sep 17 00:00:00 2001
From: Chaluvadi <saketh.chaluvadi@intel.com>
Date: Mon, 26 Feb 2024 10:15:46 -0500
Subject: [PATCH 1/2] added unit tests for the upper function

---
 tests/test_upper.py | 48 +++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 48 insertions(+)
 create mode 100644 tests/test_upper.py

diff --git a/tests/test_upper.py b/tests/test_upper.py
new file mode 100644
index 0000000..46e5d82
--- /dev/null
+++ b/tests/test_upper.py
@@ -0,0 +1,48 @@
+import pytest
+
+import arrayfire_wrapper.dtypes as dtypes
+from arrayfire_wrapper.lib.create_and_modify_array.create_array.constant import constant
+from arrayfire_wrapper.lib.create_and_modify_array.create_array.diag import diag_extract
+from arrayfire_wrapper.lib.create_and_modify_array.create_array.upper import upper
+from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_scalar
+
+
+@pytest.mark.parametrize(
+    "shape",
+    [
+        (3, 3),
+        (3, 3, 3),
+        (3, 3, 3, 3),
+    ],
+)
+def test_diag_is_unit(shape: tuple) -> None:
+    """Test if when is_unit_diag in lower returns an array with a unit diagonal"""
+    dtype = dtypes.s64
+    constant_array = constant(3, shape, dtype)
+
+    lower_array = upper(constant_array, True)
+    diagonal = diag_extract(lower_array, 0)
+    diagonal_value = get_scalar(diagonal, dtype)
+
+    assert diagonal_value == 1
+
+
+@pytest.mark.parametrize(
+    "shape",
+    [
+        (3, 3),
+        (3, 3, 3),
+        (3, 3, 3, 3),
+    ],
+)
+def test_is_original(shape: tuple) -> None:
+    """Test if is_original keeps the diagonal the same as the original array"""
+    dtype = dtypes.s64
+    constant_array = constant(3, shape, dtype)
+    original_value = get_scalar(constant_array, dtype)
+
+    lower_array = upper(constant_array, False)
+    diagonal = diag_extract(lower_array, 0)
+    diagonal_value = get_scalar(diagonal, dtype)
+
+    assert original_value == diagonal_value

From 21f72a2c8b56163975ddeb63fbb7224c1db887c4 Mon Sep 17 00:00:00 2001
From: Chaluvadi <saketh.chaluvadi@intel.com>
Date: Thu, 29 Feb 2024 14:49:14 -0500
Subject: [PATCH 2/2] fixed import formatting, black and flake8 checks

---
 tests/test_upper.py | 23 ++++++++++-------------
 1 file changed, 10 insertions(+), 13 deletions(-)

diff --git a/tests/test_upper.py b/tests/test_upper.py
index 46e5d82..adeb41e 100644
--- a/tests/test_upper.py
+++ b/tests/test_upper.py
@@ -1,10 +1,7 @@
 import pytest
 
 import arrayfire_wrapper.dtypes as dtypes
-from arrayfire_wrapper.lib.create_and_modify_array.create_array.constant import constant
-from arrayfire_wrapper.lib.create_and_modify_array.create_array.diag import diag_extract
-from arrayfire_wrapper.lib.create_and_modify_array.create_array.upper import upper
-from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_scalar
+import arrayfire_wrapper.lib as wrapper
 
 
 @pytest.mark.parametrize(
@@ -18,11 +15,11 @@
 def test_diag_is_unit(shape: tuple) -> None:
     """Test if when is_unit_diag in lower returns an array with a unit diagonal"""
     dtype = dtypes.s64
-    constant_array = constant(3, shape, dtype)
+    constant_array = wrapper.constant(3, shape, dtype)
 
-    lower_array = upper(constant_array, True)
-    diagonal = diag_extract(lower_array, 0)
-    diagonal_value = get_scalar(diagonal, dtype)
+    lower_array = wrapper.upper(constant_array, True)
+    diagonal = wrapper.diag_extract(lower_array, 0)
+    diagonal_value = wrapper.get_scalar(diagonal, dtype)
 
     assert diagonal_value == 1
 
@@ -38,11 +35,11 @@ def test_diag_is_unit(shape: tuple) -> None:
 def test_is_original(shape: tuple) -> None:
     """Test if is_original keeps the diagonal the same as the original array"""
     dtype = dtypes.s64
-    constant_array = constant(3, shape, dtype)
-    original_value = get_scalar(constant_array, dtype)
+    constant_array = wrapper.constant(3, shape, dtype)
+    original_value = wrapper.get_scalar(constant_array, dtype)
 
-    lower_array = upper(constant_array, False)
-    diagonal = diag_extract(lower_array, 0)
-    diagonal_value = get_scalar(diagonal, dtype)
+    lower_array = wrapper.upper(constant_array, False)
+    diagonal = wrapper.diag_extract(lower_array, 0)
+    diagonal_value = wrapper.get_scalar(diagonal, dtype)
 
     assert original_value == diagonal_value