Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev refactor xccl primitive #10613

Merged
merged 32 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c9b7811
raw impl
Flowingsun007 Dec 27, 2024
7bcc593
refine
Flowingsun007 Dec 27, 2024
d70c464
refine
Flowingsun007 Jan 1, 2025
89411cd
impl of ccl::CclComm
Flowingsun007 Jan 2, 2025
9ec2bc6
refine
Flowingsun007 Jan 2, 2025
a0b0391
refine
Flowingsun007 Jan 2, 2025
2c48a5e
refactor _nccl_logical_send_recv using ccl::Comm primitive
Flowingsun007 Jan 2, 2025
aac19b4
refactor _nccl_logical_send_recv using ccl::Comm primitive
Flowingsun007 Jan 2, 2025
3cb872a
refactor ccl::AllGather AllReduce ReduceScatter primitive using ccl::…
Flowingsun007 Jan 2, 2025
e5777b9
refactor _nccl_logical_fusion kernel using ccl::CclComm
Flowingsun007 Jan 2, 2025
ad9e7ee
refine
Flowingsun007 Jan 3, 2025
1809053
Merge branch 'master' into dev_refactor_xccl_primitive
Flowingsun007 Jan 3, 2025
a9cd8df
support ccl::AllToAll
Flowingsun007 Jan 9, 2025
ca258af
Merge branch 'dev_refactor_xccl_primitive' of github.com:Oneflow-Inc/…
Flowingsun007 Jan 9, 2025
15198c6
more kernels using ccl::CclComm and ccl apis
Flowingsun007 Jan 9, 2025
fa3d77f
refine
Flowingsun007 Jan 10, 2025
81394af
refine all2all
Flowingsun007 Jan 10, 2025
94dafda
refine
Flowingsun007 Jan 10, 2025
21529c8
refine
Flowingsun007 Jan 14, 2025
5636f1b
refine
Flowingsun007 Jan 22, 2025
8110bd2
Update oneflow/user/kernels/eager_nccl_s2s_kernel.cu
Flowingsun007 Jan 24, 2025
9bb1fb8
Update oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Flowingsun007 Jan 24, 2025
45fec55
Update oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Flowingsun007 Jan 24, 2025
b369a2c
Update oneflow/user/kernels/nccl_logical_kernels.cpp
Flowingsun007 Jan 24, 2025
3e67ade
Update oneflow/user/kernels/collective_communication/include/recv.h
Flowingsun007 Jan 24, 2025
269dd3e
Update oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp
Flowingsun007 Jan 24, 2025
510524a
refactor GetCclCommForParallelDesc series functions
Flowingsun007 Jan 25, 2025
c56c527
Merge branch 'dev_refactor_xccl_primitive' of github.com:Oneflow-Inc/…
Flowingsun007 Jan 25, 2025
b014608
refine
Flowingsun007 Jan 25, 2025
0dc6cbc
refactor const ccl::CclComm&
Flowingsun007 Jan 26, 2025
65e3046
refactor ccl::AllToAll
Flowingsun007 Jan 26, 2025
0a883a7
Merge branch 'master' into dev_refactor_xccl_primitive
Flowingsun007 Jan 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions oneflow/core/job/eager_ccl_comm_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include "oneflow/core/common/util.h"
#include "oneflow/core/job/plan.pb.h"
#include "oneflow/user/kernels/collective_communication/include/collective_communication.h"

namespace oneflow {

Expand All @@ -29,6 +30,13 @@ class EagerCclCommMgr {
virtual void CreateCommFromPlan(const Plan& plan) = 0;
virtual bool IsAsyncLaunchCclLogicalKernel() const = 0;
virtual void SetAsyncLaunchCclLogicalKernel(bool val) = 0;
virtual ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) = 0;
virtual ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc,
const std::string& stream_name) = 0;
virtual ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc,
const std::string& stream_name,
const int64_t this_parallel_id,
const std::string& comm_key) = 0;

template<typename T>
T* As() {
Expand Down
84 changes: 84 additions & 0 deletions oneflow/core/job/eager_nccl_comm_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,90 @@ ncclComm_t EagerNcclCommMgr::GetCommForDeviceAndStreamName(
return comm;
}

ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) {
std::set<std::pair<int64_t, int64_t>> device_set;
FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) {
int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}

