diff --git a/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py b/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py index fa91eb8fc3..8e47864076 100755 --- a/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py +++ b/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py @@ -41,7 +41,7 @@ max_memory = f"{free_in_GB - 2}GB" n_gpus = torch.cuda.device_count() -max_memory = {i: max_memory for i in range(n_gpus)} +max_memory = dict.fromkeys(range(n_gpus), max_memory) model = AutoModelForCausalLM.from_pretrained( "facebook/opt-350m", diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 580cc41038..30e7a59625 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -749,9 +749,9 @@ def get_prompt( # If we don't apply this, prefix-tuning fails to update cross-attn cache past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) past_key_values.cross_attention_cache = DynamicCache() - past_key_values.is_updated = { - layer_idx: False for layer_idx in range(len(past_key_values.cross_attention_cache.key_cache)) - } + past_key_values.is_updated = dict.fromkeys( + range(len(past_key_values.cross_attention_cache.key_cache)), False + ) map_cache_to_layer_device_map(self.get_base_model(), past_key_values) # no-op if not a Cache instance return past_key_values else: diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 1db720bceb..a04889005d 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -296,8 +296,16 @@ class LoraConfig(PeftConfig): into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA especially at low ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger overhead than pure - LoRA, so it is recommended to merge weights for inference. For more information, see + LoRA, so it is recommended to merge weights for inference. For more information, see https://huggingface.co/papers/2402.09353. + use_sinelora (`bool`): + Enable 'Sine Activated Low-Rank Adaptation' (Sine-LoRA). This technique introduce to apply sine activation + on the low-rank adaptor. This can be beneficial for rank boosting for low-rank matrices and enhancing its + capacity. For more information, see https://arxiv.org/pdf/2403.19243. + sinelora_frequency (`float`): + The frequency factor for the sine activation. If not specified, it will be set to the default value of 200. + sinelora_scaling (`float`): + The scaling factor for the sine activation. If not specified, it will be set to the default value of sqrt(in_features). layer_replication (`List[Tuple[int, int]]`): Build a new stack of layers by stacking the original model layers according to the ranges specified. This allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will @@ -494,6 +502,32 @@ class LoraConfig(PeftConfig): ) }, ) + use_sinelora: bool = field( + default=False, + metadata={ + "help": ( + "Enable 'Sine Activated Low-Rank Adaptation' (Sine-LoRA). This technique introduce to apply sine activation " + "on the low-rank adaptor. This can be beneficial for rank boosting for low-rank matrices and enhancing its " + "capacity. For more information, see https://arxiv.org/pdf/2403.19243. " + ) + }, + ) + sinelora_frequency: float = field( + default=200.0, + metadata={ + "help": ( + "The frequency factor for the sine activation. If not specified, it will be set to the default value of 200." + ) + }, + ) + sinelora_scaling: Optional[float] = field( + default=None, + metadata={ + "help": ( + "The scaling factor for the sine activation. If not specified, it will be set to the default value of sqrt(in_features)." + ) + }, + ) # Enables replicating layers in a model to expand it to a larger model. layer_replication: Optional[list[tuple[int, int]]] = field( default=None, diff --git a/src/peft/tuners/lora/eva.py b/src/peft/tuners/lora/eva.py index 1bc75453b1..8c45449251 100644 --- a/src/peft/tuners/lora/eva.py +++ b/src/peft/tuners/lora/eva.py @@ -409,7 +409,7 @@ def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget, else: pbar = iter(cycle(dataloader)) use_tqdm = False - convergence_dict = {k: False for k in hooks.keys()} + convergence_dict = dict.fromkeys(hooks.keys(), False) rank_dist = max_components.copy() for inputs in pbar: if device is not None: diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 1a070a498b..f5ccd66aa6 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -161,7 +161,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, * self.in_features = in_features self.out_features = out_features - def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]: + def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]: """Return a matching LoRA variant for this layer type. Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this @@ -173,6 +173,7 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVari convention, and not here. """ + return None def update_layer( @@ -184,17 +185,24 @@ def update_layer( init_lora_weights, use_rslora, use_dora: bool = False, + use_sinelora: bool = False, + sinelora_frequency: float = 200.0, + sinelora_scaling: Optional[float] = None, lora_bias: bool = False, ): # collect the kwargs kwargs = locals().copy() del kwargs["self"] + if use_sinelora: + self.sinelora_frequency = sinelora_frequency + self.sinelora_scaling = sinelora_scaling + # This code works for linear layers, override for other layer types if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") - lora_variant = self.resolve_lora_variant(use_dora=use_dora) + lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_sinelora=use_sinelora) if lora_variant is not None: self.lora_variant[adapter_name] = lora_variant @@ -571,7 +579,10 @@ def __init__( init_lora_weights: Union[bool, str] = True, use_rslora: bool = False, use_dora: bool = False, + sinelora_frequency: float = 200.0, + sinelora_scaling: Optional[float] = None, lora_bias: bool = False, + use_sinelora: bool = False, **kwargs, ) -> None: super().__init__() @@ -588,16 +599,24 @@ def __init__( use_rslora=use_rslora, use_dora=use_dora, lora_bias=lora_bias, + use_sinelora=use_sinelora, + sinelora_frequency=sinelora_frequency, + sinelora_scaling=sinelora_scaling, ) self.is_target_conv_1d_layer = is_target_conv_1d_layer - def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]: - if not use_dora: - return None + def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]: + if use_dora: + from .variants import DoraLinearVariant + + return DoraLinearVariant() - from .variants import DoraLinearVariant + elif use_sinelora: + from .variants import SineLoraLinearVariant - return DoraLinearVariant() + return SineLoraLinearVariant() + + return None def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ @@ -773,6 +792,9 @@ def __init__( use_rslora: bool = False, use_dora: bool = False, lora_bias: bool = False, + use_sinelora: bool = False, + sinelora_frequency=200.0, + sinelora_scaling: Optional[float] = None, **kwargs, ) -> None: if lora_bias: @@ -792,28 +814,50 @@ def __init__( init_lora_weights=init_lora_weights, use_rslora=use_rslora, use_dora=use_dora, + use_sinelora=use_sinelora, lora_bias=lora_bias, + sinelora_frequency=sinelora_frequency, + sinelora_scaling=sinelora_scaling, ) - def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]: - if not use_dora: - return None + def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]: + if use_dora: + from .variants import DoraEmbeddingVariant - from .variants import DoraEmbeddingVariant + return DoraEmbeddingVariant() + elif use_sinelora: + from .variants import SineLoraEmbeddingVariant - return DoraEmbeddingVariant() + return SineLoraEmbeddingVariant() + else: + return None def update_layer( - self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias + self, + adapter_name, + r, + lora_alpha, + lora_dropout, + init_lora_weights, + use_rslora, + use_dora, + lora_bias, + use_sinelora: bool = False, + sinelora_frequency: float = 200.0, + sinelora_scaling: Optional[float] = None, ): # collect the kwargs kwargs = locals().copy() del kwargs["self"] + if use_sinelora: + self.sinelora_frequency = sinelora_frequency + self.sinelora_scaling = sinelora_scaling + if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") - lora_variant = self.resolve_lora_variant(use_dora=use_dora) + lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_sinelora=use_sinelora) if lora_variant is not None: self.lora_variant[adapter_name] = lora_variant @@ -988,7 +1032,7 @@ def _embed(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: norm_type=base_layer.norm_type, scale_grad_by_freq=base_layer.scale_grad_by_freq, sparse=base_layer.sparse, - ) + ) def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: # TODO: no dtype conversion here, unlike in Linear, is that correct? @@ -1068,7 +1112,16 @@ def __init__( ) def update_layer( - self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias + self, + adapter_name, + r, + lora_alpha, + lora_dropout, + init_lora_weights, + use_rslora, + use_dora, + use_sinelora, + lora_bias, ): # collect the kwargs kwargs = locals().copy() @@ -1077,7 +1130,7 @@ def update_layer( if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") - lora_variant = self.resolve_lora_variant(use_dora=use_dora) + lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_sinelora=use_sinelora) if lora_variant is not None: self.lora_variant[adapter_name] = lora_variant @@ -1326,13 +1379,12 @@ def __init__(self, *args, **kwargs): raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}") self.conv_fn = F.conv1d - def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]: + def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]: if not use_dora: - return None - - from .variants import DoraConv1dVariant - - return DoraConv1dVariant() + from .variants import DoraConv1dVariant + elif use_sinelora: + from .variants import SineLoraConv1dVariant + return None class Conv3d(_ConvNd): diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 07dd0d314b..145e2c2b21 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -200,6 +200,9 @@ def _create_and_replace( "init_lora_weights": lora_config.init_lora_weights, "use_rslora": lora_config.use_rslora, "use_dora": lora_config.use_dora, + "use_sinelora": lora_config.use_sinelora, + "sinelora_frequency": lora_config.sinelora_frequency, + "sinelora_scaling": lora_config.sinelora_scaling, "ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload, "lora_bias": lora_config.lora_bias, "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), diff --git a/src/peft/tuners/lora/variants.py b/src/peft/tuners/lora/variants.py index 13aa006b4a..bdaa3fbe5a 100644 --- a/src/peft/tuners/lora/variants.py +++ b/src/peft/tuners/lora/variants.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import math from typing import Any import torch @@ -21,8 +22,9 @@ from peft.utils.other import transpose -from .dora import DoraConv1dLayer, DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer -from .layer import Conv1d, Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd +from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer +from .layer import Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd +from transformers.pytorch_utils import Conv1D class DoraLinearVariant(LoraVariant): @@ -315,3 +317,162 @@ class DoraConv3dVariant(_DoraConvNdVariant): def init(module: Conv3d, adapter_name: str, **kwargs: Any) -> None: dora_layer = DoraConv3dLayer(fan_in_fan_out=False) _DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer) + + +class SineLoraLinearVariant(LoraVariant): + @staticmethod + def init(module: Linear, adapter_name: str, **kwargs) -> None: + module.sinelora_frequency = kwargs["sinelora_frequency"] + + module.sinelora_scaling = kwargs["sinelora_scaling"] + if module.sinelora_scaling is None: + module.sinelora_scaling = math.sqrt(module.in_features) + + @staticmethod + def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype + lora_A = module.lora_A[active_adapter] + lora_B = module.lora_B[active_adapter] + lora_scaling = module.scaling[active_adapter] + delta_weight = ( + torch.sin(module.sinelora_frequency * lora_A.weight.T @ lora_B.weight.T).T + / module.sinelora_scaling + * lora_scaling + ) + + delta_weight = delta_weight.to(orig_dtype) + unmerged_weight = orig_weight - delta_weight + return unmerged_weight + + + @staticmethod + def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype + lora_A = module.lora_A[active_adapter] + lora_B = module.lora_B[active_adapter] + lora_scaling = module.scaling[active_adapter] + delta_weight = ( + torch.sin(module.sinelora_frequency * (lora_A.weight.T @ lora_B.weight.T)).T + / module.sinelora_scaling + * lora_scaling + ) + merged_weight = orig_weight + delta_weight + if not torch.isfinite(merged_weight).all(): + raise ValueError(f"NaNs detected in merged weights for adapter {active_adapter}") + module._cache_store(f"{active_adapter}-delta_weight", delta_weight) + merged_weight = merged_weight.to(orig_dtype) + return merged_weight + @staticmethod + def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None: + lora_A = module.lora_A[active_adapter] + lora_B = module.lora_B[active_adapter] + lora_scaling = module.scaling[active_adapter] + delta_weight = ( + torch.sin(module.sinelora_frequency * (lora_A.weight.T @ lora_B.weight.T)).T + / module.sinelora_scaling + * lora_scaling + ) + module._cache_store(f"{active_adapter}-delta_weight", delta_weight) + orig_weight.data += delta_weight + @staticmethod + def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor: + lora_A = module.lora_A[active_adapter] + lora_B = module.lora_B[active_adapter] + lora_scaling = module.scaling[active_adapter] + sine_output = ( + x + @ torch.sin(module.sinelora_frequency * lora_A.weight.T @ lora_B.weight.T) + / module.sinelora_scaling + * lora_scaling + ) + result = result + sine_output + return result + +class SineLoraEmbeddingVariant(LoraVariant): + @staticmethod + def init(module: Embedding, adapter_name: str, **kwargs) -> None: + module.sinelora_frequency = kwargs["sinelora_frequency"] + + sinelora_scaling = kwargs["sinelora_scaling"] + if sinelora_scaling is None: + module.sinelora_scaling = math.sqrt(module.in_features) + @staticmethod + def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype + lora_embedding_A = module.lora_embedding_A[active_adapter] + lora_embedding_B = module.lora_embedding_B[active_adapter] + lora_scaling = module.scaling[active_adapter] + delta_weight = ( + torch.sin(module.sinelora_frequency * lora_embedding_A.T @ lora_embedding_B.T) + / module.sinelora_scaling + * lora_scaling + ) + delta_weight = delta_weight.to(orig_dtype) + unmerged_weight = orig_weight - delta_weight + return unmerged_weight + + + @staticmethod + def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype + lora_embedding_A = module.lora_embedding_A[active_adapter] + lora_embedding_B = module.lora_embedding_B[active_adapter] + lora_scaling = module.scaling[active_adapter] + delta_weight = ( + torch.sin(module.sinelora_frequency * (lora_embedding_A.T @ lora_embedding_B.T)) + / module.sinelora_scaling + * lora_scaling + ) + merged_weight = orig_weight + delta_weight + if not torch.isfinite(merged_weight).all(): + raise ValueError(f"NaNs detected in merged weights for adapter {active_adapter}") + module._cache_store(f"{active_adapter}-delta_weight", delta_weight) + merged_weight = merged_weight.to(orig_dtype) + return merged_weight + @staticmethod + def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None: + + lora_embedding_A = module.lora_embedding_A[active_adapter] + lora_embedding_B = module.lora_embedding_B[active_adapter] + lora_scaling = module.scaling[active_adapter] + delta_weight = ( + torch.sin(module.sinelora_frequency * (lora_embedding_A.T @ lora_embedding_B.T)) + / module.sinelora_scaling + * lora_scaling + ) + module._cache_store(f"{active_adapter}-delta_weight", delta_weight) + orig_weight.data += delta_weight + + @staticmethod + def forward(module: Embedding, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor: + lora_embedding_A = module.lora_embedding_A[active_adapter] + lora_embedding_B = module.lora_embedding_B[active_adapter] + lora_scaling = module.scaling[active_adapter] + sine_output = ( + module._embed(x, + torch.sin(module.sinelora_frequency * lora_embedding_A.T @ lora_embedding_B.T)/ module.sinelora_scaling * lora_scaling) + ) + result = result + sine_output + return result + +class DoraEmbeddingLayer(DoraLinearLayer): + def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn): + """ + For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer + output. + """ + lora_weight = (lora_A @ lora_B).T + magnitude = self.weight + weight = base_layer.weight + weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) + # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) + # "[...] we suggest treating ||V +∆V ||_c in + # Eq. (5) as a constant, thereby detaching it from the gradient + # graph. This means that while ||V + ∆V ||_c dynamically + # reflects the updates of ∆V , it won’t receive any gradient + # during backpropagation" + weight_norm = weight_norm.detach() + mag_norm_scale = magnitude / weight_norm + result_dora = mag_norm_scale * (embed_fn(x, lora_A) @ lora_B) * scaling + return mag_norm_scale, result_dora + diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py index 7d472870df..e8571d7f71 100644 --- a/src/peft/utils/integrations.py +++ b/src/peft/utils/integrations.py @@ -145,7 +145,7 @@ def get_layer_device_map(model): return None if len(execution_device_map) == 1 and "" in execution_device_map: - return {idx: execution_device_map[""] for idx in range(model.config.num_hidden_layers)} + return dict.fromkeys(range(model.config.num_hidden_layers), execution_device_map[""]) layer_device_map = {} for layer in execution_device_map: diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 3b1283a779..1a40ab4504 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -513,6 +513,13 @@ TrainableTokensConfig, {"target_modules": ["emb"], "token_indices": [0, 1, 3], "init_weights": False}, ), + ################### + # LoRA + SineLoRA # + ################### + ("Vanilla MLP LoRA + SineLoRA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"], "use_sinelora": True}), + ("Embedding + transformers Conv1D LoRA + SineLoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb"], "use_sinelora": True}), + + ######## # RandLora # ######## diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 15040013e5..2f0bbb6f33 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -2664,7 +2664,7 @@ def fn(x, *args): if prepare_layer_inputs_keys is None: prepare_layer_inputs_fn = fn else: - prepare_layer_inputs_fn = {k: fn for k in prepare_layer_inputs_keys} + prepare_layer_inputs_fn = dict.fromkeys(prepare_layer_inputs_keys, fn) shuffled_dataset = dataset.shuffle(seed=0) dataloader = self.get_dataloader(shuffled_dataset)