Skip to content

Commit 867a706

Browse files
committed
[PyTorch] NVFP4 RHT cast-fusion: fix stale swizzle shape-gate test
The single test_nvfp4_rht_swizzle_fusion_shape_gate conflated two checks that mainline #3076 split apart. _with_gemm_swizzled_scales only controls WHERE scale-factor swizzling happens, not WHETHER: when False, the GEMM swizzles lazily at call time; when True, the tensor is pre-swizzled and the GEMM skips it. When this test landed, ineligible shapes (rows%64!=0 or cols%128!=0) ended quantize with the flag False. #3076 then added a post-quantize inplace_swizzle_scale_for_gemm fallback that eagerly swizzles ineligible shapes and flips the flag back to True, so under optimize_for_gemm the end-to-end flag is now True for all shapes. The old False expectations encoded pre-#3076 behavior and started failing CI on (64,144), (128,144), (48,128). Split into two self-consistent tests: - shape_gate: probes make_empty() (runs create_tensor only -- no quantize, no fallback), so it observes the fused-kernel shape gate in isolation and keeps the original True/False eligibility table. - end_to_end_swizzled: quantizer(x) must never raise on ineligible shapes and must always yield _with_gemm_swizzled_scales=True (eligible via the fused cast-fusion kernel, ineligible via the #3076 swizzle fallback). Signed-off-by: Cael Ling <caell@nvidia.com>
1 parent ebdd59a commit 867a706

1 file changed

Lines changed: 81 additions & 43 deletions

File tree

tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py

Lines changed: 81 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -190,65 +190,103 @@ def test_nvfp4_rht_quantize_swizzle_fusion(
190190
)
191191

192192

193+
def _make_swizzle_fusion_quantizer() -> NVFP4Quantizer:
194+
"""RHT NVFP4 quantizer configured to request GEMM-swizzled SF output."""
195+
quantizer = NVFP4Quantizer(
196+
fp4_dtype=tex.DType.kFloat4E2M1,
197+
rowwise=True,
198+
columnwise=True,
199+
with_amax_reduction=False,
200+
amax_reduction_group=None,
201+
with_rht=True,
202+
with_post_rht_amax=True,
203+
with_random_sign_mask=True,
204+
)
205+
quantizer.optimize_for_gemm = True
206+
return quantizer
207+
208+
193209
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
194210
@pytest.mark.parametrize(
195211
"M, N, expected_swizzled",
196212
[
197-
# Eligible: rows%64==0 AND cols%128==0 -> framework keeps swizzled=True
198-
# and dispatch lands on the RHT cast-fusion kernel.
213+
# Eligible: rows%64==0 AND cols%128==0 -> the fused RHT cast-fusion
214+
# kernel can bake the GEMM-swizzled SF in directly.
199215
(64, 128, True),
200216
(128, 256, True),
201-
# Ineligible (cols%128 != 0) -> framework must clamp swizzled to False
202-
# so the RHT-unfused fallback runs cleanly. This is the case hit in
203-
# production at irregular shapes like (8192, 11328) where 11328%128==64.
217+
# Ineligible (cols%128 != 0) -> the fused kernel is gated off, so
218+
# create_tensor reports swizzled=False (the fused kernel did not bake
219+
# it in). This is the case hit in production at irregular shapes like
220+
# (8192, 11328) where 11328%128==64.
204221
(64, 144, False),
205222
(128, 144, False),
206-
# Ineligible (rows%64 != 0) -> same clamping requirement.
223+
# Ineligible (rows%64 != 0) -> same gating.
207224
(48, 128, False),
208225
],
209226
)
210227
def test_nvfp4_rht_swizzle_fusion_shape_gate(M: int, N: int, expected_swizzled: bool) -> None:
211-
"""Framework gate must clamp ``with_gemm_swizzled_scales`` by shape eligibility.
212-
213-
Only ``row_cast_col_hadamard_transform_cast_fusion.cu`` can emit
214-
GEMM-swizzled SF directly. When the input shape is ineligible for that
215-
kernel (rows%64!=0 or cols%128!=0), dispatch falls back to
216-
``quantize_with_rht_unfused_helper`` whose backing kernels
217-
(``nvte_quantize_v2`` and ``nvte_hadamard_transform``) cannot emit
218-
swizzled SF. The framework gates in ``NVFP4Quantizer::create_tensor`` and
219-
``convert_and_update_tensor`` must therefore clamp the flag to False on
220-
ineligible shapes -- otherwise the defense-in-depth ``NVTE_CHECK`` in the
221-
unfused helper hard-aborts at user-facing code paths (regression observed
222-
at irregular production shapes such as ``(8192, 11328)``).
223-
224-
The defense-in-depth ``NVTE_CHECK`` in ``quantize_with_rht_unfused_helper``
225-
is kept as a safety net for direct low-level callers; this test
226-
intentionally exercises the user-level entry point (``quantizer(x)``) and
227-
asserts the framework gate makes the unfused-path crash unreachable from
228-
user code.
228+
"""The fused-kernel shape gate must reflect cast-fusion eligibility.
229+
230+
Only ``row_cast_col_hadamard_transform_cast_fusion.cu`` can bake
231+
GEMM-swizzled SF directly, and only for eligible shapes (rows%64==0 AND
232+
cols%128==0 on SM 100/110). ``NVFP4Quantizer::create_tensor`` carries this
233+
gate in ``_with_gemm_swizzled_scales``: it must be True only when the fused
234+
kernel is selectable, and False on ineligible shapes so dispatch routes
235+
them to the unfused path (which is what keeps the defense-in-depth
236+
``NVTE_CHECK`` in ``quantize_with_rht_unfused_helper`` from hard-aborting
237+
on irregular production shapes such as ``(8192, 11328)``).
238+
239+
``make_empty`` runs ``create_tensor`` only -- no quantize kernel, and in
240+
particular no post-quantize ``inplace_swizzle_scale_for_gemm`` fallback --
241+
so its ``_with_gemm_swizzled_scales`` observes the fused-kernel gate in
242+
isolation. (End-to-end ``quantizer(x)`` differs: the fallback then swizzles
243+
ineligible shapes too and flips the flag back to True; see
244+
``test_nvfp4_rht_swizzle_fusion_end_to_end_swizzled``.)
229245
"""
230-
te_dtype = tex.DType.kFloat4E2M1
231-
device = "cuda"
246+
quantizer = _make_swizzle_fusion_quantizer()
247+
out = quantizer.make_empty([M, N], dtype=torch.bfloat16)
248+
assert out._with_gemm_swizzled_scales is expected_swizzled, (
249+
f"Fused-kernel shape gate expected _with_gemm_swizzled_scales={expected_swizzled} "
250+
f"for shape ({M}, {N}) with optimize_for_gemm=True + with_rht=True, "
251+
f"got {out._with_gemm_swizzled_scales}"
252+
)
232253

233-
x = torch.randn((M, N), dtype=torch.bfloat16, device=device)
234254

235-
quantizer = NVFP4Quantizer(
236-
fp4_dtype=te_dtype,
237-
rowwise=True,
238-
columnwise=True,
239-
with_amax_reduction=False,
240-
amax_reduction_group=None,
241-
with_rht=True,
242-
with_post_rht_amax=True,
243-
with_random_sign_mask=True,
244-
)
245-
quantizer.optimize_for_gemm = True
255+
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
256+
@pytest.mark.parametrize(
257+
"M, N",
258+
[
259+
# Eligible shapes (fused cast-fusion kernel bakes swizzled SF in).
260+
(64, 128),
261+
(128, 256),
262+
# Ineligible shapes (cols%128 != 0 / rows%64 != 0): the fused kernel is
263+
# gated off, so they take the unfused quantize path and rely on the
264+
# post-quantize swizzle fallback. (8192, 11328)-class production shapes.
265+
(64, 144),
266+
(128, 144),
267+
(48, 128),
268+
],
269+
)
270+
def test_nvfp4_rht_swizzle_fusion_end_to_end_swizzled(M: int, N: int) -> None:
271+
"""End-to-end quantize must never crash and must always yield swizzled SF.
272+
273+
With ``optimize_for_gemm=True`` the contract is that the produced tensor
274+
carries GEMM-swizzled SF regardless of shape, so the GEMM can skip its own
275+
swizzle pass. Eligible shapes get this from the fused cast-fusion kernel in
276+
one pass; ineligible shapes are gated off the fused kernel (avoiding the
277+
``NVTE_CHECK`` hard-abort in ``quantize_with_rht_unfused_helper``) and
278+
instead get swizzled by the post-quantize ``inplace_swizzle_scale_for_gemm``
279+
fallback added in mainline #3076. Either way ``quantizer(x)`` must not raise
280+
and the final ``_with_gemm_swizzled_scales`` must be True.
281+
"""
282+
quantizer = _make_swizzle_fusion_quantizer()
283+
x = torch.randn((M, N), dtype=torch.bfloat16, device="cuda")
246284

247-
# Should not raise on ineligible shapes; the framework gate clamps the
248-
# flag and the unfused fallback runs cleanly.
285+
# Must not raise on ineligible shapes (gate keeps them off the fused
286+
# kernel; the unfused path + swizzle fallback run cleanly).
249287
result = quantizer(x)
250-
assert result._with_gemm_swizzled_scales is expected_swizzled, (
251-
f"Framework shape gate expected _with_gemm_swizzled_scales={expected_swizzled} "
252-
f"for shape ({M}, {N}) with optimize_for_gemm=True + with_rht=True, "
288+
assert result._with_gemm_swizzled_scales is True, (
289+
f"End-to-end quantize expected _with_gemm_swizzled_scales=True for shape "
290+
f"({M}, {N}) with optimize_for_gemm=True + with_rht=True, "
253291
f"got {result._with_gemm_swizzled_scales}"
254292
)

0 commit comments

Comments
 (0)