-
Notifications
You must be signed in to change notification settings - Fork 760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev refactor xccl primitive #10613
Merged
Merged
Dev refactor xccl primitive #10613
Changes from 20 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
c9b7811
raw impl
Flowingsun007 7bcc593
refine
Flowingsun007 d70c464
refine
Flowingsun007 89411cd
impl of ccl::CclComm
Flowingsun007 9ec2bc6
refine
Flowingsun007 a0b0391
refine
Flowingsun007 2c48a5e
refactor _nccl_logical_send_recv using ccl::Comm primitive
Flowingsun007 aac19b4
refactor _nccl_logical_send_recv using ccl::Comm primitive
Flowingsun007 3cb872a
refactor ccl::AllGather AllReduce ReduceScatter primitive using ccl::…
Flowingsun007 e5777b9
refactor _nccl_logical_fusion kernel using ccl::CclComm
Flowingsun007 ad9e7ee
refine
Flowingsun007 1809053
Merge branch 'master' into dev_refactor_xccl_primitive
Flowingsun007 a9cd8df
support ccl::AllToAll
Flowingsun007 ca258af
Merge branch 'dev_refactor_xccl_primitive' of github.com:Oneflow-Inc/…
Flowingsun007 15198c6
more kernels using ccl::CclComm and ccl apis
Flowingsun007 fa3d77f
refine
Flowingsun007 81394af
refine all2all
Flowingsun007 94dafda
refine
Flowingsun007 21529c8
refine
Flowingsun007 5636f1b
refine
Flowingsun007 8110bd2
Update oneflow/user/kernels/eager_nccl_s2s_kernel.cu
Flowingsun007 9bb1fb8
Update oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Flowingsun007 45fec55
Update oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Flowingsun007 b369a2c
Update oneflow/user/kernels/nccl_logical_kernels.cpp
Flowingsun007 3e67ade
Update oneflow/user/kernels/collective_communication/include/recv.h
Flowingsun007 269dd3e
Update oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp
Flowingsun007 510524a
refactor GetCclCommForParallelDesc series functions
Flowingsun007 c56c527
Merge branch 'dev_refactor_xccl_primitive' of github.com:Oneflow-Inc/…
Flowingsun007 b014608
refine
Flowingsun007 0dc6cbc
refactor const ccl::CclComm&
Flowingsun007 65e3046
refactor ccl::AllToAll
Flowingsun007 0a883a7
Merge branch 'master' into dev_refactor_xccl_primitive
Flowingsun007 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
107 changes: 107 additions & 0 deletions
107
oneflow/user/kernels/collective_communication/cuda/cuda_all_to_all.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#ifdef WITH_CUDA | ||
#include "oneflow/user/kernels/collective_communication/include/all_to_all.h" | ||
#include "oneflow/user/kernels/collective_communication/cuda/cuda_communication_context.h" | ||
#include "oneflow/core/device/nccl_util.h" | ||
#include "oneflow/core/common/device_type.h" | ||
|
||
namespace oneflow { | ||
|
||
namespace ccl { | ||
|
||
class CudaAllToAll final : public AllToAll { | ||
public: | ||
OF_DISALLOW_COPY_AND_MOVE(CudaAllToAll); | ||
CudaAllToAll() | ||
: send_dtype_(), recv_dtype_(), nccl_send_dtype_(), nccl_recv_dtype_(), rank_count_(0) {} | ||
~CudaAllToAll() = default; | ||
|
||
void Init(DataType send_dtype, DataType recv_dtype, size_t parallel_num) override { | ||
this->send_dtype_ = send_dtype; | ||
this->recv_dtype_ = recv_dtype; | ||
this->nccl_send_dtype_ = GetNcclDataType(send_dtype); | ||
this->nccl_recv_dtype_ = GetNcclDataType(recv_dtype); | ||
this->rank_count_ = parallel_num; | ||
} | ||
|
||
void Launch(ep::Stream* stream, void* send, int64_t send_count, void* recv, int64_t recv_count, | ||
ccl::CclComm ccl_comm) const override { | ||
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm()); | ||
int64_t send_offset = 0; | ||
int64_t recv_offset = 0; | ||
OF_NCCL_CHECK(ncclGroupStart()); | ||
for (int64_t i = 0; i < this->rank_count_; ++i) { | ||
if (send_count > 0) { | ||
char* send_ptr = static_cast<char*>(send) + send_offset; | ||
OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, | ||
stream->As<ep::CudaStream>()->cuda_stream())); | ||
} | ||
send_offset += send_count * GetSizeOfDataType(this->send_dtype_); | ||
if (recv_count) { | ||
char* recv_ptr = static_cast<char*>(recv) + recv_offset; | ||
OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, | ||
stream->As<ep::CudaStream>()->cuda_stream())); | ||
} | ||
recv_offset += recv_count * GetSizeOfDataType(this->recv_dtype_); | ||
} | ||
OF_NCCL_CHECK(ncclGroupEnd()); | ||
} | ||
|
||
void Launch(ep::Stream* stream, void* send, const void* send_counts, const void* send_offsets, | ||
void* recv, const void* recv_counts, const void* recv_offsets, | ||
ccl::CclComm ccl_comm) const override { | ||
ncclComm_t* nccl_comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm()); | ||
int64_t* send_counts_ptr = static_cast<int64_t*>(const_cast<void*>(send_counts)); | ||
int64_t* recv_counts_ptr = static_cast<int64_t*>(const_cast<void*>(recv_counts)); | ||
int64_t* send_offsets_ptr = static_cast<int64_t*>(const_cast<void*>(send_offsets)); | ||
int64_t* recv_offsets_ptr = static_cast<int64_t*>(const_cast<void*>(recv_offsets)); | ||
OF_NCCL_CHECK(ncclGroupStart()); | ||
for (int64_t i = 0; i < this->rank_count_; ++i) { | ||
uint64_t send_offset = static_cast<uint64_t>(send_offsets_ptr[i]); | ||
uint64_t send_count = static_cast<uint64_t>(send_counts_ptr[i]); | ||
char* send_ptr = static_cast<char*>(send) + send_offset; | ||
if (send_count > 0) { | ||
OF_NCCL_CHECK(ncclSend(send_ptr, send_count, this->nccl_send_dtype_, i, *nccl_comm, | ||
stream->As<ep::CudaStream>()->cuda_stream())); | ||
} | ||
|
||
uint64_t recv_offset = static_cast<uint64_t>(recv_offsets_ptr[i]); | ||
uint64_t recv_count = static_cast<uint64_t>(recv_counts_ptr[i]); | ||
char* recv_ptr = static_cast<char*>(recv) + recv_offset; | ||
if (recv_count > 0) { | ||
OF_NCCL_CHECK(ncclRecv(recv_ptr, recv_count, this->nccl_recv_dtype_, i, *nccl_comm, | ||
stream->As<ep::CudaStream>()->cuda_stream())); | ||
} | ||
} | ||
OF_NCCL_CHECK(ncclGroupEnd()); | ||
} | ||
|
||
private: | ||
DataType send_dtype_; | ||
DataType recv_dtype_; | ||
ncclDataType_t nccl_send_dtype_; | ||
ncclDataType_t nccl_recv_dtype_; | ||
size_t rank_count_; | ||
}; | ||
|
||
REGISTER_COLLECTIVE_COMMUNICATION(DeviceType::kCUDA, AllToAll, CudaAllToAll); | ||
|
||
} // namespace ccl | ||
|
||
} // namespace oneflow | ||
|
||
#endif // WITH_CUDA |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的参数应该是
parallel_desc
,device_set是cuda需要的形式,parallel_desc 构建device_set的过程应该放到派生类中There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果统一传ParallelDesc,那对于1D/2D的情况不太好处理吧?😂譬如:oneflow/user/kernels/nccl_logical_fusion_kernel.cpp这种,device_set的创建即和
hierarchy.NumAxes()
相关,还需要comm_key_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改