Skip to content

Commit 4dac749

Browse files
authored
Typoes (NVIDIA#107)
* Update splitk_gemm.cu * Update gemm_bias_relu.cu * Update mma_sm75.h
1 parent fd7e058 commit 4dac749

File tree

3 files changed

+3
-4
lines changed

3 files changed

+3
-4
lines changed

examples/06_splitK_gemm/splitk_gemm.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ int run() {
205205
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
206206
problem_size.mk()); // <- Create matrix A with dimensions M x K
207207
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
208-
problem_size.nk()); // <- Create matrix B with dimensions N x K
208+
problem_size.kn()); // <- Create matrix B with dimensions K x N
209209
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
210210
problem_size.mn()); // <- Create matrix C with dimensions M x N
211211
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(

examples/12_gemm_bias_relu/gemm_bias_relu.cu

+1-2
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ int run() {
132132
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
133133
problem_size.mk()); // <- Create matrix A with dimensions M x K
134134
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
135-
problem_size.nk()); // <- Create matrix B with dimensions N x K
135+
problem_size.kn()); // <- Create matrix B with dimensions K x N
136136

137137
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias(
138138
{problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1
@@ -234,7 +234,6 @@ int run() {
234234
tensor_a.device_ref(),
235235
tensor_b.device_ref(),
236236
0,
237-
tensor_c_bias.device_ref(),
238237
tensor_ref_d.device_ref());
239238

240239
// Wait for kernels to finish

include/cutlass/arch/mma_sm75.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ struct Mma<
823823
int const *C = reinterpret_cast<int const *>(&c);
824824
int *D = reinterpret_cast<int *>(&d);
825825

826-
asm volatile("_mma.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
826+
asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
827827
: "=r"(D[0]), "=r"(D[1])
828828
: "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
829829

0 commit comments

Comments
 (0)