diff --git a/transformer_engine/common/util/standalone_topk.cuh b/transformer_engine/common/util/standalone_topk.cuh index 3d19cbfcf2..e5274b2060 100644 --- a/transformer_engine/common/util/standalone_topk.cuh +++ b/transformer_engine/common/util/standalone_topk.cuh @@ -181,11 +181,6 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, const } else { static_assert(sizeof(WideT) % sizeof(T) == 0); constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); - // TODO: it's UB - union { - WideT scalar; - T array[items_per_scalar]; // NOLINT(runtime/arrays) - } wide; int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) @@ -198,11 +193,13 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, const const idxT len_cast = (len - skip_cnt) / items_per_scalar; for (idxT i = thread_rank; i < len_cast; i += num_threads) { - wide.scalar = in_cast[i]; + const WideT wide_data = in_cast[i]; + T local_array[items_per_scalar]; // NOLINT(runtime/arrays) + __builtin_memcpy(local_array, &wide_data, sizeof(WideT)); const idxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { - f(wide.array[j], real_i + j); + f(local_array[j], real_i + j); } } @@ -236,10 +233,6 @@ __device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width } else { static_assert(sizeof(WideT) % sizeof(T) == 0); constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); - union { - WideT scalar; - T array[items_per_scalar]; // NOLINT(runtime/arrays) - } wide; int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) @@ -251,16 +244,24 @@ __device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width const WideT *in_cast = reinterpret_cast(in + skip_cnt); const idxT len_cast = (len - skip_cnt) / items_per_scalar; - const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; - for (idxT i = tid; i < len_cast_for_sync; i += stride) { - bool valid = i < len_cast; - if (valid) { - wide.scalar = in_cast[i]; - } - const idxT real_i = skip_cnt + i * items_per_scalar; + // Skip when no full vector chunk exists: avoids len_cast_for_sync underflow and + // OOB companion reads (in_cast[0] needs at least one valid WideT). + if (len_cast > 0) { + const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; + for (idxT i = tid; i < len_cast_for_sync; i += stride) { + const bool valid = i < len_cast; + // Unconditional 128-bit vector load: invalid threads read in_cast[0] (cached, + // discarded via valid=false) so NVCC emits LDG.E.128 instead of predicated load. + // Index clamping (not pointer ternary) avoids C++ UB from &in_cast[i] when i >= len_cast. + const idxT safe_i = valid ? i : static_cast(0); + const WideT wide_data = in_cast[safe_i]; + T local_array[items_per_scalar]; // NOLINT(runtime/arrays) + __builtin_memcpy(local_array, &wide_data, sizeof(WideT)); + const idxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll - for (int j = 0; j < items_per_scalar; ++j) { - f(wide.array[j], real_i + j, valid); + for (int j = 0; j < items_per_scalar; ++j) { + f(local_array[j], real_i + j, valid); + } } }