Skip to content

Commit 2b26b69

Browse files
committed
Merge remote-tracking branch 'origin' into gtp_release
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
2 parents a532120 + 547d284 commit 2b26b69

201 files changed

Lines changed: 21118 additions & 5427 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/actions/build-pytorch-wheel/Dockerfile

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,5 @@ RUN CUDA_MAJOR_VERSION=$(echo $CUDA_VERSION | awk -F \. {'print $1'}) && \
4141
# Install PyTorch
4242
RUN export MATRIX_CUDA_VERSION=$(echo $CUDA_VERSION | awk -F \. {'print $1 $2'}) && \
4343
export MATRIX_TORCH_VERSION=$(echo $TORCH_VERSION | awk -F \. {'print $1 "." $2'}) && \
44-
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
45-
minv = {'2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126, '2.9': 126}[env['MATRIX_TORCH_VERSION']]; \
46-
maxv = {'2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129, '2.9': 130}[env['MATRIX_TORCH_VERSION']]; \
47-
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
48-
) && \
49-
pip install --no-cache-dir torch==${TORCH_VERSION} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
44+
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; versions = {'2.5': (118, 124), '2.6': (118, 126), '2.7': (118, 128), '2.8': (126, 129), '2.9': (126, 130)}; minv, maxv = versions[env['MATRIX_TORCH_VERSION']]; print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)") && \
45+
pip install --no-cache-dir torch==${TORCH_VERSION} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ FP8 and MXFP8 have been tested extensively across different model architectures
381381
+------------+------------------+---------------------------------------------------------------------------------------------------------+
382382
| Model | Framework | Source |
383383
+============+==================+=========================================================================================================+
384-
| MPT-1.3B | Mosaic Composer | https://www.databricks.com/blog/coreweave-nvidia-h100-part-1 |
384+
| MPT-1.3B | Mosaic Composer | https://www.databricks.com/blog/coreweave-nvidia-h100-part-1 |
385385
+------------+------------------+---------------------------------------------------------------------------------------------------------+
386386
| LLama2-7B | Alibaba Pai | https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ |
387387
+------------+------------------+---------------------------------------------------------------------------------------------------------+

benchmarks/benchmark_rht_cast.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch.utils.benchmark as benchmark
99

1010
import transformer_engine.pytorch as te
11-
import transformer_engine_torch as tex
1211
import transformer_engine.pytorch.cpp_extensions as ext
1312

1413
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
@@ -17,7 +16,7 @@
1716
permute_scale = False
1817

1918
TORCH_TO_TE_FLOAT_MAP = {
20-
torch.bfloat16: tex.DType.kBFloat16,
19+
torch.bfloat16: te.DType.kBFloat16,
2120
}
2221

2322

@@ -31,7 +30,7 @@ def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):
3130

3231
# Quantize
3332
nvfp4_quantizer = NVFP4Quantizer(
34-
fp4_dtype=tex.DType.kFloat4E2M1,
33+
fp4_dtype=te.DType.kFloat4E2M1,
3534
rowwise=True,
3635
columnwise=True,
3736
with_amax_reduction=False,
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""Benchmark NVFP4 RHT cast-fusion with vs without fused GEMM-swizzled SF output.
6+
7+
For each shape we measure two paths and two builds:
8+
9+
* path = "quant_only": just NVFP4Quantizer(x)
10+
* path = "quant_plus_swizzle": NVFP4Quantizer(x) + tex.swizzle_scales_for_gemm_(t)
11+
(this is what te.Linear -> tex.generic_gemm does right before the
12+
cuBLAS LT NVFP4 GEMM dispatch)
13+
14+
* build = "baseline": optimize_for_gemm=False
15+
-> quant kernel emits compact SF;
16+
tex.swizzle_scales_for_gemm_ launches the standalone
17+
swizzle_{row,col}_scaling_kernel pass before GEMM.
18+
* build = "swizzle_fusion": optimize_for_gemm=True
19+
-> quant kernel emits GEMM-swizzled SF directly (via the
20+
kEnableSwizzleSFOutput compile-time switch in
21+
row_cast_col_hadamard_transform_cast_fusion.cu);
22+
tex.swizzle_scales_for_gemm_ early-returns and the standalone
23+
swizzle pass disappears from the timeline.
24+
25+
The wall-clock delta on the "quant_plus_swizzle" path is the production
26+
saving of this PR.
27+
"""
28+
29+
import argparse
30+
import torch
31+
import pandas as pd
32+
import torch.utils.benchmark as benchmark
33+
34+
import transformer_engine.pytorch as te # noqa: F401 must be first per te-python-import-order
35+
import transformer_engine_torch as tex
36+
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
37+
38+
39+
def make_quantizer(optimize_for_gemm: bool) -> NVFP4Quantizer:
40+
q = NVFP4Quantizer(
41+
fp4_dtype=tex.DType.kFloat4E2M1,
42+
rowwise=True,
43+
columnwise=True,
44+
with_amax_reduction=False,
45+
amax_reduction_group=None,
46+
with_rht=True,
47+
with_post_rht_amax=True,
48+
with_random_sign_mask=True,
49+
)
50+
q.optimize_for_gemm = optimize_for_gemm
51+
return q
52+
53+
54+
def _bench(stmt: str, globals_dict: dict, min_run_time: float) -> float:
55+
"""Returns median wall-clock per call in microseconds."""
56+
timing = benchmark.Timer(
57+
stmt=stmt,
58+
globals=globals_dict,
59+
num_threads=1,
60+
).blocked_autorange(min_run_time=min_run_time)
61+
return timing.median * 1e6
62+
63+
64+
def run_shape(shape, min_run_time: float):
65+
M, K = shape
66+
assert M % 16 == 0 and K % 16 == 0, "Shape must be divisible by 16"
67+
68+
x = torch.randn([M, K], dtype=torch.bfloat16, device="cuda")
69+
q_base = make_quantizer(optimize_for_gemm=False)
70+
q_swf = make_quantizer(optimize_for_gemm=True)
71+
72+
# quant_only path
73+
quant_only_base_us = _bench(
74+
stmt="q(x)",
75+
globals_dict={"q": q_base, "x": x},
76+
min_run_time=min_run_time,
77+
)
78+
quant_only_swf_us = _bench(
79+
stmt="q(x)",
80+
globals_dict={"q": q_swf, "x": x},
81+
min_run_time=min_run_time,
82+
)
83+
84+
# quant_plus_swizzle path (this is what te.Linear actually runs)
85+
quant_plus_swizzle_base_us = _bench(
86+
stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)",
87+
globals_dict={"q": q_base, "x": x, "tex": tex},
88+
min_run_time=min_run_time,
89+
)
90+
quant_plus_swizzle_swf_us = _bench(
91+
stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)",
92+
globals_dict={"q": q_swf, "x": x, "tex": tex},
93+
min_run_time=min_run_time,
94+
)
95+
96+
saved_us = quant_plus_swizzle_base_us - quant_plus_swizzle_swf_us
97+
speedup = (
98+
quant_plus_swizzle_base_us / quant_plus_swizzle_swf_us
99+
if quant_plus_swizzle_swf_us > 0
100+
else float("inf")
101+
)
102+
103+
print(
104+
f" shape={shape}: quant_only base={quant_only_base_us:.2f}us, "
105+
f"SUT={quant_only_swf_us:.2f}us | "
106+
f"quant+swizzle base={quant_plus_swizzle_base_us:.2f}us, "
107+
f"SUT={quant_plus_swizzle_swf_us:.2f}us "
108+
f"-> saved {saved_us:.2f}us ({speedup:.2f}x)"
109+
)
110+
111+
return {
112+
"shape": shape,
113+
"M": M,
114+
"K": K,
115+
"quant_only_base_us": quant_only_base_us,
116+
"quant_only_swf_us": quant_only_swf_us,
117+
"quant_plus_swizzle_base_us": quant_plus_swizzle_base_us,
118+
"quant_plus_swizzle_swf_us": quant_plus_swizzle_swf_us,
119+
"saved_us": saved_us,
120+
"speedup": speedup,
121+
}
122+
123+
124+
# Nsight Compute Profiling Command (for verifying the swizzle kernel disappears):
125+
# ncu -f -o swizzle_fusion --set=full \
126+
# --kernel-name "regex:swizzle_(row|col)_scaling_kernel|cast_col_hadamard_transform_cast_fusion" \
127+
# -s 5 -c 10 python benchmarks/benchmark_rht_cast_swizzle_fusion.py --profile
128+
129+
130+
if __name__ == "__main__":
131+
parser = argparse.ArgumentParser()
132+
parser.add_argument(
133+
"--profile",
134+
action="store_true",
135+
help="Run only one shape for use with ncu/nsys; longer min_run_time",
136+
)
137+
parser.add_argument(
138+
"--min-run-time",
139+
type=float,
140+
default=2.0,
141+
help="Minimum total measured time per cell in seconds (benchmark.Timer)",
142+
)
143+
parser.add_argument(
144+
"--csv",
145+
type=str,
146+
default="benchmark_rht_cast_swizzle_fusion.csv",
147+
help="CSV output path",
148+
)
149+
args = parser.parse_args()
150+
151+
if args.profile:
152+
print("Profiling mode enabled (single shape).")
153+
shapes = [(8192, 4096)]
154+
min_run_time = max(5.0, args.min_run_time)
155+
else:
156+
shapes = [
157+
# production-class shapes
158+
(8192, 5120),
159+
(8192, 10240),
160+
(8192, 2560),
161+
(8192, 11328),
162+
(8192, 3584),
163+
(5120, 8192),
164+
(10240, 8192),
165+
(2560, 8192),
166+
(11328, 8192),
167+
(3584, 8192),
168+
(4096, 16384),
169+
(14336, 16384),
170+
]
171+
min_run_time = args.min_run_time
172+
173+
print(
174+
"NVFP4 RHT cast-fusion: swizzle-fusion (optimize_for_gemm=True) vs baseline. "
175+
f"min_run_time={min_run_time}s per cell, BF16 input, "
176+
"rowwise+columnwise SF, RHT=True+post_rht_amax."
177+
)
178+
rows = []
179+
for shape in shapes:
180+
print(f"Running {shape} ...")
181+
rows.append(run_shape(shape, min_run_time))
182+
183+
df = pd.DataFrame(rows)
184+
pd.set_option("display.max_columns", None)
185+
pd.set_option("display.width", 200)
186+
print()
187+
print(df.to_string(index=False))
188+
df.to_csv(args.csv, index=False)
189+
print(f"\nWrote {args.csv}")

0 commit comments

Comments
 (0)