-
Notifications
You must be signed in to change notification settings - Fork 49
WIP: ep_dispatch_combine: full DSV4 decode single-layer chain #772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
zhangqi-chen
wants to merge
5
commits into
hw-native-sys:main
Choose a base branch
from
zhangqi-chen:dsv4_moe_demo
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
a772ebf
ep_dispatch_combine: scale demo to production decode dims (D=4096, L=…
zhangqi-chen aa86d04
ep_dispatch_combine: replace local_expert placeholder with moe_expert
zhangqi-chen 512641b
ep_dispatch_combine: prepend moe_router to the pipeline
zhangqi-chen 8c2b39f
ep_dispatch_combine: append ffn_add (ffn_out = routed_y + sh) stage
zhangqi-chen b974a95
ep_dispatch_combine: append hc_post stage (next-layer x_hc)
zhangqi-chen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
402 changes: 402 additions & 0 deletions
402
examples/workers/l3/ep_dispatch_combine/kernels/aic/exp_gate_up_matmul.cpp
Large diffs are not rendered by default.
Oops, something went wrong.
245 changes: 245 additions & 0 deletions
245
examples/workers/l3/ep_dispatch_combine/kernels/aic/exp_w2_matmul.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,245 @@ | ||
| // Kernel Function: exp_w2_matmul | ||
| // Generated by PyPTO IR Compiler (PTO backend) | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| #ifndef __gm__ | ||
| #define __gm__ | ||
| #endif | ||
|
|
||
| #ifndef __aicore__ | ||
| #if defined(__CPU_SIM) | ||
| #define __aicore__ | ||
| #else | ||
| #define __aicore__ [aicore] | ||
| #endif | ||
| #endif | ||
|
|
||
| #include <pto/pto-inst.hpp> | ||
| #include "tensor.h" | ||
|
|
||
|
|
||
| using namespace pto; | ||
|
|
||
|
|
||
| // --- ptoas-generated code --- | ||
|
|
||
| enum class PTOAutoSyncTailMode : int { | ||
| kBarrierAll = 0, | ||
| kSetWaitMte3ToSEvent0 = 1, | ||
| }; | ||
|
|
||
| static __aicore__ inline void ptoas_auto_sync_tail( | ||
| PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { | ||
| switch (mode) { | ||
| case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: | ||
| set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); | ||
| wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); | ||
| break; | ||
| case PTOAutoSyncTailMode::kBarrierAll: | ||
| default: | ||
| pipe_barrier(PIPE_ALL); | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| static __aicore__ void exp_w2_matmul(__gm__ int8_t* v1, __gm__ int8_t* v2, __gm__ int32_t* v3, int32_t v4, int32_t v5) { | ||
| unsigned v6 = 0; | ||
| const int32_t v7 = 64; | ||
| const int32_t v8 = 128; | ||
| const int32_t v9 = 0; | ||
| const int32_t v10 = 16777216; | ||
| const int32_t v11 = 512; | ||
| const int32_t v12 = 1; | ||
| const int32_t v13 = 4096; | ||
| const int32_t v14 = 16; | ||
| const int64_t v15 = 32768; | ||
| const int64_t v16 = 1024; | ||
| const int64_t v17 = 8192; | ||
| const int64_t v18 = 0; | ||
| const int32_t v19 = 8192; | ||
| using T = float; | ||
|
|
||
| #if defined(__DAV_CUBE__) | ||
| size_t v20 = (size_t) v11; | ||
| size_t v21 = (size_t) v9; | ||
| size_t v22 = (size_t) v8; | ||
| Tile<TileType::Mat, int8_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v23 = Tile<TileType::Mat, int8_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v11); | ||
| uint64_t v24 = (uint64_t) v18; | ||
| TASSIGN(v23, v24); | ||
| pto::Shape<1, 1, 1, 16, 512> v25 = pto::Shape<1, 1, 1, 16, 512>(); | ||
| pto::Stride<65536, 65536, 65536, 4096, 1> v26 = pto::Stride<65536, 65536, 65536, 4096, 1>(); | ||
| GlobalTensor<int8_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v27 = GlobalTensor<int8_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v13 + v6 * (unsigned) v12), v25, v26); | ||
| set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); | ||
| set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); | ||
| set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); | ||
| set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); | ||
| set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); | ||
| TLOAD(v23, v27); | ||
| Tile<TileType::Mat, int8_t, 512, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v28 = Tile<TileType::Mat, int8_t, 512, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v11, v11); | ||
| uint64_t v29 = (uint64_t) v17; | ||
| TASSIGN(v28, v29); | ||
| pto::Shape<1, 1, 1, 512, 512> v30 = pto::Shape<1, 1, 1, 512, 512>(); | ||
| pto::Stride<16777216, 16777216, 16777216, 1, 4096> v31 = pto::Stride<16777216, 16777216, 16777216, 1, 4096>(); | ||
| GlobalTensor<int8_t, pto::Shape<1, 1, 1, 512, 512>, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN> v32 = GlobalTensor<int8_t, pto::Shape<1, 1, 1, 512, 512>, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN>(v2 + ((v6 + (unsigned) v4 * (unsigned) v10) + v6 * (unsigned) v12 + (unsigned) v5 * (unsigned) v13), v30, v31); | ||
| TLOAD(v28, v32); | ||
| set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); | ||
| Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v33 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11); | ||
| uint64_t v34 = (uint64_t) v18; | ||
| TASSIGN(v33, v34); | ||
| wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); | ||
| for (size_t v35 = v21; v35 < v20; v35 += v22) { | ||
| int32_t v36 = (int32_t) v35; | ||
| Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v37 = Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v7); | ||
| uint64_t v38 = (uint64_t) v18; | ||
| TASSIGN(v37, v38); | ||
| wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); | ||
| pipe_barrier(PIPE_MTE1); | ||
| TEXTRACT(v37, v23, v9, v35); | ||
| Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v39 = Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v7, v11); | ||
| uint64_t v40 = (uint64_t) v18; | ||
| TASSIGN(v39, v40); | ||
| TEXTRACT(v39, v28, v35, v9); | ||
| set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); | ||
| Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v41 = Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v7); | ||
| uint64_t v42 = (uint64_t) v16; | ||
| TASSIGN(v41, v42); | ||
| int32_t v43 = (int32_t) ((uint32_t) v36 + (uint32_t) v7); | ||
| wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); | ||
| TEXTRACT(v41, v23, v9, v43); | ||
| Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v44 = Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v7, v11); | ||
| uint64_t v45 = (uint64_t) v15; | ||
| TASSIGN(v44, v45); | ||
| TEXTRACT(v44, v28, v43, v9); | ||
| set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); | ||
| wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); | ||
| if (v36 == v9) { | ||
| Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v46 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11); | ||
| uint64_t v47 = (uint64_t) v18; | ||
| TASSIGN(v46, v47); | ||
| pipe_barrier(PIPE_M); | ||
| TMATMUL(v46, v37, v39); | ||
| } else { | ||
| Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v48 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11); | ||
| uint64_t v49 = (uint64_t) v18; | ||
| TASSIGN(v48, v49); | ||
| pipe_barrier(PIPE_M); | ||
| TMATMUL_ACC(v48, v48, v37, v39); | ||
| }; | ||
| set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); | ||
| Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v50 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11); | ||
| uint64_t v51 = (uint64_t) v18; | ||
| TASSIGN(v50, v51); | ||
| pipe_barrier(PIPE_M); | ||
| wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); | ||
| TMATMUL_ACC(v50, v50, v41, v44); | ||
| set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); | ||
| } | ||
| set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); | ||
| set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); | ||
| wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); | ||
| wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); | ||
| for (size_t v52 = v20; v52 < ((size_t) v13); v52 += v20) { | ||
| int32_t v53 = (int32_t) v52; | ||
| Tile<TileType::Mat, int8_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v54 = Tile<TileType::Mat, int8_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v11); | ||
| uint64_t v55 = (uint64_t) v18; | ||
| TASSIGN(v54, v55); | ||
| pto::Shape<1, 1, 1, 16, 512> v56 = pto::Shape<1, 1, 1, 16, 512>(); | ||
| pto::Stride<65536, 65536, 65536, 4096, 1> v57 = pto::Stride<65536, 65536, 65536, 4096, 1>(); | ||
| GlobalTensor<int8_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND> v58 = GlobalTensor<int8_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<65536, 65536, 65536, 4096, 1>, pto::Layout::ND>(v1 + (v6 + v6 * (unsigned) v13 + (unsigned) v53 * (unsigned) v12), v56, v57); | ||
| pipe_barrier(PIPE_MTE2); | ||
| wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); | ||
| TLOAD(v54, v58); | ||
| Tile<TileType::Mat, int8_t, 512, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v59 = Tile<TileType::Mat, int8_t, 512, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v11, v11); | ||
| uint64_t v60 = (uint64_t) v17; | ||
| TASSIGN(v59, v60); | ||
| pto::Shape<1, 1, 1, 512, 512> v61 = pto::Shape<1, 1, 1, 512, 512>(); | ||
| pto::Stride<16777216, 16777216, 16777216, 1, 4096> v62 = pto::Stride<16777216, 16777216, 16777216, 1, 4096>(); | ||
| GlobalTensor<int8_t, pto::Shape<1, 1, 1, 512, 512>, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN> v63 = GlobalTensor<int8_t, pto::Shape<1, 1, 1, 512, 512>, pto::Stride<16777216, 16777216, 16777216, 1, 4096>, pto::Layout::DN>(v2 + ((v6 + (unsigned) v4 * (unsigned) v10) + (unsigned) v53 * (unsigned) v12 + (unsigned) v5 * (unsigned) v13), v61, v62); | ||
| TLOAD(v59, v63); | ||
| set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); | ||
| wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); | ||
| for (size_t v64 = v21; v64 < v20; v64 += v22) { | ||
| Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v65 = Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v7); | ||
| uint64_t v66 = (uint64_t) v18; | ||
| TASSIGN(v65, v66); | ||
| wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); | ||
| TEXTRACT(v65, v54, v9, v64); | ||
| Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v67 = Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v7, v11); | ||
| uint64_t v68 = (uint64_t) v18; | ||
| TASSIGN(v67, v68); | ||
| TEXTRACT(v67, v59, v64, v9); | ||
| set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); | ||
| Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v69 = Tile<TileType::Left, int8_t, 16, 64, BLayout::RowMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>(v14, v7); | ||
| uint64_t v70 = (uint64_t) v16; | ||
| TASSIGN(v69, v70); | ||
| int32_t v71 = (int32_t) ((uint32_t) ((int32_t) v64) + (uint32_t) v7); | ||
| wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); | ||
| TEXTRACT(v69, v54, v9, v71); | ||
| Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v72 = Tile<TileType::Right, int8_t, 64, 512, BLayout::RowMajor, -1, -1, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null>(v7, v11); | ||
| uint64_t v73 = (uint64_t) v15; | ||
| TASSIGN(v72, v73); | ||
| TEXTRACT(v72, v59, v71, v9); | ||
| set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); | ||
| Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v74 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11); | ||
| uint64_t v75 = (uint64_t) v18; | ||
| TASSIGN(v74, v75); | ||
| wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); | ||
| pipe_barrier(PIPE_M); | ||
| TMATMUL_ACC(v74, v74, v65, v67); | ||
| set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); | ||
| Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null> v76 = Tile<TileType::Acc, int32_t, 16, 512, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 1024, PadValue::Null, CompactMode::Null>(v14, v11); | ||
| uint64_t v77 = (uint64_t) v18; | ||
| TASSIGN(v76, v77); | ||
| pipe_barrier(PIPE_M); | ||
| wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); | ||
| TMATMUL_ACC(v76, v76, v69, v72); | ||
| set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); | ||
| }; | ||
| set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); | ||
| } | ||
| set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); | ||
| pto::Shape<1, 1, 1, 16, 512> v78 = pto::Shape<1, 1, 1, 16, 512>(); | ||
| pto::Stride<8192, 8192, 8192, 512, 1> v79 = pto::Stride<8192, 8192, 8192, 512, 1>(); | ||
| GlobalTensor<int32_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND> v80 = GlobalTensor<int32_t, pto::Shape<1, 1, 1, 16, 512>, pto::Stride<8192, 8192, 8192, 512, 1>, pto::Layout::ND>(v3 + ((v6 + v6 * (unsigned) v19) + v6 * (unsigned) v11 + v6 * (unsigned) v12), v78, v79); | ||
| wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); | ||
| TSTORE(v80, v33); | ||
| wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); | ||
| wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); | ||
| wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); | ||
| wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); | ||
| wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); | ||
| #endif // __DAV_CUBE__ | ||
|
|
||
| ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); | ||
| return; | ||
| } | ||
|
|
||
| // --- Kernel entry point --- | ||
| extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) | ||
| { | ||
| // Unpack tensor: h_tile_i8_inline92__rv_v2 | ||
| __gm__ Tensor* h_tile_i8_inline92__rv_v2_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); | ||
| __gm__ int8_t* h_tile_i8_inline92__rv_v2 = reinterpret_cast<__gm__ int8_t*>(h_tile_i8_inline92__rv_v2_tensor->buffer.addr) + h_tile_i8_inline92__rv_v2_tensor->start_offset; | ||
|
|
||
| // Unpack tensor: expert_w2__ssa_v0 | ||
| __gm__ Tensor* expert_w2__ssa_v0_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); | ||
| __gm__ int8_t* expert_w2__ssa_v0 = reinterpret_cast<__gm__ int8_t*>(expert_w2__ssa_v0_tensor->buffer.addr) + expert_w2__ssa_v0_tensor->start_offset; | ||
|
|
||
| // Unpack tensor: ret0__out | ||
| __gm__ Tensor* ret0__out_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); | ||
| __gm__ int32_t* ret0__out = reinterpret_cast<__gm__ int32_t*>(ret0__out_tensor->buffer.addr) + ret0__out_tensor->start_offset; | ||
|
|
||
| // Unpack scalar: local_i_inline67__idx_v0 | ||
| union { uint64_t u64; int64_t val; } local_i_inline67__idx_v0_conv; | ||
| local_i_inline67__idx_v0_conv.u64 = args[3]; | ||
| int64_t local_i_inline67__idx_v0 = local_i_inline67__idx_v0_conv.val; | ||
|
|
||
| // Unpack scalar: d0_inline49__idx_v0 | ||
| union { uint64_t u64; int64_t val; } d0_inline49__idx_v0_conv; | ||
| d0_inline49__idx_v0_conv.u64 = args[4]; | ||
| int64_t d0_inline49__idx_v0 = d0_inline49__idx_v0_conv.val; | ||
|
|
||
| // Forward to ptoas-generated function | ||
| exp_w2_matmul(h_tile_i8_inline92__rv_v2, expert_w2__ssa_v0, ret0__out, local_i_inline67__idx_v0, d0_inline49__idx_v0); | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a potential data loss due to type mismatch. The variables
local_i_inline67__idx_v0andd0_inline49__idx_v0are of typeint64_t, but the functionexp_w2_matmulexpectsint32_tfor its 4th and 5th arguments. This could lead to truncation and unexpected behavior if the values exceed theint32_trange. It's safer to explicitly cast them toint32_tto ensure consistency and avoid narrowing conversion warnings.References