[ROCm][DSV4] Enable Tilelang MHC replacing torch/triton mhc#43679
Conversation
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
There was a problem hiding this comment.
Code Review
This pull request enables TileLang-based Multi-Head Latent Attention (MHC) kernels on ROCm (AMD GPUs) by adapting the existing CUDA implementations. It introduces platform-aware checks for Programmatic Dependent Launch (PDL) support, caches TileLang availability, and integrates TileLang operators into the DeepSeek V4 and MTP models. Critical feedback highlights that hardcoding a warp size of 32 in reduction logic will cause incorrect results on ROCm, where the wavefront size is 64; a platform-aware WARP_SIZE constant should be defined and used instead. Additionally, querying device capability via torch.cuda.current_device() at import time should be replaced with a stateless query to prevent eager CUDA initialization.
| lane = tid % 32 | ||
| warp_id = tid // 32 | ||
| num_warps = n_thr // 32 | ||
| warp_acc = T.alloc_shared((num_warps, block_m, tile_n), T.float32) | ||
| warp_sqr = T.alloc_shared((num_warps, block_m), T.float32) |
There was a problem hiding this comment.
On AMD GPUs (ROCm), the hardware wavefront/warp size is 64, not 32. Hardcoding 32 for warp/lane index calculations will cause incorrect reduction results and double-counting when using T.warp_reduce_sum on ROCm. Please use WARP_SIZE instead of 32.
| lane = tid % 32 | |
| warp_id = tid // 32 | |
| num_warps = n_thr // 32 | |
| warp_acc = T.alloc_shared((num_warps, block_m, tile_n), T.float32) | |
| warp_sqr = T.alloc_shared((num_warps, block_m), T.float32) | |
| lane = tid % WARP_SIZE | |
| warp_id = tid // WARP_SIZE | |
| num_warps = n_thr // WARP_SIZE | |
| warp_acc = T.alloc_shared((num_warps, block_m, tile_n), T.float32) | |
| warp_sqr = T.alloc_shared((num_warps, block_m), T.float32) |
There was a problem hiding this comment.
AMD CDNA hardware wavefront size is 64, but TileLang_s HIP warp_reduce_sum intentionally preserves 32-lane logical warp semantics. The installed TileLang source says this directly in:
tilelang/src/tl_templates/hip/reduce.h
Shows that it uses:
__shfl_xor(value, 16, 32)
...
__shfl_xor(value, 1, 32)
https://github.com/tile-ai/tilelang/blob/23d91c584dd98810b1acf91ec83bb1587dadf3c2/src/tl_templates/hip/reduce.h#L161 and https://github.com/tile-ai/tilelang/blob/23d91c584dd98810b1acf91ec83bb1587dadf3c2/src/tl_templates/hip/reduce.h#L171 comment explains that on CDNA wave64, width=32 splits the wavefront into two independent 32-lane logical groups, exactly for kernels that assume CUDA-like 32-lane warp behavior.
| @classmethod | ||
| def is_arch_support_pdl(cls) -> bool: | ||
| try: | ||
| device = torch.cuda.current_device() | ||
| major, _ = torch.cuda.get_device_capability(device) | ||
| except Exception: | ||
| return False | ||
| return major >= 9 |
There was a problem hiding this comment.
Calling torch.cuda.current_device() at import time eagerly initializes the CUDA context, which can break multi-processing, Ray, and distributed setups in vLLM. Since is_arch_support_pdl is called at the module level of vllm/_tilelang_ops.py during import, we should query the device capability statelessly using cls.get_device_capability(0) to avoid eager CUDA initialization.
| @classmethod | |
| def is_arch_support_pdl(cls) -> bool: | |
| try: | |
| device = torch.cuda.current_device() | |
| major, _ = torch.cuda.get_device_capability(device) | |
| except Exception: | |
| return False | |
| return major >= 9 | |
| @classmethod | |
| def is_arch_support_pdl(cls) -> bool: | |
| try: | |
| capability = cls.get_device_capability(0) | |
| if capability is None: | |
| return False | |
| major = capability.major | |
| except Exception: | |
| return False | |
| return major >= 9 |
There was a problem hiding this comment.
I don't think this is necessary as it is only triggered once and all machines on market has homogeneous GPUs within a single node.
| if current_platform.is_rocm(): | ||
| return self._forward_rocm( | ||
| if not self.has_tilelang: | ||
| return self._forward_unfused_post_pre( |
There was a problem hiding this comment.
I would like to keep the unfused path as well as torch fallback path when tilelang is not available and to prepare for a path to reenable the aiter MHC and allow developers easily validate if separate aiter MHC post and pre is faster than tilelang fused post and pre MHC in later PRs.
| ) | ||
|
|
||
| @classmethod | ||
| def is_arch_support_pdl(cls) -> bool: |
There was a problem hiding this comment.
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
|
This pull request has merge conflicts that must be resolved before it can be |
| self.head_dim, | ||
| ) | ||
| self._use_cutedsl_sparse_compressor = has_cutedsl() | ||
| if self._use_cutedsl_sparse_compressor: |
There was a problem hiding this comment.
This is a bugfix for import error cutlass not found, introduced in this PR #43584
There was a problem hiding this comment.
I have removed my fix. Thanks for fixing the import issue.
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
|
This pull request has merge conflicts that must be resolved before it can be |
|
WOW, tilelang in vllm now. |
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
| # amd-quark: required for Quark quantization on ROCm | ||
| # To be consistent with test_quark.py | ||
| amd-quark>=0.8.99 | ||
| tilelang>=0.1.10 |
There was a problem hiding this comment.
Nice test and even nicer feature. i am wondering for CI purposes to avoid any regressions from future versions, should we pin the version? Maybe we can add it in rocm.in too. We can also do that in a follow-up PR. But let me know if you agree.
There was a problem hiding this comment.
Let's do it in a follow up PR. Would like to land this PR and let @WoosukKwon continue with the restructuring of the mhc kernels.
There was a problem hiding this comment.
Thanks. I have pinned the version to exact version.
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Re-enable the aiter multi-head-consensus (mHC) pre/post ops as the preferred ROCm path for DeepSeek V4. PR vllm-project#43679 added a tilelang fused post+pre mHC kernel and left a hook to switch back to the (faster) aiter mHC kernels once an aiter release with the sqrsum race-condition fix was available; this is that follow-up. Dispatch is now purely capability based (no new env knobs): aiter mHC pre/post -> tilelang fused post+pre -> torch/triton reference The aiter kernels are used whenever aiter is available on a supported ROCm device (``is_aiter_found_and_supported()``) and the hidden size is a multiple of 256; otherwise we fall back to the tilelang fused kernel, and finally to the torch/triton reference implementation. On CUDA the tilelang path is unchanged. The aiter mHC kernels require aiter >= 0.1.14, which contains the sqrsum race-condition fix in ``mhc_pre_gemm_sqrsum_kernel`` (ROCm/aiter@b639cb6); without it results are wrong at large token counts. AITER_BRANCH in docker/Dockerfile.rocm_base is bumped v0.1.13 -> v0.1.14. The unfused aiter path applies hc_post inline and returns no residual streams, so the deferred hc_post in DeepseekV4Model.forward and the MTP layer is gated on ``residual is not None`` rather than has_tilelang/is_cuda. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: vLLM Contributor <contributor@vllm.ai> Co-authored-by: Cursor <cursoragent@cursor.com>
Re-enable the aiter multi-head-consensus (mHC) pre/post ops as the preferred ROCm path for DeepSeek V4. PR vllm-project#43679 added a tilelang fused post+pre mHC kernel and left a hook to switch back to the (faster) aiter mHC kernels once an aiter release with the sqrsum race-condition fix was available; this is that follow-up. Dispatch is now purely capability based (no new env knobs): aiter mHC pre/post -> tilelang fused post+pre -> torch/triton reference The aiter kernels are used whenever aiter is available on a supported ROCm device (``is_aiter_found_and_supported()``) and the hidden size is a multiple of 256; otherwise we fall back to the tilelang fused kernel, and finally to the torch/triton reference implementation. On CUDA the tilelang path is unchanged. The aiter mHC kernels require aiter >= 0.1.14, which contains the sqrsum race-condition fix in ``mhc_pre_gemm_sqrsum_kernel`` (ROCm/aiter@b639cb6); without it results are wrong at large token counts. AITER_BRANCH in docker/Dockerfile.rocm_base is bumped v0.1.13 -> v0.1.14. The unfused aiter path applies hc_post inline and returns no residual streams, so the deferred hc_post in DeepseekV4Model.forward and the MTP layer is gated on ``residual is not None`` rather than has_tilelang/is_cuda. Signed-off-by: Fangzhou Ai <fangzhou.ai@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
…ject#43679) Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: Xiaoran Chen <xiaoran@fb.com>
Purpose
In recent tilelang PR they support Vendor free compilation among CUDA and ROCM wheels in tile-ai/tilelang#2195 . So on ROCm we are pip installing the tilelang wheel from pypi directly.
This PR follows the way in sglang https://github.com/sgl-project/sglang/blob/c47f0e7cdde48ddc718e3c6ee8bc87bebee2e8ff/python/sglang/srt/layers/mhc.py#L88 to add a ENABLE_PDL control so that we can set it to False on unsupported platform like ROCm.
Now the Tilelang kernel is compatible for CUDA and ROCm.
This is used to replace the slow inference torch kernel if tilelang is support on the platform. I would like to keep the unfused path as well as torch fallback path when tilelang is not available and to prepare for a path to reenable the aiter MHC and allow developers easily validate if separate aiter MHC post and pre is faster than tilelang fused post and pre MHC in later PRs.
Test Plan
Lmeval score of no MTP and MTP must be around 0.95 for gsm8k score . MTP acceptance draft rate must be normal > 2.6 for gsm8k.
Perf gain over using torch mhc (before the PR)
kernel test case
tests/kernels/test_mhc_kernels.py======================= 51 passed, 38 warnings in 21.15s =======================Test Result
experimental details
No MTP server command
MTP command
Server benchmark script
Lmeval command (NOTE: numshot 8 and numshot 20 must both give a score of around 0.95 (+-0.01) for both strict and relax; else there could have accuracy issue)
DeepSeek V4 Pro no MTP
DeepSeek V4 Pro with MTP
Acceptance score:
[metrics.py:101] SpecDecoding metrics: Mean acceptance length: 2.71, Accepted throughput: 371.60 tokens/s, Drafted throughput: 435.60 tokens/s, Accepted: 3716 tokens, Drafted: 4356 tokens, Per-position acceptance rate: 0.962, 0.744, Avg Draft acceptance rate: 85.3%Performance gain
no MTP (Before PR - torch mhc/triton mhc and after PR - tilelang mhc)
Throughput and Latency
With MTP (Before PR - torch mhc/triton mhc and after PR - tilelang mhc)
Throughput, Latency, and Spec Decode
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.