ncclComm_t comm = GetCommForDevice(device_set);
std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);
ccl::CclComm ccl_comm(ncclCommAdapter);
return ccl_comm;
}

ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescAndStreamName(
const ParallelDesc& parallel_desc, const std::string& stream_name) {
std::set<std::pair<int64_t, int64_t>> device_set;
FOR_RANGE(int64_t, parallel_id, 0, parallel_desc.parallel_num()) {
int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}

ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name);
std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);
ccl::CclComm ccl_comm(ncclCommAdapter);
return ccl_comm;
}

ccl::CclComm EagerNcclCommMgr::GetCclCommForParallelDescNdHierarchy(
const ParallelDesc& parallel_desc, const std::string& stream_name,
const int64_t this_parallel_id, const std::string& comm_key) {
std::set<std::pair<int64_t, int64_t>> device_set;
const Shape& hierarchy = *parallel_desc.hierarchy();
CHECK_LE(hierarchy.NumAxes(), 2);

// 1D
if (hierarchy.NumAxes() == 1) {
// 1D hierarchy
for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) {
int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
} else if (hierarchy.NumAxes() == 2) {
// 2D hierarchy
CHECK(comm_key == "SameDim0" || comm_key == "SameDim1");
if (comm_key == "SameDim0") {
const int64_t num_groups = hierarchy.At(0);
const int64_t group_size = hierarchy.At(1);
CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num());
const int64_t this_group_begin_parallel_id = this_parallel_id / group_size * group_size;
CHECK_EQ(this_group_begin_parallel_id % group_size, 0);
CHECK_LE(this_group_begin_parallel_id + group_size, parallel_desc.parallel_num());
for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) {
const int64_t parallel_id = this_group_begin_parallel_id + id_in_group;
const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
} else if (comm_key == "SameDim1") {
const int64_t group_size = hierarchy.At(0);
const int64_t num_groups = hierarchy.At(1);
CHECK_EQ(num_groups * group_size, parallel_desc.parallel_num());
const int64_t this_group_begin_parallel_id = this_parallel_id % num_groups;
CHECK_LT(this_group_begin_parallel_id + (group_size - 1) * num_groups,
parallel_desc.parallel_num());
for (int64_t id_in_group = 0; id_in_group < group_size; ++id_in_group) {
const int64_t parallel_id = this_group_begin_parallel_id + (id_in_group * num_groups);
const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
const int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
} else {
UNIMPLEMENTED();
}
}

ncclComm_t comm = GetCommForDeviceAndStreamName(device_set, stream_name);
std::shared_ptr<ccl::CommBase> ncclCommAdapter = std::make_shared<ccl::NcclCommAdapter>(comm);
ccl::CclComm ccl_comm(ncclCommAdapter);
return ccl_comm;
}

