Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

245 changes: 245 additions & 0 deletions examples/workers/l3/ep_dispatch_combine/kernels/aic/exp_w2_matmul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
// Kernel Function: exp_w2_matmul
// Generated by PyPTO IR Compiler (PTO backend)

#include <cstdint>

#ifndef __gm__
#define __gm__
#endif

#ifndef __aicore__
#if defined(__CPU_SIM)
#define __aicore__
#else
#define __aicore__ [aicore]
#endif
#endif

#include <pto/pto-inst.hpp>
#include "tensor.h"


using namespace pto;


// --- ptoas-generated code ---

enum class PTOAutoSyncTailMode : int {
kBarrierAll = 0,
kSetWaitMte3ToSEvent0 = 1,
};

static __aicore__ inline void ptoas_auto_sync_tail(
PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) {
switch (mode) {
case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0:
set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0);
break;
case PTOAutoSyncTailMode::kBarrierAll:
default:
pipe_barrier(PIPE_ALL);
break;
}
}

static __aicore__ void exp_w2_matmul(__gm__ int8_t* v1, __gm__ int8_t* v2, __gm__ int32_t* v3, int32_t v4, int32_t v5) {
unsigned v6 = 0;
const int32_t v7 = 64;
const int32_t v8 = 128;
const int32_t v9 = 0;
const int32_t v10 = 16777216;
const int32_t v11 = 512;
const int32_t v12 = 1;
const int32_t v13 = 4096;
const int32_t v14 = 16;
const int64_t v15 = 32768;
const int64_t v16 = 1024;
const int64_t v17 = 8192;
const int64_t v18 = 0;
const int32_t v19 = 8192;
using T = float;

#if defined(__DAV_CUBE__)
size_t v20 = (size_t) v11;
size_t v21 = (size_t) v9;
size_t v22 = (size_t) v8;
Tile<TileType::Mat, int8_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v23 = Tile<TileType::Mat, int8_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v11);
uint64_t v24 = (uint64_t) v18;
TASSIGN(v23, v24);
pto::Shape<1, 1, 1, 16, 512> v25 = pto::Shape<1, 1, 1, 16, 512>();
pto::Stride<65536, 65536, 65536, 4096, 1> v26 = pto::Stride<65536, 65536, 65536, 4096, 1>();
GlobalTensor<int8_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v27 = GlobalTensor<int8_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v13 + v6 * (unsigned) v12), v25, v26);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4);
TLOAD(v23, v27);
Tile<TileType::Mat, int8_t, 512, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v28 = Tile<TileType::Mat, int8_t, 512, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v11, v11);
uint64_t v29 = (uint64_t) v17;
TASSIGN(v28, v29);
pto::Shape<1, 1, 1, 512, 512> v30 = pto::Shape<1, 1, 1, 512, 512>();
pto::Stride<16777216, 16777216, 16777216, 1, 4096> v31 = pto::Stride<16777216, 16777216, 16777216, 1, 4096>();
GlobalTensor<int8_t, pto::Shape<1, 1, 1, 512, 512>, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN> v32 = GlobalTensor<int8_t, pto::Shape<1, 1, 1, 512, 512>, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN>(v2 + ((v6 + (unsigned) v4 * (unsigned) v10) + v6 * (unsigned) v12 + (unsigned) v5 * (unsigned) v13), v30, v31);
TLOAD(v28, v32);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v33 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11);
uint64_t v34 = (uint64_t) v18;
TASSIGN(v33, v34);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
for (size_t v35 = v21; v35 < v20; v35 += v22) {
int32_t v36 = (int32_t) v35;
Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v37 = Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v7);
uint64_t v38 = (uint64_t) v18;
TASSIGN(v37, v38);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
pipe_barrier(PIPE_MTE1);
TEXTRACT(v37, v23, v9, v35);
Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v39 = Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v7, v11);
uint64_t v40 = (uint64_t) v18;
TASSIGN(v39, v40);
TEXTRACT(v39, v28, v35, v9);
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v41 = Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v7);
uint64_t v42 = (uint64_t) v16;
TASSIGN(v41, v42);
int32_t v43 = (int32_t) ((uint32_t) v36 + (uint32_t) v7);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
TEXTRACT(v41, v23, v9, v43);
Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v44 = Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v7, v11);
uint64_t v45 = (uint64_t) v15;
TASSIGN(v44, v45);
TEXTRACT(v44, v28, v43, v9);
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
if (v36 == v9) {
Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v46 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11);
uint64_t v47 = (uint64_t) v18;
TASSIGN(v46, v47);
pipe_barrier(PIPE_M);
TMATMUL(v46, v37, v39);
} else {
Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v48 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11);
uint64_t v49 = (uint64_t) v18;
TASSIGN(v48, v49);
pipe_barrier(PIPE_M);
TMATMUL_ACC(v48, v48, v37, v39);
};
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v50 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11);
uint64_t v51 = (uint64_t) v18;
TASSIGN(v50, v51);
pipe_barrier(PIPE_M);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
TMATMUL_ACC(v50, v50, v41, v44);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
}
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2);
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2);
for (size_t v52 = v20; v52 < ((size_t) v13); v52 += v20) {
int32_t v53 = (int32_t) v52;
Tile<TileType::Mat, int8_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v54 = Tile<TileType::Mat, int8_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v11);
uint64_t v55 = (uint64_t) v18;
TASSIGN(v54, v55);
pto::Shape<1, 1, 1, 16, 512> v56 = pto::Shape<1, 1, 1, 16, 512>();
pto::Stride<65536, 65536, 65536, 4096, 1> v57 = pto::Stride<65536, 65536, 65536, 4096, 1>();
GlobalTensor<int8_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v58 = GlobalTensor<int8_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v13 + (unsigned) v53 * (unsigned) v12), v56, v57);
pipe_barrier(PIPE_MTE2);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
TLOAD(v54, v58);
Tile<TileType::Mat, int8_t, 512, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v59 = Tile<TileType::Mat, int8_t, 512, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v11, v11);
uint64_t v60 = (uint64_t) v17;
TASSIGN(v59, v60);
pto::Shape<1, 1, 1, 512, 512> v61 = pto::Shape<1, 1, 1, 512, 512>();
pto::Stride<16777216, 16777216, 16777216, 1, 4096> v62 = pto::Stride<16777216, 16777216, 16777216, 1, 4096>();
GlobalTensor<int8_t, pto::Shape<1, 1, 1, 512, 512>, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN> v63 = GlobalTensor<int8_t, pto::Shape<1, 1, 1, 512, 512>, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN>(v2 + ((v6 + (unsigned) v4 * (unsigned) v10) + (unsigned) v53 * (unsigned) v12 + (unsigned) v5 * (unsigned) v13), v61, v62);
TLOAD(v59, v63);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
for (size_t v64 = v21; v64 < v20; v64 += v22) {
Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v65 = Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v7);
uint64_t v66 = (uint64_t) v18;
TASSIGN(v65, v66);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3);
TEXTRACT(v65, v54, v9, v64);
Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v67 = Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v7, v11);
uint64_t v68 = (uint64_t) v18;
TASSIGN(v67, v68);
TEXTRACT(v67, v59, v64, v9);
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2);
Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v69 = Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v7);
uint64_t v70 = (uint64_t) v16;
TASSIGN(v69, v70);
int32_t v71 = (int32_t) ((uint32_t) ((int32_t) v64) + (uint32_t) v7);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4);
TEXTRACT(v69, v54, v9, v71);
Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v72 = Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v7, v11);
uint64_t v73 = (uint64_t) v15;
TASSIGN(v72, v73);
TEXTRACT(v72, v59, v71, v9);
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3);
Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v74 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11);
uint64_t v75 = (uint64_t) v18;
TASSIGN(v74, v75);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2);
pipe_barrier(PIPE_M);
TMATMUL_ACC(v74, v74, v65, v67);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3);
Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v76 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11);
uint64_t v77 = (uint64_t) v18;
TASSIGN(v76, v77);
pipe_barrier(PIPE_M);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3);
TMATMUL_ACC(v76, v76, v69, v72);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4);
};
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
}
set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
pto::Shape<1, 1, 1, 16, 512> v78 = pto::Shape<1, 1, 1, 16, 512>();
pto::Stride<8192, 8192, 8192, 512, 1> v79 = pto::Stride<8192, 8192, 8192, 512, 1>();
GlobalTensor<int32_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND> v80 = GlobalTensor<int32_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND>(v3 + ((v6 + v6 * (unsigned) v19) + v6 * (unsigned) v11 + v6 * (unsigned) v12), v78, v79);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
TSTORE(v80, v33);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4);
#endif // __DAV_CUBE__

ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll);
return;
}

// --- Kernel entry point ---
extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args)
{
// Unpack tensor: h_tile_i8_inline92__rv_v2
__gm__ Tensor* h_tile_i8_inline92__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]);
__gm__ int8_t* h_tile_i8_inline92__rv_v2 = reinterpret_cast<__gm__ int8_t*>(h_tile_i8_inline92__rv_v2_tensor->buffer.addr) + h_tile_i8_inline92__rv_v2_tensor->start_offset;

// Unpack tensor: expert_w2__ssa_v0
__gm__ Tensor* expert_w2__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]);
__gm__ int8_t* expert_w2__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(expert_w2__ssa_v0_tensor->buffer.addr) + expert_w2__ssa_v0_tensor->start_offset;

// Unpack tensor: ret0__out
__gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]);
__gm__ int32_t* ret0__out = reinterpret_cast<__gm__ int32_t*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset;

// Unpack scalar: local_i_inline67__idx_v0
union { uint64_t u64; int64_t val; } local_i_inline67__idx_v0_conv;
local_i_inline67__idx_v0_conv.u64 = args[3];
int64_t local_i_inline67__idx_v0 = local_i_inline67__idx_v0_conv.val;

// Unpack scalar: d0_inline49__idx_v0
union { uint64_t u64; int64_t val; } d0_inline49__idx_v0_conv;
d0_inline49__idx_v0_conv.u64 = args[4];
int64_t d0_inline49__idx_v0 = d0_inline49__idx_v0_conv.val;

// Forward to ptoas-generated function
exp_w2_matmul(h_tile_i8_inline92__rv_v2, expert_w2__ssa_v0, ret0__out, local_i_inline67__idx_v0, d0_inline49__idx_v0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential data loss due to type mismatch. The variables local_i_inline67__idx_v0 and d0_inline49__idx_v0 are of type int64_t, but the function exp_w2_matmul expects int32_t for its 4th and 5th arguments. This could lead to truncation and unexpected behavior if the values exceed the int32_t range. It's safer to explicitly cast them to int32_t to ensure consistency and avoid narrowing conversion warnings.

    exp_w2_matmul(h_tile_i8_inline92__rv_v2, expert_w2__ssa_v0, ret0__out, static_cast<int32_t>(local_i_inline67__idx_v0), static_cast<int32_t>(d0_inline49__idx_v0));
References
  1. Be consistent in type casting when assigning size_t values (like producers.size()) to int32_t variables (like s.fanin_count) across similar methods to avoid compiler warnings about narrowing conversions.

}
Loading
Loading