Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/build/rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ wheel
jinja2>=3.1.6
amdsmi==7.0.2
timm>=1.0.17
tilelang==0.1.10
1 change: 1 addition & 0 deletions requirements/rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py
amd-quark>=0.8.99
tilelang==0.1.10
1 change: 1 addition & 0 deletions requirements/test/rocm.in
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ schemathesis>=3.39.15 # Required for openai schema test
# quantization
bitsandbytes==0.49.2
buildkite-test-collector==0.1.9
tilelang==0.1.10

genai_perf>=0.0.8
tritonclient>=2.51.0
Expand Down
23 changes: 21 additions & 2 deletions requirements/test/rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ anyio==4.13.0
# starlette
# watchfiles
apache-tvm-ffi==0.1.10
# via xgrammar
# via
# tilelang
# xgrammar
arctic-inference==0.1.1
# via -r requirements/test/rocm.in
argcomplete==3.6.3
Expand Down Expand Up @@ -129,7 +131,9 @@ click==8.3.1
# typer
# uvicorn
cloudpickle==3.1.2
# via -r requirements/test/../common.txt
# via
# -r requirements/test/../common.txt
# tilelang
colorama==0.4.6
# via
# perceptron
Expand Down Expand Up @@ -511,6 +515,8 @@ mistral-common==1.11.2
# -c requirements/common.txt
# -r requirements/test/../common.txt
# -r requirements/test/rocm.in
ml-dtypes==0.5.4
# via tilelang
model-hosting-container-standards==0.1.14
# via
# -c requirements/common.txt
Expand Down Expand Up @@ -587,6 +593,7 @@ numpy==2.2.6
# lm-eval
# matplotlib
# mistral-common
# ml-dtypes
# mteb
# numba
# opencv-python-headless
Expand All @@ -610,6 +617,7 @@ numpy==2.2.6
# statsmodels
# tensorizer
# tifffile
# tilelang
# torchvision
# transformers
# tritonclient
Expand Down Expand Up @@ -811,6 +819,7 @@ psutil==7.2.2
# accelerate
# peft
# tensorizer
# tilelang
py==1.11.0
# via pytest-forked
py-cpuinfo==9.0.0
Expand Down Expand Up @@ -1192,6 +1201,10 @@ tiktoken==0.12.0
# gpt-oss
# lm-eval
# mistral-common
tilelang==0.1.10
# via
# -c requirements/rocm.txt
# -r requirements/test/rocm.in
timm==1.0.17
# via
# -c requirements/rocm.txt
Expand All @@ -1208,6 +1221,8 @@ tomli==2.4.0
# via schemathesis
tomli-w==1.2.0
# via schemathesis
torch-c-dlpack-ext==0.1.5
# via tilelang
tqdm==4.67.3
# via
# -r requirements/test/../common.txt
Expand All @@ -1225,6 +1240,7 @@ tqdm==4.67.3
# pqdm
# segmentation-models-pytorch
# sentence-transformers
# tilelang
# transformers
transformers==5.5.3
# via
Expand Down Expand Up @@ -1293,6 +1309,7 @@ typing-extensions==4.15.0
# sentence-transformers
# sqlalchemy
# starlette
# tilelang
# torch
# typeguard
# typing-inspection
Expand Down Expand Up @@ -1359,6 +1376,8 @@ yarl==1.23.0
# via
# aiohttp
# schemathesis
z3-solver==4.15.4.0
# via tilelang
zipp==3.23.0
# via importlib-metadata

Expand Down
168 changes: 166 additions & 2 deletions tests/kernels/test_mhc_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import torch