void EagerNcclCommMgr::CreateCommFromPlan(const Plan& plan) {
const int64_t rank = GlobalProcessCtx::Rank();
const int64_t dev = GlobalProcessCtx::LocalRank();
Expand Down
20 changes: 20 additions & 0 deletions oneflow/core/job/eager_nccl_comm_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ limitations under the License.
#include "oneflow/core/device/cuda_util.h"

namespace oneflow {
namespace ccl {

class NcclCommAdapter : public CommBase {
public:
explicit NcclCommAdapter(ncclComm_t comm) : comm_(comm) {}

void* getComm() const override { return const_cast<void*>(static_cast<const void*>(&comm_)); }

private:
ncclComm_t comm_;
};

} // namespace ccl

class EagerNcclCommMgr final : public EagerCclCommMgr {
public:
Expand All @@ -36,6 +49,13 @@ class EagerNcclCommMgr final : public EagerCclCommMgr {
ncclComm_t GetCommForDevice(const std::set<std::pair<int64_t, int64_t>>& device_set);
ncclComm_t GetCommForDeviceAndStreamName(const std::set<std::pair<int64_t, int64_t>>& device_set,
const std::string& stream_name);
ccl::CclComm GetCclCommForParallelDesc(const ParallelDesc& parallel_desc) override;
ccl::CclComm GetCclCommForParallelDescAndStreamName(const ParallelDesc& parallel_desc,
const std::string& stream_name) override;
ccl::CclComm GetCclCommForParallelDescNdHierarchy(const ParallelDesc& parallel_desc,
const std::string& stream_name,
const int64_t this_parallel_id,
const std::string& comm_key) override;

void CreateCommFromPlan(const Plan& plan) override;
bool IsAsyncLaunchCclLogicalKernel() const override { return async_launch_nccl_logical_kernel_; }
Expand Down
55 changes: 26 additions & 29 deletions oneflow/core/kernel/nccl_send_recv_boxing_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "oneflow/core/ep/include/primitive/memset.h"
#include "oneflow/core/ep/include/primitive/add.h"
#include "oneflow/core/operator/nccl_send_recv_boxing_op_util.h"
#include "oneflow/user/kernels/collective_communication/include/all_to_all.h"

#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700

Expand All @@ -41,39 +42,33 @@ class NcclSendRecvBoxingKernel final : public Kernel {
const std::vector<int64_t>& recv_elem_cnts() const { return recv_elem_cnts_; }
const bool has_input() const { return has_input_; }
const bool has_output() const { return has_output_; }
ncclComm_t comm() const { return GetOrCreate().comm; }
ccl::CclComm ccl_comm() const { return GetOrCreate().ccl_comm; }

private:
struct Comm {
Comm(ncclComm_t comm) : comm(comm) {}
ncclComm_t comm;
Comm(ccl::CclComm comm) : ccl_comm(comm) {}
ccl::CclComm ccl_comm;
};

void Init() const {
ParallelDesc parallel_desc(parallel_conf_);
std::set<std::pair<int64_t, int64_t>> device_set;
for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) {
int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_set.emplace(std::make_pair(machine_id, device_id));
}
EagerCclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton<EagerCclCommMgr>::Get());
ncclComm_t comm =
comm_mgr->As<EagerNcclCommMgr>()->GetCommForDeviceAndStreamName(device_set, stream_name_);
comm_.reset(new Comm(comm));
ccl::CclComm ccl_comm =
comm_mgr->GetCclCommForParallelDescAndStreamName(parallel_desc, stream_name_);
ccl_comm_.reset(new Comm(ccl_comm));
}

const Comm& GetOrCreate() const {
if (!comm_) { Init(); }
return *comm_;
if (!ccl_comm_) { Init(); }
return *ccl_comm_;
}

void VirtualKernelInit(KernelContext* ctx) override;
void ForwardDataContent(KernelContext* ctx) const override;

std::string stream_name_;
ParallelConf parallel_conf_;
mutable std::unique_ptr<Comm> comm_;
mutable std::unique_ptr<Comm> ccl_comm_;
bool src_nd_sbp_no_partial_parallel_;
std::vector<std::shared_ptr<TensorSliceCopier>> in_tensor_slice_copier_vec_;
std::vector<std::shared_ptr<TensorSliceCopier>> out_tensor_slice_copier_vec_;
Expand All @@ -85,27 +80,31 @@ class NcclSendRecvBoxingKernel final : public Kernel {

void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const {
Blob* buf = ctx->BnInOp2Blob("buf");
ncclComm_t comm = this->comm();
cudaStream_t cuda_stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();
ccl::CclComm ccl_comm = this->ccl_comm();
const std::vector<int64_t>& send_elem_cnts = this->send_elem_cnts();
const std::vector<int64_t>& recv_elem_cnts = this->recv_elem_cnts();
const int64_t parallel_num = this->kernel_conf().parallel_ctx().parallel_num();
const DataType data_type = buf->data_type();
std::vector<void*> send_in_ptr;
std::vector<void*> recv_out_ptr;
std::vector<int64_t> send_offsets;
std::vector<int64_t> recv_offsets;
char* buf_ptr = buf->mut_dptr<char>();
int64_t offset = 0;
uint64_t offset = 0;
if (this->has_input()) {
for (int64_t i = 0; i < parallel_num; ++i) {
void* send_ptr = reinterpret_cast<void*>(buf_ptr + offset);
send_in_ptr.push_back(send_ptr);
send_offsets.push_back(offset);
offset += send_elem_cnts.at(i) * GetSizeOfDataType(data_type);
}
}
const uint64_t recv_offset = offset;
if (this->has_output()) {
for (int64_t i = 0; i < parallel_num; ++i) {
void* recv_ptr = reinterpret_cast<void*>(buf_ptr + offset);
recv_out_ptr.push_back(recv_ptr);
recv_offsets.push_back(offset - recv_offset);
offset += recv_elem_cnts.at(i) * GetSizeOfDataType(data_type);
}
}
Expand All @@ -119,18 +118,16 @@ void NcclSendRecvBoxingKernel::ForwardDataContent(KernelContext* ctx) const {
}
}
}
OF_NCCL_CHECK(ncclGroupStart());
for (int64_t i = 0; i < parallel_num; ++i) {
if (this->has_input() && send_elem_cnts.at(i) != 0) {
OF_NCCL_CHECK(ncclSend(send_in_ptr.at(i), send_elem_cnts.at(i), GetNcclDataType(data_type), i,
comm, cuda_stream));
}
if (this->has_output() && recv_elem_cnts.at(i) != 0) {
OF_NCCL_CHECK(ncclRecv(recv_out_ptr.at(i), recv_elem_cnts.at(i), GetNcclDataType(data_type),
i, comm, cuda_stream));
}

if (this->has_input() && this->has_output()) {
std::unique_ptr<ccl::AllToAll> all_to_all = ccl::NewCollectiveCommunication<ccl::AllToAll>(
ctx->stream()->device_type(), data_type, data_type, parallel_num);
void* send_buf = reinterpret_cast<void*>(buf_ptr);
void* recv_buf = reinterpret_cast<void*>(buf_ptr + recv_offset);
all_to_all->Launch(ctx->stream(), send_buf, send_elem_cnts.data(), send_offsets.data(),
recv_buf, recv_elem_cnts.data(), recv_offsets.data(), ccl_comm);
}
OF_NCCL_CHECK(ncclGroupEnd());

if (!this->has_output()) { return; }
Blob* out = ctx->BnInOp2Blob("out");
const std::vector<std::shared_ptr<TensorSliceCopier>>& out_tensor_slice_copier_vec =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ class CpuAllGather final : public AllGather {
CHECK_JUST(AllGatherImpl(in, out, elem_cnt, datatype_, cpu_communication_ctx->parallel_desc()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
const ccl::CclComm& ccl_comm) const override {
UNIMPLEMENTED();
}

private:
DataType datatype_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class CpuAllReduce final : public AllReduce {
cpu_communication_ctx->parallel_desc()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
const ccl::CclComm& ccl_comm) const override {
UNIMPLEMENTED();
}

private:
DataType datatype_;
ReduceType reduce_type_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class CpuRecvImpl final : public Recv {
CHECK_JUST(CpuRecv(out, buffer_size, src));
}

void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,
const ccl::CclComm& ccl_comm) const override {
Launch(stream, out, elem_cnt, src);
}

private:
size_t size_of_dtype_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class CpuReduceScatter final : public ReduceScatter {
cpu_communication_ctx->parallel_desc()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
const ccl::CclComm& ccl_comm) const override {
UNIMPLEMENTED();
}

private:
DataType datatype_;
ReduceType reduce_type_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class CpuSendImpl final : public Send {
CHECK_JUST(CpuSend(in, buffer_size, dst));
}

void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst,
const ccl::CclComm& comm) const override {
Launch(stream, in, elem_cnt, dst);
}

private:
size_t size_of_dtype_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ class CudaAllGather final : public AllGather {
stream->As<ep::CudaStream>()->cuda_stream()));
}

virtual void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
const ccl::CclComm& ccl_comm) const override {
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
OF_NCCL_CHECK(ncclAllGather(in, out, elem_cnt, nccl_datatype_, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}

private:
ncclDataType_t nccl_datatype_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ class CudaAllReduce final : public AllReduce {
stream->As<ep::CudaStream>()->cuda_stream()));
}

void Launch(ep::Stream* stream, const void* in, void* out, size_t elem_cnt,
const ccl::CclComm& ccl_comm) const override {
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
OF_NCCL_CHECK(ncclAllReduce(in, out, elem_cnt, nccl_datatype_, nccl_reduce_op_, *nccl_comm,
stream->As<ep::CudaStream>()->cuda_stream()));
}

private:
ncclDataType_t nccl_datatype_;
ncclRedOp_t nccl_reduce_op_;
Expand Down
Loading
Loading