Skip to content

Fix BUG and Optimize performance of mm operator for mthreads backend#2219

Open
Vincent-Xiao wants to merge 11 commits intoflagos-ai:masterfrom
Vincent-Xiao:mm_mthreads
Open

Fix BUG and Optimize performance of mm operator for mthreads backend#2219
Vincent-Xiao wants to merge 11 commits intoflagos-ai:masterfrom
Vincent-Xiao:mm_mthreads

Conversation

@Vincent-Xiao
Copy link
Copy Markdown
Contributor

PR Category

Operator

Type of Change

Bug Fix | Performance Optimization

Description

  1. Added a dedicated GEMV kernel path to address the N=1 case.
  2. Improved SQMMA kernel performance by caching the device_descriptor.
  3. Improved performance by caching the pre_hook when using @libentry().
  4. Achieved up to 124% performance improvement

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR targets correctness and performance improvements for the mthreads backend matrix multiply path, including better handling of the N=1 case and reducing overhead in SQMMA launches.

Changes:

  • Added a dedicated GEMV (N=1) kernel path for mm/mm_out.
  • Added caching for TMA device descriptors to reduce repeated descriptor construction overhead.
  • Updated @libentry() execution to cache and replay Triton autotuner pre_hook calls on cached kernel launches.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
src/flag_gems/utils/libentry.py Cache autotuner pre-hooks and replay them for cached compiled-kernel launches.
src/flag_gems/runtime/backend/_mthreads/ops/utils.py Introduce an LRU-style cache for TMA device descriptors.
src/flag_gems/runtime/backend/_mthreads/ops/mm.py Add GEMV fast path for N=1; refactor SQMMA descriptor setup to use cached descriptors + pre-hook.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +432 to 436
mm_sqmma_kernel[grid](
A,
B,
C,
desc_a,
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm_sqmma() launches mm_sqmma_kernel without with torch_device_fn.device(A.device): (unlike mm_fma(), mm_out(), etc.). This can run the kernel on the wrong current device in multi-device contexts and can also break descriptor creation/usage that depends on the active device. Wrap the kernel launch in the same torch_device_fn.device(A.device) context manager used elsewhere in this backend.

Copilot uses AI. Check for mistakes.
nargs["b_desc_ptr"].copy_(
get_cached_tma_device_descriptor(b, block_k, block_n, device)
)
nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device))
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In sqmma_descriptor_pre_hook, create_tma_device_descriptor(c, ...) allocates a new device tensor every launch and then copy_() immediately copies it into c_desc_ptr, adding an extra allocation + device-to-device copy on the critical path. Consider changing the descriptor helper to fill an existing descriptor tensor (or to return a CPU tensor and copy directly into c_desc_ptr) so the pre-hook avoids per-launch device allocations.

Suggested change
nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device))
nargs["c_desc_ptr"].copy_(
get_cached_tma_device_descriptor(c, block_m, block_n, device)
)

Copilot uses AI. Check for mistakes.
Comment on lines +24 to +35
def is_sqmma_compatible(a, b, N, K):
return (
os.getenv("MUSA_ENABLE_SQMMA", "0") == "1"
and a.dim() == 2
and b.dim() == 2
and a.dtype == b.dtype
and a.dtype in (torch.float16, torch.bfloat16)
and is_supported_sqmma_layout(a)
and is_supported_sqmma_layout(b)
and N % 8 == 0
and K % 8 == 0
)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm() now gates SQMMA via is_sqmma_compatible() instead of the shared should_enable_sqmma() helper in ops/utils.py, which means the explicit shape exclusions in should_enable_sqmma (e.g. (15, 160, 1024)) are no longer applied for mm. If those exclusions are still required to avoid known SQMMA failures, this is a behavioral regression; consider reusing should_enable_sqmma here or moving any required exclusions/alignment checks into a single shared predicate used by mm/addmm/bmm.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants