Skip to content

[feat] Add flydsl based grouped gemm#384

Merged
RuibinCheung merged 13 commits into
mainfrom
dev/kyle/flydsl_grp_gemm
Jun 22, 2026
Merged

[feat] Add flydsl based grouped gemm#384
RuibinCheung merged 13 commits into
mainfrom
dev/kyle/flydsl_grp_gemm

Conversation

@kyle-256

@kyle-256 kyle-256 commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Description

Adds a FlyDSL fp8 grouped GEMM backend for Primus-Turbo's grouped GEMM
(M-grouped forward/dgrad + variable-K wgrad), with full int64 addressing and
per-shape autotuning. On MI355X (gfx950) it beats the existing Triton backend
across all MoE shapes (kernel-level geomean: fwd 1.27×, dgrad 1.30×, wgrad 1.22×).

Type of change

  • Documentation change
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change
  • Infra/Build change
  • Code refactoring

Changes

  • FlyDSL fp8 grouped GEMM kernels: M-grouped forward (NT), dgrad (NN), and
    variable-K wgrad (TN); wired into the grouped GEMM backend dispatch.
  • Merged per-layout persistent / non-persistent NT & NN kernels with num_cu
    routing: full-device non-persistent (nt8w/nn8w) by default, persistent grid
    (capped to num_cu) for comm-overlap.
  • Per-shape online autotune (balanced-token-timed, hysteresis-guarded):
    • dgrad NN M-branch: BLOCK_M=128 single config for small-output-M shapes that
      underfill the device (+16–31%), BLOCK_M=256 swizzle sweep otherwise.
    • wgrad persist-vs-masked gated on per-group contraction (m_total/G ≤ 1536),
      fixing high-G MoE that the old m_total gate mis-routed.

Performance — FlyDSL vs Triton (MI355X / gfx950)

Hardware: AMD Instinct MI355X (gfx950)
GPUs tested: GPU 6 and GPU 7 (results cross-validated, variance < 2%)
Framework: Primus-Turbo (dev/kyle/flydsl_grp_gemm), rocm/primus:v26.3
Metric: kernel TFLOPS, pre-quantised fp8 inputs, warmup=200, median over 5×30 iters

B=8 balanced, tensorwise fp8, kernel-level (pure GEMM) warm-mean TFLOPS, SNR≥20.
Models: deepseek-v3 (h7168, i2048) / qwen3-235b (h4096, i1536) / gpt-oss (h2880, i2880); up-proj N = 2·inter.

Summary (geomean over 12 MoE shapes)

op FlyDSL Triton fly/tri
fwd 2357 1861 1.27×
dgrad 2411 1863 1.29×
wgrad 2136 1762 1.21×

fwd (kernel TFLOPS)

shape B M N K FlyDSL Triton fly/tri
deepseek-up 8 2048 4096 7168 2688 2369 1.13×
deepseek-up 8 4096 4096 7168 2693 2429 1.11×
deepseek-down 8 2048 7168 2048 2349 1889 1.24×
deepseek-down 8 4096 7168 2048 2208 1949 1.13×
qwen235b-up 8 2048 3072 4096 2694 1770 1.52×
qwen235b-up 8 4096 3072 4096 2548 2172 1.17×
qwen235b-down 8 2048 4096 1536 2054 1415 1.45×
qwen235b-down 8 4096 4096 1536 2138 1750 1.22×
gpt_oss-up 8 2048 5760 2880 2331 1795 1.30×
gpt_oss-up 8 4096 5760 2880 2257 1889 1.19×
gpt_oss-down 8 2048 2880 2880 2189 1379 1.59×
gpt_oss-down 8 4096 2880 2880 2258 1831 1.23×
geomean 2357 1861 1.27×

dgrad (kernel TFLOPS)

