Skip to content

Conversation

@githubnemo
Copy link
Collaborator

@githubnemo githubnemo commented Oct 22, 2025

Relevant issue: #2826

Gradient checkpointing is different to the normal forward/backward process in that it may execute forward
steps in the backward process. It is therefore decoupled from the normal forward process. Since we have
methods, such as activated LoRA, that depend on peft_forward_hooks in normal forward we have to make sure
that these hooks are applied in the independently executed forward by gradient checkpointing as well.

To this end we had several options:

  • Don't use hooks to communicate the alora offsets, write them to the module directly. We can do that but then
    we'd need a mechanism to clean up these values afterwards (probably involving hooks) and we would be introducing
    a non-general way which might biting us in the future since more methods need parameter injection.
  • Don't support use_reentrant=True (default for transformers) and use context_fn parameter to inject parameters
    when using use_reentrant=False. torch.utils.checkpoint supports adding a context_fn parameter which returns
    two context managers (one for the normal, one for the checkpoint forward). In theory this could be a way to
    inject the variables into the module. In practice we would need to modify the keywords argument of every
    ._gradient_checkpointing_func attribute of every module to inject the context_fn callback and update those
    accordingly every forward call. This is far less reliable than forward hooks.
  • Register forward hooks on the gradient checkpointing layers that apply the same hooks that enable_peft_forward_hooks
    does - but decoupled from enable_peft_forward_hooks. These hooks are removed once a full backward hook on the
    gradient checkpoint layer is called. We'd still need to use shared storage to store the hook handles so that
    we don't rack up forward and backward hooks but this storage is a general way of implementing forward hooks
    in gradient checkpointing. It also let's us control the flow without using private methods/variables. Since this
    adds forward hooks that are only removed when backward is called, we therefore disallow multiple forward calls
    in succession before a backward call (the user can do that with gradient checkpointing disabled).

In this change I'm implementing the last option, forward hooks on gradient checkpointing layers.

We already had tests but there were some issues that are improved in this PR:

  • parametrize use_reentrant so we check both, even though use_reentrant=True is the default;
    since use_reentrant=False has a consistency checker it might detect errors that we don't cover
  • check consistency between normal and checkpointed model runs, both runs must have the same set
    of non-zero gradients. it is not sufficient to check that there is a gradient, it must be non-zero
    as well (gradient checkpointing can fail with zero grads)
  • set model.train to enable gradient checkpointing
  • disable lora_dropout if set to make gradients (a bit more) deterministic

While testing I found that GPTBigCode doesn't support gradient checkpointing even though it says so,
skipping for now until fixed. BART doesn't work with the newest changes and fails at loss.backward()
with RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn).
Update: The BART issue is caused by the model not sharing the embed tokens module but just the weights. This in turn leads to get_input_embeddings returning one specific module which may not be invoked in the forward call at all (model.shared in this case - it has the same weights but it is a different nn.Module). Since enable_input_require_grads depends on the module returned by get_input_embeddings to have working forward hooks, the preparation for gradient checkpointing fails. This needs to be fixed in transformers either by targeting all tied weights or sharing the embedding modules instead of just the weights (like in T5).

Also, I emoved gradient requirement in CPT's embedding.
This module is never called in a gradient path so it is safe to set it
to requires_grad=False. This helps in upholding the assumption that
all parameters that require a gradient need to receive one.

Pinging @kgreenewald as well.

nemo added 3 commits October 22, 2025 10:20
This module is never called in a gradient path so it is safe to set it
to `requires_grad=False`. This helps in upholding the assumption that
all parameters that require a gradient need to receive one.
- parametrize `use_reentrant` so we check both, even though `use_reentrant=True` is the default;
  since `use_reentrant=False` has a consistency checker it might detect errors that we don't cover
- check consistency between normal and checkpointed model runs, both runs must have the same set
  of non-zero gradients. it is not sufficient to check that there is a gradient, it must be non-zero
  as well (gradient checkpointing can fail with zero grads)
- set `model.train` to enable gradient checkpointing
- disable `lora_dropout` if set to make gradients (a bit more) deterministic

Also while testing I found that GPTBigCode doesn't support gradient checkpointing even though it says so,
skipping for now until fixed. BART doesn't work with the newest changes and fails at `loss.backward()`
with `RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn`).
This is still open.
Gradient checkpointing is different to the normal forward/backward process in that it may execute forward
steps in the backward process. It is therefore decoupled from the normal forward process. Since we have
methods, such as activated LoRA, that depend on `peft_forward_hooks` in normal forward we have to make sure
that these hooks are applied in the independently executed forward by gradient checkpointing as well.

To this end we had several options:
- Don't use hooks to communicate the alora offsets, write them to the module directly. We can do that but then
  we'd need a mechanism to clean up these values afterwards (probably involving hooks) and we would be introducing
  a non-general way which might biting us in the future since more methods need parameter injection.
- Don't support `use_reentrant=True` (default for transformers) and use `context_fn` parameter to inject parameters
  when using `use_reentrant=False`. `torch.utils.checkpoint` supports adding a `context_fn` parameter which returns
  two context managers (one for the normal, one for the checkpoint forward). In theory this could be a way to
  inject the variables into the module. In practice we would need to modify the `keywords` argument of every
  `._gradient_checkpointing_func` attribute of every module to inject the `context_fn` callback and update those
  accordingly every forward call. This is far less reliable than forward hooks.
- Register forward hooks on the gradient checkpointing layers that apply the same hooks that `enable_peft_forward_hooks`
  does - but decoupled from `enable_peft_forward_hooks`. These hooks are removed once a full backward hook on the
  gradient checkpoint layer is called. We'd still need to use shared storage to store the hook handles so that
  we don't rack up forward and backward hooks but this storage is a general way of implementing forward hooks
  in gradient checkpointing. It also let's us control the flow without using private methods/variables. Since this
  adds forward hooks that are only removed when backward is called, we therefore disallow multiple forward calls
  in succession before a backward call (the user can do that with gradient checkpointing disabled).

We're implementing the last option, forward hooks on gradient checkpointing layers.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this complex issue and coming up with a solution. Generally, it LGTM, I just have a few comments, please check.

It would also be great if @@kgreenewald could try if this branch resolves the original issue.

"Multiple invocations of PEFT forward hooks before .backward() with enabled gradient "
"checkpointing. Disable gradient checkpointing or only call forward once per backward."
)
layer._peft_gradient_checkpointing_forward_hooks = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check for pre-existing _peft_gradient_checkpointing_forward_hooks? Right now, we know they don't exist, but since this is supposed to be a general solution (not aLoRA-specific), I'd say it's safer to check.

module._peft_gradient_checkpointing_forward_hooks.append(handle)

def backward_hook(name, module, *grad_output, **kwargs):
for _ in range(len(module._peft_gradient_checkpointing_forward_hooks)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would while module._peft_gradient_checkpointing_forward_hooks: be a bit easier to understand?

def forward_pre_hook(name, module, inputs, **kwargs):
for submodule in module.modules():
if isinstance(submodule, LoraLayer):
handle = submodule.register_forward_pre_hook(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is not obvious to me. We register forward_pre_hook as a forward pre hook and this function itself registers more forward pre hooks. Is this just a convenient way to recurse? Could we not make this explicit by looping over the submodules in https://github.com/huggingface/peft/pull/2860/files#diff-4331e1b00e557c2af6e682bc221d1517f554d7ca24b9149520237b431d8d797fR390 and directly register the final hooks?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _enable_peft_forward_hooks is becoming quite complex at this point. It could be worth it to refactor it into the mixed batch part and into the aLoRA part. Not necessarily in this PR, but could be done later.

self.skipTest("AdaLora with RoBERTa does not work correctly")

if "bart" in model_id.lower():
# TODO: no backprop possible with Bart, not sure why
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment can be updated, right?


model.gradient_checkpointing_enable()
# Disable lora_dropout to remove non-determinism in gradient creation
if "lora_dropout" in config_kwargs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering about non-LoRA PEFT methods.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants