diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 44709b459776..8b2a2e720264 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -51,9 +51,9 @@ using FlagType = uint32_t; // waiting for counter. We use alternating counter array to avoid this // possibility. struct Signal { - alignas(128) FlagType start[kMaxBlocks][8]; - alignas(128) FlagType end[kMaxBlocks][8]; - alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank + alignas(128) FlagType start[kMaxBlocks][8]; // 32*kMaxBlocks + alignas(128) FlagType end[kMaxBlocks][8]; // // 32*kMaxBlocks + alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank // 4*kMaxBlocks }; struct __align__(16) RankData { @@ -320,29 +320,30 @@ template __global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; // 线程的全局索引,也是packed的数据索引 + int stride = gridDim.x * blockDim.x; // 当前block的当前线程,在下一个迭代中要处理的数据索引相对于当前处理的数据索引的偏移量。 using P = typename packed_t::P; using A = typename packed_t::A; - int part = size / ngpus; - int start = rank * part; - int end = rank == ngpus - 1 ? size : start + part; - int largest_part = part + size % ngpus; - const P* ptrs[ngpus]; - P* tmps[ngpus]; + int part = size / ngpus; // 每个rank分到多少packed数据 + int start = rank * part; // 本rank的数据起点 + int end = rank == ngpus - 1 ? size : start + part;// 本rank的终点,一般是start+end, last_rank则是size + int largest_part = part + size % ngpus; // last_rank的数据量 + const P* ptrs[ngpus]; // 每个元素都是 P*,共ngpus个 + P* tmps[ngpus]; // 同上临时变量 #pragma unroll for (int i = 0; i < ngpus; i++) { - int target = (rank + i) % ngpus; + int target = (rank + i) % ngpus; // 每个rank都算,对rank0:0,1,2,3; 对rank1: 1,2,3,0; 对rank2:2,3,0,1 ptrs[i] = (const P*)_dp->ptrs[target]; - tmps[i] = get_tmp_buf

(sg.signals[target]); + tmps[i] = get_tmp_buf

(sg.signals[target]); // 取得最开始在python中申请的临时buf位,每个rank都有一个tmp_buf大小为max_size=16mb } - auto tmp_out = tmps[0]; + auto tmp_out = tmps[0]; // 上面的轮换顺序,意味着这里的tmps[0]其实就是每个rank的自己的tmp_buf barrier_at_start(sg, self_sg, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { - tmp_out[idx - start] = packed_reduce(ptrs, idx); + tmp_out[idx - start] = packed_reduce(ptrs, idx); // 因此这里的赋值,是对每个rank的tmp_buf都进行赋值 } + // 这里不同rank的tmp[0],其实是数据的不同part, barrier_at_end(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between