Skip to content

Commit b5cab17

Browse files
authored
Performance enhancement for Volta Tensor Cores TN layout (NVIDIA#53)
* Fixed performance defect with indirect access to pointer array for Volta TensorCores TN arrangement. * Updated patch version and changelog. * Updated patch version and changelog. * Added link to changelog in readme. * Fixed markdown link
1 parent eb41735 commit b5cab17

File tree

5 files changed

+18
-15
lines changed

5 files changed

+18
-15
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# NVIDIA CUTLASS Changelog
22

3+
## [1.3.2](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.2) (2019-07-09)
4+
* Performance improvement for Volta Tensor Cores TN and TT layouts.
5+
36
## [1.3.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.1) (2019-04-09)
47
* Corrected NVRTC unit tests.
58

README.md

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# CUTLASS 1.3
44

5-
_CUTLASS 1.3.1 - April 2019_
5+
_CUTLASS 1.3.2 - July 2019_
66

77
CUTLASS is a collection of CUDA C++ template abstractions for implementing
88
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
@@ -28,9 +28,6 @@ CUTLASS 1.3 is described in the [CUTLASS Documentation](CUTLASS.md) and the acco
2828
We describe the structure of an efficient GEMM in our talk at the
2929
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
3030

31-
# What's New in CUTLASS 1.3.1
32-
_April 2019_
33-
* CUTLASS 1.3.1 corrected NVRTC unit tests..
3431

3532
# What's New in CUTLASS 1.3
3633
_March 2019_
@@ -60,6 +57,8 @@ _September 2018_
6057
* [Reference implementations](tools/util/reference) for tensor operations in [host](tools/util/reference/host) and [device](tools/util/reference/device) code
6158
* Added `HostMatrix<>` for simplified matrix creation
6259

60+
For all updates, see the [CUTLASS changelog](CHANGELOG.md).
61+
6362
# Performance
6463

6564
<p align="center"><img src=/media/images/cutlass-performance-plot.png></p>

cutlass/cutlass.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
#define CUTLASS_MAJOR 1
3636
#define CUTLASS_MINOR 3
37-
#define CUTLASS_PATCH 1
37+
#define CUTLASS_PATCH 2
3838
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
3939

4040
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__))

cutlass/gemm/volta884_shared_tile_crosswise.h

+10-6
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,12 @@ struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kA,
237237
Coord<4> offset = offset_func(ptr_idx);
238238
pointer[ptr_idx] = _params.pointer + (_block_offset + offset).template dot<int>(stride);
239239
}
240+
241+
if (((threadIdx.x >> 5) * Iterations::kD) & 2) {
242+
Scalar *tmp = pointer[0];
243+
pointer[0] = pointer[1];
244+
pointer[1] = tmp;
245+
}
240246
}
241247

242248
/// Stores a fragment
@@ -254,16 +260,12 @@ struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kA,
254260
CUTLASS_PRAGMA_UNROLL
255261
for (int w = 0; w < Iterations::kW; ++w) { // 2x STS operations per LDG
256262

257-
int warp_id = (threadIdx.x >> 5);
258-
259-
int ldg_idx = d + warp_id * Iterations::kD;
260263
int k_idx = w + h * 8;
261264
int smem_row = (d >> 1);
262265

263266
// Two store pointers
264-
int ptr_idx = ((ldg_idx & 1) ^ ((ldg_idx >> 1) & 1));
265-
266-
Scalar *_pointer = pointer[ptr_idx];
267+
Scalar *_pointer = pointer[(d & 1) ^ ((d >> 1) & 1)];
268+
267269
Coord<4> sts_offset = make_Coord(k_idx, smem_row, 0, 0);
268270

269271
Store<typename Fragment::Element, kAccessSize, kMemorySpace>::store(
@@ -277,6 +279,7 @@ struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kA,
277279

278280
/// Increments store iterator to next tile
279281
__device__ Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) {
282+
CUTLASS_PRAGMA_UNROLL
280283
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
281284
pointer[ptr_idx] +=
282285
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(stride);
@@ -293,6 +296,7 @@ struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kA,
293296

294297
/// Increments store iterator to previous tile
295298
__device__ Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) {
299+
CUTLASS_PRAGMA_UNROLL
296300
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
297301
pointer[ptr_idx] -=
298302
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(stride);

tools/test/unit/gemm/volta884_gemm.cu

+1-4
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ TEST(Volta884_f16_s884gemm_128x128x32_tt, short_480x280x224) {
183183
// Contiguous - s884gemm
184184
//
185185
////////////////////////////////////////////////////////////////////////////////////////////////////
186-
#if 0
186+
187187
TEST(Volta884_f16_s884gemm_64x64x32_nt, 64x64x32) {
188188

189189
typedef cutlass::gemm::Volta884GemmTraits<
@@ -218,7 +218,6 @@ TEST(Volta884_f16_s884gemm_64x64x32_nt, 64x64x30_residue) {
218218
run_gemm<GemmTraits>(64, 64, 30);
219219
}
220220

221-
#if 0
222221
////////////////////////////////////////////////////////////////////////////////////////////////////
223222

224223
TEST(Volta884_f16_s884gemm_64x64x32_nt, 64x64x64) {
@@ -874,7 +873,6 @@ TEST(Volta884_f16_s884gemm_128x128x32_nn, 392x264x192) {
874873

875874
run_gemm<GemmTraits>(392, 264, 192);
876875
}
877-
#endif
878876

879877
////////////////////////////////////////////////////////////////////////////////////////////////////
880878

@@ -1281,7 +1279,6 @@ TEST(Volta884_f16_s884gemm_f16_128x256x32_tn, 480x280x224) {
12811279

12821280
run_gemm<GemmTraits>(480, 280, 224);
12831281
}
1284-
#endif
12851282
////////////////////////////////////////////////////////////////////////////////////////////////////
12861283

12871284
#endif // if defined(CUTLASS_ENABLE_TENSOR_CORE_MMA)

0 commit comments

Comments
 (0)