Add CDNA3 FP8 support#52
Open
Parxd wants to merge 4 commits into
Open
Conversation
kyle-256
added a commit
to kyle-256/HipKittens
that referenced
this pull request
May 11, 2026
DeepSeek-V3 style 1x128 / 128x128 block-scaled FP8 GEMM forward kernel for MI300X (gfx942, CDNA3). Achieves 781 TFLOPS @ 8192^3 — 97% of Primus-Turbo Triton blockwise (808T avg) and 70% of PR HazyResearch#52's unscaled FP8 ceiling (1120T). Files in kernels/gemm/fp8fp32/mi300x/blockwise_8192/: - blockwise.cpp: 8-warp 256x128x128 tile, 2x4 warp grid, per-warp 2 partial + 2 main fp32 col_l 64x32 accumulators. Drains partial->main with scale at end of every K_TILE. Inherits PR HazyResearch#52's 8-cluster software pipelining with register-buffered prefetch. - README.md: full design notes, optimization history (v0=94T -> v2=781T), things tried that didn't help, and the path to 1.15x Triton (~929T). - _metric.py: single-number score (min_TFLOPS x 10 - penalties) for driving A/B optimization rounds. Pattern from feat/attn-d64-gptoss/scripts/. - bench_triton.py: reference benchmark of Primus-Turbo Triton blockwise on the same shape for direct comparison. Key architectural decision: BLOCK_N = 128 (not 256). This aligns 1 b_scale block per warp, halves accumulator footprint vs naive 256x256, and brings spill from 332 -> 0. Matches Triton autotune's pin of BLOCK_N=128 across all blockwise configs. Correctness: SNR 49.6 dB vs fp32 reference (gate >= 48 dB). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kyle-256
added a commit
to kyle-256/HipKittens
that referenced
this pull request
May 11, 2026
Mirror the layout in /wekafs/kyle/remote_sync/Primus-Turbo/scripts/ —
underscore-prefixed utility scripts at repo root, kernel directories
hold only source + Makefile + correctness test.
Naming convention (see scripts/README.md):
_metric_<topic>_<gpu>.py one config → integer score (build + correctness + N×perf)
_fleet_<topic>_<gpu>.py sweep many configs, wrap _metric, summary table
_bench_<topic>_<ref>.py competitor / reference benchmark
_probe_/_compare_/_run.sh other categories (none yet)
Initial scripts (blockwise FP8 GEMM on MI300X):
scripts/_metric_blockwise_fp8_mi300x.py moved from kernel dir
scripts/_bench_blockwise_triton_ref.py moved (Triton reference)
scripts/_fleet_blockwise_fp8_mi300x.py new — A/B sweep over WGM,
launch_bounds, tile reuse,
scale-load placement, etc.
scripts/README.md full pattern doc + usage
Each script resolves its target kernel via a top-level KERNEL_DIR
constant (no relative-cwd magic). Fleet results cache lives at
scripts/.fleet_<name>_cache.json (gitignored).
Also documented the PR HazyResearch#52 dependency in the kernel README — the
kernel uses fp8 mma_ABt overloads added by that unmerged PR; building
against current main fails until PR HazyResearch#52 lands or the branch is rebased
onto Parxd:cdna3-fp8.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kyle-256
added a commit
to kyle-256/HipKittens
that referenced
this pull request
May 18, 2026
End-to-end FP8 blockwise GEMM (DeepSeek-V3 1×128 / 128×128 scaling) for
MI300X (gfx942), covering fwd (RCR/NT), dgrad (RRR/NN), wgrad (CRR/TN)
across the 18-shape production target × 3 sections = 54 configs.
** Kernel (kernels/gemm/fp8fp32/mi300x/blockwise_8192/blockwise.cpp) **
- 8-cluster software pipeline with register-buffered prefetch
- Tunables via -D flags: BLOCK_M/N/K, NUM_WARPS, REG_M, BW_RAW_DRAIN
(typed-float2 raw_partial drain), BW_PRESCALE_BS (cluster-6 b_s
pre-fold), BW_CHIPLET_CHUNK (XCD round-robin swizzle)
- Default 8192³ baseline: 49.59 dB SNR, ~768 TFLOPS, 236 VGPR, 2 wave/SIMD, 0 spill
- dispatch_micro / dispatch_micro_dgrad / dispatch_micro_wgrad bindings.
dgrad routes through fwd via caller-side B.T.contiguous(); wgrad
routes through fwd via K-contig col-T inputs + per-element b_scale
(micro_tk<true>). Header comment documents the routing contract.
** Per-shape JIT autotune **
- tuned_configs.py: 175 entries; per-shape (BM, BN, BK, NW, REG_M, RAW, PRE, CHK)
- test_python.py loader rebuilds the matching .so on demand via `make tuned`
- Manual override via BW_USE_TUNED=1 + env vars (sweep mode)
** Grouped GEMM (kernels/.../grouped_blockwise/) **
- grouped_blockwise.cpp + grouped_wgrad.cpp: persistent 304-WG kernel
for batched/MoE-style workloads with per-group M ranges
** Metrics + autotune daemon (scripts/) **
- _shapes_target.py: 18 production shapes + THIS-machine Triton baseline
- _bench_blockwise_triton_target_shapes.py: re-bench Triton fresh
- _metric_blockwise_fp8_target_shapes.py: full 54-shape score
(HK / 1.25×Tri, capped at 1.0 per shape, mean × 1000)
- _metric_blockwise_fp8_loser_shapes.py: focused subset metric (~8×
per-pair sensitivity vs full)
- auto_optimize_blockwise_fp8.py + launch_*.sh: nohup daemon that
spawns Claude per round, scores via metric, tracks best/streak.
REPO path auto-derived from script location (no hardcoded absolutes).
- _task_blockwise_fp8{,_losers_extra}.md: daemon prompt specs
- _goal_blockwise_fp8.md: north-star, current score, exhausted-knobs
table, remaining multi-week structural attacks
** Current section progress (full 54-shape metric, this commit) **
fwd (RCR/NT) 76.5% | HK avg 603 T vs Triton 637 T (95% Tri)
dgrad (RRR/NN) 98.8% | HK avg 678 T vs Triton 512 T (132% Tri)
wgrad (CRR/TN) 82.3% | HK avg 534 T vs Triton 512 T (104% Tri)
overall 85.9% | HK avg 605 T vs Triton 554 T (109% Tri)
→ score 859 / 1000
Remaining headroom concentrates in 12 loser pairs (HK < 105% Tri,
mostly big-K fwd + big-M wgrad). Single-knob sweep space exhausted
across 130+ daemon rounds; remaining gains require structural rewrites
(KBPT=2 interleaved, dedicated wgrad kernel, persistent kernel,
32×32×16 MFMA) — see scripts/_goal_blockwise_fp8.md for falsified
attempts and the remaining attack list.
** Dependency **
Requires the fp8e4m3 mma_ABt overloads added by upstream PR HazyResearch#52
(cdna3-fp8). Branch is rebased on that PR; do not merge until HazyResearch#52 lands.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Author
|
Hi @willhu-jpg just wondering if you had a chance to review this, no rush if CDNA3 stuff is lower priority at the moment + happy to make any changes |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR resolves #34 and adds:
fp8e4m3_fnuztypert_base,st,st_subtile16x16x32MFMA8192x8192x8192fp8 GEMM kernel example (non-scaled)The "tile" concepts (e.g.
rt_baseandst) currently use the genericTILE_*_DIMto determine base tile shapes. This PR adds a template specialization based on a tile's majorness (i.e. row v. col) for this value, since the use of thefp816x16x32MFMA instruction introduces the need for a non-square base tile.Modified
standst_subtileto store their majorness as well to accurately track width in subtiles & for swizzling, with a default of row-major (i.e. no change; same forTILE_*_DIM) as to not break any existing code.All row-major
sttiles remain bank conflict-free, but I noticed conflicts for some GEMM configs (bothfp8andbf16) that require col-major RTs; this change could help specialize theidxfunction for col-majorstto setup conflict-free tiles later.HipKittens benchmark
TLDR:
8192x8192x8192hipBLASLt benchmark
TLDR:
8192x8192x8192fp8fp32(identity scale)Details: