Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/1243 refactoring of communication separate mpi4py wrappers from dn darrays #1265

Prev Previous commit
Next Next commit
...
Hoppe committed Nov 21, 2023
commit e9eaf4e841df3b32abaf56f1201409e50a06a4ed
6 changes: 3 additions & 3 deletions heat/cluster/_kcluster.py
Original file line number Diff line number Diff line change
@@ -120,7 +120,7 @@ def _initialize_cluster_centers(self, x: DNDarray):
if x.comm.rank == proc:
idx = sample - displ[proc]
xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm)
xi.comm.Bcast(xi, root=proc)
xi.comm.Bcast(xi.larray, root=proc)
centroids[i, :] = xi

else:
@@ -155,7 +155,7 @@ def _initialize_cluster_centers(self, x: DNDarray):
if x.comm.rank == proc:
idx = sample - displ[proc]
x0 = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm)
x0.comm.Bcast(x0, root=proc)
x0.comm.Bcast(x0.larray, root=proc)
centroids[0, :] = x0
for i in range(1, self.n_clusters):
distances = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True)
@@ -179,7 +179,7 @@ def _initialize_cluster_centers(self, x: DNDarray):
if x.comm.rank == proc:
idx = sample - displ[proc]
xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm)
xi.comm.Bcast(xi, root=proc)
xi.comm.Bcast(xi.larray, root=proc)
centroids[i, :] = xi

else:
2 changes: 1 addition & 1 deletion heat/cluster/kmedoids.py
Original file line number Diff line number Diff line change
@@ -108,7 +108,7 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):
if x.comm.rank == proc:
lidx = idx - displ[proc]
closest_point = ht.array(x.lloc[lidx, :], device=x.device, comm=x.comm)
closest_point.comm.Bcast(closest_point, root=proc)
closest_point.comm.Bcast(closest_point.larray, root=proc)
new_cluster_centers[i, :] = closest_point

return new_cluster_centers
2 changes: 1 addition & 1 deletion heat/communication_backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Add the communication_backends functions to the ht.communication namespace
Add the communication_backends functions to the ht.communication_backends namespace
"""

from .communication import *
1 change: 0 additions & 1 deletion heat/communication_backends/communication.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@
import torch
from typing import Optional, Tuple
from ..core.stride_tricks import sanitize_axis
from ..core.dndarray import DNDarray


class Communication:
4 changes: 2 additions & 2 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
@@ -530,7 +530,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:

c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device, comm=a.comm)
c.larray[slice_0.start : slice_0.stop, :] += hold
c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
c.comm.Allreduce(MPI.IN_PLACE, c.larray, MPI.SUM)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c
@@ -707,7 +707,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c_idx = c.comm.chunk(c.shape, c.split)[2]
c_index_map[c.comm.rank, 0, :] = (c_idx[0].start, c_idx[0].stop)
c_index_map[c.comm.rank, 1, :] = (c_idx[1].start, c_idx[1].stop)
c_wait = c.comm.Iallreduce(MPI.IN_PLACE, c_index_map, MPI.SUM)
c_wait = c.comm.Iallreduce(MPI.IN_PLACE, c_index_map.larray, MPI.SUM)

if a.split == 0:
a_block_map = torch.zeros(
32 changes: 24 additions & 8 deletions heat/core/linalg/solver.py
Original file line number Diff line number Diff line change
@@ -183,8 +183,12 @@ def lanczos(
vi_loc = V._DNDarray__array[:, j]
a = torch.dot(vr.larray, torch.conj(vi_loc))
b = torch.dot(vi_loc, torch.conj(vi_loc))
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM)
A.comm.Allreduce(
ht.communication_backends.MPI.IN_PLACE, a, ht.communication_backends.MPI.SUM
)
A.comm.Allreduce(
ht.communication_backends.MPI.IN_PLACE, b, ht.communication_backends.MPI.SUM
)
vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc
# normalize v_r to Euclidean norm 1 and set as ith vector v
vi = vr / ht.norm(vr)
@@ -196,8 +200,12 @@ def lanczos(
vi_loc = V.larray[:, j]
a = torch.dot(vr._DNDarray__array, torch.conj(vi_loc))
b = torch.dot(vi_loc, torch.conj(vi_loc))
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM)
A.comm.Allreduce(
ht.communication_backends.MPI.IN_PLACE, a, ht.communication_backends.MPI.SUM
)
A.comm.Allreduce(
ht.communication_backends.MPI.IN_PLACE, b, ht.communication_backends.MPI.SUM
)
vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc

vi = vr / ht.norm(vr)
@@ -235,8 +243,12 @@ def lanczos(
vi_loc = V._DNDarray__array[:, j]
a = torch.dot(vr.larray, vi_loc)
b = torch.dot(vi_loc, vi_loc)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM)
A.comm.Allreduce(
ht.communication_backends.MPI.IN_PLACE, a, ht.communication_backends.MPI.SUM
)
A.comm.Allreduce(
ht.communication_backends.MPI.IN_PLACE, b, ht.communication_backends.MPI.SUM
)
vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc
# normalize v_r to Euclidean norm 1 and set as ith vector v
vi = vr / ht.norm(vr)
@@ -248,8 +260,12 @@ def lanczos(
vi_loc = V.larray[:, j]
a = torch.dot(vr._DNDarray__array, vi_loc)
b = torch.dot(vi_loc, vi_loc)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM)
A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM)
A.comm.Allreduce(
ht.communication_backends.MPI.IN_PLACE, a, ht.communication_backends.MPI.SUM
)
A.comm.Allreduce(
ht.communication_backends.MPI.IN_PLACE, b, ht.communication_backends.MPI.SUM
)
vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc

vi = vr / ht.norm(vr)
2 changes: 1 addition & 1 deletion heat/core/linalg/tests/test_solver.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@

class TestSolver(TestCase):
def test_cg(self):
size = ht.communication.MPI_WORLD.size * 3
size = ht.communication_backends.MPI_WORLD.size * 3
b = ht.arange(1, size + 1, dtype=ht.float32, split=0)
A = ht.manipulations.diag(b)
x0 = ht.random.rand(size, dtype=b.dtype, split=b.split)
2 changes: 1 addition & 1 deletion heat/core/statistics.py
Original file line number Diff line number Diff line change
@@ -1637,7 +1637,7 @@ def _local_percentile(data: torch.Tensor, axis: int, indices: torch.Tensor) -> t
comm=x.comm,
balanced=True,
)
x.comm.Bcast(local_p, root=r)
x.comm.Bcast(local_p.larray, root=r)
percentile[perc_slice] = local_p
else:
if x.comm.is_distributed() and split is not None:
Loading