Skip to content

[BugFix] Consider non-local store in external call and SIMT producer for warp specialize#2166

Open
Rachmanino wants to merge 3 commits intotile-ai:mainfrom
Rachmanino:fix/call-extern-parallel-and-simt-barrier
Open

[BugFix] Consider non-local store in external call and SIMT producer for warp specialize#2166
Rachmanino wants to merge 3 commits intotile-ai:mainfrom
Rachmanino:fix/call-extern-parallel-and-simt-barrier

Conversation

@Rachmanino
Copy link
Copy Markdown
Collaborator

@Rachmanino Rachmanino commented May 7, 2026

Two bugs caused incorrect results when T.Parallel loops use T.call_extern with T.address_of(shared_buffer) inside pipelined, warp-specialized kernels:

Bug 1 (lower_tile_op.cc): has_non_local_store did not recognize address_of(shared_buffer) inside call_extern as a non-local access, so parallel_loop was set to false and thread partitioning was skipped entirely — every thread ran the full serial loop.

Reproduction:

"""
Minimal reproducer for Bug 1 (fixed in 0dba22aa):

  lower_tile_op.cc: has_non_local_store did not recognize
  address_of(shared_buffer) inside call_extern as a non-local access,
  so parallel_loop was set to false and thread partitioning was skipped
  entirely - every thread ran the full serial loop.

Trigger conditions:
  1. T.Parallel loop inside a warp-specialized pipeline (T.Pipelined
     with num_stages > 0, auto-WS triggered by T.copy/TMA).
  2. The parallel loop body uses T.call_extern with
     T.address_of(shared_buffer_element).
  3. No other non-local BufferStore inside the parallel loop body.

Without fix: all consumer threads execute every iteration.
With fix:    iterations are partitioned across consumer threads.

Observable effect:
  The external function `atomic_inc` does atomicAdd(ptr, 1).  With
  correct partitioning each shared element is bumped once per stage.
  With the bug, each element is bumped N times (N = consumer threads).
"""

import tilelang
from tilelang import language as T
import torch


@tilelang.jit
def bug1_kernel(
    M: int = 16,
    K: int = 32,
    block_M: int = 16,
    block_K: int = 16,
    dtype: str = "float16",
):
    @T.prim_func
    def main(
        A: T.Tensor((M, K), dtype),
        Counts: T.Tensor((M, block_K), "int32"),
    ):
        with T.Kernel(T.ceildiv(M, block_M), threads=128) as by:
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            counters = T.alloc_shared((block_K,), "int32")

            T.import_source(
                'extern "C" __device__ void atomic_inc(int* ptr) { atomicAdd(ptr, 1); }\n'
            )

            T.clear(counters)

            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
                # TMA copy triggers auto warp-specialization (2 warp groups)
                T.copy(A[by * block_M, ko * block_K], A_shared)

                # Bug 1 trigger: T.Parallel + call_extern + address_of(shared).
                # Without the fix, has_non_local_store stays false because
                # address_of(BufferLoad(shared)) inside call_extern is not
                # recognized as a non-local access.  Thread partitioning
                # is skipped - all consumer threads run the full loop.
                for i in T.Parallel(block_K):
                    T.call_extern("void", "atomic_inc", T.address_of(counters[i]))

            # Write counters to output for inspection
            for i in T.Parallel(block_K):
                Counts[by * block_M, i] = counters[i]

    return main


if __name__ == "__main__":
    M, K = 16, 32
    block_M, block_K = 16, 16

    kernel = bug1_kernel(M, K, block_M, block_K)

    A = torch.randn((M, K), device="cuda", dtype=torch.float16)
    Counts = torch.zeros((M, block_K), device="cuda", dtype=torch.int32)

    kernel(A, Counts)
    torch.cuda.synchronize()

    num_stages = T.ceildiv(K, block_K).value
    actual = Counts.cpu()

    print(f"Pipeline stages: {num_stages}")
    print(f"Counters (should all == {num_stages}):")
    print(actual.int())
    # Only first block row (by=0) was written; check that row
    written = actual[0]
    print(f"  first row: {written.int().tolist()}")

    if written.min().item() == num_stages and written.max().item() == num_stages:
        print("PASS: Partitioning correct - each element bumped once per stage.")
    elif written.min().item() > num_stages:
        print("FAIL (BUG PRESENT): counters > expected!")
        print("  Each element was bumped by EVERY consumer thread,")
        print("  proving thread partitioning was skipped.")
    else:
        print(f"UNEXPECTED: range [{written.min().item()}, {written.max().item()}]")

