[BugFix] Consider non-local store in external call and SIMT producer for warp specialize#2166
Conversation
…lized pipelines Two bugs caused incorrect results when T.Parallel loops use T.call_extern with T.address_of(shared_buffer) inside pipelined, warp-specialized kernels: Bug 1 (lower_tile_op.cc): has_non_local_store did not recognize address_of(shared_buffer) inside call_extern as a non-local access, so parallel_loop was set to false and thread partitioning was skipped entirely — every thread ran the full serial loop. Bug 3 (producer_consumer_ws.cc): wait_insert_pos was only computed for kTmaProducer statements. SIMT/cp.async producers (e.g. T.Parallel global→shared copies) were skipped, so the consumer read from SIMT-produced shared buffers before the barrier wait — a data race. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis PR extends non-local store detection in the parallel-loop visitor to recognize ChangesNon-Local Store Detection in Parallel Loops
Barrier Wait Positioning for SIMT/cp.async Producers
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/transform/producer_consumer_ws.cc`:
- Around line 1333-1377: The kCpAsyncProducer branch currently only looks for
BufferStoreNode (so written_vars stays empty for cp.async); modify the loop that
builds written_vars (inside the has_simt_producer || has_cp_async_producer
block) to also detect cp.async call statements (CallNode/PrimCallNode) that
correspond to TileStmtKind::kCpAsyncProducer, decode their tvm_access_ptr /
tl::access_ptr arguments to extract target shared buffers and rw-masks (reuse
the same logic/pattern from AnalyzeBufferDataAccess / LocalAccessCollector to
parse pointer base and mask), and insert the underlying VarNode* into
written_vars when the call writes shared memory; this will ensure
earliest_simt_read and wait_insert_pos are updated correctly for cp.async
producers alongside BufferStoreNode writes.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4c966c02-5d6a-4be5-bb53-5c957a1c5a0a
📒 Files selected for processing (2)
src/transform/lower_tile_op.ccsrc/transform/producer_consumer_ws.cc
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
testing/python/transform/test_tilelang_transform_producer_consumer_ws.py (1)
462-481: ⚡ Quick winConsider adding a correctness check.
The test validates structural patterns in the generated code but doesn't verify the kernel produces correct results when executed. For a bug-fix test, running the kernel and checking output would increase confidence that the fix resolves the race condition mentioned in the PR.
💡 Example correctness check
After the structural assertions, you could add:
# Verify correctness import torch iters, block, cp_elems = 4, 16, 8 A = torch.randn(iters, block, dtype=torch.float16, device="cuda") B = torch.randn(iters, cp_elems, dtype=torch.float16, device="cuda") # Compile and run using the transformed module target = determine_target() kernel = tilelang.compile(mod["main"], target=target, out_idx=[2, 3]) B_out, A_out = kernel(A, B) # Verify outputs match inputs torch.testing.assert_close(A_out, A, rtol=1e-3, atol=1e-3) torch.testing.assert_close(B_out[:, 0], B[:, 0], rtol=1e-3, atol=1e-3)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@testing/python/transform/test_tilelang_transform_producer_consumer_ws.py` around lines 462 - 481, The test only asserts generated code patterns but should also run the transformed kernel to validate correctness: after the existing structural assertions in test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read, build sample input tensors (e.g., using torch on CUDA), compile the transformed function mod["main"] with tilelang.compile (using determine_target() or the same target used earlier) to get a kernel, run the kernel on those inputs, and add numeric assertions (e.g., torch.testing.assert_close) comparing A_out and B_out to expected values or to the original inputs to ensure the cp.async reordering did not break correctness; reference explicit_cp_async_wait_position, mod["main"], tilelang.compile, determine_target, kernel, and torch.testing.assert_close when locating where to insert the runtime correctness check.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@testing/python/transform/test_tilelang_transform_producer_consumer_ws.py`:
- Around line 462-481: The test function
test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read is missing the
CUDA test decorators used by other CUDA tests; add the same decorators (e.g.,
`@tvm.testing.requires_cuda` and the file's CUDA marker such as `@pytest.mark.cuda`
or whichever marker is used consistently in this file) immediately above the
function definition so the test is only collected/run on systems with CUDA
available.
---
Nitpick comments:
In `@testing/python/transform/test_tilelang_transform_producer_consumer_ws.py`:
- Around line 462-481: The test only asserts generated code patterns but should
also run the transformed kernel to validate correctness: after the existing
structural assertions in
test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read, build sample
input tensors (e.g., using torch on CUDA), compile the transformed function
mod["main"] with tilelang.compile (using determine_target() or the same target
used earlier) to get a kernel, run the kernel on those inputs, and add numeric
assertions (e.g., torch.testing.assert_close) comparing A_out and B_out to
expected values or to the original inputs to ensure the cp.async reordering did
not break correctness; reference explicit_cp_async_wait_position, mod["main"],
tilelang.compile, determine_target, kernel, and torch.testing.assert_close when
locating where to insert the runtime correctness check.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f0d29adb-9696-43b9-b9b1-4764de19599b
📒 Files selected for processing (3)
src/transform/lower_tile_op.ccsrc/transform/producer_consumer_ws.cctesting/python/transform/test_tilelang_transform_producer_consumer_ws.py
🚧 Files skipped from review as they are similar to previous changes (2)
- src/transform/lower_tile_op.cc
- src/transform/producer_consumer_ws.cc
| def test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read(): | ||
| """Explicit cp.async destinations must pull the consumer wait earlier.""" | ||
|
|
||
| func = explicit_cp_async_wait_position().with_attr("global_symbol", "main") | ||
| mod = tvm.IRModule.from_expr(func) | ||
| mod = tvm.tir.transform.BindTarget(tvm.target.Target("cuda -arch=sm_90"))(mod) | ||
| mod = tilelang.transform.ProducerConsumerWarpSpecialized()(mod) | ||
| script = mod["main"].script() | ||
|
|
||
| assert "tl_tiled_ws_applied" in script | ||
| assert "T.ptx_cp_async" in script | ||
| assert "T.tma_copy" in script | ||
|
|
||
| consumer_branch = _find_after(script, "else:") | ||
| wait = _find_after(script, "T.mbarrier_wait_parity", consumer_branch) | ||
| cp_async_read = _find_after(script, "B_out[ko] = B_shared[0]", consumer_branch) | ||
| tma_read = _find_after(script, "A_out[ko, i] = A_shared", consumer_branch) | ||
|
|
||
| assert wait < cp_async_read < tma_read | ||
|
|
There was a problem hiding this comment.
Add required CUDA test decorators.
The test is missing decorators that are present on all other CUDA tests in this file. Without them, the test will attempt to run on non-CUDA systems and fail.
🔧 Proposed fix
+@tilelang.testing.requires_cuda
+@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read():
"""Explicit cp.async destinations must pull the consumer wait earlier."""🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@testing/python/transform/test_tilelang_transform_producer_consumer_ws.py`
around lines 462 - 481, The test function
test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read is missing the
CUDA test decorators used by other CUDA tests; add the same decorators (e.g.,
`@tvm.testing.requires_cuda` and the file's CUDA marker such as `@pytest.mark.cuda`
or whichever marker is used consistently in this file) immediately above the
function definition so the test is only collected/run on systems with CUDA
available.
|
@regression-perf |
Performance Regression Test ReportTriggered by: @LeiWang1999 Results
Artifacts
|
Two bugs caused incorrect results when T.Parallel loops use T.call_extern with T.address_of(shared_buffer) inside pipelined, warp-specialized kernels:
Bug 1 (lower_tile_op.cc): has_non_local_store did not recognize address_of(shared_buffer) inside call_extern as a non-local access, so parallel_loop was set to false and thread partitioning was skipped entirely — every thread ran the full serial loop.
Reproduction:
Bug 2(producer_consumer_ws.cc): wait_insert_pos was only computed for kTmaProducer statements. SIMT/cp.async producers (e.g. T.Parallel global→shared copies) were skipped, so the consumer read from SIMT-produced shared buffers before the barrier wait — a data race.
Reproduction:
Summary by CodeRabbit
Performance Improvements
Tests