Skip to content

Commit

Permalink
Fix issue with setAttribute and int8_t vs int32_t variables (pytorch#…
Browse files Browse the repository at this point in the history
…143693)

Test Plan: Sandcastle

Pull Request resolved: pytorch#143693
Approved by: https://github.com/huydhn
  • Loading branch information
r-barnes authored and pytorchmergebot committed Dec 21, 2024
1 parent 518b505 commit 9f3c291
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1453,7 +1453,8 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
}
#ifndef USE_ROCM
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, use_fast_accum ? 1 : 0);
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
#endif
CuBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't');
CuBlasLtMatrixLayout Bdesc(ScalarTypeToCudaDataType(mat2_dtype), k, n, mat2_ld, transb == 't');
Expand Down

0 comments on commit 9f3c291

Please sign in to comment.