@@ -190,65 +190,103 @@ def test_nvfp4_rht_quantize_swizzle_fusion(
190190 )
191191
192192
193+ def _make_swizzle_fusion_quantizer () -> NVFP4Quantizer :
194+ """RHT NVFP4 quantizer configured to request GEMM-swizzled SF output."""
195+ quantizer = NVFP4Quantizer (
196+ fp4_dtype = tex .DType .kFloat4E2M1 ,
197+ rowwise = True ,
198+ columnwise = True ,
199+ with_amax_reduction = False ,
200+ amax_reduction_group = None ,
201+ with_rht = True ,
202+ with_post_rht_amax = True ,
203+ with_random_sign_mask = True ,
204+ )
205+ quantizer .optimize_for_gemm = True
206+ return quantizer
207+
208+
193209@pytest .mark .skipif (not recipe_available , reason = reason_for_no_recipe )
194210@pytest .mark .parametrize (
195211 "M, N, expected_swizzled" ,
196212 [
197- # Eligible: rows%64==0 AND cols%128==0 -> framework keeps swizzled=True
198- # and dispatch lands on the RHT cast-fusion kernel .
213+ # Eligible: rows%64==0 AND cols%128==0 -> the fused RHT cast-fusion
214+ # kernel can bake the GEMM-swizzled SF in directly .
199215 (64 , 128 , True ),
200216 (128 , 256 , True ),
201- # Ineligible (cols%128 != 0) -> framework must clamp swizzled to False
202- # so the RHT-unfused fallback runs cleanly. This is the case hit in
203- # production at irregular shapes like (8192, 11328) where 11328%128==64.
217+ # Ineligible (cols%128 != 0) -> the fused kernel is gated off, so
218+ # create_tensor reports swizzled=False (the fused kernel did not bake
219+ # it in). This is the case hit in production at irregular shapes like
220+ # (8192, 11328) where 11328%128==64.
204221 (64 , 144 , False ),
205222 (128 , 144 , False ),
206- # Ineligible (rows%64 != 0) -> same clamping requirement .
223+ # Ineligible (rows%64 != 0) -> same gating .
207224 (48 , 128 , False ),
208225 ],
209226)
210227def test_nvfp4_rht_swizzle_fusion_shape_gate (M : int , N : int , expected_swizzled : bool ) -> None :
211- """Framework gate must clamp ``with_gemm_swizzled_scales`` by shape eligibility.
212-
213- Only ``row_cast_col_hadamard_transform_cast_fusion.cu`` can emit
214- GEMM-swizzled SF directly. When the input shape is ineligible for that
215- kernel (rows%64!=0 or cols%128!=0), dispatch falls back to
216- ``quantize_with_rht_unfused_helper`` whose backing kernels
217- (``nvte_quantize_v2`` and ``nvte_hadamard_transform``) cannot emit
218- swizzled SF. The framework gates in ``NVFP4Quantizer::create_tensor`` and
219- ``convert_and_update_tensor`` must therefore clamp the flag to False on
220- ineligible shapes -- otherwise the defense-in-depth ``NVTE_CHECK`` in the
221- unfused helper hard-aborts at user-facing code paths (regression observed
222- at irregular production shapes such as ``(8192, 11328)``).
223-
224- The defense-in-depth ``NVTE_CHECK`` in ``quantize_with_rht_unfused_helper``
225- is kept as a safety net for direct low-level callers; this test
226- intentionally exercises the user-level entry point (``quantizer(x)``) and
227- asserts the framework gate makes the unfused-path crash unreachable from
228- user code.
228+ """The fused-kernel shape gate must reflect cast-fusion eligibility.
229+
230+ Only ``row_cast_col_hadamard_transform_cast_fusion.cu`` can bake
231+ GEMM-swizzled SF directly, and only for eligible shapes (rows%64==0 AND
232+ cols%128==0 on SM 100/110). ``NVFP4Quantizer::create_tensor`` carries this
233+ gate in ``_with_gemm_swizzled_scales``: it must be True only when the fused
234+ kernel is selectable, and False on ineligible shapes so dispatch routes
235+ them to the unfused path (which is what keeps the defense-in-depth
236+ ``NVTE_CHECK`` in ``quantize_with_rht_unfused_helper`` from hard-aborting
237+ on irregular production shapes such as ``(8192, 11328)``).
238+
239+ ``make_empty`` runs ``create_tensor`` only -- no quantize kernel, and in
240+ particular no post-quantize ``inplace_swizzle_scale_for_gemm`` fallback --
241+ so its ``_with_gemm_swizzled_scales`` observes the fused-kernel gate in
242+ isolation. (End-to-end ``quantizer(x)`` differs: the fallback then swizzles
243+ ineligible shapes too and flips the flag back to True; see
244+ ``test_nvfp4_rht_swizzle_fusion_end_to_end_swizzled``.)
229245 """
230- te_dtype = tex .DType .kFloat4E2M1
231- device = "cuda"
246+ quantizer = _make_swizzle_fusion_quantizer ()
247+ out = quantizer .make_empty ([M , N ], dtype = torch .bfloat16 )
248+ assert out ._with_gemm_swizzled_scales is expected_swizzled , (
249+ f"Fused-kernel shape gate expected _with_gemm_swizzled_scales={ expected_swizzled } "
250+ f"for shape ({ M } , { N } ) with optimize_for_gemm=True + with_rht=True, "
251+ f"got { out ._with_gemm_swizzled_scales } "
252+ )
232253
233- x = torch .randn ((M , N ), dtype = torch .bfloat16 , device = device )
234254
235- quantizer = NVFP4Quantizer (
236- fp4_dtype = te_dtype ,
237- rowwise = True ,
238- columnwise = True ,
239- with_amax_reduction = False ,
240- amax_reduction_group = None ,
241- with_rht = True ,
242- with_post_rht_amax = True ,
243- with_random_sign_mask = True ,
244- )
245- quantizer .optimize_for_gemm = True
255+ @pytest .mark .skipif (not recipe_available , reason = reason_for_no_recipe )
256+ @pytest .mark .parametrize (
257+ "M, N" ,
258+ [
259+ # Eligible shapes (fused cast-fusion kernel bakes swizzled SF in).
260+ (64 , 128 ),
261+ (128 , 256 ),
262+ # Ineligible shapes (cols%128 != 0 / rows%64 != 0): the fused kernel is
263+ # gated off, so they take the unfused quantize path and rely on the
264+ # post-quantize swizzle fallback. (8192, 11328)-class production shapes.
265+ (64 , 144 ),
266+ (128 , 144 ),
267+ (48 , 128 ),
268+ ],
269+ )
270+ def test_nvfp4_rht_swizzle_fusion_end_to_end_swizzled (M : int , N : int ) -> None :
271+ """End-to-end quantize must never crash and must always yield swizzled SF.
272+
273+ With ``optimize_for_gemm=True`` the contract is that the produced tensor
274+ carries GEMM-swizzled SF regardless of shape, so the GEMM can skip its own
275+ swizzle pass. Eligible shapes get this from the fused cast-fusion kernel in
276+ one pass; ineligible shapes are gated off the fused kernel (avoiding the
277+ ``NVTE_CHECK`` hard-abort in ``quantize_with_rht_unfused_helper``) and
278+ instead get swizzled by the post-quantize ``inplace_swizzle_scale_for_gemm``
279+ fallback added in mainline #3076. Either way ``quantizer(x)`` must not raise
280+ and the final ``_with_gemm_swizzled_scales`` must be True.
281+ """
282+ quantizer = _make_swizzle_fusion_quantizer ()
283+ x = torch .randn ((M , N ), dtype = torch .bfloat16 , device = "cuda" )
246284
247- # Should not raise on ineligible shapes; the framework gate clamps the
248- # flag and the unfused fallback runs cleanly.
285+ # Must not raise on ineligible shapes (gate keeps them off the fused
286+ # kernel; the unfused path + swizzle fallback run cleanly) .
249287 result = quantizer (x )
250- assert result ._with_gemm_swizzled_scales is expected_swizzled , (
251- f"Framework shape gate expected _with_gemm_swizzled_scales={ expected_swizzled } "
252- f"for shape ({ M } , { N } ) with optimize_for_gemm=True + with_rht=True, "
288+ assert result ._with_gemm_swizzled_scales is True , (
289+ f"End-to-end quantize expected _with_gemm_swizzled_scales=True for shape "
290+ f"({ M } , { N } ) with optimize_for_gemm=True + with_rht=True, "
253291 f"got { result ._with_gemm_swizzled_scales } "
254292 )
0 commit comments