From 618a378c2c51aa90db3d0cc640839917a9002da5 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Fri, 10 Oct 2025 15:51:04 -0700 Subject: [PATCH 1/7] upd Signed-off-by: Qidong Su --- .../torch/quantization/plugins/huggingface.py | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index a3fa6ef1a..8de45cf4c 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -487,7 +487,34 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: return self.w2_linear[expert_idx](x1) -class _QuantDbrxFFN(_QuantSparseMoe): +class _QuantQwen3VLMoeTextDecoderLayer(QuantModule): + def _setup(self): + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock + if not isinstance(self.mlp, Qwen3VLMoeTextSparseMoeBlock): + print(f"Skipping {type(self.mlp)}") + return + q_proj_weight = self.self_attn.q_proj.weight + dtype, device = q_proj_weight.dtype, q_proj_weight.device + def _copy_weight(module, weight): + module.to(dtype=dtype, device=device) + with torch.no_grad(): + module.weight.copy_(weight.detach()) + + new_moe_layer = Qwen3MoeSparseMoeBlock(self.self_attn.config) + new_moe_layer.gate = self.mlp.gate + experts = self.mlp.experts + expert_dim = experts.expert_dim + for idx, expert in enumerate(new_moe_layer.experts): + _copy_weight(expert.gate_proj, experts.gate_up_proj[idx, :, :expert_dim].T) + _copy_weight(expert.up_proj, experts.gate_up_proj[idx, :, expert_dim:].T) + _copy_weight(expert.down_proj, experts.down_proj[idx, :].T) + + delattr(self, "mlp") + self.mlp = new_moe_layer + + +class _QuantDbrxFFN(_QuantMoeSparseMoe): @property def num_experts(self): return self.router.moe_num_experts @@ -577,6 +604,16 @@ def top_k(self, value): pass +try: + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextDecoderLayer + + if Qwen3VLMoeTextDecoderLayer not in QuantModuleRegistry: + QuantModuleRegistry.register({Qwen3VLMoeTextDecoderLayer: "hf.Qwen3VLMoeTextDecoderLayer"})( + _QuantQwen3VLMoeTextDecoderLayer + ) +except ImportError: + pass + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`. From ba246535533cd749d525807915a9fb1f5c62cd80 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Mon, 8 Dec 2025 19:52:08 +0000 Subject: [PATCH 2/7] fix Signed-off-by: Qidong Su --- modelopt/torch/quantization/plugins/huggingface.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 8de45cf4c..fc6c47d74 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -490,17 +490,21 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: class _QuantQwen3VLMoeTextDecoderLayer(QuantModule): def _setup(self): from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock - from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeTextSparseMoeBlock, + ) + if not isinstance(self.mlp, Qwen3VLMoeTextSparseMoeBlock): print(f"Skipping {type(self.mlp)}") return q_proj_weight = self.self_attn.q_proj.weight dtype, device = q_proj_weight.dtype, q_proj_weight.device + def _copy_weight(module, weight): module.to(dtype=dtype, device=device) with torch.no_grad(): module.weight.copy_(weight.detach()) - + new_moe_layer = Qwen3MoeSparseMoeBlock(self.self_attn.config) new_moe_layer.gate = self.mlp.gate experts = self.mlp.experts @@ -509,10 +513,10 @@ def _copy_weight(module, weight): _copy_weight(expert.gate_proj, experts.gate_up_proj[idx, :, :expert_dim].T) _copy_weight(expert.up_proj, experts.gate_up_proj[idx, :, expert_dim:].T) _copy_weight(expert.down_proj, experts.down_proj[idx, :].T) - + delattr(self, "mlp") self.mlp = new_moe_layer - + class _QuantDbrxFFN(_QuantMoeSparseMoe): @property @@ -614,6 +618,7 @@ def top_k(self, value): except ImportError: pass + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`. From 700eec42b92d4f2743f45ee4ede449a926fe99b2 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Mon, 8 Dec 2025 20:02:18 +0000 Subject: [PATCH 3/7] fix Signed-off-by: Qidong Su --- modelopt/torch/quantization/plugins/huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index fc6c47d74..637e79645 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -518,7 +518,7 @@ def _copy_weight(module, weight): self.mlp = new_moe_layer -class _QuantDbrxFFN(_QuantMoeSparseMoe): +class _QuantDbrxFFN(_QuantSparseMoe): @property def num_experts(self): return self.router.moe_num_experts From e7d795b52cdcfa93136e2e962b0e17832e6cf7e2 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Mon, 8 Dec 2025 20:06:06 +0000 Subject: [PATCH 4/7] fix Signed-off-by: Qidong Su --- modelopt/torch/quantization/plugins/huggingface.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 637e79645..2023842ca 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -495,7 +495,6 @@ def _setup(self): ) if not isinstance(self.mlp, Qwen3VLMoeTextSparseMoeBlock): - print(f"Skipping {type(self.mlp)}") return q_proj_weight = self.self_attn.q_proj.weight dtype, device = q_proj_weight.dtype, q_proj_weight.device From 7dae299489fae866e3226e6994b43821823af894 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Mon, 8 Dec 2025 22:03:39 +0000 Subject: [PATCH 5/7] upd Signed-off-by: Qidong Su --- modelopt/torch/quantization/plugins/huggingface.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 2023842ca..fb55f049f 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -489,6 +489,8 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: class _QuantQwen3VLMoeTextDecoderLayer(QuantModule): def _setup(self): + """Modify the Qwen3VLMoeTextDecoderLayer by using Qwen3MoeSparseMoeBlock.""" + from accelerate import init_empty_weights from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( Qwen3VLMoeTextSparseMoeBlock, @@ -500,11 +502,12 @@ def _setup(self): dtype, device = q_proj_weight.dtype, q_proj_weight.device def _copy_weight(module, weight): - module.to(dtype=dtype, device=device) + module.to_empty(device=device) with torch.no_grad(): - module.weight.copy_(weight.detach()) + module.weight.data = weight.detach().data.to(dtype=dtype, device=device) - new_moe_layer = Qwen3MoeSparseMoeBlock(self.self_attn.config) + with init_empty_weights(): + new_moe_layer = Qwen3MoeSparseMoeBlock(self.self_attn.config) new_moe_layer.gate = self.mlp.gate experts = self.mlp.experts expert_dim = experts.expert_dim From 9bbeab815636cd801d6aca0d7fa76f8eb94fc01b Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Wed, 10 Dec 2025 21:40:09 +0000 Subject: [PATCH 6/7] refactor to directly impl qwen3_vl_moe Signed-off-by: Qidong Su --- .../torch/quantization/plugins/huggingface.py | 98 ++++++++++++++----- 1 file changed, 72 insertions(+), 26 deletions(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index fb55f049f..fc8f6be25 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -487,37 +487,73 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: return self.w2_linear[expert_idx](x1) -class _QuantQwen3VLMoeTextDecoderLayer(QuantModule): +class _QuantQwen3VLMoeTextExperts(QuantModule): def _setup(self): - """Modify the Qwen3VLMoeTextDecoderLayer by using Qwen3MoeSparseMoeBlock.""" + """Modify the Qwen3VLMoeTextExperts by using nn.Linear layers.""" from accelerate import init_empty_weights - from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock - from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( - Qwen3VLMoeTextSparseMoeBlock, - ) - - if not isinstance(self.mlp, Qwen3VLMoeTextSparseMoeBlock): - return - q_proj_weight = self.self_attn.q_proj.weight - dtype, device = q_proj_weight.dtype, q_proj_weight.device + dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device def _copy_weight(module, weight): module.to_empty(device=device) with torch.no_grad(): module.weight.data = weight.detach().data.to(dtype=dtype, device=device) - + with init_empty_weights(): - new_moe_layer = Qwen3MoeSparseMoeBlock(self.self_attn.config) - new_moe_layer.gate = self.mlp.gate - experts = self.mlp.experts - expert_dim = experts.expert_dim - for idx, expert in enumerate(new_moe_layer.experts): - _copy_weight(expert.gate_proj, experts.gate_up_proj[idx, :, :expert_dim].T) - _copy_weight(expert.up_proj, experts.gate_up_proj[idx, :, expert_dim:].T) - _copy_weight(expert.down_proj, experts.down_proj[idx, :].T) + gate_proj = nn.ModuleList( + [ + nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + for _ in range(self.num_experts) + ] + ) + up_proj = nn.ModuleList( + [ + nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + for _ in range(self.num_experts) + ] + ) + down_proj = nn.ModuleList( + [ + nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + for _ in range(self.num_experts) + ] + ) + + for idx in range(self.num_experts): + _copy_weight(gate_proj[idx], self.gate_up_proj[idx, :, :self.expert_dim].T) + _copy_weight(up_proj[idx], self.gate_up_proj[idx, :, self.expert_dim:].T) + _copy_weight(down_proj[idx], self.down_proj[idx, :].T) + + delattr(self, "gate_up_proj") + delattr(self, "down_proj") + self.gate_proj = gate_proj + self.up_proj = up_proj + self.down_proj = down_proj + + def forward( + self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + next_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit[:]: + assert expert_idx.numel() == 1, expert_idx + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + gate = self.gate_proj[expert_idx](current_state) + up = self.up_proj[expert_idx](current_state) + gated_output = up * self.act_fn(gate) + out = self.down_proj[expert_idx](gated_output) + weighted_output = out * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + + return next_states - delattr(self, "mlp") - self.mlp = new_moe_layer class _QuantDbrxFFN(_QuantSparseMoe): @@ -609,13 +645,23 @@ def top_k(self, value): except ImportError: pass +try: + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock + + if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry: + QuantModuleRegistry.register({Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"})( + _QuantSparseMoe + ) +except ImportError: + pass + try: - from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextDecoderLayer + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts - if Qwen3VLMoeTextDecoderLayer not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen3VLMoeTextDecoderLayer: "hf.Qwen3VLMoeTextDecoderLayer"})( - _QuantQwen3VLMoeTextDecoderLayer + if Qwen3VLMoeTextExperts not in QuantModuleRegistry: + QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})( + _QuantQwen3VLMoeTextExperts ) except ImportError: pass From c846c65a1cb87f2c292198dcd03dbe62fe5eb5a4 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Wed, 10 Dec 2025 21:46:17 +0000 Subject: [PATCH 7/7] format Signed-off-by: Qidong Su --- .../torch/quantization/plugins/huggingface.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index fc8f6be25..1d1d02929 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -491,13 +491,14 @@ class _QuantQwen3VLMoeTextExperts(QuantModule): def _setup(self): """Modify the Qwen3VLMoeTextExperts by using nn.Linear layers.""" from accelerate import init_empty_weights + dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device def _copy_weight(module, weight): module.to_empty(device=device) with torch.no_grad(): module.weight.data = weight.detach().data.to(dtype=dtype, device=device) - + with init_empty_weights(): gate_proj = nn.ModuleList( [ @@ -519,10 +520,10 @@ def _copy_weight(module, weight): ) for idx in range(self.num_experts): - _copy_weight(gate_proj[idx], self.gate_up_proj[idx, :, :self.expert_dim].T) - _copy_weight(up_proj[idx], self.gate_up_proj[idx, :, self.expert_dim:].T) + _copy_weight(gate_proj[idx], self.gate_up_proj[idx, :, : self.expert_dim].T) + _copy_weight(up_proj[idx], self.gate_up_proj[idx, :, self.expert_dim :].T) _copy_weight(down_proj[idx], self.down_proj[idx, :].T) - + delattr(self, "gate_up_proj") delattr(self, "down_proj") self.gate_proj = gate_proj @@ -530,7 +531,10 @@ def _copy_weight(module, weight): self.down_proj = down_proj def forward( - self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor + self, + hidden_states: torch.Tensor, + routing_weights: torch.Tensor, + router_indices: torch.Tensor, ) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) @@ -552,8 +556,7 @@ def forward( next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states - + return next_states class _QuantDbrxFFN(_QuantSparseMoe): @@ -649,9 +652,9 @@ def top_k(self, value): from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"})( - _QuantSparseMoe - ) + QuantModuleRegistry.register( + {Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"} + )(_QuantSparseMoe) except ImportError: pass