Skip to content

Commit b2dd328

Browse files
ochougulvbaddiVinayak Baddimamtsingquic-mamta
authored
New PR for GPTOSS decode-only model (#603)
Signed-off-by: vbaddi <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> Signed-off-by: Mamta Singh <[email protected]> Signed-off-by: Mamta Singh <[email protected]> Co-authored-by: Vinayak Baddi <[email protected]> Co-authored-by: Vinayak Baddi <[email protected]> Co-authored-by: Mamta Singh <[email protected]> Co-authored-by: Mamta Singh <[email protected]>
1 parent 25236bb commit b2dd328

File tree

16 files changed

+1336
-46
lines changed

16 files changed

+1336
-46
lines changed

QEfficient/base/pytorch_transforms.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -120,61 +120,109 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
120120

121121
class SplitGateUpWeightsTransform(PytorchTransform):
122122
"""
123-
split fused Gate+Up weights and copy into the model
123+
Split fused Gate+Up weights and copy into the model.
124+
Handles both standard MoE models and GptOss models.
124125
125126
For every transformer layer inside `model`:
126-
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
127-
• copies halves into
128-
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
129-
<PREFIX>.experts.up_proj <-- Up [E,H,I]
127+
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
128+
• copies halves into
129+
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
130+
<PREFIX>.experts.up_proj <-- Up [E,H,I]
131+
132+
Handles both interleaved weights (GptOss) and concatenated weights (standard MoE).
133+
Also handles bias terms when present.
130134
"""
131135

132136
@classmethod
133137
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
134138
transformed = False
135139
model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__
136-
137140
if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS:
138141
return model, transformed
139142

140143
model_tmp = model.language_model if hasattr(model, "language_model") else model
141-
142144
num_layers = len(model_tmp.model.layers)
143145
delete_fused_key = True
144146
sd = model_tmp.state_dict()
147+
145148
for layer_idx in range(num_layers):
149+
# Determine if this is a GptOss model or standard MoE model
150+
is_gpt_oss = hasattr(model_tmp.model.layers[layer_idx], "mlp")
151+
146152
# ---- build the textual prefix once per layer ----------
147-
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
153+
if is_gpt_oss:
154+
prefix = f"model.layers.{layer_idx}.mlp.experts."
155+
experts = model_tmp.model.layers[layer_idx].mlp.experts
156+
else:
157+
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
158+
experts = model_tmp.model.layers[layer_idx].feed_forward.experts
148159

149160
fused_key = prefix + "gate_up_proj"
150161
gate_key = prefix + "gate_proj"
151162
up_key = prefix + "up_proj"
152163

153-
# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
154-
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
164+
# Check if we have bias terms (GptOss case)
165+
has_bias = fused_key + "_bias" in sd
166+
if has_bias:
167+
fused_bias_key = fused_key + "_bias"
168+
gate_bias_key = gate_key + "_bias"
169+
up_bias_key = up_key + "_bias"
170+
171+
# ---- split weights based on model type ----------------------
172+
fused = sd[fused_key] # [E, H, 2I]
155173
E, H, two_I = fused.shape
156-
ffn_dim = two_I // 2
157-
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy
158174

159-
experts = model_tmp.model.layers[layer_idx].feed_forward.experts
175+
if is_gpt_oss:
176+
# For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...]
177+
gate = fused[..., ::2] # [E, H, I] - even indices
178+
up = fused[..., 1::2] # [E, H, I] - odd indices
179+
else:
180+
# For standard MoE, gate/up are concatenated: [gate, up]
181+
ffn_dim = two_I // 2
182+
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy
183+
184+
# Copy weights to model
160185
experts.gate_proj.data.copy_(gate)
161186
experts.up_proj.data.copy_(up)
162187

188+
# Handle bias if present
189+
if has_bias:
190+
fused_bias = sd[fused_bias_key] # [E, 2I]
191+
192+
if is_gpt_oss:
193+
gate_bias = fused_bias[..., ::2] # [E, I] - even indices
194+
up_bias = fused_bias[..., 1::2] # [E, I] - odd indices
195+
else:
196+
ffn_dim = fused_bias.shape[-1] // 2
197+
gate_bias, up_bias = fused_bias.split(ffn_dim, dim=-1)
198+
199+
experts.gate_proj_bias.data.copy_(gate_bias)
200+
experts.up_proj_bias.data.copy_(up_bias)
201+
163202
# ---- update the state-dict so load_state_dict sees the right keys
164203
sd[gate_key] = gate
165204
sd[up_key] = up
166205

206+
if has_bias:
207+
sd[gate_bias_key] = gate_bias
208+
sd[up_bias_key] = up_bias
209+
210+
# Delete fused keys
167211
if delete_fused_key:
168212
del sd[fused_key]
213+
if has_bias:
214+
del sd[fused_bias_key]
169215

170-
logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
216+
logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
171217
transformed = True
172218

173219
if hasattr(model, "language_model"):
174220
model.language_model = model_tmp
175221
else:
176222
model = model_tmp
223+
177224
return model, transformed
178225

179226

180-
VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"}
227+
# Keep the existing list of supported models
228+
VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"}

QEfficient/transformers/cache_utils.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,3 +537,122 @@ def update(
537537
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
538538
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out)
539539
return k_out, v_out
540+
541+
542+
# This is a hack for now, until we get to merging this code with HybridCache class,
543+
# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and
544+
# ours are made to work with AIC
545+
class QEffHybridCacheForGPTOSS:
546+
def __init__(self, config, batch_size, max_cache_len, sliding_window_len):
547+
self.max_cache_len = max_cache_len
548+
self.batch_size = batch_size
549+
self.sliding_window_len = sliding_window_len
550+
self.key_cache: List[torch.Tensor] = []
551+
self.value_cache: List[torch.Tensor] = []
552+
553+
@classmethod
554+
def from_legacy_cache(
555+
cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
556+
) -> "HybridCache":
557+
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
558+
backward compatibility."""
559+
cache = cls(
560+
config,
561+
batch_size=past_key_values[0][0].shape[0],
562+
max_cache_len=past_key_values[1][0].shape[2],
563+
sliding_window_len=past_key_values[0][0].shape[2],
564+
)
565+
if past_key_values is not None:
566+
for layer_idx in range(len(past_key_values)):
567+
key_states, value_states = past_key_values[layer_idx]
568+
cache.update(key_states, value_states, layer_idx)
569+
return cache
570+
571+
def __len__(self):
572+
"""
573+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
574+
to the number of layers in the model.
575+
"""
576+
return len(self.key_cache)
577+
578+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
579+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
580+
# TODO: deprecate this function in favor of `cache_position`
581+
is_empty_layer = (
582+
len(self.key_cache) == 0 # no cache in any layer
583+
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
584+
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
585+
)
586+
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
587+
return layer_seq_length
588+
589+
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
590+
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
591+
backward compatibility."""
592+
legacy_cache = ()
593+
for layer_idx in range(len(self)):
594+
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
595+
return legacy_cache
596+
597+
def update(
598+
self,
599+
key_states: torch.Tensor,
600+
value_states: torch.Tensor,
601+
layer_idx: int,
602+
cache_kwargs: Optional[Dict[str, Any]] = None,
603+
) -> Tuple[torch.Tensor, torch.Tensor]:
604+
if len(self.key_cache) <= layer_idx:
605+
self.key_cache.append(key_states)
606+
self.value_cache.append(value_states)
607+
k_out, v_out = key_states, value_states
608+
else:
609+
position_ids = cache_kwargs.get("position_ids")
610+
is_sliding_layer = cache_kwargs.get("is_sliding")
611+
sliding_window = cache_kwargs.get("sliding_window")
612+
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs
613+
614+
if is_sliding_layer:
615+
kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window)
616+
else:
617+
kv_position_ids = position_ids
618+
619+
if batch_index is not None:
620+
if torch.onnx.is_in_onnx_export():
621+
invalid_scatter_index = torch.iinfo(torch.int32).max
622+
scatter_position_ids = torch.where(kv_position_ids < 0, invalid_scatter_index, kv_position_ids)
623+
else:
624+
scatter_position_ids = kv_position_ids
625+
self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
626+
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
627+
)
628+
self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
629+
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
630+
)
631+
else:
632+
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
633+
self.value_cache[layer_idx] = CtxScatterFunc.apply(
634+
self.value_cache[layer_idx], kv_position_ids, value_states
635+
)
636+
637+
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
638+
639+
# Original Gather
640+
ctx_len = self.key_cache[layer_idx].shape[2]
641+
ctx_indices = torch.arange(ctx_len)[None, None, ...]
642+
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
643+
invalid_mask = ctx_indices > gather_limit
644+
if torch.onnx.is_in_onnx_export():
645+
invalid_idx_value = torch.iinfo(torch.int32).max
646+
else:
647+
invalid_idx_value = 0
648+
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
649+
650+
if batch_index is not None:
651+
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
652+
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
653+
else:
654+
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
655+
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
656+
657+
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
658+
return k_out, v_out

QEfficient/transformers/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@
185185
]
186186
)
187187

188+
# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc.
188189
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}
189190

190191
# Define a transformers layers to QEff layers dictionary
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------

0 commit comments

Comments
 (0)