Skip to content
Open

Tpush #115

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions include/pto/cpu/TPush.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ enum class TransferDir : uint8_t
template <typename TileProd>
PTO_INTERNAL constexpr bool IsC2VProducerTile()
{
return TileProd::Loc == TileType::Acc;
return TileProd::Loc == TileType::Acc || TileProd::Loc == TileType::Mat;
}

template <typename TileProd>
Expand Down Expand Up @@ -693,14 +693,14 @@ struct TPipe {
if (fifo.GM_SLOT_BUFFER != nullptr) {
popTileFromGMFiFo<TileCons, Split>(fifo, tile);
return true;
} else if constexpr (TPipe::is_c2v) {
} else if constexpr (TPipe::is_c2v && TileCons::Loc == TileType::Vec) { // && TileCons::Loc != TileType::Vec
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The trailing comment is confusing and should be removed.

            } else if constexpr (TPipe::is_c2v && TileCons::Loc == TileType::Vec) {

if constexpr (Split == TileSplitAxis::TILE_NO_SPLIT) {
popTileFromVecFiFo<TileCons, Split>(fifo, tile);
popTileFromVecFiFoSplit<TileCons, Split>(fifo, tile); // popTileFromVecFiFoSplit
} else {
popTileFromVecFiFoSplit<TileCons, Split>(fifo, tile);
}
return false;
} else if constexpr (TPipe::is_v2c) {
} else if constexpr (TPipe::is_v2c && TileCons::Loc != TileType::Vec) { // && TileCons::Loc == TileType::Vec
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The trailing comment is confusing and should be removed.

            } else if constexpr (TPipe::is_v2c && TileCons::Loc != TileType::Vec) {

popTileFromMatFiFo<TileCons, Split>(fifo, tile);
return false;
}
Expand Down Expand Up @@ -728,10 +728,18 @@ PTO_INTERNAL void TPush_c2v(Pipe &pipe, TileProd &tile, size_t entryBase, size_t
(Split == TileSplitAxis::TILE_LEFT_RIGHT) ? (TileProd::Cols / 2) : static_cast<int>(TileProd::Cols);

if constexpr (Split == TileSplitAxis::TILE_NO_SPLIT) {
using SlotTile = Tile<TileType::Vec, T, consRows, consCols, BLayout::RowMajor, consRows, consCols>;
SlotTile slotTile;
TASSIGN(slotTile, static_cast<uint64_t>(pipe.fifo.C2V_CONSUMER_BUF + entryBase));
cpu_pipe::CopyTileWindow(slotTile, tile, 0, 0);
// using SlotTile = Tile<TileType::Vec, T, consRows, consCols, BLayout::RowMajor, consRows, consCols>;
// SlotTile slotTile;
// TASSIGN(slotTile, static_cast<uint64_t>(pipe.fifo.C2V_CONSUMER_BUF + entryBase));
// cpu_pipe::CopyTileWindow(slotTile, tile, 0, 0);
Comment on lines +731 to +734
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out code is dead code and should be removed to improve maintainability.

auto &slotStorage = Pipe::GetSharedState().local_slot_storage[slotIndex];
for (uint32_t splitIndex = 0; splitIndex < cpu_pipe::GetSplitCount<Split>(); ++splitIndex) {
auto *slotPtr = reinterpret_cast<T *>(slotStorage.data() + splitIndex * Pipe::RingFiFo::SLOT_SIZE +
pipe.prod.entryOffset);
const uint32_t rowOffset = (Split == TileSplitAxis::TILE_UP_DOWN) ? splitIndex * consRows : 0;
const uint32_t colOffset = (Split == TileSplitAxis::TILE_LEFT_RIGHT) ? splitIndex * consCols : 0;
cpu_pipe::CopyTileWindowToLinear(slotPtr, consCols, tile, consRows, rowOffset, colOffset);
}
} else {
auto &slotStorage = Pipe::GetSharedState().local_slot_storage[slotIndex];
for (uint32_t splitIndex = 0; splitIndex < cpu_pipe::GetSplitCount<Split>(); ++splitIndex) {
Expand Down Expand Up @@ -795,15 +803,15 @@ PTO_INTERNAL void TPUSH_IMPL(Pipe &pipe, TileProd &tile)
TSTORE(globalData, tile);
} else {
using GlobalData = GlobalTensor<T, Shape<1, 1, 1, rows, cols>, Stride<1, 1, 1, cols, 1>>;
auto *addr =
auto *addr =
reinterpret_cast<__gm__ T *>(reinterpret_cast<std::uintptr_t>(pipe.fifo.GM_SLOT_BUFFER) + entryBase);
GlobalData globalData(addr);
TSTORE(globalData, tile);
}
} else if constexpr (Pipe::is_c2v) {
TPush_c2v<Pipe, TileProd, Split>(pipe, tile, entryBase, slotIndex);
} else if constexpr (Pipe::is_v2c) {
} else if constexpr (Pipe::is_v2c && TileProd::Loc == TileType::Vec) {
TPush_v2c<Pipe, TileProd, Split>(pipe, tile, entryBase);
} else if constexpr (Pipe::is_c2v && TileProd::Loc != TileType::Vec) {
TPush_c2v<Pipe, TileProd, Split>(pipe, tile, entryBase, slotIndex);
}
if (pipe.prod.getRecordStatus()) {
pipe.prod.template record<TileProd, Split>();
Expand Down
230 changes: 116 additions & 114 deletions tests/cpu/st/testcase/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,120 +37,122 @@ endfunction()
find_package(Threads REQUIRED)

set(ALL_TESTCASES
hashfind
mgather
mscatter
setgetval
tabs
tadd
taddc
tadds
taddsc
tassign_alias
tand
tands
targreduceop
taxpy
tbroadcast
tci
tcmp
tcmps
tcolexpand
tcolexpandop
tcolmax
tcolmin
tcolprod
tcolreduceidx
tcolsum
tconcat
tcvt
tdequant
tdiv
tdivs
texp
texpands
textract
tfillpad
tflashattn
tfmod
tfmods
tgather
tgatherb
tget
tget_async
tget_scale_addr
tgetscaleaddr
thistogram
timg2col
tinsert
tload
tloadconv
tlog
tlrelu
tmatmul
tmatmul_layout
#tmatmul_mx
tmax
tmaxs
tmin
tmins
tmov
tmrgsort
tmul
tmuls
tneg
tnot
tnotify
tor
tors
tpartadd
tpartmul
tpartmax
tpartmin
tprefetch
tprelu
tpushpop
tput
tput_async
tquant
trandom
trecip
treduce
trelu
trem
trems
treshape
trowexpand
trowexpandop
trowmax
trowmin
trowreduceidx
trowsum
trsqrt
tscatter
tsel
tsels
tshl
tshls
tshr
tshrs
tsort32
tsqrt
tstore
tsub
tsubview
tsubc
tsubs
tsubsc
ttest
ttrans
ttri
twait
txor
txors
tpushpop_cv_nosplit
tpushpop_cv
tpushpop_vc_nosplit
# hashfind
# mgather
# mscatter
# setgetval
# tabs
# tadd
# taddc
# tadds
# taddsc
# tassign_alias
# tand
# tands
# targreduceop
# taxpy
# tbroadcast
# tci
# tcmp
# tcmps
# tcolexpand
# tcolexpandop
# tcolmax
# tcolmin
# tcolprod
# tcolreduceidx
# tcolsum
# tconcat
# tcvt
# tdequant
# tdiv
# tdivs
# texp
# texpands
# textract
# tfillpad
# tflashattn
# tfmod
# tfmods
# tgather
# tgatherb
# tget
# tget_async
# tget_scale_addr
# tgetscaleaddr
# thistogram
# timg2col
# tinsert
# tload
# tloadconv
# tlog
# tlrelu
# tmatmul
# tmatmul_layout
# #tmatmul_mx
# tmax
# tmaxs
# tmin
# tmins
# tmov
# tmrgsort
# tmul
# tmuls
# tneg
# tnot
# tnotify
# tor
# tors
# tpartadd
# tpartmul
# tpartmax
# tpartmin
# tprefetch
# tprelu
# tpushpop
# tput
# tput_async
# tquant
# trandom
# trecip
# treduce
# trelu
# trem
# trems
# treshape
# trowexpand
# trowexpandop
# trowmax
# trowmin
# trowreduceidx
# trowsum
# trsqrt
# tscatter
# tsel
# tsels
# tshl
# tshls
# tshr
# tshrs
# tsort32
# tsqrt
# tstore
# tsub
# tsubview
# tsubc
# tsubs
# tsubsc
# ttest
# ttrans
# ttri
# twait
# txor
# txors
# tpushpop_cv_nosplit
# tpushpop_cv
# tpushpop_vc_nosplit
Comment on lines +40 to +153
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Commenting out a large number of existing tests is a significant regression in test coverage and should be avoided.

tpush_a3
tpush_a5
)

foreach(TESTCASE ${ALL_TESTCASES})
Expand Down
1 change: 1 addition & 0 deletions tests/cpu/st/testcase/tpush_a3/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pto_cpu_sim_st(tpush_a3)
27 changes: 27 additions & 0 deletions tests/cpu/st/testcase/tpush_a3/gen_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env python3
import numpy as np
import os

# Generate random inputs
a = np.random.randn(32, 64).astype(np.float32)
b = np.random.randn(64, 512).astype(np.float32)
c_prev = np.random.randn(32, 512).astype(np.float32)

# Compute golden output: c = c_prev + matmul(a + 1, b)
c_golden = c_prev + np.matmul(a + 1.0, b)

# Save as raw binary files


case_name = "TPUSH_A3Test.case_1"
if not os.path.exists(case_name):
os.makedirs(case_name)
original_dir = os.getcwd()
os.chdir(case_name)

a.tofile("a.bin")
b.tofile("b.bin")
c_prev.tofile("c.bin")
c_golden.tofile("golden.bin")

os.chdir(original_dir)
Loading