Skip to content

Commit 1ab1027

Browse files
authored
Updated mma_sm80.h to avoid perf penalty due to reinterpret_cast<>. (NVIDIA#100)
- Updated mma_sm80.h to avoid perf penalty due to reinterpret_cast<>. - Enhancement to CUTLASS Utility Library's HostTensorPlanarComplex template to support copy-in and copy-out - Added test_examples target to build and test all CUTLASS examples - Minor edits to documentation to point to GTC 2020 webinar
1 parent 86931fe commit 1ab1027

File tree

11 files changed

+213
-33
lines changed

11 files changed

+213
-33
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
* Tensor Float 32, BFloat16, and double-precision data types
1010
* Mixed integer data types (int8, int4, bin1)
1111
* Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
12+
* Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) (free registration required)
1213
* Features:
1314
* SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM
1415
* Complex-valued GEMMs targeting NVIDIA Ampere Tensor Cores in double-precision and Tensor Float 32

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ CUTLASS 2.2 is a significant update to CUTLASS adding:
3737
- Coverage of [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
3838
- Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types
3939
- Deep software pipelines using asynchronous copy
40+
- Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745)
4041
- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit)
4142

4243
# What's New in CUTLASS 2.1

examples/03_visualize_layout/CMakeLists.txt

+1-7
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,9 @@
2020
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2121
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2222

23-
cutlass_add_executable(
23+
cutlass_example_add_executable(
2424
03_visualize_layout
2525
visualize_layout.cpp
2626
register_layout.cu
2727
)
2828

29-
target_link_libraries(
30-
03_visualize_layout
31-
PRIVATE
32-
CUTLASS
33-
cutlass_tools_util_includes
34-
)

examples/06_splitK_gemm/splitk_gemm.cu

+5-3
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,12 @@ int run() {
182182
return -1;
183183
}
184184

185-
if (!(props.major >= 7)) {
186-
std::cerr << "Volta Tensor Ops must be run on a machine with compute capability at least 70."
185+
if (props.major != 7) {
186+
std::cerr << "Volta Tensor Ops must be run on a machine with compute capability of 70, 72, or 75."
187187
<< std::endl;
188-
return -1;
188+
189+
// Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits.
190+
return 0;
189191
}
190192

191193
//

examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu

+5-3
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,12 @@ int run() {
198198
return -1;
199199
}
200200

201-
if (!(props.major >= 7)) {
202-
std::cerr << "Volta Tensor Ops must be run on a machine with compute capability at least 70."
201+
if (props.major != 7) {
202+
std::cerr << "Volta Tensor Ops must be run on a machine with compute capability of 70, 72, or 75."
203203
<< std::endl;
204-
return -1;
204+
205+
// Return 0 so tests are considered passing if run on unsupported architectures or CUDA Toolkits.
206+
return 0;
205207
}
206208

207209
const int length_m = 5120;

examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ int run() {
208208
if (!((props.major * 10 + props.minor) >= 75)) {
209209
std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 75."
210210
<< std::endl;
211-
return -1;
211+
212+
// Return 0 so tests are considered passing if run on unsupported platforms.
213+
return 0;
212214
}
213215

214216
const int length_m = 5120;

examples/CMakeLists.txt

+10
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,18 @@ function(cutlass_example_add_executable NAME)
4444
${CUTLASS_EXAMPLES_COMMON_SOURCE_DIR}
4545
)
4646

47+
add_custom_target(
48+
test_${NAME}
49+
COMMAND
50+
${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $<TARGET_FILE:${NAME}>
51+
DEPENDS
52+
${NAME}
53+
)
54+
4755
endfunction()
4856

4957
add_custom_target(cutlass_examples)
58+
add_custom_target(test_examples)
5059

5160
foreach(EXAMPLE
5261
00_basic_gemm
@@ -66,5 +75,6 @@ foreach(EXAMPLE
6675

6776
add_subdirectory(${EXAMPLE})
6877
add_dependencies(cutlass_examples ${EXAMPLE})
78+
add_dependencies(test_examples test_${EXAMPLE})
6979

7080
endforeach()

include/cutlass/arch/mma.h

+1
Original file line numberDiff line numberDiff line change
@@ -164,4 +164,5 @@ struct Mma<gemm::GemmShape<1, 1, 1>, 1, ElementA, LayoutA, ElementB, LayoutB, El
164164
#include "cutlass/arch/mma_sm61.h"
165165
#include "cutlass/arch/mma_sm70.h"
166166
#include "cutlass/arch/mma_sm75.h"
167+
#include "cutlass/arch/mma_sm80.h"
167168
/////////////////////////////////////////////////////////////////////////////////////////////////

include/cutlass/arch/mma_sm80.h

+18-18
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,17 @@ struct Mma<
9898

9999
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
100100
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
101-
uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
102-
uint32_t *D = reinterpret_cast<uint32_t *>(&d);
101+
float const *C = reinterpret_cast<float const *>(&c);
102+
float *D = reinterpret_cast<float *>(&d);
103103

104104
asm(
105105
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
106106
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
107-
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
107+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
108108
:
109109
"r"(A[0]), "r"(A[1]),
110110
"r"(B[0]),
111-
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
111+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
112112
);
113113

114114
#else
@@ -341,15 +341,15 @@ struct Mma<
341341

342342
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
343343
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
344-
uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
345-
uint32_t *D = reinterpret_cast<uint32_t *>(&d);
344+
float const *C = reinterpret_cast<float const *>(&c);
345+
float *D = reinterpret_cast<float *>(&d);
346346

347347
asm volatile(
348348
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
349349
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
350-
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
350+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
351351
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
352-
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
352+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
353353

354354
#else
355355
assert(0);
@@ -402,15 +402,15 @@ struct Mma<
402402

403403
uint32_t const *A = reinterpret_cast<uint32_t const *>(&a);
404404
uint32_t const *B = reinterpret_cast<uint32_t const *>(&b);
405-
uint32_t const *C = reinterpret_cast<uint32_t const *>(&c);
406-
uint32_t *D = reinterpret_cast<uint32_t *>(&d);
405+
float const *C = reinterpret_cast<float const *>(&c);
406+
float *D = reinterpret_cast<float *>(&d);
407407

408408
asm volatile(
409409
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
410410
"{%10,%11,%12,%13};\n"
411-
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
411+
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
412412
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
413-
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
413+
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
414414

415415
#else
416416
assert(0);
@@ -461,15 +461,15 @@ struct Mma<
461461

462462
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
463463

464-
uint64_t const & A = reinterpret_cast<uint64_t const &>(a);
465-
uint64_t const & B = reinterpret_cast<uint64_t const &>(b);
464+
double const & A = reinterpret_cast<double const &>(a);
465+
double const & B = reinterpret_cast<double const &>(b);
466466

467-
uint64_t const *C = reinterpret_cast<uint64_t const *>(&c);
468-
uint64_t *D = reinterpret_cast<uint64_t *>(&d);
467+
double const *C = reinterpret_cast<double const *>(&c);
468+
double *D = reinterpret_cast<double *>(&d);
469469

470470
asm volatile("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
471-
: "=l"(D[0]), "=l"(D[1])
472-
: "l"(A), "l"(B), "l"(C[0]), "l"(C[1]));
471+
: "=d"(D[0]), "=d"(D[1])
472+
: "d"(A), "d"(B), "d"(C[0]), "d"(C[1]));
473473

474474
#else
475475
assert(0);

media/docs/quickstart.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ compiled as C++11 or greater.
161161
#include <iostream>
162162
#include <cutlass/cutlass.h>
163163
#include <cutlass/numeric_types.h>
164+
#include <cutlass/core_io.h>
164165

165166
int main() {
166167

@@ -174,10 +175,13 @@ int main() {
174175

175176
## Launching a GEMM kernel in CUDA
176177

177-
**Example:** launch a mixed-precision GEMM targeting Turing Tensor Cores.
178+
**Example:** launch a mixed-precision GEMM targeting Turing Tensor Cores.
179+
180+
_Note, this example uses CUTLASS Utilities. Be sure `tools/util/include` is listed as an include path._
178181
```c++
179182
#include <cutlass/numeric_types.h>
180183
#include <cutlass/gemm/device/gemm.h>
184+
181185
#include <cutlass/util/host_tensor.h>
182186

183187
int main() {

tools/util/include/cutlass/util/host_tensor_planar_complex.h

+163
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ class HostTensorPlanarComplex {
276276
/// Gets pointer to device data with a pointer offset
277277
Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; }
278278

279+
/// Gets a pointer to the device data imaginary part
280+
Element * device_data_imag() { return device_.get() + imaginary_stride(); }
281+
279282
/// Accesses the tensor reference pointing to data
280283
TensorRef host_ref(LongIndex ptr_element_offset=0) {
281284
return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride());
@@ -416,6 +419,166 @@ class HostTensorPlanarComplex {
416419
device_data(), host_data(), imaginary_stride() * 2);
417420
}
418421
}
422+
423+
/// Copy data from a caller-supplied device pointer into host memory.
424+
void copy_in_device_to_host(
425+
Element const* ptr_device_real, ///< source device memory
426+
Element const* ptr_device_imag, ///< source device memory
427+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
428+
429+
if (count < 0) {
430+
count = capacity();
431+
}
432+
else {
433+
count = __NV_STD_MIN(capacity(), count);
434+
}
435+
436+
device_memory::copy_to_host(
437+
host_data(), ptr_device_real, count);
438+
439+
device_memory::copy_to_host(
440+
host_data_imag(), ptr_device_imag, count);
441+
}
442+
443+
/// Copy data from a caller-supplied device pointer into host memory.
444+
void copy_in_device_to_device(
445+
Element const* ptr_device_real, ///< source device memory
446+
Element const* ptr_device_imag, ///< source device memory
447+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
448+
449+
if (count < 0) {
450+
count = capacity();
451+
}
452+
else {
453+
count = __NV_STD_MIN(capacity(), count);
454+
}
455+
456+
device_memory::copy_device_to_device(
457+
device_data(), ptr_device_real, count);
458+
459+
device_memory::copy_device_to_device(
460+
device_data_imag(), ptr_device_imag, count);
461+
}
462+
463+
/// Copy data from a caller-supplied device pointer into host memory.
464+
void copy_in_host_to_device(
465+
Element const* ptr_host_real, ///< source host memory
466+
Element const* ptr_host_imag, ///< source host memory
467+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
468+
469+
if (count < 0) {
470+
count = capacity();
471+
}
472+
else {
473+
count = __NV_STD_MIN(capacity(), count);
474+
}
475+
476+
device_memory::copy_to_device(
477+
device_data(), ptr_host_real, count);
478+
479+
device_memory::copy_to_device(
480+
device_data_imag(), ptr_host_imag, count);
481+
}
482+
483+
/// Copy data from a caller-supplied device pointer into host memory.
484+
void copy_in_host_to_host(
485+
Element const* ptr_host_real, ///< source host memory
486+
Element const* ptr_host_imag, ///< source host memory
487+
LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten.
488+
489+
if (count < 0) {
490+
count = capacity();
491+
}
492+
else {
493+
count = __NV_STD_MIN(capacity(), count);
494+
}
495+
496+
device_memory::copy_host_to_host(
497+
host_data(), ptr_host_real, count);
498+
499+
device_memory::copy_host_to_host(
500+
host_data_imag(), ptr_host_imag, count);
501+
}
502+
503+
/// Copy data from a caller-supplied device pointer into host memory.
504+
void copy_out_device_to_host(
505+
Element * ptr_host_real, ///< source device memory
506+
Element * ptr_host_imag, ///< source device memory
507+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
508+
509+
if (count < 0) {
510+
count = capacity();
511+
}
512+
else {
513+
count = __NV_STD_MIN(capacity(), count);
514+
}
515+
516+
device_memory::copy_to_host(
517+
ptr_host_real, device_data(), count);
518+
519+
device_memory::copy_to_host(
520+
ptr_host_imag, device_data_imag(), count);
521+
}
522+
523+
/// Copy data from a caller-supplied device pointer into host memory.
524+
void copy_out_device_to_device(
525+
Element * ptr_device_real, ///< source device memory
526+
Element * ptr_device_imag, ///< source device memory
527+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
528+
529+
if (count < 0) {
530+
count = capacity();
531+
}
532+
else {
533+
count = __NV_STD_MIN(capacity(), count);
534+
}
535+
536+
device_memory::copy_device_to_device(
537+
ptr_device_real, device_data(), count);
538+
539+
device_memory::copy_device_to_device(
540+
ptr_device_imag, device_data_imag(), count);
541+
}
542+
543+
/// Copy data from a caller-supplied device pointer into host memory.
544+
void copy_out_host_to_device(
545+
Element * ptr_device_real, ///< source device memory
546+
Element * ptr_device_imag, ///< source device memory
547+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
548+
549+
if (count < 0) {
550+
count = capacity();
551+
}
552+
else {
553+
count = __NV_STD_MIN(capacity(), count);
554+
}
555+
556+
device_memory::copy_to_device(
557+
ptr_device_real, host_data(), count);
558+
559+
device_memory::copy_to_device(
560+
ptr_device_imag, host_data_imag(), count);
561+
}
562+
563+
/// Copy data from a caller-supplied device pointer into host memory.
564+
void copy_out_host_to_host(
565+
Element * ptr_host_real, ///< source host memory
566+
Element * ptr_host_imag, ///< source host memory
567+
LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten.
568+
569+
if (count < 0) {
570+
count = capacity();
571+
}
572+
else {
573+
count = __NV_STD_MIN(capacity(), count);
574+
}
575+
576+
device_memory::copy_host_to_host(
577+
ptr_host_real, host_data(), count);
578+
579+
device_memory::copy_host_to_host(
580+
ptr_host_imag, host_data_imag(), count);
581+
}
419582
};
420583

421584
///////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)