shape B M N K FlyDSL Triton fly/tri
deepseek-up 8 2048 4096 7168 2383 2165 1.10×
deepseek-up 8 4096 4096 7168 2392 2237 1.07×
deepseek-down 8 2048 7168 2048 2634 2351 1.12×
deepseek-down 8 4096 7168 2048 2652 2340 1.13×
qwen235b-up 8 2048 3072 4096 2507 1309 1.92×
qwen235b-up 8 4096 3072 4096 2306 2106 1.09×
qwen235b-down 8 2048 4096 1536 2282 1708 1.34×
qwen235b-down 8 4096 4096 1536 2622 1792 1.46×
gpt_oss-up 8 2048 5760 2880 2469 1406 1.76×
gpt_oss-up 8 4096 5760 2880 2376 2141 1.11×
gpt_oss-down 8 2048 2880 2880 2159 1406 1.54×
gpt_oss-down 8 4096 2880 2880 2216 1843 1.20×
geomean 2411 1863 1.29×

wgrad (kernel TFLOPS)

shape B M N K FlyDSL Triton fly/tri
deepseek-up 8 2048 4096 7168 2086 1857 1.12×
deepseek-up 8 4096 4096 7168 2439 2162 1.13×
deepseek-down 8 2048 7168 2048 1998 1783 1.12×
deepseek-down 8 4096 7168 2048 2377 2099 1.13×
qwen235b-up 8 2048 3072 4096 2183 1831 1.19×
qwen235b-up 8 4096 3072 4096 2425 2121 1.14×
qwen235b-down 8 2048 4096 1536 2056 1436 1.43×
qwen235b-down 8 4096 4096 1536 2245 1414 1.59×
gpt_oss-up 8 2048 5760 2880 1947 1526 1.28×
gpt_oss-up 8 4096 5760 2880 2123 1840 1.15×
gpt_oss-down 8 2048 2880 2880 1776 1536 1.16×
gpt_oss-down 8 4096 2880 2880 2085 1756 1.19×
geomean 2136 1762 1.21×

Checklist:

  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes (SNR-gated correctness, all 288 shapes 0 ERROR/FAIL)

Copilot AI review requested due to automatic review settings June 16, 2026 11:45

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new FlyDSL-based FP8 backend for grouped GEMM (forward/dgrad and variable‑K wgrad) and extends FlyDSL dense FP8 GEMM infrastructure to support int64 SRD re-basing for large-tensor addressing and per-shape autotuning/caching behavior.

Changes:

  • Introduces FlyDSL FP8 grouped GEMM kernels (NT/NN) and variable‑K wgrad (TN) with per-shape autotune and CU-capped persistent routing.
  • Wires FlyDSL as a selectable backend for FP8 grouped GEMM dispatch (and variable‑K dispatch).
  • Updates FlyDSL dense FP8 GEMM primitives to support i64 SRD re-basing, updated loaders, and capture-vs-eager launch mode splitting.

Reviewed changes

Copilot reviewed 5 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py Adds FlyDSL backend entries for FP8 grouped GEMM and variable‑K grouped GEMM dispatch.
primus_turbo/pytorch/kernels/gemm/gemm_fp8_impl.py Tightens FlyDSL dense GEMM eligibility with additional i64 SRD re-base size caps for traversal operands.
primus_turbo/flydsl/utils/fp8_gemm_helper.py Adds SRD re-basing buffer helpers, output-store changes, and loader extensions (base offsets, readfirstlane pinning).
primus_turbo/flydsl/grouped_gemm/gemm_fp8_grouped_kernel.py New FlyDSL grouped GEMM kernel implementations + autotune/caching logic for fwd/dgrad and wgrad.
primus_turbo/flydsl/grouped_gemm/__init__.py Package marker for FlyDSL grouped GEMM module.
primus_turbo/flydsl/gemm/gemm_fp8_kernel.py Updates dense FlyDSL GEMM to use i64 re-basing helpers and adds eager-vs-capture launch handling.
Comments suppressed due to low confidence (1)

primus_turbo/flydsl/utils/fp8_gemm_helper.py:57

  • _readfirstlane_i32 (and its docstring) indicates an i32 SGPR pinning helper, but it’s being called with i64 values here (base and nr are cast to i64). Please clarify and/or enforce the intended bit-width: either implement a 64-bit-safe readfirstlane helper for addresses/num_records, or keep this helper strictly i32 and adjust the callers accordingly. As written, it’s unclear that >4GB addressing is safe.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread primus_turbo/flydsl/gemm/gemm_fp8_kernel.py Outdated
Comment thread primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py
Comment thread primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py
@kyle-256 kyle-256 force-pushed the dev/kyle/flydsl_grp_gemm branch from 5ac0623 to 2906e8d Compare June 16, 2026 13:32
Copilot AI review requested due to automatic review settings June 16, 2026 13:44

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 6 changed files in this pull request and generated 2 comments.

Comment thread primus_turbo/pytorch/kernels/grouped_gemm/grouped_gemm_fp8_impl.py
Comment thread primus_turbo/flydsl/gemm/gemm_fp8_kernel.py
kyle-256 and others added 11 commits June 17, 2026 08:20
…ad(TN)

Adds non-persistent + persistent grouped FP8 GEMM kernels on gfx950
(mfma_f32_16x16x128_f8f6f4): per-group on-device group-major scan, L2-reuse
tile swizzle (XCD-remap / group_m / group_n band), per-shape online autotune,
vectorized CShuffle store, and CUDA-graph/eager dispatch mode-split (eager
flyc.compile skips @flyc.jit per-call drift-check; graph keeps the raw closure).
Reusable primitives consolidated into flydsl/utils/fp8_gemm_helper.py (renamed
from gemm_helper.py, shared with the dense kernel).
Routes per-tensor FP8 grouped GEMM (fwd/dgrad) to the FlyDSL backend via
GroupedGEMMFP8FlyDSLBackend; num_cu<=0 -> non-persistent full-device kernels,
num_cu>0 -> persistent (reserve CUs for comm-overlap).
…N kernels

Fold _compile_grouped_nt_8w / _compile_grouped_nn_8w into _compile_grouped_nt /
_compile_grouped_nn behind a `persistent` flag: const_expr selects the outer
scf.for tile loop (persistent, cap_cu reserves CUs) vs one-tile-per-WG + s_endpgm
over-launch guard (non-persistent, full-device default), the prelude barrier
(unconditional vs divergent wave_m==1), and the launch grid. Per-mode IR is
unchanged, so SNR and kernel TFLOPS are identical (verified on the dsv3/qwen/gpt
MoE shapes). Removes ~476 lines of duplicated kernel body.
…ed kernel

A per-shape sweep over the masked kernel's (chunk, group_m, num_xcd) shows it
matches or beats the old persistent scf.for wgrad on every MoE shape (worst 1.0%,
up to +13% on long-contraction). Replace the wgrad autotune with a 3-candidate
masked sweep {(8,4,8),(8,0,8),(4,4,8)} and delete _compile_grouped_tn_wgrad_persistent,
_wgrad_loop_body_pipe, and _wgrad_compile_cfg. wgrad is now one kernel per layout.
num_cu is ignored for wgrad (the masked kernel uses a full G*tiles grid and can't
reserve CUs for comm-overlap). Verified SNR 55.6 + TFLOPS on dsv3/qwen/gpt.
… > 2^31 elems / > 4GB)

The buffer path capped addressing at int32 (flat shape pack at 2^31 elems; a single
32-bit-num_records SRD at 4GB), silently corrupting large MoE GEMMs. Per-tile i64
re-base on both sides:
- Output C: StoreCPerTensor + StoreCPerTensorCShuffle re-base per row-band via
  extract_base_index + create_buffer_resource_from_addr, small i32 in-tile offset.
- Inputs A/B: make_fp8_buffer_tensor_rebased folds each tile's huge element base into
  the i64 SRD base (readfirstlane-pinned so the SRD stays scalar), keeping the buffer
  offset small int32; pass A/B full-rank. NT/NN fully covered; wgrad folds m_start
  (per-group M_g*OUT_{M,N} stays int32). Removes the now-superseded _make_fp8_buf_nr.

Verified gfx950: NT/NN/wgrad 28.5dB; A=4.5e9 + C=2.26e9 (both > 2^31) 28.5dB; perf
within noise of baseline. (Dense gemm shares the StoreCPerTensor; its callers/int64
inputs are in the following commit.)
…lit launch cache

Dense int64 input addressing + a launch-cache optimisation (output C re-based 2D via
the shared StoreCPerTensor from the grouped commit).
- int64 inputs: NT A[M,K]/B_T[N,K] K-contiguous -> fold the per-tile base into the i64
  SRD base (unbounded). NN-B[K,N], TN A[K,M]+B[K,N] are contraction-strided -> fold the
  column base + compute the K-traversal in i64; single 32-bit SRD caps these at 4GB,
  can_handle declines > 4GB to fallback. _as_i8_flat passes full-rank int8. (A per-K-iter
  SRD-base advance removes the cap but a clean graph-replay min-of-8 bench showed ~2% on
  NN/TN, so it is not used; the 4GB cap covers every bench_gemm_turbo.py shape, max 3.49e9.)
- mode-split lazy-compiled launch cache: eager runs a one-time flyc.compile'd object
  (skips @flyc.jit's per-call drift-check + arg-key rebuild), capture runs the raw closure.

Verified gfx950 (graph-replay min-of-8): NT/NN/TN within +-0.7% of baseline; eager+graph
28.5dB; inputs to 4GB (NT 4.3GB, NN-B/TN 3.2GB) 28.5dB.
Optimizes the grouped fp8 backward (dgrad NN, variable-K wgrad TN) for MoE
shapes where the balanced-tile assumption breaks, and fixes two correctness/
perf bugs found in review.

- dgrad NN M-branch: small-output-M shapes underfill the device with the 256-row
  M-tile (few N-tiles when N = fwd-K). When the 128-row tiling fits one CU wave
  (G*ceil(pm/128)*ceil(N/256) <= num_cus) use a single BLOCK_M=128 config (+5..31%
  over every bm256 swizzle, boundary-swept); else the existing bm256 sweep.
- wgrad persist vs masked gated on PER-GROUP contraction (m_total/G <= 1536, not
  m_total) so high-G MoE with short per-expert M keeps the persistent kernel.
- wgrad skew load-balance (band-cyclic): the masked grid was group-contiguous, so
  an unbalanced token split let the largest group's tiles dominate wall-time
  (realistic 2:1 skew lost ~20%, heavy skew ~0.4x). Dispatch a group_m M-band per
  group before switching group -> every group size stays in flight, group_m
  B-stripe L2 reuse kept (balanced-neutral). Skew now flat ~0.86x balanced; 30:1
  wgrad 1162 -> 1592 TF. On by default (env WG_INTERLEAVE=0 to disable).
- persist wgrad i64 SRD rebase: cumulative m_start/m_end*OUT overflowed int32 for
  large-G MoE (e.g. 256 experts, OUT_M=8192 -> m_total*OUT ~ 3.2e9); fold into the
  i64 base + per-group num_records (same scheme as masked). Verified SNR 55.6 dB.

Kernel-level vs Triton (MI355X, B=8 balanced): fwd 1.19x, dgrad 1.14x, wgrad
1.91x; under token skew wgrad matches Triton's robustness (~0.86x balanced).
Comments trimmed to <3 lines, no result caching, no dead code.
The masked and persistent wgrad kernels each inlined the same per-tile decode
(group/block_m/block_n with band-cyclic / group_n band / group_m cluster / row-
major) and the same i64 SRD rebase. Factor both into _wgrad_block_mn and
_wgrad_rebase (single source of truth). Behavior unchanged: masked keeps the
band-cyclic skew interleave (interleave=True), persistent keeps group_n/group_m
(interleave=False); each kernel's own K-loop body is untouched. SNR 55.6 dB,
perf and skew/large-G behavior unchanged.
- Rename primus_turbo/flydsl/utils/fp8_gemm_helper.py -> gemm_helper.py
- Update imports in gemm_fp8_kernel.py and gemm_fp8_grouped_kernel.py
- gemm_helper.py trims make_fp8_buffer_tensor_rebased + _as_index (now
  inlined / superseded by the int64 rebase path in the kernel)
- Rebase onto origin/main (7 commits: #383 scale pad bug fix, #366 AITER
  MXFP4 preshuffle fast path, #381 meta fix, #377 USP attention, #349
  mxfp8 triton grouped gemm, #382 build deps, #380 ci skip)

Notable from main worth tracking:
- #383: scale pad slots now fill with 0 instead of E8M0_EXPONENT_BIAS(127)
  for both mxfp4 and mxfp8 — may affect FlyDSL kernel scale correctness
  if kernel reads beyond valid scale range
- #366: AITER MXFP4 fast path adds K_MULTIPLE=32 guard, removes
  enable_preshuffle() in favor of use_preshuffle flag in Float4QuantConfig
Add BackendType.FLYDSL to three test parametrize lists:
- test_grouped_gemm_fp8_tensorwise
- test_grouped_gemm_fp8_tensorwise_deterministic
- test_grouped_gemm_fp8_tensorwise_quantized_tensor

Each gets a gfx950-only skip guard matching the pattern in test_gemm_fp8.py.
FlyDSL backend is TENSORWISE-only (per can_handle), so no changes needed for
ROWWISE/BLOCKWISE/MX_BLOCKWISE tests.
… NT/NN kernels

The non-persistent mk() factory in _autotune_np_dispatch passed nt_vmcnt=-1
to both _compile_grouped_nt and _compile_grouped_nn, suppressing the
s_waitcnt vmcnt(N) instruction at the end of each K-loop iteration.

Without this fence, the next iteration can start reading from the LDS
ping-pong buffers (a_next/b_next) before the G2S buffer_load_lds
operations from the current iteration have written to them, causing a
data hazard and silent numerical corruption.

Symptom: NT forward kernel gave SNR -0.03 dB (vs 28.5 dB expected);
NN dgrad with BLOCK_M=128 (small-M path) gave SNR -3 dB. NN BLOCK_M=256
happened to work because more MFMA computation hides the G2S latency,
but this was not guaranteed.

Fix: use nt_vmcnt=3 (same value as the persistent kernel, already
verified correct) in both NT and NN non-persistent factories.
@kyle-256 kyle-256 force-pushed the dev/kyle/flydsl_grp_gemm branch 2 times, most recently from 96278c5 to 16c3293 Compare June 17, 2026 10:02
Copilot AI review requested due to automatic review settings June 17, 2026 11:46

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 7 changed files in this pull request and generated 3 comments.

Comment thread primus_turbo/flydsl/grouped_gemm/gemm_fp8_grouped_kernel.py
Comment thread primus_turbo/flydsl/utils/gemm_helper.py
Comment thread primus_turbo/flydsl/gemm/gemm_fp8_kernel.py
Copilot AI review requested due to automatic review settings June 17, 2026 13:21
@kyle-256 kyle-256 force-pushed the dev/kyle/flydsl_grp_gemm branch from 57fe12f to 24f55f4 Compare June 17, 2026 13:21
Restore _as_index and make_fp8_buffer_tensor_rebased that were stripped
in bfd4510 but are still imported and used by the grouped kernel.

Fix two bugs in StoreCPerTensor and StoreCPerTensorCShuffle:

1. buffer_store(mask=False) redirects voffset to 0x7FFFFFFF rather than
   using a HW predicate. When nrec was clamped to 0xFFFFFFFF, the CDNA
   HW OOB check passed (0x7FFFFFFF < 0xFFFFFFFF) and the invalid store
   fired at band_base+2GB. Fix: cap nrec at 0x7FFFFFFF so the sentinel
   is always >= nrec. Valid tile offsets are at most ~30 MB, well within
   the 2 GB cap.

2. band_base and nrec derive from the group scan (arith.select chain
   over group_offs buffer loads), which the compiler's divergence analysis
   marks as divergent. Without _readfirstlane_i32 the output SRD lands in
   VGPRs and every buffer_store is wrapped in a waterfall loop. Fix: pin
   band_base and nrec via _readfirstlane_i32 before passing to
   create_buffer_resource_from_addr, mirroring the input SRD treatment in
   make_fp8_buffer_tensor_rebased.

Fix misleading group_offs comments: the tensor is int64 [G+1] passed as
an int32 view; _load_go reads only the low word at i32[2*idx] (offsets
are < 2^31 so the high word is always 0).
@kyle-256 kyle-256 force-pushed the dev/kyle/flydsl_grp_gemm branch from 24f55f4 to ccb450e Compare June 17, 2026 13:22

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 7 changed files in this pull request and generated 2 comments.

Comment thread primus_turbo/flydsl/grouped_gemm/gemm_fp8_grouped_kernel.py
Comment thread primus_turbo/flydsl/grouped_gemm/gemm_fp8_grouped_kernel.py
Comment thread primus_turbo/pytorch/kernels/gemm/gemm_fp8_impl.py Outdated
…2^32 cap

The contraction-traversal operands (NN B[K,N]; TN A[K,M] & B[K,N]; grouped NN
B and wgrad A & B) ride their K-stride offset on the buffer instruction's 32-bit
soffset, so the span k*BLOCK_K*{N,M} wraps once the operand exceeds ~4 GB fp8 --
the dispatcher declined these to a Triton fallback.

Add an i64-traverse mode to G2SLoader: instead of a fixed SRD base + 32-bit
soffset, fold the per-load K-offset into the i64 descriptor base (re-base via
make_fp8_buffer_tensor_rebased, soffset 0). The foldable operands (NT both,
NN A) are unchanged. Threaded through _compile_dense_nn/_tn (+ autotune
dispatch + wrapper auto-select on K*N / K*M >= 2^32) and the grouped NN / wgrad
compile layer; _wgrad_rebase now also returns the per-operand re-base tuples.

NT needs nothing: both operands are K-contiguous, so the per-tile base folds
once into the i64 SRD and the per-load offset stays ~128*K (no realistic cap).

Verified on MI355X: oversized correctness NN k*n=4.33e9 -> 73.1 dB, TN
k*m=4.43e9 -> 76.0 dB. Slowdown of the i64 path on in-cap shapes (same config)
is ~2-5% on compute-bound, up to ~11% on small memory-bound -- so dispatch uses
i64 only at/above 2^32 and keeps the cheaper int32 path below.
Copilot AI review requested due to automatic review settings June 22, 2026 03:48

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 7 changed files in this pull request and generated 6 comments.

Comment on lines +314 to +315
if backend == BackendType.FLYDSL and get_device_compute_capability() < (9, 5):
pytest.skip("FlyDSL fp8 grouped GEMM is gfx950-only")
Comment on lines +436 to +437
if backend == BackendType.FLYDSL and get_device_compute_capability() < (9, 5):
pytest.skip("FlyDSL fp8 grouped GEMM is gfx950-only")
Comment on lines +679 to +680
if backend == BackendType.FLYDSL and get_device_compute_capability() < (9, 5):
pytest.skip("FlyDSL fp8 grouped GEMM is gfx950-only")
Comment on lines +430 to +434
"""FlyDSL fp8 grouped GEMM backend (gfx950, per-tensor / TENSORWISE only).

M-grouped operator: forward (trans_b=True, NT) + dgrad (trans_b=False, NN).
Uses the FlyDSL mfma_f32_16x16x128_f8f6f4 kernel (gfx950-only).
"""
Comment on lines +1589 to +1593
group_offs: "torch.Tensor",
trans_b: bool = False,
out_dtype=torch.bfloat16,
num_cu: "int | None" = -1,
) -> "torch.Tensor":
Comment on lines +2187 to +2190
group_offs: "torch.Tensor",
out_dtype=torch.bfloat16,
num_cu: "int | None" = -1,
) -> "torch.Tensor":

@RuibinCheung RuibinCheung left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@RuibinCheung RuibinCheung merged commit c566ee9 into main Jun 22, 2026
8 of 9 checks passed
@RuibinCheung RuibinCheung deleted the dev/kyle/flydsl_grp_gemm branch June 22, 2026 07:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants