From 5e7f03dbb1b66be475be6c3eb511c61efb9ea7c7 Mon Sep 17 00:00:00 2001 From: nemo Date: Wed, 22 Oct 2025 10:20:32 +0200 Subject: [PATCH 1/5] Remove gradient requirement in CPT's embedding This module is never called in a gradient path so it is safe to set it to `requires_grad=False`. This helps in upholding the assumption that all parameters that require a gradient need to receive one. --- src/peft/tuners/cpt/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/peft/tuners/cpt/model.py b/src/peft/tuners/cpt/model.py index 934a3b7928..6c4dc08e51 100644 --- a/src/peft/tuners/cpt/model.py +++ b/src/peft/tuners/cpt/model.py @@ -53,6 +53,8 @@ def __init__(self, config, word_embeddings): word_embedding_weights = word_embedding_weights.to(torch.float32) self.embedding.weight = torch.nn.Parameter(word_embedding_weights) + self.embedding.requires_grad_(False) + # Initialize delta embedding with zero weights self.delta_embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim) self.delta_embedding.weight.data = torch.zeros_like(self.delta_embedding.weight).to(torch.float32) From 0f487741430d876516a26aeb03728074366451c0 Mon Sep 17 00:00:00 2001 From: nemo Date: Wed, 22 Oct 2025 10:22:20 +0200 Subject: [PATCH 2/5] Improve gradient checkpointing tests - parametrize `use_reentrant` so we check both, even though `use_reentrant=True` is the default; since `use_reentrant=False` has a consistency checker it might detect errors that we don't cover - check consistency between normal and checkpointed model runs, both runs must have the same set of non-zero gradients. it is not sufficient to check that there is a gradient, it must be non-zero as well (gradient checkpointing can fail with zero grads) - set `model.train` to enable gradient checkpointing - disable `lora_dropout` if set to make gradients (a bit more) deterministic Also while testing I found that GPTBigCode doesn't support gradient checkpointing even though it says so, skipping for now until fixed. BART doesn't work with the newest changes and fails at `loss.backward()` with `RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn`). This is still open. --- tests/test_custom_models.py | 7 +++-- tests/test_decoder_models.py | 7 +++-- tests/test_encoder_decoder_models.py | 7 +++-- tests/testing_common.py | 40 +++++++++++++++++++++++++--- 4 files changed, 52 insertions(+), 9 deletions(-) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index e30d7fe108..4d59f65553 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -1802,8 +1802,11 @@ def test_training_custom_models_layer_indexing(self, test_name, model_id, config pass @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) - def test_training_custom_models_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): - self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) + @pytest.mark.parametrize("use_reentrant", [True, False]) + def test_training_custom_models_gradient_checkpointing( + self, test_name, model_id, config_cls, config_kwargs, use_reentrant + ): + self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant) @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs): diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 06402d637b..07611dd9cb 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -526,9 +526,12 @@ def test_training_decoders_layer_indexing(self, model_id, config_cls, config_kwa @pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) - def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("use_reentrant", [True, False]) + def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant): _skip_if_not_conv1d_supported(model_id, config_cls) - self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs.copy()) + self._test_training_gradient_checkpointing( + model_id, config_cls, config_kwargs.copy(), use_reentrant=use_reentrant + ) @pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index 1ec0aa0668..4f36ceb211 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -339,8 +339,11 @@ def test_training_encoder_decoders_layer_indexing(self, model_id, config_cls, co @pytest.mark.parametrize("model_id", PEFT_ENCODER_DECODER_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) - def test_training_encoder_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs): - self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) + @pytest.mark.parametrize("use_reentrant", [True, False]) + def test_training_encoder_decoders_gradient_checkpointing( + self, model_id, config_cls, config_kwargs, use_reentrant + ): + self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant) @pytest.mark.parametrize("model_id", PEFT_ENCODER_DECODER_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) diff --git a/tests/testing_common.py b/tests/testing_common.py index dab9ee6e45..884443a883 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1315,7 +1315,11 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs): # more than 1 layer, i.e. setting layers_to_transform=[0] should target fewer layers assert nb_trainable < nb_trainable_all - def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs): + def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant=True): + # Note that certain configurations, such as activated lora with 'alora_invocation_tokens': [1000], do not + # generate gradients since the adapter is never activated so this will be a no-op for this test. It is still + # a valid test but it might be confusing to see a test pass if it is not supposed to. + if config_cls == PrefixTuningConfig: return pytest.skip(f"Test not applicable for {config_cls}") @@ -1323,17 +1327,26 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa # TODO: no gradients on the "dense" layer, other layers work, not sure why self.skipTest("AdaLora with RoBERTa does not work correctly") + if "bart" in model_id.lower(): + # TODO: no backprop possible with Bart, not sure why + self.skipTest("Bart does not work correctly") + if (config_cls == OFTConfig) and ("deberta" in model_id.lower()): # TODO: no gradients on the "dense" layer, other layers work, not sure why self.skipTest("OFT with Deberta does not work correctly") + if "gptbigcode" in model_id.lower(): + self.skipTest("GPTBigCode currently doesn't implement gradient checkpointing correctly.") + with hub_online_once(model_id): model = self.transformers_class.from_pretrained(model_id) if not getattr(model, "supports_gradient_checkpointing", False): return pytest.skip(f"Model {model_id} does not support gradient checkpointing") - model.gradient_checkpointing_enable() + # Disable lora_dropout to remove non-determinism in gradient creation + if "lora_dropout" in config_kwargs: + del config_kwargs["lora_dropout"] config = config_cls( base_model_name_or_path=model_id, @@ -1341,15 +1354,36 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa ) model = get_peft_model(model, config) model = model.to(self.torch_device) + params = [(n, p) for n, p in model.named_parameters() if p.requires_grad] + + # if we don't set this, gradient checkpointing is not activated. + model.train(True) inputs = self.prepare_inputs_for_testing() - # check if `training` works + # invocation to get the reference non-zero grads that are supposed to exist without gradient checkpointing + output = model(**inputs)[0] + + loss = output.sum() + loss.backward() + + non_zero_grad_params_normal = {n for n, p in params if p.grad.abs().sum() > 0} + + for name, param in params: + param.grad = None + + # invocation with gradient checkpointing for comparison + model.prepare_model_for_gradient_checkpointing(model) + model.gradient_checkpointing_enable({"use_reentrant": use_reentrant}) + output = model(**inputs)[0] loss = output.sum() loss.backward() + non_zero_grad_params_checkpointing = {n for n, p in params if p.grad.abs().sum() > 0} + assert non_zero_grad_params_normal == non_zero_grad_params_checkpointing + for n, param in model.named_parameters(): if "prompt_encoder." in n: # prompt tuning methods if not issubclass(config_cls, CPTConfig): From ed6fa3081d6ae571373e37d163c353b13be25068 Mon Sep 17 00:00:00 2001 From: nemo Date: Wed, 22 Oct 2025 10:29:10 +0200 Subject: [PATCH 3/5] Fix #2826: implement gradient checkpoint callbacks Gradient checkpointing is different to the normal forward/backward process in that it may execute forward steps in the backward process. It is therefore decoupled from the normal forward process. Since we have methods, such as activated LoRA, that depend on `peft_forward_hooks` in normal forward we have to make sure that these hooks are applied in the independently executed forward by gradient checkpointing as well. To this end we had several options: - Don't use hooks to communicate the alora offsets, write them to the module directly. We can do that but then we'd need a mechanism to clean up these values afterwards (probably involving hooks) and we would be introducing a non-general way which might biting us in the future since more methods need parameter injection. - Don't support `use_reentrant=True` (default for transformers) and use `context_fn` parameter to inject parameters when using `use_reentrant=False`. `torch.utils.checkpoint` supports adding a `context_fn` parameter which returns two context managers (one for the normal, one for the checkpoint forward). In theory this could be a way to inject the variables into the module. In practice we would need to modify the `keywords` argument of every `._gradient_checkpointing_func` attribute of every module to inject the `context_fn` callback and update those accordingly every forward call. This is far less reliable than forward hooks. - Register forward hooks on the gradient checkpointing layers that apply the same hooks that `enable_peft_forward_hooks` does - but decoupled from `enable_peft_forward_hooks`. These hooks are removed once a full backward hook on the gradient checkpoint layer is called. We'd still need to use shared storage to store the hook handles so that we don't rack up forward and backward hooks but this storage is a general way of implementing forward hooks in gradient checkpointing. It also let's us control the flow without using private methods/variables. Since this adds forward hooks that are only removed when backward is called, we therefore disallow multiple forward calls in succession before a backward call (the user can do that with gradient checkpointing disabled). We're implementing the last option, forward hooks on gradient checkpointing layers. --- src/peft/tuners/lora/model.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 2e76e13ee1..e2e9074ef8 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -23,6 +23,7 @@ import torch from torch import nn +from transformers.modeling_layers import GradientCheckpointingLayer from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.tuners.tuners_utils import ( @@ -357,7 +358,38 @@ def _enable_peft_forward_hooks(self, *args, **kwargs): return hook_handles = [] if alora_offsets is not None: - for layer in self.modules(): + for n, layer in self.named_modules(): + # gradient checkpointing layer are executed concurrently to the 'normal' forward call + # (in the backward step the gradient checkpointing layer's forward will be executed again). + # to be consistent with the normal forward we need to enable the pre hooks for this concurrent + # forward call as well. + # + # Note that this will lead to double application of whatever the callbacks do in normal forward. + # Make sure that whatever change is done, can be applied more than once without harm (idempotency). + if isinstance(layer, GradientCheckpointingLayer) and layer.gradient_checkpointing: + def forward_pre_hook(name, module, inputs, **kwargs): + for submodule in module.modules(): + if isinstance(submodule, LoraLayer): + handle = submodule.register_forward_pre_hook( + partial(_alora_offsets_pre_forward_hook, alora_offsets=kwargs["alora_offsets"]), + with_kwargs=True, + ) + module._peft_gradient_checkpointing_forward_hooks.append(handle) + + def backward_hook(name, module, *grad_output, **kwargs): + for _ in range(len(module._peft_gradient_checkpointing_forward_hooks)): + module._peft_gradient_checkpointing_forward_hooks.pop().remove() + + if getattr(layer, "_peft_gradient_checkpointing_forward_hooks", []): + raise ValueError( + "Multiple invocations of PEFT forward hooks before .backward() with enabled gradient " + "checkpointing. Disable gradient checkpointing or only call forward once per backward." + ) + layer._peft_gradient_checkpointing_forward_hooks = [] + handle = layer.register_forward_pre_hook(partial(forward_pre_hook, n, alora_offsets=alora_offsets)) + layer._peft_gradient_checkpointing_forward_hooks.append(handle) + handle = layer.register_full_backward_hook(partial(backward_hook, n)) + layer._peft_gradient_checkpointing_forward_hooks.append(handle) if isinstance(layer, LoraLayer): pre_forward = partial(_alora_offsets_pre_forward_hook, alora_offsets=alora_offsets) handle = layer.register_forward_pre_hook(pre_forward, with_kwargs=True) From 5650c769de471253bc814ce434c1096ed4e7ab9e Mon Sep 17 00:00:00 2001 From: nemo Date: Wed, 22 Oct 2025 13:02:34 +0200 Subject: [PATCH 4/5] Make style --- src/peft/tuners/lora/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index e2e9074ef8..8c4967fddd 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -367,6 +367,7 @@ def _enable_peft_forward_hooks(self, *args, **kwargs): # Note that this will lead to double application of whatever the callbacks do in normal forward. # Make sure that whatever change is done, can be applied more than once without harm (idempotency). if isinstance(layer, GradientCheckpointingLayer) and layer.gradient_checkpointing: + def forward_pre_hook(name, module, inputs, **kwargs): for submodule in module.modules(): if isinstance(submodule, LoraLayer): From 28a357102cff3f02e689a92e061fb28e3434afd1 Mon Sep 17 00:00:00 2001 From: nemo Date: Thu, 23 Oct 2025 12:37:48 +0200 Subject: [PATCH 5/5] Address test flakyness --- tests/test_feature_extraction_models.py | 5 +++-- tests/testing_common.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index a5377827f4..f11876df6b 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -330,9 +330,10 @@ def test_training_layer_indexing(self, model_id, config_cls, config_kwargs): @pytest.mark.parametrize("model_id", PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) - def test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs): + @pytest.mark.parametrize("use_reentrant", [True, False]) + def test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant): skip_deberta_lora_tests(config_cls, model_id) - self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) + self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant) @pytest.mark.parametrize("model_id", PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST) @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) diff --git a/tests/testing_common.py b/tests/testing_common.py index 884443a883..22df86600b 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1361,8 +1361,9 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa inputs = self.prepare_inputs_for_testing() - # invocation to get the reference non-zero grads that are supposed to exist without gradient checkpointing - output = model(**inputs)[0] + # invocation to get the reference non-zero grads that are supposed to exist without gradient checkpointing; + # note we're squaring the output for bigger gradients + output = model(**inputs)[0] ** 2 loss = output.sum() loss.backward() @@ -1376,7 +1377,7 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa model.prepare_model_for_gradient_checkpointing(model) model.gradient_checkpointing_enable({"use_reentrant": use_reentrant}) - output = model(**inputs)[0] + output = model(**inputs)[0] ** 2 loss = output.sum() loss.backward()