Bug 2(producer_consumer_ws.cc): wait_insert_pos was only computed for kTmaProducer statements. SIMT/cp.async producers (e.g. T.Parallel global→shared copies) were skipped, so the consumer read from SIMT-produced shared buffers before the barrier wait — a data race.

Reproduction:

"""
Minimal reproducer for Bug 2 (fixed in 0dba22aa):

  producer_consumer_ws.cc: wait_insert_pos was only computed for
  kTmaProducer statements.  SIMT/cp.async producers were skipped, so
  the consumer read from SIMT-produced shared buffers before the
  barrier wait - a data race.

Trigger conditions:
  1. Warp-specialized pipeline (num_stages > 0, auto-WS via TMA).
  2. At least one SIMT producer writes to a shared buffer.
  3. Consumer reads the SIMT-produced buffer at an EARLIER statement
     than the first TMA-produced buffer read.  (If both are consumed
     in the same gemm/tile-op, the TMA-derived wait position covers
     both and the bug is masked.)

What the generated CUDA looks like WITHOUT the fix:
  // Consumer warp group (T.ws(0)):
  stmt_0:  read B_shared[SIMT-produced]  // NO barrier wait yet!
  mbarrier_wait(...)                      // wait_insert_pos = 1 (from TMA)
  stmt_1:  read A_shared[TMA-produced]

  -> B_shared is read before the producer signalled "done".
  -> Consumer may see stale/partially-written data (data race).

WITH the fix:
  // Consumer warp group (T.ws(0)):
  mbarrier_wait(...)                      // wait_insert_pos = 0 (pulled back)
  stmt_0:  read B_shared[SIMT-produced]  // safe
  stmt_1:  read A_shared[TMA-produced]

Observable effect:
  Each pipeline stage the SIMT producer copies a distinct row of
  B values to B_shared.  The consumer reads B_shared at stmt 0.
  Without the fix, the consumer may read stale/wrong B_shared data
  because it races ahead of the SIMT producer's forward barrier.
"""

import tilelang
from tilelang import language as T
import torch


@tilelang.jit
def bug2_kernel(
    num_stages: int = 4,
    block_N: int = 64,
):
    @T.prim_func
    def main(
        A_in: T.Tensor((num_stages, block_N), "float16"),
        B_in: T.Tensor((num_stages, block_N), "float16"),
        B_out: T.Tensor((num_stages, block_N), "float16"),
        A_out: T.Tensor((num_stages, block_N), "float16"),
    ):
        with T.Kernel(1, threads=128) as _bx:
            A_shared = T.alloc_shared((block_N,), "float16")
            B_shared = T.alloc_shared((block_N,), "float16")

            for ko in T.Pipelined(num_stages, num_stages=2):
                # ---- Producer side (auto-assigned to T.ws(1)) ----
                # SIMT producer: global->shared copy of B_in.
                # Classified as kSimtProducer (writes_shared+reads_global).
                for i in T.Parallel(block_N):
                    B_shared[i] = B_in[ko, i]

                # TMA producer: global->shared copy of A_in.
                # Classified as kTmaProducer (triggers auto-WS on Hopper).
                T.copy(A_in[ko, :], A_shared)

                # ---- Consumer side (auto-assigned to T.ws(0)) ----
                # CRITICAL ORDERING: stmt 0 reads SIMT buffer BEFORE
                # stmt 1 reads TMA buffer.  Without the fix, the
                # compiler places the forward-barrier wait at stmt 1
                # (derived from TMA buffer's first read), so stmt 0
                # executes without waiting — a data race on B_shared.
                #
                # Use a single-element read (not T.Parallel) so the
                # compiler doesn't merge stmt 0 and stmt 1 into one
                # consumer statement.
                #
                # Stmt 0: read SIMT-produced B_shared (barrier-skipped w/o fix).
                B_out[ko, 0] = B_shared[0]

                # Stmt 1: read TMA-produced A_shared (barrier placed here w/o fix).
                for i in T.Parallel(block_N):
                    A_out[ko, i] = A_shared[i]

    return main


if __name__ == '__main__':
    num_stages = 4
    block_N = 64

    kernel = bug2_kernel(num_stages, block_N)

    A_in = torch.arange(num_stages * block_N, device="cuda", dtype=torch.float16).reshape(
        num_stages, block_N
    )
    # B_in: each stage has a distinct first-element value
    B_in = torch.zeros((num_stages, block_N), device="cuda", dtype=torch.float16)
    for k in range(num_stages):
        B_in[k, 0] = float(k * 100)  # stage 0=0, stage 1=100, stage 2=200, ...

    B_out = torch.zeros((num_stages, block_N), device="cuda", dtype=torch.float16)
    A_out = torch.zeros((num_stages, block_N), device="cuda", dtype=torch.float16)

    kernel(A_in, B_in, B_out, A_out)
    torch.cuda.synchronize()

    b_in_cpu = B_in.cpu()
    b_out_cpu = B_out.cpu()
    a_out_cpu = A_out.cpu()
    a_in_cpu = A_in.cpu()

    print("=== Bug 2 reproducer: SIMT producer barrier wait position ===")
    print()
    print("Each stage: SIMT producer copies B to B_shared [global->shared],")
    print("TMA copies A to A_shared.")
    print("Consumer reads B_shared[0] at stmt 0, A_shared at stmt 1.")
    print()

    b_ok = torch.allclose(b_out_cpu[:, 0], b_in_cpu[:, 0])
    a_ok = torch.allclose(a_out_cpu, a_in_cpu)

    print(f"B_out first-elem matches B_in: {b_ok}")
    print(f"A_out matches A_in: {a_ok}")

    if b_ok and a_ok:
        print()
        print("PASS: Both outputs match. Bug is fixed.")
    elif not b_ok:
        print()
        print("FAIL (BUG PRESENT): B_out first element does not match B_in!")
        print("  The consumer read B_shared[0] before the SIMT producer's")
        print("  forward-barrier arrive was waited on, getting a stale value.")
        for k in range(num_stages):
            expected = b_in_cpu[k, 0].item()
            actual = b_out_cpu[k, 0].item()
            print(f"  Stage {k}: got {actual:.0f} expected {expected:.0f}")
    else:
        print()
        print(f"UNEXPECTED: A mismatch (A_out)")

Summary by CodeRabbit

  • Performance Improvements

    • Better detection of non-local memory accesses inside parallel loops to avoid missed non-local writes.
    • Refined producer/consumer synchronization for SIMT and cp.async producers so producer waits are advanced to match consumer reads, reducing stalls.
  • Tests

    • Added a test validating that producer-side waits occur before the first consumer read in pipelined cp.async scenarios.

Review Change Stack

…lized pipelines

Two bugs caused incorrect results when T.Parallel loops use
T.call_extern with T.address_of(shared_buffer) inside pipelined,
warp-specialized kernels:

Bug 1 (lower_tile_op.cc): has_non_local_store did not recognize
address_of(shared_buffer) inside call_extern as a non-local access,
so parallel_loop was set to false and thread partitioning was skipped
entirely — every thread ran the full serial loop.

Bug 3 (producer_consumer_ws.cc): wait_insert_pos was only computed
for kTmaProducer statements. SIMT/cp.async producers (e.g. T.Parallel
global→shared copies) were skipped, so the consumer read from
SIMT-produced shared buffers before the barrier wait — a data race.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 7, 2026

📝 Walkthrough

Walkthrough

This PR extends non-local store detection in the parallel-loop visitor to recognize address_of(BufferLoad(non-local)) patterns reachable from call_extern and adds a SIMT/cp.async adjustment that clamps producer-group forward barrier wait_insert_pos values to the earliest consumer read of shared buffers; tests for cp.async wait ordering are added.

Changes

Non-Local Store Detection in Parallel Loops

Layer / File(s) Summary
Call Extern Argument Scanning
src/transform/lower_tile_op.cc
Non-local store detection now treats address_of(BufferLoad(non-local_buffer)) patterns as non-local and sets has_non_local_store = true.

Barrier Wait Positioning for SIMT/cp.async Producers

Layer / File(s) Summary
Find First Async Producer-Consumer Read
src/transform/producer_consumer_ws.cc
New helper computes the earliest consumer compute-stmt index that reads any shared buffer written by a SIMT/cp.async producer.
Clamp Producer-Group Wait Positions
src/transform/producer_consumer_ws.cc
When SIMT or cp.async producers exist, compute the minimum earliest consumer read across those producers and clamp per-producer-group wait_insert_pos values to at most that index.
Tests / Runner
testing/python/transform/test_tilelang_transform_producer_consumer_ws.py
Add explicit_cp_async_wait_position kernel generator, a test asserting mbarrier wait precedes cp.async consumer reads which precede TMA reads, and update __main__ runner.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • tile-ai/tilelang#1835: Both PRs modify non-local buffer detection in LowerTileOpPass within src/transform/lower_tile_op.cc.
  • tile-ai/tilelang#1917: Related changes to non-local-store detection and how that flag drives parallel-loop partitioning decisions.
  • tile-ai/tilelang#1689: Related address_of and access_ptr handling changes in lower_tile_op.cc.

Suggested reviewers

  • LeiWang1999
  • chengyupku

Poem

🐰 I sniffed the loop where addresses hide,
Found BufferLoads tucked well inside.
I nudged the waits to earlier days,
So cp.async and consumers play,
Happy hops — no races in my stride!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 58.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly describes the main bugfix focus: considering non-local stores in external calls and SIMT producers for warp specialization, which aligns with the core changes across all three files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/transform/producer_consumer_ws.cc`:
- Around line 1333-1377: The kCpAsyncProducer branch currently only looks for
BufferStoreNode (so written_vars stays empty for cp.async); modify the loop that
builds written_vars (inside the has_simt_producer || has_cp_async_producer
block) to also detect cp.async call statements (CallNode/PrimCallNode) that
correspond to TileStmtKind::kCpAsyncProducer, decode their tvm_access_ptr /
tl::access_ptr arguments to extract target shared buffers and rw-masks (reuse
the same logic/pattern from AnalyzeBufferDataAccess / LocalAccessCollector to
parse pointer base and mask), and insert the underlying VarNode* into
written_vars when the call writes shared memory; this will ensure
earliest_simt_read and wait_insert_pos are updated correctly for cp.async
producers alongside BufferStoreNode writes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4c966c02-5d6a-4be5-bb53-5c957a1c5a0a

📥 Commits

Reviewing files that changed from the base of the PR and between 0fdd0f8 and 0dba22a.

📒 Files selected for processing (2)
  • src/transform/lower_tile_op.cc
  • src/transform/producer_consumer_ws.cc

Comment thread src/transform/producer_consumer_ws.cc
@Rachmanino Rachmanino changed the title [WIP][BugFix] Consider non-local store in external call and SIMT producer for warp specialize [BugFix] Consider non-local store in external call and SIMT producer for warp specialize May 8, 2026
Comment thread src/transform/lower_tile_op.cc Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
testing/python/transform/test_tilelang_transform_producer_consumer_ws.py (1)

462-481: ⚡ Quick win

Consider adding a correctness check.

The test validates structural patterns in the generated code but doesn't verify the kernel produces correct results when executed. For a bug-fix test, running the kernel and checking output would increase confidence that the fix resolves the race condition mentioned in the PR.

💡 Example correctness check

After the structural assertions, you could add:

# Verify correctness
import torch
iters, block, cp_elems = 4, 16, 8
A = torch.randn(iters, block, dtype=torch.float16, device="cuda")
B = torch.randn(iters, cp_elems, dtype=torch.float16, device="cuda")

# Compile and run using the transformed module
target = determine_target()
kernel = tilelang.compile(mod["main"], target=target, out_idx=[2, 3])
B_out, A_out = kernel(A, B)

# Verify outputs match inputs
torch.testing.assert_close(A_out, A, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(B_out[:, 0], B[:, 0], rtol=1e-3, atol=1e-3)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@testing/python/transform/test_tilelang_transform_producer_consumer_ws.py`
around lines 462 - 481, The test only asserts generated code patterns but should
also run the transformed kernel to validate correctness: after the existing
structural assertions in
test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read, build sample
input tensors (e.g., using torch on CUDA), compile the transformed function
mod["main"] with tilelang.compile (using determine_target() or the same target
used earlier) to get a kernel, run the kernel on those inputs, and add numeric
assertions (e.g., torch.testing.assert_close) comparing A_out and B_out to
expected values or to the original inputs to ensure the cp.async reordering did
not break correctness; reference explicit_cp_async_wait_position, mod["main"],
tilelang.compile, determine_target, kernel, and torch.testing.assert_close when
locating where to insert the runtime correctness check.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@testing/python/transform/test_tilelang_transform_producer_consumer_ws.py`:
- Around line 462-481: The test function
test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read is missing the
CUDA test decorators used by other CUDA tests; add the same decorators (e.g.,
`@tvm.testing.requires_cuda` and the file's CUDA marker such as `@pytest.mark.cuda`
or whichever marker is used consistently in this file) immediately above the
function definition so the test is only collected/run on systems with CUDA
available.

---

Nitpick comments:
In `@testing/python/transform/test_tilelang_transform_producer_consumer_ws.py`:
- Around line 462-481: The test only asserts generated code patterns but should
also run the transformed kernel to validate correctness: after the existing
structural assertions in
test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read, build sample
input tensors (e.g., using torch on CUDA), compile the transformed function
mod["main"] with tilelang.compile (using determine_target() or the same target
used earlier) to get a kernel, run the kernel on those inputs, and add numeric
assertions (e.g., torch.testing.assert_close) comparing A_out and B_out to
expected values or to the original inputs to ensure the cp.async reordering did
not break correctness; reference explicit_cp_async_wait_position, mod["main"],
tilelang.compile, determine_target, kernel, and torch.testing.assert_close when
locating where to insert the runtime correctness check.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f0d29adb-9696-43b9-b9b1-4764de19599b

📥 Commits

Reviewing files that changed from the base of the PR and between bddf8f5 and f1d1d77.

📒 Files selected for processing (3)
  • src/transform/lower_tile_op.cc
  • src/transform/producer_consumer_ws.cc
  • testing/python/transform/test_tilelang_transform_producer_consumer_ws.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/transform/lower_tile_op.cc
  • src/transform/producer_consumer_ws.cc

Comment on lines +462 to +481
def test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read():
"""Explicit cp.async destinations must pull the consumer wait earlier."""

func = explicit_cp_async_wait_position().with_attr("global_symbol", "main")
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.BindTarget(tvm.target.Target("cuda -arch=sm_90"))(mod)
mod = tilelang.transform.ProducerConsumerWarpSpecialized()(mod)
script = mod["main"].script()

assert "tl_tiled_ws_applied" in script
assert "T.ptx_cp_async" in script
assert "T.tma_copy" in script

consumer_branch = _find_after(script, "else:")
wait = _find_after(script, "T.mbarrier_wait_parity", consumer_branch)
cp_async_read = _find_after(script, "B_out[ko] = B_shared[0]", consumer_branch)
tma_read = _find_after(script, "A_out[ko, i] = A_shared", consumer_branch)

assert wait < cp_async_read < tma_read

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Add required CUDA test decorators.

The test is missing decorators that are present on all other CUDA tests in this file. Without them, the test will attempt to run on non-CUDA systems and fail.

🔧 Proposed fix
+@tilelang.testing.requires_cuda
+@tilelang.testing.requires_cuda_compute_version(9, 0)
 def test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read():
     """Explicit cp.async destinations must pull the consumer wait earlier."""
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@testing/python/transform/test_tilelang_transform_producer_consumer_ws.py`
around lines 462 - 481, The test function
test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read is missing the
CUDA test decorators used by other CUDA tests; add the same decorators (e.g.,
`@tvm.testing.requires_cuda` and the file's CUDA marker such as `@pytest.mark.cuda`
or whichever marker is used consistently in this file) immediately above the
function definition so the test is only collected/run on systems with CUDA
available.

