diff --git a/src/peft/tuners/trainable_tokens/layer.py b/src/peft/tuners/trainable_tokens/layer.py index a6f170565e..d6047b2ac9 100644 --- a/src/peft/tuners/trainable_tokens/layer.py +++ b/src/peft/tuners/trainable_tokens/layer.py @@ -116,13 +116,18 @@ def update_layer(self, adapter_name, **kwargs): # onto the new values, we would get undefined behavior. By replacing the specific token values we always # get defined behavior. weight = self.get_base_layer().weight - embed_dim = self.get_base_layer().embedding_dim + base = self.get_base_layer() + embed_dim = getattr(base, "embedding_dim", None) + if embed_dim is None: + embed_dim = getattr(base, "in_features", None) + if embed_dim is None: + embed_dim = weight.shape[-1] if init_weights: if check_deepspeed_zero3_enabled(): values = self._collect_token_weights(weight, self.token_indices[adapter_name], embed_dim) else: - values = self.weight[self.token_indices[adapter_name]] + values = weight[self.token_indices[adapter_name]] else: # random init with matching dtype/device values = torch.randn( @@ -230,9 +235,11 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) -> ) elif isinstance(self.base_layer, torch.nn.Linear): # Probably a tied adapter that wraps an LM head. + bias = getattr(self.base_layer, "bias", None) result = F.linear( input=x, weight=W, + bias=bias, ) else: raise ValueError( diff --git a/src/peft/tuners/trainable_tokens/model.py b/src/peft/tuners/trainable_tokens/model.py index ff359370cb..e5f5da74be 100644 --- a/src/peft/tuners/trainable_tokens/model.py +++ b/src/peft/tuners/trainable_tokens/model.py @@ -41,7 +41,27 @@ def __getattr__(self, name: str): def _prepare_adapter_config(self, peft_config, model_config): # target_modules can be none which prompts us to infer the embedding layer name ourselves. if peft_config.target_modules is None: - peft_config.target_modules = _get_input_embeddings_name(self.model, "embed_tokens") + targets = _get_input_embeddings_name(self.model, "embed_tokens") + if isinstance(targets, str): + targets = [targets] + + # If embeddings are untied, also include the output embedding (lm head) module name + try: + tied_cfg = model_config.get("tie_word_embeddings", False) + tied_keys = getattr(self.model, "_tied_weights_keys", None) + are_tied = bool(tied_cfg and tied_keys is not None) + except Exception: + are_tied = False + + if not are_tied and hasattr(self.model, "get_output_embeddings"): + out_emb = self.model.get_output_embeddings() + if out_emb is not None: + for name, module in self.model.named_modules(): + if module is out_emb: + targets.append(name) + break + + peft_config.target_modules = list(dict.fromkeys(targets)) return peft_config