Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev refactor xccl primitive #10613

Merged
merged 32 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c9b7811
raw impl
Flowingsun007 Dec 27, 2024
7bcc593
refine
Flowingsun007 Dec 27, 2024
d70c464
refine
Flowingsun007 Jan 1, 2025
89411cd
impl of ccl::CclComm
Flowingsun007 Jan 2, 2025
9ec2bc6
refine
Flowingsun007 Jan 2, 2025
a0b0391
refine
Flowingsun007 Jan 2, 2025
2c48a5e
refactor _nccl_logical_send_recv using ccl::Comm primitive
Flowingsun007 Jan 2, 2025
aac19b4
refactor _nccl_logical_send_recv using ccl::Comm primitive
Flowingsun007 Jan 2, 2025
3cb872a
refactor ccl::AllGather AllReduce ReduceScatter primitive using ccl::…
Flowingsun007 Jan 2, 2025
e5777b9
refactor _nccl_logical_fusion kernel using ccl::CclComm
Flowingsun007 Jan 2, 2025
ad9e7ee
refine
Flowingsun007 Jan 3, 2025
1809053
Merge branch 'master' into dev_refactor_xccl_primitive
Flowingsun007 Jan 3, 2025
a9cd8df
support ccl::AllToAll
Flowingsun007 Jan 9, 2025
ca258af
Merge branch 'dev_refactor_xccl_primitive' of github.com:Oneflow-Inc/…
Flowingsun007 Jan 9, 2025
15198c6
more kernels using ccl::CclComm and ccl apis
Flowingsun007 Jan 9, 2025
fa3d77f
refine
Flowingsun007 Jan 10, 2025
81394af
refine all2all
Flowingsun007 Jan 10, 2025
94dafda
refine
Flowingsun007 Jan 10, 2025
21529c8
refine
Flowingsun007 Jan 14, 2025
5636f1b
refine
Flowingsun007 Jan 22, 2025
8110bd2
Update oneflow/user/kernels/eager_nccl_s2s_kernel.cu
Flowingsun007 Jan 24, 2025
9bb1fb8
Update oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Flowingsun007 Jan 24, 2025
45fec55
Update oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Flowingsun007 Jan 24, 2025
b369a2c
Update oneflow/user/kernels/nccl_logical_kernels.cpp
Flowingsun007 Jan 24, 2025
3e67ade
Update oneflow/user/kernels/collective_communication/include/recv.h
Flowingsun007 Jan 24, 2025
269dd3e
Update oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp
Flowingsun007 Jan 24, 2025
510524a
refactor GetCclCommForParallelDesc series functions
Flowingsun007 Jan 25, 2025
c56c527
Merge branch 'dev_refactor_xccl_primitive' of github.com:Oneflow-Inc/…
Flowingsun007 Jan 25, 2025
b014608
refine
Flowingsun007 Jan 25, 2025
0dc6cbc
refactor const ccl::CclComm&
Flowingsun007 Jan 26, 2025
65e3046
refactor ccl::AllToAll
Flowingsun007 Jan 26, 2025
0a883a7
Merge branch 'master' into dev_refactor_xccl_primitive
Flowingsun007 Jan 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions oneflow/core/job/eager_ccl_comm_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include "oneflow/core/common/util.h"
#include "oneflow/core/job/plan.pb.h"
#include "oneflow/user/kernels/collective_communication/include/collective_communication.h"

namespace oneflow {

Expand All @@ -29,6 +30,10 @@ class EagerCclCommMgr {
virtual void CreateCommFromPlan(const Plan& plan) = 0;
virtual bool IsAsyncLaunchCclLogicalKernel() const = 0;
virtual void SetAsyncLaunchCclLogicalKernel(bool val) = 0;
virtual ccl::CclComm GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) = 0;
virtual ccl::CclComm GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name) = 0;
Copy link
Contributor

@clackhan clackhan Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
virtual ccl::CclComm GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) = 0;
virtual ccl::CclComm GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name) = 0;
virtual ccl::CclComm GetCclCommForParallelDesc(
const ParallelDesc& parallel_desc) = 0;
virtual ccl::CclComm GetCclCommForParallelDescAndStreamName(
const ParallelDesc& parallel_desc, const std::string& stream_name) = 0;

