-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix #2826: implement gradient checkpoint callbacks #2860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
|
|
||
| import torch | ||
| from torch import nn | ||
| from transformers.modeling_layers import GradientCheckpointingLayer | ||
|
|
||
| from peft.import_utils import is_bnb_4bit_available, is_bnb_available | ||
| from peft.tuners.tuners_utils import ( | ||
|
|
@@ -357,7 +358,39 @@ def _enable_peft_forward_hooks(self, *args, **kwargs): | |
| return | ||
| hook_handles = [] | ||
| if alora_offsets is not None: | ||
| for layer in self.modules(): | ||
| for n, layer in self.named_modules(): | ||
| # gradient checkpointing layer are executed concurrently to the 'normal' forward call | ||
| # (in the backward step the gradient checkpointing layer's forward will be executed again). | ||
| # to be consistent with the normal forward we need to enable the pre hooks for this concurrent | ||
| # forward call as well. | ||
| # | ||
| # Note that this will lead to double application of whatever the callbacks do in normal forward. | ||
| # Make sure that whatever change is done, can be applied more than once without harm (idempotency). | ||
| if isinstance(layer, GradientCheckpointingLayer) and layer.gradient_checkpointing: | ||
|
|
||
| def forward_pre_hook(name, module, inputs, **kwargs): | ||
| for submodule in module.modules(): | ||
| if isinstance(submodule, LoraLayer): | ||
| handle = submodule.register_forward_pre_hook( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is not obvious to me. We register |
||
| partial(_alora_offsets_pre_forward_hook, alora_offsets=kwargs["alora_offsets"]), | ||
| with_kwargs=True, | ||
| ) | ||
| module._peft_gradient_checkpointing_forward_hooks.append(handle) | ||
|
|
||
| def backward_hook(name, module, *grad_output, **kwargs): | ||
| for _ in range(len(module._peft_gradient_checkpointing_forward_hooks)): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would |
||
| module._peft_gradient_checkpointing_forward_hooks.pop().remove() | ||
|
|
||
| if getattr(layer, "_peft_gradient_checkpointing_forward_hooks", []): | ||
| raise ValueError( | ||
| "Multiple invocations of PEFT forward hooks before .backward() with enabled gradient " | ||
| "checkpointing. Disable gradient checkpointing or only call forward once per backward." | ||
| ) | ||
| layer._peft_gradient_checkpointing_forward_hooks = [] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we check for pre-existing |
||
| handle = layer.register_forward_pre_hook(partial(forward_pre_hook, n, alora_offsets=alora_offsets)) | ||
| layer._peft_gradient_checkpointing_forward_hooks.append(handle) | ||
| handle = layer.register_full_backward_hook(partial(backward_hook, n)) | ||
| layer._peft_gradient_checkpointing_forward_hooks.append(handle) | ||
| if isinstance(layer, LoraLayer): | ||
| pre_forward = partial(_alora_offsets_pre_forward_hook, alora_offsets=alora_offsets) | ||
| handle = layer.register_forward_pre_hook(pre_forward, with_kwargs=True) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1315,41 +1315,76 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs): | |
| # more than 1 layer, i.e. setting layers_to_transform=[0] should target fewer layers | ||
| assert nb_trainable < nb_trainable_all | ||
|
|
||
| def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs): | ||
| def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant=True): | ||
| # Note that certain configurations, such as activated lora with 'alora_invocation_tokens': [1000], do not | ||
| # generate gradients since the adapter is never activated so this will be a no-op for this test. It is still | ||
| # a valid test but it might be confusing to see a test pass if it is not supposed to. | ||
|
|
||
| if config_cls == PrefixTuningConfig: | ||
| return pytest.skip(f"Test not applicable for {config_cls}") | ||
|
|
||
| if (config_cls == AdaLoraConfig) and ("roberta" in model_id.lower()): | ||
| # TODO: no gradients on the "dense" layer, other layers work, not sure why | ||
| self.skipTest("AdaLora with RoBERTa does not work correctly") | ||
|
|
||
| if "bart" in model_id.lower(): | ||
| # TODO: no backprop possible with Bart, not sure why | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment can be updated, right? |
||
| self.skipTest("Bart does not work correctly") | ||
|
|
||
| if (config_cls == OFTConfig) and ("deberta" in model_id.lower()): | ||
| # TODO: no gradients on the "dense" layer, other layers work, not sure why | ||
| self.skipTest("OFT with Deberta does not work correctly") | ||
|
|
||
| if "gptbigcode" in model_id.lower(): | ||
| self.skipTest("GPTBigCode currently doesn't implement gradient checkpointing correctly.") | ||
|
|
||
| with hub_online_once(model_id): | ||
| model = self.transformers_class.from_pretrained(model_id) | ||
|
|
||
| if not getattr(model, "supports_gradient_checkpointing", False): | ||
| return pytest.skip(f"Model {model_id} does not support gradient checkpointing") | ||
|
|
||
| model.gradient_checkpointing_enable() | ||
| # Disable lora_dropout to remove non-determinism in gradient creation | ||
| if "lora_dropout" in config_kwargs: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wondering about non-LoRA PEFT methods. |
||
| del config_kwargs["lora_dropout"] | ||
|
|
||
| config = config_cls( | ||
| base_model_name_or_path=model_id, | ||
| **config_kwargs, | ||
| ) | ||
| model = get_peft_model(model, config) | ||
| model = model.to(self.torch_device) | ||
| params = [(n, p) for n, p in model.named_parameters() if p.requires_grad] | ||
|
|
||
| # if we don't set this, gradient checkpointing is not activated. | ||
| model.train(True) | ||
|
|
||
| inputs = self.prepare_inputs_for_testing() | ||
|
|
||
| # check if `training` works | ||
| output = model(**inputs)[0] | ||
| # invocation to get the reference non-zero grads that are supposed to exist without gradient checkpointing; | ||
| # note we're squaring the output for bigger gradients | ||
| output = model(**inputs)[0] ** 2 | ||
|
|
||
| loss = output.sum() | ||
| loss.backward() | ||
|
|
||
| non_zero_grad_params_normal = {n for n, p in params if p.grad.abs().sum() > 0} | ||
|
|
||
| for name, param in params: | ||
| param.grad = None | ||
|
|
||
| # invocation with gradient checkpointing for comparison | ||
| model.prepare_model_for_gradient_checkpointing(model) | ||
| model.gradient_checkpointing_enable({"use_reentrant": use_reentrant}) | ||
|
|
||
| output = model(**inputs)[0] ** 2 | ||
|
|
||
| loss = output.sum() | ||
| loss.backward() | ||
|
|
||
| non_zero_grad_params_checkpointing = {n for n, p in params if p.grad.abs().sum() > 0} | ||
| assert non_zero_grad_params_normal == non_zero_grad_params_checkpointing | ||
|
|
||
| for n, param in model.named_parameters(): | ||
| if "prompt_encoder." in n: # prompt tuning methods | ||
| if not issubclass(config_cls, CPTConfig): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_enable_peft_forward_hooksis becoming quite complex at this point. It could be worth it to refactor it into the mixed batch part and into the aLoRA part. Not necessarily in this PR, but could be done later.