Fix BUG and Optimize performance of mm operator for mthreads backend#2219
Fix BUG and Optimize performance of mm operator for mthreads backend#2219Vincent-Xiao wants to merge 11 commits intoflagos-ai:masterfrom
Conversation
… LibEntry cache the autotuner-selected config.pre_hook together with config.all_kwargs() on a cache miss, and replay these pre_hooks on the direct-launch path on a cache hit before launching the kernel.
There was a problem hiding this comment.
Pull request overview
This PR targets correctness and performance improvements for the mthreads backend matrix multiply path, including better handling of the N=1 case and reducing overhead in SQMMA launches.
Changes:
- Added a dedicated GEMV (N=1) kernel path for
mm/mm_out. - Added caching for TMA device descriptors to reduce repeated descriptor construction overhead.
- Updated
@libentry()execution to cache and replay Triton autotunerpre_hookcalls on cached kernel launches.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/flag_gems/utils/libentry.py |
Cache autotuner pre-hooks and replay them for cached compiled-kernel launches. |
src/flag_gems/runtime/backend/_mthreads/ops/utils.py |
Introduce an LRU-style cache for TMA device descriptors. |
src/flag_gems/runtime/backend/_mthreads/ops/mm.py |
Add GEMV fast path for N=1; refactor SQMMA descriptor setup to use cached descriptors + pre-hook. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| mm_sqmma_kernel[grid]( | ||
| A, | ||
| B, | ||
| C, | ||
| desc_a, |
There was a problem hiding this comment.
mm_sqmma() launches mm_sqmma_kernel without with torch_device_fn.device(A.device): (unlike mm_fma(), mm_out(), etc.). This can run the kernel on the wrong current device in multi-device contexts and can also break descriptor creation/usage that depends on the active device. Wrap the kernel launch in the same torch_device_fn.device(A.device) context manager used elsewhere in this backend.
| nargs["b_desc_ptr"].copy_( | ||
| get_cached_tma_device_descriptor(b, block_k, block_n, device) | ||
| ) | ||
| nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device)) |
There was a problem hiding this comment.
In sqmma_descriptor_pre_hook, create_tma_device_descriptor(c, ...) allocates a new device tensor every launch and then copy_() immediately copies it into c_desc_ptr, adding an extra allocation + device-to-device copy on the critical path. Consider changing the descriptor helper to fill an existing descriptor tensor (or to return a CPU tensor and copy directly into c_desc_ptr) so the pre-hook avoids per-launch device allocations.
| nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device)) | |
| nargs["c_desc_ptr"].copy_( | |
| get_cached_tma_device_descriptor(c, block_m, block_n, device) | |
| ) |
| def is_sqmma_compatible(a, b, N, K): | ||
| return ( | ||
| os.getenv("MUSA_ENABLE_SQMMA", "0") == "1" | ||
| and a.dim() == 2 | ||
| and b.dim() == 2 | ||
| and a.dtype == b.dtype | ||
| and a.dtype in (torch.float16, torch.bfloat16) | ||
| and is_supported_sqmma_layout(a) | ||
| and is_supported_sqmma_layout(b) | ||
| and N % 8 == 0 | ||
| and K % 8 == 0 | ||
| ) |
There was a problem hiding this comment.
mm() now gates SQMMA via is_sqmma_compatible() instead of the shared should_enable_sqmma() helper in ops/utils.py, which means the explicit shape exclusions in should_enable_sqmma (e.g. (15, 160, 1024)) are no longer applied for mm. If those exclusions are still required to avoid known SQMMA failures, this is a behavioral regression; consider reusing should_enable_sqmma here or moving any required exclusions/alignment checks into a single shared predicate used by mm/addmm/bmm.
PR Category
Operator
Type of Change
Bug Fix | Performance Optimization
Description
Issue
Progress
Performance