diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py
index abd19ed35..6bb481dc4 100644
--- a/QEfficient/base/pytorch_transforms.py
+++ b/QEfficient/base/pytorch_transforms.py
@@ -107,6 +107,9 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
):
for orig_method_name, mapped_method in repl_method_map.items():
setattr(module, orig_method_name, MethodType(mapped_method, module))
+ # Handling the __init__ calls in the models
+ if hasattr(module, "__qeff_init__"):
+ module.__qeff_init__()
transformed = True
return model, transformed
diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py
index 6b5deb8db..0e5531c31 100644
--- a/QEfficient/transformers/models/modeling_auto.py
+++ b/QEfficient/transformers/models/modeling_auto.py
@@ -1291,6 +1291,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
FP8DeQuantLinearToLinearTransform,
CustomOpsTransform,
KVCacheTransform,
+ KVCacheModuleMethodMapperTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
diff --git a/QEfficient/transformers/models/plamo/__init__.py b/QEfficient/transformers/models/plamo/__init__.py
new file mode 100644
index 000000000..72ba36c8a
--- /dev/null
+++ b/QEfficient/transformers/models/plamo/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
diff --git a/QEfficient/transformers/models/plamo/modeling_plamo.py b/QEfficient/transformers/models/plamo/modeling_plamo.py
new file mode 100644
index 000000000..17b3270c6
--- /dev/null
+++ b/QEfficient/transformers/models/plamo/modeling_plamo.py
@@ -0,0 +1,536 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig, PreTrainedModel
+from transformers.cache_utils import Cache
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+
+from QEfficient.customop.rms_norm import CustomRMSNorm
+from QEfficient.transformers.cache_utils import QEffDynamicCache
+from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
+
+
+class QEffPlamoConfig(PretrainedConfig): # type: ignore
+ model_type: str = "plamo"
+
+ def __init__(
+ self,
+ vocab_size: int = 32000,
+ hidden_size: int = 4096,
+ intermediate_size: int = 13312,
+ num_hidden_layers: int = 32,
+ num_attention_heads: int = 32,
+ num_key_value_heads: Optional[int] = None,
+ max_position_embeddings: int = 2048,
+ initializer_range: float = 0.02,
+ rms_norm_eps: float = 1e-6,
+ use_cache: bool = True,
+ tokenizer_class: str = "PlamoTokenizer",
+ pad_token_id: Optional[int] = None,
+ bos_token_id: int = 1,
+ eos_token_id: int = 2,
+ n_shared_head: int = 8,
+ tie_word_embeddings: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+
+ self.n_shared_head = n_shared_head
+
+ super().__init__(
+ tokenizer_class=tokenizer_class,
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class QEffPlamoRotaryEmbedding(torch.nn.Module):
+ def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None:
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) # type: ignore
+
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
+
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+ return (
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
+ )
+
+
+def _rotate_half(x: torch.Tensor) -> torch.Tensor:
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ x_embed = (x * cos) + (_rotate_half(x) * sin)
+ return x_embed
+
+
+class QEffPlamoRMSNorm(nn.Module):
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(module.qk_dim)
+
+ if attention_mask is not None:
+ attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class QEffPlamoAttention(torch.nn.Module):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ batch_index: Optional[torch.Tensor] = None,
+ layer_idx: Optional[int] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.v_num_heads, self.v_dim).transpose(1, 2)
+
+ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor:
+ return t.repeat(1, repeat, 1, 1)[:, :target]
+
+ # expand shared kv
+ assert self.k_num_heads == self.v_num_heads
+ key_states = _expand_kv(key_states, self.config.n_shared_head, self.q_num_heads)
+ value_states = _expand_kv(value_states, self.config.n_shared_head, self.q_num_heads)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len = past_key_value.get_usable_length(kv_seq_len, layer_idx)
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ # query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ query_states = _rotary_pos_emb(query_states, cos, sin, position_ids)
+ key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
+ key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class MLP(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
+
+
+class QEffPlamoDecoderLayer(torch.nn.Module):
+ def __qeff_init__(
+ self,
+ ):
+ self.norm = CustomRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ batch_index: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ layer_idx: Optional[int] = None,
+ ) -> Tuple[Any, ...]:
+ # from LlamaDecoder
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+
+ # Self Attention
+ hidden_states_sa, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ batch_index=batch_index,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ layer_idx=layer_idx,
+ )
+
+ # Fully Connected
+ hidden_states_mlp = self.mlp(hidden_states)
+
+ # Residual
+ hidden_states = residual + hidden_states_sa + hidden_states_mlp
+
+ outputs: Any = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs # type: ignore
+
+
+class QEffPlamoDecoder(torch.nn.Module):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ output_hidden_states: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ batch_index: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if output_hidden_states else None
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if output_attentions else None
+ next_decoder_cache: Optional[Tuple[torch.Tensor, ...]] = () if use_cache else None
+ hidden_states = hidden_states
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ assert all_hidden_states is not None
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ layer_idx=idx,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ cache = layer_outputs[2 if output_attentions else 1]
+ assert cache is not None
+ assert next_decoder_cache is not None
+ next_decoder_cache = cache
+
+ if output_attentions:
+ assert layer_outputs[1] is not None
+ assert all_self_attns is not None
+ all_self_attns += (layer_outputs[1],)
+
+ return (hidden_states, all_hidden_states, all_self_attns, next_decoder_cache)
+
+
+class QEffPlamoPreTrainedModel(PreTrainedModel): # type: ignore
+ config_class = QEffPlamoConfig
+ _no_split_modules: List[str]
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["PlamoDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
+
+ def _init_weights(self, module: torch.nn.Module) -> None:
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module: torch.nn.Module, value: bool = False) -> None:
+ module.gradient_checkpointing = value # type: ignore
+
+
+class QEffPlamoModel(QEffPlamoPreTrainedModel):
+ def __qeff_init__(
+ self,
+ ):
+ self.norm = CustomRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ batch_index: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ assert input_ids is not None
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ past_key_values = QEffDynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ attention_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens)
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ use_cache = False
+
+ # decoder layers
+ layer_outputs = self.layers(
+ hidden_states=hidden_states,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=output_hidden_states,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ batch_index=batch_index,
+ )
+
+ hidden_states = layer_outputs[0]
+ all_hidden_states = layer_outputs[1]
+ all_self_attns = layer_outputs[2]
+ next_decoder_cache = layer_outputs[3]
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ assert all_hidden_states is not None
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class QEffPlamoForCausalLM(QEffPlamoPreTrainedModel):
+ def forward( # type: ignore
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ batch_index: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ assert input_ids is not None
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ batch_index=batch_index,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ # Cast to INT32 to avoid issue while running in ONNXRT
+ logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
+ hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
+
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ return CausalLMOutputWithPast(
+ loss=None,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.Tensor,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs: Dict[str, Any] = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values: List[torch.FloatTensor], beam_idx: int) -> Tuple[Any, ...]:
+ reordered_past: Tuple[Any, ...] = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py
index 333c734ba..3d8eac97a 100644
--- a/QEfficient/transformers/models/pytorch_transforms.py
+++ b/QEfficient/transformers/models/pytorch_transforms.py
@@ -245,6 +245,15 @@
QEffPhi3ForCausalLM,
QEffPhi3Model,
)
+from QEfficient.transformers.models.plamo.modeling_plamo import (
+ QEffPlamoAttention,
+ QEffPlamoDecoder,
+ QEffPlamoDecoderLayer,
+ QEffPlamoForCausalLM,
+ QEffPlamoModel,
+ QEffPlamoRMSNorm,
+ QEffPlamoRotaryEmbedding,
+)
from QEfficient.transformers.models.qwen2.modeling_qwen2 import (
QEffQwen2Attention,
QEffQwen2DecoderLayer,
@@ -485,5 +494,12 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform):
"get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder,
},
"InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward},
+ "PlamoForCausalLM": {"forward": QEffPlamoForCausalLM.forward},
+ "PlamoModel": {"forward": QEffPlamoModel.forward},
+ "PlamoDecoder": {"forward": QEffPlamoDecoder.forward},
+ "PlamoDecoderLayer": {"forward": QEffPlamoDecoderLayer.forward},
+ "Attention": {"forward": QEffPlamoAttention.forward},
+ "RMSNorm": {"forward": QEffPlamoRMSNorm.forward},
+ "RotaryEmbedding": {"forward": QEffPlamoRotaryEmbedding.forward},
}
_match_class_replace_method = {}
diff --git a/README.md b/README.md
index 685db6fe7..de12aee5b 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,7 @@
- [04/2025] [Granite 3.0 and 3.1 Language MOE Models] (https://huggingface.co/ibm-granite/granite-3.0-1b-a400m-base)
- [09/2024] [AWQ](https://arxiv.org/abs/2306.00978)/[GPTQ](https://arxiv.org/abs/2210.17323) 4-bit quantized models are supported
- [09/2024] Now we support [PEFT](https://huggingface.co/docs/peft/index) models
+- [04/2025] Added support for [PLaMo] (https://huggingface.co/pfnet/plamo-13b-instruct)
- [01/2025] Added support for [Ibm-Granite] (https://huggingface.co/ibm-granite/granite-3.1-8b-instruct)
- [01/2025] Added support for [Ibm-Granite-Guardian] (https://huggingface.co/ibm-granite/granite-guardian-3.1-8b)
- [09/2024] Added support for [Gemma-2-Family](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py
index efa2187b7..6d401d4c6 100644
--- a/tests/transformers/models/test_causal_lm_models.py
+++ b/tests/transformers/models/test_causal_lm_models.py
@@ -44,6 +44,7 @@
"neuralmagic/Qwen2-0.5B-Instruct-FP8", # fp8 quant method, static, with lm head ignored
"ibm-granite/granite-3.1-2b-instruct",
"ibm-granite/granite-guardian-3.1-2b",
+ "pfnet/plamo-13b-instruct",
]
test_models_qnn = [