Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
Flowingsun007 committed Jan 25, 2025
1 parent c56c527 commit b014608
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CpuRecvImpl final : public Recv {
}

void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,
const CclComm& ccl_comm) const override {
CclComm ccl_comm) const override {
Launch(stream, out, elem_cnt, src);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class CudaSend final : public Send {
void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override {
#if HAS_NCCL_SEND_RECV
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst);
printf("\n CudaSend >>> Launch >>> communication_ctx");
OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, comm_and_peer_rank.second,
comm_and_peer_rank.first, stream->As<ep::CudaStream>()->cuda_stream()));
#else
Expand All @@ -42,14 +43,15 @@ class CudaSend final : public Send {
}

void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst,
CclComm ccl_comm) const override {
ccl::CclComm ccl_comm) const override {
#if HAS_NCCL_SEND_RECV
ncclComm_t* comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, dst, *comm,
stream->As<ep::CudaStream>()->cuda_stream()));
#else
UNIMPLEMENTED() << "GPU send is only supported when nccl version >= 2.7"
#endif // HAS_NCCL_SEND_RECV
printf("\n CudaSend >>> Launch >>> ccl::CclComm");
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Recv : public CollectiveCommunication {
virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0;

virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,
const CclComm& ccl_comm) const = 0;
CclComm ccl_comm) const = 0;
};

inline bool IsRecvRegistered(DeviceType device_type) {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState {
}
~NcclLogical2DSameDim0KernelCommState() override = default;

const ccl::CclComm& ccl_comm() const {
ccl::CclComm ccl_comm() {
if (!is_init_) { Init(); }
return ccl_comm_;
}
Expand Down

0 comments on commit b014608

Please sign in to comment.