Skip to content

Commit

Permalink
update format
Browse files Browse the repository at this point in the history
  • Loading branch information
yukirora committed Nov 17, 2023
1 parent a69c325 commit 089f0b9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using int8= int8_t;

struct Args {
int m = 16;
int n = 16;
Expand Down Expand Up @@ -85,7 +86,7 @@ template <typename T> cudaDataType_t get_datatype() {
if (std::is_same<T, fp8e5m2>::value)
return CUDA_R_8F_E5M2;
if (std:: is_same<T, int8>::value)
return CUDA_R_8I;
return CUDA_R_8I;
throw std::invalid_argument("Unknown type");
}

Expand Down Expand Up @@ -165,7 +166,7 @@ int main(int argc, char **argv) {
else if (args.in_type == "fp8e5m2")
run<fp8e5m2, fp8e4m3, fp16>(&args);
else if (args.in_type == "int8")
run<int8>(&args);
run<int8>(&args);
else
throw std::invalid_argument("Unknown type " + args.in_type);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
if (a_type == CUDA_R_64F || b_type == CUDA_R_64F)
gemm_compute_type = CUBLAS_COMPUTE_64F;
if (a_type==CUDA_R_8I)
gemm_compute_type = CUBLAS_COMPUTE_32I;
gemm_compute_type = CUBLAS_COMPUTE_32I;

cublasLtMatmulDesc_t op_desc = nullptr;
CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F));
op_desc_.reset(op_desc);
Expand Down

0 comments on commit 089f0b9

Please sign in to comment.