Skip to content

Add Llama4 Multi-Modal Support #382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch

from QEfficient.base.onnx_transforms import OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.base.pytorch_transforms import PytorchTransform, append_tranform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import constants, dump_qconfig
Expand All @@ -46,6 +46,7 @@ class QEFFBaseModel(ABC):
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]

@append_tranform
def __init__(self, model: torch.nn.Module) -> None:
super().__init__()
self.model = model
Expand Down
65 changes: 65 additions & 0 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from torch import nn

from QEfficient.utils.logging_utils import logger


class PytorchTransform:
"""
Expand Down Expand Up @@ -110,3 +112,66 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = True

return model, transformed


class SplitGateUpWeightsTransform(PytorchTransform):
"""
split fused Gate+Up weights and copy into the model

For every transformer layer inside `model`:
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]
"""

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False

model = model.language_model if hasattr(model, "language_model") else model

num_layers = len(model.model.layers)
delete_fused_key = True
sd = model.state_dict()
for layer_idx in range(num_layers):
# ---- build the textual prefix once per layer ----------
prefix = f"model.layers.{layer_idx}.feed_forward.experts."

fused_key = prefix + "gate_up_proj"
gate_key = prefix + "gate_proj"
up_key = prefix + "up_proj"

# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
E, H, two_I = fused.shape
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

experts = model.model.layers[layer_idx].feed_forward.experts
experts.gate_proj.data.copy_(gate)
experts.up_proj.data.copy_(up)

# ---- update the state-dict so load_state_dict sees the right keys
sd[gate_key] = gate
sd[up_key] = up

if delete_fused_key:
del sd[fused_key]

logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
transformed = True
return model, transformed


VLM_SPLIT_GATE_UP_WEIGHTS = ["Llama4ForConditionalGeneration", "Llama4TextModel"]


def append_tranform(func):
def wrapper(*args, **kwargs):
model_class = args[1].model.__class__.__name__ if hasattr(args[1], "model") else args[1].__class__.__name__
if model_class in VLM_SPLIT_GATE_UP_WEIGHTS:
args[0]._pytorch_transforms.append(SplitGateUpWeightsTransform)
return func(*args, **kwargs)

return wrapper
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/llama4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Loading