Enable gfx950 CI on dev branch#401
Conversation
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
|
|
||
| // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. |
There was a problem hiding this comment.
What does unstable mean?
There was a problem hiding this comment.
6192 - OperatorTest/GEMMTestSuite.Testfp8xfp8xbf16xbf16xbf16/2304x768x4096x0x0xTNxM # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
6768 - OperatorTest/GEMMTestSuite.Testfp8xbf8xbf16xbf16xfp16/2304x768x4096x0x0xTNxM # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
7344 - OperatorTest/GEMMTestSuite.Testbf8xfp8xbf16xbf16xfp32/2304x768x4096x0x0xTNxM # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
7488 - OperatorTest/GEMMTestSuite.Testbf8xfp8xbf16xbf16xfp16/2304x768x4096x0x0xTNxM # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
These testcases are failing at random, so we decided to skip for this mi350 bring up. When I tested on Rocm7.2 there was no issue
There was a problem hiding this comment.
Guard it with #if HIP_VERSION < 70200000 then. So comments about temporary disable and re-enable and mentioning of ROCm 7.2 can be removed
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
|
|
||
| // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. |
There was a problem hiding this comment.
Guard it with #if HIP_VERSION < 70200000 then. So comments about temporary disable and re-enable and mentioning of ROCm 7.2 can be removed
| // Re-enable after ROCm 7.2 once hipBLASLt fixes land. | ||
| if (prop.major == 9 && prop.minor == 5 && | ||
| params.transa && !params.transb && | ||
| params.m == 2304 && params.k == 768 && params.n == 4096) { |
There was a problem hiding this comment.
There is only 1 size for DqTest. Instead of skipping the test just use different size for test_case_sizes_mxfp8, for example 768, 3072, 4096
|
rebase to dev |
…or gfx950 ci enablement
…ed with hipblaslt
5a83295 to
b551b3f
Compare
|
Test report for MI355 with Level=3:
|
* [CI] Skipped test_gpt_full_activation_recompute tests for gfx950 * [CI] Skipped unsupported test_basic_linear_quantized tests on gfx950 * [CI] Fixed test_numerics, test_norms, test_fused_optimizer failures for gfx950 ci enablement * [CI] Disabled gfx950 support until FP8 GEMM layout coverage is verified with hipblaslt * [CI] [gfx950] Disable cudaGraph for gemmm and grouped-gemm * Addressed reviews * [CI] Add MI355 nodes to github actions workflow * [CI] Update docker image * [CI] add MI355 runner matrix and keep matrix legs independent * Skip unstable Gemm tests on gfx950 * Addressed reviews * Guard gfx950 TN skip by ROCm version and adjust MXFP8 Dq test size * Removed ROCM7.2 guards * Reverted ROCM7.2 guards * Update rocm-ci.yml
* [CI] Skipped test_gpt_full_activation_recompute tests for gfx950 * [CI] Skipped unsupported test_basic_linear_quantized tests on gfx950 * [CI] Fixed test_numerics, test_norms, test_fused_optimizer failures for gfx950 ci enablement * [CI] Disabled gfx950 support until FP8 GEMM layout coverage is verified with hipblaslt * [CI] [gfx950] Disable cudaGraph for gemmm and grouped-gemm * Addressed reviews * [CI] Add MI355 nodes to github actions workflow * [CI] Update docker image * [CI] add MI355 runner matrix and keep matrix legs independent * Skip unstable Gemm tests on gfx950 * Addressed reviews * Guard gfx950 TN skip by ROCm version and adjust MXFP8 Dq test size * Removed ROCM7.2 guards * Reverted ROCM7.2 guards * Update rocm-ci.yml
Description
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Disable cudaGraph registration for JAX gemm and grouped_gemm FFI on ROCm to stop graph-capture hangs for gfx950 (transformer_engine/jax/csrc/extensions/gemm.cpp).
Keep is_fp8_gemm_with_all_layouts_supported false on gfx950 until hipBLASLt FP8 layout coverage is validated (transformer_engine/jax/quantize/device_utils.py).
Fix RMSNorm Triton kernel for misaligned row strides by only applying 16B alignment hints when the pointers/strides are aligned; this resolves test_norms dgamma mismatches and the test_transformer_layer_hidden_states_format numerics issues. Also relax fused-optimizer FP8 tolerances on MI350 (transformer_engine/pytorch/triton_kernels/rmsnorm.py, tests/pytorch/test_numerics.py, tests/pytorch/test_fused_optimizer.py).
Skip unsupported FP8 quantized linear combinations on gfx950 where hipBLASLt lacks algorithms (tests/pytorch/test_fusible_ops.py).
Add gfx950 detection helper and skip test_gpt_full_activation_recompute on MI350 configs that hipBLASLt cannot serve (transformer_engine/pytorch/utils.py, tests/pytorch/test_numerics.py).
Checklist: