diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aic/exp_gate_up_matmul.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aic/exp_gate_up_matmul.cpp new file mode 100644 index 000000000..93eaef9ad --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aic/exp_gate_up_matmul.cpp @@ -0,0 +1,402 @@ +// Kernel Function: exp_gate_up_matmul +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void exp_gate_up_matmul(__gm__ int8_t* v1, __gm__ int8_t* v2, __gm__ int8_t* v3, __gm__ int32_t* v4, __gm__ int32_t* v5, int32_t v6, int32_t v7) { + unsigned v8 = 0; + const int32_t v9 = 128; + const int32_t v10 = 0; + const int32_t v11 = 512; + const int32_t v12 = 16777216; + const int32_t v13 = 256; + const int32_t v14 = 1; + const int32_t v15 = 4096; + const int32_t v16 = 16; + const int64_t v17 = 32768; + const int64_t v18 = 2048; + const int64_t v19 = 16384; + const int64_t v20 = 8192; + const int64_t v21 = 0; + using T = float; + + #if defined(__DAV_CUBE__) + size_t v22 = (size_t) v13; + size_t v23 = (size_t) v11; + size_t v24 = (size_t) v10; + Tile v25 = Tile(v16, v11); + uint64_t v26 = (uint64_t) v21; + TASSIGN(v25, v26); + pto::Shape<1, 1, 1, 16, 512> v27 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v28 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v29 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v8 + v8 * (unsigned) v15 + v8 * (unsigned) v14), v27, v28); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v25, v29); + Tile v30 = Tile(v16, v11); + uint64_t v31 = (uint64_t) v21; + TASSIGN(v30, v31); + pipe_barrier(PIPE_MTE2); + TLOAD(v30, v29); + Tile v32 = Tile(v11, v13); + uint64_t v33 = (uint64_t) v20; + TASSIGN(v32, v33); + pto::Shape<1, 1, 1, 512, 256> v34 = pto::Shape<1, 1, 1, 512, 256>(); + pto::Stride<16777216, 16777216, 16777216, 1, 4096> v35 = pto::Stride<16777216, 16777216, 16777216, 1, 4096>(); + GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN> v36 = GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN>(v2 + ((v8 + (unsigned) v6 * (unsigned) v12) + v8 * (unsigned) v14 + (unsigned) v7 * (unsigned) v15), v34, v35); + TLOAD(v32, v36); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + Tile v37 = Tile(v16, v13); + uint64_t v38 = (uint64_t) v19; + TASSIGN(v37, v38); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + for (size_t v39 = v24; v39 < v23; v39 += v22) { + int32_t v40 = (int32_t) v39; + Tile v41 = Tile(v16, v9); + uint64_t v42 = (uint64_t) v18; + TASSIGN(v41, v42); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + pipe_barrier(PIPE_MTE1); + TEXTRACT(v41, v30, v10, v39); + Tile v43 = Tile(v9, v13); + uint64_t v44 = (uint64_t) v21; + TASSIGN(v43, v44); + TEXTRACT(v43, v32, v39, v10); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + Tile v45 = Tile(v16, v9); + uint64_t v46 = (uint64_t) v21; + TASSIGN(v45, v46); + int32_t v47 = (int32_t) ((uint32_t) v40 + (uint32_t) v9); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v45, v30, v10, v47); + Tile v48 = Tile(v9, v13); + uint64_t v49 = (uint64_t) v17; + TASSIGN(v48, v49); + TEXTRACT(v48, v32, v47, v10); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v40 == v10) { + Tile v50 = Tile(v16, v13); + uint64_t v51 = (uint64_t) v19; + TASSIGN(v50, v51); + pipe_barrier(PIPE_M); + TMATMUL(v50, v41, v43); + } else { + Tile v52 = Tile(v16, v13); + uint64_t v53 = (uint64_t) v19; + TASSIGN(v52, v53); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v52, v52, v41, v43); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + Tile v54 = Tile(v16, v13); + uint64_t v55 = (uint64_t) v19; + TASSIGN(v54, v55); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + TMATMUL_ACC(v54, v54, v45, v48); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + } + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + Tile v56 = Tile(v16, v11); + uint64_t v57 = (uint64_t) v21; + TASSIGN(v56, v57); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + pipe_barrier(PIPE_MTE2); + TLOAD(v56, v29); + Tile v58 = Tile(v11, v13); + uint64_t v59 = (uint64_t) v20; + TASSIGN(v58, v59); + pto::Shape<1, 1, 1, 512, 256> v60 = pto::Shape<1, 1, 1, 512, 256>(); + pto::Stride<16777216, 16777216, 16777216, 1, 4096> v61 = pto::Stride<16777216, 16777216, 16777216, 1, 4096>(); + GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN> v62 = GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN>(v3 + ((v8 + (unsigned) v6 * (unsigned) v12) + v8 * (unsigned) v14 + (unsigned) v7 * (unsigned) v15), v60, v61); + TLOAD(v58, v62); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + Tile v63 = Tile(v16, v13); + uint64_t v64 = (uint64_t) v21; + TASSIGN(v63, v64); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + for (size_t v65 = v24; v65 < v23; v65 += v22) { + int32_t v66 = (int32_t) v65; + Tile v67 = Tile(v16, v9); + uint64_t v68 = (uint64_t) v18; + TASSIGN(v67, v68); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + pipe_barrier(PIPE_MTE1); + TEXTRACT(v67, v56, v10, v65); + Tile v69 = Tile(v9, v13); + uint64_t v70 = (uint64_t) v21; + TASSIGN(v69, v70); + TEXTRACT(v69, v58, v65, v10); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + Tile v71 = Tile(v16, v9); + uint64_t v72 = (uint64_t) v21; + TASSIGN(v71, v72); + int32_t v73 = (int32_t) ((uint32_t) v66 + (uint32_t) v9); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TEXTRACT(v71, v56, v10, v73); + Tile v74 = Tile(v9, v13); + uint64_t v75 = (uint64_t) v17; + TASSIGN(v74, v75); + TEXTRACT(v74, v58, v73, v10); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + if (v66 == v10) { + Tile v76 = Tile(v16, v13); + uint64_t v77 = (uint64_t) v21; + TASSIGN(v76, v77); + pipe_barrier(PIPE_M); + TMATMUL(v76, v67, v69); + } else { + Tile v78 = Tile(v16, v13); + uint64_t v79 = (uint64_t) v21; + TASSIGN(v78, v79); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v78, v78, v67, v69); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + Tile v80 = Tile(v16, v13); + uint64_t v81 = (uint64_t) v21; + TASSIGN(v80, v81); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + TMATMUL_ACC(v80, v80, v71, v74); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + } + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + for (size_t v82 = v23; v82 < ((size_t) v15); v82 += v23) { + int32_t v83 = (int32_t) v82; + Tile v84 = Tile(v16, v11); + uint64_t v85 = (uint64_t) v21; + TASSIGN(v84, v85); + pto::Shape<1, 1, 1, 16, 512> v86 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v87 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v88 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v8 + v8 * (unsigned) v15 + (unsigned) v83 * (unsigned) v14), v86, v87); + pipe_barrier(PIPE_MTE2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v84, v88); + Tile v89 = Tile(v16, v11); + uint64_t v90 = (uint64_t) v21; + TASSIGN(v89, v90); + pipe_barrier(PIPE_MTE2); + TLOAD(v89, v88); + Tile v91 = Tile(v11, v13); + uint64_t v92 = (uint64_t) v20; + TASSIGN(v91, v92); + pto::Shape<1, 1, 1, 512, 256> v93 = pto::Shape<1, 1, 1, 512, 256>(); + pto::Stride<16777216, 16777216, 16777216, 1, 4096> v94 = pto::Stride<16777216, 16777216, 16777216, 1, 4096>(); + GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN> v95 = GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN>(v2 + ((v8 + (unsigned) v6 * (unsigned) v12) + (unsigned) v83 * (unsigned) v14 + (unsigned) v7 * (unsigned) v15), v93, v94); + TLOAD(v91, v95); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + for (size_t v96 = v24; v96 < v23; v96 += v22) { + Tile v97 = Tile(v16, v9); + uint64_t v98 = (uint64_t) v18; + TASSIGN(v97, v98); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v97, v89, v10, v96); + Tile v99 = Tile(v9, v13); + uint64_t v100 = (uint64_t) v21; + TASSIGN(v99, v100); + TEXTRACT(v99, v91, v96, v10); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + Tile v101 = Tile(v16, v9); + uint64_t v102 = (uint64_t) v21; + TASSIGN(v101, v102); + int32_t v103 = (int32_t) ((uint32_t) ((int32_t) v96) + (uint32_t) v9); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v101, v89, v10, v103); + Tile v104 = Tile(v9, v13); + uint64_t v105 = (uint64_t) v17; + TASSIGN(v104, v105); + TEXTRACT(v104, v91, v103, v10); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + Tile v106 = Tile(v16, v13); + uint64_t v107 = (uint64_t) v19; + TASSIGN(v106, v107); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v106, v106, v97, v99); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + Tile v108 = Tile(v16, v13); + uint64_t v109 = (uint64_t) v19; + TASSIGN(v108, v109); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + TMATMUL_ACC(v108, v108, v101, v104); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + }; + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + Tile v110 = Tile(v16, v11); + uint64_t v111 = (uint64_t) v21; + TASSIGN(v110, v111); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + pipe_barrier(PIPE_MTE2); + TLOAD(v110, v88); + Tile v112 = Tile(v11, v13); + uint64_t v113 = (uint64_t) v20; + TASSIGN(v112, v113); + pto::Shape<1, 1, 1, 512, 256> v114 = pto::Shape<1, 1, 1, 512, 256>(); + pto::Stride<16777216, 16777216, 16777216, 1, 4096> v115 = pto::Stride<16777216, 16777216, 16777216, 1, 4096>(); + GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN> v116 = GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN>(v3 + ((v8 + (unsigned) v6 * (unsigned) v12) + (unsigned) v83 * (unsigned) v14 + (unsigned) v7 * (unsigned) v15), v114, v115); + TLOAD(v112, v116); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + for (size_t v117 = v24; v117 < v23; v117 += v22) { + Tile v118 = Tile(v16, v9); + uint64_t v119 = (uint64_t) v18; + TASSIGN(v118, v119); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v118, v110, v10, v117); + Tile v120 = Tile(v9, v13); + uint64_t v121 = (uint64_t) v21; + TASSIGN(v120, v121); + TEXTRACT(v120, v112, v117, v10); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + Tile v122 = Tile(v16, v9); + uint64_t v123 = (uint64_t) v21; + TASSIGN(v122, v123); + int32_t v124 = (int32_t) ((uint32_t) ((int32_t) v117) + (uint32_t) v9); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v122, v110, v10, v124); + Tile v125 = Tile(v9, v13); + uint64_t v126 = (uint64_t) v17; + TASSIGN(v125, v126); + TEXTRACT(v125, v112, v124, v10); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + Tile v127 = Tile(v16, v13); + uint64_t v128 = (uint64_t) v21; + TASSIGN(v127, v128); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v127, v127, v118, v120); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + Tile v129 = Tile(v16, v13); + uint64_t v130 = (uint64_t) v21; + TASSIGN(v129, v130); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + TMATMUL_ACC(v129, v129, v122, v125); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + }; + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + } + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 256> v131 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v132 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v133 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v4 + ((v8 + v8 * (unsigned) v15) + v8 * (unsigned) v13 + v8 * (unsigned) v14), v131, v132); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v133, v37); + pto::Shape<1, 1, 1, 16, 256> v134 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v135 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v136 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v5 + ((v8 + v8 * (unsigned) v15) + v8 * (unsigned) v13 + v8 * (unsigned) v14), v134, v135); + TSTORE(v136, v63); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: recv_x_tile_i8_inline75__rv_v2 + __gm__ Tensor* recv_x_tile_i8_inline75__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ int8_t* recv_x_tile_i8_inline75__rv_v2 = reinterpret_cast<__gm__ int8_t*>(recv_x_tile_i8_inline75__rv_v2_tensor->buffer.addr) + recv_x_tile_i8_inline75__rv_v2_tensor->start_offset; + + // Unpack tensor: expert_w1__ssa_v0 + __gm__ Tensor* expert_w1__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int8_t* expert_w1__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(expert_w1__ssa_v0_tensor->buffer.addr) + expert_w1__ssa_v0_tensor->start_offset; + + // Unpack tensor: expert_w3__ssa_v0 + __gm__ Tensor* expert_w3__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ int8_t* expert_w3__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(expert_w3__ssa_v0_tensor->buffer.addr) + expert_w3__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ int32_t* ret0__out = reinterpret_cast<__gm__ int32_t*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack tensor: ret1__out + __gm__ Tensor* ret1__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[4]); + __gm__ int32_t* ret1__out = reinterpret_cast<__gm__ int32_t*>(ret1__out_tensor->buffer.addr) + ret1__out_tensor->start_offset; + + // Unpack scalar: local_i_inline67__idx_v0 + union { uint64_t u64; int64_t val; } local_i_inline67__idx_v0_conv; + local_i_inline67__idx_v0_conv.u64 = args[5]; + int64_t local_i_inline67__idx_v0 = local_i_inline67__idx_v0_conv.val; + + // Unpack scalar: n0_inline72__idx_v0 + union { uint64_t u64; int64_t val; } n0_inline72__idx_v0_conv; + n0_inline72__idx_v0_conv.u64 = args[6]; + int64_t n0_inline72__idx_v0 = n0_inline72__idx_v0_conv.val; + + // Forward to ptoas-generated function + exp_gate_up_matmul(recv_x_tile_i8_inline75__rv_v2, expert_w1__ssa_v0, expert_w3__ssa_v0, ret0__out, ret1__out, local_i_inline67__idx_v0, n0_inline72__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aic/exp_w2_matmul.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aic/exp_w2_matmul.cpp new file mode 100644 index 000000000..a62a84424 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aic/exp_w2_matmul.cpp @@ -0,0 +1,245 @@ +// Kernel Function: exp_w2_matmul +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void exp_w2_matmul(__gm__ int8_t* v1, __gm__ int8_t* v2, __gm__ int32_t* v3, int32_t v4, int32_t v5) { + unsigned v6 = 0; + const int32_t v7 = 64; + const int32_t v8 = 128; + const int32_t v9 = 0; + const int32_t v10 = 16777216; + const int32_t v11 = 512; + const int32_t v12 = 1; + const int32_t v13 = 4096; + const int32_t v14 = 16; + const int64_t v15 = 32768; + const int64_t v16 = 1024; + const int64_t v17 = 8192; + const int64_t v18 = 0; + const int32_t v19 = 8192; + using T = float; + + #if defined(__DAV_CUBE__) + size_t v20 = (size_t) v11; + size_t v21 = (size_t) v9; + size_t v22 = (size_t) v8; + Tile v23 = Tile(v14, v11); + uint64_t v24 = (uint64_t) v18; + TASSIGN(v23, v24); + pto::Shape<1, 1, 1, 16, 512> v25 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v26 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v27 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v13 + v6 * (unsigned) v12), v25, v26); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TLOAD(v23, v27); + Tile v28 = Tile(v11, v11); + uint64_t v29 = (uint64_t) v17; + TASSIGN(v28, v29); + pto::Shape<1, 1, 1, 512, 512> v30 = pto::Shape<1, 1, 1, 512, 512>(); + pto::Stride<16777216, 16777216, 16777216, 1, 4096> v31 = pto::Stride<16777216, 16777216, 16777216, 1, 4096>(); + GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN> v32 = GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN>(v2 + ((v6 + (unsigned) v4 * (unsigned) v10) + v6 * (unsigned) v12 + (unsigned) v5 * (unsigned) v13), v30, v31); + TLOAD(v28, v32); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + Tile v33 = Tile(v14, v11); + uint64_t v34 = (uint64_t) v18; + TASSIGN(v33, v34); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + for (size_t v35 = v21; v35 < v20; v35 += v22) { + int32_t v36 = (int32_t) v35; + Tile v37 = Tile(v14, v7); + uint64_t v38 = (uint64_t) v18; + TASSIGN(v37, v38); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + pipe_barrier(PIPE_MTE1); + TEXTRACT(v37, v23, v9, v35); + Tile v39 = Tile(v7, v11); + uint64_t v40 = (uint64_t) v18; + TASSIGN(v39, v40); + TEXTRACT(v39, v28, v35, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + Tile v41 = Tile(v14, v7); + uint64_t v42 = (uint64_t) v16; + TASSIGN(v41, v42); + int32_t v43 = (int32_t) ((uint32_t) v36 + (uint32_t) v7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v41, v23, v9, v43); + Tile v44 = Tile(v7, v11); + uint64_t v45 = (uint64_t) v15; + TASSIGN(v44, v45); + TEXTRACT(v44, v28, v43, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v36 == v9) { + Tile v46 = Tile(v14, v11); + uint64_t v47 = (uint64_t) v18; + TASSIGN(v46, v47); + pipe_barrier(PIPE_M); + TMATMUL(v46, v37, v39); + } else { + Tile v48 = Tile(v14, v11); + uint64_t v49 = (uint64_t) v18; + TASSIGN(v48, v49); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v48, v48, v37, v39); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + Tile v50 = Tile(v14, v11); + uint64_t v51 = (uint64_t) v18; + TASSIGN(v50, v51); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + TMATMUL_ACC(v50, v50, v41, v44); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + } + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + for (size_t v52 = v20; v52 < ((size_t) v13); v52 += v20) { + int32_t v53 = (int32_t) v52; + Tile v54 = Tile(v14, v11); + uint64_t v55 = (uint64_t) v18; + TASSIGN(v54, v55); + pto::Shape<1, 1, 1, 16, 512> v56 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v57 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v58 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v13 + (unsigned) v53 * (unsigned) v12), v56, v57); + pipe_barrier(PIPE_MTE2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v54, v58); + Tile v59 = Tile(v11, v11); + uint64_t v60 = (uint64_t) v17; + TASSIGN(v59, v60); + pto::Shape<1, 1, 1, 512, 512> v61 = pto::Shape<1, 1, 1, 512, 512>(); + pto::Stride<16777216, 16777216, 16777216, 1, 4096> v62 = pto::Stride<16777216, 16777216, 16777216, 1, 4096>(); + GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN> v63 = GlobalTensor, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN>(v2 + ((v6 + (unsigned) v4 * (unsigned) v10) + (unsigned) v53 * (unsigned) v12 + (unsigned) v5 * (unsigned) v13), v61, v62); + TLOAD(v59, v63); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + for (size_t v64 = v21; v64 < v20; v64 += v22) { + Tile v65 = Tile(v14, v7); + uint64_t v66 = (uint64_t) v18; + TASSIGN(v65, v66); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v65, v54, v9, v64); + Tile v67 = Tile(v7, v11); + uint64_t v68 = (uint64_t) v18; + TASSIGN(v67, v68); + TEXTRACT(v67, v59, v64, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + Tile v69 = Tile(v14, v7); + uint64_t v70 = (uint64_t) v16; + TASSIGN(v69, v70); + int32_t v71 = (int32_t) ((uint32_t) ((int32_t) v64) + (uint32_t) v7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TEXTRACT(v69, v54, v9, v71); + Tile v72 = Tile(v7, v11); + uint64_t v73 = (uint64_t) v15; + TASSIGN(v72, v73); + TEXTRACT(v72, v59, v71, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + Tile v74 = Tile(v14, v11); + uint64_t v75 = (uint64_t) v18; + TASSIGN(v74, v75); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v74, v74, v65, v67); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + Tile v76 = Tile(v14, v11); + uint64_t v77 = (uint64_t) v18; + TASSIGN(v76, v77); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + TMATMUL_ACC(v76, v76, v69, v72); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + }; + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + } + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 512> v78 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<8192, 8192, 8192, 512, 1> v79 = pto::Stride<8192, 8192, 8192, 512, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND> v80 = GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND>(v3 + ((v6 + v6 * (unsigned) v19) + v6 * (unsigned) v11 + v6 * (unsigned) v12), v78, v79); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v80, v33); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: h_tile_i8_inline92__rv_v2 + __gm__ Tensor* h_tile_i8_inline92__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ int8_t* h_tile_i8_inline92__rv_v2 = reinterpret_cast<__gm__ int8_t*>(h_tile_i8_inline92__rv_v2_tensor->buffer.addr) + h_tile_i8_inline92__rv_v2_tensor->start_offset; + + // Unpack tensor: expert_w2__ssa_v0 + __gm__ Tensor* expert_w2__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int8_t* expert_w2__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(expert_w2__ssa_v0_tensor->buffer.addr) + expert_w2__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ int32_t* ret0__out = reinterpret_cast<__gm__ int32_t*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack scalar: local_i_inline67__idx_v0 + union { uint64_t u64; int64_t val; } local_i_inline67__idx_v0_conv; + local_i_inline67__idx_v0_conv.u64 = args[3]; + int64_t local_i_inline67__idx_v0 = local_i_inline67__idx_v0_conv.val; + + // Unpack scalar: d0_inline49__idx_v0 + union { uint64_t u64; int64_t val; } d0_inline49__idx_v0_conv; + d0_inline49__idx_v0_conv.u64 = args[4]; + int64_t d0_inline49__idx_v0 = d0_inline49__idx_v0_conv.val; + + // Forward to ptoas-generated function + exp_w2_matmul(h_tile_i8_inline92__rv_v2, expert_w2__ssa_v0, ret0__out, local_i_inline67__idx_v0, d0_inline49__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aic/sh_gate_up_matmul.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aic/sh_gate_up_matmul.cpp new file mode 100644 index 000000000..5bbc279d0 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aic/sh_gate_up_matmul.cpp @@ -0,0 +1,373 @@ +// Kernel Function: sh_gate_up_matmul +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void sh_gate_up_matmul(__gm__ int8_t* v1, __gm__ int8_t* v2, __gm__ int8_t* v3, __gm__ int32_t* v4, __gm__ int32_t* v5, int32_t v6) { + unsigned v7 = 0; + const int32_t v8 = 128; + const int32_t v9 = 0; + const int32_t v10 = 512; + const int32_t v11 = 256; + const int32_t v12 = 1; + const int32_t v13 = 4096; + const int32_t v14 = 16; + const int64_t v15 = 32768; + const int64_t v16 = 2048; + const int64_t v17 = 16384; + const int64_t v18 = 139264; + const int64_t v19 = 8192; + const int64_t v20 = 0; + using T = float; + + #if defined(__DAV_CUBE__) + size_t v21 = (size_t) v11; + size_t v22 = (size_t) v10; + size_t v23 = (size_t) v9; + Tile v24 = Tile(v14, v10); + uint64_t v25 = (uint64_t) v20; + TASSIGN(v24, v25); + pto::Shape<1, 1, 1, 16, 512> v26 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v27 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v28 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v7 + v7 * (unsigned) v13 + v7 * (unsigned) v12), v26, v27); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v24, v28); + Tile v29 = Tile(v10, v11); + uint64_t v30 = (uint64_t) v19; + TASSIGN(v29, v30); + pto::Shape<1, 1, 1, 512, 256> v31 = pto::Shape<1, 1, 1, 512, 256>(); + pto::Stride<512, 512, 512, 1, 4096> v32 = pto::Stride<512, 512, 512, 1, 4096>(); + GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN> v33 = GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN>(v2 + (v7 + v7 * (unsigned) v12 + (unsigned) v6 * (unsigned) v13), v31, v32); + TLOAD(v29, v33); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + Tile v34 = Tile(v10, v11); + uint64_t v35 = (uint64_t) v18; + TASSIGN(v34, v35); + pto::Shape<1, 1, 1, 512, 256> v36 = pto::Shape<1, 1, 1, 512, 256>(); + pto::Stride<512, 512, 512, 1, 4096> v37 = pto::Stride<512, 512, 512, 1, 4096>(); + GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN> v38 = GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN>(v3 + (v7 + v7 * (unsigned) v12 + (unsigned) v6 * (unsigned) v13), v36, v37); + TLOAD(v34, v38); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + Tile v39 = Tile(v14, v11); + uint64_t v40 = (uint64_t) v17; + TASSIGN(v39, v40); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + for (size_t v41 = v23; v41 < v22; v41 += v21) { + int32_t v42 = (int32_t) v41; + Tile v43 = Tile(v14, v8); + uint64_t v44 = (uint64_t) v16; + TASSIGN(v43, v44); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + pipe_barrier(PIPE_MTE1); + TEXTRACT(v43, v24, v9, v41); + Tile v45 = Tile(v8, v11); + uint64_t v46 = (uint64_t) v20; + TASSIGN(v45, v46); + TEXTRACT(v45, v29, v41, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + Tile v47 = Tile(v14, v8); + uint64_t v48 = (uint64_t) v20; + TASSIGN(v47, v48); + int32_t v49 = (int32_t) ((uint32_t) v42 + (uint32_t) v8); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v47, v24, v9, v49); + Tile v50 = Tile(v8, v11); + uint64_t v51 = (uint64_t) v15; + TASSIGN(v50, v51); + TEXTRACT(v50, v29, v49, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v42 == v9) { + Tile v52 = Tile(v14, v11); + uint64_t v53 = (uint64_t) v17; + TASSIGN(v52, v53); + pipe_barrier(PIPE_M); + TMATMUL(v52, v43, v45); + } else { + Tile v54 = Tile(v14, v11); + uint64_t v55 = (uint64_t) v17; + TASSIGN(v54, v55); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v54, v54, v43, v45); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + Tile v56 = Tile(v14, v11); + uint64_t v57 = (uint64_t) v17; + TASSIGN(v56, v57); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + TMATMUL_ACC(v56, v56, v47, v50); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + } + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + Tile v58 = Tile(v14, v11); + uint64_t v59 = (uint64_t) v20; + TASSIGN(v58, v59); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + for (size_t v60 = v23; v60 < v22; v60 += v21) { + int32_t v61 = (int32_t) v60; + Tile v62 = Tile(v14, v8); + uint64_t v63 = (uint64_t) v16; + TASSIGN(v62, v63); + pipe_barrier(PIPE_MTE1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v62, v24, v9, v60); + Tile v64 = Tile(v8, v11); + uint64_t v65 = (uint64_t) v20; + TASSIGN(v64, v65); + TEXTRACT(v64, v34, v60, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + Tile v66 = Tile(v14, v8); + uint64_t v67 = (uint64_t) v20; + TASSIGN(v66, v67); + int32_t v68 = (int32_t) ((uint32_t) v61 + (uint32_t) v8); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TEXTRACT(v66, v24, v9, v68); + Tile v69 = Tile(v8, v11); + uint64_t v70 = (uint64_t) v15; + TASSIGN(v69, v70); + TEXTRACT(v69, v34, v68, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + if (v61 == v9) { + Tile v71 = Tile(v14, v11); + uint64_t v72 = (uint64_t) v20; + TASSIGN(v71, v72); + pipe_barrier(PIPE_M); + TMATMUL(v71, v62, v64); + } else { + Tile v73 = Tile(v14, v11); + uint64_t v74 = (uint64_t) v20; + TASSIGN(v73, v74); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v73, v73, v62, v64); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + Tile v75 = Tile(v14, v11); + uint64_t v76 = (uint64_t) v20; + TASSIGN(v75, v76); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + TMATMUL_ACC(v75, v75, v66, v69); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + } + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + for (size_t v77 = v22; v77 < ((size_t) v13); v77 += v22) { + int32_t v78 = (int32_t) v77; + Tile v79 = Tile(v14, v10); + uint64_t v80 = (uint64_t) v20; + TASSIGN(v79, v80); + pto::Shape<1, 1, 1, 16, 512> v81 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v82 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v83 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v7 + v7 * (unsigned) v13 + (unsigned) v78 * (unsigned) v12), v81, v82); + pipe_barrier(PIPE_MTE2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v79, v83); + Tile v84 = Tile(v10, v11); + uint64_t v85 = (uint64_t) v19; + TASSIGN(v84, v85); + pto::Shape<1, 1, 1, 512, 256> v86 = pto::Shape<1, 1, 1, 512, 256>(); + pto::Stride<512, 512, 512, 1, 4096> v87 = pto::Stride<512, 512, 512, 1, 4096>(); + GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN> v88 = GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN>(v2 + (v7 + (unsigned) v78 * (unsigned) v12 + (unsigned) v6 * (unsigned) v13), v86, v87); + TLOAD(v84, v88); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + Tile v89 = Tile(v10, v11); + uint64_t v90 = (uint64_t) v18; + TASSIGN(v89, v90); + pto::Shape<1, 1, 1, 512, 256> v91 = pto::Shape<1, 1, 1, 512, 256>(); + pto::Stride<512, 512, 512, 1, 4096> v92 = pto::Stride<512, 512, 512, 1, 4096>(); + GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN> v93 = GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN>(v3 + (v7 + (unsigned) v78 * (unsigned) v12 + (unsigned) v6 * (unsigned) v13), v91, v92); + TLOAD(v89, v93); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + for (size_t v94 = v23; v94 < v22; v94 += v21) { + Tile v95 = Tile(v14, v8); + uint64_t v96 = (uint64_t) v16; + TASSIGN(v95, v96); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TEXTRACT(v95, v79, v9, v94); + Tile v97 = Tile(v8, v11); + uint64_t v98 = (uint64_t) v20; + TASSIGN(v97, v98); + TEXTRACT(v97, v84, v94, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + Tile v99 = Tile(v14, v8); + uint64_t v100 = (uint64_t) v20; + TASSIGN(v99, v100); + int32_t v101 = (int32_t) ((uint32_t) ((int32_t) v94) + (uint32_t) v8); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v99, v79, v9, v101); + Tile v102 = Tile(v8, v11); + uint64_t v103 = (uint64_t) v15; + TASSIGN(v102, v103); + TEXTRACT(v102, v84, v101, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + Tile v104 = Tile(v14, v11); + uint64_t v105 = (uint64_t) v17; + TASSIGN(v104, v105); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v104, v104, v95, v97); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + Tile v106 = Tile(v14, v11); + uint64_t v107 = (uint64_t) v17; + TASSIGN(v106, v107); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + TMATMUL_ACC(v106, v106, v99, v102); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + }; + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + for (size_t v108 = v23; v108 < v22; v108 += v21) { + Tile v109 = Tile(v14, v8); + uint64_t v110 = (uint64_t) v16; + TASSIGN(v109, v110); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TEXTRACT(v109, v79, v9, v108); + Tile v111 = Tile(v8, v11); + uint64_t v112 = (uint64_t) v20; + TASSIGN(v111, v112); + TEXTRACT(v111, v89, v108, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + Tile v113 = Tile(v14, v8); + uint64_t v114 = (uint64_t) v20; + TASSIGN(v113, v114); + int32_t v115 = (int32_t) ((uint32_t) ((int32_t) v108) + (uint32_t) v8); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v113, v79, v9, v115); + Tile v116 = Tile(v8, v11); + uint64_t v117 = (uint64_t) v15; + TASSIGN(v116, v117); + TEXTRACT(v116, v89, v115, v9); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + Tile v118 = Tile(v14, v11); + uint64_t v119 = (uint64_t) v20; + TASSIGN(v118, v119); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v118, v118, v109, v111); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + Tile v120 = Tile(v14, v11); + uint64_t v121 = (uint64_t) v20; + TASSIGN(v120, v121); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + TMATMUL_ACC(v120, v120, v113, v116); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + }; + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + } + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 256> v122 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v123 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v124 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v4 + (v7 + v7 * (unsigned) v11 + v7 * (unsigned) v12), v122, v123); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v124, v39); + pto::Shape<1, 1, 1, 16, 256> v125 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v126 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v127 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v5 + (v7 + v7 * (unsigned) v11 + v7 * (unsigned) v12), v125, v126); + TSTORE(v127, v58); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: x_local_i8_inline43__rv_v2 + __gm__ Tensor* x_local_i8_inline43__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ int8_t* x_local_i8_inline43__rv_v2 = reinterpret_cast<__gm__ int8_t*>(x_local_i8_inline43__rv_v2_tensor->buffer.addr) + x_local_i8_inline43__rv_v2_tensor->start_offset; + + // Unpack tensor: shared_w1__ssa_v0 + __gm__ Tensor* shared_w1__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int8_t* shared_w1__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(shared_w1__ssa_v0_tensor->buffer.addr) + shared_w1__ssa_v0_tensor->start_offset; + + // Unpack tensor: shared_w3__ssa_v0 + __gm__ Tensor* shared_w3__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ int8_t* shared_w3__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(shared_w3__ssa_v0_tensor->buffer.addr) + shared_w3__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ int32_t* ret0__out = reinterpret_cast<__gm__ int32_t*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack tensor: ret1__out + __gm__ Tensor* ret1__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[4]); + __gm__ int32_t* ret1__out = reinterpret_cast<__gm__ int32_t*>(ret1__out_tensor->buffer.addr) + ret1__out_tensor->start_offset; + + // Unpack scalar: n0_inline113__idx_v0 + union { uint64_t u64; int64_t val; } n0_inline113__idx_v0_conv; + n0_inline113__idx_v0_conv.u64 = args[5]; + int64_t n0_inline113__idx_v0 = n0_inline113__idx_v0_conv.val; + + // Forward to ptoas-generated function + sh_gate_up_matmul(x_local_i8_inline43__rv_v2, shared_w1__ssa_v0, shared_w3__ssa_v0, ret0__out, ret1__out, n0_inline113__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aic/sh_w2_matmul.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aic/sh_w2_matmul.cpp new file mode 100644 index 000000000..3c3299c6e --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aic/sh_w2_matmul.cpp @@ -0,0 +1,238 @@ +// Kernel Function: sh_w2_matmul +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void sh_w2_matmul(__gm__ int8_t* v1, __gm__ int8_t* v2, __gm__ int32_t* v3, int32_t v4) { + unsigned v5 = 0; + const int32_t v6 = 64; + const int32_t v7 = 128; + const int32_t v8 = 0; + const int32_t v9 = 512; + const int32_t v10 = 1; + const int32_t v11 = 4096; + const int32_t v12 = 16; + const int64_t v13 = 32768; + const int64_t v14 = 1024; + const int64_t v15 = 8192; + const int64_t v16 = 0; + using T = float; + + #if defined(__DAV_CUBE__) + size_t v17 = (size_t) v9; + size_t v18 = (size_t) v8; + size_t v19 = (size_t) v7; + Tile v20 = Tile(v12, v9); + uint64_t v21 = (uint64_t) v16; + TASSIGN(v20, v21); + pto::Shape<1, 1, 1, 16, 512> v22 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v23 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v24 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v5 + v5 * (unsigned) v11 + v5 * (unsigned) v10), v22, v23); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TLOAD(v20, v24); + Tile v25 = Tile(v9, v9); + uint64_t v26 = (uint64_t) v15; + TASSIGN(v25, v26); + pto::Shape<1, 1, 1, 512, 512> v27 = pto::Shape<1, 1, 1, 512, 512>(); + pto::Stride<512, 512, 512, 1, 4096> v28 = pto::Stride<512, 512, 512, 1, 4096>(); + GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN> v29 = GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN>(v2 + (v5 + v5 * (unsigned) v10 + (unsigned) v4 * (unsigned) v11), v27, v28); + TLOAD(v25, v29); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + Tile v30 = Tile(v12, v9); + uint64_t v31 = (uint64_t) v16; + TASSIGN(v30, v31); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + for (size_t v32 = v18; v32 < v17; v32 += v19) { + int32_t v33 = (int32_t) v32; + Tile v34 = Tile(v12, v6); + uint64_t v35 = (uint64_t) v16; + TASSIGN(v34, v35); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + pipe_barrier(PIPE_MTE1); + TEXTRACT(v34, v20, v8, v32); + Tile v36 = Tile(v6, v9); + uint64_t v37 = (uint64_t) v16; + TASSIGN(v36, v37); + TEXTRACT(v36, v25, v32, v8); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + Tile v38 = Tile(v12, v6); + uint64_t v39 = (uint64_t) v14; + TASSIGN(v38, v39); + int32_t v40 = (int32_t) ((uint32_t) v33 + (uint32_t) v6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TEXTRACT(v38, v20, v8, v40); + Tile v41 = Tile(v6, v9); + uint64_t v42 = (uint64_t) v13; + TASSIGN(v41, v42); + TEXTRACT(v41, v25, v40, v8); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + if (v33 == v8) { + Tile v43 = Tile(v12, v9); + uint64_t v44 = (uint64_t) v16; + TASSIGN(v43, v44); + pipe_barrier(PIPE_M); + TMATMUL(v43, v34, v36); + } else { + Tile v45 = Tile(v12, v9); + uint64_t v46 = (uint64_t) v16; + TASSIGN(v45, v46); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v45, v45, v34, v36); + }; + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + Tile v47 = Tile(v12, v9); + uint64_t v48 = (uint64_t) v16; + TASSIGN(v47, v48); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + TMATMUL_ACC(v47, v47, v38, v41); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + } + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + for (size_t v49 = v17; v49 < ((size_t) v11); v49 += v17) { + int32_t v50 = (int32_t) v49; + Tile v51 = Tile(v12, v9); + uint64_t v52 = (uint64_t) v16; + TASSIGN(v51, v52); + pto::Shape<1, 1, 1, 16, 512> v53 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v54 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v55 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v5 + v5 * (unsigned) v11 + (unsigned) v50 * (unsigned) v10), v53, v54); + pipe_barrier(PIPE_MTE2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v51, v55); + Tile v56 = Tile(v9, v9); + uint64_t v57 = (uint64_t) v15; + TASSIGN(v56, v57); + pto::Shape<1, 1, 1, 512, 512> v58 = pto::Shape<1, 1, 1, 512, 512>(); + pto::Stride<512, 512, 512, 1, 4096> v59 = pto::Stride<512, 512, 512, 1, 4096>(); + GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN> v60 = GlobalTensor, pto::Stride<512, 512, 512, 1, 4096>, pto::Layout::DN>(v2 + (v5 + (unsigned) v50 * (unsigned) v10 + (unsigned) v4 * (unsigned) v11), v58, v59); + TLOAD(v56, v60); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + for (size_t v61 = v18; v61 < v17; v61 += v19) { + Tile v62 = Tile(v12, v6); + uint64_t v63 = (uint64_t) v16; + TASSIGN(v62, v63); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + TEXTRACT(v62, v51, v8, v61); + Tile v64 = Tile(v6, v9); + uint64_t v65 = (uint64_t) v16; + TASSIGN(v64, v65); + TEXTRACT(v64, v56, v61, v8); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + Tile v66 = Tile(v12, v6); + uint64_t v67 = (uint64_t) v14; + TASSIGN(v66, v67); + int32_t v68 = (int32_t) ((uint32_t) ((int32_t) v61) + (uint32_t) v6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TEXTRACT(v66, v51, v8, v68); + Tile v69 = Tile(v6, v9); + uint64_t v70 = (uint64_t) v13; + TASSIGN(v69, v70); + TEXTRACT(v69, v56, v68, v8); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + Tile v71 = Tile(v12, v9); + uint64_t v72 = (uint64_t) v16; + TASSIGN(v71, v72); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + pipe_barrier(PIPE_M); + TMATMUL_ACC(v71, v71, v62, v64); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + Tile v73 = Tile(v12, v9); + uint64_t v74 = (uint64_t) v16; + TASSIGN(v73, v74); + pipe_barrier(PIPE_M); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + TMATMUL_ACC(v73, v73, v66, v69); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + }; + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + } + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 512> v75 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<8192, 8192, 8192, 512, 1> v76 = pto::Stride<8192, 8192, 8192, 512, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND> v77 = GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND>(v3 + (v5 + v5 * (unsigned) v9 + v5 * (unsigned) v10), v75, v76); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v77, v30); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: sh_tile_i8_inline126__rv_v2 + __gm__ Tensor* sh_tile_i8_inline126__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ int8_t* sh_tile_i8_inline126__rv_v2 = reinterpret_cast<__gm__ int8_t*>(sh_tile_i8_inline126__rv_v2_tensor->buffer.addr) + sh_tile_i8_inline126__rv_v2_tensor->start_offset; + + // Unpack tensor: shared_w2__ssa_v0 + __gm__ Tensor* shared_w2__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int8_t* shared_w2__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(shared_w2__ssa_v0_tensor->buffer.addr) + shared_w2__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ int32_t* ret0__out = reinterpret_cast<__gm__ int32_t*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack scalar: d0_inline64__idx_v0 + union { uint64_t u64; int64_t val; } d0_inline64__idx_v0_conv; + d0_inline64__idx_v0_conv.u64 = args[3]; + int64_t d0_inline64__idx_v0 = d0_inline64__idx_v0_conv.val; + + // Forward to ptoas-generated function + sh_w2_matmul(sh_tile_i8_inline126__rv_v2, shared_w2__ssa_v0, ret0__out, d0_inline64__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/cast_x.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/cast_x.cpp new file mode 100644 index 000000000..66adc9169 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/cast_x.cpp @@ -0,0 +1,103 @@ +// Kernel Function: cast_x +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void cast_x(__gm__ bfloat16_t* v1, __gm__ float* v2, int32_t v3) { + RoundMode v4 = RoundMode::CAST_ROUND; + unsigned v5 = 0; + const int32_t v6 = 512; + const int32_t v7 = 1; + const int32_t v8 = 16384; + const int32_t v9 = 16; + const int64_t v10 = 16384; + const int64_t v11 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v12 = Tile(v9, v6); + uint64_t v13 = (uint64_t) v11; + TASSIGN(v12, v13); + pto::Shape<1, 1, 1, 16, 512> v14 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<262144, 262144, 262144, 16384, 1> v15 = pto::Stride<262144, 262144, 262144, 16384, 1>(); + GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND> v16 = GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND>(v1 + (v5 + v5 * (unsigned) v8 + (unsigned) v3 * (unsigned) v7), v14, v15); + TLOAD(v12, v16); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v17 = Tile(v9, v6); + uint64_t v18 = (uint64_t) v10; + TASSIGN(v17, v18); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v17, v12, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 512> v19 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<262144, 262144, 262144, 16384, 1> v20 = pto::Stride<262144, 262144, 262144, 16384, 1>(); + GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND> v21 = GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND>(v2 + (v5 + v5 * (unsigned) v8 + (unsigned) v3 * (unsigned) v7), v19, v20); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v21, v17); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: x_flat_inline36__ssa_v0 + __gm__ Tensor* x_flat_inline36__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ bfloat16_t* x_flat_inline36__ssa_v0 = reinterpret_cast<__gm__ bfloat16_t*>(x_flat_inline36__ssa_v0_tensor->buffer.addr) + x_flat_inline36__ssa_v0_tensor->start_offset; + + // Unpack tensor: x_flat_fp32_inline49__iter_v1 + __gm__ Tensor* x_flat_fp32_inline49__iter_v1_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* x_flat_fp32_inline49__iter_v1 = reinterpret_cast<__gm__ float*>(x_flat_fp32_inline49__iter_v1_tensor->buffer.addr) + x_flat_fp32_inline49__iter_v1_tensor->start_offset; + + // Unpack scalar: k0_inline59__ssa_v0 + union { uint64_t u64; int64_t val; } k0_inline59__ssa_v0_conv; + k0_inline59__ssa_v0_conv.u64 = args[2]; + int64_t k0_inline59__ssa_v0 = k0_inline59__ssa_v0_conv.val; + + // Forward to ptoas-generated function + cast_x(x_flat_inline36__ssa_v0, x_flat_fp32_inline49__iter_v1, k0_inline59__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/comb_sinkhorn.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/comb_sinkhorn.cpp new file mode 100644 index 000000000..5279a3ac9 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/comb_sinkhorn.cpp @@ -0,0 +1,3326 @@ +// Kernel Function: comb_sinkhorn +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void comb_sinkhorn(__gm__ float* v1, __gm__ float* v2) { + unsigned v3 = 12; + unsigned v4 = 8; + unsigned v5 = 4; + unsigned v6 = 0; + const int32_t v7 = 255; + const int32_t v8 = 251; + const int32_t v9 = 247; + const int32_t v10 = 243; + const int32_t v11 = 254; + const int32_t v12 = 250; + const int32_t v13 = 246; + const int32_t v14 = 242; + const int32_t v15 = 253; + const int32_t v16 = 249; + const int32_t v17 = 245; + const int32_t v18 = 241; + const int32_t v19 = 252; + const int32_t v20 = 248; + const int32_t v21 = 244; + const int32_t v22 = 240; + const int32_t v23 = 239; + const int32_t v24 = 235; + const int32_t v25 = 231; + const int32_t v26 = 227; + const int32_t v27 = 238; + const int32_t v28 = 234; + const int32_t v29 = 230; + const int32_t v30 = 226; + const int32_t v31 = 237; + const int32_t v32 = 233; + const int32_t v33 = 229; + const int32_t v34 = 225; + const int32_t v35 = 236; + const int32_t v36 = 232; + const int32_t v37 = 228; + const int32_t v38 = 224; + const int32_t v39 = 223; + const int32_t v40 = 219; + const int32_t v41 = 215; + const int32_t v42 = 211; + const int32_t v43 = 222; + const int32_t v44 = 218; + const int32_t v45 = 214; + const int32_t v46 = 210; + const int32_t v47 = 221; + const int32_t v48 = 217; + const int32_t v49 = 213; + const int32_t v50 = 209; + const int32_t v51 = 220; + const int32_t v52 = 216; + const int32_t v53 = 212; + const int32_t v54 = 208; + const int32_t v55 = 207; + const int32_t v56 = 203; + const int32_t v57 = 199; + const int32_t v58 = 195; + const int32_t v59 = 206; + const int32_t v60 = 202; + const int32_t v61 = 198; + const int32_t v62 = 194; + const int32_t v63 = 205; + const int32_t v64 = 201; + const int32_t v65 = 197; + const int32_t v66 = 193; + const int32_t v67 = 204; + const int32_t v68 = 200; + const int32_t v69 = 196; + const int32_t v70 = 192; + const int32_t v71 = 191; + const int32_t v72 = 187; + const int32_t v73 = 183; + const int32_t v74 = 179; + const int32_t v75 = 190; + const int32_t v76 = 186; + const int32_t v77 = 182; + const int32_t v78 = 178; + const int32_t v79 = 189; + const int32_t v80 = 185; + const int32_t v81 = 181; + const int32_t v82 = 177; + const int32_t v83 = 188; + const int32_t v84 = 184; + const int32_t v85 = 180; + const int32_t v86 = 176; + const int32_t v87 = 175; + const int32_t v88 = 171; + const int32_t v89 = 167; + const int32_t v90 = 163; + const int32_t v91 = 174; + const int32_t v92 = 170; + const int32_t v93 = 166; + const int32_t v94 = 162; + const int32_t v95 = 173; + const int32_t v96 = 169; + const int32_t v97 = 165; + const int32_t v98 = 161; + const int32_t v99 = 172; + const int32_t v100 = 168; + const int32_t v101 = 164; + const int32_t v102 = 160; + const int32_t v103 = 159; + const int32_t v104 = 155; + const int32_t v105 = 151; + const int32_t v106 = 147; + const int32_t v107 = 158; + const int32_t v108 = 154; + const int32_t v109 = 150; + const int32_t v110 = 146; + const int32_t v111 = 157; + const int32_t v112 = 153; + const int32_t v113 = 149; + const int32_t v114 = 145; + const int32_t v115 = 156; + const int32_t v116 = 152; + const int32_t v117 = 148; + const int32_t v118 = 144; + const int32_t v119 = 143; + const int32_t v120 = 139; + const int32_t v121 = 135; + const int32_t v122 = 131; + const int32_t v123 = 142; + const int32_t v124 = 138; + const int32_t v125 = 134; + const int32_t v126 = 130; + const int32_t v127 = 141; + const int32_t v128 = 137; + const int32_t v129 = 133; + const int32_t v130 = 129; + const int32_t v131 = 140; + const int32_t v132 = 136; + const int32_t v133 = 132; + const int32_t v134 = 128; + const int32_t v135 = 127; + const int32_t v136 = 123; + const int32_t v137 = 119; + const int32_t v138 = 115; + const int32_t v139 = 126; + const int32_t v140 = 122; + const int32_t v141 = 118; + const int32_t v142 = 114; + const int32_t v143 = 125; + const int32_t v144 = 121; + const int32_t v145 = 117; + const int32_t v146 = 113; + const int32_t v147 = 124; + const int32_t v148 = 120; + const int32_t v149 = 116; + const int32_t v150 = 112; + const int32_t v151 = 111; + const int32_t v152 = 107; + const int32_t v153 = 103; + const int32_t v154 = 99; + const int32_t v155 = 110; + const int32_t v156 = 106; + const int32_t v157 = 102; + const int32_t v158 = 98; + const int32_t v159 = 109; + const int32_t v160 = 105; + const int32_t v161 = 101; + const int32_t v162 = 97; + const int32_t v163 = 108; + const int32_t v164 = 104; + const int32_t v165 = 100; + const int32_t v166 = 96; + const int32_t v167 = 95; + const int32_t v168 = 91; + const int32_t v169 = 87; + const int32_t v170 = 83; + const int32_t v171 = 94; + const int32_t v172 = 90; + const int32_t v173 = 86; + const int32_t v174 = 82; + const int32_t v175 = 93; + const int32_t v176 = 89; + const int32_t v177 = 85; + const int32_t v178 = 81; + const int32_t v179 = 92; + const int32_t v180 = 88; + const int32_t v181 = 84; + const int32_t v182 = 80; + const int32_t v183 = 79; + const int32_t v184 = 75; + const int32_t v185 = 71; + const int32_t v186 = 67; + const int32_t v187 = 78; + const int32_t v188 = 74; + const int32_t v189 = 70; + const int32_t v190 = 66; + const int32_t v191 = 77; + const int32_t v192 = 73; + const int32_t v193 = 69; + const int32_t v194 = 65; + const int32_t v195 = 76; + const int32_t v196 = 72; + const int32_t v197 = 68; + const int32_t v198 = 64; + const int32_t v199 = 63; + const int32_t v200 = 59; + const int32_t v201 = 55; + const int32_t v202 = 51; + const int32_t v203 = 62; + const int32_t v204 = 58; + const int32_t v205 = 54; + const int32_t v206 = 50; + const int32_t v207 = 61; + const int32_t v208 = 57; + const int32_t v209 = 53; + const int32_t v210 = 49; + const int32_t v211 = 60; + const int32_t v212 = 56; + const int32_t v213 = 52; + const int32_t v214 = 48; + const int32_t v215 = 47; + const int32_t v216 = 43; + const int32_t v217 = 39; + const int32_t v218 = 35; + const int32_t v219 = 46; + const int32_t v220 = 42; + const int32_t v221 = 38; + const int32_t v222 = 34; + const int32_t v223 = 45; + const int32_t v224 = 41; + const int32_t v225 = 37; + const int32_t v226 = 33; + const int32_t v227 = 44; + const int32_t v228 = 40; + const int32_t v229 = 36; + const int32_t v230 = 32; + const int32_t v231 = 31; + const int32_t v232 = 27; + const int32_t v233 = 23; + const int32_t v234 = 19; + const int32_t v235 = 30; + const int32_t v236 = 26; + const int32_t v237 = 22; + const int32_t v238 = 18; + const int32_t v239 = 29; + const int32_t v240 = 25; + const int32_t v241 = 21; + const int32_t v242 = 17; + const int32_t v243 = 28; + const int32_t v244 = 24; + const int32_t v245 = 20; + const int32_t v246 = 15; + const int32_t v247 = 11; + const int32_t v248 = 7; + const int32_t v249 = 3; + const int32_t v250 = 14; + const int32_t v251 = 10; + const int32_t v252 = 6; + const int32_t v253 = 2; + const int32_t v254 = 13; + const int32_t v255 = 9; + const int32_t v256 = 5; + const float v257 = 9.99999997E-7f; + const int32_t v258 = 12; + const int32_t v259 = 8; + const int32_t v260 = 0; + const int32_t v261 = 4; + const int32_t v262 = 1; + const int32_t v263 = 16; + const int64_t v264 = 4800; + const int64_t v265 = 3776; + const int64_t v266 = 3264; + const int64_t v267 = 2752; + const int64_t v268 = 2240; + const int64_t v269 = 1728; + const int64_t v270 = 1216; + const int64_t v271 = 128; + const int64_t v272 = 64; + const int64_t v273 = 0; + const int64_t v274 = 5376; + const int64_t v275 = 4864; + const int64_t v276 = 4288; + const int64_t v277 = 704; + const int64_t v278 = 192; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v279 = Tile(v263, v261); + uint64_t v280 = (uint64_t) v278; + TASSIGN(v279, v280); + pto::Shape<1, 1, 1, 16, 4> v281 = pto::Shape<1, 1, 1, 16, 4>(); + pto::Stride<256, 256, 256, 16, 1> v282 = pto::Stride<256, 256, 256, 16, 1>(); + GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> v283 = GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v263 + v6 * (unsigned) v262), v281, v282); + TLOAD(v279, v283); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v284 = Tile(v263, v259); + uint64_t v285 = (uint64_t) v277; + TASSIGN(v284, v285); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TFILLPAD(v284, v279); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v286 = Tile(v263, v261); + uint64_t v287 = (uint64_t) v278; + TASSIGN(v286, v287); + pto::Shape<1, 1, 1, 16, 4> v288 = pto::Shape<1, 1, 1, 16, 4>(); + pto::Stride<256, 256, 256, 16, 1> v289 = pto::Stride<256, 256, 256, 16, 1>(); + GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> v290 = GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v263 + v5 * (unsigned) v262), v288, v289); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v286, v290); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v291 = Tile(v263, v259); + uint64_t v292 = (uint64_t) v276; + TASSIGN(v291, v292); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TFILLPAD(v291, v286); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + Tile v293 = Tile(v263, v261); + uint64_t v294 = (uint64_t) v278; + TASSIGN(v293, v294); + pto::Shape<1, 1, 1, 16, 4> v295 = pto::Shape<1, 1, 1, 16, 4>(); + pto::Stride<256, 256, 256, 16, 1> v296 = pto::Stride<256, 256, 256, 16, 1>(); + GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> v297 = GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v263 + v4 * (unsigned) v262), v295, v296); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v293, v297); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + Tile v298 = Tile(v263, v259); + uint64_t v299 = (uint64_t) v275; + TASSIGN(v298, v299); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + TFILLPAD(v298, v293); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + Tile v300 = Tile(v263, v261); + uint64_t v301 = (uint64_t) v278; + TASSIGN(v300, v301); + pto::Shape<1, 1, 1, 16, 4> v302 = pto::Shape<1, 1, 1, 16, 4>(); + pto::Stride<256, 256, 256, 16, 1> v303 = pto::Stride<256, 256, 256, 16, 1>(); + GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> v304 = GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v263 + v3 * (unsigned) v262), v302, v303); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v300, v304); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + Tile v305 = Tile(v263, v259); + uint64_t v306 = (uint64_t) v274; + TASSIGN(v305, v306); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + TFILLPAD(v305, v300); + Tile v307 = Tile(v263, v262); + uint64_t v308 = (uint64_t) v273; + TASSIGN(v307, v308); + Tile v309 = Tile(v263, v262); + uint64_t v310 = (uint64_t) v272; + TASSIGN(v309, v310); + Tile v311 = Tile(v263, v262); + uint64_t v312 = (uint64_t) v271; + TASSIGN(v311, v312); + TROWMAX(v311, v284, v307); + Tile v313 = Tile(v263, v259); + uint64_t v314 = (uint64_t) v277; + TASSIGN(v313, v314); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(v313, v284, v311); + Tile v315 = Tile(v263, v259); + uint64_t v316 = (uint64_t) v277; + TASSIGN(v315, v316); + pipe_barrier(PIPE_V); + TEXP(v315, v313); + Tile v317 = Tile(v263, v262); + uint64_t v318 = (uint64_t) v271; + TASSIGN(v317, v318); + TROWMAX(v317, v291, v307); + Tile v319 = Tile(v263, v259); + uint64_t v320 = (uint64_t) v276; + TASSIGN(v319, v320); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(v319, v291, v317); + Tile v321 = Tile(v263, v259); + uint64_t v322 = (uint64_t) v276; + TASSIGN(v321, v322); + pipe_barrier(PIPE_V); + TEXP(v321, v319); + Tile v323 = Tile(v263, v262); + uint64_t v324 = (uint64_t) v271; + TASSIGN(v323, v324); + TROWMAX(v323, v298, v307); + Tile v325 = Tile(v263, v259); + uint64_t v326 = (uint64_t) v275; + TASSIGN(v325, v326); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(v325, v298, v323); + Tile v327 = Tile(v263, v259); + uint64_t v328 = (uint64_t) v275; + TASSIGN(v327, v328); + pipe_barrier(PIPE_V); + TEXP(v327, v325); + Tile v329 = Tile(v263, v262); + uint64_t v330 = (uint64_t) v273; + TASSIGN(v329, v330); + TROWMAX(v329, v305, v307); + Tile v331 = Tile(v263, v259); + uint64_t v332 = (uint64_t) v274; + TASSIGN(v331, v332); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(v331, v305, v329); + Tile v333 = Tile(v263, v259); + uint64_t v334 = (uint64_t) v274; + TASSIGN(v333, v334); + pipe_barrier(PIPE_V); + TEXP(v333, v331); + Tile v335 = Tile(v263, v262); + uint64_t v336 = (uint64_t) v273; + TASSIGN(v335, v336); + TROWSUM(v335, v315, v309); + Tile v337 = Tile(v263, v259); + uint64_t v338 = (uint64_t) v277; + TASSIGN(v337, v338); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v337, v315, v335); + Tile v339 = Tile(v263, v259); + uint64_t v340 = (uint64_t) v277; + TASSIGN(v339, v340); + pipe_barrier(PIPE_V); + TADDS(v339, v337, v257); + Tile v341 = Tile(v263, v262); + uint64_t v342 = (uint64_t) v273; + TASSIGN(v341, v342); + TROWSUM(v341, v321, v309); + Tile v343 = Tile(v263, v259); + uint64_t v344 = (uint64_t) v276; + TASSIGN(v343, v344); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v343, v321, v341); + Tile v345 = Tile(v263, v259); + uint64_t v346 = (uint64_t) v276; + TASSIGN(v345, v346); + pipe_barrier(PIPE_V); + TADDS(v345, v343, v257); + Tile v347 = Tile(v263, v262); + uint64_t v348 = (uint64_t) v273; + TASSIGN(v347, v348); + TROWSUM(v347, v327, v309); + Tile v349 = Tile(v263, v259); + uint64_t v350 = (uint64_t) v275; + TASSIGN(v349, v350); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v349, v327, v347); + Tile v351 = Tile(v263, v259); + uint64_t v352 = (uint64_t) v275; + TASSIGN(v351, v352); + pipe_barrier(PIPE_V); + TADDS(v351, v349, v257); + Tile v353 = Tile(v263, v262); + uint64_t v354 = (uint64_t) v273; + TASSIGN(v353, v354); + TROWSUM(v353, v333, v309); + Tile v355 = Tile(v263, v259); + uint64_t v356 = (uint64_t) v274; + TASSIGN(v355, v356); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v355, v333, v353); + Tile v357 = Tile(v263, v259); + uint64_t v358 = (uint64_t) v274; + TASSIGN(v357, v358); + pipe_barrier(PIPE_V); + TADDS(v357, v355, v257); + v339.SetValidShape(v263, v261); + Tile v359 = Tile(v263, v259); + uint64_t v360 = (uint64_t) v270; + TASSIGN(v359, v360); + TFILLPAD(v359, v339); + v345.SetValidShape(v263, v261); + Tile v361 = Tile(v263, v259); + uint64_t v362 = (uint64_t) v269; + TASSIGN(v361, v362); + TFILLPAD(v361, v345); + v351.SetValidShape(v263, v261); + Tile v363 = Tile(v263, v259); + uint64_t v364 = (uint64_t) v268; + TASSIGN(v363, v364); + TFILLPAD(v363, v351); + v357.SetValidShape(v263, v261); + Tile v365 = Tile(v263, v259); + uint64_t v366 = (uint64_t) v267; + TASSIGN(v365, v366); + pipe_barrier(PIPE_V); + TFILLPAD(v365, v357); + Tile v367 = Tile(v263, v262); + uint64_t v368 = (uint64_t) v273; + TASSIGN(v367, v368); + Tile v369 = Tile(v263, v259); + uint64_t v370 = (uint64_t) v266; + TASSIGN(v369, v370); + TADD(v369, v359, v361); + Tile v371 = Tile(v263, v259); + uint64_t v372 = (uint64_t) v265; + TASSIGN(v371, v372); + pipe_barrier(PIPE_V); + TADD(v371, v363, v365); + Tile v373 = Tile(v263, v259); + uint64_t v374 = (uint64_t) v266; + TASSIGN(v373, v374); + pipe_barrier(PIPE_V); + TADD(v373, v369, v371); + Tile v375 = Tile(v263, v259); + uint64_t v376 = (uint64_t) v266; + TASSIGN(v375, v376); + pipe_barrier(PIPE_V); + TADDS(v375, v373, v257); + Tile v377 = Tile(v263, v259); + uint64_t v378 = (uint64_t) v270; + TASSIGN(v377, v378); + pipe_barrier(PIPE_V); + TDIV(v377, v359, v375); + Tile v379 = Tile(v263, v259); + uint64_t v380 = (uint64_t) v269; + TASSIGN(v379, v380); + TDIV(v379, v361, v375); + Tile v381 = Tile(v263, v259); + uint64_t v382 = (uint64_t) v268; + TASSIGN(v381, v382); + TDIV(v381, v363, v375); + Tile v383 = Tile(v263, v259); + uint64_t v384 = (uint64_t) v267; + TASSIGN(v383, v384); + TDIV(v383, v365, v375); + Tile v385 = Tile(v263, v262); + uint64_t v386 = (uint64_t) v272; + TASSIGN(v385, v386); + pipe_barrier(PIPE_V); + TROWSUM(v385, v377, v367); + Tile v387 = Tile(v262, v263); + uint64_t v388 = (uint64_t) v272; + TASSIGN(v387, v388); + Tile v389 = Tile(v262, v263); + uint64_t v390 = (uint64_t) v264; + TASSIGN(v389, v390); + pipe_barrier(PIPE_V); + TADDS(v389, v387, v257); + Tile v391 = Tile(v263, v262); + uint64_t v392 = (uint64_t) v264; + TASSIGN(v391, v392); + Tile v393 = Tile(v263, v259); + uint64_t v394 = (uint64_t) v270; + TASSIGN(v393, v394); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v393, v377, v391); + Tile v395 = Tile(v263, v262); + uint64_t v396 = (uint64_t) v272; + TASSIGN(v395, v396); + TROWSUM(v395, v379, v367); + Tile v397 = Tile(v262, v263); + uint64_t v398 = (uint64_t) v272; + TASSIGN(v397, v398); + Tile v399 = Tile(v262, v263); + uint64_t v400 = (uint64_t) v264; + TASSIGN(v399, v400); + pipe_barrier(PIPE_V); + TADDS(v399, v397, v257); + Tile v401 = Tile(v263, v262); + uint64_t v402 = (uint64_t) v264; + TASSIGN(v401, v402); + Tile v403 = Tile(v263, v259); + uint64_t v404 = (uint64_t) v269; + TASSIGN(v403, v404); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v403, v379, v401); + Tile v405 = Tile(v263, v262); + uint64_t v406 = (uint64_t) v272; + TASSIGN(v405, v406); + TROWSUM(v405, v381, v367); + Tile v407 = Tile(v262, v263); + uint64_t v408 = (uint64_t) v272; + TASSIGN(v407, v408); + Tile v409 = Tile(v262, v263); + uint64_t v410 = (uint64_t) v264; + TASSIGN(v409, v410); + pipe_barrier(PIPE_V); + TADDS(v409, v407, v257); + Tile v411 = Tile(v263, v262); + uint64_t v412 = (uint64_t) v264; + TASSIGN(v411, v412); + Tile v413 = Tile(v263, v259); + uint64_t v414 = (uint64_t) v268; + TASSIGN(v413, v414); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v413, v381, v411); + Tile v415 = Tile(v263, v262); + uint64_t v416 = (uint64_t) v272; + TASSIGN(v415, v416); + TROWSUM(v415, v383, v367); + Tile v417 = Tile(v262, v263); + uint64_t v418 = (uint64_t) v272; + TASSIGN(v417, v418); + Tile v419 = Tile(v262, v263); + uint64_t v420 = (uint64_t) v264; + TASSIGN(v419, v420); + pipe_barrier(PIPE_V); + TADDS(v419, v417, v257); + Tile v421 = Tile(v263, v262); + uint64_t v422 = (uint64_t) v264; + TASSIGN(v421, v422); + Tile v423 = Tile(v263, v259); + uint64_t v424 = (uint64_t) v267; + TASSIGN(v423, v424); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v423, v383, v421); + Tile v425 = Tile(v263, v259); + uint64_t v426 = (uint64_t) v266; + TASSIGN(v425, v426); + TADD(v425, v393, v403); + Tile v427 = Tile(v263, v259); + uint64_t v428 = (uint64_t) v265; + TASSIGN(v427, v428); + pipe_barrier(PIPE_V); + TADD(v427, v413, v423); + Tile v429 = Tile(v263, v259); + uint64_t v430 = (uint64_t) v266; + TASSIGN(v429, v430); + pipe_barrier(PIPE_V); + TADD(v429, v425, v427); + Tile v431 = Tile(v263, v259); + uint64_t v432 = (uint64_t) v266; + TASSIGN(v431, v432); + pipe_barrier(PIPE_V); + TADDS(v431, v429, v257); + Tile v433 = Tile(v263, v259); + uint64_t v434 = (uint64_t) v270; + TASSIGN(v433, v434); + pipe_barrier(PIPE_V); + TDIV(v433, v393, v431); + Tile v435 = Tile(v263, v259); + uint64_t v436 = (uint64_t) v269; + TASSIGN(v435, v436); + TDIV(v435, v403, v431); + Tile v437 = Tile(v263, v259); + uint64_t v438 = (uint64_t) v268; + TASSIGN(v437, v438); + TDIV(v437, v413, v431); + Tile v439 = Tile(v263, v259); + uint64_t v440 = (uint64_t) v267; + TASSIGN(v439, v440); + TDIV(v439, v423, v431); + Tile v441 = Tile(v263, v262); + uint64_t v442 = (uint64_t) v272; + TASSIGN(v441, v442); + pipe_barrier(PIPE_V); + TROWSUM(v441, v433, v367); + Tile v443 = Tile(v262, v263); + uint64_t v444 = (uint64_t) v272; + TASSIGN(v443, v444); + Tile v445 = Tile(v262, v263); + uint64_t v446 = (uint64_t) v264; + TASSIGN(v445, v446); + pipe_barrier(PIPE_V); + TADDS(v445, v443, v257); + Tile v447 = Tile(v263, v262); + uint64_t v448 = (uint64_t) v264; + TASSIGN(v447, v448); + Tile v449 = Tile(v263, v259); + uint64_t v450 = (uint64_t) v270; + TASSIGN(v449, v450); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v449, v433, v447); + Tile v451 = Tile(v263, v262); + uint64_t v452 = (uint64_t) v272; + TASSIGN(v451, v452); + TROWSUM(v451, v435, v367); + Tile v453 = Tile(v262, v263); + uint64_t v454 = (uint64_t) v272; + TASSIGN(v453, v454); + Tile v455 = Tile(v262, v263); + uint64_t v456 = (uint64_t) v264; + TASSIGN(v455, v456); + pipe_barrier(PIPE_V); + TADDS(v455, v453, v257); + Tile v457 = Tile(v263, v262); + uint64_t v458 = (uint64_t) v264; + TASSIGN(v457, v458); + Tile v459 = Tile(v263, v259); + uint64_t v460 = (uint64_t) v269; + TASSIGN(v459, v460); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v459, v435, v457); + Tile v461 = Tile(v263, v262); + uint64_t v462 = (uint64_t) v272; + TASSIGN(v461, v462); + TROWSUM(v461, v437, v367); + Tile v463 = Tile(v262, v263); + uint64_t v464 = (uint64_t) v272; + TASSIGN(v463, v464); + Tile v465 = Tile(v262, v263); + uint64_t v466 = (uint64_t) v264; + TASSIGN(v465, v466); + pipe_barrier(PIPE_V); + TADDS(v465, v463, v257); + Tile v467 = Tile(v263, v262); + uint64_t v468 = (uint64_t) v264; + TASSIGN(v467, v468); + Tile v469 = Tile(v263, v259); + uint64_t v470 = (uint64_t) v268; + TASSIGN(v469, v470); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v469, v437, v467); + Tile v471 = Tile(v263, v262); + uint64_t v472 = (uint64_t) v272; + TASSIGN(v471, v472); + TROWSUM(v471, v439, v367); + Tile v473 = Tile(v262, v263); + uint64_t v474 = (uint64_t) v272; + TASSIGN(v473, v474); + Tile v475 = Tile(v262, v263); + uint64_t v476 = (uint64_t) v264; + TASSIGN(v475, v476); + pipe_barrier(PIPE_V); + TADDS(v475, v473, v257); + Tile v477 = Tile(v263, v262); + uint64_t v478 = (uint64_t) v264; + TASSIGN(v477, v478); + Tile v479 = Tile(v263, v259); + uint64_t v480 = (uint64_t) v267; + TASSIGN(v479, v480); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v479, v439, v477); + Tile v481 = Tile(v263, v259); + uint64_t v482 = (uint64_t) v266; + TASSIGN(v481, v482); + TADD(v481, v449, v459); + Tile v483 = Tile(v263, v259); + uint64_t v484 = (uint64_t) v265; + TASSIGN(v483, v484); + pipe_barrier(PIPE_V); + TADD(v483, v469, v479); + Tile v485 = Tile(v263, v259); + uint64_t v486 = (uint64_t) v266; + TASSIGN(v485, v486); + pipe_barrier(PIPE_V); + TADD(v485, v481, v483); + Tile v487 = Tile(v263, v259); + uint64_t v488 = (uint64_t) v266; + TASSIGN(v487, v488); + pipe_barrier(PIPE_V); + TADDS(v487, v485, v257); + Tile v489 = Tile(v263, v259); + uint64_t v490 = (uint64_t) v270; + TASSIGN(v489, v490); + pipe_barrier(PIPE_V); + TDIV(v489, v449, v487); + Tile v491 = Tile(v263, v259); + uint64_t v492 = (uint64_t) v269; + TASSIGN(v491, v492); + TDIV(v491, v459, v487); + Tile v493 = Tile(v263, v259); + uint64_t v494 = (uint64_t) v268; + TASSIGN(v493, v494); + TDIV(v493, v469, v487); + Tile v495 = Tile(v263, v259); + uint64_t v496 = (uint64_t) v267; + TASSIGN(v495, v496); + TDIV(v495, v479, v487); + Tile v497 = Tile(v263, v262); + uint64_t v498 = (uint64_t) v272; + TASSIGN(v497, v498); + pipe_barrier(PIPE_V); + TROWSUM(v497, v489, v367); + Tile v499 = Tile(v262, v263); + uint64_t v500 = (uint64_t) v272; + TASSIGN(v499, v500); + Tile v501 = Tile(v262, v263); + uint64_t v502 = (uint64_t) v264; + TASSIGN(v501, v502); + pipe_barrier(PIPE_V); + TADDS(v501, v499, v257); + Tile v503 = Tile(v263, v262); + uint64_t v504 = (uint64_t) v264; + TASSIGN(v503, v504); + Tile v505 = Tile(v263, v259); + uint64_t v506 = (uint64_t) v270; + TASSIGN(v505, v506); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v505, v489, v503); + Tile v507 = Tile(v263, v262); + uint64_t v508 = (uint64_t) v272; + TASSIGN(v507, v508); + TROWSUM(v507, v491, v367); + Tile v509 = Tile(v262, v263); + uint64_t v510 = (uint64_t) v272; + TASSIGN(v509, v510); + Tile v511 = Tile(v262, v263); + uint64_t v512 = (uint64_t) v264; + TASSIGN(v511, v512); + pipe_barrier(PIPE_V); + TADDS(v511, v509, v257); + Tile v513 = Tile(v263, v262); + uint64_t v514 = (uint64_t) v264; + TASSIGN(v513, v514); + Tile v515 = Tile(v263, v259); + uint64_t v516 = (uint64_t) v269; + TASSIGN(v515, v516); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v515, v491, v513); + Tile v517 = Tile(v263, v262); + uint64_t v518 = (uint64_t) v272; + TASSIGN(v517, v518); + TROWSUM(v517, v493, v367); + Tile v519 = Tile(v262, v263); + uint64_t v520 = (uint64_t) v272; + TASSIGN(v519, v520); + Tile v521 = Tile(v262, v263); + uint64_t v522 = (uint64_t) v264; + TASSIGN(v521, v522); + pipe_barrier(PIPE_V); + TADDS(v521, v519, v257); + Tile v523 = Tile(v263, v262); + uint64_t v524 = (uint64_t) v264; + TASSIGN(v523, v524); + Tile v525 = Tile(v263, v259); + uint64_t v526 = (uint64_t) v268; + TASSIGN(v525, v526); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v525, v493, v523); + Tile v527 = Tile(v263, v262); + uint64_t v528 = (uint64_t) v272; + TASSIGN(v527, v528); + TROWSUM(v527, v495, v367); + Tile v529 = Tile(v262, v263); + uint64_t v530 = (uint64_t) v272; + TASSIGN(v529, v530); + Tile v531 = Tile(v262, v263); + uint64_t v532 = (uint64_t) v264; + TASSIGN(v531, v532); + pipe_barrier(PIPE_V); + TADDS(v531, v529, v257); + Tile v533 = Tile(v263, v262); + uint64_t v534 = (uint64_t) v264; + TASSIGN(v533, v534); + Tile v535 = Tile(v263, v259); + uint64_t v536 = (uint64_t) v267; + TASSIGN(v535, v536); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v535, v495, v533); + Tile v537 = Tile(v263, v259); + uint64_t v538 = (uint64_t) v266; + TASSIGN(v537, v538); + TADD(v537, v505, v515); + Tile v539 = Tile(v263, v259); + uint64_t v540 = (uint64_t) v265; + TASSIGN(v539, v540); + pipe_barrier(PIPE_V); + TADD(v539, v525, v535); + Tile v541 = Tile(v263, v259); + uint64_t v542 = (uint64_t) v266; + TASSIGN(v541, v542); + pipe_barrier(PIPE_V); + TADD(v541, v537, v539); + Tile v543 = Tile(v263, v259); + uint64_t v544 = (uint64_t) v266; + TASSIGN(v543, v544); + pipe_barrier(PIPE_V); + TADDS(v543, v541, v257); + Tile v545 = Tile(v263, v259); + uint64_t v546 = (uint64_t) v270; + TASSIGN(v545, v546); + pipe_barrier(PIPE_V); + TDIV(v545, v505, v543); + Tile v547 = Tile(v263, v259); + uint64_t v548 = (uint64_t) v269; + TASSIGN(v547, v548); + TDIV(v547, v515, v543); + Tile v549 = Tile(v263, v259); + uint64_t v550 = (uint64_t) v268; + TASSIGN(v549, v550); + TDIV(v549, v525, v543); + Tile v551 = Tile(v263, v259); + uint64_t v552 = (uint64_t) v267; + TASSIGN(v551, v552); + TDIV(v551, v535, v543); + Tile v553 = Tile(v263, v262); + uint64_t v554 = (uint64_t) v272; + TASSIGN(v553, v554); + pipe_barrier(PIPE_V); + TROWSUM(v553, v545, v367); + Tile v555 = Tile(v262, v263); + uint64_t v556 = (uint64_t) v272; + TASSIGN(v555, v556); + Tile v557 = Tile(v262, v263); + uint64_t v558 = (uint64_t) v264; + TASSIGN(v557, v558); + pipe_barrier(PIPE_V); + TADDS(v557, v555, v257); + Tile v559 = Tile(v263, v262); + uint64_t v560 = (uint64_t) v264; + TASSIGN(v559, v560); + Tile v561 = Tile(v263, v259); + uint64_t v562 = (uint64_t) v270; + TASSIGN(v561, v562); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v561, v545, v559); + Tile v563 = Tile(v263, v262); + uint64_t v564 = (uint64_t) v272; + TASSIGN(v563, v564); + TROWSUM(v563, v547, v367); + Tile v565 = Tile(v262, v263); + uint64_t v566 = (uint64_t) v272; + TASSIGN(v565, v566); + Tile v567 = Tile(v262, v263); + uint64_t v568 = (uint64_t) v264; + TASSIGN(v567, v568); + pipe_barrier(PIPE_V); + TADDS(v567, v565, v257); + Tile v569 = Tile(v263, v262); + uint64_t v570 = (uint64_t) v264; + TASSIGN(v569, v570); + Tile v571 = Tile(v263, v259); + uint64_t v572 = (uint64_t) v269; + TASSIGN(v571, v572); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v571, v547, v569); + Tile v573 = Tile(v263, v262); + uint64_t v574 = (uint64_t) v272; + TASSIGN(v573, v574); + TROWSUM(v573, v549, v367); + Tile v575 = Tile(v262, v263); + uint64_t v576 = (uint64_t) v272; + TASSIGN(v575, v576); + Tile v577 = Tile(v262, v263); + uint64_t v578 = (uint64_t) v264; + TASSIGN(v577, v578); + pipe_barrier(PIPE_V); + TADDS(v577, v575, v257); + Tile v579 = Tile(v263, v262); + uint64_t v580 = (uint64_t) v264; + TASSIGN(v579, v580); + Tile v581 = Tile(v263, v259); + uint64_t v582 = (uint64_t) v268; + TASSIGN(v581, v582); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v581, v549, v579); + Tile v583 = Tile(v263, v262); + uint64_t v584 = (uint64_t) v272; + TASSIGN(v583, v584); + TROWSUM(v583, v551, v367); + Tile v585 = Tile(v262, v263); + uint64_t v586 = (uint64_t) v272; + TASSIGN(v585, v586); + Tile v587 = Tile(v262, v263); + uint64_t v588 = (uint64_t) v264; + TASSIGN(v587, v588); + pipe_barrier(PIPE_V); + TADDS(v587, v585, v257); + Tile v589 = Tile(v263, v262); + uint64_t v590 = (uint64_t) v264; + TASSIGN(v589, v590); + Tile v591 = Tile(v263, v259); + uint64_t v592 = (uint64_t) v267; + TASSIGN(v591, v592); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v591, v551, v589); + Tile v593 = Tile(v263, v259); + uint64_t v594 = (uint64_t) v266; + TASSIGN(v593, v594); + TADD(v593, v561, v571); + Tile v595 = Tile(v263, v259); + uint64_t v596 = (uint64_t) v265; + TASSIGN(v595, v596); + pipe_barrier(PIPE_V); + TADD(v595, v581, v591); + Tile v597 = Tile(v263, v259); + uint64_t v598 = (uint64_t) v266; + TASSIGN(v597, v598); + pipe_barrier(PIPE_V); + TADD(v597, v593, v595); + Tile v599 = Tile(v263, v259); + uint64_t v600 = (uint64_t) v266; + TASSIGN(v599, v600); + pipe_barrier(PIPE_V); + TADDS(v599, v597, v257); + Tile v601 = Tile(v263, v259); + uint64_t v602 = (uint64_t) v270; + TASSIGN(v601, v602); + pipe_barrier(PIPE_V); + TDIV(v601, v561, v599); + Tile v603 = Tile(v263, v259); + uint64_t v604 = (uint64_t) v269; + TASSIGN(v603, v604); + TDIV(v603, v571, v599); + Tile v605 = Tile(v263, v259); + uint64_t v606 = (uint64_t) v268; + TASSIGN(v605, v606); + TDIV(v605, v581, v599); + Tile v607 = Tile(v263, v259); + uint64_t v608 = (uint64_t) v267; + TASSIGN(v607, v608); + TDIV(v607, v591, v599); + Tile v609 = Tile(v263, v262); + uint64_t v610 = (uint64_t) v272; + TASSIGN(v609, v610); + pipe_barrier(PIPE_V); + TROWSUM(v609, v601, v367); + Tile v611 = Tile(v262, v263); + uint64_t v612 = (uint64_t) v272; + TASSIGN(v611, v612); + Tile v613 = Tile(v262, v263); + uint64_t v614 = (uint64_t) v264; + TASSIGN(v613, v614); + pipe_barrier(PIPE_V); + TADDS(v613, v611, v257); + Tile v615 = Tile(v263, v262); + uint64_t v616 = (uint64_t) v264; + TASSIGN(v615, v616); + Tile v617 = Tile(v263, v259); + uint64_t v618 = (uint64_t) v270; + TASSIGN(v617, v618); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v617, v601, v615); + Tile v619 = Tile(v263, v262); + uint64_t v620 = (uint64_t) v272; + TASSIGN(v619, v620); + TROWSUM(v619, v603, v367); + Tile v621 = Tile(v262, v263); + uint64_t v622 = (uint64_t) v272; + TASSIGN(v621, v622); + Tile v623 = Tile(v262, v263); + uint64_t v624 = (uint64_t) v264; + TASSIGN(v623, v624); + pipe_barrier(PIPE_V); + TADDS(v623, v621, v257); + Tile v625 = Tile(v263, v262); + uint64_t v626 = (uint64_t) v264; + TASSIGN(v625, v626); + Tile v627 = Tile(v263, v259); + uint64_t v628 = (uint64_t) v269; + TASSIGN(v627, v628); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v627, v603, v625); + Tile v629 = Tile(v263, v262); + uint64_t v630 = (uint64_t) v272; + TASSIGN(v629, v630); + TROWSUM(v629, v605, v367); + Tile v631 = Tile(v262, v263); + uint64_t v632 = (uint64_t) v272; + TASSIGN(v631, v632); + Tile v633 = Tile(v262, v263); + uint64_t v634 = (uint64_t) v264; + TASSIGN(v633, v634); + pipe_barrier(PIPE_V); + TADDS(v633, v631, v257); + Tile v635 = Tile(v263, v262); + uint64_t v636 = (uint64_t) v264; + TASSIGN(v635, v636); + Tile v637 = Tile(v263, v259); + uint64_t v638 = (uint64_t) v268; + TASSIGN(v637, v638); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v637, v605, v635); + Tile v639 = Tile(v263, v262); + uint64_t v640 = (uint64_t) v272; + TASSIGN(v639, v640); + TROWSUM(v639, v607, v367); + Tile v641 = Tile(v262, v263); + uint64_t v642 = (uint64_t) v272; + TASSIGN(v641, v642); + Tile v643 = Tile(v262, v263); + uint64_t v644 = (uint64_t) v264; + TASSIGN(v643, v644); + pipe_barrier(PIPE_V); + TADDS(v643, v641, v257); + Tile v645 = Tile(v263, v262); + uint64_t v646 = (uint64_t) v264; + TASSIGN(v645, v646); + Tile v647 = Tile(v263, v259); + uint64_t v648 = (uint64_t) v267; + TASSIGN(v647, v648); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v647, v607, v645); + Tile v649 = Tile(v263, v259); + uint64_t v650 = (uint64_t) v266; + TASSIGN(v649, v650); + TADD(v649, v617, v627); + Tile v651 = Tile(v263, v259); + uint64_t v652 = (uint64_t) v265; + TASSIGN(v651, v652); + pipe_barrier(PIPE_V); + TADD(v651, v637, v647); + Tile v653 = Tile(v263, v259); + uint64_t v654 = (uint64_t) v266; + TASSIGN(v653, v654); + pipe_barrier(PIPE_V); + TADD(v653, v649, v651); + Tile v655 = Tile(v263, v259); + uint64_t v656 = (uint64_t) v266; + TASSIGN(v655, v656); + pipe_barrier(PIPE_V); + TADDS(v655, v653, v257); + Tile v657 = Tile(v263, v259); + uint64_t v658 = (uint64_t) v270; + TASSIGN(v657, v658); + pipe_barrier(PIPE_V); + TDIV(v657, v617, v655); + Tile v659 = Tile(v263, v259); + uint64_t v660 = (uint64_t) v269; + TASSIGN(v659, v660); + TDIV(v659, v627, v655); + Tile v661 = Tile(v263, v259); + uint64_t v662 = (uint64_t) v268; + TASSIGN(v661, v662); + TDIV(v661, v637, v655); + Tile v663 = Tile(v263, v259); + uint64_t v664 = (uint64_t) v267; + TASSIGN(v663, v664); + TDIV(v663, v647, v655); + Tile v665 = Tile(v263, v262); + uint64_t v666 = (uint64_t) v272; + TASSIGN(v665, v666); + pipe_barrier(PIPE_V); + TROWSUM(v665, v657, v367); + Tile v667 = Tile(v262, v263); + uint64_t v668 = (uint64_t) v272; + TASSIGN(v667, v668); + Tile v669 = Tile(v262, v263); + uint64_t v670 = (uint64_t) v264; + TASSIGN(v669, v670); + pipe_barrier(PIPE_V); + TADDS(v669, v667, v257); + Tile v671 = Tile(v263, v262); + uint64_t v672 = (uint64_t) v264; + TASSIGN(v671, v672); + Tile v673 = Tile(v263, v259); + uint64_t v674 = (uint64_t) v270; + TASSIGN(v673, v674); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v673, v657, v671); + Tile v675 = Tile(v263, v262); + uint64_t v676 = (uint64_t) v272; + TASSIGN(v675, v676); + TROWSUM(v675, v659, v367); + Tile v677 = Tile(v262, v263); + uint64_t v678 = (uint64_t) v272; + TASSIGN(v677, v678); + Tile v679 = Tile(v262, v263); + uint64_t v680 = (uint64_t) v264; + TASSIGN(v679, v680); + pipe_barrier(PIPE_V); + TADDS(v679, v677, v257); + Tile v681 = Tile(v263, v262); + uint64_t v682 = (uint64_t) v264; + TASSIGN(v681, v682); + Tile v683 = Tile(v263, v259); + uint64_t v684 = (uint64_t) v269; + TASSIGN(v683, v684); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v683, v659, v681); + Tile v685 = Tile(v263, v262); + uint64_t v686 = (uint64_t) v272; + TASSIGN(v685, v686); + TROWSUM(v685, v661, v367); + Tile v687 = Tile(v262, v263); + uint64_t v688 = (uint64_t) v272; + TASSIGN(v687, v688); + Tile v689 = Tile(v262, v263); + uint64_t v690 = (uint64_t) v264; + TASSIGN(v689, v690); + pipe_barrier(PIPE_V); + TADDS(v689, v687, v257); + Tile v691 = Tile(v263, v262); + uint64_t v692 = (uint64_t) v264; + TASSIGN(v691, v692); + Tile v693 = Tile(v263, v259); + uint64_t v694 = (uint64_t) v268; + TASSIGN(v693, v694); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v693, v661, v691); + Tile v695 = Tile(v263, v262); + uint64_t v696 = (uint64_t) v272; + TASSIGN(v695, v696); + TROWSUM(v695, v663, v367); + Tile v697 = Tile(v262, v263); + uint64_t v698 = (uint64_t) v272; + TASSIGN(v697, v698); + Tile v699 = Tile(v262, v263); + uint64_t v700 = (uint64_t) v264; + TASSIGN(v699, v700); + pipe_barrier(PIPE_V); + TADDS(v699, v697, v257); + Tile v701 = Tile(v263, v262); + uint64_t v702 = (uint64_t) v264; + TASSIGN(v701, v702); + Tile v703 = Tile(v263, v259); + uint64_t v704 = (uint64_t) v267; + TASSIGN(v703, v704); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v703, v663, v701); + Tile v705 = Tile(v263, v259); + uint64_t v706 = (uint64_t) v266; + TASSIGN(v705, v706); + TADD(v705, v673, v683); + Tile v707 = Tile(v263, v259); + uint64_t v708 = (uint64_t) v265; + TASSIGN(v707, v708); + pipe_barrier(PIPE_V); + TADD(v707, v693, v703); + Tile v709 = Tile(v263, v259); + uint64_t v710 = (uint64_t) v266; + TASSIGN(v709, v710); + pipe_barrier(PIPE_V); + TADD(v709, v705, v707); + Tile v711 = Tile(v263, v259); + uint64_t v712 = (uint64_t) v266; + TASSIGN(v711, v712); + pipe_barrier(PIPE_V); + TADDS(v711, v709, v257); + Tile v713 = Tile(v263, v259); + uint64_t v714 = (uint64_t) v270; + TASSIGN(v713, v714); + pipe_barrier(PIPE_V); + TDIV(v713, v673, v711); + Tile v715 = Tile(v263, v259); + uint64_t v716 = (uint64_t) v269; + TASSIGN(v715, v716); + TDIV(v715, v683, v711); + Tile v717 = Tile(v263, v259); + uint64_t v718 = (uint64_t) v268; + TASSIGN(v717, v718); + TDIV(v717, v693, v711); + Tile v719 = Tile(v263, v259); + uint64_t v720 = (uint64_t) v267; + TASSIGN(v719, v720); + TDIV(v719, v703, v711); + Tile v721 = Tile(v263, v262); + uint64_t v722 = (uint64_t) v272; + TASSIGN(v721, v722); + pipe_barrier(PIPE_V); + TROWSUM(v721, v713, v367); + Tile v723 = Tile(v262, v263); + uint64_t v724 = (uint64_t) v272; + TASSIGN(v723, v724); + Tile v725 = Tile(v262, v263); + uint64_t v726 = (uint64_t) v264; + TASSIGN(v725, v726); + pipe_barrier(PIPE_V); + TADDS(v725, v723, v257); + Tile v727 = Tile(v263, v262); + uint64_t v728 = (uint64_t) v264; + TASSIGN(v727, v728); + Tile v729 = Tile(v263, v259); + uint64_t v730 = (uint64_t) v270; + TASSIGN(v729, v730); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v729, v713, v727); + Tile v731 = Tile(v263, v262); + uint64_t v732 = (uint64_t) v272; + TASSIGN(v731, v732); + TROWSUM(v731, v715, v367); + Tile v733 = Tile(v262, v263); + uint64_t v734 = (uint64_t) v272; + TASSIGN(v733, v734); + Tile v735 = Tile(v262, v263); + uint64_t v736 = (uint64_t) v264; + TASSIGN(v735, v736); + pipe_barrier(PIPE_V); + TADDS(v735, v733, v257); + Tile v737 = Tile(v263, v262); + uint64_t v738 = (uint64_t) v264; + TASSIGN(v737, v738); + Tile v739 = Tile(v263, v259); + uint64_t v740 = (uint64_t) v269; + TASSIGN(v739, v740); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v739, v715, v737); + Tile v741 = Tile(v263, v262); + uint64_t v742 = (uint64_t) v272; + TASSIGN(v741, v742); + TROWSUM(v741, v717, v367); + Tile v743 = Tile(v262, v263); + uint64_t v744 = (uint64_t) v272; + TASSIGN(v743, v744); + Tile v745 = Tile(v262, v263); + uint64_t v746 = (uint64_t) v264; + TASSIGN(v745, v746); + pipe_barrier(PIPE_V); + TADDS(v745, v743, v257); + Tile v747 = Tile(v263, v262); + uint64_t v748 = (uint64_t) v264; + TASSIGN(v747, v748); + Tile v749 = Tile(v263, v259); + uint64_t v750 = (uint64_t) v268; + TASSIGN(v749, v750); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v749, v717, v747); + Tile v751 = Tile(v263, v262); + uint64_t v752 = (uint64_t) v272; + TASSIGN(v751, v752); + TROWSUM(v751, v719, v367); + Tile v753 = Tile(v262, v263); + uint64_t v754 = (uint64_t) v272; + TASSIGN(v753, v754); + Tile v755 = Tile(v262, v263); + uint64_t v756 = (uint64_t) v264; + TASSIGN(v755, v756); + pipe_barrier(PIPE_V); + TADDS(v755, v753, v257); + Tile v757 = Tile(v263, v262); + uint64_t v758 = (uint64_t) v264; + TASSIGN(v757, v758); + Tile v759 = Tile(v263, v259); + uint64_t v760 = (uint64_t) v267; + TASSIGN(v759, v760); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v759, v719, v757); + Tile v761 = Tile(v263, v259); + uint64_t v762 = (uint64_t) v266; + TASSIGN(v761, v762); + TADD(v761, v729, v739); + Tile v763 = Tile(v263, v259); + uint64_t v764 = (uint64_t) v265; + TASSIGN(v763, v764); + pipe_barrier(PIPE_V); + TADD(v763, v749, v759); + Tile v765 = Tile(v263, v259); + uint64_t v766 = (uint64_t) v266; + TASSIGN(v765, v766); + pipe_barrier(PIPE_V); + TADD(v765, v761, v763); + Tile v767 = Tile(v263, v259); + uint64_t v768 = (uint64_t) v266; + TASSIGN(v767, v768); + pipe_barrier(PIPE_V); + TADDS(v767, v765, v257); + Tile v769 = Tile(v263, v259); + uint64_t v770 = (uint64_t) v270; + TASSIGN(v769, v770); + pipe_barrier(PIPE_V); + TDIV(v769, v729, v767); + Tile v771 = Tile(v263, v259); + uint64_t v772 = (uint64_t) v269; + TASSIGN(v771, v772); + TDIV(v771, v739, v767); + Tile v773 = Tile(v263, v259); + uint64_t v774 = (uint64_t) v268; + TASSIGN(v773, v774); + TDIV(v773, v749, v767); + Tile v775 = Tile(v263, v259); + uint64_t v776 = (uint64_t) v267; + TASSIGN(v775, v776); + TDIV(v775, v759, v767); + Tile v777 = Tile(v263, v262); + uint64_t v778 = (uint64_t) v272; + TASSIGN(v777, v778); + pipe_barrier(PIPE_V); + TROWSUM(v777, v769, v367); + Tile v779 = Tile(v262, v263); + uint64_t v780 = (uint64_t) v272; + TASSIGN(v779, v780); + Tile v781 = Tile(v262, v263); + uint64_t v782 = (uint64_t) v264; + TASSIGN(v781, v782); + pipe_barrier(PIPE_V); + TADDS(v781, v779, v257); + Tile v783 = Tile(v263, v262); + uint64_t v784 = (uint64_t) v264; + TASSIGN(v783, v784); + Tile v785 = Tile(v263, v259); + uint64_t v786 = (uint64_t) v270; + TASSIGN(v785, v786); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v785, v769, v783); + Tile v787 = Tile(v263, v262); + uint64_t v788 = (uint64_t) v272; + TASSIGN(v787, v788); + TROWSUM(v787, v771, v367); + Tile v789 = Tile(v262, v263); + uint64_t v790 = (uint64_t) v272; + TASSIGN(v789, v790); + Tile v791 = Tile(v262, v263); + uint64_t v792 = (uint64_t) v264; + TASSIGN(v791, v792); + pipe_barrier(PIPE_V); + TADDS(v791, v789, v257); + Tile v793 = Tile(v263, v262); + uint64_t v794 = (uint64_t) v264; + TASSIGN(v793, v794); + Tile v795 = Tile(v263, v259); + uint64_t v796 = (uint64_t) v269; + TASSIGN(v795, v796); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v795, v771, v793); + Tile v797 = Tile(v263, v262); + uint64_t v798 = (uint64_t) v272; + TASSIGN(v797, v798); + TROWSUM(v797, v773, v367); + Tile v799 = Tile(v262, v263); + uint64_t v800 = (uint64_t) v272; + TASSIGN(v799, v800); + Tile v801 = Tile(v262, v263); + uint64_t v802 = (uint64_t) v264; + TASSIGN(v801, v802); + pipe_barrier(PIPE_V); + TADDS(v801, v799, v257); + Tile v803 = Tile(v263, v262); + uint64_t v804 = (uint64_t) v264; + TASSIGN(v803, v804); + Tile v805 = Tile(v263, v259); + uint64_t v806 = (uint64_t) v268; + TASSIGN(v805, v806); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v805, v773, v803); + Tile v807 = Tile(v263, v262); + uint64_t v808 = (uint64_t) v272; + TASSIGN(v807, v808); + TROWSUM(v807, v775, v367); + Tile v809 = Tile(v262, v263); + uint64_t v810 = (uint64_t) v272; + TASSIGN(v809, v810); + Tile v811 = Tile(v262, v263); + uint64_t v812 = (uint64_t) v264; + TASSIGN(v811, v812); + pipe_barrier(PIPE_V); + TADDS(v811, v809, v257); + Tile v813 = Tile(v263, v262); + uint64_t v814 = (uint64_t) v264; + TASSIGN(v813, v814); + Tile v815 = Tile(v263, v259); + uint64_t v816 = (uint64_t) v267; + TASSIGN(v815, v816); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v815, v775, v813); + Tile v817 = Tile(v263, v259); + uint64_t v818 = (uint64_t) v266; + TASSIGN(v817, v818); + TADD(v817, v785, v795); + Tile v819 = Tile(v263, v259); + uint64_t v820 = (uint64_t) v265; + TASSIGN(v819, v820); + pipe_barrier(PIPE_V); + TADD(v819, v805, v815); + Tile v821 = Tile(v263, v259); + uint64_t v822 = (uint64_t) v266; + TASSIGN(v821, v822); + pipe_barrier(PIPE_V); + TADD(v821, v817, v819); + Tile v823 = Tile(v263, v259); + uint64_t v824 = (uint64_t) v266; + TASSIGN(v823, v824); + pipe_barrier(PIPE_V); + TADDS(v823, v821, v257); + Tile v825 = Tile(v263, v259); + uint64_t v826 = (uint64_t) v270; + TASSIGN(v825, v826); + pipe_barrier(PIPE_V); + TDIV(v825, v785, v823); + Tile v827 = Tile(v263, v259); + uint64_t v828 = (uint64_t) v269; + TASSIGN(v827, v828); + TDIV(v827, v795, v823); + Tile v829 = Tile(v263, v259); + uint64_t v830 = (uint64_t) v268; + TASSIGN(v829, v830); + TDIV(v829, v805, v823); + Tile v831 = Tile(v263, v259); + uint64_t v832 = (uint64_t) v267; + TASSIGN(v831, v832); + TDIV(v831, v815, v823); + Tile v833 = Tile(v263, v262); + uint64_t v834 = (uint64_t) v272; + TASSIGN(v833, v834); + pipe_barrier(PIPE_V); + TROWSUM(v833, v825, v367); + Tile v835 = Tile(v262, v263); + uint64_t v836 = (uint64_t) v272; + TASSIGN(v835, v836); + Tile v837 = Tile(v262, v263); + uint64_t v838 = (uint64_t) v264; + TASSIGN(v837, v838); + pipe_barrier(PIPE_V); + TADDS(v837, v835, v257); + Tile v839 = Tile(v263, v262); + uint64_t v840 = (uint64_t) v264; + TASSIGN(v839, v840); + Tile v841 = Tile(v263, v259); + uint64_t v842 = (uint64_t) v270; + TASSIGN(v841, v842); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v841, v825, v839); + Tile v843 = Tile(v263, v262); + uint64_t v844 = (uint64_t) v272; + TASSIGN(v843, v844); + TROWSUM(v843, v827, v367); + Tile v845 = Tile(v262, v263); + uint64_t v846 = (uint64_t) v272; + TASSIGN(v845, v846); + Tile v847 = Tile(v262, v263); + uint64_t v848 = (uint64_t) v264; + TASSIGN(v847, v848); + pipe_barrier(PIPE_V); + TADDS(v847, v845, v257); + Tile v849 = Tile(v263, v262); + uint64_t v850 = (uint64_t) v264; + TASSIGN(v849, v850); + Tile v851 = Tile(v263, v259); + uint64_t v852 = (uint64_t) v269; + TASSIGN(v851, v852); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v851, v827, v849); + Tile v853 = Tile(v263, v262); + uint64_t v854 = (uint64_t) v272; + TASSIGN(v853, v854); + TROWSUM(v853, v829, v367); + Tile v855 = Tile(v262, v263); + uint64_t v856 = (uint64_t) v272; + TASSIGN(v855, v856); + Tile v857 = Tile(v262, v263); + uint64_t v858 = (uint64_t) v264; + TASSIGN(v857, v858); + pipe_barrier(PIPE_V); + TADDS(v857, v855, v257); + Tile v859 = Tile(v263, v262); + uint64_t v860 = (uint64_t) v264; + TASSIGN(v859, v860); + Tile v861 = Tile(v263, v259); + uint64_t v862 = (uint64_t) v268; + TASSIGN(v861, v862); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v861, v829, v859); + Tile v863 = Tile(v263, v262); + uint64_t v864 = (uint64_t) v272; + TASSIGN(v863, v864); + TROWSUM(v863, v831, v367); + Tile v865 = Tile(v262, v263); + uint64_t v866 = (uint64_t) v272; + TASSIGN(v865, v866); + Tile v867 = Tile(v262, v263); + uint64_t v868 = (uint64_t) v264; + TASSIGN(v867, v868); + pipe_barrier(PIPE_V); + TADDS(v867, v865, v257); + Tile v869 = Tile(v263, v262); + uint64_t v870 = (uint64_t) v264; + TASSIGN(v869, v870); + Tile v871 = Tile(v263, v259); + uint64_t v872 = (uint64_t) v267; + TASSIGN(v871, v872); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v871, v831, v869); + Tile v873 = Tile(v263, v259); + uint64_t v874 = (uint64_t) v266; + TASSIGN(v873, v874); + TADD(v873, v841, v851); + Tile v875 = Tile(v263, v259); + uint64_t v876 = (uint64_t) v265; + TASSIGN(v875, v876); + pipe_barrier(PIPE_V); + TADD(v875, v861, v871); + Tile v877 = Tile(v263, v259); + uint64_t v878 = (uint64_t) v266; + TASSIGN(v877, v878); + pipe_barrier(PIPE_V); + TADD(v877, v873, v875); + Tile v879 = Tile(v263, v259); + uint64_t v880 = (uint64_t) v266; + TASSIGN(v879, v880); + pipe_barrier(PIPE_V); + TADDS(v879, v877, v257); + Tile v881 = Tile(v263, v259); + uint64_t v882 = (uint64_t) v270; + TASSIGN(v881, v882); + pipe_barrier(PIPE_V); + TDIV(v881, v841, v879); + Tile v883 = Tile(v263, v259); + uint64_t v884 = (uint64_t) v269; + TASSIGN(v883, v884); + TDIV(v883, v851, v879); + Tile v885 = Tile(v263, v259); + uint64_t v886 = (uint64_t) v268; + TASSIGN(v885, v886); + TDIV(v885, v861, v879); + Tile v887 = Tile(v263, v259); + uint64_t v888 = (uint64_t) v267; + TASSIGN(v887, v888); + TDIV(v887, v871, v879); + Tile v889 = Tile(v263, v262); + uint64_t v890 = (uint64_t) v272; + TASSIGN(v889, v890); + pipe_barrier(PIPE_V); + TROWSUM(v889, v881, v367); + Tile v891 = Tile(v262, v263); + uint64_t v892 = (uint64_t) v272; + TASSIGN(v891, v892); + Tile v893 = Tile(v262, v263); + uint64_t v894 = (uint64_t) v264; + TASSIGN(v893, v894); + pipe_barrier(PIPE_V); + TADDS(v893, v891, v257); + Tile v895 = Tile(v263, v262); + uint64_t v896 = (uint64_t) v264; + TASSIGN(v895, v896); + Tile v897 = Tile(v263, v259); + uint64_t v898 = (uint64_t) v270; + TASSIGN(v897, v898); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v897, v881, v895); + Tile v899 = Tile(v263, v262); + uint64_t v900 = (uint64_t) v272; + TASSIGN(v899, v900); + TROWSUM(v899, v883, v367); + Tile v901 = Tile(v262, v263); + uint64_t v902 = (uint64_t) v272; + TASSIGN(v901, v902); + Tile v903 = Tile(v262, v263); + uint64_t v904 = (uint64_t) v264; + TASSIGN(v903, v904); + pipe_barrier(PIPE_V); + TADDS(v903, v901, v257); + Tile v905 = Tile(v263, v262); + uint64_t v906 = (uint64_t) v264; + TASSIGN(v905, v906); + Tile v907 = Tile(v263, v259); + uint64_t v908 = (uint64_t) v269; + TASSIGN(v907, v908); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v907, v883, v905); + Tile v909 = Tile(v263, v262); + uint64_t v910 = (uint64_t) v272; + TASSIGN(v909, v910); + TROWSUM(v909, v885, v367); + Tile v911 = Tile(v262, v263); + uint64_t v912 = (uint64_t) v272; + TASSIGN(v911, v912); + Tile v913 = Tile(v262, v263); + uint64_t v914 = (uint64_t) v264; + TASSIGN(v913, v914); + pipe_barrier(PIPE_V); + TADDS(v913, v911, v257); + Tile v915 = Tile(v263, v262); + uint64_t v916 = (uint64_t) v264; + TASSIGN(v915, v916); + Tile v917 = Tile(v263, v259); + uint64_t v918 = (uint64_t) v268; + TASSIGN(v917, v918); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v917, v885, v915); + Tile v919 = Tile(v263, v262); + uint64_t v920 = (uint64_t) v272; + TASSIGN(v919, v920); + TROWSUM(v919, v887, v367); + Tile v921 = Tile(v262, v263); + uint64_t v922 = (uint64_t) v272; + TASSIGN(v921, v922); + Tile v923 = Tile(v262, v263); + uint64_t v924 = (uint64_t) v264; + TASSIGN(v923, v924); + pipe_barrier(PIPE_V); + TADDS(v923, v921, v257); + Tile v925 = Tile(v263, v262); + uint64_t v926 = (uint64_t) v264; + TASSIGN(v925, v926); + Tile v927 = Tile(v263, v259); + uint64_t v928 = (uint64_t) v267; + TASSIGN(v927, v928); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v927, v887, v925); + Tile v929 = Tile(v263, v259); + uint64_t v930 = (uint64_t) v266; + TASSIGN(v929, v930); + TADD(v929, v897, v907); + Tile v931 = Tile(v263, v259); + uint64_t v932 = (uint64_t) v265; + TASSIGN(v931, v932); + pipe_barrier(PIPE_V); + TADD(v931, v917, v927); + Tile v933 = Tile(v263, v259); + uint64_t v934 = (uint64_t) v266; + TASSIGN(v933, v934); + pipe_barrier(PIPE_V); + TADD(v933, v929, v931); + Tile v935 = Tile(v263, v259); + uint64_t v936 = (uint64_t) v266; + TASSIGN(v935, v936); + pipe_barrier(PIPE_V); + TADDS(v935, v933, v257); + Tile v937 = Tile(v263, v259); + uint64_t v938 = (uint64_t) v270; + TASSIGN(v937, v938); + pipe_barrier(PIPE_V); + TDIV(v937, v897, v935); + Tile v939 = Tile(v263, v259); + uint64_t v940 = (uint64_t) v269; + TASSIGN(v939, v940); + TDIV(v939, v907, v935); + Tile v941 = Tile(v263, v259); + uint64_t v942 = (uint64_t) v268; + TASSIGN(v941, v942); + TDIV(v941, v917, v935); + Tile v943 = Tile(v263, v259); + uint64_t v944 = (uint64_t) v267; + TASSIGN(v943, v944); + TDIV(v943, v927, v935); + Tile v945 = Tile(v263, v262); + uint64_t v946 = (uint64_t) v272; + TASSIGN(v945, v946); + pipe_barrier(PIPE_V); + TROWSUM(v945, v937, v367); + Tile v947 = Tile(v262, v263); + uint64_t v948 = (uint64_t) v272; + TASSIGN(v947, v948); + Tile v949 = Tile(v262, v263); + uint64_t v950 = (uint64_t) v264; + TASSIGN(v949, v950); + pipe_barrier(PIPE_V); + TADDS(v949, v947, v257); + Tile v951 = Tile(v263, v262); + uint64_t v952 = (uint64_t) v264; + TASSIGN(v951, v952); + Tile v953 = Tile(v263, v259); + uint64_t v954 = (uint64_t) v270; + TASSIGN(v953, v954); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v953, v937, v951); + Tile v955 = Tile(v263, v262); + uint64_t v956 = (uint64_t) v272; + TASSIGN(v955, v956); + TROWSUM(v955, v939, v367); + Tile v957 = Tile(v262, v263); + uint64_t v958 = (uint64_t) v272; + TASSIGN(v957, v958); + Tile v959 = Tile(v262, v263); + uint64_t v960 = (uint64_t) v264; + TASSIGN(v959, v960); + pipe_barrier(PIPE_V); + TADDS(v959, v957, v257); + Tile v961 = Tile(v263, v262); + uint64_t v962 = (uint64_t) v264; + TASSIGN(v961, v962); + Tile v963 = Tile(v263, v259); + uint64_t v964 = (uint64_t) v269; + TASSIGN(v963, v964); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v963, v939, v961); + Tile v965 = Tile(v263, v262); + uint64_t v966 = (uint64_t) v272; + TASSIGN(v965, v966); + TROWSUM(v965, v941, v367); + Tile v967 = Tile(v262, v263); + uint64_t v968 = (uint64_t) v272; + TASSIGN(v967, v968); + Tile v969 = Tile(v262, v263); + uint64_t v970 = (uint64_t) v264; + TASSIGN(v969, v970); + pipe_barrier(PIPE_V); + TADDS(v969, v967, v257); + Tile v971 = Tile(v263, v262); + uint64_t v972 = (uint64_t) v264; + TASSIGN(v971, v972); + Tile v973 = Tile(v263, v259); + uint64_t v974 = (uint64_t) v268; + TASSIGN(v973, v974); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v973, v941, v971); + Tile v975 = Tile(v263, v262); + uint64_t v976 = (uint64_t) v272; + TASSIGN(v975, v976); + TROWSUM(v975, v943, v367); + Tile v977 = Tile(v262, v263); + uint64_t v978 = (uint64_t) v272; + TASSIGN(v977, v978); + Tile v979 = Tile(v262, v263); + uint64_t v980 = (uint64_t) v264; + TASSIGN(v979, v980); + pipe_barrier(PIPE_V); + TADDS(v979, v977, v257); + Tile v981 = Tile(v263, v262); + uint64_t v982 = (uint64_t) v264; + TASSIGN(v981, v982); + Tile v983 = Tile(v263, v259); + uint64_t v984 = (uint64_t) v267; + TASSIGN(v983, v984); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v983, v943, v981); + Tile v985 = Tile(v263, v259); + uint64_t v986 = (uint64_t) v266; + TASSIGN(v985, v986); + TADD(v985, v953, v963); + Tile v987 = Tile(v263, v259); + uint64_t v988 = (uint64_t) v265; + TASSIGN(v987, v988); + pipe_barrier(PIPE_V); + TADD(v987, v973, v983); + Tile v989 = Tile(v263, v259); + uint64_t v990 = (uint64_t) v266; + TASSIGN(v989, v990); + pipe_barrier(PIPE_V); + TADD(v989, v985, v987); + Tile v991 = Tile(v263, v259); + uint64_t v992 = (uint64_t) v266; + TASSIGN(v991, v992); + pipe_barrier(PIPE_V); + TADDS(v991, v989, v257); + Tile v993 = Tile(v263, v259); + uint64_t v994 = (uint64_t) v270; + TASSIGN(v993, v994); + pipe_barrier(PIPE_V); + TDIV(v993, v953, v991); + Tile v995 = Tile(v263, v259); + uint64_t v996 = (uint64_t) v269; + TASSIGN(v995, v996); + TDIV(v995, v963, v991); + Tile v997 = Tile(v263, v259); + uint64_t v998 = (uint64_t) v268; + TASSIGN(v997, v998); + TDIV(v997, v973, v991); + Tile v999 = Tile(v263, v259); + uint64_t v1000 = (uint64_t) v267; + TASSIGN(v999, v1000); + TDIV(v999, v983, v991); + Tile v1001 = Tile(v263, v262); + uint64_t v1002 = (uint64_t) v272; + TASSIGN(v1001, v1002); + pipe_barrier(PIPE_V); + TROWSUM(v1001, v993, v367); + Tile v1003 = Tile(v262, v263); + uint64_t v1004 = (uint64_t) v272; + TASSIGN(v1003, v1004); + Tile v1005 = Tile(v262, v263); + uint64_t v1006 = (uint64_t) v264; + TASSIGN(v1005, v1006); + pipe_barrier(PIPE_V); + TADDS(v1005, v1003, v257); + Tile v1007 = Tile(v263, v262); + uint64_t v1008 = (uint64_t) v264; + TASSIGN(v1007, v1008); + Tile v1009 = Tile(v263, v259); + uint64_t v1010 = (uint64_t) v270; + TASSIGN(v1009, v1010); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1009, v993, v1007); + Tile v1011 = Tile(v263, v262); + uint64_t v1012 = (uint64_t) v272; + TASSIGN(v1011, v1012); + TROWSUM(v1011, v995, v367); + Tile v1013 = Tile(v262, v263); + uint64_t v1014 = (uint64_t) v272; + TASSIGN(v1013, v1014); + Tile v1015 = Tile(v262, v263); + uint64_t v1016 = (uint64_t) v264; + TASSIGN(v1015, v1016); + pipe_barrier(PIPE_V); + TADDS(v1015, v1013, v257); + Tile v1017 = Tile(v263, v262); + uint64_t v1018 = (uint64_t) v264; + TASSIGN(v1017, v1018); + Tile v1019 = Tile(v263, v259); + uint64_t v1020 = (uint64_t) v269; + TASSIGN(v1019, v1020); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1019, v995, v1017); + Tile v1021 = Tile(v263, v262); + uint64_t v1022 = (uint64_t) v272; + TASSIGN(v1021, v1022); + TROWSUM(v1021, v997, v367); + Tile v1023 = Tile(v262, v263); + uint64_t v1024 = (uint64_t) v272; + TASSIGN(v1023, v1024); + Tile v1025 = Tile(v262, v263); + uint64_t v1026 = (uint64_t) v264; + TASSIGN(v1025, v1026); + pipe_barrier(PIPE_V); + TADDS(v1025, v1023, v257); + Tile v1027 = Tile(v263, v262); + uint64_t v1028 = (uint64_t) v264; + TASSIGN(v1027, v1028); + Tile v1029 = Tile(v263, v259); + uint64_t v1030 = (uint64_t) v268; + TASSIGN(v1029, v1030); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1029, v997, v1027); + Tile v1031 = Tile(v263, v262); + uint64_t v1032 = (uint64_t) v272; + TASSIGN(v1031, v1032); + TROWSUM(v1031, v999, v367); + Tile v1033 = Tile(v262, v263); + uint64_t v1034 = (uint64_t) v272; + TASSIGN(v1033, v1034); + Tile v1035 = Tile(v262, v263); + uint64_t v1036 = (uint64_t) v264; + TASSIGN(v1035, v1036); + pipe_barrier(PIPE_V); + TADDS(v1035, v1033, v257); + Tile v1037 = Tile(v263, v262); + uint64_t v1038 = (uint64_t) v264; + TASSIGN(v1037, v1038); + Tile v1039 = Tile(v263, v259); + uint64_t v1040 = (uint64_t) v267; + TASSIGN(v1039, v1040); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1039, v999, v1037); + Tile v1041 = Tile(v263, v259); + uint64_t v1042 = (uint64_t) v266; + TASSIGN(v1041, v1042); + TADD(v1041, v1009, v1019); + Tile v1043 = Tile(v263, v259); + uint64_t v1044 = (uint64_t) v265; + TASSIGN(v1043, v1044); + pipe_barrier(PIPE_V); + TADD(v1043, v1029, v1039); + Tile v1045 = Tile(v263, v259); + uint64_t v1046 = (uint64_t) v266; + TASSIGN(v1045, v1046); + pipe_barrier(PIPE_V); + TADD(v1045, v1041, v1043); + Tile v1047 = Tile(v263, v259); + uint64_t v1048 = (uint64_t) v266; + TASSIGN(v1047, v1048); + pipe_barrier(PIPE_V); + TADDS(v1047, v1045, v257); + Tile v1049 = Tile(v263, v259); + uint64_t v1050 = (uint64_t) v270; + TASSIGN(v1049, v1050); + pipe_barrier(PIPE_V); + TDIV(v1049, v1009, v1047); + Tile v1051 = Tile(v263, v259); + uint64_t v1052 = (uint64_t) v269; + TASSIGN(v1051, v1052); + TDIV(v1051, v1019, v1047); + Tile v1053 = Tile(v263, v259); + uint64_t v1054 = (uint64_t) v268; + TASSIGN(v1053, v1054); + TDIV(v1053, v1029, v1047); + Tile v1055 = Tile(v263, v259); + uint64_t v1056 = (uint64_t) v267; + TASSIGN(v1055, v1056); + TDIV(v1055, v1039, v1047); + Tile v1057 = Tile(v263, v262); + uint64_t v1058 = (uint64_t) v272; + TASSIGN(v1057, v1058); + pipe_barrier(PIPE_V); + TROWSUM(v1057, v1049, v367); + Tile v1059 = Tile(v262, v263); + uint64_t v1060 = (uint64_t) v272; + TASSIGN(v1059, v1060); + Tile v1061 = Tile(v262, v263); + uint64_t v1062 = (uint64_t) v264; + TASSIGN(v1061, v1062); + pipe_barrier(PIPE_V); + TADDS(v1061, v1059, v257); + Tile v1063 = Tile(v263, v262); + uint64_t v1064 = (uint64_t) v264; + TASSIGN(v1063, v1064); + Tile v1065 = Tile(v263, v259); + uint64_t v1066 = (uint64_t) v270; + TASSIGN(v1065, v1066); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1065, v1049, v1063); + Tile v1067 = Tile(v263, v262); + uint64_t v1068 = (uint64_t) v272; + TASSIGN(v1067, v1068); + TROWSUM(v1067, v1051, v367); + Tile v1069 = Tile(v262, v263); + uint64_t v1070 = (uint64_t) v272; + TASSIGN(v1069, v1070); + Tile v1071 = Tile(v262, v263); + uint64_t v1072 = (uint64_t) v264; + TASSIGN(v1071, v1072); + pipe_barrier(PIPE_V); + TADDS(v1071, v1069, v257); + Tile v1073 = Tile(v263, v262); + uint64_t v1074 = (uint64_t) v264; + TASSIGN(v1073, v1074); + Tile v1075 = Tile(v263, v259); + uint64_t v1076 = (uint64_t) v269; + TASSIGN(v1075, v1076); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1075, v1051, v1073); + Tile v1077 = Tile(v263, v262); + uint64_t v1078 = (uint64_t) v272; + TASSIGN(v1077, v1078); + TROWSUM(v1077, v1053, v367); + Tile v1079 = Tile(v262, v263); + uint64_t v1080 = (uint64_t) v272; + TASSIGN(v1079, v1080); + Tile v1081 = Tile(v262, v263); + uint64_t v1082 = (uint64_t) v264; + TASSIGN(v1081, v1082); + pipe_barrier(PIPE_V); + TADDS(v1081, v1079, v257); + Tile v1083 = Tile(v263, v262); + uint64_t v1084 = (uint64_t) v264; + TASSIGN(v1083, v1084); + Tile v1085 = Tile(v263, v259); + uint64_t v1086 = (uint64_t) v268; + TASSIGN(v1085, v1086); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1085, v1053, v1083); + Tile v1087 = Tile(v263, v262); + uint64_t v1088 = (uint64_t) v272; + TASSIGN(v1087, v1088); + TROWSUM(v1087, v1055, v367); + Tile v1089 = Tile(v262, v263); + uint64_t v1090 = (uint64_t) v272; + TASSIGN(v1089, v1090); + Tile v1091 = Tile(v262, v263); + uint64_t v1092 = (uint64_t) v264; + TASSIGN(v1091, v1092); + pipe_barrier(PIPE_V); + TADDS(v1091, v1089, v257); + Tile v1093 = Tile(v263, v262); + uint64_t v1094 = (uint64_t) v264; + TASSIGN(v1093, v1094); + Tile v1095 = Tile(v263, v259); + uint64_t v1096 = (uint64_t) v267; + TASSIGN(v1095, v1096); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1095, v1055, v1093); + Tile v1097 = Tile(v263, v259); + uint64_t v1098 = (uint64_t) v266; + TASSIGN(v1097, v1098); + TADD(v1097, v1065, v1075); + Tile v1099 = Tile(v263, v259); + uint64_t v1100 = (uint64_t) v265; + TASSIGN(v1099, v1100); + pipe_barrier(PIPE_V); + TADD(v1099, v1085, v1095); + Tile v1101 = Tile(v263, v259); + uint64_t v1102 = (uint64_t) v266; + TASSIGN(v1101, v1102); + pipe_barrier(PIPE_V); + TADD(v1101, v1097, v1099); + Tile v1103 = Tile(v263, v259); + uint64_t v1104 = (uint64_t) v266; + TASSIGN(v1103, v1104); + pipe_barrier(PIPE_V); + TADDS(v1103, v1101, v257); + Tile v1105 = Tile(v263, v259); + uint64_t v1106 = (uint64_t) v270; + TASSIGN(v1105, v1106); + pipe_barrier(PIPE_V); + TDIV(v1105, v1065, v1103); + Tile v1107 = Tile(v263, v259); + uint64_t v1108 = (uint64_t) v269; + TASSIGN(v1107, v1108); + TDIV(v1107, v1075, v1103); + Tile v1109 = Tile(v263, v259); + uint64_t v1110 = (uint64_t) v268; + TASSIGN(v1109, v1110); + TDIV(v1109, v1085, v1103); + Tile v1111 = Tile(v263, v259); + uint64_t v1112 = (uint64_t) v267; + TASSIGN(v1111, v1112); + TDIV(v1111, v1095, v1103); + Tile v1113 = Tile(v263, v262); + uint64_t v1114 = (uint64_t) v272; + TASSIGN(v1113, v1114); + pipe_barrier(PIPE_V); + TROWSUM(v1113, v1105, v367); + Tile v1115 = Tile(v262, v263); + uint64_t v1116 = (uint64_t) v272; + TASSIGN(v1115, v1116); + Tile v1117 = Tile(v262, v263); + uint64_t v1118 = (uint64_t) v264; + TASSIGN(v1117, v1118); + pipe_barrier(PIPE_V); + TADDS(v1117, v1115, v257); + Tile v1119 = Tile(v263, v262); + uint64_t v1120 = (uint64_t) v264; + TASSIGN(v1119, v1120); + Tile v1121 = Tile(v263, v259); + uint64_t v1122 = (uint64_t) v270; + TASSIGN(v1121, v1122); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1121, v1105, v1119); + Tile v1123 = Tile(v263, v262); + uint64_t v1124 = (uint64_t) v272; + TASSIGN(v1123, v1124); + TROWSUM(v1123, v1107, v367); + Tile v1125 = Tile(v262, v263); + uint64_t v1126 = (uint64_t) v272; + TASSIGN(v1125, v1126); + Tile v1127 = Tile(v262, v263); + uint64_t v1128 = (uint64_t) v264; + TASSIGN(v1127, v1128); + pipe_barrier(PIPE_V); + TADDS(v1127, v1125, v257); + Tile v1129 = Tile(v263, v262); + uint64_t v1130 = (uint64_t) v264; + TASSIGN(v1129, v1130); + Tile v1131 = Tile(v263, v259); + uint64_t v1132 = (uint64_t) v269; + TASSIGN(v1131, v1132); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1131, v1107, v1129); + Tile v1133 = Tile(v263, v262); + uint64_t v1134 = (uint64_t) v272; + TASSIGN(v1133, v1134); + TROWSUM(v1133, v1109, v367); + Tile v1135 = Tile(v262, v263); + uint64_t v1136 = (uint64_t) v272; + TASSIGN(v1135, v1136); + Tile v1137 = Tile(v262, v263); + uint64_t v1138 = (uint64_t) v264; + TASSIGN(v1137, v1138); + pipe_barrier(PIPE_V); + TADDS(v1137, v1135, v257); + Tile v1139 = Tile(v263, v262); + uint64_t v1140 = (uint64_t) v264; + TASSIGN(v1139, v1140); + Tile v1141 = Tile(v263, v259); + uint64_t v1142 = (uint64_t) v268; + TASSIGN(v1141, v1142); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1141, v1109, v1139); + Tile v1143 = Tile(v263, v262); + uint64_t v1144 = (uint64_t) v272; + TASSIGN(v1143, v1144); + TROWSUM(v1143, v1111, v367); + Tile v1145 = Tile(v262, v263); + uint64_t v1146 = (uint64_t) v272; + TASSIGN(v1145, v1146); + Tile v1147 = Tile(v262, v263); + uint64_t v1148 = (uint64_t) v264; + TASSIGN(v1147, v1148); + pipe_barrier(PIPE_V); + TADDS(v1147, v1145, v257); + Tile v1149 = Tile(v263, v262); + uint64_t v1150 = (uint64_t) v264; + TASSIGN(v1149, v1150); + Tile v1151 = Tile(v263, v259); + uint64_t v1152 = (uint64_t) v267; + TASSIGN(v1151, v1152); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1151, v1111, v1149); + Tile v1153 = Tile(v263, v259); + uint64_t v1154 = (uint64_t) v266; + TASSIGN(v1153, v1154); + TADD(v1153, v1121, v1131); + Tile v1155 = Tile(v263, v259); + uint64_t v1156 = (uint64_t) v265; + TASSIGN(v1155, v1156); + pipe_barrier(PIPE_V); + TADD(v1155, v1141, v1151); + Tile v1157 = Tile(v263, v259); + uint64_t v1158 = (uint64_t) v266; + TASSIGN(v1157, v1158); + pipe_barrier(PIPE_V); + TADD(v1157, v1153, v1155); + Tile v1159 = Tile(v263, v259); + uint64_t v1160 = (uint64_t) v266; + TASSIGN(v1159, v1160); + pipe_barrier(PIPE_V); + TADDS(v1159, v1157, v257); + Tile v1161 = Tile(v263, v259); + uint64_t v1162 = (uint64_t) v270; + TASSIGN(v1161, v1162); + pipe_barrier(PIPE_V); + TDIV(v1161, v1121, v1159); + Tile v1163 = Tile(v263, v259); + uint64_t v1164 = (uint64_t) v269; + TASSIGN(v1163, v1164); + TDIV(v1163, v1131, v1159); + Tile v1165 = Tile(v263, v259); + uint64_t v1166 = (uint64_t) v268; + TASSIGN(v1165, v1166); + TDIV(v1165, v1141, v1159); + Tile v1167 = Tile(v263, v259); + uint64_t v1168 = (uint64_t) v267; + TASSIGN(v1167, v1168); + TDIV(v1167, v1151, v1159); + Tile v1169 = Tile(v263, v262); + uint64_t v1170 = (uint64_t) v272; + TASSIGN(v1169, v1170); + pipe_barrier(PIPE_V); + TROWSUM(v1169, v1161, v367); + Tile v1171 = Tile(v262, v263); + uint64_t v1172 = (uint64_t) v272; + TASSIGN(v1171, v1172); + Tile v1173 = Tile(v262, v263); + uint64_t v1174 = (uint64_t) v264; + TASSIGN(v1173, v1174); + pipe_barrier(PIPE_V); + TADDS(v1173, v1171, v257); + Tile v1175 = Tile(v263, v262); + uint64_t v1176 = (uint64_t) v264; + TASSIGN(v1175, v1176); + Tile v1177 = Tile(v263, v259); + uint64_t v1178 = (uint64_t) v270; + TASSIGN(v1177, v1178); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1177, v1161, v1175); + Tile v1179 = Tile(v263, v262); + uint64_t v1180 = (uint64_t) v272; + TASSIGN(v1179, v1180); + TROWSUM(v1179, v1163, v367); + Tile v1181 = Tile(v262, v263); + uint64_t v1182 = (uint64_t) v272; + TASSIGN(v1181, v1182); + Tile v1183 = Tile(v262, v263); + uint64_t v1184 = (uint64_t) v264; + TASSIGN(v1183, v1184); + pipe_barrier(PIPE_V); + TADDS(v1183, v1181, v257); + Tile v1185 = Tile(v263, v262); + uint64_t v1186 = (uint64_t) v264; + TASSIGN(v1185, v1186); + Tile v1187 = Tile(v263, v259); + uint64_t v1188 = (uint64_t) v269; + TASSIGN(v1187, v1188); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1187, v1163, v1185); + Tile v1189 = Tile(v263, v262); + uint64_t v1190 = (uint64_t) v272; + TASSIGN(v1189, v1190); + TROWSUM(v1189, v1165, v367); + Tile v1191 = Tile(v262, v263); + uint64_t v1192 = (uint64_t) v272; + TASSIGN(v1191, v1192); + Tile v1193 = Tile(v262, v263); + uint64_t v1194 = (uint64_t) v264; + TASSIGN(v1193, v1194); + pipe_barrier(PIPE_V); + TADDS(v1193, v1191, v257); + Tile v1195 = Tile(v263, v262); + uint64_t v1196 = (uint64_t) v264; + TASSIGN(v1195, v1196); + Tile v1197 = Tile(v263, v259); + uint64_t v1198 = (uint64_t) v268; + TASSIGN(v1197, v1198); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1197, v1165, v1195); + Tile v1199 = Tile(v263, v262); + uint64_t v1200 = (uint64_t) v272; + TASSIGN(v1199, v1200); + TROWSUM(v1199, v1167, v367); + Tile v1201 = Tile(v262, v263); + uint64_t v1202 = (uint64_t) v272; + TASSIGN(v1201, v1202); + Tile v1203 = Tile(v262, v263); + uint64_t v1204 = (uint64_t) v264; + TASSIGN(v1203, v1204); + pipe_barrier(PIPE_V); + TADDS(v1203, v1201, v257); + Tile v1205 = Tile(v263, v262); + uint64_t v1206 = (uint64_t) v264; + TASSIGN(v1205, v1206); + Tile v1207 = Tile(v263, v259); + uint64_t v1208 = (uint64_t) v267; + TASSIGN(v1207, v1208); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1207, v1167, v1205); + Tile v1209 = Tile(v263, v259); + uint64_t v1210 = (uint64_t) v266; + TASSIGN(v1209, v1210); + TADD(v1209, v1177, v1187); + Tile v1211 = Tile(v263, v259); + uint64_t v1212 = (uint64_t) v265; + TASSIGN(v1211, v1212); + pipe_barrier(PIPE_V); + TADD(v1211, v1197, v1207); + Tile v1213 = Tile(v263, v259); + uint64_t v1214 = (uint64_t) v266; + TASSIGN(v1213, v1214); + pipe_barrier(PIPE_V); + TADD(v1213, v1209, v1211); + Tile v1215 = Tile(v263, v259); + uint64_t v1216 = (uint64_t) v266; + TASSIGN(v1215, v1216); + pipe_barrier(PIPE_V); + TADDS(v1215, v1213, v257); + Tile v1217 = Tile(v263, v259); + uint64_t v1218 = (uint64_t) v270; + TASSIGN(v1217, v1218); + pipe_barrier(PIPE_V); + TDIV(v1217, v1177, v1215); + Tile v1219 = Tile(v263, v259); + uint64_t v1220 = (uint64_t) v269; + TASSIGN(v1219, v1220); + TDIV(v1219, v1187, v1215); + Tile v1221 = Tile(v263, v259); + uint64_t v1222 = (uint64_t) v268; + TASSIGN(v1221, v1222); + TDIV(v1221, v1197, v1215); + Tile v1223 = Tile(v263, v259); + uint64_t v1224 = (uint64_t) v267; + TASSIGN(v1223, v1224); + TDIV(v1223, v1207, v1215); + Tile v1225 = Tile(v263, v262); + uint64_t v1226 = (uint64_t) v272; + TASSIGN(v1225, v1226); + pipe_barrier(PIPE_V); + TROWSUM(v1225, v1217, v367); + Tile v1227 = Tile(v262, v263); + uint64_t v1228 = (uint64_t) v272; + TASSIGN(v1227, v1228); + Tile v1229 = Tile(v262, v263); + uint64_t v1230 = (uint64_t) v264; + TASSIGN(v1229, v1230); + pipe_barrier(PIPE_V); + TADDS(v1229, v1227, v257); + Tile v1231 = Tile(v263, v262); + uint64_t v1232 = (uint64_t) v264; + TASSIGN(v1231, v1232); + Tile v1233 = Tile(v263, v259); + uint64_t v1234 = (uint64_t) v270; + TASSIGN(v1233, v1234); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1233, v1217, v1231); + Tile v1235 = Tile(v263, v262); + uint64_t v1236 = (uint64_t) v272; + TASSIGN(v1235, v1236); + TROWSUM(v1235, v1219, v367); + Tile v1237 = Tile(v262, v263); + uint64_t v1238 = (uint64_t) v272; + TASSIGN(v1237, v1238); + Tile v1239 = Tile(v262, v263); + uint64_t v1240 = (uint64_t) v264; + TASSIGN(v1239, v1240); + pipe_barrier(PIPE_V); + TADDS(v1239, v1237, v257); + Tile v1241 = Tile(v263, v262); + uint64_t v1242 = (uint64_t) v264; + TASSIGN(v1241, v1242); + Tile v1243 = Tile(v263, v259); + uint64_t v1244 = (uint64_t) v269; + TASSIGN(v1243, v1244); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1243, v1219, v1241); + Tile v1245 = Tile(v263, v262); + uint64_t v1246 = (uint64_t) v272; + TASSIGN(v1245, v1246); + TROWSUM(v1245, v1221, v367); + Tile v1247 = Tile(v262, v263); + uint64_t v1248 = (uint64_t) v272; + TASSIGN(v1247, v1248); + Tile v1249 = Tile(v262, v263); + uint64_t v1250 = (uint64_t) v264; + TASSIGN(v1249, v1250); + pipe_barrier(PIPE_V); + TADDS(v1249, v1247, v257); + Tile v1251 = Tile(v263, v262); + uint64_t v1252 = (uint64_t) v264; + TASSIGN(v1251, v1252); + Tile v1253 = Tile(v263, v259); + uint64_t v1254 = (uint64_t) v268; + TASSIGN(v1253, v1254); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1253, v1221, v1251); + Tile v1255 = Tile(v263, v262); + uint64_t v1256 = (uint64_t) v272; + TASSIGN(v1255, v1256); + TROWSUM(v1255, v1223, v367); + Tile v1257 = Tile(v262, v263); + uint64_t v1258 = (uint64_t) v272; + TASSIGN(v1257, v1258); + Tile v1259 = Tile(v262, v263); + uint64_t v1260 = (uint64_t) v264; + TASSIGN(v1259, v1260); + pipe_barrier(PIPE_V); + TADDS(v1259, v1257, v257); + Tile v1261 = Tile(v263, v262); + uint64_t v1262 = (uint64_t) v264; + TASSIGN(v1261, v1262); + Tile v1263 = Tile(v263, v259); + uint64_t v1264 = (uint64_t) v267; + TASSIGN(v1263, v1264); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1263, v1223, v1261); + Tile v1265 = Tile(v263, v259); + uint64_t v1266 = (uint64_t) v266; + TASSIGN(v1265, v1266); + TADD(v1265, v1233, v1243); + Tile v1267 = Tile(v263, v259); + uint64_t v1268 = (uint64_t) v265; + TASSIGN(v1267, v1268); + pipe_barrier(PIPE_V); + TADD(v1267, v1253, v1263); + Tile v1269 = Tile(v263, v259); + uint64_t v1270 = (uint64_t) v266; + TASSIGN(v1269, v1270); + pipe_barrier(PIPE_V); + TADD(v1269, v1265, v1267); + Tile v1271 = Tile(v263, v259); + uint64_t v1272 = (uint64_t) v266; + TASSIGN(v1271, v1272); + pipe_barrier(PIPE_V); + TADDS(v1271, v1269, v257); + Tile v1273 = Tile(v263, v259); + uint64_t v1274 = (uint64_t) v270; + TASSIGN(v1273, v1274); + pipe_barrier(PIPE_V); + TDIV(v1273, v1233, v1271); + Tile v1275 = Tile(v263, v259); + uint64_t v1276 = (uint64_t) v269; + TASSIGN(v1275, v1276); + TDIV(v1275, v1243, v1271); + Tile v1277 = Tile(v263, v259); + uint64_t v1278 = (uint64_t) v268; + TASSIGN(v1277, v1278); + TDIV(v1277, v1253, v1271); + Tile v1279 = Tile(v263, v259); + uint64_t v1280 = (uint64_t) v267; + TASSIGN(v1279, v1280); + TDIV(v1279, v1263, v1271); + Tile v1281 = Tile(v263, v262); + uint64_t v1282 = (uint64_t) v272; + TASSIGN(v1281, v1282); + pipe_barrier(PIPE_V); + TROWSUM(v1281, v1273, v367); + Tile v1283 = Tile(v262, v263); + uint64_t v1284 = (uint64_t) v272; + TASSIGN(v1283, v1284); + Tile v1285 = Tile(v262, v263); + uint64_t v1286 = (uint64_t) v264; + TASSIGN(v1285, v1286); + pipe_barrier(PIPE_V); + TADDS(v1285, v1283, v257); + Tile v1287 = Tile(v263, v262); + uint64_t v1288 = (uint64_t) v264; + TASSIGN(v1287, v1288); + Tile v1289 = Tile(v263, v259); + uint64_t v1290 = (uint64_t) v270; + TASSIGN(v1289, v1290); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1289, v1273, v1287); + Tile v1291 = Tile(v263, v262); + uint64_t v1292 = (uint64_t) v272; + TASSIGN(v1291, v1292); + TROWSUM(v1291, v1275, v367); + Tile v1293 = Tile(v262, v263); + uint64_t v1294 = (uint64_t) v272; + TASSIGN(v1293, v1294); + Tile v1295 = Tile(v262, v263); + uint64_t v1296 = (uint64_t) v264; + TASSIGN(v1295, v1296); + pipe_barrier(PIPE_V); + TADDS(v1295, v1293, v257); + Tile v1297 = Tile(v263, v262); + uint64_t v1298 = (uint64_t) v264; + TASSIGN(v1297, v1298); + Tile v1299 = Tile(v263, v259); + uint64_t v1300 = (uint64_t) v269; + TASSIGN(v1299, v1300); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1299, v1275, v1297); + Tile v1301 = Tile(v263, v262); + uint64_t v1302 = (uint64_t) v272; + TASSIGN(v1301, v1302); + TROWSUM(v1301, v1277, v367); + Tile v1303 = Tile(v262, v263); + uint64_t v1304 = (uint64_t) v272; + TASSIGN(v1303, v1304); + Tile v1305 = Tile(v262, v263); + uint64_t v1306 = (uint64_t) v264; + TASSIGN(v1305, v1306); + pipe_barrier(PIPE_V); + TADDS(v1305, v1303, v257); + Tile v1307 = Tile(v263, v262); + uint64_t v1308 = (uint64_t) v264; + TASSIGN(v1307, v1308); + Tile v1309 = Tile(v263, v259); + uint64_t v1310 = (uint64_t) v268; + TASSIGN(v1309, v1310); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1309, v1277, v1307); + Tile v1311 = Tile(v263, v262); + uint64_t v1312 = (uint64_t) v272; + TASSIGN(v1311, v1312); + TROWSUM(v1311, v1279, v367); + Tile v1313 = Tile(v262, v263); + uint64_t v1314 = (uint64_t) v272; + TASSIGN(v1313, v1314); + Tile v1315 = Tile(v262, v263); + uint64_t v1316 = (uint64_t) v264; + TASSIGN(v1315, v1316); + pipe_barrier(PIPE_V); + TADDS(v1315, v1313, v257); + Tile v1317 = Tile(v263, v262); + uint64_t v1318 = (uint64_t) v264; + TASSIGN(v1317, v1318); + Tile v1319 = Tile(v263, v259); + uint64_t v1320 = (uint64_t) v267; + TASSIGN(v1319, v1320); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1319, v1279, v1317); + Tile v1321 = Tile(v263, v259); + uint64_t v1322 = (uint64_t) v266; + TASSIGN(v1321, v1322); + TADD(v1321, v1289, v1299); + Tile v1323 = Tile(v263, v259); + uint64_t v1324 = (uint64_t) v265; + TASSIGN(v1323, v1324); + pipe_barrier(PIPE_V); + TADD(v1323, v1309, v1319); + Tile v1325 = Tile(v263, v259); + uint64_t v1326 = (uint64_t) v266; + TASSIGN(v1325, v1326); + pipe_barrier(PIPE_V); + TADD(v1325, v1321, v1323); + Tile v1327 = Tile(v263, v259); + uint64_t v1328 = (uint64_t) v266; + TASSIGN(v1327, v1328); + pipe_barrier(PIPE_V); + TADDS(v1327, v1325, v257); + Tile v1329 = Tile(v263, v259); + uint64_t v1330 = (uint64_t) v270; + TASSIGN(v1329, v1330); + pipe_barrier(PIPE_V); + TDIV(v1329, v1289, v1327); + Tile v1331 = Tile(v263, v259); + uint64_t v1332 = (uint64_t) v269; + TASSIGN(v1331, v1332); + TDIV(v1331, v1299, v1327); + Tile v1333 = Tile(v263, v259); + uint64_t v1334 = (uint64_t) v268; + TASSIGN(v1333, v1334); + TDIV(v1333, v1309, v1327); + Tile v1335 = Tile(v263, v259); + uint64_t v1336 = (uint64_t) v267; + TASSIGN(v1335, v1336); + TDIV(v1335, v1319, v1327); + Tile v1337 = Tile(v263, v262); + uint64_t v1338 = (uint64_t) v272; + TASSIGN(v1337, v1338); + pipe_barrier(PIPE_V); + TROWSUM(v1337, v1329, v367); + Tile v1339 = Tile(v262, v263); + uint64_t v1340 = (uint64_t) v272; + TASSIGN(v1339, v1340); + Tile v1341 = Tile(v262, v263); + uint64_t v1342 = (uint64_t) v264; + TASSIGN(v1341, v1342); + pipe_barrier(PIPE_V); + TADDS(v1341, v1339, v257); + Tile v1343 = Tile(v263, v262); + uint64_t v1344 = (uint64_t) v264; + TASSIGN(v1343, v1344); + Tile v1345 = Tile(v263, v259); + uint64_t v1346 = (uint64_t) v270; + TASSIGN(v1345, v1346); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1345, v1329, v1343); + Tile v1347 = Tile(v263, v262); + uint64_t v1348 = (uint64_t) v272; + TASSIGN(v1347, v1348); + TROWSUM(v1347, v1331, v367); + Tile v1349 = Tile(v262, v263); + uint64_t v1350 = (uint64_t) v272; + TASSIGN(v1349, v1350); + Tile v1351 = Tile(v262, v263); + uint64_t v1352 = (uint64_t) v264; + TASSIGN(v1351, v1352); + pipe_barrier(PIPE_V); + TADDS(v1351, v1349, v257); + Tile v1353 = Tile(v263, v262); + uint64_t v1354 = (uint64_t) v264; + TASSIGN(v1353, v1354); + Tile v1355 = Tile(v263, v259); + uint64_t v1356 = (uint64_t) v269; + TASSIGN(v1355, v1356); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1355, v1331, v1353); + Tile v1357 = Tile(v263, v262); + uint64_t v1358 = (uint64_t) v272; + TASSIGN(v1357, v1358); + TROWSUM(v1357, v1333, v367); + Tile v1359 = Tile(v262, v263); + uint64_t v1360 = (uint64_t) v272; + TASSIGN(v1359, v1360); + Tile v1361 = Tile(v262, v263); + uint64_t v1362 = (uint64_t) v264; + TASSIGN(v1361, v1362); + pipe_barrier(PIPE_V); + TADDS(v1361, v1359, v257); + Tile v1363 = Tile(v263, v262); + uint64_t v1364 = (uint64_t) v264; + TASSIGN(v1363, v1364); + Tile v1365 = Tile(v263, v259); + uint64_t v1366 = (uint64_t) v268; + TASSIGN(v1365, v1366); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1365, v1333, v1363); + Tile v1367 = Tile(v263, v262); + uint64_t v1368 = (uint64_t) v272; + TASSIGN(v1367, v1368); + TROWSUM(v1367, v1335, v367); + Tile v1369 = Tile(v262, v263); + uint64_t v1370 = (uint64_t) v272; + TASSIGN(v1369, v1370); + Tile v1371 = Tile(v262, v263); + uint64_t v1372 = (uint64_t) v264; + TASSIGN(v1371, v1372); + pipe_barrier(PIPE_V); + TADDS(v1371, v1369, v257); + Tile v1373 = Tile(v263, v262); + uint64_t v1374 = (uint64_t) v264; + TASSIGN(v1373, v1374); + Tile v1375 = Tile(v263, v259); + uint64_t v1376 = (uint64_t) v267; + TASSIGN(v1375, v1376); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1375, v1335, v1373); + Tile v1377 = Tile(v263, v259); + uint64_t v1378 = (uint64_t) v266; + TASSIGN(v1377, v1378); + TADD(v1377, v1345, v1355); + Tile v1379 = Tile(v263, v259); + uint64_t v1380 = (uint64_t) v265; + TASSIGN(v1379, v1380); + pipe_barrier(PIPE_V); + TADD(v1379, v1365, v1375); + Tile v1381 = Tile(v263, v259); + uint64_t v1382 = (uint64_t) v266; + TASSIGN(v1381, v1382); + pipe_barrier(PIPE_V); + TADD(v1381, v1377, v1379); + Tile v1383 = Tile(v263, v259); + uint64_t v1384 = (uint64_t) v266; + TASSIGN(v1383, v1384); + pipe_barrier(PIPE_V); + TADDS(v1383, v1381, v257); + Tile v1385 = Tile(v263, v259); + uint64_t v1386 = (uint64_t) v270; + TASSIGN(v1385, v1386); + pipe_barrier(PIPE_V); + TDIV(v1385, v1345, v1383); + Tile v1387 = Tile(v263, v259); + uint64_t v1388 = (uint64_t) v269; + TASSIGN(v1387, v1388); + TDIV(v1387, v1355, v1383); + Tile v1389 = Tile(v263, v259); + uint64_t v1390 = (uint64_t) v268; + TASSIGN(v1389, v1390); + TDIV(v1389, v1365, v1383); + Tile v1391 = Tile(v263, v259); + uint64_t v1392 = (uint64_t) v267; + TASSIGN(v1391, v1392); + TDIV(v1391, v1375, v1383); + Tile v1393 = Tile(v263, v262); + uint64_t v1394 = (uint64_t) v272; + TASSIGN(v1393, v1394); + pipe_barrier(PIPE_V); + TROWSUM(v1393, v1385, v367); + Tile v1395 = Tile(v262, v263); + uint64_t v1396 = (uint64_t) v272; + TASSIGN(v1395, v1396); + Tile v1397 = Tile(v262, v263); + uint64_t v1398 = (uint64_t) v264; + TASSIGN(v1397, v1398); + pipe_barrier(PIPE_V); + TADDS(v1397, v1395, v257); + Tile v1399 = Tile(v263, v262); + uint64_t v1400 = (uint64_t) v264; + TASSIGN(v1399, v1400); + Tile v1401 = Tile(v263, v259); + uint64_t v1402 = (uint64_t) v270; + TASSIGN(v1401, v1402); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1401, v1385, v1399); + Tile v1403 = Tile(v263, v262); + uint64_t v1404 = (uint64_t) v272; + TASSIGN(v1403, v1404); + TROWSUM(v1403, v1387, v367); + Tile v1405 = Tile(v262, v263); + uint64_t v1406 = (uint64_t) v272; + TASSIGN(v1405, v1406); + Tile v1407 = Tile(v262, v263); + uint64_t v1408 = (uint64_t) v264; + TASSIGN(v1407, v1408); + pipe_barrier(PIPE_V); + TADDS(v1407, v1405, v257); + Tile v1409 = Tile(v263, v262); + uint64_t v1410 = (uint64_t) v264; + TASSIGN(v1409, v1410); + Tile v1411 = Tile(v263, v259); + uint64_t v1412 = (uint64_t) v269; + TASSIGN(v1411, v1412); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1411, v1387, v1409); + Tile v1413 = Tile(v263, v262); + uint64_t v1414 = (uint64_t) v272; + TASSIGN(v1413, v1414); + TROWSUM(v1413, v1389, v367); + Tile v1415 = Tile(v262, v263); + uint64_t v1416 = (uint64_t) v272; + TASSIGN(v1415, v1416); + Tile v1417 = Tile(v262, v263); + uint64_t v1418 = (uint64_t) v264; + TASSIGN(v1417, v1418); + pipe_barrier(PIPE_V); + TADDS(v1417, v1415, v257); + Tile v1419 = Tile(v263, v262); + uint64_t v1420 = (uint64_t) v264; + TASSIGN(v1419, v1420); + Tile v1421 = Tile(v263, v259); + uint64_t v1422 = (uint64_t) v268; + TASSIGN(v1421, v1422); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1421, v1389, v1419); + Tile v1423 = Tile(v263, v262); + uint64_t v1424 = (uint64_t) v273; + TASSIGN(v1423, v1424); + TROWSUM(v1423, v1391, v367); + Tile v1425 = Tile(v262, v263); + uint64_t v1426 = (uint64_t) v273; + TASSIGN(v1425, v1426); + Tile v1427 = Tile(v262, v263); + uint64_t v1428 = (uint64_t) v264; + TASSIGN(v1427, v1428); + pipe_barrier(PIPE_V); + TADDS(v1427, v1425, v257); + Tile v1429 = Tile(v263, v262); + uint64_t v1430 = (uint64_t) v264; + TASSIGN(v1429, v1430); + Tile v1431 = Tile(v263, v259); + uint64_t v1432 = (uint64_t) v267; + TASSIGN(v1431, v1432); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v1431, v1391, v1429); + Tile v1433 = Tile(v263, v259); + uint64_t v1434 = (uint64_t) v266; + TASSIGN(v1433, v1434); + TADD(v1433, v1401, v1411); + Tile v1435 = Tile(v263, v259); + uint64_t v1436 = (uint64_t) v265; + TASSIGN(v1435, v1436); + pipe_barrier(PIPE_V); + TADD(v1435, v1421, v1431); + Tile v1437 = Tile(v263, v259); + uint64_t v1438 = (uint64_t) v266; + TASSIGN(v1437, v1438); + pipe_barrier(PIPE_V); + TADD(v1437, v1433, v1435); + Tile v1439 = Tile(v263, v259); + uint64_t v1440 = (uint64_t) v266; + TASSIGN(v1439, v1440); + pipe_barrier(PIPE_V); + TADDS(v1439, v1437, v257); + Tile v1441 = Tile(v263, v259); + uint64_t v1442 = (uint64_t) v270; + TASSIGN(v1441, v1442); + pipe_barrier(PIPE_V); + TDIV(v1441, v1401, v1439); + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + Tile v1443 = Tile(v263, v259); + uint64_t v1444 = (uint64_t) v269; + TASSIGN(v1443, v1444); + TDIV(v1443, v1411, v1439); + set_flag(PIPE_V, PIPE_S, EVENT_ID1); + Tile v1445 = Tile(v263, v259); + uint64_t v1446 = (uint64_t) v268; + TASSIGN(v1445, v1446); + TDIV(v1445, v1421, v1439); + set_flag(PIPE_V, PIPE_S, EVENT_ID2); + Tile v1447 = Tile(v263, v259); + uint64_t v1448 = (uint64_t) v267; + TASSIGN(v1447, v1448); + TDIV(v1447, v1431, v1439); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + float v1449 = v1441.GetValue(v260); + v2[v260] = v1449; + wait_flag(PIPE_V, PIPE_S, EVENT_ID1); + float v1450 = v1443.GetValue(v260); + v2[v261] = v1450; + wait_flag(PIPE_V, PIPE_S, EVENT_ID2); + float v1451 = v1445.GetValue(v260); + v2[v259] = v1451; + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + float v1452 = v1447.GetValue(v260); + v2[v258] = v1452; + float v1453 = v1441.GetValue(v262); + v2[v262] = v1453; + float v1454 = v1443.GetValue(v262); + v2[v256] = v1454; + float v1455 = v1445.GetValue(v262); + v2[v255] = v1455; + float v1456 = v1447.GetValue(v262); + v2[v254] = v1456; + float v1457 = v1441.GetValue(v253); + v2[v253] = v1457; + float v1458 = v1443.GetValue(v253); + v2[v252] = v1458; + float v1459 = v1445.GetValue(v253); + v2[v251] = v1459; + float v1460 = v1447.GetValue(v253); + v2[v250] = v1460; + float v1461 = v1441.GetValue(v249); + v2[v249] = v1461; + float v1462 = v1443.GetValue(v249); + v2[v248] = v1462; + float v1463 = v1445.GetValue(v249); + v2[v247] = v1463; + float v1464 = v1447.GetValue(v249); + v2[v246] = v1464; + float v1465 = v1441.GetValue(v259); + v2[v263] = v1465; + float v1466 = v1443.GetValue(v259); + v2[v245] = v1466; + float v1467 = v1445.GetValue(v259); + v2[v244] = v1467; + float v1468 = v1447.GetValue(v259); + v2[v243] = v1468; + float v1469 = v1441.GetValue(v255); + v2[v242] = v1469; + float v1470 = v1443.GetValue(v255); + v2[v241] = v1470; + float v1471 = v1445.GetValue(v255); + v2[v240] = v1471; + float v1472 = v1447.GetValue(v255); + v2[v239] = v1472; + float v1473 = v1441.GetValue(v251); + v2[v238] = v1473; + float v1474 = v1443.GetValue(v251); + v2[v237] = v1474; + float v1475 = v1445.GetValue(v251); + v2[v236] = v1475; + float v1476 = v1447.GetValue(v251); + v2[v235] = v1476; + float v1477 = v1441.GetValue(v247); + v2[v234] = v1477; + float v1478 = v1443.GetValue(v247); + v2[v233] = v1478; + float v1479 = v1445.GetValue(v247); + v2[v232] = v1479; + float v1480 = v1447.GetValue(v247); + v2[v231] = v1480; + float v1481 = v1441.GetValue(v263); + v2[v230] = v1481; + float v1482 = v1443.GetValue(v263); + v2[v229] = v1482; + float v1483 = v1445.GetValue(v263); + v2[v228] = v1483; + float v1484 = v1447.GetValue(v263); + v2[v227] = v1484; + float v1485 = v1441.GetValue(v242); + v2[v226] = v1485; + float v1486 = v1443.GetValue(v242); + v2[v225] = v1486; + float v1487 = v1445.GetValue(v242); + v2[v224] = v1487; + float v1488 = v1447.GetValue(v242); + v2[v223] = v1488; + float v1489 = v1441.GetValue(v238); + v2[v222] = v1489; + float v1490 = v1443.GetValue(v238); + v2[v221] = v1490; + float v1491 = v1445.GetValue(v238); + v2[v220] = v1491; + float v1492 = v1447.GetValue(v238); + v2[v219] = v1492; + float v1493 = v1441.GetValue(v234); + v2[v218] = v1493; + float v1494 = v1443.GetValue(v234); + v2[v217] = v1494; + float v1495 = v1445.GetValue(v234); + v2[v216] = v1495; + float v1496 = v1447.GetValue(v234); + v2[v215] = v1496; + float v1497 = v1441.GetValue(v244); + v2[v214] = v1497; + float v1498 = v1443.GetValue(v244); + v2[v213] = v1498; + float v1499 = v1445.GetValue(v244); + v2[v212] = v1499; + float v1500 = v1447.GetValue(v244); + v2[v211] = v1500; + float v1501 = v1441.GetValue(v240); + v2[v210] = v1501; + float v1502 = v1443.GetValue(v240); + v2[v209] = v1502; + float v1503 = v1445.GetValue(v240); + v2[v208] = v1503; + float v1504 = v1447.GetValue(v240); + v2[v207] = v1504; + float v1505 = v1441.GetValue(v236); + v2[v206] = v1505; + float v1506 = v1443.GetValue(v236); + v2[v205] = v1506; + float v1507 = v1445.GetValue(v236); + v2[v204] = v1507; + float v1508 = v1447.GetValue(v236); + v2[v203] = v1508; + float v1509 = v1441.GetValue(v232); + v2[v202] = v1509; + float v1510 = v1443.GetValue(v232); + v2[v201] = v1510; + float v1511 = v1445.GetValue(v232); + v2[v200] = v1511; + float v1512 = v1447.GetValue(v232); + v2[v199] = v1512; + float v1513 = v1441.GetValue(v230); + v2[v198] = v1513; + float v1514 = v1443.GetValue(v230); + v2[v197] = v1514; + float v1515 = v1445.GetValue(v230); + v2[v196] = v1515; + float v1516 = v1447.GetValue(v230); + v2[v195] = v1516; + float v1517 = v1441.GetValue(v226); + v2[v194] = v1517; + float v1518 = v1443.GetValue(v226); + v2[v193] = v1518; + float v1519 = v1445.GetValue(v226); + v2[v192] = v1519; + float v1520 = v1447.GetValue(v226); + v2[v191] = v1520; + float v1521 = v1441.GetValue(v222); + v2[v190] = v1521; + float v1522 = v1443.GetValue(v222); + v2[v189] = v1522; + float v1523 = v1445.GetValue(v222); + v2[v188] = v1523; + float v1524 = v1447.GetValue(v222); + v2[v187] = v1524; + float v1525 = v1441.GetValue(v218); + v2[v186] = v1525; + float v1526 = v1443.GetValue(v218); + v2[v185] = v1526; + float v1527 = v1445.GetValue(v218); + v2[v184] = v1527; + float v1528 = v1447.GetValue(v218); + v2[v183] = v1528; + float v1529 = v1441.GetValue(v228); + v2[v182] = v1529; + float v1530 = v1443.GetValue(v228); + v2[v181] = v1530; + float v1531 = v1445.GetValue(v228); + v2[v180] = v1531; + float v1532 = v1447.GetValue(v228); + v2[v179] = v1532; + float v1533 = v1441.GetValue(v224); + v2[v178] = v1533; + float v1534 = v1443.GetValue(v224); + v2[v177] = v1534; + float v1535 = v1445.GetValue(v224); + v2[v176] = v1535; + float v1536 = v1447.GetValue(v224); + v2[v175] = v1536; + float v1537 = v1441.GetValue(v220); + v2[v174] = v1537; + float v1538 = v1443.GetValue(v220); + v2[v173] = v1538; + float v1539 = v1445.GetValue(v220); + v2[v172] = v1539; + float v1540 = v1447.GetValue(v220); + v2[v171] = v1540; + float v1541 = v1441.GetValue(v216); + v2[v170] = v1541; + float v1542 = v1443.GetValue(v216); + v2[v169] = v1542; + float v1543 = v1445.GetValue(v216); + v2[v168] = v1543; + float v1544 = v1447.GetValue(v216); + v2[v167] = v1544; + float v1545 = v1441.GetValue(v214); + v2[v166] = v1545; + float v1546 = v1443.GetValue(v214); + v2[v165] = v1546; + float v1547 = v1445.GetValue(v214); + v2[v164] = v1547; + float v1548 = v1447.GetValue(v214); + v2[v163] = v1548; + float v1549 = v1441.GetValue(v210); + v2[v162] = v1549; + float v1550 = v1443.GetValue(v210); + v2[v161] = v1550; + float v1551 = v1445.GetValue(v210); + v2[v160] = v1551; + float v1552 = v1447.GetValue(v210); + v2[v159] = v1552; + float v1553 = v1441.GetValue(v206); + v2[v158] = v1553; + float v1554 = v1443.GetValue(v206); + v2[v157] = v1554; + float v1555 = v1445.GetValue(v206); + v2[v156] = v1555; + float v1556 = v1447.GetValue(v206); + v2[v155] = v1556; + float v1557 = v1441.GetValue(v202); + v2[v154] = v1557; + float v1558 = v1443.GetValue(v202); + v2[v153] = v1558; + float v1559 = v1445.GetValue(v202); + v2[v152] = v1559; + float v1560 = v1447.GetValue(v202); + v2[v151] = v1560; + float v1561 = v1441.GetValue(v212); + v2[v150] = v1561; + float v1562 = v1443.GetValue(v212); + v2[v149] = v1562; + float v1563 = v1445.GetValue(v212); + v2[v148] = v1563; + float v1564 = v1447.GetValue(v212); + v2[v147] = v1564; + float v1565 = v1441.GetValue(v208); + v2[v146] = v1565; + float v1566 = v1443.GetValue(v208); + v2[v145] = v1566; + float v1567 = v1445.GetValue(v208); + v2[v144] = v1567; + float v1568 = v1447.GetValue(v208); + v2[v143] = v1568; + float v1569 = v1441.GetValue(v204); + v2[v142] = v1569; + float v1570 = v1443.GetValue(v204); + v2[v141] = v1570; + float v1571 = v1445.GetValue(v204); + v2[v140] = v1571; + float v1572 = v1447.GetValue(v204); + v2[v139] = v1572; + float v1573 = v1441.GetValue(v200); + v2[v138] = v1573; + float v1574 = v1443.GetValue(v200); + v2[v137] = v1574; + float v1575 = v1445.GetValue(v200); + v2[v136] = v1575; + float v1576 = v1447.GetValue(v200); + v2[v135] = v1576; + float v1577 = v1441.GetValue(v198); + v2[v134] = v1577; + float v1578 = v1443.GetValue(v198); + v2[v133] = v1578; + float v1579 = v1445.GetValue(v198); + v2[v132] = v1579; + float v1580 = v1447.GetValue(v198); + v2[v131] = v1580; + float v1581 = v1441.GetValue(v194); + v2[v130] = v1581; + float v1582 = v1443.GetValue(v194); + v2[v129] = v1582; + float v1583 = v1445.GetValue(v194); + v2[v128] = v1583; + float v1584 = v1447.GetValue(v194); + v2[v127] = v1584; + float v1585 = v1441.GetValue(v190); + v2[v126] = v1585; + float v1586 = v1443.GetValue(v190); + v2[v125] = v1586; + float v1587 = v1445.GetValue(v190); + v2[v124] = v1587; + float v1588 = v1447.GetValue(v190); + v2[v123] = v1588; + float v1589 = v1441.GetValue(v186); + v2[v122] = v1589; + float v1590 = v1443.GetValue(v186); + v2[v121] = v1590; + float v1591 = v1445.GetValue(v186); + v2[v120] = v1591; + float v1592 = v1447.GetValue(v186); + v2[v119] = v1592; + float v1593 = v1441.GetValue(v196); + v2[v118] = v1593; + float v1594 = v1443.GetValue(v196); + v2[v117] = v1594; + float v1595 = v1445.GetValue(v196); + v2[v116] = v1595; + float v1596 = v1447.GetValue(v196); + v2[v115] = v1596; + float v1597 = v1441.GetValue(v192); + v2[v114] = v1597; + float v1598 = v1443.GetValue(v192); + v2[v113] = v1598; + float v1599 = v1445.GetValue(v192); + v2[v112] = v1599; + float v1600 = v1447.GetValue(v192); + v2[v111] = v1600; + float v1601 = v1441.GetValue(v188); + v2[v110] = v1601; + float v1602 = v1443.GetValue(v188); + v2[v109] = v1602; + float v1603 = v1445.GetValue(v188); + v2[v108] = v1603; + float v1604 = v1447.GetValue(v188); + v2[v107] = v1604; + float v1605 = v1441.GetValue(v184); + v2[v106] = v1605; + float v1606 = v1443.GetValue(v184); + v2[v105] = v1606; + float v1607 = v1445.GetValue(v184); + v2[v104] = v1607; + float v1608 = v1447.GetValue(v184); + v2[v103] = v1608; + float v1609 = v1441.GetValue(v182); + v2[v102] = v1609; + float v1610 = v1443.GetValue(v182); + v2[v101] = v1610; + float v1611 = v1445.GetValue(v182); + v2[v100] = v1611; + float v1612 = v1447.GetValue(v182); + v2[v99] = v1612; + float v1613 = v1441.GetValue(v178); + v2[v98] = v1613; + float v1614 = v1443.GetValue(v178); + v2[v97] = v1614; + float v1615 = v1445.GetValue(v178); + v2[v96] = v1615; + float v1616 = v1447.GetValue(v178); + v2[v95] = v1616; + float v1617 = v1441.GetValue(v174); + v2[v94] = v1617; + float v1618 = v1443.GetValue(v174); + v2[v93] = v1618; + float v1619 = v1445.GetValue(v174); + v2[v92] = v1619; + float v1620 = v1447.GetValue(v174); + v2[v91] = v1620; + float v1621 = v1441.GetValue(v170); + v2[v90] = v1621; + float v1622 = v1443.GetValue(v170); + v2[v89] = v1622; + float v1623 = v1445.GetValue(v170); + v2[v88] = v1623; + float v1624 = v1447.GetValue(v170); + v2[v87] = v1624; + float v1625 = v1441.GetValue(v180); + v2[v86] = v1625; + float v1626 = v1443.GetValue(v180); + v2[v85] = v1626; + float v1627 = v1445.GetValue(v180); + v2[v84] = v1627; + float v1628 = v1447.GetValue(v180); + v2[v83] = v1628; + float v1629 = v1441.GetValue(v176); + v2[v82] = v1629; + float v1630 = v1443.GetValue(v176); + v2[v81] = v1630; + float v1631 = v1445.GetValue(v176); + v2[v80] = v1631; + float v1632 = v1447.GetValue(v176); + v2[v79] = v1632; + float v1633 = v1441.GetValue(v172); + v2[v78] = v1633; + float v1634 = v1443.GetValue(v172); + v2[v77] = v1634; + float v1635 = v1445.GetValue(v172); + v2[v76] = v1635; + float v1636 = v1447.GetValue(v172); + v2[v75] = v1636; + float v1637 = v1441.GetValue(v168); + v2[v74] = v1637; + float v1638 = v1443.GetValue(v168); + v2[v73] = v1638; + float v1639 = v1445.GetValue(v168); + v2[v72] = v1639; + float v1640 = v1447.GetValue(v168); + v2[v71] = v1640; + float v1641 = v1441.GetValue(v166); + v2[v70] = v1641; + float v1642 = v1443.GetValue(v166); + v2[v69] = v1642; + float v1643 = v1445.GetValue(v166); + v2[v68] = v1643; + float v1644 = v1447.GetValue(v166); + v2[v67] = v1644; + float v1645 = v1441.GetValue(v162); + v2[v66] = v1645; + float v1646 = v1443.GetValue(v162); + v2[v65] = v1646; + float v1647 = v1445.GetValue(v162); + v2[v64] = v1647; + float v1648 = v1447.GetValue(v162); + v2[v63] = v1648; + float v1649 = v1441.GetValue(v158); + v2[v62] = v1649; + float v1650 = v1443.GetValue(v158); + v2[v61] = v1650; + float v1651 = v1445.GetValue(v158); + v2[v60] = v1651; + float v1652 = v1447.GetValue(v158); + v2[v59] = v1652; + float v1653 = v1441.GetValue(v154); + v2[v58] = v1653; + float v1654 = v1443.GetValue(v154); + v2[v57] = v1654; + float v1655 = v1445.GetValue(v154); + v2[v56] = v1655; + float v1656 = v1447.GetValue(v154); + v2[v55] = v1656; + float v1657 = v1441.GetValue(v164); + v2[v54] = v1657; + float v1658 = v1443.GetValue(v164); + v2[v53] = v1658; + float v1659 = v1445.GetValue(v164); + v2[v52] = v1659; + float v1660 = v1447.GetValue(v164); + v2[v51] = v1660; + float v1661 = v1441.GetValue(v160); + v2[v50] = v1661; + float v1662 = v1443.GetValue(v160); + v2[v49] = v1662; + float v1663 = v1445.GetValue(v160); + v2[v48] = v1663; + float v1664 = v1447.GetValue(v160); + v2[v47] = v1664; + float v1665 = v1441.GetValue(v156); + v2[v46] = v1665; + float v1666 = v1443.GetValue(v156); + v2[v45] = v1666; + float v1667 = v1445.GetValue(v156); + v2[v44] = v1667; + float v1668 = v1447.GetValue(v156); + v2[v43] = v1668; + float v1669 = v1441.GetValue(v152); + v2[v42] = v1669; + float v1670 = v1443.GetValue(v152); + v2[v41] = v1670; + float v1671 = v1445.GetValue(v152); + v2[v40] = v1671; + float v1672 = v1447.GetValue(v152); + v2[v39] = v1672; + float v1673 = v1441.GetValue(v150); + v2[v38] = v1673; + float v1674 = v1443.GetValue(v150); + v2[v37] = v1674; + float v1675 = v1445.GetValue(v150); + v2[v36] = v1675; + float v1676 = v1447.GetValue(v150); + v2[v35] = v1676; + float v1677 = v1441.GetValue(v146); + v2[v34] = v1677; + float v1678 = v1443.GetValue(v146); + v2[v33] = v1678; + float v1679 = v1445.GetValue(v146); + v2[v32] = v1679; + float v1680 = v1447.GetValue(v146); + v2[v31] = v1680; + float v1681 = v1441.GetValue(v142); + v2[v30] = v1681; + float v1682 = v1443.GetValue(v142); + v2[v29] = v1682; + float v1683 = v1445.GetValue(v142); + v2[v28] = v1683; + float v1684 = v1447.GetValue(v142); + v2[v27] = v1684; + float v1685 = v1441.GetValue(v138); + v2[v26] = v1685; + float v1686 = v1443.GetValue(v138); + v2[v25] = v1686; + float v1687 = v1445.GetValue(v138); + v2[v24] = v1687; + float v1688 = v1447.GetValue(v138); + v2[v23] = v1688; + float v1689 = v1441.GetValue(v148); + v2[v22] = v1689; + float v1690 = v1443.GetValue(v148); + v2[v21] = v1690; + float v1691 = v1445.GetValue(v148); + v2[v20] = v1691; + float v1692 = v1447.GetValue(v148); + v2[v19] = v1692; + float v1693 = v1441.GetValue(v144); + v2[v18] = v1693; + float v1694 = v1443.GetValue(v144); + v2[v17] = v1694; + float v1695 = v1445.GetValue(v144); + v2[v16] = v1695; + float v1696 = v1447.GetValue(v144); + v2[v15] = v1696; + float v1697 = v1441.GetValue(v140); + v2[v14] = v1697; + float v1698 = v1443.GetValue(v140); + v2[v13] = v1698; + float v1699 = v1445.GetValue(v140); + v2[v12] = v1699; + float v1700 = v1447.GetValue(v140); + v2[v11] = v1700; + float v1701 = v1441.GetValue(v136); + v2[v10] = v1701; + float v1702 = v1443.GetValue(v136); + v2[v9] = v1702; + float v1703 = v1445.GetValue(v136); + v2[v8] = v1703; + float v1704 = v1447.GetValue(v136); + v2[v7] = v1704; + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: comb_logits_inline19__ssa_v1 + __gm__ Tensor* comb_logits_inline19__ssa_v1_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* comb_logits_inline19__ssa_v1 = reinterpret_cast<__gm__ float*>(comb_logits_inline19__ssa_v1_tensor->buffer.addr) + comb_logits_inline19__ssa_v1_tensor->start_offset; + + // Unpack tensor: comb_flat_inline37__ssa_v0 + __gm__ Tensor* comb_flat_inline37__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* comb_flat_inline37__ssa_v0 = reinterpret_cast<__gm__ float*>(comb_flat_inline37__ssa_v0_tensor->buffer.addr) + comb_flat_inline37__ssa_v0_tensor->start_offset; + + // Forward to ptoas-generated function + comb_sinkhorn(comb_logits_inline19__ssa_v1, comb_flat_inline37__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/combine.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/combine.cpp index 7333af6e7..543bb4770 100644 --- a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/combine.cpp +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/combine.cpp @@ -60,10 +60,10 @@ using namespace pto; // Demo dimensions — must match dispatch.cpp / main.py. static constexpr int N = 2; -static constexpr int T = 8; +static constexpr int T = 16; static constexpr int TOPK = 2; -static constexpr int D = 64; -static constexpr int L = 4; +static constexpr int D = 4096; +static constexpr int L = 8; static constexpr int R = 32; static constexpr int W_PAD = 8; static constexpr int IDX_PAD = 8; @@ -218,7 +218,7 @@ extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ in TLOAD(add_bf_tile, src_g); pipe_barrier(PIPE_ALL); - TCVT(add_fp_tile, add_bf_tile, RoundMode::CAST_ROUND); + TCVT(add_fp_tile, add_bf_tile, RoundMode::CAST_RINT); pipe_barrier(PIPE_V); TADD(acc_tile, acc_tile, add_fp_tile); pipe_barrier(PIPE_V); diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/dispatch.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/dispatch.cpp index 547a759b6..c4d63eb6d 100644 --- a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/dispatch.cpp +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/dispatch.cpp @@ -67,14 +67,16 @@ using namespace pto; -// Demo dimensions — must match main.py. +// Demo dimensions — must match main.py. Sized to mirror the production +// moe_expert decode config (D = hidden_size = 4096, L = N_LOCAL_EXPERTS = 8, +// T = decode tokens per rank = 16, R = RECV_MAX = 32). static constexpr int N = 2; -static constexpr int T = 8; +static constexpr int T = 16; static constexpr int TOPK = 2; -static constexpr int D = 64; -static constexpr int L = 4; +static constexpr int D = 4096; +static constexpr int L = 8; static constexpr int R = 32; -static constexpr int N_ROUTES = T * TOPK; // 16 +static constexpr int N_ROUTES = T * TOPK; // 32 // Weight payload tile width. The protocol contract is one FP32 weight per // (e, slot) — recv_w[L, R] FP32. AIV vector tiles have a hardware minimum @@ -90,15 +92,15 @@ static constexpr int IDX_PAD = 8; // Window region byte sizes — mirror *_BYTES in main.py. // -// Layout: -// pub_counts[N][N][L] INT32 (64 B) +// Layout (sizes for the D=4096 / L=8 / T=16 / R=32 demo config): +// pub_counts[N][N][L] INT32 (256 B) // count_done_sig[N] INT32 (padded slot, 64 B) -// recv_x[L][R][D] BF16 (16 KB) -// recv_w[L][R][W_PAD] FP32 (4 KB; weight at slot [0], rest = 0) -// recv_idx[L][R][IDX_PAD] INT32 (4 KB; r=t*TOPK+k at slot [0], rest = 0) +// recv_x[L][R][D] BF16 (2 MiB) +// recv_w[L][R][W_PAD] FP32 (8 KB; weight at slot [0], rest = 0) +// recv_idx[L][R][IDX_PAD] INT32 (8 KB; r=t*TOPK+k at slot [0], rest = 0) // data_done_sig[N] INT32 (padded slot, 64 B) // ---- Cross-rank visible regions consumed by combine.cpp ---- -// routed_y_buf[T][TOPK][D] BF16 (2 KB demo; combine push destination, +// routed_y_buf[T][TOPK][D] BF16 (256 KB; combine push destination, // addressed directly by (t, k)) // combine_done_sig[N] INT32 (padded slot, 64 B) // @@ -106,10 +108,10 @@ static constexpr int IDX_PAD = 8; // as a host-backed device tensor via the orch. static constexpr int kPubCountsBytes = N * N * L * 4; // N*N*L INT32 static constexpr int kSignalBytes = 64; -static constexpr int kRecvXBytes = L * R * D * 2; // BF16 +static constexpr int kRecvXBytes = L * R * D * 2; // BF16, 2 MiB static constexpr int kRecvWBytes = L * R * W_PAD * 4; static constexpr int kRecvIdxBytes = L * R * IDX_PAD * 4; -static constexpr int kRoutedYBufBytes = T * TOPK * D * 2; // BF16 +static constexpr int kRoutedYBufBytes = T * TOPK * D * 2; // BF16, 256 KiB static constexpr int kOffPubCounts = 0; static constexpr int kOffCountDone = kOffPubCounts + kPubCountsBytes; @@ -179,7 +181,10 @@ extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ in // ------------------------------------------------------------------ // histogram: scalar histogram + (dst, loc_e)-sorted route table. // - // Scalar GM reads of indices[t * TOPK + k] are fine on AIV. + // Scalar GM reads of indices[t * TOPK + k] are fine on AIV. The upstream + // router task (write_route_outputs) flushes its scalar stores to L2 via + // dcci(..., CACHELINE_OUT) + dsb at its own tail, so dispatch picks up + // fresh data here without needing a reader-side cache invalidate. // Bucket each route by (dst, loc_e) and stable-sort so the payload_push // cursor matches each peer's src-major slot_offset rule. // ------------------------------------------------------------------ @@ -194,10 +199,17 @@ extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ in int route_loc_e[N_ROUTES]; int route_r[N_ROUTES]; + // EP routing policy: route slot k of each token goes to peer + // (my_rank + k) % N. The router was JIT'd with EP_WORLD_SIZE=1 and writes + // local expert IDs in [0, L); we layer the rank component on top here so + // the demo spreads tokens across both ranks instead of pinning everything + // to rank 0. A production EP system would read a true global ID from the + // router output instead. for (int r = 0; r < N_ROUTES; ++r) { - int eid = indices[r]; - int dst = eid / L; - int loc_e = eid - dst * L; + int t = r / TOPK; + int k = r - t * TOPK; + int dst = (my_rank + k) % N; + int loc_e = indices[r]; // local expert id in [0, L) send_counts[dst][loc_e] += 1; route_dst[r] = dst; route_loc_e[r] = loc_e; diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_gate_up_dequant.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_gate_up_dequant.cpp new file mode 100644 index 000000000..ce3c89b7b --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_gate_up_dequant.cpp @@ -0,0 +1,204 @@ +// Kernel Function: exp_gate_up_dequant +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void exp_gate_up_dequant(__gm__ int32_t* v1, __gm__ int32_t* v2, __gm__ float* v3, __gm__ float* v4, __gm__ float* v5, __gm__ float* v6, __gm__ float* v7, int32_t v8, int32_t v9) { + RoundMode v10 = RoundMode::CAST_NONE; + unsigned v11 = 0; + const int32_t v12 = 4096; + const int32_t v13 = 256; + const int32_t v14 = 16; + const int32_t v15 = 1; + const int64_t v16 = 18432; + const int64_t v17 = 2048; + const int64_t v18 = 1024; + const int64_t v19 = 0; + const int64_t v20 = 67584; + const int64_t v21 = 51200; + const int64_t v22 = 34816; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v23 = Tile(v14, v13); + uint64_t v24 = (uint64_t) v22; + TASSIGN(v23, v24); + pto::Shape<1, 1, 1, 16, 256> v25 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v26 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v27 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v1 + ((v11 + v11 * (unsigned) v12) + v11 * (unsigned) v13 + v11 * (unsigned) v15), v25, v26); + TLOAD(v23, v27); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v28 = Tile(v14, v13); + uint64_t v29 = (uint64_t) v21; + TASSIGN(v28, v29); + pto::Shape<1, 1, 1, 16, 256> v30 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v31 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v32 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v2 + ((v11 + v11 * (unsigned) v12) + v11 * (unsigned) v13 + v11 * (unsigned) v15), v30, v31); + TLOAD(v28, v32); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v33 = Tile(v14, v15); + uint64_t v34 = (uint64_t) v20; + TASSIGN(v33, v34); + pto::Shape<1, 1, 1, 16, 1> v35 = pto::Shape<1, 1, 1, 16, 1>(); + pto::Stride<16, 16, 16, 1, 16> v36 = pto::Stride<16, 16, 16, 1, 16>(); + GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN> v37 = GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN>(v5 + (v11 + v11 * (unsigned) v15 + v11 * (unsigned) v14), v35, v36); + TLOAD(v33, v37); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + Tile v38 = Tile(v14, v13); + uint64_t v39 = (uint64_t) v22; + TASSIGN(v38, v39); + Tile v40 = Tile(v14, v13); + uint64_t v41 = (uint64_t) v21; + TASSIGN(v40, v41); + Tile v42 = Tile(v15, v13); + uint64_t v43 = (uint64_t) v19; + TASSIGN(v42, v43); + pto::Shape<1, 1, 1, 1, 256> v44 = pto::Shape<1, 1, 1, 1, 256>(); + pto::Stride<4096, 4096, 4096, 4096, 1> v45 = pto::Stride<4096, 4096, 4096, 4096, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND> v46 = GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND>(v3 + (v11 + (unsigned) v8 * (unsigned) v12 + (unsigned) v9 * (unsigned) v15), v44, v45); + TLOAD(v42, v46); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + Tile v47 = Tile(v15, v13); + uint64_t v48 = (uint64_t) v18; + TASSIGN(v47, v48); + pto::Shape<1, 1, 1, 1, 256> v49 = pto::Shape<1, 1, 1, 1, 256>(); + pto::Stride<4096, 4096, 4096, 4096, 1> v50 = pto::Stride<4096, 4096, 4096, 4096, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND> v51 = GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND>(v4 + (v11 + (unsigned) v8 * (unsigned) v12 + (unsigned) v9 * (unsigned) v15), v49, v50); + TLOAD(v47, v51); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + Tile v52 = Tile(v14, v13); + uint64_t v53 = (uint64_t) v17; + TASSIGN(v52, v53); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v52, v38, v10); + Tile v54 = Tile(v14, v13); + uint64_t v55 = (uint64_t) v16; + TASSIGN(v54, v55); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TCVT(v54, v40, v10); + Tile v56 = Tile(v14, v13); + uint64_t v57 = (uint64_t) v17; + TASSIGN(v56, v57); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + TROWEXPANDMUL(v56, v52, v33); + Tile v58 = Tile(v14, v13); + uint64_t v59 = (uint64_t) v17; + TASSIGN(v58, v59); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + TCOLEXPANDMUL(v58, v56, v42); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + Tile v60 = Tile(v14, v13); + uint64_t v61 = (uint64_t) v16; + TASSIGN(v60, v61); + TROWEXPANDMUL(v60, v54, v33); + Tile v62 = Tile(v14, v13); + uint64_t v63 = (uint64_t) v16; + TASSIGN(v62, v63); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + TCOLEXPANDMUL(v62, v60, v47); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + pto::Shape<1, 1, 1, 16, 256> v64 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v65 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v66 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v6 + (v11 + v11 * (unsigned) v13 + v11 * (unsigned) v15), v64, v65); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v66, v58); + pto::Shape<1, 1, 1, 16, 256> v67 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v68 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v69 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v7 + (v11 + v11 * (unsigned) v13 + v11 * (unsigned) v15), v67, v68); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(v69, v62); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: gate_acc_inline69__rv_v2 + __gm__ Tensor* gate_acc_inline69__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ int32_t* gate_acc_inline69__rv_v2 = reinterpret_cast<__gm__ int32_t*>(gate_acc_inline69__rv_v2_tensor->buffer.addr) + gate_acc_inline69__rv_v2_tensor->start_offset; + + // Unpack tensor: up_acc_inline46__rv_v2 + __gm__ Tensor* up_acc_inline46__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int32_t* up_acc_inline46__rv_v2 = reinterpret_cast<__gm__ int32_t*>(up_acc_inline46__rv_v2_tensor->buffer.addr) + up_acc_inline46__rv_v2_tensor->start_offset; + + // Unpack tensor: expert_w1_scale__ssa_v0 + __gm__ Tensor* expert_w1_scale__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* expert_w1_scale__ssa_v0 = reinterpret_cast<__gm__ float*>(expert_w1_scale__ssa_v0_tensor->buffer.addr) + expert_w1_scale__ssa_v0_tensor->start_offset; + + // Unpack tensor: expert_w3_scale__ssa_v0 + __gm__ Tensor* expert_w3_scale__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ float* expert_w3_scale__ssa_v0 = reinterpret_cast<__gm__ float*>(expert_w3_scale__ssa_v0_tensor->buffer.addr) + expert_w3_scale__ssa_v0_tensor->start_offset; + + // Unpack tensor: recv_x_scale_dq_inline29__ssa_v0 + __gm__ Tensor* recv_x_scale_dq_inline29__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[4]); + __gm__ float* recv_x_scale_dq_inline29__ssa_v0 = reinterpret_cast<__gm__ float*>(recv_x_scale_dq_inline29__ssa_v0_tensor->buffer.addr) + recv_x_scale_dq_inline29__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[5]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack tensor: ret1__out + __gm__ Tensor* ret1__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[6]); + __gm__ float* ret1__out = reinterpret_cast<__gm__ float*>(ret1__out_tensor->buffer.addr) + ret1__out_tensor->start_offset; + + // Unpack scalar: local_i_inline67__idx_v0 + union { uint64_t u64; int64_t val; } local_i_inline67__idx_v0_conv; + local_i_inline67__idx_v0_conv.u64 = args[7]; + int64_t local_i_inline67__idx_v0 = local_i_inline67__idx_v0_conv.val; + + // Unpack scalar: n0_inline72__idx_v0 + union { uint64_t u64; int64_t val; } n0_inline72__idx_v0_conv; + n0_inline72__idx_v0_conv.u64 = args[8]; + int64_t n0_inline72__idx_v0 = n0_inline72__idx_v0_conv.val; + + // Forward to ptoas-generated function + exp_gate_up_dequant(gate_acc_inline69__rv_v2, up_acc_inline46__rv_v2, expert_w1_scale__ssa_v0, expert_w3_scale__ssa_v0, recv_x_scale_dq_inline29__ssa_v0, ret0__out, ret1__out, local_i_inline67__idx_v0, n0_inline72__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_h_q.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_h_q.cpp new file mode 100644 index 000000000..65aabca5b --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_h_q.cpp @@ -0,0 +1,213 @@ +// Kernel Function: exp_h_q +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void exp_h_q(__gm__ float* v1, __gm__ int8_t* v2, __gm__ float* v3) { + RoundMode v4 = RoundMode::CAST_TRUNC; + RoundMode v5 = RoundMode::CAST_ROUND; + unsigned v6 = 0; + const float v7 = 127.0f; + const int32_t v8 = 256; + const int32_t v9 = 0; + const float v10 = 9.99999974E-5f; + const int32_t v11 = 1; + const int32_t v12 = 4096; + const int32_t v13 = 16; + const int64_t v14 = 24640; + const int64_t v15 = 16448; + const int64_t v16 = 64; + const int64_t v17 = 0; + const int64_t v18 = 61568; + const int64_t v19 = 45184; + const int64_t v20 = 28800; + const int64_t v21 = 28736; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + size_t v22 = (size_t) v12; + size_t v23 = (size_t) v9; + size_t v24 = (size_t) v8; + Tile v25 = Tile(v11, v13); + uint64_t v26 = (uint64_t) v21; + TASSIGN(v25, v26); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TEXPANDS(v25, v10); + for (size_t v27 = v23; v27 < v22; v27 += v24) { + Tile v28 = Tile(v13, v8); + uint64_t v29 = (uint64_t) v20; + TASSIGN(v28, v29); + pto::Shape<1, 1, 1, 16, 256> v30 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v31 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v32 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v12 + (unsigned) ((int32_t) v27) * (unsigned) v11), v30, v31); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v28, v32); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v33 = Tile(v13, v8); + uint64_t v34 = (uint64_t) v19; + TASSIGN(v33, v34); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TNEG(v33, v28); + Tile v35 = Tile(v13, v8); + uint64_t v36 = (uint64_t) v20; + TASSIGN(v35, v36); + pipe_barrier(PIPE_V); + TMAX(v35, v28, v33); + Tile v37 = Tile(v13, v8); + uint64_t v38 = (uint64_t) v19; + TASSIGN(v37, v38); + Tile v39 = Tile(v13, v11); + uint64_t v40 = (uint64_t) v18; + TASSIGN(v39, v40); + pipe_barrier(PIPE_V); + TROWMAX(v39, v35, v37); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v41 = Tile(v11, v13); + uint64_t v42 = (uint64_t) v18; + TASSIGN(v41, v42); + Tile v43 = Tile(v11, v13); + uint64_t v44 = (uint64_t) v21; + TASSIGN(v43, v44); + pipe_barrier(PIPE_V); + TMAX(v43, v25, v41); + } + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + Tile v45 = Tile(v11, v13); + uint64_t v46 = (uint64_t) v17; + TASSIGN(v45, v46); + TEXPANDS(v45, v7); + Tile v47 = Tile(v11, v13); + uint64_t v48 = (uint64_t) v21; + TASSIGN(v47, v48); + pipe_barrier(PIPE_V); + TDIV(v47, v45, v25); + Tile v49 = Tile(v11, v13); + uint64_t v50 = (uint64_t) v17; + TASSIGN(v49, v50); + pipe_barrier(PIPE_V); + TRECIP(v49, v47); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + Tile v51 = Tile(v13, v11); + uint64_t v52 = (uint64_t) v17; + TASSIGN(v51, v52); + Tile v53 = Tile(v13, v11); + uint64_t v54 = (uint64_t) v21; + TASSIGN(v53, v54); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + for (size_t v55 = v23; v55 < v22; v55 += v24) { + int32_t v56 = (int32_t) v55; + Tile v57 = Tile(v13, v8); + uint64_t v58 = (uint64_t) v20; + TASSIGN(v57, v58); + pto::Shape<1, 1, 1, 16, 256> v59 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v60 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v61 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v12 + (unsigned) v56 * (unsigned) v11), v59, v60); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v57, v61); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v62 = Tile(v13, v8); + uint64_t v63 = (uint64_t) v20; + TASSIGN(v62, v63); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TROWEXPANDMUL(v62, v57, v53); + Tile v64 = Tile(v13, v8); + uint64_t v65 = (uint64_t) v16; + TASSIGN(v64, v65); + pipe_barrier(PIPE_V); + TCVT(v64, v62, v5); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + Tile v66 = Tile(v13, v8); + uint64_t v67 = (uint64_t) v15; + TASSIGN(v66, v67); + pipe_barrier(PIPE_V); + TCVT(v66, v64, v5); + Tile v68 = Tile(v13, v8); + uint64_t v69 = (uint64_t) v14; + TASSIGN(v68, v69); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TCVT(v68, v66, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 256> v70 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v71 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v72 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v2 + (v6 + v6 * (unsigned) v12 + (unsigned) v56 * (unsigned) v11), v70, v71); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v72, v68); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + pto::Shape<1, 1, 1, 16, 1> v73 = pto::Shape<1, 1, 1, 16, 1>(); + pto::Stride<16, 16, 16, 1, 16> v74 = pto::Stride<16, 16, 16, 1, 16>(); + GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN> v75 = GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN>(v3 + (v6 + v6 * (unsigned) v11 + v6 * (unsigned) v13), v73, v74); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(v75, v51); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: h_tile_fp32_inline18__rv_v2 + __gm__ Tensor* h_tile_fp32_inline18__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* h_tile_fp32_inline18__rv_v2 = reinterpret_cast<__gm__ float*>(h_tile_fp32_inline18__rv_v2_tensor->buffer.addr) + h_tile_fp32_inline18__rv_v2_tensor->start_offset; + + // Unpack tensor: h_tile_i8_inline92__ssa_v0 + __gm__ Tensor* h_tile_i8_inline92__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int8_t* h_tile_i8_inline92__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(h_tile_i8_inline92__ssa_v0_tensor->buffer.addr) + h_tile_i8_inline92__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Forward to ptoas-generated function + exp_h_q(h_tile_fp32_inline18__rv_v2, h_tile_i8_inline92__ssa_v0, ret0__out); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_recv_y_write.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_recv_y_write.cpp new file mode 100644 index 000000000..0135b048e --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_recv_y_write.cpp @@ -0,0 +1,108 @@ +// Kernel Function: exp_recv_y_write +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void exp_recv_y_write(__gm__ float* v1, __gm__ bfloat16_t* v2, int32_t v3, int32_t v4) { + RoundMode v5 = RoundMode::CAST_RINT; + unsigned v6 = 0; + const int32_t v7 = 4096; + const int32_t v8 = 1; + const int32_t v9 = 512; + const int32_t v10 = 16; + const int64_t v11 = 32768; + const int64_t v12 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v13 = Tile(v10, v9); + uint64_t v14 = (uint64_t) v12; + TASSIGN(v13, v14); + pto::Shape<1, 1, 1, 16, 512> v15 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<8192, 8192, 8192, 512, 1> v16 = pto::Stride<8192, 8192, 8192, 512, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND> v17 = GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v9 + v6 * (unsigned) v8), v15, v16); + TLOAD(v13, v17); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v18 = Tile(v10, v9); + uint64_t v19 = (uint64_t) v11; + TASSIGN(v18, v19); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v18, v13, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 512> v20 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v21 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v22 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v2 + (v6 + (unsigned) v3 * (unsigned) v7 + (unsigned) v4 * (unsigned) v8), v20, v21); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v22, v18); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: y_2d_v1_inline17__ssa_v0 + __gm__ Tensor* y_2d_v1_inline17__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* y_2d_v1_inline17__ssa_v0 = reinterpret_cast<__gm__ float*>(y_2d_v1_inline17__ssa_v0_tensor->buffer.addr) + y_2d_v1_inline17__ssa_v0_tensor->start_offset; + + // Unpack tensor: recv_y_flat_inline53__iter_v5 + __gm__ Tensor* recv_y_flat_inline53__iter_v5_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ bfloat16_t* recv_y_flat_inline53__iter_v5 = reinterpret_cast<__gm__ bfloat16_t*>(recv_y_flat_inline53__iter_v5_tensor->buffer.addr) + recv_y_flat_inline53__iter_v5_tensor->start_offset; + + // Unpack scalar: flat_t0_inline40__ssa_v0 + union { uint64_t u64; int64_t val; } flat_t0_inline40__ssa_v0_conv; + flat_t0_inline40__ssa_v0_conv.u64 = args[2]; + int64_t flat_t0_inline40__ssa_v0 = flat_t0_inline40__ssa_v0_conv.val; + + // Unpack scalar: d0_inline49__idx_v0 + union { uint64_t u64; int64_t val; } d0_inline49__idx_v0_conv; + d0_inline49__idx_v0_conv.u64 = args[3]; + int64_t d0_inline49__idx_v0 = d0_inline49__idx_v0_conv.val; + + // Forward to ptoas-generated function + exp_recv_y_write(y_2d_v1_inline17__ssa_v0, recv_y_flat_inline53__iter_v5, flat_t0_inline40__ssa_v0, d0_inline49__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_swiglu.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_swiglu.cpp new file mode 100644 index 000000000..fd893f9ed --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_swiglu.cpp @@ -0,0 +1,161 @@ +// Kernel Function: exp_swiglu +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void exp_swiglu(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, __gm__ float* v4, int32_t v5) { + unsigned v6 = 0; + const float v7 = 1.0f; + const int32_t v8 = 1; + const int32_t v9 = 256; + const int32_t v10 = 16; + const int64_t v11 = 0; + const int64_t v12 = 49216; + const int64_t v13 = 32832; + const int64_t v14 = 16448; + const int64_t v15 = 64; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v16 = Tile(v10, v9); + uint64_t v17 = (uint64_t) v15; + TASSIGN(v16, v17); + pto::Shape<1, 1, 1, 16, 256> v18 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v19 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v20 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v9 + v6 * (unsigned) v8), v18, v19); + TLOAD(v16, v20); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v21 = Tile(v10, v9); + uint64_t v22 = (uint64_t) v14; + TASSIGN(v21, v22); + pto::Shape<1, 1, 1, 16, 256> v23 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v24 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v25 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v2 + (v6 + v6 * (unsigned) v9 + v6 * (unsigned) v8), v23, v24); + TLOAD(v21, v25); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v26 = Tile(v10, v9); + uint64_t v27 = (uint64_t) v13; + TASSIGN(v26, v27); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TNEG(v26, v16); + Tile v28 = Tile(v10, v9); + uint64_t v29 = (uint64_t) v13; + TASSIGN(v28, v29); + pipe_barrier(PIPE_V); + TEXP(v28, v26); + Tile v30 = Tile(v10, v9); + uint64_t v31 = (uint64_t) v13; + TASSIGN(v30, v31); + pipe_barrier(PIPE_V); + TADDS(v30, v28, v7); + Tile v32 = Tile(v10, v9); + uint64_t v33 = (uint64_t) v12; + TASSIGN(v32, v33); + pipe_barrier(PIPE_V); + TRECIP(v32, v30); + Tile v34 = Tile(v10, v9); + uint64_t v35 = (uint64_t) v15; + TASSIGN(v34, v35); + pipe_barrier(PIPE_V); + TMUL(v34, v16, v32); + Tile v36 = Tile(v10, v9); + uint64_t v37 = (uint64_t) v15; + TASSIGN(v36, v37); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TMUL(v36, v34, v21); + Tile v38 = Tile(v10, v8); + uint64_t v39 = (uint64_t) v11; + TASSIGN(v38, v39); + pto::Shape<1, 1, 1, 16, 1> v40 = pto::Shape<1, 1, 1, 16, 1>(); + pto::Stride<16, 16, 16, 1, 256> v41 = pto::Stride<16, 16, 16, 1, 256>(); + GlobalTensor, pto::Stride<16, 16, 16, 1, 256>, pto::Layout::DN> v42 = GlobalTensor, pto::Stride<16, 16, 16, 1, 256>, pto::Layout::DN>(v3 + (v6 + (unsigned) v5 * (unsigned) v8 + v6 * (unsigned) v9), v40, v41); + TLOAD(v38, v42); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + Tile v43 = Tile(v10, v9); + uint64_t v44 = (uint64_t) v15; + TASSIGN(v43, v44); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + pipe_barrier(PIPE_V); + TROWEXPANDMUL(v43, v36, v38); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 256> v45 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v46 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v47 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v4 + (v6 + v6 * (unsigned) v9 + v6 * (unsigned) v8), v45, v46); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v47, v43); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: gate_2d_v1_inline7__ssa_v0 + __gm__ Tensor* gate_2d_v1_inline7__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* gate_2d_v1_inline7__ssa_v0 = reinterpret_cast<__gm__ float*>(gate_2d_v1_inline7__ssa_v0_tensor->buffer.addr) + gate_2d_v1_inline7__ssa_v0_tensor->start_offset; + + // Unpack tensor: up_2d_v1_inline78__ssa_v0 + __gm__ Tensor* up_2d_v1_inline78__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* up_2d_v1_inline78__ssa_v0 = reinterpret_cast<__gm__ float*>(up_2d_v1_inline78__ssa_v0_tensor->buffer.addr) + up_2d_v1_inline78__ssa_v0_tensor->start_offset; + + // Unpack tensor: recv_weights_flat_inline62__ssa_v0 + __gm__ Tensor* recv_weights_flat_inline62__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* recv_weights_flat_inline62__ssa_v0 = reinterpret_cast<__gm__ float*>(recv_weights_flat_inline62__ssa_v0_tensor->buffer.addr) + recv_weights_flat_inline62__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack scalar: flat_t0_inline40__ssa_v0 + union { uint64_t u64; int64_t val; } flat_t0_inline40__ssa_v0_conv; + flat_t0_inline40__ssa_v0_conv.u64 = args[4]; + int64_t flat_t0_inline40__ssa_v0 = flat_t0_inline40__ssa_v0_conv.val; + + // Forward to ptoas-generated function + exp_swiglu(gate_2d_v1_inline7__ssa_v0, up_2d_v1_inline78__ssa_v0, recv_weights_flat_inline62__ssa_v0, ret0__out, flat_t0_inline40__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_swiglu_mask.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_swiglu_mask.cpp new file mode 100644 index 000000000..69645e11b --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_swiglu_mask.cpp @@ -0,0 +1,108 @@ +// Kernel Function: exp_swiglu_mask +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void exp_swiglu_mask(__gm__ float* v1, __gm__ float* v2, int32_t v3, int32_t v4) { + unsigned v5 = 0; + const int32_t v6 = 4096; + const int32_t v7 = 1; + const int32_t v8 = 256; + const int32_t v9 = 16; + const int64_t v10 = 16384; + const int64_t v11 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v12 = Tile(v9, v8); + uint64_t v13 = (uint64_t) v11; + TASSIGN(v12, v13); + pto::Shape<1, 1, 1, 16, 256> v14 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v15 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v16 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v1 + (v5 + v5 * (unsigned) v8 + v5 * (unsigned) v7), v14, v15); + TLOAD(v12, v16); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + v12.SetValidShape(v3, v8); + Tile v17 = Tile(v9, v8); + uint64_t v18 = (uint64_t) v10; + TASSIGN(v17, v18); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TFILLPAD(v17, v12); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 256> v19 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v20 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v21 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v2 + (v5 + v5 * (unsigned) v6 + (unsigned) v4 * (unsigned) v7), v19, v20); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v21, v17); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: h_chunk_inline89__ssa_v0 + __gm__ Tensor* h_chunk_inline89__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* h_chunk_inline89__ssa_v0 = reinterpret_cast<__gm__ float*>(h_chunk_inline89__ssa_v0_tensor->buffer.addr) + h_chunk_inline89__ssa_v0_tensor->start_offset; + + // Unpack tensor: h_tile_fp32_inline18__iter_v1 + __gm__ Tensor* h_tile_fp32_inline18__iter_v1_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* h_tile_fp32_inline18__iter_v1 = reinterpret_cast<__gm__ float*>(h_tile_fp32_inline18__iter_v1_tensor->buffer.addr) + h_tile_fp32_inline18__iter_v1_tensor->start_offset; + + // Unpack scalar: valid_rows_inline73__ssa_v0 + union { uint64_t u64; int64_t val; } valid_rows_inline73__ssa_v0_conv; + valid_rows_inline73__ssa_v0_conv.u64 = args[2]; + int64_t valid_rows_inline73__ssa_v0 = valid_rows_inline73__ssa_v0_conv.val; + + // Unpack scalar: n0_inline72__idx_v0 + union { uint64_t u64; int64_t val; } n0_inline72__idx_v0_conv; + n0_inline72__idx_v0_conv.u64 = args[3]; + int64_t n0_inline72__idx_v0 = n0_inline72__idx_v0_conv.val; + + // Forward to ptoas-generated function + exp_swiglu_mask(h_chunk_inline89__ssa_v0, h_tile_fp32_inline18__iter_v1, valid_rows_inline73__ssa_v0, n0_inline72__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_w2_dequant.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_w2_dequant.cpp new file mode 100644 index 000000000..0f31e0015 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/exp_w2_dequant.cpp @@ -0,0 +1,150 @@ +// Kernel Function: exp_w2_dequant +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void exp_w2_dequant(__gm__ int32_t* v1, __gm__ float* v2, __gm__ float* v3, __gm__ float* v4, int32_t v5, int32_t v6) { + RoundMode v7 = RoundMode::CAST_NONE; + unsigned v8 = 0; + const int32_t v9 = 8192; + const int32_t v10 = 4096; + const int32_t v11 = 512; + const int32_t v12 = 16; + const int32_t v13 = 1; + const int64_t v14 = 34880; + const int64_t v15 = 32832; + const int64_t v16 = 32768; + const int64_t v17 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v18 = Tile(v12, v11); + uint64_t v19 = (uint64_t) v17; + TASSIGN(v18, v19); + pto::Shape<1, 1, 1, 16, 512> v20 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<8192, 8192, 8192, 512, 1> v21 = pto::Stride<8192, 8192, 8192, 512, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND> v22 = GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND>(v1 + ((v8 + v8 * (unsigned) v9) + v8 * (unsigned) v11 + v8 * (unsigned) v13), v20, v21); + TLOAD(v18, v22); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v23 = Tile(v12, v13); + uint64_t v24 = (uint64_t) v16; + TASSIGN(v23, v24); + pto::Shape<1, 1, 1, 16, 1> v25 = pto::Shape<1, 1, 1, 16, 1>(); + pto::Stride<16, 16, 16, 1, 16> v26 = pto::Stride<16, 16, 16, 1, 16>(); + GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN> v27 = GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN>(v3 + (v8 + v8 * (unsigned) v13 + v8 * (unsigned) v12), v25, v26); + TLOAD(v23, v27); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v28 = Tile(v12, v11); + uint64_t v29 = (uint64_t) v17; + TASSIGN(v28, v29); + Tile v30 = Tile(v13, v11); + uint64_t v31 = (uint64_t) v15; + TASSIGN(v30, v31); + pto::Shape<1, 1, 1, 1, 512> v32 = pto::Shape<1, 1, 1, 1, 512>(); + pto::Stride<4096, 4096, 4096, 4096, 1> v33 = pto::Stride<4096, 4096, 4096, 4096, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND> v34 = GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND>(v2 + (v8 + (unsigned) v5 * (unsigned) v10 + (unsigned) v6 * (unsigned) v13), v32, v33); + TLOAD(v30, v34); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + Tile v35 = Tile(v12, v11); + uint64_t v36 = (uint64_t) v14; + TASSIGN(v35, v36); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v35, v28, v7); + Tile v37 = Tile(v12, v11); + uint64_t v38 = (uint64_t) v14; + TASSIGN(v37, v38); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TROWEXPANDMUL(v37, v35, v23); + Tile v39 = Tile(v12, v11); + uint64_t v40 = (uint64_t) v14; + TASSIGN(v39, v40); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + TCOLEXPANDMUL(v39, v37, v30); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 512> v41 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<8192, 8192, 8192, 512, 1> v42 = pto::Stride<8192, 8192, 8192, 512, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND> v43 = GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND>(v4 + (v8 + v8 * (unsigned) v11 + v8 * (unsigned) v13), v41, v42); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v43, v39); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: y_acc_inline109__rv_v2 + __gm__ Tensor* y_acc_inline109__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ int32_t* y_acc_inline109__rv_v2 = reinterpret_cast<__gm__ int32_t*>(y_acc_inline109__rv_v2_tensor->buffer.addr) + y_acc_inline109__rv_v2_tensor->start_offset; + + // Unpack tensor: expert_w2_scale__ssa_v0 + __gm__ Tensor* expert_w2_scale__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* expert_w2_scale__ssa_v0 = reinterpret_cast<__gm__ float*>(expert_w2_scale__ssa_v0_tensor->buffer.addr) + expert_w2_scale__ssa_v0_tensor->start_offset; + + // Unpack tensor: h_tile_scale_dq_inline63__ssa_v0 + __gm__ Tensor* h_tile_scale_dq_inline63__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* h_tile_scale_dq_inline63__ssa_v0 = reinterpret_cast<__gm__ float*>(h_tile_scale_dq_inline63__ssa_v0_tensor->buffer.addr) + h_tile_scale_dq_inline63__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack scalar: local_i_inline67__idx_v0 + union { uint64_t u64; int64_t val; } local_i_inline67__idx_v0_conv; + local_i_inline67__idx_v0_conv.u64 = args[4]; + int64_t local_i_inline67__idx_v0 = local_i_inline67__idx_v0_conv.val; + + // Unpack scalar: d0_inline49__idx_v0 + union { uint64_t u64; int64_t val; } d0_inline49__idx_v0_conv; + d0_inline49__idx_v0_conv.u64 = args[5]; + int64_t d0_inline49__idx_v0 = d0_inline49__idx_v0_conv.val; + + // Forward to ptoas-generated function + exp_w2_dequant(y_acc_inline109__rv_v2, expert_w2_scale__ssa_v0, h_tile_scale_dq_inline63__ssa_v0, ret0__out, local_i_inline67__idx_v0, d0_inline49__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/ffn_add.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/ffn_add.cpp new file mode 100644 index 000000000..e51770d93 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/ffn_add.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * FFN residual add: ``ffn_out = routed_y + sh``. + * + * routed_y : FP32 [T, D] (combine output) + * sh : BF16 [T, D] (moe_expert shared-expert output) + * ffn_out : BF16 [T, D] + */ + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include + +#include +#include "tensor.h" + +using namespace pto; + +static constexpr int T = 16; +static constexpr int D = 4096; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *routed_y_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *sh_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *ffn_out_tensor = reinterpret_cast<__gm__ Tensor *>(args[2]); + + __gm__ float *routed_y = + reinterpret_cast<__gm__ float *>(routed_y_tensor->buffer.addr) + routed_y_tensor->start_offset; + __gm__ bfloat16_t *sh = + reinterpret_cast<__gm__ bfloat16_t *>(sh_tensor->buffer.addr) + sh_tensor->start_offset; + __gm__ bfloat16_t *ffn_out = + reinterpret_cast<__gm__ bfloat16_t *>(ffn_out_tensor->buffer.addr) + ffn_out_tensor->start_offset; + + using RowFpG = GlobalTensor, Stride>; + using RowBfG = GlobalTensor, Stride>; + using RowFpTile = Tile; + using RowBfTile = Tile; + + // 4 tiles laid out within the 192 KiB AIV UB (0x30000 is OOB so the 4th + // slot uses a tighter 0x28000 offset instead of the usual 64 KB pitch). + RowFpTile routed_fp; + RowFpTile sh_fp; + RowBfTile sh_bf; + RowBfTile out_bf; + TASSIGN(routed_fp, 0x0); + TASSIGN(sh_fp, 0x10000); + TASSIGN(sh_bf, 0x20000); + TASSIGN(out_bf, 0x28000); + + for (int t = 0; t < T; ++t) { + RowFpG r_g(routed_y + t * D); + RowBfG s_g(sh + t * D); + RowBfG o_g(ffn_out + t * D); + + TLOAD(routed_fp, r_g); + TLOAD(sh_bf, s_g); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TCVT(sh_fp, sh_bf, RoundMode::CAST_RINT); + pipe_barrier(PIPE_V); + TADD(routed_fp, routed_fp, sh_fp); + pipe_barrier(PIPE_V); + TCVT(out_bf, routed_fp, RoundMode::CAST_RINT); + pipe_barrier(PIPE_V); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(o_g, out_bf); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/ffn_norm_apply.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/ffn_norm_apply.cpp new file mode 100644 index 000000000..885251cb8 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/ffn_norm_apply.cpp @@ -0,0 +1,160 @@ +// Kernel Function: ffn_norm_apply +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void ffn_norm_apply(__gm__ float* v1, __gm__ bfloat16_t* v2, __gm__ float* v3, __gm__ bfloat16_t* v4, __gm__ bfloat16_t* v5, int32_t v6) { + RoundMode v7 = RoundMode::CAST_RINT; + RoundMode v8 = RoundMode::CAST_ROUND; + unsigned v9 = 0; + const int32_t v10 = 512; + const int32_t v11 = 4096; + const int32_t v12 = 16; + const int32_t v13 = 1; + const int64_t v14 = 49216; + const int64_t v15 = 16448; + const int64_t v16 = 64; + const int64_t v17 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + int32_t v18 = (int32_t) ((uint32_t) v6 * (uint32_t) v10); + Tile v19 = Tile(v13, v12); + uint64_t v20 = (uint64_t) v17; + TASSIGN(v19, v20); + pto::Shape<1, 1, 1, 1, 16> v21 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<16, 16, 16, 16, 1> v22 = pto::Stride<16, 16, 16, 16, 1>(); + GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> v23 = GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>(v1 + (v9 + v9 * (unsigned) v12 + v9 * (unsigned) v13), v21, v22); + TLOAD(v19, v23); + Tile v24 = Tile(v12, v13); + uint64_t v25 = (uint64_t) v17; + TASSIGN(v24, v25); + Tile v26 = Tile(v12, v10); + uint64_t v27 = (uint64_t) v16; + TASSIGN(v26, v27); + pto::Shape<1, 1, 1, 16, 512> v28 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v29 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v30 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v2 + (v9 + v9 * (unsigned) v11 + (unsigned) v18 * (unsigned) v13), v28, v29); + TLOAD(v26, v30); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v31 = Tile(v12, v10); + uint64_t v32 = (uint64_t) v15; + TASSIGN(v31, v32); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v31, v26, v8); + Tile v33 = Tile(v13, v10); + uint64_t v34 = (uint64_t) v14; + TASSIGN(v33, v34); + pto::Shape<1, 1, 1, 1, 512> v35 = pto::Shape<1, 1, 1, 1, 512>(); + pto::Stride<512, 512, 512, 512, 1> v36 = pto::Stride<512, 512, 512, 512, 1>(); + GlobalTensor, pto::Stride<512, 512, 512, 512, 1>, pto::Layout::ND> v37 = GlobalTensor, pto::Stride<512, 512, 512, 512, 1>, pto::Layout::ND>(v3 + (v9 + (unsigned) v18 * (unsigned) v13), v35, v36); + TLOAD(v33, v37); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v38 = Tile(v13, v10); + uint64_t v39 = (uint64_t) v14; + TASSIGN(v38, v39); + Tile v40 = Tile(v12, v10); + uint64_t v41 = (uint64_t) v15; + TASSIGN(v40, v41); + pipe_barrier(PIPE_V); + TROWEXPANDMUL(v40, v31, v24); + Tile v42 = Tile(v12, v10); + uint64_t v43 = (uint64_t) v15; + TASSIGN(v42, v43); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TCOLEXPANDMUL(v42, v40, v38); + Tile v44 = Tile(v12, v10); + uint64_t v45 = (uint64_t) v16; + TASSIGN(v44, v45); + pipe_barrier(PIPE_V); + TCVT(v44, v42, v7); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 512> v46 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v47 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v48 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v4 + (v9 + v9 * (unsigned) v11 + (unsigned) v18 * (unsigned) v13), v46, v47); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v48, v44); + pto::Shape<1, 1, 1, 16, 512> v49 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v50 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v51 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v5 + (v9 + v9 * (unsigned) v11 + (unsigned) v18 * (unsigned) v13), v49, v50); + TSTORE(v51, v44); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: inv_rms_inline126__ssa_v1 + __gm__ Tensor* inv_rms_inline126__ssa_v1_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* inv_rms_inline126__ssa_v1 = reinterpret_cast<__gm__ float*>(inv_rms_inline126__ssa_v1_tensor->buffer.addr) + inv_rms_inline126__ssa_v1_tensor->start_offset; + + // Unpack tensor: x_mixed_flat_inline127__ssa_v0 + __gm__ Tensor* x_mixed_flat_inline127__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ bfloat16_t* x_mixed_flat_inline127__ssa_v0 = reinterpret_cast<__gm__ bfloat16_t*>(x_mixed_flat_inline127__ssa_v0_tensor->buffer.addr) + x_mixed_flat_inline127__ssa_v0_tensor->start_offset; + + // Unpack tensor: norm_w__ssa_v0 + __gm__ Tensor* norm_w__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* norm_w__ssa_v0 = reinterpret_cast<__gm__ float*>(norm_w__ssa_v0_tensor->buffer.addr) + norm_w__ssa_v0_tensor->start_offset; + + // Unpack tensor: x_norm_bf16_inline111__iter_v1 + __gm__ Tensor* x_norm_bf16_inline111__iter_v1_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ bfloat16_t* x_norm_bf16_inline111__iter_v1 = reinterpret_cast<__gm__ bfloat16_t*>(x_norm_bf16_inline111__iter_v1_tensor->buffer.addr) + x_norm_bf16_inline111__iter_v1_tensor->start_offset; + + // Unpack tensor: x_norm__iter_v1 + __gm__ Tensor* x_norm__iter_v1_tensor = reinterpret_cast<__gm__ Tensor*>(args[4]); + __gm__ bfloat16_t* x_norm__iter_v1 = reinterpret_cast<__gm__ bfloat16_t*>(x_norm__iter_v1_tensor->buffer.addr) + x_norm__iter_v1_tensor->start_offset; + + // Unpack scalar: db_inline125__idx_v0 + union { uint64_t u64; int64_t val; } db_inline125__idx_v0_conv; + db_inline125__idx_v0_conv.u64 = args[5]; + int64_t db_inline125__idx_v0 = db_inline125__idx_v0_conv.val; + + // Forward to ptoas-generated function + ffn_norm_apply(inv_rms_inline126__ssa_v1, x_mixed_flat_inline127__ssa_v0, norm_w__ssa_v0, x_norm_bf16_inline111__iter_v1, x_norm__iter_v1, db_inline125__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/ffn_norm_rms.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/ffn_norm_rms.cpp new file mode 100644 index 000000000..c648ef28c --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/ffn_norm_rms.cpp @@ -0,0 +1,279 @@ +// Kernel Function: ffn_norm_rms +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void ffn_norm_rms(__gm__ bfloat16_t* v1, __gm__ float* v2) { + RoundMode v3 = RoundMode::CAST_ROUND; + unsigned v4 = 0; + const float v5 = 9.99999997E-7f; + const float v6 = 2.44140625E-4f; + const int32_t v7 = 1536; + const int32_t v8 = 1024; + const int32_t v9 = 512; + const int32_t v10 = 4; + const int32_t v11 = 8; + const int32_t v12 = 0; + const float v13 = 0.0f; + const int32_t v14 = 1; + const int32_t v15 = 4096; + const int32_t v16 = 16; + const int64_t v17 = 64; + const int64_t v18 = 0; + const int64_t v19 = 98496; + const int64_t v20 = 65728; + const int64_t v21 = 49344; + const int64_t v22 = 32960; + const int64_t v23 = 16576; + const int64_t v24 = 192; + const int64_t v25 = 128; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v26 = Tile(v14, v16); + uint64_t v27 = (uint64_t) v25; + TASSIGN(v26, v27); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + TEXPANDS(v26, v13); + for (size_t v28 = (size_t) v12; v28 < ((size_t) v11); v28 += (size_t) v10) { + int32_t v29 = (int32_t) ((uint32_t) ((int32_t) v28) * (uint32_t) v9); + Tile v30 = Tile(v16, v9); + uint64_t v31 = (uint64_t) v24; + TASSIGN(v30, v31); + pto::Shape<1, 1, 1, 16, 512> v32 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v33 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v34 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v4 + v4 * (unsigned) v15 + (unsigned) v29 * (unsigned) v14), v32, v33); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v30, v34); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v35 = Tile(v16, v9); + uint64_t v36 = (uint64_t) v23; + TASSIGN(v35, v36); + pto::Shape<1, 1, 1, 16, 512> v37 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v38 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v39 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v4 + v4 * (unsigned) v15 + (unsigned) ((int32_t) (uint32_t) v29 + (uint32_t) v9) * (unsigned) v14), v37, v38); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v35, v39); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v40 = Tile(v16, v9); + uint64_t v41 = (uint64_t) v22; + TASSIGN(v40, v41); + pto::Shape<1, 1, 1, 16, 512> v42 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v43 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v44 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v4 + v4 * (unsigned) v15 + (unsigned) ((int32_t) (uint32_t) v29 + (uint32_t) v8) * (unsigned) v14), v42, v43); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v40, v44); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + Tile v45 = Tile(v16, v9); + uint64_t v46 = (uint64_t) v21; + TASSIGN(v45, v46); + pto::Shape<1, 1, 1, 16, 512> v47 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v48 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v49 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v4 + v4 * (unsigned) v15 + (unsigned) ((int32_t) (uint32_t) v29 + (uint32_t) v7) * (unsigned) v14), v47, v48); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + TLOAD(v45, v49); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + Tile v50 = Tile(v16, v9); + uint64_t v51 = (uint64_t) v20; + TASSIGN(v50, v51); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v50, v30, v3); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v52 = Tile(v16, v9); + uint64_t v53 = (uint64_t) v20; + TASSIGN(v52, v53); + pipe_barrier(PIPE_V); + TMUL(v52, v50, v50); + Tile v54 = Tile(v16, v9); + uint64_t v55 = (uint64_t) v19; + TASSIGN(v54, v55); + Tile v56 = Tile(v16, v14); + uint64_t v57 = (uint64_t) v18; + TASSIGN(v56, v57); + pipe_barrier(PIPE_V); + TROWSUM(v56, v52, v54); + Tile v58 = Tile(v14, v16); + uint64_t v59 = (uint64_t) v18; + TASSIGN(v58, v59); + Tile v60 = Tile(v14, v16); + uint64_t v61 = (uint64_t) v17; + TASSIGN(v60, v61); + pipe_barrier(PIPE_V); + TADD(v60, v26, v58); + Tile v62 = Tile(v16, v9); + uint64_t v63 = (uint64_t) v20; + TASSIGN(v62, v63); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TCVT(v62, v35, v3); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + Tile v64 = Tile(v16, v9); + uint64_t v65 = (uint64_t) v20; + TASSIGN(v64, v65); + pipe_barrier(PIPE_V); + TMUL(v64, v62, v62); + Tile v66 = Tile(v16, v9); + uint64_t v67 = (uint64_t) v19; + TASSIGN(v66, v67); + Tile v68 = Tile(v16, v14); + uint64_t v69 = (uint64_t) v18; + TASSIGN(v68, v69); + pipe_barrier(PIPE_V); + TROWSUM(v68, v64, v66); + Tile v70 = Tile(v14, v16); + uint64_t v71 = (uint64_t) v18; + TASSIGN(v70, v71); + Tile v72 = Tile(v14, v16); + uint64_t v73 = (uint64_t) v17; + TASSIGN(v72, v73); + pipe_barrier(PIPE_V); + TADD(v72, v60, v70); + Tile v74 = Tile(v16, v9); + uint64_t v75 = (uint64_t) v20; + TASSIGN(v74, v75); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + TCVT(v74, v40, v3); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + Tile v76 = Tile(v16, v9); + uint64_t v77 = (uint64_t) v20; + TASSIGN(v76, v77); + pipe_barrier(PIPE_V); + TMUL(v76, v74, v74); + Tile v78 = Tile(v16, v9); + uint64_t v79 = (uint64_t) v19; + TASSIGN(v78, v79); + Tile v80 = Tile(v16, v14); + uint64_t v81 = (uint64_t) v18; + TASSIGN(v80, v81); + pipe_barrier(PIPE_V); + TROWSUM(v80, v76, v78); + Tile v82 = Tile(v14, v16); + uint64_t v83 = (uint64_t) v18; + TASSIGN(v82, v83); + Tile v84 = Tile(v14, v16); + uint64_t v85 = (uint64_t) v17; + TASSIGN(v84, v85); + pipe_barrier(PIPE_V); + TADD(v84, v72, v82); + Tile v86 = Tile(v16, v9); + uint64_t v87 = (uint64_t) v20; + TASSIGN(v86, v87); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + TCVT(v86, v45, v3); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + Tile v88 = Tile(v16, v9); + uint64_t v89 = (uint64_t) v20; + TASSIGN(v88, v89); + pipe_barrier(PIPE_V); + TMUL(v88, v86, v86); + Tile v90 = Tile(v16, v9); + uint64_t v91 = (uint64_t) v19; + TASSIGN(v90, v91); + Tile v92 = Tile(v16, v14); + uint64_t v93 = (uint64_t) v18; + TASSIGN(v92, v93); + pipe_barrier(PIPE_V); + TROWSUM(v92, v88, v90); + Tile v94 = Tile(v14, v16); + uint64_t v95 = (uint64_t) v18; + TASSIGN(v94, v95); + Tile v96 = Tile(v14, v16); + uint64_t v97 = (uint64_t) v25; + TASSIGN(v96, v97); + pipe_barrier(PIPE_V); + TADD(v96, v84, v94); + } + Tile v98 = Tile(v14, v16); + uint64_t v99 = (uint64_t) v25; + TASSIGN(v98, v99); + pipe_barrier(PIPE_V); + TMULS(v98, v26, v6); + Tile v100 = Tile(v14, v16); + uint64_t v101 = (uint64_t) v25; + TASSIGN(v100, v101); + pipe_barrier(PIPE_V); + TADDS(v100, v98, v5); + Tile v102 = Tile(v14, v16); + uint64_t v103 = (uint64_t) v25; + TASSIGN(v102, v103); + pipe_barrier(PIPE_V); + TSQRT(v102, v100); + Tile v104 = Tile(v14, v16); + uint64_t v105 = (uint64_t) v17; + TASSIGN(v104, v105); + pipe_barrier(PIPE_V); + TRECIP(v104, v102); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v106 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<16, 16, 16, 16, 1> v107 = pto::Stride<16, 16, 16, 16, 1>(); + GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> v108 = GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>(v2 + (v4 + v4 * (unsigned) v16 + v4 * (unsigned) v14), v106, v107); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v108, v104); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: x_mixed_flat_inline127__ssa_v0 + __gm__ Tensor* x_mixed_flat_inline127__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ bfloat16_t* x_mixed_flat_inline127__ssa_v0 = reinterpret_cast<__gm__ bfloat16_t*>(x_mixed_flat_inline127__ssa_v0_tensor->buffer.addr) + x_mixed_flat_inline127__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Forward to ptoas-generated function + ffn_norm_rms(x_mixed_flat_inline127__ssa_v0, ret0__out); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/gate_dot.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/gate_dot.cpp new file mode 100644 index 000000000..6bfda5720 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/gate_dot.cpp @@ -0,0 +1,83 @@ +// Kernel Function: gate_dot +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void gate_dot(__gm__ float* v1) { + unsigned v2 = 0; + const float v3 = -1.00000002E+30f; + const int32_t v4 = 1; + const int32_t v5 = 32; + const int32_t v6 = 16; + const int64_t v7 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v8 = Tile(v6, v5); + uint64_t v9 = (uint64_t) v7; + TASSIGN(v8, v9); + TEXPANDS(v8, v3); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 32> v10 = pto::Shape<1, 1, 1, 16, 32>(); + pto::Stride<512, 512, 512, 32, 1> v11 = pto::Stride<512, 512, 512, 32, 1>(); + GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND> v12 = GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND>(v1 + (v2 + v2 * (unsigned) v5 + v2 * (unsigned) v4), v10, v11); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v12, v8); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Forward to ptoas-generated function + gate_dot(ret0__out); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/gate_dot_0.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/gate_dot_0.cpp new file mode 100644 index 000000000..b59045037 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/gate_dot_0.cpp @@ -0,0 +1,317 @@ +// Kernel Function: gate_dot_0 +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void gate_dot_0(__gm__ float* v1, __gm__ bfloat16_t* v2, __gm__ float* v3, __gm__ float* v4, __gm__ float* v5) { + RoundMode v6 = RoundMode::CAST_ROUND; + unsigned v7 = 0; + const int32_t v8 = 480; + const int32_t v9 = 15; + const int32_t v10 = 448; + const int32_t v11 = 14; + const int32_t v12 = 416; + const int32_t v13 = 13; + const int32_t v14 = 384; + const int32_t v15 = 12; + const int32_t v16 = 352; + const int32_t v17 = 11; + const int32_t v18 = 320; + const int32_t v19 = 10; + const int32_t v20 = 288; + const int32_t v21 = 9; + const int32_t v22 = 256; + const int32_t v23 = 224; + const int32_t v24 = 7; + const int32_t v25 = 192; + const int32_t v26 = 6; + const int32_t v27 = 160; + const int32_t v28 = 5; + const int32_t v29 = 128; + const int32_t v30 = 4; + const int32_t v31 = 96; + const int32_t v32 = 3; + const int32_t v33 = 64; + const int32_t v34 = 2; + const int32_t v35 = 32; + const float v36 = 1.0f; + const float v37 = 0.0f; + const int32_t v38 = 0; + const int32_t v39 = 512; + const int32_t v40 = 8; + const int32_t v41 = 4096; + const int32_t v42 = 16; + const int32_t v43 = 1; + const int64_t v44 = 32896; + const int64_t v45 = 32832; + const int64_t v46 = 32768; + const int64_t v47 = 0; + const int64_t v48 = 82176; + const int64_t v49 = 49408; + const int64_t v50 = 33024; + const int64_t v51 = 32960; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + size_t v52 = (size_t) v43; + size_t v53 = (size_t) v40; + size_t v54 = (size_t) v38; + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + for (size_t v55 = v54; v55 < v53; v55 += v52) { + int32_t v56 = (int32_t) v55; + Tile v57 = Tile(v43, v42); + uint64_t v58 = (uint64_t) v51; + TASSIGN(v57, v58); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + TEXPANDS(v57, v37); + for (size_t v59 = v54; v59 < v53; v59 += v52) { + int32_t v60 = (int32_t) ((uint32_t) ((int32_t) v59) * (uint32_t) v39); + Tile v61 = Tile(v42, v39); + uint64_t v62 = (uint64_t) v50; + TASSIGN(v61, v62); + pto::Shape<1, 1, 1, 16, 512> v63 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v64 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v65 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v2 + (v7 + v7 * (unsigned) v41 + (unsigned) v60 * (unsigned) v43), v63, v64); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v61, v65); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v66 = Tile(v42, v39); + uint64_t v67 = (uint64_t) v49; + TASSIGN(v66, v67); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v66, v61, v6); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v68 = Tile(v43, v39); + uint64_t v69 = (uint64_t) v48; + TASSIGN(v68, v69); + pto::Shape<1, 1, 1, 1, 512> v70 = pto::Shape<1, 1, 1, 1, 512>(); + pto::Stride<4096, 4096, 4096, 4096, 1> v71 = pto::Stride<4096, 4096, 4096, 4096, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND> v72 = GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND>(v3 + (v7 + (unsigned) v56 * (unsigned) v41 + (unsigned) v60 * (unsigned) v43), v70, v71); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v68, v72); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v73 = Tile(v42, v39); + uint64_t v74 = (uint64_t) v49; + TASSIGN(v73, v74); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + pipe_barrier(PIPE_V); + TCOLEXPANDMUL(v73, v66, v68); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + Tile v75 = Tile(v42, v39); + uint64_t v76 = (uint64_t) v47; + TASSIGN(v75, v76); + Tile v77 = Tile(v42, v43); + uint64_t v78 = (uint64_t) v46; + TASSIGN(v77, v78); + pipe_barrier(PIPE_V); + TROWSUM(v77, v73, v75); + Tile v79 = Tile(v43, v42); + uint64_t v80 = (uint64_t) v46; + TASSIGN(v79, v80); + Tile v81 = Tile(v43, v42); + uint64_t v82 = (uint64_t) v51; + TASSIGN(v81, v82); + pipe_barrier(PIPE_V); + TADD(v81, v57, v79); + }; + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v83 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<16, 16, 16, 16, 1> v84 = pto::Stride<16, 16, 16, 16, 1>(); + GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> v85 = GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>(v1 + (v7 + v7 * (unsigned) v42 + v7 * (unsigned) v43), v83, v84); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v85, v57); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + float v86 = v4[v55]; + Tile v87 = Tile(v43, v42); + uint64_t v88 = (uint64_t) v51; + TASSIGN(v87, v88); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(v87, v85); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + Tile v89 = Tile(v43, v42); + uint64_t v90 = (uint64_t) v45; + TASSIGN(v89, v90); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + TMULS(v89, v87, v37); + Tile v91 = Tile(v43, v42); + uint64_t v92 = (uint64_t) v45; + TASSIGN(v91, v92); + pipe_barrier(PIPE_V); + TMAX(v91, v87, v89); + Tile v93 = Tile(v43, v42); + uint64_t v94 = (uint64_t) v44; + TASSIGN(v93, v94); + TNEG(v93, v87); + Tile v95 = Tile(v43, v42); + uint64_t v96 = (uint64_t) v51; + TASSIGN(v95, v96); + pipe_barrier(PIPE_V); + TMAX(v95, v87, v93); + Tile v97 = Tile(v43, v42); + uint64_t v98 = (uint64_t) v51; + TASSIGN(v97, v98); + pipe_barrier(PIPE_V); + TNEG(v97, v95); + Tile v99 = Tile(v43, v42); + uint64_t v100 = (uint64_t) v51; + TASSIGN(v99, v100); + pipe_barrier(PIPE_V); + TEXP(v99, v97); + Tile v101 = Tile(v43, v42); + uint64_t v102 = (uint64_t) v51; + TASSIGN(v101, v102); + pipe_barrier(PIPE_V); + TADDS(v101, v99, v36); + Tile v103 = Tile(v43, v42); + uint64_t v104 = (uint64_t) v51; + TASSIGN(v103, v104); + pipe_barrier(PIPE_V); + TLOG(v103, v101); + Tile v105 = Tile(v43, v42); + uint64_t v106 = (uint64_t) v51; + TASSIGN(v105, v106); + pipe_barrier(PIPE_V); + TADD(v105, v91, v103); + Tile v107 = Tile(v43, v42); + uint64_t v108 = (uint64_t) v51; + TASSIGN(v107, v108); + pipe_barrier(PIPE_V); + TSQRT(v107, v105); + Tile v109 = Tile(v43, v42); + uint64_t v110 = (uint64_t) v51; + TASSIGN(v109, v110); + pipe_barrier(PIPE_V); + TADDS(v109, v107, v86); + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + Tile v111 = Tile(v43, v42); + uint64_t v112 = (uint64_t) v51; + TASSIGN(v111, v112); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + float v113 = v111.GetValue(v38); + v5[v55] = v113; + float v114 = v111.GetValue(v43); + int32_t v115 = (int32_t) ((uint32_t) v56 + (uint32_t) v35); + v5[v115] = v114; + float v116 = v111.GetValue(v34); + int32_t v117 = (int32_t) ((uint32_t) v56 + (uint32_t) v33); + v5[v117] = v116; + float v118 = v111.GetValue(v32); + int32_t v119 = (int32_t) ((uint32_t) v56 + (uint32_t) v31); + v5[v119] = v118; + float v120 = v111.GetValue(v30); + int32_t v121 = (int32_t) ((uint32_t) v56 + (uint32_t) v29); + v5[v121] = v120; + float v122 = v111.GetValue(v28); + int32_t v123 = (int32_t) ((uint32_t) v56 + (uint32_t) v27); + v5[v123] = v122; + float v124 = v111.GetValue(v26); + int32_t v125 = (int32_t) ((uint32_t) v56 + (uint32_t) v25); + v5[v125] = v124; + float v126 = v111.GetValue(v24); + int32_t v127 = (int32_t) ((uint32_t) v56 + (uint32_t) v23); + v5[v127] = v126; + float v128 = v111.GetValue(v40); + int32_t v129 = (int32_t) ((uint32_t) v56 + (uint32_t) v22); + v5[v129] = v128; + float v130 = v111.GetValue(v21); + int32_t v131 = (int32_t) ((uint32_t) v56 + (uint32_t) v20); + v5[v131] = v130; + float v132 = v111.GetValue(v19); + int32_t v133 = (int32_t) ((uint32_t) v56 + (uint32_t) v18); + v5[v133] = v132; + float v134 = v111.GetValue(v17); + int32_t v135 = (int32_t) ((uint32_t) v56 + (uint32_t) v16); + v5[v135] = v134; + float v136 = v111.GetValue(v15); + int32_t v137 = (int32_t) ((uint32_t) v56 + (uint32_t) v14); + v5[v137] = v136; + float v138 = v111.GetValue(v13); + int32_t v139 = (int32_t) ((uint32_t) v56 + (uint32_t) v12); + v5[v139] = v138; + float v140 = v111.GetValue(v11); + int32_t v141 = (int32_t) ((uint32_t) v56 + (uint32_t) v10); + v5[v141] = v140; + float v142 = v111.GetValue(v9); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + int32_t v143 = (int32_t) ((uint32_t) v56 + (uint32_t) v8); + v5[v143] = v142; + } + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: score_acc_buf_inline117__ssa_v0 + __gm__ Tensor* score_acc_buf_inline117__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* score_acc_buf_inline117__ssa_v0 = reinterpret_cast<__gm__ float*>(score_acc_buf_inline117__ssa_v0_tensor->buffer.addr) + score_acc_buf_inline117__ssa_v0_tensor->start_offset; + + // Unpack tensor: x_norm_bf16_inline111__rv_v2 + __gm__ Tensor* x_norm_bf16_inline111__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ bfloat16_t* x_norm_bf16_inline111__rv_v2 = reinterpret_cast<__gm__ bfloat16_t*>(x_norm_bf16_inline111__rv_v2_tensor->buffer.addr) + x_norm_bf16_inline111__rv_v2_tensor->start_offset; + + // Unpack tensor: gate_w__ssa_v0 + __gm__ Tensor* gate_w__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* gate_w__ssa_v0 = reinterpret_cast<__gm__ float*>(gate_w__ssa_v0_tensor->buffer.addr) + gate_w__ssa_v0_tensor->start_offset; + + // Unpack tensor: gate_bias__ssa_v0 + __gm__ Tensor* gate_bias__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ float* gate_bias__ssa_v0 = reinterpret_cast<__gm__ float*>(gate_bias__ssa_v0_tensor->buffer.addr) + gate_bias__ssa_v0_tensor->start_offset; + + // Unpack tensor: biased_flat_inline133__ssa_v0 + __gm__ Tensor* biased_flat_inline133__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[4]); + __gm__ float* biased_flat_inline133__ssa_v0 = reinterpret_cast<__gm__ float*>(biased_flat_inline133__ssa_v0_tensor->buffer.addr) + biased_flat_inline133__ssa_v0_tensor->start_offset; + + // Forward to ptoas-generated function + gate_dot_0(score_acc_buf_inline117__ssa_v0, x_norm_bf16_inline111__rv_v2, gate_w__ssa_v0, gate_bias__ssa_v0, biased_flat_inline133__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/hc_post.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/hc_post.cpp new file mode 100644 index 000000000..b4900a193 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/hc_post.cpp @@ -0,0 +1,188 @@ +// Kernel Function: hc_post +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void hc_post(__gm__ bfloat16_t* v1, __gm__ float* v2, __gm__ bfloat16_t* v3, __gm__ float* v4, __gm__ bfloat16_t* v5, int32_t v6, int32_t v7) { + RoundMode v8 = RoundMode::CAST_ROUND; + unsigned v9 = 0; + const int32_t v10 = 512; + const int32_t v11 = 8; + const int32_t v12 = 4; + const int32_t v13 = 0; + const int32_t v14 = 4096; + const int32_t v15 = 1; + const int32_t v16 = 16384; + const int32_t v17 = 16; + const int64_t v18 = 3072; + const int64_t v19 = 1024; + const int64_t v20 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + size_t v21 = (size_t) v15; + size_t v22 = (size_t) v13; + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + for (size_t v23 = v22; v23 < ((size_t) v17); v23 += v21) { + int32_t v24 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v6 * (uint32_t) v17) + (uint32_t) ((int32_t) v23)); + __gm__ bfloat16_t* v25; + if (v24 < v17) { + float v26 = v2[(int32_t) ((uint32_t) ((int32_t) (uint32_t) v24 * (uint32_t) v12) + (uint32_t) v7)]; + for (size_t v27 = v22; v27 < ((size_t) v11); v27 += v21) { + int32_t v28 = (int32_t) ((uint32_t) ((int32_t) v27) * (uint32_t) v10); + Tile v29 = Tile(v15, v10); + uint64_t v30 = (uint64_t) v20; + TASSIGN(v29, v30); + pto::Shape<1, 1, 1, 1, 512> v31 = pto::Shape<1, 1, 1, 1, 512>(); + pto::Stride<4096, 4096, 4096, 4096, 1> v32 = pto::Stride<4096, 4096, 4096, 4096, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND> v33 = GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND>(v3 + (v9 + (unsigned) v24 * (unsigned) v14 + (unsigned) v28 * (unsigned) v15), v31, v32); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + TLOAD(v29, v33); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v34 = Tile(v15, v10); + uint64_t v35 = (uint64_t) v19; + TASSIGN(v34, v35); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v34, v29, v8); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v36 = Tile(v15, v10); + uint64_t v37 = (uint64_t) v19; + TASSIGN(v36, v37); + pipe_barrier(PIPE_V); + TMULS(v36, v34, v26); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + for (size_t v38 = v22; v38 < ((size_t) v12); v38 += v21) { + int32_t v39 = (int32_t) v38; + float v40 = v4[(int32_t) ((uint32_t) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v24 * (uint32_t) v17) + (uint32_t) ((int32_t) (uint32_t) v39 * (uint32_t) v12)) + (uint32_t) v7)]; + Tile v41 = Tile(v15, v10); + uint64_t v42 = (uint64_t) v20; + TASSIGN(v41, v42); + pto::Shape<1, 1, 1, 1, 512> v43 = pto::Shape<1, 1, 1, 1, 512>(); + pto::Stride<16384, 16384, 16384, 16384, 1> v44 = pto::Stride<16384, 16384, 16384, 16384, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 16384, 1>, pto::Layout::ND> v45 = GlobalTensor, pto::Stride<16384, 16384, 16384, 16384, 1>, pto::Layout::ND>(v5 + (v9 + (unsigned) v24 * (unsigned) v16 + (unsigned) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v39 * (uint32_t) v14) + (uint32_t) v28) * (unsigned) v15), v43, v44); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v41, v45); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v46 = Tile(v15, v10); + uint64_t v47 = (uint64_t) v18; + TASSIGN(v46, v47); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + pipe_barrier(PIPE_V); + TCVT(v46, v41, v8); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + Tile v48 = Tile(v15, v10); + uint64_t v49 = (uint64_t) v18; + TASSIGN(v48, v49); + pipe_barrier(PIPE_V); + TMULS(v48, v46, v40); + Tile v50 = Tile(v15, v10); + uint64_t v51 = (uint64_t) v19; + TASSIGN(v50, v51); + pipe_barrier(PIPE_V); + TADD(v50, v36, v48); + }; + Tile v52 = Tile(v15, v10); + uint64_t v53 = (uint64_t) v20; + TASSIGN(v52, v53); + pipe_barrier(PIPE_V); + TCVT(v52, v36, v8); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 512> v54 = pto::Shape<1, 1, 1, 1, 512>(); + pto::Stride<16384, 16384, 16384, 16384, 1> v55 = pto::Stride<16384, 16384, 16384, 16384, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 16384, 1>, pto::Layout::ND> v56 = GlobalTensor, pto::Stride<16384, 16384, 16384, 16384, 1>, pto::Layout::ND>(v1 + (v9 + (unsigned) v24 * (unsigned) v16 + (unsigned) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v7 * (uint32_t) v14) + (uint32_t) v28) * (unsigned) v15), v54, v55); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v56, v52); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + }; + v25 = v1; + } else { + v25 = v1; + }; + } + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: y_flat_inline7__co_l0_iter_v3 + __gm__ Tensor* y_flat_inline7__co_l0_iter_v3_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ bfloat16_t* y_flat_inline7__co_l0_iter_v3 = reinterpret_cast<__gm__ bfloat16_t*>(y_flat_inline7__co_l0_iter_v3_tensor->buffer.addr) + y_flat_inline7__co_l0_iter_v3_tensor->start_offset; + + // Unpack tensor: post_flat_inline8__ssa_v0 + __gm__ Tensor* post_flat_inline8__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* post_flat_inline8__ssa_v0 = reinterpret_cast<__gm__ float*>(post_flat_inline8__ssa_v0_tensor->buffer.addr) + post_flat_inline8__ssa_v0_tensor->start_offset; + + // Unpack tensor: x_flat_inline2__ssa_v0 + __gm__ Tensor* x_flat_inline2__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ bfloat16_t* x_flat_inline2__ssa_v0 = reinterpret_cast<__gm__ bfloat16_t*>(x_flat_inline2__ssa_v0_tensor->buffer.addr) + x_flat_inline2__ssa_v0_tensor->start_offset; + + // Unpack tensor: comb_flat_inline9__ssa_v0 + __gm__ Tensor* comb_flat_inline9__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ float* comb_flat_inline9__ssa_v0 = reinterpret_cast<__gm__ float*>(comb_flat_inline9__ssa_v0_tensor->buffer.addr) + comb_flat_inline9__ssa_v0_tensor->start_offset; + + // Unpack tensor: residual_flat_inline6__ssa_v0 + __gm__ Tensor* residual_flat_inline6__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[4]); + __gm__ bfloat16_t* residual_flat_inline6__ssa_v0 = reinterpret_cast<__gm__ bfloat16_t*>(residual_flat_inline6__ssa_v0_tensor->buffer.addr) + residual_flat_inline6__ssa_v0_tensor->start_offset; + + // Unpack scalar: t_inline10__co_idx_v0 + union { uint64_t u64; int64_t val; } t_inline10__co_idx_v0_conv; + t_inline10__co_idx_v0_conv.u64 = args[5]; + int64_t t_inline10__co_idx_v0 = t_inline10__co_idx_v0_conv.val; + + // Unpack scalar: out_h_inline3__idx_v0 + union { uint64_t u64; int64_t val; } out_h_inline3__idx_v0_conv; + out_h_inline3__idx_v0_conv.u64 = args[6]; + int64_t out_h_inline3__idx_v0 = out_h_inline3__idx_v0_conv.val; + + // Forward to ptoas-generated function + hc_post(y_flat_inline7__co_l0_iter_v3, post_flat_inline8__ssa_v0, x_flat_inline2__ssa_v0, comb_flat_inline9__ssa_v0, residual_flat_inline6__ssa_v0, t_inline10__co_idx_v0, out_h_inline3__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/linear.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/linear.cpp new file mode 100644 index 000000000..e81e8a388 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/linear.cpp @@ -0,0 +1,244 @@ +// Kernel Function: linear +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void linear(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, __gm__ float* v4) { + unsigned v5 = 0; + const int32_t v6 = 480; + const int32_t v7 = 15; + const int32_t v8 = 448; + const int32_t v9 = 14; + const int32_t v10 = 416; + const int32_t v11 = 13; + const int32_t v12 = 384; + const int32_t v13 = 12; + const int32_t v14 = 352; + const int32_t v15 = 11; + const int32_t v16 = 320; + const int32_t v17 = 10; + const int32_t v18 = 288; + const int32_t v19 = 9; + const int32_t v20 = 256; + const int32_t v21 = 8; + const int32_t v22 = 224; + const int32_t v23 = 7; + const int32_t v24 = 192; + const int32_t v25 = 6; + const int32_t v26 = 160; + const int32_t v27 = 5; + const int32_t v28 = 128; + const int32_t v29 = 4; + const int32_t v30 = 96; + const int32_t v31 = 3; + const int32_t v32 = 64; + const int32_t v33 = 2; + const int32_t v34 = 32; + const float v35 = 0.0f; + const int32_t v36 = 0; + const int32_t v37 = 512; + const int32_t v38 = 24; + const int32_t v39 = 1; + const int32_t v40 = 16384; + const int32_t v41 = 16; + const int64_t v42 = 0; + const int64_t v43 = 35008; + const int64_t v44 = 32960; + const int64_t v45 = 192; + const int64_t v46 = 128; + const int64_t v47 = 64; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + size_t v48 = (size_t) v39; + size_t v49 = (size_t) v36; + Tile v50 = Tile(v39, v41); + uint64_t v51 = (uint64_t) v47; + TASSIGN(v50, v51); + pto::Shape<1, 1, 1, 1, 16> v52 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<16, 16, 16, 16, 1> v53 = pto::Stride<16, 16, 16, 16, 1>(); + GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> v54 = GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>(v3 + (v5 + v5 * (unsigned) v41 + v5 * (unsigned) v39), v52, v53); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v50, v54); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + for (size_t v55 = v49; v55 < ((size_t) v38); v55 += v48) { + int32_t v56 = (int32_t) v55; + Tile v57 = Tile(v39, v41); + uint64_t v58 = (uint64_t) v46; + TASSIGN(v57, v58); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + TEXPANDS(v57, v35); + for (size_t v59 = v49; v59 < ((size_t) v34); v59 += v48) { + int32_t v60 = (int32_t) ((uint32_t) ((int32_t) v59) * (uint32_t) v37); + Tile v61 = Tile(v41, v37); + uint64_t v62 = (uint64_t) v45; + TASSIGN(v61, v62); + pto::Shape<1, 1, 1, 16, 512> v63 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<262144, 262144, 262144, 16384, 1> v64 = pto::Stride<262144, 262144, 262144, 16384, 1>(); + GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND> v65 = GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND>(v1 + (v5 + v5 * (unsigned) v40 + (unsigned) v60 * (unsigned) v39), v63, v64); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v61, v65); + Tile v66 = Tile(v39, v37); + uint64_t v67 = (uint64_t) v44; + TASSIGN(v66, v67); + pto::Shape<1, 1, 1, 1, 512> v68 = pto::Shape<1, 1, 1, 1, 512>(); + pto::Stride<16384, 16384, 16384, 16384, 1> v69 = pto::Stride<16384, 16384, 16384, 16384, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 16384, 1>, pto::Layout::ND> v70 = GlobalTensor, pto::Stride<16384, 16384, 16384, 16384, 1>, pto::Layout::ND>(v2 + (v5 + (unsigned) v56 * (unsigned) v40 + (unsigned) v60 * (unsigned) v39), v68, v69); + TLOAD(v66, v70); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v71 = Tile(v41, v37); + uint64_t v72 = (uint64_t) v45; + TASSIGN(v71, v72); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TCOLEXPANDMUL(v71, v61, v66); + Tile v73 = Tile(v41, v37); + uint64_t v74 = (uint64_t) v43; + TASSIGN(v73, v74); + Tile v75 = Tile(v41, v39); + uint64_t v76 = (uint64_t) v42; + TASSIGN(v75, v76); + pipe_barrier(PIPE_V); + TROWSUM(v75, v71, v73); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v77 = Tile(v39, v41); + uint64_t v78 = (uint64_t) v42; + TASSIGN(v77, v78); + Tile v79 = Tile(v39, v41); + uint64_t v80 = (uint64_t) v46; + TASSIGN(v79, v80); + pipe_barrier(PIPE_V); + TADD(v79, v57, v77); + }; + Tile v81 = Tile(v39, v41); + uint64_t v82 = (uint64_t) v46; + TASSIGN(v81, v82); + pipe_barrier(PIPE_V); + TMUL(v81, v57, v50); + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + Tile v83 = Tile(v39, v41); + uint64_t v84 = (uint64_t) v46; + TASSIGN(v83, v84); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + float v85 = v83.GetValue(v36); + v4[v55] = v85; + float v86 = v83.GetValue(v39); + int32_t v87 = (int32_t) ((uint32_t) v56 + (uint32_t) v34); + v4[v87] = v86; + float v88 = v83.GetValue(v33); + int32_t v89 = (int32_t) ((uint32_t) v56 + (uint32_t) v32); + v4[v89] = v88; + float v90 = v83.GetValue(v31); + int32_t v91 = (int32_t) ((uint32_t) v56 + (uint32_t) v30); + v4[v91] = v90; + float v92 = v83.GetValue(v29); + int32_t v93 = (int32_t) ((uint32_t) v56 + (uint32_t) v28); + v4[v93] = v92; + float v94 = v83.GetValue(v27); + int32_t v95 = (int32_t) ((uint32_t) v56 + (uint32_t) v26); + v4[v95] = v94; + float v96 = v83.GetValue(v25); + int32_t v97 = (int32_t) ((uint32_t) v56 + (uint32_t) v24); + v4[v97] = v96; + float v98 = v83.GetValue(v23); + int32_t v99 = (int32_t) ((uint32_t) v56 + (uint32_t) v22); + v4[v99] = v98; + float v100 = v83.GetValue(v21); + int32_t v101 = (int32_t) ((uint32_t) v56 + (uint32_t) v20); + v4[v101] = v100; + float v102 = v83.GetValue(v19); + int32_t v103 = (int32_t) ((uint32_t) v56 + (uint32_t) v18); + v4[v103] = v102; + float v104 = v83.GetValue(v17); + int32_t v105 = (int32_t) ((uint32_t) v56 + (uint32_t) v16); + v4[v105] = v104; + float v106 = v83.GetValue(v15); + int32_t v107 = (int32_t) ((uint32_t) v56 + (uint32_t) v14); + v4[v107] = v106; + float v108 = v83.GetValue(v13); + int32_t v109 = (int32_t) ((uint32_t) v56 + (uint32_t) v12); + v4[v109] = v108; + float v110 = v83.GetValue(v11); + int32_t v111 = (int32_t) ((uint32_t) v56 + (uint32_t) v10); + v4[v111] = v110; + float v112 = v83.GetValue(v9); + int32_t v113 = (int32_t) ((uint32_t) v56 + (uint32_t) v8); + v4[v113] = v112; + float v114 = v83.GetValue(v7); + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + int32_t v115 = (int32_t) ((uint32_t) v56 + (uint32_t) v6); + v4[v115] = v114; + } + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: x_flat_fp32_inline49__rv_v2 + __gm__ Tensor* x_flat_fp32_inline49__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* x_flat_fp32_inline49__rv_v2 = reinterpret_cast<__gm__ float*>(x_flat_fp32_inline49__rv_v2_tensor->buffer.addr) + x_flat_fp32_inline49__rv_v2_tensor->start_offset; + + // Unpack tensor: hc_ffn_fn__ssa_v0 + __gm__ Tensor* hc_ffn_fn__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* hc_ffn_fn__ssa_v0 = reinterpret_cast<__gm__ float*>(hc_ffn_fn__ssa_v0_tensor->buffer.addr) + hc_ffn_fn__ssa_v0_tensor->start_offset; + + // Unpack tensor: inv_rms_inline75__ssa_v1 + __gm__ Tensor* inv_rms_inline75__ssa_v1_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* inv_rms_inline75__ssa_v1 = reinterpret_cast<__gm__ float*>(inv_rms_inline75__ssa_v1_tensor->buffer.addr) + inv_rms_inline75__ssa_v1_tensor->start_offset; + + // Unpack tensor: mixes_flat_inline45__ssa_v0 + __gm__ Tensor* mixes_flat_inline45__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ float* mixes_flat_inline45__ssa_v0 = reinterpret_cast<__gm__ float*>(mixes_flat_inline45__ssa_v0_tensor->buffer.addr) + mixes_flat_inline45__ssa_v0_tensor->start_offset; + + // Forward to ptoas-generated function + linear(x_flat_fp32_inline49__rv_v2, hc_ffn_fn__ssa_v0, inv_rms_inline75__ssa_v1, mixes_flat_inline45__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/local_expert.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/local_expert.cpp deleted file mode 100644 index 2d519b9b2..000000000 --- a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/local_expert.cpp +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright (c) PyPTO Contributors. - * This program is free software, you can redistribute it and/or modify it under the terms and conditions of - * CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ----------------------------------------------------------------------------------------------------------- - */ -/** - * Local-expert kernel — placeholder for the production moe_expert in this - * 2-rank demo. Replaced by the real moe_expert without changing the - * surrounding orchestration / combine wiring. - * - * Behavior (one element-wise multiply per row): - * recv_y[e, s, :] = recv_x[e, s, :] * recv_w[e, s] for s in [0, recv_count[e]) - * - * recv_x : BF16 [L, R, D] (dispatch OUTPUT_EXISTING, reused as INPUT) - * recv_w : FP32 [L, R] (dispatch OUTPUT_EXISTING, reused as INPUT) - * recv_count : INT32 [L, 1] (dispatch OUTPUT_EXISTING, reused as INPUT) - * recv_y : BF16 [L, R, D] (this kernel's OUTPUT_EXISTING) - * - * Per-expert n_rows = recv_count[e] decides how many rows to process. We - * skip padding rows entirely — they stay whatever recv_y was previously - * initialized to, but combine reads them only via pub_counts-driven slabs - * that never reach into padding, so the value doesn't matter. - * - * Pure local — no CommRemotePtr, no signals, no scratch. BF16 round-trip - * happens once per row at `cast(x*w, bf16)`. - */ - -#ifndef __gm__ -#define __gm__ -#endif - -#ifndef __aicore__ -#define __aicore__ [aicore] -#endif - -#include - -#include -#include "platform_comm/comm_context.h" -#include "tensor.h" - -using namespace pto; - -// Demo dimensions — must match dispatch.cpp / main.py. -static constexpr int L = 4; -static constexpr int R = 32; -static constexpr int D = 64; - -extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { - __gm__ Tensor *recv_x_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); - __gm__ Tensor *recv_w_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); - __gm__ Tensor *recv_count_tensor = reinterpret_cast<__gm__ Tensor *>(args[2]); - __gm__ Tensor *recv_y_tensor = reinterpret_cast<__gm__ Tensor *>(args[3]); - __gm__ CommContext *comm_ctx = reinterpret_cast<__gm__ CommContext *>(args[4]); - (void)comm_ctx; // unused; kept for ABI symmetry with dispatch / combine - - __gm__ bfloat16_t *recv_x = - reinterpret_cast<__gm__ bfloat16_t *>(recv_x_tensor->buffer.addr) + recv_x_tensor->start_offset; - __gm__ float *recv_w = reinterpret_cast<__gm__ float *>(recv_w_tensor->buffer.addr) + recv_w_tensor->start_offset; - __gm__ int32_t *recv_count = - reinterpret_cast<__gm__ int32_t *>(recv_count_tensor->buffer.addr) + recv_count_tensor->start_offset; - __gm__ bfloat16_t *recv_y = - reinterpret_cast<__gm__ bfloat16_t *>(recv_y_tensor->buffer.addr) + recv_y_tensor->start_offset; - - using XShape = Shape<1, 1, 1, 1, D>; - using XStride = Stride; - using XGlobalBF = GlobalTensor; - using XTileBF = Tile; - using XTileF = Tile; - - XTileBF x_bf; - XTileF x_f; - TASSIGN(x_bf, 0x0); - TASSIGN(x_f, 0x10000); - - for (int e = 0; e < L; ++e) { - int n_rows = recv_count[e]; - for (int slot = 0; slot < n_rows; ++slot) { - int row = e * R + slot; - float w = recv_w[row]; - - __gm__ bfloat16_t *x_src = recv_x + row * D; - __gm__ bfloat16_t *y_dst = recv_y + row * D; - XGlobalBF x_g(x_src); - XGlobalBF y_g(y_dst); - - TLOAD(x_bf, x_g); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - - TCVT(x_f, x_bf, RoundMode::CAST_ROUND); - pipe_barrier(PIPE_V); - TMULS(x_f, x_f, w); - pipe_barrier(PIPE_V); - TCVT(x_bf, x_f, RoundMode::CAST_ROUND); - pipe_barrier(PIPE_V); - - set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - TSTORE(y_g, x_bf); - set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); - } - } - pipe_barrier(PIPE_ALL); -} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/mix_x.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/mix_x.cpp new file mode 100644 index 000000000..46e1605c4 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/mix_x.cpp @@ -0,0 +1,154 @@ +// Kernel Function: mix_x +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void mix_x(__gm__ bfloat16_t* v1, __gm__ float* v2, __gm__ float* v3, int32_t v4) { + RoundMode v5 = RoundMode::CAST_RINT; + unsigned v6 = 0; + const int32_t v7 = 4; + const float v8 = 0.0f; + const int32_t v9 = 512; + const int32_t v10 = 8; + const int32_t v11 = 0; + const int32_t v12 = 16384; + const int32_t v13 = 1; + const int32_t v14 = 4096; + const int32_t v15 = 16; + const int64_t v16 = 4096; + const int64_t v17 = 2048; + const int64_t v18 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + size_t v19 = (size_t) v13; + size_t v20 = (size_t) v11; + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + for (size_t v21 = v20; v21 < ((size_t) v15); v21 += v19) { + int32_t v22 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v4 * (uint32_t) v15) + (uint32_t) ((int32_t) v21)); + __gm__ bfloat16_t* v23; + if (v22 < v15) { + for (size_t v24 = v20; v24 < ((size_t) v10); v24 += v19) { + int32_t v25 = (int32_t) ((uint32_t) ((int32_t) v24) * (uint32_t) v9); + Tile v26 = Tile(v13, v9); + uint64_t v27 = (uint64_t) v18; + TASSIGN(v26, v27); + pipe_barrier(PIPE_V); + TEXPANDS(v26, v8); + for (size_t v28 = v20; v28 < ((size_t) v7); v28 += v19) { + int32_t v29 = (int32_t) v28; + float v30 = v2[(int32_t) ((uint32_t) ((int32_t) (uint32_t) v22 * (uint32_t) v10) + (uint32_t) v29)]; + Tile v31 = Tile(v13, v9); + uint64_t v32 = (uint64_t) v17; + TASSIGN(v31, v32); + pto::Shape<1, 1, 1, 1, 512> v33 = pto::Shape<1, 1, 1, 1, 512>(); + pto::Stride<16384, 16384, 16384, 16384, 1> v34 = pto::Stride<16384, 16384, 16384, 16384, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 16384, 1>, pto::Layout::ND> v35 = GlobalTensor, pto::Stride<16384, 16384, 16384, 16384, 1>, pto::Layout::ND>(v3 + (v6 + (unsigned) v22 * (unsigned) v12 + (unsigned) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v29 * (uint32_t) v14) + (uint32_t) v25) * (unsigned) v13), v33, v34); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v31, v35); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v36 = Tile(v13, v9); + uint64_t v37 = (uint64_t) v17; + TASSIGN(v36, v37); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v36, v31, v30); + Tile v38 = Tile(v13, v9); + uint64_t v39 = (uint64_t) v18; + TASSIGN(v38, v39); + pipe_barrier(PIPE_V); + TADD(v38, v26, v36); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + }; + Tile v40 = Tile(v13, v9); + uint64_t v41 = (uint64_t) v16; + TASSIGN(v40, v41); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TCVT(v40, v26, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 512> v42 = pto::Shape<1, 1, 1, 1, 512>(); + pto::Stride<4096, 4096, 4096, 4096, 1> v43 = pto::Stride<4096, 4096, 4096, 4096, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND> v44 = GlobalTensor, pto::Stride<4096, 4096, 4096, 4096, 1>, pto::Layout::ND>(v1 + (v6 + (unsigned) v22 * (unsigned) v14 + (unsigned) v25 * (unsigned) v13), v42, v43); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v44, v40); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + }; + v23 = v1; + } else { + v23 = v1; + }; + } + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: x_mixed_view_inline55__co_l0_iter_v1 + __gm__ Tensor* x_mixed_view_inline55__co_l0_iter_v1_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ bfloat16_t* x_mixed_view_inline55__co_l0_iter_v1 = reinterpret_cast<__gm__ bfloat16_t*>(x_mixed_view_inline55__co_l0_iter_v1_tensor->buffer.addr) + x_mixed_view_inline55__co_l0_iter_v1_tensor->start_offset; + + // Unpack tensor: pre_val_flat_inline44__ssa_v0 + __gm__ Tensor* pre_val_flat_inline44__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* pre_val_flat_inline44__ssa_v0 = reinterpret_cast<__gm__ float*>(pre_val_flat_inline44__ssa_v0_tensor->buffer.addr) + pre_val_flat_inline44__ssa_v0_tensor->start_offset; + + // Unpack tensor: x_flat_fp32_inline49__rv_v2 + __gm__ Tensor* x_flat_fp32_inline49__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* x_flat_fp32_inline49__rv_v2 = reinterpret_cast<__gm__ float*>(x_flat_fp32_inline49__rv_v2_tensor->buffer.addr) + x_flat_fp32_inline49__rv_v2_tensor->start_offset; + + // Unpack scalar: t_inline5__co_idx_v0 + union { uint64_t u64; int64_t val; } t_inline5__co_idx_v0_conv; + t_inline5__co_idx_v0_conv.u64 = args[3]; + int64_t t_inline5__co_idx_v0 = t_inline5__co_idx_v0_conv.val; + + // Forward to ptoas-generated function + mix_x(x_mixed_view_inline55__co_l0_iter_v1, pre_val_flat_inline44__ssa_v0, x_flat_fp32_inline49__rv_v2, t_inline5__co_idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/recv_x_q.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/recv_x_q.cpp new file mode 100644 index 000000000..bbe47ab23 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/recv_x_q.cpp @@ -0,0 +1,241 @@ +// Kernel Function: recv_x_q +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void recv_x_q(__gm__ bfloat16_t* v1, __gm__ int8_t* v2, __gm__ float* v3, int32_t v4, int32_t v5) { + RoundMode v6 = RoundMode::CAST_TRUNC; + RoundMode v7 = RoundMode::CAST_ROUND; + unsigned v8 = 0; + const int32_t v9 = 131072; + const float v10 = 127.0f; + const int32_t v11 = 256; + const int32_t v12 = 0; + const float v13 = 9.99999974E-5f; + const int32_t v14 = 16; + const int32_t v15 = 1; + const int32_t v16 = 4096; + const int64_t v17 = 24640; + const int64_t v18 = 16448; + const int64_t v19 = 64; + const int64_t v20 = 0; + const int64_t v21 = 69760; + const int64_t v22 = 53376; + const int64_t v23 = 36992; + const int64_t v24 = 28800; + const int64_t v25 = 28736; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + size_t v26 = (size_t) v16; + size_t v27 = (size_t) v12; + size_t v28 = (size_t) v11; + Tile v29 = Tile(v15, v14); + uint64_t v30 = (uint64_t) v25; + TASSIGN(v29, v30); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TEXPANDS(v29, v13); + for (size_t v31 = v27; v31 < v26; v31 += v28) { + Tile v32 = Tile(v14, v11); + uint64_t v33 = (uint64_t) v24; + TASSIGN(v32, v33); + pto::Shape<1, 1, 1, 16, 256> v34 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<131072, 131072, 131072, 4096, 1> v35 = pto::Stride<131072, 131072, 131072, 4096, 1>(); + GlobalTensor, pto::Stride<131072, 131072, 131072, 4096, 1>, pto::Layout::ND> v36 = GlobalTensor, pto::Stride<131072, 131072, 131072, 4096, 1>, pto::Layout::ND>(v1 + ((v8 + (unsigned) v4 * (unsigned) v9) + (unsigned) v5 * (unsigned) v16 + (unsigned) ((int32_t) v31) * (unsigned) v15), v34, v35); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v32, v36); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v37 = Tile(v14, v11); + uint64_t v38 = (uint64_t) v23; + TASSIGN(v37, v38); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v37, v32, v7); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v39 = Tile(v14, v11); + uint64_t v40 = (uint64_t) v23; + TASSIGN(v39, v40); + Tile v41 = Tile(v14, v11); + uint64_t v42 = (uint64_t) v22; + TASSIGN(v41, v42); + pipe_barrier(PIPE_V); + TNEG(v41, v39); + Tile v43 = Tile(v14, v11); + uint64_t v44 = (uint64_t) v23; + TASSIGN(v43, v44); + pipe_barrier(PIPE_V); + TMAX(v43, v39, v41); + Tile v45 = Tile(v14, v11); + uint64_t v46 = (uint64_t) v22; + TASSIGN(v45, v46); + Tile v47 = Tile(v14, v15); + uint64_t v48 = (uint64_t) v21; + TASSIGN(v47, v48); + pipe_barrier(PIPE_V); + TROWMAX(v47, v43, v45); + Tile v49 = Tile(v15, v14); + uint64_t v50 = (uint64_t) v21; + TASSIGN(v49, v50); + Tile v51 = Tile(v15, v14); + uint64_t v52 = (uint64_t) v25; + TASSIGN(v51, v52); + pipe_barrier(PIPE_V); + TMAX(v51, v29, v49); + } + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + Tile v53 = Tile(v15, v14); + uint64_t v54 = (uint64_t) v20; + TASSIGN(v53, v54); + TEXPANDS(v53, v10); + Tile v55 = Tile(v15, v14); + uint64_t v56 = (uint64_t) v25; + TASSIGN(v55, v56); + pipe_barrier(PIPE_V); + TDIV(v55, v53, v29); + Tile v57 = Tile(v15, v14); + uint64_t v58 = (uint64_t) v20; + TASSIGN(v57, v58); + pipe_barrier(PIPE_V); + TRECIP(v57, v55); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + Tile v59 = Tile(v14, v15); + uint64_t v60 = (uint64_t) v20; + TASSIGN(v59, v60); + Tile v61 = Tile(v14, v15); + uint64_t v62 = (uint64_t) v25; + TASSIGN(v61, v62); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + for (size_t v63 = v27; v63 < v26; v63 += v28) { + int32_t v64 = (int32_t) v63; + Tile v65 = Tile(v14, v11); + uint64_t v66 = (uint64_t) v24; + TASSIGN(v65, v66); + pto::Shape<1, 1, 1, 16, 256> v67 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<131072, 131072, 131072, 4096, 1> v68 = pto::Stride<131072, 131072, 131072, 4096, 1>(); + GlobalTensor, pto::Stride<131072, 131072, 131072, 4096, 1>, pto::Layout::ND> v69 = GlobalTensor, pto::Stride<131072, 131072, 131072, 4096, 1>, pto::Layout::ND>(v1 + ((v8 + (unsigned) v4 * (unsigned) v9) + (unsigned) v5 * (unsigned) v16 + (unsigned) v64 * (unsigned) v15), v67, v68); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v65, v69); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v70 = Tile(v14, v11); + uint64_t v71 = (uint64_t) v23; + TASSIGN(v70, v71); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TCVT(v70, v65, v7); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + Tile v72 = Tile(v14, v11); + uint64_t v73 = (uint64_t) v23; + TASSIGN(v72, v73); + Tile v74 = Tile(v14, v11); + uint64_t v75 = (uint64_t) v23; + TASSIGN(v74, v75); + pipe_barrier(PIPE_V); + TROWEXPANDMUL(v74, v72, v61); + Tile v76 = Tile(v14, v11); + uint64_t v77 = (uint64_t) v19; + TASSIGN(v76, v77); + pipe_barrier(PIPE_V); + TCVT(v76, v74, v7); + Tile v78 = Tile(v14, v11); + uint64_t v79 = (uint64_t) v18; + TASSIGN(v78, v79); + pipe_barrier(PIPE_V); + TCVT(v78, v76, v7); + Tile v80 = Tile(v14, v11); + uint64_t v81 = (uint64_t) v17; + TASSIGN(v80, v81); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TCVT(v80, v78, v6); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 256> v82 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v83 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v84 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v2 + (v8 + v8 * (unsigned) v16 + (unsigned) v64 * (unsigned) v15), v82, v83); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v84, v80); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + pto::Shape<1, 1, 1, 16, 1> v85 = pto::Shape<1, 1, 1, 16, 1>(); + pto::Stride<16, 16, 16, 1, 16> v86 = pto::Stride<16, 16, 16, 1, 16>(); + GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN> v87 = GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN>(v3 + (v8 + v8 * (unsigned) v15 + v8 * (unsigned) v14), v85, v86); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(v87, v59); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: recv_x__ssa_v0 + __gm__ Tensor* recv_x__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ bfloat16_t* recv_x__ssa_v0 = reinterpret_cast<__gm__ bfloat16_t*>(recv_x__ssa_v0_tensor->buffer.addr) + recv_x__ssa_v0_tensor->start_offset; + + // Unpack tensor: recv_x_tile_i8_inline75__ssa_v0 + __gm__ Tensor* recv_x_tile_i8_inline75__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int8_t* recv_x_tile_i8_inline75__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(recv_x_tile_i8_inline75__ssa_v0_tensor->buffer.addr) + recv_x_tile_i8_inline75__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack scalar: local_i_inline67__idx_v0 + union { uint64_t u64; int64_t val; } local_i_inline67__idx_v0_conv; + local_i_inline67__idx_v0_conv.u64 = args[3]; + int64_t local_i_inline67__idx_v0 = local_i_inline67__idx_v0_conv.val; + + // Unpack scalar: t0_inline47__ssa_v0 + union { uint64_t u64; int64_t val; } t0_inline47__ssa_v0_conv; + t0_inline47__ssa_v0_conv.u64 = args[4]; + int64_t t0_inline47__ssa_v0 = t0_inline47__ssa_v0_conv.val; + + // Forward to ptoas-generated function + recv_x_q(recv_x__ssa_v0, recv_x_tile_i8_inline75__ssa_v0, ret0__out, local_i_inline67__idx_v0, t0_inline47__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/rms.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/rms.cpp new file mode 100644 index 000000000..3ae1379a6 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/rms.cpp @@ -0,0 +1,245 @@ +// Kernel Function: rms +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void rms(__gm__ float* v1, __gm__ float* v2) { + unsigned v3 = 0; + const float v4 = 9.99999997E-7f; + const float v5 = 6.10351563E-5f; + const int32_t v6 = 1536; + const int32_t v7 = 1024; + const int32_t v8 = 512; + const int32_t v9 = 4; + const int32_t v10 = 32; + const int32_t v11 = 0; + const float v12 = 0.0f; + const int32_t v13 = 1; + const int32_t v14 = 16384; + const int32_t v15 = 16; + const int64_t v16 = 0; + const int64_t v17 = 163968; + const int64_t v18 = 131200; + const int64_t v19 = 98432; + const int64_t v20 = 65664; + const int64_t v21 = 32896; + const int64_t v22 = 128; + const int64_t v23 = 64; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v24 = Tile(v13, v15); + uint64_t v25 = (uint64_t) v23; + TASSIGN(v24, v25); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TEXPANDS(v24, v12); + for (size_t v26 = (size_t) v11; v26 < ((size_t) v10); v26 += (size_t) v9) { + int32_t v27 = (int32_t) ((uint32_t) ((int32_t) v26) * (uint32_t) v8); + Tile v28 = Tile(v15, v8); + uint64_t v29 = (uint64_t) v22; + TASSIGN(v28, v29); + pto::Shape<1, 1, 1, 16, 512> v30 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<262144, 262144, 262144, 16384, 1> v31 = pto::Stride<262144, 262144, 262144, 16384, 1>(); + GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND> v32 = GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND>(v1 + (v3 + v3 * (unsigned) v14 + (unsigned) v27 * (unsigned) v13), v30, v31); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v28, v32); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v33 = Tile(v15, v8); + uint64_t v34 = (uint64_t) v21; + TASSIGN(v33, v34); + pto::Shape<1, 1, 1, 16, 512> v35 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<262144, 262144, 262144, 16384, 1> v36 = pto::Stride<262144, 262144, 262144, 16384, 1>(); + GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND> v37 = GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND>(v1 + (v3 + v3 * (unsigned) v14 + (unsigned) ((int32_t) (uint32_t) v27 + (uint32_t) v8) * (unsigned) v13), v35, v36); + TLOAD(v33, v37); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v38 = Tile(v15, v8); + uint64_t v39 = (uint64_t) v20; + TASSIGN(v38, v39); + pto::Shape<1, 1, 1, 16, 512> v40 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<262144, 262144, 262144, 16384, 1> v41 = pto::Stride<262144, 262144, 262144, 16384, 1>(); + GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND> v42 = GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND>(v1 + (v3 + v3 * (unsigned) v14 + (unsigned) ((int32_t) (uint32_t) v27 + (uint32_t) v7) * (unsigned) v13), v40, v41); + TLOAD(v38, v42); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + Tile v43 = Tile(v15, v8); + uint64_t v44 = (uint64_t) v19; + TASSIGN(v43, v44); + pto::Shape<1, 1, 1, 16, 512> v45 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<262144, 262144, 262144, 16384, 1> v46 = pto::Stride<262144, 262144, 262144, 16384, 1>(); + GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND> v47 = GlobalTensor, pto::Stride<262144, 262144, 262144, 16384, 1>, pto::Layout::ND>(v1 + (v3 + v3 * (unsigned) v14 + (unsigned) ((int32_t) (uint32_t) v27 + (uint32_t) v6) * (unsigned) v13), v45, v46); + TLOAD(v43, v47); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + Tile v48 = Tile(v15, v8); + uint64_t v49 = (uint64_t) v22; + TASSIGN(v48, v49); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMUL(v48, v28, v28); + Tile v50 = Tile(v15, v8); + uint64_t v51 = (uint64_t) v18; + TASSIGN(v50, v51); + Tile v52 = Tile(v15, v13); + uint64_t v53 = (uint64_t) v17; + TASSIGN(v52, v53); + pipe_barrier(PIPE_V); + TROWSUM(v52, v48, v50); + Tile v54 = Tile(v13, v15); + uint64_t v55 = (uint64_t) v17; + TASSIGN(v54, v55); + Tile v56 = Tile(v13, v15); + uint64_t v57 = (uint64_t) v16; + TASSIGN(v56, v57); + pipe_barrier(PIPE_V); + TADD(v56, v24, v54); + Tile v58 = Tile(v15, v8); + uint64_t v59 = (uint64_t) v22; + TASSIGN(v58, v59); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TMUL(v58, v33, v33); + Tile v60 = Tile(v15, v8); + uint64_t v61 = (uint64_t) v21; + TASSIGN(v60, v61); + Tile v62 = Tile(v15, v13); + uint64_t v63 = (uint64_t) v17; + TASSIGN(v62, v63); + pipe_barrier(PIPE_V); + TROWSUM(v62, v58, v60); + Tile v64 = Tile(v13, v15); + uint64_t v65 = (uint64_t) v17; + TASSIGN(v64, v65); + Tile v66 = Tile(v13, v15); + uint64_t v67 = (uint64_t) v16; + TASSIGN(v66, v67); + pipe_barrier(PIPE_V); + TADD(v66, v56, v64); + Tile v68 = Tile(v15, v8); + uint64_t v69 = (uint64_t) v22; + TASSIGN(v68, v69); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + TMUL(v68, v38, v38); + Tile v70 = Tile(v15, v8); + uint64_t v71 = (uint64_t) v21; + TASSIGN(v70, v71); + Tile v72 = Tile(v15, v13); + uint64_t v73 = (uint64_t) v17; + TASSIGN(v72, v73); + pipe_barrier(PIPE_V); + TROWSUM(v72, v68, v70); + Tile v74 = Tile(v13, v15); + uint64_t v75 = (uint64_t) v17; + TASSIGN(v74, v75); + Tile v76 = Tile(v13, v15); + uint64_t v77 = (uint64_t) v16; + TASSIGN(v76, v77); + pipe_barrier(PIPE_V); + TADD(v76, v66, v74); + Tile v78 = Tile(v15, v8); + uint64_t v79 = (uint64_t) v22; + TASSIGN(v78, v79); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + TMUL(v78, v43, v43); + Tile v80 = Tile(v15, v8); + uint64_t v81 = (uint64_t) v21; + TASSIGN(v80, v81); + Tile v82 = Tile(v15, v13); + uint64_t v83 = (uint64_t) v17; + TASSIGN(v82, v83); + pipe_barrier(PIPE_V); + TROWSUM(v82, v78, v80); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v84 = Tile(v13, v15); + uint64_t v85 = (uint64_t) v17; + TASSIGN(v84, v85); + Tile v86 = Tile(v13, v15); + uint64_t v87 = (uint64_t) v23; + TASSIGN(v86, v87); + pipe_barrier(PIPE_V); + TADD(v86, v76, v84); + } + Tile v88 = Tile(v13, v15); + uint64_t v89 = (uint64_t) v23; + TASSIGN(v88, v89); + pipe_barrier(PIPE_V); + TMULS(v88, v24, v5); + Tile v90 = Tile(v13, v15); + uint64_t v91 = (uint64_t) v23; + TASSIGN(v90, v91); + pipe_barrier(PIPE_V); + TADDS(v90, v88, v4); + Tile v92 = Tile(v13, v15); + uint64_t v93 = (uint64_t) v23; + TASSIGN(v92, v93); + pipe_barrier(PIPE_V); + TSQRT(v92, v90); + Tile v94 = Tile(v13, v15); + uint64_t v95 = (uint64_t) v16; + TASSIGN(v94, v95); + pipe_barrier(PIPE_V); + TRECIP(v94, v92); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v96 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<16, 16, 16, 16, 1> v97 = pto::Stride<16, 16, 16, 16, 1>(); + GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> v98 = GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>(v2 + (v3 + v3 * (unsigned) v15 + v3 * (unsigned) v13), v96, v97); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v98, v94); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: x_flat_fp32_inline49__rv_v2 + __gm__ Tensor* x_flat_fp32_inline49__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* x_flat_fp32_inline49__rv_v2 = reinterpret_cast<__gm__ float*>(x_flat_fp32_inline49__rv_v2_tensor->buffer.addr) + x_flat_fp32_inline49__rv_v2_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Forward to ptoas-generated function + rms(x_flat_fp32_inline49__rv_v2, ret0__out); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/route_extract_top2.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/route_extract_top2.cpp new file mode 100644 index 000000000..1a1fe0447 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/route_extract_top2.cpp @@ -0,0 +1,796 @@ +// Kernel Function: route_extract_top2 +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void route_extract_top2(__gm__ float* v1, __gm__ float* v2, __gm__ int32_t* v3) { + unsigned v4 = 15; + unsigned v5 = 14; + unsigned v6 = 13; + unsigned v7 = 12; + unsigned v8 = 11; + unsigned v9 = 10; + unsigned v10 = 9; + unsigned v11 = 8; + unsigned v12 = 7; + unsigned v13 = 6; + unsigned v14 = 5; + unsigned v15 = 4; + unsigned v16 = 3; + unsigned v17 = 2; + unsigned v18 = 1; + unsigned v19 = 0; + const int32_t v20 = 2; + const float v21 = 0.0f; + const int32_t v22 = 64; + const int32_t v23 = 1; + const int32_t v24 = 32; + const int32_t v25 = 16; + const int64_t v26 = 2304; + const int64_t v27 = 2240; + const int64_t v28 = 2176; + const int64_t v29 = 2048; + const int64_t v30 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v31 = Tile(v25, v24); + uint64_t v32 = (uint64_t) v30; + TASSIGN(v31, v32); + TEXPANDS(v31, v21); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 32> v33 = pto::Shape<1, 1, 1, 16, 32>(); + pto::Stride<512, 512, 512, 32, 1> v34 = pto::Stride<512, 512, 512, 32, 1>(); + GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND> v35 = GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND>(v1 + (v19 + v19 * (unsigned) v24 + v19 * (unsigned) v23), v33, v34); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v35, v31); + Tile v36 = Tile(v23, v24); + uint64_t v37 = (uint64_t) v29; + TASSIGN(v36, v37); + pto::Shape<1, 1, 1, 1, 32> v38 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v39 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v40 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v19 * (unsigned) v22 + v19 * (unsigned) v23), v38, v39); + TLOAD(v36, v40); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v41 = Tile(v23, v25); + uint64_t v42 = (uint64_t) v28; + TASSIGN(v41, v42); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TGATHER, Tile, MaskPattern::P0101>(v41, v36); + Tile v43 = Tile(v23, v25); + uint64_t v44 = (uint64_t) v27; + TASSIGN(v43, v44); + TGATHER, Tile, MaskPattern::P1010>(v43, v36); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v45 = Tile(v23, v20); + uint64_t v46 = (uint64_t) v28; + TASSIGN(v45, v46); + pipe_barrier(PIPE_V); + TMOV(v45, v41); + v45.SetValidShape(v23, v20); + Tile v47 = Tile(v23, v25); + uint64_t v48 = (uint64_t) v26; + TASSIGN(v47, v48); + pipe_barrier(PIPE_V); + TFILLPAD(v47, v45); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + pto::Shape<1, 1, 1, 1, 16> v49 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v50 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v51 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v19 * (unsigned) v24 + v19 * (unsigned) v23), v49, v50); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + pipe_barrier(PIPE_MTE3); + TSTORE(v51, v47); + pto::Shape<1, 1, 1, 1, 16> v52 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v53 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v54 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v19 * (unsigned) v24 + v19 * (unsigned) v23), v52, v53); + TSTORE(v54, v43); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v55 = Tile(v23, v24); + uint64_t v56 = (uint64_t) v29; + TASSIGN(v55, v56); + pto::Shape<1, 1, 1, 1, 32> v57 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v58 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v59 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v18 * (unsigned) v22 + v19 * (unsigned) v23), v57, v58); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v55, v59); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v60 = Tile(v23, v25); + uint64_t v61 = (uint64_t) v28; + TASSIGN(v60, v61); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v60, v55); + Tile v62 = Tile(v23, v25); + uint64_t v63 = (uint64_t) v27; + TASSIGN(v62, v63); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TGATHER, Tile, MaskPattern::P1010>(v62, v55); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + Tile v64 = Tile(v23, v20); + uint64_t v65 = (uint64_t) v28; + TASSIGN(v64, v65); + pipe_barrier(PIPE_V); + TMOV(v64, v60); + v64.SetValidShape(v23, v20); + Tile v66 = Tile(v23, v25); + uint64_t v67 = (uint64_t) v26; + TASSIGN(v66, v67); + pipe_barrier(PIPE_V); + TFILLPAD(v66, v64); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID2); + pto::Shape<1, 1, 1, 1, 16> v68 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v69 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v70 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v18 * (unsigned) v24 + v19 * (unsigned) v23), v68, v69); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID2); + TSTORE(v70, v66); + pto::Shape<1, 1, 1, 1, 16> v71 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v72 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v73 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v18 * (unsigned) v24 + v19 * (unsigned) v23), v71, v72); + TSTORE(v73, v62); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + Tile v74 = Tile(v23, v24); + uint64_t v75 = (uint64_t) v29; + TASSIGN(v74, v75); + pto::Shape<1, 1, 1, 1, 32> v76 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v77 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v78 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v17 * (unsigned) v22 + v19 * (unsigned) v23), v76, v77); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v74, v78); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + Tile v79 = Tile(v23, v25); + uint64_t v80 = (uint64_t) v28; + TASSIGN(v79, v80); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v79, v74); + Tile v81 = Tile(v23, v25); + uint64_t v82 = (uint64_t) v27; + TASSIGN(v81, v82); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + TGATHER, Tile, MaskPattern::P1010>(v81, v74); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + Tile v83 = Tile(v23, v20); + uint64_t v84 = (uint64_t) v28; + TASSIGN(v83, v84); + pipe_barrier(PIPE_V); + TMOV(v83, v79); + v83.SetValidShape(v23, v20); + Tile v85 = Tile(v23, v25); + uint64_t v86 = (uint64_t) v26; + TASSIGN(v85, v86); + pipe_barrier(PIPE_V); + TFILLPAD(v85, v83); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + pto::Shape<1, 1, 1, 1, 16> v87 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v88 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v89 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v17 * (unsigned) v24 + v19 * (unsigned) v23), v87, v88); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + TSTORE(v89, v85); + pto::Shape<1, 1, 1, 1, 16> v90 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v91 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v92 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v17 * (unsigned) v24 + v19 * (unsigned) v23), v90, v91); + TSTORE(v92, v81); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + Tile v93 = Tile(v23, v24); + uint64_t v94 = (uint64_t) v29; + TASSIGN(v93, v94); + pto::Shape<1, 1, 1, 1, 32> v95 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v96 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v97 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v16 * (unsigned) v22 + v19 * (unsigned) v23), v95, v96); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v93, v97); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + Tile v98 = Tile(v23, v25); + uint64_t v99 = (uint64_t) v28; + TASSIGN(v98, v99); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v98, v93); + Tile v100 = Tile(v23, v25); + uint64_t v101 = (uint64_t) v27; + TASSIGN(v100, v101); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TGATHER, Tile, MaskPattern::P1010>(v100, v93); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + Tile v102 = Tile(v23, v20); + uint64_t v103 = (uint64_t) v28; + TASSIGN(v102, v103); + pipe_barrier(PIPE_V); + TMOV(v102, v98); + v102.SetValidShape(v23, v20); + Tile v104 = Tile(v23, v25); + uint64_t v105 = (uint64_t) v26; + TASSIGN(v104, v105); + pipe_barrier(PIPE_V); + TFILLPAD(v104, v102); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID4); + pto::Shape<1, 1, 1, 1, 16> v106 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v107 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v108 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v16 * (unsigned) v24 + v19 * (unsigned) v23), v106, v107); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID4); + TSTORE(v108, v104); + pto::Shape<1, 1, 1, 1, 16> v109 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v110 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v111 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v16 * (unsigned) v24 + v19 * (unsigned) v23), v109, v110); + TSTORE(v111, v100); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + Tile v112 = Tile(v23, v24); + uint64_t v113 = (uint64_t) v29; + TASSIGN(v112, v113); + pto::Shape<1, 1, 1, 1, 32> v114 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v115 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v116 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v15 * (unsigned) v22 + v19 * (unsigned) v23), v114, v115); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + TLOAD(v112, v116); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + Tile v117 = Tile(v23, v25); + uint64_t v118 = (uint64_t) v28; + TASSIGN(v117, v118); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v117, v112); + Tile v119 = Tile(v23, v25); + uint64_t v120 = (uint64_t) v27; + TASSIGN(v119, v120); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + TGATHER, Tile, MaskPattern::P1010>(v119, v112); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + Tile v121 = Tile(v23, v20); + uint64_t v122 = (uint64_t) v28; + TASSIGN(v121, v122); + pipe_barrier(PIPE_V); + TMOV(v121, v117); + v121.SetValidShape(v23, v20); + Tile v123 = Tile(v23, v25); + uint64_t v124 = (uint64_t) v26; + TASSIGN(v123, v124); + pipe_barrier(PIPE_V); + TFILLPAD(v123, v121); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID5); + pto::Shape<1, 1, 1, 1, 16> v125 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v126 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v127 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v15 * (unsigned) v24 + v19 * (unsigned) v23), v125, v126); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID5); + TSTORE(v127, v123); + pto::Shape<1, 1, 1, 1, 16> v128 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v129 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v130 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v15 * (unsigned) v24 + v19 * (unsigned) v23), v128, v129); + TSTORE(v130, v119); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID4); + Tile v131 = Tile(v23, v24); + uint64_t v132 = (uint64_t) v29; + TASSIGN(v131, v132); + pto::Shape<1, 1, 1, 1, 32> v133 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v134 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v135 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v14 * (unsigned) v22 + v19 * (unsigned) v23), v133, v134); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + TLOAD(v131, v135); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID5); + Tile v136 = Tile(v23, v25); + uint64_t v137 = (uint64_t) v28; + TASSIGN(v136, v137); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID5); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v136, v131); + Tile v138 = Tile(v23, v25); + uint64_t v139 = (uint64_t) v27; + TASSIGN(v138, v139); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID4); + TGATHER, Tile, MaskPattern::P1010>(v138, v131); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID5); + Tile v140 = Tile(v23, v20); + uint64_t v141 = (uint64_t) v28; + TASSIGN(v140, v141); + pipe_barrier(PIPE_V); + TMOV(v140, v136); + v140.SetValidShape(v23, v20); + Tile v142 = Tile(v23, v25); + uint64_t v143 = (uint64_t) v26; + TASSIGN(v142, v143); + pipe_barrier(PIPE_V); + TFILLPAD(v142, v140); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID6); + pto::Shape<1, 1, 1, 1, 16> v144 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v145 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v146 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v14 * (unsigned) v24 + v19 * (unsigned) v23), v144, v145); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID6); + TSTORE(v146, v142); + pto::Shape<1, 1, 1, 1, 16> v147 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v148 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v149 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v14 * (unsigned) v24 + v19 * (unsigned) v23), v147, v148); + TSTORE(v149, v138); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + Tile v150 = Tile(v23, v24); + uint64_t v151 = (uint64_t) v29; + TASSIGN(v150, v151); + pto::Shape<1, 1, 1, 1, 32> v152 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v153 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v154 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v13 * (unsigned) v22 + v19 * (unsigned) v23), v152, v153); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID5); + TLOAD(v150, v154); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID6); + Tile v155 = Tile(v23, v25); + uint64_t v156 = (uint64_t) v28; + TASSIGN(v155, v156); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID6); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v155, v150); + Tile v157 = Tile(v23, v25); + uint64_t v158 = (uint64_t) v27; + TASSIGN(v157, v158); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + TGATHER, Tile, MaskPattern::P1010>(v157, v150); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID6); + Tile v159 = Tile(v23, v20); + uint64_t v160 = (uint64_t) v28; + TASSIGN(v159, v160); + pipe_barrier(PIPE_V); + TMOV(v159, v155); + v159.SetValidShape(v23, v20); + Tile v161 = Tile(v23, v25); + uint64_t v162 = (uint64_t) v26; + TASSIGN(v161, v162); + pipe_barrier(PIPE_V); + TFILLPAD(v161, v159); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID7); + pto::Shape<1, 1, 1, 1, 16> v163 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v164 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v165 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v13 * (unsigned) v24 + v19 * (unsigned) v23), v163, v164); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID7); + TSTORE(v165, v161); + pto::Shape<1, 1, 1, 1, 16> v166 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v167 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v168 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v13 * (unsigned) v24 + v19 * (unsigned) v23), v166, v167); + TSTORE(v168, v157); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + Tile v169 = Tile(v23, v24); + uint64_t v170 = (uint64_t) v29; + TASSIGN(v169, v170); + pto::Shape<1, 1, 1, 1, 32> v171 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v172 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v173 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v12 * (unsigned) v22 + v19 * (unsigned) v23), v171, v172); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID6); + TLOAD(v169, v173); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v174 = Tile(v23, v25); + uint64_t v175 = (uint64_t) v28; + TASSIGN(v174, v175); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v174, v169); + Tile v176 = Tile(v23, v25); + uint64_t v177 = (uint64_t) v27; + TASSIGN(v176, v177); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + TGATHER, Tile, MaskPattern::P1010>(v176, v169); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + Tile v178 = Tile(v23, v20); + uint64_t v179 = (uint64_t) v28; + TASSIGN(v178, v179); + pipe_barrier(PIPE_V); + TMOV(v178, v174); + v178.SetValidShape(v23, v20); + Tile v180 = Tile(v23, v25); + uint64_t v181 = (uint64_t) v26; + TASSIGN(v180, v181); + pipe_barrier(PIPE_V); + TFILLPAD(v180, v178); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v182 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v183 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v184 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v12 * (unsigned) v24 + v19 * (unsigned) v23), v182, v183); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v184, v180); + pto::Shape<1, 1, 1, 1, 16> v185 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v186 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v187 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v12 * (unsigned) v24 + v19 * (unsigned) v23), v185, v186); + TSTORE(v187, v176); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID7); + Tile v188 = Tile(v23, v24); + uint64_t v189 = (uint64_t) v29; + TASSIGN(v188, v189); + pto::Shape<1, 1, 1, 1, 32> v190 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v191 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v192 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v11 * (unsigned) v22 + v19 * (unsigned) v23), v190, v191); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + TLOAD(v188, v192); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v193 = Tile(v23, v25); + uint64_t v194 = (uint64_t) v28; + TASSIGN(v193, v194); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v193, v188); + Tile v195 = Tile(v23, v25); + uint64_t v196 = (uint64_t) v27; + TASSIGN(v195, v196); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID7); + TGATHER, Tile, MaskPattern::P1010>(v195, v188); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v197 = Tile(v23, v20); + uint64_t v198 = (uint64_t) v28; + TASSIGN(v197, v198); + pipe_barrier(PIPE_V); + TMOV(v197, v193); + v197.SetValidShape(v23, v20); + Tile v199 = Tile(v23, v25); + uint64_t v200 = (uint64_t) v26; + TASSIGN(v199, v200); + pipe_barrier(PIPE_V); + TFILLPAD(v199, v197); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v201 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v202 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v203 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v11 * (unsigned) v24 + v19 * (unsigned) v23), v201, v202); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v203, v199); + pto::Shape<1, 1, 1, 1, 16> v204 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v205 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v206 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v11 * (unsigned) v24 + v19 * (unsigned) v23), v204, v205); + TSTORE(v206, v195); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v207 = Tile(v23, v24); + uint64_t v208 = (uint64_t) v29; + TASSIGN(v207, v208); + pto::Shape<1, 1, 1, 1, 32> v209 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v210 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v211 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v10 * (unsigned) v22 + v19 * (unsigned) v23), v209, v210); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v207, v211); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v212 = Tile(v23, v25); + uint64_t v213 = (uint64_t) v28; + TASSIGN(v212, v213); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v212, v207); + Tile v214 = Tile(v23, v25); + uint64_t v215 = (uint64_t) v27; + TASSIGN(v214, v215); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TGATHER, Tile, MaskPattern::P1010>(v214, v207); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v216 = Tile(v23, v20); + uint64_t v217 = (uint64_t) v28; + TASSIGN(v216, v217); + pipe_barrier(PIPE_V); + TMOV(v216, v212); + v216.SetValidShape(v23, v20); + Tile v218 = Tile(v23, v25); + uint64_t v219 = (uint64_t) v26; + TASSIGN(v218, v219); + pipe_barrier(PIPE_V); + TFILLPAD(v218, v216); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v220 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v221 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v222 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v10 * (unsigned) v24 + v19 * (unsigned) v23), v220, v221); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v222, v218); + pto::Shape<1, 1, 1, 1, 16> v223 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v224 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v225 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v10 * (unsigned) v24 + v19 * (unsigned) v23), v223, v224); + TSTORE(v225, v214); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v226 = Tile(v23, v24); + uint64_t v227 = (uint64_t) v29; + TASSIGN(v226, v227); + pto::Shape<1, 1, 1, 1, 32> v228 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v229 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v230 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v9 * (unsigned) v22 + v19 * (unsigned) v23), v228, v229); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v226, v230); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v231 = Tile(v23, v25); + uint64_t v232 = (uint64_t) v28; + TASSIGN(v231, v232); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v231, v226); + Tile v233 = Tile(v23, v25); + uint64_t v234 = (uint64_t) v27; + TASSIGN(v233, v234); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TGATHER, Tile, MaskPattern::P1010>(v233, v226); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v235 = Tile(v23, v20); + uint64_t v236 = (uint64_t) v28; + TASSIGN(v235, v236); + pipe_barrier(PIPE_V); + TMOV(v235, v231); + v235.SetValidShape(v23, v20); + Tile v237 = Tile(v23, v25); + uint64_t v238 = (uint64_t) v26; + TASSIGN(v237, v238); + pipe_barrier(PIPE_V); + TFILLPAD(v237, v235); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v239 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v240 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v241 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v9 * (unsigned) v24 + v19 * (unsigned) v23), v239, v240); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v241, v237); + pto::Shape<1, 1, 1, 1, 16> v242 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v243 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v244 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v9 * (unsigned) v24 + v19 * (unsigned) v23), v242, v243); + TSTORE(v244, v233); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v245 = Tile(v23, v24); + uint64_t v246 = (uint64_t) v29; + TASSIGN(v245, v246); + pto::Shape<1, 1, 1, 1, 32> v247 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v248 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v249 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v8 * (unsigned) v22 + v19 * (unsigned) v23), v247, v248); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v245, v249); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v250 = Tile(v23, v25); + uint64_t v251 = (uint64_t) v28; + TASSIGN(v250, v251); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v250, v245); + Tile v252 = Tile(v23, v25); + uint64_t v253 = (uint64_t) v27; + TASSIGN(v252, v253); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TGATHER, Tile, MaskPattern::P1010>(v252, v245); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v254 = Tile(v23, v20); + uint64_t v255 = (uint64_t) v28; + TASSIGN(v254, v255); + pipe_barrier(PIPE_V); + TMOV(v254, v250); + v254.SetValidShape(v23, v20); + Tile v256 = Tile(v23, v25); + uint64_t v257 = (uint64_t) v26; + TASSIGN(v256, v257); + pipe_barrier(PIPE_V); + TFILLPAD(v256, v254); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v258 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v259 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v260 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v8 * (unsigned) v24 + v19 * (unsigned) v23), v258, v259); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v260, v256); + pto::Shape<1, 1, 1, 1, 16> v261 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v262 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v263 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v8 * (unsigned) v24 + v19 * (unsigned) v23), v261, v262); + TSTORE(v263, v252); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v264 = Tile(v23, v24); + uint64_t v265 = (uint64_t) v29; + TASSIGN(v264, v265); + pto::Shape<1, 1, 1, 1, 32> v266 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v267 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v268 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v7 * (unsigned) v22 + v19 * (unsigned) v23), v266, v267); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v264, v268); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v269 = Tile(v23, v25); + uint64_t v270 = (uint64_t) v28; + TASSIGN(v269, v270); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v269, v264); + Tile v271 = Tile(v23, v25); + uint64_t v272 = (uint64_t) v27; + TASSIGN(v271, v272); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TGATHER, Tile, MaskPattern::P1010>(v271, v264); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v273 = Tile(v23, v20); + uint64_t v274 = (uint64_t) v28; + TASSIGN(v273, v274); + pipe_barrier(PIPE_V); + TMOV(v273, v269); + v273.SetValidShape(v23, v20); + Tile v275 = Tile(v23, v25); + uint64_t v276 = (uint64_t) v26; + TASSIGN(v275, v276); + pipe_barrier(PIPE_V); + TFILLPAD(v275, v273); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v277 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v278 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v279 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v7 * (unsigned) v24 + v19 * (unsigned) v23), v277, v278); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v279, v275); + pto::Shape<1, 1, 1, 1, 16> v280 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v281 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v282 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v7 * (unsigned) v24 + v19 * (unsigned) v23), v280, v281); + TSTORE(v282, v271); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v283 = Tile(v23, v24); + uint64_t v284 = (uint64_t) v29; + TASSIGN(v283, v284); + pto::Shape<1, 1, 1, 1, 32> v285 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v286 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v287 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v6 * (unsigned) v22 + v19 * (unsigned) v23), v285, v286); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v283, v287); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v288 = Tile(v23, v25); + uint64_t v289 = (uint64_t) v28; + TASSIGN(v288, v289); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v288, v283); + Tile v290 = Tile(v23, v25); + uint64_t v291 = (uint64_t) v27; + TASSIGN(v290, v291); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TGATHER, Tile, MaskPattern::P1010>(v290, v283); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v292 = Tile(v23, v20); + uint64_t v293 = (uint64_t) v28; + TASSIGN(v292, v293); + pipe_barrier(PIPE_V); + TMOV(v292, v288); + v292.SetValidShape(v23, v20); + Tile v294 = Tile(v23, v25); + uint64_t v295 = (uint64_t) v26; + TASSIGN(v294, v295); + pipe_barrier(PIPE_V); + TFILLPAD(v294, v292); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v296 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v297 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v298 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v6 * (unsigned) v24 + v19 * (unsigned) v23), v296, v297); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v298, v294); + pto::Shape<1, 1, 1, 1, 16> v299 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v300 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v301 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v6 * (unsigned) v24 + v19 * (unsigned) v23), v299, v300); + TSTORE(v301, v290); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v302 = Tile(v23, v24); + uint64_t v303 = (uint64_t) v29; + TASSIGN(v302, v303); + pto::Shape<1, 1, 1, 1, 32> v304 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v305 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v306 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v5 * (unsigned) v22 + v19 * (unsigned) v23), v304, v305); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v302, v306); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v307 = Tile(v23, v25); + uint64_t v308 = (uint64_t) v28; + TASSIGN(v307, v308); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v307, v302); + Tile v309 = Tile(v23, v25); + uint64_t v310 = (uint64_t) v27; + TASSIGN(v309, v310); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TGATHER, Tile, MaskPattern::P1010>(v309, v302); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v311 = Tile(v23, v20); + uint64_t v312 = (uint64_t) v28; + TASSIGN(v311, v312); + pipe_barrier(PIPE_V); + TMOV(v311, v307); + v311.SetValidShape(v23, v20); + Tile v313 = Tile(v23, v25); + uint64_t v314 = (uint64_t) v26; + TASSIGN(v313, v314); + pipe_barrier(PIPE_V); + TFILLPAD(v313, v311); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v315 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v316 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v317 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v5 * (unsigned) v24 + v19 * (unsigned) v23), v315, v316); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v317, v313); + pto::Shape<1, 1, 1, 1, 16> v318 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v319 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v320 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v5 * (unsigned) v24 + v19 * (unsigned) v23), v318, v319); + TSTORE(v320, v309); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v321 = Tile(v23, v24); + uint64_t v322 = (uint64_t) v29; + TASSIGN(v321, v322); + pto::Shape<1, 1, 1, 1, 32> v323 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<64, 64, 64, 64, 1> v324 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v325 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v19 + v4 * (unsigned) v22 + v19 * (unsigned) v23), v323, v324); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v321, v325); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v326 = Tile(v23, v25); + uint64_t v327 = (uint64_t) v28; + TASSIGN(v326, v327); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TGATHER, Tile, MaskPattern::P0101>(v326, v321); + Tile v328 = Tile(v23, v25); + uint64_t v329 = (uint64_t) v27; + TASSIGN(v328, v329); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TGATHER, Tile, MaskPattern::P1010>(v328, v321); + Tile v330 = Tile(v23, v20); + uint64_t v331 = (uint64_t) v28; + TASSIGN(v330, v331); + pipe_barrier(PIPE_V); + TMOV(v330, v326); + v330.SetValidShape(v23, v20); + Tile v332 = Tile(v23, v25); + uint64_t v333 = (uint64_t) v26; + TASSIGN(v332, v333); + pipe_barrier(PIPE_V); + TFILLPAD(v332, v330); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 16> v334 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v335 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v336 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v19 + v4 * (unsigned) v24 + v19 * (unsigned) v23), v334, v335); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v336, v332); + pto::Shape<1, 1, 1, 1, 16> v337 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<32, 32, 32, 32, 1> v338 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v339 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v3 + (v19 + v4 * (unsigned) v24 + v19 * (unsigned) v23), v337, v338); + TSTORE(v339, v328); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: topk_vals_pad_inline110__ssa_v0 + __gm__ Tensor* topk_vals_pad_inline110__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* topk_vals_pad_inline110__ssa_v0 = reinterpret_cast<__gm__ float*>(topk_vals_pad_inline110__ssa_v0_tensor->buffer.addr) + topk_vals_pad_inline110__ssa_v0_tensor->start_offset; + + // Unpack tensor: sorted_rows_inline97__ssa_v16 + __gm__ Tensor* sorted_rows_inline97__ssa_v16_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* sorted_rows_inline97__ssa_v16 = reinterpret_cast<__gm__ float*>(sorted_rows_inline97__ssa_v16_tensor->buffer.addr) + sorted_rows_inline97__ssa_v16_tensor->start_offset; + + // Unpack tensor: topk_idx_pad_inline104__ssa_v0 + __gm__ Tensor* topk_idx_pad_inline104__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ int32_t* topk_idx_pad_inline104__ssa_v0 = reinterpret_cast<__gm__ int32_t*>(topk_idx_pad_inline104__ssa_v0_tensor->buffer.addr) + topk_idx_pad_inline104__ssa_v0_tensor->start_offset; + + // Forward to ptoas-generated function + route_extract_top2(topk_vals_pad_inline110__ssa_v0, sorted_rows_inline97__ssa_v16, topk_idx_pad_inline104__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/route_normalize_weights.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/route_normalize_weights.cpp new file mode 100644 index 000000000..c925b3e7f --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/route_normalize_weights.cpp @@ -0,0 +1,115 @@ +// Kernel Function: route_normalize_weights +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void route_normalize_weights(__gm__ float* v1, __gm__ float* v2) { + unsigned v3 = 0; + const float v4 = 1.0f; + const int32_t v5 = 128; + const int32_t v6 = 1; + const int32_t v7 = 32; + const int32_t v8 = 16; + const int64_t v9 = 10240; + const int64_t v10 = 2048; + const int64_t v11 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v12 = Tile(v8, v7); + uint64_t v13 = (uint64_t) v11; + TASSIGN(v12, v13); + pto::Shape<1, 1, 1, 16, 32> v14 = pto::Shape<1, 1, 1, 16, 32>(); + pto::Stride<512, 512, 512, 32, 1> v15 = pto::Stride<512, 512, 512, 32, 1>(); + GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND> v16 = GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND>(v1 + (v3 + v3 * (unsigned) v7 + v3 * (unsigned) v6), v14, v15); + TLOAD(v12, v16); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v17 = Tile(v8, v5); + uint64_t v18 = (uint64_t) v10; + TASSIGN(v17, v18); + Tile v19 = Tile(v8, v6); + uint64_t v20 = (uint64_t) v9; + TASSIGN(v19, v20); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TROWSUM(v19, v12, v17); + Tile v21 = Tile(v8, v6); + uint64_t v22 = (uint64_t) v9; + TASSIGN(v21, v22); + Tile v23 = Tile(v8, v7); + uint64_t v24 = (uint64_t) v11; + TASSIGN(v23, v24); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v23, v12, v21); + Tile v25 = Tile(v8, v7); + uint64_t v26 = (uint64_t) v11; + TASSIGN(v25, v26); + pipe_barrier(PIPE_V); + TMULS(v25, v23, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 32> v27 = pto::Shape<1, 1, 1, 16, 32>(); + pto::Stride<512, 512, 512, 32, 1> v28 = pto::Stride<512, 512, 512, 32, 1>(); + GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND> v29 = GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND>(v2 + (v3 + v3 * (unsigned) v7 + v3 * (unsigned) v6), v27, v28); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v29, v25); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: topk_vals_pad_inline110__ssa_v17 + __gm__ Tensor* topk_vals_pad_inline110__ssa_v17_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* topk_vals_pad_inline110__ssa_v17 = reinterpret_cast<__gm__ float*>(topk_vals_pad_inline110__ssa_v17_tensor->buffer.addr) + topk_vals_pad_inline110__ssa_v17_tensor->start_offset; + + // Unpack tensor: weight_out_pad_inline108__ssa_v0 + __gm__ Tensor* weight_out_pad_inline108__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* weight_out_pad_inline108__ssa_v0 = reinterpret_cast<__gm__ float*>(weight_out_pad_inline108__ssa_v0_tensor->buffer.addr) + weight_out_pad_inline108__ssa_v0_tensor->start_offset; + + // Forward to ptoas-generated function + route_normalize_weights(topk_vals_pad_inline110__ssa_v17, weight_out_pad_inline108__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/route_sort_top2.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/route_sort_top2.cpp new file mode 100644 index 000000000..d4ea158f0 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/route_sort_top2.cpp @@ -0,0 +1,554 @@ +// Kernel Function: route_sort_top2 +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void route_sort_top2(__gm__ float* v1, __gm__ float* v2) { + unsigned v3 = 15; + unsigned v4 = 14; + unsigned v5 = 13; + unsigned v6 = 12; + unsigned v7 = 11; + unsigned v8 = 10; + unsigned v9 = 9; + unsigned v10 = 8; + unsigned v11 = 7; + unsigned v12 = 6; + unsigned v13 = 5; + unsigned v14 = 4; + unsigned v15 = 3; + unsigned v16 = 2; + unsigned v17 = 1; + unsigned v18 = 0; + const int32_t v19 = 0; + const int32_t v20 = 64; + const int32_t v21 = 1; + const int32_t v22 = 32; + const int64_t v23 = 256; + const int64_t v24 = 128; + const int64_t v25 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + uint32_t v26 = (uint32_t) ((int32_t) v19); + Tile v27 = Tile(v21, v22); + uint64_t v28 = (uint64_t) v25; + TASSIGN(v27, v28); + pto::Shape<1, 1, 1, 1, 32> v29 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v30 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v31 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v18 * (unsigned) v22 + v18 * (unsigned) v21), v29, v30); + TLOAD(v27, v31); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v32 = Tile(v21, v22); + uint64_t v33 = (uint64_t) v24; + TASSIGN(v32, v33); + TCI, uint32_t, 0>(v32, v26); + Tile v34 = Tile(v21, v20); + uint64_t v35 = (uint64_t) v23; + TASSIGN(v34, v35); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TSORT32(v34, v27, v32); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 64> v36 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v37 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v38 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v18 * (unsigned) v20 + v18 * (unsigned) v21), v36, v37); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v38, v34); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v39 = Tile(v21, v22); + uint64_t v40 = (uint64_t) v25; + TASSIGN(v39, v40); + pto::Shape<1, 1, 1, 1, 32> v41 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v42 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v43 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v17 * (unsigned) v22 + v18 * (unsigned) v21), v41, v42); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v39, v43); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v44 = Tile(v21, v22); + uint64_t v45 = (uint64_t) v24; + TASSIGN(v44, v45); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v44, v26); + Tile v46 = Tile(v21, v20); + uint64_t v47 = (uint64_t) v23; + TASSIGN(v46, v47); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TSORT32(v46, v39, v44); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + pto::Shape<1, 1, 1, 1, 64> v48 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v49 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v50 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v17 * (unsigned) v20 + v18 * (unsigned) v21), v48, v49); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(v50, v46); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + Tile v51 = Tile(v21, v22); + uint64_t v52 = (uint64_t) v25; + TASSIGN(v51, v52); + pto::Shape<1, 1, 1, 1, 32> v53 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v54 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v55 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v16 * (unsigned) v22 + v18 * (unsigned) v21), v53, v54); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v51, v55); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + Tile v56 = Tile(v21, v22); + uint64_t v57 = (uint64_t) v24; + TASSIGN(v56, v57); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v56, v26); + Tile v58 = Tile(v21, v20); + uint64_t v59 = (uint64_t) v23; + TASSIGN(v58, v59); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + TSORT32(v58, v51, v56); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID2); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + pto::Shape<1, 1, 1, 1, 64> v60 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v61 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v62 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v16 * (unsigned) v20 + v18 * (unsigned) v21), v60, v61); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID2); + TSTORE(v62, v58); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + Tile v63 = Tile(v21, v22); + uint64_t v64 = (uint64_t) v25; + TASSIGN(v63, v64); + pto::Shape<1, 1, 1, 1, 32> v65 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v66 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v67 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v15 * (unsigned) v22 + v18 * (unsigned) v21), v65, v66); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v63, v67); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + Tile v68 = Tile(v21, v22); + uint64_t v69 = (uint64_t) v24; + TASSIGN(v68, v69); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v68, v26); + Tile v70 = Tile(v21, v20); + uint64_t v71 = (uint64_t) v23; + TASSIGN(v70, v71); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TSORT32(v70, v63, v68); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + pto::Shape<1, 1, 1, 1, 64> v72 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v73 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v74 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v15 * (unsigned) v20 + v18 * (unsigned) v21), v72, v73); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + TSTORE(v74, v70); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + Tile v75 = Tile(v21, v22); + uint64_t v76 = (uint64_t) v25; + TASSIGN(v75, v76); + pto::Shape<1, 1, 1, 1, 32> v77 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v78 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v79 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v14 * (unsigned) v22 + v18 * (unsigned) v21), v77, v78); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + TLOAD(v75, v79); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + Tile v80 = Tile(v21, v22); + uint64_t v81 = (uint64_t) v24; + TASSIGN(v80, v81); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v80, v26); + Tile v82 = Tile(v21, v20); + uint64_t v83 = (uint64_t) v23; + TASSIGN(v82, v83); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + TSORT32(v82, v75, v80); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID4); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + pto::Shape<1, 1, 1, 1, 64> v84 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v85 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v86 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v14 * (unsigned) v20 + v18 * (unsigned) v21), v84, v85); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID4); + TSTORE(v86, v82); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID4); + Tile v87 = Tile(v21, v22); + uint64_t v88 = (uint64_t) v25; + TASSIGN(v87, v88); + pto::Shape<1, 1, 1, 1, 32> v89 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v90 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v91 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v13 * (unsigned) v22 + v18 * (unsigned) v21), v89, v90); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + TLOAD(v87, v91); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID5); + Tile v92 = Tile(v21, v22); + uint64_t v93 = (uint64_t) v24; + TASSIGN(v92, v93); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v92, v26); + Tile v94 = Tile(v21, v20); + uint64_t v95 = (uint64_t) v23; + TASSIGN(v94, v95); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID5); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID4); + TSORT32(v94, v87, v92); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID5); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID5); + pto::Shape<1, 1, 1, 1, 64> v96 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v97 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v98 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v13 * (unsigned) v20 + v18 * (unsigned) v21), v96, v97); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID5); + TSTORE(v98, v94); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + Tile v99 = Tile(v21, v22); + uint64_t v100 = (uint64_t) v25; + TASSIGN(v99, v100); + pto::Shape<1, 1, 1, 1, 32> v101 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v102 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v103 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v12 * (unsigned) v22 + v18 * (unsigned) v21), v101, v102); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID5); + TLOAD(v99, v103); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID6); + Tile v104 = Tile(v21, v22); + uint64_t v105 = (uint64_t) v24; + TASSIGN(v104, v105); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v104, v26); + Tile v106 = Tile(v21, v20); + uint64_t v107 = (uint64_t) v23; + TASSIGN(v106, v107); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID6); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + TSORT32(v106, v99, v104); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID6); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID6); + pto::Shape<1, 1, 1, 1, 64> v108 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v109 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v110 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v12 * (unsigned) v20 + v18 * (unsigned) v21), v108, v109); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID6); + TSTORE(v110, v106); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + Tile v111 = Tile(v21, v22); + uint64_t v112 = (uint64_t) v25; + TASSIGN(v111, v112); + pto::Shape<1, 1, 1, 1, 32> v113 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v114 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v115 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v11 * (unsigned) v22 + v18 * (unsigned) v21), v113, v114); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID6); + TLOAD(v111, v115); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v116 = Tile(v21, v22); + uint64_t v117 = (uint64_t) v24; + TASSIGN(v116, v117); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v116, v26); + Tile v118 = Tile(v21, v20); + uint64_t v119 = (uint64_t) v23; + TASSIGN(v118, v119); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + TSORT32(v118, v111, v116); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID7); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + pto::Shape<1, 1, 1, 1, 64> v120 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v121 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v122 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v11 * (unsigned) v20 + v18 * (unsigned) v21), v120, v121); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID7); + TSTORE(v122, v118); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID7); + Tile v123 = Tile(v21, v22); + uint64_t v124 = (uint64_t) v25; + TASSIGN(v123, v124); + pto::Shape<1, 1, 1, 1, 32> v125 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v126 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v127 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v10 * (unsigned) v22 + v18 * (unsigned) v21), v125, v126); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + TLOAD(v123, v127); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v128 = Tile(v21, v22); + uint64_t v129 = (uint64_t) v24; + TASSIGN(v128, v129); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v128, v26); + Tile v130 = Tile(v21, v20); + uint64_t v131 = (uint64_t) v23; + TASSIGN(v130, v131); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID7); + TSORT32(v130, v123, v128); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 64> v132 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v133 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v134 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v10 * (unsigned) v20 + v18 * (unsigned) v21), v132, v133); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v134, v130); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v135 = Tile(v21, v22); + uint64_t v136 = (uint64_t) v25; + TASSIGN(v135, v136); + pto::Shape<1, 1, 1, 1, 32> v137 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v138 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v139 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v9 * (unsigned) v22 + v18 * (unsigned) v21), v137, v138); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v135, v139); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v140 = Tile(v21, v22); + uint64_t v141 = (uint64_t) v24; + TASSIGN(v140, v141); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v140, v26); + Tile v142 = Tile(v21, v20); + uint64_t v143 = (uint64_t) v23; + TASSIGN(v142, v143); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TSORT32(v142, v135, v140); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 64> v144 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v145 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v146 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v9 * (unsigned) v20 + v18 * (unsigned) v21), v144, v145); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v146, v142); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v147 = Tile(v21, v22); + uint64_t v148 = (uint64_t) v25; + TASSIGN(v147, v148); + pto::Shape<1, 1, 1, 1, 32> v149 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v150 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v151 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v8 * (unsigned) v22 + v18 * (unsigned) v21), v149, v150); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v147, v151); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v152 = Tile(v21, v22); + uint64_t v153 = (uint64_t) v24; + TASSIGN(v152, v153); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v152, v26); + Tile v154 = Tile(v21, v20); + uint64_t v155 = (uint64_t) v23; + TASSIGN(v154, v155); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TSORT32(v154, v147, v152); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 64> v156 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v157 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v158 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v8 * (unsigned) v20 + v18 * (unsigned) v21), v156, v157); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v158, v154); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v159 = Tile(v21, v22); + uint64_t v160 = (uint64_t) v25; + TASSIGN(v159, v160); + pto::Shape<1, 1, 1, 1, 32> v161 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v162 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v163 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v7 * (unsigned) v22 + v18 * (unsigned) v21), v161, v162); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v159, v163); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v164 = Tile(v21, v22); + uint64_t v165 = (uint64_t) v24; + TASSIGN(v164, v165); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v164, v26); + Tile v166 = Tile(v21, v20); + uint64_t v167 = (uint64_t) v23; + TASSIGN(v166, v167); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TSORT32(v166, v159, v164); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 64> v168 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v169 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v170 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v7 * (unsigned) v20 + v18 * (unsigned) v21), v168, v169); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v170, v166); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v171 = Tile(v21, v22); + uint64_t v172 = (uint64_t) v25; + TASSIGN(v171, v172); + pto::Shape<1, 1, 1, 1, 32> v173 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v174 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v175 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v6 * (unsigned) v22 + v18 * (unsigned) v21), v173, v174); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v171, v175); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v176 = Tile(v21, v22); + uint64_t v177 = (uint64_t) v24; + TASSIGN(v176, v177); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v176, v26); + Tile v178 = Tile(v21, v20); + uint64_t v179 = (uint64_t) v23; + TASSIGN(v178, v179); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TSORT32(v178, v171, v176); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 64> v180 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v181 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v182 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v6 * (unsigned) v20 + v18 * (unsigned) v21), v180, v181); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v182, v178); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v183 = Tile(v21, v22); + uint64_t v184 = (uint64_t) v25; + TASSIGN(v183, v184); + pto::Shape<1, 1, 1, 1, 32> v185 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v186 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v187 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v5 * (unsigned) v22 + v18 * (unsigned) v21), v185, v186); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v183, v187); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v188 = Tile(v21, v22); + uint64_t v189 = (uint64_t) v24; + TASSIGN(v188, v189); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v188, v26); + Tile v190 = Tile(v21, v20); + uint64_t v191 = (uint64_t) v23; + TASSIGN(v190, v191); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TSORT32(v190, v183, v188); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 64> v192 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v193 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v194 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v5 * (unsigned) v20 + v18 * (unsigned) v21), v192, v193); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v194, v190); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v195 = Tile(v21, v22); + uint64_t v196 = (uint64_t) v25; + TASSIGN(v195, v196); + pto::Shape<1, 1, 1, 1, 32> v197 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v198 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v199 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v4 * (unsigned) v22 + v18 * (unsigned) v21), v197, v198); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v195, v199); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v200 = Tile(v21, v22); + uint64_t v201 = (uint64_t) v24; + TASSIGN(v200, v201); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v200, v26); + Tile v202 = Tile(v21, v20); + uint64_t v203 = (uint64_t) v23; + TASSIGN(v202, v203); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TSORT32(v202, v195, v200); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 64> v204 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v205 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v206 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v4 * (unsigned) v20 + v18 * (unsigned) v21), v204, v205); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v206, v202); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v207 = Tile(v21, v22); + uint64_t v208 = (uint64_t) v25; + TASSIGN(v207, v208); + pto::Shape<1, 1, 1, 1, 32> v209 = pto::Shape<1, 1, 1, 1, 32>(); + pto::Stride<32, 32, 32, 32, 1> v210 = pto::Stride<32, 32, 32, 32, 1>(); + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v211 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v1 + (v18 + v3 * (unsigned) v22 + v18 * (unsigned) v21), v209, v210); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v207, v211); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v212 = Tile(v21, v22); + uint64_t v213 = (uint64_t) v24; + TASSIGN(v212, v213); + pipe_barrier(PIPE_V); + TCI, uint32_t, 0>(v212, v26); + Tile v214 = Tile(v21, v20); + uint64_t v215 = (uint64_t) v23; + TASSIGN(v214, v215); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TSORT32(v214, v207, v212); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 1, 64> v216 = pto::Shape<1, 1, 1, 1, 64>(); + pto::Stride<64, 64, 64, 64, 1> v217 = pto::Stride<64, 64, 64, 64, 1>(); + GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v218 = GlobalTensor, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v18 + v3 * (unsigned) v20 + v18 * (unsigned) v21), v216, v217); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v218, v214); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: biased_scores_v1_inline136__ssa_v0 + __gm__ Tensor* biased_scores_v1_inline136__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* biased_scores_v1_inline136__ssa_v0 = reinterpret_cast<__gm__ float*>(biased_scores_v1_inline136__ssa_v0_tensor->buffer.addr) + biased_scores_v1_inline136__ssa_v0_tensor->start_offset; + + // Unpack tensor: sorted_rows_inline97__ssa_v0 + __gm__ Tensor* sorted_rows_inline97__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* sorted_rows_inline97__ssa_v0 = reinterpret_cast<__gm__ float*>(sorted_rows_inline97__ssa_v0_tensor->buffer.addr) + sorted_rows_inline97__ssa_v0_tensor->start_offset; + + // Forward to ptoas-generated function + route_sort_top2(biased_scores_v1_inline136__ssa_v0, sorted_rows_inline97__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_gate_up_dequant.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_gate_up_dequant.cpp new file mode 100644 index 000000000..157809f0b --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_gate_up_dequant.cpp @@ -0,0 +1,198 @@ +// Kernel Function: sh_gate_up_dequant +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void sh_gate_up_dequant(__gm__ float* v1, __gm__ float* v2, __gm__ int32_t* v3, __gm__ int32_t* v4, __gm__ float* v5, __gm__ float* v6, __gm__ float* v7, int32_t v8) { + RoundMode v9 = RoundMode::CAST_NONE; + unsigned v10 = 0; + const int32_t v11 = 256; + const int32_t v12 = 16; + const int32_t v13 = 1; + const int64_t v14 = 18432; + const int64_t v15 = 2048; + const int64_t v16 = 1024; + const int64_t v17 = 0; + const int64_t v18 = 67584; + const int64_t v19 = 51200; + const int64_t v20 = 34816; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v21 = Tile(v12, v11); + uint64_t v22 = (uint64_t) v20; + TASSIGN(v21, v22); + pto::Shape<1, 1, 1, 16, 256> v23 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v24 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v25 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v3 + (v10 + v10 * (unsigned) v11 + v10 * (unsigned) v13), v23, v24); + TLOAD(v21, v25); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v26 = Tile(v12, v11); + uint64_t v27 = (uint64_t) v19; + TASSIGN(v26, v27); + pto::Shape<1, 1, 1, 16, 256> v28 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v29 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v30 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v4 + (v10 + v10 * (unsigned) v11 + v10 * (unsigned) v13), v28, v29); + TLOAD(v26, v30); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v31 = Tile(v12, v13); + uint64_t v32 = (uint64_t) v18; + TASSIGN(v31, v32); + pto::Shape<1, 1, 1, 16, 1> v33 = pto::Shape<1, 1, 1, 16, 1>(); + pto::Stride<16, 16, 16, 1, 16> v34 = pto::Stride<16, 16, 16, 1, 16>(); + GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN> v35 = GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN>(v5 + (v10 + v10 * (unsigned) v13 + v10 * (unsigned) v12), v33, v34); + TLOAD(v31, v35); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + Tile v36 = Tile(v13, v11); + uint64_t v37 = (uint64_t) v17; + TASSIGN(v36, v37); + pto::Shape<1, 1, 1, 1, 256> v38 = pto::Shape<1, 1, 1, 1, 256>(); + pto::Stride<256, 256, 256, 256, 1> v39 = pto::Stride<256, 256, 256, 256, 1>(); + GlobalTensor, pto::Stride<256, 256, 256, 256, 1>, pto::Layout::ND> v40 = GlobalTensor, pto::Stride<256, 256, 256, 256, 1>, pto::Layout::ND>(v1 + (v10 + (unsigned) v8 * (unsigned) v13), v38, v39); + TLOAD(v36, v40); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + Tile v41 = Tile(v13, v11); + uint64_t v42 = (uint64_t) v17; + TASSIGN(v41, v42); + Tile v43 = Tile(v13, v11); + uint64_t v44 = (uint64_t) v16; + TASSIGN(v43, v44); + pto::Shape<1, 1, 1, 1, 256> v45 = pto::Shape<1, 1, 1, 1, 256>(); + pto::Stride<256, 256, 256, 256, 1> v46 = pto::Stride<256, 256, 256, 256, 1>(); + GlobalTensor, pto::Stride<256, 256, 256, 256, 1>, pto::Layout::ND> v47 = GlobalTensor, pto::Stride<256, 256, 256, 256, 1>, pto::Layout::ND>(v2 + (v10 + (unsigned) v8 * (unsigned) v13), v45, v46); + TLOAD(v43, v47); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + Tile v48 = Tile(v13, v11); + uint64_t v49 = (uint64_t) v16; + TASSIGN(v48, v49); + Tile v50 = Tile(v12, v11); + uint64_t v51 = (uint64_t) v15; + TASSIGN(v50, v51); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v50, v21, v9); + Tile v52 = Tile(v12, v11); + uint64_t v53 = (uint64_t) v14; + TASSIGN(v52, v53); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TCVT(v52, v26, v9); + Tile v54 = Tile(v12, v11); + uint64_t v55 = (uint64_t) v15; + TASSIGN(v54, v55); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + TROWEXPANDMUL(v54, v50, v31); + Tile v56 = Tile(v12, v11); + uint64_t v57 = (uint64_t) v15; + TASSIGN(v56, v57); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + TCOLEXPANDMUL(v56, v54, v41); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + Tile v58 = Tile(v12, v11); + uint64_t v59 = (uint64_t) v14; + TASSIGN(v58, v59); + TROWEXPANDMUL(v58, v52, v31); + Tile v60 = Tile(v12, v11); + uint64_t v61 = (uint64_t) v14; + TASSIGN(v60, v61); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + TCOLEXPANDMUL(v60, v58, v48); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + pto::Shape<1, 1, 1, 16, 256> v62 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v63 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v64 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v6 + (v10 + v10 * (unsigned) v11 + v10 * (unsigned) v13), v62, v63); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v64, v56); + pto::Shape<1, 1, 1, 16, 256> v65 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v66 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v67 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v7 + (v10 + v10 * (unsigned) v11 + v10 * (unsigned) v13), v65, v66); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(v67, v60); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: shared_w1_scale__ssa_v0 + __gm__ Tensor* shared_w1_scale__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* shared_w1_scale__ssa_v0 = reinterpret_cast<__gm__ float*>(shared_w1_scale__ssa_v0_tensor->buffer.addr) + shared_w1_scale__ssa_v0_tensor->start_offset; + + // Unpack tensor: shared_w3_scale__ssa_v0 + __gm__ Tensor* shared_w3_scale__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* shared_w3_scale__ssa_v0 = reinterpret_cast<__gm__ float*>(shared_w3_scale__ssa_v0_tensor->buffer.addr) + shared_w3_scale__ssa_v0_tensor->start_offset; + + // Unpack tensor: sh_gate_acc_inline95__rv_v2 + __gm__ Tensor* sh_gate_acc_inline95__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ int32_t* sh_gate_acc_inline95__rv_v2 = reinterpret_cast<__gm__ int32_t*>(sh_gate_acc_inline95__rv_v2_tensor->buffer.addr) + sh_gate_acc_inline95__rv_v2_tensor->start_offset; + + // Unpack tensor: sh_up_acc_inline118__rv_v2 + __gm__ Tensor* sh_up_acc_inline118__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ int32_t* sh_up_acc_inline118__rv_v2 = reinterpret_cast<__gm__ int32_t*>(sh_up_acc_inline118__rv_v2_tensor->buffer.addr) + sh_up_acc_inline118__rv_v2_tensor->start_offset; + + // Unpack tensor: x_local_scale_dq_inline32__ssa_v0 + __gm__ Tensor* x_local_scale_dq_inline32__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[4]); + __gm__ float* x_local_scale_dq_inline32__ssa_v0 = reinterpret_cast<__gm__ float*>(x_local_scale_dq_inline32__ssa_v0_tensor->buffer.addr) + x_local_scale_dq_inline32__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[5]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack tensor: ret1__out + __gm__ Tensor* ret1__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[6]); + __gm__ float* ret1__out = reinterpret_cast<__gm__ float*>(ret1__out_tensor->buffer.addr) + ret1__out_tensor->start_offset; + + // Unpack scalar: n0_inline113__idx_v0 + union { uint64_t u64; int64_t val; } n0_inline113__idx_v0_conv; + n0_inline113__idx_v0_conv.u64 = args[7]; + int64_t n0_inline113__idx_v0 = n0_inline113__idx_v0_conv.val; + + // Forward to ptoas-generated function + sh_gate_up_dequant(shared_w1_scale__ssa_v0, shared_w3_scale__ssa_v0, sh_gate_acc_inline95__rv_v2, sh_up_acc_inline118__rv_v2, x_local_scale_dq_inline32__ssa_v0, ret0__out, ret1__out, n0_inline113__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_h_q.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_h_q.cpp new file mode 100644 index 000000000..358916169 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_h_q.cpp @@ -0,0 +1,213 @@ +// Kernel Function: sh_h_q +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void sh_h_q(__gm__ float* v1, __gm__ int8_t* v2, __gm__ float* v3) { + RoundMode v4 = RoundMode::CAST_TRUNC; + RoundMode v5 = RoundMode::CAST_ROUND; + unsigned v6 = 0; + const float v7 = 127.0f; + const int32_t v8 = 256; + const int32_t v9 = 0; + const float v10 = 9.99999974E-5f; + const int32_t v11 = 1; + const int32_t v12 = 4096; + const int32_t v13 = 16; + const int64_t v14 = 24640; + const int64_t v15 = 16448; + const int64_t v16 = 64; + const int64_t v17 = 0; + const int64_t v18 = 61568; + const int64_t v19 = 45184; + const int64_t v20 = 28800; + const int64_t v21 = 28736; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + size_t v22 = (size_t) v12; + size_t v23 = (size_t) v9; + size_t v24 = (size_t) v8; + Tile v25 = Tile(v11, v13); + uint64_t v26 = (uint64_t) v21; + TASSIGN(v25, v26); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TEXPANDS(v25, v10); + for (size_t v27 = v23; v27 < v22; v27 += v24) { + Tile v28 = Tile(v13, v8); + uint64_t v29 = (uint64_t) v20; + TASSIGN(v28, v29); + pto::Shape<1, 1, 1, 16, 256> v30 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v31 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v32 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v12 + (unsigned) ((int32_t) v27) * (unsigned) v11), v30, v31); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v28, v32); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v33 = Tile(v13, v8); + uint64_t v34 = (uint64_t) v19; + TASSIGN(v33, v34); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TNEG(v33, v28); + Tile v35 = Tile(v13, v8); + uint64_t v36 = (uint64_t) v20; + TASSIGN(v35, v36); + pipe_barrier(PIPE_V); + TMAX(v35, v28, v33); + Tile v37 = Tile(v13, v8); + uint64_t v38 = (uint64_t) v19; + TASSIGN(v37, v38); + Tile v39 = Tile(v13, v11); + uint64_t v40 = (uint64_t) v18; + TASSIGN(v39, v40); + pipe_barrier(PIPE_V); + TROWMAX(v39, v35, v37); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v41 = Tile(v11, v13); + uint64_t v42 = (uint64_t) v18; + TASSIGN(v41, v42); + Tile v43 = Tile(v11, v13); + uint64_t v44 = (uint64_t) v21; + TASSIGN(v43, v44); + pipe_barrier(PIPE_V); + TMAX(v43, v25, v41); + } + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + Tile v45 = Tile(v11, v13); + uint64_t v46 = (uint64_t) v17; + TASSIGN(v45, v46); + TEXPANDS(v45, v7); + Tile v47 = Tile(v11, v13); + uint64_t v48 = (uint64_t) v21; + TASSIGN(v47, v48); + pipe_barrier(PIPE_V); + TDIV(v47, v45, v25); + Tile v49 = Tile(v11, v13); + uint64_t v50 = (uint64_t) v17; + TASSIGN(v49, v50); + pipe_barrier(PIPE_V); + TRECIP(v49, v47); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + Tile v51 = Tile(v13, v11); + uint64_t v52 = (uint64_t) v17; + TASSIGN(v51, v52); + Tile v53 = Tile(v13, v11); + uint64_t v54 = (uint64_t) v21; + TASSIGN(v53, v54); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + for (size_t v55 = v23; v55 < v22; v55 += v24) { + int32_t v56 = (int32_t) v55; + Tile v57 = Tile(v13, v8); + uint64_t v58 = (uint64_t) v20; + TASSIGN(v57, v58); + pto::Shape<1, 1, 1, 16, 256> v59 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v60 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v61 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v12 + (unsigned) v56 * (unsigned) v11), v59, v60); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v57, v61); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v62 = Tile(v13, v8); + uint64_t v63 = (uint64_t) v20; + TASSIGN(v62, v63); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TROWEXPANDMUL(v62, v57, v53); + Tile v64 = Tile(v13, v8); + uint64_t v65 = (uint64_t) v16; + TASSIGN(v64, v65); + pipe_barrier(PIPE_V); + TCVT(v64, v62, v5); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + Tile v66 = Tile(v13, v8); + uint64_t v67 = (uint64_t) v15; + TASSIGN(v66, v67); + pipe_barrier(PIPE_V); + TCVT(v66, v64, v5); + Tile v68 = Tile(v13, v8); + uint64_t v69 = (uint64_t) v14; + TASSIGN(v68, v69); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TCVT(v68, v66, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 256> v70 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v71 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v72 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v2 + (v6 + v6 * (unsigned) v12 + (unsigned) v56 * (unsigned) v11), v70, v71); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v72, v68); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + pto::Shape<1, 1, 1, 16, 1> v73 = pto::Shape<1, 1, 1, 16, 1>(); + pto::Stride<16, 16, 16, 1, 16> v74 = pto::Stride<16, 16, 16, 1, 16>(); + GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN> v75 = GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN>(v3 + (v6 + v6 * (unsigned) v11 + v6 * (unsigned) v13), v73, v74); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(v75, v51); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: sh_tile_fp32_inline88__rv_v2 + __gm__ Tensor* sh_tile_fp32_inline88__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* sh_tile_fp32_inline88__rv_v2 = reinterpret_cast<__gm__ float*>(sh_tile_fp32_inline88__rv_v2_tensor->buffer.addr) + sh_tile_fp32_inline88__rv_v2_tensor->start_offset; + + // Unpack tensor: sh_tile_i8_inline126__ssa_v0 + __gm__ Tensor* sh_tile_i8_inline126__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int8_t* sh_tile_i8_inline126__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(sh_tile_i8_inline126__ssa_v0_tensor->buffer.addr) + sh_tile_i8_inline126__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Forward to ptoas-generated function + sh_h_q(sh_tile_fp32_inline88__rv_v2, sh_tile_i8_inline126__ssa_v0, ret0__out); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_swiglu.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_swiglu.cpp new file mode 100644 index 000000000..644df3f38 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_swiglu.cpp @@ -0,0 +1,143 @@ +// Kernel Function: sh_swiglu +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void sh_swiglu(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, int32_t v4) { + unsigned v5 = 0; + const float v6 = 1.0f; + const int32_t v7 = 4096; + const int32_t v8 = 1; + const int32_t v9 = 256; + const int32_t v10 = 16; + const int64_t v11 = 49152; + const int64_t v12 = 32768; + const int64_t v13 = 16384; + const int64_t v14 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v15 = Tile(v10, v9); + uint64_t v16 = (uint64_t) v14; + TASSIGN(v15, v16); + pto::Shape<1, 1, 1, 16, 256> v17 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v18 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v19 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v1 + (v5 + v5 * (unsigned) v9 + v5 * (unsigned) v8), v17, v18); + TLOAD(v15, v19); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v20 = Tile(v10, v9); + uint64_t v21 = (uint64_t) v13; + TASSIGN(v20, v21); + pto::Shape<1, 1, 1, 16, 256> v22 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<4096, 4096, 4096, 256, 1> v23 = pto::Stride<4096, 4096, 4096, 256, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND> v24 = GlobalTensor, pto::Stride<4096, 4096, 4096, 256, 1>, pto::Layout::ND>(v2 + (v5 + v5 * (unsigned) v9 + v5 * (unsigned) v8), v22, v23); + TLOAD(v20, v24); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v25 = Tile(v10, v9); + uint64_t v26 = (uint64_t) v12; + TASSIGN(v25, v26); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TNEG(v25, v15); + Tile v27 = Tile(v10, v9); + uint64_t v28 = (uint64_t) v12; + TASSIGN(v27, v28); + pipe_barrier(PIPE_V); + TEXP(v27, v25); + Tile v29 = Tile(v10, v9); + uint64_t v30 = (uint64_t) v12; + TASSIGN(v29, v30); + pipe_barrier(PIPE_V); + TADDS(v29, v27, v6); + Tile v31 = Tile(v10, v9); + uint64_t v32 = (uint64_t) v11; + TASSIGN(v31, v32); + pipe_barrier(PIPE_V); + TRECIP(v31, v29); + Tile v33 = Tile(v10, v9); + uint64_t v34 = (uint64_t) v14; + TASSIGN(v33, v34); + pipe_barrier(PIPE_V); + TMUL(v33, v15, v31); + Tile v35 = Tile(v10, v9); + uint64_t v36 = (uint64_t) v14; + TASSIGN(v35, v36); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TMUL(v35, v33, v20); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 256> v37 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v38 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v39 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v3 + (v5 + v5 * (unsigned) v7 + (unsigned) v4 * (unsigned) v8), v37, v38); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v39, v35); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: sh_gate_v1_inline50__ssa_v0 + __gm__ Tensor* sh_gate_v1_inline50__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* sh_gate_v1_inline50__ssa_v0 = reinterpret_cast<__gm__ float*>(sh_gate_v1_inline50__ssa_v0_tensor->buffer.addr) + sh_gate_v1_inline50__ssa_v0_tensor->start_offset; + + // Unpack tensor: sh_up_v1_inline9__ssa_v0 + __gm__ Tensor* sh_up_v1_inline9__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* sh_up_v1_inline9__ssa_v0 = reinterpret_cast<__gm__ float*>(sh_up_v1_inline9__ssa_v0_tensor->buffer.addr) + sh_up_v1_inline9__ssa_v0_tensor->start_offset; + + // Unpack tensor: sh_tile_fp32_inline88__iter_v1 + __gm__ Tensor* sh_tile_fp32_inline88__iter_v1_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* sh_tile_fp32_inline88__iter_v1 = reinterpret_cast<__gm__ float*>(sh_tile_fp32_inline88__iter_v1_tensor->buffer.addr) + sh_tile_fp32_inline88__iter_v1_tensor->start_offset; + + // Unpack scalar: n0_inline113__idx_v0 + union { uint64_t u64; int64_t val; } n0_inline113__idx_v0_conv; + n0_inline113__idx_v0_conv.u64 = args[3]; + int64_t n0_inline113__idx_v0 = n0_inline113__idx_v0_conv.val; + + // Forward to ptoas-generated function + sh_swiglu(sh_gate_v1_inline50__ssa_v0, sh_up_v1_inline9__ssa_v0, sh_tile_fp32_inline88__iter_v1, n0_inline113__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_w2_dequant.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_w2_dequant.cpp new file mode 100644 index 000000000..2cdb6bce2 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_w2_dequant.cpp @@ -0,0 +1,143 @@ +// Kernel Function: sh_w2_dequant +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void sh_w2_dequant(__gm__ float* v1, __gm__ int32_t* v2, __gm__ float* v3, __gm__ float* v4, int32_t v5) { + RoundMode v6 = RoundMode::CAST_NONE; + unsigned v7 = 0; + const int32_t v8 = 512; + const int32_t v9 = 16; + const int32_t v10 = 1; + const int64_t v11 = 34880; + const int64_t v12 = 32832; + const int64_t v13 = 32768; + const int64_t v14 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v15 = Tile(v9, v8); + uint64_t v16 = (uint64_t) v14; + TASSIGN(v15, v16); + pto::Shape<1, 1, 1, 16, 512> v17 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<8192, 8192, 8192, 512, 1> v18 = pto::Stride<8192, 8192, 8192, 512, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND> v19 = GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND>(v2 + (v7 + v7 * (unsigned) v8 + v7 * (unsigned) v10), v17, v18); + TLOAD(v15, v19); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v20 = Tile(v9, v10); + uint64_t v21 = (uint64_t) v13; + TASSIGN(v20, v21); + pto::Shape<1, 1, 1, 16, 1> v22 = pto::Shape<1, 1, 1, 16, 1>(); + pto::Stride<16, 16, 16, 1, 16> v23 = pto::Stride<16, 16, 16, 1, 16>(); + GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN> v24 = GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN>(v3 + (v7 + v7 * (unsigned) v10 + v7 * (unsigned) v9), v22, v23); + TLOAD(v20, v24); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v25 = Tile(v10, v8); + uint64_t v26 = (uint64_t) v12; + TASSIGN(v25, v26); + pto::Shape<1, 1, 1, 1, 512> v27 = pto::Shape<1, 1, 1, 1, 512>(); + pto::Stride<512, 512, 512, 512, 1> v28 = pto::Stride<512, 512, 512, 512, 1>(); + GlobalTensor, pto::Stride<512, 512, 512, 512, 1>, pto::Layout::ND> v29 = GlobalTensor, pto::Stride<512, 512, 512, 512, 1>, pto::Layout::ND>(v1 + (v7 + (unsigned) v5 * (unsigned) v10), v27, v28); + TLOAD(v25, v29); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + Tile v30 = Tile(v10, v8); + uint64_t v31 = (uint64_t) v12; + TASSIGN(v30, v31); + Tile v32 = Tile(v9, v8); + uint64_t v33 = (uint64_t) v11; + TASSIGN(v32, v33); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v32, v15, v6); + Tile v34 = Tile(v9, v8); + uint64_t v35 = (uint64_t) v11; + TASSIGN(v34, v35); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TROWEXPANDMUL(v34, v32, v20); + Tile v36 = Tile(v9, v8); + uint64_t v37 = (uint64_t) v11; + TASSIGN(v36, v37); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + TCOLEXPANDMUL(v36, v34, v30); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 512> v38 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<8192, 8192, 8192, 512, 1> v39 = pto::Stride<8192, 8192, 8192, 512, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND> v40 = GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND>(v4 + (v7 + v7 * (unsigned) v8 + v7 * (unsigned) v10), v38, v39); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v40, v36); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: shared_w2_scale__ssa_v0 + __gm__ Tensor* shared_w2_scale__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* shared_w2_scale__ssa_v0 = reinterpret_cast<__gm__ float*>(shared_w2_scale__ssa_v0_tensor->buffer.addr) + shared_w2_scale__ssa_v0_tensor->start_offset; + + // Unpack tensor: sh_y_acc_inline4__rv_v2 + __gm__ Tensor* sh_y_acc_inline4__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int32_t* sh_y_acc_inline4__rv_v2 = reinterpret_cast<__gm__ int32_t*>(sh_y_acc_inline4__rv_v2_tensor->buffer.addr) + sh_y_acc_inline4__rv_v2_tensor->start_offset; + + // Unpack tensor: sh_tile_scale_dq_inline99__ssa_v0 + __gm__ Tensor* sh_tile_scale_dq_inline99__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* sh_tile_scale_dq_inline99__ssa_v0 = reinterpret_cast<__gm__ float*>(sh_tile_scale_dq_inline99__ssa_v0_tensor->buffer.addr) + sh_tile_scale_dq_inline99__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack scalar: d0_inline64__idx_v0 + union { uint64_t u64; int64_t val; } d0_inline64__idx_v0_conv; + d0_inline64__idx_v0_conv.u64 = args[4]; + int64_t d0_inline64__idx_v0 = d0_inline64__idx_v0_conv.val; + + // Forward to ptoas-generated function + sh_w2_dequant(shared_w2_scale__ssa_v0, sh_y_acc_inline4__rv_v2, sh_tile_scale_dq_inline99__ssa_v0, ret0__out, d0_inline64__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_write.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_write.cpp new file mode 100644 index 000000000..492672961 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/sh_write.cpp @@ -0,0 +1,103 @@ +// Kernel Function: sh_write +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void sh_write(__gm__ float* v1, __gm__ bfloat16_t* v2, int32_t v3) { + RoundMode v4 = RoundMode::CAST_RINT; + unsigned v5 = 0; + const int32_t v6 = 4096; + const int32_t v7 = 1; + const int32_t v8 = 512; + const int32_t v9 = 16; + const int64_t v10 = 32768; + const int64_t v11 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v12 = Tile(v9, v8); + uint64_t v13 = (uint64_t) v11; + TASSIGN(v12, v13); + pto::Shape<1, 1, 1, 16, 512> v14 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<8192, 8192, 8192, 512, 1> v15 = pto::Stride<8192, 8192, 8192, 512, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND> v16 = GlobalTensor, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND>(v1 + (v5 + v5 * (unsigned) v8 + v5 * (unsigned) v7), v14, v15); + TLOAD(v12, v16); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v17 = Tile(v9, v8); + uint64_t v18 = (uint64_t) v10; + TASSIGN(v17, v18); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v17, v12, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 512> v19 = pto::Shape<1, 1, 1, 16, 512>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v20 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v21 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v2 + (v5 + v5 * (unsigned) v6 + (unsigned) v3 * (unsigned) v7), v19, v20); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v21, v17); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: sh_y_v1_inline114__ssa_v0 + __gm__ Tensor* sh_y_v1_inline114__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* sh_y_v1_inline114__ssa_v0 = reinterpret_cast<__gm__ float*>(sh_y_v1_inline114__ssa_v0_tensor->buffer.addr) + sh_y_v1_inline114__ssa_v0_tensor->start_offset; + + // Unpack tensor: sh__iter_v1 + __gm__ Tensor* sh__iter_v1_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ bfloat16_t* sh__iter_v1 = reinterpret_cast<__gm__ bfloat16_t*>(sh__iter_v1_tensor->buffer.addr) + sh__iter_v1_tensor->start_offset; + + // Unpack scalar: d0_inline64__idx_v0 + union { uint64_t u64; int64_t val; } d0_inline64__idx_v0_conv; + d0_inline64__idx_v0_conv.u64 = args[2]; + int64_t d0_inline64__idx_v0 = d0_inline64__idx_v0_conv.val; + + // Forward to ptoas-generated function + sh_write(sh_y_v1_inline114__ssa_v0, sh__iter_v1, d0_inline64__idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post.cpp new file mode 100644 index 000000000..453ae14d7 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post.cpp @@ -0,0 +1,83 @@ +// Kernel Function: split_pre_post +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void split_pre_post(__gm__ float* v1) { + unsigned v2 = 0; + const float v3 = 1.0f; + const int32_t v4 = 1; + const int32_t v5 = 8; + const int32_t v6 = 16; + const int64_t v7 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v8 = Tile(v6, v5); + uint64_t v9 = (uint64_t) v7; + TASSIGN(v8, v9); + TEXPANDS(v8, v3); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 8> v10 = pto::Shape<1, 1, 1, 16, 8>(); + pto::Stride<128, 128, 128, 8, 1> v11 = pto::Stride<128, 128, 128, 8, 1>(); + GlobalTensor, pto::Stride<128, 128, 128, 8, 1>, pto::Layout::ND> v12 = GlobalTensor, pto::Stride<128, 128, 128, 8, 1>, pto::Layout::ND>(v1 + (v2 + v2 * (unsigned) v5 + v2 * (unsigned) v4), v10, v11); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v12, v8); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Forward to ptoas-generated function + split_pre_post(ret0__out); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post_0.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post_0.cpp new file mode 100644 index 000000000..bf6a0f230 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post_0.cpp @@ -0,0 +1,163 @@ +// Kernel Function: split_pre_post_0 +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void split_pre_post_0(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, __gm__ float* v4, float v5) { + unsigned v6 = 0; + const float v7 = 9.99999997E-7f; + const float v8 = 1.0f; + const int32_t v9 = 1; + const int32_t v10 = 32; + const int32_t v11 = 8; + const int32_t v12 = 16; + const int64_t v13 = 1024; + const int64_t v14 = 512; + const int64_t v15 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v16 = Tile(v12, v11); + uint64_t v17 = (uint64_t) v15; + TASSIGN(v16, v17); + pto::Shape<1, 1, 1, 16, 8> v18 = pto::Shape<1, 1, 1, 16, 8>(); + pto::Stride<512, 512, 512, 32, 1> v19 = pto::Stride<512, 512, 512, 32, 1>(); + GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND> v20 = GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v10 + v6 * (unsigned) v9), v18, v19); + TLOAD(v16, v20); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v21 = Tile(v12, v11); + uint64_t v22 = (uint64_t) v14; + TASSIGN(v21, v22); + pto::Shape<1, 1, 1, 16, 8> v23 = pto::Shape<1, 1, 1, 16, 8>(); + pto::Stride<128, 128, 128, 8, 1> v24 = pto::Stride<128, 128, 128, 8, 1>(); + GlobalTensor, pto::Stride<128, 128, 128, 8, 1>, pto::Layout::ND> v25 = GlobalTensor, pto::Stride<128, 128, 128, 8, 1>, pto::Layout::ND>(v2 + (v6 + v6 * (unsigned) v11 + v6 * (unsigned) v9), v23, v24); + TLOAD(v21, v25); + Tile v26 = Tile(v9, v11); + uint64_t v27 = (uint64_t) v13; + TASSIGN(v26, v27); + pto::Shape<1, 1, 1, 1, 8> v28 = pto::Shape<1, 1, 1, 1, 8>(); + pto::Stride<8, 8, 8, 8, 1> v29 = pto::Stride<8, 8, 8, 8, 1>(); + GlobalTensor, pto::Stride<8, 8, 8, 8, 1>, pto::Layout::ND> v30 = GlobalTensor, pto::Stride<8, 8, 8, 8, 1>, pto::Layout::ND>(v3 + (v6 + v6 * (unsigned) v11 + v6 * (unsigned) v9), v28, v29); + TLOAD(v26, v30); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v31 = Tile(v12, v11); + uint64_t v32 = (uint64_t) v15; + TASSIGN(v31, v32); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v31, v16, v5); + Tile v33 = Tile(v12, v11); + uint64_t v34 = (uint64_t) v14; + TASSIGN(v33, v34); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TCOLEXPANDMUL(v33, v21, v26); + Tile v35 = Tile(v12, v11); + uint64_t v36 = (uint64_t) v15; + TASSIGN(v35, v36); + pipe_barrier(PIPE_V); + TADD(v35, v31, v33); + Tile v37 = Tile(v12, v11); + uint64_t v38 = (uint64_t) v15; + TASSIGN(v37, v38); + pipe_barrier(PIPE_V); + TNEG(v37, v35); + Tile v39 = Tile(v12, v11); + uint64_t v40 = (uint64_t) v15; + TASSIGN(v39, v40); + pipe_barrier(PIPE_V); + TEXP(v39, v37); + Tile v41 = Tile(v12, v11); + uint64_t v42 = (uint64_t) v15; + TASSIGN(v41, v42); + pipe_barrier(PIPE_V); + TADDS(v41, v39, v8); + Tile v43 = Tile(v12, v11); + uint64_t v44 = (uint64_t) v14; + TASSIGN(v43, v44); + pipe_barrier(PIPE_V); + TRECIP(v43, v41); + Tile v45 = Tile(v12, v11); + uint64_t v46 = (uint64_t) v15; + TASSIGN(v45, v46); + pipe_barrier(PIPE_V); + TADDS(v45, v43, v7); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 8> v47 = pto::Shape<1, 1, 1, 16, 8>(); + pto::Stride<128, 128, 128, 8, 1> v48 = pto::Stride<128, 128, 128, 8, 1>(); + GlobalTensor, pto::Stride<128, 128, 128, 8, 1>, pto::Layout::ND> v49 = GlobalTensor, pto::Stride<128, 128, 128, 8, 1>, pto::Layout::ND>(v4 + (v6 + v6 * (unsigned) v11 + v6 * (unsigned) v9), v47, v48); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v49, v45); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: t__tmp_v26 + __gm__ Tensor* t__tmp_v26_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* t__tmp_v26 = reinterpret_cast<__gm__ float*>(t__tmp_v26_tensor->buffer.addr) + t__tmp_v26_tensor->start_offset; + + // Unpack tensor: ones_hc_inline79__ssa_v0 + __gm__ Tensor* ones_hc_inline79__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* ones_hc_inline79__ssa_v0 = reinterpret_cast<__gm__ float*>(ones_hc_inline79__ssa_v0_tensor->buffer.addr) + ones_hc_inline79__ssa_v0_tensor->start_offset; + + // Unpack tensor: pre_base_inline76__ssa_v0 + __gm__ Tensor* pre_base_inline76__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* pre_base_inline76__ssa_v0 = reinterpret_cast<__gm__ float*>(pre_base_inline76__ssa_v0_tensor->buffer.addr) + pre_base_inline76__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack scalar: scale0_inline21__ssa_v0 + union { uint64_t u64; float val; } scale0_inline21__ssa_v0_conv; + scale0_inline21__ssa_v0_conv.u64 = args[4]; + float scale0_inline21__ssa_v0 = scale0_inline21__ssa_v0_conv.val; + + // Forward to ptoas-generated function + split_pre_post_0(t__tmp_v26, ones_hc_inline79__ssa_v0, pre_base_inline76__ssa_v0, ret0__out, scale0_inline21__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post_1.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post_1.cpp new file mode 100644 index 000000000..d60eb2b83 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post_1.cpp @@ -0,0 +1,178 @@ +// Kernel Function: split_pre_post_1 +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void split_pre_post_1(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, __gm__ float* v4, __gm__ float* v5, float v6) { + unsigned v7 = 0; + const float v8 = 2.0f; + const float v9 = 1.0f; + const int32_t v10 = 1; + const int32_t v11 = 32; + const int32_t v12 = 8; + const int32_t v13 = 16; + const int64_t v14 = 0; + const int64_t v15 = 2048; + const int64_t v16 = 1536; + const int64_t v17 = 1024; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v18 = Tile(v13, v12); + uint64_t v19 = (uint64_t) v17; + TASSIGN(v18, v19); + pto::Shape<1, 1, 1, 16, 8> v20 = pto::Shape<1, 1, 1, 16, 8>(); + pto::Stride<512, 512, 512, 32, 1> v21 = pto::Stride<512, 512, 512, 32, 1>(); + GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND> v22 = GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND>(v1 + (v7 + v7 * (unsigned) v11 + v7 * (unsigned) v10), v20, v21); + TLOAD(v18, v22); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v23 = Tile(v13, v12); + uint64_t v24 = (uint64_t) v16; + TASSIGN(v23, v24); + pto::Shape<1, 1, 1, 16, 8> v25 = pto::Shape<1, 1, 1, 16, 8>(); + pto::Stride<128, 128, 128, 8, 1> v26 = pto::Stride<128, 128, 128, 8, 1>(); + GlobalTensor, pto::Stride<128, 128, 128, 8, 1>, pto::Layout::ND> v27 = GlobalTensor, pto::Stride<128, 128, 128, 8, 1>, pto::Layout::ND>(v2 + (v7 + v7 * (unsigned) v12 + v7 * (unsigned) v10), v25, v26); + TLOAD(v23, v27); + Tile v28 = Tile(v10, v12); + uint64_t v29 = (uint64_t) v15; + TASSIGN(v28, v29); + pto::Shape<1, 1, 1, 1, 8> v30 = pto::Shape<1, 1, 1, 1, 8>(); + pto::Stride<8, 8, 8, 8, 1> v31 = pto::Stride<8, 8, 8, 8, 1>(); + GlobalTensor, pto::Stride<8, 8, 8, 8, 1>, pto::Layout::ND> v32 = GlobalTensor, pto::Stride<8, 8, 8, 8, 1>, pto::Layout::ND>(v3 + (v7 + v7 * (unsigned) v12 + v7 * (unsigned) v10), v30, v31); + TLOAD(v28, v32); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v33 = Tile(v13, v12); + uint64_t v34 = (uint64_t) v17; + TASSIGN(v33, v34); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v33, v18, v6); + Tile v35 = Tile(v13, v12); + uint64_t v36 = (uint64_t) v16; + TASSIGN(v35, v36); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TCOLEXPANDMUL(v35, v23, v28); + Tile v37 = Tile(v13, v12); + uint64_t v38 = (uint64_t) v17; + TASSIGN(v37, v38); + pipe_barrier(PIPE_V); + TADD(v37, v33, v35); + Tile v39 = Tile(v13, v12); + uint64_t v40 = (uint64_t) v17; + TASSIGN(v39, v40); + pipe_barrier(PIPE_V); + TNEG(v39, v37); + Tile v41 = Tile(v13, v12); + uint64_t v42 = (uint64_t) v17; + TASSIGN(v41, v42); + pipe_barrier(PIPE_V); + TEXP(v41, v39); + Tile v43 = Tile(v13, v12); + uint64_t v44 = (uint64_t) v17; + TASSIGN(v43, v44); + pipe_barrier(PIPE_V); + TADDS(v43, v41, v9); + Tile v45 = Tile(v13, v12); + uint64_t v46 = (uint64_t) v16; + TASSIGN(v45, v46); + pipe_barrier(PIPE_V); + TRECIP(v45, v43); + Tile v47 = Tile(v13, v12); + uint64_t v48 = (uint64_t) v17; + TASSIGN(v47, v48); + pipe_barrier(PIPE_V); + TMULS(v47, v45, v8); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + Tile v49 = Tile(v13, v13); + uint64_t v50 = (uint64_t) v14; + TASSIGN(v49, v50); + TEXPANDS(v49, v9); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + pto::Shape<1, 1, 1, 16, 8> v51 = pto::Shape<1, 1, 1, 16, 8>(); + pto::Stride<128, 128, 128, 8, 1> v52 = pto::Stride<128, 128, 128, 8, 1>(); + GlobalTensor, pto::Stride<128, 128, 128, 8, 1>, pto::Layout::ND> v53 = GlobalTensor, pto::Stride<128, 128, 128, 8, 1>, pto::Layout::ND>(v4 + (v7 + v7 * (unsigned) v12 + v7 * (unsigned) v10), v51, v52); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v53, v47); + pto::Shape<1, 1, 1, 16, 16> v54 = pto::Shape<1, 1, 1, 16, 16>(); + pto::Stride<256, 256, 256, 16, 1> v55 = pto::Stride<256, 256, 256, 16, 1>(); + GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> v56 = GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>(v5 + (v7 + v7 * (unsigned) v13 + v7 * (unsigned) v10), v54, v55); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(v56, v49); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: t__tmp_v34 + __gm__ Tensor* t__tmp_v34_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* t__tmp_v34 = reinterpret_cast<__gm__ float*>(t__tmp_v34_tensor->buffer.addr) + t__tmp_v34_tensor->start_offset; + + // Unpack tensor: ones_hc_inline79__ssa_v0 + __gm__ Tensor* ones_hc_inline79__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* ones_hc_inline79__ssa_v0 = reinterpret_cast<__gm__ float*>(ones_hc_inline79__ssa_v0_tensor->buffer.addr) + ones_hc_inline79__ssa_v0_tensor->start_offset; + + // Unpack tensor: post_base_inline50__ssa_v0 + __gm__ Tensor* post_base_inline50__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* post_base_inline50__ssa_v0 = reinterpret_cast<__gm__ float*>(post_base_inline50__ssa_v0_tensor->buffer.addr) + post_base_inline50__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack tensor: ret1__out + __gm__ Tensor* ret1__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[4]); + __gm__ float* ret1__out = reinterpret_cast<__gm__ float*>(ret1__out_tensor->buffer.addr) + ret1__out_tensor->start_offset; + + // Unpack scalar: scale1_inline17__ssa_v0 + union { uint64_t u64; float val; } scale1_inline17__ssa_v0_conv; + scale1_inline17__ssa_v0_conv.u64 = args[5]; + float scale1_inline17__ssa_v0 = scale1_inline17__ssa_v0_conv.val; + + // Forward to ptoas-generated function + split_pre_post_1(t__tmp_v34, ones_hc_inline79__ssa_v0, post_base_inline50__ssa_v0, ret0__out, ret1__out, scale1_inline17__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post_2.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post_2.cpp new file mode 100644 index 000000000..6f0848188 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/split_pre_post_2.cpp @@ -0,0 +1,135 @@ +// Kernel Function: split_pre_post_2 +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void split_pre_post_2(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, __gm__ float* v4, float v5) { + unsigned v6 = 0; + const int32_t v7 = 1; + const int32_t v8 = 32; + const int32_t v9 = 16; + const int64_t v10 = 2048; + const int64_t v11 = 1024; + const int64_t v12 = 0; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + Tile v13 = Tile(v9, v9); + uint64_t v14 = (uint64_t) v12; + TASSIGN(v13, v14); + pto::Shape<1, 1, 1, 16, 16> v15 = pto::Shape<1, 1, 1, 16, 16>(); + pto::Stride<512, 512, 512, 32, 1> v16 = pto::Stride<512, 512, 512, 32, 1>(); + GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND> v17 = GlobalTensor, pto::Stride<512, 512, 512, 32, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v8 + v6 * (unsigned) v7), v15, v16); + TLOAD(v13, v17); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v18 = Tile(v9, v9); + uint64_t v19 = (uint64_t) v11; + TASSIGN(v18, v19); + pto::Shape<1, 1, 1, 16, 16> v20 = pto::Shape<1, 1, 1, 16, 16>(); + pto::Stride<256, 256, 256, 16, 1> v21 = pto::Stride<256, 256, 256, 16, 1>(); + GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> v22 = GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>(v2 + (v6 + v6 * (unsigned) v9 + v6 * (unsigned) v7), v20, v21); + TLOAD(v18, v22); + Tile v23 = Tile(v7, v9); + uint64_t v24 = (uint64_t) v10; + TASSIGN(v23, v24); + pto::Shape<1, 1, 1, 1, 16> v25 = pto::Shape<1, 1, 1, 1, 16>(); + pto::Stride<16, 16, 16, 16, 1> v26 = pto::Stride<16, 16, 16, 16, 1>(); + GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND> v27 = GlobalTensor, pto::Stride<16, 16, 16, 16, 1>, pto::Layout::ND>(v3 + (v6 + v6 * (unsigned) v9 + v6 * (unsigned) v7), v25, v26); + TLOAD(v23, v27); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v28 = Tile(v9, v9); + uint64_t v29 = (uint64_t) v12; + TASSIGN(v28, v29); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v28, v13, v5); + Tile v30 = Tile(v9, v9); + uint64_t v31 = (uint64_t) v11; + TASSIGN(v30, v31); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TCOLEXPANDMUL(v30, v18, v23); + Tile v32 = Tile(v9, v9); + uint64_t v33 = (uint64_t) v12; + TASSIGN(v32, v33); + pipe_barrier(PIPE_V); + TADD(v32, v28, v30); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 16> v34 = pto::Shape<1, 1, 1, 16, 16>(); + pto::Stride<256, 256, 256, 16, 1> v35 = pto::Stride<256, 256, 256, 16, 1>(); + GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND> v36 = GlobalTensor, pto::Stride<256, 256, 256, 16, 1>, pto::Layout::ND>(v4 + (v6 + v6 * (unsigned) v9 + v6 * (unsigned) v7), v34, v35); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v36, v32); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: comb_mix_inline41__ssa_v0 + __gm__ Tensor* comb_mix_inline41__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* comb_mix_inline41__ssa_v0 = reinterpret_cast<__gm__ float*>(comb_mix_inline41__ssa_v0_tensor->buffer.addr) + comb_mix_inline41__ssa_v0_tensor->start_offset; + + // Unpack tensor: ones_comb_inline67__ssa_v0 + __gm__ Tensor* ones_comb_inline67__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* ones_comb_inline67__ssa_v0 = reinterpret_cast<__gm__ float*>(ones_comb_inline67__ssa_v0_tensor->buffer.addr) + ones_comb_inline67__ssa_v0_tensor->start_offset; + + // Unpack tensor: comb_base_inline70__ssa_v0 + __gm__ Tensor* comb_base_inline70__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* comb_base_inline70__ssa_v0 = reinterpret_cast<__gm__ float*>(comb_base_inline70__ssa_v0_tensor->buffer.addr) + comb_base_inline70__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Unpack scalar: scale2_inline26__ssa_v0 + union { uint64_t u64; float val; } scale2_inline26__ssa_v0_conv; + scale2_inline26__ssa_v0_conv.u64 = args[4]; + float scale2_inline26__ssa_v0 = scale2_inline26__ssa_v0_conv.val; + + // Forward to ptoas-generated function + split_pre_post_2(comb_mix_inline41__ssa_v0, ones_comb_inline67__ssa_v0, comb_base_inline70__ssa_v0, ret0__out, scale2_inline26__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/write_post.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/write_post.cpp new file mode 100644 index 000000000..6fb3e9515 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/write_post.cpp @@ -0,0 +1,101 @@ +// Kernel Function: write_post +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void write_post(__gm__ float* v1, __gm__ float* v2, int32_t v3) { + const int32_t v4 = 3; + const int32_t v5 = 2; + const int32_t v6 = 4; + const int32_t v7 = 8; + const int32_t v8 = 16; + const int32_t v9 = 0; + const int32_t v10 = 1; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + for (size_t v11 = (size_t) v9; v11 < ((size_t) v8); v11 += (size_t) v10) { + int32_t v12 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v3 * (uint32_t) v8) + (uint32_t) ((int32_t) v11)); + if (v12 < v8) { + int32_t v13 = (int32_t) ((uint32_t) v12 * (uint32_t) v7); + float v14 = v1[v13]; + int32_t v15 = (int32_t) ((uint32_t) v12 * (uint32_t) v6); + v2[v15] = v14; + float v16 = v1[(int32_t) ((uint32_t) v13 + (uint32_t) v10)]; + int32_t v17 = (int32_t) ((uint32_t) v15 + (uint32_t) v10); + v2[v17] = v16; + float v18 = v1[(int32_t) ((uint32_t) v13 + (uint32_t) v5)]; + int32_t v19 = (int32_t) ((uint32_t) v15 + (uint32_t) v5); + v2[v19] = v18; + float v20 = v1[(int32_t) ((uint32_t) v13 + (uint32_t) v4)]; + int32_t v21 = (int32_t) ((uint32_t) v15 + (uint32_t) v4); + v2[v21] = v20; + }; + } + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: post_pad_flat_inline73__ssa_v0 + __gm__ Tensor* post_pad_flat_inline73__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ float* post_pad_flat_inline73__ssa_v0 = reinterpret_cast<__gm__ float*>(post_pad_flat_inline73__ssa_v0_tensor->buffer.addr) + post_pad_flat_inline73__ssa_v0_tensor->start_offset; + + // Unpack tensor: post_flat_inline51__ssa_v0 + __gm__ Tensor* post_flat_inline51__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ float* post_flat_inline51__ssa_v0 = reinterpret_cast<__gm__ float*>(post_flat_inline51__ssa_v0_tensor->buffer.addr) + post_flat_inline51__ssa_v0_tensor->start_offset; + + // Unpack scalar: t_inline7__co_idx_v0 + union { uint64_t u64; int64_t val; } t_inline7__co_idx_v0_conv; + t_inline7__co_idx_v0_conv.u64 = args[2]; + int64_t t_inline7__co_idx_v0 = t_inline7__co_idx_v0_conv.val; + + // Forward to ptoas-generated function + write_post(post_pad_flat_inline73__ssa_v0, post_flat_inline51__ssa_v0, t_inline7__co_idx_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/write_route_outputs.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/write_route_outputs.cpp new file mode 100644 index 000000000..af070e9b6 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/write_route_outputs.cpp @@ -0,0 +1,281 @@ +// Kernel Function: write_route_outputs +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" +#include "aicore/aicore.h" // dcci — flush scalar writes out of L1 so the + // downstream dispatch kernel sees them via scalar + // reads. Without this, indices/weights stay in + // this core's data cache and dispatch races + // ~60% of runs in onboard mode. + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void write_route_outputs(__gm__ int32_t* v1, __gm__ int32_t* v2, __gm__ float* v3, __gm__ float* v4) { + const int32_t v5 = 31; + const int32_t v6 = 481; + const int32_t v7 = 30; + const int32_t v8 = 480; + const int32_t v9 = 29; + const int32_t v10 = 449; + const int32_t v11 = 28; + const int32_t v12 = 448; + const int32_t v13 = 27; + const int32_t v14 = 417; + const int32_t v15 = 26; + const int32_t v16 = 416; + const int32_t v17 = 25; + const int32_t v18 = 385; + const int32_t v19 = 24; + const int32_t v20 = 384; + const int32_t v21 = 23; + const int32_t v22 = 353; + const int32_t v23 = 22; + const int32_t v24 = 352; + const int32_t v25 = 21; + const int32_t v26 = 321; + const int32_t v27 = 20; + const int32_t v28 = 320; + const int32_t v29 = 19; + const int32_t v30 = 289; + const int32_t v31 = 18; + const int32_t v32 = 288; + const int32_t v33 = 17; + const int32_t v34 = 257; + const int32_t v35 = 16; + const int32_t v36 = 256; + const int32_t v37 = 15; + const int32_t v38 = 225; + const int32_t v39 = 14; + const int32_t v40 = 224; + const int32_t v41 = 13; + const int32_t v42 = 193; + const int32_t v43 = 12; + const int32_t v44 = 192; + const int32_t v45 = 11; + const int32_t v46 = 161; + const int32_t v47 = 10; + const int32_t v48 = 160; + const int32_t v49 = 9; + const int32_t v50 = 129; + const int32_t v51 = 8; + const int32_t v52 = 128; + const int32_t v53 = 7; + const int32_t v54 = 97; + const int32_t v55 = 6; + const int32_t v56 = 96; + const int32_t v57 = 5; + const int32_t v58 = 65; + const int32_t v59 = 4; + const int32_t v60 = 64; + const int32_t v61 = 3; + const int32_t v62 = 33; + const int32_t v63 = 2; + const int32_t v64 = 0; + const int32_t v65 = 32; + const int32_t v66 = 1; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + int32_t v67 = v1[v64]; + v2[v64] = v67; + float v68 = v3[v64]; + v4[v64] = v68; + int32_t v69 = v1[v66]; + v2[v66] = v69; + float v70 = v3[v66]; + v4[v66] = v70; + int32_t v71 = v1[v65]; + v2[v63] = v71; + float v72 = v3[v65]; + v4[v63] = v72; + int32_t v73 = v1[v62]; + v2[v61] = v73; + float v74 = v3[v62]; + v4[v61] = v74; + int32_t v75 = v1[v60]; + v2[v59] = v75; + float v76 = v3[v60]; + v4[v59] = v76; + int32_t v77 = v1[v58]; + v2[v57] = v77; + float v78 = v3[v58]; + v4[v57] = v78; + int32_t v79 = v1[v56]; + v2[v55] = v79; + float v80 = v3[v56]; + v4[v55] = v80; + int32_t v81 = v1[v54]; + v2[v53] = v81; + float v82 = v3[v54]; + v4[v53] = v82; + int32_t v83 = v1[v52]; + v2[v51] = v83; + float v84 = v3[v52]; + v4[v51] = v84; + int32_t v85 = v1[v50]; + v2[v49] = v85; + float v86 = v3[v50]; + v4[v49] = v86; + int32_t v87 = v1[v48]; + v2[v47] = v87; + float v88 = v3[v48]; + v4[v47] = v88; + int32_t v89 = v1[v46]; + v2[v45] = v89; + float v90 = v3[v46]; + v4[v45] = v90; + int32_t v91 = v1[v44]; + v2[v43] = v91; + float v92 = v3[v44]; + v4[v43] = v92; + int32_t v93 = v1[v42]; + v2[v41] = v93; + float v94 = v3[v42]; + v4[v41] = v94; + int32_t v95 = v1[v40]; + v2[v39] = v95; + float v96 = v3[v40]; + v4[v39] = v96; + int32_t v97 = v1[v38]; + v2[v37] = v97; + float v98 = v3[v38]; + v4[v37] = v98; + int32_t v99 = v1[v36]; + v2[v35] = v99; + float v100 = v3[v36]; + v4[v35] = v100; + int32_t v101 = v1[v34]; + v2[v33] = v101; + float v102 = v3[v34]; + v4[v33] = v102; + int32_t v103 = v1[v32]; + v2[v31] = v103; + float v104 = v3[v32]; + v4[v31] = v104; + int32_t v105 = v1[v30]; + v2[v29] = v105; + float v106 = v3[v30]; + v4[v29] = v106; + int32_t v107 = v1[v28]; + v2[v27] = v107; + float v108 = v3[v28]; + v4[v27] = v108; + int32_t v109 = v1[v26]; + v2[v25] = v109; + float v110 = v3[v26]; + v4[v25] = v110; + int32_t v111 = v1[v24]; + v2[v23] = v111; + float v112 = v3[v24]; + v4[v23] = v112; + int32_t v113 = v1[v22]; + v2[v21] = v113; + float v114 = v3[v22]; + v4[v21] = v114; + int32_t v115 = v1[v20]; + v2[v19] = v115; + float v116 = v3[v20]; + v4[v19] = v116; + int32_t v117 = v1[v18]; + v2[v17] = v117; + float v118 = v3[v18]; + v4[v17] = v118; + int32_t v119 = v1[v16]; + v2[v15] = v119; + float v120 = v3[v16]; + v4[v15] = v120; + int32_t v121 = v1[v14]; + v2[v13] = v121; + float v122 = v3[v14]; + v4[v13] = v122; + int32_t v123 = v1[v12]; + v2[v11] = v123; + float v124 = v3[v12]; + v4[v11] = v124; + int32_t v125 = v1[v10]; + v2[v9] = v125; + float v126 = v3[v10]; + v4[v9] = v126; + int32_t v127 = v1[v8]; + v2[v7] = v127; + float v128 = v3[v8]; + v4[v7] = v128; + int32_t v129 = v1[v6]; + v2[v5] = v129; + float v130 = v3[v6]; + v4[v5] = v130; + // Flush scalar writes (indices, weights) out of this AICore's L1 so the + // downstream dispatch task sees them when it scalar-reads `indices`. The + // PyPTO-emitted body has no tail sync, and dispatch's reader-side dcci on + // its own only invalidates the reader cache — both sides must participate. + pipe_barrier(PIPE_ALL); + dcci(v2, ENTIRE_DATA_CACHE, CACHELINE_OUT); + dcci(v4, ENTIRE_DATA_CACHE, CACHELINE_OUT); + dsb((mem_dsb_t)0); + #endif // __DAV_VEC__ + + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: topk_idx_flat_inline87__ssa_v0 + __gm__ Tensor* topk_idx_flat_inline87__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ int32_t* topk_idx_flat_inline87__ssa_v0 = reinterpret_cast<__gm__ int32_t*>(topk_idx_flat_inline87__ssa_v0_tensor->buffer.addr) + topk_idx_flat_inline87__ssa_v0_tensor->start_offset; + + // Unpack tensor: indices_flat_inline89__ssa_v0 + __gm__ Tensor* indices_flat_inline89__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int32_t* indices_flat_inline89__ssa_v0 = reinterpret_cast<__gm__ int32_t*>(indices_flat_inline89__ssa_v0_tensor->buffer.addr) + indices_flat_inline89__ssa_v0_tensor->start_offset; + + // Unpack tensor: weight_out_flat_inline86__ssa_v0 + __gm__ Tensor* weight_out_flat_inline86__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* weight_out_flat_inline86__ssa_v0 = reinterpret_cast<__gm__ float*>(weight_out_flat_inline86__ssa_v0_tensor->buffer.addr) + weight_out_flat_inline86__ssa_v0_tensor->start_offset; + + // Unpack tensor: weights_flat_inline88__ssa_v0 + __gm__ Tensor* weights_flat_inline88__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + __gm__ float* weights_flat_inline88__ssa_v0 = reinterpret_cast<__gm__ float*>(weights_flat_inline88__ssa_v0_tensor->buffer.addr) + weights_flat_inline88__ssa_v0_tensor->start_offset; + + // Forward to ptoas-generated function + write_route_outputs(topk_idx_flat_inline87__ssa_v0, indices_flat_inline89__ssa_v0, weight_out_flat_inline86__ssa_v0, weights_flat_inline88__ssa_v0); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/x_local_q.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/x_local_q.cpp new file mode 100644 index 000000000..9e4b38721 --- /dev/null +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/x_local_q.cpp @@ -0,0 +1,224 @@ +// Kernel Function: x_local_q +// Generated by PyPTO IR Compiler (PTO backend) + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#if defined(__CPU_SIM) +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif +#endif + +#include +#include "tensor.h" + + +using namespace pto; + + +// --- ptoas-generated code --- + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static __aicore__ inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +static __aicore__ void x_local_q(__gm__ bfloat16_t* v1, __gm__ int8_t* v2, __gm__ float* v3) { + RoundMode v4 = RoundMode::CAST_TRUNC; + RoundMode v5 = RoundMode::CAST_ROUND; + unsigned v6 = 0; + const float v7 = 127.0f; + const int32_t v8 = 256; + const int32_t v9 = 0; + const float v10 = 9.99999974E-5f; + const int32_t v11 = 1; + const int32_t v12 = 4096; + const int32_t v13 = 16; + const int64_t v14 = 24640; + const int64_t v15 = 16448; + const int64_t v16 = 64; + const int64_t v17 = 0; + const int64_t v18 = 69760; + const int64_t v19 = 53376; + const int64_t v20 = 36992; + const int64_t v21 = 28800; + const int64_t v22 = 28736; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + size_t v23 = (size_t) v12; + size_t v24 = (size_t) v9; + size_t v25 = (size_t) v8; + Tile v26 = Tile(v11, v13); + uint64_t v27 = (uint64_t) v22; + TASSIGN(v26, v27); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TEXPANDS(v26, v10); + for (size_t v28 = v24; v28 < v23; v28 += v25) { + Tile v29 = Tile(v13, v8); + uint64_t v30 = (uint64_t) v21; + TASSIGN(v29, v30); + pto::Shape<1, 1, 1, 16, 256> v31 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v32 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v33 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v12 + (unsigned) ((int32_t) v28) * (unsigned) v11), v31, v32); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v29, v33); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + Tile v34 = Tile(v13, v8); + uint64_t v35 = (uint64_t) v20; + TASSIGN(v34, v35); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TCVT(v34, v29, v5); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Tile v36 = Tile(v13, v8); + uint64_t v37 = (uint64_t) v19; + TASSIGN(v36, v37); + pipe_barrier(PIPE_V); + TNEG(v36, v34); + Tile v38 = Tile(v13, v8); + uint64_t v39 = (uint64_t) v20; + TASSIGN(v38, v39); + pipe_barrier(PIPE_V); + TMAX(v38, v34, v36); + Tile v40 = Tile(v13, v8); + uint64_t v41 = (uint64_t) v19; + TASSIGN(v40, v41); + Tile v42 = Tile(v13, v11); + uint64_t v43 = (uint64_t) v18; + TASSIGN(v42, v43); + pipe_barrier(PIPE_V); + TROWMAX(v42, v38, v40); + Tile v44 = Tile(v11, v13); + uint64_t v45 = (uint64_t) v18; + TASSIGN(v44, v45); + Tile v46 = Tile(v11, v13); + uint64_t v47 = (uint64_t) v22; + TASSIGN(v46, v47); + pipe_barrier(PIPE_V); + TMAX(v46, v26, v44); + } + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + Tile v48 = Tile(v11, v13); + uint64_t v49 = (uint64_t) v17; + TASSIGN(v48, v49); + TEXPANDS(v48, v7); + Tile v50 = Tile(v11, v13); + uint64_t v51 = (uint64_t) v22; + TASSIGN(v50, v51); + pipe_barrier(PIPE_V); + TDIV(v50, v48, v26); + Tile v52 = Tile(v11, v13); + uint64_t v53 = (uint64_t) v17; + TASSIGN(v52, v53); + pipe_barrier(PIPE_V); + TRECIP(v52, v50); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + Tile v54 = Tile(v13, v11); + uint64_t v55 = (uint64_t) v17; + TASSIGN(v54, v55); + Tile v56 = Tile(v13, v11); + uint64_t v57 = (uint64_t) v22; + TASSIGN(v56, v57); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + for (size_t v58 = v24; v58 < v23; v58 += v25) { + int32_t v59 = (int32_t) v58; + Tile v60 = Tile(v13, v8); + uint64_t v61 = (uint64_t) v21; + TASSIGN(v60, v61); + pto::Shape<1, 1, 1, 16, 256> v62 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v63 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v64 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v12 + (unsigned) v59 * (unsigned) v11), v62, v63); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v60, v64); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + Tile v65 = Tile(v13, v8); + uint64_t v66 = (uint64_t) v20; + TASSIGN(v65, v66); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TCVT(v65, v60, v5); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + Tile v67 = Tile(v13, v8); + uint64_t v68 = (uint64_t) v20; + TASSIGN(v67, v68); + pipe_barrier(PIPE_V); + TROWEXPANDMUL(v67, v65, v56); + Tile v69 = Tile(v13, v8); + uint64_t v70 = (uint64_t) v16; + TASSIGN(v69, v70); + pipe_barrier(PIPE_V); + TCVT(v69, v67, v5); + Tile v71 = Tile(v13, v8); + uint64_t v72 = (uint64_t) v15; + TASSIGN(v71, v72); + pipe_barrier(PIPE_V); + TCVT(v71, v69, v5); + Tile v73 = Tile(v13, v8); + uint64_t v74 = (uint64_t) v14; + TASSIGN(v73, v74); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TCVT(v73, v71, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 16, 256> v75 = pto::Shape<1, 1, 1, 16, 256>(); + pto::Stride<65536, 65536, 65536, 4096, 1> v76 = pto::Stride<65536, 65536, 65536, 4096, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v77 = GlobalTensor, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v2 + (v6 + v6 * (unsigned) v12 + (unsigned) v59 * (unsigned) v11), v75, v76); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v77, v73); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + pto::Shape<1, 1, 1, 16, 1> v78 = pto::Shape<1, 1, 1, 16, 1>(); + pto::Stride<16, 16, 16, 1, 16> v79 = pto::Stride<16, 16, 16, 1, 16>(); + GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN> v80 = GlobalTensor, pto::Stride<16, 16, 16, 1, 16>, pto::Layout::DN>(v3 + (v6 + v6 * (unsigned) v11 + v6 * (unsigned) v13), v78, v79); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(v80, v54); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +// --- Kernel entry point --- +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) +{ + // Unpack tensor: x_local__ssa_v0 + __gm__ Tensor* x_local__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ bfloat16_t* x_local__ssa_v0 = reinterpret_cast<__gm__ bfloat16_t*>(x_local__ssa_v0_tensor->buffer.addr) + x_local__ssa_v0_tensor->start_offset; + + // Unpack tensor: x_local_i8_inline43__ssa_v0 + __gm__ Tensor* x_local_i8_inline43__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int8_t* x_local_i8_inline43__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(x_local_i8_inline43__ssa_v0_tensor->buffer.addr) + x_local_i8_inline43__ssa_v0_tensor->start_offset; + + // Unpack tensor: ret0__out + __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ float* ret0__out = reinterpret_cast<__gm__ float*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; + + // Forward to ptoas-generated function + x_local_q(x_local__ssa_v0, x_local_i8_inline43__ssa_v0, ret0__out); +} diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/orchestration/ep_dispatch_combine_orch.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/orchestration/ep_dispatch_combine_orch.cpp index cf2f10a1e..5e00419ce 100644 --- a/examples/workers/l3/ep_dispatch_combine/kernels/orchestration/ep_dispatch_combine_orch.cpp +++ b/examples/workers/l3/ep_dispatch_combine/kernels/orchestration/ep_dispatch_combine_orch.cpp @@ -9,42 +9,74 @@ * ----------------------------------------------------------------------------------------------------------- */ /** - * EP dispatch + local-expert + combine orchestration. + * moe_router + EP dispatch + moe_expert + combine orchestration. * - * Three child AIV kernels chained via the runtime; cross-kernel data flows - * through host-backed device tensors (recv_*_out / recv_y) rather than the - * HCCL window — only routed_y_buf and combine_done need cross-rank - * visibility, both still in the shared scratch. + * func_id 0..17 moe_router kernels ffn half-compress pre-mix (hc_pre) + + * RMSNorm + learned-score gate + top-k + + * weight normalize. Produces x_norm, + * indices, weights (plus post_ffn / + * comb_ffn for hc_post). + * func_id 18 dispatch.cpp EP count exchange + 3-channel push; + * reads chip-produced x_norm + indices, + * plus host-packed w_padded / idx_padded. + * func_id 19..35 moe_expert kernels routed local experts + shared expert + * (task ids +19; recv_expert_count read + * from host-known tensor(11)). + * func_id 36 combine.cpp TPUT recv_y rows -> routed_y_buf; + * reduce_sum -> routed_y. + * func_id 37 ffn_add.cpp ffn_out = routed_y + sh + * (single-layer spec, model.py:644-645). + * func_id 38 hc_post.cpp hc_post(ffn_out, x_hc, post_ffn, + * comb_ffn) -> y (next-layer x_hc). * - * func_id 0 dispatch.cpp count exchange + 3-channel push - * + stage-out + recv_count emission - * func_id 1 local_expert.cpp placeholder for moe_expert: - * recv_y[e, s, :] = recv_x[e, s, :] * recv_w[e, s] - * func_id 2 combine.cpp TPUT recv_y rows by recv_idx_out into - * routed_y_buf (relies on HCCL window - * zero-init), barrier, reduce_sum along - * TOPK -> routed_y FP32 - * - * tensor(0) indices INPUT [T, TOPK] INT32 - * tensor(1) x_norm INPUT [T, D] BF16 - * tensor(2) w_padded INPUT [T*TOPK, W_PAD=8] FP32 - * tensor(3) idx_padded INPUT [T*TOPK, IDX_PAD=8] INT32 - * tensor(4) recv_x_out OUTPUT_EXISTING [L, R, D] BF16 - * tensor(5) recv_w_out OUTPUT_EXISTING [L, R] FP32 - * tensor(6) recv_idx_out OUTPUT_EXISTING [L, R] INT32 - * tensor(7) recv_count_out OUTPUT_EXISTING [L, 1] INT32 - * tensor(8) recv_y OUTPUT_EXISTING [L, R, D] BF16 - * tensor(9) routed_y OUTPUT_EXISTING [T, D] FP32 - * tensor(10) scratch INOUT HCCL window slot - * scalar(0) nranks - * scalar(1) CommContext device pointer - * - * Tasks run sequentially because rt_submit_aiv_task dispatches in order. - * Cross-rank synchronization: dispatch's data_done barrier ends the dispatch - * step; combine's combine_done barrier (inside combine.cpp) ends combine. + * tensor(0) x_hc INPUT (host) [B, S, HC_MULT, D] BF16 + * tensor(1) hc_ffn_fn INPUT (host) [MIX_HC, HC_DIM] FP32 + * tensor(2) hc_ffn_scale INPUT (host) [3] FP32 (orch reads .data_as) + * tensor(3) hc_ffn_base INPUT (host) [MIX_HC] FP32 + * tensor(4) norm_w INPUT (host) [D] FP32 + * tensor(5) gate_w INPUT (host) [N_EXPERTS, D] FP32 + * tensor(6) gate_bias INPUT (host) [N_EXPERTS] FP32 + * tensor(7) tid2eid INPUT (host) [VOCAB, TOPK] INT32 (unused at LAYER_ID >= N_HASH_LAYERS) + * tensor(8) input_ids INPUT (host) [B, S] INT64 (ditto) + * tensor(9) w_padded INPUT (host) [T*TOPK, W_PAD=8] FP32 (packed from golden weights) + * tensor(10) idx_padded INPUT (host) [T*TOPK, IDX_PAD=8] INT32 (packed from golden indices) + * tensor(11) recv_count_host INPUT (host) [L, 1] INT32 (orch reads .data_as for moe_expert loop bounds) + * tensor(12) expert_w1 INPUT (host) [L, MOE_INTER, D] INT8 + * tensor(13) expert_w1_scale INPUT (host) [L, MOE_INTER] FP32 + * tensor(14) expert_w3 INPUT (host) [L, MOE_INTER, D] INT8 + * tensor(15) expert_w3_scale INPUT (host) [L, MOE_INTER] FP32 + * tensor(16) expert_w2 INPUT (host) [L, D, MOE_INTER] INT8 + * tensor(17) expert_w2_scale INPUT (host) [L, D] FP32 + * tensor(18) shared_w1 INPUT (host) [MOE_INTER, D] INT8 + * tensor(19) shared_w1_scale INPUT (host) [MOE_INTER] FP32 + * tensor(20) shared_w3 INPUT (host) [MOE_INTER, D] INT8 + * tensor(21) shared_w3_scale INPUT (host) [MOE_INTER] FP32 + * tensor(22) shared_w2 INPUT (host) [D, MOE_INTER] INT8 + * tensor(23) shared_w2_scale INPUT (host) [D] FP32 + * tensor(24) x_norm OUTPUT_EXISTING [T, D] BF16 (router out -> dispatch + moe_expert) + * tensor(25) indices OUTPUT_EXISTING [T, TOPK] INT32 (router out -> dispatch) + * tensor(26) weights OUTPUT_EXISTING [T, TOPK] FP32 (router out; verification only) + * tensor(27) post_ffn OUTPUT_EXISTING [B, S, HC_MULT] FP32 (router out; unused) + * tensor(28) comb_ffn OUTPUT_EXISTING [B, S, HC_MULT, HC_MULT] FP32 (router out; unused) + * tensor(29) recv_x_out OUTPUT_EXISTING [L, R, D] BF16 + * tensor(30) recv_w_out OUTPUT_EXISTING [L, R] FP32 + * tensor(31) recv_idx_out OUTPUT_EXISTING [L, R] INT32 + * tensor(32) recv_count_out OUTPUT_EXISTING [L, 1] INT32 + * tensor(33) recv_y OUTPUT_EXISTING [L, R, D] BF16 + * tensor(34) sh OUTPUT_EXISTING [T, D] BF16 + * tensor(35) routed_y OUTPUT_EXISTING [T, D] FP32 + * tensor(36) ffn_out OUTPUT_EXISTING [T, D] BF16 (routed_y + sh) + * tensor(37) y OUTPUT_EXISTING [B, S, HC_MULT, D] BF16 (hc_post out) + * tensor(38) scratch INOUT HCCL window slot + * scalar(0) nranks + * scalar(1) CommContext device pointer */ +#include "runtime.h" + +#include #include +#include #include "pto_orchestration_api.h" @@ -54,24 +86,362 @@ __attribute__((visibility("default"))) PTO2OrchestrationConfig ep_dispatch_combine_orchestration_config(const ChipStorageTaskArgs &orch_args) { (void)orch_args; return PTO2OrchestrationConfig{ - .expected_arg_count = 13, // 11 tensors + 2 scalars + .expected_arg_count = 41, // 39 tensors + 2 scalars }; } __attribute__((visibility("default"))) void ep_dispatch_combine_orchestration(const ChipStorageTaskArgs &orch_args) { - Tensor indices = from_tensor_arg(orch_args.tensor(0)); - Tensor x_norm = from_tensor_arg(orch_args.tensor(1)); - Tensor w_padded = from_tensor_arg(orch_args.tensor(2)); - Tensor idx_padded = from_tensor_arg(orch_args.tensor(3)); - Tensor recv_x_out = from_tensor_arg(orch_args.tensor(4)); - Tensor recv_w_out = from_tensor_arg(orch_args.tensor(5)); - Tensor recv_idx_out = from_tensor_arg(orch_args.tensor(6)); - Tensor recv_count_out = from_tensor_arg(orch_args.tensor(7)); - Tensor recv_y = from_tensor_arg(orch_args.tensor(8)); - Tensor routed_y = from_tensor_arg(orch_args.tensor(9)); - Tensor scratch = from_tensor_arg(orch_args.tensor(10)); - - // child 0: dispatch + // ---- bind external tensors (kept under the orch's view) ---- + Tensor ext_x_hc = from_tensor_arg(orch_args.tensor(0)); + Tensor ext_hc_ffn_fn = from_tensor_arg(orch_args.tensor(1)); + Tensor ext_hc_ffn_scale = from_tensor_arg(orch_args.tensor(2)); + Tensor ext_hc_ffn_base = from_tensor_arg(orch_args.tensor(3)); + Tensor ext_norm_w = from_tensor_arg(orch_args.tensor(4)); + Tensor ext_gate_w = from_tensor_arg(orch_args.tensor(5)); + Tensor ext_gate_bias = from_tensor_arg(orch_args.tensor(6)); + Tensor ext_tid2eid = from_tensor_arg(orch_args.tensor(7)); + Tensor ext_input_ids = from_tensor_arg(orch_args.tensor(8)); + Tensor w_padded = from_tensor_arg(orch_args.tensor(9)); + Tensor idx_padded = from_tensor_arg(orch_args.tensor(10)); + // tensor(11) recv_count_host is read directly via .data_as inside the moe_expert loops. + Tensor x_norm = from_tensor_arg(orch_args.tensor(24)); + Tensor indices = from_tensor_arg(orch_args.tensor(25)); + Tensor recv_x_out = from_tensor_arg(orch_args.tensor(29)); + Tensor recv_w_out = from_tensor_arg(orch_args.tensor(30)); + Tensor recv_idx_out = from_tensor_arg(orch_args.tensor(31)); + Tensor recv_count_out = from_tensor_arg(orch_args.tensor(32)); + Tensor recv_y = from_tensor_arg(orch_args.tensor(33)); + Tensor sh = from_tensor_arg(orch_args.tensor(34)); + Tensor routed_y = from_tensor_arg(orch_args.tensor(35)); + Tensor ffn_out = from_tensor_arg(orch_args.tensor(36)); + Tensor y = from_tensor_arg(orch_args.tensor(37)); + Tensor scratch = from_tensor_arg(orch_args.tensor(38)); + + // Router output names (PyPTO body refs them as ext_*). + Tensor ext_x_norm = x_norm; + Tensor ext_indices = indices; + Tensor ext_weights = from_tensor_arg(orch_args.tensor(26)); + Tensor ext_post_ffn = from_tensor_arg(orch_args.tensor(27)); + Tensor ext_comb_ffn = from_tensor_arg(orch_args.tensor(28)); + + // moe_expert external tensors (names preserved from the generated orch). + Tensor ext_recv_x = recv_x_out; + Tensor ext_recv_weights = recv_w_out; + Tensor ext_x_local = x_norm; + Tensor ext_expert_w1 = from_tensor_arg(orch_args.tensor(12)); + Tensor ext_expert_w1_scale = from_tensor_arg(orch_args.tensor(13)); + Tensor ext_expert_w3 = from_tensor_arg(orch_args.tensor(14)); + Tensor ext_expert_w3_scale = from_tensor_arg(orch_args.tensor(15)); + Tensor ext_expert_w2 = from_tensor_arg(orch_args.tensor(16)); + Tensor ext_expert_w2_scale = from_tensor_arg(orch_args.tensor(17)); + Tensor ext_shared_w1 = from_tensor_arg(orch_args.tensor(18)); + Tensor ext_shared_w1_scale = from_tensor_arg(orch_args.tensor(19)); + Tensor ext_shared_w3 = from_tensor_arg(orch_args.tensor(20)); + Tensor ext_shared_w3_scale = from_tensor_arg(orch_args.tensor(21)); + Tensor ext_shared_w2 = from_tensor_arg(orch_args.tensor(22)); + Tensor ext_shared_w2_scale = from_tensor_arg(orch_args.tensor(23)); + Tensor ext_recv_y = recv_y; + Tensor ext_sh = sh; + + // ---- func_id 0..17: moe_router (transplanted PyPTO-generated body) ---- + PTO2_SCOPE() { + uint32_t x_mixed_ci_shapes[3] = {16, 1, 4096}; + TensorCreateInfo x_mixed_ci(x_mixed_ci_shapes, 3, DataType::BFLOAT16); + uint32_t x_flat_fp32_inline49_ci_shapes[2] = {16, 16384}; + TensorCreateInfo x_flat_fp32_inline49_ci(x_flat_fp32_inline49_ci_shapes, 2, DataType::FLOAT32); + uint32_t inv_rms_inline75_ci_shapes[2] = {1, 16}; + TensorCreateInfo inv_rms_inline75_ci(inv_rms_inline75_ci_shapes, 2, DataType::FLOAT32); + uint32_t mixes_inline34_ci_shapes[2] = {16, 32}; + TensorCreateInfo mixes_inline34_ci(mixes_inline34_ci_shapes, 2, DataType::FLOAT32); + uint32_t comb_logits_inline19_ci_shapes[2] = {16, 16}; + TensorCreateInfo comb_logits_inline19_ci(comb_logits_inline19_ci_shapes, 2, DataType::FLOAT32); + uint32_t ret0__out_ci_shapes[2] = {16, 8}; + TensorCreateInfo ret0__out_ci(ret0__out_ci_shapes, 2, DataType::FLOAT32); + uint32_t ret0__out_1_ci_shapes[2] = {16, 8}; + TensorCreateInfo ret0__out_1_ci(ret0__out_1_ci_shapes, 2, DataType::FLOAT32); + uint32_t ret0__out_2_ci_shapes[2] = {16, 8}; + TensorCreateInfo ret0__out_2_ci(ret0__out_2_ci_shapes, 2, DataType::FLOAT32); + uint32_t ret1__out_ci_shapes[2] = {16, 16}; + TensorCreateInfo ret1__out_ci(ret1__out_ci_shapes, 2, DataType::FLOAT32); + uint32_t inv_rms_inline126_ci_shapes[2] = {1, 16}; + TensorCreateInfo inv_rms_inline126_ci(inv_rms_inline126_ci_shapes, 2, DataType::FLOAT32); + uint32_t x_norm_bf16_inline111_ci_shapes[2] = {16, 4096}; + TensorCreateInfo x_norm_bf16_inline111_ci(x_norm_bf16_inline111_ci_shapes, 2, DataType::BFLOAT16); + uint32_t biased_scores_inline131_ci_shapes[2] = {16, 32}; + TensorCreateInfo biased_scores_inline131_ci(biased_scores_inline131_ci_shapes, 2, DataType::FLOAT32); + uint32_t score_acc_buf_inline117_ci_shapes[2] = {1, 16}; + TensorCreateInfo score_acc_buf_inline117_ci(score_acc_buf_inline117_ci_shapes, 2, DataType::FLOAT32); + uint32_t sorted_rows_inline97_ci_shapes[2] = {16, 64}; + TensorCreateInfo sorted_rows_inline97_ci(sorted_rows_inline97_ci_shapes, 2, DataType::FLOAT32); + uint32_t topk_vals_pad_inline110_ci_shapes[2] = {16, 32}; + TensorCreateInfo topk_vals_pad_inline110_ci(topk_vals_pad_inline110_ci_shapes, 2, DataType::FLOAT32); + uint32_t topk_idx_pad_inline104_ci_shapes[2] = {16, 32}; + TensorCreateInfo topk_idx_pad_inline104_ci(topk_idx_pad_inline104_ci_shapes, 2, DataType::INT32); + TaskOutputTensors alloc_0 = alloc_tensors(x_mixed_ci, x_flat_fp32_inline49_ci, inv_rms_inline75_ci, mixes_inline34_ci, comb_logits_inline19_ci, ret0__out_ci, ret0__out_1_ci, ret0__out_2_ci, ret1__out_ci, inv_rms_inline126_ci, x_norm_bf16_inline111_ci, biased_scores_inline131_ci, score_acc_buf_inline117_ci, sorted_rows_inline97_ci, topk_vals_pad_inline110_ci, topk_idx_pad_inline104_ci); + const Tensor& x_mixed = alloc_0.get_ref(0); + const Tensor& x_flat_fp32_inline49 = alloc_0.get_ref(1); + const Tensor& inv_rms_inline75 = alloc_0.get_ref(2); + const Tensor& mixes_inline34 = alloc_0.get_ref(3); + const Tensor& comb_logits_inline19 = alloc_0.get_ref(4); + const Tensor& ret0__out = alloc_0.get_ref(5); + const Tensor& ret0__out_1 = alloc_0.get_ref(6); + const Tensor& ret0__out_2 = alloc_0.get_ref(7); + const Tensor& ret1__out = alloc_0.get_ref(8); + const Tensor& inv_rms_inline126 = alloc_0.get_ref(9); + const Tensor& x_norm_bf16_inline111 = alloc_0.get_ref(10); + const Tensor& biased_scores_inline131 = alloc_0.get_ref(11); + const Tensor& score_acc_buf_inline117 = alloc_0.get_ref(12); + const Tensor& sorted_rows_inline97 = alloc_0.get_ref(13); + const Tensor& topk_vals_pad_inline110 = alloc_0.get_ref(14); + const Tensor& topk_idx_pad_inline104 = alloc_0.get_ref(15); + uint32_t weight_out_pad_inline108_ci_shapes[2] = {16, 32}; + TensorCreateInfo weight_out_pad_inline108_ci(weight_out_pad_inline108_ci_shapes, 2, DataType::FLOAT32); + TaskOutputTensors alloc_1 = alloc_tensors(weight_out_pad_inline108_ci); + const Tensor& weight_out_pad_inline108 = alloc_1.get_ref(0); + uint32_t x_flat_inline36_shapes[2] = {16, 16384}; + Tensor x_flat_inline36 = ext_x_hc.reshape(x_flat_inline36_shapes, 2); + uint32_t post_flat_inline51_shapes[1] = {64}; + Tensor post_flat_inline51 = ext_post_ffn.reshape(post_flat_inline51_shapes, 1); + uint32_t comb_flat_inline37_shapes[1] = {256}; + Tensor comb_flat_inline37 = ext_comb_ffn.reshape(comb_flat_inline37_shapes, 1); + for (int64_t kb_inline48 = 0; kb_inline48 < 32; kb_inline48 += 1) { + PTO2_SCOPE() { + int64_t k0_inline59 = (kb_inline48 * 512); + + // Task 0: cast_x + Arg params_t0; + params_t0.add_input(x_flat_inline36); + params_t0.add_output(x_flat_fp32_inline49); + params_t0.add_scalar(k0_inline59); + rt_submit_aiv_task(0, params_t0); + const Tensor& x_flat_fp32_inline49__ssa_v3 = x_flat_fp32_inline49; + } + } + uint32_t ret0__out_3_shapes[2] = {1, 16}; + uint32_t ret0__out_3_offsets[2] = {0, 0}; + Tensor ret0__out_3 = inv_rms_inline75.view(ret0__out_3_shapes, ret0__out_3_offsets); + + // Task 1: rms + Arg params_t1; + params_t1.add_input(x_flat_fp32_inline49); + params_t1.add_output(ret0__out_3); + rt_submit_aiv_task(1, params_t1); + uint32_t mixes_flat_inline45_shapes[1] = {512}; + Tensor mixes_flat_inline45 = mixes_inline34.reshape(mixes_flat_inline45_shapes, 1); + + // Task 2: linear + Arg params_t2; + params_t2.add_input(x_flat_fp32_inline49); + params_t2.add_input(ext_hc_ffn_fn); + params_t2.add_input(inv_rms_inline75); + params_t2.add_output(mixes_flat_inline45); + rt_submit_aiv_task(2, params_t2); + uint32_t mixes_v1_inline58_shapes[2] = {16, 32}; + Tensor mixes_v1_inline58 = mixes_flat_inline45.reshape(mixes_v1_inline58_shapes, 2); + float scale0_inline21 = static_cast(orch_args.tensor(2).data_as())[0]; + float scale1_inline17 = static_cast(orch_args.tensor(2).data_as())[1]; + float scale2_inline26 = static_cast(orch_args.tensor(2).data_as())[2]; + + // Task 3: split_pre_post + Arg params_t3; + params_t3.add_output(ret0__out); + rt_submit_aiv_task(3, params_t3); + const Tensor& ones_hc_inline79 = ret0__out; + uint32_t t_shapes[1] = {8}; + uint32_t t_offsets[1] = {0}; + Tensor t = ext_hc_ffn_base.view(t_shapes, t_offsets); + uint32_t pre_base_inline76_shapes[2] = {1, 8}; + Tensor pre_base_inline76 = t.reshape(pre_base_inline76_shapes, 2); + uint32_t t__tmp_v26_shapes[2] = {16, 8}; + uint32_t t__tmp_v26_offsets[2] = {0, 0}; + Tensor t__tmp_v26 = mixes_v1_inline58.view(t__tmp_v26_shapes, t__tmp_v26_offsets); + + // Task 4: split_pre_post_0 + Arg params_t4; + params_t4.add_input(t__tmp_v26); + params_t4.add_input(ones_hc_inline79); + params_t4.add_input(pre_base_inline76); + params_t4.add_output(ret0__out_1); + params_t4.add_scalar(to_u64(scale0_inline21)); + rt_submit_aiv_task(4, params_t4); + const Tensor& pre_val_inline62 = ret0__out_1; + uint32_t t__tmp_v33_shapes[1] = {8}; + uint32_t t__tmp_v33_offsets[1] = {4}; + Tensor t__tmp_v33 = ext_hc_ffn_base.view(t__tmp_v33_shapes, t__tmp_v33_offsets); + uint32_t post_base_inline50_shapes[2] = {1, 8}; + Tensor post_base_inline50 = t__tmp_v33.reshape(post_base_inline50_shapes, 2); + uint32_t t__tmp_v34_shapes[2] = {16, 8}; + uint32_t t__tmp_v34_offsets[2] = {0, 4}; + Tensor t__tmp_v34 = mixes_v1_inline58.view(t__tmp_v34_shapes, t__tmp_v34_offsets); + + // Task 5: split_pre_post_1 + Arg params_t5; + params_t5.add_input(t__tmp_v34); + params_t5.add_input(ones_hc_inline79); + params_t5.add_input(post_base_inline50); + params_t5.add_output(ret0__out_2); + params_t5.add_output(ret1__out); + params_t5.add_scalar(to_u64(scale1_inline17)); + rt_submit_aiv_task(5, params_t5); + const Tensor& post_pad_inline66 = ret0__out_2; + const Tensor& ones_comb_inline67 = ret1__out; + uint32_t t__tmp_v41_shapes[1] = {16}; + uint32_t t__tmp_v41_offsets[1] = {8}; + Tensor t__tmp_v41 = ext_hc_ffn_base.view(t__tmp_v41_shapes, t__tmp_v41_offsets); + uint32_t comb_base_inline70_shapes[2] = {1, 16}; + Tensor comb_base_inline70 = t__tmp_v41.reshape(comb_base_inline70_shapes, 2); + uint32_t comb_mix_inline41_shapes[2] = {16, 16}; + uint32_t comb_mix_inline41_offsets[2] = {0, 8}; + Tensor comb_mix_inline41 = mixes_v1_inline58.view(comb_mix_inline41_shapes, comb_mix_inline41_offsets); + uint32_t ret0__out_4_shapes[2] = {16, 16}; + uint32_t ret0__out_4_offsets[2] = {0, 0}; + Tensor ret0__out_4 = comb_logits_inline19.view(ret0__out_4_shapes, ret0__out_4_offsets); + + // Task 6: split_pre_post_2 + Arg params_t6; + params_t6.add_input(comb_mix_inline41); + params_t6.add_input(ones_comb_inline67); + params_t6.add_input(comb_base_inline70); + params_t6.add_output(ret0__out_4); + params_t6.add_scalar(to_u64(scale2_inline26)); + rt_submit_aiv_task(6, params_t6); + uint32_t post_pad_flat_inline73_shapes[1] = {128}; + Tensor post_pad_flat_inline73 = post_pad_inline66.reshape(post_pad_flat_inline73_shapes, 1); + + // Task 7: comb_sinkhorn + Arg params_t7; + params_t7.add_input(comb_logits_inline19); + params_t7.add_output(comb_flat_inline37); + rt_submit_aiv_task(7, params_t7); + for (int64_t t_inline7 = 0; t_inline7 < 1; t_inline7 += 1) { + PTO2_SCOPE() { + + // Task 8: write_post + Arg params_t8; + params_t8.add_input(post_pad_flat_inline73); + params_t8.add_output(post_flat_inline51); + params_t8.add_scalar((uint64_t)0); + rt_submit_aiv_task(8, params_t8); + } + } + uint32_t pre_val_flat_inline44_shapes[1] = {128}; + Tensor pre_val_flat_inline44 = pre_val_inline62.reshape(pre_val_flat_inline44_shapes, 1); + uint32_t x_mixed_view_inline55_shapes[2] = {16, 4096}; + Tensor x_mixed_view_inline55 = x_mixed.reshape(x_mixed_view_inline55_shapes, 2); + for (int64_t t_inline5 = 0; t_inline5 < 1; t_inline5 += 1) { + PTO2_SCOPE() { + + // Task 9: mix_x + Arg params_t9; + params_t9.add_output(x_mixed_view_inline55); + params_t9.add_input(pre_val_flat_inline44); + params_t9.add_input(x_flat_fp32_inline49); + params_t9.add_scalar((uint64_t)0); + rt_submit_aiv_task(9, params_t9); + const Tensor& x_mixed_view_inline55__co_l1_rv_v1 = x_mixed_view_inline55; + } + } + uint32_t x_mixed_v1_inline0_shapes[3] = {16, 1, 4096}; + Tensor x_mixed_v1_inline0 = x_mixed_view_inline55.reshape(x_mixed_v1_inline0_shapes, 3); + uint32_t x_mixed_flat_inline127_shapes[2] = {16, 4096}; + Tensor x_mixed_flat_inline127 = x_mixed.reshape(x_mixed_flat_inline127_shapes, 2); + uint32_t ret0__out_5_shapes[2] = {1, 16}; + uint32_t ret0__out_5_offsets[2] = {0, 0}; + Tensor ret0__out_5 = inv_rms_inline126.view(ret0__out_5_shapes, ret0__out_5_offsets); + + // Task 10: ffn_norm_rms + Arg params_t10; + params_t10.add_input(x_mixed_flat_inline127); + params_t10.add_output(ret0__out_5); + rt_submit_aiv_task(10, params_t10); + for (int64_t db_inline125 = 0; db_inline125 < 8; db_inline125 += 1) { + PTO2_SCOPE() { + + // Task 11: ffn_norm_apply + Arg params_t11; + params_t11.add_input(inv_rms_inline126); + params_t11.add_input(x_mixed_flat_inline127); + params_t11.add_input(ext_norm_w); + params_t11.add_output(x_norm_bf16_inline111); + params_t11.add_output(ext_x_norm); + params_t11.add_scalar(db_inline125); + rt_submit_aiv_task(11, params_t11); + const Tensor& x_norm_bf16_inline111__ssa_v3 = x_norm_bf16_inline111; + const Tensor& x_norm__ssa_v3 = ext_x_norm; + } + } + uint32_t biased_flat_inline133_shapes[1] = {512}; + Tensor biased_flat_inline133 = biased_scores_inline131.reshape(biased_flat_inline133_shapes, 1); + uint32_t ret0__out_6_shapes[2] = {16, 32}; + uint32_t ret0__out_6_offsets[2] = {0, 0}; + Tensor ret0__out_6 = biased_scores_inline131.view(ret0__out_6_shapes, ret0__out_6_offsets); + + // Task 12: gate_dot + Arg params_t12; + params_t12.add_output(ret0__out_6); + rt_submit_aiv_task(12, params_t12); + + // Task 13: gate_dot_0 + Arg params_t13; + params_t13.add_inout(score_acc_buf_inline117); + params_t13.add_input(x_norm_bf16_inline111); + params_t13.add_input(ext_gate_w); + params_t13.add_input(ext_gate_bias); + params_t13.add_output(biased_flat_inline133); + rt_submit_aiv_task(13, params_t13); + uint32_t biased_scores_v1_inline136_shapes[2] = {16, 32}; + Tensor biased_scores_v1_inline136 = biased_flat_inline133.reshape(biased_scores_v1_inline136_shapes, 2); + + // Task 14: route_sort_top2 + Arg params_t14; + params_t14.add_input(biased_scores_v1_inline136); + params_t14.add_inout(sorted_rows_inline97); + rt_submit_aiv_task(14, params_t14); + const Tensor& sorted_rows_inline97__ssa_v16 = sorted_rows_inline97; + + // Task 15: route_extract_top2 + Arg params_t15; + params_t15.add_inout(topk_vals_pad_inline110); + params_t15.add_input(sorted_rows_inline97__ssa_v16); + params_t15.add_inout(topk_idx_pad_inline104); + rt_submit_aiv_task(15, params_t15); + const Tensor& topk_vals_pad_inline110__ssa_v17 = topk_vals_pad_inline110; + const Tensor& topk_idx_pad_inline104__ssa_v16 = topk_idx_pad_inline104; + + // Task 16: route_normalize_weights + Arg params_t16; + params_t16.add_input(topk_vals_pad_inline110__ssa_v17); + params_t16.add_inout(weight_out_pad_inline108); + rt_submit_aiv_task(16, params_t16); + const Tensor& weight_out_pad_inline108__ssa_v1 = weight_out_pad_inline108; + uint32_t indices_flat_inline89_shapes[1] = {32}; + Tensor indices_flat_inline89 = ext_indices.reshape(indices_flat_inline89_shapes, 1); + uint32_t weights_flat_inline88_shapes[1] = {32}; + Tensor weights_flat_inline88 = ext_weights.reshape(weights_flat_inline88_shapes, 1); + uint32_t topk_idx_flat_inline87_shapes[1] = {512}; + Tensor topk_idx_flat_inline87 = topk_idx_pad_inline104__ssa_v16.reshape(topk_idx_flat_inline87_shapes, 1); + uint32_t weight_out_flat_inline86_shapes[1] = {512}; + Tensor weight_out_flat_inline86 = weight_out_pad_inline108__ssa_v1.reshape(weight_out_flat_inline86_shapes, 1); + + // Task 17: write_route_outputs. + // NOTE: declare the writes against the *same* C++ Tensor variables + // that dispatch reads (``indices`` / ``ext_weights``'s storage), + // not their flat reshape views. The L3 runtime tracks deps by + // Tensor-object identity, so writes to the reshape view don't + // establish a happens-before edge from this task to dispatch's + // later ``add_input(indices)`` — without this dispatch races with + // the router and reads stale (mostly-zero) indices on a fraction + // of runs. Buffer addr / offset are unchanged so the generated + // kernel still indexes the same memory. + Arg params_t17; + params_t17.add_input(topk_idx_flat_inline87); + params_t17.add_output(indices); + params_t17.add_input(weight_out_flat_inline86); + params_t17.add_output(ext_weights); + rt_submit_aiv_task(17, params_t17); + } + + // ---- func_id 18: dispatch ---- { Arg p; p.add_input(indices); @@ -85,30 +455,366 @@ __attribute__((visibility("default"))) void ep_dispatch_combine_orchestration(co p.add_inout(scratch); p.add_scalar(orch_args.scalar(0)); // nranks p.add_scalar(orch_args.scalar(1)); // CommContext - rt_submit_aiv_task(0, p); + rt_submit_aiv_task(18, p); } - // child 1: local_expert (pure local, host-backed I/O only — no scratch) - { - Arg p; - p.add_input(recv_x_out); - p.add_input(recv_w_out); - p.add_input(recv_count_out); - p.add_output(recv_y); - p.add_scalar(orch_args.scalar(1)); // CommContext (only for ABI symmetry) - rt_submit_aiv_task(1, p); - } + // ---- func_id 19..35: moe_expert (transplanted PyPTO-generated body) ---- + PTO2_SCOPE() { + uint32_t x_local_i8_inline43_ci_shapes[2] = {16, 4096}; + TensorCreateInfo x_local_i8_inline43_ci(x_local_i8_inline43_ci_shapes, 2, DataType::INT8); + uint32_t ret0__out_ci_shapes[2] = {16, 1}; + TensorCreateInfo ret0__out_ci(ret0__out_ci_shapes, 2, DataType::FLOAT32); + uint32_t sh_tile_fp32_inline88_ci_shapes[2] = {16, 4096}; + TensorCreateInfo sh_tile_fp32_inline88_ci(sh_tile_fp32_inline88_ci_shapes, 2, DataType::FLOAT32); + uint32_t sh_tile_i8_inline126_ci_shapes[2] = {16, 4096}; + TensorCreateInfo sh_tile_i8_inline126_ci(sh_tile_i8_inline126_ci_shapes, 2, DataType::INT8); + uint32_t ret0__out_1_ci_shapes[2] = {16, 1}; + TensorCreateInfo ret0__out_1_ci(ret0__out_1_ci_shapes, 2, DataType::FLOAT32); + TaskOutputTensors alloc_0 = alloc_tensors(x_local_i8_inline43_ci, ret0__out_ci, sh_tile_fp32_inline88_ci, sh_tile_i8_inline126_ci, ret0__out_1_ci); + const Tensor& x_local_i8_inline43 = alloc_0.get_ref(0); + const Tensor& ret0__out = alloc_0.get_ref(1); + const Tensor& sh_tile_fp32_inline88 = alloc_0.get_ref(2); + const Tensor& sh_tile_i8_inline126 = alloc_0.get_ref(3); + const Tensor& ret0__out_1 = alloc_0.get_ref(4); + uint32_t recv_y_flat_inline53_shapes[2] = {256, 4096}; + Tensor recv_y_flat_inline53 = ext_recv_y.reshape(recv_y_flat_inline53_shapes, 2); + uint32_t recv_weights_flat_inline62_shapes[2] = {256, 1}; + Tensor recv_weights_flat_inline62 = ext_recv_weights.reshape(recv_weights_flat_inline62_shapes, 2); - // child 2: combine (push to routed_y_buf in scratch, barrier, reduce_sum) + // Task 19: x_local_q + Arg params_t0; + params_t0.add_input(ext_x_local); + params_t0.add_output(x_local_i8_inline43); + params_t0.add_output(ret0__out); + rt_submit_aiv_task(19, params_t0); + const Tensor& x_local_i8_inline43__rv_v2 = x_local_i8_inline43; + const Tensor& x_local_scale_dq_inline32 = ret0__out; + for (int64_t local_i_inline67 = 0; local_i_inline67 < 8; local_i_inline67 += 1) { + PTO2_SCOPE() { + size_t idx_n_rows_inline68 = local_i_inline67 * 1 + 0; + int32_t n_rows_inline68 = static_cast(orch_args.tensor(11).data_as())[idx_n_rows_inline68]; + int64_t n_tiles_inline91 = ((static_cast(n_rows_inline68) + 15) / 16); + int64_t flat_base_inline30 = (local_i_inline67 * 32); + for (int64_t t_inline108 = 0; t_inline108 < n_tiles_inline91; t_inline108 += 1) { + PTO2_SCOPE() { + uint32_t recv_x_tile_i8_inline75_ci_shapes[2] = {16, 4096}; + TensorCreateInfo recv_x_tile_i8_inline75_ci(recv_x_tile_i8_inline75_ci_shapes, 2, DataType::INT8); + uint32_t ret0__out_2_ci_shapes[2] = {16, 1}; + TensorCreateInfo ret0__out_2_ci(ret0__out_2_ci_shapes, 2, DataType::FLOAT32); + uint32_t h_tile_fp32_inline18_ci_shapes[2] = {16, 4096}; + TensorCreateInfo h_tile_fp32_inline18_ci(h_tile_fp32_inline18_ci_shapes, 2, DataType::FLOAT32); + uint32_t h_tile_i8_inline92_ci_shapes[2] = {16, 4096}; + TensorCreateInfo h_tile_i8_inline92_ci(h_tile_i8_inline92_ci_shapes, 2, DataType::INT8); + uint32_t ret0__out_3_ci_shapes[2] = {16, 1}; + TensorCreateInfo ret0__out_3_ci(ret0__out_3_ci_shapes, 2, DataType::FLOAT32); + TaskOutputTensors alloc_1 = alloc_tensors(recv_x_tile_i8_inline75_ci, ret0__out_2_ci, h_tile_fp32_inline18_ci, h_tile_i8_inline92_ci, ret0__out_3_ci); + const Tensor& recv_x_tile_i8_inline75 = alloc_1.get_ref(0); + const Tensor& ret0__out_2 = alloc_1.get_ref(1); + const Tensor& h_tile_fp32_inline18 = alloc_1.get_ref(2); + const Tensor& h_tile_i8_inline92 = alloc_1.get_ref(3); + const Tensor& ret0__out_3 = alloc_1.get_ref(4); + int64_t t0_inline47 = (t_inline108 * 16); + int64_t flat_t0_inline40 = (flat_base_inline30 + t0_inline47); + int64_t valid_rows_inline73 = std::min((static_cast(n_rows_inline68) - t0_inline47), 16); + + // Task 20: recv_x_q + Arg params_t1; + params_t1.add_input(ext_recv_x); + params_t1.add_output(recv_x_tile_i8_inline75); + params_t1.add_output(ret0__out_2); + params_t1.add_scalar(local_i_inline67); + params_t1.add_scalar(t0_inline47); + rt_submit_aiv_task(20, params_t1); + const Tensor& recv_x_tile_i8_inline75__rv_v2 = recv_x_tile_i8_inline75; + const Tensor& recv_x_scale_dq_inline29 = ret0__out_2; + for (int64_t n0_inline72 = 0; n0_inline72 < 4096; n0_inline72 += 256) { + PTO2_SCOPE() { + uint32_t ret0__out_4_ci_shapes[3] = {1, 16, 256}; + TensorCreateInfo ret0__out_4_ci(ret0__out_4_ci_shapes, 3, DataType::INT32); + uint32_t ret1__out_ci_shapes[3] = {1, 16, 256}; + TensorCreateInfo ret1__out_ci(ret1__out_ci_shapes, 3, DataType::INT32); + uint32_t ret0__out_5_ci_shapes[2] = {16, 256}; + TensorCreateInfo ret0__out_5_ci(ret0__out_5_ci_shapes, 2, DataType::FLOAT32); + uint32_t ret1__out_1_ci_shapes[2] = {16, 256}; + TensorCreateInfo ret1__out_1_ci(ret1__out_1_ci_shapes, 2, DataType::FLOAT32); + uint32_t ret0__out_6_ci_shapes[2] = {16, 256}; + TensorCreateInfo ret0__out_6_ci(ret0__out_6_ci_shapes, 2, DataType::FLOAT32); + TaskOutputTensors alloc_2 = alloc_tensors(ret0__out_4_ci, ret1__out_ci, ret0__out_5_ci, ret1__out_1_ci, ret0__out_6_ci); + const Tensor& ret0__out_4 = alloc_2.get_ref(0); + const Tensor& ret1__out = alloc_2.get_ref(1); + const Tensor& ret0__out_5 = alloc_2.get_ref(2); + const Tensor& ret1__out_1 = alloc_2.get_ref(3); + const Tensor& ret0__out_6 = alloc_2.get_ref(4); + + // Task 21: exp_gate_up_matmul + Arg params_t2; + params_t2.add_input(recv_x_tile_i8_inline75__rv_v2); + params_t2.add_input(ext_expert_w1); + params_t2.add_input(ext_expert_w3); + params_t2.add_output(ret0__out_4); + params_t2.add_output(ret1__out); + params_t2.add_scalar(local_i_inline67); + params_t2.add_scalar(n0_inline72); + rt_submit_aic_task(21, params_t2); + const Tensor& gate_acc_inline69 = ret0__out_4; + const Tensor& up_acc_inline46 = ret1__out; + + // Task 22: exp_gate_up_dequant + Arg params_t3; + params_t3.add_input(gate_acc_inline69); + params_t3.add_input(up_acc_inline46); + params_t3.add_input(ext_expert_w1_scale); + params_t3.add_input(ext_expert_w3_scale); + params_t3.add_input(recv_x_scale_dq_inline29); + params_t3.add_output(ret0__out_5); + params_t3.add_output(ret1__out_1); + params_t3.add_scalar(local_i_inline67); + params_t3.add_scalar(n0_inline72); + rt_submit_aiv_task(22, params_t3); + const Tensor& gate_2d_v1_inline7 = ret0__out_5; + const Tensor& up_2d_v1_inline78 = ret1__out_1; + + // Task 23: exp_swiglu + Arg params_t4; + params_t4.add_input(gate_2d_v1_inline7); + params_t4.add_input(up_2d_v1_inline78); + params_t4.add_input(recv_weights_flat_inline62); + params_t4.add_output(ret0__out_6); + params_t4.add_scalar(flat_t0_inline40); + rt_submit_aiv_task(23, params_t4); + const Tensor& h_chunk_inline89 = ret0__out_6; + + // Task 24: exp_swiglu_mask + Arg params_t5; + params_t5.add_input(h_chunk_inline89); + params_t5.add_output(h_tile_fp32_inline18); + params_t5.add_scalar(valid_rows_inline73); + params_t5.add_scalar(n0_inline72); + rt_submit_aiv_task(24, params_t5); + const Tensor& h_tile_fp32_inline18__ssa_v3 = h_tile_fp32_inline18; + } + } + + // Task 25: exp_h_q + Arg params_t6; + params_t6.add_input(h_tile_fp32_inline18); + params_t6.add_output(h_tile_i8_inline92); + params_t6.add_output(ret0__out_3); + rt_submit_aiv_task(25, params_t6); + const Tensor& h_tile_i8_inline92__rv_v2 = h_tile_i8_inline92; + const Tensor& h_tile_scale_dq_inline63 = ret0__out_3; + for (int64_t d0_inline49 = 0; d0_inline49 < 4096; d0_inline49 += 512) { + PTO2_SCOPE() { + uint32_t ret0__out_7_ci_shapes[3] = {1, 16, 512}; + TensorCreateInfo ret0__out_7_ci(ret0__out_7_ci_shapes, 3, DataType::INT32); + uint32_t ret0__out_8_ci_shapes[2] = {16, 512}; + TensorCreateInfo ret0__out_8_ci(ret0__out_8_ci_shapes, 2, DataType::FLOAT32); + TaskOutputTensors alloc_3 = alloc_tensors(ret0__out_7_ci, ret0__out_8_ci); + const Tensor& ret0__out_7 = alloc_3.get_ref(0); + const Tensor& ret0__out_8 = alloc_3.get_ref(1); + + // Task 26: exp_w2_matmul + Arg params_t7; + params_t7.add_input(h_tile_i8_inline92__rv_v2); + params_t7.add_input(ext_expert_w2); + params_t7.add_output(ret0__out_7); + params_t7.add_scalar(local_i_inline67); + params_t7.add_scalar(d0_inline49); + rt_submit_aic_task(26, params_t7); + const Tensor& y_acc_inline109 = ret0__out_7; + + // Task 27: exp_w2_dequant + Arg params_t8; + params_t8.add_input(y_acc_inline109); + params_t8.add_input(ext_expert_w2_scale); + params_t8.add_input(h_tile_scale_dq_inline63); + params_t8.add_output(ret0__out_8); + params_t8.add_scalar(local_i_inline67); + params_t8.add_scalar(d0_inline49); + rt_submit_aiv_task(27, params_t8); + const Tensor& y_2d_v1_inline17 = ret0__out_8; + + // Task 28: exp_recv_y_write + Arg params_t9; + params_t9.add_input(y_2d_v1_inline17); + params_t9.add_output(recv_y_flat_inline53); + params_t9.add_scalar(flat_t0_inline40); + params_t9.add_scalar(d0_inline49); + rt_submit_aiv_task(28, params_t9); + const Tensor& recv_y_flat_inline53__ssa_v7 = recv_y_flat_inline53; + } + } + } + } + } + } + for (int64_t n0_inline113 = 0; n0_inline113 < 4096; n0_inline113 += 256) { + PTO2_SCOPE() { + uint32_t ret0__out_9_ci_shapes[2] = {16, 256}; + TensorCreateInfo ret0__out_9_ci(ret0__out_9_ci_shapes, 2, DataType::INT32); + uint32_t ret1__out_2_ci_shapes[2] = {16, 256}; + TensorCreateInfo ret1__out_2_ci(ret1__out_2_ci_shapes, 2, DataType::INT32); + uint32_t ret0__out_10_ci_shapes[2] = {16, 256}; + TensorCreateInfo ret0__out_10_ci(ret0__out_10_ci_shapes, 2, DataType::FLOAT32); + uint32_t ret1__out_3_ci_shapes[2] = {16, 256}; + TensorCreateInfo ret1__out_3_ci(ret1__out_3_ci_shapes, 2, DataType::FLOAT32); + TaskOutputTensors alloc_4 = alloc_tensors(ret0__out_9_ci, ret1__out_2_ci, ret0__out_10_ci, ret1__out_3_ci); + const Tensor& ret0__out_9 = alloc_4.get_ref(0); + const Tensor& ret1__out_2 = alloc_4.get_ref(1); + const Tensor& ret0__out_10 = alloc_4.get_ref(2); + const Tensor& ret1__out_3 = alloc_4.get_ref(3); + + // Task 29: sh_gate_up_matmul + Arg params_t10; + params_t10.add_input(x_local_i8_inline43__rv_v2); + params_t10.add_input(ext_shared_w1); + params_t10.add_input(ext_shared_w3); + params_t10.add_output(ret0__out_9); + params_t10.add_output(ret1__out_2); + params_t10.add_scalar(n0_inline113); + rt_submit_aic_task(29, params_t10); + const Tensor& sh_gate_acc_inline95 = ret0__out_9; + const Tensor& sh_up_acc_inline118 = ret1__out_2; + + // Task 30: sh_gate_up_dequant + Arg params_t11; + params_t11.add_input(ext_shared_w1_scale); + params_t11.add_input(ext_shared_w3_scale); + params_t11.add_input(sh_gate_acc_inline95); + params_t11.add_input(sh_up_acc_inline118); + params_t11.add_input(x_local_scale_dq_inline32); + params_t11.add_output(ret0__out_10); + params_t11.add_output(ret1__out_3); + params_t11.add_scalar(n0_inline113); + rt_submit_aiv_task(30, params_t11); + const Tensor& sh_gate_v1_inline50 = ret0__out_10; + const Tensor& sh_up_v1_inline9 = ret1__out_3; + + // Task 31: sh_swiglu + Arg params_t12; + params_t12.add_input(sh_gate_v1_inline50); + params_t12.add_input(sh_up_v1_inline9); + params_t12.add_output(sh_tile_fp32_inline88); + params_t12.add_scalar(n0_inline113); + rt_submit_aiv_task(31, params_t12); + const Tensor& sh_tile_fp32_inline88__ssa_v3 = sh_tile_fp32_inline88; + } + } + + // Task 32: sh_h_q + Arg params_t13; + params_t13.add_input(sh_tile_fp32_inline88); + params_t13.add_output(sh_tile_i8_inline126); + params_t13.add_output(ret0__out_1); + rt_submit_aiv_task(32, params_t13); + const Tensor& sh_tile_i8_inline126__rv_v2 = sh_tile_i8_inline126; + const Tensor& sh_tile_scale_dq_inline99 = ret0__out_1; + for (int64_t d0_inline64 = 0; d0_inline64 < 4096; d0_inline64 += 512) { + PTO2_SCOPE() { + uint32_t ret0__out_11_ci_shapes[2] = {16, 512}; + TensorCreateInfo ret0__out_11_ci(ret0__out_11_ci_shapes, 2, DataType::INT32); + uint32_t ret0__out_12_ci_shapes[2] = {16, 512}; + TensorCreateInfo ret0__out_12_ci(ret0__out_12_ci_shapes, 2, DataType::FLOAT32); + TaskOutputTensors alloc_5 = alloc_tensors(ret0__out_11_ci, ret0__out_12_ci); + const Tensor& ret0__out_11 = alloc_5.get_ref(0); + const Tensor& ret0__out_12 = alloc_5.get_ref(1); + + // Task 33: sh_w2_matmul + Arg params_t14; + params_t14.add_input(sh_tile_i8_inline126__rv_v2); + params_t14.add_input(ext_shared_w2); + params_t14.add_output(ret0__out_11); + params_t14.add_scalar(d0_inline64); + rt_submit_aic_task(33, params_t14); + const Tensor& sh_y_acc_inline4 = ret0__out_11; + + // Task 34: sh_w2_dequant + Arg params_t15; + params_t15.add_input(ext_shared_w2_scale); + params_t15.add_input(sh_y_acc_inline4); + params_t15.add_input(sh_tile_scale_dq_inline99); + params_t15.add_output(ret0__out_12); + params_t15.add_scalar(d0_inline64); + rt_submit_aiv_task(34, params_t15); + const Tensor& sh_y_v1_inline114 = ret0__out_12; + + // Task 35: sh_write + Arg params_t16; + params_t16.add_input(sh_y_v1_inline114); + params_t16.add_output(ext_sh); + params_t16.add_scalar(d0_inline64); + rt_submit_aiv_task(35, params_t16); + const Tensor& sh = ext_sh; + } + } + } + + // ---- func_id 36: combine ---- { Arg p; p.add_input(recv_y); p.add_input(recv_idx_out); p.add_output(routed_y); p.add_inout(scratch); - p.add_scalar(orch_args.scalar(0)); - p.add_scalar(orch_args.scalar(1)); - rt_submit_aiv_task(2, p); + p.add_scalar(orch_args.scalar(0)); // nranks + p.add_scalar(orch_args.scalar(1)); // CommContext + rt_submit_aiv_task(36, p); + } + + // ---- func_id 37: ffn_add (ffn_out = routed_y + sh) ---- + { + Arg p; + p.add_input(routed_y); + p.add_input(sh); + p.add_output(ffn_out); + rt_submit_aiv_task(37, p); + } + + // ---- func_id 38: hc_post (transplanted PyPTO-generated body) ---- + // Local aliases for the names the generated body refers to (`ext_x` etc.): + // x ← ffn_out (the MoE FFN result) + // residual ← ext_x_hc (the router's input x_hc carries through as + // residual per the single-layer spec line 99) + // post / comb ← post_ffn / comb_ffn from the router + // y ← chip-allocated output tensor (next-layer x_hc) + { + Tensor ext_x = ffn_out; + Tensor ext_residual = ext_x_hc; + Tensor ext_post = ext_post_ffn; + Tensor ext_comb = ext_comb_ffn; + Tensor ext_y = y; + + PTO2_SCOPE() { + uint32_t x_flat_inline2_shapes[2] = {16, 4096}; + Tensor x_flat_inline2 = ext_x.reshape(x_flat_inline2_shapes, 2); + uint32_t residual_flat_inline6_shapes[2] = {16, 16384}; + Tensor residual_flat_inline6 = ext_residual.reshape(residual_flat_inline6_shapes, 2); + uint32_t post_flat_inline8_shapes[1] = {64}; + Tensor post_flat_inline8 = ext_post.reshape(post_flat_inline8_shapes, 1); + uint32_t comb_flat_inline9_shapes[1] = {256}; + Tensor comb_flat_inline9 = ext_comb.reshape(comb_flat_inline9_shapes, 1); + uint32_t y_flat_inline7_shapes[2] = {16, 16384}; + Tensor y_flat_inline7 = ext_y.reshape(y_flat_inline7_shapes, 2); + for (int64_t out_h_inline3 = 0; out_h_inline3 < 4; out_h_inline3 += 1) { + PTO2_SCOPE() { + for (int64_t t_inline10 = 0; t_inline10 < 1; t_inline10 += 1) { + PTO2_SCOPE() { + + // Task 38: hc_post + Arg params_t0; + params_t0.add_output(y_flat_inline7); + params_t0.add_input(post_flat_inline8); + params_t0.add_input(x_flat_inline2); + params_t0.add_input(comb_flat_inline9); + params_t0.add_input(residual_flat_inline6); + params_t0.add_scalar((uint64_t)0); + params_t0.add_scalar(out_h_inline3); + rt_submit_aiv_task(38, params_t0); + } + } + } + } + } } } diff --git a/examples/workers/l3/ep_dispatch_combine/main.py b/examples/workers/l3/ep_dispatch_combine/main.py index 3ebc60f70..0ec10068f 100644 --- a/examples/workers/l3/ep_dispatch_combine/main.py +++ b/examples/workers/l3/ep_dispatch_combine/main.py @@ -7,51 +7,37 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. # ----------------------------------------------------------------------------------------------------------- -"""End-to-end 2-card EP dispatch + local_expert + combine demo. - -A single orchestration runs three child AIV kernels back-to-back over a -shared HCCL window scratch: - - dispatch.cpp count exchange + 3-channel push + per-channel stage-out - local_expert.cpp recv_y[e, s, :] = recv_x[e, s, :] * recv_w[e, s] - (placeholder for the production moe_expert) - combine.cpp TPUT recv_y rows by recv_idx_out into - routed_y_buf[t, k, :] (relies on HCCL window zero-init, - no per-call clear), barrier, reduce_sum along - TOPK -> routed_y FP32 - -Each rank drives the dispatch kernel through these phases (0..4): - - histogram scalar histogram + (dst, loc_e)-sorted route table from indices - publish publish full send_counts table to peers via TNOTIFY(AtomicAdd) - + count_done barrier - prefix_sum local prefix sums over global pub_counts (no comm) - payload_push for each route: TPUT three independent payload tiles — - x [BF16, 1xD] x_norm[t, :] - weight [FP32, 1xW_PAD] w_padded[r, :] = [weight, 0, …, 0] - idx [INT32, 1xIDX_PAD] idx_padded[r, :] = [r, 0, …, 0] - where r = t * TOPK + k - to peer's recv_x[loc_e][slot, :] / recv_w[…] / recv_idx[…] - + data_done barrier - stage_out stage out recv_x / recv_w / recv_idx windows -> host-backed outputs - -Type/shape contract: - - ``x_norm`` and ``recv_x_out`` are **BF16**. Test inputs use small - integer values (≤ 256) that fit BF16 exactly. - - Weight uses a 1xW_PAD=8 FP32 tile per route (the minimum vector tile - granularity = 32 B = one MTE burst). The host pre-packs each row as - [weight, 0, 0, …, 0]; receiver writes recv_w[loc_e][slot, :W_PAD] - and the kernel TROWSUM-compacts to a [L, R] FP32 host output. - - Idx uses the same minimum-tile rationale: 1xIDX_PAD=8 INT32 per - route, actual r=t*TOPK+k at slot [0]; compacted via scalar copy to - [L, R] INT32 host output. Combine reads it to address - routed_y_buf[t, k, :] without a host-built origin_map. - - ``recv_count_out`` is [L, 1] INT32 emitted by dispatch's prefix_sum - phase. +"""End-to-end 2-card moe_router + EP dispatch + moe_expert + combine demo. + +A single orchestration runs four back-to-back stages over a shared HCCL window +scratch: + + moe_router kernels FFN half-compress pre-mix (hc_pre) + RMSNorm + learned- + score gate + top-k + weight normalize. Produces x_norm, + indices, weights (plus post_ffn / comb_ffn for hc_post). + 18 PyPTO-generated AIV kernels. + dispatch.cpp EP count exchange + 3-channel push (x BF16 / weight FP32 / + idx INT32) + per-channel stage-out + recv_count emission. + Now reads the *chip-produced* x_norm and indices, plus + host-packed w_padded / idx_padded. + moe_expert kernels DeepSeek-V4 decode MoE block — routed local experts + (per-tile A8 gate/up matmul → SwiGLU → routing-weight mul + → A8 requant → w2 matmul → recv_y) + shared expert + (x_local A8 → gate/up → SwiGLU → A8 → w2 → sh). 17 + PyPTO-generated incore kernels (4 AIC matmuls + 13 AIV). + combine.cpp TPUT recv_y rows by recv_idx_out into routed_y_buf, + barrier, reduce_sum along TOPK -> routed_y FP32. + +Dimensions mirror the ``DEMO`` decode config: D = hidden_size = 4096, +MOE_INTER = 4096, L = N_LOCAL_EXPERTS = 8, T = 16, TOPK = 2, R = RECV_MAX = 32, +HC_MULT = 4. INT8 expert weight banks + FP32 router fixtures are generated +randomly on host (shared across the two ranks). The host golden is +``golden_moe_router`` (ported from ``models/deepseek/v4/moe_router.py``) → +dispatch protocol replay → ``golden_moe_expert`` → combine reduce. Run: - python examples/workers/l3/ep_dispatch_combine/main.py -p a2a3sim -d 0-1 + python examples/workers/l3/ep_dispatch_combine/main.py -p a2a3 -d 0-1 """ from __future__ import annotations @@ -59,10 +45,12 @@ import argparse import os import sys +from concurrent.futures import ThreadPoolExecutor os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE") import torch # noqa: E402 +import torch.nn.functional as F # noqa: E402 from simpler.task_interface import ( # noqa: E402 ArgDirection, CallConfig, @@ -85,37 +73,96 @@ from simpler_setup.torch_interop import make_tensor_arg # noqa: E402 HERE = os.path.dirname(os.path.abspath(__file__)) +RUNTIME = "tensormap_and_ringbuffer" -# Demo dimensions — must mirror constants at the top of the kernel. +# Demo dimensions — mirror the ``DEMO`` decode config and the constants at the +# top of dispatch.cpp / combine.cpp. N_RANKS = 2 -T = 8 +B = 16 # DECODE_BATCH +S = 1 # DECODE_SEQ +T = B * S # tokens per rank TOPK = 2 -D = 64 -L = 4 # N_LOCAL_EXPERTS per rank -R = 32 # RECV_MAX (single-expert receive upper bound) -W_PAD = 8 # weight tile width — minimum vector tile (1x8 FP32 = 32 B) -IDX_PAD = 8 # idx tile width — minimum vector tile (1x8 INT32 = 32 B) -E_GLOBAL = N_RANKS * L +D = 4096 # hidden_size +L = 8 # N_LOCAL_EXPERTS per rank +N_EXPERTS = L # global experts in moe_expert's view (its JIT was compiled with EP_WORLD_SIZE=1) +R = 32 # RECV_MAX +MOE_INTER = 4096 +SWIGLU_LIMIT = 0.0 +INT8_SCALE_MAX = 127.0 +INT8_AMAX_EPS = 1e-4 +W_PAD = 8 +IDX_PAD = 8 +E_GLOBAL = N_RANKS * L # global EP expert count N_ROUTES = T * TOPK -# Window region byte sizes — mirror k*Bytes / kOff* in the kernels. -PUB_COUNTS_BYTES = N_RANKS * N_RANKS * L * 4 # N*N*L INT32 -SIGNAL_BYTES = 64 # padded slot per signal area -RECV_X_BYTES = L * R * D * 2 # 16 KB (BF16) -RECV_W_BYTES = L * R * W_PAD * 4 # 4 KB (FP32; weight at slot 0) -RECV_IDX_BYTES = L * R * IDX_PAD * 4 # 4 KB (INT32; r at slot 0) -ROUTED_Y_BUF_BYTES = T * TOPK * D * 2 # 2 KB (BF16; combine push dest) +# Router constants — must match the ``DEMO`` config the router was JIT'd against. +HC_MULT = 4 +HC_DIM = HC_MULT * D # 16384 +MIX_HC = (2 + HC_MULT) * HC_MULT # 24 +HC_SINKHORN_ITER = 20 +HC_EPS = 1e-6 +NORM_EPS = 1e-6 +VOCAB = 129280 # vocab_size (tid2eid is unused at LAYER_ID >= N_HASH_LAYERS; still must allocate) +ROUTE_SCALE = 1.0 +LAYER_ID = 1 +N_HASH_LAYERS = 0 # LAYER_ID >= N_HASH_LAYERS -> learned-score path + +# Window region byte sizes — mirror k*Bytes / kOff* in dispatch.cpp / combine.cpp. +PUB_COUNTS_BYTES = N_RANKS * N_RANKS * L * 4 +SIGNAL_BYTES = 64 +RECV_X_BYTES = L * R * D * 2 +RECV_W_BYTES = L * R * W_PAD * 4 +RECV_IDX_BYTES = L * R * IDX_PAD * 4 +ROUTED_Y_BUF_BYTES = T * TOPK * D * 2 SCRATCH_NBYTES = ( - PUB_COUNTS_BYTES - + SIGNAL_BYTES # count_done_sig - + RECV_X_BYTES - + RECV_W_BYTES - + RECV_IDX_BYTES - + SIGNAL_BYTES # data_done_sig - + ROUTED_Y_BUF_BYTES # combine push destination - + SIGNAL_BYTES # combine_done_sig + PUB_COUNTS_BYTES + SIGNAL_BYTES + RECV_X_BYTES + RECV_W_BYTES + RECV_IDX_BYTES + + SIGNAL_BYTES + ROUTED_Y_BUF_BYTES + SIGNAL_BYTES ) +# Kernel set: 18 router (0..17), dispatch (18), 17 moe_expert (19..35), combine (36). +# func_id matches rt_submit_*_task calls in ep_dispatch_combine_orch.cpp. +KERNELS: list[tuple[int, str, str, str]] = [ + (0, "cast_x", "kernels/aiv/cast_x.cpp", "aiv"), + (1, "rms", "kernels/aiv/rms.cpp", "aiv"), + (2, "linear", "kernels/aiv/linear.cpp", "aiv"), + (3, "split_pre_post", "kernels/aiv/split_pre_post.cpp", "aiv"), + (4, "split_pre_post_0", "kernels/aiv/split_pre_post_0.cpp", "aiv"), + (5, "split_pre_post_1", "kernels/aiv/split_pre_post_1.cpp", "aiv"), + (6, "split_pre_post_2", "kernels/aiv/split_pre_post_2.cpp", "aiv"), + (7, "comb_sinkhorn", "kernels/aiv/comb_sinkhorn.cpp", "aiv"), + (8, "write_post", "kernels/aiv/write_post.cpp", "aiv"), + (9, "mix_x", "kernels/aiv/mix_x.cpp", "aiv"), + (10, "ffn_norm_rms", "kernels/aiv/ffn_norm_rms.cpp", "aiv"), + (11, "ffn_norm_apply", "kernels/aiv/ffn_norm_apply.cpp", "aiv"), + (12, "gate_dot", "kernels/aiv/gate_dot.cpp", "aiv"), + (13, "gate_dot_0", "kernels/aiv/gate_dot_0.cpp", "aiv"), + (14, "route_sort_top2", "kernels/aiv/route_sort_top2.cpp", "aiv"), + (15, "route_extract_top2", "kernels/aiv/route_extract_top2.cpp", "aiv"), + (16, "route_normalize_weights", "kernels/aiv/route_normalize_weights.cpp", "aiv"), + (17, "write_route_outputs", "kernels/aiv/write_route_outputs.cpp", "aiv"), + (18, "dispatch", "kernels/aiv/dispatch.cpp", "aiv"), + (19, "x_local_q", "kernels/aiv/x_local_q.cpp", "aiv"), + (20, "recv_x_q", "kernels/aiv/recv_x_q.cpp", "aiv"), + (21, "exp_gate_up_matmul", "kernels/aic/exp_gate_up_matmul.cpp", "aic"), + (22, "exp_gate_up_dequant", "kernels/aiv/exp_gate_up_dequant.cpp", "aiv"), + (23, "exp_swiglu", "kernels/aiv/exp_swiglu.cpp", "aiv"), + (24, "exp_swiglu_mask", "kernels/aiv/exp_swiglu_mask.cpp", "aiv"), + (25, "exp_h_q", "kernels/aiv/exp_h_q.cpp", "aiv"), + (26, "exp_w2_matmul", "kernels/aic/exp_w2_matmul.cpp", "aic"), + (27, "exp_w2_dequant", "kernels/aiv/exp_w2_dequant.cpp", "aiv"), + (28, "exp_recv_y_write", "kernels/aiv/exp_recv_y_write.cpp", "aiv"), + (29, "sh_gate_up_matmul", "kernels/aic/sh_gate_up_matmul.cpp", "aic"), + (30, "sh_gate_up_dequant", "kernels/aiv/sh_gate_up_dequant.cpp", "aiv"), + (31, "sh_swiglu", "kernels/aiv/sh_swiglu.cpp", "aiv"), + (32, "sh_h_q", "kernels/aiv/sh_h_q.cpp", "aiv"), + (33, "sh_w2_matmul", "kernels/aic/sh_w2_matmul.cpp", "aic"), + (34, "sh_w2_dequant", "kernels/aiv/sh_w2_dequant.cpp", "aiv"), + (35, "sh_write", "kernels/aiv/sh_write.cpp", "aiv"), + (36, "combine", "kernels/aiv/combine.cpp", "aiv"), + (37, "ffn_add", "kernels/aiv/ffn_add.cpp", "aiv"), + (38, "hc_post", "kernels/aiv/hc_post.cpp", "aiv"), +] + def parse_device_range(spec: str) -> list[int]: if "," in spec: @@ -131,146 +178,202 @@ def parse_device_range(spec: str) -> list[int]: def build_chip_callable(platform: str, pto_isa_commit: str | None) -> ChipCallable: - """Compile the dispatch / local_expert / combine AIV kernels + their - shared C++ orchestration shim into a single ChipCallable with three - child callables. - """ + """Compile the 18 router + dispatch + 17 moe_expert + combine kernels and + the merged C++ orchestration into a single ChipCallable. ccec runs in a + thread pool because the AIC matmuls alone take real time.""" kc = KernelCompiler(platform=platform) - runtime = "tensormap_and_ringbuffer" pto_isa_root = ensure_pto_isa_root(commit=pto_isa_commit, clone_protocol="https") - include_dirs = kc.get_orchestration_include_dirs(runtime) - kernel_include_dirs = list(include_dirs) + [str(kc.project_root / "src" / "common")] + include_dirs = kc.get_orchestration_include_dirs(RUNTIME) + is_sim = platform.endswith("sim") + # dispatch.cpp uses `dcci(...)` (via aicore/aicore.h) to invalidate the + # D-cache before its scalar read of the router-written `indices`. aicore.h + # transitively includes "inner_kernel.h" which lives under the platform- + # specific aicore subdir, so plumb that in. + arch = "a2a3" if platform.startswith("a2a3") else "a5" + inner_kernel_subdir = "sim" if is_sim else "onboard" + kernel_include_dirs = list(include_dirs) + [ + str(kc.project_root / "src" / "common"), + str(kc.project_root / "src" / arch / "platform" / inner_kernel_subdir / "aicore"), + ] - def compile_aiv(name: str) -> bytes: + def compile_one(rel_src: str, core_type: str) -> bytes: b = kc.compile_incore( - source_path=os.path.join(HERE, "kernels/aiv", name), - core_type="aiv", + source_path=os.path.join(HERE, rel_src), + core_type=core_type, pto_isa_root=pto_isa_root, extra_include_dirs=kernel_include_dirs, ) - if not platform.endswith("sim"): + if not is_sim: b = extract_text_section(b) return b - dispatch_bin = compile_aiv("dispatch.cpp") - local_expert_bin = compile_aiv("local_expert.cpp") - combine_bin = compile_aiv("combine.cpp") + with ThreadPoolExecutor(max_workers=min(8, len(KERNELS))) as ex: + futs = {fid: ex.submit(compile_one, src, ct) for (fid, _name, src, ct) in KERNELS} + bins = {fid: f.result() for fid, f in futs.items()} orch_bytes = kc.compile_orchestration( - runtime_name=runtime, + runtime_name=RUNTIME, source_path=os.path.join(HERE, "kernels/orchestration/ep_dispatch_combine_orch.cpp"), ) - # Per-child signatures — each kernel sees only the args it actually - # consumes (matching the orch's `Arg` packing for that submit). + # dispatch / combine keep explicit per-child sigs (they're hand-written + # comm kernels). PyPTO-generated incore kernels get an empty signature like + # the other generated examples (dependency tracking lives in the orch's + # add_input / add_output / add_inout calls). sig_dispatch = [ - ArgDirection.IN, # indices - ArgDirection.IN, # x_norm - ArgDirection.IN, # w_padded - ArgDirection.IN, # idx_padded - ArgDirection.OUT, # recv_x_out - ArgDirection.OUT, # recv_w_out - ArgDirection.OUT, # recv_idx_out - ArgDirection.OUT, # recv_count_out - ArgDirection.INOUT, # scratch - ] - sig_local_expert = [ - ArgDirection.IN, # recv_x_out (reused as INPUT) - ArgDirection.IN, # recv_w_out (reused as INPUT) - ArgDirection.IN, # recv_count_out (reused as INPUT) - ArgDirection.OUT, # recv_y - ] - sig_combine = [ - ArgDirection.IN, # recv_y (reused as INPUT) - ArgDirection.IN, # recv_idx_out (reused as INPUT) - ArgDirection.OUT, # routed_y - ArgDirection.INOUT, # scratch - ] - - # The orch's view is the union of every child's tensor footprint. - sig_orch = [ - ArgDirection.IN, # indices - ArgDirection.IN, # x_norm - ArgDirection.IN, # w_padded - ArgDirection.IN, # idx_padded - ArgDirection.OUT, # recv_x_out - ArgDirection.OUT, # recv_w_out - ArgDirection.OUT, # recv_idx_out - ArgDirection.OUT, # recv_count_out - ArgDirection.OUT, # recv_y - ArgDirection.OUT, # routed_y - ArgDirection.INOUT, # scratch + ArgDirection.IN, ArgDirection.IN, ArgDirection.IN, ArgDirection.IN, + ArgDirection.OUT, ArgDirection.OUT, ArgDirection.OUT, ArgDirection.OUT, + ArgDirection.INOUT, ] + sig_combine = [ArgDirection.IN, ArgDirection.IN, ArgDirection.OUT, ArgDirection.INOUT] + sig_ffn_add = [ArgDirection.IN, ArgDirection.IN, ArgDirection.OUT] + children: list[tuple[int, CoreCallable]] = [] + for fid, name, _src, _ct in KERNELS: + if name == "dispatch": + sig = sig_dispatch + elif name == "combine": + sig = sig_combine + elif name == "ffn_add": + sig = sig_ffn_add + else: + sig = [] + children.append((fid, CoreCallable.build(signature=sig, binary=bins[fid]))) + + # Orchestration arg view: 24 INs (x_hc..shared_w2_scale), 14 OUTPUT_EXISTING + # (x_norm..routed_y, ffn_out, y), 1 INOUT scratch — 39 tensors + 2 scalars. + sig_orch = [ArgDirection.IN] * 24 + [ArgDirection.OUT] * 14 + [ArgDirection.INOUT] return ChipCallable.build( signature=sig_orch, func_name="ep_dispatch_combine_orchestration", config_name="ep_dispatch_combine_orchestration_config", binary=orch_bytes, - children=[ - (0, CoreCallable.build(signature=sig_dispatch, binary=dispatch_bin)), - (1, CoreCallable.build(signature=sig_local_expert, binary=local_expert_bin)), - (2, CoreCallable.build(signature=sig_combine, binary=combine_bin)), - ], + children=children, ) -def generate_routing_indices(seed: int) -> torch.Tensor: - """Generate `indices[N_RANKS][T, TOPK]` so no expert exceeds RECV_MAX. - - Each (t, k) is a global expert id in [0, E_GLOBAL). Top-k entries within - a single token are forced unique. Reseed if any per-expert receive count - would overflow R. - """ - rng = torch.Generator().manual_seed(seed) - while True: - indices = torch.zeros(N_RANKS, T, TOPK, dtype=torch.int32) - for r in range(N_RANKS): - for t in range(T): - perm = torch.randperm(E_GLOBAL, generator=rng)[:TOPK] - indices[r, t, :] = perm.to(torch.int32) - - per_expert = torch.zeros(N_RANKS, L, dtype=torch.int32) - for r in range(N_RANKS): - for t in range(T): - for k in range(TOPK): - eid = int(indices[r, t, k].item()) - dst = eid // L - loc_e = eid % L - per_expert[dst, loc_e] += 1 - if int(per_expert.max().item()) <= R: - return indices - seed += 1 - rng.manual_seed(seed) - - -def compute_golden( - x_norms: list[torch.Tensor], # [N_RANKS] of [T, D] BF16 - indices: torch.Tensor, # [N_RANKS, T, TOPK] INT32 - weights: torch.Tensor, # [N_RANKS, T, TOPK] FP32 -): - """Replay dispatch protocol on host -> per-rank - expected_recv_x[L, R, D] BF16 (x payload) - expected_recv_w[L, R] FP32 (weight payload) - expected_recv_idx[L, R] INT32 (r = t*TOPK+k for each delivered row) - expected_count[L] INT32 - """ +# --------------------------------------------------------------------------- # +# Host golden: moe_router (hc_pre + RMSNorm + learned-score gate + top-k) +# Ported from models/deepseek/v4/{hc_pre,moe_router}.py to keep this example +# self-contained (the pypto-lib model code isn't installed alongside the wheel). +# --------------------------------------------------------------------------- # +def _golden_hc_pre(x_hc, hc_ffn_fn, hc_ffn_scale, hc_ffn_base): + """Half-compress pre-mix: produces (x_mixed BF16 [B,S,D], post_ffn FP32 + [B,S,HC_MULT], comb_ffn FP32 [B,S,HC_MULT,HC_MULT]).""" + x = x_hc.float() + hc_fn = hc_ffn_fn.float() + hc_scale = hc_ffn_scale.float() + hc_base = hc_ffn_base.float() + + x_flat = x.flatten(2).reshape(T, HC_DIM) # [T, HC_MULT*D] + sq_sum = (x_flat * x_flat).sum(dim=1, keepdim=True) + rsqrt = torch.rsqrt(sq_sum / HC_DIM + NORM_EPS) + mixes = (x_flat @ hc_fn.T) * rsqrt # [T, MIX_HC] + mixes = mixes.reshape(B, S, MIX_HC) + + pre = torch.sigmoid(mixes[..., :HC_MULT] * hc_scale[0] + hc_base[:HC_MULT]) + HC_EPS + post_t = 2 * torch.sigmoid( + mixes[..., HC_MULT : HC_MULT * 2] * hc_scale[1] + hc_base[HC_MULT : HC_MULT * 2] + ) + comb_t = (mixes[..., HC_MULT * 2 :] * hc_scale[2] + hc_base[HC_MULT * 2 :]).view(B, S, HC_MULT, HC_MULT) + comb_t = torch.softmax(comb_t, dim=-1) + HC_EPS + comb_t = comb_t / (comb_t.sum(-2, keepdim=True) + HC_EPS) + for _ in range(HC_SINKHORN_ITER - 1): + comb_t = comb_t / (comb_t.sum(-1, keepdim=True) + HC_EPS) + comb_t = comb_t / (comb_t.sum(-2, keepdim=True) + HC_EPS) + + y = torch.zeros(B, S, D, dtype=torch.float32) + for h in range(HC_MULT): + y += x[:, :, h, :] * pre[:, :, h : h + 1] + return y.to(torch.bfloat16), post_t, comb_t + + +def golden_moe_router(x_hc, hc_ffn_fn, hc_ffn_scale, hc_ffn_base, norm_w, gate_w, gate_bias): + """Router golden — returns x_norm [T,D] BF16, indices [T,TOPK] INT32, + weights [T,TOPK] FP32, post_ffn [B,S,HC_MULT] FP32, comb_ffn [B,S,HC_MULT,HC_MULT] FP32.""" + x_mixed, post_ffn, comb_ffn = _golden_hc_pre(x_hc, hc_ffn_fn, hc_ffn_scale, hc_ffn_base) + + # FFN RMSNorm (returns BF16 to match the chip's cast-to-BF16 store). + norm_w_f = norm_w.float() + x_f = x_mixed.float() + var = x_f.square().mean(-1, keepdim=True) + x_n = x_f * torch.rsqrt(var + NORM_EPS) + x_norm = (norm_w_f * x_n).to(torch.bfloat16).view(T, D) + + # Learned routing scores + top-k. + scores = F.softplus(x_norm.float() @ gate_w.float().T).sqrt() # [T, N_EXPERTS] + biased = scores + gate_bias.float() + indices = biased.topk(TOPK, dim=-1).indices.to(torch.int32) # [T, TOPK] + weights = scores.gather(1, indices.long()) # [T, TOPK] FP32 + weights = weights / weights.sum(dim=-1, keepdim=True) * ROUTE_SCALE + + # NOTE: the router writes `indices` modulo N_EXPERTS = L (its JIT was built + # with EP_WORLD_SIZE=1 so it only addresses its own 8 local experts). The + # dispatch protocol below treats them as *global* IDs in [0, E_GLOBAL), so + # rebroadcast across ranks by mixing in the rank id at host fixture time. + return x_norm, indices, weights.float(), post_ffn, comb_ffn + + +def build_router_inputs(seed: int): + """Per-rank x_hc / input_ids plus shared FFN-norm + gate + hc_pre fixtures.""" + gen = torch.Generator().manual_seed(seed) + x_hcs = [(torch.randn(B, S, HC_MULT, D, generator=gen) * 0.1).to(torch.bfloat16).share_memory_() + for _ in range(N_RANKS)] + hc_ffn_fn = (torch.randn(MIX_HC, HC_DIM, generator=gen) / HC_DIM**0.5).contiguous().share_memory_() + hc_ffn_scale = (torch.ones(3, dtype=torch.float32) * 0.5).contiguous().share_memory_() + hc_ffn_base = torch.zeros(MIX_HC, dtype=torch.float32).contiguous().share_memory_() + norm_w = torch.ones(D, dtype=torch.float32).contiguous().share_memory_() + # gate_w shaped [N_EXPERTS, D] — N_EXPERTS = L = 8 (the moe_expert JIT was + # compiled with EP_WORLD_SIZE=1, so the router only addresses local experts). + gate_w = (torch.randn(L, D, generator=gen) / D**0.5).contiguous().share_memory_() + gate_bias = torch.zeros(L, dtype=torch.float32).contiguous().share_memory_() + # tid2eid / input_ids are unused at LAYER_ID >= N_HASH_LAYERS=0 but must + # still be passed (the orch binds the slots). + tid2eid = torch.randint(0, L, (VOCAB, TOPK), generator=gen, dtype=torch.int32).contiguous().share_memory_() + input_ids_list = [torch.randint(0, VOCAB, (B, S), generator=gen, dtype=torch.int64).contiguous().share_memory_() + for _ in range(N_RANKS)] + return { + "x_hcs": x_hcs, + "hc_ffn_fn": hc_ffn_fn, + "hc_ffn_scale": hc_ffn_scale, + "hc_ffn_base": hc_ffn_base, + "norm_w": norm_w, + "gate_w": gate_w, + "gate_bias": gate_bias, + "tid2eid": tid2eid, + "input_ids_list": input_ids_list, + } + + +# --------------------------------------------------------------------------- # +# Routing / dispatch host model +# --------------------------------------------------------------------------- # +def _route_dst(src_rank: int, k: int) -> int: + """EP routing policy mirrored from dispatch.cpp: route slot k from rank + src_rank goes to peer (src_rank + k) % N_RANKS. The router's chip-produced + `indices` are local expert IDs in [0, L); the rank component is layered on + here so the demo spreads tokens across both ranks.""" + return (src_rank + k) % N_RANKS + + +def compute_dispatch_golden(x_norms, indices_local, weights): + """Replay the dispatch protocol on host. ``indices_local`` is per-rank + [N_RANKS, T, TOPK] of local expert IDs in [0, L); destination rank is + derived from (src, k) per the EP policy above.""" expected_recv_x = [torch.zeros(L, R, D, dtype=torch.bfloat16) for _ in range(N_RANKS)] expected_recv_w = [torch.zeros(L, R, dtype=torch.float32) for _ in range(N_RANKS)] expected_recv_idx = [torch.zeros(L, R, dtype=torch.int32) for _ in range(N_RANKS)] expected_count = [torch.zeros(L, dtype=torch.int32) for _ in range(N_RANKS)] + route_dest = [[[None] * TOPK for _ in range(T)] for _ in range(N_RANKS)] send_counts = torch.zeros(N_RANKS, N_RANKS, L, dtype=torch.int32) for src in range(N_RANKS): for t in range(T): for k in range(TOPK): - eid = int(indices[src, t, k].item()) - dst = eid // L - loc_e = eid % L - send_counts[src, dst, loc_e] += 1 + loc_e = int(indices_local[src][t, k].item()) + send_counts[src, _route_dst(src, k), loc_e] += 1 for dst in range(N_RANKS): - # Per-destination slot_offset[src][e] = sum_{s < src} send_counts[s, dst, e]. slot_offset = torch.zeros(N_RANKS, L, dtype=torch.int32) running = torch.zeros(L, dtype=torch.int32) for src in range(N_RANKS): @@ -281,70 +384,194 @@ def compute_golden( cursor = torch.zeros(L, dtype=torch.int32) for t in range(T): for k in range(TOPK): - eid = int(indices[src, t, k].item()) - if eid // L != dst: + if _route_dst(src, k) != dst: continue - loc_e = eid % L + loc_e = int(indices_local[src][t, k].item()) slot = int(slot_offset[src, loc_e].item() + cursor[loc_e].item()) cursor[loc_e] += 1 expected_recv_x[dst][loc_e, slot, :] = x_norms[src][t, :] - expected_recv_w[dst][loc_e, slot] = weights[src, t, k] + expected_recv_w[dst][loc_e, slot] = weights[src][t, k] expected_recv_idx[dst][loc_e, slot] = t * TOPK + k + route_dest[src][t][k] = (dst, loc_e, slot) for e in range(L): expected_count[dst][e] = int(running[e].item()) - return expected_recv_x, expected_recv_w, expected_recv_idx, expected_count + return expected_recv_x, expected_recv_w, expected_recv_idx, expected_count, route_dest def pack_weights_padded(weights_row: torch.Tensor) -> torch.Tensor: - """Build [N_ROUTES, W_PAD] FP32 where row r = (weight_value, 0, …, 0). - - The kernel TPUTs row r as a 1xW_PAD tile to the receiver's - recv_w[loc_e][slot, :], so the actual weight ends up at recv_w[..., 0]. - Slots [1, W_PAD) are zero — bandwidth waste vs. a true [L, R] FP32 - output, but W_PAD=8 is the minimum vector tile size in PTO ISA. - """ + """[T*TOPK, W_PAD] FP32 where row r = (weight_value, 0, …, 0).""" out = torch.zeros(N_ROUTES, W_PAD, dtype=torch.float32) for t in range(T): for k in range(TOPK): - r = t * TOPK + k - out[r, 0] = weights_row[t, k] + out[t * TOPK + k, 0] = weights_row[t, k] return out def pack_idx_padded() -> torch.Tensor: - """Build [N_ROUTES, IDX_PAD] INT32 where row r = (r, 0, …, 0). - - Identical layout for every rank — `r = t*TOPK + k` is an intrinsic - label, not rank-specific. Receiver picks slot [0] in the combine - kernel to address routed_y_buf[t, k, :]. - """ + """[T*TOPK, IDX_PAD] INT32 where row r = (r, 0, …, 0).""" out = torch.zeros(N_ROUTES, IDX_PAD, dtype=torch.int32) for t in range(T): for k in range(TOPK): - r = t * TOPK + k - out[r, 0] = r + out[t * TOPK + k, 0] = t * TOPK + k return out +# --------------------------------------------------------------------------- # +# moe_expert host golden (ported from models/deepseek/v4/moe_expert.py) +# --------------------------------------------------------------------------- # +def _round_half_away_from_zero(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.floor(torch.abs(x) + 0.5) + + +def _int8_quant_per_row(x: torch.Tensor): + rows = x.float().reshape(-1, x.shape[-1]) + amax = rows.abs().amax(dim=-1, keepdim=True).clamp_min(INT8_AMAX_EPS) + scale_quant = INT8_SCALE_MAX / amax + out_i8 = _round_half_away_from_zero(rows * scale_quant).to(torch.int32).to(torch.float16).to(torch.int8) + return out_i8.reshape_as(x), (1.0 / scale_quant).reshape(*x.shape[:-1], 1) + + +def _quant_w_per_channel(w: torch.Tensor): + amax = w.float().abs().amax(dim=-1).clamp_min(INT8_AMAX_EPS) + scale_quant = INT8_SCALE_MAX / amax + w_i8 = (_round_half_away_from_zero(w.float() * scale_quant.unsqueeze(-1)) + .to(torch.int32).to(torch.float16).to(torch.int8)) + return w_i8, (1.0 / scale_quant).float() + + +def build_expert_weights(seed: int): + gen = torch.Generator().manual_seed(seed) + out: dict[str, torch.Tensor] = {} + + def _store(name: str, w_bf16: torch.Tensor) -> None: + w_i8, w_s = _quant_w_per_channel(w_bf16) + out[name] = w_i8.contiguous().share_memory_() + out[name + "_scale"] = w_s.contiguous().share_memory_() + + _store("expert_w1", (torch.randn(L, MOE_INTER, D, generator=gen) / D**0.5).to(torch.bfloat16)) + _store("expert_w3", (torch.randn(L, MOE_INTER, D, generator=gen) / D**0.5).to(torch.bfloat16)) + _store("expert_w2", (torch.randn(L, D, MOE_INTER, generator=gen) / MOE_INTER**0.5).to(torch.bfloat16)) + _store("shared_w1", (torch.randn(MOE_INTER, D, generator=gen) / D**0.5).to(torch.bfloat16)) + _store("shared_w3", (torch.randn(MOE_INTER, D, generator=gen) / D**0.5).to(torch.bfloat16)) + _store("shared_w2", (torch.randn(D, MOE_INTER, generator=gen) / MOE_INTER**0.5).to(torch.bfloat16)) + return out + + +def _dequant_w(w_i8: torch.Tensor, w_scale: torch.Tensor) -> torch.Tensor: + return w_i8.to(torch.float32) * w_scale.unsqueeze(-1) + + +def golden_moe_expert(recv_x, recv_weights, recv_count, x_local, w): + """Torch reference for one rank's moe_expert call (see moe_expert.py).""" + recv_x = recv_x.float() + recv_weights = recv_weights.float() + x_local = x_local.float() + w1 = _dequant_w(w["expert_w1"], w["expert_w1_scale"].float()) + w3 = _dequant_w(w["expert_w3"], w["expert_w3_scale"].float()) + w2 = _dequant_w(w["expert_w2"], w["expert_w2_scale"].float()) + sw1 = _dequant_w(w["shared_w1"], w["shared_w1_scale"].float()) + sw3 = _dequant_w(w["shared_w3"], w["shared_w3_scale"].float()) + sw2 = _dequant_w(w["shared_w2"], w["shared_w2_scale"].float()) + + x_local_i8, x_local_sd = _int8_quant_per_row(x_local) + x_local_q = x_local_i8.float() * x_local_sd + + recv_y = torch.zeros(L, R, D, dtype=torch.float32) + for e in range(L): + n_rows = int(recv_count[e].item()) + if n_rows == 0: + continue + x_sub = recv_x[e, :n_rows, :] + w_sub = recv_weights[e, :n_rows] + + x_sub_i8, x_sub_sd = _int8_quant_per_row(x_sub) + x_sub_q = x_sub_i8.float() * x_sub_sd + + gate = x_sub_q @ w1[e].T + up = x_sub_q @ w3[e].T + if SWIGLU_LIMIT > 0: + gate = gate.clamp(max=SWIGLU_LIMIT) + up = up.clamp(-SWIGLU_LIMIT, SWIGLU_LIMIT) + h = F.silu(gate) * up + h = h * w_sub.unsqueeze(-1) + h_i8, h_sd = _int8_quant_per_row(h) + h = h_i8.float() * h_sd + recv_y[e, :n_rows, :] = h @ w2[e].T + + sh_gate = x_local_q @ sw1.T + sh_up = x_local_q @ sw3.T + if SWIGLU_LIMIT > 0: + sh_gate = sh_gate.clamp(max=SWIGLU_LIMIT) + sh_up = sh_up.clamp(-SWIGLU_LIMIT, SWIGLU_LIMIT) + sh_h = F.silu(sh_gate) * sh_up + sh_h_i8, sh_h_sd = _int8_quant_per_row(sh_h) + sh_h = sh_h_i8.float() * sh_h_sd + sh = sh_h @ sw2.T + + return recv_y.to(torch.bfloat16), sh.to(torch.bfloat16) + + +# --------------------------------------------------------------------------- # +# Verification +# --------------------------------------------------------------------------- # +def _verify_router_outputs( + nranks, x_norm_goldens, indices_goldens, weights_goldens, + x_norm_outs, indices_outs, weights_outs, +) -> bool: + ok = True + for r in range(nranks): + got_xn = x_norm_outs[r].float() + exp_xn = x_norm_goldens[r].float() + d_xn = (got_xn - exp_xn).abs() + print(f"[ep_dispatch] chip {r}: router x_norm max|diff|={float(d_xn.max()):.3e} (BF16 tol 1e-2)") + if not torch.allclose(got_xn, exp_xn, rtol=1e-2, atol=1e-2): + ok = False + print(f"[ep_dispatch] chip {r}: x_norm mismatch") + + # indices: rerank locally so chip[T,TOPK]'s local IDs in [0,L) align + # with the golden (the host extended them by rank for dispatch; the + # chip writes raw local IDs). + got_idx = indices_outs[r] + exp_idx_local = indices_goldens[r] # already local IDs (pre-rerank) + # Compare the (sorted) set of (index, weight) pairs per row to absorb + # the tie-break order differences moe_router.py's topk_pair_compare + # handles (sort32 vs torch.topk). + got_pairs = sorted(zip(got_idx[r_t].tolist(), weights_outs[r][r_t].tolist()) + for r_t in range(0)) # placeholder; we'll do it per row below + idx_mismatches = 0 + for t in range(T): + got_set = sorted((int(got_idx[t, k].item()), round(float(weights_outs[r][t, k].item()), 4)) for k in range(TOPK)) + exp_set = sorted((int(exp_idx_local[t, k].item()), round(float(weights_goldens[r][t, k].item()), 4)) for k in range(TOPK)) + # match expert set; weights within fp32 tol + if [s[0] for s in got_set] != [s[0] for s in exp_set]: + idx_mismatches += 1 + if idx_mismatches <= 3: + print(f"[ep_dispatch] chip {r} token {t}: indices got={[s[0] for s in got_set]} exp={[s[0] for s in exp_set]}") + if idx_mismatches > 0: + ok = False + + got_w = weights_outs[r] + exp_w = weights_goldens[r] + d_w = (got_w - exp_w).abs() + print(f"[ep_dispatch] chip {r}: router weights max|diff|={float(d_w.max()):.3e}") + if not torch.allclose(got_w, exp_w, rtol=1e-3, atol=1e-3): + ok = False + print(f"[ep_dispatch] chip {r}: weights mismatch") + return ok + + def _verify_recv_outputs( - nranks: int, - expected_count: list[torch.Tensor], - expected_recv_x: list[torch.Tensor], - expected_recv_w: list[torch.Tensor], - expected_recv_idx: list[torch.Tensor], - recv_count_outs: list[torch.Tensor], - recv_x_outs: list[torch.Tensor], - recv_w_outs: list[torch.Tensor], - recv_idx_outs: list[torch.Tensor], + nranks, expected_count, expected_recv_x, expected_recv_w, expected_recv_idx, + recv_count_outs, recv_x_outs, recv_w_outs, recv_idx_outs, ) -> bool: - """Compare dispatch outputs against the host golden, per rank and per expert.""" + """dispatch outputs vs the protocol replay. recv_x now compares with BF16 + tolerance — its source x_norm goes through router rounding too.""" ok = True for r in range(nranks): cnt = expected_count[r] print(f"[ep_dispatch] chip {r}: expected counts per expert = {cnt.tolist()}") - # recv_count_out is [L, 1] INT32 per the protocol. got_count = recv_count_outs[r].squeeze(-1) if (got_count - cnt).abs().max().item() != 0: ok = False @@ -353,226 +580,349 @@ def _verify_recv_outputs( n = int(cnt[e].item()) if n == 0: continue - # Cast BF16 → FP32 for diff math; values are integer ≤ 256 so - # the comparison is bit-exact. - got_x = recv_x_outs[r][e, :n, :].to(torch.float32) - exp_x = expected_recv_x[r][e, :n, :].to(torch.float32) - got_w = recv_w_outs[r][e, :n] - exp_w = expected_recv_w[r][e, :n] - got_idx = recv_idx_outs[r][e, :n] - exp_idx = expected_recv_idx[r][e, :n] + got_x = recv_x_outs[r][e, :n, :].float() + exp_x = expected_recv_x[r][e, :n, :].float() x_diff = (got_x - exp_x).abs().max().item() - w_diff = (got_w - exp_w).abs().max().item() - idx_diff = (got_idx - exp_idx).abs().max().item() - if x_diff > 0 or w_diff > 1e-5 or idx_diff != 0: + w_diff = (recv_w_outs[r][e, :n] - expected_recv_w[r][e, :n]).abs().max().item() + idx_diff = (recv_idx_outs[r][e, :n] - expected_recv_idx[r][e, :n]).abs().max().item() + if (x_diff > 5e-2) or (w_diff > 1e-3) or (idx_diff != 0): ok = False - print( - f"[ep_dispatch] chip {r} expert {e}: cnt={n} " - f"x_diff={x_diff:.3e} w_diff={w_diff:.3e} idx_diff={idx_diff}" - ) - if x_diff > 0: - for s in range(min(n, 3)): - print(f" slot {s}: got x[0]={float(got_x[s, 0])} expected={float(exp_x[s, 0])}") - if idx_diff != 0: - for s in range(min(n, 3)): - print(f" slot {s}: got idx={int(got_idx[s])} expected={int(exp_idx[s])}") + print(f"[ep_dispatch] chip {r} expert {e}: cnt={n} x_diff={x_diff:.3e} w_diff={w_diff:.3e} idx_diff={idx_diff}") return ok -def _verify_routed_y( - nranks: int, - x_norms: list[torch.Tensor], - weights: torch.Tensor, - routed_y_outs: list[torch.Tensor], -) -> bool: - """Compare combine output routed_y[t, :] against the local-only golden. - - routed_y[t, :] should equal sum_k weights[me, t, k] * x_norms[me][t, :] - since local_expert is elementwise x*weight and combine reduces along - TOPK. The only BF16 round-trip is local_expert's `cast(x*w, bf16)`; - combine's accumulator stays FP32 — mirror that exactly so the model - captures every cast and we can assert ~0 diff in steady state. - """ +def _verify_expert_outputs(nranks, recv_y_goldens, sh_goldens, expected_count, recv_y_outs, sh_outs) -> bool: ok = True + # Loose tolerance: chip x_norm differs from golden by up to 1 BF16 ulp + # (~1.5e-2 in this fixture), then the INT8 per-row quant amax shifts, + # compounded across the gate/up matmul → SwiGLU → A8 requant → w2 chain. + rtol, atol = 1e-2, 5e-2 for r in range(nranks): - expected = torch.zeros(T, D, dtype=torch.float32) - for t in range(T): - for k in range(TOPK): - weighted_fp32 = weights[r, t, k] * x_norms[r][t, :].to(torch.float32) - expected[t, :] += weighted_fp32.to(torch.bfloat16).to(torch.float32) - diff = (routed_y_outs[r] - expected).abs().max().item() - rel_diff = diff / (expected.abs().max().item() + 1e-9) - print(f"[ep_dispatch] chip {r}: routed_y max|diff|={diff:.3e} (rel={rel_diff:.3e})") - # Allow 1e-3 abs as headroom for any fp32 reorder we missed. - if diff > 1e-3: + for e in range(L): + n = int(expected_count[r][e].item()) + if n == 0: + continue + got = recv_y_outs[r][e, :n, :].float() + exp = recv_y_goldens[r][e, :n, :].float() + if not torch.allclose(got, exp, rtol=rtol, atol=atol): + ok = False + diff = (got - exp).abs() + print(f"[ep_dispatch] chip {r} expert {e}: recv_y mismatch n={n} max|diff|={float(diff.max()):.3e}") + got_sh, exp_sh = sh_outs[r].float(), sh_goldens[r].float() + d_sh = (got_sh - exp_sh).abs() + print(f"[ep_dispatch] chip {r}: sh max|diff|={float(d_sh.max()):.3e}") + if not torch.allclose(got_sh, exp_sh, rtol=rtol, atol=atol): ok = False - print(f"[ep_dispatch] chip {r}: routed_y mismatch (tol=1e-3)") - per_token_diff = (routed_y_outs[r] - expected).abs().max(dim=1).values - for t in range(T): - if per_token_diff[t] > 1e-3: - print( - f" token {t}: got[0]={float(routed_y_outs[r][t, 0]):.4f} " - f"expected[0]={float(expected[t, 0]):.4f}" - ) + print(f"[ep_dispatch] chip {r}: sh mismatch") + return ok + + +def _routed_y_golden(route_dest, recv_y_goldens, rank: int) -> torch.Tensor: + """Reconstruct the golden routed_y[rank] = sum_k recv_y_golden[holder][loc_e, slot, :] + in FP32 (mirrors combine's reduce_sum FP32 accumulator).""" + out = torch.zeros(T, D, dtype=torch.float32) + for t in range(T): + for k in range(TOPK): + dst, loc_e, slot = route_dest[rank][t][k] + out[t, :] += recv_y_goldens[dst][loc_e, slot, :].float() + return out + + +def _verify_routed_y(nranks, route_dest, recv_y_goldens, routed_y_outs) -> bool: + ok = True + # routed_y is the FP32 sum of TOPK BF16 recv_y rows; each row carries the + # same kind of compounded noise we tolerate in _verify_expert_outputs, so + # propagate that bound (TOPK * 5e-2 plus a small FP32 reorder buffer). + atol = TOPK * 5e-2 + 1e-2 + rtol = 1e-2 + for me in range(nranks): + expected = _routed_y_golden(route_dest, recv_y_goldens, me) + got = routed_y_outs[me] + diff = (got - expected).abs() + print(f"[ep_dispatch] chip {me}: routed_y max|diff|={float(diff.max()):.3e}") + if not torch.allclose(got, expected, rtol=rtol, atol=atol): + ok = False + print(f"[ep_dispatch] chip {me}: routed_y mismatch (rtol={rtol}, atol={atol})") + return ok + + +def _golden_hc_post(x_bf16, residual_bf16, post_fp32, comb_fp32): + """Torch port of ``models/deepseek/v4/hc_post.py::golden_hc_post``. + + Inputs: + x [B, S, D] BF16 (= ffn_out) + residual [B, S, HC_MULT, D] BF16 (= moe_router input x_hc) + post [B, S, HC_MULT] FP32 + comb [B, S, HC_MULT, HC_MULT] FP32 + Returns y [B, S, HC_MULT, D] BF16.""" + x = x_bf16.float() + residual = residual_bf16.float() + post = post_fp32.float() + comb = comb_fp32.float() + y = torch.zeros(B, S, HC_MULT, D, dtype=torch.float32) + for out_h in range(HC_MULT): + y_row = x * post[:, :, out_h : out_h + 1] + for in_h in range(HC_MULT): + y_row = y_row + residual[:, :, in_h, :] * comb[:, :, in_h, out_h : out_h + 1] + y[:, :, out_h, :] = y_row + return y.to(torch.bfloat16) + + +def _verify_hc_post(nranks, x_hcs, post_ffn_goldens, comb_ffn_goldens, + route_dest, recv_y_goldens, sh_goldens, y_outs) -> bool: + """Verify hc_post output. Feed the host-side ffn_out_golden (= routed_y + + sh, both golden) through golden_hc_post and compare against chip y.""" + ok = True + # hc_post is a 4×4 linear combination of (x, residual) with FP32 fma then a + # single BF16 cast. Noise budget propagates from ffn_out (~5e-2) scaled by + # post (~O(1)) plus the residual contribution; cap at ~1e-1. + atol = 1e-1 + rtol = 1e-2 + for me in range(nranks): + routed_golden = _routed_y_golden(route_dest, recv_y_goldens, me) + ffn_out_golden = (routed_golden + sh_goldens[me].float()).to(torch.bfloat16) + x_bs = ffn_out_golden.view(B, S, D) + expected = _golden_hc_post(x_bs, x_hcs[me], post_ffn_goldens[me], comb_ffn_goldens[me]).float() + got = y_outs[me].float() + diff = (got - expected).abs() + print(f"[ep_dispatch] chip {me}: y (hc_post) max|diff|={float(diff.max()):.3e}") + if not torch.allclose(got, expected, rtol=rtol, atol=atol): + ok = False + print(f"[ep_dispatch] chip {me}: hc_post mismatch (rtol={rtol}, atol={atol})") + return ok + + +def _verify_ffn_out(nranks, route_dest, recv_y_goldens, sh_goldens, ffn_out_outs) -> bool: + """ffn_out = (routed_y + sh).to(bf16). chip uses an FP32 add then one + BF16 cast (CAST_RINT); golden mirrors exactly.""" + ok = True + # Same noise budget as routed_y plus one BF16 ulp from the final cast — keep + # the routed_y bound and let the trailing 1e-2 absorb the cast. + atol = TOPK * 5e-2 + 2e-2 + rtol = 1e-2 + for me in range(nranks): + routed_golden = _routed_y_golden(route_dest, recv_y_goldens, me) + expected = (routed_golden + sh_goldens[me].float()).to(torch.bfloat16).float() + got = ffn_out_outs[me].float() + diff = (got - expected).abs() + print(f"[ep_dispatch] chip {me}: ffn_out max|diff|={float(diff.max()):.3e}") + if not torch.allclose(got, expected, rtol=rtol, atol=atol): + ok = False + print(f"[ep_dispatch] chip {me}: ffn_out mismatch (rtol={rtol}, atol={atol})") return ok +# --------------------------------------------------------------------------- # +# Driver +# --------------------------------------------------------------------------- # def run( device_ids: list[int], platform: str = "a2a3", pto_isa_commit: str | None = None, build: bool = False, + seed: int = 20260513, ) -> int: - """Core logic — callable from CLI and pytest.""" nranks = len(device_ids) assert nranks == N_RANKS window_size = max(SCRATCH_NBYTES, 128 * 1024) - rootinfo_path = f"/tmp/pto_ep_dispatch_rootinfo_{os.getpid()}.bin" try: os.unlink(rootinfo_path) except FileNotFoundError: pass - print(f"[ep_dispatch] platform={platform} devices={device_ids} nranks={nranks}") - - # x_norm[r, t, d] = r*100 + t*10 + d → max value = 1*100 + 7*10 + 63 = 233. - # All values are integers ≤ 256, so they fit BF16 exactly (8-bit mantissa - # with hidden bit = exact integers up to 2^8). The host can compare BF16 - # outputs bit-for-bit against this golden. - x_norms = [ - torch.tensor( - [[r * 100 + t * 10 + d for d in range(D)] for t in range(T)], - dtype=torch.bfloat16, - ).share_memory_() - for r in range(nranks) - ] - weights = torch.tensor( - [[[(r + 1) * 0.01 + t * 0.1 + k * 0.001 for k in range(TOPK)] for t in range(T)] for r in range(nranks)], - dtype=torch.float32, - ) + print(f"[ep_dispatch] platform={platform} devices={device_ids} nranks={nranks} seed={seed}") - indices = generate_routing_indices(seed=20260510) - print(f"[ep_dispatch] indices shape={tuple(indices.shape)} (rank,t,k -> global expert id)") + # ---- host fixtures: router inputs (per-rank x_hc / input_ids + shared FFN/gate banks) ---- + R_in = build_router_inputs(seed) + # ---- moe_expert weight banks (shared across both ranks) ---- + weight_banks = build_expert_weights(seed=seed) - indices_per_rank = [indices[r].clone().contiguous().share_memory_() for r in range(nranks)] - w_padded_list = [pack_weights_padded(weights[r]).share_memory_() for r in range(nranks)] - # idx_padded is rank-independent (r = t*TOPK + k is intrinsic), but each - # rank gets its own shared-memory copy so the framework can plumb it as a - # per-rank input tensor. + # ---- host golden: run the router golden per rank to get x_norm / indices / weights ---- + print("[ep_dispatch] computing host golden (router -> dispatch replay -> moe_expert -> combine)...") + x_norm_goldens: list[torch.Tensor] = [] + indices_goldens: list[torch.Tensor] = [] # local IDs in [0, L) + weights_goldens: list[torch.Tensor] = [] + post_ffn_goldens: list[torch.Tensor] = [] + comb_ffn_goldens: list[torch.Tensor] = [] + for r in range(nranks): + xn, idx_local, w, pf, cf = golden_moe_router( + R_in["x_hcs"][r], R_in["hc_ffn_fn"], R_in["hc_ffn_scale"], R_in["hc_ffn_base"], + R_in["norm_w"], R_in["gate_w"], R_in["gate_bias"], + ) + x_norm_goldens.append(xn) + indices_goldens.append(idx_local) + weights_goldens.append(w) + post_ffn_goldens.append(pf) + comb_ffn_goldens.append(cf) + + # The router's `indices` are local IDs in [0, L); destination rank is + # derived from (src_rank, k) per the EP routing policy that dispatch.cpp + # and compute_dispatch_golden share. No host-side rebroadcast needed. + (expected_recv_x, expected_recv_w, expected_recv_idx, expected_count, route_dest) = \ + compute_dispatch_golden(x_norm_goldens, indices_goldens, weights_goldens) + + # ---- moe_expert per-rank golden (recv_y, sh) ---- + recv_y_goldens, sh_goldens = [], [] + for r in range(nranks): + ry, shg = golden_moe_expert( + expected_recv_x[r], expected_recv_w[r], expected_count[r], x_norm_goldens[r], weight_banks, + ) + recv_y_goldens.append(ry) + sh_goldens.append(shg) + + # ---- pack dispatch host inputs (w_padded / idx_padded) from the GOLDEN weights/indices ---- + # The chip's router produces near-identical weights modulo fp32 rounding; we + # feed the host-known versions through dispatch so recv_w / recv_idx stay + # deterministic. + w_padded_list = [pack_weights_padded(weights_goldens[r]).share_memory_() for r in range(nranks)] idx_padded_list = [pack_idx_padded().share_memory_() for _ in range(nranks)] - - # Outputs — recv_x_out is BF16 (matches kernel TPUT/TSTORE element type); - # recv_w_out / recv_idx_out are compacted to [L, R] inside the kernel. - # recv_count_out is [L, 1] INT32 — dispatch's prefix_sum phase fills it - # from pub_counts so local_expert can iterate `recv_count[e]` rows per expert. + recv_count_host_list = [expected_count[r].reshape(L, 1).clone().contiguous().share_memory_() + for r in range(nranks)] + + # ---- chip output / cross-stage tensors (zero-init, OUTPUT_EXISTING) ---- + x_norm_outs = [torch.zeros(T, D, dtype=torch.bfloat16).share_memory_() for _ in range(nranks)] + indices_outs = [torch.zeros(T, TOPK, dtype=torch.int32).share_memory_() for _ in range(nranks)] + weights_outs = [torch.zeros(T, TOPK, dtype=torch.float32).share_memory_() for _ in range(nranks)] + post_ffn_outs = [torch.zeros(B, S, HC_MULT, dtype=torch.float32).share_memory_() for _ in range(nranks)] + comb_ffn_outs = [torch.zeros(B, S, HC_MULT, HC_MULT, dtype=torch.float32).share_memory_() for _ in range(nranks)] recv_x_outs = [torch.zeros(L, R, D, dtype=torch.bfloat16).share_memory_() for _ in range(nranks)] recv_w_outs = [torch.zeros(L, R, dtype=torch.float32).share_memory_() for _ in range(nranks)] recv_idx_outs = [torch.zeros(L, R, dtype=torch.int32).share_memory_() for _ in range(nranks)] recv_count_outs = [torch.zeros(L, 1, dtype=torch.int32).share_memory_() for _ in range(nranks)] - # Cross-kernel host-backed tensors: - # recv_y [L, R, D] BF16 — local_expert output, also visible to host - # for debug. - # routed_y [T, D] FP32 — combine output; the FP32 reduce accumulator - # is written out directly without a final - # cast back to BF16 (mirrors how the - # production block keeps the FP32 accumulator - # live until exit). recv_y_outs = [torch.zeros(L, R, D, dtype=torch.bfloat16).share_memory_() for _ in range(nranks)] + sh_outs = [torch.zeros(T, D, dtype=torch.bfloat16).share_memory_() for _ in range(nranks)] routed_y_outs = [torch.zeros(T, D, dtype=torch.float32).share_memory_() for _ in range(nranks)] - - print("[ep_dispatch] computing host golden...") - expected_recv_x, expected_recv_w, expected_recv_idx, expected_count = compute_golden(x_norms, indices, weights) + # ffn_out = routed_y + sh — final post-MoE add (single-layer spec). + ffn_out_outs = [torch.zeros(T, D, dtype=torch.bfloat16).share_memory_() for _ in range(nranks)] + # hc_post output — next-layer x_hc [B, S, HC_MULT, D] BF16. + y_outs = [torch.zeros(B, S, HC_MULT, D, dtype=torch.bfloat16).share_memory_() for _ in range(nranks)] cfgs = [ ChipBootstrapConfig( comm=ChipCommBootstrapConfig( - rank=rank, - nranks=nranks, - rootinfo_path=rootinfo_path, - window_size=window_size, + rank=rank, nranks=nranks, rootinfo_path=rootinfo_path, window_size=window_size, ), - buffers=[ - ChipBufferSpec( - name="scratch", - dtype="float32", - count=SCRATCH_NBYTES // 4, - nbytes=SCRATCH_NBYTES, - ), - ], + buffers=[ChipBufferSpec(name="scratch", dtype="float32", count=SCRATCH_NBYTES // 4, nbytes=SCRATCH_NBYTES)], ) for rank in range(nranks) ] - print("[ep_dispatch] compiling kernels...") + print(f"[ep_dispatch] compiling orchestration + {len(KERNELS)} kernels for {platform}...") chip_callable = build_chip_callable(platform, pto_isa_commit) worker = Worker( - level=3, - platform=platform, - runtime="tensormap_and_ringbuffer", - device_ids=device_ids, - num_sub_workers=0, - chip_bootstrap_configs=cfgs, - build=build, + level=3, platform=platform, runtime=RUNTIME, device_ids=device_ids, + num_sub_workers=0, chip_bootstrap_configs=cfgs, build=build, ) chip_cid = worker.register(chip_callable) + cfg = CallConfig() + cfg.block_dim = 24 + cfg.aicpu_thread_num = 4 + # Swimlane / L2 perf trace — each chip writes /l2_perf_records.json, + # so per-chip dirs (chip-0 / chip-1) are required to avoid the second rank + # overwriting the first. We also emit a func_names.json mapping + # func_id -> kernel name so swimlane_converter can label tasks instead of + # showing the default ``func_xx`` placeholders. + swimlane_base = None + if os.environ.get("EP_SWIMLANE", "") == "1": + swimlane_base = os.environ.get("EP_SWIMLANE_DIR", os.path.join(HERE, "outputs", "swimlane")) + os.makedirs(swimlane_base, exist_ok=True) + with open(os.path.join(swimlane_base, "func_names.json"), "w") as f: + import json as _json + _json.dump({ + "orchestrator_name": "ep_dispatch_combine_orchestration", + "callable_id_to_name": {str(fid): name for (fid, name, _, _) in KERNELS}, + }, f, indent=2) + print(f"[ep_dispatch] swimlane enabled, base={swimlane_base}") + try: print("[ep_dispatch] init worker (forks chip children + bootstraps HCCL)...") worker.init() - contexts: list[ChipContext] = worker.chip_contexts assert len(contexts) == nranks for i, ctx in enumerate(contexts): print( f"[ep_dispatch] chip {i}: device={ctx.device_id} rank={ctx.rank}/{ctx.nranks} " - f"window=[0x{ctx.local_window_base:x} +{ctx.actual_window_size}B] " - f"scratch=0x{ctx.buffer_ptrs['scratch']:x}" + f"window=[0x{ctx.local_window_base:x} +{ctx.actual_window_size}B] scratch=0x{ctx.buffer_ptrs['scratch']:x}" ) - def orch_fn(orch, _args, cfg): + def orch_fn(orch, _args, _cfg): for i, ctx in enumerate(contexts): - chip_args = TaskArgs() - chip_args.add_tensor(make_tensor_arg(indices_per_rank[i]), TensorArgType.INPUT) - chip_args.add_tensor(make_tensor_arg(x_norms[i]), TensorArgType.INPUT) - chip_args.add_tensor(make_tensor_arg(w_padded_list[i]), TensorArgType.INPUT) - chip_args.add_tensor(make_tensor_arg(idx_padded_list[i]), TensorArgType.INPUT) - chip_args.add_tensor(make_tensor_arg(recv_x_outs[i]), TensorArgType.OUTPUT_EXISTING) - chip_args.add_tensor(make_tensor_arg(recv_w_outs[i]), TensorArgType.OUTPUT_EXISTING) - chip_args.add_tensor(make_tensor_arg(recv_idx_outs[i]), TensorArgType.OUTPUT_EXISTING) - chip_args.add_tensor(make_tensor_arg(recv_count_outs[i]), TensorArgType.OUTPUT_EXISTING) - chip_args.add_tensor(make_tensor_arg(recv_y_outs[i]), TensorArgType.OUTPUT_EXISTING) - chip_args.add_tensor(make_tensor_arg(routed_y_outs[i]), TensorArgType.OUTPUT_EXISTING) - chip_args.add_tensor( + a = TaskArgs() + # 9 router host inputs (0..8) + a.add_tensor(make_tensor_arg(R_in["x_hcs"][i]), TensorArgType.INPUT) + a.add_tensor(make_tensor_arg(R_in["hc_ffn_fn"]), TensorArgType.INPUT) + a.add_tensor(make_tensor_arg(R_in["hc_ffn_scale"]), TensorArgType.INPUT) + a.add_tensor(make_tensor_arg(R_in["hc_ffn_base"]), TensorArgType.INPUT) + a.add_tensor(make_tensor_arg(R_in["norm_w"]), TensorArgType.INPUT) + a.add_tensor(make_tensor_arg(R_in["gate_w"]), TensorArgType.INPUT) + a.add_tensor(make_tensor_arg(R_in["gate_bias"]), TensorArgType.INPUT) + a.add_tensor(make_tensor_arg(R_in["tid2eid"]), TensorArgType.INPUT) + a.add_tensor(make_tensor_arg(R_in["input_ids_list"][i]), TensorArgType.INPUT) + # 2 dispatch host inputs (9..10) + a.add_tensor(make_tensor_arg(w_padded_list[i]), TensorArgType.INPUT) + a.add_tensor(make_tensor_arg(idx_padded_list[i]), TensorArgType.INPUT) + # 1 host-known recv count (11) + a.add_tensor(make_tensor_arg(recv_count_host_list[i]), TensorArgType.INPUT) + # 12 moe_expert weight tensors (12..23) + for name in ("expert_w1", "expert_w1_scale", "expert_w3", "expert_w3_scale", + "expert_w2", "expert_w2_scale", "shared_w1", "shared_w1_scale", + "shared_w3", "shared_w3_scale", "shared_w2", "shared_w2_scale"): + a.add_tensor(make_tensor_arg(weight_banks[name]), TensorArgType.INPUT) + # 13 chip OUTPUT_EXISTING tensors (24..36) + a.add_tensor(make_tensor_arg(x_norm_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(indices_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(weights_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(post_ffn_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(comb_ffn_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(recv_x_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(recv_w_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(recv_idx_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(recv_count_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(recv_y_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(sh_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(routed_y_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(ffn_out_outs[i]), TensorArgType.OUTPUT_EXISTING) + a.add_tensor(make_tensor_arg(y_outs[i]), TensorArgType.OUTPUT_EXISTING) + # scratch (38) + a.add_tensor( ContinuousTensor.make( - data=ctx.buffer_ptrs["scratch"], - shapes=(SCRATCH_NBYTES // 4,), - dtype=DataType.FLOAT32, - child_memory=True, + data=ctx.buffer_ptrs["scratch"], shapes=(SCRATCH_NBYTES // 4,), + dtype=DataType.FLOAT32, child_memory=True, ), TensorArgType.INOUT, ) - chip_args.add_scalar(ctx.nranks) - chip_args.add_scalar(ctx.device_ctx) - orch.submit_next_level(chip_cid, chip_args, cfg, worker=i) - - print("[ep_dispatch] running 2-chip dispatch DAG...") - worker.run(orch_fn, args=None, config=CallConfig()) - - ok = _verify_recv_outputs( - nranks, - expected_count, - expected_recv_x, - expected_recv_w, - expected_recv_idx, - recv_count_outs, - recv_x_outs, - recv_w_outs, - recv_idx_outs, + a.add_scalar(ctx.nranks) + a.add_scalar(ctx.device_ctx) + # Per-chip cfg so each rank's swimlane lands in its own subdir. + cfg_i = CallConfig() + cfg_i.block_dim = cfg.block_dim + cfg_i.aicpu_thread_num = cfg.aicpu_thread_num + if swimlane_base is not None: + cfg_i.enable_l2_swimlane = True + cfg_i.output_prefix = os.path.join(swimlane_base, f"chip-{i}") + os.makedirs(cfg_i.output_prefix, exist_ok=True) + orch.submit_next_level(chip_cid, a, cfg_i, worker=i) + + print("[ep_dispatch] running 2-chip router + dispatch + moe_expert + combine DAG...") + worker.run(orch_fn, args=None, config=cfg) + + ok = _verify_router_outputs( + nranks, x_norm_goldens, indices_goldens, weights_goldens, + x_norm_outs, indices_outs, weights_outs, ) - ok = _verify_routed_y(nranks, x_norms, weights, routed_y_outs) and ok + ok = _verify_recv_outputs( + nranks, expected_count, expected_recv_x, expected_recv_w, expected_recv_idx, + recv_count_outs, recv_x_outs, recv_w_outs, recv_idx_outs, + ) and ok + ok = _verify_expert_outputs(nranks, recv_y_goldens, sh_goldens, expected_count, recv_y_outs, sh_outs) and ok + ok = _verify_routed_y(nranks, route_dest, recv_y_goldens, routed_y_outs) and ok + ok = _verify_ffn_out(nranks, route_dest, recv_y_goldens, sh_goldens, ffn_out_outs) and ok + ok = _verify_hc_post(nranks, R_in["x_hcs"], post_ffn_goldens, comb_ffn_goldens, + route_dest, recv_y_goldens, sh_goldens, y_outs) and ok if not ok: print("[ep_dispatch] golden check FAILED") @@ -590,14 +940,14 @@ def orch_fn(orch, _args, cfg): def main() -> int: parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument("-d", "--device", default="0-1", help="Device range, e.g. '0-1'. Two chips required.") - parser.add_argument("-p", "--platform", default="a2a3", help="Platform backend, e.g. a2a3 or a2a3sim.") - parser.add_argument( - "--build", action="store_true", help="Rebuild runtime from source instead of using cached libs." - ) + parser.add_argument("-p", "--platform", default="a2a3", help="Platform backend.") + parser.add_argument("--build", action="store_true", help="Rebuild runtime from source instead of using cached libs.") parser.add_argument("--pto-isa-commit", default=None, help="Optional PTO ISA commit/tag to fetch before compiling.") + parser.add_argument("--seed", type=int, default=20260513, help="Seed for the random input fixture.") cli = parser.parse_args() return run( - parse_device_range(cli.device), platform=cli.platform, pto_isa_commit=cli.pto_isa_commit, build=cli.build + parse_device_range(cli.device), platform=cli.platform, pto_isa_commit=cli.pto_isa_commit, + build=cli.build, seed=cli.seed, ) diff --git a/examples/workers/l3/ep_dispatch_combine/test_ep_dispatch_combine.py b/examples/workers/l3/ep_dispatch_combine/test_ep_dispatch_combine.py index 44a1bf5e6..e1dc65f5a 100644 --- a/examples/workers/l3/ep_dispatch_combine/test_ep_dispatch_combine.py +++ b/examples/workers/l3/ep_dispatch_combine/test_ep_dispatch_combine.py @@ -13,7 +13,9 @@ from .main import run -@pytest.mark.platforms(["a2a3sim", "a2a3", "a5sim"]) +# Hardware only — the moe_expert path runs D=MOE_INTER=4096 INT8 matmuls that +# are far too slow under simulation. +@pytest.mark.platforms(["a2a3"]) @pytest.mark.runtime("tensormap_and_ringbuffer") @pytest.mark.device_count(2) def test_ep_dispatch_combine(st_device_ids, st_platform):