@@ -2056,27 +2056,33 @@ class joint_matrix {
2056
2056
const size_t num_elements;
2057
2057
};
2058
2058
2059
- // / Stores 1 8x8 b16 matrix from private memory to local memory per sub-group.
2059
+ // / Collectively stores 1 8x8 b16 (128 bytes) matrix from private memory to
2060
+ // / local memory per sub-group.
2060
2061
// / Requires the sub-group size of kernel calling this function to be 32.
2061
- // / Each of the first 8 work items contain the starting address of their
2062
- // / respective matrix row.
2063
- // / Each of the 32 work items store 32-bits (2 packed 16-bit data) for a total
2064
- // / of 128 bytes.
2065
- // / Row Major: Each row of the matrix is stored by a group of 4 work items
2066
- // / r0: t0 t1 t2 t3
2067
- // / r1: t4 t5 t6 t7
2062
+ // / 'mat' specifies the matrix index to be stored. The first '(mat + 1) * 8'
2063
+ // / work items of sub-group contain the starting address of their respective
2064
+ // / matrix row in 'addr'.
2065
+ // / After distributing addresses to other work items, each of the 32 work items
2066
+ // / store 32-bits (2 packed 16-bit data) into 'm' for a total of 128 bytes.
2067
+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2068
+ // / item like below
2069
+ // / Row Major: Each row of the matrix is stored by a group of 4 work items(wi)
2070
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2071
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2068
2072
// / ...
2069
- // / r7: t24 t25 t26 t27
2070
- // / r7: t28 t29 t30 t31
2071
- // / Col Major: Each col of the matrix is stored by a group of 4 work items
2072
- // / r0: t0 t4 t8 ... t28
2073
- // / r1: t0 t4 t8 ... t28
2073
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2074
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2075
+ // / Col Major: Each col of the matrix is stored by a group of 4 work items(wi)
2076
+ // / row-0: wi0 wi4 wi8 ... wi28
2077
+ // / row-1: wi0 wi4 wi8 ... wi28
2074
2078
// / ...
2075
- // / r6: t3 t7 t11 ... t31
2076
- // / r7: t3 t7 t11 ... t31
2077
- // / \tparam [in] T The type of matrix elements
2078
- // / \param [in] addr The address of the matrix in local memory
2079
- // / \param [in] m The private memory containing data of matrix
2079
+ // / row-6: wi3 wi7 wi11 ... wi31
2080
+ // / row-7: wi3 wi7 wi11 ... wi31
2081
+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2082
+ // / \param [in] addr The starting address of corresponding matrix row for a work
2083
+ // / item in local memory
2084
+ // / \param [in] m The private memory to store the matrix. It points to 2 b16
2085
+ // / type elements.
2080
2086
// / \param [in] trans Indicates whether the matrix to be stored transposed
2081
2087
// / \param [in] mat The matrix index to be stored
2082
2088
template <typename T>
@@ -2126,16 +2132,35 @@ void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
2126
2132
}
2127
2133
}
2128
2134
2129
- // / Stores 2 8x8 b16 matrix from private memory to local memory per sub-group.
2135
+ // / Collectively stores 2 8x8 b16 (256 bytes) matrix from private memory to
2136
+ // / local memory per sub-group.
2130
2137
// / Requires the sub-group size of kernel calling this function to be 32.
2131
- // / Each of the first 16 work items contain the starting address of their
2132
- // / respective matrix row.
2133
- // / Each of the 32 work items store 64-bits (32-bit per matrix) for a total
2134
- // / of 256 bytes.
2135
- // / \tparam [in] T The type of matrix elements
2136
- // / \param [in] addr The address of the matrix in local memory
2137
- // / \param [in] m1 The private memory containing data of 1st matrix
2138
- // / \param [in] m2 The private memory containing data of 2nd matrix
2138
+ // / The first 16 work items of sub-group contain the starting address of their
2139
+ // / respective matrix row in 'addr'.
2140
+ // / After distributing addresses to other work items, each of the 32 work items
2141
+ // / store 64-bits (32-bits per matrix) into 'm1' & 'm2' for a total of 256
2142
+ // / bytes.
2143
+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2144
+ // / item like below
2145
+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2146
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2147
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2148
+ // / ...
2149
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2150
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2151
+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2152
+ // / row-0: wi0 wi4 wi8 ... wi28
2153
+ // / row-1: wi0 wi4 wi8 ... wi28
2154
+ // / ...
2155
+ // / row-6: wi3 wi7 wi11 ... wi31
2156
+ // / row-7: wi3 wi7 wi11 ... wi31
2157
+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2158
+ // / \param [in] addr The starting address of corresponding matrix row for a work
2159
+ // / item in local memory
2160
+ // / \param [in] m1 The private memory to store the data of 1st matrix. It points
2161
+ // / to 2 b16 type elements.
2162
+ // / \param [in] m2 The private memory to store the data of 2nd matrix. It points
2163
+ // / to 2 b16 type elements.
2139
2164
// / \param [in] trans Indicates whether the matrix to be stored transposed
2140
2165
template <typename T>
2141
2166
void stmatrix (uintptr_t addr, T m1, T m2, bool trans = false ) {
@@ -2145,18 +2170,39 @@ void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
2145
2170
stmatrix (addr, m2, trans, 1 );
2146
2171
}
2147
2172
2148
- // / Stores 4 8x8 b16 matrix from private memory to local memory per sub-group.
2173
+ // / Collectively stores 4 8x8 b16 (512 bytes) matrix from private memory to
2174
+ // / local memory per sub-group.
2149
2175
// / Requires the sub-group size of kernel calling this function to be 32.
2150
- // / Each of the 32 work items contain the starting address of their
2151
- // / respective matrix row.
2152
- // / Each of the 32 work items store 128-bits (32-bit per matrix) for a total
2176
+ // / Each work item of sub-group contains the starting address of their
2177
+ // / respective matrix row in 'addr'.
2178
+ // / After distributing addresses to other work items, each of the 32 work items
2179
+ // / store 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
2153
2180
// / of 512 bytes.
2154
- // / \tparam [in] T The type of matrix elements
2155
- // / \param [in] addr The address of the matrix in local memory
2156
- // / \param [in] m1 The private memory containing data of 1st matrix
2157
- // / \param [in] m2 The private memory containing data of 2nd matrix
2158
- // / \param [in] m3 The private memory containing data of 3rd matrix
2159
- // / \param [in] m4 The private memory containing data of 4th matrix
2181
+ // / 'trans' specifies to perform a transposed/non-transposed store by each work
2182
+ // / item like below
2183
+ // / Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2184
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2185
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2186
+ // / ...
2187
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2188
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2189
+ // / Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2190
+ // / row-0: wi0 wi4 wi8 ... wi28
2191
+ // / row-1: wi0 wi4 wi8 ... wi28
2192
+ // / ...
2193
+ // / row-6: wi3 wi7 wi11 ... wi31
2194
+ // / row-7: wi3 wi7 wi11 ... wi31
2195
+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2196
+ // / \param [in] addr The starting address of corresponding matrix row for a work
2197
+ // / item in local memory
2198
+ // / \param [in] m1 The private memory to store the data of 1st matrix. It points
2199
+ // / to 2 b16 type elements.
2200
+ // / \param [in] m2 The private memory to store the data of 2nd matrix. It points
2201
+ // / to 2 b16 type elements.
2202
+ // / \param [in] m3 The private memory to store the data of 3rd matrix. It points
2203
+ // / to 2 b16 type elements.
2204
+ // / \param [in] m4 The private memory to store the data of 4th matrix. It points
2205
+ // / to 2 b16 type elements.
2160
2206
// / \param [in] trans Indicates whether the matrix to be stored transposed
2161
2207
template <typename T>
2162
2208
void stmatrix (uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false ) {
0 commit comments