Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/pytorch/test_float8blockwisetensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def _to_list(x: Union[Iterable, Any]) -> List:
DimsType = Union[Iterable[int], int]

# TODO replace with call to fp8.py when recipe added.
recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8)
if IS_HIP_EXTENSION:
recipe_available = get_device_compute_capability() >= (9, 0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this be always True on ROCm TE?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to (9, 5)

else:
recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8
reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS."


Expand Down
8 changes: 3 additions & 5 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ list(APPEND transformer_engine_cuda_sources
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_vector_blockwise.cu #CUDA-only
transpose/quantize_transpose_vector_blockwise.cu
transpose/quantize_transpose_square_blockwise.cu

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should stay in transformer_engine_cuda_arch_specific_sources

transpose/swap_first_dims.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
Expand Down Expand Up @@ -233,7 +234,6 @@ list(APPEND transformer_engine_cuda_sources
comm_gemm_overlap/userbuffers/userbuffers.cu)

set(cuda_only_cuda_sources
transpose/quantize_transpose_vector_blockwise.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
fused_attn/fused_attn_fp8.cu
Expand All @@ -257,7 +257,6 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
recipe/nvfp4.cu
transpose/quantize_transpose_square_blockwise.cu #CUDA-only
transpose/quantize_transpose_vector_blockwise_fp4.cu)

set(cuda_only_cuda_arch_specific_sources
Expand All @@ -267,8 +266,7 @@ set(cuda_only_cuda_arch_specific_sources
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
transpose/quantize_transpose_square_blockwise.cu)
hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu)

# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
Expand Down
4 changes: 0 additions & 4 deletions transformer_engine/common/cast/dispatch/quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
#endif
break;
}
#ifndef __HIP_PLATFORM_AMD__
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D");
Expand Down Expand Up @@ -196,7 +195,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
#endif//#ifndef __HIP_PLATFORM_AMD__
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
Expand Down Expand Up @@ -317,7 +315,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
#endif
break;
}
#ifndef __HIP_PLATFORM_AMD__
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT),
Expand Down Expand Up @@ -351,7 +348,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
#endif //#ifndef __HIP_PLATFORM_AMD__
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,8 @@ struct TypeExtrema<fp8e4m3> {
static constexpr float max = te_fp8_fnuz() ? 240.0f : 448.0f;
static constexpr float max_inverse = 1.0 / max;
#else
static float max;
static float max_inverse;
static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary? fp8e4m3 max depends on the device type on AMD.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantize_transpose_square_blockwise.cu and quantize_transpose_vector_blockwise.cu use
compute_scale_from_types<IType, fp8e4m3> for the first time, which exposed a latent bug in common.h

The #else branch of TypeExtrema<fp8e4m3> declared max as a static float,
This caused the constexpr static float max_finite_value initializer in TypeInfo in the same file to fail when the template was instantiated on the host.

The fix uses HIP_FP8_TYPE_FNUZ, used in hip_float8.h for selecting FNUZ at compile time, to make the host-pass branch constexpr as well.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed

#endif
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,50 @@
************************************************************************/

#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include <cfloat>
#ifndef __HIP_PLATFORM_AMD__
#include <cuda/barrier>
#endif

#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#ifndef __HIP_PLATFORM_AMD__
#include "common/util/ptx.cuh"
#endif
Comment thread
asdfvg123 marked this conversation as resolved.
Outdated
#include "common/utils.cuh"

#ifndef __HIP_PLATFORM_AMD__
#if (!defined(__CUDA_MINIMUM_ARCH__) && __CUDA_ARCH__ >= 900) || \
(defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900)
#define TMA_HW_SUPPORTED
#endif
#endif

namespace transformer_engine {
namespace {

#ifdef __HIP_PLATFORM_AMD__
using WarpSyncMask = uint64_t;
constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm should not use it. See how *_sync calls are guarded in other places

#else
using WarpSyncMask = unsigned;
constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFu;
#endif

// const values configuration

#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp = 64;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is platform dependent.

#else
constexpr size_t kThreadsPerWarp = 32;
#endif
#ifdef TMA_HW_SUPPORTED
constexpr size_t BLOCK_TILE_DIM = 128;
constexpr size_t WARP_TILE_DIM_X = 32;
Expand Down Expand Up @@ -62,6 +82,7 @@ constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP

#define MIN(a, b) (a < b ? a : b)

#ifdef TMA_HW_SUPPORTED
template <bool kReturnTranspose, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c,
Expand Down Expand Up @@ -133,7 +154,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to just use warp_reduce_max here, and remove the kThreadsPerWarp=64 logic too. For the most part, the compiler will double up and we will be okay here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warp_reduce_max in transformer_engine/common/utils.cuh uses THREADS_PER_WARP = 32 in the file which creates bug. Let me know if there is a better way

// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero);

// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if (tid_in_warp == 0) {
Expand Down Expand Up @@ -247,6 +268,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
#endif
}
}
#endif // TMA_HW_SUPPORTED

template <bool kReturnTranspose, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned(
Expand Down Expand Up @@ -360,7 +382,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
warp_tile_amax = __shfl_sync(kFullWarpMask, warp_tile_amax, lane_zero);

// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if (tid_in_warp == 0) {
Expand Down Expand Up @@ -456,6 +478,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
}
}