这里的参数应该是parallel_desc,device_set是cuda需要的形式,parallel_desc 构建device_set的过程应该放到派生类中

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果统一传ParallelDesc,那对于1D/2D的情况不太好处理吧?😂譬如:oneflow/user/kernels/nccl_logical_fusion_kernel.cpp这种,device_set的创建即和hierarchy.NumAxes()相关,还需要comm_key_

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


template<typename T>
T* As() {
Expand Down
16 changes: 16 additions & 0 deletions oneflow/core/job/eager_nccl_comm_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ ncclComm_t EagerNcclCommMgr::GetCommForDeviceAndStreamName(
return comm;
}

ccl::CclComm EagerNcclCommMgr::GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) {
ncclComm_t comm = GetCommForDevice(device_set);
std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);
ccl::CclComm ccl_comm(ncclCommAdapter);
return ccl_comm;
}

ccl::CclComm EagerNcclCommMgr::GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set, const std::string& stream_name) {
ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name);
std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);
ccl::CclComm ccl_comm(ncclCommAdapter);
return ccl_comm;
}

void EagerNcclCommMgr::CreateCommFromPlan(const Plan& plan) {
const int64_t rank = GlobalProcessCtx::Rank();
const int64_t dev = GlobalProcessCtx::LocalRank();
Expand Down
18 changes: 18 additions & 0 deletions oneflow/core/job/eager_nccl_comm_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ limitations under the License.
#include "oneflow/core/device/cuda_util.h"

namespace oneflow {
namespace ccl {

class NcclCommAdapter : public CommBase {
public:
explicit NcclCommAdapter(ncclComm_t comm) : comm_(comm) {}

void* getComm() override { return static_cast<void*>(&comm_); }

private:
ncclComm_t comm_;
};

} // namespace ccl

class EagerNcclCommMgr final : public EagerCclCommMgr {
public:
Expand All @@ -36,6 +49,11 @@ class EagerNcclCommMgr final : public EagerCclCommMgr {
ncclComm_t GetCommForDevice(const std::set<std::pair<int64_t, int64_t>>& device_set);
ncclComm_t GetCommForDeviceAndStreamName(const std::set<std::pair<int64_t, int64_t>>& device_set,
const std::string& stream_name);
ccl::CclComm GetCclCommForDevice(
const std::set<std::pair<int64_t, int64_t>>& device_set) override;
ccl::CclComm GetCclCommForDeviceAndStreamName(
const std::set<std::pair<int64_t, int64_t>>& device_set,
const std::string& stream_name) override;

void CreateCommFromPlan(const Plan& plan) override;
bool IsAsyncLaunchCclLogicalKernel() const override { return async_launch_nccl_logical_kernel_; }
Expand Down
48 changes: 25 additions & 23 deletions oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "oneflow/core/ep/include/primitive/memset.h"
#include "oneflow/core/ep/include/primitive/add.h"
#include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h"
#include "oneflow/user/kernels/collective_communication/include/all_to_all.h"

#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700

Expand All @@ -41,12 +42,12 @@ class NcclSendRecvBoxingKernel final : public Kernel {
const std::vector<int64_t>& recv_elem_cnts() const { return recv_elem_cnts_; }
const bool has_input() const { return has_input_; }
const bool has_output() const { return has_output_; }
ncclComm_t comm() const { return GetOrCreate().comm; }
ccl::CclComm ccl_comm() const { return GetOrCreate().ccl_comm; }

private:
struct Comm {
Comm(ncclComm_t comm) : comm(comm) {}
ncclComm_t comm;
Comm(ccl::CclComm comm) : ccl_comm(comm) {}
ccl::CclComm ccl_comm;
};

void Init() const {
Expand All @@ -58,22 +59,21 @@ class NcclSendRecvBoxingKernel final : public Kernel {
device_set.emplace(std::make_pair(machine_id, device_id));
}
EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());
ncclComm_t comm =
comm_mgr->As<EagerNcclCommMgr>()->GetCommForDeviceAndStreamName(device_set, stream_name_);
comm_.reset(new Comm(comm));
ccl::CclComm ccl_comm = comm_mgr->GetCclCommForDeviceAndStreamName(device_set, stream_name_);
ccl_comm_.reset(new Comm(ccl_comm));
}

const Comm& GetOrCreate() const {
if (!comm_) { Init(); }
return *comm_;
if (!ccl_comm_) { Init(); }
return *ccl_comm_;
}

void VirtualKernelInit(KernelContext* ctx) override;
void ForwardDataContent(KernelContext* ctx) const override;

std::string stream_name_;
ParallelConf parallel_conf_;
mutable std::unique_ptr<Comm> comm_;
mutable std::unique_ptr<Comm> ccl_comm_;
bool src_nd_sbp_no_partial_parallel_;
std::vector<std::shared_ptr<TensorSliceCopier>> in_tensor_slice_copier_vec_;
std::vector<std::shared_ptr<TensorSliceCopier>> out_tensor_slice_copier_vec_;
Expand All @@ -85,27 +85,31 @@ class NcclSendRecvBoxingKernel final : public Kernel {

void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const {
Blob* buf = ctx->BnInOp2Blob("buf");
ncclComm_t comm = this->comm();
cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();
ccl::CclComm ccl_comm = this->ccl_comm();
const std::vector<int64_t>& send_elem_cnts = this->send_elem_cnts();
const std::vector<int64_t>& recv_elem_cnts = this->recv_elem_cnts();
const int64_t parallel_num = this->kernel_conf().parallel_ctx().parallel_num();
const DataType data_type = buf->data_type();
std::vector<void*> send_in_ptr;
std::vector<void*> recv_out_ptr;
std::vector<int64_t> send_offsets;
std::vector<int64_t> recv_offsets;
char* buf_ptr = buf->mut_dptr<char>();
int64_t offset = 0;
uint64_t offset = 0;
if (this->has_input()) {
for (int64_t i = 0; i < parallel_num; ++i) {
void* send_ptr = reinterpret_cast<void*>(buf_ptr + offset);
send_in_ptr.push_back(send_ptr);
send_offsets.push_back(offset);
offset += send_elem_cnts.at(i) * GetSizeOfDataType(data_type);
}
}
const uint64_t recv_offset = offset;
if (this->has_output()) {
for (int64_t i = 0; i < parallel_num; ++i) {
void* recv_ptr = reinterpret_cast<void*>(buf_ptr + offset);
recv_out_ptr.push_back(recv_ptr);
recv_offsets.push_back(offset - recv_offset);
offset += recv_elem_cnts.at(i) * GetSizeOfDataType(data_type);
}
}
Expand All @@ -119,18 +123,16 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const {
}
}
}
OF_NCCL_CHECK(ncclGroupStart());
for (int64_t i = 0; i < parallel_num; ++i) {
if (this->has_input() && send_elem_cnts.at(i) != 0) {
OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), i,
comm, cuda_stream));
}
if (this->has_output() && recv_elem_cnts.at(i) != 0) {
OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), GetNcclDataType(data_type),
i, comm, cuda_stream));
}

