diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 5117891e4..5e723828b 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -7,14 +7,15 @@ # pyre-strict +from abc import ABC, abstractmethod +from collections.abc import Sequence from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Any, List, Optional, Tuple, TypeVar +from typing import Any, List, Optional, Tuple, TypeVar, Union import torch import torch.distributed as dist import torch.distributed._functional_collectives - from torch import Tensor from torch.autograd import Function from torch.autograd.profiler import record_function @@ -334,6 +335,93 @@ def _get_split_lengths_by_len( return (my_len, splits) +class Comm(ABC): + """ + Interface for communication primitives. + A primitive primarily needs to handle 3 tasks, namely: + + 1. How to allocate memory for communication + Depending on the goal, an implementation can choose to: + a. associate each call to a temporary buffer + (best for flexibility and simplicity) + b. reuse an persistent buffer for efficiency reasons + + 2. Where to allocate memory + (e.g. NCCL mem pool or regular cuda caching allocator) + + 3. What to do/call upon the comm is called + (see `AllGather` interface as an example) + """ + + @abstractmethod + def allocate( + self, + size: Sequence[Union[int, torch.SymInt]], + *, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + """ + This handles the "how to allocate memory" part. + + A default implementation could be simply: + + .. code-block:: python + with self.mem_pool: + torch.empty(...) + + Args: + size (Sequence[Union[int, torch.SymInt]]): size of the tensor buffer + dtype (torch.dtype): dtype of the tensor buffer + device (torch.device): which device to allocate the tensor onto + """ + ... + + +class All2AllSingle(Comm): + @abstractmethod + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + output_split_sizes: Optional[list[int]] = None, + input_split_sizes: Optional[list[int]] = None, + group: Optional[dist.ProcessGroup] = None, + async_op: bool = False, + ) -> Optional[dist.Work]: ... + + +class DefaultAllocMixin: + def allocate( + self, + size: Sequence[Union[int, torch.SymInt]], + *, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + return torch.empty(*size, dtype=dtype, device=device) + + +class DefaultAll2AllSingle(DefaultAllocMixin, All2AllSingle): + def __call__( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + output_split_sizes: Optional[list[int]] = None, + input_split_sizes: Optional[list[int]] = None, + group: Optional[dist.ProcessGroup] = None, + async_op: bool = False, + ) -> Optional[dist.Work]: + return dist.all_to_all_single( + output_tensor, + input_tensor, + output_split_sizes, + input_split_sizes, + group=group, + async_op=async_op, + ) + + def alltoall_pooled( a2a_pooled_embs_tensor: Tensor, batch_size_per_rank: List[int], @@ -342,6 +430,7 @@ def alltoall_pooled( cumsum_dim_sum_per_rank_tensor: Optional[Tensor] = None, group: Optional[dist.ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None, + all_to_all_single_comm: Optional[All2AllSingle] = None, ) -> Awaitable[Tensor]: """ Performs AlltoAll operation for a single pooled embedding tensor. Each process @@ -391,7 +480,15 @@ def alltoall_pooled( return NoWait(all2all_pooled_sync(group, a2ai, a2a_pooled_embs_tensor)) myreq = Request(group, device=a2a_pooled_embs_tensor.device) - All2All_Pooled_Req.apply(group, myreq, a2ai, a2a_pooled_embs_tensor) + if all_to_all_single_comm is None: + all_to_all_single_comm = DefaultAll2AllSingle() + All2All_Pooled_Req.apply( + group, + myreq, + a2ai, + a2a_pooled_embs_tensor, + all_to_all_single_comm, + ) return myreq @@ -476,6 +573,7 @@ def variable_batch_alltoall_pooled( emb_dim_per_rank_per_feature: List[List[int]], group: Optional[dist.ProcessGroup] = None, codecs: Optional[QuantizedCommCodecs] = None, + all_to_all_single_comm: Optional[All2AllSingle] = None, ) -> Awaitable[Tensor]: if group is None: @@ -497,7 +595,11 @@ def variable_batch_alltoall_pooled( ) myreq = Request(group, device=a2a_pooled_embs_tensor.device) - Variable_Batch_All2All_Pooled_Req.apply(group, myreq, a2ai, a2a_pooled_embs_tensor) + if all_to_all_single_comm is None: + all_to_all_single_comm = DefaultAll2AllSingle() + Variable_Batch_All2All_Pooled_Req.apply( + group, myreq, a2ai, a2a_pooled_embs_tensor, all_to_all_single_comm + ) return myreq @@ -1138,6 +1240,7 @@ def forward( myreq: Request[Tensor], a2ai: All2AllPooledInfo, input_embeddings: Tensor, + all_to_all_single_comm: All2AllSingle, ) -> Tensor: my_rank = dist.get_rank(pg) (B_global, D_local_sum) = input_embeddings.shape @@ -1191,16 +1294,24 @@ def forward( input_split_sizes = [D_local_sum * B_rank for B_rank in batch_size_per_rank] qcomm_ctx = None - sharded_output_embeddings = torch.empty( - sum(output_split_sizes), + sharded_output_embeddings = all_to_all_single_comm.allocate( + (sum(output_split_sizes),), dtype=sharded_input_embeddings.dtype, device=sharded_input_embeddings.device, ) + sharded_input_embeddings_registered = all_to_all_single_comm.allocate( + sharded_input_embeddings.shape, + dtype=sharded_input_embeddings.dtype, + device=sharded_input_embeddings.device, + ) + + sharded_input_embeddings_registered.copy_(sharded_input_embeddings) + with record_function("## alltoall_fwd_single ##"): - req = dist.all_to_all_single( - output=sharded_output_embeddings, - input=sharded_input_embeddings, + req = all_to_all_single_comm( + output_tensor=sharded_output_embeddings, + input_tensor=sharded_input_embeddings_registered, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=pg, @@ -1219,7 +1330,7 @@ def forward( @staticmethod # pyre-fixme[2]: Parameter must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: + def backward(ctx, *unused) -> Tuple[None, None, None, Tensor, None]: pg = ctx.pg my_rank = dist.get_rank(pg) myreq = ctx.myreq @@ -1242,7 +1353,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: grad_input.div_(dist.get_world_size(ctx.pg)) myreq.tensor = None myreq.dummy_tensor = None - return (None, None, None, grad_input) + return (None, None, None, grad_input, None) class All2All_Pooled_Wait(Function): @@ -1386,6 +1497,7 @@ def forward( myreq: Request[Tensor], a2ai: VariableBatchAll2AllPooledInfo, input_embeddings: Tensor, + all_to_all_single_comm: All2AllSingle, ) -> Tensor: my_rank = dist.get_rank(pg) @@ -1439,16 +1551,24 @@ def forward( for split in input_split_sizes ] - sharded_output_embeddings = torch.empty( - sum(output_split_sizes), + sharded_output_embeddings = all_to_all_single_comm.allocate( + (sum(output_split_sizes),), dtype=sharded_input_embeddings.dtype, device=sharded_input_embeddings.device, ) + sharded_input_embeddings_registered = all_to_all_single_comm.allocate( + sharded_input_embeddings.shape, + dtype=sharded_input_embeddings.dtype, + device=sharded_input_embeddings.device, + ) + + sharded_input_embeddings_registered.copy_(sharded_input_embeddings) + with record_function("## alltoall_fwd_single ##"): - req = dist.all_to_all_single( - output=sharded_output_embeddings, - input=sharded_input_embeddings, + req = all_to_all_single_comm( + output_tensor=sharded_output_embeddings, + input_tensor=sharded_input_embeddings_registered, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=pg, @@ -1467,7 +1587,7 @@ def forward( @staticmethod # pyre-fixme[2]: Parameter must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: + def backward(ctx, *unused) -> Tuple[None, None, None, Tensor, None]: myreq = ctx.myreq a2ai = myreq.a2ai assert myreq.req is not None @@ -1487,7 +1607,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: grad_input.div_(dist.get_world_size(ctx.pg)) myreq.tensor = None myreq.dummy_tensor = None - return (None, None, None, grad_input) + return (None, None, None, grad_input, None) class Variable_Batch_All2All_Pooled_Wait(Function):