Skip to content

Commit

Permalink
FIX: fixing spmd tests utilities for dpctl inputs (#1999) (#2020)
Browse files Browse the repository at this point in the history
* FIX: fixing spmd tests utilities for dpctl inputs

(cherry picked from commit 18d0428)

Co-authored-by: Samir Nasibli <[email protected]>
  • Loading branch information
mergify[bot] and samir-nasibli authored Sep 2, 2024
1 parent ae3c3a5 commit 84666fa
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions sklearnex/tests/_utils_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,11 @@ def _assert_unordered_allclose(spmd_result, batch_result, localize=False, **kwar
Raises:
AssertionError: If results do not match.
"""
np_spmd_result = _as_numpy(spmd_result)

sorted_spmd_result = spmd_result[np.argsort(np.linalg.norm(spmd_result, axis=1))]
sorted_spmd_result = np_spmd_result[
np.argsort(np.linalg.norm(np_spmd_result, axis=1))
]
if localize:
local_batch_result = _get_local_tensor(batch_result)
sorted_batch_result = local_batch_result[
Expand All @@ -158,7 +161,7 @@ def _assert_unordered_allclose(spmd_result, batch_result, localize=False, **kwar
np.argsort(np.linalg.norm(batch_result, axis=1))
]

assert_allclose(_as_numpy(sorted_spmd_result), sorted_batch_result, **kwargs)
assert_allclose(sorted_spmd_result, sorted_batch_result, **kwargs)


def _assert_kmeans_labels_allclose(
Expand All @@ -179,7 +182,11 @@ def _assert_kmeans_labels_allclose(
AssertionError: If clusters are not correctly assigned.
"""

np_spmd_labels = _as_numpy(spmd_labels)
np_spmd_centers = _as_numpy(spmd_centers)
local_batch_labels = _get_local_tensor(batch_labels)
assert_allclose(
spmd_centers[_as_numpy(spmd_labels)], batch_centers[local_batch_labels], **kwargs
np_spmd_centers[np_spmd_labels],
batch_centers[local_batch_labels],
**kwargs,
)

0 comments on commit 84666fa

Please sign in to comment.