if (this->has_input() && this->has_output()) {
std::unique_ptr<ccl::AllToAll> all_to_all = ccl::NewCollectiveCommunication<ccl::AllToAll>(
ctx->stream()->device_type(), data_type, data_type, parallel_num);
void* send_buf = reinterpret_cast<void*>(buf_ptr);
void* recv_buf = reinterpret_cast<void*>(buf_ptr + recv_offset);
all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(),
recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm);
}
OF_NCCL_CHECK(ncclGroupEnd());

if (!this->has_output()) { return; }
Blob* out = ctx->BnInOp2Blob("out");
const std::vector<std::shared_ptr<TensorSliceCopier>>& out_tensor_slice_copier_vec =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ class CpuAllGather final : public AllGather {
CHECK_JUST(AllGatherImpl(in, out, elem_cnt, datatype_, cpu_communication_ctx->parallel_desc()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const override {
UNIMPLEMENTED();
}

private:
DataType datatype_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class CpuAllReduce final : public AllReduce {
cpu_communication_ctx->parallel_desc()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const override {
UNIMPLEMENTED();
}

private:
DataType datatype_;
ReduceType reduce_type_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class CpuRecvImpl final : public Recv {
CHECK_JUST(CpuRecv(out, buffer_size, src));
}

void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,
CclComm ccl_comm) const override {
Flowingsun007 marked this conversation as resolved.
Show resolved Hide resolved
Launch(stream, out, elem_cnt, src);
}

private:
size_t size_of_dtype_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class CpuReduceScatter final : public ReduceScatter {
cpu_communication_ctx->parallel_desc()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const override {
UNIMPLEMENTED();
}

private:
DataType datatype_;
ReduceType reduce_type_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class CpuSendImpl final : public Send {
CHECK_JUST(CpuSend(in, buffer_size, dst));
}

void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst,
CclComm comm) const override {
Launch(stream, in, elem_cnt, dst);
}

private:
size_t size_of_dtype_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ class CudaAllGather final : public AllGather {
stream->As<ep::CudaStream>()->cuda_stream()));
}

virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const override {
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
OF_NCCL_CHECK(ncclAllGather(in, out, elem_cnt, nccl_datatype_, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}

private:
ncclDataType_t nccl_datatype_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ class CudaAllReduce final : public AllReduce {
stream->As<ep::CudaStream>()->cuda_stream()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
ccl::CclComm ccl_comm) const override {
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
OF_NCCL_CHECK(ncclAllReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}

private:
ncclDataType_t nccl_datatype_;
ncclRedOp_t nccl_reduce_op_;
Expand Down
107 changes: 107 additions & 0 deletions oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_CUDA
#include "oneflow/user/kernels/collective_communication/include/all_to_all.h"
#include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h"
#include "oneflow/core/device/nccl_util.h"
#include "oneflow/core/common/device_type.h"

namespace oneflow {

namespace ccl {

class CudaAllToAll final : public AllToAll {
public:
OF_DISALLOW_COPY_AND_MOVE(CudaAllToAll);
CudaAllToAll()
: send_dtype_(), recv_dtype_(), nccl_send_dtype_(), nccl_recv_dtype_(), rank_count_(0) {}
~CudaAllToAll() = default;

void Init(DataType send_dtype, DataType recv_dtype, size_t parallel_num) override {
this->send_dtype_ = send_dtype;
this->recv_dtype_ = recv_dtype;
this->nccl_send_dtype_ = GetNcclDataType(send_dtype);
this->nccl_recv_dtype_ = GetNcclDataType(recv_dtype);
this->rank_count_ = parallel_num;
}

void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv, int64_t recv_count,
ccl::CclComm ccl_comm) const override {
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
int64_t send_offset = 0;
int64_t recv_offset = 0;
OF_NCCL_CHECK(ncclGroupStart());
for (int64_t i = 0; i < this->rank_count_; ++i) {
if (send_count > 0) {
char* send_ptr = static_cast<char*>(send) + send_offset;
OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}
send_offset += send_count * GetSizeOfDataType(this->send_dtype_);
if (recv_count) {
char* recv_ptr = static_cast<char*>(recv) + recv_offset;
OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}
recv_offset += recv_count * GetSizeOfDataType(this->recv_dtype_);
}
OF_NCCL_CHECK(ncclGroupEnd());
}

void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets,
void* recv, const void* recv_counts, const void* recv_offsets,
ccl::CclComm ccl_comm) const override {
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
int64_t* send_counts_ptr = static_cast<int64_t*>(const_cast<void*>(send_counts));
int64_t* recv_counts_ptr = static_cast<int64_t*>(const_cast<void*>(recv_counts));
int64_t* send_offsets_ptr = static_cast<int64_t*>(const_cast<void*>(send_offsets));
int64_t* recv_offsets_ptr = static_cast<int64_t*>(const_cast<void*>(recv_offsets));
OF_NCCL_CHECK(ncclGroupStart());
for (int64_t i = 0; i < this->rank_count_; ++i) {
uint64_t send_offset = static_cast<uint64_t>(send_offsets_ptr[i]);
uint64_t send_count = static_cast<uint64_t>(send_counts_ptr[i]);
char* send_ptr = static_cast<char*>(send) + send_offset;
if (send_count > 0) {
OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}

uint64_t recv_offset = static_cast<uint64_t>(recv_offsets_ptr[i]);
uint64_t recv_count = static_cast<uint64_t>(recv_counts_ptr[i]);
char* recv_ptr = static_cast<char*>(recv) + recv_offset;
if (recv_count > 0) {
OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}
}
OF_NCCL_CHECK(ncclGroupEnd());
}

private:
DataType send_dtype_;
DataType recv_dtype_;
ncclDataType_t nccl_send_dtype_;
ncclDataType_t nccl_recv_dtype_;
size_t rank_count_;
};

REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, AllToAll, CudaAllToAll);

} // namespace ccl

} // namespace oneflow

#endif // WITH_CUDA
Loading
Loading