Skip to content

Add Llama4 Multi-Modal Support #382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ceee86a
Add Llama4 Multi-Modal Support
vbaddi Apr 29, 2025
693e474
nit: modeling changes
vbaddi Apr 29, 2025
ed85612
Adding Vision Part and Chunking (#383)
mohiso22 Apr 30, 2025
77df9d2
nit: update moe implementation and add sample export/compile script
vbaddi May 2, 2025
b5eb805
nit: fix linter for example script
vbaddi May 2, 2025
eface81
nit: update pytorch transforms to map Llama4TextExperts
vbaddi May 2, 2025
b1eae7e
nit: update modeling with new freq apply computation and sample mm ex…
vbaddi May 4, 2025
5aca1ef
nit: update llama4 mm example script
vbaddi May 4, 2025
91eca48
nit: update modeling to avoid >2GiB issue in Onnx, rope max-position
vbaddi May 6, 2025
a661727
Added pytorch transform for the split_gate_up_weights and removed exa…
quic-amitraj May 8, 2025
3cc5511
Ruff Check and format
quic-amitraj May 8, 2025
6ca555e
Minor fixes-1
quic-amitraj May 8, 2025
7482533
Added logger for new transform
quic-amitraj May 8, 2025
a83f23f
fixed Llama4 MOE accuracy bug
ochougul May 18, 2025
f98f4af
Updating index method in Wrappers (#410)
mohiso22 May 19, 2025
1b36e90
nit: add position_ids to attn_scales instead of cache_position in use…
vbaddi May 20, 2025
1ebd9a8
Minor Fixes (#421)
mohiso22 May 21, 2025
63b0b10
Updating Specialization and modeling auto files
mohiso22 May 21, 2025
8e0a613
Fix for Multi Image Chunking
mohiso22 May 23, 2025
c91dd9f
Adding SingleQPC
mohiso22 Jun 1, 2025
aefa2b0
Rebase and Minor Fixes
mohiso22 Jun 9, 2025
ba30051
Addressed Comments
mohiso22 Jun 9, 2025
64ec975
Addressed Comments
mohiso22 Jun 10, 2025
d15fe69
Addressed comments
quic-amitraj Jun 10, 2025
b7e1a46
Header and doc update
quic-rishinr Jun 10, 2025
d8a947a
Updated get_available_device_id method
quic-rishinr Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from torch import nn

from QEfficient.utils.logging_utils import logger


class PytorchTransform:
"""
Expand Down Expand Up @@ -110,3 +112,65 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = True

return model, transformed


class SplitGateUpWeightsTransform(PytorchTransform):
"""
split fused Gate+Up weights and copy into the model

For every transformer layer inside `model`:
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]
"""

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__

if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS:
return model, transformed

model_tmp = model.language_model if hasattr(model, "language_model") else model

num_layers = len(model_tmp.model.layers)
delete_fused_key = True
sd = model_tmp.state_dict()
for layer_idx in range(num_layers):
# ---- build the textual prefix once per layer ----------
prefix = f"model.layers.{layer_idx}.feed_forward.experts."

fused_key = prefix + "gate_up_proj"
gate_key = prefix + "gate_proj"
up_key = prefix + "up_proj"

# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
E, H, two_I = fused.shape
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

experts = model_tmp.model.layers[layer_idx].feed_forward.experts
experts.gate_proj.data.copy_(gate)
experts.up_proj.data.copy_(up)

# ---- update the state-dict so load_state_dict sees the right keys
sd[gate_key] = gate
sd[up_key] = up

if delete_fused_key:
del sd[fused_key]

logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
transformed = True

if hasattr(model, "language_model"):
model.language_model = model_tmp
else:
model = model_tmp
return model, transformed


VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration"}
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/llama4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Loading
Loading