Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 138 additions & 18 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

# pyre-strict

from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, List, Optional, Tuple, TypeVar
from typing import Any, List, Optional, Tuple, TypeVar, Union

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives

from torch import Tensor
from torch.autograd import Function
from torch.autograd.profiler import record_function
Expand Down Expand Up @@ -334,6 +335,93 @@ def _get_split_lengths_by_len(
return (my_len, splits)


class Comm(ABC):
"""
Interface for communication primitives.
A primitive primarily needs to handle 3 tasks, namely:

1. How to allocate memory for communication
Depending on the goal, an implementation can choose to:
a. associate each call to a temporary buffer
(best for flexibility and simplicity)
b. reuse an persistent buffer for efficiency reasons

2. Where to allocate memory
(e.g. NCCL mem pool or regular cuda caching allocator)

3. What to do/call upon the comm is called
(see `AllGather` interface as an example)
"""

@abstractmethod
def allocate(
self,
size: Sequence[Union[int, torch.SymInt]],
*,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""
This handles the "how to allocate memory" part.

A default implementation could be simply:

.. code-block:: python
with self.mem_pool:
torch.empty(...)

Args:
size (Sequence[Union[int, torch.SymInt]]): size of the tensor buffer
dtype (torch.dtype): dtype of the tensor buffer
device (torch.device): which device to allocate the tensor onto
"""
...


class All2AllSingle(Comm):
@abstractmethod
def __call__(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
output_split_sizes: Optional[list[int]] = None,
input_split_sizes: Optional[list[int]] = None,
group: Optional[dist.ProcessGroup] = None,
async_op: bool = False,
) -> Optional[dist.Work]: ...


class DefaultAllocMixin:
def allocate(
self,
size: Sequence[Union[int, torch.SymInt]],
*,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
return torch.empty(*size, dtype=dtype, device=device)


class DefaultAll2AllSingle(DefaultAllocMixin, All2AllSingle):
def __call__(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
output_split_sizes: Optional[list[int]] = None,
input_split_sizes: Optional[list[int]] = None,
group: Optional[dist.ProcessGroup] = None,
async_op: bool = False,
) -> Optional[dist.Work]:
return dist.all_to_all_single(
output_tensor,
input_tensor,
output_split_sizes,
input_split_sizes,
group=group,
async_op=async_op,
)


def alltoall_pooled(
a2a_pooled_embs_tensor: Tensor,
batch_size_per_rank: List[int],
Expand All @@ -342,6 +430,7 @@ def alltoall_pooled(
cumsum_dim_sum_per_rank_tensor: Optional[Tensor] = None,
group: Optional[dist.ProcessGroup] = None,
codecs: Optional[QuantizedCommCodecs] = None,
all_to_all_single_comm: Optional[All2AllSingle] = None,
) -> Awaitable[Tensor]:
"""
Performs AlltoAll operation for a single pooled embedding tensor. Each process
Expand Down Expand Up @@ -391,7 +480,15 @@ def alltoall_pooled(
return NoWait(all2all_pooled_sync(group, a2ai, a2a_pooled_embs_tensor))

myreq = Request(group, device=a2a_pooled_embs_tensor.device)
All2All_Pooled_Req.apply(group, myreq, a2ai, a2a_pooled_embs_tensor)
if all_to_all_single_comm is None:
all_to_all_single_comm = DefaultAll2AllSingle()
All2All_Pooled_Req.apply(
group,
myreq,
a2ai,
a2a_pooled_embs_tensor,
all_to_all_single_comm,
)
return myreq


Expand Down Expand Up @@ -476,6 +573,7 @@ def variable_batch_alltoall_pooled(
emb_dim_per_rank_per_feature: List[List[int]],
group: Optional[dist.ProcessGroup] = None,
codecs: Optional[QuantizedCommCodecs] = None,
all_to_all_single_comm: Optional[All2AllSingle] = None,
) -> Awaitable[Tensor]:

if group is None:
Expand All @@ -497,7 +595,11 @@ def variable_batch_alltoall_pooled(
)

myreq = Request(group, device=a2a_pooled_embs_tensor.device)
Variable_Batch_All2All_Pooled_Req.apply(group, myreq, a2ai, a2a_pooled_embs_tensor)
if all_to_all_single_comm is None:
all_to_all_single_comm = DefaultAll2AllSingle()
Variable_Batch_All2All_Pooled_Req.apply(
group, myreq, a2ai, a2a_pooled_embs_tensor, all_to_all_single_comm
)
return myreq


Expand Down Expand Up @@ -1138,6 +1240,7 @@ def forward(
myreq: Request[Tensor],
a2ai: All2AllPooledInfo,
input_embeddings: Tensor,
all_to_all_single_comm: All2AllSingle,
) -> Tensor:
my_rank = dist.get_rank(pg)
(B_global, D_local_sum) = input_embeddings.shape
Expand Down Expand Up @@ -1191,16 +1294,24 @@ def forward(
input_split_sizes = [D_local_sum * B_rank for B_rank in batch_size_per_rank]
qcomm_ctx = None

sharded_output_embeddings = torch.empty(
sum(output_split_sizes),
sharded_output_embeddings = all_to_all_single_comm.allocate(
(sum(output_split_sizes),),
dtype=sharded_input_embeddings.dtype,
device=sharded_input_embeddings.device,
)

sharded_input_embeddings_registered = all_to_all_single_comm.allocate(
sharded_input_embeddings.shape,
dtype=sharded_input_embeddings.dtype,
device=sharded_input_embeddings.device,
)

sharded_input_embeddings_registered.copy_(sharded_input_embeddings)

with record_function("## alltoall_fwd_single ##"):
req = dist.all_to_all_single(
output=sharded_output_embeddings,
input=sharded_input_embeddings,
req = all_to_all_single_comm(
output_tensor=sharded_output_embeddings,
input_tensor=sharded_input_embeddings_registered,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=pg,
Expand All @@ -1219,7 +1330,7 @@ def forward(
@staticmethod
# pyre-fixme[2]: Parameter must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor, None]:
pg = ctx.pg
my_rank = dist.get_rank(pg)
myreq = ctx.myreq
Expand All @@ -1242,7 +1353,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
grad_input.div_(dist.get_world_size(ctx.pg))
myreq.tensor = None
myreq.dummy_tensor = None
return (None, None, None, grad_input)
return (None, None, None, grad_input, None)


class All2All_Pooled_Wait(Function):
Expand Down Expand Up @@ -1386,6 +1497,7 @@ def forward(
myreq: Request[Tensor],
a2ai: VariableBatchAll2AllPooledInfo,
input_embeddings: Tensor,
all_to_all_single_comm: All2AllSingle,
) -> Tensor:
my_rank = dist.get_rank(pg)

Expand Down Expand Up @@ -1439,16 +1551,24 @@ def forward(
for split in input_split_sizes
]

sharded_output_embeddings = torch.empty(
sum(output_split_sizes),
sharded_output_embeddings = all_to_all_single_comm.allocate(
(sum(output_split_sizes),),
dtype=sharded_input_embeddings.dtype,
device=sharded_input_embeddings.device,
)

sharded_input_embeddings_registered = all_to_all_single_comm.allocate(
sharded_input_embeddings.shape,
dtype=sharded_input_embeddings.dtype,
device=sharded_input_embeddings.device,
)

sharded_input_embeddings_registered.copy_(sharded_input_embeddings)

with record_function("## alltoall_fwd_single ##"):
req = dist.all_to_all_single(
output=sharded_output_embeddings,
input=sharded_input_embeddings,
req = all_to_all_single_comm(
output_tensor=sharded_output_embeddings,
input_tensor=sharded_input_embeddings_registered,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=pg,
Expand All @@ -1467,7 +1587,7 @@ def forward(
@staticmethod
# pyre-fixme[2]: Parameter must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor, None]:
myreq = ctx.myreq
a2ai = myreq.a2ai
assert myreq.req is not None
Expand All @@ -1487,7 +1607,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
grad_input.div_(dist.get_world_size(ctx.pg))
myreq.tensor = None
myreq.dummy_tensor = None
return (None, None, None, grad_input)
return (None, None, None, grad_input, None)


class Variable_Batch_All2All_Pooled_Wait(Function):
Expand Down
Loading