Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/build/rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ wheel
jinja2>=3.1.6
amdsmi==7.0.2
timm>=1.0.17
tilelang>=0.1.10
1 change: 1 addition & 0 deletions requirements/rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py
amd-quark>=0.8.99
tilelang>=0.1.10
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test and even nicer feature. i am wondering for CI purposes to avoid any regressions from future versions, should we pin the version? Maybe we can add it in rocm.in too. We can also do that in a follow-up PR. But let me know if you agree.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it in a follow up PR. Would like to land this PR and let @WoosukKwon continue with the restructuring of the mhc kernels.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have pinned the version to exact version.

168 changes: 166 additions & 2 deletions tests/kernels/test_mhc_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import torch

import vllm.model_executor.kernels.mhc # noqa: F401
from vllm.model_executor.kernels.mhc.tilelang import (
_tilelang_hc_prenorm_gemm,
_torch_hc_prenorm_gemm,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_tilelang
from vllm.utils.torch_utils import set_random_seed

DEVICE = current_platform.device_type
Expand Down Expand Up @@ -92,8 +97,128 @@ def hc_head_ref(


@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="CUDA required",
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_mhc_pre_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)

residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
hc_mult2 = hc_mult * hc_mult
hc_mult3 = 2 * hc_mult + hc_mult2
fn = (
torch.randn((hc_mult3, hc_mult, hidden_size), dtype=torch.float)
* 1e-4
* (1 + torch.arange(hc_mult).mul(0.01).view(1, -1, 1))
).flatten(1, 2)
hc_scale = torch.randn((3,), dtype=torch.float) * 0.1
hc_base = torch.randn((hc_mult3,), dtype=torch.float) * 0.1

hc_sinkhorn_eps = hc_pre_eps = rms_eps = 1e-6
sinkhorn_repeat = 20
hc_post_alpha = 1.0

ref = mhc_pre_ref(
residual,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_alpha,
sinkhorn_repeat,
)
out = torch.ops.vllm.mhc_pre_tilelang(
residual,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_alpha,
sinkhorn_repeat,
)

for actual, expected in zip(out, ref, strict=True):
torch.testing.assert_close(actual, expected, atol=5e-2, rtol=1e-2)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize(
("num_tokens", "hidden_size"),
[
(1, 1280),
(512, 1280),
(2048, 1280),
(1, 4096),
(64, 4096),
(512, 4096),
(2048, 4096),
(1, 7168),
(64, 7168),
(512, 7168),
(2048, 7168),
],
)
def test_hc_prenorm_gemm_tilelang(num_tokens, hidden_size):
torch.set_default_device(DEVICE)
set_random_seed(0)

hc_mult = 4
hc_mult3 = 2 * hc_mult + hc_mult * hc_mult
x = torch.randn((num_tokens, hc_mult * hidden_size), dtype=torch.bfloat16)
fn = torch.randn((hc_mult3, hc_mult * hidden_size), dtype=torch.float32) * 1e-4
out_ref = torch.empty((1, num_tokens, hc_mult3), dtype=torch.float32)
sqrsum_ref = torch.empty((1, num_tokens), dtype=torch.float32)
out = torch.empty_like(out_ref)
sqrsum = torch.empty_like(sqrsum_ref)

_torch_hc_prenorm_gemm(x, fn, out_ref, sqrsum_ref)
_tilelang_hc_prenorm_gemm(x, fn, out, sqrsum, hidden_size, hc_mult)

torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(sqrsum, sqrsum_ref, atol=8.0, rtol=5e-4)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_mhc_post_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)

x = torch.randn((num_tokens, hidden_size), dtype=torch.bfloat16)
residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
post_layer_mix = torch.randn((num_tokens, hc_mult, 1), dtype=torch.float32)
comb_res_mix = torch.randn((num_tokens, hc_mult, hc_mult), dtype=torch.float32)

ref = mhc_post_ref(x, residual, post_layer_mix, comb_res_mix)
out = torch.ops.vllm.mhc_post_tilelang(
x,
residual,
post_layer_mix,
comb_res_mix,
)

torch.testing.assert_close(out, ref, atol=5e-2, rtol=1e-2)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
Expand Down Expand Up @@ -196,3 +321,42 @@ def test_hc_head_triton(num_tokens, hidden_size, hc_mult):

out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps)
torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2)


@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_hc_head_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)

residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
fn = torch.randn((hc_mult, hc_mult * hidden_size), dtype=torch.float32) * 1e-4
hc_scale = torch.randn((1,), dtype=torch.float32) * 0.1
hc_base = torch.randn((hc_mult,), dtype=torch.float32) * 0.1
rms_eps = hc_eps = 1e-6

out = torch.empty((num_tokens, hidden_size), dtype=torch.bfloat16)
out.fill_(float("nan"))

result = torch.ops.vllm.hc_head_fused_kernel_tilelang(
residual,
fn,
hc_scale,
hc_base,
out,
hidden_size,
rms_eps,
hc_eps,
hc_mult,
)

assert result is None
assert not torch.isnan(out).any()

out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps)
torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2)
Loading
Loading