Skip to content

zzmtsvv/basic_cute_gemm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Persistent GEMM Kernel

This is relatively minimalistic persistent reimplementation (for educational purposes) of Dense Batched GEMM C = A @ B by NVIDIA on CuTeDSL for 16 and 8 bit floating point numbers.

Performance Results

M=8192, N=8192, K=8192, L=1

Implementation Data Types Time (ms) TFLOPs
dense_gemm.py Float8E4M3FN → Float16 0.713 1543.15
dense_gemm.py Float16 → Float32 1.363 806.89
gemm.py (ours) Float8E4M3FN → Float16 1.006 1092.92
gemm.py (ours) BFloat16 → Float32 1.624 677.07

Reimplementation details:

  • Explicit warp specialization of producer warps handling TMA load and consumer warps handling WGMMA
  • PersistentTileScheduler depending on clusters
  • CTAs stay resident, multiple ouput tiles processing before retiring
  • Direct register-to-global copy (no TMA store) of accumulator
  • Supports Float16, BFloat16, Float8E4M3FN, Float8E5M2 data types

Configuration

  • Tile shape: (128, 256, 64) - M=128, N=256, K=64
  • Cluster shape: (2, 1, 1) - 2 CTAs in M dimension
  • Accumulator: Float32/Float16
  • Layout: K-major for A/B, N-major for C

There are stull further optimizations that can be done to be able to at least match the performance of the baseline. Feel free to use this code for educational purposes or contact me in case you find any mistakes :)

References

About

minimalistic persistent gemm implementation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages