Skip to content
2 changes: 1 addition & 1 deletion docs/source/package_reference/fourierft.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ rendered properly in your Markdown viewer.

FourierFT currently has the following constraints:

- Only `nn.Linear` layers are supported.
- Only `nn.Linear` and `nn.Conv2d` layers are supported.
- Quantized layers are not supported.

If these constraints don't work for your use case, consider other methods instead.
Expand Down
37 changes: 36 additions & 1 deletion src/peft/tuners/fourierft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional, Union
from typing import Literal, Optional, Union

from peft.config import PeftConfig
from peft.utils import PeftType
Expand Down Expand Up @@ -78,6 +78,14 @@ class FourierFTConfig(PeftConfig):
init_weights (`bool`):
The initialization of the Fourier weights. Set this to False (the default) if the spectrum are initialized
to a standard normal distribution. Set this to True if the spectrum are initialized to zeros.
alpha:
The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features) If alpha is set,
the n_frequency and n_frequency_pattern parameters should not be set.
ifft2_norm:
The normalization applied for the ifft2 operation. It has to be either `backward`, `forward` or `ortho`.
See the pytorch documentation for the ifft2 function for more details
(https://docs.pytorch.org/docs/stable/generated/torch.fft.ifft2.html) The default value is `backward`.

"""

n_frequency: int = field(
Expand Down Expand Up @@ -174,6 +182,17 @@ class FourierFTConfig(PeftConfig):
)
},
)

ifft2_norm: Optional[Literal["backward", "forward", "ortho"]] = field(
default_factory="backward",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default_factory="backward",
default="backward",

metadata={
"help": (
"The normalization applied for the ifft2 operation. "
"It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details "
"(https://docs.pytorch.org/docs/stable/generated/torch.fft.ifft2.html) The default value is `backward`."
)
},
)
init_weights: bool = field(
default=False,
metadata={
Expand All @@ -185,6 +204,16 @@ class FourierFTConfig(PeftConfig):
},
)

alpha: float = field(
default=None,
metadata={
"help": (
"The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features)"
"If alpha is set, the n_frequency and n_frequency_pattern parameters should not be set."
)
},
)

def __post_init__(self):
super().__post_init__()
self.peft_type = PeftType.FOURIERFT
Expand All @@ -204,3 +233,9 @@ def __post_init__(self):
# check for layers_to_transform and layers_pattern
if self.layers_pattern and not self.layers_to_transform:
raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ")

if (self.alpha is not None) and (self.n_frequency != 1000):
raise ValueError("Don't set both alpha and n_frequency, as alpha overrides n_frequency.")

if (self.alpha is not None) and (self.n_frequency_pattern != {}):
raise ValueError("Don't set both alpha and n_frequency_pattern, as alpha overrides n_frequency_pattern.")
172 changes: 141 additions & 31 deletions src/peft/tuners/fourierft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ class FourierFTLayer(BaseTunerLayer):
# All names of layers that may contain (trainable) adapter weights
adapter_layer_names = ("fourierft_spectrum",)
# All names of other parameters that may contain adapter-related parameters
other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed")
other_param_names = (
"fourierft_n_frequency",
"fourierft_scaling",
"fourierft_random_loc_seed",
"fourierft_ifft2_norm",
)

def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.base_layer = base_layer
Expand All @@ -39,6 +44,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
self.fourierft_ifft2_norm = kwargs["ifft2_norm"]
self.kwargs = kwargs

base_layer = self.get_base_layer()
Expand All @@ -48,6 +54,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.in_features, self.out_features = (
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
)
elif isinstance(base_layer, nn.Conv2d):
self.in_features = base_layer.in_channels
self.out_features = base_layer.out_channels
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")

Expand All @@ -56,20 +65,22 @@ def update_layer(
):
if n_frequency <= 0:
raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}")
if n_frequency > self.in_features * self.out_features:

if isinstance(self, FourierFTLinear):
max_freqs = self.in_features * self.out_features
else:
kW = self.base_layer.kernel_size[0]
kH = self.base_layer.kernel_size[1]
max_freqs = self.in_features * self.out_features * kW * kH

if n_frequency >= max_freqs:
raise ValueError(
f"`n_frequency` should be less than or equal to the product of the input and output dimensions "
f"but the value passed is {n_frequency} and the product is {self.in_features * self.out_features}"
f"but the value passed is {n_frequency} and the product is {max_freqs}"
)
self.fourierft_n_frequency[adapter_name] = n_frequency
self.fourierft_random_loc_seed[adapter_name] = random_loc_seed
self.indices[adapter_name] = torch.randperm(
self.out_features * self.in_features,
generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]),
)[:n_frequency]
self.indices[adapter_name] = torch.stack(
[self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0
)
self.set_indices(adapter_name, n_frequency)
self.fourierft_scaling[adapter_name] = scaling
# Actual trainable parameters
self.fourierft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency), requires_grad=True)
Expand All @@ -91,29 +102,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
indices = self.indices[adapter].to(spectrum.device)
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device)
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float()
delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter]
delta_weight = (
torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter]
)
return delta_weight.to(spectrum.dtype)


class FourierFTLinear(nn.Module, FourierFTLayer):
# FourierFT implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name: str,
n_frequency: int = 1000,
scaling: float = 150.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
init_weights: Union[bool, str] = False,
random_loc_seed: int = 777,
**kwargs,
) -> None:
super().__init__()
FourierFTLayer.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Expand Down Expand Up @@ -163,6 +156,41 @@ def unmerge(self) -> None:
if active_adapter in self.fourierft_spectrum.keys():
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)

def set_indices(self, adapter_name: str, n_frequency: int):
self.indices[adapter_name] = torch.randperm(
self.out_features * self.in_features,
generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]),
)[:n_frequency]
self.indices[adapter_name] = torch.stack(
[self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0
)


class FourierFTLinear(nn.Module, FourierFTLayer):
# FourierFT implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name: str,
n_frequency: int = 1000,
alpha: Optional[float] = None,
scaling: float = 150.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
init_weights: Union[bool, str] = False,
random_loc_seed: int = 777,
**kwargs,
) -> None:
super().__init__()
FourierFTLayer.__init__(self, base_layer, **kwargs)

# apply alpha patch
if alpha:
n_frequency = int(alpha * self.in_features * self.out_features)

self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed)

def get_delta_weight(self, adapter) -> torch.Tensor:
return super().get_delta_weight(adapter)

Expand Down Expand Up @@ -191,3 +219,85 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
def __repr__(self) -> str:
rep = super().__repr__()
return "fourierft." + rep


class FourierFTConv2D(nn.Module, FourierFTLayer):
# FourierFT implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name: str,
n_frequency: int = 1000,
alpha: Optional[float] = None,
scaling: float = 150.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
init_weights: Union[bool, str] = False,
random_loc_seed: int = 777,
**kwargs,
) -> None:
super().__init__()
FourierFTLayer.__init__(self, base_layer, **kwargs)

self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
kW = base_layer.kernel_size[0]
kH = base_layer.kernel_size[1]

# apply alpha patch
if alpha:
n_frequency = int(alpha * self.in_features * self.out_features * kW * kH)
self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed)

def set_indices(self, adapter_name: str, n_frequency: int):
kW = self.base_layer.kernel_size[0]
kH = self.base_layer.kernel_size[1]
self.indices[adapter_name] = torch.randperm(
self.out_features * self.in_features * kW * kH,
generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]),
)[:n_frequency]
self.indices[adapter_name] = torch.stack(
[
self.indices[adapter_name] // (self.in_features * kW),
self.indices[adapter_name] % (self.in_features * kW),
],
dim=0,
)

def get_delta_weight(self, adapter) -> torch.Tensor:
kW = self.base_layer.kernel_size[0]
kH = self.base_layer.kernel_size[1]
spectrum = self.fourierft_spectrum[adapter]
indices = self.indices[adapter].to(spectrum.device)
dense_spectrum = torch.zeros(self.out_features * kH, self.in_features * kW, device=spectrum.device)
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float()
delta_weight = (
torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter]
)
return torch.reshape(delta_weight, (self.out_features, self.in_features, kW, kH))

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.fourierft_spectrum.keys():
continue

delta_w = self.get_delta_weight(active_adapter)
x = x.to(delta_w.dtype)
y = F.conv2d(x, delta_w, stride=self.base_layer.stride, padding=self.base_layer.padding)
result += y

result = result.to(previous_dtype)
return result

def __repr__(self) -> str:
rep = super().__repr__()
return "fourierft." + rep
16 changes: 11 additions & 5 deletions src/peft/tuners/fourierft/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from itertools import chain

import torch
from torch.nn import Conv2d
from transformers.pytorch_utils import Conv1D

from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils import (
TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING,
)

from .layer import FourierFTLayer, FourierFTLinear
from .layer import FourierFTConv2D, FourierFTLayer, FourierFTLinear


class FourierFTModel(BaseTuner):
Expand Down Expand Up @@ -71,11 +72,15 @@ def _create_and_replace(

n_frequency = fourierft_config.n_frequency_pattern.get(target_name_key, fourierft_config.n_frequency)
scaling = fourierft_config.scaling
alpha = fourierft_config.alpha
ifft2_norm = fourierft_config.ifft2_norm
random_loc_seed = fourierft_config.random_loc_seed
bias = hasattr(target, "bias") and target.bias is not None
kwargs = {
"n_frequency": n_frequency,
"alpha": alpha,
"scaling": scaling,
"ifft2_norm": ifft2_norm,
"fan_in_fan_out": fourierft_config.fan_in_fan_out,
"init_weights": fourierft_config.init_weights,
"random_loc_seed": fourierft_config.random_loc_seed,
Expand Down Expand Up @@ -110,19 +115,20 @@ def _create_new_module(fourierft_config, adapter_name, target, **kwargs):
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = fourierft_config.fan_in_fan_out = False
new_module = FourierFTLinear(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, Conv1D):
kwargs["is_target_conv_1d_layer"] = True
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = fourierft_config.fan_in_fan_out = True
new_module = FourierFTLinear(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, Conv2d):
new_module = FourierFTConv2D(target, adapter_name, **kwargs)
else:
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`."
"`torch.nn.Linear`, `torch.nn.Conv2d`"
)

new_module = FourierFTLinear(target, adapter_name, **kwargs)

return new_module
20 changes: 20 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,26 @@
"init_weights": True,
},
),
(
"Conv2d 1 FourierFT",
"Conv2d",
FourierFTConfig,
{
"target_modules": ["conv2d"],
"n_frequency": 1000,
},
),
(
"Conv2d 2 FourierFT",
"Conv2d",
FourierFTConfig,
{
"target_modules": ["conv2d", "lin0"],
"alpha": 0.01,
"init_weights": True,
"ifft2_norm": "ortho",
},
),
##########
# VBLoRA #
##########
Expand Down
Loading
Loading