Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 14 additions & 25 deletions include/pto/cpu/TPush.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ enum class TransferDir : uint8_t
template <typename TileProd>
PTO_INTERNAL constexpr bool IsC2VProducerTile()
{
return TileProd::Loc == TileType::Acc;
return TileProd::Loc == TileType::Acc || TileProd::Loc == TileType::Mat;
}

template <typename TileProd>
Expand Down Expand Up @@ -694,14 +694,10 @@ struct TPipe {
if (fifo.GM_SLOT_BUFFER != nullptr) {
popTileFromGMFiFo<TileCons, Split>(fifo, tile);
return true;
} else if constexpr (TPipe::is_c2v) {
if constexpr (Split == TileSplitAxis::TILE_NO_SPLIT) {
popTileFromVecFiFo<TileCons, Split>(fifo, tile);
} else {
popTileFromVecFiFoSplit<TileCons, Split>(fifo, tile);
}
} else if constexpr (TPipe::is_c2v && TileCons::Loc == TileType::Vec) {
popTileFromVecFiFoSplit<TileCons, Split>(fifo, tile);
return false;
} else if constexpr (TPipe::is_v2c) {
} else if constexpr (TPipe::is_v2c && TileCons::Loc != TileType::Vec) {
popTileFromMatFiFo<TileCons, Split>(fifo, tile);
return false;
}
Expand All @@ -728,20 +724,13 @@ PTO_INTERNAL void TPush_c2v(Pipe &pipe, TileProd &tile, size_t entryBase, size_t
constexpr int consCols =
(Split == TileSplitAxis::TILE_LEFT_RIGHT) ? (TileProd::Cols / 2) : static_cast<int>(TileProd::Cols);

if constexpr (Split == TileSplitAxis::TILE_NO_SPLIT) {
using SlotTile = Tile<TileType::Vec, T, consRows, consCols, BLayout::RowMajor, consRows, consCols>;
SlotTile slotTile;
TASSIGN(slotTile, static_cast<uint64_t>(pipe.fifo.C2V_CONSUMER_BUF + entryBase));
cpu_pipe::CopyTileWindow(slotTile, tile, 0, 0);
} else {
auto &slotStorage = Pipe::GetSharedState().local_slot_storage[slotIndex];
for (uint32_t splitIndex = 0; splitIndex < cpu_pipe::GetSplitCount<Split>(); ++splitIndex) {
auto *slotPtr = reinterpret_cast<T *>(slotStorage.data() + splitIndex * Pipe::RingFiFo::SLOT_SIZE +
pipe.prod.entryOffset);
const uint32_t rowOffset = (Split == TileSplitAxis::TILE_UP_DOWN) ? splitIndex * consRows : 0;
const uint32_t colOffset = (Split == TileSplitAxis::TILE_LEFT_RIGHT) ? splitIndex * consCols : 0;
cpu_pipe::CopyTileWindowToLinear(slotPtr, consCols, tile, consRows, rowOffset, colOffset);
}
auto &slotStorage = Pipe::GetSharedState().local_slot_storage[slotIndex];
for (uint32_t splitIndex = 0; splitIndex < cpu_pipe::GetSplitCount<Split>(); ++splitIndex) {
auto *slotPtr =
reinterpret_cast<T *>(slotStorage.data() + splitIndex * Pipe::RingFiFo::SLOT_SIZE + pipe.prod.entryOffset);
const uint32_t rowOffset = (Split == TileSplitAxis::TILE_UP_DOWN) ? splitIndex * consRows : 0;
const uint32_t colOffset = (Split == TileSplitAxis::TILE_LEFT_RIGHT) ? splitIndex * consCols : 0;
cpu_pipe::CopyTileWindowToLinear(slotPtr, consCols, tile, consRows, rowOffset, colOffset);
}
}

Expand Down Expand Up @@ -801,10 +790,10 @@ PTO_INTERNAL void TPUSH_IMPL(Pipe &pipe, TileProd &tile)
GlobalData globalData(addr);
TSTORE(globalData, tile);
}
} else if constexpr (Pipe::is_c2v) {
TPush_c2v<Pipe, TileProd, Split>(pipe, tile, entryBase, slotIndex);
} else if constexpr (Pipe::is_v2c) {
} else if constexpr (Pipe::is_v2c && TileProd::Loc == TileType::Vec) {
TPush_v2c<Pipe, TileProd, Split>(pipe, tile, entryBase);
} else if constexpr (Pipe::is_c2v && TileProd::Loc != TileType::Vec) {
TPush_c2v<Pipe, TileProd, Split>(pipe, tile, entryBase, slotIndex);
Comment thread
azizbek-khabibov marked this conversation as resolved.
}
if (pipe.prod.getRecordStatus()) {
pipe.prod.template record<TileProd, Split>();
Expand Down
1 change: 1 addition & 0 deletions tests/cpu/st/testcase/tpush_a3/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pto_cpu_sim_st(tpush_a3)
27 changes: 27 additions & 0 deletions tests/cpu/st/testcase/tpush_a3/gen_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env python3
import numpy as np
import os

# Generate random inputs
a = np.random.randn(32, 64).astype(np.float32)
b = np.random.randn(64, 512).astype(np.float32)
c_prev = np.random.randn(32, 512).astype(np.float32)

# Compute golden output: c = c_prev + matmul(a + 1, b)
c_golden = c_prev + np.matmul(a + 1.0, b)

# Save as raw binary files


case_name = "TPUSH_A3Test.case_1"
if not os.path.exists(case_name):
os.makedirs(case_name)
original_dir = os.getcwd()
os.chdir(case_name)

a.tofile("a.bin")
b.tofile("b.bin")
c_prev.tofile("c.bin")
c_golden.tofile("golden.bin")

os.chdir(original_dir)
Loading
Loading