Skip to content

Commit 71108d8

Browse files
committed
address pr comments
1 parent 4b017fe commit 71108d8

2 files changed

Lines changed: 5 additions & 21 deletions

File tree

transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ enum class GPUArch {
8585
};
8686

8787
static inline GPUArch detect_gpu_arch() {
88-
switch (cuda::sm_arch(0)) {
88+
switch (cuda::sm_arch()) {
8989
case 94:
9090
return GPUArch::GFX942;
9191
case 95:

transformer_engine/common/gemm/cublaslt_gemm.cu

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,31 +1157,15 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
11571157
};
11581158
#endif
11591159

1160-
#ifdef __HIP_PLATFORM_AMD__
1161-
auto effective_dtype = [](const transformer_engine::Tensor *t) {
1162-
if (t->has_data()) {
1163-
return t->data.dtype;
1164-
}
1165-
if (t->has_columnwise_data()) {
1166-
return t->columnwise_data.dtype;
1167-
}
1168-
return t->data.dtype;
1169-
};
1170-
#endif
1171-
11721160
auto is_supported_dtype = [&]() -> bool {
11731161
auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
11741162
auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
11751163
auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]);
11761164
#ifdef __HIP_PLATFORM_AMD__
1177-
auto effective_dtype = [](const transformer_engine::Tensor* t) {
1178-
if (is_fp8_dtype(t->data.dtype)) {
1179-
return t->data.dtype;
1180-
}
1181-
if (t->has_columnwise_data() && is_fp8_dtype(t->columnwise_data.dtype)) {
1182-
return t->columnwise_data.dtype;
1183-
}
1184-
return t->data.dtype;
1165+
auto effective_dtype = [](const transformer_engine::Tensor *t) {
1166+
NVTE_CHECK(t->has_data() || t->has_columnwise_data(),
1167+
"Input tensor has neither row-wise nor column-wise data.");
1168+
return t->has_data() ? t->data.dtype : t->columnwise_data.dtype;
11851169
};
11861170

11871171
auto A_dt = effective_dtype(inputA);

0 commit comments

Comments
 (0)