Skip to content
241 changes: 191 additions & 50 deletions src/flag_gems/runtime/backend/_mthreads/ops/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,33 @@
from flag_gems.utils import libentry, libtuner
from flag_gems.utils import triton_lang_extension as tle

from .utils import create_tma_device_descriptor, should_enable_sqmma
from .utils import create_tma_device_descriptor, get_cached_tma_device_descriptor

logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.mm")


def is_supported_sqmma_layout(tensor):
return tensor.is_contiguous() or (
tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
)

logger = logging.getLogger(
f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
)

def is_sqmma_compatible(a, b, N, K):
return (
os.getenv("MUSA_ENABLE_SQMMA", "0") == "1"
and a.dim() == 2
and b.dim() == 2
and a.dtype == b.dtype
and a.dtype in (torch.float16, torch.bfloat16)
and is_supported_sqmma_layout(a)
and is_supported_sqmma_layout(b)
and N % 8 == 0
and K % 8 == 0
)
Comment on lines +24 to +35
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm() now gates SQMMA via is_sqmma_compatible() instead of the shared should_enable_sqmma() helper in ops/utils.py, which means the explicit shape exclusions in should_enable_sqmma (e.g. (15, 160, 1024)) are no longer applied for mm. If those exclusions are still required to avoid known SQMMA failures, this is a behavioral regression; consider reusing should_enable_sqmma here or moving any required exclusions/alignment checks into a single shared predicate used by mm/addmm/bmm.

Copilot uses AI. Check for mistakes.


def matmul_get_configs():
return runtime.get_tuned_config("mm")


@triton.jit
Expand All @@ -25,9 +47,9 @@ def prev_multiple_of(a, b):

@libentry()
@libtuner(
configs=runtime.get_tuned_config("mm"),
key=["M", "N", "K"],
strategy=["align32", "align32", "align32"],
configs=matmul_get_configs(),
key=["M", "N", "K", "stride_am", "stride_bk"],
strategy=["align32", "align32", "align32", "align32", "align32"],
)
@triton.jit
def mm_kernel(
Expand All @@ -43,6 +65,7 @@ def mm_kernel(
stride_bn,
stride_cm,
stride_cn,
dtype: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
Expand Down Expand Up @@ -101,6 +124,58 @@ def mm_kernel(
tl.store(C, acc, mask=mask)


def gemv_get_configs():
return [triton.Config({"BLOCK_M": 64, "BLOCK_K": 64})]


@libentry()
@libtuner(
configs=gemv_get_configs(),
key=["M", "K", "stride_am", "stride_bk"],
strategy=["align32", "align32", "align32", "default"],
)
@triton.jit
def gemv_kernel(
A,
B,
C,
M,
K,
stride_am,
stride_ak,
stride_bk,
stride_cm,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid = tle.program_id(0)

row_start = pid * BLOCK_M
row_offset = row_start + tl.arange(0, BLOCK_M)
row_mask = row_offset < M

acc = tl.zeros((BLOCK_M,), dtype=tl.float32)

for k_start in range(0, K, BLOCK_K):
k_offset = k_start + tl.arange(0, BLOCK_K)
k_mask = k_offset < K

a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)

b_ptrs = B + k_offset * stride_bk
b = tl.load(b_ptrs, mask=k_mask, other=0.0)

# Keep the reduction in fp32 so N=1 GEMV matches the mm path more closely.
a = a.to(tl.float32)
b = b.to(tl.float32)
acc += tl.sum(a * b[None, :], axis=1)

c_ptrs = C + row_offset * stride_cm
acc = acc.to(C.dtype.element_ty)
tl.store(c_ptrs, acc, mask=row_mask)


_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]


Expand Down Expand Up @@ -151,11 +226,34 @@ def mm_fma(a, b):
b.stride(1),
c.stride(0),
c.stride(1),
dtype=str(a.dtype).split(".")[-1],
GROUP_M=8,
)
return c


def gemv_mm(a, b, c, M, K):
logger.debug(
"GEMS_MTHREADS MM(GEMV), [shape info]: [%s, %s, 1](M, K, N)",
M,
K,
)
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
with torch_device_fn.device(a.device):
gemv_kernel[grid](
a,
b,
c,
M,
K,
a.stride(0),
a.stride(1),
b.stride(0),
c.stride(0),
)
return c


def mm_out(a, b, *, out):
logger.debug("GEMS_MTHREADS MM_OUT")
# handle non-contiguous inputs if necessary
Expand All @@ -169,6 +267,8 @@ def mm_out(a, b, *, out):
_, N = b.shape
# allocates output
c = out
if N == 1:
return gemv_mm(a, b, c, M, K)
# launch kernel
grid = lambda META: (
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
Expand All @@ -187,62 +287,106 @@ def mm_out(a, b, *, out):
b.stride(1),
c.stride(0),
c.stride(1),
dtype=str(a.dtype).split(".")[-1],
GROUP_M=8,
)
return c


def sqmma_descriptor_pre_hook(nargs):
a = nargs["A"]
b = nargs["B"]
c = nargs["C"]
block_m = nargs["BLOCK_M"]
block_n = nargs["BLOCK_N"]
block_k = nargs["BLOCK_K"]
device = c.device

nargs["a_desc_ptr"].copy_(
get_cached_tma_device_descriptor(a, block_m, block_k, device)
)
nargs["b_desc_ptr"].copy_(
get_cached_tma_device_descriptor(b, block_k, block_n, device)
)
nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device))
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In sqmma_descriptor_pre_hook, create_tma_device_descriptor(c, ...) allocates a new device tensor every launch and then copy_() immediately copies it into c_desc_ptr, adding an extra allocation + device-to-device copy on the critical path. Consider changing the descriptor helper to fill an existing descriptor tensor (or to return a CPU tensor and copy directly into c_desc_ptr) so the pre-hook avoids per-launch device allocations.

Suggested change
nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device))
nargs["c_desc_ptr"].copy_(
get_cached_tma_device_descriptor(c, block_m, block_n, device)
)

Copilot uses AI. Check for mistakes.


def sqmma_get_configs(pre_hook=sqmma_descriptor_pre_hook):
return [
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64},
num_stages=1,
num_warps=4,
pre_hook=pre_hook,
)
]


@libentry()
@libtuner(
configs=sqmma_get_configs(),
key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
strategy=["align32", "align32", "align32", "align32", "align32", "default"],
)
@triton.jit
def mm_sqmma_kernel(
A,
B,
C,
a_desc_ptr,
b_desc_ptr,
c_desc_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
dtype: tl.constexpr,
GROUP_M: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ab_dtype: tl.constexpr,
c_dtype: tl.constexpr,
is_transpose_a: tl.constexpr = False,
is_transpose_b: tl.constexpr = False,
):
pid = tle.program_id(0)
grid_m = tl.cdiv(M, BLOCK_SIZE_M)
grid_n = tl.cdiv(N, BLOCK_SIZE_N)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
offs_k = 0
offs_am = offs_am.to(tl.int32)
offs_bn = offs_bn.to(tl.int32)
offs_k = offs_k.to(tl.int32)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
tme_load_ab_dtype = ab_dtype
c_store_dtype = c_dtype
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl._experimental_descriptor_load(
a_desc_ptr,
[offs_am, offs_k],
[BLOCK_SIZE_M, BLOCK_SIZE_K],
[BLOCK_M, BLOCK_K],
tme_load_ab_dtype,
is_transpose_a,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
[offs_k, offs_bn],
[BLOCK_SIZE_K, BLOCK_SIZE_N],
[BLOCK_K, BLOCK_N],
tme_load_ab_dtype,
is_transpose_b,
)
accumulator += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
offs_k += BLOCK_SIZE_K
offs_k += BLOCK_K
accumulator = accumulator.to(c_store_dtype)
tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])

Expand All @@ -256,9 +400,9 @@ def get_triton_type(elem_type):
return type_map.get(elem_type, None)


def mm_sqmma(A, B, M, N, K, GROUP_M, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages):
def mm_sqmma(A, B, M, N, K, GROUP_M):
logger.debug("GEMS_MTHREADS MM(SQMMA)")
device = "musa"
device = A.device
# handle non-contiguous inputs if necessary
is_transpose_a = False
is_transpose_b = False
Expand All @@ -277,24 +421,32 @@ def mm_sqmma(A, B, M, N, K, GROUP_M, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_s
assert a_type == b_type, "Mat A and Mat B should have the same dtype"
c_dtype = get_higher_dtype(a_type, b_type)
C = torch.empty((M, N), dtype=c_dtype, device=device)
desc_a = create_tma_device_descriptor(A, BLOCK_M, BLOCK_K, device)
desc_b = create_tma_device_descriptor(B, BLOCK_K, BLOCK_N, device)
desc_c = create_tma_device_descriptor(C, BLOCK_M, BLOCK_N, device)
mm_sqmma_kernel[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1)](
desc_a = torch.empty((64,), dtype=torch.int8, device=device)
desc_b = torch.empty((64,), dtype=torch.int8, device=device)
desc_c = torch.empty((64,), dtype=torch.int8, device=device)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
1,
1,
)
mm_sqmma_kernel[grid](
A,
B,
C,
desc_a,
Comment on lines +432 to 436
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm_sqmma() launches mm_sqmma_kernel without with torch_device_fn.device(A.device): (unlike mm_fma(), mm_out(), etc.). This can run the kernel on the wrong current device in multi-device contexts and can also break descriptor creation/usage that depends on the active device. Wrap the kernel launch in the same torch_device_fn.device(A.device) context manager used elsewhere in this backend.

Copilot uses AI. Check for mistakes.
desc_b,
desc_c,
M,
N,
K,
GROUP_M,
BLOCK_M,
BLOCK_N,
BLOCK_K,
get_triton_type(a_type),
get_triton_type(c_dtype),
num_warps=num_warps,
num_stages=num_stages,
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(1),
str(a_type).split(".")[-1],
GROUP_M=GROUP_M,
ab_dtype=get_triton_type(a_type),
c_dtype=get_triton_type(c_dtype),
is_transpose_a=is_transpose_a,
is_transpose_b=is_transpose_b,
)
Expand All @@ -306,30 +458,19 @@ def mm(a, b):
b_dtype = b.dtype
M, K = a.shape
_, N = b.shape
use_sqmma = should_enable_sqmma(a_dtype, b_dtype, M, N, K)
if use_sqmma:
if N == 1:
c_dtype = get_higher_dtype(a_dtype, b_dtype)
c = torch.empty((M, N), device=a.device, dtype=c_dtype)
return gemv_mm(a, b, c, M, K)
if is_sqmma_compatible(a, b, N, K):
GROUP_M = 8
BLOCK_M = 128
BLOCK_N = BLOCK_M
BLOCK_K = 64
num_warps = 16 if BLOCK_M == 256 else 4
num_stages = 1
return mm_sqmma(
a,
b,
M,
N,
K,
GROUP_M,
BLOCK_M,
BLOCK_N,
BLOCK_K,
num_warps,
num_stages,
)
else:
enable_sqmma = os.environ.pop("MUSA_ENABLE_SQMMA", None)
result = mm_fma(a, b)
if enable_sqmma:
os.environ["MUSA_ENABLE_SQMMA"] = enable_sqmma
return result
return mm_fma(a, b)
Loading
Loading