Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ using FlagType = uint32_t;
// waiting for counter. We use alternating counter array to avoid this
// possibility.
struct Signal {
alignas(128) FlagType start[kMaxBlocks][8];
alignas(128) FlagType end[kMaxBlocks][8];
alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank
alignas(128) FlagType start[kMaxBlocks][8]; // 32*kMaxBlocks
alignas(128) FlagType end[kMaxBlocks][8]; // // 32*kMaxBlocks
alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank // 4*kMaxBlocks
};

struct __align__(16) RankData {
Expand Down Expand Up @@ -320,29 +320,30 @@ template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x; // 线程的全局索引,也是packed的数据索引
int stride = gridDim.x * blockDim.x; // 当前block的当前线程,在下一个迭代中要处理的数据索引相对于当前处理的数据索引的偏移量。
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
int part = size / ngpus; // 每个rank分到多少packed数据
int start = rank * part; // 本rank的数据起点
int end = rank == ngpus - 1 ? size : start + part;// 本rank的终点,一般是start+end, last_rank则是size
int largest_part = part + size % ngpus; // last_rank的数据量
const P* ptrs[ngpus]; // 每个元素都是 P*,共ngpus个
P* tmps[ngpus]; // 同上临时变量
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
int target = (rank + i) % ngpus; // 每个rank都算,对rank0:0,1,2,3; 对rank1: 1,2,3,0; 对rank2:2,3,0,1
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
tmps[i] = get_tmp_buf<P>(sg.signals[target]); // 取得最开始在python中申请的临时buf位,每个rank都有一个tmp_buf大小为max_size=16mb
}
auto tmp_out = tmps[0];
auto tmp_out = tmps[0]; // 上面的轮换顺序,意味着这里的tmps[0]其实就是每个rank的自己的tmp_buf
barrier_at_start<ngpus>(sg, self_sg, rank);

// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx); // 因此这里的赋值,是对每个rank的tmp_buf都进行赋值
}
// 这里不同rank的tmp[0],其实是数据的不同part,
barrier_at_end<ngpus>(sg, self_sg, rank);

// stage 2: allgather. Note: it's important to match the tid between
Expand Down