Skip to content

enable blockwise FP8 quantization on rocm#609

Open
asdfvg123 wants to merge 3 commits into
devfrom
yeonsoo/blockwise_fp8
Open

enable blockwise FP8 quantization on rocm#609
asdfvg123 wants to merge 3 commits into
devfrom
yeonsoo/blockwise_fp8

Conversation

@asdfvg123
Copy link
Copy Markdown

Description

Please include a brief summary of the changes, relevant motivation and context.

Enable blockwise FP8 quantization on rocm

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

remove HIP guard in quantization.py
guard kernels using TMA in quantization.
add branch to handle rocm for different threads per wave

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

# TODO replace with call to fp8.py when recipe added.
recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8)
if IS_HIP_EXTENSION:
recipe_available = get_device_compute_capability() >= (9, 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this be always True on ROCm TE?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to (9, 5)

Comment thread transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Outdated
Comment on lines +639 to +640
static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary? fp8e4m3 max depends on the device type on AMD.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantize_transpose_square_blockwise.cu and quantize_transpose_vector_blockwise.cu use
compute_scale_from_types<IType, fp8e4m3> for the first time, which exposed a latent bug in common.h

The #else branch of TypeExtrema<fp8e4m3> declared max as a static float,
This caused the constexpr static float max_finite_value initializer in TypeInfo in the same file to fail when the template was instantiated on the host.

The fix uses HIP_FP8_TYPE_FNUZ, used in hip_float8.h for selecting FNUZ at compile time, to make the host-pass branch constexpr as well.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed

@alextmagro
Copy link
Copy Markdown
Contributor

alextmagro commented Jun 3, 2026

Could you give a description of what you want to achieve with this PR? My understanding is that block fp8 quantization relies on some upstream kernels that will need to be adapted for AMD.

If you're just trying to enable the interface, I would argue that we should do this last, after we have a working quantization and GEMM path (and enabled and passing C++/Python tests).

@asdfvg123
Copy link
Copy Markdown
Author

@alextmagro
This PR is to enable only the quantization in the AMD gpus, not the GEMM. There are two kernels in the upstream which uses TMA for the quantization and does not uses TMA for the quantization. I guarded the kernels which uses TMA and used the non-TMA kernels to quantize for AMD.

I tested with
tests/pytorch/test_float8blockwisetensor.py
and it passes [175 passed / 32 xpassed / 5 warnings]

@alextmagro
Copy link
Copy Markdown
Contributor

@alextmagro This PR is to enable only the quantization in the AMD gpus, not the GEMM. There are two kernels in the upstream which uses TMA for the quantization and does not uses TMA for the quantization. I guarded the kernels which uses TMA and used the non-TMA kernels to quantize for AMD.

I tested with tests/pytorch/test_float8blockwisetensor.py and it passes [175 passed / 32 xpassed / 5 warnings]

OK, in that case we need to add the cpp blockwise tests to the CMake file, and the pytorch test file to ci/pytorch.sh.

…dant HIP guards, revert unnecessary common.h change
Comment thread ci/pytorch.sh
Comment thread tests/cpp/operator/test_cast_float8blockwise.cu Outdated
Comment thread tests/cpp/operator/test_cast_float8blockwise.cu Outdated
Comment thread tests/pytorch/test_float8blockwisetensor.py Outdated
#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax = blockwise_warp_reduce_max(amax);
#else
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to just use warp_reduce_max here, and remove the kThreadsPerWarp=64 logic too. For the most part, the compiler will double up and we will be okay here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warp_reduce_max in transformer_engine/common/utils.cuh uses THREADS_PER_WARP = 32 in the file which creates bug. Let me know if there is a better way

Comment thread transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Outdated
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
tensor_map_output_trans, pow_2_scale, noop_ptr);
} else {
} else
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid splitting up the } else { line here. We can add another macro guard instead if needed.

constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches

#ifdef __HIP_PLATFORM_AMD__
constexpr int kThreadsPerBlock = 512; // Thread block size, 8 warps (wave64) in total
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there actual performance improvements for increasing the # of threads and the threads per warp? If not, we should use the already present values for now.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel expects 8 waves / block , so I increased the number of threads

Comment thread transformer_engine/pytorch/quantization.py
@alextmagro
Copy link
Copy Markdown
Contributor

alextmagro commented Jun 4, 2026

By the way, to run CI you need to add a CI level label. L3 is required before merging, L1 is for lighter testing, mostly sGPU tests, if you are midway through the ticket and expect to make more changes

Uploading image.png…

@asdfvg123 asdfvg123 added the ci-level 1 CI test level 1 label Jun 4, 2026
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright


#ifdef __HIP_PLATFORM_AMD__
using WarpSyncMask = uint64_t;
constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm should not use it. See how *_sync calls are guarded in other places

}
}
// Reduce amax in the warp (32x32 tile)
#ifdef __HIP_PLATFORM_AMD__
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole this code is under #ifndef HIP_PLATFORM_AMD

// const values configuration

#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp = 64;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is platform dependent.

transpose/multi_cast_transpose.cu
transpose/quantize_transpose_vector_blockwise.cu #CUDA-only
transpose/quantize_transpose_vector_blockwise.cu
transpose/quantize_transpose_square_blockwise.cu
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should stay in transformer_engine_cuda_arch_specific_sources

Comment on lines +639 to +640
static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants