Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions loratorch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def set_param(curr_mod, name, param=None, mode='update'):

class LoRALayer():
def __init__(
self,
self,
r: int,
lora_alpha: int,
fan_in_fan_out: bool = False,
Expand All @@ -47,28 +47,28 @@ def __init__(
def register_lora_param(self):
r"""Register LoRA matrix"""
for param_name, lora_name in self.params_with_lora.items():
assert len(eval(f'self.{param_name}').size()) == 2
assert len(getattr(self, param_name).size()) == 2
self.register_parameter(f'{lora_name}_lora_A',
nn.Parameter(eval(f'self.{param_name}').new_zeros((self.r, eval(f'self.{param_name}').size()[1])))
nn.Parameter(getattr(self, param_name).new_zeros((self.r, getattr(self, param_name).size()[1])))
)
self.register_parameter(f'{lora_name}_lora_B',
nn.Parameter(eval(f'self.{param_name}').new_zeros((eval(f'self.{param_name}').size()[0], self.r)))
nn.Parameter(getattr(self, param_name).new_zeros((getattr(self, param_name).size()[0], self.r)))
)
eval(f'self.{param_name}').requires_grad = False
getattr(self, param_name).requires_grad = False

def init_lora_param(self):
for param_name, lora_name in self.params_with_lora.items():
if hasattr(self, f'{lora_name}_lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A'), a=math.sqrt(5))
nn.init.zeros_(eval(f'self.{lora_name}_lora_B'))
nn.init.kaiming_uniform_(getattr(self, f'{lora_name}_lora_A'), a=math.sqrt(5))
nn.init.zeros_(getattr(self, f'{lora_name}_lora_B'))

def transpose(self, w: torch.Tensor):
return w.transpose(0, 1) if self.fan_in_fan_out else w

def merge_BA(self, param_name: str):
lora_name = self.params_with_lora[param_name]
return self.transpose((eval(f'self.{lora_name}_lora_B') @ eval(f'self.{lora_name}_lora_A')).view(eval(f'self.{param_name}').shape))
return self.transpose((getattr(self, f'{lora_name}_lora_B') @ getattr(self, f'{lora_name}_lora_A')).view(getattr(self, param_name).shape))

def merge_lora_param(self):
r"""p_new = p + scaling * B @ A and keep differentiable to A and B"""
Expand All @@ -81,12 +81,12 @@ def merge_lora_param(self):
def add_lora_data(self):
r"""NOT differentiable"""
for param_name, lora_name in self.params_with_lora.items():
eval(f'self.{param_name}').data += self.merge_BA(param_name) * self.scaling
getattr(self, param_name).data += self.merge_BA(param_name) * self.scaling

def sub_lora_data(self):
r"""NOT differentiable"""
for param_name, lora_name in self.params_with_lora.items():
eval(f'self.{param_name}').data -= self.merge_BA(param_name) * self.scaling
getattr(self, param_name).data -= self.merge_BA(param_name) * self.scaling

def lora_train(self, mode: bool = True):
if mode:
Expand Down Expand Up @@ -346,13 +346,13 @@ def __init__(

def init_lora_param_qkv(self, enable_lora_bool):
lora_name = self.params_with_lora['in_proj_weight']
nn.init.zeros_(eval(f'self.{lora_name}_lora_B'))
nn.init.zeros_(getattr(self, f'{lora_name}_lora_B'))
dim = int(self.in_proj_weight.size()[1] / 3)
for idx, enable in zip(range(3), enable_lora_bool):
if enable:
nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A')[:,idx*dim:(idx+1)*dim], a=math.sqrt(5))
nn.init.kaiming_uniform_(getattr(self, f'{lora_name}_lora_A')[:,idx*dim:(idx+1)*dim], a=math.sqrt(5))
else:
nn.init.zeros_(eval(f'self.{lora_name}_lora_A')[:,idx*dim:(idx+1)*dim])
nn.init.zeros_(getattr(self, f'{lora_name}_lora_A')[:,idx*dim:(idx+1)*dim])

def train(self, mode: bool = True):
nn.MultiheadAttention.train(self, mode)
Expand Down Expand Up @@ -418,8 +418,8 @@ def zero_pad(self, x):
def merge_BA(self, param_name: str):
lora_name = self.params_with_lora[param_name]
delta_w = F.conv1d(
eval(f'self.{lora_name}_lora_A').unsqueeze(0),
eval(f'self.{lora_name}_lora_B').unsqueeze(-1),
getattr(self, f'{lora_name}_lora_A').unsqueeze(0),
getattr(self, f'{lora_name}_lora_B').unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
return self.transpose(self.zero_pad(delta_w))
Expand Down