Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/transform/lower_tile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,14 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
has_non_local_store = true;
}
}
} else if (call->op.same_as(builtin::address_of())) {
// call_extern may pass address_of(non-local-buffer) pointers, and
// PostOrderVisit reaches the address_of call directly.
if (const auto *load = call->args[0].as<BufferLoadNode>()) {
if (!IsLocalBuffer(load->buffer)) {
has_non_local_store = true;
}
}
}
}
});
Expand Down
82 changes: 82 additions & 0 deletions src/transform/producer_consumer_ws.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,67 @@ static Optional<Var> ExtractProducerWriteBufferData(const Stmt &stmt) {
return Optional<Var>();
}

static int
FindFirstAsyncProducerConsumerRead(const Stmt &producer_stmt,
const Array<Stmt> &consumer_compute_stmts,
const BufferDataToBufferMap &buffer_map) {
int earliest_read = static_cast<int>(consumer_compute_stmts.size());
auto update_earliest_read = [&](const Var &buffer_data) {
for (size_t ci = 0; ci < static_cast<size_t>(earliest_read); ++ci) {
BufferDataAccessInfo access = AnalyzeBufferDataAccess(
consumer_compute_stmts[ci], buffer_data, buffer_map);
if (access.read) {
earliest_read = static_cast<int>(ci);
return;
}
}
};
if (Optional<Var> write_buffer_data =
ExtractProducerWriteBufferData(producer_stmt)) {
update_earliest_read(write_buffer_data.value());
}
PostOrderVisit(producer_stmt, [&](const ObjectRef &obj) {
if (earliest_read == 0) {
return;
}
if (const auto *store = obj.as<BufferStoreNode>()) {
if (IsSharedBuffer(store->buffer)) {
update_earliest_read(store->buffer->data);
}
return;
}
const auto *call = obj.as<CallNode>();
if (!call || !(call->op.same_as(builtin::ptx_cp_async()) ||
call->op.same_as(tl::ptx_cp_async()))) {
return;
}
PostOrderVisit(call->args[0], [&](const ObjectRef &ptr_obj) {
if (earliest_read == 0) {
return;
}
if (const auto *load = ptr_obj.as<BufferLoadNode>()) {
if (IsSharedBuffer(load->buffer)) {
update_earliest_read(load->buffer->data);
}
return;
}
const auto *ptr_call = ptr_obj.as<CallNode>();
if (!ptr_call || !ptr_call->op.same_as(builtin::tvm_access_ptr())) {
return;
}
const auto *var = ptr_call->args[1].as<VarNode>();
if (!var) {
return;
}
auto it = buffer_map.find(ffi::GetRef<Var>(var));
if (it != buffer_map.end() && IsSharedBuffer(it->second)) {
update_earliest_read(it->second->data);
}
});
});
return earliest_read;
}

static Stmt RewritePreludeTmaProducerStmt(const Stmt &stmt,
const Buffer &barrier_buf,
PrimExpr barrier_id) {
Expand Down Expand Up @@ -1330,6 +1391,27 @@ class ProducerConsumerWSRewriter : public StmtExprMutator {
++access_group_idx;
}

// --- Adjust wait positions for SIMT/cp.async producers ---
// SIMT and cp.async producers tie their completion to all forward barriers.
// If a consumer reads any such shared destination before the first TMA
// read, pull all waits earlier so the async producer is also covered.
if (has_simt_producer || has_cp_async_producer) {
int earliest_async_read = static_cast<int>(consumer_compute_stmts.size());
for (size_t i = 0; i < flat_stmts.size(); ++i) {
if (kinds[i] != TileStmtKind::kSimtProducer &&
kinds[i] != TileStmtKind::kCpAsyncProducer) {
continue;
}
int first_read = FindFirstAsyncProducerConsumerRead(
flat_stmts[i], consumer_compute_stmts, buffer_data_to_buffer_);
earliest_async_read = std::min(earliest_async_read, first_read);
}
// Pull all wait positions earlier if needed.
for (int g = 0; g < num_producer_groups; ++g) {
wait_insert_pos[g] = std::min(wait_insert_pos[g], earliest_async_read);
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// --- Determine if TMA barriers can be merged ---
// When all pure-TMA producers wait at the same consumer position and
// release at the same position, forward and back-pressure barriers can
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,34 @@ def main(
return main


def explicit_cp_async_wait_position(iters=4, block=16, cp_elems=8, dtype="float16", threads=128):
"""A mixed TMA + explicit cp.async pipeline with cp.async consumed first."""

@T.prim_func
def main(
A: T.Buffer((iters, block), dtype),
B: T.Buffer((iters, cp_elems), dtype),
B_out: T.Buffer((iters,), dtype),
A_out: T.Buffer((iters, block), dtype),
):
with T.Kernel(1, threads=threads) as _:
A_shared = T.alloc_shared((block,), dtype)
B_shared = T.alloc_shared((cp_elems,), dtype)

for ko in T.Pipelined(iters, num_stages=2):
T.ptx_cp_async(
T.access_ptr(B_shared[0], "w", cp_elems),
T.access_ptr(B[ko, 0], "r", cp_elems),
cp_elems,
)
T.copy(A[ko, 0], A_shared)
B_out[ko] = B_shared[0]
for i in T.Parallel(block):
A_out[ko, i] = A_shared[i]

return main


def grouped_gemm_padded_pipelined(
batch_sizes,
K,
Expand Down Expand Up @@ -431,6 +459,27 @@ def test_tiled_ws_sinks_preloop_tma_waits_into_consumer():
assert k_load < v_load < branch < first_wait


def test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read():
"""Explicit cp.async destinations must pull the consumer wait earlier."""

func = explicit_cp_async_wait_position().with_attr("global_symbol", "main")
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.BindTarget(tvm.target.Target("cuda -arch=sm_90"))(mod)
mod = tilelang.transform.ProducerConsumerWarpSpecialized()(mod)
script = mod["main"].script()

assert "tl_tiled_ws_applied" in script
assert "T.ptx_cp_async" in script
assert "T.tma_copy" in script

consumer_branch = _find_after(script, "else:")
wait = _find_after(script, "T.mbarrier_wait_parity", consumer_branch)
cp_async_read = _find_after(script, "B_out[ko] = B_shared[0]", consumer_branch)
tma_read = _find_after(script, "A_out[ko, i] = A_shared", consumer_branch)

assert wait < cp_async_read < tma_read

Comment on lines +462 to +481
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Add required CUDA test decorators.

The test is missing decorators that are present on all other CUDA tests in this file. Without them, the test will attempt to run on non-CUDA systems and fail.

🔧 Proposed fix
+@tilelang.testing.requires_cuda
+@tilelang.testing.requires_cuda_compute_version(9, 0)
 def test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read():
     """Explicit cp.async destinations must pull the consumer wait earlier."""
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@testing/python/transform/test_tilelang_transform_producer_consumer_ws.py`
around lines 462 - 481, The test function
test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read is missing the
CUDA test decorators used by other CUDA tests; add the same decorators (e.g.,
`@tvm.testing.requires_cuda` and the file's CUDA marker such as `@pytest.mark.cuda`
or whichever marker is used consistently in this file) immediately above the
function definition so the test is only collected/run on systems with CUDA
available.


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_tiled_ws_keeps_shared_prelude_local_vars_for_grouped_gemm():
Expand Down Expand Up @@ -468,5 +517,6 @@ def test_tiled_ws_does_not_clone_local_var_into_producer_branch():
test_tiled_ws_swizzled_layout_allows_ws()
test_tiled_ws_incompatible_layout_blocks_ws()
test_tiled_ws_sinks_preloop_tma_waits_into_consumer()
test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read()
test_tiled_ws_keeps_shared_prelude_local_vars_for_grouped_gemm()
test_tiled_ws_does_not_clone_local_var_into_producer_branch()
Loading