-
Notifications
You must be signed in to change notification settings - Fork 308
Fix BUG and Optimize performance of mm operator for mthreads backend #2219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
e6e15b2
a69c8a6
4f6b211
d2c0340
1e3994b
b151753
4164af3
0956652
9d87b5d
57adbdf
2f014e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,11 +10,33 @@ | |||||||||
| from flag_gems.utils import libentry, libtuner | ||||||||||
| from flag_gems.utils import triton_lang_extension as tle | ||||||||||
|
|
||||||||||
| from .utils import create_tma_device_descriptor, should_enable_sqmma | ||||||||||
| from .utils import create_tma_device_descriptor, get_cached_tma_device_descriptor | ||||||||||
|
|
||||||||||
| logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.mm") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def is_supported_sqmma_layout(tensor): | ||||||||||
| return tensor.is_contiguous() or ( | ||||||||||
| tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| logger = logging.getLogger( | ||||||||||
| f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| def is_sqmma_compatible(a, b, N, K): | ||||||||||
| return ( | ||||||||||
| os.getenv("MUSA_ENABLE_SQMMA", "0") == "1" | ||||||||||
| and a.dim() == 2 | ||||||||||
| and b.dim() == 2 | ||||||||||
| and a.dtype == b.dtype | ||||||||||
| and a.dtype in (torch.float16, torch.bfloat16) | ||||||||||
| and is_supported_sqmma_layout(a) | ||||||||||
| and is_supported_sqmma_layout(b) | ||||||||||
| and N % 8 == 0 | ||||||||||
| and K % 8 == 0 | ||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def matmul_get_configs(): | ||||||||||
| return runtime.get_tuned_config("mm") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @triton.jit | ||||||||||
|
|
@@ -25,9 +47,9 @@ def prev_multiple_of(a, b): | |||||||||
|
|
||||||||||
| @libentry() | ||||||||||
| @libtuner( | ||||||||||
| configs=runtime.get_tuned_config("mm"), | ||||||||||
| key=["M", "N", "K"], | ||||||||||
| strategy=["align32", "align32", "align32"], | ||||||||||
| configs=matmul_get_configs(), | ||||||||||
| key=["M", "N", "K", "stride_am", "stride_bk"], | ||||||||||
| strategy=["align32", "align32", "align32", "align32", "align32"], | ||||||||||
| ) | ||||||||||
| @triton.jit | ||||||||||
| def mm_kernel( | ||||||||||
|
|
@@ -43,6 +65,7 @@ def mm_kernel( | |||||||||
| stride_bn, | ||||||||||
| stride_cm, | ||||||||||
| stride_cn, | ||||||||||
| dtype: tl.constexpr, | ||||||||||
| BLOCK_M: tl.constexpr, | ||||||||||
| BLOCK_N: tl.constexpr, | ||||||||||
| BLOCK_K: tl.constexpr, | ||||||||||
|
|
@@ -101,6 +124,58 @@ def mm_kernel( | |||||||||
| tl.store(C, acc, mask=mask) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def gemv_get_configs(): | ||||||||||
| return [triton.Config({"BLOCK_M": 64, "BLOCK_K": 64})] | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @libentry() | ||||||||||
| @libtuner( | ||||||||||
| configs=gemv_get_configs(), | ||||||||||
| key=["M", "K", "stride_am", "stride_bk"], | ||||||||||
| strategy=["align32", "align32", "align32", "default"], | ||||||||||
| ) | ||||||||||
| @triton.jit | ||||||||||
| def gemv_kernel( | ||||||||||
| A, | ||||||||||
| B, | ||||||||||
| C, | ||||||||||
| M, | ||||||||||
| K, | ||||||||||
| stride_am, | ||||||||||
| stride_ak, | ||||||||||
| stride_bk, | ||||||||||
| stride_cm, | ||||||||||
| BLOCK_M: tl.constexpr, | ||||||||||
| BLOCK_K: tl.constexpr, | ||||||||||
| ): | ||||||||||
| pid = tle.program_id(0) | ||||||||||
|
|
||||||||||
| row_start = pid * BLOCK_M | ||||||||||
| row_offset = row_start + tl.arange(0, BLOCK_M) | ||||||||||
| row_mask = row_offset < M | ||||||||||
|
|
||||||||||
| acc = tl.zeros((BLOCK_M,), dtype=tl.float32) | ||||||||||
|
|
||||||||||
| for k_start in range(0, K, BLOCK_K): | ||||||||||
| k_offset = k_start + tl.arange(0, BLOCK_K) | ||||||||||
| k_mask = k_offset < K | ||||||||||
|
|
||||||||||
| a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak | ||||||||||
| a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) | ||||||||||
|
|
||||||||||
| b_ptrs = B + k_offset * stride_bk | ||||||||||
| b = tl.load(b_ptrs, mask=k_mask, other=0.0) | ||||||||||
|
|
||||||||||
| # Keep the reduction in fp32 so N=1 GEMV matches the mm path more closely. | ||||||||||
| a = a.to(tl.float32) | ||||||||||
| b = b.to(tl.float32) | ||||||||||
| acc += tl.sum(a * b[None, :], axis=1) | ||||||||||
|
|
||||||||||
| c_ptrs = C + row_offset * stride_cm | ||||||||||
| acc = acc.to(C.dtype.element_ty) | ||||||||||
| tl.store(c_ptrs, acc, mask=row_mask) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
|
@@ -151,11 +226,34 @@ def mm_fma(a, b): | |||||||||
| b.stride(1), | ||||||||||
| c.stride(0), | ||||||||||
| c.stride(1), | ||||||||||
| dtype=str(a.dtype).split(".")[-1], | ||||||||||
| GROUP_M=8, | ||||||||||
| ) | ||||||||||
| return c | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def gemv_mm(a, b, c, M, K): | ||||||||||
| logger.debug( | ||||||||||
| "GEMS_MTHREADS MM(GEMV), [shape info]: [%s, %s, 1](M, K, N)", | ||||||||||
| M, | ||||||||||
| K, | ||||||||||
| ) | ||||||||||
| grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) | ||||||||||
| with torch_device_fn.device(a.device): | ||||||||||
| gemv_kernel[grid]( | ||||||||||
| a, | ||||||||||
| b, | ||||||||||
| c, | ||||||||||
| M, | ||||||||||
| K, | ||||||||||
| a.stride(0), | ||||||||||
| a.stride(1), | ||||||||||
| b.stride(0), | ||||||||||
| c.stride(0), | ||||||||||
| ) | ||||||||||
| return c | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def mm_out(a, b, *, out): | ||||||||||
| logger.debug("GEMS_MTHREADS MM_OUT") | ||||||||||
| # handle non-contiguous inputs if necessary | ||||||||||
|
|
@@ -169,6 +267,8 @@ def mm_out(a, b, *, out): | |||||||||
| _, N = b.shape | ||||||||||
| # allocates output | ||||||||||
| c = out | ||||||||||
| if N == 1: | ||||||||||
| return gemv_mm(a, b, c, M, K) | ||||||||||
| # launch kernel | ||||||||||
| grid = lambda META: ( | ||||||||||
| triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), | ||||||||||
|
|
@@ -187,62 +287,106 @@ def mm_out(a, b, *, out): | |||||||||
| b.stride(1), | ||||||||||
| c.stride(0), | ||||||||||
| c.stride(1), | ||||||||||
| dtype=str(a.dtype).split(".")[-1], | ||||||||||
| GROUP_M=8, | ||||||||||
| ) | ||||||||||
| return c | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def sqmma_descriptor_pre_hook(nargs): | ||||||||||
| a = nargs["A"] | ||||||||||
| b = nargs["B"] | ||||||||||
| c = nargs["C"] | ||||||||||
| block_m = nargs["BLOCK_M"] | ||||||||||
| block_n = nargs["BLOCK_N"] | ||||||||||
| block_k = nargs["BLOCK_K"] | ||||||||||
| device = c.device | ||||||||||
|
|
||||||||||
| nargs["a_desc_ptr"].copy_( | ||||||||||
| get_cached_tma_device_descriptor(a, block_m, block_k, device) | ||||||||||
| ) | ||||||||||
| nargs["b_desc_ptr"].copy_( | ||||||||||
| get_cached_tma_device_descriptor(b, block_k, block_n, device) | ||||||||||
| ) | ||||||||||
| nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device)) | ||||||||||
|
||||||||||
| nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device)) | |
| nargs["c_desc_ptr"].copy_( | |
| get_cached_tma_device_descriptor(c, block_m, block_n, device) | |
| ) |
Copilot
AI
Apr 2, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mm_sqmma() launches mm_sqmma_kernel without with torch_device_fn.device(A.device): (unlike mm_fma(), mm_out(), etc.). This can run the kernel on the wrong current device in multi-device contexts and can also break descriptor creation/usage that depends on the active device. Wrap the kernel launch in the same torch_device_fn.device(A.device) context manager used elsewhere in this backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mm()now gates SQMMA viais_sqmma_compatible()instead of the sharedshould_enable_sqmma()helper inops/utils.py, which means the explicit shape exclusions inshould_enable_sqmma(e.g.(15, 160, 1024)) are no longer applied formm. If those exclusions are still required to avoid known SQMMA failures, this is a behavioral regression; consider reusingshould_enable_sqmmahere or moving any required exclusions/alignment checks into a single shared predicate used bymm/addmm/bmm.