Skip to content

Use cuDNN for row-scaled NVFP4 grouped GEMM#3042

Draft
zianglih wants to merge 2 commits into
NVIDIA:mainfrom
zianglih:codex/cudnn-row-scale-nvfp4-grouped-gemm
Draft

Use cuDNN for row-scaled NVFP4 grouped GEMM#3042
zianglih wants to merge 2 commits into
NVIDIA:mainfrom
zianglih:codex/cudnn-row-scale-nvfp4-grouped-gemm

Conversation

@zianglih

Copy link
Copy Markdown
Contributor

Summary

  • route row-scaled NVFP4 grouped GEMM through cuDNN grouped GEMM quant
  • remove the per-GEMM fallback for the row-scaled grouped path so unsupported cases fail explicitly
  • tighten NVFP4 grouped GEMM tests to cover the cuDNN wrapper path, a supported functional case, and an unsupported no-fallback case

Required dependency

This PR explicitly requires the corresponding cuDNN Frontend feature in NVIDIA/cudnn-frontend#251. It requires a cudnn-frontend version whose cudnn.grouped_gemm_quant_wrapper_sm100(...) accepts row_scale_tensor; without that cudnn-fe PR feature, this TransformerEngine PR is expected to fail on the row-scaled grouped GEMM path.

Motivation

Related to the row-scaled NVFP4 work in #2931. This PR is intended to land only after TransformerEngine can depend on a cudnn-fe version containing the row-scaled grouped GEMM quant feature.

Validation

  • python3 -m py_compile transformer_engine/pytorch/cpp_extensions/gemm.py tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
  • git diff --check -- transformer_engine/pytorch/cpp_extensions/gemm.py tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
  • pre-commit run --all-files
  • B200 devbox: installed cudnn-fe branch from Add row-scale support to grouped GEMM quant cudnn-frontend#251 and verified row_scale_tensor is in grouped_gemm_quant_wrapper_sm100
  • B200 devbox: built and installed TransformerEngine with NVTE_FRAMEWORK=pytorch NVTE_CUDA_ARCHS=100a MAX_JOBS=4
  • B200 devbox: pytest -q tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py::test_nvfp4_row_scaled_grouped_gemm_uses_cudnn_quant_wrapper --tb=short passed: 2 passed
  • B200 devbox: supported BF16 case passed: test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm[mae_err-default-single_output-no_bias-torch.bfloat16-torch.bfloat16-torch.bfloat16-m_splits4-1024-1024]
  • B200 devbox: unsupported no-fallback case passed: test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm[mae_err-default-list_output-no_bias-torch.float32-torch.float32-torch.float32-m_splits0-128-128]
  • B200 devbox: python3 -m pylint transformer_engine/pytorch/cpp_extensions/gemm.py passed: 10.00/10

zianglih added 2 commits May 22, 2026 21:35
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 26, 2026
@zianglih

Copy link
Copy Markdown
Contributor Author

need to rebase and refactor according to recent cute dsl integration

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant