Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ def graph_pool_handle():
return _graph_pool_handle()


@contextlib.contextmanager
def _none_grad_context_wrapper(inputs):
"""
Wrapper to set the gradients of the inputs to None,
in case the backward pass makes grad accumulations.
"""
original_input_grads = []
for input in inputs:
original_input_grads.append(input.grad)
input.grad = None
yield
for input, original_grad in zip(inputs, original_input_grads):
input.grad = original_grad


@contextlib.contextmanager
def _graph_context_wrapper(*args, **kwargs):
"""Wrapper around `torch.cuda.graph`.
Expand Down Expand Up @@ -434,13 +449,15 @@ def hook_fn(
for hook in hooks:
hook.remove()
if is_training:
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
only_inputs=True,
allow_unused=allow_unused_input,
)
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs):
torch.autograd.backward(
tuple(o for o in outputs if o.requires_grad),
grad_tensors=tuple(
torch.empty_like(o) for o in outputs if o.requires_grad
),
)
grad_inputs = tuple(input.grad for input in inputs)

# Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks
Expand All @@ -454,7 +471,9 @@ def hook_fn(
required_grad_input_idx.append(i)
module_params_with_grad = []
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
if (
if grad_inputs[grad_inputs_idx] is None and grad_inputs_idx < num_required_grad_sample_args:
assert allow_unused_input, "The input tensor requires grad, but the grad is None after backward pass."
elif (
grad_inputs[grad_inputs_idx] is not None
and grad_inputs_idx >= num_required_grad_sample_args
):
Expand Down Expand Up @@ -606,15 +625,21 @@ def hook_fn(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
inputs = tuple(
i for i in static_input_surface if i.requires_grad
)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
grad_tensors=tuple(
o for o in static_grad_outputs if o is not None
),
retain_graph=retain_graph_in_backward,
)
grad_inputs = tuple(input.grad for input in inputs)

# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
Expand Down Expand Up @@ -695,15 +720,19 @@ def hook_fn(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
grad_tensors=tuple(
o for o in static_grad_outputs if o is not None
),
retain_graph=retain_graph_in_backward,
)
grad_inputs = tuple(input.grad for input in inputs)

if need_bwd_dw_graph[bwd_idx]:
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[bwd_idx]:
Expand Down
Loading