Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions transformer_engine/common/util/standalone_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,6 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, const
} else {
static_assert(sizeof(WideT) % sizeof(T) == 0);
constexpr int items_per_scalar = sizeof(WideT) / sizeof(T);
// TODO: it's UB
union {
WideT scalar;
T array[items_per_scalar]; // NOLINT(runtime/arrays)
} wide;

int skip_cnt =
(reinterpret_cast<size_t>(in) % sizeof(WideT))
Expand All @@ -198,11 +193,13 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, const
const idxT len_cast = (len - skip_cnt) / items_per_scalar;

for (idxT i = thread_rank; i < len_cast; i += num_threads) {
wide.scalar = in_cast[i];
const WideT wide_data = in_cast[i];
T local_array[items_per_scalar]; // NOLINT(runtime/arrays)
__builtin_memcpy(local_array, &wide_data, sizeof(WideT));
const idxT real_i = skip_cnt + i * items_per_scalar;
#pragma unroll
for (int j = 0; j < items_per_scalar; ++j) {
f(wide.array[j], real_i + j);
f(local_array[j], real_i + j);
}
}

Expand Down Expand Up @@ -236,10 +233,6 @@ __device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width
} else {
static_assert(sizeof(WideT) % sizeof(T) == 0);
constexpr int items_per_scalar = sizeof(WideT) / sizeof(T);
union {
WideT scalar;
T array[items_per_scalar]; // NOLINT(runtime/arrays)
} wide;

int skip_cnt =
(reinterpret_cast<size_t>(in) % sizeof(WideT))
Expand All @@ -251,16 +244,24 @@ __device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width
const WideT *in_cast = reinterpret_cast<decltype(in_cast)>(in + skip_cnt);
const idxT len_cast = (len - skip_cnt) / items_per_scalar;

const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width;
for (idxT i = tid; i < len_cast_for_sync; i += stride) {
bool valid = i < len_cast;
if (valid) {
wide.scalar = in_cast[i];
}
const idxT real_i = skip_cnt + i * items_per_scalar;
// Skip when no full vector chunk exists: avoids len_cast_for_sync underflow and
// OOB companion reads (in_cast[0] needs at least one valid WideT).
if (len_cast > 0) {
const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width;
for (idxT i = tid; i < len_cast_for_sync; i += stride) {
const bool valid = i < len_cast;
// Unconditional 128-bit vector load: invalid threads read in_cast[0] (cached,
// discarded via valid=false) so NVCC emits LDG.E.128 instead of predicated load.
// Safe because len_cast > 0 guarantees at least one valid WideT at in_cast[0].
const WideT *load_ptr = valid ? &in_cast[i] : &in_cast[0];
const WideT wide_data = *load_ptr;
Comment thread
solos marked this conversation as resolved.
Outdated
T local_array[items_per_scalar]; // NOLINT(runtime/arrays)
__builtin_memcpy(local_array, &wide_data, sizeof(WideT));
const idxT real_i = skip_cnt + i * items_per_scalar;
#pragma unroll
for (int j = 0; j < items_per_scalar; ++j) {
f(wide.array[j], real_i + j, valid);
for (int j = 0; j < items_per_scalar; ++j) {
f(local_array[j], real_i + j, valid);
}
}
}

Expand Down