Summary
Successfully implemented 3-WG warp-specialized GEMM where each consumer WG computes a full output tile (not half-rows). One CTA covers 2 adjacent M-tiles with grid halved accordingly.
Design
Grid: ceildiv(N, block_N) × ceildiv(M, 2*block_M)
↑ each CTA covers 2 M-tiles
CTA (threads=384):
WG0 (producer): TMA load A1, A2, B (shared)
WG1 (consumer 1): C[row_base_1] += A_shared_1 @ B_shared (full block_M × block_N)
WG2 (consumer 2): C[row_base_2] += A_shared_2 @ B_shared (full block_M × block_N)
Key insight: don't split a tile's rows between WGs — give each WG its own complete tile. This avoids the half_m fragment bug and preserves full WGMMA efficiency.
Results
| Config |
max_diff |
Status |
| M=1024 N=1024 K=1024 |
0.0614 |
✅ |
| M=2048 N=2048 K=1024 |
0.0618 |
✅ |
| M=512 N=512 K=512 |
0.0313 |
✅ |
Why This Works (vs previous failures)
| Approach |
Fragment size |
Status |
Issue |
T.ws(1, 2) merged |
block_M |
❌ |
TileLang WGMMA lowering bug (#5) |
| Separate WGs, half_m fragments |
half_M |
❌ |
Fragment/WGMMA conflict at half size |
| Separate WGs, full block_M, same tile |
block_M |
✅ |
Works but redundant compute |
| Separate WGs, full block_M, split-KV |
block_M |
✅ |
Works but slow (4x shared mem) |
| Separate WGs, full block_M, 2 tiles per CTA |
block_M |
✅ |
Correct & efficient |
Each consumer uses the same full block_M × block_N fragment size that works with single-WG. The only difference from the 2-WG baseline is that the grid is halved and the producer loads 2 A-tiles + 1 shared B-tile.
Implications for GQA Attention
Same pattern applies:
- Grid:
ceildiv(seq_len, 2*block_m) instead of ceildiv(seq_len, block_m)
- WG1: full attention for Q[2bxblock_m : (2*bx+1)*block_m]
- WG2: full attention for Q[(2*bx+1)block_m : (2bx+2)*block_m]
- Both share K/V tiles from producer
- No combine needed — each consumer writes its own output independently
- Full block_m fragments everywhere
File
_test_ws_3wg_gemm_2tile.py
Related
Summary
Successfully implemented 3-WG warp-specialized GEMM where each consumer WG computes a full output tile (not half-rows). One CTA covers 2 adjacent M-tiles with grid halved accordingly.
Design
Key insight: don't split a tile's rows between WGs — give each WG its own complete tile. This avoids the half_m fragment bug and preserves full WGMMA efficiency.
Results
Why This Works (vs previous failures)
T.ws(1, 2)mergedEach consumer uses the same full block_M × block_N fragment size that works with single-WG. The only difference from the 2-WG baseline is that the grid is halved and the producer loads 2 A-tiles + 1 shared B-tile.
Implications for GQA Attention
Same pattern applies:
ceildiv(seq_len, 2*block_m)instead ofceildiv(seq_len, block_m)File
_test_ws_3wg_gemm_2tile.pyRelated
T.ws(1,2)merged consumer bug