This is relatively minimalistic persistent reimplementation (for educational purposes) of Dense Batched GEMM C = A @ B by NVIDIA on CuTeDSL for 16 and 8 bit floating point numbers.
M=8192, N=8192, K=8192, L=1
| Implementation | Data Types | Time (ms) | TFLOPs |
|---|---|---|---|
dense_gemm.py |
Float8E4M3FN → Float16 | 0.713 | 1543.15 |
dense_gemm.py |
Float16 → Float32 | 1.363 | 806.89 |
gemm.py (ours) |
Float8E4M3FN → Float16 | 1.006 | 1092.92 |
gemm.py (ours) |
BFloat16 → Float32 | 1.624 | 677.07 |
- Explicit warp specialization of producer warps handling TMA load and consumer warps handling WGMMA
- PersistentTileScheduler depending on clusters
- CTAs stay resident, multiple ouput tiles processing before retiring
- Direct register-to-global copy (no TMA store) of accumulator
- Supports Float16, BFloat16, Float8E4M3FN, Float8E5M2 data types
- Tile shape: (128, 256, 64) - M=128, N=256, K=64
- Cluster shape: (2, 1, 1) - 2 CTAs in M dimension
- Accumulator: Float32/Float16
- Layout: K-major for A/B, N-major for C
There are stull further optimizations that can be done to be able to at least match the performance of the baseline. Feel free to use this code for educational purposes or contact me in case you find any mistakes :)