diff --git a/docs/source/package_reference/fourierft.md b/docs/source/package_reference/fourierft.md index 1d298a9042..641ad65441 100644 --- a/docs/source/package_reference/fourierft.md +++ b/docs/source/package_reference/fourierft.md @@ -20,7 +20,7 @@ rendered properly in your Markdown viewer. FourierFT currently has the following constraints: -- Only `nn.Linear` layers are supported. +- Only `nn.Linear` and `nn.Conv2d` layers are supported. - Quantized layers are not supported. If these constraints don't work for your use case, consider other methods instead. diff --git a/src/peft/tuners/fourierft/config.py b/src/peft/tuners/fourierft/config.py index dbbb80d8e0..c30115e84b 100644 --- a/src/peft/tuners/fourierft/config.py +++ b/src/peft/tuners/fourierft/config.py @@ -15,7 +15,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Optional, Union +from typing import Literal, Optional, Union from peft.config import PeftConfig from peft.utils import PeftType @@ -78,6 +78,14 @@ class FourierFTConfig(PeftConfig): init_weights (`bool`): The initialization of the Fourier weights. Set this to False (the default) if the spectrum are initialized to a standard normal distribution. Set this to True if the spectrum are initialized to zeros. + alpha: + The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features) If alpha is set, + the n_frequency and n_frequency_pattern parameters should not be set. + ifft2_norm: + The normalization applied for the ifft2 operation. It has to be either `backward`, `forward` or `ortho`. + See the pytorch documentation for the ifft2 function for more details + (https://docs.pytorch.org/docs/stable/generated/torch.fft.ifft2.html) The default value is `backward`. + """ n_frequency: int = field( @@ -174,6 +182,17 @@ class FourierFTConfig(PeftConfig): ) }, ) + + ifft2_norm: Optional[Literal["backward", "forward", "ortho"]] = field( + default_factory="backward", + metadata={ + "help": ( + "The normalization applied for the ifft2 operation. " + "It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details " + "(https://docs.pytorch.org/docs/stable/generated/torch.fft.ifft2.html) The default value is `backward`." + ) + }, + ) init_weights: bool = field( default=False, metadata={ @@ -185,6 +204,16 @@ class FourierFTConfig(PeftConfig): }, ) + alpha: float = field( + default=None, + metadata={ + "help": ( + "The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features)" + "If alpha is set, the n_frequency and n_frequency_pattern parameters should not be set." + ) + }, + ) + def __post_init__(self): super().__post_init__() self.peft_type = PeftType.FOURIERFT @@ -204,3 +233,9 @@ def __post_init__(self): # check for layers_to_transform and layers_pattern if self.layers_pattern and not self.layers_to_transform: raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") + + if (self.alpha is not None) and (self.n_frequency != 1000): + raise ValueError("Don't set both alpha and n_frequency, as alpha overrides n_frequency.") + + if (self.alpha is not None) and (self.n_frequency_pattern != {}): + raise ValueError("Don't set both alpha and n_frequency_pattern, as alpha overrides n_frequency_pattern.") diff --git a/src/peft/tuners/fourierft/layer.py b/src/peft/tuners/fourierft/layer.py index a03a57f118..7afe3e0244 100644 --- a/src/peft/tuners/fourierft/layer.py +++ b/src/peft/tuners/fourierft/layer.py @@ -27,7 +27,12 @@ class FourierFTLayer(BaseTunerLayer): # All names of layers that may contain (trainable) adapter weights adapter_layer_names = ("fourierft_spectrum",) # All names of other parameters that may contain adapter-related parameters - other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed") + other_param_names = ( + "fourierft_n_frequency", + "fourierft_scaling", + "fourierft_random_loc_seed", + "fourierft_ifft2_norm", + ) def __init__(self, base_layer: nn.Module, **kwargs) -> None: self.base_layer = base_layer @@ -39,6 +44,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] + self.fourierft_ifft2_norm = kwargs["ifft2_norm"] self.kwargs = kwargs base_layer = self.get_base_layer() @@ -48,6 +54,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: self.in_features, self.out_features = ( base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape ) + elif isinstance(base_layer, nn.Conv2d): + self.in_features = base_layer.in_channels + self.out_features = base_layer.out_channels else: raise ValueError(f"Unsupported layer type {type(base_layer)}") @@ -56,20 +65,22 @@ def update_layer( ): if n_frequency <= 0: raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}") - if n_frequency > self.in_features * self.out_features: + + if isinstance(self, FourierFTLinear): + max_freqs = self.in_features * self.out_features + else: + kW = self.base_layer.kernel_size[0] + kH = self.base_layer.kernel_size[1] + max_freqs = self.in_features * self.out_features * kW * kH + + if n_frequency >= max_freqs: raise ValueError( f"`n_frequency` should be less than or equal to the product of the input and output dimensions " - f"but the value passed is {n_frequency} and the product is {self.in_features * self.out_features}" + f"but the value passed is {n_frequency} and the product is {max_freqs}" ) self.fourierft_n_frequency[adapter_name] = n_frequency self.fourierft_random_loc_seed[adapter_name] = random_loc_seed - self.indices[adapter_name] = torch.randperm( - self.out_features * self.in_features, - generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), - )[:n_frequency] - self.indices[adapter_name] = torch.stack( - [self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0 - ) + self.set_indices(adapter_name, n_frequency) self.fourierft_scaling[adapter_name] = scaling # Actual trainable parameters self.fourierft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency), requires_grad=True) @@ -91,29 +102,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor: indices = self.indices[adapter].to(spectrum.device) dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device) dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() - delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter] + delta_weight = ( + torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter] + ) return delta_weight.to(spectrum.dtype) - -class FourierFTLinear(nn.Module, FourierFTLayer): - # FourierFT implemented in a dense layer - def __init__( - self, - base_layer, - adapter_name: str, - n_frequency: int = 1000, - scaling: float = 150.0, - fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) - init_weights: Union[bool, str] = False, - random_loc_seed: int = 777, - **kwargs, - ) -> None: - super().__init__() - FourierFTLayer.__init__(self, base_layer, **kwargs) - self.fan_in_fan_out = fan_in_fan_out - self._active_adapter = adapter_name - self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) - def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ Merge the active adapter weights into the base weights @@ -163,6 +156,41 @@ def unmerge(self) -> None: if active_adapter in self.fourierft_spectrum.keys(): self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + def set_indices(self, adapter_name: str, n_frequency: int): + self.indices[adapter_name] = torch.randperm( + self.out_features * self.in_features, + generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), + )[:n_frequency] + self.indices[adapter_name] = torch.stack( + [self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0 + ) + + +class FourierFTLinear(nn.Module, FourierFTLayer): + # FourierFT implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + n_frequency: int = 1000, + alpha: Optional[float] = None, + scaling: float = 150.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_weights: Union[bool, str] = False, + random_loc_seed: int = 777, + **kwargs, + ) -> None: + super().__init__() + FourierFTLayer.__init__(self, base_layer, **kwargs) + + # apply alpha patch + if alpha: + n_frequency = int(alpha * self.in_features * self.out_features) + + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) + def get_delta_weight(self, adapter) -> torch.Tensor: return super().get_delta_weight(adapter) @@ -191,3 +219,85 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: def __repr__(self) -> str: rep = super().__repr__() return "fourierft." + rep + + +class FourierFTConv2D(nn.Module, FourierFTLayer): + # FourierFT implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + n_frequency: int = 1000, + alpha: Optional[float] = None, + scaling: float = 150.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_weights: Union[bool, str] = False, + random_loc_seed: int = 777, + **kwargs, + ) -> None: + super().__init__() + FourierFTLayer.__init__(self, base_layer, **kwargs) + + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + kW = base_layer.kernel_size[0] + kH = base_layer.kernel_size[1] + + # apply alpha patch + if alpha: + n_frequency = int(alpha * self.in_features * self.out_features * kW * kH) + self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) + + def set_indices(self, adapter_name: str, n_frequency: int): + kW = self.base_layer.kernel_size[0] + kH = self.base_layer.kernel_size[1] + self.indices[adapter_name] = torch.randperm( + self.out_features * self.in_features * kW * kH, + generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), + )[:n_frequency] + self.indices[adapter_name] = torch.stack( + [ + self.indices[adapter_name] // (self.in_features * kW), + self.indices[adapter_name] % (self.in_features * kW), + ], + dim=0, + ) + + def get_delta_weight(self, adapter) -> torch.Tensor: + kW = self.base_layer.kernel_size[0] + kH = self.base_layer.kernel_size[1] + spectrum = self.fourierft_spectrum[adapter] + indices = self.indices[adapter].to(spectrum.device) + dense_spectrum = torch.zeros(self.out_features * kH, self.in_features * kW, device=spectrum.device) + dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() + delta_weight = ( + torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter] + ) + return torch.reshape(delta_weight, (self.out_features, self.in_features, kW, kH)) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.fourierft_spectrum.keys(): + continue + + delta_w = self.get_delta_weight(active_adapter) + x = x.to(delta_w.dtype) + y = F.conv2d(x, delta_w, stride=self.base_layer.stride, padding=self.base_layer.padding) + result += y + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "fourierft." + rep diff --git a/src/peft/tuners/fourierft/model.py b/src/peft/tuners/fourierft/model.py index 5347d90b17..abf4226a90 100644 --- a/src/peft/tuners/fourierft/model.py +++ b/src/peft/tuners/fourierft/model.py @@ -18,6 +18,7 @@ from itertools import chain import torch +from torch.nn import Conv2d from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer @@ -25,7 +26,7 @@ TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, ) -from .layer import FourierFTLayer, FourierFTLinear +from .layer import FourierFTConv2D, FourierFTLayer, FourierFTLinear class FourierFTModel(BaseTuner): @@ -71,11 +72,15 @@ def _create_and_replace( n_frequency = fourierft_config.n_frequency_pattern.get(target_name_key, fourierft_config.n_frequency) scaling = fourierft_config.scaling + alpha = fourierft_config.alpha + ifft2_norm = fourierft_config.ifft2_norm random_loc_seed = fourierft_config.random_loc_seed bias = hasattr(target, "bias") and target.bias is not None kwargs = { "n_frequency": n_frequency, + "alpha": alpha, "scaling": scaling, + "ifft2_norm": ifft2_norm, "fan_in_fan_out": fourierft_config.fan_in_fan_out, "init_weights": fourierft_config.init_weights, "random_loc_seed": fourierft_config.random_loc_seed, @@ -110,6 +115,7 @@ def _create_new_module(fourierft_config, adapter_name, target, **kwargs): "Setting fan_in_fan_out to False." ) kwargs["fan_in_fan_out"] = fourierft_config.fan_in_fan_out = False + new_module = FourierFTLinear(target, adapter_name, **kwargs) elif isinstance(target_base_layer, Conv1D): kwargs["is_target_conv_1d_layer"] = True if not kwargs["fan_in_fan_out"]: @@ -117,12 +123,12 @@ def _create_new_module(fourierft_config, adapter_name, target, **kwargs): "fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True." ) kwargs["fan_in_fan_out"] = fourierft_config.fan_in_fan_out = True + new_module = FourierFTLinear(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, Conv2d): + new_module = FourierFTConv2D(target, adapter_name, **kwargs) else: raise ValueError( f"Target module {target} is not supported. Currently, only the following modules are supported: " - "`torch.nn.Linear`." + "`torch.nn.Linear`, `torch.nn.Conv2d`" ) - - new_module = FourierFTLinear(target, adapter_name, **kwargs) - return new_module diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 33cded4116..b4b5b2b019 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -663,6 +663,26 @@ "init_weights": True, }, ), + ( + "Conv2d 1 FourierFT", + "Conv2d", + FourierFTConfig, + { + "target_modules": ["conv2d"], + "n_frequency": 1000, + }, + ), + ( + "Conv2d 2 FourierFT", + "Conv2d", + FourierFTConfig, + { + "target_modules": ["conv2d", "lin0"], + "alpha": 0.01, + "init_weights": True, + "ifft2_norm": "ortho", + }, + ), ########## # VBLoRA # ########## diff --git a/tests/test_initialization.py b/tests/test_initialization.py index f37e0c2cbe..0d9f780f15 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -39,6 +39,7 @@ AdaLoraConfig, C3AConfig, EvaConfig, + FourierFTConfig, IA3Config, LoftQConfig, LoKrConfig, @@ -4499,3 +4500,41 @@ def test_key_mapping_save_old_load_new_vblora(self, old_model, new_model, tmp_pa def test_key_mapping_save_new_load_old_vblora(self, old_model, new_model, tmp_path): # save the new model, load it into the old model, should work without issues (forwards compatibility) self.check_vblora_load_no_warning(new_model, old_model, tmp_path) + + +class TestFourierFTInitialization: + torch_device = infer_device() + + def get_model(self, bias=True): + class MyModule(nn.Module): + def __init__(self): + super().__init__() + # choose a large weight so that averages are close to expected values + self.linear = nn.Linear(1000, 1000, bias=bias) + self.embed = nn.Embedding(1000, 1000) + self.conv2d = nn.Conv2d(100, 100, 3, bias=bias) + + def forward(self, x): + x_int = (100 * x).int() + x_4d = x.flatten().reshape(1, 100, 10, 10) + return self.linear(x), self.embed(x_int), self.conv2d(x_4d) + + return MyModule().eval().to(self.torch_device) + + def test_fourierft_set_alpha_and_n_frequency_raises(self): + torch.manual_seed(0) + + model = self.get_model() + config = FourierFTConfig(target_modules=["linear"], alpha=0.1, n_frequency=2000) + msg = "User shoudn't set both alpha and n_frequency parameters." + with pytest.raises(ValueError, match=msg): + get_peft_model(model, config) + + def test_fourierft_set_alpha_and_n_frequency_pattern_raises(self): + torch.manual_seed(0) + + model = self.get_model() + config = FourierFTConfig(target_modules=["linear"], alpha=0.1, n_frequency_pattern={"linear": 2000}) + msg = "User shoudn't set both alpha and n_frequency_pattern parameters." + with pytest.raises(ValueError, match=msg): + get_peft_model(model, config)