diff --git a/requirements/build/rocm.txt b/requirements/build/rocm.txt index e5c2176a2c8c..e09bdc078bf5 100644 --- a/requirements/build/rocm.txt +++ b/requirements/build/rocm.txt @@ -16,3 +16,4 @@ wheel jinja2>=3.1.6 amdsmi==7.0.2 timm>=1.0.17 +tilelang==0.1.10 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 61fcbc07010c..0520f4ca1e91 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -22,3 +22,4 @@ timm>=1.0.17 # amd-quark: required for Quark quantization on ROCm # To be consistent with test_quark.py amd-quark>=0.8.99 +tilelang==0.1.10 diff --git a/requirements/test/rocm.in b/requirements/test/rocm.in index 812fb736b570..97e0658fb106 100644 --- a/requirements/test/rocm.in +++ b/requirements/test/rocm.in @@ -43,6 +43,7 @@ schemathesis>=3.39.15 # Required for openai schema test # quantization bitsandbytes==0.49.2 buildkite-test-collector==0.1.9 +tilelang==0.1.10 genai_perf>=0.0.8 tritonclient>=2.51.0 diff --git a/requirements/test/rocm.txt b/requirements/test/rocm.txt index b15e00edf1dd..c39f268709b5 100644 --- a/requirements/test/rocm.txt +++ b/requirements/test/rocm.txt @@ -43,7 +43,9 @@ anyio==4.13.0 # starlette # watchfiles apache-tvm-ffi==0.1.10 - # via xgrammar + # via + # tilelang + # xgrammar arctic-inference==0.1.1 # via -r requirements/test/rocm.in argcomplete==3.6.3 @@ -129,7 +131,9 @@ click==8.3.1 # typer # uvicorn cloudpickle==3.1.2 - # via -r requirements/test/../common.txt + # via + # -r requirements/test/../common.txt + # tilelang colorama==0.4.6 # via # perceptron @@ -511,6 +515,8 @@ mistral-common==1.11.2 # -c requirements/common.txt # -r requirements/test/../common.txt # -r requirements/test/rocm.in +ml-dtypes==0.5.4 + # via tilelang model-hosting-container-standards==0.1.14 # via # -c requirements/common.txt @@ -587,6 +593,7 @@ numpy==2.2.6 # lm-eval # matplotlib # mistral-common + # ml-dtypes # mteb # numba # opencv-python-headless @@ -610,6 +617,7 @@ numpy==2.2.6 # statsmodels # tensorizer # tifffile + # tilelang # torchvision # transformers # tritonclient @@ -811,6 +819,7 @@ psutil==7.2.2 # accelerate # peft # tensorizer + # tilelang py==1.11.0 # via pytest-forked py-cpuinfo==9.0.0 @@ -1192,6 +1201,10 @@ tiktoken==0.12.0 # gpt-oss # lm-eval # mistral-common +tilelang==0.1.10 + # via + # -c requirements/rocm.txt + # -r requirements/test/rocm.in timm==1.0.17 # via # -c requirements/rocm.txt @@ -1208,6 +1221,8 @@ tomli==2.4.0 # via schemathesis tomli-w==1.2.0 # via schemathesis +torch-c-dlpack-ext==0.1.5 + # via tilelang tqdm==4.67.3 # via # -r requirements/test/../common.txt @@ -1225,6 +1240,7 @@ tqdm==4.67.3 # pqdm # segmentation-models-pytorch # sentence-transformers + # tilelang # transformers transformers==5.5.3 # via @@ -1293,6 +1309,7 @@ typing-extensions==4.15.0 # sentence-transformers # sqlalchemy # starlette + # tilelang # torch # typeguard # typing-inspection @@ -1359,6 +1376,8 @@ yarl==1.23.0 # via # aiohttp # schemathesis +z3-solver==4.15.4.0 + # via tilelang zipp==3.23.0 # via importlib-metadata diff --git a/tests/kernels/test_mhc_kernels.py b/tests/kernels/test_mhc_kernels.py index 81fceeceac10..e7d4cde43f1d 100644 --- a/tests/kernels/test_mhc_kernels.py +++ b/tests/kernels/test_mhc_kernels.py @@ -4,7 +4,12 @@ import torch import vllm.model_executor.kernels.mhc # noqa: F401 +from vllm.model_executor.kernels.mhc.tilelang import ( + _tilelang_hc_prenorm_gemm, + _torch_hc_prenorm_gemm, +) from vllm.platforms import current_platform +from vllm.utils.import_utils import has_tilelang from vllm.utils.torch_utils import set_random_seed DEVICE = current_platform.device_type @@ -92,8 +97,128 @@ def hc_head_ref( @pytest.mark.skipif( - not current_platform.is_cuda(), - reason="CUDA required", + not (current_platform.is_cuda_alike() and has_tilelang()), + reason="CUDA or ROCm and tilelang required", +) +@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128]) +@pytest.mark.parametrize("hidden_size", [4096, 7168]) +@pytest.mark.parametrize("hc_mult", [4]) +def test_mhc_pre_tilelang(num_tokens, hidden_size, hc_mult): + torch.set_default_device(DEVICE) + set_random_seed(0) + + residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16) + hc_mult2 = hc_mult * hc_mult + hc_mult3 = 2 * hc_mult + hc_mult2 + fn = ( + torch.randn((hc_mult3, hc_mult, hidden_size), dtype=torch.float) + * 1e-4 + * (1 + torch.arange(hc_mult).mul(0.01).view(1, -1, 1)) + ).flatten(1, 2) + hc_scale = torch.randn((3,), dtype=torch.float) * 0.1 + hc_base = torch.randn((hc_mult3,), dtype=torch.float) * 0.1 + + hc_sinkhorn_eps = hc_pre_eps = rms_eps = 1e-6 + sinkhorn_repeat = 20 + hc_post_alpha = 1.0 + + ref = mhc_pre_ref( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_alpha, + sinkhorn_repeat, + ) + out = torch.ops.vllm.mhc_pre_tilelang( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_alpha, + sinkhorn_repeat, + ) + + for actual, expected in zip(out, ref, strict=True): + torch.testing.assert_close(actual, expected, atol=5e-2, rtol=1e-2) + + +@pytest.mark.skipif( + not (current_platform.is_cuda_alike() and has_tilelang()), + reason="CUDA or ROCm and tilelang required", +) +@pytest.mark.parametrize( + ("num_tokens", "hidden_size"), + [ + (1, 1280), + (512, 1280), + (2048, 1280), + (1, 4096), + (64, 4096), + (512, 4096), + (2048, 4096), + (1, 7168), + (64, 7168), + (512, 7168), + (2048, 7168), + ], +) +def test_hc_prenorm_gemm_tilelang(num_tokens, hidden_size): + torch.set_default_device(DEVICE) + set_random_seed(0) + + hc_mult = 4 + hc_mult3 = 2 * hc_mult + hc_mult * hc_mult + x = torch.randn((num_tokens, hc_mult * hidden_size), dtype=torch.bfloat16) + fn = torch.randn((hc_mult3, hc_mult * hidden_size), dtype=torch.float32) * 1e-4 + out_ref = torch.empty((1, num_tokens, hc_mult3), dtype=torch.float32) + sqrsum_ref = torch.empty((1, num_tokens), dtype=torch.float32) + out = torch.empty_like(out_ref) + sqrsum = torch.empty_like(sqrsum_ref) + + _torch_hc_prenorm_gemm(x, fn, out_ref, sqrsum_ref) + _tilelang_hc_prenorm_gemm(x, fn, out, sqrsum, hidden_size, hc_mult) + + torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-4) + torch.testing.assert_close(sqrsum, sqrsum_ref, atol=8.0, rtol=5e-4) + + +@pytest.mark.skipif( + not (current_platform.is_cuda_alike() and has_tilelang()), + reason="CUDA or ROCm and tilelang required", +) +@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128]) +@pytest.mark.parametrize("hidden_size", [4096, 7168]) +@pytest.mark.parametrize("hc_mult", [4]) +def test_mhc_post_tilelang(num_tokens, hidden_size, hc_mult): + torch.set_default_device(DEVICE) + set_random_seed(0) + + x = torch.randn((num_tokens, hidden_size), dtype=torch.bfloat16) + residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16) + post_layer_mix = torch.randn((num_tokens, hc_mult, 1), dtype=torch.float32) + comb_res_mix = torch.randn((num_tokens, hc_mult, hc_mult), dtype=torch.float32) + + ref = mhc_post_ref(x, residual, post_layer_mix, comb_res_mix) + out = torch.ops.vllm.mhc_post_tilelang( + x, + residual, + post_layer_mix, + comb_res_mix, + ) + + torch.testing.assert_close(out, ref, atol=5e-2, rtol=1e-2) + + +@pytest.mark.skipif( + not (current_platform.is_cuda_alike() and has_tilelang()), + reason="CUDA or ROCm and tilelang required", ) @pytest.mark.parametrize("num_tokens", [1, 4, 8, 128]) @pytest.mark.parametrize("hidden_size", [4096, 7168]) @@ -196,3 +321,42 @@ def test_hc_head_triton(num_tokens, hidden_size, hc_mult): out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps) torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2) + + +@pytest.mark.skipif( + not (current_platform.is_cuda_alike() and has_tilelang()), + reason="CUDA or ROCm and tilelang required", +) +@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128]) +@pytest.mark.parametrize("hidden_size", [4096, 7168]) +@pytest.mark.parametrize("hc_mult", [4]) +def test_hc_head_tilelang(num_tokens, hidden_size, hc_mult): + torch.set_default_device(DEVICE) + set_random_seed(0) + + residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16) + fn = torch.randn((hc_mult, hc_mult * hidden_size), dtype=torch.float32) * 1e-4 + hc_scale = torch.randn((1,), dtype=torch.float32) * 0.1 + hc_base = torch.randn((hc_mult,), dtype=torch.float32) * 0.1 + rms_eps = hc_eps = 1e-6 + + out = torch.empty((num_tokens, hidden_size), dtype=torch.bfloat16) + out.fill_(float("nan")) + + result = torch.ops.vllm.hc_head_fused_kernel_tilelang( + residual, + fn, + hc_scale, + hc_base, + out, + hidden_size, + rms_eps, + hc_eps, + hc_mult, + ) + + assert result is None + assert not torch.isnan(out).any() + + out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps) + torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2) diff --git a/vllm/_tilelang_ops.py b/vllm/_tilelang_ops.py index aa742fe50320..5cc91a470a31 100644 --- a/vllm/_tilelang_ops.py +++ b/vllm/_tilelang_ops.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from functools import cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch @@ -10,8 +10,9 @@ from vllm.utils.import_utils import has_tilelang from vllm.utils.math_utils import cdiv -# tilelang is only available on CUDA platforms -if TYPE_CHECKING or current_platform.is_cuda(): +# TileLang is used for MHC on CUDA and ROCm. Keep non-GPU imports cheap so +# registering the Python wrapper modules does not require TileLang everywhere. +if TYPE_CHECKING or current_platform.is_cuda_alike(): if not has_tilelang(): raise ImportError( "tilelang is required for mhc but is not installed. Install it with " @@ -23,6 +24,8 @@ tilelang = None # type: ignore[assignment] T = None # type: ignore[assignment] +ENABLE_PDL = current_platform.is_arch_support_pdl() and current_platform.is_cuda() + @cache def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int: @@ -37,12 +40,17 @@ def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int: return split_k +pass_configs: dict[tilelang.PassConfigKey, Any] = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, +} + +if current_platform.is_cuda(): + pass_configs[tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL] = 10 + + @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs=pass_configs, ) def mhc_pre_big_fuse_tilelang( gemm_out_mul, @@ -78,7 +86,8 @@ def mhc_pre_big_fuse_tilelang( layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type] with T.Kernel(num_tokens, threads=96) as i: - T.pdl_sync() + if ENABLE_PDL: + T.pdl_sync() ################################################################## # _pre_norm_fn_fwd_norm rms = T.alloc_fragment(1, T.float32) @@ -174,18 +183,16 @@ def mhc_pre_big_fuse_tilelang( ol[i1_h] += pre * xl[i_hc, i1_h] T.copy(ol, layer_input[i, i0_h * hidden_block]) - T.pdl_trigger() + + if ENABLE_PDL: + T.pdl_trigger() # Copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/mhc.py#L478 @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs=pass_configs, ) def mhc_pre_big_fuse_with_norm_tilelang( gemm_out_mul, @@ -230,7 +237,8 @@ def mhc_pre_big_fuse_with_norm_tilelang( T.clear(mixes) rms[0] = 0 - T.pdl_sync() + if ENABLE_PDL: + T.pdl_sync() for i_split in T.serial(n_splits): rms[0] += gemm_out_sqrsum[i_split, i] @@ -341,15 +349,12 @@ def mhc_pre_big_fuse_with_norm_tilelang( T.copy(ol, layer_input[i, i0_h * hidden_block]) - T.pdl_trigger() + if ENABLE_PDL: + T.pdl_trigger() @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs=pass_configs, ) def mhc_fused_tilelang( comb_mix, @@ -390,8 +395,8 @@ def mhc_fused_tilelang( with T.Kernel(m, n_tiles, split_k, threads=n_thr) as (i_n, i_nt, i_ks): tid = T.get_thread_binding() - warp_id = T.get_warp_idx() - lane = T.get_lane_idx() + warp_id = tid // 32 + lane = tid % 32 s_warp = T.alloc_shared((num_warps, tile_n + 1), T.float32) s_post = T.alloc_shared((hc,), T.float32) @@ -407,7 +412,8 @@ def mhc_fused_tilelang( T.clear(sqr) h_split_start = i_ks * h_per_split - T.pdl_sync() + if ENABLE_PDL: + T.pdl_sync() T.copy(post_mix[i_n, 0], s_post) T.copy(comb_mix[i_n, 0, 0], s_comb) @@ -466,15 +472,12 @@ def mhc_fused_tilelang( v2 += s_warp[w, tile_n] rp_out[i_ks, i_n] = v2 - T.pdl_trigger() + if ENABLE_PDL: + T.pdl_trigger() @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs=pass_configs, ) def mhc_post_tilelang( a, @@ -507,7 +510,8 @@ def mhc_post_tilelang( a_local = T.alloc_fragment((hc, hc), T.float32) c_local = T.alloc_fragment(hc, T.float32) - T.pdl_sync() + if ENABLE_PDL: + T.pdl_sync() T.copy(a[i_n, 0, 0], a_local) T.copy(c[i_n, 0], c_local) @@ -523,15 +527,193 @@ def mhc_post_tilelang( x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h] T.copy(x_local, x[i_n, 0, i0_h * h_blk]) - T.pdl_trigger() + if ENABLE_PDL: + T.pdl_trigger() + + +@tilelang.jit( + pass_configs=pass_configs, +) +def hc_prenorm_gemm_tilelang( + x, + fn, + out, + sqrsum, + hidden_size: int, + hc_mult: int = 4, + n_out: int = 24, + n_thr: int = 512, + tile_n: int = 12, + n_splits: int = 1, +) -> tilelang.JITKernel: + num_tokens = T.dynamic("num_tokens") + hc_hidden_size = hc_mult * hidden_size + k_per_split = hc_hidden_size // n_splits + k_iters = k_per_split // n_thr + n_tiles = T.ceildiv(n_out, tile_n) + + x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) # type: ignore[no-redef, valid-type] + fn: T.Tensor((n_out, hc_hidden_size), T.float32) # type: ignore[no-redef, valid-type] + out: T.Tensor((n_splits, num_tokens, n_out), T.float32) # type: ignore[no-redef, valid-type] + sqrsum: T.Tensor((n_splits, num_tokens), T.float32) # type: ignore[no-redef, valid-type] + + with T.Kernel(num_tokens, n_tiles, n_splits, threads=n_thr) as ( + i_n, + i_t, + i_s, + ): + tid = T.get_thread_binding() + acc = T.alloc_local((tile_n,), T.float32) + sqr = T.alloc_local((1,), T.float32) + T.clear(acc) + T.clear(sqr) + + if ENABLE_PDL: + T.pdl_sync() + + for it in T.serial(k_iters): + i_k = i_s * k_per_split + it * n_thr + tid + x_val = x[i_n, i_k] + for i_o in T.unroll(tile_n): + out_idx = i_t * tile_n + i_o + if out_idx < n_out: + acc[i_o] += x_val * fn[out_idx, i_k] + if i_t == 0: + sqr[0] += x_val * x_val + + for i_o in T.unroll(tile_n): + acc[i_o] = T.warp_reduce_sum(acc[i_o]) + if i_t == 0: + sqr[0] = T.warp_reduce_sum(sqr[0]) + + lane = tid % 32 + warp_id = tid // 32 + num_warps = n_thr // 32 + warp_acc = T.alloc_shared((num_warps, tile_n), T.float32) + warp_sqr = T.alloc_shared(num_warps, T.float32) + + if lane == 0: + for i_o in T.unroll(tile_n): + warp_acc[warp_id, i_o] = acc[i_o] + if i_t == 0: + warp_sqr[warp_id] = sqr[0] + T.sync_threads() + + if warp_id == 0: + if lane < tile_n: + reduced_acc = T.alloc_var(T.float32, init=0.0) + for i_w in T.unroll(num_warps): + reduced_acc += warp_acc[i_w, lane] + out_idx = i_t * tile_n + lane + if out_idx < n_out: + out[i_s, i_n, out_idx] = reduced_acc + if lane == 0 and i_t == 0: + reduced_sqr = T.alloc_var(T.float32, init=0.0) + for i_w in T.unroll(num_warps): + reduced_sqr += warp_sqr[i_w] + sqrsum[i_s, i_n] = reduced_sqr + + if ENABLE_PDL: + T.pdl_trigger() + + +@tilelang.jit( + pass_configs=pass_configs, +) +def hc_prenorm_gemm_block_m_tilelang( + x, + fn, + out, + sqrsum, + hidden_size: int, + hc_mult: int = 4, + n_out: int = 24, + n_thr: int = 512, + tile_n: int = 12, + block_m: int = 2, +) -> tilelang.JITKernel: + num_tokens = T.dynamic("num_tokens") + hc_hidden_size = hc_mult * hidden_size + k_iters = hc_hidden_size // n_thr + n_tiles = T.ceildiv(n_out, tile_n) + m_tiles = T.ceildiv(num_tokens, block_m) + + x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) # type: ignore[no-redef, valid-type] + fn: T.Tensor((n_out, hc_hidden_size), T.float32) # type: ignore[no-redef, valid-type] + out: T.Tensor((1, num_tokens, n_out), T.float32) # type: ignore[no-redef, valid-type] + sqrsum: T.Tensor((1, num_tokens), T.float32) # type: ignore[no-redef, valid-type] + + with T.Kernel(m_tiles, n_tiles, threads=n_thr) as (i_mt, i_t): + tid = T.get_thread_binding() + acc = T.alloc_local((block_m, tile_n), T.float32) + sqr = T.alloc_local((block_m,), T.float32) + T.clear(acc) + T.clear(sqr) + + if ENABLE_PDL: + T.pdl_sync() + + for it in T.serial(k_iters): + i_k = it * n_thr + tid + fn_val = T.alloc_local((tile_n,), T.float32) + for i_o in T.unroll(tile_n): + out_idx = i_t * tile_n + i_o + if out_idx < n_out: + fn_val[i_o] = fn[out_idx, i_k] + else: + fn_val[i_o] = 0.0 + for i_m in T.unroll(block_m): + token_idx = i_mt * block_m + i_m + if token_idx < num_tokens: + x_val = x[token_idx, i_k] + for i_o in T.unroll(tile_n): + acc[i_m, i_o] += x_val * fn_val[i_o] + if i_t == 0: + sqr[i_m] += x_val * x_val + + for i_m in T.unroll(block_m): + for i_o in T.unroll(tile_n): + acc[i_m, i_o] = T.warp_reduce_sum(acc[i_m, i_o]) + if i_t == 0: + sqr[i_m] = T.warp_reduce_sum(sqr[i_m]) + + lane = tid % 32 + warp_id = tid // 32 + num_warps = n_thr // 32 + warp_acc = T.alloc_shared((num_warps, block_m, tile_n), T.float32) + warp_sqr = T.alloc_shared((num_warps, block_m), T.float32) + + if lane == 0: + for i_m in T.unroll(block_m): + for i_o in T.unroll(tile_n): + warp_acc[warp_id, i_m, i_o] = acc[i_m, i_o] + if i_t == 0: + warp_sqr[warp_id, i_m] = sqr[i_m] + T.sync_threads() + + if warp_id == 0: + for i_m in T.unroll(block_m): + token_idx = i_mt * block_m + i_m + if token_idx < num_tokens: + if lane < tile_n: + reduced_acc = T.alloc_var(T.float32, init=0.0) + for i_w in T.unroll(num_warps): + reduced_acc += warp_acc[i_w, i_m, lane] + out_idx = i_t * tile_n + lane + if out_idx < n_out: + out[0, token_idx, out_idx] = reduced_acc + if lane == 0 and i_t == 0: + reduced_sqr = T.alloc_var(T.float32, init=0.0) + for i_w in T.unroll(num_warps): + reduced_sqr += warp_sqr[i_w, i_m] + sqrsum[0, token_idx] = reduced_sqr + + if ENABLE_PDL: + T.pdl_trigger() @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, + pass_configs=pass_configs, ) def hc_head_fuse_tilelang( residual, @@ -566,7 +748,8 @@ def hc_head_fuse_tilelang( out: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef,valid-type] with T.Kernel(num_tokens, threads=n_thr) as i: - T.pdl_sync() + if ENABLE_PDL: + T.pdl_sync() # ------------------------------------------------------------------ # Pass 1 – for each residual channel m_c and h_block: @@ -624,4 +807,5 @@ def hc_head_fuse_tilelang( T.copy(ol, out[i, i0_h * h_block], disable_tma=True) - T.pdl_trigger() + if ENABLE_PDL: + T.pdl_trigger() diff --git a/vllm/model_executor/kernels/mhc/tilelang.py b/vllm/model_executor/kernels/mhc/tilelang.py index c242fef2d026..d76123bb7625 100644 --- a/vllm/model_executor/kernels/mhc/tilelang.py +++ b/vllm/model_executor/kernels/mhc/tilelang.py @@ -5,6 +5,88 @@ from vllm.utils.torch_utils import direct_register_custom_op +def _torch_hc_prenorm_gemm( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, +) -> None: + assert out.shape[0] == 1 + assert sqrsum.shape[0] == 1 + x_float = x.float() + out[0].copy_(x_float @ fn.t()) + sqrsum[0].copy_(x_float.square().sum(dim=-1)) + + +def _tilelang_hc_prenorm_gemm( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + hidden_size: int, + hc_mult: int, + tile_n: int = 12, + n_thr: int = 512, + n_splits: int = 1, +) -> None: + from vllm._tilelang_ops import ( + hc_prenorm_gemm_block_m_tilelang, + hc_prenorm_gemm_tilelang, + ) + + assert out.shape[0] == n_splits + assert sqrsum.shape[0] == n_splits + assert x.shape[1] == hc_mult * hidden_size + assert x.shape[1] % n_splits == 0 + assert (x.shape[1] // n_splits) % n_thr == 0 + use_default_config = tile_n == 12 and n_thr == 512 + if n_splits == 1 and use_default_config and x.shape[0] >= 1024: + hc_prenorm_gemm_block_m_tilelang( + x, + fn, + out, + sqrsum, + hidden_size, + hc_mult, + fn.shape[0], + n_thr, + tile_n, + 2, + ) + return + if ( + n_splits == 1 + and use_default_config + and x.shape[0] < 128 + and x.shape[1] % 1024 == 0 + ): + hc_prenorm_gemm_tilelang( + x, + fn, + out, + sqrsum, + hidden_size, + hc_mult, + fn.shape[0], + 1024, + 4, + n_splits, + ) + return + hc_prenorm_gemm_tilelang( + x, + fn, + out, + sqrsum, + hidden_size, + hc_mult, + fn.shape[0], + n_thr, + tile_n, + n_splits, + ) + + def mhc_pre_tilelang( residual: torch.Tensor, fn: torch.Tensor, @@ -80,10 +162,16 @@ def mhc_pre_tilelang( residual_flat = residual.view(-1, hc_mult, hidden_size) num_tokens = residual_flat.shape[0] - # these numbers are from deepgemm kernel impl - block_k = 64 - block_m = 64 - n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m)) + from vllm.utils.deep_gemm import is_deep_gemm_supported + + use_deep_gemm = is_deep_gemm_supported() + if use_deep_gemm: + # these numbers are from deepgemm kernel impl + block_k = 64 + block_m = 64 + n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m)) + else: + n_splits = 1 post_mix = torch.empty( num_tokens, hc_mult, dtype=torch.float32, device=residual.device @@ -102,13 +190,24 @@ def mhc_pre_tilelang( n_splits, num_tokens, dtype=torch.float32, device=residual.device ) - tf32_hc_prenorm_gemm( - residual_flat.view(num_tokens, hc_mult * hidden_size), - fn, - gemm_out_mul, - gemm_out_sqrsum, - n_splits, - ) + residual_2d = residual_flat.view(num_tokens, hc_mult * hidden_size) + if use_deep_gemm: + tf32_hc_prenorm_gemm( + residual_2d, + fn, + gemm_out_mul, + gemm_out_sqrsum, + n_splits, + ) + else: + _tilelang_hc_prenorm_gemm( + residual_2d, + fn, + gemm_out_mul, + gemm_out_sqrsum, + hidden_size, + hc_mult, + ) if norm_weight is None: mhc_pre_big_fuse_tilelang( @@ -304,16 +403,24 @@ def mhc_fused_post_pre_tilelang( post_layer_mix_flat = post_layer_mix.view(num_tokens, hc_mult) comb_res_mix_flat = comb_res_mix.view(num_tokens, hc_mult, hc_mult) - fma_token_threshold = 16 - if num_tokens <= fma_token_threshold: + from vllm.utils.deep_gemm import is_deep_gemm_supported + + use_deep_gemm = is_deep_gemm_supported() + use_small_fma = num_tokens <= 16 + if use_small_fma: # TODO(gnovack): investigate autotuning these heuristics tile_n = 2 if num_tokens < 8 else 3 n_splits = 8 if (num_tokens < 8 and hidden_size <= 4096) else 4 else: - # these number are from deepgemm kernel impl - block_k = 64 - block_m = 64 - n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m)) + if use_deep_gemm: + # these number are from deepgemm kernel impl + block_k = 64 + block_m = 64 + n_splits = compute_num_split( + block_k, hc_hidden_size, cdiv(num_tokens, block_m) + ) + else: + n_splits = 1 gemm_out_mul = torch.empty( n_splits, @@ -348,7 +455,7 @@ def mhc_fused_post_pre_tilelang( device=residual.device, ) - if num_tokens <= fma_token_threshold: + if use_small_fma: mhc_fused_tilelang( comb_res_mix_flat, residual_flat, @@ -375,15 +482,26 @@ def mhc_fused_post_pre_tilelang( residual.shape[-1], ) - from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm - - tf32_hc_prenorm_gemm( - residual_cur.view(num_tokens, hc_mult * hidden_size), - fn, - gemm_out_mul, - gemm_out_sqrsum, - n_splits, - ) + residual_cur_2d = residual_cur.view(num_tokens, hc_mult * hidden_size) + if use_deep_gemm: + from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm + + tf32_hc_prenorm_gemm( + residual_cur_2d, + fn, + gemm_out_mul, + gemm_out_sqrsum, + n_splits, + ) + else: + _tilelang_hc_prenorm_gemm( + residual_cur_2d, + fn, + gemm_out_mul, + gemm_out_sqrsum, + hidden_size, + hc_mult, + ) if norm_weight is None: mhc_pre_big_fuse_tilelang( diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index cb3965f764ba..b720fa1f6fe2 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -3,8 +3,12 @@ import torch # this import will also register the custom ops +# import vllm.model_executor.kernels.mhc # noqa: F401 import vllm.model_executor.kernels.mhc as mhc_kernels from vllm.model_executor.custom_op import CustomOp +from vllm.utils.import_utils import has_tilelang + +HAS_TILELANG = has_tilelang() # --8<-- [start:mhc_pre] @@ -85,6 +89,52 @@ def forward_hip( # sinkhorn_repeat, # ) # else: + if HAS_TILELANG: + return torch.ops.vllm.mhc_pre_tilelang( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + norm_weight, + norm_eps, + ) + else: + return self.forward_native( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + norm_weight, + norm_eps, + ) + + def forward_native( + self, + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, + norm_weight: torch.Tensor | None = None, + norm_eps: float = 0.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return mhc_kernels.mhc_pre_torch( residual, fn, @@ -97,9 +147,6 @@ def forward_hip( sinkhorn_repeat, ) - def forward_native(self, *args, **kwargs): - raise NotImplementedError("Native implementation of mhc_pre is not available") - # --8<-- [start:mhc_post] @CustomOp.register("mhc_post") @@ -147,6 +194,20 @@ def forward_hip( # comb_res_mix, # ) # else: + if HAS_TILELANG: + return torch.ops.vllm.mhc_post_tilelang( + x, residual, post_layer_mix, comb_res_mix + ) + else: + return self.forward_native(x, residual, post_layer_mix, comb_res_mix) + + def forward_native( + self, + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, + ) -> torch.Tensor: return mhc_kernels.mhc_post_torch( x, residual, @@ -154,9 +215,6 @@ def forward_hip( comb_res_mix, ) - def forward_native(self, *args, **kwargs): - raise NotImplementedError("Native implementation of mhc_post is not available") - # --8<-- [start:hc_head] @CustomOp.register("hc_head") @@ -220,17 +278,32 @@ def forward_hip( out = torch.empty( num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device ) - torch.ops.vllm.hc_head_triton( - hs_flat, - hc_fn, - hc_scale, - hc_base, - out, - hidden_size, - rms_norm_eps, - hc_eps, - hc_mult, - ) + + if HAS_TILELANG: + torch.ops.vllm.hc_head_fused_kernel_tilelang( + hs_flat, + hc_fn, + hc_scale, + hc_base, + out, + hidden_size, + rms_norm_eps, + hc_eps, + hc_mult, + ) + else: + torch.ops.vllm.hc_head_triton( + hs_flat, + hc_fn, + hc_scale, + hc_base, + out, + hidden_size, + rms_norm_eps, + hc_eps, + hc_mult, + ) + return out.view(*outer_shape, hidden_size) def forward_native(self, *args, **kwargs): @@ -290,9 +363,42 @@ def forward_cuda( norm_eps, ) - def forward_hip(self, *args, **kwargs): - raise NotImplementedError( - "Hip implementation of mhc_fused_post_pre is not available" + def forward_hip( + self, + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, + tile_n: int = 1, + norm_weight: torch.Tensor | None = None, + norm_eps: float = 0.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return torch.ops.vllm.mhc_fused_post_pre_tilelang( + x, + residual, + post_layer_mix, + comb_res_mix, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + tile_n, + norm_weight, + norm_eps, ) def forward_native(self, *args, **kwargs): diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index 84318a8107d3..0a829c117ed9 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -54,6 +54,7 @@ ) from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils.import_utils import has_tilelang class DeepseekV4MLP(nn.Module): @@ -473,6 +474,7 @@ def __init__( self.mhc_pre = MHCPreOp() self.mhc_post = MHCPostOp() self.mhc_fused_post_pre = MHCFusedPostPreOp() + self.has_tilelang = has_tilelang() def hc_pre( self, @@ -503,7 +505,7 @@ def hc_post( ): return self.mhc_post(x, residual, post, comb) - def _forward_cuda( + def _forward_fused_post_pre( self, x: torch.Tensor, positions: torch.Tensor, @@ -555,7 +557,7 @@ def _forward_cuda( x = self.ffn(x, input_ids) return x, residual, post_mix, res_mix - def _forward_rocm( + def _forward_unfused_post_pre( self, x: torch.Tensor, positions: torch.Tensor, @@ -594,12 +596,13 @@ def forward( ) -> tuple[ torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None ]: - if current_platform.is_rocm(): - return self._forward_rocm( + if not self.has_tilelang: + return self._forward_unfused_post_pre( x, positions, input_ids, post_mix, res_mix, residual ) - - return self._forward_cuda(x, positions, input_ids, post_mix, res_mix, residual) + return self._forward_fused_post_pre( + x, positions, input_ids, post_mix, res_mix, residual + ) @support_torch_compile @@ -682,6 +685,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): requires_grad=False, ) self.hc_head_op = HCHeadOp() + self.has_tilelang = has_tilelang() # Pre-hc_head residual stream buffer for the MTP draft. Stable # address (outside the cudagraph pool) so the copy_ in forward() # refreshes it correctly across captured shapes. @@ -748,7 +752,7 @@ def forward( res_mix, residual, ) - if layer is not None and current_platform.is_cuda(): + if layer is not None and self.has_tilelang: hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix) if not get_pp_group().is_last_rank: diff --git a/vllm/models/deepseek_v4/amd/mtp.py b/vllm/models/deepseek_v4/amd/mtp.py index 168d9b938b76..2f83af5ad98d 100644 --- a/vllm/models/deepseek_v4/amd/mtp.py +++ b/vllm/models/deepseek_v4/amd/mtp.py @@ -39,6 +39,7 @@ from vllm.model_executor.models.utils import maybe_prefix from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils.import_utils import has_tilelang from .model import DeepseekV4DecoderLayer @@ -118,6 +119,7 @@ def __init__( ) self.hc_head_op = HCHeadOp() + self.has_tilelang = has_tilelang() def forward( self, @@ -144,7 +146,7 @@ def forward( hidden_states, residual, post_mix, res_mix = self.mtp_block( positions=positions, x=hidden_states, input_ids=None ) - if current_platform.is_cuda(): + if self.has_tilelang: hidden_states = self.mtp_block.hc_post( hidden_states, residual, post_mix, res_mix ) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 78a5be4e84c7..58cef2ec976e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -592,6 +592,15 @@ def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConf default, rms_norm=rms_norm, fused_add_rms_norm=rms_norm ) + @classmethod + def is_arch_support_pdl(cls) -> bool: + try: + device = torch.cuda.current_device() + major, _ = torch.cuda.get_device_capability(device) + except Exception: + return False + return major >= 9 + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 9a93ef9f82a7..cf774b7bda91 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1016,6 +1016,13 @@ def get_default_ir_op_priority( # Native always used by default. Platforms can override this behavior. return IrOpPriorityConfig.with_default(["native"]) + @classmethod + def is_arch_support_pdl(cls) -> bool: + """ + Does the current platform support PDL (Programmatic Dependent Launch)? + """ + return False + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index 5822e5840afc..e97228bfa609 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -430,6 +430,7 @@ def has_triton_kernels() -> bool: return is_available +@cache def has_tilelang() -> bool: """Whether the optional `tilelang` package is available.""" return _has_module("tilelang")