diff --git a/oneflow/core/job/eager_ccl_comm_manager.h b/oneflow/core/job/eager_ccl_comm_manager.h index eb747f78dac..647b4d0a132 100644 --- a/oneflow/core/job/eager_ccl_comm_manager.h +++ b/oneflow/core/job/eager_ccl_comm_manager.h @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/job/plan.pb.h" +#include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { @@ -29,6 +30,13 @@ class EagerCclCommMgr { virtual void CreateCommFromPlan(const Plan& plan) = 0; virtual bool IsAsyncLaunchCclLogicalKernel() const = 0; virtual void SetAsyncLaunchCclLogicalKernel(bool val) = 0; + virtual ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) = 0; + virtual ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc, + const std::string& stream_name) = 0; + virtual ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc, + const std::string& stream_name, + const int64_t this_parallel_id, + const std::string& comm_key) = 0; template T* As() { diff --git a/oneflow/core/job/eager_nccl_comm_manager.cpp b/oneflow/core/job/eager_nccl_comm_manager.cpp index 2b3e5a4a735..fe2755438b7 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.cpp +++ b/oneflow/core/job/eager_nccl_comm_manager.cpp @@ -156,6 +156,90 @@ ncclComm_t EagerNcclCommMgr::GetCommForDeviceAndStreamName( return comm; } +ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) { + std::set> device_set; + FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) { + int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + + ncclComm_t comm = GetCommForDevice(device_set); + std::shared_ptr ncclCommAdapter = std::make_shared(comm); + ccl::CclComm ccl_comm(ncclCommAdapter); + return ccl_comm; +} + +ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescAndStreamName( + const ParallelDesc& parallel_desc, const std::string& stream_name) { + std::set> device_set; + FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) { + int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + + ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name); + std::shared_ptr ncclCommAdapter = std::make_shared(comm); + ccl::CclComm ccl_comm(ncclCommAdapter); + return ccl_comm; +} + +ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescNdHierarchy( + const ParallelDesc& parallel_desc, const std::string& stream_name, + const int64_t this_parallel_id, const std::string& comm_key) { + std::set> device_set; + const Shape& hierarchy = *parallel_desc.hierarchy(); + CHECK_LE(hierarchy.NumAxes(), 2); + + // 1D + if (hierarchy.NumAxes() == 1) { + // 1D hierarchy + for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { + int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + } else if (hierarchy.NumAxes() == 2) { + // 2D hierarchy + CHECK(comm_key == "SameDim0" || comm_key == "SameDim1"); + if (comm_key == "SameDim0") { + const int64_t num_groups = hierarchy.At(0); + const int64_t group_size = hierarchy.At(1); + CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num()); + const int64_t this_group_begin_parallel_id = this_parallel_id / group_size * group_size; + CHECK_EQ(this_group_begin_parallel_id % group_size, 0); + CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc.parallel_num()); + for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { + const int64_t parallel_id = this_group_begin_parallel_id + id_in_group; + const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + } else if (comm_key == "SameDim1") { + const int64_t group_size = hierarchy.At(0); + const int64_t num_groups = hierarchy.At(1); + CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num()); + const int64_t this_group_begin_parallel_id = this_parallel_id % num_groups; + CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups, + parallel_desc.parallel_num()); + for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { + const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups); + const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + } else { + UNIMPLEMENTED(); + } + } + + ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name); + std::shared_ptr ncclCommAdapter = std::make_shared(comm); + ccl::CclComm ccl_comm(ncclCommAdapter); + return ccl_comm; +} + void EagerNcclCommMgr::CreateCommFromPlan(const Plan& plan) { const int64_t rank = GlobalProcessCtx::Rank(); const int64_t dev = GlobalProcessCtx::LocalRank(); diff --git a/oneflow/core/job/eager_nccl_comm_manager.h b/oneflow/core/job/eager_nccl_comm_manager.h index 2210983f3a1..e335772dc39 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.h +++ b/oneflow/core/job/eager_nccl_comm_manager.h @@ -25,6 +25,19 @@ limitations under the License. #include "oneflow/core/device/cuda_util.h" namespace oneflow { +namespace ccl { + +class NcclCommAdapter : public CommBase { + public: + explicit NcclCommAdapter(ncclComm_t comm) : comm_(comm) {} + + void* getComm() const override { return const_cast(static_cast(&comm_)); } + + private: + ncclComm_t comm_; +}; + +} // namespace ccl class EagerNcclCommMgr final : public EagerCclCommMgr { public: @@ -36,6 +49,13 @@ class EagerNcclCommMgr final : public EagerCclCommMgr { ncclComm_t GetCommForDevice(const std::set>& device_set); ncclComm_t GetCommForDeviceAndStreamName(const std::set>& device_set, const std::string& stream_name); + ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) override; + ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc, + const std::string& stream_name) override; + ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc, + const std::string& stream_name, + const int64_t this_parallel_id, + const std::string& comm_key) override; void CreateCommFromPlan(const Plan& plan) override; bool IsAsyncLaunchCclLogicalKernel() const override { return async_launch_nccl_logical_kernel_; } diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp index c342f4a2f42..825c43a36a6 100644 --- a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -20,6 +20,7 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -41,31 +42,25 @@ class NcclSendRecvBoxingKernel final : public Kernel { const std::vector& recv_elem_cnts() const { return recv_elem_cnts_; } const bool has_input() const { return has_input_; } const bool has_output() const { return has_output_; } - ncclComm_t comm() const { return GetOrCreate().comm; } + ccl::CclComm ccl_comm() const { return GetOrCreate().ccl_comm; } private: struct Comm { - Comm(ncclComm_t comm) : comm(comm) {} - ncclComm_t comm; + Comm(ccl::CclComm comm) : ccl_comm(comm) {} + ccl::CclComm ccl_comm; }; void Init() const { ParallelDesc parallel_desc(parallel_conf_); - std::set> device_set; - for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { - int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ncclComm_t comm = - comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); - comm_.reset(new Comm(comm)); + ccl::CclComm ccl_comm = + comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc, stream_name_); + ccl_comm_.reset(new Comm(ccl_comm)); } const Comm& GetOrCreate() const { - if (!comm_) { Init(); } - return *comm_; + if (!ccl_comm_) { Init(); } + return *ccl_comm_; } void VirtualKernelInit(KernelContext* ctx) override; @@ -73,7 +68,7 @@ class NcclSendRecvBoxingKernel final : public Kernel { std::string stream_name_; ParallelConf parallel_conf_; - mutable std::unique_ptr comm_; + mutable std::unique_ptr ccl_comm_; bool src_nd_sbp_no_partial_parallel_; std::vector> in_tensor_slice_copier_vec_; std::vector> out_tensor_slice_copier_vec_; @@ -85,27 +80,31 @@ class NcclSendRecvBoxingKernel final : public Kernel { void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { Blob* buf = ctx->BnInOp2Blob("buf"); - ncclComm_t comm = this->comm(); - cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); + ccl::CclComm ccl_comm = this->ccl_comm(); const std::vector& send_elem_cnts = this->send_elem_cnts(); const std::vector& recv_elem_cnts = this->recv_elem_cnts(); const int64_t parallel_num = this->kernel_conf().parallel_ctx().parallel_num(); const DataType data_type = buf->data_type(); std::vector send_in_ptr; std::vector recv_out_ptr; + std::vector send_offsets; + std::vector recv_offsets; char* buf_ptr = buf->mut_dptr(); - int64_t offset = 0; + uint64_t offset = 0; if (this->has_input()) { for (int64_t i = 0; i < parallel_num; ++i) { void* send_ptr = reinterpret_cast(buf_ptr + offset); send_in_ptr.push_back(send_ptr); + send_offsets.push_back(offset); offset += send_elem_cnts.at(i) * GetSizeOfDataType(data_type); } } + const uint64_t recv_offset = offset; if (this->has_output()) { for (int64_t i = 0; i < parallel_num; ++i) { void* recv_ptr = reinterpret_cast(buf_ptr + offset); recv_out_ptr.push_back(recv_ptr); + recv_offsets.push_back(offset - recv_offset); offset += recv_elem_cnts.at(i) * GetSizeOfDataType(data_type); } } @@ -119,18 +118,17 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { } } } - OF_NCCL_CHECK(ncclGroupStart()); - for (int64_t i = 0; i < parallel_num; ++i) { - if (this->has_input() && send_elem_cnts.at(i) != 0) { - OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), i, - comm, cuda_stream)); - } - if (this->has_output() && recv_elem_cnts.at(i) != 0) { - OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), GetNcclDataType(data_type), - i, comm, cuda_stream)); - } + + if (this->has_input() || this->has_output()) { + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), data_type, data_type, parallel_num); + void* send_buf = reinterpret_cast(buf_ptr); + void* recv_buf = reinterpret_cast(buf_ptr + recv_offset); + all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), + recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm, + this->has_input(), this->has_output()); } - OF_NCCL_CHECK(ncclGroupEnd()); + if (!this->has_output()) { return; } Blob* out = ctx->BnInOp2Blob("out"); const std::vector>& out_tensor_slice_copier_vec = diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp index f6b6f2005fe..a09ebec52d4 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp @@ -95,6 +95,11 @@ class CpuAllGather final : public AllGather { CHECK_JUST(AllGatherImpl(in, out, elem_cnt, datatype_, cpu_communication_ctx->parallel_desc())); } + void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + const ccl::CclComm& ccl_comm) const override { + UNIMPLEMENTED(); + } + private: DataType datatype_; }; diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp index 6550f9b81fb..c03bed6c1db 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp @@ -148,6 +148,11 @@ class CpuAllReduce final : public AllReduce { cpu_communication_ctx->parallel_desc())); } + void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + const ccl::CclComm& ccl_comm) const override { + UNIMPLEMENTED(); + } + private: DataType datatype_; ReduceType reduce_type_; diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp index 412e2442c12..2640d8f372f 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp @@ -40,6 +40,11 @@ class CpuRecvImpl final : public Recv { CHECK_JUST(CpuRecv(out, buffer_size, src)); } + void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, + const ccl::CclComm& ccl_comm) const override { + Launch(stream, out, elem_cnt, src); + } + private: size_t size_of_dtype_; }; diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp index 8ff362c2eaa..7ba9db44570 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp @@ -120,6 +120,11 @@ class CpuReduceScatter final : public ReduceScatter { cpu_communication_ctx->parallel_desc())); } + void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + const ccl::CclComm& ccl_comm) const override { + UNIMPLEMENTED(); + } + private: DataType datatype_; ReduceType reduce_type_; diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp index a0e62957fbd..829bf5a09cf 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp @@ -40,6 +40,11 @@ class CpuSendImpl final : public Send { CHECK_JUST(CpuSend(in, buffer_size, dst)); } + void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, + const ccl::CclComm& comm) const override { + Launch(stream, in, elem_cnt, dst); + } + private: size_t size_of_dtype_; }; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp index a3012783f74..5de37a2ea08 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp @@ -40,6 +40,13 @@ class CudaAllGather final : public AllGather { stream->As()->cuda_stream())); } + virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + const ccl::CclComm& ccl_comm) const override { + ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); + OF_NCCL_CHECK(ncclAllGather(in, out, elem_cnt, nccl_datatype_, *nccl_comm, + stream->As()->cuda_stream())); + } + private: ncclDataType_t nccl_datatype_; }; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp index fa6803d54df..e9567863a1c 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp @@ -57,6 +57,13 @@ class CudaAllReduce final : public AllReduce { stream->As()->cuda_stream())); } + void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + const ccl::CclComm& ccl_comm) const override { + ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); + OF_NCCL_CHECK(ncclAllReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm, + stream->As()->cuda_stream())); + } + private: ncclDataType_t nccl_datatype_; ncclRedOp_t nccl_reduce_op_; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp new file mode 100644 index 00000000000..313a14f00e0 --- /dev/null +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp @@ -0,0 +1,113 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifdef WITH_CUDA +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" +#include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" +#include "oneflow/core/device/nccl_util.h" +#include "oneflow/core/common/device_type.h" + +namespace oneflow { + +namespace ccl { + +class CudaAllToAll final : public AllToAll { + public: + OF_DISALLOW_COPY_AND_MOVE(CudaAllToAll); + CudaAllToAll() + : send_dtype_(), recv_dtype_(), nccl_send_dtype_(), nccl_recv_dtype_(), rank_count_(0) {} + ~CudaAllToAll() = default; + + void Init(DataType send_dtype, DataType recv_dtype, size_t parallel_num) override { + this->send_dtype_ = send_dtype; + this->recv_dtype_ = recv_dtype; + this->nccl_send_dtype_ = GetNcclDataType(send_dtype); + this->nccl_recv_dtype_ = GetNcclDataType(recv_dtype); + this->rank_count_ = parallel_num; + } + + void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv, int64_t recv_count, + const ccl::CclComm& ccl_comm) const override { + ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); + int64_t send_offset = 0; + int64_t recv_offset = 0; + OF_NCCL_CHECK(ncclGroupStart()); + for (int64_t i = 0; i < this->rank_count_; ++i) { + if (send_count > 0) { + char* send_ptr = static_cast(send) + send_offset; + OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + send_offset += send_count * GetSizeOfDataType(this->send_dtype_); + if (recv_count) { + char* recv_ptr = static_cast(recv) + recv_offset; + OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + recv_offset += recv_count * GetSizeOfDataType(this->recv_dtype_); + } + OF_NCCL_CHECK(ncclGroupEnd()); + } + + void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets, + void* recv, const void* recv_counts, const void* recv_offsets, + const ccl::CclComm& ccl_comm, const bool has_input, + const bool has_output) const override { + ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); + int64_t* send_counts_ptr = static_cast(const_cast(send_counts)); + int64_t* recv_counts_ptr = static_cast(const_cast(recv_counts)); + int64_t* send_offsets_ptr = static_cast(const_cast(send_offsets)); + int64_t* recv_offsets_ptr = static_cast(const_cast(recv_offsets)); + if (has_input || has_output) { + OF_NCCL_CHECK(ncclGroupStart()); + for (int64_t i = 0; i < this->rank_count_; ++i) { + if (has_input) { + const uint64_t send_count = static_cast(send_counts_ptr[i]); + if (send_count > 0) { + uint64_t send_offset = static_cast(send_offsets_ptr[i]); + char* send_ptr = static_cast(send) + send_offset; + OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + } + if (has_output) { + const uint64_t recv_count = static_cast(recv_counts_ptr[i]); + if (recv_count > 0) { + uint64_t recv_offset = static_cast(recv_offsets_ptr[i]); + char* recv_ptr = static_cast(recv) + recv_offset; + OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + } + } + OF_NCCL_CHECK(ncclGroupEnd()); + } + } + + private: + DataType send_dtype_; + DataType recv_dtype_; + ncclDataType_t nccl_send_dtype_; + ncclDataType_t nccl_recv_dtype_; + size_t rank_count_; +}; + +REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, AllToAll, CudaAllToAll); + +} // namespace ccl + +} // namespace oneflow + +#endif // WITH_CUDA diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp index cc4bcfafe3f..46dcdc7b871 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp @@ -16,6 +16,7 @@ limitations under the License. #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/recv.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h" +#include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" namespace oneflow { @@ -40,6 +41,17 @@ class CudaRecv final : public Recv { #endif // HAS_NCCL_SEND_RECV } + void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, + const ccl::CclComm& ccl_comm) const override { +#if HAS_NCCL_SEND_RECV + ncclComm_t* comm = reinterpret_cast(ccl_comm.getComm()); + OF_NCCL_CHECK(ncclRecv(out, elem_cnt, nccl_datatype_, src, *comm, + stream->As()->cuda_stream())); +#else + UNIMPLEMENTED() << "GPU recv is only supported when nccl version >= 2.7" +#endif // HAS_NCCL_SEND_RECV + } + private: ncclDataType_t nccl_datatype_; }; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp index 80419a84759..3d57f89d6fe 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp @@ -57,6 +57,13 @@ class CudaReduceScatter final : public ReduceScatter { stream->As()->cuda_stream())); } + virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + const ccl::CclComm& ccl_comm) const override { + ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); + OF_NCCL_CHECK(ncclReduceScatter(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm, + stream->As()->cuda_stream())); + } + private: ncclDataType_t nccl_datatype_; ncclRedOp_t nccl_reduce_op_; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp index da7ac181252..78de1d86d8e 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp @@ -16,6 +16,7 @@ limitations under the License. #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/send.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h" +#include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" namespace oneflow { @@ -40,6 +41,17 @@ class CudaSend final : public Send { #endif // HAS_NCCL_SEND_RECV } + void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, + const ccl::CclComm& ccl_comm) const override { +#if HAS_NCCL_SEND_RECV + ncclComm_t* comm = reinterpret_cast(ccl_comm.getComm()); + OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, dst, *comm, + stream->As()->cuda_stream())); +#else + UNIMPLEMENTED() << "GPU send is only supported when nccl version >= 2.7" +#endif // HAS_NCCL_SEND_RECV + } + private: ncclDataType_t nccl_datatype_; }; diff --git a/oneflow/user/kernels/collective_communication/include/all_gather.h b/oneflow/user/kernels/collective_communication/include/all_gather.h index 66b520be6a5..e765f27e5ca 100644 --- a/oneflow/user/kernels/collective_communication/include/all_gather.h +++ b/oneflow/user/kernels/collective_communication/include/all_gather.h @@ -32,6 +32,9 @@ class AllGather : public CollectiveCommunication { virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communicator) const = 0; + + virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsAllGatherRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/all_reduce.h b/oneflow/user/kernels/collective_communication/include/all_reduce.h index 0dcc685b966..4253864d026 100644 --- a/oneflow/user/kernels/collective_communication/include/all_reduce.h +++ b/oneflow/user/kernels/collective_communication/include/all_reduce.h @@ -32,6 +32,9 @@ class AllReduce : public CollectiveCommunication { virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communicator) const = 0; + + virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsAllReduceRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/all_to_all.h b/oneflow/user/kernels/collective_communication/include/all_to_all.h new file mode 100644 index 00000000000..81c35ce80ab --- /dev/null +++ b/oneflow/user/kernels/collective_communication/include/all_to_all.h @@ -0,0 +1,52 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_ +#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_ + +#include "oneflow/user/kernels/collective_communication/include/collective_communication.h" + +namespace oneflow { + +namespace ccl { + +class AllToAll : public CollectiveCommunication { + public: + OF_DISALLOW_COPY_AND_MOVE(AllToAll); + AllToAll() = default; + ~AllToAll() override = default; + + virtual void Init(DataType send_dtype, DataType recv_dtype, size_t rank_count) = 0; + + // for normal alltoall(balanced send/resv count) + virtual void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv, + int64_t recv_count, const ccl::CclComm& ccl_comm) const = 0; + + // for unbalanced all to all(e.g. nccl all2all using send/recv; hccl HcclAlltoAllV) + virtual void Launch(ep::Stream* stream, void* send, const void* send_counts, + const void* send_offsets, void* recv, const void* recv_counts, + const void* recv_offsets, const ccl::CclComm& ccl_comm, const bool has_input, + const bool has_output) const = 0; +}; + +inline bool IsAllToAllRegistered(DeviceType device_type) { + return IsClassRegistered(device_type); +} + +} // namespace ccl + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_ diff --git a/oneflow/user/kernels/collective_communication/include/collective_communication.h b/oneflow/user/kernels/collective_communication/include/collective_communication.h index c197820d974..4f2211cf9d3 100644 --- a/oneflow/user/kernels/collective_communication/include/collective_communication.h +++ b/oneflow/user/kernels/collective_communication/include/collective_communication.h @@ -41,6 +41,26 @@ enum ReduceType { MAKE_TYPED_CTRV_SEQ(ReduceType, \ OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, REDUCE_TYPE_SEQ)) +// abstruct base class for comm +class CommBase { + public: + virtual ~CommBase() = default; + + // return impl of comm + virtual void* getComm() const = 0; +}; + +class CclComm { + public: + CclComm() {} + explicit CclComm(std::shared_ptr comm) : comm_(std::move(comm)) {} + + void* getComm() const { return comm_->getComm(); } + + private: + std::shared_ptr comm_{}; +}; + class CollectiveCommunication { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveCommunication); diff --git a/oneflow/user/kernels/collective_communication/include/communication_context.h b/oneflow/user/kernels/collective_communication/include/communication_context.h index 9423f0af997..ae6dbf1fdb9 100644 --- a/oneflow/user/kernels/collective_communication/include/communication_context.h +++ b/oneflow/user/kernels/collective_communication/include/communication_context.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_ +#include "collective_communication.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/auto_registration_factory.h" diff --git a/oneflow/user/kernels/collective_communication/include/recv.h b/oneflow/user/kernels/collective_communication/include/recv.h index 59c1aef849f..f0cf5d34627 100644 --- a/oneflow/user/kernels/collective_communication/include/recv.h +++ b/oneflow/user/kernels/collective_communication/include/recv.h @@ -31,6 +31,9 @@ class Recv : public CollectiveCommunication { virtual void Init(DataType dtype) = 0; virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0; + + virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, + const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsRecvRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/reduce_scatter.h b/oneflow/user/kernels/collective_communication/include/reduce_scatter.h index a3b179b48fb..da62e5d6085 100644 --- a/oneflow/user/kernels/collective_communication/include/reduce_scatter.h +++ b/oneflow/user/kernels/collective_communication/include/reduce_scatter.h @@ -32,6 +32,9 @@ class ReduceScatter : public CollectiveCommunication { virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communicator) const = 0; + + virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsReduceScatterRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/send.h b/oneflow/user/kernels/collective_communication/include/send.h index 6658c7de292..4ca4491c7e5 100644 --- a/oneflow/user/kernels/collective_communication/include/send.h +++ b/oneflow/user/kernels/collective_communication/include/send.h @@ -31,6 +31,9 @@ class Send : public CollectiveCommunication { virtual void Init(DataType dtype) = 0; virtual void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const = 0; + + virtual void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, + const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsSendRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu index 2b5511b9c64..7518efb0c33 100644 --- a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu +++ b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu @@ -21,6 +21,7 @@ limitations under the License. #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -34,27 +35,20 @@ class EagerNcclOpKernelCache final : public user_op::OpKernelCache { ~EagerNcclOpKernelCache() override = default; Symbol parallel_desc() const { return parallel_desc_; } - ncclComm_t comm() const { return comm_; } + const ccl::CclComm& ccl_comm() const { return ccl_comm_; } private: void Init(user_op::KernelCacheContext* ctx) { const std::string& parallel_conf_txt = ctx->Attr("parallel_conf"); ParallelConf parallel_conf; - std::set> device_set; CHECK(TxtString2PbMessage(parallel_conf_txt, ¶llel_conf)); parallel_desc_ = SymbolOf(ParallelDesc(parallel_conf)); - FOR_RANGE(int64_t, parallel_id, 0, parallel_desc_->parallel_num()) { - int64_t machine_id = CHECK_JUST(parallel_desc_->MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc_->DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } - comm_ = CHECK_NOTNULL(Singleton::Get()) - ->As() - ->GetCommForDevice(device_set); + EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); + ccl_comm_ = comm_mgr->GetCclCommForParallelDesc(parallel_conf); } Symbol parallel_desc_; - ncclComm_t comm_{}; + ccl::CclComm ccl_comm_{}; }; size_t InferEagerNcclS2SKernelTmpBufferSize(user_op::InferContext* ctx) { @@ -148,21 +142,12 @@ class EagerNcclS2SKernel final : public user_op::OpKernel { { // NOTE: Do S2S - OF_NCCL_CHECK(ncclGroupStart()); const int64_t elem_per_chunk = elem_cnt / num_ranks; - const int64_t chunk_size = elem_per_chunk * dtype_size; - for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, - kernel_cache->comm(), - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( - reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, kernel_cache->comm(), - ctx->stream()->As()->cuda_stream())); - } - OF_NCCL_CHECK(ncclGroupEnd()); + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); + auto& ccl_comm = kernel_cache->ccl_comm(); + all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, + unpack_from_ptr, elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { diff --git a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp index dac2545f905..03cecc9e427 100644 --- a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp @@ -21,6 +21,9 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" +#include "oneflow/user/kernels/collective_communication/include/all_reduce.h" +#include "oneflow/user/kernels/collective_communication/include/all_gather.h" +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -39,9 +42,9 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState { } ~NcclLogical2DSameDim0KernelCommState() override = default; - ncclComm_t comm() { + const ccl::CclComm& ccl_comm() { if (!is_init_) { Init(); } - return comm_; + return ccl_comm_; } int64_t num_ranks() { @@ -54,24 +57,12 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState { private: void Init() { CHECK(!is_init_); - std::set> device_set; const Shape& hierarchy = *parallel_desc_.hierarchy(); CHECK_EQ(hierarchy.NumAxes(), 2); - const int64_t num_groups = hierarchy.At(0); const int64_t group_size = hierarchy.At(1); - CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num()); - const int64_t this_group_begin_parallel_id = this_parallel_id_ / group_size * group_size; - CHECK_EQ(this_group_begin_parallel_id % group_size, 0); - CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc_.parallel_num()); - for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { - const int64_t parallel_id = this_group_begin_parallel_id + id_in_group; - const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - comm_ = - comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_, + this_parallel_id_, "SameDim0"); num_ranks_ = group_size; is_init_ = true; } @@ -81,7 +72,7 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState { ParallelDesc parallel_desc_; int64_t this_parallel_id_; int64_t num_ranks_{}; - ncclComm_t comm_{}; + ccl::CclComm ccl_comm_{}; }; class NcclLogical2DSameDim0AllGatherNoncontinuousKernelState @@ -127,19 +118,22 @@ class NcclLogical2DSameDim0AllReduce final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); - VLOG(3) << "[NcclLogical2D][SameDim0AllReduce] " << nccl_comm->stream_name() << " " + VLOG(3) << "[NcclLogical2D][SameDim0AllReduce] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, nccl_comm->comm(), - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -161,18 +155,22 @@ class NcclLogical2DSameDim0AllGather final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); - const int64_t num_ranks = nccl_comm->num_ranks(); + const int64_t num_ranks = comm_state->num_ranks(); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - VLOG(3) << "[NcclLogical2D][SameDim0AllGather] " << nccl_comm->stream_name() << " " + VLOG(3) << "[NcclLogical2D][SameDim0AllGather] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), nccl_comm->comm(), - ctx->stream()->As()->cuda_stream())); + + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -225,9 +223,13 @@ class NcclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKern // NOTE(chengcheng): Do AllGather CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); + + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), + ccl_comm); CHECK_GT(in_split_axis, 0); // NOTE(chengcheng): Do unpack. @@ -342,21 +344,12 @@ class NcclLogical2DSameDim0All2All final : public user_op::OpKernel { { // NOTE(chengcheng): Do S2S - OF_NCCL_CHECK(ncclGroupStart()); const int64_t elem_per_chunk = elem_cnt / num_ranks; - const int64_t chunk_size = elem_per_chunk * dtype_size; - for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, - kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( - reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); - } - OF_NCCL_CHECK(ncclGroupEnd()); + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); + all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, + unpack_from_ptr, elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { @@ -414,29 +407,16 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState } ~NcclLogical2DSameDim1KernelCommState() = default; - ncclComm_t comm() { + const ccl::CclComm& ccl_comm() { if (!is_init_) { - std::set> device_set; const Shape& hierarchy = *parallel_desc_.hierarchy(); CHECK_EQ(hierarchy.NumAxes(), 2); - const int64_t group_size = hierarchy.At(0); - const int64_t num_groups = hierarchy.At(1); - CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num()); - const int64_t this_group_begin_parallel_id = this_parallel_id_ % num_groups; - CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups, - parallel_desc_.parallel_num()); - for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { - const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups); - const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - comm_ = - comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_, + this_parallel_id_, "SameDim1"); is_init_ = true; } - return comm_; + return ccl_comm_; } const std::string& stream_name() const { return stream_name_; } @@ -446,7 +426,7 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState std::string stream_name_; ParallelDesc parallel_desc_; int64_t this_parallel_id_; - ncclComm_t comm_{}; + ccl::CclComm ccl_comm_{}; }; class NcclLogical2DSameDim1AllReduce final : public user_op::OpKernel { @@ -462,19 +442,23 @@ class NcclLogical2DSameDim1AllReduce final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); - VLOG(3) << "[NcclLogical2D][SameDim1AllReduce] " << nccl_comm->stream_name() << " " + VLOG(3) << "[NcclLogical2D][SameDim1AllReduce] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, nccl_comm->comm(), - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { diff --git a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp index aeb906b6387..4efe792d1a8 100644 --- a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp @@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" @@ -21,6 +20,13 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" +#include "collective_communication/include/collective_communication.h" +#include "collective_communication/include/send.h" +#include "collective_communication/include/recv.h" +#include "collective_communication/include/all_gather.h" +#include "collective_communication/include/all_reduce.h" +#include "collective_communication/include/all_to_all.h" +#include "collective_communication/include/reduce_scatter.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -83,9 +89,9 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { } ~NcclLogicalFusionKernelState() override = default; - ncclComm_t comm() { + ccl::CclComm ccl_comm() { if (!is_init_) { InitComm(); } - return comm_; + return ccl_comm_; } int64_t num_ranks() { @@ -121,45 +127,17 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { private: void InitComm() { CHECK(!is_init_); - std::set> device_set; const Shape& hierarchy = *parallel_desc_.hierarchy(); if (hierarchy.NumAxes() == 1) { num_ranks_ = parallel_desc_.parallel_num(); - FOR_RANGE(int64_t, parallel_id, 0, parallel_desc_.parallel_num()) { - int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } } else if (hierarchy.NumAxes() == 2) { CHECK(comm_key_ == "SameDim0" || comm_key_ == "SameDim1"); if (comm_key_ == "SameDim0") { - const int64_t num_groups = hierarchy.At(0); const int64_t group_size = hierarchy.At(1); - CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num()); - const int64_t this_group_begin_parallel_id = this_parallel_id_ / group_size * group_size; - CHECK_EQ(this_group_begin_parallel_id % group_size, 0); - CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc_.parallel_num()); - for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { - const int64_t parallel_id = this_group_begin_parallel_id + id_in_group; - const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } num_ranks_ = group_size; } else if (comm_key_ == "SameDim1") { const int64_t group_size = hierarchy.At(0); - const int64_t num_groups = hierarchy.At(1); - CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num()); - const int64_t this_group_begin_parallel_id = this_parallel_id_ % num_groups; - CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups, - parallel_desc_.parallel_num()); - for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { - const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups); - const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } num_ranks_ = group_size; } else { UNIMPLEMENTED(); @@ -169,8 +147,8 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - comm_ = - comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_, + this_parallel_id_, comm_key_); is_init_ = true; } @@ -277,7 +255,7 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { std::vector dst_split_axis_list_; std::vector tmp_buffer_offset_; std::vector tmp_buffer_size_; - ncclComm_t comm_{}; + ccl::CclComm ccl_comm_{}; }; class NcclLogicalFusionKernel final : public user_op::OpKernel { @@ -425,109 +403,120 @@ void DoNcclComputeByNcclTypeInGroup(const void* pack_to_ptr, void* unpack_from_p const std::string& nccl_type, const user_op::Tensor* in, user_op::Tensor* out, user_op::KernelComputeContext* ctx, NcclLogicalFusionKernelState* kernel_state, const int32_t i, - const ncclComm_t& comm) { + ccl::CclComm ccl_comm) { + std::unique_ptr ccl_send = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); + std::unique_ptr ccl_recv = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); + const int64_t num_ranks = kernel_state->num_ranks(); VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i << ", stream: " << kernel_state->stream_name() << " Try launch nccl_type: " << nccl_type; if (nccl_type == "_nccl_logical_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); + } else if (nccl_type == "_nccl_logical_reduce_scatter") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt() * num_ranks); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclReduceScatter(in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_reduce_scatter = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), + out->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_all_gather") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_all_gather_noncontinuous") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() != unpack_from_ptr); // do unpack from ptr -> out CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_reduce_scatter_noncontinuous") { CHECK(in->dptr() != pack_to_ptr); // do in -> pack to ptr CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclReduceScatter(pack_to_ptr, out->mut_dptr(), out->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_reduce_scatter = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_reduce_scatter->Launch(ctx->stream(), pack_to_ptr, out->mut_dptr(), + out->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_s2s") { const int64_t elem_cnt = in->shape_view().elem_cnt(); - const int64_t dtype_size = GetSizeOfDataType(in->data_type()); const int64_t elem_per_chunk = elem_cnt / num_ranks; - const int64_t chunk_size = elem_per_chunk * dtype_size; - for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( - reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); - } + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); + all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, + unpack_from_ptr, elem_per_chunk, ccl_comm); + } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather_noncontinuous") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() != unpack_from_ptr); // do unpack from ptr -> out CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") { const int64_t elem_cnt = in->shape_view().elem_cnt(); - const int64_t dtype_size = GetSizeOfDataType(in->data_type()); const int64_t elem_per_chunk = elem_cnt / num_ranks; - const int64_t chunk_size = elem_per_chunk * dtype_size; - for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( - reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); - } + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); + all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, + unpack_from_ptr, elem_per_chunk, ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim1_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); + } else { UNIMPLEMENTED(); } @@ -663,7 +652,7 @@ void NcclLogicalFusionKernel::Compute(user_op::KernelComputeContext* ctx, } // NOTE(chengcheng): init nccl comm need before ncclGroupStart. - ncclComm_t comm = kernel_state->comm(); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); // do nccl compute in group OF_NCCL_CHECK(ncclGroupStart()); @@ -671,7 +660,7 @@ void NcclLogicalFusionKernel::Compute(user_op::KernelComputeContext* ctx, const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", i); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", i); DoNcclComputeByNcclTypeInGroup(pack_to_ptr_list.at(i), unpack_from_ptr_list.at(i), - nccl_type_list.at(i), in, out, ctx, kernel_state, i, comm); + nccl_type_list.at(i), in, out, ctx, kernel_state, i, ccl_comm); } OF_NCCL_CHECK(ncclGroupEnd()); diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp index be88acbca18..f965ac6c418 100644 --- a/oneflow/user/kernels/nccl_logical_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_kernels.cpp @@ -21,6 +21,10 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" +#include "oneflow/user/kernels/collective_communication/include/all_reduce.h" +#include "oneflow/user/kernels/collective_communication/include/all_gather.h" +#include "oneflow/user/kernels/collective_communication/include/reduce_scatter.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -38,20 +42,13 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState { } ~NcclLogicalKernelCommState() override = default; - ncclComm_t comm() { + const ccl::CclComm& ccl_comm() { if (!is_init_) { - std::set> device_set; - FOR_RANGE(int64_t, parallel_id, 0, parallel_desc_.parallel_num()) { - int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - comm_ = - comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc_, stream_name_); is_init_ = true; } - return comm_; + return ccl_comm_; } const std::string& stream_name() const { return stream_name_; } @@ -60,7 +57,7 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState { bool is_init_; std::string stream_name_; ParallelDesc parallel_desc_; - ncclComm_t comm_{}; + ccl::CclComm ccl_comm_{}; }; class NcclLogicalAllGatherNoncontinuousKernelState : public NcclLogicalKernelCommState { @@ -118,19 +115,23 @@ class NcclLogicalAllReduceKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); - VLOG(3) << "[NcclLogical][AllReduce] " << nccl_comm->stream_name() << " " << ctx->op_name() + VLOG(3) << "[NcclLogical][AllReduce] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, nccl_comm->comm(), - ctx->stream()->As()->cuda_stream())); + + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -152,20 +153,24 @@ class NcclLogicalReduceScatterKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = ctx->parallel_ctx().parallel_num(); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt() * num_ranks); - VLOG(3) << "[NcclLogical][ReduceScatter] " << nccl_comm->stream_name() << " " << ctx->op_name() + VLOG(3) << "[NcclLogical][ReduceScatter] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclReduceScatter( - in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(), GetNcclDataType(in->data_type()), - reduce_type, nccl_comm->comm(), ctx->stream()->As()->cuda_stream())); + + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_reduce_scatter = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), + out->shape_view().elem_cnt(), ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -187,18 +192,22 @@ class NcclLogicalAllGatherKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = ctx->parallel_ctx().parallel_num(); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - VLOG(3) << "[NcclLogical][AllGather] " << nccl_comm->stream_name() << " " << ctx->op_name() + VLOG(3) << "[NcclLogical][AllGather] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), nccl_comm->comm(), - ctx->stream()->As()->cuda_stream())); + + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -250,9 +259,12 @@ class NcclLogicalAllGatherNoncontinuous final : public user_op::OpKernel { // NOTE(chengcheng): Do AllGather CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), + ccl_comm); CHECK_GT(in_split_axis, 0); // NOTE(chengcheng): Do unpack. @@ -334,12 +346,15 @@ class NcclLogicalReduceScatterNoncontinuous final : public user_op::OpKernel { transpose_in_dim_vec.data(), in->dptr(), perm.data(), tmp_buffer->mut_dptr()); VLOG(3) << "[NcclLogical][ReduceScatterNoncontinuous] " << kernel_state->stream_name() << " " << ctx->op_name() << std::endl; - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclReduceScatter(tmp_buffer->dptr(), out->mut_dptr(), - out->shape_view().elem_cnt(), GetNcclDataType(in->data_type()), - reduce_type, kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); + + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_reduce_scatter = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_reduce_scatter->Launch(ctx->stream(), tmp_buffer->dptr(), out->mut_dptr(), + out->shape_view().elem_cnt(), ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -437,23 +452,13 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel { } { - // NOTE(chengcheng): init nccl comm need before ncclGroupStart. - ncclComm_t comm = kernel_state->comm(); // NOTE(chengcheng): Do S2S - OF_NCCL_CHECK(ncclGroupStart()); const int64_t elem_per_chunk = elem_cnt / num_ranks; - const int64_t chunk_size = elem_per_chunk * dtype_size; - for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( - reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); - } - OF_NCCL_CHECK(ncclGroupEnd()); + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); + all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, + unpack_from_ptr, elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index 36478b8c27a..fa06e22fd44 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "collective_communication/include/collective_communication.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/framework.h" @@ -27,6 +28,7 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -44,22 +46,23 @@ class NcclLogicalSendRecvState final : public user_op::OpKernelState { bool src_nd_sbp_has_no_partial_parallel() const { return src_nd_sbp_no_partial_parallel_; } const std::vector& send_elem_cnts() const { return send_elem_cnts_; } const std::vector& recv_elem_cnts() const { return recv_elem_cnts_; } - ncclComm_t comm() const { return GetOrCreateComm().comm; } + ccl::CclComm ccl_comm() const { return GetOrCreateComm().ccl_comm; } private: struct Comm { - explicit Comm(ncclComm_t comm) : comm(comm) {} - ncclComm_t comm; + Comm(ccl::CclComm comm) : ccl_comm(comm) {} + ccl::CclComm ccl_comm; }; + void InitComm() const; const Comm& GetOrCreateComm() const { - if (!comm_) { InitComm(); } - return *comm_; + if (!ccl_comm_) { InitComm(); } + return *ccl_comm_; } std::string stream_name_; std::unique_ptr parallel_desc_; - mutable std::unique_ptr comm_; + mutable std::unique_ptr ccl_comm_; bool src_nd_sbp_no_partial_parallel_; std::vector> in_tensor_slice_copier_vec_; std::vector> out_tensor_slice_copier_vec_; @@ -123,16 +126,10 @@ NcclLogicalSendRecvState::NcclLogicalSendRecvState(user_op::KernelInitContext* c } void NcclLogicalSendRecvState::InitComm() const { - std::set> device_set; - for (int64_t parallel_id = 0; parallel_id < parallel_desc_->parallel_num(); ++parallel_id) { - int64_t machine_id = CHECK_JUST(parallel_desc_->MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc_->DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ncclComm_t comm = nullptr; - comm = comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); - comm_.reset(new Comm(comm)); + ccl::CclComm ccl_comm = + comm_mgr->GetCclCommForParallelDescAndStreamName(*parallel_desc_.get(), stream_name_); + ccl_comm_.reset(new Comm(ccl_comm)); } class NcclLogicalSendRecv final : public user_op::OpKernel { @@ -163,8 +160,7 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - ncclComm_t comm = kernel_state->comm(); - cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); const std::vector& send_elem_cnts = kernel_state->send_elem_cnts(); const std::vector& recv_elem_cnts = kernel_state->recv_elem_cnts(); const int64_t parallel_num = send_elem_cnts.size(); @@ -172,16 +168,21 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O std::vector send_in_ptr; std::vector recv_out_ptr; + std::vector send_offsets; + std::vector recv_offsets; char* buf_ptr = tmp_buffer->mut_dptr(); - int64_t offset = 0; + uint64_t offset = 0; for (int64_t i = 0; i < parallel_num; ++i) { void* send_ptr = reinterpret_cast(buf_ptr + offset); send_in_ptr.push_back(send_ptr); + send_offsets.push_back(offset); offset += send_elem_cnts.at(i) * GetSizeOfDataType(data_type); } + const uint64_t recv_offset = offset; for (int64_t i = 0; i < parallel_num; ++i) { void* recv_ptr = reinterpret_cast(buf_ptr + offset); recv_out_ptr.push_back(recv_ptr); + recv_offsets.push_back(offset - recv_offset); offset += recv_elem_cnts.at(i) * GetSizeOfDataType(data_type); } @@ -192,21 +193,15 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O in_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), send_in_ptr.at(i), in->dptr()); } } - const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); - OF_NCCL_CHECK(ncclGroupStart()); - for (int64_t i = 0; i < parallel_num; ++i) { - if (send_elem_cnts.at(i) != 0) { - LOG(INFO) << parallel_id << " send " << send_elem_cnts.at(i) << " to " << i; - OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), i, - comm, cuda_stream)); - } - if (recv_elem_cnts.at(i) != 0) { - LOG(INFO) << parallel_id << " recv " << recv_elem_cnts.at(i) << " from " << i; - OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), GetNcclDataType(data_type), - i, comm, cuda_stream)); - } - } - OF_NCCL_CHECK(ncclGroupEnd()); + + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), data_type, data_type, parallel_num); + void* send_buf = reinterpret_cast(buf_ptr); + void* recv_buf = reinterpret_cast(buf_ptr + recv_offset); + all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), recv_buf, + recv_elem_cnts.data(), recv_offsets.data(), ccl_comm, /*has_input=*/true, + /*has_output=*/true); + const std::vector>& out_tensor_slice_copier_vec = kernel_state->out_tensor_slice_copier_vec();