Summary
T.ws(1, 2) (merged multi-consumer warp group) produces incorrect GEMM results, while T.ws(1) + T.ws(2) (separate consumers) works correctly. This is a TileLang codegen/lowering bug.
Minimal Reproduction
Two 3-WG (threads=384) GEMM variants, same A×B computation:
Variant A: Merged consumer T.ws(1, 2) — FAIL
with T.ws(1, 2): # WG1+WG2 execute same code block
T.gemm(A_shared, B_shared, C_local)
Result: max_diff=207.0859 ❌
Variant B: Separate consumers T.ws(1) + T.ws(2) — PASS
with T.ws(1): # WG1 independently
T.gemm(A_shared, B_shared, C1_local)
with T.ws(2): # WG2 independently
T.gemm(A_shared, B_shared, C2_local)
Result: max_diff=0.0626 ✅
2-WG baseline T.ws(1) — PASS
with T.ws(1): # single consumer WG
T.gemm(A_shared, B_shared, C_local)
Result: max_diff=0.0610 ✅
Root Cause Analysis
From TileLang source (src/ir.cc), T.ws(1, 2) generates:
// Merge consecutive groups [1,2] -> range [1, 3)
PrimExpr range_cond = (thread_idx >= 128) && (thread_idx < 384);
IfFrame if_frame = If(condition);
This is purely a thread-range guard (if-then block + warp_specialize attribute). It does NOT handle:
- WGMMA work distribution across 2 warp groups
- Fragment partitioning between WG1 and WG2
- Correct accumulator merging
When T.gemm is lowered to WGMMA inside T.ws(1, 2), it likely generates WGMMA instructions assuming a single warp group context, but 256 threads (2 WGs) are active. The fragment layout and MMA tile mapping break.
Implications
This blocks the FA3-aligned GQA forward kernel from achieving parity with the pipelined variant:
- 2-WG WS (1 producer + 1 consumer): correct but ~2-3x slower (consumer has only 128 threads vs pipelined's 256)
- 3-WG WS with
T.ws(1, 2): incorrect results
- 3-WG WS with separate consumers: correct, but requires splitting the attention computation between WG1 and WG2 manually (each handles different Q rows)
Workaround
Use separate T.ws(1) and T.ws(2) blocks with manually partitioned work instead of merged T.ws(1, 2).
Full Reproduction Script
File: _test_ws_3wg_gemm.py in this session's TileOPs working tree. Tests both matmul_ws_3wg_merged (FAIL) and matmul_ws_2wg (PASS).
File: _test_ws_3wg_gemm_v2.py tests matmul_ws_3wg_separate (PASS) with independent T.ws(1) + T.ws(2) consumers.
Environment
- TileLang: 0.1.8+cuda.git5f70374c
- GPU: NVIDIA H100
- PyTorch: 2.9.1+cu128
Related
Summary
T.ws(1, 2)(merged multi-consumer warp group) produces incorrect GEMM results, whileT.ws(1)+T.ws(2)(separate consumers) works correctly. This is a TileLang codegen/lowering bug.Minimal Reproduction
Two 3-WG (threads=384) GEMM variants, same A×B computation:
Variant A: Merged consumer
T.ws(1, 2)— FAILResult:
max_diff=207.0859❌Variant B: Separate consumers
T.ws(1)+T.ws(2)— PASSResult:
max_diff=0.0626✅2-WG baseline
T.ws(1)— PASSResult:
max_diff=0.0610✅Root Cause Analysis
From TileLang source (
src/ir.cc),T.ws(1, 2)generates:This is purely a thread-range guard (if-then block +
warp_specializeattribute). It does NOT handle:When
T.gemmis lowered to WGMMA insideT.ws(1, 2), it likely generates WGMMA instructions assuming a single warp group context, but 256 threads (2 WGs) are active. The fragment layout and MMA tile mapping break.Implications
This blocks the FA3-aligned GQA forward kernel from achieving parity with the pipelined variant:
T.ws(1, 2): incorrect resultsWorkaround
Use separate
T.ws(1)andT.ws(2)blocks with manually partitioned work instead of mergedT.ws(1, 2).Full Reproduction Script
File:
_test_ws_3wg_gemm.pyin this session's TileOPs working tree. Tests bothmatmul_ws_3wg_merged(FAIL) andmatmul_ws_2wg(PASS).File:
_test_ws_3wg_gemm_v2.pytestsmatmul_ws_3wg_separate(PASS) with independentT.ws(1)+T.ws(2)consumers.Environment
Related