diff --git a/heat/__init__.py b/heat/__init__.py index 84c4afc11b..79c0ec2b3b 100644 --- a/heat/__init__.py +++ b/heat/__init__.py @@ -19,3 +19,4 @@ from . import spatial from . import utils from . import preprocessing +from . import communication_backends diff --git a/heat/cluster/_kcluster.py b/heat/cluster/_kcluster.py index d3f0bdae19..abee882396 100644 --- a/heat/cluster/_kcluster.py +++ b/heat/cluster/_kcluster.py @@ -120,7 +120,7 @@ def _initialize_cluster_centers(self, x: DNDarray): if x.comm.rank == proc: idx = sample - displ[proc] xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm) - xi.comm.Bcast(xi, root=proc) + xi.comm.Bcast(xi.larray, root=proc) centroids[i, :] = xi else: @@ -155,7 +155,7 @@ def _initialize_cluster_centers(self, x: DNDarray): if x.comm.rank == proc: idx = sample - displ[proc] x0 = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm) - x0.comm.Bcast(x0, root=proc) + x0.comm.Bcast(x0.larray, root=proc) centroids[0, :] = x0 for i in range(1, self.n_clusters): distances = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True) @@ -179,7 +179,7 @@ def _initialize_cluster_centers(self, x: DNDarray): if x.comm.rank == proc: idx = sample - displ[proc] xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm) - xi.comm.Bcast(xi, root=proc) + xi.comm.Bcast(xi.larray, root=proc) centroids[i, :] = xi else: diff --git a/heat/cluster/kmedoids.py b/heat/cluster/kmedoids.py index c24c5287ed..582da86924 100644 --- a/heat/cluster/kmedoids.py +++ b/heat/cluster/kmedoids.py @@ -108,7 +108,7 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray): if x.comm.rank == proc: lidx = idx - displ[proc] closest_point = ht.array(x.lloc[lidx, :], device=x.device, comm=x.comm) - closest_point.comm.Bcast(closest_point, root=proc) + closest_point.comm.Bcast(closest_point.larray, root=proc) new_cluster_centers[i, :] = closest_point return new_cluster_centers diff --git a/heat/communication_backends/__init__.py b/heat/communication_backends/__init__.py new file mode 100644 index 0000000000..d8bbcf8e81 --- /dev/null +++ b/heat/communication_backends/__init__.py @@ -0,0 +1,6 @@ +""" +Add the communication_backends functions to the ht.communication_backends namespace +""" + +from .communication import * +from .mpi4py4torch import * diff --git a/heat/communication_backends/communication.py b/heat/communication_backends/communication.py new file mode 100644 index 0000000000..f1d9f21096 --- /dev/null +++ b/heat/communication_backends/communication.py @@ -0,0 +1,169 @@ +""" +Module implementing the communication layer of HeAT +""" +from __future__ import annotations +import torch +from typing import Optional, Tuple +from ..core.stride_tricks import sanitize_axis + + +class Communication: + """ + Base class for Communications (inteded for other backends) + """ + + @staticmethod + def is_distributed() -> NotImplementedError: + """ + Whether or not the Communication is distributed + """ + raise NotImplementedError() + + def __init__(self) -> NotImplementedError: + raise NotImplementedError() + + def chunk( + self, + shape: Tuple[int], + split: int, + rank: int = None, + w_size: int = None, + sparse: bool = False, + ) -> Tuple[int, Tuple[int], Tuple[slice]]: + """ + Calculates the chunk of data that will be assigned to this compute node given a global data shape and a split + axis. + Returns ``(offset, local_shape, slices)``: the offset in the split dimension, the resulting local shape if the + global input shape is chunked on the split axis and the chunk slices with respect to the given shape + + Parameters + ---------- + shape : Tuple[int,...] + The global shape of the data to be split + split : int + The axis along which to chunk the data + rank : int, optional + Process for which the chunking is calculated for, defaults to ``self.rank``. + Intended for creating chunk maps without communication + w_size : int, optional + The MPI world size, defaults to ``self.size``. + Intended for creating chunk maps without communication + sparse : bool, optional + Specifies whether the array is a sparse matrix + """ + # ensure the split axis is valid, we actually do not need it + split = sanitize_axis(shape, split) + if split is None: + return 0, shape, tuple(slice(0, end) for end in shape) + rank = self.rank if rank is None else rank + w_size = self.size if w_size is None else w_size + if not isinstance(rank, int) or not isinstance(w_size, int): + raise TypeError("rank and size must be integers") + + dims = len(shape) + size = shape[split] + chunk = size // w_size + remainder = size % w_size + + if remainder > rank: + chunk += 1 + start = rank * chunk + else: + start = rank * chunk + remainder + end = start + chunk + + if sparse: + return start, end + + return ( + start, + tuple(shape[i] if i != split else end - start for i in range(dims)), + tuple(slice(0, shape[i]) if i != split else slice(start, end) for i in range(dims)), + ) + + def counts_displs_shape( + self, shape: Tuple[int], axis: int + ) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]: + """ + Calculates the item counts, displacements and output shape for a variable sized all-to-all MPI-call (e.g. + ``MPI_Alltoallv``). The passed shape is regularly chunk along the given axis and for all nodes. + + Parameters + ---------- + shape : Tuple[int,...] + The object for which to calculate the chunking. + axis : int + The axis along which the chunking is performed. + + """ + # the elements send/received by all nodes + counts = torch.full((self.size,), shape[axis] // self.size) + counts[: shape[axis] % self.size] += 1 + + # the displacements into the buffer + displs = torch.zeros((self.size,), dtype=counts.dtype) + torch.cumsum(counts[:-1], out=displs[1:], dim=0) + + # helper that calculates the output shape for a receiving buffer under the assumption all nodes have an equally + # sized input compared to this node + output_shape = list(shape) + output_shape[axis] = self.size * counts[self.rank].item() + + return tuple(counts.tolist()), tuple(displs.tolist()), tuple(output_shape) + + +# creating a duplicate COMM +from .mpi4py4torch import MPICommunication +from mpi4py import MPI + +comm = MPI.COMM_WORLD +dup_comm = comm.Dup() + +MPI_WORLD = MPICommunication(dup_comm) +MPI_SELF = MPICommunication(MPI.COMM_SELF.Dup()) + +# set the default communicator to be MPI_WORLD +__default_comm = MPI_WORLD + + +def get_comm() -> Communication: + """ + Retrieves the currently globally set default communication. + """ + return __default_comm + + +def sanitize_comm(comm: Optional[Communication]) -> Communication: + """ + Sanitizes a device or device identifier, i.e. checks whether it is already an instance of :class:`heat.core.devices.Device` + or a string with known device identifier and maps it to a proper ``Device``. + + Parameters + ---------- + comm : Communication + The comm to be sanitized + + Raises + ------ + TypeError + If the given communication is not the proper type + """ + if comm is None: + return get_comm() + elif isinstance(comm, Communication): + return comm + + raise TypeError(f"Unknown communication, must be instance of {Communication}") + + +def use_comm(comm: Communication = None): + """ + Sets the globally used default communicator. + + Parameters + ---------- + comm : Communication or None + The communication to be set + """ + global __default_comm + __default_comm = sanitize_comm(comm) diff --git a/heat/core/communication.py b/heat/communication_backends/mpi4py4torch.py similarity index 79% rename from heat/core/communication.py rename to heat/communication_backends/mpi4py4torch.py index d505364b03..1260e6df8b 100644 --- a/heat/core/communication.py +++ b/heat/communication_backends/mpi4py4torch.py @@ -1,2063 +1,1850 @@ -""" -Module implementing the communication layer of HeAT -""" -from __future__ import annotations - -import numpy as np -import os -import subprocess -import torch -from mpi4py import MPI - -from typing import Any, Callable, Optional, List, Tuple, Union -from .stride_tricks import sanitize_axis - -CUDA_AWARE_MPI = False -# check whether OpenMPI support CUDA-aware MPI -if "openmpi" in os.environ.get("MPI_SUFFIX", "").lower(): - buffer = subprocess.check_output(["ompi_info", "--parsable", "--all"]) - CUDA_AWARE_MPI = b"mpi_built_with_cuda_support:value:true" in buffer -# MVAPICH -CUDA_AWARE_MPI = CUDA_AWARE_MPI or os.environ.get("MV2_USE_CUDA") == "1" -# MPICH -CUDA_AWARE_MPI = CUDA_AWARE_MPI or os.environ.get("MPIR_CVAR_ENABLE_HCOLL") == "1" -# ParaStationMPI -CUDA_AWARE_MPI = CUDA_AWARE_MPI or os.environ.get("PSP_CUDA") == "1" - - -class MPIRequest: - """ - Represents a handle on a non-blocking operation - - Parameters - ---------- - handle: MPI.Communicator - Handle for the mpi4py Communicator - sendbuf: DNDarray or torch.Tensor or Any - The buffer for the data to be send - recvbuf: DNDarray or torch.Tensor or Any - The buffer to the receive data - tensor: torch.Tensor - Internal Data - permutation: Tuple[int,...] - Permutation of the tensor axes - """ - - def __init__( - self, - handle, - sendbuf: Union[DNDarray, torch.Tensor, Any] = None, - recvbuf: Union[DNDarray, torch.Tensor, Any] = None, - tensor: torch.Tensor = None, - permutation: Tuple[int, ...] = None, - ): - self.handle = handle - self.tensor = tensor - self.recvbuf = recvbuf - self.sendbuf = sendbuf - self.permutation = permutation - - def Wait(self, status: MPI.Status = None): - """ - Waits for an MPI request to complete - """ - self.handle.Wait(status) - if self.tensor is not None and isinstance(self.tensor, torch.Tensor): - if self.permutation is not None: - self.recvbuf = self.recvbuf.permute(self.permutation) - if self.tensor is not None and self.tensor.is_cuda and not CUDA_AWARE_MPI: - self.tensor.copy_(self.recvbuf) - - def __getattr__(self, name: str) -> Callable: - """ - Default pass-through for the communicator methods. - - Parameters - ---------- - name : str - The name of the method to be called. - """ - return getattr(self.handle, name) - - -class Communication: - """ - Base class for Communications (inteded for other backends) - """ - - @staticmethod - def is_distributed() -> NotImplementedError: - """ - Whether or not the Communication is distributed - """ - raise NotImplementedError() - - def __init__(self) -> NotImplementedError: - raise NotImplementedError() - - def chunk(self, shape, split) -> NotImplementedError: - """ - Calculates the chunk of data that will be assigned to this compute node given a global data shape and a split - axis. Returns ``(offset, local_shape, slices)``: the offset in the split dimension, the resulting local shape if the - global input shape is chunked on the split axis and the chunk slices with respect to the given shape - - Parameters - ---------- - shape : Tuple[int,...] - The global shape of the data to be split - split : int - The axis along which to chunk the data - - """ - raise NotImplementedError() - - -class MPICommunication(Communication): - """ - Class encapsulating all MPI Communication - - Parameters - ---------- - handle: MPI.Communicator - Handle for the mpi4py Communicator - """ - - __mpi_type_mappings = { - torch.bool: MPI.BOOL, - torch.uint8: MPI.UNSIGNED_CHAR, - torch.int8: MPI.SIGNED_CHAR, - torch.int16: MPI.SHORT, - torch.int32: MPI.INT, - torch.int64: MPI.LONG, - torch.bfloat16: MPI.INT16_T, - torch.float16: MPI.INT16_T, - torch.float32: MPI.FLOAT, - torch.float64: MPI.DOUBLE, - torch.complex64: MPI.COMPLEX, - torch.complex128: MPI.DOUBLE_COMPLEX, - } - - def __init__(self, handle=MPI.COMM_WORLD): - self.handle = handle - try: - self.rank = handle.Get_rank() - self.size = handle.Get_size() - except MPI.Exception: - # ranks not within the group will fail with an MPI.Exception, this is expected - self.rank = None - self.size = None - - def is_distributed(self) -> bool: - """ - Determines whether the communicator is distributed, i.e. handles more than one node. - """ - return self.size > 1 - - def chunk( - self, - shape: Tuple[int], - split: int, - rank: int = None, - w_size: int = None, - sparse: bool = False, - ) -> Tuple[int, Tuple[int], Tuple[slice]]: - """ - Calculates the chunk of data that will be assigned to this compute node given a global data shape and a split - axis. - Returns ``(offset, local_shape, slices)``: the offset in the split dimension, the resulting local shape if the - global input shape is chunked on the split axis and the chunk slices with respect to the given shape - - Parameters - ---------- - shape : Tuple[int,...] - The global shape of the data to be split - split : int - The axis along which to chunk the data - rank : int, optional - Process for which the chunking is calculated for, defaults to ``self.rank``. - Intended for creating chunk maps without communication - w_size : int, optional - The MPI world size, defaults to ``self.size``. - Intended for creating chunk maps without communication - sparse : bool, optional - Specifies whether the array is a sparse matrix - """ - # ensure the split axis is valid, we actually do not need it - split = sanitize_axis(shape, split) - if split is None: - return 0, shape, tuple(slice(0, end) for end in shape) - rank = self.rank if rank is None else rank - w_size = self.size if w_size is None else w_size - if not isinstance(rank, int) or not isinstance(w_size, int): - raise TypeError("rank and size must be integers") - - dims = len(shape) - size = shape[split] - chunk = size // w_size - remainder = size % w_size - - if remainder > rank: - chunk += 1 - start = rank * chunk - else: - start = rank * chunk + remainder - end = start + chunk - - if sparse: - return start, end - - return ( - start, - tuple(shape[i] if i != split else end - start for i in range(dims)), - tuple(slice(0, shape[i]) if i != split else slice(start, end) for i in range(dims)), - ) - - def counts_displs_shape( - self, shape: Tuple[int], axis: int - ) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]: - """ - Calculates the item counts, displacements and output shape for a variable sized all-to-all MPI-call (e.g. - ``MPI_Alltoallv``). The passed shape is regularly chunk along the given axis and for all nodes. - - Parameters - ---------- - shape : Tuple[int,...] - The object for which to calculate the chunking. - axis : int - The axis along which the chunking is performed. - - """ - # the elements send/received by all nodes - counts = torch.full((self.size,), shape[axis] // self.size) - counts[: shape[axis] % self.size] += 1 - - # the displacements into the buffer - displs = torch.zeros((self.size,), dtype=counts.dtype) - torch.cumsum(counts[:-1], out=displs[1:], dim=0) - - # helper that calculates the output shape for a receiving buffer under the assumption all nodes have an equally - # sized input compared to this node - output_shape = list(shape) - output_shape[axis] = self.size * counts[self.rank].item() - - return tuple(counts.tolist()), tuple(displs.tolist()), tuple(output_shape) - - @classmethod - def mpi_type_and_elements_of( - cls, - obj: Union[DNDarray, torch.Tensor], - counts: Tuple[int], - displs: Tuple[int], - is_contiguous: Optional[bool], - ) -> Tuple[MPI.Datatype, Tuple[int, ...]]: - """ - Determines the MPI data type and number of respective elements for the given tensor (:class:`~heat.core.dndarray.DNDarray` - or ``torch.Tensor). In case the tensor is contiguous in memory, a native MPI data type can be used. - Otherwise, a derived data type is automatically constructed using the storage information of the passed object. - - Parameters - ---------- - obj : DNDarray or torch.Tensor - The object for which to construct the MPI data type and number of elements - counts : Tuple[ints,...], optional - Optional counts arguments for variable MPI-calls (e.g. Alltoallv) - displs : Tuple[ints,...], optional - Optional displacements arguments for variable MPI-calls (e.g. Alltoallv) - is_contiguous: bool - Information on global contiguity of the memory-distributed object. If `None`, it will be set to local contiguity via ``torch.Tensor.is_contiguous()``. - # ToDo: The option to explicitely specify the counts and displacements to be send still needs propper implementation - """ - mpi_type, elements = cls.__mpi_type_mappings[obj.dtype], torch.numel(obj) - - # simple case, contiguous memory can be transmitted as is - if is_contiguous is None: - # determine local contiguity - is_contiguous = obj.is_contiguous() - - if is_contiguous: - if counts is None: - return mpi_type, elements - factor = np.prod(obj.shape[1:]) - return ( - mpi_type, - ( - tuple(factor * ele for ele in counts), - (tuple(factor * ele for ele in displs)), - ), - ) - - # non-contiguous memory, e.g. after a transpose, has to be packed in derived MPI types - elements = obj.shape[0] - shape = obj.shape[1:] - strides = [1] * len(shape) - strides[0] = obj.stride()[-1] - strides = strides[::-1] - offsets = [obj.element_size() * stride for stride in obj.stride()[:-1]] - - # chain the types based on the - for i in range(len(shape) - 1, -1, -1): - mpi_type = mpi_type.Create_vector(shape[i], 1, strides[i]).Create_resized(0, offsets[i]) - mpi_type.Commit() - - if counts is not None: - return mpi_type, (counts, displs) - - return mpi_type, elements - - @classmethod - def as_mpi_memory(cls, obj) -> MPI.memory: - """ - Converts the passed ``torch.Tensor`` into an MPI compatible memory view. - - Parameters - ---------- - obj : torch.Tensor - The tensor to be converted into a MPI memory view. - """ - return MPI.memory.fromaddress(obj.data_ptr(), 0) - - @classmethod - def as_buffer( - cls, - obj: torch.Tensor, - counts: Tuple[int] = None, - displs: Tuple[int] = None, - is_contiguous: Optional[bool] = None, - ) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]: - """ - Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type. - - Parameters - ---------- - obj : torch.Tensor - The object to be converted into a buffer representation. - counts : Tuple[int,...], optional - Optional counts arguments for variable MPI-calls (e.g. Alltoallv) - displs : Tuple[int,...], optional - Optional displacements arguments for variable MPI-calls (e.g. Alltoallv) - is_contiguous: bool, optional - Optional information on global contiguity of the memory-distributed object. - """ - squ = False - if not obj.is_contiguous() and obj.ndim == 1: - # this makes the math work below this function. - obj.unsqueeze_(-1) - squ = True - - mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs, is_contiguous) - mpi_mem = cls.as_mpi_memory(obj) - if squ: - # the squeeze happens in the mpi_type_and_elements_of function in the case of a - # non-contiguous 1D tensor. Squeezing it puts the memory back to where it should be - obj.squeeze_(-1) - return [mpi_mem, elements, mpi_type] - - def alltoall_sendbuffer( - self, obj: torch.Tensor - ) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]: - """ - Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type. - XXX: might not work for all MPI stacks. Might require multiple type commits or so - - Parameters - ---------- - obj: torch.Tensor - The object to be transformed into a custom MPI datatype - """ - mpi_type = self.__mpi_type_mappings[obj.dtype] - - nproc = self.size - shape = obj.shape - strides = [1] * len(shape) - strides[-1] = obj.stride()[-1] - offsets = [0] * len(shape) - offsets[1:] = [obj.element_size() * stride for stride in obj.stride()[:-1]] - - # Step 1: Wrap along axes > 1 (all axes except send_axis and recv_axis - for i in range(len(shape) - 1, 1, -1): - mpi_type = mpi_type.Create_vector(shape[i], 1, strides[i]).Create_resized(0, offsets[i]) - mpi_type.Commit() - - # Step 2: Create Custom sized vector datatypes, according to rank-specific size along send_axis - # send_elements has nproc entries, defining how many vectors of mpi_type are stacked together for each process to receive along the send_axis - send_elements = np.full((nproc,), obj.shape[1] // nproc) - send_elements[: obj.shape[1] % nproc] += 1 - - # Create short_Type from the last entry of send_elements - mpi_short_type = mpi_type.Create_vector(send_elements[-1], 1, strides[1]).Create_resized( - 0, offsets[1] - ) - mpi_short_type.Commit() - # Create long_Type from the first entry of send_elements (wraps one more mpi_type vector than short_Type - mpi_long_type = mpi_type.Create_vector(send_elements[0], 1, strides[1]).Create_resized( - 0, offsets[1] - ) - mpi_long_type.Commit() - - # Step 3: Pack short_type and long_type along the recv_axis - mpi_short_type = mpi_short_type.Create_vector(shape[0], 1, strides[0]).Create_resized( - 0, send_elements[-1] * obj.stride()[1] * obj.element_size() - ) - mpi_short_type.Commit() - mpi_long_type = mpi_long_type.Create_vector(shape[0], 1, strides[0]).Create_resized( - 0, send_elements[0] * obj.stride()[1] * obj.element_size() - ) - mpi_long_type.Commit() - - # Step 4: Prepare sencounts, senddispls and sendtypes for alltoallw - # to each process 1 element (=sendcount) of the custom prepared long or short type will be send - sendcount = [1] * nproc - tmp_displs = [0] * nproc - tmp_displs[1:] = np.cumsum(send_elements[:-1]) - element_size = obj.element_size() - senddispls = [element_size * obj.stride()[1] * d for d in tmp_displs] - sendtypes = [mpi_short_type] * nproc - for i in range(obj.shape[1] % nproc): - sendtypes[i] = mpi_long_type - - return self.as_mpi_memory(obj), (sendcount, senddispls), sendtypes - - def alltoall_recvbuffer( - self, obj: torch.Tensor - ) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]: - """ - Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type. - XXX: might not work for all MPI stacks. Might require multiple type commits or so - - Parameters - ---------- - obj: torch.Tensor - The object to be transformed into a custom MPI datatype - """ - mpi_type, _ = self.__mpi_type_mappings[obj.dtype], torch.numel(obj) - - nproc = self.size - shape = obj.shape[1:] - strides = [1] * len(shape) - strides[0] = obj.stride()[-1] - strides = strides[::-1] - offsets = [obj.element_size() * stride for stride in obj.stride()[:-1]] - - # Step 1: Wrap along axes > 0 (all axes except recv_axis) - for i in range(len(shape) - 1, -1, -1): - mpi_type = mpi_type.Create_vector(shape[i], 1, strides[i]).Create_resized(0, offsets[i]) - mpi_type.Commit() - - # Step 2: Receive blocks along the recv axis - # Prepare recvcount, senddispls and sendtypes for alltoallw - recvcount = np.full((nproc,), obj.shape[0] // nproc) - recvcount[: obj.shape[0] % nproc] += 1 - # size/extent of mpitype = offsets[0] - tmp_displs = [0] * nproc - tmp_displs[1:] = np.cumsum(recvcount[:-1]) - recvdispls = [offsets[0] * d for d in tmp_displs] - recvtypes = [mpi_type] * nproc - - return self.as_mpi_memory(obj), (recvcount, recvdispls), recvtypes - - def Free(self) -> None: - """ - Free a communicator. - """ - self.handle.Free() - - def Split(self, color: int = 0, key: int = 0) -> MPICommunication: - """ - Split communicator by color and key. - - Parameters - ---------- - color : int, optional - Determines the new communicator for a process. - key: int, optional - Ordering within the new communicator. - """ - return MPICommunication(self.handle.Split(color, key)) - - def Irecv( - self, - buf: Union[DNDarray, torch.Tensor, Any], - source: int = MPI.ANY_SOURCE, - tag: int = MPI.ANY_TAG, - ) -> MPIRequest: - """ - Nonblocking receive - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to place the received message - source: int, optional - Rank of source process, that send the message - tag: int, optional - A Tag to identify the message - """ - if isinstance(buf, DNDarray): - buf = buf.larray - if not isinstance(buf, torch.Tensor): - return MPIRequest(self.handle.Irecv(buf, source, tag)) - - rbuf = buf if CUDA_AWARE_MPI else buf.cpu() - return MPIRequest(self.handle.Irecv(self.as_buffer(rbuf), source, tag), None, rbuf, buf) - - Irecv.__doc__ = MPI.Comm.Irecv.__doc__ - - def Recv( - self, - buf: Union[DNDarray, torch.Tensor, Any], - source: int = MPI.ANY_SOURCE, - tag: int = MPI.ANY_TAG, - status: MPI.Status = None, - ): - """ - Blocking receive - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to place the received message - source: int, optional - Rank of the source process, that send the message - tag: int, optional - A Tag to identify the message - status: MPI.Status, optional - Details on the communication - """ - if isinstance(buf, DNDarray): - buf = buf.larray - if not isinstance(buf, torch.Tensor): - return self.handle.Recv(buf, source, tag, status) - - rbuf = buf if CUDA_AWARE_MPI else buf.cpu() - ret = self.handle.Recv(self.as_buffer(rbuf), source, tag, status) - - if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Recv.__doc__ = MPI.Comm.Recv.__doc__ - - def __send_like( - self, func: Callable, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int - ) -> Tuple[Optional[Union[DNDarray, torch.Tensor]]]: - """ - Generic function for sending a message to process with rank "dest" - - Parameters - ------------ - func: Callable - The respective MPI sending function - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be send - dest: int, optional - Rank of the destination process, that receives the message - tag: int, optional - A Tag to identify the message - """ - if isinstance(buf, DNDarray): - buf = buf.larray - if not isinstance(buf, torch.Tensor): - return func(buf, dest, tag), None - - # in case of GPUs, the memory has to be copied to host memory if CUDA-aware MPI is not supported - sbuf = buf if CUDA_AWARE_MPI else buf.cpu() - return func(self.as_buffer(sbuf), dest, tag), sbuf - - def Bsend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0): - """ - Blocking buffered send - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be send - dest: int, optional - Index of the destination process, that receives the message - tag: int, optional - A Tag to identify the message - """ - return self.__send_like(self.handle.Bsend, buf, dest, tag)[0] - - Bsend.__doc__ = MPI.Comm.Bsend.__doc__ - - def Ibsend( - self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0 - ) -> MPIRequest: - """ - Nonblocking buffered send - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be send - dest: int, optional - Rank of the destination process, that receives the message - tag: int, optional - A Tag to identify the message - """ - return MPIRequest(*self.__send_like(self.handle.Ibsend, buf, dest, tag)) - - Ibsend.__doc__ = MPI.Comm.Ibsend.__doc__ - - def Irsend( - self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0 - ) -> MPIRequest: - """ - Nonblocking ready send - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be send - dest: int, optional - Rank of the destination process, that receives the message - tag: int, optional - A Tag to identify the message - """ - return MPIRequest(*self.__send_like(self.handle.Irsend, buf, dest, tag)) - - Irsend.__doc__ = MPI.Comm.Irsend.__doc__ - - def Isend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0) -> MPIRequest: - """ - Nonblocking send - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be send - dest: int, optional - Rank of the destination process, that receives the message - tag: int, optional - A Tag to identify the message - """ - return MPIRequest(*self.__send_like(self.handle.Isend, buf, dest, tag)) - - Isend.__doc__ = MPI.Comm.Isend.__doc__ - - def Issend( - self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0 - ) -> MPIRequest: - """ - Nonblocking synchronous send - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be send - dest: int, optional - Rank of the destination process, that receives the message - tag: int, optional - A Tag to identify the message - """ - return MPIRequest(*self.__send_like(self.handle.Issend, buf, dest, tag)) - - Issend.__doc__ = MPI.Comm.Issend.__doc__ - - def Rsend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0): - """ - Blocking ready send - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be send - dest: int, optional - Rank of the destination process, that receives the message - tag: int, optional - A Tag to identify the message - """ - return self.__send_like(self.handle.Rsend, buf, dest, tag)[0] - - Rsend.__doc__ = MPI.Comm.Rsend.__doc__ - - def Ssend(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0): - """ - Blocking synchronous send - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be send - dest: int, optional - Rank of the destination process, that receives the message - tag: int, optional - A Tag to identify the message - """ - return self.__send_like(self.handle.Ssend, buf, dest, tag)[0] - - Ssend.__doc__ = MPI.Comm.Ssend.__doc__ - - def Send(self, buf: Union[DNDarray, torch.Tensor, Any], dest: int, tag: int = 0): - """ - Blocking send - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be send - dest: int, optional - Rank of the destination process, that receives the message - tag: int, optional - A Tag to identify the message - """ - return self.__send_like(self.handle.Send, buf, dest, tag)[0] - - Send.__doc__ = MPI.Comm.Send.__doc__ - - def __broadcast_like( - self, func: Callable, buf: Union[DNDarray, torch.Tensor, Any], root: int - ) -> Tuple[Optional[DNDarray, torch.Tensor]]: - """ - Generic function for broadcasting a message from the process with rank "root" to all other processes of the - communicator - - Parameters - ------------ - func: Callable - The respective MPI broadcast function - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be broadcasted - root: int - Rank of the root process, that broadcasts the message - """ - # unpack the buffer if it is a HeAT tensor - if isinstance(buf, DNDarray): - buf = buf.larray - # convert torch tensors to MPI memory buffers - if not isinstance(buf, torch.Tensor): - return func(buf, root), None, None, None - - srbuf = buf if CUDA_AWARE_MPI else buf.cpu() - - return func(self.as_buffer(srbuf), root), srbuf, srbuf, buf - - def Bcast(self, buf: Union[DNDarray, torch.Tensor, Any], root: int = 0) -> None: - """ - Blocking Broadcast - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be broadcasted - root: int - Rank of the root process, that broadcasts the message - """ - ret, sbuf, rbuf, buf = self.__broadcast_like(self.handle.Bcast, buf, root) - if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Bcast.__doc__ = MPI.Comm.Bcast.__doc__ - - def Ibcast(self, buf: Union[DNDarray, torch.Tensor, Any], root: int = 0) -> MPIRequest: - """ - Nonblocking Broadcast - - Parameters - ------------ - buf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the message to be broadcasted - root: int - Rank of the root process, that broadcasts the message - """ - return MPIRequest(*self.__broadcast_like(self.handle.Ibcast, buf, root)) - - Ibcast.__doc__ = MPI.Comm.Ibcast.__doc__ - - def __reduce_like( - self, - func: Callable, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - *args, - **kwargs, - ) -> Tuple[Optional[DNDarray, torch.Tensor]]: - """ - Generic function for reduction operations. - - Parameters - ------------ - func: Callable - The respective MPI reduction operation - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result of the reduction - """ - sbuf = None - rbuf = None - buf = None - # unpack the send buffer if it is a HeAT tensor - if isinstance(sendbuf, DNDarray): - sendbuf = sendbuf.larray - # unpack the receive buffer if it is a HeAT tensor - if isinstance(recvbuf, DNDarray): - recvbuf = recvbuf.larray - - # harmonize the input and output buffers - # MPI requires send and receive buffers to be of same type and length. If the torch tensors are either not both - # contiguous or differently strided, they have to be made matching (if possible) first. - if isinstance(sendbuf, torch.Tensor): - # convert the send buffer to a pointer, number of elements and type are identical to the receive buffer - dummy = ( - sendbuf.contiguous() - ) # make a contiguous copy and reassign the storage, old will be collected - # In PyTorch Version >= 2.0.0 we can use untyped_storage() instead of storage - # to keep backward compatibility with earlier PyTorch versions (where no untyped_storage() exists) we use a try/except - # (this applies to all places of Heat where untyped_storage() is used without further comment) - try: - sendbuf.set_( - dummy.untyped_storage(), - dummy.storage_offset(), - size=dummy.shape, - stride=dummy.stride(), - ) - except AttributeError: - sendbuf.set_( - dummy.storage(), - dummy.storage_offset(), - size=dummy.shape, - stride=dummy.stride(), - ) - sbuf = sendbuf if CUDA_AWARE_MPI else sendbuf.cpu() - sendbuf = self.as_buffer(sbuf) - if isinstance(recvbuf, torch.Tensor): - buf = recvbuf - # nothing matches, the buffers have to be made contiguous - dummy = recvbuf.contiguous() - try: - recvbuf.set_( - dummy.untyped_storage(), - dummy.storage_offset(), - size=dummy.shape, - stride=dummy.stride(), - ) - except AttributeError: - recvbuf.set_( - dummy.storage(), - dummy.storage_offset(), - size=dummy.shape, - stride=dummy.stride(), - ) - rbuf = recvbuf if CUDA_AWARE_MPI else recvbuf.cpu() - if sendbuf is MPI.IN_PLACE: - recvbuf = self.as_buffer(rbuf) - else: - recvbuf = (self.as_mpi_memory(rbuf), sendbuf[1], sendbuf[2]) - - # perform the actual reduction operation - return func(sendbuf, recvbuf, *args, **kwargs), sbuf, rbuf, buf - - def Allreduce( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - op: MPI.Op = MPI.SUM, - ): - """ - Combines values from all processes and distributes the result back to all processes - - Parameters - --------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result of the reduction - op: MPI.Op - The operation to perform upon reduction - """ - ret, sbuf, rbuf, buf = self.__reduce_like(self.handle.Allreduce, sendbuf, recvbuf, op) - if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Allreduce.__doc__ = MPI.Comm.Allreduce.__doc__ - - def Exscan( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - op: MPI.Op = MPI.SUM, - ): - """ - Computes the exclusive scan (partial reductions) of data on a collection of processes - - Parameters - ------------ - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result of the reduction - op: MPI.Op - The operation to perform upon reduction - """ - ret, sbuf, rbuf, buf = self.__reduce_like(self.handle.Exscan, sendbuf, recvbuf, op) - if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Exscan.__doc__ = MPI.COMM_WORLD.Exscan.__doc__ - - def Iallreduce( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - op: MPI.Op = MPI.SUM, - ) -> MPIRequest: - """ - Nonblocking allreduce reducing values on all processes to a single value - - Parameters - --------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result of the reduction - op: MPI.Op - The operation to perform upon reduction - """ - return MPIRequest(*self.__reduce_like(self.handle.Iallreduce, sendbuf, recvbuf, op)) - - Iallreduce.__doc__ = MPI.Comm.Iallreduce.__doc__ - - def Iexscan( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - op: MPI.Op = MPI.SUM, - ) -> MPIRequest: - """ - Nonblocking Exscan - - Parameters - ------------ - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result of the reduction - op: MPI.Op - The operation to perform upon reduction - """ - return MPIRequest(*self.__reduce_like(self.handle.Iexscan, sendbuf, recvbuf, op)) - - Iexscan.__doc__ = MPI.COMM_WORLD.Iexscan.__doc__ - - def Iscan( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - op: MPI.Op = MPI.SUM, - ) -> MPIRequest: - """ - Nonblocking Scan - - Parameters - ------------ - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result of the reduction - op: MPI.Op - The operation to perform upon reduction - """ - return MPIRequest(*self.__reduce_like(self.handle.Iscan, sendbuf, recvbuf, op)) - - Iscan.__doc__ = MPI.COMM_WORLD.Iscan.__doc__ - - def Ireduce( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - op: MPI.Op = MPI.SUM, - root: int = 0, - ) -> MPIRequest: - """ - Nonblocking reduction operation - - Parameters - --------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result of the reduction - op: MPI.Op - The operation to perform upon reduction - root: int - Rank of the root process - """ - return MPIRequest(*self.__reduce_like(self.handle.Ireduce, sendbuf, recvbuf, op, root)) - - Ireduce.__doc__ = MPI.Comm.Ireduce.__doc__ - - def Reduce( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - op: MPI.Op = MPI.SUM, - root: int = 0, - ): - """ - Reduce values from all processes to a single value on process "root" - - Parameters - --------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result of the reduction - op: MPI.Op - The operation to perform upon reduction - root: int - Rank of the root process - """ - ret, sbuf, rbuf, buf = self.__reduce_like(self.handle.Reduce, sendbuf, recvbuf, op, root) - if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Reduce.__doc__ = MPI.Comm.Reduce.__doc__ - - def Scan( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - op: MPI.Op = MPI.SUM, - ): - """ - Computes the scan (partial reductions) of data on a collection of processes in a nonblocking way - - Parameters - ------------ - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result of the reduction - op: MPI.Op - The operation to perform upon reduction - """ - ret, sbuf, rbuf, buf = self.__reduce_like(self.handle.Scan, sendbuf, recvbuf, op) - if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Scan.__doc__ = MPI.COMM_WORLD.Scan.__doc__ - - def __allgather_like( - self, - func: Callable, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - axis: int, - **kwargs, - ): - """ - Generic function for allgather operations. - - Parameters - ---------- - func: Callable - Type of MPI Allgather function (i.e. allgather, allgatherv, iallgather) - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - axis: int - Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks - """ - # dummy allocation for *v calls - # ToDO: Propper implementation of usage - send_counts, send_displs, recv_counts, recv_displs = None, None, None, None - - # unpack the send buffer - if isinstance(sendbuf, tuple): - sendbuf, send_counts, send_displs = sendbuf - if isinstance(sendbuf, DNDarray): - sendbuf = sendbuf.larray - if not isinstance(sendbuf, torch.Tensor) and axis != 0: - raise TypeError( - f"sendbuf of type {type(sendbuf)} does not support concatenation axis != 0" - ) - # unpack the receive buffer - if isinstance(recvbuf, tuple): - recvbuf, recv_counts, recv_displs = recvbuf - if isinstance(recvbuf, DNDarray): - recvbuf = recvbuf.larray - if not isinstance(recvbuf, torch.Tensor) and axis != 0: - raise TypeError( - f"recvbuf of type {type(recvbuf)} does not support concatenation axis != 0" - ) - - # keep a reference to the original buffer object - original_recvbuf = recvbuf - sbuf_is_contiguous, rbuf_is_contiguous = None, None - # permute the send_axis order so that the split send_axis is the first to be transmitted - if axis != 0: - send_axis_permutation = list(range(sendbuf.ndimension())) - send_axis_permutation[0], send_axis_permutation[axis] = axis, 0 - sendbuf = sendbuf.permute(*send_axis_permutation) - sbuf_is_contiguous = False - - recv_axis_permutation = list(range(recvbuf.ndimension())) - recv_axis_permutation[0], recv_axis_permutation[axis] = axis, 0 - recvbuf = recvbuf.permute(*recv_axis_permutation) - rbuf_is_contiguous = False - else: - recv_axis_permutation = None - - sbuf = sendbuf if CUDA_AWARE_MPI or not isinstance(sendbuf, torch.Tensor) else sendbuf.cpu() - rbuf = recvbuf if CUDA_AWARE_MPI or not isinstance(recvbuf, torch.Tensor) else recvbuf.cpu() - - # prepare buffer objects - if sendbuf is MPI.IN_PLACE or not isinstance(sendbuf, torch.Tensor): - mpi_sendbuf = sbuf - else: - mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs, sbuf_is_contiguous) - if send_counts is not None: - mpi_sendbuf[1] = mpi_sendbuf[1][0][self.rank] - - if recvbuf is MPI.IN_PLACE or not isinstance(recvbuf, torch.Tensor): - mpi_recvbuf = rbuf - else: - mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs, rbuf_is_contiguous) - if recv_counts is None: - mpi_recvbuf[1] //= self.size - # perform the scatter operation - exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs) - return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation - - def Allgather( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - recv_axis: int = 0, - ): - """ - Gathers data from all tasks and distribute the combined data to all tasks - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - recv_axis: int - Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks - """ - ret, sbuf, rbuf, buf, permutation = self.__allgather_like( - self.handle.Allgather, sendbuf, recvbuf, recv_axis - ) - if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: - rbuf = rbuf.permute(permutation) - if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Allgather.__doc__ = MPI.Comm.Allgather.__doc__ - - def Allgatherv( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - recv_axis: int = 0, - ): - """ - v-call of Allgather: Each process may contribute a different amount of data. - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - recv_axis: int - Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks - """ - ret, sbuf, rbuf, buf, permutation = self.__allgather_like( - self.handle.Allgatherv, sendbuf, recvbuf, recv_axis - ) - if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: - rbuf = rbuf.permute(permutation) - if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Allgatherv.__doc__ = MPI.Comm.Allgatherv.__doc__ - - def Iallgather( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - recv_axis: int = 0, - ) -> MPIRequest: - """ - Nonblocking Allgather. - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - recv_axis: int - Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks - """ - return MPIRequest( - *self.__allgather_like(self.handle.Iallgather, sendbuf, recvbuf, recv_axis) - ) - - Iallgather.__doc__ = MPI.Comm.Iallgather.__doc__ - - def Iallgatherv( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - recv_axis: int = 0, - ): - """ - Nonblocking v-call of Allgather: Each process may contribute a different amount of data. - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - recv_axis: int - Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks - """ - return MPIRequest( - *self.__allgather_like(self.handle.Iallgatherv, sendbuf, recvbuf, recv_axis) - ) - - Iallgatherv.__doc__ = MPI.Comm.Iallgatherv.__doc__ - - def __alltoall_like( - self, - func: Callable, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - send_axis: int, - recv_axis: int, - **kwargs, - ): - """ - Generic function for alltoall operations. - - Parameters - ---------- - func: Callable - Specific alltoall function - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - send_axis: int - Future split axis, along which data blocks will be created that will be send to individual ranks - - - if ``send_axis==recv_axis``, an error will be thrown - - if ``send_axis`` or ``recv_axis`` are ``None``, an error will be thrown - recv_axis: int - Prior split axis, along which blocks are received from the individual ranks - """ - if send_axis is None: - raise NotImplementedError( - f"AllToAll needs send_axis and recv_axis to be specified but was send_axis = {send_axis}, recv_axis = {recv_axis}. Please set send_axis and recv_axis" - ) - # align the output buffer in the same way as the input buffer by default - if recv_axis is None: - recv_axis = send_axis - - # dummy allocation for *v calls - send_counts, send_displs, recv_counts, recv_displs = None, None, None, None - - # unpack the send buffer - if isinstance(sendbuf, tuple): - sendbuf, send_counts, send_displs = sendbuf - if isinstance(sendbuf, DNDarray): - sendbuf = sendbuf.larray - if not isinstance(sendbuf, torch.Tensor) and send_axis != 0: - raise TypeError(f"sendbuf of type {type(sendbuf)} does not support send_axis != 0") - - # unpack the receive buffer - if isinstance(recvbuf, tuple): - recvbuf, recv_counts, recv_displs = recvbuf - if isinstance(recvbuf, DNDarray): - recvbuf = recvbuf.larray - if not isinstance(recvbuf, torch.Tensor) and send_axis != 0: - raise TypeError(f"recvbuf of type {type(recvbuf)} does not support send_axis != 0") - - # keep a reference to the original buffer object - original_recvbuf = recvbuf - - # Simple case, contiguous buffers can be transmitted as is - if send_axis < 2 and recv_axis < 2: - send_axis_permutation = list(range(recvbuf.ndimension())) - recv_axis_permutation = list(range(recvbuf.ndimension())) - - # Minimal Fix; Could possibly be improved when reworking counts, displs algorithmics - if self.size > 1: - send_axis_permutation[0], send_axis_permutation[send_axis] = (send_axis, 0) - recv_axis_permutation[0], recv_axis_permutation[recv_axis] = (recv_axis, 0) - - else: - recv_counts = send_counts - - sendbuf = sendbuf.permute(*send_axis_permutation) - recvbuf = recvbuf.permute(*recv_axis_permutation) - - # prepare buffer objects - sbuf = ( - sendbuf - if CUDA_AWARE_MPI or not isinstance(sendbuf, torch.Tensor) - else sendbuf.cpu() - ) - mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs) - if send_counts is None: - mpi_sendbuf[1] //= self.size - - rbuf = ( - recvbuf - if CUDA_AWARE_MPI or not isinstance(recvbuf, torch.Tensor) - else recvbuf.cpu() - ) - mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs) - if recv_counts is None: - mpi_recvbuf[1] //= self.size - - # perform the scatter operation - exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs) - # slightly more difficult situation, send and receive buffer need custom datatype preparation; - # operation is performed via alltoallw - else: - if recv_axis == send_axis: - raise NotImplementedError( - "AllToAll for same axes not supported. Please choose send_axis and recv_axis to be different." - ) - - # Send_axis-Permutation: [recv_axis, send_axis, rest ...] - axis_permutation = list(range(recvbuf.ndimension())) - if send_axis == 0: - axis_permutation[1], axis_permutation[send_axis] = send_axis, 1 - axis_permutation[recv_axis] = axis_permutation[0] - axis_permutation[0] = recv_axis - - else: - axis_permutation[0], axis_permutation[recv_axis] = recv_axis, 0 - axis_permutation[send_axis] = axis_permutation[1] - axis_permutation[1] = send_axis - - sendbuf = sendbuf.permute(*axis_permutation) - recvbuf = recvbuf.permute(*axis_permutation) - - # prepare buffer objects - sbuf = ( - sendbuf - if CUDA_AWARE_MPI or not isinstance(sendbuf, torch.Tensor) - else sendbuf.cpu() - ) - rbuf = ( - recvbuf - if CUDA_AWARE_MPI or not isinstance(recvbuf, torch.Tensor) - else recvbuf.cpu() - ) - mpi_sendbuf = self.alltoall_sendbuffer(sbuf) - mpi_recvbuf = self.alltoall_recvbuffer(rbuf) - - exit_code = self.handle.Alltoallw(mpi_sendbuf, mpi_recvbuf, **kwargs) - # original_recvbuf.set_(recvbuf.untyped_storage(), recvbuf.storage_offset(), original_recvbuf.shape, original_recvbuf.stride()) - recv_axis_permutation = list(np.argsort(np.array(axis_permutation))) - - return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation - - def Alltoall( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - send_axis: int = 0, - recv_axis: int = None, - ): - """ - All processes send data to all processes: The jth block sent from process i is received by process j and is - placed in the ith block of recvbuf. - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - send_axis: int - Future split axis, along which data blocks will be created that will be send to individual ranks - - - if ``send_axis==recv_axis``, an error will be thrown - - if ``send_axis`` or ``recv_axis`` are ``None``, an error will be thrown - recv_axis: int - Prior split axis, along which blocks are received from the individual ranks - """ - ret, sbuf, rbuf, buf, permutation = self.__alltoall_like( - self.handle.Alltoall, sendbuf, recvbuf, send_axis, recv_axis - ) - if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: - rbuf = rbuf.permute(permutation) - if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Alltoall.__doc__ = MPI.Comm.Alltoall.__doc__ - - def Alltoallv( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - send_axis: int = 0, - recv_axis: int = None, - ): - """ - v-call of Alltoall: All processes send different amount of data to, and receive different amount of data - from, all processes - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - send_axis: int - Future split axis, along which data blocks will be created that will be send to individual ranks - - - if ``send_axis==recv_axis``, an error will be thrown - - if ``send_axis`` or ``recv_axis`` are ``None``, an error will be thrown - recv_axis: int - Prior split axis, along which blocks are received from the individual ranks - """ - ret, sbuf, rbuf, buf, permutation = self.__alltoall_like( - self.handle.Alltoallv, sendbuf, recvbuf, send_axis, recv_axis - ) - if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: - rbuf = rbuf.permute(permutation) - if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Alltoallv.__doc__ = MPI.Comm.Alltoallv.__doc__ - - def Ialltoall( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - send_axis: int = 0, - recv_axis: int = None, - ) -> MPIRequest: - """ - Nonblocking Alltoall - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - send_axis: int - Future split axis, along which data blocks will be created that will be send to individual ranks - - - if ``send_axis==recv_axis``, an error will be thrown - - if ``send_axis`` or ``recv_axis`` are ``None``, an error will be thrown - recv_axis: int - Prior split axis, along which blocks are received from the individual ranks - """ - return MPIRequest( - *self.__alltoall_like(self.handle.Ialltoall, sendbuf, recvbuf, send_axis, recv_axis) - ) - - Ialltoall.__doc__ = MPI.Comm.Ialltoall.__doc__ - - def Ialltoallv( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - send_axis: int = 0, - recv_axis: int = None, - ) -> MPIRequest: - """ - Nonblocking v-call of Alltoall: All processes send different amount of data to, and receive different amount of - data from, all processes - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - send_axis: int - Future split axis, along which data blocks will be created that will be send to individual ranks - - - if ``send_axis==recv_axis``, an error will be thrown - - if ``send_axis`` or ``recv_axis`` are ``None``, an error will be thrown - recv_axis: int - Prior split axis, along which blocks are received from the individual ranks - """ - return MPIRequest( - *self.__alltoall_like(self.handle.Ialltoallv, sendbuf, recvbuf, send_axis, recv_axis) - ) - - Ialltoallv.__doc__ = MPI.Comm.Ialltoallv.__doc__ - - def __gather_like( - self, - func: Callable, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - send_axis: int, - recv_axis: int, - send_factor: int = 1, - recv_factor: int = 1, - **kwargs, - ): - """ - Generic function for gather operations. - - Parameters - ---------- - func: Callable - Type of MPI Scatter/Gather function - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - send_axis: int - The axis along which ``sendbuf`` is packed - recv_axis: int - The axis along which ``recvbuf`` is packed - send_factor: int - Number of elements to be scattered (vor non-v-calls) - recv_factor: int - Number of elements to be gathered (vor non-v-calls) - """ - sbuf, rbuf, recv_axis_permutation = None, None, None - - # align the output buffer in the same way as the input buffer by default - if recv_axis is None: - recv_axis = send_axis - - # dummy allocation for *v calls - send_counts, send_displs, recv_counts, recv_displs = None, None, None, None - - # unpack the send buffer - # if isinstance(sendbuf, tuple): - # sendbuf, send_counts, send_displs = sendbuf - if isinstance(sendbuf, DNDarray): - sendbuf = sendbuf.larray - if not isinstance(sendbuf, torch.Tensor) and send_axis != 0: - raise TypeError(f"sendbuf of type {type(sendbuf)} does not support send_axis != 0") - - # unpack the receive buffer - if isinstance(recvbuf, tuple): - recvbuf, recv_counts, recv_displs = recvbuf - if isinstance(recvbuf, DNDarray): - recvbuf = recvbuf.larray - if not isinstance(recvbuf, torch.Tensor) and recv_axis != 0: - raise TypeError(f"recvbuf of type {type(recvbuf)} does not support recv_axis != 0") - - # keep a reference to the original buffer object - original_recvbuf = recvbuf - - # permute the send_axis order so that the split send_axis is the first to be transmitted - send_axis_permutation = list(range(sendbuf.ndimension())) - send_axis_permutation[0], send_axis_permutation[send_axis] = send_axis, 0 - sendbuf = sendbuf.permute(*send_axis_permutation) - - if self.rank == kwargs.get("root"): - recv_axis_permutation = list(range(recvbuf.ndimension())) - recv_axis_permutation[0], recv_axis_permutation[recv_axis] = recv_axis, 0 - recvbuf = recvbuf.permute(*recv_axis_permutation) - - # prepare buffer objects - sbuf = sendbuf if CUDA_AWARE_MPI or not isinstance(sendbuf, torch.Tensor) else sendbuf.cpu() - rbuf = recvbuf if CUDA_AWARE_MPI or not isinstance(recvbuf, torch.Tensor) else recvbuf.cpu() - - if sendbuf is not MPI.IN_PLACE: - mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs) - if send_counts is None: - mpi_sendbuf[1] //= send_factor - else: - mpi_sendbuf = sbuf - if recvbuf is not MPI.IN_PLACE: - mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs) - if recv_counts is None: - mpi_recvbuf[1] //= recv_factor - else: - mpi_recvbuf = rbuf - - # perform the scatter operation - exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs) - - # undo the recvbuf permutation and assign the temporary buffer to the original recvbuf - # if recv_axis != 0: - # recvbuf = recvbuf.permute(*recv_axis_permutation) - # original_recvbuf.set_(recvbuf.untyped_storage(), recvbuf.storage_offset(), recvbuf.shape, recvbuf.stride()) - - return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation - - def Gather( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - root: int = 0, - axis: int = 0, - recv_axis: int = None, - ): - """ - Gathers together values from a group of processes - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - root: int - Rank of receiving process - axis: int - The axis along which ``sendbuf`` is packed - recv_axis: int - The axis along which ``recvbuf`` is packed - """ - ret, sbuf, rbuf, buf, permutation = self.__gather_like( - self.handle.Gather, sendbuf, recvbuf, axis, recv_axis, root=root, recv_factor=self.size - ) - if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: - rbuf = rbuf.permute(permutation) - if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Gather.__doc__ = MPI.Comm.Gather.__doc__ - - def Gatherv( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - root: int = 0, - axis: int = 0, - recv_axis: int = None, - ): - """ - v-call for Gather: All processes send different amount of data - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - root: int - Rank of receiving process - axis: int - The axis along which ``sendbuf`` is packed - recv_axis: int - The axis along which ``recvbuf`` is packed - """ - ret, sbuf, rbuf, buf, permutation = self.__gather_like( - self.handle.Gatherv, sendbuf, recvbuf, axis, recv_axis, root=root - ) - if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: - rbuf = rbuf.permute(permutation) - if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Gatherv.__doc__ = MPI.Comm.Gatherv.__doc__ - - def Igather( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - root: int = 0, - axis: int = 0, - recv_axis: int = None, - ) -> MPIRequest: - """ - Non-blocking Gather - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - root: int - Rank of receiving process - axis: int - The axis along which ``sendbuf`` is packed - recv_axis: int - The axis along which ``recvbuf`` is packed - """ - return MPIRequest( - *self.__gather_like( - self.handle.Igather, - sendbuf, - recvbuf, - axis, - recv_axis, - root=root, - recv_factor=self.size, - ) - ) - - Igather.__doc__ = MPI.Comm.Igather.__doc__ - - def Igatherv( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - root: int = 0, - axis: int = 0, - recv_axis: int = None, - ) -> MPIRequest: - """ - Non-blocking v-call for Gather: All processes send different amount of data - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - root: int - Rank of receiving process - axis: int - The axis along which ``sendbuf`` is packed - recv_axis: int - The axis along which ``recvbuf`` is packed - """ - return MPIRequest( - *self.__gather_like( - self.handle.Igatherv, - sendbuf, - recvbuf, - axis, - recv_axis, - root=root, - recv_factor=self.size, - ) - ) - - Igatherv.__doc__ = MPI.Comm.Igatherv.__doc__ - - def __scatter_like( - self, - func: Callable, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - send_axis: int, - recv_axis: int, - send_factor: int = 1, - recv_factor: int = 1, - **kwargs, - ): - """ - Generic function for scatter operations. - - Parameters - ---------- - func: Callable - Type of MPI Scatter/Gather function - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - send_axis: int - The axis along which ``sendbuf`` is packed - recv_axis: int - The axis along which ``recvbuf`` is packed - send_factor: int - Number of elements to be scattered (vor non-v-calls) - recv_factor: int - Number of elements to be gathered (vor non-v-calls) - """ - sbuf, rbuf, recv_axis_permutation = None, None, None - - # align the output buffer in the same way as the input buffer by default - if recv_axis is None: - recv_axis = send_axis - - # dummy allocation for *v calls - send_counts, send_displs, recv_counts, recv_displs = None, None, None, None - - # unpack the send buffer - if isinstance(sendbuf, tuple): - sendbuf, send_counts, send_displs = sendbuf - if isinstance(sendbuf, DNDarray): - sendbuf = sendbuf.larray - if not isinstance(sendbuf, torch.Tensor) and send_axis != 0: - raise TypeError(f"sendbuf of type {type(sendbuf)} does not support send_axis != 0") - - # unpack the receive buffer - # if isinstance(recvbuf, tuple): - # recvbuf, recv_counts, recv_displs = recvbuf - if isinstance(recvbuf, DNDarray): - recvbuf = recvbuf.larray - if not isinstance(recvbuf, torch.Tensor) and recv_axis != 0: - raise TypeError(f"recvbuf of type {type(recvbuf)} does not support recv_axis != 0") - - # keep a reference to the original buffer object - original_recvbuf = recvbuf - - # permute the send_axis order so that the split send_axis is the first to be transmitted - if self.rank == kwargs.get("root"): - send_axis_permutation = list(range(sendbuf.ndimension())) - send_axis_permutation[0], send_axis_permutation[send_axis] = send_axis, 0 - sendbuf = sendbuf.permute(*send_axis_permutation) - - recv_axis_permutation = list(range(recvbuf.ndimension())) - recv_axis_permutation[0], recv_axis_permutation[recv_axis] = recv_axis, 0 - recvbuf = recvbuf.permute(*recv_axis_permutation) - - # prepare buffer objects - sbuf = sendbuf if CUDA_AWARE_MPI or not isinstance(sendbuf, torch.Tensor) else sendbuf.cpu() - rbuf = recvbuf if CUDA_AWARE_MPI or not isinstance(recvbuf, torch.Tensor) else recvbuf.cpu() - - if sendbuf is not MPI.IN_PLACE: - mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs) - if send_counts is None: - mpi_sendbuf[1] //= send_factor - else: - mpi_sendbuf = sbuf - if recvbuf is not MPI.IN_PLACE: - mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs) - if recv_counts is None: - mpi_recvbuf[1] //= recv_factor - else: - mpi_recvbuf = rbuf - - # perform the scatter operation - exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs) - - # undo the recvbuf permutation and assign the temporary buffer to the original recvbuf - # if recv_axis != 0: - # recvbuf = recvbuf.permute(*recv_axis_permutation) - # original_recvbuf.set_(recvbuf.untyped_storage(), recvbuf.storage_offset(), recvbuf.shape, recvbuf.stride()) - - return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation - - def Iscatter( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - root: int = 0, - axis: int = 0, - recv_axis: int = None, - ) -> MPIRequest: - """ - Non-blocking Scatter - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - root: int - Rank of sending process - axis: int - The axis along which ``sendbuf`` is packed - recv_axis: int - The axis along which ``recvbuf`` is packed - """ - return MPIRequest( - *self.__scatter_like( - self.handle.Iscatter, - sendbuf, - recvbuf, - axis, - recv_axis, - root=root, - send_factor=self.size, - ) - ) - - Iscatter.__doc__ = MPI.Comm.Iscatter.__doc__ - - def Iscatterv( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - root: int = 0, - axis: int = 0, - recv_axis: int = None, - ) -> MPIRequest: - """ - Non-blocking v-call for Scatter: Sends different amounts of data to different processes - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - root: int - Rank of sending process - axis: int - The axis along which ``sendbuf`` is packed - recv_axis: int - The axis along which ``recvbuf`` is packed - """ - return MPIRequest( - *self.__scatter_like( - self.handle.Iscatterv, - sendbuf, - recvbuf, - axis, - recv_axis, - root=root, - send_factor=self.size, - ) - ) - - Iscatterv.__doc__ = MPI.Comm.Iscatterv.__doc__ - - def Scatter( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: Union[DNDarray, torch.Tensor, Any], - root: int = 0, - axis: int = 0, - recv_axis: int = None, - ): - """ - Sends data parts from one process to all other processes in a communicator - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - root: int - Rank of sending process - axis: int - The axis along which ``sendbuf`` is packed - recv_axis: int - The axis along which ``recvbuf`` is packed - """ - ret, sbuf, rbuf, buf, permutation = self.__scatter_like( - self.handle.Scatter, sendbuf, recvbuf, axis, recv_axis, root=root, send_factor=self.size - ) - if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: - rbuf = rbuf.permute(permutation) - if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Scatter.__doc__ = MPI.Comm.Scatter.__doc__ - - def Scatterv( - self, - sendbuf: Union[DNDarray, torch.Tensor, Any], - recvbuf: int, - root: int = 0, - axis: int = 0, - recv_axis: int = None, - ): - """ - v-call for Scatter: Sends different amounts of data to different processes - - Parameters - ---------- - sendbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address of the send message - recvbuf: Union[DNDarray, torch.Tensor, Any] - Buffer address where to store the result - root: int - Rank of sending process - axis: int - The axis along which ``sendbuf`` is packed - recv_axis: int - The axis along which ``recvbuf`` is packed - """ - ret, sbuf, rbuf, buf, permutation = self.__scatter_like( - self.handle.Scatterv, - sendbuf, - recvbuf, - axis, - recv_axis, - root=root, - send_factor=self.size, - ) - if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: - rbuf = rbuf.permute(permutation) - if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: - buf.copy_(rbuf) - return ret - - Scatterv.__doc__ = MPI.Comm.Scatterv.__doc__ - - def __getattr__(self, name: str): - """ - Default pass-through for the communicator methods. - - Parameters - ---------- - name : str - The name of the method to be called. - """ - return getattr(self.handle, name) - - -# creating a duplicate COMM -comm = MPI.COMM_WORLD -dup_comm = comm.Dup() - -MPI_WORLD = MPICommunication(dup_comm) -MPI_SELF = MPICommunication(MPI.COMM_SELF.Dup()) - -# set the default communicator to be MPI_WORLD -__default_comm = MPI_WORLD - - -def get_comm() -> Communication: - """ - Retrieves the currently globally set default communication. - """ - return __default_comm - - -def sanitize_comm(comm: Optional[Communication]) -> Communication: - """ - Sanitizes a device or device identifier, i.e. checks whether it is already an instance of :class:`heat.core.devices.Device` - or a string with known device identifier and maps it to a proper ``Device``. - - Parameters - ---------- - comm : Communication - The comm to be sanitized - - Raises - ------ - TypeError - If the given communication is not the proper type - """ - if comm is None: - return get_comm() - elif isinstance(comm, Communication): - return comm - - raise TypeError(f"Unknown communication, must be instance of {Communication}") - - -def use_comm(comm: Communication = None): - """ - Sets the globally used default communicator. - - Parameters - ---------- - comm : Communication or None - The communication to be set - """ - global __default_comm - __default_comm = sanitize_comm(comm) - - -# import at the end of file to break circular dependencies -from .dndarray import DNDarray +""" +This implements wrappers for mpi4py for PyTorch and Heat +""" +from __future__ import annotations + +import numpy as np +import os +import subprocess +import torch +from mpi4py import MPI + +from typing import Any, Callable, Optional, List, Tuple, Union +from heat.communication_backends.communication import Communication + + +CUDA_AWARE_MPI = False +# check whether OpenMPI support CUDA-aware MPI +if "openmpi" in os.environ.get("MPI_SUFFIX", "").lower(): + buffer = subprocess.check_output(["ompi_info", "--parsable", "--all"]) + CUDA_AWARE_MPI = b"mpi_built_with_cuda_support:value:true" in buffer +# MVAPICH +CUDA_AWARE_MPI = CUDA_AWARE_MPI or os.environ.get("MV2_USE_CUDA") == "1" +# MPICH +CUDA_AWARE_MPI = CUDA_AWARE_MPI or os.environ.get("MPIR_CVAR_ENABLE_HCOLL") == "1" +# ParaStationMPI +CUDA_AWARE_MPI = CUDA_AWARE_MPI or os.environ.get("PSP_CUDA") == "1" + + +class MPIRequest: + """ + Represents a handle on a non-blocking operation + + Parameters + ---------- + handle: MPI.Communicator + Handle for the mpi4py Communicator + sendbuf: torch.Tensor or Any + The buffer for the data to be send + recvbuf: torch.Tensor or Any + The buffer to the receive data + tensor: torch.Tensor + Internal Data + permutation: Tuple[int,...] + Permutation of the tensor axes + """ + + def __init__( + self, + handle, + sendbuf: Union[torch.Tensor, Any] = None, + recvbuf: Union[torch.Tensor, Any] = None, + tensor: torch.Tensor = None, + permutation: Tuple[int, ...] = None, + ): + self.handle = handle + self.tensor = tensor + self.recvbuf = recvbuf + self.sendbuf = sendbuf + self.permutation = permutation + + def Wait(self, status: MPI.Status = None): + """ + Waits for an MPI request to complete + """ + self.handle.Wait(status) + if self.tensor is not None and isinstance(self.tensor, torch.Tensor): + if self.permutation is not None: + self.recvbuf = self.recvbuf.permute(self.permutation) + if self.tensor is not None and self.tensor.is_cuda and not CUDA_AWARE_MPI: + self.tensor.copy_(self.recvbuf) + + def __getattr__(self, name: str) -> Callable: + """ + Default pass-through for the communicator methods. + + Parameters + ---------- + name : str + The name of the method to be called. + """ + return getattr(self.handle, name) + + +class MPICommunication(Communication): + """ + Class encapsulating all MPI Communication + + Parameters + ---------- + handle: MPI.Communicator + Handle for the mpi4py Communicator + """ + + __mpi_type_mappings = { + torch.bool: MPI.BOOL, + torch.uint8: MPI.UNSIGNED_CHAR, + torch.int8: MPI.SIGNED_CHAR, + torch.int16: MPI.SHORT, + torch.int32: MPI.INT, + torch.int64: MPI.LONG, + torch.bfloat16: MPI.INT16_T, + torch.float16: MPI.INT16_T, + torch.float32: MPI.FLOAT, + torch.float64: MPI.DOUBLE, + torch.complex64: MPI.COMPLEX, + torch.complex128: MPI.DOUBLE_COMPLEX, + } + + def __init__(self, handle=MPI.COMM_WORLD): + self.handle = handle + try: + self.rank = handle.Get_rank() + self.size = handle.Get_size() + except MPI.Exception: + # ranks not within the group will fail with an MPI.Exception, this is expected + self.rank = None + self.size = None + + def is_distributed(self) -> bool: + """ + Determines whether the communicator is distributed, i.e. handles more than one node. + """ + return self.size > 1 + + @classmethod + def mpi_type_and_elements_of( + cls, + obj: torch.Tensor, + counts: Tuple[int], + displs: Tuple[int], + is_contiguous: Optional[bool], + ) -> Tuple[MPI.Datatype, Tuple[int, ...]]: + """ + Determines the MPI data type and number of respective elements for the given ``torch.Tensor. In case the tensor is contiguous in memory, a native MPI data type can be used. + Otherwise, a derived data type is automatically constructed using the storage information of the passed object. + + Parameters + ---------- + obj : torch.Tensor + The object for which to construct the MPI data type and number of elements + counts : Tuple[ints,...], optional + Optional counts arguments for variable MPI-calls (e.g. Alltoallv) + displs : Tuple[ints,...], optional + Optional displacements arguments for variable MPI-calls (e.g. Alltoallv) + is_contiguous: bool + Information on global contiguity of the memory-distributed object. If `None`, it will be set to local contiguity via ``torch.Tensor.is_contiguous()``. + # ToDo: The option to explicitely specify the counts and displacements to be send still needs propper implementation + """ + mpi_type, elements = cls.__mpi_type_mappings[obj.dtype], torch.numel(obj) + + # simple case, contiguous memory can be transmitted as is + if is_contiguous is None: + # determine local contiguity + is_contiguous = obj.is_contiguous() + + if is_contiguous: + if counts is None: + return mpi_type, elements + factor = np.prod(obj.shape[1:]) + return ( + mpi_type, + ( + tuple(factor * ele for ele in counts), + (tuple(factor * ele for ele in displs)), + ), + ) + + # non-contiguous memory, e.g. after a transpose, has to be packed in derived MPI types + elements = obj.shape[0] + shape = obj.shape[1:] + strides = [1] * len(shape) + strides[0] = obj.stride()[-1] + strides = strides[::-1] + offsets = [obj.element_size() * stride for stride in obj.stride()[:-1]] + + # chain the types based on the + for i in range(len(shape) - 1, -1, -1): + mpi_type = mpi_type.Create_vector(shape[i], 1, strides[i]).Create_resized(0, offsets[i]) + mpi_type.Commit() + + if counts is not None: + return mpi_type, (counts, displs) + + return mpi_type, elements + + @classmethod + def as_mpi_memory(cls, obj) -> MPI.memory: + """ + Converts the passed ``torch.Tensor`` into an MPI compatible memory view. + + Parameters + ---------- + obj : torch.Tensor + The tensor to be converted into a MPI memory view. + """ + return MPI.memory.fromaddress(obj.data_ptr(), 0) + + @classmethod + def as_buffer( + cls, + obj: torch.Tensor, + counts: Tuple[int] = None, + displs: Tuple[int] = None, + is_contiguous: Optional[bool] = None, + ) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]: + """ + Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type. + + Parameters + ---------- + obj : torch.Tensor + The object to be converted into a buffer representation. + counts : Tuple[int,...], optional + Optional counts arguments for variable MPI-calls (e.g. Alltoallv) + displs : Tuple[int,...], optional + Optional displacements arguments for variable MPI-calls (e.g. Alltoallv) + is_contiguous: bool, optional + Optional information on global contiguity of the memory-distributed object. + """ + squ = False + if not obj.is_contiguous() and obj.ndim == 1: + # this makes the math work below this function. + obj.unsqueeze_(-1) + squ = True + + mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs, is_contiguous) + mpi_mem = cls.as_mpi_memory(obj) + if squ: + # the squeeze happens in the mpi_type_and_elements_of function in the case of a + # non-contiguous 1D tensor. Squeezing it puts the memory back to where it should be + obj.squeeze_(-1) + return [mpi_mem, elements, mpi_type] + + def alltoall_sendbuffer( + self, obj: torch.Tensor + ) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]: + """ + Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type. + XXX: might not work for all MPI stacks. Might require multiple type commits or so + + Parameters + ---------- + obj: torch.Tensor + The object to be transformed into a custom MPI datatype + """ + mpi_type = self.__mpi_type_mappings[obj.dtype] + + nproc = self.size + shape = obj.shape + strides = [1] * len(shape) + strides[-1] = obj.stride()[-1] + offsets = [0] * len(shape) + offsets[1:] = [obj.element_size() * stride for stride in obj.stride()[:-1]] + + # Step 1: Wrap along axes > 1 (all axes except send_axis and recv_axis + for i in range(len(shape) - 1, 1, -1): + mpi_type = mpi_type.Create_vector(shape[i], 1, strides[i]).Create_resized(0, offsets[i]) + mpi_type.Commit() + + # Step 2: Create Custom sized vector datatypes, according to rank-specific size along send_axis + # send_elements has nproc entries, defining how many vectors of mpi_type are stacked together for each process to receive along the send_axis + send_elements = np.full((nproc,), obj.shape[1] // nproc) + send_elements[: obj.shape[1] % nproc] += 1 + + # Create short_Type from the last entry of send_elements + mpi_short_type = mpi_type.Create_vector(send_elements[-1], 1, strides[1]).Create_resized( + 0, offsets[1] + ) + mpi_short_type.Commit() + # Create long_Type from the first entry of send_elements (wraps one more mpi_type vector than short_Type + mpi_long_type = mpi_type.Create_vector(send_elements[0], 1, strides[1]).Create_resized( + 0, offsets[1] + ) + mpi_long_type.Commit() + + # Step 3: Pack short_type and long_type along the recv_axis + mpi_short_type = mpi_short_type.Create_vector(shape[0], 1, strides[0]).Create_resized( + 0, send_elements[-1] * obj.stride()[1] * obj.element_size() + ) + mpi_short_type.Commit() + mpi_long_type = mpi_long_type.Create_vector(shape[0], 1, strides[0]).Create_resized( + 0, send_elements[0] * obj.stride()[1] * obj.element_size() + ) + mpi_long_type.Commit() + + # Step 4: Prepare sencounts, senddispls and sendtypes for alltoallw + # to each process 1 element (=sendcount) of the custom prepared long or short type will be send + sendcount = [1] * nproc + tmp_displs = [0] * nproc + tmp_displs[1:] = np.cumsum(send_elements[:-1]) + element_size = obj.element_size() + senddispls = [element_size * obj.stride()[1] * d for d in tmp_displs] + sendtypes = [mpi_short_type] * nproc + for i in range(obj.shape[1] % nproc): + sendtypes[i] = mpi_long_type + + return self.as_mpi_memory(obj), (sendcount, senddispls), sendtypes + + def alltoall_recvbuffer( + self, obj: torch.Tensor + ) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]: + """ + Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type. + XXX: might not work for all MPI stacks. Might require multiple type commits or so + + Parameters + ---------- + obj: torch.Tensor + The object to be transformed into a custom MPI datatype + """ + mpi_type, _ = self.__mpi_type_mappings[obj.dtype], torch.numel(obj) + + nproc = self.size + shape = obj.shape[1:] + strides = [1] * len(shape) + strides[0] = obj.stride()[-1] + strides = strides[::-1] + offsets = [obj.element_size() * stride for stride in obj.stride()[:-1]] + + # Step 1: Wrap along axes > 0 (all axes except recv_axis) + for i in range(len(shape) - 1, -1, -1): + mpi_type = mpi_type.Create_vector(shape[i], 1, strides[i]).Create_resized(0, offsets[i]) + mpi_type.Commit() + + # Step 2: Receive blocks along the recv axis + # Prepare recvcount, senddispls and sendtypes for alltoallw + recvcount = np.full((nproc,), obj.shape[0] // nproc) + recvcount[: obj.shape[0] % nproc] += 1 + # size/extent of mpitype = offsets[0] + tmp_displs = [0] * nproc + tmp_displs[1:] = np.cumsum(recvcount[:-1]) + recvdispls = [offsets[0] * d for d in tmp_displs] + recvtypes = [mpi_type] * nproc + + return self.as_mpi_memory(obj), (recvcount, recvdispls), recvtypes + + def Free(self) -> None: + """ + Free a communicator. + """ + self.handle.Free() + + def Split(self, color: int = 0, key: int = 0) -> MPICommunication: + """ + Split communicator by color and key. + + Parameters + ---------- + color : int, optional + Determines the new communicator for a process. + key: int, optional + Ordering within the new communicator. + """ + return MPICommunication(self.handle.Split(color, key)) + + def Irecv( + self, + buf: Union[torch.Tensor, Any], + source: int = MPI.ANY_SOURCE, + tag: int = MPI.ANY_TAG, + ) -> MPIRequest: + """ + Nonblocking receive + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address where to place the received message + source: int, optional + Rank of source process, that send the message + tag: int, optional + A Tag to identify the message + """ + if not isinstance(buf, torch.Tensor): + return MPIRequest(self.handle.Irecv(buf, source, tag)) + + rbuf = buf if CUDA_AWARE_MPI else buf.cpu() + return MPIRequest(self.handle.Irecv(self.as_buffer(rbuf), source, tag), None, rbuf, buf) + + Irecv.__doc__ = MPI.Comm.Irecv.__doc__ + + def Recv( + self, + buf: Union[torch.Tensor, Any], + source: int = MPI.ANY_SOURCE, + tag: int = MPI.ANY_TAG, + status: MPI.Status = None, + ): + """ + Blocking receive + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address where to place the received message + source: int, optional + Rank of the source process, that send the message + tag: int, optional + A Tag to identify the message + status: MPI.Status, optional + Details on the communication + """ + if not isinstance(buf, torch.Tensor): + return self.handle.Recv(buf, source, tag, status) + + rbuf = buf if CUDA_AWARE_MPI else buf.cpu() + ret = self.handle.Recv(self.as_buffer(rbuf), source, tag, status) + + if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Recv.__doc__ = MPI.Comm.Recv.__doc__ + + def __send_like( + self, func: Callable, buf: Union[torch.Tensor, Any], dest: int, tag: int + ) -> Tuple[Optional[Union[torch.Tensor]]]: + """ + Generic function for sending a message to process with rank "dest" + + Parameters + ------------ + func: Callable + The respective MPI sending function + buf: Union[torch.Tensor, Any] + Buffer address of the message to be send + dest: int, optional + Rank of the destination process, that receives the message + tag: int, optional + A Tag to identify the message + """ + if not isinstance(buf, torch.Tensor): + return func(buf, dest, tag), None + + # in case of GPUs, the memory has to be copied to host memory if CUDA-aware MPI is not supported + sbuf = buf if CUDA_AWARE_MPI else buf.cpu() + return func(self.as_buffer(sbuf), dest, tag), sbuf + + def Bsend(self, buf: Union[torch.Tensor, Any], dest: int, tag: int = 0): + """ + Blocking buffered send + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address of the message to be send + dest: int, optional + Index of the destination process, that receives the message + tag: int, optional + A Tag to identify the message + """ + return self.__send_like(self.handle.Bsend, buf, dest, tag)[0] + + Bsend.__doc__ = MPI.Comm.Bsend.__doc__ + + def Ibsend(self, buf: Union[torch.Tensor, Any], dest: int, tag: int = 0) -> MPIRequest: + """ + Nonblocking buffered send + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address of the message to be send + dest: int, optional + Rank of the destination process, that receives the message + tag: int, optional + A Tag to identify the message + """ + return MPIRequest(*self.__send_like(self.handle.Ibsend, buf, dest, tag)) + + Ibsend.__doc__ = MPI.Comm.Ibsend.__doc__ + + def Irsend(self, buf: Union[torch.Tensor, Any], dest: int, tag: int = 0) -> MPIRequest: + """ + Nonblocking ready send + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address of the message to be send + dest: int, optional + Rank of the destination process, that receives the message + tag: int, optional + A Tag to identify the message + """ + return MPIRequest(*self.__send_like(self.handle.Irsend, buf, dest, tag)) + + Irsend.__doc__ = MPI.Comm.Irsend.__doc__ + + def Isend(self, buf: Union[torch.Tensor, Any], dest: int, tag: int = 0) -> MPIRequest: + """ + Nonblocking send + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address of the message to be send + dest: int, optional + Rank of the destination process, that receives the message + tag: int, optional + A Tag to identify the message + """ + return MPIRequest(*self.__send_like(self.handle.Isend, buf, dest, tag)) + + Isend.__doc__ = MPI.Comm.Isend.__doc__ + + def Issend(self, buf: Union[torch.Tensor, Any], dest: int, tag: int = 0) -> MPIRequest: + """ + Nonblocking synchronous send + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address of the message to be send + dest: int, optional + Rank of the destination process, that receives the message + tag: int, optional + A Tag to identify the message + """ + return MPIRequest(*self.__send_like(self.handle.Issend, buf, dest, tag)) + + Issend.__doc__ = MPI.Comm.Issend.__doc__ + + def Rsend(self, buf: Union[torch.Tensor, Any], dest: int, tag: int = 0): + """ + Blocking ready send + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address of the message to be send + dest: int, optional + Rank of the destination process, that receives the message + tag: int, optional + A Tag to identify the message + """ + return self.__send_like(self.handle.Rsend, buf, dest, tag)[0] + + Rsend.__doc__ = MPI.Comm.Rsend.__doc__ + + def Ssend(self, buf: Union[torch.Tensor, Any], dest: int, tag: int = 0): + """ + Blocking synchronous send + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address of the message to be send + dest: int, optional + Rank of the destination process, that receives the message + tag: int, optional + A Tag to identify the message + """ + return self.__send_like(self.handle.Ssend, buf, dest, tag)[0] + + Ssend.__doc__ = MPI.Comm.Ssend.__doc__ + + def Send(self, buf: Union[torch.Tensor, Any], dest: int, tag: int = 0): + """ + Blocking send + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address of the message to be send + dest: int, optional + Rank of the destination process, that receives the message + tag: int, optional + A Tag to identify the message + """ + return self.__send_like(self.handle.Send, buf, dest, tag)[0] + + Send.__doc__ = MPI.Comm.Send.__doc__ + + def __broadcast_like( + self, func: Callable, buf: Union[torch.Tensor, Any], root: int + ) -> Tuple[Optional[torch.Tensor]]: + """ + Generic function for broadcasting a message from the process with rank "root" to all other processes of the + communicator + + Parameters + ------------ + func: Callable + The respective MPI broadcast function + buf: Union[torch.Tensor, Any] + Buffer address of the message to be broadcasted + root: int + Rank of the root process, that broadcasts the message + """ + # convert torch tensors to MPI memory buffers + if not isinstance(buf, torch.Tensor): + return func(buf, root), None, None, None + + srbuf = buf if CUDA_AWARE_MPI else buf.cpu() + + return func(self.as_buffer(srbuf), root), srbuf, srbuf, buf + + def Bcast(self, buf: Union[torch.Tensor, Any], root: int = 0) -> None: + """ + Blocking Broadcast + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address of the message to be broadcasted + root: int + Rank of the root process, that broadcasts the message + """ + ret, sbuf, rbuf, buf = self.__broadcast_like(self.handle.Bcast, buf, root) + if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Bcast.__doc__ = MPI.Comm.Bcast.__doc__ + + def Ibcast(self, buf: Union[torch.Tensor, Any], root: int = 0) -> MPIRequest: + """ + Nonblocking Broadcast + + Parameters + ------------ + buf: Union[torch.Tensor, Any] + Buffer address of the message to be broadcasted + root: int + Rank of the root process, that broadcasts the message + """ + return MPIRequest(*self.__broadcast_like(self.handle.Ibcast, buf, root)) + + Ibcast.__doc__ = MPI.Comm.Ibcast.__doc__ + + def __reduce_like( + self, + func: Callable, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + *args, + **kwargs, + ) -> Tuple[Optional[torch.Tensor]]: + """ + Generic function for reduction operations. + + Parameters + ------------ + func: Callable + The respective MPI reduction operation + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result of the reduction + """ + sbuf = None + rbuf = None + buf = None + + # harmonize the input and output buffers + # MPI requires send and receive buffers to be of same type and length. If the torch tensors are either not both + # contiguous or differently strided, they have to be made matching (if possible) first. + if isinstance(sendbuf, torch.Tensor): + # convert the send buffer to a pointer, number of elements and type are identical to the receive buffer + dummy = ( + sendbuf.contiguous() + ) # make a contiguous copy and reassign the storage, old will be collected + # In PyTorch Version >= 2.0.0 we can use untyped_storage() instead of storage + # to keep backward compatibility with earlier PyTorch versions (where no untyped_storage() exists) we use a try/except + # (this applies to all places of Heat where untyped_storage() is used without further comment) + try: + sendbuf.set_( + dummy.untyped_storage(), + dummy.storage_offset(), + size=dummy.shape, + stride=dummy.stride(), + ) + except AttributeError: + sendbuf.set_( + dummy.storage(), + dummy.storage_offset(), + size=dummy.shape, + stride=dummy.stride(), + ) + sbuf = sendbuf if CUDA_AWARE_MPI else sendbuf.cpu() + sendbuf = self.as_buffer(sbuf) + if isinstance(recvbuf, torch.Tensor): + buf = recvbuf + # nothing matches, the buffers have to be made contiguous + dummy = recvbuf.contiguous() + try: + recvbuf.set_( + dummy.untyped_storage(), + dummy.storage_offset(), + size=dummy.shape, + stride=dummy.stride(), + ) + except AttributeError: + recvbuf.set_( + dummy.storage(), + dummy.storage_offset(), + size=dummy.shape, + stride=dummy.stride(), + ) + rbuf = recvbuf if CUDA_AWARE_MPI else recvbuf.cpu() + if sendbuf is MPI.IN_PLACE: + recvbuf = self.as_buffer(rbuf) + else: + recvbuf = (self.as_mpi_memory(rbuf), sendbuf[1], sendbuf[2]) + + # perform the actual reduction operation + return func(sendbuf, recvbuf, *args, **kwargs), sbuf, rbuf, buf + + def Allreduce( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + op: MPI.Op = MPI.SUM, + ): + """ + Combines values from all processes and distributes the result back to all processes + + Parameters + --------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result of the reduction + op: MPI.Op + The operation to perform upon reduction + """ + ret, sbuf, rbuf, buf = self.__reduce_like(self.handle.Allreduce, sendbuf, recvbuf, op) + if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Allreduce.__doc__ = MPI.Comm.Allreduce.__doc__ + + def Exscan( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + op: MPI.Op = MPI.SUM, + ): + """ + Computes the exclusive scan (partial reductions) of data on a collection of processes + + Parameters + ------------ + sendbuf: Union[DNDarray, torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[DNDarray, torch.Tensor, Any] + Buffer address where to store the result of the reduction + op: MPI.Op + The operation to perform upon reduction + """ + ret, sbuf, rbuf, buf = self.__reduce_like(self.handle.Exscan, sendbuf, recvbuf, op) + if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Exscan.__doc__ = MPI.COMM_WORLD.Exscan.__doc__ + + def Iallreduce( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + op: MPI.Op = MPI.SUM, + ) -> MPIRequest: + """ + Nonblocking allreduce reducing values on all processes to a single value + + Parameters + --------- + sendbuf: Union[DNDarray, torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[DNDarray, torch.Tensor, Any] + Buffer address where to store the result of the reduction + op: MPI.Op + The operation to perform upon reduction + """ + return MPIRequest(*self.__reduce_like(self.handle.Iallreduce, sendbuf, recvbuf, op)) + + Iallreduce.__doc__ = MPI.Comm.Iallreduce.__doc__ + + def Iexscan( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + op: MPI.Op = MPI.SUM, + ) -> MPIRequest: + """ + Nonblocking Exscan + + Parameters + ------------ + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result of the reduction + op: MPI.Op + The operation to perform upon reduction + """ + return MPIRequest(*self.__reduce_like(self.handle.Iexscan, sendbuf, recvbuf, op)) + + Iexscan.__doc__ = MPI.COMM_WORLD.Iexscan.__doc__ + + def Iscan( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + op: MPI.Op = MPI.SUM, + ) -> MPIRequest: + """ + Nonblocking Scan + + Parameters + ------------ + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result of the reduction + op: MPI.Op + The operation to perform upon reduction + """ + return MPIRequest(*self.__reduce_like(self.handle.Iscan, sendbuf, recvbuf, op)) + + Iscan.__doc__ = MPI.COMM_WORLD.Iscan.__doc__ + + def Ireduce( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + op: MPI.Op = MPI.SUM, + root: int = 0, + ) -> MPIRequest: + """ + Nonblocking reduction operation + + Parameters + --------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result of the reduction + op: MPI.Op + The operation to perform upon reduction + root: int + Rank of the root process + """ + return MPIRequest(*self.__reduce_like(self.handle.Ireduce, sendbuf, recvbuf, op, root)) + + Ireduce.__doc__ = MPI.Comm.Ireduce.__doc__ + + def Reduce( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + op: MPI.Op = MPI.SUM, + root: int = 0, + ): + """ + Reduce values from all processes to a single value on process "root" + + Parameters + --------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result of the reduction + op: MPI.Op + The operation to perform upon reduction + root: int + Rank of the root process + """ + ret, sbuf, rbuf, buf = self.__reduce_like(self.handle.Reduce, sendbuf, recvbuf, op, root) + if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Reduce.__doc__ = MPI.Comm.Reduce.__doc__ + + def Scan( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + op: MPI.Op = MPI.SUM, + ): + """ + Computes the scan (partial reductions) of data on a collection of processes in a nonblocking way + + Parameters + ------------ + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result of the reduction + op: MPI.Op + The operation to perform upon reduction + """ + ret, sbuf, rbuf, buf = self.__reduce_like(self.handle.Scan, sendbuf, recvbuf, op) + if buf is not None and isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Scan.__doc__ = MPI.COMM_WORLD.Scan.__doc__ + + def __allgather_like( + self, + func: Callable, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + axis: int, + **kwargs, + ): + """ + Generic function for allgather operations. + + Parameters + ---------- + func: Callable + Type of MPI Allgather function (i.e. allgather, allgatherv, iallgather) + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + axis: int + Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks + """ + # dummy allocation for *v calls + # ToDO: Propper implementation of usage + send_counts, send_displs, recv_counts, recv_displs = None, None, None, None + + # unpack the send buffer + if isinstance(sendbuf, tuple): + sendbuf, send_counts, send_displs = sendbuf + if not isinstance(sendbuf, torch.Tensor) and axis != 0: + raise TypeError( + f"sendbuf of type {type(sendbuf)} does not support concatenation axis != 0" + ) + # unpack the receive buffer + if isinstance(recvbuf, tuple): + recvbuf, recv_counts, recv_displs = recvbuf + if not isinstance(recvbuf, torch.Tensor) and axis != 0: + raise TypeError( + f"recvbuf of type {type(recvbuf)} does not support concatenation axis != 0" + ) + + # keep a reference to the original buffer object + original_recvbuf = recvbuf + sbuf_is_contiguous, rbuf_is_contiguous = None, None + # permute the send_axis order so that the split send_axis is the first to be transmitted + if axis != 0: + send_axis_permutation = list(range(sendbuf.ndimension())) + send_axis_permutation[0], send_axis_permutation[axis] = axis, 0 + sendbuf = sendbuf.permute(*send_axis_permutation) + sbuf_is_contiguous = False + + recv_axis_permutation = list(range(recvbuf.ndimension())) + recv_axis_permutation[0], recv_axis_permutation[axis] = axis, 0 + recvbuf = recvbuf.permute(*recv_axis_permutation) + rbuf_is_contiguous = False + else: + recv_axis_permutation = None + + sbuf = sendbuf if CUDA_AWARE_MPI or not isinstance(sendbuf, torch.Tensor) else sendbuf.cpu() + rbuf = recvbuf if CUDA_AWARE_MPI or not isinstance(recvbuf, torch.Tensor) else recvbuf.cpu() + + # prepare buffer objects + if sendbuf is MPI.IN_PLACE or not isinstance(sendbuf, torch.Tensor): + mpi_sendbuf = sbuf + else: + mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs, sbuf_is_contiguous) + if send_counts is not None: + mpi_sendbuf[1] = mpi_sendbuf[1][0][self.rank] + + if recvbuf is MPI.IN_PLACE or not isinstance(recvbuf, torch.Tensor): + mpi_recvbuf = rbuf + else: + mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs, rbuf_is_contiguous) + if recv_counts is None: + mpi_recvbuf[1] //= self.size + # perform the scatter operation + exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs) + return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation + + def Allgather( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + recv_axis: int = 0, + ): + """ + Gathers data from all tasks and distribute the combined data to all tasks + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + recv_axis: int + Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks + """ + ret, sbuf, rbuf, buf, permutation = self.__allgather_like( + self.handle.Allgather, sendbuf, recvbuf, recv_axis + ) + if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: + rbuf = rbuf.permute(permutation) + if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Allgather.__doc__ = MPI.Comm.Allgather.__doc__ + + def Allgatherv( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + recv_axis: int = 0, + ): + """ + v-call of Allgather: Each process may contribute a different amount of data. + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + recv_axis: int + Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks + """ + ret, sbuf, rbuf, buf, permutation = self.__allgather_like( + self.handle.Allgatherv, sendbuf, recvbuf, recv_axis + ) + if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: + rbuf = rbuf.permute(permutation) + if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Allgatherv.__doc__ = MPI.Comm.Allgatherv.__doc__ + + def Iallgather( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + recv_axis: int = 0, + ) -> MPIRequest: + """ + Nonblocking Allgather. + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + recv_axis: int + Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks + """ + return MPIRequest( + *self.__allgather_like(self.handle.Iallgather, sendbuf, recvbuf, recv_axis) + ) + + Iallgather.__doc__ = MPI.Comm.Iallgather.__doc__ + + def Iallgatherv( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + recv_axis: int = 0, + ): + """ + Nonblocking v-call of Allgather: Each process may contribute a different amount of data. + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + recv_axis: int + Concatenation axis: The axis along which ``sendbuf`` is packed and along which ``recvbuf`` puts together individual chunks + """ + return MPIRequest( + *self.__allgather_like(self.handle.Iallgatherv, sendbuf, recvbuf, recv_axis) + ) + + Iallgatherv.__doc__ = MPI.Comm.Iallgatherv.__doc__ + + def __alltoall_like( + self, + func: Callable, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + send_axis: int, + recv_axis: int, + **kwargs, + ): + """ + Generic function for alltoall operations. + + Parameters + ---------- + func: Callable + Specific alltoall function + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + send_axis: int + Future split axis, along which data blocks will be created that will be send to individual ranks + + - if ``send_axis==recv_axis``, an error will be thrown + - if ``send_axis`` or ``recv_axis`` are ``None``, an error will be thrown + recv_axis: int + Prior split axis, along which blocks are received from the individual ranks + """ + if send_axis is None: + raise NotImplementedError( + f"AllToAll needs send_axis and recv_axis to be specified but was send_axis = {send_axis}, recv_axis = {recv_axis}. Please set send_axis and recv_axis" + ) + # align the output buffer in the same way as the input buffer by default + if recv_axis is None: + recv_axis = send_axis + + # dummy allocation for *v calls + send_counts, send_displs, recv_counts, recv_displs = None, None, None, None + + # unpack the send buffer + if isinstance(sendbuf, tuple): + sendbuf, send_counts, send_displs = sendbuf + if not isinstance(sendbuf, torch.Tensor) and send_axis != 0: + raise TypeError(f"sendbuf of type {type(sendbuf)} does not support send_axis != 0") + + # unpack the receive buffer + if isinstance(recvbuf, tuple): + recvbuf, recv_counts, recv_displs = recvbuf + if not isinstance(recvbuf, torch.Tensor) and send_axis != 0: + raise TypeError(f"recvbuf of type {type(recvbuf)} does not support send_axis != 0") + + # keep a reference to the original buffer object + original_recvbuf = recvbuf + + # Simple case, contiguous buffers can be transmitted as is + if send_axis < 2 and recv_axis < 2: + send_axis_permutation = list(range(recvbuf.ndimension())) + recv_axis_permutation = list(range(recvbuf.ndimension())) + + # Minimal Fix; Could possibly be improved when reworking counts, displs algorithmics + if self.size > 1: + send_axis_permutation[0], send_axis_permutation[send_axis] = (send_axis, 0) + recv_axis_permutation[0], recv_axis_permutation[recv_axis] = (recv_axis, 0) + + else: + recv_counts = send_counts + + sendbuf = sendbuf.permute(*send_axis_permutation) + recvbuf = recvbuf.permute(*recv_axis_permutation) + + # prepare buffer objects + sbuf = ( + sendbuf + if CUDA_AWARE_MPI or not isinstance(sendbuf, torch.Tensor) + else sendbuf.cpu() + ) + mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs) + if send_counts is None: + mpi_sendbuf[1] //= self.size + + rbuf = ( + recvbuf + if CUDA_AWARE_MPI or not isinstance(recvbuf, torch.Tensor) + else recvbuf.cpu() + ) + mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs) + if recv_counts is None: + mpi_recvbuf[1] //= self.size + + # perform the scatter operation + exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs) + # slightly more difficult situation, send and receive buffer need custom datatype preparation; + # operation is performed via alltoallw + else: + if recv_axis == send_axis: + raise NotImplementedError( + "AllToAll for same axes not supported. Please choose send_axis and recv_axis to be different." + ) + + # Send_axis-Permutation: [recv_axis, send_axis, rest ...] + axis_permutation = list(range(recvbuf.ndimension())) + if send_axis == 0: + axis_permutation[1], axis_permutation[send_axis] = send_axis, 1 + axis_permutation[recv_axis] = axis_permutation[0] + axis_permutation[0] = recv_axis + + else: + axis_permutation[0], axis_permutation[recv_axis] = recv_axis, 0 + axis_permutation[send_axis] = axis_permutation[1] + axis_permutation[1] = send_axis + + sendbuf = sendbuf.permute(*axis_permutation) + recvbuf = recvbuf.permute(*axis_permutation) + + # prepare buffer objects + sbuf = ( + sendbuf + if CUDA_AWARE_MPI or not isinstance(sendbuf, torch.Tensor) + else sendbuf.cpu() + ) + rbuf = ( + recvbuf + if CUDA_AWARE_MPI or not isinstance(recvbuf, torch.Tensor) + else recvbuf.cpu() + ) + mpi_sendbuf = self.alltoall_sendbuffer(sbuf) + mpi_recvbuf = self.alltoall_recvbuffer(rbuf) + + exit_code = self.handle.Alltoallw(mpi_sendbuf, mpi_recvbuf, **kwargs) + # original_recvbuf.set_(recvbuf.untyped_storage(), recvbuf.storage_offset(), original_recvbuf.shape, original_recvbuf.stride()) + recv_axis_permutation = list(np.argsort(np.array(axis_permutation))) + + return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation + + def Alltoall( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + send_axis: int = 0, + recv_axis: int = None, + ): + """ + All processes send data to all processes: The jth block sent from process i is received by process j and is + placed in the ith block of recvbuf. + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + send_axis: int + Future split axis, along which data blocks will be created that will be send to individual ranks + + - if ``send_axis==recv_axis``, an error will be thrown + - if ``send_axis`` or ``recv_axis`` are ``None``, an error will be thrown + recv_axis: int + Prior split axis, along which blocks are received from the individual ranks + """ + ret, sbuf, rbuf, buf, permutation = self.__alltoall_like( + self.handle.Alltoall, sendbuf, recvbuf, send_axis, recv_axis + ) + if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: + rbuf = rbuf.permute(permutation) + if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Alltoall.__doc__ = MPI.Comm.Alltoall.__doc__ + + def Alltoallv( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + send_axis: int = 0, + recv_axis: int = None, + ): + """ + v-call of Alltoall: All processes send different amount of data to, and receive different amount of data + from, all processes + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + send_axis: int + Future split axis, along which data blocks will be created that will be send to individual ranks + + - if ``send_axis==recv_axis``, an error will be thrown + - if ``send_axis`` or ``recv_axis`` are ``None``, an error will be thrown + recv_axis: int + Prior split axis, along which blocks are received from the individual ranks + """ + ret, sbuf, rbuf, buf, permutation = self.__alltoall_like( + self.handle.Alltoallv, sendbuf, recvbuf, send_axis, recv_axis + ) + if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: + rbuf = rbuf.permute(permutation) + if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Alltoallv.__doc__ = MPI.Comm.Alltoallv.__doc__ + + def Ialltoall( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + send_axis: int = 0, + recv_axis: int = None, + ) -> MPIRequest: + """ + Nonblocking Alltoall + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + send_axis: int + Future split axis, along which data blocks will be created that will be send to individual ranks + + - if ``send_axis==recv_axis``, an error will be thrown + - if ``send_axis`` or ``recv_axis`` are ``None``, an error will be thrown + recv_axis: int + Prior split axis, along which blocks are received from the individual ranks + """ + return MPIRequest( + *self.__alltoall_like(self.handle.Ialltoall, sendbuf, recvbuf, send_axis, recv_axis) + ) + + Ialltoall.__doc__ = MPI.Comm.Ialltoall.__doc__ + + def Ialltoallv( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + send_axis: int = 0, + recv_axis: int = None, + ) -> MPIRequest: + """ + Nonblocking v-call of Alltoall: All processes send different amount of data to, and receive different amount of + data from, all processes + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + send_axis: int + Future split axis, along which data blocks will be created that will be send to individual ranks + + - if ``send_axis==recv_axis``, an error will be thrown + - if ``send_axis`` or ``recv_axis`` are ``None``, an error will be thrown + recv_axis: int + Prior split axis, along which blocks are received from the individual ranks + """ + return MPIRequest( + *self.__alltoall_like(self.handle.Ialltoallv, sendbuf, recvbuf, send_axis, recv_axis) + ) + + Ialltoallv.__doc__ = MPI.Comm.Ialltoallv.__doc__ + + def __gather_like( + self, + func: Callable, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + send_axis: int, + recv_axis: int, + send_factor: int = 1, + recv_factor: int = 1, + **kwargs, + ): + """ + Generic function for gather operations. + + Parameters + ---------- + func: Callable + Type of MPI Scatter/Gather function + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + send_axis: int + The axis along which ``sendbuf`` is packed + recv_axis: int + The axis along which ``recvbuf`` is packed + send_factor: int + Number of elements to be scattered (vor non-v-calls) + recv_factor: int + Number of elements to be gathered (vor non-v-calls) + """ + sbuf, rbuf, recv_axis_permutation = None, None, None + + # align the output buffer in the same way as the input buffer by default + if recv_axis is None: + recv_axis = send_axis + + # dummy allocation for *v calls + send_counts, send_displs, recv_counts, recv_displs = None, None, None, None + + # unpack the send buffer + # if isinstance(sendbuf, tuple): + # sendbuf, send_counts, send_displs = sendbuf + if not isinstance(sendbuf, torch.Tensor) and send_axis != 0: + raise TypeError(f"sendbuf of type {type(sendbuf)} does not support send_axis != 0") + + # unpack the receive buffer + if isinstance(recvbuf, tuple): + recvbuf, recv_counts, recv_displs = recvbuf + if not isinstance(recvbuf, torch.Tensor) and recv_axis != 0: + raise TypeError(f"recvbuf of type {type(recvbuf)} does not support recv_axis != 0") + + # keep a reference to the original buffer object + original_recvbuf = recvbuf + + # permute the send_axis order so that the split send_axis is the first to be transmitted + send_axis_permutation = list(range(sendbuf.ndimension())) + send_axis_permutation[0], send_axis_permutation[send_axis] = send_axis, 0 + sendbuf = sendbuf.permute(*send_axis_permutation) + + if self.rank == kwargs.get("root"): + recv_axis_permutation = list(range(recvbuf.ndimension())) + recv_axis_permutation[0], recv_axis_permutation[recv_axis] = recv_axis, 0 + recvbuf = recvbuf.permute(*recv_axis_permutation) + + # prepare buffer objects + sbuf = sendbuf if CUDA_AWARE_MPI or not isinstance(sendbuf, torch.Tensor) else sendbuf.cpu() + rbuf = recvbuf if CUDA_AWARE_MPI or not isinstance(recvbuf, torch.Tensor) else recvbuf.cpu() + + if sendbuf is not MPI.IN_PLACE: + mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs) + if send_counts is None: + mpi_sendbuf[1] //= send_factor + else: + mpi_sendbuf = sbuf + if recvbuf is not MPI.IN_PLACE: + mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs) + if recv_counts is None: + mpi_recvbuf[1] //= recv_factor + else: + mpi_recvbuf = rbuf + + # perform the scatter operation + exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs) + + # undo the recvbuf permutation and assign the temporary buffer to the original recvbuf + # if recv_axis != 0: + # recvbuf = recvbuf.permute(*recv_axis_permutation) + # original_recvbuf.set_(recvbuf.untyped_storage(), recvbuf.storage_offset(), recvbuf.shape, recvbuf.stride()) + + return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation + + def Gather( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + root: int = 0, + axis: int = 0, + recv_axis: int = None, + ): + """ + Gathers together values from a group of processes + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + root: int + Rank of receiving process + axis: int + The axis along which ``sendbuf`` is packed + recv_axis: int + The axis along which ``recvbuf`` is packed + """ + ret, sbuf, rbuf, buf, permutation = self.__gather_like( + self.handle.Gather, sendbuf, recvbuf, axis, recv_axis, root=root, recv_factor=self.size + ) + if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: + rbuf = rbuf.permute(permutation) + if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Gather.__doc__ = MPI.Comm.Gather.__doc__ + + def Gatherv( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + root: int = 0, + axis: int = 0, + recv_axis: int = None, + ): + """ + v-call for Gather: All processes send different amount of data + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + root: int + Rank of receiving process + axis: int + The axis along which ``sendbuf`` is packed + recv_axis: int + The axis along which ``recvbuf`` is packed + """ + ret, sbuf, rbuf, buf, permutation = self.__gather_like( + self.handle.Gatherv, sendbuf, recvbuf, axis, recv_axis, root=root + ) + if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: + rbuf = rbuf.permute(permutation) + if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Gatherv.__doc__ = MPI.Comm.Gatherv.__doc__ + + def Igather( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + root: int = 0, + axis: int = 0, + recv_axis: int = None, + ) -> MPIRequest: + """ + Non-blocking Gather + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + root: int + Rank of receiving process + axis: int + The axis along which ``sendbuf`` is packed + recv_axis: int + The axis along which ``recvbuf`` is packed + """ + return MPIRequest( + *self.__gather_like( + self.handle.Igather, + sendbuf, + recvbuf, + axis, + recv_axis, + root=root, + recv_factor=self.size, + ) + ) + + Igather.__doc__ = MPI.Comm.Igather.__doc__ + + def Igatherv( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + root: int = 0, + axis: int = 0, + recv_axis: int = None, + ) -> MPIRequest: + """ + Non-blocking v-call for Gather: All processes send different amount of data + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + root: int + Rank of receiving process + axis: int + The axis along which ``sendbuf`` is packed + recv_axis: int + The axis along which ``recvbuf`` is packed + """ + return MPIRequest( + *self.__gather_like( + self.handle.Igatherv, + sendbuf, + recvbuf, + axis, + recv_axis, + root=root, + recv_factor=self.size, + ) + ) + + Igatherv.__doc__ = MPI.Comm.Igatherv.__doc__ + + def __scatter_like( + self, + func: Callable, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + send_axis: int, + recv_axis: int, + send_factor: int = 1, + recv_factor: int = 1, + **kwargs, + ): + """ + Generic function for scatter operations. + + Parameters + ---------- + func: Callable + Type of MPI Scatter/Gather function + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + send_axis: int + The axis along which ``sendbuf`` is packed + recv_axis: int + The axis along which ``recvbuf`` is packed + send_factor: int + Number of elements to be scattered (vor non-v-calls) + recv_factor: int + Number of elements to be gathered (vor non-v-calls) + """ + sbuf, rbuf, recv_axis_permutation = None, None, None + + # align the output buffer in the same way as the input buffer by default + if recv_axis is None: + recv_axis = send_axis + + # dummy allocation for *v calls + send_counts, send_displs, recv_counts, recv_displs = None, None, None, None + + # unpack the send buffer + if isinstance(sendbuf, tuple): + sendbuf, send_counts, send_displs = sendbuf + if not isinstance(sendbuf, torch.Tensor) and send_axis != 0: + raise TypeError(f"sendbuf of type {type(sendbuf)} does not support send_axis != 0") + + # unpack the receive buffer + # if isinstance(recvbuf, tuple): + # recvbuf, recv_counts, recv_displs = recvbuf + if not isinstance(recvbuf, torch.Tensor) and recv_axis != 0: + raise TypeError(f"recvbuf of type {type(recvbuf)} does not support recv_axis != 0") + + # keep a reference to the original buffer object + original_recvbuf = recvbuf + + # permute the send_axis order so that the split send_axis is the first to be transmitted + if self.rank == kwargs.get("root"): + send_axis_permutation = list(range(sendbuf.ndimension())) + send_axis_permutation[0], send_axis_permutation[send_axis] = send_axis, 0 + sendbuf = sendbuf.permute(*send_axis_permutation) + + recv_axis_permutation = list(range(recvbuf.ndimension())) + recv_axis_permutation[0], recv_axis_permutation[recv_axis] = recv_axis, 0 + recvbuf = recvbuf.permute(*recv_axis_permutation) + + # prepare buffer objects + sbuf = sendbuf if CUDA_AWARE_MPI or not isinstance(sendbuf, torch.Tensor) else sendbuf.cpu() + rbuf = recvbuf if CUDA_AWARE_MPI or not isinstance(recvbuf, torch.Tensor) else recvbuf.cpu() + + if sendbuf is not MPI.IN_PLACE: + mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs) + if send_counts is None: + mpi_sendbuf[1] //= send_factor + else: + mpi_sendbuf = sbuf + if recvbuf is not MPI.IN_PLACE: + mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs) + if recv_counts is None: + mpi_recvbuf[1] //= recv_factor + else: + mpi_recvbuf = rbuf + + # perform the scatter operation + exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs) + + # undo the recvbuf permutation and assign the temporary buffer to the original recvbuf + # if recv_axis != 0: + # recvbuf = recvbuf.permute(*recv_axis_permutation) + # original_recvbuf.set_(recvbuf.untyped_storage(), recvbuf.storage_offset(), recvbuf.shape, recvbuf.stride()) + + return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation + + def Iscatter( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + root: int = 0, + axis: int = 0, + recv_axis: int = None, + ) -> MPIRequest: + """ + Non-blocking Scatter + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + root: int + Rank of sending process + axis: int + The axis along which ``sendbuf`` is packed + recv_axis: int + The axis along which ``recvbuf`` is packed + """ + return MPIRequest( + *self.__scatter_like( + self.handle.Iscatter, + sendbuf, + recvbuf, + axis, + recv_axis, + root=root, + send_factor=self.size, + ) + ) + + Iscatter.__doc__ = MPI.Comm.Iscatter.__doc__ + + def Iscatterv( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + root: int = 0, + axis: int = 0, + recv_axis: int = None, + ) -> MPIRequest: + """ + Non-blocking v-call for Scatter: Sends different amounts of data to different processes + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + root: int + Rank of sending process + axis: int + The axis along which ``sendbuf`` is packed + recv_axis: int + The axis along which ``recvbuf`` is packed + """ + return MPIRequest( + *self.__scatter_like( + self.handle.Iscatterv, + sendbuf, + recvbuf, + axis, + recv_axis, + root=root, + send_factor=self.size, + ) + ) + + Iscatterv.__doc__ = MPI.Comm.Iscatterv.__doc__ + + def Scatter( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: Union[torch.Tensor, Any], + root: int = 0, + axis: int = 0, + recv_axis: int = None, + ): + """ + Sends data parts from one process to all other processes in a communicator + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + root: int + Rank of sending process + axis: int + The axis along which ``sendbuf`` is packed + recv_axis: int + The axis along which ``recvbuf`` is packed + """ + ret, sbuf, rbuf, buf, permutation = self.__scatter_like( + self.handle.Scatter, sendbuf, recvbuf, axis, recv_axis, root=root, send_factor=self.size + ) + if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: + rbuf = rbuf.permute(permutation) + if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Scatter.__doc__ = MPI.Comm.Scatter.__doc__ + + def Scatterv( + self, + sendbuf: Union[torch.Tensor, Any], + recvbuf: int, + root: int = 0, + axis: int = 0, + recv_axis: int = None, + ): + """ + v-call for Scatter: Sends different amounts of data to different processes + + Parameters + ---------- + sendbuf: Union[torch.Tensor, Any] + Buffer address of the send message + recvbuf: Union[torch.Tensor, Any] + Buffer address where to store the result + root: int + Rank of sending process + axis: int + The axis along which ``sendbuf`` is packed + recv_axis: int + The axis along which ``recvbuf`` is packed + """ + ret, sbuf, rbuf, buf, permutation = self.__scatter_like( + self.handle.Scatterv, + sendbuf, + recvbuf, + axis, + recv_axis, + root=root, + send_factor=self.size, + ) + if buf is not None and isinstance(buf, torch.Tensor) and permutation is not None: + rbuf = rbuf.permute(permutation) + if isinstance(buf, torch.Tensor) and buf.is_cuda and not CUDA_AWARE_MPI: + buf.copy_(rbuf) + return ret + + Scatterv.__doc__ = MPI.Comm.Scatterv.__doc__ + + def __getattr__(self, name: str): + """ + Default pass-through for the communicator methods. + + Parameters + ---------- + name : str + The name of the method to be called. + """ + return getattr(self.handle, name) + + +# import at the end of file to break circular dependencies diff --git a/heat/communication_backends/tests/__init__.py b/heat/communication_backends/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/heat/core/__init__.py b/heat/core/__init__.py index 28d548439a..b4667c4601 100644 --- a/heat/core/__init__.py +++ b/heat/core/__init__.py @@ -4,7 +4,7 @@ from .arithmetics import * from .base import * -from .communication import * +from ..communication_backends.communication import * from .constants import * from .complex_math import * from .devices import * diff --git a/heat/core/_operations.py b/heat/core/_operations.py index 1a9d6766e5..375cbf316a 100644 --- a/heat/core/_operations.py +++ b/heat/core/_operations.py @@ -5,7 +5,7 @@ import torch import warnings -from .communication import MPI, MPI_WORLD +from ..communication_backends.communication import MPI, MPI_WORLD from . import factories from . import stride_tricks from . import sanitation diff --git a/heat/core/arithmetics.py b/heat/core/arithmetics.py index eed92c58c9..ed24b1c143 100644 --- a/heat/core/arithmetics.py +++ b/heat/core/arithmetics.py @@ -15,7 +15,7 @@ from . import types from . import logical -from .communication import MPI +from ..communication_backends.communication import MPI from .dndarray import DNDarray from .types import ( canonical_heat_type, diff --git a/heat/core/devices.py b/heat/core/devices.py index dfb69d2224..076ef68756 100644 --- a/heat/core/devices.py +++ b/heat/core/devices.py @@ -8,7 +8,7 @@ from typing import Any, Optional, Union -from . import communication +from ..communication_backends import communication __all__ = ["Device", "cpu", "get_device", "sanitize_device", "use_device"] diff --git a/heat/core/factories.py b/heat/core/factories.py index 34213eefba..2bafa9b626 100644 --- a/heat/core/factories.py +++ b/heat/core/factories.py @@ -6,7 +6,7 @@ from typing import Callable, Iterable, Optional, Sequence, Tuple, Type, Union, List -from .communication import MPI, sanitize_comm, Communication +from ..communication_backends.communication import MPI, sanitize_comm, Communication from .devices import Device from .dndarray import DNDarray from .memory import sanitize_memory_layout diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 33d94c04d0..203163d4ec 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -5,7 +5,7 @@ import torch from typing import List, Dict, Any, TypeVar, Union, Tuple, Sequence -from .communication import MPI +from ..communication_backends.communication import MPI from .dndarray import DNDarray from . import sanitation from . import types diff --git a/heat/core/io.py b/heat/core/io.py index c615a821b5..e4a01f5598 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -13,7 +13,7 @@ from . import factories from . import types -from .communication import Communication, MPI, MPI_WORLD, sanitize_comm +from ..communication_backends.communication import Communication, MPI, MPI_WORLD, sanitize_comm from .dndarray import DNDarray from .manipulations import hsplit, vsplit from .statistics import max as smax, min as smin diff --git a/heat/core/linalg/basics.py b/heat/core/linalg/basics.py index f71bf4b2d4..6e7fe513e4 100644 --- a/heat/core/linalg/basics.py +++ b/heat/core/linalg/basics.py @@ -10,7 +10,7 @@ from torch._C import Value -from ..communication import MPI +from ...communication_backends.communication import MPI from .. import arithmetics from .. import complex_math from .. import constants @@ -530,7 +530,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device, comm=a.comm) c.larray[slice_0.start : slice_0.stop, :] += hold - c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) + c.comm.Allreduce(MPI.IN_PLACE, c.larray, MPI.SUM) if gpu_int_flag: c = og_type(c, device=a.device) return c @@ -707,7 +707,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: c_idx = c.comm.chunk(c.shape, c.split)[2] c_index_map[c.comm.rank, 0, :] = (c_idx[0].start, c_idx[0].stop) c_index_map[c.comm.rank, 1, :] = (c_idx[1].start, c_idx[1].stop) - c_wait = c.comm.Iallreduce(MPI.IN_PLACE, c_index_map, MPI.SUM) + c_wait = c.comm.Iallreduce(MPI.IN_PLACE, c_index_map.larray, MPI.SUM) if a.split == 0: a_block_map = torch.zeros( diff --git a/heat/core/linalg/qr.py b/heat/core/linalg/qr.py index 4e3f0cea28..cc61b0fad5 100644 --- a/heat/core/linalg/qr.py +++ b/heat/core/linalg/qr.py @@ -5,7 +5,7 @@ import torch from typing import Type, Callable, Dict, Any, TypeVar, Union, Tuple -from ..communication import MPICommunication +from ...communication_backends.communication import MPICommunication from ..types import datatype from ..tiling import SquareDiagTiles from ..dndarray import DNDarray diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 3273fc739c..6dfc192c46 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -183,8 +183,12 @@ def lanczos( vi_loc = V._DNDarray__array[:, j] a = torch.dot(vr.larray, torch.conj(vi_loc)) b = torch.dot(vi_loc, torch.conj(vi_loc)) - A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM) - A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM) + A.comm.Allreduce( + ht.communication_backends.MPI.IN_PLACE, a, ht.communication_backends.MPI.SUM + ) + A.comm.Allreduce( + ht.communication_backends.MPI.IN_PLACE, b, ht.communication_backends.MPI.SUM + ) vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc # normalize v_r to Euclidean norm 1 and set as ith vector v vi = vr / ht.norm(vr) @@ -196,8 +200,12 @@ def lanczos( vi_loc = V.larray[:, j] a = torch.dot(vr._DNDarray__array, torch.conj(vi_loc)) b = torch.dot(vi_loc, torch.conj(vi_loc)) - A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM) - A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM) + A.comm.Allreduce( + ht.communication_backends.MPI.IN_PLACE, a, ht.communication_backends.MPI.SUM + ) + A.comm.Allreduce( + ht.communication_backends.MPI.IN_PLACE, b, ht.communication_backends.MPI.SUM + ) vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc vi = vr / ht.norm(vr) @@ -235,8 +243,12 @@ def lanczos( vi_loc = V._DNDarray__array[:, j] a = torch.dot(vr.larray, vi_loc) b = torch.dot(vi_loc, vi_loc) - A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM) - A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM) + A.comm.Allreduce( + ht.communication_backends.MPI.IN_PLACE, a, ht.communication_backends.MPI.SUM + ) + A.comm.Allreduce( + ht.communication_backends.MPI.IN_PLACE, b, ht.communication_backends.MPI.SUM + ) vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc # normalize v_r to Euclidean norm 1 and set as ith vector v vi = vr / ht.norm(vr) @@ -248,8 +260,12 @@ def lanczos( vi_loc = V.larray[:, j] a = torch.dot(vr._DNDarray__array, vi_loc) b = torch.dot(vi_loc, vi_loc) - A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM) - A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM) + A.comm.Allreduce( + ht.communication_backends.MPI.IN_PLACE, a, ht.communication_backends.MPI.SUM + ) + A.comm.Allreduce( + ht.communication_backends.MPI.IN_PLACE, b, ht.communication_backends.MPI.SUM + ) vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc vi = vr / ht.norm(vr) diff --git a/heat/core/linalg/svdtools.py b/heat/core/linalg/svdtools.py index fb90406384..b60fea8077 100644 --- a/heat/core/linalg/svdtools.py +++ b/heat/core/linalg/svdtools.py @@ -6,7 +6,7 @@ import torch from typing import Type, Callable, Dict, Any, TypeVar, Union, Tuple, Optional -from ..communication import MPICommunication +from ...communication_backends.communication import MPICommunication from ..dndarray import DNDarray from .. import factories from .. import types diff --git a/heat/core/linalg/tests/test_solver.py b/heat/core/linalg/tests/test_solver.py index f8f9889a9d..b10cadfa31 100644 --- a/heat/core/linalg/tests/test_solver.py +++ b/heat/core/linalg/tests/test_solver.py @@ -9,7 +9,7 @@ class TestSolver(TestCase): def test_cg(self): - size = ht.communication.MPI_WORLD.size * 3 + size = ht.communication_backends.MPI_WORLD.size * 3 b = ht.arange(1, size + 1, dtype=ht.float32, split=0) A = ht.manipulations.diag(b) x0 = ht.random.rand(size, dtype=b.dtype, split=b.split) diff --git a/heat/core/logical.py b/heat/core/logical.py index 49e2f332ac..84b8a22c20 100644 --- a/heat/core/logical.py +++ b/heat/core/logical.py @@ -14,7 +14,7 @@ from . import stride_tricks from . import types -from .communication import MPI +from ..communication_backends.communication import MPI from .dndarray import DNDarray __all__ = [ diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 7ae8d9db86..c86483ee39 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -9,7 +9,7 @@ from typing import Iterable, Type, List, Callable, Union, Tuple, Sequence, Optional -from .communication import MPI +from ..communication_backends.communication import MPI from .dndarray import DNDarray from . import arithmetics diff --git a/heat/core/printing.py b/heat/core/printing.py index e06db65c50..0a03bf3f5b 100644 --- a/heat/core/printing.py +++ b/heat/core/printing.py @@ -3,7 +3,7 @@ import builtins import copy import torch -from .communication import MPI_WORLD +from ..communication_backends.communication import MPI_WORLD from .dndarray import DNDarray diff --git a/heat/core/random.py b/heat/core/random.py index c7accc3d5a..71146c040a 100644 --- a/heat/core/random.py +++ b/heat/core/random.py @@ -6,14 +6,14 @@ from typing import List, Optional, Tuple, Type, Union -from . import communication +from ..communication_backends import communication from . import devices from . import factories from . import logical from . import stride_tricks from . import types -from .communication import Communication +from ..communication_backends.communication import Communication from .devices import Device from .dndarray import DNDarray from .types import datatype diff --git a/heat/core/relational.py b/heat/core/relational.py index 8167dcd18a..6cfc790d2d 100644 --- a/heat/core/relational.py +++ b/heat/core/relational.py @@ -8,7 +8,7 @@ from typing import Union -from .communication import MPI +from ..communication_backends.communication import MPI from .dndarray import DNDarray from . import _operations from . import dndarray diff --git a/heat/core/sanitation.py b/heat/core/sanitation.py index 6485e4139d..10a3bf8273 100644 --- a/heat/core/sanitation.py +++ b/heat/core/sanitation.py @@ -8,7 +8,7 @@ import warnings from typing import Any, Union, Sequence, List, Tuple -from .communication import MPI, Communication +from ..communication_backends.communication import MPI, Communication from .dndarray import DNDarray from . import factories diff --git a/heat/core/signal.py b/heat/core/signal.py index b5427705d4..31b1f79f4d 100644 --- a/heat/core/signal.py +++ b/heat/core/signal.py @@ -3,7 +3,7 @@ import torch import numpy as np -from .communication import MPI +from ..communication_backends.communication import MPI from .dndarray import DNDarray from .types import promote_types from .manipulations import pad, flip diff --git a/heat/core/statistics.py b/heat/core/statistics.py index bdb6765a1e..1fc1b1d064 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -5,7 +5,7 @@ import torch from typing import Any, Callable, Union, Tuple, List, Optional -from .communication import MPI +from ..communication_backends.communication import MPI from . import arithmetics from . import exponential from . import factories @@ -672,7 +672,7 @@ def histc( out = factories.empty( hist.size(), dtype=types.canonical_heat_type(hist.dtype), device=input.device ) - input.comm.Allreduce(hist, out, op=MPI.SUM) + input.comm.Allreduce(hist, out.larray, op=MPI.SUM) return out @@ -966,8 +966,8 @@ def reduce_means_elementwise(output_shape_i: torch.Tensor) -> DNDarray: n_tot = factories.zeros(x.comm.size, device=x.device) n_tot[x.comm.rank] = float(x.lshape[x.split]) mu_tot[x.comm.rank, :] = mu - x.comm.Allreduce(MPI.IN_PLACE, n_tot, MPI.SUM) - x.comm.Allreduce(MPI.IN_PLACE, mu_tot, MPI.SUM) + x.comm.Allreduce(MPI.IN_PLACE, n_tot.larray, MPI.SUM) + x.comm.Allreduce(MPI.IN_PLACE, mu_tot.larray, MPI.SUM) for i in range(1, x.comm.size): mu_tot[0, :], n_tot[0] = __merge_moments( @@ -999,7 +999,7 @@ def reduce_means_elementwise(output_shape_i: torch.Tensor) -> DNDarray: mu_tot = factories.zeros((x.comm.size, 2), device=x.device) mu_proc = factories.zeros((x.comm.size, 2), device=x.device) mu_proc[x.comm.rank] = mu_in, float(n) - x.comm.Allreduce(mu_proc, mu_tot, MPI.SUM) + x.comm.Allreduce(mu_proc.larray, mu_tot.larray, MPI.SUM) for i in range(1, x.comm.size): mu_tot[0, 0], mu_tot[0, 1] = __merge_moments( @@ -1637,7 +1637,7 @@ def _local_percentile(data: torch.Tensor, axis: int, indices: torch.Tensor) -> t comm=x.comm, balanced=True, ) - x.comm.Bcast(local_p, root=r) + x.comm.Bcast(local_p.larray, root=r) percentile[perc_slice] = local_p else: if x.comm.is_distributed() and split is not None: @@ -1943,7 +1943,7 @@ def reduce_vars_elementwise(output_shape_i: torch.Tensor) -> DNDarray: var_tot[x.comm.rank, 0, :] = var var_tot[x.comm.rank, 1, :] = mu var_tot[x.comm.rank, 2, :] = float(x.lshape[x.split]) - x.comm.Allreduce(MPI.IN_PLACE, var_tot, MPI.SUM) + x.comm.Allreduce(MPI.IN_PLACE, var_tot.larray, MPI.SUM) for i in range(1, x.comm.size): var_tot[0, 0, :], var_tot[0, 1, :], var_tot[0, 2, :] = __merge_moments( @@ -1974,7 +1974,7 @@ def reduce_vars_elementwise(output_shape_i: torch.Tensor) -> DNDarray: var_tot = factories.zeros((x.comm.size, 3), dtype=x.dtype, device=x.device) var_proc = factories.zeros((x.comm.size, 3), dtype=x.dtype, device=x.device) var_proc[x.comm.rank] = var_in, mu_in, float(n) - x.comm.Allreduce(var_proc, var_tot, MPI.SUM) + x.comm.Allreduce(var_proc.larray, var_tot.larray, MPI.SUM) for i in range(1, x.comm.size): var_tot[0, 0], var_tot[0, 1], var_tot[0, 2] = __merge_moments( diff --git a/heat/core/tests/test_communication.py b/heat/core/tests/test_communication.py index 48187a591b..3a20e4ccef 100644 --- a/heat/core/tests/test_communication.py +++ b/heat/core/tests/test_communication.py @@ -21,7 +21,7 @@ def setUpClass(cls): ) def test_self_communicator(self): - comm = ht.core.communication.MPI_SELF + comm = ht.communication_backends.MPI_SELF with self.assertRaises(ValueError): comm.chunk(self.data.shape, split=2) @@ -44,7 +44,7 @@ def test_self_communicator(self): self.assertEqual(1, (self.data == self.data[chunks]).all().item()) def test_mpi_communicator(self): - comm = ht.core.communication.MPI_WORLD + comm = ht.communication_backends.MPI_WORLD self.assertLess(comm.rank, comm.size) with self.assertRaises(ValueError): @@ -66,8 +66,8 @@ def test_mpi_communicator(self): self.assertEqual(len(chunks), len(self.data.shape)) def test_cuda_aware_mpi(self): - self.assertTrue(hasattr(ht.communication, "CUDA_AWARE_MPI")) - self.assertIsInstance(ht.communication.CUDA_AWARE_MPI, bool) + self.assertTrue(hasattr(ht.communication_backends, "CUDA_AWARE_MPI")) + self.assertIsInstance(ht.communication_backends.CUDA_AWARE_MPI, bool) def test_contiguous_memory_buffer(self): # vector heat tensor @@ -80,8 +80,8 @@ def test_contiguous_memory_buffer(self): self.assertTrue(vector_out.larray.is_contiguous()) # send message to self that is received into a separate buffer afterwards - req = vector_data.comm.Isend(vector_data, dest=vector_data.comm.rank) - vector_out.comm.Recv(vector_out, source=vector_out.comm.rank) + req = vector_data.comm.Isend(vector_data.larray, dest=vector_data.comm.rank) + vector_out.comm.Recv(vector_out.larray, source=vector_out.comm.rank) req.Wait() @@ -101,7 +101,7 @@ def test_contiguous_memory_buffer(self): self.assertTrue(tensor_out.is_contiguous()) # send message to self that is received into a separate buffer afterwards - comm = ht.core.communication.MPI_WORLD + comm = ht.communication_backends.MPI_WORLD req = comm.Isend(tensor_data, dest=comm.rank) comm.Recv(tensor_out, source=comm.rank) @@ -123,15 +123,15 @@ def test_non_contiguous_memory_buffer(self): # send message to self that is received into a separate buffer afterwards req = non_contiguous_data.comm.Isend( - non_contiguous_data, dest=non_contiguous_data.comm.rank + non_contiguous_data.larray, dest=non_contiguous_data.comm.rank ) - contiguous_out.comm.Recv(contiguous_out, source=contiguous_out.comm.rank) + contiguous_out.comm.Recv(contiguous_out.larray, source=contiguous_out.comm.rank) req.Wait() # check that after sending the data everything is equal self.assertTrue((non_contiguous_data.larray == contiguous_out.larray).all()) - if ht.get_device().device_type == "cpu" or ht.communication.CUDA_AWARE_MPI: + if ht.get_device().device_type == "cpu" or ht.communication_backends.CUDA_AWARE_MPI: self.assertTrue(contiguous_out.larray.is_contiguous()) # non-contiguous destination @@ -144,13 +144,13 @@ def test_non_contiguous_memory_buffer(self): self.assertFalse(non_contiguous_out.larray.is_contiguous()) # send message to self that is received into a separate buffer afterwards - req = contiguous_data.comm.Isend(contiguous_data, dest=contiguous_data.comm.rank) - non_contiguous_out.comm.Recv(non_contiguous_out, source=non_contiguous_out.comm.rank) + req = contiguous_data.comm.Isend(contiguous_data.larray, dest=contiguous_data.comm.rank) + non_contiguous_out.comm.Recv(non_contiguous_out.larray, source=non_contiguous_out.comm.rank) req.Wait() # check that after sending the data everything is equal self.assertTrue((contiguous_data.larray == non_contiguous_out.larray).all()) - if ht.get_device().device_type == "cpu" or ht.communication.CUDA_AWARE_MPI: + if ht.get_device().device_type == "cpu" or ht.communication_backends.CUDA_AWARE_MPI: self.assertFalse(non_contiguous_out.larray.is_contiguous()) # non-contiguous destination @@ -164,16 +164,16 @@ def test_non_contiguous_memory_buffer(self): # send message to self that is received into a separate buffer afterwards req = both_non_contiguous_data.comm.Isend( - both_non_contiguous_data, dest=both_non_contiguous_data.comm.rank + both_non_contiguous_data.larray, dest=both_non_contiguous_data.comm.rank ) both_non_contiguous_out.comm.Recv( - both_non_contiguous_out, source=both_non_contiguous_out.comm.rank + both_non_contiguous_out.larray, source=both_non_contiguous_out.comm.rank ) req.Wait() # check that after sending the data everything is equal self.assertTrue((both_non_contiguous_data.larray == both_non_contiguous_out.larray).all()) - if ht.get_device().device_type == "cpu" or ht.communication.CUDA_AWARE_MPI: + if ht.get_device().device_type == "cpu" or ht.communication_backends.CUDA_AWARE_MPI: self.assertFalse(both_non_contiguous_out.larray.is_contiguous()) def test_default_comm(self): @@ -219,7 +219,7 @@ def test_allgather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - data.comm.Allgather(data, output) + data.comm.Allgather(data.larray, output.larray) # check result self.assertTrue(data.larray.is_contiguous()) @@ -237,7 +237,7 @@ def test_allgather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - data.comm.Allgather(data, output, recv_axis=1) + data.comm.Allgather(data.larray, output.larray, recv_axis=1) # check result self.assertTrue(data.larray.is_contiguous()) @@ -256,7 +256,7 @@ def test_allgather(self): # ensure prior invariants self.assertFalse(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - data.comm.Allgather(data, output) + data.comm.Allgather(data.larray, output.larray) # check result self.assertFalse(data.larray.is_contiguous()) @@ -275,7 +275,7 @@ def test_allgather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertFalse(output.larray.is_contiguous()) - data.comm.Allgather(data, output, recv_axis=1) + data.comm.Allgather(data.larray, output.larray, recv_axis=1) # check result self.assertTrue(data.larray.is_contiguous()) @@ -296,7 +296,7 @@ def test_allgather(self): self.assertTrue(output.larray.is_contiguous()) # perform the allgather operation - data.comm.Allgather(data, output, recv_axis=0) + data.comm.Allgather(data.larray, output.larray, recv_axis=0) # check result result = ht.array([np.arange(0, ht.MPI_WORLD.size)] * 10).T @@ -311,7 +311,7 @@ def test_allgather(self): self.assertTrue(output.larray.is_contiguous()) # perform the allgather operation - data.comm.Allgather(data, output, recv_axis=1) + data.comm.Allgather(data.larray, output.larray, recv_axis=1) # check result result = ht.array([np.arange(0, ht.MPI_WORLD.size)] * 10) @@ -322,7 +322,7 @@ def test_allgather(self): output = ht.array([[0] * 3] * ht.MPI_WORLD.size) # perform the allgather operation - ht.MPI_WORLD.Allgatherv(data, output) + ht.MPI_WORLD.Allgatherv(data, output.larray) # check result result = ht.array([np.arange(0, ht.MPI_WORLD.size)] * 3).T @@ -332,7 +332,7 @@ def test_allgather(self): output = np.array([[0] * 3] * ht.MPI_WORLD.size) # perform the allgather operation - ht.MPI_WORLD.Allgatherv(data, output) + ht.MPI_WORLD.Allgatherv(data.larray, output) # check result result = np.array([np.arange(0, ht.MPI_WORLD.size)] * 3).T @@ -341,11 +341,11 @@ def test_allgather(self): with self.assertRaises(TypeError): data = np.array([ht.MPI_WORLD.rank] * 3) output = ht.array([[0] * 3 * ht.MPI_WORLD.size]) - ht.MPI_WORLD.Allgatherv(data, output, recv_axis=1) + ht.MPI_WORLD.Allgatherv(data, output.larray, recv_axis=1) with self.assertRaises(TypeError): data = ht.array([ht.MPI_WORLD.rank] * 3) output = np.array([[0] * 3 * ht.MPI_WORLD.size]) - ht.MPI_WORLD.Allgatherv(data, output, recv_axis=1) + ht.MPI_WORLD.Allgatherv(data.larray, output, recv_axis=1) def test_allgatherv(self): # contiguous data buffer, contiguous output buffer @@ -360,7 +360,7 @@ def test_allgatherv(self): # perform the allgather operation counts = tuple(range(1, ht.MPI_WORLD.size + 1)) displs = tuple(np.cumsum(range(ht.MPI_WORLD.size))) - data.comm.Allgatherv(data, (output, counts, displs)) + data.comm.Allgatherv(data.larray, (output.larray, counts, displs)) # check result self.assertTrue(data.larray.is_contiguous()) @@ -381,7 +381,7 @@ def test_allgatherv(self): # perform the allgather operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - data.comm.Allgatherv(data, (output, counts, displs)) + data.comm.Allgatherv(data.larray, (output.larray, counts, displs)) # check result self.assertFalse(data.larray.is_contiguous()) @@ -402,7 +402,7 @@ def test_allgatherv(self): # perform the allgather operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - data.comm.Allgatherv(data, (output, counts, displs)) + data.comm.Allgatherv(data.larray, (output.larray, counts, displs)) # check result self.assertTrue(data.larray.is_contiguous()) @@ -423,7 +423,7 @@ def test_allgatherv(self): # perform the allgather operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - data.comm.Allgatherv(data, (output, counts, displs)) + data.comm.Allgatherv(data.larray, (output.larray, counts, displs)) # check result self.assertFalse(data.larray.is_contiguous()) @@ -446,7 +446,9 @@ def test_allgatherv(self): # perform allgather operation send_counts, send_displs, _ = data.comm.counts_displs_shape(data.lshape, 0) recv_counts, recv_displs, _ = data.comm.counts_displs_shape(output.lshape, 0) - data.comm.Allgatherv((data, send_counts, send_displs), (output, recv_counts, recv_displs)) + data.comm.Allgatherv( + (data.larray, send_counts, send_displs), (output.larray, recv_counts, recv_displs) + ) # check result self.assertTrue(data.larray.is_contiguous()) @@ -462,7 +464,7 @@ def test_allreduce(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - data.comm.Allreduce(data, out, op=ht.MPI.SUM) + data.comm.Allreduce(data.larray, out.larray, op=ht.MPI.SUM) # check the reduction result self.assertTrue(data.larray.is_contiguous()) @@ -476,7 +478,7 @@ def test_allreduce(self): # reduce across all nodes self.assertFalse(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - data.comm.Allreduce(data, out, op=ht.MPI.SUM) + data.comm.Allreduce(data.larray, out.larray, op=ht.MPI.SUM) # check the reduction result # the data tensor will be contiguous after the reduction @@ -494,7 +496,7 @@ def test_allreduce(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertFalse(out.larray.is_contiguous()) - data.comm.Allreduce(data, out, op=ht.MPI.SUM) + data.comm.Allreduce(data.larray, out.larray, op=ht.MPI.SUM) # check the reduction result # the data tensor will be contiguous after the reduction @@ -513,7 +515,7 @@ def test_alltoall(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - data.comm.Alltoall(data, output) + data.comm.Alltoall(data.larray, output.larray) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -532,7 +534,7 @@ def test_alltoall(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - data.comm.Alltoall(data, output, send_axis=1) + data.comm.Alltoall(data.larray, output.larray, send_axis=1) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -551,7 +553,7 @@ def test_alltoall(self): # ensure prior invariants self.assertFalse(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - data.comm.Alltoall(data, output) + data.comm.Alltoall(data.larray, output.larray) # check scatter result self.assertFalse(data.larray.is_contiguous()) @@ -568,7 +570,7 @@ def test_alltoall(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertFalse(output.larray.is_contiguous()) - data.comm.Alltoall(data, output, send_axis=1) + data.comm.Alltoall(data.larray, output.larray, send_axis=1) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -581,11 +583,11 @@ def test_alltoall(self): with self.assertRaises(TypeError): data = np.array([ht.MPI_WORLD.rank] * 3) output = ht.array([[0] * 3 * ht.MPI_WORLD.size]) - ht.MPI_WORLD.Alltoall(data, output, send_axis=1) + ht.MPI_WORLD.Alltoall(data, output.larray, send_axis=1) with self.assertRaises(TypeError): data = ht.array([ht.MPI_WORLD.rank] * 3) output = np.array([[0] * 3 * ht.MPI_WORLD.size]) - ht.MPI_WORLD.Alltoall(data, output, send_axis=1) + ht.MPI_WORLD.Alltoall(data.larray, output, send_axis=1) def test_alltoallv(self): # contiguous data buffer @@ -604,7 +606,9 @@ def test_alltoallv(self): else: self.assertEqual(data.shape[0] % ht.MPI_WORLD.size, 0) - data.comm.Alltoallv((data, send_counts, send_displs), (output, recv_counts, recv_displs)) + data.comm.Alltoallv( + (data.larray, send_counts, send_displs), (output.larray, recv_counts, recv_displs) + ) self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) stack_count = output_shape[0] // ht.MPI_WORLD.size * 10 @@ -632,7 +636,9 @@ def test_alltoallv(self): else: self.assertEqual(data.shape[0] % ht.MPI_WORLD.size, 0) - data.comm.Alltoallv((data, send_counts, send_displs), (output, recv_counts, recv_displs)) + data.comm.Alltoallv( + (data.larray, send_counts, send_displs), (output.larray, recv_counts, recv_displs) + ) self.assertFalse(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) stack_count = output_shape[0] // ht.MPI_WORLD.size * 10 @@ -661,7 +667,9 @@ def test_alltoallv(self): else: self.assertEqual(data.shape[0] % ht.MPI_WORLD.size, 0) - data.comm.Alltoallv((data, send_counts, send_displs), (output, recv_counts, recv_displs)) + data.comm.Alltoallv( + (data.larray, send_counts, send_displs), (output.larray, recv_counts, recv_displs) + ) self.assertTrue(data.larray.is_contiguous()) self.assertFalse(output.larray.is_contiguous()) stack_count = output_shape[1] // ht.MPI_WORLD.size * 10 @@ -690,7 +698,9 @@ def test_alltoallv(self): else: self.assertEqual(data.shape[0] % ht.MPI_WORLD.size, 0) - data.comm.Alltoallv((data, send_counts, send_displs), (output, recv_counts, recv_displs)) + data.comm.Alltoallv( + (data.larray, send_counts, send_displs), (output.larray, recv_counts, recv_displs) + ) self.assertFalse(data.larray.is_contiguous()) self.assertFalse(output.larray.is_contiguous()) stack_count = output_shape[1] // ht.MPI_WORLD.size * 10 @@ -710,7 +720,7 @@ def test_bcast(self): # broadcast data to all nodes self.assertTrue(data.larray.is_contiguous()) - data.comm.Bcast(data, root=0) + data.comm.Bcast(data.larray, root=0) # assert output is equal self.assertTrue(data.larray.is_contiguous()) @@ -723,7 +733,7 @@ def test_bcast(self): # broadcast data to all nodes self.assertFalse(data.larray.is_contiguous()) - data.comm.Bcast(data, root=0) + data.comm.Bcast(data.larray, root=0) # assert output is equal self.assertFalse(data.larray.is_contiguous()) @@ -742,7 +752,7 @@ def test_exscan(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - data.comm.Exscan(data, out) + data.comm.Exscan(data.larray, out.larray) # check the reduction result self.assertTrue(data.larray.is_contiguous()) @@ -756,7 +766,7 @@ def test_exscan(self): # reduce across all nodes self.assertFalse(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - data.comm.Exscan(data, out) + data.comm.Exscan(data.larray, out.larray) # check the reduction result # the data tensor will be contiguous after the reduction @@ -774,7 +784,7 @@ def test_exscan(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertFalse(out.larray.is_contiguous()) - data.comm.Exscan(data, out) + data.comm.Exscan(data.larray, out.larray) # check the reduction result # the data tensor will be contiguous after the reduction @@ -793,7 +803,7 @@ def test_gather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - data.comm.Gather(data, output, root=0) + data.comm.Gather(data.larray, output.larray, root=0) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -813,7 +823,7 @@ def test_gather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - data.comm.Gather(data, output, root=0, axis=1) + data.comm.Gather(data.larray, output.larray, root=0, axis=1) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -833,7 +843,7 @@ def test_gather(self): # ensure prior invariants self.assertFalse(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - data.comm.Gather(data, output, root=0) + data.comm.Gather(data.larray, output.larray, root=0) # check scatter result self.assertFalse(data.larray.is_contiguous()) @@ -853,7 +863,7 @@ def test_gather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertFalse(output.larray.is_contiguous()) - data.comm.Gather(data, output, root=0, axis=1) + data.comm.Gather(data.larray, output.larray, root=0, axis=1) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -879,7 +889,7 @@ def test_gatherv(self): # perform the scatter operation counts = tuple(range(1, ht.MPI_WORLD.size + 1)) displs = tuple(np.cumsum(range(ht.MPI_WORLD.size))) - data.comm.Gatherv(data, (output, counts, displs), root=0) + data.comm.Gatherv(data.larray, (output.larray, counts, displs), root=0) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -903,7 +913,7 @@ def test_gatherv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - data.comm.Gatherv(data, (output, counts, displs), root=0) + data.comm.Gatherv(data.larray, (output.larray, counts, displs), root=0) # check scatter result self.assertFalse(data.larray.is_contiguous()) @@ -927,7 +937,7 @@ def test_gatherv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - data.comm.Gatherv(data, (output, counts, displs), root=0) + data.comm.Gatherv(data.larray, (output.larray, counts, displs), root=0) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -951,7 +961,7 @@ def test_gatherv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - data.comm.Gatherv(data, (output, counts, displs), root=0) + data.comm.Gatherv(data.larray, (output.larray, counts, displs), root=0) # check scatter result self.assertFalse(data.larray.is_contiguous()) @@ -972,7 +982,7 @@ def test_iallgather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Iallgather(data, output) + req = data.comm.Iallgather(data.larray, output.larray) req.Wait() # check scatter result @@ -992,7 +1002,7 @@ def test_iallgather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Iallgather(data, output, recv_axis=1) + req = data.comm.Iallgather(data.larray, output.larray, recv_axis=1) req.Wait() # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -1011,7 +1021,7 @@ def test_iallgather(self): # ensure prior invariants self.assertFalse(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Iallgather(data, output) + req = data.comm.Iallgather(data.larray, output.larray) req.Wait() # check scatter result @@ -1031,7 +1041,7 @@ def test_iallgather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertFalse(output.larray.is_contiguous()) - req = data.comm.Iallgather(data, output, recv_axis=1) + req = data.comm.Iallgather(data.larray, output.larray, recv_axis=1) req.Wait() # check scatter result @@ -1062,7 +1072,7 @@ def test_iallgatherv(self): # perform the scatter operation counts = tuple(range(1, ht.MPI_WORLD.size + 1)) displs = tuple(np.cumsum(range(ht.MPI_WORLD.size))) - req = data.comm.Iallgatherv(data, (output, counts, displs)) + req = data.comm.Iallgatherv(data.larray, (output.larray, counts, displs)) req.Wait() # check scatter result @@ -1086,7 +1096,7 @@ def test_iallgatherv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - req = data.comm.Iallgatherv(data, (output, counts, displs)) + req = data.comm.Iallgatherv(data.larray, (output.larray, counts, displs)) req.Wait() # check scatter result @@ -1110,7 +1120,7 @@ def test_iallgatherv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - req = data.comm.Iallgatherv(data, (output, counts, displs)) + req = data.comm.Iallgatherv(data.larray, (output.larray, counts, displs)) req.Wait() # check scatter result @@ -1134,7 +1144,7 @@ def test_iallgatherv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - req = data.comm.Iallgatherv(data, (output, counts, displs)) + req = data.comm.Iallgatherv(data.larray, (output.larray, counts, displs)) req.Wait() # check scatter result @@ -1159,7 +1169,7 @@ def test_iallreduce(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - req = data.comm.Iallreduce(data, out, op=ht.MPI.SUM) + req = data.comm.Iallreduce(data.larray, out.larray, op=ht.MPI.SUM) req.Wait() # check the reduction result @@ -1174,7 +1184,7 @@ def test_iallreduce(self): # reduce across all nodes self.assertFalse(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - req = data.comm.Iallreduce(data, out, op=ht.MPI.SUM) + req = data.comm.Iallreduce(data.larray, out.larray, op=ht.MPI.SUM) req.Wait() # check the reduction result @@ -1193,7 +1203,7 @@ def test_iallreduce(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertFalse(out.larray.is_contiguous()) - req = data.comm.Iallreduce(data, out, op=ht.MPI.SUM) + req = data.comm.Iallreduce(data.larray, out.larray, op=ht.MPI.SUM) req.Wait() # check the reduction result @@ -1218,7 +1228,7 @@ def test_ialltoall(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Ialltoall(data, output) + req = data.comm.Ialltoall(data.larray, output.larray) req.Wait() # check scatter result @@ -1238,7 +1248,7 @@ def test_ialltoall(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Ialltoall(data, output, send_axis=1) + req = data.comm.Ialltoall(data.larray, output.larray, send_axis=1) req.Wait() # check scatter result @@ -1258,7 +1268,7 @@ def test_ialltoall(self): # ensure prior invariants self.assertFalse(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Ialltoall(data, output) + req = data.comm.Ialltoall(data.larray, output.larray) req.Wait() # check scatter result @@ -1276,7 +1286,7 @@ def test_ialltoall(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertFalse(output.larray.is_contiguous()) - req = data.comm.Ialltoall(data, output, send_axis=1) + req = data.comm.Ialltoall(data.larray, output.larray, send_axis=1) req.Wait() # check scatter result @@ -1310,7 +1320,7 @@ def test_ialltoallv(self): self.assertEqual(data.shape[0] % ht.MPI_WORLD.size, 0) req = data.comm.Ialltoallv( - (data, send_counts, send_displs), (output, recv_counts, recv_displs) + (data.larray, send_counts, send_displs), (output.larray, recv_counts, recv_displs) ) req.Wait() @@ -1342,7 +1352,7 @@ def test_ialltoallv(self): self.assertEqual(data.shape[0] % ht.MPI_WORLD.size, 0) req = data.comm.Ialltoallv( - (data, send_counts, send_displs), (output, recv_counts, recv_displs) + (data.larray, send_counts, send_displs), (output.larray, recv_counts, recv_displs) ) req.Wait() @@ -1375,7 +1385,7 @@ def test_ialltoallv(self): self.assertEqual(data.shape[0] % ht.MPI_WORLD.size, 0) req = data.comm.Ialltoallv( - (data, send_counts, send_displs), (output, recv_counts, recv_displs) + (data.larray, send_counts, send_displs), (output.larray, recv_counts, recv_displs) ) req.Wait() @@ -1408,7 +1418,7 @@ def test_ialltoallv(self): self.assertEqual(data.shape[0] % ht.MPI_WORLD.size, 0) req = data.comm.Ialltoallv( - (data, send_counts, send_displs), (output, recv_counts, recv_displs) + (data.larray, send_counts, send_displs), (output.larray, recv_counts, recv_displs) ) req.Wait() @@ -1436,7 +1446,7 @@ def test_ibcast(self): # broadcast data to all nodes self.assertTrue(data.larray.is_contiguous()) - req = data.comm.Ibcast(data, root=0) + req = data.comm.Ibcast(data.larray, root=0) req.Wait() # assert output is equal @@ -1452,7 +1462,7 @@ def test_ibcast(self): # broadcast data to all nodes self.assertFalse(data.larray.is_contiguous()) - req = data.comm.Ibcast(data, root=0) + req = data.comm.Ibcast(data.larray, root=0) req.Wait() # assert output is equal @@ -1477,7 +1487,7 @@ def test_iexscan(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - req = data.comm.Iexscan(data, out) + req = data.comm.Iexscan(data.larray, out.larray) req.Wait() # check the reduction result @@ -1492,7 +1502,7 @@ def test_iexscan(self): # reduce across all nodes self.assertFalse(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - req = data.comm.Iexscan(data, out) + req = data.comm.Iexscan(data.larray, out.larray) req.Wait() # check the reduction result @@ -1511,7 +1521,7 @@ def test_iexscan(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertFalse(out.larray.is_contiguous()) - req = data.comm.Iexscan(data, out) + req = data.comm.Iexscan(data.larray, out.larray) req.Wait() # check the reduction result @@ -1536,7 +1546,7 @@ def test_igather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Igather(data, output, root=0) + req = data.comm.Igather(data.larray, output.larray, root=0) req.Wait() # check scatter result @@ -1561,7 +1571,7 @@ def test_igather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Igather(data, output, root=0, axis=1) + req = data.comm.Igather(data.larray, output.larray, root=0, axis=1) req.Wait() # check scatter result @@ -1586,7 +1596,7 @@ def test_igather(self): # ensure prior invariants self.assertFalse(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Igather(data, output, root=0) + req = data.comm.Igather(data.larray, output.larray, root=0) req.Wait() # check scatter result @@ -1611,7 +1621,7 @@ def test_igather(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertFalse(output.larray.is_contiguous()) - req = data.comm.Igather(data, output, root=0, axis=1) + req = data.comm.Igather(data.larray, output.larray, root=0, axis=1) req.Wait() # check scatter result @@ -1647,7 +1657,7 @@ def test_igatherv(self): # perform the scatter operation counts = tuple(range(1, ht.MPI_WORLD.size + 1)) displs = tuple(np.cumsum(range(ht.MPI_WORLD.size))) - req = data.comm.Igatherv(data, (output, counts, displs), root=0) + req = data.comm.Igatherv(data.larray, (output.larray, counts, displs), root=0) req.Wait() # check scatter result @@ -1673,7 +1683,7 @@ def test_igatherv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - req = data.comm.Igatherv(data, (output, counts, displs), root=0) + req = data.comm.Igatherv(data.larray, (output.larray, counts, displs), root=0) req.Wait() # check scatter result @@ -1699,7 +1709,7 @@ def test_igatherv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - req = data.comm.Igatherv(data, (output, counts, displs), root=0) + req = data.comm.Igatherv(data.larray, (output.larray, counts, displs), root=0) req.Wait() # check scatter result @@ -1725,7 +1735,7 @@ def test_igatherv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - req = data.comm.Igatherv(data, (output, counts, displs), root=0) + req = data.comm.Igatherv(data.larray, (output.larray, counts, displs), root=0) req.Wait() # check scatter result @@ -1752,7 +1762,7 @@ def test_ireduce(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - req = data.comm.Ireduce(data, out, op=ht.MPI.SUM, root=0) + req = data.comm.Ireduce(data.larray, out.larray, op=ht.MPI.SUM, root=0) req.Wait() # check the reduction result @@ -1768,7 +1778,7 @@ def test_ireduce(self): # reduce across all nodes self.assertFalse(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - req = data.comm.Ireduce(data, out, op=ht.MPI.SUM, root=0) + req = data.comm.Ireduce(data.larray, out.larray, op=ht.MPI.SUM, root=0) req.Wait() # check the reduction result @@ -1788,7 +1798,7 @@ def test_ireduce(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertFalse(out.larray.is_contiguous()) - req = data.comm.Ireduce(data, out, op=ht.MPI.SUM, root=0) + req = data.comm.Ireduce(data.larray, out.larray, op=ht.MPI.SUM, root=0) req.Wait() # check the reduction result @@ -1814,7 +1824,7 @@ def test_iscan(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - req = data.comm.Iscan(data, out) + req = data.comm.Iscan(data.larray, out.larray) req.Wait() # check the reduction result @@ -1829,7 +1839,7 @@ def test_iscan(self): # reduce across all nodes self.assertFalse(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - req = data.comm.Iscan(data, out) + req = data.comm.Iscan(data.larray, out.larray) req.Wait() # check the reduction result @@ -1848,7 +1858,7 @@ def test_iscan(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertFalse(out.larray.is_contiguous()) - req = data.comm.Iscan(data, out) + req = data.comm.Iscan(data.larray, out.larray) req.Wait() # check the reduction result @@ -1876,7 +1886,7 @@ def test_iscatter(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Iscatter(data, output, root=0) + req = data.comm.Iscatter(data.larray, output.larray, root=0) req.Wait() # check scatter result @@ -1896,7 +1906,7 @@ def test_iscatter(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Iscatter(data, output, root=0, axis=1) + req = data.comm.Iscatter(data.larray, output.larray, root=0, axis=1) req.Wait() # check scatter result @@ -1917,7 +1927,7 @@ def test_iscatter(self): # ensure prior invariants self.assertTrue(output.larray.is_contiguous()) - req = data.comm.Iscatter(data, output, root=0) + req = data.comm.Iscatter(data.larray, output.larray, root=0) req.Wait() # check scatter result @@ -1940,7 +1950,7 @@ def test_iscatter(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertFalse(output.larray.is_contiguous()) - req = data.comm.Iscatter(data, output, root=0, axis=1) + req = data.comm.Iscatter(data.larray, output.larray, root=0, axis=1) req.Wait() # check scatter result @@ -1969,7 +1979,7 @@ def test_iscatterv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - req = data.comm.Iscatterv((data, counts, displs), output, root=0) + req = data.comm.Iscatterv((data.larray, counts, displs), output.larray, root=0) req.Wait() # check scatter result @@ -1994,7 +2004,7 @@ def test_iscatterv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - req = data.comm.Iscatterv((data, counts, displs), output, root=0) + req = data.comm.Iscatterv((data.larray, counts, displs), output.larray, root=0) req.Wait() # check scatter result @@ -2019,7 +2029,7 @@ def test_iscatterv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - req = data.comm.Iscatterv((data, counts, displs), output, root=0) + req = data.comm.Iscatterv((data.larray, counts, displs), output.larray, root=0) req.Wait() # check scatter result @@ -2044,7 +2054,7 @@ def test_iscatterv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - req = data.comm.Iscatterv((data, counts, displs), output, root=0) + req = data.comm.Iscatterv((data.larray, counts, displs), output.larray, root=0) req.Wait() # check scatter result @@ -2063,7 +2073,7 @@ def test_iscatterv(self): def test_mpi_in_place(self): size = ht.MPI_WORLD.size data = ht.ones((size, size), dtype=ht.int32) - data.comm.Allreduce(ht.MPI.IN_PLACE, data, op=ht.MPI.SUM) + data.comm.Allreduce(ht.MPI.IN_PLACE, data.larray, op=ht.MPI.SUM) self.assertTrue((data.larray == size).all()) # MPI Inplace is not allowed for AllToAll @@ -2076,7 +2086,7 @@ def test_reduce(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - data.comm.Reduce(data, out, op=ht.MPI.SUM, root=0) + data.comm.Reduce(data.larray, out.larray, op=ht.MPI.SUM, root=0) # check the reduction result self.assertTrue(data.larray.is_contiguous()) @@ -2091,7 +2101,7 @@ def test_reduce(self): # reduce across all nodes self.assertFalse(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - data.comm.Reduce(data, out, op=ht.MPI.SUM, root=0) + data.comm.Reduce(data.larray, out.larray, op=ht.MPI.SUM, root=0) # check the reduction result # the data tensor will be contiguous after the reduction @@ -2110,7 +2120,7 @@ def test_reduce(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertFalse(out.larray.is_contiguous()) - data.comm.Reduce(data, out, op=ht.MPI.SUM, root=0) + data.comm.Reduce(data.larray, out.larray, op=ht.MPI.SUM, root=0) # check the reduction result # the data tensor will be contiguous after the reduction @@ -2130,7 +2140,7 @@ def test_scan(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - data.comm.Scan(data, out) + data.comm.Scan(data.larray, out.larray) # check the reduction result self.assertTrue(data.larray.is_contiguous()) @@ -2144,7 +2154,7 @@ def test_scan(self): # reduce across all nodes self.assertFalse(data.larray.is_contiguous()) self.assertTrue(out.larray.is_contiguous()) - data.comm.Scan(data, out) + data.comm.Scan(data.larray, out.larray) # check the reduction result # the data tensor will be contiguous after the reduction @@ -2162,7 +2172,7 @@ def test_scan(self): # reduce across all nodes self.assertTrue(data.larray.is_contiguous()) self.assertFalse(out.larray.is_contiguous()) - data.comm.Scan(data, out) + data.comm.Scan(data.larray, out.larray) # check the reduction result # the data tensor will be contiguous after the reduction @@ -2184,7 +2194,7 @@ def test_scatter(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - data.comm.Scatter(data, output, root=0) + data.comm.Scatter(data.larray, output.larray, root=0) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -2201,7 +2211,7 @@ def test_scatter(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertTrue(output.larray.is_contiguous()) - data.comm.Scatter(data, output, root=0, axis=1) + data.comm.Scatter(data.larray, output.larray, root=0, axis=1) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -2219,7 +2229,7 @@ def test_scatter(self): # ensure prior invariants self.assertTrue(output.larray.is_contiguous()) - data.comm.Scatter(data, output, root=0) + data.comm.Scatter(data.larray, output.larray, root=0) # check scatter result if ht.MPI_WORLD.rank == 0: @@ -2239,7 +2249,7 @@ def test_scatter(self): # ensure prior invariants self.assertTrue(data.larray.is_contiguous()) self.assertFalse(output.larray.is_contiguous()) - data.comm.Scatter(data, output, root=0, axis=1) + data.comm.Scatter(data.larray, output.larray, root=0, axis=1) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -2265,7 +2275,7 @@ def test_scatter_like_axes(self): output = ht.zeros_like(data) # main axis send buffer, main axis receive buffer - data.comm.Alltoall(data, output, send_axis=0) + data.comm.Alltoall(data.larray, output.larray, send_axis=0) comparison = ( torch.arange(ht.MPI_WORLD.size, device=self.device.torch_device) .reshape(-1, 1) @@ -2274,7 +2284,7 @@ def test_scatter_like_axes(self): self.assertTrue((output.larray == comparison).all()) # minor axis send buffer, main axis receive buffer - data.comm.Alltoall(data, output, send_axis=1) + data.comm.Alltoall(data.larray, output.larray, send_axis=1) comparison = ( torch.arange(ht.MPI_WORLD.size, device=self.device.torch_device) .reshape(1, -1) @@ -2285,7 +2295,7 @@ def test_scatter_like_axes(self): # main axis send buffer, minor axis receive buffer data = ht.array([[ht.MPI_WORLD.rank] * (2 * ht.MPI_WORLD.size)] * ht.MPI_WORLD.size) output = ht.zeros((2 * ht.MPI_WORLD.size, ht.MPI_WORLD.size), dtype=data.dtype) - data.comm.Alltoall(data, output, send_axis=0, recv_axis=1) + data.comm.Alltoall(data.larray, output.larray, send_axis=0, recv_axis=1) comparison = ( torch.arange(ht.MPI_WORLD.size, device=self.device.torch_device) .reshape(1, -1) @@ -2296,7 +2306,7 @@ def test_scatter_like_axes(self): # minor axis send buffer, minor axis receive buffer data = ht.array([range(ht.MPI_WORLD.size)] * ht.MPI_WORLD.size) output = ht.zeros((ht.MPI_WORLD.size, ht.MPI_WORLD.size), dtype=data.dtype) - data.comm.Alltoall(data, output, send_axis=0, recv_axis=1) + data.comm.Alltoall(data.larray, output.larray, send_axis=0, recv_axis=1) comparison = ( torch.arange(ht.MPI_WORLD.size, device=self.device.torch_device) .reshape(-1, 1) @@ -2319,7 +2329,7 @@ def test_scatterv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - data.comm.Scatterv((data, counts, displs), output, root=0) + data.comm.Scatterv((data.larray, counts, displs), output.larray, root=0) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -2341,7 +2351,7 @@ def test_scatterv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - data.comm.Scatterv((data, counts, displs), output, root=0) + data.comm.Scatterv((data.larray, counts, displs), output.larray, root=0) # check scatter result self.assertFalse(data.larray.is_contiguous()) @@ -2363,7 +2373,7 @@ def test_scatterv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - data.comm.Scatterv((data, counts, displs), output, root=0) + data.comm.Scatterv((data.larray, counts, displs), output.larray, root=0) # check scatter result self.assertTrue(data.larray.is_contiguous()) @@ -2385,7 +2395,7 @@ def test_scatterv(self): # perform the scatter operation counts = tuple(range(2, 2 * (ht.MPI_WORLD.size + 1), 2)) displs = tuple(np.cumsum(range(0, 2 * ht.MPI_WORLD.size, 2))) - data.comm.Scatterv((data, counts, displs), output, root=0) + data.comm.Scatterv((data.larray, counts, displs), output.larray, root=0) # check scatter result self.assertFalse(data.larray.is_contiguous()) @@ -2418,17 +2428,17 @@ def test_allgathervSorting(self): gathered3 = torch.empty(self.sorted3Dtensor.shape, device=self.device.torch_device) test1.comm.Allgatherv( - test1, (gathered1, gathered1_counts, gathered1_displs), recv_axis=test1.split + test1.larray, (gathered1, gathered1_counts, gathered1_displs), recv_axis=test1.split ) self.assertTrue(torch.equal(gathered1, result.larray)) test2.comm.Allgatherv( - test2, (gathered2, gathered2_counts, gathered2_displs), recv_axis=test2.split + test2.larray, (gathered2, gathered2_counts, gathered2_displs), recv_axis=test2.split ) self.assertTrue(torch.equal(gathered2, result.larray)) test3.comm.Allgatherv( - test3, (gathered3, gathered3_counts, gathered3_displs), recv_axis=test3.split + test3.larray, (gathered3, gathered3_counts, gathered3_displs), recv_axis=test3.split ) self.assertTrue(torch.equal(gathered3, result.larray)) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 1ba1c45608..b295d627b4 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1611,7 +1611,7 @@ def test_stride_and_strides(self): self.assertEqual(heat_float64_F.strides, numpy_float64_F.strides) # Distributed, int16, row-major memory layout - size = ht.communication.MPI_WORLD.size + size = ht.communication_backends.MPI_WORLD.size split = 2 torch_int16 = torch.arange( 6 * 5 * 3 * size * 4 * 5 * 7, dtype=torch.int16, device=self.device.torch_device diff --git a/heat/core/tests/test_factories.py b/heat/core/tests/test_factories.py index b304d6b231..4ac55bbfe0 100644 --- a/heat/core/tests/test_factories.py +++ b/heat/core/tests/test_factories.py @@ -201,7 +201,7 @@ def test_array(self): ) # distributed array, partial data (is_split) - if ht.communication.MPI_WORLD.rank == 0: + if ht.communication_backends.MPI_WORLD.rank == 0: split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [0.0, 0.0, 0.0]] else: split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] @@ -209,7 +209,7 @@ def test_array(self): self.assertIsInstance(e, ht.DNDarray) self.assertEqual(e.dtype, ht.float32) - if ht.communication.MPI_WORLD.rank == 0: + if ht.communication_backends.MPI_WORLD.rank == 0: self.assertEqual(e.lshape, (3, 3, 1)) else: self.assertEqual(e.lshape, (2, 3, 1)) @@ -221,8 +221,8 @@ def test_array(self): self.assertGreaterEqual(ele, e.lshape[index]) # exception distributed shapes do not fit - if ht.communication.MPI_WORLD.size > 1: - if ht.communication.MPI_WORLD.rank == 0: + if ht.communication_backends.MPI_WORLD.size > 1: + if ht.communication_backends.MPI_WORLD.rank == 0: split_data = [4.0, 5.0, 6.0] else: split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] @@ -232,8 +232,8 @@ def test_array(self): ht.array(split_data, is_split=0) # exception distributed shapes do not fit - if ht.communication.MPI_WORLD.size > 1: - if ht.communication.MPI_WORLD.rank == 0: + if ht.communication_backends.MPI_WORLD.size > 1: + if ht.communication_backends.MPI_WORLD.rank == 0: split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [0.0, 0.0, 0.0]] else: split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] @@ -250,7 +250,7 @@ def test_array(self): self.assertIsInstance(e, ht.DNDarray) self.assertEqual(e.dtype, ht.float32) - if ht.communication.MPI_WORLD.rank == 0: + if ht.communication_backends.MPI_WORLD.rank == 0: self.assertEqual(e.lshape, (1, 3, 3)) else: self.assertEqual(e.lshape, (1, 2, 3)) @@ -262,8 +262,8 @@ def test_array(self): self.assertGreaterEqual(ele, e.lshape[index]) # exception distributed shapes do not fit - if ht.communication.MPI_WORLD.size > 1: - if ht.communication.MPI_WORLD.rank == 0: + if ht.communication_backends.MPI_WORLD.size > 1: + if ht.communication_backends.MPI_WORLD.rank == 0: split_data = [4.0, 5.0, 6.0] else: split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] @@ -273,8 +273,8 @@ def test_array(self): ht.array(split_data, is_split=0) # exception distributed shapes do not fit - if ht.communication.MPI_WORLD.size > 1: - if ht.communication.MPI_WORLD.rank == 0: + if ht.communication_backends.MPI_WORLD.size > 1: + if ht.communication_backends.MPI_WORLD.rank == 0: split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [0.0, 0.0, 0.0]] else: split_data = [[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]] diff --git a/heat/core/tests/test_logical.py b/heat/core/tests/test_logical.py index 3e46fd144e..7e138d8e9a 100644 --- a/heat/core/tests/test_logical.py +++ b/heat/core/tests/test_logical.py @@ -267,7 +267,7 @@ def test_any(self): self.assertEqual(keepdims_any.split, None) def test_isclose(self): - size = ht.communication.MPI_WORLD.size + size = ht.communication_backends.MPI_WORLD.size a = ht.float32([[2, 2], [2, 2]]) b = ht.float32([[2.00005, 2.00005], [2.00005, 2.00005]]) c = ht.zeros((4 * size, 6), split=0) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 9825d333e9..69f1a9e6ce 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -473,8 +473,8 @@ def test_concatenate(self): with self.assertRaises(ValueError): ht.concatenate((x, ht.zeros((2, 2))), axis=0) with self.assertRaises(RuntimeError): - a = ht.zeros((10,), comm=ht.communication.MPI_WORLD) - b = ht.zeros((10,), comm=ht.communication.MPI_SELF) + a = ht.zeros((10,), comm=ht.communication_backends.MPI_WORLD) + b = ht.zeros((10,), comm=ht.communication_backends.MPI_SELF) ht.concatenate([a, b]) with self.assertRaises(ValueError): ht.concatenate((ht.zeros((12, 12)), ht.zeros((2, 2))), axis=0) diff --git a/heat/core/tests/test_memory.py b/heat/core/tests/test_memory.py index bdff40ac4b..c7abeea0ad 100644 --- a/heat/core/tests/test_memory.py +++ b/heat/core/tests/test_memory.py @@ -37,7 +37,7 @@ def test_sanitize_memory_layout(self): a_torch_5d_sum = a_torch_5d.sum(-2) self.assert_array_equal(a_heat_5d_F_sum, a_torch_5d_sum) # distributed, split, 2D - size = ht.communication.MPI_WORLD.size + size = ht.communication_backends.MPI_WORLD.size a_torch_2d = torch.arange(4 * size * 3 * size, device=self.device.torch_device).reshape( 4 * size, 3 * size ) diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index 98e3c459fa..c3e1a0a213 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -169,7 +169,7 @@ def test_floor(self): ht.floor(object()) def test_modf(self): - size = ht.communication.MPI_WORLD.size + size = ht.communication_backends.MPI_WORLD.size start, end = -5.0, 5.0 step = (end - start) / (2 * size) npArray = np.arange(start, end, step, dtype=np.float32) @@ -248,7 +248,7 @@ def test_modf(self): self.assert_array_equal(float64_modf_distrbd[1], comparison[1]) def test_round(self): - size = ht.communication.MPI_WORLD.size + size = ht.communication_backends.MPI_WORLD.size start, end = -5.7, 5.1 step = (end - start) / (2 * size) comparison = torch.arange(start, end, step, dtype=torch.float32).round() diff --git a/heat/core/types.py b/heat/core/types.py index 7f6159adc9..8a0739c988 100644 --- a/heat/core/types.py +++ b/heat/core/types.py @@ -9,7 +9,7 @@ import numpy as np import torch -from . import communication +from ..communication_backends import communication from . import devices from . import factories from . import _operations diff --git a/heat/graph/tests/test_laplacian.py b/heat/graph/tests/test_laplacian.py index 1c21764861..8ed153df5f 100644 --- a/heat/graph/tests/test_laplacian.py +++ b/heat/graph/tests/test_laplacian.py @@ -8,8 +8,8 @@ class TestLaplacian(TestCase): def test_laplacian(self): - size = ht.communication.MPI_WORLD.size - rank = ht.communication.MPI_WORLD.rank + size = ht.communication_backends.MPI_WORLD.size + rank = ht.communication_backends.MPI_WORLD.rank X = ht.ones((size * 2, 4), split=0) X.larray[0, :] *= rank X.larray[1, :] *= rank + 0.5 diff --git a/heat/nn/data_parallel.py b/heat/nn/data_parallel.py index 9eeda8b0f7..3428ff07c1 100644 --- a/heat/nn/data_parallel.py +++ b/heat/nn/data_parallel.py @@ -10,9 +10,9 @@ from typing import Any, Callable, Dict, List, Union, Tuple from .. import optim -from ..core.communication import MPI -from ..core.communication import MPI_WORLD -from ..core.communication import MPICommunication +from ..communication_backends.communication import MPI +from ..communication_backends.communication import MPI_WORLD +from ..communication_backends.communication import MPICommunication __all__ = ["DataParallel", "DataParallelMultiGPU"] diff --git a/heat/optim/dp_optimizer.py b/heat/optim/dp_optimizer.py index 5e45545349..cfe45f9cc6 100644 --- a/heat/optim/dp_optimizer.py +++ b/heat/optim/dp_optimizer.py @@ -9,9 +9,9 @@ from torch.nn.parallel import DistributedDataParallel as tDDP from typing import Union, List, Tuple, Dict -from ..core.communication import MPICommunication -from ..core.communication import MPI -from ..core.communication import MPI_WORLD +from ..communication_backends.communication import MPICommunication +from ..communication_backends.communication import MPI +from ..communication_backends.communication import MPI_WORLD from .utils import DetectMetricPlateau diff --git a/heat/sparse/_operations.py b/heat/sparse/_operations.py index 1d38114955..062ef719ea 100644 --- a/heat/sparse/_operations.py +++ b/heat/sparse/_operations.py @@ -5,7 +5,7 @@ from heat.sparse.dcsr_matrix import DCSR_matrix from . import factories -from ..core.communication import MPI +from ..communication_backends.communication import MPI from ..core.dndarray import DNDarray from ..core import types diff --git a/heat/sparse/factories.py b/heat/sparse/factories.py index d1e545e957..37b5ea1a65 100644 --- a/heat/sparse/factories.py +++ b/heat/sparse/factories.py @@ -9,7 +9,7 @@ from ..core import devices from ..core import types -from ..core.communication import MPI, sanitize_comm, Communication +from ..communication_backends.communication import MPI, sanitize_comm, Communication from ..core.devices import Device from ..core.types import datatype diff --git a/heat/sparse/tests/test_arithmetics.py b/heat/sparse/tests/test_arithmetics.py index 20517541b6..fd49d0be07 100644 --- a/heat/sparse/tests/test_arithmetics.py +++ b/heat/sparse/tests/test_arithmetics.py @@ -63,12 +63,12 @@ def setUpClass(self): device=self.device.torch_device, ) - self.world_size = ht.communication.MPI_WORLD.size - self.rank = ht.communication.MPI_WORLD.rank + self.world_size = ht.communication_backends.MPI_WORLD.size + self.rank = ht.communication_backends.MPI_WORLD.rank self.scalar = np.array(random.randint(1, 100)) if self.world_size > 0: - ht.communication.MPI_WORLD.Bcast(self.scalar, root=0) + ht.communication_backends.MPI_WORLD.Bcast(self.scalar, root=0) self.scalar = self.scalar.item() def test_add(self): diff --git a/heat/sparse/tests/test_dcsrmatrix.py b/heat/sparse/tests/test_dcsrmatrix.py index 6cf86ebf87..6a92176ba3 100644 --- a/heat/sparse/tests/test_dcsrmatrix.py +++ b/heat/sparse/tests/test_dcsrmatrix.py @@ -35,8 +35,8 @@ def setUpClass(self): self.ref_indptr, self.ref_indices, self.ref_data, device=self.device.torch_device ) - self.world_size = ht.communication.MPI_WORLD.size - self.rank = ht.communication.MPI_WORLD.rank + self.world_size = ht.communication_backends.MPI_WORLD.size + self.rank = ht.communication_backends.MPI_WORLD.rank def test_larray(self): heat_sparse_csr = ht.sparse.sparse_csr_matrix(self.ref_torch_sparse_csr) diff --git a/heat/sparse/tests/test_factories.py b/heat/sparse/tests/test_factories.py index 5728534605..32ebbb58ed 100644 --- a/heat/sparse/tests/test_factories.py +++ b/heat/sparse/tests/test_factories.py @@ -51,8 +51,8 @@ def setUpClass(self): ) ) - self.world_size = ht.communication.MPI_WORLD.size - self.rank = ht.communication.MPI_WORLD.rank + self.world_size = ht.communication_backends.MPI_WORLD.size + self.rank = ht.communication_backends.MPI_WORLD.rank def test_sparse_csr_matrix(self): """ diff --git a/heat/spatial/tests/test_distances.py b/heat/spatial/tests/test_distances.py index d8ce1a44ca..7c093e529a 100644 --- a/heat/spatial/tests/test_distances.py +++ b/heat/spatial/tests/test_distances.py @@ -12,7 +12,7 @@ class TestDistances(TestCase): def test_cdist(self): - n = ht.communication.MPI_WORLD.size + n = ht.communication_backends.MPI_WORLD.size X = ht.ones((n * 2, 4), dtype=ht.float32, split=None) Y = ht.zeros((n * 2, 4), dtype=ht.float32, split=None) res_XX_cdist = ht.zeros((n * 2, n * 2), dtype=ht.float32, split=None) @@ -204,7 +204,7 @@ def test_cdist(self): with self.assertRaises(NotImplementedError): ht.spatial.cdist(X, Z, quadratic_expansion=False) - n = ht.communication.MPI_WORLD.size + n = ht.communication_backends.MPI_WORLD.size A = ht.ones((n * 2, 6), dtype=ht.float32, split=None) for i in range(n): A[2 * i, :] = A[2 * i, :] * (2 * i) @@ -221,7 +221,7 @@ def test_cdist(self): result = ht.array(res, dtype=ht.float32, split=0) self.assertTrue(ht.allclose(d, result, atol=1e-5)) - n = ht.communication.MPI_WORLD.size + n = ht.communication_backends.MPI_WORLD.size A = ht.ones((n * 2, 6), dtype=ht.float32, split=None) for i in range(n): A[2 * i, :] = A[2 * i, :] * (2 * i) diff --git a/heat/utils/data/datatools.py b/heat/utils/data/datatools.py index 91b9a98f81..9c502cf517 100644 --- a/heat/utils/data/datatools.py +++ b/heat/utils/data/datatools.py @@ -7,7 +7,7 @@ from typing import Callable, List, Iterator, Union, Optional, Sized from ...core.dndarray import DNDarray -from ...core.communication import MPI_WORLD +from ...communication_backends.communication import MPI_WORLD from . import partial_dataset __all__ = ["DataLoader", "Dataset", "dataset_shuffle", "dataset_ishuffle"] diff --git a/heat/utils/data/matrixgallery.py b/heat/utils/data/matrixgallery.py index 5937da869a..31ab135210 100644 --- a/heat/utils/data/matrixgallery.py +++ b/heat/utils/data/matrixgallery.py @@ -4,7 +4,7 @@ from heat import core from ...core.dndarray import DNDarray -from ...core.communication import Communication +from ...communication_backends.communication import Communication from ...core.devices import Device from ...core.types import datatype, heat_type_is_complexfloating, heat_type_is_exact from ...core.random import randn, rand diff --git a/heat/utils/data/partial_dataset.py b/heat/utils/data/partial_dataset.py index 5b48d72efa..ac87263abe 100644 --- a/heat/utils/data/partial_dataset.py +++ b/heat/utils/data/partial_dataset.py @@ -11,8 +11,8 @@ from torch.utils import data as torch_data from typing import Callable, List, Iterator, Union -from ...core.communication import MPICommunication -from ...core.communication import MPI_WORLD +from ...communication_backends.communication import MPICommunication +from ...communication_backends.communication import MPI_WORLD __all__ = ["PartialH5Dataset", "PartialH5DataLoaderIter"]