Skip to content

Commit

Permalink
Add nccl_alltoall function (pytorch#3551)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3551

X-link: facebookresearch/FBGEMM#636

Expose generic nccl_alltoall thru FBGEMM

Differential Revision: D67870377

fbshipit-source-id: a40eeb48cfbf1ad88ed0aa6eb852f537c3ec202c
  • Loading branch information
jasonjk-park authored and facebook-github-bot committed Jan 7, 2025
1 parent af77a89 commit dc116f6
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ void nccl_alltoall_single(
src, dst, world_size, *get_nccl_comm(comm_idx), stream);
}

void nccl_alltoall(
std::vector<at::Tensor> dsts,
std::vector<at::Tensor> srcs,
int64_t comm_idx) {
auto stream = at::cuda::getCurrentCUDAStream();
torch::cuda::nccl::all2all(dsts, srcs, *get_nccl_comm(comm_idx), stream);
}

void nccl_reducescatter(at::Tensor dst, at::Tensor src, int64_t comm_idx) {
using namespace c10d;
TORCH_CHECK(src.is_contiguous());
Expand Down Expand Up @@ -272,6 +280,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {

m.def(
"nccl_alltoall_single(Tensor(a!) dst, Tensor src, int world_size, int comm_idx=0) -> ()");
m.def("nccl_alltoall(Tensor(a!)[] dst, Tensor[] src, int comm_idx=0) -> ()");

m.def("nccl_reducescatter(Tensor(a!) dst, Tensor src, int comm_idx=0) -> ()");

Expand Down Expand Up @@ -299,6 +308,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("nccl_allreduce", nccl_allreduce);
m.impl("nccl_allgather", nccl_allgather);
m.impl("nccl_alltoall_single", nccl_alltoall_single);
m.impl("nccl_alltoall", nccl_alltoall);
m.impl("nccl_reducescatter", nccl_reducescatter);
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
Expand All @@ -310,6 +320,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
m.impl("nccl_allreduce", nccl_allreduce);
m.impl("nccl_allgather", nccl_allgather);
m.impl("nccl_alltoall_single", nccl_alltoall_single);
m.impl("nccl_alltoall", nccl_alltoall);
m.impl("nccl_reducescatter", nccl_reducescatter);
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
Expand Down Expand Up @@ -339,6 +350,13 @@ void nccl_alltoall_single_meta(
return;
}

void nccl_alltoall_meta(
std::vector<at::Tensor> /* dsts */,
std::vector<at::Tensor> /* srcs */,
int64_t /* comm_idx */) {
return;
}

void nccl_reducescatter_meta(
at::Tensor /* dst */,
at::Tensor /* src */,
Expand Down Expand Up @@ -366,6 +384,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("nccl_allreduce", nccl_allreduce_meta);
m.impl("nccl_allgather", nccl_allgather_meta);
m.impl("nccl_alltoall_single", nccl_alltoall_single_meta);
m.impl("nccl_alltoall", nccl_alltoall_meta);
m.impl("nccl_reducescatter", nccl_reducescatter_meta);
m.impl("one_shot_car_allreduce", one_shot_car_allreduce_meta);
m.impl("two_shot_car_allreduce", two_shot_car_allreduce_meta);
Expand Down

0 comments on commit dc116f6

Please sign in to comment.