From c9b7811e74180d5de64e4ac54f81ddc168bf6f6e Mon Sep 17 00:00:00 2001 From: luyang Date: Fri, 27 Dec 2024 06:45:33 +0000 Subject: [PATCH 01/28] raw impl --- .../core/graph/boxing/collective_boxing.proto | 5 +++ .../kernel/nccl_send_recv_boxing_kernel.cpp | 30 ++++++++++++---- .../collective_communication/cpu/cpu_recv.cpp | 5 +++ .../collective_communication/cpu/cpu_send.cpp | 5 +++ .../cuda/cuda_communication_context.h | 10 ++++++ .../cuda/cuda_recv.cpp | 12 +++++++ .../cuda/cuda_send.cpp | 12 +++++++ .../include/collective_communication.h | 34 +++++++++++++++++++ .../collective_communication/include/recv.h | 3 ++ .../collective_communication/include/send.h | 4 +++ .../kernels/nccl_logical_send_recv_kernel.cpp | 21 +++++++++--- 11 files changed, 131 insertions(+), 10 deletions(-) diff --git a/oneflow/core/graph/boxing/collective_boxing.proto b/oneflow/core/graph/boxing/collective_boxing.proto index 0f3d01d525c..c024eb733a2 100644 --- a/oneflow/core/graph/boxing/collective_boxing.proto +++ b/oneflow/core/graph/boxing/collective_boxing.proto @@ -15,6 +15,11 @@ enum OpType { kOpTypeAll2All = 6; } +enum CclCommType{ + kCommTypeCudaNccl = 0; + kCommTypeAscendHccl = 1; +} + enum ReduceMethod { kReduceMethodInvalid = 0; kReduceMethodSum = 1; diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp index c342f4a2f42..f6a19b97eef 100644 --- a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -20,8 +20,10 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" +#include "oneflow/user/kernels/collective_communication/include/send.h" +#include "oneflow/user/kernels/collective_communication/include/recv.h" -#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 +// #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 namespace oneflow { @@ -84,6 +86,7 @@ class NcclSendRecvBoxingKernel final : public Kernel { }; void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { + printf("\n NcclSendRecvBoxingKernel::ForwardDataContent()"); Blob* buf = ctx->BnInOp2Blob("buf"); ncclComm_t comm = this->comm(); cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); @@ -122,12 +125,27 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { OF_NCCL_CHECK(ncclGroupStart()); for (int64_t i = 0; i < parallel_num; ++i) { if (this->has_input() && send_elem_cnts.at(i) != 0) { - OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), i, - comm, cuda_stream)); + // OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), + // i, + // comm, cuda_stream)); + std::unique_ptr send = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); + + std::shared_ptr ncclCommAdapter = + std::make_shared(&comm); + ccl::CclComm ccl_comm(ncclCommAdapter); + send->Launch(ctx->stream(), send_in_ptr.at(i), send_elem_cnts.at(i), i, ccl_comm); } if (this->has_output() && recv_elem_cnts.at(i) != 0) { - OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), GetNcclDataType(data_type), - i, comm, cuda_stream)); + // OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), + // GetNcclDataType(data_type), + // i, comm, cuda_stream)); + std::unique_ptr recv = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); + std::shared_ptr ncclCommAdapter = + std::make_shared(&comm); + ccl::CclComm ccl_comm(ncclCommAdapter); + recv->Launch(ctx->stream(), recv_out_ptr.at(i), recv_elem_cnts.at(i), i, ccl_comm); } } OF_NCCL_CHECK(ncclGroupEnd()); @@ -254,4 +272,4 @@ REGISTER_SYSTEM_OP_KERNEL_UNIFIED_NCCL_COMM_INIT(OperatorConf::kNcclSendRecvBoxi } // namespace oneflow -#endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 +// #endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp index 412e2442c12..9f903ef7761 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp @@ -40,6 +40,11 @@ class CpuRecvImpl final : public Recv { CHECK_JUST(CpuRecv(out, buffer_size, src)); } + void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, + CclComm ccl_comm) const override { + Launch(stream, out, elem_cnt, src); + } + private: size_t size_of_dtype_; }; diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp index a0e62957fbd..f4cbddeede3 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp @@ -40,6 +40,11 @@ class CpuSendImpl final : public Send { CHECK_JUST(CpuSend(in, buffer_size, dst)); } + void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, + CclComm comm) const override { + Launch(stream, in, elem_cnt, dst); + } + private: size_t size_of_dtype_; }; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h b/oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h index c3a45939cae..804c699a8a1 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h @@ -28,6 +28,16 @@ namespace oneflow { namespace ccl { +// class NcclCommAdapter : public CommBase { +// public: +// NcclCommAdapter(ncclComm_t* comm) : comm_(comm) {} + +// void* getComm() override { return static_cast(comm_); } + +// private: +// ncclComm_t* comm_; +// }; + class CudaCommunicationContext : public CommunicationContext { public: explicit CudaCommunicationContext() = default; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp index cc4bcfafe3f..64ae471cd63 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp @@ -16,6 +16,7 @@ limitations under the License. #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/recv.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h" +#include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" namespace oneflow { @@ -40,6 +41,17 @@ class CudaRecv final : public Recv { #endif // HAS_NCCL_SEND_RECV } + void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, + CclComm ccl_comm) const override { +#if HAS_NCCL_SEND_RECV + ncclComm_t* comm = reinterpret_cast(ccl_comm.getComm()); + OF_NCCL_CHECK(ncclRecv(out, elem_cnt, nccl_datatype_, src, *comm, + stream->As()->cuda_stream())); +#else + UNIMPLEMENTED() << "GPU recv is only supported when nccl version >= 2.7" +#endif // HAS_NCCL_SEND_RECV + } + private: ncclDataType_t nccl_datatype_; }; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp index da7ac181252..f5aa6d9045c 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp @@ -16,6 +16,7 @@ limitations under the License. #ifdef WITH_CUDA #include "oneflow/user/kernels/collective_communication/include/send.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_send_recv_util.h" +#include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" namespace oneflow { @@ -40,6 +41,17 @@ class CudaSend final : public Send { #endif // HAS_NCCL_SEND_RECV } + void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, + CclComm ccl_comm) const override { +#if HAS_NCCL_SEND_RECV + ncclComm_t* comm = reinterpret_cast(ccl_comm.getComm()); + OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, dst, *comm, + stream->As()->cuda_stream())); +#else + UNIMPLEMENTED() << "GPU send is only supported when nccl version >= 2.7" +#endif // HAS_NCCL_SEND_RECV + } + private: ncclDataType_t nccl_datatype_; }; diff --git a/oneflow/user/kernels/collective_communication/include/collective_communication.h b/oneflow/user/kernels/collective_communication/include/collective_communication.h index c197820d974..73070ac1c4a 100644 --- a/oneflow/user/kernels/collective_communication/include/collective_communication.h +++ b/oneflow/user/kernels/collective_communication/include/collective_communication.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COLLECTIVE_COMMUNICATION_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COLLECTIVE_COMMUNICATION_H_ +#include +#include #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/user/kernels/collective_communication/include/communication_context.h" @@ -41,6 +43,38 @@ enum ReduceType { MAKE_TYPED_CTRV_SEQ(ReduceType, \ OF_PP_FOR_EACH_TUPLE(OF_PP_I_MAKE_REPLICATE_TUPLE_SEQ, REDUCE_TYPE_SEQ)) +// abstruct base class for comm +class CommBase { + public: + virtual ~CommBase() = default; + + // return impl of comm + virtual void* getComm() = 0; +}; + +#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 +#include +class NcclCommAdapter : public CommBase { + public: + NcclCommAdapter(ncclComm_t* comm) : comm_(comm) {} + + void* getComm() override { return static_cast(comm_); } + + private: + ncclComm_t* comm_; +}; +#endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 + +class CclComm { + public: + explicit CclComm(std::shared_ptr comm) : comm_(std::move(comm)) {} + + void* getComm() { return comm_->getComm(); } + + private: + std::shared_ptr comm_{}; +}; + class CollectiveCommunication { public: OF_DISALLOW_COPY_AND_MOVE(CollectiveCommunication); diff --git a/oneflow/user/kernels/collective_communication/include/recv.h b/oneflow/user/kernels/collective_communication/include/recv.h index 59c1aef849f..7d6b1c24a5c 100644 --- a/oneflow/user/kernels/collective_communication/include/recv.h +++ b/oneflow/user/kernels/collective_communication/include/recv.h @@ -31,6 +31,9 @@ class Recv : public CollectiveCommunication { virtual void Init(DataType dtype) = 0; virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0; + + virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, + CclComm ccl_comm) const = 0; }; inline bool IsRecvRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/send.h b/oneflow/user/kernels/collective_communication/include/send.h index 6658c7de292..00e3800c12a 100644 --- a/oneflow/user/kernels/collective_communication/include/send.h +++ b/oneflow/user/kernels/collective_communication/include/send.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_ +#include "collective_communication.h" #include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { @@ -31,6 +32,9 @@ class Send : public CollectiveCommunication { virtual void Init(DataType dtype) = 0; virtual void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const = 0; + + virtual void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, + CclComm ccl_comm) const = 0; }; inline bool IsSendRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index 36478b8c27a..add3fa4ef12 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -27,6 +27,8 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" +#include "oneflow/user/kernels/collective_communication/include/send.h" +#include "oneflow/user/kernels/collective_communication/include/recv.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -158,6 +160,7 @@ class NcclLogicalSendRecv final : public user_op::OpKernel { void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const { + printf("\n NcclLogicalSendRecv::Compute()"); auto* kernel_state = dynamic_cast(state); CHECK_NOTNULL(kernel_state); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); @@ -197,13 +200,23 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O for (int64_t i = 0; i < parallel_num; ++i) { if (send_elem_cnts.at(i) != 0) { LOG(INFO) << parallel_id << " send " << send_elem_cnts.at(i) << " to " << i; - OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), i, - comm, cuda_stream)); + // OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), + // i, + // comm, cuda_stream)); + printf("\n NcclLogicalSendRecv::Compute() >>> ccl::Send"); + std::unique_ptr send = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); + send->Launch(ctx->stream(), send_in_ptr.at(i), send_elem_cnts.at(i), i); } if (recv_elem_cnts.at(i) != 0) { LOG(INFO) << parallel_id << " recv " << recv_elem_cnts.at(i) << " from " << i; - OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), GetNcclDataType(data_type), - i, comm, cuda_stream)); + // OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), + // GetNcclDataType(data_type), + // i, comm, cuda_stream)); + printf("\n NcclLogicalSendRecv::Compute() >>> ccl::Recv"); + std::unique_ptr recv = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); + recv->Launch(ctx->stream(), recv_out_ptr.at(i), recv_elem_cnts.at(i), i); } } OF_NCCL_CHECK(ncclGroupEnd()); From 7bcc5938feade07a9fd703c04a47e5c27185678a Mon Sep 17 00:00:00 2001 From: zhaoluyang Date: Fri, 27 Dec 2024 13:34:55 +0000 Subject: [PATCH 02/28] refine --- oneflow/core/job/eager_ccl_comm_manager.h | 1 + oneflow/core/job/eager_nccl_comm_manager.h | 11 +++++++++++ .../include/collective_communication.h | 13 ------------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/oneflow/core/job/eager_ccl_comm_manager.h b/oneflow/core/job/eager_ccl_comm_manager.h index eb747f78dac..13a54b272d7 100644 --- a/oneflow/core/job/eager_ccl_comm_manager.h +++ b/oneflow/core/job/eager_ccl_comm_manager.h @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/job/plan.pb.h" +#include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { diff --git a/oneflow/core/job/eager_nccl_comm_manager.h b/oneflow/core/job/eager_nccl_comm_manager.h index 2210983f3a1..5a004b734d6 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.h +++ b/oneflow/core/job/eager_nccl_comm_manager.h @@ -25,6 +25,17 @@ limitations under the License. #include "oneflow/core/device/cuda_util.h" namespace oneflow { +namespace ccl { +class NcclCommAdapter : public CommBase { + public: + NcclCommAdapter(ncclComm_t* comm) : comm_(comm) {} + + void* getComm() override { return static_cast(comm_); } + + private: + ncclComm_t* comm_; +}; +} // namespace ccl class EagerNcclCommMgr final : public EagerCclCommMgr { public: diff --git a/oneflow/user/kernels/collective_communication/include/collective_communication.h b/oneflow/user/kernels/collective_communication/include/collective_communication.h index 73070ac1c4a..83f2671771e 100644 --- a/oneflow/user/kernels/collective_communication/include/collective_communication.h +++ b/oneflow/user/kernels/collective_communication/include/collective_communication.h @@ -52,19 +52,6 @@ class CommBase { virtual void* getComm() = 0; }; -#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 -#include -class NcclCommAdapter : public CommBase { - public: - NcclCommAdapter(ncclComm_t* comm) : comm_(comm) {} - - void* getComm() override { return static_cast(comm_); } - - private: - ncclComm_t* comm_; -}; -#endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 - class CclComm { public: explicit CclComm(std::shared_ptr comm) : comm_(std::move(comm)) {} From d70c46466131fbcf620281702d70e0900a3cfd8a Mon Sep 17 00:00:00 2001 From: zhaoluyang Date: Wed, 1 Jan 2025 13:39:21 +0000 Subject: [PATCH 03/28] refine --- .../kernel/nccl_send_recv_boxing_kernel.cpp | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp index f6a19b97eef..8804cd65620 100644 --- a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -122,29 +122,25 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { } } } + std::unique_ptr send = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); + std::unique_ptr recv = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); + std::shared_ptr ncclCommAdapter = + std::make_shared(&comm); + ccl::CclComm ccl_comm(ncclCommAdapter); OF_NCCL_CHECK(ncclGroupStart()); for (int64_t i = 0; i < parallel_num; ++i) { if (this->has_input() && send_elem_cnts.at(i) != 0) { // OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), // i, // comm, cuda_stream)); - std::unique_ptr send = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); - - std::shared_ptr ncclCommAdapter = - std::make_shared(&comm); - ccl::CclComm ccl_comm(ncclCommAdapter); send->Launch(ctx->stream(), send_in_ptr.at(i), send_elem_cnts.at(i), i, ccl_comm); } if (this->has_output() && recv_elem_cnts.at(i) != 0) { // OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), // GetNcclDataType(data_type), // i, comm, cuda_stream)); - std::unique_ptr recv = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); - std::shared_ptr ncclCommAdapter = - std::make_shared(&comm); - ccl::CclComm ccl_comm(ncclCommAdapter); recv->Launch(ctx->stream(), recv_out_ptr.at(i), recv_elem_cnts.at(i), i, ccl_comm); } } From 89411cd154f9a3181f3ec0ad1df9dca26369fb49 Mon Sep 17 00:00:00 2001 From: zhaoluyang Date: Thu, 2 Jan 2025 01:30:05 +0000 Subject: [PATCH 04/28] impl of ccl::CclComm --- oneflow/core/job/eager_ccl_comm_manager.h | 5 +++ oneflow/core/job/eager_nccl_comm_manager.cpp | 8 +++++ oneflow/core/job/eager_nccl_comm_manager.h | 10 ++++-- .../kernel/nccl_send_recv_boxing_kernel.cpp | 32 +++++++------------ .../include/collective_communication.h | 1 + .../kernels/nccl_logical_send_recv_kernel.cpp | 8 ----- 6 files changed, 32 insertions(+), 32 deletions(-) diff --git a/oneflow/core/job/eager_ccl_comm_manager.h b/oneflow/core/job/eager_ccl_comm_manager.h index 13a54b272d7..3c2405dcf89 100644 --- a/oneflow/core/job/eager_ccl_comm_manager.h +++ b/oneflow/core/job/eager_ccl_comm_manager.h @@ -30,6 +30,11 @@ class EagerCclCommMgr { virtual void CreateCommFromPlan(const Plan& plan) = 0; virtual bool IsAsyncLaunchCclLogicalKernel() const = 0; virtual void SetAsyncLaunchCclLogicalKernel(bool val) = 0; + virtual ccl::CclComm GetCclCommForDeviceAndStreamName( + const std::set>& device_set, const std::string& stream_name) { + ccl::CclComm ccl_comm{}; + return ccl_comm; + } template T* As() { diff --git a/oneflow/core/job/eager_nccl_comm_manager.cpp b/oneflow/core/job/eager_nccl_comm_manager.cpp index 2b3e5a4a735..342c394293f 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.cpp +++ b/oneflow/core/job/eager_nccl_comm_manager.cpp @@ -156,6 +156,14 @@ ncclComm_t EagerNcclCommMgr::GetCommForDeviceAndStreamName( return comm; } +ccl::CclComm EagerNcclCommMgr::GetCclCommForDeviceAndStreamName( + const std::set>& device_set, const std::string& stream_name) { + ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name); + std::shared_ptr ncclCommAdapter = std::make_shared(comm); + ccl::CclComm ccl_comm(ncclCommAdapter); + return ccl_comm; +} + void EagerNcclCommMgr::CreateCommFromPlan(const Plan& plan) { const int64_t rank = GlobalProcessCtx::Rank(); const int64_t dev = GlobalProcessCtx::LocalRank(); diff --git a/oneflow/core/job/eager_nccl_comm_manager.h b/oneflow/core/job/eager_nccl_comm_manager.h index 5a004b734d6..8e96fc62f2d 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.h +++ b/oneflow/core/job/eager_nccl_comm_manager.h @@ -26,15 +26,17 @@ limitations under the License. namespace oneflow { namespace ccl { + class NcclCommAdapter : public CommBase { public: - NcclCommAdapter(ncclComm_t* comm) : comm_(comm) {} + NcclCommAdapter(ncclComm_t comm) : comm_(comm) {} - void* getComm() override { return static_cast(comm_); } + void* getComm() override { return static_cast(&comm_); } private: - ncclComm_t* comm_; + ncclComm_t comm_; }; + } // namespace ccl class EagerNcclCommMgr final : public EagerCclCommMgr { @@ -47,6 +49,8 @@ class EagerNcclCommMgr final : public EagerCclCommMgr { ncclComm_t GetCommForDevice(const std::set>& device_set); ncclComm_t GetCommForDeviceAndStreamName(const std::set>& device_set, const std::string& stream_name); + ccl::CclComm GetCclCommForDeviceAndStreamName( + const std::set>& device_set, const std::string& stream_name); void CreateCommFromPlan(const Plan& plan) override; bool IsAsyncLaunchCclLogicalKernel() const override { return async_launch_nccl_logical_kernel_; } diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp index 8804cd65620..e1111fc8061 100644 --- a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -43,12 +43,12 @@ class NcclSendRecvBoxingKernel final : public Kernel { const std::vector& recv_elem_cnts() const { return recv_elem_cnts_; } const bool has_input() const { return has_input_; } const bool has_output() const { return has_output_; } - ncclComm_t comm() const { return GetOrCreate().comm; } + ccl::CclComm ccl_comm() const { return GetOrCreate().ccl_comm; } private: struct Comm { - Comm(ncclComm_t comm) : comm(comm) {} - ncclComm_t comm; + Comm(ccl::CclComm comm) : ccl_comm(comm) {} + ccl::CclComm ccl_comm; }; void Init() const { @@ -60,14 +60,13 @@ class NcclSendRecvBoxingKernel final : public Kernel { device_set.emplace(std::make_pair(machine_id, device_id)); } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ncclComm_t comm = - comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); - comm_.reset(new Comm(comm)); + ccl::CclComm ccl_comm = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_.reset(new Comm(ccl_comm)); } const Comm& GetOrCreate() const { - if (!comm_) { Init(); } - return *comm_; + if (!ccl_comm_) { Init(); } + return *ccl_comm_; } void VirtualKernelInit(KernelContext* ctx) override; @@ -75,7 +74,7 @@ class NcclSendRecvBoxingKernel final : public Kernel { std::string stream_name_; ParallelConf parallel_conf_; - mutable std::unique_ptr comm_; + mutable std::unique_ptr ccl_comm_; bool src_nd_sbp_no_partial_parallel_; std::vector> in_tensor_slice_copier_vec_; std::vector> out_tensor_slice_copier_vec_; @@ -86,10 +85,8 @@ class NcclSendRecvBoxingKernel final : public Kernel { }; void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { - printf("\n NcclSendRecvBoxingKernel::ForwardDataContent()"); Blob* buf = ctx->BnInOp2Blob("buf"); - ncclComm_t comm = this->comm(); - cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); + ccl::CclComm ccl_comm = this->ccl_comm(); const std::vector& send_elem_cnts = this->send_elem_cnts(); const std::vector& recv_elem_cnts = this->recv_elem_cnts(); const int64_t parallel_num = this->kernel_conf().parallel_ctx().parallel_num(); @@ -122,25 +119,18 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { } } } + // init ccl Send/Recv primitive std::unique_ptr send = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); std::unique_ptr recv = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); - std::shared_ptr ncclCommAdapter = - std::make_shared(&comm); - ccl::CclComm ccl_comm(ncclCommAdapter); + // launch ccl::Send/Recv OF_NCCL_CHECK(ncclGroupStart()); for (int64_t i = 0; i < parallel_num; ++i) { if (this->has_input() && send_elem_cnts.at(i) != 0) { - // OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), - // i, - // comm, cuda_stream)); send->Launch(ctx->stream(), send_in_ptr.at(i), send_elem_cnts.at(i), i, ccl_comm); } if (this->has_output() && recv_elem_cnts.at(i) != 0) { - // OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), - // GetNcclDataType(data_type), - // i, comm, cuda_stream)); recv->Launch(ctx->stream(), recv_out_ptr.at(i), recv_elem_cnts.at(i), i, ccl_comm); } } diff --git a/oneflow/user/kernels/collective_communication/include/collective_communication.h b/oneflow/user/kernels/collective_communication/include/collective_communication.h index 83f2671771e..53433d711c1 100644 --- a/oneflow/user/kernels/collective_communication/include/collective_communication.h +++ b/oneflow/user/kernels/collective_communication/include/collective_communication.h @@ -54,6 +54,7 @@ class CommBase { class CclComm { public: + CclComm() {} explicit CclComm(std::shared_ptr comm) : comm_(std::move(comm)) {} void* getComm() { return comm_->getComm(); } diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index add3fa4ef12..28f9d5fb1f0 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -160,13 +160,11 @@ class NcclLogicalSendRecv final : public user_op::OpKernel { void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const { - printf("\n NcclLogicalSendRecv::Compute()"); auto* kernel_state = dynamic_cast(state); CHECK_NOTNULL(kernel_state); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - ncclComm_t comm = kernel_state->comm(); cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); const std::vector& send_elem_cnts = kernel_state->send_elem_cnts(); const std::vector& recv_elem_cnts = kernel_state->recv_elem_cnts(); @@ -200,9 +198,6 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O for (int64_t i = 0; i < parallel_num; ++i) { if (send_elem_cnts.at(i) != 0) { LOG(INFO) << parallel_id << " send " << send_elem_cnts.at(i) << " to " << i; - // OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), - // i, - // comm, cuda_stream)); printf("\n NcclLogicalSendRecv::Compute() >>> ccl::Send"); std::unique_ptr send = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); @@ -210,9 +205,6 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O } if (recv_elem_cnts.at(i) != 0) { LOG(INFO) << parallel_id << " recv " << recv_elem_cnts.at(i) << " from " << i; - // OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), - // GetNcclDataType(data_type), - // i, comm, cuda_stream)); printf("\n NcclLogicalSendRecv::Compute() >>> ccl::Recv"); std::unique_ptr recv = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); From 9ec2bc6b8b00e294cd599f087ae68daf53d19d59 Mon Sep 17 00:00:00 2001 From: zhaoluyang Date: Thu, 2 Jan 2025 01:50:23 +0000 Subject: [PATCH 05/28] refine --- oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp | 6 +----- .../cuda/cuda_communication_context.h | 10 ---------- .../kernels/collective_communication/include/send.h | 1 - oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp | 2 -- 4 files changed, 1 insertion(+), 18 deletions(-) diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp index e1111fc8061..5e179a45278 100644 --- a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -23,8 +23,6 @@ limitations under the License. #include "oneflow/user/kernels/collective_communication/include/send.h" #include "oneflow/user/kernels/collective_communication/include/recv.h" -// #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 - namespace oneflow { class NcclSendRecvBoxingKernel final : public Kernel { @@ -256,6 +254,4 @@ REGISTER_KERNEL(OperatorConf::kNcclSendRecvBoxingConf, NcclSendRecvBoxingKernel) REGISTER_SYSTEM_OP_KERNEL_UNIFIED_NCCL_COMM_INIT(OperatorConf::kNcclSendRecvBoxingConf); -} // namespace oneflow - -// #endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 +} // namespace oneflow \ No newline at end of file diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h b/oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h index 804c699a8a1..c3a45939cae 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h @@ -28,16 +28,6 @@ namespace oneflow { namespace ccl { -// class NcclCommAdapter : public CommBase { -// public: -// NcclCommAdapter(ncclComm_t* comm) : comm_(comm) {} - -// void* getComm() override { return static_cast(comm_); } - -// private: -// ncclComm_t* comm_; -// }; - class CudaCommunicationContext : public CommunicationContext { public: explicit CudaCommunicationContext() = default; diff --git a/oneflow/user/kernels/collective_communication/include/send.h b/oneflow/user/kernels/collective_communication/include/send.h index 00e3800c12a..23fbda3cb2b 100644 --- a/oneflow/user/kernels/collective_communication/include/send.h +++ b/oneflow/user/kernels/collective_communication/include/send.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_SEND_H_ -#include "collective_communication.h" #include "oneflow/user/kernels/collective_communication/include/collective_communication.h" namespace oneflow { diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index 28f9d5fb1f0..94c0ecd5a79 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -198,14 +198,12 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O for (int64_t i = 0; i < parallel_num; ++i) { if (send_elem_cnts.at(i) != 0) { LOG(INFO) << parallel_id << " send " << send_elem_cnts.at(i) << " to " << i; - printf("\n NcclLogicalSendRecv::Compute() >>> ccl::Send"); std::unique_ptr send = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); send->Launch(ctx->stream(), send_in_ptr.at(i), send_elem_cnts.at(i), i); } if (recv_elem_cnts.at(i) != 0) { LOG(INFO) << parallel_id << " recv " << recv_elem_cnts.at(i) << " from " << i; - printf("\n NcclLogicalSendRecv::Compute() >>> ccl::Recv"); std::unique_ptr recv = ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); recv->Launch(ctx->stream(), recv_out_ptr.at(i), recv_elem_cnts.at(i), i); From a0b0391ed9e0d96112518b3567c32cd1bc4384fc Mon Sep 17 00:00:00 2001 From: zhaoluyang Date: Thu, 2 Jan 2025 01:52:24 +0000 Subject: [PATCH 06/28] refine --- oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index 94c0ecd5a79..b11accbb8fb 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -194,18 +194,19 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O } } const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); + + std::unique_ptr send = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); + std::unique_ptr recv = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); OF_NCCL_CHECK(ncclGroupStart()); for (int64_t i = 0; i < parallel_num; ++i) { if (send_elem_cnts.at(i) != 0) { LOG(INFO) << parallel_id << " send " << send_elem_cnts.at(i) << " to " << i; - std::unique_ptr send = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); send->Launch(ctx->stream(), send_in_ptr.at(i), send_elem_cnts.at(i), i); } if (recv_elem_cnts.at(i) != 0) { LOG(INFO) << parallel_id << " recv " << recv_elem_cnts.at(i) << " from " << i; - std::unique_ptr recv = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); recv->Launch(ctx->stream(), recv_out_ptr.at(i), recv_elem_cnts.at(i), i); } } From 2c48a5e8365badc223660e3c2f801369af4ab948 Mon Sep 17 00:00:00 2001 From: luyang Date: Thu, 2 Jan 2025 02:57:55 +0000 Subject: [PATCH 07/28] refactor _nccl_logical_send_recv using ccl::Comm primitive --- .../kernel/nccl_send_recv_boxing_kernel.cpp | 6 +-- .../kernels/nccl_logical_send_recv_kernel.cpp | 48 ++++++++++++------- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp index 5e179a45278..afc366be174 100644 --- a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -119,9 +119,9 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { } // init ccl Send/Recv primitive std::unique_ptr send = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); std::unique_ptr recv = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); // launch ccl::Send/Recv OF_NCCL_CHECK(ncclGroupStart()); for (int64_t i = 0; i < parallel_num; ++i) { @@ -254,4 +254,4 @@ REGISTER_KERNEL(OperatorConf::kNcclSendRecvBoxingConf, NcclSendRecvBoxingKernel) REGISTER_SYSTEM_OP_KERNEL_UNIFIED_NCCL_COMM_INIT(OperatorConf::kNcclSendRecvBoxingConf); -} // namespace oneflow \ No newline at end of file +} // namespace oneflow diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index b11accbb8fb..d9e51aca1a8 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "collective_communication/include/collective_communication.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/framework/framework.h" @@ -30,7 +31,7 @@ limitations under the License. #include "oneflow/user/kernels/collective_communication/include/send.h" #include "oneflow/user/kernels/collective_communication/include/recv.h" -#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 +// #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 namespace oneflow { @@ -46,22 +47,33 @@ class NcclLogicalSendRecvState final : public user_op::OpKernelState { bool src_nd_sbp_has_no_partial_parallel() const { return src_nd_sbp_no_partial_parallel_; } const std::vector& send_elem_cnts() const { return send_elem_cnts_; } const std::vector& recv_elem_cnts() const { return recv_elem_cnts_; } - ncclComm_t comm() const { return GetOrCreateComm().comm; } + // ncclComm_t comm() const { return GetOrCreateComm().comm; } + ccl::CclComm ccl_comm() const { return GetOrCreateComm().ccl_comm; } private: + // struct Comm { + // explicit Comm(ncclComm_t comm) : comm(comm) {} + // ncclComm_t comm; + // }; struct Comm { - explicit Comm(ncclComm_t comm) : comm(comm) {} - ncclComm_t comm; + Comm(ccl::CclComm comm) : ccl_comm(comm) {} + ccl::CclComm ccl_comm; }; + void InitComm() const; + // const Comm& GetOrCreateComm() const { + // if (!comm_) { InitComm(); } + // return *comm_; + // } const Comm& GetOrCreateComm() const { - if (!comm_) { InitComm(); } - return *comm_; + if (!ccl_comm_) { InitComm(); } + return *ccl_comm_; } std::string stream_name_; std::unique_ptr parallel_desc_; - mutable std::unique_ptr comm_; + // mutable std::unique_ptr comm_; + mutable std::unique_ptr ccl_comm_; bool src_nd_sbp_no_partial_parallel_; std::vector> in_tensor_slice_copier_vec_; std::vector> out_tensor_slice_copier_vec_; @@ -132,9 +144,11 @@ void NcclLogicalSendRecvState::InitComm() const { device_set.emplace(std::make_pair(machine_id, device_id)); } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ncclComm_t comm = nullptr; - comm = comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); - comm_.reset(new Comm(comm)); + // ncclComm_t comm = nullptr; + // comm = comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, + // stream_name_); comm_.reset(new Comm(comm)); + ccl::CclComm ccl_comm = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_.reset(new Comm(ccl_comm)); } class NcclLogicalSendRecv final : public user_op::OpKernel { @@ -165,7 +179,9 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); + // ncclComm_t comm = kernel_state->comm(); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); + // cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); const std::vector& send_elem_cnts = kernel_state->send_elem_cnts(); const std::vector& recv_elem_cnts = kernel_state->recv_elem_cnts(); const int64_t parallel_num = send_elem_cnts.size(); @@ -196,18 +212,18 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); std::unique_ptr send = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); std::unique_ptr recv = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); OF_NCCL_CHECK(ncclGroupStart()); for (int64_t i = 0; i < parallel_num; ++i) { if (send_elem_cnts.at(i) != 0) { LOG(INFO) << parallel_id << " send " << send_elem_cnts.at(i) << " to " << i; - send->Launch(ctx->stream(), send_in_ptr.at(i), send_elem_cnts.at(i), i); + send->Launch(ctx->stream(), send_in_ptr.at(i), send_elem_cnts.at(i), i, ccl_comm); } if (recv_elem_cnts.at(i) != 0) { LOG(INFO) << parallel_id << " recv " << recv_elem_cnts.at(i) << " from " << i; - recv->Launch(ctx->stream(), recv_out_ptr.at(i), recv_elem_cnts.at(i), i); + recv->Launch(ctx->stream(), recv_out_ptr.at(i), recv_elem_cnts.at(i), i, ccl_comm); } } OF_NCCL_CHECK(ncclGroupEnd()); @@ -300,4 +316,4 @@ REGISTER_USER_KERNEL("_nccl_logical_send_recv") } // namespace oneflow -#endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 +// #endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 From aac19b45898ebb518e0fb01870c82db7a773d973 Mon Sep 17 00:00:00 2001 From: luyang Date: Thu, 2 Jan 2025 03:00:29 +0000 Subject: [PATCH 08/28] refactor _nccl_logical_send_recv using ccl::Comm primitive --- .../kernels/nccl_logical_send_recv_kernel.cpp | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index d9e51aca1a8..59bc8fa151c 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -31,7 +31,7 @@ limitations under the License. #include "oneflow/user/kernels/collective_communication/include/send.h" #include "oneflow/user/kernels/collective_communication/include/recv.h" -// #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 +#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 namespace oneflow { @@ -47,24 +47,15 @@ class NcclLogicalSendRecvState final : public user_op::OpKernelState { bool src_nd_sbp_has_no_partial_parallel() const { return src_nd_sbp_no_partial_parallel_; } const std::vector& send_elem_cnts() const { return send_elem_cnts_; } const std::vector& recv_elem_cnts() const { return recv_elem_cnts_; } - // ncclComm_t comm() const { return GetOrCreateComm().comm; } ccl::CclComm ccl_comm() const { return GetOrCreateComm().ccl_comm; } private: - // struct Comm { - // explicit Comm(ncclComm_t comm) : comm(comm) {} - // ncclComm_t comm; - // }; struct Comm { Comm(ccl::CclComm comm) : ccl_comm(comm) {} ccl::CclComm ccl_comm; }; void InitComm() const; - // const Comm& GetOrCreateComm() const { - // if (!comm_) { InitComm(); } - // return *comm_; - // } const Comm& GetOrCreateComm() const { if (!ccl_comm_) { InitComm(); } return *ccl_comm_; @@ -72,7 +63,6 @@ class NcclLogicalSendRecvState final : public user_op::OpKernelState { std::string stream_name_; std::unique_ptr parallel_desc_; - // mutable std::unique_ptr comm_; mutable std::unique_ptr ccl_comm_; bool src_nd_sbp_no_partial_parallel_; std::vector> in_tensor_slice_copier_vec_; @@ -144,9 +134,6 @@ void NcclLogicalSendRecvState::InitComm() const { device_set.emplace(std::make_pair(machine_id, device_id)); } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - // ncclComm_t comm = nullptr; - // comm = comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, - // stream_name_); comm_.reset(new Comm(comm)); ccl::CclComm ccl_comm = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); ccl_comm_.reset(new Comm(ccl_comm)); } @@ -179,9 +166,7 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - // ncclComm_t comm = kernel_state->comm(); ccl::CclComm ccl_comm = kernel_state->ccl_comm(); - // cudaStream_t cuda_stream = ctx->stream()->As()->cuda_stream(); const std::vector& send_elem_cnts = kernel_state->send_elem_cnts(); const std::vector& recv_elem_cnts = kernel_state->recv_elem_cnts(); const int64_t parallel_num = send_elem_cnts.size(); @@ -316,4 +301,4 @@ REGISTER_USER_KERNEL("_nccl_logical_send_recv") } // namespace oneflow -// #endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 +#endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 From 3cb872abc3bd69068f3889ff868beb21421ca5f8 Mon Sep 17 00:00:00 2001 From: luyang Date: Thu, 2 Jan 2025 14:04:29 +0000 Subject: [PATCH 09/28] refactor ccl::AllGather AllReduce ReduceScatter primitive using ccl::CclComm --- oneflow/core/job/eager_ccl_comm_manager.h | 5 +++++ oneflow/core/job/eager_nccl_comm_manager.cpp | 8 ++++++++ oneflow/core/job/eager_nccl_comm_manager.h | 1 + .../collective_communication/cpu/cpu_all_gather.cpp | 5 +++++ .../collective_communication/cpu/cpu_all_reduce.cpp | 5 +++++ .../collective_communication/cpu/cpu_reduce_scatter.cpp | 5 +++++ .../collective_communication/cuda/cuda_all_gather.cpp | 7 +++++++ .../collective_communication/cuda/cuda_all_reduce.cpp | 7 +++++++ .../collective_communication/cuda/cuda_reduce_scatter.cpp | 7 +++++++ .../kernels/collective_communication/include/all_gather.h | 3 +++ .../kernels/collective_communication/include/all_reduce.h | 3 +++ .../include/communication_context.h | 1 + .../collective_communication/include/reduce_scatter.h | 3 +++ 13 files changed, 60 insertions(+) diff --git a/oneflow/core/job/eager_ccl_comm_manager.h b/oneflow/core/job/eager_ccl_comm_manager.h index 3c2405dcf89..b0cefa73efd 100644 --- a/oneflow/core/job/eager_ccl_comm_manager.h +++ b/oneflow/core/job/eager_ccl_comm_manager.h @@ -30,6 +30,11 @@ class EagerCclCommMgr { virtual void CreateCommFromPlan(const Plan& plan) = 0; virtual bool IsAsyncLaunchCclLogicalKernel() const = 0; virtual void SetAsyncLaunchCclLogicalKernel(bool val) = 0; + virtual ccl::CclComm GetCclCommForDevice( + const std::set>& device_set) { + ccl::CclComm ccl_comm{}; + return ccl_comm; + } virtual ccl::CclComm GetCclCommForDeviceAndStreamName( const std::set>& device_set, const std::string& stream_name) { ccl::CclComm ccl_comm{}; diff --git a/oneflow/core/job/eager_nccl_comm_manager.cpp b/oneflow/core/job/eager_nccl_comm_manager.cpp index 342c394293f..9aadffcd559 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.cpp +++ b/oneflow/core/job/eager_nccl_comm_manager.cpp @@ -156,6 +156,14 @@ ncclComm_t EagerNcclCommMgr::GetCommForDeviceAndStreamName( return comm; } +ccl::CclComm EagerNcclCommMgr::GetCclCommForDevice( + const std::set>& device_set) { + ncclComm_t comm = GetCommForDevice(device_set); + std::shared_ptr ncclCommAdapter = std::make_shared(comm); + ccl::CclComm ccl_comm(ncclCommAdapter); + return ccl_comm; +} + ccl::CclComm EagerNcclCommMgr::GetCclCommForDeviceAndStreamName( const std::set>& device_set, const std::string& stream_name) { ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name); diff --git a/oneflow/core/job/eager_nccl_comm_manager.h b/oneflow/core/job/eager_nccl_comm_manager.h index 8e96fc62f2d..a17cffb55cc 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.h +++ b/oneflow/core/job/eager_nccl_comm_manager.h @@ -49,6 +49,7 @@ class EagerNcclCommMgr final : public EagerCclCommMgr { ncclComm_t GetCommForDevice(const std::set>& device_set); ncclComm_t GetCommForDeviceAndStreamName(const std::set>& device_set, const std::string& stream_name); + ccl::CclComm GetCclCommForDevice(const std::set>& device_set); ccl::CclComm GetCclCommForDeviceAndStreamName( const std::set>& device_set, const std::string& stream_name); diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp index f6b6f2005fe..c1a581c7aa3 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp @@ -95,6 +95,11 @@ class CpuAllGather final : public AllGather { CHECK_JUST(AllGatherImpl(in, out, elem_cnt, datatype_, cpu_communication_ctx->parallel_desc())); } + void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + ccl::CclComm ccl_comm) const override { + UNIMPLEMENTED(); + } + private: DataType datatype_; }; diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp index 6550f9b81fb..b4a63890e1c 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp @@ -148,6 +148,11 @@ class CpuAllReduce final : public AllReduce { cpu_communication_ctx->parallel_desc())); } + void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + ccl::CclComm ccl_comm) const override { + UNIMPLEMENTED(); + } + private: DataType datatype_; ReduceType reduce_type_; diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp index 8ff362c2eaa..67f06dabb64 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp @@ -120,6 +120,11 @@ class CpuReduceScatter final : public ReduceScatter { cpu_communication_ctx->parallel_desc())); } + void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + ccl::CclComm ccl_comm) const override { + UNIMPLEMENTED(); + } + private: DataType datatype_; ReduceType reduce_type_; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp index a3012783f74..936d492bc75 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp @@ -40,6 +40,13 @@ class CudaAllGather final : public AllGather { stream->As()->cuda_stream())); } + virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + ccl::CclComm ccl_comm) const override { + ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); + OF_NCCL_CHECK(ncclAllGather(in, out, elem_cnt, nccl_datatype_, *nccl_comm, + stream->As()->cuda_stream())); + } + private: ncclDataType_t nccl_datatype_; }; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp index fa6803d54df..2d035c6ff60 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp @@ -57,6 +57,13 @@ class CudaAllReduce final : public AllReduce { stream->As()->cuda_stream())); } + void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + ccl::CclComm ccl_comm) const override { + ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); + OF_NCCL_CHECK(ncclAllReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm, + stream->As()->cuda_stream())); + } + private: ncclDataType_t nccl_datatype_; ncclRedOp_t nccl_reduce_op_; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp index 80419a84759..fd3811df309 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp @@ -57,6 +57,13 @@ class CudaReduceScatter final : public ReduceScatter { stream->As()->cuda_stream())); } + virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + ccl::CclComm ccl_comm) const override { + ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); + OF_NCCL_CHECK(ncclReduceScatter(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm, + stream->As()->cuda_stream())); + } + private: ncclDataType_t nccl_datatype_; ncclRedOp_t nccl_reduce_op_; diff --git a/oneflow/user/kernels/collective_communication/include/all_gather.h b/oneflow/user/kernels/collective_communication/include/all_gather.h index 66b520be6a5..1212b5eea4b 100644 --- a/oneflow/user/kernels/collective_communication/include/all_gather.h +++ b/oneflow/user/kernels/collective_communication/include/all_gather.h @@ -32,6 +32,9 @@ class AllGather : public CollectiveCommunication { virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communicator) const = 0; + + virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + ccl::CclComm ccl_comm) const = 0; }; inline bool IsAllGatherRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/all_reduce.h b/oneflow/user/kernels/collective_communication/include/all_reduce.h index 0dcc685b966..b152bb8934f 100644 --- a/oneflow/user/kernels/collective_communication/include/all_reduce.h +++ b/oneflow/user/kernels/collective_communication/include/all_reduce.h @@ -32,6 +32,9 @@ class AllReduce : public CollectiveCommunication { virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communicator) const = 0; + + virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + ccl::CclComm ccl_comm) const = 0; }; inline bool IsAllReduceRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/communication_context.h b/oneflow/user/kernels/collective_communication/include/communication_context.h index 9423f0af997..ae6dbf1fdb9 100644 --- a/oneflow/user/kernels/collective_communication/include/communication_context.h +++ b/oneflow/user/kernels/collective_communication/include/communication_context.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COMMUNICATION_CONTEXT_H_ +#include "collective_communication.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/common/auto_registration_factory.h" diff --git a/oneflow/user/kernels/collective_communication/include/reduce_scatter.h b/oneflow/user/kernels/collective_communication/include/reduce_scatter.h index a3b179b48fb..8b97079e30d 100644 --- a/oneflow/user/kernels/collective_communication/include/reduce_scatter.h +++ b/oneflow/user/kernels/collective_communication/include/reduce_scatter.h @@ -32,6 +32,9 @@ class ReduceScatter : public CollectiveCommunication { virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, const std::shared_ptr& communicator) const = 0; + + virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, + ccl::CclComm ccl_comm) const = 0; }; inline bool IsReduceScatterRegistered(DeviceType device_type) { From e5777b99e6223aca272df7d610e4771e1f8132db Mon Sep 17 00:00:00 2001 From: luyang Date: Thu, 2 Jan 2025 14:05:14 +0000 Subject: [PATCH 10/28] refactor _nccl_logical_fusion kernel using ccl::CclComm --- .../kernels/nccl_logical_fusion_kernel.cpp | 149 +++++++++++------- 1 file changed, 90 insertions(+), 59 deletions(-) diff --git a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp index aeb906b6387..662426c5ea1 100644 --- a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp @@ -14,6 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "collective_communication/include/all_gather.h" +#include "collective_communication/include/all_reduce.h" +#include "collective_communication/include/collective_communication.h" +#include "collective_communication/include/reduce_scatter.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" @@ -21,6 +25,8 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" +#include "oneflow/user/kernels/collective_communication/include/send.h" +#include "oneflow/user/kernels/collective_communication/include/recv.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -83,9 +89,9 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { } ~NcclLogicalFusionKernelState() override = default; - ncclComm_t comm() { + ccl::CclComm ccl_comm() { if (!is_init_) { InitComm(); } - return comm_; + return ccl_comm_; } int64_t num_ranks() { @@ -169,8 +175,7 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - comm_ = - comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); is_init_ = true; } @@ -277,7 +282,7 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { std::vector dst_split_axis_list_; std::vector tmp_buffer_offset_; std::vector tmp_buffer_size_; - ncclComm_t comm_{}; + ccl::CclComm ccl_comm_{}; }; class NcclLogicalFusionKernel final : public user_op::OpKernel { @@ -425,109 +430,135 @@ void DoNcclComputeByNcclTypeInGroup(const void* pack_to_ptr, void* unpack_from_p const std::string& nccl_type, const user_op::Tensor* in, user_op::Tensor* out, user_op::KernelComputeContext* ctx, NcclLogicalFusionKernelState* kernel_state, const int32_t i, - const ncclComm_t& comm) { + ccl::CclComm ccl_comm) { + std::unique_ptr ccl_send = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); + std::unique_ptr ccl_recv = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), in->data_type()); + const int64_t num_ranks = kernel_state->num_ranks(); VLOG(3) << "[NcclLogicalFusion] op: " << ctx->op_name() << " , i= " << i << ", stream: " << kernel_state->stream_name() << " Try launch nccl_type: " << nccl_type; if (nccl_type == "_nccl_logical_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); + } else if (nccl_type == "_nccl_logical_reduce_scatter") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt() * num_ranks); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclReduceScatter(in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_reduce_scatter = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), + out->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_all_gather") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_all_gather_noncontinuous") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() != unpack_from_ptr); // do unpack from ptr -> out CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_reduce_scatter_noncontinuous") { CHECK(in->dptr() != pack_to_ptr); // do in -> pack to ptr CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclReduceScatter(pack_to_ptr, out->mut_dptr(), out->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_reduce_scatter = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_reduce_scatter->Launch(ctx->stream(), pack_to_ptr, out->mut_dptr(), + out->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_s2s") { const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t dtype_size = GetSizeOfDataType(in->data_type()); const int64_t elem_per_chunk = elem_cnt / num_ranks; const int64_t chunk_size = elem_per_chunk * dtype_size; for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( + ccl_send->Launch(ctx->stream(), + reinterpret_cast(reinterpret_cast(pack_to_ptr) + + j * chunk_size), + elem_per_chunk, j, ccl_comm); + ccl_recv->Launch( + ctx->stream(), reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); + elem_per_chunk, j, ccl_comm); } } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_gather_noncontinuous") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() != unpack_from_ptr); // do unpack from ptr -> out CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), comm, - ctx->stream()->As()->cuda_stream())); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), + ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") { const int64_t elem_cnt = in->shape_view().elem_cnt(); const int64_t dtype_size = GetSizeOfDataType(in->data_type()); const int64_t elem_per_chunk = elem_cnt / num_ranks; const int64_t chunk_size = elem_per_chunk * dtype_size; for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( + ccl_send->Launch(ctx->stream(), + reinterpret_cast(reinterpret_cast(pack_to_ptr) + + j * chunk_size), + elem_per_chunk, j, ccl_comm); + ccl_recv->Launch( + ctx->stream(), reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); + elem_per_chunk, j, ccl_comm); } } else if (nccl_type == "_nccl_logical_2D_same_dim1_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, comm, - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); + } else { UNIMPLEMENTED(); } @@ -663,7 +694,7 @@ void NcclLogicalFusionKernel::Compute(user_op::KernelComputeContext* ctx, } // NOTE(chengcheng): init nccl comm need before ncclGroupStart. - ncclComm_t comm = kernel_state->comm(); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); // do nccl compute in group OF_NCCL_CHECK(ncclGroupStart()); @@ -671,7 +702,7 @@ void NcclLogicalFusionKernel::Compute(user_op::KernelComputeContext* ctx, const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", i); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", i); DoNcclComputeByNcclTypeInGroup(pack_to_ptr_list.at(i), unpack_from_ptr_list.at(i), - nccl_type_list.at(i), in, out, ctx, kernel_state, i, comm); + nccl_type_list.at(i), in, out, ctx, kernel_state, i, ccl_comm); } OF_NCCL_CHECK(ncclGroupEnd()); From ad9e7ee8aa282967e34174462e25a015e101aa68 Mon Sep 17 00:00:00 2001 From: luyang Date: Fri, 3 Jan 2025 07:24:39 +0000 Subject: [PATCH 11/28] refine --- .../core/kernel/nccl_send_recv_boxing_kernel.cpp | 4 ++++ .../include/collective_communication.h | 2 -- oneflow/user/kernels/nccl_logical_fusion_kernel.cpp | 13 ++++++------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp index afc366be174..6fcb71e61fa 100644 --- a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -23,6 +23,8 @@ limitations under the License. #include "oneflow/user/kernels/collective_communication/include/send.h" #include "oneflow/user/kernels/collective_communication/include/recv.h" +#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 + namespace oneflow { class NcclSendRecvBoxingKernel final : public Kernel { @@ -255,3 +257,5 @@ REGISTER_KERNEL(OperatorConf::kNcclSendRecvBoxingConf, NcclSendRecvBoxingKernel) REGISTER_SYSTEM_OP_KERNEL_UNIFIED_NCCL_COMM_INIT(OperatorConf::kNcclSendRecvBoxingConf); } // namespace oneflow + +#endif // WITH_CUDA && NCCL_VERSION_CODE > 2700 diff --git a/oneflow/user/kernels/collective_communication/include/collective_communication.h b/oneflow/user/kernels/collective_communication/include/collective_communication.h index 53433d711c1..14ddd364a5a 100644 --- a/oneflow/user/kernels/collective_communication/include/collective_communication.h +++ b/oneflow/user/kernels/collective_communication/include/collective_communication.h @@ -16,8 +16,6 @@ limitations under the License. #ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COLLECTIVE_COMMUNICATION_H_ #define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_COLLECTIVE_COMMUNICATION_H_ -#include -#include #include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/common/switch_func.h" #include "oneflow/user/kernels/collective_communication/include/communication_context.h" diff --git a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp index 662426c5ea1..c282955908c 100644 --- a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp @@ -13,11 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#include "collective_communication/include/all_gather.h" -#include "collective_communication/include/all_reduce.h" -#include "collective_communication/include/collective_communication.h" -#include "collective_communication/include/reduce_scatter.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/nccl_util.h" #include "oneflow/core/job/eager_nccl_comm_manager.h" @@ -25,8 +20,12 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" -#include "oneflow/user/kernels/collective_communication/include/send.h" -#include "oneflow/user/kernels/collective_communication/include/recv.h" +#include "collective_communication/include/send.h" +#include "collective_communication/include/recv.h" +#include "collective_communication/include/all_gather.h" +#include "collective_communication/include/all_reduce.h" +#include "collective_communication/include/collective_communication.h" +#include "collective_communication/include/reduce_scatter.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 From a9cd8df02cd5024ec060a4c050a95aa8792ea4d3 Mon Sep 17 00:00:00 2001 From: luyang Date: Thu, 9 Jan 2025 09:23:08 +0000 Subject: [PATCH 12/28] support ccl::AllToAll --- oneflow/core/job/eager_nccl_comm_manager.h | 2 +- .../kernel/nccl_send_recv_boxing_kernel.cpp | 34 +++--- .../cuda/cuda_all_to_all.cpp | 109 ++++++++++++++++++ .../include/all_to_all.h | 51 ++++++++ 4 files changed, 177 insertions(+), 19 deletions(-) create mode 100644 oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp create mode 100644 oneflow/user/kernels/collective_communication/include/all_to_all.h diff --git a/oneflow/core/job/eager_nccl_comm_manager.h b/oneflow/core/job/eager_nccl_comm_manager.h index a17cffb55cc..9a91692a4d7 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.h +++ b/oneflow/core/job/eager_nccl_comm_manager.h @@ -29,7 +29,7 @@ namespace ccl { class NcclCommAdapter : public CommBase { public: - NcclCommAdapter(ncclComm_t comm) : comm_(comm) {} + explicit NcclCommAdapter(ncclComm_t comm) : comm_(comm) {} void* getComm() override { return static_cast(&comm_); } diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp index 6fcb71e61fa..e7fc1939f89 100644 --- a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -20,8 +20,7 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" -#include "oneflow/user/kernels/collective_communication/include/send.h" -#include "oneflow/user/kernels/collective_communication/include/recv.h" +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -93,19 +92,24 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { const DataType data_type = buf->data_type(); std::vector send_in_ptr; std::vector recv_out_ptr; + std::vector send_offsets; + std::vector recv_offsets; char* buf_ptr = buf->mut_dptr(); - int64_t offset = 0; + uint64_t offset = 0; if (this->has_input()) { for (int64_t i = 0; i < parallel_num; ++i) { void* send_ptr = reinterpret_cast(buf_ptr + offset); send_in_ptr.push_back(send_ptr); + send_offsets.push_back(offset); offset += send_elem_cnts.at(i) * GetSizeOfDataType(data_type); } } + const uint64_t recv_offset = offset; if (this->has_output()) { for (int64_t i = 0; i < parallel_num; ++i) { void* recv_ptr = reinterpret_cast(buf_ptr + offset); recv_out_ptr.push_back(recv_ptr); + recv_offsets.push_back(offset - recv_offset); offset += recv_elem_cnts.at(i) * GetSizeOfDataType(data_type); } } @@ -119,22 +123,16 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { } } } - // init ccl Send/Recv primitive - std::unique_ptr send = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); - std::unique_ptr recv = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); - // launch ccl::Send/Recv - OF_NCCL_CHECK(ncclGroupStart()); - for (int64_t i = 0; i < parallel_num; ++i) { - if (this->has_input() && send_elem_cnts.at(i) != 0) { - send->Launch(ctx->stream(), send_in_ptr.at(i), send_elem_cnts.at(i), i, ccl_comm); - } - if (this->has_output() && recv_elem_cnts.at(i) != 0) { - recv->Launch(ctx->stream(), recv_out_ptr.at(i), recv_elem_cnts.at(i), i, ccl_comm); - } + + if (this->has_input() && this->has_output()) { + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), data_type, data_type, parallel_num); + void* send_buf = reinterpret_cast(buf_ptr); + void* recv_buf = reinterpret_cast(buf_ptr + recv_offset); + all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), + recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm); } - OF_NCCL_CHECK(ncclGroupEnd()); + if (!this->has_output()) { return; } Blob* out = ctx->BnInOp2Blob("out"); const std::vector>& out_tensor_slice_copier_vec = diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp new file mode 100644 index 00000000000..1dd71fb4c98 --- /dev/null +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp @@ -0,0 +1,109 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifdef WITH_CUDA +#include "oneflow/user/kernels/collective_communication/include/send.h" +#include "oneflow/user/kernels/collective_communication/include/recv.h" +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" +#include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" +#include "oneflow/core/device/nccl_util.h" +#include "oneflow/core/common/device_type.h" + +namespace oneflow { + +namespace ccl { + +class CudaAllToAll final : public AllToAll { + public: + OF_DISALLOW_COPY_AND_MOVE(CudaAllToAll); + CudaAllToAll() + : send_dtype_(), recv_dtype_(), nccl_send_dtype_(), nccl_recv_dtype_(), rank_count_(0) {} + ~CudaAllToAll() = default; + + void Init(DataType send_dtype, DataType recv_dtype, size_t parallel_num) override { + this->send_dtype_ = send_dtype; + this->recv_dtype_ = recv_dtype; + this->nccl_send_dtype_ = GetNcclDataType(send_dtype); + this->nccl_recv_dtype_ = GetNcclDataType(recv_dtype); + this->rank_count_ = parallel_num; + } + + void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv, int64_t recv_count, + ccl::CclComm ccl_comm) const override { + ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); + int64_t send_offset = 0; + int64_t recv_offset = 0; + OF_NCCL_CHECK(ncclGroupStart()); + for (int64_t i = 0; i < this->rank_count_; ++i) { + if (send_count > 0) { + char* send_ptr = static_cast(send) + send_offset; + OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + send_offset += send_count * GetSizeOfDataType(this->send_dtype_); + if (recv_count) { + char* recv_ptr = static_cast(recv) + recv_offset; + OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + recv_offset += recv_count * GetSizeOfDataType(this->recv_dtype_); + } + OF_NCCL_CHECK(ncclGroupEnd()); + } + + void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets, + void* recv, const void* recv_counts, const void* recv_offsets, + ccl::CclComm ccl_comm) const override { + ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); + int64_t* send_counts_ptr = static_cast(const_cast(send_counts)); + int64_t* recv_counts_ptr = static_cast(const_cast(recv_counts)); + int64_t* send_offsets_ptr = static_cast(const_cast(send_offsets)); + int64_t* recv_offsets_ptr = static_cast(const_cast(recv_offsets)); + OF_NCCL_CHECK(ncclGroupStart()); + for (int64_t i = 0; i < this->rank_count_; ++i) { + uint64_t send_offset = static_cast(send_offsets_ptr[i]); + uint64_t send_count = static_cast(send_counts_ptr[i]); + char* send_ptr = static_cast(send) + send_offset; + if (send_count > 0) { + OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + + uint64_t recv_offset = static_cast(recv_offsets_ptr[i]); + uint64_t recv_count = static_cast(recv_counts_ptr[i]); + char* recv_ptr = static_cast(recv) + recv_offset; + if (recv_count > 0) { + OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + } + OF_NCCL_CHECK(ncclGroupEnd()); + } + + private: + DataType send_dtype_; + DataType recv_dtype_; + ncclDataType_t nccl_send_dtype_; + ncclDataType_t nccl_recv_dtype_; + size_t rank_count_; +}; + +REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, AllToAll, CudaAllToAll); + +} // namespace ccl + +} // namespace oneflow + +#endif // WITH_CUDA diff --git a/oneflow/user/kernels/collective_communication/include/all_to_all.h b/oneflow/user/kernels/collective_communication/include/all_to_all.h new file mode 100644 index 00000000000..0f39cf3d875 --- /dev/null +++ b/oneflow/user/kernels/collective_communication/include/all_to_all.h @@ -0,0 +1,51 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_ +#define ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_ + +#include "oneflow/user/kernels/collective_communication/include/collective_communication.h" + +namespace oneflow { + +namespace ccl { + +class AllToAll : public CollectiveCommunication { + public: + OF_DISALLOW_COPY_AND_MOVE(AllToAll); + AllToAll() = default; + ~AllToAll() override = default; + + virtual void Init(DataType send_dtype, DataType recv_dtype, size_t rank_count) = 0; + + // for normal alltoall(balanced send/resv count) + virtual void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv, + int64_t recv_count, ccl::CclComm ccl_comm) const = 0; + + // for unbalanced all to all(e.g. nccl all2all using send/recv; hccl HcclAlltoAllV) + virtual void Launch(ep::Stream* stream, void* send, const void* send_counts, + const void* send_offsets, void* recv, const void* recv_counts, + const void* recv_offsets, ccl::CclComm ccl_comm) const = 0; +}; + +inline bool IsAllToAllRegistered(DeviceType device_type) { + return IsClassRegistered(device_type); +} + +} // namespace ccl + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_COLLECTIVE_COMMUNICATION_INCLUDE_ALL_TO_ALL_H_ From 15198c6f4202af85e09e6404b3bfe1da2f766db6 Mon Sep 17 00:00:00 2001 From: luyang Date: Thu, 9 Jan 2025 14:41:02 +0000 Subject: [PATCH 13/28] more kernels using ccl::CclComm and ccl apis --- oneflow/user/kernels/eager_nccl_s2s_kernel.cu | 29 ++--- .../kernels/nccl_logical_2d_sbp_kernels.cpp | 107 +++++++++-------- oneflow/user/kernels/nccl_logical_kernels.cpp | 113 ++++++++++-------- 3 files changed, 129 insertions(+), 120 deletions(-) diff --git a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu index 2b5511b9c64..2b38f634b32 100644 --- a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu +++ b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu @@ -21,6 +21,7 @@ limitations under the License. #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -34,7 +35,7 @@ class EagerNcclOpKernelCache final : public user_op::OpKernelCache { ~EagerNcclOpKernelCache() override = default; Symbol parallel_desc() const { return parallel_desc_; } - ncclComm_t comm() const { return comm_; } + ccl::CclComm ccl_comm() const { return ccl_comm_; } private: void Init(user_op::KernelCacheContext* ctx) { @@ -48,13 +49,12 @@ class EagerNcclOpKernelCache final : public user_op::OpKernelCache { int64_t device_id = CHECK_JUST(parallel_desc_->DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } - comm_ = CHECK_NOTNULL(Singleton::Get()) - ->As() - ->GetCommForDevice(device_set); + EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); + ccl_comm_ = comm_mgr->GetCclCommForDevice(device_set); } Symbol parallel_desc_; - ncclComm_t comm_{}; + ccl::CclComm ccl_comm_{}; }; size_t InferEagerNcclS2SKernelTmpBufferSize(user_op::InferContext* ctx) { @@ -148,21 +148,12 @@ class EagerNcclS2SKernel final : public user_op::OpKernel { { // NOTE: Do S2S - OF_NCCL_CHECK(ncclGroupStart()); const int64_t elem_per_chunk = elem_cnt / num_ranks; - const int64_t chunk_size = elem_per_chunk * dtype_size; - for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, - kernel_cache->comm(), - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( - reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, kernel_cache->comm(), - ctx->stream()->As()->cuda_stream())); - } - OF_NCCL_CHECK(ncclGroupEnd()); + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); + ccl::CclComm ccl_comm = kernel_cache->ccl_comm(); + all_to_all->Launch(ctx->stream(), pack_to_ptr, elem_per_chunk, unpack_from_ptr, + elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { diff --git a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp index dac2545f905..495a9e6ade6 100644 --- a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp @@ -21,6 +21,9 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" +#include "oneflow/user/kernels/collective_communication/include/all_reduce.h" +#include "oneflow/user/kernels/collective_communication/include/all_gather.h" +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -39,9 +42,9 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState { } ~NcclLogical2DSameDim0KernelCommState() override = default; - ncclComm_t comm() { + ccl::CclComm ccl_comm() { if (!is_init_) { Init(); } - return comm_; + return ccl_comm_; } int64_t num_ranks() { @@ -70,8 +73,7 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState { device_set.emplace(std::make_pair(machine_id, device_id)); } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - comm_ = - comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); num_ranks_ = group_size; is_init_ = true; } @@ -81,7 +83,7 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState { ParallelDesc parallel_desc_; int64_t this_parallel_id_; int64_t num_ranks_{}; - ncclComm_t comm_{}; + ccl::CclComm ccl_comm_{}; }; class NcclLogical2DSameDim0AllGatherNoncontinuousKernelState @@ -127,19 +129,22 @@ class NcclLogical2DSameDim0AllReduce final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); - VLOG(3) << "[NcclLogical2D][SameDim0AllReduce] " << nccl_comm->stream_name() << " " + VLOG(3) << "[NcclLogical2D][SameDim0AllReduce] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, nccl_comm->comm(), - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -161,18 +166,22 @@ class NcclLogical2DSameDim0AllGather final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); - const int64_t num_ranks = nccl_comm->num_ranks(); + const int64_t num_ranks = comm_state->num_ranks(); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - VLOG(3) << "[NcclLogical2D][SameDim0AllGather] " << nccl_comm->stream_name() << " " + VLOG(3) << "[NcclLogical2D][SameDim0AllGather] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), nccl_comm->comm(), - ctx->stream()->As()->cuda_stream())); + + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -225,9 +234,13 @@ class NcclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKern // NOTE(chengcheng): Do AllGather CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); + + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), + ccl_comm); CHECK_GT(in_split_axis, 0); // NOTE(chengcheng): Do unpack. @@ -342,21 +355,12 @@ class NcclLogical2DSameDim0All2All final : public user_op::OpKernel { { // NOTE(chengcheng): Do S2S - OF_NCCL_CHECK(ncclGroupStart()); const int64_t elem_per_chunk = elem_cnt / num_ranks; - const int64_t chunk_size = elem_per_chunk * dtype_size; - for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, - kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( - reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); - } - OF_NCCL_CHECK(ncclGroupEnd()); + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); + all_to_all->Launch(ctx->stream(), pack_to_ptr, elem_per_chunk, unpack_from_ptr, + elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { @@ -414,7 +418,7 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState } ~NcclLogical2DSameDim1KernelCommState() = default; - ncclComm_t comm() { + ccl::CclComm ccl_comm() { if (!is_init_) { std::set> device_set; const Shape& hierarchy = *parallel_desc_.hierarchy(); @@ -432,11 +436,10 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState device_set.emplace(std::make_pair(machine_id, device_id)); } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - comm_ = - comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); is_init_ = true; } - return comm_; + return ccl_comm_; } const std::string& stream_name() const { return stream_name_; } @@ -446,7 +449,7 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState std::string stream_name_; ParallelDesc parallel_desc_; int64_t this_parallel_id_; - ncclComm_t comm_{}; + ccl::CclComm ccl_comm_{}; }; class NcclLogical2DSameDim1AllReduce final : public user_op::OpKernel { @@ -462,19 +465,23 @@ class NcclLogical2DSameDim1AllReduce final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); - VLOG(3) << "[NcclLogical2D][SameDim1AllReduce] " << nccl_comm->stream_name() << " " + VLOG(3) << "[NcclLogical2D][SameDim1AllReduce] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, nccl_comm->comm(), - ctx->stream()->As()->cuda_stream())); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp index be88acbca18..85d7b539fbd 100644 --- a/oneflow/user/kernels/nccl_logical_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_kernels.cpp @@ -21,6 +21,10 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" +#include "oneflow/user/kernels/collective_communication/include/all_reduce.h" +#include "oneflow/user/kernels/collective_communication/include/all_gather.h" +#include "oneflow/user/kernels/collective_communication/include/reduce_scatter.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -38,7 +42,7 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState { } ~NcclLogicalKernelCommState() override = default; - ncclComm_t comm() { + ccl::CclComm ccl_comm() { if (!is_init_) { std::set> device_set; FOR_RANGE(int64_t, parallel_id, 0, parallel_desc_.parallel_num()) { @@ -47,11 +51,10 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState { device_set.emplace(std::make_pair(machine_id, device_id)); } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - comm_ = - comm_mgr->As()->GetCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); is_init_ = true; } - return comm_; + return ccl_comm_; } const std::string& stream_name() const { return stream_name_; } @@ -60,7 +63,7 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState { bool is_init_; std::string stream_name_; ParallelDesc parallel_desc_; - ncclComm_t comm_{}; + ccl::CclComm ccl_comm{}; }; class NcclLogicalAllGatherNoncontinuousKernelState : public NcclLogicalKernelCommState { @@ -118,19 +121,23 @@ class NcclLogicalAllReduceKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape_view(), out->shape_view()); CHECK_EQ(in->data_type(), out->data_type()); - VLOG(3) << "[NcclLogical][AllReduce] " << nccl_comm->stream_name() << " " << ctx->op_name() + VLOG(3) << "[NcclLogical][AllReduce] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), reduce_type, nccl_comm->comm(), - ctx->stream()->As()->cuda_stream())); + + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_all_reduce = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_all_reduce->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -152,20 +159,24 @@ class NcclLogicalReduceScatterKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = ctx->parallel_ctx().parallel_num(); CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt() * num_ranks); - VLOG(3) << "[NcclLogical][ReduceScatter] " << nccl_comm->stream_name() << " " << ctx->op_name() + VLOG(3) << "[NcclLogical][ReduceScatter] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == DataType::kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclReduceScatter( - in->dptr(), out->mut_dptr(), out->shape_view().elem_cnt(), GetNcclDataType(in->data_type()), - reduce_type, nccl_comm->comm(), ctx->stream()->As()->cuda_stream())); + + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_reduce_scatter = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_reduce_scatter->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), + out->shape_view().elem_cnt(), ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -187,18 +198,22 @@ class NcclLogicalAllGatherKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache*) const override { - auto* nccl_comm = dynamic_cast(state); - CHECK(nccl_comm != nullptr); + auto* comm_state = dynamic_cast(state); + CHECK(comm_state != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); const int64_t num_ranks = ctx->parallel_ctx().parallel_num(); CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - VLOG(3) << "[NcclLogical][AllGather] " << nccl_comm->stream_name() << " " << ctx->op_name() + VLOG(3) << "[NcclLogical][AllGather] " << comm_state->stream_name() << " " << ctx->op_name() << std::endl; - OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), nccl_comm->comm(), - ctx->stream()->As()->cuda_stream())); + + ccl::CclComm ccl_comm = comm_state->ccl_comm(); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), out->mut_dptr(), in->shape_view().elem_cnt(), + ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -250,9 +265,12 @@ class NcclLogicalAllGatherNoncontinuous final : public user_op::OpKernel { // NOTE(chengcheng): Do AllGather CHECK_EQ(in->shape_view().elem_cnt() * num_ranks, out->shape_view().elem_cnt()); - OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), - GetNcclDataType(in->data_type()), kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); + std::unique_ptr ccl_all_gather = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type()); + ccl_all_gather->Launch(ctx->stream(), in->dptr(), unpack_from_ptr, in->shape_view().elem_cnt(), + ccl_comm); CHECK_GT(in_split_axis, 0); // NOTE(chengcheng): Do unpack. @@ -334,12 +352,15 @@ class NcclLogicalReduceScatterNoncontinuous final : public user_op::OpKernel { transpose_in_dim_vec.data(), in->dptr(), perm.data(), tmp_buffer->mut_dptr()); VLOG(3) << "[NcclLogical][ReduceScatterNoncontinuous] " << kernel_state->stream_name() << " " << ctx->op_name() << std::endl; - ncclRedOp_t reduce_type = ncclRedOp_t::ncclSum; - if (in->data_type() == kBool) { reduce_type = ncclRedOp_t::ncclMax; } - OF_NCCL_CHECK(ncclReduceScatter(tmp_buffer->dptr(), out->mut_dptr(), - out->shape_view().elem_cnt(), GetNcclDataType(in->data_type()), - reduce_type, kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); + + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); + ccl::ReduceType ccl_reduce_type = ccl::ReduceType::kSum; + if (in->data_type() == DataType::kBool) { ccl_reduce_type = ccl::ReduceType::kMax; } + std::unique_ptr ccl_reduce_scatter = + ccl::NewCollectiveCommunication(ctx->stream()->device_type(), + in->data_type(), ccl_reduce_type); + ccl_reduce_scatter->Launch(ctx->stream(), tmp_buffer->dptr(), out->mut_dptr(), + out->shape_view().elem_cnt(), ccl_comm); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool IsKernelLaunchSynchronized() const override { @@ -437,23 +458,13 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel { } { - // NOTE(chengcheng): init nccl comm need before ncclGroupStart. - ncclComm_t comm = kernel_state->comm(); // NOTE(chengcheng): Do S2S - OF_NCCL_CHECK(ncclGroupStart()); const int64_t elem_per_chunk = elem_cnt / num_ranks; - const int64_t chunk_size = elem_per_chunk * dtype_size; - for (int64_t j = 0; j < num_ranks; ++j) { - OF_NCCL_CHECK(ncclSend(reinterpret_cast( - reinterpret_cast(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, - ctx->stream()->As()->cuda_stream())); - OF_NCCL_CHECK(ncclRecv( - reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, kernel_state->comm(), - ctx->stream()->As()->cuda_stream())); - } - OF_NCCL_CHECK(ncclGroupEnd()); + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); + ccl::CclComm ccl_comm = kernel_cache->ccl_comm(); + all_to_all->Launch(ctx->stream(), pack_to_ptr, elem_per_chunk, unpack_from_ptr, + elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { From fa3d77fb0bb2f76b7ad92a35ad1edf5782652f63 Mon Sep 17 00:00:00 2001 From: luyang Date: Fri, 10 Jan 2025 03:40:56 +0000 Subject: [PATCH 14/28] refine --- oneflow/user/kernels/eager_nccl_s2s_kernel.cu | 4 ++-- oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp | 4 ++-- oneflow/user/kernels/nccl_logical_kernels.cpp | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu index 2b38f634b32..101f1ca8ec9 100644 --- a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu +++ b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu @@ -152,8 +152,8 @@ class EagerNcclS2SKernel final : public user_op::OpKernel { std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); ccl::CclComm ccl_comm = kernel_cache->ccl_comm(); - all_to_all->Launch(ctx->stream(), pack_to_ptr, elem_per_chunk, unpack_from_ptr, - elem_per_chunk, ccl_comm); + all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, + unpack_from_ptr, elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { diff --git a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp index 495a9e6ade6..d337350844f 100644 --- a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp @@ -359,8 +359,8 @@ class NcclLogical2DSameDim0All2All final : public user_op::OpKernel { std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); ccl::CclComm ccl_comm = kernel_state->ccl_comm(); - all_to_all->Launch(ctx->stream(), pack_to_ptr, elem_per_chunk, unpack_from_ptr, - elem_per_chunk, ccl_comm); + all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, + unpack_from_ptr, elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp index 85d7b539fbd..34dfb2e8dd5 100644 --- a/oneflow/user/kernels/nccl_logical_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_kernels.cpp @@ -63,7 +63,7 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState { bool is_init_; std::string stream_name_; ParallelDesc parallel_desc_; - ccl::CclComm ccl_comm{}; + ccl::CclComm ccl_comm_{}; }; class NcclLogicalAllGatherNoncontinuousKernelState : public NcclLogicalKernelCommState { @@ -462,9 +462,9 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel { const int64_t elem_per_chunk = elem_cnt / num_ranks; std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); - ccl::CclComm ccl_comm = kernel_cache->ccl_comm(); - all_to_all->Launch(ctx->stream(), pack_to_ptr, elem_per_chunk, unpack_from_ptr, - elem_per_chunk, ccl_comm); + ccl::CclComm ccl_comm = kernel_state->ccl_comm(); + all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, + unpack_from_ptr, elem_per_chunk, ccl_comm); } if (in_split_axis != 0) { From 81394af6ce036f002106531e718e7ec657910f62 Mon Sep 17 00:00:00 2001 From: luyang Date: Fri, 10 Jan 2025 04:13:00 +0000 Subject: [PATCH 15/28] refine all2all --- .../kernels/nccl_logical_fusion_kernel.cpp | 36 ++++++------------- .../kernels/nccl_logical_send_recv_kernel.cpp | 33 ++++++++--------- 2 files changed, 25 insertions(+), 44 deletions(-) diff --git a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp index c282955908c..b7bee7d8f17 100644 --- a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp @@ -20,11 +20,12 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/permute.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/user/ops/nccl_logical_util.h" +#include "collective_communication/include/collective_communication.h" #include "collective_communication/include/send.h" #include "collective_communication/include/recv.h" #include "collective_communication/include/all_gather.h" #include "collective_communication/include/all_reduce.h" -#include "collective_communication/include/collective_communication.h" +#include "collective_communication/include/all_to_all.h" #include "collective_communication/include/reduce_scatter.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -491,19 +492,12 @@ void DoNcclComputeByNcclTypeInGroup(const void* pack_to_ptr, void* unpack_from_p out->shape_view().elem_cnt(), ccl_comm); } else if (nccl_type == "_nccl_logical_s2s") { const int64_t elem_cnt = in->shape_view().elem_cnt(); - const int64_t dtype_size = GetSizeOfDataType(in->data_type()); const int64_t elem_per_chunk = elem_cnt / num_ranks; - const int64_t chunk_size = elem_per_chunk * dtype_size; - for (int64_t j = 0; j < num_ranks; ++j) { - ccl_send->Launch(ctx->stream(), - reinterpret_cast(reinterpret_cast(pack_to_ptr) - + j * chunk_size), - elem_per_chunk, j, ccl_comm); - ccl_recv->Launch( - ctx->stream(), - reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, j, ccl_comm); - } + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); + all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, + unpack_from_ptr, elem_per_chunk, ccl_comm); + } else if (nccl_type == "_nccl_logical_2D_same_dim0_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); @@ -534,19 +528,11 @@ void DoNcclComputeByNcclTypeInGroup(const void* pack_to_ptr, void* unpack_from_p ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim0_all2all") { const int64_t elem_cnt = in->shape_view().elem_cnt(); - const int64_t dtype_size = GetSizeOfDataType(in->data_type()); const int64_t elem_per_chunk = elem_cnt / num_ranks; - const int64_t chunk_size = elem_per_chunk * dtype_size; - for (int64_t j = 0; j < num_ranks; ++j) { - ccl_send->Launch(ctx->stream(), - reinterpret_cast(reinterpret_cast(pack_to_ptr) - + j * chunk_size), - elem_per_chunk, j, ccl_comm); - ccl_recv->Launch( - ctx->stream(), - reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, j, ccl_comm); - } + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); + all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, + unpack_from_ptr, elem_per_chunk, ccl_comm); } else if (nccl_type == "_nccl_logical_2D_same_dim1_all_reduce") { CHECK(in->dptr() == pack_to_ptr); CHECK(out->mut_dptr() == unpack_from_ptr); diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index 59bc8fa151c..1606f7d2ccf 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -28,8 +28,7 @@ limitations under the License. #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" -#include "oneflow/user/kernels/collective_communication/include/send.h" -#include "oneflow/user/kernels/collective_communication/include/recv.h" +#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 @@ -174,16 +173,21 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O std::vector send_in_ptr; std::vector recv_out_ptr; + std::vector send_offsets; + std::vector recv_offsets; char* buf_ptr = tmp_buffer->mut_dptr(); - int64_t offset = 0; + uint64_t offset = 0; for (int64_t i = 0; i < parallel_num; ++i) { void* send_ptr = reinterpret_cast(buf_ptr + offset); send_in_ptr.push_back(send_ptr); + send_offsets.push_back(offset); offset += send_elem_cnts.at(i) * GetSizeOfDataType(data_type); } + const uint64_t recv_offset = offset; for (int64_t i = 0; i < parallel_num; ++i) { void* recv_ptr = reinterpret_cast(buf_ptr + offset); recv_out_ptr.push_back(recv_ptr); + recv_offsets.push_back(offset - recv_offset); offset += recv_elem_cnts.at(i) * GetSizeOfDataType(data_type); } @@ -196,22 +200,13 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O } const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); - std::unique_ptr send = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); - std::unique_ptr recv = - ccl::NewCollectiveCommunication(ctx->stream()->device_type(), data_type); - OF_NCCL_CHECK(ncclGroupStart()); - for (int64_t i = 0; i < parallel_num; ++i) { - if (send_elem_cnts.at(i) != 0) { - LOG(INFO) << parallel_id << " send " << send_elem_cnts.at(i) << " to " << i; - send->Launch(ctx->stream(), send_in_ptr.at(i), send_elem_cnts.at(i), i, ccl_comm); - } - if (recv_elem_cnts.at(i) != 0) { - LOG(INFO) << parallel_id << " recv " << recv_elem_cnts.at(i) << " from " << i; - recv->Launch(ctx->stream(), recv_out_ptr.at(i), recv_elem_cnts.at(i), i, ccl_comm); - } - } - OF_NCCL_CHECK(ncclGroupEnd()); + std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( + ctx->stream()->device_type(), data_type, data_type, parallel_num); + void* send_buf = reinterpret_cast(buf_ptr); + void* recv_buf = reinterpret_cast(buf_ptr + recv_offset); + all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), recv_buf, + recv_elem_cnts.data(), recv_offsets.data(), ccl_comm); + const std::vector>& out_tensor_slice_copier_vec = kernel_state->out_tensor_slice_copier_vec(); From 94dafda4f0917d572074448369f79ab3931fad91 Mon Sep 17 00:00:00 2001 From: luyang Date: Fri, 10 Jan 2025 15:23:46 +0000 Subject: [PATCH 16/28] refine --- oneflow/core/graph/boxing/collective_boxing.proto | 5 ----- .../collective_communication/cuda/cuda_all_to_all.cpp | 2 -- 2 files changed, 7 deletions(-) diff --git a/oneflow/core/graph/boxing/collective_boxing.proto b/oneflow/core/graph/boxing/collective_boxing.proto index c024eb733a2..0f3d01d525c 100644 --- a/oneflow/core/graph/boxing/collective_boxing.proto +++ b/oneflow/core/graph/boxing/collective_boxing.proto @@ -15,11 +15,6 @@ enum OpType { kOpTypeAll2All = 6; } -enum CclCommType{ - kCommTypeCudaNccl = 0; - kCommTypeAscendHccl = 1; -} - enum ReduceMethod { kReduceMethodInvalid = 0; kReduceMethodSum = 1; diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp index 1dd71fb4c98..9f9c3e31ef1 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp @@ -14,8 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA -#include "oneflow/user/kernels/collective_communication/include/send.h" -#include "oneflow/user/kernels/collective_communication/include/recv.h" #include "oneflow/user/kernels/collective_communication/include/all_to_all.h" #include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" #include "oneflow/core/device/nccl_util.h" From 21529c874a0dcaec943445b2806cab0bd9dc1430 Mon Sep 17 00:00:00 2001 From: luyang Date: Tue, 14 Jan 2025 03:10:01 +0000 Subject: [PATCH 17/28] refine --- oneflow/core/job/eager_ccl_comm_manager.h | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/oneflow/core/job/eager_ccl_comm_manager.h b/oneflow/core/job/eager_ccl_comm_manager.h index b0cefa73efd..fcce9e9509a 100644 --- a/oneflow/core/job/eager_ccl_comm_manager.h +++ b/oneflow/core/job/eager_ccl_comm_manager.h @@ -31,15 +31,9 @@ class EagerCclCommMgr { virtual bool IsAsyncLaunchCclLogicalKernel() const = 0; virtual void SetAsyncLaunchCclLogicalKernel(bool val) = 0; virtual ccl::CclComm GetCclCommForDevice( - const std::set>& device_set) { - ccl::CclComm ccl_comm{}; - return ccl_comm; - } + const std::set>& device_set) = 0; virtual ccl::CclComm GetCclCommForDeviceAndStreamName( - const std::set>& device_set, const std::string& stream_name) { - ccl::CclComm ccl_comm{}; - return ccl_comm; - } + const std::set>& device_set, const std::string& stream_name) = 0; template T* As() { From 5636f1b00c597f424151f6e59d16bed532babb2c Mon Sep 17 00:00:00 2001 From: luyang Date: Wed, 22 Jan 2025 07:27:17 +0000 Subject: [PATCH 18/28] refine --- oneflow/core/job/eager_nccl_comm_manager.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/oneflow/core/job/eager_nccl_comm_manager.h b/oneflow/core/job/eager_nccl_comm_manager.h index 9a91692a4d7..489138eb1d9 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.h +++ b/oneflow/core/job/eager_nccl_comm_manager.h @@ -49,9 +49,11 @@ class EagerNcclCommMgr final : public EagerCclCommMgr { ncclComm_t GetCommForDevice(const std::set>& device_set); ncclComm_t GetCommForDeviceAndStreamName(const std::set>& device_set, const std::string& stream_name); - ccl::CclComm GetCclCommForDevice(const std::set>& device_set); + ccl::CclComm GetCclCommForDevice( + const std::set>& device_set) override; ccl::CclComm GetCclCommForDeviceAndStreamName( - const std::set>& device_set, const std::string& stream_name); + const std::set>& device_set, + const std::string& stream_name) override; void CreateCommFromPlan(const Plan& plan) override; bool IsAsyncLaunchCclLogicalKernel() const override { return async_launch_nccl_logical_kernel_; } From 8110bd2d2ebc6d7879e5355256152b0bfbc22959 Mon Sep 17 00:00:00 2001 From: Luyang Date: Fri, 24 Jan 2025 20:35:22 +0800 Subject: [PATCH 19/28] Update oneflow/user/kernels/eager_nccl_s2s_kernel.cu Co-authored-by: binbinHan --- oneflow/user/kernels/eager_nccl_s2s_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu index 101f1ca8ec9..426a3b8b66f 100644 --- a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu +++ b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu @@ -35,7 +35,7 @@ class EagerNcclOpKernelCache final : public user_op::OpKernelCache { ~EagerNcclOpKernelCache() override = default; Symbol parallel_desc() const { return parallel_desc_; } - ccl::CclComm ccl_comm() const { return ccl_comm_; } + const ccl::CclComm& ccl_comm() const { return ccl_comm_; } private: void Init(user_op::KernelCacheContext* ctx) { From 9bb1fb8665b117ba11ae59176788ee6b76029814 Mon Sep 17 00:00:00 2001 From: Luyang Date: Fri, 24 Jan 2025 20:35:34 +0800 Subject: [PATCH 20/28] Update oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp Co-authored-by: binbinHan --- oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp index d337350844f..f173c0e6dfd 100644 --- a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp @@ -42,7 +42,7 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState { } ~NcclLogical2DSameDim0KernelCommState() override = default; - ccl::CclComm ccl_comm() { + const ccl::CclComm& ccl_comm() const { if (!is_init_) { Init(); } return ccl_comm_; } From 45fec552efe2dfd747878cec1bc8d42940ba15d0 Mon Sep 17 00:00:00 2001 From: Luyang Date: Fri, 24 Jan 2025 20:35:47 +0800 Subject: [PATCH 21/28] Update oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp Co-authored-by: binbinHan --- oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp index f173c0e6dfd..a2a829cb851 100644 --- a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp @@ -418,7 +418,7 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState } ~NcclLogical2DSameDim1KernelCommState() = default; - ccl::CclComm ccl_comm() { + const ccl::CclComm& ccl_comm() { if (!is_init_) { std::set> device_set; const Shape& hierarchy = *parallel_desc_.hierarchy(); From b369a2cbbdab91d2d8dc575fd86f0bca939eec29 Mon Sep 17 00:00:00 2001 From: Luyang Date: Fri, 24 Jan 2025 20:35:59 +0800 Subject: [PATCH 22/28] Update oneflow/user/kernels/nccl_logical_kernels.cpp Co-authored-by: binbinHan --- oneflow/user/kernels/nccl_logical_kernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp index 34dfb2e8dd5..e2c87e61d23 100644 --- a/oneflow/user/kernels/nccl_logical_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_kernels.cpp @@ -42,7 +42,7 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState { } ~NcclLogicalKernelCommState() override = default; - ccl::CclComm ccl_comm() { + const ccl::CclComm& ccl_comm() { if (!is_init_) { std::set> device_set; FOR_RANGE(int64_t, parallel_id, 0, parallel_desc_.parallel_num()) { From 3e67ade992b589c11fb0f16965a79ccc583b3adc Mon Sep 17 00:00:00 2001 From: Luyang Date: Fri, 24 Jan 2025 20:36:43 +0800 Subject: [PATCH 23/28] Update oneflow/user/kernels/collective_communication/include/recv.h Co-authored-by: binbinHan --- oneflow/user/kernels/collective_communication/include/recv.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/user/kernels/collective_communication/include/recv.h b/oneflow/user/kernels/collective_communication/include/recv.h index 7d6b1c24a5c..46eec9e4426 100644 --- a/oneflow/user/kernels/collective_communication/include/recv.h +++ b/oneflow/user/kernels/collective_communication/include/recv.h @@ -33,7 +33,7 @@ class Recv : public CollectiveCommunication { virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0; virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, - CclComm ccl_comm) const = 0; + const CclComm& ccl_comm) const = 0; }; inline bool IsRecvRegistered(DeviceType device_type) { From 269dd3e6c1183adf69a45bc574437e3809f25927 Mon Sep 17 00:00:00 2001 From: Luyang Date: Fri, 24 Jan 2025 20:36:55 +0800 Subject: [PATCH 24/28] Update oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp Co-authored-by: binbinHan --- oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp index 9f903ef7761..cd0d202eade 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp @@ -41,7 +41,7 @@ class CpuRecvImpl final : public Recv { } void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, - CclComm ccl_comm) const override { + const CclComm& ccl_comm) const override { Launch(stream, out, elem_cnt, src); } From 510524ac7c1b95285f813f02c3b5adf97dd3d909 Mon Sep 17 00:00:00 2001 From: luyang Date: Sat, 25 Jan 2025 14:19:31 +0000 Subject: [PATCH 25/28] refactor GetCclCommForParallelDesc series functions --- oneflow/core/job/eager_ccl_comm_manager.h | 11 ++- oneflow/core/job/eager_nccl_comm_manager.cpp | 76 ++++++++++++++++++- oneflow/core/job/eager_nccl_comm_manager.h | 12 +-- .../kernel/nccl_send_recv_boxing_kernel.cpp | 9 +-- oneflow/user/kernels/eager_nccl_s2s_kernel.cu | 8 +- .../kernels/nccl_logical_2d_sbp_kernels.cpp | 31 +------- .../kernels/nccl_logical_fusion_kernel.cpp | 31 +------- oneflow/user/kernels/nccl_logical_kernels.cpp | 8 +- .../kernels/nccl_logical_send_recv_kernel.cpp | 9 +-- 9 files changed, 98 insertions(+), 97 deletions(-) diff --git a/oneflow/core/job/eager_ccl_comm_manager.h b/oneflow/core/job/eager_ccl_comm_manager.h index fcce9e9509a..647b4d0a132 100644 --- a/oneflow/core/job/eager_ccl_comm_manager.h +++ b/oneflow/core/job/eager_ccl_comm_manager.h @@ -30,10 +30,13 @@ class EagerCclCommMgr { virtual void CreateCommFromPlan(const Plan& plan) = 0; virtual bool IsAsyncLaunchCclLogicalKernel() const = 0; virtual void SetAsyncLaunchCclLogicalKernel(bool val) = 0; - virtual ccl::CclComm GetCclCommForDevice( - const std::set>& device_set) = 0; - virtual ccl::CclComm GetCclCommForDeviceAndStreamName( - const std::set>& device_set, const std::string& stream_name) = 0; + virtual ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) = 0; + virtual ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc, + const std::string& stream_name) = 0; + virtual ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc, + const std::string& stream_name, + const int64_t this_parallel_id, + const std::string& comm_key) = 0; template T* As() { diff --git a/oneflow/core/job/eager_nccl_comm_manager.cpp b/oneflow/core/job/eager_nccl_comm_manager.cpp index 9aadffcd559..fe2755438b7 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.cpp +++ b/oneflow/core/job/eager_nccl_comm_manager.cpp @@ -156,16 +156,84 @@ ncclComm_t EagerNcclCommMgr::GetCommForDeviceAndStreamName( return comm; } -ccl::CclComm EagerNcclCommMgr::GetCclCommForDevice( - const std::set>& device_set) { +ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) { + std::set> device_set; + FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) { + int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + ncclComm_t comm = GetCommForDevice(device_set); std::shared_ptr ncclCommAdapter = std::make_shared(comm); ccl::CclComm ccl_comm(ncclCommAdapter); return ccl_comm; } -ccl::CclComm EagerNcclCommMgr::GetCclCommForDeviceAndStreamName( - const std::set>& device_set, const std::string& stream_name) { +ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescAndStreamName( + const ParallelDesc& parallel_desc, const std::string& stream_name) { + std::set> device_set; + FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) { + int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + + ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name); + std::shared_ptr ncclCommAdapter = std::make_shared(comm); + ccl::CclComm ccl_comm(ncclCommAdapter); + return ccl_comm; +} + +ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescNdHierarchy( + const ParallelDesc& parallel_desc, const std::string& stream_name, + const int64_t this_parallel_id, const std::string& comm_key) { + std::set> device_set; + const Shape& hierarchy = *parallel_desc.hierarchy(); + CHECK_LE(hierarchy.NumAxes(), 2); + + // 1D + if (hierarchy.NumAxes() == 1) { + // 1D hierarchy + for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { + int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + } else if (hierarchy.NumAxes() == 2) { + // 2D hierarchy + CHECK(comm_key == "SameDim0" || comm_key == "SameDim1"); + if (comm_key == "SameDim0") { + const int64_t num_groups = hierarchy.At(0); + const int64_t group_size = hierarchy.At(1); + CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num()); + const int64_t this_group_begin_parallel_id = this_parallel_id / group_size * group_size; + CHECK_EQ(this_group_begin_parallel_id % group_size, 0); + CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc.parallel_num()); + for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { + const int64_t parallel_id = this_group_begin_parallel_id + id_in_group; + const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + } else if (comm_key == "SameDim1") { + const int64_t group_size = hierarchy.At(0); + const int64_t num_groups = hierarchy.At(1); + CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num()); + const int64_t this_group_begin_parallel_id = this_parallel_id % num_groups; + CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups, + parallel_desc.parallel_num()); + for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { + const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups); + const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + } else { + UNIMPLEMENTED(); + } + } + ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name); std::shared_ptr ncclCommAdapter = std::make_shared(comm); ccl::CclComm ccl_comm(ncclCommAdapter); diff --git a/oneflow/core/job/eager_nccl_comm_manager.h b/oneflow/core/job/eager_nccl_comm_manager.h index 489138eb1d9..96351e875a6 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.h +++ b/oneflow/core/job/eager_nccl_comm_manager.h @@ -49,11 +49,13 @@ class EagerNcclCommMgr final : public EagerCclCommMgr { ncclComm_t GetCommForDevice(const std::set>& device_set); ncclComm_t GetCommForDeviceAndStreamName(const std::set>& device_set, const std::string& stream_name); - ccl::CclComm GetCclCommForDevice( - const std::set>& device_set) override; - ccl::CclComm GetCclCommForDeviceAndStreamName( - const std::set>& device_set, - const std::string& stream_name) override; + ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) override; + ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc, + const std::string& stream_name) override; + ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc, + const std::string& stream_name, + const int64_t this_parallel_id, + const std::string& comm_key) override; void CreateCommFromPlan(const Plan& plan) override; bool IsAsyncLaunchCclLogicalKernel() const override { return async_launch_nccl_logical_kernel_; } diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp index e7fc1939f89..63694047bbe 100644 --- a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -52,14 +52,9 @@ class NcclSendRecvBoxingKernel final : public Kernel { void Init() const { ParallelDesc parallel_desc(parallel_conf_); - std::set> device_set; - for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { - int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ccl::CclComm ccl_comm = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); + ccl::CclComm ccl_comm = + comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc, stream_name_); ccl_comm_.reset(new Comm(ccl_comm)); } diff --git a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu index 101f1ca8ec9..d3c5dca522d 100644 --- a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu +++ b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu @@ -41,16 +41,10 @@ class EagerNcclOpKernelCache final : public user_op::OpKernelCache { void Init(user_op::KernelCacheContext* ctx) { const std::string& parallel_conf_txt = ctx->Attr("parallel_conf"); ParallelConf parallel_conf; - std::set> device_set; CHECK(TxtString2PbMessage(parallel_conf_txt, ¶llel_conf)); parallel_desc_ = SymbolOf(ParallelDesc(parallel_conf)); - FOR_RANGE(int64_t, parallel_id, 0, parallel_desc_->parallel_num()) { - int64_t machine_id = CHECK_JUST(parallel_desc_->MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc_->DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ccl_comm_ = comm_mgr->GetCclCommForDevice(device_set); + ccl_comm_ = comm_mgr->GetCclCommForParallelDesc(parallel_conf); } Symbol parallel_desc_; diff --git a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp index d337350844f..655c2b60de7 100644 --- a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp @@ -57,23 +57,12 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState { private: void Init() { CHECK(!is_init_); - std::set> device_set; const Shape& hierarchy = *parallel_desc_.hierarchy(); CHECK_EQ(hierarchy.NumAxes(), 2); - const int64_t num_groups = hierarchy.At(0); const int64_t group_size = hierarchy.At(1); - CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num()); - const int64_t this_group_begin_parallel_id = this_parallel_id_ / group_size * group_size; - CHECK_EQ(this_group_begin_parallel_id % group_size, 0); - CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc_.parallel_num()); - for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { - const int64_t parallel_id = this_group_begin_parallel_id + id_in_group; - const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_, + this_parallel_id_, "SameDim0"); num_ranks_ = group_size; is_init_ = true; } @@ -420,23 +409,11 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState ccl::CclComm ccl_comm() { if (!is_init_) { - std::set> device_set; const Shape& hierarchy = *parallel_desc_.hierarchy(); CHECK_EQ(hierarchy.NumAxes(), 2); - const int64_t group_size = hierarchy.At(0); - const int64_t num_groups = hierarchy.At(1); - CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num()); - const int64_t this_group_begin_parallel_id = this_parallel_id_ % num_groups; - CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups, - parallel_desc_.parallel_num()); - for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { - const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups); - const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_, + this_parallel_id_, "SameDim1"); is_init_ = true; } return ccl_comm_; diff --git a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp index b7bee7d8f17..4efe792d1a8 100644 --- a/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_fusion_kernel.cpp @@ -127,45 +127,17 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { private: void InitComm() { CHECK(!is_init_); - std::set> device_set; const Shape& hierarchy = *parallel_desc_.hierarchy(); if (hierarchy.NumAxes() == 1) { num_ranks_ = parallel_desc_.parallel_num(); - FOR_RANGE(int64_t, parallel_id, 0, parallel_desc_.parallel_num()) { - int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } } else if (hierarchy.NumAxes() == 2) { CHECK(comm_key_ == "SameDim0" || comm_key_ == "SameDim1"); if (comm_key_ == "SameDim0") { - const int64_t num_groups = hierarchy.At(0); const int64_t group_size = hierarchy.At(1); - CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num()); - const int64_t this_group_begin_parallel_id = this_parallel_id_ / group_size * group_size; - CHECK_EQ(this_group_begin_parallel_id % group_size, 0); - CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc_.parallel_num()); - for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { - const int64_t parallel_id = this_group_begin_parallel_id + id_in_group; - const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } num_ranks_ = group_size; } else if (comm_key_ == "SameDim1") { const int64_t group_size = hierarchy.At(0); - const int64_t num_groups = hierarchy.At(1); - CHECK_EQ(num_groups * group_size, parallel_desc_.parallel_num()); - const int64_t this_group_begin_parallel_id = this_parallel_id_ % num_groups; - CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups, - parallel_desc_.parallel_num()); - for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) { - const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups); - const int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - const int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } num_ranks_ = group_size; } else { UNIMPLEMENTED(); @@ -175,7 +147,8 @@ class NcclLogicalFusionKernelState : public user_op::OpKernelState { } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForParallelDescNdHierarchy(parallel_desc_, stream_name_, + this_parallel_id_, comm_key_); is_init_ = true; } diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp index 34dfb2e8dd5..b9775976318 100644 --- a/oneflow/user/kernels/nccl_logical_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_kernels.cpp @@ -44,14 +44,8 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState { ccl::CclComm ccl_comm() { if (!is_init_) { - std::set> device_set; - FOR_RANGE(int64_t, parallel_id, 0, parallel_desc_.parallel_num()) { - int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ccl_comm_ = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); + ccl_comm_ = comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc_, stream_name_); is_init_ = true; } return ccl_comm_; diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index 1606f7d2ccf..d8e6b396cb2 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -126,14 +126,9 @@ NcclLogicalSendRecvState::NcclLogicalSendRecvState(user_op::KernelInitContext* c } void NcclLogicalSendRecvState::InitComm() const { - std::set> device_set; - for (int64_t parallel_id = 0; parallel_id < parallel_desc_->parallel_num(); ++parallel_id) { - int64_t machine_id = CHECK_JUST(parallel_desc_->MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc_->DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ccl::CclComm ccl_comm = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_); + ccl::CclComm ccl_comm = + comm_mgr->GetCclCommForParallelDescAndStreamName(*parallel_desc_.get(), stream_name_); ccl_comm_.reset(new Comm(ccl_comm)); } From b0146084256a975fc7c5ab3c782e46fbd63840fb Mon Sep 17 00:00:00 2001 From: luyang Date: Sat, 25 Jan 2025 15:43:02 +0000 Subject: [PATCH 26/28] refine --- .../user/kernels/collective_communication/cpu/cpu_recv.cpp | 2 +- .../user/kernels/collective_communication/cuda/cuda_send.cpp | 4 +++- oneflow/user/kernels/collective_communication/include/recv.h | 2 +- oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp index cd0d202eade..9f903ef7761 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp @@ -41,7 +41,7 @@ class CpuRecvImpl final : public Recv { } void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, - const CclComm& ccl_comm) const override { + CclComm ccl_comm) const override { Launch(stream, out, elem_cnt, src); } diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp index f5aa6d9045c..65c4b9d23b6 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp @@ -34,6 +34,7 @@ class CudaSend final : public Send { void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override { #if HAS_NCCL_SEND_RECV const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst); + printf("\n CudaSend >>> Launch >>> communication_ctx"); OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, comm_and_peer_rank.second, comm_and_peer_rank.first, stream->As()->cuda_stream())); #else @@ -42,7 +43,7 @@ class CudaSend final : public Send { } void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, - CclComm ccl_comm) const override { + ccl::CclComm ccl_comm) const override { #if HAS_NCCL_SEND_RECV ncclComm_t* comm = reinterpret_cast(ccl_comm.getComm()); OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, dst, *comm, @@ -50,6 +51,7 @@ class CudaSend final : public Send { #else UNIMPLEMENTED() << "GPU send is only supported when nccl version >= 2.7" #endif // HAS_NCCL_SEND_RECV + printf("\n CudaSend >>> Launch >>> ccl::CclComm"); } private: diff --git a/oneflow/user/kernels/collective_communication/include/recv.h b/oneflow/user/kernels/collective_communication/include/recv.h index 46eec9e4426..7d6b1c24a5c 100644 --- a/oneflow/user/kernels/collective_communication/include/recv.h +++ b/oneflow/user/kernels/collective_communication/include/recv.h @@ -33,7 +33,7 @@ class Recv : public CollectiveCommunication { virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0; virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, - const CclComm& ccl_comm) const = 0; + CclComm ccl_comm) const = 0; }; inline bool IsRecvRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp index 3df0d541efd..268d2ba4d5f 100644 --- a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp @@ -42,7 +42,7 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState { } ~NcclLogical2DSameDim0KernelCommState() override = default; - const ccl::CclComm& ccl_comm() const { + ccl::CclComm ccl_comm() { if (!is_init_) { Init(); } return ccl_comm_; } From 0dc6cbcf86a0db943587a2d2c0899a56256d0770 Mon Sep 17 00:00:00 2001 From: luyang Date: Sun, 26 Jan 2025 10:49:54 +0000 Subject: [PATCH 27/28] refactor const ccl::CclComm& --- oneflow/core/job/eager_nccl_comm_manager.h | 2 +- .../kernels/collective_communication/cpu/cpu_all_gather.cpp | 2 +- .../kernels/collective_communication/cpu/cpu_all_reduce.cpp | 2 +- .../user/kernels/collective_communication/cpu/cpu_recv.cpp | 2 +- .../collective_communication/cpu/cpu_reduce_scatter.cpp | 2 +- .../user/kernels/collective_communication/cpu/cpu_send.cpp | 2 +- .../kernels/collective_communication/cuda/cuda_all_gather.cpp | 2 +- .../kernels/collective_communication/cuda/cuda_all_reduce.cpp | 2 +- .../kernels/collective_communication/cuda/cuda_all_to_all.cpp | 4 ++-- .../user/kernels/collective_communication/cuda/cuda_recv.cpp | 2 +- .../collective_communication/cuda/cuda_reduce_scatter.cpp | 2 +- .../user/kernels/collective_communication/cuda/cuda_send.cpp | 4 +--- .../kernels/collective_communication/include/all_gather.h | 2 +- .../kernels/collective_communication/include/all_reduce.h | 2 +- .../kernels/collective_communication/include/all_to_all.h | 4 ++-- .../include/collective_communication.h | 4 ++-- oneflow/user/kernels/collective_communication/include/recv.h | 2 +- .../kernels/collective_communication/include/reduce_scatter.h | 2 +- oneflow/user/kernels/collective_communication/include/send.h | 2 +- oneflow/user/kernels/eager_nccl_s2s_kernel.cu | 2 +- oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp | 2 +- 21 files changed, 24 insertions(+), 26 deletions(-) diff --git a/oneflow/core/job/eager_nccl_comm_manager.h b/oneflow/core/job/eager_nccl_comm_manager.h index 96351e875a6..e335772dc39 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.h +++ b/oneflow/core/job/eager_nccl_comm_manager.h @@ -31,7 +31,7 @@ class NcclCommAdapter : public CommBase { public: explicit NcclCommAdapter(ncclComm_t comm) : comm_(comm) {} - void* getComm() override { return static_cast(&comm_); } + void* getComm() const override { return const_cast(static_cast(&comm_)); } private: ncclComm_t comm_; diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp index c1a581c7aa3..a09ebec52d4 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_all_gather.cpp @@ -96,7 +96,7 @@ class CpuAllGather final : public AllGather { } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, - ccl::CclComm ccl_comm) const override { + const ccl::CclComm& ccl_comm) const override { UNIMPLEMENTED(); } diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp index b4a63890e1c..c03bed6c1db 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_all_reduce.cpp @@ -149,7 +149,7 @@ class CpuAllReduce final : public AllReduce { } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, - ccl::CclComm ccl_comm) const override { + const ccl::CclComm& ccl_comm) const override { UNIMPLEMENTED(); } diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp index 9f903ef7761..2640d8f372f 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp @@ -41,7 +41,7 @@ class CpuRecvImpl final : public Recv { } void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, - CclComm ccl_comm) const override { + const ccl::CclComm& ccl_comm) const override { Launch(stream, out, elem_cnt, src); } diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp index 67f06dabb64..7ba9db44570 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_reduce_scatter.cpp @@ -121,7 +121,7 @@ class CpuReduceScatter final : public ReduceScatter { } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, - ccl::CclComm ccl_comm) const override { + const ccl::CclComm& ccl_comm) const override { UNIMPLEMENTED(); } diff --git a/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp b/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp index f4cbddeede3..829bf5a09cf 100644 --- a/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp +++ b/oneflow/user/kernels/collective_communication/cpu/cpu_send.cpp @@ -41,7 +41,7 @@ class CpuSendImpl final : public Send { } void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, - CclComm comm) const override { + const ccl::CclComm& comm) const override { Launch(stream, in, elem_cnt, dst); } diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp index 936d492bc75..5de37a2ea08 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_gather.cpp @@ -41,7 +41,7 @@ class CudaAllGather final : public AllGather { } virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, - ccl::CclComm ccl_comm) const override { + const ccl::CclComm& ccl_comm) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); OF_NCCL_CHECK(ncclAllGather(in, out, elem_cnt, nccl_datatype_, *nccl_comm, stream->As()->cuda_stream())); diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp index 2d035c6ff60..e9567863a1c 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_reduce.cpp @@ -58,7 +58,7 @@ class CudaAllReduce final : public AllReduce { } void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, - ccl::CclComm ccl_comm) const override { + const ccl::CclComm& ccl_comm) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); OF_NCCL_CHECK(ncclAllReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm, stream->As()->cuda_stream())); diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp index 9f9c3e31ef1..c32363d5c4d 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp @@ -39,7 +39,7 @@ class CudaAllToAll final : public AllToAll { } void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv, int64_t recv_count, - ccl::CclComm ccl_comm) const override { + const ccl::CclComm& ccl_comm) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); int64_t send_offset = 0; int64_t recv_offset = 0; @@ -63,7 +63,7 @@ class CudaAllToAll final : public AllToAll { void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets, void* recv, const void* recv_counts, const void* recv_offsets, - ccl::CclComm ccl_comm) const override { + const ccl::CclComm& ccl_comm) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); int64_t* send_counts_ptr = static_cast(const_cast(send_counts)); int64_t* recv_counts_ptr = static_cast(const_cast(recv_counts)); diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp index 64ae471cd63..46dcdc7b871 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_recv.cpp @@ -42,7 +42,7 @@ class CudaRecv final : public Recv { } void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, - CclComm ccl_comm) const override { + const ccl::CclComm& ccl_comm) const override { #if HAS_NCCL_SEND_RECV ncclComm_t* comm = reinterpret_cast(ccl_comm.getComm()); OF_NCCL_CHECK(ncclRecv(out, elem_cnt, nccl_datatype_, src, *comm, diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp index fd3811df309..3d57f89d6fe 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_reduce_scatter.cpp @@ -58,7 +58,7 @@ class CudaReduceScatter final : public ReduceScatter { } virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, - ccl::CclComm ccl_comm) const override { + const ccl::CclComm& ccl_comm) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); OF_NCCL_CHECK(ncclReduceScatter(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm, stream->As()->cuda_stream())); diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp index 65c4b9d23b6..78de1d86d8e 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp @@ -34,7 +34,6 @@ class CudaSend final : public Send { void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override { #if HAS_NCCL_SEND_RECV const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst); - printf("\n CudaSend >>> Launch >>> communication_ctx"); OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, comm_and_peer_rank.second, comm_and_peer_rank.first, stream->As()->cuda_stream())); #else @@ -43,7 +42,7 @@ class CudaSend final : public Send { } void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, - ccl::CclComm ccl_comm) const override { + const ccl::CclComm& ccl_comm) const override { #if HAS_NCCL_SEND_RECV ncclComm_t* comm = reinterpret_cast(ccl_comm.getComm()); OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, dst, *comm, @@ -51,7 +50,6 @@ class CudaSend final : public Send { #else UNIMPLEMENTED() << "GPU send is only supported when nccl version >= 2.7" #endif // HAS_NCCL_SEND_RECV - printf("\n CudaSend >>> Launch >>> ccl::CclComm"); } private: diff --git a/oneflow/user/kernels/collective_communication/include/all_gather.h b/oneflow/user/kernels/collective_communication/include/all_gather.h index 1212b5eea4b..e765f27e5ca 100644 --- a/oneflow/user/kernels/collective_communication/include/all_gather.h +++ b/oneflow/user/kernels/collective_communication/include/all_gather.h @@ -34,7 +34,7 @@ class AllGather : public CollectiveCommunication { const std::shared_ptr& communicator) const = 0; virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, - ccl::CclComm ccl_comm) const = 0; + const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsAllGatherRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/all_reduce.h b/oneflow/user/kernels/collective_communication/include/all_reduce.h index b152bb8934f..4253864d026 100644 --- a/oneflow/user/kernels/collective_communication/include/all_reduce.h +++ b/oneflow/user/kernels/collective_communication/include/all_reduce.h @@ -34,7 +34,7 @@ class AllReduce : public CollectiveCommunication { const std::shared_ptr& communicator) const = 0; virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, - ccl::CclComm ccl_comm) const = 0; + const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsAllReduceRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/all_to_all.h b/oneflow/user/kernels/collective_communication/include/all_to_all.h index 0f39cf3d875..af7bd3fb3b4 100644 --- a/oneflow/user/kernels/collective_communication/include/all_to_all.h +++ b/oneflow/user/kernels/collective_communication/include/all_to_all.h @@ -32,12 +32,12 @@ class AllToAll : public CollectiveCommunication { // for normal alltoall(balanced send/resv count) virtual void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv, - int64_t recv_count, ccl::CclComm ccl_comm) const = 0; + int64_t recv_count, const ccl::CclComm& ccl_comm) const = 0; // for unbalanced all to all(e.g. nccl all2all using send/recv; hccl HcclAlltoAllV) virtual void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets, void* recv, const void* recv_counts, - const void* recv_offsets, ccl::CclComm ccl_comm) const = 0; + const void* recv_offsets, const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsAllToAllRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/collective_communication.h b/oneflow/user/kernels/collective_communication/include/collective_communication.h index 14ddd364a5a..4f2211cf9d3 100644 --- a/oneflow/user/kernels/collective_communication/include/collective_communication.h +++ b/oneflow/user/kernels/collective_communication/include/collective_communication.h @@ -47,7 +47,7 @@ class CommBase { virtual ~CommBase() = default; // return impl of comm - virtual void* getComm() = 0; + virtual void* getComm() const = 0; }; class CclComm { @@ -55,7 +55,7 @@ class CclComm { CclComm() {} explicit CclComm(std::shared_ptr comm) : comm_(std::move(comm)) {} - void* getComm() { return comm_->getComm(); } + void* getComm() const { return comm_->getComm(); } private: std::shared_ptr comm_{}; diff --git a/oneflow/user/kernels/collective_communication/include/recv.h b/oneflow/user/kernels/collective_communication/include/recv.h index 7d6b1c24a5c..f0cf5d34627 100644 --- a/oneflow/user/kernels/collective_communication/include/recv.h +++ b/oneflow/user/kernels/collective_communication/include/recv.h @@ -33,7 +33,7 @@ class Recv : public CollectiveCommunication { virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0; virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src, - CclComm ccl_comm) const = 0; + const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsRecvRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/reduce_scatter.h b/oneflow/user/kernels/collective_communication/include/reduce_scatter.h index 8b97079e30d..da62e5d6085 100644 --- a/oneflow/user/kernels/collective_communication/include/reduce_scatter.h +++ b/oneflow/user/kernels/collective_communication/include/reduce_scatter.h @@ -34,7 +34,7 @@ class ReduceScatter : public CollectiveCommunication { const std::shared_ptr& communicator) const = 0; virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt, - ccl::CclComm ccl_comm) const = 0; + const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsReduceScatterRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/collective_communication/include/send.h b/oneflow/user/kernels/collective_communication/include/send.h index 23fbda3cb2b..4ca4491c7e5 100644 --- a/oneflow/user/kernels/collective_communication/include/send.h +++ b/oneflow/user/kernels/collective_communication/include/send.h @@ -33,7 +33,7 @@ class Send : public CollectiveCommunication { virtual void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const = 0; virtual void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst, - CclComm ccl_comm) const = 0; + const ccl::CclComm& ccl_comm) const = 0; }; inline bool IsSendRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu index bf9e035d5e3..7518efb0c33 100644 --- a/oneflow/user/kernels/eager_nccl_s2s_kernel.cu +++ b/oneflow/user/kernels/eager_nccl_s2s_kernel.cu @@ -145,7 +145,7 @@ class EagerNcclS2SKernel final : public user_op::OpKernel { const int64_t elem_per_chunk = elem_cnt / num_ranks; std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), in->data_type(), in->data_type(), num_ranks); - ccl::CclComm ccl_comm = kernel_cache->ccl_comm(); + auto& ccl_comm = kernel_cache->ccl_comm(); all_to_all->Launch(ctx->stream(), const_cast(pack_to_ptr), elem_per_chunk, unpack_from_ptr, elem_per_chunk, ccl_comm); } diff --git a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp index 268d2ba4d5f..03cecc9e427 100644 --- a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp @@ -42,7 +42,7 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState { } ~NcclLogical2DSameDim0KernelCommState() override = default; - ccl::CclComm ccl_comm() { + const ccl::CclComm& ccl_comm() { if (!is_init_) { Init(); } return ccl_comm_; } From 65e30465662be65a3d32670dc124cedb52b02872 Mon Sep 17 00:00:00 2001 From: luyang Date: Sun, 26 Jan 2025 14:13:47 +0000 Subject: [PATCH 28/28] refactor ccl::AllToAll --- .../kernel/nccl_send_recv_boxing_kernel.cpp | 5 ++- .../cuda/cuda_all_to_all.cpp | 42 +++++++++++-------- .../include/all_to_all.h | 3 +- .../kernels/nccl_logical_send_recv_kernel.cpp | 4 +- 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp index 63694047bbe..825c43a36a6 100644 --- a/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp +++ b/oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp @@ -119,13 +119,14 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const { } } - if (this->has_input() && this->has_output()) { + if (this->has_input() || this->has_output()) { std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), data_type, data_type, parallel_num); void* send_buf = reinterpret_cast(buf_ptr); void* recv_buf = reinterpret_cast(buf_ptr + recv_offset); all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), - recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm); + recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm, + this->has_input(), this->has_output()); } if (!this->has_output()) { return; } diff --git a/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp index c32363d5c4d..313a14f00e0 100644 --- a/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp +++ b/oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp @@ -63,31 +63,37 @@ class CudaAllToAll final : public AllToAll { void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets, void* recv, const void* recv_counts, const void* recv_offsets, - const ccl::CclComm& ccl_comm) const override { + const ccl::CclComm& ccl_comm, const bool has_input, + const bool has_output) const override { ncclComm_t* nccl_comm = reinterpret_cast(ccl_comm.getComm()); int64_t* send_counts_ptr = static_cast(const_cast(send_counts)); int64_t* recv_counts_ptr = static_cast(const_cast(recv_counts)); int64_t* send_offsets_ptr = static_cast(const_cast(send_offsets)); int64_t* recv_offsets_ptr = static_cast(const_cast(recv_offsets)); - OF_NCCL_CHECK(ncclGroupStart()); - for (int64_t i = 0; i < this->rank_count_; ++i) { - uint64_t send_offset = static_cast(send_offsets_ptr[i]); - uint64_t send_count = static_cast(send_counts_ptr[i]); - char* send_ptr = static_cast(send) + send_offset; - if (send_count > 0) { - OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, - stream->As()->cuda_stream())); - } - - uint64_t recv_offset = static_cast(recv_offsets_ptr[i]); - uint64_t recv_count = static_cast(recv_counts_ptr[i]); - char* recv_ptr = static_cast(recv) + recv_offset; - if (recv_count > 0) { - OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, - stream->As()->cuda_stream())); + if (has_input || has_output) { + OF_NCCL_CHECK(ncclGroupStart()); + for (int64_t i = 0; i < this->rank_count_; ++i) { + if (has_input) { + const uint64_t send_count = static_cast(send_counts_ptr[i]); + if (send_count > 0) { + uint64_t send_offset = static_cast(send_offsets_ptr[i]); + char* send_ptr = static_cast(send) + send_offset; + OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + } + if (has_output) { + const uint64_t recv_count = static_cast(recv_counts_ptr[i]); + if (recv_count > 0) { + uint64_t recv_offset = static_cast(recv_offsets_ptr[i]); + char* recv_ptr = static_cast(recv) + recv_offset; + OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, + stream->As()->cuda_stream())); + } + } } + OF_NCCL_CHECK(ncclGroupEnd()); } - OF_NCCL_CHECK(ncclGroupEnd()); } private: diff --git a/oneflow/user/kernels/collective_communication/include/all_to_all.h b/oneflow/user/kernels/collective_communication/include/all_to_all.h index af7bd3fb3b4..81c35ce80ab 100644 --- a/oneflow/user/kernels/collective_communication/include/all_to_all.h +++ b/oneflow/user/kernels/collective_communication/include/all_to_all.h @@ -37,7 +37,8 @@ class AllToAll : public CollectiveCommunication { // for unbalanced all to all(e.g. nccl all2all using send/recv; hccl HcclAlltoAllV) virtual void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets, void* recv, const void* recv_counts, - const void* recv_offsets, const ccl::CclComm& ccl_comm) const = 0; + const void* recv_offsets, const ccl::CclComm& ccl_comm, const bool has_input, + const bool has_output) const = 0; }; inline bool IsAllToAllRegistered(DeviceType device_type) { diff --git a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp index d8e6b396cb2..fa06e22fd44 100644 --- a/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp +++ b/oneflow/user/kernels/nccl_logical_send_recv_kernel.cpp @@ -193,14 +193,14 @@ void NcclLogicalSendRecv::Compute(user_op::KernelComputeContext* ctx, user_op::O in_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), send_in_ptr.at(i), in->dptr()); } } - const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); std::unique_ptr all_to_all = ccl::NewCollectiveCommunication( ctx->stream()->device_type(), data_type, data_type, parallel_num); void* send_buf = reinterpret_cast(buf_ptr); void* recv_buf = reinterpret_cast(buf_ptr + recv_offset); all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(), recv_buf, - recv_elem_cnts.data(), recv_offsets.data(), ccl_comm); + recv_elem_cnts.data(), recv_offsets.data(), ccl_comm, /*has_input=*/true, + /*has_output=*/true); const std::vector>& out_tensor_slice_copier_vec = kernel_state->out_tensor_slice_copier_vec();