Skip to content

Commit 9d24da4

Browse files
committed
Merge branch 'main' into bl/update_view_groups
2 parents e0c385b + 4d3a3c3 commit 9d24da4

File tree

6 files changed

+315
-156
lines changed

6 files changed

+315
-156
lines changed

thunder/benchmarks/benchmark_inference.py

Lines changed: 107 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import json
1818
import os
1919
import statistics
20-
import sys
2120
import time
2221
import warnings
2322
from typing import Any
@@ -40,16 +39,17 @@
4039
import thunder
4140
from thunder.dynamo.compiler import thunderfx
4241
from thunder.benchmarks.layers_for_inference_benchmark import (
43-
GroupedLinear,
42+
GroupedSwiGLU,
4443
Llama4MoE,
45-
NVFP4InferenceGroupedLinear,
46-
NVFP4InferenceLinear,
44+
NVFP4InferenceGroupedSwiGLU,
4745
nvfuser_f16a_nvfp4weight_scaled_grouped_mm,
48-
nvfuser_f16a_nvfp4weight_scaled_mm,
46+
FLOAT4_E2M1_MAX,
47+
FLOAT8_E4M3_EPS,
48+
FLOAT8_E4M3_MAX,
4949
)
50-
from thunder.torch.custom_op import _register_custom_op
5150
from thunder.tests.distributed.test_moe import GroupedLinearColwiseParallel, GroupedLinearRowwiseParallel
5251
from thunder.transforms.cudagraph import CUDAGraphTransform
52+
from thunder.torch.custom_op import _register_custom_op, _register_nvfuser_translator
5353

5454
if TYPE_CHECKING:
5555
from typing import Any
@@ -73,6 +73,71 @@
7373
LLAMA4_MAVERICK_MODEL_ID: str = "meta-llama/Llama-4-Maverick-17B-128E"
7474

7575

76+
# TODO: Add mm quantization once nvfuser implements nvfp4 gemm
77+
# Register nvfp4 custom ops with Thunder and nvFuser
78+
def _register_nvfp4_ops():
79+
"""Register nvfp4 custom operations with Thunder."""
80+
# Register f16a_nvfp4weight_scaled_grouped_mm with nvfuser translator
81+
_nvfp4_grouped_mm_symbol = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_grouped_mm)
82+
83+
def nvfp4_grouped_mm_translator(
84+
activation,
85+
fp4_weight,
86+
weight_scaling_factor,
87+
global_scale,
88+
offsets,
89+
blockscale_offsets,
90+
problem_sizes,
91+
*,
92+
fd,
93+
lc_to_nv_map,
94+
):
95+
from nvfuser_direct import DataType
96+
from thunder.executors.nvfuserex_impl import getnv
97+
98+
nv_act = getnv(activation, fd, lc_to_nv_map)
99+
nv_fp4_w = getnv(fp4_weight, fd, lc_to_nv_map)
100+
nv_sf_w = getnv(weight_scaling_factor, fd, lc_to_nv_map)
101+
nv_alpha = getnv(global_scale, fd, lc_to_nv_map)
102+
nv_offsets = getnv(offsets, fd, lc_to_nv_map)
103+
nv_blocksf_offsets = getnv(blockscale_offsets, fd, lc_to_nv_map)
104+
nv_problem_sizes = getnv(problem_sizes, fd, lc_to_nv_map)
105+
# dynamic shape support has some concretization issue
106+
m_size = activation.shape[0]
107+
k_size = activation.shape[1]
108+
k_tile_size = k_size // 16
109+
110+
reshaped_mat1 = fd.ops.reshape(nv_act, [m_size, k_tile_size, 16])
111+
scale1 = fd.ops.abs(reshaped_mat1)
112+
scale1 = fd.ops.max(scale1, 2)
113+
scale1 = fd.ops.div(scale1, FLOAT4_E2M1_MAX)
114+
scale1 = fd.ops.clamp(scale1, FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX)
115+
116+
broadcast_scale1 = fd.ops.broadcast(scale1, [False, False, True])
117+
reshaped_scaled_mat1 = fd.ops.div(reshaped_mat1, broadcast_scale1)
118+
reshaped_scaled_mat1 = fd.ops.clamp(reshaped_scaled_mat1, -FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX)
119+
120+
scaled_mat1 = fd.ops.reshape(reshaped_scaled_mat1, [m_size, k_size])
121+
fp4_mat1 = fd.ops.cast(scaled_mat1, DataType.Float4_e2m1fn)
122+
fp8_scale1 = fd.ops.cast(scale1, DataType.Float8_e4m3fn)
123+
layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf(fp8_scale1, nv_offsets, nv_blocksf_offsets)
124+
out = fd.ops.cutlass_nvfp4_grouped_mm(
125+
fp4_mat1,
126+
nv_fp4_w,
127+
layout_fp8_scale1,
128+
nv_sf_w,
129+
nv_alpha,
130+
# NOTE: we might need to call contiguous on problem_sizes
131+
nv_problem_sizes,
132+
nv_offsets,
133+
nv_blocksf_offsets,
134+
DataType.BFloat16,
135+
)
136+
return out
137+
138+
_register_nvfuser_translator(_nvfp4_grouped_mm_symbol, nvfp4_grouped_mm_translator)
139+
140+
76141
# The logic is based on https://github.com/pytorch/ao/blob/b34c1037/torchao/quantization/quant_api.py#L230
77142
def _replace_with_custom_fn_if_matches_filter_with_name(
78143
model,
@@ -117,16 +182,18 @@ def _replace_llama4_moe(model: nn.Module) -> None:
117182

118183

119184
def _quantize_llama4(model: nn.Module) -> None:
120-
"""Replace linear and moe with nvfp4 inference version."""
121-
_replace_with_custom_fn_if_matches_filter_with_name(
122-
model,
123-
NVFP4InferenceLinear.from_linear,
124-
lambda model, cur_fqn: isinstance(model, nn.Linear),
125-
)
185+
"""Replace linear and/or MoE with nvfp4 inference version.
186+
187+
Args:
188+
model: The model to quantize
189+
190+
Note: GroupedSwiGLU is always quantized when this function is called.
191+
"""
192+
# Always quantize GroupedSwiGLU when this function is called
126193
_replace_with_custom_fn_if_matches_filter_with_name(
127194
model,
128-
NVFP4InferenceGroupedLinear.from_grouped_linear,
129-
lambda model, cur_fqn: isinstance(model, GroupedLinear),
195+
NVFP4InferenceGroupedSwiGLU.from_grouped_swiglu,
196+
lambda model, cur_fqn: isinstance(model, GroupedSwiGLU),
130197
)
131198

132199

@@ -150,7 +217,7 @@ class InferenceBenchmarkConfig:
150217
num_layers: int | None
151218
num_iterations: int
152219
warmup_iterations: int
153-
enable_nvfp4: bool # Enable NVFP4 quantization
220+
enable_nvfp4: bool # Enable NVFP4 registration and quantize GroupedSwiGLU in MoE
154221
fx_report_folder: str | None
155222
enable_nv_linear: bool
156223
mode: str
@@ -670,7 +737,11 @@ def parse_args() -> argparse.Namespace:
670737
help="Specify the folder for thunderfx_benchmark_report.",
671738
)
672739

