Skip to content

Commit f428a90

Browse files
committed
fix shape check cond and add tests to cover it
Signed-off-by: Masaki Kozuki <[email protected]>
1 parent 2dab9ae commit f428a90

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

thunder/tests/test_ops.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -439,15 +439,20 @@ def fn(a):
439439
MXFP8_GROUPED_MSG = "MXFP8 grouped GEMM is only supported when PyTorch is built with USE_FBGEMM_GENAI=1 on SM100+"
440440

441441
@requiresCUDA
442-
def test_scaled_grouped_mm_3d2d_rowwise():
442+
@pytest.mark.parametrize(
443+
"group_sizes,k,n",
444+
[
445+
([8, 8], 16, 16),
446+
([16, 16], 16, 16),
447+
],
448+
)
449+
def test_scaled_grouped_mm_2d3d_rowwise(group_sizes, k, n):
450+
"""Test 2D x 3D grouped matmul with various dimensions."""
443451
if not bool(PLATFORM_SUPPORTS_FP8_GROUPED_GEMM):
444452
pytest.skip(F8_GROUPED_MSG)
445453
device = "cuda"
446-
group_sizes = [16, 16]
447454
groups = len(group_sizes)
448455
total_rows = sum(group_sizes)
449-
k = 16
450-
n = 16
451456

452457
mat_a = torch.randn(total_rows, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
453458
mat_b = torch.randn(groups, n, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
@@ -475,21 +480,29 @@ def fn(a, b, scale_a, scale_b, offs):
475480
assert_consistency_of_compiletime_and_runtime(jitted, result)
476481

477482
@requiresCUDA
478-
def test_scaled_grouped_mm_2d3d_rowwise():
483+
@pytest.mark.parametrize(
484+
"group_sizes,m,k,n",
485+
[
486+
([8, 8], 16, 32, 16), # k != n to catch the dimension check bug
487+
([8, 8], 16, 16, 16), # k == n edge case
488+
],
489+
)
490+
def test_scaled_grouped_mm_3d2d_rowwise(group_sizes, m, k, n):
491+
"""Test 3D x 2D grouped matmul with various dimensions.
492+
493+
Note: k != n in first test case specifically catches the bug where
494+
mat_a.shape[2] was incorrectly compared with mat_b.shape[1].
495+
"""
479496
if not bool(PLATFORM_SUPPORTS_FP8_GROUPED_GEMM):
480497
pytest.skip(F8_GROUPED_MSG)
481498
device = "cuda"
482-
group_sizes = [8, 8]
483499
groups = len(group_sizes)
484-
total_rows = sum(group_sizes)
485-
k = 16
486-
n = 16
487500

488-
mat_a = torch.randn(total_rows, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
489-
mat_b = torch.randn(groups, n, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
501+
mat_a = torch.randn(groups, m, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
502+
mat_b = torch.randn(n, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
490503
offs = torch.tensor(group_sizes, device=device, dtype=torch.int32).cumsum(0, dtype=torch.int32)
491-
scale_a = torch.ones(total_rows, device=device, dtype=torch.float32)
492-
scale_b = torch.ones(groups, n, device=device, dtype=torch.float32)
504+
scale_a = torch.ones(groups, m, device=device, dtype=torch.float32)
505+
scale_b = torch.ones(n, device=device, dtype=torch.float32)
493506

494507
def fn(a, b, scale_a, scale_b, offs):
495508
return torch.nn.functional.scaled_grouped_mm(

thunder/torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5849,7 +5849,7 @@ def scaled_grouped_mm(
58495849
out_shape = (offs.shape[0], mat_a.shape[0], mat_b.shape[1])
58505850
elif mat_a.ndim == 3 and mat_b.ndim == 2:
58515851
utils.check(
5852-
mat_a.shape[2] == mat_b.shape[1],
5852+
mat_a.shape[2] == mat_b.shape[0],
58535853
lambda: f"Inner dimension mismatch: {mat_a.shape} vs {mat_b.shape}",
58545854
)
58555855
utils.check(

0 commit comments

Comments
 (0)