import vllm.model_executor.kernels.mhc # noqa: F401
from vllm.model_executor.kernels.mhc.tilelang import (
_tilelang_hc_prenorm_gemm,
_torch_hc_prenorm_gemm,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_tilelang
from vllm.utils.torch_utils import set_random_seed

DEVICE = current_platform.device_type
Expand Down Expand Up @@ -92,8 +97,128 @@ def hc_head_ref(


@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="CUDA required",
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_mhc_pre_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)

residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
hc_mult2 = hc_mult * hc_mult
hc_mult3 = 2 * hc_mult + hc_mult2
fn = (
torch.randn((hc_mult3, hc_mult, hidden_size), dtype=torch.float)
* 1e-4
* (1 + torch.arange(hc_mult).mul(0.01).view(1, -1, 1))
).flatten(1, 2)
hc_scale = torch.randn((3,), dtype=torch.float) * 0.1
hc_base = torch.randn((hc_mult3,), dtype=torch.float) * 0.1

hc_sinkhorn_eps = hc_pre_eps = rms_eps = 1e-6
sinkhorn_repeat = 20
hc_post_alpha = 1.0

ref = mhc_pre_ref(
residual,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_alpha,
sinkhorn_repeat,
)
out = torch.ops.vllm.mhc_pre_tilelang(
residual,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_alpha,
sinkhorn_repeat,
)

for actual, expected in zip(out, ref, strict=True):
torch.testing.assert_close(actual, expected, atol=5e-2, rtol=1e-2)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize(
("num_tokens", "hidden_size"),
[
(1, 1280),
(512, 1280),
(2048, 1280),
(1, 4096),
(64, 4096),
(512, 4096),
(2048, 4096),
(1, 7168),
(64, 7168),
(512, 7168),
(2048, 7168),
],
)
def test_hc_prenorm_gemm_tilelang(num_tokens, hidden_size):
torch.set_default_device(DEVICE)
set_random_seed(0)

hc_mult = 4
hc_mult3 = 2 * hc_mult + hc_mult * hc_mult
x = torch.randn((num_tokens, hc_mult * hidden_size), dtype=torch.bfloat16)
fn = torch.randn((hc_mult3, hc_mult * hidden_size), dtype=torch.float32) * 1e-4
out_ref = torch.empty((1, num_tokens, hc_mult3), dtype=torch.float32)
sqrsum_ref = torch.empty((1, num_tokens), dtype=torch.float32)
out = torch.empty_like(out_ref)
sqrsum = torch.empty_like(sqrsum_ref)

_torch_hc_prenorm_gemm(x, fn, out_ref, sqrsum_ref)
_tilelang_hc_prenorm_gemm(x, fn, out, sqrsum, hidden_size, hc_mult)

torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(sqrsum, sqrsum_ref, atol=8.0, rtol=5e-4)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_mhc_post_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)

x = torch.randn((num_tokens, hidden_size), dtype=torch.bfloat16)
residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
post_layer_mix = torch.randn((num_tokens, hc_mult, 1), dtype=torch.float32)
comb_res_mix = torch.randn((num_tokens, hc_mult, hc_mult), dtype=torch.float32)

ref = mhc_post_ref(x, residual, post_layer_mix, comb_res_mix)
out = torch.ops.vllm.mhc_post_tilelang(
x,
residual,
post_layer_mix,
comb_res_mix,
)

torch.testing.assert_close(out, ref, atol=5e-2, rtol=1e-2)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
Expand Down Expand Up @@ -196,3 +321,42 @@ def test_hc_head_triton(num_tokens, hidden_size, hc_mult):

out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps)
torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_hc_head_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)

residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
fn = torch.randn((hc_mult, hc_mult * hidden_size), dtype=torch.float32) * 1e-4
hc_scale = torch.randn((1,), dtype=torch.float32) * 0.1
hc_base = torch.randn((hc_mult,), dtype=torch.float32) * 0.1
rms_eps = hc_eps = 1e-6

out = torch.empty((num_tokens, hidden_size), dtype=torch.bfloat16)
out.fill_(float("nan"))

result = torch.ops.vllm.hc_head_fused_kernel_tilelang(
residual,
fn,
hc_scale,
hc_base,
out,
hidden_size,
rms_eps,
hc_eps,
hc_mult,
)

assert result is None
assert not torch.isnan(out).any()

out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps)
torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2)
Loading
Loading