diff --git a/.agent/skills/translate_cpp2py/references/example_translation/.gitignore b/.agent/skills/translate_cpp2py/references/example_translation/.gitignore deleted file mode 100644 index 72e8ffc0..00000000 --- a/.agent/skills/translate_cpp2py/references/example_translation/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* diff --git a/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/compile.sh new file mode 100644 index 00000000..7c64ddc7 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./geglu_builder.py > ./geglu.pto +ptoas --enable-insert-sync ./geglu.pto -o ./geglu.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/geglu.cpp b/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/geglu.cpp new file mode 100644 index 00000000..beae8982 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/geglu.cpp @@ -0,0 +1,155 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void _kernel(__gm__ half* v1, __gm__ half* v2, __gm__ half* v3, int32_t v4, int32_t v5) { + unsigned v6 = 1; + unsigned v7 = 0; + const int32_t v8 = 16384; + const int32_t v9 = 1; + const int32_t v10 = 0; + const int64_t v11 = 0; + const int64_t v12 = 32768; + const int64_t v13 = 65536; + const int64_t v14 = 98304; + const int64_t v15 = 131072; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + if (v5 > v10) { + if (v5 <= v8) { + int64_t v16 = get_block_idx(); + int64_t v17 = get_subblockid(); + int64_t v18 = get_subblockdim(); + int64_t v19 = (int64_t) v18; + int64_t v20 = get_block_num(); + int32_t v21 = (int32_t) ((int64_t) (uint64_t) ((int64_t) v20) * (uint64_t) v19); + int32_t v22 = v4 / v21; + int32_t v23 = v4 % v21 != v10 && v4 < v10 == v21 < v10 ? v22 + v9 : v22; + int32_t v24 = (int32_t) ((uint32_t) ((int32_t) (int64_t) ((uint64_t) ((int64_t) (uint64_t) ((int64_t) v16) * (uint64_t) v19) + (uint64_t) ((int64_t) v17))) * (uint32_t) v23); + int32_t v25 = (int32_t) ((uint32_t) v24 + (uint32_t) v23); + int32_t v26 = (int32_t) ((uint32_t) ((uint32_t) v25 < (uint32_t) v4 ? v25 : v4) - (uint32_t) v24); + int32_t v27 = (int32_t) ((uint32_t) v4 * (uint32_t) v5); + if (v26 > v10) { + Tile v28; + TASSIGN(v28, v11); + Tile v29 = Tile(v5); + __ubuf__ half* v30 = v28.data(); + uint64_t v31 = reinterpret_cast(v30); + TASSIGN(v29, v31); + Tile v32; + TASSIGN(v32, v12); + Tile v33 = Tile(v5); + __ubuf__ half* v34 = v32.data(); + uint64_t v35 = reinterpret_cast(v34); + TASSIGN(v33, v35); + Tile v36; + TASSIGN(v36, v13); + Tile v37 = Tile(v5); + __ubuf__ half* v38 = v36.data(); + uint64_t v39 = reinterpret_cast(v38); + TASSIGN(v37, v39); + Tile v40; + TASSIGN(v40, v14); + Tile v41 = Tile(v5); + __ubuf__ half* v42 = v40.data(); + uint64_t v43 = reinterpret_cast(v42); + TASSIGN(v41, v43); + Tile v44; + TASSIGN(v44, v15); + Tile v45 = Tile(v5); + __ubuf__ half* v46 = v44.data(); + uint64_t v47 = reinterpret_cast(v46); + TASSIGN(v45, v47); + for (size_t v48 = (size_t) v10; v48 < ((size_t) v26); v48 += (size_t) v9) { + int32_t v49 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v24 + (uint32_t) ((int32_t) v48)) * (uint32_t) v5); + unsigned v50 = (unsigned) v5 * v6; + pto::Shape<1, 1, 1, 1, -1> v51 = pto::Shape<1, 1, 1, 1, -1>(v5); + pto::Stride<-1, -1, -1, -1, 1> v52 = pto::Stride<-1, -1, -1, -1, 1>(v50, v50, v50, v50); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v53 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v7 + (unsigned) v49 * (unsigned) v9), v51, v52); + unsigned v54 = (unsigned) v5 * v6; + pto::Shape<1, 1, 1, 1, -1> v55 = pto::Shape<1, 1, 1, 1, -1>(v5); + pto::Stride<-1, -1, -1, -1, 1> v56 = pto::Stride<-1, -1, -1, -1, 1>(v54, v54, v54, v54); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v57 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v2 + (v7 + (unsigned) v49 * (unsigned) v9), v55, v56); + unsigned v58 = (unsigned) v5 * v6; + pto::Shape<1, 1, 1, 1, -1> v59 = pto::Shape<1, 1, 1, 1, -1>(v5); + pto::Stride<-1, -1, -1, -1, 1> v60 = pto::Stride<-1, -1, -1, -1, 1>(v58, v58, v58, v58); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v61 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v3 + (v7 + (unsigned) v49 * (unsigned) v9), v59, v60); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v29, v53); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v33, v57); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TSUB(v45, v29, v29); + pipe_barrier(PIPE_V); + TEXP(v37, v45); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TADD(v41, v29, v29); + pipe_barrier(PIPE_V); + TEXP(v41, v41); + pipe_barrier(PIPE_V); + TSUB(v45, v41, v37); + pipe_barrier(PIPE_V); + TADD(v41, v41, v37); + pipe_barrier(PIPE_V); + TDIV(v45, v45, v41); + pipe_barrier(PIPE_V); + TADD(v41, v37, v45); + pipe_barrier(PIPE_V); + TMUL(v41, v29, v41); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TADD(v45, v37, v37); + pipe_barrier(PIPE_V); + TDIV(v41, v41, v45); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TMUL(v41, v41, v33); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pipe_barrier(PIPE_MTE3); + TSTORE(v61, v41); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + }; + }; + }; + } + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/geglu.pto b/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/geglu.pto new file mode 100644 index 00000000..8b731674 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/geglu.pto @@ -0,0 +1,67 @@ +module { + func.func @_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16384 = arith.constant 16384 : index + %0 = arith.index_cast %arg3 : i32 to index + %1 = arith.index_cast %arg4 : i32 to index + pto.section.vector { + %2 = arith.cmpi sgt, %1, %c0 : index + scf.if %2 { + %3 = arith.cmpi sge, %c16384, %1 : index + scf.if %3 { + %4 = pto.get_block_idx + %5 = pto.get_subblock_idx + %6 = pto.get_subblock_num + %7 = pto.get_block_num + %8 = arith.muli %4, %6 : i64 + %9 = arith.addi %8, %5 : i64 + %10 = arith.index_cast %9 : i64 to index + %11 = arith.muli %7, %6 : i64 + %12 = arith.index_cast %11 : i64 to index + %13 = arith.ceildivsi %0, %12 : index + %14 = arith.muli %10, %13 : index + %15 = arith.addi %14, %13 : index + %16 = arith.minui %15, %0 : index + %17 = arith.subi %16, %14 : index + %18 = arith.muli %0, %1 : index + %19 = pto.make_tensor_view %arg0, shape = [%18], strides = [%c1] : !pto.tensor_view + %20 = pto.make_tensor_view %arg1, shape = [%18], strides = [%c1] : !pto.tensor_view + %21 = pto.make_tensor_view %arg2, shape = [%18], strides = [%c1] : !pto.tensor_view + %22 = arith.cmpi sgt, %17, %c0 : index + scf.if %22 { + %23 = pto.alloc_tile valid_col = %1 : !pto.tile_buf + %24 = pto.alloc_tile valid_col = %1 : !pto.tile_buf + %25 = pto.alloc_tile valid_col = %1 : !pto.tile_buf + %26 = pto.alloc_tile valid_col = %1 : !pto.tile_buf + %27 = pto.alloc_tile valid_col = %1 : !pto.tile_buf + scf.for %arg5 = %c0 to %17 step %c1 { + %28 = arith.addi %14, %arg5 : index + %29 = arith.muli %28, %1 : index + %30 = pto.partition_view %19, offsets = [%29], sizes = [%1] : !pto.tensor_view -> !pto.partition_tensor_view<1x16384xf16> + %31 = pto.partition_view %20, offsets = [%29], sizes = [%1] : !pto.tensor_view -> !pto.partition_tensor_view<1x16384xf16> + %32 = pto.partition_view %21, offsets = [%29], sizes = [%1] : !pto.tensor_view -> !pto.partition_tensor_view<1x16384xf16> + pto.tload ins(%30 : !pto.partition_tensor_view<1x16384xf16>) outs(%23 : !pto.tile_buf) + pto.tload ins(%31 : !pto.partition_tensor_view<1x16384xf16>) outs(%24 : !pto.tile_buf) + pto.tsub ins(%23, %23 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.texp ins(%27 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tadd ins(%23, %23 : !pto.tile_buf, !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.texp ins(%26 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tsub ins(%26, %25 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tadd ins(%26, %25 : !pto.tile_buf, !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tdiv ins(%27, %26 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tadd ins(%25, %27 : !pto.tile_buf, !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmul ins(%23, %26 : !pto.tile_buf, !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tadd ins(%25, %25 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tdiv ins(%26, %27 : !pto.tile_buf, !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmul ins(%26, %24 : !pto.tile_buf, !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tstore ins(%26 : !pto.tile_buf) outs(%32 : !pto.partition_tensor_view<1x16384xf16>) + } + } + } + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/geglu_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/geglu_builder.py new file mode 100644 index 00000000..7d1e88b3 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/activations/geglu_dynamic_multicore/geglu/geglu_builder.py @@ -0,0 +1,179 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +# 32 KB of UB / sizeof(fp16) = 16384 elements per tile +ELEMENTS_PER_TILE = 32 * 1024 // 2 + + +def meta_data(): + dtype = pto.float16 + ptr_type = pto.PtrType(dtype) + index_dtype = pto.int32 + + tensor_type = pto.TensorType(rank=1, dtype=dtype) + subtensor_type = pto.SubTensorType(shape=[1, ELEMENTS_PER_TILE], dtype=dtype) + + tile_cfg = pto.TileBufConfig() + tile_type = pto.TileBufType( + shape=[1, ELEMENTS_PER_TILE], + valid_shape=[1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + + return { + "ptr_type": ptr_type, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_type": subtensor_type, + "tile_type": tile_type, + } + + +def build_geglu(fn_name="geglu_fp16"): + """ + Build a dynamic-batch GEGLU kernel in PTO DSL. + + Computes c = gelu_approx(a) * b, where: + gelu_approx(a) = 0.5 * a * (1 + tanh(a)) + tanh(a) = (exp(2a) - 1) / (exp(2a) + 1) + + Constants (1.0, 2.0) are derived from the input tile itself using + the identity exp(a - a) = exp(0) = 1.0, which avoids the need for + scalar-tile broadcast operations not available in PTO DSL. + + UB tile budget (fp16, 5 tiles × 32 KB = 160 KB < 192 KB): + tb_a : input row a + tb_b : input row b + tb_ones : constant 1.0 (recomputed each row via exp(a-a)) + tb_tmp1 : intermediate / final output + tb_tmp2 : intermediate + + Kernel args: + a_ptr : fp16[batch * n_cols] -- gating input + b_ptr : fp16[batch * n_cols] -- linear input + c_ptr : fp16[batch * n_cols] -- output + batch : int32 -- number of rows + n_cols : int32 -- elements per row; must be <= 16384 + """ + + @to_ir_module(meta_data=meta_data) + def _kernel( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + batch_i32: "index_dtype", + n_cols_i32: "index_dtype", + ) -> None: + c0 = const(0) + c1 = const(1) + c_tile = const(ELEMENTS_PER_TILE) + + batch = s.index_cast(batch_i32) + n_cols = s.index_cast(n_cols_i32) + + with pto.vector_section(): + # Guard: n_cols must be in (0, ELEMENTS_PER_TILE]. + + with pto.if_context(n_cols > c0): + with pto.if_context(c_tile >= n_cols): + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast( + num_blocks * sub_bnum + ) # number of vector cores + + # Distribute rows across cores (row-level parallelism). + rows_per_core = s.ceil_div(batch, num_cores) + row_start = vid * rows_per_core + row_end = s.min_u(row_start + rows_per_core, batch) + num_rows = row_end - row_start + + total_elems = batch * n_cols + tv_a = pto.as_tensor( + tensor_type, ptr=a_ptr, shape=[total_elems], strides=[c1] + ) + tv_b = pto.as_tensor( + tensor_type, ptr=b_ptr, shape=[total_elems], strides=[c1] + ) + tv_c = pto.as_tensor( + tensor_type, ptr=c_ptr, shape=[total_elems], strides=[c1] + ) + + with pto.if_context(num_rows > c0): + # Allocate 5 UB tiles (160 KB total, well under 192 KB UB). + tb_a = pto.alloc_tile(tile_type, valid_col=n_cols) + tb_b = pto.alloc_tile(tile_type, valid_col=n_cols) + tb_ones = pto.alloc_tile(tile_type, valid_col=n_cols) + tb_tmp1 = pto.alloc_tile(tile_type, valid_col=n_cols) + tb_tmp2 = pto.alloc_tile(tile_type, valid_col=n_cols) + + for row_i in pto.range(c0, num_rows, c1): + gm_offset = (row_start + row_i) * n_cols + + sv_a = pto.slice_view( + subtensor_type, + source=tv_a, + offsets=[gm_offset], + sizes=[n_cols], + ) + sv_b = pto.slice_view( + subtensor_type, + source=tv_b, + offsets=[gm_offset], + sizes=[n_cols], + ) + sv_c = pto.slice_view( + subtensor_type, + source=tv_c, + offsets=[gm_offset], + sizes=[n_cols], + ) + + pto.load(sv_a, tb_a) + pto.load(sv_b, tb_b) + + # Derive constants from data (no scalar-tile broadcast needed): + # a - a = 0 => exp(0) = 1.0 + tile.sub(tb_a, tb_a, tb_tmp2) # tmp2 = 0.0 + tile.exp(tb_tmp2, tb_ones) # ones = 1.0 + + # tanh(a) = (exp(2a) - 1) / (exp(2a) + 1) + tile.add(tb_a, tb_a, tb_tmp1) # tmp1 = 2a + tile.exp(tb_tmp1, tb_tmp1) # tmp1 = exp(2a) + tile.sub(tb_tmp1, tb_ones, tb_tmp2) # tmp2 = exp(2a) - 1 + tile.add(tb_tmp1, tb_ones, tb_tmp1) # tmp1 = exp(2a) + 1 + tile.div(tb_tmp2, tb_tmp1, tb_tmp2) # tmp2 = tanh(a) + + # gelu_approx(a) = a * (1 + tanh(a)) / 2 + tile.add(tb_ones, tb_tmp2, tb_tmp1) # tmp1 = 1 + tanh(a) + tile.mul(tb_a, tb_tmp1, tb_tmp1) # tmp1 = a * (1 + tanh(a)) + tile.add(tb_ones, tb_ones, tb_tmp2) # tmp2 = 2.0 + tile.div(tb_tmp1, tb_tmp2, tb_tmp1) # tmp1 = gelu_approx(a) + + # GEGLU: c = gelu_approx(a) * b + tile.mul(tb_tmp1, tb_b, tb_tmp1) # tmp1 = c + pto.store(tb_tmp1, sv_c) + + _ = fn_name + return _kernel + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--fn-name", + default="geglu_fp16", + help="Generated kernel function name.", + ) + args = parser.parse_args() + print(build_geglu(fn_name=args.fn_name)) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/compile.sh new file mode 100644 index 00000000..74d8cc13 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python relu_builder.py > ./relu.pto +ptoas --enable-insert-sync ./relu.pto > generated_relu.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/generated_relu.cpp b/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/generated_relu.cpp new file mode 100644 index 00000000..72934456 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/generated_relu.cpp @@ -0,0 +1,97 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void sync_kernel_dyn(__gm__ float* v1, __gm__ float* v2, int32_t v3) { + unsigned v4 = 0; + const int32_t v5 = 0; + const int32_t v6 = 1; + const int32_t v7 = 32; + const int64_t v8 = 0; + const int64_t v9 = 128; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + int64_t v10 = get_block_idx(); + int64_t v11 = get_subblockid(); + int64_t v12 = get_subblockdim(); + int64_t v13 = (int64_t) v12; + int64_t v14 = get_block_num(); + int32_t v15 = (int32_t) ((int64_t) (uint64_t) ((int64_t) v14) * (uint64_t) v13); + int32_t v16 = v3 / v15; + int32_t v17 = v3 % v15 != v5 && v3 < v5 == v15 < v5 ? v16 + v6 : v16; + int32_t v18 = (int32_t) ((uint32_t) ((int32_t) (int64_t) ((uint64_t) ((int64_t) (uint64_t) ((int64_t) v10) * (uint64_t) v13) + (uint64_t) ((int64_t) v11))) * (uint32_t) v17); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + if (v18 < v3) { + int32_t v19 = (int32_t) ((uint32_t) v18 + (uint32_t) v17); + int32_t v20 = (uint32_t) v19 < (uint32_t) v3 ? v19 : v3; + int32_t v21 = (int32_t) ((uint32_t) v20 - (uint32_t) v18); + int32_t v22 = v21 / v7; + for (size_t v23 = (size_t) v5; v23 < ((size_t) (v21 % v7 != v5 && v21 < v5 == v7 < v5 ? v22 + v6 : v22)); v23 += (size_t) v6) { + int32_t v24 = (int32_t) ((uint32_t) v18 + (uint32_t) ((int32_t) (uint32_t) ((int32_t) v23) * (uint32_t) v7)); + int32_t v25 = (int32_t) ((uint32_t) v20 - (uint32_t) v24); + int32_t v26 = (uint32_t) v25 < (uint32_t) v7 ? v25 : v7; + Tile v27; + TASSIGN(v27, v8); + Tile v28 = Tile(v6, v26); + __ubuf__ float* v29 = v27.data(); + uint64_t v30 = reinterpret_cast(v29); + TASSIGN(v28, v30); + Tile v31; + TASSIGN(v31, v9); + Tile v32 = Tile(v6, v26); + __ubuf__ float* v33 = v31.data(); + uint64_t v34 = reinterpret_cast(v33); + TASSIGN(v32, v34); + pto::Shape<1, 1, 1, 1, 32> v35 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v36 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v37 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v4 + (unsigned) v24 * (unsigned) v6), v35, v36); + pto::Shape<1, 1, 1, 1, 32> v38 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v39 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v40 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v2 + (v4 + (unsigned) v24 * (unsigned) v6), v38, v39); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v28, v37); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TRELU(v32, v28); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pipe_barrier(PIPE_MTE3); + TSTORE(v40, v32); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + }; + } + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/relu.pto b/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/relu.pto new file mode 100644 index 00000000..da778db5 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/relu.pto @@ -0,0 +1,45 @@ +module { + func.func @sync_kernel_dyn(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: i32) { + pto.section.vector { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %0 = arith.index_cast %arg2 : i32 to index + %1 = pto.get_block_idx + %2 = pto.get_subblock_idx + %3 = pto.get_subblock_num + %4 = arith.muli %1, %3 : i64 + %5 = arith.addi %4, %2 : i64 + %6 = arith.index_cast %5 : i64 to index + %7 = pto.get_block_num + %8 = arith.muli %7, %3 : i64 + %9 = arith.index_cast %8 : i64 to index + %10 = arith.ceildivsi %0, %9 : index + %11 = arith.muli %6, %10 : index + %12 = pto.make_tensor_view %arg0, shape = [%0], strides = [%c1] : !pto.tensor_view + %13 = pto.make_tensor_view %arg1, shape = [%0], strides = [%c1] : !pto.tensor_view + %14 = arith.cmpi slt, %11, %0 : index + scf.if %14 { + %15 = arith.addi %11, %10 : index + %16 = arith.minui %15, %0 : index + %17 = arith.subi %16, %11 : index + %18 = arith.ceildivsi %17, %c32 : index + scf.for %arg3 = %c0 to %18 step %c1 { + %19 = arith.muli %arg3, %c32 : index + %20 = arith.addi %11, %19 : index + %21 = arith.subi %16, %20 : index + %22 = arith.minui %21, %c32 : index + %23 = pto.alloc_tile valid_row = %c1 valid_col = %22 : !pto.tile_buf + %24 = pto.alloc_tile valid_row = %c1 valid_col = %22 : !pto.tile_buf + %25 = pto.partition_view %12, offsets = [%20], sizes = [%c32] : !pto.tensor_view -> !pto.partition_tensor_view<32xf32> + %26 = pto.partition_view %13, offsets = [%20], sizes = [%c32] : !pto.tensor_view -> !pto.partition_tensor_view<32xf32> + pto.tload ins(%25 : !pto.partition_tensor_view<32xf32>) outs(%23 : !pto.tile_buf) + pto.trelu ins(%23 : !pto.tile_buf) outs(%24 : !pto.tile_buf) + pto.tstore ins(%24 : !pto.tile_buf) outs(%26 : !pto.partition_tensor_view<32xf32>) + } + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/relu_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/relu_builder.py new file mode 100644 index 00000000..0ce2397e --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/activations/relu_dynamic_multicore/relu/relu_builder.py @@ -0,0 +1,103 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + + +def build(): + tile_w = 32 + + def meta_data(): + dtype = pto.float32 + index_dtype = pto.int32 + ptr_type = pto.PtrType(dtype) + tensor_type = pto.TensorType(rank=1, dtype=dtype) + subtensor_type = pto.SubTensorType(shape=[tile_w], dtype=dtype) + tile_cfg = pto.TileBufConfig() + # Dynamic valid shape so we can mask partial tiles via valid_row/valid_col. + tile_type = pto.TileBufType( + shape=[1, tile_w], + valid_shape=[-1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + return { + "ptr_type": ptr_type, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_type": subtensor_type, + "tile_type": tile_type, + "tile_w": tile_w, + } + + const = s.const + + @to_ir_module(meta_data=meta_data) + def sync_kernel_dyn( + arg0: "ptr_type", arg1: "ptr_type", argN: "index_dtype" + ) -> None: + with pto.vector_section(): + c0 = const(0) + c1 = const(1) + c_tile_w = const(tile_w) + total_elements = s.index_cast(argN) + + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + vid = s.index_cast(cid * sub_bnum + sub_bid) + num_blocks = s.index_cast(pto.get_block_num() * sub_bnum) + num_el_per_core = s.ceil_div(total_elements, num_blocks) + + # Per-core range: [core_start, core_end) + core_start = vid * num_el_per_core + + # GM tensors shape N with stride 1. + tv0 = pto.as_tensor( + tensor_type, ptr=arg0, shape=[total_elements], strides=[c1] + ) + tv1 = pto.as_tensor( + tensor_type, ptr=arg1, shape=[total_elements], strides=[c1] + ) + + with pto.if_context(core_start < total_elements): + core_end_unclamped = core_start + num_el_per_core + core_end = s.min_u(core_end_unclamped, total_elements) + core_len = core_end - core_start + + # Per-core number of tiles: ceil(core_len / tile_w). + num_tiles = s.ceil_div(core_len, c_tile_w) + + for i in pto.range(c0, num_tiles, c1): + offset_tile = i * c_tile_w + offset_total = core_start + offset_tile + + remaining_core = core_end - offset_total + valid_len = s.min_u(remaining_core, c_tile_w) + + # Keep per-iteration tile alloc to match original behavior. + tb0 = pto.alloc_tile(tile_type, valid_row=c1, valid_col=valid_len) + tb1 = pto.alloc_tile(tile_type, valid_row=c1, valid_col=valid_len) + + # each core c takes a tile at offset c*num_el_per_core + i*tile_w + sv0 = pto.slice_view( + subtensor_type, + source=tv0, + offsets=[offset_total], + sizes=[c_tile_w], + ) + sv1 = pto.slice_view( + subtensor_type, + source=tv1, + offsets=[offset_total], + sizes=[c_tile_w], + ) + + pto.load(sv0, tb0) + tile.relu(tb0, tb1) + pto.store(tb1, sv1) + + return sync_kernel_dyn + + +if __name__ == "__main__": + print(build()) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/compile.sh new file mode 100644 index 00000000..011431e4 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./matmul_builder.py > matmul.pto +ptoas matmul.pto -o matmul.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/matmul.cpp b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/matmul.cpp new file mode 100644 index 00000000..d0248341 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/matmul.cpp @@ -0,0 +1,111 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void RunTMATMULSplitK(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, __gm__ float* v4, bool v5, int32_t v6) { + unsigned v7 = 0; + const int32_t v8 = 0; + const int32_t v9 = 1; + const int32_t v10 = 128; + const int32_t v11 = 32; + const int32_t v12 = 4; + const int64_t v13 = 16384; + const int64_t v14 = 0; + const int64_t v15 = 32768; + using T = float; + size_t v16 = (size_t) v9; + + #if defined(__DAV_CUBE__) + int32_t v17 = (int32_t) ((uint32_t) v6 * (uint32_t) v10); + int64_t v18 = get_block_num(); + int32_t v19 = (int32_t) ((int64_t) v18); + int32_t v20 = v6 / v19; + int32_t v21 = v6 % v19 != v8 && v6 < v8 == v19 < v8 ? v20 + v9 : v20; + int64_t v22 = get_block_idx(); + int32_t v23 = (int32_t) ((uint32_t) ((int32_t) (int64_t) v22) * (uint32_t) v21); + int32_t v24 = (int32_t) ((uint32_t) v23 + (uint32_t) v21); + Tile v25; + TASSIGN(v25, v13); + Tile v26; + TASSIGN(v26, v14); + Tile v27; + TASSIGN(v27, v15); + Tile v28; + TASSIGN(v28, v14); + Tile v29; + TASSIGN(v29, v14); + Tile v30; + TASSIGN(v30, v14); + Tile v31; + TASSIGN(v31, v14); + for (size_t v32 = (size_t) v23; v32 < ((size_t) ((uint32_t) v24 < (uint32_t) v6 ? v24 : v6)); v32 += v16) { + int32_t v33 = (int32_t) ((uint32_t) ((int32_t) v32) * (uint32_t) v10); + for (size_t v34 = (size_t) v8; v34 < ((size_t) v12); v34 += v16) { + int32_t v35 = (int32_t) v34; + int32_t v36 = (int32_t) ((uint32_t) v35 * (uint32_t) v11); + pto::Shape<1, 1, 1, 128, 32> v37 = pto::Shape<1, 1, 1, 128, 32>(); + pto::Stride<16384, 16384, 16384, 128, 1> v38 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v39 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v2 + (v7 + (unsigned) v33 * (unsigned) v10 + (unsigned) v36 * (unsigned) v9), v37, v38); + pto::Shape<1, 1, 1, 32, 128> v40 = pto::Shape<1, 1, 1, 32, 128>(); + pto::Stride<4096, 4096, 4096, 128, 1> v41 = pto::Stride<4096, 4096, 4096, 128, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 128, 1>, pto::Layout::ND> v42 = GlobalTensor, pto::Stride<4096, 4096, 4096, 128, 1>, pto::Layout::ND>(v3 + (v7 + (unsigned) v36 * (unsigned) v10 + v7 * (unsigned) v9), v40, v41); + pto::Shape<1, 1, 1, 1, 128> v43 = pto::Shape<1, 1, 1, 1, 128>(); + pto::Stride<128, 128, 128, 128, 1> v44 = pto::Stride<128, 128, 128, 128, 1>(); + GlobalTensor, pto::Stride<128, 128, 128, 128, 1>, pto::Layout::ND> v45 = GlobalTensor, pto::Stride<128, 128, 128, 128, 1>, pto::Layout::ND>(v4 + (v7 + v7 * (unsigned) v10 + v7 * (unsigned) v9), v43, v44); + TLOAD(v25, v39); + TLOAD(v26, v42); + if (v5) { + TLOAD(v27, v45); + }; + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v28, v25); + TMOV(v29, v26); + if (v5) { + TMOV(v31, v27); + }; + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v35 == v8) { + if (v5) { + TMATMUL_BIAS(v30, v28, v29, v31); + } else { + TMATMUL(v30, v28, v29); + }; + } else { + TMATMUL_ACC(v30, v30, v28, v29); + }; + set_flag(PIPE_M, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE2, EVENT_ID0); + }; + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + pto::Shape<1, 1, 1, 128, 128> v46 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v47 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v48 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v1 + (v7 + (unsigned) v33 * (unsigned) v10 + v7 * (unsigned) v9), v46, v47); + TSTORE(v48, v30); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + } + #endif // __DAV_CUBE__ + + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/matmul.pto b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/matmul.pto new file mode 100644 index 00000000..954cd09e --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/matmul.pto @@ -0,0 +1,79 @@ +module { + func.func @RunTMATMULSplitK(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %arg4: i1, %arg5: i32) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c128_0 = arith.constant 128 : index + %c128_1 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c4 = arith.constant 4 : index + %c128_2 = arith.constant 128 : index + %c128_3 = arith.constant 128 : index + %0 = arith.index_cast %arg5 : i32 to index + %1 = arith.muli %0, %c128 : index + %2 = pto.get_block_num + %3 = arith.index_cast %2 : i64 to index + %4 = arith.ceildivsi %0, %3 : index + %5 = pto.get_block_idx + %6 = arith.index_cast %5 : i64 to index + %7 = arith.muli %6, %4 : index + %8 = arith.addi %7, %4 : index + %9 = arith.minui %8, %0 : index + %10 = pto.make_tensor_view %arg1, shape = [%1, %c128_0], strides = [%c128_0, %c1] : !pto.tensor_view + %11 = pto.make_tensor_view %arg2, shape = [%c128_0, %c128_1], strides = [%c128_1, %c1] : !pto.tensor_view + %12 = pto.make_tensor_view %arg0, shape = [%1, %c128_1], strides = [%c128_1, %c1] : !pto.tensor_view + %13 = pto.make_tensor_view %arg3, shape = [%c1, %c128_1], strides = [%c128_1, %c1] : !pto.tensor_view + %14 = pto.alloc_tile : !pto.tile_buf + %15 = pto.alloc_tile : !pto.tile_buf + %16 = pto.alloc_tile : !pto.tile_buf + %17 = pto.alloc_tile : !pto.tile_buf + %18 = pto.alloc_tile : !pto.tile_buf + %19 = pto.alloc_tile : !pto.tile_buf + %20 = pto.alloc_tile : !pto.tile_buf + scf.for %arg6 = %7 to %9 step %c1 { + %21 = arith.muli %arg6, %c128 : index + scf.for %arg7 = %c0 to %c4 step %c1 { + %23 = arith.muli %arg7, %c32 : index + %24 = pto.partition_view %10, offsets = [%21, %23], sizes = [%c128_2, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<128x32xf32> + %25 = pto.partition_view %11, offsets = [%23, %c0], sizes = [%c32, %c128_3] : !pto.tensor_view -> !pto.partition_tensor_view<32x128xf32> + %26 = pto.partition_view %13, offsets = [%c0, %c0], sizes = [%c1, %c128_3] : !pto.tensor_view -> !pto.partition_tensor_view<1x128xf32> + pto.tload ins(%24 : !pto.partition_tensor_view<128x32xf32>) outs(%14 : !pto.tile_buf) + pto.tload ins(%25 : !pto.partition_tensor_view<32x128xf32>) outs(%15 : !pto.tile_buf) + scf.if %arg4 { + pto.tload ins(%26 : !pto.partition_tensor_view<1x128xf32>) outs(%16 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmov ins(%14 : !pto.tile_buf) outs(%17 : !pto.tile_buf) + pto.tmov ins(%15 : !pto.tile_buf) outs(%18 : !pto.tile_buf) + scf.if %arg4 { + pto.tmov ins(%16 : !pto.tile_buf) outs(%20 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + %27 = arith.cmpi eq, %arg7, %c0 : index + scf.if %27 { + scf.if %arg4 { + pto.tmatmul.bias ins(%17, %18, %20 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%19 : !pto.tile_buf) + } else { + pto.tmatmul ins(%17, %18 : !pto.tile_buf, !pto.tile_buf) outs(%19 : !pto.tile_buf) + } + } else { + pto.tmatmul.acc ins(%19, %17, %18 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%19 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + %22 = pto.partition_view %12, offsets = [%21, %c0], sizes = [%c128_2, %c128_3] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf32> + pto.tstore ins(%19 : !pto.tile_buf) outs(%22 : !pto.partition_tensor_view<128x128xf32>) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/matmul_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/matmul_builder.py new file mode 100644 index 00000000..28015228 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore/matmul/matmul_builder.py @@ -0,0 +1,193 @@ +from mlir.ir import IntegerType + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + + +def build( + M=128, + K=128, + N=128, + validM=128, + validK=128, + validN=128, + BASEK=32, +): + assert K % BASEK == 0 + iters = K // BASEK + + def meta_data(): + dtype = pto.float32 + ptr_dtype = pto.PtrType(dtype) + i1 = IntegerType.get_signless(1) + i32 = pto.int32 + + tensor_type = pto.TensorType(rank=2, dtype=dtype) + + tile_view_a = pto.SubTensorType(shape=[M, BASEK], dtype=dtype) + tile_view_b = pto.SubTensorType(shape=[BASEK, N], dtype=dtype) + tile_view_out = pto.SubTensorType(shape=[M, N], dtype=dtype) + tile_view_bias = pto.SubTensorType(shape=[1, N], dtype=dtype) + + tile_buf_aMat = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="MAT" + ) + tile_buf_bMat = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="MAT" + ) + tile_buf_biasData = pto.TileBufType( + shape=[1, N], dtype=dtype, memory_space="MAT" + ) + + tile_buf_aTile = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="LEFT" + ) + tile_buf_bTile = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_cTile = pto.TileBufType(shape=[M, N], dtype=dtype, memory_space="ACC") + tile_buf_biasTile = pto.TileBufType( + shape=[1, N], dtype=dtype, memory_space="BIAS" + ) + + return { + "ptr_type": ptr_dtype, + "i1": i1, + "i32": i32, + "tensor_type": tensor_type, + "tile_view_a": tile_view_a, + "tile_view_b": tile_view_b, + "tile_view_out": tile_view_out, + "tile_view_bias": tile_view_bias, + "tile_buf_aMat": tile_buf_aMat, + "tile_buf_bMat": tile_buf_bMat, + "tile_buf_biasData": tile_buf_biasData, + "tile_buf_aTile": tile_buf_aTile, + "tile_buf_bTile": tile_buf_bTile, + "tile_buf_cTile": tile_buf_cTile, + "tile_buf_biasTile": tile_buf_biasTile, + } + + const = s.const + + @to_ir_module(meta_data=meta_data) + def RunTMATMULSplitK( + out_ptr: "ptr_type", + a_ptr: "ptr_type", + b_ptr: "ptr_type", + bias_ptr: "ptr_type", + isBias: "i1", + batch_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + cM = const(validM) + cK = const(validK) + cN = const(validN) + cBASEK = const(BASEK) + cIter = const(iters) + cTileM = const(M) + cTileN = const(N) + + batch = s.index_cast(batch_i32) + cBM = batch * cM + + num_blocks = s.index_cast(pto.get_block_num()) + batches_per_core = s.ceil_div(batch, num_blocks) + bid = s.index_cast(pto.get_block_idx()) + b_start = bid * batches_per_core + b_end_unclamped = b_start + batches_per_core + b_end = s.min_u(b_end_unclamped, batch) + + tvA = pto.as_tensor( + tensor_type, ptr=a_ptr, shape=[cBM, cK], strides=[cK, c1] + ) + tvB = pto.as_tensor( + tensor_type, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1] + ) + tvOut = pto.as_tensor( + tensor_type, ptr=out_ptr, shape=[cBM, cN], strides=[cN, c1] + ) + tvBias = pto.as_tensor( + tensor_type, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1] + ) + + aMatTile = pto.alloc_tile(tile_buf_aMat) + bMatTile = pto.alloc_tile(tile_buf_bMat) + biasDataTile = pto.alloc_tile(tile_buf_biasData) + aTile = pto.alloc_tile(tile_buf_aTile) + bTile = pto.alloc_tile(tile_buf_bTile) + cTile = pto.alloc_tile(tile_buf_cTile) + biasTile = pto.alloc_tile(tile_buf_biasTile) + + for b_idx in pto.range(b_start, b_end, c1): + row_off = b_idx * cM + + for i in pto.range(c0, cIter, c1): + kOff = i * cBASEK + svA = pto.slice_view( + tile_view_a, + source=tvA, + offsets=[row_off, kOff], + sizes=[cTileM, cBASEK], + ) + svB = pto.slice_view( + tile_view_b, + source=tvB, + offsets=[kOff, c0], + sizes=[cBASEK, cTileN], + ) + svBias = pto.slice_view( + tile_view_bias, + source=tvBias, + offsets=[c0, c0], + sizes=[c1, cTileN], + ) + + pto.load(svA, aMatTile) + pto.load(svB, bMatTile) + with pto.if_context(isBias): + pto.load(svBias, biasDataTile) + + pto.record_wait_pair("LOAD", "MOV_M2L", event_id=0) + + tile.mov(aMatTile, aTile) + tile.mov(bMatTile, bTile) + with pto.if_context(isBias): + tile.mov(biasDataTile, biasTile) + + pto.record_wait_pair("MOV_M2L", "MATMUL", event_id=0) + + is_i0 = s.eq(i, c0) + + def _first_iter(): + pto.cond( + isBias, + lambda: tile.matmul_bias(aTile, bTile, biasTile, cTile), + lambda: tile.matmul(aTile, bTile, cTile), + ) + + pto.cond( + is_i0, + _first_iter, + lambda: tile.matmul_acc(cTile, aTile, bTile, cTile), + ) + + pto.record_wait_pair("MATMUL", "LOAD", event_id=0) + + pto.record_wait_pair("MATMUL", "STORE_ACC", event_id=0) + svOut = pto.slice_view( + tile_view_out, + source=tvOut, + offsets=[row_off, c0], + sizes=[cTileM, cTileN], + ) + pto.store(cTile, svOut) + pto.record_wait_pair("STORE_ACC", "MATMUL", event_id=0) + + return RunTMATMULSplitK + + +if __name__ == "__main__": + print(build()) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/compile.sh new file mode 100644 index 00000000..011431e4 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./matmul_builder.py > matmul.pto +ptoas matmul.pto -o matmul.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/matmul.cpp b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/matmul.cpp new file mode 100644 index 00000000..8d995638 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/matmul.cpp @@ -0,0 +1,85 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void RunTMATMULSplitK(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, __gm__ float* v4, bool v5, int32_t v6) { + unsigned v7 = 0; + const int32_t v8 = 0; + const int32_t v9 = 1; + const int32_t v10 = 128; + const int32_t v11 = 16384; + const int64_t v12 = 65536; + const int64_t v13 = 0; + using T = float; + + #if defined(__DAV_CUBE__) + int64_t v14 = get_block_num(); + int32_t v15 = (int32_t) ((int64_t) v14); + int64_t v16 = get_block_idx(); + int32_t v17 = (int32_t) ((int64_t) v16); + int32_t v18 = v6 / v15; + int32_t v19 = v6 % v15; + int32_t v20 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v17 * (uint32_t) v18) + (uint32_t) ((uint32_t) v17 < (uint32_t) v19 ? v17 : v19)); + int32_t v21 = (int32_t) ((uint32_t) v20 + (uint32_t) ((int32_t) (uint32_t) v18 + (uint32_t) (v17 < v19 ? v9 : v8))); + Tile v22; + TASSIGN(v22, v12); + Tile v23; + TASSIGN(v23, v13); + Tile v24; + TASSIGN(v24, v13); + Tile v25; + TASSIGN(v25, v13); + Tile v26; + TASSIGN(v26, v13); + pto::Shape<1, 1, 1, 128, 128> v27 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v28 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v29 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v3 + (v7 + v7 * (unsigned) v10 + v7 * (unsigned) v9), v27, v28); + TLOAD(v23, v29); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v25, v23); + for (size_t v30 = (size_t) v20; v30 < ((size_t) ((uint32_t) v21 < (uint32_t) v6 ? v21 : v6)); v30 += (size_t) v9) { + int32_t v31 = (int32_t) v30; + pto::Shape<1, 1, 1, 128, 128> v32 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v33 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v34 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v2 + ((v7 + (unsigned) v31 * (unsigned) v11) + v7 * (unsigned) v10 + v7 * (unsigned) v9), v32, v33); + pto::Shape<1, 1, 1, 128, 128> v35 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v36 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v37 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v1 + ((v7 + (unsigned) v31 * (unsigned) v11) + v7 * (unsigned) v10 + v7 * (unsigned) v9), v35, v36); + TLOAD(v22, v34); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v24, v22); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(v26, v24, v25); + set_flag(PIPE_M, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v37, v26); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + } + #endif // __DAV_CUBE__ + + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/matmul.pto b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/matmul.pto new file mode 100644 index 00000000..7f9dffe3 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/matmul.pto @@ -0,0 +1,63 @@ +module { + func.func @RunTMATMULSplitK(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %arg4: i1, %arg5: i32) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c128_0 = arith.constant 128 : index + %c128_1 = arith.constant 128 : index + %c16384 = arith.constant 16384 : index + %c16384_2 = arith.constant 16384 : index + %c128_3 = arith.constant 128 : index + %c128_4 = arith.constant 128 : index + %0 = arith.index_cast %arg5 : i32 to index + %1 = pto.get_block_num + %2 = arith.index_cast %1 : i64 to index + %3 = pto.get_block_idx + %4 = arith.index_cast %3 : i64 to index + %5 = arith.divsi %0, %2 : index + %6 = arith.remsi %0, %2 : index + %7 = arith.cmpi slt, %4, %6 : index + %8 = arith.minui %4, %6 : index + %9 = arith.muli %4, %5 : index + %10 = arith.addi %9, %8 : index + %11 = arith.select %7, %c1, %c0 : index + %12 = arith.addi %5, %11 : index + %13 = arith.addi %10, %12 : index + %14 = arith.minui %13, %0 : index + %15 = pto.make_tensor_view %arg1, shape = [%0, %c128, %c128_0], strides = [%c16384, %c128_0, %c1] : !pto.tensor_view + %16 = pto.make_tensor_view %arg2, shape = [%c128_0, %c128_1], strides = [%c128_1, %c1] : !pto.tensor_view + %17 = pto.make_tensor_view %arg0, shape = [%0, %c128, %c128_1], strides = [%c16384_2, %c128_1, %c1] : !pto.tensor_view + %18 = pto.alloc_tile : !pto.tile_buf + %19 = pto.alloc_tile : !pto.tile_buf + %20 = pto.alloc_tile : !pto.tile_buf + %21 = pto.alloc_tile : !pto.tile_buf + %22 = pto.alloc_tile : !pto.tile_buf + %23 = pto.partition_view %16, offsets = [%c0, %c0], sizes = [%c128_0, %c128_4] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf32> + pto.tload ins(%23 : !pto.partition_tensor_view<128x128xf32>) outs(%19 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmov ins(%19 : !pto.tile_buf) outs(%21 : !pto.tile_buf) + scf.for %arg6 = %10 to %14 step %c1 { + %24 = pto.partition_view %15, offsets = [%arg6, %c0, %c0], sizes = [%c1, %c128_3, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf32> + %25 = pto.partition_view %17, offsets = [%arg6, %c0, %c0], sizes = [%c1, %c128_3, %c128_4] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf32> + pto.tload ins(%24 : !pto.partition_tensor_view<128x128xf32>) outs(%18 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmov ins(%18 : !pto.tile_buf) outs(%20 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul ins(%20, %21 : !pto.tile_buf, !pto.tile_buf) outs(%22 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tstore ins(%22 : !pto.tile_buf) outs(%25 : !pto.partition_tensor_view<128x128xf32>) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/matmul_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/matmul_builder.py new file mode 100644 index 00000000..b711f4d6 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/batch_matmul/matmul_dynbatch_multicore_opt/matmul/matmul_builder.py @@ -0,0 +1,147 @@ +# adapted from https://github.com/zhangstevenunity/PTOAS/blob/a301aa43b388d9b2e1ba0db8773b3a719e8c445b/test/samples/MatMul/tmatmulk.py + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + + +def build( + M=128, + K=128, + N=128, + validM=128, + validK=128, + validN=128, +): + def meta_data(): + dtype = pto.float32 + ptr_type = pto.PtrType(dtype) + bool_type = pto.bool + index_dtype = pto.int32 + + tv_a = pto.TensorType(rank=3, dtype=dtype) + tv_b = pto.TensorType(rank=2, dtype=dtype) + tv_out = pto.TensorType(rank=3, dtype=dtype) + + tile_view_a = pto.SubTensorType(shape=[M, K], dtype=dtype) + tile_view_b = pto.SubTensorType(shape=[K, N], dtype=dtype) + tile_view_out = pto.SubTensorType(shape=[M, N], dtype=dtype) + + tile_buf_aMat = pto.TileBufType(shape=[M, K], dtype=dtype, memory_space="MAT") + tile_buf_bMat = pto.TileBufType(shape=[K, N], dtype=dtype, memory_space="MAT") + tile_buf_aTile = pto.TileBufType(shape=[M, K], dtype=dtype, memory_space="LEFT") + tile_buf_bTile = pto.TileBufType( + shape=[K, N], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_cTile = pto.TileBufType(shape=[M, N], dtype=dtype, memory_space="ACC") + + return { + "ptr_type": ptr_type, + "bool_type": bool_type, + "index_dtype": index_dtype, + "tv_a": tv_a, + "tv_b": tv_b, + "tv_out": tv_out, + "tile_view_a": tile_view_a, + "tile_view_b": tile_view_b, + "tile_view_out": tile_view_out, + "tile_buf_aMat": tile_buf_aMat, + "tile_buf_bMat": tile_buf_bMat, + "tile_buf_aTile": tile_buf_aTile, + "tile_buf_bTile": tile_buf_bTile, + "tile_buf_cTile": tile_buf_cTile, + } + + const = s.const + + @to_ir_module(meta_data=meta_data) + def RunTMATMULSplitK( + out_ptr: "ptr_type", + a_ptr: "ptr_type", + b_ptr: "ptr_type", + bias_ptr: "ptr_type", + isBias: "bool_type", + batch_i32: "index_dtype", + ) -> None: + # Keep unused args to preserve original function signature/ABI. + _ = bias_ptr + _ = isBias + + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + cM = const(validM) + cK = const(validK) + cN = const(validN) + cKM = const(validK * validM) + cMN = const(validM * validN) + cTileM = const(M) + cTileN = const(N) + + batch = s.index_cast(batch_i32) + + # Distribute batches over cores with "base + remainder" policy. + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + + base = batch // num_blocks + rem = batch % num_blocks + lt_rem = s.lt(bid, rem) + min_bid_rem = s.min_u(bid, rem) + b_start = bid * base + min_bid_rem + length = base + s.select(lt_rem, c1, c0) + b_end = s.min_u(b_start + length, batch) + + tvA = pto.as_tensor( + tv_a, ptr=a_ptr, shape=[batch, cM, cK], strides=[cKM, cK, c1] + ) + tvB = pto.as_tensor(tv_b, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1]) + tvOut = pto.as_tensor( + tv_out, ptr=out_ptr, shape=[batch, cM, cN], strides=[cMN, cN, c1] + ) + + aMatTile = pto.alloc_tile(tile_buf_aMat) + bMatTile = pto.alloc_tile(tile_buf_bMat) + aTile = pto.alloc_tile(tile_buf_aTile) + bTile = pto.alloc_tile(tile_buf_bTile) + cTile = pto.alloc_tile(tile_buf_cTile) + + # B is shared across batches: load once GM->L1->L0B. + svB = pto.slice_view( + tile_view_b, source=tvB, offsets=[c0, c0], sizes=[cK, cTileN] + ) + pto.load(svB, bMatTile) + pto.record_wait_pair("LOAD", "MOV_M2L", event_id=0) + tile.mov(bMatTile, bTile) + + for b_idx in pto.range(b_start, b_end, c1): + svA = pto.slice_view( + tile_view_a, + source=tvA, + offsets=[b_idx, c0, c0], + sizes=[c1, cTileM, cK], + ) + svOut = pto.slice_view( + tile_view_out, + source=tvOut, + offsets=[b_idx, c0, c0], + sizes=[c1, cTileM, cTileN], + ) + + pto.load(svA, aMatTile) + pto.record_wait_pair("LOAD", "MOV_M2L", event_id=0) + + tile.mov(aMatTile, aTile) + pto.record_wait_pair("MOV_M2L", "MATMUL", event_id=0) + tile.matmul(aTile, bTile, cTile) + pto.record_wait_pair("MATMUL", "LOAD", event_id=0) + + pto.record_wait_pair("MATMUL", "STORE_ACC", event_id=0) + pto.store(cTile, svOut) + pto.record_wait_pair("STORE_ACC", "MATMUL", event_id=0) + + return RunTMATMULSplitK + + +if __name__ == "__main__": + m = build() + print(m) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/add.cpp b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/add.cpp new file mode 100644 index 00000000..46b9c50f --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/add.cpp @@ -0,0 +1,95 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void vec_add_1d_dynamic(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, int32_t v4) { + unsigned v5 = 0; + const int32_t v6 = 8192; + const int32_t v7 = 1; + const int32_t v8 = 0; + const int64_t v9 = 65536; + const int64_t v10 = 0; + const int64_t v11 = 32768; + using T = float; + int64_t v12 = get_block_idx(); + int64_t v13 = get_subblockid(); + int64_t v14 = get_subblockdim(); + int64_t v15 = (int64_t) v14; + int64_t v16 = get_block_num(); + int32_t v17 = (int32_t) ((int64_t) (uint64_t) ((int64_t) v16) * (uint64_t) v15); + int32_t v18 = v4 / v6; + int32_t v19 = v4 % v6 != v8 && v4 < v8 == v6 < v8 ? v18 + v7 : v18; + int32_t v20 = v19 / v17; + int32_t v21 = v19 % v17 != v8 && v19 < v8 == v17 < v8 ? v20 + v7 : v20; + int32_t v22 = (int32_t) ((uint32_t) ((int32_t) (int64_t) ((uint64_t) ((int64_t) (uint64_t) ((int64_t) v12) * (uint64_t) v15) + (uint64_t) ((int64_t) v13))) * (uint32_t) v21); + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v23; + TASSIGN(v23, v9); + Tile v24; + TASSIGN(v24, v10); + Tile v25; + TASSIGN(v25, v11); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + if (v22 < v19) { + int32_t v26 = (int32_t) ((uint32_t) v22 + (uint32_t) v21) > v19 ? (int32_t) ((uint32_t) v19 - (uint32_t) v22) : v21; + if ((int32_t) ((uint32_t) v26 * (uint32_t) v6) > v8) { + for (size_t v27 = (size_t) v8; v27 < ((size_t) v26); v27 += (size_t) v7) { + int32_t v28 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) ((int32_t) v27) + (uint32_t) v22) * (uint32_t) v6); + pto::Shape<1, 1, 1, 1, 8192> v29 = pto::Shape<1, 1, 1, 1, 8192>(); + pto::Stride<8192, 8192, 8192, 8192, 1> v30 = pto::Stride<8192, 8192, 8192, 8192, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND> v31 = GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND>(v1 + (v5 + (unsigned) v28 * (unsigned) v7), v29, v30); + pto::Shape<1, 1, 1, 1, 8192> v32 = pto::Shape<1, 1, 1, 1, 8192>(); + pto::Stride<8192, 8192, 8192, 8192, 1> v33 = pto::Stride<8192, 8192, 8192, 8192, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND> v34 = GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND>(v2 + (v5 + (unsigned) v28 * (unsigned) v7), v32, v33); + pto::Shape<1, 1, 1, 1, 8192> v35 = pto::Shape<1, 1, 1, 1, 8192>(); + pto::Stride<8192, 8192, 8192, 8192, 1> v36 = pto::Stride<8192, 8192, 8192, 8192, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND> v37 = GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND>(v3 + (v5 + (unsigned) v28 * (unsigned) v7), v35, v36); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v23, v31); + TLOAD(v24, v34); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TADD(v25, v23, v24); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pipe_barrier(PIPE_MTE3); + TSTORE(v37, v25); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + }; + }; + } + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/add.pto b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/add.pto new file mode 100644 index 00000000..0bd3223f --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/add.pto @@ -0,0 +1,52 @@ +module { + func.func @vec_add_1d_dynamic(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8192 = arith.constant 8192 : index + %0 = pto.get_block_idx + %1 = pto.get_subblock_idx + %2 = pto.get_subblock_num + %3 = pto.get_block_num + %4 = arith.muli %0, %2 : i64 + %5 = arith.addi %4, %1 : i64 + %6 = arith.index_cast %5 : i64 to index + %7 = arith.muli %3, %2 : i64 + %8 = arith.index_cast %7 : i64 to index + %9 = arith.index_cast %arg3 : i32 to index + %10 = arith.ceildivsi %9, %c8192 : index + %11 = arith.ceildivsi %10, %8 : index + %12 = arith.muli %6, %11 : index + pto.section.vector { + %13 = pto.make_tensor_view %arg0, shape = [%9], strides = [%c1] : !pto.tensor_view + %14 = pto.make_tensor_view %arg1, shape = [%9], strides = [%c1] : !pto.tensor_view + %15 = pto.make_tensor_view %arg2, shape = [%9], strides = [%c1] : !pto.tensor_view + %16 = pto.alloc_tile : !pto.tile_buf + %17 = pto.alloc_tile : !pto.tile_buf + %18 = pto.alloc_tile : !pto.tile_buf + %19 = arith.cmpi slt, %12, %10 : index + scf.if %19 { + %20 = arith.addi %12, %11 : index + %21 = arith.cmpi sgt, %20, %10 : index + %22 = arith.subi %10, %12 : index + %23 = arith.select %21, %22, %11 : index + %24 = arith.muli %23, %c8192 : index + %25 = arith.cmpi sgt, %24, %c0 : index + scf.if %25 { + scf.for %arg4 = %c0 to %23 step %c1 { + %26 = arith.addi %arg4, %12 : index + %27 = arith.muli %26, %c8192 : index + %28 = pto.partition_view %13, offsets = [%27], sizes = [%c8192] : !pto.tensor_view -> !pto.partition_tensor_view<1x8192xf32> + %29 = pto.partition_view %14, offsets = [%27], sizes = [%c8192] : !pto.tensor_view -> !pto.partition_tensor_view<1x8192xf32> + %30 = pto.partition_view %15, offsets = [%27], sizes = [%c8192] : !pto.tensor_view -> !pto.partition_tensor_view<1x8192xf32> + pto.tload ins(%28 : !pto.partition_tensor_view<1x8192xf32>) outs(%16 : !pto.tile_buf) + pto.tload ins(%29 : !pto.partition_tensor_view<1x8192xf32>) outs(%17 : !pto.tile_buf) + pto.tadd ins(%16, %17 : !pto.tile_buf, !pto.tile_buf) outs(%18 : !pto.tile_buf) + pto.tstore ins(%18 : !pto.tile_buf) outs(%30 : !pto.partition_tensor_view<1x8192xf32>) + } + } + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/add_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/add_builder.py new file mode 100644 index 00000000..c4a67865 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/add_builder.py @@ -0,0 +1,109 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + + +def meta_data(): + dtype = pto.float32 + index_dtype = pto.int32 + ptr_type = pto.PtrType(dtype) + tensor_type = pto.TensorType(rank=1, dtype=dtype) + tile_length = 8192 # >=16 KB DMA gets high BW util + subtensor_type = pto.SubTensorType(shape=[1, tile_length], dtype=dtype) + tile_cfg = pto.TileBufConfig() + tile_type = pto.TileBufType( + shape=[1, tile_length], + valid_shape=[1, tile_length], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + return { + "ptr_type": ptr_type, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_type": subtensor_type, + "tile_type": tile_type, + "tile_length": tile_length, + } + + +@to_ir_module(meta_data=meta_data) +def vec_add_1d_dynamic( + arg0: "ptr_type", + arg1: "ptr_type", + arg2: "ptr_type", + argN: "index_dtype", +) -> None: + c0 = const(0) + c1 = const(1) + c_tile = const(tile_length) + + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + + # Convert i64/i32 values to index for arithmetic ops. + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast(num_blocks * sub_bnum) # number of vector cores + total_elements = s.index_cast(argN) + + num_tiles_global = s.ceil_div(total_elements, c_tile) + num_tiles_per_core = s.ceil_div(num_tiles_global, num_cores) + tile_offset_this_core = vid * num_tiles_per_core + + with pto.vector_section(): + tv0 = pto.as_tensor(tensor_type, ptr=arg0, shape=[total_elements], strides=[c1]) + tv1 = pto.as_tensor(tensor_type, ptr=arg1, shape=[total_elements], strides=[c1]) + tv2 = pto.as_tensor(tensor_type, ptr=arg2, shape=[total_elements], strides=[c1]) + + tb0 = pto.alloc_tile(tile_type) + tb1 = pto.alloc_tile(tile_type) + tb2 = pto.alloc_tile(tile_type) + + # Skip whole core if its starting tile is already out-of-bound. + with pto.if_context(tile_offset_this_core < num_tiles_global): + tiles_end_this_core = tile_offset_this_core + num_tiles_per_core + need_truncate = tiles_end_this_core > num_tiles_global + remaining_tiles = num_tiles_global - tile_offset_this_core + + tiles_to_process = s.select( + need_truncate, remaining_tiles, num_tiles_per_core + ) + + elements_to_process = tiles_to_process * c_tile + with pto.if_context(elements_to_process > c0): + for i in pto.range(c0, tiles_to_process, c1): + tile_offset_global = i + tile_offset_this_core + offset_global = tile_offset_global * c_tile + + sv0 = pto.slice_view( + subtensor_type, + source=tv0, + offsets=[offset_global], + sizes=[c_tile], + ) + sv1 = pto.slice_view( + subtensor_type, + source=tv1, + offsets=[offset_global], + sizes=[c_tile], + ) + sv2 = pto.slice_view( + subtensor_type, + source=tv2, + offsets=[offset_global], + sizes=[c_tile], + ) + + pto.load(sv0, tb0) + pto.load(sv1, tb1) + tile.add(tb0, tb1, tb2) + pto.store(tb2, sv2) + + +if __name__ == "__main__": + module = vec_add_1d_dynamic + print(module) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/compile.sh new file mode 100644 index 00000000..2eacc832 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./add_builder.py > ./add.pto +ptoas --enable-insert-sync ./add.pto -o ./add.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/add_double.cpp b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/add_double.cpp new file mode 100644 index 00000000..c1a30590 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/add_double.cpp @@ -0,0 +1,130 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void vec_add_1d_dynamic(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, int32_t v4) { + unsigned v5 = 0; + const int32_t v6 = 8192; + const int32_t v7 = 2; + const int32_t v8 = 1; + const int32_t v9 = 0; + const int64_t v10 = 98304; + const int64_t v11 = 32768; + const int64_t v12 = 131072; + const int64_t v13 = 0; + const int64_t v14 = 65536; + const int64_t v15 = 163840; + using T = float; + int64_t v16 = get_block_idx(); + int64_t v17 = get_subblockid(); + int64_t v18 = get_subblockdim(); + int64_t v19 = (int64_t) v18; + int64_t v20 = get_block_num(); + int32_t v21 = (int32_t) ((int64_t) (uint64_t) ((int64_t) v20) * (uint64_t) v19); + int32_t v22 = v4 / v6; + int32_t v23 = v4 % v6 != v9 && v4 < v9 == v6 < v9 ? v22 + v8 : v22; + int32_t v24 = v23 / v21; + int32_t v25 = v23 % v21 != v9 && v23 < v9 == v21 < v9 ? v24 + v8 : v24; + int32_t v26 = (int32_t) ((uint32_t) ((int32_t) (int64_t) ((uint64_t) ((int64_t) (uint64_t) ((int64_t) v16) * (uint64_t) v19) + (uint64_t) ((int64_t) v17))) * (uint32_t) v25); + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v27; + TASSIGN(v27, v10); + Tile v28; + TASSIGN(v28, v11); + Tile v29; + TASSIGN(v29, v12); + Tile v30; + TASSIGN(v30, v13); + Tile v31; + TASSIGN(v31, v14); + Tile v32; + TASSIGN(v32, v15); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + if (v26 < v23) { + int32_t v33 = (int32_t) ((uint32_t) v26 + (uint32_t) v25) > v23 ? (int32_t) ((uint32_t) v23 - (uint32_t) v26) : v25; + if ((int32_t) ((uint32_t) v33 * (uint32_t) v6) > v9) { + for (size_t v34 = (size_t) v9; v34 < ((size_t) v33); v34 += (size_t) v8) { + int32_t v35 = (int32_t) v34; + int32_t v36 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v35 + (uint32_t) v26) * (uint32_t) v6); + pto::Shape<1, 1, 1, 1, 8192> v37 = pto::Shape<1, 1, 1, 1, 8192>(); + pto::Stride<8192, 8192, 8192, 8192, 1> v38 = pto::Stride<8192, 8192, 8192, 8192, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND> v39 = GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND>(v1 + (v5 + (unsigned) v36 * (unsigned) v8), v37, v38); + pto::Shape<1, 1, 1, 1, 8192> v40 = pto::Shape<1, 1, 1, 1, 8192>(); + pto::Stride<8192, 8192, 8192, 8192, 1> v41 = pto::Stride<8192, 8192, 8192, 8192, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND> v42 = GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND>(v2 + (v5 + (unsigned) v36 * (unsigned) v8), v40, v41); + pto::Shape<1, 1, 1, 1, 8192> v43 = pto::Shape<1, 1, 1, 1, 8192>(); + pto::Stride<8192, 8192, 8192, 8192, 1> v44 = pto::Stride<8192, 8192, 8192, 8192, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND> v45 = GlobalTensor, pto::Stride<8192, 8192, 8192, 8192, 1>, pto::Layout::ND>(v3 + (v5 + (unsigned) v36 * (unsigned) v8), v43, v44); + if (v35 % v7 == v9) { + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v27, v39); + TLOAD(v28, v42); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TADD(v29, v27, v28); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pipe_barrier(PIPE_MTE3); + TSTORE(v45, v29); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } else { + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v30, v39); + TLOAD(v31, v42); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TADD(v32, v30, v31); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + pipe_barrier(PIPE_MTE3); + TSTORE(v45, v32); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + }; + }; + }; + } + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/add_double.pto b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/add_double.pto new file mode 100644 index 00000000..e03a9e11 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/add_double.pto @@ -0,0 +1,65 @@ +module { + func.func @vec_add_1d_dynamic(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8192 = arith.constant 8192 : index + %0 = pto.get_block_idx + %1 = pto.get_subblock_idx + %2 = pto.get_subblock_num + %3 = pto.get_block_num + %4 = arith.muli %0, %2 : i64 + %5 = arith.addi %4, %1 : i64 + %6 = arith.index_cast %5 : i64 to index + %7 = arith.muli %3, %2 : i64 + %8 = arith.index_cast %7 : i64 to index + %9 = arith.index_cast %arg3 : i32 to index + %10 = arith.ceildivsi %9, %c8192 : index + %11 = arith.ceildivsi %10, %8 : index + %12 = arith.muli %6, %11 : index + pto.section.vector { + %13 = pto.make_tensor_view %arg0, shape = [%9], strides = [%c1] : !pto.tensor_view + %14 = pto.make_tensor_view %arg1, shape = [%9], strides = [%c1] : !pto.tensor_view + %15 = pto.make_tensor_view %arg2, shape = [%9], strides = [%c1] : !pto.tensor_view + %16 = pto.alloc_tile : !pto.tile_buf + %17 = pto.alloc_tile : !pto.tile_buf + %18 = pto.alloc_tile : !pto.tile_buf + %19 = pto.alloc_tile : !pto.tile_buf + %20 = pto.alloc_tile : !pto.tile_buf + %21 = pto.alloc_tile : !pto.tile_buf + %22 = arith.cmpi slt, %12, %10 : index + scf.if %22 { + %23 = arith.addi %12, %11 : index + %24 = arith.cmpi sgt, %23, %10 : index + %25 = arith.subi %10, %12 : index + %26 = arith.select %24, %25, %11 : index + %27 = arith.muli %26, %c8192 : index + %28 = arith.cmpi sgt, %27, %c0 : index + scf.if %28 { + scf.for %arg4 = %c0 to %26 step %c1 { + %29 = arith.addi %arg4, %12 : index + %30 = arith.muli %29, %c8192 : index + %31 = pto.partition_view %13, offsets = [%30], sizes = [%c8192] : !pto.tensor_view -> !pto.partition_tensor_view<1x8192xf32> + %32 = pto.partition_view %14, offsets = [%30], sizes = [%c8192] : !pto.tensor_view -> !pto.partition_tensor_view<1x8192xf32> + %33 = pto.partition_view %15, offsets = [%30], sizes = [%c8192] : !pto.tensor_view -> !pto.partition_tensor_view<1x8192xf32> + %34 = arith.remsi %arg4, %c2 : index + %35 = arith.cmpi eq, %34, %c0 : index + scf.if %35 { + pto.tload ins(%31 : !pto.partition_tensor_view<1x8192xf32>) outs(%16 : !pto.tile_buf) + pto.tload ins(%32 : !pto.partition_tensor_view<1x8192xf32>) outs(%17 : !pto.tile_buf) + pto.tadd ins(%16, %17 : !pto.tile_buf, !pto.tile_buf) outs(%18 : !pto.tile_buf) + pto.tstore ins(%18 : !pto.tile_buf) outs(%33 : !pto.partition_tensor_view<1x8192xf32>) + } else { + pto.tload ins(%31 : !pto.partition_tensor_view<1x8192xf32>) outs(%19 : !pto.tile_buf) + pto.tload ins(%32 : !pto.partition_tensor_view<1x8192xf32>) outs(%20 : !pto.tile_buf) + pto.tadd ins(%19, %20 : !pto.tile_buf, !pto.tile_buf) outs(%21 : !pto.tile_buf) + pto.tstore ins(%21 : !pto.tile_buf) outs(%33 : !pto.partition_tensor_view<1x8192xf32>) + } + } + } + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/add_double_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/add_double_builder.py new file mode 100644 index 00000000..022a18bd --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/add_double_builder.py @@ -0,0 +1,119 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + + +def meta_data(): + dtype = pto.float32 + index_dtype = pto.int32 + ptr_type = pto.PtrType(dtype) + tensor_type = pto.TensorType(rank=1, dtype=dtype) + tile_length = 8192 # >=16 KB DMA gets high BW util + subtensor_type = pto.SubTensorType(shape=[1, tile_length], dtype=dtype) + tile_cfg = pto.TileBufConfig() + tile_type = pto.TileBufType( + shape=[1, tile_length], + valid_shape=[1, tile_length], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + return { + "ptr_type": ptr_type, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_type": subtensor_type, + "tile_type": tile_type, + "tile_length": tile_length, + } + + +@to_ir_module(meta_data=meta_data) +def vec_add_1d_dynamic( + arg0: "ptr_type", + arg1: "ptr_type", + arg2: "ptr_type", + argN: "index_dtype", +) -> None: + c0 = const(0) + c1 = const(1) + c2 = const(2) + c_tile = const(tile_length) + + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + + # Convert i64/i32 values to index for arithmetic ops. + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast(num_blocks * sub_bnum) # number of vector cores + total_elements = s.index_cast(argN) + + num_tiles_global = s.ceil_div(total_elements, c_tile) + num_tiles_per_core = s.ceil_div(num_tiles_global, num_cores) + tile_offset_this_core = vid * num_tiles_per_core + + with pto.vector_section(): + tv0 = pto.as_tensor(tensor_type, ptr=arg0, shape=[total_elements], strides=[c1]) + tv1 = pto.as_tensor(tensor_type, ptr=arg1, shape=[total_elements], strides=[c1]) + tv2 = pto.as_tensor(tensor_type, ptr=arg2, shape=[total_elements], strides=[c1]) + + # Ping/pong tile buffers for software pipelining. + tb0_ping = pto.alloc_tile(tile_type) + tb1_ping = pto.alloc_tile(tile_type) + tb2_ping = pto.alloc_tile(tile_type) + tb0_pong = pto.alloc_tile(tile_type) + tb1_pong = pto.alloc_tile(tile_type) + tb2_pong = pto.alloc_tile(tile_type) + + # Skip whole core if its starting tile is already out-of-bound. + with pto.if_context(tile_offset_this_core < num_tiles_global): + tiles_end_this_core = tile_offset_this_core + num_tiles_per_core + need_truncate = tiles_end_this_core > num_tiles_global + remaining_tiles = num_tiles_global - tile_offset_this_core + + tiles_to_process = s.select( + need_truncate, remaining_tiles, num_tiles_per_core + ) + + elements_to_process = tiles_to_process * c_tile + with pto.if_context(elements_to_process > c0): + for i in pto.range(c0, tiles_to_process, c1): + tile_offset_global = i + tile_offset_this_core + offset_global = tile_offset_global * c_tile + + sv0 = pto.slice_view( + subtensor_type, + source=tv0, + offsets=[offset_global], + sizes=[c_tile], + ) + sv1 = pto.slice_view( + subtensor_type, + source=tv1, + offsets=[offset_global], + sizes=[c_tile], + ) + sv2 = pto.slice_view( + subtensor_type, + source=tv2, + offsets=[offset_global], + sizes=[c_tile], + ) + with pto.if_context((i % c2) == c0, has_else=True) as branch: + pto.load(sv0, tb0_ping) + pto.load(sv1, tb1_ping) + tile.add(tb0_ping, tb1_ping, tb2_ping) + pto.store(tb2_ping, sv2) + with branch.else_context(): + pto.load(sv0, tb0_pong) + pto.load(sv1, tb1_pong) + tile.add(tb0_pong, tb1_pong, tb2_pong) + pto.store(tb2_pong, sv2) + + +if __name__ == "__main__": + module = vec_add_1d_dynamic + print(module) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/compile.sh new file mode 100644 index 00000000..d4953a66 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/elementwise/add_dynamic_multicore/add_double/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./add_double_builder.py > ./add_double.pto +ptoas --enable-insert-sync ./add_double.pto -o ./add_double.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/compile.sh new file mode 100644 index 00000000..76d6cb80 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./hadamard_builder.py > ./hadamard_auto_sync.pto +ptoas --enable-insert-sync ./hadamard_auto_sync.pto -o ./hadamard_auto_sync.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/hadamard_auto_sync.cpp b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/hadamard_auto_sync.cpp new file mode 100644 index 00000000..d72fde97 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/hadamard_auto_sync.cpp @@ -0,0 +1,199 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void fast_hadamard_autosync(__gm__ half* v1, int32_t v2, int32_t v3, int32_t v4) { + unsigned v5 = 16384; + unsigned v6 = 1; + unsigned v7 = 0; + const int32_t v8 = 2; + const int32_t v9 = 1; + const int32_t v10 = 0; + const int32_t v11 = 8192; + const int64_t v12 = 0; + const int64_t v13 = 32768; + const int64_t v14 = 49152; + const int64_t v15 = 65536; + const int64_t v16 = 98304; + const int64_t v17 = 114688; + using T = float; + size_t v18 = (size_t) v10; + size_t v19 = (size_t) v9; + size_t v20 = (size_t) v4; + int64_t v21 = get_block_idx(); + int64_t v22 = get_subblockid(); + int64_t v23 = get_subblockdim(); + int64_t v24 = (int64_t) v23; + int64_t v25 = get_block_num(); + int32_t v26 = (int32_t) ((int64_t) (uint64_t) ((int64_t) v25) * (uint64_t) v24); + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + int32_t v27 = v2 / v26; + int32_t v28 = v2 % v26 != v10 && v2 < v10 == v26 < v10 ? v27 + v9 : v27; + int32_t v29 = (int32_t) ((uint32_t) ((int32_t) (int64_t) ((uint64_t) ((int64_t) (uint64_t) ((int64_t) v21) * (uint64_t) v24) + (uint64_t) ((int64_t) v22))) * (uint32_t) v28); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID4); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID5); + if (v29 < v2) { + int32_t v30 = (int32_t) ((uint32_t) v29 + (uint32_t) v28) > v2 ? (int32_t) ((uint32_t) v2 - (uint32_t) v29) : v28; + if (v30 > v10) { + int32_t v31 = (int32_t) ((uint32_t) v2 * (uint32_t) v3); + Tile v32; + TASSIGN(v32, v12); + Tile v33 = Tile(v3); + __ubuf__ half* v34 = v32.data(); + uint64_t v35 = reinterpret_cast(v34); + TASSIGN(v33, v35); + int32_t v36 = v3 / v8; + Tile v37; + TASSIGN(v37, v13); + Tile v38 = Tile(v36); + __ubuf__ half* v39 = v37.data(); + uint64_t v40 = reinterpret_cast(v39); + TASSIGN(v38, v40); + Tile v41; + TASSIGN(v41, v14); + Tile v42 = Tile(v36); + __ubuf__ half* v43 = v41.data(); + uint64_t v44 = reinterpret_cast(v43); + TASSIGN(v42, v44); + Tile v45; + TASSIGN(v45, v15); + Tile v46 = Tile(v3); + __ubuf__ half* v47 = v45.data(); + uint64_t v48 = reinterpret_cast(v47); + TASSIGN(v46, v48); + Tile v49; + TASSIGN(v49, v16); + Tile v50 = Tile(v36); + __ubuf__ half* v51 = v49.data(); + uint64_t v52 = reinterpret_cast(v51); + TASSIGN(v50, v52); + Tile v53; + TASSIGN(v53, v17); + Tile v54 = Tile(v36); + __ubuf__ half* v55 = v53.data(); + uint64_t v56 = reinterpret_cast(v55); + TASSIGN(v54, v56); + for (size_t v57 = v18; v57 < ((size_t) v30); v57 += v19) { + int32_t v58 = (int32_t) v57; + int32_t v59 = (int32_t) ((uint32_t) v30 - (uint32_t) v58); + int32_t v60 = v59 < v9 ? v59 : v9; + size_t v61 = (size_t) v60; + if (v60 > v10) { + int32_t v62 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v29 + (uint32_t) v58) * (uint32_t) v3); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1); + if (v58 % v8 == v10) { + for (size_t v63 = v18; v63 < v61; v63 += v19) { + unsigned v64 = (unsigned) v3 * v6; + pto::Shape<1, 1, 1, 1, -1> v65 = pto::Shape<1, 1, 1, 1, -1>(v3); + pto::Stride<-1, -1, -1, -1, 1> v66 = pto::Stride<-1, -1, -1, -1, 1>(v64, v64, v64, v64); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v67 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v7 + (unsigned) ((int32_t) (uint32_t) v62 + (uint32_t) ((int32_t) (uint32_t) ((int32_t) v63) * (uint32_t) v3)) * (unsigned) v9), v65, v66); + __ubuf__ half* v68 = v33.data(); + int64_t v69 = (int64_t) v3; + int32_t v70 = (int32_t) ((int64_t) (uint64_t) v69 - (uint64_t) ((int64_t) (uint64_t) v12 % (uint64_t) v69)); + Tile v71 = Tile(v70 < v11 ? v70 : v11); + uint64_t v72 = reinterpret_cast((__ubuf__ half*) (v68 + (v7 + v7 * v5) + v7 * v6)); + TASSIGN(v71, v72); + __ubuf__ half* v73 = v33.data(); + int32_t v74 = (int32_t) ((int64_t) (uint64_t) v69 - (uint64_t) ((int64_t) (uint64_t) ((int64_t) v36) % (uint64_t) v69)); + Tile v75 = Tile(v74 < v11 ? v74 : v11); + uint64_t v76 = reinterpret_cast((__ubuf__ half*) (v73 + (v7 + v7 * v5) + (unsigned) v36 * v6)); + TASSIGN(v75, v76); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID2); + TLOAD(v33, v67); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + for (size_t v77 = v18; v77 < v20; v77 += v19) { + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v38, v33); + TGATHER, Tile, MaskPattern::P1010>(v42, v33); + pipe_barrier(PIPE_V); + TADD(v71, v38, v42); + pipe_barrier(PIPE_V); + TSUB(v75, v38, v42); + }; + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pipe_barrier(PIPE_MTE3); + TSTORE(v67, v33); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID2); + }; + } else { + for (size_t v78 = v18; v78 < v61; v78 += v19) { + unsigned v79 = (unsigned) v3 * v6; + pto::Shape<1, 1, 1, 1, -1> v80 = pto::Shape<1, 1, 1, 1, -1>(v3); + pto::Stride<-1, -1, -1, -1, 1> v81 = pto::Stride<-1, -1, -1, -1, 1>(v79, v79, v79, v79); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v82 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v7 + (unsigned) ((int32_t) (uint32_t) v62 + (uint32_t) ((int32_t) (uint32_t) ((int32_t) v78) * (uint32_t) v3)) * (unsigned) v9), v80, v81); + __ubuf__ half* v83 = v46.data(); + int64_t v84 = (int64_t) v3; + int32_t v85 = (int32_t) ((int64_t) (uint64_t) v84 - (uint64_t) ((int64_t) (uint64_t) v12 % (uint64_t) v84)); + Tile v86 = Tile(v85 < v11 ? v85 : v11); + uint64_t v87 = reinterpret_cast((__ubuf__ half*) (v83 + (v7 + v7 * v5) + v7 * v6)); + TASSIGN(v86, v87); + __ubuf__ half* v88 = v46.data(); + int32_t v89 = (int32_t) ((int64_t) (uint64_t) v84 - (uint64_t) ((int64_t) (uint64_t) ((int64_t) v36) % (uint64_t) v84)); + Tile v90 = Tile(v89 < v11 ? v89 : v11); + uint64_t v91 = reinterpret_cast((__ubuf__ half*) (v88 + (v7 + v7 * v5) + (unsigned) v36 * v6)); + TASSIGN(v90, v91); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID4); + TLOAD(v46, v82); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + for (size_t v92 = v18; v92 < v20; v92 += v19) { + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v50, v46); + TGATHER, Tile, MaskPattern::P1010>(v54, v46); + pipe_barrier(PIPE_V); + TADD(v86, v50, v54); + pipe_barrier(PIPE_V); + TSUB(v90, v50, v54); + }; + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + pipe_barrier(PIPE_MTE3); + TSTORE(v82, v46); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID4); + }; + }; + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + }; + }; + }; + } + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID4); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID5); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/hadamard_auto_sync.pto b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/hadamard_auto_sync.pto new file mode 100644 index 00000000..081d2d1d --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/hadamard_auto_sync.pto @@ -0,0 +1,95 @@ +module { + func.func @fast_hadamard_autosync(%arg0: !pto.ptr, %arg1: i32, %arg2: i32, %arg3: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.index_cast %arg2 : i32 to index + %2 = arith.index_cast %arg3 : i32 to index + %3 = pto.get_block_idx + %4 = pto.get_subblock_idx + %5 = pto.get_subblock_num + %6 = pto.get_block_num + %7 = arith.muli %3, %5 : i64 + %8 = arith.addi %7, %4 : i64 + %9 = arith.index_cast %8 : i64 to index + %10 = arith.muli %6, %5 : i64 + %11 = arith.index_cast %10 : i64 to index + pto.section.vector { + %12 = arith.ceildivsi %0, %11 : index + %13 = arith.muli %9, %12 : index + %14 = arith.cmpi slt, %13, %0 : index + scf.if %14 { + %15 = arith.addi %13, %12 : index + %16 = arith.cmpi sgt, %15, %0 : index + %17 = arith.subi %0, %13 : index + %18 = arith.select %16, %17, %12 : index + %19 = arith.cmpi sgt, %18, %c0 : index + scf.if %19 { + %20 = arith.muli %0, %1 : index + %21 = pto.make_tensor_view %arg0, shape = [%20], strides = [%c1] : !pto.tensor_view + %22 = pto.alloc_tile valid_col = %1 : !pto.tile_buf + %23 = arith.divsi %1, %c2 : index + %24 = pto.alloc_tile valid_col = %23 : !pto.tile_buf + %25 = arith.divsi %1, %c2 : index + %26 = pto.alloc_tile valid_col = %25 : !pto.tile_buf + %27 = pto.alloc_tile valid_col = %1 : !pto.tile_buf + %28 = arith.divsi %1, %c2 : index + %29 = pto.alloc_tile valid_col = %28 : !pto.tile_buf + %30 = arith.divsi %1, %c2 : index + %31 = pto.alloc_tile valid_col = %30 : !pto.tile_buf + %32 = arith.divsi %1, %c2 : index + %33 = arith.ceildivsi %18, %c1 : index + scf.for %arg4 = %c0 to %33 step %c1 { + %34 = arith.muli %arg4, %c1 : index + %35 = arith.subi %18, %34 : index + %36 = arith.cmpi slt, %35, %c1 : index + %37 = arith.select %36, %35, %c1 : index + %38 = arith.cmpi sgt, %37, %c0 : index + scf.if %38 { + %39 = arith.addi %13, %34 : index + %40 = arith.muli %39, %1 : index + %41 = arith.remsi %arg4, %c2 : index + %42 = arith.cmpi eq, %41, %c0 : index + scf.if %42 { + scf.for %arg5 = %c0 to %37 step %c1 { + %43 = arith.muli %arg5, %1 : index + %44 = arith.addi %40, %43 : index + %45 = pto.partition_view %21, offsets = [%44], sizes = [%1] : !pto.tensor_view -> !pto.partition_tensor_view<1x16384xf16> + %46 = pto.subset %22[%c0, %c0] sizes [1, 8192] : !pto.tile_buf + %47 = pto.subset %22[%c0, %32] sizes [1, 8192] : !pto.tile_buf + pto.tload ins(%45 : !pto.partition_tensor_view<1x16384xf16>) outs(%22 : !pto.tile_buf) + scf.for %arg6 = %c0 to %2 step %c1 { + pto.tgather ins(%22, {maskPattern = #pto.mask_pattern} : !pto.tile_buf) outs(%24 : !pto.tile_buf) + pto.tgather ins(%22, {maskPattern = #pto.mask_pattern} : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tadd ins(%24, %26 : !pto.tile_buf, !pto.tile_buf) outs(%46 : !pto.tile_buf) + pto.tsub ins(%24, %26 : !pto.tile_buf, !pto.tile_buf) outs(%47 : !pto.tile_buf) + } + pto.tstore ins(%22 : !pto.tile_buf) outs(%45 : !pto.partition_tensor_view<1x16384xf16>) + } + } else { + scf.for %arg5 = %c0 to %37 step %c1 { + %43 = arith.muli %arg5, %1 : index + %44 = arith.addi %40, %43 : index + %45 = pto.partition_view %21, offsets = [%44], sizes = [%1] : !pto.tensor_view -> !pto.partition_tensor_view<1x16384xf16> + %46 = pto.subset %27[%c0, %c0] sizes [1, 8192] : !pto.tile_buf + %47 = pto.subset %27[%c0, %32] sizes [1, 8192] : !pto.tile_buf + pto.tload ins(%45 : !pto.partition_tensor_view<1x16384xf16>) outs(%27 : !pto.tile_buf) + scf.for %arg6 = %c0 to %2 step %c1 { + pto.tgather ins(%27, {maskPattern = #pto.mask_pattern} : !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tgather ins(%27, {maskPattern = #pto.mask_pattern} : !pto.tile_buf) outs(%31 : !pto.tile_buf) + pto.tadd ins(%29, %31 : !pto.tile_buf, !pto.tile_buf) outs(%46 : !pto.tile_buf) + pto.tsub ins(%29, %31 : !pto.tile_buf, !pto.tile_buf) outs(%47 : !pto.tile_buf) + } + pto.tstore ins(%27 : !pto.tile_buf) outs(%45 : !pto.partition_tensor_view<1x16384xf16>) + } + } + } + } + } + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/hadamard_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/hadamard_builder.py new file mode 100644 index 00000000..6d641a26 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_auto_sync/hadamard_builder.py @@ -0,0 +1,286 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +ELEMENTS_PER_TILE = 32 * 1024 // 2 # 32KB UB / sizeof(fp16) +HALF_ELEMENTS_PER_TILE = ELEMENTS_PER_TILE // 2 + + +def meta_data(): + dtype = pto.float16 + ptr_type = pto.PtrType(dtype) + index_dtype = pto.int32 + + tensor_type = pto.TensorType(rank=1, dtype=dtype) + subtensor_full = pto.SubTensorType(shape=[1, ELEMENTS_PER_TILE], dtype=dtype) + subtensor_half = pto.SubTensorType(shape=[1, HALF_ELEMENTS_PER_TILE], dtype=dtype) + + tile_cfg = pto.TileBufConfig() + tile_full = pto.TileBufType( + shape=[1, ELEMENTS_PER_TILE], + valid_shape=[1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + tile_half = pto.TileBufType( + shape=[1, HALF_ELEMENTS_PER_TILE], + valid_shape=[1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + + return { + "ptr_type": ptr_type, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_full": subtensor_full, + "subtensor_half": subtensor_half, + "tile_full": tile_full, + "tile_half": tile_half, + } + + +@to_ir_module(meta_data=meta_data) +def fast_hadamard_autosync( + x_ptr: "ptr_type", + batch_i32: "index_dtype", + n_i32: "index_dtype", + log2_n_i32: "index_dtype", +) -> None: + c0 = const(0) + c1 = const(1) + c2 = const(2) + + batch = s.index_cast(batch_i32) + n = s.index_cast(n_i32) + log2_n = s.index_cast(log2_n_i32) + + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast(num_blocks * sub_bnum) # number of vector cores + + with pto.vector_section(): + samples_per_core = s.ceil_div(batch, num_cores) + sample_offset = vid * samples_per_core + + with pto.if_context(sample_offset < batch): + samples_end = sample_offset + samples_per_core + samples_to_process = s.select( + samples_end > batch, + batch - sample_offset, + samples_per_core, + ) + + with pto.if_context(samples_to_process > c0): + total_elements = batch * n + tv_x = pto.as_tensor( + tensor_type, ptr=x_ptr, shape=[total_elements], strides=[c1] + ) + + # Two independent tile sets (ping/pong) so event_id 0/1 map to + # disjoint UB buffers, matching the manual C++ reference. + tb_row_0 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + + tb_row_1 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + + n_half = n // c2 + + # Keep one sample per chunk. Multi-sample chunks interact + # poorly with static tile subset sizing in current PTO Python + # bindings and can corrupt rows for larger batches. + samples_per_load = c1 + num_chunks = s.ceil_div(samples_to_process, samples_per_load) + + def process_rows(tb_row, tb_even, tb_odd, gm_offset, cur_samples): + for s in pto.range(c0, cur_samples, c1): + row_offset = gm_offset + s * n + sv_row = pto.slice_view( + subtensor_full, source=tv_x, offsets=[row_offset], sizes=[n] + ) + # Alias row halves inside UB row tile (no GM round-trip + # per Hadamard iteration). + tb_first = tile.subset( + tb_row, [c0, c0], [1, HALF_ELEMENTS_PER_TILE] + ) + tb_second = tile.subset( + tb_row, [c0, n_half], [1, HALF_ELEMENTS_PER_TILE] + ) + + pto.load(sv_row, tb_row) + for _ in pto.range(c0, log2_n, c1): + tile.gather(tb_row, tb_even, mask_pattern="P0101") + tile.gather(tb_row, tb_odd, mask_pattern="P1010") + tile.add(tb_even, tb_odd, tb_first) + tile.sub(tb_even, tb_odd, tb_second) + pto.store(tb_row, sv_row) + + for chunk_i in pto.range(c0, num_chunks, c1): + sample_done = chunk_i * samples_per_load + chunk_left = samples_to_process - sample_done + cur_samples = s.select( + chunk_left < samples_per_load, chunk_left, samples_per_load + ) + + with pto.if_context(cur_samples > c0): + gm_offset = (sample_offset + sample_done) * n + use_ev0 = (chunk_i % c2) == c0 + + with pto.if_context(use_ev0, has_else=True) as branch: + process_rows( + tb_row_0, tb_even_0, tb_odd_0, gm_offset, cur_samples + ) + with branch.else_context(): + process_rows( + tb_row_1, tb_even_1, tb_odd_1, gm_offset, cur_samples + ) + + +@to_ir_module(meta_data=meta_data) +def fast_hadamard_manualsync( + x_ptr: "ptr_type", + batch_i32: "index_dtype", + n_i32: "index_dtype", + log2_n_i32: "index_dtype", +) -> None: + c0 = const(0) + c1 = const(1) + c2 = const(2) + + batch = s.index_cast(batch_i32) + n = s.index_cast(n_i32) + log2_n = s.index_cast(log2_n_i32) + + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast(num_blocks * sub_bnum) # number of vector cores + + with pto.vector_section(): + samples_per_core = s.ceil_div(batch, num_cores) + sample_offset = vid * samples_per_core + + with pto.if_context(sample_offset < batch): + samples_end = sample_offset + samples_per_core + samples_to_process = s.select( + samples_end > batch, + batch - sample_offset, + samples_per_core, + ) + + with pto.if_context(samples_to_process > c0): + total_elements = batch * n + tv_x = pto.as_tensor( + tensor_type, ptr=x_ptr, shape=[total_elements], strides=[c1] + ) + + # Two independent tile sets (ping/pong) so event_id 0/1 map to + # disjoint UB buffers, matching the manual C++ reference. + tb_row_0 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + + tb_row_1 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + + n_half = n // c2 + + # Keep one sample per chunk. Multi-sample chunks interact + # poorly with static tile subset sizing in current PTO Python + # bindings and can corrupt rows for larger batches. + samples_per_load = c1 + num_chunks = s.ceil_div(samples_to_process, samples_per_load) + + def process_rows( + tb_row, tb_even, tb_odd, event_id, gm_offset, cur_samples + ): + for s in pto.range(c0, cur_samples, c1): + row_offset = gm_offset + s * n + sv_row = pto.slice_view( + subtensor_full, source=tv_x, offsets=[row_offset], sizes=[n] + ) + # Alias row halves inside UB row tile (no GM round-trip + # per Hadamard iteration). + tb_first = tile.subset( + tb_row, [c0, c0], [1, HALF_ELEMENTS_PER_TILE] + ) + tb_second = tile.subset( + tb_row, [c0, n_half], [1, HALF_ELEMENTS_PER_TILE] + ) + + pto.wait_event("VEC", "LOAD", event_id=event_id) + pto.wait_event("STORE_VEC", "VEC", event_id=event_id) + pto.load(sv_row, tb_row) + pto.record_wait_pair("LOAD", "VEC", event_id=event_id) + + for _ in pto.range(c0, log2_n, c1): + tile.gather(tb_row, tb_even, mask_pattern="P0101") + tile.gather(tb_row, tb_odd, mask_pattern="P1010") + pto.barrier("VEC") + tile.add(tb_even, tb_odd, tb_first) + tile.sub(tb_even, tb_odd, tb_second) + pto.barrier("VEC") + + pto.record_wait_pair("VEC", "STORE_VEC", event_id=event_id) + pto.store(tb_row, sv_row) + pto.record_event("STORE_VEC", "VEC", event_id=event_id) + pto.record_event("VEC", "LOAD", event_id=event_id) + + for event_id in (0, 1): + pto.record_event("VEC", "LOAD", event_id=event_id) + pto.record_event("STORE_VEC", "VEC", event_id=event_id) + + for chunk_i in pto.range(c0, num_chunks, c1): + sample_done = chunk_i * samples_per_load + chunk_left = samples_to_process - sample_done + cur_samples = s.select( + chunk_left < samples_per_load, chunk_left, samples_per_load + ) + + with pto.if_context(cur_samples > c0): + gm_offset = (sample_offset + sample_done) * n + use_ev0 = (chunk_i % c2) == c0 + + with pto.if_context(use_ev0, has_else=True) as branch: + process_rows( + tb_row_0, tb_even_0, tb_odd_0, 0, gm_offset, cur_samples + ) + with branch.else_context(): + process_rows( + tb_row_1, tb_even_1, tb_odd_1, 1, gm_offset, cur_samples + ) + + for event_id in (0, 1): + pto.wait_event("VEC", "LOAD", event_id=event_id) + pto.wait_event("STORE_VEC", "VEC", event_id=event_id) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--manual-sync", + action="store_true", + help="Emit explicit record/wait events instead of relying on --enable-insert-sync.", + ) + args = parser.parse_args() + if args.manual_sync: + module = fast_hadamard_manualsync + else: + module = fast_hadamard_autosync + print(module) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/compile.sh new file mode 100644 index 00000000..cac03f7b --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./hadamard_builder.py --manual-sync > ./hadamard_manual_sync.pto +ptoas ./hadamard_manual_sync.pto -o ./hadamard_manual_sync.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/hadamard_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/hadamard_builder.py new file mode 100644 index 00000000..6d641a26 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/hadamard_builder.py @@ -0,0 +1,286 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +ELEMENTS_PER_TILE = 32 * 1024 // 2 # 32KB UB / sizeof(fp16) +HALF_ELEMENTS_PER_TILE = ELEMENTS_PER_TILE // 2 + + +def meta_data(): + dtype = pto.float16 + ptr_type = pto.PtrType(dtype) + index_dtype = pto.int32 + + tensor_type = pto.TensorType(rank=1, dtype=dtype) + subtensor_full = pto.SubTensorType(shape=[1, ELEMENTS_PER_TILE], dtype=dtype) + subtensor_half = pto.SubTensorType(shape=[1, HALF_ELEMENTS_PER_TILE], dtype=dtype) + + tile_cfg = pto.TileBufConfig() + tile_full = pto.TileBufType( + shape=[1, ELEMENTS_PER_TILE], + valid_shape=[1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + tile_half = pto.TileBufType( + shape=[1, HALF_ELEMENTS_PER_TILE], + valid_shape=[1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + + return { + "ptr_type": ptr_type, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_full": subtensor_full, + "subtensor_half": subtensor_half, + "tile_full": tile_full, + "tile_half": tile_half, + } + + +@to_ir_module(meta_data=meta_data) +def fast_hadamard_autosync( + x_ptr: "ptr_type", + batch_i32: "index_dtype", + n_i32: "index_dtype", + log2_n_i32: "index_dtype", +) -> None: + c0 = const(0) + c1 = const(1) + c2 = const(2) + + batch = s.index_cast(batch_i32) + n = s.index_cast(n_i32) + log2_n = s.index_cast(log2_n_i32) + + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast(num_blocks * sub_bnum) # number of vector cores + + with pto.vector_section(): + samples_per_core = s.ceil_div(batch, num_cores) + sample_offset = vid * samples_per_core + + with pto.if_context(sample_offset < batch): + samples_end = sample_offset + samples_per_core + samples_to_process = s.select( + samples_end > batch, + batch - sample_offset, + samples_per_core, + ) + + with pto.if_context(samples_to_process > c0): + total_elements = batch * n + tv_x = pto.as_tensor( + tensor_type, ptr=x_ptr, shape=[total_elements], strides=[c1] + ) + + # Two independent tile sets (ping/pong) so event_id 0/1 map to + # disjoint UB buffers, matching the manual C++ reference. + tb_row_0 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + + tb_row_1 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + + n_half = n // c2 + + # Keep one sample per chunk. Multi-sample chunks interact + # poorly with static tile subset sizing in current PTO Python + # bindings and can corrupt rows for larger batches. + samples_per_load = c1 + num_chunks = s.ceil_div(samples_to_process, samples_per_load) + + def process_rows(tb_row, tb_even, tb_odd, gm_offset, cur_samples): + for s in pto.range(c0, cur_samples, c1): + row_offset = gm_offset + s * n + sv_row = pto.slice_view( + subtensor_full, source=tv_x, offsets=[row_offset], sizes=[n] + ) + # Alias row halves inside UB row tile (no GM round-trip + # per Hadamard iteration). + tb_first = tile.subset( + tb_row, [c0, c0], [1, HALF_ELEMENTS_PER_TILE] + ) + tb_second = tile.subset( + tb_row, [c0, n_half], [1, HALF_ELEMENTS_PER_TILE] + ) + + pto.load(sv_row, tb_row) + for _ in pto.range(c0, log2_n, c1): + tile.gather(tb_row, tb_even, mask_pattern="P0101") + tile.gather(tb_row, tb_odd, mask_pattern="P1010") + tile.add(tb_even, tb_odd, tb_first) + tile.sub(tb_even, tb_odd, tb_second) + pto.store(tb_row, sv_row) + + for chunk_i in pto.range(c0, num_chunks, c1): + sample_done = chunk_i * samples_per_load + chunk_left = samples_to_process - sample_done + cur_samples = s.select( + chunk_left < samples_per_load, chunk_left, samples_per_load + ) + + with pto.if_context(cur_samples > c0): + gm_offset = (sample_offset + sample_done) * n + use_ev0 = (chunk_i % c2) == c0 + + with pto.if_context(use_ev0, has_else=True) as branch: + process_rows( + tb_row_0, tb_even_0, tb_odd_0, gm_offset, cur_samples + ) + with branch.else_context(): + process_rows( + tb_row_1, tb_even_1, tb_odd_1, gm_offset, cur_samples + ) + + +@to_ir_module(meta_data=meta_data) +def fast_hadamard_manualsync( + x_ptr: "ptr_type", + batch_i32: "index_dtype", + n_i32: "index_dtype", + log2_n_i32: "index_dtype", +) -> None: + c0 = const(0) + c1 = const(1) + c2 = const(2) + + batch = s.index_cast(batch_i32) + n = s.index_cast(n_i32) + log2_n = s.index_cast(log2_n_i32) + + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast(num_blocks * sub_bnum) # number of vector cores + + with pto.vector_section(): + samples_per_core = s.ceil_div(batch, num_cores) + sample_offset = vid * samples_per_core + + with pto.if_context(sample_offset < batch): + samples_end = sample_offset + samples_per_core + samples_to_process = s.select( + samples_end > batch, + batch - sample_offset, + samples_per_core, + ) + + with pto.if_context(samples_to_process > c0): + total_elements = batch * n + tv_x = pto.as_tensor( + tensor_type, ptr=x_ptr, shape=[total_elements], strides=[c1] + ) + + # Two independent tile sets (ping/pong) so event_id 0/1 map to + # disjoint UB buffers, matching the manual C++ reference. + tb_row_0 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + + tb_row_1 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + + n_half = n // c2 + + # Keep one sample per chunk. Multi-sample chunks interact + # poorly with static tile subset sizing in current PTO Python + # bindings and can corrupt rows for larger batches. + samples_per_load = c1 + num_chunks = s.ceil_div(samples_to_process, samples_per_load) + + def process_rows( + tb_row, tb_even, tb_odd, event_id, gm_offset, cur_samples + ): + for s in pto.range(c0, cur_samples, c1): + row_offset = gm_offset + s * n + sv_row = pto.slice_view( + subtensor_full, source=tv_x, offsets=[row_offset], sizes=[n] + ) + # Alias row halves inside UB row tile (no GM round-trip + # per Hadamard iteration). + tb_first = tile.subset( + tb_row, [c0, c0], [1, HALF_ELEMENTS_PER_TILE] + ) + tb_second = tile.subset( + tb_row, [c0, n_half], [1, HALF_ELEMENTS_PER_TILE] + ) + + pto.wait_event("VEC", "LOAD", event_id=event_id) + pto.wait_event("STORE_VEC", "VEC", event_id=event_id) + pto.load(sv_row, tb_row) + pto.record_wait_pair("LOAD", "VEC", event_id=event_id) + + for _ in pto.range(c0, log2_n, c1): + tile.gather(tb_row, tb_even, mask_pattern="P0101") + tile.gather(tb_row, tb_odd, mask_pattern="P1010") + pto.barrier("VEC") + tile.add(tb_even, tb_odd, tb_first) + tile.sub(tb_even, tb_odd, tb_second) + pto.barrier("VEC") + + pto.record_wait_pair("VEC", "STORE_VEC", event_id=event_id) + pto.store(tb_row, sv_row) + pto.record_event("STORE_VEC", "VEC", event_id=event_id) + pto.record_event("VEC", "LOAD", event_id=event_id) + + for event_id in (0, 1): + pto.record_event("VEC", "LOAD", event_id=event_id) + pto.record_event("STORE_VEC", "VEC", event_id=event_id) + + for chunk_i in pto.range(c0, num_chunks, c1): + sample_done = chunk_i * samples_per_load + chunk_left = samples_to_process - sample_done + cur_samples = s.select( + chunk_left < samples_per_load, chunk_left, samples_per_load + ) + + with pto.if_context(cur_samples > c0): + gm_offset = (sample_offset + sample_done) * n + use_ev0 = (chunk_i % c2) == c0 + + with pto.if_context(use_ev0, has_else=True) as branch: + process_rows( + tb_row_0, tb_even_0, tb_odd_0, 0, gm_offset, cur_samples + ) + with branch.else_context(): + process_rows( + tb_row_1, tb_even_1, tb_odd_1, 1, gm_offset, cur_samples + ) + + for event_id in (0, 1): + pto.wait_event("VEC", "LOAD", event_id=event_id) + pto.wait_event("STORE_VEC", "VEC", event_id=event_id) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--manual-sync", + action="store_true", + help="Emit explicit record/wait events instead of relying on --enable-insert-sync.", + ) + args = parser.parse_args() + if args.manual_sync: + module = fast_hadamard_manualsync + else: + module = fast_hadamard_autosync + print(module) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/hadamard_manual_sync.cpp b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/hadamard_manual_sync.cpp new file mode 100644 index 00000000..5b65c93c --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/hadamard_manual_sync.cpp @@ -0,0 +1,190 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void fast_hadamard_manualsync(__gm__ half* v1, int32_t v2, int32_t v3, int32_t v4) { + unsigned v5 = 16384; + unsigned v6 = 1; + unsigned v7 = 0; + const int32_t v8 = 2; + const int32_t v9 = 1; + const int32_t v10 = 0; + const int32_t v11 = 8192; + const int64_t v12 = 0; + const int64_t v13 = 32768; + const int64_t v14 = 49152; + const int64_t v15 = 65536; + const int64_t v16 = 98304; + const int64_t v17 = 114688; + using T = float; + size_t v18 = (size_t) v10; + size_t v19 = (size_t) v9; + size_t v20 = (size_t) v4; + int64_t v21 = get_block_idx(); + int64_t v22 = get_subblockid(); + int64_t v23 = get_subblockdim(); + int64_t v24 = (int64_t) v23; + int64_t v25 = get_block_num(); + int32_t v26 = (int32_t) ((int64_t) (uint64_t) ((int64_t) v25) * (uint64_t) v24); + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + int32_t v27 = v2 / v26; + int32_t v28 = v2 % v26 != v10 && v2 < v10 == v26 < v10 ? v27 + v9 : v27; + int32_t v29 = (int32_t) ((uint32_t) ((int32_t) (int64_t) ((uint64_t) ((int64_t) (uint64_t) ((int64_t) v21) * (uint64_t) v24) + (uint64_t) ((int64_t) v22))) * (uint32_t) v28); + if (v29 < v2) { + int32_t v30 = (int32_t) ((uint32_t) v29 + (uint32_t) v28) > v2 ? (int32_t) ((uint32_t) v2 - (uint32_t) v29) : v28; + if (v30 > v10) { + int32_t v31 = (int32_t) ((uint32_t) v2 * (uint32_t) v3); + Tile v32; + TASSIGN(v32, v12); + Tile v33 = Tile(v3); + __ubuf__ half* v34 = v32.data(); + uint64_t v35 = reinterpret_cast(v34); + TASSIGN(v33, v35); + int32_t v36 = v3 / v8; + Tile v37; + TASSIGN(v37, v13); + Tile v38 = Tile(v36); + __ubuf__ half* v39 = v37.data(); + uint64_t v40 = reinterpret_cast(v39); + TASSIGN(v38, v40); + Tile v41; + TASSIGN(v41, v14); + Tile v42 = Tile(v36); + __ubuf__ half* v43 = v41.data(); + uint64_t v44 = reinterpret_cast(v43); + TASSIGN(v42, v44); + Tile v45; + TASSIGN(v45, v15); + Tile v46 = Tile(v3); + __ubuf__ half* v47 = v45.data(); + uint64_t v48 = reinterpret_cast(v47); + TASSIGN(v46, v48); + Tile v49; + TASSIGN(v49, v16); + Tile v50 = Tile(v36); + __ubuf__ half* v51 = v49.data(); + uint64_t v52 = reinterpret_cast(v51); + TASSIGN(v50, v52); + Tile v53; + TASSIGN(v53, v17); + Tile v54 = Tile(v36); + __ubuf__ half* v55 = v53.data(); + uint64_t v56 = reinterpret_cast(v55); + TASSIGN(v54, v56); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + for (size_t v57 = v18; v57 < ((size_t) v30); v57 += v19) { + int32_t v58 = (int32_t) v57; + int32_t v59 = (int32_t) ((uint32_t) v30 - (uint32_t) v58); + int32_t v60 = v59 < v9 ? v59 : v9; + size_t v61 = (size_t) v60; + if (v60 > v10) { + int32_t v62 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v29 + (uint32_t) v58) * (uint32_t) v3); + if (v58 % v8 == v10) { + for (size_t v63 = v18; v63 < v61; v63 += v19) { + unsigned v64 = (unsigned) v3 * v6; + pto::Shape<1, 1, 1, 1, -1> v65 = pto::Shape<1, 1, 1, 1, -1>(v3); + pto::Stride<-1, -1, -1, -1, 1> v66 = pto::Stride<-1, -1, -1, -1, 1>(v64, v64, v64, v64); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v67 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v7 + (unsigned) ((int32_t) (uint32_t) v62 + (uint32_t) ((int32_t) (uint32_t) ((int32_t) v63) * (uint32_t) v3)) * (unsigned) v9), v65, v66); + __ubuf__ half* v68 = v33.data(); + int64_t v69 = (int64_t) v3; + int32_t v70 = (int32_t) ((int64_t) (uint64_t) v69 - (uint64_t) ((int64_t) (uint64_t) v12 % (uint64_t) v69)); + Tile v71 = Tile(v70 < v11 ? v70 : v11); + uint64_t v72 = reinterpret_cast((__ubuf__ half*) (v68 + (v7 + v7 * v5) + v7 * v6)); + TASSIGN(v71, v72); + __ubuf__ half* v73 = v33.data(); + int32_t v74 = (int32_t) ((int64_t) (uint64_t) v69 - (uint64_t) ((int64_t) (uint64_t) ((int64_t) v36) % (uint64_t) v69)); + Tile v75 = Tile(v74 < v11 ? v74 : v11); + uint64_t v76 = reinterpret_cast((__ubuf__ half*) (v73 + (v7 + v7 * v5) + (unsigned) v36 * v6)); + TASSIGN(v75, v76); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TLOAD(v33, v67); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + for (size_t v77 = v18; v77 < v20; v77 += v19) { + TGATHER, Tile, MaskPattern::P0101>(v38, v33); + TGATHER, Tile, MaskPattern::P1010>(v42, v33); + pipe_barrier(PIPE_V); + TADD(v71, v38, v42); + TSUB(v75, v38, v42); + pipe_barrier(PIPE_V); + }; + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v67, v33); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + }; + } else { + for (size_t v78 = v18; v78 < v61; v78 += v19) { + unsigned v79 = (unsigned) v3 * v6; + pto::Shape<1, 1, 1, 1, -1> v80 = pto::Shape<1, 1, 1, 1, -1>(v3); + pto::Stride<-1, -1, -1, -1, 1> v81 = pto::Stride<-1, -1, -1, -1, 1>(v79, v79, v79, v79); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v82 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v7 + (unsigned) ((int32_t) (uint32_t) v62 + (uint32_t) ((int32_t) (uint32_t) ((int32_t) v78) * (uint32_t) v3)) * (unsigned) v9), v80, v81); + __ubuf__ half* v83 = v46.data(); + int64_t v84 = (int64_t) v3; + int32_t v85 = (int32_t) ((int64_t) (uint64_t) v84 - (uint64_t) ((int64_t) (uint64_t) v12 % (uint64_t) v84)); + Tile v86 = Tile(v85 < v11 ? v85 : v11); + uint64_t v87 = reinterpret_cast((__ubuf__ half*) (v83 + (v7 + v7 * v5) + v7 * v6)); + TASSIGN(v86, v87); + __ubuf__ half* v88 = v46.data(); + int32_t v89 = (int32_t) ((int64_t) (uint64_t) v84 - (uint64_t) ((int64_t) (uint64_t) ((int64_t) v36) % (uint64_t) v84)); + Tile v90 = Tile(v89 < v11 ? v89 : v11); + uint64_t v91 = reinterpret_cast((__ubuf__ half*) (v88 + (v7 + v7 * v5) + (unsigned) v36 * v6)); + TASSIGN(v90, v91); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + TLOAD(v46, v82); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + for (size_t v92 = v18; v92 < v20; v92 += v19) { + TGATHER, Tile, MaskPattern::P0101>(v50, v46); + TGATHER, Tile, MaskPattern::P1010>(v54, v46); + pipe_barrier(PIPE_V); + TADD(v86, v50, v54); + TSUB(v90, v50, v54); + pipe_barrier(PIPE_V); + }; + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(v82, v46); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + }; + }; + }; + }; + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + }; + } + #endif // __DAV_VEC__ + + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/hadamard_manual_sync.pto b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/hadamard_manual_sync.pto new file mode 100644 index 00000000..936d1e23 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_hadamard/hadamard_manual_sync/hadamard_manual_sync.pto @@ -0,0 +1,123 @@ +module { + func.func @fast_hadamard_manualsync(%arg0: !pto.ptr, %arg1: i32, %arg2: i32, %arg3: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = arith.index_cast %arg1 : i32 to index + %1 = arith.index_cast %arg2 : i32 to index + %2 = arith.index_cast %arg3 : i32 to index + %3 = pto.get_block_idx + %4 = pto.get_subblock_idx + %5 = pto.get_subblock_num + %6 = pto.get_block_num + %7 = arith.muli %3, %5 : i64 + %8 = arith.addi %7, %4 : i64 + %9 = arith.index_cast %8 : i64 to index + %10 = arith.muli %6, %5 : i64 + %11 = arith.index_cast %10 : i64 to index + pto.section.vector { + %12 = arith.ceildivsi %0, %11 : index + %13 = arith.muli %9, %12 : index + %14 = arith.cmpi slt, %13, %0 : index + scf.if %14 { + %15 = arith.addi %13, %12 : index + %16 = arith.cmpi sgt, %15, %0 : index + %17 = arith.subi %0, %13 : index + %18 = arith.select %16, %17, %12 : index + %19 = arith.cmpi sgt, %18, %c0 : index + scf.if %19 { + %20 = arith.muli %0, %1 : index + %21 = pto.make_tensor_view %arg0, shape = [%20], strides = [%c1] : !pto.tensor_view + %22 = pto.alloc_tile valid_col = %1 : !pto.tile_buf + %23 = arith.divsi %1, %c2 : index + %24 = pto.alloc_tile valid_col = %23 : !pto.tile_buf + %25 = arith.divsi %1, %c2 : index + %26 = pto.alloc_tile valid_col = %25 : !pto.tile_buf + %27 = pto.alloc_tile valid_col = %1 : !pto.tile_buf + %28 = arith.divsi %1, %c2 : index + %29 = pto.alloc_tile valid_col = %28 : !pto.tile_buf + %30 = arith.divsi %1, %c2 : index + %31 = pto.alloc_tile valid_col = %30 : !pto.tile_buf + %32 = arith.divsi %1, %c2 : index + %33 = arith.ceildivsi %18, %c1 : index + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg4 = %c0 to %33 step %c1 { + %34 = arith.muli %arg4, %c1 : index + %35 = arith.subi %18, %34 : index + %36 = arith.cmpi slt, %35, %c1 : index + %37 = arith.select %36, %35, %c1 : index + %38 = arith.cmpi sgt, %37, %c0 : index + scf.if %38 { + %39 = arith.addi %13, %34 : index + %40 = arith.muli %39, %1 : index + %41 = arith.remsi %arg4, %c2 : index + %42 = arith.cmpi eq, %41, %c0 : index + scf.if %42 { + scf.for %arg5 = %c0 to %37 step %c1 { + %43 = arith.muli %arg5, %1 : index + %44 = arith.addi %40, %43 : index + %45 = pto.partition_view %21, offsets = [%44], sizes = [%1] : !pto.tensor_view -> !pto.partition_tensor_view<1x16384xf16> + %46 = pto.subset %22[%c0, %c0] sizes [1, 8192] : !pto.tile_buf + %47 = pto.subset %22[%c0, %32] sizes [1, 8192] : !pto.tile_buf + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%45 : !pto.partition_tensor_view<1x16384xf16>) outs(%22 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg6 = %c0 to %2 step %c1 { + pto.tgather ins(%22, {maskPattern = #pto.mask_pattern} : !pto.tile_buf) outs(%24 : !pto.tile_buf) + pto.tgather ins(%22, {maskPattern = #pto.mask_pattern} : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.barrier_sync[] + pto.tadd ins(%24, %26 : !pto.tile_buf, !pto.tile_buf) outs(%46 : !pto.tile_buf) + pto.tsub ins(%24, %26 : !pto.tile_buf, !pto.tile_buf) outs(%47 : !pto.tile_buf) + pto.barrier_sync[] + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tstore ins(%22 : !pto.tile_buf) outs(%45 : !pto.partition_tensor_view<1x16384xf16>) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } else { + scf.for %arg5 = %c0 to %37 step %c1 { + %43 = arith.muli %arg5, %1 : index + %44 = arith.addi %40, %43 : index + %45 = pto.partition_view %21, offsets = [%44], sizes = [%1] : !pto.tensor_view -> !pto.partition_tensor_view<1x16384xf16> + %46 = pto.subset %27[%c0, %c0] sizes [1, 8192] : !pto.tile_buf + %47 = pto.subset %27[%c0, %32] sizes [1, 8192] : !pto.tile_buf + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%45 : !pto.partition_tensor_view<1x16384xf16>) outs(%27 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg6 = %c0 to %2 step %c1 { + pto.tgather ins(%27, {maskPattern = #pto.mask_pattern} : !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tgather ins(%27, {maskPattern = #pto.mask_pattern} : !pto.tile_buf) outs(%31 : !pto.tile_buf) + pto.barrier_sync[] + pto.tadd ins(%29, %31 : !pto.tile_buf, !pto.tile_buf) outs(%46 : !pto.tile_buf) + pto.tsub ins(%29, %31 : !pto.tile_buf, !pto.tile_buf) outs(%47 : !pto.tile_buf) + pto.barrier_sync[] + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tstore ins(%27 : !pto.tile_buf) outs(%45 : !pto.partition_tensor_view<1x16384xf16>) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } + } + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/compile.sh new file mode 100644 index 00000000..f4f4e1da --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./inverse_builder.py --matrix-size 128 > ./inverse_basic_dense_128.pto +ptoas --enable-insert-sync ./inverse_basic_dense_128.pto -o ./inverse_basic_dense_128.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/inverse_basic_dense_128.cpp b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/inverse_basic_dense_128.cpp new file mode 100644 index 00000000..8118b9b3 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/inverse_basic_dense_128.cpp @@ -0,0 +1,188 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void tri_inv_trick_fp16(__gm__ float* v1, __gm__ half* v2, __gm__ half* v3, int32_t v4, int32_t v5) { + unsigned v6 = 0; + const int32_t v7 = 0; + const int32_t v8 = 1; + const int32_t v9 = 128; + const int64_t v10 = 0; + const int64_t v11 = 65536; + const int64_t v12 = 98304; + const int64_t v13 = 32768; + using T = float; + size_t v14 = (size_t) v8; + + #if defined(__DAV_CUBE__) + int64_t v15 = get_block_idx(); + int32_t v16 = (int32_t) ((int64_t) v15); + int64_t v17 = get_block_num(); + int32_t v18 = (int32_t) ((int64_t) v17); + int32_t v19 = (int32_t) ((uint32_t) v4 * (uint32_t) v9); + int32_t v20 = v4 / v18; + int32_t v21 = v4 % v18; + int32_t v22 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v16 * (uint32_t) v20) + (uint32_t) ((uint32_t) v16 < (uint32_t) v21 ? v16 : v21)); + int32_t v23 = (int32_t) ((uint32_t) v22 + (uint32_t) ((int32_t) (uint32_t) v20 + (uint32_t) (v16 < v21 ? v8 : v7))); + pto::Shape<1, 1, 1, 128, 128> v24 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v25 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v26 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v3 + (v6 + v6 * (unsigned) v9 + v6 * (unsigned) v8), v24, v25); + Tile v27; + TASSIGN(v27, v10); + Tile v28; + TASSIGN(v28, v11); + Tile v29; + TASSIGN(v29, v12); + Tile v30; + TASSIGN(v30, v13); + Tile v31; + TASSIGN(v31, v10); + Tile v32; + TASSIGN(v32, v10); + Tile v33; + TASSIGN(v33, v10); + set_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + TLOAD(v27, v26); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v31, v27); + TMOV(v32, v27); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(v33, v31, v32); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TMOV(v30, v33); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + for (size_t v34 = (size_t) v22; v34 < ((size_t) ((uint32_t) v23 < (uint32_t) v4 ? v23 : v4)); v34 += v14) { + int32_t v35 = (int32_t) ((uint32_t) ((int32_t) v34) * (uint32_t) v9); + pto::Shape<1, 1, 1, 128, 128> v36 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v37 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v38 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v2 + (v6 + (unsigned) v35 * (unsigned) v9 + v6 * (unsigned) v8), v36, v37); + pto::Shape<1, 1, 1, 128, 128> v39 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v40 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v41 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v1 + (v6 + (unsigned) v35 * (unsigned) v9 + v6 * (unsigned) v8), v39, v40); + wait_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v29, v38); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + TMOV(v31, v29); + TMOV(v32, v29); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + TMATMUL(v33, v31, v32); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + TMOV(v29, v33); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + TMOV(v32, v27); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + TMATMUL(v33, v31, v32); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TMOV(v31, v27); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + TMATMUL_ACC(v33, v33, v31, v32); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + TMOV(v28, v33); + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + for (size_t v42 = (size_t) v7; v42 < ((size_t) v5); v42 += v14) { + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TMOV(v31, v28); + TMOV(v32, v30); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + TMATMUL(v33, v31, v32); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TMOV(v32, v29); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + TMATMUL_ACC(v33, v33, v31, v32); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + if ((int32_t) ((uint32_t) ((int32_t) v42) + (uint32_t) v8) < v5) { + TMOV(v28, v33); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + TMOV(v31, v29); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + TMATMUL(v33, v31, v32); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + TMOV(v29, v33); + }; + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + }; + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + TSTORE(v41, v33); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + } + wait_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/inverse_basic_dense_128.pto b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/inverse_basic_dense_128.pto new file mode 100644 index 00000000..04513918 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/inverse_basic_dense_128.pto @@ -0,0 +1,75 @@ +module { + func.func @tri_inv_trick_fp16(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: i32) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %0 = arith.index_cast %arg3 : i32 to index + %1 = arith.index_cast %arg4 : i32 to index + %2 = pto.get_block_idx + %3 = arith.index_cast %2 : i64 to index + %4 = pto.get_block_num + %5 = arith.index_cast %4 : i64 to index + %6 = arith.muli %0, %c128 : index + %7 = arith.divsi %0, %5 : index + %8 = arith.remsi %0, %5 : index + %9 = arith.cmpi slt, %3, %8 : index + %10 = arith.minui %3, %8 : index + %11 = arith.muli %3, %7 : index + %12 = arith.addi %11, %10 : index + %13 = arith.select %9, %c1, %c0 : index + %14 = arith.addi %7, %13 : index + %15 = arith.addi %12, %14 : index + %16 = arith.minui %15, %0 : index + %17 = pto.make_tensor_view %arg1, shape = [%6, %c128], strides = [%c128, %c1] : !pto.tensor_view + %18 = pto.make_tensor_view %arg0, shape = [%6, %c128], strides = [%c128, %c1] : !pto.tensor_view + %19 = pto.make_tensor_view %arg2, shape = [%c128, %c128], strides = [%c128, %c1] : !pto.tensor_view + %20 = pto.partition_view %19, offsets = [%c0, %c0], sizes = [%c128, %c128] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + %21 = pto.alloc_tile : !pto.tile_buf + %22 = pto.alloc_tile : !pto.tile_buf + %23 = pto.alloc_tile : !pto.tile_buf + %24 = pto.alloc_tile : !pto.tile_buf + %25 = pto.alloc_tile : !pto.tile_buf + %26 = pto.alloc_tile : !pto.tile_buf + %27 = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%20 : !pto.partition_tensor_view<128x128xf16>) outs(%21 : !pto.tile_buf) + pto.tmov ins(%21 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmov ins(%21 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmatmul ins(%25, %26 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%27 : !pto.tile_buf) outs(%24 : !pto.tile_buf) + scf.for %arg5 = %12 to %16 step %c1 { + %28 = arith.muli %arg5, %c128 : index + %29 = pto.partition_view %17, offsets = [%28, %c0], sizes = [%c128, %c128] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + %30 = pto.partition_view %18, offsets = [%28, %c0], sizes = [%c128, %c128] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf32> + pto.tload ins(%29 : !pto.partition_tensor_view<128x128xf16>) outs(%23 : !pto.tile_buf) + pto.tmov ins(%23 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmov ins(%23 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmatmul ins(%25, %26 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%27 : !pto.tile_buf) outs(%23 : !pto.tile_buf) + pto.tmov ins(%21 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmatmul ins(%25, %26 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%21 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmatmul.acc ins(%27, %25, %26 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%27 : !pto.tile_buf) outs(%22 : !pto.tile_buf) + scf.for %arg6 = %c0 to %1 step %c1 { + pto.tmov ins(%22 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmatmul ins(%25, %26 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%23 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmatmul.acc ins(%27, %25, %26 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + %31 = arith.addi %arg6, %c1 : index + %32 = arith.cmpi slt, %31, %1 : index + scf.if %32 { + pto.tmov ins(%27 : !pto.tile_buf) outs(%22 : !pto.tile_buf) + pto.tmov ins(%23 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmatmul ins(%25, %26 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%27 : !pto.tile_buf) outs(%23 : !pto.tile_buf) + } + } + pto.tstore ins(%27 : !pto.tile_buf) outs(%30 : !pto.partition_tensor_view<128x128xf32>) + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/inverse_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/inverse_builder.py new file mode 100644 index 00000000..498abdc5 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_128/inverse_builder.py @@ -0,0 +1,174 @@ +# pyright: reportUndefinedVariable=false +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const +SUPPORTED_MATRIX_SIZES = (16, 32, 64, 128) + + +def make_meta_data(n: int): + def meta_data(): + in_dtype = pto.float16 + out_dtype = pto.float32 + i32 = pto.int32 + + in_ptr_type = pto.PtrType(in_dtype) + out_ptr_type = pto.PtrType(out_dtype) + in_tensor_type = pto.TensorType(rank=2, dtype=in_dtype) + out_tensor_type = pto.TensorType(rank=2, dtype=out_dtype) + in_subtensor = pto.SubTensorType(shape=[n, n], dtype=in_dtype) + out_subtensor = pto.SubTensorType(shape=[n, n], dtype=out_dtype) + l1_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=in_dtype, memory_space="MAT" + ) + l0a_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=in_dtype, memory_space="LEFT" + ) + l0b_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=in_dtype, memory_space="RIGHT" + ) + l0c_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=out_dtype, memory_space="ACC" + ) + + return { + "in_ptr_type": in_ptr_type, + "out_ptr_type": out_ptr_type, + "i32": i32, + "in_tensor_type": in_tensor_type, + "out_tensor_type": out_tensor_type, + "in_subtensor": in_subtensor, + "out_subtensor": out_subtensor, + "l1_tile_type": l1_tile_type, + "l0a_tile_type": l0a_tile_type, + "l0b_tile_type": l0b_tile_type, + "l0c_tile_type": l0c_tile_type, + } + + return meta_data + + +def build_kernel(matrix_size: int): + @to_ir_module(meta_data=make_meta_data(matrix_size)) + def tri_inv_trick_fp16( + out_ptr: "out_ptr_type", + in_ptr: "in_ptr_type", + i_neg_ptr: "in_ptr_type", + matrix_size_i32: "i32", + log2_blocksize_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + n_c = const(matrix_size) + + batch_size = s.index_cast(matrix_size_i32) + log2_blocksize = s.index_cast(log2_blocksize_i32) + block_idx = s.index_cast(pto.get_block_idx()) + num_cores = s.index_cast(pto.get_block_num()) + total_rows = batch_size * n_c + + # Persistent-kernel work split: base + remainder. + base = batch_size // num_cores + rem = batch_size % num_cores + lt_rem = s.lt(block_idx, rem) + min_bid_rem = s.min_u(block_idx, rem) + b_start = block_idx * base + min_bid_rem + length = base + s.select(lt_rem, c1, c0) + b_end = s.min_u(b_start + length, batch_size) + + tv_m = pto.as_tensor( + in_tensor_type, ptr=in_ptr, shape=[total_rows, n_c], strides=[n_c, c1] + ) + tv_out = pto.as_tensor( + out_tensor_type, ptr=out_ptr, shape=[total_rows, n_c], strides=[n_c, c1] + ) + tv_i_neg = pto.as_tensor( + in_tensor_type, ptr=i_neg_ptr, shape=[n_c, n_c], strides=[n_c, c1] + ) + + sv_i_neg = pto.slice_view( + in_subtensor, source=tv_i_neg, offsets=[c0, c0], sizes=[n_c, n_c] + ) + + i_neg_l1 = pto.alloc_tile(l1_tile_type) + x_l1 = pto.alloc_tile(l1_tile_type) + y_l1 = pto.alloc_tile(l1_tile_type) + i_l1 = pto.alloc_tile(l1_tile_type) + a_l0 = pto.alloc_tile(l0a_tile_type) + b_l0 = pto.alloc_tile(l0b_tile_type) + c_l0 = pto.alloc_tile(l0c_tile_type) + + pto.load(sv_i_neg, i_neg_l1) + # I = (-I) @ (-I) is batch-invariant, so compute it once. + tile.mov(i_neg_l1, a_l0) + tile.mov(i_neg_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, i_l1) + + for b_idx in pto.range(b_start, b_end, c1): + row_offset = b_idx * n_c + sv_m = pto.slice_view( + in_subtensor, + source=tv_m, + offsets=[row_offset, c0], + sizes=[n_c, n_c], + ) + sv_out = pto.slice_view( + out_subtensor, + source=tv_out, + offsets=[row_offset, c0], + sizes=[n_c, n_c], + ) + + # in_ptr carries A = M - I, where M is the dense matrix to invert. + pto.load(sv_m, y_l1) + + tile.mov(y_l1, a_l0) + tile.mov(y_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, y_l1) # y = A @ A + + tile.mov(i_neg_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) # c = -A + + tile.mov(i_neg_l1, a_l0) + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) # c = I - A + tile.mov(c_l0, x_l1) # x = I - A + + # Mirrors: + # for i in range(log2_c - 1): + # X, Y = (X + X @ Y, Y @ Y) + for iter_idx in pto.range(c0, log2_blocksize, c1): + tile.mov(x_l1, a_l0) + tile.mov(i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + + tile.mov(y_l1, b_l0) + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) # x + x @ y + + with pto.if_context(iter_idx + c1 < log2_blocksize): + tile.mov(c_l0, x_l1) + tile.mov(y_l1, a_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, y_l1) # y = y @ y + + pto.store(c_l0, sv_out) + + return tri_inv_trick_fp16 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--matrix-size", + type=int, + choices=SUPPORTED_MATRIX_SIZES, + default=64, + help="Compile-time specialized dense matrix size.", + ) + args = parser.parse_args() + module = build_kernel(args.matrix_size) + print(module) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/compile.sh new file mode 100644 index 00000000..44da3433 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./inverse_builder.py --matrix-size 64 > ./inverse_basic_dense_64.pto +ptoas --enable-insert-sync ./inverse_basic_dense_64.pto -o ./inverse_basic_dense_64.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/inverse_basic_dense_64.cpp b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/inverse_basic_dense_64.cpp new file mode 100644 index 00000000..e1b003f5 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/inverse_basic_dense_64.cpp @@ -0,0 +1,188 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void tri_inv_trick_fp16(__gm__ float* v1, __gm__ half* v2, __gm__ half* v3, int32_t v4, int32_t v5) { + unsigned v6 = 0; + const int32_t v7 = 0; + const int32_t v8 = 1; + const int32_t v9 = 64; + const int64_t v10 = 0; + const int64_t v11 = 16384; + const int64_t v12 = 24576; + const int64_t v13 = 8192; + using T = float; + size_t v14 = (size_t) v8; + + #if defined(__DAV_CUBE__) + int64_t v15 = get_block_idx(); + int32_t v16 = (int32_t) ((int64_t) v15); + int64_t v17 = get_block_num(); + int32_t v18 = (int32_t) ((int64_t) v17); + int32_t v19 = (int32_t) ((uint32_t) v4 * (uint32_t) v9); + int32_t v20 = v4 / v18; + int32_t v21 = v4 % v18; + int32_t v22 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v16 * (uint32_t) v20) + (uint32_t) ((uint32_t) v16 < (uint32_t) v21 ? v16 : v21)); + int32_t v23 = (int32_t) ((uint32_t) v22 + (uint32_t) ((int32_t) (uint32_t) v20 + (uint32_t) (v16 < v21 ? v8 : v7))); + pto::Shape<1, 1, 1, 64, 64> v24 = pto::Shape<1, 1, 1, 64, 64>(); + pto::Stride<4096, 4096, 4096, 64, 1> v25 = pto::Stride<4096, 4096, 4096, 64, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 64, 1>, pto::Layout::ND> v26 = GlobalTensor, pto::Stride<4096, 4096, 4096, 64, 1>, pto::Layout::ND>(v3 + (v6 + v6 * (unsigned) v9 + v6 * (unsigned) v8), v24, v25); + Tile v27; + TASSIGN(v27, v10); + Tile v28; + TASSIGN(v28, v11); + Tile v29; + TASSIGN(v29, v12); + Tile v30; + TASSIGN(v30, v13); + Tile v31; + TASSIGN(v31, v10); + Tile v32; + TASSIGN(v32, v10); + Tile v33; + TASSIGN(v33, v10); + set_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + TLOAD(v27, v26); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v31, v27); + TMOV(v32, v27); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(v33, v31, v32); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TMOV(v30, v33); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + for (size_t v34 = (size_t) v22; v34 < ((size_t) ((uint32_t) v23 < (uint32_t) v4 ? v23 : v4)); v34 += v14) { + int32_t v35 = (int32_t) ((uint32_t) ((int32_t) v34) * (uint32_t) v9); + pto::Shape<1, 1, 1, 64, 64> v36 = pto::Shape<1, 1, 1, 64, 64>(); + pto::Stride<4096, 4096, 4096, 64, 1> v37 = pto::Stride<4096, 4096, 4096, 64, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 64, 1>, pto::Layout::ND> v38 = GlobalTensor, pto::Stride<4096, 4096, 4096, 64, 1>, pto::Layout::ND>(v2 + (v6 + (unsigned) v35 * (unsigned) v9 + v6 * (unsigned) v8), v36, v37); + pto::Shape<1, 1, 1, 64, 64> v39 = pto::Shape<1, 1, 1, 64, 64>(); + pto::Stride<4096, 4096, 4096, 64, 1> v40 = pto::Stride<4096, 4096, 4096, 64, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 64, 1>, pto::Layout::ND> v41 = GlobalTensor, pto::Stride<4096, 4096, 4096, 64, 1>, pto::Layout::ND>(v1 + (v6 + (unsigned) v35 * (unsigned) v9 + v6 * (unsigned) v8), v39, v40); + wait_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v29, v38); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + TMOV(v31, v29); + TMOV(v32, v29); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + TMATMUL(v33, v31, v32); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + TMOV(v29, v33); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + TMOV(v32, v27); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + TMATMUL(v33, v31, v32); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TMOV(v31, v27); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + TMATMUL_ACC(v33, v33, v31, v32); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + TMOV(v28, v33); + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + for (size_t v42 = (size_t) v7; v42 < ((size_t) v5); v42 += v14) { + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TMOV(v31, v28); + TMOV(v32, v30); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + TMATMUL(v33, v31, v32); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TMOV(v32, v29); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + TMATMUL_ACC(v33, v33, v31, v32); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + if ((int32_t) ((uint32_t) ((int32_t) v42) + (uint32_t) v8) < v5) { + TMOV(v28, v33); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + TMOV(v31, v29); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + TMATMUL(v33, v31, v32); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + TMOV(v29, v33); + }; + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + }; + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + TSTORE(v41, v33); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + } + wait_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/inverse_basic_dense_64.pto b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/inverse_basic_dense_64.pto new file mode 100644 index 00000000..41f795fe --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/inverse_basic_dense_64.pto @@ -0,0 +1,75 @@ +module { + func.func @tri_inv_trick_fp16(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: i32) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %0 = arith.index_cast %arg3 : i32 to index + %1 = arith.index_cast %arg4 : i32 to index + %2 = pto.get_block_idx + %3 = arith.index_cast %2 : i64 to index + %4 = pto.get_block_num + %5 = arith.index_cast %4 : i64 to index + %6 = arith.muli %0, %c64 : index + %7 = arith.divsi %0, %5 : index + %8 = arith.remsi %0, %5 : index + %9 = arith.cmpi slt, %3, %8 : index + %10 = arith.minui %3, %8 : index + %11 = arith.muli %3, %7 : index + %12 = arith.addi %11, %10 : index + %13 = arith.select %9, %c1, %c0 : index + %14 = arith.addi %7, %13 : index + %15 = arith.addi %12, %14 : index + %16 = arith.minui %15, %0 : index + %17 = pto.make_tensor_view %arg1, shape = [%6, %c64], strides = [%c64, %c1] : !pto.tensor_view + %18 = pto.make_tensor_view %arg0, shape = [%6, %c64], strides = [%c64, %c1] : !pto.tensor_view + %19 = pto.make_tensor_view %arg2, shape = [%c64, %c64], strides = [%c64, %c1] : !pto.tensor_view + %20 = pto.partition_view %19, offsets = [%c0, %c0], sizes = [%c64, %c64] : !pto.tensor_view -> !pto.partition_tensor_view<64x64xf16> + %21 = pto.alloc_tile : !pto.tile_buf + %22 = pto.alloc_tile : !pto.tile_buf + %23 = pto.alloc_tile : !pto.tile_buf + %24 = pto.alloc_tile : !pto.tile_buf + %25 = pto.alloc_tile : !pto.tile_buf + %26 = pto.alloc_tile : !pto.tile_buf + %27 = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%20 : !pto.partition_tensor_view<64x64xf16>) outs(%21 : !pto.tile_buf) + pto.tmov ins(%21 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmov ins(%21 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmatmul ins(%25, %26 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%27 : !pto.tile_buf) outs(%24 : !pto.tile_buf) + scf.for %arg5 = %12 to %16 step %c1 { + %28 = arith.muli %arg5, %c64 : index + %29 = pto.partition_view %17, offsets = [%28, %c0], sizes = [%c64, %c64] : !pto.tensor_view -> !pto.partition_tensor_view<64x64xf16> + %30 = pto.partition_view %18, offsets = [%28, %c0], sizes = [%c64, %c64] : !pto.tensor_view -> !pto.partition_tensor_view<64x64xf32> + pto.tload ins(%29 : !pto.partition_tensor_view<64x64xf16>) outs(%23 : !pto.tile_buf) + pto.tmov ins(%23 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmov ins(%23 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmatmul ins(%25, %26 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%27 : !pto.tile_buf) outs(%23 : !pto.tile_buf) + pto.tmov ins(%21 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmatmul ins(%25, %26 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%21 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmatmul.acc ins(%27, %25, %26 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%27 : !pto.tile_buf) outs(%22 : !pto.tile_buf) + scf.for %arg6 = %c0 to %1 step %c1 { + pto.tmov ins(%22 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmatmul ins(%25, %26 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%23 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmatmul.acc ins(%27, %25, %26 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + %31 = arith.addi %arg6, %c1 : index + %32 = arith.cmpi slt, %31, %1 : index + scf.if %32 { + pto.tmov ins(%27 : !pto.tile_buf) outs(%22 : !pto.tile_buf) + pto.tmov ins(%23 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmatmul ins(%25, %26 : !pto.tile_buf, !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%27 : !pto.tile_buf) outs(%23 : !pto.tile_buf) + } + } + pto.tstore ins(%27 : !pto.tile_buf) outs(%30 : !pto.partition_tensor_view<64x64xf32>) + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/inverse_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/inverse_builder.py new file mode 100644 index 00000000..498abdc5 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/basic_dense/inverse_basic_dense_64/inverse_builder.py @@ -0,0 +1,174 @@ +# pyright: reportUndefinedVariable=false +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const +SUPPORTED_MATRIX_SIZES = (16, 32, 64, 128) + + +def make_meta_data(n: int): + def meta_data(): + in_dtype = pto.float16 + out_dtype = pto.float32 + i32 = pto.int32 + + in_ptr_type = pto.PtrType(in_dtype) + out_ptr_type = pto.PtrType(out_dtype) + in_tensor_type = pto.TensorType(rank=2, dtype=in_dtype) + out_tensor_type = pto.TensorType(rank=2, dtype=out_dtype) + in_subtensor = pto.SubTensorType(shape=[n, n], dtype=in_dtype) + out_subtensor = pto.SubTensorType(shape=[n, n], dtype=out_dtype) + l1_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=in_dtype, memory_space="MAT" + ) + l0a_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=in_dtype, memory_space="LEFT" + ) + l0b_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=in_dtype, memory_space="RIGHT" + ) + l0c_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=out_dtype, memory_space="ACC" + ) + + return { + "in_ptr_type": in_ptr_type, + "out_ptr_type": out_ptr_type, + "i32": i32, + "in_tensor_type": in_tensor_type, + "out_tensor_type": out_tensor_type, + "in_subtensor": in_subtensor, + "out_subtensor": out_subtensor, + "l1_tile_type": l1_tile_type, + "l0a_tile_type": l0a_tile_type, + "l0b_tile_type": l0b_tile_type, + "l0c_tile_type": l0c_tile_type, + } + + return meta_data + + +def build_kernel(matrix_size: int): + @to_ir_module(meta_data=make_meta_data(matrix_size)) + def tri_inv_trick_fp16( + out_ptr: "out_ptr_type", + in_ptr: "in_ptr_type", + i_neg_ptr: "in_ptr_type", + matrix_size_i32: "i32", + log2_blocksize_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + n_c = const(matrix_size) + + batch_size = s.index_cast(matrix_size_i32) + log2_blocksize = s.index_cast(log2_blocksize_i32) + block_idx = s.index_cast(pto.get_block_idx()) + num_cores = s.index_cast(pto.get_block_num()) + total_rows = batch_size * n_c + + # Persistent-kernel work split: base + remainder. + base = batch_size // num_cores + rem = batch_size % num_cores + lt_rem = s.lt(block_idx, rem) + min_bid_rem = s.min_u(block_idx, rem) + b_start = block_idx * base + min_bid_rem + length = base + s.select(lt_rem, c1, c0) + b_end = s.min_u(b_start + length, batch_size) + + tv_m = pto.as_tensor( + in_tensor_type, ptr=in_ptr, shape=[total_rows, n_c], strides=[n_c, c1] + ) + tv_out = pto.as_tensor( + out_tensor_type, ptr=out_ptr, shape=[total_rows, n_c], strides=[n_c, c1] + ) + tv_i_neg = pto.as_tensor( + in_tensor_type, ptr=i_neg_ptr, shape=[n_c, n_c], strides=[n_c, c1] + ) + + sv_i_neg = pto.slice_view( + in_subtensor, source=tv_i_neg, offsets=[c0, c0], sizes=[n_c, n_c] + ) + + i_neg_l1 = pto.alloc_tile(l1_tile_type) + x_l1 = pto.alloc_tile(l1_tile_type) + y_l1 = pto.alloc_tile(l1_tile_type) + i_l1 = pto.alloc_tile(l1_tile_type) + a_l0 = pto.alloc_tile(l0a_tile_type) + b_l0 = pto.alloc_tile(l0b_tile_type) + c_l0 = pto.alloc_tile(l0c_tile_type) + + pto.load(sv_i_neg, i_neg_l1) + # I = (-I) @ (-I) is batch-invariant, so compute it once. + tile.mov(i_neg_l1, a_l0) + tile.mov(i_neg_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, i_l1) + + for b_idx in pto.range(b_start, b_end, c1): + row_offset = b_idx * n_c + sv_m = pto.slice_view( + in_subtensor, + source=tv_m, + offsets=[row_offset, c0], + sizes=[n_c, n_c], + ) + sv_out = pto.slice_view( + out_subtensor, + source=tv_out, + offsets=[row_offset, c0], + sizes=[n_c, n_c], + ) + + # in_ptr carries A = M - I, where M is the dense matrix to invert. + pto.load(sv_m, y_l1) + + tile.mov(y_l1, a_l0) + tile.mov(y_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, y_l1) # y = A @ A + + tile.mov(i_neg_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) # c = -A + + tile.mov(i_neg_l1, a_l0) + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) # c = I - A + tile.mov(c_l0, x_l1) # x = I - A + + # Mirrors: + # for i in range(log2_c - 1): + # X, Y = (X + X @ Y, Y @ Y) + for iter_idx in pto.range(c0, log2_blocksize, c1): + tile.mov(x_l1, a_l0) + tile.mov(i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + + tile.mov(y_l1, b_l0) + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) # x + x @ y + + with pto.if_context(iter_idx + c1 < log2_blocksize): + tile.mov(c_l0, x_l1) + tile.mov(y_l1, a_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, y_l1) # y = y @ y + + pto.store(c_l0, sv_out) + + return tri_inv_trick_fp16 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--matrix-size", + type=int, + choices=SUPPORTED_MATRIX_SIZES, + default=64, + help="Compile-time specialized dense matrix size.", + ) + args = parser.parse_args() + module = build_kernel(args.matrix_size) + print(module) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/compile.sh new file mode 100644 index 00000000..0ec07350 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./inverse_builder.py --matrix-size 64 > ./inverse_block_inversion_64.pto +ptoas --enable-insert-sync ./inverse_block_inversion_64.pto -o ./inverse_block_inversion_64.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/inverse_block_inversion_64.cpp b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/inverse_block_inversion_64.cpp new file mode 100644 index 00000000..25b800f2 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/inverse_block_inversion_64.cpp @@ -0,0 +1,275 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void tri_inv_block2x2_fp16(__gm__ float* v1, __gm__ half* v2, __gm__ half* v3, int32_t v4) { + unsigned v5 = 32; + unsigned v6 = 0; + const int32_t v7 = 0; + const int32_t v8 = 1; + const int32_t v9 = 64; + const int32_t v10 = 32; + const int64_t v11 = 4096; + const int64_t v12 = 8192; + const int64_t v13 = 6144; + const int64_t v14 = 14336; + const int64_t v15 = 10240; + const int64_t v16 = 0; + const int64_t v17 = 2048; + const int64_t v18 = 12288; + using T = float; + size_t v19 = (size_t) v8; + size_t v20 = (size_t) v7; + + #if defined(__DAV_CUBE__) + int32_t v21 = (int32_t) ((uint32_t) v4 - (uint32_t) v8); + size_t v22 = (size_t) v21; + int64_t v23 = get_block_idx(); + int64_t v24 = get_block_num(); + int32_t v25 = (int32_t) ((uint32_t) ((int32_t) (int64_t) v24) * (uint32_t) v9); + int32_t v26 = (int32_t) ((uint32_t) ((int32_t) (int64_t) v23) * (uint32_t) v9); + int32_t v27 = (int32_t) ((uint32_t) v26 + (uint32_t) v10); + pto::Shape<1, 1, 1, 32, 32> v28 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<1024, 1024, 1024, 32, 1> v29 = pto::Stride<1024, 1024, 1024, 32, 1>(); + GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND> v30 = GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND>(v3 + (v6 + v6 * (unsigned) v10 + v6 * (unsigned) v8), v28, v29); + pto::Shape<1, 1, 1, 32, 32> v31 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<2048, 2048, 2048, 64, 1> v32 = pto::Stride<2048, 2048, 2048, 64, 1>(); + GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND> v33 = GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND>(v2 + (v6 + (unsigned) v26 * (unsigned) v9 + v6 * (unsigned) v8), v31, v32); + pto::Shape<1, 1, 1, 32, 32> v34 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<2048, 2048, 2048, 64, 1> v35 = pto::Stride<2048, 2048, 2048, 64, 1>(); + GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND> v36 = GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND>(v2 + (v6 + (unsigned) v27 * (unsigned) v9 + v6 * (unsigned) v8), v34, v35); + pto::Shape<1, 1, 1, 32, 32> v37 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<2048, 2048, 2048, 64, 1> v38 = pto::Stride<2048, 2048, 2048, 64, 1>(); + GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND> v39 = GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND>(v2 + (v6 + (unsigned) v27 * (unsigned) v9 + v5 * (unsigned) v8), v37, v38); + pto::Shape<1, 1, 1, 32, 32> v40 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<2048, 2048, 2048, 64, 1> v41 = pto::Stride<2048, 2048, 2048, 64, 1>(); + GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND> v42 = GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND>(v1 + (v6 + (unsigned) v26 * (unsigned) v9 + v6 * (unsigned) v8), v40, v41); + pto::Shape<1, 1, 1, 32, 32> v43 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<2048, 2048, 2048, 64, 1> v44 = pto::Stride<2048, 2048, 2048, 64, 1>(); + GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND> v45 = GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND>(v1 + (v6 + (unsigned) v27 * (unsigned) v9 + v6 * (unsigned) v8), v43, v44); + pto::Shape<1, 1, 1, 32, 32> v46 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<2048, 2048, 2048, 64, 1> v47 = pto::Stride<2048, 2048, 2048, 64, 1>(); + GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND> v48 = GlobalTensor, pto::Stride<2048, 2048, 2048, 64, 1>, pto::Layout::ND>(v1 + (v6 + (unsigned) v27 * (unsigned) v9 + v5 * (unsigned) v8), v46, v47); + Tile v49; + TASSIGN(v49, v11); + Tile v50; + TASSIGN(v50, v12); + Tile v51; + TASSIGN(v51, v13); + Tile v52; + TASSIGN(v52, v14); + Tile v53; + TASSIGN(v53, v15); + Tile v54; + TASSIGN(v54, v16); + Tile v55; + TASSIGN(v55, v17); + Tile v56; + TASSIGN(v56, v18); + Tile v57; + TASSIGN(v57, v16); + Tile v58; + TASSIGN(v58, v16); + Tile v59; + TASSIGN(v59, v16); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID6); + TLOAD(v54, v30); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v57, v54); + TMOV(v58, v54); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(v59, v57, v58); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TMOV(v55, v59); + TMOV(v49, v59); + TMOV(v51, v59); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + TLOAD(v50, v33); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TMOV(v57, v50); + TMOV(v58, v54); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + TMATMUL(v59, v57, v58); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + TMOV(v50, v59); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + for (size_t v60 = v20; v60 < v22; v60 += v19) { + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + TMOV(v57, v49); + TMOV(v58, v55); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + TMATMUL(v59, v57, v58); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TMOV(v58, v50); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + TMATMUL_ACC(v59, v59, v57, v58); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + if ((int32_t) ((uint32_t) ((int32_t) v60) + (uint32_t) v8) < v21) { + TMOV(v49, v59); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + TMOV(v57, v50); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + TMATMUL(v59, v57, v58); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + TMOV(v50, v59); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + } + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + TMOV(v49, v59); + TSTORE(v42, v59); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + TLOAD(v52, v39); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TMOV(v57, v52); + TMOV(v58, v54); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + TMATMUL(v59, v57, v58); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + TMOV(v52, v59); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + for (size_t v61 = v20; v61 < v22; v61 += v19) { + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v57, v51); + TMOV(v58, v55); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID6); + TMATMUL(v59, v57, v58); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + TMOV(v58, v52); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + TMATMUL_ACC(v59, v59, v57, v58); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID6); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + if ((int32_t) ((uint32_t) ((int32_t) v61) + (uint32_t) v8) < v21) { + TMOV(v51, v59); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID7); + TMOV(v57, v52); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID7); + TMATMUL(v59, v57, v58); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID7); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID7); + TMOV(v52, v59); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID6); + } + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TMOV(v51, v59); + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + TSTORE(v48, v59); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + TLOAD(v53, v36); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + TMOV(v57, v51); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TMOV(v58, v53); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + TMATMUL(v59, v57, v58); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TMOV(v56, v59); + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1); + TMOV(v57, v56); + TMOV(v58, v49); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(v59, v57, v58); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TMOV(v56, v59); + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TMOV(v57, v54); + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID2); + TMOV(v58, v56); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(v59, v57, v58); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v45, v59); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID6); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/inverse_block_inversion_64.pto b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/inverse_block_inversion_64.pto new file mode 100644 index 00000000..f57eedd9 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/inverse_block_inversion_64.pto @@ -0,0 +1,106 @@ +module { + func.func @tri_inv_block2x2_fp16(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c32 = arith.constant 32 : index + %0 = arith.index_cast %arg3 : i32 to index + %1 = arith.subi %0, %c1 : index + %2 = pto.get_block_idx + %3 = arith.index_cast %2 : i64 to index + %4 = pto.get_block_num + %5 = arith.index_cast %4 : i64 to index + %6 = arith.muli %5, %c64 : index + %7 = arith.muli %3, %c64 : index + %8 = arith.addi %7, %c32 : index + %9 = pto.make_tensor_view %arg1, shape = [%6, %c64], strides = [%c64, %c1] : !pto.tensor_view + %10 = pto.make_tensor_view %arg0, shape = [%6, %c64], strides = [%c64, %c1] : !pto.tensor_view + %11 = pto.make_tensor_view %arg2, shape = [%c32, %c32], strides = [%c32, %c1] : !pto.tensor_view + %12 = pto.partition_view %11, offsets = [%c0, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf16> + %13 = pto.partition_view %9, offsets = [%7, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf16> + %14 = pto.partition_view %9, offsets = [%8, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf16> + %15 = pto.partition_view %9, offsets = [%8, %c32], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf16> + %16 = pto.partition_view %10, offsets = [%7, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %17 = pto.partition_view %10, offsets = [%8, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %18 = pto.partition_view %10, offsets = [%8, %c32], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %19 = pto.alloc_tile : !pto.tile_buf + %20 = pto.alloc_tile : !pto.tile_buf + %21 = pto.alloc_tile : !pto.tile_buf + %22 = pto.alloc_tile : !pto.tile_buf + %23 = pto.alloc_tile : !pto.tile_buf + %24 = pto.alloc_tile : !pto.tile_buf + %25 = pto.alloc_tile : !pto.tile_buf + %26 = pto.alloc_tile : !pto.tile_buf + %27 = pto.alloc_tile : !pto.tile_buf + %28 = pto.alloc_tile : !pto.tile_buf + %29 = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%12 : !pto.partition_tensor_view<32x32xf16>) outs(%24 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%19 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%21 : !pto.tile_buf) + pto.tload ins(%13 : !pto.partition_tensor_view<32x32xf16>) outs(%20 : !pto.tile_buf) + pto.tmov ins(%20 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%20 : !pto.tile_buf) + scf.for %arg4 = %c0 to %1 step %c1 { + pto.tmov ins(%19 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%25 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%20 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul.acc ins(%29, %27, %28 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + %30 = arith.addi %arg4, %c1 : index + %31 = arith.cmpi slt, %30, %1 : index + scf.if %31 { + pto.tmov ins(%29 : !pto.tile_buf) outs(%19 : !pto.tile_buf) + pto.tmov ins(%20 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%20 : !pto.tile_buf) + } + } + pto.tmov ins(%29 : !pto.tile_buf) outs(%19 : !pto.tile_buf) + pto.tstore ins(%29 : !pto.tile_buf) outs(%16 : !pto.partition_tensor_view<32x32xf32>) + pto.tload ins(%15 : !pto.partition_tensor_view<32x32xf16>) outs(%22 : !pto.tile_buf) + pto.tmov ins(%22 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%22 : !pto.tile_buf) + scf.for %arg4 = %c0 to %1 step %c1 { + pto.tmov ins(%21 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%25 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%22 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul.acc ins(%29, %27, %28 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + %30 = arith.addi %arg4, %c1 : index + %31 = arith.cmpi slt, %30, %1 : index + scf.if %31 { + pto.tmov ins(%29 : !pto.tile_buf) outs(%21 : !pto.tile_buf) + pto.tmov ins(%22 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%22 : !pto.tile_buf) + } + } + pto.tmov ins(%29 : !pto.tile_buf) outs(%21 : !pto.tile_buf) + pto.tstore ins(%29 : !pto.tile_buf) outs(%18 : !pto.partition_tensor_view<32x32xf32>) + pto.tload ins(%14 : !pto.partition_tensor_view<32x32xf16>) outs(%23 : !pto.tile_buf) + pto.tmov ins(%21 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%23 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmov ins(%26 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%19 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%26 : !pto.tile_buf) + pto.tmov ins(%24 : !pto.tile_buf) outs(%27 : !pto.tile_buf) + pto.tmov ins(%26 : !pto.tile_buf) outs(%28 : !pto.tile_buf) + pto.tmatmul ins(%27, %28 : !pto.tile_buf, !pto.tile_buf) outs(%29 : !pto.tile_buf) + pto.tstore ins(%29 : !pto.tile_buf) outs(%17 : !pto.partition_tensor_view<32x32xf32>) + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/inverse_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/inverse_builder.py new file mode 100644 index 00000000..1aa4ec4a --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/fast_inverse/block_inversion/inverse_block_inversion_64/inverse_builder.py @@ -0,0 +1,235 @@ +# pyright: reportUndefinedVariable=false +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const +SUPPORTED_MATRIX_SIZES = (16, 32, 64, 128) + + +def make_meta_data(n: int): + h = n // 2 + + def meta_data(): + in_dtype = pto.float16 + out_dtype = pto.float32 + i32 = pto.int32 + + in_ptr_type = pto.PtrType(in_dtype) + out_ptr_type = pto.PtrType(out_dtype) + in_tensor_type = pto.TensorType(rank=2, dtype=in_dtype) + out_tensor_type = pto.TensorType(rank=2, dtype=out_dtype) + + in_subtensor_h = pto.SubTensorType(shape=[h, h], dtype=in_dtype) + out_subtensor_h = pto.SubTensorType(shape=[h, h], dtype=out_dtype) + + l1_tile_type = pto.TileBufType( + shape=[h, h], valid_shape=[h, h], dtype=in_dtype, memory_space="MAT" + ) + l0a_tile_type = pto.TileBufType( + shape=[h, h], valid_shape=[h, h], dtype=in_dtype, memory_space="LEFT" + ) + l0b_tile_type = pto.TileBufType( + shape=[h, h], valid_shape=[h, h], dtype=in_dtype, memory_space="RIGHT" + ) + l0c_tile_type = pto.TileBufType( + shape=[h, h], valid_shape=[h, h], dtype=out_dtype, memory_space="ACC" + ) + + return { + "in_ptr_type": in_ptr_type, + "out_ptr_type": out_ptr_type, + "i32": i32, + "in_tensor_type": in_tensor_type, + "out_tensor_type": out_tensor_type, + "in_subtensor_h": in_subtensor_h, + "out_subtensor_h": out_subtensor_h, + "l1_tile_type": l1_tile_type, + "l0a_tile_type": l0a_tile_type, + "l0b_tile_type": l0b_tile_type, + "l0c_tile_type": l0c_tile_type, + } + + return meta_data + + +def build_kernel(matrix_size: int): + assert matrix_size % 2 == 0 and matrix_size >= 16 + + @to_ir_module(meta_data=make_meta_data(matrix_size)) + def tri_inv_block2x2_fp16( + out_ptr: "out_ptr_type", + in_ptr: "in_ptr_type", + i_neg_ptr: "in_ptr_type", + log2_blocksize_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + n_c = const(matrix_size) + h_c = const(matrix_size // 2) + + log2_half = s.index_cast(log2_blocksize_i32) - c1 + block_idx = s.index_cast(pto.get_block_idx()) + num_blocks = s.index_cast(pto.get_block_num()) + + total_rows = num_blocks * n_c + row_offset = block_idx * n_c + row_offset_h = row_offset + h_c + + tv_in = pto.as_tensor( + in_tensor_type, ptr=in_ptr, shape=[total_rows, n_c], strides=[n_c, c1] + ) + tv_out = pto.as_tensor( + out_tensor_type, ptr=out_ptr, shape=[total_rows, n_c], strides=[n_c, c1] + ) + tv_i_neg = pto.as_tensor( + in_tensor_type, ptr=i_neg_ptr, shape=[h_c, h_c], strides=[h_c, c1] + ) + sv_i_neg = pto.slice_view( + in_subtensor_h, source=tv_i_neg, offsets=[c0, c0], sizes=[h_c, h_c] + ) + + sv_a11 = pto.slice_view( + in_subtensor_h, source=tv_in, offsets=[row_offset, c0], sizes=[h_c, h_c] + ) + sv_a21 = pto.slice_view( + in_subtensor_h, + source=tv_in, + offsets=[row_offset_h, c0], + sizes=[h_c, h_c], + ) + sv_a22 = pto.slice_view( + in_subtensor_h, + source=tv_in, + offsets=[row_offset_h, h_c], + sizes=[h_c, h_c], + ) + + sv_out11 = pto.slice_view( + out_subtensor_h, + source=tv_out, + offsets=[row_offset, c0], + sizes=[h_c, h_c], + ) + sv_out21 = pto.slice_view( + out_subtensor_h, + source=tv_out, + offsets=[row_offset_h, c0], + sizes=[h_c, h_c], + ) + sv_out22 = pto.slice_view( + out_subtensor_h, + source=tv_out, + offsets=[row_offset_h, h_c], + sizes=[h_c, h_c], + ) + + x11_l1 = pto.alloc_tile(l1_tile_type) + y11_l1 = pto.alloc_tile(l1_tile_type) + x22_l1 = pto.alloc_tile(l1_tile_type) + y22_l1 = pto.alloc_tile(l1_tile_type) + a21_l1 = pto.alloc_tile(l1_tile_type) + neg_i_l1 = pto.alloc_tile(l1_tile_type) + pos_i_l1 = pto.alloc_tile(l1_tile_type) + tmp_l1 = pto.alloc_tile(l1_tile_type) + + a_l0 = pto.alloc_tile(l0a_tile_type) + b_l0 = pto.alloc_tile(l0b_tile_type) + c_l0 = pto.alloc_tile(l0c_tile_type) + + # Build +/- identity tiles for half-size blocks. + # Also seed x11 = x22 = I for the recurrence below. + pto.load(sv_i_neg, neg_i_l1) + tile.mov(neg_i_l1, a_l0) + tile.mov(neg_i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, pos_i_l1) + tile.mov(c_l0, x11_l1) # x11 = I + tile.mov(c_l0, x22_l1) # x22 = I + + # Invert (I + A11): start the recurrence with y11 = -A11, x11 = I. + # The loop then computes x_{k+1} = x_k(I + y_k), y_{k+1} = y_k^2 + # which gives (I + A11)^{-1} after log2_half steps. + pto.load(sv_a11, y11_l1) + tile.mov(y11_l1, a_l0) + tile.mov(neg_i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) # c = -A11 + tile.mov(c_l0, y11_l1) # y11 = -A11 + + for iter_idx in pto.range(c0, log2_half, c1): + tile.mov(x11_l1, a_l0) + tile.mov(pos_i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + + tile.mov(y11_l1, b_l0) + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) + + with pto.if_context(iter_idx + c1 < log2_half): + tile.mov(c_l0, x11_l1) + tile.mov(y11_l1, a_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, y11_l1) + + tile.mov(c_l0, x11_l1) + pto.store(c_l0, sv_out11) + + # Invert (I + A22): start with y22 = -A22, x22 = I (already set above). + pto.load(sv_a22, y22_l1) + tile.mov(y22_l1, a_l0) + tile.mov(neg_i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) # c = -A22 + tile.mov(c_l0, y22_l1) # y22 = -A22 + + for iter_idx in pto.range(c0, log2_half, c1): + tile.mov(x22_l1, a_l0) + tile.mov(pos_i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + + tile.mov(y22_l1, b_l0) + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) + + with pto.if_context(iter_idx + c1 < log2_half): + tile.mov(c_l0, x22_l1) + tile.mov(y22_l1, a_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, y22_l1) + + tile.mov(c_l0, x22_l1) + pto.store(c_l0, sv_out22) + + # A21 term in block inversion: + # X21 = - X22 @ A21 @ X11 + pto.load(sv_a21, a21_l1) + + tile.mov(x22_l1, a_l0) + tile.mov(a21_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, tmp_l1) + + tile.mov(tmp_l1, a_l0) + tile.mov(x11_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, tmp_l1) + + tile.mov(neg_i_l1, a_l0) + tile.mov(tmp_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + pto.store(c_l0, sv_out21) + + return tri_inv_block2x2_fp16 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--matrix-size", + type=int, + choices=SUPPORTED_MATRIX_SIZES, + default=64, + help="Compile-time specialized matrix size.", + ) + args = parser.parse_args() + module = build_kernel(args.matrix_size) + print(module) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/compile.sh new file mode 100644 index 00000000..011431e4 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./matmul_builder.py > matmul.pto +ptoas matmul.pto -o matmul.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/matmul.cpp b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/matmul.cpp new file mode 100644 index 00000000..373e6725 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/matmul.cpp @@ -0,0 +1,1450 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void matmul_kernel_ABt(__gm__ half* v1, __gm__ half* v2, __gm__ half* v3, int32_t v4, int32_t v5, int32_t v6, int32_t v7, int32_t v8) { + unsigned v9 = 128; + unsigned v10 = 0; + const int32_t v11 = 0; + const int32_t v12 = 1; + const int32_t v13 = 2; + const int32_t v14 = 128; + const int32_t v15 = 256; + const int32_t v16 = 512; + const int32_t v17 = 64; + const int32_t v18 = 192; + const int32_t v19 = 320; + const int32_t v20 = 384; + const int32_t v21 = 448; + const int64_t v22 = 0; + const int64_t v23 = 131072; + const int64_t v24 = 262144; + const int64_t v25 = 393216; + const int64_t v26 = 16384; + const int64_t v27 = 32768; + const int64_t v28 = 196608; + const int64_t v29 = 327680; + using T = float; + size_t v30 = (size_t) v12; + size_t v31 = (size_t) v11; + + #if defined(__DAV_CUBE__) + int64_t v32 = get_block_num(); + int32_t v33 = (int32_t) ((int64_t) v32); + int64_t v34 = get_block_idx(); + int32_t v35 = (int32_t) ((int64_t) v34); + int32_t v36 = v8 > v11 ? v8 : v12; + int32_t v37 = (int32_t) ((uint32_t) v36 - (uint32_t) v12); + int32_t v38 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v5 + (uint32_t) v15) - (uint32_t) v12) / v15; + int32_t v39 = v4 / v14; + int32_t v40 = (int32_t) ((uint32_t) v38 * (uint32_t) v39); + int32_t v41 = v6 / v16; + size_t v42 = (size_t) v41; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + for (size_t v43 = (size_t) v35; v43 < ((size_t) v40); v43 += (size_t) v33) { + int32_t v44 = (int32_t) v43; + if (v7 == v11) { + int32_t v45 = (int32_t) ((uint32_t) v36 * (uint32_t) v38); + int32_t v46 = v44 / v45; + int32_t v47 = v44 % v45; + int32_t v48 = (int32_t) ((uint32_t) v36 * (uint32_t) v46); + int32_t v49 = v46 == (int32_t) ((uint32_t) ((int32_t) ((uint32_t) v39 + (uint32_t) v37) / v36) - (uint32_t) v12) ? (int32_t) ((uint32_t) v39 - (uint32_t) v48) : v36; + int32_t v50 = v47 / v49; + int32_t v51 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v48 + (uint32_t) (v47 % v49)) * (uint32_t) v14); + int32_t v52 = (int32_t) ((uint32_t) (v46 % v13 == v12 ? (int32_t) ((uint32_t) ((int32_t) (uint32_t) v38 - (uint32_t) v50) - (uint32_t) v12) : v50) * (uint32_t) v15); + if (((int32_t) ((uint32_t) v52 + (uint32_t) v15) > v5 ? v14 : v15) == v15) { + Tile v53; + TASSIGN(v53, v22); + Tile v54; + TASSIGN(v54, v23); + Tile v55; + TASSIGN(v55, v24); + Tile v56; + TASSIGN(v56, v25); + Tile v57; + TASSIGN(v57, v22); + Tile v58; + TASSIGN(v58, v26); + Tile v59; + TASSIGN(v59, v27); + Tile v60; + TASSIGN(v60, v22); + Tile v61; + TASSIGN(v61, v22); + if (v44 != v35) { + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + unsigned v62 = (unsigned) v6; + unsigned v63 = v9 * v62; + pto::Shape<1, 1, 1, 128, 512> v64 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v65 = pto::Stride<-1, -1, -1, -1, 1>(v63, v63, v63, v62); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v66 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v51 * (unsigned) v6 + v10 * (unsigned) v12), v64, v65); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v53, v66); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + for (size_t v67 = v31; v67 < v42; v67 += v30) { + int32_t v68 = (int32_t) v67; + int32_t v69 = (int32_t) ((uint32_t) v68 * (uint32_t) v16); + if (v68 % v13 == v11) { + pto::Shape<1, 1, 1, 256, 256> v70 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v71 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v72 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v69 * (unsigned) v12 + (unsigned) v52 * (unsigned) v6), v70, v71); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v55, v72); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v57, v53, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v59, v55, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v68 == v11) { + TMATMUL(v61, v57, v59); + } else { + TMATMUL_ACC(v61, v61, v57, v59); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v58, v53, v11, v17); + TEXTRACT(v60, v55, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v58, v60); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v57, v53, v11, v14); + TEXTRACT(v59, v55, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v57, v59); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v58, v53, v11, v18); + TEXTRACT(v60, v55, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v58, v60); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 256> v73 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v74 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v75 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v69 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v52 * (unsigned) v6), v73, v74); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v56, v75); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v57, v53, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v59, v56, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v57, v59); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v58, v53, v11, v19); + TEXTRACT(v60, v56, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v58, v60); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v57, v53, v11, v20); + TEXTRACT(v59, v56, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v57, v59); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v58, v53, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TEXTRACT(v60, v56, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v58, v60); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v68 + (uint32_t) v12) < v41) { + unsigned v76 = (unsigned) v6; + unsigned v77 = v9 * v76; + pto::Shape<1, 1, 1, 128, 512> v78 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v79 = pto::Stride<-1, -1, -1, -1, 1>(v77, v77, v77, v76); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v80 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v51 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v69 + (uint32_t) v16) * (unsigned) v12), v78, v79); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v54, v80); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + }; + } else { + pto::Shape<1, 1, 1, 256, 256> v81 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v82 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v83 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v69 * (unsigned) v12 + (unsigned) v52 * (unsigned) v6), v81, v82); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v55, v83); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v57, v54, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v59, v55, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v68 == v11) { + TMATMUL(v61, v57, v59); + } else { + TMATMUL_ACC(v61, v61, v57, v59); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v58, v54, v11, v17); + TEXTRACT(v60, v55, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v58, v60); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v57, v54, v11, v14); + TEXTRACT(v59, v55, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v57, v59); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v58, v54, v11, v18); + TEXTRACT(v60, v55, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v58, v60); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 256> v84 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v85 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v86 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v69 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v52 * (unsigned) v6), v84, v85); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v56, v86); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v57, v54, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v59, v56, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v57, v59); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v58, v54, v11, v19); + TEXTRACT(v60, v56, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v58, v60); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v57, v54, v11, v20); + TEXTRACT(v59, v56, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v57, v59); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v58, v54, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TEXTRACT(v60, v56, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v61, v61, v58, v60); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v68 + (uint32_t) v12) < v41) { + unsigned v87 = (unsigned) v6; + unsigned v88 = v9 * v87; + pto::Shape<1, 1, 1, 128, 512> v89 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v90 = pto::Stride<-1, -1, -1, -1, 1>(v88, v88, v88, v87); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v91 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v51 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v69 + (uint32_t) v16) * (unsigned) v12), v89, v90); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v53, v91); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + }; + }; + }; + unsigned v92 = (unsigned) v5; + unsigned v93 = v9 * v92; + pto::Shape<1, 1, 1, 128, 256> v94 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<-1, -1, -1, -1, 1> v95 = pto::Stride<-1, -1, -1, -1, 1>(v93, v93, v93, v92); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v96 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v3 + (v10 + (unsigned) v51 * (unsigned) v5 + (unsigned) v52 * (unsigned) v12), v94, v95); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v96, v61); + if ((int32_t) ((uint32_t) v44 + (uint32_t) v33) < v40) { + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + } else { + Tile v97; + TASSIGN(v97, v22); + Tile v98; + TASSIGN(v98, v24); + Tile v99; + TASSIGN(v99, v23); + Tile v100; + TASSIGN(v100, v28); + Tile v101; + TASSIGN(v101, v26); + Tile v102; + TASSIGN(v102, v22); + Tile v103; + TASSIGN(v103, v22); + Tile v104; + TASSIGN(v104, v26); + Tile v105; + TASSIGN(v105, v22); + if (v44 != v35) { + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + unsigned v106 = (unsigned) v6; + unsigned v107 = v9 * v106; + pto::Shape<1, 1, 1, 128, 512> v108 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v109 = pto::Stride<-1, -1, -1, -1, 1>(v107, v107, v107, v106); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v110 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v51 * (unsigned) v6 + v10 * (unsigned) v12), v108, v109); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v97, v110); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + for (size_t v111 = v31; v111 < v42; v111 += v30) { + int32_t v112 = (int32_t) v111; + int32_t v113 = (int32_t) ((uint32_t) v112 * (uint32_t) v16); + if (v112 % v13 == v11) { + pto::Shape<1, 1, 1, 256, 128> v114 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v115 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v116 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v113 * (unsigned) v12 + (unsigned) v52 * (unsigned) v6), v114, v115); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v99, v116); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v101, v97, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v103, v99, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v112 == v11) { + TMATMUL(v105, v101, v103); + } else { + TMATMUL_ACC(v105, v105, v101, v103); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v102, v97, v11, v17); + TEXTRACT(v104, v99, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v102, v104); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v101, v97, v11, v14); + TEXTRACT(v103, v99, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v101, v103); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v102, v97, v11, v18); + TEXTRACT(v104, v99, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v102, v104); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 128> v117 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v118 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v119 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v113 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v52 * (unsigned) v6), v117, v118); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v100, v119); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v101, v97, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v103, v100, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v101, v103); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v102, v97, v11, v19); + TEXTRACT(v104, v100, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v102, v104); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v101, v97, v11, v20); + TEXTRACT(v103, v100, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v101, v103); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v102, v97, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TEXTRACT(v104, v100, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v102, v104); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v112 + (uint32_t) v12) < v41) { + unsigned v120 = (unsigned) v6; + unsigned v121 = v9 * v120; + pto::Shape<1, 1, 1, 128, 512> v122 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v123 = pto::Stride<-1, -1, -1, -1, 1>(v121, v121, v121, v120); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v124 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v51 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v113 + (uint32_t) v16) * (unsigned) v12), v122, v123); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v98, v124); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + }; + } else { + pto::Shape<1, 1, 1, 256, 128> v125 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v126 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v127 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v113 * (unsigned) v12 + (unsigned) v52 * (unsigned) v6), v125, v126); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v99, v127); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v101, v98, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v103, v99, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v112 == v11) { + TMATMUL(v105, v101, v103); + } else { + TMATMUL_ACC(v105, v105, v101, v103); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v102, v98, v11, v17); + TEXTRACT(v104, v99, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v102, v104); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v101, v98, v11, v14); + TEXTRACT(v103, v99, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v101, v103); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v102, v98, v11, v18); + TEXTRACT(v104, v99, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v102, v104); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 128> v128 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v129 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v130 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v113 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v52 * (unsigned) v6), v128, v129); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v100, v130); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v101, v98, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v103, v100, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v101, v103); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v102, v98, v11, v19); + TEXTRACT(v104, v100, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v102, v104); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v101, v98, v11, v20); + TEXTRACT(v103, v100, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v101, v103); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v102, v98, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TEXTRACT(v104, v100, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v105, v105, v102, v104); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v112 + (uint32_t) v12) < v41) { + unsigned v131 = (unsigned) v6; + unsigned v132 = v9 * v131; + pto::Shape<1, 1, 1, 128, 512> v133 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v134 = pto::Stride<-1, -1, -1, -1, 1>(v132, v132, v132, v131); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v135 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v51 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v113 + (uint32_t) v16) * (unsigned) v12), v133, v134); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v97, v135); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + }; + }; + }; + unsigned v136 = (unsigned) v5; + unsigned v137 = v9 * v136; + pto::Shape<1, 1, 1, 128, 128> v138 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<-1, -1, -1, -1, 1> v139 = pto::Stride<-1, -1, -1, -1, 1>(v137, v137, v137, v136); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v140 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v3 + (v10 + (unsigned) v51 * (unsigned) v5 + (unsigned) v52 * (unsigned) v12), v138, v139); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v140, v105); + if ((int32_t) ((uint32_t) v44 + (uint32_t) v33) < v40) { + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + }; + } else { + if (v7 == v12) { + int32_t v141 = (int32_t) ((uint32_t) v36 * (uint32_t) v39); + int32_t v142 = v44 / v141; + int32_t v143 = v44 % v141; + int32_t v144 = (int32_t) ((uint32_t) v36 * (uint32_t) v142); + int32_t v145 = v142 == (int32_t) ((uint32_t) ((int32_t) ((uint32_t) v38 + (uint32_t) v37) / v36) - (uint32_t) v12) ? (int32_t) ((uint32_t) v38 - (uint32_t) v144) : v36; + int32_t v146 = v143 / v145; + int32_t v147 = (int32_t) ((uint32_t) (v142 % v13 == v12 ? (int32_t) ((uint32_t) ((int32_t) (uint32_t) v39 - (uint32_t) v146) - (uint32_t) v12) : v146) * (uint32_t) v14); + int32_t v148 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v144 + (uint32_t) (v143 % v145)) * (uint32_t) v15); + if (((int32_t) ((uint32_t) v148 + (uint32_t) v15) > v5 ? v14 : v15) == v15) { + Tile v149; + TASSIGN(v149, v22); + Tile v150; + TASSIGN(v150, v23); + Tile v151; + TASSIGN(v151, v24); + Tile v152; + TASSIGN(v152, v25); + Tile v153; + TASSIGN(v153, v26); + Tile v154; + TASSIGN(v154, v22); + Tile v155; + TASSIGN(v155, v22); + Tile v156; + TASSIGN(v156, v27); + Tile v157; + TASSIGN(v157, v22); + if (v44 != v35) { + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + unsigned v158 = (unsigned) v6; + unsigned v159 = v9 * v158; + pto::Shape<1, 1, 1, 128, 512> v160 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v161 = pto::Stride<-1, -1, -1, -1, 1>(v159, v159, v159, v158); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v162 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v147 * (unsigned) v6 + v10 * (unsigned) v12), v160, v161); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v149, v162); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + for (size_t v163 = v31; v163 < v42; v163 += v30) { + int32_t v164 = (int32_t) v163; + int32_t v165 = (int32_t) ((uint32_t) v164 * (uint32_t) v16); + if (v164 % v13 == v11) { + pto::Shape<1, 1, 1, 256, 256> v166 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v167 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v168 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v165 * (unsigned) v12 + (unsigned) v148 * (unsigned) v6), v166, v167); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v151, v168); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v153, v149, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v155, v151, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v164 == v11) { + TMATMUL(v157, v153, v155); + } else { + TMATMUL_ACC(v157, v157, v153, v155); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v154, v149, v11, v17); + TEXTRACT(v156, v151, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v154, v156); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v153, v149, v11, v14); + TEXTRACT(v155, v151, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v153, v155); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v154, v149, v11, v18); + TEXTRACT(v156, v151, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v154, v156); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 256> v169 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v170 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v171 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v165 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v148 * (unsigned) v6), v169, v170); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v152, v171); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v153, v149, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v155, v152, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v153, v155); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v154, v149, v11, v19); + TEXTRACT(v156, v152, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v154, v156); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v153, v149, v11, v20); + TEXTRACT(v155, v152, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v153, v155); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v154, v149, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TEXTRACT(v156, v152, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v154, v156); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v164 + (uint32_t) v12) < v41) { + unsigned v172 = (unsigned) v6; + unsigned v173 = v9 * v172; + pto::Shape<1, 1, 1, 128, 512> v174 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v175 = pto::Stride<-1, -1, -1, -1, 1>(v173, v173, v173, v172); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v176 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v147 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v165 + (uint32_t) v16) * (unsigned) v12), v174, v175); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v150, v176); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + }; + } else { + pto::Shape<1, 1, 1, 256, 256> v177 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v178 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v179 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v165 * (unsigned) v12 + (unsigned) v148 * (unsigned) v6), v177, v178); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v151, v179); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v153, v150, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v155, v151, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v164 == v11) { + TMATMUL(v157, v153, v155); + } else { + TMATMUL_ACC(v157, v157, v153, v155); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v154, v150, v11, v17); + TEXTRACT(v156, v151, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v154, v156); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v153, v150, v11, v14); + TEXTRACT(v155, v151, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v153, v155); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v154, v150, v11, v18); + TEXTRACT(v156, v151, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v154, v156); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 256> v180 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v181 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v182 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v165 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v148 * (unsigned) v6), v180, v181); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v152, v182); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v153, v150, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v155, v152, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v153, v155); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v154, v150, v11, v19); + TEXTRACT(v156, v152, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v154, v156); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v153, v150, v11, v20); + TEXTRACT(v155, v152, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v153, v155); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v154, v150, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TEXTRACT(v156, v152, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v157, v157, v154, v156); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v164 + (uint32_t) v12) < v41) { + unsigned v183 = (unsigned) v6; + unsigned v184 = v9 * v183; + pto::Shape<1, 1, 1, 128, 512> v185 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v186 = pto::Stride<-1, -1, -1, -1, 1>(v184, v184, v184, v183); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v187 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v147 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v165 + (uint32_t) v16) * (unsigned) v12), v185, v186); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v149, v187); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + }; + }; + }; + unsigned v188 = (unsigned) v5; + unsigned v189 = v9 * v188; + pto::Shape<1, 1, 1, 128, 256> v190 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<-1, -1, -1, -1, 1> v191 = pto::Stride<-1, -1, -1, -1, 1>(v189, v189, v189, v188); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v192 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v3 + (v10 + (unsigned) v147 * (unsigned) v5 + (unsigned) v148 * (unsigned) v12), v190, v191); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v192, v157); + if ((int32_t) ((uint32_t) v44 + (uint32_t) v33) < v40) { + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + } else { + Tile v193; + TASSIGN(v193, v22); + Tile v194; + TASSIGN(v194, v23); + Tile v195; + TASSIGN(v195, v29); + Tile v196; + TASSIGN(v196, v24); + Tile v197; + TASSIGN(v197, v22); + Tile v198; + TASSIGN(v198, v26); + Tile v199; + TASSIGN(v199, v26); + Tile v200; + TASSIGN(v200, v22); + Tile v201; + TASSIGN(v201, v22); + if (v44 != v35) { + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + unsigned v202 = (unsigned) v6; + unsigned v203 = v9 * v202; + pto::Shape<1, 1, 1, 128, 512> v204 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v205 = pto::Stride<-1, -1, -1, -1, 1>(v203, v203, v203, v202); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v206 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v147 * (unsigned) v6 + v10 * (unsigned) v12), v204, v205); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v193, v206); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + for (size_t v207 = v31; v207 < v42; v207 += v30) { + int32_t v208 = (int32_t) v207; + int32_t v209 = (int32_t) ((uint32_t) v208 * (uint32_t) v16); + if (v208 % v13 == v11) { + pto::Shape<1, 1, 1, 256, 128> v210 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v211 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v212 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v209 * (unsigned) v12 + (unsigned) v148 * (unsigned) v6), v210, v211); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v195, v212); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v197, v193, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v199, v195, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v208 == v11) { + TMATMUL(v201, v197, v199); + } else { + TMATMUL_ACC(v201, v201, v197, v199); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v198, v193, v11, v17); + TEXTRACT(v200, v195, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v198, v200); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v197, v193, v11, v14); + TEXTRACT(v199, v195, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v197, v199); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v198, v193, v11, v18); + TEXTRACT(v200, v195, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v198, v200); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 128> v213 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v214 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v215 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v209 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v148 * (unsigned) v6), v213, v214); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v196, v215); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v197, v193, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v199, v196, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v197, v199); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v198, v193, v11, v19); + TEXTRACT(v200, v196, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v198, v200); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v197, v193, v11, v20); + TEXTRACT(v199, v196, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v197, v199); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v198, v193, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TEXTRACT(v200, v196, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v198, v200); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v208 + (uint32_t) v12) < v41) { + unsigned v216 = (unsigned) v6; + unsigned v217 = v9 * v216; + pto::Shape<1, 1, 1, 128, 512> v218 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v219 = pto::Stride<-1, -1, -1, -1, 1>(v217, v217, v217, v216); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v220 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v147 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v209 + (uint32_t) v16) * (unsigned) v12), v218, v219); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v194, v220); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + }; + } else { + pto::Shape<1, 1, 1, 256, 128> v221 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v222 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v223 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v209 * (unsigned) v12 + (unsigned) v148 * (unsigned) v6), v221, v222); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v195, v223); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v197, v194, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v199, v195, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v208 == v11) { + TMATMUL(v201, v197, v199); + } else { + TMATMUL_ACC(v201, v201, v197, v199); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v198, v194, v11, v17); + TEXTRACT(v200, v195, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v198, v200); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v197, v194, v11, v14); + TEXTRACT(v199, v195, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v197, v199); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v198, v194, v11, v18); + TEXTRACT(v200, v195, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v198, v200); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 128> v224 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v225 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v226 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v209 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v148 * (unsigned) v6), v224, v225); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v196, v226); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v197, v194, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v199, v196, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v197, v199); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v198, v194, v11, v19); + TEXTRACT(v200, v196, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v198, v200); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v197, v194, v11, v20); + TEXTRACT(v199, v196, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v197, v199); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v198, v194, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TEXTRACT(v200, v196, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v201, v201, v198, v200); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v208 + (uint32_t) v12) < v41) { + unsigned v227 = (unsigned) v6; + unsigned v228 = v9 * v227; + pto::Shape<1, 1, 1, 128, 512> v229 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v230 = pto::Stride<-1, -1, -1, -1, 1>(v228, v228, v228, v227); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v231 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v147 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v209 + (uint32_t) v16) * (unsigned) v12), v229, v230); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v193, v231); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + }; + }; + }; + unsigned v232 = (unsigned) v5; + unsigned v233 = v9 * v232; + pto::Shape<1, 1, 1, 128, 128> v234 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<-1, -1, -1, -1, 1> v235 = pto::Stride<-1, -1, -1, -1, 1>(v233, v233, v233, v232); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v236 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v3 + (v10 + (unsigned) v147 * (unsigned) v5 + (unsigned) v148 * (unsigned) v12), v234, v235); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v236, v201); + if ((int32_t) ((uint32_t) v44 + (uint32_t) v33) < v40) { + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + }; + } else { + int32_t v237 = (int32_t) ((uint32_t) (v44 / v38) * (uint32_t) v14); + int32_t v238 = (int32_t) ((uint32_t) (v44 % v38) * (uint32_t) v15); + if (((int32_t) ((uint32_t) v238 + (uint32_t) v15) > v5 ? v14 : v15) == v15) { + Tile v239; + TASSIGN(v239, v22); + Tile v240; + TASSIGN(v240, v25); + Tile v241; + TASSIGN(v241, v23); + Tile v242; + TASSIGN(v242, v24); + Tile v243; + TASSIGN(v243, v22); + Tile v244; + TASSIGN(v244, v26); + Tile v245; + TASSIGN(v245, v27); + Tile v246; + TASSIGN(v246, v22); + Tile v247; + TASSIGN(v247, v22); + if (v44 != v35) { + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + unsigned v248 = (unsigned) v6; + unsigned v249 = v9 * v248; + pto::Shape<1, 1, 1, 128, 512> v250 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v251 = pto::Stride<-1, -1, -1, -1, 1>(v249, v249, v249, v248); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v252 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v237 * (unsigned) v6 + v10 * (unsigned) v12), v250, v251); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v239, v252); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + for (size_t v253 = v31; v253 < v42; v253 += v30) { + int32_t v254 = (int32_t) v253; + int32_t v255 = (int32_t) ((uint32_t) v254 * (uint32_t) v16); + if (v254 % v13 == v11) { + pto::Shape<1, 1, 1, 256, 256> v256 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v257 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v258 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v255 * (unsigned) v12 + (unsigned) v238 * (unsigned) v6), v256, v257); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v241, v258); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v243, v239, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v245, v241, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v254 == v11) { + TMATMUL(v247, v243, v245); + } else { + TMATMUL_ACC(v247, v247, v243, v245); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v244, v239, v11, v17); + TEXTRACT(v246, v241, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v244, v246); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v243, v239, v11, v14); + TEXTRACT(v245, v241, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v243, v245); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v244, v239, v11, v18); + TEXTRACT(v246, v241, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v244, v246); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 256> v259 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v260 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v261 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v255 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v238 * (unsigned) v6), v259, v260); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v242, v261); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v243, v239, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v245, v242, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v243, v245); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v244, v239, v11, v19); + TEXTRACT(v246, v242, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v244, v246); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v243, v239, v11, v20); + TEXTRACT(v245, v242, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v243, v245); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v244, v239, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TEXTRACT(v246, v242, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v244, v246); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v254 + (uint32_t) v12) < v41) { + unsigned v262 = (unsigned) v6; + unsigned v263 = v9 * v262; + pto::Shape<1, 1, 1, 128, 512> v264 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v265 = pto::Stride<-1, -1, -1, -1, 1>(v263, v263, v263, v262); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v266 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v237 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v255 + (uint32_t) v16) * (unsigned) v12), v264, v265); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v240, v266); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + }; + } else { + pto::Shape<1, 1, 1, 256, 256> v267 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v268 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v269 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v255 * (unsigned) v12 + (unsigned) v238 * (unsigned) v6), v267, v268); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v241, v269); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v243, v240, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v245, v241, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v254 == v11) { + TMATMUL(v247, v243, v245); + } else { + TMATMUL_ACC(v247, v247, v243, v245); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v244, v240, v11, v17); + TEXTRACT(v246, v241, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v244, v246); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v243, v240, v11, v14); + TEXTRACT(v245, v241, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v243, v245); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v244, v240, v11, v18); + TEXTRACT(v246, v241, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v244, v246); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 256> v270 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v271 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v272 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v255 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v238 * (unsigned) v6), v270, v271); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v242, v272); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v243, v240, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v245, v242, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v243, v245); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v244, v240, v11, v19); + TEXTRACT(v246, v242, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v244, v246); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v243, v240, v11, v20); + TEXTRACT(v245, v242, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v243, v245); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v244, v240, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TEXTRACT(v246, v242, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v247, v247, v244, v246); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v254 + (uint32_t) v12) < v41) { + unsigned v273 = (unsigned) v6; + unsigned v274 = v9 * v273; + pto::Shape<1, 1, 1, 128, 512> v275 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v276 = pto::Stride<-1, -1, -1, -1, 1>(v274, v274, v274, v273); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v277 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v237 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v255 + (uint32_t) v16) * (unsigned) v12), v275, v276); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v239, v277); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + }; + }; + }; + unsigned v278 = (unsigned) v5; + unsigned v279 = v9 * v278; + pto::Shape<1, 1, 1, 128, 256> v280 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<-1, -1, -1, -1, 1> v281 = pto::Stride<-1, -1, -1, -1, 1>(v279, v279, v279, v278); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v282 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v3 + (v10 + (unsigned) v237 * (unsigned) v5 + (unsigned) v238 * (unsigned) v12), v280, v281); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v282, v247); + if ((int32_t) ((uint32_t) v44 + (uint32_t) v33) < v40) { + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + } else { + Tile v283; + TASSIGN(v283, v22); + Tile v284; + TASSIGN(v284, v24); + Tile v285; + TASSIGN(v285, v28); + Tile v286; + TASSIGN(v286, v23); + Tile v287; + TASSIGN(v287, v22); + Tile v288; + TASSIGN(v288, v26); + Tile v289; + TASSIGN(v289, v26); + Tile v290; + TASSIGN(v290, v22); + Tile v291; + TASSIGN(v291, v22); + if (v44 != v35) { + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + unsigned v292 = (unsigned) v6; + unsigned v293 = v9 * v292; + pto::Shape<1, 1, 1, 128, 512> v294 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v295 = pto::Stride<-1, -1, -1, -1, 1>(v293, v293, v293, v292); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v296 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v237 * (unsigned) v6 + v10 * (unsigned) v12), v294, v295); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v283, v296); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + for (size_t v297 = v31; v297 < v42; v297 += v30) { + int32_t v298 = (int32_t) v297; + int32_t v299 = (int32_t) ((uint32_t) v298 * (uint32_t) v16); + if (v298 % v13 == v11) { + pto::Shape<1, 1, 1, 256, 128> v300 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v301 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v302 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v299 * (unsigned) v12 + (unsigned) v238 * (unsigned) v6), v300, v301); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v285, v302); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v287, v283, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v289, v285, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v298 == v11) { + TMATMUL(v291, v287, v289); + } else { + TMATMUL_ACC(v291, v291, v287, v289); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v288, v283, v11, v17); + TEXTRACT(v290, v285, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v288, v290); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v287, v283, v11, v14); + TEXTRACT(v289, v285, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v287, v289); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v288, v283, v11, v18); + TEXTRACT(v290, v285, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v288, v290); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 128> v303 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v304 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v305 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v299 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v238 * (unsigned) v6), v303, v304); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v286, v305); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v287, v283, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v289, v286, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v287, v289); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v288, v283, v11, v19); + TEXTRACT(v290, v286, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v288, v290); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v287, v283, v11, v20); + TEXTRACT(v289, v286, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v287, v289); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v288, v283, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TEXTRACT(v290, v286, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v288, v290); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v298 + (uint32_t) v12) < v41) { + unsigned v306 = (unsigned) v6; + unsigned v307 = v9 * v306; + pto::Shape<1, 1, 1, 128, 512> v308 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v309 = pto::Stride<-1, -1, -1, -1, 1>(v307, v307, v307, v306); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v310 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v237 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v299 + (uint32_t) v16) * (unsigned) v12), v308, v309); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v284, v310); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + }; + } else { + pto::Shape<1, 1, 1, 256, 128> v311 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v312 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v313 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) v299 * (unsigned) v12 + (unsigned) v238 * (unsigned) v6), v311, v312); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v285, v313); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v287, v284, v11, v11); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v289, v285, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v298 == v11) { + TMATMUL(v291, v287, v289); + } else { + TMATMUL_ACC(v291, v291, v287, v289); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v288, v284, v11, v17); + TEXTRACT(v290, v285, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v288, v290); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v287, v284, v11, v14); + TEXTRACT(v289, v285, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v287, v289); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v288, v284, v11, v18); + TEXTRACT(v290, v285, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v288, v290); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 128> v314 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<256, 256, 256, 1, -1> v315 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v316 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) v299 + (uint32_t) v15) * (unsigned) v12 + (unsigned) v238 * (unsigned) v6), v314, v315); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v286, v316); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v287, v284, v11, v15); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v289, v286, v11, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v287, v289); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v288, v284, v11, v19); + TEXTRACT(v290, v286, v17, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v288, v290); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v287, v284, v11, v20); + TEXTRACT(v289, v286, v14, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v287, v289); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v288, v284, v11, v21); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TEXTRACT(v290, v286, v18, v11); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v291, v291, v288, v290); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v298 + (uint32_t) v12) < v41) { + unsigned v317 = (unsigned) v6; + unsigned v318 = v9 * v317; + pto::Shape<1, 1, 1, 128, 512> v319 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v320 = pto::Stride<-1, -1, -1, -1, 1>(v318, v318, v318, v317); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v321 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v237 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v299 + (uint32_t) v16) * (unsigned) v12), v319, v320); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v283, v321); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + }; + }; + }; + unsigned v322 = (unsigned) v5; + unsigned v323 = v9 * v322; + pto::Shape<1, 1, 1, 128, 128> v324 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<-1, -1, -1, -1, 1> v325 = pto::Stride<-1, -1, -1, -1, 1>(v323, v323, v323, v322); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v326 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v3 + (v10 + (unsigned) v237 * (unsigned) v5 + (unsigned) v238 * (unsigned) v12), v324, v325); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v326, v291); + if ((int32_t) ((uint32_t) v44 + (uint32_t) v33) < v40) { + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + }; + }; + }; + } + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + #endif // __DAV_CUBE__ + + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/matmul.pto b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/matmul.pto new file mode 100644 index 00000000..6721cae9 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/matmul.pto @@ -0,0 +1,1637 @@ +module { + func.func @matmul_kernel_ABt(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c128_0 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %0 = arith.index_cast %arg3 : i32 to index + %1 = arith.index_cast %arg4 : i32 to index + %2 = arith.index_cast %arg5 : i32 to index + %3 = arith.index_cast %arg6 : i32 to index + %4 = arith.index_cast %arg7 : i32 to index + %5 = pto.get_block_num + %6 = arith.index_cast %5 : i64 to index + %7 = pto.get_block_idx + %8 = arith.index_cast %7 : i64 to index + %9 = arith.cmpi sgt, %4, %c0 : index + %10 = arith.select %9, %4, %c1 : index + %11 = arith.subi %10, %c1 : index + %12 = arith.addi %1, %c256 : index + %13 = arith.subi %12, %c1 : index + %14 = arith.divsi %13, %c256 : index + %15 = arith.divsi %0, %c128 : index + %16 = arith.muli %14, %15 : index + %17 = arith.divsi %2, %c512 : index + %18 = pto.make_tensor_view %arg0, shape = [%0, %2], strides = [%2, %c1] : !pto.tensor_view + %19 = pto.make_tensor_view %arg1, shape = [%2, %1], strides = [%c1, %2] {layout = #pto.layout} : !pto.tensor_view + %20 = pto.make_tensor_view %arg2, shape = [%0, %1], strides = [%1, %c1] : !pto.tensor_view + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg8 = %8 to %16 step %6 { + %21 = arith.cmpi eq, %3, %c0 : index + scf.if %21 { + %22 = arith.addi %15, %11 : index + %23 = arith.divsi %22, %10 : index + %24 = arith.muli %10, %14 : index + %25 = arith.divsi %arg8, %24 : index + %26 = arith.remsi %arg8, %24 : index + %27 = arith.subi %23, %c1 : index + %28 = arith.cmpi eq, %25, %27 : index + %29 = arith.muli %10, %25 : index + %30 = arith.subi %15, %29 : index + %31 = arith.select %28, %30, %10 : index + %32 = arith.muli %25, %10 : index + %33 = arith.remsi %26, %31 : index + %34 = arith.addi %32, %33 : index + %35 = arith.divsi %26, %31 : index + %36 = arith.remsi %25, %c2 : index + %37 = arith.cmpi eq, %36, %c1 : index + %38 = arith.subi %14, %35 : index + %39 = arith.subi %38, %c1 : index + %40 = arith.select %37, %39, %35 : index + %41 = arith.muli %34, %c128 : index + %42 = arith.muli %40, %c256 : index + %43 = arith.addi %42, %c256 : index + %44 = arith.cmpi sgt, %43, %1 : index + %45 = arith.select %44, %c128_0, %c256 : index + %46 = arith.cmpi eq, %45, %c256 : index + scf.if %46 { + %c0_1 = arith.constant 0 : index + %c1_2 = arith.constant 1 : index + %c2_3 = arith.constant 2 : index + %c256_4 = arith.constant 256 : index + %c512_5 = arith.constant 512 : index + %c256_6 = arith.constant 256 : index + %47 = pto.alloc_tile : !pto.tile_buf + %48 = pto.alloc_tile : !pto.tile_buf + %49 = pto.alloc_tile : !pto.tile_buf + %50 = pto.alloc_tile : !pto.tile_buf + %51 = pto.alloc_tile : !pto.tile_buf + %52 = pto.alloc_tile : !pto.tile_buf + %53 = pto.alloc_tile : !pto.tile_buf + %54 = pto.alloc_tile : !pto.tile_buf + %55 = pto.alloc_tile : !pto.tile_buf + %56 = arith.cmpi ne, %arg8, %8 : index + scf.if %56 { + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + %c128_7 = arith.constant 128 : index + %57 = pto.partition_view %18, offsets = [%41, %c0_1], sizes = [%c128_7, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%57 : !pto.partition_tensor_view<128x512xf16>) outs(%47 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg9 = %c0_1 to %17 step %c1_2 { + %61 = arith.muli %arg9, %c512_5 : index + %62 = arith.remsi %arg9, %c2_3 : index + %63 = arith.cmpi eq, %62, %c0_1 : index + scf.if %63 { + %64 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %65 = arith.addi %61, %c0_9 : index + %66 = pto.partition_view %19, offsets = [%65, %42], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%66 : !pto.partition_tensor_view<256x256xf16>) outs(%49 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %64 { + pto.tmatmul ins(%51, %53 : !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%49, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.textract ins(%49, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%49, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %67 = arith.addi %61, %c256_16 : index + %68 = pto.partition_view %19, offsets = [%67, %42], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%68 : !pto.partition_tensor_view<256x256xf16>) outs(%50 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%50, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.textract ins(%50, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %69 = arith.addi %arg9, %c1_2 : index + %70 = arith.cmpi slt, %69, %17 : index + scf.if %70 { + %71 = arith.addi %61, %c512_5 : index + %c128_22 = arith.constant 128 : index + %72 = pto.partition_view %18, offsets = [%41, %71], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%72 : !pto.partition_tensor_view<128x512xf16>) outs(%48 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } else { + %64 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %65 = arith.addi %61, %c0_9 : index + %66 = pto.partition_view %19, offsets = [%65, %42], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%66 : !pto.partition_tensor_view<256x256xf16>) outs(%49 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %64 { + pto.tmatmul ins(%51, %53 : !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%49, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.textract ins(%49, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%49, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %67 = arith.addi %61, %c256_16 : index + %68 = pto.partition_view %19, offsets = [%67, %42], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%68 : !pto.partition_tensor_view<256x256xf16>) outs(%50 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%50, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.textract ins(%50, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %69 = arith.addi %arg9, %c1_2 : index + %70 = arith.cmpi slt, %69, %17 : index + scf.if %70 { + %71 = arith.addi %61, %c512_5 : index + %c128_22 = arith.constant 128 : index + %72 = pto.partition_view %18, offsets = [%41, %71], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%72 : !pto.partition_tensor_view<128x512xf16>) outs(%47 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } + %c128_8 = arith.constant 128 : index + %58 = pto.partition_view %20, offsets = [%41, %42], sizes = [%c128_8, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xf16> + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tstore ins(%55 : !pto.tile_buf) outs(%58 : !pto.partition_tensor_view<128x256xf16>) + %59 = arith.addi %arg8, %6 : index + %60 = arith.cmpi slt, %59, %16 : index + scf.if %60 { + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } else { + %c0_1 = arith.constant 0 : index + %c1_2 = arith.constant 1 : index + %c2_3 = arith.constant 2 : index + %c256_4 = arith.constant 256 : index + %c512_5 = arith.constant 512 : index + %c128_6 = arith.constant 128 : index + %47 = pto.alloc_tile : !pto.tile_buf + %48 = pto.alloc_tile : !pto.tile_buf + %49 = pto.alloc_tile : !pto.tile_buf + %50 = pto.alloc_tile : !pto.tile_buf + %51 = pto.alloc_tile : !pto.tile_buf + %52 = pto.alloc_tile : !pto.tile_buf + %53 = pto.alloc_tile : !pto.tile_buf + %54 = pto.alloc_tile : !pto.tile_buf + %55 = pto.alloc_tile : !pto.tile_buf + %56 = arith.cmpi ne, %arg8, %8 : index + scf.if %56 { + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + %c128_7 = arith.constant 128 : index + %57 = pto.partition_view %18, offsets = [%41, %c0_1], sizes = [%c128_7, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%57 : !pto.partition_tensor_view<128x512xf16>) outs(%47 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg9 = %c0_1 to %17 step %c1_2 { + %61 = arith.muli %arg9, %c512_5 : index + %62 = arith.remsi %arg9, %c2_3 : index + %63 = arith.cmpi eq, %62, %c0_1 : index + scf.if %63 { + %64 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %65 = arith.addi %61, %c0_9 : index + %66 = pto.partition_view %19, offsets = [%65, %42], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%66 : !pto.partition_tensor_view<256x128xf16>) outs(%49 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %64 { + pto.tmatmul ins(%51, %53 : !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%49, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.textract ins(%49, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%49, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %67 = arith.addi %61, %c256_16 : index + %68 = pto.partition_view %19, offsets = [%67, %42], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%68 : !pto.partition_tensor_view<256x128xf16>) outs(%50 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%50, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.textract ins(%50, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%47, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %69 = arith.addi %arg9, %c1_2 : index + %70 = arith.cmpi slt, %69, %17 : index + scf.if %70 { + %71 = arith.addi %61, %c512_5 : index + %c128_22 = arith.constant 128 : index + %72 = pto.partition_view %18, offsets = [%41, %71], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%72 : !pto.partition_tensor_view<128x512xf16>) outs(%48 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } else { + %64 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %65 = arith.addi %61, %c0_9 : index + %66 = pto.partition_view %19, offsets = [%65, %42], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%66 : !pto.partition_tensor_view<256x128xf16>) outs(%49 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %64 { + pto.tmatmul ins(%51, %53 : !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%49, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.textract ins(%49, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%49, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %67 = arith.addi %61, %c256_16 : index + %68 = pto.partition_view %19, offsets = [%67, %42], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%68 : !pto.partition_tensor_view<256x128xf16>) outs(%50 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%50, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%51 : !pto.tile_buf) + pto.textract ins(%50, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %51, %53 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%55, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %69 = arith.addi %arg9, %c1_2 : index + %70 = arith.cmpi slt, %69, %17 : index + scf.if %70 { + %71 = arith.addi %61, %c512_5 : index + %c128_22 = arith.constant 128 : index + %72 = pto.partition_view %18, offsets = [%41, %71], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%72 : !pto.partition_tensor_view<128x512xf16>) outs(%47 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } + %c128_8 = arith.constant 128 : index + %58 = pto.partition_view %20, offsets = [%41, %42], sizes = [%c128_8, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tstore ins(%55 : !pto.tile_buf) outs(%58 : !pto.partition_tensor_view<128x128xf16>) + %59 = arith.addi %arg8, %6 : index + %60 = arith.cmpi slt, %59, %16 : index + scf.if %60 { + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } else { + %22 = arith.cmpi eq, %3, %c1 : index + scf.if %22 { + %23 = arith.addi %14, %11 : index + %24 = arith.divsi %23, %10 : index + %25 = arith.muli %10, %15 : index + %26 = arith.divsi %arg8, %25 : index + %27 = arith.remsi %arg8, %25 : index + %28 = arith.subi %24, %c1 : index + %29 = arith.cmpi eq, %26, %28 : index + %30 = arith.muli %10, %26 : index + %31 = arith.subi %14, %30 : index + %32 = arith.select %29, %31, %10 : index + %33 = arith.divsi %27, %32 : index + %34 = arith.muli %26, %10 : index + %35 = arith.remsi %27, %32 : index + %36 = arith.addi %34, %35 : index + %37 = arith.remsi %26, %c2 : index + %38 = arith.cmpi eq, %37, %c1 : index + %39 = arith.subi %15, %33 : index + %40 = arith.subi %39, %c1 : index + %41 = arith.select %38, %40, %33 : index + %42 = arith.muli %41, %c128 : index + %43 = arith.muli %36, %c256 : index + %44 = arith.addi %43, %c256 : index + %45 = arith.cmpi sgt, %44, %1 : index + %46 = arith.select %45, %c128_0, %c256 : index + %47 = arith.cmpi eq, %46, %c256 : index + scf.if %47 { + %c0_1 = arith.constant 0 : index + %c1_2 = arith.constant 1 : index + %c2_3 = arith.constant 2 : index + %c256_4 = arith.constant 256 : index + %c512_5 = arith.constant 512 : index + %c256_6 = arith.constant 256 : index + %48 = pto.alloc_tile : !pto.tile_buf + %49 = pto.alloc_tile : !pto.tile_buf + %50 = pto.alloc_tile : !pto.tile_buf + %51 = pto.alloc_tile : !pto.tile_buf + %52 = pto.alloc_tile : !pto.tile_buf + %53 = pto.alloc_tile : !pto.tile_buf + %54 = pto.alloc_tile : !pto.tile_buf + %55 = pto.alloc_tile : !pto.tile_buf + %56 = pto.alloc_tile : !pto.tile_buf + %57 = arith.cmpi ne, %arg8, %8 : index + scf.if %57 { + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + %c128_7 = arith.constant 128 : index + %58 = pto.partition_view %18, offsets = [%42, %c0_1], sizes = [%c128_7, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%58 : !pto.partition_tensor_view<128x512xf16>) outs(%48 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg9 = %c0_1 to %17 step %c1_2 { + %62 = arith.muli %arg9, %c512_5 : index + %63 = arith.remsi %arg9, %c2_3 : index + %64 = arith.cmpi eq, %63, %c0_1 : index + scf.if %64 { + %65 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %66 = arith.addi %62, %c0_9 : index + %67 = pto.partition_view %19, offsets = [%66, %43], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%67 : !pto.partition_tensor_view<256x256xf16>) outs(%50 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %65 { + pto.tmatmul ins(%52, %54 : !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%50, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%50, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%50, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %68 = arith.addi %62, %c256_16 : index + %69 = pto.partition_view %19, offsets = [%68, %43], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%69 : !pto.partition_tensor_view<256x256xf16>) outs(%51 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%51, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%51, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%51, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%51, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %70 = arith.addi %arg9, %c1_2 : index + %71 = arith.cmpi slt, %70, %17 : index + scf.if %71 { + %72 = arith.addi %62, %c512_5 : index + %c128_22 = arith.constant 128 : index + %73 = pto.partition_view %18, offsets = [%42, %72], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%73 : !pto.partition_tensor_view<128x512xf16>) outs(%49 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } else { + %65 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %66 = arith.addi %62, %c0_9 : index + %67 = pto.partition_view %19, offsets = [%66, %43], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%67 : !pto.partition_tensor_view<256x256xf16>) outs(%50 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %65 { + pto.tmatmul ins(%52, %54 : !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%50, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%50, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%50, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %68 = arith.addi %62, %c256_16 : index + %69 = pto.partition_view %19, offsets = [%68, %43], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%69 : !pto.partition_tensor_view<256x256xf16>) outs(%51 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%51, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%51, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%51, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%51, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %70 = arith.addi %arg9, %c1_2 : index + %71 = arith.cmpi slt, %70, %17 : index + scf.if %71 { + %72 = arith.addi %62, %c512_5 : index + %c128_22 = arith.constant 128 : index + %73 = pto.partition_view %18, offsets = [%42, %72], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%73 : !pto.partition_tensor_view<128x512xf16>) outs(%48 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } + %c128_8 = arith.constant 128 : index + %59 = pto.partition_view %20, offsets = [%42, %43], sizes = [%c128_8, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xf16> + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tstore ins(%56 : !pto.tile_buf) outs(%59 : !pto.partition_tensor_view<128x256xf16>) + %60 = arith.addi %arg8, %6 : index + %61 = arith.cmpi slt, %60, %16 : index + scf.if %61 { + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } else { + %c0_1 = arith.constant 0 : index + %c1_2 = arith.constant 1 : index + %c2_3 = arith.constant 2 : index + %c256_4 = arith.constant 256 : index + %c512_5 = arith.constant 512 : index + %c128_6 = arith.constant 128 : index + %48 = pto.alloc_tile : !pto.tile_buf + %49 = pto.alloc_tile : !pto.tile_buf + %50 = pto.alloc_tile : !pto.tile_buf + %51 = pto.alloc_tile : !pto.tile_buf + %52 = pto.alloc_tile : !pto.tile_buf + %53 = pto.alloc_tile : !pto.tile_buf + %54 = pto.alloc_tile : !pto.tile_buf + %55 = pto.alloc_tile : !pto.tile_buf + %56 = pto.alloc_tile : !pto.tile_buf + %57 = arith.cmpi ne, %arg8, %8 : index + scf.if %57 { + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + %c128_7 = arith.constant 128 : index + %58 = pto.partition_view %18, offsets = [%42, %c0_1], sizes = [%c128_7, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%58 : !pto.partition_tensor_view<128x512xf16>) outs(%48 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg9 = %c0_1 to %17 step %c1_2 { + %62 = arith.muli %arg9, %c512_5 : index + %63 = arith.remsi %arg9, %c2_3 : index + %64 = arith.cmpi eq, %63, %c0_1 : index + scf.if %64 { + %65 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %66 = arith.addi %62, %c0_9 : index + %67 = pto.partition_view %19, offsets = [%66, %43], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%67 : !pto.partition_tensor_view<256x128xf16>) outs(%50 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %65 { + pto.tmatmul ins(%52, %54 : !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%50, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%50, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%50, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %68 = arith.addi %62, %c256_16 : index + %69 = pto.partition_view %19, offsets = [%68, %43], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%69 : !pto.partition_tensor_view<256x128xf16>) outs(%51 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%51, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%51, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%51, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%48, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%51, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %70 = arith.addi %arg9, %c1_2 : index + %71 = arith.cmpi slt, %70, %17 : index + scf.if %71 { + %72 = arith.addi %62, %c512_5 : index + %c128_22 = arith.constant 128 : index + %73 = pto.partition_view %18, offsets = [%42, %72], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%73 : !pto.partition_tensor_view<128x512xf16>) outs(%49 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } else { + %65 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %66 = arith.addi %62, %c0_9 : index + %67 = pto.partition_view %19, offsets = [%66, %43], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%67 : !pto.partition_tensor_view<256x128xf16>) outs(%50 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%50, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %65 { + pto.tmatmul ins(%52, %54 : !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%50, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%50, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%50, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %68 = arith.addi %62, %c256_16 : index + %69 = pto.partition_view %19, offsets = [%68, %43], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%69 : !pto.partition_tensor_view<256x128xf16>) outs(%51 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%51, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.textract ins(%51, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%52 : !pto.tile_buf) + pto.textract ins(%51, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%54 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %52, %54 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%49, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%53 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%51, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%55 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%56, %53, %55 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %70 = arith.addi %arg9, %c1_2 : index + %71 = arith.cmpi slt, %70, %17 : index + scf.if %71 { + %72 = arith.addi %62, %c512_5 : index + %c128_22 = arith.constant 128 : index + %73 = pto.partition_view %18, offsets = [%42, %72], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%73 : !pto.partition_tensor_view<128x512xf16>) outs(%48 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } + %c128_8 = arith.constant 128 : index + %59 = pto.partition_view %20, offsets = [%42, %43], sizes = [%c128_8, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tstore ins(%56 : !pto.tile_buf) outs(%59 : !pto.partition_tensor_view<128x128xf16>) + %60 = arith.addi %arg8, %6 : index + %61 = arith.cmpi slt, %60, %16 : index + scf.if %61 { + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } else { + %23 = arith.divsi %arg8, %14 : index + %24 = arith.remsi %arg8, %14 : index + %25 = arith.muli %23, %c128 : index + %26 = arith.muli %24, %c256 : index + %27 = arith.addi %26, %c256 : index + %28 = arith.cmpi sgt, %27, %1 : index + %29 = arith.select %28, %c128_0, %c256 : index + %30 = arith.cmpi eq, %29, %c256 : index + scf.if %30 { + %c0_1 = arith.constant 0 : index + %c1_2 = arith.constant 1 : index + %c2_3 = arith.constant 2 : index + %c256_4 = arith.constant 256 : index + %c512_5 = arith.constant 512 : index + %c256_6 = arith.constant 256 : index + %31 = pto.alloc_tile : !pto.tile_buf + %32 = pto.alloc_tile : !pto.tile_buf + %33 = pto.alloc_tile : !pto.tile_buf + %34 = pto.alloc_tile : !pto.tile_buf + %35 = pto.alloc_tile : !pto.tile_buf + %36 = pto.alloc_tile : !pto.tile_buf + %37 = pto.alloc_tile : !pto.tile_buf + %38 = pto.alloc_tile : !pto.tile_buf + %39 = pto.alloc_tile : !pto.tile_buf + %40 = arith.cmpi ne, %arg8, %8 : index + scf.if %40 { + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + %c128_7 = arith.constant 128 : index + %41 = pto.partition_view %18, offsets = [%25, %c0_1], sizes = [%c128_7, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%41 : !pto.partition_tensor_view<128x512xf16>) outs(%31 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg9 = %c0_1 to %17 step %c1_2 { + %45 = arith.muli %arg9, %c512_5 : index + %46 = arith.remsi %arg9, %c2_3 : index + %47 = arith.cmpi eq, %46, %c0_1 : index + scf.if %47 { + %48 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %49 = arith.addi %45, %c0_9 : index + %50 = pto.partition_view %19, offsets = [%49, %26], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%50 : !pto.partition_tensor_view<256x256xf16>) outs(%33 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%33, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %48 { + pto.tmatmul ins(%35, %37 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%33, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.textract ins(%33, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%33, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %51 = arith.addi %45, %c256_16 : index + %52 = pto.partition_view %19, offsets = [%51, %26], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%52 : !pto.partition_tensor_view<256x256xf16>) outs(%34 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%34, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%34, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.textract ins(%34, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%34, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %53 = arith.addi %arg9, %c1_2 : index + %54 = arith.cmpi slt, %53, %17 : index + scf.if %54 { + %55 = arith.addi %45, %c512_5 : index + %c128_22 = arith.constant 128 : index + %56 = pto.partition_view %18, offsets = [%25, %55], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%56 : !pto.partition_tensor_view<128x512xf16>) outs(%32 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } else { + %48 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %49 = arith.addi %45, %c0_9 : index + %50 = pto.partition_view %19, offsets = [%49, %26], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%50 : !pto.partition_tensor_view<256x256xf16>) outs(%33 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%33, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %48 { + pto.tmatmul ins(%35, %37 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%33, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.textract ins(%33, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%33, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %51 = arith.addi %45, %c256_16 : index + %52 = pto.partition_view %19, offsets = [%51, %26], sizes = [%c256_4, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%52 : !pto.partition_tensor_view<256x256xf16>) outs(%34 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%34, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%34, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.textract ins(%34, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%34, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %53 = arith.addi %arg9, %c1_2 : index + %54 = arith.cmpi slt, %53, %17 : index + scf.if %54 { + %55 = arith.addi %45, %c512_5 : index + %c128_22 = arith.constant 128 : index + %56 = pto.partition_view %18, offsets = [%25, %55], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%56 : !pto.partition_tensor_view<128x512xf16>) outs(%31 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } + %c128_8 = arith.constant 128 : index + %42 = pto.partition_view %20, offsets = [%25, %26], sizes = [%c128_8, %c256_6] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xf16> + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tstore ins(%39 : !pto.tile_buf) outs(%42 : !pto.partition_tensor_view<128x256xf16>) + %43 = arith.addi %arg8, %6 : index + %44 = arith.cmpi slt, %43, %16 : index + scf.if %44 { + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } else { + %c0_1 = arith.constant 0 : index + %c1_2 = arith.constant 1 : index + %c2_3 = arith.constant 2 : index + %c256_4 = arith.constant 256 : index + %c512_5 = arith.constant 512 : index + %c128_6 = arith.constant 128 : index + %31 = pto.alloc_tile : !pto.tile_buf + %32 = pto.alloc_tile : !pto.tile_buf + %33 = pto.alloc_tile : !pto.tile_buf + %34 = pto.alloc_tile : !pto.tile_buf + %35 = pto.alloc_tile : !pto.tile_buf + %36 = pto.alloc_tile : !pto.tile_buf + %37 = pto.alloc_tile : !pto.tile_buf + %38 = pto.alloc_tile : !pto.tile_buf + %39 = pto.alloc_tile : !pto.tile_buf + %40 = arith.cmpi ne, %arg8, %8 : index + scf.if %40 { + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + %c128_7 = arith.constant 128 : index + %41 = pto.partition_view %18, offsets = [%25, %c0_1], sizes = [%c128_7, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%41 : !pto.partition_tensor_view<128x512xf16>) outs(%31 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg9 = %c0_1 to %17 step %c1_2 { + %45 = arith.muli %arg9, %c512_5 : index + %46 = arith.remsi %arg9, %c2_3 : index + %47 = arith.cmpi eq, %46, %c0_1 : index + scf.if %47 { + %48 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %49 = arith.addi %45, %c0_9 : index + %50 = pto.partition_view %19, offsets = [%49, %26], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%50 : !pto.partition_tensor_view<256x128xf16>) outs(%33 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%33, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %48 { + pto.tmatmul ins(%35, %37 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%33, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.textract ins(%33, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%33, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %51 = arith.addi %45, %c256_16 : index + %52 = pto.partition_view %19, offsets = [%51, %26], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%52 : !pto.partition_tensor_view<256x128xf16>) outs(%34 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%34, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%34, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.textract ins(%34, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%31, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%34, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %53 = arith.addi %arg9, %c1_2 : index + %54 = arith.cmpi slt, %53, %17 : index + scf.if %54 { + %55 = arith.addi %45, %c512_5 : index + %c128_22 = arith.constant 128 : index + %56 = pto.partition_view %18, offsets = [%25, %55], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%56 : !pto.partition_tensor_view<128x512xf16>) outs(%32 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } else { + %48 = arith.cmpi eq, %arg9, %c0_1 : index + %c0_9 = arith.constant 0 : index + %49 = arith.addi %45, %c0_9 : index + %50 = pto.partition_view %19, offsets = [%49, %26], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%50 : !pto.partition_tensor_view<256x128xf16>) outs(%33 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c0_10 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%33, %c0_11, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %48 { + pto.tmatmul ins(%35, %37 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_12 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c64 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%33, %c64_12, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_13 = arith.constant 128 : index + %c128_14 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c128_13 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.textract ins(%33, %c128_14, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_15 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c192 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%33, %c192_15, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_16 = arith.constant 256 : index + %51 = arith.addi %45, %c256_16 : index + %52 = pto.partition_view %19, offsets = [%51, %26], sizes = [%c256_4, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%52 : !pto.partition_tensor_view<256x128xf16>) outs(%34 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_17 = arith.constant 256 : index + %c0_18 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c256_17 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%34, %c0_18, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_19 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c320 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.textract ins(%34, %c64_19, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_20 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c384 : !pto.tile_buf, index, index) outs(%35 : !pto.tile_buf) + pto.textract ins(%34, %c128_20, %c0_1 : !pto.tile_buf, index, index) outs(%37 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %35, %37 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_21 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%32, %c0_1, %c448 : !pto.tile_buf, index, index) outs(%36 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%34, %c192_21, %c0_1 : !pto.tile_buf, index, index) outs(%38 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %53 = arith.addi %arg9, %c1_2 : index + %54 = arith.cmpi slt, %53, %17 : index + scf.if %54 { + %55 = arith.addi %45, %c512_5 : index + %c128_22 = arith.constant 128 : index + %56 = pto.partition_view %18, offsets = [%25, %55], sizes = [%c128_22, %c512_5] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%56 : !pto.partition_tensor_view<128x512xf16>) outs(%31 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } + %c128_8 = arith.constant 128 : index + %42 = pto.partition_view %20, offsets = [%25, %26], sizes = [%c128_8, %c128_6] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tstore ins(%39 : !pto.tile_buf) outs(%42 : !pto.partition_tensor_view<128x128xf16>) + %43 = arith.addi %arg8, %6 : index + %44 = arith.cmpi slt, %43, %16 : index + scf.if %44 { + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } + } + } + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/matmul_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/matmul_builder.py new file mode 100644 index 00000000..a807414c --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/experimental/matmul/matmul_builder.py @@ -0,0 +1,364 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + + +def build(): + M_TILE = 128 + K_QTILE = 64 + K_TILE = 256 + K_DTILE = 512 + N_FULL = 256 + N_HALF = 128 + + def meta_data(): + dtype = pto.float16 + acc_dtype = pto.float32 + ptr_type = pto.PtrType(dtype) + i32 = pto.int32 + tv_a = pto.TensorType(rank=2, dtype=dtype) + tv_b = pto.TensorType(rank=2, dtype=dtype) + tv_c = pto.TensorType(rank=2, dtype=dtype) + + tile_view_a = pto.SubTensorType(shape=[M_TILE, K_DTILE], dtype=dtype) + tile_view_b_256 = pto.SubTensorType(shape=[K_TILE, N_FULL], dtype=dtype) + tile_view_b_128 = pto.SubTensorType(shape=[K_TILE, N_HALF], dtype=dtype) + tile_view_c_256 = pto.SubTensorType(shape=[M_TILE, N_FULL], dtype=dtype) + tile_view_c_128 = pto.SubTensorType(shape=[M_TILE, N_HALF], dtype=dtype) + + b_l1_cfg = pto.TileBufConfig( + blayout="RowMajor", slayout="ColMajor", s_fractal_size=512 + ) + + tile_buf_a_l1 = pto.TileBufType( + shape=[M_TILE, K_DTILE], dtype=dtype, memory_space="MAT" + ) + tile_buf_b_l1_256 = pto.TileBufType( + shape=[K_TILE, N_FULL], dtype=dtype, memory_space="MAT", config=b_l1_cfg + ) + tile_buf_b_l1_128 = pto.TileBufType( + shape=[K_TILE, N_HALF], dtype=dtype, memory_space="MAT", config=b_l1_cfg + ) + tile_buf_a_l0 = pto.TileBufType( + shape=[M_TILE, K_QTILE], dtype=dtype, memory_space="LEFT" + ) + tile_buf_b_l0_256 = pto.TileBufType( + shape=[K_QTILE, N_FULL], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_b_l0_128 = pto.TileBufType( + shape=[K_QTILE, N_HALF], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_c_256 = pto.TileBufType( + shape=[M_TILE, N_FULL], dtype=acc_dtype, memory_space="ACC" + ) + tile_buf_c_128 = pto.TileBufType( + shape=[M_TILE, N_HALF], dtype=acc_dtype, memory_space="ACC" + ) + + return { + "ptr_type": ptr_type, + "i32": i32, + "tv_a": tv_a, + "tv_b": tv_b, + "tv_c": tv_c, + "tile_view_a": tile_view_a, + "tile_view_b_256": tile_view_b_256, + "tile_view_b_128": tile_view_b_128, + "tile_view_c_256": tile_view_c_256, + "tile_view_c_128": tile_view_c_128, + "tile_buf_a_l1": tile_buf_a_l1, + "tile_buf_b_l1_256": tile_buf_b_l1_256, + "tile_buf_b_l1_128": tile_buf_b_l1_128, + "tile_buf_a_l0": tile_buf_a_l0, + "tile_buf_b_l0_256": tile_buf_b_l0_256, + "tile_buf_b_l0_128": tile_buf_b_l0_128, + "tile_buf_c_256": tile_buf_c_256, + "tile_buf_c_128": tile_buf_c_128, + } + + def swizzle_zn(li, m_loop, n_loop, cSwizzle, cSwizzleM1, c1, c2): + tile_block_loop = (m_loop + cSwizzleM1) // cSwizzle + tile_block_span = cSwizzle * n_loop + tile_block_idx = li // tile_block_span + in_tile_block_idx = li % tile_block_span + is_last_block = tile_block_idx == (tile_block_loop - c1) + n_row_tail = m_loop - cSwizzle * tile_block_idx + n_row = s.select(is_last_block, n_row_tail, cSwizzle) + m_idx = tile_block_idx * cSwizzle + (in_tile_block_idx % n_row) + n_idx = in_tile_block_idx // n_row + odd_block = (tile_block_idx % c2) == c1 + flipped_n_idx = n_loop - n_idx - c1 + n_idx = s.select(odd_block, flipped_n_idx, n_idx) + return m_idx, n_idx + + def swizzle_nz(li, m_loop, n_loop, cSwizzle, cSwizzleM1, c1, c2): + tile_block_loop = (n_loop + cSwizzleM1) // cSwizzle + tile_block_span = cSwizzle * m_loop + tile_block_idx = li // tile_block_span + in_tile_block_idx = li % tile_block_span + is_last_block = tile_block_idx == (tile_block_loop - c1) + n_col_tail = n_loop - cSwizzle * tile_block_idx + n_col = s.select(is_last_block, n_col_tail, cSwizzle) + m_idx = in_tile_block_idx // n_col + n_idx = tile_block_idx * cSwizzle + (in_tile_block_idx % n_col) + odd_block = (tile_block_idx % c2) == c1 + flipped_m_idx = m_loop - m_idx - c1 + m_idx = s.select(odd_block, flipped_m_idx, m_idx) + return m_idx, n_idx + + def level1_loop_mn_dynamic_tilesize( + n_tile: int, + b_view_type, + c_view_type, + b_l1_type, + b_l0_type, + c_type, + m_offset, + n_offset, + k_dtile_num, + li, + core_loop, + bid, + num_blocks, + tvA, + tvB, + tvC, + ): + c0 = const(0) + c1 = const(1) + c2 = const(2) + cKT = const(K_TILE) + cKD = const(K_DTILE) + cNT = const(n_tile) + + a_l1 = [pto.alloc_tile(tile_buf_a_l1), pto.alloc_tile(tile_buf_a_l1)] + b_l1 = [pto.alloc_tile(b_l1_type), pto.alloc_tile(b_l1_type)] + a_l0 = [pto.alloc_tile(tile_buf_a_l0), pto.alloc_tile(tile_buf_a_l0)] + b_l0 = [pto.alloc_tile(b_l0_type), pto.alloc_tile(b_l0_type)] + c_l0 = pto.alloc_tile(c_type) + + not_first_tile = li != bid + with pto.if_context(not_first_tile): + pto.wait_event("STORE_ACC", "MATMUL", event_id=0) + + sv_a0 = pto.slice_view( + tile_view_a, + source=tvA, + offsets=[m_offset, c0], + sizes=[const(M_TILE), cKD], + ) + pto.wait_event("MOV_M2L", "LOAD", event_id=0) + pto.load(sv_a0, a_l1[0]) + pto.record_event("LOAD", "MOV_M2L", event_id=0) + + for k_idx in pto.range(c0, k_dtile_num, c1): + k_offset = k_idx * cKD + is_curr0 = (k_idx % c2) == c0 + + def level2_loop_k(curr_id, next_id, a_curr, a_next): + is_first_k_tile = k_idx == c0 + + for h in range(2): + b_evt = 2 + h + h_off = const(h * K_TILE) + sv_b = pto.slice_view( + b_view_type, + source=tvB, + offsets=[k_offset + h_off, n_offset], + sizes=[cKT, cNT], + ) + + pto.wait_event("MOV_M2L", "LOAD", event_id=b_evt) + pto.load(sv_b, b_l1[h]) + pto.record_event("LOAD", "MOV_M2L", event_id=b_evt) + + for quarter in range(4): + phase = h * 4 + quarter + ping = phase & 1 + a_col = const(phase * K_QTILE) + b_row = const(quarter * K_QTILE) + + pto.wait_event("MATMUL", "MOV_M2L", event_id=ping) + if phase == 0: + pto.wait_event("LOAD", "MOV_M2L", event_id=curr_id) + + tile.extract(a_curr, c0, a_col, a_l0[ping]) + if phase == 7: + pto.record_event("MOV_M2L", "LOAD", event_id=curr_id) + + if quarter == 0: + pto.wait_event("LOAD", "MOV_M2L", event_id=b_evt) + + tile.extract(b_l1[h], b_row, c0, b_l0[ping]) + pto.record_event("MOV_M2L", "MATMUL", event_id=0) + + if quarter == 3: + pto.record_event("MOV_M2L", "LOAD", event_id=b_evt) + + pto.wait_event("MOV_M2L", "MATMUL", event_id=0) + if phase == 0: + pto.cond( + is_first_k_tile, + lambda: tile.matmul(a_l0[ping], b_l0[ping], c_l0), + lambda: tile.matmul_acc( + c_l0, a_l0[ping], b_l0[ping], c_l0 + ), + ) + else: + tile.matmul_acc(c_l0, a_l0[ping], b_l0[ping], c_l0) + + pto.record_event("MATMUL", "MOV_M2L", event_id=ping) + + with pto.if_context(k_idx + c1 < k_dtile_num): + sv_a_next = pto.slice_view( + tile_view_a, + source=tvA, + offsets=[m_offset, k_offset + cKD], + sizes=[const(M_TILE), cKD], + ) + pto.wait_event("MOV_M2L", "LOAD", event_id=next_id) + pto.load(sv_a_next, a_next) + pto.record_event("LOAD", "MOV_M2L", event_id=next_id) + + with pto.if_context(is_curr0, has_else=True) as branch: + level2_loop_k(0, 1, a_l1[0], a_l1[1]) + with branch.else_context(): + level2_loop_k(1, 0, a_l1[1], a_l1[0]) + + sv_c = pto.slice_view( + c_view_type, + source=tvC, + offsets=[m_offset, n_offset], + sizes=[const(M_TILE), cNT], + ) + pto.record_wait_pair("MATMUL", "STORE_ACC", event_id=0) + pto.store(c_l0, sv_c) + + with pto.if_context(li + num_blocks < core_loop): + pto.record_event("STORE_ACC", "MATMUL", event_id=0) + + @to_ir_module(meta_data=meta_data) + def matmul_kernel_ABt( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + m_i32: "i32", + n_i32: "i32", + k_i32: "i32", + swizzle_direction_i32: "i32", + swizzle_count_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + c2 = const(2) + c128 = const(M_TILE) + c256 = const(N_FULL) + c128n = const(N_HALF) + c512 = const(K_DTILE) + + m_total = s.index_cast(m_i32) + n_total = s.index_cast(n_i32) + k_total = s.index_cast(k_i32) + swizzle_direction = s.index_cast(swizzle_direction_i32) + swizzle_count = s.index_cast(swizzle_count_i32) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + cSwizzle = s.select(swizzle_count > c0, swizzle_count, c1) + cSwizzleM1 = cSwizzle - c1 + + n_loop = (n_total + c256 - c1) // c256 + m_loop = m_total // c128 + core_loop = n_loop * m_loop + k_dtile_num = k_total // c512 + + tvA = pto.as_tensor( + tv_a, ptr=a_ptr, shape=[m_total, k_total], strides=[k_total, c1] + ) + tvB = pto.as_tensor( + tv_b, + ptr=b_ptr, + shape=[k_total, n_total], + strides=[c1, k_total], + layout="DN", + ) + tvC = pto.as_tensor( + tv_c, ptr=c_ptr, shape=[m_total, n_total], strides=[n_total, c1] + ) + + pto.record_event("MATMUL", "MOV_M2L", event_id=[0, 1]) + pto.record_event("MOV_M2L", "LOAD", event_id=[0, 1, 2, 3]) + + def level1_loop_mn(m_offset, n_offset, li): + # TODO: make a simpler version that only uses full-tile (256) branch, and reduce the types needed in meta_data + n_tile_size = s.select(n_offset + c256 > n_total, c128n, c256) + shared_args = [ + m_offset, + n_offset, + k_dtile_num, + li, + core_loop, + bid, + num_blocks, + tvA, + tvB, + tvC, + ] + with pto.if_context(n_tile_size == c256, has_else=True) as branch: + level1_loop_mn_dynamic_tilesize( + N_FULL, + tile_view_b_256, + tile_view_c_256, + tile_buf_b_l1_256, + tile_buf_b_l0_256, + tile_buf_c_256, + *shared_args, + ) + with branch.else_context(): + level1_loop_mn_dynamic_tilesize( + N_HALF, + tile_view_b_128, + tile_view_c_128, + tile_buf_b_l1_128, + tile_buf_b_l0_128, + tile_buf_c_128, + *shared_args, + ) + + for li in pto.range(bid, core_loop, num_blocks): + with pto.if_context( + swizzle_direction == c0, has_else=True + ) as c0_branch: + m_idx, n_idx = swizzle_zn( + li, m_loop, n_loop, cSwizzle, cSwizzleM1, c1, c2 + ) + level1_loop_mn(m_idx * c128, n_idx * c256, li) + + with c0_branch.else_context(): + with pto.if_context( + swizzle_direction == c1, has_else=True + ) as c1_branch: + m_idx, n_idx = swizzle_nz( + li, m_loop, n_loop, cSwizzle, cSwizzleM1, c1, c2 + ) + level1_loop_mn(m_idx * c128, n_idx * c256, li) + + with c1_branch.else_context(): + # Default linear mapping, used when swizzle_direction is not 0/1. + m_idx = li // n_loop + n_idx = li % n_loop + level1_loop_mn(m_idx * c128, n_idx * c256, li) + + pto.wait_event("MOV_M2L", "LOAD", event_id=3) + pto.wait_event("MOV_M2L", "LOAD", event_id=2) + pto.wait_event("MOV_M2L", "LOAD", event_id=1) + pto.wait_event("MOV_M2L", "LOAD", event_id=0) + pto.wait_event("MATMUL", "MOV_M2L", event_id=0) + pto.wait_event("MATMUL", "MOV_M2L", event_id=1) + + return matmul_kernel_ABt + + +if __name__ == "__main__": + print(build()) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/common_utils.py b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/common_utils.py new file mode 100644 index 00000000..58d8b801 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/common_utils.py @@ -0,0 +1,76 @@ +from ptodsl import pto +from ptodsl import scalar as s + +const = s.const + +M_TILE = 128 +K_QTILE = 64 +K_TILE = 256 +K_DTILE = 512 +N_FULL = 256 +SWIZZLE_COUNT = 5 + + +def build_meta_data(): + def meta_data(): + dtype = pto.float16 + acc_dtype = pto.float32 + ptr_type = pto.PtrType(dtype) + i32 = pto.int32 + tv_2d = pto.TensorType(rank=2, dtype=dtype) + + tile_view_a = pto.SubTensorType(shape=[M_TILE, K_DTILE], dtype=dtype) + tile_view_b = pto.SubTensorType(shape=[K_TILE, N_FULL], dtype=dtype) + tile_view_c = pto.SubTensorType(shape=[M_TILE, N_FULL], dtype=dtype) + + b_l1_cfg = pto.TileBufConfig( + blayout="RowMajor", slayout="ColMajor", s_fractal_size=512 + ) + + tile_buf_a_l1 = pto.TileBufType( + shape=[M_TILE, K_DTILE], dtype=dtype, memory_space="MAT" + ) + tile_buf_b_l1 = pto.TileBufType( + shape=[K_TILE, N_FULL], dtype=dtype, memory_space="MAT", config=b_l1_cfg + ) + tile_buf_a_l0 = pto.TileBufType( + shape=[M_TILE, K_QTILE], dtype=dtype, memory_space="LEFT" + ) + tile_buf_b_l0 = pto.TileBufType( + shape=[K_QTILE, N_FULL], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_c = pto.TileBufType( + shape=[M_TILE, N_FULL], dtype=acc_dtype, memory_space="ACC" + ) + + return { + "ptr_type": ptr_type, + "i32": i32, + "tv_2d": tv_2d, + "tile_view_a": tile_view_a, + "tile_view_b": tile_view_b, + "tile_view_c": tile_view_c, + "tile_buf_a_l1": tile_buf_a_l1, + "tile_buf_b_l1": tile_buf_b_l1, + "tile_buf_a_l0": tile_buf_a_l0, + "tile_buf_b_l0": tile_buf_b_l0, + "tile_buf_c": tile_buf_c, + } + + return meta_data + + +def swizzle_nz(li, m_loop, n_loop, c_swizzle, c_swizzle_m1, c1, c2): + tile_block_loop = (n_loop + c_swizzle_m1) // c_swizzle + tile_block_span = c_swizzle * m_loop + tile_block_idx = li // tile_block_span + in_tile_block_idx = li % tile_block_span + is_last_block = tile_block_idx == (tile_block_loop - c1) + n_col_tail = n_loop - c_swizzle * tile_block_idx + n_col = s.select(is_last_block, n_col_tail, c_swizzle) + m_idx = in_tile_block_idx // n_col + n_idx = tile_block_idx * c_swizzle + (in_tile_block_idx % n_col) + odd_block = (tile_block_idx % c2) == c1 + flipped_m_idx = m_loop - m_idx - c1 + m_idx = s.select(odd_block, flipped_m_idx, m_idx) + return m_idx, n_idx diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/compile.sh new file mode 100644 index 00000000..ad191598 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./step1_baseline.py > ./step1_baseline.pto +ptoas --enable-insert-sync ./step1_baseline.pto -o ./step1_baseline.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/step1_baseline.cpp b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/step1_baseline.cpp new file mode 100644 index 00000000..5d4e3701 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/step1_baseline.cpp @@ -0,0 +1,202 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void matmul_kernel_step1_baseline(__gm__ half* v1, __gm__ half* v2, __gm__ half* v3, int32_t v4, int32_t v5, int32_t v6) { + unsigned v7 = 128; + unsigned v8 = 0; + const int32_t v9 = 0; + const int32_t v10 = 1; + const int32_t v11 = 128; + const int32_t v12 = 256; + const int32_t v13 = 512; + const int32_t v14 = 64; + const int32_t v15 = 192; + const int32_t v16 = 320; + const int32_t v17 = 384; + const int32_t v18 = 448; + const int64_t v19 = 0; + const int64_t v20 = 131072; + using T = float; + + #if defined(__DAV_CUBE__) + int64_t v21 = get_block_num(); + int64_t v22 = get_block_idx(); + int32_t v23 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v5 + (uint32_t) v12) - (uint32_t) v10) / v12; + int32_t v24 = v6 / v13; + Tile v25; + TASSIGN(v25, v19); + Tile v26; + TASSIGN(v26, v20); + Tile v27; + TASSIGN(v27, v19); + Tile v28; + TASSIGN(v28, v19); + Tile v29; + TASSIGN(v29, v19); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + for (size_t v30 = (size_t) ((int32_t) (int64_t) v22); v30 < ((size_t) ((int32_t) (uint32_t) v23 * (uint32_t) (v4 / v11))); v30 += (size_t) ((int32_t) (int64_t) v21)) { + int32_t v31 = (int32_t) v30; + int32_t v32 = (int32_t) ((uint32_t) (v31 / v23) * (uint32_t) v11); + int32_t v33 = (int32_t) ((uint32_t) (v31 % v23) * (uint32_t) v12); + unsigned v34 = (unsigned) v6; + unsigned v35 = v7 * v34; + pto::Shape<1, 1, 1, 128, 512> v36 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v37 = pto::Stride<-1, -1, -1, -1, 1>(v35, v35, v35, v34); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v38 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v8 + (unsigned) v32 * (unsigned) v6 + v8 * (unsigned) v10), v36, v37); + pipe_barrier(PIPE_MTE2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v25, v38); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + for (size_t v39 = (size_t) v9; v39 < ((size_t) v24); v39 += (size_t) v10) { + int32_t v40 = (int32_t) v39; + int32_t v41 = (int32_t) ((uint32_t) v40 * (uint32_t) v13); + pto::Shape<1, 1, 1, 256, 256> v42 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v43 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v44 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) v41 * (unsigned) v10 + (unsigned) v33 * (unsigned) v6), v42, v43); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v26, v44); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v27, v25, v9, v9); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v28, v26, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v40 == v9) { + TMATMUL(v29, v27, v28); + } else { + TMATMUL_ACC(v29, v29, v27, v28); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pipe_barrier(PIPE_MTE1); + TEXTRACT(v27, v25, v9, v14); + TEXTRACT(v28, v26, v14, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + TMATMUL_ACC(v29, v29, v27, v28); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v27, v25, v9, v11); + TEXTRACT(v28, v26, v11, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + TMATMUL_ACC(v29, v29, v27, v28); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v27, v25, v9, v15); + TEXTRACT(v28, v26, v15, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + TMATMUL_ACC(v29, v29, v27, v28); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + pto::Shape<1, 1, 1, 256, 256> v45 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v46 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v47 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) ((int32_t) (uint32_t) v41 + (uint32_t) v12) * (unsigned) v10 + (unsigned) v33 * (unsigned) v6), v45, v46); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + TLOAD(v26, v47); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TEXTRACT(v27, v25, v9, v12); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + TEXTRACT(v28, v26, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + TMATMUL_ACC(v29, v29, v27, v28); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + TEXTRACT(v27, v25, v9, v16); + TEXTRACT(v28, v26, v14, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + TMATMUL_ACC(v29, v29, v27, v28); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v27, v25, v9, v17); + TEXTRACT(v28, v26, v11, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + TMATMUL_ACC(v29, v29, v27, v28); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v27, v25, v9, v18); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + TEXTRACT(v28, v26, v15, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + TMATMUL_ACC(v29, v29, v27, v28); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + if ((int32_t) ((uint32_t) v40 + (uint32_t) v10) < v24) { + unsigned v48 = (unsigned) v6; + unsigned v49 = v7 * v48; + pto::Shape<1, 1, 1, 128, 512> v50 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v51 = pto::Stride<-1, -1, -1, -1, 1>(v49, v49, v49, v48); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v52 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v8 + (unsigned) v32 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v41 + (uint32_t) v13) * (unsigned) v10), v50, v51); + TLOAD(v25, v52); + }; + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + }; + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + unsigned v53 = (unsigned) v5; + unsigned v54 = v7 * v53; + pto::Shape<1, 1, 1, 128, 256> v55 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<-1, -1, -1, -1, 1> v56 = pto::Stride<-1, -1, -1, -1, 1>(v54, v54, v54, v53); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v57 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v3 + (v8 + (unsigned) v32 * (unsigned) v5 + (unsigned) v33 * (unsigned) v10), v55, v56); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v57, v29); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + } + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/step1_baseline.pto b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/step1_baseline.pto new file mode 100644 index 00000000..23111626 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/step1_baseline.pto @@ -0,0 +1,113 @@ +module { + func.func @matmul_kernel_step1_baseline(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %0 = arith.index_cast %arg3 : i32 to index + %1 = arith.index_cast %arg4 : i32 to index + %2 = arith.index_cast %arg5 : i32 to index + %3 = pto.get_block_num + %4 = arith.index_cast %3 : i64 to index + %5 = pto.get_block_idx + %6 = arith.index_cast %5 : i64 to index + %7 = arith.addi %1, %c256 : index + %8 = arith.subi %7, %c1 : index + %9 = arith.divsi %8, %c256 : index + %10 = arith.divsi %0, %c128 : index + %11 = arith.muli %9, %10 : index + %12 = arith.divsi %2, %c512 : index + %13 = pto.make_tensor_view %arg0, shape = [%0, %2], strides = [%2, %c1] : !pto.tensor_view + %14 = pto.make_tensor_view %arg1, shape = [%2, %1], strides = [%c1, %2] {layout = #pto.layout} : !pto.tensor_view + %15 = pto.make_tensor_view %arg2, shape = [%0, %1], strides = [%1, %c1] : !pto.tensor_view + %16 = pto.alloc_tile : !pto.tile_buf + %17 = pto.alloc_tile : !pto.tile_buf + %18 = pto.alloc_tile : !pto.tile_buf + %19 = pto.alloc_tile : !pto.tile_buf + %20 = pto.alloc_tile : !pto.tile_buf + scf.for %arg6 = %6 to %11 step %4 { + %21 = arith.divsi %arg6, %9 : index + %22 = arith.remsi %arg6, %9 : index + %23 = arith.muli %21, %c128 : index + %24 = arith.muli %22, %c256 : index + %c256_0 = arith.constant 256 : index + %c512_1 = arith.constant 512 : index + %c256_2 = arith.constant 256 : index + %c128_3 = arith.constant 128 : index + %25 = pto.partition_view %13, offsets = [%23, %c0], sizes = [%c128_3, %c512_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.tload ins(%25 : !pto.partition_tensor_view<128x512xf16>) outs(%16 : !pto.tile_buf) + scf.for %arg7 = %c0 to %12 step %c1 { + %27 = arith.muli %arg7, %c512_1 : index + %28 = arith.cmpi eq, %arg7, %c0 : index + %c0_5 = arith.constant 0 : index + %29 = arith.addi %27, %c0_5 : index + %30 = pto.partition_view %14, offsets = [%29, %24], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.tload ins(%30 : !pto.partition_tensor_view<256x256xf16>) outs(%17 : !pto.tile_buf) + %c0_6 = arith.constant 0 : index + %c0_7 = arith.constant 0 : index + pto.textract ins(%16, %c0, %c0_6 : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + pto.textract ins(%17, %c0_7, %c0 : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) + scf.if %28 { + pto.tmatmul ins(%18, %19 : !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%20, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + } + %c64 = arith.constant 64 : index + %c64_8 = arith.constant 64 : index + pto.textract ins(%16, %c0, %c64 : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + pto.textract ins(%17, %c64_8, %c0 : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) + pto.tmatmul.acc ins(%20, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + %c128_9 = arith.constant 128 : index + %c128_10 = arith.constant 128 : index + pto.textract ins(%16, %c0, %c128_9 : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + pto.textract ins(%17, %c128_10, %c0 : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) + pto.tmatmul.acc ins(%20, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + %c192 = arith.constant 192 : index + %c192_11 = arith.constant 192 : index + pto.textract ins(%16, %c0, %c192 : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + pto.textract ins(%17, %c192_11, %c0 : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) + pto.tmatmul.acc ins(%20, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + %c256_12 = arith.constant 256 : index + %31 = arith.addi %27, %c256_12 : index + %32 = pto.partition_view %14, offsets = [%31, %24], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.tload ins(%32 : !pto.partition_tensor_view<256x256xf16>) outs(%17 : !pto.tile_buf) + %c256_13 = arith.constant 256 : index + %c0_14 = arith.constant 0 : index + pto.textract ins(%16, %c0, %c256_13 : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + pto.textract ins(%17, %c0_14, %c0 : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) + pto.tmatmul.acc ins(%20, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + %c320 = arith.constant 320 : index + %c64_15 = arith.constant 64 : index + pto.textract ins(%16, %c0, %c320 : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + pto.textract ins(%17, %c64_15, %c0 : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) + pto.tmatmul.acc ins(%20, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + %c384 = arith.constant 384 : index + %c128_16 = arith.constant 128 : index + pto.textract ins(%16, %c0, %c384 : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + pto.textract ins(%17, %c128_16, %c0 : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) + pto.tmatmul.acc ins(%20, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + %c448 = arith.constant 448 : index + %c192_17 = arith.constant 192 : index + pto.textract ins(%16, %c0, %c448 : !pto.tile_buf, index, index) outs(%18 : !pto.tile_buf) + pto.textract ins(%17, %c192_17, %c0 : !pto.tile_buf, index, index) outs(%19 : !pto.tile_buf) + pto.tmatmul.acc ins(%20, %18, %19 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%20 : !pto.tile_buf) + %33 = arith.addi %arg7, %c1 : index + %34 = arith.cmpi slt, %33, %12 : index + scf.if %34 { + %35 = arith.addi %27, %c512_1 : index + %c128_18 = arith.constant 128 : index + %36 = pto.partition_view %13, offsets = [%23, %35], sizes = [%c128_18, %c512_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.tload ins(%36 : !pto.partition_tensor_view<128x512xf16>) outs(%16 : !pto.tile_buf) + } + } + %c128_4 = arith.constant 128 : index + %26 = pto.partition_view %15, offsets = [%23, %24], sizes = [%c128_4, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xf16> + pto.tstore ins(%20 : !pto.tile_buf) outs(%26 : !pto.partition_tensor_view<128x256xf16>) + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/step1_baseline.py b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/step1_baseline.py new file mode 100644 index 00000000..e192dd77 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step1_baseline/step1_baseline.py @@ -0,0 +1,138 @@ +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +from common_utils import ( + K_DTILE, + K_QTILE, + K_TILE, + M_TILE, + N_FULL, + build_meta_data, + const, +) + + +def build(): + meta_data = build_meta_data() + + @to_ir_module(meta_data=meta_data) + def matmul_kernel_step1_baseline( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + m_i32: "i32", + n_i32: "i32", + k_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + c128 = const(M_TILE) + c256 = const(N_FULL) + c512 = const(K_DTILE) + + m_total = s.index_cast(m_i32) + n_total = s.index_cast(n_i32) + k_total = s.index_cast(k_i32) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + + n_loop = (n_total + c256 - c1) // c256 + m_loop = m_total // c128 + core_loop = n_loop * m_loop + k_dtile_num = k_total // c512 + + tv_a = pto.as_tensor( + tv_2d, ptr=a_ptr, shape=[m_total, k_total], strides=[k_total, c1] + ) + tv_b = pto.as_tensor( + tv_2d, + ptr=b_ptr, + shape=[k_total, n_total], + strides=[c1, k_total], + layout="DN", + ) + tv_c = pto.as_tensor( + tv_2d, ptr=c_ptr, shape=[m_total, n_total], strides=[n_total, c1] + ) + + a_l1 = pto.alloc_tile(tile_buf_a_l1) + b_l1 = pto.alloc_tile(tile_buf_b_l1) + a_l0 = pto.alloc_tile(tile_buf_a_l0) + b_l0 = pto.alloc_tile(tile_buf_b_l0) + c_l0 = pto.alloc_tile(tile_buf_c) + + for li in pto.range(bid, core_loop, num_blocks): + m_idx = li // n_loop + n_idx = li % n_loop + m_offset = m_idx * c128 + n_offset = n_idx * c256 + c_kt = const(K_TILE) + c_kd = const(K_DTILE) + c_nt = const(N_FULL) + + sv_a0 = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, c0], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a0, a_l1) + + for k_idx in pto.range(c0, k_dtile_num, c1): + k_offset = k_idx * c_kd + is_first_k_tile = k_idx == c0 + + for phase in range(8): + if phase % 4 == 0: + b_half = phase // 4 + h_off = const(b_half * K_TILE) + sv_b = pto.slice_view( + tile_view_b, + source=tv_b, + offsets=[k_offset + h_off, n_offset], + sizes=[c_kt, c_nt], + ) + pto.load(sv_b, b_l1) + + a_col = const(phase * K_QTILE) + b_row = const((phase % 4) * K_QTILE) + tile.extract(a_l1, c0, a_col, a_l0) + tile.extract(b_l1, b_row, c0, b_l0) + + if phase == 0: + with pto.if_context( + is_first_k_tile, has_else=True + ) as branch: + tile.matmul(a_l0, b_l0, c_l0) + with branch.else_context(): + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) + else: + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) + + with pto.if_context(k_idx + c1 < k_dtile_num): + sv_a_next = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, k_offset + c_kd], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a_next, a_l1) + + sv_c = pto.slice_view( + tile_view_c, + source=tv_c, + offsets=[m_offset, n_offset], + sizes=[const(M_TILE), c_nt], + ) + pto.store(c_l0, sv_c) + + return matmul_kernel_step1_baseline + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _ = parser.parse_args() + print(build()) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/common_utils.py b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/common_utils.py new file mode 100644 index 00000000..58d8b801 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/common_utils.py @@ -0,0 +1,76 @@ +from ptodsl import pto +from ptodsl import scalar as s + +const = s.const + +M_TILE = 128 +K_QTILE = 64 +K_TILE = 256 +K_DTILE = 512 +N_FULL = 256 +SWIZZLE_COUNT = 5 + + +def build_meta_data(): + def meta_data(): + dtype = pto.float16 + acc_dtype = pto.float32 + ptr_type = pto.PtrType(dtype) + i32 = pto.int32 + tv_2d = pto.TensorType(rank=2, dtype=dtype) + + tile_view_a = pto.SubTensorType(shape=[M_TILE, K_DTILE], dtype=dtype) + tile_view_b = pto.SubTensorType(shape=[K_TILE, N_FULL], dtype=dtype) + tile_view_c = pto.SubTensorType(shape=[M_TILE, N_FULL], dtype=dtype) + + b_l1_cfg = pto.TileBufConfig( + blayout="RowMajor", slayout="ColMajor", s_fractal_size=512 + ) + + tile_buf_a_l1 = pto.TileBufType( + shape=[M_TILE, K_DTILE], dtype=dtype, memory_space="MAT" + ) + tile_buf_b_l1 = pto.TileBufType( + shape=[K_TILE, N_FULL], dtype=dtype, memory_space="MAT", config=b_l1_cfg + ) + tile_buf_a_l0 = pto.TileBufType( + shape=[M_TILE, K_QTILE], dtype=dtype, memory_space="LEFT" + ) + tile_buf_b_l0 = pto.TileBufType( + shape=[K_QTILE, N_FULL], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_c = pto.TileBufType( + shape=[M_TILE, N_FULL], dtype=acc_dtype, memory_space="ACC" + ) + + return { + "ptr_type": ptr_type, + "i32": i32, + "tv_2d": tv_2d, + "tile_view_a": tile_view_a, + "tile_view_b": tile_view_b, + "tile_view_c": tile_view_c, + "tile_buf_a_l1": tile_buf_a_l1, + "tile_buf_b_l1": tile_buf_b_l1, + "tile_buf_a_l0": tile_buf_a_l0, + "tile_buf_b_l0": tile_buf_b_l0, + "tile_buf_c": tile_buf_c, + } + + return meta_data + + +def swizzle_nz(li, m_loop, n_loop, c_swizzle, c_swizzle_m1, c1, c2): + tile_block_loop = (n_loop + c_swizzle_m1) // c_swizzle + tile_block_span = c_swizzle * m_loop + tile_block_idx = li // tile_block_span + in_tile_block_idx = li % tile_block_span + is_last_block = tile_block_idx == (tile_block_loop - c1) + n_col_tail = n_loop - c_swizzle * tile_block_idx + n_col = s.select(is_last_block, n_col_tail, c_swizzle) + m_idx = in_tile_block_idx // n_col + n_idx = tile_block_idx * c_swizzle + (in_tile_block_idx % n_col) + odd_block = (tile_block_idx % c2) == c1 + flipped_m_idx = m_loop - m_idx - c1 + m_idx = s.select(odd_block, flipped_m_idx, m_idx) + return m_idx, n_idx diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/compile.sh new file mode 100644 index 00000000..0b52a06a --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./step2_doublebuffer.py > ./step2_doublebuffer.pto +ptoas --enable-insert-sync ./step2_doublebuffer.pto -o ./step2_doublebuffer.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/step2_doublebuffer.cpp b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/step2_doublebuffer.cpp new file mode 100644 index 00000000..923b9e2d --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/step2_doublebuffer.cpp @@ -0,0 +1,360 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void matmul_kernel_ABt_autosync(__gm__ half* v1, __gm__ half* v2, __gm__ half* v3, int32_t v4, int32_t v5, int32_t v6) { + unsigned v7 = 128; + unsigned v8 = 0; + const int32_t v9 = 0; + const int32_t v10 = 1; + const int32_t v11 = 2; + const int32_t v12 = 128; + const int32_t v13 = 256; + const int32_t v14 = 512; + const int32_t v15 = 64; + const int32_t v16 = 192; + const int32_t v17 = 320; + const int32_t v18 = 384; + const int32_t v19 = 448; + const int64_t v20 = 393216; + const int64_t v21 = 131072; + const int64_t v22 = 0; + const int64_t v23 = 262144; + const int64_t v24 = 16384; + const int64_t v25 = 32768; + using T = float; + + #if defined(__DAV_CUBE__) + int64_t v26 = get_block_num(); + int64_t v27 = get_block_idx(); + int32_t v28 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v5 + (uint32_t) v13) - (uint32_t) v10) / v13; + int32_t v29 = v6 / v14; + Tile v30; + TASSIGN(v30, v20); + Tile v31; + TASSIGN(v31, v21); + Tile v32; + TASSIGN(v32, v22); + Tile v33; + TASSIGN(v33, v23); + Tile v34; + TASSIGN(v34, v24); + Tile v35; + TASSIGN(v35, v22); + Tile v36; + TASSIGN(v36, v22); + Tile v37; + TASSIGN(v37, v25); + Tile v38; + TASSIGN(v38, v22); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + for (size_t v39 = (size_t) ((int32_t) (int64_t) v27); v39 < ((size_t) ((int32_t) (uint32_t) v28 * (uint32_t) (v4 / v12))); v39 += (size_t) ((int32_t) (int64_t) v26)) { + int32_t v40 = (int32_t) v39; + int32_t v41 = (int32_t) ((uint32_t) (v40 / v28) * (uint32_t) v12); + int32_t v42 = (int32_t) ((uint32_t) (v40 % v28) * (uint32_t) v13); + unsigned v43 = (unsigned) v6; + unsigned v44 = v7 * v43; + pto::Shape<1, 1, 1, 128, 512> v45 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v46 = pto::Stride<-1, -1, -1, -1, 1>(v44, v44, v44, v43); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v47 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v8 + (unsigned) v41 * (unsigned) v6 + v8 * (unsigned) v10), v45, v46); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + pipe_barrier(PIPE_MTE2); + TLOAD(v30, v47); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID7); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + for (size_t v48 = (size_t) v9; v48 < ((size_t) v29); v48 += (size_t) v10) { + int32_t v49 = (int32_t) v48; + int32_t v50 = (int32_t) ((uint32_t) v49 * (uint32_t) v14); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + if (v49 % v11 == v9) { + pto::Shape<1, 1, 1, 256, 256> v51 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v52 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v53 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) v50 * (unsigned) v10 + (unsigned) v42 * (unsigned) v6), v51, v52); + TLOAD(v32, v53); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TEXTRACT(v34, v30, v9, v9); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID5); + TEXTRACT(v36, v32, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v49 == v9) { + pipe_barrier(PIPE_M); + TMATMUL(v38, v34, v36); + } else { + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v34, v36); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + TEXTRACT(v35, v30, v9, v15); + TEXTRACT(v37, v32, v15, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v35, v37); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + pipe_barrier(PIPE_MTE1); + TEXTRACT(v34, v30, v9, v12); + TEXTRACT(v36, v32, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v34, v36); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v35, v30, v9, v16); + TEXTRACT(v37, v32, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v35, v37); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + pto::Shape<1, 1, 1, 256, 256> v54 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v55 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v56 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) ((int32_t) (uint32_t) v50 + (uint32_t) v13) * (unsigned) v10 + (unsigned) v42 * (unsigned) v6), v54, v55); + TLOAD(v33, v56); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v34, v30, v9, v13); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v36, v33, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v34, v36); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v35, v30, v9, v17); + TEXTRACT(v37, v33, v15, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v35, v37); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v34, v30, v9, v18); + TEXTRACT(v36, v33, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v34, v36); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v35, v30, v9, v19); + TEXTRACT(v37, v33, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v35, v37); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + if ((int32_t) ((uint32_t) v49 + (uint32_t) v10) < v29) { + unsigned v57 = (unsigned) v6; + unsigned v58 = v7 * v57; + pto::Shape<1, 1, 1, 128, 512> v59 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v60 = pto::Stride<-1, -1, -1, -1, 1>(v58, v58, v58, v57); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v61 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v8 + (unsigned) v41 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v50 + (uint32_t) v14) * (unsigned) v10), v59, v60); + pipe_barrier(PIPE_MTE2); + TLOAD(v31, v61); + }; + } else { + pto::Shape<1, 1, 1, 256, 256> v62 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v63 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v64 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) v50 * (unsigned) v10 + (unsigned) v42 * (unsigned) v6), v62, v63); + TLOAD(v32, v64); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v34, v31, v9, v9); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v36, v32, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v49 == v9) { + pipe_barrier(PIPE_M); + TMATMUL(v38, v34, v36); + } else { + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v34, v36); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v35, v31, v9, v15); + TEXTRACT(v37, v32, v15, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v35, v37); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + pipe_barrier(PIPE_MTE1); + TEXTRACT(v34, v31, v9, v12); + TEXTRACT(v36, v32, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v34, v36); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v35, v31, v9, v16); + TEXTRACT(v37, v32, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v35, v37); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + pto::Shape<1, 1, 1, 256, 256> v65 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v66 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v67 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) ((int32_t) (uint32_t) v50 + (uint32_t) v13) * (unsigned) v10 + (unsigned) v42 * (unsigned) v6), v65, v66); + TLOAD(v33, v67); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v34, v31, v9, v13); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v36, v33, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v34, v36); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v35, v31, v9, v17); + TEXTRACT(v37, v33, v15, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v35, v37); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v34, v31, v9, v18); + TEXTRACT(v36, v33, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v34, v36); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v35, v31, v9, v19); + TEXTRACT(v37, v33, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v38, v38, v35, v37); + if ((int32_t) ((uint32_t) v49 + (uint32_t) v10) < v29) { + unsigned v68 = (unsigned) v6; + unsigned v69 = v7 * v68; + pto::Shape<1, 1, 1, 128, 512> v70 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v71 = pto::Stride<-1, -1, -1, -1, 1>(v69, v69, v69, v68); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v72 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v8 + (unsigned) v41 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v50 + (uint32_t) v14) * (unsigned) v10), v70, v71); + pipe_barrier(PIPE_MTE2); + TLOAD(v30, v72); + }; + }; + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + }; + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + unsigned v73 = (unsigned) v5; + unsigned v74 = v7 * v73; + pto::Shape<1, 1, 1, 128, 256> v75 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<-1, -1, -1, -1, 1> v76 = pto::Stride<-1, -1, -1, -1, 1>(v74, v74, v74, v73); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v77 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v3 + (v8 + (unsigned) v41 * (unsigned) v5 + (unsigned) v42 * (unsigned) v10), v75, v76); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + pipe_barrier(PIPE_FIX); + TSTORE(v77, v38); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + } + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/step2_doublebuffer.pto b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/step2_doublebuffer.pto new file mode 100644 index 00000000..ee622d69 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/step2_doublebuffer.pto @@ -0,0 +1,184 @@ +module { + func.func @matmul_kernel_ABt_autosync(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %0 = arith.index_cast %arg3 : i32 to index + %1 = arith.index_cast %arg4 : i32 to index + %2 = arith.index_cast %arg5 : i32 to index + %3 = pto.get_block_num + %4 = arith.index_cast %3 : i64 to index + %5 = pto.get_block_idx + %6 = arith.index_cast %5 : i64 to index + %7 = arith.addi %1, %c256 : index + %8 = arith.subi %7, %c1 : index + %9 = arith.divsi %8, %c256 : index + %10 = arith.divsi %0, %c128 : index + %11 = arith.muli %9, %10 : index + %12 = arith.divsi %2, %c512 : index + %13 = pto.make_tensor_view %arg0, shape = [%0, %2], strides = [%2, %c1] : !pto.tensor_view + %14 = pto.make_tensor_view %arg1, shape = [%2, %1], strides = [%c1, %2] {layout = #pto.layout} : !pto.tensor_view + %15 = pto.make_tensor_view %arg2, shape = [%0, %1], strides = [%1, %c1] : !pto.tensor_view + %16 = pto.alloc_tile : !pto.tile_buf + %17 = pto.alloc_tile : !pto.tile_buf + %18 = pto.alloc_tile : !pto.tile_buf + %19 = pto.alloc_tile : !pto.tile_buf + %20 = pto.alloc_tile : !pto.tile_buf + %21 = pto.alloc_tile : !pto.tile_buf + %22 = pto.alloc_tile : !pto.tile_buf + %23 = pto.alloc_tile : !pto.tile_buf + %24 = pto.alloc_tile : !pto.tile_buf + scf.for %arg6 = %6 to %11 step %4 { + %25 = arith.divsi %arg6, %9 : index + %26 = arith.remsi %arg6, %9 : index + %27 = arith.muli %25, %c128 : index + %28 = arith.muli %26, %c256 : index + %c256_0 = arith.constant 256 : index + %c512_1 = arith.constant 512 : index + %c256_2 = arith.constant 256 : index + %c128_3 = arith.constant 128 : index + %29 = pto.partition_view %13, offsets = [%27, %c0], sizes = [%c128_3, %c512_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.tload ins(%29 : !pto.partition_tensor_view<128x512xf16>) outs(%16 : !pto.tile_buf) + scf.for %arg7 = %c0 to %12 step %c1 { + %31 = arith.muli %arg7, %c512_1 : index + %32 = arith.remsi %arg7, %c2 : index + %33 = arith.cmpi eq, %32, %c0 : index + scf.if %33 { + %34 = arith.cmpi eq, %arg7, %c0 : index + %c0_5 = arith.constant 0 : index + %35 = arith.addi %31, %c0_5 : index + %36 = pto.partition_view %14, offsets = [%35, %28], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.tload ins(%36 : !pto.partition_tensor_view<256x256xf16>) outs(%18 : !pto.tile_buf) + %c0_6 = arith.constant 0 : index + %c0_7 = arith.constant 0 : index + pto.textract ins(%16, %c0, %c0_6 : !pto.tile_buf, index, index) outs(%20 : !pto.tile_buf) + pto.textract ins(%18, %c0_7, %c0 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + scf.if %34 { + pto.tmatmul ins(%20, %22 : !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%24, %20, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + } + %c64 = arith.constant 64 : index + %c64_8 = arith.constant 64 : index + pto.textract ins(%16, %c0, %c64 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%18, %c64_8, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c128_9 = arith.constant 128 : index + %c128_10 = arith.constant 128 : index + pto.textract ins(%16, %c0, %c128_9 : !pto.tile_buf, index, index) outs(%20 : !pto.tile_buf) + pto.textract ins(%18, %c128_10, %c0 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %20, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c192 = arith.constant 192 : index + %c192_11 = arith.constant 192 : index + pto.textract ins(%16, %c0, %c192 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%18, %c192_11, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c256_12 = arith.constant 256 : index + %37 = arith.addi %31, %c256_12 : index + %38 = pto.partition_view %14, offsets = [%37, %28], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.tload ins(%38 : !pto.partition_tensor_view<256x256xf16>) outs(%19 : !pto.tile_buf) + %c256_13 = arith.constant 256 : index + %c0_14 = arith.constant 0 : index + pto.textract ins(%16, %c0, %c256_13 : !pto.tile_buf, index, index) outs(%20 : !pto.tile_buf) + pto.textract ins(%19, %c0_14, %c0 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %20, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c320 = arith.constant 320 : index + %c64_15 = arith.constant 64 : index + pto.textract ins(%16, %c0, %c320 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%19, %c64_15, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c384 = arith.constant 384 : index + %c128_16 = arith.constant 128 : index + pto.textract ins(%16, %c0, %c384 : !pto.tile_buf, index, index) outs(%20 : !pto.tile_buf) + pto.textract ins(%19, %c128_16, %c0 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %20, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c448 = arith.constant 448 : index + %c192_17 = arith.constant 192 : index + pto.textract ins(%16, %c0, %c448 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%19, %c192_17, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %39 = arith.addi %arg7, %c1 : index + %40 = arith.cmpi slt, %39, %12 : index + scf.if %40 { + %41 = arith.addi %31, %c512_1 : index + %c128_18 = arith.constant 128 : index + %42 = pto.partition_view %13, offsets = [%27, %41], sizes = [%c128_18, %c512_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.tload ins(%42 : !pto.partition_tensor_view<128x512xf16>) outs(%17 : !pto.tile_buf) + } + } else { + %34 = arith.cmpi eq, %arg7, %c0 : index + %c0_5 = arith.constant 0 : index + %35 = arith.addi %31, %c0_5 : index + %36 = pto.partition_view %14, offsets = [%35, %28], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.tload ins(%36 : !pto.partition_tensor_view<256x256xf16>) outs(%18 : !pto.tile_buf) + %c0_6 = arith.constant 0 : index + %c0_7 = arith.constant 0 : index + pto.textract ins(%17, %c0, %c0_6 : !pto.tile_buf, index, index) outs(%20 : !pto.tile_buf) + pto.textract ins(%18, %c0_7, %c0 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + scf.if %34 { + pto.tmatmul ins(%20, %22 : !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%24, %20, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + } + %c64 = arith.constant 64 : index + %c64_8 = arith.constant 64 : index + pto.textract ins(%17, %c0, %c64 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%18, %c64_8, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c128_9 = arith.constant 128 : index + %c128_10 = arith.constant 128 : index + pto.textract ins(%17, %c0, %c128_9 : !pto.tile_buf, index, index) outs(%20 : !pto.tile_buf) + pto.textract ins(%18, %c128_10, %c0 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %20, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c192 = arith.constant 192 : index + %c192_11 = arith.constant 192 : index + pto.textract ins(%17, %c0, %c192 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%18, %c192_11, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c256_12 = arith.constant 256 : index + %37 = arith.addi %31, %c256_12 : index + %38 = pto.partition_view %14, offsets = [%37, %28], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.tload ins(%38 : !pto.partition_tensor_view<256x256xf16>) outs(%19 : !pto.tile_buf) + %c256_13 = arith.constant 256 : index + %c0_14 = arith.constant 0 : index + pto.textract ins(%17, %c0, %c256_13 : !pto.tile_buf, index, index) outs(%20 : !pto.tile_buf) + pto.textract ins(%19, %c0_14, %c0 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %20, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c320 = arith.constant 320 : index + %c64_15 = arith.constant 64 : index + pto.textract ins(%17, %c0, %c320 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%19, %c64_15, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c384 = arith.constant 384 : index + %c128_16 = arith.constant 128 : index + pto.textract ins(%17, %c0, %c384 : !pto.tile_buf, index, index) outs(%20 : !pto.tile_buf) + pto.textract ins(%19, %c128_16, %c0 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %20, %22 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %c448 = arith.constant 448 : index + %c192_17 = arith.constant 192 : index + pto.textract ins(%17, %c0, %c448 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%19, %c192_17, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%24, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%24 : !pto.tile_buf) + %39 = arith.addi %arg7, %c1 : index + %40 = arith.cmpi slt, %39, %12 : index + scf.if %40 { + %41 = arith.addi %31, %c512_1 : index + %c128_18 = arith.constant 128 : index + %42 = pto.partition_view %13, offsets = [%27, %41], sizes = [%c128_18, %c512_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.tload ins(%42 : !pto.partition_tensor_view<128x512xf16>) outs(%16 : !pto.tile_buf) + } + } + } + %c128_4 = arith.constant 128 : index + %30 = pto.partition_view %15, offsets = [%27, %28], sizes = [%c128_4, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xf16> + pto.tstore ins(%24 : !pto.tile_buf) outs(%30 : !pto.partition_tensor_view<128x256xf16>) + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/step2_doublebuffer.py b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/step2_doublebuffer.py new file mode 100644 index 00000000..cb2344e6 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step2_doublebuffer/step2_doublebuffer.py @@ -0,0 +1,153 @@ +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +from common_utils import ( + K_DTILE, + K_QTILE, + K_TILE, + M_TILE, + N_FULL, + build_meta_data, + const, +) + + +def build(): + meta_data = build_meta_data() + + @to_ir_module(meta_data=meta_data) + def matmul_kernel_ABt_autosync( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + m_i32: "i32", + n_i32: "i32", + k_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + c2 = const(2) + c128 = const(M_TILE) + c256 = const(N_FULL) + c512 = const(K_DTILE) + + m_total = s.index_cast(m_i32) + n_total = s.index_cast(n_i32) + k_total = s.index_cast(k_i32) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + + n_loop = (n_total + c256 - c1) // c256 + m_loop = m_total // c128 + core_loop = n_loop * m_loop + k_dtile_num = k_total // c512 + + tv_a = pto.as_tensor( + tv_2d, ptr=a_ptr, shape=[m_total, k_total], strides=[k_total, c1] + ) + tv_b = pto.as_tensor( + tv_2d, + ptr=b_ptr, + shape=[k_total, n_total], + strides=[c1, k_total], + layout="DN", + ) + tv_c = pto.as_tensor( + tv_2d, ptr=c_ptr, shape=[m_total, n_total], strides=[n_total, c1] + ) + + a_l1 = [pto.alloc_tile(tile_buf_a_l1), pto.alloc_tile(tile_buf_a_l1)] + b_l1 = [pto.alloc_tile(tile_buf_b_l1), pto.alloc_tile(tile_buf_b_l1)] + a_l0 = [pto.alloc_tile(tile_buf_a_l0), pto.alloc_tile(tile_buf_a_l0)] + b_l0 = [pto.alloc_tile(tile_buf_b_l0), pto.alloc_tile(tile_buf_b_l0)] + c_l0 = pto.alloc_tile(tile_buf_c) + + for li in pto.range(bid, core_loop, num_blocks): + m_idx = li // n_loop + n_idx = li % n_loop + + m_offset = m_idx * c128 + n_offset = n_idx * c256 + c_kt = const(K_TILE) + c_kd = const(K_DTILE) + c_nt = const(N_FULL) + + sv_a0 = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, c0], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a0, a_l1[0]) + + for k_idx in pto.range(c0, k_dtile_num, c1): + k_offset = k_idx * c_kd + + def run_loop_k(a_curr, a_next): + is_first_k_tile = k_idx == c0 + + for h in range(2): + h_off = const(h * K_TILE) + sv_b = pto.slice_view( + tile_view_b, + source=tv_b, + offsets=[k_offset + h_off, n_offset], + sizes=[c_kt, c_nt], + ) + pto.load(sv_b, b_l1[h]) + + for quarter in range(4): + phase = h * 4 + quarter + ping = phase & 1 + a_col = const(phase * K_QTILE) + b_row = const(quarter * K_QTILE) + + tile.extract(a_curr, c0, a_col, a_l0[ping]) + tile.extract(b_l1[h], b_row, c0, b_l0[ping]) + + if phase == 0: + pto.cond( + is_first_k_tile, + lambda: tile.matmul( + a_l0[ping], b_l0[ping], c_l0 + ), + lambda: tile.matmul_acc( + c_l0, a_l0[ping], b_l0[ping], c_l0 + ), + ) + else: + tile.matmul_acc(c_l0, a_l0[ping], b_l0[ping], c_l0) + + with pto.if_context(k_idx + c1 < k_dtile_num): + sv_a_next = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, k_offset + c_kd], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a_next, a_next) + + is_curr0 = (k_idx % c2) == c0 + with pto.if_context(is_curr0, has_else=True) as branch: + run_loop_k(a_l1[0], a_l1[1]) + with branch.else_context(): + run_loop_k(a_l1[1], a_l1[0]) + + sv_c = pto.slice_view( + tile_view_c, + source=tv_c, + offsets=[m_offset, n_offset], + sizes=[const(M_TILE), c_nt], + ) + pto.store(c_l0, sv_c) + + return matmul_kernel_ABt_autosync + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _ = parser.parse_args() + print(build()) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/common_utils.py b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/common_utils.py new file mode 100644 index 00000000..58d8b801 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/common_utils.py @@ -0,0 +1,76 @@ +from ptodsl import pto +from ptodsl import scalar as s + +const = s.const + +M_TILE = 128 +K_QTILE = 64 +K_TILE = 256 +K_DTILE = 512 +N_FULL = 256 +SWIZZLE_COUNT = 5 + + +def build_meta_data(): + def meta_data(): + dtype = pto.float16 + acc_dtype = pto.float32 + ptr_type = pto.PtrType(dtype) + i32 = pto.int32 + tv_2d = pto.TensorType(rank=2, dtype=dtype) + + tile_view_a = pto.SubTensorType(shape=[M_TILE, K_DTILE], dtype=dtype) + tile_view_b = pto.SubTensorType(shape=[K_TILE, N_FULL], dtype=dtype) + tile_view_c = pto.SubTensorType(shape=[M_TILE, N_FULL], dtype=dtype) + + b_l1_cfg = pto.TileBufConfig( + blayout="RowMajor", slayout="ColMajor", s_fractal_size=512 + ) + + tile_buf_a_l1 = pto.TileBufType( + shape=[M_TILE, K_DTILE], dtype=dtype, memory_space="MAT" + ) + tile_buf_b_l1 = pto.TileBufType( + shape=[K_TILE, N_FULL], dtype=dtype, memory_space="MAT", config=b_l1_cfg + ) + tile_buf_a_l0 = pto.TileBufType( + shape=[M_TILE, K_QTILE], dtype=dtype, memory_space="LEFT" + ) + tile_buf_b_l0 = pto.TileBufType( + shape=[K_QTILE, N_FULL], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_c = pto.TileBufType( + shape=[M_TILE, N_FULL], dtype=acc_dtype, memory_space="ACC" + ) + + return { + "ptr_type": ptr_type, + "i32": i32, + "tv_2d": tv_2d, + "tile_view_a": tile_view_a, + "tile_view_b": tile_view_b, + "tile_view_c": tile_view_c, + "tile_buf_a_l1": tile_buf_a_l1, + "tile_buf_b_l1": tile_buf_b_l1, + "tile_buf_a_l0": tile_buf_a_l0, + "tile_buf_b_l0": tile_buf_b_l0, + "tile_buf_c": tile_buf_c, + } + + return meta_data + + +def swizzle_nz(li, m_loop, n_loop, c_swizzle, c_swizzle_m1, c1, c2): + tile_block_loop = (n_loop + c_swizzle_m1) // c_swizzle + tile_block_span = c_swizzle * m_loop + tile_block_idx = li // tile_block_span + in_tile_block_idx = li % tile_block_span + is_last_block = tile_block_idx == (tile_block_loop - c1) + n_col_tail = n_loop - c_swizzle * tile_block_idx + n_col = s.select(is_last_block, n_col_tail, c_swizzle) + m_idx = in_tile_block_idx // n_col + n_idx = tile_block_idx * c_swizzle + (in_tile_block_idx % n_col) + odd_block = (tile_block_idx % c2) == c1 + flipped_m_idx = m_loop - m_idx - c1 + m_idx = s.select(odd_block, flipped_m_idx, m_idx) + return m_idx, n_idx diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/compile.sh new file mode 100644 index 00000000..35ecfbca --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./step3_swizzle.py > ./step3_swizzle.pto +ptoas --enable-insert-sync ./step3_swizzle.pto -o ./step3_swizzle.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/step3_swizzle.cpp b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/step3_swizzle.cpp new file mode 100644 index 00000000..4b4f9cdc --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/step3_swizzle.cpp @@ -0,0 +1,369 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void matmul_kernel_ABt_autosync(__gm__ half* v1, __gm__ half* v2, __gm__ half* v3, int32_t v4, int32_t v5, int32_t v6) { + unsigned v7 = 128; + unsigned v8 = 0; + const int32_t v9 = 0; + const int32_t v10 = 1; + const int32_t v11 = 2; + const int32_t v12 = 128; + const int32_t v13 = 256; + const int32_t v14 = 512; + const int32_t v15 = 5; + const int32_t v16 = 64; + const int32_t v17 = 192; + const int32_t v18 = 320; + const int32_t v19 = 384; + const int32_t v20 = 448; + const int32_t v21 = 4; + const int64_t v22 = 393216; + const int64_t v23 = 131072; + const int64_t v24 = 262144; + const int64_t v25 = 0; + const int64_t v26 = 16384; + const int64_t v27 = 32768; + using T = float; + + #if defined(__DAV_CUBE__) + int64_t v28 = get_block_num(); + int64_t v29 = get_block_idx(); + int32_t v30 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v5 + (uint32_t) v13) - (uint32_t) v10) / v13; + int32_t v31 = v4 / v12; + int32_t v32 = v6 / v14; + Tile v33; + TASSIGN(v33, v22); + Tile v34; + TASSIGN(v34, v23); + Tile v35; + TASSIGN(v35, v24); + Tile v36; + TASSIGN(v36, v25); + Tile v37; + TASSIGN(v37, v25); + Tile v38; + TASSIGN(v38, v26); + Tile v39; + TASSIGN(v39, v25); + Tile v40; + TASSIGN(v40, v27); + Tile v41; + TASSIGN(v41, v25); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + for (size_t v42 = (size_t) ((int32_t) (int64_t) v29); v42 < ((size_t) ((int32_t) (uint32_t) v30 * (uint32_t) v31)); v42 += (size_t) ((int32_t) (int64_t) v28)) { + int32_t v43 = (int32_t) v42; + int32_t v44 = (int32_t) ((uint32_t) v31 * (uint32_t) v15); + int32_t v45 = v43 / v44; + int32_t v46 = v43 % v44; + int32_t v47 = (int32_t) ((uint32_t) v45 * (uint32_t) v15); + int32_t v48 = v45 == (int32_t) ((uint32_t) ((int32_t) ((uint32_t) v30 + (uint32_t) v21) / v15) - (uint32_t) v10) ? (int32_t) ((uint32_t) v30 - (uint32_t) v47) : v15; + int32_t v49 = v46 / v48; + int32_t v50 = (int32_t) ((uint32_t) (v45 % v11 == v10 ? (int32_t) ((uint32_t) ((int32_t) (uint32_t) v31 - (uint32_t) v49) - (uint32_t) v10) : v49) * (uint32_t) v12); + int32_t v51 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v47 + (uint32_t) (v46 % v48)) * (uint32_t) v13); + unsigned v52 = (unsigned) v6; + unsigned v53 = v7 * v52; + pto::Shape<1, 1, 1, 128, 512> v54 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v55 = pto::Stride<-1, -1, -1, -1, 1>(v53, v53, v53, v52); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v56 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v8 + (unsigned) v50 * (unsigned) v6 + v8 * (unsigned) v10), v54, v55); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + pipe_barrier(PIPE_MTE2); + TLOAD(v33, v56); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID7); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + for (size_t v57 = (size_t) v9; v57 < ((size_t) v32); v57 += (size_t) v10) { + int32_t v58 = (int32_t) v57; + int32_t v59 = (int32_t) ((uint32_t) v58 * (uint32_t) v14); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + if (v58 % v11 == v9) { + pto::Shape<1, 1, 1, 256, 256> v60 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v61 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v62 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) v59 * (unsigned) v10 + (unsigned) v51 * (unsigned) v6), v60, v61); + TLOAD(v35, v62); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TEXTRACT(v37, v33, v9, v9); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID5); + TEXTRACT(v39, v35, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v58 == v9) { + pipe_barrier(PIPE_M); + TMATMUL(v41, v37, v39); + } else { + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v37, v39); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + TEXTRACT(v38, v33, v9, v16); + TEXTRACT(v40, v35, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v38, v40); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + pipe_barrier(PIPE_MTE1); + TEXTRACT(v37, v33, v9, v12); + TEXTRACT(v39, v35, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v37, v39); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v38, v33, v9, v17); + TEXTRACT(v40, v35, v17, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v38, v40); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + pto::Shape<1, 1, 1, 256, 256> v63 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v64 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v65 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) ((int32_t) (uint32_t) v59 + (uint32_t) v13) * (unsigned) v10 + (unsigned) v51 * (unsigned) v6), v63, v64); + TLOAD(v36, v65); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v37, v33, v9, v13); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v39, v36, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v37, v39); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v38, v33, v9, v18); + TEXTRACT(v40, v36, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v38, v40); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v37, v33, v9, v19); + TEXTRACT(v39, v36, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v37, v39); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v38, v33, v9, v20); + TEXTRACT(v40, v36, v17, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v38, v40); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + if ((int32_t) ((uint32_t) v58 + (uint32_t) v10) < v32) { + unsigned v66 = (unsigned) v6; + unsigned v67 = v7 * v66; + pto::Shape<1, 1, 1, 128, 512> v68 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v69 = pto::Stride<-1, -1, -1, -1, 1>(v67, v67, v67, v66); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v70 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v8 + (unsigned) v50 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v59 + (uint32_t) v14) * (unsigned) v10), v68, v69); + pipe_barrier(PIPE_MTE2); + TLOAD(v34, v70); + }; + } else { + pto::Shape<1, 1, 1, 256, 256> v71 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v72 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v73 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) v59 * (unsigned) v10 + (unsigned) v51 * (unsigned) v6), v71, v72); + TLOAD(v35, v73); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v37, v34, v9, v9); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v39, v35, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v58 == v9) { + pipe_barrier(PIPE_M); + TMATMUL(v41, v37, v39); + } else { + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v37, v39); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v38, v34, v9, v16); + TEXTRACT(v40, v35, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v38, v40); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + pipe_barrier(PIPE_MTE1); + TEXTRACT(v37, v34, v9, v12); + TEXTRACT(v39, v35, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v37, v39); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v38, v34, v9, v17); + TEXTRACT(v40, v35, v17, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v38, v40); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + pto::Shape<1, 1, 1, 256, 256> v74 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v75 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v76 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) ((int32_t) (uint32_t) v59 + (uint32_t) v13) * (unsigned) v10 + (unsigned) v51 * (unsigned) v6), v74, v75); + TLOAD(v36, v76); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v37, v34, v9, v13); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v39, v36, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v37, v39); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v38, v34, v9, v18); + TEXTRACT(v40, v36, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v38, v40); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TEXTRACT(v37, v34, v9, v19); + TEXTRACT(v39, v36, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v37, v39); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v38, v34, v9, v20); + TEXTRACT(v40, v36, v17, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v41, v41, v38, v40); + if ((int32_t) ((uint32_t) v58 + (uint32_t) v10) < v32) { + unsigned v77 = (unsigned) v6; + unsigned v78 = v7 * v77; + pto::Shape<1, 1, 1, 128, 512> v79 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v80 = pto::Stride<-1, -1, -1, -1, 1>(v78, v78, v78, v77); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v81 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v8 + (unsigned) v50 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v59 + (uint32_t) v14) * (unsigned) v10), v79, v80); + pipe_barrier(PIPE_MTE2); + TLOAD(v33, v81); + }; + }; + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + }; + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + unsigned v82 = (unsigned) v5; + unsigned v83 = v7 * v82; + pto::Shape<1, 1, 1, 128, 256> v84 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<-1, -1, -1, -1, 1> v85 = pto::Stride<-1, -1, -1, -1, 1>(v83, v83, v83, v82); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v86 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v3 + (v8 + (unsigned) v50 * (unsigned) v5 + (unsigned) v51 * (unsigned) v10), v84, v85); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + pipe_barrier(PIPE_FIX); + TSTORE(v86, v41); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + } + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/step3_swizzle.pto b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/step3_swizzle.pto new file mode 100644 index 00000000..74a5b9f5 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/step3_swizzle.pto @@ -0,0 +1,203 @@ +module { + func.func @matmul_kernel_ABt_autosync(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %0 = arith.index_cast %arg3 : i32 to index + %1 = arith.index_cast %arg4 : i32 to index + %2 = arith.index_cast %arg5 : i32 to index + %3 = pto.get_block_num + %4 = arith.index_cast %3 : i64 to index + %5 = pto.get_block_idx + %6 = arith.index_cast %5 : i64 to index + %7 = arith.addi %1, %c256 : index + %8 = arith.subi %7, %c1 : index + %9 = arith.divsi %8, %c256 : index + %10 = arith.divsi %0, %c128 : index + %11 = arith.muli %9, %10 : index + %12 = arith.divsi %2, %c512 : index + %c5 = arith.constant 5 : index + %13 = arith.subi %c5, %c1 : index + %14 = pto.make_tensor_view %arg0, shape = [%0, %2], strides = [%2, %c1] : !pto.tensor_view + %15 = pto.make_tensor_view %arg1, shape = [%2, %1], strides = [%c1, %2] {layout = #pto.layout} : !pto.tensor_view + %16 = pto.make_tensor_view %arg2, shape = [%0, %1], strides = [%1, %c1] : !pto.tensor_view + %17 = pto.alloc_tile : !pto.tile_buf + %18 = pto.alloc_tile : !pto.tile_buf + %19 = pto.alloc_tile : !pto.tile_buf + %20 = pto.alloc_tile : !pto.tile_buf + %21 = pto.alloc_tile : !pto.tile_buf + %22 = pto.alloc_tile : !pto.tile_buf + %23 = pto.alloc_tile : !pto.tile_buf + %24 = pto.alloc_tile : !pto.tile_buf + %25 = pto.alloc_tile : !pto.tile_buf + scf.for %arg6 = %6 to %11 step %4 { + %26 = arith.addi %9, %13 : index + %27 = arith.divsi %26, %c5 : index + %28 = arith.muli %c5, %10 : index + %29 = arith.divsi %arg6, %28 : index + %30 = arith.remsi %arg6, %28 : index + %31 = arith.subi %27, %c1 : index + %32 = arith.cmpi eq, %29, %31 : index + %33 = arith.muli %c5, %29 : index + %34 = arith.subi %9, %33 : index + %35 = arith.select %32, %34, %c5 : index + %36 = arith.divsi %30, %35 : index + %37 = arith.muli %29, %c5 : index + %38 = arith.remsi %30, %35 : index + %39 = arith.addi %37, %38 : index + %40 = arith.remsi %29, %c2 : index + %41 = arith.cmpi eq, %40, %c1 : index + %42 = arith.subi %10, %36 : index + %43 = arith.subi %42, %c1 : index + %44 = arith.select %41, %43, %36 : index + %45 = arith.muli %44, %c128 : index + %46 = arith.muli %39, %c256 : index + %c256_0 = arith.constant 256 : index + %c512_1 = arith.constant 512 : index + %c256_2 = arith.constant 256 : index + %c128_3 = arith.constant 128 : index + %47 = pto.partition_view %14, offsets = [%45, %c0], sizes = [%c128_3, %c512_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.tload ins(%47 : !pto.partition_tensor_view<128x512xf16>) outs(%17 : !pto.tile_buf) + scf.for %arg7 = %c0 to %12 step %c1 { + %49 = arith.muli %arg7, %c512_1 : index + %50 = arith.remsi %arg7, %c2 : index + %51 = arith.cmpi eq, %50, %c0 : index + scf.if %51 { + %52 = arith.cmpi eq, %arg7, %c0 : index + %c0_5 = arith.constant 0 : index + %53 = arith.addi %49, %c0_5 : index + %54 = pto.partition_view %15, offsets = [%53, %46], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.tload ins(%54 : !pto.partition_tensor_view<256x256xf16>) outs(%19 : !pto.tile_buf) + %c0_6 = arith.constant 0 : index + %c0_7 = arith.constant 0 : index + pto.textract ins(%17, %c0, %c0_6 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%19, %c0_7, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + scf.if %52 { + pto.tmatmul ins(%21, %23 : !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + } + %c64 = arith.constant 64 : index + %c64_8 = arith.constant 64 : index + pto.textract ins(%17, %c0, %c64 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%19, %c64_8, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c128_9 = arith.constant 128 : index + %c128_10 = arith.constant 128 : index + pto.textract ins(%17, %c0, %c128_9 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%19, %c128_10, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c192 = arith.constant 192 : index + %c192_11 = arith.constant 192 : index + pto.textract ins(%17, %c0, %c192 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%19, %c192_11, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c256_12 = arith.constant 256 : index + %55 = arith.addi %49, %c256_12 : index + %56 = pto.partition_view %15, offsets = [%55, %46], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.tload ins(%56 : !pto.partition_tensor_view<256x256xf16>) outs(%20 : !pto.tile_buf) + %c256_13 = arith.constant 256 : index + %c0_14 = arith.constant 0 : index + pto.textract ins(%17, %c0, %c256_13 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%20, %c0_14, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c320 = arith.constant 320 : index + %c64_15 = arith.constant 64 : index + pto.textract ins(%17, %c0, %c320 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%20, %c64_15, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c384 = arith.constant 384 : index + %c128_16 = arith.constant 128 : index + pto.textract ins(%17, %c0, %c384 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%20, %c128_16, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c448 = arith.constant 448 : index + %c192_17 = arith.constant 192 : index + pto.textract ins(%17, %c0, %c448 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%20, %c192_17, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %57 = arith.addi %arg7, %c1 : index + %58 = arith.cmpi slt, %57, %12 : index + scf.if %58 { + %59 = arith.addi %49, %c512_1 : index + %c128_18 = arith.constant 128 : index + %60 = pto.partition_view %14, offsets = [%45, %59], sizes = [%c128_18, %c512_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.tload ins(%60 : !pto.partition_tensor_view<128x512xf16>) outs(%18 : !pto.tile_buf) + } + } else { + %52 = arith.cmpi eq, %arg7, %c0 : index + %c0_5 = arith.constant 0 : index + %53 = arith.addi %49, %c0_5 : index + %54 = pto.partition_view %15, offsets = [%53, %46], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.tload ins(%54 : !pto.partition_tensor_view<256x256xf16>) outs(%19 : !pto.tile_buf) + %c0_6 = arith.constant 0 : index + %c0_7 = arith.constant 0 : index + pto.textract ins(%18, %c0, %c0_6 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%19, %c0_7, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + scf.if %52 { + pto.tmatmul ins(%21, %23 : !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + } + %c64 = arith.constant 64 : index + %c64_8 = arith.constant 64 : index + pto.textract ins(%18, %c0, %c64 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%19, %c64_8, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c128_9 = arith.constant 128 : index + %c128_10 = arith.constant 128 : index + pto.textract ins(%18, %c0, %c128_9 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%19, %c128_10, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c192 = arith.constant 192 : index + %c192_11 = arith.constant 192 : index + pto.textract ins(%18, %c0, %c192 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%19, %c192_11, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c256_12 = arith.constant 256 : index + %55 = arith.addi %49, %c256_12 : index + %56 = pto.partition_view %15, offsets = [%55, %46], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.tload ins(%56 : !pto.partition_tensor_view<256x256xf16>) outs(%20 : !pto.tile_buf) + %c256_13 = arith.constant 256 : index + %c0_14 = arith.constant 0 : index + pto.textract ins(%18, %c0, %c256_13 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%20, %c0_14, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c320 = arith.constant 320 : index + %c64_15 = arith.constant 64 : index + pto.textract ins(%18, %c0, %c320 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%20, %c64_15, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c384 = arith.constant 384 : index + %c128_16 = arith.constant 128 : index + pto.textract ins(%18, %c0, %c384 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%20, %c128_16, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %c448 = arith.constant 448 : index + %c192_17 = arith.constant 192 : index + pto.textract ins(%18, %c0, %c448 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%20, %c192_17, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + %57 = arith.addi %arg7, %c1 : index + %58 = arith.cmpi slt, %57, %12 : index + scf.if %58 { + %59 = arith.addi %49, %c512_1 : index + %c128_18 = arith.constant 128 : index + %60 = pto.partition_view %14, offsets = [%45, %59], sizes = [%c128_18, %c512_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.tload ins(%60 : !pto.partition_tensor_view<128x512xf16>) outs(%17 : !pto.tile_buf) + } + } + } + %c128_4 = arith.constant 128 : index + %48 = pto.partition_view %16, offsets = [%45, %46], sizes = [%c128_4, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xf16> + pto.tstore ins(%25 : !pto.tile_buf) outs(%48 : !pto.partition_tensor_view<128x256xf16>) + } + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/step3_swizzle.py b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/step3_swizzle.py new file mode 100644 index 00000000..93c0cf35 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step3_swizzle/step3_swizzle.py @@ -0,0 +1,157 @@ +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +from common_utils import ( + K_DTILE, + K_QTILE, + K_TILE, + M_TILE, + N_FULL, + SWIZZLE_COUNT, + build_meta_data, + const, + swizzle_nz, +) + + +def build(): + meta_data = build_meta_data() + + @to_ir_module(meta_data=meta_data) + def matmul_kernel_ABt_autosync( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + m_i32: "i32", + n_i32: "i32", + k_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + c2 = const(2) + c128 = const(M_TILE) + c256 = const(N_FULL) + c512 = const(K_DTILE) + + m_total = s.index_cast(m_i32) + n_total = s.index_cast(n_i32) + k_total = s.index_cast(k_i32) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + + n_loop = (n_total + c256 - c1) // c256 + m_loop = m_total // c128 + core_loop = n_loop * m_loop + k_dtile_num = k_total // c512 + c_swizzle = const(SWIZZLE_COUNT) + c_swizzle_m1 = c_swizzle - c1 + + tv_a = pto.as_tensor( + tv_2d, ptr=a_ptr, shape=[m_total, k_total], strides=[k_total, c1] + ) + tv_b = pto.as_tensor( + tv_2d, + ptr=b_ptr, + shape=[k_total, n_total], + strides=[c1, k_total], + layout="DN", + ) + tv_c = pto.as_tensor( + tv_2d, ptr=c_ptr, shape=[m_total, n_total], strides=[n_total, c1] + ) + + a_l1 = [pto.alloc_tile(tile_buf_a_l1), pto.alloc_tile(tile_buf_a_l1)] + b_l1 = [pto.alloc_tile(tile_buf_b_l1), pto.alloc_tile(tile_buf_b_l1)] + a_l0 = [pto.alloc_tile(tile_buf_a_l0), pto.alloc_tile(tile_buf_a_l0)] + b_l0 = [pto.alloc_tile(tile_buf_b_l0), pto.alloc_tile(tile_buf_b_l0)] + c_l0 = pto.alloc_tile(tile_buf_c) + + for li in pto.range(bid, core_loop, num_blocks): + m_idx, n_idx = swizzle_nz( + li, m_loop, n_loop, c_swizzle, c_swizzle_m1, c1, c2 + ) + m_offset = m_idx * c128 + n_offset = n_idx * c256 + c_kt = const(K_TILE) + c_kd = const(K_DTILE) + c_nt = const(N_FULL) + + sv_a0 = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, c0], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a0, a_l1[0]) + + for k_idx in pto.range(c0, k_dtile_num, c1): + k_offset = k_idx * c_kd + + def run_loop_k(a_curr, a_next): + is_first_k_tile = k_idx == c0 + + for h in range(2): + h_off = const(h * K_TILE) + sv_b = pto.slice_view( + tile_view_b, + source=tv_b, + offsets=[k_offset + h_off, n_offset], + sizes=[c_kt, c_nt], + ) + pto.load(sv_b, b_l1[h]) + + for quarter in range(4): + phase = h * 4 + quarter + ping = phase & 1 + a_col = const(phase * K_QTILE) + b_row = const(quarter * K_QTILE) + + tile.extract(a_curr, c0, a_col, a_l0[ping]) + tile.extract(b_l1[h], b_row, c0, b_l0[ping]) + + if phase == 0: + pto.cond( + is_first_k_tile, + lambda: tile.matmul( + a_l0[ping], b_l0[ping], c_l0 + ), + lambda: tile.matmul_acc( + c_l0, a_l0[ping], b_l0[ping], c_l0 + ), + ) + else: + tile.matmul_acc(c_l0, a_l0[ping], b_l0[ping], c_l0) + + with pto.if_context(k_idx + c1 < k_dtile_num): + sv_a_next = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, k_offset + c_kd], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a_next, a_next) + + is_curr0 = (k_idx % c2) == c0 + with pto.if_context(is_curr0, has_else=True) as branch: + run_loop_k(a_l1[0], a_l1[1]) + with branch.else_context(): + run_loop_k(a_l1[1], a_l1[0]) + + sv_c = pto.slice_view( + tile_view_c, + source=tv_c, + offsets=[m_offset, n_offset], + sizes=[const(M_TILE), c_nt], + ) + pto.store(c_l0, sv_c) + + return matmul_kernel_ABt_autosync + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _ = parser.parse_args() + print(build()) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/common_utils.py b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/common_utils.py new file mode 100644 index 00000000..58d8b801 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/common_utils.py @@ -0,0 +1,76 @@ +from ptodsl import pto +from ptodsl import scalar as s + +const = s.const + +M_TILE = 128 +K_QTILE = 64 +K_TILE = 256 +K_DTILE = 512 +N_FULL = 256 +SWIZZLE_COUNT = 5 + + +def build_meta_data(): + def meta_data(): + dtype = pto.float16 + acc_dtype = pto.float32 + ptr_type = pto.PtrType(dtype) + i32 = pto.int32 + tv_2d = pto.TensorType(rank=2, dtype=dtype) + + tile_view_a = pto.SubTensorType(shape=[M_TILE, K_DTILE], dtype=dtype) + tile_view_b = pto.SubTensorType(shape=[K_TILE, N_FULL], dtype=dtype) + tile_view_c = pto.SubTensorType(shape=[M_TILE, N_FULL], dtype=dtype) + + b_l1_cfg = pto.TileBufConfig( + blayout="RowMajor", slayout="ColMajor", s_fractal_size=512 + ) + + tile_buf_a_l1 = pto.TileBufType( + shape=[M_TILE, K_DTILE], dtype=dtype, memory_space="MAT" + ) + tile_buf_b_l1 = pto.TileBufType( + shape=[K_TILE, N_FULL], dtype=dtype, memory_space="MAT", config=b_l1_cfg + ) + tile_buf_a_l0 = pto.TileBufType( + shape=[M_TILE, K_QTILE], dtype=dtype, memory_space="LEFT" + ) + tile_buf_b_l0 = pto.TileBufType( + shape=[K_QTILE, N_FULL], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_c = pto.TileBufType( + shape=[M_TILE, N_FULL], dtype=acc_dtype, memory_space="ACC" + ) + + return { + "ptr_type": ptr_type, + "i32": i32, + "tv_2d": tv_2d, + "tile_view_a": tile_view_a, + "tile_view_b": tile_view_b, + "tile_view_c": tile_view_c, + "tile_buf_a_l1": tile_buf_a_l1, + "tile_buf_b_l1": tile_buf_b_l1, + "tile_buf_a_l0": tile_buf_a_l0, + "tile_buf_b_l0": tile_buf_b_l0, + "tile_buf_c": tile_buf_c, + } + + return meta_data + + +def swizzle_nz(li, m_loop, n_loop, c_swizzle, c_swizzle_m1, c1, c2): + tile_block_loop = (n_loop + c_swizzle_m1) // c_swizzle + tile_block_span = c_swizzle * m_loop + tile_block_idx = li // tile_block_span + in_tile_block_idx = li % tile_block_span + is_last_block = tile_block_idx == (tile_block_loop - c1) + n_col_tail = n_loop - c_swizzle * tile_block_idx + n_col = s.select(is_last_block, n_col_tail, c_swizzle) + m_idx = in_tile_block_idx // n_col + n_idx = tile_block_idx * c_swizzle + (in_tile_block_idx % n_col) + odd_block = (tile_block_idx % c2) == c1 + flipped_m_idx = m_loop - m_idx - c1 + m_idx = s.select(odd_block, flipped_m_idx, m_idx) + return m_idx, n_idx diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/compile.sh new file mode 100644 index 00000000..e53612e7 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./step4_manual_pipelining.py > ./step4_manual_pipelining.pto +ptoas ./step4_manual_pipelining.pto -o ./step4_manual_pipelining.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/step4_manual_pipelining.cpp b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/step4_manual_pipelining.cpp new file mode 100644 index 00000000..f17463da --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/step4_manual_pipelining.cpp @@ -0,0 +1,305 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void matmul_kernel_ABt(__gm__ half* v1, __gm__ half* v2, __gm__ half* v3, int32_t v4, int32_t v5, int32_t v6) { + unsigned v7 = 128; + unsigned v8 = 0; + const int32_t v9 = 0; + const int32_t v10 = 1; + const int32_t v11 = 2; + const int32_t v12 = 128; + const int32_t v13 = 256; + const int32_t v14 = 512; + const int32_t v15 = 5; + const int32_t v16 = 64; + const int32_t v17 = 192; + const int32_t v18 = 320; + const int32_t v19 = 384; + const int32_t v20 = 448; + const int32_t v21 = 4; + const int64_t v22 = 0; + const int64_t v23 = 393216; + const int64_t v24 = 262144; + const int64_t v25 = 131072; + const int64_t v26 = 16384; + const int64_t v27 = 32768; + using T = float; + + #if defined(__DAV_CUBE__) + int64_t v28 = get_block_num(); + int32_t v29 = (int32_t) ((int64_t) v28); + int64_t v30 = get_block_idx(); + int32_t v31 = (int32_t) ((int64_t) v30); + int32_t v32 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v5 + (uint32_t) v13) - (uint32_t) v10) / v13; + int32_t v33 = v4 / v12; + int32_t v34 = (int32_t) ((uint32_t) v32 * (uint32_t) v33); + int32_t v35 = v6 / v14; + Tile v36; + TASSIGN(v36, v22); + Tile v37; + TASSIGN(v37, v23); + Tile v38; + TASSIGN(v38, v24); + Tile v39; + TASSIGN(v39, v25); + Tile v40; + TASSIGN(v40, v22); + Tile v41; + TASSIGN(v41, v26); + Tile v42; + TASSIGN(v42, v27); + Tile v43; + TASSIGN(v43, v22); + Tile v44; + TASSIGN(v44, v22); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + for (size_t v45 = (size_t) v31; v45 < ((size_t) v34); v45 += (size_t) v29) { + int32_t v46 = (int32_t) v45; + int32_t v47 = (int32_t) ((uint32_t) v33 * (uint32_t) v15); + int32_t v48 = v46 / v47; + int32_t v49 = v46 % v47; + int32_t v50 = (int32_t) ((uint32_t) v48 * (uint32_t) v15); + int32_t v51 = v48 == (int32_t) ((uint32_t) ((int32_t) ((uint32_t) v32 + (uint32_t) v21) / v15) - (uint32_t) v10) ? (int32_t) ((uint32_t) v32 - (uint32_t) v50) : v15; + int32_t v52 = v49 / v51; + int32_t v53 = (int32_t) ((uint32_t) (v48 % v11 == v10 ? (int32_t) ((uint32_t) ((int32_t) (uint32_t) v33 - (uint32_t) v52) - (uint32_t) v10) : v52) * (uint32_t) v12); + int32_t v54 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v50 + (uint32_t) (v49 % v51)) * (uint32_t) v13); + if (v46 != v31) { + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + unsigned v55 = (unsigned) v6; + unsigned v56 = v7 * v55; + pto::Shape<1, 1, 1, 128, 512> v57 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v58 = pto::Stride<-1, -1, -1, -1, 1>(v56, v56, v56, v55); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v59 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v8 + (unsigned) v53 * (unsigned) v6 + v8 * (unsigned) v10), v57, v58); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v36, v59); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + for (size_t v60 = (size_t) v9; v60 < ((size_t) v35); v60 += (size_t) v10) { + int32_t v61 = (int32_t) v60; + int32_t v62 = (int32_t) ((uint32_t) v61 * (uint32_t) v14); + if (v61 % v11 == v9) { + pto::Shape<1, 1, 1, 256, 256> v63 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v64 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v65 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) v62 * (unsigned) v10 + (unsigned) v54 * (unsigned) v6), v63, v64); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v38, v65); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v40, v36, v9, v9); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v42, v38, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v61 == v9) { + TMATMUL(v44, v40, v42); + } else { + TMATMUL_ACC(v44, v44, v40, v42); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v41, v36, v9, v16); + TEXTRACT(v43, v38, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v41, v43); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v40, v36, v9, v12); + TEXTRACT(v42, v38, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v40, v42); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v41, v36, v9, v17); + TEXTRACT(v43, v38, v17, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v41, v43); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 256> v66 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v67 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v68 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) ((int32_t) (uint32_t) v62 + (uint32_t) v13) * (unsigned) v10 + (unsigned) v54 * (unsigned) v6), v66, v67); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v39, v68); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v40, v36, v9, v13); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v42, v39, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v40, v42); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v41, v36, v9, v18); + TEXTRACT(v43, v39, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v41, v43); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v40, v36, v9, v19); + TEXTRACT(v42, v39, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v40, v42); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v41, v36, v9, v20); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TEXTRACT(v43, v39, v17, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v41, v43); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v61 + (uint32_t) v10) < v35) { + unsigned v69 = (unsigned) v6; + unsigned v70 = v7 * v69; + pto::Shape<1, 1, 1, 128, 512> v71 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v72 = pto::Stride<-1, -1, -1, -1, 1>(v70, v70, v70, v69); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v73 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v8 + (unsigned) v53 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v62 + (uint32_t) v14) * (unsigned) v10), v71, v72); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v37, v73); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + }; + } else { + pto::Shape<1, 1, 1, 256, 256> v74 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v75 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v76 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) v62 * (unsigned) v10 + (unsigned) v54 * (unsigned) v6), v74, v75); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v38, v76); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v40, v37, v9, v9); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + TEXTRACT(v42, v38, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v61 == v9) { + TMATMUL(v44, v40, v42); + } else { + TMATMUL_ACC(v44, v44, v40, v42); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v41, v37, v9, v16); + TEXTRACT(v43, v38, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v41, v43); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v40, v37, v9, v12); + TEXTRACT(v42, v38, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v40, v42); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v41, v37, v9, v17); + TEXTRACT(v43, v38, v17, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v41, v43); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 256, 256> v77 = pto::Shape<1, 1, 1, 256, 256>(); + pto::Stride<256, 256, 256, 1, -1> v78 = pto::Stride<256, 256, 256, 1, -1>((unsigned) v6); + GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN> v79 = GlobalTensor, pto::Stride<256, 256, 256, 1, -1>, pto::Layout::DN>(v2 + (v8 + (unsigned) ((int32_t) (uint32_t) v62 + (uint32_t) v13) * (unsigned) v10 + (unsigned) v54 * (unsigned) v6), v77, v78); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v39, v79); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v40, v37, v9, v13); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v42, v39, v9, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v40, v42); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v41, v37, v9, v18); + TEXTRACT(v43, v39, v16, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v41, v43); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v40, v37, v9, v19); + TEXTRACT(v42, v39, v12, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v40, v42); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v41, v37, v9, v20); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TEXTRACT(v43, v39, v17, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v44, v44, v41, v43); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + if ((int32_t) ((uint32_t) v61 + (uint32_t) v10) < v35) { + unsigned v80 = (unsigned) v6; + unsigned v81 = v7 * v80; + pto::Shape<1, 1, 1, 128, 512> v82 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<-1, -1, -1, -1, 1> v83 = pto::Stride<-1, -1, -1, -1, 1>(v81, v81, v81, v80); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v84 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v1 + (v8 + (unsigned) v53 * (unsigned) v6 + (unsigned) ((int32_t) (uint32_t) v62 + (uint32_t) v14) * (unsigned) v10), v82, v83); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v36, v84); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + }; + }; + }; + unsigned v85 = (unsigned) v5; + unsigned v86 = v7 * v85; + pto::Shape<1, 1, 1, 128, 256> v87 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<-1, -1, -1, -1, 1> v88 = pto::Stride<-1, -1, -1, -1, 1>(v86, v86, v86, v85); + GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND> v89 = GlobalTensor, pto::Stride<-1, -1, -1, -1, 1>, pto::Layout::ND>(v3 + (v8 + (unsigned) v53 * (unsigned) v5 + (unsigned) v54 * (unsigned) v10), v87, v88); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v89, v44); + if ((int32_t) ((uint32_t) v46 + (uint32_t) v29) < v34) { + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + }; + } + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + #endif // __DAV_CUBE__ + + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/step4_manual_pipelining.pto b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/step4_manual_pipelining.pto new file mode 100644 index 00000000..8597ef04 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/step4_manual_pipelining.pto @@ -0,0 +1,316 @@ +module { + func.func @matmul_kernel_ABt(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %0 = arith.index_cast %arg3 : i32 to index + %1 = arith.index_cast %arg4 : i32 to index + %2 = arith.index_cast %arg5 : i32 to index + %3 = pto.get_block_num + %4 = arith.index_cast %3 : i64 to index + %5 = pto.get_block_idx + %6 = arith.index_cast %5 : i64 to index + %7 = arith.addi %1, %c256 : index + %8 = arith.subi %7, %c1 : index + %9 = arith.divsi %8, %c256 : index + %10 = arith.divsi %0, %c128 : index + %11 = arith.muli %9, %10 : index + %12 = arith.divsi %2, %c512 : index + %c5 = arith.constant 5 : index + %13 = arith.subi %c5, %c1 : index + %14 = pto.make_tensor_view %arg0, shape = [%0, %2], strides = [%2, %c1] : !pto.tensor_view + %15 = pto.make_tensor_view %arg1, shape = [%2, %1], strides = [%c1, %2] {layout = #pto.layout} : !pto.tensor_view + %16 = pto.make_tensor_view %arg2, shape = [%0, %1], strides = [%1, %c1] : !pto.tensor_view + %17 = pto.alloc_tile : !pto.tile_buf + %18 = pto.alloc_tile : !pto.tile_buf + %19 = pto.alloc_tile : !pto.tile_buf + %20 = pto.alloc_tile : !pto.tile_buf + %21 = pto.alloc_tile : !pto.tile_buf + %22 = pto.alloc_tile : !pto.tile_buf + %23 = pto.alloc_tile : !pto.tile_buf + %24 = pto.alloc_tile : !pto.tile_buf + %25 = pto.alloc_tile : !pto.tile_buf + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg6 = %6 to %11 step %4 { + %26 = arith.addi %9, %13 : index + %27 = arith.divsi %26, %c5 : index + %28 = arith.muli %c5, %10 : index + %29 = arith.divsi %arg6, %28 : index + %30 = arith.remsi %arg6, %28 : index + %31 = arith.subi %27, %c1 : index + %32 = arith.cmpi eq, %29, %31 : index + %33 = arith.muli %c5, %29 : index + %34 = arith.subi %9, %33 : index + %35 = arith.select %32, %34, %c5 : index + %36 = arith.divsi %30, %35 : index + %37 = arith.muli %29, %c5 : index + %38 = arith.remsi %30, %35 : index + %39 = arith.addi %37, %38 : index + %40 = arith.remsi %29, %c2 : index + %41 = arith.cmpi eq, %40, %c1 : index + %42 = arith.subi %10, %36 : index + %43 = arith.subi %42, %c1 : index + %44 = arith.select %41, %43, %36 : index + %45 = arith.muli %44, %c128 : index + %46 = arith.muli %39, %c256 : index + %c256_0 = arith.constant 256 : index + %c512_1 = arith.constant 512 : index + %c256_2 = arith.constant 256 : index + %47 = arith.cmpi ne, %arg6, %6 : index + scf.if %47 { + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + %c128_3 = arith.constant 128 : index + %48 = pto.partition_view %14, offsets = [%45, %c0], sizes = [%c128_3, %c512_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%48 : !pto.partition_tensor_view<128x512xf16>) outs(%17 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.for %arg7 = %c0 to %12 step %c1 { + %52 = arith.muli %arg7, %c512_1 : index + %53 = arith.remsi %arg7, %c2 : index + %54 = arith.cmpi eq, %53, %c0 : index + scf.if %54 { + %55 = arith.cmpi eq, %arg7, %c0 : index + %c0_5 = arith.constant 0 : index + %56 = arith.addi %52, %c0_5 : index + %57 = pto.partition_view %15, offsets = [%56, %46], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%57 : !pto.partition_tensor_view<256x256xf16>) outs(%19 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_6 = arith.constant 0 : index + %c0_7 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%17, %c0, %c0_6 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%19, %c0_7, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %55 { + pto.tmatmul ins(%21, %23 : !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_8 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%17, %c0, %c64 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%19, %c64_8, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_9 = arith.constant 128 : index + %c128_10 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%17, %c0, %c128_9 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%19, %c128_10, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_11 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%17, %c0, %c192 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%19, %c192_11, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_12 = arith.constant 256 : index + %58 = arith.addi %52, %c256_12 : index + %59 = pto.partition_view %15, offsets = [%58, %46], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%59 : !pto.partition_tensor_view<256x256xf16>) outs(%20 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_13 = arith.constant 256 : index + %c0_14 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%17, %c0, %c256_13 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%20, %c0_14, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_15 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%17, %c0, %c320 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%20, %c64_15, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_16 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%17, %c0, %c384 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%20, %c128_16, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_17 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%17, %c0, %c448 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%20, %c192_17, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %60 = arith.addi %arg7, %c1 : index + %61 = arith.cmpi slt, %60, %12 : index + scf.if %61 { + %62 = arith.addi %52, %c512_1 : index + %c128_18 = arith.constant 128 : index + %63 = pto.partition_view %14, offsets = [%45, %62], sizes = [%c128_18, %c512_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%63 : !pto.partition_tensor_view<128x512xf16>) outs(%18 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } else { + %55 = arith.cmpi eq, %arg7, %c0 : index + %c0_5 = arith.constant 0 : index + %56 = arith.addi %52, %c0_5 : index + %57 = pto.partition_view %15, offsets = [%56, %46], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%57 : !pto.partition_tensor_view<256x256xf16>) outs(%19 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c0_6 = arith.constant 0 : index + %c0_7 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%18, %c0, %c0_6 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%19, %c0_7, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + scf.if %55 { + pto.tmatmul ins(%21, %23 : !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + } else { + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + } + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c64 = arith.constant 64 : index + %c64_8 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%18, %c0, %c64 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%19, %c64_8, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c128_9 = arith.constant 128 : index + %c128_10 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%18, %c0, %c128_9 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%19, %c128_10, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c192 = arith.constant 192 : index + %c192_11 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%18, %c0, %c192 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%19, %c192_11, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_12 = arith.constant 256 : index + %58 = arith.addi %52, %c256_12 : index + %59 = pto.partition_view %15, offsets = [%58, %46], sizes = [%c256_0, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<256x256xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%59 : !pto.partition_tensor_view<256x256xf16>) outs(%20 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c256_13 = arith.constant 256 : index + %c0_14 = arith.constant 0 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%18, %c0, %c256_13 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%20, %c0_14, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c320 = arith.constant 320 : index + %c64_15 = arith.constant 64 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%18, %c0, %c320 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.textract ins(%20, %c64_15, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c384 = arith.constant 384 : index + %c128_16 = arith.constant 128 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%18, %c0, %c384 : !pto.tile_buf, index, index) outs(%21 : !pto.tile_buf) + pto.textract ins(%20, %c128_16, %c0 : !pto.tile_buf, index, index) outs(%23 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %21, %23 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %c448 = arith.constant 448 : index + %c192_17 = arith.constant 192 : index + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%18, %c0, %c448 : !pto.tile_buf, index, index) outs(%22 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.textract ins(%20, %c192_17, %c0 : !pto.tile_buf, index, index) outs(%24 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tmatmul.acc ins(%25, %22, %24 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + %60 = arith.addi %arg7, %c1 : index + %61 = arith.cmpi slt, %60, %12 : index + scf.if %61 { + %62 = arith.addi %52, %c512_1 : index + %c128_18 = arith.constant 128 : index + %63 = pto.partition_view %14, offsets = [%45, %62], sizes = [%c128_18, %c512_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x512xf16> + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tload ins(%63 : !pto.partition_tensor_view<128x512xf16>) outs(%17 : !pto.tile_buf) + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + } + %c128_4 = arith.constant 128 : index + %49 = pto.partition_view %16, offsets = [%45, %46], sizes = [%c128_4, %c256_2] : !pto.tensor_view -> !pto.partition_tensor_view<128x256xf16> + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.tstore ins(%25 : !pto.tile_buf) outs(%49 : !pto.partition_tensor_view<128x256xf16>) + %50 = arith.addi %arg6, %4 : index + %51 = arith.cmpi slt, %50, %11 : index + scf.if %51 { + pto.record_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + } + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + pto.wait_event[#pto.sync_op_type, #pto.sync_op_type, ] + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/step4_manual_pipelining.py b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/step4_manual_pipelining.py new file mode 100644 index 00000000..a92d7ed5 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/matmul_optimization_guide/step4_manual_pipelining/step4_manual_pipelining.py @@ -0,0 +1,202 @@ +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +from common_utils import ( + K_DTILE, + K_QTILE, + K_TILE, + M_TILE, + N_FULL, + SWIZZLE_COUNT, + build_meta_data, + const, + swizzle_nz, +) + + +def build(): + meta_data = build_meta_data() + + @to_ir_module(meta_data=meta_data) + def matmul_kernel_ABt( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + m_i32: "i32", + n_i32: "i32", + k_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + c2 = const(2) + c128 = const(M_TILE) + c256 = const(N_FULL) + c512 = const(K_DTILE) + + m_total = s.index_cast(m_i32) + n_total = s.index_cast(n_i32) + k_total = s.index_cast(k_i32) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + + n_loop = (n_total + c256 - c1) // c256 + m_loop = m_total // c128 + core_loop = n_loop * m_loop + k_dtile_num = k_total // c512 + c_swizzle = const(SWIZZLE_COUNT) + c_swizzle_m1 = c_swizzle - c1 + + tv_a = pto.as_tensor( + tv_2d, ptr=a_ptr, shape=[m_total, k_total], strides=[k_total, c1] + ) + tv_b = pto.as_tensor( + tv_2d, + ptr=b_ptr, + shape=[k_total, n_total], + strides=[c1, k_total], + layout="DN", + ) + tv_c = pto.as_tensor( + tv_2d, ptr=c_ptr, shape=[m_total, n_total], strides=[n_total, c1] + ) + + a_l1 = [pto.alloc_tile(tile_buf_a_l1), pto.alloc_tile(tile_buf_a_l1)] + b_l1 = [pto.alloc_tile(tile_buf_b_l1), pto.alloc_tile(tile_buf_b_l1)] + a_l0 = [pto.alloc_tile(tile_buf_a_l0), pto.alloc_tile(tile_buf_a_l0)] + b_l0 = [pto.alloc_tile(tile_buf_b_l0), pto.alloc_tile(tile_buf_b_l0)] + c_l0 = pto.alloc_tile(tile_buf_c) + + pto.record_event("MATMUL", "MOV_M2L", event_id=[0, 1]) + pto.record_event("MOV_M2L", "LOAD", event_id=[0, 1, 2, 3]) + + for li in pto.range(bid, core_loop, num_blocks): + m_idx, n_idx = swizzle_nz( + li, m_loop, n_loop, c_swizzle, c_swizzle_m1, c1, c2 + ) + m_offset = m_idx * c128 + n_offset = n_idx * c256 + c_kt = const(K_TILE) + c_kd = const(K_DTILE) + c_nt = const(N_FULL) + + not_first_tile = li != bid + with pto.if_context(not_first_tile): + pto.wait_event("STORE_ACC", "MATMUL", event_id=0) + + sv_a0 = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, c0], + sizes=[const(M_TILE), c_kd], + ) + pto.wait_event("MOV_M2L", "LOAD", event_id=0) + pto.load(sv_a0, a_l1[0]) + pto.record_event("LOAD", "MOV_M2L", event_id=0) + + for k_idx in pto.range(c0, k_dtile_num, c1): + k_offset = k_idx * c_kd + + def run_loop_k(curr_id, next_id, a_curr, a_next): + is_first_k_tile = k_idx == c0 + + for h in range(2): + b_evt = 2 + h + h_off = const(h * K_TILE) + sv_b = pto.slice_view( + tile_view_b, + source=tv_b, + offsets=[k_offset + h_off, n_offset], + sizes=[c_kt, c_nt], + ) + + pto.wait_event("MOV_M2L", "LOAD", event_id=b_evt) + pto.load(sv_b, b_l1[h]) + pto.record_event("LOAD", "MOV_M2L", event_id=b_evt) + + for quarter in range(4): + phase = h * 4 + quarter + ping = phase & 1 + a_col = const(phase * K_QTILE) + b_row = const(quarter * K_QTILE) + + pto.wait_event("MATMUL", "MOV_M2L", event_id=ping) + if phase == 0: + pto.wait_event("LOAD", "MOV_M2L", event_id=curr_id) + + tile.extract(a_curr, c0, a_col, a_l0[ping]) + if phase == 7: + pto.record_event( + "MOV_M2L", "LOAD", event_id=curr_id + ) + + if quarter == 0: + pto.wait_event("LOAD", "MOV_M2L", event_id=b_evt) + + tile.extract(b_l1[h], b_row, c0, b_l0[ping]) + pto.record_event("MOV_M2L", "MATMUL", event_id=0) + + if quarter == 3: + pto.record_event("MOV_M2L", "LOAD", event_id=b_evt) + + pto.wait_event("MOV_M2L", "MATMUL", event_id=0) + if phase == 0: + pto.cond( + is_first_k_tile, + lambda: tile.matmul( + a_l0[ping], b_l0[ping], c_l0 + ), + lambda: tile.matmul_acc( + c_l0, a_l0[ping], b_l0[ping], c_l0 + ), + ) + else: + tile.matmul_acc(c_l0, a_l0[ping], b_l0[ping], c_l0) + + pto.record_event("MATMUL", "MOV_M2L", event_id=ping) + + with pto.if_context(k_idx + c1 < k_dtile_num): + sv_a_next = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, k_offset + c_kd], + sizes=[const(M_TILE), c_kd], + ) + pto.wait_event("MOV_M2L", "LOAD", event_id=next_id) + pto.load(sv_a_next, a_next) + pto.record_event("LOAD", "MOV_M2L", event_id=next_id) + + is_curr0 = (k_idx % c2) == c0 + with pto.if_context(is_curr0, has_else=True) as branch: + run_loop_k(0, 1, a_l1[0], a_l1[1]) + with branch.else_context(): + run_loop_k(1, 0, a_l1[1], a_l1[0]) + + sv_c = pto.slice_view( + tile_view_c, + source=tv_c, + offsets=[m_offset, n_offset], + sizes=[const(M_TILE), c_nt], + ) + pto.record_wait_pair("MATMUL", "STORE_ACC", event_id=0) + pto.store(c_l0, sv_c) + + with pto.if_context(li + num_blocks < core_loop): + pto.record_event("STORE_ACC", "MATMUL", event_id=0) + + pto.wait_event("MOV_M2L", "LOAD", event_id=3) + pto.wait_event("MOV_M2L", "LOAD", event_id=2) + pto.wait_event("MOV_M2L", "LOAD", event_id=1) + pto.wait_event("MOV_M2L", "LOAD", event_id=0) + pto.wait_event("MATMUL", "MOV_M2L", event_id=0) + pto.wait_event("MATMUL", "MOV_M2L", event_id=1) + + return matmul_kernel_ABt + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _ = parser.parse_args() + print(build()) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/add.cpp b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/add.cpp new file mode 100644 index 00000000..a2319712 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/add.cpp @@ -0,0 +1,78 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void vec_add_kernel_2d_dynamic(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, int32_t v4, int32_t v5) { + unsigned v6 = 0; + const int32_t v7 = 32; + const int32_t v8 = 1; + const int64_t v9 = 0; + const int64_t v10 = 4096; + const int64_t v11 = 8192; + using T = float; + int64_t v12 = get_block_idx(); + int64_t v13 = get_subblockid(); + int64_t v14 = get_subblockdim(); + int32_t v15 = (int32_t) ((uint32_t) ((int32_t) (int64_t) ((uint64_t) ((int64_t) (uint64_t) ((int64_t) v12) * (uint64_t) ((int64_t) v14)) + (uint64_t) ((int64_t) v13))) * (uint32_t) v7); + pto::Shape<1, 1, 1, 32, 32> v16 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<1024, 1024, 1024, 32, 1> v17 = pto::Stride<1024, 1024, 1024, 32, 1>(); + GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND> v18 = GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND>(v1 + (v6 + (unsigned) v15 * (unsigned) v7 + v6 * (unsigned) v8), v16, v17); + pto::Shape<1, 1, 1, 32, 32> v19 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<1024, 1024, 1024, 32, 1> v20 = pto::Stride<1024, 1024, 1024, 32, 1>(); + GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND> v21 = GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND>(v2 + (v6 + (unsigned) v15 * (unsigned) v7 + v6 * (unsigned) v8), v19, v20); + pto::Shape<1, 1, 1, 32, 32> v22 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<1024, 1024, 1024, 32, 1> v23 = pto::Stride<1024, 1024, 1024, 32, 1>(); + GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND> v24 = GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND>(v3 + (v6 + (unsigned) v15 * (unsigned) v7 + v6 * (unsigned) v8), v22, v23); + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v25; + TASSIGN(v25, v9); + Tile v26 = Tile(v4, v5); + __ubuf__ float* v27 = v25.data(); + uint64_t v28 = reinterpret_cast(v27); + TASSIGN(v26, v28); + Tile v29; + TASSIGN(v29, v10); + Tile v30 = Tile(v4, v5); + __ubuf__ float* v31 = v29.data(); + uint64_t v32 = reinterpret_cast(v31); + TASSIGN(v30, v32); + Tile v33; + TASSIGN(v33, v11); + Tile v34 = Tile(v4, v5); + __ubuf__ float* v35 = v33.data(); + uint64_t v36 = reinterpret_cast(v35); + TASSIGN(v34, v36); + TLOAD(v26, v18); + TLOAD(v30, v21); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v34, v26, v30); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v24, v34); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/add.pto b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/add.pto new file mode 100644 index 00000000..7f152410 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/add.pto @@ -0,0 +1,34 @@ +module { + func.func @vec_add_kernel_2d_dynamic(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: i32, %arg4: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1280 = arith.constant 1280 : index + %0 = pto.get_block_idx + %1 = pto.get_subblock_idx + %2 = pto.get_subblock_num + %3 = arith.muli %0, %2 : i64 + %4 = arith.addi %3, %1 : i64 + %5 = arith.index_cast %arg3 : i32 to index + %6 = arith.index_cast %arg4 : i32 to index + %7 = pto.make_tensor_view %arg0, shape = [%c1280, %c32], strides = [%c32, %c1] : !pto.tensor_view + %8 = pto.make_tensor_view %arg1, shape = [%c1280, %c32], strides = [%c32, %c1] : !pto.tensor_view + %9 = pto.make_tensor_view %arg2, shape = [%c1280, %c32], strides = [%c32, %c1] : !pto.tensor_view + %10 = arith.index_cast %4 : i64 to index + %11 = arith.muli %10, %c32 : index + %12 = pto.partition_view %7, offsets = [%11, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %13 = pto.partition_view %8, offsets = [%11, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %14 = pto.partition_view %9, offsets = [%11, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + pto.section.vector { + %15 = pto.alloc_tile valid_row = %5 valid_col = %6 : !pto.tile_buf + %16 = pto.alloc_tile valid_row = %5 valid_col = %6 : !pto.tile_buf + %17 = pto.alloc_tile valid_row = %5 valid_col = %6 : !pto.tile_buf + pto.tload ins(%12 : !pto.partition_tensor_view<32x32xf32>) outs(%15 : !pto.tile_buf) + pto.tload ins(%13 : !pto.partition_tensor_view<32x32xf32>) outs(%16 : !pto.tile_buf) + pto.tadd ins(%15, %16 : !pto.tile_buf, !pto.tile_buf) outs(%17 : !pto.tile_buf) + pto.tstore ins(%17 : !pto.tile_buf) outs(%14 : !pto.partition_tensor_view<32x32xf32>) + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/add_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/add_builder.py new file mode 100644 index 00000000..1c790077 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/add_builder.py @@ -0,0 +1,87 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + + +def meta_data(): + # common, reusable type declarations + dtype = pto.float32 + index_dtype = pto.int32 + ptr_type = pto.PtrType(dtype) + tensor_type = pto.TensorType(rank=2, dtype=dtype) + subtensor_type = pto.SubTensorType( + shape=[32, 32], dtype=dtype + ) # TODO: omit shape https://github.com/zhangstevenunity/PTOAS/issues/31 + tile_cfg = pto.TileBufConfig() + # defaults to pto.TileBufConfig(blayout="RowMajor", slayout="NoneBox", s_fractal_size=512, pad="Null") + tile_type = pto.TileBufType( + shape=[32, 32], + valid_shape=[-1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + return { + "ptr_type": ptr_type, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_type": subtensor_type, + "tile_type": tile_type, + } + + +@to_ir_module(meta_data=meta_data) +def vec_add_kernel_2d_dynamic( + arg0: "ptr_type", + arg1: "ptr_type", + arg2: "ptr_type", + arg_vrow_i32: "index_dtype", + arg_vcol_i32: "index_dtype", +) -> None: + c0 = const(0) + c1 = const(1) + c32 = const(32) + c1280 = const(1280) + + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + cidmul = cid * sub_bnum + vid = cidmul + sub_bid + + v_row_idx = s.index_cast(arg_vrow_i32) + v_col_idx = s.index_cast(arg_vcol_i32) + + tv0 = pto.as_tensor(tensor_type, ptr=arg0, shape=[c1280, c32], strides=[c32, c1]) + tv1 = pto.as_tensor(tensor_type, ptr=arg1, shape=[c1280, c32], strides=[c32, c1]) + tv2 = pto.as_tensor(tensor_type, ptr=arg2, shape=[c1280, c32], strides=[c32, c1]) + + vid_idx = s.index_cast(vid) + offset_row = vid_idx * c32 # every core loads 32 rows of data + sv0 = pto.slice_view( + subtensor_type, source=tv0, offsets=[offset_row, c0], sizes=[c32, c32] + ) + sv1 = pto.slice_view( + subtensor_type, source=tv1, offsets=[offset_row, c0], sizes=[c32, c32] + ) + sv2 = pto.slice_view( + subtensor_type, source=tv2, offsets=[offset_row, c0], sizes=[c32, c32] + ) + + with pto.vector_section(): + tb0 = pto.alloc_tile(tile_type, valid_row=v_row_idx, valid_col=v_col_idx) + tb1 = pto.alloc_tile(tile_type, valid_row=v_row_idx, valid_col=v_col_idx) + tb2 = pto.alloc_tile(tile_type, valid_row=v_row_idx, valid_col=v_col_idx) + + pto.load(sv0, tb0) + pto.load(sv1, tb1) + tile.add(tb0, tb1, tb2) + pto.store(tb2, sv2) + + # `default `return None` maps to `func.ReturnOp([])` + + +if __name__ == "__main__": + module = vec_add_kernel_2d_dynamic + print(module) diff --git a/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/compile.sh new file mode 100644 index 00000000..2eacc832 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/add_static_multicore/add/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./add_builder.py > ./add.pto +ptoas --enable-insert-sync ./add.pto -o ./add.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/compile.sh b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/compile.sh new file mode 100644 index 00000000..c6fe9a7a --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/compile.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -e +python ./matmul_builder.py > matmul.pto +ptoas --enable-insert-sync matmul.pto -o matmul.cpp diff --git a/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/matmul.cpp b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/matmul.cpp new file mode 100644 index 00000000..76aa5f69 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/matmul.cpp @@ -0,0 +1,132 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +__global__ AICORE void RunTMATMULSplitK(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, __gm__ float* v4, bool v5) { + unsigned v6 = 0; + const int32_t v7 = 0; + const int32_t v8 = 1; + const int32_t v9 = 32; + const int32_t v10 = 256; + const int32_t v11 = 8; + const int64_t v12 = 0; + const int64_t v13 = 4096; + const int64_t v14 = 8192; + using T = float; + + #if defined(__DAV_CUBE__) + Tile v15; + TASSIGN(v15, v12); + Tile v16; + TASSIGN(v16, v13); + Tile v17; + TASSIGN(v17, v14); + Tile v18; + TASSIGN(v18, v12); + Tile v19; + TASSIGN(v19, v12); + Tile v20; + TASSIGN(v20, v12); + Tile v21; + TASSIGN(v21, v12); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + for (size_t v22 = (size_t) v7; v22 < ((size_t) v11); v22 += (size_t) v8) { + int32_t v23 = (int32_t) v22; + int32_t v24 = (int32_t) ((uint32_t) v23 * (uint32_t) v9); + pto::Shape<1, 1, 1, 32, 32> v25 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<8192, 8192, 8192, 256, 1> v26 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v27 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v2 + (v6 + v6 * (unsigned) v10 + (unsigned) v24 * (unsigned) v8), v25, v26); + pto::Shape<1, 1, 1, 32, 32> v28 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<1024, 1024, 1024, 32, 1> v29 = pto::Stride<1024, 1024, 1024, 32, 1>(); + GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND> v30 = GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND>(v3 + (v6 + (unsigned) v24 * (unsigned) v9 + v6 * (unsigned) v8), v28, v29); + pto::Shape<1, 1, 1, 1, 32> v31 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v32 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v33 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v4 + (v6 + v6 * (unsigned) v9 + v6 * (unsigned) v8), v31, v32); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v15, v27); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v16, v30); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + if (v5) { + pipe_barrier(PIPE_MTE2); + TLOAD(v17, v33); + }; + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + pipe_barrier(PIPE_MTE1); + TMOV(v18, v15); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + TMOV(v19, v16); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + if (v5) { + TMOV(v21, v17); + }; + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v23 == v7) { + if (v5) { + TMATMUL_BIAS(v20, v18, v19, v21); + } else { + TMATMUL(v20, v18, v19); + }; + } else { + TMATMUL_ACC(v20, v20, v18, v19); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + } + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + pto::Shape<1, 1, 1, 32, 32> v34 = pto::Shape<1, 1, 1, 32, 32>(); + pto::Stride<1024, 1024, 1024, 32, 1> v35 = pto::Stride<1024, 1024, 1024, 32, 1>(); + GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND> v36 = GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v9 + v6 * (unsigned) v8), v34, v35); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v36, v20); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} diff --git a/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/matmul.pto b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/matmul.pto new file mode 100644 index 00000000..9931e8af --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/matmul.pto @@ -0,0 +1,56 @@ +module { + func.func @RunTMATMULSplitK(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %arg4: i1) { + pto.section.cube { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c32_0 = arith.constant 32 : index + %c32_1 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c32_2 = arith.constant 32 : index + %c32_3 = arith.constant 32 : index + %0 = pto.make_tensor_view %arg1, shape = [%c32, %c256], strides = [%c256, %c1] : !pto.tensor_view + %1 = pto.make_tensor_view %arg2, shape = [%c256, %c32_0], strides = [%c32_0, %c1] : !pto.tensor_view + %2 = pto.make_tensor_view %arg0, shape = [%c32, %c32_0], strides = [%c32_0, %c1] : !pto.tensor_view + %3 = pto.make_tensor_view %arg3, shape = [%c1, %c32_0], strides = [%c32_0, %c1] : !pto.tensor_view + %4 = pto.alloc_tile : !pto.tile_buf + %5 = pto.alloc_tile : !pto.tile_buf + %6 = pto.alloc_tile : !pto.tile_buf + %7 = pto.alloc_tile : !pto.tile_buf + %8 = pto.alloc_tile : !pto.tile_buf + %9 = pto.alloc_tile : !pto.tile_buf + %10 = pto.alloc_tile : !pto.tile_buf + scf.for %arg5 = %c0 to %c8 step %c1 { + %12 = arith.muli %arg5, %c32_1 : index + %13 = pto.partition_view %0, offsets = [%c0, %12], sizes = [%c32_2, %c32_1] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %14 = pto.partition_view %1, offsets = [%12, %c0], sizes = [%c32_1, %c32_3] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %15 = pto.partition_view %3, offsets = [%c0, %c0], sizes = [%c1, %c32_3] : !pto.tensor_view -> !pto.partition_tensor_view<1x32xf32> + pto.tload ins(%13 : !pto.partition_tensor_view<32x32xf32>) outs(%4 : !pto.tile_buf) + pto.tload ins(%14 : !pto.partition_tensor_view<32x32xf32>) outs(%5 : !pto.tile_buf) + scf.if %arg4 { + pto.tload ins(%15 : !pto.partition_tensor_view<1x32xf32>) outs(%6 : !pto.tile_buf) + } + pto.tmov ins(%4 : !pto.tile_buf) outs(%7 : !pto.tile_buf) + pto.tmov ins(%5 : !pto.tile_buf) outs(%8 : !pto.tile_buf) + scf.if %arg4 { + pto.tmov ins(%6 : !pto.tile_buf) outs(%10 : !pto.tile_buf) + } + %16 = arith.cmpi eq, %arg5, %c0 : index + scf.if %16 { + scf.if %arg4 { + pto.tmatmul.bias ins(%7, %8, %10 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) + } else { + pto.tmatmul ins(%7, %8 : !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) + } + } else { + pto.tmatmul.acc ins(%9, %7, %8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) + } + } + %11 = pto.partition_view %2, offsets = [%c0, %c0], sizes = [%c32_2, %c32_3] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + pto.tstore ins(%9 : !pto.tile_buf) outs(%11 : !pto.partition_tensor_view<32x32xf32>) + } + return + } +} + diff --git a/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/matmul_builder.py b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/matmul_builder.py new file mode 100644 index 00000000..6b11b015 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/simple_static/matmul_static_singlecore/matmul/matmul_builder.py @@ -0,0 +1,169 @@ +# adapted from https://github.com/zhangstevenunity/PTOAS/blob/a301aa43b388d9b2e1ba0db8773b3a719e8c445b/test/samples/MatMul/tmatmulk.py + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + + +def build( + M=32, + K=256, + N=32, + validM=32, + validK=256, + validN=32, + BASEK=32, +): + assert K % BASEK == 0 + iters = K // BASEK + + def meta_data(): + dtype = pto.float32 + i1 = pto.bool + ptr_type = pto.PtrType(dtype) + + tensor_type = pto.TensorType(rank=2, dtype=dtype) + + tile_view_a = pto.SubTensorType(shape=[M, BASEK], dtype=dtype) + tile_view_b = pto.SubTensorType(shape=[BASEK, N], dtype=dtype) + tile_view_out = pto.SubTensorType(shape=[M, N], dtype=dtype) + tile_view_bias = pto.SubTensorType(shape=[1, N], dtype=dtype) + + tile_buf_aMat = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="MAT" + ) + tile_buf_bMat = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="MAT" + ) + tile_buf_biasData = pto.TileBufType( + shape=[1, N], dtype=dtype, memory_space="MAT" + ) + tile_buf_aTile = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="LEFT" + ) + tile_buf_bTile = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_cTile = pto.TileBufType(shape=[M, N], dtype=dtype, memory_space="ACC") + tile_buf_biasTile = pto.TileBufType( + shape=[1, N], dtype=dtype, memory_space="BIAS" + ) + + return { + "ptr_type": ptr_type, + "i1": i1, + "tensor_type": tensor_type, + "tile_view_a": tile_view_a, + "tile_view_b": tile_view_b, + "tile_view_out": tile_view_out, + "tile_view_bias": tile_view_bias, + "tile_buf_aMat": tile_buf_aMat, + "tile_buf_bMat": tile_buf_bMat, + "tile_buf_biasData": tile_buf_biasData, + "tile_buf_aTile": tile_buf_aTile, + "tile_buf_bTile": tile_buf_bTile, + "tile_buf_cTile": tile_buf_cTile, + "tile_buf_biasTile": tile_buf_biasTile, + } + + const = s.const + + @to_ir_module(meta_data=meta_data) + def RunTMATMULSplitK( + out_ptr: "ptr_type", + a_ptr: "ptr_type", + b_ptr: "ptr_type", + bias_ptr: "ptr_type", + isBias: "i1", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + cM = const(validM) + cK = const(validK) + cN = const(validN) + cBASEK = const(BASEK) + cIter = const(iters) + cTileM = const(M) + cTileN = const(N) + + tvA = pto.as_tensor( + tensor_type, ptr=a_ptr, shape=[cM, cK], strides=[cK, c1] + ) + tvB = pto.as_tensor( + tensor_type, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1] + ) + tvOut = pto.as_tensor( + tensor_type, ptr=out_ptr, shape=[cM, cN], strides=[cN, c1] + ) + tvBias = pto.as_tensor( + tensor_type, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1] + ) + + aMatTile = pto.alloc_tile(tile_buf_aMat) + bMatTile = pto.alloc_tile(tile_buf_bMat) + biasDataTile = pto.alloc_tile(tile_buf_biasData) + aTile = pto.alloc_tile(tile_buf_aTile) + bTile = pto.alloc_tile(tile_buf_bTile) + cTile = pto.alloc_tile(tile_buf_cTile) + biasTile = pto.alloc_tile(tile_buf_biasTile) + + for i in pto.range(c0, cIter, c1): + kOff = i * cBASEK + svA = pto.slice_view( + tile_view_a, + source=tvA, + offsets=[c0, kOff], + sizes=[cTileM, cBASEK], + ) + svB = pto.slice_view( + tile_view_b, + source=tvB, + offsets=[kOff, c0], + sizes=[cBASEK, cTileN], + ) + svBias = pto.slice_view( + tile_view_bias, + source=tvBias, + offsets=[c0, c0], + sizes=[c1, cTileN], + ) + + pto.load(svA, aMatTile) + pto.load(svB, bMatTile) + with pto.if_context(isBias): + pto.load(svBias, biasDataTile) + + tile.mov(aMatTile, aTile) + tile.mov(bMatTile, bTile) + with pto.if_context(isBias): + tile.mov(biasDataTile, biasTile) + + is_i0 = s.eq(i, c0) + + def _first_iter(): + pto.cond( + isBias, + lambda: tile.matmul_bias(aTile, bTile, biasTile, cTile), + lambda: tile.matmul(aTile, bTile, cTile), + ) + + pto.cond( + is_i0, + _first_iter, + lambda: tile.matmul_acc(cTile, aTile, bTile, cTile), + ) + + svOut = pto.slice_view( + tile_view_out, + source=tvOut, + offsets=[c0, c0], + sizes=[cTileM, cTileN], + ) + pto.store(cTile, svOut) + + module = RunTMATMULSplitK + return module + + +if __name__ == "__main__": + print(build())