#ifdef TMA_HW_SUPPORTED
Comment thread
asdfvg123 marked this conversation as resolved.
Outdated
template <typename OutputType>
CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) {
CUtensorMapDataType dataType;
Expand All @@ -473,6 +496,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size
/*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8);
return tensor_map_output_trans;
}
#endif // TMA_HW_SUPPORTED

} // namespace
} // namespace transformer_engine
Expand Down Expand Up @@ -546,6 +570,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
const bool full_tile =
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;

#ifdef TMA_HW_SUPPORTED
if (full_tile) {
CUtensorMap tensor_map_output_trans;
if (return_transpose) {
Expand All @@ -561,7 +586,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
tensor_map_output_trans, pow_2_scale, noop_ptr);
} else {
} else

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid splitting up the } else { line here. We can add another macro guard instead if needed.

#else
(void)full_tile;
#endif
{
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType,
OutputType>
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
************************************************************************/

#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include <algorithm>
#include <cfloat>
#ifndef __HIP_PLATFORM_AMD__
#include <cuda/barrier>
#endif
#include <utility>

#include "common/common.h"
Expand All @@ -26,6 +30,14 @@ namespace {
using transformer_engine::detail::FP8BlockwiseColumnwiseOption;
using transformer_engine::detail::FP8BlockwiseRowwiseOption;

#ifdef __HIP_PLATFORM_AMD__
using WarpSyncMask = uint64_t;
constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL;
#else
using WarpSyncMask = unsigned;
constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFu;
#endif

// clang-format off
/*

Expand Down Expand Up @@ -145,14 +157,23 @@ Step 3 (if columnwise transpose is False, COMPACT format): Skip Transpose, cast
*/
// clang-format on

#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp = 64;
#else
constexpr size_t kThreadsPerWarp = 32;
#endif

// Hyperparameters for performance tuning
constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization
constexpr int kNVecIn = 8; // The number of elements each LDG touches
constexpr int kNVecOut = 16; // The number of elements each STG touches
constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches

#ifdef __HIP_PLATFORM_AMD__
constexpr int kThreadsPerBlock = 512; // Thread block size, 8 warps (wave64) in total

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there actual performance improvements for increasing the # of threads and the threads per warp? If not, we should use the already present values for now.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel expects 8 waves / block , so I increased the number of threads

#else
constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total
#endif

// Auto-calculated constants, do not modify directly)
static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem");
Expand Down Expand Up @@ -259,7 +280,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
// the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const WarpSyncMask mask = ((WarpSyncMask{1} << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
Expand Down Expand Up @@ -350,7 +371,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
// the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const WarpSyncMask mask = ((WarpSyncMask{1} << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
Expand Down Expand Up @@ -474,7 +495,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const bool is_src_lane = thr_idx_in_warp == 0;
amax = warp_reduce_max<kThreadsPerWarp>(amax);
constexpr int lane_zero = 0;
amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero);
amax = __shfl_sync(kFullWarpMask, amax, lane_zero);
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def check_nvfp4_support() -> Tuple[bool, str]:
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available"""
if IS_HIP_EXTENSION:
return False, "FP8 block scaled gemm not yet supported for ROCm"
Comment thread
asdfvg123 marked this conversation as resolved.
return True, ""
if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9:
return True, ""
return (
Expand Down