Skip to content

Commit 1316c1e

Browse files
authored
Merge branch 'main' into expand_wheel_builds
2 parents 123f778 + 1b12177 commit 1316c1e

54 files changed

Lines changed: 4015 additions & 456 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 380 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""Benchmark MXFP8 graph-safe grouped MLP.
6+
7+
This mirrors ``benchmark_grouped_linear.py`` but targets the graph-safe TE ops
8+
path used by grouped MLP:
9+
10+
GroupedLinear -> ScaledSwiGLU -> GroupedLinear
11+
12+
The benchmark intentionally uses CUDA-device ``m_splits`` and MXFP8 only.
13+
14+
Example:
15+
16+
python benchmarks/linear/benchmark_graph_safe_grouped_linear.py
17+
18+
Forward-only:
19+
20+
python benchmarks/linear/benchmark_graph_safe_grouped_linear.py --fwd-only
21+
22+
Nsight Systems:
23+
24+
(optionally: unset DEBUGINFOD_URLS)
25+
26+
nsys profile \
27+
--output=./benchmarks/linear/graph_safe_grouped_linear_mxfp8 \
28+
--force-overwrite true \
29+
--trace=cuda,nvtx,cudnn,cublas \
30+
python benchmarks/linear/benchmark_graph_safe_grouped_linear.py --profile
31+
"""
32+
33+
# Match the Qwen MXFP8 SFT launch toggles before importing TE.
34+
import os
35+
36+
os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1")
37+
os.environ.setdefault("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")
38+
os.environ.setdefault("NVTE_CUTEDSL_FUSED_GROUPED_MLP", "1")
39+
os.environ.setdefault("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1")
40+
41+
import argparse
42+
from contextlib import nullcontext
43+
44+
import pandas as pd
45+
import torch
46+
import torch.utils.benchmark as benchmark
47+
48+
import transformer_engine.pytorch as te
49+
import transformer_engine.pytorch.ops as te_ops
50+
from transformer_engine.common.recipe import MXFP8BlockScaling
51+
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
52+
53+
54+
MXFP8_AVAILABLE, REASON_FOR_NO_MXFP8 = FP8GlobalStateManager.is_mxfp8_available()
55+
56+
57+
def parse_int_list(value: str) -> list[int]:
58+
"""Parse comma-separated integers."""
59+
return [int(x) for x in value.split(",") if x]
60+
61+
62+
def make_uniform_splits(total_tokens: int, num_groups: int) -> list[int]:
63+
"""Split tokens uniformly across groups."""
64+
if total_tokens % num_groups != 0:
65+
raise ValueError(
66+
"Uniform split requires total_tokens divisible by num_groups, "
67+
f"got total_tokens={total_tokens}, num_groups={num_groups}"
68+
)
69+
return [total_tokens // num_groups] * num_groups
70+
71+
72+
def build_grouped_mlp(
73+
*,
74+
num_groups: int,
75+
hidden_dim: int,
76+
ffn_hidden_dim: int,
77+
dtype: torch.dtype,
78+
single_grouped_weight: bool,
79+
accumulate_into_main_grad: bool,
80+
glu_interleave_size: int,
81+
) -> te_ops.Sequential:
82+
"""Build graph-safe grouped MLP ops sequence."""
83+
recipe = MXFP8BlockScaling()
84+
with te.quantized_model_init(enabled=True, recipe=recipe):
85+
fc1 = te_ops.GroupedLinear(
86+
num_groups,
87+
hidden_dim,
88+
2 * ffn_hidden_dim,
89+
bias=False,
90+
device="cuda",
91+
dtype=dtype,
92+
single_grouped_weight=single_grouped_weight,
93+
accumulate_into_main_grad=accumulate_into_main_grad,
94+
)
95+
fc2 = te_ops.GroupedLinear(
96+
num_groups,
97+
ffn_hidden_dim,
98+
hidden_dim,
99+
bias=False,
100+
device="cuda",
101+
dtype=dtype,
102+
single_grouped_weight=single_grouped_weight,
103+
accumulate_into_main_grad=accumulate_into_main_grad,
104+
)
105+
return te_ops.Sequential(
106+
fc1,
107+
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size),
108+
fc2,
109+
)
110+
111+
112+
def init_main_grads(module: torch.nn.Module, value: float = 0.0) -> None:
113+
"""Initialize Megatron-style main_grad buffers for accumulate_into_main_grad."""
114+
with torch.no_grad():
115+
for param in module.parameters():
116+
if getattr(param, "main_grad", None) is None:
117+
param.main_grad = torch.empty(
118+
param.size(), device=param.device, dtype=torch.float32
119+
)
120+
param.main_grad.fill_(value)
121+
122+
123+
def zero_grads(module: torch.nn.Module, x: torch.Tensor, scales: torch.Tensor) -> None:
124+
"""Reset gradients without changing allocated main_grad buffers."""
125+
module.zero_grad(set_to_none=True)
126+
x.grad = None
127+
scales.grad = None
128+
129+
130+
def run_grouped_mlp_steps(
131+
module: torch.nn.Module,
132+
x: torch.Tensor,
133+
split_sizes: torch.Tensor,
134+
scales: torch.Tensor,
135+
grad_output: torch.Tensor,
136+
*,
137+
recipe: MXFP8BlockScaling,
138+
fwd_only: bool,
139+
num_steps: int,
140+
accumulate_into_main_grad: bool,
141+
) -> torch.Tensor:
142+
"""Run eager grouped MLP for a number of synthetic microbatches."""
143+
quantization_context = te.autocast(enabled=True, recipe=recipe)
144+
145+
if fwd_only:
146+
with torch.no_grad(), quantization_context:
147+
for _ in range(num_steps):
148+
out = module(x, split_sizes, scales, split_sizes)
149+
return out
150+
151+
zero_grads(module, x, scales)
152+
if accumulate_into_main_grad:
153+
init_main_grads(module)
154+
155+
with quantization_context:
156+
for step in range(num_steps):
157+
torch.cuda.nvtx.range_push(f"step_{step}")
158+
out = module(x, split_sizes, scales, split_sizes)
159+
out.backward(grad_output)
160+
torch.cuda.nvtx.range_pop()
161+
return out
162+
163+
164+
def benchmark_case(
165+
*,
166+
total_tokens: int,
167+
hidden_dim: int,
168+
ffn_hidden_dim: int,
169+
num_groups: int,
170+
dtype: torch.dtype,
171+
fwd_only: bool,
172+
single_grouped_weight: bool,
173+
accumulate_into_main_grad: bool,
174+
glu_interleave_size: int,
175+
num_microbatches: int,
176+
min_run_time: float,
177+
profile: bool,
178+
) -> float:
179+
"""Benchmark one grouped MLP shape."""
180+
split_sizes_list = make_uniform_splits(total_tokens, num_groups)
181+
split_sizes = torch.tensor(split_sizes_list, dtype=torch.int64, device="cuda")
182+
x = torch.randn(
183+
(total_tokens, hidden_dim),
184+
dtype=dtype,
185+
device="cuda",
186+
requires_grad=not fwd_only,
187+
)
188+
scales = torch.ones(
189+
(total_tokens,),
190+
dtype=dtype,
191+
device="cuda",
192+
requires_grad=not fwd_only,
193+
)
194+
grad_output = torch.ones((total_tokens, hidden_dim), dtype=dtype, device="cuda")
195+
196+
module = build_grouped_mlp(
197+
num_groups=num_groups,
198+
hidden_dim=hidden_dim,
199+
ffn_hidden_dim=ffn_hidden_dim,
200+
dtype=dtype,
201+
single_grouped_weight=single_grouped_weight,
202+
accumulate_into_main_grad=accumulate_into_main_grad,
203+
glu_interleave_size=glu_interleave_size,
204+
)
205+
recipe = MXFP8BlockScaling()
206+
207+
print(
208+
"case:",
209+
f"tokens={total_tokens}",
210+
f"hidden={hidden_dim}",
211+
f"ffn_hidden={ffn_hidden_dim}",
212+
f"num_groups={num_groups}",
213+
f"fwd_only={fwd_only}",
214+
f"single_grouped_weight={single_grouped_weight}",
215+
f"accumulate_into_main_grad={accumulate_into_main_grad}",
216+
f"glu_interleave_size={glu_interleave_size}",
217+
)
218+
print(f"m_splits: {split_sizes_list}")
219+
220+
# Warmup also forces the op-fuser to materialize the expected fused ops.
221+
run_grouped_mlp_steps(
222+
module,
223+
x,
224+
split_sizes,
225+
scales,
226+
grad_output,
227+
recipe=recipe,
228+
fwd_only=fwd_only,
229+
num_steps=128,
230+
accumulate_into_main_grad=accumulate_into_main_grad,
231+
)
232+
torch.cuda.synchronize()
233+
234+
forward_ops = module._module_groups[0]._forward_ops
235+
print("forward fused op:", type(forward_ops[0][0]).__name__ if forward_ops else "none")
236+
if not fwd_only:
237+
backward_ops = module._module_groups[0]._backward_ops
238+
print("backward fused op:", type(backward_ops[0][0]).__name__ if backward_ops else "none")
239+
240+
label = "graph_safe_grouped_mlp_mxfp8_swiglu"
241+
timing_context = (
242+
torch.autograd.profiler.emit_nvtx(record_shapes=True) if profile else nullcontext()
243+
)
244+
with timing_context:
245+
torch.cuda.nvtx.range_push(label)
246+
timing = benchmark.Timer(
247+
stmt=(
248+
"run_grouped_mlp_steps("
249+
"module, x, split_sizes, scales, grad_output, "
250+
"recipe=recipe, fwd_only=fwd_only, num_steps=num_microbatches, "
251+
"accumulate_into_main_grad=accumulate_into_main_grad)"
252+
),
253+
globals={
254+
"run_grouped_mlp_steps": run_grouped_mlp_steps,
255+
"module": module,
256+
"x": x,
257+
"split_sizes": split_sizes,
258+
"scales": scales,
259+
"grad_output": grad_output,
260+
"recipe": recipe,
261+
"fwd_only": fwd_only,
262+
"num_microbatches": num_microbatches,
263+
"accumulate_into_main_grad": accumulate_into_main_grad,
264+
},
265+
num_threads=1,
266+
).blocked_autorange(min_run_time=min_run_time)
267+
torch.cuda.nvtx.range_pop()
268+
269+
print(f"mxfp8_swiglu: {timing}\n")
270+
return timing.median * 1000 / num_microbatches
271+
272+
273+
def main() -> None:
274+
parser = argparse.ArgumentParser()
275+
parser.add_argument("--profile", action="store_true", help="Enable NVTX profiling annotations")
276+
parser.add_argument(
277+
"--fwd-only",
278+
action="store_true",
279+
default=False,
280+
help="Benchmark forward only. Default benchmarks forward + backward.",
281+
)
282+
parser.add_argument(
283+
"--num-groups",
284+
type=str,
285+
default="8",
286+
help="Comma-separated local grouped GEMM/expert counts.",
287+
)
288+
parser.add_argument(
289+
"--token-dims",
290+
type=str,
291+
default="65536",
292+
help="Comma-separated total token counts to benchmark.",
293+
)
294+
parser.add_argument("--hidden-dim", type=int, default=7168)
295+
parser.add_argument("--ffn-hidden-dim", type=int, default=2048)
296+
parser.add_argument("--num-microbatches", type=int, default=32)
297+
parser.add_argument("--min-run-time", type=float, default=10.0)
298+
parser.add_argument("--glu-interleave-size", type=int, default=32)
299+
parser.add_argument(
300+
"--single-grouped-weight",
301+
action="store_true",
302+
default=False,
303+
help="Use one GroupedTensor parameter for each grouped linear.",
304+
)
305+
args = parser.parse_args()
306+
307+
if not MXFP8_AVAILABLE:
308+
raise RuntimeError(f"MXFP8 is not available: {REASON_FOR_NO_MXFP8}")
309+
if not torch.cuda.is_available():
310+
raise RuntimeError("CUDA is required for this benchmark.")
311+
312+
dtype = torch.bfloat16
313+
accumulate_into_main_grad = True
314+
token_dims = parse_int_list(args.token_dims)
315+
num_groups_list = parse_int_list(args.num_groups)
316+
317+
print("Environment toggles:")
318+
for name in (
319+
"CUDA_DEVICE_MAX_CONNECTIONS",
320+
"NVTE_ALLOW_NONDETERMINISTIC_ALGO",
321+
"NVTE_CUTEDSL_FUSED_GROUPED_MLP",
322+
"CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL",
323+
):
324+
print(f" {name}={os.environ.get(name)}")
325+
print("Recipe: MXFP8BlockScaling")
326+
print("Activation: ScaledSwiGLU")
327+
print(f"Default GLU interleave size: {args.glu_interleave_size}")
328+
print()
329+
330+
data = []
331+
for num_groups in num_groups_list:
332+
for total_tokens in token_dims:
333+
timing_ms = benchmark_case(
334+
total_tokens=total_tokens,
335+
hidden_dim=args.hidden_dim,
336+
ffn_hidden_dim=args.ffn_hidden_dim,
337+
num_groups=num_groups,
338+
dtype=dtype,
339+
fwd_only=args.fwd_only,
340+
single_grouped_weight=args.single_grouped_weight,
341+
accumulate_into_main_grad=accumulate_into_main_grad,
342+
glu_interleave_size=args.glu_interleave_size,
343+
num_microbatches=args.num_microbatches,
344+
min_run_time=args.min_run_time,
345+
profile=args.profile,
346+
)
347+
data.append(
348+
[
349+
total_tokens,
350+
args.hidden_dim,
351+
args.ffn_hidden_dim,
352+
num_groups,
353+
args.glu_interleave_size,
354+
args.single_grouped_weight,
355+
accumulate_into_main_grad,
356+
"fwd" if args.fwd_only else "fwd_bwd",
357+
timing_ms,
358+
]
359+
)
360+
361+
timing_col = "time_per_microbatch_ms"
362+
df = pd.DataFrame(
363+
data=data,
364+
columns=[
365+
"tokens",
366+
"hidden_dim",
367+
"ffn_hidden_dim",
368+
"num_groups",
369+
"glu_interleave_size",
370+
"single_grouped_weight",
371+
"accumulate_into_main_grad",
372+
"mode",
373+
timing_col,
374+
],
375+
)
376+
print(df)
377+
378+
379+
if __name__ == "__main__":
380+
main()

build_tools/jax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def setup_jax_extension(
103103

104104
setup_mpi_flags(include_dirs, cxx_flags)
105105

106+
if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))):
107+
cxx_flags.append("-DNVTE_WITH_CUBLASMP")
108+
106109
# Define TE/JAX as a Pybind11Extension
107110
from pybind11.setup_helpers import Pybind11Extension
108111

0 commit comments

Comments
 (0)