Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/pto/common/pto_instr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,15 @@ PTO_INST RecordEvent TGATHER(DstTileData &dst, SrcTileData &src, WaitEvents &...
return {};
}

template <MaskPattern maskPattern = MaskPattern::P1111, typename DstTileData, typename SrcTileData,
typename... WaitEvents>
PTO_INST RecordEvent TSCATTER(DstTileData &dst, SrcTileData &src, WaitEvents &...events)
{
TSYNC(events...);
TSCATTER_IMPL<maskPattern>(dst, src);
return {};
}

template <typename TileDataDst, typename TileDataSrc0, typename TileDataSrc1, typename... WaitEvents>
PTO_INST RecordEvent TPARTADD(TileDataDst &dst, TileDataSrc0 &src0, TileDataSrc1 &src1, WaitEvents &...events)
{
Expand Down
347 changes: 177 additions & 170 deletions include/pto/costmodel/pto_instr.hpp

Large diffs are not rendered by default.

38 changes: 38 additions & 0 deletions include/pto/cpu/TScatter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ See LICENSE in the root of the software repository for the full text of the Lice
#define TSCATTER_HPP

#include "pto/cpu/tile_offsets.hpp"
#include "pto/cpu/TGather.hpp"
#include <pto/common/pto_tile.hpp>
#include <type_traits>

Expand Down Expand Up @@ -40,6 +41,43 @@ PTO_INTERNAL void TSCATTER_IMPL(TileDataDst &dst, TileDataSrc &src, TileInd &ind
}
}

template <MaskPattern maskPattern, typename DstTileData, typename SrcTileData>
PTO_INTERNAL void TScatter(typename DstTileData::TileDType dst, typename SrcTileData::TileDType src, unsigned validRow,
unsigned validCol)
{
unsigned sR = 0;
unsigned sC = 0;
for (unsigned r = 0; r < validRow; r++) {
for (unsigned c = 0; c < validCol; c++) {
const size_t didx = GetTileElementOffset<DstTileData>(r, c);
if (MaskSelect(maskPattern, c)) {
const size_t sidx = GetTileElementOffset<SrcTileData>(sR, sC);
dst[didx] = static_cast<typename DstTileData::DType>(src[sidx]);
if (++sC == SrcTileData::Cols) {
sC = 0;
sR++;
}
} else {
dst[didx] = static_cast<typename DstTileData::DType>(0);
}
}
}
}

template <MaskPattern maskPattern, typename DstTileData, typename SrcTileData>
PTO_INTERNAL void TSCATTER_IMPL(DstTileData &dst, SrcTileData &src)
{
using T = typename SrcTileData::DType;
static_assert(sizeof(T) == 2 || sizeof(T) == 4, "TSCATTER: src element type must be 16 or 32-bit wide");
static_assert((DstTileData::Loc == TileType::Vec) && (SrcTileData::Loc == TileType::Vec),
"TSCATTER: expect vec TileType");
static_assert((DstTileData::isRowMajor && SrcTileData::isRowMajor), "TSCATTER: expect row major");
static_assert((sizeof(typename DstTileData::DType) == sizeof(T)),
"TSCATTER: expect same type size for dst and src");
assert(src.GetValidCol() == SrcTileData::Cols);
TScatter<maskPattern, DstTileData, SrcTileData>(dst.data(), src.data(), src.GetValidRow(), dst.GetValidCol());
}

} // namespace pto

#endif
140 changes: 140 additions & 0 deletions tests/cpu/st/testcase/tscatter/gen_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,42 @@

np.random.seed(19)

P0101 = 1
P1010 = 2
P0001 = 3
P0010 = 4
P0100 = 5
P1000 = 6
P1111 = 7

FLOAT_P0101_ROW = 4
FLOAT_P0101_COL = 64
FLOAT_P1010_ROW = 7
FLOAT_P1010_COL = 1024
FLOAT_P0001_ROW = 3
FLOAT_P0001_COL = 1056
FLOAT_P0010_ROW = 4
FLOAT_P0010_COL = 128
FLOAT_P0100_ROW = 5
FLOAT_P0100_COL = 256
FLOAT_P1000_ROW = 6
FLOAT_P1000_COL = 288
FLOAT_P1111_ROW = 7
FLOAT_P1111_COL = 320

HALF_P0101_ROW = 5
HALF_P0101_COL = 128
HALF_P1010_ROW = 7
HALF_P1010_COL = 1024
HALF_P0001_ROW = 3
HALF_P0001_COL = 1024
HALF_P0010_ROW = 4
HALF_P0010_COL = 128
HALF_P0100_ROW = 5
HALF_P0100_COL = 256
HALF_P1000_ROW = 6
HALF_P1000_COL = 256


def gen_case(case_dir: str, rows: int, cols: int):
os.makedirs(case_dir, exist_ok=True)
Expand All @@ -34,6 +70,110 @@ def gen_case(case_dir: str, rows: int, cols: int):
os.chdir("..")


class TScatterParamsMasked:
def __init__(self, name, src_type, row, dst_col, pattern):
self.testname = name
self.src_type = src_type
self.row = row
self.dst_col = dst_col
self.pattern = pattern


def gen_masked_scatter_golden(param: TScatterParamsMasked):
original_dir = os.getcwd()
os.makedirs(param.testname, exist_ok=True)
os.chdir(param.testname)

row = param.row
dst_col = param.dst_col
pattern = param.pattern

if pattern == P0101:
src_col = dst_col // 2
mask_indices = set(range(0, dst_col, 2))
elif pattern == P1010:
src_col = dst_col // 2
mask_indices = set(range(1, dst_col, 2))
elif pattern == P0001:
src_col = dst_col // 4
mask_indices = set(range(0, dst_col, 4))
elif pattern == P0010:
src_col = dst_col // 4
mask_indices = set(range(1, dst_col, 4))
elif pattern == P0100:
src_col = dst_col // 4
mask_indices = set(range(2, dst_col, 4))
elif pattern == P1000:
src_col = dst_col // 4
mask_indices = set(range(3, dst_col, 4))
elif pattern == P1111:
src_col = dst_col
mask_indices = set(range(0, dst_col))
else:
raise ValueError(f"Unsupported pattern: {pattern}")

src = np.random.randint(1, 100, [row, src_col]).astype(param.src_type)
dst = np.zeros([row, dst_col], dtype=param.src_type)

for r in range(row):
sidx = 0
for c in range(dst_col):
if c in mask_indices:
dst[r, c] = src.flat[r * src_col + sidx]
sidx += 1

src.tofile("./x1_gm.bin")
dst.tofile("./golden.bin")
os.chdir(original_dir)


if __name__ == "__main__":
gen_case("TSCATTERTest.case_float_16x16_16x16_16x16", 16, 16)

masked_cases = [
# float
TScatterParamsMasked("TSCATTERTest.case_masked_float_P0101",
np.float32, FLOAT_P0101_ROW, FLOAT_P0101_COL, P0101),
TScatterParamsMasked("TSCATTERTest.case_masked_float_P1010",
np.float32, FLOAT_P1010_ROW, FLOAT_P1010_COL, P1010),
TScatterParamsMasked("TSCATTERTest.case_masked_float_P0001",
np.float32, FLOAT_P0001_ROW, FLOAT_P0001_COL, P0001),
TScatterParamsMasked("TSCATTERTest.case_masked_float_P0010",
np.float32, FLOAT_P0010_ROW, FLOAT_P0010_COL, P0010),
TScatterParamsMasked("TSCATTERTest.case_masked_float_P0100",
np.float32, FLOAT_P0100_ROW, FLOAT_P0100_COL, P0100),
TScatterParamsMasked("TSCATTERTest.case_masked_float_P1000",
np.float32, FLOAT_P1000_ROW, FLOAT_P1000_COL, P1000),
TScatterParamsMasked("TSCATTERTest.case_masked_float_P1111",
np.float32, FLOAT_P1111_ROW, FLOAT_P1111_COL, P1111),
# half
TScatterParamsMasked("TSCATTERTest.case_masked_half_P0101",
np.float16, HALF_P0101_ROW, HALF_P0101_COL, P0101),
TScatterParamsMasked("TSCATTERTest.case_masked_half_P1010",
np.float16, HALF_P1010_ROW, HALF_P1010_COL, P1010),
TScatterParamsMasked("TSCATTERTest.case_masked_half_P0001",
np.float16, HALF_P0001_ROW, HALF_P0001_COL, P0001),
TScatterParamsMasked("TSCATTERTest.case_masked_half_P0100",
np.float16, HALF_P0100_ROW, HALF_P0100_COL, P0100),
TScatterParamsMasked("TSCATTERTest.case_masked_half_P1000",
np.float16, HALF_P1000_ROW, HALF_P1000_COL, P1000),
# uint16 / int16
TScatterParamsMasked("TSCATTERTest.case_masked_U16_P0101",
np.uint16, HALF_P0101_ROW, HALF_P0101_COL, P0101),
TScatterParamsMasked("TSCATTERTest.case_masked_U16_P1010",
np.uint16, HALF_P1010_ROW, HALF_P1010_COL, P1010),
TScatterParamsMasked("TSCATTERTest.case_masked_I16_P0001",
np.int16, HALF_P0001_ROW, HALF_P0001_COL, P0001),
TScatterParamsMasked("TSCATTERTest.case_masked_I16_P0010",
np.int16, HALF_P0010_ROW, HALF_P0010_COL, P0010),
# uint32 / int32
TScatterParamsMasked("TSCATTERTest.case_masked_U32_P0100",
np.uint32, FLOAT_P0100_ROW, FLOAT_P0100_COL, P0100),
TScatterParamsMasked("TSCATTERTest.case_masked_I32_P1000",
np.int32, FLOAT_P1000_ROW, FLOAT_P1000_COL, P1000),
TScatterParamsMasked("TSCATTERTest.case_masked_I32_P1111",
np.int32, FLOAT_P1111_ROW, FLOAT_P1111_COL, P1111),
]

for case in masked_cases:
gen_masked_scatter_golden(case)
156 changes: 156 additions & 0 deletions tests/cpu/st/testcase/tscatter/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ See LICENSE in the root of the software repository for the full text of the Lice
*/

#include "test_common.h"
#include "tscatter_common.h"
#include <pto/pto-inst.hpp>
#include <gtest/gtest.h>

Expand Down Expand Up @@ -96,3 +97,158 @@ TEST_F(TSCATTERTest, case_float_16x16_16x16_16x16)
{
test_tscatter<16, 16>();
}

// --- Mask-pattern TSCATTER tests ---

template <int32_t tilingKey>
void launchTSCATTER_masked(uint8_t *out, uint8_t *src, void *stream);

template <typename T, uint8_t PATTERN, uint32_t ROW, uint32_t DST_COL, uint32_t MASK_DIVISOR>
void test_scatter_masked()
{
constexpr uint32_t SRC_COL = DST_COL / MASK_DIVISOR;
size_t srcSize = ROW * SRC_COL * sizeof(T);
size_t dstSize = ROW * DST_COL * sizeof(T);

aclInit(nullptr);
aclrtSetDevice(0);
aclrtStream stream;
aclrtCreateStream(&stream);

uint8_t *dstHost, *srcHost;
uint8_t *dstDevice, *srcDevice;

aclrtMallocHost((void **)(&dstHost), dstSize);
aclrtMallocHost((void **)(&srcHost), srcSize);
aclrtMalloc((void **)&dstDevice, dstSize, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc((void **)&srcDevice, srcSize, ACL_MEM_MALLOC_HUGE_FIRST);

size_t readSize = srcSize;
CHECK_RESULT_GTEST(ReadFile(GetGoldenDir() + "/x1_gm.bin", readSize, srcHost, srcSize));

aclrtMemcpy(srcDevice, srcSize, srcHost, srcSize, ACL_MEMCPY_HOST_TO_DEVICE);
launchTSCATTER_masked<PATTERN>(dstDevice, srcDevice, stream);

aclrtSynchronizeStream(stream);
aclrtMemcpy(dstHost, dstSize, dstDevice, dstSize, ACL_MEMCPY_DEVICE_TO_HOST);

WriteFile(GetGoldenDir() + "/output_z.bin", dstHost, dstSize);

aclrtFree(dstDevice);
aclrtFree(srcDevice);
aclrtFreeHost(dstHost);
aclrtFreeHost(srcHost);
aclrtDestroyStream(stream);
aclrtResetDevice(0);
aclFinalize();

constexpr size_t numElements = ROW * DST_COL;
std::vector<T> golden(numElements);
std::vector<T> devFinal(numElements);
readSize = dstSize;
CHECK_RESULT_GTEST(ReadFile(GetGoldenDir() + "/golden.bin", readSize, golden.data(), dstSize));
readSize = dstSize;
CHECK_RESULT_GTEST(ReadFile(GetGoldenDir() + "/output_z.bin", readSize, devFinal.data(), dstSize));

bool ret = ResultCmp<T>(golden, devFinal, 0.001f);
EXPECT_TRUE(ret);
}

// float
TEST_F(TSCATTERTest, case_masked_float_P0101)
{
test_scatter_masked<float, FP0101, FLOAT_P0101_ROW, FLOAT_P0101_COL, 2>();
}

TEST_F(TSCATTERTest, case_masked_float_P1010)
{
test_scatter_masked<float, FP1010, FLOAT_P1010_ROW, FLOAT_P1010_COL, 2>();
}

TEST_F(TSCATTERTest, case_masked_float_P0001)
{
test_scatter_masked<float, FP0001, FLOAT_P0001_ROW, FLOAT_P0001_COL, 4>();
}

TEST_F(TSCATTERTest, case_masked_float_P0010)
{
test_scatter_masked<float, FP0010, FLOAT_P0010_ROW, FLOAT_P0010_COL, 4>();
}

TEST_F(TSCATTERTest, case_masked_float_P0100)
{
test_scatter_masked<float, FP0100, FLOAT_P0100_ROW, FLOAT_P0100_COL, 4>();
}

TEST_F(TSCATTERTest, case_masked_float_P1000)
{
test_scatter_masked<float, FP1000, FLOAT_P1000_ROW, FLOAT_P1000_COL, 4>();
}

TEST_F(TSCATTERTest, case_masked_float_P1111)
{
test_scatter_masked<float, FP1111, FLOAT_P1111_ROW, FLOAT_P1111_COL, 1>();
}

// half
TEST_F(TSCATTERTest, case_masked_half_P0101)
{
test_scatter_masked<half, HP0101, HALF_P0101_ROW, HALF_P0101_COL, 2>();
}

TEST_F(TSCATTERTest, case_masked_half_P1010)
{
test_scatter_masked<half, HP1010, HALF_P1010_ROW, HALF_P1010_COL, 2>();
}

TEST_F(TSCATTERTest, case_masked_half_P0001)
{
test_scatter_masked<half, HP0001, HALF_P0001_ROW, HALF_P0001_COL, 4>();
}

TEST_F(TSCATTERTest, case_masked_half_P0100)
{
test_scatter_masked<half, HP0100, HALF_P0100_ROW, HALF_P0100_COL, 4>();
}

TEST_F(TSCATTERTest, case_masked_half_P1000)
{
test_scatter_masked<half, HP1000, HALF_P1000_ROW, HALF_P1000_COL, 4>();
}

// uint16 / int16
TEST_F(TSCATTERTest, case_masked_U16_P0101)
{
test_scatter_masked<uint16_t, U16P0101, HALF_P0101_ROW, HALF_P0101_COL, 2>();
}

TEST_F(TSCATTERTest, case_masked_U16_P1010)
{
test_scatter_masked<uint16_t, U16P1010, HALF_P1010_ROW, HALF_P1010_COL, 2>();
}

TEST_F(TSCATTERTest, case_masked_I16_P0001)
{
test_scatter_masked<int16_t, I16P0001, HALF_P0001_ROW, HALF_P0001_COL, 4>();
}

TEST_F(TSCATTERTest, case_masked_I16_P0010)
{
test_scatter_masked<int16_t, I16P0010, HALF_P0010_ROW, HALF_P0010_COL, 4>();
}

// uint32 / int32
TEST_F(TSCATTERTest, case_masked_U32_P0100)
{
test_scatter_masked<uint32_t, U32P0100, FLOAT_P0100_ROW, FLOAT_P0100_COL, 4>();
}

TEST_F(TSCATTERTest, case_masked_I32_P1000)
{
test_scatter_masked<int32_t, I32P1000, FLOAT_P1000_ROW, FLOAT_P1000_COL, 4>();
}

TEST_F(TSCATTERTest, case_masked_I32_P1111)
{
test_scatter_masked<int32_t, I32P1111, FLOAT_P1111_ROW, FLOAT_P1111_COL, 1>();
}
Loading
Loading