Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/peft/tuners/cpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self, config, word_embeddings):
word_embedding_weights = word_embedding_weights.to(torch.float32)
self.embedding.weight = torch.nn.Parameter(word_embedding_weights)

self.embedding.requires_grad_(False)

# Initialize delta embedding with zero weights
self.delta_embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim)
self.delta_embedding.weight.data = torch.zeros_like(self.delta_embedding.weight).to(torch.float32)
Expand Down
35 changes: 34 additions & 1 deletion src/peft/tuners/lora/model.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _enable_peft_forward_hooks is becoming quite complex at this point. It could be worth it to refactor it into the mixed batch part and into the aLoRA part. Not necessarily in this PR, but could be done later.

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch
from torch import nn
from transformers.modeling_layers import GradientCheckpointingLayer

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import (
Expand Down Expand Up @@ -357,7 +358,39 @@ def _enable_peft_forward_hooks(self, *args, **kwargs):
return
hook_handles = []
if alora_offsets is not None:
for layer in self.modules():
for n, layer in self.named_modules():
# gradient checkpointing layer are executed concurrently to the 'normal' forward call
# (in the backward step the gradient checkpointing layer's forward will be executed again).
# to be consistent with the normal forward we need to enable the pre hooks for this concurrent
# forward call as well.
#
# Note that this will lead to double application of whatever the callbacks do in normal forward.
# Make sure that whatever change is done, can be applied more than once without harm (idempotency).
if isinstance(layer, GradientCheckpointingLayer) and layer.gradient_checkpointing:

def forward_pre_hook(name, module, inputs, **kwargs):
for submodule in module.modules():
if isinstance(submodule, LoraLayer):
handle = submodule.register_forward_pre_hook(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is not obvious to me. We register forward_pre_hook as a forward pre hook and this function itself registers more forward pre hooks. Is this just a convenient way to recurse? Could we not make this explicit by looping over the submodules in https://github.com/huggingface/peft/pull/2860/files#diff-4331e1b00e557c2af6e682bc221d1517f554d7ca24b9149520237b431d8d797fR390 and directly register the final hooks?

partial(_alora_offsets_pre_forward_hook, alora_offsets=kwargs["alora_offsets"]),
with_kwargs=True,
)
module._peft_gradient_checkpointing_forward_hooks.append(handle)

def backward_hook(name, module, *grad_output, **kwargs):
for _ in range(len(module._peft_gradient_checkpointing_forward_hooks)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would while module._peft_gradient_checkpointing_forward_hooks: be a bit easier to understand?

module._peft_gradient_checkpointing_forward_hooks.pop().remove()

if getattr(layer, "_peft_gradient_checkpointing_forward_hooks", []):
raise ValueError(
"Multiple invocations of PEFT forward hooks before .backward() with enabled gradient "
"checkpointing. Disable gradient checkpointing or only call forward once per backward."
)
layer._peft_gradient_checkpointing_forward_hooks = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check for pre-existing _peft_gradient_checkpointing_forward_hooks? Right now, we know they don't exist, but since this is supposed to be a general solution (not aLoRA-specific), I'd say it's safer to check.

handle = layer.register_forward_pre_hook(partial(forward_pre_hook, n, alora_offsets=alora_offsets))
layer._peft_gradient_checkpointing_forward_hooks.append(handle)
handle = layer.register_full_backward_hook(partial(backward_hook, n))
layer._peft_gradient_checkpointing_forward_hooks.append(handle)
if isinstance(layer, LoraLayer):
pre_forward = partial(_alora_offsets_pre_forward_hook, alora_offsets=alora_offsets)
handle = layer.register_forward_pre_hook(pre_forward, with_kwargs=True)
Expand Down
7 changes: 5 additions & 2 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,8 +1802,11 @@ def test_training_custom_models_layer_indexing(self, test_name, model_id, config
pass

@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_training_custom_models_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("use_reentrant", [True, False])
def test_training_custom_models_gradient_checkpointing(
self, test_name, model_id, config_cls, config_kwargs, use_reentrant
):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant)

@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs):
Expand Down
7 changes: 5 additions & 2 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,12 @@ def test_training_decoders_layer_indexing(self, model_id, config_cls, config_kwa

@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
@pytest.mark.parametrize("use_reentrant", [True, False])
def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant):
_skip_if_not_conv1d_supported(model_id, config_cls)
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs.copy())
self._test_training_gradient_checkpointing(
model_id, config_cls, config_kwargs.copy(), use_reentrant=use_reentrant
)

@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
Expand Down
7 changes: 5 additions & 2 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,11 @@ def test_training_encoder_decoders_layer_indexing(self, model_id, config_cls, co

@pytest.mark.parametrize("model_id", PEFT_ENCODER_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_training_encoder_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("use_reentrant", [True, False])
def test_training_encoder_decoders_gradient_checkpointing(
self, model_id, config_cls, config_kwargs, use_reentrant
):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant)

@pytest.mark.parametrize("model_id", PEFT_ENCODER_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_feature_extraction_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ def test_training_layer_indexing(self, model_id, config_cls, config_kwargs):

@pytest.mark.parametrize("model_id", PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
@pytest.mark.parametrize("use_reentrant", [True, False])
def test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant):
skip_deberta_lora_tests(config_cls, model_id)
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs, use_reentrant=use_reentrant)

@pytest.mark.parametrize("model_id", PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
Expand Down
43 changes: 39 additions & 4 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,41 +1315,76 @@ def _test_training_layer_indexing(self, model_id, config_cls, config_kwargs):
# more than 1 layer, i.e. setting layers_to_transform=[0] should target fewer layers
assert nb_trainable < nb_trainable_all

def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwargs, use_reentrant=True):
# Note that certain configurations, such as activated lora with 'alora_invocation_tokens': [1000], do not
# generate gradients since the adapter is never activated so this will be a no-op for this test. It is still
# a valid test but it might be confusing to see a test pass if it is not supposed to.

if config_cls == PrefixTuningConfig:
return pytest.skip(f"Test not applicable for {config_cls}")

if (config_cls == AdaLoraConfig) and ("roberta" in model_id.lower()):
# TODO: no gradients on the "dense" layer, other layers work, not sure why
self.skipTest("AdaLora with RoBERTa does not work correctly")

if "bart" in model_id.lower():
# TODO: no backprop possible with Bart, not sure why
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment can be updated, right?

self.skipTest("Bart does not work correctly")

if (config_cls == OFTConfig) and ("deberta" in model_id.lower()):
# TODO: no gradients on the "dense" layer, other layers work, not sure why
self.skipTest("OFT with Deberta does not work correctly")

if "gptbigcode" in model_id.lower():
self.skipTest("GPTBigCode currently doesn't implement gradient checkpointing correctly.")

with hub_online_once(model_id):
model = self.transformers_class.from_pretrained(model_id)

if not getattr(model, "supports_gradient_checkpointing", False):
return pytest.skip(f"Model {model_id} does not support gradient checkpointing")

model.gradient_checkpointing_enable()
# Disable lora_dropout to remove non-determinism in gradient creation
if "lora_dropout" in config_kwargs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering about non-LoRA PEFT methods.

del config_kwargs["lora_dropout"]

config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model = model.to(self.torch_device)
params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]

# if we don't set this, gradient checkpointing is not activated.
model.train(True)

inputs = self.prepare_inputs_for_testing()

# check if `training` works
output = model(**inputs)[0]
# invocation to get the reference non-zero grads that are supposed to exist without gradient checkpointing;
# note we're squaring the output for bigger gradients
output = model(**inputs)[0] ** 2

loss = output.sum()
loss.backward()

non_zero_grad_params_normal = {n for n, p in params if p.grad.abs().sum() > 0}

for name, param in params:
param.grad = None

# invocation with gradient checkpointing for comparison
model.prepare_model_for_gradient_checkpointing(model)
model.gradient_checkpointing_enable({"use_reentrant": use_reentrant})

output = model(**inputs)[0] ** 2

loss = output.sum()
loss.backward()

non_zero_grad_params_checkpointing = {n for n, p in params if p.grad.abs().sum() > 0}
assert non_zero_grad_params_normal == non_zero_grad_params_checkpointing

for n, param in model.named_parameters():
if "prompt_encoder." in n: # prompt tuning methods
if not issubclass(config_cls, CPTConfig):
Expand Down
Loading