diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 7267f17ae..ea92ac66a 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -1247,6 +1247,14 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { has_non_local_store = true; } } + } else if (call->op.same_as(builtin::address_of())) { + // call_extern may pass address_of(non-local-buffer) pointers, and + // PostOrderVisit reaches the address_of call directly. + if (const auto *load = call->args[0].as()) { + if (!IsLocalBuffer(load->buffer)) { + has_non_local_store = true; + } + } } } }); diff --git a/src/transform/producer_consumer_ws.cc b/src/transform/producer_consumer_ws.cc index 943759408..c1598e909 100644 --- a/src/transform/producer_consumer_ws.cc +++ b/src/transform/producer_consumer_ws.cc @@ -1027,6 +1027,67 @@ static Optional ExtractProducerWriteBufferData(const Stmt &stmt) { return Optional(); } +static int +FindFirstAsyncProducerConsumerRead(const Stmt &producer_stmt, + const Array &consumer_compute_stmts, + const BufferDataToBufferMap &buffer_map) { + int earliest_read = static_cast(consumer_compute_stmts.size()); + auto update_earliest_read = [&](const Var &buffer_data) { + for (size_t ci = 0; ci < static_cast(earliest_read); ++ci) { + BufferDataAccessInfo access = AnalyzeBufferDataAccess( + consumer_compute_stmts[ci], buffer_data, buffer_map); + if (access.read) { + earliest_read = static_cast(ci); + return; + } + } + }; + if (Optional write_buffer_data = + ExtractProducerWriteBufferData(producer_stmt)) { + update_earliest_read(write_buffer_data.value()); + } + PostOrderVisit(producer_stmt, [&](const ObjectRef &obj) { + if (earliest_read == 0) { + return; + } + if (const auto *store = obj.as()) { + if (IsSharedBuffer(store->buffer)) { + update_earliest_read(store->buffer->data); + } + return; + } + const auto *call = obj.as(); + if (!call || !(call->op.same_as(builtin::ptx_cp_async()) || + call->op.same_as(tl::ptx_cp_async()))) { + return; + } + PostOrderVisit(call->args[0], [&](const ObjectRef &ptr_obj) { + if (earliest_read == 0) { + return; + } + if (const auto *load = ptr_obj.as()) { + if (IsSharedBuffer(load->buffer)) { + update_earliest_read(load->buffer->data); + } + return; + } + const auto *ptr_call = ptr_obj.as(); + if (!ptr_call || !ptr_call->op.same_as(builtin::tvm_access_ptr())) { + return; + } + const auto *var = ptr_call->args[1].as(); + if (!var) { + return; + } + auto it = buffer_map.find(ffi::GetRef(var)); + if (it != buffer_map.end() && IsSharedBuffer(it->second)) { + update_earliest_read(it->second->data); + } + }); + }); + return earliest_read; +} + static Stmt RewritePreludeTmaProducerStmt(const Stmt &stmt, const Buffer &barrier_buf, PrimExpr barrier_id) { @@ -1330,6 +1391,27 @@ class ProducerConsumerWSRewriter : public StmtExprMutator { ++access_group_idx; } + // --- Adjust wait positions for SIMT/cp.async producers --- + // SIMT and cp.async producers tie their completion to all forward barriers. + // If a consumer reads any such shared destination before the first TMA + // read, pull all waits earlier so the async producer is also covered. + if (has_simt_producer || has_cp_async_producer) { + int earliest_async_read = static_cast(consumer_compute_stmts.size()); + for (size_t i = 0; i < flat_stmts.size(); ++i) { + if (kinds[i] != TileStmtKind::kSimtProducer && + kinds[i] != TileStmtKind::kCpAsyncProducer) { + continue; + } + int first_read = FindFirstAsyncProducerConsumerRead( + flat_stmts[i], consumer_compute_stmts, buffer_data_to_buffer_); + earliest_async_read = std::min(earliest_async_read, first_read); + } + // Pull all wait positions earlier if needed. + for (int g = 0; g < num_producer_groups; ++g) { + wait_insert_pos[g] = std::min(wait_insert_pos[g], earliest_async_read); + } + } + // --- Determine if TMA barriers can be merged --- // When all pure-TMA producers wait at the same consumer position and // release at the same position, forward and back-pressure barriers can diff --git a/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py b/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py index eee80b2c4..eb5ed7793 100644 --- a/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py +++ b/testing/python/transform/test_tilelang_transform_producer_consumer_ws.py @@ -114,6 +114,34 @@ def main( return main +def explicit_cp_async_wait_position(iters=4, block=16, cp_elems=8, dtype="float16", threads=128): + """A mixed TMA + explicit cp.async pipeline with cp.async consumed first.""" + + @T.prim_func + def main( + A: T.Buffer((iters, block), dtype), + B: T.Buffer((iters, cp_elems), dtype), + B_out: T.Buffer((iters,), dtype), + A_out: T.Buffer((iters, block), dtype), + ): + with T.Kernel(1, threads=threads) as _: + A_shared = T.alloc_shared((block,), dtype) + B_shared = T.alloc_shared((cp_elems,), dtype) + + for ko in T.Pipelined(iters, num_stages=2): + T.ptx_cp_async( + T.access_ptr(B_shared[0], "w", cp_elems), + T.access_ptr(B[ko, 0], "r", cp_elems), + cp_elems, + ) + T.copy(A[ko, 0], A_shared) + B_out[ko] = B_shared[0] + for i in T.Parallel(block): + A_out[ko, i] = A_shared[i] + + return main + + def grouped_gemm_padded_pipelined( batch_sizes, K, @@ -431,6 +459,27 @@ def test_tiled_ws_sinks_preloop_tma_waits_into_consumer(): assert k_load < v_load < branch < first_wait +def test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read(): + """Explicit cp.async destinations must pull the consumer wait earlier.""" + + func = explicit_cp_async_wait_position().with_attr("global_symbol", "main") + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.BindTarget(tvm.target.Target("cuda -arch=sm_90"))(mod) + mod = tilelang.transform.ProducerConsumerWarpSpecialized()(mod) + script = mod["main"].script() + + assert "tl_tiled_ws_applied" in script + assert "T.ptx_cp_async" in script + assert "T.tma_copy" in script + + consumer_branch = _find_after(script, "else:") + wait = _find_after(script, "T.mbarrier_wait_parity", consumer_branch) + cp_async_read = _find_after(script, "B_out[ko] = B_shared[0]", consumer_branch) + tma_read = _find_after(script, "A_out[ko, i] = A_shared", consumer_branch) + + assert wait < cp_async_read < tma_read + + @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(9, 0) def test_tiled_ws_keeps_shared_prelude_local_vars_for_grouped_gemm(): @@ -468,5 +517,6 @@ def test_tiled_ws_does_not_clone_local_var_into_producer_branch(): test_tiled_ws_swizzled_layout_allows_ws() test_tiled_ws_incompatible_layout_blocks_ws() test_tiled_ws_sinks_preloop_tma_waits_into_consumer() + test_tiled_ws_explicit_cp_async_wait_precedes_first_consumer_read() test_tiled_ws_keeps_shared_prelude_local_vars_for_grouped_gemm() test_tiled_ws_does_not_clone_local_var_into_producer_branch()