Skip to content

TileLang: T.ws blocks incorrectly alias shared memory across parallel warp groups #8

@superAngGao

Description

@superAngGao

Summary

When two T.ws consumer blocks each allocate a shared buffer of the same shape, TileLang's shared memory allocator treats them as non-overlapping lifetimes and assigns them the same physical address. But T.ws blocks execute in parallel on different warp groups, so both WGs read/write the same memory concurrently, causing silent data corruption.

Root Cause

The compiler treats T.ws(1) and T.ws(2) as mutually exclusive code paths (like if-else branches), allowing their T.alloc_shared buffers to alias. In reality, warp-specialized blocks run concurrently on different hardware warp groups.

Reproduction

with T.ws(1):
    p_shared_1 = T.alloc_shared([64, 128], "float16")  # 16KB
    T.copy(acc_s_1, p_shared_1)
    T.gemm(p_shared_1, v_shared, acc_o_1)

with T.ws(2):
    p_shared_2 = T.alloc_shared([64, 128], "float16")  # 16KB
    T.copy(acc_s_2, p_shared_2)
    T.gemm(p_shared_2, v_shared, acc_o_2)

Compiled CUDA shows both mapped to the same address:

// WG1 PV gemm
initialize_wgmma_descriptor(desc_a_1, &smem[24576]);  // p_shared_1
// WG2 PV gemm
initialize_wgmma_descriptor(desc_a_3, &smem[24576]);  // p_shared_2 ← SAME ADDRESS

Observed Behavior

  • D=64 (small buffers): Both P buffers alias → max_diff > 1.0 (incorrect results)
  • D=128 (larger buffers): Buffers happen to NOT alias due to size constraints → max_diff = 0.0002 (correct)
  • tir.disable_storage_rewrite: True does NOT fix the issue — the aliasing occurs during T.ws lowering, not in the storage rewrite pass

Additional Issue: T.wgmma_gemm fence/wait ordering

While investigating, we also found that T.wgmma_gemm() emits warpgroup_fence_operand before wait_wgmma, but the correct WGMMA sequence (per NVIDIA PTX spec) requires wait before fence:

// Generated (incorrect):
warpgroup_commit_batch();
warpgroup_fence_operand(acc, N);  // fence BEFORE wait
wait_wgmma<0>();                  // wait AFTER fence

// Correct (as in T.gemm and FA3):
warpgroup_commit_batch();
warpgroup_wait<0>();              // wait FIRST
warpgroup_fence_operand(acc, N);  // fence AFTER

The synchronous T.gemm() generates the correct order. Only T.wgmma_gemm() is affected. This blocks implementing FA3-style IntraWGOverlap (overlapping QK[n] and PV[n-1] within a consumer warp group).

Context

These issues were found while implementing FA3-aligned warp-specialized GQA attention kernels. Related: #7 (3-WG WS GEMM 2-Tile), #4 (GQA WS parent issue).

Environment

  • TileLang from tilelang 0.1.8+cuda.git5f70374c
  • CUDA SM90 (Hopper)
  • Conda env: env_ws_test

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions