From 784c3ca07cb8daa80063f4f5965fdfbbb2e3d3b6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 17 Oct 2025 14:14:52 +0800 Subject: [PATCH 01/30] update --- swift/mcore_bridge/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 swift/mcore_bridge/__init__.py diff --git a/swift/mcore_bridge/__init__.py b/swift/mcore_bridge/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From 954503035d663cd688a02473c8da6f9505f862b5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 23 Oct 2025 15:03:53 +0800 Subject: [PATCH 02/30] update --- ...53\351\200\237\345\274\200\345\247\213.md" | 2 +- docs/source_en/Megatron-SWIFT/Quick-start.md | 2 +- swift/mcore_bridge/__init__.py | 0 swift/megatron/__init__.py | 2 + swift/megatron/bridge/__init__.py | 1 + swift/megatron/bridge/auto.py | 13 ++ swift/megatron/init.py | 10 + swift/megatron/model/gpt/hf2mcore.py | 211 ++++++++++-------- swift/megatron/model/gpt_model.py | 3 + swift/megatron/train/sft.py | 2 - swift/megatron/utils/__init__.py | 1 - swift/megatron/utils/convert.py | 10 +- swift/megatron/utils/patcher.py | 8 - tests/megatron/test_auto.py | 3 + 14 files changed, 164 insertions(+), 104 deletions(-) delete mode 100644 swift/mcore_bridge/__init__.py create mode 100644 swift/megatron/bridge/__init__.py create mode 100644 swift/megatron/bridge/auto.py create mode 100644 tests/megatron/test_auto.py diff --git "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" index f6e520c0da..bec829723e 100644 --- "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -65,7 +65,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 | torch | >=2.0 | 2.6.0/2.7.1 | | | transformer_engine | >=2.3 | | | | apex | | 0.1 | | -| megatron_core | >=0.12 | 0.13 | | +| megatron_core | | 0.13 | | | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.56.2 | | | modelscope | >=1.23 | | | diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 03b1558f3d..59ac888b65 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -65,7 +65,7 @@ Recommended Operating Environment: | torch | >=2.0 | 2.6.0/2.7.1 | | | transformer_engine | >=2.3 | | | | apex | | 0.1 | | -| megatron_core | >=0.12 | 0.13 | | +| megatron_core | | 0.13 | | | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.56.2 | | | modelscope | >=1.23 | | | diff --git a/swift/mcore_bridge/__init__.py b/swift/mcore_bridge/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index 0a5a41ebc1..bca6522cc3 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -18,6 +18,7 @@ from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model from .trainers import MegatronTrainer, MegatronDPOTrainer from .tuners import LoraParallelLinear + from .bridge import AutoMcoreModel else: _import_structure = { 'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'], @@ -26,6 +27,7 @@ 'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model'], 'trainers': ['MegatronTrainer', 'MegatronDPOTrainer'], 'tuners': ['LoraParallelLinear'], + 'bridge': ['AutoMcoreModel'], } import sys diff --git a/swift/megatron/bridge/__init__.py b/swift/megatron/bridge/__init__.py new file mode 100644 index 0000000000..653cea3525 --- /dev/null +++ b/swift/megatron/bridge/__init__.py @@ -0,0 +1 @@ +from .auto import AutoMcoreModel diff --git a/swift/megatron/bridge/auto.py b/swift/megatron/bridge/auto.py new file mode 100644 index 0000000000..8cea017155 --- /dev/null +++ b/swift/megatron/bridge/auto.py @@ -0,0 +1,13 @@ + +from megatron.training import get_args + +class AutoMcoreModel: + + @classmethod + def build_model(cls): + args = get_args() + model_meta = args.model_meta + model_info = args.model_info + megatron_model_meta = args.megatron_model_meta + logger.info(f'Creating mcore_model using model_dir: {model_info.model_dir}') + diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 33abf259a1..57c91400c2 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -378,6 +378,15 @@ def sharded_state_dict( TEGroupedLinear.sharded_state_dict = sharded_state_dict +def _patch_megatron_tokenizer(): + from megatron.training import global_vars + + def build_tokenizer(args): + return + + global_vars.build_tokenizer = build_tokenizer + + def _patch_peft_ModulesToSaveWrapper(): if version.parse(peft.__version__) >= version.parse('0.16'): from peft.utils import other as peft_module @@ -660,6 +669,7 @@ def _patch_megatron(): _patch_compile_helpers() _patch_build_train_valid_test_datasets() _patch_mrope() + _patch_megatron_tokenizer() from swift.megatron import tuners # patch lora try: _patch_torch_FileSystemReader() diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py index cc8960163b..6393c31936 100644 --- a/swift/megatron/model/gpt/hf2mcore.py +++ b/swift/megatron/model/gpt/hf2mcore.py @@ -4,124 +4,161 @@ import torch from megatron.training import get_args from torch import nn +from tqdm import tqdm -def set_mla_attn_state(args, mg_attn, hf_attn): - mg_attn.linear_proj.weight.data.copy_(hf_attn.o_proj.weight) +def set_mla_attn_state(args, state_dict, prefix: str): + mg_state_dict = {} + mg_state_dict['linear_proj.weight'] = state_dict['o_proj.weight'] if args.q_lora_rank is None: - mg_attn.linear_q_proj.weight.data.copy_(hf_attn.q_proj.weight) + mg_state_dict['linear_q_proj.weight'] = state_dict['q_proj.weight'] else: - mg_attn.linear_q_down_proj.weight.data.copy_(hf_attn.q_a_proj.weight) - mg_attn.linear_q_up_proj.weight.data.copy_(hf_attn.q_b_proj.weight) - mg_attn.linear_kv_down_proj.weight.data.copy_(hf_attn.kv_a_proj_with_mqa.weight) - mg_attn.linear_kv_up_proj.weight.data.copy_(hf_attn.kv_b_proj.weight) + mg_state_dict['linear_q_down_proj.weight'] = state_dict['q_a_proj.weight'] + mg_state_dict['linear_q_up_proj.weight'] = state_dict['q_b_proj.weight'] + mg_state_dict['linear_kv_down_proj.weight'] = state_dict['kv_a_proj_with_mqa.weight'] + mg_state_dict['linear_kv_up_proj.weight'] = state_dict['kv_b_proj.weight'] if args.qk_layernorm: - mg_attn.linear_kv_up_proj.layer_norm_weight.data.copy_(hf_attn.kv_a_layernorm.weight) + mg_state_dict['linear_kv_up_proj.layer_norm_weight'] = state_dict['kv_a_layernorm.weight'] + return _add_prefix(mg_state_dict, prefix) -def set_attn_state(args, mg_attn, hf_attn): +def set_attn_state(args, state_dict, prefix: str): + mg_state_dict = {} num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) - - # Copy weights - mg_attn.linear_qkv.weight.data.copy_( - torch.cat([ - hf_attn.q_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), - hf_attn.k_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), - hf_attn.v_proj.weight.reshape((num_query_groups, -1, args.hidden_size)), - ], - dim=1).reshape((-1, args.hidden_size))) - mg_attn.linear_proj.weight.data.copy_(hf_attn.o_proj.weight) + mg_state_dict['linear_qkv.weight'] = torch.cat([ + state_dict['q_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), + state_dict['k_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), + state_dict['v_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), + ], + dim=1).reshape((-1, args.hidden_size)) + mg_state_dict['linear_proj.weight'] = state_dict['o_proj.weight'] # Copy bias if args.add_qkv_bias: - mg_attn.linear_qkv.bias.data.copy_( - torch.cat([ - hf_attn.q_proj.bias.reshape((num_query_groups, -1)), - hf_attn.k_proj.bias.reshape((num_query_groups, -1)), - hf_attn.v_proj.bias.reshape((num_query_groups, -1)), - ], - dim=1).reshape(-1)) + mg_state_dict['linear_qkv.bias'] = torch.cat([ + state_dict['q_proj.bias'].reshape((num_query_groups, -1)), + state_dict['k_proj.bias'].reshape((num_query_groups, -1)), + state_dict['v_proj.bias'].reshape((num_query_groups, -1)), + ], + dim=1).reshape(-1) if args.qk_layernorm: - q_norm = hf_attn.query_layernorm if hasattr(hf_attn, 'query_layernorm') else hf_attn.q_norm - k_norm = hf_attn.key_layernorm if hasattr(hf_attn, 'key_layernorm') else hf_attn.k_norm - mg_attn.q_layernorm.weight.data.copy_(q_norm.weight) - mg_attn.k_layernorm.weight.data.copy_(k_norm.weight) - - -def _set_mlp_state(mg_mlp, hf_mlp, group_idx: Optional[int] = None): - hf_grouped = not isinstance(hf_mlp.down_proj, nn.Module) + mg_state_dict['q_layernorm.weight'] = state_dict['query_layernorm.weight'] if state_dict.get( + 'q_norm.weight') is None else state_dict['q_norm.weight'] + mg_state_dict['k_layernorm.weight'] = state_dict['key_layernorm.weight'] if state_dict.get( + 'k_norm.weight') is None else state_dict['k_norm.weight'] + + return _add_prefix(mg_state_dict, prefix) + + +def _set_mlp_state( + args, + state_dict, + prefix: str, + group_idx: Optional[int] = None, +): + mg_state_dict = {} + hf_grouped = False + # hf_grouped = not isinstance(hf_mlp.down_proj, nn.Module) if group_idx is None: - linear_fc1_weight = mg_mlp.linear_fc1.weight - linear_fc2_weight = mg_mlp.linear_fc2.weight + fc1_key = 'linear_fc1.weight' + fc2_key = 'linear_fc2.weight' else: - linear_fc1_weight = getattr(mg_mlp.linear_fc1, f'weight{group_idx}') - linear_fc2_weight = getattr(mg_mlp.linear_fc2, f'weight{group_idx}') + fc1_key = f'linear_fc1.weight{group_idx}' + fc2_key = f'linear_fc2.weight{group_idx}' if hf_grouped: linear_fc1_weight.data.copy_(hf_mlp.gate_up_proj[group_idx].t()) linear_fc2_weight.data.copy_(hf_mlp.down_proj[group_idx].t()) else: - if hasattr(hf_mlp, 'gate_up_proj'): - linear_fc1_weight.data.copy_(hf_mlp.gate_up_proj.weight) + if 'gate_up_proj.weight' in state_dict: + mg_state_dict[fc1_key] = state_dict['gate_up_proj.weight'] else: - linear_fc1_weight.data.copy_(torch.cat([hf_mlp.gate_proj.weight, hf_mlp.up_proj.weight], dim=0)) - linear_fc2_weight.data.copy_(hf_mlp.down_proj.weight) - - -def _set_moe_state(args, mg_mlp, hf_mlp): - hf_gate = hf_mlp.gate - if hasattr(hf_gate, 'wg'): - hf_gate = hf_gate.wg - mg_mlp.router.weight.data.copy_(hf_gate.weight) + mg_state_dict[fc1_key] = torch.cat([ + state_dict['gate_proj.weight'], + state_dict['up_proj.weight'], + ], dim=0) + mg_state_dict[fc2_key] = state_dict['down_proj.weight'] + return _add_prefix(mg_state_dict, prefix) + + +def _set_moe_state(args, state_dict, prefix: str): + mg_state_dict = {} + mg_state_dict['router.weight'] = state_dict['gate.weight'] if state_dict.get( + 'gate.wg.weight') is None else state_dict['gate.wg.weight'] if args.moe_router_enable_expert_bias: - mg_mlp.router.expert_bias.data.copy_(hf_gate.e_score_correction_bias) - if mg_mlp.shared_experts is not None: - if hasattr(hf_mlp, 'shared_experts'): - hf_shared_expert = hf_mlp.shared_experts - elif hasattr(hf_mlp, 'shared_mlp'): - hf_shared_expert = hf_mlp.shared_mlp - else: - hf_shared_expert = hf_mlp.shared_expert - _set_mlp_state(mg_mlp.shared_experts, hf_shared_expert) - if mg_mlp.shared_experts.gate_weight is not None: - mg_mlp.shared_experts.gate_weight.data.copy_(hf_mlp.shared_expert_gate.weight) + mg_state_dict['router.expert_bias'] = state_dict['gate.e_score_correction_bias'] + + if args.moe_shared_expert_intermediate_size: + shared_expert_sd = _remove_prefix(state_dict, 'shared_expert.') + if not shared_expert_sd: + shared_expert_sd = _remove_prefix(state_dict, 'shared_experts.') + if not shared_expert_sd: + shared_expert_sd = _remove_prefix(state_dict, 'shared_mlp.') + mg_state_dict.update(_set_mlp_state(args, shared_expert_sd, 'shared_experts.')) + if 'shared_expert_gate.weight' in state_dict: + mg_state_dict['shared_experts.gate_weight'] = state_dict['shared_expert_gate.weight'] for expert_idx in range(args.num_experts): - hf_expert = hf_mlp.experts - if hasattr(hf_expert, '__len__'): - hf_expert = hf_expert[expert_idx] - _set_mlp_state(mg_mlp.experts, hf_expert, group_idx=expert_idx) + # hf_expert = hf_mlp.experts + # if hasattr(hf_expert, '__len__'): + # hf_expert = hf_expert[expert_idx] + mg_state_dict.update( + _set_mlp_state( + args, _remove_prefix(state_dict, f'experts.{expert_idx}.'), 'experts.', group_idx=expert_idx)) + return _add_prefix(mg_state_dict, prefix) -def set_mlp_state(args, mg_mlp, hf_mlp): - if 'moe' in mg_mlp.__class__.__name__.lower(): - _set_moe_state(args, mg_mlp, hf_mlp) - else: - _set_mlp_state(mg_mlp, hf_mlp) +def _is_moe(state_dict): + for k, v in state_dict.items(): + if 'experts.' in k: + return True + return False -def set_layer_state(args, mg_model, hf_model, layer_idx): - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.layers[layer_idx] +def set_layer_state(args, state_dict, layer_idx: int, prefix: str): + mg_state_dict = {} if args.multi_latent_attention: - set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - mg_layer.input_layernorm.weight.data.copy_(hf_layer.input_layernorm.weight) + mg_state_dict.update(set_mla_attn_state(args, _remove_prefix(state_dict, 'self_attn.'), 'self_attention.')) + # set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) + mg_state_dict['input_layernorm.weight'] = state_dict['input_layernorm.weight'] + else: - set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - mg_layer.self_attention.linear_qkv.layer_norm_weight.data.copy_(hf_layer.input_layernorm.weight) + mg_state_dict.update(set_attn_state(args, _remove_prefix(state_dict, 'self_attn.'), 'self_attention.')) + mg_state_dict['self_attention.linear_qkv.layer_norm_weight'] = state_dict['input_layernorm.weight'] - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) + mlp_state_dict = _remove_prefix(state_dict, 'mlp.') + is_moe = _is_moe(mlp_state_dict) + if is_moe: + mg_state_dict.update(_set_moe_state(args, mlp_state_dict, 'mlp.')) + else: + mg_state_dict.update(_set_mlp_state(args, mlp_state_dict, 'mlp.')) - post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight - if 'moe' in mg_layer.mlp.__class__.__name__.lower(): - mg_layer.pre_mlp_layernorm.weight.data.copy_(post_attention_layernorm_weight) + if is_moe: + mg_state_dict['pre_mlp_layernorm.weight'] = state_dict['post_attention_layernorm.weight'] else: - mg_layer.mlp.linear_fc1.layer_norm_weight.data.copy_(post_attention_layernorm_weight) + mg_state_dict['mlp.linear_fc1.layer_norm_weight'] = state_dict['post_attention_layernorm.weight'] + return _add_prefix(mg_state_dict, prefix) + + +def _remove_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} + + +def _add_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {f'{prefix}{k}': v for k, v in state_dict.items()} -def convert_hf2mcore(hf_model, mg_model): +def convert_hf2mcore(state_dict, prefix=''): args = get_args() - mg_model.embedding.word_embeddings.weight.data.copy_(hf_model.model.embed_tokens.weight) + mg_state_dict = {} + mg_state_dict['embedding.word_embeddings.weight'] = state_dict['model.embed_tokens.weight'] if args.untie_embeddings_and_output_weights: - mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) - mg_model.decoder.final_layernorm.weight.data.copy_(hf_model.model.norm.weight) - for layer_idx in range(args.num_layers): - set_layer_state(args, mg_model, hf_model.model, layer_idx) + mg_state_dict['output_layer.weight'] = state_dict['lm_head.weight'] + mg_state_dict['decoder.final_layernorm.weight'] = state_dict['model.norm.weight'] + for layer_idx in tqdm(range(args.num_layers), dynamic_ncols=True, desc='Converting: '): + mg_state_dict.update( + set_layer_state(args, _remove_prefix(state_dict, f'model.layers.{layer_idx}.'), layer_idx, + f'decoder.layers.{layer_idx}.')) + return _add_prefix(mg_state_dict, prefix) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 9c37bda427..8058f83705 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -272,3 +272,6 @@ def forward( def get_input_tensor(self): return self.decoder.input_tensor + + def save_hf_checkpoint(self): + print() diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 15b81434f9..d71eebdc68 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -10,7 +10,6 @@ from swift.utils import get_logger, is_master, plot_images from ..argument import MegatronTrainArguments from ..trainers import MegatronTrainer -from ..utils import patch_megatron_tokenizer from .utils import build_streaming_dataloader logger = get_logger() @@ -35,7 +34,6 @@ def __init__(self, args: Optional[Union[List[str], MegatronTrainArguments]] = No with torch.device('meta'): self.model, self.processor = args.get_model_processor(**kwargs) self._prepare_template() - patch_megatron_tokenizer(self.processor) args.init_model_args(self.tokenizer, self.processor.model_info.config) args.save_args(args.save) self.template.use_megatron = True diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 1afd505df0..1833ea043f 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .convert import convert_hf2mcore, convert_mcore2hf -from .patcher import patch_megatron_tokenizer from .utils import (adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index aa3202f580..c095b6d3cb 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -18,7 +18,7 @@ from swift.utils import get_logger, get_n_params_grads from ..argument import MegatronArguments from ..model import get_megatron_model_meta -from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard +from .patcher import patch_torch_dist_shard logger = get_logger() @@ -246,7 +246,6 @@ def convert_hf2mcore(args: ExportArguments) -> None: current_convert_kwargs['moe_grouped_gemm'] = True megatron_args = MegatronArguments( **kwargs, **current_convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) - patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() extra_args['model_info'] = args.model_info extra_args['model_meta'] = args.model_meta @@ -256,7 +255,11 @@ def convert_hf2mcore(args: ExportArguments) -> None: mg_model = megatron_model_meta.model_provider() logger.info('Megatron model created successfully.') - megatron_model_meta.convert_hf2mcore(hf_model, mg_model) + incompatible_keys = mg_model.load_state_dict( + megatron_model_meta.convert_hf2mcore(hf_model.state_dict()), strict=False) + missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')] + assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}' + assert len(missing_keys) == 0, f'missing_keys: {missing_keys}' if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) del hf_model @@ -291,7 +294,6 @@ def convert_mcore2hf(args: ExportArguments) -> None: **current_convert_kwargs, save=args.output_dir if args.to_mcore else None, torch_dtype=args.torch_dtype) - patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() extra_args['model_info'] = args.model_info extra_args['model_meta'] = args.model_meta diff --git a/swift/megatron/utils/patcher.py b/swift/megatron/utils/patcher.py index 49dec85dea..9fd1cbf169 100644 --- a/swift/megatron/utils/patcher.py +++ b/swift/megatron/utils/patcher.py @@ -7,14 +7,6 @@ logger = get_logger() -def patch_megatron_tokenizer(tokenizer): - - def build_tokenizer(args): - return tokenizer - - global_vars.build_tokenizer = build_tokenizer - - def patch_torch_dist_shard(thread_count): __init__ = TorchDistSaveShardedStrategy.__init__ diff --git a/tests/megatron/test_auto.py b/tests/megatron/test_auto.py new file mode 100644 index 0000000000..7d5c3af29b --- /dev/null +++ b/tests/megatron/test_auto.py @@ -0,0 +1,3 @@ + +from swift.megatron import AutoMcoreModel + From 29a487d5c44a03b3fc90f460e282fa6bf24f9e48 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 23 Oct 2025 16:00:00 +0800 Subject: [PATCH 03/30] update --- swift/megatron/init.py | 2 +- swift/megatron/model/gpt/hf2mcore.py | 48 ++++++++++++++----------- swift/megatron/model/mm_gpt/qwen3_vl.py | 19 ++++------ swift/megatron/model/register.py | 5 +-- swift/megatron/utils/convert.py | 2 +- 5 files changed, 39 insertions(+), 37 deletions(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 57c91400c2..a41ef062dd 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -382,7 +382,7 @@ def _patch_megatron_tokenizer(): from megatron.training import global_vars def build_tokenizer(args): - return + return 'dummy_tokenizer' global_vars.build_tokenizer = build_tokenizer diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py index 6393c31936..e6b9f980f2 100644 --- a/swift/megatron/model/gpt/hf2mcore.py +++ b/swift/megatron/model/gpt/hf2mcore.py @@ -42,10 +42,14 @@ def set_attn_state(args, state_dict, prefix: str): ], dim=1).reshape(-1) if args.qk_layernorm: - mg_state_dict['q_layernorm.weight'] = state_dict['query_layernorm.weight'] if state_dict.get( - 'q_norm.weight') is None else state_dict['q_norm.weight'] - mg_state_dict['k_layernorm.weight'] = state_dict['key_layernorm.weight'] if state_dict.get( - 'k_norm.weight') is None else state_dict['k_norm.weight'] + if 'q_norm.weight' in state_dict: + mg_state_dict['q_layernorm.weight'] = state_dict['q_norm.weight'] + else: + mg_state_dict['q_layernorm.weight'] = state_dict['query_layernorm.weight'] + if 'k_norm.weight' in state_dict: + mg_state_dict['k_layernorm.weight'] = state_dict['k_norm.weight'] + else: + mg_state_dict['k_layernorm.weight'] = state_dict['key_layernorm.weight'] return _add_prefix(mg_state_dict, prefix) @@ -55,10 +59,9 @@ def _set_mlp_state( state_dict, prefix: str, group_idx: Optional[int] = None, + hf_grouped: bool = False, ): mg_state_dict = {} - hf_grouped = False - # hf_grouped = not isinstance(hf_mlp.down_proj, nn.Module) if group_idx is None: fc1_key = 'linear_fc1.weight' fc2_key = 'linear_fc2.weight' @@ -66,8 +69,8 @@ def _set_mlp_state( fc1_key = f'linear_fc1.weight{group_idx}' fc2_key = f'linear_fc2.weight{group_idx}' if hf_grouped: - linear_fc1_weight.data.copy_(hf_mlp.gate_up_proj[group_idx].t()) - linear_fc2_weight.data.copy_(hf_mlp.down_proj[group_idx].t()) + mg_state_dict[fc1_key] = state_dict['gate_up_proj'][group_idx].t() + mg_state_dict[fc2_key] = state_dict['down_proj'][group_idx].t() else: if 'gate_up_proj.weight' in state_dict: mg_state_dict[fc1_key] = state_dict['gate_up_proj.weight'] @@ -82,8 +85,10 @@ def _set_mlp_state( def _set_moe_state(args, state_dict, prefix: str): mg_state_dict = {} - mg_state_dict['router.weight'] = state_dict['gate.weight'] if state_dict.get( - 'gate.wg.weight') is None else state_dict['gate.wg.weight'] + if 'gate.wg.weight' in state_dict: + mg_state_dict['router.weight'] = state_dict['gate.wg.weight'] + else: + mg_state_dict['router.weight'] = state_dict['gate.weight'] if args.moe_router_enable_expert_bias: mg_state_dict['router.expert_bias'] = state_dict['gate.e_score_correction_bias'] @@ -97,12 +102,11 @@ def _set_moe_state(args, state_dict, prefix: str): if 'shared_expert_gate.weight' in state_dict: mg_state_dict['shared_experts.gate_weight'] = state_dict['shared_expert_gate.weight'] for expert_idx in range(args.num_experts): - # hf_expert = hf_mlp.experts - # if hasattr(hf_expert, '__len__'): - # hf_expert = hf_expert[expert_idx] - mg_state_dict.update( - _set_mlp_state( - args, _remove_prefix(state_dict, f'experts.{expert_idx}.'), 'experts.', group_idx=expert_idx)) + expert_sd = _remove_prefix(state_dict, f'experts.') + hf_grouped = expert_sd is not None + if expert_sd is None: + expert_sd = _remove_prefix(state_dict, f'experts.{expert_idx}.') + mg_state_dict.update(_set_mlp_state(args, expert_sd, 'experts.', group_idx=expert_idx, hf_grouped=hf_grouped)) return _add_prefix(mg_state_dict, prefix) @@ -113,7 +117,7 @@ def _is_moe(state_dict): return False -def set_layer_state(args, state_dict, layer_idx: int, prefix: str): +def set_layer_state(args, state_dict, prefix: str): mg_state_dict = {} if args.multi_latent_attention: mg_state_dict.update(set_mla_attn_state(args, _remove_prefix(state_dict, 'self_attn.'), 'self_attention.')) @@ -153,12 +157,14 @@ def _add_prefix(state_dict, prefix: str): def convert_hf2mcore(state_dict, prefix=''): args = get_args() mg_state_dict = {} - mg_state_dict['embedding.word_embeddings.weight'] = state_dict['model.embed_tokens.weight'] - if args.untie_embeddings_and_output_weights: + is_language_model = 'model.language_model.embed_tokens.weight' in state_dict + hf_prefix = 'model.language_model.' if is_language_model else 'model.' + mg_state_dict['embedding.word_embeddings.weight'] = state_dict[f'{hf_prefix}embed_tokens.weight'] + if args.untie_embeddings_and_output_weights and 'lm_head.weight' in state_dict: mg_state_dict['output_layer.weight'] = state_dict['lm_head.weight'] - mg_state_dict['decoder.final_layernorm.weight'] = state_dict['model.norm.weight'] + mg_state_dict['decoder.final_layernorm.weight'] = state_dict[f'{hf_prefix}norm.weight'] for layer_idx in tqdm(range(args.num_layers), dynamic_ncols=True, desc='Converting: '): mg_state_dict.update( - set_layer_state(args, _remove_prefix(state_dict, f'model.layers.{layer_idx}.'), layer_idx, + set_layer_state(args, _remove_prefix(state_dict, f'{hf_prefix}layers.{layer_idx}.'), f'decoder.layers.{layer_idx}.')) return _add_prefix(mg_state_dict, prefix) diff --git a/swift/megatron/model/mm_gpt/qwen3_vl.py b/swift/megatron/model/mm_gpt/qwen3_vl.py index de222f75d7..564529d0f4 100644 --- a/swift/megatron/model/mm_gpt/qwen3_vl.py +++ b/swift/megatron/model/mm_gpt/qwen3_vl.py @@ -15,7 +15,7 @@ from swift.llm import ModelType, to_device from ..constant import MegatronModelType -from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore +from ..gpt.hf2mcore import _add_prefix, _remove_prefix, convert_hf2mcore from ..gpt.mcore2hf import set_layer_state as set_layer_state_mcore2hf from ..mm_gpt_model import MultimodalGPTModel from ..register import register_megatron_model @@ -499,19 +499,14 @@ def __init__(self, *args, **kwargs): model_cls=Qwen3VLGPTModel, visual_cls=Qwen3Omni_Vit)) - -def convert_hf2mcore_qwen3_vl(hf_model, mg_model): - language_model = hf_model.model.language_model - mg_language_model = mg_model.language_model +def convert_hf2mcore_qwen3_vl(state_dict, prefix=''): args = get_args() - mg_language_model.embedding.word_embeddings.weight.data.copy_(language_model.embed_tokens.weight) + mg_state_dict = {} if args.untie_embeddings_and_output_weights: - mg_language_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) - mg_language_model.decoder.final_layernorm.weight.data.copy_(language_model.norm.weight) - for layer_idx in range(args.num_layers): - set_layer_state_hf2mcore(args, mg_language_model, language_model, layer_idx) - mg_model.visual.visual.load_state_dict(hf_model.model.visual.state_dict()) - + mg_state_dict['language_model.output_layer.weight'] = state_dict['lm_head.weight'] + mg_state_dict.update(convert_hf2mcore(state_dict, 'language_model.')) + mg_state_dict.update(_add_prefix(_remove_prefix(state_dict, 'model.visual.'), 'visual.visual.')) + return _add_prefix(mg_state_dict, prefix) def convert_mcore2hf_qwen3_vl(hf_model, mg_model): language_model = hf_model.model.language_model diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 93f892c2e8..acaf4f6ec0 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Type +import torch import torch.nn as nn from transformers import PretrainedConfig @@ -17,8 +18,8 @@ class MegatronModelMeta: megatron_model_type: str model_types: List[str] - convert_mcore2hf: Callable[[nn.Module, nn.Module], None] - convert_hf2mcore: Callable[[nn.Module, nn.Module], None] + convert_mcore2hf: Callable[[Dict[str, torch.Tensor], str], Dict[str, torch.Tensor]] + convert_hf2mcore: Callable[[Dict[str, torch.Tensor], str], Dict[str, torch.Tensor]] model_cls: Type[nn.Module] convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index c095b6d3cb..8298e8b00d 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -256,7 +256,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: mg_model = megatron_model_meta.model_provider() logger.info('Megatron model created successfully.') incompatible_keys = mg_model.load_state_dict( - megatron_model_meta.convert_hf2mcore(hf_model.state_dict()), strict=False) + megatron_model_meta.convert_hf2mcore(hf_model.state_dict()), strict=False, assign=True) missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')] assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}' assert len(missing_keys) == 0, f'missing_keys: {missing_keys}' From 4be6ae18e7efa6b9e29f8eb7432b39b7a29324c1 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 24 Oct 2025 13:51:07 +0800 Subject: [PATCH 04/30] update --- swift/megatron/model/__init__.py | 3 +- swift/megatron/model/gpt/hf2mcore.py | 1 - swift/megatron/model/gpt/mcore2hf.py | 219 ++++++++++++++----------- swift/megatron/model/model_provider.py | 3 +- swift/megatron/utils/convert.py | 2 +- 5 files changed, 131 insertions(+), 97 deletions(-) diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py index 3c882c9864..35b8777147 100644 --- a/swift/megatron/model/__init__.py +++ b/swift/megatron/model/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from . import gpt, mm_gpt +# from . import gpt, mm_gpt +from . import gpt from .constant import MegatronModelType from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py index e6b9f980f2..ccfd265bae 100644 --- a/swift/megatron/model/gpt/hf2mcore.py +++ b/swift/megatron/model/gpt/hf2mcore.py @@ -121,7 +121,6 @@ def set_layer_state(args, state_dict, prefix: str): mg_state_dict = {} if args.multi_latent_attention: mg_state_dict.update(set_mla_attn_state(args, _remove_prefix(state_dict, 'self_attn.'), 'self_attention.')) - # set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) mg_state_dict['input_layernorm.weight'] = state_dict['input_layernorm.weight'] else: diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index eac8023801..8a5235921c 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -1,127 +1,162 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Optional +import torch from megatron.training import get_args +from .hf2mcore import _add_prefix, _remove_prefix from torch import nn +from tqdm import tqdm -def set_mla_attn_state(args, mg_attn, hf_attn): - hf_attn.o_proj.weight.data.copy_(mg_attn.linear_proj.weight) +def set_mla_attn_state(args, state_dict, prefix: str): + hf_state_dict = {} + hf_state_dict['o_proj.weight'] = state_dict['linear_proj.weight'] if args.q_lora_rank is None: - hf_attn.q_proj.weight.data.copy_(mg_attn.linear_q_proj.weight) + hf_state_dict['q_proj.weight'] = state_dict['linear_q_proj.weight'] else: - hf_attn.q_a_proj.weight.data.copy_(mg_attn.linear_q_down_proj.weight) - hf_attn.q_b_proj.weight.data.copy_(mg_attn.linear_q_up_proj.weight) - hf_attn.kv_a_proj_with_mqa.weight.data.copy_(mg_attn.linear_kv_down_proj.weight) - hf_attn.kv_b_proj.weight.data.copy_(mg_attn.linear_kv_up_proj.weight) + hf_state_dict['q_a_proj.weight'] = state_dict['linear_q_down_proj.weight'] + hf_state_dict['q_b_proj.weight'] = state_dict['linear_q_up_proj.weight'] + hf_state_dict['kv_a_proj_with_mqa.weight'] = state_dict['linear_kv_down_proj.weight'] + hf_state_dict['kv_b_proj.weight'] = state_dict['linear_kv_up_proj.weight'] if args.qk_layernorm: - hf_attn.kv_a_layernorm.weight.data.copy_(mg_attn.linear_kv_up_proj.layer_norm_weight) + hf_state_dict['kv_a_layernorm.weight'] = state_dict['linear_kv_up_proj.layer_norm_weight'] + return _add_prefix(hf_state_dict, prefix) -def set_attn_state(args, mg_attn, hf_attn): +def set_attn_state(args, state_dict, prefix: str): + hf_state_dict = {} num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) - # Copy weights - mg_attn_weight = mg_attn.linear_qkv.weight.reshape((num_query_groups, -1, args.hidden_size)) - q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ - 0] // num_query_groups - hf_attn.q_proj.weight.data.copy_(mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size)) - hf_attn.k_proj.weight.data.copy_(mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size)) - hf_attn.v_proj.weight.data.copy_(mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size)) - hf_attn.o_proj.weight.data.copy_(mg_attn.linear_proj.weight) + mg_attn_weight = state_dict['linear_qkv.weight'].reshape((num_query_groups, -1, args.hidden_size)) + q_dim = args.kv_channels * args.num_attention_heads + kv_dim = args.kv_channels * args.num_query_groups + hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size) + hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size) + hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size) + hf_state_dict['o_proj.weight'] = state_dict['linear_proj.weight'] # Copy bias if args.add_qkv_bias: - mg_attn_bias = mg_attn.linear_qkv.bias.reshape((num_query_groups, -1)) - hf_attn.q_proj.bias.data.copy_(mg_attn_bias[:, :q_dim].reshape(-1)) - hf_attn.k_proj.bias.data.copy_(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1)) - hf_attn.v_proj.bias.data.copy_(mg_attn_bias[:, -kv_dim:].reshape(-1)) - + mg_attn_bias = state_dict['linear_qkv.bias'].reshape((num_query_groups, -1)) + state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1) + state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1) + state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1) if args.qk_layernorm: - q_norm = hf_attn.query_layernorm if hasattr(hf_attn, 'query_layernorm') else hf_attn.q_norm - k_norm = hf_attn.key_layernorm if hasattr(hf_attn, 'key_layernorm') else hf_attn.k_norm - q_norm.weight.data.copy_(mg_attn.q_layernorm.weight) - k_norm.weight.data.copy_(mg_attn.k_layernorm.weight) - - -def _set_moe_state(args, mg_mlp, hf_mlp): - hf_gate = hf_mlp.gate - if hasattr(hf_gate, 'wg'): - hf_gate = hf_gate.wg - hf_gate.weight.data.copy_(mg_mlp.router.weight) - if args.moe_router_enable_expert_bias: - hf_gate.e_score_correction_bias.data.copy_(mg_mlp.router.expert_bias) - if mg_mlp.shared_experts is not None: - if hasattr(hf_mlp, 'shared_experts'): - hf_shared_expert = hf_mlp.shared_experts - elif hasattr(hf_mlp, 'shared_mlp'): - hf_shared_expert = hf_mlp.shared_mlp - else: - hf_shared_expert = hf_mlp.shared_expert - _set_mlp_state(mg_mlp.shared_experts, hf_shared_expert) - if mg_mlp.shared_experts.gate_weight is not None: - hf_mlp.shared_expert_gate.weight.data.copy_(mg_mlp.shared_experts.gate_weight) - for expert_idx in range(args.num_experts): - hf_expert = hf_mlp.experts - if hasattr(hf_expert, '__len__'): - hf_expert = hf_expert[expert_idx] - _set_mlp_state(mg_mlp.experts, hf_expert, group_idx=expert_idx) - - -def _set_mlp_state(mg_mlp, hf_mlp, group_idx: Optional[int] = None): - hf_grouped = not isinstance(hf_mlp.down_proj, nn.Module) + hf_state_dict['q_norm.weight'] = state_dict['q_layernorm.weight'] + hf_state_dict['query_layernorm.weight'] = state_dict['q_layernorm.weight'] + hf_state_dict['k_norm.weight'] = hf_state_dict['k_layernorm.weight'] + hf_state_dict['key_layernorm.weight'] = hf_state_dict['k_layernorm.weight'] + + return _add_prefix(mg_state_dict, prefix) + + +def _set_mlp_state( + args, + state_dict, + prefix: str, + group_idx: Optional[int] = None, + hf_grouped: bool = False, +): + mg_state_dict = {} if group_idx is None: - linear_fc1_weight = mg_mlp.linear_fc1.weight - linear_fc2_weight = mg_mlp.linear_fc2.weight + fc1_key = 'linear_fc1.weight' + fc2_key = 'linear_fc2.weight' else: - linear_fc1_weight = getattr(mg_mlp.linear_fc1, f'weight{group_idx}') - linear_fc2_weight = getattr(mg_mlp.linear_fc2, f'weight{group_idx}') - + fc1_key = f'linear_fc1.weight{group_idx}' + fc2_key = f'linear_fc2.weight{group_idx}' if hf_grouped: - hf_mlp.gate_up_proj.data[group_idx] = linear_fc1_weight.t() - hf_mlp.down_proj.data[group_idx] = linear_fc2_weight.t() + mg_state_dict[fc1_key] = state_dict['gate_up_proj'][group_idx].t() + mg_state_dict[fc2_key] = state_dict['down_proj'][group_idx].t() else: - if hasattr(hf_mlp, 'gate_up_proj'): - hf_mlp.gate_up_proj.weight.data.copy_(linear_fc1_weight) + if 'gate_up_proj.weight' in state_dict: + mg_state_dict[fc1_key] = state_dict['gate_up_proj.weight'] else: - ffn_hidden_size = hf_mlp.gate_proj.weight.shape[0] - hf_mlp.gate_proj.weight.data.copy_(linear_fc1_weight[:ffn_hidden_size]) - hf_mlp.up_proj.weight.data.copy_(linear_fc1_weight[ffn_hidden_size:]) - hf_mlp.down_proj.weight.data.copy_(linear_fc2_weight) - - -def set_mlp_state(args, mg_mlp, hf_mlp): - if 'moe' in mg_mlp.__class__.__name__.lower(): - _set_moe_state(args, mg_mlp, hf_mlp) + mg_state_dict[fc1_key] = torch.cat([ + state_dict['gate_proj.weight'], + state_dict['up_proj.weight'], + ], dim=0) + mg_state_dict[fc2_key] = state_dict['down_proj.weight'] + return _add_prefix(mg_state_dict, prefix) + + +def _set_moe_state(args, state_dict, prefix: str): + mg_state_dict = {} + if 'gate.wg.weight' in state_dict: + mg_state_dict['router.weight'] = state_dict['gate.wg.weight'] else: - _set_mlp_state(mg_mlp, hf_mlp) + mg_state_dict['router.weight'] = state_dict['gate.weight'] + if args.moe_router_enable_expert_bias: + mg_state_dict['router.expert_bias'] = state_dict['gate.e_score_correction_bias'] + + if args.moe_shared_expert_intermediate_size: + shared_expert_sd = _remove_prefix(state_dict, 'shared_expert.') + if not shared_expert_sd: + shared_expert_sd = _remove_prefix(state_dict, 'shared_experts.') + if not shared_expert_sd: + shared_expert_sd = _remove_prefix(state_dict, 'shared_mlp.') + mg_state_dict.update(_set_mlp_state(args, shared_expert_sd, 'shared_experts.')) + if 'shared_expert_gate.weight' in state_dict: + mg_state_dict['shared_experts.gate_weight'] = state_dict['shared_expert_gate.weight'] + for expert_idx in range(args.num_experts): + expert_sd = _remove_prefix(state_dict, f'experts.') + hf_grouped = expert_sd is not None + if expert_sd is None: + expert_sd = _remove_prefix(state_dict, f'experts.{expert_idx}.') + mg_state_dict.update(_set_mlp_state(args, expert_sd, 'experts.', group_idx=expert_idx, hf_grouped=hf_grouped)) + return _add_prefix(mg_state_dict, prefix) -def set_layer_state(args, mg_model, hf_model, layer_idx): - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.layers[layer_idx] +def _is_moe(state_dict): + for k, v in state_dict.items(): + if 'experts.' in k: + return True + return False + +def set_layer_state(args, state_dict, prefix: str): + hf_state_dict = {} if args.multi_latent_attention: - set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - hf_layer.input_layernorm.weight.data.copy_(mg_layer.input_layernorm.weight) + hf_state_dict.update(set_mla_attn_state(args, _remove_prefix(state_dict, ), 'self_attn.')) + hf_state_dict['input_layernorm.weight'] = mg_state_dict['input_layernorm.weight'] + else: - set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) - hf_layer.input_layernorm.weight.data.copy_(mg_layer.self_attention.linear_qkv.layer_norm_weight) + hf_state_dict.update(set_attn_state(args, _remove_prefix(state_dict, 'self_attention.'), 'self_attn.')) + hf_state_dict['input_layernorm.weight'] = state_dict['self_attention.linear_qkv.layer_norm_weight'] - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) + mlp_state_dict = _remove_prefix(state_dict, 'mlp.') + is_moe = _is_moe(mlp_state_dict) + if is_moe: + hf_state_dict.update(_set_moe_state(args, mlp_state_dict, 'mlp.')) + else: + hf_state_dict.update(_set_mlp_state(args, mlp_state_dict, 'mlp.')) - post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight - if 'moe' in mg_layer.mlp.__class__.__name__.lower(): - post_attention_layernorm_weight.data.copy_(mg_layer.pre_mlp_layernorm.weight) + if is_moe: + hf_state_dict['post_attention_layernorm.weight'] = state_dict['pre_mlp_layernorm.weight'] else: - post_attention_layernorm_weight.data.copy_(mg_layer.mlp.linear_fc1.layer_norm_weight) + hf_state_dict['post_attention_layernorm.weight'] = state_dict['mlp.linear_fc1.layer_norm_weight'] + return _add_prefix(hf_state_dict, prefix) + + +def _remove_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} + + +def _add_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {f'{prefix}{k}': v for k, v in state_dict.items()} -def convert_mcore2hf(hf_model, mg_model): +def convert_mcore2hf(state_dict, prefix=''): args = get_args() - hf_model.model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight) + hf_state_dict = {} + hf_state_dict['model.embed_tokens.weight'] = state_dict['embedding.word_embeddings.weight'] if args.untie_embeddings_and_output_weights: - lm_head_weight = hf_model.score.weight if args.task_type == 'seq_cls' else hf_model.lm_head.weight - lm_head_weight.data.copy_(mg_model.output_layer.weight) - hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight) - for layer_idx in range(args.num_layers): - set_layer_state(args, mg_model, hf_model.model, layer_idx) + hf_state_dict['lm_head.weight'] = state_dict['output_layer.weight'] + hf_state_dict[f'model.norm.weight'] = state_dict['decoder.final_layernorm.weight'] + for layer_idx in tqdm(range(args.num_layers), dynamic_ncols=True, desc='Converting: '): + hf_state_dict.update( + set_layer_state(args, _remove_prefix(state_dict, f'decoder.layers.{layer_idx}.' + ), f'model.layers.{layer_idx}.')) + return _add_prefix(hf_state_dict, prefix) diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index 814cd82213..fba3706a6a 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -14,14 +14,13 @@ if TYPE_CHECKING: from .gpt_model import GPTModel - from .mm_gpt_model import MultimodalGPTModel # Code borrowed from NVIDIA/Megatron-LM def model_provider( pre_process=True, post_process=True, - vp_stage: Optional[int] = None) -> Union['GPTModel', 'MultimodalGPTModel', megatron.legacy.model.GPTModel]: + vp_stage: Optional[int] = None) -> Union['GPTModel', megatron.legacy.model.GPTModel]: """Builds the model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index 8298e8b00d..c095b6d3cb 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -256,7 +256,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: mg_model = megatron_model_meta.model_provider() logger.info('Megatron model created successfully.') incompatible_keys = mg_model.load_state_dict( - megatron_model_meta.convert_hf2mcore(hf_model.state_dict()), strict=False, assign=True) + megatron_model_meta.convert_hf2mcore(hf_model.state_dict()), strict=False) missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')] assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}' assert len(missing_keys) == 0, f'missing_keys: {missing_keys}' From 4e6e41f156173917812ac17c4f5f24041e104d15 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 24 Oct 2025 13:52:09 +0800 Subject: [PATCH 05/30] update --- swift/megatron/__init__.py | 6 +- swift/megatron/argument/train_args.py | 3 +- swift/megatron/bridge/__init__.py | 1 - swift/megatron/bridge/auto.py | 13 --- swift/megatron/{utils => }/convert.py | 14 +-- swift/megatron/model/config.py | 85 ------------------ swift/megatron/model/gpt/__init__.py | 9 +- swift/megatron/model/gpt/hf2mcore.py | 11 --- swift/megatron/model/gpt/mcore2hf.py | 1 - swift/megatron/model/gpt_bridge.py | 82 +++++++++++++++++ swift/megatron/model/register.py | 8 +- swift/megatron/utils/__init__.py | 3 +- swift/megatron/{model/gpt => utils}/config.py | 88 ++++++++++++++++++- 13 files changed, 189 insertions(+), 135 deletions(-) delete mode 100644 swift/megatron/bridge/__init__.py delete mode 100644 swift/megatron/bridge/auto.py rename swift/megatron/{utils => }/convert.py (97%) delete mode 100644 swift/megatron/model/config.py create mode 100644 swift/megatron/model/gpt_bridge.py rename swift/megatron/{model/gpt => utils}/config.py (50%) diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index bca6522cc3..75c639a497 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -13,7 +13,8 @@ if TYPE_CHECKING: from .train import megatron_sft_main, megatron_pt_main, megatron_rlhf_main - from .utils import convert_hf2mcore, convert_mcore2hf, prepare_mcore_model, adapter_state_dict_context + from .convert import convert_hf2mcore, convert_mcore2hf + from .utils import prepare_mcore_model, adapter_state_dict_context from .argument import MegatronTrainArguments, MegatronRLHFArguments from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model from .trainers import MegatronTrainer, MegatronDPOTrainer @@ -22,7 +23,8 @@ else: _import_structure = { 'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'], - 'utils': ['convert_hf2mcore', 'convert_mcore2hf', 'prepare_mcore_model', 'adapter_state_dict_context'], + 'convert': ['convert_hf2mcore', 'convert_mcore2hf'], + 'utils': ['prepare_mcore_model', 'adapter_state_dict_context'], 'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments'], 'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model'], 'trainers': ['MegatronTrainer', 'MegatronDPOTrainer'], diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 61535def15..9aa3533a8f 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -10,6 +10,7 @@ from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master from ..model import get_megatron_model_meta from .megatron_args import MegatronArguments +from ..utils import convert_hf_config logger = get_logger() @@ -23,7 +24,7 @@ def init_model_args(self, tokenizer, config): if self.task_type == 'seq_cls': self.problem_type = self.problem_type or getattr(config, 'problem_type', None) logger.info(f'args.problem_type: {self.problem_type}') - kwargs = self.megatron_model_meta.convert_hf_config(config) + kwargs = convert_hf_config(config) if self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer): kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128 self.initialize_embedding = True diff --git a/swift/megatron/bridge/__init__.py b/swift/megatron/bridge/__init__.py deleted file mode 100644 index 653cea3525..0000000000 --- a/swift/megatron/bridge/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .auto import AutoMcoreModel diff --git a/swift/megatron/bridge/auto.py b/swift/megatron/bridge/auto.py deleted file mode 100644 index 8cea017155..0000000000 --- a/swift/megatron/bridge/auto.py +++ /dev/null @@ -1,13 +0,0 @@ - -from megatron.training import get_args - -class AutoMcoreModel: - - @classmethod - def build_model(cls): - args = get_args() - model_meta = args.model_meta - model_info = args.model_info - megatron_model_meta = args.megatron_model_meta - logger.info(f'Creating mcore_model using model_dir: {model_info.model_dir}') - diff --git a/swift/megatron/utils/convert.py b/swift/megatron/convert.py similarity index 97% rename from swift/megatron/utils/convert.py rename to swift/megatron/convert.py index c095b6d3cb..f2d5d11da6 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/convert.py @@ -12,13 +12,12 @@ from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint from megatron.training.initialize import initialize_megatron from megatron.training.utils import get_ltor_masks_and_position_ids - from swift.llm import (ExportArguments, HfConfigFactory, prepare_model_template, save_checkpoint, to_device, to_float_dtype) from swift.utils import get_logger, get_n_params_grads -from ..argument import MegatronArguments -from ..model import get_megatron_model_meta -from .patcher import patch_torch_dist_shard +from .utils import convert_hf_config, patch_torch_dist_shard +from .argument import MegatronArguments +from .model import get_megatron_model_meta logger = get_logger() @@ -238,7 +237,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' - kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) + kwargs = convert_hf_config(processor.model_info.config) logger.info(f'megatron_config: {kwargs}') _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() @@ -255,8 +254,9 @@ def convert_hf2mcore(args: ExportArguments) -> None: mg_model = megatron_model_meta.model_provider() logger.info('Megatron model created successfully.') + bridge = megatron_model_meta.bridge_cls() incompatible_keys = mg_model.load_state_dict( - megatron_model_meta.convert_hf2mcore(hf_model.state_dict()), strict=False) + bridge.convert_hf2mcore(hf_model.state_dict()), strict=False) missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')] assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}' assert len(missing_keys) == 0, f'missing_keys: {missing_keys}' @@ -277,7 +277,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' - kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) + kwargs = convert_hf_config(processor.model_info.config) logger.info(f'megatron_config: {kwargs}') _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() diff --git a/swift/megatron/model/config.py b/swift/megatron/model/config.py deleted file mode 100644 index 7a68537efc..0000000000 --- a/swift/megatron/model/config.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict - -from swift.utils import get_logger - -logger = get_logger() -config_mapping = { - 'num_layers': ['num_hidden_layers'], - 'hidden_size': ['hidden_size'], - 'ffn_hidden_size': ['intermediate_size'], - 'num_attention_heads': ['num_attention_heads'], - 'num_query_groups': ['num_key_value_heads'], - 'max_position_embeddings': ['max_position_embeddings'], - 'norm_epsilon': ['rms_norm_eps'], - 'rotary_base': ['rope_theta'], - 'padded_vocab_size': ['vocab_size'], - 'attention_dropout': ['attention_dropout'], - 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], - 'swiglu': ['hidden_act'], - 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'], - 'disable_bias_linear': ['mlp_bias'], - 'kv_channels': ['head_dim', 'v_head_dim'], - 'architectures': ['architectures'], - # moe - 'moe_ffn_hidden_size': ['moe_intermediate_size'], - 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], - 'moe_router_topk': ['num_experts_per_tok', 'n_group', 'moe_topk', 'moe_k'], - 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts'], - 'moe_router_pre_softmax': ['norm_topk_prob'], - # deepseek - 'q_lora_rank': ['q_lora_rank'], - 'kv_lora_rank': ['kv_lora_rank'], - 'moe_router_score_function': ['scoring_func'], - 'qk_head_dim': ['qk_nope_head_dim'], - 'qk_pos_emb_head_dim': ['qk_rope_head_dim'], - 'moe_router_topk_scaling_factor': ['routed_scaling_factor'], - 'qk_layernorm': ['use_qk_norm'], - # qwen3_next - 'linear_num_value_heads': ['linear_num_value_heads'], - 'linear_num_key_heads': ['linear_num_key_heads'], - 'linear_key_head_dim': ['linear_key_head_dim'], - 'linear_value_head_dim': ['linear_value_head_dim'], - 'linear_conv_kernel_dim': ['linear_conv_kernel_dim'], - 'full_attention_interval': ['full_attention_interval'], - # other - 'original_max_position_embeddings': ['original_max_position_embeddings'], - 'partial_rotary_factor': ['partial_rotary_factor'], - 'first_k_dense_replace': ['first_k_dense_replace', 'moe_layer_start_index'], - 'n_shared_experts': ['n_shared_experts', 'num_shared_expert', 'moe_num_shared_experts'], -} - - -def convert_hf_config(config, _internal_call=False) -> Dict[str, Any]: - megatron_config = {} - for k, hf_keys in config_mapping.items(): - for hf_k in hf_keys: - if hasattr(config, hf_k): - hf_v = getattr(config, hf_k) - if hf_v is None: - continue - if k == 'rotary_base': - megatron_config[k] = int(hf_v) - elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: - megatron_config[k] = not hf_v - elif k == 'swiglu': - if hf_v == 'silu': - megatron_config[k] = True - else: - if k == 'kv_lora_rank': - megatron_config['multi_latent_attention'] = True - elif k == 'architectures': - if _internal_call: - k = 'llm_architectures' - megatron_config[k] = hf_v - break - for key in ['text_config', 'llm_config', 'thinker_config']: - if hasattr(config, key): - megatron_config.update(convert_hf_config(getattr(config, key), _internal_call=True)) - # compat llama3 - if getattr(config, 'rope_scaling', None) is not None: - if isinstance(config.rope_scaling, int): - megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'}, - elif isinstance(config.rope_scaling, dict): - megatron_config['rope_scaling'] = config.rope_scaling - return megatron_config diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py index f3eb68e0cd..02ee79de92 100644 --- a/swift/megatron/model/gpt/__init__.py +++ b/swift/megatron/model/gpt/__init__.py @@ -3,8 +3,7 @@ from ..constant import MegatronModelType from ..gpt_model import GPTModel from ..register import MegatronModelMeta, register_megatron_model -from . import qwen3_next -from .config import convert_gpt_hf_config +# from . import qwen3_next from .hf2mcore import convert_hf2mcore from .mcore2hf import convert_mcore2hf @@ -59,7 +58,7 @@ ModelType.ernie_thinking, ], model_cls=GPTModel, - convert_hf_config=convert_gpt_hf_config, - convert_mcore2hf=convert_mcore2hf, - convert_hf2mcore=convert_hf2mcore, + # convert_hf_config=convert_gpt_hf_config, + # convert_mcore2hf=convert_mcore2hf, + # convert_hf2mcore=convert_hf2mcore, )) diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py index ccfd265bae..8f05390265 100644 --- a/swift/megatron/model/gpt/hf2mcore.py +++ b/swift/megatron/model/gpt/hf2mcore.py @@ -141,17 +141,6 @@ def set_layer_state(args, state_dict, prefix: str): return _add_prefix(mg_state_dict, prefix) -def _remove_prefix(state_dict, prefix: str): - if not prefix: - return state_dict - return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} - - -def _add_prefix(state_dict, prefix: str): - if not prefix: - return state_dict - return {f'{prefix}{k}': v for k, v in state_dict.items()} - def convert_hf2mcore(state_dict, prefix=''): args = get_args() diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index 8a5235921c..e8f5cb1a23 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -3,7 +3,6 @@ import torch from megatron.training import get_args -from .hf2mcore import _add_prefix, _remove_prefix from torch import nn from tqdm import tqdm diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py new file mode 100644 index 0000000000..e4971177cb --- /dev/null +++ b/swift/megatron/model/gpt_bridge.py @@ -0,0 +1,82 @@ + +from typing import Any, Dict +import torch +from megatron.training import get_args +from swift.llm import get_model_tokenizer, deep_getattr +from swift.utils import disable_safe_ddp_context_use_barrier + +class GPTBridge: + lm_layers_prefix = 'model.layers' # hf + + def __init__(self): + self.args = get_args() + model_info = self.args.model_info + with torch.device('meta'), disable_safe_ddp_context_use_barrier(): + self.hf_model, _ = get_model_tokenizer( + model_info.model_dir, + model_type=model_info.model_type, + return_dummy_model=True) + self.hf_layers = deep_getattr(self.hf_model, self.lm_layers_prefix) + + + def _set_state_dict(self, state_dict, res_state_dict, hf_key: str, mg_key: str, reverse: bool = False): + src_key, tgt_key = hf_key, mg_key + if reverse: + src_key, tgt_key = tgt_key, src_key + res_state_dict[tgt_key] = state_dict[src_key] + + @staticmethod + def _remove_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} + + @staticmethod + def _add_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {f'{prefix}{k}': v for k, v in state_dict.items()} + + def _set_layer_state(state_dict, hf_prefix: str, mg_prefix: str, reverse: bool = False): + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict) + res = {} + if args.multi_latent_attention: + mg_state_dict.update(set_mla_attn_state(args, _remove_prefix(state_dict, 'self_attn.'), 'self_attention.')) + mg_state_dict['input_layernorm.weight'] = state_dict['input_layernorm.weight'] + + else: + mg_state_dict.update(set_attn_state(args, _remove_prefix(state_dict, 'self_attn.'), 'self_attention.')) + mg_state_dict['self_attention.linear_qkv.layer_norm_weight'] = state_dict['input_layernorm.weight'] + + mlp_state_dict = _remove_prefix(state_dict, 'mlp.') + is_moe = _is_moe(mlp_state_dict) + if is_moe: + mg_state_dict.update(_set_moe_state(args, mlp_state_dict, 'mlp.')) + else: + mg_state_dict.update(_set_mlp_state(args, mlp_state_dict, 'mlp.')) + + if is_moe: + mg_state_dict['pre_mlp_layernorm.weight'] = state_dict['post_attention_layernorm.weight'] + else: + mg_state_dict['mlp.linear_fc1.layer_norm_weight'] = state_dict['post_attention_layernorm.weight'] + return _add_prefix(mg_state_dict, prefix) + + + def convert_hf2mcore(self, state_dict, prefix: str = '', reverse: bool = False): + res = {} + self._set_state_dict(state_dict, res, 'model.embed_tokens.weight', 'embedding.word_embeddings.weight', reverse) + if args.untie_embeddings_and_output_weights: + self._set_state_dict(state_dict, res, 'lm_head.weight', 'output_layer.weight', reverse) + self._set_state_dict(state_dict, res, 'model.norm.weight', 'decoder.final_layernorm.weight', reverse) + for layer_idx in tqdm(range(args.num_layers), dynamic_ncols=True, desc='Converting: '): + mg_state_dict.update( + self._set_layer_state(state_dict, f'model.layers.{layer_idx}.', + f'decoder.layers.{layer_idx}.', reverse)) + + + + def convert_mcore2hf(self, state_dict, prefix: str = ''): + return self.convert_hf2mcore(state_dict, prefix, True) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index acaf4f6ec0..3db6b8d45c 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -9,6 +9,7 @@ from swift.llm import MODEL_MAPPING from .model_provider import model_provider as model_provider_func +from .gpt_bridge import GPTBridge MEGATRON_MODEL_MAPPING = {} @@ -17,12 +18,9 @@ class MegatronModelMeta: megatron_model_type: str model_types: List[str] - - convert_mcore2hf: Callable[[Dict[str, torch.Tensor], str], Dict[str, torch.Tensor]] - convert_hf2mcore: Callable[[Dict[str, torch.Tensor], str], Dict[str, torch.Tensor]] - model_cls: Type[nn.Module] - convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] + + bridge_cls: Type[GPTBridge] = GPTBridge get_transformer_layer_spec: Optional[Callable] = None model_provider: Callable[[], nn.Module] = model_provider_func visual_cls: Optional[Type[nn.Module]] = None diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 1833ea043f..4ee3ace3de 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .convert import convert_hf2mcore, convert_mcore2hf from .utils import (adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model, tuners_sharded_state_dict) +from .config import convert_hf_config +from .patcher import patch_torch_dist_shard diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/utils/config.py similarity index 50% rename from swift/megatron/model/gpt/config.py rename to swift/megatron/utils/config.py index e779a827b6..15558c2954 100644 --- a/swift/megatron/model/gpt/config.py +++ b/swift/megatron/utils/config.py @@ -1,11 +1,93 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict -from ..config import convert_hf_config +from swift.utils import get_logger +logger = get_logger() +config_mapping = { + 'num_layers': ['num_hidden_layers'], + 'hidden_size': ['hidden_size'], + 'ffn_hidden_size': ['intermediate_size'], + 'num_attention_heads': ['num_attention_heads'], + 'num_query_groups': ['num_key_value_heads'], + 'max_position_embeddings': ['max_position_embeddings'], + 'norm_epsilon': ['rms_norm_eps'], + 'rotary_base': ['rope_theta'], + 'padded_vocab_size': ['vocab_size'], + 'attention_dropout': ['attention_dropout'], + 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], + 'swiglu': ['hidden_act'], + 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'], + 'disable_bias_linear': ['mlp_bias'], + 'kv_channels': ['head_dim', 'v_head_dim'], + 'architectures': ['architectures'], + # moe + 'moe_ffn_hidden_size': ['moe_intermediate_size'], + 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], + 'moe_router_topk': ['num_experts_per_tok', 'n_group', 'moe_topk', 'moe_k'], + 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts'], + 'moe_router_pre_softmax': ['norm_topk_prob'], + # deepseek + 'q_lora_rank': ['q_lora_rank'], + 'kv_lora_rank': ['kv_lora_rank'], + 'moe_router_score_function': ['scoring_func'], + 'qk_head_dim': ['qk_nope_head_dim'], + 'qk_pos_emb_head_dim': ['qk_rope_head_dim'], + 'moe_router_topk_scaling_factor': ['routed_scaling_factor'], + 'qk_layernorm': ['use_qk_norm'], + # qwen3_next + 'linear_num_value_heads': ['linear_num_value_heads'], + 'linear_num_key_heads': ['linear_num_key_heads'], + 'linear_key_head_dim': ['linear_key_head_dim'], + 'linear_value_head_dim': ['linear_value_head_dim'], + 'linear_conv_kernel_dim': ['linear_conv_kernel_dim'], + 'full_attention_interval': ['full_attention_interval'], + # other + 'original_max_position_embeddings': ['original_max_position_embeddings'], + 'partial_rotary_factor': ['partial_rotary_factor'], + 'first_k_dense_replace': ['first_k_dense_replace', 'moe_layer_start_index'], + 'n_shared_experts': ['n_shared_experts', 'num_shared_expert', 'moe_num_shared_experts'], +} -def convert_gpt_hf_config(config) -> Dict[str, Any]: - res = convert_hf_config(config) + +def _convert_config(config, _internal_call=False) -> Dict[str, Any]: + megatron_config = {} + for k, hf_keys in config_mapping.items(): + for hf_k in hf_keys: + if hasattr(config, hf_k): + hf_v = getattr(config, hf_k) + if hf_v is None: + continue + if k == 'rotary_base': + megatron_config[k] = int(hf_v) + elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: + megatron_config[k] = not hf_v + elif k == 'swiglu': + if hf_v == 'silu': + megatron_config[k] = True + else: + if k == 'kv_lora_rank': + megatron_config['multi_latent_attention'] = True + elif k == 'architectures': + if _internal_call: + k = 'llm_architectures' + megatron_config[k] = hf_v + break + for key in ['text_config', 'llm_config', 'thinker_config']: + if hasattr(config, key): + megatron_config.update(convert_hf_config(getattr(config, key), _internal_call=True)) + # compat llama3 + if getattr(config, 'rope_scaling', None) is not None: + if isinstance(config.rope_scaling, int): + megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'}, + elif isinstance(config.rope_scaling, dict): + megatron_config['rope_scaling'] = config.rope_scaling + return megatron_config + + + +def convert_hf_config(config) -> Dict[str, Any]: + res = _convert_config(config) architectures = res.get('architectures') if isinstance(architectures, list) and architectures: architectures = architectures[0] From 4900c15cd0845f9a17a18e4b17aa901ae4b0a2df Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 24 Oct 2025 14:56:34 +0800 Subject: [PATCH 06/30] update --- swift/megatron/argument/train_args.py | 2 +- swift/megatron/convert.py | 6 +- swift/megatron/model/gpt/hf2mcore.py | 10 +- swift/megatron/model/gpt/mcore2hf.py | 8 +- swift/megatron/model/gpt_bridge.py | 183 ++++++++++++++++++++---- swift/megatron/model/mm_gpt/qwen3_vl.py | 2 + swift/megatron/model/model_provider.py | 7 +- swift/megatron/model/register.py | 2 +- swift/megatron/utils/__init__.py | 4 +- swift/megatron/utils/config.py | 1 - 10 files changed, 169 insertions(+), 56 deletions(-) diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 9aa3533a8f..74b7f2e055 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -9,8 +9,8 @@ from swift.llm.argument.base_args import to_abspath from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master from ..model import get_megatron_model_meta -from .megatron_args import MegatronArguments from ..utils import convert_hf_config +from .megatron_args import MegatronArguments logger = get_logger() diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index f2d5d11da6..db9c80a246 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -12,12 +12,13 @@ from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint from megatron.training.initialize import initialize_megatron from megatron.training.utils import get_ltor_masks_and_position_ids + from swift.llm import (ExportArguments, HfConfigFactory, prepare_model_template, save_checkpoint, to_device, to_float_dtype) from swift.utils import get_logger, get_n_params_grads -from .utils import convert_hf_config, patch_torch_dist_shard from .argument import MegatronArguments from .model import get_megatron_model_meta +from .utils import convert_hf_config, patch_torch_dist_shard logger = get_logger() @@ -255,8 +256,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: mg_model = megatron_model_meta.model_provider() logger.info('Megatron model created successfully.') bridge = megatron_model_meta.bridge_cls() - incompatible_keys = mg_model.load_state_dict( - bridge.convert_hf2mcore(hf_model.state_dict()), strict=False) + incompatible_keys = mg_model.load_state_dict(bridge.convert_hf2mcore(hf_model.state_dict()), strict=False) missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')] assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}' assert len(missing_keys) == 0, f'missing_keys: {missing_keys}' diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py index 8f05390265..4c06eddbb8 100644 --- a/swift/megatron/model/gpt/hf2mcore.py +++ b/swift/megatron/model/gpt/hf2mcore.py @@ -102,7 +102,7 @@ def _set_moe_state(args, state_dict, prefix: str): if 'shared_expert_gate.weight' in state_dict: mg_state_dict['shared_experts.gate_weight'] = state_dict['shared_expert_gate.weight'] for expert_idx in range(args.num_experts): - expert_sd = _remove_prefix(state_dict, f'experts.') + expert_sd = _remove_prefix(state_dict, 'experts.') hf_grouped = expert_sd is not None if expert_sd is None: expert_sd = _remove_prefix(state_dict, f'experts.{expert_idx}.') @@ -110,13 +110,6 @@ def _set_moe_state(args, state_dict, prefix: str): return _add_prefix(mg_state_dict, prefix) -def _is_moe(state_dict): - for k, v in state_dict.items(): - if 'experts.' in k: - return True - return False - - def set_layer_state(args, state_dict, prefix: str): mg_state_dict = {} if args.multi_latent_attention: @@ -141,7 +134,6 @@ def set_layer_state(args, state_dict, prefix: str): return _add_prefix(mg_state_dict, prefix) - def convert_hf2mcore(state_dict, prefix=''): args = get_args() mg_state_dict = {} diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index e8f5cb1a23..f417bafae7 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -96,7 +96,7 @@ def _set_moe_state(args, state_dict, prefix: str): if 'shared_expert_gate.weight' in state_dict: mg_state_dict['shared_experts.gate_weight'] = state_dict['shared_expert_gate.weight'] for expert_idx in range(args.num_experts): - expert_sd = _remove_prefix(state_dict, f'experts.') + expert_sd = _remove_prefix(state_dict, 'experts.') hf_grouped = expert_sd is not None if expert_sd is None: expert_sd = _remove_prefix(state_dict, f'experts.{expert_idx}.') @@ -153,9 +153,9 @@ def convert_mcore2hf(state_dict, prefix=''): hf_state_dict['model.embed_tokens.weight'] = state_dict['embedding.word_embeddings.weight'] if args.untie_embeddings_and_output_weights: hf_state_dict['lm_head.weight'] = state_dict['output_layer.weight'] - hf_state_dict[f'model.norm.weight'] = state_dict['decoder.final_layernorm.weight'] + hf_state_dict['model.norm.weight'] = state_dict['decoder.final_layernorm.weight'] for layer_idx in tqdm(range(args.num_layers), dynamic_ncols=True, desc='Converting: '): hf_state_dict.update( - set_layer_state(args, _remove_prefix(state_dict, f'decoder.layers.{layer_idx}.' - ), f'model.layers.{layer_idx}.')) + set_layer_state(args, _remove_prefix(state_dict, f'decoder.layers.{layer_idx}.'), + f'model.layers.{layer_idx}.')) return _add_prefix(hf_state_dict, prefix) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index e4971177cb..32f907e6bc 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1,10 +1,13 @@ +from typing import Any, Dict, Optional -from typing import Any, Dict import torch from megatron.training import get_args -from swift.llm import get_model_tokenizer, deep_getattr +from tqdm import tqdm + +from swift.llm import deep_getattr, get_model_tokenizer from swift.utils import disable_safe_ddp_context_use_barrier + class GPTBridge: lm_layers_prefix = 'model.layers' # hf @@ -13,13 +16,10 @@ def __init__(self): model_info = self.args.model_info with torch.device('meta'), disable_safe_ddp_context_use_barrier(): self.hf_model, _ = get_model_tokenizer( - model_info.model_dir, - model_type=model_info.model_type, - return_dummy_model=True) + model_info.model_dir, model_type=model_info.model_type, return_dummy_model=True) self.hf_layers = deep_getattr(self.hf_model, self.lm_layers_prefix) - - def _set_state_dict(self, state_dict, res_state_dict, hf_key: str, mg_key: str, reverse: bool = False): + def _set_state_dict(self, state_dict, res_state_dict, hf_key: str, mg_key: str, reverse: bool): src_key, tgt_key = hf_key, mg_key if reverse: src_key, tgt_key = tgt_key, src_key @@ -37,46 +37,167 @@ def _add_prefix(state_dict, prefix: str): return state_dict return {f'{prefix}{k}': v for k, v in state_dict.items()} - def _set_layer_state(state_dict, hf_prefix: str, mg_prefix: str, reverse: bool = False): + @staticmethod + def _is_moe(state_dict): + for k, v in state_dict.items(): + if 'experts.' in k: + return True + return False + + def _set_attn_state(self, state_dict, hf_prefix: str, mg_prefix: str, hf_attn, reverse: bool): src_prefix, tgt_prefix = hf_prefix, mg_prefix if reverse: src_prefix, tgt_prefix = tgt_prefix, src_prefix - state_dict = self._remove_prefix(state_dict) + state_dict = self._remove_prefix(state_dict, src_prefix) + args = self.args res = {} - if args.multi_latent_attention: - mg_state_dict.update(set_mla_attn_state(args, _remove_prefix(state_dict, 'self_attn.'), 'self_attention.')) - mg_state_dict['input_layernorm.weight'] = state_dict['input_layernorm.weight'] + num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) + if reverse: + pass + else: + res['linear_qkv.weight'] = torch.cat([ + state_dict['q_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), + state_dict['k_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), + state_dict['v_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), + ], + dim=1).reshape((-1, args.hidden_size)) + self._set_state_dict(state_dict, res, 'o_proj.weight', 'linear_proj.weight', reverse) + + # Copy bias + if args.add_qkv_bias: + if reverse: + pass + else: + res['linear_qkv.bias'] = torch.cat([ + state_dict['q_proj.bias'].reshape((num_query_groups, -1)), + state_dict['k_proj.bias'].reshape((num_query_groups, -1)), + state_dict['v_proj.bias'].reshape((num_query_groups, -1)), + ], + dim=1).reshape(-1) + if args.qk_layernorm: + if 'q_norm.weight' in state_dict: + res['q_layernorm.weight'] = state_dict['q_norm.weight'] + else: + res['q_layernorm.weight'] = state_dict['query_layernorm.weight'] + if 'k_norm.weight' in state_dict: + res['k_layernorm.weight'] = state_dict['k_norm.weight'] + else: + res['k_layernorm.weight'] = state_dict['key_layernorm.weight'] + + return self._add_prefix(res, tgt_prefix) + + def _set_moe_state(self, state_dict, prefix: str): + mg_state_dict = {} + if 'gate.wg.weight' in state_dict: + mg_state_dict['router.weight'] = state_dict['gate.wg.weight'] + else: + mg_state_dict['router.weight'] = state_dict['gate.weight'] + if args.moe_router_enable_expert_bias: + mg_state_dict['router.expert_bias'] = state_dict['gate.e_score_correction_bias'] + + if args.moe_shared_expert_intermediate_size: + shared_expert_sd = _remove_prefix(state_dict, 'shared_expert.') + if not shared_expert_sd: + shared_expert_sd = _remove_prefix(state_dict, 'shared_experts.') + if not shared_expert_sd: + shared_expert_sd = _remove_prefix(state_dict, 'shared_mlp.') + mg_state_dict.update(_set_mlp_state(args, shared_expert_sd, 'shared_experts.')) + if 'shared_expert_gate.weight' in state_dict: + mg_state_dict['shared_experts.gate_weight'] = state_dict['shared_expert_gate.weight'] + for expert_idx in range(args.num_experts): + expert_sd = _remove_prefix(state_dict, 'experts.') + hf_grouped = expert_sd is not None + if expert_sd is None: + expert_sd = _remove_prefix(state_dict, f'experts.{expert_idx}.') + mg_state_dict.update( + _set_mlp_state(args, expert_sd, 'experts.', group_idx=expert_idx, hf_grouped=hf_grouped)) + return _add_prefix(mg_state_dict, prefix) + def _set_mlp_state( + self, + state_dict, + hf_prefix: str, + mg_prefix: str, + hf_mlp, + reverse: bool, + group_idx: Optional[int] = None, + hf_grouped: bool = False, + ): + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) + res = {} + # Determines the keys for fc1 and fc2 in megatron + if group_idx is None: + fc1_key = 'linear_fc1.weight' + fc2_key = 'linear_fc2.weight' else: - mg_state_dict.update(set_attn_state(args, _remove_prefix(state_dict, 'self_attn.'), 'self_attention.')) - mg_state_dict['self_attention.linear_qkv.layer_norm_weight'] = state_dict['input_layernorm.weight'] + fc1_key = f'linear_fc1.weight{group_idx}' + fc2_key = f'linear_fc2.weight{group_idx}' + if hf_grouped: + res[fc1_key] = state_dict['gate_up_proj'][group_idx].t() + res[fc2_key] = state_dict['down_proj'][group_idx].t() + else: + if hasattr(hf_mlp, 'gate_up_proj'): + self._set_state_dict(state_dict, res, 'gate_up_proj.weight', fc1_key, reverse) + else: + if reverse: + pass + else: + res[fc1_key] = torch.cat([ + state_dict['gate_proj.weight'], + state_dict['up_proj.weight'], + ], dim=0) + self._set_state_dict(state_dict, res, 'down_proj.weight', fc2_key, reverse) + return self._add_prefix(res, tgt_prefix) + + def _set_layer_state(self, state_dict, layer_idx: int, hf_prefix: str, mg_prefix: str, reverse: bool): + hf_prefix = f'{hf_prefix}{layer_idx}.' + mg_prefix = f'{mg_prefix}{layer_idx}.' + hf_layer = self.hf_layers[layer_idx] + hf_attn, hf_mlp = hf_layer.self_attn, hf_layer.mlp + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) + res = {} + if self.args.multi_latent_attention: + res.update(self._set_mla_attn_state(state_dict, 'self_attn.', 'self_attention.', reverse)) + self._set_state_dict(state_dict, res, 'input_layernorm.weight', 'input_layernorm.weight', reverse) + else: + res.update(self._set_attn_state(state_dict, 'self_attn.', 'self_attention.', hf_attn, reverse)) + self._set_state_dict(state_dict, res, 'input_layernorm.weight', + 'self_attention.linear_qkv.layer_norm_weight', reverse) - mlp_state_dict = _remove_prefix(state_dict, 'mlp.') - is_moe = _is_moe(mlp_state_dict) + is_moe = self._is_moe(hf_mlp.state_dict()) if is_moe: - mg_state_dict.update(_set_moe_state(args, mlp_state_dict, 'mlp.')) + res.update(self._set_moe_state(state_dict, 'mlp.')) else: - mg_state_dict.update(_set_mlp_state(args, mlp_state_dict, 'mlp.')) + res.update(self._set_mlp_state(state_dict, 'mlp.', 'mlp.', hf_mlp, reverse)) if is_moe: - mg_state_dict['pre_mlp_layernorm.weight'] = state_dict['post_attention_layernorm.weight'] + res['pre_mlp_layernorm.weight'] = state_dict['post_attention_layernorm.weight'] else: - mg_state_dict['mlp.linear_fc1.layer_norm_weight'] = state_dict['post_attention_layernorm.weight'] - return _add_prefix(mg_state_dict, prefix) + res['mlp.linear_fc1.layer_norm_weight'] = state_dict['post_attention_layernorm.weight'] + return self._add_prefix(res, tgt_prefix) - - def convert_hf2mcore(self, state_dict, prefix: str = '', reverse: bool = False): + def _convert(self, state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) res = {} self._set_state_dict(state_dict, res, 'model.embed_tokens.weight', 'embedding.word_embeddings.weight', reverse) - if args.untie_embeddings_and_output_weights: + if self.args.untie_embeddings_and_output_weights: self._set_state_dict(state_dict, res, 'lm_head.weight', 'output_layer.weight', reverse) self._set_state_dict(state_dict, res, 'model.norm.weight', 'decoder.final_layernorm.weight', reverse) - for layer_idx in tqdm(range(args.num_layers), dynamic_ncols=True, desc='Converting: '): - mg_state_dict.update( - self._set_layer_state(state_dict, f'model.layers.{layer_idx}.', - f'decoder.layers.{layer_idx}.', reverse)) - + for layer_idx in tqdm(range(self.args.num_layers), dynamic_ncols=True, desc='Converting: '): + res.update(self._set_layer_state(state_dict, layer_idx, 'model.layers.', 'decoder.layers.', reverse)) + return self._add_prefix(res, tgt_prefix) + def convert_hf2mcore(self, state_dict): + return self._convert(state_dict, '', '', False) - def convert_mcore2hf(self, state_dict, prefix: str = ''): - return self.convert_hf2mcore(state_dict, prefix, True) + def convert_mcore2hf(self, state_dict): + return self._convert(state_dict, '', '', True) diff --git a/swift/megatron/model/mm_gpt/qwen3_vl.py b/swift/megatron/model/mm_gpt/qwen3_vl.py index 564529d0f4..b0ce9b8f02 100644 --- a/swift/megatron/model/mm_gpt/qwen3_vl.py +++ b/swift/megatron/model/mm_gpt/qwen3_vl.py @@ -499,6 +499,7 @@ def __init__(self, *args, **kwargs): model_cls=Qwen3VLGPTModel, visual_cls=Qwen3Omni_Vit)) + def convert_hf2mcore_qwen3_vl(state_dict, prefix=''): args = get_args() mg_state_dict = {} @@ -508,6 +509,7 @@ def convert_hf2mcore_qwen3_vl(state_dict, prefix=''): mg_state_dict.update(_add_prefix(_remove_prefix(state_dict, 'model.visual.'), 'visual.visual.')) return _add_prefix(mg_state_dict, prefix) + def convert_mcore2hf_qwen3_vl(hf_model, mg_model): language_model = hf_model.model.language_model mg_language_model = mg_model.language_model diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index fba3706a6a..392785e085 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -17,10 +17,9 @@ # Code borrowed from NVIDIA/Megatron-LM -def model_provider( - pre_process=True, - post_process=True, - vp_stage: Optional[int] = None) -> Union['GPTModel', megatron.legacy.model.GPTModel]: +def model_provider(pre_process=True, + post_process=True, + vp_stage: Optional[int] = None) -> Union['GPTModel', megatron.legacy.model.GPTModel]: """Builds the model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 3db6b8d45c..c511caf501 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -8,8 +8,8 @@ from transformers import PretrainedConfig from swift.llm import MODEL_MAPPING -from .model_provider import model_provider as model_provider_func from .gpt_bridge import GPTBridge +from .model_provider import model_provider as model_provider_func MEGATRON_MODEL_MAPPING = {} diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 4ee3ace3de..91fbd0fbf8 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .utils import (adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model, - tuners_sharded_state_dict) from .config import convert_hf_config from .patcher import patch_torch_dist_shard +from .utils import (adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model, + tuners_sharded_state_dict) diff --git a/swift/megatron/utils/config.py b/swift/megatron/utils/config.py index 15558c2954..fba9dee26c 100644 --- a/swift/megatron/utils/config.py +++ b/swift/megatron/utils/config.py @@ -85,7 +85,6 @@ def _convert_config(config, _internal_call=False) -> Dict[str, Any]: return megatron_config - def convert_hf_config(config) -> Dict[str, Any]: res = _convert_config(config) architectures = res.get('architectures') From 302fbf46c83da2e1612d10fbc04fba1bd347adb2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 24 Oct 2025 16:06:55 +0800 Subject: [PATCH 07/30] updae --- swift/megatron/convert.py | 6 +- swift/megatron/model/gpt_bridge.py | 114 ++++++++++++++++++----------- 2 files changed, 76 insertions(+), 44 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index db9c80a246..aae514fd16 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -314,7 +314,11 @@ def convert_mcore2hf(args: ExportArguments) -> None: logger.info('Megatron model created successfully.') if args.to_hf: hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) - megatron_model_meta.convert_mcore2hf(hf_model, mg_model) + bridge = megatron_model_meta.bridge_cls() + incompatible_keys = hf_model.load_state_dict(bridge.convert_mcore2hf(mg_model.state_dict()), strict=False) + missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')] + assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}' + assert len(missing_keys) == 0, f'missing_keys: {missing_keys}' if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) del mg_model diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 32f907e6bc..e606dd14ca 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -53,7 +53,12 @@ def _set_attn_state(self, state_dict, hf_prefix: str, mg_prefix: str, hf_attn, r res = {} num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) if reverse: - pass + mg_attn_weight = state_dict['linear_qkv.weight'].reshape((num_query_groups, -1, args.hidden_size)) + q_dim = args.kv_channels * args.num_attention_heads // num_query_groups + kv_dim = args.kv_channels + res['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size) + res['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size) + res['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size) else: res['linear_qkv.weight'] = torch.cat([ state_dict['q_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), @@ -66,7 +71,10 @@ def _set_attn_state(self, state_dict, hf_prefix: str, mg_prefix: str, hf_attn, r # Copy bias if args.add_qkv_bias: if reverse: - pass + mg_attn_bias = state_dict['linear_qkv.bias'].reshape((num_query_groups, -1)) + res['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1) + res['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1) + res['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1) else: res['linear_qkv.bias'] = torch.cat([ state_dict['q_proj.bias'].reshape((num_query_groups, -1)), @@ -75,43 +83,37 @@ def _set_attn_state(self, state_dict, hf_prefix: str, mg_prefix: str, hf_attn, r ], dim=1).reshape(-1) if args.qk_layernorm: - if 'q_norm.weight' in state_dict: - res['q_layernorm.weight'] = state_dict['q_norm.weight'] - else: - res['q_layernorm.weight'] = state_dict['query_layernorm.weight'] - if 'k_norm.weight' in state_dict: - res['k_layernorm.weight'] = state_dict['k_norm.weight'] - else: - res['k_layernorm.weight'] = state_dict['key_layernorm.weight'] + hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight' + hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight' + self._set_state_dict(state_dict, res, hf_q_norm_key, 'q_layernorm.weight', reverse) + self._set_state_dict(state_dict, res, hf_k_norm_key, 'k_layernorm.weight', reverse) return self._add_prefix(res, tgt_prefix) - def _set_moe_state(self, state_dict, prefix: str): - mg_state_dict = {} - if 'gate.wg.weight' in state_dict: - mg_state_dict['router.weight'] = state_dict['gate.wg.weight'] - else: - mg_state_dict['router.weight'] = state_dict['gate.weight'] - if args.moe_router_enable_expert_bias: - mg_state_dict['router.expert_bias'] = state_dict['gate.e_score_correction_bias'] - - if args.moe_shared_expert_intermediate_size: - shared_expert_sd = _remove_prefix(state_dict, 'shared_expert.') - if not shared_expert_sd: - shared_expert_sd = _remove_prefix(state_dict, 'shared_experts.') - if not shared_expert_sd: - shared_expert_sd = _remove_prefix(state_dict, 'shared_mlp.') - mg_state_dict.update(_set_mlp_state(args, shared_expert_sd, 'shared_experts.')) - if 'shared_expert_gate.weight' in state_dict: - mg_state_dict['shared_experts.gate_weight'] = state_dict['shared_expert_gate.weight'] - for expert_idx in range(args.num_experts): - expert_sd = _remove_prefix(state_dict, 'experts.') - hf_grouped = expert_sd is not None - if expert_sd is None: - expert_sd = _remove_prefix(state_dict, f'experts.{expert_idx}.') - mg_state_dict.update( - _set_mlp_state(args, expert_sd, 'experts.', group_idx=expert_idx, hf_grouped=hf_grouped)) - return _add_prefix(mg_state_dict, prefix) + def _set_moe_state(self, state_dict, hf_prefix: str, mg_prefix: str, hf_mlp, reverse: bool): + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) + res = {} + hf_gate_key = 'gate.wg.weight' if hasattr(hf_mlp.gate, 'wg') else 'gate.weight' + self._set_state_dict(state_dict, res, hf_gate_key, 'router.weight', reverse) + if self.args.moe_router_enable_expert_bias: + self._set_state_dict(state_dict, res, 'gate.e_score_correction_bias', 'router.expert_bias', reverse) + + if self.args.moe_shared_expert_intermediate_size: + for key in ['shared_expert', 'shared_experts', 'shared_mlp']: + if hasattr(hf_mlp, key): + hf_shared_expert_prefix = f'{key}.' + res.update(self._set_mlp_state(state_dict, hf_shared_expert_prefix, 'shared_experts.', hf_mlp, reverse)) + if hasattr(hf_mlp, 'shared_expert_gate'): + self._set_state_dict(state_dict, res, 'shared_expert_gate.weight', 'shared_experts.gate_weight', + reverse) + for expert_idx in range(self.args.num_experts): + hf_expert_prefix = f'experts.{expert_idx}.' if hasattr(hf_mlp.experts, '__len__') else 'experts.' + res.update( + self._set_mlp_state(state_dict, hf_expert_prefix, 'experts.', hf_mlp, reverse, group_idx=expert_idx)) + return self._add_prefix(res, tgt_prefix) def _set_mlp_state( self, @@ -121,12 +123,12 @@ def _set_mlp_state( hf_mlp, reverse: bool, group_idx: Optional[int] = None, - hf_grouped: bool = False, ): src_prefix, tgt_prefix = hf_prefix, mg_prefix if reverse: src_prefix, tgt_prefix = tgt_prefix, src_prefix state_dict = self._remove_prefix(state_dict, src_prefix) + hf_grouped = not hasattr(hf_mlp.experts, '__len__') res = {} # Determines the keys for fc1 and fc2 in megatron if group_idx is None: @@ -143,7 +145,8 @@ def _set_mlp_state( self._set_state_dict(state_dict, res, 'gate_up_proj.weight', fc1_key, reverse) else: if reverse: - pass + res['gate_proj.weight'] = state_dict[fc1_key][:self.args.ffn_hidden_size] + res['up_proj.weight'] = state_dict[fc1_key][self.args.ffn_hidden_size:] else: res[fc1_key] = torch.cat([ state_dict['gate_proj.weight'], @@ -152,6 +155,31 @@ def _set_mlp_state( self._set_state_dict(state_dict, res, 'down_proj.weight', fc2_key, reverse) return self._add_prefix(res, tgt_prefix) + def _set_mla_attn_state( + self, + state_dict, + hf_prefix: str, + mg_prefix: str, + hf_mlp, + reverse: bool, + ): + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) + res = {} + self._set_state_dict(state_dict, res, 'o_proj.weight', 'linear_proj.weight', reverse) + if self.args.q_lora_rank is None: + self._set_state_dict(state_dict, res, 'q_proj.weight', 'linear_q_proj.weight', reverse) + else: + self._set_state_dict(state_dict, res, 'q_a_proj.weight', 'linear_q_down_proj.weight', reverse) + self._set_state_dict(state_dict, res, 'q_b_proj.weight', 'linear_q_up_proj.weight', reverse) + self._set_state_dict(state_dict, res, 'kv_a_proj_with_mqa.weight', 'linear_kv_down_proj.weight', reverse) + self._set_state_dict(state_dict, res, 'kv_b_proj.weight', 'linear_kv_up_proj.weight', reverse) + if self.args.qk_layernorm: + self._set_state_dict(state_dict, res, 'kv_a_layernorm.weight', 'linear_kv_up_proj.weight', reverse) + return self._add_prefix(res, tgt_prefix) + def _set_layer_state(self, state_dict, layer_idx: int, hf_prefix: str, mg_prefix: str, reverse: bool): hf_prefix = f'{hf_prefix}{layer_idx}.' mg_prefix = f'{mg_prefix}{layer_idx}.' @@ -172,14 +200,14 @@ def _set_layer_state(self, state_dict, layer_idx: int, hf_prefix: str, mg_prefix is_moe = self._is_moe(hf_mlp.state_dict()) if is_moe: - res.update(self._set_moe_state(state_dict, 'mlp.')) + res.update(self._set_moe_state(state_dict, 'mlp.', 'mlp.', hf_mlp, reverse)) + self._set_state_dict(state_dict, res, 'post_attention_layernorm.weight', 'pre_mlp_layernorm.weight', + reverse) else: res.update(self._set_mlp_state(state_dict, 'mlp.', 'mlp.', hf_mlp, reverse)) + self._set_state_dict(state_dict, res, 'post_attention_layernorm.weight', 'mlp.linear_fc1.layer_norm_weight', + reverse) - if is_moe: - res['pre_mlp_layernorm.weight'] = state_dict['post_attention_layernorm.weight'] - else: - res['mlp.linear_fc1.layer_norm_weight'] = state_dict['post_attention_layernorm.weight'] return self._add_prefix(res, tgt_prefix) def _convert(self, state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): From 33a4ef04c64c2c585855464e9c6683e491d7f1f5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 24 Oct 2025 16:35:45 +0800 Subject: [PATCH 08/30] update --- swift/megatron/__init__.py | 2 -- swift/megatron/model/gpt_bridge.py | 13 ++++++++----- swift/megatron/model/gpt_model.py | 3 --- tests/megatron/test_auto.py | 3 --- tests/megatron/test_save.py | 4 +--- 5 files changed, 9 insertions(+), 16 deletions(-) delete mode 100644 tests/megatron/test_auto.py diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index 75c639a497..a5a48f0897 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -19,7 +19,6 @@ from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model from .trainers import MegatronTrainer, MegatronDPOTrainer from .tuners import LoraParallelLinear - from .bridge import AutoMcoreModel else: _import_structure = { 'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'], @@ -29,7 +28,6 @@ 'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model'], 'trainers': ['MegatronTrainer', 'MegatronDPOTrainer'], 'tuners': ['LoraParallelLinear'], - 'bridge': ['AutoMcoreModel'], } import sys diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index e606dd14ca..05b580d41d 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -128,7 +128,9 @@ def _set_mlp_state( if reverse: src_prefix, tgt_prefix = tgt_prefix, src_prefix state_dict = self._remove_prefix(state_dict, src_prefix) - hf_grouped = not hasattr(hf_mlp.experts, '__len__') + hf_grouped = False + if group_idx is not None and not hasattr(hf_mlp.experts, '__len__'): + hf_grouped = True res = {} # Determines the keys for fc1 and fc2 in megatron if group_idx is None: @@ -145,8 +147,9 @@ def _set_mlp_state( self._set_state_dict(state_dict, res, 'gate_up_proj.weight', fc1_key, reverse) else: if reverse: - res['gate_proj.weight'] = state_dict[fc1_key][:self.args.ffn_hidden_size] - res['up_proj.weight'] = state_dict[fc1_key][self.args.ffn_hidden_size:] + ffn_hidden_size = state_dict[fc1_key].shape[0] // 2 + res['gate_proj.weight'] = state_dict[fc1_key][:ffn_hidden_size] + res['up_proj.weight'] = state_dict[fc1_key][ffn_hidden_size:] else: res[fc1_key] = torch.cat([ state_dict['gate_proj.weight'], @@ -177,7 +180,7 @@ def _set_mla_attn_state( self._set_state_dict(state_dict, res, 'kv_a_proj_with_mqa.weight', 'linear_kv_down_proj.weight', reverse) self._set_state_dict(state_dict, res, 'kv_b_proj.weight', 'linear_kv_up_proj.weight', reverse) if self.args.qk_layernorm: - self._set_state_dict(state_dict, res, 'kv_a_layernorm.weight', 'linear_kv_up_proj.weight', reverse) + self._set_state_dict(state_dict, res, 'kv_a_layernorm.weight', 'linear_kv_up_proj.layer_norm_weight', reverse) return self._add_prefix(res, tgt_prefix) def _set_layer_state(self, state_dict, layer_idx: int, hf_prefix: str, mg_prefix: str, reverse: bool): @@ -191,7 +194,7 @@ def _set_layer_state(self, state_dict, layer_idx: int, hf_prefix: str, mg_prefix state_dict = self._remove_prefix(state_dict, src_prefix) res = {} if self.args.multi_latent_attention: - res.update(self._set_mla_attn_state(state_dict, 'self_attn.', 'self_attention.', reverse)) + res.update(self._set_mla_attn_state(state_dict, 'self_attn.', 'self_attention.', hf_mlp, reverse)) self._set_state_dict(state_dict, res, 'input_layernorm.weight', 'input_layernorm.weight', reverse) else: res.update(self._set_attn_state(state_dict, 'self_attn.', 'self_attention.', hf_attn, reverse)) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 8058f83705..9c37bda427 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -272,6 +272,3 @@ def forward( def get_input_tensor(self): return self.decoder.input_tensor - - def save_hf_checkpoint(self): - print() diff --git a/tests/megatron/test_auto.py b/tests/megatron/test_auto.py deleted file mode 100644 index 7d5c3af29b..0000000000 --- a/tests/megatron/test_auto.py +++ /dev/null @@ -1,3 +0,0 @@ - -from swift.megatron import AutoMcoreModel - diff --git a/tests/megatron/test_save.py b/tests/megatron/test_save.py index cfc78182ae..c19b7792e6 100644 --- a/tests/megatron/test_save.py +++ b/tests/megatron/test_save.py @@ -12,7 +12,7 @@ def get_mg_model_tokenizer(): _, processor = get_model_tokenizer(model_id, load_model=False) megatron_model_meta = get_megatron_model_meta(processor.model_meta.model_type) model_info = processor.model_info - kwargs = megatron_model_meta.convert_hf_config(model_info.config) + kwargs = convert_hf_config(model_info.config) megatron_args = MegatronArguments( **kwargs, seq_length=1, @@ -22,7 +22,6 @@ def get_mg_model_tokenizer(): save='mcore-hf-test', no_load_optim=True, no_load_rng=True) - patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() initialize_megatron(args_defaults=extra_args) mg_model = megatron_model_meta.model_provider() @@ -57,5 +56,4 @@ def test_save(): from swift.utils import set_default_ddp_config from swift.megatron.argument import MegatronArguments from swift.megatron.model import get_megatron_model_meta - from swift.megatron.utils import patch_megatron_tokenizer test_save() From 8e8fdb1173145d2453d408267cab33f91863db10 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 24 Oct 2025 16:55:49 +0800 Subject: [PATCH 09/30] update --- swift/megatron/model/gpt/__init__.py | 8 +- swift/megatron/model/gpt/hf2mcore.py | 150 ----------------------- swift/megatron/model/gpt/mcore2hf.py | 161 ------------------------- swift/megatron/model/gpt/qwen3_next.py | 13 +- swift/megatron/model/gpt_bridge.py | 9 +- swift/megatron/model/register.py | 16 ++- 6 files changed, 24 insertions(+), 333 deletions(-) delete mode 100644 swift/megatron/model/gpt/hf2mcore.py delete mode 100644 swift/megatron/model/gpt/mcore2hf.py diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py index 02ee79de92..9dd370ccfd 100644 --- a/swift/megatron/model/gpt/__init__.py +++ b/swift/megatron/model/gpt/__init__.py @@ -3,9 +3,7 @@ from ..constant import MegatronModelType from ..gpt_model import GPTModel from ..register import MegatronModelMeta, register_megatron_model -# from . import qwen3_next -from .hf2mcore import convert_hf2mcore -from .mcore2hf import convert_mcore2hf +from . import qwen3_next register_megatron_model( MegatronModelMeta( @@ -57,8 +55,4 @@ ModelType.deepseek_v3_1, ModelType.ernie_thinking, ], - model_cls=GPTModel, - # convert_hf_config=convert_gpt_hf_config, - # convert_mcore2hf=convert_mcore2hf, - # convert_hf2mcore=convert_hf2mcore, )) diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py deleted file mode 100644 index 4c06eddbb8..0000000000 --- a/swift/megatron/model/gpt/hf2mcore.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Optional - -import torch -from megatron.training import get_args -from torch import nn -from tqdm import tqdm - - -def set_mla_attn_state(args, state_dict, prefix: str): - mg_state_dict = {} - mg_state_dict['linear_proj.weight'] = state_dict['o_proj.weight'] - if args.q_lora_rank is None: - mg_state_dict['linear_q_proj.weight'] = state_dict['q_proj.weight'] - else: - mg_state_dict['linear_q_down_proj.weight'] = state_dict['q_a_proj.weight'] - mg_state_dict['linear_q_up_proj.weight'] = state_dict['q_b_proj.weight'] - mg_state_dict['linear_kv_down_proj.weight'] = state_dict['kv_a_proj_with_mqa.weight'] - mg_state_dict['linear_kv_up_proj.weight'] = state_dict['kv_b_proj.weight'] - if args.qk_layernorm: - mg_state_dict['linear_kv_up_proj.layer_norm_weight'] = state_dict['kv_a_layernorm.weight'] - return _add_prefix(mg_state_dict, prefix) - - -def set_attn_state(args, state_dict, prefix: str): - mg_state_dict = {} - num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) - mg_state_dict['linear_qkv.weight'] = torch.cat([ - state_dict['q_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), - state_dict['k_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), - state_dict['v_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), - ], - dim=1).reshape((-1, args.hidden_size)) - mg_state_dict['linear_proj.weight'] = state_dict['o_proj.weight'] - - # Copy bias - if args.add_qkv_bias: - mg_state_dict['linear_qkv.bias'] = torch.cat([ - state_dict['q_proj.bias'].reshape((num_query_groups, -1)), - state_dict['k_proj.bias'].reshape((num_query_groups, -1)), - state_dict['v_proj.bias'].reshape((num_query_groups, -1)), - ], - dim=1).reshape(-1) - if args.qk_layernorm: - if 'q_norm.weight' in state_dict: - mg_state_dict['q_layernorm.weight'] = state_dict['q_norm.weight'] - else: - mg_state_dict['q_layernorm.weight'] = state_dict['query_layernorm.weight'] - if 'k_norm.weight' in state_dict: - mg_state_dict['k_layernorm.weight'] = state_dict['k_norm.weight'] - else: - mg_state_dict['k_layernorm.weight'] = state_dict['key_layernorm.weight'] - - return _add_prefix(mg_state_dict, prefix) - - -def _set_mlp_state( - args, - state_dict, - prefix: str, - group_idx: Optional[int] = None, - hf_grouped: bool = False, -): - mg_state_dict = {} - if group_idx is None: - fc1_key = 'linear_fc1.weight' - fc2_key = 'linear_fc2.weight' - else: - fc1_key = f'linear_fc1.weight{group_idx}' - fc2_key = f'linear_fc2.weight{group_idx}' - if hf_grouped: - mg_state_dict[fc1_key] = state_dict['gate_up_proj'][group_idx].t() - mg_state_dict[fc2_key] = state_dict['down_proj'][group_idx].t() - else: - if 'gate_up_proj.weight' in state_dict: - mg_state_dict[fc1_key] = state_dict['gate_up_proj.weight'] - else: - mg_state_dict[fc1_key] = torch.cat([ - state_dict['gate_proj.weight'], - state_dict['up_proj.weight'], - ], dim=0) - mg_state_dict[fc2_key] = state_dict['down_proj.weight'] - return _add_prefix(mg_state_dict, prefix) - - -def _set_moe_state(args, state_dict, prefix: str): - mg_state_dict = {} - if 'gate.wg.weight' in state_dict: - mg_state_dict['router.weight'] = state_dict['gate.wg.weight'] - else: - mg_state_dict['router.weight'] = state_dict['gate.weight'] - if args.moe_router_enable_expert_bias: - mg_state_dict['router.expert_bias'] = state_dict['gate.e_score_correction_bias'] - - if args.moe_shared_expert_intermediate_size: - shared_expert_sd = _remove_prefix(state_dict, 'shared_expert.') - if not shared_expert_sd: - shared_expert_sd = _remove_prefix(state_dict, 'shared_experts.') - if not shared_expert_sd: - shared_expert_sd = _remove_prefix(state_dict, 'shared_mlp.') - mg_state_dict.update(_set_mlp_state(args, shared_expert_sd, 'shared_experts.')) - if 'shared_expert_gate.weight' in state_dict: - mg_state_dict['shared_experts.gate_weight'] = state_dict['shared_expert_gate.weight'] - for expert_idx in range(args.num_experts): - expert_sd = _remove_prefix(state_dict, 'experts.') - hf_grouped = expert_sd is not None - if expert_sd is None: - expert_sd = _remove_prefix(state_dict, f'experts.{expert_idx}.') - mg_state_dict.update(_set_mlp_state(args, expert_sd, 'experts.', group_idx=expert_idx, hf_grouped=hf_grouped)) - return _add_prefix(mg_state_dict, prefix) - - -def set_layer_state(args, state_dict, prefix: str): - mg_state_dict = {} - if args.multi_latent_attention: - mg_state_dict.update(set_mla_attn_state(args, _remove_prefix(state_dict, 'self_attn.'), 'self_attention.')) - mg_state_dict['input_layernorm.weight'] = state_dict['input_layernorm.weight'] - - else: - mg_state_dict.update(set_attn_state(args, _remove_prefix(state_dict, 'self_attn.'), 'self_attention.')) - mg_state_dict['self_attention.linear_qkv.layer_norm_weight'] = state_dict['input_layernorm.weight'] - - mlp_state_dict = _remove_prefix(state_dict, 'mlp.') - is_moe = _is_moe(mlp_state_dict) - if is_moe: - mg_state_dict.update(_set_moe_state(args, mlp_state_dict, 'mlp.')) - else: - mg_state_dict.update(_set_mlp_state(args, mlp_state_dict, 'mlp.')) - - if is_moe: - mg_state_dict['pre_mlp_layernorm.weight'] = state_dict['post_attention_layernorm.weight'] - else: - mg_state_dict['mlp.linear_fc1.layer_norm_weight'] = state_dict['post_attention_layernorm.weight'] - return _add_prefix(mg_state_dict, prefix) - - -def convert_hf2mcore(state_dict, prefix=''): - args = get_args() - mg_state_dict = {} - is_language_model = 'model.language_model.embed_tokens.weight' in state_dict - hf_prefix = 'model.language_model.' if is_language_model else 'model.' - mg_state_dict['embedding.word_embeddings.weight'] = state_dict[f'{hf_prefix}embed_tokens.weight'] - if args.untie_embeddings_and_output_weights and 'lm_head.weight' in state_dict: - mg_state_dict['output_layer.weight'] = state_dict['lm_head.weight'] - mg_state_dict['decoder.final_layernorm.weight'] = state_dict[f'{hf_prefix}norm.weight'] - for layer_idx in tqdm(range(args.num_layers), dynamic_ncols=True, desc='Converting: '): - mg_state_dict.update( - set_layer_state(args, _remove_prefix(state_dict, f'{hf_prefix}layers.{layer_idx}.'), - f'decoder.layers.{layer_idx}.')) - return _add_prefix(mg_state_dict, prefix) diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py deleted file mode 100644 index f417bafae7..0000000000 --- a/swift/megatron/model/gpt/mcore2hf.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Optional - -import torch -from megatron.training import get_args -from torch import nn -from tqdm import tqdm - - -def set_mla_attn_state(args, state_dict, prefix: str): - hf_state_dict = {} - hf_state_dict['o_proj.weight'] = state_dict['linear_proj.weight'] - if args.q_lora_rank is None: - hf_state_dict['q_proj.weight'] = state_dict['linear_q_proj.weight'] - else: - hf_state_dict['q_a_proj.weight'] = state_dict['linear_q_down_proj.weight'] - hf_state_dict['q_b_proj.weight'] = state_dict['linear_q_up_proj.weight'] - hf_state_dict['kv_a_proj_with_mqa.weight'] = state_dict['linear_kv_down_proj.weight'] - hf_state_dict['kv_b_proj.weight'] = state_dict['linear_kv_up_proj.weight'] - if args.qk_layernorm: - hf_state_dict['kv_a_layernorm.weight'] = state_dict['linear_kv_up_proj.layer_norm_weight'] - return _add_prefix(hf_state_dict, prefix) - - -def set_attn_state(args, state_dict, prefix: str): - hf_state_dict = {} - num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) - mg_attn_weight = state_dict['linear_qkv.weight'].reshape((num_query_groups, -1, args.hidden_size)) - q_dim = args.kv_channels * args.num_attention_heads - kv_dim = args.kv_channels * args.num_query_groups - hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size) - hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size) - hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size) - hf_state_dict['o_proj.weight'] = state_dict['linear_proj.weight'] - - # Copy bias - if args.add_qkv_bias: - mg_attn_bias = state_dict['linear_qkv.bias'].reshape((num_query_groups, -1)) - state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1) - state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1) - state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1) - if args.qk_layernorm: - hf_state_dict['q_norm.weight'] = state_dict['q_layernorm.weight'] - hf_state_dict['query_layernorm.weight'] = state_dict['q_layernorm.weight'] - hf_state_dict['k_norm.weight'] = hf_state_dict['k_layernorm.weight'] - hf_state_dict['key_layernorm.weight'] = hf_state_dict['k_layernorm.weight'] - - return _add_prefix(mg_state_dict, prefix) - - -def _set_mlp_state( - args, - state_dict, - prefix: str, - group_idx: Optional[int] = None, - hf_grouped: bool = False, -): - mg_state_dict = {} - if group_idx is None: - fc1_key = 'linear_fc1.weight' - fc2_key = 'linear_fc2.weight' - else: - fc1_key = f'linear_fc1.weight{group_idx}' - fc2_key = f'linear_fc2.weight{group_idx}' - if hf_grouped: - mg_state_dict[fc1_key] = state_dict['gate_up_proj'][group_idx].t() - mg_state_dict[fc2_key] = state_dict['down_proj'][group_idx].t() - else: - if 'gate_up_proj.weight' in state_dict: - mg_state_dict[fc1_key] = state_dict['gate_up_proj.weight'] - else: - mg_state_dict[fc1_key] = torch.cat([ - state_dict['gate_proj.weight'], - state_dict['up_proj.weight'], - ], dim=0) - mg_state_dict[fc2_key] = state_dict['down_proj.weight'] - return _add_prefix(mg_state_dict, prefix) - - -def _set_moe_state(args, state_dict, prefix: str): - mg_state_dict = {} - if 'gate.wg.weight' in state_dict: - mg_state_dict['router.weight'] = state_dict['gate.wg.weight'] - else: - mg_state_dict['router.weight'] = state_dict['gate.weight'] - if args.moe_router_enable_expert_bias: - mg_state_dict['router.expert_bias'] = state_dict['gate.e_score_correction_bias'] - - if args.moe_shared_expert_intermediate_size: - shared_expert_sd = _remove_prefix(state_dict, 'shared_expert.') - if not shared_expert_sd: - shared_expert_sd = _remove_prefix(state_dict, 'shared_experts.') - if not shared_expert_sd: - shared_expert_sd = _remove_prefix(state_dict, 'shared_mlp.') - mg_state_dict.update(_set_mlp_state(args, shared_expert_sd, 'shared_experts.')) - if 'shared_expert_gate.weight' in state_dict: - mg_state_dict['shared_experts.gate_weight'] = state_dict['shared_expert_gate.weight'] - for expert_idx in range(args.num_experts): - expert_sd = _remove_prefix(state_dict, 'experts.') - hf_grouped = expert_sd is not None - if expert_sd is None: - expert_sd = _remove_prefix(state_dict, f'experts.{expert_idx}.') - mg_state_dict.update(_set_mlp_state(args, expert_sd, 'experts.', group_idx=expert_idx, hf_grouped=hf_grouped)) - return _add_prefix(mg_state_dict, prefix) - - -def _is_moe(state_dict): - for k, v in state_dict.items(): - if 'experts.' in k: - return True - return False - - -def set_layer_state(args, state_dict, prefix: str): - hf_state_dict = {} - if args.multi_latent_attention: - hf_state_dict.update(set_mla_attn_state(args, _remove_prefix(state_dict, ), 'self_attn.')) - hf_state_dict['input_layernorm.weight'] = mg_state_dict['input_layernorm.weight'] - - else: - hf_state_dict.update(set_attn_state(args, _remove_prefix(state_dict, 'self_attention.'), 'self_attn.')) - hf_state_dict['input_layernorm.weight'] = state_dict['self_attention.linear_qkv.layer_norm_weight'] - - mlp_state_dict = _remove_prefix(state_dict, 'mlp.') - is_moe = _is_moe(mlp_state_dict) - if is_moe: - hf_state_dict.update(_set_moe_state(args, mlp_state_dict, 'mlp.')) - else: - hf_state_dict.update(_set_mlp_state(args, mlp_state_dict, 'mlp.')) - - if is_moe: - hf_state_dict['post_attention_layernorm.weight'] = state_dict['pre_mlp_layernorm.weight'] - else: - hf_state_dict['post_attention_layernorm.weight'] = state_dict['mlp.linear_fc1.layer_norm_weight'] - return _add_prefix(hf_state_dict, prefix) - - -def _remove_prefix(state_dict, prefix: str): - if not prefix: - return state_dict - return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} - - -def _add_prefix(state_dict, prefix: str): - if not prefix: - return state_dict - return {f'{prefix}{k}': v for k, v in state_dict.items()} - - -def convert_mcore2hf(state_dict, prefix=''): - args = get_args() - hf_state_dict = {} - hf_state_dict['model.embed_tokens.weight'] = state_dict['embedding.word_embeddings.weight'] - if args.untie_embeddings_and_output_weights: - hf_state_dict['lm_head.weight'] = state_dict['output_layer.weight'] - hf_state_dict['model.norm.weight'] = state_dict['decoder.final_layernorm.weight'] - for layer_idx in tqdm(range(args.num_layers), dynamic_ncols=True, desc='Converting: '): - hf_state_dict.update( - set_layer_state(args, _remove_prefix(state_dict, f'decoder.layers.{layer_idx}.'), - f'model.layers.{layer_idx}.')) - return _add_prefix(hf_state_dict, prefix) diff --git a/swift/megatron/model/gpt/qwen3_next.py b/swift/megatron/model/gpt/qwen3_next.py index 1d291b6a4c..4b5919c1c5 100644 --- a/swift/megatron/model/gpt/qwen3_next.py +++ b/swift/megatron/model/gpt/qwen3_next.py @@ -22,9 +22,8 @@ from swift.llm import ModelType from swift.utils import get_logger from ..constant import MegatronModelType -from ..gpt_model import GPTModel +from ..gpt_bridge import GPTBridge from ..register import MegatronModelMeta, register_megatron_model -from .config import convert_gpt_hf_config try: from flashattn_hopper.flash_attn_interface import _flash_attn_forward @@ -473,6 +472,12 @@ def get_qwen3_next_transformer_layer_spec(config, vp_stage=None): return block_spec +class Qwen3NextBridge(GPTBridge): + + def _convert(self, state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): + print() + + def convert_mcore2hf_qwen3_next(hf_model, mg_model): from .mcore2hf import set_mlp_state, set_attn_state args = get_args() @@ -537,9 +542,5 @@ def convert_hf2mcore_qwen3_next(hf_model, mg_model): ModelType.qwen3_next, ModelType.qwen3_next_thinking, ], - model_cls=GPTModel, - convert_hf_config=convert_gpt_hf_config, get_transformer_layer_spec=get_qwen3_next_transformer_layer_spec, - convert_mcore2hf=convert_mcore2hf_qwen3_next, - convert_hf2mcore=convert_hf2mcore_qwen3_next, )) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 05b580d41d..bc6c0cec45 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Dict, Optional import torch from megatron.training import get_args @@ -180,7 +180,8 @@ def _set_mla_attn_state( self._set_state_dict(state_dict, res, 'kv_a_proj_with_mqa.weight', 'linear_kv_down_proj.weight', reverse) self._set_state_dict(state_dict, res, 'kv_b_proj.weight', 'linear_kv_up_proj.weight', reverse) if self.args.qk_layernorm: - self._set_state_dict(state_dict, res, 'kv_a_layernorm.weight', 'linear_kv_up_proj.layer_norm_weight', reverse) + self._set_state_dict(state_dict, res, 'kv_a_layernorm.weight', 'linear_kv_up_proj.layer_norm_weight', + reverse) return self._add_prefix(res, tgt_prefix) def _set_layer_state(self, state_dict, layer_idx: int, hf_prefix: str, mg_prefix: str, reverse: bool): @@ -227,8 +228,8 @@ def _convert(self, state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): res.update(self._set_layer_state(state_dict, layer_idx, 'model.layers.', 'decoder.layers.', reverse)) return self._add_prefix(res, tgt_prefix) - def convert_hf2mcore(self, state_dict): + def convert_hf2mcore(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self._convert(state_dict, '', '', False) - def convert_mcore2hf(self, state_dict): + def convert_mcore2hf(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self._convert(state_dict, '', '', True) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index c511caf501..877a09db92 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -1,14 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from argparse import ArgumentParser from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Callable, List, Optional, Type -import torch import torch.nn as nn -from transformers import PretrainedConfig from swift.llm import MODEL_MAPPING +from .constant import MLLMMegatronModelType from .gpt_bridge import GPTBridge +from .gpt_model import GPTModel +from .mm_gpt_model import MultimodalGPTModel from .model_provider import model_provider as model_provider_func MEGATRON_MODEL_MAPPING = {} @@ -18,8 +19,8 @@ class MegatronModelMeta: megatron_model_type: str model_types: List[str] - model_cls: Type[nn.Module] + is_multimodal: bool = False bridge_cls: Type[GPTBridge] = GPTBridge get_transformer_layer_spec: Optional[Callable] = None model_provider: Callable[[], nn.Module] = model_provider_func @@ -27,6 +28,10 @@ class MegatronModelMeta: extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None + @property + def model_cls(self): + return MultimodalGPTModel if self.is_multimodal else GPTModel + def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False): megatron_model_type = megatron_model_meta.megatron_model_type @@ -35,7 +40,8 @@ def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: model_meta.support_megatron = True if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING: raise ValueError(f'The `{megatron_model_type}` has already been registered in the MODEL_MAPPING.') - + if megatron_model_type in MLLMMegatronModelType.__dict__: + megatron_model_meta.is_multimodal = True MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta From 385f5bcb322f7e9b87775aa0bcdffd1e6a3e6223 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 26 Oct 2025 20:24:19 +0800 Subject: [PATCH 10/30] update --- swift/llm/model/patcher.py | 2 +- swift/megatron/model/gpt/qwen3_next.py | 73 ++++++------------------ swift/megatron/model/gpt_bridge.py | 77 ++++++++++++++++++-------- 3 files changed, 70 insertions(+), 82 deletions(-) diff --git a/swift/llm/model/patcher.py b/swift/llm/model/patcher.py index 1ab31aa8f7..58e2065f1e 100644 --- a/swift/llm/model/patcher.py +++ b/swift/llm/model/patcher.py @@ -467,7 +467,7 @@ def patch_tp_plan(load_model: bool): transformers.__version__) < version.parse('4.50') or 'WORLD_SIZE' not in os.environ: yield return - logger.info('Patch tp_plan.') + logger.info_once('Patch tp_plan.') WORLD_SIZE = os.environ.get('WORLD_SIZE') os.environ['_PATCH_WORLD_SIZE'] = WORLD_SIZE os.environ.pop('WORLD_SIZE') diff --git a/swift/megatron/model/gpt/qwen3_next.py b/swift/megatron/model/gpt/qwen3_next.py index 4b5919c1c5..2dd9026dd9 100644 --- a/swift/megatron/model/gpt/qwen3_next.py +++ b/swift/megatron/model/gpt/qwen3_next.py @@ -3,7 +3,6 @@ from typing import Optional, Tuple, Union import torch -from megatron.core import mpu from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb @@ -474,65 +473,24 @@ def get_qwen3_next_transformer_layer_spec(config, vp_stage=None): class Qwen3NextBridge(GPTBridge): - def _convert(self, state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): - print() - - -def convert_mcore2hf_qwen3_next(hf_model, mg_model): - from .mcore2hf import set_mlp_state, set_attn_state - args = get_args() - hf_model.model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight) - if args.untie_embeddings_and_output_weights: - lm_head_weight = hf_model.score.weight if args.task_type == 'seq_cls' else hf_model.lm_head.weight - lm_head_weight.data.copy_(mg_model.output_layer.weight) - hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight - 1) - for layer_idx in range(args.num_layers): - layer_type = args.layer_types[layer_idx] - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.model.layers[layer_idx] - mg_attn = mg_layer.self_attention - - if layer_type == 'linear_attention': - hf_layer.linear_attn.load_state_dict(mg_attn.state_dict(), strict=False) - hf_layer.input_layernorm.weight.data.copy_(mg_layer.input_layernorm.weight - 1) - elif layer_type == 'full_attention': - hf_attn = hf_layer.self_attn - set_attn_state(args, mg_attn, hf_attn) - hf_layer.input_layernorm.weight.data.copy_(mg_attn.linear_qkv.layer_norm_weight - 1) - if args.qk_layernorm: - hf_attn.q_norm.weight.data.copy_(mg_attn.q_layernorm.weight - 1) - hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight - 1) - - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) - hf_layer.post_attention_layernorm.weight.data.copy_(mg_layer.pre_mlp_layernorm.weight - 1) - - -def convert_hf2mcore_qwen3_next(hf_model, mg_model): - from .hf2mcore import set_mlp_state, set_attn_state - args = get_args() - mg_model.embedding.word_embeddings.weight.data.copy_(hf_model.model.embed_tokens.weight) - if args.untie_embeddings_and_output_weights: - mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) - mg_model.decoder.final_layernorm.weight.data.copy_(hf_model.model.norm.weight + 1) - for layer_idx in range(args.num_layers): - layer_type = args.layer_types[layer_idx] - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.model.layers[layer_idx] - mg_attn = mg_layer.self_attention + def _set_state_dict(self, state_dict, res_state_dict, hf_key: str, mg_key: str, reverse: bool, offset: float = 0): + if hf_key in { + 'model.norm.weight', 'q_norm.weight', 'k_norm.weight', 'input_layernorm.weight', + 'post_attention_layernorm.weight' + }: + offset = -1 if reverse else 1 + else: + assert 'norm' not in hf_key, f'hf_key: {hf_key}' # just check + return super()._set_state_dict(state_dict, res_state_dict, hf_key, mg_key, reverse, offset) + def _set_layer_attn(self, state_dict, layer_idx: int, reverse: bool): + layer_type = self.args.layer_types[layer_idx] if layer_type == 'linear_attention': - mg_attn.load_state_dict(hf_layer.linear_attn.state_dict(), strict=False) - mg_layer.input_layernorm.weight.data.copy_(hf_layer.input_layernorm.weight + 1) + res = self._replace_prefix(state_dict, 'linear_attn.', 'self_attention.', reverse) + self._set_state_dict(state_dict, res, 'input_layernorm.weight', 'input_layernorm.weight', reverse) elif layer_type == 'full_attention': - hf_attn = hf_layer.self_attn - set_attn_state(args, mg_attn, hf_attn) - mg_attn.linear_qkv.layer_norm_weight.data.copy_(hf_layer.input_layernorm.weight + 1) - if args.qk_layernorm: - mg_attn.q_layernorm.weight.data.copy_(hf_attn.q_norm.weight + 1) - mg_attn.k_layernorm.weight.data.copy_(hf_attn.k_norm.weight + 1) - - set_mlp_state(args, mg_layer.mlp, hf_layer.mlp) - mg_layer.pre_mlp_layernorm.weight.data.copy_(hf_layer.post_attention_layernorm.weight + 1) + res = super()._set_layer_attn(state_dict, layer_idx, reverse) + return res register_megatron_model( @@ -543,4 +501,5 @@ def convert_hf2mcore_qwen3_next(hf_model, mg_model): ModelType.qwen3_next_thinking, ], get_transformer_layer_spec=get_qwen3_next_transformer_layer_spec, + bridge_cls=Qwen3NextBridge, )) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index bc6c0cec45..76241d4017 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -9,7 +9,7 @@ class GPTBridge: - lm_layers_prefix = 'model.layers' # hf + lm_layers_prefix = 'model.layers' # HF model def __init__(self): self.args = get_args() @@ -19,11 +19,13 @@ def __init__(self): model_info.model_dir, model_type=model_info.model_type, return_dummy_model=True) self.hf_layers = deep_getattr(self.hf_model, self.lm_layers_prefix) - def _set_state_dict(self, state_dict, res_state_dict, hf_key: str, mg_key: str, reverse: bool): + def _set_state_dict(self, state_dict, res_state_dict, hf_key: str, mg_key: str, reverse: bool, offset: float = 0): src_key, tgt_key = hf_key, mg_key if reverse: src_key, tgt_key = tgt_key, src_key res_state_dict[tgt_key] = state_dict[src_key] + if offset: + res_state_dict[tgt_key] = res_state_dict[tgt_key] + offset @staticmethod def _remove_prefix(state_dict, prefix: str): @@ -37,6 +39,20 @@ def _add_prefix(state_dict, prefix: str): return state_dict return {f'{prefix}{k}': v for k, v in state_dict.items()} + @staticmethod + def _filter_prefix(state_dict, prefix: str): + if not prefix: + return state_dict + return {k: v for k, v in state_dict.items() if k.startswith(prefix)} + + @staticmethod + def _replace_prefix(state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + res = GPTBridge._remove_prefix(state_dict, src_prefix) + return GPTBridge._add_prefix(res, tgt_prefix) + @staticmethod def _is_moe(state_dict): for k, v in state_dict.items(): @@ -44,18 +60,19 @@ def _is_moe(state_dict): return True return False - def _set_attn_state(self, state_dict, hf_prefix: str, mg_prefix: str, hf_attn, reverse: bool): + def _set_attn_state(self, state_dict, hf_prefix: str, mg_prefix: str, layer_idx: int, reverse: bool): src_prefix, tgt_prefix = hf_prefix, mg_prefix if reverse: src_prefix, tgt_prefix = tgt_prefix, src_prefix state_dict = self._remove_prefix(state_dict, src_prefix) + hf_attn = self.hf_layers[layer_idx].self_attn args = self.args res = {} num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) if reverse: mg_attn_weight = state_dict['linear_qkv.weight'].reshape((num_query_groups, -1, args.hidden_size)) - q_dim = args.kv_channels * args.num_attention_heads // num_query_groups - kv_dim = args.kv_channels + q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ + 0] // num_query_groups res['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size) res['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size) res['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size) @@ -90,11 +107,12 @@ def _set_attn_state(self, state_dict, hf_prefix: str, mg_prefix: str, hf_attn, r return self._add_prefix(res, tgt_prefix) - def _set_moe_state(self, state_dict, hf_prefix: str, mg_prefix: str, hf_mlp, reverse: bool): + def _set_moe_state(self, state_dict, hf_prefix: str, mg_prefix: str, layer_idx: int, reverse: bool): src_prefix, tgt_prefix = hf_prefix, mg_prefix if reverse: src_prefix, tgt_prefix = tgt_prefix, src_prefix state_dict = self._remove_prefix(state_dict, src_prefix) + hf_mlp = self.hf_layers[layer_idx].mlp res = {} hf_gate_key = 'gate.wg.weight' if hasattr(hf_mlp.gate, 'wg') else 'gate.weight' self._set_state_dict(state_dict, res, hf_gate_key, 'router.weight', reverse) @@ -105,14 +123,14 @@ def _set_moe_state(self, state_dict, hf_prefix: str, mg_prefix: str, hf_mlp, rev for key in ['shared_expert', 'shared_experts', 'shared_mlp']: if hasattr(hf_mlp, key): hf_shared_expert_prefix = f'{key}.' - res.update(self._set_mlp_state(state_dict, hf_shared_expert_prefix, 'shared_experts.', hf_mlp, reverse)) + res.update(self._set_mlp_state(state_dict, hf_shared_expert_prefix, 'shared_experts.', layer_idx, reverse)) if hasattr(hf_mlp, 'shared_expert_gate'): self._set_state_dict(state_dict, res, 'shared_expert_gate.weight', 'shared_experts.gate_weight', reverse) for expert_idx in range(self.args.num_experts): hf_expert_prefix = f'experts.{expert_idx}.' if hasattr(hf_mlp.experts, '__len__') else 'experts.' res.update( - self._set_mlp_state(state_dict, hf_expert_prefix, 'experts.', hf_mlp, reverse, group_idx=expert_idx)) + self._set_mlp_state(state_dict, hf_expert_prefix, 'experts.', layer_idx, reverse, group_idx=expert_idx)) return self._add_prefix(res, tgt_prefix) def _set_mlp_state( @@ -120,7 +138,7 @@ def _set_mlp_state( state_dict, hf_prefix: str, mg_prefix: str, - hf_mlp, + layer_idx: int, reverse: bool, group_idx: Optional[int] = None, ): @@ -128,6 +146,7 @@ def _set_mlp_state( if reverse: src_prefix, tgt_prefix = tgt_prefix, src_prefix state_dict = self._remove_prefix(state_dict, src_prefix) + hf_mlp = self.hf_layers[layer_idx].mlp hf_grouped = False if group_idx is not None and not hasattr(hf_mlp.experts, '__len__'): hf_grouped = True @@ -163,7 +182,7 @@ def _set_mla_attn_state( state_dict, hf_prefix: str, mg_prefix: str, - hf_mlp, + layer_idx: int, reverse: bool, ): src_prefix, tgt_prefix = hf_prefix, mg_prefix @@ -184,37 +203,44 @@ def _set_mla_attn_state( reverse) return self._add_prefix(res, tgt_prefix) - def _set_layer_state(self, state_dict, layer_idx: int, hf_prefix: str, mg_prefix: str, reverse: bool): - hf_prefix = f'{hf_prefix}{layer_idx}.' - mg_prefix = f'{mg_prefix}{layer_idx}.' - hf_layer = self.hf_layers[layer_idx] - hf_attn, hf_mlp = hf_layer.self_attn, hf_layer.mlp - src_prefix, tgt_prefix = hf_prefix, mg_prefix - if reverse: - src_prefix, tgt_prefix = tgt_prefix, src_prefix - state_dict = self._remove_prefix(state_dict, src_prefix) + def _set_layer_attn(self, state_dict, layer_idx: int, reverse: bool): res = {} if self.args.multi_latent_attention: - res.update(self._set_mla_attn_state(state_dict, 'self_attn.', 'self_attention.', hf_mlp, reverse)) + res.update(self._set_mla_attn_state(state_dict, 'self_attn.', 'self_attention.', layer_idx, reverse)) self._set_state_dict(state_dict, res, 'input_layernorm.weight', 'input_layernorm.weight', reverse) else: - res.update(self._set_attn_state(state_dict, 'self_attn.', 'self_attention.', hf_attn, reverse)) + res.update(self._set_attn_state(state_dict, 'self_attn.', 'self_attention.', layer_idx, reverse)) self._set_state_dict(state_dict, res, 'input_layernorm.weight', 'self_attention.linear_qkv.layer_norm_weight', reverse) + return res + def _set_layer_mlp(self, state_dict, layer_idx: int, reverse: bool): + hf_mlp = self.hf_layers[layer_idx].mlp + res = {} is_moe = self._is_moe(hf_mlp.state_dict()) if is_moe: - res.update(self._set_moe_state(state_dict, 'mlp.', 'mlp.', hf_mlp, reverse)) + res.update(self._set_moe_state(state_dict, 'mlp.', 'mlp.', layer_idx, reverse)) self._set_state_dict(state_dict, res, 'post_attention_layernorm.weight', 'pre_mlp_layernorm.weight', reverse) else: - res.update(self._set_mlp_state(state_dict, 'mlp.', 'mlp.', hf_mlp, reverse)) + res.update(self._set_mlp_state(state_dict, 'mlp.', 'mlp.', layer_idx, reverse)) self._set_state_dict(state_dict, res, 'post_attention_layernorm.weight', 'mlp.linear_fc1.layer_norm_weight', reverse) + return res + def _set_layer_state(self, state_dict, layer_idx: int, hf_prefix: str, mg_prefix: str, reverse: bool): + hf_prefix = f'{hf_prefix}{layer_idx}.' + mg_prefix = f'{mg_prefix}{layer_idx}.' + src_prefix, tgt_prefix = hf_prefix, mg_prefix + if reverse: + src_prefix, tgt_prefix = tgt_prefix, src_prefix + state_dict = self._remove_prefix(state_dict, src_prefix) + res = self._set_layer_attn(state_dict, layer_idx, reverse) + res.update(self._set_layer_mlp(state_dict, layer_idx, reverse)) return self._add_prefix(res, tgt_prefix) def _convert(self, state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): + """reverse: False: hf -> mg; True: mg -> hf""" src_prefix, tgt_prefix = hf_prefix, mg_prefix if reverse: src_prefix, tgt_prefix = tgt_prefix, src_prefix @@ -222,7 +248,10 @@ def _convert(self, state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): res = {} self._set_state_dict(state_dict, res, 'model.embed_tokens.weight', 'embedding.word_embeddings.weight', reverse) if self.args.untie_embeddings_and_output_weights: - self._set_state_dict(state_dict, res, 'lm_head.weight', 'output_layer.weight', reverse) + hf_lm_head_key = 'lm_head.weight' + if reverse and self.args.task_type == 'seq_cls': + hf_lm_head_key = 'score.weight' + self._set_state_dict(state_dict, res, hf_lm_head_key, 'output_layer.weight', reverse) self._set_state_dict(state_dict, res, 'model.norm.weight', 'decoder.final_layernorm.weight', reverse) for layer_idx in tqdm(range(self.args.num_layers), dynamic_ncols=True, desc='Converting: '): res.update(self._set_layer_state(state_dict, layer_idx, 'model.layers.', 'decoder.layers.', reverse)) From 4e2cee2c27c3a775ee308a52199291a5e4c3ac92 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 27 Oct 2025 00:40:48 +0800 Subject: [PATCH 11/30] update --- swift/megatron/argument/megatron_args.py | 2 + swift/megatron/argument/train_args.py | 7 ++- swift/megatron/convert.py | 10 +--- swift/megatron/model/gpt_bridge.py | 20 +++++++ swift/megatron/trainers/base.py | 21 +++++-- swift/megatron/utils/lazy_tensor.py | 74 ++++++++++++++++++++++++ 6 files changed, 119 insertions(+), 15 deletions(-) create mode 100644 swift/megatron/utils/lazy_tensor.py diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 4cd2fc7a18..7b9a2efdd2 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -94,6 +94,8 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): torch_dtype: Optional[torch.dtype] = None padding_free: bool = True mlp_padding_free: bool = False + load_hf_checkpoint: bool = False + save_hf_checkpoint: bool = False # streaming dataloader dataloader_persistent_workers: bool = True dataloader_prefetch_factor: int = 10 diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 74b7f2e055..cf4676bf97 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -76,8 +76,13 @@ def __post_init__(self): if self.num_workers > 1: self.num_workers = 1 logger.info('Using streaming dataset, setting args.num_workers to 1.') - if self.load is None and self.no_initialization: + if self.load is None and self.no_initialization and not self.load_hf_checkpoint: raise ValueError('You did not pass `--load`, so you need to set `--no_initialization false` ' 'to allow the model to initialize weights properly.') if self.cached_dataset and self.context_parallel_size > 1: raise ValueError('`cached_dataset` does not support context parallelism.') + + def get_model_kwargs(self): + res = super().get_model_kwargs() + res['download_model'] = self.load_hf_checkpoint + return res diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index aae514fd16..7d94f96fc0 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -256,10 +256,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: mg_model = megatron_model_meta.model_provider() logger.info('Megatron model created successfully.') bridge = megatron_model_meta.bridge_cls() - incompatible_keys = mg_model.load_state_dict(bridge.convert_hf2mcore(hf_model.state_dict()), strict=False) - missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')] - assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}' - assert len(missing_keys) == 0, f'missing_keys: {missing_keys}' + bridge.load_state_dict(mg_model, bridge.convert_hf2mcore(hf_model.state_dict())) if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) del hf_model @@ -315,10 +312,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: if args.to_hf: hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) bridge = megatron_model_meta.bridge_cls() - incompatible_keys = hf_model.load_state_dict(bridge.convert_mcore2hf(mg_model.state_dict()), strict=False) - missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')] - assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}' - assert len(missing_keys) == 0, f'missing_keys: {missing_keys}' + bridge.load_state_dict(hf_model, bridge.convert_mcore2hf(mg_model.state_dict())) if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) del mg_model diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 76241d4017..f0a869cde9 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -262,3 +262,23 @@ def convert_hf2mcore(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor def convert_mcore2hf(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self._convert(state_dict, '', '', True) + + def load_state_dict(self, model, state_dict) -> None: + """The model can be either hf_model or mg_model""" + incompatible_keys = model.load_state_dict(state_dict, strict=False) + missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')] + assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}' + assert len(missing_keys) == 0, f'missing_keys: {missing_keys}' + + def load_from_hf_checkpoint(self, mg_model, hf_model_dir: str) -> None: + """按照mg_model的模型结构, 加载需要的参数,并进行scatter""" + print() + + def get_hf_state_dict(self, mg_models) -> Dict[str, torch.Tensor]: + """获取完整的hf state_dict""" + print() + + def save_hf_checkpoint(self, mg_models, output_dir: str) -> None: + """保存mg_model的hf格式checkpoint""" + state_dict = get_hf_state_dict(mg_models) + # rank0 save() diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 52a6f7077f..5ea4613846 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -47,6 +47,7 @@ def __init__(self, args, template): self.stimer = StragglerDetector() self.unwrapped_models = [] self.peft_models = [] + self.bridge = args.megatron_model_meta.bridge_cls() logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'logging_path: {logging_path}') self.jsonl_writer = JsonlWriter(logging_path, enable_async=True, write_on_rank='last') # for evaluate @@ -251,13 +252,16 @@ def load_state_dict(self, state_dict, strict: bool = True, *args, **kwargs): def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): - def new_model_provider_func(*args, **kwargs): - model = model_provider_func(*args, **kwargs) + args = get_args() + + def new_model_provider_func(*_args, **kwargs): + model = model_provider_func(*_args, **kwargs) + if args.load_hf_checkpoint: + bridge.load_from_hf_checkpoint(model, args.model_info.model_dir) self.unwrapped_models.append(model) self.peft_models.append(prepare_mcore_model(model)) return model - args = get_args() self._init_multimodal_full(args) with self._patch_load_state_dict(self._load_base_checkpoint): model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer( @@ -724,9 +728,14 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear return report_memory_flag - def save_checkpoint(self, *args, **kwargs): - with adapter_state_dict_context(): - return self._origin_save_checkpoint(*args, **kwargs) + def save_checkpoint(self, iteration, *_args, **kwargs): + args = get_args() + if args.save_hf_checkpoint: + ouput_dir = os.path.join(args.save, f'checkpoint-{iteration}') + bridge.save_hf_checkpoint(self.unwrapped_models, ouput_dir) + else: + with adapter_state_dict_context(): + return self._origin_save_checkpoint(iteration, *_args, **kwargs) def _patch_megatron(self): # support max_epochs diff --git a/swift/megatron/utils/lazy_tensor.py b/swift/megatron/utils/lazy_tensor.py new file mode 100644 index 0000000000..f594b26dbf --- /dev/null +++ b/swift/megatron/utils/lazy_tensor.py @@ -0,0 +1,74 @@ +import os +from functools import partial + +import json +import safetensors.torch + + +class LazyTensor: + + def __init__(self, tensor=None, loader=None): + """You need to provide a tensor or loader""" + self.tensor = tensor + self.loader = loader + + def load(self): + if self.tensor is None: + self.tensor = self.loader() + self.loader = None + return self.tensor + + +class SafetensorsLazyLoader: + + def __init__(self, hf_model_dir: str): + self.hf_model_dir = hf_model_dir + self._weight_map = {} + self._file_handles = {} + self._load_index() + + def _open_file(self, filename: str): + """Open a safetensors file if not already open.""" + if filename not in self._file_handles: + file_path = os.path.join(self.hf_model_dir, filename) + self._file_handles[filename] = safetensors.torch.safe_open(file_path, framework='pt') + return self._file_handles[filename] + + def _load_index(self): + """Load the model index file to get weight map.""" + index_path = os.path.join(self.hf_model_dir, 'model.safetensors.index.json') + + if os.path.exists(index_path): + with open(index_path, 'r') as f: + self._index_file = json.load(f) + self._weight_map = self._index_file.get('weight_map', {}) + else: + # Single file model + safetensors_file = os.path.join(self.hf_model_dir, 'model.safetensors') + if os.path.exists(safetensors_file): + # All weights are in single file + with safetensors.torch.safe_open(safetensors_file, framework='pt') as f: + for key in f.keys(): + self._weight_map[key] = 'model.safetensors' + + def get_state_dict(self): + res = {} + for k in self._weight_map.keys(): + res[k] = LazyTensor(loader=partial(self._load_tensor, key=k)) + return res + + def _load_tensor(self, key): + filename = self._weight_map[key] + file_handle = self._open_file(filename) + return file_handle.get_tensor(key) + + def close(self): + for f in self._file_handles: + f.close() + self._file_handles.clear() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() From 8bec0083176d09895ac422ad35fa165428eb6eea Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 28 Oct 2025 00:35:42 +0800 Subject: [PATCH 12/30] update --- swift/megatron/argument/megatron_args.py | 2 + swift/megatron/convert.py | 22 +- swift/megatron/model/gpt_bridge.py | 320 ++++++++++++++--------- swift/megatron/trainers/base.py | 12 +- swift/megatron/utils/__init__.py | 1 + swift/megatron/utils/io_utils.py | 153 +++++++++++ swift/megatron/utils/lazy_tensor.py | 74 ------ 7 files changed, 373 insertions(+), 211 deletions(-) create mode 100644 swift/megatron/utils/io_utils.py delete mode 100644 swift/megatron/utils/lazy_tensor.py diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 7b9a2efdd2..62f1164605 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -125,6 +125,8 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): layer_types: Optional[List[str]] = None # qwen3_vl, qwen3_omni mrope_interleaved: Optional[bool] = None + # hf saver + max_shard_size: str = '5GB' @staticmethod def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]: diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 7d94f96fc0..4135795b92 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -256,7 +256,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: mg_model = megatron_model_meta.model_provider() logger.info('Megatron model created successfully.') bridge = megatron_model_meta.bridge_cls() - bridge.load_state_dict(mg_model, bridge.convert_hf2mcore(hf_model.state_dict())) + bridge.load_weights(mg_model, args.model_info.model_dir) if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) del hf_model @@ -310,24 +310,14 @@ def convert_mcore2hf(args: ExportArguments) -> None: mg_model = peft_model.merge_and_unload() logger.info('Megatron model created successfully.') if args.to_hf: - hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) bridge = megatron_model_meta.bridge_cls() - bridge.load_state_dict(hf_model, bridge.convert_mcore2hf(mg_model.state_dict())) + logger.info('Converting weights and saving the model...') + bridge.save_weights([mg_model], args.output_dir) + logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') if args.test_convert_precision: + args.model = args.output_dir + hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) - del mg_model - logger.info('Successfully transferred MG model weights to HF model.') - ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load - logger.info('Saving the model...') - save_checkpoint( - hf_model, - processor, - args.output_dir, - safe_serialization=args.safe_serialization, - model_dirs=[ckpt_dir, args.model_dir], - max_shard_size=args.max_shard_size, - additional_saved_files=hf_model.model_meta.additional_saved_files) - logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') elif args.to_mcore: if args.thread_count is None: checkpoint_size = sum(get_n_params_grads(mg_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index f0a869cde9..cc264fca90 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1,11 +1,14 @@ -from typing import Dict, Optional +from functools import partial +from typing import Dict, Literal, Optional, Union import torch +from megatron.core import mpu from megatron.training import get_args from tqdm import tqdm -from swift.llm import deep_getattr, get_model_tokenizer +from swift.llm import deep_getattr, get_model_tokenizer, save_checkpoint from swift.utils import disable_safe_ddp_context_use_barrier +from ..utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver class GPTBridge: @@ -15,17 +18,82 @@ def __init__(self): self.args = get_args() model_info = self.args.model_info with torch.device('meta'), disable_safe_ddp_context_use_barrier(): - self.hf_model, _ = get_model_tokenizer( + self.hf_model, self.processor = get_model_tokenizer( model_info.model_dir, model_type=model_info.model_type, return_dummy_model=True) self.hf_layers = deep_getattr(self.hf_model, self.lm_layers_prefix) + self.tp_size = self.args.tensor_model_parallel_size + self.pp_size = self.args.pipeline_model_parallel_size + self.etp_size = self.args.expert_tensor_parallel_size + self.ep_size = self.args.expert_model_parallel_size - def _set_state_dict(self, state_dict, res_state_dict, hf_key: str, mg_key: str, reverse: bool, offset: float = 0): - src_key, tgt_key = hf_key, mg_key - if reverse: - src_key, tgt_key = tgt_key, src_key - res_state_dict[tgt_key] = state_dict[src_key] + self.tp_group = mpu.get_tensor_model_parallel_group() + self.pp_group = mpu.get_pipeline_model_parallel_group() + self.etp_group = mpu.get_expert_tensor_parallel_group() + self.ep_group = mpu.get_expert_model_parallel_group() + + self.tp_rank = mpu.get_tensor_model_parallel_rank() + self.pp_rank = mpu.get_pipeline_model_parallel_rank() + self.etp_rank = mpu.get_expert_tensor_parallel_rank() + self.ep_rank = mpu.get_expert_model_parallel_group() + + @staticmethod + def _get_tp_split_dim(mg_key: str) -> Optional[int]: + key, suffix = mg_key.rsplit('.', 2)[-2:] + if suffix == 'layer_norm_weight': + return + if key in {'word_embeddings', 'output_layer', 'linear_qkv', 'linear_fc1'}: + return 0 + elif key in {'linear_proj', 'linear_fc2'}: + return 1 + + def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0): + tp_dim = self._get_tp_split_dim(mg_key) + hf_weight = hf_weight.to(mg_param.device) + if tp_dim is not None and self.tp_size > 1: + if self.tp_rank == 0: + splited_weights = list(hf_weight.chunk(self.tp_size, dim=tp_dim)) + else: + splited_weights = None + tensor = torch.empty_like(mg_param) + torch.distributed.scatter( + tensor, + splited_weights, + src=0, + group=self.tp_group, + ) + else: + tensor = hf_weight + if offset: + tensor = tensor + offset + mg_param.data.copy_(tensor) + + def _get_weights(self, mg_weight, mg_key, offset: int = 0): + tp_dim = self._get_tp_split_dim(mg_key) + if tp_dim is not None and self.tp_size > 1: + gather_list = [torch.empty_like(mg_weight) for _ in range(self.tp_size)] + torch.distributed.gather( + mg_weight, + gather_list, + dst=0, + group=self.tp_group, + ) + tensor = torch.cat(gather_list, dim=tp_dim) + else: + tensor = mg_weight if offset: - res_state_dict[tgt_key] = res_state_dict[tgt_key] + offset + tensor = tensor + offset + return tensor + + def _set_state_dict(self, mg_module, mg_key: str, hf_state_dict, hf_key: str, reverse: bool, offset: float = 0): + mg_param = deep_getattr(mg_module, mg_key) + if reverse: + hf_state_dict[hf_key] = self._get_weights(mg_param.data, mg_key, offset) + else: + if mg_param is None: + assert self.pp_size > 1, f'mg_module: {mg_module}, mg_key: {mg_key}' + return + hf_weight = hf_state_dict[hf_key].load() + self._set_weights(mg_param, hf_weight, mg_key, offset) @staticmethod def _remove_prefix(state_dict, prefix: str): @@ -60,52 +128,55 @@ def _is_moe(state_dict): return True return False - def _set_attn_state(self, state_dict, hf_prefix: str, mg_prefix: str, layer_idx: int, reverse: bool): - src_prefix, tgt_prefix = hf_prefix, mg_prefix + def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, reverse: bool): if reverse: - src_prefix, tgt_prefix = tgt_prefix, src_prefix - state_dict = self._remove_prefix(state_dict, src_prefix) + hf_state_dict = {} + else: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) hf_attn = self.hf_layers[layer_idx].self_attn args = self.args - res = {} num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) if reverse: - mg_attn_weight = state_dict['linear_qkv.weight'].reshape((num_query_groups, -1, args.hidden_size)) + mg_attn_weight = mg_attn.linear_qkv.weight.reshape((num_query_groups, -1, args.hidden_size)) q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ 0] // num_query_groups - res['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size) - res['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size) - res['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size) + hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size) + hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size) + hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size) else: - res['linear_qkv.weight'] = torch.cat([ - state_dict['q_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), - state_dict['k_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), - state_dict['v_proj.weight'].reshape((num_query_groups, -1, args.hidden_size)), + linear_qkv_weight = torch.cat([ + hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), + hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), + hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), ], - dim=1).reshape((-1, args.hidden_size)) - self._set_state_dict(state_dict, res, 'o_proj.weight', 'linear_proj.weight', reverse) + dim=1).reshape((-1, args.hidden_size)) + self._set_weights(mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight') + self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', reverse) # Copy bias if args.add_qkv_bias: if reverse: - mg_attn_bias = state_dict['linear_qkv.bias'].reshape((num_query_groups, -1)) - res['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1) - res['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1) - res['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1) + mg_attn_bias = mg_attn.linear_qkv.bias.reshape((num_query_groups, -1)) + hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1) + hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1) + hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1) else: - res['linear_qkv.bias'] = torch.cat([ - state_dict['q_proj.bias'].reshape((num_query_groups, -1)), - state_dict['k_proj.bias'].reshape((num_query_groups, -1)), - state_dict['v_proj.bias'].reshape((num_query_groups, -1)), + linear_qkv_bias = torch.cat([ + hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)), + hf_state_dict['k_proj.bias'].load().reshape((num_query_groups, -1)), + hf_state_dict['v_proj.bias'].load().reshape((num_query_groups, -1)), ], - dim=1).reshape(-1) + dim=1).reshape(-1) + self._set_weights(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias') + if args.qk_layernorm: hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight' hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight' self._set_state_dict(state_dict, res, hf_q_norm_key, 'q_layernorm.weight', reverse) self._set_state_dict(state_dict, res, hf_k_norm_key, 'k_layernorm.weight', reverse) - - return self._add_prefix(res, tgt_prefix) + if reverse: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict def _set_moe_state(self, state_dict, hf_prefix: str, mg_prefix: str, layer_idx: int, reverse: bool): src_prefix, tgt_prefix = hf_prefix, mg_prefix @@ -135,53 +206,57 @@ def _set_moe_state(self, state_dict, hf_prefix: str, mg_prefix: str, layer_idx: def _set_mlp_state( self, - state_dict, + mg_mlp, + hf_state_dict, hf_prefix: str, - mg_prefix: str, layer_idx: int, reverse: bool, group_idx: Optional[int] = None, ): - src_prefix, tgt_prefix = hf_prefix, mg_prefix if reverse: - src_prefix, tgt_prefix = tgt_prefix, src_prefix - state_dict = self._remove_prefix(state_dict, src_prefix) + hf_state_dict = {} + else: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) hf_mlp = self.hf_layers[layer_idx].mlp hf_grouped = False - if group_idx is not None and not hasattr(hf_mlp.experts, '__len__'): - hf_grouped = True - res = {} - # Determines the keys for fc1 and fc2 in megatron - if group_idx is None: - fc1_key = 'linear_fc1.weight' - fc2_key = 'linear_fc2.weight' - else: + if group_idx is not None: + if not hasattr(hf_mlp.experts, '__len__'): + hf_grouped = True fc1_key = f'linear_fc1.weight{group_idx}' fc2_key = f'linear_fc2.weight{group_idx}' + else: + fc1_key = 'linear_fc1.weight' + fc2_key = 'linear_fc2.weight' if hf_grouped: - res[fc1_key] = state_dict['gate_up_proj'][group_idx].t() - res[fc2_key] = state_dict['down_proj'][group_idx].t() + res[fc1_key] = hf_state_dict['gate_up_proj'][group_idx].t() + res[fc2_key] = hf_state_dict['down_proj'][group_idx].t() else: if hasattr(hf_mlp, 'gate_up_proj'): - self._set_state_dict(state_dict, res, 'gate_up_proj.weight', fc1_key, reverse) + self._set_state_dict(hf_state_dict, res, 'gate_up_proj.weight', fc1_key, reverse) else: if reverse: - ffn_hidden_size = state_dict[fc1_key].shape[0] // 2 - res['gate_proj.weight'] = state_dict[fc1_key][:ffn_hidden_size] - res['up_proj.weight'] = state_dict[fc1_key][ffn_hidden_size:] + ffn_hidden_size = hf_mlp.gate_proj.weight.shape[0] + fc1_weight = deep_getattr(mg_mlp, fc1_key) + hf_state_dict['gate_proj.weight'] = fc1_weight[:ffn_hidden_size] + hf_state_dict['up_proj.weight'] = fc1_weight[ffn_hidden_size:] else: - res[fc1_key] = torch.cat([ - state_dict['gate_proj.weight'], - state_dict['up_proj.weight'], - ], dim=0) - self._set_state_dict(state_dict, res, 'down_proj.weight', fc2_key, reverse) - return self._add_prefix(res, tgt_prefix) + fc1_weight = torch.cat([ + hf_state_dict['gate_proj.weight'].load(), + hf_state_dict['up_proj.weight'].load(), + ], + dim=0) + self._set_weights(deep_getattr(mg_mlp, fc1_key), fc1_weight, 'linear_qkv.weight') + self._set_state_dict(mg_mlp, fc2_key, hf_state_dict, 'down_proj.weight', reverse) + if reverse: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict def _set_mla_attn_state( self, + mg_model, + mg_prefix: str, state_dict, hf_prefix: str, - mg_prefix: str, layer_idx: int, reverse: bool, ): @@ -203,82 +278,91 @@ def _set_mla_attn_state( reverse) return self._add_prefix(res, tgt_prefix) - def _set_layer_attn(self, state_dict, layer_idx: int, reverse: bool): - res = {} + def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, reverse: bool): + mg_attn = mg_layer.self_attention if self.args.multi_latent_attention: - res.update(self._set_mla_attn_state(state_dict, 'self_attn.', 'self_attention.', layer_idx, reverse)) - self._set_state_dict(state_dict, res, 'input_layernorm.weight', 'input_layernorm.weight', reverse) + hf_state_dict.update(self._set_mla_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, reverse)) + self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', reverse) else: - res.update(self._set_attn_state(state_dict, 'self_attn.', 'self_attention.', layer_idx, reverse)) - self._set_state_dict(state_dict, res, 'input_layernorm.weight', - 'self_attention.linear_qkv.layer_norm_weight', reverse) - return res + hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, reverse)) + self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict, + 'input_layernorm.weight', reverse) + return hf_state_dict - def _set_layer_mlp(self, state_dict, layer_idx: int, reverse: bool): + def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, reverse: bool): hf_mlp = self.hf_layers[layer_idx].mlp - res = {} is_moe = self._is_moe(hf_mlp.state_dict()) + mg_mlp = mg_layer.mlp if is_moe: - res.update(self._set_moe_state(state_dict, 'mlp.', 'mlp.', layer_idx, reverse)) - self._set_state_dict(state_dict, res, 'post_attention_layernorm.weight', 'pre_mlp_layernorm.weight', + hf_state_dict.update(self._set_moe_state(mg_mlp, hf_state_dict, 'mlp.', layer_idx, reverse)) + self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', reverse) else: - res.update(self._set_mlp_state(state_dict, 'mlp.', 'mlp.', layer_idx, reverse)) - self._set_state_dict(state_dict, res, 'post_attention_layernorm.weight', 'mlp.linear_fc1.layer_norm_weight', - reverse) - return res + hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, 'mlp.', layer_idx, reverse)) + self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict, + 'post_attention_layernorm.weight', reverse) + return hf_state_dict - def _set_layer_state(self, state_dict, layer_idx: int, hf_prefix: str, mg_prefix: str, reverse: bool): + def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, reverse: bool): hf_prefix = f'{hf_prefix}{layer_idx}.' - mg_prefix = f'{mg_prefix}{layer_idx}.' - src_prefix, tgt_prefix = hf_prefix, mg_prefix if reverse: - src_prefix, tgt_prefix = tgt_prefix, src_prefix - state_dict = self._remove_prefix(state_dict, src_prefix) - res = self._set_layer_attn(state_dict, layer_idx, reverse) - res.update(self._set_layer_mlp(state_dict, layer_idx, reverse)) - return self._add_prefix(res, tgt_prefix) + hf_state_dict = {} + else: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, reverse)) + hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, reverse)) + if reverse: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict - def _convert(self, state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): + def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): """reverse: False: hf -> mg; True: mg -> hf""" - src_prefix, tgt_prefix = hf_prefix, mg_prefix if reverse: - src_prefix, tgt_prefix = tgt_prefix, src_prefix - state_dict = self._remove_prefix(state_dict, src_prefix) - res = {} - self._set_state_dict(state_dict, res, 'model.embed_tokens.weight', 'embedding.word_embeddings.weight', reverse) + hf_state_dict = {} + else: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + self._set_state_dict(mg_model, 'embedding.word_embeddings.weight', hf_state_dict, 'model.embed_tokens.weight', + reverse) if self.args.untie_embeddings_and_output_weights: hf_lm_head_key = 'lm_head.weight' if reverse and self.args.task_type == 'seq_cls': hf_lm_head_key = 'score.weight' - self._set_state_dict(state_dict, res, hf_lm_head_key, 'output_layer.weight', reverse) - self._set_state_dict(state_dict, res, 'model.norm.weight', 'decoder.final_layernorm.weight', reverse) - for layer_idx in tqdm(range(self.args.num_layers), dynamic_ncols=True, desc='Converting: '): - res.update(self._set_layer_state(state_dict, layer_idx, 'model.layers.', 'decoder.layers.', reverse)) - return self._add_prefix(res, tgt_prefix) - - def convert_hf2mcore(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - return self._convert(state_dict, '', '', False) - - def convert_mcore2hf(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - return self._convert(state_dict, '', '', True) + self._set_state_dict(mg_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, reverse) + self._set_state_dict(mg_model, 'decoder.final_layernorm.weight', hf_state_dict, 'model.norm.weight', reverse) + if reverse: + yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + else: + yield - def load_state_dict(self, model, state_dict) -> None: - """The model can be either hf_model or mg_model""" - incompatible_keys = model.load_state_dict(state_dict, strict=False) - missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')] - assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}' - assert len(missing_keys) == 0, f'missing_keys: {missing_keys}' + for layer_idx in tqdm(range(self.args.num_layers), dynamic_ncols=True, desc='Converting: '): + mg_layer = mg_model.decoder.layers[layer_idx] + hf_state_dict = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, reverse) + if reverse: + yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + else: + yield - def load_from_hf_checkpoint(self, mg_model, hf_model_dir: str) -> None: - """按照mg_model的模型结构, 加载需要的参数,并进行scatter""" - print() + def load_weights(self, mg_model, hf_model_dir: str) -> None: + with SafetensorLazyLoader(hf_model_dir) as loader: + state_dict = loader.get_state_dict() + list(self._convert(mg_model, state_dict, '', False)) - def get_hf_state_dict(self, mg_models) -> Dict[str, torch.Tensor]: - """获取完整的hf state_dict""" - print() + def export_weights(self, mg_models): + state_dict = {} + for mg_model in mg_models: + yield from self._convert(mg_model, state_dict, '', True) - def save_hf_checkpoint(self, mg_models, output_dir: str) -> None: - """保存mg_model的hf格式checkpoint""" - state_dict = get_hf_state_dict(mg_models) - # rank0 save() + def save_weights(self, mg_models, output_dir: str) -> None: + """Save the mg_model checkpoint in HF format""" + saver = StreamingSafetensorSaver(save_dir=output_dir, max_shard_size=self.args.max_shard_size) + for k, v in self.export_weights(mg_models): + saver.add_tensor(k, v) + saver.finalize() + # TODO: new_special_tokens + self.hf_model.config.save_pretrained(output_dir) + save_checkpoint( + None, + self.processor, + output_dir, + model_dirs=[self.hf_model.model_info.model_dir], + additional_saved_files=self.hf_model.model_meta.additional_saved_files) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 5ea4613846..f480889ebd 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -47,7 +47,7 @@ def __init__(self, args, template): self.stimer = StragglerDetector() self.unwrapped_models = [] self.peft_models = [] - self.bridge = args.megatron_model_meta.bridge_cls() + self._bridge = None logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'logging_path: {logging_path}') self.jsonl_writer = JsonlWriter(logging_path, enable_async=True, write_on_rank='last') # for evaluate @@ -62,6 +62,12 @@ def _get_mean_metric(): } self.megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + @property + def bridge(self): + if self._bridge is None: + self._bridge = self.args.megatron_model_meta.bridge_cls() + return self._bridge + @contextmanager def _get_iters(self, train_dataset, val_dataset): origin_initialize_megatron = training.initialize_megatron @@ -257,7 +263,7 @@ def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **k def new_model_provider_func(*_args, **kwargs): model = model_provider_func(*_args, **kwargs) if args.load_hf_checkpoint: - bridge.load_from_hf_checkpoint(model, args.model_info.model_dir) + self.bridge.load_weights(model, args.model_info.model_dir) self.unwrapped_models.append(model) self.peft_models.append(prepare_mcore_model(model)) return model @@ -732,7 +738,7 @@ def save_checkpoint(self, iteration, *_args, **kwargs): args = get_args() if args.save_hf_checkpoint: ouput_dir = os.path.join(args.save, f'checkpoint-{iteration}') - bridge.save_hf_checkpoint(self.unwrapped_models, ouput_dir) + self.bridge.save_weights(self.unwrapped_models, ouput_dir) else: with adapter_state_dict_context(): return self._origin_save_checkpoint(iteration, *_args, **kwargs) diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 91fbd0fbf8..4bccac1021 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .config import convert_hf_config +from .io_utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver from .patcher import patch_torch_dist_shard from .utils import (adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/io_utils.py b/swift/megatron/utils/io_utils.py new file mode 100644 index 0000000000..ef53c7e5c3 --- /dev/null +++ b/swift/megatron/utils/io_utils.py @@ -0,0 +1,153 @@ +import os +from functools import partial + +import json +from safetensors.torch import save_file, safe_open +from swift.utils import is_master + +class LazyTensor: + + def __init__(self, tensor=None, loader=None): + """You need to provide a tensor or loader""" + self.tensor = tensor + self.loader = loader + + def load(self): + if self.tensor is None: + return self.loader() + return self.tensor + + +class SafetensorLazyLoader: + + def __init__(self, hf_model_dir: str): + self.hf_model_dir = hf_model_dir + self._weight_map = {} + self._file_handles = {} + self._load_index() + + def _open_file(self, filename: str): + """Open a safetensors file if not already open.""" + if filename not in self._file_handles: + file_path = os.path.join(self.hf_model_dir, filename) + self._file_handles[filename] = safe_open(file_path, framework='pt') + return self._file_handles[filename] + + def _load_index(self): + """Load the model index file to get weight map.""" + index_path = os.path.join(self.hf_model_dir, 'model.safetensors.index.json') + + if os.path.exists(index_path): + with open(index_path, 'r') as f: + self._index_file = json.load(f) + self._weight_map = self._index_file.get('weight_map', {}) + else: + # Single file model + safetensors_file = os.path.join(self.hf_model_dir, 'model.safetensors') + if os.path.exists(safetensors_file): + with safe_open(safetensors_file, framework='pt') as f: + for key in f.keys(): + self._weight_map[key] = 'model.safetensors' + + def get_state_dict(self): + res = {} + for k in self._weight_map.keys(): + res[k] = LazyTensor(loader=partial(self._load_tensor, key=k)) + return res + + def _load_tensor(self, key): + filename = self._weight_map[key] + file_handle = self._open_file(filename) + return file_handle.get_tensor(key) + + def close(self): + self._file_handles.clear() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + + + +class StreamingSafetensorSaver: + def __init__(self, save_dir, max_shard_size='5GB') -> None: + if not is_master(): + return + # max_shard_size: GiB + self.save_dir = save_dir + os.makedirs(save_dir, exist_ok=True) + if isinstance(max_shard_size, str): + if max_shard_size.endswith('GB'): + max_shard_size = int(max_shard_size[:-2]) + else: + raise ValueError(f"Invalid max_shard_size: {max_shard_size}") + self.max_shard_size = max_shard_size * 1000 ** 3 + self.current_shard = {} + self.current_shard_size = 0 + self.total_size = 0 + self.shard_index = 1 + self.weight_map = {} + + def add_tensor(self, name, tensor): + if not is_master(): + return + tensor_size = tensor.numel() * tensor.element_size() + if self.current_shard_size + tensor_size > self.max_shard_size and self.current_shard: + self._save_current_shard() + + self.current_shard[name] = tensor.cpu().contiguous() + self.current_shard_size += tensor_size + + def _save_current_shard(self, shard_filename: str = None): + if not self.current_shard: + return + if shard_filename is None: + shard_filename = f'model-{self.shard_index:05d}-of-?????.safetensors' + shard_path = os.path.join(self.save_dir, shard_filename) + save_file(self.current_shard, str(shard_path)) + for key in self.current_shard.keys(): + self.weight_map[key] = shard_filename + + self.total_size += self.current_shard_size + self.current_shard = {} + self.current_shard_size = 0 + self.shard_index += 1 + + def finalize(self): + if not is_master(): + return + if self.current_shard: + self._save_current_shard() + total_shards = self.shard_index - 1 + # rename `?????` + for i in range(1, total_shards + 1): + old_path = os.path.join(self.save_dir, f"model-{i:05d}-of-?????.safetensors") + if total_shards == 1: + new_name = f'model.safetensors' + else: + new_name = f"model-{i:05d}-of-{total_shards:05d}.safetensors" + new_path = os.path.join(self.save_dir, new_name) + if os.path.exists(old_path): + os.rename(old_path, new_path) + + updated_weight_map = {} + for key, filename in self.weight_map.items(): + new_filename = filename.replace('?????', f'{total_shards:05d}') + updated_weight_map[key] = new_filename + + self._save_index(updated_weight_map) + + def _save_index(self, weight_map): + index = { + "metadata": { + "total_size": self.total_size + }, + "weight_map": weight_map + } + + index_path = os.path.join(self.save_dir, "model.safetensors.index.json") + with open(index_path, 'w') as f: + json.dump(index, f, indent=2) diff --git a/swift/megatron/utils/lazy_tensor.py b/swift/megatron/utils/lazy_tensor.py deleted file mode 100644 index f594b26dbf..0000000000 --- a/swift/megatron/utils/lazy_tensor.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -from functools import partial - -import json -import safetensors.torch - - -class LazyTensor: - - def __init__(self, tensor=None, loader=None): - """You need to provide a tensor or loader""" - self.tensor = tensor - self.loader = loader - - def load(self): - if self.tensor is None: - self.tensor = self.loader() - self.loader = None - return self.tensor - - -class SafetensorsLazyLoader: - - def __init__(self, hf_model_dir: str): - self.hf_model_dir = hf_model_dir - self._weight_map = {} - self._file_handles = {} - self._load_index() - - def _open_file(self, filename: str): - """Open a safetensors file if not already open.""" - if filename not in self._file_handles: - file_path = os.path.join(self.hf_model_dir, filename) - self._file_handles[filename] = safetensors.torch.safe_open(file_path, framework='pt') - return self._file_handles[filename] - - def _load_index(self): - """Load the model index file to get weight map.""" - index_path = os.path.join(self.hf_model_dir, 'model.safetensors.index.json') - - if os.path.exists(index_path): - with open(index_path, 'r') as f: - self._index_file = json.load(f) - self._weight_map = self._index_file.get('weight_map', {}) - else: - # Single file model - safetensors_file = os.path.join(self.hf_model_dir, 'model.safetensors') - if os.path.exists(safetensors_file): - # All weights are in single file - with safetensors.torch.safe_open(safetensors_file, framework='pt') as f: - for key in f.keys(): - self._weight_map[key] = 'model.safetensors' - - def get_state_dict(self): - res = {} - for k in self._weight_map.keys(): - res[k] = LazyTensor(loader=partial(self._load_tensor, key=k)) - return res - - def _load_tensor(self, key): - filename = self._weight_map[key] - file_handle = self._open_file(filename) - return file_handle.get_tensor(key) - - def close(self): - for f in self._file_handles: - f.close() - self._file_handles.clear() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() From d036165f6f333cb9e299e5e766ed339e20726e3e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 28 Oct 2025 12:49:34 +0800 Subject: [PATCH 13/30] update --- swift/cli/_megatron/export.py | 5 ++ swift/cli/_megatron/main.py | 3 +- swift/cli/main.py | 4 +- swift/megatron/__init__.py | 6 +- swift/megatron/argument/__init__.py | 1 + swift/megatron/argument/export_args.py | 57 +++++++++++++++++++ swift/megatron/argument/megatron_base_args.py | 45 +++++++++++++++ swift/megatron/argument/train_args.py | 36 +----------- swift/megatron/convert.py | 46 ++++++++------- swift/megatron/export/__init__.py | 1 + swift/megatron/export/export.py | 54 ++++++++++++++++++ swift/megatron/model/gpt_bridge.py | 8 +-- swift/megatron/utils/io_utils.py | 28 ++++----- 13 files changed, 216 insertions(+), 78 deletions(-) create mode 100644 swift/cli/_megatron/export.py create mode 100644 swift/megatron/argument/export_args.py create mode 100644 swift/megatron/argument/megatron_base_args.py create mode 100644 swift/megatron/export/__init__.py create mode 100644 swift/megatron/export/export.py diff --git a/swift/cli/_megatron/export.py b/swift/cli/_megatron/export.py new file mode 100644 index 0000000000..3eca73ca8c --- /dev/null +++ b/swift/cli/_megatron/export.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.megatron import megatron_export_main + +if __name__ == '__main__': + megatron_export_main() diff --git a/swift/cli/_megatron/main.py b/swift/cli/_megatron/main.py index a83dcac8b1..0fc2de0c92 100644 --- a/swift/cli/_megatron/main.py +++ b/swift/cli/_megatron/main.py @@ -10,11 +10,12 @@ 'pt': 'swift.cli._megatron.pt', 'sft': 'swift.cli._megatron.sft', 'rlhf': 'swift.cli._megatron.rlhf', + 'export': 'swift.cli._megatron.export', } def cli_main(): - return swift_cli_main(ROUTE_MAPPING) + return swift_cli_main(ROUTE_MAPPING, is_megatron=True) if __name__ == '__main__': diff --git a/swift/cli/main.py b/swift/cli/main.py index 8924b03bab..99fa05bc73 100644 --- a/swift/cli/main.py +++ b/swift/cli/main.py @@ -53,7 +53,7 @@ def _compat_web_ui(argv): logger.warning('Please use `swift app`.') -def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None: +def cli_main(route_mapping: Optional[Dict[str, str]] = None, is_megatron: bool = False) -> None: route_mapping = route_mapping or ROUTE_MAPPING argv = sys.argv[1:] _compat_web_ui(argv) @@ -62,7 +62,7 @@ def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None: file_path = importlib.util.find_spec(route_mapping[method_name]).origin torchrun_args = get_torchrun_args() python_cmd = sys.executable - if torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}: + if not is_megatron and (torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}): args = [python_cmd, file_path, *argv] else: args = [python_cmd, '-m', 'torch.distributed.run', *torchrun_args, file_path, *argv] diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index a5a48f0897..9815a0ea87 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -13,18 +13,20 @@ if TYPE_CHECKING: from .train import megatron_sft_main, megatron_pt_main, megatron_rlhf_main + from .export import megatron_export_main from .convert import convert_hf2mcore, convert_mcore2hf from .utils import prepare_mcore_model, adapter_state_dict_context - from .argument import MegatronTrainArguments, MegatronRLHFArguments + from .argument import MegatronTrainArguments, MegatronRLHFArguments, MegatronExportArguments from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model from .trainers import MegatronTrainer, MegatronDPOTrainer from .tuners import LoraParallelLinear else: _import_structure = { 'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'], + 'export': ['megatron_export_main'], 'convert': ['convert_hf2mcore', 'convert_mcore2hf'], 'utils': ['prepare_mcore_model', 'adapter_state_dict_context'], - 'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments'], + 'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments', 'MegatronExportArguments'], 'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model'], 'trainers': ['MegatronTrainer', 'MegatronDPOTrainer'], 'tuners': ['LoraParallelLinear'], diff --git a/swift/megatron/argument/__init__.py b/swift/megatron/argument/__init__.py index a2ad08daa3..e6d577e1ad 100644 --- a/swift/megatron/argument/__init__.py +++ b/swift/megatron/argument/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .export_args import MegatronExportArguments from .megatron_args import MegatronArguments from .rlhf_args import MegatronRLHFArguments from .train_args import MegatronTrainArguments diff --git a/swift/megatron/argument/export_args.py b/swift/megatron/argument/export_args.py new file mode 100644 index 0000000000..6f051213b9 --- /dev/null +++ b/swift/megatron/argument/export_args.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from dataclasses import dataclass + +from swift.llm import HfConfigFactory +from swift.llm.argument.base_args import to_abspath +from swift.utils import get_logger +from .megatron_base_args import MegatronBaseArguments + +logger = get_logger() + + +@dataclass +class MegatronExportArguments(MegatronBaseArguments): + to_hf: bool = False + to_mcore: bool = False + test_convert_precision: bool = False + test_convert_dtype: str = 'float32' + exist_ok: bool = False + + def _init_save(self): + if self.save is None: + ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}' + ckpt_dir, ckpt_name = os.path.split(ckpt_dir) + if self.to_mcore: + suffix = 'mcore' + elif self.to_hf: + suffix = 'hf' + self.save = os.path.join(ckpt_dir, f'{ckpt_name}-{suffix}') + + self.save = to_abspath(self.save) + if not self.exist_ok and os.path.exists(self.save): + raise FileExistsError(f'args.save: `{self.save}` already exists.') + logger.info(f'args.save: `{self.save}`') + + def __post_init__(self): + super().__post_init__() + self._init_save() + self.test_convert_dtype = HfConfigFactory.to_torch_dtype(self.test_convert_dtype) + if self.to_hf or self.to_mcore: + self._init_convert() + + def _init_convert(self): + convert_kwargs = { + 'no_save_optim': True, + 'no_save_rng': True, + 'no_load_optim': True, + 'no_load_rng': True, + 'finetune': True, + 'attention_backend': 'unfused', + 'device_map': 'cpu', + 'padding_free': False, + } + for k, v in convert_kwargs.items(): + setattr(self, k, v) + if self.model_info.is_moe_model: + self.moe_grouped_gemm = True diff --git a/swift/megatron/argument/megatron_base_args.py b/swift/megatron/argument/megatron_base_args.py new file mode 100644 index 0000000000..6f978469e4 --- /dev/null +++ b/swift/megatron/argument/megatron_base_args.py @@ -0,0 +1,45 @@ +import os +from dataclasses import dataclass + +from swift.llm import BaseArguments +from swift.utils import get_logger +from ..model import get_megatron_model_meta +from ..utils import convert_hf_config +from .megatron_args import MegatronArguments + +logger = get_logger() + + +@dataclass +class MegatronBaseArguments(MegatronArguments, BaseArguments): + + def __post_init__(self): + self.sequence_parallel_size = self.context_parallel_size + if self.packing: + self.padding_free = True + BaseArguments.__post_init__(self) + self.megatron_model_meta = get_megatron_model_meta(self.model_type) + self.seq_length = self.seq_length or self.packing_length or self.max_length + if self.streaming: + self.dataloader_type = 'external' + if self.num_workers > 1: + self.num_workers = 1 + logger.info('Using streaming dataset, setting args.num_workers to 1.') + + def init_model_args(self, tokenizer, config): + if self.task_type == 'seq_cls': + self.problem_type = self.problem_type or getattr(config, 'problem_type', None) + logger.info(f'args.problem_type: {self.problem_type}') + kwargs = convert_hf_config(config) + if self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer): + kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128 + self.initialize_embedding = True + logger.info(f'megatron_config: {kwargs}') + for k, v in kwargs.items(): + if getattr(self, k) is None: + setattr(self, k, v) + MegatronArguments.__post_init__(self) + self.extra_args = self.parse_to_megatron() + self.extra_args['model_info'] = self.model_info + self.extra_args['model_meta'] = self.model_meta + self.extra_args['megatron_model_meta'] = self.megatron_model_meta diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index cf4676bf97..4c0b19ab55 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -8,36 +8,16 @@ from swift.llm import BaseArguments from swift.llm.argument.base_args import to_abspath from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master -from ..model import get_megatron_model_meta -from ..utils import convert_hf_config -from .megatron_args import MegatronArguments +from .megatron_base_args import MegatronBaseArguments logger = get_logger() @dataclass -class MegatronTrainArguments(MegatronArguments, BaseArguments): +class MegatronTrainArguments(MegatronBaseArguments): add_version: bool = True load_args: bool = False - def init_model_args(self, tokenizer, config): - if self.task_type == 'seq_cls': - self.problem_type = self.problem_type or getattr(config, 'problem_type', None) - logger.info(f'args.problem_type: {self.problem_type}') - kwargs = convert_hf_config(config) - if self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer): - kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128 - self.initialize_embedding = True - logger.info(f'megatron_config: {kwargs}') - for k, v in kwargs.items(): - if getattr(self, k) is None: - setattr(self, k, v) - MegatronArguments.__post_init__(self) - self.extra_args = self.parse_to_megatron() - self.extra_args['model_info'] = self.model_info - self.extra_args['model_meta'] = self.model_meta - self.extra_args['megatron_model_meta'] = self.megatron_model_meta - def _init_save(self): init_process_group(backend=self.ddp_backend, timeout=self.ddp_timeout) if self.save is None: @@ -60,22 +40,12 @@ def _init_ckpt_dir(self, adapters=None): self.model = old_args.get('model') def __post_init__(self): - self.sequence_parallel_size = self.context_parallel_size - if self.packing: - self.padding_free = True self.load = to_abspath(self.load, check_path_exist=True) - BaseArguments.__post_init__(self) - self.megatron_model_meta = get_megatron_model_meta(self.model_type) + super().__post_init__() if len(self.dataset) == 0 and len(self.cached_dataset) == 0: raise ValueError(f'self.dataset: {self.dataset}, self.cached_dataset: {self.cached_dataset}. ' 'Please input the training dataset.') self._init_save() - self.seq_length = self.seq_length or self.packing_length or self.max_length - if self.streaming: - self.dataloader_type = 'external' - if self.num_workers > 1: - self.num_workers = 1 - logger.info('Using streaming dataset, setting args.num_workers to 1.') if self.load is None and self.no_initialization and not self.load_hf_checkpoint: raise ValueError('You did not pass `--load`, so you need to set `--no_initialization false` ' 'to allow the model to initialize weights properly.') diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 4135795b92..3c3eb2e47c 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn +from megatron.core import mpu from megatron.training import get_args from megatron.training.checkpointing import load_checkpoint from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint @@ -37,9 +38,10 @@ def _test_params_sum(model): logger.warning(f'n: {n}, sum: {sum_}') else: total_sum += sum_ - logger.info(f'n_parameter: {n_parameter}') - logger.info(f'total_sum: {total_sum}') - logger.info(f'zero_count: {zero_count}') + cond = mpu.get_data_parallel_rank() == 0 + logger.info_if(f'n_parameter: {n_parameter}', cond=cond) + logger.info_if(f'total_sum: {total_sum}', cond=cond) + logger.info_if(f'zero_count: {zero_count}', cond=cond) def _find_modules(model, recurse: bool = True, prefix='', ignore_modules=None): @@ -144,29 +146,29 @@ def get_examples(is_multimodal: bool) -> Dict[str, Any]: def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float32): - hf_model.eval() - mg_model.eval() - _test_params_sum(hf_model) + template.set_mode('train') _test_params_sum(mg_model) - template.set_mode('train') - template.register_post_encode_hook([hf_model]) is_multimodal = template.model_meta.is_multimodal inputs = get_examples(is_multimodal) inputs = template.encode(inputs) inputs = to_device(template.data_collator([inputs]), 'cuda') - - HfConfigFactory.set_model_config_attr(hf_model, 'use_cache', False) mg_language_model = mg_model.language_model if is_multimodal else mg_model share_embedding = mg_language_model.share_embeddings_and_output_weights - model_arch = hf_model.model_meta.model_arch - ignore_modules = (model_arch.vision_tower + model_arch.aligner) if is_multimodal else [] - - hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules) - with torch.inference_mode(), _model_cpu_forward_context(hf_modules, torch_dtype, share_embedding=share_embedding): - inputs.pop('text_position_ids', None) - hf_logits = hf_model(**inputs).logits - hf_model.to('cpu') + if hf_model is not None: + hf_model.eval() + _test_params_sum(hf_model) + template.register_post_encode_hook([hf_model]) + HfConfigFactory.set_model_config_attr(hf_model, 'use_cache', False) + model_arch = hf_model.model_meta.model_arch + ignore_modules = (model_arch.vision_tower + model_arch.aligner) if is_multimodal else [] + hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules) + with torch.inference_mode(), _model_cpu_forward_context( + hf_modules, torch_dtype, share_embedding=share_embedding): + inputs.pop('text_position_ids', None) + hf_logits = hf_model(**inputs).logits + hf_logits = hf_logits.to('cuda') + hf_model.to('cpu') input_ids = inputs['input_ids'] attention_mask, _, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True) @@ -186,6 +188,11 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float mg_modules, mg_torch_dtype, 'cuda', share_embedding=share_embedding): mg_logits = mg_model( input_ids=input_ids, attention_mask=attention_mask, packed_seq_params=packed_seq_params, **kwargs) + if args.tensor_model_parallel_size > 1: + from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region + mg_logits = gather_from_tensor_model_parallel_region(mg_logits) + if hf_model is None: + return args = get_args() if args.task_type == 'seq_cls': mg_logits = mg_logits[:, -1] @@ -257,10 +264,9 @@ def convert_hf2mcore(args: ExportArguments) -> None: logger.info('Megatron model created successfully.') bridge = megatron_model_meta.bridge_cls() bridge.load_weights(mg_model, args.model_info.model_dir) + logger.info('Successfully transferred HF model weights to MG model.') if args.test_convert_precision: test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) - del hf_model - logger.info('Successfully transferred HF model weights to MG model.') args.save_args() logger.info('Saving the model...') mg_save_checkpoint(1, [mg_model], None, None, 0) diff --git a/swift/megatron/export/__init__.py b/swift/megatron/export/__init__.py new file mode 100644 index 0000000000..97590932f0 --- /dev/null +++ b/swift/megatron/export/__init__.py @@ -0,0 +1 @@ +from .export import megatron_export_main diff --git a/swift/megatron/export/export.py b/swift/megatron/export/export.py new file mode 100644 index 0000000000..1546ccac91 --- /dev/null +++ b/swift/megatron/export/export.py @@ -0,0 +1,54 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import List, Optional, Union + +from megatron.training import initialize_megatron + +from swift.llm import SwiftPipeline, prepare_model_template +from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_master +from ..argument import MegatronExportArguments +from ..convert import test_convert_precision +from ..utils import prepare_mcore_model + +logger = get_logger() + + +class MegatronExport(SwiftPipeline): + args_class = MegatronExportArguments + args: args_class + + def run(self): + args = self.args + if args.to_hf: + self.convert_mcore2hf() + elif args.to_mcore: + self.convert_hf2mcore() + + def convert_mcore2hf(self) -> None: + print() + + def convert_hf2mcore(self) -> None: + args = self.args + _, template = prepare_model_template(args, load_model=False) + self.processor = template.processor + args.init_model_args(self.tokenizer, self.processor.model_info.config) + megatron_model_meta = args.megatron_model_meta + extra_args_provider = megatron_model_meta.extra_args_provider + initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args.extra_args) + + mg_model = megatron_model_meta.model_provider() + logger.info('Megatron model created successfully.') + bridge = megatron_model_meta.bridge_cls() + bridge.load_weights(mg_model, args.model_info.model_dir) + logger.info('Successfully transferred HF model weights to MG model.') + if args.test_convert_precision: + with disable_safe_ddp_context_use_barrier(): + hf_model = prepare_model_template(args, device_map='cpu')[0] if is_master() else None + test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) + args.save_args() + logger.info('Saving the model...') + mg_save_checkpoint(1, [mg_model], None, None, 0) + logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.') + + +def megatron_export_main(args: Optional[Union[List[str], MegatronExportArguments]] = None): + return MegatronExport(args).main() diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index cc264fca90..0db3f22321 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -51,7 +51,7 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0): hf_weight = hf_weight.to(mg_param.device) if tp_dim is not None and self.tp_size > 1: if self.tp_rank == 0: - splited_weights = list(hf_weight.chunk(self.tp_size, dim=tp_dim)) + splited_weights = [t.contiguous() for t in hf_weight.chunk(self.tp_size, dim=tp_dim)] else: splited_weights = None tensor = torch.empty_like(mg_param) @@ -70,7 +70,7 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0): def _get_weights(self, mg_weight, mg_key, offset: int = 0): tp_dim = self._get_tp_split_dim(mg_key) if tp_dim is not None and self.tp_size > 1: - gather_list = [torch.empty_like(mg_weight) for _ in range(self.tp_size)] + gather_list = [torch.empty_like(mg_weight) for _ in range(self.tp_size)] if self.tp_rank == 0 else None torch.distributed.gather( mg_weight, gather_list, @@ -336,9 +336,9 @@ def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): for layer_idx in tqdm(range(self.args.num_layers), dynamic_ncols=True, desc='Converting: '): mg_layer = mg_model.decoder.layers[layer_idx] - hf_state_dict = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, reverse) + res = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, reverse) if reverse: - yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + yield from list(self._add_prefix(res, hf_prefix).items()) else: yield diff --git a/swift/megatron/utils/io_utils.py b/swift/megatron/utils/io_utils.py index ef53c7e5c3..5590aa5556 100644 --- a/swift/megatron/utils/io_utils.py +++ b/swift/megatron/utils/io_utils.py @@ -2,9 +2,11 @@ from functools import partial import json -from safetensors.torch import save_file, safe_open +from safetensors.torch import safe_open, save_file + from swift.utils import is_master + class LazyTensor: def __init__(self, tensor=None, loader=None): @@ -70,9 +72,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() - - class StreamingSafetensorSaver: + def __init__(self, save_dir, max_shard_size='5GB') -> None: if not is_master(): return @@ -83,8 +84,8 @@ def __init__(self, save_dir, max_shard_size='5GB') -> None: if max_shard_size.endswith('GB'): max_shard_size = int(max_shard_size[:-2]) else: - raise ValueError(f"Invalid max_shard_size: {max_shard_size}") - self.max_shard_size = max_shard_size * 1000 ** 3 + raise ValueError(f'Invalid max_shard_size: {max_shard_size}') + self.max_shard_size = max_shard_size * 1000**3 self.current_shard = {} self.current_shard_size = 0 self.total_size = 0 @@ -124,11 +125,11 @@ def finalize(self): total_shards = self.shard_index - 1 # rename `?????` for i in range(1, total_shards + 1): - old_path = os.path.join(self.save_dir, f"model-{i:05d}-of-?????.safetensors") + old_path = os.path.join(self.save_dir, f'model-{i:05d}-of-?????.safetensors') if total_shards == 1: - new_name = f'model.safetensors' + new_name = 'model.safetensors' else: - new_name = f"model-{i:05d}-of-{total_shards:05d}.safetensors" + new_name = f'model-{i:05d}-of-{total_shards:05d}.safetensors' new_path = os.path.join(self.save_dir, new_name) if os.path.exists(old_path): os.rename(old_path, new_path) @@ -141,13 +142,8 @@ def finalize(self): self._save_index(updated_weight_map) def _save_index(self, weight_map): - index = { - "metadata": { - "total_size": self.total_size - }, - "weight_map": weight_map - } - - index_path = os.path.join(self.save_dir, "model.safetensors.index.json") + index = {'metadata': {'total_size': self.total_size}, 'weight_map': weight_map} + + index_path = os.path.join(self.save_dir, 'model.safetensors.index.json') with open(index_path, 'w') as f: json.dump(index, f, indent=2) From 5d7233d2779dec8bd8e1b56cf50993d9fdb310e3 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 28 Oct 2025 14:05:33 +0800 Subject: [PATCH 14/30] update --- swift/megatron/convert.py | 2 +- swift/megatron/export/export.py | 3 ++- swift/megatron/model/gpt_bridge.py | 15 +++++++++++---- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 3c3eb2e47c..591b579e8d 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -182,6 +182,7 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float mg_language_model.config.fp8 = None # compat fp8 mg_modules = _find_modules(mg_language_model, ignore_modules=['visual']) kwargs = {k: v for k, v in inputs.items() if k not in ['input_ids', 'attention_mask', 'labels']} + args = get_args() if 'position_ids' not in kwargs: kwargs['position_ids'] = position_ids with torch.inference_mode(), _model_cpu_forward_context( @@ -193,7 +194,6 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float mg_logits = gather_from_tensor_model_parallel_region(mg_logits) if hf_model is None: return - args = get_args() if args.task_type == 'seq_cls': mg_logits = mg_logits[:, -1] mean_diff = (mg_logits - hf_logits).abs().mean().item() diff --git a/swift/megatron/export/export.py b/swift/megatron/export/export.py index 1546ccac91..3edb85ebe8 100644 --- a/swift/megatron/export/export.py +++ b/swift/megatron/export/export.py @@ -2,6 +2,7 @@ from typing import List, Optional, Union from megatron.training import initialize_megatron +from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint from swift.llm import SwiftPipeline, prepare_model_template from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_master @@ -44,7 +45,7 @@ def convert_hf2mcore(self) -> None: with disable_safe_ddp_context_use_barrier(): hf_model = prepare_model_template(args, device_map='cpu')[0] if is_master() else None test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) - args.save_args() + args.save_args(args.save) logger.info('Saving the model...') mg_save_checkpoint(1, [mg_model], None, None, 0) logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.') diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 0db3f22321..cdd1fd2987 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -46,7 +46,7 @@ def _get_tp_split_dim(mg_key: str) -> Optional[int]: elif key in {'linear_proj', 'linear_fc2'}: return 1 - def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0): + def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0, mg_slices=()): tp_dim = self._get_tp_split_dim(mg_key) hf_weight = hf_weight.to(mg_param.device) if tp_dim is not None and self.tp_size > 1: @@ -54,7 +54,7 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0): splited_weights = [t.contiguous() for t in hf_weight.chunk(self.tp_size, dim=tp_dim)] else: splited_weights = None - tensor = torch.empty_like(mg_param) + tensor = torch.empty_like(mg_param.data[mg_slices]) torch.distributed.scatter( tensor, splited_weights, @@ -65,7 +65,7 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0): tensor = hf_weight if offset: tensor = tensor + offset - mg_param.data.copy_(tensor) + mg_param.data[mg_slices].copy_(tensor) def _get_weights(self, mg_weight, mg_key, offset: int = 0): tp_dim = self._get_tp_split_dim(mg_key) @@ -245,7 +245,14 @@ def _set_mlp_state( hf_state_dict['up_proj.weight'].load(), ], dim=0) - self._set_weights(deep_getattr(mg_mlp, fc1_key), fc1_weight, 'linear_qkv.weight') + linear_fc1_weight = deep_getattr(mg_mlp, fc1_key) + gate_slices = (slice(None, linear_fc1_weight.shape[0] // 2), ) + up_slices = (slice(linear_fc1_weight.shape[0] // 2, None), ) + self._set_weights( + linear_fc1_weight, hf_state_dict['gate_proj.weight'].load(), fc1_key, mg_slices=gate_slices) + self._set_weights( + linear_fc1_weight, hf_state_dict['up_proj.weight'].load(), fc1_key, mg_slices=up_slices) + self._set_state_dict(mg_mlp, fc2_key, hf_state_dict, 'down_proj.weight', reverse) if reverse: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) From b8c1746b6d6630cde0ec98a830da7d979db8c394 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 28 Oct 2025 17:10:02 +0800 Subject: [PATCH 15/30] update --- swift/llm/argument/base_args/base_args.py | 15 +++++---- swift/megatron/argument/export_args.py | 1 - swift/megatron/convert.py | 5 ++- swift/megatron/export/export.py | 22 +++++++++++-- swift/megatron/model/gpt_bridge.py | 40 ++++++++++++----------- 5 files changed, 51 insertions(+), 32 deletions(-) diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index bc18ba573b..738a447e7c 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -306,12 +306,13 @@ def get_model_processor(self, **kwargs): if self.tuner_backend == 'unsloth': return load_by_unsloth(self) - kwargs.update(self.get_model_kwargs()) + res = self.get_model_kwargs() + res.update(kwargs) # compat rlhf - kwargs['model_id_or_path'] = model or self.model - kwargs['model_type'] = model_type or self.model_type - kwargs['model_revision'] = model_revision or self.model_revision - kwargs['task_type'] = task_type or self.task_type - kwargs['num_labels'] = num_labels or self.num_labels + res['model_id_or_path'] = model or self.model + res['model_type'] = model_type or self.model_type + res['model_revision'] = model_revision or self.model_revision + res['task_type'] = task_type or self.task_type + res['num_labels'] = num_labels or self.num_labels - return get_model_tokenizer(**kwargs) + return get_model_tokenizer(**res) diff --git a/swift/megatron/argument/export_args.py b/swift/megatron/argument/export_args.py index 6f051213b9..64cdd68aaf 100644 --- a/swift/megatron/argument/export_args.py +++ b/swift/megatron/argument/export_args.py @@ -48,7 +48,6 @@ def _init_convert(self): 'no_load_rng': True, 'finetune': True, 'attention_backend': 'unfused', - 'device_map': 'cpu', 'padding_free': False, } for k, v in convert_kwargs.items(): diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 591b579e8d..1208d60252 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -30,7 +30,7 @@ def _test_params_sum(model): n_parameter = 0 for n, p in model.named_parameters(): n_parameter += 1 - sum_ = p.cuda().float().abs().sum().cpu().item() + sum_ = p.to(device='cuda', dtype=torch.float32).abs().sum().cpu().item() if sum_ == 0: zero_count += 1 logger.warning(f'n: {n}, sum: {sum_}') @@ -321,8 +321,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: bridge.save_weights([mg_model], args.output_dir) logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') if args.test_convert_precision: - args.model = args.output_dir - hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision) + hf_model, template = prepare_model_template(args, model=args.output_dir) test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) elif args.to_mcore: if args.thread_count is None: diff --git a/swift/megatron/export/export.py b/swift/megatron/export/export.py index 3edb85ebe8..00554c1c1d 100644 --- a/swift/megatron/export/export.py +++ b/swift/megatron/export/export.py @@ -2,6 +2,7 @@ from typing import List, Optional, Union from megatron.training import initialize_megatron +from megatron.training.checkpointing import load_checkpoint from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint from swift.llm import SwiftPipeline, prepare_model_template @@ -25,7 +26,24 @@ def run(self): self.convert_hf2mcore() def convert_mcore2hf(self) -> None: - print() + args = self.args + _, template = prepare_model_template(args, load_model=False) + self.processor = template.processor + args.init_model_args(self.tokenizer, self.processor.model_info.config) + megatron_model_meta = args.megatron_model_meta + extra_args_provider = megatron_model_meta.extra_args_provider + initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args.extra_args) + + mg_model = megatron_model_meta.model_provider() + load_checkpoint([mg_model], None, None, strict=True) + logger.info('Converting weights and saving the model...') + bridge = megatron_model_meta.bridge_cls() + bridge.save_weights([mg_model], args.save) + logger.info(f'Successfully saved HF model weights in `{args.save}`.') + if args.test_convert_precision: + with disable_safe_ddp_context_use_barrier(): + hf_model = prepare_model_template(args, model=args.save, device_map='cpu')[0] if is_master() else None + test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) def convert_hf2mcore(self) -> None: args = self.args @@ -48,7 +66,7 @@ def convert_hf2mcore(self) -> None: args.save_args(args.save) logger.info('Saving the model...') mg_save_checkpoint(1, [mg_model], None, None, 0) - logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.') + logger.info(f'Successfully saved Megatron model weights in `{args.save}`.') def megatron_export_main(args: Optional[Union[List[str], MegatronExportArguments]] = None): diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index cdd1fd2987..6ebeefab22 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -7,7 +7,7 @@ from tqdm import tqdm from swift.llm import deep_getattr, get_model_tokenizer, save_checkpoint -from swift.utils import disable_safe_ddp_context_use_barrier +from swift.utils import disable_safe_ddp_context_use_barrier, is_master from ..utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver @@ -70,14 +70,13 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0, mg_sli def _get_weights(self, mg_weight, mg_key, offset: int = 0): tp_dim = self._get_tp_split_dim(mg_key) if tp_dim is not None and self.tp_size > 1: - gather_list = [torch.empty_like(mg_weight) for _ in range(self.tp_size)] if self.tp_rank == 0 else None - torch.distributed.gather( + tensor_list = [torch.empty_like(mg_weight) for _ in range(self.tp_size)] + torch.distributed.all_gather( + tensor_list, mg_weight, - gather_list, - dst=0, group=self.tp_group, ) - tensor = torch.cat(gather_list, dim=tp_dim) + tensor = torch.cat(tensor_list, dim=tp_dim) else: tensor = mg_weight if offset: @@ -137,7 +136,8 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int args = self.args num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) if reverse: - mg_attn_weight = mg_attn.linear_qkv.weight.reshape((num_query_groups, -1, args.hidden_size)) + mg_attn_weight = self._get_weights(mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') + mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size)) q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ 0] // num_query_groups hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size) @@ -156,7 +156,8 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int # Copy bias if args.add_qkv_bias: if reverse: - mg_attn_bias = mg_attn.linear_qkv.bias.reshape((num_query_groups, -1)) + mg_attn_bias = self._get_weights(mg_attn.linear_qkv.bias.data, 'linear_qkv.bias') + mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1)) hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1) hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1) hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1) @@ -235,10 +236,10 @@ def _set_mlp_state( self._set_state_dict(hf_state_dict, res, 'gate_up_proj.weight', fc1_key, reverse) else: if reverse: - ffn_hidden_size = hf_mlp.gate_proj.weight.shape[0] fc1_weight = deep_getattr(mg_mlp, fc1_key) - hf_state_dict['gate_proj.weight'] = fc1_weight[:ffn_hidden_size] - hf_state_dict['up_proj.weight'] = fc1_weight[ffn_hidden_size:] + hf_state_dict['gate_proj.weight'] = self._get_weights(fc1_weight[:fc1_weight.shape[0] // 2], + fc1_key) + hf_state_dict['up_proj.weight'] = self._get_weights(fc1_weight[fc1_weight.shape[0] // 2:], fc1_key) else: fc1_weight = torch.cat([ hf_state_dict['gate_proj.weight'].load(), @@ -365,11 +366,12 @@ def save_weights(self, mg_models, output_dir: str) -> None: for k, v in self.export_weights(mg_models): saver.add_tensor(k, v) saver.finalize() - # TODO: new_special_tokens - self.hf_model.config.save_pretrained(output_dir) - save_checkpoint( - None, - self.processor, - output_dir, - model_dirs=[self.hf_model.model_info.model_dir], - additional_saved_files=self.hf_model.model_meta.additional_saved_files) + if is_master(): + # TODO: new_special_tokens + self.hf_model.config.save_pretrained(output_dir) + save_checkpoint( + None, + self.processor, + output_dir, + model_dirs=[self.hf_model.model_info.model_dir], + additional_saved_files=self.hf_model.model_meta.additional_saved_files) From 43830cb0b6ffff8c2da441bd3d9b162a445c66ce Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 28 Oct 2025 17:11:32 +0800 Subject: [PATCH 16/30] update --- tests/megatron/export/test_export.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/megatron/export/test_export.py diff --git a/tests/megatron/export/test_export.py b/tests/megatron/export/test_export.py new file mode 100644 index 0000000000..fa31528950 --- /dev/null +++ b/tests/megatron/export/test_export.py @@ -0,0 +1,27 @@ +from swift.megatron import MegatronExportArguments, megatron_export_main + + +def test_to_mcore(): + megatron_export_main( + MegatronExportArguments( + model='Qwen/Qwen2.5-7B-Instruct', + save='Qwen2.5-7B-Instruct-mcore', + to_mcore=True, + exist_ok=True, + tensor_model_parallel_size=2, + test_convert_precision=True)) + + +def test_to_hf(): + megatron_export_main( + MegatronExportArguments( + load='Qwen2.5-7B-Instruct-mcore', + to_hf=True, + exist_ok=True, + tensor_model_parallel_size=2, + test_convert_precision=True)) + + +if __name__ == '__main__': + # test_to_mcore() + test_to_hf() From 9a968b5dc2d2eb2a338766152e273b7d58dc5f4a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 28 Oct 2025 18:14:13 +0800 Subject: [PATCH 17/30] update --- swift/megatron/argument/train_args.py | 4 +-- swift/megatron/convert.py | 12 ++++---- swift/megatron/export/export.py | 20 +++++++++---- swift/megatron/model/gpt_bridge.py | 43 +++++++++++++++++++-------- swift/megatron/train/sft.py | 4 +-- swift/megatron/trainers/rlhf_mixin.py | 29 ------------------ swift/megatron/utils/__init__.py | 2 +- swift/megatron/utils/io_utils.py | 15 +++++----- swift/megatron/utils/utils.py | 32 ++++++++++++++++++++ 9 files changed, 98 insertions(+), 63 deletions(-) diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 4c0b19ab55..2d3acb1313 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -7,7 +7,7 @@ from swift.llm import BaseArguments from swift.llm.argument.base_args import to_abspath -from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master +from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_last_rank from .megatron_base_args import MegatronBaseArguments logger = get_logger() @@ -26,7 +26,7 @@ def _init_save(self): if self.add_version: self.save = add_version_to_work_dir(self.save) logger.info(f'args.save: {self.save}') - if is_master(): + if is_last_rank(): os.makedirs(self.save, exist_ok=True) def _init_ckpt_dir(self, adapters=None): diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 1208d60252..d95c9ff3ab 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -19,7 +19,7 @@ from swift.utils import get_logger, get_n_params_grads from .argument import MegatronArguments from .model import get_megatron_model_meta -from .utils import convert_hf_config, patch_torch_dist_shard +from .utils import convert_hf_config, forward_step_helper, patch_torch_dist_shard logger = get_logger() @@ -67,7 +67,7 @@ def _find_modules(model, recurse: bool = True, prefix='', ignore_modules=None): @contextmanager def _model_cpu_forward_context(modules, torch_dtype=None, device=None, share_embedding: bool = False): - origin_torch_dtype = next(modules[0].parameters()).dtype + origin_torch_dtype = next(modules[-1].parameters()).dtype def _to_cuda_hook(module, args): if device is not None or torch_dtype is not None: @@ -181,14 +181,16 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float # attention_mask = None mg_language_model.config.fp8 = None # compat fp8 mg_modules = _find_modules(mg_language_model, ignore_modules=['visual']) - kwargs = {k: v for k, v in inputs.items() if k not in ['input_ids', 'attention_mask', 'labels']} + kwargs = inputs.copy() + kwargs.pop('labels') + kwargs.update({'input_ids': input_ids, 'attention_mask': attention_mask, 'packed_seq_params': packed_seq_params}) args = get_args() if 'position_ids' not in kwargs: kwargs['position_ids'] = position_ids with torch.inference_mode(), _model_cpu_forward_context( mg_modules, mg_torch_dtype, 'cuda', share_embedding=share_embedding): - mg_logits = mg_model( - input_ids=input_ids, attention_mask=attention_mask, packed_seq_params=packed_seq_params, **kwargs) + # TODO: test pp tie_weights + mg_logits = forward_step_helper(mg_model, kwargs, dtype=mg_torch_dtype) if args.tensor_model_parallel_size > 1: from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region mg_logits = gather_from_tensor_model_parallel_region(mg_logits) diff --git a/swift/megatron/export/export.py b/swift/megatron/export/export.py index 00554c1c1d..7717876b36 100644 --- a/swift/megatron/export/export.py +++ b/swift/megatron/export/export.py @@ -1,12 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List, Optional, Union +import torch.distributed as dist +from megatron.core import mpu from megatron.training import initialize_megatron from megatron.training.checkpointing import load_checkpoint from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint from swift.llm import SwiftPipeline, prepare_model_template -from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_master +from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_last_rank from ..argument import MegatronExportArguments from ..convert import test_convert_precision from ..utils import prepare_mcore_model @@ -34,15 +36,19 @@ def convert_mcore2hf(self) -> None: extra_args_provider = megatron_model_meta.extra_args_provider initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args.extra_args) - mg_model = megatron_model_meta.model_provider() + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + mg_model = megatron_model_meta.model_provider(pre_process=pre_process, post_process=post_process) load_checkpoint([mg_model], None, None, strict=True) logger.info('Converting weights and saving the model...') bridge = megatron_model_meta.bridge_cls() bridge.save_weights([mg_model], args.save) logger.info(f'Successfully saved HF model weights in `{args.save}`.') + dist.barrier() if args.test_convert_precision: with disable_safe_ddp_context_use_barrier(): - hf_model = prepare_model_template(args, model=args.save, device_map='cpu')[0] if is_master() else None + hf_model = prepare_model_template( + args, model=args.save, device_map='cpu')[0] if is_last_rank() else None test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) def convert_hf2mcore(self) -> None: @@ -54,15 +60,19 @@ def convert_hf2mcore(self) -> None: extra_args_provider = megatron_model_meta.extra_args_provider initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args.extra_args) - mg_model = megatron_model_meta.model_provider() + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + mg_model = megatron_model_meta.model_provider(pre_process=pre_process, post_process=post_process) logger.info('Megatron model created successfully.') bridge = megatron_model_meta.bridge_cls() bridge.load_weights(mg_model, args.model_info.model_dir) logger.info('Successfully transferred HF model weights to MG model.') + dist.barrier() if args.test_convert_precision: with disable_safe_ddp_context_use_barrier(): - hf_model = prepare_model_template(args, device_map='cpu')[0] if is_master() else None + hf_model = prepare_model_template(args, device_map='cpu')[0] if is_last_rank() else None test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) + dist.barrier() args.save_args(args.save) logger.info('Saving the model...') mg_save_checkpoint(1, [mg_model], None, None, 0) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 6ebeefab22..a9ec4d6b9a 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -7,7 +7,7 @@ from tqdm import tqdm from swift.llm import deep_getattr, get_model_tokenizer, save_checkpoint -from swift.utils import disable_safe_ddp_context_use_barrier, is_master +from swift.utils import disable_safe_ddp_context_use_barrier, is_last_rank from ..utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver @@ -83,14 +83,25 @@ def _get_weights(self, mg_weight, mg_key, offset: int = 0): tensor = tensor + offset return tensor - def _set_state_dict(self, mg_module, mg_key: str, hf_state_dict, hf_key: str, reverse: bool, offset: float = 0): + def _set_state_dict(self, + mg_module, + mg_key: str, + hf_state_dict, + hf_key: str, + reverse: bool, + offset: float = 0, + pre_process: bool = False, + post_process: bool = False): mg_param = deep_getattr(mg_module, mg_key) if reverse: hf_state_dict[hf_key] = self._get_weights(mg_param.data, mg_key, offset) else: - if mg_param is None: - assert self.pp_size > 1, f'mg_module: {mg_module}, mg_key: {mg_key}' - return + if mg_param is None and self.pp_size > 1: + if pre_process and not mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_module.vp_stage): + return + elif post_process and not mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_module.vp_stage): + return + assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}' hf_weight = hf_state_dict[hf_key].load() self._set_weights(mg_param, hf_weight, mg_key, offset) @@ -329,22 +340,30 @@ def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): hf_state_dict = {} else: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - self._set_state_dict(mg_model, 'embedding.word_embeddings.weight', hf_state_dict, 'model.embed_tokens.weight', - reverse) + self._set_state_dict( + mg_model, + 'embedding.word_embeddings.weight', + hf_state_dict, + 'model.embed_tokens.weight', + reverse, + pre_process=True) if self.args.untie_embeddings_and_output_weights: hf_lm_head_key = 'lm_head.weight' if reverse and self.args.task_type == 'seq_cls': hf_lm_head_key = 'score.weight' - self._set_state_dict(mg_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, reverse) - self._set_state_dict(mg_model, 'decoder.final_layernorm.weight', hf_state_dict, 'model.norm.weight', reverse) + self._set_state_dict( + mg_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, reverse, post_process=True) + self._set_state_dict( + mg_model, 'decoder.final_layernorm.weight', hf_state_dict, 'model.norm.weight', reverse, post_process=True) if reverse: yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) else: yield - for layer_idx in tqdm(range(self.args.num_layers), dynamic_ncols=True, desc='Converting: '): + for layer_idx in tqdm(range(len(mg_model.decoder.layers)), dynamic_ncols=True, desc='Converting: '): mg_layer = mg_model.decoder.layers[layer_idx] - res = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, reverse) + hf_layer_number = mg_layer.layer_number - 1 + res = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', hf_layer_number, reverse) if reverse: yield from list(self._add_prefix(res, hf_prefix).items()) else: @@ -366,7 +385,7 @@ def save_weights(self, mg_models, output_dir: str) -> None: for k, v in self.export_weights(mg_models): saver.add_tensor(k, v) saver.finalize() - if is_master(): + if is_last_rank(): # TODO: new_special_tokens self.hf_model.config.save_pretrained(output_dir) save_checkpoint( diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index d71eebdc68..7a7cf85f10 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -7,7 +7,7 @@ from swift.llm import TEMPLATE_MAPPING from swift.llm.train import SwiftSft -from swift.utils import get_logger, is_master, plot_images +from swift.utils import get_logger, is_last_rank, plot_images from ..argument import MegatronTrainArguments from ..trainers import MegatronTrainer from .utils import build_streaming_dataloader @@ -67,7 +67,7 @@ def run(self): self.trainer.train(train_dataset, val_dataset, data_collator) finally: # Visualization - if is_master(): + if is_last_rank(): images_dir = os.path.join(args.save, 'images') logger.info(f'images_dir: {images_dir}') plot_images(images_dir, args.tensorboard_dir) diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index 5e65b64060..026f01e8f7 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -53,35 +53,6 @@ def null_ref_context(self): for m in self.peft_models: m.set_adapter('default') - @staticmethod - def _forward_step_helper(model, inputs): - args = get_args() - if mpu.is_pipeline_first_stage(): - micro_batch_size = 1 # use qkv_format 'thd' - seq_length = inputs['input_ids'].shape[1] - if args.sequence_parallel: - seq_length //= mpu.get_tensor_model_parallel_world_size() - recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], - device=torch.cuda.current_device(), - dtype=torch.int64) - else: - recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) - recv_from_prev_pipeline_rank_(recv_shape_buffer) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(recv_shape_buffer) - shape = recv_shape_buffer.tolist() - - if not mpu.is_pipeline_first_stage(): - recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) - recv_from_prev_pipeline_rank_(recv_buffer) - model.set_input_tensor(recv_buffer) - output_tensor = model(**inputs) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(output_tensor) - output_tensor = None - - return output_tensor - def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None): args = get_args() per_token_logps = -output_tensor diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 4bccac1021..bb05457980 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -3,5 +3,5 @@ from .config import convert_hf_config from .io_utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver from .patcher import patch_torch_dist_shard -from .utils import (adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model, +from .utils import (adapter_state_dict_context, copy_original_module_weight, forward_step_helper, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/io_utils.py b/swift/megatron/utils/io_utils.py index 5590aa5556..4489722f40 100644 --- a/swift/megatron/utils/io_utils.py +++ b/swift/megatron/utils/io_utils.py @@ -1,10 +1,11 @@ import os from functools import partial +from typing import Literal import json from safetensors.torch import safe_open, save_file -from swift.utils import is_master +from swift.utils import is_last_rank, is_master class LazyTensor: @@ -74,12 +75,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): class StreamingSafetensorSaver: - def __init__(self, save_dir, max_shard_size='5GB') -> None: - if not is_master(): - return + def __init__(self, save_dir, max_shard_size='5GB', save_rank: Literal['master', 'last'] = 'last') -> None: # max_shard_size: GiB self.save_dir = save_dir - os.makedirs(save_dir, exist_ok=True) if isinstance(max_shard_size, str): if max_shard_size.endswith('GB'): max_shard_size = int(max_shard_size[:-2]) @@ -91,9 +89,12 @@ def __init__(self, save_dir, max_shard_size='5GB') -> None: self.total_size = 0 self.shard_index = 1 self.weight_map = {} + self.is_save_rank = is_last_rank() if save_rank == 'last' else is_master() + if self.is_save_rank: + os.makedirs(save_dir, exist_ok=True) def add_tensor(self, name, tensor): - if not is_master(): + if not is_save_rank: return tensor_size = tensor.numel() * tensor.element_size() if self.current_shard_size + tensor_size > self.max_shard_size and self.current_shard: @@ -118,7 +119,7 @@ def _save_current_shard(self, shard_filename: str = None): self.shard_index += 1 def finalize(self): - if not is_master(): + if not is_save_rank: return if self.current_shard: self._save_current_shard() diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 6d4ae82228..a8e0a43ac5 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -3,9 +3,11 @@ from copy import deepcopy from typing import Optional, Tuple +import torch import torch.distributed as dist from megatron.core import mpu from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear +from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default @@ -267,3 +269,33 @@ def copy_ref_adapter_weight(model, ref_adapter_name: str): sub_module = module.modules_to_save if 'default' in sub_module and ref_adapter_name in sub_module: sub_module[ref_adapter_name].load_state_dict(sub_module['default'].state_dict()) + + +def forward_step_helper(model, inputs, dtype=None): + args = get_args() + if mpu.is_pipeline_first_stage(): + micro_batch_size = 1 # use qkv_format 'thd' + seq_length = inputs['input_ids'].shape[1] + if args.sequence_parallel: + seq_length //= mpu.get_tensor_model_parallel_world_size() + recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], + device=torch.cuda.current_device(), + dtype=torch.int64) + else: + recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) + recv_from_prev_pipeline_rank_(recv_shape_buffer) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(recv_shape_buffer) + shape = recv_shape_buffer.tolist() + + if not mpu.is_pipeline_first_stage(): + dtype = dtype or args.params_dtype + recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=dtype) + recv_from_prev_pipeline_rank_(recv_buffer) + model.set_input_tensor(recv_buffer) + output_tensor = model(**inputs) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor) + output_tensor = None + + return output_tensor From 15e5d074132ca4492c2e3216cce15b650125ae77 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 28 Oct 2025 20:39:42 +0800 Subject: [PATCH 18/30] support pp tp --- swift/megatron/convert.py | 3 +- swift/megatron/model/gpt_bridge.py | 108 +++++++++++++++++------------ swift/megatron/utils/io_utils.py | 4 +- 3 files changed, 68 insertions(+), 47 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index d95c9ff3ab..0244aa403b 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -193,7 +193,8 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float mg_logits = forward_step_helper(mg_model, kwargs, dtype=mg_torch_dtype) if args.tensor_model_parallel_size > 1: from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region - mg_logits = gather_from_tensor_model_parallel_region(mg_logits) + if mg_logits is not None: + mg_logits = gather_from_tensor_model_parallel_region(mg_logits) if hf_model is None: return if args.task_type == 'seq_cls': diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index a9ec4d6b9a..ec2ce760dc 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -2,6 +2,7 @@ from typing import Dict, Literal, Optional, Union import torch +import torch.distributed as dist from megatron.core import mpu from megatron.training import get_args from tqdm import tqdm @@ -47,6 +48,7 @@ def _get_tp_split_dim(mg_key: str) -> Optional[int]: return 1 def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0, mg_slices=()): + # tp tp_dim = self._get_tp_split_dim(mg_key) hf_weight = hf_weight.to(mg_param.device) if tp_dim is not None and self.tp_size > 1: @@ -55,10 +57,10 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0, mg_sli else: splited_weights = None tensor = torch.empty_like(mg_param.data[mg_slices]) - torch.distributed.scatter( + dist.scatter( tensor, splited_weights, - src=0, + src=dist.get_global_rank(self.tp_group, 0), group=self.tp_group, ) else: @@ -68,10 +70,27 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0, mg_sli mg_param.data[mg_slices].copy_(tensor) def _get_weights(self, mg_weight, mg_key, offset: int = 0): + # pp + if self.pp_size > 1: + if mg_weight is None: + output = [None] * self.pp_size + dist.all_gather_object(output, None, group=self.pp_group) + src_idx = self._find_not_none_index(output) + assert len(src_idx) == 1, f'src_idx: {src_idx}' + src_idx = src_idx[0] + shape, dtype = output[src_idx] + mg_weight = torch.empty(shape, device='cuda', dtype=dtype) + dist.broadcast(mg_weight, src=dist.get_global_rank(self.pp_group, src_idx), group=self.pp_group) + else: + output = [None] * self.pp_size + meta_data = (mg_weight.shape, mg_weight.dtype) + dist.all_gather_object(output, meta_data, group=self.pp_group) + dist.broadcast(mg_weight, src=dist.get_global_rank(self.pp_group, self.pp_rank), group=self.pp_group) + # tp tp_dim = self._get_tp_split_dim(mg_key) if tp_dim is not None and self.tp_size > 1: tensor_list = [torch.empty_like(mg_weight) for _ in range(self.tp_size)] - torch.distributed.all_gather( + dist.all_gather( tensor_list, mg_weight, group=self.tp_group, @@ -83,24 +102,19 @@ def _get_weights(self, mg_weight, mg_key, offset: int = 0): tensor = tensor + offset return tensor - def _set_state_dict(self, - mg_module, - mg_key: str, - hf_state_dict, - hf_key: str, - reverse: bool, - offset: float = 0, - pre_process: bool = False, - post_process: bool = False): + @staticmethod + def _find_not_none_index(lst): + res = [] + for i, x in enumerate(lst): + if x is not None: + res.append(i) + return res + + def _set_state_dict(self, mg_module, mg_key: str, hf_state_dict, hf_key: str, reverse: bool, offset: float = 0): mg_param = deep_getattr(mg_module, mg_key) if reverse: - hf_state_dict[hf_key] = self._get_weights(mg_param.data, mg_key, offset) + hf_state_dict[hf_key] = self._get_weights(None if mg_param is None else mg_param.data, mg_key, offset) else: - if mg_param is None and self.pp_size > 1: - if pre_process and not mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_module.vp_stage): - return - elif post_process and not mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_module.vp_stage): - return assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}' hf_weight = hf_state_dict[hf_key].load() self._set_weights(mg_param, hf_weight, mg_key, offset) @@ -147,7 +161,8 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int args = self.args num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) if reverse: - mg_attn_weight = self._get_weights(mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') + mg_attn_weight = self._get_weights(None if mg_attn is None else mg_attn.linear_qkv.weight.data, + 'linear_qkv.weight') mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size)) q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ 0] // num_query_groups @@ -167,7 +182,8 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int # Copy bias if args.add_qkv_bias: if reverse: - mg_attn_bias = self._get_weights(mg_attn.linear_qkv.bias.data, 'linear_qkv.bias') + mg_attn_bias = self._get_weights(None if mg_attn is None else mg_attn.linear_qkv.bias.data, + 'linear_qkv.bias') mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1)) hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1) hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1) @@ -248,9 +264,10 @@ def _set_mlp_state( else: if reverse: fc1_weight = deep_getattr(mg_mlp, fc1_key) - hf_state_dict['gate_proj.weight'] = self._get_weights(fc1_weight[:fc1_weight.shape[0] // 2], - fc1_key) - hf_state_dict['up_proj.weight'] = self._get_weights(fc1_weight[fc1_weight.shape[0] // 2:], fc1_key) + hf_state_dict['gate_proj.weight'] = self._get_weights( + None if fc1_weight is None else fc1_weight[:fc1_weight.shape[0] // 2], fc1_key) + hf_state_dict['up_proj.weight'] = self._get_weights( + None if fc1_weight is None else fc1_weight[fc1_weight.shape[0] // 2:], fc1_key) else: fc1_weight = torch.cat([ hf_state_dict['gate_proj.weight'].load(), @@ -298,7 +315,7 @@ def _set_mla_attn_state( return self._add_prefix(res, tgt_prefix) def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, reverse: bool): - mg_attn = mg_layer.self_attention + mg_attn = None if mg_layer is None else mg_layer.self_attention if self.args.multi_latent_attention: hf_state_dict.update(self._set_mla_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, reverse)) self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', reverse) @@ -311,7 +328,7 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, reverse: bool def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, reverse: bool): hf_mlp = self.hf_layers[layer_idx].mlp is_moe = self._is_moe(hf_mlp.state_dict()) - mg_mlp = mg_layer.mlp + mg_mlp = None if mg_layer is None else mg_layer.mlp if is_moe: hf_state_dict.update(self._set_moe_state(mg_mlp, hf_state_dict, 'mlp.', layer_idx, reverse)) self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', @@ -340,30 +357,33 @@ def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): hf_state_dict = {} else: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - self._set_state_dict( - mg_model, - 'embedding.word_embeddings.weight', - hf_state_dict, - 'model.embed_tokens.weight', - reverse, - pre_process=True) - if self.args.untie_embeddings_and_output_weights: - hf_lm_head_key = 'lm_head.weight' - if reverse and self.args.task_type == 'seq_cls': - hf_lm_head_key = 'score.weight' - self._set_state_dict( - mg_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, reverse, post_process=True) - self._set_state_dict( - mg_model, 'decoder.final_layernorm.weight', hf_state_dict, 'model.norm.weight', reverse, post_process=True) + if reverse or mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): + self._set_state_dict(mg_model, 'embedding.word_embeddings.weight', hf_state_dict, + 'model.embed_tokens.weight', reverse) + if reverse or mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): + if self.args.untie_embeddings_and_output_weights: + hf_lm_head_key = 'lm_head.weight' + if reverse and self.args.task_type == 'seq_cls': + hf_lm_head_key = 'score.weight' + self._set_state_dict(mg_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, reverse) + self._set_state_dict(mg_model, 'decoder.final_layernorm.weight', hf_state_dict, 'model.norm.weight', + reverse) if reverse: yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) else: yield - for layer_idx in tqdm(range(len(mg_model.decoder.layers)), dynamic_ncols=True, desc='Converting: '): - mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer_number = mg_layer.layer_number - 1 - res = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', hf_layer_number, reverse) + for layer_idx in tqdm(range(self.args.num_layers), dynamic_ncols=True, desc='Converting: '): + start_idx = mg_model.decoder.layers[0].layer_number - 1 + mg_layer_avaiable = (start_idx <= layer_idx <= mg_model.decoder.layers[-1].layer_number - 1) + if mg_layer_avaiable: + mg_layer = mg_model.decoder.layers[layer_idx - start_idx] + else: + if reverse: + mg_layer = None + else: + continue + res = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, reverse) if reverse: yield from list(self._add_prefix(res, hf_prefix).items()) else: diff --git a/swift/megatron/utils/io_utils.py b/swift/megatron/utils/io_utils.py index 4489722f40..602be8fa0f 100644 --- a/swift/megatron/utils/io_utils.py +++ b/swift/megatron/utils/io_utils.py @@ -94,7 +94,7 @@ def __init__(self, save_dir, max_shard_size='5GB', save_rank: Literal['master', os.makedirs(save_dir, exist_ok=True) def add_tensor(self, name, tensor): - if not is_save_rank: + if not self.is_save_rank: return tensor_size = tensor.numel() * tensor.element_size() if self.current_shard_size + tensor_size > self.max_shard_size and self.current_shard: @@ -119,7 +119,7 @@ def _save_current_shard(self, shard_filename: str = None): self.shard_index += 1 def finalize(self): - if not is_save_rank: + if not self.is_save_rank: return if self.current_shard: self._save_current_shard() From 856c52ec070fdc1f0b5b1da0d01783b54e2d67d3 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 28 Oct 2025 21:36:03 +0800 Subject: [PATCH 19/30] update --- swift/megatron/export/export.py | 8 ++-- swift/megatron/model/gpt_bridge.py | 75 +++++++++++++++++++----------- swift/megatron/trainers/base.py | 32 ++----------- swift/megatron/utils/__init__.py | 2 +- swift/megatron/utils/patcher.py | 49 ++++++++++++++++++- 5 files changed, 105 insertions(+), 61 deletions(-) diff --git a/swift/megatron/export/export.py b/swift/megatron/export/export.py index 7717876b36..8d2582f0ec 100644 --- a/swift/megatron/export/export.py +++ b/swift/megatron/export/export.py @@ -11,7 +11,7 @@ from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_last_rank from ..argument import MegatronExportArguments from ..convert import test_convert_precision -from ..utils import prepare_mcore_model +from ..utils import patch_load_base_checkpoint, prepare_mcore_model logger = get_logger() @@ -39,7 +39,8 @@ def convert_mcore2hf(self) -> None: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() mg_model = megatron_model_meta.model_provider(pre_process=pre_process, post_process=post_process) - load_checkpoint([mg_model], None, None, strict=True) + with patch_load_base_checkpoint(): + load_checkpoint([mg_model], None, None, strict=True) logger.info('Converting weights and saving the model...') bridge = megatron_model_meta.bridge_cls() bridge.save_weights([mg_model], args.save) @@ -50,6 +51,7 @@ def convert_mcore2hf(self) -> None: hf_model = prepare_model_template( args, model=args.save, device_map='cpu')[0] if is_last_rank() else None test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) + dist.barrier() def convert_hf2mcore(self) -> None: args = self.args @@ -72,7 +74,7 @@ def convert_hf2mcore(self) -> None: with disable_safe_ddp_context_use_barrier(): hf_model = prepare_model_template(args, device_map='cpu')[0] if is_last_rank() else None test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) - dist.barrier() + dist.barrier() args.save_args(args.save) logger.info('Saving the model...') mg_save_checkpoint(1, [mg_model], None, None, 0) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index ec2ce760dc..85f34ff0f0 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -63,6 +63,7 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0, mg_sli src=dist.get_global_rank(self.tp_group, 0), group=self.tp_group, ) + del splited_weights else: tensor = hf_weight if offset: @@ -70,34 +71,46 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0, mg_sli mg_param.data[mg_slices].copy_(tensor) def _get_weights(self, mg_weight, mg_key, offset: int = 0): + # tp + tp_dim = self._get_tp_split_dim(mg_key) + tensor = mg_weight + if tensor is not None and tp_dim is not None and self.tp_size > 1: + if tp_dim == 0: + # save memory + tensor_shape = list(tensor.shape) + tensor_shape[0] *= self.tp_size + output = tensor.new_empty(tensor_shape) + dist.all_gather_into_tensor( + output, + tensor, + group=self.tp_group, + ) + tensor = output + else: + output = [torch.empty_like(tensor) for _ in range(self.tp_size)] + dist.all_gather( + output, + tensor, + group=self.tp_group, + ) + tensor = torch.cat(output, dim=tp_dim) + del output # pp if self.pp_size > 1: - if mg_weight is None: + if tensor is None: output = [None] * self.pp_size dist.all_gather_object(output, None, group=self.pp_group) src_idx = self._find_not_none_index(output) assert len(src_idx) == 1, f'src_idx: {src_idx}' src_idx = src_idx[0] shape, dtype = output[src_idx] - mg_weight = torch.empty(shape, device='cuda', dtype=dtype) - dist.broadcast(mg_weight, src=dist.get_global_rank(self.pp_group, src_idx), group=self.pp_group) + tensor = torch.empty(shape, device='cuda', dtype=dtype) + dist.broadcast(tensor, src=dist.get_global_rank(self.pp_group, src_idx), group=self.pp_group) else: output = [None] * self.pp_size - meta_data = (mg_weight.shape, mg_weight.dtype) + meta_data = (tensor.shape, tensor.dtype) dist.all_gather_object(output, meta_data, group=self.pp_group) - dist.broadcast(mg_weight, src=dist.get_global_rank(self.pp_group, self.pp_rank), group=self.pp_group) - # tp - tp_dim = self._get_tp_split_dim(mg_key) - if tp_dim is not None and self.tp_size > 1: - tensor_list = [torch.empty_like(mg_weight) for _ in range(self.tp_size)] - dist.all_gather( - tensor_list, - mg_weight, - group=self.tp_group, - ) - tensor = torch.cat(tensor_list, dim=tp_dim) - else: - tensor = mg_weight + dist.broadcast(tensor, src=dist.get_global_rank(self.pp_group, self.pp_rank), group=self.pp_group) if offset: tensor = tensor + offset return tensor @@ -360,19 +373,11 @@ def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): if reverse or mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): self._set_state_dict(mg_model, 'embedding.word_embeddings.weight', hf_state_dict, 'model.embed_tokens.weight', reverse) - if reverse or mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): - if self.args.untie_embeddings_and_output_weights: - hf_lm_head_key = 'lm_head.weight' - if reverse and self.args.task_type == 'seq_cls': - hf_lm_head_key = 'score.weight' - self._set_state_dict(mg_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, reverse) - self._set_state_dict(mg_model, 'decoder.final_layernorm.weight', hf_state_dict, 'model.norm.weight', - reverse) if reverse: yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + hf_state_dict = {} else: yield - for layer_idx in tqdm(range(self.args.num_layers), dynamic_ncols=True, desc='Converting: '): start_idx = mg_model.decoder.layers[0].layer_number - 1 mg_layer_avaiable = (start_idx <= layer_idx <= mg_model.decoder.layers[-1].layer_number - 1) @@ -383,11 +388,25 @@ def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): mg_layer = None else: continue - res = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, reverse) + hf_state_dict = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, reverse) if reverse: - yield from list(self._add_prefix(res, hf_prefix).items()) + yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + hf_state_dict = {} else: yield + if reverse or mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): + if self.args.untie_embeddings_and_output_weights: + hf_lm_head_key = 'lm_head.weight' + if reverse and self.args.task_type == 'seq_cls': + hf_lm_head_key = 'score.weight' + self._set_state_dict(mg_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, reverse) + self._set_state_dict(mg_model, 'decoder.final_layernorm.weight', hf_state_dict, 'model.norm.weight', + reverse) + if reverse: + yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + hf_state_dict = {} + else: + yield def load_weights(self, mg_model, hf_model_dir: str) -> None: with SafetensorLazyLoader(hf_model_dir) as loader: diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index f480889ebd..335d656cdd 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -11,7 +11,6 @@ import torch import torch.nn from megatron.core import mpu -from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory from megatron.core.enums import ModelType from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.pipeline_parallel import get_forward_backward_func @@ -32,7 +31,7 @@ from swift.plugin import MeanMetric from swift.trainers import SwiftMixin from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger -from ..utils import adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model +from ..utils import adapter_state_dict_context, copy_original_module_weight, patch_merge_fn, prepare_mcore_model from .utils import (get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, get_swift_datasets_provider) @@ -133,29 +132,6 @@ def new_cyclic_iter(self, iterable): def _replace_data_iterator(self, data_iterator, model): return data_iterator - @staticmethod - def _patch_merge_fn(state_dict_model): - # https://github.com/NVIDIA/Megatron-LM/issues/1380 - - def sh_ten_merge_fn(sub_state_dict): - with torch.no_grad(): - shared_storage = sub_state_dict[0].untyped_storage() - if all(shared_storage.data_ptr() == tensor.untyped_storage().data_ptr() for tensor in sub_state_dict): - element_size = sub_state_dict[0].element_size() - total_numel = sum(tensor.numel() for tensor in sub_state_dict) - if shared_storage.nbytes() == total_numel * element_size: - dim_0 = sum(tensor.shape[0] for tensor in sub_state_dict) - shape = (dim_0, ) + sub_state_dict[0].shape[1:] - combined_tensor = torch.empty( - shape, dtype=sub_state_dict[0].dtype, - device=sub_state_dict[0].device).set_(shared_storage, 0, shape) - return combined_tensor - return torch.cat(sub_state_dict) - - for v in state_dict_model.values(): - if isinstance(v, ShardedTensorFactory) and 'apply_swiglu_sharded_factory' in v.merge_fn.__qualname__: - v.merge_fn = sh_ten_merge_fn - def _load_adapter_base_checkpoint(self, *_args, **kwargs): adapter_name = kwargs.pop('adapter_name', None) or 'ref_adapter' sharded_state_dict = kwargs.get('sharded_state_dict') @@ -176,7 +152,7 @@ def _load_adapter_base_checkpoint(self, *_args, **kwargs): v.key = v.key.replace(f'.{adapter_name}.', '.default.') state_dict_model[k] = v sharded_state_dict[model_k] = state_dict_model - self._patch_merge_fn(state_dict_model) + patch_merge_fn(state_dict_model) # TODO: check res = checkpointing.origin__load_base_checkpoint(*_args, **kwargs) for model_k in model_keys: state_dict = res[0][model_k] @@ -192,7 +168,7 @@ def _load_base_checkpoint(self, *_args, **kwargs): model_keys = [k for k in sharded_state_dict.keys() if k.startswith('model')] if self.args.train_type == 'full': for k in model_keys: - self._patch_merge_fn(sharded_state_dict[k]) + patch_merge_fn(sharded_state_dict[k]) return checkpointing.origin__load_base_checkpoint(*_args, **kwargs) mapping = {} for model_k in model_keys: @@ -218,7 +194,7 @@ def _load_base_checkpoint(self, *_args, **kwargs): v.key = v.key.replace('.modules_to_save.default', '') state_dict_model[k] = v sharded_state_dict[model_k] = state_dict_model - self._patch_merge_fn(state_dict_model) + patch_merge_fn(state_dict_model) res = checkpointing.origin__load_base_checkpoint(*_args, **kwargs) for model_k in model_keys: state_dict = res[0][model_k] diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index bb05457980..ea806b8a45 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -2,6 +2,6 @@ from .config import convert_hf_config from .io_utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver -from .patcher import patch_torch_dist_shard +from .patcher import patch_load_base_checkpoint, patch_merge_fn, patch_torch_dist_shard from .utils import (adapter_state_dict_context, copy_original_module_weight, forward_step_helper, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/patcher.py b/swift/megatron/utils/patcher.py index 9fd1cbf169..3c35a42124 100644 --- a/swift/megatron/utils/patcher.py +++ b/swift/megatron/utils/patcher.py @@ -1,6 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from contextlib import contextmanager + +import torch +from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy -from megatron.training import get_args, global_vars, initialize, training +from megatron.training import checkpointing from swift.utils import get_logger @@ -15,3 +19,46 @@ def __new_init__(*args, **kwargs): return __init__(*args, **kwargs) TorchDistSaveShardedStrategy.__init__ = __new_init__ + + +def patch_merge_fn(state_dict_model): + # https://github.com/NVIDIA/Megatron-LM/issues/1380 + + def sh_ten_merge_fn(sub_state_dict): + with torch.no_grad(): + shared_storage = sub_state_dict[0].untyped_storage() + if all(shared_storage.data_ptr() == tensor.untyped_storage().data_ptr() for tensor in sub_state_dict): + element_size = sub_state_dict[0].element_size() + total_numel = sum(tensor.numel() for tensor in sub_state_dict) + if shared_storage.nbytes() == total_numel * element_size: + dim_0 = sum(tensor.shape[0] for tensor in sub_state_dict) + shape = (dim_0, ) + sub_state_dict[0].shape[1:] + combined_tensor = torch.empty( + shape, dtype=sub_state_dict[0].dtype, + device=sub_state_dict[0].device).set_(shared_storage, 0, shape) + return combined_tensor + return torch.cat(sub_state_dict) + + for v in state_dict_model.values(): + if isinstance(v, ShardedTensorFactory) and 'apply_swiglu_sharded_factory' in v.merge_fn.__qualname__: + v.merge_fn = sh_ten_merge_fn + + +@contextmanager +def patch_load_base_checkpoint(): + origin__load_base_checkpoint = checkpointing._load_base_checkpoint + + def _load_base_checkpoint(*_args, **kwargs): + sharded_state_dict = kwargs.get('sharded_state_dict') + if sharded_state_dict is None: + return origin__load_base_checkpoint(*_args, **kwargs) + model_keys = [k for k in sharded_state_dict.keys() if k.startswith('model')] # compat vpp + for k in model_keys: + patch_merge_fn(sharded_state_dict[k]) + return origin__load_base_checkpoint(*_args, **kwargs) + + checkpointing._load_base_checkpoint = _load_base_checkpoint + try: + yield + finally: + checkpointing._load_base_checkpoint = origin__load_base_checkpoint From 8344a92fb6447241dabd42d2b5ccad479c9b7ea4 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 28 Oct 2025 22:38:13 +0800 Subject: [PATCH 20/30] update --- swift/megatron/argument/export_args.py | 3 + swift/megatron/model/gpt_bridge.py | 137 +++++++++++++++---------- 2 files changed, 85 insertions(+), 55 deletions(-) diff --git a/swift/megatron/argument/export_args.py b/swift/megatron/argument/export_args.py index 64cdd68aaf..a389d765d2 100644 --- a/swift/megatron/argument/export_args.py +++ b/swift/megatron/argument/export_args.py @@ -39,6 +39,9 @@ def __post_init__(self): self.test_convert_dtype = HfConfigFactory.to_torch_dtype(self.test_convert_dtype) if self.to_hf or self.to_mcore: self._init_convert() + if self.model_info.is_moe_model is not None and self.tensor_model_parallel_size > 1: + self.sequence_parallel = True + logger.info('Settting args.sequence_parallel: True') def _init_convert(self): convert_kwargs = { diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 85f34ff0f0..f009c5581f 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -47,21 +47,24 @@ def _get_tp_split_dim(mg_key: str) -> Optional[int]: elif key in {'linear_proj', 'linear_fc2'}: return 1 - def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0, mg_slices=()): - # tp + def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: float = 0, is_expert: bool = False, mg_slices=()): + # tp/etp tp_dim = self._get_tp_split_dim(mg_key) hf_weight = hf_weight.to(mg_param.device) - if tp_dim is not None and self.tp_size > 1: - if self.tp_rank == 0: - splited_weights = [t.contiguous() for t in hf_weight.chunk(self.tp_size, dim=tp_dim)] + tp_size = self.etp_size if is_expert else self.tp_size + tp_rank = self.etp_rank if is_expert else self.tp_rank + tp_group = self.etp_group if is_expert else self.tp_group + if tp_dim is not None and tp_size > 1: + if tp_rank == 0: + splited_weights = [t.contiguous() for t in hf_weight.chunk(tp_size, dim=tp_dim)] else: splited_weights = None tensor = torch.empty_like(mg_param.data[mg_slices]) dist.scatter( tensor, splited_weights, - src=dist.get_global_rank(self.tp_group, 0), - group=self.tp_group, + src=dist.get_global_rank(tp_group, 0), + group=tp_group, ) del splited_weights else: @@ -70,28 +73,30 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: int = 0, mg_sli tensor = tensor + offset mg_param.data[mg_slices].copy_(tensor) - def _get_weights(self, mg_weight, mg_key, offset: int = 0): - # tp + def _get_weights(self, mg_weight, mg_key, offset: int = 0, is_expert: bool = False): + # tp/etp tp_dim = self._get_tp_split_dim(mg_key) tensor = mg_weight - if tensor is not None and tp_dim is not None and self.tp_size > 1: + tp_size = self.etp_size if is_expert else self.tp_size + tp_group = self.etp_group if is_expert else self.tp_group + if tensor is not None and tp_dim is not None and tp_size > 1: if tp_dim == 0: # save memory tensor_shape = list(tensor.shape) - tensor_shape[0] *= self.tp_size + tensor_shape[0] *= tp_size output = tensor.new_empty(tensor_shape) dist.all_gather_into_tensor( output, tensor, - group=self.tp_group, + group=tp_group, ) tensor = output else: - output = [torch.empty_like(tensor) for _ in range(self.tp_size)] + output = [torch.empty_like(tensor) for _ in range(tp_size)] dist.all_gather( output, tensor, - group=self.tp_group, + group=tp_group, ) tensor = torch.cat(output, dim=tp_dim) del output @@ -123,14 +128,22 @@ def _find_not_none_index(lst): res.append(i) return res - def _set_state_dict(self, mg_module, mg_key: str, hf_state_dict, hf_key: str, reverse: bool, offset: float = 0): + def _set_state_dict(self, + mg_module, + mg_key: str, + hf_state_dict, + hf_key: str, + reverse: bool, + offset: float = 0, + is_expert: bool = False): mg_param = deep_getattr(mg_module, mg_key) if reverse: - hf_state_dict[hf_key] = self._get_weights(None if mg_param is None else mg_param.data, mg_key, offset) + hf_state_dict[hf_key] = self._get_weights(None if mg_param is None else mg_param.data, mg_key, offset, + is_expert) else: assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}' hf_weight = hf_state_dict[hf_key].load() - self._set_weights(mg_param, hf_weight, mg_key, offset) + self._set_weights(mg_param, hf_weight, mg_key, offset, is_expert) @staticmethod def _remove_prefix(state_dict, prefix: str): @@ -213,37 +226,47 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int if args.qk_layernorm: hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight' hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight' - self._set_state_dict(state_dict, res, hf_q_norm_key, 'q_layernorm.weight', reverse) - self._set_state_dict(state_dict, res, hf_k_norm_key, 'k_layernorm.weight', reverse) + self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, hf_q_norm_key, reverse) + self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, hf_k_norm_key, reverse) if reverse: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _set_moe_state(self, state_dict, hf_prefix: str, mg_prefix: str, layer_idx: int, reverse: bool): - src_prefix, tgt_prefix = hf_prefix, mg_prefix + def _set_moe_state( + self, + mg_mlp, + hf_state_dict, + hf_prefix: str, + layer_idx: int, + reverse: bool, + ): if reverse: - src_prefix, tgt_prefix = tgt_prefix, src_prefix - state_dict = self._remove_prefix(state_dict, src_prefix) + hf_state_dict = {} + else: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) hf_mlp = self.hf_layers[layer_idx].mlp - res = {} hf_gate_key = 'gate.wg.weight' if hasattr(hf_mlp.gate, 'wg') else 'gate.weight' - self._set_state_dict(state_dict, res, hf_gate_key, 'router.weight', reverse) + self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, reverse) if self.args.moe_router_enable_expert_bias: - self._set_state_dict(state_dict, res, 'gate.e_score_correction_bias', 'router.expert_bias', reverse) + self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, 'gate.e_score_correction_bias', reverse) if self.args.moe_shared_expert_intermediate_size: for key in ['shared_expert', 'shared_experts', 'shared_mlp']: if hasattr(hf_mlp, key): hf_shared_expert_prefix = f'{key}.' - res.update(self._set_mlp_state(state_dict, hf_shared_expert_prefix, 'shared_experts.', layer_idx, reverse)) + hf_state_dict.update( + self._set_mlp_state(mg_mlp.shared_experts, hf_state_dict, hf_shared_expert_prefix, layer_idx, reverse)) if hasattr(hf_mlp, 'shared_expert_gate'): - self._set_state_dict(state_dict, res, 'shared_expert_gate.weight', 'shared_experts.gate_weight', + self._set_state_dict(mg_mlp, 'shared_experts.gate_weight', hf_state_dict, 'shared_expert_gate.weight', reverse) for expert_idx in range(self.args.num_experts): hf_expert_prefix = f'experts.{expert_idx}.' if hasattr(hf_mlp.experts, '__len__') else 'experts.' - res.update( - self._set_mlp_state(state_dict, hf_expert_prefix, 'experts.', layer_idx, reverse, group_idx=expert_idx)) - return self._add_prefix(res, tgt_prefix) + hf_state_dict.update( + self._set_mlp_state( + mg_mlp.experts, hf_state_dict, hf_expert_prefix, layer_idx, reverse, group_idx=expert_idx)) + if reverse: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict def _set_mlp_state( self, @@ -260,6 +283,7 @@ def _set_mlp_state( hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) hf_mlp = self.hf_layers[layer_idx].mlp hf_grouped = False + is_expert = group_idx is not None if group_idx is not None: if not hasattr(hf_mlp.experts, '__len__'): hf_grouped = True @@ -272,30 +296,33 @@ def _set_mlp_state( res[fc1_key] = hf_state_dict['gate_up_proj'][group_idx].t() res[fc2_key] = hf_state_dict['down_proj'][group_idx].t() else: - if hasattr(hf_mlp, 'gate_up_proj'): - self._set_state_dict(hf_state_dict, res, 'gate_up_proj.weight', fc1_key, reverse) + if reverse: + fc1_weight = deep_getattr(mg_mlp, fc1_key) + gate_proj_weight = self._get_weights( + None if fc1_weight is None else fc1_weight[:fc1_weight.shape[0] // 2], fc1_key, is_expert=is_expert) + up_proj_weight = self._get_weights( + None if fc1_weight is None else fc1_weight[fc1_weight.shape[0] // 2:], fc1_key, is_expert=is_expert) + if hasattr(hf_mlp, 'gate_up_proj'): + hf_state_dict['gate_up_proj'] = torch.concat([gate_proj_weight, up_proj_weight], dim=0) + else: + hf_state_dict['gate_proj.weight'] = gate_proj_weight + hf_state_dict['up_proj.weight'] = up_proj_weight else: - if reverse: - fc1_weight = deep_getattr(mg_mlp, fc1_key) - hf_state_dict['gate_proj.weight'] = self._get_weights( - None if fc1_weight is None else fc1_weight[:fc1_weight.shape[0] // 2], fc1_key) - hf_state_dict['up_proj.weight'] = self._get_weights( - None if fc1_weight is None else fc1_weight[fc1_weight.shape[0] // 2:], fc1_key) + linear_fc1_weight = deep_getattr(mg_mlp, fc1_key) + gate_slices = (slice(None, linear_fc1_weight.shape[0] // 2), ) + up_slices = (slice(linear_fc1_weight.shape[0] // 2, None), ) + if hasattr(hf_mlp, 'gate_up_proj'): + gate_up_proj_weight = hf_state_dict['gate_up_proj.weight'].load() + gate_proj_weight = gate_up_proj_weight[gate_slices] + up_proj_weight = gate_up_proj_weight[up_slices] else: - fc1_weight = torch.cat([ - hf_state_dict['gate_proj.weight'].load(), - hf_state_dict['up_proj.weight'].load(), - ], - dim=0) - linear_fc1_weight = deep_getattr(mg_mlp, fc1_key) - gate_slices = (slice(None, linear_fc1_weight.shape[0] // 2), ) - up_slices = (slice(linear_fc1_weight.shape[0] // 2, None), ) - self._set_weights( - linear_fc1_weight, hf_state_dict['gate_proj.weight'].load(), fc1_key, mg_slices=gate_slices) - self._set_weights( - linear_fc1_weight, hf_state_dict['up_proj.weight'].load(), fc1_key, mg_slices=up_slices) - - self._set_state_dict(mg_mlp, fc2_key, hf_state_dict, 'down_proj.weight', reverse) + gate_proj_weight = hf_state_dict['gate_proj.weight'].load() + up_proj_weight = hf_state_dict['up_proj.weight'].load() + self._set_weights( + linear_fc1_weight, gate_proj_weight, fc1_key, is_expert=is_expert, mg_slices=gate_slices) + self._set_weights(linear_fc1_weight, up_proj_weight, fc1_key, is_expert=is_expert, mg_slices=up_slices) + + self._set_state_dict(mg_mlp, fc2_key, hf_state_dict, 'down_proj.weight', reverse, is_expert=is_expert) if reverse: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict @@ -388,9 +415,9 @@ def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): mg_layer = None else: continue - hf_state_dict = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, reverse) + res = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, reverse) if reverse: - yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + yield from list(self._add_prefix(res, hf_prefix).items()) hf_state_dict = {} else: yield From c294248c2bd9b43adcb003032762e842a2370060 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 28 Oct 2025 23:08:54 +0800 Subject: [PATCH 21/30] update --- swift/megatron/convert.py | 4 +-- swift/megatron/model/gpt_bridge.py | 42 ++++++++++++++++++++++-------- swift/megatron/train/sft.py | 10 ++----- swift/megatron/utils/__init__.py | 4 +-- swift/megatron/utils/utils.py | 12 +++++++++ 5 files changed, 49 insertions(+), 23 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 0244aa403b..97975d49af 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -19,7 +19,7 @@ from swift.utils import get_logger, get_n_params_grads from .argument import MegatronArguments from .model import get_megatron_model_meta -from .utils import convert_hf_config, forward_step_helper, patch_torch_dist_shard +from .utils import convert_hf_config, forward_step_helper, get_padding_to, patch_torch_dist_shard logger = get_logger() @@ -152,7 +152,7 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float is_multimodal = template.model_meta.is_multimodal inputs = get_examples(is_multimodal) inputs = template.encode(inputs) - inputs = to_device(template.data_collator([inputs]), 'cuda') + inputs = to_device(template.data_collator([inputs], padding_to=get_padding_to()), 'cuda') mg_language_model = mg_model.language_model if is_multimodal else mg_model share_embedding = mg_language_model.share_embeddings_and_output_weights if hf_model is not None: diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index f009c5581f..c422fcb79c 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -35,7 +35,7 @@ def __init__(self): self.tp_rank = mpu.get_tensor_model_parallel_rank() self.pp_rank = mpu.get_pipeline_model_parallel_rank() self.etp_rank = mpu.get_expert_tensor_parallel_rank() - self.ep_rank = mpu.get_expert_model_parallel_group() + self.ep_rank = mpu.get_expert_model_parallel_rank() @staticmethod def _get_tp_split_dim(mg_key: str) -> Optional[int]: @@ -260,10 +260,32 @@ def _set_moe_state( self._set_state_dict(mg_mlp, 'shared_experts.gate_weight', hf_state_dict, 'shared_expert_gate.weight', reverse) for expert_idx in range(self.args.num_experts): - hf_expert_prefix = f'experts.{expert_idx}.' if hasattr(hf_mlp.experts, '__len__') else 'experts.' + mg_experts = mg_mlp.experts + start_idx = mg_experts.num_local_experts * self.ep_rank + expert_available = (start_idx <= expert_idx < start_idx + mg_experts.num_local_experts) + if expert_available: + group_idx = expert_idx - start_idx + else: + group_idx = None + if reverse: + mg_experts = None + else: + continue + if hasattr(hf_mlp.experts, '__len__'): + hf_expert_prefix = f'experts.{expert_idx}.' + hf_group_idx = None + else: + hf_expert_prefix = 'experts.' + hf_group_idx = expert_idx hf_state_dict.update( self._set_mlp_state( - mg_mlp.experts, hf_state_dict, hf_expert_prefix, layer_idx, reverse, group_idx=expert_idx)) + mg_experts, + hf_state_dict, + hf_expert_prefix, + layer_idx, + reverse, + group_idx=group_idx, + hf_group_idx=hf_group_idx)) if reverse: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict @@ -276,25 +298,23 @@ def _set_mlp_state( layer_idx: int, reverse: bool, group_idx: Optional[int] = None, + hf_group_idx: Optional[int] = None, ): if reverse: hf_state_dict = {} else: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) hf_mlp = self.hf_layers[layer_idx].mlp - hf_grouped = False is_expert = group_idx is not None if group_idx is not None: - if not hasattr(hf_mlp.experts, '__len__'): - hf_grouped = True fc1_key = f'linear_fc1.weight{group_idx}' fc2_key = f'linear_fc2.weight{group_idx}' else: fc1_key = 'linear_fc1.weight' fc2_key = 'linear_fc2.weight' - if hf_grouped: - res[fc1_key] = hf_state_dict['gate_up_proj'][group_idx].t() - res[fc2_key] = hf_state_dict['down_proj'][group_idx].t() + if hf_group_idx is not None: + res[fc1_key] = hf_state_dict['gate_up_proj'][hf_group_idx].t() + res[fc2_key] = hf_state_dict['down_proj'][hf_group_idx].t() else: if reverse: fc1_weight = deep_getattr(mg_mlp, fc1_key) @@ -407,8 +427,8 @@ def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): yield for layer_idx in tqdm(range(self.args.num_layers), dynamic_ncols=True, desc='Converting: '): start_idx = mg_model.decoder.layers[0].layer_number - 1 - mg_layer_avaiable = (start_idx <= layer_idx <= mg_model.decoder.layers[-1].layer_number - 1) - if mg_layer_avaiable: + mg_layer_available = (start_idx <= layer_idx < mg_model.decoder.layers[-1].layer_number) + if mg_layer_available: mg_layer = mg_model.decoder.layers[layer_idx - start_idx] else: if reverse: diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 7a7cf85f10..0d604ba778 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -10,6 +10,7 @@ from swift.utils import get_logger, is_last_rank, plot_images from ..argument import MegatronTrainArguments from ..trainers import MegatronTrainer +from ..utils import get_padding_to from .utils import build_streaming_dataloader logger = get_logger() @@ -40,15 +41,8 @@ def __init__(self, args: Optional[Union[List[str], MegatronTrainArguments]] = No self.trainer = self.prepare_trainer() def _get_data_collator(self): - args = self.args data_collator = self.template.data_collator - padding_to = None - if args.tensor_model_parallel_size > 1 and args.sequence_parallel: - padding_to = args.tensor_model_parallel_size - if args.context_parallel_size > 1: - padding_to = (padding_to or 1) * args.context_parallel_size - if args.fp8_format: - padding_to = max((padding_to or 1) * 8, 16) + padding_to = get_padding_to() logger.info(f'padding_to: {padding_to}') data_collator = partial(data_collator, padding_to=padding_to) return data_collator diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index ea806b8a45..e996e7362d 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -3,5 +3,5 @@ from .config import convert_hf_config from .io_utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver from .patcher import patch_load_base_checkpoint, patch_merge_fn, patch_torch_dist_shard -from .utils import (adapter_state_dict_context, copy_original_module_weight, forward_step_helper, prepare_mcore_model, - tuners_sharded_state_dict) +from .utils import (adapter_state_dict_context, copy_original_module_weight, forward_step_helper, get_padding_to, + prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index a8e0a43ac5..89b77888c3 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -299,3 +299,15 @@ def forward_step_helper(model, inputs, dtype=None): output_tensor = None return output_tensor + + +def get_padding_to(): + args = get_args() + padding_to = None + if args.tensor_model_parallel_size > 1 and args.sequence_parallel: + padding_to = args.tensor_model_parallel_size + if args.context_parallel_size > 1: + padding_to = (padding_to or 1) * args.context_parallel_size + if args.fp8: + padding_to = max((padding_to or 1) * 8, 16) + return padding_to From 51b64114473f64315554c5b45bcd4466405d7a35 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Oct 2025 16:47:37 +0800 Subject: [PATCH 22/30] update --- swift/llm/template/base.py | 1 + swift/megatron/argument/megatron_args.py | 3 --- swift/megatron/argument/train_args.py | 3 +++ swift/megatron/convert.py | 24 +++++++++++------------- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 017674ce6b..1e09358108 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1685,6 +1685,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in assert res['attention_mask'].dtype is torch.bool, f'attention_mask.dtype: {res["attention_mask"].dtype}' for i, seq_len in enumerate(seq_lens): res['attention_mask'][i, :, seq_len:] = 0 + res['attention_mask'] = ~res['attention_mask'] for key, pad_value in zip(keys, pad_values): if key not in res: diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 62f1164605..3e5fde66e8 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -456,12 +456,9 @@ def __post_init__(self): self.seq_length = self.max_position_embeddings if self.position_embedding_type is None: self.position_embedding_type = 'rope' - if self.tensorboard_dir is None and self.save is not None: - self.tensorboard_dir = f'{self.save}/runs' self._init_moe() self._init_mixed_precision() - self.tensorboard_dir = to_abspath(self.tensorboard_dir) self.megatron_extra_kwargs = json_parse_to_dict(self.megatron_extra_kwargs) self._init_no_rope_fusion() diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 2d3acb1313..b0c85bb20f 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -46,6 +46,9 @@ def __post_init__(self): raise ValueError(f'self.dataset: {self.dataset}, self.cached_dataset: {self.cached_dataset}. ' 'Please input the training dataset.') self._init_save() + if self.tensorboard_dir is None and self.save is not None: + self.tensorboard_dir = f'{self.save}/runs' + self.tensorboard_dir = to_abspath(self.tensorboard_dir) if self.load is None and self.no_initialization and not self.load_hf_checkpoint: raise ValueError('You did not pass `--load`, so you need to set `--no_initialization false` ' 'to allow the model to initialize weights properly.') diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 97975d49af..4fac4da5b1 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -12,7 +12,6 @@ from megatron.training.checkpointing import load_checkpoint from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint from megatron.training.initialize import initialize_megatron -from megatron.training.utils import get_ltor_masks_and_position_ids from swift.llm import (ExportArguments, HfConfigFactory, prepare_model_template, save_checkpoint, to_device, to_float_dtype) @@ -152,7 +151,7 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float is_multimodal = template.model_meta.is_multimodal inputs = get_examples(is_multimodal) inputs = template.encode(inputs) - inputs = to_device(template.data_collator([inputs], padding_to=get_padding_to()), 'cuda') + hf_inputs = to_device(template.data_collator([inputs]), 'cuda') mg_language_model = mg_model.language_model if is_multimodal else mg_model share_embedding = mg_language_model.share_embeddings_and_output_weights if hf_model is not None: @@ -165,13 +164,13 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules) with torch.inference_mode(), _model_cpu_forward_context( hf_modules, torch_dtype, share_embedding=share_embedding): - inputs.pop('text_position_ids', None) - hf_logits = hf_model(**inputs).logits + hf_inputs.pop('text_position_ids', None) + hf_logits = hf_model(**hf_inputs).logits hf_logits = hf_logits.to('cuda') hf_model.to('cpu') - input_ids = inputs['input_ids'] - attention_mask, _, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True) + template.use_megatron = True + mg_inputs = to_device(template.data_collator([inputs], padding_to=get_padding_to()), 'cuda') packed_seq_params = None mg_torch_dtype = torch_dtype # thd @@ -181,16 +180,14 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float # attention_mask = None mg_language_model.config.fp8 = None # compat fp8 mg_modules = _find_modules(mg_language_model, ignore_modules=['visual']) - kwargs = inputs.copy() - kwargs.pop('labels') - kwargs.update({'input_ids': input_ids, 'attention_mask': attention_mask, 'packed_seq_params': packed_seq_params}) + for key in ['labels', 'num_samples', 'attention_mask_2d', 'text_position_ids']: + mg_inputs.pop(key, None) + mg_inputs.update({'packed_seq_params': packed_seq_params}) args = get_args() - if 'position_ids' not in kwargs: - kwargs['position_ids'] = position_ids with torch.inference_mode(), _model_cpu_forward_context( mg_modules, mg_torch_dtype, 'cuda', share_embedding=share_embedding): # TODO: test pp tie_weights - mg_logits = forward_step_helper(mg_model, kwargs, dtype=mg_torch_dtype) + mg_logits = forward_step_helper(mg_model, mg_inputs, dtype=mg_torch_dtype) if args.tensor_model_parallel_size > 1: from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region if mg_logits is not None: @@ -203,10 +200,11 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float max_diff = (mg_logits - hf_logits).abs().max().item() print(f'mean_diff: {mean_diff}, max_diff: {max_diff}') else: + mg_logits = mg_logits[:, :hf_logits.shape[1]] token_mean_diff = (mg_logits - hf_logits).abs().mean(dim=-1) mean_diff = token_mean_diff.mean().item() max_diff = (mg_logits - hf_logits).abs().max().item() - loss_mask = (torch.roll(inputs['labels'], -1) != -100) + loss_mask = (torch.roll(hf_inputs['labels'], -1) != -100) mean_diff_with_loss = token_mean_diff[loss_mask].mean().item() max_diff_with_loss = (mg_logits - hf_logits)[loss_mask].abs().max().item() print(f'token_mean_diff: {token_mean_diff}') From 20d9fb36b1846e99bc1ada929d04c60a0ad63b6a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Oct 2025 17:54:09 +0800 Subject: [PATCH 23/30] update --- swift/megatron/convert.py | 15 +- swift/megatron/model/gpt_bridge.py | 291 ++++++++++++++++-------- swift/megatron/model/mm_gpt/qwen3_vl.py | 5 +- 3 files changed, 202 insertions(+), 109 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 4fac4da5b1..e882498b23 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -65,19 +65,18 @@ def _find_modules(model, recurse: bool = True, prefix='', ignore_modules=None): @contextmanager -def _model_cpu_forward_context(modules, torch_dtype=None, device=None, share_embedding: bool = False): +def _model_cpu_forward_context(modules, torch_dtype=None, compute_device=None, share_embedding: bool = False, target_device='cpu'): origin_torch_dtype = next(modules[-1].parameters()).dtype - def _to_cuda_hook(module, args): - if device is not None or torch_dtype is not None: - module.to(device=device, dtype=torch_dtype) + if compute_device is not None or torch_dtype is not None: + module.to(device=compute_device, dtype=torch_dtype) args = to_float_dtype(args, dtype=torch_dtype) return args def _to_cpu_hook(module, args, output): if share_embedding and module is modules[0]: return - module.to(device='cpu', dtype=origin_torch_dtype) + module.to(device=target_device, dtype=origin_torch_dtype) hooks = [] for module in modules: @@ -162,8 +161,12 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float model_arch = hf_model.model_meta.model_arch ignore_modules = (model_arch.vision_tower + model_arch.aligner) if is_multimodal else [] hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules) + if hf_model.device == 'cpu': + compute_device = 'cuda' + else: + compute_device = None with torch.inference_mode(), _model_cpu_forward_context( - hf_modules, torch_dtype, share_embedding=share_embedding): + hf_modules, torch_dtype, compute_device, share_embedding=share_embedding): hf_inputs.pop('text_position_ids', None) hf_logits = hf_model(**hf_inputs).logits hf_logits = hf_logits.to('cuda') diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index c422fcb79c..eff8f54e25 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1,4 +1,3 @@ -from functools import partial from typing import Dict, Literal, Optional, Union import torch @@ -15,8 +14,11 @@ class GPTBridge: lm_layers_prefix = 'model.layers' # HF model - def __init__(self): + def __init__(self, disable_tqmd: bool = False): self.args = get_args() + self.disable_tqmd = disable_tqmd + self._target_device = None + self.only_last_rank = False model_info = self.args.model_info with torch.device('meta'), disable_safe_ddp_context_use_barrier(): self.hf_model, self.processor = get_model_tokenizer( @@ -42,12 +44,20 @@ def _get_tp_split_dim(mg_key: str) -> Optional[int]: key, suffix = mg_key.rsplit('.', 2)[-2:] if suffix == 'layer_norm_weight': return - if key in {'word_embeddings', 'output_layer', 'linear_qkv', 'linear_fc1'}: + if key in {'word_embeddings', 'output_layer', 'linear_qkv'}: return 0 - elif key in {'linear_proj', 'linear_fc2'}: + elif key in {'linear_proj', 'linear_fc1', 'linear_fc2'}: + # linear_fc1 shape [2, X, Y] return 1 - def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: float = 0, is_expert: bool = False, mg_slices=()): + def _set_weights( + self, + mg_param: torch.Tensor, + hf_weight: torch.Tensor, + mg_key: str, + offset: float = 0, + is_expert: bool = False, + ): # tp/etp tp_dim = self._get_tp_split_dim(mg_key) hf_weight = hf_weight.to(mg_param.device) @@ -59,7 +69,7 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: float = 0, is_e splited_weights = [t.contiguous() for t in hf_weight.chunk(tp_size, dim=tp_dim)] else: splited_weights = None - tensor = torch.empty_like(mg_param.data[mg_slices]) + tensor = torch.empty_like(mg_param.data) dist.scatter( tensor, splited_weights, @@ -71,9 +81,9 @@ def _set_weights(self, mg_param, hf_weight, mg_key: str, offset: float = 0, is_e tensor = hf_weight if offset: tensor = tensor + offset - mg_param.data[mg_slices].copy_(tensor) + mg_param.data.copy_(tensor) - def _get_weights(self, mg_weight, mg_key, offset: int = 0, is_expert: bool = False): + def _get_weights(self, mg_weight: torch.Tensor, mg_key: str, offset: int = 0, is_expert: bool = False): # tp/etp tp_dim = self._get_tp_split_dim(mg_key) tensor = mg_weight @@ -100,34 +110,43 @@ def _get_weights(self, mg_weight, mg_key, offset: int = 0, is_expert: bool = Fal ) tensor = torch.cat(output, dim=tp_dim) del output - # pp - if self.pp_size > 1: + # pp/ep + for parallel_state in ['ep', 'pp']: + if parallel_state == 'pp' and self.pp_size > 1: + parallel_group = self.pp_group + parallel_rank = self.pp_rank + elif parallel_state == 'ep' and is_expert and self.ep_size > 1: + parallel_group = self.ep_group + parallel_rank = self.ep_rank + else: + continue + src_rank = torch.tensor([0 if tensor is None else parallel_rank], dtype=torch.int64, device='cuda') + dist.all_reduce(src_rank, group=parallel_group) + src_rank = dist.get_global_rank(parallel_group, src_rank.item()) + meta_data = torch.zeros(10, dtype=torch.int64, device='cuda') + dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3} + dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} if tensor is None: - output = [None] * self.pp_size - dist.all_gather_object(output, None, group=self.pp_group) - src_idx = self._find_not_none_index(output) - assert len(src_idx) == 1, f'src_idx: {src_idx}' - src_idx = src_idx[0] - shape, dtype = output[src_idx] + dist.broadcast(meta_data, src=src_rank, group=parallel_group) + shape = meta_data[1:1 + meta_data[0]].tolist() + dtype = dtype_mapping_r[meta_data[-1].item()] tensor = torch.empty(shape, device='cuda', dtype=dtype) - dist.broadcast(tensor, src=dist.get_global_rank(self.pp_group, src_idx), group=self.pp_group) + dist.broadcast(tensor, src=src_rank, group=parallel_group) else: - output = [None] * self.pp_size - meta_data = (tensor.shape, tensor.dtype) - dist.all_gather_object(output, meta_data, group=self.pp_group) - dist.broadcast(tensor, src=dist.get_global_rank(self.pp_group, self.pp_rank), group=self.pp_group) + meta_data[0] = tensor.ndim + meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda') + meta_data[-1] = dtype_mapping[tensor.dtype] + dist.broadcast(meta_data, src=src_rank, group=parallel_group) + dist.broadcast(tensor, src=src_rank, group=parallel_group) + assert tensor is not None, f'mg_key: {mg_key}' if offset: tensor = tensor + offset + if self._target_device is not None: + tensor = tensor.to(device=self._target_device) + if self._only_last_rank and not is_last_rank(): + tensor = None return tensor - @staticmethod - def _find_not_none_index(lst): - res = [] - for i, x in enumerate(lst): - if x is not None: - res.append(i) - return res - def _set_state_dict(self, mg_module, mg_key: str, @@ -189,12 +208,15 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int if reverse: mg_attn_weight = self._get_weights(None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') - mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size)) - q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ - 0] // num_query_groups - hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size) - hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size) - hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size) + if mg_attn_weight is not None: + mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size)) + q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ + 0] // num_query_groups + hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size).clone() + hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, + args.hidden_size).clone() + hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size).clone() + del mg_attn_weight else: linear_qkv_weight = torch.cat([ hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), @@ -210,10 +232,11 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int if reverse: mg_attn_bias = self._get_weights(None if mg_attn is None else mg_attn.linear_qkv.bias.data, 'linear_qkv.bias') - mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1)) - hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1) - hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1) - hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1) + if mg_attn_bias is not None: + mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1)) + hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone() + hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone() + hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone() else: linear_qkv_bias = torch.cat([ hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)), @@ -254,95 +277,158 @@ def _set_moe_state( for key in ['shared_expert', 'shared_experts', 'shared_mlp']: if hasattr(hf_mlp, key): hf_shared_expert_prefix = f'{key}.' + shared_expert = getattr(hf_mlp, key) hf_state_dict.update( - self._set_mlp_state(mg_mlp.shared_experts, hf_state_dict, hf_shared_expert_prefix, layer_idx, reverse)) + self._set_mlp_state( + mg_mlp.shared_experts, + hf_state_dict, + hf_shared_expert_prefix, + layer_idx, + reverse, + hf_mlp=shared_expert)) if hasattr(hf_mlp, 'shared_expert_gate'): self._set_state_dict(mg_mlp, 'shared_experts.gate_weight', hf_state_dict, 'shared_expert_gate.weight', reverse) - for expert_idx in range(self.args.num_experts): + for ep_rank in range(self.ep_size): mg_experts = mg_mlp.experts - start_idx = mg_experts.num_local_experts * self.ep_rank - expert_available = (start_idx <= expert_idx < start_idx + mg_experts.num_local_experts) - if expert_available: - group_idx = expert_idx - start_idx - else: - group_idx = None + expert_available = ep_rank == self.ep_rank + if not expert_available: if reverse: mg_experts = None else: continue - if hasattr(hf_mlp.experts, '__len__'): - hf_expert_prefix = f'experts.{expert_idx}.' - hf_group_idx = None - else: - hf_expert_prefix = 'experts.' - hf_group_idx = expert_idx hf_state_dict.update( - self._set_mlp_state( - mg_experts, - hf_state_dict, - hf_expert_prefix, - layer_idx, - reverse, - group_idx=group_idx, - hf_group_idx=hf_group_idx)) + self._set_expert_state(mg_experts, hf_state_dict, 'experts.', layer_idx, reverse, ep_rank)) if reverse: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _set_mlp_state( + def _set_expert_state( self, mg_mlp, hf_state_dict, hf_prefix: str, layer_idx: int, reverse: bool, - group_idx: Optional[int] = None, - hf_group_idx: Optional[int] = None, + ep_rank: int, ): if reverse: hf_state_dict = {} else: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - hf_mlp = self.hf_layers[layer_idx].mlp - is_expert = group_idx is not None - if group_idx is not None: - fc1_key = f'linear_fc1.weight{group_idx}' - fc2_key = f'linear_fc2.weight{group_idx}' + hf_experts = self.hf_layers[layer_idx].mlp.experts + num_local_experts = self.args.num_experts // self.ep_size + # if hf_group_idx is not None: + # res[fc1_key] = hf_state_dict['gate_up_proj'][hf_group_idx].t() + # res[fc2_key] = hf_state_dict['down_proj'][hf_group_idx].t() + # else: + if reverse: + if mg_mlp is None: + fc1_weight = None + else: + fc1_weight = torch.concat([getattr(mg_mlp.linear_fc1, f'weight{i}') for i in range(num_local_experts)], + dim=0) + fc1_weight = fc1_weight.view(num_local_experts * 2, -1, fc1_weight.shape[1]) + gate_up_proj_weight = self._get_weights(fc1_weight, 'linear_fc1.weight', is_expert=True) + del fc1_weight + if gate_up_proj_weight is not None: + gate_up_proj_weight = gate_up_proj_weight.view(num_local_experts, 2, -1, gate_up_proj_weight.shape[-1]) + for i in range(num_local_experts): + hf_i = i + ep_rank * num_local_experts + if hasattr(hf_experts[i], 'gate_up_proj'): + hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = gate_up_proj_weight[i].view( + -1, gate_up_proj_weight.shape[-1]).clone() + else: + hf_state_dict[f'{hf_i}.gate_proj.weight'] = gate_up_proj_weight[i][0].clone() + hf_state_dict[f'{hf_i}.up_proj.weight'] = gate_up_proj_weight[i][1].clone() + del gate_up_proj_weight + else: + fc1_weight = mg_mlp.linear_fc1.weight0 + fc1_weight = fc1_weight.new_empty(num_local_experts * 2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) + if hasattr(hf_experts[0], 'gate_up_proj'): + gate_up_proj_weight = torch.concat([ + hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_up_proj.weight'].load() + for i in range(num_local_experts) + ], + dim=0) + else: + weight_list = [] + start_idx = ep_rank * num_local_experts + for i in range(num_local_experts): + gate_proj_weight = hf_state_dict[f'{start_idx + i}.gate_proj.weight'].load() + up_proj_weight = hf_state_dict[f'{start_idx + i}.up_proj.weight'].load() + weight_list.append(torch.stack([gate_proj_weight, up_proj_weight], dim=0)) + gate_up_proj_weight = torch.concat(weight_list, dim=0) + del weight_list + self._set_weights(fc1_weight, gate_up_proj_weight, 'linear_fc1.weight', is_expert=True) + fc1_weight = fc1_weight.view(num_local_experts, -1, fc1_weight.shape[-1]) + for i in range(num_local_experts): + getattr(mg_mlp.linear_fc1, f'weight{i}').data.copy_(fc1_weight[i].view(-1, fc1_weight.shape[-1])) + del fc1_weight + if reverse: + if mg_mlp is None: + fc2_weight = None + else: + fc2_weight = torch.concat([getattr(mg_mlp.linear_fc2, f'weight{i}') for i in range(num_local_experts)], + dim=0) + fc2_weight = fc2_weight.view(num_local_experts * 2, -1, fc2_weight.shape[1]) + down_proj_weight = self._get_weights(fc2_weight, 'linear_fc2.weight', is_expert=True) + del fc2_weight + if down_proj_weight is not None: + down_proj_weight = down_proj_weight.view(num_local_experts, -1, down_proj_weight.shape[-1]) + for i in range(num_local_experts): + hf_i = i + ep_rank * num_local_experts + hf_state_dict[f'{hf_i}.down_proj.weight'] = down_proj_weight[i].view( + -1, down_proj_weight.shape[-1]).clone() else: - fc1_key = 'linear_fc1.weight' - fc2_key = 'linear_fc2.weight' - if hf_group_idx is not None: - res[fc1_key] = hf_state_dict['gate_up_proj'][hf_group_idx].t() - res[fc2_key] = hf_state_dict['down_proj'][hf_group_idx].t() + fc2_weight = mg_mlp.linear_fc2.weight0 + fc2_weight = fc2_weight.new_empty(num_local_experts * fc2_weight.shape[0], fc2_weight.shape[1]) + down_proj_weight = torch.concat([ + hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.weight'].load() + for i in range(num_local_experts) + ], + dim=0) + self._set_weights(fc2_weight, down_proj_weight, 'linear_fc2.weight', is_expert=True) + fc2_weight = fc2_weight.view(num_local_experts, -1, fc2_weight.shape[-1]) + for i in range(num_local_experts): + getattr(mg_mlp.linear_fc2, f'weight{i}').data.copy_(fc2_weight[i]) + if reverse: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_mlp_state(self, mg_mlp, hf_state_dict, hf_prefix: str, layer_idx: int, reverse: bool, hf_mlp=None): + if reverse: + hf_state_dict = {} else: - if reverse: - fc1_weight = deep_getattr(mg_mlp, fc1_key) - gate_proj_weight = self._get_weights( - None if fc1_weight is None else fc1_weight[:fc1_weight.shape[0] // 2], fc1_key, is_expert=is_expert) - up_proj_weight = self._get_weights( - None if fc1_weight is None else fc1_weight[fc1_weight.shape[0] // 2:], fc1_key, is_expert=is_expert) - if hasattr(hf_mlp, 'gate_up_proj'): - hf_state_dict['gate_up_proj'] = torch.concat([gate_proj_weight, up_proj_weight], dim=0) - else: - hf_state_dict['gate_proj.weight'] = gate_proj_weight - hf_state_dict['up_proj.weight'] = up_proj_weight + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + if hf_mlp is None: + hf_mlp = self.hf_layers[layer_idx].mlp + if reverse: + if mg_mlp is None: + fc1_weight = None else: - linear_fc1_weight = deep_getattr(mg_mlp, fc1_key) - gate_slices = (slice(None, linear_fc1_weight.shape[0] // 2), ) - up_slices = (slice(linear_fc1_weight.shape[0] // 2, None), ) + fc1_weight = mg_mlp.linear_fc1.weight + fc1_weight = fc1_weight.view(2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) + gate_up_proj_weight = self._get_weights(None if fc1_weight is None else fc1_weight, 'linear_fc1.weight') + if gate_up_proj_weight is not None: if hasattr(hf_mlp, 'gate_up_proj'): - gate_up_proj_weight = hf_state_dict['gate_up_proj.weight'].load() - gate_proj_weight = gate_up_proj_weight[gate_slices] - up_proj_weight = gate_up_proj_weight[up_slices] + hf_state_dict['gate_up_proj'] = gate_up_proj_weight.view(-1, gate_up_proj_weight.shape[-1]).clone() else: - gate_proj_weight = hf_state_dict['gate_proj.weight'].load() - up_proj_weight = hf_state_dict['up_proj.weight'].load() - self._set_weights( - linear_fc1_weight, gate_proj_weight, fc1_key, is_expert=is_expert, mg_slices=gate_slices) - self._set_weights(linear_fc1_weight, up_proj_weight, fc1_key, is_expert=is_expert, mg_slices=up_slices) - - self._set_state_dict(mg_mlp, fc2_key, hf_state_dict, 'down_proj.weight', reverse, is_expert=is_expert) + hf_state_dict['gate_proj.weight'] = gate_up_proj_weight[0].clone() + hf_state_dict['up_proj.weight'] = gate_up_proj_weight[1].clone() + else: + fc1_weight = mg_mlp.linear_fc1.weight + fc1_weight = fc1_weight.new_empty(2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) + if hasattr(hf_mlp, 'gate_up_proj'): + gate_up_proj_weight = hf_state_dict['gate_up_proj.weight'].load() + else: + gate_proj_weight = hf_state_dict['gate_proj.weight'].load() + up_proj_weight = hf_state_dict['up_proj.weight'].load() + gate_up_proj_weight = torch.concat([gate_proj_weight, up_proj_weight], dim=0) + gate_up_proj_weight = gate_up_proj_weight.view(2, -1, gate_up_proj_weight.shape[-1]) + self._set_weights(fc1_weight, gate_up_proj_weight, 'linear_fc1.weight') + mg_mlp.linear_fc1.weight.data.copy_(fc1_weight.view(-1, fc1_weight.shape[-1])) + self._set_state_dict(mg_mlp, 'linear_fc2.weight', hf_state_dict, 'down_proj.weight', reverse) if reverse: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict @@ -425,7 +511,8 @@ def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): hf_state_dict = {} else: yield - for layer_idx in tqdm(range(self.args.num_layers), dynamic_ncols=True, desc='Converting: '): + for layer_idx in tqdm( + range(self.args.num_layers), dynamic_ncols=True, desc='Converting: ', disable=self.disable_tqmd): start_idx = mg_model.decoder.layers[0].layer_number - 1 mg_layer_available = (start_idx <= layer_idx < mg_model.decoder.layers[-1].layer_number) if mg_layer_available: @@ -460,7 +547,9 @@ def load_weights(self, mg_model, hf_model_dir: str) -> None: state_dict = loader.get_state_dict() list(self._convert(mg_model, state_dict, '', False)) - def export_weights(self, mg_models): + def export_weights(self, mg_models, target_device=None, only_last_rank: bool = False): + self._target_device = target_device + self._only_last_rank = only_last_rank state_dict = {} for mg_model in mg_models: yield from self._convert(mg_model, state_dict, '', True) @@ -468,7 +557,7 @@ def export_weights(self, mg_models): def save_weights(self, mg_models, output_dir: str) -> None: """Save the mg_model checkpoint in HF format""" saver = StreamingSafetensorSaver(save_dir=output_dir, max_shard_size=self.args.max_shard_size) - for k, v in self.export_weights(mg_models): + for k, v in self.export_weights(mg_models, target_device='cpu', only_last_rank=True): saver.add_tensor(k, v) saver.finalize() if is_last_rank(): diff --git a/swift/megatron/model/mm_gpt/qwen3_vl.py b/swift/megatron/model/mm_gpt/qwen3_vl.py index b0ce9b8f02..619e3d5857 100644 --- a/swift/megatron/model/mm_gpt/qwen3_vl.py +++ b/swift/megatron/model/mm_gpt/qwen3_vl.py @@ -99,9 +99,8 @@ def _get_inputs_embeds(inputs_embeds, inputs, visual, processor, config): media_inputs = processor.image_processor(images=images, return_tensors='pt') media_inputs = to_device(media_inputs, input_ids.device) pixel_values = media_inputs['pixel_values'].type(dtype) - image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])[0] + image_embeds, deepstack_visual_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0. - deepstack_visual_embeds = None visual_pos_masks = None else: if pixel_values is None: @@ -469,6 +468,8 @@ def forward( def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor): + if visual_pos_masks is None: + return hidden_states + visual_embeds.mean() * 0 visual_pos_masks = visual_pos_masks.to(hidden_states.device) visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds From 74b74567957b18cb9c04f22f6c086928c902d5b7 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Oct 2025 18:27:41 +0800 Subject: [PATCH 24/30] update --- swift/megatron/convert.py | 20 ++++++++++++-------- swift/megatron/init.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index e882498b23..49bc5ad1b3 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -6,6 +6,7 @@ from typing import Any, Dict import torch +import torch.distributed as dist import torch.nn as nn from megatron.core import mpu from megatron.training import get_args @@ -65,8 +66,13 @@ def _find_modules(model, recurse: bool = True, prefix='', ignore_modules=None): @contextmanager -def _model_cpu_forward_context(modules, torch_dtype=None, compute_device=None, share_embedding: bool = False, target_device='cpu'): +def _model_cpu_forward_context(modules, + torch_dtype=None, + compute_device=None, + share_embedding: bool = False, + target_device='cpu'): origin_torch_dtype = next(modules[-1].parameters()).dtype + def _to_cuda_hook(module, args): if compute_device is not None or torch_dtype is not None: module.to(device=compute_device, dtype=torch_dtype) @@ -155,18 +161,15 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float share_embedding = mg_language_model.share_embeddings_and_output_weights if hf_model is not None: hf_model.eval() - _test_params_sum(hf_model) + if dist.get_world_size() == 1: + _test_params_sum(hf_model) template.register_post_encode_hook([hf_model]) HfConfigFactory.set_model_config_attr(hf_model, 'use_cache', False) model_arch = hf_model.model_meta.model_arch ignore_modules = (model_arch.vision_tower + model_arch.aligner) if is_multimodal else [] hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules) - if hf_model.device == 'cpu': - compute_device = 'cuda' - else: - compute_device = None with torch.inference_mode(), _model_cpu_forward_context( - hf_modules, torch_dtype, compute_device, share_embedding=share_embedding): + hf_modules, torch_dtype, share_embedding=share_embedding): hf_inputs.pop('text_position_ids', None) hf_logits = hf_model(**hf_inputs).logits hf_logits = hf_logits.to('cuda') @@ -187,8 +190,9 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float mg_inputs.pop(key, None) mg_inputs.update({'packed_seq_params': packed_seq_params}) args = get_args() + mg_device = next(mg_language_model.parameters()).device with torch.inference_mode(), _model_cpu_forward_context( - mg_modules, mg_torch_dtype, 'cuda', share_embedding=share_embedding): + mg_modules, mg_torch_dtype, 'cuda', share_embedding=share_embedding, target_device=mg_device): # TODO: test pp tie_weights mg_logits = forward_step_helper(mg_model, mg_inputs, dtype=mg_torch_dtype) if args.tensor_model_parallel_size > 1: diff --git a/swift/megatron/init.py b/swift/megatron/init.py index d1892f1cba..a1c6272d18 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -518,6 +518,16 @@ def _worker(plan_shard): FileSystemReader.read_data = read_data +def _patch_validate_non_overlapping_shards_metadata(): + # too slow + from torch.distributed._shard.sharded_tensor import api + + def validate_non_overlapping_shards_metadata(*args, **kwargs): + pass + + api.validate_non_overlapping_shards_metadata = validate_non_overlapping_shards_metadata + + def _patch_TELinear(): from megatron.core.extensions.transformer_engine import TELinear @@ -681,6 +691,11 @@ def _patch_megatron(): logger.info('Patch FileSystemReader successfully applied.') except Exception: pass + try: + _patch_validate_non_overlapping_shards_metadata() + except Exception: + logger.warning('Patch validate_non_overlapping_shards_metadata failed.') + pass try: _patch_peft_BaseTuner() _patch_peft_ModulesToSaveWrapper() From 296a16fd423b13665a7ab28398fa8d0fc82de848 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Oct 2025 20:21:05 +0800 Subject: [PATCH 25/30] support vpp --- swift/megatron/convert.py | 4 ++-- swift/megatron/export/export.py | 6 ++---- swift/megatron/model/gpt_bridge.py | 20 ++++++++++++++------ swift/megatron/train/sft.py | 2 +- swift/megatron/trainers/base.py | 4 ++-- swift/megatron/utils/utils.py | 6 +++--- 6 files changed, 24 insertions(+), 18 deletions(-) diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 49bc5ad1b3..94d3d45a73 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -175,8 +175,9 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float hf_logits = hf_logits.to('cuda') hf_model.to('cpu') + args = get_args() template.use_megatron = True - mg_inputs = to_device(template.data_collator([inputs], padding_to=get_padding_to()), 'cuda') + mg_inputs = to_device(template.data_collator([inputs], padding_to=get_padding_to(args)), 'cuda') packed_seq_params = None mg_torch_dtype = torch_dtype # thd @@ -189,7 +190,6 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float for key in ['labels', 'num_samples', 'attention_mask_2d', 'text_position_ids']: mg_inputs.pop(key, None) mg_inputs.update({'packed_seq_params': packed_seq_params}) - args = get_args() mg_device = next(mg_language_model.parameters()).device with torch.inference_mode(), _model_cpu_forward_context( mg_modules, mg_torch_dtype, 'cuda', share_embedding=share_embedding, target_device=mg_device): diff --git a/swift/megatron/export/export.py b/swift/megatron/export/export.py index 8d2582f0ec..c1dc6c626e 100644 --- a/swift/megatron/export/export.py +++ b/swift/megatron/export/export.py @@ -44,8 +44,6 @@ def convert_mcore2hf(self) -> None: logger.info('Converting weights and saving the model...') bridge = megatron_model_meta.bridge_cls() bridge.save_weights([mg_model], args.save) - logger.info(f'Successfully saved HF model weights in `{args.save}`.') - dist.barrier() if args.test_convert_precision: with disable_safe_ddp_context_use_barrier(): hf_model = prepare_model_template( @@ -68,8 +66,8 @@ def convert_hf2mcore(self) -> None: logger.info('Megatron model created successfully.') bridge = megatron_model_meta.bridge_cls() bridge.load_weights(mg_model, args.model_info.model_dir) - logger.info('Successfully transferred HF model weights to MG model.') dist.barrier() + logger.info('Successfully transferred HF model weights to MG model.') if args.test_convert_precision: with disable_safe_ddp_context_use_barrier(): hf_model = prepare_model_template(args, device_map='cpu')[0] if is_last_rank() else None @@ -78,7 +76,7 @@ def convert_hf2mcore(self) -> None: args.save_args(args.save) logger.info('Saving the model...') mg_save_checkpoint(1, [mg_model], None, None, 0) - logger.info(f'Successfully saved Megatron model weights in `{args.save}`.') + logger.info_if(f'Successfully saved Megatron model weights in `{args.save}`.', cond=is_last_rank()) def megatron_export_main(args: Optional[Union[List[str], MegatronExportArguments]] = None): diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index eff8f54e25..dff30d04cc 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -7,9 +7,10 @@ from tqdm import tqdm from swift.llm import deep_getattr, get_model_tokenizer, save_checkpoint -from swift.utils import disable_safe_ddp_context_use_barrier, is_last_rank +from swift.utils import disable_safe_ddp_context_use_barrier, is_last_rank, get_logger from ..utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver +logger = get_logger() class GPTBridge: lm_layers_prefix = 'model.layers' # HF model @@ -497,12 +498,14 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): + def _convert(self, mg_models, hf_state_dict, hf_prefix: str, reverse: bool): """reverse: False: hf -> mg; True: mg -> hf""" if reverse: hf_state_dict = {} else: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + mg_models = iter(mg_models) + mg_model = next(mg_models) if reverse or mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): self._set_state_dict(mg_model, 'embedding.word_embeddings.weight', hf_state_dict, 'model.embed_tokens.weight', reverse) @@ -522,6 +525,12 @@ def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): mg_layer = None else: continue + if reverse and self.pp_size > 1: + has_model = torch.tensor([mg_layer is not None], dtype=torch.bool, device='cuda') + dist.all_reduce(has_model, group=self.pp_group) + if not has_model: + mg_model = next(mg_models) + continue res = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, reverse) if reverse: yield from list(self._add_prefix(res, hf_prefix).items()) @@ -545,14 +554,12 @@ def _convert(self, mg_model, hf_state_dict, hf_prefix: str, reverse: bool): def load_weights(self, mg_model, hf_model_dir: str) -> None: with SafetensorLazyLoader(hf_model_dir) as loader: state_dict = loader.get_state_dict() - list(self._convert(mg_model, state_dict, '', False)) + list(self._convert([mg_model], state_dict, '', False)) def export_weights(self, mg_models, target_device=None, only_last_rank: bool = False): self._target_device = target_device self._only_last_rank = only_last_rank - state_dict = {} - for mg_model in mg_models: - yield from self._convert(mg_model, state_dict, '', True) + yield from self._convert(mg_models, {}, '', True) def save_weights(self, mg_models, output_dir: str) -> None: """Save the mg_model checkpoint in HF format""" @@ -569,3 +576,4 @@ def save_weights(self, mg_models, output_dir: str) -> None: output_dir, model_dirs=[self.hf_model.model_info.model_dir], additional_saved_files=self.hf_model.model_meta.additional_saved_files) + logger.info_info(f'Successfully saved HF model weights in `{output_dir}`.', cond=is_last_rank()) diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 0d604ba778..31978b343c 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -42,7 +42,7 @@ def __init__(self, args: Optional[Union[List[str], MegatronTrainArguments]] = No def _get_data_collator(self): data_collator = self.template.data_collator - padding_to = get_padding_to() + padding_to = get_padding_to(self.args) logger.info(f'padding_to: {padding_to}') data_collator = partial(data_collator, padding_to=padding_to) return data_collator diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 335d656cdd..f54ec0a6e6 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -713,8 +713,8 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear def save_checkpoint(self, iteration, *_args, **kwargs): args = get_args() if args.save_hf_checkpoint: - ouput_dir = os.path.join(args.save, f'checkpoint-{iteration}') - self.bridge.save_weights(self.unwrapped_models, ouput_dir) + output_dir = os.path.join(args.save, f'checkpoint-{iteration}') + self.bridge.save_weights(self.unwrapped_models, output_dir) else: with adapter_state_dict_context(): return self._origin_save_checkpoint(iteration, *_args, **kwargs) diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 89b77888c3..48b9c01f52 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -301,13 +301,13 @@ def forward_step_helper(model, inputs, dtype=None): return output_tensor -def get_padding_to(): - args = get_args() +def get_padding_to(args): padding_to = None if args.tensor_model_parallel_size > 1 and args.sequence_parallel: padding_to = args.tensor_model_parallel_size if args.context_parallel_size > 1: padding_to = (padding_to or 1) * args.context_parallel_size - if args.fp8: + fp8_format = getattr(args, 'fp8_format', None) or getattr(args, 'fp8', None) + if fp8_format is not None: padding_to = max((padding_to or 1) * 8, 16) return padding_to From aa15b972cc33513ab2575a2c9fbfdd5211feb5e0 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Oct 2025 21:06:55 +0800 Subject: [PATCH 26/30] update --- swift/megatron/model/gpt_bridge.py | 5 ++-- swift/megatron/trainers/base.py | 26 +++++++++++++++++++ swift/megatron/tuners/lora.py | 41 ++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index dff30d04cc..591cd721e4 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -7,11 +7,12 @@ from tqdm import tqdm from swift.llm import deep_getattr, get_model_tokenizer, save_checkpoint -from swift.utils import disable_safe_ddp_context_use_barrier, is_last_rank, get_logger +from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_last_rank from ..utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver logger = get_logger() + class GPTBridge: lm_layers_prefix = 'model.layers' # HF model @@ -576,4 +577,4 @@ def save_weights(self, mg_models, output_dir: str) -> None: output_dir, model_dirs=[self.hf_model.model_info.model_dir], additional_saved_files=self.hf_model.model_meta.additional_saved_files) - logger.info_info(f'Successfully saved HF model weights in `{output_dir}`.', cond=is_last_rank()) + logger.info_if(f'Successfully saved HF model weights in `{output_dir}`.', cond=is_last_rank()) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index f54ec0a6e6..dfa48f8050 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import collections import os +import shutil import time from abc import ABC, abstractmethod from contextlib import contextmanager @@ -31,6 +32,7 @@ from swift.plugin import MeanMetric from swift.trainers import SwiftMixin from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger +from ..tuners import LoraParallelLinear from ..utils import adapter_state_dict_context, copy_original_module_weight, patch_merge_fn, prepare_mcore_model from .utils import (get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, get_swift_datasets_provider) @@ -710,11 +712,35 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear return report_memory_flag + def merge_lora_adapters(self): + """Merge LoRA adapters into base model weights for vLLM inference.""" + for model in self.unwrapped_models: + for module in model.modules(): + if isinstance(module, LoraParallelLinear): + # Merge all active adapters + module.merge() + + def unmerge_lora_adapters(self): + """Unmerge LoRA adapters to restore training state.""" + for model in self.unwrapped_models: + for module in model.modules(): + if isinstance(module, LoraParallelLinear): + # Unmerge to restore separate LoRA weights for training + module.unmerge() + def save_checkpoint(self, iteration, *_args, **kwargs): args = get_args() if args.save_hf_checkpoint: + if args.train_type == 'lora': + self.merge_lora_adapters() output_dir = os.path.join(args.save, f'checkpoint-{iteration}') self.bridge.save_weights(self.unwrapped_models, output_dir) + if is_last_rank(): + args_path = os.path.join(os.path.dirname(output_dir), 'args.json') + if os.path.exists(args_path): + shutil.copy(args_path, os.path.join(output_dir, 'args.json')) + if args.train_type == 'lora': + self.unmerge_lora_adapters() else: with adapter_state_dict_context(): return self._origin_save_checkpoint(iteration, *_args, **kwargs) diff --git a/swift/megatron/tuners/lora.py b/swift/megatron/tuners/lora.py index f9ad78ef50..e41e29326a 100644 --- a/swift/megatron/tuners/lora.py +++ b/swift/megatron/tuners/lora.py @@ -422,6 +422,47 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N if origin_device.type == 'cpu': self.to(device=origin_device) + def unmerge(self) -> None: + """ + Unmerge all merged adapter weights from the base weights. + This method reverses the merge operation by subtracting the LoRA delta weights + from the base layer weights, restoring the original base weights. + """ + if not self.merged: + # No adapters to unmerge + return + + base_layer = self.get_base_layer() + origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device + if origin_device.type == 'cpu': + self.to(device=get_current_device()) + + for active_adapter in self.merged_adapters: + if active_adapter in self.lora_A.keys(): + if self.is_grouped: + orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)] + else: + orig_weights = [base_layer.weight] + + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip(orig_weights, delta_weights): + # Subtract the delta weight to unmerge + orig_weight.data -= delta_weight + + # Clear the merged adapters list + self.merged_adapters = [] + + if origin_device.type == 'cpu': + self.to(device=origin_device) + + def __getattr__(self, key: str): + try: + return super().__getattr__(key) + except AttributeError: + if 'base_layer' in dir(self): + return getattr(self.base_layer, key) + raise + def dispatch_megatron( target: torch.nn.Module, From 333d42ead4e096408ab8a6d7c6ed439d7b18204f Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 30 Oct 2025 15:55:43 +0800 Subject: [PATCH 27/30] support lora --- swift/megatron/argument/megatron_args.py | 3 + swift/megatron/model/gpt_bridge.py | 505 +++++++++++++---------- swift/megatron/trainers/base.py | 4 +- swift/megatron/tuners/lora.py | 8 - swift/megatron/utils/io_utils.py | 13 +- 5 files changed, 311 insertions(+), 222 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 3e5fde66e8..98fed1e4be 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -96,6 +96,7 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): mlp_padding_free: bool = False load_hf_checkpoint: bool = False save_hf_checkpoint: bool = False + merge_lora: Optional[bool] = None # streaming dataloader dataloader_persistent_workers: bool = True dataloader_prefetch_factor: int = 10 @@ -456,6 +457,8 @@ def __post_init__(self): self.seq_length = self.max_position_embeddings if self.position_embedding_type is None: self.position_embedding_type = 'rope' + if self.merge_lora is None: + self.merge_lora = self.save_hf_checkpoint self._init_moe() self._init_mixed_precision() diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 591cd721e4..86dda71ef7 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1,3 +1,4 @@ +from copy import copy from typing import Dict, Literal, Optional, Union import torch @@ -8,6 +9,7 @@ from swift.llm import deep_getattr, get_model_tokenizer, save_checkpoint from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_last_rank +from ..tuners import LoraParallelLinear from ..utils import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver logger = get_logger() @@ -20,7 +22,8 @@ def __init__(self, disable_tqmd: bool = False): self.args = get_args() self.disable_tqmd = disable_tqmd self._target_device = None - self.only_last_rank = False + self._only_last_rank = False + self._peft_target_modules = set() model_info = self.args.model_info with torch.device('meta'), disable_safe_ddp_context_use_barrier(): self.hf_model, self.processor = get_model_tokenizer( @@ -40,19 +43,31 @@ def __init__(self, disable_tqmd: bool = False): self.pp_rank = mpu.get_pipeline_model_parallel_rank() self.etp_rank = mpu.get_expert_tensor_parallel_rank() self.ep_rank = mpu.get_expert_model_parallel_rank() + self.save_peft_checkpoint = self.args.train_type == 'lora' and not self.args.merge_lora @staticmethod def _get_tp_split_dim(mg_key: str) -> Optional[int]: - key, suffix = mg_key.rsplit('.', 2)[-2:] - if suffix == 'layer_norm_weight': - return - if key in {'word_embeddings', 'output_layer', 'linear_qkv'}: - return 0 - elif key in {'linear_proj', 'linear_fc1', 'linear_fc2'}: - # linear_fc1 shape [2, X, Y] - return 1 - - def _set_weights( + if 'lora_A' not in mg_key and 'lora_B' not in mg_key: + key, suffix = mg_key.rsplit('.', 2)[-2:] + if suffix == 'layer_norm_weight': + return + if key in {'word_embeddings', 'output_layer', 'linear_qkv'}: + return 0 + elif key in {'linear_proj', 'linear_fc1', 'linear_fc2'}: + # linear_fc1 shape [2, X, Y] + return 1 + else: + mg_key_splited = mg_key.rsplit('.', 3) + key, lora_name = mg_key_splited[:2] + if lora_name == 'lora_A': + if key in {'linear_proj', 'linear_fc1', 'linear_fc2'}: + # linear_fc1 shape [2, X, Y] + return 1 + elif lora_name == 'lora_B': + if key in {'word_embeddings', 'output_layer', 'linear_qkv'}: + return 0 + + def _set_weight( self, mg_param: torch.Tensor, hf_weight: torch.Tensor, @@ -85,7 +100,7 @@ def _set_weights( tensor = tensor + offset mg_param.data.copy_(tensor) - def _get_weights(self, mg_weight: torch.Tensor, mg_key: str, offset: int = 0, is_expert: bool = False): + def _get_weight(self, mg_weight: torch.Tensor, mg_key: str, offset: int = 0, is_expert: bool = False): # tp/etp tp_dim = self._get_tp_split_dim(mg_key) tensor = mg_weight @@ -154,17 +169,39 @@ def _set_state_dict(self, mg_key: str, hf_state_dict, hf_key: str, - reverse: bool, + to_mcore: bool, offset: float = 0, is_expert: bool = False): - mg_param = deep_getattr(mg_module, mg_key) - if reverse: - hf_state_dict[hf_key] = self._get_weights(None if mg_param is None else mg_param.data, mg_key, offset, - is_expert) - else: - assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}' - hf_weight = hf_state_dict[hf_key].load() - self._set_weights(mg_param, hf_weight, mg_key, offset, is_expert) + module_key, param_key = mg_key.rsplit('.', 1) + sub_module = deep_getattr(mg_module, module_key) + if isinstance(sub_module, + LoraParallelLinear) and self.save_peft_checkpoint and param_key != 'layer_norm_weight': + hf_module_key, hf_param_key = hf_key.rsplit('.', 1) + lora_A_key = f'{module_key}.lora_A.default.{param_key}' + lora_B_key = f'{module_key}.lora_B.default.{param_key}' + lora_A_tensor = deep_getattr(mg_module, f'{lora_A_key}.data') + lora_B_tensor = deep_getattr(mg_module, f'{lora_B_key}.data') + hf_lora_A_key = f'{hf_module_key}.lora_A.default.{hf_param_key}' + hf_lora_B_key = f'{hf_module_key}.lora_B.default.{hf_param_key}' + lora_A = self._get_weight(lora_A_tensor, lora_A_key, offset, is_expert) + lora_B = self._get_weight(lora_B_tensor, lora_B_key, offset, is_expert) + if lora_A is not None: + self._peft_target_modules.add(hf_module_key) + hf_state_dict[hf_lora_A_key] = lora_A + hf_state_dict[hf_lora_B_key] = lora_B + elif to_mcore or not self.save_peft_checkpoint: + if isinstance(sub_module, LoraParallelLinear): + mg_param = deep_getattr(sub_module, f'base_layer.{param_key}') + else: + mg_param = deep_getattr(sub_module, param_key) + if to_mcore: + assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}' + hf_weight = hf_state_dict[hf_key].load() + self._set_weight(mg_param, hf_weight, mg_key, offset, is_expert) + else: + weight = self._get_weight(None if mg_param is None else mg_param.data, mg_key, offset, is_expert) + if weight is not None: + hf_state_dict[hf_key] = weight @staticmethod def _remove_prefix(state_dict, prefix: str): @@ -184,14 +221,6 @@ def _filter_prefix(state_dict, prefix: str): return state_dict return {k: v for k, v in state_dict.items() if k.startswith(prefix)} - @staticmethod - def _replace_prefix(state_dict, hf_prefix: str, mg_prefix: str, reverse: bool): - src_prefix, tgt_prefix = hf_prefix, mg_prefix - if reverse: - src_prefix, tgt_prefix = tgt_prefix, src_prefix - res = GPTBridge._remove_prefix(state_dict, src_prefix) - return GPTBridge._add_prefix(res, tgt_prefix) - @staticmethod def _is_moe(state_dict): for k, v in state_dict.items(): @@ -199,61 +228,83 @@ def _is_moe(state_dict): return True return False - def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, reverse: bool): - if reverse: - hf_state_dict = {} - else: + def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} hf_attn = self.hf_layers[layer_idx].self_attn args = self.args num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) - if reverse: - mg_attn_weight = self._get_weights(None if mg_attn is None else mg_attn.linear_qkv.weight.data, - 'linear_qkv.weight') - if mg_attn_weight is not None: - mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size)) - q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ - 0] // num_query_groups - hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size).clone() - hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, - args.hidden_size).clone() - hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size).clone() - del mg_attn_weight - else: + if to_mcore: linear_qkv_weight = torch.cat([ hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), ], dim=1).reshape((-1, args.hidden_size)) - self._set_weights(mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight') - self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', reverse) + self._set_weight(mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight') + else: + q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ + 0] // num_query_groups + is_lora = False if mg_attn is None else isinstance(mg_attn.linear_qkv, + LoraParallelLinear) and not mg_attn.linear_qkv.merged + if self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group) + if is_lora: + lora_A = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.lora_A['default'].weight.data, + 'linear_qkv.lora_A.default.weight') + lora_B = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.lora_B['default'].weight.data, + 'linear_qkv.lora_B.default.weight') + if lora_A is not None: + self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'}) + for key in ['q_proj', 'k_proj', 'v_proj']: + hf_state_dict[f'{key}.lora_A.default.weight'] = lora_A.clone() + lora_B = lora_B.reshape((num_query_groups, -1, lora_B.shape[-1])) + hf_state_dict['q_proj.lora_B.default.weight'] = lora_B[:, :q_dim, :].reshape( + -1, lora_B.shape[-1]).clone() + hf_state_dict['k_proj.lora_B.default.weight'] = lora_B[:, q_dim:-kv_dim, :].reshape( + -1, lora_B.shape[-1]).clone() + hf_state_dict['v_proj.lora_B.default.weight'] = lora_B[:, -kv_dim:, :].reshape( + -1, lora_B.shape[-1]).clone() + else: + mg_attn_weight = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.weight.data, + 'linear_qkv.weight') + if mg_attn_weight is not None: + mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size)) + hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size).clone() + hf_state_dict['k_proj.weight'] = mg_attn_weight[:, + q_dim:-kv_dim, :].reshape(-1, + args.hidden_size).clone() + hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, + args.hidden_size).clone() + del mg_attn_weight + self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore) # Copy bias if args.add_qkv_bias: - if reverse: - mg_attn_bias = self._get_weights(None if mg_attn is None else mg_attn.linear_qkv.bias.data, - 'linear_qkv.bias') - if mg_attn_bias is not None: - mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1)) - hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone() - hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone() - hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone() - else: + if to_mcore: linear_qkv_bias = torch.cat([ hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)), hf_state_dict['k_proj.bias'].load().reshape((num_query_groups, -1)), hf_state_dict['v_proj.bias'].load().reshape((num_query_groups, -1)), ], dim=1).reshape(-1) - self._set_weights(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias') - + self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias') + else: + mg_attn_bias = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.bias.data, + 'linear_qkv.bias') + if mg_attn_bias is not None: + mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1)) + hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone() + hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone() + hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone() if args.qk_layernorm: hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight' hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight' - self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, hf_q_norm_key, reverse) - self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, hf_k_norm_key, reverse) - if reverse: + self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, hf_q_norm_key, to_mcore) + self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, hf_k_norm_key, to_mcore) + if not to_mcore: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict @@ -263,17 +314,18 @@ def _set_moe_state( hf_state_dict, hf_prefix: str, layer_idx: int, - reverse: bool, + to_mcore: bool, ): - if reverse: - hf_state_dict = {} - else: + if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + hf_mlp = self.hf_layers[layer_idx].mlp hf_gate_key = 'gate.wg.weight' if hasattr(hf_mlp.gate, 'wg') else 'gate.weight' - self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, reverse) + self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore) if self.args.moe_router_enable_expert_bias: - self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, 'gate.e_score_correction_bias', reverse) + self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, 'gate.e_score_correction_bias', to_mcore) if self.args.moe_shared_expert_intermediate_size: for key in ['shared_expert', 'shared_experts', 'shared_mlp']: @@ -286,22 +338,22 @@ def _set_moe_state( hf_state_dict, hf_shared_expert_prefix, layer_idx, - reverse, + to_mcore, hf_mlp=shared_expert)) if hasattr(hf_mlp, 'shared_expert_gate'): self._set_state_dict(mg_mlp, 'shared_experts.gate_weight', hf_state_dict, 'shared_expert_gate.weight', - reverse) + to_mcore) for ep_rank in range(self.ep_size): mg_experts = mg_mlp.experts expert_available = ep_rank == self.ep_rank if not expert_available: - if reverse: - mg_experts = None - else: + if to_mcore: continue + else: + mg_experts = None hf_state_dict.update( - self._set_expert_state(mg_experts, hf_state_dict, 'experts.', layer_idx, reverse, ep_rank)) - if reverse: + self._set_expert_state(mg_experts, hf_state_dict, 'experts.', layer_idx, to_mcore, ep_rank)) + if not to_mcore: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict @@ -311,40 +363,21 @@ def _set_expert_state( hf_state_dict, hf_prefix: str, layer_idx: int, - reverse: bool, + to_mcore: bool, ep_rank: int, ): - if reverse: - hf_state_dict = {} - else: + if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} hf_experts = self.hf_layers[layer_idx].mlp.experts num_local_experts = self.args.num_experts // self.ep_size # if hf_group_idx is not None: # res[fc1_key] = hf_state_dict['gate_up_proj'][hf_group_idx].t() # res[fc2_key] = hf_state_dict['down_proj'][hf_group_idx].t() # else: - if reverse: - if mg_mlp is None: - fc1_weight = None - else: - fc1_weight = torch.concat([getattr(mg_mlp.linear_fc1, f'weight{i}') for i in range(num_local_experts)], - dim=0) - fc1_weight = fc1_weight.view(num_local_experts * 2, -1, fc1_weight.shape[1]) - gate_up_proj_weight = self._get_weights(fc1_weight, 'linear_fc1.weight', is_expert=True) - del fc1_weight - if gate_up_proj_weight is not None: - gate_up_proj_weight = gate_up_proj_weight.view(num_local_experts, 2, -1, gate_up_proj_weight.shape[-1]) - for i in range(num_local_experts): - hf_i = i + ep_rank * num_local_experts - if hasattr(hf_experts[i], 'gate_up_proj'): - hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = gate_up_proj_weight[i].view( - -1, gate_up_proj_weight.shape[-1]).clone() - else: - hf_state_dict[f'{hf_i}.gate_proj.weight'] = gate_up_proj_weight[i][0].clone() - hf_state_dict[f'{hf_i}.up_proj.weight'] = gate_up_proj_weight[i][1].clone() - del gate_up_proj_weight - else: + if to_mcore: + # linear_fc1 fc1_weight = mg_mlp.linear_fc1.weight0 fc1_weight = fc1_weight.new_empty(num_local_experts * 2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) if hasattr(hf_experts[0], 'gate_up_proj'): @@ -362,19 +395,51 @@ def _set_expert_state( weight_list.append(torch.stack([gate_proj_weight, up_proj_weight], dim=0)) gate_up_proj_weight = torch.concat(weight_list, dim=0) del weight_list - self._set_weights(fc1_weight, gate_up_proj_weight, 'linear_fc1.weight', is_expert=True) + self._set_weight(fc1_weight, gate_up_proj_weight, 'linear_fc1.weight', is_expert=True) fc1_weight = fc1_weight.view(num_local_experts, -1, fc1_weight.shape[-1]) for i in range(num_local_experts): getattr(mg_mlp.linear_fc1, f'weight{i}').data.copy_(fc1_weight[i].view(-1, fc1_weight.shape[-1])) del fc1_weight - if reverse: + # linear_fc2 + fc2_weight = mg_mlp.linear_fc2.weight0 + fc2_weight = fc2_weight.new_empty(num_local_experts * fc2_weight.shape[0], fc2_weight.shape[1]) + down_proj_weight = torch.concat([ + hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.weight'].load() + for i in range(num_local_experts) + ], + dim=0) + self._set_weight(fc2_weight, down_proj_weight, 'linear_fc2.weight', is_expert=True) + fc2_weight = fc2_weight.view(num_local_experts, -1, fc2_weight.shape[-1]) + for i in range(num_local_experts): + getattr(mg_mlp.linear_fc2, f'weight{i}').data.copy_(fc2_weight[i]) + else: + if mg_mlp is None: + fc1_weight = None + else: + fc1_weight = torch.concat([getattr(mg_mlp.linear_fc1, f'weight{i}') for i in range(num_local_experts)], + dim=0) + fc1_weight = fc1_weight.view(num_local_experts * 2, -1, fc1_weight.shape[1]) + gate_up_proj_weight = self._get_weight(fc1_weight, 'linear_fc1.weight', is_expert=True) + del fc1_weight + if gate_up_proj_weight is not None: + gate_up_proj_weight = gate_up_proj_weight.view(num_local_experts, 2, -1, gate_up_proj_weight.shape[-1]) + for i in range(num_local_experts): + hf_i = i + ep_rank * num_local_experts + if hasattr(hf_experts[i], 'gate_up_proj'): + hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = gate_up_proj_weight[i].view( + -1, gate_up_proj_weight.shape[-1]).clone() + else: + hf_state_dict[f'{hf_i}.gate_proj.weight'] = gate_up_proj_weight[i][0].clone() + hf_state_dict[f'{hf_i}.up_proj.weight'] = gate_up_proj_weight[i][1].clone() + del gate_up_proj_weight + # linear_fc2 if mg_mlp is None: fc2_weight = None else: fc2_weight = torch.concat([getattr(mg_mlp.linear_fc2, f'weight{i}') for i in range(num_local_experts)], dim=0) fc2_weight = fc2_weight.view(num_local_experts * 2, -1, fc2_weight.shape[1]) - down_proj_weight = self._get_weights(fc2_weight, 'linear_fc2.weight', is_expert=True) + down_proj_weight = self._get_weight(fc2_weight, 'linear_fc2.weight', is_expert=True) del fc2_weight if down_proj_weight is not None: down_proj_weight = down_proj_weight.view(num_local_experts, -1, down_proj_weight.shape[-1]) @@ -382,43 +447,18 @@ def _set_expert_state( hf_i = i + ep_rank * num_local_experts hf_state_dict[f'{hf_i}.down_proj.weight'] = down_proj_weight[i].view( -1, down_proj_weight.shape[-1]).clone() - else: - fc2_weight = mg_mlp.linear_fc2.weight0 - fc2_weight = fc2_weight.new_empty(num_local_experts * fc2_weight.shape[0], fc2_weight.shape[1]) - down_proj_weight = torch.concat([ - hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.weight'].load() - for i in range(num_local_experts) - ], - dim=0) - self._set_weights(fc2_weight, down_proj_weight, 'linear_fc2.weight', is_expert=True) - fc2_weight = fc2_weight.view(num_local_experts, -1, fc2_weight.shape[-1]) - for i in range(num_local_experts): - getattr(mg_mlp.linear_fc2, f'weight{i}').data.copy_(fc2_weight[i]) - if reverse: + if not to_mcore: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _set_mlp_state(self, mg_mlp, hf_state_dict, hf_prefix: str, layer_idx: int, reverse: bool, hf_mlp=None): - if reverse: - hf_state_dict = {} - else: + def _set_mlp_state(self, mg_mlp, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool, hf_mlp=None): + if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} if hf_mlp is None: hf_mlp = self.hf_layers[layer_idx].mlp - if reverse: - if mg_mlp is None: - fc1_weight = None - else: - fc1_weight = mg_mlp.linear_fc1.weight - fc1_weight = fc1_weight.view(2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) - gate_up_proj_weight = self._get_weights(None if fc1_weight is None else fc1_weight, 'linear_fc1.weight') - if gate_up_proj_weight is not None: - if hasattr(hf_mlp, 'gate_up_proj'): - hf_state_dict['gate_up_proj'] = gate_up_proj_weight.view(-1, gate_up_proj_weight.shape[-1]).clone() - else: - hf_state_dict['gate_proj.weight'] = gate_up_proj_weight[0].clone() - hf_state_dict['up_proj.weight'] = gate_up_proj_weight[1].clone() - else: + if to_mcore: fc1_weight = mg_mlp.linear_fc1.weight fc1_weight = fc1_weight.new_empty(2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) if hasattr(hf_mlp, 'gate_up_proj'): @@ -428,93 +468,133 @@ def _set_mlp_state(self, mg_mlp, hf_state_dict, hf_prefix: str, layer_idx: int, up_proj_weight = hf_state_dict['up_proj.weight'].load() gate_up_proj_weight = torch.concat([gate_proj_weight, up_proj_weight], dim=0) gate_up_proj_weight = gate_up_proj_weight.view(2, -1, gate_up_proj_weight.shape[-1]) - self._set_weights(fc1_weight, gate_up_proj_weight, 'linear_fc1.weight') + self._set_weight(fc1_weight, gate_up_proj_weight, 'linear_fc1.weight') mg_mlp.linear_fc1.weight.data.copy_(fc1_weight.view(-1, fc1_weight.shape[-1])) - self._set_state_dict(mg_mlp, 'linear_fc2.weight', hf_state_dict, 'down_proj.weight', reverse) - if reverse: + else: + is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1, + LoraParallelLinear) and not mg_mlp.linear_fc1.merged + if self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group) + if is_lora: + if mg_mlp is None: + lora_A = None + lora_B = None + else: + lora_A = mg_mlp.linear_fc1.lora_A['default'].weight + lora_B = mg_mlp.linear_fc1.lora_B['default'].weight + lora_B = lora_B.view(2, lora_B.shape[0] // 2, lora_B.shape[1]) + lora_A = self._get_weight(lora_A, 'linear_fc1.lora_A.default.weight') + lora_B = self._get_weight(lora_B, 'linear_fc1.lora_B.default.weight') + if lora_A is not None: + if hasattr(hf_mlp, 'gate_up_proj'): + self._peft_target_modules.update({'gate_up_proj'}) + hf_state_dict['gate_up_proj.lora_A.default.weight'] = lora_A.clone() + hf_state_dict['gate_up_proj.lora_B.default.weight'] = lora_B.clone() + else: + self._peft_target_modules.update({'gate_proj', 'up_proj'}) + hf_state_dict['gate_proj.lora_A.default.weight'] = lora_A.clone() + hf_state_dict['up_proj.lora_A.default.weight'] = lora_A.clone() + hf_state_dict['gate_proj.lora_B.default.weight'] = lora_B[0].clone() + hf_state_dict['up_proj.lora_B.default.weight'] = lora_B[1].clone() + else: + if mg_mlp is None: + fc1_weight = None + else: + fc1_weight = mg_mlp.linear_fc1.weight + fc1_weight = fc1_weight.view(2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) + gate_up_proj_weight = self._get_weight(None if fc1_weight is None else fc1_weight, 'linear_fc1.weight') + if gate_up_proj_weight is not None: + if hasattr(hf_mlp, 'gate_up_proj'): + hf_state_dict['gate_up_proj.weight'] = gate_up_proj_weight.view( + -1, gate_up_proj_weight.shape[-1]).clone() + else: + hf_state_dict['gate_proj.weight'] = gate_up_proj_weight[0].clone() + hf_state_dict['up_proj.weight'] = gate_up_proj_weight[1].clone() + self._set_state_dict(mg_mlp, 'linear_fc2.weight', hf_state_dict, 'down_proj.weight', to_mcore) + if not to_mcore: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict def _set_mla_attn_state( self, - mg_model, - mg_prefix: str, - state_dict, + mg_attn, + hf_state_dict, hf_prefix: str, layer_idx: int, - reverse: bool, + to_mcore: bool, ): - src_prefix, tgt_prefix = hf_prefix, mg_prefix - if reverse: - src_prefix, tgt_prefix = tgt_prefix, src_prefix - state_dict = self._remove_prefix(state_dict, src_prefix) - res = {} - self._set_state_dict(state_dict, res, 'o_proj.weight', 'linear_proj.weight', reverse) + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore) if self.args.q_lora_rank is None: - self._set_state_dict(state_dict, res, 'q_proj.weight', 'linear_q_proj.weight', reverse) + self._set_state_dict(mg_attn, 'linear_q_proj.weight', hf_state_dict, 'q_proj.weight', to_mcore) else: - self._set_state_dict(state_dict, res, 'q_a_proj.weight', 'linear_q_down_proj.weight', reverse) - self._set_state_dict(state_dict, res, 'q_b_proj.weight', 'linear_q_up_proj.weight', reverse) - self._set_state_dict(state_dict, res, 'kv_a_proj_with_mqa.weight', 'linear_kv_down_proj.weight', reverse) - self._set_state_dict(state_dict, res, 'kv_b_proj.weight', 'linear_kv_up_proj.weight', reverse) + self._set_state_dict(mg_attn, 'linear_q_down_proj.weight', hf_state_dict, 'q_a_proj.weight', to_mcore) + self._set_state_dict(mg_attn, 'linear_q_up_proj.weight', hf_state_dict, 'q_b_proj.weight', to_mcore) + self._set_state_dict(mg_attn, 'linear_kv_down_proj.weight', hf_state_dict, 'kv_a_proj_with_mqa.weight', + to_mcore) + self._set_state_dict(mg_attn, 'linear_kv_up_proj.weight', hf_state_dict, 'kv_b_proj.weight', to_mcore) if self.args.qk_layernorm: - self._set_state_dict(state_dict, res, 'kv_a_layernorm.weight', 'linear_kv_up_proj.layer_norm_weight', - reverse) - return self._add_prefix(res, tgt_prefix) + self._set_state_dict(mg_attn, 'linear_kv_up_proj.layer_norm_weight', hf_state_dict, 'kv_a_layernorm.weight', + to_mcore) + if not to_mcore: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict - def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, reverse: bool): + def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): mg_attn = None if mg_layer is None else mg_layer.self_attention if self.args.multi_latent_attention: - hf_state_dict.update(self._set_mla_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, reverse)) - self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', reverse) + hf_state_dict.update(self._set_mla_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) + self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore) else: - hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, reverse)) + hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict, - 'input_layernorm.weight', reverse) + 'input_layernorm.weight', to_mcore) return hf_state_dict - def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, reverse: bool): + def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): hf_mlp = self.hf_layers[layer_idx].mlp is_moe = self._is_moe(hf_mlp.state_dict()) mg_mlp = None if mg_layer is None else mg_layer.mlp if is_moe: - hf_state_dict.update(self._set_moe_state(mg_mlp, hf_state_dict, 'mlp.', layer_idx, reverse)) + hf_state_dict.update(self._set_moe_state(mg_mlp, hf_state_dict, 'mlp.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', - reverse) + to_mcore) else: - hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, 'mlp.', layer_idx, reverse)) + hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, 'mlp.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict, - 'post_attention_layernorm.weight', reverse) + 'post_attention_layernorm.weight', to_mcore) return hf_state_dict - def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, reverse: bool): + def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): hf_prefix = f'{hf_prefix}{layer_idx}.' - if reverse: - hf_state_dict = {} - else: + if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, reverse)) - hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, reverse)) - if reverse: + else: + hf_state_dict = {} + hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore)) + if not to_mcore: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _convert(self, mg_models, hf_state_dict, hf_prefix: str, reverse: bool): - """reverse: False: hf -> mg; True: mg -> hf""" - if reverse: - hf_state_dict = {} - else: + def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool): + if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} mg_models = iter(mg_models) mg_model = next(mg_models) - if reverse or mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): + if not to_mcore or mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): self._set_state_dict(mg_model, 'embedding.word_embeddings.weight', hf_state_dict, - 'model.embed_tokens.weight', reverse) - if reverse: + 'model.embed_tokens.weight', to_mcore) + if to_mcore: + yield + else: yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) hf_state_dict = {} - else: - yield for layer_idx in tqdm( range(self.args.num_layers), dynamic_ncols=True, desc='Converting: ', disable=self.disable_tqmd): start_idx = mg_model.decoder.layers[0].layer_number - 1 @@ -522,59 +602,66 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, reverse: bool): if mg_layer_available: mg_layer = mg_model.decoder.layers[layer_idx - start_idx] else: - if reverse: - mg_layer = None - else: + if to_mcore: continue - if reverse and self.pp_size > 1: + else: + mg_layer = None + if not to_mcore and self.pp_size > 1: has_model = torch.tensor([mg_layer is not None], dtype=torch.bool, device='cuda') dist.all_reduce(has_model, group=self.pp_group) if not has_model: mg_model = next(mg_models) continue - res = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, reverse) - if reverse: + res = self._set_layer_state(mg_layer, hf_state_dict, 'model.layers.', layer_idx, to_mcore) + if to_mcore: + yield + else: yield from list(self._add_prefix(res, hf_prefix).items()) hf_state_dict = {} - else: - yield - if reverse or mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): + if not to_mcore or mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): if self.args.untie_embeddings_and_output_weights: hf_lm_head_key = 'lm_head.weight' - if reverse and self.args.task_type == 'seq_cls': + if not to_mcore and self.args.task_type == 'seq_cls': hf_lm_head_key = 'score.weight' - self._set_state_dict(mg_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, reverse) + self._set_state_dict(mg_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, to_mcore) self._set_state_dict(mg_model, 'decoder.final_layernorm.weight', hf_state_dict, 'model.norm.weight', - reverse) - if reverse: + to_mcore) + if to_mcore: + yield + else: yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) hf_state_dict = {} - else: - yield def load_weights(self, mg_model, hf_model_dir: str) -> None: with SafetensorLazyLoader(hf_model_dir) as loader: state_dict = loader.get_state_dict() - list(self._convert([mg_model], state_dict, '', False)) + list(self._convert([mg_model], state_dict, '', True)) def export_weights(self, mg_models, target_device=None, only_last_rank: bool = False): self._target_device = target_device self._only_last_rank = only_last_rank - yield from self._convert(mg_models, {}, '', True) + yield from self._convert(mg_models, {}, '', False) def save_weights(self, mg_models, output_dir: str) -> None: """Save the mg_model checkpoint in HF format""" - saver = StreamingSafetensorSaver(save_dir=output_dir, max_shard_size=self.args.max_shard_size) + saver = StreamingSafetensorSaver( + save_dir=output_dir, max_shard_size=self.args.max_shard_size, is_peft_format=self.save_peft_checkpoint) + self._peft_target_modules = set() for k, v in self.export_weights(mg_models, target_device='cpu', only_last_rank=True): saver.add_tensor(k, v) saver.finalize() if is_last_rank(): - # TODO: new_special_tokens - self.hf_model.config.save_pretrained(output_dir) - save_checkpoint( - None, - self.processor, - output_dir, - model_dirs=[self.hf_model.model_info.model_dir], - additional_saved_files=self.hf_model.model_meta.additional_saved_files) - logger.info_if(f'Successfully saved HF model weights in `{output_dir}`.', cond=is_last_rank()) + if self.save_peft_checkpoint: + peft_config = copy(mg_models[0].peft_config['default']) + peft_config.target_modules = self._peft_target_modules + peft_config.save_pretrained(output_dir) + else: + # TODO: new_special_tokens + self.hf_model.config.save_pretrained(output_dir) + save_checkpoint( + None, + self.processor, + output_dir, + model_dirs=[self.hf_model.model_info.model_dir], + additional_saved_files=self.hf_model.model_meta.additional_saved_files) + logger.info_if(f'Successfully saved `safetensors` model weights in `{output_dir}`.', cond=is_last_rank()) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index dfa48f8050..9dbff94f95 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -731,7 +731,7 @@ def unmerge_lora_adapters(self): def save_checkpoint(self, iteration, *_args, **kwargs): args = get_args() if args.save_hf_checkpoint: - if args.train_type == 'lora': + if args.train_type == 'lora' and args.merge_lora: self.merge_lora_adapters() output_dir = os.path.join(args.save, f'checkpoint-{iteration}') self.bridge.save_weights(self.unwrapped_models, output_dir) @@ -739,7 +739,7 @@ def save_checkpoint(self, iteration, *_args, **kwargs): args_path = os.path.join(os.path.dirname(output_dir), 'args.json') if os.path.exists(args_path): shutil.copy(args_path, os.path.join(output_dir, 'args.json')) - if args.train_type == 'lora': + if args.train_type == 'lora' and args.merge_lora: self.unmerge_lora_adapters() else: with adapter_state_dict_context(): diff --git a/swift/megatron/tuners/lora.py b/swift/megatron/tuners/lora.py index e41e29326a..69b3a0a4ed 100644 --- a/swift/megatron/tuners/lora.py +++ b/swift/megatron/tuners/lora.py @@ -455,14 +455,6 @@ def unmerge(self) -> None: if origin_device.type == 'cpu': self.to(device=origin_device) - def __getattr__(self, key: str): - try: - return super().__getattr__(key) - except AttributeError: - if 'base_layer' in dir(self): - return getattr(self.base_layer, key) - raise - def dispatch_megatron( target: torch.nn.Module, diff --git a/swift/megatron/utils/io_utils.py b/swift/megatron/utils/io_utils.py index 602be8fa0f..a93b80701b 100644 --- a/swift/megatron/utils/io_utils.py +++ b/swift/megatron/utils/io_utils.py @@ -75,8 +75,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): class StreamingSafetensorSaver: - def __init__(self, save_dir, max_shard_size='5GB', save_rank: Literal['master', 'last'] = 'last') -> None: - # max_shard_size: GiB + def __init__( + self, + save_dir, + max_shard_size: str = '5GB', + save_rank: Literal['master', 'last'] = 'last', + is_peft_format: bool = False, + ) -> None: self.save_dir = save_dir if isinstance(max_shard_size, str): if max_shard_size.endswith('GB'): @@ -90,6 +95,7 @@ def __init__(self, save_dir, max_shard_size='5GB', save_rank: Literal['master', self.shard_index = 1 self.weight_map = {} self.is_save_rank = is_last_rank() if save_rank == 'last' else is_master() + self.is_peft_format = is_peft_format if self.is_save_rank: os.makedirs(save_dir, exist_ok=True) @@ -97,7 +103,8 @@ def add_tensor(self, name, tensor): if not self.is_save_rank: return tensor_size = tensor.numel() * tensor.element_size() - if self.current_shard_size + tensor_size > self.max_shard_size and self.current_shard: + if (self.current_shard_size + tensor_size > self.max_shard_size and self.current_shard + and not self.is_peft_format): self._save_current_shard() self.current_shard[name] = tensor.cpu().contiguous() From e26587d5b98486b45ec5babae40cc4e0378f2799 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 30 Oct 2025 16:04:00 +0800 Subject: [PATCH 28/30] support lora --- swift/megatron/utils/io_utils.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/swift/megatron/utils/io_utils.py b/swift/megatron/utils/io_utils.py index a93b80701b..fefe179c61 100644 --- a/swift/megatron/utils/io_utils.py +++ b/swift/megatron/utils/io_utils.py @@ -114,7 +114,10 @@ def _save_current_shard(self, shard_filename: str = None): if not self.current_shard: return if shard_filename is None: - shard_filename = f'model-{self.shard_index:05d}-of-?????.safetensors' + if self.is_peft_format: + shard_filename = 'adapter_model.safetensors' + else: + shard_filename = f'model-{self.shard_index:05d}-of-?????.safetensors' shard_path = os.path.join(self.save_dir, shard_filename) save_file(self.current_shard, str(shard_path)) for key in self.current_shard.keys(): @@ -130,6 +133,8 @@ def finalize(self): return if self.current_shard: self._save_current_shard() + if self.is_peft_format: + return total_shards = self.shard_index - 1 # rename `?????` for i in range(1, total_shards + 1): @@ -142,12 +147,13 @@ def finalize(self): if os.path.exists(old_path): os.rename(old_path, new_path) - updated_weight_map = {} - for key, filename in self.weight_map.items(): - new_filename = filename.replace('?????', f'{total_shards:05d}') - updated_weight_map[key] = new_filename + if total_shards > 1: + updated_weight_map = {} + for key, filename in self.weight_map.items(): + new_filename = filename.replace('?????', f'{total_shards:05d}') + updated_weight_map[key] = new_filename - self._save_index(updated_weight_map) + self._save_index(updated_weight_map) def _save_index(self, weight_map): index = {'metadata': {'total_size': self.total_size}, 'weight_map': weight_map} From 26a46ea7fa64b1b69a6a31c3cd2a3b945c627d4a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 30 Oct 2025 17:11:38 +0800 Subject: [PATCH 29/30] fix lora --- swift/megatron/model/gpt_bridge.py | 72 +++++++++++++++--------------- swift/megatron/trainers/base.py | 3 +- swift/megatron/utils/io_utils.py | 11 +++-- 3 files changed, 47 insertions(+), 39 deletions(-) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 86dda71ef7..9027493f79 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -24,6 +24,7 @@ def __init__(self, disable_tqmd: bool = False): self._target_device = None self._only_last_rank = False self._peft_target_modules = set() + self._is_peft_format = False model_info = self.args.model_info with torch.device('meta'), disable_safe_ddp_context_use_barrier(): self.hf_model, self.processor = get_model_tokenizer( @@ -43,7 +44,6 @@ def __init__(self, disable_tqmd: bool = False): self.pp_rank = mpu.get_pipeline_model_parallel_rank() self.etp_rank = mpu.get_expert_tensor_parallel_rank() self.ep_rank = mpu.get_expert_model_parallel_rank() - self.save_peft_checkpoint = self.args.train_type == 'lora' and not self.args.merge_lora @staticmethod def _get_tp_split_dim(mg_key: str) -> Optional[int]: @@ -60,12 +60,14 @@ def _get_tp_split_dim(mg_key: str) -> Optional[int]: mg_key_splited = mg_key.rsplit('.', 3) key, lora_name = mg_key_splited[:2] if lora_name == 'lora_A': - if key in {'linear_proj', 'linear_fc1', 'linear_fc2'}: - # linear_fc1 shape [2, X, Y] + if key in {'linear_proj', 'linear_fc2'}: return 1 elif lora_name == 'lora_B': if key in {'word_embeddings', 'output_layer', 'linear_qkv'}: return 0 + elif key in {'linear_fc1'}: + # linear_fc1 shape [2, X, Y] + return 1 def _set_weight( self, @@ -174,22 +176,21 @@ def _set_state_dict(self, is_expert: bool = False): module_key, param_key = mg_key.rsplit('.', 1) sub_module = deep_getattr(mg_module, module_key) - if isinstance(sub_module, - LoraParallelLinear) and self.save_peft_checkpoint and param_key != 'layer_norm_weight': + if isinstance(sub_module, LoraParallelLinear) and self._is_peft_format and param_key != 'layer_norm_weight': hf_module_key, hf_param_key = hf_key.rsplit('.', 1) lora_A_key = f'{module_key}.lora_A.default.{param_key}' lora_B_key = f'{module_key}.lora_B.default.{param_key}' lora_A_tensor = deep_getattr(mg_module, f'{lora_A_key}.data') lora_B_tensor = deep_getattr(mg_module, f'{lora_B_key}.data') - hf_lora_A_key = f'{hf_module_key}.lora_A.default.{hf_param_key}' - hf_lora_B_key = f'{hf_module_key}.lora_B.default.{hf_param_key}' + hf_lora_A_key = f'{hf_module_key}.lora_A.{hf_param_key}' + hf_lora_B_key = f'{hf_module_key}.lora_B.{hf_param_key}' lora_A = self._get_weight(lora_A_tensor, lora_A_key, offset, is_expert) lora_B = self._get_weight(lora_B_tensor, lora_B_key, offset, is_expert) if lora_A is not None: self._peft_target_modules.add(hf_module_key) hf_state_dict[hf_lora_A_key] = lora_A hf_state_dict[hf_lora_B_key] = lora_B - elif to_mcore or not self.save_peft_checkpoint: + elif to_mcore or not self._is_peft_format: if isinstance(sub_module, LoraParallelLinear): mg_param = deep_getattr(sub_module, f'base_layer.{param_key}') else: @@ -248,25 +249,24 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ 0] // num_query_groups is_lora = False if mg_attn is None else isinstance(mg_attn.linear_qkv, - LoraParallelLinear) and not mg_attn.linear_qkv.merged + LoraParallelLinear) and self._is_peft_format if self.pp_size > 1: dist.all_reduce(is_lora, group=self.pp_group) if is_lora: lora_A = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.lora_A['default'].weight.data, - 'linear_qkv.lora_A.default.weight') + 'linear_qkv.lora_A.weight') lora_B = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.lora_B['default'].weight.data, - 'linear_qkv.lora_B.default.weight') + 'linear_qkv.lora_B.weight') if lora_A is not None: self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'}) for key in ['q_proj', 'k_proj', 'v_proj']: - hf_state_dict[f'{key}.lora_A.default.weight'] = lora_A.clone() + hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() lora_B = lora_B.reshape((num_query_groups, -1, lora_B.shape[-1])) - hf_state_dict['q_proj.lora_B.default.weight'] = lora_B[:, :q_dim, :].reshape( - -1, lora_B.shape[-1]).clone() - hf_state_dict['k_proj.lora_B.default.weight'] = lora_B[:, q_dim:-kv_dim, :].reshape( - -1, lora_B.shape[-1]).clone() - hf_state_dict['v_proj.lora_B.default.weight'] = lora_B[:, -kv_dim:, :].reshape( - -1, lora_B.shape[-1]).clone() + hf_state_dict['q_proj.lora_B.weight'] = lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone() + hf_state_dict['k_proj.lora_B.weight'] = lora_B[:, + q_dim:-kv_dim, :].reshape(-1, + lora_B.shape[-1]).clone() + hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, -kv_dim:, :].reshape(-1, lora_B.shape[-1]).clone() else: mg_attn_weight = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') @@ -291,7 +291,7 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int ], dim=1).reshape(-1) self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias') - else: + elif not self._is_peft_format: mg_attn_bias = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.bias.data, 'linear_qkv.bias') if mg_attn_bias is not None: @@ -472,7 +472,7 @@ def _set_mlp_state(self, mg_mlp, hf_state_dict, hf_prefix: str, layer_idx: int, mg_mlp.linear_fc1.weight.data.copy_(fc1_weight.view(-1, fc1_weight.shape[-1])) else: is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1, - LoraParallelLinear) and not mg_mlp.linear_fc1.merged + LoraParallelLinear) and self._is_peft_format if self.pp_size > 1: dist.all_reduce(is_lora, group=self.pp_group) if is_lora: @@ -483,19 +483,19 @@ def _set_mlp_state(self, mg_mlp, hf_state_dict, hf_prefix: str, layer_idx: int, lora_A = mg_mlp.linear_fc1.lora_A['default'].weight lora_B = mg_mlp.linear_fc1.lora_B['default'].weight lora_B = lora_B.view(2, lora_B.shape[0] // 2, lora_B.shape[1]) - lora_A = self._get_weight(lora_A, 'linear_fc1.lora_A.default.weight') - lora_B = self._get_weight(lora_B, 'linear_fc1.lora_B.default.weight') + lora_A = self._get_weight(lora_A, 'linear_fc1.lora_A.weight') + lora_B = self._get_weight(lora_B, 'linear_fc1.lora_B.weight') if lora_A is not None: if hasattr(hf_mlp, 'gate_up_proj'): self._peft_target_modules.update({'gate_up_proj'}) - hf_state_dict['gate_up_proj.lora_A.default.weight'] = lora_A.clone() - hf_state_dict['gate_up_proj.lora_B.default.weight'] = lora_B.clone() + hf_state_dict['gate_up_proj.lora_A.weight'] = lora_A.clone() + hf_state_dict['gate_up_proj.lora_B.weight'] = lora_B.clone() else: self._peft_target_modules.update({'gate_proj', 'up_proj'}) - hf_state_dict['gate_proj.lora_A.default.weight'] = lora_A.clone() - hf_state_dict['up_proj.lora_A.default.weight'] = lora_A.clone() - hf_state_dict['gate_proj.lora_B.default.weight'] = lora_B[0].clone() - hf_state_dict['up_proj.lora_B.default.weight'] = lora_B[1].clone() + hf_state_dict['gate_proj.lora_A.weight'] = lora_A.clone() + hf_state_dict['up_proj.lora_A.weight'] = lora_A.clone() + hf_state_dict['gate_proj.lora_B.weight'] = lora_B[0].clone() + hf_state_dict['up_proj.lora_B.weight'] = lora_B[1].clone() else: if mg_mlp is None: fc1_weight = None @@ -630,28 +630,30 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool): yield else: yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) - hf_state_dict = {} def load_weights(self, mg_model, hf_model_dir: str) -> None: with SafetensorLazyLoader(hf_model_dir) as loader: state_dict = loader.get_state_dict() list(self._convert([mg_model], state_dict, '', True)) - def export_weights(self, mg_models, target_device=None, only_last_rank: bool = False): + def export_weights(self, mg_models, target_device=None, only_last_rank: bool = False, is_peft_format: bool = False): self._target_device = target_device self._only_last_rank = only_last_rank - yield from self._convert(mg_models, {}, '', False) + self._is_peft_format = is_peft_format + hf_prefix = 'base_model.model.' if is_peft_format else '' + yield from self._convert(mg_models, {}, hf_prefix, False) - def save_weights(self, mg_models, output_dir: str) -> None: + def save_weights(self, mg_models, output_dir: str, is_peft_format: bool = False) -> None: """Save the mg_model checkpoint in HF format""" saver = StreamingSafetensorSaver( - save_dir=output_dir, max_shard_size=self.args.max_shard_size, is_peft_format=self.save_peft_checkpoint) + save_dir=output_dir, max_shard_size=self.args.max_shard_size, is_peft_format=is_peft_format) self._peft_target_modules = set() - for k, v in self.export_weights(mg_models, target_device='cpu', only_last_rank=True): + for k, v in self.export_weights( + mg_models, target_device='cpu', only_last_rank=True, is_peft_format=is_peft_format): saver.add_tensor(k, v) saver.finalize() if is_last_rank(): - if self.save_peft_checkpoint: + if is_peft_format: peft_config = copy(mg_models[0].peft_config['default']) peft_config.target_modules = self._peft_target_modules peft_config.save_pretrained(output_dir) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 9dbff94f95..766e21b57e 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -734,7 +734,8 @@ def save_checkpoint(self, iteration, *_args, **kwargs): if args.train_type == 'lora' and args.merge_lora: self.merge_lora_adapters() output_dir = os.path.join(args.save, f'checkpoint-{iteration}') - self.bridge.save_weights(self.unwrapped_models, output_dir) + save_peft_format = args.train_type == 'lora' and not args.merge_lora + self.bridge.save_weights(self.unwrapped_models, output_dir, is_peft_format=save_peft_format) if is_last_rank(): args_path = os.path.join(os.path.dirname(output_dir), 'args.json') if os.path.exists(args_path): diff --git a/swift/megatron/utils/io_utils.py b/swift/megatron/utils/io_utils.py index fefe179c61..cbe9505efd 100644 --- a/swift/megatron/utils/io_utils.py +++ b/swift/megatron/utils/io_utils.py @@ -23,8 +23,9 @@ def load(self): class SafetensorLazyLoader: - def __init__(self, hf_model_dir: str): + def __init__(self, hf_model_dir: str, is_peft_format: bool = False): self.hf_model_dir = hf_model_dir + self.is_peft_format = is_peft_format self._weight_map = {} self._file_handles = {} self._load_index() @@ -45,12 +46,16 @@ def _load_index(self): self._index_file = json.load(f) self._weight_map = self._index_file.get('weight_map', {}) else: + if self.is_peft_format: + safetensors_fname = 'adapter_model.safetensors' + else: + safetensors_fname = 'model.safetensors' # Single file model - safetensors_file = os.path.join(self.hf_model_dir, 'model.safetensors') + safetensors_file = os.path.join(self.hf_model_dir, safetensors_fname) if os.path.exists(safetensors_file): with safe_open(safetensors_file, framework='pt') as f: for key in f.keys(): - self._weight_map[key] = 'model.safetensors' + self._weight_map[key] = safetensors_fname def get_state_dict(self): res = {} From 763d0913b946ec5c1b2434256ac4a1f4f96d08ae Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 31 Oct 2025 00:47:49 +0800 Subject: [PATCH 30/30] update --- swift/llm/argument/base_args/base_args.py | 5 +- swift/llm/infer/utils.py | 3 +- swift/megatron/argument/export_args.py | 5 + swift/megatron/argument/megatron_args.py | 28 +++- swift/megatron/argument/megatron_base_args.py | 2 + swift/megatron/export/export.py | 30 +++- swift/megatron/model/gpt_bridge.py | 130 ++++++++++++------ swift/megatron/trainers/base.py | 6 +- 8 files changed, 157 insertions(+), 52 deletions(-) diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 2f337317e9..31c647224f 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -220,7 +220,10 @@ def from_pretrained(cls, checkpoint_dir: str): def _init_ckpt_dir(self, adapters=None): # compat megatron model = self.model or getattr(self, 'mcore_model', None) or getattr(self, 'load', None) - adapters = adapters or self.adapters or getattr(self, 'mcore_adapters', None) + adapters = adapters or self.adapters or getattr(self, 'mcore_adapters', None) or getattr( + self, 'adapter_load', None) + if isinstance(adapters, str): + adapters = [adapters] self.ckpt_dir = get_ckpt_dir(model, adapters) if self.ckpt_dir and self.load_args: self.load_args_from_ckpt() diff --git a/swift/llm/infer/utils.py b/swift/llm/infer/utils.py index 49f35fc646..49bc8b53a2 100644 --- a/swift/llm/infer/utils.py +++ b/swift/llm/infer/utils.py @@ -142,11 +142,12 @@ def prepare_adapter(args, model, adapters=None): def prepare_model_template(args, **kwargs): + adapters = kwargs.get('adapters') model, processor = args.get_model_processor(**kwargs) template = args.get_template(processor) if model is not None: if template.use_model: template.model = model - model = prepare_adapter(args, model) + model = prepare_adapter(args, model, adapters=adapters) update_generation_config_eos_token(model.generation_config, template) return model, template diff --git a/swift/megatron/argument/export_args.py b/swift/megatron/argument/export_args.py index a389d765d2..e3492e6371 100644 --- a/swift/megatron/argument/export_args.py +++ b/swift/megatron/argument/export_args.py @@ -5,6 +5,7 @@ from swift.llm import HfConfigFactory from swift.llm.argument.base_args import to_abspath from swift.utils import get_logger +from .megatron_args import MegatronArguments from .megatron_base_args import MegatronBaseArguments logger = get_logger() @@ -37,6 +38,10 @@ def __post_init__(self): super().__post_init__() self._init_save() self.test_convert_dtype = HfConfigFactory.to_torch_dtype(self.test_convert_dtype) + extra_config = MegatronArguments.load_args_config(self.adapter_load or self.load) + extra_config['adapter_load'] = self.adapter_load + for k, v in extra_config.items(): + setattr(self, k, v) if self.to_hf or self.to_mcore: self._init_convert() if self.model_info.is_moe_model is not None and self.tensor_model_parallel_size > 1: diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 98fed1e4be..203167ec32 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -9,7 +9,6 @@ import torch from transformers.utils.versions import require_version -from swift.llm.argument.base_args import to_abspath from swift.utils import get_dist_setting, get_logger, json_parse_to_dict logger = get_logger() @@ -96,6 +95,8 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): mlp_padding_free: bool = False load_hf_checkpoint: bool = False save_hf_checkpoint: bool = False + model: Optional[str] = None + adapters: List[str] = field(default_factory=list) merge_lora: Optional[bool] = None # streaming dataloader dataloader_persistent_workers: bool = True @@ -459,6 +460,12 @@ def __post_init__(self): self.position_embedding_type = 'rope' if self.merge_lora is None: self.merge_lora = self.save_hf_checkpoint + if self.adapters or self.adapter_load or self.ref_adapter_load: + if self.train_type == 'full': + self.train_type = 'lora' + logger.info('Setting args.train_type: lora') + if self.adapters: + self._load_adapter_config() self._init_moe() self._init_mixed_precision() @@ -502,3 +509,22 @@ def parse_to_megatron(self): # parameter conflict extra_args.pop('loss_scale', None) return extra_args + + def _load_adapter_config(self): + assert len(self.adapters) == 1, 'Currently only support one adapter' + adapter_path = self.adapters[0] + adapter_config_path = os.path.join(adapter_path, 'adapter_config.json') + adapter_config = {} + if os.path.exists(adapter_config_path): + with open(adapter_config_path, 'r') as f: + adapter_config = json.load(f) + mapping = {'r': 'lora_rank', 'bias': 'lora_bias'} + for k in ['lora_alpha', 'lora_dropout', 'use_rslora']: + mapping[k] = k + for k, v in adapter_config.items(): + if k not in mapping: + continue + k = mapping[k] + if v != getattr(self, k): + setattr(self, k, v) + logger.info(f'Setting {k}: {v}') diff --git a/swift/megatron/argument/megatron_base_args.py b/swift/megatron/argument/megatron_base_args.py index 6f978469e4..0f496aac42 100644 --- a/swift/megatron/argument/megatron_base_args.py +++ b/swift/megatron/argument/megatron_base_args.py @@ -1,6 +1,8 @@ import os from dataclasses import dataclass +import json + from swift.llm import BaseArguments from swift.utils import get_logger from ..model import get_megatron_model_meta diff --git a/swift/megatron/export/export.py b/swift/megatron/export/export.py index c1dc6c626e..6ff56a52b0 100644 --- a/swift/megatron/export/export.py +++ b/swift/megatron/export/export.py @@ -11,7 +11,7 @@ from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_last_rank from ..argument import MegatronExportArguments from ..convert import test_convert_precision -from ..utils import patch_load_base_checkpoint, prepare_mcore_model +from ..utils import adapter_state_dict_context, patch_load_base_checkpoint, prepare_mcore_model logger = get_logger() @@ -41,13 +41,25 @@ def convert_mcore2hf(self) -> None: mg_model = megatron_model_meta.model_provider(pre_process=pre_process, post_process=post_process) with patch_load_base_checkpoint(): load_checkpoint([mg_model], None, None, strict=True) + if args.adapter_load is not None: + prepare_mcore_model(mg_model) + with adapter_state_dict_context(): + load_checkpoint([mg_model], None, None, load_arg='adapter_load', strict=False) logger.info('Converting weights and saving the model...') bridge = megatron_model_meta.bridge_cls() - bridge.save_weights([mg_model], args.save) + save_peft_format = args.train_type == 'lora' and not args.merge_lora + bridge.save_weights([mg_model], args.save, is_peft_format=save_peft_format) + if is_last_rank(): + args_path = os.path.join(os.path.dirname(args.save), 'args.json') + if os.path.exists(args_path): + shutil.copy(args_path, os.path.join(args.save, 'args.json')) if args.test_convert_precision: with disable_safe_ddp_context_use_barrier(): - hf_model = prepare_model_template( - args, model=args.save, device_map='cpu')[0] if is_last_rank() else None + if save_peft_format: + kwargs = {'adapters': [args.save]} + else: + kwargs - {'model': args.save} + hf_model = prepare_model_template(args, device_map='cpu', **kwargs)[0] if is_last_rank() else None test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) dist.barrier() @@ -67,15 +79,21 @@ def convert_hf2mcore(self) -> None: bridge = megatron_model_meta.bridge_cls() bridge.load_weights(mg_model, args.model_info.model_dir) dist.barrier() + if args.adapters: + prepare_mcore_model(mg_model) + assert len(args.adapters) == 1, 'Currently only support one adapter' + bridge.load_weights(mg_model, args.adapters[0], is_peft_format=True) logger.info('Successfully transferred HF model weights to MG model.') if args.test_convert_precision: with disable_safe_ddp_context_use_barrier(): hf_model = prepare_model_template(args, device_map='cpu')[0] if is_last_rank() else None test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype) dist.barrier() - args.save_args(args.save) + if is_last_rank(): + args.save_args(args.save) logger.info('Saving the model...') - mg_save_checkpoint(1, [mg_model], None, None, 0) + with adapter_state_dict_context(): + mg_save_checkpoint(1, [mg_model], None, None, 0) logger.info_if(f'Successfully saved Megatron model weights in `{args.save}`.', cond=is_last_rank()) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 9027493f79..a599d6e639 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -79,7 +79,7 @@ def _set_weight( ): # tp/etp tp_dim = self._get_tp_split_dim(mg_key) - hf_weight = hf_weight.to(mg_param.device) + hf_weight = hf_weight.to(device=mg_param.device, dtype=mg_param.dtype) tp_size = self.etp_size if is_expert else self.tp_size tp_rank = self.etp_rank if is_expert else self.tp_rank tp_group = self.etp_group if is_expert else self.tp_group @@ -177,20 +177,31 @@ def _set_state_dict(self, module_key, param_key = mg_key.rsplit('.', 1) sub_module = deep_getattr(mg_module, module_key) if isinstance(sub_module, LoraParallelLinear) and self._is_peft_format and param_key != 'layer_norm_weight': - hf_module_key, hf_param_key = hf_key.rsplit('.', 1) - lora_A_key = f'{module_key}.lora_A.default.{param_key}' - lora_B_key = f'{module_key}.lora_B.default.{param_key}' - lora_A_tensor = deep_getattr(mg_module, f'{lora_A_key}.data') - lora_B_tensor = deep_getattr(mg_module, f'{lora_B_key}.data') - hf_lora_A_key = f'{hf_module_key}.lora_A.{hf_param_key}' - hf_lora_B_key = f'{hf_module_key}.lora_B.{hf_param_key}' - lora_A = self._get_weight(lora_A_tensor, lora_A_key, offset, is_expert) - lora_B = self._get_weight(lora_B_tensor, lora_B_key, offset, is_expert) - if lora_A is not None: - self._peft_target_modules.add(hf_module_key) - hf_state_dict[hf_lora_A_key] = lora_A - hf_state_dict[hf_lora_B_key] = lora_B - elif to_mcore or not self._is_peft_format: + if to_mcore: + hf_module_key, hf_param_key = hf_key.rsplit('.', 1) + lora_A_key = f'{module_key}.lora_A.default.{param_key}' + lora_B_key = f'{module_key}.lora_B.default.{param_key}' + mg_lora_A = deep_getattr(mg_module, f'{lora_A_key}') + mg_lora_B = deep_getattr(mg_module, f'{lora_B_key}') + hf_lora_A = hf_state_dict[f'{hf_module_key}.lora_A.{hf_param_key}'].load() + hf_lora_B = hf_state_dict[f'{hf_module_key}.lora_B.{hf_param_key}'].load() + self._set_weight(mg_lora_A, hf_lora_A, lora_A_key, offset, is_expert) + self._set_weight(mg_lora_B, hf_lora_B, lora_B_key, offset, is_expert) + else: + hf_module_key, hf_param_key = hf_key.rsplit('.', 1) + lora_A_key = f'{module_key}.lora_A.default.{param_key}' + lora_B_key = f'{module_key}.lora_B.default.{param_key}' + lora_A_tensor = deep_getattr(mg_module, f'{lora_A_key}.data') + lora_B_tensor = deep_getattr(mg_module, f'{lora_B_key}.data') + hf_lora_A_key = f'{hf_module_key}.lora_A.{hf_param_key}' + hf_lora_B_key = f'{hf_module_key}.lora_B.{hf_param_key}' + lora_A = self._get_weight(lora_A_tensor, lora_A_key, offset, is_expert) + lora_B = self._get_weight(lora_B_tensor, lora_B_key, offset, is_expert) + if lora_A is not None: + self._peft_target_modules.add(hf_module_key) + hf_state_dict[hf_lora_A_key] = lora_A + hf_state_dict[hf_lora_B_key] = lora_B + elif not self._is_peft_format: if isinstance(sub_module, LoraParallelLinear): mg_param = deep_getattr(sub_module, f'base_layer.{param_key}') else: @@ -238,13 +249,28 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int args = self.args num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads) if to_mcore: - linear_qkv_weight = torch.cat([ - hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), - hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), - hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), - ], - dim=1).reshape((-1, args.hidden_size)) - self._set_weight(mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight') + if isinstance(mg_attn.linear_qkv, LoraParallelLinear): + lora_A = hf_state_dict['q_proj.lora_A.weight'].load() + assert (lora_A == hf_state_dict['k_proj.lora_A.weight'].load()).all() and ( + lora_A == hf_state_dict['v_proj.lora_A.weight'].load() + ).all(), 'Need to ensure QKV\'s lora_A are consistent' + q_lora_B = hf_state_dict['q_proj.lora_B.weight'].load() + lora_B = torch.cat([ + q_lora_B.reshape((num_query_groups, -1, q_lora_B.shape[-1])), + hf_state_dict['k_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])), + hf_state_dict['v_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])), + ], + dim=1).reshape((-1, q_lora_B.shape[-1])) + self._set_weight(mg_attn.linear_qkv.lora_A['default'].weight, lora_A, 'linear_qkv.lora_A.weight') + self._set_weight(mg_attn.linear_qkv.lora_B['default'].weight, lora_B, 'linear_qkv.lora_B.weight') + else: + linear_qkv_weight = torch.cat([ + hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), + hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), + hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)), + ], + dim=1).reshape((-1, args.hidden_size)) + self._set_weight(mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight') else: q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[ 0] // num_query_groups @@ -254,9 +280,9 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int dist.all_reduce(is_lora, group=self.pp_group) if is_lora: lora_A = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.lora_A['default'].weight.data, - 'linear_qkv.lora_A.weight') + 'linear_qkv.lora_A.default.weight') lora_B = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.lora_B['default'].weight.data, - 'linear_qkv.lora_B.weight') + 'linear_qkv.lora_B.default.weight') if lora_A is not None: self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'}) for key in ['q_proj', 'k_proj', 'v_proj']: @@ -282,7 +308,7 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore) # Copy bias - if args.add_qkv_bias: + if args.add_qkv_bias and not self._is_peft_format: if to_mcore: linear_qkv_bias = torch.cat([ hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)), @@ -291,7 +317,7 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int ], dim=1).reshape(-1) self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias') - elif not self._is_peft_format: + else: mg_attn_bias = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.bias.data, 'linear_qkv.bias') if mg_attn_bias is not None: @@ -459,17 +485,34 @@ def _set_mlp_state(self, mg_mlp, hf_state_dict, hf_prefix: str, layer_idx: int, if hf_mlp is None: hf_mlp = self.hf_layers[layer_idx].mlp if to_mcore: - fc1_weight = mg_mlp.linear_fc1.weight - fc1_weight = fc1_weight.new_empty(2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) - if hasattr(hf_mlp, 'gate_up_proj'): - gate_up_proj_weight = hf_state_dict['gate_up_proj.weight'].load() + if isinstance(mg_mlp.linear_fc1, LoraParallelLinear): + mg_lora_B = mg_mlp.linear_fc1.lora_B['default'].weight + mg_lora_B = mg_lora_B.new_empty(2, mg_lora_B.shape[0] // 2, mg_lora_B.shape[-1]) + if hasattr(hf_mlp, 'gate_up_proj'): + lora_A = hf_state_dict['gate_up_proj.lora_A.weight'].load() + lora_B = hf_state_dict['gate_up_proj.lora_B.weight'].load() + else: + lora_A = hf_state_dict['gate_proj.lora_A.weight'].load() + assert (lora_A == hf_state_dict['up_proj.lora_A.weight'].load() + ).all(), 'Need to ensure lora_A consistency between gate_proj and up_proj' + gate_lora_B = hf_state_dict['gate_proj.lora_B.weight'].load() + up_lora_B = hf_state_dict['up_proj.lora_B.weight'].load() + lora_B = torch.stack([gate_lora_B, up_lora_B], dim=0) + self._set_weight(mg_mlp.linear_fc1.lora_A['default'].weight, lora_A, 'linear_fc1.lora_A.default.weight') + self._set_weight(mg_lora_B, lora_B, 'linear_fc1.lora_B.default.weight') + mg_mlp.linear_fc1.lora_B['default'].weight.data.copy_(mg_lora_B.view(-1, mg_lora_B.shape[-1])) else: - gate_proj_weight = hf_state_dict['gate_proj.weight'].load() - up_proj_weight = hf_state_dict['up_proj.weight'].load() - gate_up_proj_weight = torch.concat([gate_proj_weight, up_proj_weight], dim=0) - gate_up_proj_weight = gate_up_proj_weight.view(2, -1, gate_up_proj_weight.shape[-1]) - self._set_weight(fc1_weight, gate_up_proj_weight, 'linear_fc1.weight') - mg_mlp.linear_fc1.weight.data.copy_(fc1_weight.view(-1, fc1_weight.shape[-1])) + fc1_weight = mg_mlp.linear_fc1.weight + fc1_weight = fc1_weight.new_empty(2, fc1_weight.shape[0] // 2, fc1_weight.shape[1]) + if hasattr(hf_mlp, 'gate_up_proj'): + gate_up_proj_weight = hf_state_dict['gate_up_proj.weight'].load().view( + 2, -1, gate_up_proj_weight.shape[-1]) + else: + gate_proj_weight = hf_state_dict['gate_proj.weight'].load() + up_proj_weight = hf_state_dict['up_proj.weight'].load() + gate_up_proj_weight = torch.stack([gate_proj_weight, up_proj_weight], dim=0) + self._set_weight(fc1_weight, gate_up_proj_weight, 'linear_fc1.weight') + mg_mlp.linear_fc1.weight.data.copy_(fc1_weight.view(-1, fc1_weight.shape[-1])) else: is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1, LoraParallelLinear) and self._is_peft_format @@ -483,8 +526,8 @@ def _set_mlp_state(self, mg_mlp, hf_state_dict, hf_prefix: str, layer_idx: int, lora_A = mg_mlp.linear_fc1.lora_A['default'].weight lora_B = mg_mlp.linear_fc1.lora_B['default'].weight lora_B = lora_B.view(2, lora_B.shape[0] // 2, lora_B.shape[1]) - lora_A = self._get_weight(lora_A, 'linear_fc1.lora_A.weight') - lora_B = self._get_weight(lora_B, 'linear_fc1.lora_B.weight') + lora_A = self._get_weight(lora_A, 'linear_fc1.lora_A.default.weight') + lora_B = self._get_weight(lora_B, 'linear_fc1.lora_B.default.weight') if lora_A is not None: if hasattr(hf_mlp, 'gate_up_proj'): self._peft_target_modules.update({'gate_up_proj'}) @@ -631,15 +674,19 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool): else: yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) - def load_weights(self, mg_model, hf_model_dir: str) -> None: - with SafetensorLazyLoader(hf_model_dir) as loader: + def load_weights(self, mg_model, hf_model_dir: str, is_peft_format: bool = False): + self._is_peft_format = is_peft_format + with SafetensorLazyLoader(hf_model_dir, is_peft_format=is_peft_format) as loader: state_dict = loader.get_state_dict() - list(self._convert([mg_model], state_dict, '', True)) + hf_prefix = 'base_model.model.' if is_peft_format else '' + list(self._convert([mg_model], state_dict, hf_prefix, True)) def export_weights(self, mg_models, target_device=None, only_last_rank: bool = False, is_peft_format: bool = False): + # TODO: modules_to_save self._target_device = target_device self._only_last_rank = only_last_rank self._is_peft_format = is_peft_format + self._peft_target_modules = set() hf_prefix = 'base_model.model.' if is_peft_format else '' yield from self._convert(mg_models, {}, hf_prefix, False) @@ -647,7 +694,6 @@ def save_weights(self, mg_models, output_dir: str, is_peft_format: bool = False) """Save the mg_model checkpoint in HF format""" saver = StreamingSafetensorSaver( save_dir=output_dir, max_shard_size=self.args.max_shard_size, is_peft_format=is_peft_format) - self._peft_target_modules = set() for k, v in self.export_weights( mg_models, target_device='cpu', only_last_rank=True, is_peft_format=is_peft_format): saver.add_tensor(k, v) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 766e21b57e..b483977214 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -243,7 +243,11 @@ def new_model_provider_func(*_args, **kwargs): if args.load_hf_checkpoint: self.bridge.load_weights(model, args.model_info.model_dir) self.unwrapped_models.append(model) - self.peft_models.append(prepare_mcore_model(model)) + peft_model = prepare_mcore_model(model) + if args.load_hf_checkpoint and args.train_type == 'lora' and args.adapters: + assert len(args.adapters) == 1, 'Currently only support one adapter' + self.bridge.load_weights(model, args.adapters[0], is_peft_format=True) + self.peft_models.append(peft_model) return model self._init_multimodal_full(args)