@@ -439,15 +439,20 @@ def fn(a):
439439 MXFP8_GROUPED_MSG = "MXFP8 grouped GEMM is only supported when PyTorch is built with USE_FBGEMM_GENAI=1 on SM100+"
440440
441441 @requiresCUDA
442- def test_scaled_grouped_mm_3d2d_rowwise ():
442+ @pytest .mark .parametrize (
443+ "group_sizes,k,n" ,
444+ [
445+ ([8 , 8 ], 16 , 16 ),
446+ ([16 , 16 ], 16 , 16 ),
447+ ],
448+ )
449+ def test_scaled_grouped_mm_2d3d_rowwise (group_sizes , k , n ):
450+ """Test 2D x 3D grouped matmul with various dimensions."""
443451 if not bool (PLATFORM_SUPPORTS_FP8_GROUPED_GEMM ):
444452 pytest .skip (F8_GROUPED_MSG )
445453 device = "cuda"
446- group_sizes = [16 , 16 ]
447454 groups = len (group_sizes )
448455 total_rows = sum (group_sizes )
449- k = 16
450- n = 16
451456
452457 mat_a = torch .randn (total_rows , k , device = device , dtype = torch .bfloat16 ).to (torch .float8_e4m3fn )
453458 mat_b = torch .randn (groups , n , k , device = device , dtype = torch .bfloat16 ).to (torch .float8_e4m3fn )
@@ -475,21 +480,29 @@ def fn(a, b, scale_a, scale_b, offs):
475480 assert_consistency_of_compiletime_and_runtime (jitted , result )
476481
477482 @requiresCUDA
478- def test_scaled_grouped_mm_2d3d_rowwise ():
483+ @pytest .mark .parametrize (
484+ "group_sizes,m,k,n" ,
485+ [
486+ ([8 , 8 ], 16 , 32 , 16 ), # k != n to catch the dimension check bug
487+ ([8 , 8 ], 16 , 16 , 16 ), # k == n edge case
488+ ],
489+ )
490+ def test_scaled_grouped_mm_3d2d_rowwise (group_sizes , m , k , n ):
491+ """Test 3D x 2D grouped matmul with various dimensions.
492+
493+ Note: k != n in first test case specifically catches the bug where
494+ mat_a.shape[2] was incorrectly compared with mat_b.shape[1].
495+ """
479496 if not bool (PLATFORM_SUPPORTS_FP8_GROUPED_GEMM ):
480497 pytest .skip (F8_GROUPED_MSG )
481498 device = "cuda"
482- group_sizes = [8 , 8 ]
483499 groups = len (group_sizes )
484- total_rows = sum (group_sizes )
485- k = 16
486- n = 16
487500
488- mat_a = torch .randn (total_rows , k , device = device , dtype = torch .bfloat16 ).to (torch .float8_e4m3fn )
489- mat_b = torch .randn (groups , n , k , device = device , dtype = torch .bfloat16 ).to (torch .float8_e4m3fn )
501+ mat_a = torch .randn (groups , m , k , device = device , dtype = torch .bfloat16 ).to (torch .float8_e4m3fn )
502+ mat_b = torch .randn (n , k , device = device , dtype = torch .bfloat16 ).to (torch .float8_e4m3fn )
490503 offs = torch .tensor (group_sizes , device = device , dtype = torch .int32 ).cumsum (0 , dtype = torch .int32 )
491- scale_a = torch .ones (total_rows , device = device , dtype = torch .float32 )
492- scale_b = torch .ones (groups , n , device = device , dtype = torch .float32 )
504+ scale_a = torch .ones (groups , m , device = device , dtype = torch .float32 )
505+ scale_b = torch .ones (n , device = device , dtype = torch .float32 )
493506
494507 def fn (a , b , scale_a , scale_b , offs ):
495508 return torch .nn .functional .scaled_grouped_mm (
0 commit comments