Skip to content
185 changes: 115 additions & 70 deletions src/te_graph_runtime/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def _make_graphed_callables(
pre_warmup_hook: Optional[Callable] = None,
post_warmup_hook: Optional[Callable] = None,
capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]] = None,
capture_stream: Optional[torch.cuda.Stream] = None,
) -> SingleOrTuple[Callable]:
"""
Helper method for `make_graphed_callables`
Expand Down Expand Up @@ -826,7 +827,7 @@ def _make_grad_outputs(outputs):
torch.empty_like(o) if o is not None and o.requires_grad else None for o in outputs
)

def _run_warmup_forward(func_idx, func, callable_idx):
def _run_warmup_forward(func_idx, func, callable_idx, register_discovery_hooks=True):
args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx]

Expand Down Expand Up @@ -855,7 +856,7 @@ def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unus

_call_capture_time_forward_pre_hooks(callable_idx, func, args, kwargs)
hooks = []
if isinstance(func, torch.nn.Module):
if register_discovery_hooks and isinstance(func, torch.nn.Module):
for module in func.modules():
hooks.append(module.register_forward_hook(hook_fn))
outputs = func(*args, **kwargs)
Expand All @@ -867,7 +868,7 @@ def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unus

def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> None:
static_input_surface = per_callable_static_input_surfaces[func_idx]
inputs = tuple(i for i in static_input_surface if i.requires_grad)
inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad)
outputs_requiring_grad = tuple(o for o in outputs if o is not None and o.requires_grad)
grad_outputs = _make_grad_outputs(outputs)

Expand All @@ -883,10 +884,13 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
# Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks
# registered to these params are not wrongly triggered.
num_required_grad_sample_args = sum(arg.requires_grad for arg in flatten_sample_args[func_idx])
num_required_grad_sample_args = sum(
isinstance(arg, torch.Tensor) and arg.requires_grad
for arg in flatten_sample_args[func_idx]
)
required_grad_input_idx = []
for i, arg in enumerate(static_input_surface):
if arg.requires_grad:
if isinstance(arg, torch.Tensor) and arg.requires_grad:
required_grad_input_idx.append(i)
module_params_with_grad = []
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
Expand Down Expand Up @@ -916,59 +920,95 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
module.backward_dw()
need_bwd_dw_graph[func_idx] = need_backward_dw

# Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()):
def _run_warmup_iteration(warmup_iter, register_discovery_hooks):
if _order is None:
warmup_outputs = []
for func_idx, func in zip(warmup_func_idx, warmup_func):
outputs = _run_warmup_forward(
func_idx,
func,
func_idx,
register_discovery_hooks=register_discovery_hooks,
)
warmup_outputs.append((func_idx, func, outputs))
if is_training:
for func_idx, func, outputs in reversed(warmup_outputs):
_run_warmup_backward(func_idx, func, outputs, warmup_iter, func_idx)
return

per_fwd_outputs = {}
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
for c_id in _order:
if c_id > 0:
m_chunk = c_id - 1
for l_no in range(_num_layers_per_chunk[m_chunk]):
callable_idx = _prefix_num_layers[m_chunk] + l_no
per_callable_fwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
)
func = callables[callable_idx]
outputs = _run_warmup_forward(
per_callable_fwd_idx,
func,
callable_idx,
register_discovery_hooks=register_discovery_hooks,
)
per_fwd_outputs[per_callable_fwd_idx] = outputs
fwd_idx[m_chunk] += 1
elif ceil(c_id) == c_id:
if is_training:
m_chunk = -c_id - 1
for l_no in reversed(range(_num_layers_per_chunk[m_chunk])):
callable_idx = _prefix_num_layers[m_chunk] + l_no
per_callable_bwd_idx = (
_prefix_num_layers[m_chunk] * num_microbatches
) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no)
func = callables[callable_idx]
outputs = per_fwd_outputs[per_callable_bwd_idx]
_run_warmup_backward(
per_callable_bwd_idx,
func,
outputs,
warmup_iter,
callable_idx,
)
bwd_idx[m_chunk] += 1

# Run warmup on the same stream as capture so workspace buffers
# stay in the same CUDA context and don't need re-allocation.
capture_stream = capture_stream or torch.cuda.Stream()
with torch.cuda.stream(capture_stream):
if pre_warmup_hook is not None:
pre_warmup_hook()

for warmup_iter in range(num_warmup_iters):
if _order is None:
warmup_outputs = []
for func_idx, func in zip(warmup_func_idx, warmup_func):
outputs = _run_warmup_forward(func_idx, func, func_idx)
warmup_outputs.append((func_idx, func, outputs))
if is_training:
for func_idx, func, outputs in reversed(warmup_outputs):
_run_warmup_backward(func_idx, func, outputs, warmup_iter, func_idx)
else:
per_fwd_outputs = {}
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
for c_id in _order:
if c_id > 0:
m_chunk = c_id - 1
for l_no in range(_num_layers_per_chunk[m_chunk]):
callable_idx = _prefix_num_layers[m_chunk] + l_no
per_callable_fwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
)
func = callables[callable_idx]
outputs = _run_warmup_forward(per_callable_fwd_idx, func, callable_idx)
per_fwd_outputs[per_callable_fwd_idx] = outputs
fwd_idx[m_chunk] += 1
elif ceil(c_id) == c_id:
if is_training:
m_chunk = -c_id - 1
for l_no in reversed(range(_num_layers_per_chunk[m_chunk])):
callable_idx = _prefix_num_layers[m_chunk] + l_no
per_callable_bwd_idx = (
_prefix_num_layers[m_chunk] * num_microbatches
) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no)
func = callables[callable_idx]
outputs = per_fwd_outputs[per_callable_bwd_idx]
_run_warmup_backward(
per_callable_bwd_idx,
func,
outputs,
warmup_iter,
callable_idx,
)
bwd_idx[m_chunk] += 1
_run_warmup_iteration(warmup_iter, register_discovery_hooks=True)

# TE discovery temporarily registers forward hooks, and Dynamo guards
# compiled modules on hook state. Capture runs after those hooks are
# removed, so warm the capture-equivalent specialization as well.
compiled_callables = any(
getattr(func, "_compiled_call_impl", None) is not None
or hasattr(getattr(func, "forward", None), "_torchdynamo_orig_callable")
for func in callables
)
if num_warmup_iters > 0 and compiled_callables:
_run_warmup_iteration(
num_warmup_iters,
register_discovery_hooks=False,
)

if post_warmup_hook is not None:
post_warmup_hook()
torch.cuda.synchronize()

import gc
gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()

# All captures here share a mempool. To avoid replays corrupting each other's memory,
# the safest approach is to capture all passes in the same order they'll run:
# fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
Expand Down Expand Up @@ -1002,7 +1042,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
kwargs = sample_kwargs[per_callable_fwd_idx]
fwd_graph = fwd_graphs[per_callable_fwd_idx]
_call_capture_time_forward_pre_hooks(callable_idx, func, args, kwargs)
with _graph_context_wrapper(fwd_graph, pool=mempool):
with _graph_context_wrapper(fwd_graph, pool=mempool, stream=capture_stream):
outputs = func(*args, **kwargs)
_call_capture_time_forward_hooks(callable_idx, func, args, kwargs, outputs)
flatten_outputs, spec = _tree_flatten(outputs)
Expand All @@ -1019,7 +1059,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
)
if ceil(c_id) == c_id and need_bwd_dw_graph[per_callable_bwd_idx]:
if ceil(c_id) == c_id and need_bwd_dw_graph.get(per_callable_bwd_idx, False):
# Check if bwd graph has corresponding wgrad graph:
# Number of dgrad backward graphs should be equal to number of
# wgrad backward graphs.
Expand Down Expand Up @@ -1062,12 +1102,12 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
"The order diff of wgrad and dgrad must be 0.5, "
f"get {ceil(c_id) - c_id}."
)
if not need_bwd_dw_graph[per_callable_bwd_idx]:
if not need_bwd_dw_graph.get(per_callable_bwd_idx, False):
raise RuntimeError(
"No module needs wgrad computation but get float in order"
)
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
with _graph_context_wrapper(bwd_dw_graph, pool=mempool, stream=capture_stream):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
Expand Down Expand Up @@ -1108,9 +1148,9 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
func,
static_grad_outputs,
)
inputs = tuple(i for i in static_input_surface if i.requires_grad)
inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
bwd_graph, pool=mempool, stream=capture_stream
):
torch.autograd.backward(
tuple(
Expand Down Expand Up @@ -1204,7 +1244,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
zip(callables, sample_args, sample_kwargs, fwd_graphs)
):
_call_capture_time_forward_pre_hooks(func_idx, func, args, kwargs)
with _graph_context_wrapper(fwd_graph, pool=mempool):
with _graph_context_wrapper(fwd_graph, pool=mempool, stream=capture_stream):
outputs = func(*args, **kwargs)
_call_capture_time_forward_hooks(func_idx, func, args, kwargs, outputs)
graph_callables[func_idx] = func
Expand Down Expand Up @@ -1232,7 +1272,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
if is_training:
func = graph_callables[bwd_idx]
_call_capture_time_backward_pre_hooks(bwd_idx, func, static_grad_outputs)
inputs = tuple(i for i in static_input_surface if i.requires_grad)
inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
Expand All @@ -1244,8 +1284,8 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
grad_inputs = tuple(input.grad for input in inputs)
_call_capture_time_backward_hooks(bwd_idx, func, grad_inputs, static_grad_outputs)

if need_bwd_dw_graph[bwd_idx]:
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
if need_bwd_dw_graph.get(bwd_idx, False):
with _graph_context_wrapper(bwd_dw_graph, pool=mempool, stream=capture_stream):
for module in visited_te_modules[bwd_idx]:
if hasattr(module, "need_backward_dw") and module.need_backward_dw():
module.backward_dw()
Expand Down Expand Up @@ -1311,6 +1351,7 @@ def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *i
for i in range(len_user_args):
if (
isinstance(static_input_surface[i], torch.Tensor)
and inputs[i] is not None
and static_input_surface[i].data_ptr() != inputs[i].data_ptr()
):
static_input_surface[i].copy_(inputs[i])
Expand Down Expand Up @@ -1416,21 +1457,23 @@ def functionalized(*user_args, **user_kwargs):
user_kwargs.pop("cuda_graph_event")
else:
cuda_graph_event = None
# Check that required kwargs are provided
for key in kwargs_keys:
if key not in user_kwargs:
raise TypeError(
f"Graphed callable was initialized with kwarg {key} ,"
"but it was not provided in graph replay"
)

# Runs the autograd function with inputs == all inputs to
# the graph that might require grad (explicit user args +
# module parameters)
# Assumes module params didn't change since capture.
flatten_user_args, _ = _tree_flatten(user_args)
flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys])
func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params
# Reconstruct the same flattened arg order as capture time.
# User may pass some recorded kwargs as positional args, so
# check user_args first (by position), then user_kwargs.
user_pos_args = list(user_args)
kwarg_values = []
for key in kwargs_keys:
if key in user_kwargs:
kwarg_values.append(user_kwargs[key])
elif user_pos_args:
kwarg_values.append(user_pos_args.pop(0))
# else: key was a default not passed — skip (not a tensor)
flatten_user_kwargs, _ = _tree_flatten(kwarg_values)
func_args = tuple(flatten_user_kwargs) + module_params
out = Graphed.apply(
skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args
)
Expand Down Expand Up @@ -1624,6 +1667,7 @@ def make_graphed_callables(
pre_warmup_hook: Optional[Callable] = None,
post_warmup_hook: Optional[Callable] = None,
capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]] = None,
capture_stream: Optional[torch.cuda.Stream] = None,
) -> Union[Callable, Tuple[Callable, ...]]:
"""
Make CUDA graph version of Transformer Engine modules
Expand Down Expand Up @@ -1897,6 +1941,7 @@ def call_func(self, *args, **kwargs):
pre_warmup_hook=pre_warmup_hook,
post_warmup_hook=post_warmup_hook,
capture_time_hooks=capture_time_hooks,
capture_stream=capture_stream,
)

# Ensures warmup does not affect numerics for ops such as dropout.
Expand Down
20 changes: 20 additions & 0 deletions tests/test_cuda_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,26 @@ def test_cuda_warmup_hooks_without_te(monkeypatch: pytest.MonkeyPatch) -> None:
assert calls == {"pre": 1, "post": 1}


def test_cuda_module_compile_is_stabilized_before_capture_without_te(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_force_no_te(monkeypatch)
model = torch.nn.Linear(4, 4).cuda()
model.compile(dynamic=False, options={"triton.cudagraphs": False})
sample = torch.randn(2, 4, device="cuda", requires_grad=True)

graphed = graph.make_graphed_callables(
model,
(sample,),
allow_unused_input=True,
num_warmup_iters=1,
)

x = torch.randn(2, 4, device="cuda", requires_grad=True)
graphed(x).sum().backward()
assert model.weight.grad is not None


def _assert_not_cuda_capturing() -> None:
is_current_stream_capturing = getattr(torch.cuda, "is_current_stream_capturing", None)
if is_current_stream_capturing is not None:
Expand Down