Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tokenspeed-kernel/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ def mi300_platform() -> PlatformInfo:


@pytest.fixture
def mi355_platform() -> PlatformInfo:
def mi350_platform() -> PlatformInfo:
return PlatformInfo(
vendor="amd",
arch_version=ArchVersion(9, 5),
device_name="AMD Instinct MI355X",
device_name="AMD Instinct MI350X/MI355X",
device_count=8,
total_memory=288 * (1024**3),
memory_bandwidth=8000.0,
Expand Down
6 changes: 5 additions & 1 deletion tokenspeed-kernel/test/test_callsite_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,11 @@ def _site_id(site: CallSite) -> str:
)
@pytest.mark.parametrize(
"platform_name",
["h100_platform", "b200_platform"],
[
"h100_platform",
"b200_platform",
"mi350_platform",
],
)
def test_kernel_selection(site, platform_name, request):
platform = request.getfixturevalue(platform_name)
Expand Down
14 changes: 7 additions & 7 deletions tokenspeed-kernel/test/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_filter_by_features(self, sample_specs):
assert "paged" in s.features

def test_filter_by_platform(
self, sample_specs, h100_platform, mi300_platform, mi355_platform
self, sample_specs, h100_platform, mi300_platform, mi350_platform
):
reg = KernelRegistry.get()
register_all_samples(reg, sample_specs)
Expand All @@ -159,13 +159,13 @@ def test_filter_by_platform(
assert "flashinfer_decode" not in amd_names
assert "aiter_decode" in amd_names

mi355_kernels = reg.get_for_operator(
"attention", "decode", platform=mi355_platform
mi350_kernels = reg.get_for_operator(
"attention", "decode", platform=mi350_platform
)
mi355_names = {s.name for s in mi355_kernels}
assert "flashinfer_decode" not in mi355_names
assert "aiter_decode" in mi355_names
assert "triton_decode" in mi355_names
mi350_names = {s.name for s in mi350_kernels}
assert "flashinfer_decode" not in mi350_names
assert "aiter_decode" in mi350_names
assert "triton_decode" in mi350_names

def test_filter_by_dtype(self, sample_specs):
reg = KernelRegistry.get()
Expand Down
4 changes: 2 additions & 2 deletions tokenspeed-kernel/test/test_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,15 +687,15 @@ def test_amd_platform_selects_aiter(self, sample_specs, mi300_platform):
)
assert impl() == "aiter_decode"

def test_amd_mi355_platform_selects_aiter(self, sample_specs, mi355_platform):
def test_amd_mi350_platform_selects_aiter(self, sample_specs, mi350_platform):
reg = KernelRegistry.get()
register_all_samples(reg, sample_specs)

impl = select_kernel(
"attention",
"decode",
torch.bfloat16,
platform=mi355_platform,
platform=mi350_platform,
)
assert impl() == "aiter_decode"

Expand Down
Loading