Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ run_test_config(){
fi
run 1 test_cuda_graphs.py
run_default_fa 1 test_deferred_init.py
run_default_fa 1 test_quantized_tensor.py
run_default_fa 1 test_float8_current_scaling_exact.py
run_default_fa 1 test_float8blockwisetensor.py
Comment thread
asdfvg123 marked this conversation as resolved.
run_default_fa 1 test_quantized_tensor.py
test $_fus_attn = auto -o $_fus_attn = ck && run 1 test_cpu_offloading.py
test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 run 3 test_cpu_offloading_v1.py
run_default_fa 1 test_fused_rope.py
Expand Down
3 changes: 1 addition & 2 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ add_executable(test_operator
test_cast_mxfp8.cu
test_cast_mxfp8_grouped.cu
test_cast_nvfp4_transpose.cu
test_cast_float8blockwise.cu #CUDA-only test
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_transpose.cu
test_cast_transpose.cu
Expand All @@ -41,7 +41,6 @@ if(USE_ROCM)
get_target_property(test_cuda_sources test_operator SOURCES)
# Remove CUDA-only tests and add ROCm specific ones
list(REMOVE_ITEM test_cuda_sources
test_cast_float8blockwise.cu
test_swizzle.cu
test_grouped_gemm.cu)
list(APPEND test_cuda_sources
Expand Down
18 changes: 18 additions & 0 deletions tests/cpp/operator/test_cast_float8blockwise.cu
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright

Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,28 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method
}

inline size_t scale_align_stride(size_t inner_elements) {
#ifdef __HIP_PLATFORM_AMD__
return inner_elements;
#else
return ((inner_elements + 4u - 1u) / 4u) * 4u;
#endif
};

void compare_scaling_factors(const std::string& name, const float* test, const float* ref,
const size_t row_blocks, const size_t col_blocks,
const size_t test_stride, const size_t ref_stride) {
#ifdef __HIP_PLATFORM_AMD__
const float atol = 1e-6f;
#endif
for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int test_idx = i * test_stride + j;
const int ref_idx = i * ref_stride + j;
#ifdef __HIP_PLATFORM_AMD__
ASSERT_FALSE(std::abs(test[test_idx] - ref[ref_idx]) > atol)
#else
ASSERT_FALSE(test[test_idx] != ref[ref_idx])
#endif
<< "Error in " << name << std::endl
<< "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx
<< "," << ref_idx;
Expand All @@ -248,12 +259,19 @@ void compare_scaling_factors(const std::string& name, const float* test, const f
void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test,
const float* ref, const size_t rows,
const size_t col_blocks) {
#ifdef __HIP_PLATFORM_AMD__
const float atol = 1e-6f;
#endif
const size_t test_stride = scale_align_stride(rows);
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int test_idx = i + test_stride * j;
const int ref_idx = i + rows * j;
#ifdef __HIP_PLATFORM_AMD__
ASSERT_FALSE(std::abs(test[test_idx] - ref[ref_idx]) > atol)
#else
ASSERT_FALSE(test[test_idx] != ref[ref_idx])
#endif
<< "Error in " << name << std::endl
<< "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx
<< "," << ref_idx;
Expand Down
5 changes: 4 additions & 1 deletion tests/pytorch/test_float8blockwisetensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def _to_list(x: Union[Iterable, Any]) -> List:
DimsType = Union[Iterable[int], int]

# TODO replace with call to fp8.py when recipe added.
recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8)
if IS_HIP_EXTENSION:
recipe_available = get_device_compute_capability() >= (9, 0)
else:
recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8
reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS."


Expand Down
8 changes: 3 additions & 5 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ list(APPEND transformer_engine_cuda_sources
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_vector_blockwise.cu #CUDA-only
transpose/quantize_transpose_vector_blockwise.cu
transpose/quantize_transpose_square_blockwise.cu
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should stay in transformer_engine_cuda_arch_specific_sources

transpose/swap_first_dims.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
Expand Down Expand Up @@ -233,7 +234,6 @@ list(APPEND transformer_engine_cuda_sources
comm_gemm_overlap/userbuffers/userbuffers.cu)

set(cuda_only_cuda_sources
transpose/quantize_transpose_vector_blockwise.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
fused_attn/fused_attn_fp8.cu
Expand All @@ -257,7 +257,6 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
recipe/nvfp4.cu
transpose/quantize_transpose_square_blockwise.cu #CUDA-only
transpose/quantize_transpose_vector_blockwise_fp4.cu)

set(cuda_only_cuda_arch_specific_sources
Expand All @@ -267,8 +266,7 @@ set(cuda_only_cuda_arch_specific_sources
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
transpose/quantize_transpose_square_blockwise.cu)
hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu)

# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
Expand Down
4 changes: 0 additions & 4 deletions transformer_engine/common/cast/dispatch/quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
#endif
break;
}
#ifndef __HIP_PLATFORM_AMD__
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D");
Expand Down Expand Up @@ -196,7 +195,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
#endif//#ifndef __HIP_PLATFORM_AMD__
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
Expand Down Expand Up @@ -317,7 +315,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
#endif
break;
}
#ifndef __HIP_PLATFORM_AMD__
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
Expand Down Expand Up @@ -351,7 +348,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
#endif //#ifndef __HIP_PLATFORM_AMD__
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
Expand Down
7 changes: 5 additions & 2 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,12 @@ struct TypeExtrema<fp8e4m3> {
#elif defined(__HIP_DEVICE_COMPILE__)
static constexpr float max = te_fp8_fnuz() ? 240.0f : 448.0f;
static constexpr float max_inverse = 1.0 / max;
#elif defined(HIP_FP8_TYPE_FNUZ)
static constexpr float max = 240.0f;
static constexpr float max_inverse = 1.0 / max;
#else
static float max;
static float max_inverse;
static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;
#endif
};

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand All @@ -18,17 +20,31 @@
#include "common/util/ptx.cuh"
#include "common/utils.cuh"

#ifndef __HIP_PLATFORM_AMD__
#if (!defined(__CUDA_MINIMUM_ARCH__) && __CUDA_ARCH__ >= 900) || \
(defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900)
#define TMA_HW_SUPPORTED
#endif
#endif

namespace transformer_engine {
namespace {

#ifdef __HIP_PLATFORM_AMD__
using WarpSyncMask = uint64_t;
constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm should not use it. See how *_sync calls are guarded in other places

#else
using WarpSyncMask = unsigned;
constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFu;
#endif

// const values configuration

#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp = 64;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is platform dependent.

#else
constexpr size_t kThreadsPerWarp = 32;
#endif
#ifdef TMA_HW_SUPPORTED
constexpr size_t BLOCK_TILE_DIM = 128;
constexpr size_t WARP_TILE_DIM_X = 32;
Expand All @@ -40,8 +56,12 @@ constexpr size_t BLOCK_TILE_DIM = 128;
constexpr size_t WARP_TILE_DIM_X = 64;
constexpr size_t WARP_TILE_DIM_Y = 32;
constexpr size_t THREAD_TILE_DIM_X = 8;
#ifdef __HIP_PLATFORM_AMD__
constexpr size_t THREAD_TILE_DIM_Y = 4;
#else
constexpr size_t THREAD_TILE_DIM_Y = 8;
#endif
#endif

#ifdef TMA_HW_SUPPORTED
constexpr size_t NUM_BYTES_PER_BANK = 4;
Expand All @@ -62,6 +82,16 @@ constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP

#define MIN(a, b) (a < b ? a : b)

#ifdef __HIP_PLATFORM_AMD__
__device__ __forceinline__ float blockwise_warp_reduce_max(float val) {
#pragma unroll
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2)
val = fmaxf(val, __shfl_down(val, delta, kThreadsPerWarp));
return val;
}
#endif

#ifndef __HIP_PLATFORM_AMD__
template <bool kReturnTranspose, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c,
Expand Down Expand Up @@ -130,10 +160,14 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
}
}
// Reduce amax in the warp (32x32 tile)
#ifdef __HIP_PLATFORM_AMD__
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole this code is under #ifndef HIP_PLATFORM_AMD

warp_tile_amax = blockwise_warp_reduce_max(amax);
#else
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to just use warp_reduce_max here, and remove the kThreadsPerWarp=64 logic too. For the most part, the compiler will double up and we will be okay here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warp_reduce_max in transformer_engine/common/utils.cuh uses THREADS_PER_WARP = 32 in the file which creates bug. Let me know if there is a better way

#endif
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero);

// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if (tid_in_warp == 0) {
Expand Down Expand Up @@ -247,6 +281,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
#endif
}
}
#endif // __HIP_PLATFORM_AMD__

template <bool kReturnTranspose, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned(
Expand Down Expand Up @@ -357,10 +392,14 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
}
}
// Reduce amax in the warp (32x32 tile)
#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax = blockwise_warp_reduce_max(amax);
#else
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);
#endif
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero);

// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if (tid_in_warp == 0) {
Expand Down Expand Up @@ -456,6 +495,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
}
}

#ifndef __HIP_PLATFORM_AMD__
template <typename OutputType>
CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) {
CUtensorMapDataType dataType;
Expand All @@ -473,6 +513,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size
/*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8);
return tensor_map_output_trans;
}
#endif // __HIP_PLATFORM_AMD__

} // namespace
} // namespace transformer_engine
Expand Down Expand Up @@ -543,9 +584,10 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
return_transpose, kReturnTranspose,

dim3 grid(num_blocks_x, num_blocks_y, 1);

#ifndef __HIP_PLATFORM_AMD__
const bool full_tile =
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;

if (full_tile) {
CUtensorMap tensor_map_output_trans;
if (return_transpose) {
Expand Down Expand Up @@ -573,6 +615,18 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
pow_2_scale, noop_ptr);
} // full-tile
#else
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType,
OutputType>
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
pow_2_scale, noop_ptr);
#endif // __HIP_PLATFORM_AMD__
) // return_transpose
) // OutputType
) // InputType
Expand Down
Loading
Loading