Skip to content
Merged
70 changes: 70 additions & 0 deletions transformer_engine/common/util/padding.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand All @@ -13,6 +15,9 @@

#include "../common.h"
#include "../utils.cuh"
#ifdef __HIP_PLATFORM_AMD__
#include "rocm_device_utils.cuh" // for rocm_upper_bound()
#endif

namespace transformer_engine {

Expand Down Expand Up @@ -65,15 +70,22 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;

// Find tensor corresponding to block
#ifdef __HIP_PLATFORM_AMD__
const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid);
#else
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
#endif
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int padded_num_rows = args.padded_num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
#ifdef __HIP_PLATFORM_AMD__
const bool inplace = (input == output);
#endif

// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
Expand All @@ -83,6 +95,35 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;

#ifdef __HIP_PLATFORM_AMD__
// Process subtiles with vectorized loads/stores
#pragma unroll
Comment thread
alextmagro marked this conversation as resolved.
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
const int col = tile_col + j1 * nvec;
const int remaining = row_length - col;
Comment thread
aris134 marked this conversation as resolved.
Outdated
if (row < num_rows) {
// Valid data row: skip copy when in-place
if (!inplace) {
const size_t offset = static_cast<size_t>(row) * row_length + col;
Vec v;
v.load_from_elts(input, offset, remaining > 0 ? min(remaining, nvec) : 0);
v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0);
}
} else if (row < padded_num_rows) {
// Padding row: fill with zeros
const size_t offset = static_cast<size_t>(row) * row_length + col;
Vec v;
v.clear();
v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0);
}
}
}
#else // !__HIP_PLATFORM_AMD__
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Expand Down Expand Up @@ -125,6 +166,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
}
}
}
#endif // __HIP_PLATFORM_AMD__
}

template <int nvec, typename Type>
Expand All @@ -150,14 +192,21 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;

// Find tensor corresponding to block
#ifdef __HIP_PLATFORM_AMD__
const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid);
#else
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
#endif
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
#ifdef __HIP_PLATFORM_AMD__
const bool inplace = (input == output);
#endif

// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
Expand All @@ -167,6 +216,26 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;

#ifdef __HIP_PLATFORM_AMD__
// Process subtiles with vectorized loads/stores
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
const int col = tile_col + j1 * nvec;
Comment thread
aris134 marked this conversation as resolved.
Outdated
if (row < num_rows && !inplace) {
const int remaining = row_length - col;
const size_t offset = static_cast<size_t>(row) * row_length + col;
Vec v;
v.load_from_elts(input, offset, remaining > 0 ? min(remaining, nvec) : 0);
v.store_to_elts(output, offset, remaining > 0 ? min(remaining, nvec) : 0);
}
}
}
#else // !__HIP_PLATFORM_AMD__
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Expand Down Expand Up @@ -202,6 +271,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
}
}
}
#endif // __HIP_PLATFORM_AMD__
}

} // namespace
Expand Down
17 changes: 17 additions & 0 deletions transformer_engine/common/util/rocm_device_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,23 @@ __device__ __forceinline__ void rocm_atomicMaxFloat(float *addr, float val) {
atomicMax(reinterpret_cast<int*>(addr), __float_as_int(val));
}

// Binary search on a sorted array.
// Returns the largest index i in [0, n) such that arr[i] <= val.
// Precondition: arr is sorted in non-decreasing order and arr[0] <= val.
template <typename T>
__device__ __forceinline__ int rocm_upper_bound(const T* arr, int n, T val) {
int lo = 0, hi = n - 1;
while (lo < hi) {
int mid = (lo + hi + 1) / 2;
if (arr[mid] <= val) {
lo = mid;
} else {
hi = mid - 1;
}
}
return lo;
}

template <int WARPS>
__device__ __forceinline__ float rocm_block_reduce_max(float val, int warp_id) {
__shared__ float staging[WARPS];
Expand Down
Loading