-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix #2826: implement gradient checkpoint callbacks #2860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix #2826: implement gradient checkpoint callbacks #2860
Conversation
This module is never called in a gradient path so it is safe to set it to `requires_grad=False`. This helps in upholding the assumption that all parameters that require a gradient need to receive one.
- parametrize `use_reentrant` so we check both, even though `use_reentrant=True` is the default; since `use_reentrant=False` has a consistency checker it might detect errors that we don't cover - check consistency between normal and checkpointed model runs, both runs must have the same set of non-zero gradients. it is not sufficient to check that there is a gradient, it must be non-zero as well (gradient checkpointing can fail with zero grads) - set `model.train` to enable gradient checkpointing - disable `lora_dropout` if set to make gradients (a bit more) deterministic Also while testing I found that GPTBigCode doesn't support gradient checkpointing even though it says so, skipping for now until fixed. BART doesn't work with the newest changes and fails at `loss.backward()` with `RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn`). This is still open.
Gradient checkpointing is different to the normal forward/backward process in that it may execute forward steps in the backward process. It is therefore decoupled from the normal forward process. Since we have methods, such as activated LoRA, that depend on `peft_forward_hooks` in normal forward we have to make sure that these hooks are applied in the independently executed forward by gradient checkpointing as well. To this end we had several options: - Don't use hooks to communicate the alora offsets, write them to the module directly. We can do that but then we'd need a mechanism to clean up these values afterwards (probably involving hooks) and we would be introducing a non-general way which might biting us in the future since more methods need parameter injection. - Don't support `use_reentrant=True` (default for transformers) and use `context_fn` parameter to inject parameters when using `use_reentrant=False`. `torch.utils.checkpoint` supports adding a `context_fn` parameter which returns two context managers (one for the normal, one for the checkpoint forward). In theory this could be a way to inject the variables into the module. In practice we would need to modify the `keywords` argument of every `._gradient_checkpointing_func` attribute of every module to inject the `context_fn` callback and update those accordingly every forward call. This is far less reliable than forward hooks. - Register forward hooks on the gradient checkpointing layers that apply the same hooks that `enable_peft_forward_hooks` does - but decoupled from `enable_peft_forward_hooks`. These hooks are removed once a full backward hook on the gradient checkpoint layer is called. We'd still need to use shared storage to store the hook handles so that we don't rack up forward and backward hooks but this storage is a general way of implementing forward hooks in gradient checkpointing. It also let's us control the flow without using private methods/variables. Since this adds forward hooks that are only removed when backward is called, we therefore disallow multiple forward calls in succession before a backward call (the user can do that with gradient checkpointing disabled). We're implementing the last option, forward hooks on gradient checkpointing layers.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this complex issue and coming up with a solution. Generally, it LGTM, I just have a few comments, please check.
It would also be great if @@kgreenewald could try if this branch resolves the original issue.
| "Multiple invocations of PEFT forward hooks before .backward() with enabled gradient " | ||
| "checkpointing. Disable gradient checkpointing or only call forward once per backward." | ||
| ) | ||
| layer._peft_gradient_checkpointing_forward_hooks = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we check for pre-existing _peft_gradient_checkpointing_forward_hooks? Right now, we know they don't exist, but since this is supposed to be a general solution (not aLoRA-specific), I'd say it's safer to check.
| module._peft_gradient_checkpointing_forward_hooks.append(handle) | ||
|
|
||
| def backward_hook(name, module, *grad_output, **kwargs): | ||
| for _ in range(len(module._peft_gradient_checkpointing_forward_hooks)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would while module._peft_gradient_checkpointing_forward_hooks: be a bit easier to understand?
| def forward_pre_hook(name, module, inputs, **kwargs): | ||
| for submodule in module.modules(): | ||
| if isinstance(submodule, LoraLayer): | ||
| handle = submodule.register_forward_pre_hook( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is not obvious to me. We register forward_pre_hook as a forward pre hook and this function itself registers more forward pre hooks. Is this just a convenient way to recurse? Could we not make this explicit by looping over the submodules in https://github.com/huggingface/peft/pull/2860/files#diff-4331e1b00e557c2af6e682bc221d1517f554d7ca24b9149520237b431d8d797fR390 and directly register the final hooks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _enable_peft_forward_hooks is becoming quite complex at this point. It could be worth it to refactor it into the mixed batch part and into the aLoRA part. Not necessarily in this PR, but could be done later.
| self.skipTest("AdaLora with RoBERTa does not work correctly") | ||
|
|
||
| if "bart" in model_id.lower(): | ||
| # TODO: no backprop possible with Bart, not sure why |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment can be updated, right?
|
|
||
| model.gradient_checkpointing_enable() | ||
| # Disable lora_dropout to remove non-determinism in gradient creation | ||
| if "lora_dropout" in config_kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering about non-LoRA PEFT methods.
Relevant issue: #2826
Gradient checkpointing is different to the normal forward/backward process in that it may execute forward
steps in the backward process. It is therefore decoupled from the normal forward process. Since we have
methods, such as activated LoRA, that depend on
peft_forward_hooksin normal forward we have to make surethat these hooks are applied in the independently executed forward by gradient checkpointing as well.
To this end we had several options:
we'd need a mechanism to clean up these values afterwards (probably involving hooks) and we would be introducing
a non-general way which might biting us in the future since more methods need parameter injection.
use_reentrant=True(default for transformers) and usecontext_fnparameter to inject parameterswhen using
use_reentrant=False.torch.utils.checkpointsupports adding acontext_fnparameter which returnstwo context managers (one for the normal, one for the checkpoint forward). In theory this could be a way to
inject the variables into the module. In practice we would need to modify the
keywordsargument of every._gradient_checkpointing_funcattribute of every module to inject thecontext_fncallback and update thoseaccordingly every forward call. This is far less reliable than forward hooks.
enable_peft_forward_hooksdoes - but decoupled from
enable_peft_forward_hooks. These hooks are removed once a full backward hook on thegradient checkpoint layer is called. We'd still need to use shared storage to store the hook handles so that
we don't rack up forward and backward hooks but this storage is a general way of implementing forward hooks
in gradient checkpointing. It also let's us control the flow without using private methods/variables. Since this
adds forward hooks that are only removed when backward is called, we therefore disallow multiple forward calls
in succession before a backward call (the user can do that with gradient checkpointing disabled).
In this change I'm implementing the last option, forward hooks on gradient checkpointing layers.
We already had tests but there were some issues that are improved in this PR:
use_reentrantso we check both, even thoughuse_reentrant=Trueis the default;since
use_reentrant=Falsehas a consistency checker it might detect errors that we don't coverof non-zero gradients. it is not sufficient to check that there is a gradient, it must be non-zero
as well (gradient checkpointing can fail with zero grads)
model.trainto enable gradient checkpointinglora_dropoutif set to make gradients (a bit more) deterministicWhile testing I found that GPTBigCode doesn't support gradient checkpointing even though it says so,
skipping for now until fixed. BART doesn't work with the newest changes and fails at
loss.backward()with
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn).Update: The BART issue is caused by the model not sharing the embed tokens module but just the weights. This in turn leads to
get_input_embeddingsreturning one specific module which may not be invoked in the forward call at all (model.sharedin this case - it has the same weights but it is a differentnn.Module). Sinceenable_input_require_gradsdepends on the module returned byget_input_embeddingsto have working forward hooks, the preparation for gradient checkpointing fails. This needs to be fixed in transformers either by targeting all tied weights or sharing the embedding modules instead of just the weights (like in T5).Also, I emoved gradient requirement in CPT's embedding.
This module is never called in a gradient path so it is safe to set it
to
requires_grad=False. This helps in upholding the assumption thatall parameters that require a gradient need to receive one.
Pinging @kgreenewald as well.