Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,8 @@ def _copy_with_setitem_impl(a, key, value):
baddbmm = _register_torch_operation("baddbmm")
if LooseVersion(torch.__version__) >= "2.8":
_grouped_mm = _register_torch_operation("_grouped_mm")
if _scaled_grouped_mm_available := hasattr(torch.nn.functional, "scaled_grouped_mm"):
scaled_grouped_mm = _register_torch_operation("scaled_grouped_mm", module=torch.nn.functional)
convolution = _register_torch_operation("convolution")
conv1d = _register_torch_operation("conv1d", module=torch.nn.functional)
conv2d = _register_torch_operation("conv2d", module=torch.nn.functional)
Expand Down Expand Up @@ -1827,11 +1829,40 @@ def _grouped_mm_checker(a: TensorProxy, b: TensorProxy, offsets: TensorProxy) ->
return a.dtype == dtypes.bfloat16 and b.dtype == dtypes.bfloat16 and offsets.dtype == dtypes.int32


def _scaled_grouped_mm_checker(
mat_a: TensorProxy,
mat_b: TensorProxy,
scale_a,
scale_recipe_a,
scale_b,
scale_recipe_b,
swizzle_a=None,
swizzle_b=None,
bias: None | TensorProxy = None,
offs: None | TensorProxy = None,
output_dtype: None | dtypeLike = None,
contraction_dim: Sequence[int] | tuple[int, ...] = (),
use_fast_accum: bool = False,
) -> bool:
if offs is None:
return False
if isinstance(offs, TensorProxy):
return utils.is_integer_dtype(offs.dtype)
return True


_register_implementation(ltorch.baddbmm, baddbmm, checker=_always_executable)
_register_implementation(ltorch.bmm, bmm, checker=_always_executable)
if LooseVersion(torch.__version__) >= "2.8":
_register_implementation(prims._grouped_mm, _grouped_mm, checker=_grouped_mm_checker)
_register_implementation(ltorch._grouped_mm, _grouped_mm, checker=_grouped_mm_checker)

if _scaled_grouped_mm_available:
_register_implementation(
ltorch.scaled_grouped_mm,
scaled_grouped_mm,
checker=_scaled_grouped_mm_checker,
)
_register_implementation(ltorch.convolution, checker=_always_executable, execution_transform=_convolution_transform)
_register_implementation(ltorch.conv1d, conv1d, checker=_always_executable)
_register_implementation(ltorch.conv2d, conv2d, checker=_always_executable)
Expand Down
169 changes: 169 additions & 0 deletions thunder/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
import torch
from torch.testing import assert_close

HAS_SCALED_GROUPED_MM = hasattr(torch.nn.functional, "scaled_grouped_mm")

if HAS_SCALED_GROUPED_MM:
from torch.nn.functional import ScalingType, SwizzleType
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM,
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM,
)
from torch.testing._internal.common_quantized import to_blocked, to_mxfp

import thunder
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
Expand Down Expand Up @@ -419,6 +429,165 @@ def fn(a):
assert_close(b, b_ref)


if HAS_SCALED_GROUPED_MM:
# NOTE: The following tests exercise torch.nn.functional.scaled_grouped_mm via Thunder.
# They validate that the Thunder tracing mirrors eager execution for representative 2D/2D
# and 2D/3D grouped matmul shapes that correspond to the accepted combinations in
# thunder.core.prims._grouped_mm_meta. These scenarios mirror the small tensor smoke tests
# in PyTorch's scaled matmul CUDA suite ([pytorch/test_scaled_matmul_cuda.py](https://github.com/pytorch/pytorch/blob/2f023bf7/test/test_scaled_matmul_cuda.py)).
F8_GROUPED_MSG = "FP8 grouped is only supported on SM90 and MI300+ devices"
MXFP8_GROUPED_MSG = "MXFP8 grouped GEMM is only supported when PyTorch is built with USE_FBGEMM_GENAI=1 on SM100+"

@requiresCUDA
@pytest.mark.parametrize(
"group_sizes,k,n",
[
([8, 8], 16, 16),
([16, 16], 16, 16),
],
)
def test_scaled_grouped_mm_2d3d_rowwise(group_sizes, k, n):
"""Test 2D x 3D grouped matmul with various dimensions."""
if not bool(PLATFORM_SUPPORTS_FP8_GROUPED_GEMM):
pytest.skip(F8_GROUPED_MSG)
device = "cuda"
groups = len(group_sizes)
total_rows = sum(group_sizes)

mat_a = torch.randn(total_rows, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
mat_b = torch.randn(groups, n, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
offs = torch.tensor(group_sizes, device=device, dtype=torch.int32).cumsum(0, dtype=torch.int32)
scale_a = torch.ones(total_rows, device=device, dtype=torch.float32)
scale_b = torch.ones(groups, n, device=device, dtype=torch.float32)

def fn(a, b, scale_a, scale_b, offs):
return torch.nn.functional.scaled_grouped_mm(
a,
b.transpose(-2, -1),
scale_a,
ScalingType.RowWise,
scale_b,
ScalingType.RowWise,
offs=offs,
output_dtype=torch.bfloat16,
)

eager = fn(mat_a, mat_b, scale_a, scale_b, offs)
jitted = thunder.jit(fn)
result = jitted(mat_a, mat_b, scale_a, scale_b, offs)

torch.testing.assert_close(result, eager)
assert_consistency_of_compiletime_and_runtime(jitted, result)

@requiresCUDA
@pytest.mark.parametrize(
"group_sizes,m,k,n",
[
([8, 8], 16, 32, 16), # k != n to catch the dimension check bug
([8, 8], 16, 16, 16), # k == n edge case
],
)
def test_scaled_grouped_mm_3d2d_rowwise(group_sizes, m, k, n):
"""Test 3D x 2D grouped matmul with various dimensions.

Note: k != n in first test case specifically catches the bug where
mat_a.shape[2] was incorrectly compared with mat_b.shape[1].
"""
if not bool(PLATFORM_SUPPORTS_FP8_GROUPED_GEMM):
pytest.skip(F8_GROUPED_MSG)
device = "cuda"
groups = len(group_sizes)

mat_a = torch.randn(groups, m, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
mat_b = torch.randn(n, k, device=device, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
offs = torch.tensor(group_sizes, device=device, dtype=torch.int32).cumsum(0, dtype=torch.int32)
scale_a = torch.ones(groups, m, device=device, dtype=torch.float32)
scale_b = torch.ones(n, device=device, dtype=torch.float32)

def fn(a, b, scale_a, scale_b, offs):
return torch.nn.functional.scaled_grouped_mm(
a,
b.transpose(-2, -1),
scale_a,
ScalingType.RowWise,
scale_b,
ScalingType.RowWise,
offs=offs,
output_dtype=torch.bfloat16,
)

eager = fn(mat_a, mat_b, scale_a, scale_b, offs)
jitted = thunder.jit(fn)
result = jitted(mat_a, mat_b, scale_a, scale_b, offs)

torch.testing.assert_close(result, eager)
assert_consistency_of_compiletime_and_runtime(jitted, result)

@requiresCUDA
def test_scaled_grouped_mm_2d2d_mxfp8_blockwise():
if not bool(PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM):
pytest.skip(MXFP8_GROUPED_MSG)

device = "cuda"
torch.manual_seed(0)

group_sizes = [64, 32]
m = 128
n = 96
total_k = sum(group_sizes)

raw_offs = torch.tensor(group_sizes, device=device, dtype=torch.int32)
offs = torch.cumsum(raw_offs, dim=0, dtype=torch.int32)

def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y

def quantize_to_mxfp8(mat: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
segments: list[torch.Tensor] = []
scale_segments: list[torch.Tensor] = []
start = 0
for end in offs.tolist():
segment = mat[:, start:end].contiguous()
scale, lowp = to_mxfp(segment, format="mxfp8")
scale_segments.append(to_blocked(scale))
segments.append(lowp)
start = end
lowp_full = torch.cat(segments, dim=1)
rows_rounded = round_up(mat.shape[0], 128)
scale_full = torch.cat(scale_segments, dim=0).reshape(rows_rounded, -1)
return lowp_full, scale_full

mat_a_hp = torch.randn(m, total_k, device=device, dtype=torch.bfloat16) * 0.1
mat_b_hp = torch.randn(n, total_k, device=device, dtype=torch.bfloat16) * 0.01

mat_a_lp, scale_a = quantize_to_mxfp8(mat_a_hp)
mat_b_lp_rows, scale_b = quantize_to_mxfp8(mat_b_hp)
mat_b_lp = mat_b_lp_rows.transpose(0, 1)

swizzle = SwizzleType.SWIZZLE_32_4_4 if torch.version.cuda else SwizzleType.NO_SWIZZLE

def fn(mat_a, mat_b, scale_a, scale_b, offs):
return torch.nn.functional.scaled_grouped_mm(
mat_a,
mat_b,
scale_a,
ScalingType.BlockWise1x32,
scale_b,
ScalingType.BlockWise1x32,
swizzle_a=swizzle,
swizzle_b=swizzle,
offs=offs,
output_dtype=torch.bfloat16,
)

eager = fn(mat_a_lp, mat_b_lp, scale_a, scale_b, offs)
jitted = thunder.jit(fn)
result = jitted(mat_a_lp, mat_b_lp, scale_a, scale_b, offs)

torch.testing.assert_close(result, eager)
assert_consistency_of_compiletime_and_runtime(jitted, result)


# https://github.com/Lightning-AI/lightning-thunder/issues/1857
def test_max_with_int():
def f(x, ids):
Expand Down
88 changes: 88 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5798,6 +5798,94 @@ def _grouped_mm(
return prims._grouped_mm(a, b, offsets)


if hasattr(torch.nn.functional, "scaled_grouped_mm"):

@torchsymbol(
torch.nn.functional.scaled_grouped_mm,
id="torch.nn.functional.scaled_grouped_mm",
is_method=False,
is_prim=True,
)
def scaled_grouped_mm(
mat_a: TensorProxy,
mat_b: TensorProxy,
scale_a,
scale_recipe_a,
scale_b,
scale_recipe_b,
swizzle_a=None,
swizzle_b=None,
bias: None | TensorProxy = None,
offs: None | TensorProxy = None,
output_dtype: dtypeLike = torch.bfloat16,
contraction_dim: Sequence[int] | tuple[int, ...] = (),
use_fast_accum: bool = False,
) -> TensorProxy:
utils.check(
offs is not None,
lambda: "scaled_grouped_mm currently requires `offs`.",
)
utils.check_type(offs, TensorProxy)
utils.check(
bias is None,
lambda: "scaled_grouped_mm currently doesn't support `bias`.",
)
utils.check(
len(contraction_dim) == 0,
lambda: f"scaled_grouped_mm currently expects an empty `contraction_dim`, but got {contraction_dim}.",
)

utils.check_type(mat_a, TensorProxy)
utils.check_type(mat_b, TensorProxy)
utils.check(mat_a.ndim in (2, 3), lambda: f"Expected mat_a to have 2 or 3 dimensions, got {mat_a.ndim}")
utils.check(mat_b.ndim in (2, 3), lambda: f"Expected mat_b to have 2 or 3 dimensions, got {mat_b.ndim}")

utils.check(offs.ndim == 1, lambda: f"`offs` must be a vector, got shape {offs.shape}")
if mat_a.ndim == 2 and mat_b.ndim == 2:
utils.check(
mat_a.shape[1] == mat_b.shape[0],
lambda: f"Inner dimension mismatch: {mat_a.shape} vs {mat_b.shape}",
)
out_shape = (offs.shape[0], mat_a.shape[0], mat_b.shape[1])
elif mat_a.ndim == 3 and mat_b.ndim == 2:
utils.check(
mat_a.shape[2] == mat_b.shape[0],
lambda: f"Inner dimension mismatch: {mat_a.shape} vs {mat_b.shape}",
)
utils.check(
mat_a.shape[0] == offs.shape[0],
lambda: f"Group count mismatch: {mat_a.shape} vs {offs.shape}",
)
out_shape = (mat_a.shape[1], mat_b.shape[1])
elif mat_a.ndim == 2 and mat_b.ndim == 3:
utils.check(
mat_a.shape[1] == mat_b.shape[1],
lambda: f"Inner dimension mismatch: {mat_a.shape} vs {mat_b.shape}",
)
utils.check(
mat_b.shape[0] == offs.shape[0],
lambda: f"Group count mismatch: {mat_b.shape} vs {offs.shape}",
)
out_shape = (mat_a.shape[0], mat_b.shape[2])
else:
utils.check(False, lambda: f"Unexpected shape combination: {mat_a.shape} and {mat_b.shape}")

utils.check_same_dtype(mat_a, mat_b)
allowed_input_dtypes = dtypes.float_math_dtypes | dtypes.float_8bit_dtypes
utils.check(
mat_a.dtype in allowed_input_dtypes,
lambda: f"`mat_a` must be a floating dtype, got {mat_a.dtype}",
)
utils.check(
utils.is_integer_dtype(offs.dtype),
lambda: f"`offs` must be integers, got {offs.dtype}",
)
utils.check_same_device(mat_a, mat_b)

target_dtype = to_dtype(output_dtype) if output_dtype is not None else mat_a.dtype
return TensorProxy(like=mat_a, shape=out_shape, dtype=target_dtype)


@torchsymbol(torch.logsumexp, is_method=True)
def logsumexp(a: TensorLike, /, dim: int | Sequence[int], keepdim: bool = False) -> TensorLike:
input_max = amax(a, dim, keepdim=True)
Expand Down
Loading