diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index d6fb0b31c1..e76c11d3f3 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1531,6 +1531,8 @@ def _copy_with_setitem_impl(a, key, value): baddbmm = _register_torch_operation("baddbmm") if LooseVersion(torch.__version__) >= "2.8": _grouped_mm = _register_torch_operation("_grouped_mm") +if _scaled_grouped_mm_available := hasattr(torch.nn.functional, "scaled_grouped_mm"): + scaled_grouped_mm = _register_torch_operation("scaled_grouped_mm", module=torch.nn.functional) convolution = _register_torch_operation("convolution") conv1d = _register_torch_operation("conv1d", module=torch.nn.functional) conv2d = _register_torch_operation("conv2d", module=torch.nn.functional) @@ -1827,11 +1829,40 @@ def _grouped_mm_checker(a: TensorProxy, b: TensorProxy, offsets: TensorProxy) -> return a.dtype == dtypes.bfloat16 and b.dtype == dtypes.bfloat16 and offsets.dtype == dtypes.int32 +def _scaled_grouped_mm_checker( + mat_a: TensorProxy, + mat_b: TensorProxy, + scale_a, + scale_recipe_a, + scale_b, + scale_recipe_b, + swizzle_a=None, + swizzle_b=None, + bias: None | TensorProxy = None, + offs: None | TensorProxy = None, + output_dtype: None | dtypeLike = None, + contraction_dim: Sequence[int] | tuple[int, ...] = (), + use_fast_accum: bool = False, +) -> bool: + if offs is None: + return False + if isinstance(offs, TensorProxy): + return utils.is_integer_dtype(offs.dtype) + return True + + _register_implementation(ltorch.baddbmm, baddbmm, checker=_always_executable) _register_implementation(ltorch.bmm, bmm, checker=_always_executable) if LooseVersion(torch.__version__) >= "2.8": _register_implementation(prims._grouped_mm, _grouped_mm, checker=_grouped_mm_checker) _register_implementation(ltorch._grouped_mm, _grouped_mm, checker=_grouped_mm_checker) + +if _scaled_grouped_mm_available: + _register_implementation( + ltorch.scaled_grouped_mm, + scaled_grouped_mm, + checker=_scaled_grouped_mm_checker, + ) _register_implementation(ltorch.convolution, checker=_always_executable, execution_transform=_convolution_transform) _register_implementation(ltorch.conv1d, conv1d, checker=_always_executable) _register_implementation(ltorch.conv2d, conv2d, checker=_always_executable) diff --git a/thunder/tests/test_ops.py b/thunder/tests/test_ops.py index 62479e1223..659ac99eaf 100644 --- a/thunder/tests/test_ops.py +++ b/thunder/tests/test_ops.py @@ -6,6 +6,16 @@ import torch from torch.testing import assert_close +HAS_SCALED_GROUPED_MM = hasattr(torch.nn.functional, "scaled_grouped_mm") + +if HAS_SCALED_GROUPED_MM: + from torch.nn.functional import ScalingType, SwizzleType + from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, + PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, + ) + from torch.testing._internal.common_quantized import to_blocked, to_mxfp + import thunder import thunder.core.devices as devices import thunder.core.dtypes as dtypes @@ -419,6 +429,165 @@ def fn(a): assert_close(b, b_ref) +if HAS_SCALED_GROUPED_MM: + # NOTE: The following tests exercise torch.nn.functional.scaled_grouped_mm via Thunder. + # They validate that the Thunder tracing mirrors eager execution for representative 2D/2D + # and 2D/3D grouped matmul shapes that correspond to the accepted combinations in + # thunder.core.prims._grouped_mm_meta. These scenarios mirror the small tensor smoke tests + # in PyTorch's scaled matmul CUDA suite ([pytorch/test_scaled_matmul_cuda.py](https://github.com/pytorch/pytorch/blob/2f023bf7/test/test_scaled_matmul_cuda.py)). + F8_GROUPED_MSG = "FP8 grouped is only supported on SM90 and MI300+ devices" + MXFP8_GROUPED_MSG = "MXFP8 grouped GEMM is only supported when PyTorch is built with USE_FBGEMM_GENAI=1 on SM100+" + + @requiresCUDA + @pytest.mark.parametrize( + "group_sizes,k,n", + [ + ([8, 8], 16, 16), + ([16, 16], 16, 16), + ], + ) + def test_scaled_grouped_mm_2d3d_rowwise(group_sizes, k, n): + """Test 2D x 3D grouped matmul with various dimensions.""" + if not bool(PLATFORM_SUPPORTS_FP8_GROUPED_GEMM): + pytest.skip(F8_GROUPED_MSG) + device = "cuda" + groups = len(group_sizes) + total_rows = sum(group_sizes) + + mat_a = torch.randn(total_rows, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn) + mat_b = torch.randn(groups, n, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn) + offs = torch.tensor(group_sizes, device=device, dtype=torch.int32).cumsum(0, dtype=torch.int32) + scale_a = torch.ones(total_rows, device=device, dtype=torch.float32) + scale_b = torch.ones(groups, n, device=device, dtype=torch.float32) + + def fn(a, b, scale_a, scale_b, offs): + return torch.nn.functional.scaled_grouped_mm( + a, + b.transpose(-2, -1), + scale_a, + ScalingType.RowWise, + scale_b, + ScalingType.RowWise, + offs=offs, + output_dtype=torch.bfloat16, + ) + + eager = fn(mat_a, mat_b, scale_a, scale_b, offs) + jitted = thunder.jit(fn) + result = jitted(mat_a, mat_b, scale_a, scale_b, offs) + + torch.testing.assert_close(result, eager) + assert_consistency_of_compiletime_and_runtime(jitted, result) + + @requiresCUDA + @pytest.mark.parametrize( + "group_sizes,m,k,n", + [ + ([8, 8], 16, 32, 16), # k != n to catch the dimension check bug + ([8, 8], 16, 16, 16), # k == n edge case + ], + ) + def test_scaled_grouped_mm_3d2d_rowwise(group_sizes, m, k, n): + """Test 3D x 2D grouped matmul with various dimensions. + + Note: k != n in first test case specifically catches the bug where + mat_a.shape[2] was incorrectly compared with mat_b.shape[1]. + """ + if not bool(PLATFORM_SUPPORTS_FP8_GROUPED_GEMM): + pytest.skip(F8_GROUPED_MSG) + device = "cuda" + groups = len(group_sizes) + + mat_a = torch.randn(groups, m, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn) + mat_b = torch.randn(n, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn) + offs = torch.tensor(group_sizes, device=device, dtype=torch.int32).cumsum(0, dtype=torch.int32) + scale_a = torch.ones(groups, m, device=device, dtype=torch.float32) + scale_b = torch.ones(n, device=device, dtype=torch.float32) + + def fn(a, b, scale_a, scale_b, offs): + return torch.nn.functional.scaled_grouped_mm( + a, + b.transpose(-2, -1), + scale_a, + ScalingType.RowWise, + scale_b, + ScalingType.RowWise, + offs=offs, + output_dtype=torch.bfloat16, + ) + + eager = fn(mat_a, mat_b, scale_a, scale_b, offs) + jitted = thunder.jit(fn) + result = jitted(mat_a, mat_b, scale_a, scale_b, offs) + + torch.testing.assert_close(result, eager) + assert_consistency_of_compiletime_and_runtime(jitted, result) + + @requiresCUDA + def test_scaled_grouped_mm_2d2d_mxfp8_blockwise(): + if not bool(PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM): + pytest.skip(MXFP8_GROUPED_MSG) + + device = "cuda" + torch.manual_seed(0) + + group_sizes = [64, 32] + m = 128 + n = 96 + total_k = sum(group_sizes) + + raw_offs = torch.tensor(group_sizes, device=device, dtype=torch.int32) + offs = torch.cumsum(raw_offs, dim=0, dtype=torch.int32) + + def round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + def quantize_to_mxfp8(mat: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + segments: list[torch.Tensor] = [] + scale_segments: list[torch.Tensor] = [] + start = 0 + for end in offs.tolist(): + segment = mat[:, start:end].contiguous() + scale, lowp = to_mxfp(segment, format="mxfp8") + scale_segments.append(to_blocked(scale)) + segments.append(lowp) + start = end + lowp_full = torch.cat(segments, dim=1) + rows_rounded = round_up(mat.shape[0], 128) + scale_full = torch.cat(scale_segments, dim=0).reshape(rows_rounded, -1) + return lowp_full, scale_full + + mat_a_hp = torch.randn(m, total_k, device=device, dtype=torch.bfloat16) * 0.1 + mat_b_hp = torch.randn(n, total_k, device=device, dtype=torch.bfloat16) * 0.01 + + mat_a_lp, scale_a = quantize_to_mxfp8(mat_a_hp) + mat_b_lp_rows, scale_b = quantize_to_mxfp8(mat_b_hp) + mat_b_lp = mat_b_lp_rows.transpose(0, 1) + + swizzle = SwizzleType.SWIZZLE_32_4_4 if torch.version.cuda else SwizzleType.NO_SWIZZLE + + def fn(mat_a, mat_b, scale_a, scale_b, offs): + return torch.nn.functional.scaled_grouped_mm( + mat_a, + mat_b, + scale_a, + ScalingType.BlockWise1x32, + scale_b, + ScalingType.BlockWise1x32, + swizzle_a=swizzle, + swizzle_b=swizzle, + offs=offs, + output_dtype=torch.bfloat16, + ) + + eager = fn(mat_a_lp, mat_b_lp, scale_a, scale_b, offs) + jitted = thunder.jit(fn) + result = jitted(mat_a_lp, mat_b_lp, scale_a, scale_b, offs) + + torch.testing.assert_close(result, eager) + assert_consistency_of_compiletime_and_runtime(jitted, result) + + # https://github.com/Lightning-AI/lightning-thunder/issues/1857 def test_max_with_int(): def f(x, ids): diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 090f46643e..dacd58c116 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5806,6 +5806,94 @@ def _grouped_mm( return prims._grouped_mm(a, b, offsets) +if hasattr(torch.nn.functional, "scaled_grouped_mm"): + + @torchsymbol( + torch.nn.functional.scaled_grouped_mm, + id="torch.nn.functional.scaled_grouped_mm", + is_method=False, + is_prim=True, + ) + def scaled_grouped_mm( + mat_a: TensorProxy, + mat_b: TensorProxy, + scale_a, + scale_recipe_a, + scale_b, + scale_recipe_b, + swizzle_a=None, + swizzle_b=None, + bias: None | TensorProxy = None, + offs: None | TensorProxy = None, + output_dtype: dtypeLike = torch.bfloat16, + contraction_dim: Sequence[int] | tuple[int, ...] = (), + use_fast_accum: bool = False, + ) -> TensorProxy: + utils.check( + offs is not None, + lambda: "scaled_grouped_mm currently requires `offs`.", + ) + utils.check_type(offs, TensorProxy) + utils.check( + bias is None, + lambda: "scaled_grouped_mm currently doesn't support `bias`.", + ) + utils.check( + len(contraction_dim) == 0, + lambda: f"scaled_grouped_mm currently expects an empty `contraction_dim`, but got {contraction_dim}.", + ) + + utils.check_type(mat_a, TensorProxy) + utils.check_type(mat_b, TensorProxy) + utils.check(mat_a.ndim in (2, 3), lambda: f"Expected mat_a to have 2 or 3 dimensions, got {mat_a.ndim}") + utils.check(mat_b.ndim in (2, 3), lambda: f"Expected mat_b to have 2 or 3 dimensions, got {mat_b.ndim}") + + utils.check(offs.ndim == 1, lambda: f"`offs` must be a vector, got shape {offs.shape}") + if mat_a.ndim == 2 and mat_b.ndim == 2: + utils.check( + mat_a.shape[1] == mat_b.shape[0], + lambda: f"Inner dimension mismatch: {mat_a.shape} vs {mat_b.shape}", + ) + out_shape = (offs.shape[0], mat_a.shape[0], mat_b.shape[1]) + elif mat_a.ndim == 3 and mat_b.ndim == 2: + utils.check( + mat_a.shape[2] == mat_b.shape[0], + lambda: f"Inner dimension mismatch: {mat_a.shape} vs {mat_b.shape}", + ) + utils.check( + mat_a.shape[0] == offs.shape[0], + lambda: f"Group count mismatch: {mat_a.shape} vs {offs.shape}", + ) + out_shape = (mat_a.shape[1], mat_b.shape[1]) + elif mat_a.ndim == 2 and mat_b.ndim == 3: + utils.check( + mat_a.shape[1] == mat_b.shape[1], + lambda: f"Inner dimension mismatch: {mat_a.shape} vs {mat_b.shape}", + ) + utils.check( + mat_b.shape[0] == offs.shape[0], + lambda: f"Group count mismatch: {mat_b.shape} vs {offs.shape}", + ) + out_shape = (mat_a.shape[0], mat_b.shape[2]) + else: + utils.check(False, lambda: f"Unexpected shape combination: {mat_a.shape} and {mat_b.shape}") + + utils.check_same_dtype(mat_a, mat_b) + allowed_input_dtypes = dtypes.float_math_dtypes | dtypes.float_8bit_dtypes + utils.check( + mat_a.dtype in allowed_input_dtypes, + lambda: f"`mat_a` must be a floating dtype, got {mat_a.dtype}", + ) + utils.check( + utils.is_integer_dtype(offs.dtype), + lambda: f"`offs` must be integers, got {offs.dtype}", + ) + utils.check_same_device(mat_a, mat_b) + + target_dtype = to_dtype(output_dtype) if output_dtype is not None else mat_a.dtype + return TensorProxy(like=mat_a, shape=out_shape, dtype=target_dtype) + + @torchsymbol(torch.logsumexp, is_method=True) def logsumexp(a: TensorLike, /, dim: int | Sequence[int], keepdim: bool = False) -> TensorLike: input_max = amax(a, dim, keepdim=True)