Skip to content

Commit

Permalink
Benchmarks: micro benchmarks - add int8 support for cublaslt function (
Browse files Browse the repository at this point in the history
…#574)

**Description**
add int8 support for cublaslt function.
  • Loading branch information
yukirora authored Nov 20, 2023
1 parent c7800bb commit f53d941
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, name, parameters=''):
super().__init__(name, parameters)

self._bin_name = 'cublaslt_gemm'
self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2']
self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2', 'int8']

def mrange(self, start, stop=-1, multiplication_factor=2):
"""Range constructor with multiplication factor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using int8 = int8_t;

struct Args {
int m = 16;
Expand Down Expand Up @@ -84,6 +85,8 @@ template <typename T> cudaDataType_t get_datatype() {
return CUDA_R_8F_E4M3;
if (std::is_same<T, fp8e5m2>::value)
return CUDA_R_8F_E5M2;
if (std::is_same<T, int8>::value)
return CUDA_R_8I;
throw std::invalid_argument("Unknown type");
}

Expand Down Expand Up @@ -162,6 +165,8 @@ int main(int argc, char **argv) {
run<fp8e4m3, fp8e4m3, fp16>(&args);
else if (args.in_type == "fp8e5m2")
run<fp8e5m2, fp8e4m3, fp16>(&args);
else if (args.in_type == "int8")
run<int8>(&args);
else
throw std::invalid_argument("Unknown type " + args.in_type);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
gemm_compute_type = CUBLAS_COMPUTE_32F;
if (a_type == CUDA_R_64F || b_type == CUDA_R_64F)
gemm_compute_type = CUBLAS_COMPUTE_64F;
if (a_type == CUDA_R_8I)
gemm_compute_type = CUBLAS_COMPUTE_32I;

cublasLtMatmulDesc_t op_desc = nullptr;
CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F));
Expand Down
6 changes: 3 additions & 3 deletions tests/benchmarks/micro_benchmarks/test_cublaslt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def test_cublaslt_gemm_command_generation(self):
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
benchmark = benchmark_cls(
self.benchmark_name,
parameters='--batch 2:16:2 --shapes 2:4,4:8,8:32 32:128:4,128,128 --in_types fp16 fp32 fp64',
parameters='--batch 2:16:2 --shapes 2:4,4:8,8:32 32:128:4,128,128 --in_types fp16 fp32 fp64 int8',
)
self.assertTrue(benchmark._preprocess())
self.assertEqual(4 * (2 * 2 * 3 + 2) * 3, len(benchmark._commands))
self.assertEqual(4 * (2 * 2 * 3 + 2) * len(benchmark._args.in_types), len(benchmark._commands))

def cmd(t, b, m, n, k):
return f'{benchmark._CublasLtBenchmark__bin_path} -m {m} -n {n} -k {k} -b {b} -w 20 -i 50 -t {t}'

for _t in ['fp16', 'fp32', 'fp64']:
for _t in ['fp16', 'fp32', 'fp64', 'int8']:
for _b in [2, 4, 8, 16]:
for _m in [2, 4]:
for _n in [4, 8]:
Expand Down

0 comments on commit f53d941

Please sign in to comment.