From 9676182482d210a43790ed1bea3edc3bbf194b7b Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 25 May 2026 21:11:42 -0500 Subject: [PATCH 01/13] use tilelang mhc on rocm Signed-off-by: tjtanaa --- tests/kernels/test_mhc_kernels.py | 123 +++++++++++++++++++- vllm/_tilelang_ops.py | 91 ++++++++------- vllm/model_executor/kernels/mhc/tilelang.py | 95 ++++++++++----- vllm/model_executor/layers/mhc.py | 74 ++++++++++-- vllm/models/deepseek_v4/amd/model.py | 12 +- vllm/platforms/cuda.py | 8 ++ vllm/platforms/interface.py | 6 + 7 files changed, 326 insertions(+), 83 deletions(-) diff --git a/tests/kernels/test_mhc_kernels.py b/tests/kernels/test_mhc_kernels.py index 81fceeceac10..a2e76e9f8acf 100644 --- a/tests/kernels/test_mhc_kernels.py +++ b/tests/kernels/test_mhc_kernels.py @@ -92,8 +92,88 @@ def hc_head_ref( @pytest.mark.skipif( - not current_platform.is_cuda(), - reason="CUDA required", + not (current_platform.is_cuda() or current_platform.is_rocm()), + reason="CUDA or ROCm required", +) +@pytest.mark.parametrize("num_tokens", [4, 256]) +@pytest.mark.parametrize("hidden_size", [1280, 7168]) +@pytest.mark.parametrize("hc_mult", [4]) +def test_mhc_pre_tilelang(num_tokens, hidden_size, hc_mult): + torch.set_default_device(DEVICE) + set_random_seed(0) + + residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16) + hc_mult2 = hc_mult * hc_mult + hc_mult3 = 2 * hc_mult + hc_mult2 + fn = ( + torch.randn((hc_mult3, hc_mult, hidden_size), dtype=torch.float) + * 1e-4 + * (1 + torch.arange(hc_mult).mul(0.01).view(1, -1, 1)) + ).flatten(1, 2) + hc_scale = torch.randn((3,), dtype=torch.float) * 0.1 + hc_base = torch.randn((hc_mult3,), dtype=torch.float) * 0.1 + + hc_sinkhorn_eps = hc_pre_eps = rms_eps = 1e-6 + sinkhorn_repeat = 20 + hc_post_alpha = 1.0 + + ref = mhc_pre_ref( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_alpha, + sinkhorn_repeat, + ) + out = torch.ops.vllm.mhc_pre_tilelang( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_alpha, + sinkhorn_repeat, + ) + + for actual, expected in zip(out, ref, strict=True): + torch.testing.assert_close(actual, expected, atol=5e-2, rtol=1e-2) + + +@pytest.mark.skipif( + not (current_platform.is_cuda() or current_platform.is_rocm()), + reason="CUDA or ROCm required", +) +@pytest.mark.parametrize("num_tokens", [4, 256]) +@pytest.mark.parametrize("hidden_size", [1280, 7168]) +@pytest.mark.parametrize("hc_mult", [4]) +def test_mhc_post_tilelang(num_tokens, hidden_size, hc_mult): + torch.set_default_device(DEVICE) + set_random_seed(0) + + x = torch.randn((num_tokens, hidden_size), dtype=torch.bfloat16) + residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16) + post_layer_mix = torch.randn((num_tokens, hc_mult, 1), dtype=torch.float32) + comb_res_mix = torch.randn((num_tokens, hc_mult, hc_mult), dtype=torch.float32) + + ref = mhc_post_ref(x, residual, post_layer_mix, comb_res_mix) + out = torch.ops.vllm.mhc_post_tilelang( + x, + residual, + post_layer_mix, + comb_res_mix, + ) + + torch.testing.assert_close(out, ref, atol=5e-2, rtol=1e-2) + + +@pytest.mark.skipif( + not (current_platform.is_cuda() or current_platform.is_rocm()), + reason="CUDA or ROCm required", ) @pytest.mark.parametrize("num_tokens", [1, 4, 8, 128]) @pytest.mark.parametrize("hidden_size", [4096, 7168]) @@ -196,3 +276,42 @@ def test_hc_head_triton(num_tokens, hidden_size, hc_mult): out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps) torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2) + + +@pytest.mark.skipif( + not (current_platform.is_cuda() or current_platform.is_rocm()), + reason="CUDA or ROCm required", +) +@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128]) +@pytest.mark.parametrize("hidden_size", [4096, 7168]) +@pytest.mark.parametrize("hc_mult", [4]) +def test_hc_head_tilelang(num_tokens, hidden_size, hc_mult): + torch.set_default_device(DEVICE) + set_random_seed(0) + + residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16) + fn = torch.randn((hc_mult, hc_mult * hidden_size), dtype=torch.float32) * 1e-4 + hc_scale = torch.randn((1,), dtype=torch.float32) * 0.1 + hc_base = torch.randn((hc_mult,), dtype=torch.float32) * 0.1 + rms_eps = hc_eps = 1e-6 + + out = torch.empty((num_tokens, hidden_size), dtype=torch.bfloat16) + out.fill_(float("nan")) + + result = torch.ops.vllm.hc_head_fused_kernel_tilelang( + residual, + fn, + hc_scale, + hc_base, + out, + hidden_size, + rms_eps, + hc_eps, + hc_mult, + ) + + assert result is None + assert not torch.isnan(out).any() + + out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps) + torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2) diff --git a/vllm/_tilelang_ops.py b/vllm/_tilelang_ops.py index aa742fe50320..9f2827cceee3 100644 --- a/vllm/_tilelang_ops.py +++ b/vllm/_tilelang_ops.py @@ -10,8 +10,20 @@ from vllm.utils.import_utils import has_tilelang from vllm.utils.math_utils import cdiv -# tilelang is only available on CUDA platforms -if TYPE_CHECKING or current_platform.is_cuda(): +def _is_tilelang_platform() -> bool: + return current_platform.is_cuda() or current_platform.is_rocm() + + +def _is_pdl_supported() -> bool: + is_arch_support_pdl = getattr(current_platform, "is_arch_support_pdl", None) + if not callable(is_arch_support_pdl): + return False + return is_arch_support_pdl() + + +# TileLang is used for MHC on CUDA and ROCm. Keep non-GPU imports cheap so +# registering the Python wrapper modules does not require TileLang everywhere. +if TYPE_CHECKING or _is_tilelang_platform(): if not has_tilelang(): raise ImportError( "tilelang is required for mhc but is not installed. Install it with " @@ -23,6 +35,7 @@ tilelang = None # type: ignore[assignment] T = None # type: ignore[assignment] +ENABLE_PDL = _is_pdl_supported() @cache def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int: @@ -37,12 +50,17 @@ def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int: return split_k +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, +} + +if current_platform.is_cuda(): + pass_configs[tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL] = 10 + + @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs=pass_configs, ) def mhc_pre_big_fuse_tilelang( gemm_out_mul, @@ -78,7 +96,8 @@ def mhc_pre_big_fuse_tilelang( layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type] with T.Kernel(num_tokens, threads=96) as i: - T.pdl_sync() + if ENABLE_PDL: + T.pdl_sync() ################################################################## # _pre_norm_fn_fwd_norm rms = T.alloc_fragment(1, T.float32) @@ -174,18 +193,16 @@ def mhc_pre_big_fuse_tilelang( ol[i1_h] += pre * xl[i_hc, i1_h] T.copy(ol, layer_input[i, i0_h * hidden_block]) - T.pdl_trigger() + + if ENABLE_PDL: + T.pdl_trigger() # Copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/mhc.py#L478 @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs=pass_configs, ) def mhc_pre_big_fuse_with_norm_tilelang( gemm_out_mul, @@ -230,7 +247,8 @@ def mhc_pre_big_fuse_with_norm_tilelang( T.clear(mixes) rms[0] = 0 - T.pdl_sync() + if ENABLE_PDL: + T.pdl_sync() for i_split in T.serial(n_splits): rms[0] += gemm_out_sqrsum[i_split, i] @@ -341,15 +359,12 @@ def mhc_pre_big_fuse_with_norm_tilelang( T.copy(ol, layer_input[i, i0_h * hidden_block]) - T.pdl_trigger() + if ENABLE_PDL: + T.pdl_trigger() @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs=pass_configs, ) def mhc_fused_tilelang( comb_mix, @@ -390,8 +405,8 @@ def mhc_fused_tilelang( with T.Kernel(m, n_tiles, split_k, threads=n_thr) as (i_n, i_nt, i_ks): tid = T.get_thread_binding() - warp_id = T.get_warp_idx() - lane = T.get_lane_idx() + warp_id = tid // 32 + lane = tid % 32 s_warp = T.alloc_shared((num_warps, tile_n + 1), T.float32) s_post = T.alloc_shared((hc,), T.float32) @@ -407,7 +422,8 @@ def mhc_fused_tilelang( T.clear(sqr) h_split_start = i_ks * h_per_split - T.pdl_sync() + if ENABLE_PDL: + T.pdl_sync() T.copy(post_mix[i_n, 0], s_post) T.copy(comb_mix[i_n, 0, 0], s_comb) @@ -466,15 +482,12 @@ def mhc_fused_tilelang( v2 += s_warp[w, tile_n] rp_out[i_ks, i_n] = v2 - T.pdl_trigger() + if ENABLE_PDL: + T.pdl_trigger() @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs=pass_configs, ) def mhc_post_tilelang( a, @@ -507,7 +520,8 @@ def mhc_post_tilelang( a_local = T.alloc_fragment((hc, hc), T.float32) c_local = T.alloc_fragment(hc, T.float32) - T.pdl_sync() + if ENABLE_PDL: + T.pdl_sync() T.copy(a[i_n, 0, 0], a_local) T.copy(c[i_n, 0], c_local) @@ -523,15 +537,12 @@ def mhc_post_tilelang( x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h] T.copy(x_local, x[i_n, 0, i0_h * h_blk]) - T.pdl_trigger() + if ENABLE_PDL: + T.pdl_trigger() @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs=pass_configs, ) def hc_head_fuse_tilelang( residual, @@ -566,7 +577,8 @@ def hc_head_fuse_tilelang( out: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef,valid-type] with T.Kernel(num_tokens, threads=n_thr) as i: - T.pdl_sync() + if ENABLE_PDL: + T.pdl_sync() # ------------------------------------------------------------------ # Pass 1 – for each residual channel m_c and h_block: @@ -624,4 +636,5 @@ def hc_head_fuse_tilelang( T.copy(ol, out[i, i0_h * h_block], disable_tma=True) - T.pdl_trigger() + if ENABLE_PDL: + T.pdl_trigger() diff --git a/vllm/model_executor/kernels/mhc/tilelang.py b/vllm/model_executor/kernels/mhc/tilelang.py index c242fef2d026..3000cd402dd1 100644 --- a/vllm/model_executor/kernels/mhc/tilelang.py +++ b/vllm/model_executor/kernels/mhc/tilelang.py @@ -5,6 +5,25 @@ from vllm.utils.torch_utils import direct_register_custom_op +def _can_use_deep_gemm_hc_prenorm() -> bool: + from vllm.utils.deep_gemm import is_deep_gemm_supported + + return is_deep_gemm_supported() + + +def _torch_hc_prenorm_gemm( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, +) -> None: + assert out.shape[0] == 1 + assert sqrsum.shape[0] == 1 + x_float = x.float() + out[0].copy_(x_float @ fn.t()) + sqrsum[0].copy_(x_float.square().sum(dim=-1)) + + def mhc_pre_tilelang( residual: torch.Tensor, fn: torch.Tensor, @@ -80,10 +99,16 @@ def mhc_pre_tilelang( residual_flat = residual.view(-1, hc_mult, hidden_size) num_tokens = residual_flat.shape[0] - # these numbers are from deepgemm kernel impl - block_k = 64 - block_m = 64 - n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m)) + use_deep_gemm = _can_use_deep_gemm_hc_prenorm() + if use_deep_gemm: + # these numbers are from deepgemm kernel impl + block_k = 64 + block_m = 64 + n_splits = compute_num_split( + block_k, hc_hidden_size, cdiv(num_tokens, block_m) + ) + else: + n_splits = 1 post_mix = torch.empty( num_tokens, hc_mult, dtype=torch.float32, device=residual.device @@ -102,13 +127,17 @@ def mhc_pre_tilelang( n_splits, num_tokens, dtype=torch.float32, device=residual.device ) - tf32_hc_prenorm_gemm( - residual_flat.view(num_tokens, hc_mult * hidden_size), - fn, - gemm_out_mul, - gemm_out_sqrsum, - n_splits, - ) + residual_2d = residual_flat.view(num_tokens, hc_mult * hidden_size) + if use_deep_gemm: + tf32_hc_prenorm_gemm( + residual_2d, + fn, + gemm_out_mul, + gemm_out_sqrsum, + n_splits, + ) + else: + _torch_hc_prenorm_gemm(residual_2d, fn, gemm_out_mul, gemm_out_sqrsum) if norm_weight is None: mhc_pre_big_fuse_tilelang( @@ -304,16 +333,22 @@ def mhc_fused_post_pre_tilelang( post_layer_mix_flat = post_layer_mix.view(num_tokens, hc_mult) comb_res_mix_flat = comb_res_mix.view(num_tokens, hc_mult, hc_mult) - fma_token_threshold = 16 - if num_tokens <= fma_token_threshold: + use_deep_gemm = _can_use_deep_gemm_hc_prenorm() + use_small_fma = num_tokens <= 16 + if use_small_fma: # TODO(gnovack): investigate autotuning these heuristics tile_n = 2 if num_tokens < 8 else 3 n_splits = 8 if (num_tokens < 8 and hidden_size <= 4096) else 4 else: - # these number are from deepgemm kernel impl - block_k = 64 - block_m = 64 - n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m)) + if use_deep_gemm: + # these number are from deepgemm kernel impl + block_k = 64 + block_m = 64 + n_splits = compute_num_split( + block_k, hc_hidden_size, cdiv(num_tokens, block_m) + ) + else: + n_splits = 1 gemm_out_mul = torch.empty( n_splits, @@ -348,7 +383,7 @@ def mhc_fused_post_pre_tilelang( device=residual.device, ) - if num_tokens <= fma_token_threshold: + if use_small_fma: mhc_fused_tilelang( comb_res_mix_flat, residual_flat, @@ -375,15 +410,21 @@ def mhc_fused_post_pre_tilelang( residual.shape[-1], ) - from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm - - tf32_hc_prenorm_gemm( - residual_cur.view(num_tokens, hc_mult * hidden_size), - fn, - gemm_out_mul, - gemm_out_sqrsum, - n_splits, - ) + residual_cur_2d = residual_cur.view(num_tokens, hc_mult * hidden_size) + if use_deep_gemm: + from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm + + tf32_hc_prenorm_gemm( + residual_cur_2d, + fn, + gemm_out_mul, + gemm_out_sqrsum, + n_splits, + ) + else: + _torch_hc_prenorm_gemm( + residual_cur_2d, fn, gemm_out_mul, gemm_out_sqrsum + ) if norm_weight is None: mhc_pre_big_fuse_tilelang( diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index cb3965f764ba..ecf3b7a3c024 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -3,6 +3,7 @@ import torch # this import will also register the custom ops +# import vllm.model_executor.kernels.mhc # noqa: F401 import vllm.model_executor.kernels.mhc as mhc_kernels from vllm.model_executor.custom_op import CustomOp @@ -85,7 +86,18 @@ def forward_hip( # sinkhorn_repeat, # ) # else: - return mhc_kernels.mhc_pre_torch( + # return mhc_kernels.mhc_pre_torch( + # residual, + # fn, + # hc_scale, + # hc_base, + # rms_eps, + # hc_pre_eps, + # hc_sinkhorn_eps, + # hc_post_mult_value, + # sinkhorn_repeat, + # ) + return torch.ops.vllm.mhc_pre_tilelang( residual, fn, hc_scale, @@ -95,6 +107,9 @@ def forward_hip( hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, + n_splits, + norm_weight, + norm_eps, ) def forward_native(self, *args, **kwargs): @@ -147,11 +162,14 @@ def forward_hip( # comb_res_mix, # ) # else: - return mhc_kernels.mhc_post_torch( - x, - residual, - post_layer_mix, - comb_res_mix, + # return mhc_kernels.mhc_post_torch( + # x, + # residual, + # post_layer_mix, + # comb_res_mix, + # ) + return torch.ops.vllm.mhc_post_tilelang( + x, residual, post_layer_mix, comb_res_mix ) def forward_native(self, *args, **kwargs): @@ -220,7 +238,7 @@ def forward_hip( out = torch.empty( num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device ) - torch.ops.vllm.hc_head_triton( + torch.ops.vllm.hc_head_fused_kernel_tilelang( hs_flat, hc_fn, hc_scale, @@ -290,10 +308,46 @@ def forward_cuda( norm_eps, ) - def forward_hip(self, *args, **kwargs): - raise NotImplementedError( - "Hip implementation of mhc_fused_post_pre is not available" + def forward_hip( + self, + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, + tile_n: int = 1, + norm_weight: torch.Tensor | None = None, + norm_eps: float = 0.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return torch.ops.vllm.mhc_fused_post_pre_tilelang( + x, + residual, + post_layer_mix, + comb_res_mix, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + tile_n, + norm_weight, + norm_eps, ) + # raise NotImplementedError( + # "Hip implementation of mhc_fused_post_pre is not available" + # ) def forward_native(self, *args, **kwargs): raise NotImplementedError( diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index 1540667d1a4b..1680d10a192e 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -1195,10 +1195,10 @@ def forward( ) -> tuple[ torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None ]: - if current_platform.is_rocm(): - return self._forward_rocm( - x, positions, input_ids, post_mix, res_mix, residual - ) + # if current_platform.is_rocm(): + # return self._forward_rocm( + # x, positions, input_ids, post_mix, res_mix, residual + # ) return self._forward_cuda(x, positions, input_ids, post_mix, res_mix, residual) @@ -1361,7 +1361,9 @@ def forward( res_mix, residual, ) - if layer is not None and current_platform.is_cuda(): + if layer is not None and ( + current_platform.is_cuda() or current_platform.is_rocm() + ): hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix) if not get_pp_group().is_last_rank: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 4a5be741d06b..b5cfce40b1b9 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -592,6 +592,14 @@ def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConf default, rms_norm=rms_norm, fused_add_rms_norm=rms_norm ) + @classmethod + def is_arch_support_pdl(cls) -> bool: + try: + device = torch.cuda.current_device() + major, _ = torch.cuda.get_device_capability(device) + except Exception: + return False + return major >= 9 # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 9a93ef9f82a7..8c1046881113 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1016,6 +1016,12 @@ def get_default_ir_op_priority( # Native always used by default. Platforms can override this behavior. return IrOpPriorityConfig.with_default(["native"]) + @classmethod + def is_arch_support_pdl(cls) -> bool: + """ + Does the current platform support PDL (Programmatic Dependent Launch)? + """ + return False class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED From 5be8268972dcad982aab020898e3eb9b1e80bc0c Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 25 May 2026 21:21:36 -0500 Subject: [PATCH 02/13] clean up mhc test Signed-off-by: tjtanaa --- tests/kernels/test_mhc_kernels.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/kernels/test_mhc_kernels.py b/tests/kernels/test_mhc_kernels.py index a2e76e9f8acf..597f8e1ff7cd 100644 --- a/tests/kernels/test_mhc_kernels.py +++ b/tests/kernels/test_mhc_kernels.py @@ -6,6 +6,7 @@ import vllm.model_executor.kernels.mhc # noqa: F401 from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed +from vllm.utils.import_utils import has_tilelang DEVICE = current_platform.device_type @@ -92,11 +93,11 @@ def hc_head_ref( @pytest.mark.skipif( - not (current_platform.is_cuda() or current_platform.is_rocm()), - reason="CUDA or ROCm required", + not (current_platform.is_cuda_alike() and has_tilelang()), + reason="CUDA or ROCm and tilelang required", ) -@pytest.mark.parametrize("num_tokens", [4, 256]) -@pytest.mark.parametrize("hidden_size", [1280, 7168]) +@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128]) +@pytest.mark.parametrize("hidden_size", [4096, 7168]) @pytest.mark.parametrize("hc_mult", [4]) def test_mhc_pre_tilelang(num_tokens, hidden_size, hc_mult): torch.set_default_device(DEVICE) @@ -145,11 +146,11 @@ def test_mhc_pre_tilelang(num_tokens, hidden_size, hc_mult): @pytest.mark.skipif( - not (current_platform.is_cuda() or current_platform.is_rocm()), - reason="CUDA or ROCm required", + not (current_platform.is_cuda_alike() and has_tilelang()), + reason="CUDA or ROCm and tilelang required", ) -@pytest.mark.parametrize("num_tokens", [4, 256]) -@pytest.mark.parametrize("hidden_size", [1280, 7168]) +@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128]) +@pytest.mark.parametrize("hidden_size", [4096, 7168]) @pytest.mark.parametrize("hc_mult", [4]) def test_mhc_post_tilelang(num_tokens, hidden_size, hc_mult): torch.set_default_device(DEVICE) @@ -172,8 +173,8 @@ def test_mhc_post_tilelang(num_tokens, hidden_size, hc_mult): @pytest.mark.skipif( - not (current_platform.is_cuda() or current_platform.is_rocm()), - reason="CUDA or ROCm required", + not (current_platform.is_cuda_alike() and has_tilelang()), + reason="CUDA or ROCm and tilelang required", ) @pytest.mark.parametrize("num_tokens", [1, 4, 8, 128]) @pytest.mark.parametrize("hidden_size", [4096, 7168]) @@ -279,8 +280,8 @@ def test_hc_head_triton(num_tokens, hidden_size, hc_mult): @pytest.mark.skipif( - not (current_platform.is_cuda() or current_platform.is_rocm()), - reason="CUDA or ROCm required", + not (current_platform.is_cuda_alike() and has_tilelang()), + reason="CUDA or ROCm and tilelang required", ) @pytest.mark.parametrize("num_tokens", [1, 4, 8, 128]) @pytest.mark.parametrize("hidden_size", [4096, 7168]) From 122ad4490afea61a884ffd6c88714bc8387b2e0e Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 25 May 2026 21:21:59 -0500 Subject: [PATCH 03/13] add tilelang to requirements.txt Signed-off-by: tjtanaa --- requirements/build/rocm.txt | 1 + requirements/rocm.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements/build/rocm.txt b/requirements/build/rocm.txt index e5c2176a2c8c..752fb1db786a 100644 --- a/requirements/build/rocm.txt +++ b/requirements/build/rocm.txt @@ -16,3 +16,4 @@ wheel jinja2>=3.1.6 amdsmi==7.0.2 timm>=1.0.17 +tilelang>=0.1.10 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 61fcbc07010c..37bdf1afd15a 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -22,3 +22,4 @@ timm>=1.0.17 # amd-quark: required for Quark quantization on ROCm # To be consistent with test_quark.py amd-quark>=0.8.99 +tilelang>=0.1.10 From 63d5159e30c5716bb1c908f7677adc169147eefc Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 25 May 2026 23:02:52 -0500 Subject: [PATCH 04/13] add tilelang_hc_pre_norm_gemm Signed-off-by: tjtanaa --- tests/kernels/test_mhc_kernels.py | 48 ++++++ vllm/_tilelang_ops.py | 181 ++++++++++++++++++++ vllm/model_executor/kernels/mhc/tilelang.py | 87 +++++++++- 3 files changed, 313 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_mhc_kernels.py b/tests/kernels/test_mhc_kernels.py index 597f8e1ff7cd..985cb6830105 100644 --- a/tests/kernels/test_mhc_kernels.py +++ b/tests/kernels/test_mhc_kernels.py @@ -4,6 +4,10 @@ import torch import vllm.model_executor.kernels.mhc # noqa: F401 +from vllm.model_executor.kernels.mhc.tilelang import ( + _tilelang_hc_prenorm_gemm, + _torch_hc_prenorm_gemm, +) from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed from vllm.utils.import_utils import has_tilelang @@ -145,6 +149,50 @@ def test_mhc_pre_tilelang(num_tokens, hidden_size, hc_mult): torch.testing.assert_close(actual, expected, atol=5e-2, rtol=1e-2) +@pytest.mark.skipif( + not (current_platform.is_cuda_alike() and has_tilelang()), + reason="CUDA or ROCm and tilelang required", +) +@pytest.mark.parametrize( + ("num_tokens", "hidden_size"), + [ + (1, 1280), + (512, 1280), + (2048, 1280), + (1, 4096), + (64, 4096), + (512, 4096), + (2048, 4096), + (1, 7168), + (64, 7168), + (512, 7168), + (2048, 7168), + ], +) +def test_hc_prenorm_gemm_tilelang(num_tokens, hidden_size): + torch.set_default_device(DEVICE) + set_random_seed(0) + + hc_mult = 4 + hc_mult3 = 2 * hc_mult + hc_mult * hc_mult + x = torch.randn( + (num_tokens, hc_mult * hidden_size), dtype=torch.bfloat16 + ) + fn = torch.randn( + (hc_mult3, hc_mult * hidden_size), dtype=torch.float32 + ) * 1e-4 + out_ref = torch.empty((1, num_tokens, hc_mult3), dtype=torch.float32) + sqrsum_ref = torch.empty((1, num_tokens), dtype=torch.float32) + out = torch.empty_like(out_ref) + sqrsum = torch.empty_like(sqrsum_ref) + + _torch_hc_prenorm_gemm(x, fn, out_ref, sqrsum_ref) + _tilelang_hc_prenorm_gemm(x, fn, out, sqrsum, hidden_size, hc_mult) + + torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-4) + torch.testing.assert_close(sqrsum, sqrsum_ref, atol=8.0, rtol=5e-4) + + @pytest.mark.skipif( not (current_platform.is_cuda_alike() and has_tilelang()), reason="CUDA or ROCm and tilelang required", diff --git a/vllm/_tilelang_ops.py b/vllm/_tilelang_ops.py index 9f2827cceee3..d30193f08a67 100644 --- a/vllm/_tilelang_ops.py +++ b/vllm/_tilelang_ops.py @@ -541,6 +541,187 @@ def mhc_post_tilelang( T.pdl_trigger() +@tilelang.jit( + pass_configs=pass_configs, +) +def hc_prenorm_gemm_tilelang( + x, + fn, + out, + sqrsum, + hidden_size: int, + hc_mult: int = 4, + n_out: int = 24, + n_thr: int = 512, + tile_n: int = 12, + n_splits: int = 1, +) -> tilelang.JITKernel: + num_tokens = T.dynamic("num_tokens") + hc_hidden_size = hc_mult * hidden_size + k_per_split = hc_hidden_size // n_splits + k_iters = k_per_split // n_thr + n_tiles = T.ceildiv(n_out, tile_n) + + x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) # type: ignore[no-redef, valid-type] + fn: T.Tensor((n_out, hc_hidden_size), T.float32) # type: ignore[no-redef, valid-type] + out: T.Tensor((n_splits, num_tokens, n_out), T.float32) # type: ignore[no-redef, valid-type] + sqrsum: T.Tensor((n_splits, num_tokens), T.float32) # type: ignore[no-redef, valid-type] + + with T.Kernel(num_tokens, n_tiles, n_splits, threads=n_thr) as ( + i_n, + i_t, + i_s, + ): + tid = T.get_thread_binding() + acc = T.alloc_local((tile_n,), T.float32) + sqr = T.alloc_local((1,), T.float32) + T.clear(acc) + T.clear(sqr) + + if ENABLE_PDL: + T.pdl_sync() + + for it in T.serial(k_iters): + i_k = i_s * k_per_split + it * n_thr + tid + x_val = x[i_n, i_k] + for i_o in T.unroll(tile_n): + out_idx = i_t * tile_n + i_o + if out_idx < n_out: + acc[i_o] += x_val * fn[out_idx, i_k] + if i_t == 0: + sqr[0] += x_val * x_val + + for i_o in T.unroll(tile_n): + acc[i_o] = T.warp_reduce_sum(acc[i_o]) + if i_t == 0: + sqr[0] = T.warp_reduce_sum(sqr[0]) + + lane = tid % 32 + warp_id = tid // 32 + num_warps = n_thr // 32 + warp_acc = T.alloc_shared((num_warps, tile_n), T.float32) + warp_sqr = T.alloc_shared(num_warps, T.float32) + + if lane == 0: + for i_o in T.unroll(tile_n): + warp_acc[warp_id, i_o] = acc[i_o] + if i_t == 0: + warp_sqr[warp_id] = sqr[0] + T.sync_threads() + + if warp_id == 0: + if lane < tile_n: + reduced_acc = T.alloc_var(T.float32, init=0.0) + for i_w in T.unroll(num_warps): + reduced_acc += warp_acc[i_w, lane] + out_idx = i_t * tile_n + lane + if out_idx < n_out: + out[i_s, i_n, out_idx] = reduced_acc + if lane == 0 and i_t == 0: + reduced_sqr = T.alloc_var(T.float32, init=0.0) + for i_w in T.unroll(num_warps): + reduced_sqr += warp_sqr[i_w] + sqrsum[i_s, i_n] = reduced_sqr + + if ENABLE_PDL: + T.pdl_trigger() + + +@tilelang.jit( + pass_configs=pass_configs, +) +def hc_prenorm_gemm_block_m_tilelang( + x, + fn, + out, + sqrsum, + hidden_size: int, + hc_mult: int = 4, + n_out: int = 24, + n_thr: int = 512, + tile_n: int = 12, + block_m: int = 2, +) -> tilelang.JITKernel: + num_tokens = T.dynamic("num_tokens") + hc_hidden_size = hc_mult * hidden_size + k_iters = hc_hidden_size // n_thr + n_tiles = T.ceildiv(n_out, tile_n) + m_tiles = T.ceildiv(num_tokens, block_m) + + x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) # type: ignore[no-redef, valid-type] + fn: T.Tensor((n_out, hc_hidden_size), T.float32) # type: ignore[no-redef, valid-type] + out: T.Tensor((1, num_tokens, n_out), T.float32) # type: ignore[no-redef, valid-type] + sqrsum: T.Tensor((1, num_tokens), T.float32) # type: ignore[no-redef, valid-type] + + with T.Kernel(m_tiles, n_tiles, threads=n_thr) as (i_mt, i_t): + tid = T.get_thread_binding() + acc = T.alloc_local((block_m, tile_n), T.float32) + sqr = T.alloc_local((block_m,), T.float32) + T.clear(acc) + T.clear(sqr) + + if ENABLE_PDL: + T.pdl_sync() + + for it in T.serial(k_iters): + i_k = it * n_thr + tid + fn_val = T.alloc_local((tile_n,), T.float32) + for i_o in T.unroll(tile_n): + out_idx = i_t * tile_n + i_o + if out_idx < n_out: + fn_val[i_o] = fn[out_idx, i_k] + else: + fn_val[i_o] = 0.0 + for i_m in T.unroll(block_m): + token_idx = i_mt * block_m + i_m + if token_idx < num_tokens: + x_val = x[token_idx, i_k] + for i_o in T.unroll(tile_n): + acc[i_m, i_o] += x_val * fn_val[i_o] + if i_t == 0: + sqr[i_m] += x_val * x_val + + for i_m in T.unroll(block_m): + for i_o in T.unroll(tile_n): + acc[i_m, i_o] = T.warp_reduce_sum(acc[i_m, i_o]) + if i_t == 0: + sqr[i_m] = T.warp_reduce_sum(sqr[i_m]) + + lane = tid % 32 + warp_id = tid // 32 + num_warps = n_thr // 32 + warp_acc = T.alloc_shared((num_warps, block_m, tile_n), T.float32) + warp_sqr = T.alloc_shared((num_warps, block_m), T.float32) + + if lane == 0: + for i_m in T.unroll(block_m): + for i_o in T.unroll(tile_n): + warp_acc[warp_id, i_m, i_o] = acc[i_m, i_o] + if i_t == 0: + warp_sqr[warp_id, i_m] = sqr[i_m] + T.sync_threads() + + if warp_id == 0: + for i_m in T.unroll(block_m): + token_idx = i_mt * block_m + i_m + if token_idx < num_tokens: + if lane < tile_n: + reduced_acc = T.alloc_var(T.float32, init=0.0) + for i_w in T.unroll(num_warps): + reduced_acc += warp_acc[i_w, i_m, lane] + out_idx = i_t * tile_n + lane + if out_idx < n_out: + out[0, token_idx, out_idx] = reduced_acc + if lane == 0 and i_t == 0: + reduced_sqr = T.alloc_var(T.float32, init=0.0) + for i_w in T.unroll(num_warps): + reduced_sqr += warp_sqr[i_w, i_m] + sqrsum[0, token_idx] = reduced_sqr + + if ENABLE_PDL: + T.pdl_trigger() + + @tilelang.jit( pass_configs=pass_configs, ) diff --git a/vllm/model_executor/kernels/mhc/tilelang.py b/vllm/model_executor/kernels/mhc/tilelang.py index 3000cd402dd1..16ef72fcf19c 100644 --- a/vllm/model_executor/kernels/mhc/tilelang.py +++ b/vllm/model_executor/kernels/mhc/tilelang.py @@ -24,6 +24,75 @@ def _torch_hc_prenorm_gemm( sqrsum[0].copy_(x_float.square().sum(dim=-1)) +def _tilelang_hc_prenorm_gemm( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + hidden_size: int, + hc_mult: int, + tile_n: int = 12, + n_thr: int = 512, + n_splits: int = 1, +) -> None: + from vllm._tilelang_ops import ( + hc_prenorm_gemm_block_m_tilelang, + hc_prenorm_gemm_tilelang, + ) + + assert out.shape[0] == n_splits + assert sqrsum.shape[0] == n_splits + assert x.shape[1] == hc_mult * hidden_size + assert x.shape[1] % n_splits == 0 + assert (x.shape[1] // n_splits) % n_thr == 0 + use_default_config = tile_n == 12 and n_thr == 512 + if n_splits == 1 and use_default_config and x.shape[0] >= 1024: + hc_prenorm_gemm_block_m_tilelang( + x, + fn, + out, + sqrsum, + hidden_size, + hc_mult, + fn.shape[0], + n_thr, + tile_n, + 2, + ) + return + if ( + n_splits == 1 + and use_default_config + and x.shape[0] < 128 + and x.shape[1] % 1024 == 0 + ): + hc_prenorm_gemm_tilelang( + x, + fn, + out, + sqrsum, + hidden_size, + hc_mult, + fn.shape[0], + 1024, + 4, + n_splits, + ) + return + hc_prenorm_gemm_tilelang( + x, + fn, + out, + sqrsum, + hidden_size, + hc_mult, + fn.shape[0], + n_thr, + tile_n, + n_splits, + ) + + def mhc_pre_tilelang( residual: torch.Tensor, fn: torch.Tensor, @@ -137,7 +206,14 @@ def mhc_pre_tilelang( n_splits, ) else: - _torch_hc_prenorm_gemm(residual_2d, fn, gemm_out_mul, gemm_out_sqrsum) + _tilelang_hc_prenorm_gemm( + residual_2d, + fn, + gemm_out_mul, + gemm_out_sqrsum, + hidden_size, + hc_mult, + ) if norm_weight is None: mhc_pre_big_fuse_tilelang( @@ -422,8 +498,13 @@ def mhc_fused_post_pre_tilelang( n_splits, ) else: - _torch_hc_prenorm_gemm( - residual_cur_2d, fn, gemm_out_mul, gemm_out_sqrsum + _tilelang_hc_prenorm_gemm( + residual_cur_2d, + fn, + gemm_out_mul, + gemm_out_sqrsum, + hidden_size, + hc_mult, ) if norm_weight is None: From c966a7dab19b33afd4c933d3f948b1b80fa814ab Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 26 May 2026 01:01:06 -0500 Subject: [PATCH 05/13] fix prefix commit; support fuse and unfused codepath Signed-off-by: tjtanaa --- tests/kernels/test_mhc_kernels.py | 10 +-- vllm/_tilelang_ops.py | 6 +- vllm/model_executor/kernels/mhc/tilelang.py | 4 +- vllm/model_executor/layers/mhc.py | 93 +++++++++++++++------ vllm/models/deepseek_v4/amd/model.py | 23 ++--- vllm/platforms/cuda.py | 1 + vllm/platforms/interface.py | 1 + 7 files changed, 91 insertions(+), 47 deletions(-) diff --git a/tests/kernels/test_mhc_kernels.py b/tests/kernels/test_mhc_kernels.py index 985cb6830105..e7d4cde43f1d 100644 --- a/tests/kernels/test_mhc_kernels.py +++ b/tests/kernels/test_mhc_kernels.py @@ -9,8 +9,8 @@ _torch_hc_prenorm_gemm, ) from vllm.platforms import current_platform -from vllm.utils.torch_utils import set_random_seed from vllm.utils.import_utils import has_tilelang +from vllm.utils.torch_utils import set_random_seed DEVICE = current_platform.device_type @@ -175,12 +175,8 @@ def test_hc_prenorm_gemm_tilelang(num_tokens, hidden_size): hc_mult = 4 hc_mult3 = 2 * hc_mult + hc_mult * hc_mult - x = torch.randn( - (num_tokens, hc_mult * hidden_size), dtype=torch.bfloat16 - ) - fn = torch.randn( - (hc_mult3, hc_mult * hidden_size), dtype=torch.float32 - ) * 1e-4 + x = torch.randn((num_tokens, hc_mult * hidden_size), dtype=torch.bfloat16) + fn = torch.randn((hc_mult3, hc_mult * hidden_size), dtype=torch.float32) * 1e-4 out_ref = torch.empty((1, num_tokens, hc_mult3), dtype=torch.float32) sqrsum_ref = torch.empty((1, num_tokens), dtype=torch.float32) out = torch.empty_like(out_ref) diff --git a/vllm/_tilelang_ops.py b/vllm/_tilelang_ops.py index d30193f08a67..3b34f7f098a7 100644 --- a/vllm/_tilelang_ops.py +++ b/vllm/_tilelang_ops.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from functools import cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch @@ -10,6 +10,7 @@ from vllm.utils.import_utils import has_tilelang from vllm.utils.math_utils import cdiv + def _is_tilelang_platform() -> bool: return current_platform.is_cuda() or current_platform.is_rocm() @@ -37,6 +38,7 @@ def _is_pdl_supported() -> bool: ENABLE_PDL = _is_pdl_supported() + @cache def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int: device_props = torch.cuda.get_device_properties(0) @@ -50,7 +52,7 @@ def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int: return split_k -pass_configs = { +pass_configs: dict[tilelang.PassConfigKey, Any] = { tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, } diff --git a/vllm/model_executor/kernels/mhc/tilelang.py b/vllm/model_executor/kernels/mhc/tilelang.py index 16ef72fcf19c..61f5c6aa2fa7 100644 --- a/vllm/model_executor/kernels/mhc/tilelang.py +++ b/vllm/model_executor/kernels/mhc/tilelang.py @@ -173,9 +173,7 @@ def mhc_pre_tilelang( # these numbers are from deepgemm kernel impl block_k = 64 block_m = 64 - n_splits = compute_num_split( - block_k, hc_hidden_size, cdiv(num_tokens, block_m) - ) + n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m)) else: n_splits = 1 diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index ecf3b7a3c024..aadefac380f5 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -86,17 +86,6 @@ def forward_hip( # sinkhorn_repeat, # ) # else: - # return mhc_kernels.mhc_pre_torch( - # residual, - # fn, - # hc_scale, - # hc_base, - # rms_eps, - # hc_pre_eps, - # hc_sinkhorn_eps, - # hc_post_mult_value, - # sinkhorn_repeat, - # ) return torch.ops.vllm.mhc_pre_tilelang( residual, fn, @@ -112,8 +101,32 @@ def forward_hip( norm_eps, ) - def forward_native(self, *args, **kwargs): - raise NotImplementedError("Native implementation of mhc_pre is not available") + def forward_native( + self, + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, + norm_weight: torch.Tensor | None = None, + norm_eps: float = 0.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return mhc_kernels.mhc_pre_torch( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + ) # --8<-- [start:mhc_post] @@ -162,18 +175,23 @@ def forward_hip( # comb_res_mix, # ) # else: - # return mhc_kernels.mhc_post_torch( - # x, - # residual, - # post_layer_mix, - # comb_res_mix, - # ) return torch.ops.vllm.mhc_post_tilelang( x, residual, post_layer_mix, comb_res_mix ) - def forward_native(self, *args, **kwargs): - raise NotImplementedError("Native implementation of mhc_post is not available") + def forward_native( + self, + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, + ) -> torch.Tensor: + return mhc_kernels.mhc_post_torch( + x, + residual, + post_layer_mix, + comb_res_mix, + ) # --8<-- [start:hc_head] @@ -251,6 +269,36 @@ def forward_hip( ) return out.view(*outer_shape, hidden_size) + def _forward_triton( + self, + hidden_states: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_norm_eps: float, + hc_eps: float, + ) -> torch.Tensor: + hc_mult, hidden_size = hidden_states.shape[-2:] + outer_shape = hidden_states.shape[:-2] + hs_flat = hidden_states.view(-1, hc_mult, hidden_size) + num_tokens = hs_flat.shape[0] + + out = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device + ) + torch.ops.vllm.hc_head_triton( + hs_flat, + hc_fn, + hc_scale, + hc_base, + out, + hidden_size, + rms_norm_eps, + hc_eps, + hc_mult, + ) + return out.view(*outer_shape, hidden_size) + def forward_native(self, *args, **kwargs): raise NotImplementedError("Native implementation of hc_head is not available") @@ -345,9 +393,6 @@ def forward_hip( norm_weight, norm_eps, ) - # raise NotImplementedError( - # "Hip implementation of mhc_fused_post_pre is not available" - # ) def forward_native(self, *args, **kwargs): raise NotImplementedError( diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index 1680d10a192e..035aee4bd3ab 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -61,6 +61,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.triton_utils import tl, triton +from vllm.utils.import_utils import has_tilelang from vllm.utils.torch_utils import direct_register_custom_op @@ -1074,6 +1075,7 @@ def __init__( self.mhc_pre = MHCPreOp() self.mhc_post = MHCPostOp() self.mhc_fused_post_pre = MHCFusedPostPreOp() + self.has_tilelang = has_tilelang() def hc_pre( self, @@ -1104,7 +1106,7 @@ def hc_post( ): return self.mhc_post(x, residual, post, comb) - def _forward_cuda( + def _forward_fused_post_pre( self, x: torch.Tensor, positions: torch.Tensor, @@ -1156,7 +1158,7 @@ def _forward_cuda( x = self.ffn(x, input_ids) return x, residual, post_mix, res_mix - def _forward_rocm( + def _forward_unfused_post_pre( self, x: torch.Tensor, positions: torch.Tensor, @@ -1195,12 +1197,13 @@ def forward( ) -> tuple[ torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None ]: - # if current_platform.is_rocm(): - # return self._forward_rocm( - # x, positions, input_ids, post_mix, res_mix, residual - # ) - - return self._forward_cuda(x, positions, input_ids, post_mix, res_mix, residual) + if not self.has_tilelang: + return self._forward_unfused_post_pre( + x, positions, input_ids, post_mix, res_mix, residual + ) + return self._forward_fused_post_pre( + x, positions, input_ids, post_mix, res_mix, residual + ) @support_torch_compile @@ -1361,9 +1364,7 @@ def forward( res_mix, residual, ) - if layer is not None and ( - current_platform.is_cuda() or current_platform.is_rocm() - ): + if layer is not None and self.has_tilelang: hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix) if not get_pp_group().is_last_rank: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b5cfce40b1b9..4b6f52b12fea 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -601,6 +601,7 @@ def is_arch_support_pdl(cls) -> bool: return False return major >= 9 + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # all the related functions work on real physical device ids. diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 8c1046881113..cf774b7bda91 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1023,6 +1023,7 @@ def is_arch_support_pdl(cls) -> bool: """ return False + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED device_type = "" From d2f3c6dadde6cbeec63ceb9242742e6b5f3e4165 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 26 May 2026 02:32:57 -0500 Subject: [PATCH 06/13] use tilelang mhc in mtp as well Signed-off-by: tjtanaa --- vllm/model_executor/layers/mhc.py | 119 ++++++++++++++------------- vllm/models/deepseek_v4/amd/model.py | 1 + vllm/models/deepseek_v4/amd/mtp.py | 4 +- vllm/utils/import_utils.py | 1 + 4 files changed, 68 insertions(+), 57 deletions(-) diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index aadefac380f5..b720fa1f6fe2 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -6,6 +6,9 @@ # import vllm.model_executor.kernels.mhc # noqa: F401 import vllm.model_executor.kernels.mhc as mhc_kernels from vllm.model_executor.custom_op import CustomOp +from vllm.utils.import_utils import has_tilelang + +HAS_TILELANG = has_tilelang() # --8<-- [start:mhc_pre] @@ -86,20 +89,36 @@ def forward_hip( # sinkhorn_repeat, # ) # else: - return torch.ops.vllm.mhc_pre_tilelang( - residual, - fn, - hc_scale, - hc_base, - rms_eps, - hc_pre_eps, - hc_sinkhorn_eps, - hc_post_mult_value, - sinkhorn_repeat, - n_splits, - norm_weight, - norm_eps, - ) + if HAS_TILELANG: + return torch.ops.vllm.mhc_pre_tilelang( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + norm_weight, + norm_eps, + ) + else: + return self.forward_native( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + norm_weight, + norm_eps, + ) def forward_native( self, @@ -175,9 +194,12 @@ def forward_hip( # comb_res_mix, # ) # else: - return torch.ops.vllm.mhc_post_tilelang( - x, residual, post_layer_mix, comb_res_mix - ) + if HAS_TILELANG: + return torch.ops.vllm.mhc_post_tilelang( + x, residual, post_layer_mix, comb_res_mix + ) + else: + return self.forward_native(x, residual, post_layer_mix, comb_res_mix) def forward_native( self, @@ -256,47 +278,32 @@ def forward_hip( out = torch.empty( num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device ) - torch.ops.vllm.hc_head_fused_kernel_tilelang( - hs_flat, - hc_fn, - hc_scale, - hc_base, - out, - hidden_size, - rms_norm_eps, - hc_eps, - hc_mult, - ) - return out.view(*outer_shape, hidden_size) - def _forward_triton( - self, - hidden_states: torch.Tensor, - hc_fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - rms_norm_eps: float, - hc_eps: float, - ) -> torch.Tensor: - hc_mult, hidden_size = hidden_states.shape[-2:] - outer_shape = hidden_states.shape[:-2] - hs_flat = hidden_states.view(-1, hc_mult, hidden_size) - num_tokens = hs_flat.shape[0] + if HAS_TILELANG: + torch.ops.vllm.hc_head_fused_kernel_tilelang( + hs_flat, + hc_fn, + hc_scale, + hc_base, + out, + hidden_size, + rms_norm_eps, + hc_eps, + hc_mult, + ) + else: + torch.ops.vllm.hc_head_triton( + hs_flat, + hc_fn, + hc_scale, + hc_base, + out, + hidden_size, + rms_norm_eps, + hc_eps, + hc_mult, + ) - out = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device - ) - torch.ops.vllm.hc_head_triton( - hs_flat, - hc_fn, - hc_scale, - hc_base, - out, - hidden_size, - rms_norm_eps, - hc_eps, - hc_mult, - ) return out.view(*outer_shape, hidden_size) def forward_native(self, *args, **kwargs): diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index 035aee4bd3ab..b1d8618f98bc 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -1295,6 +1295,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): requires_grad=False, ) self.hc_head_op = HCHeadOp() + self.has_tilelang = has_tilelang() # Pre-hc_head residual stream buffer for the MTP draft. Stable # address (outside the cudagraph pool) so the copy_ in forward() # refreshes it correctly across captured shapes. diff --git a/vllm/models/deepseek_v4/amd/mtp.py b/vllm/models/deepseek_v4/amd/mtp.py index bcdd76de4c29..7f7501d7a4f0 100644 --- a/vllm/models/deepseek_v4/amd/mtp.py +++ b/vllm/models/deepseek_v4/amd/mtp.py @@ -39,6 +39,7 @@ from vllm.model_executor.models.utils import maybe_prefix from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils.import_utils import has_tilelang from .model import ( DeepseekV4DecoderLayer, @@ -121,6 +122,7 @@ def __init__( ) self.hc_head_op = HCHeadOp() + self.has_tilelang = has_tilelang() def forward( self, @@ -147,7 +149,7 @@ def forward( hidden_states, residual, post_mix, res_mix = self.mtp_block( positions=positions, x=hidden_states, input_ids=None ) - if current_platform.is_cuda(): + if self.has_tilelang: hidden_states = self.mtp_block.hc_post( hidden_states, residual, post_mix, res_mix ) diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index 5822e5840afc..e97228bfa609 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -430,6 +430,7 @@ def has_triton_kernels() -> bool: return is_available +@cache def has_tilelang() -> bool: """Whether the optional `tilelang` package is available.""" return _has_module("tilelang") From 9111eeee93ed21619f18b8ff9cb29920599dc420 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 26 May 2026 08:28:42 -0500 Subject: [PATCH 07/13] clean up code Signed-off-by: tjtanaa --- vllm/model_executor/kernels/mhc/tilelang.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/kernels/mhc/tilelang.py b/vllm/model_executor/kernels/mhc/tilelang.py index 61f5c6aa2fa7..d76123bb7625 100644 --- a/vllm/model_executor/kernels/mhc/tilelang.py +++ b/vllm/model_executor/kernels/mhc/tilelang.py @@ -5,12 +5,6 @@ from vllm.utils.torch_utils import direct_register_custom_op -def _can_use_deep_gemm_hc_prenorm() -> bool: - from vllm.utils.deep_gemm import is_deep_gemm_supported - - return is_deep_gemm_supported() - - def _torch_hc_prenorm_gemm( x: torch.Tensor, fn: torch.Tensor, @@ -168,7 +162,9 @@ def mhc_pre_tilelang( residual_flat = residual.view(-1, hc_mult, hidden_size) num_tokens = residual_flat.shape[0] - use_deep_gemm = _can_use_deep_gemm_hc_prenorm() + from vllm.utils.deep_gemm import is_deep_gemm_supported + + use_deep_gemm = is_deep_gemm_supported() if use_deep_gemm: # these numbers are from deepgemm kernel impl block_k = 64 @@ -407,7 +403,9 @@ def mhc_fused_post_pre_tilelang( post_layer_mix_flat = post_layer_mix.view(num_tokens, hc_mult) comb_res_mix_flat = comb_res_mix.view(num_tokens, hc_mult, hc_mult) - use_deep_gemm = _can_use_deep_gemm_hc_prenorm() + from vllm.utils.deep_gemm import is_deep_gemm_supported + + use_deep_gemm = is_deep_gemm_supported() use_small_fma = num_tokens <= 16 if use_small_fma: # TODO(gnovack): investigate autotuning these heuristics From fbb43d77fe2e05e1cae8a84f616252662b0a02ca Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 26 May 2026 09:14:26 -0500 Subject: [PATCH 08/13] clean up code Signed-off-by: tjtanaa --- vllm/_tilelang_ops.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/_tilelang_ops.py b/vllm/_tilelang_ops.py index 3b34f7f098a7..8808a1ef0abf 100644 --- a/vllm/_tilelang_ops.py +++ b/vllm/_tilelang_ops.py @@ -11,10 +11,6 @@ from vllm.utils.math_utils import cdiv -def _is_tilelang_platform() -> bool: - return current_platform.is_cuda() or current_platform.is_rocm() - - def _is_pdl_supported() -> bool: is_arch_support_pdl = getattr(current_platform, "is_arch_support_pdl", None) if not callable(is_arch_support_pdl): @@ -24,7 +20,7 @@ def _is_pdl_supported() -> bool: # TileLang is used for MHC on CUDA and ROCm. Keep non-GPU imports cheap so # registering the Python wrapper modules does not require TileLang everywhere. -if TYPE_CHECKING or _is_tilelang_platform(): +if TYPE_CHECKING or current_platform.is_cuda_alike(): if not has_tilelang(): raise ImportError( "tilelang is required for mhc but is not installed. Install it with " From ed3db3288417fa051abf113893590ed2be1d8178 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 26 May 2026 09:49:22 -0500 Subject: [PATCH 09/13] clean up code Signed-off-by: tjtanaa --- vllm/_tilelang_ops.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/vllm/_tilelang_ops.py b/vllm/_tilelang_ops.py index 8808a1ef0abf..5cc91a470a31 100644 --- a/vllm/_tilelang_ops.py +++ b/vllm/_tilelang_ops.py @@ -10,14 +10,6 @@ from vllm.utils.import_utils import has_tilelang from vllm.utils.math_utils import cdiv - -def _is_pdl_supported() -> bool: - is_arch_support_pdl = getattr(current_platform, "is_arch_support_pdl", None) - if not callable(is_arch_support_pdl): - return False - return is_arch_support_pdl() - - # TileLang is used for MHC on CUDA and ROCm. Keep non-GPU imports cheap so # registering the Python wrapper modules does not require TileLang everywhere. if TYPE_CHECKING or current_platform.is_cuda_alike(): @@ -32,7 +24,7 @@ def _is_pdl_supported() -> bool: tilelang = None # type: ignore[assignment] T = None # type: ignore[assignment] -ENABLE_PDL = _is_pdl_supported() +ENABLE_PDL = current_platform.is_arch_support_pdl() and current_platform.is_cuda() @cache From 97c342af98bc938022a58829d6a741433d55aa3e Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Wed, 27 May 2026 19:41:49 -0500 Subject: [PATCH 10/13] add adaptation citation Signed-off-by: tjtanaa --- vllm/_tilelang_ops.py | 2 ++ vllm/platforms/cuda.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/vllm/_tilelang_ops.py b/vllm/_tilelang_ops.py index 5cc91a470a31..272cd256939a 100644 --- a/vllm/_tilelang_ops.py +++ b/vllm/_tilelang_ops.py @@ -24,6 +24,8 @@ tilelang = None # type: ignore[assignment] T = None # type: ignore[assignment] +# Conditions adapted from +# https://github.com/sgl-project/sglang/blob/0abe6a85a51f2b7f1c3ca0e8f78944b609b94344/python/sglang/srt/layers/mhc.py#L33 # noqa: E501 ENABLE_PDL = current_platform.is_arch_support_pdl() and current_platform.is_cuda() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 58cef2ec976e..1a6a59535fb2 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -594,6 +594,8 @@ def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConf @classmethod def is_arch_support_pdl(cls) -> bool: + # Conditions adapted from + # https://github.com/sgl-project/sglang/blob/0abe6a85a51f2b7f1c3ca0e8f78944b609b94344/sgl-kernel/python/sgl_kernel/utils.py#L61 # noqa: E501 try: device = torch.cuda.current_device() major, _ = torch.cuda.get_device_capability(device) From 5bc61a45a1ed5a5f068c7e1246e1e0333ace2e98 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Wed, 27 May 2026 19:47:01 -0500 Subject: [PATCH 11/13] remove unnecessary citation Signed-off-by: tjtanaa --- vllm/_tilelang_ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/_tilelang_ops.py b/vllm/_tilelang_ops.py index 272cd256939a..5cc91a470a31 100644 --- a/vllm/_tilelang_ops.py +++ b/vllm/_tilelang_ops.py @@ -24,8 +24,6 @@ tilelang = None # type: ignore[assignment] T = None # type: ignore[assignment] -# Conditions adapted from -# https://github.com/sgl-project/sglang/blob/0abe6a85a51f2b7f1c3ca0e8f78944b609b94344/python/sglang/srt/layers/mhc.py#L33 # noqa: E501 ENABLE_PDL = current_platform.is_arch_support_pdl() and current_platform.is_cuda() From 685bede28efa744f9c3644682c33284acd10ccf6 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Wed, 27 May 2026 19:47:18 -0500 Subject: [PATCH 12/13] remove unnecessary citation Signed-off-by: tjtanaa --- vllm/platforms/cuda.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 1a6a59535fb2..58cef2ec976e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -594,8 +594,6 @@ def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConf @classmethod def is_arch_support_pdl(cls) -> bool: - # Conditions adapted from - # https://github.com/sgl-project/sglang/blob/0abe6a85a51f2b7f1c3ca0e8f78944b609b94344/sgl-kernel/python/sgl_kernel/utils.py#L61 # noqa: E501 try: device = torch.cuda.current_device() major, _ = torch.cuda.get_device_capability(device) From 8f40bc10951c3637e399b3d79608a67b21f8dc7f Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Wed, 27 May 2026 21:09:28 -0500 Subject: [PATCH 13/13] pin tilelang version rather than relax Signed-off-by: tjtanaa --- requirements/build/rocm.txt | 2 +- requirements/rocm.txt | 2 +- requirements/test/rocm.in | 1 + requirements/test/rocm.txt | 23 +++++++++++++++++++++-- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/requirements/build/rocm.txt b/requirements/build/rocm.txt index 752fb1db786a..e09bdc078bf5 100644 --- a/requirements/build/rocm.txt +++ b/requirements/build/rocm.txt @@ -16,4 +16,4 @@ wheel jinja2>=3.1.6 amdsmi==7.0.2 timm>=1.0.17 -tilelang>=0.1.10 +tilelang==0.1.10 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 37bdf1afd15a..0520f4ca1e91 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -22,4 +22,4 @@ timm>=1.0.17 # amd-quark: required for Quark quantization on ROCm # To be consistent with test_quark.py amd-quark>=0.8.99 -tilelang>=0.1.10 +tilelang==0.1.10 diff --git a/requirements/test/rocm.in b/requirements/test/rocm.in index 812fb736b570..97e0658fb106 100644 --- a/requirements/test/rocm.in +++ b/requirements/test/rocm.in @@ -43,6 +43,7 @@ schemathesis>=3.39.15 # Required for openai schema test # quantization bitsandbytes==0.49.2 buildkite-test-collector==0.1.9 +tilelang==0.1.10 genai_perf>=0.0.8 tritonclient>=2.51.0 diff --git a/requirements/test/rocm.txt b/requirements/test/rocm.txt index b15e00edf1dd..c39f268709b5 100644 --- a/requirements/test/rocm.txt +++ b/requirements/test/rocm.txt @@ -43,7 +43,9 @@ anyio==4.13.0 # starlette # watchfiles apache-tvm-ffi==0.1.10 - # via xgrammar + # via + # tilelang + # xgrammar arctic-inference==0.1.1 # via -r requirements/test/rocm.in argcomplete==3.6.3 @@ -129,7 +131,9 @@ click==8.3.1 # typer # uvicorn cloudpickle==3.1.2 - # via -r requirements/test/../common.txt + # via + # -r requirements/test/../common.txt + # tilelang colorama==0.4.6 # via # perceptron @@ -511,6 +515,8 @@ mistral-common==1.11.2 # -c requirements/common.txt # -r requirements/test/../common.txt # -r requirements/test/rocm.in +ml-dtypes==0.5.4 + # via tilelang model-hosting-container-standards==0.1.14 # via # -c requirements/common.txt @@ -587,6 +593,7 @@ numpy==2.2.6 # lm-eval # matplotlib # mistral-common + # ml-dtypes # mteb # numba # opencv-python-headless @@ -610,6 +617,7 @@ numpy==2.2.6 # statsmodels # tensorizer # tifffile + # tilelang # torchvision # transformers # tritonclient @@ -811,6 +819,7 @@ psutil==7.2.2 # accelerate # peft # tensorizer + # tilelang py==1.11.0 # via pytest-forked py-cpuinfo==9.0.0 @@ -1192,6 +1201,10 @@ tiktoken==0.12.0 # gpt-oss # lm-eval # mistral-common +tilelang==0.1.10 + # via + # -c requirements/rocm.txt + # -r requirements/test/rocm.in timm==1.0.17 # via # -c requirements/rocm.txt @@ -1208,6 +1221,8 @@ tomli==2.4.0 # via schemathesis tomli-w==1.2.0 # via schemathesis +torch-c-dlpack-ext==0.1.5 + # via tilelang tqdm==4.67.3 # via # -r requirements/test/../common.txt @@ -1225,6 +1240,7 @@ tqdm==4.67.3 # pqdm # segmentation-models-pytorch # sentence-transformers + # tilelang # transformers transformers==5.5.3 # via @@ -1293,6 +1309,7 @@ typing-extensions==4.15.0 # sentence-transformers # sqlalchemy # starlette + # tilelang # torch # typeguard # typing-inspection @@ -1359,6 +1376,8 @@ yarl==1.23.0 # via # aiohttp # schemathesis +z3-solver==4.15.4.0 + # via tilelang zipp==3.23.0 # via importlib-metadata