673-
parser.add_argument("--enable-nvfp4", action="store_true", help="Enable NVFP4 quantization for linear layers")
740+
parser.add_argument(
741+
"--enable-nvfp4",
742+
action="store_true",
743+
help="Enable NVFP4 quantization for MoE GroupedSwiGLU layers (has nvfuser grouped_mm support)",
744+
)
674745
parser.add_argument(
675746
"--enable-nv-linear",
676747
action="store_true",
@@ -682,6 +753,11 @@ def parse_args() -> argparse.Namespace:
682753
help="Wrap each non-warmup iteration with cudaProfilerStart() and cudaProfilerStop(). This allows us to run `nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<N> ... --profile` to record only the non-warmup iterations.",
683754
)
684755

756+
parser.add_argument(
757+
"--thunder-trace",
758+
action="store_true",
759+
help="Enable debug dump of thunder trace",
760+
)
685761
parser.add_argument("--save-results", action="store_true", help="Save results to JSON file")
686762
parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results")
687763
parser.add_argument(
@@ -702,13 +778,13 @@ def main():
702778
if args.save_results:
703779
os.makedirs(args.output_dir, exist_ok=True)
704780

705-
# TODO: Override the forward with nvfuser_direct based implementation like
706-
# https://github.com/Lightning-AI/lightning-thunder/blob/8b72715d/thunder/tests/test_torch_library_custom_op.py#L250-L266 does.
707-
# Note that the linked code is in a draft pull request of https://github.com/Lightning-AI/lightning-thunder/pull/2481
708-
# so we might want to do it more clumsily by copying the code in the pull request for now.
781+
# Register NVFP4 custom ops with nvfuser translators when enabled
709782
if args.enable_nvfp4:
710-
sym_of_nvfp4_scaled_mm = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_mm) # noqa: F841
711-
sym_of_nvfp4_scaled_grouped_mm = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_grouped_mm) # noqa: F841
783+
try:
784+
_register_nvfp4_ops()
785+
except Exception as e:
786+
# If registration fails (e.g., nvfuser not available), warn and continue
787+
warnings.warn(f"Failed to register nvfp4 custom ops: {e}")
712788

713789
config = InferenceBenchmarkConfig(
714790
model_name=args.model_name,
@@ -730,13 +806,17 @@ def main():
730806
)
731807
benchmark = InferenceBenchmark(config)
732808

733-
if args.enable_nvfp4:
734-
msg = "NVFP4 kernels are not yet available. `--enable-nvfp4` runs only quantization but not benchmark"
735-
warnings.warn(msg)
736-
sys.exit(0)
737-
738809
benchmark.run_benchmark()
739810
benchmark.print_results()
811+
812+
if args.thunder_trace and args.mode == "thunder":
813+
backend = benchmark.model._backend
814+
for subgraph_info in backend.subgraph_infos:
815+
assert isinstance(subgraph_info.original_graph_module, torch.fx.GraphModule)
816+
assert len(subgraph_info.thunder_compiled_fns)
817+
for thunder_fn in subgraph_info.thunder_compiled_fns:
818+
print(thunder.last_traces(thunder_fn)[-1])
819+
740820
if args.save_results:
741821
timestamp = time.strftime("%Y%m%d_%H%M%S")
742822
filename = f"thunder_inference_{args.model_name.replace('/', '_')}_{timestamp}.json"

0 commit comments

Comments
 (0)