From 84666faa2d648688ff24fb69c87ddebf310d2e1a Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 16:21:47 +0200 Subject: [PATCH] FIX: fixing spmd tests utilities for dpctl inputs (#1999) (#2020) * FIX: fixing spmd tests utilities for dpctl inputs (cherry picked from commit 18d042849ca6e7ca03713593840882bd1ff83483) Co-authored-by: Samir Nasibli --- sklearnex/tests/_utils_spmd.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sklearnex/tests/_utils_spmd.py b/sklearnex/tests/_utils_spmd.py index 172db788be..9c6f970300 100644 --- a/sklearnex/tests/_utils_spmd.py +++ b/sklearnex/tests/_utils_spmd.py @@ -146,8 +146,11 @@ def _assert_unordered_allclose(spmd_result, batch_result, localize=False, **kwar Raises: AssertionError: If results do not match. """ + np_spmd_result = _as_numpy(spmd_result) - sorted_spmd_result = spmd_result[np.argsort(np.linalg.norm(spmd_result, axis=1))] + sorted_spmd_result = np_spmd_result[ + np.argsort(np.linalg.norm(np_spmd_result, axis=1)) + ] if localize: local_batch_result = _get_local_tensor(batch_result) sorted_batch_result = local_batch_result[ @@ -158,7 +161,7 @@ def _assert_unordered_allclose(spmd_result, batch_result, localize=False, **kwar np.argsort(np.linalg.norm(batch_result, axis=1)) ] - assert_allclose(_as_numpy(sorted_spmd_result), sorted_batch_result, **kwargs) + assert_allclose(sorted_spmd_result, sorted_batch_result, **kwargs) def _assert_kmeans_labels_allclose( @@ -179,7 +182,11 @@ def _assert_kmeans_labels_allclose( AssertionError: If clusters are not correctly assigned. """ + np_spmd_labels = _as_numpy(spmd_labels) + np_spmd_centers = _as_numpy(spmd_centers) local_batch_labels = _get_local_tensor(batch_labels) assert_allclose( - spmd_centers[_as_numpy(spmd_labels)], batch_centers[local_batch_labels], **kwargs + np_spmd_centers[np_spmd_labels], + batch_centers[local_batch_labels], + **kwargs, )