-
Notifications
You must be signed in to change notification settings - Fork 619
add qwen3 moe #631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
HaloKim
wants to merge
2
commits into
arcee-ai:main
Choose a base branch
from
HaloKim:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+218
−5
Open
add qwen3 moe #631
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
# Copyright (C) 2025 Arcee AI | ||
# SPDX-License-Identifier: BUSL-1.1 | ||
|
||
import json | ||
import logging | ||
import os | ||
from typing import List, Optional | ||
|
||
import torch | ||
import tqdm | ||
import transformers | ||
|
||
from mergekit.architecture import arch_info_for_config | ||
from mergekit.architecture.json_definitions import NAME_TO_ARCH | ||
from mergekit.moe.arch import MoEOutputArchitecture | ||
from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype | ||
from mergekit.moe.config import MoEMergeConfig | ||
from mergekit.options import MergeOptions | ||
|
||
QWEN3_INFO = NAME_TO_ARCH["Qwen3ForCausalLM"][0] | ||
|
||
|
||
class Qwen3MoE(MoEOutputArchitecture): | ||
def name(self) -> str: | ||
return "Qwen3 MoE" | ||
|
||
def supports_config( | ||
self, | ||
config: MoEMergeConfig, | ||
explain: bool = False, | ||
trust_remote_code: bool = False, | ||
) -> bool: | ||
model_types = [] | ||
for model_ref in ( | ||
[config.base_model] | ||
+ [e.source_model for e in config.experts] | ||
+ [e.source_model for e in (config.shared_experts or [])] | ||
): | ||
model_cfg = model_ref.config(trust_remote_code=trust_remote_code) | ||
model_types.append(model_cfg.model_type) | ||
|
||
if len(set(model_types)) != 1: | ||
if explain: | ||
logging.warning( | ||
"Qwen3 MoE requires all input models to have the same architecture" | ||
) | ||
return False | ||
|
||
if model_types[0] != "qwen3": | ||
if explain: | ||
logging.warning( | ||
"Qwen3 MoE requires input models to be Qwen3 architecture" | ||
) | ||
return False | ||
|
||
return True | ||
|
||
def _generate_config( | ||
self, | ||
base_config: transformers.PretrainedConfig, | ||
num_experts: int, | ||
num_shared_experts: int = 0, | ||
experts_per_token: Optional[int] = None, | ||
) -> dict: | ||
res = base_config.to_dict() | ||
res["architectures"] = ["Qwen3MoeForCausalLM"] | ||
res["model_type"] = "qwen3_moe" | ||
res["num_experts"] = num_experts | ||
res["num_experts_per_tok"] = experts_per_token or 2 | ||
res["decoder_sparse_step"] = 1 | ||
res["norm_topk_prob"] = True | ||
res["sliding_window"] = None | ||
res["use_sliding_window"] = False | ||
res["moe_intermediate_size"] = res["intermediate_size"] | ||
|
||
if num_shared_experts > 0: | ||
res["shared_expert_intermediate_size"] = res["intermediate_size"] | ||
|
||
if (res["num_experts"] & (res["num_experts"] - 1)) != 0: | ||
logging.warning( | ||
f"Your model has {res['num_experts']} experts, which is " | ||
"not a power of two. The model will not be usable in llama.cpp." | ||
) | ||
return res | ||
|
||
def write_model( | ||
self, | ||
out_path: str, | ||
config: MoEMergeConfig, | ||
merge_options: MergeOptions, | ||
router_weights: List[torch.Tensor], | ||
shared_router_weights: Optional[List[torch.Tensor]] = None, | ||
): | ||
base_model = config.base_model | ||
base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code) | ||
|
||
# 출력 디렉토리 생성 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Untranslated Comments in Python FileThe Additional Locations (4) |
||
os.makedirs(out_path, exist_ok=True) | ||
|
||
out_dtype = select_dtype(config, base_cfg) | ||
out_cfg = self._generate_config( | ||
base_cfg, | ||
len(config.experts), | ||
len(config.shared_experts or []), | ||
config.experts_per_token, | ||
) | ||
if out_dtype is not None: | ||
out_cfg["torch_dtype"] = str(out_dtype).removeprefix("torch.") | ||
|
||
with open(os.path.join(out_path, "config.json"), "w", encoding="utf-8") as f: | ||
json.dump(out_cfg, f, indent=4) | ||
|
||
shared_def = config.shared_experts[0] if config.shared_experts else None | ||
|
||
loaders, base_loader, writer = initialize_io(config, out_path, merge_options) | ||
shared_loader = loaders.get(shared_def.source_model) if shared_def else base_loader | ||
|
||
for weight_info in tqdm.tqdm( | ||
QWEN3_INFO.all_weights(base_cfg), | ||
desc="Weights", | ||
): | ||
tensor_name = weight_info.name | ||
if ".mlp." in tensor_name: | ||
# Expert weights 복사 | ||
for expert_idx, expert in enumerate(config.experts): | ||
expert_name = tensor_name.replace( | ||
".mlp.", f".mlp.experts.{expert_idx}." | ||
) | ||
expert_loader = loaders.get(expert.source_model) | ||
copy_tensor_out( | ||
weight_info, | ||
expert_loader, | ||
writer, | ||
expert=expert, | ||
is_residual="down_proj" in tensor_name, | ||
output_name=expert_name, | ||
out_dtype=out_dtype, | ||
clone=merge_options.clone_tensors, | ||
) | ||
|
||
# Shared expert weights 복사 - shared_experts가 있을 때만! | ||
if shared_def is not None: | ||
shared_expert_name = tensor_name.replace(".mlp.", ".mlp.shared_expert.") | ||
copy_tensor_out( | ||
weight_info, | ||
shared_loader, | ||
writer, | ||
expert=shared_def, | ||
is_residual="down_proj" in tensor_name, | ||
output_name=shared_expert_name, | ||
out_dtype=out_dtype, | ||
clone=merge_options.clone_tensors, | ||
) | ||
else: | ||
# 일반 weights 복사 | ||
copy_tensor_out( | ||
weight_info, | ||
base_loader, | ||
writer, | ||
out_dtype=out_dtype, | ||
clone=merge_options.clone_tensors, | ||
) | ||
|
||
# Router weights 저장 | ||
for layer_idx, weight in enumerate( | ||
tqdm.tqdm(router_weights, desc="Router weights") | ||
): | ||
writer.save_tensor( | ||
f"model.layers.{layer_idx}.mlp.gate.weight", | ||
weight.to(dtype=out_dtype).contiguous(), | ||
clone=merge_options.clone_tensors, | ||
) | ||
|
||
# Shared expert gate weights 저장 - shared_experts가 있을 때만! | ||
if shared_def is not None: | ||
if shared_router_weights is not None and len(shared_router_weights) > layer_idx: | ||
shared_weight = shared_router_weights[layer_idx] | ||
else: | ||
# shared_router_weights가 없으면 dummy weight 생성 | ||
shared_weight = torch.zeros_like(weight[:1, :]) # [1, hidden_size] | ||
|
||
writer.save_tensor( | ||
f"model.layers.{layer_idx}.mlp.shared_expert_gate.weight", | ||
shared_weight.to(dtype=out_dtype).contiguous(), | ||
clone=merge_options.clone_tensors, | ||
) | ||
|
||
writer.finalize() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Layer Weights Naming Conflict
The
Qwen3MoeModuleArchitecture.layer_weights
method unconditionally adds shared expert and shared expert gate weight names. This conflicts withwrite_model
's conditional writing of these weights, potentially causing them to be missing. Additionally, the declared names clash withwrite_model
's path transformation, leading to malformed weight paths.