Skip to content

Enable gfx950 CI on dev branch#401

Merged
VeeraRajasekhar merged 15 commits into
devfrom
veergopu_gfx950_ci
Jan 9, 2026
Merged

Enable gfx950 CI on dev branch#401
VeeraRajasekhar merged 15 commits into
devfrom
veergopu_gfx950_ci

Conversation

@VeeraRajasekhar

Copy link
Copy Markdown
Contributor

Description

  • Enable gfx950 (MI350) CI by addressing the specific failures we saw: FP8 GEMM coverage gaps in hipBLASLt, RMSNorm misalignment on odd strides (e.g., N=17389), fused optimizer tolerances, and unsupported quantized/activation-recompute test cases on ROCm.
  • Prevent JAX GEMM/grouped-GEMM FFI from being marked cudaGraph-safe on ROCm to avoid failures; keep gfx950 FP8 layout support disabled until hipBLASLt coverage is validated.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • [] New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Disable cudaGraph registration for JAX gemm and grouped_gemm FFI on ROCm to stop graph-capture hangs for gfx950 (transformer_engine/jax/csrc/extensions/gemm.cpp).

  • Keep is_fp8_gemm_with_all_layouts_supported false on gfx950 until hipBLASLt FP8 layout coverage is validated (transformer_engine/jax/quantize/device_utils.py).

  • Fix RMSNorm Triton kernel for misaligned row strides by only applying 16B alignment hints when the pointers/strides are aligned; this resolves test_norms dgamma mismatches and the test_transformer_layer_hidden_states_format numerics issues. Also relax fused-optimizer FP8 tolerances on MI350 (transformer_engine/pytorch/triton_kernels/rmsnorm.py, tests/pytorch/test_numerics.py, tests/pytorch/test_fused_optimizer.py).

  • Skip unsupported FP8 quantized linear combinations on gfx950 where hipBLASLt lacks algorithms (tests/pytorch/test_fusible_ops.py).

  • Add gfx950 detection helper and skip test_gpt_full_activation_recompute on MI350 configs that hipBLASLt cannot serve (transformer_engine/pytorch/utils.py, tests/pytorch/test_numerics.py).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Comment thread transformer_engine/jax/quantize/device_utils.py
Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py
Comment thread tests/pytorch/test_fused_optimizer.py Outdated
Comment thread tests/pytorch/test_numerics.py
Comment thread transformer_engine/jax/csrc/extensions/gemm.cpp
Comment thread transformer_engine/jax/quantize/device_utils.py Outdated
Comment thread transformer_engine/pytorch/triton_kernels/rmsnorm.py
Comment thread tests/pytorch/test_fused_optimizer.py Outdated
Comment thread ci/ci_config.json Outdated
Comment thread tests/pytorch/test_fused_optimizer.py

#ifdef __HIP_PLATFORM_AMD__

// Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does unstable mean?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6192 - OperatorTest/GEMMTestSuite.Testfp8xfp8xbf16xbf16xbf16/2304x768x4096x0x0xTNxM  # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
6768 - OperatorTest/GEMMTestSuite.Testfp8xbf8xbf16xbf16xfp16/2304x768x4096x0x0xTNxM  # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
7344 - OperatorTest/GEMMTestSuite.Testbf8xfp8xbf16xbf16xfp32/2304x768x4096x0x0xTNxM  # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
7488 - OperatorTest/GEMMTestSuite.Testbf8xfp8xbf16xbf16xfp16/2304x768x4096x0x0xTNxM  # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)

These testcases are failing at random, so we decided to skip for this mi350 bring up. When I tested on Rocm7.2 there was no issue

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard it with #if HIP_VERSION < 70200000 then. So comments about temporary disable and re-enable and mentioning of ROCm 7.2 can be removed

Comment thread tests/pytorch/test_numerics.py Outdated
Comment thread tests/pytorch/test_numerics.py Outdated

#ifdef __HIP_PLATFORM_AMD__

// Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard it with #if HIP_VERSION < 70200000 then. So comments about temporary disable and re-enable and mentioning of ROCm 7.2 can be removed

// Re-enable after ROCm 7.2 once hipBLASLt fixes land.
if (prop.major == 9 && prop.minor == 5 &&
params.transa && !params.transb &&
params.m == 2304 && params.k == 768 && params.n == 4096) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is only 1 size for DqTest. Instead of skipping the test just use different size for test_case_sizes_mxfp8, for example 768, 3072, 4096

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/pytorch/test_fused_optimizer.py Outdated
@VeeraRajasekhar

Copy link
Copy Markdown
Contributor Author

rebase to dev

@VeeraRajasekhar

Copy link
Copy Markdown
Contributor Author

Test report for MI355 with Level=3:

  • No issues with sgpu tests reported.
  • Pytorch Mgpu tests had no issues
  • Jax test [auto] test_distributed_fused_attn.py timeout is triggered due to hang which is known. Other Jax tests passed

@VeeraRajasekhar VeeraRajasekhar merged commit f141f34 into dev Jan 9, 2026
2 of 4 checks passed
wangye805 pushed a commit that referenced this pull request Feb 2, 2026
* [CI] Skipped test_gpt_full_activation_recompute tests for gfx950

* [CI] Skipped unsupported test_basic_linear_quantized tests on gfx950

* [CI] Fixed test_numerics, test_norms, test_fused_optimizer failures for gfx950 ci enablement

* [CI] Disabled gfx950 support until FP8 GEMM layout coverage is verified with hipblaslt

* [CI] [gfx950] Disable cudaGraph for gemmm and grouped-gemm

* Addressed reviews

* [CI] Add MI355 nodes to github actions workflow

* [CI] Update docker image

* [CI] add MI355 runner matrix and keep matrix legs independent

* Skip unstable Gemm tests on gfx950

* Addressed reviews

* Guard gfx950 TN skip by ROCm version and adjust MXFP8 Dq test size

* Removed ROCM7.2 guards

* Reverted ROCM7.2 guards

* Update rocm-ci.yml
wangye805 pushed a commit that referenced this pull request Feb 2, 2026
* [CI] Skipped test_gpt_full_activation_recompute tests for gfx950

* [CI] Skipped unsupported test_basic_linear_quantized tests on gfx950

* [CI] Fixed test_numerics, test_norms, test_fused_optimizer failures for gfx950 ci enablement

* [CI] Disabled gfx950 support until FP8 GEMM layout coverage is verified with hipblaslt

* [CI] [gfx950] Disable cudaGraph for gemmm and grouped-gemm

* Addressed reviews

* [CI] Add MI355 nodes to github actions workflow

* [CI] Update docker image

* [CI] add MI355 runner matrix and keep matrix legs independent

* Skip unstable Gemm tests on gfx950

* Addressed reviews

* Guard gfx950 TN skip by ROCm version and adjust MXFP8 Dq test size

* Removed ROCM7.2 guards

* Reverted ROCM7.2 guards

* Update rocm-ci.yml
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants