[feat] Add flydsl based grouped gemm#384
Merged
Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
Adds a new FlyDSL-based FP8 backend for grouped GEMM (forward/dgrad and variable‑K wgrad) and extends FlyDSL dense FP8 GEMM infrastructure to support int64 SRD re-basing for large-tensor addressing and per-shape autotuning/caching behavior.
Changes:
- Introduces FlyDSL FP8 grouped GEMM kernels (NT/NN) and variable‑K wgrad (TN) with per-shape autotune and CU-capped persistent routing.
- Wires FlyDSL as a selectable backend for FP8 grouped GEMM dispatch (and variable‑K dispatch).
- Updates FlyDSL dense FP8 GEMM primitives to support i64 SRD re-basing, updated loaders, and capture-vs-eager launch mode splitting.
Reviewed changes
Copilot reviewed 5 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py |
Adds FlyDSL backend entries for FP8 grouped GEMM and variable‑K grouped GEMM dispatch. |
primus_turbo/pytorch/kernels/gemm/gemm_fp8_impl.py |
Tightens FlyDSL dense GEMM eligibility with additional i64 SRD re-base size caps for traversal operands. |
primus_turbo/flydsl/utils/fp8_gemm_helper.py |
Adds SRD re-basing buffer helpers, output-store changes, and loader extensions (base offsets, readfirstlane pinning). |
primus_turbo/flydsl/grouped_gemm/gemm_fp8_grouped_kernel.py |
New FlyDSL grouped GEMM kernel implementations + autotune/caching logic for fwd/dgrad and wgrad. |
primus_turbo/flydsl/grouped_gemm/__init__.py |
Package marker for FlyDSL grouped GEMM module. |
primus_turbo/flydsl/gemm/gemm_fp8_kernel.py |
Updates dense FlyDSL GEMM to use i64 re-basing helpers and adds eager-vs-capture launch handling. |
Comments suppressed due to low confidence (1)
primus_turbo/flydsl/utils/fp8_gemm_helper.py:57
_readfirstlane_i32(and its docstring) indicates an i32 SGPR pinning helper, but it’s being called with i64 values here (baseandnrare cast to i64). Please clarify and/or enforce the intended bit-width: either implement a 64-bit-safe readfirstlane helper for addresses/num_records, or keep this helper strictly i32 and adjust the callers accordingly. As written, it’s unclear that >4GB addressing is safe.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
5ac0623 to
2906e8d
Compare
…ad(TN) Adds non-persistent + persistent grouped FP8 GEMM kernels on gfx950 (mfma_f32_16x16x128_f8f6f4): per-group on-device group-major scan, L2-reuse tile swizzle (XCD-remap / group_m / group_n band), per-shape online autotune, vectorized CShuffle store, and CUDA-graph/eager dispatch mode-split (eager flyc.compile skips @flyc.jit per-call drift-check; graph keeps the raw closure). Reusable primitives consolidated into flydsl/utils/fp8_gemm_helper.py (renamed from gemm_helper.py, shared with the dense kernel).
Routes per-tensor FP8 grouped GEMM (fwd/dgrad) to the FlyDSL backend via GroupedGEMMFP8FlyDSLBackend; num_cu<=0 -> non-persistent full-device kernels, num_cu>0 -> persistent (reserve CUs for comm-overlap).
…N kernels Fold _compile_grouped_nt_8w / _compile_grouped_nn_8w into _compile_grouped_nt / _compile_grouped_nn behind a `persistent` flag: const_expr selects the outer scf.for tile loop (persistent, cap_cu reserves CUs) vs one-tile-per-WG + s_endpgm over-launch guard (non-persistent, full-device default), the prelude barrier (unconditional vs divergent wave_m==1), and the launch grid. Per-mode IR is unchanged, so SNR and kernel TFLOPS are identical (verified on the dsv3/qwen/gpt MoE shapes). Removes ~476 lines of duplicated kernel body.
…ed kernel
A per-shape sweep over the masked kernel's (chunk, group_m, num_xcd) shows it
matches or beats the old persistent scf.for wgrad on every MoE shape (worst 1.0%,
up to +13% on long-contraction). Replace the wgrad autotune with a 3-candidate
masked sweep {(8,4,8),(8,0,8),(4,4,8)} and delete _compile_grouped_tn_wgrad_persistent,
_wgrad_loop_body_pipe, and _wgrad_compile_cfg. wgrad is now one kernel per layout.
num_cu is ignored for wgrad (the masked kernel uses a full G*tiles grid and can't
reserve CUs for comm-overlap). Verified SNR 55.6 + TFLOPS on dsv3/qwen/gpt.
… > 2^31 elems / > 4GB)
The buffer path capped addressing at int32 (flat shape pack at 2^31 elems; a single
32-bit-num_records SRD at 4GB), silently corrupting large MoE GEMMs. Per-tile i64
re-base on both sides:
- Output C: StoreCPerTensor + StoreCPerTensorCShuffle re-base per row-band via
extract_base_index + create_buffer_resource_from_addr, small i32 in-tile offset.
- Inputs A/B: make_fp8_buffer_tensor_rebased folds each tile's huge element base into
the i64 SRD base (readfirstlane-pinned so the SRD stays scalar), keeping the buffer
offset small int32; pass A/B full-rank. NT/NN fully covered; wgrad folds m_start
(per-group M_g*OUT_{M,N} stays int32). Removes the now-superseded _make_fp8_buf_nr.
Verified gfx950: NT/NN/wgrad 28.5dB; A=4.5e9 + C=2.26e9 (both > 2^31) 28.5dB; perf
within noise of baseline. (Dense gemm shares the StoreCPerTensor; its callers/int64
inputs are in the following commit.)
…lit launch cache Dense int64 input addressing + a launch-cache optimisation (output C re-based 2D via the shared StoreCPerTensor from the grouped commit). - int64 inputs: NT A[M,K]/B_T[N,K] K-contiguous -> fold the per-tile base into the i64 SRD base (unbounded). NN-B[K,N], TN A[K,M]+B[K,N] are contraction-strided -> fold the column base + compute the K-traversal in i64; single 32-bit SRD caps these at 4GB, can_handle declines > 4GB to fallback. _as_i8_flat passes full-rank int8. (A per-K-iter SRD-base advance removes the cap but a clean graph-replay min-of-8 bench showed ~2% on NN/TN, so it is not used; the 4GB cap covers every bench_gemm_turbo.py shape, max 3.49e9.) - mode-split lazy-compiled launch cache: eager runs a one-time flyc.compile'd object (skips @flyc.jit's per-call drift-check + arg-key rebuild), capture runs the raw closure. Verified gfx950 (graph-replay min-of-8): NT/NN/TN within +-0.7% of baseline; eager+graph 28.5dB; inputs to 4GB (NT 4.3GB, NN-B/TN 3.2GB) 28.5dB.
Optimizes the grouped fp8 backward (dgrad NN, variable-K wgrad TN) for MoE shapes where the balanced-tile assumption breaks, and fixes two correctness/ perf bugs found in review. - dgrad NN M-branch: small-output-M shapes underfill the device with the 256-row M-tile (few N-tiles when N = fwd-K). When the 128-row tiling fits one CU wave (G*ceil(pm/128)*ceil(N/256) <= num_cus) use a single BLOCK_M=128 config (+5..31% over every bm256 swizzle, boundary-swept); else the existing bm256 sweep. - wgrad persist vs masked gated on PER-GROUP contraction (m_total/G <= 1536, not m_total) so high-G MoE with short per-expert M keeps the persistent kernel. - wgrad skew load-balance (band-cyclic): the masked grid was group-contiguous, so an unbalanced token split let the largest group's tiles dominate wall-time (realistic 2:1 skew lost ~20%, heavy skew ~0.4x). Dispatch a group_m M-band per group before switching group -> every group size stays in flight, group_m B-stripe L2 reuse kept (balanced-neutral). Skew now flat ~0.86x balanced; 30:1 wgrad 1162 -> 1592 TF. On by default (env WG_INTERLEAVE=0 to disable). - persist wgrad i64 SRD rebase: cumulative m_start/m_end*OUT overflowed int32 for large-G MoE (e.g. 256 experts, OUT_M=8192 -> m_total*OUT ~ 3.2e9); fold into the i64 base + per-group num_records (same scheme as masked). Verified SNR 55.6 dB. Kernel-level vs Triton (MI355X, B=8 balanced): fwd 1.19x, dgrad 1.14x, wgrad 1.91x; under token skew wgrad matches Triton's robustness (~0.86x balanced). Comments trimmed to <3 lines, no result caching, no dead code.
The masked and persistent wgrad kernels each inlined the same per-tile decode (group/block_m/block_n with band-cyclic / group_n band / group_m cluster / row- major) and the same i64 SRD rebase. Factor both into _wgrad_block_mn and _wgrad_rebase (single source of truth). Behavior unchanged: masked keeps the band-cyclic skew interleave (interleave=True), persistent keeps group_n/group_m (interleave=False); each kernel's own K-loop body is untouched. SNR 55.6 dB, perf and skew/large-G behavior unchanged.
- Rename primus_turbo/flydsl/utils/fp8_gemm_helper.py -> gemm_helper.py - Update imports in gemm_fp8_kernel.py and gemm_fp8_grouped_kernel.py - gemm_helper.py trims make_fp8_buffer_tensor_rebased + _as_index (now inlined / superseded by the int64 rebase path in the kernel) - Rebase onto origin/main (7 commits: #383 scale pad bug fix, #366 AITER MXFP4 preshuffle fast path, #381 meta fix, #377 USP attention, #349 mxfp8 triton grouped gemm, #382 build deps, #380 ci skip) Notable from main worth tracking: - #383: scale pad slots now fill with 0 instead of E8M0_EXPONENT_BIAS(127) for both mxfp4 and mxfp8 — may affect FlyDSL kernel scale correctness if kernel reads beyond valid scale range - #366: AITER MXFP4 fast path adds K_MULTIPLE=32 guard, removes enable_preshuffle() in favor of use_preshuffle flag in Float4QuantConfig
Add BackendType.FLYDSL to three test parametrize lists: - test_grouped_gemm_fp8_tensorwise - test_grouped_gemm_fp8_tensorwise_deterministic - test_grouped_gemm_fp8_tensorwise_quantized_tensor Each gets a gfx950-only skip guard matching the pattern in test_gemm_fp8.py. FlyDSL backend is TENSORWISE-only (per can_handle), so no changes needed for ROWWISE/BLOCKWISE/MX_BLOCKWISE tests.
… NT/NN kernels The non-persistent mk() factory in _autotune_np_dispatch passed nt_vmcnt=-1 to both _compile_grouped_nt and _compile_grouped_nn, suppressing the s_waitcnt vmcnt(N) instruction at the end of each K-loop iteration. Without this fence, the next iteration can start reading from the LDS ping-pong buffers (a_next/b_next) before the G2S buffer_load_lds operations from the current iteration have written to them, causing a data hazard and silent numerical corruption. Symptom: NT forward kernel gave SNR -0.03 dB (vs 28.5 dB expected); NN dgrad with BLOCK_M=128 (small-M path) gave SNR -3 dB. NN BLOCK_M=256 happened to work because more MFMA computation hides the G2S latency, but this was not guaranteed. Fix: use nt_vmcnt=3 (same value as the persistent kernel, already verified correct) in both NT and NN non-persistent factories.
96278c5 to
16c3293
Compare
57fe12f to
24f55f4
Compare
Restore _as_index and make_fp8_buffer_tensor_rebased that were stripped in bfd4510 but are still imported and used by the grouped kernel. Fix two bugs in StoreCPerTensor and StoreCPerTensorCShuffle: 1. buffer_store(mask=False) redirects voffset to 0x7FFFFFFF rather than using a HW predicate. When nrec was clamped to 0xFFFFFFFF, the CDNA HW OOB check passed (0x7FFFFFFF < 0xFFFFFFFF) and the invalid store fired at band_base+2GB. Fix: cap nrec at 0x7FFFFFFF so the sentinel is always >= nrec. Valid tile offsets are at most ~30 MB, well within the 2 GB cap. 2. band_base and nrec derive from the group scan (arith.select chain over group_offs buffer loads), which the compiler's divergence analysis marks as divergent. Without _readfirstlane_i32 the output SRD lands in VGPRs and every buffer_store is wrapped in a waterfall loop. Fix: pin band_base and nrec via _readfirstlane_i32 before passing to create_buffer_resource_from_addr, mirroring the input SRD treatment in make_fp8_buffer_tensor_rebased. Fix misleading group_offs comments: the tensor is int64 [G+1] passed as an int32 view; _load_go reads only the low word at i32[2*idx] (offsets are < 2^31 so the high word is always 0).
24f55f4 to
ccb450e
Compare
RuibinCheung
requested changes
Jun 18, 2026
…2^32 cap
The contraction-traversal operands (NN B[K,N]; TN A[K,M] & B[K,N]; grouped NN
B and wgrad A & B) ride their K-stride offset on the buffer instruction's 32-bit
soffset, so the span k*BLOCK_K*{N,M} wraps once the operand exceeds ~4 GB fp8 --
the dispatcher declined these to a Triton fallback.
Add an i64-traverse mode to G2SLoader: instead of a fixed SRD base + 32-bit
soffset, fold the per-load K-offset into the i64 descriptor base (re-base via
make_fp8_buffer_tensor_rebased, soffset 0). The foldable operands (NT both,
NN A) are unchanged. Threaded through _compile_dense_nn/_tn (+ autotune
dispatch + wrapper auto-select on K*N / K*M >= 2^32) and the grouped NN / wgrad
compile layer; _wgrad_rebase now also returns the per-operand re-base tuples.
NT needs nothing: both operands are K-contiguous, so the per-tile base folds
once into the i64 SRD and the per-load offset stays ~128*K (no realistic cap).
Verified on MI355X: oversized correctness NN k*n=4.33e9 -> 73.1 dB, TN
k*m=4.43e9 -> 76.0 dB. Slowdown of the i64 path on in-cap shapes (same config)
is ~2-5% on compute-bound, up to ~11% on small memory-bound -- so dispatch uses
i64 only at/above 2^32 and keeps the cheaper int32 path below.
Comment on lines
+314
to
+315
| if backend == BackendType.FLYDSL and get_device_compute_capability() < (9, 5): | ||
| pytest.skip("FlyDSL fp8 grouped GEMM is gfx950-only") |
Comment on lines
+436
to
+437
| if backend == BackendType.FLYDSL and get_device_compute_capability() < (9, 5): | ||
| pytest.skip("FlyDSL fp8 grouped GEMM is gfx950-only") |
Comment on lines
+679
to
+680
| if backend == BackendType.FLYDSL and get_device_compute_capability() < (9, 5): | ||
| pytest.skip("FlyDSL fp8 grouped GEMM is gfx950-only") |
Comment on lines
+430
to
+434
| """FlyDSL fp8 grouped GEMM backend (gfx950, per-tensor / TENSORWISE only). | ||
|
|
||
| M-grouped operator: forward (trans_b=True, NT) + dgrad (trans_b=False, NN). | ||
| Uses the FlyDSL mfma_f32_16x16x128_f8f6f4 kernel (gfx950-only). | ||
| """ |
Comment on lines
+1589
to
+1593
| group_offs: "torch.Tensor", | ||
| trans_b: bool = False, | ||
| out_dtype=torch.bfloat16, | ||
| num_cu: "int | None" = -1, | ||
| ) -> "torch.Tensor": |
Comment on lines
+2187
to
+2190
| group_offs: "torch.Tensor", | ||
| out_dtype=torch.bfloat16, | ||
| num_cu: "int | None" = -1, | ||
| ) -> "torch.Tensor": |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds a FlyDSL fp8 grouped GEMM backend for Primus-Turbo's grouped GEMM
(M-grouped forward/dgrad + variable-K wgrad), with full int64 addressing and
per-shape autotuning. On MI355X (gfx950) it beats the existing Triton backend
across all MoE shapes (kernel-level geomean: fwd 1.27×, dgrad 1.30×, wgrad 1.22×).
Type of change
Changes
variable-K wgrad (TN); wired into the grouped GEMM backend dispatch.
num_curouting: full-device non-persistent (nt8w/nn8w) by default, persistent grid
(capped to
num_cu) for comm-overlap.BLOCK_M=128single config for small-output-M shapes thatunderfill the device (+16–31%),
BLOCK_M=256swizzle sweep otherwise.m_total/G ≤ 1536),fixing high-G MoE that the old
m_totalgate mis-routed.Performance — FlyDSL vs Triton (MI355X / gfx950)
Hardware: AMD Instinct MI355X (gfx950)
GPUs tested: GPU 6 and GPU 7 (results cross-validated, variance < 2%)
Framework: Primus-Turbo (dev/kyle/flydsl_grp_gemm), rocm/primus:v26.3
Metric: kernel TFLOPS, pre-quantised fp8 inputs, warmup=200, median over 5×30 iters
B=8 balanced, tensorwise fp8, kernel-level (pure GEMM) warm-mean TFLOPS, SNR≥20.
Models: deepseek-v3 (h7168, i2048) / qwen3-235b (h4096, i1536) / gpt-oss (h2880, i2880); up-proj N = 2·inter.
Summary (geomean over 12 MoE shapes)
fwd (kernel TFLOPS)
dgrad (kernel TFLOPS)
wgrad (kernel TFLOPS)
Checklist: