Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/build/rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ wheel
jinja2>=3.1.6
amdsmi==7.0.2
timm>=1.0.17
tilelang>=0.1.10
1 change: 1 addition & 0 deletions requirements/rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py
amd-quark>=0.8.99
tilelang>=0.1.10
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test and even nicer feature. i am wondering for CI purposes to avoid any regressions from future versions, should we pin the version? Maybe we can add it in rocm.in too. We can also do that in a follow-up PR. But let me know if you agree.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it in a follow up PR. Would like to land this PR and let @WoosukKwon continue with the restructuring of the mhc kernels.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have pinned the version to exact version.

168 changes: 166 additions & 2 deletions tests/kernels/test_mhc_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import torch

import vllm.model_executor.kernels.mhc # noqa: F401
from vllm.model_executor.kernels.mhc.tilelang import (
_tilelang_hc_prenorm_gemm,
_torch_hc_prenorm_gemm,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_tilelang
from vllm.utils.torch_utils import set_random_seed

DEVICE = current_platform.device_type
Expand Down Expand Up @@ -92,8 +97,128 @@ def hc_head_ref(


@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="CUDA required",
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_mhc_pre_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)

residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
hc_mult2 = hc_mult * hc_mult
hc_mult3 = 2 * hc_mult + hc_mult2
fn = (
torch.randn((hc_mult3, hc_mult, hidden_size), dtype=torch.float)
* 1e-4
* (1 + torch.arange(hc_mult).mul(0.01).view(1, -1, 1))
).flatten(1, 2)
hc_scale = torch.randn((3,), dtype=torch.float) * 0.1
hc_base = torch.randn((hc_mult3,), dtype=torch.float) * 0.1

hc_sinkhorn_eps = hc_pre_eps = rms_eps = 1e-6
sinkhorn_repeat = 20
hc_post_alpha = 1.0

ref = mhc_pre_ref(
residual,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_alpha,
sinkhorn_repeat,
)
out = torch.ops.vllm.mhc_pre_tilelang(
residual,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_alpha,
sinkhorn_repeat,
)

for actual, expected in zip(out, ref, strict=True):
torch.testing.assert_close(actual, expected, atol=5e-2, rtol=1e-2)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize(
("num_tokens", "hidden_size"),
[
(1, 1280),
(512, 1280),
(2048, 1280),
(1, 4096),
(64, 4096),
(512, 4096),
(2048, 4096),
(1, 7168),
(64, 7168),
(512, 7168),
(2048, 7168),
],
)
def test_hc_prenorm_gemm_tilelang(num_tokens, hidden_size):
torch.set_default_device(DEVICE)
set_random_seed(0)

hc_mult = 4
hc_mult3 = 2 * hc_mult + hc_mult * hc_mult
x = torch.randn((num_tokens, hc_mult * hidden_size), dtype=torch.bfloat16)
fn = torch.randn((hc_mult3, hc_mult * hidden_size), dtype=torch.float32) * 1e-4
out_ref = torch.empty((1, num_tokens, hc_mult3), dtype=torch.float32)
sqrsum_ref = torch.empty((1, num_tokens), dtype=torch.float32)
out = torch.empty_like(out_ref)
sqrsum = torch.empty_like(sqrsum_ref)

_torch_hc_prenorm_gemm(x, fn, out_ref, sqrsum_ref)
_tilelang_hc_prenorm_gemm(x, fn, out, sqrsum, hidden_size, hc_mult)

torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(sqrsum, sqrsum_ref, atol=8.0, rtol=5e-4)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_mhc_post_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)

x = torch.randn((num_tokens, hidden_size), dtype=torch.bfloat16)
residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
post_layer_mix = torch.randn((num_tokens, hc_mult, 1), dtype=torch.float32)
comb_res_mix = torch.randn((num_tokens, hc_mult, hc_mult), dtype=torch.float32)

ref = mhc_post_ref(x, residual, post_layer_mix, comb_res_mix)
out = torch.ops.vllm.mhc_post_tilelang(
x,
residual,
post_layer_mix,
comb_res_mix,
)

torch.testing.assert_close(out, ref, atol=5e-2, rtol=1e-2)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
Expand Down Expand Up @@ -196,3 +321,42 @@ def test_hc_head_triton(num_tokens, hidden_size, hc_mult):

out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps)
torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_hc_head_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)

residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
fn = torch.randn((hc_mult, hc_mult * hidden_size), dtype=torch.float32) * 1e-4
hc_scale = torch.randn((1,), dtype=torch.float32) * 0.1
hc_base = torch.randn((hc_mult,), dtype=torch.float32) * 0.1
rms_eps = hc_eps = 1e-6

out = torch.empty((num_tokens, hidden_size), dtype=torch.bfloat16)
out.fill_(float("nan"))

result = torch.ops.vllm.hc_head_fused_kernel_tilelang(
residual,
fn,
hc_scale,
hc_base,
out,
hidden_size,
rms_eps,
hc_eps,
hc_mult,
)

assert result is None
assert not torch.isnan(out).any()

out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps)
torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2)
Loading