@LeiWang1999
Copy link
Copy Markdown
Member

@regression-perf

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 9, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/25596047923

Results

File Original Latency Current Latency Speedup
example_linear_attn_fwd 0.0364438 0.0366095 0.995473
example_warp_specialize_gemm_barrierpipe_stage2 0.0401829 0.0402565 0.998174
example_topk 0.0111118 0.0111293 0.998428
example_tilelang_gemm_splitk 0.981939 0.983355 0.99856
example_vertical_slash_sparse_attn 0.231145 0.231363 0.999059
example_gemv 0.28822 0.288433 0.999261
example_mha_fwd_bhsd 0.0116164 0.0116236 0.99938
example_dequant_gemv_fp16xint4 0.0283393 0.0283564 0.999399
example_gqa_decode 0.048613 0.0486346 0.999556
example_mha_fwd_varlen 0.0444584 0.0444781 0.999556
example_mha_bwd_bhsd 0.0407856 0.0408035 0.999561
example_tilelang_gemm_fp8_2xAcc 0.128112 0.128165 0.99959
example_gqa_fwd_bshd 0.0690085 0.0690271 0.99973
block_sparse_attn_tilelang 0.00918125 0.00918272 0.99984
example_mha_inference 0.0787723 0.078782 0.999876
example_mha_bwd_bshd 0.0402041 0.0402078 0.99991
example_gemm_autotune 0.022521 0.0225223 0.999941
example_dequant_gemm_w4a8 5.5794 5.57931 1.00002
example_gemm_intrinsics 0.0348683 0.0348665 1.00005
example_mha_fwd_bshd 0.0248379 0.0248331 1.00019
example_gqa_bwd_tma_reduce_varlen 0.0463616 0.0463525 1.0002
example_fusedmoe_tilelang 0.13316 0.133132 1.00021
example_dynamic 0.638025 0.637867 1.00025
example_tilelang_gemm_fp8 0.305084 0.304958 1.00041
example_mhc_post 0.109868 0.109818 1.00046
example_tilelang_gemm_splitk_vectorize_atomicadd 0.983964 0.983328 1.00065
example_warp_specialize_gemm_copy_1_gemm_0 0.0275745 0.0275566 1.00065
example_gqa_bwd 0.0465776 0.0465423 1.00076
example_elementwise_add 0.115664 0.115545 1.00103
example_linear_attn_bwd 0.153298 0.153111 1.00122
example_warp_specialize_gemm_softpipe_stage2 0.0275789 0.0275378 1.00149
example_gemm 0.0223221 0.0222774 1.002
example_convolution_autotune 0.982107 0.980131 1.00202
example_dequant_gemm_fp4_hopper 1.03333 1.03078 1.00247
example_per_token_cast_to_fp8 0.00738878 0.0073696 1.0026
example_mha_sink_fwd_bhsd 0.0169761 0.0169317 1.00262
example_group_per_split_token_cast_to_fp8 0.0104284 0.010395 1.00321
example_tilelang_nsa_decode 0.00683811 0.0068161 1.00323
example_tilelang_nsa_fwd 0.00695168 0.00692025 1.00454
example_tilelang_block_sparse_attn 0.00941585 0.00936614 1.00531
sparse_mla_bwd 0.29645 0.294797 1.00561
example_blocksparse_gemm 0.019253 0.019141 1.00585
example_mha_sink_fwd_bhsd_sliding_window 0.0163369 0.0162402 1.00595
example_mha_sink_bwd_bhsd 0.0675632 0.0671205 1.0066
sparse_mla_fwd_pipelined 0.0906498 0.0900504 1.00666
example_mha_sink_bwd_bhsd_sliding_window 0.0497073 0.0493689 1.00686
fp8_lighting_indexer 0.0326338 0.032388 1.00759
sparse_mla_fwd 0.108755 0.107925 1.00769
example_tilelang_sparse_gqa_decode_varlen_indice 0.0161101 0.0159811 1.00807
example_tilelang_sparse_gqa_decode_varlen_mask 0.0177615 0.0176163 1.00824
topk_selector 0.0543538 0.0539015 1.00839
example_dequant_gemm_bf16_fp4_hopper 0.562544 0.556589 1.0107
example_mhc_pre 0.154033 0.15235 1.01105
example_gqa_sink_bwd_bhsd_sliding_window 0.0255441 0.0252646 1.01106
example_convolution 1.30471 1.29029 1.01117
example_gqa_sink_bwd_bhsd 0.0432816 0.0427668 1.01204
example_warp_specialize_gemm_copy_0_gemm_1 0.0373537 0.036904 1.01218
example_dequant_gemm_bf16_mxfp4_hopper 0.520854 0.514256 1.01283
example_mla_decode 0.453897 0.447822 1.01357

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants