@@ -2061,13 +2061,12 @@ class joint_matrix {
2061
2061
// / \tparam [in] T The type of matrix elements
2062
2062
// / \param [in] addr The address of the matrix in local memory
2063
2063
// / \param [in] m The private memory containing data of matrix
2064
- // / \param [in] item The sycl::nd_item index space class
2065
2064
// / \param [in] trans Indicates whether the matrix to be stored transposed
2066
2065
// / \param [in] mat The matrix index to be stored
2067
- template <typename T, typename ItemT >
2068
- void stmatrix (uintptr_t addr, T m, const ItemT &item, bool trans = false ,
2069
- unsigned mat = 0 ) {
2070
- int lane = item. get_sub_group () .get_local_linear_id ();
2066
+ template <typename T>
2067
+ void stmatrix (uintptr_t addr, T m, bool trans = false , unsigned mat = 0 ) {
2068
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group ();
2069
+ int lane = sg .get_local_linear_id ();
2071
2070
2072
2071
int lane_group8_row = lane / 8 ;
2073
2072
int lane_group8_col = lane % 8 ;
@@ -2079,8 +2078,8 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
2079
2078
src_lane += 1 ;
2080
2079
2081
2080
// Broadcast the address from the source lane
2082
- auto recv_addr_uintp = dpct::select_from_sub_group (
2083
- item. get_sub_group () , addr, mat * 8 + src_lane);
2081
+ auto recv_addr_uintp =
2082
+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
2084
2083
2085
2084
// Cast the received address from uintptr_t to the type of 'm'
2086
2085
auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
@@ -2092,10 +2091,10 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
2092
2091
int src_lane = (lane % 4 ) * 2 ;
2093
2092
2094
2093
// Broadcast the address from the source lane
2095
- auto recv_addr_uintp_1 = dpct::select_from_sub_group (
2096
- item. get_sub_group () , addr, mat * 8 + src_lane);
2097
- auto recv_addr_uintp_2 = dpct::select_from_sub_group (
2098
- item. get_sub_group () , addr, mat * 8 + src_lane + 1 );
2094
+ auto recv_addr_uintp_1 =
2095
+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
2096
+ auto recv_addr_uintp_2 =
2097
+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane + 1 );
2099
2098
2100
2099
// Cast the received address from uintptr_t to 'half *'
2101
2100
auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
@@ -2117,15 +2116,13 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
2117
2116
// / \param [in] addr The address of the matrix in local memory
2118
2117
// / \param [in] m1 The private memory containing data of 1st matrix
2119
2118
// / \param [in] m2 The private memory containing data of 2nd matrix
2120
- // / \param [in] item The sycl::nd_item index space class
2121
2119
// / \param [in] trans Indicates whether the matrix to be stored transposed
2122
- template <typename T, typename ItemT>
2123
- void stmatrix (uintptr_t addr, T m1, T m2, const ItemT &item,
2124
- bool trans = false ) {
2120
+ template <typename T>
2121
+ void stmatrix (uintptr_t addr, T m1, T m2, bool trans = false ) {
2125
2122
// Store 1st matrix
2126
- stmatrix (addr, m1, item, trans, 0 );
2123
+ stmatrix (addr, m1, trans, 0 );
2127
2124
// Store 2nd matrix
2128
- stmatrix (addr, m2, item, trans, 1 );
2125
+ stmatrix (addr, m2, trans, 1 );
2129
2126
}
2130
2127
2131
2128
// / Stores 4 8x8 b16 matrix from private memory to local memory (32-bits per wi)
@@ -2136,19 +2133,17 @@ void stmatrix(uintptr_t addr, T m1, T m2, const ItemT &item,
2136
2133
// / \param [in] m2 The private memory containing data of 2nd matrix
2137
2134
// / \param [in] m3 The private memory containing data of 3rd matrix
2138
2135
// / \param [in] m4 The private memory containing data of 4th matrix
2139
- // / \param [in] item The sycl::nd_item index space class
2140
2136
// / \param [in] trans Indicates whether the matrix to be stored transposed
2141
- template <typename T, typename ItemT>
2142
- void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, const ItemT &item,
2143
- bool trans = false ) {
2137
+ template <typename T>
2138
+ void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false ) {
2144
2139
// Store 1st matrix
2145
- stmatrix (addr, m1, item, trans, 0 );
2140
+ stmatrix (addr, m1, trans, 0 );
2146
2141
// Store 2nd matrix
2147
- stmatrix (addr, m2, item, trans, 1 );
2142
+ stmatrix (addr, m2, trans, 1 );
2148
2143
// Store 3rd matrix
2149
- stmatrix (addr, m3, item, trans, 2 );
2144
+ stmatrix (addr, m3, trans, 2 );
2150
2145
// Store 4th matrix
2151
- stmatrix (addr, m4, item, trans, 3 );
2146
+ stmatrix (addr, m4, trans, 3 );
2152
2147
}
2153
2148
2154
2149
} // namespace matrix
0 commit comments