Summary
When two T.ws consumer blocks each allocate a shared buffer of the same shape, TileLang's shared memory allocator treats them as non-overlapping lifetimes and assigns them the same physical address. But T.ws blocks execute in parallel on different warp groups, so both WGs read/write the same memory concurrently, causing silent data corruption.
Root Cause
The compiler treats T.ws(1) and T.ws(2) as mutually exclusive code paths (like if-else branches), allowing their T.alloc_shared buffers to alias. In reality, warp-specialized blocks run concurrently on different hardware warp groups.
Reproduction
with T.ws(1):
p_shared_1 = T.alloc_shared([64, 128], "float16") # 16KB
T.copy(acc_s_1, p_shared_1)
T.gemm(p_shared_1, v_shared, acc_o_1)
with T.ws(2):
p_shared_2 = T.alloc_shared([64, 128], "float16") # 16KB
T.copy(acc_s_2, p_shared_2)
T.gemm(p_shared_2, v_shared, acc_o_2)
Compiled CUDA shows both mapped to the same address:
// WG1 PV gemm
initialize_wgmma_descriptor(desc_a_1, &smem[24576]); // p_shared_1
// WG2 PV gemm
initialize_wgmma_descriptor(desc_a_3, &smem[24576]); // p_shared_2 ← SAME ADDRESS
Observed Behavior
- D=64 (small buffers): Both P buffers alias →
max_diff > 1.0 (incorrect results)
- D=128 (larger buffers): Buffers happen to NOT alias due to size constraints →
max_diff = 0.0002 (correct)
tir.disable_storage_rewrite: True does NOT fix the issue — the aliasing occurs during T.ws lowering, not in the storage rewrite pass
Additional Issue: T.wgmma_gemm fence/wait ordering
While investigating, we also found that T.wgmma_gemm() emits warpgroup_fence_operand before wait_wgmma, but the correct WGMMA sequence (per NVIDIA PTX spec) requires wait before fence:
// Generated (incorrect):
warpgroup_commit_batch();
warpgroup_fence_operand(acc, N); // fence BEFORE wait
wait_wgmma<0>(); // wait AFTER fence
// Correct (as in T.gemm and FA3):
warpgroup_commit_batch();
warpgroup_wait<0>(); // wait FIRST
warpgroup_fence_operand(acc, N); // fence AFTER
The synchronous T.gemm() generates the correct order. Only T.wgmma_gemm() is affected. This blocks implementing FA3-style IntraWGOverlap (overlapping QK[n] and PV[n-1] within a consumer warp group).
Context
These issues were found while implementing FA3-aligned warp-specialized GQA attention kernels. Related: #7 (3-WG WS GEMM 2-Tile), #4 (GQA WS parent issue).
Environment
- TileLang from
tilelang 0.1.8+cuda.git5f70374c
- CUDA SM90 (Hopper)
- Conda env:
env_ws_test
Summary
When two
T.wsconsumer blocks each allocate a shared buffer of the same shape, TileLang's shared memory allocator treats them as non-overlapping lifetimes and assigns them the same physical address. ButT.wsblocks execute in parallel on different warp groups, so both WGs read/write the same memory concurrently, causing silent data corruption.Root Cause
The compiler treats
T.ws(1)andT.ws(2)as mutually exclusive code paths (like if-else branches), allowing theirT.alloc_sharedbuffers to alias. In reality, warp-specialized blocks run concurrently on different hardware warp groups.Reproduction
Compiled CUDA shows both mapped to the same address:
Observed Behavior
max_diff > 1.0(incorrect results)max_diff = 0.0002(correct)tir.disable_storage_rewrite: Truedoes NOT fix the issue — the aliasing occurs during T.ws lowering, not in the storage rewrite passAdditional Issue:
T.wgmma_gemmfence/wait orderingWhile investigating, we also found that
T.wgmma_gemm()emitswarpgroup_fence_operandbeforewait_wgmma, but the correct WGMMA sequence (per NVIDIA PTX spec) requires wait before fence:The synchronous
T.gemm()generates the correct order. OnlyT.wgmma_gemm()is affected. This blocks implementing FA3-style IntraWGOverlap (overlapping QK[n] and PV[n-1] within a consumer warp group).Context
These issues were found while implementing FA3-aligned warp-specialized GQA attention kernels. Related: #7 (3-WG WS GEMM 2-Tile), #4 (GQA WS parent issue).
Environment
tilelang 0.1.8+cuda.git5f70374cenv_ws_test