Skip to content

Add CDNA3 FP8 support#52

Open
Parxd wants to merge 4 commits into
HazyResearch:cdna3from
Parxd:cdna3-fp8
Open

Add CDNA3 FP8 support#52
Parxd wants to merge 4 commits into
HazyResearch:cdna3from
Parxd:cdna3-fp8

Conversation

@Parxd
Copy link
Copy Markdown

@Parxd Parxd commented Apr 14, 2026

This PR resolves #34 and adds:

  • base fp8e4m3_fnuz type
    • utilities for packing and initialization
  • fp8 support for rt_base, st, st_subtile
  • fp8 support for MMA w/ 16x16x32 MFMA
  • 8192x8192x8192 fp8 GEMM kernel example (non-scaled)
    • PyTorch reference (correctness-only, see hipBLASLt performance reference below)

The "tile" concepts (e.g. rt_base and st) currently use the generic TILE_*_DIM to determine base tile shapes. This PR adds a template specialization based on a tile's majorness (i.e. row v. col) for this value, since the use of the fp8 16x16x32 MFMA instruction introduces the need for a non-square base tile.

Modified st and st_subtile to store their majorness as well to accurately track width in subtiles & for swizzling, with a default of row-major (i.e. no change; same for TILE_*_DIM) as to not break any existing code.

All row-major st tiles remain bank conflict-free, but I noticed conflicts for some GEMM configs (both fp8 and bf16) that require col-major RTs; this change could help specialize the idx function for col-major st to setup conflict-free tiles later.

HipKittens benchmark
(.venv) root@6:/workdir/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32# make
# @mkdir -p 
/opt/rocm-7.2.1/bin/hipcc 256_256_64_32.cpp -DKITTENS_CDNA3 --offload-arch=gfx942 -std=c++20 -w -I/workdir/include -I/workdir/prototype -I/usr/include/python3.12 -I/opt/venv/lib/python3.12/site-packages/pybind11/include -L/usr/lib/python3.12/config-3.12-x86_64-linux-gnu -L/usr/lib/x86_64-linux-gnu  -ldl  -lm  -shared -fPIC -Rpass-analysis=kernel-resource-usage -I/workdir/include -I/opt/rocm-7.2.1/include -I/opt/rocm-7.2.1/include/hip  \
    -o tk_kernel.cpython-312-x86_64-linux-gnu.so 2>&1 # | tee 
256_256_64_32.cpp:32:1: remark: Function Name: _Z8micro_tk13micro_globals [-Rpass-analysis=kernel-resource-usage]
   32 | void micro_tk(const micro_globals g) {
      | ^
256_256_64_32.cpp:32:1: remark:     TotalSGPRs: 36 [-Rpass-analysis=kernel-resource-usage]
256_256_64_32.cpp:32:1: remark:     VGPRs: 250 [-Rpass-analysis=kernel-resource-usage]
256_256_64_32.cpp:32:1: remark:     AGPRs: 0 [-Rpass-analysis=kernel-resource-usage]
256_256_64_32.cpp:32:1: remark:     ScratchSize [bytes/lane]: 0 [-Rpass-analysis=kernel-resource-usage]
256_256_64_32.cpp:32:1: remark:     Dynamic Stack: False [-Rpass-analysis=kernel-resource-usage]
256_256_64_32.cpp:32:1: remark:     Occupancy [waves/SIMD]: 2 [-Rpass-analysis=kernel-resource-usage]
256_256_64_32.cpp:32:1: remark:     SGPRs Spill: 0 [-Rpass-analysis=kernel-resource-usage]
256_256_64_32.cpp:32:1: remark:     VGPRs Spill: 0 [-Rpass-analysis=kernel-resource-usage]
256_256_64_32.cpp:32:1: remark:     LDS Size [bytes/block]: 0 [-Rpass-analysis=kernel-resource-usage]

(.venv) root@6:/workdir/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32# python test_python.py 
src: 256_256_64_32.cpp
Warning: tk_kernel.cpython-313-x86_64-linux-gnu.so not found at /workdir/kernels/gemm/fp8fp32/mi300x/8192_256_256_64_32/tk_kernel.cpython-313-x86_64-linux-gnu.so, skipping.
C.dtype=torch.float32
Average execution time: 0.9551 ms
Performance: 1151.20 TFLOPS for 8192x8192 matrix multiplication.

Max error between kernel and reference: 0.0
Max error: 0.0
Mean error: 0.0
Number of large errors (>0.1): 0

TLDR:

  • Shape: 8192x8192x8192
  • Perf.: 1151 TFLOPs
  • Warmup iters: 500
  • Timed iters: 100
hipBLASLt benchmark
(.venv) root@6:/workdir/rocm-libraries/projects/hipblaslt/build/release/clients# ./hipblaslt-bench --api_method c --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 --alpha 1.000000 --beta 0.000000 --transA T --transB N --batch_count 1 --scaleA 1 --scaleB 1 --a_type f8_r --b_type f8_r --c_type f32_r --d_type f32_r --scale_type f32_r --bias_type f32_r --compute_type f32_r --rotating 512 --iters 100 --cold_iters 500 -m 8192 -n 8192 -k 8192 --lda 8192 --ldb 8192 --ldc 8192 --ldd 8192

hipBLASLt version: 100202
hipBLASLt git version: dabb6df2b9
Query device success: there are 1 devices. (Target device ID is 0)
Device ID 0 : AMD Instinct MI300X VF gfx942:sramecc+:xnack-
with 205.8 GB memory, max. SCLK 2100 MHz, max. MCLK 1300 MHz, compute capability 9.4
maxGridDimX 2147483647, sharedMemPerBlock 65.5 KB, maxThreadsPerBlock 1024, warpSize 64

hipblaslt-bench INFO: Using f8_fnuz_r instead of f8_r
hipblaslt-bench INFO: Using f8_fnuz_r instead of f8_r
Rotating buffer 512 MiB. Needed Size: 384 MiB. Needed block count: 2 (Capped to max iters: 500)
Is supported 1 / Total solutions: 1
[0]:transA,transB,grouped_gemm,batch_count,m,n,k,alpha,lda,stride_a,beta,ldb,stride_b,ldc,stride_c,ldd,stride_d,a_type,b_type,c_type,d_type,compute_type,scaleA,scaleB,scaleC,scaleD,amaxD,swizzle_a,swizzle_b,activation_type,bias_vector,bias_type,aux_type,rotating_buffer,flush,use_gpu_timer,hipblaslt-Gflops,hipblaslt-GB/s,us
    T,N,0,1,8192,8192,8192,1,8192,67108864,0,8192,67108864,8192,67108864,8192,67108864,f8_fnuz_r,f8_fnuz_r,f32_r,f32_r,f32_r,1,1,0,0,0,0,0,none,0,f32_r,f32_r,512,0,0,1.12708e+06,384.402,975.54

TLDR:

  • Shape: 8192x8192x8192
  • Type: fp8fp32 (identity scale)
  • Perf.: 1127 TFLOPs
  • Warmup iters: 500
  • Timed iters: 100

Details:

  • Docker container: docker.io/rocm/pytorch
  • ROCm version: 7.2.1
  • Device: MI300X

@Parxd Parxd changed the title Add FP8 support Add CDNA3 FP8 support Apr 14, 2026
kyle-256 added a commit to kyle-256/HipKittens that referenced this pull request May 11, 2026
DeepSeek-V3 style 1x128 / 128x128 block-scaled FP8 GEMM forward kernel for
MI300X (gfx942, CDNA3). Achieves 781 TFLOPS @ 8192^3 — 97% of Primus-Turbo
Triton blockwise (808T avg) and 70% of PR HazyResearch#52's unscaled FP8 ceiling (1120T).

Files in kernels/gemm/fp8fp32/mi300x/blockwise_8192/:
- blockwise.cpp: 8-warp 256x128x128 tile, 2x4 warp grid, per-warp 2 partial +
  2 main fp32 col_l 64x32 accumulators. Drains partial->main with scale at
  end of every K_TILE. Inherits PR HazyResearch#52's 8-cluster software pipelining with
  register-buffered prefetch.
- README.md: full design notes, optimization history (v0=94T -> v2=781T),
  things tried that didn't help, and the path to 1.15x Triton (~929T).
- _metric.py: single-number score (min_TFLOPS x 10 - penalties) for driving
  A/B optimization rounds. Pattern from feat/attn-d64-gptoss/scripts/.
- bench_triton.py: reference benchmark of Primus-Turbo Triton blockwise
  on the same shape for direct comparison.

Key architectural decision: BLOCK_N = 128 (not 256). This aligns 1 b_scale
block per warp, halves accumulator footprint vs naive 256x256, and brings
spill from 332 -> 0. Matches Triton autotune's pin of BLOCK_N=128 across
all blockwise configs.

Correctness: SNR 49.6 dB vs fp32 reference (gate >= 48 dB).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kyle-256 added a commit to kyle-256/HipKittens that referenced this pull request May 11, 2026
Mirror the layout in /wekafs/kyle/remote_sync/Primus-Turbo/scripts/ —
underscore-prefixed utility scripts at repo root, kernel directories
hold only source + Makefile + correctness test.

Naming convention (see scripts/README.md):
  _metric_<topic>_<gpu>.py  one config → integer score (build + correctness + N×perf)
  _fleet_<topic>_<gpu>.py   sweep many configs, wrap _metric, summary table
  _bench_<topic>_<ref>.py   competitor / reference benchmark
  _probe_/_compare_/_run.sh other categories (none yet)

Initial scripts (blockwise FP8 GEMM on MI300X):
  scripts/_metric_blockwise_fp8_mi300x.py   moved from kernel dir
  scripts/_bench_blockwise_triton_ref.py    moved (Triton reference)
  scripts/_fleet_blockwise_fp8_mi300x.py    new — A/B sweep over WGM,
                                             launch_bounds, tile reuse,
                                             scale-load placement, etc.
  scripts/README.md                          full pattern doc + usage

Each script resolves its target kernel via a top-level KERNEL_DIR
constant (no relative-cwd magic). Fleet results cache lives at
scripts/.fleet_<name>_cache.json (gitignored).

Also documented the PR HazyResearch#52 dependency in the kernel README — the
kernel uses fp8 mma_ABt overloads added by that unmerged PR; building
against current main fails until PR HazyResearch#52 lands or the branch is rebased
onto Parxd:cdna3-fp8.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kyle-256 added a commit to kyle-256/HipKittens that referenced this pull request May 18, 2026
End-to-end FP8 blockwise GEMM (DeepSeek-V3 1×128 / 128×128 scaling) for
MI300X (gfx942), covering fwd (RCR/NT), dgrad (RRR/NN), wgrad (CRR/TN)
across the 18-shape production target × 3 sections = 54 configs.

** Kernel (kernels/gemm/fp8fp32/mi300x/blockwise_8192/blockwise.cpp) **
- 8-cluster software pipeline with register-buffered prefetch
- Tunables via -D flags: BLOCK_M/N/K, NUM_WARPS, REG_M, BW_RAW_DRAIN
  (typed-float2 raw_partial drain), BW_PRESCALE_BS (cluster-6 b_s
  pre-fold), BW_CHIPLET_CHUNK (XCD round-robin swizzle)
- Default 8192³ baseline: 49.59 dB SNR, ~768 TFLOPS, 236 VGPR, 2 wave/SIMD, 0 spill
- dispatch_micro / dispatch_micro_dgrad / dispatch_micro_wgrad bindings.
  dgrad routes through fwd via caller-side B.T.contiguous(); wgrad
  routes through fwd via K-contig col-T inputs + per-element b_scale
  (micro_tk<true>). Header comment documents the routing contract.

** Per-shape JIT autotune **
- tuned_configs.py: 175 entries; per-shape (BM, BN, BK, NW, REG_M, RAW, PRE, CHK)
- test_python.py loader rebuilds the matching .so on demand via `make tuned`
- Manual override via BW_USE_TUNED=1 + env vars (sweep mode)

** Grouped GEMM (kernels/.../grouped_blockwise/) **
- grouped_blockwise.cpp + grouped_wgrad.cpp: persistent 304-WG kernel
  for batched/MoE-style workloads with per-group M ranges

** Metrics + autotune daemon (scripts/) **
- _shapes_target.py: 18 production shapes + THIS-machine Triton baseline
- _bench_blockwise_triton_target_shapes.py: re-bench Triton fresh
- _metric_blockwise_fp8_target_shapes.py: full 54-shape score
  (HK / 1.25×Tri, capped at 1.0 per shape, mean × 1000)
- _metric_blockwise_fp8_loser_shapes.py: focused subset metric (~8×
  per-pair sensitivity vs full)
- auto_optimize_blockwise_fp8.py + launch_*.sh: nohup daemon that
  spawns Claude per round, scores via metric, tracks best/streak.
  REPO path auto-derived from script location (no hardcoded absolutes).
- _task_blockwise_fp8{,_losers_extra}.md: daemon prompt specs
- _goal_blockwise_fp8.md: north-star, current score, exhausted-knobs
  table, remaining multi-week structural attacks

** Current section progress (full 54-shape metric, this commit) **
  fwd   (RCR/NT) 76.5%  |  HK avg 603 T vs Triton 637 T  (95% Tri)
  dgrad (RRR/NN) 98.8%  |  HK avg 678 T vs Triton 512 T  (132% Tri)
  wgrad (CRR/TN) 82.3%  |  HK avg 534 T vs Triton 512 T  (104% Tri)
  overall        85.9%  |  HK avg 605 T vs Triton 554 T  (109% Tri)
                                                  → score 859 / 1000

Remaining headroom concentrates in 12 loser pairs (HK < 105% Tri,
mostly big-K fwd + big-M wgrad). Single-knob sweep space exhausted
across 130+ daemon rounds; remaining gains require structural rewrites
(KBPT=2 interleaved, dedicated wgrad kernel, persistent kernel,
32×32×16 MFMA) — see scripts/_goal_blockwise_fp8.md for falsified
attempts and the remaining attack list.

** Dependency **
Requires the fp8e4m3 mma_ABt overloads added by upstream PR HazyResearch#52
(cdna3-fp8). Branch is rebased on that PR; do not merge until HazyResearch#52 lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Parxd
Copy link
Copy Markdown
Author

Parxd commented May 29, 2026

Hi @willhu-jpg just wondering if you had a chance to review this, no rush if CDNA3 stuff is lower priority at the moment + happy to make any changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant