Skip to content

Commit d37eade

Browse files
Removed item_ct1 in favor of free functions
1 parent 403371e commit d37eade

File tree

3 files changed

+27
-34
lines changed

3 files changed

+27
-34
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,15 +1344,13 @@ class SYCLGen : public SYCLGenBase {
13441344
if (emitStmt(Dst)) {
13451345
return SYCLGenError();
13461346
}
1347-
OS() << ", ";
13481347
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
13491348
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
13501349
continue;
1350+
OS() << ", ";
13511351
if (emitStmt(VE->getElement(Inst)))
13521352
return SYCLGenError();
1353-
OS() << ", ";
13541353
}
1355-
OS() << DpctGlobalInfo::getItem(GAS);
13561354
if (Inst->hasAttr(InstAttr::trans))
13571355
OS() << ", true";
13581356
OS() << ");";

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,13 +2061,12 @@ class joint_matrix {
20612061
/// \tparam [in] T The type of matrix elements
20622062
/// \param [in] addr The address of the matrix in local memory
20632063
/// \param [in] m The private memory containing data of matrix
2064-
/// \param [in] item The sycl::nd_item index space class
20652064
/// \param [in] trans Indicates whether the matrix to be stored transposed
20662065
/// \param [in] mat The matrix index to be stored
2067-
template <typename T, typename ItemT>
2068-
void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
2069-
unsigned mat = 0) {
2070-
int lane = item.get_sub_group().get_local_linear_id();
2066+
template <typename T>
2067+
void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
2068+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
2069+
int lane = sg.get_local_linear_id();
20712070

20722071
int lane_group8_row = lane / 8;
20732072
int lane_group8_col = lane % 8;
@@ -2079,8 +2078,8 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
20792078
src_lane += 1;
20802079

20812080
// Broadcast the address from the source lane
2082-
auto recv_addr_uintp = dpct::select_from_sub_group(
2083-
item.get_sub_group(), addr, mat * 8 + src_lane);
2081+
auto recv_addr_uintp =
2082+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
20842083

20852084
// Cast the received address from uintptr_t to the type of 'm'
20862085
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
@@ -2092,10 +2091,10 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
20922091
int src_lane = (lane % 4) * 2;
20932092

20942093
// Broadcast the address from the source lane
2095-
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
2096-
item.get_sub_group(), addr, mat * 8 + src_lane);
2097-
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
2098-
item.get_sub_group(), addr, mat * 8 + src_lane + 1);
2094+
auto recv_addr_uintp_1 =
2095+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
2096+
auto recv_addr_uintp_2 =
2097+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
20992098

21002099
// Cast the received address from uintptr_t to 'half *'
21012100
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
@@ -2117,15 +2116,13 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
21172116
/// \param [in] addr The address of the matrix in local memory
21182117
/// \param [in] m1 The private memory containing data of 1st matrix
21192118
/// \param [in] m2 The private memory containing data of 2nd matrix
2120-
/// \param [in] item The sycl::nd_item index space class
21212119
/// \param [in] trans Indicates whether the matrix to be stored transposed
2122-
template <typename T, typename ItemT>
2123-
void stmatrix(uintptr_t addr, T m1, T m2, const ItemT &item,
2124-
bool trans = false) {
2120+
template <typename T>
2121+
void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
21252122
// Store 1st matrix
2126-
stmatrix(addr, m1, item, trans, 0);
2123+
stmatrix(addr, m1, trans, 0);
21272124
// Store 2nd matrix
2128-
stmatrix(addr, m2, item, trans, 1);
2125+
stmatrix(addr, m2, trans, 1);
21292126
}
21302127

21312128
/// Stores 4 8x8 b16 matrix from private memory to local memory (32-bits per wi)
@@ -2136,19 +2133,17 @@ void stmatrix(uintptr_t addr, T m1, T m2, const ItemT &item,
21362133
/// \param [in] m2 The private memory containing data of 2nd matrix
21372134
/// \param [in] m3 The private memory containing data of 3rd matrix
21382135
/// \param [in] m4 The private memory containing data of 4th matrix
2139-
/// \param [in] item The sycl::nd_item index space class
21402136
/// \param [in] trans Indicates whether the matrix to be stored transposed
2141-
template <typename T, typename ItemT>
2142-
void stmatrix(uintptr_t addr, T m1, T m2, T m3, T m4, const ItemT &item,
2143-
bool trans = false) {
2137+
template <typename T>
2138+
void stmatrix(uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false) {
21442139
// Store 1st matrix
2145-
stmatrix(addr, m1, item, trans, 0);
2140+
stmatrix(addr, m1, trans, 0);
21462141
// Store 2nd matrix
2147-
stmatrix(addr, m2, item, trans, 1);
2142+
stmatrix(addr, m2, trans, 1);
21482143
// Store 3rd matrix
2149-
stmatrix(addr, m3, item, trans, 2);
2144+
stmatrix(addr, m3, trans, 2);
21502145
// Store 4th matrix
2151-
stmatrix(addr, m4, item, trans, 3);
2146+
stmatrix(addr, m4, trans, 3);
21522147
}
21532148

21542149
} // namespace matrix

clang/test/dpct/asm/stmatrix.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ __device__ void store_matrix_x1(void *sh_r_addr, int *r) {
2222
// CHECK: auto addr = sh_r_addr;
2323
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
2424

25-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], item_ct1);
25+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0]);
2626
asm volatile("stmatrix.sync.aligned.m8n8.x1.shared.b16 [%0], {%1};\n"
2727
:
2828
: "r"(addr), "r"(r[0]));
@@ -32,7 +32,7 @@ __device__ void store_matrix_x2(void *sh_r_addr, int *r) {
3232
// CHECK: auto addr = sh_r_addr;
3333
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
3434

35-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], item_ct1);
35+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1]);
3636
asm volatile("stmatrix.sync.aligned.m8n8.x2.shared.b16 [%0], {%1, %2};\n"
3737
:
3838
: "r"(addr), "r"(r[0]), "r"(r[1]));
@@ -42,7 +42,7 @@ __device__ void store_matrix_x4(void *sh_r_addr, int *r) {
4242
// CHECK: auto addr = sh_r_addr;
4343
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
4444

45-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], r[2], r[3], item_ct1);
45+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], r[2], r[3]);
4646
asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n"
4747
:
4848
: "r"(addr), "r"(r[0]), "r"(r[1]), "r"(r[2]), "r"(r[3]));
@@ -52,7 +52,7 @@ __device__ void store_matrix_x1_trans(void *sh_r_addr, int *r) {
5252
// CHECK: auto addr = sh_r_addr;
5353
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
5454

55-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], item_ct1, true);
55+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], true);
5656
asm volatile("stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 [%0], {%1};\n"
5757
:
5858
: "r"(addr), "r"(r[0]));
@@ -62,7 +62,7 @@ __device__ void store_matrix_x2_trans(void *sh_r_addr, int *r) {
6262
// CHECK: auto addr = sh_r_addr;
6363
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
6464

65-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], item_ct1, true);
65+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], true);
6666
asm volatile("stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 [%0], {%1, %2};\n"
6767
:
6868
: "r"(addr), "r"(r[0]), "r"(r[1]));
@@ -72,7 +72,7 @@ __device__ void store_matrix_x4_trans(void *sh_r_addr, int *r) {
7272
// CHECK: auto addr = sh_r_addr;
7373
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
7474

75-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], r[2], r[3], item_ct1, true);
75+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], r[2], r[3], true);
7676
asm volatile("stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [%0], {%1, %2, %3, %4};\n"
7777
:
7878
: "r"(addr), "r"(r[0]), "r"(r[1]), "r"(r[2]), "r"(r[3]));

0 commit comments

Comments
 (0)