Skip to content

Commit 5fdfbec

Browse files
LeSingh1ksivaman
andauthored
[PyTorch] Propagate skip_fp8_weight_update in GroupedLinear during FP8 CUDA graph capture (#3065)
* [PyTorch] Propagate skip_fp8_weight_update in GroupedLinear during FP8 CUDA graph capture GroupedLinear.forward hardcoded None for skip_fp8_weight_update, so the FP8 graph-capture skip tensor was never forwarded during CUDA graph replay. Mirror Linear.forward: when fp8_graph_capturing() is true, read quantization_state.skip_fp8_weight_update_tensor, force is_first_microbatch to False, and thread the tensor into the forward call (the slot _GroupedLinear.forward already unpacks). Fixes #3051 Signed-off-by: LeSingh1 <sshaurya914@gmail.com> * [PyTorch] Add CUDA graph FP8 weight-caching test for GroupedLinear Exercises skip_fp8_weight_update propagation in GroupedLinear during FP8 CUDA graph capture. With fp8_weight_caching enabled, graphed and eager runs only match when is_first_microbatch is threaded into the weight- update skip tensor for every microbatch, which the prior None hardcode prevented. Signed-off-by: LeSingh1 <sshaurya914@gmail.com> --------- Signed-off-by: LeSingh1 <sshaurya914@gmail.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent 4bf946d commit 5fdfbec

2 files changed

Lines changed: 97 additions & 0 deletions

File tree

tests/pytorch/test_cuda_graphs.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from transformer_engine.pytorch import (
1111
DotProductAttention,
12+
GroupedLinear,
1213
LayerNormLinear,
1314
LayerNormMLP,
1415
Linear,
@@ -216,6 +217,38 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
216217
return x
217218

218219

220+
class _GroupedLinearWrapper(torch.nn.Module):
221+
"""Adapt GroupedLinear to the [seqlen, batch, hidden] data used in this test.
222+
223+
GroupedLinear expects a 2D `[total_tokens, hidden]` input plus an `m_splits`
224+
list, so this wrapper flattens the leading dims, splits the tokens evenly
225+
across the GEMMs, and restores the original shape. It also forwards
226+
`is_first_microbatch` so the FP8 weight-caching path is exercised under CUDA
227+
graphs.
228+
"""
229+
230+
def __init__(self, hidden_size: int, num_gemms: int, params_dtype: torch.dtype) -> None:
231+
super().__init__()
232+
self.num_gemms = num_gemms
233+
self.grouped_linear = GroupedLinear(
234+
num_gemms,
235+
hidden_size,
236+
hidden_size,
237+
device="cuda",
238+
params_dtype=params_dtype,
239+
)
240+
241+
def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
242+
seqlen, batch, hidden = input_.shape
243+
x = input_.reshape(seqlen * batch, hidden)
244+
total_tokens = x.shape[0]
245+
assert total_tokens % self.num_gemms == 0, "tokens must split evenly across GEMMs"
246+
split = total_tokens // self.num_gemms
247+
m_splits = [split] * self.num_gemms
248+
out = self.grouped_linear(x, m_splits, **kwargs)
249+
return out.reshape(seqlen, batch, hidden)
250+
251+
219252
# Supported modules
220253
_test_cuda_graphs_modules: List[str] = [
221254
# Put linear first to test the case where the cuda context might not be set in
@@ -315,6 +348,15 @@ def _test_cuda_graphs(
315348
)
316349
for _ in range(num_layers)
317350
]
351+
elif module == "grouped_linear":
352+
modules = [
353+
_GroupedLinearWrapper(
354+
model_config.hidden_size,
355+
num_gemms=2,
356+
params_dtype=dtype,
357+
)
358+
for _ in range(num_layers)
359+
]
318360
elif module == "linear_op":
319361
modules = [
320362
te_ops.Sequential(
@@ -501,6 +543,52 @@ def test_make_graphed_callables_with_fp8_weight_caching(
501543
)
502544

503545

546+
# Per-tensor FP8 recipes that support GroupedLinear FP8 weight caching.
547+
_grouped_linear_fp8_weight_caching_recipes = []
548+
if fp8_available:
549+
_grouped_linear_fp8_weight_caching_recipes.append(recipe.DelayedScaling())
550+
_grouped_linear_fp8_weight_caching_recipes.append(recipe.Float8CurrentScaling())
551+
552+
553+
@pytest.mark.skipif(not fp8_available, reason="FP8 is not supported")
554+
@pytest.mark.parametrize("dtype", dtypes)
555+
@pytest.mark.parametrize("fp8_params", (False, True))
556+
@pytest.mark.parametrize("fp8_recipe", _grouped_linear_fp8_weight_caching_recipes, ids=recipe_id)
557+
def test_make_graphed_callables_grouped_linear_with_fp8_weight_caching(
558+
*,
559+
dtype: torch.dtype,
560+
fp8_params: bool,
561+
fp8_recipe: recipe.Recipe,
562+
model_config: str = "small",
563+
num_layers: int = 3,
564+
) -> None:
565+
"""GroupedLinear must thread `is_first_microbatch` into the FP8 weight-update
566+
skip tensor under CUDA graphs.
567+
568+
With `fp8_weight_caching` enabled, the graphed and non-graphed runs only match
569+
when `skip_fp8_weight_update` is propagated for every microbatch. Before the
570+
fix, GroupedLinear hardcoded it to `None`, so the cached FP8 weights diverged
571+
from the eager reference. This regresses if that propagation is dropped again.
572+
"""
573+
config = model_configs[model_config]
574+
kwargs = dict(
575+
module="grouped_linear",
576+
model_config=config,
577+
num_layers=num_layers,
578+
dtype=dtype,
579+
fp8=True,
580+
fp8_params=fp8_params,
581+
fp8_weight_caching=True,
582+
fp8_recipe=fp8_recipe,
583+
)
584+
graph_outputs_full = _test_cuda_graphs(graph_mode="full", **kwargs)
585+
graph_outputs_individual = _test_cuda_graphs(graph_mode="individual", **kwargs)
586+
outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
587+
588+
assert_all_equal(outputs, graph_outputs_full)
589+
assert_all_equal(outputs, graph_outputs_individual)
590+
591+
504592
def generate_data_for_dot_product_attention(
505593
model_config: ModelConfig,
506594
dtype: torch.dtype,

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,6 +1684,15 @@ def forward(
16841684
is_grad_enabled = torch.is_grad_enabled()
16851685
num_gemms = self.num_gemms
16861686

1687+
if FP8GlobalStateManager.fp8_graph_capturing():
1688+
skip_fp8_weight_update = (
1689+
FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor
1690+
)
1691+
else:
1692+
skip_fp8_weight_update = None
1693+
if skip_fp8_weight_update is not None:
1694+
is_first_microbatch = False
1695+
16871696
# Make sure splits are in expected format
16881697
if not isinstance(m_splits, torch.Tensor):
16891698
# Convert list of ints to tensor for backward compatibility

0 commit comments

Comments
 (0)