Skip to content

[PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly#3011

Merged
ptrendx merged 10 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-rht-cast-fusion-swizzled-sf-output
Jun 5, 2026
Merged

[PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly#3011
ptrendx merged 10 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-rht-cast-fusion-swizzled-sf-output

Conversation

@cael-ling

@cael-ling cael-ling commented May 19, 2026

Copy link
Copy Markdown
Contributor

Description

Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two standalone swizzle kernels (rowwise + columnwise) whose only job was to move scale factors into the layout cuBLAS LT consumes. The cast-fusion kernel already had a kEnableSwizzleSFOutput switch for that, but the framework never set the matching with_gemm_swizzled_scales flag on
NVFP4 outputs -- it was a false with a TODO. This PR wires it through and saves ~25 us per quantize on LLaMA-class shapes (1.18x – 1.36x on the quant + swizzle path that te.Linear runs).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Kernel side (transformer_engine/common/hadamard_transform/):

  • row_cast_col_hadamard_transform_cast_fusion.cu &
    group_row_cast_col_hadamard_transform_cast_fusion.cu: drive the
    existing kEnableSwizzleSFOutput template parameter from
    output.with_gemm_swizzled_scales. The grouped kernel additionally
    NVTE_CHECKs the flag is consistent across all tensors in a group
    (it honours a single boolean).
  • The graph-safe grouped variant already had this wired correctly --
    no change.

Framework side (transformer_engine/pytorch/csrc/):

  • New static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion(rows, cols)
    mirroring the dispatch-time eligibility check in
    NVFP4Quantizer::quantize_impl (rows%64==0 && cols%128==0 && SM100/110).
  • NVFP4Quantizer::create_tensor, NVFP4Quantizer::convert_and_update_tensor,
    and bulk_allocate_nvfp4_tensors now set
    with_gemm_swizzled_scales = optimize_for_gemm && with_rht && shape_eligible.
    For the grouped allocator the flag is True only if every tensor in
    the group is eligible.
  • Belt-and-suspenders NVTE_CHECK(!out.with_gemm_swizzled_scales) at
    the entry of quantize_with_rht_unfused_helper. The framework gate
    already keeps user code from tripping it; this only fires if a future
    low-level caller bypasses the gate.

Performance

SM100a, bf16 input, rowwise + columnwise SF, RHT + post-RHT amax.
Per-quantize wall-clock median via torch.utils.benchmark.Timer.blocked_autorange.
quant + swizzle = quantizer(x); tex.swizzle_scales_for_gemm_(t) --
exactly what te.Linear runs before its GEMM.

shape baseline SUT saved speedup note
(8192, 5120) 108.6 us 81.9 us 26.6 us 1.33x eligible
(8192, 10240) 107.8 us 90.2 us 17.5 us 1.19x eligible
(8192, 2560) 107.7 us 79.9 us 27.8 us 1.35x eligible
(8192, 11328) 236.3 us 236.3 us 0.0 us 1.00x ineligible
(8192, 3584) 106.0 us 78.6 us 27.4 us 1.35x eligible
(5120, 8192) 101.2 us 76.0 us 25.3 us 1.33x eligible
(10240, 8192) 107.8 us 90.4 us 17.4 us 1.19x eligible
(2560, 8192) 101.4 us 74.9 us 26.4 us 1.35x eligible
(11328, 8192) 114.4 us 93.2 us 21.2 us 1.23x eligible
(3584, 8192) 101.6 us 74.9 us 26.7 us 1.36x eligible
(4096, 16384) 100.2 us 75.0 us 25.2 us 1.34x eligible
(14336, 16384) 232.1 us 197.5 us 34.6 us 1.18x eligible
  • 11/12 shapes get 1.18x – 1.36x on the quant + swizzle path.
  • The single ineligible shape (8192, 11328) shows 1.00x as expected;
    the gate clamped, the unfused fallback ran, and the result is byte-
    identical to baseline (no regression, no crash).
  • quant_only is unchanged on all shapes within noise -- writing
    swizzled SF inside the cast-fusion kernel is essentially free; the
    entire win comes from eliminating the standalone swizzle pass.
    Repro: benchmarks/benchmark_rht_cast_swizzle_fusion.py (also has a
    --profile mode for ncu / nsys).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…ectly

Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two
standalone swizzle kernels (rowwise + columnwise) whose only job was to
move scale factors into the layout cuBLAS LT consumes. The cast-fusion
kernel already had a `kEnableSwizzleSFOutput` switch for that, but the
framework never set the matching `with_gemm_swizzled_scales` flag on
NVFP4 outputs -- it was a `false` with a TODO. This PR wires it through.

Changes:
* Single + grouped Hadamard cast-fusion kernels: drive
  `kEnableSwizzleSFOutput` from `output.with_gemm_swizzled_scales`.
* NVFP4Quantizer create_tensor / convert_and_update_tensor /
  bulk_allocate_nvfp4_tensors: set the flag when
  `optimize_for_gemm && with_rht && shape eligible`, with eligibility
  in a new static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion
  (rows%64==0 && cols%128==0 && SM100/110) shared by all three sites.
* Belt-and-suspenders NVTE_CHECK in quantize_with_rht_unfused_helper
  in case a future low-level caller bypasses the gate.

The shape gate is part of this PR (not a follow-up) because LLaMA-class
shapes like (8192, 11328) have K%128==64. Without the gate the framework
would set the flag, dispatch would fall to the unfused path that can't
emit swizzled SF, and the process would abort. With the gate, ineligible
shapes silently fall back to the original code path.

Numbers (GB200 SM100, bf16, rowwise+columnwise, RHT, per-quantize median,
`quant + swizzle` path -- what te.Linear actually runs):

  (8192,  5120)    108.6 ->  81.9 us   1.33x   eligible
  (8192, 11328)    236.3 -> 236.3 us   1.00x   ineligible, gate clamped
  (11328, 8192)    114.4 ->  93.2 us   1.23x   eligible
  (14336,16384)    232.1 -> 197.5 us   1.18x   eligible

11/12 production-class shapes get 1.18x - 1.36x. The one ineligible
shape gets 1.00x (= unchanged, no regression). `quant_only` is unchanged
across all shapes -- the savings come entirely from eliminating the
standalone swizzle pass, not from a faster quant kernel.

Repro: benchmarks/benchmark_rht_cast_swizzle_fusion.py

Tests:
* new tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py:
  byte-equal SF / FP4 data / amax vs swizzled reference; plus 5 cases
  verifying the shape gate clamps correctly and that quantizer(x) on an
  ineligible shape does not raise.
* tests/pytorch/nvfp4/test_nvfp4_group_quantize.py: added
  optimize_for_gemm parametrization for the legacy grouped path.
* test_nvfp4_group_quantize_graph_safe.py passes unchanged (graph-safe
  variant already had the wiring).

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling requested a review from ksivaman as a code owner May 19, 2026 03:49
@greptile-apps

greptile-apps Bot commented May 19, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR wires up the existing kEnableSwizzleSFOutput compile-time switch in the NVFP4 RHT cast-fusion kernels by driving it from output.with_gemm_swizzled_scales, and updates the C++ framework to set that flag when the shape and hardware are eligible. The result is that the two standalone swizzle kernel passes (rowwise + columnwise) are eliminated for eligible shapes, saving ~25 µs per quantize call.

  • Kernel side: row_cast_col_hadamard_transform_cast_fusion.cu and group_row_cast_col_hadamard_transform_cast_fusion.cu now read with_gemm_swizzled_scales from the output tensor instead of using a hardcoded false, and the grouped variant adds an NVTE_CHECK enforcing consistency across all output tensors.
  • Framework side: A new NVFP4Quantizer::is_eligible_for_rht_cast_fusion static helper centralises the shape+SM eligibility gate; create_tensor, convert_and_update_tensor, and bulk_allocate_nvfp4_tensors all use it to set with_gemm_swizzled_scales; a belt-and-suspenders NVTE_CHECK in the unfused path prevents any future caller from bypassing the gate silently.
  • Tests: Fidelity tests for both the single-tensor and grouped paths are added, covering eligible shapes (byte-equal swizzled SF) as well as ineligible shapes (fallback to post-quantize inplace_swizzle_scale_for_gemm).

Confidence Score: 5/5

Safe to merge — the eligibility gate, belt-and-suspenders NVTE_CHECK, and post-quantize fallback together ensure no shape can silently end up with a mismatched SF layout.

The change is well-scoped: eligible shapes take the new fast path (swizzled SF from the fused kernel, standalone swizzle pass eliminated), while ineligible shapes continue through the existing unfused path and post-quantize inplace_swizzle_scale_for_gemm fallback unchanged. The group consistency NVTE_CHECK in the grouped kernel and the unfused-path NVTE_CHECK provide clear failure modes if any caller bypasses the framework gate. Test coverage exercises both eligible and ineligible shapes for single-tensor and grouped quantize, including byte-equal SF comparisons.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/quantizer.cpp Core framework changes: adds is_eligible_for_rht_cast_fusion helper, gates with_gemm_swizzled_scales in create_tensor/convert_and_update_tensor, and adds NVTE_CHECK guard in quantize_with_rht_unfused_helper. Logic is correct and consistent with quantize_impl's existing dtype+shape eligibility check.
transformer_engine/pytorch/csrc/extensions/cast.cpp Replaces hardcoded with_gemm_swizzled_scales=false in bulk_allocate_nvfp4_tensors with proper group-wide eligibility derivation and cross-quantizer consistency checks. The existing post-quantize inplace_swizzle_scale_for_gemm fallback (unchanged) correctly covers ineligible shapes.
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu One-line change: use_swizzle_sf_output now reads output_.with_gemm_swizzled_scales instead of hardcoded false. Clean and correct.
transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu Reads use_swizzle_sf_output from output_list[0]->with_gemm_swizzled_scales and enforces consistency across the group with an NVTE_CHECK loop. The empty-list guard is a nice defensive addition.
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py New test file covering shape-gate logic, fidelity (byte-equal SF after swizzle), and end-to-end behavior for both eligible and ineligible shapes. All three test functions carry the recipe_available skipif guard.
tests/pytorch/nvfp4/test_nvfp4_group_quantize.py Extends existing group quantize test with optimize_for_gemm parametrization and applies swizzle_nvfp4_scale to the reference SF when the SUT emits swizzled SF, enabling byte-equal comparison.
transformer_engine/pytorch/csrc/common.h Adds is_eligible_for_rht_cast_fusion static method declaration with clear docstring. Minimal, correct.
benchmarks/benchmark_rht_cast_swizzle_fusion.py New benchmark script comparing baseline (compact SF) vs swizzle-fusion paths across production LLaMA shapes. Well-structured with quant_only and quant_plus_swizzle timing cells and CSV output.
benchmarks/profile_rht_cast_swizzle_fusion.py Profiling script that validates standalone swizzle kernels disappear from the timeline under optimize_for_gemm=True.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["quantizer(x) called\n(optimize_for_gemm=True, with_rht=True)"] --> B["create_tensor / convert_and_update_tensor\nis_eligible_for_rht_cast_fusion(shape)"]

    B -->|"rows%align==0 && cols%128==0\n&& SM 100/110"| C["with_gemm_swizzled_scales = True\n(output tensor allocated)"]
    B -->|"ineligible shape\n(e.g. cols%128 != 0)"| D["with_gemm_swizzled_scales = False\n(output tensor allocated)"]

    C --> E["quantize_impl\neligible_for_rht_cast_fusion = BF16 && shape+SM"]
    D --> F["quantize_impl\neligible_for_rht_cast_fusion = False"]

    E --> G["RHT cast-fusion kernel\nkEnableSwizzleSFOutput=True\nSwizzled SF baked in directly"]
    F --> H["quantize_with_rht_unfused_helper\nNVTE_CHECK(!with_gemm_swizzled_scales)\n→ passes (flag=False)"]

    G --> I["cast.cpp post-quantize check\noptimize_for_gemm && !with_gemm_swizzled_scales\n→ False → SKIP swizzle"]
    H --> J["cast.cpp post-quantize check\noptimize_for_gemm && !with_gemm_swizzled_scales\n→ True → inplace_swizzle_scale_for_gemm"]

    I --> K["Result: swizzled SF\n_with_gemm_swizzled_scales=True\n✅ ~25µs saved"]
    J --> L["Result: swizzled SF\n_with_gemm_swizzled_scales=True\n✅ No regression"]
Loading

Reviews (7): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +751 to +753
const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm &&
quantizer_cpp_list[0]->with_rht &&
all_tensors_rht_cast_fusion_eligible;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 optimize_for_gemm and with_rht read only from first quantizer without validation

with_gemm_swizzled_scales is derived exclusively from quantizer_cpp_list[0], so if any later quantizer in the group has a different optimize_for_gemm or with_rht value, its tensors are silently allocated with the wrong SF layout. The shape-eligibility loop below correctly iterates every tensor, but there is no matching check that all quantizers agree on optimize_for_gemm/with_rht. The split-quantize path at line 1276 documents this assumption explicitly (// Assume all quantizers have identical config); the same note or an NVTE_CHECK loop here would make the contract visible and consistent.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with_gemm_swizzled_scales was derived from quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking that other quantizers in the group agreed; if any later quantizer had a different value, its tensors would be silently allocated with the wrong SF layout.

Following the precedent of the split-quantize path at line 1276
(// Assume all quantizers have identical config), this commit:

  • adds an explicit comment block calling out the group-wide
    identical-config assumption and which fields this PR enforces
    vs. which are pre-existing;
  • adds an NVTE_CHECK loop enforcing identical optimize_for_gemm
    and with_rht across the group (the two fields the
    with_gemm_swizzled_scales gate depends on), with error messages
    that print the offending tensor index and the disagreeing values;
  • extracts the [0] reads into group_optimize_for_gemm and
    group_with_rht locals so the same value feeds both the check
    and the gate.

Reviewer feedback: with_gemm_swizzled_scales was derived from
quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking
that other quantizers in the group agreed; if any later quantizer
had a different value, its tensors would be silently allocated with
the wrong SF layout.
Following the precedent of the split-quantize path at line 1276
(// Assume all quantizers have identical config), this commit:
  * adds an explicit comment block calling out the group-wide
    identical-config assumption and which fields this PR enforces
    vs. which are pre-existing;
  * adds an NVTE_CHECK loop enforcing identical optimize_for_gemm
    and with_rht across the group (the two fields the
    with_gemm_swizzled_scales gate depends on), with error messages
    that print the offending tensor index and the disagreeing values;
  * extracts the [0] reads into group_optimize_for_gemm and
    group_with_rht locals so the same value feeds both the check
    and the gate.
Other from-[0] reads (rowwise_usage, row_scaled_nvfp4,
columnwise_usage, scaling_mode, dtype) are pre-existing assumptions
and remain out of scope for this PR.
Signed-off-by: Cael Ling <caell@nvidia.com>
Comment on lines +722 to +732
// Quantization parameters. Like the NVFP4 split-quantize path
// (see split_quantize_nvfp4_impl in this file), we assume all
// quantizers in the group share an identical config and read
// group-wide flags from quantizer_cpp_list[0]. The grouped RHT
// cast-fusion kernel honours a single with_gemm_swizzled_scales
// boolean across the whole group, so optimize_for_gemm and with_rht
// must in particular agree across all quantizers; the NVTE_CHECK
// loop below enforces that for the fields the swizzled-SF gate
// depends on. (The other group-wide reads from [0] -- rowwise_usage,
// row_scaled_nvfp4, columnwise_usage, scaling_mode, dtype -- are
// pre-existing assumptions and out of scope for this PR.)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super happy about the style of those comments - they reference multiple other files, and while right now the comment matches the reality, it will easily drift. We should concentrate on commenting the invariants and assumptions needed for this file only.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit removed the prose about dispatch internals and caller responsibilities.

Comment on lines +744 to +753
// Only the RHT cast-fusion quant kernel supports direct swizzled SF
// emission. Other NVFP4 quant kernels (e.g. nvte_quantize_v2 ->
// quantize_nvfp4.cuh, quantize_transpose_nvfp4.cuh) NVTE_CHECK reject
// a swizzled-flagged output, so we gate on with_rht to avoid silent
// data corruption / hard aborts on non-RHT paths. Additionally we
// require *all* tensors in the group to be shape-eligible for RHT
// cast-fusion, because the grouped kernel honours a single boolean
// and the unfused fallback rejects swizzled output (see NVTE_CHECK
// at group_row_cast_col_hadamard_transform_cast_fusion.cu and
// quantize_with_rht_unfused_helper).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit removed the prose about dispatch internals and caller responsibilities.

Comment on lines +377 to +383
* Matches the dispatch logic in NVFP4Quantizer::quantize_impl.
* The dtype check (BF16) is implicit -- with_rht=True requires
* BF16 input by construction, so callers gate on with_rht first.
* When false, the dispatch falls back to quantize_with_rht_unfused
* which cannot emit GEMM-swizzled SF; framework gates that opt
* into with_gemm_swizzled_scales must therefore also check this
* to avoid mismatched-flag aborts in the fallback path.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this is mostly talking about the internal implementation choices rather than what that function actually does (which is covered by the first sentence).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit removed the prose about dispatch internals and caller responsibilities.

* into with_gemm_swizzled_scales must therefore also check this
* to avoid mismatched-flag aborts in the fallback path.
*/
static bool is_eligible_for_rht_cast_fusion(size_t rows, size_t cols);

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it take arbitrary shape rather than assuming it will be 2D?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Changed the signature to take the full tensor shape (const std::vector<size_t>& shape) and moved the get_2d_dims(...) flatten inside the function. All four call sites (create_tensor, convert_and_update_tensor, quantize_impl, and the grouped path in cast.cpp) now pass the shape directly without pre-flattening. The bulk loop in cast.cpp also no longer calls get_2d_dims per iteration since the function takes care of it.

Comment on lines +1764 to +1767
// Must mirror the eligibility check in NVFP4Quantizer::quantize_impl
// (search for "eligible_for_rht_cast_fusion" in this file). The dtype
// check (BF16) is implicit: with_rht is only valid for BF16 input by
// construction.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it have to mirror the check in that other function? Considering that both of these functions are in the same file and in the same class, can't we just call one from the other to keep a single source of truth?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, the proper fix is to call one from the other. quantize_impl now delegates the shape/arch predicate to NVFP4Quantizer::is_eligible_for_rht_cast_fusion(...) instead of re-inlining the same check. The BF16 dtype guard stays as an explicit && at the call site because it's specific to quantize_impl (the allocation callers don't have an input tensor to check). I also replaced the hand-rolled rows = product(input.shape[:-1]) loop with get_2d_dims(input.shape()) so the flattening rule isn't duplicated either. The shape/arch eligibility now has a single source of truth.

// neither of which supports emitting SF in the GEMM-swizzled layout (their
// backing kernels NVTE_CHECK reject swizzled-flagged output). Surface a clean
// error here instead of letting it abort deep inside the kernel with an
// opaque message. JAX hard-asserts eligibility upfront; PyTorch matches that

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we mention JAX in the PyTorch source files?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit dropped the JAX reference and the surrounding narration. The remaining 2-line comment just explains why this NVTE_CHECK is here.

@ptrendx

ptrendx commented May 19, 2026

Copy link
Copy Markdown
Member

Please also handle the convert_and_update_tensor path since it also needs changes.

bool all_tensors_rht_cast_fusion_eligible = true;
for (size_t i = 0; i < num_tensors; ++i) {
const auto [rows, cols] = get_2d_dims(shape_list[i]);
if (!NVFP4Quantizer::is_eligible_for_rht_cast_fusion(rows, cols)) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grouped kernel that supports the swizzle will only run for rows being divisible by 128, but this function will allow tensors divisible by 64.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — this was a real bug, not just an over-permissive style.

Before this fix, is_eligible_for_rht_cast_fusion(shape) used a single row-alignment constraint of rows % 64 == 0 (the single-tensor RHT cast-fusion kernel's entry check at row_cast_col_hadamard_transform_cast_fusion.cu:1161). The bulk-allocation path in cast.cpp was calling this same lax check, so shapes like rows in {64, 192, 320, ...} — all satisfying % 64 == 0 — would pass eligibility, get with_gemm_swizzled_scales=True, and then hard-abort inside the grouped kernel whose entry asserts first_logical_dim % 128 == 0
(graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu:1385).

The fix adds a for_grouped_kernel parameter on is_eligible_for_rht_cast_fusion so callers select the constraint
that matches the kernel they will actually invoke:

  • false (default): rows % 64 == 0, single-tensor kernel
  • true: rows % 128 == 0, grouped kernel

The bulk-allocation caller in cast.cpp passes /*for_grouped_kernel=*/true; the three single-tensor callers
(create_tensor, convert_and_update_tensor, quantize_impl) keep the default false. Shapes with rows in {64, 192, 320, ...} now correctly fail the grouped-path eligibility and fall back to the unfused path instead of reaching the grouped kernel.

// (search for "eligible_for_rht_cast_fusion" in this file). The dtype
// check (BF16) is implicit: with_rht is only valid for BF16 input by
// construction.
return rows % 64 == 0 && cols % 128 == 0 && transformer_engine::cuda::sm_arch() >= 100 &&

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the rows % 64 == 0 a requirement here rather than rows % 128 == 0?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 64 here is correct for the single-tensor cast-fusion kernel — its entry check is NVTE_CHECK(M % 64 == 0, ...) at row_cast_col_hadamard_transform_cast_fusion.cu:1161. The 128 you're thinking of is the grouped kernel's stricter requirement at graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu:1385.

cael-ling and others added 2 commits May 19, 2026 20:14
Functional fix:
- `bulk_allocate_nvfp4_tensors` previously used the single-tensor RHT
  eligibility check (`rows % 64 == 0`), but the grouped kernel asserts
  `first_logical_dim % 128 == 0` at entry. Shapes with rows in
  {64, 192, 320, ...} would pass eligibility, set
  `with_gemm_swizzled_scales=True`, and then hard-abort inside the
  grouped kernel with an opaque NVTE_CHECK message. Adding a
  `for_grouped_kernel` parameter on `is_eligible_for_rht_cast_fusion`
  selects the correct row alignment: 64 for the single-tensor kernel,
  128 for the grouped variant. Only the bulk-allocation caller passes
  `true`; the three single-tensor callers keep the default `false`.
Refactors:
- `is_eligible_for_rht_cast_fusion` now takes the full tensor shape
  (`std::vector<size_t>`) and flattens internally with `get_2d_dims`,
  so the four call sites no longer pre-flatten and duplicate the
  flatten rule.
- `quantize_impl` delegates the shape/arch eligibility to
  `is_eligible_for_rht_cast_fusion` instead of inlining the same
  predicate, and its hand-rolled `rows = product(shape[:-1])` loop is
  replaced with `get_2d_dims(input.shape())`. The shape/arch
  eligibility now has a single source of truth.
Comment cleanups:
- Trimmed verbose comments in `bulk_allocate_nvfp4_tensors`,
  `create_tensor`, `convert_and_update_tensor`, and
  `quantize_with_rht_unfused_helper`. Removed cross-references to
  other functions/files, code narration of subsequent lines, the JAX
  reference in PyTorch source, and the "see X for rationale" pattern.
- Doxygen on `is_eligible_for_rht_cast_fusion` reduced to a single
  brief sentence.

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling

Copy link
Copy Markdown
Contributor Author

Please also handle the convert_and_update_tensor path since it also needs changes.

Done. Both create_tensor and convert_and_update_tensor now have the same 2-line comment on the gating; removed the previous "See NVFP4Quantizer::create_tensor for the rationale" cross-reference. I also trimmed create_tensor's long rationale block (which referenced specific .cu/.cuh filenames and quantize_with_rht_unfused's internal behavior) in the same pass, so the two functions are consistent.

@cael-ling cael-ling requested a review from ptrendx May 21, 2026 01:17

// Swizzled SF is only valid when the RHT cast-fusion path runs;
// other quantize paths reject it.
const bool with_gemm_swizzled_scales = this->optimize_for_gemm && this->with_rht &&

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is then set for the out_cpp TensorWrapper (at the end of this function), but not in the actual Python object. See handing of this in the MXFP8 quantizer:

  tensor.attr("_with_gemm_swizzled_scales") = with_gemm_swizzled_scales;

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed in the latest commit (pushed), mirroring the MXFP8 quantizer

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 3, 2026
@cael-ling cael-ling requested a review from ptrendx June 3, 2026 09:55
ptrendx
ptrendx previously approved these changes Jun 3, 2026
@ptrendx

ptrendx commented Jun 3, 2026

Copy link
Copy Markdown
Member

/te-ci pytorch

…tput

Signed-off-by: cael-ling <caell@nvidia.com>
ptrendx
ptrendx previously approved these changes Jun 4, 2026
Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling requested a review from ptrendx June 4, 2026 05:22
@ptrendx

ptrendx commented Jun 4, 2026

Copy link
Copy Markdown
Member

/te-ci pytorch

ptrendx
ptrendx previously approved these changes Jun 4, 2026
The single test_nvfp4_rht_swizzle_fusion_shape_gate conflated two checks
that mainline NVIDIA#3076 split apart. _with_gemm_swizzled_scales only controls
WHERE scale-factor swizzling happens, not WHETHER: when False, the GEMM
swizzles lazily at call time; when True, the tensor is pre-swizzled and the
GEMM skips it. When this test landed, ineligible shapes (rows%64!=0 or
cols%128!=0) ended quantize with the flag False. NVIDIA#3076 then added a
post-quantize inplace_swizzle_scale_for_gemm fallback that eagerly swizzles
ineligible shapes and flips the flag back to True, so under optimize_for_gemm
the end-to-end flag is now True for all shapes. The old False expectations
encoded pre-NVIDIA#3076 behavior and started failing CI on (64,144), (128,144),
(48,128).

Split into two self-consistent tests:
- shape_gate: probes make_empty() (runs create_tensor only -- no quantize,
  no fallback), so it observes the fused-kernel shape gate in isolation and
  keeps the original True/False eligibility table.
- end_to_end_swizzled: quantizer(x) must never raise on ineligible shapes
  and must always yield _with_gemm_swizzled_scales=True (eligible via the
  fused cast-fusion kernel, ineligible via the NVIDIA#3076 swizzle fallback).

Signed-off-by: Cael Ling <caell@nvidia.com>
@ptrendx

ptrendx commented Jun 4, 2026

Copy link
Copy Markdown
Member

/te-ci pytorch

@ptrendx ptrendx merged commit 720ec27 into NVIDIA:main Jun 5, 2026
21 of 25 checks passed
francesco-bertolotti pushed a commit to francesco-bertolotti/TransformerEngine that referenced this pull request Jun 11, 2026
…ectly (NVIDIA#3011)

* [PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly

Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two
standalone swizzle kernels (rowwise + columnwise) whose only job was to
move scale factors into the layout cuBLAS LT consumes. The cast-fusion
kernel already had a `kEnableSwizzleSFOutput` switch for that, but the
framework never set the matching `with_gemm_swizzled_scales` flag on
NVFP4 outputs -- it was a `false` with a TODO. This PR wires it through.

Changes:
* Single + grouped Hadamard cast-fusion kernels: drive
  `kEnableSwizzleSFOutput` from `output.with_gemm_swizzled_scales`.
* NVFP4Quantizer create_tensor / convert_and_update_tensor /
  bulk_allocate_nvfp4_tensors: set the flag when
  `optimize_for_gemm && with_rht && shape eligible`, with eligibility
  in a new static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion
  (rows%64==0 && cols%128==0 && SM100/110) shared by all three sites.
* Belt-and-suspenders NVTE_CHECK in quantize_with_rht_unfused_helper
  in case a future low-level caller bypasses the gate.

The shape gate is part of this PR (not a follow-up) because LLaMA-class
shapes like (8192, 11328) have K%128==64. Without the gate the framework
would set the flag, dispatch would fall to the unfused path that can't
emit swizzled SF, and the process would abort. With the gate, ineligible
shapes silently fall back to the original code path.

Numbers (GB200 SM100, bf16, rowwise+columnwise, RHT, per-quantize median,
`quant + swizzle` path -- what te.Linear actually runs):

  (8192,  5120)    108.6 ->  81.9 us   1.33x   eligible
  (8192, 11328)    236.3 -> 236.3 us   1.00x   ineligible, gate clamped
  (11328, 8192)    114.4 ->  93.2 us   1.23x   eligible
  (14336,16384)    232.1 -> 197.5 us   1.18x   eligible

11/12 production-class shapes get 1.18x - 1.36x. The one ineligible
shape gets 1.00x (= unchanged, no regression). `quant_only` is unchanged
across all shapes -- the savings come entirely from eliminating the
standalone swizzle pass, not from a faster quant kernel.

Repro: benchmarks/benchmark_rht_cast_swizzle_fusion.py

Tests:
* new tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py:
  byte-equal SF / FP4 data / amax vs swizzled reference; plus 5 cases
  verifying the shape gate clamps correctly and that quantizer(x) on an
  ineligible shape does not raise.
* tests/pytorch/nvfp4/test_nvfp4_group_quantize.py: added
  optimize_for_gemm parametrization for the legacy grouped path.
* test_nvfp4_group_quantize_graph_safe.py passes unchanged (graph-safe
  variant already had the wiring).

Signed-off-by: Cael Ling <caell@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] NVFP4 RHT cast-fusion: enforce group-wide quantizer config

Reviewer feedback: with_gemm_swizzled_scales was derived from
quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking
that other quantizers in the group agreed; if any later quantizer
had a different value, its tensors would be silently allocated with
the wrong SF layout.
Following the precedent of the split-quantize path at line 1276
(// Assume all quantizers have identical config), this commit:
  * adds an explicit comment block calling out the group-wide
    identical-config assumption and which fields this PR enforces
    vs. which are pre-existing;
  * adds an NVTE_CHECK loop enforcing identical optimize_for_gemm
    and with_rht across the group (the two fields the
    with_gemm_swizzled_scales gate depends on), with error messages
    that print the offending tensor index and the disagreeing values;
  * extracts the [0] reads into group_optimize_for_gemm and
    group_with_rht locals so the same value feeds both the check
    and the gate.
Other from-[0] reads (rowwise_usage, row_scaled_nvfp4,
columnwise_usage, scaling_mode, dtype) are pre-existing assumptions
and remain out of scope for this PR.
Signed-off-by: Cael Ling <caell@nvidia.com>

* [PyTorch] NVFP4 RHT cast-fusion: address review feedback

Functional fix:
- `bulk_allocate_nvfp4_tensors` previously used the single-tensor RHT
  eligibility check (`rows % 64 == 0`), but the grouped kernel asserts
  `first_logical_dim % 128 == 0` at entry. Shapes with rows in
  {64, 192, 320, ...} would pass eligibility, set
  `with_gemm_swizzled_scales=True`, and then hard-abort inside the
  grouped kernel with an opaque NVTE_CHECK message. Adding a
  `for_grouped_kernel` parameter on `is_eligible_for_rht_cast_fusion`
  selects the correct row alignment: 64 for the single-tensor kernel,
  128 for the grouped variant. Only the bulk-allocation caller passes
  `true`; the three single-tensor callers keep the default `false`.
Refactors:
- `is_eligible_for_rht_cast_fusion` now takes the full tensor shape
  (`std::vector<size_t>`) and flattens internally with `get_2d_dims`,
  so the four call sites no longer pre-flatten and duplicate the
  flatten rule.
- `quantize_impl` delegates the shape/arch eligibility to
  `is_eligible_for_rht_cast_fusion` instead of inlining the same
  predicate, and its hand-rolled `rows = product(shape[:-1])` loop is
  replaced with `get_2d_dims(input.shape())`. The shape/arch
  eligibility now has a single source of truth.
Comment cleanups:
- Trimmed verbose comments in `bulk_allocate_nvfp4_tensors`,
  `create_tensor`, `convert_and_update_tensor`, and
  `quantize_with_rht_unfused_helper`. Removed cross-references to
  other functions/files, code narration of subsequent lines, the JAX
  reference in PyTorch source, and the "see X for rationale" pattern.
- Doxygen on `is_eligible_for_rht_cast_fusion` reduced to a single
  brief sentence.

Signed-off-by: Cael Ling <caell@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Sync _with_gemm_swizzled_scales onto Python NVFP4Tensor in update path

Signed-off-by: Cael Ling <caell@nvidia.com>

* Add license header to profile_rht_cast_swizzle_fusion.py

Signed-off-by: Cael Ling <caell@nvidia.com>

* [PyTorch] NVFP4 RHT cast-fusion: fix stale swizzle shape-gate test

The single test_nvfp4_rht_swizzle_fusion_shape_gate conflated two checks
that mainline NVIDIA#3076 split apart. _with_gemm_swizzled_scales only controls
WHERE scale-factor swizzling happens, not WHETHER: when False, the GEMM
swizzles lazily at call time; when True, the tensor is pre-swizzled and the
GEMM skips it. When this test landed, ineligible shapes (rows%64!=0 or
cols%128!=0) ended quantize with the flag False. NVIDIA#3076 then added a
post-quantize inplace_swizzle_scale_for_gemm fallback that eagerly swizzles
ineligible shapes and flips the flag back to True, so under optimize_for_gemm
the end-to-end flag is now True for all shapes. The old False expectations
encoded pre-NVIDIA#3076 behavior and started failing CI on (64,144), (128,144),
(48,128).

Split into two self-consistent tests:
- shape_gate: probes make_empty() (runs create_tensor only -- no quantize,
  no fallback), so it observes the fused-kernel shape gate in isolation and
  keeps the original True/False eligibility table.
- end_to_end_swizzled: quantizer(x) must never raise on ineligible shapes
  and must always yield _with_gemm_swizzled_scales=True (eligible via the
  fused cast-fusion kernel, ineligible via the NVIDIA#3076 swizzle fallback).

Signed-off-by: Cael Ling <caell@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Cael Ling <caell@nvidia.com>
Signed-off-by: cael-ling <caell@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants