1717import json
1818import os
1919import statistics
20- import sys
2120import time
2221import warnings
2322from typing import Any
4039import thunder
4140from thunder .dynamo .compiler import thunderfx
4241from thunder .benchmarks .layers_for_inference_benchmark import (
43- GroupedLinear ,
42+ GroupedSwiGLU ,
4443 Llama4MoE ,
45- NVFP4InferenceGroupedLinear ,
46- NVFP4InferenceLinear ,
44+ NVFP4InferenceGroupedSwiGLU ,
4745 nvfuser_f16a_nvfp4weight_scaled_grouped_mm ,
48- nvfuser_f16a_nvfp4weight_scaled_mm ,
46+ FLOAT4_E2M1_MAX ,
47+ FLOAT8_E4M3_EPS ,
48+ FLOAT8_E4M3_MAX ,
4949)
50- from thunder .torch .custom_op import _register_custom_op
5150from thunder .tests .distributed .test_moe import GroupedLinearColwiseParallel , GroupedLinearRowwiseParallel
5251from thunder .transforms .cudagraph import CUDAGraphTransform
52+ from thunder .torch .custom_op import _register_custom_op , _register_nvfuser_translator
5353
5454if TYPE_CHECKING :
5555 from typing import Any
7373LLAMA4_MAVERICK_MODEL_ID : str = "meta-llama/Llama-4-Maverick-17B-128E"
7474
7575
76+ # TODO: Add mm quantization once nvfuser implements nvfp4 gemm
77+ # Register nvfp4 custom ops with Thunder and nvFuser
78+ def _register_nvfp4_ops ():
79+ """Register nvfp4 custom operations with Thunder."""
80+ # Register f16a_nvfp4weight_scaled_grouped_mm with nvfuser translator
81+ _nvfp4_grouped_mm_symbol = _register_custom_op (nvfuser_f16a_nvfp4weight_scaled_grouped_mm )
82+
83+ def nvfp4_grouped_mm_translator (
84+ activation ,
85+ fp4_weight ,
86+ weight_scaling_factor ,
87+ global_scale ,
88+ offsets ,
89+ blockscale_offsets ,
90+ problem_sizes ,
91+ * ,
92+ fd ,
93+ lc_to_nv_map ,
94+ ):
95+ from nvfuser_direct import DataType
96+ from thunder .executors .nvfuserex_impl import getnv
97+
98+ nv_act = getnv (activation , fd , lc_to_nv_map )
99+ nv_fp4_w = getnv (fp4_weight , fd , lc_to_nv_map )
100+ nv_sf_w = getnv (weight_scaling_factor , fd , lc_to_nv_map )
101+ nv_alpha = getnv (global_scale , fd , lc_to_nv_map )
102+ nv_offsets = getnv (offsets , fd , lc_to_nv_map )
103+ nv_blocksf_offsets = getnv (blockscale_offsets , fd , lc_to_nv_map )
104+ nv_problem_sizes = getnv (problem_sizes , fd , lc_to_nv_map )
105+ # dynamic shape support has some concretization issue
106+ m_size = activation .shape [0 ]
107+ k_size = activation .shape [1 ]
108+ k_tile_size = k_size // 16
109+
110+ reshaped_mat1 = fd .ops .reshape (nv_act , [m_size , k_tile_size , 16 ])
111+ scale1 = fd .ops .abs (reshaped_mat1 )
112+ scale1 = fd .ops .max (scale1 , 2 )
113+ scale1 = fd .ops .div (scale1 , FLOAT4_E2M1_MAX )
114+ scale1 = fd .ops .clamp (scale1 , FLOAT8_E4M3_EPS , FLOAT8_E4M3_MAX )
115+
116+ broadcast_scale1 = fd .ops .broadcast (scale1 , [False , False , True ])
117+ reshaped_scaled_mat1 = fd .ops .div (reshaped_mat1 , broadcast_scale1 )
118+ reshaped_scaled_mat1 = fd .ops .clamp (reshaped_scaled_mat1 , - FLOAT8_E4M3_MAX , FLOAT8_E4M3_MAX )
119+
120+ scaled_mat1 = fd .ops .reshape (reshaped_scaled_mat1 , [m_size , k_size ])
121+ fp4_mat1 = fd .ops .cast (scaled_mat1 , DataType .Float4_e2m1fn )
122+ fp8_scale1 = fd .ops .cast (scale1 , DataType .Float8_e4m3fn )
123+ layout_fp8_scale1 = fd .ops .preprocess_grouped_matmul_input_sf (fp8_scale1 , nv_offsets , nv_blocksf_offsets )
124+ out = fd .ops .cutlass_nvfp4_grouped_mm (
125+ fp4_mat1 ,
126+ nv_fp4_w ,
127+ layout_fp8_scale1 ,
128+ nv_sf_w ,
129+ nv_alpha ,
130+ # NOTE: we might need to call contiguous on problem_sizes
131+ nv_problem_sizes ,
132+ nv_offsets ,
133+ nv_blocksf_offsets ,
134+ DataType .BFloat16 ,
135+ )
136+ return out
137+
138+ _register_nvfuser_translator (_nvfp4_grouped_mm_symbol , nvfp4_grouped_mm_translator )
139+
140+
76141# The logic is based on https://github.com/pytorch/ao/blob/b34c1037/torchao/quantization/quant_api.py#L230
77142def _replace_with_custom_fn_if_matches_filter_with_name (
78143 model ,
@@ -117,16 +182,18 @@ def _replace_llama4_moe(model: nn.Module) -> None:
117182
118183
119184def _quantize_llama4 (model : nn .Module ) -> None :
120- """Replace linear and moe with nvfp4 inference version."""
121- _replace_with_custom_fn_if_matches_filter_with_name (
122- model ,
123- NVFP4InferenceLinear .from_linear ,
124- lambda model , cur_fqn : isinstance (model , nn .Linear ),
125- )
185+ """Replace linear and/or MoE with nvfp4 inference version.
186+
187+ Args:
188+ model: The model to quantize
189+
190+ Note: GroupedSwiGLU is always quantized when this function is called.
191+ """
192+ # Always quantize GroupedSwiGLU when this function is called
126193 _replace_with_custom_fn_if_matches_filter_with_name (
127194 model ,
128- NVFP4InferenceGroupedLinear . from_grouped_linear ,
129- lambda model , cur_fqn : isinstance (model , GroupedLinear ),
195+ NVFP4InferenceGroupedSwiGLU . from_grouped_swiglu ,
196+ lambda model , cur_fqn : isinstance (model , GroupedSwiGLU ),
130197 )
131198
132199
@@ -150,7 +217,7 @@ class InferenceBenchmarkConfig:
150217 num_layers : int | None
151218 num_iterations : int
152219 warmup_iterations : int
153- enable_nvfp4 : bool # Enable NVFP4 quantization
220+ enable_nvfp4 : bool # Enable NVFP4 registration and quantize GroupedSwiGLU in MoE
154221 fx_report_folder : str | None
155222 enable_nv_linear : bool
156223 mode : str
@@ -670,7 +737,11 @@ def parse_args() -> argparse.Namespace:
670737 help = "Specify the folder for thunderfx_benchmark_report." ,
671738 )
672739
673- parser .add_argument ("--enable-nvfp4" , action = "store_true" , help = "Enable NVFP4 quantization for linear layers" )
740+ parser .add_argument (
741+ "--enable-nvfp4" ,
742+ action = "store_true" ,
743+ help = "Enable NVFP4 quantization for MoE GroupedSwiGLU layers (has nvfuser grouped_mm support)" ,
744+ )
674745 parser .add_argument (
675746 "--enable-nv-linear" ,
676747 action = "store_true" ,
@@ -682,6 +753,11 @@ def parse_args() -> argparse.Namespace:
682753 help = "Wrap each non-warmup iteration with cudaProfilerStart() and cudaProfilerStop(). This allows us to run `nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<N> ... --profile` to record only the non-warmup iterations." ,
683754 )
684755
756+ parser .add_argument (
757+ "--thunder-trace" ,
758+ action = "store_true" ,
759+ help = "Enable debug dump of thunder trace" ,
760+ )
685761 parser .add_argument ("--save-results" , action = "store_true" , help = "Save results to JSON file" )
686762 parser .add_argument ("--output-dir" , type = str , default = "./results" , help = "Directory to save results" )
687763 parser .add_argument (
@@ -702,13 +778,13 @@ def main():
702778 if args .save_results :
703779 os .makedirs (args .output_dir , exist_ok = True )
704780
705- # TODO: Override the forward with nvfuser_direct based implementation like
706- # https://github.com/Lightning-AI/lightning-thunder/blob/8b72715d/thunder/tests/test_torch_library_custom_op.py#L250-L266 does.
707- # Note that the linked code is in a draft pull request of https://github.com/Lightning-AI/lightning-thunder/pull/2481
708- # so we might want to do it more clumsily by copying the code in the pull request for now.
781+ # Register NVFP4 custom ops with nvfuser translators when enabled
709782 if args .enable_nvfp4 :
710- sym_of_nvfp4_scaled_mm = _register_custom_op (nvfuser_f16a_nvfp4weight_scaled_mm ) # noqa: F841
711- sym_of_nvfp4_scaled_grouped_mm = _register_custom_op (nvfuser_f16a_nvfp4weight_scaled_grouped_mm ) # noqa: F841
783+ try :
784+ _register_nvfp4_ops ()
785+ except Exception as e :
786+ # If registration fails (e.g., nvfuser not available), warn and continue
787+ warnings .warn (f"Failed to register nvfp4 custom ops: { e } " )
712788
713789 config = InferenceBenchmarkConfig (
714790 model_name = args .model_name ,
@@ -730,13 +806,17 @@ def main():
730806 )
731807 benchmark = InferenceBenchmark (config )
732808
733- if args .enable_nvfp4 :
734- msg = "NVFP4 kernels are not yet available. `--enable-nvfp4` runs only quantization but not benchmark"
735- warnings .warn (msg )
736- sys .exit (0 )
737-
738809 benchmark .run_benchmark ()
739810 benchmark .print_results ()
811+
812+ if args .thunder_trace and args .mode == "thunder" :
813+ backend = benchmark .model ._backend
814+ for subgraph_info in backend .subgraph_infos :
815+ assert isinstance (subgraph_info .original_graph_module , torch .fx .GraphModule )
816+ assert len (subgraph_info .thunder_compiled_fns )
817+ for thunder_fn in subgraph_info .thunder_compiled_fns :
818+ print (thunder .last_traces (thunder_fn )[- 1 ])
819+
740820 if args .save_results :
741821 timestamp = time .strftime ("%Y%m%d_%H%M%S" )
742822 filename = f"thunder_inference_{ args .model_name .replace ('/' , '_' )} _{ timestamp } .json"
0 commit comments