Skip to content
Open
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,39 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
return self.w2_linear[expert_idx](x1)


class _QuantQwen3VLMoeTextDecoderLayer(QuantModule):
def _setup(self):
"""Modify the Qwen3VLMoeTextDecoderLayer by using Qwen3MoeSparseMoeBlock."""
from accelerate import init_empty_weights
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextSparseMoeBlock,
)

if not isinstance(self.mlp, Qwen3VLMoeTextSparseMoeBlock):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of whether we want to directly add support to Qwen3VLMoeTextSparseMoeBlock, it feels a bit fragile to use a replacement module.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've refactored this part.

return
q_proj_weight = self.self_attn.q_proj.weight
dtype, device = q_proj_weight.dtype, q_proj_weight.device

def _copy_weight(module, weight):
module.to_empty(device=device)
with torch.no_grad():
module.weight.data = weight.detach().data.to(dtype=dtype, device=device)

with init_empty_weights():
new_moe_layer = Qwen3MoeSparseMoeBlock(self.self_attn.config)
new_moe_layer.gate = self.mlp.gate
experts = self.mlp.experts
expert_dim = experts.expert_dim
for idx, expert in enumerate(new_moe_layer.experts):
_copy_weight(expert.gate_proj, experts.gate_up_proj[idx, :, :expert_dim].T)
_copy_weight(expert.up_proj, experts.gate_up_proj[idx, :, expert_dim:].T)
_copy_weight(expert.down_proj, experts.down_proj[idx, :].T)

delattr(self, "mlp")
self.mlp = new_moe_layer


class _QuantDbrxFFN(_QuantSparseMoe):
@property
def num_experts(self):
Expand Down Expand Up @@ -577,6 +610,17 @@ def top_k(self, value):
pass


try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextDecoderLayer

if Qwen3VLMoeTextDecoderLayer not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3VLMoeTextDecoderLayer: "hf.Qwen3VLMoeTextDecoderLayer"})(
_QuantQwen3VLMoeTextDecoderLayer
)
except ImportError:
pass


class _QuantGptOssExperts(_QuantFunctionalMixin):
"""Quantized wrapper for `transformers.GptOssExperts`.

Expand Down