Skip to content

TileLang Bug: T.ws(1,2) merged multi-WG consumer produces incorrect GEMM results #5

@superAngGao

Description

@superAngGao

Summary

T.ws(1, 2) (merged multi-consumer warp group) produces incorrect GEMM results, while T.ws(1) + T.ws(2) (separate consumers) works correctly. This is a TileLang codegen/lowering bug.

Minimal Reproduction

Two 3-WG (threads=384) GEMM variants, same A×B computation:

Variant A: Merged consumer T.ws(1, 2)FAIL

with T.ws(1, 2):   # WG1+WG2 execute same code block
    T.gemm(A_shared, B_shared, C_local)

Result: max_diff=207.0859

Variant B: Separate consumers T.ws(1) + T.ws(2)PASS

with T.ws(1):       # WG1 independently
    T.gemm(A_shared, B_shared, C1_local)
with T.ws(2):       # WG2 independently
    T.gemm(A_shared, B_shared, C2_local)

Result: max_diff=0.0626

2-WG baseline T.ws(1)PASS

with T.ws(1):       # single consumer WG
    T.gemm(A_shared, B_shared, C_local)

Result: max_diff=0.0610

Root Cause Analysis

From TileLang source (src/ir.cc), T.ws(1, 2) generates:

// Merge consecutive groups [1,2] -> range [1, 3)
PrimExpr range_cond = (thread_idx >= 128) && (thread_idx < 384);
IfFrame if_frame = If(condition);

This is purely a thread-range guard (if-then block + warp_specialize attribute). It does NOT handle:

  • WGMMA work distribution across 2 warp groups
  • Fragment partitioning between WG1 and WG2
  • Correct accumulator merging

When T.gemm is lowered to WGMMA inside T.ws(1, 2), it likely generates WGMMA instructions assuming a single warp group context, but 256 threads (2 WGs) are active. The fragment layout and MMA tile mapping break.

Implications

This blocks the FA3-aligned GQA forward kernel from achieving parity with the pipelined variant:

  • 2-WG WS (1 producer + 1 consumer): correct but ~2-3x slower (consumer has only 128 threads vs pipelined's 256)
  • 3-WG WS with T.ws(1, 2): incorrect results
  • 3-WG WS with separate consumers: correct, but requires splitting the attention computation between WG1 and WG2 manually (each handles different Q rows)

Workaround

Use separate T.ws(1) and T.ws(2) blocks with manually partitioned work instead of merged T.ws(1, 2).

Full Reproduction Script

File: _test_ws_3wg_gemm.py in this session's TileOPs working tree. Tests both matmul_ws_3wg_merged (FAIL) and matmul_ws_2wg (PASS).

File: _test_ws_3wg_gemm_v2.py tests matmul_ws_3wg_separate (PASS) with independent T.ws(1) + T.ws(2) consumers.

Environment

  • TileLang: 0.1.8+cuda.git5f70374c
  • GPU: NVIDIA H100
  • PyTorch: 2.9.1+cu128

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions